From 5bf56f6d8d786215b36e514667a18cb43ac27413 Mon Sep 17 00:00:00 2001 From: Matt Joiner Date: Sun, 16 Nov 2014 21:22:29 -0600 Subject: [PATCH] Tunnel addrs through dht as an internal interface to make switching easier --- dht/dht.go | 62 ++++++++++++++++++++++++++++--------------------- dht/dht_test.go | 4 ++-- 2 files changed, 37 insertions(+), 29 deletions(-) diff --git a/dht/dht.go b/dht/dht.go index 5f671e60..90169ff4 100644 --- a/dht/dht.go +++ b/dht/dht.go @@ -14,23 +14,31 @@ import ( "time" "bitbucket.org/anacrolix/go.torrent/logonce" - "bitbucket.org/anacrolix/go.torrent/util" - "github.com/nsf/libtorgo/bencode" + "github.com/anacrolix/libtorgo/bencode" ) type Server struct { id string - socket *net.UDPConn + socket net.PacketConn transactions []*transaction transactionIDInt uint64 - nodes map[string]*Node // Keyed by *net.UDPAddr.String(). + nodes map[string]*Node // Keyed by dHTAddr.String(). mu sync.Mutex closed chan struct{} NumConfirmedAnnounces int } +type dHTAddr interface { + net.Addr +} + +func newDHTAddr(addr *net.UDPAddr) (ret dHTAddr) { + ret = addr + return +} + type ServerConfig struct { Addr string } @@ -86,7 +94,7 @@ func (s *Server) String() string { } type Node struct { - addr *net.UDPAddr + addr dHTAddr id string lastHeardFrom time.Time lastSentTo time.Time @@ -155,7 +163,7 @@ func (m Msg) AnnounceToken() string { } type transaction struct { - remoteAddr net.Addr + remoteAddr dHTAddr t string Response chan Msg onResponse func(Msg) @@ -214,7 +222,7 @@ func (s *Server) init() (err error) { func (s *Server) serve() error { for { var b [0x10000]byte - n, addr, err := s.socket.ReadFromUDP(b[:]) + n, addr_, err := s.socket.ReadFrom(b[:]) if err != nil { return err } @@ -226,6 +234,7 @@ func (s *Server) serve() error { } continue } + addr := newDHTAddr(addr_.(*net.UDPAddr)) s.mu.Lock() if d["y"] == "q" { s.handleQuery(addr, d) @@ -268,7 +277,7 @@ func (s *Server) nodeByID(id string) *Node { return nil } -func (s *Server) handleQuery(source *net.UDPAddr, m Msg) { +func (s *Server) handleQuery(source dHTAddr, m Msg) { args := m["a"].(map[string]interface{}) s.heardFromNode(source, args["id"].(string)) switch m["q"] { @@ -327,7 +336,7 @@ func (s *Server) handleQuery(source *net.UDPAddr, m Msg) { } } -func (s *Server) reply(addr *net.UDPAddr, t string, r map[string]interface{}) { +func (s *Server) reply(addr dHTAddr, t string, r map[string]interface{}) { if r == nil { r = make(map[string]interface{}, 1) } @@ -347,7 +356,7 @@ func (s *Server) reply(addr *net.UDPAddr, t string, r map[string]interface{}) { } } -func (s *Server) heardFromNode(addr *net.UDPAddr, id string) { +func (s *Server) heardFromNode(addr dHTAddr, id string) { n := s.getNode(addr) if len(id) == 20 { n.id = id @@ -355,8 +364,8 @@ func (s *Server) heardFromNode(addr *net.UDPAddr, id string) { n.lastHeardFrom = time.Now() } -func (s *Server) getNode(addr *net.UDPAddr) (n *Node) { - if addr.Port == 0 { +func (s *Server) getNode(addr dHTAddr) (n *Node) { + if util.AddrPort(addr) == 0 { panic(addr) } n = s.nodes[addr.String()] @@ -369,7 +378,7 @@ func (s *Server) getNode(addr *net.UDPAddr) (n *Node) { return } -func (s *Server) writeToNode(b []byte, node *net.UDPAddr) (err error) { +func (s *Server) writeToNode(b []byte, node dHTAddr) (err error) { n, err := s.socket.WriteTo(b, node) if err != nil { err = fmt.Errorf("error writing %d bytes to %s: %s", len(b), node, err) @@ -383,12 +392,12 @@ func (s *Server) writeToNode(b []byte, node *net.UDPAddr) (err error) { return } -func (s *Server) sentToNode(addr *net.UDPAddr) { +func (s *Server) sentToNode(addr dHTAddr) { n := s.getNode(addr) n.lastSentTo = time.Now() } -func (s *Server) findResponseTransaction(transactionID string, sourceNode net.Addr) *transaction { +func (s *Server) findResponseTransaction(transactionID string, sourceNode dHTAddr) *transaction { for _, t := range s.transactions { if t.t == transactionID && t.remoteAddr.String() == sourceNode.String() { return t @@ -444,7 +453,7 @@ func (s *Server) timeoutTransaction(t *transaction) { s.removeTransaction(t) } -func (s *Server) query(node *net.UDPAddr, q string, a map[string]interface{}) (t *transaction, err error) { +func (s *Server) query(node dHTAddr, q string, a map[string]interface{}) (t *transaction, err error) { tid := s.nextTransactionID() if a == nil { a = make(map[string]interface{}, 1) @@ -480,7 +489,7 @@ const CompactNodeInfoLen = 26 type NodeInfo struct { ID [20]byte - Addr *net.UDPAddr + Addr dHTAddr } func (ni *NodeInfo) PutCompact(b []byte) error { @@ -505,18 +514,17 @@ func (cni *NodeInfo) UnmarshalCompact(b []byte) error { if 20 != copy(cni.ID[:], b[:20]) { panic("impossibru!") } - if cni.Addr == nil { - cni.Addr = &net.UDPAddr{} - } - cni.Addr.IP = net.IPv4(b[20], b[21], b[22], b[23]) - cni.Addr.Port = int(binary.BigEndian.Uint16(b[24:26])) + cni.Addr = newDHTAddr(&net.UDPAddr{ + IP: net.IPv4(b[20], b[21], b[22], b[23]), + Port: int(binary.BigEndian.Uint16(b[24:26])), + }) return nil } func (s *Server) Ping(node *net.UDPAddr) (*transaction, error) { s.mu.Lock() defer s.mu.Unlock() - return s.query(node, "ping", nil) + return s.query(newDHTAddr(node), "ping", nil) } // Announce a local peer. This can only be done to nodes that gave us an @@ -536,7 +544,7 @@ func (s *Server) AnnouncePeer(port int, impliedPort bool, infoHash string) (err return } -func (s *Server) announcePeer(node *net.UDPAddr, infoHash string, port int, token string, impliedPort bool) error { +func (s *Server) announcePeer(node dHTAddr, infoHash string, port int, token string, impliedPort bool) error { t, err := s.query(node, "announce_peer", map[string]interface{}{ "implied_port": func() int { if impliedPort { @@ -663,7 +671,7 @@ func (s *Server) liftNodes(d Msg) { } // Sends a find_node query to addr. targetID is the node we're looking for. -func (s *Server) findNode(addr *net.UDPAddr, targetID string) (t *transaction, err error) { +func (s *Server) findNode(addr dHTAddr, targetID string) (t *transaction, err error) { t, err = s.query(addr, "find_node", map[string]interface{}{"target": targetID}) if err != nil { return @@ -770,7 +778,7 @@ func (s *Server) GetPeers(infoHash string) (ps *peerStream, err error) { return } -func (s *Server) getPeers(addr *net.UDPAddr, infoHash string) (t *transaction, err error) { +func (s *Server) getPeers(addr dHTAddr, infoHash string) (t *transaction, err error) { if len(infoHash) != 20 { err = fmt.Errorf("infohash has bad length") return @@ -792,7 +800,7 @@ func (s *Server) addRootNode() error { return err } s.nodes[addr.String()] = &Node{ - addr: addr, + addr: newDHTAddr(addr), } return nil } diff --git a/dht/dht_test.go b/dht/dht_test.go index 48aaa1e5..71afc5eb 100644 --- a/dht/dht_test.go +++ b/dht/dht_test.go @@ -10,11 +10,11 @@ func TestMarshalCompactNodeInfo(t *testing.T) { cni := NodeInfo{ ID: [20]byte{'a', 'b', 'c'}, } - var err error - cni.Addr, err = net.ResolveUDPAddr("udp4", "1.2.3.4:5") + addr, err := net.ResolveUDPAddr("udp4", "1.2.3.4:5") if err != nil { t.Fatal(err) } + cni.Addr = newDHTAddr(addr) var b [CompactNodeInfoLen]byte cni.PutCompact(b[:]) if err != nil { -- 2.48.1