]> Sergey Matveev's repositories - btrtrc.git/commitdiff
Keep track of ongoing handshakes and add timeouts to connection sockets
authorMatt Joiner <anacrolix@gmail.com>
Wed, 27 Aug 2014 23:45:20 +0000 (09:45 +1000)
committerMatt Joiner <anacrolix@gmail.com>
Wed, 27 Aug 2014 23:45:20 +0000 (09:45 +1000)
client.go

index 31b7cabee8b4eca0bbf4a19e83dced50c8dc9c17..7a4ee86f3c670172ed944774b9b1ad590b85d9b6 100644 (file)
--- a/client.go
+++ b/client.go
@@ -109,8 +109,9 @@ type Client struct {
        event sync.Cond
        quit  chan struct{}
 
-       halfOpen int
-       torrents map[InfoHash]*torrent
+       halfOpen    int
+       handshaking int
+       torrents    map[InfoHash]*torrent
 
        dataWaiterMutex sync.Mutex
        dataWaiter      chan struct{}
@@ -130,6 +131,7 @@ func (cl *Client) WriteStatus(w io.Writer) {
        }
        fmt.Fprintf(w, "Peer ID: %q\n", cl.peerID)
        fmt.Fprintf(w, "Half open outgoing connections: %d\n", cl.halfOpen)
+       fmt.Fprintf(w, "Handshaking: %d\n", cl.handshaking)
        if cl.dHT != nil {
                fmt.Fprintf(w, "DHT nodes: %d\n", cl.dHT.NumNodes())
                fmt.Fprintf(w, "DHT Server ID: %x\n", cl.dHT.IDString())
@@ -474,8 +476,26 @@ func handshake(sock io.ReadWriteCloser, ih *InfoHash, peerID [20]byte) (res hand
        return
 }
 
+type peerConn struct {
+       net.Conn
+}
+
+func (pc peerConn) Read(b []byte) (n int, err error) {
+       err = pc.Conn.SetReadDeadline(time.Now().Add(150 * time.Second))
+       if err != nil {
+               return
+       }
+       n, err = pc.Conn.Read(b)
+       return
+}
+
 func (me *Client) runConnection(sock net.Conn, torrent *torrent, discovery peerSource) (err error) {
        defer sock.Close()
+       me.mu.Lock()
+       me.handshaking++
+       me.mu.Unlock()
+       // One minute to complete handshake.
+       sock.SetDeadline(time.Now().Add(time.Minute))
        hsRes, ok, err := handshake(sock, func() *InfoHash {
                if torrent == nil {
                        return nil
@@ -483,6 +503,12 @@ func (me *Client) runConnection(sock net.Conn, torrent *torrent, discovery peerS
                        return &torrent.InfoHash
                }
        }(), me.peerID)
+       me.mu.Lock()
+       defer me.mu.Unlock()
+       if me.handshaking == 0 {
+               panic("handshake count invariant is broken")
+       }
+       me.handshaking--
        if err != nil {
                err = fmt.Errorf("error during handshake: %s", err)
                return
@@ -490,12 +516,12 @@ func (me *Client) runConnection(sock net.Conn, torrent *torrent, discovery peerS
        if !ok {
                return
        }
-       me.mu.Lock()
-       defer me.mu.Unlock()
        torrent = me.torrent(hsRes.InfoHash)
        if torrent == nil {
                return
        }
+       sock.SetWriteDeadline(time.Time{})
+       sock = peerConn{sock}
        conn := newConnection(sock, hsRes.peerExtensionBytes, hsRes.peerID)
        defer conn.Close()
        conn.Discovery = discovery