From: Matt Joiner Date: Thu, 22 May 2014 14:36:47 +0000 (+1000) Subject: Avoid rebuffering in peer_protocol.Decode X-Git-Tag: v1.0.0~1735 X-Git-Url: http://www.git.stargrave.org/?a=commitdiff_plain;h=dd30d144ae5e502a874d3b5c94538591aae3d9bc;p=btrtrc.git Avoid rebuffering in peer_protocol.Decode --- diff --git a/peer_protocol/protocol.go b/peer_protocol/protocol.go index a7fcbfc2..75e33cab 100644 --- a/peer_protocol/protocol.go +++ b/peer_protocol/protocol.go @@ -105,7 +105,21 @@ func (d *Decoder) Decode(msg *Message) (err error) { if length > d.MaxLength { return errors.New("message too long") } - r := bufio.NewReader(io.LimitReader(d.R, int64(length))) + if length == 0 { + msg.Keepalive = true + return + } + msg.Keepalive = false + b := make([]byte, length) + _, err = io.ReadFull(d.R, b) + if err == io.EOF { + err = io.ErrUnexpectedEOF + return + } + if err != nil { + return + } + r := bytes.NewReader(b) defer func() { written, _ := io.Copy(ioutil.Discard, r) if written != 0 && err == nil { @@ -114,10 +128,6 @@ func (d *Decoder) Decode(msg *Message) (err error) { err = io.ErrUnexpectedEOF } }() - if length == 0 { - msg.Keepalive = true - return - } msg.Keepalive = false c, err := r.ReadByte() if err != nil { diff --git a/peer_protocol/protocol_test.go b/peer_protocol/protocol_test.go index b818a2e2..169399ad 100644 --- a/peer_protocol/protocol_test.go +++ b/peer_protocol/protocol_test.go @@ -104,7 +104,7 @@ func TestUnexpectedEOF(t *testing.T) { } err := dec.Decode(msg) if err != io.ErrUnexpectedEOF { - t.Fatal(err) + t.Fatalf("expected ErrUnexpectedEOF decoding %q, got %s", stream, err) } } }