]> Sergey Matveev's repositories - btrtrc.git/blobdiff - client.go
Implementing bitfields and connection message handling
[btrtrc.git] / client.go
index 284c2c1cb4e7891919da4c18cf5943f6595c28d9..55bf8d1a6e52ea816a917f261991f2dbbfe5adf4 100644 (file)
--- a/client.go
+++ b/client.go
@@ -2,9 +2,11 @@ package torrent
 
 import (
        "bitbucket.org/anacrolix/go.torrent/peer_protocol"
+       "bufio"
        "container/list"
        "crypto"
        "crypto/rand"
+       "encoding"
        "errors"
        metainfo "github.com/nsf/libtorgo/torrent"
        "io"
@@ -16,7 +18,9 @@ import (
 )
 
 const (
-       PieceHash = crypto.SHA1
+       PieceHash   = crypto.SHA1
+       maxRequests = 10
+       chunkSize   = 0x4000 // 16KiB
 )
 
 type InfoHash [20]byte
@@ -45,24 +49,76 @@ const (
 )
 
 type piece struct {
-       State pieceState
-       Hash  pieceSum
+       State         pieceState
+       Hash          pieceSum
+       PendingChunks map[chunk]struct{}
+}
+
+type chunk struct {
+       Begin, Length peer_protocol.Integer
+}
+
+type request struct {
+       Index, Begin, Length peer_protocol.Integer
 }
 
 type connection struct {
        Socket net.Conn
-       post   chan peer_protocol.Message
+       post   chan encoding.BinaryMarshaler
        write  chan []byte
 
        Interested bool
        Choked     bool
-       Requests   []peer_protocol.Request
+       Requests   map[request]struct{}
 
        PeerId         [20]byte
        PeerInterested bool
        PeerChoked     bool
-       PeerRequests   []peer_protocol.Request
+       PeerRequests   map[request]struct{}
        PeerExtensions [8]byte
+       PeerPieces     []bool
+}
+
+func (c *connection) PeerHasPiece(index int) bool {
+       if c.PeerPieces == nil {
+               return false
+       }
+       return c.PeerPieces[index]
+}
+
+func (c *connection) Post(msg encoding.BinaryMarshaler) {
+       c.post <- msg
+}
+
+func (c *connection) Request(chunk request) bool {
+       if len(c.Requests) >= maxRequests {
+               return false
+       }
+       if _, ok := c.Requests[chunk]; !ok {
+               c.Post(peer_protocol.Message{
+                       Type:   peer_protocol.Request,
+                       Index:  chunk.Index,
+                       Begin:  chunk.Begin,
+                       Length: chunk.Length,
+               })
+       }
+       c.Requests[chunk] = struct{}{}
+       return true
+}
+
+func (c *connection) SetInterested(interested bool) {
+       if c.Interested == interested {
+               return
+       }
+       c.Post(peer_protocol.Message{
+               Type: func() peer_protocol.MessageType {
+                       if interested {
+                               return peer_protocol.Interested
+                       }
+                       return peer_protocol.NotInterested
+               }(),
+       })
+       c.Interested = interested
 }
 
 func (conn *connection) writer() {
@@ -90,7 +146,11 @@ func (conn *connection) writeOptimizer() {
                        write = nil
                        nextWrite = nil
                } else {
-                       nextWrite = pending.Front().Value.(peer_protocol.Message).Encode()
+                       var err error
+                       nextWrite, err = pending.Front().Value.(encoding.BinaryMarshaler).MarshalBinary()
+                       if err != nil {
+                               panic(err)
+                       }
                }
                select {
                case msg := <-conn.post:
@@ -106,10 +166,43 @@ type torrent struct {
        Pieces   []piece
        Data     MMapSpan
        MetaInfo *metainfo.MetaInfo
-       Conns    []connection
+       Conns    []*connection
        Peers    []Peer
 }
 
+func (t *torrent) bitfield() (bf []bool) {
+       for _, p := range t.Pieces {
+               bf = append(bf, p.State == pieceStateComplete)
+       }
+       return
+}
+
+func (t *torrent) pieceChunks(index int) (cs map[chunk]struct{}) {
+       cs = make(map[chunk]struct{}, (t.MetaInfo.PieceLength+chunkSize-1)/chunkSize)
+       c := chunk{
+               Begin: 0,
+       }
+       for left := peer_protocol.Integer(t.PieceSize(index)); left > 0; left -= c.Length {
+               c.Length = left
+               if c.Length > chunkSize {
+                       c.Length = chunkSize
+               }
+               cs[c] = struct{}{}
+               c.Begin += c.Length
+       }
+       return
+}
+
+func (t *torrent) chunkHeat() (ret map[request]int) {
+       ret = make(map[request]int)
+       for _, conn := range t.Conns {
+               for req, _ := range conn.Requests {
+                       ret[req]++
+               }
+       }
+       return
+}
+
 type Peer struct {
        Id   [20]byte
        IP   net.IP
@@ -143,6 +236,8 @@ func (t *torrent) HashPiece(piece int) (ps pieceSum) {
        return
 }
 
+// func (t *torrent) bitfield
+
 type client struct {
        DataDir       string
        HalfOpenLimit int
@@ -252,7 +347,7 @@ func (me *client) initiateConn(peer Peer, torrent *torrent) {
                        log.Printf("error connecting to peer: %s", err)
                        return
                }
-               log.Printf("connected to %s", sock.RemoteAddr())
+               log.Printf("connected to %s", conn.RemoteAddr())
                me.handshake(conn, torrent, peer.Id)
        }()
 }
@@ -263,7 +358,7 @@ func (me *client) handshake(sock net.Conn, torrent *torrent, peerId [20]byte) {
                Choked:     true,
                PeerChoked: true,
                write:      make(chan []byte),
-               post:       make(chan peer_protocol.Message),
+               post:       make(chan encoding.BinaryMarshaler),
        }
        go conn.writer()
        go conn.writeOptimizer()
@@ -303,6 +398,135 @@ func (me *client) handshake(sock net.Conn, torrent *torrent, peerId [20]byte) {
                conn.post <- peer_protocol.Bytes(torrent.InfoHash[:])
                conn.post <- peer_protocol.Bytes(me.PeerId[:])
        }
+       me.withContext(func() {
+               me.addConnection(torrent, conn)
+               conn.Post(peer_protocol.Message{
+                       Type:     peer_protocol.Bitfield,
+                       Bitfield: torrent.bitfield(),
+               })
+               go func() {
+                       defer me.withContext(func() {
+                               me.dropConnection(torrent, conn)
+                       })
+                       err := me.runConnection(torrent, conn)
+                       if err != nil {
+                               log.Print(err)
+                       }
+               }()
+       })
+}
+
+func (me *client) peerGotPiece(torrent *torrent, conn *connection, piece int) {
+       if torrent.Pieces[piece].State != pieceStateIncomplete {
+               return
+       }
+       conn.SetInterested(true)
+}
+
+func (me *client) peerUnchoked(torrent *torrent, conn *connection) {
+       chunkHeatMap := torrent.chunkHeat()
+       for index, has := range conn.PeerPieces {
+               if !has {
+                       continue
+               }
+               for chunk, _ := range torrent.Pieces[index].PendingChunks {
+                       if _, ok := chunkHeatMap[chunk]; ok {
+                               continue
+                       }
+                       conn.SetInterested(true)
+                       if !conn.Request(chunk) {
+                               return
+                       }
+               }
+       }
+       conn.SetInterested(false)
+}
+
+func (me *client) runConnection(torrent *torrent, conn *connection) error {
+       decoder := peer_protocol.Decoder{
+               R:         bufio.NewReader(conn.Socket),
+               MaxLength: 256 * 1024,
+       }
+       for {
+               msg := new(peer_protocol.Message)
+               err := decoder.Decode(msg)
+               if err != nil {
+                       return err
+               }
+               if msg.Keepalive {
+                       continue
+               }
+               go me.withContext(func() {
+                       var err error
+                       switch msg.Type {
+                       case peer_protocol.Choke:
+                               conn.PeerChoked = true
+                       case peer_protocol.Unchoke:
+                               conn.PeerChoked = false
+                               me.peerUnchoked(torrent, conn)
+                       case peer_protocol.Interested:
+                               conn.PeerInterested = true
+                       case peer_protocol.NotInterested:
+                               conn.PeerInterested = false
+                       case peer_protocol.Have:
+                               me.peerGotPiece(torrent, conn, int(msg.Index))
+                               conn.PeerPieces[msg.Index] = true
+                       case peer_protocol.Request:
+                               conn.PeerRequests[request{
+                                       Index:  msg.Index,
+                                       Begin:  msg.Begin,
+                                       Length: msg.Length,
+                               }] = struct{}{}
+                       case peer_protocol.Bitfield:
+                               if len(msg.Bitfield) < len(torrent.Pieces) {
+                                       err = errors.New("received invalid bitfield")
+                                       break
+                               }
+                               if conn.PeerPieces != nil {
+                                       err = errors.New("received unexpected bitfield")
+                                       break
+                               }
+                               conn.PeerPieces = msg.Bitfield[:len(torrent.Pieces)]
+                               for index, has := range conn.PeerPieces {
+                                       if has {
+                                               me.peerGotPiece(torrent, conn, index)
+                                       }
+                               }
+                       default:
+                               log.Printf("received unknown message type: %#v", msg.Type)
+                       }
+                       if err != nil {
+                               log.Print(err)
+                               me.dropConnection(torrent, conn)
+                       }
+               })
+       }
+}
+
+func (me *client) dropConnection(torrent *torrent, conn *connection) {
+       conn.Socket.Close()
+       for i0, c := range torrent.Conns {
+               if c != conn {
+                       continue
+               }
+               i1 := len(torrent.Conns) - 1
+               if i0 != i1 {
+                       torrent.Conns[i0] = torrent.Conns[i1]
+               }
+               torrent.Conns = torrent.Conns[:i1]
+               return
+       }
+       panic("no such connection")
+}
+
+func (me *client) addConnection(t *torrent, c *connection) bool {
+       for _, c := range t.Conns {
+               if c.PeerId == c.PeerId {
+                       return false
+               }
+       }
+       t.Conns = append(t.Conns, c)
+       return true
 }
 
 func (me *client) openNewConns() {
@@ -367,21 +591,34 @@ func (me *client) withContext(f func()) {
 
 func (me *client) pieceHashed(ih InfoHash, piece int, correct bool) {
        torrent := me.torrents[ih]
-       torrent.Pieces[piece].State = func() pieceState {
+       newState := func() pieceState {
                if correct {
                        return pieceStateComplete
                } else {
                        return pieceStateIncomplete
                }
        }()
-       for _, piece := range torrent.Pieces {
-               if piece.State == pieceStateUnknown {
-                       return
+       oldState := torrent.Pieces[piece].State
+       if newState == oldState {
+               return
+       }
+       torrent.Pieces[piece].State = newState
+       if newState == pieceStateIncomplete {
+               torrent.Pieces[piece].PendingChunks = torrent.pieceChunks(piece)
+       }
+       for _, conn := range torrent.Conns {
+               if correct {
+                       conn.Post(peer_protocol.Message{
+                               Type:  peer_protocol.Have,
+                               Index: peer_protocol.Integer(piece),
+                       })
+               } else {
+                       if conn.PeerHasPiece(piece) {
+                               conn.SetInterested(true)
+                       }
                }
        }
-       go func() {
-               me.torrentFinished <- ih
-       }()
+
 }
 
 func (me *client) run() {