]> Sergey Matveev's repositories - btrtrc.git/blob - bencode/decode.go
bencode: Improve unknown value type error
[btrtrc.git] / bencode / decode.go
1 package bencode
2
3 import (
4         "bufio"
5         "bytes"
6         "errors"
7         "fmt"
8         "io"
9         "math/big"
10         "reflect"
11         "runtime"
12         "strconv"
13         "strings"
14 )
15
16 type decoder struct {
17         *bufio.Reader
18         offset int64
19         buf    bytes.Buffer
20         key    string
21 }
22
23 func (d *decoder) decode(v interface{}) (err error) {
24         defer func() {
25                 if e := recover(); e != nil {
26                         if _, ok := e.(runtime.Error); ok {
27                                 panic(e)
28                         }
29                         err = e.(error)
30                 }
31         }()
32
33         pv := reflect.ValueOf(v)
34         if pv.Kind() != reflect.Ptr || pv.IsNil() {
35                 return &UnmarshalInvalidArgError{reflect.TypeOf(v)}
36         }
37
38         if !d.parse_value(pv.Elem()) {
39                 d.throwSyntaxError(d.offset-1, errors.New("unexpected 'e'"))
40         }
41         return nil
42 }
43
44 func check_for_unexpected_eof(err error, offset int64) {
45         if err == io.EOF {
46                 panic(&SyntaxError{
47                         Offset: offset,
48                         What:   io.ErrUnexpectedEOF,
49                 })
50         }
51 }
52
53 func (d *decoder) read_byte() byte {
54         b, err := d.ReadByte()
55         if err != nil {
56                 check_for_unexpected_eof(err, d.offset)
57                 panic(err)
58         }
59
60         d.offset++
61         return b
62 }
63
64 // reads data writing it to 'd.buf' until 'sep' byte is encountered, 'sep' byte
65 // is consumed, but not included into the 'd.buf'
66 func (d *decoder) read_until(sep byte) {
67         for {
68                 b := d.read_byte()
69                 if b == sep {
70                         return
71                 }
72                 d.buf.WriteByte(b)
73         }
74 }
75
76 func check_for_int_parse_error(err error, offset int64) {
77         if err != nil {
78                 panic(&SyntaxError{
79                         Offset: offset,
80                         What:   err,
81                 })
82         }
83 }
84
85 func (d *decoder) throwSyntaxError(offset int64, err error) {
86         panic(&SyntaxError{
87                 Offset: offset,
88                 What:   err,
89         })
90 }
91
92 // called when 'i' was consumed
93 func (d *decoder) parse_int(v reflect.Value) {
94         start := d.offset - 1
95         d.read_until('e')
96         if d.buf.Len() == 0 {
97                 panic(&SyntaxError{
98                         Offset: start,
99                         What:   errors.New("empty integer value"),
100                 })
101         }
102
103         s := d.buf.String()
104
105         switch v.Kind() {
106         case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
107                 n, err := strconv.ParseInt(s, 10, 64)
108                 check_for_int_parse_error(err, start)
109
110                 if v.OverflowInt(n) {
111                         panic(&UnmarshalTypeError{
112                                 Value: "integer " + s,
113                                 Type:  v.Type(),
114                         })
115                 }
116                 v.SetInt(n)
117         case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
118                 n, err := strconv.ParseUint(s, 10, 64)
119                 check_for_int_parse_error(err, start)
120
121                 if v.OverflowUint(n) {
122                         panic(&UnmarshalTypeError{
123                                 Value: "integer " + s,
124                                 Type:  v.Type(),
125                         })
126                 }
127                 v.SetUint(n)
128         case reflect.Bool:
129                 v.SetBool(s != "0")
130         default:
131                 panic(&UnmarshalTypeError{
132                         Value: "integer " + s,
133                         Type:  v.Type(),
134                 })
135         }
136         d.buf.Reset()
137 }
138
139 func (d *decoder) parse_string(v reflect.Value) {
140         start := d.offset - 1
141
142         // read the string length first
143         d.read_until(':')
144         length, err := strconv.ParseInt(d.buf.String(), 10, 64)
145         check_for_int_parse_error(err, start)
146
147         d.buf.Reset()
148         n, err := io.CopyN(&d.buf, d, length)
149         d.offset += n
150         if err != nil {
151                 check_for_unexpected_eof(err, d.offset)
152                 panic(&SyntaxError{
153                         Offset: d.offset,
154                         What:   errors.New("unexpected I/O error: " + err.Error()),
155                 })
156         }
157
158         switch v.Kind() {
159         case reflect.String:
160                 v.SetString(d.buf.String())
161         case reflect.Slice:
162                 if v.Type().Elem().Kind() != reflect.Uint8 {
163                         panic(&UnmarshalTypeError{
164                                 Value: "string",
165                                 Type:  v.Type(),
166                         })
167                 }
168                 sl := make([]byte, len(d.buf.Bytes()))
169                 copy(sl, d.buf.Bytes())
170                 v.Set(reflect.ValueOf(sl))
171         default:
172                 panic(&UnmarshalTypeError{
173                         Value: "string",
174                         Type:  v.Type(),
175                 })
176         }
177
178         d.buf.Reset()
179 }
180
181 func (d *decoder) parse_dict(v reflect.Value) {
182         switch v.Kind() {
183         case reflect.Map:
184                 t := v.Type()
185                 if t.Key().Kind() != reflect.String {
186                         panic(&UnmarshalTypeError{
187                                 Value: "object",
188                                 Type:  t,
189                         })
190                 }
191                 if v.IsNil() {
192                         v.Set(reflect.MakeMap(t))
193                 }
194         case reflect.Struct:
195         default:
196                 panic(&UnmarshalTypeError{
197                         Value: "object",
198                         Type:  v.Type(),
199                 })
200         }
201
202         var map_elem reflect.Value
203
204         // so, at this point 'd' byte was consumed, let's just read key/value
205         // pairs one by one
206         for {
207                 var valuev reflect.Value
208                 keyv := reflect.ValueOf(&d.key).Elem()
209                 if !d.parse_value(keyv) {
210                         return
211                 }
212
213                 // get valuev as a map value or as a struct field
214                 switch v.Kind() {
215                 case reflect.Map:
216                         elem_type := v.Type().Elem()
217                         if !map_elem.IsValid() {
218                                 map_elem = reflect.New(elem_type).Elem()
219                         } else {
220                                 map_elem.Set(reflect.Zero(elem_type))
221                         }
222                         valuev = map_elem
223                 case reflect.Struct:
224                         var f reflect.StructField
225                         var ok bool
226
227                         t := v.Type()
228                         for i, n := 0, t.NumField(); i < n; i++ {
229                                 f = t.Field(i)
230                                 tag := f.Tag.Get("bencode")
231                                 if tag == "-" {
232                                         continue
233                                 }
234                                 if f.Anonymous {
235                                         continue
236                                 }
237
238                                 tag_name, _ := parse_tag(tag)
239                                 if tag_name == d.key {
240                                         ok = true
241                                         break
242                                 }
243
244                                 if f.Name == d.key {
245                                         ok = true
246                                         break
247                                 }
248
249                                 if strings.EqualFold(f.Name, d.key) {
250                                         ok = true
251                                         break
252                                 }
253                         }
254
255                         if ok {
256                                 if f.PkgPath != "" {
257                                         panic(&UnmarshalFieldError{
258                                                 Key:   d.key,
259                                                 Type:  v.Type(),
260                                                 Field: f,
261                                         })
262                                 } else {
263                                         valuev = v.FieldByIndex(f.Index)
264                                 }
265                         } else {
266                                 _, ok := d.parse_value_interface()
267                                 if !ok {
268                                         return
269                                 }
270                                 continue
271                         }
272                 }
273
274                 // now we need to actually parse it
275                 if !d.parse_value(valuev) {
276                         return
277                 }
278
279                 if v.Kind() == reflect.Map {
280                         v.SetMapIndex(keyv, valuev)
281                 }
282         }
283 }
284
285 func (d *decoder) parse_list(v reflect.Value) {
286         switch v.Kind() {
287         case reflect.Array, reflect.Slice:
288         default:
289                 panic(&UnmarshalTypeError{
290                         Value: "array",
291                         Type:  v.Type(),
292                 })
293         }
294
295         i := 0
296         for {
297                 if v.Kind() == reflect.Slice && i >= v.Len() {
298                         v.Set(reflect.Append(v, reflect.Zero(v.Type().Elem())))
299                 }
300
301                 ok := false
302                 if i < v.Len() {
303                         ok = d.parse_value(v.Index(i))
304                 } else {
305                         _, ok = d.parse_value_interface()
306                 }
307
308                 if !ok {
309                         break
310                 }
311
312                 i++
313         }
314
315         if i < v.Len() {
316                 if v.Kind() == reflect.Array {
317                         z := reflect.Zero(v.Type().Elem())
318                         for n := v.Len(); i < n; i++ {
319                                 v.Index(i).Set(z)
320                         }
321                 } else {
322                         v.SetLen(i)
323                 }
324         }
325
326         if i == 0 && v.Kind() == reflect.Slice {
327                 v.Set(reflect.MakeSlice(v.Type(), 0, 0))
328         }
329 }
330
331 func (d *decoder) read_one_value() bool {
332         b, err := d.ReadByte()
333         if err != nil {
334                 panic(err)
335         }
336         if b == 'e' {
337                 d.UnreadByte()
338                 return false
339         } else {
340                 d.offset++
341                 d.buf.WriteByte(b)
342         }
343
344         switch b {
345         case 'd', 'l':
346                 // read until there is nothing to read
347                 for d.read_one_value() {
348                 }
349                 // consume 'e' as well
350                 b = d.read_byte()
351                 d.buf.WriteByte(b)
352         case 'i':
353                 d.read_until('e')
354                 d.buf.WriteString("e")
355         default:
356                 if b >= '0' && b <= '9' {
357                         start := d.buf.Len() - 1
358                         d.read_until(':')
359                         length, err := strconv.ParseInt(d.buf.String()[start:], 10, 64)
360                         check_for_int_parse_error(err, d.offset-1)
361
362                         d.buf.WriteString(":")
363                         n, err := io.CopyN(&d.buf, d, length)
364                         d.offset += n
365                         if err != nil {
366                                 check_for_unexpected_eof(err, d.offset)
367                                 panic(&SyntaxError{
368                                         Offset: d.offset,
369                                         What:   errors.New("unexpected I/O error: " + err.Error()),
370                                 })
371                         }
372                         break
373                 }
374
375                 d.raiseUnknownValueType(b, d.offset-1)
376         }
377
378         return true
379
380 }
381
382 func (d *decoder) parse_unmarshaler(v reflect.Value) bool {
383         m, ok := v.Interface().(Unmarshaler)
384         if !ok {
385                 // T doesn't work, try *T
386                 if v.Kind() != reflect.Ptr && v.CanAddr() {
387                         m, ok = v.Addr().Interface().(Unmarshaler)
388                         if ok {
389                                 v = v.Addr()
390                         }
391                 }
392         }
393         if ok && (v.Kind() != reflect.Ptr || !v.IsNil()) {
394                 if d.read_one_value() {
395                         err := m.UnmarshalBencode(d.buf.Bytes())
396                         d.buf.Reset()
397                         if err != nil {
398                                 panic(&UnmarshalerError{v.Type(), err})
399                         }
400                         return true
401                 }
402                 d.buf.Reset()
403         }
404
405         return false
406 }
407
408 // Returns true if there was a value and it's now stored in 'v', otherwise
409 // there was an end symbol ("e") and no value was stored.
410 func (d *decoder) parse_value(v reflect.Value) bool {
411         // we support one level of indirection at the moment
412         if v.Kind() == reflect.Ptr {
413                 // if the pointer is nil, allocate a new element of the type it
414                 // points to
415                 if v.IsNil() {
416                         v.Set(reflect.New(v.Type().Elem()))
417                 }
418                 v = v.Elem()
419         }
420
421         if d.parse_unmarshaler(v) {
422                 return true
423         }
424
425         // common case: interface{}
426         if v.Kind() == reflect.Interface && v.NumMethod() == 0 {
427                 iface, _ := d.parse_value_interface()
428                 v.Set(reflect.ValueOf(iface))
429                 return true
430         }
431
432         b, err := d.ReadByte()
433         if err != nil {
434                 panic(err)
435         }
436         d.offset++
437
438         switch b {
439         case 'e':
440                 return false
441         case 'd':
442                 d.parse_dict(v)
443         case 'l':
444                 d.parse_list(v)
445         case 'i':
446                 d.parse_int(v)
447         default:
448                 if b >= '0' && b <= '9' {
449                         // string
450                         // append first digit of the length to the buffer
451                         d.buf.WriteByte(b)
452                         d.parse_string(v)
453                         break
454                 }
455
456                 d.raiseUnknownValueType(b, d.offset-1)
457         }
458
459         return true
460 }
461
462 // An unknown bencode type character was encountered.
463 func (d *decoder) raiseUnknownValueType(b byte, offset int64) {
464         panic(&SyntaxError{
465                 Offset: offset,
466                 What:   fmt.Errorf("unknown value type %+q", b),
467         })
468 }
469
470 func (d *decoder) parse_value_interface() (interface{}, bool) {
471         b, err := d.ReadByte()
472         if err != nil {
473                 panic(err)
474         }
475         d.offset++
476
477         switch b {
478         case 'e':
479                 return nil, false
480         case 'd':
481                 return d.parse_dict_interface(), true
482         case 'l':
483                 return d.parse_list_interface(), true
484         case 'i':
485                 return d.parse_int_interface(), true
486         default:
487                 if b >= '0' && b <= '9' {
488                         // string
489                         // append first digit of the length to the buffer
490                         d.buf.WriteByte(b)
491                         return d.parse_string_interface(), true
492                 }
493
494                 d.raiseUnknownValueType(b, d.offset-1)
495                 panic("unreachable")
496         }
497 }
498
499 func (d *decoder) parse_int_interface() (ret interface{}) {
500         start := d.offset - 1
501         d.read_until('e')
502         if d.buf.Len() == 0 {
503                 panic(&SyntaxError{
504                         Offset: start,
505                         What:   errors.New("empty integer value"),
506                 })
507         }
508
509         n, err := strconv.ParseInt(d.buf.String(), 10, 64)
510         if ne, ok := err.(*strconv.NumError); ok && ne.Err == strconv.ErrRange {
511                 i := new(big.Int)
512                 _, ok := i.SetString(d.buf.String(), 10)
513                 if !ok {
514                         panic(&SyntaxError{
515                                 Offset: start,
516                                 What:   errors.New("failed to parse integer"),
517                         })
518                 }
519                 ret = i
520         } else {
521                 check_for_int_parse_error(err, start)
522                 ret = n
523         }
524
525         d.buf.Reset()
526         return
527 }
528
529 func (d *decoder) parse_string_interface() interface{} {
530         start := d.offset - 1
531
532         // read the string length first
533         d.read_until(':')
534         length, err := strconv.ParseInt(d.buf.String(), 10, 64)
535         check_for_int_parse_error(err, start)
536
537         d.buf.Reset()
538         n, err := io.CopyN(&d.buf, d, length)
539         d.offset += n
540         if err != nil {
541                 check_for_unexpected_eof(err, d.offset)
542                 panic(&SyntaxError{
543                         Offset: d.offset,
544                         What:   errors.New("unexpected I/O error: " + err.Error()),
545                 })
546         }
547
548         s := d.buf.String()
549         d.buf.Reset()
550         return s
551 }
552
553 func (d *decoder) parse_dict_interface() interface{} {
554         dict := make(map[string]interface{})
555         for {
556                 keyi, ok := d.parse_value_interface()
557                 if !ok {
558                         break
559                 }
560
561                 key, ok := keyi.(string)
562                 if !ok {
563                         panic(&SyntaxError{
564                                 Offset: d.offset,
565                                 What:   errors.New("non-string key in a dict"),
566                         })
567                 }
568
569                 valuei, ok := d.parse_value_interface()
570                 if !ok {
571                         break
572                 }
573
574                 dict[key] = valuei
575         }
576         return dict
577 }
578
579 func (d *decoder) parse_list_interface() interface{} {
580         var list []interface{}
581         for {
582                 valuei, ok := d.parse_value_interface()
583                 if !ok {
584                         break
585                 }
586
587                 list = append(list, valuei)
588         }
589         if list == nil {
590                 list = make([]interface{}, 0, 0)
591         }
592         return list
593 }