From ef69e827655fb65196f78aa0262961e3a5399e46 Mon Sep 17 00:00:00 2001 From: Matt Joiner Date: Thu, 29 May 2014 01:22:51 +1000 Subject: [PATCH] Keepalives weren't marshalled correctly --- peer_protocol/protocol.go | 62 ++++++++++++++++------------------ peer_protocol/protocol_test.go | 14 ++++++++ 2 files changed, 44 insertions(+), 32 deletions(-) diff --git a/peer_protocol/protocol.go b/peer_protocol/protocol.go index 75e33cab..90b01359 100644 --- a/peer_protocol/protocol.go +++ b/peer_protocol/protocol.go @@ -45,43 +45,41 @@ type Message struct { func (msg Message) MarshalBinary() (data []byte, err error) { buf := &bytes.Buffer{} - if msg.Keepalive { - data = buf.Bytes() - return - } - err = buf.WriteByte(byte(msg.Type)) - if err != nil { - return - } - switch msg.Type { - case Choke, Unchoke, Interested, NotInterested: - case Have: - err = binary.Write(buf, binary.BigEndian, msg.Index) - case Request, Cancel: - for _, i := range []Integer{msg.Index, msg.Begin, msg.Length} { - err = binary.Write(buf, binary.BigEndian, i) + if !msg.Keepalive { + err = buf.WriteByte(byte(msg.Type)) + if err != nil { + return + } + switch msg.Type { + case Choke, Unchoke, Interested, NotInterested: + case Have: + err = binary.Write(buf, binary.BigEndian, msg.Index) + case Request, Cancel: + for _, i := range []Integer{msg.Index, msg.Begin, msg.Length} { + err = binary.Write(buf, binary.BigEndian, i) + if err != nil { + break + } + } + case Bitfield: + _, err = buf.Write(marshalBitfield(msg.Bitfield)) + case Piece: + for _, i := range []Integer{msg.Index, msg.Begin} { + err = binary.Write(buf, binary.BigEndian, i) + if err != nil { + return + } + } + n, err := buf.Write(msg.Piece) if err != nil { break } - } - case Bitfield: - _, err = buf.Write(marshalBitfield(msg.Bitfield)) - case Piece: - for _, i := range []Integer{msg.Index, msg.Begin} { - err = binary.Write(buf, binary.BigEndian, i) - if err != nil { - return + if n != len(msg.Piece) { + panic(n) } + default: + err = fmt.Errorf("unknown message type: %s", msg.Type) } - n, err := buf.Write(msg.Piece) - if err != nil { - break - } - if n != len(msg.Piece) { - panic(n) - } - default: - err = fmt.Errorf("unknown message type: %s", msg.Type) } data = make([]byte, 4+buf.Len()) binary.BigEndian.PutUint32(data, uint32(buf.Len())) diff --git a/peer_protocol/protocol_test.go b/peer_protocol/protocol_test.go index 169399ad..ca65d32b 100644 --- a/peer_protocol/protocol_test.go +++ b/peer_protocol/protocol_test.go @@ -108,3 +108,17 @@ func TestUnexpectedEOF(t *testing.T) { } } } + +func TestMarshalKeepalive(t *testing.T) { + b, err := (Message{ + Keepalive: true, + }).MarshalBinary() + if err != nil { + t.Fatalf("error marshalling keepalive: %s", err) + } + bs := string(b) + const expected = "\x00\x00\x00\x00" + if bs != expected { + t.Fatalf("marshalled keepalive is %q, expected %q", bs, expected) + } +} -- 2.44.0