]> Sergey Matveev's repositories - btrtrc.git/blobdiff - peer_protocol/decoder.go
Fix incorrect EOF when decoding some peer protocol message types
[btrtrc.git] / peer_protocol / decoder.go
index f4432f64d6d1750afe4c90b184557cf535b933cd..9dfe125b1a45e2960e8f746df1fc95ad17bb67d6 100644 (file)
@@ -11,7 +11,10 @@ import (
 )
 
 type Decoder struct {
-       R         *bufio.Reader
+       R *bufio.Reader
+       // This must return *[]byte where the slices can fit data for piece messages. I think we store
+       // *[]byte in the pool to avoid an extra allocation every time we put the slice back into the
+       // pool. The chunk size should not change for the life of the decoder.
        Pool      *sync.Pool
        MaxLength Integer // TODO: Should this include the length header or not?
 }
@@ -35,11 +38,18 @@ func (d *Decoder) Decode(msg *Message) (err error) {
                length--
                return d.R.ReadByte()
        }
+       // From this point onwards, EOF is unexpected
+       defer func() {
+               if err == io.EOF {
+                       err = io.ErrUnexpectedEOF
+               }
+       }()
        c, err := readByte()
        if err != nil {
                return
        }
        msg.Type = MessageType(c)
+       // Can return directly in cases when err is not nil, or length is known to be zero.
        switch msg.Type {
        case Choke, Unchoke, Interested, NotInterested, HaveAll, HaveNone:
        case Have, AllowedFast, Suggest:
@@ -58,6 +68,7 @@ func (d *Decoder) Decode(msg *Message) (err error) {
                _, err = io.ReadFull(r, b)
                length = 0
                msg.Bitfield = unmarshalBitfield(b)
+               return
        case Piece:
                for _, pi := range []*Integer{&msg.Index, &msg.Begin} {
                        err := pi.Read(r)
@@ -67,16 +78,18 @@ func (d *Decoder) Decode(msg *Message) (err error) {
                }
                length -= 8
                dataLen := int64(length)
-               msg.Piece = *d.Pool.Get().(*[]byte)
-               if int64(cap(msg.Piece)) < dataLen {
-                       return errors.New("piece data longer than expected")
-               }
-               msg.Piece = msg.Piece[:dataLen]
-               _, err := io.ReadFull(r, msg.Piece)
-               if err != nil {
-                       return fmt.Errorf("reading piece data: %w", err)
+               if d.Pool == nil {
+                       msg.Piece = make([]byte, dataLen)
+               } else {
+                       msg.Piece = *d.Pool.Get().(*[]byte)
+                       if int64(cap(msg.Piece)) < dataLen {
+                               return errors.New("piece data longer than expected")
+                       }
+                       msg.Piece = msg.Piece[:dataLen]
                }
+               _, err = io.ReadFull(r, msg.Piece)
                length = 0
+               return
        case Extended:
                var b byte
                b, err = readByte()
@@ -86,10 +99,8 @@ func (d *Decoder) Decode(msg *Message) (err error) {
                msg.ExtendedID = ExtensionNumber(b)
                msg.ExtendedPayload = make([]byte, length)
                _, err = io.ReadFull(r, msg.ExtendedPayload)
-               if err == io.EOF {
-                       err = io.ErrUnexpectedEOF
-               }
                length = 0
+               return
        case Port:
                err = binary.Read(r, binary.BigEndian, &msg.Port)
                length -= 2