]> Sergey Matveev's repositories - btrtrc.git/commitdiff
dht: Various improvements and removal of cruft
authorMatt Joiner <anacrolix@gmail.com>
Fri, 26 Dec 2014 06:21:48 +0000 (17:21 +1100)
committerMatt Joiner <anacrolix@gmail.com>
Fri, 26 Dec 2014 06:21:48 +0000 (17:21 +1100)
dht/bitcount.go [new file with mode: 0644]
dht/closest_nodes.go
dht/dht.go
dht/dht_test.go
dht/getpeers.go

diff --git a/dht/bitcount.go b/dht/bitcount.go
new file mode 100644 (file)
index 0000000..a21c8c0
--- /dev/null
@@ -0,0 +1,55 @@
+package dht
+
+import (
+       "math/big"
+)
+
+// TODO: The bitcounting is a relic of the old and incorrect distance
+// calculation. It is still useful in some tests but should eventually be
+// replaced with actual distances.
+
+// How many bits?
+func bitCount(n big.Int) int {
+       var count int = 0
+       for _, b := range n.Bytes() {
+               count += int(bitCounts[b])
+       }
+       return count
+}
+
+// The bit counts for each byte value (0 - 255).
+var bitCounts = []int8{
+       // Generated by Java BitCount of all values from 0 to 255
+       0, 1, 1, 2, 1, 2, 2, 3,
+       1, 2, 2, 3, 2, 3, 3, 4,
+       1, 2, 2, 3, 2, 3, 3, 4,
+       2, 3, 3, 4, 3, 4, 4, 5,
+       1, 2, 2, 3, 2, 3, 3, 4,
+       2, 3, 3, 4, 3, 4, 4, 5,
+       2, 3, 3, 4, 3, 4, 4, 5,
+       3, 4, 4, 5, 4, 5, 5, 6,
+       1, 2, 2, 3, 2, 3, 3, 4,
+       2, 3, 3, 4, 3, 4, 4, 5,
+       2, 3, 3, 4, 3, 4, 4, 5,
+       3, 4, 4, 5, 4, 5, 5, 6,
+       2, 3, 3, 4, 3, 4, 4, 5,
+       3, 4, 4, 5, 4, 5, 5, 6,
+       3, 4, 4, 5, 4, 5, 5, 6,
+       4, 5, 5, 6, 5, 6, 6, 7,
+       1, 2, 2, 3, 2, 3, 3, 4,
+       2, 3, 3, 4, 3, 4, 4, 5,
+       2, 3, 3, 4, 3, 4, 4, 5,
+       3, 4, 4, 5, 4, 5, 5, 6,
+       2, 3, 3, 4, 3, 4, 4, 5,
+       3, 4, 4, 5, 4, 5, 5, 6,
+       3, 4, 4, 5, 4, 5, 5, 6,
+       4, 5, 5, 6, 5, 6, 6, 7,
+       2, 3, 3, 4, 3, 4, 4, 5,
+       3, 4, 4, 5, 4, 5, 5, 6,
+       3, 4, 4, 5, 4, 5, 5, 6,
+       4, 5, 5, 6, 5, 6, 6, 7,
+       3, 4, 4, 5, 4, 5, 5, 6,
+       4, 5, 5, 6, 5, 6, 6, 7,
+       4, 5, 5, 6, 5, 6, 6, 7,
+       5, 6, 6, 7, 6, 7, 7, 8,
+}
index 70366dd54ce41823d05330d4d47e35dd0d7657e7..d0ccd94a700284aa192f8dbbde5a3b3ae5d8bb10 100644 (file)
@@ -5,14 +5,16 @@ import (
 )
 
 type nodeMaxHeap struct {
-       IDs    []string
-       Target string
+       IDs    []nodeID
+       Target nodeID
 }
 
 func (me nodeMaxHeap) Len() int { return len(me.IDs) }
 
 func (me nodeMaxHeap) Less(i, j int) bool {
-       return idDistance(me.IDs[i], me.Target).Cmp(idDistance(me.IDs[j], me.Target)) > 0
+       m := me.IDs[i].Distance(&me.Target)
+       n := me.IDs[j].Distance(&me.Target)
+       return m.Cmp(&n) > 0
 }
 
 func (me *nodeMaxHeap) Pop() (ret interface{}) {
@@ -20,7 +22,7 @@ func (me *nodeMaxHeap) Pop() (ret interface{}) {
        return
 }
 func (me *nodeMaxHeap) Push(val interface{}) {
-       me.IDs = append(me.IDs, val.(string))
+       me.IDs = append(me.IDs, val.(nodeID))
 }
 func (me nodeMaxHeap) Swap(i, j int) {
        me.IDs[i], me.IDs[j] = me.IDs[j], me.IDs[i]
@@ -31,18 +33,18 @@ type closestNodesSelector struct {
        k       int
 }
 
-func (me *closestNodesSelector) Push(id string) {
+func (me *closestNodesSelector) Push(id nodeID) {
        heap.Push(&me.closest, id)
        if me.closest.Len() > me.k {
                heap.Pop(&me.closest)
        }
 }
 
-func (me *closestNodesSelector) IDs() []string {
+func (me *closestNodesSelector) IDs() []nodeID {
        return me.closest.IDs
 }
 
-func newKClosestNodesSelector(k int, targetID string) (ret closestNodesSelector) {
+func newKClosestNodesSelector(k int, targetID nodeID) (ret closestNodesSelector) {
        ret.k = k
        ret.closest.Target = targetID
        return
index 19ccb667198fad1f72d5fe290ef7fb35ebf9292c..352a6f59cc3c4c5e5f1385e38c7aa5e6f67d06ad 100644 (file)
@@ -1,11 +1,16 @@
 package dht
 
 import (
+       "bitbucket.org/anacrolix/go.torrent/iplist"
+       "bitbucket.org/anacrolix/go.torrent/logonce"
+       "bitbucket.org/anacrolix/go.torrent/util"
+       "bitbucket.org/anacrolix/sync"
        "crypto"
        _ "crypto/sha1"
        "encoding/binary"
        "errors"
        "fmt"
+       "github.com/anacrolix/libtorgo/bencode"
        "io"
        "log"
        "math/big"
@@ -13,14 +18,6 @@ import (
        "net"
        "os"
        "time"
-
-       "bitbucket.org/anacrolix/sync"
-
-       "bitbucket.org/anacrolix/go.torrent/iplist"
-
-       "bitbucket.org/anacrolix/go.torrent/logonce"
-       "bitbucket.org/anacrolix/go.torrent/util"
-       "github.com/anacrolix/libtorgo/bencode"
 )
 
 const maxNodes = 10000
@@ -47,11 +44,28 @@ type Server struct {
 
 type dHTAddr interface {
        net.Addr
+       UDPAddr() *net.UDPAddr
 }
 
-func newDHTAddr(addr *net.UDPAddr) (ret dHTAddr) {
-       ret = addr
-       return
+type cachedAddr struct {
+       a net.Addr
+       s string
+}
+
+func (ca cachedAddr) Network() string {
+       return ca.a.Network()
+}
+
+func (ca cachedAddr) String() string {
+       return ca.s
+}
+
+func (ca cachedAddr) UDPAddr() *net.UDPAddr {
+       return ca.a.(*net.UDPAddr)
+}
+
+func newDHTAddr(addr *net.UDPAddr) dHTAddr {
+       return cachedAddr{addr, addr.String()}
 }
 
 type ServerConfig struct {
@@ -134,9 +148,40 @@ func (s *Server) String() string {
        return fmt.Sprintf("dht server on %s", s.socket.LocalAddr())
 }
 
+type nodeID struct {
+       i   big.Int
+       set bool
+}
+
+func (nid *nodeID) IsUnset() bool {
+       return !nid.set
+}
+
+func nodeIDFromString(s string) (ret nodeID) {
+       if s == "" {
+               return
+       }
+       ret.i.SetBytes([]byte(s))
+       ret.set = true
+       return
+}
+
+func (nid0 *nodeID) Distance(nid1 *nodeID) (ret big.Int) {
+       if nid0.IsUnset() != nid1.IsUnset() {
+               ret = maxDistance
+               return
+       }
+       ret.Xor(&nid0.i, &nid1.i)
+       return
+}
+
+func (nid *nodeID) String() string {
+       return string(nid.i.Bytes())
+}
+
 type Node struct {
        addr          dHTAddr
-       id            string
+       id            nodeID
        announceToken string
 
        lastGotQuery    time.Time
@@ -144,16 +189,33 @@ type Node struct {
        lastSentQuery   time.Time
 }
 
+func (n *Node) idString() string {
+       return n.id.String()
+}
+
+func (n *Node) SetIDFromBytes(b []byte) {
+       n.id.i.SetBytes(b)
+       n.id.set = true
+}
+
+func (n *Node) SetIDFromString(s string) {
+       n.id.i.SetBytes([]byte(s))
+}
+
+func (n *Node) IDNotSet() bool {
+       return n.id.i.Int64() == 0
+}
+
 func (n *Node) NodeInfo() (ret NodeInfo) {
        ret.Addr = n.addr
-       if n := copy(ret.ID[:], n.id); n != 20 {
+       if n := copy(ret.ID[:], n.idString()); n != 20 {
                panic(n)
        }
        return
 }
 
 func (n *Node) DefinitelyGood() bool {
-       if len(n.id) != 20 {
+       if len(n.idString()) != 20 {
                return false
        }
        // No reason to think ill of them if they've never been queried.
@@ -184,6 +246,13 @@ func (m Msg) T() (t string) {
        return
 }
 
+func (m Msg) ID() string {
+       defer func() {
+               recover()
+       }()
+       return m[m["y"].(string)].(map[string]interface{})["id"].(string)
+}
+
 func (m Msg) Nodes() []NodeInfo {
        var r findNodeResponse
        if err := r.UnmarshalKRPCMsg(m); err != nil {
@@ -447,14 +516,14 @@ func (s *Server) AddNode(ni NodeInfo) {
                s.nodes = make(map[string]*Node)
        }
        n := s.getNode(ni.Addr)
-       if n.id == "" {
-               n.id = string(ni.ID[:])
+       if n.IDNotSet() {
+               n.SetIDFromBytes(ni.ID[:])
        }
 }
 
 func (s *Server) nodeByID(id string) *Node {
        for _, node := range s.nodes {
-               if node.id == id {
+               if node.idString() == id {
                        return node
                }
        }
@@ -464,7 +533,7 @@ func (s *Server) nodeByID(id string) *Node {
 func (s *Server) handleQuery(source dHTAddr, m Msg) {
        args := m["a"].(map[string]interface{})
        node := s.getNode(source)
-       node.id = args["id"].(string)
+       node.SetIDFromString(args["id"].(string))
        node.lastGotQuery = time.Now()
        // Don't respond.
        if s.passive {
@@ -473,7 +542,7 @@ func (s *Server) handleQuery(source dHTAddr, m Msg) {
        switch m["q"] {
        case "ping":
                s.reply(source, m["t"].(string), nil)
-       case "get_peers":
+       case "get_peers": // TODO: Extract common behaviour with find_node.
                targetID := args["info_hash"].(string)
                if len(targetID) != 20 {
                        break
@@ -494,7 +563,7 @@ func (s *Server) handleQuery(source dHTAddr, m Msg) {
                        "nodes": string(nodesBytes),
                        "token": "hi",
                })
-       case "find_node":
+       case "find_node": // TODO: Extract common behaviour with get_peers.
                targetID := args["target"].(string)
                if len(targetID) != 20 {
                        log.Printf("bad DHT query: %v", m)
@@ -510,9 +579,14 @@ func (s *Server) handleQuery(source dHTAddr, m Msg) {
                }
                nodesBytes := make([]byte, CompactNodeInfoLen*len(rNodes))
                for i, ni := range rNodes {
+                       // TODO: Put IPv6 nodes into the correct dict element.
+                       if ni.Addr.UDPAddr().IP.To4() == nil {
+                               continue
+                       }
                        err := ni.PutCompact(nodesBytes[i*CompactNodeInfoLen : (i+1)*CompactNodeInfoLen])
                        if err != nil {
-                               panic(err)
+                               log.Printf("error compacting %#v: %s", ni, err)
+                               continue
                        }
                }
                s.reply(source, m["t"].(string), map[string]interface{}{
@@ -550,13 +624,14 @@ func (s *Server) reply(addr dHTAddr, t string, r map[string]interface{}) {
 }
 
 func (s *Server) getNode(addr dHTAddr) (n *Node) {
-       n = s.nodes[addr.String()]
+       addrStr := addr.String()
+       n = s.nodes[addrStr]
        if n == nil {
                n = &Node{
                        addr: addr,
                }
                if len(s.nodes) < maxNodes {
-                       s.nodes[addr.String()] = n
+                       s.nodes[addrStr] = n
                }
        }
        return
@@ -577,12 +652,12 @@ func (s *Server) nodeTimedOut(addr dHTAddr) {
 
 func (s *Server) writeToNode(b []byte, node dHTAddr) (err error) {
        if list := s.ipBlockList; list != nil {
-               if r := list.Lookup(util.AddrIP(node)); r != nil {
+               if r := list.Lookup(util.AddrIP(node.UDPAddr())); r != nil {
                        err = fmt.Errorf("write to %s blocked: %s", node, r.Description)
                        return
                }
        }
-       n, err := s.socket.WriteTo(b, node)
+       n, err := s.socket.WriteTo(b, node.UDPAddr())
        if err != nil {
                err = fmt.Errorf("error writing %d bytes to %s: %s", len(b), node, err)
                return
@@ -672,7 +747,7 @@ func (ni *NodeInfo) PutCompact(b []byte) error {
        }
        ip := util.AddrIP(ni.Addr).To4()
        if len(ip) != 4 {
-               panic(ip)
+               return errors.New("expected ipv4 address")
        }
        if n := copy(b[20:], ip); n != 4 {
                panic(n)
@@ -707,7 +782,7 @@ func (s *Server) Ping(node *net.UDPAddr) (*transaction, error) {
 func (s *Server) AnnouncePeer(port int, impliedPort bool, infoHash string) (err error) {
        s.mu.Lock()
        defer s.mu.Unlock()
-       for _, node := range s.closestNodes(160, infoHash, func(n *Node) bool {
+       for _, node := range s.closestNodes(160, nodeIDFromString(infoHash), func(n *Node) bool {
                return n.announceToken != ""
        }) {
                err = s.announcePeer(node.addr, infoHash, port, node.announceToken, impliedPort)
@@ -841,7 +916,7 @@ func (s *Server) liftNodes(d Msg) {
                                continue
                        }
                        n := s.getNode(cni.Addr)
-                       n.id = string(cni.ID[:])
+                       n.SetIDFromBytes(cni.ID[:])
                }
                // log.Printf("lifted %d nodes", len(r.Nodes))
        }
@@ -1014,7 +1089,7 @@ func (s *Server) Nodes() (nis []NodeInfo) {
                ni := NodeInfo{
                        Addr: node.addr,
                }
-               if n := copy(ni.ID[:], node.id); n != 20 && n != 0 {
+               if n := copy(ni.ID[:], node.idString()); n != 20 && n != 0 {
                        panic(n)
                }
                nis = append(nis, ni)
@@ -1033,95 +1108,6 @@ func (s *Server) Close() {
        s.mu.Unlock()
 }
 
-type distance interface {
-       Cmp(distance) int
-       BitCount() int
-       IsZero() bool
-}
-
-type bigIntDistance struct {
-       big.Int
-}
-
-// How many bits?
-func bitCount(n *big.Int) int {
-       var count int = 0
-       for _, b := range n.Bytes() {
-               count += int(bitCounts[b])
-       }
-       return count
-}
-
-// The bit counts for each byte value (0 - 255).
-var bitCounts = []int8{
-       // Generated by Java BitCount of all values from 0 to 255
-       0, 1, 1, 2, 1, 2, 2, 3,
-       1, 2, 2, 3, 2, 3, 3, 4,
-       1, 2, 2, 3, 2, 3, 3, 4,
-       2, 3, 3, 4, 3, 4, 4, 5,
-       1, 2, 2, 3, 2, 3, 3, 4,
-       2, 3, 3, 4, 3, 4, 4, 5,
-       2, 3, 3, 4, 3, 4, 4, 5,
-       3, 4, 4, 5, 4, 5, 5, 6,
-       1, 2, 2, 3, 2, 3, 3, 4,
-       2, 3, 3, 4, 3, 4, 4, 5,
-       2, 3, 3, 4, 3, 4, 4, 5,
-       3, 4, 4, 5, 4, 5, 5, 6,
-       2, 3, 3, 4, 3, 4, 4, 5,
-       3, 4, 4, 5, 4, 5, 5, 6,
-       3, 4, 4, 5, 4, 5, 5, 6,
-       4, 5, 5, 6, 5, 6, 6, 7,
-       1, 2, 2, 3, 2, 3, 3, 4,
-       2, 3, 3, 4, 3, 4, 4, 5,
-       2, 3, 3, 4, 3, 4, 4, 5,
-       3, 4, 4, 5, 4, 5, 5, 6,
-       2, 3, 3, 4, 3, 4, 4, 5,
-       3, 4, 4, 5, 4, 5, 5, 6,
-       3, 4, 4, 5, 4, 5, 5, 6,
-       4, 5, 5, 6, 5, 6, 6, 7,
-       2, 3, 3, 4, 3, 4, 4, 5,
-       3, 4, 4, 5, 4, 5, 5, 6,
-       3, 4, 4, 5, 4, 5, 5, 6,
-       4, 5, 5, 6, 5, 6, 6, 7,
-       3, 4, 4, 5, 4, 5, 5, 6,
-       4, 5, 5, 6, 5, 6, 6, 7,
-       4, 5, 5, 6, 5, 6, 6, 7,
-       5, 6, 6, 7, 6, 7, 7, 8,
-}
-
-func (me bigIntDistance) BitCount() int {
-       return bitCount(&me.Int)
-}
-
-func (me bigIntDistance) Cmp(d bigIntDistance) int {
-       return me.Int.Cmp(&d.Int)
-}
-
-func (me bigIntDistance) IsZero() bool {
-       var zero big.Int
-       return me.Int.Cmp(&zero) == 0
-}
-
-type bitCountDistance int
-
-func (me bitCountDistance) BitCount() int { return int(me) }
-
-func (me bitCountDistance) Cmp(rhs distance) int {
-       rhs_ := rhs.(bitCountDistance)
-       if me < rhs_ {
-               return -1
-       } else if me == rhs_ {
-               return 0
-       } else {
-               return 1
-       }
-}
-
-func (me bitCountDistance) IsZero() bool {
-       return me == 0
-}
-
-// Below are 2 versions of idDistance. Only one can be active.
 var maxDistance big.Int
 
 func init() {
@@ -1129,67 +1115,24 @@ func init() {
        maxDistance.SetBit(&zero, 160, 1)
 }
 
-// If we don't know the ID for a node, then its distance is more than the
-// furthest possible distance otherwise.
-func idDistance(a, b string) (ret bigIntDistance) {
-       if a == "" && b == "" {
-               return
-       }
-       if a == "" {
-               if len(b) != 20 {
-                       panic(b)
-               }
-               ret.Set(&maxDistance)
-               return
-       }
-       if b == "" {
-               if len(a) != 20 {
-                       panic(a)
-               }
-               ret.Set(&maxDistance)
-               return
-       }
-       if len(a) != 20 {
-               panic(a)
-       }
-       if len(b) != 20 {
-               panic(b)
-       }
-       var x, y big.Int
-       x.SetBytes([]byte(a))
-       y.SetBytes([]byte(b))
-       ret.Int.Xor(&x, &y)
-       return ret
-}
-
-// func idDistance(a, b string) bitCountDistance {
-//     ret := 0
-//     for i := 0; i < 20; i++ {
-//             for j := uint(0); j < 8; j++ {
-//                     ret += int(a[i]>>j&1 ^ b[i]>>j&1)
-//             }
-//     }
-//     return bitCountDistance(ret)
-// }
-
 func (s *Server) closestGoodNodes(k int, targetID string) []*Node {
-       return s.closestNodes(k, targetID, func(n *Node) bool { return n.DefinitelyGood() })
+       return s.closestNodes(k, nodeIDFromString(targetID), func(n *Node) bool { return n.DefinitelyGood() })
 }
 
-func (s *Server) closestNodes(k int, targetID string, filter func(*Node) bool) []*Node {
-       sel := newKClosestNodesSelector(k, targetID)
+func (s *Server) closestNodes(k int, target nodeID, filter func(*Node) bool) []*Node {
+       sel := newKClosestNodesSelector(k, target)
        idNodes := make(map[string]*Node, len(s.nodes))
        for _, node := range s.nodes {
                if !filter(node) {
                        continue
                }
                sel.Push(node.id)
-               idNodes[node.id] = node
+               idNodes[node.idString()] = node
        }
        ids := sel.IDs()
        ret := make([]*Node, 0, len(ids))
        for _, id := range ids {
-               ret = append(ret, idNodes[id])
+               ret = append(ret, idNodes[id.String()])
        }
        return ret
 }
index 137aa1366ee2f6e37ef0a4bba7afd119be323f86..8a8a1f7f46d93b4914a7daea10e528ab1cf3c9b5 100644 (file)
@@ -46,29 +46,34 @@ func recoverPanicOrDie(t *testing.T, f func()) {
 
 const zeroID = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
 
-var testIDs = []string{
-       zeroID,
-       "\x03" + zeroID[1:],
-       "\x03" + zeroID[1:18] + "\x55\xf0",
-       "\x55" + zeroID[1:17] + "\xff\x55\x0f",
-       "\x54" + zeroID[1:18] + "\x50\x0f",
-       "",
+var testIDs []nodeID
+
+func init() {
+       for _, s := range []string{
+               zeroID,
+               "\x03" + zeroID[1:],
+               "\x03" + zeroID[1:18] + "\x55\xf0",
+               "\x55" + zeroID[1:17] + "\xff\x55\x0f",
+               "\x54" + zeroID[1:18] + "\x50\x0f",
+               "",
+       } {
+               testIDs = append(testIDs, nodeIDFromString(s))
+       }
 }
 
 func TestDistances(t *testing.T) {
-       if idDistance(testIDs[3], testIDs[0]).BitCount() != 4+8+4+4 {
-               t.FailNow()
-       }
-       if idDistance(testIDs[3], testIDs[1]).BitCount() != 4+8+4+4 {
-               t.FailNow()
-       }
-       if idDistance(testIDs[3], testIDs[2]).BitCount() != 4+8+8 {
-               t.FailNow()
+       expectBitcount := func(i big.Int, count int) {
+               if bitCount(i) != count {
+                       t.Fatalf("expected bitcount of %d: got %d", count, bitCount(i))
+               }
        }
+       expectBitcount(testIDs[3].Distance(&testIDs[0]), 4+8+4+4)
+       expectBitcount(testIDs[3].Distance(&testIDs[1]), 4+8+4+4)
+       expectBitcount(testIDs[3].Distance(&testIDs[2]), 4+8+8)
        for i := 0; i < 5; i++ {
-               dist := idDistance(testIDs[i], testIDs[5]).Int
+               dist := testIDs[i].Distance(&testIDs[5])
                if dist.Cmp(&maxDistance) != 0 {
-                       t.FailNow()
+                       t.Fatal("expected max distance for comparison with unset node id")
                }
        }
 }
@@ -79,37 +84,6 @@ func TestMaxDistanceString(t *testing.T) {
        }
 }
 
-func TestBadIdStrings(t *testing.T) {
-       var a, b string
-       idDistance(a, b)
-       idDistance(a, zeroID)
-       idDistance(zeroID, b)
-       recoverPanicOrDie(t, func() {
-               idDistance("when", a)
-       })
-       recoverPanicOrDie(t, func() {
-               idDistance(a, "bad")
-       })
-       recoverPanicOrDie(t, func() {
-               idDistance("meets", "evil")
-       })
-       for _, id := range testIDs {
-               if !idDistance(id, id).IsZero() {
-                       t.Fatal("identical IDs should have distance 0")
-               }
-       }
-       a = "\x03" + zeroID[1:]
-       b = zeroID
-       if idDistance(a, b).BitCount() != 2 {
-               t.FailNow()
-       }
-       a = "\x03" + zeroID[1:18] + "\x55\xf0"
-       b = "\x55" + zeroID[1:17] + "\xff\x55\x0f"
-       if c := idDistance(a, b).BitCount(); c != 20 {
-               t.Fatal(c)
-       }
-}
-
 func TestClosestNodes(t *testing.T) {
        cn := newKClosestNodesSelector(2, testIDs[3])
        for _, i := range rand.Perm(len(testIDs)) {
@@ -120,9 +94,9 @@ func TestClosestNodes(t *testing.T) {
        }
        m := map[string]bool{}
        for _, id := range cn.IDs() {
-               m[id] = true
+               m[id.String()] = true
        }
-       if !m[testIDs[3]] || !m[testIDs[4]] {
+       if !m[testIDs[3].String()] || !m[testIDs[4].String()] {
                t.FailNow()
        }
 }
@@ -154,3 +128,28 @@ func TestDHTDefaultConfig(t *testing.T) {
        }
        s.Close()
 }
+
+func TestPing(t *testing.T) {
+       srv, err := NewServer(nil)
+       if err != nil {
+               t.Fatal(err)
+       }
+       defer srv.Close()
+       srv0, err := NewServer(nil)
+       if err != nil {
+               t.Fatal(err)
+       }
+       defer srv0.Close()
+       tn, err := srv.Ping(&net.UDPAddr{
+               IP:   []byte{127, 0, 0, 1},
+               Port: srv0.LocalAddr().(*net.UDPAddr).Port,
+       })
+       if err != nil {
+               t.Fatal(err)
+       }
+       defer tn.Close()
+       msg := <-tn.Response
+       if msg.ID() != srv0.IDString() {
+               t.FailNow()
+       }
+}
index eb8614c7a63d0a533c3d1d139e38cdc6b238421e..31537cc67b75f234fad6fa7ad3792c16dc590040 100644 (file)
@@ -1,12 +1,12 @@
 package dht
 
 import (
+       "log"
+       "time"
+
        "bitbucket.org/anacrolix/go.torrent/util"
        "bitbucket.org/anacrolix/sync"
        "github.com/willf/bloom"
-       "log"
-       "net"
-       "time"
 )
 
 type peerDiscovery struct {
@@ -19,7 +19,7 @@ type peerDiscovery struct {
 
 func (s *Server) GetPeers(infoHash string) (*peerStream, error) {
        s.mu.Lock()
-       startAddrs := func() (ret []net.Addr) {
+       startAddrs := func() (ret []dHTAddr) {
                for _, n := range s.closestGoodNodes(160, infoHash) {
                        ret = append(ret, n.addr)
                }
@@ -32,7 +32,7 @@ func (s *Server) GetPeers(infoHash string) (*peerStream, error) {
                        return nil, err
                }
                for _, addr := range addrs {
-                       startAddrs = append(startAddrs, addr)
+                       startAddrs = append(startAddrs, newDHTAddr(addr))
                }
        }
        disc := &peerDiscovery{
@@ -41,7 +41,7 @@ func (s *Server) GetPeers(infoHash string) (*peerStream, error) {
                        stop:   make(chan struct{}),
                        values: make(chan peerStreamValue),
                },
-               triedAddrs: bloom.NewWithEstimates(500000, 0.01),
+               triedAddrs: bloom.NewWithEstimates(10000, 0.01),
                server:     s,
                infoHash:   infoHash,
        }
@@ -72,7 +72,7 @@ func (s *Server) GetPeers(infoHash string) (*peerStream, error) {
        return disc.peerStream, nil
 }
 
-func (me *peerDiscovery) gotNodeAddr(addr net.Addr) {
+func (me *peerDiscovery) gotNodeAddr(addr dHTAddr) {
        if util.AddrPort(addr) == 0 {
                // Not a contactable address.
                return
@@ -86,7 +86,7 @@ func (me *peerDiscovery) gotNodeAddr(addr net.Addr) {
        me.contact(addr)
 }
 
-func (me *peerDiscovery) contact(addr net.Addr) {
+func (me *peerDiscovery) contact(addr dHTAddr) {
        me.triedAddrs.Add([]byte(addr.String()))
        if err := me.getPeers(addr); err != nil {
                log.Printf("error sending get_peers request to %s: %s", addr, err)
@@ -111,7 +111,7 @@ func (me *peerDiscovery) closingCh() chan struct{} {
        return me.peerStream.stop
 }
 
-func (me *peerDiscovery) getPeers(addr net.Addr) error {
+func (me *peerDiscovery) getPeers(addr dHTAddr) error {
        me.server.mu.Lock()
        defer me.server.mu.Unlock()
        t, err := me.server.getPeers(addr, me.infoHash)