From beb599698f8b0f9f9984c8bfdc73e6df5775c124 Mon Sep 17 00:00:00 2001 From: Matt Joiner Date: Wed, 2 Oct 2013 17:57:19 +1000 Subject: [PATCH] Decoding of Piece messages, and checking entire message is consumed --- peer_protocol/protocol.go | 30 ++++++++++++++++++++++++++---- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/peer_protocol/protocol.go b/peer_protocol/protocol.go index 4020c524..2f9547ef 100644 --- a/peer_protocol/protocol.go +++ b/peer_protocol/protocol.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "io" + "io/ioutil" ) type ( @@ -90,30 +91,51 @@ func (d *Decoder) Decode(msg *Message) (err error) { if length > d.MaxLength { return errors.New("message too long") } + r := bufio.NewReader(io.LimitReader(d.R, int64(length))) if length == 0 { msg.Keepalive = true return } msg.Keepalive = false - c, err := d.R.ReadByte() + c, err := r.ReadByte() if err != nil { return } msg.Type = MessageType(c) + defer func() { + written, _ := io.Copy(ioutil.Discard, r) + if written != 0 && err != nil { + err = fmt.Errorf("short read on message type %d, left %d bytes", msg.Type, written) + } + }() switch msg.Type { case Choke, Unchoke, Interested, NotInterested: return case Have: - err = msg.Index.Read(d.R) + err = msg.Index.Read(r) case Request, Cancel: - err = binary.Read(d.R, binary.BigEndian, []*Integer{&msg.Index, &msg.Begin, &msg.Length}) + err = binary.Read(r, binary.BigEndian, []*Integer{&msg.Index, &msg.Begin, &msg.Length}) case Bitfield: b := make([]byte, length-1) - _, err = io.ReadFull(d.R, b) + _, err = io.ReadFull(r, b) msg.Bitfield = unmarshalBitfield(b) + case Piece: + for _, pi := range []*Integer{&msg.Index, &msg.Begin} { + err = pi.Read(r) + if err != nil { + break + } + } + if err != nil { + break + } + msg.Piece, err = ioutil.ReadAll(r) default: err = fmt.Errorf("unknown message type %#v", c) } + if err != nil { + err = fmt.Errorf("decoding type %d: %s", msg.Type, err) + } return } -- 2.44.0