]> Sergey Matveev's repositories - btrtrc.git/blob - bencode/decode.go
ebd4684e2c956f140571d5526ecc9842de4b8e80
[btrtrc.git] / bencode / decode.go
1 package bencode
2
3 import (
4         "bytes"
5         "errors"
6         "fmt"
7         "io"
8         "math/big"
9         "reflect"
10         "runtime"
11         "strconv"
12         "strings"
13 )
14
15 type decoder struct {
16         r interface {
17                 io.ByteScanner
18                 io.Reader
19         }
20         offset int64
21         buf    bytes.Buffer
22         key    string
23 }
24
25 func (d *decoder) decode(v interface{}) (err error) {
26         defer func() {
27                 if e := recover(); e != nil {
28                         if _, ok := e.(runtime.Error); ok {
29                                 panic(e)
30                         }
31                         err = e.(error)
32                 }
33         }()
34
35         pv := reflect.ValueOf(v)
36         if pv.Kind() != reflect.Ptr || pv.IsNil() {
37                 return &UnmarshalInvalidArgError{reflect.TypeOf(v)}
38         }
39
40         if !d.parse_value(pv.Elem()) {
41                 d.throwSyntaxError(d.offset-1, errors.New("unexpected 'e'"))
42         }
43         return nil
44 }
45
46 func check_for_unexpected_eof(err error, offset int64) {
47         if err == io.EOF {
48                 panic(&SyntaxError{
49                         Offset: offset,
50                         What:   io.ErrUnexpectedEOF,
51                 })
52         }
53 }
54
55 func (d *decoder) read_byte() byte {
56         b, err := d.r.ReadByte()
57         if err != nil {
58                 check_for_unexpected_eof(err, d.offset)
59                 panic(err)
60         }
61
62         d.offset++
63         return b
64 }
65
66 // reads data writing it to 'd.buf' until 'sep' byte is encountered, 'sep' byte
67 // is consumed, but not included into the 'd.buf'
68 func (d *decoder) read_until(sep byte) {
69         for {
70                 b := d.read_byte()
71                 if b == sep {
72                         return
73                 }
74                 d.buf.WriteByte(b)
75         }
76 }
77
78 func check_for_int_parse_error(err error, offset int64) {
79         if err != nil {
80                 panic(&SyntaxError{
81                         Offset: offset,
82                         What:   err,
83                 })
84         }
85 }
86
87 func (d *decoder) throwSyntaxError(offset int64, err error) {
88         panic(&SyntaxError{
89                 Offset: offset,
90                 What:   err,
91         })
92 }
93
94 // called when 'i' was consumed
95 func (d *decoder) parse_int(v reflect.Value) {
96         start := d.offset - 1
97         d.read_until('e')
98         if d.buf.Len() == 0 {
99                 panic(&SyntaxError{
100                         Offset: start,
101                         What:   errors.New("empty integer value"),
102                 })
103         }
104
105         s := d.buf.String()
106
107         switch v.Kind() {
108         case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
109                 n, err := strconv.ParseInt(s, 10, 64)
110                 check_for_int_parse_error(err, start)
111
112                 if v.OverflowInt(n) {
113                         panic(&UnmarshalTypeError{
114                                 Value: "integer " + s,
115                                 Type:  v.Type(),
116                         })
117                 }
118                 v.SetInt(n)
119         case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
120                 n, err := strconv.ParseUint(s, 10, 64)
121                 check_for_int_parse_error(err, start)
122
123                 if v.OverflowUint(n) {
124                         panic(&UnmarshalTypeError{
125                                 Value: "integer " + s,
126                                 Type:  v.Type(),
127                         })
128                 }
129                 v.SetUint(n)
130         case reflect.Bool:
131                 v.SetBool(s != "0")
132         default:
133                 panic(&UnmarshalTypeError{
134                         Value: "integer " + s,
135                         Type:  v.Type(),
136                 })
137         }
138         d.buf.Reset()
139 }
140
141 func (d *decoder) parse_string(v reflect.Value) {
142         start := d.offset - 1
143
144         // read the string length first
145         d.read_until(':')
146         length, err := strconv.ParseInt(d.buf.String(), 10, 64)
147         check_for_int_parse_error(err, start)
148
149         d.buf.Reset()
150         n, err := io.CopyN(&d.buf, d.r, length)
151         d.offset += n
152         if err != nil {
153                 check_for_unexpected_eof(err, d.offset)
154                 panic(&SyntaxError{
155                         Offset: d.offset,
156                         What:   errors.New("unexpected I/O error: " + err.Error()),
157                 })
158         }
159
160         switch v.Kind() {
161         case reflect.String:
162                 v.SetString(d.buf.String())
163         case reflect.Slice:
164                 if v.Type().Elem().Kind() != reflect.Uint8 {
165                         panic(&UnmarshalTypeError{
166                                 Value: "string",
167                                 Type:  v.Type(),
168                         })
169                 }
170                 sl := make([]byte, len(d.buf.Bytes()))
171                 copy(sl, d.buf.Bytes())
172                 v.Set(reflect.ValueOf(sl))
173         default:
174                 panic(&UnmarshalTypeError{
175                         Value: "string",
176                         Type:  v.Type(),
177                 })
178         }
179
180         d.buf.Reset()
181 }
182
183 func (d *decoder) parse_dict(v reflect.Value) {
184         switch v.Kind() {
185         case reflect.Map:
186                 t := v.Type()
187                 if t.Key().Kind() != reflect.String {
188                         panic(&UnmarshalTypeError{
189                                 Value: "object",
190                                 Type:  t,
191                         })
192                 }
193                 if v.IsNil() {
194                         v.Set(reflect.MakeMap(t))
195                 }
196         case reflect.Struct:
197         default:
198                 panic(&UnmarshalTypeError{
199                         Value: "object",
200                         Type:  v.Type(),
201                 })
202         }
203
204         var map_elem reflect.Value
205
206         // so, at this point 'd' byte was consumed, let's just read key/value
207         // pairs one by one
208         for {
209                 var valuev reflect.Value
210                 keyv := reflect.ValueOf(&d.key).Elem()
211                 if !d.parse_value(keyv) {
212                         return
213                 }
214
215                 // get valuev as a map value or as a struct field
216                 switch v.Kind() {
217                 case reflect.Map:
218                         elem_type := v.Type().Elem()
219                         if !map_elem.IsValid() {
220                                 map_elem = reflect.New(elem_type).Elem()
221                         } else {
222                                 map_elem.Set(reflect.Zero(elem_type))
223                         }
224                         valuev = map_elem
225                 case reflect.Struct:
226                         var f reflect.StructField
227                         var ok bool
228
229                         t := v.Type()
230                         for i, n := 0, t.NumField(); i < n; i++ {
231                                 f = t.Field(i)
232                                 tag := f.Tag.Get("bencode")
233                                 if tag == "-" {
234                                         continue
235                                 }
236                                 if f.Anonymous {
237                                         continue
238                                 }
239
240                                 tag_name, _ := parse_tag(tag)
241                                 if tag_name == d.key {
242                                         ok = true
243                                         break
244                                 }
245
246                                 if f.Name == d.key {
247                                         ok = true
248                                         break
249                                 }
250
251                                 if strings.EqualFold(f.Name, d.key) {
252                                         ok = true
253                                         break
254                                 }
255                         }
256
257                         if ok {
258                                 if f.PkgPath != "" {
259                                         panic(&UnmarshalFieldError{
260                                                 Key:   d.key,
261                                                 Type:  v.Type(),
262                                                 Field: f,
263                                         })
264                                 } else {
265                                         valuev = v.FieldByIndex(f.Index)
266                                 }
267                         } else {
268                                 _, ok := d.parse_value_interface()
269                                 if !ok {
270                                         return
271                                 }
272                                 continue
273                         }
274                 }
275
276                 // now we need to actually parse it
277                 if !d.parse_value(valuev) {
278                         return
279                 }
280
281                 if v.Kind() == reflect.Map {
282                         v.SetMapIndex(keyv, valuev)
283                 }
284         }
285 }
286
287 func (d *decoder) parse_list(v reflect.Value) {
288         switch v.Kind() {
289         case reflect.Array, reflect.Slice:
290         default:
291                 panic(&UnmarshalTypeError{
292                         Value: "array",
293                         Type:  v.Type(),
294                 })
295         }
296
297         i := 0
298         for {
299                 if v.Kind() == reflect.Slice && i >= v.Len() {
300                         v.Set(reflect.Append(v, reflect.Zero(v.Type().Elem())))
301                 }
302
303                 ok := false
304                 if i < v.Len() {
305                         ok = d.parse_value(v.Index(i))
306                 } else {
307                         _, ok = d.parse_value_interface()
308                 }
309
310                 if !ok {
311                         break
312                 }
313
314                 i++
315         }
316
317         if i < v.Len() {
318                 if v.Kind() == reflect.Array {
319                         z := reflect.Zero(v.Type().Elem())
320                         for n := v.Len(); i < n; i++ {
321                                 v.Index(i).Set(z)
322                         }
323                 } else {
324                         v.SetLen(i)
325                 }
326         }
327
328         if i == 0 && v.Kind() == reflect.Slice {
329                 v.Set(reflect.MakeSlice(v.Type(), 0, 0))
330         }
331 }
332
333 func (d *decoder) read_one_value() bool {
334         b, err := d.r.ReadByte()
335         if err != nil {
336                 panic(err)
337         }
338         if b == 'e' {
339                 d.r.UnreadByte()
340                 return false
341         } else {
342                 d.offset++
343                 d.buf.WriteByte(b)
344         }
345
346         switch b {
347         case 'd', 'l':
348                 // read until there is nothing to read
349                 for d.read_one_value() {
350                 }
351                 // consume 'e' as well
352                 b = d.read_byte()
353                 d.buf.WriteByte(b)
354         case 'i':
355                 d.read_until('e')
356                 d.buf.WriteString("e")
357         default:
358                 if b >= '0' && b <= '9' {
359                         start := d.buf.Len() - 1
360                         d.read_until(':')
361                         length, err := strconv.ParseInt(d.buf.String()[start:], 10, 64)
362                         check_for_int_parse_error(err, d.offset-1)
363
364                         d.buf.WriteString(":")
365                         n, err := io.CopyN(&d.buf, d.r, length)
366                         d.offset += n
367                         if err != nil {
368                                 check_for_unexpected_eof(err, d.offset)
369                                 panic(&SyntaxError{
370                                         Offset: d.offset,
371                                         What:   errors.New("unexpected I/O error: " + err.Error()),
372                                 })
373                         }
374                         break
375                 }
376
377                 d.raiseUnknownValueType(b, d.offset-1)
378         }
379
380         return true
381
382 }
383
384 func (d *decoder) parse_unmarshaler(v reflect.Value) bool {
385         m, ok := v.Interface().(Unmarshaler)
386         if !ok {
387                 // T doesn't work, try *T
388                 if v.Kind() != reflect.Ptr && v.CanAddr() {
389                         m, ok = v.Addr().Interface().(Unmarshaler)
390                         if ok {
391                                 v = v.Addr()
392                         }
393                 }
394         }
395         if ok && (v.Kind() != reflect.Ptr || !v.IsNil()) {
396                 if d.read_one_value() {
397                         err := m.UnmarshalBencode(d.buf.Bytes())
398                         d.buf.Reset()
399                         if err != nil {
400                                 panic(&UnmarshalerError{v.Type(), err})
401                         }
402                         return true
403                 }
404                 d.buf.Reset()
405         }
406
407         return false
408 }
409
410 // Returns true if there was a value and it's now stored in 'v', otherwise
411 // there was an end symbol ("e") and no value was stored.
412 func (d *decoder) parse_value(v reflect.Value) bool {
413         // we support one level of indirection at the moment
414         if v.Kind() == reflect.Ptr {
415                 // if the pointer is nil, allocate a new element of the type it
416                 // points to
417                 if v.IsNil() {
418                         v.Set(reflect.New(v.Type().Elem()))
419                 }
420                 v = v.Elem()
421         }
422
423         if d.parse_unmarshaler(v) {
424                 return true
425         }
426
427         // common case: interface{}
428         if v.Kind() == reflect.Interface && v.NumMethod() == 0 {
429                 iface, _ := d.parse_value_interface()
430                 v.Set(reflect.ValueOf(iface))
431                 return true
432         }
433
434         b, err := d.r.ReadByte()
435         if err != nil {
436                 panic(err)
437         }
438         d.offset++
439
440         switch b {
441         case 'e':
442                 return false
443         case 'd':
444                 d.parse_dict(v)
445         case 'l':
446                 d.parse_list(v)
447         case 'i':
448                 d.parse_int(v)
449         default:
450                 if b >= '0' && b <= '9' {
451                         // string
452                         // append first digit of the length to the buffer
453                         d.buf.WriteByte(b)
454                         d.parse_string(v)
455                         break
456                 }
457
458                 d.raiseUnknownValueType(b, d.offset-1)
459         }
460
461         return true
462 }
463
464 // An unknown bencode type character was encountered.
465 func (d *decoder) raiseUnknownValueType(b byte, offset int64) {
466         panic(&SyntaxError{
467                 Offset: offset,
468                 What:   fmt.Errorf("unknown value type %+q", b),
469         })
470 }
471
472 func (d *decoder) parse_value_interface() (interface{}, bool) {
473         b, err := d.r.ReadByte()
474         if err != nil {
475                 panic(err)
476         }
477         d.offset++
478
479         switch b {
480         case 'e':
481                 return nil, false
482         case 'd':
483                 return d.parse_dict_interface(), true
484         case 'l':
485                 return d.parse_list_interface(), true
486         case 'i':
487                 return d.parse_int_interface(), true
488         default:
489                 if b >= '0' && b <= '9' {
490                         // string
491                         // append first digit of the length to the buffer
492                         d.buf.WriteByte(b)
493                         return d.parse_string_interface(), true
494                 }
495
496                 d.raiseUnknownValueType(b, d.offset-1)
497                 panic("unreachable")
498         }
499 }
500
501 func (d *decoder) parse_int_interface() (ret interface{}) {
502         start := d.offset - 1
503         d.read_until('e')
504         if d.buf.Len() == 0 {
505                 panic(&SyntaxError{
506                         Offset: start,
507                         What:   errors.New("empty integer value"),
508                 })
509         }
510
511         n, err := strconv.ParseInt(d.buf.String(), 10, 64)
512         if ne, ok := err.(*strconv.NumError); ok && ne.Err == strconv.ErrRange {
513                 i := new(big.Int)
514                 _, ok := i.SetString(d.buf.String(), 10)
515                 if !ok {
516                         panic(&SyntaxError{
517                                 Offset: start,
518                                 What:   errors.New("failed to parse integer"),
519                         })
520                 }
521                 ret = i
522         } else {
523                 check_for_int_parse_error(err, start)
524                 ret = n
525         }
526
527         d.buf.Reset()
528         return
529 }
530
531 func (d *decoder) parse_string_interface() interface{} {
532         start := d.offset - 1
533
534         // read the string length first
535         d.read_until(':')
536         length, err := strconv.ParseInt(d.buf.String(), 10, 64)
537         check_for_int_parse_error(err, start)
538
539         d.buf.Reset()
540         n, err := io.CopyN(&d.buf, d.r, length)
541         d.offset += n
542         if err != nil {
543                 check_for_unexpected_eof(err, d.offset)
544                 panic(&SyntaxError{
545                         Offset: d.offset,
546                         What:   errors.New("unexpected I/O error: " + err.Error()),
547                 })
548         }
549
550         s := d.buf.String()
551         d.buf.Reset()
552         return s
553 }
554
555 func (d *decoder) parse_dict_interface() interface{} {
556         dict := make(map[string]interface{})
557         for {
558                 keyi, ok := d.parse_value_interface()
559                 if !ok {
560                         break
561                 }
562
563                 key, ok := keyi.(string)
564                 if !ok {
565                         panic(&SyntaxError{
566                                 Offset: d.offset,
567                                 What:   errors.New("non-string key in a dict"),
568                         })
569                 }
570
571                 valuei, ok := d.parse_value_interface()
572                 if !ok {
573                         break
574                 }
575
576                 dict[key] = valuei
577         }
578         return dict
579 }
580
581 func (d *decoder) parse_list_interface() interface{} {
582         var list []interface{}
583         for {
584                 valuei, ok := d.parse_value_interface()
585                 if !ok {
586                         break
587                 }
588
589                 list = append(list, valuei)
590         }
591         if list == nil {
592                 list = make([]interface{}, 0, 0)
593         }
594         return list
595 }