From 2027028539a7cf0a26543cbec36094c5e49f51c0 Mon Sep 17 00:00:00 2001 From: Matt Joiner Date: Thu, 30 Sep 2021 11:05:01 +1000 Subject: [PATCH] More optimizations in peer protocol message decoding --- peer_protocol/decoder.go | 45 +++++++++++++++++++--------------- peer_protocol/protocol_test.go | 2 +- 2 files changed, 26 insertions(+), 21 deletions(-) diff --git a/peer_protocol/decoder.go b/peer_protocol/decoder.go index cba8dd08..0963f668 100644 --- a/peer_protocol/decoder.go +++ b/peer_protocol/decoder.go @@ -5,7 +5,6 @@ import ( "encoding/binary" "fmt" "io" - "io/ioutil" "sync" "github.com/pkg/errors" @@ -31,27 +30,20 @@ func (d *Decoder) Decode(msg *Message) (err error) { msg.Keepalive = true return } - msg.Keepalive = false - r := &io.LimitedReader{R: d.R, N: int64(length)} - // Check that all of r was utilized. - defer func() { - if err != nil { - return - } - if r.N != 0 { - err = fmt.Errorf("%d bytes unused in message type %d", r.N, msg.Type) - } - }() - msg.Keepalive = false - c, err := readByte(r) + r := d.R + readByte := func() (byte, error) { + length-- + return d.R.ReadByte() + } + c, err := readByte() if err != nil { return } msg.Type = MessageType(c) switch msg.Type { case Choke, Unchoke, Interested, NotInterested, HaveAll, HaveNone: - return case Have, AllowedFast, Suggest: + length -= 4 err = msg.Index.Read(r) case Request, Cancel, Reject: for _, data := range []*Integer{&msg.Index, &msg.Begin, &msg.Length} { @@ -60,9 +52,11 @@ func (d *Decoder) Decode(msg *Message) (err error) { break } } + length -= 12 case Bitfield: - b := make([]byte, length-1) + b := make([]byte, length) _, err = io.ReadFull(r, b) + length = 0 msg.Bitfield = unmarshalBitfield(b) case Piece: for _, pi := range []*Integer{&msg.Index, &msg.Begin} { @@ -71,7 +65,8 @@ func (d *Decoder) Decode(msg *Message) (err error) { return err } } - dataLen := r.N + length -= 8 + dataLen := int64(length) msg.Piece = (*d.Pool.Get().(*[]byte)) if int64(cap(msg.Piece)) < dataLen { return errors.New("piece data longer than expected") @@ -79,21 +74,31 @@ func (d *Decoder) Decode(msg *Message) (err error) { msg.Piece = msg.Piece[:dataLen] _, err := io.ReadFull(r, msg.Piece) if err != nil { - return errors.Wrap(err, "reading piece data") + return fmt.Errorf("reading piece data: %w", err) } + length = 0 case Extended: var b byte - b, err = readByte(r) + b, err = readByte() if err != nil { break } msg.ExtendedID = ExtensionNumber(b) - msg.ExtendedPayload, err = ioutil.ReadAll(r) + msg.ExtendedPayload = make([]byte, length) + _, err = io.ReadFull(r, msg.ExtendedPayload) + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + length = 0 case Port: err = binary.Read(r, binary.BigEndian, &msg.Port) + length -= 2 default: err = fmt.Errorf("unknown message type %#v", c) } + if err == nil && length != 0 { + err = fmt.Errorf("%v unused bytes in message type %v", length, msg.Type) + } return } diff --git a/peer_protocol/protocol_test.go b/peer_protocol/protocol_test.go index eabbd545..df01a1a6 100644 --- a/peer_protocol/protocol_test.go +++ b/peer_protocol/protocol_test.go @@ -82,7 +82,7 @@ func TestShortRead(t *testing.T) { } msg := new(Message) err := dec.Decode(msg) - if !strings.Contains(err.Error(), "1 bytes unused in message type 0") { + if !strings.Contains(err.Error(), "1 unused bytes in message type Choke") { t.Fatal(err) } } -- 2.44.0