]> Sergey Matveev's repositories - btrtrc.git/commitdiff
Avoid rebuffering in peer_protocol.Decode
authorMatt Joiner <anacrolix@gmail.com>
Thu, 22 May 2014 14:36:47 +0000 (00:36 +1000)
committerMatt Joiner <anacrolix@gmail.com>
Thu, 22 May 2014 14:36:47 +0000 (00:36 +1000)
peer_protocol/protocol.go
peer_protocol/protocol_test.go

index a7fcbfc2b2da780121f61329eb476d9ea6bf5b3d..75e33cab453abd5db59bbf9a0dd1fef07cc9f4c8 100644 (file)
@@ -105,7 +105,21 @@ func (d *Decoder) Decode(msg *Message) (err error) {
        if length > d.MaxLength {
                return errors.New("message too long")
        }
-       r := bufio.NewReader(io.LimitReader(d.R, int64(length)))
+       if length == 0 {
+               msg.Keepalive = true
+               return
+       }
+       msg.Keepalive = false
+       b := make([]byte, length)
+       _, err = io.ReadFull(d.R, b)
+       if err == io.EOF {
+               err = io.ErrUnexpectedEOF
+               return
+       }
+       if err != nil {
+               return
+       }
+       r := bytes.NewReader(b)
        defer func() {
                written, _ := io.Copy(ioutil.Discard, r)
                if written != 0 && err == nil {
@@ -114,10 +128,6 @@ func (d *Decoder) Decode(msg *Message) (err error) {
                        err = io.ErrUnexpectedEOF
                }
        }()
-       if length == 0 {
-               msg.Keepalive = true
-               return
-       }
        msg.Keepalive = false
        c, err := r.ReadByte()
        if err != nil {
index b818a2e260be064d003dbb633103899185073b22..169399adebdde884cfeb62a018f859b269382f5d 100644 (file)
@@ -104,7 +104,7 @@ func TestUnexpectedEOF(t *testing.T) {
                }
                err := dec.Decode(msg)
                if err != io.ErrUnexpectedEOF {
-                       t.Fatal(err)
+                       t.Fatalf("expected ErrUnexpectedEOF decoding %q, got %s", stream, err)
                }
        }
 }