From e79f1bcbf7ca50aa745ea149e1db4b5a19c402ad Mon Sep 17 00:00:00 2001 From: Matt Joiner Date: Thu, 28 Aug 2014 09:45:20 +1000 Subject: [PATCH] Keep track of ongoing handshakes and add timeouts to connection sockets --- client.go | 34 ++++++++++++++++++++++++++++++---- 1 file changed, 30 insertions(+), 4 deletions(-) diff --git a/client.go b/client.go index 31b7cabe..7a4ee86f 100644 --- a/client.go +++ b/client.go @@ -109,8 +109,9 @@ type Client struct { event sync.Cond quit chan struct{} - halfOpen int - torrents map[InfoHash]*torrent + halfOpen int + handshaking int + torrents map[InfoHash]*torrent dataWaiterMutex sync.Mutex dataWaiter chan struct{} @@ -130,6 +131,7 @@ func (cl *Client) WriteStatus(w io.Writer) { } fmt.Fprintf(w, "Peer ID: %q\n", cl.peerID) fmt.Fprintf(w, "Half open outgoing connections: %d\n", cl.halfOpen) + fmt.Fprintf(w, "Handshaking: %d\n", cl.handshaking) if cl.dHT != nil { fmt.Fprintf(w, "DHT nodes: %d\n", cl.dHT.NumNodes()) fmt.Fprintf(w, "DHT Server ID: %x\n", cl.dHT.IDString()) @@ -474,8 +476,26 @@ func handshake(sock io.ReadWriteCloser, ih *InfoHash, peerID [20]byte) (res hand return } +type peerConn struct { + net.Conn +} + +func (pc peerConn) Read(b []byte) (n int, err error) { + err = pc.Conn.SetReadDeadline(time.Now().Add(150 * time.Second)) + if err != nil { + return + } + n, err = pc.Conn.Read(b) + return +} + func (me *Client) runConnection(sock net.Conn, torrent *torrent, discovery peerSource) (err error) { defer sock.Close() + me.mu.Lock() + me.handshaking++ + me.mu.Unlock() + // One minute to complete handshake. + sock.SetDeadline(time.Now().Add(time.Minute)) hsRes, ok, err := handshake(sock, func() *InfoHash { if torrent == nil { return nil @@ -483,6 +503,12 @@ func (me *Client) runConnection(sock net.Conn, torrent *torrent, discovery peerS return &torrent.InfoHash } }(), me.peerID) + me.mu.Lock() + defer me.mu.Unlock() + if me.handshaking == 0 { + panic("handshake count invariant is broken") + } + me.handshaking-- if err != nil { err = fmt.Errorf("error during handshake: %s", err) return @@ -490,12 +516,12 @@ func (me *Client) runConnection(sock net.Conn, torrent *torrent, discovery peerS if !ok { return } - me.mu.Lock() - defer me.mu.Unlock() torrent = me.torrent(hsRes.InfoHash) if torrent == nil { return } + sock.SetWriteDeadline(time.Time{}) + sock = peerConn{sock} conn := newConnection(sock, hsRes.peerExtensionBytes, hsRes.peerID) defer conn.Close() conn.Discovery = discovery -- 2.48.1