peer_protocol/protocol.go | 30 ++++++++++++++++++++++++++---- diff --git a/peer_protocol/protocol.go b/peer_protocol/protocol.go index 4020c5240d3f8066b1799f87993f2c973ea4db1f..2f9547ef29dcf439686da9e610cbfe5b3280c52b 100644 --- a/peer_protocol/protocol.go +++ b/peer_protocol/protocol.go @@ -7,6 +7,7 @@ "encoding/binary" "errors" "fmt" "io" + "io/ioutil" ) type ( @@ -90,29 +91,50 @@ } if length > d.MaxLength { return errors.New("message too long") } + r := bufio.NewReader(io.LimitReader(d.R, int64(length))) if length == 0 { msg.Keepalive = true return } msg.Keepalive = false - c, err := d.R.ReadByte() + c, err := r.ReadByte() if err != nil { return } msg.Type = MessageType(c) + defer func() { + written, _ := io.Copy(ioutil.Discard, r) + if written != 0 && err != nil { + err = fmt.Errorf("short read on message type %d, left %d bytes", msg.Type, written) + } + }() switch msg.Type { case Choke, Unchoke, Interested, NotInterested: return case Have: - err = msg.Index.Read(d.R) + err = msg.Index.Read(r) case Request, Cancel: - err = binary.Read(d.R, binary.BigEndian, []*Integer{&msg.Index, &msg.Begin, &msg.Length}) + err = binary.Read(r, binary.BigEndian, []*Integer{&msg.Index, &msg.Begin, &msg.Length}) case Bitfield: b := make([]byte, length-1) - _, err = io.ReadFull(d.R, b) + _, err = io.ReadFull(r, b) msg.Bitfield = unmarshalBitfield(b) + case Piece: + for _, pi := range []*Integer{&msg.Index, &msg.Begin} { + err = pi.Read(r) + if err != nil { + break + } + } + if err != nil { + break + } + msg.Piece, err = ioutil.ReadAll(r) default: err = fmt.Errorf("unknown message type %#v", c) + } + if err != nil { + err = fmt.Errorf("decoding type %d: %s", msg.Type, err) } return }