From afe4d8795daa14a80d3d1ed1e72d1261f566a1b7 Mon Sep 17 00:00:00 2001
From: Matt Joiner <anacrolix@gmail.com>
Date: Thu, 20 Feb 2020 17:46:29 +1100
Subject: [PATCH] Support custom DHT servers

Addresses #266.
---
 client.go      | 32 ++++++++++++++++---------------
 client_test.go |  8 ++++----
 connection.go  |  5 ++---
 dht.go         | 51 ++++++++++++++++++++++++++++++++++++++++++++++++++
 torrent.go     |  6 +++---
 5 files changed, 77 insertions(+), 25 deletions(-)
 create mode 100644 dht.go

diff --git a/client.go b/client.go
index 8bb1ca38..e5dfa045 100644
--- a/client.go
+++ b/client.go
@@ -58,7 +58,7 @@ type Client struct {
 	onClose        []func()
 	dialers        []Dialer
 	listeners      []Listener
-	dhtServers     []*dht.Server
+	dhtServers     []DhtServer
 	ipBlockList    iplist.Ranger
 	// Our BitTorrent protocol extension bytes, sent in our BT handshakes.
 	extensionBytes pp.PeerExtensionBits
@@ -101,12 +101,10 @@ func (cl *Client) LocalPort() (port int) {
 	return
 }
 
-func writeDhtServerStatus(w io.Writer, s *dht.Server) {
+func writeDhtServerStatus(w io.Writer, s DhtServer) {
 	dhtStats := s.Stats()
-	fmt.Fprintf(w, "\t# Nodes: %d (%d good, %d banned)\n", dhtStats.Nodes, dhtStats.GoodNodes, dhtStats.BadNodes)
 	fmt.Fprintf(w, "\tServer ID: %x\n", s.ID())
-	fmt.Fprintf(w, "\tAnnounces: %d\n", dhtStats.SuccessfulOutboundAnnouncePeerQueries)
-	fmt.Fprintf(w, "\tOutstanding transactions: %d\n", dhtStats.OutstandingTransactions)
+	spew.Fdump(w, dhtStats)
 }
 
 // Writes out a human readable status of the client, such as for writing to a
@@ -120,7 +118,7 @@ func (cl *Client) WriteStatus(_w io.Writer) {
 	fmt.Fprintf(w, "Peer ID: %+q\n", cl.PeerID())
 	fmt.Fprintf(w, "Announce key: %x\n", cl.announceKey())
 	fmt.Fprintf(w, "Banned IPs: %d\n", len(cl.badPeerIPsLocked()))
-	cl.eachDhtServer(func(s *dht.Server) {
+	cl.eachDhtServer(func(s DhtServer) {
 		fmt.Fprintf(w, "%s DHT server at %s:\n", s.Addr().Network(), s.Addr().String())
 		writeDhtServerStatus(w, s)
 	})
@@ -237,11 +235,11 @@ func NewClient(cfg *ClientConfig) (cl *Client, err error) {
 	if !cfg.NoDHT {
 		for _, s := range sockets {
 			if pc, ok := s.(net.PacketConn); ok {
-				ds, err := cl.newDhtServer(pc)
+				ds, err := cl.newAnacrolixDhtServer(pc)
 				if err != nil {
 					panic(err)
 				}
-				cl.dhtServers = append(cl.dhtServers, ds)
+				cl.dhtServers = append(cl.dhtServers, anacrolixDhtServerWrapper{ds})
 				cl.onClose = append(cl.onClose, func() { ds.Close() })
 			}
 		}
@@ -250,6 +248,10 @@ func NewClient(cfg *ClientConfig) (cl *Client, err error) {
 	return
 }
 
+func (cl *Client) AddDhtServer(d DhtServer) {
+	cl.dhtServers = append(cl.dhtServers, d)
+}
+
 // Adds a Dialer for outgoing connections. All Dialers are used when attempting to connect to a
 // given address for any Torrent.
 func (cl *Client) AddDialer(d Dialer) {
@@ -300,7 +302,7 @@ func (cl *Client) listenNetworks() (ns []network) {
 	return
 }
 
-func (cl *Client) newDhtServer(conn net.PacketConn) (s *dht.Server, err error) {
+func (cl *Client) newAnacrolixDhtServer(conn net.PacketConn) (s *dht.Server, err error) {
 	cfg := dht.ServerConfig{
 		IPBlocklist:    cl.ipBlockList,
 		Conn:           conn,
@@ -335,7 +337,7 @@ func (cl *Client) Closed() <-chan struct{} {
 	return cl.closed.C()
 }
 
-func (cl *Client) eachDhtServer(f func(*dht.Server)) {
+func (cl *Client) eachDhtServer(f func(DhtServer)) {
 	for _, ds := range cl.dhtServers {
 		f(ds)
 	}
@@ -929,14 +931,14 @@ func (cl *Client) sendInitialMessages(conn *connection, torrent *Torrent) {
 }
 
 func (cl *Client) dhtPort() (ret uint16) {
-	cl.eachDhtServer(func(s *dht.Server) {
+	cl.eachDhtServer(func(s DhtServer) {
 		ret = uint16(missinggo.AddrPort(s.Addr()))
 	})
 	return
 }
 
 func (cl *Client) haveDhtServer() (ret bool) {
-	cl.eachDhtServer(func(_ *dht.Server) {
+	cl.eachDhtServer(func(_ DhtServer) {
 		ret = true
 	})
 	return
@@ -1071,7 +1073,7 @@ func (cl *Client) AddTorrentInfoHashWithStorage(infoHash metainfo.Hash, specStor
 	new = true
 
 	t = cl.newTorrent(infoHash, specStorage)
-	cl.eachDhtServer(func(s *dht.Server) {
+	cl.eachDhtServer(func(s DhtServer) {
 		go t.dhtAnnouncer(s)
 	})
 	cl.torrents[infoHash] = t
@@ -1188,7 +1190,7 @@ func (cl *Client) AddTorrentFromFile(filename string) (T *Torrent, err error) {
 	return cl.AddTorrent(mi)
 }
 
-func (cl *Client) DhtServers() []*dht.Server {
+func (cl *Client) DhtServers() []DhtServer {
 	return cl.dhtServers
 }
 
@@ -1206,7 +1208,7 @@ func (cl *Client) AddDHTNodes(nodes []string) {
 				Port: hmp.Port,
 			},
 		}
-		cl.eachDhtServer(func(s *dht.Server) {
+		cl.eachDhtServer(func(s DhtServer) {
 			s.AddNode(ni)
 		})
 	}
diff --git a/client_test.go b/client_test.go
index 7af30321..6748bd30 100644
--- a/client_test.go
+++ b/client_test.go
@@ -316,8 +316,8 @@ func TestDHTInheritBlocklist(t *testing.T) {
 	require.NoError(t, err)
 	defer cl.Close()
 	numServers := 0
-	cl.eachDhtServer(func(s *dht.Server) {
-		assert.Equal(t, ipl, s.IPBlocklist())
+	cl.eachDhtServer(func(s DhtServer) {
+		assert.Equal(t, ipl, s.(anacrolixDhtServerWrapper).IPBlocklist())
 		numServers++
 	})
 	assert.EqualValues(t, 2, numServers)
@@ -434,8 +434,8 @@ func TestAddMetainfoWithNodes(t *testing.T) {
 	require.NoError(t, err)
 	defer cl.Close()
 	sum := func() (ret int64) {
-		cl.eachDhtServer(func(s *dht.Server) {
-			ret += s.Stats().OutboundQueriesAttempted
+		cl.eachDhtServer(func(s DhtServer) {
+			ret += s.Stats().(dht.ServerStats).OutboundQueriesAttempted
 		})
 		return
 	}
diff --git a/connection.go b/connection.go
index a7f8edf2..edddebd8 100644
--- a/connection.go
+++ b/connection.go
@@ -12,7 +12,6 @@ import (
 	"sync"
 	"time"
 
-	"github.com/anacrolix/dht/v2"
 	"github.com/anacrolix/log"
 	"github.com/anacrolix/missinggo"
 	"github.com/anacrolix/missinggo/iter"
@@ -1060,8 +1059,8 @@ func (c *connection) mainReadLoop() (err error) {
 			if msg.Port != 0 {
 				pingAddr.Port = int(msg.Port)
 			}
-			cl.eachDhtServer(func(s *dht.Server) {
-				go s.Ping(&pingAddr, nil)
+			cl.eachDhtServer(func(s DhtServer) {
+				go s.Ping(&pingAddr)
 			})
 		case pp.Suggest:
 			torrent.Add("suggests received", 1)
diff --git a/dht.go b/dht.go
new file mode 100644
index 00000000..da79aee4
--- /dev/null
+++ b/dht.go
@@ -0,0 +1,51 @@
+package torrent
+
+import (
+	"io"
+	"net"
+
+	"github.com/anacrolix/dht/v2"
+	"github.com/anacrolix/dht/v2/krpc"
+)
+
+type DhtServer interface {
+	Stats() interface{}
+	ID() [20]byte
+	Addr() net.Addr
+	AddNode(ni krpc.NodeInfo) error
+	Ping(addr *net.UDPAddr)
+	Announce(hash [20]byte, port int, impliedPort bool) (DhtAnnounce, error)
+	WriteStatus(io.Writer)
+}
+
+type DhtAnnounce interface {
+	Close()
+	Peers() <-chan dht.PeersValues
+}
+
+type anacrolixDhtServerWrapper struct {
+	*dht.Server
+}
+
+func (me anacrolixDhtServerWrapper) Stats() interface{} {
+	return me.Server.Stats()
+}
+
+type anacrolixDhtAnnounceWrapper struct {
+	*dht.Announce
+}
+
+func (me anacrolixDhtAnnounceWrapper) Peers() <-chan dht.PeersValues {
+	return me.Announce.Peers
+}
+
+func (me anacrolixDhtServerWrapper) Announce(hash [20]byte, port int, impliedPort bool) (DhtAnnounce, error) {
+	ann, err := me.Server.Announce(hash, port, impliedPort)
+	return anacrolixDhtAnnounceWrapper{ann}, err
+}
+
+func (me anacrolixDhtServerWrapper) Ping(addr *net.UDPAddr) {
+	me.Server.Ping(addr, nil)
+}
+
+var _ DhtServer = anacrolixDhtServerWrapper{}
diff --git a/torrent.go b/torrent.go
index 474acfd1..4cd6cd4b 100644
--- a/torrent.go
+++ b/torrent.go
@@ -1359,12 +1359,12 @@ func (t *Torrent) consumeDhtAnnouncePeers(pvs <-chan dht.PeersValues) {
 	}
 }
 
-func (t *Torrent) announceToDht(impliedPort bool, s *dht.Server) error {
+func (t *Torrent) announceToDht(impliedPort bool, s DhtServer) error {
 	ps, err := s.Announce(t.infoHash, t.cl.incomingPeerPort(), impliedPort)
 	if err != nil {
 		return err
 	}
-	go t.consumeDhtAnnouncePeers(ps.Peers)
+	go t.consumeDhtAnnouncePeers(ps.Peers())
 	select {
 	case <-t.closed.LockedChan(t.cl.locker()):
 	case <-time.After(5 * time.Minute):
@@ -1373,7 +1373,7 @@ func (t *Torrent) announceToDht(impliedPort bool, s *dht.Server) error {
 	return nil
 }
 
-func (t *Torrent) dhtAnnouncer(s *dht.Server) {
+func (t *Torrent) dhtAnnouncer(s DhtServer) {
 	cl := t.cl
 	for {
 		select {
-- 
2.51.0