]> Sergey Matveev's repositories - btrtrc.git/commitdiff
More optimizations in peer protocol message decoding
authorMatt Joiner <anacrolix@gmail.com>
Thu, 30 Sep 2021 01:05:01 +0000 (11:05 +1000)
committerMatt Joiner <anacrolix@gmail.com>
Thu, 30 Sep 2021 01:05:01 +0000 (11:05 +1000)
peer_protocol/decoder.go
peer_protocol/protocol_test.go

index cba8dd087a0c71de35776bdbf195d2125e548f74..0963f668c9810edb42ee300ff1df70050d261391 100644 (file)
@@ -5,7 +5,6 @@ import (
        "encoding/binary"
        "fmt"
        "io"
-       "io/ioutil"
        "sync"
 
        "github.com/pkg/errors"
@@ -31,27 +30,20 @@ func (d *Decoder) Decode(msg *Message) (err error) {
                msg.Keepalive = true
                return
        }
-       msg.Keepalive = false
-       r := &io.LimitedReader{R: d.R, N: int64(length)}
-       // Check that all of r was utilized.
-       defer func() {
-               if err != nil {
-                       return
-               }
-               if r.N != 0 {
-                       err = fmt.Errorf("%d bytes unused in message type %d", r.N, msg.Type)
-               }
-       }()
-       msg.Keepalive = false
-       c, err := readByte(r)
+       r := d.R
+       readByte := func() (byte, error) {
+               length--
+               return d.R.ReadByte()
+       }
+       c, err := readByte()
        if err != nil {
                return
        }
        msg.Type = MessageType(c)
        switch msg.Type {
        case Choke, Unchoke, Interested, NotInterested, HaveAll, HaveNone:
-               return
        case Have, AllowedFast, Suggest:
+               length -= 4
                err = msg.Index.Read(r)
        case Request, Cancel, Reject:
                for _, data := range []*Integer{&msg.Index, &msg.Begin, &msg.Length} {
@@ -60,9 +52,11 @@ func (d *Decoder) Decode(msg *Message) (err error) {
                                break
                        }
                }
+               length -= 12
        case Bitfield:
-               b := make([]byte, length-1)
+               b := make([]byte, length)
                _, err = io.ReadFull(r, b)
+               length = 0
                msg.Bitfield = unmarshalBitfield(b)
        case Piece:
                for _, pi := range []*Integer{&msg.Index, &msg.Begin} {
@@ -71,7 +65,8 @@ func (d *Decoder) Decode(msg *Message) (err error) {
                                return err
                        }
                }
-               dataLen := r.N
+               length -= 8
+               dataLen := int64(length)
                msg.Piece = (*d.Pool.Get().(*[]byte))
                if int64(cap(msg.Piece)) < dataLen {
                        return errors.New("piece data longer than expected")
@@ -79,21 +74,31 @@ func (d *Decoder) Decode(msg *Message) (err error) {
                msg.Piece = msg.Piece[:dataLen]
                _, err := io.ReadFull(r, msg.Piece)
                if err != nil {
-                       return errors.Wrap(err, "reading piece data")
+                       return fmt.Errorf("reading piece data: %w", err)
                }
+               length = 0
        case Extended:
                var b byte
-               b, err = readByte(r)
+               b, err = readByte()
                if err != nil {
                        break
                }
                msg.ExtendedID = ExtensionNumber(b)
-               msg.ExtendedPayload, err = ioutil.ReadAll(r)
+               msg.ExtendedPayload = make([]byte, length)
+               _, err = io.ReadFull(r, msg.ExtendedPayload)
+               if err == io.EOF {
+                       err = io.ErrUnexpectedEOF
+               }
+               length = 0
        case Port:
                err = binary.Read(r, binary.BigEndian, &msg.Port)
+               length -= 2
        default:
                err = fmt.Errorf("unknown message type %#v", c)
        }
+       if err == nil && length != 0 {
+               err = fmt.Errorf("%v unused bytes in message type %v", length, msg.Type)
+       }
        return
 }
 
index eabbd54533e8f0e32ca984bebf1c5ee5fe812036..df01a1a676f49dc26907f6ab135864b590b6b577 100644 (file)
@@ -82,7 +82,7 @@ func TestShortRead(t *testing.T) {
        }
        msg := new(Message)
        err := dec.Decode(msg)
-       if !strings.Contains(err.Error(), "1 bytes unused in message type 0") {
+       if !strings.Contains(err.Error(), "1 unused bytes in message type Choke") {
                t.Fatal(err)
        }
 }