"errors"
"fmt"
"io"
+ "io/ioutil"
)
type (
if length > d.MaxLength {
return errors.New("message too long")
}
+ r := bufio.NewReader(io.LimitReader(d.R, int64(length)))
if length == 0 {
msg.Keepalive = true
return
}
msg.Keepalive = false
- c, err := d.R.ReadByte()
+ c, err := r.ReadByte()
if err != nil {
return
}
msg.Type = MessageType(c)
+ defer func() {
+ written, _ := io.Copy(ioutil.Discard, r)
+ if written != 0 && err != nil {
+ err = fmt.Errorf("short read on message type %d, left %d bytes", msg.Type, written)
+ }
+ }()
switch msg.Type {
case Choke, Unchoke, Interested, NotInterested:
return
case Have:
- err = msg.Index.Read(d.R)
+ err = msg.Index.Read(r)
case Request, Cancel:
- err = binary.Read(d.R, binary.BigEndian, []*Integer{&msg.Index, &msg.Begin, &msg.Length})
+ err = binary.Read(r, binary.BigEndian, []*Integer{&msg.Index, &msg.Begin, &msg.Length})
case Bitfield:
b := make([]byte, length-1)
- _, err = io.ReadFull(d.R, b)
+ _, err = io.ReadFull(r, b)
msg.Bitfield = unmarshalBitfield(b)
+ case Piece:
+ for _, pi := range []*Integer{&msg.Index, &msg.Begin} {
+ err = pi.Read(r)
+ if err != nil {
+ break
+ }
+ }
+ if err != nil {
+ break
+ }
+ msg.Piece, err = ioutil.ReadAll(r)
default:
err = fmt.Errorf("unknown message type %#v", c)
}
+ if err != nil {
+ err = fmt.Errorf("decoding type %d: %s", msg.Type, err)
+ }
return
}