]> Sergey Matveev's repositories - btrtrc.git/commitdiff
Expose DHT ID distances as an interface and switch to big.Int and possibly the correc...
authorMatt Joiner <anacrolix@gmail.com>
Mon, 17 Nov 2014 07:47:24 +0000 (01:47 -0600)
committerMatt Joiner <anacrolix@gmail.com>
Mon, 17 Nov 2014 07:47:24 +0000 (01:47 -0600)
dht/closest_nodes.go
dht/dht.go
dht/dht_test.go

index a677a2e640bf76701069971d0a64cb7d54b23554..70366dd54ce41823d05330d4d47e35dd0d7657e7 100644 (file)
@@ -12,7 +12,7 @@ type nodeMaxHeap struct {
 func (me nodeMaxHeap) Len() int { return len(me.IDs) }
 
 func (me nodeMaxHeap) Less(i, j int) bool {
-       return idDistance(me.IDs[i], me.Target) > idDistance(me.IDs[j], me.Target)
+       return idDistance(me.IDs[i], me.Target).Cmp(idDistance(me.IDs[j], me.Target)) > 0
 }
 
 func (me *nodeMaxHeap) Pop() (ret interface{}) {
index 6e5d29cb7d02ea366df0e3b2d6866cadb325cf66..094a6bc63cfeb757ea99ead0265f02375b810976 100644 (file)
@@ -8,6 +8,7 @@ import (
        "fmt"
        "io"
        "log"
+       "math/big"
        "net"
        "os"
        "sync"
@@ -901,19 +902,116 @@ func (s *Server) Close() {
        s.mu.Unlock()
 }
 
-func idDistance(a, b string) (ret int) {
-       if len(a) != 20 {
-               panic(a)
-       }
-       if len(b) != 20 {
-               panic(b)
+type distance interface {
+       Cmp(distance) int
+       BitCount() int
+       IsZero() bool
+}
+
+type bigIntDistance struct {
+       *big.Int
+}
+
+// How many bits?
+func bitCount(n *big.Int) int {
+       var count int = 0
+       for _, b := range n.Bytes() {
+               count += int(bitCounts[b])
+       }
+       return count
+}
+
+// The bit counts for each byte value (0 - 255).
+var bitCounts = []int8{
+       // Generated by Java BitCount of all values from 0 to 255
+       0, 1, 1, 2, 1, 2, 2, 3,
+       1, 2, 2, 3, 2, 3, 3, 4,
+       1, 2, 2, 3, 2, 3, 3, 4,
+       2, 3, 3, 4, 3, 4, 4, 5,
+       1, 2, 2, 3, 2, 3, 3, 4,
+       2, 3, 3, 4, 3, 4, 4, 5,
+       2, 3, 3, 4, 3, 4, 4, 5,
+       3, 4, 4, 5, 4, 5, 5, 6,
+       1, 2, 2, 3, 2, 3, 3, 4,
+       2, 3, 3, 4, 3, 4, 4, 5,
+       2, 3, 3, 4, 3, 4, 4, 5,
+       3, 4, 4, 5, 4, 5, 5, 6,
+       2, 3, 3, 4, 3, 4, 4, 5,
+       3, 4, 4, 5, 4, 5, 5, 6,
+       3, 4, 4, 5, 4, 5, 5, 6,
+       4, 5, 5, 6, 5, 6, 6, 7,
+       1, 2, 2, 3, 2, 3, 3, 4,
+       2, 3, 3, 4, 3, 4, 4, 5,
+       2, 3, 3, 4, 3, 4, 4, 5,
+       3, 4, 4, 5, 4, 5, 5, 6,
+       2, 3, 3, 4, 3, 4, 4, 5,
+       3, 4, 4, 5, 4, 5, 5, 6,
+       3, 4, 4, 5, 4, 5, 5, 6,
+       4, 5, 5, 6, 5, 6, 6, 7,
+       2, 3, 3, 4, 3, 4, 4, 5,
+       3, 4, 4, 5, 4, 5, 5, 6,
+       3, 4, 4, 5, 4, 5, 5, 6,
+       4, 5, 5, 6, 5, 6, 6, 7,
+       3, 4, 4, 5, 4, 5, 5, 6,
+       4, 5, 5, 6, 5, 6, 6, 7,
+       4, 5, 5, 6, 5, 6, 6, 7,
+       5, 6, 6, 7, 6, 7, 7, 8,
+}
+
+func (me bigIntDistance) BitCount() int {
+       return bitCount(me.Int)
+}
+
+func (me bigIntDistance) Cmp(d distance) int {
+       return me.Int.Cmp(d.(bigIntDistance).Int)
+}
+
+func (me bigIntDistance) IsZero() bool {
+       return me.Int.Cmp(big.NewInt(0)) == 0
+}
+
+type bitCountDistance int
+
+func (me bitCountDistance) BitCount() int { return int(me) }
+
+func (me bitCountDistance) Cmp(rhs distance) int {
+       rhs_ := rhs.(bitCountDistance)
+       if me < rhs_ {
+               return -1
+       } else if me == rhs_ {
+               return 0
+       } else {
+               return 1
        }
-       for i := 0; i < 20; i++ {
-               for j := uint(0); j < 8; j++ {
-                       ret += int(a[i]>>j&1 ^ b[i]>>j&1)
+}
+
+func (me bitCountDistance) IsZero() bool {
+       return me == 0
+}
+
+func idDistance(a, b string) distance {
+       if true {
+               if len(a) != 20 {
+                       panic(a)
                }
+               if len(b) != 20 {
+                       panic(b)
+               }
+               x := new(big.Int)
+               y := new(big.Int)
+               x.SetBytes([]byte(a))
+               y.SetBytes([]byte(b))
+               dist := new(big.Int)
+               return bigIntDistance{dist.Xor(x, y)}
+       } else {
+               ret := 0
+               for i := 0; i < 20; i++ {
+                       for j := uint(0); j < 8; j++ {
+                               ret += int(a[i]>>j&1 ^ b[i]>>j&1)
+                       }
+               }
+               return bitCountDistance(ret)
        }
-       return
 }
 
 func (s *Server) closestGoodNodes(k int, targetID string) []*Node {
index 71afc5ebaa0aeaf2024534efee29e8c1a9e4039f..5b5b5d1d5fdbe055a9069c97feafc9ec5b6c60ba 100644 (file)
@@ -1,11 +1,17 @@
 package dht
 
 import (
+       "math/big"
        "math/rand"
        "net"
        "testing"
 )
 
+func TestSetNilBigInt(t *testing.T) {
+       i := new(big.Int)
+       i.SetBytes(make([]byte, 2))
+}
+
 func TestMarshalCompactNodeInfo(t *testing.T) {
        cni := NodeInfo{
                ID: [20]byte{'a', 'b', 'c'},
@@ -49,13 +55,13 @@ var testIDs = []string{
 }
 
 func TestDistances(t *testing.T) {
-       if idDistance(testIDs[3], testIDs[0]) != 4+8+4+4 {
+       if idDistance(testIDs[3], testIDs[0]).BitCount() != 4+8+4+4 {
                t.FailNow()
        }
-       if idDistance(testIDs[3], testIDs[1]) != 4+8+4+4 {
+       if idDistance(testIDs[3], testIDs[1]).BitCount() != 4+8+4+4 {
                t.FailNow()
        }
-       if idDistance(testIDs[3], testIDs[2]) != 4+8+8 {
+       if idDistance(testIDs[3], testIDs[2]).BitCount() != 4+8+8 {
                t.FailNow()
        }
 }
@@ -71,17 +77,17 @@ func TestBadIdStrings(t *testing.T) {
        recoverPanicOrDie(t, func() {
                idDistance(zeroID, b)
        })
-       if idDistance(zeroID, zeroID) != 0 {
-               t.FailNow()
+       if !idDistance(zeroID, zeroID).IsZero() {
+               t.Fatal("identical IDs should have distance 0")
        }
        a = "\x03" + zeroID[1:]
        b = zeroID
-       if idDistance(a, b) != 2 {
+       if idDistance(a, b).BitCount() != 2 {
                t.FailNow()
        }
        a = "\x03" + zeroID[1:18] + "\x55\xf0"
        b = "\x55" + zeroID[1:17] + "\xff\x55\x0f"
-       if c := idDistance(a, b); c != 20 {
+       if c := idDistance(a, b).BitCount(); c != 20 {
                t.Fatal(c)
        }
 }