From 35ba3c44e157d70222e49cf5f2f746bc2eb85d6b Mon Sep 17 00:00:00 2001
From: Matt Joiner <anacrolix@gmail.com>
Date: Sat, 6 Dec 2014 21:21:20 -0600
Subject: [PATCH] dht: Retry queries twice before timing out

---
 dht/dht.go | 144 ++++++++++++++++++++++++++++++-----------------------
 1 file changed, 81 insertions(+), 63 deletions(-)

diff --git a/dht/dht.go b/dht/dht.go
index 827a6046..65447144 100644
--- a/dht/dht.go
+++ b/dht/dht.go
@@ -9,6 +9,7 @@ import (
 	"io"
 	"log"
 	"math/big"
+	"math/rand"
 	"net"
 	"os"
 	"sync"
@@ -127,9 +128,11 @@ func (s *Server) String() string {
 type Node struct {
 	addr          dHTAddr
 	id            string
-	lastHeardFrom time.Time
-	lastSentTo    time.Time
 	announceToken string
+
+	lastGotQuery    time.Time
+	lastGotResponse time.Time
+	lastSentQuery   time.Time
 }
 
 func (n *Node) NodeInfo() (ret NodeInfo) {
@@ -144,13 +147,15 @@ func (n *Node) Good() bool {
 	if len(n.id) != 20 {
 		return false
 	}
-	if n.lastSentTo.IsZero() {
+	// No reason to think ill of them if they've never responded.
+	if n.lastSentQuery.IsZero() {
 		return true
 	}
-	if n.lastSentTo.Before(n.lastHeardFrom) {
+	// They answered our last query.
+	if n.lastSentQuery.Before(n.lastGotResponse) {
 		return true
 	}
-	if time.Now().Sub(n.lastHeardFrom) >= 1*time.Minute {
+	if time.Now().Sub(n.lastSentQuery) >= 2*time.Minute {
 		return false
 	}
 	return true
@@ -217,16 +222,63 @@ func (m Msg) AnnounceToken() string {
 }
 
 type transaction struct {
-	mu         sync.Mutex
-	remoteAddr dHTAddr
-	t          string
-	Response   chan Msg
-	onResponse func(Msg)
-	done       chan struct{}
+	mu          sync.Mutex
+	remoteAddr  dHTAddr
+	t           string
+	Response    chan Msg
+	onResponse  func(Msg)
+	done        chan struct{}
+	queryPacket []byte
+	timer       *time.Timer
+	s           *Server
+	retries     int
+}
+
+func jitterDuration(average time.Duration, plusMinus time.Duration) time.Duration {
+	return average - plusMinus/2 + time.Duration(rand.Int63n(int64(plusMinus)))
+}
+
+func (t *transaction) startTimer() {
+	t.timer = time.AfterFunc(jitterDuration(20*time.Second, time.Second), t.timerCallback)
+}
+
+func (t *transaction) timerCallback() {
+	t.mu.Lock()
+	defer t.mu.Unlock()
+	select {
+	case <-t.done:
+		return
+	default:
+	}
+	if t.retries == 2 {
+		t.timeout()
+		return
+	}
+	t.retries++
+	t.sendQuery()
+	if t.timer.Reset(jitterDuration(20*time.Second, time.Second)) {
+		panic("timer should have fired to get here")
+	}
+}
+
+func (t *transaction) sendQuery() error {
+	return t.s.writeToNode(t.queryPacket, t.remoteAddr)
 }
 
 func (t *transaction) timeout() {
-	t.Close()
+	t.close()
+}
+
+func (t *transaction) close() {
+	if t.closing() {
+		return
+	}
+	close(t.Response)
+	close(t.done)
+	t.timer.Stop()
+	t.s.mu.Lock()
+	defer t.s.mu.Unlock()
+	t.s.removeTransaction(t)
 }
 
 func (t *transaction) closing() bool {
@@ -241,11 +293,7 @@ func (t *transaction) closing() bool {
 func (t *transaction) Close() {
 	t.mu.Lock()
 	defer t.mu.Unlock()
-	if t.closing() {
-		return
-	}
-	close(t.Response)
-	close(t.done)
+	t.close()
 }
 
 func (t *transaction) handleResponse(m Msg) {
@@ -338,13 +386,9 @@ func (s *Server) processPacket(b []byte, addr dHTAddr) {
 		//log.Printf("unexpected message: %#v", d)
 		return
 	}
+	s.getNode(addr).lastGotResponse = time.Now()
 	t.handleResponse(d)
 	s.removeTransaction(t)
-	id := ""
-	if d["y"] == "r" {
-		id = d["r"].(map[string]interface{})["id"].(string)
-	}
-	s.heardFromNode(addr, id)
 }
 
 func (s *Server) serve() error {
@@ -392,7 +436,10 @@ func (s *Server) nodeByID(id string) *Node {
 
 func (s *Server) handleQuery(source dHTAddr, m Msg) {
 	args := m["a"].(map[string]interface{})
-	s.heardFromNode(source, args["id"].(string))
+	node := s.getNode(source)
+	node.id = args["id"].(string)
+	node.lastGotQuery = time.Now()
+	// Don't respond.
 	if s.passive {
 		return
 	}
@@ -472,14 +519,6 @@ func (s *Server) reply(addr dHTAddr, t string, r map[string]interface{}) {
 	}
 }
 
-func (s *Server) heardFromNode(addr dHTAddr, id string) {
-	n := s.getNode(addr)
-	if len(id) == 20 {
-		n.id = id
-	}
-	n.lastHeardFrom = time.Now()
-}
-
 func (s *Server) getNode(addr dHTAddr) (n *Node) {
 	n = s.nodes[addr.String()]
 	if n == nil {
@@ -507,15 +546,9 @@ func (s *Server) writeToNode(b []byte, node dHTAddr) (err error) {
 		err = io.ErrShortWrite
 		return
 	}
-	s.sentToNode(node)
 	return
 }
 
-func (s *Server) sentToNode(addr dHTAddr) {
-	n := s.getNode(addr)
-	n.lastSentTo = time.Now()
-}
-
 func (s *Server) findResponseTransaction(transactionID string, sourceNode dHTAddr) *transaction {
 	for _, t := range s.transactions {
 		if t.t == transactionID && t.remoteAddr.String() == sourceNode.String() {
@@ -555,23 +588,6 @@ func (s *Server) IDString() string {
 	return s.id
 }
 
-func (s *Server) timeoutTransaction(t *transaction) {
-	select {
-	case <-t.done:
-		return
-	case <-time.After(time.Minute):
-	}
-	s.mu.Lock()
-	defer s.mu.Unlock()
-	select {
-	case <-t.done:
-		return
-	default:
-	}
-	t.timeout()
-	s.removeTransaction(t)
-}
-
 func (s *Server) query(node dHTAddr, q string, a map[string]interface{}) (t *transaction, err error) {
 	tid := s.nextTransactionID()
 	if a == nil {
@@ -589,18 +605,20 @@ func (s *Server) query(node dHTAddr, q string, a map[string]interface{}) (t *tra
 		return
 	}
 	t = &transaction{
-		remoteAddr: node,
-		t:          tid,
-		Response:   make(chan Msg, 1),
-		done:       make(chan struct{}),
-	}
-	s.addTransaction(t)
-	err = s.writeToNode(b, node)
+		remoteAddr:  node,
+		t:           tid,
+		Response:    make(chan Msg, 1),
+		done:        make(chan struct{}),
+		queryPacket: b,
+		s:           s,
+	}
+	err = t.sendQuery()
 	if err != nil {
-		s.removeTransaction(t)
 		return
 	}
-	go s.timeoutTransaction(t)
+	s.getNode(node).lastSentQuery = time.Now()
+	t.startTimer()
+	s.addTransaction(t)
 	return
 }
 
-- 
2.51.0