peer_protocol/protocol.go | 20 +++++++++++++++----- peer_protocol/protocol_test.go | 2 +- diff --git a/peer_protocol/protocol.go b/peer_protocol/protocol.go index a7fcbfc2b2da780121f61329eb476d9ea6bf5b3d..75e33cab453abd5db59bbf9a0dd1fef07cc9f4c8 100644 --- a/peer_protocol/protocol.go +++ b/peer_protocol/protocol.go @@ -105,7 +105,21 @@ } if length > d.MaxLength { return errors.New("message too long") } - r := bufio.NewReader(io.LimitReader(d.R, int64(length))) + if length == 0 { + msg.Keepalive = true + return + } + msg.Keepalive = false + b := make([]byte, length) + _, err = io.ReadFull(d.R, b) + if err == io.EOF { + err = io.ErrUnexpectedEOF + return + } + if err != nil { + return + } + r := bytes.NewReader(b) defer func() { written, _ := io.Copy(ioutil.Discard, r) if written != 0 && err == nil { @@ -114,10 +128,6 @@ } else if err == io.EOF { err = io.ErrUnexpectedEOF } }() - if length == 0 { - msg.Keepalive = true - return - } msg.Keepalive = false c, err := r.ReadByte() if err != nil { diff --git a/peer_protocol/protocol_test.go b/peer_protocol/protocol_test.go index b818a2e260be064d003dbb633103899185073b22..169399adebdde884cfeb62a018f859b269382f5d 100644 --- a/peer_protocol/protocol_test.go +++ b/peer_protocol/protocol_test.go @@ -104,7 +104,7 @@ MaxLength: 42, } err := dec.Decode(msg) if err != io.ErrUnexpectedEOF { - t.Fatal(err) + t.Fatalf("expected ErrUnexpectedEOF decoding %q, got %s", stream, err) } } }