]> Sergey Matveev's repositories - btrtrc.git/blob - bencode/decode.go
bencode: Unmarshal now returns an error on unused trailing bytes
[btrtrc.git] / bencode / decode.go
1 package bencode
2
3 import (
4         "bytes"
5         "errors"
6         "fmt"
7         "io"
8         "math/big"
9         "reflect"
10         "runtime"
11         "strconv"
12         "strings"
13 )
14
15 type Decoder struct {
16         r interface {
17                 io.ByteScanner
18                 io.Reader
19         }
20         // Sum of bytes used to Decode values.
21         Offset int64
22         buf    bytes.Buffer
23 }
24
25 func (d *Decoder) Decode(v interface{}) (err error) {
26         defer func() {
27                 if e := recover(); e != nil {
28                         if _, ok := e.(runtime.Error); ok {
29                                 panic(e)
30                         }
31                         err = e.(error)
32                 }
33         }()
34
35         pv := reflect.ValueOf(v)
36         if pv.Kind() != reflect.Ptr || pv.IsNil() {
37                 return &UnmarshalInvalidArgError{reflect.TypeOf(v)}
38         }
39
40         ok, err := d.parseValue(pv.Elem())
41         if err != nil {
42                 return
43         }
44         if !ok {
45                 d.throwSyntaxError(d.Offset-1, errors.New("unexpected 'e'"))
46         }
47         return
48 }
49
50 func checkForUnexpectedEOF(err error, offset int64) {
51         if err == io.EOF {
52                 panic(&SyntaxError{
53                         Offset: offset,
54                         What:   io.ErrUnexpectedEOF,
55                 })
56         }
57 }
58
59 func (d *Decoder) readByte() byte {
60         b, err := d.r.ReadByte()
61         if err != nil {
62                 checkForUnexpectedEOF(err, d.Offset)
63                 panic(err)
64         }
65
66         d.Offset++
67         return b
68 }
69
70 // reads data writing it to 'd.buf' until 'sep' byte is encountered, 'sep' byte
71 // is consumed, but not included into the 'd.buf'
72 func (d *Decoder) readUntil(sep byte) {
73         for {
74                 b := d.readByte()
75                 if b == sep {
76                         return
77                 }
78                 d.buf.WriteByte(b)
79         }
80 }
81
82 func checkForIntParseError(err error, offset int64) {
83         if err != nil {
84                 panic(&SyntaxError{
85                         Offset: offset,
86                         What:   err,
87                 })
88         }
89 }
90
91 func (d *Decoder) throwSyntaxError(offset int64, err error) {
92         panic(&SyntaxError{
93                 Offset: offset,
94                 What:   err,
95         })
96 }
97
98 // called when 'i' was consumed
99 func (d *Decoder) parseInt(v reflect.Value) {
100         start := d.Offset - 1
101         d.readUntil('e')
102         if d.buf.Len() == 0 {
103                 panic(&SyntaxError{
104                         Offset: start,
105                         What:   errors.New("empty integer value"),
106                 })
107         }
108
109         s := d.buf.String()
110
111         switch v.Kind() {
112         case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
113                 n, err := strconv.ParseInt(s, 10, 64)
114                 checkForIntParseError(err, start)
115
116                 if v.OverflowInt(n) {
117                         panic(&UnmarshalTypeError{
118                                 Value: "integer " + s,
119                                 Type:  v.Type(),
120                         })
121                 }
122                 v.SetInt(n)
123         case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
124                 n, err := strconv.ParseUint(s, 10, 64)
125                 checkForIntParseError(err, start)
126
127                 if v.OverflowUint(n) {
128                         panic(&UnmarshalTypeError{
129                                 Value: "integer " + s,
130                                 Type:  v.Type(),
131                         })
132                 }
133                 v.SetUint(n)
134         case reflect.Bool:
135                 v.SetBool(s != "0")
136         default:
137                 panic(&UnmarshalTypeError{
138                         Value: "integer " + s,
139                         Type:  v.Type(),
140                 })
141         }
142         d.buf.Reset()
143 }
144
145 func (d *Decoder) parseString(v reflect.Value) error {
146         start := d.Offset - 1
147
148         // read the string length first
149         d.readUntil(':')
150         length, err := strconv.ParseInt(d.buf.String(), 10, 64)
151         checkForIntParseError(err, start)
152
153         d.buf.Reset()
154         n, err := io.CopyN(&d.buf, d.r, length)
155         d.Offset += n
156         if err != nil {
157                 checkForUnexpectedEOF(err, d.Offset)
158                 panic(&SyntaxError{
159                         Offset: d.Offset,
160                         What:   errors.New("unexpected I/O error: " + err.Error()),
161                 })
162         }
163
164         switch v.Kind() {
165         case reflect.String:
166                 v.SetString(d.buf.String())
167         case reflect.Slice:
168                 if v.Type().Elem().Kind() != reflect.Uint8 {
169                         panic(&UnmarshalTypeError{
170                                 Value: "string",
171                                 Type:  v.Type(),
172                         })
173                 }
174                 sl := make([]byte, len(d.buf.Bytes()))
175                 copy(sl, d.buf.Bytes())
176                 v.Set(reflect.ValueOf(sl))
177         default:
178                 return &UnmarshalTypeError{
179                         Value: "string",
180                         Type:  v.Type(),
181                 }
182         }
183
184         d.buf.Reset()
185         return nil
186 }
187
188 // Info for parsing a dict value.
189 type dictField struct {
190         Value reflect.Value // Storage for the parsed value.
191         // True if field value should be parsed into Value. If false, the value
192         // should be parsed and discarded.
193         Ok                       bool
194         Set                      func() // Call this after parsing into Value.
195         IgnoreUnmarshalTypeError bool
196 }
197
198 // Returns specifics for parsing a dict field value.
199 func getDictField(dict reflect.Value, key string) dictField {
200         // get valuev as a map value or as a struct field
201         switch dict.Kind() {
202         case reflect.Map:
203                 value := reflect.New(dict.Type().Elem()).Elem()
204                 return dictField{
205                         Value: value,
206                         Ok:    true,
207                         Set: func() {
208                                 // Assigns the value into the map.
209                                 dict.SetMapIndex(reflect.ValueOf(key), value)
210                         },
211                 }
212         case reflect.Struct:
213                 sf, ok := getStructFieldForKey(dict.Type(), key)
214                 if !ok {
215                         return dictField{}
216                 }
217                 if sf.PkgPath != "" {
218                         panic(&UnmarshalFieldError{
219                                 Key:   key,
220                                 Type:  dict.Type(),
221                                 Field: sf,
222                         })
223                 }
224                 return dictField{
225                         Value: dict.FieldByIndex(sf.Index),
226                         Ok:    true,
227                         Set:   func() {},
228                         IgnoreUnmarshalTypeError: getTag(sf.Tag).IgnoreUnmarshalTypeError(),
229                 }
230         default:
231                 panic(dict.Kind())
232         }
233 }
234
235 func getStructFieldForKey(struct_ reflect.Type, key string) (f reflect.StructField, ok bool) {
236         for i, n := 0, struct_.NumField(); i < n; i++ {
237                 f = struct_.Field(i)
238                 tag := f.Tag.Get("bencode")
239                 if tag == "-" {
240                         continue
241                 }
242                 if f.Anonymous {
243                         continue
244                 }
245
246                 if parseTag(tag).Key() == key {
247                         ok = true
248                         break
249                 }
250
251                 if f.Name == key {
252                         ok = true
253                         break
254                 }
255
256                 if strings.EqualFold(f.Name, key) {
257                         ok = true
258                         break
259                 }
260         }
261         return
262 }
263
264 func (d *Decoder) parseDict(v reflect.Value) error {
265         switch v.Kind() {
266         case reflect.Map:
267                 t := v.Type()
268                 if t.Key().Kind() != reflect.String {
269                         panic(&UnmarshalTypeError{
270                                 Value: "object",
271                                 Type:  t,
272                         })
273                 }
274                 if v.IsNil() {
275                         v.Set(reflect.MakeMap(t))
276                 }
277         case reflect.Struct:
278         default:
279                 panic(&UnmarshalTypeError{
280                         Value: "object",
281                         Type:  v.Type(),
282                 })
283         }
284
285         // so, at this point 'd' byte was consumed, let's just read key/value
286         // pairs one by one
287         for {
288                 var keyStr string
289                 keyValue := reflect.ValueOf(&keyStr).Elem()
290                 ok, err := d.parseValue(keyValue)
291                 if err != nil {
292                         return fmt.Errorf("error parsing dict key: %s", err)
293                 }
294                 if !ok {
295                         return nil
296                 }
297
298                 df := getDictField(v, keyStr)
299
300                 // now we need to actually parse it
301                 if df.Ok {
302                         // log.Printf("parsing ok struct field for key %q", keyStr)
303                         ok, err = d.parseValue(df.Value)
304                 } else {
305                         // Discard the value, there's nowhere to put it.
306                         var if_ interface{}
307                         if_, ok = d.parseValueInterface()
308                         if if_ == nil {
309                                 err = fmt.Errorf("error parsing value for key %q", keyStr)
310                         }
311                 }
312                 if err != nil {
313                         if _, ok := err.(*UnmarshalTypeError); !ok || !df.IgnoreUnmarshalTypeError {
314                                 return fmt.Errorf("parsing value for key %q: %s", keyStr, err)
315                         }
316                 }
317                 if !ok {
318                         return fmt.Errorf("missing value for key %q", keyStr)
319                 }
320                 if df.Ok {
321                         df.Set()
322                 }
323         }
324 }
325
326 func (d *Decoder) parseList(v reflect.Value) error {
327         switch v.Kind() {
328         case reflect.Array, reflect.Slice:
329         default:
330                 panic(&UnmarshalTypeError{
331                         Value: "array",
332                         Type:  v.Type(),
333                 })
334         }
335
336         i := 0
337         for ; ; i++ {
338                 if v.Kind() == reflect.Slice && i >= v.Len() {
339                         v.Set(reflect.Append(v, reflect.Zero(v.Type().Elem())))
340                 }
341
342                 if i < v.Len() {
343                         ok, err := d.parseValue(v.Index(i))
344                         if err != nil {
345                                 return err
346                         }
347                         if !ok {
348                                 break
349                         }
350                 } else {
351                         _, ok := d.parseValueInterface()
352                         if !ok {
353                                 break
354                         }
355                 }
356         }
357
358         if i < v.Len() {
359                 if v.Kind() == reflect.Array {
360                         z := reflect.Zero(v.Type().Elem())
361                         for n := v.Len(); i < n; i++ {
362                                 v.Index(i).Set(z)
363                         }
364                 } else {
365                         v.SetLen(i)
366                 }
367         }
368
369         if i == 0 && v.Kind() == reflect.Slice {
370                 v.Set(reflect.MakeSlice(v.Type(), 0, 0))
371         }
372         return nil
373 }
374
375 func (d *Decoder) readOneValue() bool {
376         b, err := d.r.ReadByte()
377         if err != nil {
378                 panic(err)
379         }
380         if b == 'e' {
381                 d.r.UnreadByte()
382                 return false
383         } else {
384                 d.Offset++
385                 d.buf.WriteByte(b)
386         }
387
388         switch b {
389         case 'd', 'l':
390                 // read until there is nothing to read
391                 for d.readOneValue() {
392                 }
393                 // consume 'e' as well
394                 b = d.readByte()
395                 d.buf.WriteByte(b)
396         case 'i':
397                 d.readUntil('e')
398                 d.buf.WriteString("e")
399         default:
400                 if b >= '0' && b <= '9' {
401                         start := d.buf.Len() - 1
402                         d.readUntil(':')
403                         length, err := strconv.ParseInt(d.buf.String()[start:], 10, 64)
404                         checkForIntParseError(err, d.Offset-1)
405
406                         d.buf.WriteString(":")
407                         n, err := io.CopyN(&d.buf, d.r, length)
408                         d.Offset += n
409                         if err != nil {
410                                 checkForUnexpectedEOF(err, d.Offset)
411                                 panic(&SyntaxError{
412                                         Offset: d.Offset,
413                                         What:   errors.New("unexpected I/O error: " + err.Error()),
414                                 })
415                         }
416                         break
417                 }
418
419                 d.raiseUnknownValueType(b, d.Offset-1)
420         }
421
422         return true
423
424 }
425
426 func (d *Decoder) parseUnmarshaler(v reflect.Value) bool {
427         m, ok := v.Interface().(Unmarshaler)
428         if !ok {
429                 // T doesn't work, try *T
430                 if v.Kind() != reflect.Ptr && v.CanAddr() {
431                         m, ok = v.Addr().Interface().(Unmarshaler)
432                         if ok {
433                                 v = v.Addr()
434                         }
435                 }
436         }
437         if ok && (v.Kind() != reflect.Ptr || !v.IsNil()) {
438                 if d.readOneValue() {
439                         err := m.UnmarshalBencode(d.buf.Bytes())
440                         d.buf.Reset()
441                         if err != nil {
442                                 panic(&UnmarshalerError{v.Type(), err})
443                         }
444                         return true
445                 }
446                 d.buf.Reset()
447         }
448
449         return false
450 }
451
452 // Returns true if there was a value and it's now stored in 'v', otherwise
453 // there was an end symbol ("e") and no value was stored.
454 func (d *Decoder) parseValue(v reflect.Value) (bool, error) {
455         // we support one level of indirection at the moment
456         if v.Kind() == reflect.Ptr {
457                 // if the pointer is nil, allocate a new element of the type it
458                 // points to
459                 if v.IsNil() {
460                         v.Set(reflect.New(v.Type().Elem()))
461                 }
462                 v = v.Elem()
463         }
464
465         if d.parseUnmarshaler(v) {
466                 return true, nil
467         }
468
469         // common case: interface{}
470         if v.Kind() == reflect.Interface && v.NumMethod() == 0 {
471                 iface, _ := d.parseValueInterface()
472                 v.Set(reflect.ValueOf(iface))
473                 return true, nil
474         }
475
476         b, err := d.r.ReadByte()
477         if err != nil {
478                 panic(err)
479         }
480         d.Offset++
481
482         switch b {
483         case 'e':
484                 return false, nil
485         case 'd':
486                 return true, d.parseDict(v)
487         case 'l':
488                 return true, d.parseList(v)
489         case 'i':
490                 d.parseInt(v)
491                 return true, nil
492         default:
493                 if b >= '0' && b <= '9' {
494                         // string
495                         // append first digit of the length to the buffer
496                         d.buf.WriteByte(b)
497                         return true, d.parseString(v)
498                 }
499
500                 d.raiseUnknownValueType(b, d.Offset-1)
501         }
502         panic("unreachable")
503 }
504
505 // An unknown bencode type character was encountered.
506 func (d *Decoder) raiseUnknownValueType(b byte, offset int64) {
507         panic(&SyntaxError{
508                 Offset: offset,
509                 What:   fmt.Errorf("unknown value type %+q", b),
510         })
511 }
512
513 func (d *Decoder) parseValueInterface() (interface{}, bool) {
514         b, err := d.r.ReadByte()
515         if err != nil {
516                 panic(err)
517         }
518         d.Offset++
519
520         switch b {
521         case 'e':
522                 return nil, false
523         case 'd':
524                 return d.parseDictInterface(), true
525         case 'l':
526                 return d.parseListInterface(), true
527         case 'i':
528                 return d.parseIntInterface(), true
529         default:
530                 if b >= '0' && b <= '9' {
531                         // string
532                         // append first digit of the length to the buffer
533                         d.buf.WriteByte(b)
534                         return d.parseStringInterface(), true
535                 }
536
537                 d.raiseUnknownValueType(b, d.Offset-1)
538                 panic("unreachable")
539         }
540 }
541
542 func (d *Decoder) parseIntInterface() (ret interface{}) {
543         start := d.Offset - 1
544         d.readUntil('e')
545         if d.buf.Len() == 0 {
546                 panic(&SyntaxError{
547                         Offset: start,
548                         What:   errors.New("empty integer value"),
549                 })
550         }
551
552         n, err := strconv.ParseInt(d.buf.String(), 10, 64)
553         if ne, ok := err.(*strconv.NumError); ok && ne.Err == strconv.ErrRange {
554                 i := new(big.Int)
555                 _, ok := i.SetString(d.buf.String(), 10)
556                 if !ok {
557                         panic(&SyntaxError{
558                                 Offset: start,
559                                 What:   errors.New("failed to parse integer"),
560                         })
561                 }
562                 ret = i
563         } else {
564                 checkForIntParseError(err, start)
565                 ret = n
566         }
567
568         d.buf.Reset()
569         return
570 }
571
572 func (d *Decoder) parseStringInterface() interface{} {
573         start := d.Offset - 1
574
575         // read the string length first
576         d.readUntil(':')
577         length, err := strconv.ParseInt(d.buf.String(), 10, 64)
578         checkForIntParseError(err, start)
579
580         d.buf.Reset()
581         n, err := io.CopyN(&d.buf, d.r, length)
582         d.Offset += n
583         if err != nil {
584                 checkForUnexpectedEOF(err, d.Offset)
585                 panic(&SyntaxError{
586                         Offset: d.Offset,
587                         What:   errors.New("unexpected I/O error: " + err.Error()),
588                 })
589         }
590
591         s := d.buf.String()
592         d.buf.Reset()
593         return s
594 }
595
596 func (d *Decoder) parseDictInterface() interface{} {
597         dict := make(map[string]interface{})
598         for {
599                 keyi, ok := d.parseValueInterface()
600                 if !ok {
601                         break
602                 }
603
604                 key, ok := keyi.(string)
605                 if !ok {
606                         panic(&SyntaxError{
607                                 Offset: d.Offset,
608                                 What:   errors.New("non-string key in a dict"),
609                         })
610                 }
611
612                 valuei, ok := d.parseValueInterface()
613                 if !ok {
614                         break
615                 }
616
617                 dict[key] = valuei
618         }
619         return dict
620 }
621
622 func (d *Decoder) parseListInterface() interface{} {
623         var list []interface{}
624         for {
625                 valuei, ok := d.parseValueInterface()
626                 if !ok {
627                         break
628                 }
629
630                 list = append(list, valuei)
631         }
632         if list == nil {
633                 list = make([]interface{}, 0, 0)
634         }
635         return list
636 }