From: Matt Joiner <anacrolix@gmail.com>
Date: Thu, 22 May 2014 14:36:47 +0000 (+1000)
Subject: Avoid rebuffering in peer_protocol.Decode
X-Git-Tag: v1.0.0~1735
X-Git-Url: http://www.git.stargrave.org/?a=commitdiff_plain;h=dd30d144ae5e502a874d3b5c94538591aae3d9bc;p=btrtrc.git

Avoid rebuffering in peer_protocol.Decode
---

diff --git a/peer_protocol/protocol.go b/peer_protocol/protocol.go
index a7fcbfc2..75e33cab 100644
--- a/peer_protocol/protocol.go
+++ b/peer_protocol/protocol.go
@@ -105,7 +105,21 @@ func (d *Decoder) Decode(msg *Message) (err error) {
 	if length > d.MaxLength {
 		return errors.New("message too long")
 	}
-	r := bufio.NewReader(io.LimitReader(d.R, int64(length)))
+	if length == 0 {
+		msg.Keepalive = true
+		return
+	}
+	msg.Keepalive = false
+	b := make([]byte, length)
+	_, err = io.ReadFull(d.R, b)
+	if err == io.EOF {
+		err = io.ErrUnexpectedEOF
+		return
+	}
+	if err != nil {
+		return
+	}
+	r := bytes.NewReader(b)
 	defer func() {
 		written, _ := io.Copy(ioutil.Discard, r)
 		if written != 0 && err == nil {
@@ -114,10 +128,6 @@ func (d *Decoder) Decode(msg *Message) (err error) {
 			err = io.ErrUnexpectedEOF
 		}
 	}()
-	if length == 0 {
-		msg.Keepalive = true
-		return
-	}
 	msg.Keepalive = false
 	c, err := r.ReadByte()
 	if err != nil {
diff --git a/peer_protocol/protocol_test.go b/peer_protocol/protocol_test.go
index b818a2e2..169399ad 100644
--- a/peer_protocol/protocol_test.go
+++ b/peer_protocol/protocol_test.go
@@ -104,7 +104,7 @@ func TestUnexpectedEOF(t *testing.T) {
 		}
 		err := dec.Decode(msg)
 		if err != io.ErrUnexpectedEOF {
-			t.Fatal(err)
+			t.Fatalf("expected ErrUnexpectedEOF decoding %q, got %s", stream, err)
 		}
 	}
 }