import (
"bufio"
"encoding/binary"
- "errors"
"fmt"
"io"
- "io/ioutil"
"sync"
+
+ "github.com/pkg/errors"
)
type Decoder struct {
- R *bufio.Reader
+ R *bufio.Reader
+ // This must return *[]byte where the slices can fit data for piece messages. I think we store
+ // *[]byte in the pool to avoid an extra allocation every time we put the slice back into the
+ // pool. The chunk size should not change for the life of the decoder.
Pool *sync.Pool
MaxLength Integer // TODO: Should this include the length header or not?
}
// io.EOF is returned if the source terminates cleanly on a message boundary.
-// TODO: Is that before or after the message?
func (d *Decoder) Decode(msg *Message) (err error) {
var length Integer
- err = binary.Read(d.R, binary.BigEndian, &length)
+ err = length.Read(d.R)
if err != nil {
- if err != io.EOF {
- err = fmt.Errorf("error reading message length: %s", err)
- }
- return
+ return fmt.Errorf("reading message length: %w", err)
}
if length > d.MaxLength {
return errors.New("message too long")
msg.Keepalive = true
return
}
- msg.Keepalive = false
- r := &io.LimitedReader{R:d.R, N:int64(length)}
- // Check that all of r was utilized.
+ r := d.R
+ readByte := func() (byte, error) {
+ length--
+ return d.R.ReadByte()
+ }
+ // From this point onwards, EOF is unexpected
defer func() {
- if err != nil {
- return
- }
- if r.N != 0 {
- err = fmt.Errorf("%d bytes unused in message type %d", r.N, msg.Type)
+ if err == io.EOF {
+ err = io.ErrUnexpectedEOF
}
}()
- msg.Keepalive = false
- c, err := readByte(r)
+ c, err := readByte()
if err != nil {
return
}
msg.Type = MessageType(c)
+ // Can return directly in cases when err is not nil, or length is known to be zero.
switch msg.Type {
case Choke, Unchoke, Interested, NotInterested, HaveAll, HaveNone:
- return
case Have, AllowedFast, Suggest:
+ length -= 4
err = msg.Index.Read(r)
case Request, Cancel, Reject:
for _, data := range []*Integer{&msg.Index, &msg.Begin, &msg.Length} {
break
}
}
+ length -= 12
case Bitfield:
- b := make([]byte, length-1)
+ b := make([]byte, length)
_, err = io.ReadFull(r, b)
+ length = 0
msg.Bitfield = unmarshalBitfield(b)
+ return
case Piece:
for _, pi := range []*Integer{&msg.Index, &msg.Begin} {
- err = pi.Read(r)
+ err := pi.Read(r)
if err != nil {
- break
+ return err
}
}
- if err != nil {
- break
- }
- //msg.Piece, err = ioutil.ReadAll(r)
- b := *d.Pool.Get().(*[]byte)
- n, err := io.ReadFull(r, b)
- if err != nil {
- if err != io.ErrUnexpectedEOF || n != int(length-9) {
- return err
+ length -= 8
+ dataLen := int64(length)
+ if d.Pool == nil {
+ msg.Piece = make([]byte, dataLen)
+ } else {
+ msg.Piece = *d.Pool.Get().(*[]byte)
+ if int64(cap(msg.Piece)) < dataLen {
+ return errors.New("piece data longer than expected")
}
- b = b[0:n]
+ msg.Piece = msg.Piece[:dataLen]
}
- msg.Piece = b
+ _, err = io.ReadFull(r, msg.Piece)
+ length = 0
+ return
case Extended:
- msg.ExtendedID, err = readByte(r)
+ var b byte
+ b, err = readByte()
if err != nil {
break
}
- msg.ExtendedPayload, err = ioutil.ReadAll(r)
+ msg.ExtendedID = ExtensionNumber(b)
+ msg.ExtendedPayload = make([]byte, length)
+ _, err = io.ReadFull(r, msg.ExtendedPayload)
+ length = 0
+ return
case Port:
err = binary.Read(r, binary.BigEndian, &msg.Port)
+ length -= 2
default:
err = fmt.Errorf("unknown message type %#v", c)
}
+ if err == nil && length != 0 {
+ err = fmt.Errorf("%v unused bytes in message type %v", length, msg.Type)
+ }
return
}