]> Sergey Matveev's repositories - btrtrc.git/blob - bencode/decode.go
Add more comments to the bencode library. Implement UnmarshalerError.
[btrtrc.git] / bencode / decode.go
1 package bencode
2
3 import "reflect"
4 import "runtime"
5 import "bufio"
6 import "bytes"
7 import "strconv"
8 import "strings"
9 import "io"
10
11 type decoder struct {
12         *bufio.Reader
13         offset int64
14         buf    bytes.Buffer
15         key    string
16 }
17
18 func (d *decoder) decode(v interface{}) (err error) {
19         defer func() {
20                 if e := recover(); e != nil {
21                         if _, ok := e.(runtime.Error); ok {
22                                 panic(e)
23                         }
24                         err = e.(error)
25                 }
26         }()
27
28         pv := reflect.ValueOf(v)
29         if pv.Kind() != reflect.Ptr || pv.IsNil() {
30                 return &UnmarshalInvalidArgError{reflect.TypeOf(v)}
31         }
32
33         d.parse_value(pv.Elem())
34         return nil
35 }
36
37 func check_for_unexpected_eof(err error, offset int64) {
38         if err == io.EOF {
39                 panic(&SyntaxError{
40                         Offset: offset,
41                         what:   "unexpected EOF",
42                 })
43         }
44 }
45
46 func (d *decoder) read_byte() byte {
47         b, err := d.ReadByte()
48         if err != nil {
49                 check_for_unexpected_eof(err, d.offset)
50                 panic(err)
51         }
52
53         d.offset++
54         return b
55 }
56
57 // reads data writing it to 'd.buf' until 'sep' byte is encountered, 'sep' byte
58 // is consumed, but not included into the 'd.buf'
59 func (d *decoder) read_until(sep byte) {
60         for {
61                 b := d.read_byte()
62                 if b == sep {
63                         return
64                 }
65                 d.buf.WriteByte(b)
66         }
67 }
68
69 func check_for_int_parse_error(err error, offset int64) {
70         if err != nil {
71                 panic(&SyntaxError{
72                         Offset: offset,
73                         what:   err.Error(),
74                 })
75         }
76 }
77
78 // called when 'i' was consumed
79 func (d *decoder) parse_int(v reflect.Value) {
80         start := d.offset - 1
81         d.read_until('e')
82         if d.buf.Len() == 0 {
83                 panic(&SyntaxError{
84                         Offset: start,
85                         what:   "empty integer value",
86                 })
87         }
88
89         switch v.Kind() {
90         case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
91                 n, err := strconv.ParseInt(d.buf.String(), 10, 64)
92                 check_for_int_parse_error(err, start)
93
94                 if v.OverflowInt(n) {
95                         panic(&UnmarshalTypeError{
96                                 Value: "integer " + d.buf.String(),
97                                 Type:  v.Type(),
98                         })
99                 }
100                 v.SetInt(n)
101         case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
102                 n, err := strconv.ParseUint(d.buf.String(), 10, 64)
103                 check_for_int_parse_error(err, start)
104
105                 if v.OverflowUint(n) {
106                         panic(&UnmarshalTypeError{
107                                 Value: "integer " + d.buf.String(),
108                                 Type:  v.Type(),
109                         })
110                 }
111                 v.SetUint(n)
112         case reflect.Bool:
113                 if d.buf.Len() == 1 && d.buf.Bytes()[0] == '0' {
114                         v.SetBool(false)
115                 }
116                 v.SetBool(true)
117         default:
118                 panic(&UnmarshalTypeError{
119                         Value: "integer " + d.buf.String(),
120                         Type:  v.Type(),
121                 })
122         }
123         d.buf.Reset()
124 }
125
126 func (d *decoder) parse_string(v reflect.Value) {
127         start := d.offset - 1
128
129         // read the string length first
130         d.read_until(':')
131         length, err := strconv.ParseInt(d.buf.String(), 10, 64)
132         check_for_int_parse_error(err, start)
133
134         d.buf.Reset()
135         n, err := io.CopyN(&d.buf, d, length)
136         d.offset += n
137         if err != nil {
138                 check_for_unexpected_eof(err, d.offset)
139                 panic(&SyntaxError{
140                         Offset: d.offset,
141                         what:   "unexpected I/O error: " + err.Error(),
142                 })
143         }
144
145         switch v.Kind() {
146         case reflect.String:
147                 v.SetString(d.buf.String())
148         case reflect.Slice:
149                 if v.Type().Elem().Kind() != reflect.Uint8 {
150                         panic(&UnmarshalTypeError{
151                                 Value: "string",
152                                 Type:  v.Type(),
153                         })
154                 }
155                 v.Set(reflect.ValueOf(d.buf.Bytes()))
156         default:
157                 panic(&UnmarshalTypeError{
158                         Value: "string",
159                         Type:  v.Type(),
160                 })
161         }
162
163         d.buf.Reset()
164 }
165
166 func (d *decoder) parse_dict(v reflect.Value) {
167         switch v.Kind() {
168         case reflect.Map:
169                 t := v.Type()
170                 if t.Key().Kind() != reflect.String {
171                         panic(&UnmarshalTypeError{
172                                 Value: "object",
173                                 Type:  t,
174                         })
175                 }
176                 if v.IsNil() {
177                         v.Set(reflect.MakeMap(t))
178                 }
179         case reflect.Struct:
180         default:
181                 panic(&UnmarshalTypeError{
182                         Value: "object",
183                         Type:  v.Type(),
184                 })
185         }
186
187         var map_elem reflect.Value
188
189         // so, at this point 'd' byte was consumed, let's just read key/value
190         // pairs one by one
191         for {
192                 var valuev reflect.Value
193                 keyv := reflect.ValueOf(&d.key).Elem()
194                 if !d.parse_value(keyv) {
195                         return
196                 }
197
198                 // get valuev as a map value or as a struct field
199                 switch v.Kind() {
200                 case reflect.Map:
201                         elem_type := v.Type().Elem()
202                         if !map_elem.IsValid() {
203                                 map_elem = reflect.New(elem_type).Elem()
204                         } else {
205                                 map_elem.Set(reflect.Zero(elem_type))
206                         }
207                         valuev = map_elem
208                 case reflect.Struct:
209                         var f reflect.StructField
210                         var ok bool
211
212                         t := v.Type()
213                         for i, n := 0, t.NumField(); i < n; i++ {
214                                 f = t.Field(i)
215                                 tag := f.Tag.Get("bencode")
216                                 if tag == "-" {
217                                         continue
218                                 }
219                                 if f.Anonymous {
220                                         continue
221                                 }
222
223                                 tag_name, _ := parse_tag(tag)
224                                 if tag_name == d.key {
225                                         ok = true
226                                         break
227                                 }
228
229                                 if f.Name == d.key {
230                                         ok = true
231                                         break
232                                 }
233
234                                 if strings.EqualFold(f.Name, d.key) {
235                                         ok = true
236                                         break
237                                 }
238                         }
239
240                         if ok {
241                                 if f.PkgPath != "" {
242                                         panic(&UnmarshalFieldError{
243                                                 Key:   d.key,
244                                                 Type:  v.Type(),
245                                                 Field: f,
246                                         })
247                                 } else {
248                                         valuev = v.FieldByIndex(f.Index)
249                                 }
250                         } else {
251                                 _, ok := d.parse_value_interface()
252                                 if !ok {
253                                         panic(&SyntaxError{
254                                                 Offset: d.offset,
255                                                 what:   "unexpected end of dict, no matching value for a given key",
256                                         })
257                                 }
258                                 continue
259                         }
260                 }
261
262                 // now we need to actually parse it
263                 if !d.parse_value(valuev) {
264                         panic(&SyntaxError{
265                                 Offset: d.offset,
266                                 what:   "unexpected end of dict, no matching value for a given key",
267                         })
268                 }
269
270                 if v.Kind() == reflect.Map {
271                         v.SetMapIndex(keyv, valuev)
272                 }
273         }
274 }
275
276 func (d *decoder) parse_list(v reflect.Value) {
277         switch v.Kind() {
278         case reflect.Array, reflect.Slice:
279         default:
280                 panic(&UnmarshalTypeError{
281                         Value: "array",
282                         Type:  v.Type(),
283                 })
284         }
285
286         i := 0
287         for {
288                 if v.Kind() == reflect.Slice && i >= v.Len() {
289                         v.Set(reflect.Append(v, reflect.Zero(v.Type().Elem())))
290                 }
291
292                 ok := false
293                 if i < v.Len() {
294                         ok = d.parse_value(v.Index(i))
295                 } else {
296                         _, ok = d.parse_value_interface()
297                 }
298
299                 if !ok {
300                         break
301                 }
302
303                 i++
304         }
305
306         if i < v.Len() {
307                 if v.Kind() == reflect.Array {
308                         z := reflect.Zero(v.Type().Elem())
309                         for n := v.Len(); i < n; i++ {
310                                 v.Index(i).Set(z)
311                         }
312                 } else {
313                         v.SetLen(i)
314                 }
315         }
316
317         if i == 0 && v.Kind() == reflect.Slice {
318                 v.Set(reflect.MakeSlice(v.Type(), 0, 0))
319         }
320 }
321
322 func (d *decoder) read_one_value() bool {
323         b, err := d.ReadByte()
324         if err != nil {
325                 panic(err)
326         }
327         if b == 'e' {
328                 d.UnreadByte()
329                 return false
330         } else {
331                 d.offset++
332                 d.buf.WriteByte(b)
333         }
334
335         switch b {
336         case 'd', 'l':
337                 // read until there is nothing to read
338                 for d.read_one_value() {}
339                 // consume 'e' as well
340                 b = d.read_byte()
341                 d.buf.WriteByte(b)
342         case 'i':
343                 d.read_until('e')
344                 d.buf.WriteString("e")
345         default:
346                 if b >= '0' && b <= '9' {
347                         start := d.buf.Len() - 1
348                         d.read_until(':')
349                         length, err := strconv.ParseInt(d.buf.String()[start:], 10, 64)
350                         check_for_int_parse_error(err, d.offset - 1)
351
352                         d.buf.WriteString(":")
353                         n, err := io.CopyN(&d.buf, d, length)
354                         d.offset += n
355                         if err != nil {
356                                 check_for_unexpected_eof(err, d.offset)
357                                 panic(&SyntaxError{
358                                         Offset: d.offset,
359                                         what:   "unexpected I/O error: " + err.Error(),
360                                 })
361                         }
362                         break
363                 }
364
365                 // unknown value
366                 panic(&SyntaxError{
367                         Offset: d.offset - 1,
368                         what:   "unknown value type (invalid bencode?)",
369                 })
370         }
371
372         return true
373
374 }
375
376 func (d *decoder) parse_unmarshaler(v reflect.Value) bool {
377         m, ok := v.Interface().(Unmarshaler)
378         if !ok {
379                 // T doesn't work, try *T
380                 if v.Kind() != reflect.Ptr && v.CanAddr() {
381                         m, ok = v.Addr().Interface().(Unmarshaler)
382                         if ok {
383                                 v = v.Addr()
384                         }
385                 }
386         }
387         if ok && (v.Kind() != reflect.Ptr || !v.IsNil()) {
388                 if d.read_one_value() {
389                         err := m.UnmarshalBencode(d.buf.Bytes())
390                         d.buf.Reset()
391                         if err != nil {
392                                 panic(&UnmarshalerError{v.Type(), err})
393                         }
394                         return true
395                 }
396                 d.buf.Reset()
397         }
398
399         return false
400 }
401
402 // returns true if there was a value and it's now stored in 'v', otherwise there
403 // was an end symbol ("e") and no value was stored
404 func (d *decoder) parse_value(v reflect.Value) bool {
405         // we support one level of indirection at the moment
406         if v.Kind() == reflect.Ptr {
407                 // if the pointer is nil, allocate a new element of the type it
408                 // points to
409                 if v.IsNil() {
410                         v.Set(reflect.New(v.Type().Elem()))
411                 }
412                 v = v.Elem()
413         }
414
415         if d.parse_unmarshaler(v) {
416                 return true
417         }
418
419         // common case: interface{}
420         if v.Kind() == reflect.Interface && v.NumMethod() == 0 {
421                 iface, _ := d.parse_value_interface()
422                 v.Set(reflect.ValueOf(iface))
423                 return true
424         }
425
426         b, err := d.ReadByte()
427         if err != nil {
428                 panic(err)
429         }
430         d.offset++
431
432         switch b {
433         case 'e':
434                 return false
435         case 'd':
436                 d.parse_dict(v)
437         case 'l':
438                 d.parse_list(v)
439         case 'i':
440                 d.parse_int(v)
441         default:
442                 if b >= '0' && b <= '9' {
443                         // string
444                         // append first digit of the length to the buffer
445                         d.buf.WriteByte(b)
446                         d.parse_string(v)
447                         break
448                 }
449
450                 // unknown value
451                 panic(&SyntaxError{
452                         Offset: d.offset - 1,
453                         what:   "unknown value type (invalid bencode?)",
454                 })
455         }
456
457         return true
458 }
459
460 func (d *decoder) parse_value_interface() (interface{}, bool) {
461         b, err := d.ReadByte()
462         if err != nil {
463                 panic(err)
464         }
465         d.offset++
466
467         switch b {
468         case 'e':
469                 return nil, false
470         case 'd':
471                 return d.parse_dict_interface(), true
472         case 'l':
473                 return d.parse_list_interface(), true
474         case 'i':
475                 return d.parse_int_interface(), true
476         default:
477                 if b >= '0' && b <= '9' {
478                         // string
479                         // append first digit of the length to the buffer
480                         d.buf.WriteByte(b)
481                         return d.parse_string_interface(), true
482                 }
483
484                 // unknown value
485                 panic(&SyntaxError{
486                         Offset: d.offset - 1,
487                         what:   "unknown value type (invalid bencode?)",
488                 })
489         }
490         panic("unreachable")
491 }
492
493 func (d *decoder) parse_int_interface() interface{} {
494         start := d.offset - 1
495         d.read_until('e')
496         if d.buf.Len() == 0 {
497                 panic(&SyntaxError{
498                         Offset: start,
499                         what:   "empty integer value",
500                 })
501         }
502
503         n, err := strconv.ParseInt(d.buf.String(), 10, 64)
504         check_for_int_parse_error(err, start)
505         d.buf.Reset()
506         return n
507 }
508
509 func (d *decoder) parse_string_interface() interface{} {
510         start := d.offset - 1
511
512         // read the string length first
513         d.read_until(':')
514         length, err := strconv.ParseInt(d.buf.String(), 10, 64)
515         check_for_int_parse_error(err, start)
516
517         d.buf.Reset()
518         n, err := io.CopyN(&d.buf, d, length)
519         d.offset += n
520         if err != nil {
521                 check_for_unexpected_eof(err, d.offset)
522                 panic(&SyntaxError{
523                         Offset: d.offset,
524                         what:   "unexpected I/O error: " + err.Error(),
525                 })
526         }
527
528         s := d.buf.String()
529         d.buf.Reset()
530         return s
531 }
532
533 func (d *decoder) parse_dict_interface() interface{} {
534         dict := make(map[string]interface{})
535         for {
536                 keyi, ok := d.parse_value_interface()
537                 if !ok {
538                         break
539                 }
540
541                 key, ok := keyi.(string)
542                 if !ok {
543                         panic(&SyntaxError{
544                                 Offset: d.offset,
545                                 what:   "non-string key in a dict",
546                         })
547                 }
548
549                 valuei, ok := d.parse_value_interface()
550                 if !ok {
551                         panic(&SyntaxError{
552                                 Offset: d.offset,
553                                 what:   "unexpected end of dict, no matching value for a given key",
554                         })
555                 }
556
557                 dict[key] = valuei
558         }
559         return dict
560 }
561
562 func (d *decoder) parse_list_interface() interface{} {
563         var list []interface{}
564         for {
565                 valuei, ok := d.parse_value_interface()
566                 if !ok {
567                         break
568                 }
569
570                 list = append(list, valuei)
571         }
572         if list == nil {
573                 list = make([]interface{}, 0, 0)
574         }
575         return list
576 }