]> Sergey Matveev's repositories - btrtrc.git/blobdiff - bencode/decode.go
Attribute accepted connection to holepunching when connect message is late
[btrtrc.git] / bencode / decode.go
index 31ce157fece95b9d505dfdf64bb5bf6375351e41..3839b849c21155cffb15132d12b310d3c3936900 100644 (file)
@@ -12,7 +12,16 @@ import (
        "sync"
 )
 
+// The default bencode string length limit. This is a poor attempt to prevent excessive memory
+// allocation when parsing, but also leaves the window open to implement a better solution.
+const DefaultDecodeMaxStrLen = 1<<27 - 1 // ~128MiB
+
+type MaxStrLen = int64
+
 type Decoder struct {
+       // Maximum parsed bencode string length. Defaults to DefaultMaxStrLen if zero.
+       MaxStrLen MaxStrLen
+
        r interface {
                io.ByteScanner
                io.Reader
@@ -28,14 +37,21 @@ func (d *Decoder) Decode(v interface{}) (err error) {
                        return
                }
                r := recover()
+               if r == nil {
+                       return
+               }
                _, ok := r.(runtime.Error)
                if ok {
                        panic(r)
                }
-               err, ok = r.(error)
-               if !ok && r != nil {
+               if err, ok = r.(error); !ok {
                        panic(r)
                }
+               // Errors thrown from deeper in parsing are unexpected. At value boundaries, errors should
+               // be returned directly (at least until all the panic nonsense is removed entirely).
+               if err == io.EOF {
+                       err = io.ErrUnexpectedEOF
+               }
        }()
 
        pv := reflect.ValueOf(v)
@@ -105,7 +121,7 @@ func (d *Decoder) throwSyntaxError(offset int64, err error) {
 func (d *Decoder) readInt() error {
        // start := d.Offset - 1
        d.readUntil('e')
-       if err := d.bufLeadingZero(); err != nil {
+       if err := d.checkBufferedInt(); err != nil {
                return err
        }
        // if d.buf.Len() == 0 {
@@ -161,25 +177,37 @@ func (d *Decoder) parseInt(v reflect.Value) error {
        return nil
 }
 
-func (d *Decoder) bufLeadingZero() error {
+func (d *Decoder) checkBufferedInt() error {
        b := d.buf.Bytes()
-       if len(b) > 1 && b[0] == '0' {
-               return fmt.Errorf("non-zero integer has leading zeroes: %q", b)
+       if len(b) <= 1 {
+               return nil
+       }
+       if b[0] == '-' {
+               b = b[1:]
+       }
+       if b[0] < '1' || b[0] > '9' {
+               return errors.New("invalid leading digit")
        }
        return nil
 }
 
-func (d *Decoder) parseStringLength() (uint64, error) {
+func (d *Decoder) parseStringLength() (int, error) {
        // We should have already consumed the first byte of the length into the Decoder buf.
        start := d.Offset - 1
        d.readUntil(':')
-       if err := d.bufLeadingZero(); err != nil {
+       if err := d.checkBufferedInt(); err != nil {
                return 0, err
        }
-       length, err := strconv.ParseUint(bytesAsString(d.buf.Bytes()), 10, 32)
+       // Really the limit should be the uint size for the platform. But we can't pass in an allocator,
+       // or limit total memory use in Go, the best we might hope to do is limit the size of a single
+       // decoded value (by reading it in in-place and then operating on a view).
+       length, err := strconv.ParseInt(bytesAsString(d.buf.Bytes()), 10, 0)
        checkForIntParseError(err, start)
+       if int64(length) > d.getMaxStrLen() {
+               err = fmt.Errorf("parsed string length %v exceeds limit (%v)", length, DefaultDecodeMaxStrLen)
+       }
        d.buf.Reset()
-       return length, err
+       return int(length), err
 }
 
 func (d *Decoder) parseString(v reflect.Value) error {
@@ -218,13 +246,25 @@ func (d *Decoder) parseString(v reflect.Value) error {
                if v.Type().Elem().Kind() != reflect.Uint8 {
                        break
                }
-               d.buf.Grow(int(length))
+               d.buf.Grow(length)
                b := d.buf.Bytes()[:length]
                read(b)
                reflect.Copy(v, reflect.ValueOf(b))
                return nil
+       case reflect.Bool:
+               d.buf.Grow(length)
+               b := d.buf.Bytes()[:length]
+               read(b)
+               x, err := strconv.ParseBool(bytesAsString(b))
+               if err != nil {
+                       x = length != 0
+               }
+               v.SetBool(x)
+               return nil
        }
-       d.buf.Grow(int(length))
+       // Can't move this into default clause because some cases above fail through to here after
+       // additional checks.
+       d.buf.Grow(length)
        read(d.buf.Bytes()[:length])
        // I believe we return here to support "ignore_unmarshal_type_error".
        return &UnmarshalTypeError{
@@ -545,7 +585,7 @@ func (d *Decoder) parseValue(v reflect.Value) (bool, error) {
 
        b, err := d.r.ReadByte()
        if err != nil {
-               panic(err)
+               return false, err
        }
        d.Offset++
 
@@ -636,14 +676,24 @@ func (d *Decoder) parseIntInterface() (ret interface{}) {
        return
 }
 
+func (d *Decoder) readBytes(length int) []byte {
+       b, err := io.ReadAll(io.LimitReader(d.r, int64(length)))
+       if err != nil {
+               panic(err)
+       }
+       if len(b) != length {
+               panic(fmt.Errorf("read %v bytes expected %v", len(b), length))
+       }
+       return b
+}
+
 func (d *Decoder) parseStringInterface() string {
        length, err := d.parseStringLength()
        if err != nil {
                panic(err)
        }
-       b := make([]byte, length)
-       n, err := io.ReadFull(d.r, b)
-       d.Offset += int64(n)
+       b := d.readBytes(int(length))
+       d.Offset += int64(len(b))
        if err != nil {
                panic(&SyntaxError{Offset: d.Offset, What: err})
        }
@@ -652,7 +702,10 @@ func (d *Decoder) parseStringInterface() string {
 
 func (d *Decoder) parseDictInterface() interface{} {
        dict := make(map[string]interface{})
+       var lastKey string
+       lastKeyOk := false
        for {
+               start := d.Offset
                keyi, ok := d.parseValueInterface()
                if !ok {
                        break
@@ -665,12 +718,17 @@ func (d *Decoder) parseDictInterface() interface{} {
                                What:   errors.New("non-string key in a dict"),
                        })
                }
-
+               if lastKeyOk && key <= lastKey {
+                       d.throwSyntaxError(start, fmt.Errorf("dict keys unsorted: %q <= %q", key, lastKey))
+               }
+               start = d.Offset
                valuei, ok := d.parseValueInterface()
                if !ok {
-                       break
+                       d.throwSyntaxError(start, fmt.Errorf("dict elem missing value [key=%v]", key))
                }
 
+               lastKey = key
+               lastKeyOk = true
                dict[key] = valuei
        }
        return dict
@@ -685,3 +743,10 @@ func (d *Decoder) parseListInterface() (list []interface{}) {
        }
        return
 }
+
+func (d *Decoder) getMaxStrLen() int64 {
+       if d.MaxStrLen == 0 {
+               return DefaultDecodeMaxStrLen
+       }
+       return d.MaxStrLen
+}