]> Sergey Matveev's repositories - btrtrc.git/commitdiff
Switch dht-server to bootstrapping
authorMatt Joiner <anacrolix@gmail.com>
Sun, 25 May 2014 11:34:29 +0000 (21:34 +1000)
committerMatt Joiner <anacrolix@gmail.com>
Sun, 25 May 2014 11:34:29 +0000 (21:34 +1000)
cmd/dht-ping/main.go [new file with mode: 0644]
cmd/dht-server/main.go
dht/dht.go
dht/dht_test.go [new file with mode: 0644]

diff --git a/cmd/dht-ping/main.go b/cmd/dht-ping/main.go
new file mode 100644 (file)
index 0000000..833ca6f
--- /dev/null
@@ -0,0 +1,58 @@
+package main
+
+import (
+       "bitbucket.org/anacrolix/go.torrent/dht"
+       "flag"
+       "log"
+       "net"
+       "os"
+)
+
+type pingResponse struct {
+       addr string
+       krpc dht.Msg
+}
+
+func main() {
+       log.SetFlags(log.LstdFlags | log.Lshortfile)
+       flag.Parse()
+       pingStrAddrs := flag.Args()
+       if len(pingStrAddrs) == 0 {
+               os.Stderr.WriteString("u must specify addrs of nodes to ping e.g. router.bittorrent.com:6881\n")
+               os.Exit(2)
+       }
+       s := dht.Server{}
+       var err error
+       s.Socket, err = net.ListenPacket("udp4", "")
+       if err != nil {
+               log.Fatal(err)
+       }
+       log.Printf("dht server on %s", s.Socket.LocalAddr())
+       s.Init()
+       go func() {
+               err := s.Serve()
+               if err != nil {
+                       log.Fatal(err)
+               }
+       }()
+       pingResponses := make(chan pingResponse)
+       for _, netloc := range pingStrAddrs {
+               addr, err := net.ResolveUDPAddr("udp4", netloc)
+               if err != nil {
+                       log.Fatal(err)
+               }
+               t, err := s.Ping(addr)
+               if err != nil {
+                       log.Fatal(err)
+               }
+               go func(addr string) {
+                       pingResponses <- pingResponse{
+                               addr: addr,
+                               krpc: <-t.Response,
+                       }
+               }(netloc)
+       }
+       for _ = range pingStrAddrs {
+               log.Print(<-pingResponses)
+       }
+}
index c539b2ee6f35ebe5cab8d66c832efbfd82f93750..1e6abe7245be74322e068544bf39802ddea1c982 100644 (file)
@@ -15,10 +15,11 @@ func main() {
        log.SetFlags(log.LstdFlags | log.Lshortfile)
        s := dht.Server{}
        var err error
-       s.Socket, err = net.ListenPacket("udp4", "")
+       s.Socket, err = net.ListenUDP("udp4", nil)
        if err != nil {
                log.Fatal(err)
        }
+       s.Init()
        log.Printf("dht server on %s", s.Socket.LocalAddr())
        go func() {
                err := s.Serve()
@@ -26,28 +27,9 @@ func main() {
                        log.Fatal(err)
                }
        }()
-       pingResponses := make(chan pingResponse)
-       pingStrAddrs := []string{
-               "router.utorrent.com:6881",
-               "router.bittorrent.com:6881",
-       }
-       for _, netloc := range pingStrAddrs {
-               addr, err := net.ResolveUDPAddr("udp4", netloc)
-               if err != nil {
-                       log.Fatal(err)
-               }
-               t, err := s.Ping(addr)
-               if err != nil {
-                       log.Fatal(err)
-               }
-               go func(addr string) {
-                       pingResponses <- pingResponse{
-                               addr: addr,
-                               krpc: <-t.Response,
-                       }
-               }(netloc)
-       }
-       for _ = range pingStrAddrs {
-               log.Print(<-pingResponses)
+       err = s.Bootstrap()
+       if err != nil {
+               log.Fatal(err)
        }
+       select {}
 }
index c6b5634f5ad0aea560f3c73a8c1ae9c7391d9a94..27c1a79e54f3ff9460b454688783b920381e2b41 100644 (file)
@@ -3,6 +3,7 @@ package dht
 import (
        "crypto/rand"
        "encoding/binary"
+       "errors"
        "fmt"
        "github.com/nsf/libtorgo/bencode"
        "io"
@@ -13,14 +14,14 @@ import (
 
 type Server struct {
        ID               string
-       Socket           net.PacketConn
+       Socket           *net.UDPConn
        transactions     []*transaction
        transactionIDInt uint64
        nodes            map[string]*Node
 }
 
 type Node struct {
-       addr          net.Addr
+       addr          *net.UDPAddr
        id            string
        lastHeardFrom time.Time
        lastSentTo    time.Time
@@ -38,6 +39,27 @@ type transaction struct {
        remoteAddr net.Addr
        t          string
        Response   chan Msg
+       response   chan Msg
+}
+
+func (s *Server) WriteNodes(w io.Writer) (n int, err error) {
+       for _, node := range s.nodes {
+               cni := compactNodeInfo{
+                       Addr: node.addr,
+               }
+               if n := copy(cni.ID[:], node.id); n != 20 {
+                       panic(n)
+               }
+               var b [26]byte
+               cni.PutBinary(b[:])
+               var nn int
+               nn, err = w.Write(b[:])
+               if err != nil {
+                       return
+               }
+               n += nn
+       }
+       return
 }
 
 func (s *Server) setDefaults() {
@@ -51,16 +73,15 @@ func (s *Server) setDefaults() {
        }
 }
 
-func (s *Server) init() {
+func (s *Server) Init() {
+       s.setDefaults()
        s.nodes = make(map[string]*Node, 1000)
 }
 
 func (s *Server) Serve() error {
-       s.setDefaults()
-       s.init()
        for {
                var b [1500]byte
-               n, addr, err := s.Socket.ReadFrom(b[:])
+               n, addr, err := s.Socket.ReadFromUDP(b[:])
                if err != nil {
                        return err
                }
@@ -71,7 +92,7 @@ func (s *Server) Serve() error {
                        continue
                }
                t := s.findResponseTransaction(d["t"].(string), addr)
-               t.Response <- d
+               t.response <- d
                s.removeTransaction(t)
                id := ""
                if d["y"] == "r" {
@@ -81,13 +102,13 @@ func (s *Server) Serve() error {
        }
 }
 
-func (s *Server) heardFromNode(addr net.Addr, id string) {
+func (s *Server) heardFromNode(addr *net.UDPAddr, id string) {
        n := s.getNode(addr)
        n.id = id
        n.lastHeardFrom = time.Now()
 }
 
-func (s *Server) getNode(addr net.Addr) (n *Node) {
+func (s *Server) getNode(addr *net.UDPAddr) (n *Node) {
        n = s.nodes[addr.String()]
        if n == nil {
                n = &Node{
@@ -98,7 +119,7 @@ func (s *Server) getNode(addr net.Addr) (n *Node) {
        return
 }
 
-func (s *Server) sentToNode(addr net.Addr) {
+func (s *Server) sentToNode(addr *net.UDPAddr) {
        n := s.getNode(addr)
        n.lastSentTo = time.Now()
 }
@@ -142,7 +163,7 @@ func (s *Server) IDString() string {
        return s.ID
 }
 
-func (s *Server) query(node net.Addr, q string, a map[string]string) (t *transaction, err error) {
+func (s *Server) query(node *net.UDPAddr, q string, a map[string]string) (t *transaction, err error) {
        tid := s.nextTransactionID()
        if a == nil {
                a = make(map[string]string, 1)
@@ -163,6 +184,7 @@ func (s *Server) query(node net.Addr, q string, a map[string]string) (t *transac
                t:          tid,
                Response:   make(chan Msg, 1),
        }
+       t.response = t.Response
        s.addTransaction(t)
        n, err := s.Socket.WriteTo(b, node)
        if err != nil {
@@ -178,10 +200,124 @@ func (s *Server) query(node net.Addr, q string, a map[string]string) (t *transac
        return
 }
 
-func (s *Server) GetPeers(node *net.UDPAddr, targetInfoHash [20]byte) {
+const compactAddrInfoLen = 26
+
+type compactAddrInfo *net.UDPAddr
+
+type compactNodeInfo struct {
+       ID   [20]byte
+       Addr compactAddrInfo
+}
+
+func (cni *compactNodeInfo) PutBinary(b []byte) {
+       if n := copy(b[:], cni.ID[:]); n != 20 {
+               panic(n)
+       }
+       ip := cni.Addr.IP.To4()
+       if len(ip) != 4 {
+               panic(ip)
+       }
+       if n := copy(b[20:], ip); n != 4 {
+               panic(n)
+       }
+       binary.BigEndian.PutUint16(b[24:], uint16(cni.Addr.Port))
+}
 
+func (cni *compactNodeInfo) UnmarshalBinary(b []byte) error {
+       if len(b) != 26 {
+               return errors.New("expected 26 bytes")
+       }
+       if 20 != copy(cni.ID[:], b[:20]) {
+               panic("impossibru!")
+       }
+       if cni.Addr == nil {
+               cni.Addr = &net.UDPAddr{}
+       }
+       cni.Addr.IP = net.IPv4(b[20], b[21], b[22], b[23])
+       cni.Addr.Port = int(binary.BigEndian.Uint16(b[24:26]))
+       return nil
 }
 
-func (s *Server) Ping(node net.Addr) (*transaction, error) {
+func (s *Server) Ping(node *net.UDPAddr) (*transaction, error) {
        return s.query(node, "ping", nil)
 }
+
+type findNodeResponse struct {
+       Nodes []compactNodeInfo
+}
+
+func (me *findNodeResponse) UnmarshalKRPCMsg(m Msg) error {
+       b := m["r"].(map[string]interface{})["nodes"].(string)
+       log.Printf("%q", b)
+       for i := 0; i < len(b); i += 26 {
+               var n compactNodeInfo
+               err := n.UnmarshalBinary([]byte(b[i : i+26]))
+               if err != nil {
+                       return err
+               }
+               me.Nodes = append(me.Nodes, n)
+       }
+       return nil
+}
+
+func (s *Server) FindNode(addr *net.UDPAddr, targetID string) (t *transaction, err error) {
+       log.Print(addr)
+       t, err = s.query(addr, "find_node", map[string]string{"target": targetID})
+       if err != nil {
+               return
+       }
+       ch := make(chan Msg)
+       t.response = ch
+       go func() {
+               d, ok := <-t.response
+               if !ok {
+                       close(t.Response)
+                       return
+               }
+               if d["y"] == "r" {
+                       var r findNodeResponse
+                       err = r.UnmarshalKRPCMsg(d)
+                       if err != nil {
+                               log.Print(err)
+                       } else {
+                               for _, cni := range r.Nodes {
+                                       n := s.getNode(cni.Addr)
+                                       n.id = string(cni.ID[:])
+                               }
+                       }
+               }
+               t.Response <- d
+       }()
+       return
+}
+
+func (s *Server) Bootstrap() error {
+       if len(s.nodes) == 0 {
+               addr, err := net.ResolveUDPAddr("udp4", "router.bittorrent.com:6881")
+               if err != nil {
+                       return err
+               }
+               s.nodes[addr.String()] = &Node{
+                       addr: addr,
+               }
+       }
+       queriedNodes := make(map[string]bool, 1000)
+       for {
+               for _, node := range s.nodes {
+                       if queriedNodes[node.addr.String()] {
+                               log.Printf("skipping already queried: %s", node.addr)
+                               continue
+                       }
+                       t, err := s.FindNode(node.addr, s.ID)
+                       if err != nil {
+                               return err
+                       }
+                       queriedNodes[node.addr.String()] = true
+                       go func() {
+                               log.Print(<-t.Response)
+                       }()
+               }
+               time.Sleep(3 * time.Second)
+       }
+       return nil
+}
diff --git a/dht/dht_test.go b/dht/dht_test.go
new file mode 100644 (file)
index 0000000..83bd6e8
--- /dev/null
@@ -0,0 +1,28 @@
+package dht
+
+import (
+       "net"
+       "testing"
+)
+
+func TestMarshalCompactNodeInfo(t *testing.T) {
+       cni := compactNodeInfo{
+               ID: [20]byte{'a', 'b', 'c'},
+       }
+       var err error
+       cni.Addr, err = net.ResolveUDPAddr("udp4", "1.2.3.4:5")
+       if err != nil {
+               t.Fatal(err)
+       }
+       var b [compactAddrInfoLen]byte
+       cni.PutBinary(b[:])
+       if err != nil {
+               t.Fatal(err)
+       }
+       var bb [26]byte
+       copy(bb[:], []byte("abc"))
+       copy(bb[20:], []byte("\x01\x02\x03\x04\x00\x05"))
+       if b != bb {
+               t.FailNow()
+       }
+}