From: Matt Joiner <anacrolix@gmail.com>
Date: Wed, 28 May 2014 15:22:51 +0000 (+1000)
Subject: Keepalives weren't marshalled correctly
X-Git-Tag: v1.0.0~1722
X-Git-Url: http://www.git.stargrave.org/?a=commitdiff_plain;h=ef69e827655fb65196f78aa0262961e3a5399e46;p=btrtrc.git

Keepalives weren't marshalled correctly
---

diff --git a/peer_protocol/protocol.go b/peer_protocol/protocol.go
index 75e33cab..90b01359 100644
--- a/peer_protocol/protocol.go
+++ b/peer_protocol/protocol.go
@@ -45,43 +45,41 @@ type Message struct {
 
 func (msg Message) MarshalBinary() (data []byte, err error) {
 	buf := &bytes.Buffer{}
-	if msg.Keepalive {
-		data = buf.Bytes()
-		return
-	}
-	err = buf.WriteByte(byte(msg.Type))
-	if err != nil {
-		return
-	}
-	switch msg.Type {
-	case Choke, Unchoke, Interested, NotInterested:
-	case Have:
-		err = binary.Write(buf, binary.BigEndian, msg.Index)
-	case Request, Cancel:
-		for _, i := range []Integer{msg.Index, msg.Begin, msg.Length} {
-			err = binary.Write(buf, binary.BigEndian, i)
+	if !msg.Keepalive {
+		err = buf.WriteByte(byte(msg.Type))
+		if err != nil {
+			return
+		}
+		switch msg.Type {
+		case Choke, Unchoke, Interested, NotInterested:
+		case Have:
+			err = binary.Write(buf, binary.BigEndian, msg.Index)
+		case Request, Cancel:
+			for _, i := range []Integer{msg.Index, msg.Begin, msg.Length} {
+				err = binary.Write(buf, binary.BigEndian, i)
+				if err != nil {
+					break
+				}
+			}
+		case Bitfield:
+			_, err = buf.Write(marshalBitfield(msg.Bitfield))
+		case Piece:
+			for _, i := range []Integer{msg.Index, msg.Begin} {
+				err = binary.Write(buf, binary.BigEndian, i)
+				if err != nil {
+					return
+				}
+			}
+			n, err := buf.Write(msg.Piece)
 			if err != nil {
 				break
 			}
-		}
-	case Bitfield:
-		_, err = buf.Write(marshalBitfield(msg.Bitfield))
-	case Piece:
-		for _, i := range []Integer{msg.Index, msg.Begin} {
-			err = binary.Write(buf, binary.BigEndian, i)
-			if err != nil {
-				return
+			if n != len(msg.Piece) {
+				panic(n)
 			}
+		default:
+			err = fmt.Errorf("unknown message type: %s", msg.Type)
 		}
-		n, err := buf.Write(msg.Piece)
-		if err != nil {
-			break
-		}
-		if n != len(msg.Piece) {
-			panic(n)
-		}
-	default:
-		err = fmt.Errorf("unknown message type: %s", msg.Type)
 	}
 	data = make([]byte, 4+buf.Len())
 	binary.BigEndian.PutUint32(data, uint32(buf.Len()))
diff --git a/peer_protocol/protocol_test.go b/peer_protocol/protocol_test.go
index 169399ad..ca65d32b 100644
--- a/peer_protocol/protocol_test.go
+++ b/peer_protocol/protocol_test.go
@@ -108,3 +108,17 @@ func TestUnexpectedEOF(t *testing.T) {
 		}
 	}
 }
+
+func TestMarshalKeepalive(t *testing.T) {
+	b, err := (Message{
+		Keepalive: true,
+	}).MarshalBinary()
+	if err != nil {
+		t.Fatalf("error marshalling keepalive: %s", err)
+	}
+	bs := string(b)
+	const expected = "\x00\x00\x00\x00"
+	if bs != expected {
+		t.Fatalf("marshalled keepalive is %q, expected %q", bs, expected)
+	}
+}