From bf5552ae3ca69660dcfc841debb5cdee5e69f0f8 Mon Sep 17 00:00:00 2001
From: Matt Joiner <anacrolix@gmail.com>
Date: Mon, 23 Jul 2018 13:12:14 +1000
Subject: [PATCH] bencode: Remove a lot of expensive allocations

---
 bencode/decode.go      | 82 +++++++++++++++++++++---------------------
 bencode/decode_test.go |  2 +-
 bencode/encode.go      | 32 +++++++----------
 bencode/misc.go        | 28 +++++++++++++++
 4 files changed, 83 insertions(+), 61 deletions(-)
 create mode 100644 bencode/misc.go

diff --git a/bencode/decode.go b/bencode/decode.go
index e0ff3874..c60bf17a 100644
--- a/bencode/decode.go
+++ b/bencode/decode.go
@@ -10,7 +10,6 @@ import (
 	"runtime"
 	"strconv"
 	"sync"
-	"unsafe"
 )
 
 type Decoder struct {
@@ -113,7 +112,7 @@ func (d *Decoder) parseInt(v reflect.Value) {
 		})
 	}
 
-	s := d.buf.String()
+	s := bytesAsString(d.buf.Bytes())
 
 	switch v.Kind() {
 	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
@@ -154,38 +153,49 @@ func (d *Decoder) parseString(v reflect.Value) error {
 
 	// read the string length first
 	d.readUntil(':')
-	length, err := strconv.ParseInt(d.buf.String(), 10, 64)
+	length, err := strconv.ParseInt(bytesAsString(d.buf.Bytes()), 10, 0)
 	checkForIntParseError(err, start)
 
-	d.buf.Reset()
-	n, err := io.CopyN(&d.buf, d.r, length)
-	d.Offset += n
-	if err != nil {
-		checkForUnexpectedEOF(err, d.Offset)
-		panic(&SyntaxError{
-			Offset: d.Offset,
-			What:   errors.New("unexpected I/O error: " + err.Error()),
-		})
+	defer d.buf.Reset()
+
+	read := func(b []byte) {
+		n, err := io.ReadFull(d.r, b)
+		d.Offset += int64(n)
+		if err != nil {
+			checkForUnexpectedEOF(err, d.Offset)
+			panic(&SyntaxError{
+				Offset: d.Offset,
+				What:   errors.New("unexpected I/O error: " + err.Error()),
+			})
+		}
 	}
 
-	defer d.buf.Reset()
 	switch v.Kind() {
 	case reflect.String:
-		v.SetString(d.buf.String())
+		b := make([]byte, length)
+		read(b)
+		v.SetString(bytesAsString(b))
 		return nil
 	case reflect.Slice:
 		if v.Type().Elem().Kind() != reflect.Uint8 {
 			break
 		}
-		v.SetBytes(append([]byte(nil), d.buf.Bytes()...))
+		b := make([]byte, length)
+		read(b)
+		v.SetBytes(b)
 		return nil
 	case reflect.Array:
 		if v.Type().Elem().Kind() != reflect.Uint8 {
 			break
 		}
-		reflect.Copy(v, reflect.ValueOf(d.buf.Bytes()))
+		d.buf.Grow(int(length))
+		b := d.buf.Bytes()[:length]
+		read(b)
+		reflect.Copy(v, reflect.ValueOf(b))
 		return nil
 	}
+	d.buf.Grow(int(length))
+	read(d.buf.Bytes()[:length])
 	// I believe we return here to support "ignore_unmarshal_type_error".
 	return &UnmarshalTypeError{
 		Value: "string",
@@ -409,11 +419,7 @@ func (d *Decoder) readOneValue() bool {
 		if b >= '0' && b <= '9' {
 			start := d.buf.Len() - 1
 			d.readUntil(':')
-			s := reflect.StringHeader{
-				uintptr(unsafe.Pointer(&d.buf.Bytes()[start])),
-				d.buf.Len() - start,
-			}
-			length, err := strconv.ParseInt(*(*string)(unsafe.Pointer(&s)), 10, 64)
+			length, err := strconv.ParseInt(bytesAsString(d.buf.Bytes()[start:]), 10, 64)
 			checkForIntParseError(err, d.Offset-1)
 
 			d.buf.WriteString(":")
@@ -437,29 +443,23 @@ func (d *Decoder) readOneValue() bool {
 }
 
 func (d *Decoder) parseUnmarshaler(v reflect.Value) bool {
-	m, ok := v.Interface().(Unmarshaler)
-	if !ok {
-		// T doesn't work, try *T
-		if v.Kind() != reflect.Ptr && v.CanAddr() {
-			m, ok = v.Addr().Interface().(Unmarshaler)
-			if ok {
-				v = v.Addr()
-			}
+	if !v.Type().Implements(unmarshalerType) {
+		if v.Addr().Type().Implements(unmarshalerType) {
+			v = v.Addr()
+		} else {
+			return false
 		}
 	}
-	if ok && (v.Kind() != reflect.Ptr || !v.IsNil()) {
-		if d.readOneValue() {
-			err := m.UnmarshalBencode(d.buf.Bytes())
-			d.buf.Reset()
-			if err != nil {
-				panic(&UnmarshalerError{v.Type(), err})
-			}
-			return true
-		}
-		d.buf.Reset()
+	d.buf.Reset()
+	if !d.readOneValue() {
+		return false
 	}
-
-	return false
+	m := v.Interface().(Unmarshaler)
+	err := m.UnmarshalBencode(d.buf.Bytes())
+	if err != nil {
+		panic(&UnmarshalerError{v.Type(), err})
+	}
+	return true
 }
 
 // Returns true if there was a value and it's now stored in 'v', otherwise
diff --git a/bencode/decode_test.go b/bencode/decode_test.go
index 13329044..f9876f5a 100644
--- a/bencode/decode_test.go
+++ b/bencode/decode_test.go
@@ -143,7 +143,7 @@ func TestIgnoreUnmarshalTypeError(t *testing.T) {
 		Normal int
 	}{}
 	require.Error(t, Unmarshal([]byte("d6:Normal5:helloe"), &s))
-	assert.Nil(t, Unmarshal([]byte("d6:Ignore5:helloe"), &s))
+	assert.NoError(t, Unmarshal([]byte("d6:Ignore5:helloe"), &s))
 	require.Nil(t, Unmarshal([]byte("d6:Ignorei42ee"), &s))
 	assert.EqualValues(t, 42, s.Ignore)
 }
diff --git a/bencode/encode.go b/bencode/encode.go
index 2bbf0bad..8542d430 100644
--- a/bencode/encode.go
+++ b/bencode/encode.go
@@ -77,29 +77,23 @@ func (e *Encoder) reflectByteSlice(s []byte) {
 	e.write(s)
 }
 
-// returns true if the value implements Marshaler interface and marshaling was
-// done successfully
+// Returns true if the value implements Marshaler interface and marshaling was
+// done successfully.
 func (e *Encoder) reflectMarshaler(v reflect.Value) bool {
-	m, ok := v.Interface().(Marshaler)
-	if !ok {
-		// T doesn't work, try *T
-		if v.Kind() != reflect.Ptr && v.CanAddr() {
-			m, ok = v.Addr().Interface().(Marshaler)
-			if ok {
-				v = v.Addr()
-			}
+	if !v.Type().Implements(marshalerType) {
+		if v.Kind() != reflect.Ptr && v.CanAddr() && v.Addr().Type().Implements(marshalerType) {
+			v = v.Addr()
+		} else {
+			return false
 		}
 	}
-	if ok && (v.Kind() != reflect.Ptr || !v.IsNil()) {
-		data, err := m.MarshalBencode()
-		if err != nil {
-			panic(&MarshalerError{v.Type(), err})
-		}
-		e.write(data)
-		return true
+	m := v.Interface().(Marshaler)
+	data, err := m.MarshalBencode()
+	if err != nil {
+		panic(&MarshalerError{v.Type(), err})
 	}
-
-	return false
+	e.write(data)
+	return true
 }
 
 var bigIntType = reflect.TypeOf(big.Int{})
diff --git a/bencode/misc.go b/bencode/misc.go
new file mode 100644
index 00000000..71199590
--- /dev/null
+++ b/bencode/misc.go
@@ -0,0 +1,28 @@
+package bencode
+
+import (
+	"reflect"
+	"unsafe"
+)
+
+// Wow Go is retarded.
+var marshalerType = reflect.TypeOf(func() *Marshaler {
+	var m Marshaler
+	return &m
+}()).Elem()
+
+// Wow Go is retarded.
+var unmarshalerType = reflect.TypeOf(func() *Unmarshaler {
+	var i Unmarshaler
+	return &i
+}()).Elem()
+
+func bytesAsString(b []byte) string {
+	if len(b) == 0 {
+		return ""
+	}
+	return *(*string)(unsafe.Pointer(&reflect.StringHeader{
+		uintptr(unsafe.Pointer(&b[0])),
+		len(b),
+	}))
+}
-- 
2.51.0