From 1b69e69461650c574d620b8cd459f83c8a689de2 Mon Sep 17 00:00:00 2001 From: Matt Joiner Date: Sun, 25 May 2014 21:34:29 +1000 Subject: [PATCH] Switch dht-server to bootstrapping --- cmd/dht-ping/main.go | 58 +++++++++++++++ cmd/dht-server/main.go | 30 ++------ dht/dht.go | 162 +++++++++++++++++++++++++++++++++++++---- dht/dht_test.go | 28 +++++++ 4 files changed, 241 insertions(+), 37 deletions(-) create mode 100644 cmd/dht-ping/main.go create mode 100644 dht/dht_test.go diff --git a/cmd/dht-ping/main.go b/cmd/dht-ping/main.go new file mode 100644 index 00000000..833ca6fd --- /dev/null +++ b/cmd/dht-ping/main.go @@ -0,0 +1,58 @@ +package main + +import ( + "bitbucket.org/anacrolix/go.torrent/dht" + "flag" + "log" + "net" + "os" +) + +type pingResponse struct { + addr string + krpc dht.Msg +} + +func main() { + log.SetFlags(log.LstdFlags | log.Lshortfile) + flag.Parse() + pingStrAddrs := flag.Args() + if len(pingStrAddrs) == 0 { + os.Stderr.WriteString("u must specify addrs of nodes to ping e.g. router.bittorrent.com:6881\n") + os.Exit(2) + } + s := dht.Server{} + var err error + s.Socket, err = net.ListenPacket("udp4", "") + if err != nil { + log.Fatal(err) + } + log.Printf("dht server on %s", s.Socket.LocalAddr()) + s.Init() + go func() { + err := s.Serve() + if err != nil { + log.Fatal(err) + } + }() + pingResponses := make(chan pingResponse) + for _, netloc := range pingStrAddrs { + addr, err := net.ResolveUDPAddr("udp4", netloc) + if err != nil { + log.Fatal(err) + } + t, err := s.Ping(addr) + if err != nil { + log.Fatal(err) + } + go func(addr string) { + pingResponses <- pingResponse{ + addr: addr, + krpc: <-t.Response, + } + }(netloc) + } + for _ = range pingStrAddrs { + log.Print(<-pingResponses) + } +} diff --git a/cmd/dht-server/main.go b/cmd/dht-server/main.go index c539b2ee..1e6abe72 100644 --- a/cmd/dht-server/main.go +++ b/cmd/dht-server/main.go @@ -15,10 +15,11 @@ func main() { log.SetFlags(log.LstdFlags | log.Lshortfile) s := dht.Server{} var err error - s.Socket, err = net.ListenPacket("udp4", "") + s.Socket, err = net.ListenUDP("udp4", nil) if err != nil { log.Fatal(err) } + s.Init() log.Printf("dht server on %s", s.Socket.LocalAddr()) go func() { err := s.Serve() @@ -26,28 +27,9 @@ func main() { log.Fatal(err) } }() - pingResponses := make(chan pingResponse) - pingStrAddrs := []string{ - "router.utorrent.com:6881", - "router.bittorrent.com:6881", - } - for _, netloc := range pingStrAddrs { - addr, err := net.ResolveUDPAddr("udp4", netloc) - if err != nil { - log.Fatal(err) - } - t, err := s.Ping(addr) - if err != nil { - log.Fatal(err) - } - go func(addr string) { - pingResponses <- pingResponse{ - addr: addr, - krpc: <-t.Response, - } - }(netloc) - } - for _ = range pingStrAddrs { - log.Print(<-pingResponses) + err = s.Bootstrap() + if err != nil { + log.Fatal(err) } + select {} } diff --git a/dht/dht.go b/dht/dht.go index c6b5634f..27c1a79e 100644 --- a/dht/dht.go +++ b/dht/dht.go @@ -3,6 +3,7 @@ package dht import ( "crypto/rand" "encoding/binary" + "errors" "fmt" "github.com/nsf/libtorgo/bencode" "io" @@ -13,14 +14,14 @@ import ( type Server struct { ID string - Socket net.PacketConn + Socket *net.UDPConn transactions []*transaction transactionIDInt uint64 nodes map[string]*Node } type Node struct { - addr net.Addr + addr *net.UDPAddr id string lastHeardFrom time.Time lastSentTo time.Time @@ -38,6 +39,27 @@ type transaction struct { remoteAddr net.Addr t string Response chan Msg + response chan Msg +} + +func (s *Server) WriteNodes(w io.Writer) (n int, err error) { + for _, node := range s.nodes { + cni := compactNodeInfo{ + Addr: node.addr, + } + if n := copy(cni.ID[:], node.id); n != 20 { + panic(n) + } + var b [26]byte + cni.PutBinary(b[:]) + var nn int + nn, err = w.Write(b[:]) + if err != nil { + return + } + n += nn + } + return } func (s *Server) setDefaults() { @@ -51,16 +73,15 @@ func (s *Server) setDefaults() { } } -func (s *Server) init() { +func (s *Server) Init() { + s.setDefaults() s.nodes = make(map[string]*Node, 1000) } func (s *Server) Serve() error { - s.setDefaults() - s.init() for { var b [1500]byte - n, addr, err := s.Socket.ReadFrom(b[:]) + n, addr, err := s.Socket.ReadFromUDP(b[:]) if err != nil { return err } @@ -71,7 +92,7 @@ func (s *Server) Serve() error { continue } t := s.findResponseTransaction(d["t"].(string), addr) - t.Response <- d + t.response <- d s.removeTransaction(t) id := "" if d["y"] == "r" { @@ -81,13 +102,13 @@ func (s *Server) Serve() error { } } -func (s *Server) heardFromNode(addr net.Addr, id string) { +func (s *Server) heardFromNode(addr *net.UDPAddr, id string) { n := s.getNode(addr) n.id = id n.lastHeardFrom = time.Now() } -func (s *Server) getNode(addr net.Addr) (n *Node) { +func (s *Server) getNode(addr *net.UDPAddr) (n *Node) { n = s.nodes[addr.String()] if n == nil { n = &Node{ @@ -98,7 +119,7 @@ func (s *Server) getNode(addr net.Addr) (n *Node) { return } -func (s *Server) sentToNode(addr net.Addr) { +func (s *Server) sentToNode(addr *net.UDPAddr) { n := s.getNode(addr) n.lastSentTo = time.Now() } @@ -142,7 +163,7 @@ func (s *Server) IDString() string { return s.ID } -func (s *Server) query(node net.Addr, q string, a map[string]string) (t *transaction, err error) { +func (s *Server) query(node *net.UDPAddr, q string, a map[string]string) (t *transaction, err error) { tid := s.nextTransactionID() if a == nil { a = make(map[string]string, 1) @@ -163,6 +184,7 @@ func (s *Server) query(node net.Addr, q string, a map[string]string) (t *transac t: tid, Response: make(chan Msg, 1), } + t.response = t.Response s.addTransaction(t) n, err := s.Socket.WriteTo(b, node) if err != nil { @@ -178,10 +200,124 @@ func (s *Server) query(node net.Addr, q string, a map[string]string) (t *transac return } -func (s *Server) GetPeers(node *net.UDPAddr, targetInfoHash [20]byte) { +const compactAddrInfoLen = 26 + +type compactAddrInfo *net.UDPAddr + +type compactNodeInfo struct { + ID [20]byte + Addr compactAddrInfo +} + +func (cni *compactNodeInfo) PutBinary(b []byte) { + if n := copy(b[:], cni.ID[:]); n != 20 { + panic(n) + } + ip := cni.Addr.IP.To4() + if len(ip) != 4 { + panic(ip) + } + if n := copy(b[20:], ip); n != 4 { + panic(n) + } + binary.BigEndian.PutUint16(b[24:], uint16(cni.Addr.Port)) +} +func (cni *compactNodeInfo) UnmarshalBinary(b []byte) error { + if len(b) != 26 { + return errors.New("expected 26 bytes") + } + if 20 != copy(cni.ID[:], b[:20]) { + panic("impossibru!") + } + if cni.Addr == nil { + cni.Addr = &net.UDPAddr{} + } + cni.Addr.IP = net.IPv4(b[20], b[21], b[22], b[23]) + cni.Addr.Port = int(binary.BigEndian.Uint16(b[24:26])) + return nil } -func (s *Server) Ping(node net.Addr) (*transaction, error) { +func (s *Server) Ping(node *net.UDPAddr) (*transaction, error) { return s.query(node, "ping", nil) } + +type findNodeResponse struct { + Nodes []compactNodeInfo +} + +func (me *findNodeResponse) UnmarshalKRPCMsg(m Msg) error { + b := m["r"].(map[string]interface{})["nodes"].(string) + log.Printf("%q", b) + for i := 0; i < len(b); i += 26 { + var n compactNodeInfo + err := n.UnmarshalBinary([]byte(b[i : i+26])) + if err != nil { + return err + } + me.Nodes = append(me.Nodes, n) + } + return nil +} + +func (s *Server) FindNode(addr *net.UDPAddr, targetID string) (t *transaction, err error) { + log.Print(addr) + t, err = s.query(addr, "find_node", map[string]string{"target": targetID}) + if err != nil { + return + } + ch := make(chan Msg) + t.response = ch + go func() { + d, ok := <-t.response + if !ok { + close(t.Response) + return + } + if d["y"] == "r" { + var r findNodeResponse + err = r.UnmarshalKRPCMsg(d) + if err != nil { + log.Print(err) + } else { + for _, cni := range r.Nodes { + n := s.getNode(cni.Addr) + n.id = string(cni.ID[:]) + } + } + } + t.Response <- d + }() + return +} + +func (s *Server) Bootstrap() error { + if len(s.nodes) == 0 { + addr, err := net.ResolveUDPAddr("udp4", "router.bittorrent.com:6881") + if err != nil { + return err + } + s.nodes[addr.String()] = &Node{ + addr: addr, + } + } + queriedNodes := make(map[string]bool, 1000) + for { + for _, node := range s.nodes { + if queriedNodes[node.addr.String()] { + log.Printf("skipping already queried: %s", node.addr) + continue + } + t, err := s.FindNode(node.addr, s.ID) + if err != nil { + return err + } + queriedNodes[node.addr.String()] = true + go func() { + log.Print(<-t.Response) + }() + } + time.Sleep(3 * time.Second) + } + return nil +} diff --git a/dht/dht_test.go b/dht/dht_test.go new file mode 100644 index 00000000..83bd6e8f --- /dev/null +++ b/dht/dht_test.go @@ -0,0 +1,28 @@ +package dht + +import ( + "net" + "testing" +) + +func TestMarshalCompactNodeInfo(t *testing.T) { + cni := compactNodeInfo{ + ID: [20]byte{'a', 'b', 'c'}, + } + var err error + cni.Addr, err = net.ResolveUDPAddr("udp4", "1.2.3.4:5") + if err != nil { + t.Fatal(err) + } + var b [compactAddrInfoLen]byte + cni.PutBinary(b[:]) + if err != nil { + t.Fatal(err) + } + var bb [26]byte + copy(bb[:], []byte("abc")) + copy(bb[20:], []byte("\x01\x02\x03\x04\x00\x05")) + if b != bb { + t.FailNow() + } +} -- 2.48.1