From a69044b9ea222f2c8685b8c8c8d87c54734b3068 Mon Sep 17 00:00:00 2001 From: Matt Joiner Date: Tue, 17 May 2016 16:40:08 +1000 Subject: [PATCH] Break out KRPC stuff from dht into new package --- client.go | 7 +-- cmd/dht-get-peers/main.go | 7 +-- cmd/dht-ping/main.go | 7 +-- cmd/dht-server/main.go | 7 +-- dht/addr.go | 5 +- dht/announce.go | 15 +++--- dht/dht.go | 7 +-- dht/dht_test.go | 16 +++--- .../compact_node_info.go} | 6 +-- dht/{krpcError.go => krpc/error.go} | 2 +- dht/{ => krpc}/msg.go | 2 +- dht/{ => krpc}/msg_test.go | 6 +-- dht/{ => krpc}/nodeinfo.go | 12 ++--- dht/server.go | 49 ++++++++++--------- dht/transaction.go | 12 +++-- 15 files changed, 87 insertions(+), 73 deletions(-) rename dht/{compactNodeInfo.go => krpc/compact_node_info.go} (88%) rename dht/{krpcError.go => krpc/error.go} (98%) rename dht/{ => krpc}/msg.go (99%) rename dht/{ => krpc}/msg_test.go (97%) rename dht/{ => krpc}/nodeinfo.go (84%) diff --git a/client.go b/client.go index 44298663..ceea2af8 100644 --- a/client.go +++ b/client.go @@ -26,6 +26,7 @@ import ( "github.com/anacrolix/torrent/bencode" "github.com/anacrolix/torrent/dht" + "github.com/anacrolix/torrent/dht/krpc" "github.com/anacrolix/torrent/iplist" "github.com/anacrolix/torrent/metainfo" "github.com/anacrolix/torrent/mse" @@ -2001,11 +2002,11 @@ func (cl *Client) AddDHTNodes(nodes []string) { log.Printf("won't add DHT node with bad IP: %q", hmp.Host) continue } - ni := dht.NodeInfo{ - Addr: dht.NewAddr(&net.UDPAddr{ + ni := krpc.NodeInfo{ + Addr: &net.UDPAddr{ IP: ip, Port: hmp.Port, - }), + }, } cl.DHT().AddNode(ni) } diff --git a/cmd/dht-get-peers/main.go b/cmd/dht-get-peers/main.go index a63cd52e..32e9fe61 100644 --- a/cmd/dht-get-peers/main.go +++ b/cmd/dht-get-peers/main.go @@ -12,6 +12,7 @@ import ( _ "github.com/anacrolix/envpprof" "github.com/anacrolix/torrent/dht" + "github.com/anacrolix/torrent/dht/krpc" ) var ( @@ -38,7 +39,7 @@ func loadTable() error { defer f.Close() added := 0 for { - b := make([]byte, dht.CompactIPv4NodeInfoLen) + b := make([]byte, krpc.CompactIPv4NodeInfoLen) _, err := io.ReadFull(f, b) if err == io.EOF { break @@ -46,7 +47,7 @@ func loadTable() error { if err != nil { return fmt.Errorf("error reading table file: %s", err) } - var ni dht.NodeInfo + var ni krpc.NodeInfo err = ni.UnmarshalCompactIPv4(b) if err != nil { return fmt.Errorf("error unmarshaling compact node info: %s", err) @@ -100,7 +101,7 @@ func saveTable() error { } defer f.Close() for _, nodeInfo := range goodNodes { - var b [dht.CompactIPv4NodeInfoLen]byte + var b [krpc.CompactIPv4NodeInfoLen]byte err := nodeInfo.PutCompact(b[:]) if err != nil { return fmt.Errorf("error compacting node info: %s", err) diff --git a/cmd/dht-ping/main.go b/cmd/dht-ping/main.go index b23aa0cf..97edf72d 100644 --- a/cmd/dht-ping/main.go +++ b/cmd/dht-ping/main.go @@ -13,6 +13,7 @@ import ( "github.com/bradfitz/iter" "github.com/anacrolix/torrent/dht" + "github.com/anacrolix/torrent/dht/krpc" ) func main() { @@ -65,7 +66,7 @@ func startPings(s *dht.Server, pongChan chan pong, nodes []string) { type pong struct { addr string - krpc dht.Msg + krpc krpc.Msg msgOk bool rtt time.Duration } @@ -80,8 +81,8 @@ func ping(netloc string, pongChan chan pong, s *dht.Server) { log.Fatal(err) } start := time.Now() - t.SetResponseHandler(func(addr string) func(dht.Msg, bool) { - return func(resp dht.Msg, ok bool) { + t.SetResponseHandler(func(addr string) func(krpc.Msg, bool) { + return func(resp krpc.Msg, ok bool) { pongChan <- pong{ addr: addr, krpc: resp, diff --git a/cmd/dht-server/main.go b/cmd/dht-server/main.go index f0792b8d..15c523a5 100644 --- a/cmd/dht-server/main.go +++ b/cmd/dht-server/main.go @@ -9,6 +9,7 @@ import ( "os/signal" "github.com/anacrolix/torrent/dht" + "github.com/anacrolix/torrent/dht/krpc" ) var ( @@ -32,7 +33,7 @@ func loadTable() error { defer f.Close() added := 0 for { - b := make([]byte, dht.CompactIPv4NodeInfoLen) + b := make([]byte, krpc.CompactIPv4NodeInfoLen) _, err := io.ReadFull(f, b) if err == io.EOF { break @@ -40,7 +41,7 @@ func loadTable() error { if err != nil { return fmt.Errorf("error reading table file: %s", err) } - var ni dht.NodeInfo + var ni krpc.NodeInfo err = ni.UnmarshalCompactIPv4(b) if err != nil { return fmt.Errorf("error unmarshaling compact node info: %s", err) @@ -84,7 +85,7 @@ func saveTable() error { } defer f.Close() for _, nodeInfo := range goodNodes { - var b [dht.CompactIPv4NodeInfoLen]byte + var b [krpc.CompactIPv4NodeInfoLen]byte err := nodeInfo.PutCompact(b[:]) if err != nil { return fmt.Errorf("error compacting node info: %s", err) diff --git a/dht/addr.go b/dht/addr.go index 2495b150..5121ffc5 100644 --- a/dht/addr.go +++ b/dht/addr.go @@ -2,7 +2,10 @@ package dht import "net" -// Used internally to refer to node network addresses. +// Used internally to refer to node network addresses. String() is called a +// lot, and so can be optimized. Network() is not exposed, so that the +// interface does not satisfy net.Addr, as the underlying type must be passed +// to any OS-level function that take net.Addr. type Addr interface { UDPAddr() *net.UDPAddr String() string diff --git a/dht/announce.go b/dht/announce.go index c6b64579..b6fd14f8 100644 --- a/dht/announce.go +++ b/dht/announce.go @@ -9,6 +9,7 @@ import ( "github.com/anacrolix/sync" "github.com/willf/bloom" + "github.com/anacrolix/torrent/dht/krpc" "github.com/anacrolix/torrent/logonce" ) @@ -168,8 +169,8 @@ func (a *Announce) transactionClosed() { a.maybeClose() } -func (a *Announce) responseNode(node NodeInfo) { - a.gotNodeAddr(node.Addr) +func (a *Announce) responseNode(node krpc.NodeInfo) { + a.gotNodeAddr(NewAddr(node.Addr)) } func (a *Announce) closingCh() chan struct{} { @@ -201,7 +202,7 @@ func (a *Announce) getPeers(addr Addr) error { if err != nil { return err } - t.SetResponseHandler(func(m Msg, ok bool) { + t.SetResponseHandler(func(m krpc.Msg, ok bool) { // Register suggested nodes closer to the target info-hash. if m.R != nil { a.mu.Lock() @@ -211,8 +212,8 @@ func (a *Announce) getPeers(addr Addr) error { a.mu.Unlock() if vs := m.R.Values; len(vs) != 0 { - nodeInfo := NodeInfo{ - Addr: t.remoteAddr, + nodeInfo := krpc.NodeInfo{ + Addr: t.remoteAddr.UDPAddr(), } copy(nodeInfo.ID[:], m.SenderID()) select { @@ -243,8 +244,8 @@ func (a *Announce) getPeers(addr Addr) error { // peers that a node has reported as being in the swarm for a queried info // hash. type PeersValues struct { - Peers []Peer // Peers given in get_peers response. - NodeInfo // The node that gave the response. + Peers []Peer // Peers given in get_peers response. + krpc.NodeInfo // The node that gave the response. } // Stop the announce. diff --git a/dht/dht.go b/dht/dht.go index 61caad34..3b2eb858 100644 --- a/dht/dht.go +++ b/dht/dht.go @@ -9,6 +9,7 @@ import ( "strconv" "time" + "github.com/anacrolix/torrent/dht/krpc" "github.com/anacrolix/torrent/iplist" ) @@ -64,7 +65,7 @@ type ServerConfig struct { // Used to secure the server's ID. Defaults to the Conn's LocalAddr(). PublicIP net.IP - OnQuery func(*Msg, net.Addr) bool + OnQuery func(*krpc.Msg, net.Addr) bool } // ServerStats instance is returned by Server.Stats() and stores Server metrics @@ -162,8 +163,8 @@ func (n *node) IDNotSet() bool { return n.id.i.Int64() == 0 } -func (n *node) NodeInfo() (ret NodeInfo) { - ret.Addr = n.addr +func (n *node) NodeInfo() (ret krpc.NodeInfo) { + ret.Addr = n.addr.UDPAddr() if n := copy(ret.ID[:], n.idString()); n != 20 { panic(n) } diff --git a/dht/dht_test.go b/dht/dht_test.go index 6ed8c4aa..1f394fde 100644 --- a/dht/dht_test.go +++ b/dht/dht_test.go @@ -12,6 +12,8 @@ import ( _ "github.com/anacrolix/envpprof" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/anacrolix/torrent/dht/krpc" ) func TestSetNilBigInt(t *testing.T) { @@ -20,13 +22,13 @@ func TestSetNilBigInt(t *testing.T) { } func TestMarshalCompactNodeInfo(t *testing.T) { - cni := NodeInfo{ + cni := krpc.NodeInfo{ ID: [20]byte{'a', 'b', 'c'}, } addr, err := net.ResolveUDPAddr("udp4", "1.2.3.4:5") require.NoError(t, err) - cni.Addr = NewAddr(addr) - var b [CompactIPv4NodeInfoLen]byte + cni.Addr = addr + var b [krpc.CompactIPv4NodeInfoLen]byte err = cni.PutCompact(b[:]) require.NoError(t, err) var bb [26]byte @@ -129,7 +131,7 @@ func TestPing(t *testing.T) { require.NoError(t, err) defer tn.Close() ok := make(chan bool) - tn.SetResponseHandler(func(msg Msg, msgOk bool) { + tn.SetResponseHandler(func(msg krpc.Msg, msgOk bool) { ok <- msg.SenderID() == srv0.ID() }) if !<-ok { @@ -169,7 +171,7 @@ func TestAnnounceTimeout(t *testing.T) { } func TestEqualPointers(t *testing.T) { - assert.EqualValues(t, &Msg{R: &Return{}}, &Msg{R: &Return{}}) + assert.EqualValues(t, &krpc.Msg{R: &krpc.Return{}}, &krpc.Msg{R: &krpc.Return{}}) } func TestHook(t *testing.T) { @@ -185,7 +187,7 @@ func TestHook(t *testing.T) { srv0, err := NewServer(&ServerConfig{ Addr: "127.0.0.1:5679", BootstrapNodes: []string{"127.0.0.1:5678"}, - OnQuery: func(m *Msg, addr net.Addr) bool { + OnQuery: func(m *krpc.Msg, addr net.Addr) bool { if m.Q == "ping" { hookCalled <- true } @@ -203,7 +205,7 @@ func TestHook(t *testing.T) { assert.NoError(t, err) defer tn.Close() // Await response from hooked server - tn.SetResponseHandler(func(msg Msg, b bool) { + tn.SetResponseHandler(func(msg krpc.Msg, b bool) { t.Log("TestHook: Sender received response from pinged hook server, so normal execution resumed.") }) // Await signal that hook has been called. diff --git a/dht/compactNodeInfo.go b/dht/krpc/compact_node_info.go similarity index 88% rename from dht/compactNodeInfo.go rename to dht/krpc/compact_node_info.go index 0ed970a9..82f36ff7 100644 --- a/dht/compactNodeInfo.go +++ b/dht/krpc/compact_node_info.go @@ -1,4 +1,4 @@ -package dht +package krpc import ( "bytes" @@ -42,8 +42,8 @@ func (i CompactIPv4NodeInfo) MarshalBencode() (ret []byte, err error) { err = errors.New("nil addr in node info") return } - buf.Write(ni.Addr.UDPAddr().IP.To4()) - binary.Write(&buf, binary.BigEndian, uint16(ni.Addr.UDPAddr().Port)) + buf.Write(ni.Addr.IP.To4()) + binary.Write(&buf, binary.BigEndian, uint16(ni.Addr.Port)) } return bencode.Marshal(buf.Bytes()) } diff --git a/dht/krpcError.go b/dht/krpc/error.go similarity index 98% rename from dht/krpcError.go rename to dht/krpc/error.go index 4118b8f8..fddd0f48 100644 --- a/dht/krpcError.go +++ b/dht/krpc/error.go @@ -1,4 +1,4 @@ -package dht +package krpc import ( "fmt" diff --git a/dht/msg.go b/dht/krpc/msg.go similarity index 99% rename from dht/msg.go rename to dht/krpc/msg.go index 094653bd..b765b9f4 100644 --- a/dht/msg.go +++ b/dht/krpc/msg.go @@ -1,4 +1,4 @@ -package dht +package krpc import ( "fmt" diff --git a/dht/msg_test.go b/dht/krpc/msg_test.go similarity index 97% rename from dht/msg_test.go rename to dht/krpc/msg_test.go index ca446de3..7b30cf02 100644 --- a/dht/msg_test.go +++ b/dht/krpc/msg_test.go @@ -1,4 +1,4 @@ -package dht +package krpc import ( "net" @@ -46,10 +46,10 @@ func TestMarshalUnmarshalMsg(t *testing.T) { R: &Return{ Nodes: CompactIPv4NodeInfo{ NodeInfo{ - Addr: NewAddr(&net.UDPAddr{ + Addr: &net.UDPAddr{ IP: net.IPv4(1, 2, 3, 4).To4(), Port: 0x1234, - }), + }, }, }, }, diff --git a/dht/nodeinfo.go b/dht/krpc/nodeinfo.go similarity index 84% rename from dht/nodeinfo.go rename to dht/krpc/nodeinfo.go index 04708bb6..a7e0afa1 100644 --- a/dht/nodeinfo.go +++ b/dht/krpc/nodeinfo.go @@ -1,4 +1,4 @@ -package dht +package krpc import ( "encoding/binary" @@ -13,7 +13,7 @@ const CompactIPv4NodeInfoLen = 26 type NodeInfo struct { ID [20]byte - Addr Addr + Addr *net.UDPAddr } // Writes the node info to its compact binary representation in b. See @@ -22,14 +22,14 @@ func (ni *NodeInfo) PutCompact(b []byte) error { if n := copy(b[:], ni.ID[:]); n != 20 { panic(n) } - ip := ni.Addr.UDPAddr().IP.To4() + ip := ni.Addr.IP.To4() if len(ip) != 4 { return errors.New("expected ipv4 address") } if n := copy(b[20:], ip); n != 4 { panic(n) } - binary.BigEndian.PutUint16(b[24:], uint16(ni.Addr.UDPAddr().Port)) + binary.BigEndian.PutUint16(b[24:], uint16(ni.Addr.Port)) return nil } @@ -38,9 +38,9 @@ func (ni *NodeInfo) UnmarshalCompactIPv4(b []byte) error { return errors.New("expected 26 bytes") } missinggo.CopyExact(ni.ID[:], b[:20]) - ni.Addr = NewAddr(&net.UDPAddr{ + ni.Addr = &net.UDPAddr{ IP: append(make([]byte, 0, 4), b[20:24]...), Port: int(binary.BigEndian.Uint16(b[24:26])), - }) + } return nil } diff --git a/dht/server.go b/dht/server.go index c0f3127c..56917292 100644 --- a/dht/server.go +++ b/dht/server.go @@ -17,6 +17,7 @@ import ( "github.com/tylertreat/BoomFilters" "github.com/anacrolix/torrent/bencode" + "github.com/anacrolix/torrent/dht/krpc" "github.com/anacrolix/torrent/iplist" "github.com/anacrolix/torrent/logonce" ) @@ -153,7 +154,7 @@ func (s *Server) processPacket(b []byte, addr Addr) { readNotKRPCDict.Add(1) return } - var d Msg + var d krpc.Msg err := bencode.Unmarshal(b, &d) if err != nil { readUnmarshalError.Add(1) @@ -232,13 +233,13 @@ func (s *Server) ipBlocked(ip net.IP) (blocked bool) { } // Adds directly to the node table. -func (s *Server) AddNode(ni NodeInfo) { +func (s *Server) AddNode(ni krpc.NodeInfo) { s.mu.Lock() defer s.mu.Unlock() if s.nodes == nil { s.nodes = make(map[string]*node) } - s.getNode(ni.Addr, string(ni.ID[:])) + s.getNode(NewAddr(ni.Addr), string(ni.ID[:])) } func (s *Server) nodeByID(id string) *node { @@ -250,7 +251,7 @@ func (s *Server) nodeByID(id string) *node { return nil } -func (s *Server) handleQuery(source Addr, m Msg) { +func (s *Server) handleQuery(source Addr, m krpc.Msg) { node := s.getNode(source, m.SenderID()) node.lastGotQuery = time.Now() if s.config.OnQuery != nil { @@ -266,18 +267,18 @@ func (s *Server) handleQuery(source Addr, m Msg) { args := m.A switch m.Q { case "ping": - s.reply(source, m.T, Return{}) + s.reply(source, m.T, krpc.Return{}) case "get_peers": // TODO: Extract common behaviour with find_node. targetID := args.InfoHash if len(targetID) != 20 { break } - var rNodes []NodeInfo + var rNodes []krpc.NodeInfo // TODO: Reply with "values" list if we have peers instead. for _, node := range s.closestGoodNodes(8, targetID) { rNodes = append(rNodes, node.NodeInfo()) } - s.reply(source, m.T, Return{ + s.reply(source, m.T, krpc.Return{ Nodes: rNodes, // TODO: Generate this dynamically, and store it for the source. Token: "hi", @@ -288,7 +289,7 @@ func (s *Server) handleQuery(source Addr, m Msg) { log.Printf("bad DHT query: %v", m) return } - var rNodes []NodeInfo + var rNodes []krpc.NodeInfo if node := s.nodeByID(targetID); node != nil { rNodes = append(rNodes, node.NodeInfo()) } else { @@ -297,7 +298,7 @@ func (s *Server) handleQuery(source Addr, m Msg) { rNodes = append(rNodes, node.NodeInfo()) } } - s.reply(source, m.T, Return{ + s.reply(source, m.T, krpc.Return{ Nodes: rNodes, }) case "announce_peer": @@ -311,9 +312,9 @@ func (s *Server) handleQuery(source Addr, m Msg) { } } -func (s *Server) reply(addr Addr, t string, r Return) { +func (s *Server) reply(addr Addr, t string, r krpc.Return) { r.ID = s.ID() - m := Msg{ + m := krpc.Msg{ T: t, Y: "r", R: &r, @@ -425,7 +426,7 @@ func (s *Server) ID() string { return s.id } -func (s *Server) query(node Addr, q string, a map[string]interface{}, onResponse func(Msg)) (t *Transaction, err error) { +func (s *Server) query(node Addr, q string, a map[string]interface{}, onResponse func(krpc.Msg)) (t *Transaction, err error) { tid := s.nextTransactionID() if a == nil { a = make(map[string]interface{}, 1) @@ -449,7 +450,7 @@ func (s *Server) query(node Addr, q string, a map[string]interface{}, onResponse _t := &Transaction{ remoteAddr: node, t: tid, - response: make(chan Msg, 1), + response: make(chan krpc.Msg, 1), done: make(chan struct{}), queryPacket: b, s: s, @@ -490,7 +491,7 @@ func (s *Server) announcePeer(node Addr, infoHash string, port int, token string "info_hash": infoHash, "port": port, "token": token, - }, func(m Msg) { + }, func(m krpc.Msg) { if err := m.Error(); err != nil { announceErrors.Add(1) // log.Print(token) @@ -503,26 +504,26 @@ func (s *Server) announcePeer(node Addr, infoHash string, port int, token string } // Add response nodes to node table. -func (s *Server) liftNodes(d Msg) { +func (s *Server) liftNodes(d krpc.Msg) { if d.Y != "r" { return } for _, cni := range d.R.Nodes { - if cni.Addr.UDPAddr().Port == 0 { + if cni.Addr.Port == 0 { // TODO: Why would people even do this? continue } - if s.ipBlocked(cni.Addr.UDPAddr().IP) { + if s.ipBlocked(cni.Addr.IP) { continue } - n := s.getNode(cni.Addr, string(cni.ID[:])) + n := s.getNode(NewAddr(cni.Addr), string(cni.ID[:])) n.SetIDFromBytes(cni.ID[:]) } } // Sends a find_node query to addr. targetID is the node we're looking for. func (s *Server) findNode(addr Addr, targetID string) (t *Transaction, err error) { - t, err = s.query(addr, "find_node", map[string]interface{}{"target": targetID}, func(d Msg) { + t, err = s.query(addr, "find_node", map[string]interface{}{"target": targetID}, func(d krpc.Msg) { // Scrape peers from the response to put in the server's table before // handing the response back to the caller. s.liftNodes(d) @@ -578,7 +579,7 @@ func (s *Server) bootstrap() (err error) { return } outstanding.Add(1) - t.SetResponseHandler(func(Msg, bool) { + t.SetResponseHandler(func(krpc.Msg, bool) { outstanding.Done() }) } @@ -621,15 +622,15 @@ func (s *Server) NumNodes() int { } // Exports the current node table. -func (s *Server) Nodes() (nis []NodeInfo) { +func (s *Server) Nodes() (nis []krpc.NodeInfo) { s.mu.Lock() defer s.mu.Unlock() for _, node := range s.nodes { // if !node.Good() { // continue // } - ni := NodeInfo{ - Addr: node.addr, + ni := krpc.NodeInfo{ + Addr: node.addr.UDPAddr(), } if n := copy(ni.ID[:], node.idString()); n != 20 && n != 0 { panic(n) @@ -682,7 +683,7 @@ func (s *Server) getPeers(addr Addr, infoHash string) (t *Transaction, err error err = fmt.Errorf("infohash has bad length") return } - t, err = s.query(addr, "get_peers", map[string]interface{}{"info_hash": infoHash}, func(m Msg) { + t, err = s.query(addr, "get_peers", map[string]interface{}{"info_hash": infoHash}, func(m krpc.Msg) { s.liftNodes(m) if m.R != nil && m.R.Token != "" { s.getNode(addr, m.SenderID()).announceToken = m.R.Token diff --git a/dht/transaction.go b/dht/transaction.go index 0c007f00..3698a262 100644 --- a/dht/transaction.go +++ b/dht/transaction.go @@ -3,6 +3,8 @@ package dht import ( "sync" "time" + + "github.com/anacrolix/torrent/dht/krpc" ) // Transaction keeps track of a message exchange between nodes, such as a @@ -11,20 +13,20 @@ type Transaction struct { mu sync.Mutex remoteAddr Addr t string - response chan Msg - onResponse func(Msg) // Called with the server locked. + response chan krpc.Msg + onResponse func(krpc.Msg) // Called with the server locked. done chan struct{} queryPacket []byte timer *time.Timer s *Server retries int lastSend time.Time - userOnResponse func(Msg, bool) + userOnResponse func(krpc.Msg, bool) } // SetResponseHandler sets up a function to be called when the query response // is available. -func (t *Transaction) SetResponseHandler(f func(Msg, bool)) { +func (t *Transaction) SetResponseHandler(f func(krpc.Msg, bool)) { t.mu.Lock() defer t.mu.Unlock() t.userOnResponse = f @@ -124,7 +126,7 @@ func (t *Transaction) Close() { t.close() } -func (t *Transaction) handleResponse(m Msg) { +func (t *Transaction) handleResponse(m krpc.Msg) { t.mu.Lock() if t.closing() { t.mu.Unlock() -- 2.48.1