From: Matt Joiner Date: Wed, 27 Aug 2014 23:45:20 +0000 (+1000) Subject: Keep track of ongoing handshakes and add timeouts to connection sockets X-Git-Tag: v1.0.0~1582 X-Git-Url: http://www.git.stargrave.org/?a=commitdiff_plain;h=e79f1bcbf7ca50aa745ea149e1db4b5a19c402ad;p=btrtrc.git Keep track of ongoing handshakes and add timeouts to connection sockets --- diff --git a/client.go b/client.go index 31b7cabe..7a4ee86f 100644 --- a/client.go +++ b/client.go @@ -109,8 +109,9 @@ type Client struct { event sync.Cond quit chan struct{} - halfOpen int - torrents map[InfoHash]*torrent + halfOpen int + handshaking int + torrents map[InfoHash]*torrent dataWaiterMutex sync.Mutex dataWaiter chan struct{} @@ -130,6 +131,7 @@ func (cl *Client) WriteStatus(w io.Writer) { } fmt.Fprintf(w, "Peer ID: %q\n", cl.peerID) fmt.Fprintf(w, "Half open outgoing connections: %d\n", cl.halfOpen) + fmt.Fprintf(w, "Handshaking: %d\n", cl.handshaking) if cl.dHT != nil { fmt.Fprintf(w, "DHT nodes: %d\n", cl.dHT.NumNodes()) fmt.Fprintf(w, "DHT Server ID: %x\n", cl.dHT.IDString()) @@ -474,8 +476,26 @@ func handshake(sock io.ReadWriteCloser, ih *InfoHash, peerID [20]byte) (res hand return } +type peerConn struct { + net.Conn +} + +func (pc peerConn) Read(b []byte) (n int, err error) { + err = pc.Conn.SetReadDeadline(time.Now().Add(150 * time.Second)) + if err != nil { + return + } + n, err = pc.Conn.Read(b) + return +} + func (me *Client) runConnection(sock net.Conn, torrent *torrent, discovery peerSource) (err error) { defer sock.Close() + me.mu.Lock() + me.handshaking++ + me.mu.Unlock() + // One minute to complete handshake. + sock.SetDeadline(time.Now().Add(time.Minute)) hsRes, ok, err := handshake(sock, func() *InfoHash { if torrent == nil { return nil @@ -483,6 +503,12 @@ func (me *Client) runConnection(sock net.Conn, torrent *torrent, discovery peerS return &torrent.InfoHash } }(), me.peerID) + me.mu.Lock() + defer me.mu.Unlock() + if me.handshaking == 0 { + panic("handshake count invariant is broken") + } + me.handshaking-- if err != nil { err = fmt.Errorf("error during handshake: %s", err) return @@ -490,12 +516,12 @@ func (me *Client) runConnection(sock net.Conn, torrent *torrent, discovery peerS if !ok { return } - me.mu.Lock() - defer me.mu.Unlock() torrent = me.torrent(hsRes.InfoHash) if torrent == nil { return } + sock.SetWriteDeadline(time.Time{}) + sock = peerConn{sock} conn := newConnection(sock, hsRes.peerExtensionBytes, hsRes.peerID) defer conn.Close() conn.Discovery = discovery