From afe4d8795daa14a80d3d1ed1e72d1261f566a1b7 Mon Sep 17 00:00:00 2001 From: Matt Joiner Date: Thu, 20 Feb 2020 17:46:29 +1100 Subject: [PATCH] Support custom DHT servers Addresses #266. --- client.go | 32 ++++++++++++++++--------------- client_test.go | 8 ++++---- connection.go | 5 ++--- dht.go | 51 ++++++++++++++++++++++++++++++++++++++++++++++++++ torrent.go | 6 +++--- 5 files changed, 77 insertions(+), 25 deletions(-) create mode 100644 dht.go diff --git a/client.go b/client.go index 8bb1ca38..e5dfa045 100644 --- a/client.go +++ b/client.go @@ -58,7 +58,7 @@ type Client struct { onClose []func() dialers []Dialer listeners []Listener - dhtServers []*dht.Server + dhtServers []DhtServer ipBlockList iplist.Ranger // Our BitTorrent protocol extension bytes, sent in our BT handshakes. extensionBytes pp.PeerExtensionBits @@ -101,12 +101,10 @@ func (cl *Client) LocalPort() (port int) { return } -func writeDhtServerStatus(w io.Writer, s *dht.Server) { +func writeDhtServerStatus(w io.Writer, s DhtServer) { dhtStats := s.Stats() - fmt.Fprintf(w, "\t# Nodes: %d (%d good, %d banned)\n", dhtStats.Nodes, dhtStats.GoodNodes, dhtStats.BadNodes) fmt.Fprintf(w, "\tServer ID: %x\n", s.ID()) - fmt.Fprintf(w, "\tAnnounces: %d\n", dhtStats.SuccessfulOutboundAnnouncePeerQueries) - fmt.Fprintf(w, "\tOutstanding transactions: %d\n", dhtStats.OutstandingTransactions) + spew.Fdump(w, dhtStats) } // Writes out a human readable status of the client, such as for writing to a @@ -120,7 +118,7 @@ func (cl *Client) WriteStatus(_w io.Writer) { fmt.Fprintf(w, "Peer ID: %+q\n", cl.PeerID()) fmt.Fprintf(w, "Announce key: %x\n", cl.announceKey()) fmt.Fprintf(w, "Banned IPs: %d\n", len(cl.badPeerIPsLocked())) - cl.eachDhtServer(func(s *dht.Server) { + cl.eachDhtServer(func(s DhtServer) { fmt.Fprintf(w, "%s DHT server at %s:\n", s.Addr().Network(), s.Addr().String()) writeDhtServerStatus(w, s) }) @@ -237,11 +235,11 @@ func NewClient(cfg *ClientConfig) (cl *Client, err error) { if !cfg.NoDHT { for _, s := range sockets { if pc, ok := s.(net.PacketConn); ok { - ds, err := cl.newDhtServer(pc) + ds, err := cl.newAnacrolixDhtServer(pc) if err != nil { panic(err) } - cl.dhtServers = append(cl.dhtServers, ds) + cl.dhtServers = append(cl.dhtServers, anacrolixDhtServerWrapper{ds}) cl.onClose = append(cl.onClose, func() { ds.Close() }) } } @@ -250,6 +248,10 @@ func NewClient(cfg *ClientConfig) (cl *Client, err error) { return } +func (cl *Client) AddDhtServer(d DhtServer) { + cl.dhtServers = append(cl.dhtServers, d) +} + // Adds a Dialer for outgoing connections. All Dialers are used when attempting to connect to a // given address for any Torrent. func (cl *Client) AddDialer(d Dialer) { @@ -300,7 +302,7 @@ func (cl *Client) listenNetworks() (ns []network) { return } -func (cl *Client) newDhtServer(conn net.PacketConn) (s *dht.Server, err error) { +func (cl *Client) newAnacrolixDhtServer(conn net.PacketConn) (s *dht.Server, err error) { cfg := dht.ServerConfig{ IPBlocklist: cl.ipBlockList, Conn: conn, @@ -335,7 +337,7 @@ func (cl *Client) Closed() <-chan struct{} { return cl.closed.C() } -func (cl *Client) eachDhtServer(f func(*dht.Server)) { +func (cl *Client) eachDhtServer(f func(DhtServer)) { for _, ds := range cl.dhtServers { f(ds) } @@ -929,14 +931,14 @@ func (cl *Client) sendInitialMessages(conn *connection, torrent *Torrent) { } func (cl *Client) dhtPort() (ret uint16) { - cl.eachDhtServer(func(s *dht.Server) { + cl.eachDhtServer(func(s DhtServer) { ret = uint16(missinggo.AddrPort(s.Addr())) }) return } func (cl *Client) haveDhtServer() (ret bool) { - cl.eachDhtServer(func(_ *dht.Server) { + cl.eachDhtServer(func(_ DhtServer) { ret = true }) return @@ -1071,7 +1073,7 @@ func (cl *Client) AddTorrentInfoHashWithStorage(infoHash metainfo.Hash, specStor new = true t = cl.newTorrent(infoHash, specStorage) - cl.eachDhtServer(func(s *dht.Server) { + cl.eachDhtServer(func(s DhtServer) { go t.dhtAnnouncer(s) }) cl.torrents[infoHash] = t @@ -1188,7 +1190,7 @@ func (cl *Client) AddTorrentFromFile(filename string) (T *Torrent, err error) { return cl.AddTorrent(mi) } -func (cl *Client) DhtServers() []*dht.Server { +func (cl *Client) DhtServers() []DhtServer { return cl.dhtServers } @@ -1206,7 +1208,7 @@ func (cl *Client) AddDHTNodes(nodes []string) { Port: hmp.Port, }, } - cl.eachDhtServer(func(s *dht.Server) { + cl.eachDhtServer(func(s DhtServer) { s.AddNode(ni) }) } diff --git a/client_test.go b/client_test.go index 7af30321..6748bd30 100644 --- a/client_test.go +++ b/client_test.go @@ -316,8 +316,8 @@ func TestDHTInheritBlocklist(t *testing.T) { require.NoError(t, err) defer cl.Close() numServers := 0 - cl.eachDhtServer(func(s *dht.Server) { - assert.Equal(t, ipl, s.IPBlocklist()) + cl.eachDhtServer(func(s DhtServer) { + assert.Equal(t, ipl, s.(anacrolixDhtServerWrapper).IPBlocklist()) numServers++ }) assert.EqualValues(t, 2, numServers) @@ -434,8 +434,8 @@ func TestAddMetainfoWithNodes(t *testing.T) { require.NoError(t, err) defer cl.Close() sum := func() (ret int64) { - cl.eachDhtServer(func(s *dht.Server) { - ret += s.Stats().OutboundQueriesAttempted + cl.eachDhtServer(func(s DhtServer) { + ret += s.Stats().(dht.ServerStats).OutboundQueriesAttempted }) return } diff --git a/connection.go b/connection.go index a7f8edf2..edddebd8 100644 --- a/connection.go +++ b/connection.go @@ -12,7 +12,6 @@ import ( "sync" "time" - "github.com/anacrolix/dht/v2" "github.com/anacrolix/log" "github.com/anacrolix/missinggo" "github.com/anacrolix/missinggo/iter" @@ -1060,8 +1059,8 @@ func (c *connection) mainReadLoop() (err error) { if msg.Port != 0 { pingAddr.Port = int(msg.Port) } - cl.eachDhtServer(func(s *dht.Server) { - go s.Ping(&pingAddr, nil) + cl.eachDhtServer(func(s DhtServer) { + go s.Ping(&pingAddr) }) case pp.Suggest: torrent.Add("suggests received", 1) diff --git a/dht.go b/dht.go new file mode 100644 index 00000000..da79aee4 --- /dev/null +++ b/dht.go @@ -0,0 +1,51 @@ +package torrent + +import ( + "io" + "net" + + "github.com/anacrolix/dht/v2" + "github.com/anacrolix/dht/v2/krpc" +) + +type DhtServer interface { + Stats() interface{} + ID() [20]byte + Addr() net.Addr + AddNode(ni krpc.NodeInfo) error + Ping(addr *net.UDPAddr) + Announce(hash [20]byte, port int, impliedPort bool) (DhtAnnounce, error) + WriteStatus(io.Writer) +} + +type DhtAnnounce interface { + Close() + Peers() <-chan dht.PeersValues +} + +type anacrolixDhtServerWrapper struct { + *dht.Server +} + +func (me anacrolixDhtServerWrapper) Stats() interface{} { + return me.Server.Stats() +} + +type anacrolixDhtAnnounceWrapper struct { + *dht.Announce +} + +func (me anacrolixDhtAnnounceWrapper) Peers() <-chan dht.PeersValues { + return me.Announce.Peers +} + +func (me anacrolixDhtServerWrapper) Announce(hash [20]byte, port int, impliedPort bool) (DhtAnnounce, error) { + ann, err := me.Server.Announce(hash, port, impliedPort) + return anacrolixDhtAnnounceWrapper{ann}, err +} + +func (me anacrolixDhtServerWrapper) Ping(addr *net.UDPAddr) { + me.Server.Ping(addr, nil) +} + +var _ DhtServer = anacrolixDhtServerWrapper{} diff --git a/torrent.go b/torrent.go index 474acfd1..4cd6cd4b 100644 --- a/torrent.go +++ b/torrent.go @@ -1359,12 +1359,12 @@ func (t *Torrent) consumeDhtAnnouncePeers(pvs <-chan dht.PeersValues) { } } -func (t *Torrent) announceToDht(impliedPort bool, s *dht.Server) error { +func (t *Torrent) announceToDht(impliedPort bool, s DhtServer) error { ps, err := s.Announce(t.infoHash, t.cl.incomingPeerPort(), impliedPort) if err != nil { return err } - go t.consumeDhtAnnouncePeers(ps.Peers) + go t.consumeDhtAnnouncePeers(ps.Peers()) select { case <-t.closed.LockedChan(t.cl.locker()): case <-time.After(5 * time.Minute): @@ -1373,7 +1373,7 @@ func (t *Torrent) announceToDht(impliedPort bool, s *dht.Server) error { return nil } -func (t *Torrent) dhtAnnouncer(s *dht.Server) { +func (t *Torrent) dhtAnnouncer(s DhtServer) { cl := t.cl for { select { -- 2.44.0