]> Sergey Matveev's repositories - btrtrc.git/blobdiff - bencode/decode.go
Attribute accepted connection to holepunching when connect message is late
[btrtrc.git] / bencode / decode.go
index e891f5be0dc8b868d136d3003721a45729e365e8..3839b849c21155cffb15132d12b310d3c3936900 100644 (file)
@@ -12,7 +12,16 @@ import (
        "sync"
 )
 
+// The default bencode string length limit. This is a poor attempt to prevent excessive memory
+// allocation when parsing, but also leaves the window open to implement a better solution.
+const DefaultDecodeMaxStrLen = 1<<27 - 1 // ~128MiB
+
+type MaxStrLen = int64
+
 type Decoder struct {
+       // Maximum parsed bencode string length. Defaults to DefaultMaxStrLen if zero.
+       MaxStrLen MaxStrLen
+
        r interface {
                io.ByteScanner
                io.Reader
@@ -28,14 +37,21 @@ func (d *Decoder) Decode(v interface{}) (err error) {
                        return
                }
                r := recover()
+               if r == nil {
+                       return
+               }
                _, ok := r.(runtime.Error)
                if ok {
                        panic(r)
                }
-               err, ok = r.(error)
-               if !ok && r != nil {
+               if err, ok = r.(error); !ok {
                        panic(r)
                }
+               // Errors thrown from deeper in parsing are unexpected. At value boundaries, errors should
+               // be returned directly (at least until all the panic nonsense is removed entirely).
+               if err == io.EOF {
+                       err = io.ErrUnexpectedEOF
+               }
        }()
 
        pv := reflect.ValueOf(v)
@@ -101,17 +117,29 @@ func (d *Decoder) throwSyntaxError(offset int64, err error) {
        })
 }
 
-// called when 'i' was consumed
-func (d *Decoder) parseInt(v reflect.Value) {
-       start := d.Offset - 1
+// Assume the 'i' is already consumed. Read and validate the rest of an int into the buffer.
+func (d *Decoder) readInt() error {
+       // start := d.Offset - 1
        d.readUntil('e')
-       if d.buf.Len() == 0 {
-               panic(&SyntaxError{
-                       Offset: start,
-                       What:   errors.New("empty integer value"),
-               })
-       }
+       if err := d.checkBufferedInt(); err != nil {
+               return err
+       }
+       // if d.buf.Len() == 0 {
+       //      panic(&SyntaxError{
+       //              Offset: start,
+       //              What:   errors.New("empty integer value"),
+       //      })
+       // }
+       return nil
+}
 
+// called when 'i' was consumed, for the integer type in v.
+func (d *Decoder) parseInt(v reflect.Value) error {
+       start := d.Offset - 1
+
+       if err := d.readInt(); err != nil {
+               return err
+       }
        s := bytesAsString(d.buf.Bytes())
 
        switch v.Kind() {
@@ -120,10 +148,10 @@ func (d *Decoder) parseInt(v reflect.Value) {
                checkForIntParseError(err, start)
 
                if v.OverflowInt(n) {
-                       panic(&UnmarshalTypeError{
-                               Value: "integer " + s,
-                               Type:  v.Type(),
-                       })
+                       return &UnmarshalTypeError{
+                               BencodeTypeName:     "int",
+                               UnmarshalTargetType: v.Type(),
+                       }
                }
                v.SetInt(n)
        case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
@@ -131,33 +159,63 @@ func (d *Decoder) parseInt(v reflect.Value) {
                checkForIntParseError(err, start)
 
                if v.OverflowUint(n) {
-                       panic(&UnmarshalTypeError{
-                               Value: "integer " + s,
-                               Type:  v.Type(),
-                       })
+                       return &UnmarshalTypeError{
+                               BencodeTypeName:     "int",
+                               UnmarshalTargetType: v.Type(),
+                       }
                }
                v.SetUint(n)
        case reflect.Bool:
                v.SetBool(s != "0")
        default:
-               panic(&UnmarshalTypeError{
-                       Value: "integer " + s,
-                       Type:  v.Type(),
-               })
+               return &UnmarshalTypeError{
+                       BencodeTypeName:     "int",
+                       UnmarshalTargetType: v.Type(),
+               }
        }
        d.buf.Reset()
+       return nil
 }
 
-func (d *Decoder) parseString(v reflect.Value) error {
-       start := d.Offset - 1
+func (d *Decoder) checkBufferedInt() error {
+       b := d.buf.Bytes()
+       if len(b) <= 1 {
+               return nil
+       }
+       if b[0] == '-' {
+               b = b[1:]
+       }
+       if b[0] < '1' || b[0] > '9' {
+               return errors.New("invalid leading digit")
+       }
+       return nil
+}
 
-       // read the string length first
+func (d *Decoder) parseStringLength() (int, error) {
+       // We should have already consumed the first byte of the length into the Decoder buf.
+       start := d.Offset - 1
        d.readUntil(':')
+       if err := d.checkBufferedInt(); err != nil {
+               return 0, err
+       }
+       // Really the limit should be the uint size for the platform. But we can't pass in an allocator,
+       // or limit total memory use in Go, the best we might hope to do is limit the size of a single
+       // decoded value (by reading it in in-place and then operating on a view).
        length, err := strconv.ParseInt(bytesAsString(d.buf.Bytes()), 10, 0)
        checkForIntParseError(err, start)
+       if int64(length) > d.getMaxStrLen() {
+               err = fmt.Errorf("parsed string length %v exceeds limit (%v)", length, DefaultDecodeMaxStrLen)
+       }
+       d.buf.Reset()
+       return int(length), err
+}
 
+func (d *Decoder) parseString(v reflect.Value) error {
+       length, err := d.parseStringLength()
+       if err != nil {
+               return err
+       }
        defer d.buf.Reset()
-
        read := func(b []byte) {
                n, err := io.ReadFull(d.r, b)
                d.Offset += int64(n)
@@ -188,18 +246,30 @@ func (d *Decoder) parseString(v reflect.Value) error {
                if v.Type().Elem().Kind() != reflect.Uint8 {
                        break
                }
-               d.buf.Grow(int(length))
+               d.buf.Grow(length)
                b := d.buf.Bytes()[:length]
                read(b)
                reflect.Copy(v, reflect.ValueOf(b))
                return nil
+       case reflect.Bool:
+               d.buf.Grow(length)
+               b := d.buf.Bytes()[:length]
+               read(b)
+               x, err := strconv.ParseBool(bytesAsString(b))
+               if err != nil {
+                       x = length != 0
+               }
+               v.SetBool(x)
+               return nil
        }
-       d.buf.Grow(int(length))
+       // Can't move this into default clause because some cases above fail through to here after
+       // additional checks.
+       d.buf.Grow(length)
        read(d.buf.Bytes()[:length])
        // I believe we return here to support "ignore_unmarshal_type_error".
        return &UnmarshalTypeError{
-               Value: "string",
-               Type:  v.Type(),
+               BencodeTypeName:     "string",
+               UnmarshalTargetType: v.Type(),
        }
 }
 
@@ -211,9 +281,9 @@ type dictField struct {
 }
 
 // Returns specifics for parsing a dict field value.
-func getDictField(dict reflect.Type, key string) dictField {
+func getDictField(dict reflect.Type, key string) (_ dictField, err error) {
        // get valuev as a map value or as a struct field
-       switch dict.Kind() {
+       switch k := dict.Kind(); k {
        case reflect.Map:
                return dictField{
                        Type: dict.Elem(),
@@ -223,23 +293,23 @@ func getDictField(dict reflect.Type, key string) dictField {
                                                mapValue.Set(reflect.MakeMap(dict))
                                        }
                                        // Assigns the value into the map.
-                                       //log.Printf("map type: %v", mapValue.Type())
+                                       // log.Printf("map type: %v", mapValue.Type())
                                        mapValue.SetMapIndex(reflect.ValueOf(key).Convert(dict.Key()), value)
                                }
                        },
-               }
+               }, nil
        case reflect.Struct:
-               return getStructFieldForKey(dict, key)
-               //if sf.r.PkgPath != "" {
+               return getStructFieldForKey(dict, key), nil
+               // if sf.r.PkgPath != "" {
                //      panic(&UnmarshalFieldError{
                //              Key:   key,
                //              Type:  dict.Type(),
                //              Field: sf.r,
                //      })
-               //}
+               // }
        default:
-               panic("unimplemented")
-               return dictField{}
+               err = fmt.Errorf("can't assign bencode dict items into a %v", k)
+               return
        }
 }
 
@@ -253,14 +323,19 @@ func parseStructFields(struct_ reflect.Type, each func(key string, df dictField)
                i := _i
                f := struct_.Field(i)
                if f.Anonymous {
-                       parseStructFields(f.Type.Elem(), func(key string, df dictField) {
+                       t := f.Type
+                       if t.Kind() == reflect.Ptr {
+                               t = t.Elem()
+                       }
+                       parseStructFields(t, func(key string, df dictField) {
                                innerGet := df.Get
                                df.Get = func(value reflect.Value) func(reflect.Value) {
                                        anonPtr := value.Field(i)
-                                       if anonPtr.IsNil() {
+                                       if anonPtr.Kind() == reflect.Ptr && anonPtr.IsNil() {
                                                anonPtr.Set(reflect.New(f.Type.Elem()))
+                                               anonPtr = anonPtr.Elem()
                                        }
-                                       return innerGet(anonPtr.Elem())
+                                       return innerGet(anonPtr)
                                }
                                each(key, df)
                        })
@@ -308,20 +383,22 @@ func getStructFieldForKey(struct_ reflect.Type, key string) (f dictField) {
 }
 
 func (d *Decoder) parseDict(v reflect.Value) error {
-       // so, at this point 'd' byte was consumed, let's just read key/value
-       // pairs one by one
+       // At this point 'd' byte was consumed, now read key/value pairs
        for {
                var keyStr string
                keyValue := reflect.ValueOf(&keyStr).Elem()
                ok, err := d.parseValue(keyValue)
                if err != nil {
-                       return fmt.Errorf("error parsing dict key: %s", err)
+                       return fmt.Errorf("error parsing dict key: %w", err)
                }
                if !ok {
                        return nil
                }
 
-               df := getDictField(v.Type(), keyStr)
+               df, err := getDictField(v.Type(), keyStr)
+               if err != nil {
+                       return fmt.Errorf("parsing bencode dict into %v: %w", v.Type(), err)
+               }
 
                // now we need to actually parse it
                if df.Type == nil {
@@ -337,11 +414,12 @@ func (d *Decoder) parseDict(v reflect.Value) error {
                        continue
                }
                setValue := reflect.New(df.Type).Elem()
-               //log.Printf("parsing into %v", setValue.Type())
+               // log.Printf("parsing into %v", setValue.Type())
                ok, err = d.parseValue(setValue)
                if err != nil {
-                       if _, ok := err.(*UnmarshalTypeError); !ok || !df.Tags.IgnoreUnmarshalTypeError() {
-                               return fmt.Errorf("parsing value for key %q: %s", keyStr, err)
+                       var target *UnmarshalTypeError
+                       if !(errors.As(err, &target) && df.Tags.IgnoreUnmarshalTypeError()) {
+                               return fmt.Errorf("parsing value for key %q: %w", keyStr, err)
                        }
                }
                if !ok {
@@ -362,8 +440,8 @@ func (d *Decoder) parseList(v reflect.Value) error {
                }
                if l.Elem().Len() != 1 {
                        return &UnmarshalTypeError{
-                               Value: "list",
-                               Type:  v.Type(),
+                               BencodeTypeName:     "list",
+                               UnmarshalTargetType: v.Type(),
                        }
                }
                v.Set(l.Elem().Index(0))
@@ -459,7 +537,6 @@ func (d *Decoder) readOneValue() bool {
        }
 
        return true
-
 }
 
 func (d *Decoder) parseUnmarshaler(v reflect.Value) bool {
@@ -508,7 +585,7 @@ func (d *Decoder) parseValue(v reflect.Value) (bool, error) {
 
        b, err := d.r.ReadByte()
        if err != nil {
-               panic(err)
+               return false, err
        }
        d.Offset++
 
@@ -520,8 +597,7 @@ func (d *Decoder) parseValue(v reflect.Value) (bool, error) {
        case 'l':
                return true, d.parseList(v)
        case 'i':
-               d.parseInt(v)
-               return true, nil
+               return true, d.parseInt(v)
        default:
                if b >= '0' && b <= '9' {
                        // It's a string.
@@ -573,16 +649,13 @@ func (d *Decoder) parseValueInterface() (interface{}, bool) {
        }
 }
 
+// Called after 'i', for an arbitrary integer size.
 func (d *Decoder) parseIntInterface() (ret interface{}) {
        start := d.Offset - 1
-       d.readUntil('e')
-       if d.buf.Len() == 0 {
-               panic(&SyntaxError{
-                       Offset: start,
-                       What:   errors.New("empty integer value"),
-               })
-       }
 
+       if err := d.readInt(); err != nil {
+               panic(err)
+       }
        n, err := strconv.ParseInt(d.buf.String(), 10, 64)
        if ne, ok := err.(*strconv.NumError); ok && ne.Err == strconv.ErrRange {
                i := new(big.Int)
@@ -603,33 +676,36 @@ func (d *Decoder) parseIntInterface() (ret interface{}) {
        return
 }
 
-func (d *Decoder) parseStringInterface() interface{} {
-       start := d.Offset - 1
-
-       // read the string length first
-       d.readUntil(':')
-       length, err := strconv.ParseInt(d.buf.String(), 10, 64)
-       checkForIntParseError(err, start)
-
-       d.buf.Reset()
-       n, err := io.CopyN(&d.buf, d.r, length)
-       d.Offset += n
+func (d *Decoder) readBytes(length int) []byte {
+       b, err := io.ReadAll(io.LimitReader(d.r, int64(length)))
        if err != nil {
-               checkForUnexpectedEOF(err, d.Offset)
-               panic(&SyntaxError{
-                       Offset: d.Offset,
-                       What:   errors.New("unexpected I/O error: " + err.Error()),
-               })
+               panic(err)
+       }
+       if len(b) != length {
+               panic(fmt.Errorf("read %v bytes expected %v", len(b), length))
        }
+       return b
+}
 
-       s := d.buf.String()
-       d.buf.Reset()
-       return s
+func (d *Decoder) parseStringInterface() string {
+       length, err := d.parseStringLength()
+       if err != nil {
+               panic(err)
+       }
+       b := d.readBytes(int(length))
+       d.Offset += int64(len(b))
+       if err != nil {
+               panic(&SyntaxError{Offset: d.Offset, What: err})
+       }
+       return bytesAsString(b)
 }
 
 func (d *Decoder) parseDictInterface() interface{} {
        dict := make(map[string]interface{})
+       var lastKey string
+       lastKeyOk := false
        for {
+               start := d.Offset
                keyi, ok := d.parseValueInterface()
                if !ok {
                        break
@@ -642,29 +718,35 @@ func (d *Decoder) parseDictInterface() interface{} {
                                What:   errors.New("non-string key in a dict"),
                        })
                }
-
+               if lastKeyOk && key <= lastKey {
+                       d.throwSyntaxError(start, fmt.Errorf("dict keys unsorted: %q <= %q", key, lastKey))
+               }
+               start = d.Offset
                valuei, ok := d.parseValueInterface()
                if !ok {
-                       break
+                       d.throwSyntaxError(start, fmt.Errorf("dict elem missing value [key=%v]", key))
                }
 
+               lastKey = key
+               lastKeyOk = true
                dict[key] = valuei
        }
        return dict
 }
 
-func (d *Decoder) parseListInterface() interface{} {
-       var list []interface{}
-       for {
-               valuei, ok := d.parseValueInterface()
-               if !ok {
-                       break
-               }
-
+func (d *Decoder) parseListInterface() (list []interface{}) {
+       list = []interface{}{}
+       valuei, ok := d.parseValueInterface()
+       for ok {
                list = append(list, valuei)
+               valuei, ok = d.parseValueInterface()
        }
-       if list == nil {
-               list = make([]interface{}, 0, 0)
+       return
+}
+
+func (d *Decoder) getMaxStrLen() int64 {
+       if d.MaxStrLen == 0 {
+               return DefaultDecodeMaxStrLen
        }
-       return list
+       return d.MaxStrLen
 }