]> Sergey Matveev's repositories - btrtrc.git/blob - bencode/decode.go
bencode: Simplify parse_int
[btrtrc.git] / bencode / decode.go
1 package bencode
2
3 import (
4         "bufio"
5         "bytes"
6         "errors"
7         "io"
8         "reflect"
9         "runtime"
10         "strconv"
11         "strings"
12 )
13
14 type decoder struct {
15         *bufio.Reader
16         offset int64
17         buf    bytes.Buffer
18         key    string
19 }
20
21 func (d *decoder) decode(v interface{}) (err error) {
22         defer func() {
23                 if e := recover(); e != nil {
24                         if _, ok := e.(runtime.Error); ok {
25                                 panic(e)
26                         }
27                         err = e.(error)
28                 }
29         }()
30
31         pv := reflect.ValueOf(v)
32         if pv.Kind() != reflect.Ptr || pv.IsNil() {
33                 return &UnmarshalInvalidArgError{reflect.TypeOf(v)}
34         }
35
36         if !d.parse_value(pv.Elem()) {
37                 d.throwSyntaxError(d.offset-1, errors.New("unexpected 'e'"))
38         }
39         return nil
40 }
41
42 func check_for_unexpected_eof(err error, offset int64) {
43         if err == io.EOF {
44                 panic(&SyntaxError{
45                         Offset: offset,
46                         What:   io.ErrUnexpectedEOF,
47                 })
48         }
49 }
50
51 func (d *decoder) read_byte() byte {
52         b, err := d.ReadByte()
53         if err != nil {
54                 check_for_unexpected_eof(err, d.offset)
55                 panic(err)
56         }
57
58         d.offset++
59         return b
60 }
61
62 // reads data writing it to 'd.buf' until 'sep' byte is encountered, 'sep' byte
63 // is consumed, but not included into the 'd.buf'
64 func (d *decoder) read_until(sep byte) {
65         for {
66                 b := d.read_byte()
67                 if b == sep {
68                         return
69                 }
70                 d.buf.WriteByte(b)
71         }
72 }
73
74 func check_for_int_parse_error(err error, offset int64) {
75         if err != nil {
76                 panic(&SyntaxError{
77                         Offset: offset,
78                         What:   err,
79                 })
80         }
81 }
82
83 func (d *decoder) throwSyntaxError(offset int64, err error) {
84         panic(&SyntaxError{
85                 Offset: offset,
86                 What:   err,
87         })
88 }
89
90 // called when 'i' was consumed
91 func (d *decoder) parse_int(v reflect.Value) {
92         start := d.offset - 1
93         d.read_until('e')
94         if d.buf.Len() == 0 {
95                 panic(&SyntaxError{
96                         Offset: start,
97                         What:   errors.New("empty integer value"),
98                 })
99         }
100
101         s := d.buf.String()
102
103         switch v.Kind() {
104         case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
105                 n, err := strconv.ParseInt(s, 10, 64)
106                 check_for_int_parse_error(err, start)
107
108                 if v.OverflowInt(n) {
109                         panic(&UnmarshalTypeError{
110                                 Value: "integer " + s,
111                                 Type:  v.Type(),
112                         })
113                 }
114                 v.SetInt(n)
115         case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
116                 n, err := strconv.ParseUint(s, 10, 64)
117                 check_for_int_parse_error(err, start)
118
119                 if v.OverflowUint(n) {
120                         panic(&UnmarshalTypeError{
121                                 Value: "integer " + s,
122                                 Type:  v.Type(),
123                         })
124                 }
125                 v.SetUint(n)
126         case reflect.Bool:
127                 v.SetBool(s != "0")
128         default:
129                 panic(&UnmarshalTypeError{
130                         Value: "integer " + s,
131                         Type:  v.Type(),
132                 })
133         }
134         d.buf.Reset()
135 }
136
137 func (d *decoder) parse_string(v reflect.Value) {
138         start := d.offset - 1
139
140         // read the string length first
141         d.read_until(':')
142         length, err := strconv.ParseInt(d.buf.String(), 10, 64)
143         check_for_int_parse_error(err, start)
144
145         d.buf.Reset()
146         n, err := io.CopyN(&d.buf, d, length)
147         d.offset += n
148         if err != nil {
149                 check_for_unexpected_eof(err, d.offset)
150                 panic(&SyntaxError{
151                         Offset: d.offset,
152                         What:   errors.New("unexpected I/O error: " + err.Error()),
153                 })
154         }
155
156         switch v.Kind() {
157         case reflect.String:
158                 v.SetString(d.buf.String())
159         case reflect.Slice:
160                 if v.Type().Elem().Kind() != reflect.Uint8 {
161                         panic(&UnmarshalTypeError{
162                                 Value: "string",
163                                 Type:  v.Type(),
164                         })
165                 }
166                 sl := make([]byte, len(d.buf.Bytes()))
167                 copy(sl, d.buf.Bytes())
168                 v.Set(reflect.ValueOf(sl))
169         default:
170                 panic(&UnmarshalTypeError{
171                         Value: "string",
172                         Type:  v.Type(),
173                 })
174         }
175
176         d.buf.Reset()
177 }
178
179 func (d *decoder) parse_dict(v reflect.Value) {
180         switch v.Kind() {
181         case reflect.Map:
182                 t := v.Type()
183                 if t.Key().Kind() != reflect.String {
184                         panic(&UnmarshalTypeError{
185                                 Value: "object",
186                                 Type:  t,
187                         })
188                 }
189                 if v.IsNil() {
190                         v.Set(reflect.MakeMap(t))
191                 }
192         case reflect.Struct:
193         default:
194                 panic(&UnmarshalTypeError{
195                         Value: "object",
196                         Type:  v.Type(),
197                 })
198         }
199
200         var map_elem reflect.Value
201
202         // so, at this point 'd' byte was consumed, let's just read key/value
203         // pairs one by one
204         for {
205                 var valuev reflect.Value
206                 keyv := reflect.ValueOf(&d.key).Elem()
207                 if !d.parse_value(keyv) {
208                         return
209                 }
210
211                 // get valuev as a map value or as a struct field
212                 switch v.Kind() {
213                 case reflect.Map:
214                         elem_type := v.Type().Elem()
215                         if !map_elem.IsValid() {
216                                 map_elem = reflect.New(elem_type).Elem()
217                         } else {
218                                 map_elem.Set(reflect.Zero(elem_type))
219                         }
220                         valuev = map_elem
221                 case reflect.Struct:
222                         var f reflect.StructField
223                         var ok bool
224
225                         t := v.Type()
226                         for i, n := 0, t.NumField(); i < n; i++ {
227                                 f = t.Field(i)
228                                 tag := f.Tag.Get("bencode")
229                                 if tag == "-" {
230                                         continue
231                                 }
232                                 if f.Anonymous {
233                                         continue
234                                 }
235
236                                 tag_name, _ := parse_tag(tag)
237                                 if tag_name == d.key {
238                                         ok = true
239                                         break
240                                 }
241
242                                 if f.Name == d.key {
243                                         ok = true
244                                         break
245                                 }
246
247                                 if strings.EqualFold(f.Name, d.key) {
248                                         ok = true
249                                         break
250                                 }
251                         }
252
253                         if ok {
254                                 if f.PkgPath != "" {
255                                         panic(&UnmarshalFieldError{
256                                                 Key:   d.key,
257                                                 Type:  v.Type(),
258                                                 Field: f,
259                                         })
260                                 } else {
261                                         valuev = v.FieldByIndex(f.Index)
262                                 }
263                         } else {
264                                 _, ok := d.parse_value_interface()
265                                 if !ok {
266                                         panic(&SyntaxError{
267                                                 Offset: d.offset,
268                                                 What:   errors.New("unexpected end of dict, no matching value for a given key"),
269                                         })
270                                 }
271                                 continue
272                         }
273                 }
274
275                 // now we need to actually parse it
276                 if !d.parse_value(valuev) {
277                         panic(&SyntaxError{
278                                 Offset: d.offset,
279                                 What:   errors.New("unexpected end of dict, no matching value for a given key"),
280                         })
281                 }
282
283                 if v.Kind() == reflect.Map {
284                         v.SetMapIndex(keyv, valuev)
285                 }
286         }
287 }
288
289 func (d *decoder) parse_list(v reflect.Value) {
290         switch v.Kind() {
291         case reflect.Array, reflect.Slice:
292         default:
293                 panic(&UnmarshalTypeError{
294                         Value: "array",
295                         Type:  v.Type(),
296                 })
297         }
298
299         i := 0
300         for {
301                 if v.Kind() == reflect.Slice && i >= v.Len() {
302                         v.Set(reflect.Append(v, reflect.Zero(v.Type().Elem())))
303                 }
304
305                 ok := false
306                 if i < v.Len() {
307                         ok = d.parse_value(v.Index(i))
308                 } else {
309                         _, ok = d.parse_value_interface()
310                 }
311
312                 if !ok {
313                         break
314                 }
315
316                 i++
317         }
318
319         if i < v.Len() {
320                 if v.Kind() == reflect.Array {
321                         z := reflect.Zero(v.Type().Elem())
322                         for n := v.Len(); i < n; i++ {
323                                 v.Index(i).Set(z)
324                         }
325                 } else {
326                         v.SetLen(i)
327                 }
328         }
329
330         if i == 0 && v.Kind() == reflect.Slice {
331                 v.Set(reflect.MakeSlice(v.Type(), 0, 0))
332         }
333 }
334
335 func (d *decoder) read_one_value() bool {
336         b, err := d.ReadByte()
337         if err != nil {
338                 panic(err)
339         }
340         if b == 'e' {
341                 d.UnreadByte()
342                 return false
343         } else {
344                 d.offset++
345                 d.buf.WriteByte(b)
346         }
347
348         switch b {
349         case 'd', 'l':
350                 // read until there is nothing to read
351                 for d.read_one_value() {
352                 }
353                 // consume 'e' as well
354                 b = d.read_byte()
355                 d.buf.WriteByte(b)
356         case 'i':
357                 d.read_until('e')
358                 d.buf.WriteString("e")
359         default:
360                 if b >= '0' && b <= '9' {
361                         start := d.buf.Len() - 1
362                         d.read_until(':')
363                         length, err := strconv.ParseInt(d.buf.String()[start:], 10, 64)
364                         check_for_int_parse_error(err, d.offset-1)
365
366                         d.buf.WriteString(":")
367                         n, err := io.CopyN(&d.buf, d, length)
368                         d.offset += n
369                         if err != nil {
370                                 check_for_unexpected_eof(err, d.offset)
371                                 panic(&SyntaxError{
372                                         Offset: d.offset,
373                                         What:   errors.New("unexpected I/O error: " + err.Error()),
374                                 })
375                         }
376                         break
377                 }
378
379                 // unknown value
380                 panic(&SyntaxError{
381                         Offset: d.offset - 1,
382                         What:   errors.New("unknown value type (invalid bencode?)"),
383                 })
384         }
385
386         return true
387
388 }
389
390 func (d *decoder) parse_unmarshaler(v reflect.Value) bool {
391         m, ok := v.Interface().(Unmarshaler)
392         if !ok {
393                 // T doesn't work, try *T
394                 if v.Kind() != reflect.Ptr && v.CanAddr() {
395                         m, ok = v.Addr().Interface().(Unmarshaler)
396                         if ok {
397                                 v = v.Addr()
398                         }
399                 }
400         }
401         if ok && (v.Kind() != reflect.Ptr || !v.IsNil()) {
402                 if d.read_one_value() {
403                         err := m.UnmarshalBencode(d.buf.Bytes())
404                         d.buf.Reset()
405                         if err != nil {
406                                 panic(&UnmarshalerError{v.Type(), err})
407                         }
408                         return true
409                 }
410                 d.buf.Reset()
411         }
412
413         return false
414 }
415
416 // returns true if there was a value and it's now stored in 'v', otherwise there
417 // was an end symbol ("e") and no value was stored
418 func (d *decoder) parse_value(v reflect.Value) bool {
419         // we support one level of indirection at the moment
420         if v.Kind() == reflect.Ptr {
421                 // if the pointer is nil, allocate a new element of the type it
422                 // points to
423                 if v.IsNil() {
424                         v.Set(reflect.New(v.Type().Elem()))
425                 }
426                 v = v.Elem()
427         }
428
429         if d.parse_unmarshaler(v) {
430                 return true
431         }
432
433         // common case: interface{}
434         if v.Kind() == reflect.Interface && v.NumMethod() == 0 {
435                 iface, _ := d.parse_value_interface()
436                 v.Set(reflect.ValueOf(iface))
437                 return true
438         }
439
440         b, err := d.ReadByte()
441         if err != nil {
442                 panic(err)
443         }
444         d.offset++
445
446         switch b {
447         case 'e':
448                 return false
449         case 'd':
450                 d.parse_dict(v)
451         case 'l':
452                 d.parse_list(v)
453         case 'i':
454                 d.parse_int(v)
455         default:
456                 if b >= '0' && b <= '9' {
457                         // string
458                         // append first digit of the length to the buffer
459                         d.buf.WriteByte(b)
460                         d.parse_string(v)
461                         break
462                 }
463
464                 // unknown value
465                 panic(&SyntaxError{
466                         Offset: d.offset - 1,
467                         What:   errors.New("unknown value type (invalid bencode?)"),
468                 })
469         }
470
471         return true
472 }
473
474 func (d *decoder) parse_value_interface() (interface{}, bool) {
475         b, err := d.ReadByte()
476         if err != nil {
477                 panic(err)
478         }
479         d.offset++
480
481         switch b {
482         case 'e':
483                 return nil, false
484         case 'd':
485                 return d.parse_dict_interface(), true
486         case 'l':
487                 return d.parse_list_interface(), true
488         case 'i':
489                 return d.parse_int_interface(), true
490         default:
491                 if b >= '0' && b <= '9' {
492                         // string
493                         // append first digit of the length to the buffer
494                         d.buf.WriteByte(b)
495                         return d.parse_string_interface(), true
496                 }
497
498                 // unknown value
499                 panic(&SyntaxError{
500                         Offset: d.offset - 1,
501                         What:   errors.New("unknown value type (invalid bencode?)"),
502                 })
503         }
504 }
505
506 func (d *decoder) parse_int_interface() interface{} {
507         start := d.offset - 1
508         d.read_until('e')
509         if d.buf.Len() == 0 {
510                 panic(&SyntaxError{
511                         Offset: start,
512                         What:   errors.New("empty integer value"),
513                 })
514         }
515
516         n, err := strconv.ParseInt(d.buf.String(), 10, 64)
517         check_for_int_parse_error(err, start)
518         d.buf.Reset()
519         return n
520 }
521
522 func (d *decoder) parse_string_interface() interface{} {
523         start := d.offset - 1
524
525         // read the string length first
526         d.read_until(':')
527         length, err := strconv.ParseInt(d.buf.String(), 10, 64)
528         check_for_int_parse_error(err, start)
529
530         d.buf.Reset()
531         n, err := io.CopyN(&d.buf, d, length)
532         d.offset += n
533         if err != nil {
534                 check_for_unexpected_eof(err, d.offset)
535                 panic(&SyntaxError{
536                         Offset: d.offset,
537                         What:   errors.New("unexpected I/O error: " + err.Error()),
538                 })
539         }
540
541         s := d.buf.String()
542         d.buf.Reset()
543         return s
544 }
545
546 func (d *decoder) parse_dict_interface() interface{} {
547         dict := make(map[string]interface{})
548         for {
549                 keyi, ok := d.parse_value_interface()
550                 if !ok {
551                         break
552                 }
553
554                 key, ok := keyi.(string)
555                 if !ok {
556                         panic(&SyntaxError{
557                                 Offset: d.offset,
558                                 What:   errors.New("non-string key in a dict"),
559                         })
560                 }
561
562                 valuei, ok := d.parse_value_interface()
563                 if !ok {
564                         panic(&SyntaxError{
565                                 Offset: d.offset,
566                                 What:   errors.New("unexpected end of dict, no matching value for a given key"),
567                         })
568                 }
569
570                 dict[key] = valuei
571         }
572         return dict
573 }
574
575 func (d *decoder) parse_list_interface() interface{} {
576         var list []interface{}
577         for {
578                 valuei, ok := d.parse_value_interface()
579                 if !ok {
580                         break
581                 }
582
583                 list = append(list, valuei)
584         }
585         if list == nil {
586                 list = make([]interface{}, 0, 0)
587         }
588         return list
589 }