]> Sergey Matveev's repositories - btrtrc.git/commitdiff
Fix short read and report unexpected EOFs decoding peer protocol
authorMatt Joiner <anacrolix@gmail.com>
Thu, 20 Mar 2014 13:42:40 +0000 (00:42 +1100)
committerMatt Joiner <anacrolix@gmail.com>
Thu, 20 Mar 2014 13:42:40 +0000 (00:42 +1100)
peer_protocol/protocol.go
peer_protocol/protocol_test.go

index 281d84869a9666929ccedc6aa91081bc256dd5ae..a7fcbfc2b2da780121f61329eb476d9ea6bf5b3d 100644 (file)
@@ -93,7 +93,7 @@ func (msg Message) MarshalBinary() (data []byte, err error) {
 
 type Decoder struct {
        R         *bufio.Reader
-       MaxLength Integer
+       MaxLength Integer // TODO: Should this include the length header or not?
 }
 
 func (d *Decoder) Decode(msg *Message) (err error) {
@@ -106,6 +106,14 @@ func (d *Decoder) Decode(msg *Message) (err error) {
                return errors.New("message too long")
        }
        r := bufio.NewReader(io.LimitReader(d.R, int64(length)))
+       defer func() {
+               written, _ := io.Copy(ioutil.Discard, r)
+               if written != 0 && err == nil {
+                       err = fmt.Errorf("short read on message type %d, left %d bytes", msg.Type, written)
+               } else if err == io.EOF {
+                       err = io.ErrUnexpectedEOF
+               }
+       }()
        if length == 0 {
                msg.Keepalive = true
                return
@@ -116,12 +124,6 @@ func (d *Decoder) Decode(msg *Message) (err error) {
                return
        }
        msg.Type = MessageType(c)
-       defer func() {
-               written, _ := io.Copy(ioutil.Discard, r)
-               if written != 0 && err != nil {
-                       err = fmt.Errorf("short read on message type %d, left %d bytes", msg.Type, written)
-               }
-       }()
        switch msg.Type {
        case Choke, Unchoke, Interested, NotInterested:
                return
@@ -152,9 +154,6 @@ func (d *Decoder) Decode(msg *Message) (err error) {
        default:
                err = fmt.Errorf("unknown message type %#v", c)
        }
-       if err != nil {
-               err = fmt.Errorf("decoding type %d: %s", msg.Type, err)
-       }
        return
 }
 
index e9e453521d351558e5916154666c0b070d1d0191..b818a2e260be064d003dbb633103899185073b22 100644 (file)
@@ -1,7 +1,10 @@
 package peer_protocol
 
 import (
+       "bufio"
        "bytes"
+       "io"
+       "strings"
        "testing"
 )
 
@@ -72,3 +75,36 @@ func TestHaveEncode(t *testing.T) {
                t.Fatalf("expected %#v, got %#v", expected, actualString)
        }
 }
+
+func TestShortRead(t *testing.T) {
+       dec := Decoder{
+               R:         bufio.NewReader(bytes.NewBufferString("\x00\x00\x00\x02\x00!")),
+               MaxLength: 2,
+       }
+       msg := new(Message)
+       err := dec.Decode(msg)
+       if !strings.Contains(err.Error(), "short read") {
+               t.Fatal(err)
+       }
+}
+
+func TestUnexpectedEOF(t *testing.T) {
+       msg := new(Message)
+       for _, stream := range []string{
+               "\x00\x00\x00",     // Header truncated.
+               "\x00\x00\x00\x01", // Expecting 1 more byte.
+               // Request with wrong length, and too short anyway.
+               "\x00\x00\x00\x06\x06\x00\x00\x00\x00\x00",
+               // Request truncated.
+               "\x00\x00\x00\x0b\x06\x00\x00\x00\x00\x00",
+       } {
+               dec := Decoder{
+                       R:         bufio.NewReader(bytes.NewBufferString(stream)),
+                       MaxLength: 42,
+               }
+               err := dec.Decode(msg)
+               if err != io.ErrUnexpectedEOF {
+                       t.Fatal(err)
+               }
+       }
+}