]> Sergey Matveev's repositories - btrtrc.git/blob - bencode/decode.go
67554bcc8b68b0ab8d2327ca71dd2086e3cb20d0
[btrtrc.git] / bencode / decode.go
1 package bencode
2
3 import (
4         "bytes"
5         "errors"
6         "fmt"
7         "io"
8         "math/big"
9         "reflect"
10         "runtime"
11         "strconv"
12         "strings"
13 )
14
15 type Decoder struct {
16         r interface {
17                 io.ByteScanner
18                 io.Reader
19         }
20         // Sum of bytes used to Decode values.
21         Offset int64
22         buf    bytes.Buffer
23         key    string
24 }
25
26 func (d *Decoder) Decode(v interface{}) (err error) {
27         defer func() {
28                 if e := recover(); e != nil {
29                         if _, ok := e.(runtime.Error); ok {
30                                 panic(e)
31                         }
32                         err = e.(error)
33                 }
34         }()
35
36         pv := reflect.ValueOf(v)
37         if pv.Kind() != reflect.Ptr || pv.IsNil() {
38                 return &UnmarshalInvalidArgError{reflect.TypeOf(v)}
39         }
40
41         ok, err := d.parseValue(pv.Elem())
42         if err != nil {
43                 return
44         }
45         if !ok {
46                 d.throwSyntaxError(d.Offset-1, errors.New("unexpected 'e'"))
47         }
48         return
49 }
50
51 func checkForUnexpectedEOF(err error, offset int64) {
52         if err == io.EOF {
53                 panic(&SyntaxError{
54                         Offset: offset,
55                         What:   io.ErrUnexpectedEOF,
56                 })
57         }
58 }
59
60 func (d *Decoder) readByte() byte {
61         b, err := d.r.ReadByte()
62         if err != nil {
63                 checkForUnexpectedEOF(err, d.Offset)
64                 panic(err)
65         }
66
67         d.Offset++
68         return b
69 }
70
71 // reads data writing it to 'd.buf' until 'sep' byte is encountered, 'sep' byte
72 // is consumed, but not included into the 'd.buf'
73 func (d *Decoder) readUntil(sep byte) {
74         for {
75                 b := d.readByte()
76                 if b == sep {
77                         return
78                 }
79                 d.buf.WriteByte(b)
80         }
81 }
82
83 func checkForIntParseError(err error, offset int64) {
84         if err != nil {
85                 panic(&SyntaxError{
86                         Offset: offset,
87                         What:   err,
88                 })
89         }
90 }
91
92 func (d *Decoder) throwSyntaxError(offset int64, err error) {
93         panic(&SyntaxError{
94                 Offset: offset,
95                 What:   err,
96         })
97 }
98
99 // called when 'i' was consumed
100 func (d *Decoder) parseInt(v reflect.Value) {
101         start := d.Offset - 1
102         d.readUntil('e')
103         if d.buf.Len() == 0 {
104                 panic(&SyntaxError{
105                         Offset: start,
106                         What:   errors.New("empty integer value"),
107                 })
108         }
109
110         s := d.buf.String()
111
112         switch v.Kind() {
113         case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
114                 n, err := strconv.ParseInt(s, 10, 64)
115                 checkForIntParseError(err, start)
116
117                 if v.OverflowInt(n) {
118                         panic(&UnmarshalTypeError{
119                                 Value: "integer " + s,
120                                 Type:  v.Type(),
121                         })
122                 }
123                 v.SetInt(n)
124         case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
125                 n, err := strconv.ParseUint(s, 10, 64)
126                 checkForIntParseError(err, start)
127
128                 if v.OverflowUint(n) {
129                         panic(&UnmarshalTypeError{
130                                 Value: "integer " + s,
131                                 Type:  v.Type(),
132                         })
133                 }
134                 v.SetUint(n)
135         case reflect.Bool:
136                 v.SetBool(s != "0")
137         default:
138                 panic(&UnmarshalTypeError{
139                         Value: "integer " + s,
140                         Type:  v.Type(),
141                 })
142         }
143         d.buf.Reset()
144 }
145
146 func (d *Decoder) parseString(v reflect.Value) error {
147         start := d.Offset - 1
148
149         // read the string length first
150         d.readUntil(':')
151         length, err := strconv.ParseInt(d.buf.String(), 10, 64)
152         checkForIntParseError(err, start)
153
154         d.buf.Reset()
155         n, err := io.CopyN(&d.buf, d.r, length)
156         d.Offset += n
157         if err != nil {
158                 checkForUnexpectedEOF(err, d.Offset)
159                 panic(&SyntaxError{
160                         Offset: d.Offset,
161                         What:   errors.New("unexpected I/O error: " + err.Error()),
162                 })
163         }
164
165         switch v.Kind() {
166         case reflect.String:
167                 v.SetString(d.buf.String())
168         case reflect.Slice:
169                 if v.Type().Elem().Kind() != reflect.Uint8 {
170                         panic(&UnmarshalTypeError{
171                                 Value: "string",
172                                 Type:  v.Type(),
173                         })
174                 }
175                 sl := make([]byte, len(d.buf.Bytes()))
176                 copy(sl, d.buf.Bytes())
177                 v.Set(reflect.ValueOf(sl))
178         default:
179                 return &UnmarshalTypeError{
180                         Value: "string",
181                         Type:  v.Type(),
182                 }
183         }
184
185         d.buf.Reset()
186         return nil
187 }
188
189 func (d *Decoder) parseDict(v reflect.Value) error {
190         switch v.Kind() {
191         case reflect.Map:
192                 t := v.Type()
193                 if t.Key().Kind() != reflect.String {
194                         panic(&UnmarshalTypeError{
195                                 Value: "object",
196                                 Type:  t,
197                         })
198                 }
199                 if v.IsNil() {
200                         v.Set(reflect.MakeMap(t))
201                 }
202         case reflect.Struct:
203         default:
204                 panic(&UnmarshalTypeError{
205                         Value: "object",
206                         Type:  v.Type(),
207                 })
208         }
209
210         var mapElem reflect.Value
211
212         // so, at this point 'd' byte was consumed, let's just read key/value
213         // pairs one by one
214         for {
215                 var valuev reflect.Value
216                 keyv := reflect.ValueOf(&d.key).Elem()
217                 ok, err := d.parseValue(keyv)
218                 if err != nil {
219                         return fmt.Errorf("error parsing dict key: %s", err)
220                 }
221                 if !ok {
222                         return nil
223                 }
224
225                 // get valuev as a map value or as a struct field
226                 switch v.Kind() {
227                 case reflect.Map:
228                         elem_type := v.Type().Elem()
229                         if !mapElem.IsValid() {
230                                 mapElem = reflect.New(elem_type).Elem()
231                         } else {
232                                 mapElem.Set(reflect.Zero(elem_type))
233                         }
234                         valuev = mapElem
235                 case reflect.Struct:
236                         var f reflect.StructField
237                         var ok bool
238
239                         t := v.Type()
240                         for i, n := 0, t.NumField(); i < n; i++ {
241                                 f = t.Field(i)
242                                 tag := f.Tag.Get("bencode")
243                                 if tag == "-" {
244                                         continue
245                                 }
246                                 if f.Anonymous {
247                                         continue
248                                 }
249
250                                 tag_name, _ := parseTag(tag)
251                                 if tag_name == d.key {
252                                         ok = true
253                                         break
254                                 }
255
256                                 if f.Name == d.key {
257                                         ok = true
258                                         break
259                                 }
260
261                                 if strings.EqualFold(f.Name, d.key) {
262                                         ok = true
263                                         break
264                                 }
265                         }
266
267                         if ok {
268                                 if f.PkgPath != "" {
269                                         panic(&UnmarshalFieldError{
270                                                 Key:   d.key,
271                                                 Type:  v.Type(),
272                                                 Field: f,
273                                         })
274                                 } else {
275                                         valuev = v.FieldByIndex(f.Index)
276                                 }
277                         } else {
278                                 _, ok := d.parseValueInterface()
279                                 if !ok {
280                                         return fmt.Errorf("error parsing dict value for key %q", d.key)
281                                 }
282                                 continue
283                         }
284                 }
285
286                 // now we need to actually parse it
287                 ok, err = d.parseValue(valuev)
288                 if err != nil {
289                         return fmt.Errorf("parsing value for key %q: %s", d.key, err)
290                 }
291                 if !ok {
292                         return fmt.Errorf("missing value for key %q", d.key)
293                 }
294                 if v.Kind() == reflect.Map {
295                         v.SetMapIndex(keyv, valuev)
296                 }
297         }
298 }
299
300 func (d *Decoder) parseList(v reflect.Value) error {
301         switch v.Kind() {
302         case reflect.Array, reflect.Slice:
303         default:
304                 panic(&UnmarshalTypeError{
305                         Value: "array",
306                         Type:  v.Type(),
307                 })
308         }
309
310         i := 0
311         for ; ; i++ {
312                 if v.Kind() == reflect.Slice && i >= v.Len() {
313                         v.Set(reflect.Append(v, reflect.Zero(v.Type().Elem())))
314                 }
315
316                 if i < v.Len() {
317                         ok, err := d.parseValue(v.Index(i))
318                         if err != nil {
319                                 return err
320                         }
321                         if !ok {
322                                 break
323                         }
324                 } else {
325                         _, ok := d.parseValueInterface()
326                         if !ok {
327                                 break
328                         }
329                 }
330         }
331
332         if i < v.Len() {
333                 if v.Kind() == reflect.Array {
334                         z := reflect.Zero(v.Type().Elem())
335                         for n := v.Len(); i < n; i++ {
336                                 v.Index(i).Set(z)
337                         }
338                 } else {
339                         v.SetLen(i)
340                 }
341         }
342
343         if i == 0 && v.Kind() == reflect.Slice {
344                 v.Set(reflect.MakeSlice(v.Type(), 0, 0))
345         }
346         return nil
347 }
348
349 func (d *Decoder) readOneValue() bool {
350         b, err := d.r.ReadByte()
351         if err != nil {
352                 panic(err)
353         }
354         if b == 'e' {
355                 d.r.UnreadByte()
356                 return false
357         } else {
358                 d.Offset++
359                 d.buf.WriteByte(b)
360         }
361
362         switch b {
363         case 'd', 'l':
364                 // read until there is nothing to read
365                 for d.readOneValue() {
366                 }
367                 // consume 'e' as well
368                 b = d.readByte()
369                 d.buf.WriteByte(b)
370         case 'i':
371                 d.readUntil('e')
372                 d.buf.WriteString("e")
373         default:
374                 if b >= '0' && b <= '9' {
375                         start := d.buf.Len() - 1
376                         d.readUntil(':')
377                         length, err := strconv.ParseInt(d.buf.String()[start:], 10, 64)
378                         checkForIntParseError(err, d.Offset-1)
379
380                         d.buf.WriteString(":")
381                         n, err := io.CopyN(&d.buf, d.r, length)
382                         d.Offset += n
383                         if err != nil {
384                                 checkForUnexpectedEOF(err, d.Offset)
385                                 panic(&SyntaxError{
386                                         Offset: d.Offset,
387                                         What:   errors.New("unexpected I/O error: " + err.Error()),
388                                 })
389                         }
390                         break
391                 }
392
393                 d.raiseUnknownValueType(b, d.Offset-1)
394         }
395
396         return true
397
398 }
399
400 func (d *Decoder) parseUnmarshaler(v reflect.Value) bool {
401         m, ok := v.Interface().(Unmarshaler)
402         if !ok {
403                 // T doesn't work, try *T
404                 if v.Kind() != reflect.Ptr && v.CanAddr() {
405                         m, ok = v.Addr().Interface().(Unmarshaler)
406                         if ok {
407                                 v = v.Addr()
408                         }
409                 }
410         }
411         if ok && (v.Kind() != reflect.Ptr || !v.IsNil()) {
412                 if d.readOneValue() {
413                         err := m.UnmarshalBencode(d.buf.Bytes())
414                         d.buf.Reset()
415                         if err != nil {
416                                 panic(&UnmarshalerError{v.Type(), err})
417                         }
418                         return true
419                 }
420                 d.buf.Reset()
421         }
422
423         return false
424 }
425
426 // Returns true if there was a value and it's now stored in 'v', otherwise
427 // there was an end symbol ("e") and no value was stored.
428 func (d *Decoder) parseValue(v reflect.Value) (bool, error) {
429         // we support one level of indirection at the moment
430         if v.Kind() == reflect.Ptr {
431                 // if the pointer is nil, allocate a new element of the type it
432                 // points to
433                 if v.IsNil() {
434                         v.Set(reflect.New(v.Type().Elem()))
435                 }
436                 v = v.Elem()
437         }
438
439         if d.parseUnmarshaler(v) {
440                 return true, nil
441         }
442
443         // common case: interface{}
444         if v.Kind() == reflect.Interface && v.NumMethod() == 0 {
445                 iface, _ := d.parseValueInterface()
446                 v.Set(reflect.ValueOf(iface))
447                 return true, nil
448         }
449
450         b, err := d.r.ReadByte()
451         if err != nil {
452                 panic(err)
453         }
454         d.Offset++
455
456         switch b {
457         case 'e':
458                 return false, nil
459         case 'd':
460                 return true, d.parseDict(v)
461         case 'l':
462                 return true, d.parseList(v)
463         case 'i':
464                 d.parseInt(v)
465                 return true, nil
466         default:
467                 if b >= '0' && b <= '9' {
468                         // string
469                         // append first digit of the length to the buffer
470                         d.buf.WriteByte(b)
471                         return true, d.parseString(v)
472                 }
473
474                 d.raiseUnknownValueType(b, d.Offset-1)
475         }
476         panic("unreachable")
477 }
478
479 // An unknown bencode type character was encountered.
480 func (d *Decoder) raiseUnknownValueType(b byte, offset int64) {
481         panic(&SyntaxError{
482                 Offset: offset,
483                 What:   fmt.Errorf("unknown value type %+q", b),
484         })
485 }
486
487 func (d *Decoder) parseValueInterface() (interface{}, bool) {
488         b, err := d.r.ReadByte()
489         if err != nil {
490                 panic(err)
491         }
492         d.Offset++
493
494         switch b {
495         case 'e':
496                 return nil, false
497         case 'd':
498                 return d.parseDictInterface(), true
499         case 'l':
500                 return d.parseListInterface(), true
501         case 'i':
502                 return d.parseIntInterface(), true
503         default:
504                 if b >= '0' && b <= '9' {
505                         // string
506                         // append first digit of the length to the buffer
507                         d.buf.WriteByte(b)
508                         return d.parseStringInterface(), true
509                 }
510
511                 d.raiseUnknownValueType(b, d.Offset-1)
512                 panic("unreachable")
513         }
514 }
515
516 func (d *Decoder) parseIntInterface() (ret interface{}) {
517         start := d.Offset - 1
518         d.readUntil('e')
519         if d.buf.Len() == 0 {
520                 panic(&SyntaxError{
521                         Offset: start,
522                         What:   errors.New("empty integer value"),
523                 })
524         }
525
526         n, err := strconv.ParseInt(d.buf.String(), 10, 64)
527         if ne, ok := err.(*strconv.NumError); ok && ne.Err == strconv.ErrRange {
528                 i := new(big.Int)
529                 _, ok := i.SetString(d.buf.String(), 10)
530                 if !ok {
531                         panic(&SyntaxError{
532                                 Offset: start,
533                                 What:   errors.New("failed to parse integer"),
534                         })
535                 }
536                 ret = i
537         } else {
538                 checkForIntParseError(err, start)
539                 ret = n
540         }
541
542         d.buf.Reset()
543         return
544 }
545
546 func (d *Decoder) parseStringInterface() interface{} {
547         start := d.Offset - 1
548
549         // read the string length first
550         d.readUntil(':')
551         length, err := strconv.ParseInt(d.buf.String(), 10, 64)
552         checkForIntParseError(err, start)
553
554         d.buf.Reset()
555         n, err := io.CopyN(&d.buf, d.r, length)
556         d.Offset += n
557         if err != nil {
558                 checkForUnexpectedEOF(err, d.Offset)
559                 panic(&SyntaxError{
560                         Offset: d.Offset,
561                         What:   errors.New("unexpected I/O error: " + err.Error()),
562                 })
563         }
564
565         s := d.buf.String()
566         d.buf.Reset()
567         return s
568 }
569
570 func (d *Decoder) parseDictInterface() interface{} {
571         dict := make(map[string]interface{})
572         for {
573                 keyi, ok := d.parseValueInterface()
574                 if !ok {
575                         break
576                 }
577
578                 key, ok := keyi.(string)
579                 if !ok {
580                         panic(&SyntaxError{
581                                 Offset: d.Offset,
582                                 What:   errors.New("non-string key in a dict"),
583                         })
584                 }
585
586                 valuei, ok := d.parseValueInterface()
587                 if !ok {
588                         break
589                 }
590
591                 dict[key] = valuei
592         }
593         return dict
594 }
595
596 func (d *Decoder) parseListInterface() interface{} {
597         var list []interface{}
598         for {
599                 valuei, ok := d.parseValueInterface()
600                 if !ok {
601                         break
602                 }
603
604                 list = append(list, valuei)
605         }
606         if list == nil {
607                 list = make([]interface{}, 0, 0)
608         }
609         return list
610 }