]> Sergey Matveev's repositories - btrtrc.git/blobdiff - tracker/udp/client.go
Don't reconnect before sending requests with current conn ID
[btrtrc.git] / tracker / udp / client.go
index 42f0d14e6a99e6a4a226a25adef43d91c4469b14..d570b1a0ee0c4e007667e9ffc4349a9d6ca5024a 100644 (file)
@@ -19,8 +19,11 @@ type Client struct {
        mu           sync.Mutex
        connId       ConnectionId
        connIdIssued time.Time
-       Dispatcher   *Dispatcher
-       Writer       io.Writer
+
+       shouldReconnectOverride func() bool
+
+       Dispatcher *Dispatcher
+       Writer     io.Writer
 }
 
 func (cl *Client) Announce(
@@ -81,14 +84,26 @@ func (cl *Client) Scrape(
        return
 }
 
+func (cl *Client) shouldReconnectDefault() bool {
+       return cl.connIdIssued.IsZero() || time.Since(cl.connIdIssued) >= time.Minute
+}
+
+func (cl *Client) shouldReconnect() bool {
+       if cl.shouldReconnectOverride != nil {
+               return cl.shouldReconnectOverride()
+       }
+       return cl.shouldReconnectDefault()
+}
+
 func (cl *Client) connect(ctx context.Context) (err error) {
-       // We could get fancier here and use RWMutex, and even fire off the connection asynchronously
-       // and provide a grace period while it resolves.
-       cl.mu.Lock()
-       defer cl.mu.Unlock()
-       if !cl.connIdIssued.IsZero() && time.Since(cl.connIdIssued) < time.Minute {
+       if !cl.shouldReconnect() {
                return nil
        }
+       return cl.doConnectRoundTrip(ctx)
+}
+
+// This just does the connect request and updates local state if it succeeds.
+func (cl *Client) doConnectRoundTrip(ctx context.Context) (err error) {
        respBody, _, err := cl.request(ctx, ActionConnect, nil)
        if err != nil {
                return err
@@ -100,6 +115,7 @@ func (cl *Client) connect(ctx context.Context) (err error) {
        }
        cl.connId = connResp.ConnectionId
        cl.connIdIssued = time.Now()
+       //log.Printf("conn id set to %x", cl.connId)
        return
 }
 
@@ -116,25 +132,45 @@ func (cl *Client) connIdForRequest(ctx context.Context, action Action) (id Conne
        return
 }
 
-func (cl *Client) requestWriter(ctx context.Context, action Action, body []byte, tId TransactionId) (err error) {
-       var buf bytes.Buffer
-       for n := 0; ; n++ {
-               var connId ConnectionId
+func (cl *Client) writeRequest(
+       ctx context.Context, action Action, body []byte, tId TransactionId, buf *bytes.Buffer,
+) (
+       err error,
+) {
+       var connId ConnectionId
+       if action == ActionConnect {
+               connId = ConnectRequestConnectionId
+       } else {
+               // We lock here while establishing a connection ID, and then ensuring that the request is
+               // written before allowing the connection ID to change again. This is to ensure the server
+               // doesn't assign us another ID before we've sent this request. Note that this doesn't allow
+               // for us to return if the context is cancelled while we wait to obtain a new ID.
+               cl.mu.Lock()
+               defer cl.mu.Unlock()
                connId, err = cl.connIdForRequest(ctx, action)
                if err != nil {
                        return
                }
-               buf.Reset()
-               err = Write(&buf, RequestHeader{
-                       ConnectionId:  connId,
-                       Action:        action,
-                       TransactionId: tId,
-               })
-               if err != nil {
-                       panic(err)
-               }
-               buf.Write(body)
-               _, err = cl.Writer.Write(buf.Bytes())
+       }
+       buf.Reset()
+       err = Write(buf, RequestHeader{
+               ConnectionId:  connId,
+               Action:        action,
+               TransactionId: tId,
+       })
+       if err != nil {
+               panic(err)
+       }
+       buf.Write(body)
+       _, err = cl.Writer.Write(buf.Bytes())
+       //log.Printf("sent request with conn id %x", connId)
+       return
+}
+
+func (cl *Client) requestWriter(ctx context.Context, action Action, body []byte, tId TransactionId) (err error) {
+       var buf bytes.Buffer
+       for n := 0; ; n++ {
+               err = cl.writeRequest(ctx, action, body, tId, &buf)
                if err != nil {
                        return
                }