]> Sergey Matveev's repositories - btrtrc.git/blob - bencode/decode.go
bencode: Support unmarshalling strings into slices of kind Uint8
[btrtrc.git] / bencode / decode.go
1 package bencode
2
3 import (
4         "bytes"
5         "errors"
6         "fmt"
7         "io"
8         "math/big"
9         "reflect"
10         "runtime"
11         "strconv"
12         "strings"
13 )
14
15 type Decoder struct {
16         r interface {
17                 io.ByteScanner
18                 io.Reader
19         }
20         // Sum of bytes used to Decode values.
21         Offset int64
22         buf    bytes.Buffer
23 }
24
25 func (d *Decoder) Decode(v interface{}) (err error) {
26         defer func() {
27                 if e := recover(); e != nil {
28                         if _, ok := e.(runtime.Error); ok {
29                                 panic(e)
30                         }
31                         err = e.(error)
32                 }
33         }()
34
35         pv := reflect.ValueOf(v)
36         if pv.Kind() != reflect.Ptr || pv.IsNil() {
37                 return &UnmarshalInvalidArgError{reflect.TypeOf(v)}
38         }
39
40         ok, err := d.parseValue(pv.Elem())
41         if err != nil {
42                 return
43         }
44         if !ok {
45                 d.throwSyntaxError(d.Offset-1, errors.New("unexpected 'e'"))
46         }
47         return
48 }
49
50 func checkForUnexpectedEOF(err error, offset int64) {
51         if err == io.EOF {
52                 panic(&SyntaxError{
53                         Offset: offset,
54                         What:   io.ErrUnexpectedEOF,
55                 })
56         }
57 }
58
59 func (d *Decoder) readByte() byte {
60         b, err := d.r.ReadByte()
61         if err != nil {
62                 checkForUnexpectedEOF(err, d.Offset)
63                 panic(err)
64         }
65
66         d.Offset++
67         return b
68 }
69
70 // reads data writing it to 'd.buf' until 'sep' byte is encountered, 'sep' byte
71 // is consumed, but not included into the 'd.buf'
72 func (d *Decoder) readUntil(sep byte) {
73         for {
74                 b := d.readByte()
75                 if b == sep {
76                         return
77                 }
78                 d.buf.WriteByte(b)
79         }
80 }
81
82 func checkForIntParseError(err error, offset int64) {
83         if err != nil {
84                 panic(&SyntaxError{
85                         Offset: offset,
86                         What:   err,
87                 })
88         }
89 }
90
91 func (d *Decoder) throwSyntaxError(offset int64, err error) {
92         panic(&SyntaxError{
93                 Offset: offset,
94                 What:   err,
95         })
96 }
97
98 // called when 'i' was consumed
99 func (d *Decoder) parseInt(v reflect.Value) {
100         start := d.Offset - 1
101         d.readUntil('e')
102         if d.buf.Len() == 0 {
103                 panic(&SyntaxError{
104                         Offset: start,
105                         What:   errors.New("empty integer value"),
106                 })
107         }
108
109         s := d.buf.String()
110
111         switch v.Kind() {
112         case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
113                 n, err := strconv.ParseInt(s, 10, 64)
114                 checkForIntParseError(err, start)
115
116                 if v.OverflowInt(n) {
117                         panic(&UnmarshalTypeError{
118                                 Value: "integer " + s,
119                                 Type:  v.Type(),
120                         })
121                 }
122                 v.SetInt(n)
123         case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
124                 n, err := strconv.ParseUint(s, 10, 64)
125                 checkForIntParseError(err, start)
126
127                 if v.OverflowUint(n) {
128                         panic(&UnmarshalTypeError{
129                                 Value: "integer " + s,
130                                 Type:  v.Type(),
131                         })
132                 }
133                 v.SetUint(n)
134         case reflect.Bool:
135                 v.SetBool(s != "0")
136         default:
137                 panic(&UnmarshalTypeError{
138                         Value: "integer " + s,
139                         Type:  v.Type(),
140                 })
141         }
142         d.buf.Reset()
143 }
144
145 func (d *Decoder) parseString(v reflect.Value) error {
146         start := d.Offset - 1
147
148         // read the string length first
149         d.readUntil(':')
150         length, err := strconv.ParseInt(d.buf.String(), 10, 64)
151         checkForIntParseError(err, start)
152
153         d.buf.Reset()
154         n, err := io.CopyN(&d.buf, d.r, length)
155         d.Offset += n
156         if err != nil {
157                 checkForUnexpectedEOF(err, d.Offset)
158                 panic(&SyntaxError{
159                         Offset: d.Offset,
160                         What:   errors.New("unexpected I/O error: " + err.Error()),
161                 })
162         }
163
164         switch v.Kind() {
165         case reflect.String:
166                 v.SetString(d.buf.String())
167         case reflect.Slice:
168                 if v.Type().Elem().Kind() != reflect.Uint8 {
169                         panic(&UnmarshalTypeError{
170                                 Value: "string",
171                                 Type:  v.Type(),
172                         })
173                 }
174                 v.SetBytes(append([]byte(nil), d.buf.Bytes()...))
175         default:
176                 return &UnmarshalTypeError{
177                         Value: "string",
178                         Type:  v.Type(),
179                 }
180         }
181
182         d.buf.Reset()
183         return nil
184 }
185
186 // Info for parsing a dict value.
187 type dictField struct {
188         Value reflect.Value // Storage for the parsed value.
189         // True if field value should be parsed into Value. If false, the value
190         // should be parsed and discarded.
191         Ok                       bool
192         Set                      func() // Call this after parsing into Value.
193         IgnoreUnmarshalTypeError bool
194 }
195
196 // Returns specifics for parsing a dict field value.
197 func getDictField(dict reflect.Value, key string) dictField {
198         // get valuev as a map value or as a struct field
199         switch dict.Kind() {
200         case reflect.Map:
201                 value := reflect.New(dict.Type().Elem()).Elem()
202                 return dictField{
203                         Value: value,
204                         Ok:    true,
205                         Set: func() {
206                                 // Assigns the value into the map.
207                                 dict.SetMapIndex(reflect.ValueOf(key), value)
208                         },
209                 }
210         case reflect.Struct:
211                 sf, ok := getStructFieldForKey(dict.Type(), key)
212                 if !ok {
213                         return dictField{}
214                 }
215                 if sf.PkgPath != "" {
216                         panic(&UnmarshalFieldError{
217                                 Key:   key,
218                                 Type:  dict.Type(),
219                                 Field: sf,
220                         })
221                 }
222                 return dictField{
223                         Value: dict.FieldByIndex(sf.Index),
224                         Ok:    true,
225                         Set:   func() {},
226                         IgnoreUnmarshalTypeError: getTag(sf.Tag).IgnoreUnmarshalTypeError(),
227                 }
228         default:
229                 panic(dict.Kind())
230         }
231 }
232
233 func getStructFieldForKey(struct_ reflect.Type, key string) (f reflect.StructField, ok bool) {
234         for i, n := 0, struct_.NumField(); i < n; i++ {
235                 f = struct_.Field(i)
236                 tag := f.Tag.Get("bencode")
237                 if tag == "-" {
238                         continue
239                 }
240                 if f.Anonymous {
241                         continue
242                 }
243
244                 if parseTag(tag).Key() == key {
245                         ok = true
246                         break
247                 }
248
249                 if f.Name == key {
250                         ok = true
251                         break
252                 }
253
254                 if strings.EqualFold(f.Name, key) {
255                         ok = true
256                         break
257                 }
258         }
259         return
260 }
261
262 func (d *Decoder) parseDict(v reflect.Value) error {
263         switch v.Kind() {
264         case reflect.Map:
265                 t := v.Type()
266                 if t.Key().Kind() != reflect.String {
267                         panic(&UnmarshalTypeError{
268                                 Value: "object",
269                                 Type:  t,
270                         })
271                 }
272                 if v.IsNil() {
273                         v.Set(reflect.MakeMap(t))
274                 }
275         case reflect.Struct:
276         default:
277                 panic(&UnmarshalTypeError{
278                         Value: "object",
279                         Type:  v.Type(),
280                 })
281         }
282
283         // so, at this point 'd' byte was consumed, let's just read key/value
284         // pairs one by one
285         for {
286                 var keyStr string
287                 keyValue := reflect.ValueOf(&keyStr).Elem()
288                 ok, err := d.parseValue(keyValue)
289                 if err != nil {
290                         return fmt.Errorf("error parsing dict key: %s", err)
291                 }
292                 if !ok {
293                         return nil
294                 }
295
296                 df := getDictField(v, keyStr)
297
298                 // now we need to actually parse it
299                 if df.Ok {
300                         // log.Printf("parsing ok struct field for key %q", keyStr)
301                         ok, err = d.parseValue(df.Value)
302                 } else {
303                         // Discard the value, there's nowhere to put it.
304                         var if_ interface{}
305                         if_, ok = d.parseValueInterface()
306                         if if_ == nil {
307                                 err = fmt.Errorf("error parsing value for key %q", keyStr)
308                         }
309                 }
310                 if err != nil {
311                         if _, ok := err.(*UnmarshalTypeError); !ok || !df.IgnoreUnmarshalTypeError {
312                                 return fmt.Errorf("parsing value for key %q: %s", keyStr, err)
313                         }
314                 }
315                 if !ok {
316                         return fmt.Errorf("missing value for key %q", keyStr)
317                 }
318                 if df.Ok {
319                         df.Set()
320                 }
321         }
322 }
323
324 func (d *Decoder) parseList(v reflect.Value) error {
325         switch v.Kind() {
326         case reflect.Array, reflect.Slice:
327         default:
328                 panic(&UnmarshalTypeError{
329                         Value: "array",
330                         Type:  v.Type(),
331                 })
332         }
333
334         i := 0
335         for ; ; i++ {
336                 if v.Kind() == reflect.Slice && i >= v.Len() {
337                         v.Set(reflect.Append(v, reflect.Zero(v.Type().Elem())))
338                 }
339
340                 if i < v.Len() {
341                         ok, err := d.parseValue(v.Index(i))
342                         if err != nil {
343                                 return err
344                         }
345                         if !ok {
346                                 break
347                         }
348                 } else {
349                         _, ok := d.parseValueInterface()
350                         if !ok {
351                                 break
352                         }
353                 }
354         }
355
356         if i < v.Len() {
357                 if v.Kind() == reflect.Array {
358                         z := reflect.Zero(v.Type().Elem())
359                         for n := v.Len(); i < n; i++ {
360                                 v.Index(i).Set(z)
361                         }
362                 } else {
363                         v.SetLen(i)
364                 }
365         }
366
367         if i == 0 && v.Kind() == reflect.Slice {
368                 v.Set(reflect.MakeSlice(v.Type(), 0, 0))
369         }
370         return nil
371 }
372
373 func (d *Decoder) readOneValue() bool {
374         b, err := d.r.ReadByte()
375         if err != nil {
376                 panic(err)
377         }
378         if b == 'e' {
379                 d.r.UnreadByte()
380                 return false
381         } else {
382                 d.Offset++
383                 d.buf.WriteByte(b)
384         }
385
386         switch b {
387         case 'd', 'l':
388                 // read until there is nothing to read
389                 for d.readOneValue() {
390                 }
391                 // consume 'e' as well
392                 b = d.readByte()
393                 d.buf.WriteByte(b)
394         case 'i':
395                 d.readUntil('e')
396                 d.buf.WriteString("e")
397         default:
398                 if b >= '0' && b <= '9' {
399                         start := d.buf.Len() - 1
400                         d.readUntil(':')
401                         length, err := strconv.ParseInt(d.buf.String()[start:], 10, 64)
402                         checkForIntParseError(err, d.Offset-1)
403
404                         d.buf.WriteString(":")
405                         n, err := io.CopyN(&d.buf, d.r, length)
406                         d.Offset += n
407                         if err != nil {
408                                 checkForUnexpectedEOF(err, d.Offset)
409                                 panic(&SyntaxError{
410                                         Offset: d.Offset,
411                                         What:   errors.New("unexpected I/O error: " + err.Error()),
412                                 })
413                         }
414                         break
415                 }
416
417                 d.raiseUnknownValueType(b, d.Offset-1)
418         }
419
420         return true
421
422 }
423
424 func (d *Decoder) parseUnmarshaler(v reflect.Value) bool {
425         m, ok := v.Interface().(Unmarshaler)
426         if !ok {
427                 // T doesn't work, try *T
428                 if v.Kind() != reflect.Ptr && v.CanAddr() {
429                         m, ok = v.Addr().Interface().(Unmarshaler)
430                         if ok {
431                                 v = v.Addr()
432                         }
433                 }
434         }
435         if ok && (v.Kind() != reflect.Ptr || !v.IsNil()) {
436                 if d.readOneValue() {
437                         err := m.UnmarshalBencode(d.buf.Bytes())
438                         d.buf.Reset()
439                         if err != nil {
440                                 panic(&UnmarshalerError{v.Type(), err})
441                         }
442                         return true
443                 }
444                 d.buf.Reset()
445         }
446
447         return false
448 }
449
450 // Returns true if there was a value and it's now stored in 'v', otherwise
451 // there was an end symbol ("e") and no value was stored.
452 func (d *Decoder) parseValue(v reflect.Value) (bool, error) {
453         // we support one level of indirection at the moment
454         if v.Kind() == reflect.Ptr {
455                 // if the pointer is nil, allocate a new element of the type it
456                 // points to
457                 if v.IsNil() {
458                         v.Set(reflect.New(v.Type().Elem()))
459                 }
460                 v = v.Elem()
461         }
462
463         if d.parseUnmarshaler(v) {
464                 return true, nil
465         }
466
467         // common case: interface{}
468         if v.Kind() == reflect.Interface && v.NumMethod() == 0 {
469                 iface, _ := d.parseValueInterface()
470                 v.Set(reflect.ValueOf(iface))
471                 return true, nil
472         }
473
474         b, err := d.r.ReadByte()
475         if err != nil {
476                 panic(err)
477         }
478         d.Offset++
479
480         switch b {
481         case 'e':
482                 return false, nil
483         case 'd':
484                 return true, d.parseDict(v)
485         case 'l':
486                 return true, d.parseList(v)
487         case 'i':
488                 d.parseInt(v)
489                 return true, nil
490         default:
491                 if b >= '0' && b <= '9' {
492                         // string
493                         // append first digit of the length to the buffer
494                         d.buf.WriteByte(b)
495                         return true, d.parseString(v)
496                 }
497
498                 d.raiseUnknownValueType(b, d.Offset-1)
499         }
500         panic("unreachable")
501 }
502
503 // An unknown bencode type character was encountered.
504 func (d *Decoder) raiseUnknownValueType(b byte, offset int64) {
505         panic(&SyntaxError{
506                 Offset: offset,
507                 What:   fmt.Errorf("unknown value type %+q", b),
508         })
509 }
510
511 func (d *Decoder) parseValueInterface() (interface{}, bool) {
512         b, err := d.r.ReadByte()
513         if err != nil {
514                 panic(err)
515         }
516         d.Offset++
517
518         switch b {
519         case 'e':
520                 return nil, false
521         case 'd':
522                 return d.parseDictInterface(), true
523         case 'l':
524                 return d.parseListInterface(), true
525         case 'i':
526                 return d.parseIntInterface(), true
527         default:
528                 if b >= '0' && b <= '9' {
529                         // string
530                         // append first digit of the length to the buffer
531                         d.buf.WriteByte(b)
532                         return d.parseStringInterface(), true
533                 }
534
535                 d.raiseUnknownValueType(b, d.Offset-1)
536                 panic("unreachable")
537         }
538 }
539
540 func (d *Decoder) parseIntInterface() (ret interface{}) {
541         start := d.Offset - 1
542         d.readUntil('e')
543         if d.buf.Len() == 0 {
544                 panic(&SyntaxError{
545                         Offset: start,
546                         What:   errors.New("empty integer value"),
547                 })
548         }
549
550         n, err := strconv.ParseInt(d.buf.String(), 10, 64)
551         if ne, ok := err.(*strconv.NumError); ok && ne.Err == strconv.ErrRange {
552                 i := new(big.Int)
553                 _, ok := i.SetString(d.buf.String(), 10)
554                 if !ok {
555                         panic(&SyntaxError{
556                                 Offset: start,
557                                 What:   errors.New("failed to parse integer"),
558                         })
559                 }
560                 ret = i
561         } else {
562                 checkForIntParseError(err, start)
563                 ret = n
564         }
565
566         d.buf.Reset()
567         return
568 }
569
570 func (d *Decoder) parseStringInterface() interface{} {
571         start := d.Offset - 1
572
573         // read the string length first
574         d.readUntil(':')
575         length, err := strconv.ParseInt(d.buf.String(), 10, 64)
576         checkForIntParseError(err, start)
577
578         d.buf.Reset()
579         n, err := io.CopyN(&d.buf, d.r, length)
580         d.Offset += n
581         if err != nil {
582                 checkForUnexpectedEOF(err, d.Offset)
583                 panic(&SyntaxError{
584                         Offset: d.Offset,
585                         What:   errors.New("unexpected I/O error: " + err.Error()),
586                 })
587         }
588
589         s := d.buf.String()
590         d.buf.Reset()
591         return s
592 }
593
594 func (d *Decoder) parseDictInterface() interface{} {
595         dict := make(map[string]interface{})
596         for {
597                 keyi, ok := d.parseValueInterface()
598                 if !ok {
599                         break
600                 }
601
602                 key, ok := keyi.(string)
603                 if !ok {
604                         panic(&SyntaxError{
605                                 Offset: d.Offset,
606                                 What:   errors.New("non-string key in a dict"),
607                         })
608                 }
609
610                 valuei, ok := d.parseValueInterface()
611                 if !ok {
612                         break
613                 }
614
615                 dict[key] = valuei
616         }
617         return dict
618 }
619
620 func (d *Decoder) parseListInterface() interface{} {
621         var list []interface{}
622         for {
623                 valuei, ok := d.parseValueInterface()
624                 if !ok {
625                         break
626                 }
627
628                 list = append(list, valuei)
629         }
630         if list == nil {
631                 list = make([]interface{}, 0, 0)
632         }
633         return list
634 }