]> Sergey Matveev's repositories - btrtrc.git/blob - bencode/decode.go
bencode: Don't allow extraneous trailing 'e's
[btrtrc.git] / bencode / decode.go
1 package bencode
2
3 import (
4         "bufio"
5         "bytes"
6         "errors"
7         "io"
8         "reflect"
9         "runtime"
10         "strconv"
11         "strings"
12 )
13
14 type decoder struct {
15         *bufio.Reader
16         offset int64
17         buf    bytes.Buffer
18         key    string
19 }
20
21 func (d *decoder) decode(v interface{}) (err error) {
22         defer func() {
23                 if e := recover(); e != nil {
24                         if _, ok := e.(runtime.Error); ok {
25                                 panic(e)
26                         }
27                         err = e.(error)
28                 }
29         }()
30
31         pv := reflect.ValueOf(v)
32         if pv.Kind() != reflect.Ptr || pv.IsNil() {
33                 return &UnmarshalInvalidArgError{reflect.TypeOf(v)}
34         }
35
36         if !d.parse_value(pv.Elem()) {
37                 d.throwSyntaxError(d.offset-1, errors.New("unexpected 'e'"))
38         }
39         return nil
40 }
41
42 func check_for_unexpected_eof(err error, offset int64) {
43         if err == io.EOF {
44                 panic(&SyntaxError{
45                         Offset: offset,
46                         What:   io.ErrUnexpectedEOF,
47                 })
48         }
49 }
50
51 func (d *decoder) read_byte() byte {
52         b, err := d.ReadByte()
53         if err != nil {
54                 check_for_unexpected_eof(err, d.offset)
55                 panic(err)
56         }
57
58         d.offset++
59         return b
60 }
61
62 // reads data writing it to 'd.buf' until 'sep' byte is encountered, 'sep' byte
63 // is consumed, but not included into the 'd.buf'
64 func (d *decoder) read_until(sep byte) {
65         for {
66                 b := d.read_byte()
67                 if b == sep {
68                         return
69                 }
70                 d.buf.WriteByte(b)
71         }
72 }
73
74 func check_for_int_parse_error(err error, offset int64) {
75         if err != nil {
76                 panic(&SyntaxError{
77                         Offset: offset,
78                         What:   err,
79                 })
80         }
81 }
82
83 func (d *decoder) throwSyntaxError(offset int64, err error) {
84         panic(&SyntaxError{
85                 Offset: offset,
86                 What:   err,
87         })
88 }
89
90 // called when 'i' was consumed
91 func (d *decoder) parse_int(v reflect.Value) {
92         start := d.offset - 1
93         d.read_until('e')
94         if d.buf.Len() == 0 {
95                 panic(&SyntaxError{
96                         Offset: start,
97                         What:   errors.New("empty integer value"),
98                 })
99         }
100
101         switch v.Kind() {
102         case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
103                 n, err := strconv.ParseInt(d.buf.String(), 10, 64)
104                 check_for_int_parse_error(err, start)
105
106                 if v.OverflowInt(n) {
107                         panic(&UnmarshalTypeError{
108                                 Value: "integer " + d.buf.String(),
109                                 Type:  v.Type(),
110                         })
111                 }
112                 v.SetInt(n)
113         case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
114                 n, err := strconv.ParseUint(d.buf.String(), 10, 64)
115                 check_for_int_parse_error(err, start)
116
117                 if v.OverflowUint(n) {
118                         panic(&UnmarshalTypeError{
119                                 Value: "integer " + d.buf.String(),
120                                 Type:  v.Type(),
121                         })
122                 }
123                 v.SetUint(n)
124         case reflect.Bool:
125                 v.SetBool(d.buf.String() != "0")
126         default:
127                 panic(&UnmarshalTypeError{
128                         Value: "integer " + d.buf.String(),
129                         Type:  v.Type(),
130                 })
131         }
132         d.buf.Reset()
133 }
134
135 func (d *decoder) parse_string(v reflect.Value) {
136         start := d.offset - 1
137
138         // read the string length first
139         d.read_until(':')
140         length, err := strconv.ParseInt(d.buf.String(), 10, 64)
141         check_for_int_parse_error(err, start)
142
143         d.buf.Reset()
144         n, err := io.CopyN(&d.buf, d, length)
145         d.offset += n
146         if err != nil {
147                 check_for_unexpected_eof(err, d.offset)
148                 panic(&SyntaxError{
149                         Offset: d.offset,
150                         What:   errors.New("unexpected I/O error: " + err.Error()),
151                 })
152         }
153
154         switch v.Kind() {
155         case reflect.String:
156                 v.SetString(d.buf.String())
157         case reflect.Slice:
158                 if v.Type().Elem().Kind() != reflect.Uint8 {
159                         panic(&UnmarshalTypeError{
160                                 Value: "string",
161                                 Type:  v.Type(),
162                         })
163                 }
164                 sl := make([]byte, len(d.buf.Bytes()))
165                 copy(sl, d.buf.Bytes())
166                 v.Set(reflect.ValueOf(sl))
167         default:
168                 panic(&UnmarshalTypeError{
169                         Value: "string",
170                         Type:  v.Type(),
171                 })
172         }
173
174         d.buf.Reset()
175 }
176
177 func (d *decoder) parse_dict(v reflect.Value) {
178         switch v.Kind() {
179         case reflect.Map:
180                 t := v.Type()
181                 if t.Key().Kind() != reflect.String {
182                         panic(&UnmarshalTypeError{
183                                 Value: "object",
184                                 Type:  t,
185                         })
186                 }
187                 if v.IsNil() {
188                         v.Set(reflect.MakeMap(t))
189                 }
190         case reflect.Struct:
191         default:
192                 panic(&UnmarshalTypeError{
193                         Value: "object",
194                         Type:  v.Type(),
195                 })
196         }
197
198         var map_elem reflect.Value
199
200         // so, at this point 'd' byte was consumed, let's just read key/value
201         // pairs one by one
202         for {
203                 var valuev reflect.Value
204                 keyv := reflect.ValueOf(&d.key).Elem()
205                 if !d.parse_value(keyv) {
206                         return
207                 }
208
209                 // get valuev as a map value or as a struct field
210                 switch v.Kind() {
211                 case reflect.Map:
212                         elem_type := v.Type().Elem()
213                         if !map_elem.IsValid() {
214                                 map_elem = reflect.New(elem_type).Elem()
215                         } else {
216                                 map_elem.Set(reflect.Zero(elem_type))
217                         }
218                         valuev = map_elem
219                 case reflect.Struct:
220                         var f reflect.StructField
221                         var ok bool
222
223                         t := v.Type()
224                         for i, n := 0, t.NumField(); i < n; i++ {
225                                 f = t.Field(i)
226                                 tag := f.Tag.Get("bencode")
227                                 if tag == "-" {
228                                         continue
229                                 }
230                                 if f.Anonymous {
231                                         continue
232                                 }
233
234                                 tag_name, _ := parse_tag(tag)
235                                 if tag_name == d.key {
236                                         ok = true
237                                         break
238                                 }
239
240                                 if f.Name == d.key {
241                                         ok = true
242                                         break
243                                 }
244
245                                 if strings.EqualFold(f.Name, d.key) {
246                                         ok = true
247                                         break
248                                 }
249                         }
250
251                         if ok {
252                                 if f.PkgPath != "" {
253                                         panic(&UnmarshalFieldError{
254                                                 Key:   d.key,
255                                                 Type:  v.Type(),
256                                                 Field: f,
257                                         })
258                                 } else {
259                                         valuev = v.FieldByIndex(f.Index)
260                                 }
261                         } else {
262                                 _, ok := d.parse_value_interface()
263                                 if !ok {
264                                         panic(&SyntaxError{
265                                                 Offset: d.offset,
266                                                 What:   errors.New("unexpected end of dict, no matching value for a given key"),
267                                         })
268                                 }
269                                 continue
270                         }
271                 }
272
273                 // now we need to actually parse it
274                 if !d.parse_value(valuev) {
275                         panic(&SyntaxError{
276                                 Offset: d.offset,
277                                 What:   errors.New("unexpected end of dict, no matching value for a given key"),
278                         })
279                 }
280
281                 if v.Kind() == reflect.Map {
282                         v.SetMapIndex(keyv, valuev)
283                 }
284         }
285 }
286
287 func (d *decoder) parse_list(v reflect.Value) {
288         switch v.Kind() {
289         case reflect.Array, reflect.Slice:
290         default:
291                 panic(&UnmarshalTypeError{
292                         Value: "array",
293                         Type:  v.Type(),
294                 })
295         }
296
297         i := 0
298         for {
299                 if v.Kind() == reflect.Slice && i >= v.Len() {
300                         v.Set(reflect.Append(v, reflect.Zero(v.Type().Elem())))
301                 }
302
303                 ok := false
304                 if i < v.Len() {
305                         ok = d.parse_value(v.Index(i))
306                 } else {
307                         _, ok = d.parse_value_interface()
308                 }
309
310                 if !ok {
311                         break
312                 }
313
314                 i++
315         }
316
317         if i < v.Len() {
318                 if v.Kind() == reflect.Array {
319                         z := reflect.Zero(v.Type().Elem())
320                         for n := v.Len(); i < n; i++ {
321                                 v.Index(i).Set(z)
322                         }
323                 } else {
324                         v.SetLen(i)
325                 }
326         }
327
328         if i == 0 && v.Kind() == reflect.Slice {
329                 v.Set(reflect.MakeSlice(v.Type(), 0, 0))
330         }
331 }
332
333 func (d *decoder) read_one_value() bool {
334         b, err := d.ReadByte()
335         if err != nil {
336                 panic(err)
337         }
338         if b == 'e' {
339                 d.UnreadByte()
340                 return false
341         } else {
342                 d.offset++
343                 d.buf.WriteByte(b)
344         }
345
346         switch b {
347         case 'd', 'l':
348                 // read until there is nothing to read
349                 for d.read_one_value() {
350                 }
351                 // consume 'e' as well
352                 b = d.read_byte()
353                 d.buf.WriteByte(b)
354         case 'i':
355                 d.read_until('e')
356                 d.buf.WriteString("e")
357         default:
358                 if b >= '0' && b <= '9' {
359                         start := d.buf.Len() - 1
360                         d.read_until(':')
361                         length, err := strconv.ParseInt(d.buf.String()[start:], 10, 64)
362                         check_for_int_parse_error(err, d.offset-1)
363
364                         d.buf.WriteString(":")
365                         n, err := io.CopyN(&d.buf, d, length)
366                         d.offset += n
367                         if err != nil {
368                                 check_for_unexpected_eof(err, d.offset)
369                                 panic(&SyntaxError{
370                                         Offset: d.offset,
371                                         What:   errors.New("unexpected I/O error: " + err.Error()),
372                                 })
373                         }
374                         break
375                 }
376
377                 // unknown value
378                 panic(&SyntaxError{
379                         Offset: d.offset - 1,
380                         What:   errors.New("unknown value type (invalid bencode?)"),
381                 })
382         }
383
384         return true
385
386 }
387
388 func (d *decoder) parse_unmarshaler(v reflect.Value) bool {
389         m, ok := v.Interface().(Unmarshaler)
390         if !ok {
391                 // T doesn't work, try *T
392                 if v.Kind() != reflect.Ptr && v.CanAddr() {
393                         m, ok = v.Addr().Interface().(Unmarshaler)
394                         if ok {
395                                 v = v.Addr()
396                         }
397                 }
398         }
399         if ok && (v.Kind() != reflect.Ptr || !v.IsNil()) {
400                 if d.read_one_value() {
401                         err := m.UnmarshalBencode(d.buf.Bytes())
402                         d.buf.Reset()
403                         if err != nil {
404                                 panic(&UnmarshalerError{v.Type(), err})
405                         }
406                         return true
407                 }
408                 d.buf.Reset()
409         }
410
411         return false
412 }
413
414 // returns true if there was a value and it's now stored in 'v', otherwise there
415 // was an end symbol ("e") and no value was stored
416 func (d *decoder) parse_value(v reflect.Value) bool {
417         // we support one level of indirection at the moment
418         if v.Kind() == reflect.Ptr {
419                 // if the pointer is nil, allocate a new element of the type it
420                 // points to
421                 if v.IsNil() {
422                         v.Set(reflect.New(v.Type().Elem()))
423                 }
424                 v = v.Elem()
425         }
426
427         if d.parse_unmarshaler(v) {
428                 return true
429         }
430
431         // common case: interface{}
432         if v.Kind() == reflect.Interface && v.NumMethod() == 0 {
433                 iface, _ := d.parse_value_interface()
434                 v.Set(reflect.ValueOf(iface))
435                 return true
436         }
437
438         b, err := d.ReadByte()
439         if err != nil {
440                 panic(err)
441         }
442         d.offset++
443
444         switch b {
445         case 'e':
446                 return false
447         case 'd':
448                 d.parse_dict(v)
449         case 'l':
450                 d.parse_list(v)
451         case 'i':
452                 d.parse_int(v)
453         default:
454                 if b >= '0' && b <= '9' {
455                         // string
456                         // append first digit of the length to the buffer
457                         d.buf.WriteByte(b)
458                         d.parse_string(v)
459                         break
460                 }
461
462                 // unknown value
463                 panic(&SyntaxError{
464                         Offset: d.offset - 1,
465                         What:   errors.New("unknown value type (invalid bencode?)"),
466                 })
467         }
468
469         return true
470 }
471
472 func (d *decoder) parse_value_interface() (interface{}, bool) {
473         b, err := d.ReadByte()
474         if err != nil {
475                 panic(err)
476         }
477         d.offset++
478
479         switch b {
480         case 'e':
481                 return nil, false
482         case 'd':
483                 return d.parse_dict_interface(), true
484         case 'l':
485                 return d.parse_list_interface(), true
486         case 'i':
487                 return d.parse_int_interface(), true
488         default:
489                 if b >= '0' && b <= '9' {
490                         // string
491                         // append first digit of the length to the buffer
492                         d.buf.WriteByte(b)
493                         return d.parse_string_interface(), true
494                 }
495
496                 // unknown value
497                 panic(&SyntaxError{
498                         Offset: d.offset - 1,
499                         What:   errors.New("unknown value type (invalid bencode?)"),
500                 })
501         }
502 }
503
504 func (d *decoder) parse_int_interface() interface{} {
505         start := d.offset - 1
506         d.read_until('e')
507         if d.buf.Len() == 0 {
508                 panic(&SyntaxError{
509                         Offset: start,
510                         What:   errors.New("empty integer value"),
511                 })
512         }
513
514         n, err := strconv.ParseInt(d.buf.String(), 10, 64)
515         check_for_int_parse_error(err, start)
516         d.buf.Reset()
517         return n
518 }
519
520 func (d *decoder) parse_string_interface() interface{} {
521         start := d.offset - 1
522
523         // read the string length first
524         d.read_until(':')
525         length, err := strconv.ParseInt(d.buf.String(), 10, 64)
526         check_for_int_parse_error(err, start)
527
528         d.buf.Reset()
529         n, err := io.CopyN(&d.buf, d, length)
530         d.offset += n
531         if err != nil {
532                 check_for_unexpected_eof(err, d.offset)
533                 panic(&SyntaxError{
534                         Offset: d.offset,
535                         What:   errors.New("unexpected I/O error: " + err.Error()),
536                 })
537         }
538
539         s := d.buf.String()
540         d.buf.Reset()
541         return s
542 }
543
544 func (d *decoder) parse_dict_interface() interface{} {
545         dict := make(map[string]interface{})
546         for {
547                 keyi, ok := d.parse_value_interface()
548                 if !ok {
549                         break
550                 }
551
552                 key, ok := keyi.(string)
553                 if !ok {
554                         panic(&SyntaxError{
555                                 Offset: d.offset,
556                                 What:   errors.New("non-string key in a dict"),
557                         })
558                 }
559
560                 valuei, ok := d.parse_value_interface()
561                 if !ok {
562                         panic(&SyntaxError{
563                                 Offset: d.offset,
564                                 What:   errors.New("unexpected end of dict, no matching value for a given key"),
565                         })
566                 }
567
568                 dict[key] = valuei
569         }
570         return dict
571 }
572
573 func (d *decoder) parse_list_interface() interface{} {
574         var list []interface{}
575         for {
576                 valuei, ok := d.parse_value_interface()
577                 if !ok {
578                         break
579                 }
580
581                 list = append(list, valuei)
582         }
583         if list == nil {
584                 list = make([]interface{}, 0, 0)
585         }
586         return list
587 }