]> Sergey Matveev's repositories - btrtrc.git/commitdiff
dht: Retry queries twice before timing out
authorMatt Joiner <anacrolix@gmail.com>
Sun, 7 Dec 2014 03:21:20 +0000 (21:21 -0600)
committerMatt Joiner <anacrolix@gmail.com>
Sun, 7 Dec 2014 03:21:20 +0000 (21:21 -0600)
dht/dht.go

index 827a6046994452eb26674b96c1cdda316c2228b0..654471443fe664218618ca661f463e5e66b0c1c3 100644 (file)
@@ -9,6 +9,7 @@ import (
        "io"
        "log"
        "math/big"
+       "math/rand"
        "net"
        "os"
        "sync"
@@ -127,9 +128,11 @@ func (s *Server) String() string {
 type Node struct {
        addr          dHTAddr
        id            string
-       lastHeardFrom time.Time
-       lastSentTo    time.Time
        announceToken string
+
+       lastGotQuery    time.Time
+       lastGotResponse time.Time
+       lastSentQuery   time.Time
 }
 
 func (n *Node) NodeInfo() (ret NodeInfo) {
@@ -144,13 +147,15 @@ func (n *Node) Good() bool {
        if len(n.id) != 20 {
                return false
        }
-       if n.lastSentTo.IsZero() {
+       // No reason to think ill of them if they've never responded.
+       if n.lastSentQuery.IsZero() {
                return true
        }
-       if n.lastSentTo.Before(n.lastHeardFrom) {
+       // They answered our last query.
+       if n.lastSentQuery.Before(n.lastGotResponse) {
                return true
        }
-       if time.Now().Sub(n.lastHeardFrom) >= 1*time.Minute {
+       if time.Now().Sub(n.lastSentQuery) >= 2*time.Minute {
                return false
        }
        return true
@@ -217,16 +222,63 @@ func (m Msg) AnnounceToken() string {
 }
 
 type transaction struct {
-       mu         sync.Mutex
-       remoteAddr dHTAddr
-       t          string
-       Response   chan Msg
-       onResponse func(Msg)
-       done       chan struct{}
+       mu          sync.Mutex
+       remoteAddr  dHTAddr
+       t           string
+       Response    chan Msg
+       onResponse  func(Msg)
+       done        chan struct{}
+       queryPacket []byte
+       timer       *time.Timer
+       s           *Server
+       retries     int
+}
+
+func jitterDuration(average time.Duration, plusMinus time.Duration) time.Duration {
+       return average - plusMinus/2 + time.Duration(rand.Int63n(int64(plusMinus)))
+}
+
+func (t *transaction) startTimer() {
+       t.timer = time.AfterFunc(jitterDuration(20*time.Second, time.Second), t.timerCallback)
+}
+
+func (t *transaction) timerCallback() {
+       t.mu.Lock()
+       defer t.mu.Unlock()
+       select {
+       case <-t.done:
+               return
+       default:
+       }
+       if t.retries == 2 {
+               t.timeout()
+               return
+       }
+       t.retries++
+       t.sendQuery()
+       if t.timer.Reset(jitterDuration(20*time.Second, time.Second)) {
+               panic("timer should have fired to get here")
+       }
+}
+
+func (t *transaction) sendQuery() error {
+       return t.s.writeToNode(t.queryPacket, t.remoteAddr)
 }
 
 func (t *transaction) timeout() {
-       t.Close()
+       t.close()
+}
+
+func (t *transaction) close() {
+       if t.closing() {
+               return
+       }
+       close(t.Response)
+       close(t.done)
+       t.timer.Stop()
+       t.s.mu.Lock()
+       defer t.s.mu.Unlock()
+       t.s.removeTransaction(t)
 }
 
 func (t *transaction) closing() bool {
@@ -241,11 +293,7 @@ func (t *transaction) closing() bool {
 func (t *transaction) Close() {
        t.mu.Lock()
        defer t.mu.Unlock()
-       if t.closing() {
-               return
-       }
-       close(t.Response)
-       close(t.done)
+       t.close()
 }
 
 func (t *transaction) handleResponse(m Msg) {
@@ -338,13 +386,9 @@ func (s *Server) processPacket(b []byte, addr dHTAddr) {
                //log.Printf("unexpected message: %#v", d)
                return
        }
+       s.getNode(addr).lastGotResponse = time.Now()
        t.handleResponse(d)
        s.removeTransaction(t)
-       id := ""
-       if d["y"] == "r" {
-               id = d["r"].(map[string]interface{})["id"].(string)
-       }
-       s.heardFromNode(addr, id)
 }
 
 func (s *Server) serve() error {
@@ -392,7 +436,10 @@ func (s *Server) nodeByID(id string) *Node {
 
 func (s *Server) handleQuery(source dHTAddr, m Msg) {
        args := m["a"].(map[string]interface{})
-       s.heardFromNode(source, args["id"].(string))
+       node := s.getNode(source)
+       node.id = args["id"].(string)
+       node.lastGotQuery = time.Now()
+       // Don't respond.
        if s.passive {
                return
        }
@@ -472,14 +519,6 @@ func (s *Server) reply(addr dHTAddr, t string, r map[string]interface{}) {
        }
 }
 
-func (s *Server) heardFromNode(addr dHTAddr, id string) {
-       n := s.getNode(addr)
-       if len(id) == 20 {
-               n.id = id
-       }
-       n.lastHeardFrom = time.Now()
-}
-
 func (s *Server) getNode(addr dHTAddr) (n *Node) {
        n = s.nodes[addr.String()]
        if n == nil {
@@ -507,15 +546,9 @@ func (s *Server) writeToNode(b []byte, node dHTAddr) (err error) {
                err = io.ErrShortWrite
                return
        }
-       s.sentToNode(node)
        return
 }
 
-func (s *Server) sentToNode(addr dHTAddr) {
-       n := s.getNode(addr)
-       n.lastSentTo = time.Now()
-}
-
 func (s *Server) findResponseTransaction(transactionID string, sourceNode dHTAddr) *transaction {
        for _, t := range s.transactions {
                if t.t == transactionID && t.remoteAddr.String() == sourceNode.String() {
@@ -555,23 +588,6 @@ func (s *Server) IDString() string {
        return s.id
 }
 
-func (s *Server) timeoutTransaction(t *transaction) {
-       select {
-       case <-t.done:
-               return
-       case <-time.After(time.Minute):
-       }
-       s.mu.Lock()
-       defer s.mu.Unlock()
-       select {
-       case <-t.done:
-               return
-       default:
-       }
-       t.timeout()
-       s.removeTransaction(t)
-}
-
 func (s *Server) query(node dHTAddr, q string, a map[string]interface{}) (t *transaction, err error) {
        tid := s.nextTransactionID()
        if a == nil {
@@ -589,18 +605,20 @@ func (s *Server) query(node dHTAddr, q string, a map[string]interface{}) (t *tra
                return
        }
        t = &transaction{
-               remoteAddr: node,
-               t:          tid,
-               Response:   make(chan Msg, 1),
-               done:       make(chan struct{}),
-       }
-       s.addTransaction(t)
-       err = s.writeToNode(b, node)
+               remoteAddr:  node,
+               t:           tid,
+               Response:    make(chan Msg, 1),
+               done:        make(chan struct{}),
+               queryPacket: b,
+               s:           s,
+       }
+       err = t.sendQuery()
        if err != nil {
-               s.removeTransaction(t)
                return
        }
-       go s.timeoutTransaction(t)
+       s.getNode(node).lastSentQuery = time.Now()
+       t.startTimer()
+       s.addTransaction(t)
        return
 }