]> Sergey Matveev's repositories - btrtrc.git/commitdiff
Fix request/chunk confusion, missing outgoing message prefix, protocol tests; improve...
authorMatt Joiner <anacrolix@gmail.com>
Tue, 1 Oct 2013 08:43:18 +0000 (18:43 +1000)
committerMatt Joiner <anacrolix@gmail.com>
Tue, 1 Oct 2013 08:43:18 +0000 (18:43 +1000)
client.go
cmd/torrent/main.go
peer_protocol/protocol.go
peer_protocol/protocol_test.go

index 55bf8d1a6e52ea816a917f261991f2dbbfe5adf4..fddefa378cade07a3d4a7a94eba396c275a8bc98 100644 (file)
--- a/client.go
+++ b/client.go
@@ -49,17 +49,18 @@ const (
 )
 
 type piece struct {
-       State         pieceState
-       Hash          pieceSum
-       PendingChunks map[chunk]struct{}
+       State             pieceState
+       Hash              pieceSum
+       PendingChunkSpecs map[chunkSpec]struct{}
 }
 
-type chunk struct {
+type chunkSpec struct {
        Begin, Length peer_protocol.Integer
 }
 
 type request struct {
-       Index, Begin, Length peer_protocol.Integer
+       Index peer_protocol.Integer
+       chunkSpec
 }
 
 type connection struct {
@@ -102,6 +103,9 @@ func (c *connection) Request(chunk request) bool {
                        Length: chunk.Length,
                })
        }
+       if c.Requests == nil {
+               c.Requests = make(map[request]struct{}, maxRequests)
+       }
        c.Requests[chunk] = struct{}{}
        return true
 }
@@ -114,8 +118,9 @@ func (c *connection) SetInterested(interested bool) {
                Type: func() peer_protocol.MessageType {
                        if interested {
                                return peer_protocol.Interested
+                       } else {
+                               return peer_protocol.NotInterested
                        }
-                       return peer_protocol.NotInterested
                }(),
        })
        c.Interested = interested
@@ -124,7 +129,6 @@ func (c *connection) SetInterested(interested bool) {
 func (conn *connection) writer() {
        for {
                b := <-conn.write
-               log.Printf("writing %#v", string(b))
                n, err := conn.Socket.Write(b)
                if err != nil {
                        log.Print(err)
@@ -134,6 +138,7 @@ func (conn *connection) writer() {
                if n != len(b) {
                        panic("didn't write all bytes")
                }
+               log.Printf("wrote %#v", string(b))
        }
 }
 
@@ -144,7 +149,6 @@ func (conn *connection) writeOptimizer() {
                write := conn.write
                if pending.Len() == 0 {
                        write = nil
-                       nextWrite = nil
                } else {
                        var err error
                        nextWrite, err = pending.Front().Value.(encoding.BinaryMarshaler).MarshalBinary()
@@ -177,9 +181,9 @@ func (t *torrent) bitfield() (bf []bool) {
        return
 }
 
-func (t *torrent) pieceChunks(index int) (cs map[chunk]struct{}) {
-       cs = make(map[chunk]struct{}, (t.MetaInfo.PieceLength+chunkSize-1)/chunkSize)
-       c := chunk{
+func (t *torrent) pieceChunkSpecs(index int) (cs map[chunkSpec]struct{}) {
+       cs = make(map[chunkSpec]struct{}, (t.MetaInfo.PieceLength+chunkSize-1)/chunkSize)
+       c := chunkSpec{
                Begin: 0,
        }
        for left := peer_protocol.Integer(t.PieceSize(index)); left > 0; left -= c.Length {
@@ -193,7 +197,7 @@ func (t *torrent) pieceChunks(index int) (cs map[chunk]struct{}) {
        return
 }
 
-func (t *torrent) chunkHeat() (ret map[request]int) {
+func (t *torrent) requestHeat() (ret map[request]int) {
        ret = make(map[request]int)
        for _, conn := range t.Conns {
                for req, _ := range conn.Requests {
@@ -352,6 +356,15 @@ func (me *client) initiateConn(peer Peer, torrent *torrent) {
        }()
 }
 
+func (me *torrent) haveAnyPieces() bool {
+       for _, piece := range me.Pieces {
+               if piece.State == pieceStateComplete {
+                       return true
+               }
+       }
+       return false
+}
+
 func (me *client) handshake(sock net.Conn, torrent *torrent, peerId [20]byte) {
        conn := &connection{
                Socket:     sock,
@@ -400,10 +413,12 @@ func (me *client) handshake(sock net.Conn, torrent *torrent, peerId [20]byte) {
        }
        me.withContext(func() {
                me.addConnection(torrent, conn)
-               conn.Post(peer_protocol.Message{
-                       Type:     peer_protocol.Bitfield,
-                       Bitfield: torrent.bitfield(),
-               })
+               if torrent.haveAnyPieces() {
+                       conn.Post(peer_protocol.Message{
+                               Type:     peer_protocol.Bitfield,
+                               Bitfield: torrent.bitfield(),
+                       })
+               }
                go func() {
                        defer me.withContext(func() {
                                me.dropConnection(torrent, conn)
@@ -417,29 +432,22 @@ func (me *client) handshake(sock net.Conn, torrent *torrent, peerId [20]byte) {
 }
 
 func (me *client) peerGotPiece(torrent *torrent, conn *connection, piece int) {
-       if torrent.Pieces[piece].State != pieceStateIncomplete {
-               return
+       if conn.PeerPieces == nil {
+               conn.PeerPieces = make([]bool, len(torrent.Pieces))
+       }
+       conn.PeerPieces[piece] = true
+       if torrent.wantPiece(piece) {
+               conn.SetInterested(true)
+               me.replenishConnRequests(torrent, conn)
        }
-       conn.SetInterested(true)
+}
+
+func (t *torrent) wantPiece(index int) bool {
+       return t.Pieces[index].State == pieceStateIncomplete
 }
 
 func (me *client) peerUnchoked(torrent *torrent, conn *connection) {
-       chunkHeatMap := torrent.chunkHeat()
-       for index, has := range conn.PeerPieces {
-               if !has {
-                       continue
-               }
-               for chunk, _ := range torrent.Pieces[index].PendingChunks {
-                       if _, ok := chunkHeatMap[chunk]; ok {
-                               continue
-                       }
-                       conn.SetInterested(true)
-                       if !conn.Request(chunk) {
-                               return
-                       }
-               }
-       }
-       conn.SetInterested(false)
+       me.replenishConnRequests(torrent, conn)
 }
 
 func (me *client) runConnection(torrent *torrent, conn *connection) error {
@@ -457,6 +465,7 @@ func (me *client) runConnection(torrent *torrent, conn *connection) error {
                        continue
                }
                go me.withContext(func() {
+                       log.Print(msg)
                        var err error
                        switch msg.Type {
                        case peer_protocol.Choke:
@@ -470,12 +479,10 @@ func (me *client) runConnection(torrent *torrent, conn *connection) error {
                                conn.PeerInterested = false
                        case peer_protocol.Have:
                                me.peerGotPiece(torrent, conn, int(msg.Index))
-                               conn.PeerPieces[msg.Index] = true
                        case peer_protocol.Request:
                                conn.PeerRequests[request{
-                                       Index:  msg.Index,
-                                       Begin:  msg.Begin,
-                                       Length: msg.Length,
+                                       Index:     msg.Index,
+                                       chunkSpec: chunkSpec{msg.Begin, msg.Length},
                                }] = struct{}{}
                        case peer_protocol.Bitfield:
                                if len(msg.Bitfield) < len(torrent.Pieces) {
@@ -589,6 +596,33 @@ func (me *client) withContext(f func()) {
        me.actorTask <- f
 }
 
+func (me *client) replenishConnRequests(torrent *torrent, conn *connection) {
+       if len(conn.Requests) >= maxRequests {
+               return
+       }
+       if conn.PeerChoked {
+               return
+       }
+       requestHeatMap := torrent.requestHeat()
+       for index, has := range conn.PeerPieces {
+               if !has {
+                       continue
+               }
+               for chunkSpec, _ := range torrent.Pieces[index].PendingChunkSpecs {
+                       request := request{peer_protocol.Integer(index), chunkSpec}
+                       if heat := requestHeatMap[request]; heat > 0 {
+                               continue
+                       }
+                       conn.SetInterested(true)
+                       if !conn.Request(request) {
+                               return
+                       }
+               }
+       }
+       //conn.SetInterested(false)
+
+}
+
 func (me *client) pieceHashed(ih InfoHash, piece int, correct bool) {
        torrent := me.torrents[ih]
        newState := func() pieceState {
@@ -604,7 +638,7 @@ func (me *client) pieceHashed(ih InfoHash, piece int, correct bool) {
        }
        torrent.Pieces[piece].State = newState
        if newState == pieceStateIncomplete {
-               torrent.Pieces[piece].PendingChunks = torrent.pieceChunks(piece)
+               torrent.Pieces[piece].PendingChunkSpecs = torrent.pieceChunkSpecs(piece)
        }
        for _, conn := range torrent.Conns {
                if correct {
@@ -614,7 +648,7 @@ func (me *client) pieceHashed(ih InfoHash, piece int, correct bool) {
                        })
                } else {
                        if conn.PeerHasPiece(piece) {
-                               conn.SetInterested(true)
+                               me.replenishConnRequests(torrent, conn)
                        }
                }
        }
index 74380e9d8012d3b85d6228bff2c4586a6560d862..76a9d943111f8040003a4fc65eb2c909175b97f0 100644 (file)
@@ -30,7 +30,7 @@ func main() {
                }
                err = client.AddPeers(torrent.BytesInfoHash(metaInfo.InfoHash), []torrent.Peer{{
                        IP:   net.IPv4(127, 0, 0, 1),
-                       Port: 53219,
+                       Port: 50933,
                }})
                if err != nil {
                        log.Fatal(err)
index dfef720e3a148ada07030e194b49e7c8f8d126dd..4020c5240d3f8066b1799f87993f2c973ea4db1f 100644 (file)
@@ -68,7 +68,11 @@ func (msg Message) MarshalBinary() (data []byte, err error) {
        default:
                err = errors.New("unknown message type")
        }
-       data = buf.Bytes()
+       data = make([]byte, 4+buf.Len())
+       binary.BigEndian.PutUint32(data, uint32(buf.Len()))
+       if buf.Len() != copy(data[4:], buf.Bytes()) {
+               panic("bad copy")
+       }
        return
 }
 
@@ -113,16 +117,6 @@ func (d *Decoder) Decode(msg *Message) (err error) {
        return
 }
 
-func encodeMessage(type_ MessageType, data interface{}) []byte {
-       w := &bytes.Buffer{}
-       w.WriteByte(byte(type_))
-       err := binary.Write(w, binary.BigEndian, data)
-       if err != nil {
-               panic(err)
-       }
-       return w.Bytes()
-}
-
 type Bytes []byte
 
 func (b Bytes) MarshalBinary() ([]byte, error) {
index 7a7b8dd1ac1d109caafaa8811c6c5a86e8a68fb4..3c04aff85f23e9258a095cde464e4c15cda6bfc1 100644 (file)
@@ -10,22 +10,47 @@ func TestConstants(t *testing.T) {
                t.FailNow()
        }
 }
+
 func TestBitfieldEncode(t *testing.T) {
-       bm := make(Bitfield, 37)
-       bm[2] = true
-       bm[7] = true
-       bm[32] = true
-       s := string(bm.Encode())
+       bf := make([]bool, 37)
+       bf[2] = true
+       bf[7] = true
+       bf[32] = true
+       s := string(marshalBitfield(bf))
        const expected = "\x21\x00\x00\x00\x80"
        if s != expected {
                t.Fatalf("got %#v, expected %#v", s, expected)
        }
 }
 
+func TestBitfieldUnmarshal(t *testing.T) {
+       bf := unmarshalBitfield([]byte("\x81\x06"))
+       expected := make([]bool, 16)
+       expected[0] = true
+       expected[7] = true
+       expected[13] = true
+       expected[14] = true
+       if len(bf) != len(expected) {
+               t.FailNow()
+       }
+       for i := range expected {
+               if bf[i] != expected[i] {
+                       t.FailNow()
+               }
+       }
+}
+
 func TestHaveEncode(t *testing.T) {
-       actual := string(Have(42).Encode())
+       actualBytes, err := Message{
+               Type:  Have,
+               Index: 42,
+       }.MarshalBinary()
+       if err != nil {
+               t.Fatal(err)
+       }
+       actualString := string(actualBytes)
        expected := "\x04\x00\x00\x00\x2a"
-       if actual != expected {
-               t.Fatalf("expected %#v, got %#v", expected, actual)
+       if actualString != expected {
+               t.Fatalf("expected %#v, got %#v", expected, actualString)
        }
 }