]> Sergey Matveev's repositories - btrtrc.git/blob - bencode/decode.go
bencode.scanner.ReadByte returned errors when it shouldn't have
[btrtrc.git] / bencode / decode.go
1 package bencode
2
3 import (
4         "bytes"
5         "errors"
6         "fmt"
7         "io"
8         "math/big"
9         "reflect"
10         "runtime"
11         "strconv"
12         "strings"
13 )
14
15 type Decoder struct {
16         r interface {
17                 io.ByteScanner
18                 io.Reader
19         }
20         // Sum of bytes used to Decode values.
21         Offset int64
22         buf    bytes.Buffer
23         key    string
24 }
25
26 func (d *Decoder) Decode(v interface{}) (err error) {
27         defer func() {
28                 if e := recover(); e != nil {
29                         if _, ok := e.(runtime.Error); ok {
30                                 panic(e)
31                         }
32                         err = e.(error)
33                 }
34         }()
35
36         pv := reflect.ValueOf(v)
37         if pv.Kind() != reflect.Ptr || pv.IsNil() {
38                 return &UnmarshalInvalidArgError{reflect.TypeOf(v)}
39         }
40
41         if !d.parseValue(pv.Elem()) {
42                 d.throwSyntaxError(d.Offset-1, errors.New("unexpected 'e'"))
43         }
44         return nil
45 }
46
47 func checkForUnexpectedEOF(err error, offset int64) {
48         if err == io.EOF {
49                 panic(&SyntaxError{
50                         Offset: offset,
51                         What:   io.ErrUnexpectedEOF,
52                 })
53         }
54 }
55
56 func (d *Decoder) readByte() byte {
57         b, err := d.r.ReadByte()
58         if err != nil {
59                 checkForUnexpectedEOF(err, d.Offset)
60                 panic(err)
61         }
62
63         d.Offset++
64         return b
65 }
66
67 // reads data writing it to 'd.buf' until 'sep' byte is encountered, 'sep' byte
68 // is consumed, but not included into the 'd.buf'
69 func (d *Decoder) readUntil(sep byte) {
70         for {
71                 b := d.readByte()
72                 if b == sep {
73                         return
74                 }
75                 d.buf.WriteByte(b)
76         }
77 }
78
79 func checkForIntParseError(err error, offset int64) {
80         if err != nil {
81                 panic(&SyntaxError{
82                         Offset: offset,
83                         What:   err,
84                 })
85         }
86 }
87
88 func (d *Decoder) throwSyntaxError(offset int64, err error) {
89         panic(&SyntaxError{
90                 Offset: offset,
91                 What:   err,
92         })
93 }
94
95 // called when 'i' was consumed
96 func (d *Decoder) parseInt(v reflect.Value) {
97         start := d.Offset - 1
98         d.readUntil('e')
99         if d.buf.Len() == 0 {
100                 panic(&SyntaxError{
101                         Offset: start,
102                         What:   errors.New("empty integer value"),
103                 })
104         }
105
106         s := d.buf.String()
107
108         switch v.Kind() {
109         case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
110                 n, err := strconv.ParseInt(s, 10, 64)
111                 checkForIntParseError(err, start)
112
113                 if v.OverflowInt(n) {
114                         panic(&UnmarshalTypeError{
115                                 Value: "integer " + s,
116                                 Type:  v.Type(),
117                         })
118                 }
119                 v.SetInt(n)
120         case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
121                 n, err := strconv.ParseUint(s, 10, 64)
122                 checkForIntParseError(err, start)
123
124                 if v.OverflowUint(n) {
125                         panic(&UnmarshalTypeError{
126                                 Value: "integer " + s,
127                                 Type:  v.Type(),
128                         })
129                 }
130                 v.SetUint(n)
131         case reflect.Bool:
132                 v.SetBool(s != "0")
133         default:
134                 panic(&UnmarshalTypeError{
135                         Value: "integer " + s,
136                         Type:  v.Type(),
137                 })
138         }
139         d.buf.Reset()
140 }
141
142 func (d *Decoder) parseString(v reflect.Value) {
143         start := d.Offset - 1
144
145         // read the string length first
146         d.readUntil(':')
147         length, err := strconv.ParseInt(d.buf.String(), 10, 64)
148         checkForIntParseError(err, start)
149
150         d.buf.Reset()
151         n, err := io.CopyN(&d.buf, d.r, length)
152         d.Offset += n
153         if err != nil {
154                 checkForUnexpectedEOF(err, d.Offset)
155                 panic(&SyntaxError{
156                         Offset: d.Offset,
157                         What:   errors.New("unexpected I/O error: " + err.Error()),
158                 })
159         }
160
161         switch v.Kind() {
162         case reflect.String:
163                 v.SetString(d.buf.String())
164         case reflect.Slice:
165                 if v.Type().Elem().Kind() != reflect.Uint8 {
166                         panic(&UnmarshalTypeError{
167                                 Value: "string",
168                                 Type:  v.Type(),
169                         })
170                 }
171                 sl := make([]byte, len(d.buf.Bytes()))
172                 copy(sl, d.buf.Bytes())
173                 v.Set(reflect.ValueOf(sl))
174         default:
175                 panic(&UnmarshalTypeError{
176                         Value: "string",
177                         Type:  v.Type(),
178                 })
179         }
180
181         d.buf.Reset()
182 }
183
184 func (d *Decoder) parseDict(v reflect.Value) {
185         switch v.Kind() {
186         case reflect.Map:
187                 t := v.Type()
188                 if t.Key().Kind() != reflect.String {
189                         panic(&UnmarshalTypeError{
190                                 Value: "object",
191                                 Type:  t,
192                         })
193                 }
194                 if v.IsNil() {
195                         v.Set(reflect.MakeMap(t))
196                 }
197         case reflect.Struct:
198         default:
199                 panic(&UnmarshalTypeError{
200                         Value: "object",
201                         Type:  v.Type(),
202                 })
203         }
204
205         var mapElem reflect.Value
206
207         // so, at this point 'd' byte was consumed, let's just read key/value
208         // pairs one by one
209         for {
210                 var valuev reflect.Value
211                 keyv := reflect.ValueOf(&d.key).Elem()
212                 if !d.parseValue(keyv) {
213                         return
214                 }
215
216                 // get valuev as a map value or as a struct field
217                 switch v.Kind() {
218                 case reflect.Map:
219                         elem_type := v.Type().Elem()
220                         if !mapElem.IsValid() {
221                                 mapElem = reflect.New(elem_type).Elem()
222                         } else {
223                                 mapElem.Set(reflect.Zero(elem_type))
224                         }
225                         valuev = mapElem
226                 case reflect.Struct:
227                         var f reflect.StructField
228                         var ok bool
229
230                         t := v.Type()
231                         for i, n := 0, t.NumField(); i < n; i++ {
232                                 f = t.Field(i)
233                                 tag := f.Tag.Get("bencode")
234                                 if tag == "-" {
235                                         continue
236                                 }
237                                 if f.Anonymous {
238                                         continue
239                                 }
240
241                                 tag_name, _ := parseTag(tag)
242                                 if tag_name == d.key {
243                                         ok = true
244                                         break
245                                 }
246
247                                 if f.Name == d.key {
248                                         ok = true
249                                         break
250                                 }
251
252                                 if strings.EqualFold(f.Name, d.key) {
253                                         ok = true
254                                         break
255                                 }
256                         }
257
258                         if ok {
259                                 if f.PkgPath != "" {
260                                         panic(&UnmarshalFieldError{
261                                                 Key:   d.key,
262                                                 Type:  v.Type(),
263                                                 Field: f,
264                                         })
265                                 } else {
266                                         valuev = v.FieldByIndex(f.Index)
267                                 }
268                         } else {
269                                 _, ok := d.parseValueInterface()
270                                 if !ok {
271                                         return
272                                 }
273                                 continue
274                         }
275                 }
276
277                 // now we need to actually parse it
278                 if !d.parseValue(valuev) {
279                         return
280                 }
281
282                 if v.Kind() == reflect.Map {
283                         v.SetMapIndex(keyv, valuev)
284                 }
285         }
286 }
287
288 func (d *Decoder) parseList(v reflect.Value) {
289         switch v.Kind() {
290         case reflect.Array, reflect.Slice:
291         default:
292                 panic(&UnmarshalTypeError{
293                         Value: "array",
294                         Type:  v.Type(),
295                 })
296         }
297
298         i := 0
299         for {
300                 if v.Kind() == reflect.Slice && i >= v.Len() {
301                         v.Set(reflect.Append(v, reflect.Zero(v.Type().Elem())))
302                 }
303
304                 ok := false
305                 if i < v.Len() {
306                         ok = d.parseValue(v.Index(i))
307                 } else {
308                         _, ok = d.parseValueInterface()
309                 }
310
311                 if !ok {
312                         break
313                 }
314
315                 i++
316         }
317
318         if i < v.Len() {
319                 if v.Kind() == reflect.Array {
320                         z := reflect.Zero(v.Type().Elem())
321                         for n := v.Len(); i < n; i++ {
322                                 v.Index(i).Set(z)
323                         }
324                 } else {
325                         v.SetLen(i)
326                 }
327         }
328
329         if i == 0 && v.Kind() == reflect.Slice {
330                 v.Set(reflect.MakeSlice(v.Type(), 0, 0))
331         }
332 }
333
334 func (d *Decoder) readOneValue() bool {
335         b, err := d.r.ReadByte()
336         if err != nil {
337                 panic(err)
338         }
339         if b == 'e' {
340                 d.r.UnreadByte()
341                 return false
342         } else {
343                 d.Offset++
344                 d.buf.WriteByte(b)
345         }
346
347         switch b {
348         case 'd', 'l':
349                 // read until there is nothing to read
350                 for d.readOneValue() {
351                 }
352                 // consume 'e' as well
353                 b = d.readByte()
354                 d.buf.WriteByte(b)
355         case 'i':
356                 d.readUntil('e')
357                 d.buf.WriteString("e")
358         default:
359                 if b >= '0' && b <= '9' {
360                         start := d.buf.Len() - 1
361                         d.readUntil(':')
362                         length, err := strconv.ParseInt(d.buf.String()[start:], 10, 64)
363                         checkForIntParseError(err, d.Offset-1)
364
365                         d.buf.WriteString(":")
366                         n, err := io.CopyN(&d.buf, d.r, length)
367                         d.Offset += n
368                         if err != nil {
369                                 checkForUnexpectedEOF(err, d.Offset)
370                                 panic(&SyntaxError{
371                                         Offset: d.Offset,
372                                         What:   errors.New("unexpected I/O error: " + err.Error()),
373                                 })
374                         }
375                         break
376                 }
377
378                 d.raiseUnknownValueType(b, d.Offset-1)
379         }
380
381         return true
382
383 }
384
385 func (d *Decoder) parseUnmarshaler(v reflect.Value) bool {
386         m, ok := v.Interface().(Unmarshaler)
387         if !ok {
388                 // T doesn't work, try *T
389                 if v.Kind() != reflect.Ptr && v.CanAddr() {
390                         m, ok = v.Addr().Interface().(Unmarshaler)
391                         if ok {
392                                 v = v.Addr()
393                         }
394                 }
395         }
396         if ok && (v.Kind() != reflect.Ptr || !v.IsNil()) {
397                 if d.readOneValue() {
398                         err := m.UnmarshalBencode(d.buf.Bytes())
399                         d.buf.Reset()
400                         if err != nil {
401                                 panic(&UnmarshalerError{v.Type(), err})
402                         }
403                         return true
404                 }
405                 d.buf.Reset()
406         }
407
408         return false
409 }
410
411 // Returns true if there was a value and it's now stored in 'v', otherwise
412 // there was an end symbol ("e") and no value was stored.
413 func (d *Decoder) parseValue(v reflect.Value) bool {
414         // we support one level of indirection at the moment
415         if v.Kind() == reflect.Ptr {
416                 // if the pointer is nil, allocate a new element of the type it
417                 // points to
418                 if v.IsNil() {
419                         v.Set(reflect.New(v.Type().Elem()))
420                 }
421                 v = v.Elem()
422         }
423
424         if d.parseUnmarshaler(v) {
425                 return true
426         }
427
428         // common case: interface{}
429         if v.Kind() == reflect.Interface && v.NumMethod() == 0 {
430                 iface, _ := d.parseValueInterface()
431                 v.Set(reflect.ValueOf(iface))
432                 return true
433         }
434
435         b, err := d.r.ReadByte()
436         if err != nil {
437                 panic(err)
438         }
439         d.Offset++
440
441         switch b {
442         case 'e':
443                 return false
444         case 'd':
445                 d.parseDict(v)
446         case 'l':
447                 d.parseList(v)
448         case 'i':
449                 d.parseInt(v)
450         default:
451                 if b >= '0' && b <= '9' {
452                         // string
453                         // append first digit of the length to the buffer
454                         d.buf.WriteByte(b)
455                         d.parseString(v)
456                         break
457                 }
458
459                 d.raiseUnknownValueType(b, d.Offset-1)
460         }
461
462         return true
463 }
464
465 // An unknown bencode type character was encountered.
466 func (d *Decoder) raiseUnknownValueType(b byte, offset int64) {
467         panic(&SyntaxError{
468                 Offset: offset,
469                 What:   fmt.Errorf("unknown value type %+q", b),
470         })
471 }
472
473 func (d *Decoder) parseValueInterface() (interface{}, bool) {
474         b, err := d.r.ReadByte()
475         if err != nil {
476                 panic(err)
477         }
478         d.Offset++
479
480         switch b {
481         case 'e':
482                 return nil, false
483         case 'd':
484                 return d.parseDictInterface(), true
485         case 'l':
486                 return d.parseListInterface(), true
487         case 'i':
488                 return d.parseIntInterface(), true
489         default:
490                 if b >= '0' && b <= '9' {
491                         // string
492                         // append first digit of the length to the buffer
493                         d.buf.WriteByte(b)
494                         return d.parseStringInterface(), true
495                 }
496
497                 d.raiseUnknownValueType(b, d.Offset-1)
498                 panic("unreachable")
499         }
500 }
501
502 func (d *Decoder) parseIntInterface() (ret interface{}) {
503         start := d.Offset - 1
504         d.readUntil('e')
505         if d.buf.Len() == 0 {
506                 panic(&SyntaxError{
507                         Offset: start,
508                         What:   errors.New("empty integer value"),
509                 })
510         }
511
512         n, err := strconv.ParseInt(d.buf.String(), 10, 64)
513         if ne, ok := err.(*strconv.NumError); ok && ne.Err == strconv.ErrRange {
514                 i := new(big.Int)
515                 _, ok := i.SetString(d.buf.String(), 10)
516                 if !ok {
517                         panic(&SyntaxError{
518                                 Offset: start,
519                                 What:   errors.New("failed to parse integer"),
520                         })
521                 }
522                 ret = i
523         } else {
524                 checkForIntParseError(err, start)
525                 ret = n
526         }
527
528         d.buf.Reset()
529         return
530 }
531
532 func (d *Decoder) parseStringInterface() interface{} {
533         start := d.Offset - 1
534
535         // read the string length first
536         d.readUntil(':')
537         length, err := strconv.ParseInt(d.buf.String(), 10, 64)
538         checkForIntParseError(err, start)
539
540         d.buf.Reset()
541         n, err := io.CopyN(&d.buf, d.r, length)
542         d.Offset += n
543         if err != nil {
544                 checkForUnexpectedEOF(err, d.Offset)
545                 panic(&SyntaxError{
546                         Offset: d.Offset,
547                         What:   errors.New("unexpected I/O error: " + err.Error()),
548                 })
549         }
550
551         s := d.buf.String()
552         d.buf.Reset()
553         return s
554 }
555
556 func (d *Decoder) parseDictInterface() interface{} {
557         dict := make(map[string]interface{})
558         for {
559                 keyi, ok := d.parseValueInterface()
560                 if !ok {
561                         break
562                 }
563
564                 key, ok := keyi.(string)
565                 if !ok {
566                         panic(&SyntaxError{
567                                 Offset: d.Offset,
568                                 What:   errors.New("non-string key in a dict"),
569                         })
570                 }
571
572                 valuei, ok := d.parseValueInterface()
573                 if !ok {
574                         break
575                 }
576
577                 dict[key] = valuei
578         }
579         return dict
580 }
581
582 func (d *Decoder) parseListInterface() interface{} {
583         var list []interface{}
584         for {
585                 valuei, ok := d.parseValueInterface()
586                 if !ok {
587                         break
588                 }
589
590                 list = append(list, valuei)
591         }
592         if list == nil {
593                 list = make([]interface{}, 0, 0)
594         }
595         return list
596 }