]> Sergey Matveev's repositories - btrtrc.git/commitdiff
Keepalives weren't marshalled correctly
authorMatt Joiner <anacrolix@gmail.com>
Wed, 28 May 2014 15:22:51 +0000 (01:22 +1000)
committerMatt Joiner <anacrolix@gmail.com>
Wed, 28 May 2014 15:22:51 +0000 (01:22 +1000)
peer_protocol/protocol.go
peer_protocol/protocol_test.go

index 75e33cab453abd5db59bbf9a0dd1fef07cc9f4c8..90b01359e6db8b539cb532d8526375e88056d82e 100644 (file)
@@ -45,43 +45,41 @@ type Message struct {
 
 func (msg Message) MarshalBinary() (data []byte, err error) {
        buf := &bytes.Buffer{}
-       if msg.Keepalive {
-               data = buf.Bytes()
-               return
-       }
-       err = buf.WriteByte(byte(msg.Type))
-       if err != nil {
-               return
-       }
-       switch msg.Type {
-       case Choke, Unchoke, Interested, NotInterested:
-       case Have:
-               err = binary.Write(buf, binary.BigEndian, msg.Index)
-       case Request, Cancel:
-               for _, i := range []Integer{msg.Index, msg.Begin, msg.Length} {
-                       err = binary.Write(buf, binary.BigEndian, i)
+       if !msg.Keepalive {
+               err = buf.WriteByte(byte(msg.Type))
+               if err != nil {
+                       return
+               }
+               switch msg.Type {
+               case Choke, Unchoke, Interested, NotInterested:
+               case Have:
+                       err = binary.Write(buf, binary.BigEndian, msg.Index)
+               case Request, Cancel:
+                       for _, i := range []Integer{msg.Index, msg.Begin, msg.Length} {
+                               err = binary.Write(buf, binary.BigEndian, i)
+                               if err != nil {
+                                       break
+                               }
+                       }
+               case Bitfield:
+                       _, err = buf.Write(marshalBitfield(msg.Bitfield))
+               case Piece:
+                       for _, i := range []Integer{msg.Index, msg.Begin} {
+                               err = binary.Write(buf, binary.BigEndian, i)
+                               if err != nil {
+                                       return
+                               }
+                       }
+                       n, err := buf.Write(msg.Piece)
                        if err != nil {
                                break
                        }
-               }
-       case Bitfield:
-               _, err = buf.Write(marshalBitfield(msg.Bitfield))
-       case Piece:
-               for _, i := range []Integer{msg.Index, msg.Begin} {
-                       err = binary.Write(buf, binary.BigEndian, i)
-                       if err != nil {
-                               return
+                       if n != len(msg.Piece) {
+                               panic(n)
                        }
+               default:
+                       err = fmt.Errorf("unknown message type: %s", msg.Type)
                }
-               n, err := buf.Write(msg.Piece)
-               if err != nil {
-                       break
-               }
-               if n != len(msg.Piece) {
-                       panic(n)
-               }
-       default:
-               err = fmt.Errorf("unknown message type: %s", msg.Type)
        }
        data = make([]byte, 4+buf.Len())
        binary.BigEndian.PutUint32(data, uint32(buf.Len()))
index 169399adebdde884cfeb62a018f859b269382f5d..ca65d32b57e93fd51e3795f4758eb4656ca1e8cf 100644 (file)
@@ -108,3 +108,17 @@ func TestUnexpectedEOF(t *testing.T) {
                }
        }
 }
+
+func TestMarshalKeepalive(t *testing.T) {
+       b, err := (Message{
+               Keepalive: true,
+       }).MarshalBinary()
+       if err != nil {
+               t.Fatalf("error marshalling keepalive: %s", err)
+       }
+       bs := string(b)
+       const expected = "\x00\x00\x00\x00"
+       if bs != expected {
+               t.Fatalf("marshalled keepalive is %q, expected %q", bs, expected)
+       }
+}