decode.go (18014B)
1 package toml 2 3 import ( 4 "bytes" 5 "encoding" 6 "encoding/json" 7 "fmt" 8 "io" 9 "io/fs" 10 "math" 11 "os" 12 "reflect" 13 "strconv" 14 "strings" 15 "time" 16 ) 17 18 // Unmarshaler is the interface implemented by objects that can unmarshal a 19 // TOML description of themselves. 20 type Unmarshaler interface { 21 UnmarshalTOML(any) error 22 } 23 24 // Unmarshal decodes the contents of data in TOML format into a pointer v. 25 // 26 // See [Decoder] for a description of the decoding process. 27 func Unmarshal(data []byte, v any) error { 28 _, err := NewDecoder(bytes.NewReader(data)).Decode(v) 29 return err 30 } 31 32 // Decode the TOML data in to the pointer v. 33 // 34 // See [Decoder] for a description of the decoding process. 35 func Decode(data string, v any) (MetaData, error) { 36 return NewDecoder(strings.NewReader(data)).Decode(v) 37 } 38 39 // DecodeFile reads the contents of a file and decodes it with [Decode]. 40 func DecodeFile(path string, v any) (MetaData, error) { 41 fp, err := os.Open(path) 42 if err != nil { 43 return MetaData{}, err 44 } 45 defer fp.Close() 46 return NewDecoder(fp).Decode(v) 47 } 48 49 // DecodeFS reads the contents of a file from [fs.FS] and decodes it with 50 // [Decode]. 51 func DecodeFS(fsys fs.FS, path string, v any) (MetaData, error) { 52 fp, err := fsys.Open(path) 53 if err != nil { 54 return MetaData{}, err 55 } 56 defer fp.Close() 57 return NewDecoder(fp).Decode(v) 58 } 59 60 // Primitive is a TOML value that hasn't been decoded into a Go value. 61 // 62 // This type can be used for any value, which will cause decoding to be delayed. 63 // You can use [PrimitiveDecode] to "manually" decode these values. 64 // 65 // NOTE: The underlying representation of a `Primitive` value is subject to 66 // change. Do not rely on it. 67 // 68 // NOTE: Primitive values are still parsed, so using them will only avoid the 69 // overhead of reflection. They can be useful when you don't know the exact type 70 // of TOML data until runtime. 71 type Primitive struct { 72 undecoded any 73 context Key 74 } 75 76 // The significand precision for float32 and float64 is 24 and 53 bits; this is 77 // the range a natural number can be stored in a float without loss of data. 78 const ( 79 maxSafeFloat32Int = 16777215 // 2^24-1 80 maxSafeFloat64Int = int64(9007199254740991) // 2^53-1 81 ) 82 83 // Decoder decodes TOML data. 84 // 85 // TOML tables correspond to Go structs or maps; they can be used 86 // interchangeably, but structs offer better type safety. 87 // 88 // TOML table arrays correspond to either a slice of structs or a slice of maps. 89 // 90 // TOML datetimes correspond to [time.Time]. Local datetimes are parsed in the 91 // local timezone. 92 // 93 // [time.Duration] types are treated as nanoseconds if the TOML value is an 94 // integer, or they're parsed with time.ParseDuration() if they're strings. 95 // 96 // All other TOML types (float, string, int, bool and array) correspond to the 97 // obvious Go types. 98 // 99 // An exception to the above rules is if a type implements the TextUnmarshaler 100 // interface, in which case any primitive TOML value (floats, strings, integers, 101 // booleans, datetimes) will be converted to a []byte and given to the value's 102 // UnmarshalText method. See the Unmarshaler example for a demonstration with 103 // email addresses. 104 // 105 // # Key mapping 106 // 107 // TOML keys can map to either keys in a Go map or field names in a Go struct. 108 // The special `toml` struct tag can be used to map TOML keys to struct fields 109 // that don't match the key name exactly (see the example). A case insensitive 110 // match to struct names will be tried if an exact match can't be found. 111 // 112 // The mapping between TOML values and Go values is loose. That is, there may 113 // exist TOML values that cannot be placed into your representation, and there 114 // may be parts of your representation that do not correspond to TOML values. 115 // This loose mapping can be made stricter by using the IsDefined and/or 116 // Undecoded methods on the MetaData returned. 117 // 118 // This decoder does not handle cyclic types. Decode will not terminate if a 119 // cyclic type is passed. 120 type Decoder struct { 121 r io.Reader 122 } 123 124 // NewDecoder creates a new Decoder. 125 func NewDecoder(r io.Reader) *Decoder { 126 return &Decoder{r: r} 127 } 128 129 var ( 130 unmarshalToml = reflect.TypeOf((*Unmarshaler)(nil)).Elem() 131 unmarshalText = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() 132 primitiveType = reflect.TypeOf((*Primitive)(nil)).Elem() 133 ) 134 135 // Decode TOML data in to the pointer `v`. 136 func (dec *Decoder) Decode(v any) (MetaData, error) { 137 rv := reflect.ValueOf(v) 138 if rv.Kind() != reflect.Ptr { 139 s := "%q" 140 if reflect.TypeOf(v) == nil { 141 s = "%v" 142 } 143 144 return MetaData{}, fmt.Errorf("toml: cannot decode to non-pointer "+s, reflect.TypeOf(v)) 145 } 146 if rv.IsNil() { 147 return MetaData{}, fmt.Errorf("toml: cannot decode to nil value of %q", reflect.TypeOf(v)) 148 } 149 150 // Check if this is a supported type: struct, map, any, or something that 151 // implements UnmarshalTOML or UnmarshalText. 152 rv = indirect(rv) 153 rt := rv.Type() 154 if rv.Kind() != reflect.Struct && rv.Kind() != reflect.Map && 155 !(rv.Kind() == reflect.Interface && rv.NumMethod() == 0) && 156 !rt.Implements(unmarshalToml) && !rt.Implements(unmarshalText) { 157 return MetaData{}, fmt.Errorf("toml: cannot decode to type %s", rt) 158 } 159 160 // TODO: parser should read from io.Reader? Or at the very least, make it 161 // read from []byte rather than string 162 data, err := io.ReadAll(dec.r) 163 if err != nil { 164 return MetaData{}, err 165 } 166 167 p, err := parse(string(data)) 168 if err != nil { 169 return MetaData{}, err 170 } 171 172 md := MetaData{ 173 mapping: p.mapping, 174 keyInfo: p.keyInfo, 175 keys: p.ordered, 176 decoded: make(map[string]struct{}, len(p.ordered)), 177 context: nil, 178 data: data, 179 } 180 return md, md.unify(p.mapping, rv) 181 } 182 183 // PrimitiveDecode is just like the other Decode* functions, except it decodes a 184 // TOML value that has already been parsed. Valid primitive values can *only* be 185 // obtained from values filled by the decoder functions, including this method. 186 // (i.e., v may contain more [Primitive] values.) 187 // 188 // Meta data for primitive values is included in the meta data returned by the 189 // Decode* functions with one exception: keys returned by the Undecoded method 190 // will only reflect keys that were decoded. Namely, any keys hidden behind a 191 // Primitive will be considered undecoded. Executing this method will update the 192 // undecoded keys in the meta data. (See the example.) 193 func (md *MetaData) PrimitiveDecode(primValue Primitive, v any) error { 194 md.context = primValue.context 195 defer func() { md.context = nil }() 196 return md.unify(primValue.undecoded, rvalue(v)) 197 } 198 199 // markDecodedRecursive is a helper to mark any key under the given tmap as 200 // decoded, recursing as needed 201 func markDecodedRecursive(md *MetaData, tmap map[string]any) { 202 for key := range tmap { 203 md.decoded[md.context.add(key).String()] = struct{}{} 204 if tmap, ok := tmap[key].(map[string]any); ok { 205 md.context = append(md.context, key) 206 markDecodedRecursive(md, tmap) 207 md.context = md.context[0 : len(md.context)-1] 208 } 209 if tarr, ok := tmap[key].([]map[string]any); ok { 210 for _, elm := range tarr { 211 md.context = append(md.context, key) 212 markDecodedRecursive(md, elm) 213 md.context = md.context[0 : len(md.context)-1] 214 } 215 } 216 } 217 } 218 219 // unify performs a sort of type unification based on the structure of `rv`, 220 // which is the client representation. 221 // 222 // Any type mismatch produces an error. Finding a type that we don't know 223 // how to handle produces an unsupported type error. 224 func (md *MetaData) unify(data any, rv reflect.Value) error { 225 // Special case. Look for a `Primitive` value. 226 // TODO: #76 would make this superfluous after implemented. 227 if rv.Type() == primitiveType { 228 // Save the undecoded data and the key context into the primitive 229 // value. 230 context := make(Key, len(md.context)) 231 copy(context, md.context) 232 rv.Set(reflect.ValueOf(Primitive{ 233 undecoded: data, 234 context: context, 235 })) 236 return nil 237 } 238 239 rvi := rv.Interface() 240 if v, ok := rvi.(Unmarshaler); ok { 241 err := v.UnmarshalTOML(data) 242 if err != nil { 243 return md.parseErr(err) 244 } 245 // Assume the Unmarshaler decoded everything, so mark all keys under 246 // this table as decoded. 247 if tmap, ok := data.(map[string]any); ok { 248 markDecodedRecursive(md, tmap) 249 } 250 if aot, ok := data.([]map[string]any); ok { 251 for _, tmap := range aot { 252 markDecodedRecursive(md, tmap) 253 } 254 } 255 return nil 256 } 257 if v, ok := rvi.(encoding.TextUnmarshaler); ok { 258 return md.unifyText(data, v) 259 } 260 261 // TODO: 262 // The behavior here is incorrect whenever a Go type satisfies the 263 // encoding.TextUnmarshaler interface but also corresponds to a TOML hash or 264 // array. In particular, the unmarshaler should only be applied to primitive 265 // TOML values. But at this point, it will be applied to all kinds of values 266 // and produce an incorrect error whenever those values are hashes or arrays 267 // (including arrays of tables). 268 269 k := rv.Kind() 270 271 if k >= reflect.Int && k <= reflect.Uint64 { 272 return md.unifyInt(data, rv) 273 } 274 switch k { 275 case reflect.Struct: 276 return md.unifyStruct(data, rv) 277 case reflect.Map: 278 return md.unifyMap(data, rv) 279 case reflect.Array: 280 return md.unifyArray(data, rv) 281 case reflect.Slice: 282 return md.unifySlice(data, rv) 283 case reflect.String: 284 return md.unifyString(data, rv) 285 case reflect.Bool: 286 return md.unifyBool(data, rv) 287 case reflect.Interface: 288 if rv.NumMethod() > 0 { /// Only empty interfaces are supported. 289 return md.e("unsupported type %s", rv.Type()) 290 } 291 return md.unifyAnything(data, rv) 292 case reflect.Float32, reflect.Float64: 293 return md.unifyFloat64(data, rv) 294 } 295 return md.e("unsupported type %s", rv.Kind()) 296 } 297 298 func (md *MetaData) unifyStruct(mapping any, rv reflect.Value) error { 299 tmap, ok := mapping.(map[string]any) 300 if !ok { 301 if mapping == nil { 302 return nil 303 } 304 return md.e("type mismatch for %s: expected table but found %s", rv.Type().String(), fmtType(mapping)) 305 } 306 307 for key, datum := range tmap { 308 var f *field 309 fields := cachedTypeFields(rv.Type()) 310 for i := range fields { 311 ff := &fields[i] 312 if ff.name == key { 313 f = ff 314 break 315 } 316 if f == nil && strings.EqualFold(ff.name, key) { 317 f = ff 318 } 319 } 320 if f != nil { 321 subv := rv 322 for _, i := range f.index { 323 subv = indirect(subv.Field(i)) 324 } 325 326 if isUnifiable(subv) { 327 md.decoded[md.context.add(key).String()] = struct{}{} 328 md.context = append(md.context, key) 329 330 err := md.unify(datum, subv) 331 if err != nil { 332 return err 333 } 334 md.context = md.context[0 : len(md.context)-1] 335 } else if f.name != "" { 336 return md.e("cannot write unexported field %s.%s", rv.Type().String(), f.name) 337 } 338 } 339 } 340 return nil 341 } 342 343 func (md *MetaData) unifyMap(mapping any, rv reflect.Value) error { 344 keyType := rv.Type().Key().Kind() 345 if keyType != reflect.String && keyType != reflect.Interface { 346 return fmt.Errorf("toml: cannot decode to a map with non-string key type (%s in %q)", 347 keyType, rv.Type()) 348 } 349 350 tmap, ok := mapping.(map[string]any) 351 if !ok { 352 if tmap == nil { 353 return nil 354 } 355 return md.badtype("map", mapping) 356 } 357 if rv.IsNil() { 358 rv.Set(reflect.MakeMap(rv.Type())) 359 } 360 for k, v := range tmap { 361 md.decoded[md.context.add(k).String()] = struct{}{} 362 md.context = append(md.context, k) 363 364 rvval := reflect.Indirect(reflect.New(rv.Type().Elem())) 365 366 err := md.unify(v, indirect(rvval)) 367 if err != nil { 368 return err 369 } 370 md.context = md.context[0 : len(md.context)-1] 371 372 rvkey := indirect(reflect.New(rv.Type().Key())) 373 374 switch keyType { 375 case reflect.Interface: 376 rvkey.Set(reflect.ValueOf(k)) 377 case reflect.String: 378 rvkey.SetString(k) 379 } 380 381 rv.SetMapIndex(rvkey, rvval) 382 } 383 return nil 384 } 385 386 func (md *MetaData) unifyArray(data any, rv reflect.Value) error { 387 datav := reflect.ValueOf(data) 388 if datav.Kind() != reflect.Slice { 389 if !datav.IsValid() { 390 return nil 391 } 392 return md.badtype("slice", data) 393 } 394 if l := datav.Len(); l != rv.Len() { 395 return md.e("expected array length %d; got TOML array of length %d", rv.Len(), l) 396 } 397 return md.unifySliceArray(datav, rv) 398 } 399 400 func (md *MetaData) unifySlice(data any, rv reflect.Value) error { 401 datav := reflect.ValueOf(data) 402 if datav.Kind() != reflect.Slice { 403 if !datav.IsValid() { 404 return nil 405 } 406 return md.badtype("slice", data) 407 } 408 n := datav.Len() 409 if rv.IsNil() || rv.Cap() < n { 410 rv.Set(reflect.MakeSlice(rv.Type(), n, n)) 411 } 412 rv.SetLen(n) 413 return md.unifySliceArray(datav, rv) 414 } 415 416 func (md *MetaData) unifySliceArray(data, rv reflect.Value) error { 417 l := data.Len() 418 for i := 0; i < l; i++ { 419 err := md.unify(data.Index(i).Interface(), indirect(rv.Index(i))) 420 if err != nil { 421 return err 422 } 423 } 424 return nil 425 } 426 427 func (md *MetaData) unifyString(data any, rv reflect.Value) error { 428 _, ok := rv.Interface().(json.Number) 429 if ok { 430 if i, ok := data.(int64); ok { 431 rv.SetString(strconv.FormatInt(i, 10)) 432 } else if f, ok := data.(float64); ok { 433 rv.SetString(strconv.FormatFloat(f, 'g', -1, 64)) 434 } else { 435 return md.badtype("string", data) 436 } 437 return nil 438 } 439 440 if s, ok := data.(string); ok { 441 rv.SetString(s) 442 return nil 443 } 444 return md.badtype("string", data) 445 } 446 447 func (md *MetaData) unifyFloat64(data any, rv reflect.Value) error { 448 rvk := rv.Kind() 449 450 if num, ok := data.(float64); ok { 451 switch rvk { 452 case reflect.Float32: 453 if num < -math.MaxFloat32 || num > math.MaxFloat32 { 454 return md.parseErr(errParseRange{i: num, size: rvk.String()}) 455 } 456 fallthrough 457 case reflect.Float64: 458 rv.SetFloat(num) 459 default: 460 panic("bug") 461 } 462 return nil 463 } 464 465 if num, ok := data.(int64); ok { 466 if (rvk == reflect.Float32 && (num < -maxSafeFloat32Int || num > maxSafeFloat32Int)) || 467 (rvk == reflect.Float64 && (num < -maxSafeFloat64Int || num > maxSafeFloat64Int)) { 468 return md.parseErr(errUnsafeFloat{i: num, size: rvk.String()}) 469 } 470 rv.SetFloat(float64(num)) 471 return nil 472 } 473 474 return md.badtype("float", data) 475 } 476 477 func (md *MetaData) unifyInt(data any, rv reflect.Value) error { 478 _, ok := rv.Interface().(time.Duration) 479 if ok { 480 // Parse as string duration, and fall back to regular integer parsing 481 // (as nanosecond) if this is not a string. 482 if s, ok := data.(string); ok { 483 dur, err := time.ParseDuration(s) 484 if err != nil { 485 return md.parseErr(errParseDuration{s}) 486 } 487 rv.SetInt(int64(dur)) 488 return nil 489 } 490 } 491 492 num, ok := data.(int64) 493 if !ok { 494 return md.badtype("integer", data) 495 } 496 497 rvk := rv.Kind() 498 switch { 499 case rvk >= reflect.Int && rvk <= reflect.Int64: 500 if (rvk == reflect.Int8 && (num < math.MinInt8 || num > math.MaxInt8)) || 501 (rvk == reflect.Int16 && (num < math.MinInt16 || num > math.MaxInt16)) || 502 (rvk == reflect.Int32 && (num < math.MinInt32 || num > math.MaxInt32)) { 503 return md.parseErr(errParseRange{i: num, size: rvk.String()}) 504 } 505 rv.SetInt(num) 506 case rvk >= reflect.Uint && rvk <= reflect.Uint64: 507 unum := uint64(num) 508 if rvk == reflect.Uint8 && (num < 0 || unum > math.MaxUint8) || 509 rvk == reflect.Uint16 && (num < 0 || unum > math.MaxUint16) || 510 rvk == reflect.Uint32 && (num < 0 || unum > math.MaxUint32) { 511 return md.parseErr(errParseRange{i: num, size: rvk.String()}) 512 } 513 rv.SetUint(unum) 514 default: 515 panic("unreachable") 516 } 517 return nil 518 } 519 520 func (md *MetaData) unifyBool(data any, rv reflect.Value) error { 521 if b, ok := data.(bool); ok { 522 rv.SetBool(b) 523 return nil 524 } 525 return md.badtype("boolean", data) 526 } 527 528 func (md *MetaData) unifyAnything(data any, rv reflect.Value) error { 529 rv.Set(reflect.ValueOf(data)) 530 return nil 531 } 532 533 func (md *MetaData) unifyText(data any, v encoding.TextUnmarshaler) error { 534 var s string 535 switch sdata := data.(type) { 536 case Marshaler: 537 text, err := sdata.MarshalTOML() 538 if err != nil { 539 return err 540 } 541 s = string(text) 542 case encoding.TextMarshaler: 543 text, err := sdata.MarshalText() 544 if err != nil { 545 return err 546 } 547 s = string(text) 548 case fmt.Stringer: 549 s = sdata.String() 550 case string: 551 s = sdata 552 case bool: 553 s = fmt.Sprintf("%v", sdata) 554 case int64: 555 s = fmt.Sprintf("%d", sdata) 556 case float64: 557 s = fmt.Sprintf("%f", sdata) 558 default: 559 return md.badtype("primitive (string-like)", data) 560 } 561 if err := v.UnmarshalText([]byte(s)); err != nil { 562 return md.parseErr(err) 563 } 564 return nil 565 } 566 567 func (md *MetaData) badtype(dst string, data any) error { 568 return md.e("incompatible types: TOML value has type %s; destination has type %s", fmtType(data), dst) 569 } 570 571 func (md *MetaData) parseErr(err error) error { 572 k := md.context.String() 573 d := string(md.data) 574 return ParseError{ 575 Message: err.Error(), 576 err: err, 577 LastKey: k, 578 Position: md.keyInfo[k].pos.withCol(d), 579 Line: md.keyInfo[k].pos.Line, 580 input: d, 581 } 582 } 583 584 func (md *MetaData) e(format string, args ...any) error { 585 f := "toml: " 586 if len(md.context) > 0 { 587 f = fmt.Sprintf("toml: (last key %q): ", md.context) 588 p := md.keyInfo[md.context.String()].pos 589 if p.Line > 0 { 590 f = fmt.Sprintf("toml: line %d (last key %q): ", p.Line, md.context) 591 } 592 } 593 return fmt.Errorf(f+format, args...) 594 } 595 596 // rvalue returns a reflect.Value of `v`. All pointers are resolved. 597 func rvalue(v any) reflect.Value { 598 return indirect(reflect.ValueOf(v)) 599 } 600 601 // indirect returns the value pointed to by a pointer. 602 // 603 // Pointers are followed until the value is not a pointer. New values are 604 // allocated for each nil pointer. 605 // 606 // An exception to this rule is if the value satisfies an interface of interest 607 // to us (like encoding.TextUnmarshaler). 608 func indirect(v reflect.Value) reflect.Value { 609 if v.Kind() != reflect.Ptr { 610 if v.CanSet() { 611 pv := v.Addr() 612 pvi := pv.Interface() 613 if _, ok := pvi.(encoding.TextUnmarshaler); ok { 614 return pv 615 } 616 if _, ok := pvi.(Unmarshaler); ok { 617 return pv 618 } 619 } 620 return v 621 } 622 if v.IsNil() { 623 v.Set(reflect.New(v.Type().Elem())) 624 } 625 return indirect(reflect.Indirect(v)) 626 } 627 628 func isUnifiable(rv reflect.Value) bool { 629 if rv.CanSet() { 630 return true 631 } 632 rvi := rv.Interface() 633 if _, ok := rvi.(encoding.TextUnmarshaler); ok { 634 return true 635 } 636 if _, ok := rvi.(Unmarshaler); ok { 637 return true 638 } 639 return false 640 } 641 642 // fmt %T with "interface {}" replaced with "any", which is far more readable. 643 func fmtType(t any) string { 644 return strings.ReplaceAll(fmt.Sprintf("%T", t), "interface {}", "any") 645 }