"sync"
)
+// The default bencode string length limit. This is a poor attempt to prevent excessive memory
+// allocation when parsing, but also leaves the window open to implement a better solution.
+const DefaultDecodeMaxStrLen = 1<<27 - 1 // ~128MiB
+
+type MaxStrLen = int64
+
type Decoder struct {
+ // Maximum parsed bencode string length. Defaults to DefaultMaxStrLen if zero.
+ MaxStrLen MaxStrLen
+
r interface {
io.ByteScanner
io.Reader
return
}
r := recover()
+ if r == nil {
+ return
+ }
_, ok := r.(runtime.Error)
if ok {
panic(r)
}
- err, ok = r.(error)
- if !ok && r != nil {
+ if err, ok = r.(error); !ok {
panic(r)
}
+ // Errors thrown from deeper in parsing are unexpected. At value boundaries, errors should
+ // be returned directly (at least until all the panic nonsense is removed entirely).
+ if err == io.EOF {
+ err = io.ErrUnexpectedEOF
+ }
}()
pv := reflect.ValueOf(v)
func (d *Decoder) readInt() error {
// start := d.Offset - 1
d.readUntil('e')
- if err := d.bufLeadingZero(); err != nil {
+ if err := d.checkBufferedInt(); err != nil {
return err
}
// if d.buf.Len() == 0 {
return nil
}
-func (d *Decoder) bufLeadingZero() error {
+func (d *Decoder) checkBufferedInt() error {
b := d.buf.Bytes()
- if len(b) > 1 && b[0] == '0' {
- return fmt.Errorf("non-zero integer has leading zeroes: %q", b)
+ if len(b) <= 1 {
+ return nil
+ }
+ if b[0] == '-' {
+ b = b[1:]
+ }
+ if b[0] < '1' || b[0] > '9' {
+ return errors.New("invalid leading digit")
}
return nil
}
-func (d *Decoder) parseStringLength() (uint64, error) {
+func (d *Decoder) parseStringLength() (int, error) {
// We should have already consumed the first byte of the length into the Decoder buf.
start := d.Offset - 1
d.readUntil(':')
- if err := d.bufLeadingZero(); err != nil {
+ if err := d.checkBufferedInt(); err != nil {
return 0, err
}
- length, err := strconv.ParseUint(bytesAsString(d.buf.Bytes()), 10, 32)
+ // Really the limit should be the uint size for the platform. But we can't pass in an allocator,
+ // or limit total memory use in Go, the best we might hope to do is limit the size of a single
+ // decoded value (by reading it in in-place and then operating on a view).
+ length, err := strconv.ParseInt(bytesAsString(d.buf.Bytes()), 10, 0)
checkForIntParseError(err, start)
+ if int64(length) > d.getMaxStrLen() {
+ err = fmt.Errorf("parsed string length %v exceeds limit (%v)", length, DefaultDecodeMaxStrLen)
+ }
d.buf.Reset()
- return length, err
+ return int(length), err
}
func (d *Decoder) parseString(v reflect.Value) error {
if v.Type().Elem().Kind() != reflect.Uint8 {
break
}
- d.buf.Grow(int(length))
+ d.buf.Grow(length)
b := d.buf.Bytes()[:length]
read(b)
reflect.Copy(v, reflect.ValueOf(b))
return nil
+ case reflect.Bool:
+ d.buf.Grow(length)
+ b := d.buf.Bytes()[:length]
+ read(b)
+ x, err := strconv.ParseBool(bytesAsString(b))
+ if err != nil {
+ x = length != 0
+ }
+ v.SetBool(x)
+ return nil
}
- d.buf.Grow(int(length))
+ // Can't move this into default clause because some cases above fail through to here after
+ // additional checks.
+ d.buf.Grow(length)
read(d.buf.Bytes()[:length])
// I believe we return here to support "ignore_unmarshal_type_error".
return &UnmarshalTypeError{
b, err := d.r.ReadByte()
if err != nil {
- panic(err)
+ return false, err
}
d.Offset++
return
}
+func (d *Decoder) readBytes(length int) []byte {
+ b, err := io.ReadAll(io.LimitReader(d.r, int64(length)))
+ if err != nil {
+ panic(err)
+ }
+ if len(b) != length {
+ panic(fmt.Errorf("read %v bytes expected %v", len(b), length))
+ }
+ return b
+}
+
func (d *Decoder) parseStringInterface() string {
length, err := d.parseStringLength()
if err != nil {
panic(err)
}
- b := make([]byte, length)
- n, err := io.ReadFull(d.r, b)
- d.Offset += int64(n)
+ b := d.readBytes(int(length))
+ d.Offset += int64(len(b))
if err != nil {
panic(&SyntaxError{Offset: d.Offset, What: err})
}
func (d *Decoder) parseDictInterface() interface{} {
dict := make(map[string]interface{})
- lastKey := ""
+ var lastKey string
+ lastKeyOk := false
for {
start := d.Offset
keyi, ok := d.parseValueInterface()
What: errors.New("non-string key in a dict"),
})
}
- if key <= lastKey {
+ if lastKeyOk && key <= lastKey {
d.throwSyntaxError(start, fmt.Errorf("dict keys unsorted: %q <= %q", key, lastKey))
}
start = d.Offset
}
lastKey = key
+ lastKeyOk = true
dict[key] = valuei
}
return dict
}
return
}
+
+func (d *Decoder) getMaxStrLen() int64 {
+ if d.MaxStrLen == 0 {
+ return DefaultDecodeMaxStrLen
+ }
+ return d.MaxStrLen
+}