]> Sergey Matveev's repositories - btrtrc.git/blobdiff - peer_protocol/decoder.go
Fix incorrect EOF when decoding some peer protocol message types
[btrtrc.git] / peer_protocol / decoder.go
index 37758fd8de495b44d3175156de6c42af535b95bc..9dfe125b1a45e2960e8f746df1fc95ad17bb67d6 100644 (file)
@@ -3,29 +3,28 @@ package peer_protocol
 import (
        "bufio"
        "encoding/binary"
-       "errors"
        "fmt"
        "io"
-       "io/ioutil"
        "sync"
+
+       "github.com/pkg/errors"
 )
 
 type Decoder struct {
-       R         *bufio.Reader
+       R *bufio.Reader
+       // This must return *[]byte where the slices can fit data for piece messages. I think we store
+       // *[]byte in the pool to avoid an extra allocation every time we put the slice back into the
+       // pool. The chunk size should not change for the life of the decoder.
        Pool      *sync.Pool
        MaxLength Integer // TODO: Should this include the length header or not?
 }
 
 // io.EOF is returned if the source terminates cleanly on a message boundary.
-// TODO: Is that before or after the message?
 func (d *Decoder) Decode(msg *Message) (err error) {
        var length Integer
-       err = binary.Read(d.R, binary.BigEndian, &length)
+       err = length.Read(d.R)
        if err != nil {
-               if err != io.EOF {
-                       err = fmt.Errorf("error reading message length: %s", err)
-               }
-               return
+               return fmt.Errorf("reading message length: %w", err)
        }
        if length > d.MaxLength {
                return errors.New("message too long")
@@ -34,27 +33,27 @@ func (d *Decoder) Decode(msg *Message) (err error) {
                msg.Keepalive = true
                return
        }
-       msg.Keepalive = false
-       r := &io.LimitedReader{d.R, int64(length)}
-       // Check that all of r was utilized.
+       r := d.R
+       readByte := func() (byte, error) {
+               length--
+               return d.R.ReadByte()
+       }
+       // From this point onwards, EOF is unexpected
        defer func() {
-               if err != nil {
-                       return
-               }
-               if r.N != 0 {
-                       err = fmt.Errorf("%d bytes unused in message type %d", r.N, msg.Type)
+               if err == io.EOF {
+                       err = io.ErrUnexpectedEOF
                }
        }()
-       msg.Keepalive = false
-       c, err := readByte(r)
+       c, err := readByte()
        if err != nil {
                return
        }
        msg.Type = MessageType(c)
+       // Can return directly in cases when err is not nil, or length is known to be zero.
        switch msg.Type {
        case Choke, Unchoke, Interested, NotInterested, HaveAll, HaveNone:
-               return
        case Have, AllowedFast, Suggest:
+               length -= 4
                err = msg.Index.Read(r)
        case Request, Cancel, Reject:
                for _, data := range []*Integer{&msg.Index, &msg.Begin, &msg.Length} {
@@ -63,41 +62,54 @@ func (d *Decoder) Decode(msg *Message) (err error) {
                                break
                        }
                }
+               length -= 12
        case Bitfield:
-               b := make([]byte, length-1)
+               b := make([]byte, length)
                _, err = io.ReadFull(r, b)
+               length = 0
                msg.Bitfield = unmarshalBitfield(b)
+               return
        case Piece:
                for _, pi := range []*Integer{&msg.Index, &msg.Begin} {
-                       err = pi.Read(r)
+                       err := pi.Read(r)
                        if err != nil {
-                               break
+                               return err
                        }
                }
-               if err != nil {
-                       break
-               }
-               //msg.Piece, err = ioutil.ReadAll(r)
-               b := *d.Pool.Get().(*[]byte)
-               n, err := io.ReadFull(r, b)
-               if err != nil {
-                       if err != io.ErrUnexpectedEOF || n != int(length-9) {
-                               return err
+               length -= 8
+               dataLen := int64(length)
+               if d.Pool == nil {
+                       msg.Piece = make([]byte, dataLen)
+               } else {
+                       msg.Piece = *d.Pool.Get().(*[]byte)
+                       if int64(cap(msg.Piece)) < dataLen {
+                               return errors.New("piece data longer than expected")
                        }
-                       b = b[0:n]
+                       msg.Piece = msg.Piece[:dataLen]
                }
-               msg.Piece = b
+               _, err = io.ReadFull(r, msg.Piece)
+               length = 0
+               return
        case Extended:
-               msg.ExtendedID, err = readByte(r)
+               var b byte
+               b, err = readByte()
                if err != nil {
                        break
                }
-               msg.ExtendedPayload, err = ioutil.ReadAll(r)
+               msg.ExtendedID = ExtensionNumber(b)
+               msg.ExtendedPayload = make([]byte, length)
+               _, err = io.ReadFull(r, msg.ExtendedPayload)
+               length = 0
+               return
        case Port:
                err = binary.Read(r, binary.BigEndian, &msg.Port)
+               length -= 2
        default:
                err = fmt.Errorf("unknown message type %#v", c)
        }
+       if err == nil && length != 0 {
+               err = fmt.Errorf("%v unused bytes in message type %v", length, msg.Type)
+       }
        return
 }