]> Sergey Matveev's repositories - btrtrc.git/commitdiff
Improve the internal connection and handshake logic
authorMatt Joiner <anacrolix@gmail.com>
Thu, 21 Aug 2014 08:12:49 +0000 (18:12 +1000)
committerMatt Joiner <anacrolix@gmail.com>
Thu, 21 Aug 2014 08:12:49 +0000 (18:12 +1000)
client.go
connection.go

index 8e44663a12f08e0a68cbc6aba46bf5a32f0f4659..68d89965a34558a52d4edea30ec605b4595292d9 100644 (file)
--- a/client.go
+++ b/client.go
@@ -41,6 +41,8 @@ import (
        _ "bitbucket.org/anacrolix/go.torrent/tracker/udp"
 )
 
+const extensionBytes = "\x00\x00\x00\x00\x00\x10\x00\x00"
+
 // Currently doesn't really queue, but should in the future.
 func (cl *Client) queuePieceCheck(t *torrent, pieceIndex pp.Integer) {
        piece := t.Pieces[pieceIndex]
@@ -284,7 +286,10 @@ func (me *Client) initiateConn(peer Peer, torrent *torrent) {
                        IP:   peer.IP,
                        Port: peer.Port,
                }
-               // Binding to the listener address and dialing via net.Dialer gives "address in use" error. It seems it's not possible to dial out from this address so that peers associate our local address with our listen address.
+               // Binding to the listener address and dialing via net.Dialer gives
+               // "address in use" error. It seems it's not possible to dial out from
+               // this address so that peers associate our local address with our
+               // listen address.
                conn, err := net.DialTimeout(addr.Network(), addr.String(), dialTimeout)
 
                // Whether or not the connection attempt succeeds, the half open
@@ -352,72 +357,115 @@ func addrCompactIP(addr net.Addr) (string, error) {
        }
 }
 
-func (me *Client) runConnection(sock net.Conn, torrent *torrent, discovery peerSource) (err error) {
-       conn := &connection{
-               Discovery:       discovery,
-               Socket:          sock,
-               Choked:          true,
-               PeerChoked:      true,
-               writeCh:         make(chan []byte),
-               PeerMaxRequests: 250, // Default in libtorrent is 250.
-       }
-       go conn.writer()
+func handshakeWriter(w io.WriteCloser, bb <-chan []byte, done chan<- error) {
+       var err error
+       for b := range bb {
+               _, err = w.Write(b)
+               if err != nil {
+                       w.Close()
+                       break
+               }
+       }
+       done <- err
+}
+
+type peerExtensionBytes [8]byte
+type peerID [20]byte
+
+type handshakeResult struct {
+       peerExtensionBytes
+       peerID
+       InfoHash
+}
+
+func handshake(sock io.ReadWriteCloser, ih *InfoHash, peerID [20]byte) (res handshakeResult, ok bool, err error) {
+       // Bytes to be sent to the peer. Should never block the sender.
+       postCh := make(chan []byte, 4)
+       // A single error value sent when the writer completes.
+       writeDone := make(chan error, 1)
+       // Performs writes to the socket and ensures posts don't block.
+       go handshakeWriter(sock, postCh, writeDone)
+
        defer func() {
-               // There's a lock and deferred unlock later in this function. The
-               // client will not be locked when this deferred is invoked.
-               me.mu.Lock()
-               defer me.mu.Unlock()
-               conn.Close()
+               close(postCh) // Done writing.
+               if !ok {
+                       return
+               }
+               if err != nil {
+                       panic(err)
+               }
+               // Wait until writes complete before returning from handshake.
+               err = <-writeDone
+               if err != nil {
+                       err = fmt.Errorf("error writing during handshake: %s", err)
+               }
        }()
-       // go conn.writeOptimizer()
-       conn.write(pp.Bytes(pp.Protocol))
-       conn.write(pp.Bytes("\x00\x00\x00\x00\x00\x10\x00\x00"))
-       if torrent != nil {
-               conn.write(pp.Bytes(torrent.InfoHash[:]))
-               conn.write(pp.Bytes(me.PeerId[:]))
-       }
-       var b [28]byte
-       _, err = io.ReadFull(conn.Socket, b[:])
-       if err == io.EOF {
-               return nil
+
+       post := func(bb []byte) {
+               select {
+               case postCh <- bb:
+               default:
+                       panic("mustn't block while posting")
+               }
        }
+
+       post([]byte(pp.Protocol))
+       post([]byte(extensionBytes))
+       if ih != nil { // We already know what we want.
+               post(ih[:])
+               post(peerID[:])
+       }
+       var b [68]byte
+       _, err = io.ReadFull(sock, b[:68])
        if err != nil {
-               err = fmt.Errorf("when reading protocol and extensions: %s", err)
+               err = nil
                return
        }
        if string(b[:20]) != pp.Protocol {
-               // err = fmt.Errorf("wrong protocol: %#v", string(b[:20]))
                return
        }
-       if 8 != copy(conn.PeerExtensions[:], b[20:]) {
-               panic("wtf")
-       }
-       // log.Printf("peer extensions: %#v", string(conn.PeerExtensions[:]))
-       var infoHash [20]byte
-       _, err = io.ReadFull(conn.Socket, infoHash[:])
-       if err != nil {
-               return fmt.Errorf("reading peer info hash: %s", err)
-       }
-       _, err = io.ReadFull(conn.Socket, conn.PeerId[:])
-       if err != nil {
-               return fmt.Errorf("reading peer id: %s", err)
+       copy(res.peerExtensionBytes[:], b[20:28])
+       copy(res.InfoHash[:], b[28:48])
+       copy(res.peerID[:], b[48:68])
+
+       if ih == nil { // We were waiting for the peer to tell us what they wanted.
+               post(res.InfoHash[:])
+               post(peerID[:])
        }
-       if torrent == nil {
-               torrent = me.torrent(infoHash)
+
+       ok = true
+       return
+}
+
+func (me *Client) runConnection(sock net.Conn, torrent *torrent, discovery peerSource) (err error) {
+       defer sock.Close()
+       hsRes, ok, err := handshake(sock, func() *InfoHash {
                if torrent == nil {
-                       return
+                       return nil
+               } else {
+                       return &torrent.InfoHash
                }
-               conn.write(pp.Bytes(torrent.InfoHash[:]))
-               conn.write(pp.Bytes(me.PeerId[:]))
+       }(), me.peerID)
+       if err != nil {
+               err = fmt.Errorf("error during handshake: %s", err)
+               return
+       }
+       if !ok {
+               return
        }
        me.mu.Lock()
        defer me.mu.Unlock()
+       torrent = me.torrent(hsRes.InfoHash)
+       if torrent == nil {
+               return
+       }
+       conn := newConnection(sock, hsRes.peerExtensionBytes, hsRes.peerID)
+       defer conn.Close()
+       conn.Discovery = discovery
        if !me.addConnection(torrent, conn) {
                return
        }
-       conn.post = make(chan pp.Message)
-       go conn.writeOptimizer(time.Minute)
-       if conn.PeerExtensions[5]&0x10 != 0 {
+       if conn.PeerExtensionBytes[5]&0x10 != 0 {
                conn.Post(pp.Message{
                        Type:       pp.Extended,
                        ExtendedID: pp.HandshakeExtendedID,
@@ -459,7 +507,7 @@ func (me *Client) runConnection(sock net.Conn, torrent *torrent, discovery peerS
        }
        err = me.connectionLoop(torrent, conn)
        if err != nil {
-               err = fmt.Errorf("during Connection loop: %s", err)
+               err = fmt.Errorf("during Connection loop with peer %q: %s", conn.PeerID, err)
        }
        me.dropConnection(torrent, conn)
        return
@@ -527,7 +575,7 @@ func (cl *Client) completedMetadata(t *torrent) {
        h := sha1.New()
        h.Write(t.MetaData)
        var ih InfoHash
-       copy(ih[:], h.Sum(nil)[:])
+       CopyExact(&ih, h.Sum(nil))
        if ih != t.InfoHash {
                log.Print("bad metadata")
                t.InvalidateMetadata()
@@ -551,6 +599,7 @@ func (cl *Client) completedMetadata(t *torrent) {
        log.Printf("%s: got metadata from peers", t)
 }
 
+// Process incoming ut_metadata message.
 func (cl *Client) gotMetadataExtensionMsg(payload []byte, t *torrent, c *connection) (err error) {
        var d map[string]int
        err = bencode.Unmarshal(payload, &d)
@@ -579,7 +628,8 @@ func (cl *Client) gotMetadataExtensionMsg(payload []byte, t *torrent, c *connect
                        c.Post(t.NewMetadataExtensionMessage(c, pp.RejectMetadataExtensionMsgType, d["piece"], nil))
                        break
                }
-               c.Post(t.NewMetadataExtensionMessage(c, pp.DataMetadataExtensionMsgType, piece, t.MetaData[(1<<14)*piece:(1<<14)*piece+t.metadataPieceSize(piece)]))
+               start := (1 << 14) * piece
+               c.Post(t.NewMetadataExtensionMessage(c, pp.DataMetadataExtensionMsgType, piece, t.MetaData[start:start+t.metadataPieceSize(piece)]))
        case pp.RejectMetadataExtensionMsgType:
        default:
                err = errors.New("unknown msg_type value")
@@ -588,9 +638,9 @@ func (cl *Client) gotMetadataExtensionMsg(payload []byte, t *torrent, c *connect
 }
 
 type peerExchangeMessage struct {
-       Added      util.CompactPeers `bencode:"added"`
-       AddedFlags []byte            `bencode:"added.f"`
-       Dropped    []tracker.Peer    `bencode:"dropped"`
+       Added      CompactPeers   `bencode:"added"`
+       AddedFlags []byte         `bencode:"added.f"`
+       Dropped    []tracker.Peer `bencode:"dropped"`
 }
 
 // Processes incoming bittorrent messages. The client lock is held upon entry
@@ -605,7 +655,7 @@ func (me *Client) connectionLoop(t *torrent, c *connection) error {
                var msg pp.Message
                err := decoder.Decode(&msg)
                me.mu.Lock()
-               if c.closed {
+               if c.getClosed() {
                        return nil
                }
                if err != nil {
@@ -773,6 +823,9 @@ func (me *Client) connectionLoop(t *torrent, c *connection) error {
                        default:
                                err = fmt.Errorf("unexpected extended message ID: %v", msg.ExtendedID)
                        }
+                       if err != nil {
+                               log.Printf("peer extension map: %#v", c.PeerExtensionIDs)
+                       }
                default:
                        err = fmt.Errorf("received unknown message type: %#v", msg.Type)
                }
@@ -806,7 +859,7 @@ func (me *Client) addConnection(t *torrent, c *connection) bool {
                return false
        }
        for _, c0 := range t.Conns {
-               if c.PeerId == c0.PeerId {
+               if c.PeerID == c0.PeerID {
                        // Already connected to a client with that ID.
                        return false
                }
index f966774fbd60d1bbe9cc6d044d6a728ab6a3e12b..4a7104d958740ba83c7e5e0c3d27a50a48bc8711 100644 (file)
@@ -11,7 +11,7 @@ import (
        "sync"
        "time"
 
-       "bitbucket.org/anacrolix/go.torrent/peer_protocol"
+       pp "bitbucket.org/anacrolix/go.torrent/peer_protocol"
 )
 
 type peerSource byte
@@ -26,9 +26,9 @@ const (
 type connection struct {
        Socket    net.Conn
        Discovery peerSource
-       closed    bool
+       closed    chan struct{}
        mu        sync.Mutex // Only for closing.
-       post      chan peer_protocol.Message
+       post      chan pp.Message
        writeCh   chan []byte
 
        // Stuff controlled by the local peer.
@@ -37,11 +37,11 @@ type connection struct {
        Requests   map[request]struct{}
 
        // Stuff controlled by the remote peer.
-       PeerId         [20]byte
-       PeerInterested bool
-       PeerChoked     bool
-       PeerRequests   map[request]struct{}
-       PeerExtensions [8]byte
+       PeerID             [20]byte
+       PeerInterested     bool
+       PeerChoked         bool
+       PeerRequests       map[request]struct{}
+       PeerExtensionBytes peerExtensionBytes
        // Whether the peer has the given piece. nil if they've not sent any
        // related messages yet.
        PeerPieces       []bool
@@ -50,10 +50,22 @@ type connection struct {
        PeerClientName   string
 }
 
-func (cn *connection) write(b []byte) {
-       cn.mu.Lock()
-       cn.writeCh <- b
-       cn.mu.Unlock()
+func newConnection(sock net.Conn, peb peerExtensionBytes, peerID [20]byte) (c *connection) {
+       c = &connection{
+               Socket:             sock,
+               Choked:             true,
+               PeerChoked:         true,
+               PeerMaxRequests:    250,
+               PeerExtensionBytes: peb,
+               PeerID:             peerID,
+
+               closed:  make(chan struct{}),
+               writeCh: make(chan []byte),
+               post:    make(chan pp.Message),
+       }
+       go c.writer()
+       go c.writeOptimizer(time.Minute)
+       return
 }
 
 func (cn *connection) completedString() string {
@@ -105,7 +117,7 @@ func (cn *connection) setNumPieces(num int) error {
 }
 
 func (cn *connection) WriteStatus(w io.Writer) {
-       fmt.Fprintf(w, "%q: %s-%s: %s completed, reqs: %d-%d, flags: ", cn.PeerId, cn.Socket.LocalAddr(), cn.Socket.RemoteAddr(), cn.completedString(), len(cn.Requests), len(cn.PeerRequests))
+       fmt.Fprintf(w, "%q: %s-%s: %s completed, reqs: %d-%d, flags: ", cn.PeerID, cn.Socket.LocalAddr(), cn.Socket.RemoteAddr(), cn.completedString(), len(cn.Requests), len(cn.PeerRequests))
        c := func(b byte) {
                fmt.Fprintf(w, "%c", b)
        }
@@ -139,28 +151,23 @@ func (cn *connection) WriteStatus(w io.Writer) {
 func (c *connection) Close() {
        c.mu.Lock()
        defer c.mu.Unlock()
-       if c.closed {
+       if c.getClosed() {
                return
        }
+       close(c.closed)
        c.Socket.Close()
-       if c.post == nil {
-               // writeOptimizer isn't running, so we need to signal the writer to
-               // stop.
-               close(c.writeCh)
-       } else {
-               // This will kill the writeOptimizer, and it kills the writer.
-               close(c.post)
-       }
-       c.closed = true
 }
 
 func (c *connection) getClosed() bool {
-       c.mu.Lock()
-       defer c.mu.Unlock()
-       return c.closed
+       select {
+       case <-c.closed:
+               return true
+       default:
+               return false
+       }
 }
 
-func (c *connection) PeerHasPiece(index peer_protocol.Integer) bool {
+func (c *connection) PeerHasPiece(index pp.Integer) bool {
        if c.PeerPieces == nil {
                return false
        }
@@ -170,8 +177,11 @@ func (c *connection) PeerHasPiece(index peer_protocol.Integer) bool {
        return c.PeerPieces[index]
 }
 
-func (c *connection) Post(msg peer_protocol.Message) {
-       c.post <- msg
+func (c *connection) Post(msg pp.Message) {
+       select {
+       case c.post <- msg:
+       case <-c.closed:
+       }
 }
 
 func (c *connection) RequestPending(r request) bool {
@@ -198,8 +208,8 @@ func (c *connection) Request(chunk request) bool {
                c.Requests = make(map[request]struct{}, c.PeerMaxRequests)
        }
        c.Requests[chunk] = struct{}{}
-       c.Post(peer_protocol.Message{
-               Type:   peer_protocol.Request,
+       c.Post(pp.Message{
+               Type:   pp.Request,
                Index:  chunk.Index,
                Begin:  chunk.Begin,
                Length: chunk.Length,
@@ -216,8 +226,8 @@ func (c *connection) Cancel(r request) bool {
                return false
        }
        delete(c.Requests, r)
-       c.Post(peer_protocol.Message{
-               Type:   peer_protocol.Cancel,
+       c.Post(pp.Message{
+               Type:   pp.Cancel,
                Index:  r.Index,
                Begin:  r.Begin,
                Length: r.Length,
@@ -241,8 +251,8 @@ func (c *connection) Choke() {
        if c.Choked {
                return
        }
-       c.Post(peer_protocol.Message{
-               Type: peer_protocol.Choke,
+       c.Post(pp.Message{
+               Type: pp.Choke,
        })
        c.Choked = true
 }
@@ -251,8 +261,8 @@ func (c *connection) Unchoke() {
        if !c.Choked {
                return
        }
-       c.Post(peer_protocol.Message{
-               Type: peer_protocol.Unchoke,
+       c.Post(pp.Message{
+               Type: pp.Unchoke,
        })
        c.Choked = false
 }
@@ -261,12 +271,12 @@ func (c *connection) SetInterested(interested bool) {
        if c.Interested == interested {
                return
        }
-       c.Post(peer_protocol.Message{
-               Type: func() peer_protocol.MessageType {
+       c.Post(pp.Message{
+               Type: func() pp.MessageType {
                        if interested {
-                               return peer_protocol.Interested
+                               return pp.Interested
                        } else {
-                               return peer_protocol.NotInterested
+                               return pp.NotInterested
                        }
                }(),
        })
@@ -280,17 +290,21 @@ var (
 
 // Writes buffers to the socket from the write channel.
 func (conn *connection) writer() {
-       for b := range conn.writeCh {
-               _, err := conn.Socket.Write(b)
-               // log.Printf("wrote %q to %s", b, conn.Socket.RemoteAddr())
-               if err != nil {
-                       if !conn.getClosed() {
-                               log.Print(err)
+       for {
+               select {
+               case b, ok := <-conn.writeCh:
+                       if !ok {
+                               return
+                       }
+                       _, err := conn.Socket.Write(b)
+                       if err != nil {
+                               conn.Close()
+                               return
                        }
-                       break
+               case <-conn.closed:
+                       return
                }
        }
-       conn.Close()
 }
 
 func (conn *connection) writeOptimizer(keepAliveDelay time.Duration) {
@@ -322,15 +336,15 @@ func (conn *connection) writeOptimizer(keepAliveDelay time.Duration) {
                                timer.Reset(keepAliveTime.Sub(time.Now()))
                                break
                        }
-                       pending.PushBack(peer_protocol.Message{Keepalive: true})
+                       pending.PushBack(pp.Message{Keepalive: true})
                case msg, ok := <-conn.post:
                        if !ok {
                                return
                        }
-                       if msg.Type == peer_protocol.Cancel {
+                       if msg.Type == pp.Cancel {
                                for e := pending.Back(); e != nil; e = e.Prev() {
-                                       elemMsg := e.Value.(peer_protocol.Message)
-                                       if elemMsg.Type == peer_protocol.Request && msg.Index == elemMsg.Index && msg.Begin == elemMsg.Begin && msg.Length == elemMsg.Length {
+                                       elemMsg := e.Value.(pp.Message)
+                                       if elemMsg.Type == pp.Request && msg.Index == elemMsg.Index && msg.Begin == elemMsg.Begin && msg.Length == elemMsg.Length {
                                                pending.Remove(e)
                                                log.Printf("optimized cancel! %v", msg)
                                                break event
@@ -345,6 +359,8 @@ func (conn *connection) writeOptimizer(keepAliveDelay time.Duration) {
                        if pending.Len() == 0 {
                                timer.Reset(keepAliveDelay)
                        }
+               case <-conn.closed:
+                       return
                }
        }
 }