]> Sergey Matveev's repositories - btrtrc.git/commitdiff
peer_protocol.Decoder.Decode: Avoid allocating another intermediate reader
authorMatt Joiner <anacrolix@gmail.com>
Sun, 25 Sep 2016 00:31:43 +0000 (10:31 +1000)
committerMatt Joiner <anacrolix@gmail.com>
Sun, 25 Sep 2016 00:31:43 +0000 (10:31 +1000)
On my system, BenchmarkConnectionMainReadLoop goes from 596 to 1311 MB/s.

peer_protocol/protocol.go

index 1e572bd89354bee443a703ff07e3aac501d5f5d7..f1f392f9a9c414d3ebb5702b38f77db29ff85c62 100644 (file)
@@ -128,6 +128,20 @@ type Decoder struct {
        MaxLength Integer // TODO: Should this include the length header or not?
 }
 
+func readByte(r io.Reader) (b byte, err error) {
+       var arr [1]byte
+       n, err := r.Read(arr[:])
+       b = arr[0]
+       if n == 1 {
+               err = nil
+               return
+       }
+       if err == nil {
+               panic(err)
+       }
+       return
+}
+
 // io.EOF is returned if the source terminates cleanly on a message boundary.
 func (d *Decoder) Decode(msg *Message) (err error) {
        var length Integer
@@ -146,29 +160,18 @@ func (d *Decoder) Decode(msg *Message) (err error) {
                return
        }
        msg.Keepalive = false
-       b := make([]byte, length)
-       _, err = io.ReadFull(d.R, b)
-       if err != nil {
-               if err == io.EOF {
-                       err = io.ErrUnexpectedEOF
-               }
-               if err != io.ErrUnexpectedEOF {
-                       err = fmt.Errorf("error reading message: %s", err)
-               }
-               return
-       }
-       r := bytes.NewReader(b)
+       r := &io.LimitedReader{d.R, int64(length)}
        // Check that all of r was utilized.
        defer func() {
                if err != nil {
                        return
                }
-               if r.Len() != 0 {
-                       err = fmt.Errorf("%d bytes unused in message type %d", r.Len(), msg.Type)
+               if r.N != 0 {
+                       err = fmt.Errorf("%d bytes unused in message type %d", r.N, msg.Type)
                }
        }()
        msg.Keepalive = false
-       c, err := r.ReadByte()
+       c, err := readByte(r)
        if err != nil {
                return
        }
@@ -210,7 +213,7 @@ func (d *Decoder) Decode(msg *Message) (err error) {
                }
                msg.Piece = b
        case Extended:
-               msg.ExtendedID, err = r.ReadByte()
+               msg.ExtendedID, err = readByte(r)
                if err != nil {
                        break
                }