From: Matt Joiner Date: Thu, 21 Aug 2014 08:12:49 +0000 (+1000) Subject: Improve the internal connection and handshake logic X-Git-Tag: v1.0.0~1635 X-Git-Url: http://www.git.stargrave.org/?a=commitdiff_plain;h=52fc7c72050f20a7ba7d66fef3abb6799f64ea90;p=btrtrc.git Improve the internal connection and handshake logic --- diff --git a/client.go b/client.go index 8e44663a..68d89965 100644 --- a/client.go +++ b/client.go @@ -41,6 +41,8 @@ import ( _ "bitbucket.org/anacrolix/go.torrent/tracker/udp" ) +const extensionBytes = "\x00\x00\x00\x00\x00\x10\x00\x00" + // Currently doesn't really queue, but should in the future. func (cl *Client) queuePieceCheck(t *torrent, pieceIndex pp.Integer) { piece := t.Pieces[pieceIndex] @@ -284,7 +286,10 @@ func (me *Client) initiateConn(peer Peer, torrent *torrent) { IP: peer.IP, Port: peer.Port, } - // Binding to the listener address and dialing via net.Dialer gives "address in use" error. It seems it's not possible to dial out from this address so that peers associate our local address with our listen address. + // Binding to the listener address and dialing via net.Dialer gives + // "address in use" error. It seems it's not possible to dial out from + // this address so that peers associate our local address with our + // listen address. conn, err := net.DialTimeout(addr.Network(), addr.String(), dialTimeout) // Whether or not the connection attempt succeeds, the half open @@ -352,72 +357,115 @@ func addrCompactIP(addr net.Addr) (string, error) { } } -func (me *Client) runConnection(sock net.Conn, torrent *torrent, discovery peerSource) (err error) { - conn := &connection{ - Discovery: discovery, - Socket: sock, - Choked: true, - PeerChoked: true, - writeCh: make(chan []byte), - PeerMaxRequests: 250, // Default in libtorrent is 250. - } - go conn.writer() +func handshakeWriter(w io.WriteCloser, bb <-chan []byte, done chan<- error) { + var err error + for b := range bb { + _, err = w.Write(b) + if err != nil { + w.Close() + break + } + } + done <- err +} + +type peerExtensionBytes [8]byte +type peerID [20]byte + +type handshakeResult struct { + peerExtensionBytes + peerID + InfoHash +} + +func handshake(sock io.ReadWriteCloser, ih *InfoHash, peerID [20]byte) (res handshakeResult, ok bool, err error) { + // Bytes to be sent to the peer. Should never block the sender. + postCh := make(chan []byte, 4) + // A single error value sent when the writer completes. + writeDone := make(chan error, 1) + // Performs writes to the socket and ensures posts don't block. + go handshakeWriter(sock, postCh, writeDone) + defer func() { - // There's a lock and deferred unlock later in this function. The - // client will not be locked when this deferred is invoked. - me.mu.Lock() - defer me.mu.Unlock() - conn.Close() + close(postCh) // Done writing. + if !ok { + return + } + if err != nil { + panic(err) + } + // Wait until writes complete before returning from handshake. + err = <-writeDone + if err != nil { + err = fmt.Errorf("error writing during handshake: %s", err) + } }() - // go conn.writeOptimizer() - conn.write(pp.Bytes(pp.Protocol)) - conn.write(pp.Bytes("\x00\x00\x00\x00\x00\x10\x00\x00")) - if torrent != nil { - conn.write(pp.Bytes(torrent.InfoHash[:])) - conn.write(pp.Bytes(me.PeerId[:])) - } - var b [28]byte - _, err = io.ReadFull(conn.Socket, b[:]) - if err == io.EOF { - return nil + + post := func(bb []byte) { + select { + case postCh <- bb: + default: + panic("mustn't block while posting") + } } + + post([]byte(pp.Protocol)) + post([]byte(extensionBytes)) + if ih != nil { // We already know what we want. + post(ih[:]) + post(peerID[:]) + } + var b [68]byte + _, err = io.ReadFull(sock, b[:68]) if err != nil { - err = fmt.Errorf("when reading protocol and extensions: %s", err) + err = nil return } if string(b[:20]) != pp.Protocol { - // err = fmt.Errorf("wrong protocol: %#v", string(b[:20])) return } - if 8 != copy(conn.PeerExtensions[:], b[20:]) { - panic("wtf") - } - // log.Printf("peer extensions: %#v", string(conn.PeerExtensions[:])) - var infoHash [20]byte - _, err = io.ReadFull(conn.Socket, infoHash[:]) - if err != nil { - return fmt.Errorf("reading peer info hash: %s", err) - } - _, err = io.ReadFull(conn.Socket, conn.PeerId[:]) - if err != nil { - return fmt.Errorf("reading peer id: %s", err) + copy(res.peerExtensionBytes[:], b[20:28]) + copy(res.InfoHash[:], b[28:48]) + copy(res.peerID[:], b[48:68]) + + if ih == nil { // We were waiting for the peer to tell us what they wanted. + post(res.InfoHash[:]) + post(peerID[:]) } - if torrent == nil { - torrent = me.torrent(infoHash) + + ok = true + return +} + +func (me *Client) runConnection(sock net.Conn, torrent *torrent, discovery peerSource) (err error) { + defer sock.Close() + hsRes, ok, err := handshake(sock, func() *InfoHash { if torrent == nil { - return + return nil + } else { + return &torrent.InfoHash } - conn.write(pp.Bytes(torrent.InfoHash[:])) - conn.write(pp.Bytes(me.PeerId[:])) + }(), me.peerID) + if err != nil { + err = fmt.Errorf("error during handshake: %s", err) + return + } + if !ok { + return } me.mu.Lock() defer me.mu.Unlock() + torrent = me.torrent(hsRes.InfoHash) + if torrent == nil { + return + } + conn := newConnection(sock, hsRes.peerExtensionBytes, hsRes.peerID) + defer conn.Close() + conn.Discovery = discovery if !me.addConnection(torrent, conn) { return } - conn.post = make(chan pp.Message) - go conn.writeOptimizer(time.Minute) - if conn.PeerExtensions[5]&0x10 != 0 { + if conn.PeerExtensionBytes[5]&0x10 != 0 { conn.Post(pp.Message{ Type: pp.Extended, ExtendedID: pp.HandshakeExtendedID, @@ -459,7 +507,7 @@ func (me *Client) runConnection(sock net.Conn, torrent *torrent, discovery peerS } err = me.connectionLoop(torrent, conn) if err != nil { - err = fmt.Errorf("during Connection loop: %s", err) + err = fmt.Errorf("during Connection loop with peer %q: %s", conn.PeerID, err) } me.dropConnection(torrent, conn) return @@ -527,7 +575,7 @@ func (cl *Client) completedMetadata(t *torrent) { h := sha1.New() h.Write(t.MetaData) var ih InfoHash - copy(ih[:], h.Sum(nil)[:]) + CopyExact(&ih, h.Sum(nil)) if ih != t.InfoHash { log.Print("bad metadata") t.InvalidateMetadata() @@ -551,6 +599,7 @@ func (cl *Client) completedMetadata(t *torrent) { log.Printf("%s: got metadata from peers", t) } +// Process incoming ut_metadata message. func (cl *Client) gotMetadataExtensionMsg(payload []byte, t *torrent, c *connection) (err error) { var d map[string]int err = bencode.Unmarshal(payload, &d) @@ -579,7 +628,8 @@ func (cl *Client) gotMetadataExtensionMsg(payload []byte, t *torrent, c *connect c.Post(t.NewMetadataExtensionMessage(c, pp.RejectMetadataExtensionMsgType, d["piece"], nil)) break } - c.Post(t.NewMetadataExtensionMessage(c, pp.DataMetadataExtensionMsgType, piece, t.MetaData[(1<<14)*piece:(1<<14)*piece+t.metadataPieceSize(piece)])) + start := (1 << 14) * piece + c.Post(t.NewMetadataExtensionMessage(c, pp.DataMetadataExtensionMsgType, piece, t.MetaData[start:start+t.metadataPieceSize(piece)])) case pp.RejectMetadataExtensionMsgType: default: err = errors.New("unknown msg_type value") @@ -588,9 +638,9 @@ func (cl *Client) gotMetadataExtensionMsg(payload []byte, t *torrent, c *connect } type peerExchangeMessage struct { - Added util.CompactPeers `bencode:"added"` - AddedFlags []byte `bencode:"added.f"` - Dropped []tracker.Peer `bencode:"dropped"` + Added CompactPeers `bencode:"added"` + AddedFlags []byte `bencode:"added.f"` + Dropped []tracker.Peer `bencode:"dropped"` } // Processes incoming bittorrent messages. The client lock is held upon entry @@ -605,7 +655,7 @@ func (me *Client) connectionLoop(t *torrent, c *connection) error { var msg pp.Message err := decoder.Decode(&msg) me.mu.Lock() - if c.closed { + if c.getClosed() { return nil } if err != nil { @@ -773,6 +823,9 @@ func (me *Client) connectionLoop(t *torrent, c *connection) error { default: err = fmt.Errorf("unexpected extended message ID: %v", msg.ExtendedID) } + if err != nil { + log.Printf("peer extension map: %#v", c.PeerExtensionIDs) + } default: err = fmt.Errorf("received unknown message type: %#v", msg.Type) } @@ -806,7 +859,7 @@ func (me *Client) addConnection(t *torrent, c *connection) bool { return false } for _, c0 := range t.Conns { - if c.PeerId == c0.PeerId { + if c.PeerID == c0.PeerID { // Already connected to a client with that ID. return false } diff --git a/connection.go b/connection.go index f966774f..4a7104d9 100644 --- a/connection.go +++ b/connection.go @@ -11,7 +11,7 @@ import ( "sync" "time" - "bitbucket.org/anacrolix/go.torrent/peer_protocol" + pp "bitbucket.org/anacrolix/go.torrent/peer_protocol" ) type peerSource byte @@ -26,9 +26,9 @@ const ( type connection struct { Socket net.Conn Discovery peerSource - closed bool + closed chan struct{} mu sync.Mutex // Only for closing. - post chan peer_protocol.Message + post chan pp.Message writeCh chan []byte // Stuff controlled by the local peer. @@ -37,11 +37,11 @@ type connection struct { Requests map[request]struct{} // Stuff controlled by the remote peer. - PeerId [20]byte - PeerInterested bool - PeerChoked bool - PeerRequests map[request]struct{} - PeerExtensions [8]byte + PeerID [20]byte + PeerInterested bool + PeerChoked bool + PeerRequests map[request]struct{} + PeerExtensionBytes peerExtensionBytes // Whether the peer has the given piece. nil if they've not sent any // related messages yet. PeerPieces []bool @@ -50,10 +50,22 @@ type connection struct { PeerClientName string } -func (cn *connection) write(b []byte) { - cn.mu.Lock() - cn.writeCh <- b - cn.mu.Unlock() +func newConnection(sock net.Conn, peb peerExtensionBytes, peerID [20]byte) (c *connection) { + c = &connection{ + Socket: sock, + Choked: true, + PeerChoked: true, + PeerMaxRequests: 250, + PeerExtensionBytes: peb, + PeerID: peerID, + + closed: make(chan struct{}), + writeCh: make(chan []byte), + post: make(chan pp.Message), + } + go c.writer() + go c.writeOptimizer(time.Minute) + return } func (cn *connection) completedString() string { @@ -105,7 +117,7 @@ func (cn *connection) setNumPieces(num int) error { } func (cn *connection) WriteStatus(w io.Writer) { - fmt.Fprintf(w, "%q: %s-%s: %s completed, reqs: %d-%d, flags: ", cn.PeerId, cn.Socket.LocalAddr(), cn.Socket.RemoteAddr(), cn.completedString(), len(cn.Requests), len(cn.PeerRequests)) + fmt.Fprintf(w, "%q: %s-%s: %s completed, reqs: %d-%d, flags: ", cn.PeerID, cn.Socket.LocalAddr(), cn.Socket.RemoteAddr(), cn.completedString(), len(cn.Requests), len(cn.PeerRequests)) c := func(b byte) { fmt.Fprintf(w, "%c", b) } @@ -139,28 +151,23 @@ func (cn *connection) WriteStatus(w io.Writer) { func (c *connection) Close() { c.mu.Lock() defer c.mu.Unlock() - if c.closed { + if c.getClosed() { return } + close(c.closed) c.Socket.Close() - if c.post == nil { - // writeOptimizer isn't running, so we need to signal the writer to - // stop. - close(c.writeCh) - } else { - // This will kill the writeOptimizer, and it kills the writer. - close(c.post) - } - c.closed = true } func (c *connection) getClosed() bool { - c.mu.Lock() - defer c.mu.Unlock() - return c.closed + select { + case <-c.closed: + return true + default: + return false + } } -func (c *connection) PeerHasPiece(index peer_protocol.Integer) bool { +func (c *connection) PeerHasPiece(index pp.Integer) bool { if c.PeerPieces == nil { return false } @@ -170,8 +177,11 @@ func (c *connection) PeerHasPiece(index peer_protocol.Integer) bool { return c.PeerPieces[index] } -func (c *connection) Post(msg peer_protocol.Message) { - c.post <- msg +func (c *connection) Post(msg pp.Message) { + select { + case c.post <- msg: + case <-c.closed: + } } func (c *connection) RequestPending(r request) bool { @@ -198,8 +208,8 @@ func (c *connection) Request(chunk request) bool { c.Requests = make(map[request]struct{}, c.PeerMaxRequests) } c.Requests[chunk] = struct{}{} - c.Post(peer_protocol.Message{ - Type: peer_protocol.Request, + c.Post(pp.Message{ + Type: pp.Request, Index: chunk.Index, Begin: chunk.Begin, Length: chunk.Length, @@ -216,8 +226,8 @@ func (c *connection) Cancel(r request) bool { return false } delete(c.Requests, r) - c.Post(peer_protocol.Message{ - Type: peer_protocol.Cancel, + c.Post(pp.Message{ + Type: pp.Cancel, Index: r.Index, Begin: r.Begin, Length: r.Length, @@ -241,8 +251,8 @@ func (c *connection) Choke() { if c.Choked { return } - c.Post(peer_protocol.Message{ - Type: peer_protocol.Choke, + c.Post(pp.Message{ + Type: pp.Choke, }) c.Choked = true } @@ -251,8 +261,8 @@ func (c *connection) Unchoke() { if !c.Choked { return } - c.Post(peer_protocol.Message{ - Type: peer_protocol.Unchoke, + c.Post(pp.Message{ + Type: pp.Unchoke, }) c.Choked = false } @@ -261,12 +271,12 @@ func (c *connection) SetInterested(interested bool) { if c.Interested == interested { return } - c.Post(peer_protocol.Message{ - Type: func() peer_protocol.MessageType { + c.Post(pp.Message{ + Type: func() pp.MessageType { if interested { - return peer_protocol.Interested + return pp.Interested } else { - return peer_protocol.NotInterested + return pp.NotInterested } }(), }) @@ -280,17 +290,21 @@ var ( // Writes buffers to the socket from the write channel. func (conn *connection) writer() { - for b := range conn.writeCh { - _, err := conn.Socket.Write(b) - // log.Printf("wrote %q to %s", b, conn.Socket.RemoteAddr()) - if err != nil { - if !conn.getClosed() { - log.Print(err) + for { + select { + case b, ok := <-conn.writeCh: + if !ok { + return + } + _, err := conn.Socket.Write(b) + if err != nil { + conn.Close() + return } - break + case <-conn.closed: + return } } - conn.Close() } func (conn *connection) writeOptimizer(keepAliveDelay time.Duration) { @@ -322,15 +336,15 @@ func (conn *connection) writeOptimizer(keepAliveDelay time.Duration) { timer.Reset(keepAliveTime.Sub(time.Now())) break } - pending.PushBack(peer_protocol.Message{Keepalive: true}) + pending.PushBack(pp.Message{Keepalive: true}) case msg, ok := <-conn.post: if !ok { return } - if msg.Type == peer_protocol.Cancel { + if msg.Type == pp.Cancel { for e := pending.Back(); e != nil; e = e.Prev() { - elemMsg := e.Value.(peer_protocol.Message) - if elemMsg.Type == peer_protocol.Request && msg.Index == elemMsg.Index && msg.Begin == elemMsg.Begin && msg.Length == elemMsg.Length { + elemMsg := e.Value.(pp.Message) + if elemMsg.Type == pp.Request && msg.Index == elemMsg.Index && msg.Begin == elemMsg.Begin && msg.Length == elemMsg.Length { pending.Remove(e) log.Printf("optimized cancel! %v", msg) break event @@ -345,6 +359,8 @@ func (conn *connection) writeOptimizer(keepAliveDelay time.Duration) { if pending.Len() == 0 { timer.Reset(keepAliveDelay) } + case <-conn.closed: + return } } }