]> Sergey Matveev's repositories - btrtrc.git/blob - bencode/decode.go
bencode: Decode very large integers as big.Int if they overflow int64
[btrtrc.git] / bencode / decode.go
1 package bencode
2
3 import (
4         "bufio"
5         "bytes"
6         "errors"
7         "io"
8         "math/big"
9         "reflect"
10         "runtime"
11         "strconv"
12         "strings"
13 )
14
15 type decoder struct {
16         *bufio.Reader
17         offset int64
18         buf    bytes.Buffer
19         key    string
20 }
21
22 func (d *decoder) decode(v interface{}) (err error) {
23         defer func() {
24                 if e := recover(); e != nil {
25                         if _, ok := e.(runtime.Error); ok {
26                                 panic(e)
27                         }
28                         err = e.(error)
29                 }
30         }()
31
32         pv := reflect.ValueOf(v)
33         if pv.Kind() != reflect.Ptr || pv.IsNil() {
34                 return &UnmarshalInvalidArgError{reflect.TypeOf(v)}
35         }
36
37         if !d.parse_value(pv.Elem()) {
38                 d.throwSyntaxError(d.offset-1, errors.New("unexpected 'e'"))
39         }
40         return nil
41 }
42
43 func check_for_unexpected_eof(err error, offset int64) {
44         if err == io.EOF {
45                 panic(&SyntaxError{
46                         Offset: offset,
47                         What:   io.ErrUnexpectedEOF,
48                 })
49         }
50 }
51
52 func (d *decoder) read_byte() byte {
53         b, err := d.ReadByte()
54         if err != nil {
55                 check_for_unexpected_eof(err, d.offset)
56                 panic(err)
57         }
58
59         d.offset++
60         return b
61 }
62
63 // reads data writing it to 'd.buf' until 'sep' byte is encountered, 'sep' byte
64 // is consumed, but not included into the 'd.buf'
65 func (d *decoder) read_until(sep byte) {
66         for {
67                 b := d.read_byte()
68                 if b == sep {
69                         return
70                 }
71                 d.buf.WriteByte(b)
72         }
73 }
74
75 func check_for_int_parse_error(err error, offset int64) {
76         if err != nil {
77                 panic(&SyntaxError{
78                         Offset: offset,
79                         What:   err,
80                 })
81         }
82 }
83
84 func (d *decoder) throwSyntaxError(offset int64, err error) {
85         panic(&SyntaxError{
86                 Offset: offset,
87                 What:   err,
88         })
89 }
90
91 // called when 'i' was consumed
92 func (d *decoder) parse_int(v reflect.Value) {
93         start := d.offset - 1
94         d.read_until('e')
95         if d.buf.Len() == 0 {
96                 panic(&SyntaxError{
97                         Offset: start,
98                         What:   errors.New("empty integer value"),
99                 })
100         }
101
102         s := d.buf.String()
103
104         switch v.Kind() {
105         case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
106                 n, err := strconv.ParseInt(s, 10, 64)
107                 check_for_int_parse_error(err, start)
108
109                 if v.OverflowInt(n) {
110                         panic(&UnmarshalTypeError{
111                                 Value: "integer " + s,
112                                 Type:  v.Type(),
113                         })
114                 }
115                 v.SetInt(n)
116         case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
117                 n, err := strconv.ParseUint(s, 10, 64)
118                 check_for_int_parse_error(err, start)
119
120                 if v.OverflowUint(n) {
121                         panic(&UnmarshalTypeError{
122                                 Value: "integer " + s,
123                                 Type:  v.Type(),
124                         })
125                 }
126                 v.SetUint(n)
127         case reflect.Bool:
128                 v.SetBool(s != "0")
129         default:
130                 panic(&UnmarshalTypeError{
131                         Value: "integer " + s,
132                         Type:  v.Type(),
133                 })
134         }
135         d.buf.Reset()
136 }
137
138 func (d *decoder) parse_string(v reflect.Value) {
139         start := d.offset - 1
140
141         // read the string length first
142         d.read_until(':')
143         length, err := strconv.ParseInt(d.buf.String(), 10, 64)
144         check_for_int_parse_error(err, start)
145
146         d.buf.Reset()
147         n, err := io.CopyN(&d.buf, d, length)
148         d.offset += n
149         if err != nil {
150                 check_for_unexpected_eof(err, d.offset)
151                 panic(&SyntaxError{
152                         Offset: d.offset,
153                         What:   errors.New("unexpected I/O error: " + err.Error()),
154                 })
155         }
156
157         switch v.Kind() {
158         case reflect.String:
159                 v.SetString(d.buf.String())
160         case reflect.Slice:
161                 if v.Type().Elem().Kind() != reflect.Uint8 {
162                         panic(&UnmarshalTypeError{
163                                 Value: "string",
164                                 Type:  v.Type(),
165                         })
166                 }
167                 sl := make([]byte, len(d.buf.Bytes()))
168                 copy(sl, d.buf.Bytes())
169                 v.Set(reflect.ValueOf(sl))
170         default:
171                 panic(&UnmarshalTypeError{
172                         Value: "string",
173                         Type:  v.Type(),
174                 })
175         }
176
177         d.buf.Reset()
178 }
179
180 func (d *decoder) parse_dict(v reflect.Value) {
181         switch v.Kind() {
182         case reflect.Map:
183                 t := v.Type()
184                 if t.Key().Kind() != reflect.String {
185                         panic(&UnmarshalTypeError{
186                                 Value: "object",
187                                 Type:  t,
188                         })
189                 }
190                 if v.IsNil() {
191                         v.Set(reflect.MakeMap(t))
192                 }
193         case reflect.Struct:
194         default:
195                 panic(&UnmarshalTypeError{
196                         Value: "object",
197                         Type:  v.Type(),
198                 })
199         }
200
201         var map_elem reflect.Value
202
203         // so, at this point 'd' byte was consumed, let's just read key/value
204         // pairs one by one
205         for {
206                 var valuev reflect.Value
207                 keyv := reflect.ValueOf(&d.key).Elem()
208                 if !d.parse_value(keyv) {
209                         return
210                 }
211
212                 // get valuev as a map value or as a struct field
213                 switch v.Kind() {
214                 case reflect.Map:
215                         elem_type := v.Type().Elem()
216                         if !map_elem.IsValid() {
217                                 map_elem = reflect.New(elem_type).Elem()
218                         } else {
219                                 map_elem.Set(reflect.Zero(elem_type))
220                         }
221                         valuev = map_elem
222                 case reflect.Struct:
223                         var f reflect.StructField
224                         var ok bool
225
226                         t := v.Type()
227                         for i, n := 0, t.NumField(); i < n; i++ {
228                                 f = t.Field(i)
229                                 tag := f.Tag.Get("bencode")
230                                 if tag == "-" {
231                                         continue
232                                 }
233                                 if f.Anonymous {
234                                         continue
235                                 }
236
237                                 tag_name, _ := parse_tag(tag)
238                                 if tag_name == d.key {
239                                         ok = true
240                                         break
241                                 }
242
243                                 if f.Name == d.key {
244                                         ok = true
245                                         break
246                                 }
247
248                                 if strings.EqualFold(f.Name, d.key) {
249                                         ok = true
250                                         break
251                                 }
252                         }
253
254                         if ok {
255                                 if f.PkgPath != "" {
256                                         panic(&UnmarshalFieldError{
257                                                 Key:   d.key,
258                                                 Type:  v.Type(),
259                                                 Field: f,
260                                         })
261                                 } else {
262                                         valuev = v.FieldByIndex(f.Index)
263                                 }
264                         } else {
265                                 _, ok := d.parse_value_interface()
266                                 if !ok {
267                                         panic(&SyntaxError{
268                                                 Offset: d.offset,
269                                                 What:   errors.New("unexpected end of dict, no matching value for a given key"),
270                                         })
271                                 }
272                                 continue
273                         }
274                 }
275
276                 // now we need to actually parse it
277                 if !d.parse_value(valuev) {
278                         panic(&SyntaxError{
279                                 Offset: d.offset,
280                                 What:   errors.New("unexpected end of dict, no matching value for a given key"),
281                         })
282                 }
283
284                 if v.Kind() == reflect.Map {
285                         v.SetMapIndex(keyv, valuev)
286                 }
287         }
288 }
289
290 func (d *decoder) parse_list(v reflect.Value) {
291         switch v.Kind() {
292         case reflect.Array, reflect.Slice:
293         default:
294                 panic(&UnmarshalTypeError{
295                         Value: "array",
296                         Type:  v.Type(),
297                 })
298         }
299
300         i := 0
301         for {
302                 if v.Kind() == reflect.Slice && i >= v.Len() {
303                         v.Set(reflect.Append(v, reflect.Zero(v.Type().Elem())))
304                 }
305
306                 ok := false
307                 if i < v.Len() {
308                         ok = d.parse_value(v.Index(i))
309                 } else {
310                         _, ok = d.parse_value_interface()
311                 }
312
313                 if !ok {
314                         break
315                 }
316
317                 i++
318         }
319
320         if i < v.Len() {
321                 if v.Kind() == reflect.Array {
322                         z := reflect.Zero(v.Type().Elem())
323                         for n := v.Len(); i < n; i++ {
324                                 v.Index(i).Set(z)
325                         }
326                 } else {
327                         v.SetLen(i)
328                 }
329         }
330
331         if i == 0 && v.Kind() == reflect.Slice {
332                 v.Set(reflect.MakeSlice(v.Type(), 0, 0))
333         }
334 }
335
336 func (d *decoder) read_one_value() bool {
337         b, err := d.ReadByte()
338         if err != nil {
339                 panic(err)
340         }
341         if b == 'e' {
342                 d.UnreadByte()
343                 return false
344         } else {
345                 d.offset++
346                 d.buf.WriteByte(b)
347         }
348
349         switch b {
350         case 'd', 'l':
351                 // read until there is nothing to read
352                 for d.read_one_value() {
353                 }
354                 // consume 'e' as well
355                 b = d.read_byte()
356                 d.buf.WriteByte(b)
357         case 'i':
358                 d.read_until('e')
359                 d.buf.WriteString("e")
360         default:
361                 if b >= '0' && b <= '9' {
362                         start := d.buf.Len() - 1
363                         d.read_until(':')
364                         length, err := strconv.ParseInt(d.buf.String()[start:], 10, 64)
365                         check_for_int_parse_error(err, d.offset-1)
366
367                         d.buf.WriteString(":")
368                         n, err := io.CopyN(&d.buf, d, length)
369                         d.offset += n
370                         if err != nil {
371                                 check_for_unexpected_eof(err, d.offset)
372                                 panic(&SyntaxError{
373                                         Offset: d.offset,
374                                         What:   errors.New("unexpected I/O error: " + err.Error()),
375                                 })
376                         }
377                         break
378                 }
379
380                 // unknown value
381                 panic(&SyntaxError{
382                         Offset: d.offset - 1,
383                         What:   errors.New("unknown value type (invalid bencode?)"),
384                 })
385         }
386
387         return true
388
389 }
390
391 func (d *decoder) parse_unmarshaler(v reflect.Value) bool {
392         m, ok := v.Interface().(Unmarshaler)
393         if !ok {
394                 // T doesn't work, try *T
395                 if v.Kind() != reflect.Ptr && v.CanAddr() {
396                         m, ok = v.Addr().Interface().(Unmarshaler)
397                         if ok {
398                                 v = v.Addr()
399                         }
400                 }
401         }
402         if ok && (v.Kind() != reflect.Ptr || !v.IsNil()) {
403                 if d.read_one_value() {
404                         err := m.UnmarshalBencode(d.buf.Bytes())
405                         d.buf.Reset()
406                         if err != nil {
407                                 panic(&UnmarshalerError{v.Type(), err})
408                         }
409                         return true
410                 }
411                 d.buf.Reset()
412         }
413
414         return false
415 }
416
417 // returns true if there was a value and it's now stored in 'v', otherwise there
418 // was an end symbol ("e") and no value was stored
419 func (d *decoder) parse_value(v reflect.Value) bool {
420         // we support one level of indirection at the moment
421         if v.Kind() == reflect.Ptr {
422                 // if the pointer is nil, allocate a new element of the type it
423                 // points to
424                 if v.IsNil() {
425                         v.Set(reflect.New(v.Type().Elem()))
426                 }
427                 v = v.Elem()
428         }
429
430         if d.parse_unmarshaler(v) {
431                 return true
432         }
433
434         // common case: interface{}
435         if v.Kind() == reflect.Interface && v.NumMethod() == 0 {
436                 iface, _ := d.parse_value_interface()
437                 v.Set(reflect.ValueOf(iface))
438                 return true
439         }
440
441         b, err := d.ReadByte()
442         if err != nil {
443                 panic(err)
444         }
445         d.offset++
446
447         switch b {
448         case 'e':
449                 return false
450         case 'd':
451                 d.parse_dict(v)
452         case 'l':
453                 d.parse_list(v)
454         case 'i':
455                 d.parse_int(v)
456         default:
457                 if b >= '0' && b <= '9' {
458                         // string
459                         // append first digit of the length to the buffer
460                         d.buf.WriteByte(b)
461                         d.parse_string(v)
462                         break
463                 }
464
465                 // unknown value
466                 panic(&SyntaxError{
467                         Offset: d.offset - 1,
468                         What:   errors.New("unknown value type (invalid bencode?)"),
469                 })
470         }
471
472         return true
473 }
474
475 func (d *decoder) parse_value_interface() (interface{}, bool) {
476         b, err := d.ReadByte()
477         if err != nil {
478                 panic(err)
479         }
480         d.offset++
481
482         switch b {
483         case 'e':
484                 return nil, false
485         case 'd':
486                 return d.parse_dict_interface(), true
487         case 'l':
488                 return d.parse_list_interface(), true
489         case 'i':
490                 return d.parse_int_interface(), true
491         default:
492                 if b >= '0' && b <= '9' {
493                         // string
494                         // append first digit of the length to the buffer
495                         d.buf.WriteByte(b)
496                         return d.parse_string_interface(), true
497                 }
498
499                 // unknown value
500                 panic(&SyntaxError{
501                         Offset: d.offset - 1,
502                         What:   errors.New("unknown value type (invalid bencode?)"),
503                 })
504         }
505 }
506
507 func (d *decoder) parse_int_interface() (ret interface{}) {
508         start := d.offset - 1
509         d.read_until('e')
510         if d.buf.Len() == 0 {
511                 panic(&SyntaxError{
512                         Offset: start,
513                         What:   errors.New("empty integer value"),
514                 })
515         }
516
517         n, err := strconv.ParseInt(d.buf.String(), 10, 64)
518         if ne, ok := err.(*strconv.NumError); ok && ne.Err == strconv.ErrRange {
519                 i := new(big.Int)
520                 _, ok := i.SetString(d.buf.String(), 10)
521                 if !ok {
522                         panic(&SyntaxError{
523                                 Offset: start,
524                                 What:   errors.New("failed to parse integer"),
525                         })
526                 }
527                 ret = i
528         } else {
529                 check_for_int_parse_error(err, start)
530                 ret = n
531         }
532
533         d.buf.Reset()
534         return
535 }
536
537 func (d *decoder) parse_string_interface() interface{} {
538         start := d.offset - 1
539
540         // read the string length first
541         d.read_until(':')
542         length, err := strconv.ParseInt(d.buf.String(), 10, 64)
543         check_for_int_parse_error(err, start)
544
545         d.buf.Reset()
546         n, err := io.CopyN(&d.buf, d, length)
547         d.offset += n
548         if err != nil {
549                 check_for_unexpected_eof(err, d.offset)
550                 panic(&SyntaxError{
551                         Offset: d.offset,
552                         What:   errors.New("unexpected I/O error: " + err.Error()),
553                 })
554         }
555
556         s := d.buf.String()
557         d.buf.Reset()
558         return s
559 }
560
561 func (d *decoder) parse_dict_interface() interface{} {
562         dict := make(map[string]interface{})
563         for {
564                 keyi, ok := d.parse_value_interface()
565                 if !ok {
566                         break
567                 }
568
569                 key, ok := keyi.(string)
570                 if !ok {
571                         panic(&SyntaxError{
572                                 Offset: d.offset,
573                                 What:   errors.New("non-string key in a dict"),
574                         })
575                 }
576
577                 valuei, ok := d.parse_value_interface()
578                 if !ok {
579                         panic(&SyntaxError{
580                                 Offset: d.offset,
581                                 What:   errors.New("unexpected end of dict, no matching value for a given key"),
582                         })
583                 }
584
585                 dict[key] = valuei
586         }
587         return dict
588 }
589
590 func (d *decoder) parse_list_interface() interface{} {
591         var list []interface{}
592         for {
593                 valuei, ok := d.parse_value_interface()
594                 if !ok {
595                         break
596                 }
597
598                 list = append(list, valuei)
599         }
600         if list == nil {
601                 list = make([]interface{}, 0, 0)
602         }
603         return list
604 }