]> Sergey Matveev's repositories - btrtrc.git/blob - bencode/decode.go
Merge pull request #9 from gitter-badger/gitter-badge
[btrtrc.git] / bencode / decode.go
1 package bencode
2
3 import (
4         "bufio"
5         "bytes"
6         "errors"
7         "io"
8         "reflect"
9         "runtime"
10         "strconv"
11         "strings"
12 )
13
14 type decoder struct {
15         *bufio.Reader
16         offset int64
17         buf    bytes.Buffer
18         key    string
19 }
20
21 func (d *decoder) decode(v interface{}) (err error) {
22         defer func() {
23                 if e := recover(); e != nil {
24                         if _, ok := e.(runtime.Error); ok {
25                                 panic(e)
26                         }
27                         err = e.(error)
28                 }
29         }()
30
31         pv := reflect.ValueOf(v)
32         if pv.Kind() != reflect.Ptr || pv.IsNil() {
33                 return &UnmarshalInvalidArgError{reflect.TypeOf(v)}
34         }
35
36         d.parse_value(pv.Elem())
37         return nil
38 }
39
40 func check_for_unexpected_eof(err error, offset int64) {
41         if err == io.EOF {
42                 panic(&SyntaxError{
43                         Offset: offset,
44                         What:   io.ErrUnexpectedEOF,
45                 })
46         }
47 }
48
49 func (d *decoder) read_byte() byte {
50         b, err := d.ReadByte()
51         if err != nil {
52                 check_for_unexpected_eof(err, d.offset)
53                 panic(err)
54         }
55
56         d.offset++
57         return b
58 }
59
60 // reads data writing it to 'd.buf' until 'sep' byte is encountered, 'sep' byte
61 // is consumed, but not included into the 'd.buf'
62 func (d *decoder) read_until(sep byte) {
63         for {
64                 b := d.read_byte()
65                 if b == sep {
66                         return
67                 }
68                 d.buf.WriteByte(b)
69         }
70 }
71
72 func check_for_int_parse_error(err error, offset int64) {
73         if err != nil {
74                 panic(&SyntaxError{
75                         Offset: offset,
76                         What:   err,
77                 })
78         }
79 }
80
81 // called when 'i' was consumed
82 func (d *decoder) parse_int(v reflect.Value) {
83         start := d.offset - 1
84         d.read_until('e')
85         if d.buf.Len() == 0 {
86                 panic(&SyntaxError{
87                         Offset: start,
88                         What:   errors.New("empty integer value"),
89                 })
90         }
91
92         switch v.Kind() {
93         case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
94                 n, err := strconv.ParseInt(d.buf.String(), 10, 64)
95                 check_for_int_parse_error(err, start)
96
97                 if v.OverflowInt(n) {
98                         panic(&UnmarshalTypeError{
99                                 Value: "integer " + d.buf.String(),
100                                 Type:  v.Type(),
101                         })
102                 }
103                 v.SetInt(n)
104         case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
105                 n, err := strconv.ParseUint(d.buf.String(), 10, 64)
106                 check_for_int_parse_error(err, start)
107
108                 if v.OverflowUint(n) {
109                         panic(&UnmarshalTypeError{
110                                 Value: "integer " + d.buf.String(),
111                                 Type:  v.Type(),
112                         })
113                 }
114                 v.SetUint(n)
115         case reflect.Bool:
116                 v.SetBool(d.buf.String() != "0")
117         default:
118                 panic(&UnmarshalTypeError{
119                         Value: "integer " + d.buf.String(),
120                         Type:  v.Type(),
121                 })
122         }
123         d.buf.Reset()
124 }
125
126 func (d *decoder) parse_string(v reflect.Value) {
127         start := d.offset - 1
128
129         // read the string length first
130         d.read_until(':')
131         length, err := strconv.ParseInt(d.buf.String(), 10, 64)
132         check_for_int_parse_error(err, start)
133
134         d.buf.Reset()
135         n, err := io.CopyN(&d.buf, d, length)
136         d.offset += n
137         if err != nil {
138                 check_for_unexpected_eof(err, d.offset)
139                 panic(&SyntaxError{
140                         Offset: d.offset,
141                         What:   errors.New("unexpected I/O error: " + err.Error()),
142                 })
143         }
144
145         switch v.Kind() {
146         case reflect.String:
147                 v.SetString(d.buf.String())
148         case reflect.Slice:
149                 if v.Type().Elem().Kind() != reflect.Uint8 {
150                         panic(&UnmarshalTypeError{
151                                 Value: "string",
152                                 Type:  v.Type(),
153                         })
154                 }
155                 sl := make([]byte, len(d.buf.Bytes()))
156                 copy(sl, d.buf.Bytes())
157                 v.Set(reflect.ValueOf(sl))
158         default:
159                 panic(&UnmarshalTypeError{
160                         Value: "string",
161                         Type:  v.Type(),
162                 })
163         }
164
165         d.buf.Reset()
166 }
167
168 func (d *decoder) parse_dict(v reflect.Value) {
169         switch v.Kind() {
170         case reflect.Map:
171                 t := v.Type()
172                 if t.Key().Kind() != reflect.String {
173                         panic(&UnmarshalTypeError{
174                                 Value: "object",
175                                 Type:  t,
176                         })
177                 }
178                 if v.IsNil() {
179                         v.Set(reflect.MakeMap(t))
180                 }
181         case reflect.Struct:
182         default:
183                 panic(&UnmarshalTypeError{
184                         Value: "object",
185                         Type:  v.Type(),
186                 })
187         }
188
189         var map_elem reflect.Value
190
191         // so, at this point 'd' byte was consumed, let's just read key/value
192         // pairs one by one
193         for {
194                 var valuev reflect.Value
195                 keyv := reflect.ValueOf(&d.key).Elem()
196                 if !d.parse_value(keyv) {
197                         return
198                 }
199
200                 // get valuev as a map value or as a struct field
201                 switch v.Kind() {
202                 case reflect.Map:
203                         elem_type := v.Type().Elem()
204                         if !map_elem.IsValid() {
205                                 map_elem = reflect.New(elem_type).Elem()
206                         } else {
207                                 map_elem.Set(reflect.Zero(elem_type))
208                         }
209                         valuev = map_elem
210                 case reflect.Struct:
211                         var f reflect.StructField
212                         var ok bool
213
214                         t := v.Type()
215                         for i, n := 0, t.NumField(); i < n; i++ {
216                                 f = t.Field(i)
217                                 tag := f.Tag.Get("bencode")
218                                 if tag == "-" {
219                                         continue
220                                 }
221                                 if f.Anonymous {
222                                         continue
223                                 }
224
225                                 tag_name, _ := parse_tag(tag)
226                                 if tag_name == d.key {
227                                         ok = true
228                                         break
229                                 }
230
231                                 if f.Name == d.key {
232                                         ok = true
233                                         break
234                                 }
235
236                                 if strings.EqualFold(f.Name, d.key) {
237                                         ok = true
238                                         break
239                                 }
240                         }
241
242                         if ok {
243                                 if f.PkgPath != "" {
244                                         panic(&UnmarshalFieldError{
245                                                 Key:   d.key,
246                                                 Type:  v.Type(),
247                                                 Field: f,
248                                         })
249                                 } else {
250                                         valuev = v.FieldByIndex(f.Index)
251                                 }
252                         } else {
253                                 _, ok := d.parse_value_interface()
254                                 if !ok {
255                                         panic(&SyntaxError{
256                                                 Offset: d.offset,
257                                                 What:   errors.New("unexpected end of dict, no matching value for a given key"),
258                                         })
259                                 }
260                                 continue
261                         }
262                 }
263
264                 // now we need to actually parse it
265                 if !d.parse_value(valuev) {
266                         panic(&SyntaxError{
267                                 Offset: d.offset,
268                                 What:   errors.New("unexpected end of dict, no matching value for a given key"),
269                         })
270                 }
271
272                 if v.Kind() == reflect.Map {
273                         v.SetMapIndex(keyv, valuev)
274                 }
275         }
276 }
277
278 func (d *decoder) parse_list(v reflect.Value) {
279         switch v.Kind() {
280         case reflect.Array, reflect.Slice:
281         default:
282                 panic(&UnmarshalTypeError{
283                         Value: "array",
284                         Type:  v.Type(),
285                 })
286         }
287
288         i := 0
289         for {
290                 if v.Kind() == reflect.Slice && i >= v.Len() {
291                         v.Set(reflect.Append(v, reflect.Zero(v.Type().Elem())))
292                 }
293
294                 ok := false
295                 if i < v.Len() {
296                         ok = d.parse_value(v.Index(i))
297                 } else {
298                         _, ok = d.parse_value_interface()
299                 }
300
301                 if !ok {
302                         break
303                 }
304
305                 i++
306         }
307
308         if i < v.Len() {
309                 if v.Kind() == reflect.Array {
310                         z := reflect.Zero(v.Type().Elem())
311                         for n := v.Len(); i < n; i++ {
312                                 v.Index(i).Set(z)
313                         }
314                 } else {
315                         v.SetLen(i)
316                 }
317         }
318
319         if i == 0 && v.Kind() == reflect.Slice {
320                 v.Set(reflect.MakeSlice(v.Type(), 0, 0))
321         }
322 }
323
324 func (d *decoder) read_one_value() bool {
325         b, err := d.ReadByte()
326         if err != nil {
327                 panic(err)
328         }
329         if b == 'e' {
330                 d.UnreadByte()
331                 return false
332         } else {
333                 d.offset++
334                 d.buf.WriteByte(b)
335         }
336
337         switch b {
338         case 'd', 'l':
339                 // read until there is nothing to read
340                 for d.read_one_value() {
341                 }
342                 // consume 'e' as well
343                 b = d.read_byte()
344                 d.buf.WriteByte(b)
345         case 'i':
346                 d.read_until('e')
347                 d.buf.WriteString("e")
348         default:
349                 if b >= '0' && b <= '9' {
350                         start := d.buf.Len() - 1
351                         d.read_until(':')
352                         length, err := strconv.ParseInt(d.buf.String()[start:], 10, 64)
353                         check_for_int_parse_error(err, d.offset-1)
354
355                         d.buf.WriteString(":")
356                         n, err := io.CopyN(&d.buf, d, length)
357                         d.offset += n
358                         if err != nil {
359                                 check_for_unexpected_eof(err, d.offset)
360                                 panic(&SyntaxError{
361                                         Offset: d.offset,
362                                         What:   errors.New("unexpected I/O error: " + err.Error()),
363                                 })
364                         }
365                         break
366                 }
367
368                 // unknown value
369                 panic(&SyntaxError{
370                         Offset: d.offset - 1,
371                         What:   errors.New("unknown value type (invalid bencode?)"),
372                 })
373         }
374
375         return true
376
377 }
378
379 func (d *decoder) parse_unmarshaler(v reflect.Value) bool {
380         m, ok := v.Interface().(Unmarshaler)
381         if !ok {
382                 // T doesn't work, try *T
383                 if v.Kind() != reflect.Ptr && v.CanAddr() {
384                         m, ok = v.Addr().Interface().(Unmarshaler)
385                         if ok {
386                                 v = v.Addr()
387                         }
388                 }
389         }
390         if ok && (v.Kind() != reflect.Ptr || !v.IsNil()) {
391                 if d.read_one_value() {
392                         err := m.UnmarshalBencode(d.buf.Bytes())
393                         d.buf.Reset()
394                         if err != nil {
395                                 panic(&UnmarshalerError{v.Type(), err})
396                         }
397                         return true
398                 }
399                 d.buf.Reset()
400         }
401
402         return false
403 }
404
405 // returns true if there was a value and it's now stored in 'v', otherwise there
406 // was an end symbol ("e") and no value was stored
407 func (d *decoder) parse_value(v reflect.Value) bool {
408         // we support one level of indirection at the moment
409         if v.Kind() == reflect.Ptr {
410                 // if the pointer is nil, allocate a new element of the type it
411                 // points to
412                 if v.IsNil() {
413                         v.Set(reflect.New(v.Type().Elem()))
414                 }
415                 v = v.Elem()
416         }
417
418         if d.parse_unmarshaler(v) {
419                 return true
420         }
421
422         // common case: interface{}
423         if v.Kind() == reflect.Interface && v.NumMethod() == 0 {
424                 iface, _ := d.parse_value_interface()
425                 v.Set(reflect.ValueOf(iface))
426                 return true
427         }
428
429         b, err := d.ReadByte()
430         if err != nil {
431                 panic(err)
432         }
433         d.offset++
434
435         switch b {
436         case 'e':
437                 return false
438         case 'd':
439                 d.parse_dict(v)
440         case 'l':
441                 d.parse_list(v)
442         case 'i':
443                 d.parse_int(v)
444         default:
445                 if b >= '0' && b <= '9' {
446                         // string
447                         // append first digit of the length to the buffer
448                         d.buf.WriteByte(b)
449                         d.parse_string(v)
450                         break
451                 }
452
453                 // unknown value
454                 panic(&SyntaxError{
455                         Offset: d.offset - 1,
456                         What:   errors.New("unknown value type (invalid bencode?)"),
457                 })
458         }
459
460         return true
461 }
462
463 func (d *decoder) parse_value_interface() (interface{}, bool) {
464         b, err := d.ReadByte()
465         if err != nil {
466                 panic(err)
467         }
468         d.offset++
469
470         switch b {
471         case 'e':
472                 return nil, false
473         case 'd':
474                 return d.parse_dict_interface(), true
475         case 'l':
476                 return d.parse_list_interface(), true
477         case 'i':
478                 return d.parse_int_interface(), true
479         default:
480                 if b >= '0' && b <= '9' {
481                         // string
482                         // append first digit of the length to the buffer
483                         d.buf.WriteByte(b)
484                         return d.parse_string_interface(), true
485                 }
486
487                 // unknown value
488                 panic(&SyntaxError{
489                         Offset: d.offset - 1,
490                         What:   errors.New("unknown value type (invalid bencode?)"),
491                 })
492         }
493 }
494
495 func (d *decoder) parse_int_interface() interface{} {
496         start := d.offset - 1
497         d.read_until('e')
498         if d.buf.Len() == 0 {
499                 panic(&SyntaxError{
500                         Offset: start,
501                         What:   errors.New("empty integer value"),
502                 })
503         }
504
505         n, err := strconv.ParseInt(d.buf.String(), 10, 64)
506         check_for_int_parse_error(err, start)
507         d.buf.Reset()
508         return n
509 }
510
511 func (d *decoder) parse_string_interface() interface{} {
512         start := d.offset - 1
513
514         // read the string length first
515         d.read_until(':')
516         length, err := strconv.ParseInt(d.buf.String(), 10, 64)
517         check_for_int_parse_error(err, start)
518
519         d.buf.Reset()
520         n, err := io.CopyN(&d.buf, d, length)
521         d.offset += n
522         if err != nil {
523                 check_for_unexpected_eof(err, d.offset)
524                 panic(&SyntaxError{
525                         Offset: d.offset,
526                         What:   errors.New("unexpected I/O error: " + err.Error()),
527                 })
528         }
529
530         s := d.buf.String()
531         d.buf.Reset()
532         return s
533 }
534
535 func (d *decoder) parse_dict_interface() interface{} {
536         dict := make(map[string]interface{})
537         for {
538                 keyi, ok := d.parse_value_interface()
539                 if !ok {
540                         break
541                 }
542
543                 key, ok := keyi.(string)
544                 if !ok {
545                         panic(&SyntaxError{
546                                 Offset: d.offset,
547                                 What:   errors.New("non-string key in a dict"),
548                         })
549                 }
550
551                 valuei, ok := d.parse_value_interface()
552                 if !ok {
553                         panic(&SyntaxError{
554                                 Offset: d.offset,
555                                 What:   errors.New("unexpected end of dict, no matching value for a given key"),
556                         })
557                 }
558
559                 dict[key] = valuei
560         }
561         return dict
562 }
563
564 func (d *decoder) parse_list_interface() interface{} {
565         var list []interface{}
566         for {
567                 valuei, ok := d.parse_value_interface()
568                 if !ok {
569                         break
570                 }
571
572                 list = append(list, valuei)
573         }
574         if list == nil {
575                 list = make([]interface{}, 0, 0)
576         }
577         return list
578 }