From 8cb39521f2ae36404d647571bd0acf6c29ba75eb Mon Sep 17 00:00:00 2001 From: Matt Joiner Date: Fri, 21 Mar 2014 00:42:40 +1100 Subject: [PATCH] Fix short read and report unexpected EOFs decoding peer protocol --- peer_protocol/protocol.go | 19 +++++++++--------- peer_protocol/protocol_test.go | 36 ++++++++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 10 deletions(-) diff --git a/peer_protocol/protocol.go b/peer_protocol/protocol.go index 281d8486..a7fcbfc2 100644 --- a/peer_protocol/protocol.go +++ b/peer_protocol/protocol.go @@ -93,7 +93,7 @@ func (msg Message) MarshalBinary() (data []byte, err error) { type Decoder struct { R *bufio.Reader - MaxLength Integer + MaxLength Integer // TODO: Should this include the length header or not? } func (d *Decoder) Decode(msg *Message) (err error) { @@ -106,6 +106,14 @@ func (d *Decoder) Decode(msg *Message) (err error) { return errors.New("message too long") } r := bufio.NewReader(io.LimitReader(d.R, int64(length))) + defer func() { + written, _ := io.Copy(ioutil.Discard, r) + if written != 0 && err == nil { + err = fmt.Errorf("short read on message type %d, left %d bytes", msg.Type, written) + } else if err == io.EOF { + err = io.ErrUnexpectedEOF + } + }() if length == 0 { msg.Keepalive = true return @@ -116,12 +124,6 @@ func (d *Decoder) Decode(msg *Message) (err error) { return } msg.Type = MessageType(c) - defer func() { - written, _ := io.Copy(ioutil.Discard, r) - if written != 0 && err != nil { - err = fmt.Errorf("short read on message type %d, left %d bytes", msg.Type, written) - } - }() switch msg.Type { case Choke, Unchoke, Interested, NotInterested: return @@ -152,9 +154,6 @@ func (d *Decoder) Decode(msg *Message) (err error) { default: err = fmt.Errorf("unknown message type %#v", c) } - if err != nil { - err = fmt.Errorf("decoding type %d: %s", msg.Type, err) - } return } diff --git a/peer_protocol/protocol_test.go b/peer_protocol/protocol_test.go index e9e45352..b818a2e2 100644 --- a/peer_protocol/protocol_test.go +++ b/peer_protocol/protocol_test.go @@ -1,7 +1,10 @@ package peer_protocol import ( + "bufio" "bytes" + "io" + "strings" "testing" ) @@ -72,3 +75,36 @@ func TestHaveEncode(t *testing.T) { t.Fatalf("expected %#v, got %#v", expected, actualString) } } + +func TestShortRead(t *testing.T) { + dec := Decoder{ + R: bufio.NewReader(bytes.NewBufferString("\x00\x00\x00\x02\x00!")), + MaxLength: 2, + } + msg := new(Message) + err := dec.Decode(msg) + if !strings.Contains(err.Error(), "short read") { + t.Fatal(err) + } +} + +func TestUnexpectedEOF(t *testing.T) { + msg := new(Message) + for _, stream := range []string{ + "\x00\x00\x00", // Header truncated. + "\x00\x00\x00\x01", // Expecting 1 more byte. + // Request with wrong length, and too short anyway. + "\x00\x00\x00\x06\x06\x00\x00\x00\x00\x00", + // Request truncated. + "\x00\x00\x00\x0b\x06\x00\x00\x00\x00\x00", + } { + dec := Decoder{ + R: bufio.NewReader(bytes.NewBufferString(stream)), + MaxLength: 42, + } + err := dec.Decode(msg) + if err != io.ErrUnexpectedEOF { + t.Fatal(err) + } + } +} -- 2.48.1