]> Sergey Matveev's repositories - btrtrc.git/blob - bencode/decode.go
6300f5cd7a02ad11e8288680faa1a7f1412759cc
[btrtrc.git] / bencode / decode.go
1 package bencode
2
3 import (
4         "bytes"
5         "errors"
6         "fmt"
7         "io"
8         "math/big"
9         "reflect"
10         "runtime"
11         "strconv"
12         "sync"
13         "unsafe"
14 )
15
16 // The default bencode string length limit. This is a poor attempt to prevent excessive memory
17 // allocation when parsing, but also leaves the window open to implement a better solution.
18 const DefaultDecodeMaxStrLen = 1<<27 - 1 // ~128MiB
19
20 type MaxStrLen = int64
21
22 type Decoder struct {
23         // Maximum parsed bencode string length. Defaults to DefaultMaxStrLen if zero.
24         MaxStrLen MaxStrLen
25
26         r interface {
27                 io.ByteScanner
28                 io.Reader
29         }
30         // Sum of bytes used to Decode values.
31         Offset int64
32         buf    bytes.Buffer
33 }
34
35 func (d *Decoder) Decode(v interface{}) (err error) {
36         defer func() {
37                 if err != nil {
38                         return
39                 }
40                 r := recover()
41                 if r == nil {
42                         return
43                 }
44                 _, ok := r.(runtime.Error)
45                 if ok {
46                         panic(r)
47                 }
48                 if err, ok = r.(error); !ok {
49                         panic(r)
50                 }
51                 // Errors thrown from deeper in parsing are unexpected. At value boundaries, errors should
52                 // be returned directly (at least until all the panic nonsense is removed entirely).
53                 if err == io.EOF {
54                         err = io.ErrUnexpectedEOF
55                 }
56         }()
57
58         pv := reflect.ValueOf(v)
59         if pv.Kind() != reflect.Ptr || pv.IsNil() {
60                 return &UnmarshalInvalidArgError{reflect.TypeOf(v)}
61         }
62
63         ok, err := d.parseValue(pv.Elem())
64         if err != nil {
65                 return
66         }
67         if !ok {
68                 d.throwSyntaxError(d.Offset-1, errors.New("unexpected 'e'"))
69         }
70         return
71 }
72
73 func checkForUnexpectedEOF(err error, offset int64) {
74         if err == io.EOF {
75                 panic(&SyntaxError{
76                         Offset: offset,
77                         What:   io.ErrUnexpectedEOF,
78                 })
79         }
80 }
81
82 func (d *Decoder) readByte() byte {
83         b, err := d.r.ReadByte()
84         if err != nil {
85                 checkForUnexpectedEOF(err, d.Offset)
86                 panic(err)
87         }
88
89         d.Offset++
90         return b
91 }
92
93 // reads data writing it to 'd.buf' until 'sep' byte is encountered, 'sep' byte
94 // is consumed, but not included into the 'd.buf'
95 func (d *Decoder) readUntil(sep byte) {
96         for {
97                 b := d.readByte()
98                 if b == sep {
99                         return
100                 }
101                 d.buf.WriteByte(b)
102         }
103 }
104
105 func checkForIntParseError(err error, offset int64) {
106         if err != nil {
107                 panic(&SyntaxError{
108                         Offset: offset,
109                         What:   err,
110                 })
111         }
112 }
113
114 func (d *Decoder) throwSyntaxError(offset int64, err error) {
115         panic(&SyntaxError{
116                 Offset: offset,
117                 What:   err,
118         })
119 }
120
121 // Assume the 'i' is already consumed. Read and validate the rest of an int into the buffer.
122 func (d *Decoder) readInt() error {
123         // start := d.Offset - 1
124         d.readUntil('e')
125         if err := d.checkBufferedInt(); err != nil {
126                 return err
127         }
128         // if d.buf.Len() == 0 {
129         //      panic(&SyntaxError{
130         //              Offset: start,
131         //              What:   errors.New("empty integer value"),
132         //      })
133         // }
134         return nil
135 }
136
137 // called when 'i' was consumed, for the integer type in v.
138 func (d *Decoder) parseInt(v reflect.Value) error {
139         start := d.Offset - 1
140
141         if err := d.readInt(); err != nil {
142                 return err
143         }
144         s := bytesAsString(d.buf.Bytes())
145
146         switch v.Kind() {
147         case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
148                 n, err := strconv.ParseInt(s, 10, 64)
149                 checkForIntParseError(err, start)
150
151                 if v.OverflowInt(n) {
152                         return &UnmarshalTypeError{
153                                 BencodeTypeName:     "int",
154                                 UnmarshalTargetType: v.Type(),
155                         }
156                 }
157                 v.SetInt(n)
158         case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
159                 n, err := strconv.ParseUint(s, 10, 64)
160                 checkForIntParseError(err, start)
161
162                 if v.OverflowUint(n) {
163                         return &UnmarshalTypeError{
164                                 BencodeTypeName:     "int",
165                                 UnmarshalTargetType: v.Type(),
166                         }
167                 }
168                 v.SetUint(n)
169         case reflect.Bool:
170                 v.SetBool(s != "0")
171         default:
172                 return &UnmarshalTypeError{
173                         BencodeTypeName:     "int",
174                         UnmarshalTargetType: v.Type(),
175                 }
176         }
177         d.buf.Reset()
178         return nil
179 }
180
181 func (d *Decoder) checkBufferedInt() error {
182         b := d.buf.Bytes()
183         if len(b) <= 1 {
184                 return nil
185         }
186         if b[0] == '-' {
187                 b = b[1:]
188         }
189         if b[0] < '1' || b[0] > '9' {
190                 return errors.New("invalid leading digit")
191         }
192         return nil
193 }
194
195 func (d *Decoder) parseStringLength() (int, error) {
196         // We should have already consumed the first byte of the length into the Decoder buf.
197         start := d.Offset - 1
198         d.readUntil(':')
199         if err := d.checkBufferedInt(); err != nil {
200                 return 0, err
201         }
202         // Really the limit should be the uint size for the platform. But we can't pass in an allocator,
203         // or limit total memory use in Go, the best we might hope to do is limit the size of a single
204         // decoded value (by reading it in in-place and then operating on a view).
205         length, err := strconv.ParseInt(bytesAsString(d.buf.Bytes()), 10, 0)
206         checkForIntParseError(err, start)
207         if int64(length) > d.getMaxStrLen() {
208                 err = fmt.Errorf("parsed string length %v exceeds limit (%v)", length, DefaultDecodeMaxStrLen)
209         }
210         d.buf.Reset()
211         return int(length), err
212 }
213
214 func (d *Decoder) parseString(v reflect.Value) error {
215         length, err := d.parseStringLength()
216         if err != nil {
217                 return err
218         }
219         defer d.buf.Reset()
220         read := func(b []byte) {
221                 n, err := io.ReadFull(d.r, b)
222                 d.Offset += int64(n)
223                 if err != nil {
224                         checkForUnexpectedEOF(err, d.Offset)
225                         panic(&SyntaxError{
226                                 Offset: d.Offset,
227                                 What:   errors.New("unexpected I/O error: " + err.Error()),
228                         })
229                 }
230         }
231
232         switch v.Kind() {
233         case reflect.String:
234                 b := make([]byte, length)
235                 read(b)
236                 v.SetString(bytesAsString(b))
237                 return nil
238         case reflect.Slice:
239                 if v.Type().Elem().Kind() != reflect.Uint8 {
240                         break
241                 }
242                 b := make([]byte, length)
243                 read(b)
244                 v.SetBytes(b)
245                 return nil
246         case reflect.Array:
247                 if v.Type().Elem().Kind() != reflect.Uint8 {
248                         break
249                 }
250                 d.buf.Grow(length)
251                 b := d.buf.Bytes()[:length]
252                 read(b)
253                 reflect.Copy(v, reflect.ValueOf(b))
254                 return nil
255         case reflect.Bool:
256                 d.buf.Grow(length)
257                 b := d.buf.Bytes()[:length]
258                 read(b)
259                 x, err := strconv.ParseBool(unsafe.String(unsafe.SliceData(b), len(b)))
260                 if err != nil {
261                         x = length != 0
262                 }
263                 v.SetBool(x)
264                 return nil
265         }
266         // Can't move this into default clause because some cases above fail through to here after
267         // additional checks.
268         d.buf.Grow(length)
269         read(d.buf.Bytes()[:length])
270         // I believe we return here to support "ignore_unmarshal_type_error".
271         return &UnmarshalTypeError{
272                 BencodeTypeName:     "string",
273                 UnmarshalTargetType: v.Type(),
274         }
275 }
276
277 // Info for parsing a dict value.
278 type dictField struct {
279         Type reflect.Type
280         Get  func(value reflect.Value) func(reflect.Value)
281         Tags tag
282 }
283
284 // Returns specifics for parsing a dict field value.
285 func getDictField(dict reflect.Type, key string) (_ dictField, err error) {
286         // get valuev as a map value or as a struct field
287         switch k := dict.Kind(); k {
288         case reflect.Map:
289                 return dictField{
290                         Type: dict.Elem(),
291                         Get: func(mapValue reflect.Value) func(reflect.Value) {
292                                 return func(value reflect.Value) {
293                                         if mapValue.IsNil() {
294                                                 mapValue.Set(reflect.MakeMap(dict))
295                                         }
296                                         // Assigns the value into the map.
297                                         // log.Printf("map type: %v", mapValue.Type())
298                                         mapValue.SetMapIndex(reflect.ValueOf(key).Convert(dict.Key()), value)
299                                 }
300                         },
301                 }, nil
302         case reflect.Struct:
303                 return getStructFieldForKey(dict, key), nil
304                 // if sf.r.PkgPath != "" {
305                 //      panic(&UnmarshalFieldError{
306                 //              Key:   key,
307                 //              Type:  dict.Type(),
308                 //              Field: sf.r,
309                 //      })
310                 // }
311         default:
312                 err = fmt.Errorf("can't assign bencode dict items into a %v", k)
313                 return
314         }
315 }
316
317 var (
318         structFieldsMu sync.Mutex
319         structFields   = map[reflect.Type]map[string]dictField{}
320 )
321
322 func parseStructFields(struct_ reflect.Type, each func(key string, df dictField)) {
323         for _i, n := 0, struct_.NumField(); _i < n; _i++ {
324                 i := _i
325                 f := struct_.Field(i)
326                 if f.Anonymous {
327                         t := f.Type
328                         if t.Kind() == reflect.Ptr {
329                                 t = t.Elem()
330                         }
331                         parseStructFields(t, func(key string, df dictField) {
332                                 innerGet := df.Get
333                                 df.Get = func(value reflect.Value) func(reflect.Value) {
334                                         anonPtr := value.Field(i)
335                                         if anonPtr.Kind() == reflect.Ptr && anonPtr.IsNil() {
336                                                 anonPtr.Set(reflect.New(f.Type.Elem()))
337                                                 anonPtr = anonPtr.Elem()
338                                         }
339                                         return innerGet(anonPtr)
340                                 }
341                                 each(key, df)
342                         })
343                         continue
344                 }
345                 tagStr := f.Tag.Get("bencode")
346                 if tagStr == "-" {
347                         continue
348                 }
349                 tag := parseTag(tagStr)
350                 key := tag.Key()
351                 if key == "" {
352                         key = f.Name
353                 }
354                 each(key, dictField{f.Type, func(value reflect.Value) func(reflect.Value) {
355                         return value.Field(i).Set
356                 }, tag})
357         }
358 }
359
360 func saveStructFields(struct_ reflect.Type) {
361         m := make(map[string]dictField)
362         parseStructFields(struct_, func(key string, sf dictField) {
363                 m[key] = sf
364         })
365         structFields[struct_] = m
366 }
367
368 func getStructFieldForKey(struct_ reflect.Type, key string) (f dictField) {
369         structFieldsMu.Lock()
370         if _, ok := structFields[struct_]; !ok {
371                 saveStructFields(struct_)
372         }
373         f, ok := structFields[struct_][key]
374         structFieldsMu.Unlock()
375         if !ok {
376                 var discard interface{}
377                 return dictField{
378                         Type: reflect.TypeOf(discard),
379                         Get:  func(reflect.Value) func(reflect.Value) { return func(reflect.Value) {} },
380                         Tags: nil,
381                 }
382         }
383         return
384 }
385
386 func (d *Decoder) parseDict(v reflect.Value) error {
387         // At this point 'd' byte was consumed, now read key/value pairs
388         for {
389                 var keyStr string
390                 keyValue := reflect.ValueOf(&keyStr).Elem()
391                 ok, err := d.parseValue(keyValue)
392                 if err != nil {
393                         return fmt.Errorf("error parsing dict key: %w", err)
394                 }
395                 if !ok {
396                         return nil
397                 }
398
399                 df, err := getDictField(v.Type(), keyStr)
400                 if err != nil {
401                         return fmt.Errorf("parsing bencode dict into %v: %w", v.Type(), err)
402                 }
403
404                 // now we need to actually parse it
405                 if df.Type == nil {
406                         // Discard the value, there's nowhere to put it.
407                         var if_ interface{}
408                         if_, ok = d.parseValueInterface()
409                         if if_ == nil {
410                                 return fmt.Errorf("error parsing value for key %q", keyStr)
411                         }
412                         if !ok {
413                                 return fmt.Errorf("missing value for key %q", keyStr)
414                         }
415                         continue
416                 }
417                 setValue := reflect.New(df.Type).Elem()
418                 // log.Printf("parsing into %v", setValue.Type())
419                 ok, err = d.parseValue(setValue)
420                 if err != nil {
421                         var target *UnmarshalTypeError
422                         if !(errors.As(err, &target) && df.Tags.IgnoreUnmarshalTypeError()) {
423                                 return fmt.Errorf("parsing value for key %q: %w", keyStr, err)
424                         }
425                 }
426                 if !ok {
427                         return fmt.Errorf("missing value for key %q", keyStr)
428                 }
429                 df.Get(v)(setValue)
430         }
431 }
432
433 func (d *Decoder) parseList(v reflect.Value) error {
434         switch v.Kind() {
435         default:
436                 // If the list is a singleton of the expected type, use that value. See
437                 // https://github.com/anacrolix/torrent/issues/297.
438                 l := reflect.New(reflect.SliceOf(v.Type()))
439                 if err := d.parseList(l.Elem()); err != nil {
440                         return err
441                 }
442                 if l.Elem().Len() != 1 {
443                         return &UnmarshalTypeError{
444                                 BencodeTypeName:     "list",
445                                 UnmarshalTargetType: v.Type(),
446                         }
447                 }
448                 v.Set(l.Elem().Index(0))
449                 return nil
450         case reflect.Array, reflect.Slice:
451                 // We can work with this. Normal case, fallthrough.
452         }
453
454         i := 0
455         for ; ; i++ {
456                 if v.Kind() == reflect.Slice && i >= v.Len() {
457                         v.Set(reflect.Append(v, reflect.Zero(v.Type().Elem())))
458                 }
459
460                 if i < v.Len() {
461                         ok, err := d.parseValue(v.Index(i))
462                         if err != nil {
463                                 return err
464                         }
465                         if !ok {
466                                 break
467                         }
468                 } else {
469                         _, ok := d.parseValueInterface()
470                         if !ok {
471                                 break
472                         }
473                 }
474         }
475
476         if i < v.Len() {
477                 if v.Kind() == reflect.Array {
478                         z := reflect.Zero(v.Type().Elem())
479                         for n := v.Len(); i < n; i++ {
480                                 v.Index(i).Set(z)
481                         }
482                 } else {
483                         v.SetLen(i)
484                 }
485         }
486
487         if i == 0 && v.Kind() == reflect.Slice {
488                 v.Set(reflect.MakeSlice(v.Type(), 0, 0))
489         }
490         return nil
491 }
492
493 func (d *Decoder) readOneValue() bool {
494         b, err := d.r.ReadByte()
495         if err != nil {
496                 panic(err)
497         }
498         if b == 'e' {
499                 d.r.UnreadByte()
500                 return false
501         } else {
502                 d.Offset++
503                 d.buf.WriteByte(b)
504         }
505
506         switch b {
507         case 'd', 'l':
508                 // read until there is nothing to read
509                 for d.readOneValue() {
510                 }
511                 // consume 'e' as well
512                 b = d.readByte()
513                 d.buf.WriteByte(b)
514         case 'i':
515                 d.readUntil('e')
516                 d.buf.WriteString("e")
517         default:
518                 if b >= '0' && b <= '9' {
519                         start := d.buf.Len() - 1
520                         d.readUntil(':')
521                         length, err := strconv.ParseInt(bytesAsString(d.buf.Bytes()[start:]), 10, 64)
522                         checkForIntParseError(err, d.Offset-1)
523
524                         d.buf.WriteString(":")
525                         n, err := io.CopyN(&d.buf, d.r, length)
526                         d.Offset += n
527                         if err != nil {
528                                 checkForUnexpectedEOF(err, d.Offset)
529                                 panic(&SyntaxError{
530                                         Offset: d.Offset,
531                                         What:   errors.New("unexpected I/O error: " + err.Error()),
532                                 })
533                         }
534                         break
535                 }
536
537                 d.raiseUnknownValueType(b, d.Offset-1)
538         }
539
540         return true
541 }
542
543 func (d *Decoder) parseUnmarshaler(v reflect.Value) bool {
544         if !v.Type().Implements(unmarshalerType) {
545                 if v.Addr().Type().Implements(unmarshalerType) {
546                         v = v.Addr()
547                 } else {
548                         return false
549                 }
550         }
551         d.buf.Reset()
552         if !d.readOneValue() {
553                 return false
554         }
555         m := v.Interface().(Unmarshaler)
556         err := m.UnmarshalBencode(d.buf.Bytes())
557         if err != nil {
558                 panic(&UnmarshalerError{v.Type(), err})
559         }
560         return true
561 }
562
563 // Returns true if there was a value and it's now stored in 'v', otherwise
564 // there was an end symbol ("e") and no value was stored.
565 func (d *Decoder) parseValue(v reflect.Value) (bool, error) {
566         // we support one level of indirection at the moment
567         if v.Kind() == reflect.Ptr {
568                 // if the pointer is nil, allocate a new element of the type it
569                 // points to
570                 if v.IsNil() {
571                         v.Set(reflect.New(v.Type().Elem()))
572                 }
573                 v = v.Elem()
574         }
575
576         if d.parseUnmarshaler(v) {
577                 return true, nil
578         }
579
580         // common case: interface{}
581         if v.Kind() == reflect.Interface && v.NumMethod() == 0 {
582                 iface, _ := d.parseValueInterface()
583                 v.Set(reflect.ValueOf(iface))
584                 return true, nil
585         }
586
587         b, err := d.r.ReadByte()
588         if err != nil {
589                 return false, err
590         }
591         d.Offset++
592
593         switch b {
594         case 'e':
595                 return false, nil
596         case 'd':
597                 return true, d.parseDict(v)
598         case 'l':
599                 return true, d.parseList(v)
600         case 'i':
601                 return true, d.parseInt(v)
602         default:
603                 if b >= '0' && b <= '9' {
604                         // It's a string.
605                         d.buf.Reset()
606                         // Write the first digit of the length to the buffer.
607                         d.buf.WriteByte(b)
608                         return true, d.parseString(v)
609                 }
610
611                 d.raiseUnknownValueType(b, d.Offset-1)
612         }
613         panic("unreachable")
614 }
615
616 // An unknown bencode type character was encountered.
617 func (d *Decoder) raiseUnknownValueType(b byte, offset int64) {
618         panic(&SyntaxError{
619                 Offset: offset,
620                 What:   fmt.Errorf("unknown value type %+q", b),
621         })
622 }
623
624 func (d *Decoder) parseValueInterface() (interface{}, bool) {
625         b, err := d.r.ReadByte()
626         if err != nil {
627                 panic(err)
628         }
629         d.Offset++
630
631         switch b {
632         case 'e':
633                 return nil, false
634         case 'd':
635                 return d.parseDictInterface(), true
636         case 'l':
637                 return d.parseListInterface(), true
638         case 'i':
639                 return d.parseIntInterface(), true
640         default:
641                 if b >= '0' && b <= '9' {
642                         // string
643                         // append first digit of the length to the buffer
644                         d.buf.WriteByte(b)
645                         return d.parseStringInterface(), true
646                 }
647
648                 d.raiseUnknownValueType(b, d.Offset-1)
649                 panic("unreachable")
650         }
651 }
652
653 // Called after 'i', for an arbitrary integer size.
654 func (d *Decoder) parseIntInterface() (ret interface{}) {
655         start := d.Offset - 1
656
657         if err := d.readInt(); err != nil {
658                 panic(err)
659         }
660         n, err := strconv.ParseInt(d.buf.String(), 10, 64)
661         if ne, ok := err.(*strconv.NumError); ok && ne.Err == strconv.ErrRange {
662                 i := new(big.Int)
663                 _, ok := i.SetString(d.buf.String(), 10)
664                 if !ok {
665                         panic(&SyntaxError{
666                                 Offset: start,
667                                 What:   errors.New("failed to parse integer"),
668                         })
669                 }
670                 ret = i
671         } else {
672                 checkForIntParseError(err, start)
673                 ret = n
674         }
675
676         d.buf.Reset()
677         return
678 }
679
680 func (d *Decoder) readBytes(length int) []byte {
681         b, err := io.ReadAll(io.LimitReader(d.r, int64(length)))
682         if err != nil {
683                 panic(err)
684         }
685         if len(b) != length {
686                 panic(fmt.Errorf("read %v bytes expected %v", len(b), length))
687         }
688         return b
689 }
690
691 func (d *Decoder) parseStringInterface() string {
692         length, err := d.parseStringLength()
693         if err != nil {
694                 panic(err)
695         }
696         b := d.readBytes(int(length))
697         d.Offset += int64(len(b))
698         if err != nil {
699                 panic(&SyntaxError{Offset: d.Offset, What: err})
700         }
701         return bytesAsString(b)
702 }
703
704 func (d *Decoder) parseDictInterface() interface{} {
705         dict := make(map[string]interface{})
706         var lastKey string
707         lastKeyOk := false
708         for {
709                 start := d.Offset
710                 keyi, ok := d.parseValueInterface()
711                 if !ok {
712                         break
713                 }
714
715                 key, ok := keyi.(string)
716                 if !ok {
717                         panic(&SyntaxError{
718                                 Offset: d.Offset,
719                                 What:   errors.New("non-string key in a dict"),
720                         })
721                 }
722                 if lastKeyOk && key <= lastKey {
723                         d.throwSyntaxError(start, fmt.Errorf("dict keys unsorted: %q <= %q", key, lastKey))
724                 }
725                 start = d.Offset
726                 valuei, ok := d.parseValueInterface()
727                 if !ok {
728                         d.throwSyntaxError(start, fmt.Errorf("dict elem missing value [key=%v]", key))
729                 }
730
731                 lastKey = key
732                 lastKeyOk = true
733                 dict[key] = valuei
734         }
735         return dict
736 }
737
738 func (d *Decoder) parseListInterface() (list []interface{}) {
739         list = []interface{}{}
740         valuei, ok := d.parseValueInterface()
741         for ok {
742                 list = append(list, valuei)
743                 valuei, ok = d.parseValueInterface()
744         }
745         return
746 }
747
748 func (d *Decoder) getMaxStrLen() int64 {
749         if d.MaxStrLen == 0 {
750                 return DefaultDecodeMaxStrLen
751         }
752         return d.MaxStrLen
753 }