]> Sergey Matveev's repositories - btrtrc.git/blob - bencode/decode.go
53ce6efd9035f9bb1034f3652b7c59e45bfd54be
[btrtrc.git] / bencode / decode.go
1 package bencode
2
3 import (
4         "bytes"
5         "errors"
6         "fmt"
7         "io"
8         "math/big"
9         "reflect"
10         "runtime"
11         "strconv"
12         "sync"
13 )
14
15 type Decoder struct {
16         r interface {
17                 io.ByteScanner
18                 io.Reader
19         }
20         // Sum of bytes used to Decode values.
21         Offset int64
22         buf    bytes.Buffer
23 }
24
25 func (d *Decoder) Decode(v interface{}) (err error) {
26         defer func() {
27                 if err != nil {
28                         return
29                 }
30                 r := recover()
31                 _, ok := r.(runtime.Error)
32                 if ok {
33                         panic(r)
34                 }
35                 err, ok = r.(error)
36                 if !ok && r != nil {
37                         panic(r)
38                 }
39         }()
40
41         pv := reflect.ValueOf(v)
42         if pv.Kind() != reflect.Ptr || pv.IsNil() {
43                 return &UnmarshalInvalidArgError{reflect.TypeOf(v)}
44         }
45
46         ok, err := d.parseValue(pv.Elem())
47         if err != nil {
48                 return
49         }
50         if !ok {
51                 d.throwSyntaxError(d.Offset-1, errors.New("unexpected 'e'"))
52         }
53         return
54 }
55
56 func checkForUnexpectedEOF(err error, offset int64) {
57         if err == io.EOF {
58                 panic(&SyntaxError{
59                         Offset: offset,
60                         What:   io.ErrUnexpectedEOF,
61                 })
62         }
63 }
64
65 func (d *Decoder) readByte() byte {
66         b, err := d.r.ReadByte()
67         if err != nil {
68                 checkForUnexpectedEOF(err, d.Offset)
69                 panic(err)
70         }
71
72         d.Offset++
73         return b
74 }
75
76 // reads data writing it to 'd.buf' until 'sep' byte is encountered, 'sep' byte
77 // is consumed, but not included into the 'd.buf'
78 func (d *Decoder) readUntil(sep byte) {
79         for {
80                 b := d.readByte()
81                 if b == sep {
82                         return
83                 }
84                 d.buf.WriteByte(b)
85         }
86 }
87
88 func checkForIntParseError(err error, offset int64) {
89         if err != nil {
90                 panic(&SyntaxError{
91                         Offset: offset,
92                         What:   err,
93                 })
94         }
95 }
96
97 func (d *Decoder) throwSyntaxError(offset int64, err error) {
98         panic(&SyntaxError{
99                 Offset: offset,
100                 What:   err,
101         })
102 }
103
104 // Assume the 'i' is already consumed. Read and validate the rest of an int into the buffer.
105 func (d *Decoder) readInt() error {
106         // start := d.Offset - 1
107         d.readUntil('e')
108         if err := d.checkBufferedInt(); err != nil {
109                 return err
110         }
111         // if d.buf.Len() == 0 {
112         //      panic(&SyntaxError{
113         //              Offset: start,
114         //              What:   errors.New("empty integer value"),
115         //      })
116         // }
117         return nil
118 }
119
120 // called when 'i' was consumed, for the integer type in v.
121 func (d *Decoder) parseInt(v reflect.Value) error {
122         start := d.Offset - 1
123
124         if err := d.readInt(); err != nil {
125                 return err
126         }
127         s := bytesAsString(d.buf.Bytes())
128
129         switch v.Kind() {
130         case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
131                 n, err := strconv.ParseInt(s, 10, 64)
132                 checkForIntParseError(err, start)
133
134                 if v.OverflowInt(n) {
135                         return &UnmarshalTypeError{
136                                 BencodeTypeName:     "int",
137                                 UnmarshalTargetType: v.Type(),
138                         }
139                 }
140                 v.SetInt(n)
141         case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
142                 n, err := strconv.ParseUint(s, 10, 64)
143                 checkForIntParseError(err, start)
144
145                 if v.OverflowUint(n) {
146                         return &UnmarshalTypeError{
147                                 BencodeTypeName:     "int",
148                                 UnmarshalTargetType: v.Type(),
149                         }
150                 }
151                 v.SetUint(n)
152         case reflect.Bool:
153                 v.SetBool(s != "0")
154         default:
155                 return &UnmarshalTypeError{
156                         BencodeTypeName:     "int",
157                         UnmarshalTargetType: v.Type(),
158                 }
159         }
160         d.buf.Reset()
161         return nil
162 }
163
164 func (d *Decoder) checkBufferedInt() error {
165         b := d.buf.Bytes()
166         if len(b) <= 1 {
167                 return nil
168         }
169         if b[0] == '-' {
170                 b = b[1:]
171         }
172         if b[0] < '1' || b[0] > '9' {
173                 return errors.New("invalid leading digit")
174         }
175         return nil
176 }
177
178 func (d *Decoder) parseStringLength() (uint64, error) {
179         // We should have already consumed the first byte of the length into the Decoder buf.
180         start := d.Offset - 1
181         d.readUntil(':')
182         if err := d.checkBufferedInt(); err != nil {
183                 return 0, err
184         }
185         length, err := strconv.ParseUint(bytesAsString(d.buf.Bytes()), 10, 32)
186         checkForIntParseError(err, start)
187         d.buf.Reset()
188         return length, err
189 }
190
191 func (d *Decoder) parseString(v reflect.Value) error {
192         length, err := d.parseStringLength()
193         if err != nil {
194                 return err
195         }
196         defer d.buf.Reset()
197         read := func(b []byte) {
198                 n, err := io.ReadFull(d.r, b)
199                 d.Offset += int64(n)
200                 if err != nil {
201                         checkForUnexpectedEOF(err, d.Offset)
202                         panic(&SyntaxError{
203                                 Offset: d.Offset,
204                                 What:   errors.New("unexpected I/O error: " + err.Error()),
205                         })
206                 }
207         }
208
209         switch v.Kind() {
210         case reflect.String:
211                 b := make([]byte, length)
212                 read(b)
213                 v.SetString(bytesAsString(b))
214                 return nil
215         case reflect.Slice:
216                 if v.Type().Elem().Kind() != reflect.Uint8 {
217                         break
218                 }
219                 b := make([]byte, length)
220                 read(b)
221                 v.SetBytes(b)
222                 return nil
223         case reflect.Array:
224                 if v.Type().Elem().Kind() != reflect.Uint8 {
225                         break
226                 }
227                 d.buf.Grow(int(length))
228                 b := d.buf.Bytes()[:length]
229                 read(b)
230                 reflect.Copy(v, reflect.ValueOf(b))
231                 return nil
232         }
233         d.buf.Grow(int(length))
234         read(d.buf.Bytes()[:length])
235         // I believe we return here to support "ignore_unmarshal_type_error".
236         return &UnmarshalTypeError{
237                 BencodeTypeName:     "string",
238                 UnmarshalTargetType: v.Type(),
239         }
240 }
241
242 // Info for parsing a dict value.
243 type dictField struct {
244         Type reflect.Type
245         Get  func(value reflect.Value) func(reflect.Value)
246         Tags tag
247 }
248
249 // Returns specifics for parsing a dict field value.
250 func getDictField(dict reflect.Type, key string) (_ dictField, err error) {
251         // get valuev as a map value or as a struct field
252         switch k := dict.Kind(); k {
253         case reflect.Map:
254                 return dictField{
255                         Type: dict.Elem(),
256                         Get: func(mapValue reflect.Value) func(reflect.Value) {
257                                 return func(value reflect.Value) {
258                                         if mapValue.IsNil() {
259                                                 mapValue.Set(reflect.MakeMap(dict))
260                                         }
261                                         // Assigns the value into the map.
262                                         // log.Printf("map type: %v", mapValue.Type())
263                                         mapValue.SetMapIndex(reflect.ValueOf(key).Convert(dict.Key()), value)
264                                 }
265                         },
266                 }, nil
267         case reflect.Struct:
268                 return getStructFieldForKey(dict, key), nil
269                 // if sf.r.PkgPath != "" {
270                 //      panic(&UnmarshalFieldError{
271                 //              Key:   key,
272                 //              Type:  dict.Type(),
273                 //              Field: sf.r,
274                 //      })
275                 // }
276         default:
277                 err = fmt.Errorf("can't assign bencode dict items into a %v", k)
278                 return
279         }
280 }
281
282 var (
283         structFieldsMu sync.Mutex
284         structFields   = map[reflect.Type]map[string]dictField{}
285 )
286
287 func parseStructFields(struct_ reflect.Type, each func(key string, df dictField)) {
288         for _i, n := 0, struct_.NumField(); _i < n; _i++ {
289                 i := _i
290                 f := struct_.Field(i)
291                 if f.Anonymous {
292                         t := f.Type
293                         if t.Kind() == reflect.Ptr {
294                                 t = t.Elem()
295                         }
296                         parseStructFields(t, func(key string, df dictField) {
297                                 innerGet := df.Get
298                                 df.Get = func(value reflect.Value) func(reflect.Value) {
299                                         anonPtr := value.Field(i)
300                                         if anonPtr.Kind() == reflect.Ptr && anonPtr.IsNil() {
301                                                 anonPtr.Set(reflect.New(f.Type.Elem()))
302                                                 anonPtr = anonPtr.Elem()
303                                         }
304                                         return innerGet(anonPtr)
305                                 }
306                                 each(key, df)
307                         })
308                         continue
309                 }
310                 tagStr := f.Tag.Get("bencode")
311                 if tagStr == "-" {
312                         continue
313                 }
314                 tag := parseTag(tagStr)
315                 key := tag.Key()
316                 if key == "" {
317                         key = f.Name
318                 }
319                 each(key, dictField{f.Type, func(value reflect.Value) func(reflect.Value) {
320                         return value.Field(i).Set
321                 }, tag})
322         }
323 }
324
325 func saveStructFields(struct_ reflect.Type) {
326         m := make(map[string]dictField)
327         parseStructFields(struct_, func(key string, sf dictField) {
328                 m[key] = sf
329         })
330         structFields[struct_] = m
331 }
332
333 func getStructFieldForKey(struct_ reflect.Type, key string) (f dictField) {
334         structFieldsMu.Lock()
335         if _, ok := structFields[struct_]; !ok {
336                 saveStructFields(struct_)
337         }
338         f, ok := structFields[struct_][key]
339         structFieldsMu.Unlock()
340         if !ok {
341                 var discard interface{}
342                 return dictField{
343                         Type: reflect.TypeOf(discard),
344                         Get:  func(reflect.Value) func(reflect.Value) { return func(reflect.Value) {} },
345                         Tags: nil,
346                 }
347         }
348         return
349 }
350
351 func (d *Decoder) parseDict(v reflect.Value) error {
352         // At this point 'd' byte was consumed, now read key/value pairs
353         for {
354                 var keyStr string
355                 keyValue := reflect.ValueOf(&keyStr).Elem()
356                 ok, err := d.parseValue(keyValue)
357                 if err != nil {
358                         return fmt.Errorf("error parsing dict key: %w", err)
359                 }
360                 if !ok {
361                         return nil
362                 }
363
364                 df, err := getDictField(v.Type(), keyStr)
365                 if err != nil {
366                         return fmt.Errorf("parsing bencode dict into %v: %w", v.Type(), err)
367                 }
368
369                 // now we need to actually parse it
370                 if df.Type == nil {
371                         // Discard the value, there's nowhere to put it.
372                         var if_ interface{}
373                         if_, ok = d.parseValueInterface()
374                         if if_ == nil {
375                                 return fmt.Errorf("error parsing value for key %q", keyStr)
376                         }
377                         if !ok {
378                                 return fmt.Errorf("missing value for key %q", keyStr)
379                         }
380                         continue
381                 }
382                 setValue := reflect.New(df.Type).Elem()
383                 // log.Printf("parsing into %v", setValue.Type())
384                 ok, err = d.parseValue(setValue)
385                 if err != nil {
386                         var target *UnmarshalTypeError
387                         if !(errors.As(err, &target) && df.Tags.IgnoreUnmarshalTypeError()) {
388                                 return fmt.Errorf("parsing value for key %q: %w", keyStr, err)
389                         }
390                 }
391                 if !ok {
392                         return fmt.Errorf("missing value for key %q", keyStr)
393                 }
394                 df.Get(v)(setValue)
395         }
396 }
397
398 func (d *Decoder) parseList(v reflect.Value) error {
399         switch v.Kind() {
400         default:
401                 // If the list is a singleton of the expected type, use that value. See
402                 // https://github.com/anacrolix/torrent/issues/297.
403                 l := reflect.New(reflect.SliceOf(v.Type()))
404                 if err := d.parseList(l.Elem()); err != nil {
405                         return err
406                 }
407                 if l.Elem().Len() != 1 {
408                         return &UnmarshalTypeError{
409                                 BencodeTypeName:     "list",
410                                 UnmarshalTargetType: v.Type(),
411                         }
412                 }
413                 v.Set(l.Elem().Index(0))
414                 return nil
415         case reflect.Array, reflect.Slice:
416                 // We can work with this. Normal case, fallthrough.
417         }
418
419         i := 0
420         for ; ; i++ {
421                 if v.Kind() == reflect.Slice && i >= v.Len() {
422                         v.Set(reflect.Append(v, reflect.Zero(v.Type().Elem())))
423                 }
424
425                 if i < v.Len() {
426                         ok, err := d.parseValue(v.Index(i))
427                         if err != nil {
428                                 return err
429                         }
430                         if !ok {
431                                 break
432                         }
433                 } else {
434                         _, ok := d.parseValueInterface()
435                         if !ok {
436                                 break
437                         }
438                 }
439         }
440
441         if i < v.Len() {
442                 if v.Kind() == reflect.Array {
443                         z := reflect.Zero(v.Type().Elem())
444                         for n := v.Len(); i < n; i++ {
445                                 v.Index(i).Set(z)
446                         }
447                 } else {
448                         v.SetLen(i)
449                 }
450         }
451
452         if i == 0 && v.Kind() == reflect.Slice {
453                 v.Set(reflect.MakeSlice(v.Type(), 0, 0))
454         }
455         return nil
456 }
457
458 func (d *Decoder) readOneValue() bool {
459         b, err := d.r.ReadByte()
460         if err != nil {
461                 panic(err)
462         }
463         if b == 'e' {
464                 d.r.UnreadByte()
465                 return false
466         } else {
467                 d.Offset++
468                 d.buf.WriteByte(b)
469         }
470
471         switch b {
472         case 'd', 'l':
473                 // read until there is nothing to read
474                 for d.readOneValue() {
475                 }
476                 // consume 'e' as well
477                 b = d.readByte()
478                 d.buf.WriteByte(b)
479         case 'i':
480                 d.readUntil('e')
481                 d.buf.WriteString("e")
482         default:
483                 if b >= '0' && b <= '9' {
484                         start := d.buf.Len() - 1
485                         d.readUntil(':')
486                         length, err := strconv.ParseInt(bytesAsString(d.buf.Bytes()[start:]), 10, 64)
487                         checkForIntParseError(err, d.Offset-1)
488
489                         d.buf.WriteString(":")
490                         n, err := io.CopyN(&d.buf, d.r, length)
491                         d.Offset += n
492                         if err != nil {
493                                 checkForUnexpectedEOF(err, d.Offset)
494                                 panic(&SyntaxError{
495                                         Offset: d.Offset,
496                                         What:   errors.New("unexpected I/O error: " + err.Error()),
497                                 })
498                         }
499                         break
500                 }
501
502                 d.raiseUnknownValueType(b, d.Offset-1)
503         }
504
505         return true
506 }
507
508 func (d *Decoder) parseUnmarshaler(v reflect.Value) bool {
509         if !v.Type().Implements(unmarshalerType) {
510                 if v.Addr().Type().Implements(unmarshalerType) {
511                         v = v.Addr()
512                 } else {
513                         return false
514                 }
515         }
516         d.buf.Reset()
517         if !d.readOneValue() {
518                 return false
519         }
520         m := v.Interface().(Unmarshaler)
521         err := m.UnmarshalBencode(d.buf.Bytes())
522         if err != nil {
523                 panic(&UnmarshalerError{v.Type(), err})
524         }
525         return true
526 }
527
528 // Returns true if there was a value and it's now stored in 'v', otherwise
529 // there was an end symbol ("e") and no value was stored.
530 func (d *Decoder) parseValue(v reflect.Value) (bool, error) {
531         // we support one level of indirection at the moment
532         if v.Kind() == reflect.Ptr {
533                 // if the pointer is nil, allocate a new element of the type it
534                 // points to
535                 if v.IsNil() {
536                         v.Set(reflect.New(v.Type().Elem()))
537                 }
538                 v = v.Elem()
539         }
540
541         if d.parseUnmarshaler(v) {
542                 return true, nil
543         }
544
545         // common case: interface{}
546         if v.Kind() == reflect.Interface && v.NumMethod() == 0 {
547                 iface, _ := d.parseValueInterface()
548                 v.Set(reflect.ValueOf(iface))
549                 return true, nil
550         }
551
552         b, err := d.r.ReadByte()
553         if err != nil {
554                 panic(err)
555         }
556         d.Offset++
557
558         switch b {
559         case 'e':
560                 return false, nil
561         case 'd':
562                 return true, d.parseDict(v)
563         case 'l':
564                 return true, d.parseList(v)
565         case 'i':
566                 return true, d.parseInt(v)
567         default:
568                 if b >= '0' && b <= '9' {
569                         // It's a string.
570                         d.buf.Reset()
571                         // Write the first digit of the length to the buffer.
572                         d.buf.WriteByte(b)
573                         return true, d.parseString(v)
574                 }
575
576                 d.raiseUnknownValueType(b, d.Offset-1)
577         }
578         panic("unreachable")
579 }
580
581 // An unknown bencode type character was encountered.
582 func (d *Decoder) raiseUnknownValueType(b byte, offset int64) {
583         panic(&SyntaxError{
584                 Offset: offset,
585                 What:   fmt.Errorf("unknown value type %+q", b),
586         })
587 }
588
589 func (d *Decoder) parseValueInterface() (interface{}, bool) {
590         b, err := d.r.ReadByte()
591         if err != nil {
592                 panic(err)
593         }
594         d.Offset++
595
596         switch b {
597         case 'e':
598                 return nil, false
599         case 'd':
600                 return d.parseDictInterface(), true
601         case 'l':
602                 return d.parseListInterface(), true
603         case 'i':
604                 return d.parseIntInterface(), true
605         default:
606                 if b >= '0' && b <= '9' {
607                         // string
608                         // append first digit of the length to the buffer
609                         d.buf.WriteByte(b)
610                         return d.parseStringInterface(), true
611                 }
612
613                 d.raiseUnknownValueType(b, d.Offset-1)
614                 panic("unreachable")
615         }
616 }
617
618 // Called after 'i', for an arbitrary integer size.
619 func (d *Decoder) parseIntInterface() (ret interface{}) {
620         start := d.Offset - 1
621
622         if err := d.readInt(); err != nil {
623                 panic(err)
624         }
625         n, err := strconv.ParseInt(d.buf.String(), 10, 64)
626         if ne, ok := err.(*strconv.NumError); ok && ne.Err == strconv.ErrRange {
627                 i := new(big.Int)
628                 _, ok := i.SetString(d.buf.String(), 10)
629                 if !ok {
630                         panic(&SyntaxError{
631                                 Offset: start,
632                                 What:   errors.New("failed to parse integer"),
633                         })
634                 }
635                 ret = i
636         } else {
637                 checkForIntParseError(err, start)
638                 ret = n
639         }
640
641         d.buf.Reset()
642         return
643 }
644
645 func (d *Decoder) readBytes(length int) []byte {
646         b, err := io.ReadAll(io.LimitReader(d.r, int64(length)))
647         if err != nil {
648                 panic(err)
649         }
650         if len(b) != length {
651                 panic(fmt.Errorf("read %v bytes expected %v", len(b), length))
652         }
653         return b
654 }
655
656 func (d *Decoder) parseStringInterface() string {
657         length, err := d.parseStringLength()
658         if err != nil {
659                 panic(err)
660         }
661         b := d.readBytes(int(length))
662         d.Offset += int64(len(b))
663         if err != nil {
664                 panic(&SyntaxError{Offset: d.Offset, What: err})
665         }
666         return bytesAsString(b)
667 }
668
669 func (d *Decoder) parseDictInterface() interface{} {
670         dict := make(map[string]interface{})
671         lastKey := ""
672         for {
673                 start := d.Offset
674                 keyi, ok := d.parseValueInterface()
675                 if !ok {
676                         break
677                 }
678
679                 key, ok := keyi.(string)
680                 if !ok {
681                         panic(&SyntaxError{
682                                 Offset: d.Offset,
683                                 What:   errors.New("non-string key in a dict"),
684                         })
685                 }
686                 if key <= lastKey {
687                         d.throwSyntaxError(start, fmt.Errorf("dict keys unsorted: %q <= %q", key, lastKey))
688                 }
689                 start = d.Offset
690                 valuei, ok := d.parseValueInterface()
691                 if !ok {
692                         d.throwSyntaxError(start, fmt.Errorf("dict elem missing value [key=%v]", key))
693                 }
694
695                 lastKey = key
696                 dict[key] = valuei
697         }
698         return dict
699 }
700
701 func (d *Decoder) parseListInterface() (list []interface{}) {
702         list = []interface{}{}
703         valuei, ok := d.parseValueInterface()
704         for ok {
705                 list = append(list, valuei)
706                 valuei, ok = d.parseValueInterface()
707         }
708         return
709 }