--- /dev/null
+package dht
+
+import (
+ "math/big"
+)
+
+// TODO: The bitcounting is a relic of the old and incorrect distance
+// calculation. It is still useful in some tests but should eventually be
+// replaced with actual distances.
+
+// How many bits?
+func bitCount(n big.Int) int {
+ var count int = 0
+ for _, b := range n.Bytes() {
+ count += int(bitCounts[b])
+ }
+ return count
+}
+
+// The bit counts for each byte value (0 - 255).
+var bitCounts = []int8{
+ // Generated by Java BitCount of all values from 0 to 255
+ 0, 1, 1, 2, 1, 2, 2, 3,
+ 1, 2, 2, 3, 2, 3, 3, 4,
+ 1, 2, 2, 3, 2, 3, 3, 4,
+ 2, 3, 3, 4, 3, 4, 4, 5,
+ 1, 2, 2, 3, 2, 3, 3, 4,
+ 2, 3, 3, 4, 3, 4, 4, 5,
+ 2, 3, 3, 4, 3, 4, 4, 5,
+ 3, 4, 4, 5, 4, 5, 5, 6,
+ 1, 2, 2, 3, 2, 3, 3, 4,
+ 2, 3, 3, 4, 3, 4, 4, 5,
+ 2, 3, 3, 4, 3, 4, 4, 5,
+ 3, 4, 4, 5, 4, 5, 5, 6,
+ 2, 3, 3, 4, 3, 4, 4, 5,
+ 3, 4, 4, 5, 4, 5, 5, 6,
+ 3, 4, 4, 5, 4, 5, 5, 6,
+ 4, 5, 5, 6, 5, 6, 6, 7,
+ 1, 2, 2, 3, 2, 3, 3, 4,
+ 2, 3, 3, 4, 3, 4, 4, 5,
+ 2, 3, 3, 4, 3, 4, 4, 5,
+ 3, 4, 4, 5, 4, 5, 5, 6,
+ 2, 3, 3, 4, 3, 4, 4, 5,
+ 3, 4, 4, 5, 4, 5, 5, 6,
+ 3, 4, 4, 5, 4, 5, 5, 6,
+ 4, 5, 5, 6, 5, 6, 6, 7,
+ 2, 3, 3, 4, 3, 4, 4, 5,
+ 3, 4, 4, 5, 4, 5, 5, 6,
+ 3, 4, 4, 5, 4, 5, 5, 6,
+ 4, 5, 5, 6, 5, 6, 6, 7,
+ 3, 4, 4, 5, 4, 5, 5, 6,
+ 4, 5, 5, 6, 5, 6, 6, 7,
+ 4, 5, 5, 6, 5, 6, 6, 7,
+ 5, 6, 6, 7, 6, 7, 7, 8,
+}
package dht
import (
+ "bitbucket.org/anacrolix/go.torrent/iplist"
+ "bitbucket.org/anacrolix/go.torrent/logonce"
+ "bitbucket.org/anacrolix/go.torrent/util"
+ "bitbucket.org/anacrolix/sync"
"crypto"
_ "crypto/sha1"
"encoding/binary"
"errors"
"fmt"
+ "github.com/anacrolix/libtorgo/bencode"
"io"
"log"
"math/big"
"net"
"os"
"time"
-
- "bitbucket.org/anacrolix/sync"
-
- "bitbucket.org/anacrolix/go.torrent/iplist"
-
- "bitbucket.org/anacrolix/go.torrent/logonce"
- "bitbucket.org/anacrolix/go.torrent/util"
- "github.com/anacrolix/libtorgo/bencode"
)
const maxNodes = 10000
type dHTAddr interface {
net.Addr
+ UDPAddr() *net.UDPAddr
}
-func newDHTAddr(addr *net.UDPAddr) (ret dHTAddr) {
- ret = addr
- return
+type cachedAddr struct {
+ a net.Addr
+ s string
+}
+
+func (ca cachedAddr) Network() string {
+ return ca.a.Network()
+}
+
+func (ca cachedAddr) String() string {
+ return ca.s
+}
+
+func (ca cachedAddr) UDPAddr() *net.UDPAddr {
+ return ca.a.(*net.UDPAddr)
+}
+
+func newDHTAddr(addr *net.UDPAddr) dHTAddr {
+ return cachedAddr{addr, addr.String()}
}
type ServerConfig struct {
return fmt.Sprintf("dht server on %s", s.socket.LocalAddr())
}
+type nodeID struct {
+ i big.Int
+ set bool
+}
+
+func (nid *nodeID) IsUnset() bool {
+ return !nid.set
+}
+
+func nodeIDFromString(s string) (ret nodeID) {
+ if s == "" {
+ return
+ }
+ ret.i.SetBytes([]byte(s))
+ ret.set = true
+ return
+}
+
+func (nid0 *nodeID) Distance(nid1 *nodeID) (ret big.Int) {
+ if nid0.IsUnset() != nid1.IsUnset() {
+ ret = maxDistance
+ return
+ }
+ ret.Xor(&nid0.i, &nid1.i)
+ return
+}
+
+func (nid *nodeID) String() string {
+ return string(nid.i.Bytes())
+}
+
type Node struct {
addr dHTAddr
- id string
+ id nodeID
announceToken string
lastGotQuery time.Time
lastSentQuery time.Time
}
+func (n *Node) idString() string {
+ return n.id.String()
+}
+
+func (n *Node) SetIDFromBytes(b []byte) {
+ n.id.i.SetBytes(b)
+ n.id.set = true
+}
+
+func (n *Node) SetIDFromString(s string) {
+ n.id.i.SetBytes([]byte(s))
+}
+
+func (n *Node) IDNotSet() bool {
+ return n.id.i.Int64() == 0
+}
+
func (n *Node) NodeInfo() (ret NodeInfo) {
ret.Addr = n.addr
- if n := copy(ret.ID[:], n.id); n != 20 {
+ if n := copy(ret.ID[:], n.idString()); n != 20 {
panic(n)
}
return
}
func (n *Node) DefinitelyGood() bool {
- if len(n.id) != 20 {
+ if len(n.idString()) != 20 {
return false
}
// No reason to think ill of them if they've never been queried.
return
}
+func (m Msg) ID() string {
+ defer func() {
+ recover()
+ }()
+ return m[m["y"].(string)].(map[string]interface{})["id"].(string)
+}
+
func (m Msg) Nodes() []NodeInfo {
var r findNodeResponse
if err := r.UnmarshalKRPCMsg(m); err != nil {
s.nodes = make(map[string]*Node)
}
n := s.getNode(ni.Addr)
- if n.id == "" {
- n.id = string(ni.ID[:])
+ if n.IDNotSet() {
+ n.SetIDFromBytes(ni.ID[:])
}
}
func (s *Server) nodeByID(id string) *Node {
for _, node := range s.nodes {
- if node.id == id {
+ if node.idString() == id {
return node
}
}
func (s *Server) handleQuery(source dHTAddr, m Msg) {
args := m["a"].(map[string]interface{})
node := s.getNode(source)
- node.id = args["id"].(string)
+ node.SetIDFromString(args["id"].(string))
node.lastGotQuery = time.Now()
// Don't respond.
if s.passive {
switch m["q"] {
case "ping":
s.reply(source, m["t"].(string), nil)
- case "get_peers":
+ case "get_peers": // TODO: Extract common behaviour with find_node.
targetID := args["info_hash"].(string)
if len(targetID) != 20 {
break
"nodes": string(nodesBytes),
"token": "hi",
})
- case "find_node":
+ case "find_node": // TODO: Extract common behaviour with get_peers.
targetID := args["target"].(string)
if len(targetID) != 20 {
log.Printf("bad DHT query: %v", m)
}
nodesBytes := make([]byte, CompactNodeInfoLen*len(rNodes))
for i, ni := range rNodes {
+ // TODO: Put IPv6 nodes into the correct dict element.
+ if ni.Addr.UDPAddr().IP.To4() == nil {
+ continue
+ }
err := ni.PutCompact(nodesBytes[i*CompactNodeInfoLen : (i+1)*CompactNodeInfoLen])
if err != nil {
- panic(err)
+ log.Printf("error compacting %#v: %s", ni, err)
+ continue
}
}
s.reply(source, m["t"].(string), map[string]interface{}{
}
func (s *Server) getNode(addr dHTAddr) (n *Node) {
- n = s.nodes[addr.String()]
+ addrStr := addr.String()
+ n = s.nodes[addrStr]
if n == nil {
n = &Node{
addr: addr,
}
if len(s.nodes) < maxNodes {
- s.nodes[addr.String()] = n
+ s.nodes[addrStr] = n
}
}
return
func (s *Server) writeToNode(b []byte, node dHTAddr) (err error) {
if list := s.ipBlockList; list != nil {
- if r := list.Lookup(util.AddrIP(node)); r != nil {
+ if r := list.Lookup(util.AddrIP(node.UDPAddr())); r != nil {
err = fmt.Errorf("write to %s blocked: %s", node, r.Description)
return
}
}
- n, err := s.socket.WriteTo(b, node)
+ n, err := s.socket.WriteTo(b, node.UDPAddr())
if err != nil {
err = fmt.Errorf("error writing %d bytes to %s: %s", len(b), node, err)
return
}
ip := util.AddrIP(ni.Addr).To4()
if len(ip) != 4 {
- panic(ip)
+ return errors.New("expected ipv4 address")
}
if n := copy(b[20:], ip); n != 4 {
panic(n)
func (s *Server) AnnouncePeer(port int, impliedPort bool, infoHash string) (err error) {
s.mu.Lock()
defer s.mu.Unlock()
- for _, node := range s.closestNodes(160, infoHash, func(n *Node) bool {
+ for _, node := range s.closestNodes(160, nodeIDFromString(infoHash), func(n *Node) bool {
return n.announceToken != ""
}) {
err = s.announcePeer(node.addr, infoHash, port, node.announceToken, impliedPort)
continue
}
n := s.getNode(cni.Addr)
- n.id = string(cni.ID[:])
+ n.SetIDFromBytes(cni.ID[:])
}
// log.Printf("lifted %d nodes", len(r.Nodes))
}
ni := NodeInfo{
Addr: node.addr,
}
- if n := copy(ni.ID[:], node.id); n != 20 && n != 0 {
+ if n := copy(ni.ID[:], node.idString()); n != 20 && n != 0 {
panic(n)
}
nis = append(nis, ni)
s.mu.Unlock()
}
-type distance interface {
- Cmp(distance) int
- BitCount() int
- IsZero() bool
-}
-
-type bigIntDistance struct {
- big.Int
-}
-
-// How many bits?
-func bitCount(n *big.Int) int {
- var count int = 0
- for _, b := range n.Bytes() {
- count += int(bitCounts[b])
- }
- return count
-}
-
-// The bit counts for each byte value (0 - 255).
-var bitCounts = []int8{
- // Generated by Java BitCount of all values from 0 to 255
- 0, 1, 1, 2, 1, 2, 2, 3,
- 1, 2, 2, 3, 2, 3, 3, 4,
- 1, 2, 2, 3, 2, 3, 3, 4,
- 2, 3, 3, 4, 3, 4, 4, 5,
- 1, 2, 2, 3, 2, 3, 3, 4,
- 2, 3, 3, 4, 3, 4, 4, 5,
- 2, 3, 3, 4, 3, 4, 4, 5,
- 3, 4, 4, 5, 4, 5, 5, 6,
- 1, 2, 2, 3, 2, 3, 3, 4,
- 2, 3, 3, 4, 3, 4, 4, 5,
- 2, 3, 3, 4, 3, 4, 4, 5,
- 3, 4, 4, 5, 4, 5, 5, 6,
- 2, 3, 3, 4, 3, 4, 4, 5,
- 3, 4, 4, 5, 4, 5, 5, 6,
- 3, 4, 4, 5, 4, 5, 5, 6,
- 4, 5, 5, 6, 5, 6, 6, 7,
- 1, 2, 2, 3, 2, 3, 3, 4,
- 2, 3, 3, 4, 3, 4, 4, 5,
- 2, 3, 3, 4, 3, 4, 4, 5,
- 3, 4, 4, 5, 4, 5, 5, 6,
- 2, 3, 3, 4, 3, 4, 4, 5,
- 3, 4, 4, 5, 4, 5, 5, 6,
- 3, 4, 4, 5, 4, 5, 5, 6,
- 4, 5, 5, 6, 5, 6, 6, 7,
- 2, 3, 3, 4, 3, 4, 4, 5,
- 3, 4, 4, 5, 4, 5, 5, 6,
- 3, 4, 4, 5, 4, 5, 5, 6,
- 4, 5, 5, 6, 5, 6, 6, 7,
- 3, 4, 4, 5, 4, 5, 5, 6,
- 4, 5, 5, 6, 5, 6, 6, 7,
- 4, 5, 5, 6, 5, 6, 6, 7,
- 5, 6, 6, 7, 6, 7, 7, 8,
-}
-
-func (me bigIntDistance) BitCount() int {
- return bitCount(&me.Int)
-}
-
-func (me bigIntDistance) Cmp(d bigIntDistance) int {
- return me.Int.Cmp(&d.Int)
-}
-
-func (me bigIntDistance) IsZero() bool {
- var zero big.Int
- return me.Int.Cmp(&zero) == 0
-}
-
-type bitCountDistance int
-
-func (me bitCountDistance) BitCount() int { return int(me) }
-
-func (me bitCountDistance) Cmp(rhs distance) int {
- rhs_ := rhs.(bitCountDistance)
- if me < rhs_ {
- return -1
- } else if me == rhs_ {
- return 0
- } else {
- return 1
- }
-}
-
-func (me bitCountDistance) IsZero() bool {
- return me == 0
-}
-
-// Below are 2 versions of idDistance. Only one can be active.
var maxDistance big.Int
func init() {
maxDistance.SetBit(&zero, 160, 1)
}
-// If we don't know the ID for a node, then its distance is more than the
-// furthest possible distance otherwise.
-func idDistance(a, b string) (ret bigIntDistance) {
- if a == "" && b == "" {
- return
- }
- if a == "" {
- if len(b) != 20 {
- panic(b)
- }
- ret.Set(&maxDistance)
- return
- }
- if b == "" {
- if len(a) != 20 {
- panic(a)
- }
- ret.Set(&maxDistance)
- return
- }
- if len(a) != 20 {
- panic(a)
- }
- if len(b) != 20 {
- panic(b)
- }
- var x, y big.Int
- x.SetBytes([]byte(a))
- y.SetBytes([]byte(b))
- ret.Int.Xor(&x, &y)
- return ret
-}
-
-// func idDistance(a, b string) bitCountDistance {
-// ret := 0
-// for i := 0; i < 20; i++ {
-// for j := uint(0); j < 8; j++ {
-// ret += int(a[i]>>j&1 ^ b[i]>>j&1)
-// }
-// }
-// return bitCountDistance(ret)
-// }
-
func (s *Server) closestGoodNodes(k int, targetID string) []*Node {
- return s.closestNodes(k, targetID, func(n *Node) bool { return n.DefinitelyGood() })
+ return s.closestNodes(k, nodeIDFromString(targetID), func(n *Node) bool { return n.DefinitelyGood() })
}
-func (s *Server) closestNodes(k int, targetID string, filter func(*Node) bool) []*Node {
- sel := newKClosestNodesSelector(k, targetID)
+func (s *Server) closestNodes(k int, target nodeID, filter func(*Node) bool) []*Node {
+ sel := newKClosestNodesSelector(k, target)
idNodes := make(map[string]*Node, len(s.nodes))
for _, node := range s.nodes {
if !filter(node) {
continue
}
sel.Push(node.id)
- idNodes[node.id] = node
+ idNodes[node.idString()] = node
}
ids := sel.IDs()
ret := make([]*Node, 0, len(ids))
for _, id := range ids {
- ret = append(ret, idNodes[id])
+ ret = append(ret, idNodes[id.String()])
}
return ret
}
const zeroID = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
-var testIDs = []string{
- zeroID,
- "\x03" + zeroID[1:],
- "\x03" + zeroID[1:18] + "\x55\xf0",
- "\x55" + zeroID[1:17] + "\xff\x55\x0f",
- "\x54" + zeroID[1:18] + "\x50\x0f",
- "",
+var testIDs []nodeID
+
+func init() {
+ for _, s := range []string{
+ zeroID,
+ "\x03" + zeroID[1:],
+ "\x03" + zeroID[1:18] + "\x55\xf0",
+ "\x55" + zeroID[1:17] + "\xff\x55\x0f",
+ "\x54" + zeroID[1:18] + "\x50\x0f",
+ "",
+ } {
+ testIDs = append(testIDs, nodeIDFromString(s))
+ }
}
func TestDistances(t *testing.T) {
- if idDistance(testIDs[3], testIDs[0]).BitCount() != 4+8+4+4 {
- t.FailNow()
- }
- if idDistance(testIDs[3], testIDs[1]).BitCount() != 4+8+4+4 {
- t.FailNow()
- }
- if idDistance(testIDs[3], testIDs[2]).BitCount() != 4+8+8 {
- t.FailNow()
+ expectBitcount := func(i big.Int, count int) {
+ if bitCount(i) != count {
+ t.Fatalf("expected bitcount of %d: got %d", count, bitCount(i))
+ }
}
+ expectBitcount(testIDs[3].Distance(&testIDs[0]), 4+8+4+4)
+ expectBitcount(testIDs[3].Distance(&testIDs[1]), 4+8+4+4)
+ expectBitcount(testIDs[3].Distance(&testIDs[2]), 4+8+8)
for i := 0; i < 5; i++ {
- dist := idDistance(testIDs[i], testIDs[5]).Int
+ dist := testIDs[i].Distance(&testIDs[5])
if dist.Cmp(&maxDistance) != 0 {
- t.FailNow()
+ t.Fatal("expected max distance for comparison with unset node id")
}
}
}
}
}
-func TestBadIdStrings(t *testing.T) {
- var a, b string
- idDistance(a, b)
- idDistance(a, zeroID)
- idDistance(zeroID, b)
- recoverPanicOrDie(t, func() {
- idDistance("when", a)
- })
- recoverPanicOrDie(t, func() {
- idDistance(a, "bad")
- })
- recoverPanicOrDie(t, func() {
- idDistance("meets", "evil")
- })
- for _, id := range testIDs {
- if !idDistance(id, id).IsZero() {
- t.Fatal("identical IDs should have distance 0")
- }
- }
- a = "\x03" + zeroID[1:]
- b = zeroID
- if idDistance(a, b).BitCount() != 2 {
- t.FailNow()
- }
- a = "\x03" + zeroID[1:18] + "\x55\xf0"
- b = "\x55" + zeroID[1:17] + "\xff\x55\x0f"
- if c := idDistance(a, b).BitCount(); c != 20 {
- t.Fatal(c)
- }
-}
-
func TestClosestNodes(t *testing.T) {
cn := newKClosestNodesSelector(2, testIDs[3])
for _, i := range rand.Perm(len(testIDs)) {
}
m := map[string]bool{}
for _, id := range cn.IDs() {
- m[id] = true
+ m[id.String()] = true
}
- if !m[testIDs[3]] || !m[testIDs[4]] {
+ if !m[testIDs[3].String()] || !m[testIDs[4].String()] {
t.FailNow()
}
}
}
s.Close()
}
+
+func TestPing(t *testing.T) {
+ srv, err := NewServer(nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer srv.Close()
+ srv0, err := NewServer(nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer srv0.Close()
+ tn, err := srv.Ping(&net.UDPAddr{
+ IP: []byte{127, 0, 0, 1},
+ Port: srv0.LocalAddr().(*net.UDPAddr).Port,
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer tn.Close()
+ msg := <-tn.Response
+ if msg.ID() != srv0.IDString() {
+ t.FailNow()
+ }
+}