From: Matt Joiner Date: Tue, 1 Oct 2013 08:43:18 +0000 (+1000) Subject: Fix request/chunk confusion, missing outgoing message prefix, protocol tests; improve... X-Git-Tag: v1.0.0~1809 X-Git-Url: http://www.git.stargrave.org/?a=commitdiff_plain;h=28531a4fccf0bdff0d40257a8d31bb0d4e9a0a64;p=btrtrc.git Fix request/chunk confusion, missing outgoing message prefix, protocol tests; improve request triggering --- diff --git a/client.go b/client.go index 55bf8d1a..fddefa37 100644 --- a/client.go +++ b/client.go @@ -49,17 +49,18 @@ const ( ) type piece struct { - State pieceState - Hash pieceSum - PendingChunks map[chunk]struct{} + State pieceState + Hash pieceSum + PendingChunkSpecs map[chunkSpec]struct{} } -type chunk struct { +type chunkSpec struct { Begin, Length peer_protocol.Integer } type request struct { - Index, Begin, Length peer_protocol.Integer + Index peer_protocol.Integer + chunkSpec } type connection struct { @@ -102,6 +103,9 @@ func (c *connection) Request(chunk request) bool { Length: chunk.Length, }) } + if c.Requests == nil { + c.Requests = make(map[request]struct{}, maxRequests) + } c.Requests[chunk] = struct{}{} return true } @@ -114,8 +118,9 @@ func (c *connection) SetInterested(interested bool) { Type: func() peer_protocol.MessageType { if interested { return peer_protocol.Interested + } else { + return peer_protocol.NotInterested } - return peer_protocol.NotInterested }(), }) c.Interested = interested @@ -124,7 +129,6 @@ func (c *connection) SetInterested(interested bool) { func (conn *connection) writer() { for { b := <-conn.write - log.Printf("writing %#v", string(b)) n, err := conn.Socket.Write(b) if err != nil { log.Print(err) @@ -134,6 +138,7 @@ func (conn *connection) writer() { if n != len(b) { panic("didn't write all bytes") } + log.Printf("wrote %#v", string(b)) } } @@ -144,7 +149,6 @@ func (conn *connection) writeOptimizer() { write := conn.write if pending.Len() == 0 { write = nil - nextWrite = nil } else { var err error nextWrite, err = pending.Front().Value.(encoding.BinaryMarshaler).MarshalBinary() @@ -177,9 +181,9 @@ func (t *torrent) bitfield() (bf []bool) { return } -func (t *torrent) pieceChunks(index int) (cs map[chunk]struct{}) { - cs = make(map[chunk]struct{}, (t.MetaInfo.PieceLength+chunkSize-1)/chunkSize) - c := chunk{ +func (t *torrent) pieceChunkSpecs(index int) (cs map[chunkSpec]struct{}) { + cs = make(map[chunkSpec]struct{}, (t.MetaInfo.PieceLength+chunkSize-1)/chunkSize) + c := chunkSpec{ Begin: 0, } for left := peer_protocol.Integer(t.PieceSize(index)); left > 0; left -= c.Length { @@ -193,7 +197,7 @@ func (t *torrent) pieceChunks(index int) (cs map[chunk]struct{}) { return } -func (t *torrent) chunkHeat() (ret map[request]int) { +func (t *torrent) requestHeat() (ret map[request]int) { ret = make(map[request]int) for _, conn := range t.Conns { for req, _ := range conn.Requests { @@ -352,6 +356,15 @@ func (me *client) initiateConn(peer Peer, torrent *torrent) { }() } +func (me *torrent) haveAnyPieces() bool { + for _, piece := range me.Pieces { + if piece.State == pieceStateComplete { + return true + } + } + return false +} + func (me *client) handshake(sock net.Conn, torrent *torrent, peerId [20]byte) { conn := &connection{ Socket: sock, @@ -400,10 +413,12 @@ func (me *client) handshake(sock net.Conn, torrent *torrent, peerId [20]byte) { } me.withContext(func() { me.addConnection(torrent, conn) - conn.Post(peer_protocol.Message{ - Type: peer_protocol.Bitfield, - Bitfield: torrent.bitfield(), - }) + if torrent.haveAnyPieces() { + conn.Post(peer_protocol.Message{ + Type: peer_protocol.Bitfield, + Bitfield: torrent.bitfield(), + }) + } go func() { defer me.withContext(func() { me.dropConnection(torrent, conn) @@ -417,29 +432,22 @@ func (me *client) handshake(sock net.Conn, torrent *torrent, peerId [20]byte) { } func (me *client) peerGotPiece(torrent *torrent, conn *connection, piece int) { - if torrent.Pieces[piece].State != pieceStateIncomplete { - return + if conn.PeerPieces == nil { + conn.PeerPieces = make([]bool, len(torrent.Pieces)) + } + conn.PeerPieces[piece] = true + if torrent.wantPiece(piece) { + conn.SetInterested(true) + me.replenishConnRequests(torrent, conn) } - conn.SetInterested(true) +} + +func (t *torrent) wantPiece(index int) bool { + return t.Pieces[index].State == pieceStateIncomplete } func (me *client) peerUnchoked(torrent *torrent, conn *connection) { - chunkHeatMap := torrent.chunkHeat() - for index, has := range conn.PeerPieces { - if !has { - continue - } - for chunk, _ := range torrent.Pieces[index].PendingChunks { - if _, ok := chunkHeatMap[chunk]; ok { - continue - } - conn.SetInterested(true) - if !conn.Request(chunk) { - return - } - } - } - conn.SetInterested(false) + me.replenishConnRequests(torrent, conn) } func (me *client) runConnection(torrent *torrent, conn *connection) error { @@ -457,6 +465,7 @@ func (me *client) runConnection(torrent *torrent, conn *connection) error { continue } go me.withContext(func() { + log.Print(msg) var err error switch msg.Type { case peer_protocol.Choke: @@ -470,12 +479,10 @@ func (me *client) runConnection(torrent *torrent, conn *connection) error { conn.PeerInterested = false case peer_protocol.Have: me.peerGotPiece(torrent, conn, int(msg.Index)) - conn.PeerPieces[msg.Index] = true case peer_protocol.Request: conn.PeerRequests[request{ - Index: msg.Index, - Begin: msg.Begin, - Length: msg.Length, + Index: msg.Index, + chunkSpec: chunkSpec{msg.Begin, msg.Length}, }] = struct{}{} case peer_protocol.Bitfield: if len(msg.Bitfield) < len(torrent.Pieces) { @@ -589,6 +596,33 @@ func (me *client) withContext(f func()) { me.actorTask <- f } +func (me *client) replenishConnRequests(torrent *torrent, conn *connection) { + if len(conn.Requests) >= maxRequests { + return + } + if conn.PeerChoked { + return + } + requestHeatMap := torrent.requestHeat() + for index, has := range conn.PeerPieces { + if !has { + continue + } + for chunkSpec, _ := range torrent.Pieces[index].PendingChunkSpecs { + request := request{peer_protocol.Integer(index), chunkSpec} + if heat := requestHeatMap[request]; heat > 0 { + continue + } + conn.SetInterested(true) + if !conn.Request(request) { + return + } + } + } + //conn.SetInterested(false) + +} + func (me *client) pieceHashed(ih InfoHash, piece int, correct bool) { torrent := me.torrents[ih] newState := func() pieceState { @@ -604,7 +638,7 @@ func (me *client) pieceHashed(ih InfoHash, piece int, correct bool) { } torrent.Pieces[piece].State = newState if newState == pieceStateIncomplete { - torrent.Pieces[piece].PendingChunks = torrent.pieceChunks(piece) + torrent.Pieces[piece].PendingChunkSpecs = torrent.pieceChunkSpecs(piece) } for _, conn := range torrent.Conns { if correct { @@ -614,7 +648,7 @@ func (me *client) pieceHashed(ih InfoHash, piece int, correct bool) { }) } else { if conn.PeerHasPiece(piece) { - conn.SetInterested(true) + me.replenishConnRequests(torrent, conn) } } } diff --git a/cmd/torrent/main.go b/cmd/torrent/main.go index 74380e9d..76a9d943 100644 --- a/cmd/torrent/main.go +++ b/cmd/torrent/main.go @@ -30,7 +30,7 @@ func main() { } err = client.AddPeers(torrent.BytesInfoHash(metaInfo.InfoHash), []torrent.Peer{{ IP: net.IPv4(127, 0, 0, 1), - Port: 53219, + Port: 50933, }}) if err != nil { log.Fatal(err) diff --git a/peer_protocol/protocol.go b/peer_protocol/protocol.go index dfef720e..4020c524 100644 --- a/peer_protocol/protocol.go +++ b/peer_protocol/protocol.go @@ -68,7 +68,11 @@ func (msg Message) MarshalBinary() (data []byte, err error) { default: err = errors.New("unknown message type") } - data = buf.Bytes() + data = make([]byte, 4+buf.Len()) + binary.BigEndian.PutUint32(data, uint32(buf.Len())) + if buf.Len() != copy(data[4:], buf.Bytes()) { + panic("bad copy") + } return } @@ -113,16 +117,6 @@ func (d *Decoder) Decode(msg *Message) (err error) { return } -func encodeMessage(type_ MessageType, data interface{}) []byte { - w := &bytes.Buffer{} - w.WriteByte(byte(type_)) - err := binary.Write(w, binary.BigEndian, data) - if err != nil { - panic(err) - } - return w.Bytes() -} - type Bytes []byte func (b Bytes) MarshalBinary() ([]byte, error) { diff --git a/peer_protocol/protocol_test.go b/peer_protocol/protocol_test.go index 7a7b8dd1..3c04aff8 100644 --- a/peer_protocol/protocol_test.go +++ b/peer_protocol/protocol_test.go @@ -10,22 +10,47 @@ func TestConstants(t *testing.T) { t.FailNow() } } + func TestBitfieldEncode(t *testing.T) { - bm := make(Bitfield, 37) - bm[2] = true - bm[7] = true - bm[32] = true - s := string(bm.Encode()) + bf := make([]bool, 37) + bf[2] = true + bf[7] = true + bf[32] = true + s := string(marshalBitfield(bf)) const expected = "\x21\x00\x00\x00\x80" if s != expected { t.Fatalf("got %#v, expected %#v", s, expected) } } +func TestBitfieldUnmarshal(t *testing.T) { + bf := unmarshalBitfield([]byte("\x81\x06")) + expected := make([]bool, 16) + expected[0] = true + expected[7] = true + expected[13] = true + expected[14] = true + if len(bf) != len(expected) { + t.FailNow() + } + for i := range expected { + if bf[i] != expected[i] { + t.FailNow() + } + } +} + func TestHaveEncode(t *testing.T) { - actual := string(Have(42).Encode()) + actualBytes, err := Message{ + Type: Have, + Index: 42, + }.MarshalBinary() + if err != nil { + t.Fatal(err) + } + actualString := string(actualBytes) expected := "\x04\x00\x00\x00\x2a" - if actual != expected { - t.Fatalf("expected %#v, got %#v", expected, actual) + if actualString != expected { + t.Fatalf("expected %#v, got %#v", expected, actualString) } }