]> Sergey Matveev's repositories - btrtrc.git/blob - bencode/decode.go
2905aec94554dc56a59f3ef1211322fe3d016281
[btrtrc.git] / bencode / decode.go
1 package bencode
2
3 import (
4         "bytes"
5         "errors"
6         "fmt"
7         "io"
8         "math/big"
9         "reflect"
10         "runtime"
11         "strconv"
12         "strings"
13 )
14
15 type Decoder struct {
16         r interface {
17                 io.ByteScanner
18                 io.Reader
19         }
20         // Sum of bytes used to Decode values.
21         Offset int64
22         buf    bytes.Buffer
23 }
24
25 func (d *Decoder) Decode(v interface{}) (err error) {
26         defer func() {
27                 if err != nil {
28                         return
29                 }
30                 r := recover()
31                 _, ok := r.(runtime.Error)
32                 if ok {
33                         panic(r)
34                 }
35                 err, ok = r.(error)
36                 if !ok && r != nil {
37                         panic(r)
38                 }
39         }()
40
41         pv := reflect.ValueOf(v)
42         if pv.Kind() != reflect.Ptr || pv.IsNil() {
43                 return &UnmarshalInvalidArgError{reflect.TypeOf(v)}
44         }
45
46         ok, err := d.parseValue(pv.Elem())
47         if err != nil {
48                 return
49         }
50         if !ok {
51                 d.throwSyntaxError(d.Offset-1, errors.New("unexpected 'e'"))
52         }
53         return
54 }
55
56 func checkForUnexpectedEOF(err error, offset int64) {
57         if err == io.EOF {
58                 panic(&SyntaxError{
59                         Offset: offset,
60                         What:   io.ErrUnexpectedEOF,
61                 })
62         }
63 }
64
65 func (d *Decoder) readByte() byte {
66         b, err := d.r.ReadByte()
67         if err != nil {
68                 checkForUnexpectedEOF(err, d.Offset)
69                 panic(err)
70         }
71
72         d.Offset++
73         return b
74 }
75
76 // reads data writing it to 'd.buf' until 'sep' byte is encountered, 'sep' byte
77 // is consumed, but not included into the 'd.buf'
78 func (d *Decoder) readUntil(sep byte) {
79         for {
80                 b := d.readByte()
81                 if b == sep {
82                         return
83                 }
84                 d.buf.WriteByte(b)
85         }
86 }
87
88 func checkForIntParseError(err error, offset int64) {
89         if err != nil {
90                 panic(&SyntaxError{
91                         Offset: offset,
92                         What:   err,
93                 })
94         }
95 }
96
97 func (d *Decoder) throwSyntaxError(offset int64, err error) {
98         panic(&SyntaxError{
99                 Offset: offset,
100                 What:   err,
101         })
102 }
103
104 // called when 'i' was consumed
105 func (d *Decoder) parseInt(v reflect.Value) {
106         start := d.Offset - 1
107         d.readUntil('e')
108         if d.buf.Len() == 0 {
109                 panic(&SyntaxError{
110                         Offset: start,
111                         What:   errors.New("empty integer value"),
112                 })
113         }
114
115         s := d.buf.String()
116
117         switch v.Kind() {
118         case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
119                 n, err := strconv.ParseInt(s, 10, 64)
120                 checkForIntParseError(err, start)
121
122                 if v.OverflowInt(n) {
123                         panic(&UnmarshalTypeError{
124                                 Value: "integer " + s,
125                                 Type:  v.Type(),
126                         })
127                 }
128                 v.SetInt(n)
129         case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
130                 n, err := strconv.ParseUint(s, 10, 64)
131                 checkForIntParseError(err, start)
132
133                 if v.OverflowUint(n) {
134                         panic(&UnmarshalTypeError{
135                                 Value: "integer " + s,
136                                 Type:  v.Type(),
137                         })
138                 }
139                 v.SetUint(n)
140         case reflect.Bool:
141                 v.SetBool(s != "0")
142         default:
143                 panic(&UnmarshalTypeError{
144                         Value: "integer " + s,
145                         Type:  v.Type(),
146                 })
147         }
148         d.buf.Reset()
149 }
150
151 func (d *Decoder) parseString(v reflect.Value) error {
152         start := d.Offset - 1
153
154         // read the string length first
155         d.readUntil(':')
156         length, err := strconv.ParseInt(d.buf.String(), 10, 64)
157         checkForIntParseError(err, start)
158
159         d.buf.Reset()
160         n, err := io.CopyN(&d.buf, d.r, length)
161         d.Offset += n
162         if err != nil {
163                 checkForUnexpectedEOF(err, d.Offset)
164                 panic(&SyntaxError{
165                         Offset: d.Offset,
166                         What:   errors.New("unexpected I/O error: " + err.Error()),
167                 })
168         }
169
170         switch v.Kind() {
171         case reflect.String:
172                 v.SetString(d.buf.String())
173         case reflect.Slice:
174                 if v.Type().Elem().Kind() != reflect.Uint8 {
175                         panic(&UnmarshalTypeError{
176                                 Value: "string",
177                                 Type:  v.Type(),
178                         })
179                 }
180                 v.SetBytes(append([]byte(nil), d.buf.Bytes()...))
181         default:
182                 return &UnmarshalTypeError{
183                         Value: "string",
184                         Type:  v.Type(),
185                 }
186         }
187
188         d.buf.Reset()
189         return nil
190 }
191
192 // Info for parsing a dict value.
193 type dictField struct {
194         Value reflect.Value // Storage for the parsed value.
195         // True if field value should be parsed into Value. If false, the value
196         // should be parsed and discarded.
197         Ok                       bool
198         Set                      func() // Call this after parsing into Value.
199         IgnoreUnmarshalTypeError bool
200 }
201
202 // Returns specifics for parsing a dict field value.
203 func getDictField(dict reflect.Value, key string) dictField {
204         // get valuev as a map value or as a struct field
205         switch dict.Kind() {
206         case reflect.Map:
207                 value := reflect.New(dict.Type().Elem()).Elem()
208                 return dictField{
209                         Value: value,
210                         Ok:    true,
211                         Set: func() {
212                                 // Assigns the value into the map.
213                                 dict.SetMapIndex(reflect.ValueOf(key), value)
214                         },
215                 }
216         case reflect.Struct:
217                 sf, ok := getStructFieldForKey(dict.Type(), key)
218                 if !ok {
219                         return dictField{}
220                 }
221                 if sf.PkgPath != "" {
222                         panic(&UnmarshalFieldError{
223                                 Key:   key,
224                                 Type:  dict.Type(),
225                                 Field: sf,
226                         })
227                 }
228                 return dictField{
229                         Value:                    dict.FieldByIndex(sf.Index),
230                         Ok:                       true,
231                         Set:                      func() {},
232                         IgnoreUnmarshalTypeError: getTag(sf.Tag).IgnoreUnmarshalTypeError(),
233                 }
234         default:
235                 panic(dict.Kind())
236         }
237 }
238
239 func getStructFieldForKey(struct_ reflect.Type, key string) (f reflect.StructField, ok bool) {
240         for i, n := 0, struct_.NumField(); i < n; i++ {
241                 f = struct_.Field(i)
242                 tag := f.Tag.Get("bencode")
243                 if tag == "-" {
244                         continue
245                 }
246                 if f.Anonymous {
247                         continue
248                 }
249
250                 if parseTag(tag).Key() == key {
251                         ok = true
252                         break
253                 }
254
255                 if f.Name == key {
256                         ok = true
257                         break
258                 }
259
260                 if strings.EqualFold(f.Name, key) {
261                         ok = true
262                         break
263                 }
264         }
265         return
266 }
267
268 func (d *Decoder) parseDict(v reflect.Value) error {
269         switch v.Kind() {
270         case reflect.Map:
271                 t := v.Type()
272                 if t.Key().Kind() != reflect.String {
273                         panic(&UnmarshalTypeError{
274                                 Value: "object",
275                                 Type:  t,
276                         })
277                 }
278                 if v.IsNil() {
279                         v.Set(reflect.MakeMap(t))
280                 }
281         case reflect.Struct:
282         default:
283                 panic(&UnmarshalTypeError{
284                         Value: "object",
285                         Type:  v.Type(),
286                 })
287         }
288
289         // so, at this point 'd' byte was consumed, let's just read key/value
290         // pairs one by one
291         for {
292                 var keyStr string
293                 keyValue := reflect.ValueOf(&keyStr).Elem()
294                 ok, err := d.parseValue(keyValue)
295                 if err != nil {
296                         return fmt.Errorf("error parsing dict key: %s", err)
297                 }
298                 if !ok {
299                         return nil
300                 }
301
302                 df := getDictField(v, keyStr)
303
304                 // now we need to actually parse it
305                 if df.Ok {
306                         // log.Printf("parsing ok struct field for key %q", keyStr)
307                         ok, err = d.parseValue(df.Value)
308                 } else {
309                         // Discard the value, there's nowhere to put it.
310                         var if_ interface{}
311                         if_, ok = d.parseValueInterface()
312                         if if_ == nil {
313                                 err = fmt.Errorf("error parsing value for key %q", keyStr)
314                         }
315                 }
316                 if err != nil {
317                         if _, ok := err.(*UnmarshalTypeError); !ok || !df.IgnoreUnmarshalTypeError {
318                                 return fmt.Errorf("parsing value for key %q: %s", keyStr, err)
319                         }
320                 }
321                 if !ok {
322                         return fmt.Errorf("missing value for key %q", keyStr)
323                 }
324                 if df.Ok {
325                         df.Set()
326                 }
327         }
328 }
329
330 func (d *Decoder) parseList(v reflect.Value) error {
331         switch v.Kind() {
332         case reflect.Array, reflect.Slice:
333         default:
334                 panic(&UnmarshalTypeError{
335                         Value: "array",
336                         Type:  v.Type(),
337                 })
338         }
339
340         i := 0
341         for ; ; i++ {
342                 if v.Kind() == reflect.Slice && i >= v.Len() {
343                         v.Set(reflect.Append(v, reflect.Zero(v.Type().Elem())))
344                 }
345
346                 if i < v.Len() {
347                         ok, err := d.parseValue(v.Index(i))
348                         if err != nil {
349                                 return err
350                         }
351                         if !ok {
352                                 break
353                         }
354                 } else {
355                         _, ok := d.parseValueInterface()
356                         if !ok {
357                                 break
358                         }
359                 }
360         }
361
362         if i < v.Len() {
363                 if v.Kind() == reflect.Array {
364                         z := reflect.Zero(v.Type().Elem())
365                         for n := v.Len(); i < n; i++ {
366                                 v.Index(i).Set(z)
367                         }
368                 } else {
369                         v.SetLen(i)
370                 }
371         }
372
373         if i == 0 && v.Kind() == reflect.Slice {
374                 v.Set(reflect.MakeSlice(v.Type(), 0, 0))
375         }
376         return nil
377 }
378
379 func (d *Decoder) readOneValue() bool {
380         b, err := d.r.ReadByte()
381         if err != nil {
382                 panic(err)
383         }
384         if b == 'e' {
385                 d.r.UnreadByte()
386                 return false
387         } else {
388                 d.Offset++
389                 d.buf.WriteByte(b)
390         }
391
392         switch b {
393         case 'd', 'l':
394                 // read until there is nothing to read
395                 for d.readOneValue() {
396                 }
397                 // consume 'e' as well
398                 b = d.readByte()
399                 d.buf.WriteByte(b)
400         case 'i':
401                 d.readUntil('e')
402                 d.buf.WriteString("e")
403         default:
404                 if b >= '0' && b <= '9' {
405                         start := d.buf.Len() - 1
406                         d.readUntil(':')
407                         length, err := strconv.ParseInt(d.buf.String()[start:], 10, 64)
408                         checkForIntParseError(err, d.Offset-1)
409
410                         d.buf.WriteString(":")
411                         n, err := io.CopyN(&d.buf, d.r, length)
412                         d.Offset += n
413                         if err != nil {
414                                 checkForUnexpectedEOF(err, d.Offset)
415                                 panic(&SyntaxError{
416                                         Offset: d.Offset,
417                                         What:   errors.New("unexpected I/O error: " + err.Error()),
418                                 })
419                         }
420                         break
421                 }
422
423                 d.raiseUnknownValueType(b, d.Offset-1)
424         }
425
426         return true
427
428 }
429
430 func (d *Decoder) parseUnmarshaler(v reflect.Value) bool {
431         m, ok := v.Interface().(Unmarshaler)
432         if !ok {
433                 // T doesn't work, try *T
434                 if v.Kind() != reflect.Ptr && v.CanAddr() {
435                         m, ok = v.Addr().Interface().(Unmarshaler)
436                         if ok {
437                                 v = v.Addr()
438                         }
439                 }
440         }
441         if ok && (v.Kind() != reflect.Ptr || !v.IsNil()) {
442                 if d.readOneValue() {
443                         err := m.UnmarshalBencode(d.buf.Bytes())
444                         d.buf.Reset()
445                         if err != nil {
446                                 panic(&UnmarshalerError{v.Type(), err})
447                         }
448                         return true
449                 }
450                 d.buf.Reset()
451         }
452
453         return false
454 }
455
456 // Returns true if there was a value and it's now stored in 'v', otherwise
457 // there was an end symbol ("e") and no value was stored.
458 func (d *Decoder) parseValue(v reflect.Value) (bool, error) {
459         // we support one level of indirection at the moment
460         if v.Kind() == reflect.Ptr {
461                 // if the pointer is nil, allocate a new element of the type it
462                 // points to
463                 if v.IsNil() {
464                         v.Set(reflect.New(v.Type().Elem()))
465                 }
466                 v = v.Elem()
467         }
468
469         if d.parseUnmarshaler(v) {
470                 return true, nil
471         }
472
473         // common case: interface{}
474         if v.Kind() == reflect.Interface && v.NumMethod() == 0 {
475                 iface, _ := d.parseValueInterface()
476                 v.Set(reflect.ValueOf(iface))
477                 return true, nil
478         }
479
480         b, err := d.r.ReadByte()
481         if err != nil {
482                 panic(err)
483         }
484         d.Offset++
485
486         switch b {
487         case 'e':
488                 return false, nil
489         case 'd':
490                 return true, d.parseDict(v)
491         case 'l':
492                 return true, d.parseList(v)
493         case 'i':
494                 d.parseInt(v)
495                 return true, nil
496         default:
497                 if b >= '0' && b <= '9' {
498                         // It's a string.
499                         d.buf.Reset()
500                         // Write the  first digit of the length to the buffer.
501                         d.buf.WriteByte(b)
502                         return true, d.parseString(v)
503                 }
504
505                 d.raiseUnknownValueType(b, d.Offset-1)
506         }
507         panic("unreachable")
508 }
509
510 // An unknown bencode type character was encountered.
511 func (d *Decoder) raiseUnknownValueType(b byte, offset int64) {
512         panic(&SyntaxError{
513                 Offset: offset,
514                 What:   fmt.Errorf("unknown value type %+q", b),
515         })
516 }
517
518 func (d *Decoder) parseValueInterface() (interface{}, bool) {
519         b, err := d.r.ReadByte()
520         if err != nil {
521                 panic(err)
522         }
523         d.Offset++
524
525         switch b {
526         case 'e':
527                 return nil, false
528         case 'd':
529                 return d.parseDictInterface(), true
530         case 'l':
531                 return d.parseListInterface(), true
532         case 'i':
533                 return d.parseIntInterface(), true
534         default:
535                 if b >= '0' && b <= '9' {
536                         // string
537                         // append first digit of the length to the buffer
538                         d.buf.WriteByte(b)
539                         return d.parseStringInterface(), true
540                 }
541
542                 d.raiseUnknownValueType(b, d.Offset-1)
543                 panic("unreachable")
544         }
545 }
546
547 func (d *Decoder) parseIntInterface() (ret interface{}) {
548         start := d.Offset - 1
549         d.readUntil('e')
550         if d.buf.Len() == 0 {
551                 panic(&SyntaxError{
552                         Offset: start,
553                         What:   errors.New("empty integer value"),
554                 })
555         }
556
557         n, err := strconv.ParseInt(d.buf.String(), 10, 64)
558         if ne, ok := err.(*strconv.NumError); ok && ne.Err == strconv.ErrRange {
559                 i := new(big.Int)
560                 _, ok := i.SetString(d.buf.String(), 10)
561                 if !ok {
562                         panic(&SyntaxError{
563                                 Offset: start,
564                                 What:   errors.New("failed to parse integer"),
565                         })
566                 }
567                 ret = i
568         } else {
569                 checkForIntParseError(err, start)
570                 ret = n
571         }
572
573         d.buf.Reset()
574         return
575 }
576
577 func (d *Decoder) parseStringInterface() interface{} {
578         start := d.Offset - 1
579
580         // read the string length first
581         d.readUntil(':')
582         length, err := strconv.ParseInt(d.buf.String(), 10, 64)
583         checkForIntParseError(err, start)
584
585         d.buf.Reset()
586         n, err := io.CopyN(&d.buf, d.r, length)
587         d.Offset += n
588         if err != nil {
589                 checkForUnexpectedEOF(err, d.Offset)
590                 panic(&SyntaxError{
591                         Offset: d.Offset,
592                         What:   errors.New("unexpected I/O error: " + err.Error()),
593                 })
594         }
595
596         s := d.buf.String()
597         d.buf.Reset()
598         return s
599 }
600
601 func (d *Decoder) parseDictInterface() interface{} {
602         dict := make(map[string]interface{})
603         for {
604                 keyi, ok := d.parseValueInterface()
605                 if !ok {
606                         break
607                 }
608
609                 key, ok := keyi.(string)
610                 if !ok {
611                         panic(&SyntaxError{
612                                 Offset: d.Offset,
613                                 What:   errors.New("non-string key in a dict"),
614                         })
615                 }
616
617                 valuei, ok := d.parseValueInterface()
618                 if !ok {
619                         break
620                 }
621
622                 dict[key] = valuei
623         }
624         return dict
625 }
626
627 func (d *Decoder) parseListInterface() interface{} {
628         var list []interface{}
629         for {
630                 valuei, ok := d.parseValueInterface()
631                 if !ok {
632                         break
633                 }
634
635                 list = append(list, valuei)
636         }
637         if list == nil {
638                 list = make([]interface{}, 0, 0)
639         }
640         return list
641 }