]> Sergey Matveev's repositories - btrtrc.git/commitdiff
Augment dialed connection timeouts with context.Context
authorMatt Joiner <anacrolix@gmail.com>
Wed, 16 Aug 2017 07:05:05 +0000 (17:05 +1000)
committerMatt Joiner <anacrolix@gmail.com>
Wed, 16 Aug 2017 07:05:05 +0000 (17:05 +1000)
Works toward a dial timeout for half open connections. https://github.com/anacrolix/torrent/issues/169

client.go

index 890ff8cff4dfd15ae24fc8d80cf153bd2706fbf6..13f5d24abac96a2c4970f293da741a284decca5c 100644 (file)
--- a/client.go
+++ b/client.go
@@ -472,20 +472,12 @@ type dialResult struct {
        UTP  bool
 }
 
-func doDial(dial func(string, *Torrent) (net.Conn, error), ch chan dialResult, utp bool, addr string, t *Torrent) {
-       conn, err := dial(addr, t)
-       if err != nil {
-               if conn != nil {
-                       conn.Close()
-               }
-               conn = nil // Pedantic
-       }
-       ch <- dialResult{conn, utp}
+func countDialResult(err error) {
        if err == nil {
                successfulDials.Add(1)
-               return
+       } else {
+               unsuccessfulDials.Add(1)
        }
-       unsuccessfulDials.Add(1)
 }
 
 func reducedDialTimeout(max time.Duration, halfOpenLimit int, pendingPeers int) (ret time.Duration) {
@@ -526,12 +518,12 @@ func (cl *Client) dialTimeout(t *Torrent) time.Duration {
        return reducedDialTimeout(nominalDialTimeout, cl.halfOpenLimit, pendingPeers)
 }
 
-func (cl *Client) dialTCP(addr string, t *Torrent) (c net.Conn, err error) {
+func (cl *Client) dialTCP(ctx context.Context, addr string) (c net.Conn, err error) {
        d := net.Dialer{
-               // LocalAddr: cl.tcpListener.Addr(),
-               Timeout: cl.dialTimeout(t),
+       // LocalAddr: cl.tcpListener.Addr(),
        }
-       c, err = d.Dial("tcp", addr)
+       c, err = d.DialContext(ctx, "tcp", addr)
+       countDialResult(err)
        if err == nil {
                c.(*net.TCPConn).SetLinger(0)
        }
@@ -539,25 +531,32 @@ func (cl *Client) dialTCP(addr string, t *Torrent) (c net.Conn, err error) {
        return
 }
 
-func (cl *Client) dialUTP(addr string, t *Torrent) (net.Conn, error) {
-       ctx, cancel := context.WithTimeout(context.Background(), cl.dialTimeout(t))
-       defer cancel()
-       return cl.utpSock.DialContext(ctx, addr)
+func (cl *Client) dialUTP(ctx context.Context, addr string) (c net.Conn, err error) {
+       c, err = cl.utpSock.DialContext(ctx, addr)
+       countDialResult(err)
+       return
 }
 
 // Returns a connection over UTP or TCP, whichever is first to connect.
-func (cl *Client) dialFirst(addr string, t *Torrent) (conn net.Conn, utp bool) {
-       // Initiate connections via TCP and UTP simultaneously. Use the first one
-       // that succeeds.
+func (cl *Client) dialFirst(ctx context.Context, addr string) (conn net.Conn, utp bool) {
+       ctx, cancel := context.WithCancel(ctx)
+       // As soon as we return one connection, cancel the others.
+       defer cancel()
        left := 0
        resCh := make(chan dialResult, left)
        if !cl.config.DisableUTP {
                left++
-               go doDial(cl.dialUTP, resCh, true, addr, t)
+               go func() {
+                       c, _ := cl.dialUTP(ctx, addr)
+                       resCh <- dialResult{c, true}
+               }()
        }
        if !cl.config.DisableTCP {
                left++
-               go doDial(cl.dialTCP, resCh, false, addr, t)
+               go func() {
+                       c, _ := cl.dialTCP(ctx, addr)
+                       resCh <- dialResult{c, false}
+               }()
        }
        var res dialResult
        // Wait for a successful connection.
@@ -590,15 +589,21 @@ func (cl *Client) noLongerHalfOpen(t *Torrent, addr string) {
 
 // Performs initiator handshakes and returns a connection. Returns nil
 // *connection if no connection for valid reasons.
-func (cl *Client) handshakesConnection(nc net.Conn, t *Torrent, encrypted, utp bool) (c *connection, err error) {
+func (cl *Client) handshakesConnection(ctx context.Context, nc net.Conn, t *Torrent, encrypted, utp bool) (c *connection, err error) {
        c = cl.newConnection(nc)
        c.encrypted = encrypted
        c.uTP = utp
-       err = nc.SetDeadline(time.Now().Add(handshakesTimeout))
+       ctx, cancel := context.WithTimeout(ctx, handshakesTimeout)
+       defer cancel()
+       dl, ok := ctx.Deadline()
+       if !ok {
+               panic(ctx)
+       }
+       err = nc.SetDeadline(dl)
        if err != nil {
-               return
+               panic(err)
        }
-       ok, err := cl.initiateHandshakes(c, t)
+       ok, err = cl.initiateHandshakes(c, t)
        if !ok {
                c = nil
        }
@@ -608,12 +613,14 @@ func (cl *Client) handshakesConnection(nc net.Conn, t *Torrent, encrypted, utp b
 // Returns nil connection and nil error if no connection could be established
 // for valid reasons.
 func (cl *Client) establishOutgoingConn(t *Torrent, addr string) (c *connection, err error) {
-       nc, utp := cl.dialFirst(addr, t)
+       ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
+       defer cancel()
+       nc, utp := cl.dialFirst(ctx, addr)
        if nc == nil {
                return
        }
        encryptFirst := !cl.config.DisableEncryption && !cl.config.PreferNoEncryption
-       c, err = cl.handshakesConnection(nc, t, encryptFirst, utp)
+       c, err = cl.handshakesConnection(ctx, nc, t, encryptFirst, utp)
        if err != nil {
                nc.Close()
                return
@@ -628,15 +635,15 @@ func (cl *Client) establishOutgoingConn(t *Torrent, addr string) (c *connection,
        // Try again with encryption if we didn't earlier, or without if we did,
        // using whichever protocol type worked last time.
        if utp {
-               nc, err = cl.dialUTP(addr, t)
+               nc, err = cl.dialUTP(ctx, addr)
        } else {
-               nc, err = cl.dialTCP(addr, t)
+               nc, err = cl.dialTCP(ctx, addr)
        }
        if err != nil {
                err = fmt.Errorf("error dialing for unencrypted connection: %s", err)
                return
        }
-       c, err = cl.handshakesConnection(nc, t, !encryptFirst, utp)
+       c, err = cl.handshakesConnection(ctx, nc, t, !encryptFirst, utp)
        if err != nil || c == nil {
                nc.Close()
        }