]> Sergey Matveev's repositories - btrtrc.git/commitdiff
Implement dht-get-peers command and GetPeers method in dht package
authorMatt Joiner <anacrolix@gmail.com>
Thu, 3 Jul 2014 15:43:04 +0000 (01:43 +1000)
committerMatt Joiner <anacrolix@gmail.com>
Thu, 3 Jul 2014 15:43:04 +0000 (01:43 +1000)
cmd/dht-get-peers/main.go [new file with mode: 0644]
dht/dht.go

diff --git a/cmd/dht-get-peers/main.go b/cmd/dht-get-peers/main.go
new file mode 100644 (file)
index 0000000..9e8b895
--- /dev/null
@@ -0,0 +1,151 @@
+package main
+
+import (
+       "bitbucket.org/anacrolix/go.torrent/dht"
+       "flag"
+       "fmt"
+       "io"
+       "log"
+       "net"
+       "os"
+       "os/signal"
+)
+
+type pingResponse struct {
+       addr string
+       krpc dht.Msg
+}
+
+var (
+       tableFileName = flag.String("tableFile", "", "name of file for storing node info")
+       serveAddr     = flag.String("serveAddr", ":0", "local UDP address")
+       infoHash      = flag.String("infoHash", "", "torrent infohash")
+
+       s dht.Server
+)
+
+func loadTable() error {
+       if *tableFileName == "" {
+               return nil
+       }
+       f, err := os.Open(*tableFileName)
+       if os.IsNotExist(err) {
+               return nil
+       }
+       if err != nil {
+               return fmt.Errorf("error opening table file: %s", err)
+       }
+       defer f.Close()
+       added := 0
+       for {
+               b := make([]byte, dht.CompactNodeInfoLen)
+               _, err := io.ReadFull(f, b)
+               if err == io.EOF {
+                       break
+               }
+               if err != nil {
+                       return fmt.Errorf("error reading table file: %s", err)
+               }
+               var ni dht.NodeInfo
+               err = ni.UnmarshalCompact(b)
+               if err != nil {
+                       return fmt.Errorf("error unmarshaling compact node info: %s", err)
+               }
+               s.AddNode(ni)
+               added++
+       }
+       log.Printf("loaded %d nodes from table file", added)
+       return nil
+}
+
+func init() {
+       log.SetFlags(log.LstdFlags | log.Lshortfile)
+       flag.Parse()
+       switch len(*infoHash) {
+       case 20:
+       case 40:
+               if _, err := fmt.Sscanf(*infoHash, "%x", infoHash); err != nil {
+                       log.Fatal(err)
+               }
+       default:
+               log.Fatal("require 20 byte infohash")
+       }
+       var err error
+       s.Socket, err = net.ListenUDP("udp4", func() *net.UDPAddr {
+               addr, err := net.ResolveUDPAddr("udp4", *serveAddr)
+               if err != nil {
+                       log.Fatalf("error resolving serve addr: %s", err)
+               }
+               return addr
+       }())
+       if err != nil {
+               log.Fatal(err)
+       }
+       s.Init()
+       err = loadTable()
+       if err != nil {
+               log.Fatalf("error loading table: %s", err)
+       }
+       log.Printf("dht server on %s, ID is %q", s.Socket.LocalAddr(), s.IDString())
+       setupSignals()
+}
+
+func saveTable() error {
+       goodNodes := s.Nodes()
+       if *tableFileName == "" {
+               if len(goodNodes) != 0 {
+                       log.Printf("discarding %d good nodes!", len(goodNodes))
+               }
+               return nil
+       }
+       f, err := os.OpenFile(*tableFileName, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0666)
+       if err != nil {
+               return fmt.Errorf("error opening table file: %s", err)
+       }
+       defer f.Close()
+       for _, nodeInfo := range goodNodes {
+               var b [dht.CompactNodeInfoLen]byte
+               err := nodeInfo.PutCompact(b[:])
+               if err != nil {
+                       return fmt.Errorf("error compacting node info: %s", err)
+               }
+               _, err = f.Write(b[:])
+               if err != nil {
+                       return fmt.Errorf("error writing compact node info: %s", err)
+               }
+       }
+       log.Printf("saved %d nodes to table file", len(goodNodes))
+       return nil
+}
+
+func setupSignals() {
+       ch := make(chan os.Signal)
+       signal.Notify(ch)
+       go func() {
+               <-ch
+               s.StopServing()
+       }()
+}
+
+func main() {
+       // go s.Bootstrap()
+       go func() {
+               ps, err := s.GetPeers(*infoHash)
+               if err != nil {
+                       log.Fatal(err)
+               }
+               for sl := range ps.Values {
+                       for _, p := range sl {
+                               fmt.Println(p)
+                       }
+               }
+               s.StopServing()
+       }()
+       err := s.Serve()
+       if err := saveTable(); err != nil {
+               log.Printf("error saving node table: %s", err)
+       }
+       if err != nil {
+               log.Fatalf("error serving dht: %s", err)
+       }
+}
index 0bc6c4f4f3f2c841af647027625eaf2fdbc07aae..2b132b6ebb0bc7b2e42e6e040caa35b6ecce6ec8 100644 (file)
@@ -1,6 +1,8 @@
 package dht
 
 import (
+       "bitbucket.org/anacrolix/go.torrent/tracker"
+       "bitbucket.org/anacrolix/go.torrent/util"
        "crypto"
        _ "crypto/sha1"
        "encoding/binary"
@@ -85,6 +87,7 @@ func (s *Server) setDefaults() (err error) {
                }
                s.ID = string(id[:])
        }
+       s.nodes = make(map[string]*Node, 10000)
        return
 }
 
@@ -95,7 +98,7 @@ func (s *Server) Init() error {
 
 func (s *Server) Serve() error {
        for {
-               var b [1500]byte
+               var b [0x10000]byte
                n, addr, err := s.Socket.ReadFromUDP(b[:])
                if err != nil {
                        return err
@@ -103,7 +106,7 @@ func (s *Server) Serve() error {
                var d map[string]interface{}
                err = bencode.Unmarshal(b[:n], &d)
                if err != nil {
-                       log.Printf("bad krpc message: %s", err)
+                       log.Printf("bad krpc message: %s: %q", err, b[:n])
                        continue
                }
                s.mu.Lock()
@@ -343,6 +346,39 @@ func (me *findNodeResponse) UnmarshalKRPCMsg(m Msg) error {
        return nil
 }
 
+func (t *transaction) onResponse(f func(m Msg)) {
+       ch := make(chan Msg)
+       t.response = ch
+       go func() {
+               d, ok := <-t.response
+               if !ok {
+                       close(t.Response)
+                       return
+               }
+               f(d)
+               t.Response <- d
+       }()
+}
+
+func (s *Server) liftNodes(d Msg) {
+       if d["y"] != "r" {
+               return
+       }
+       var r findNodeResponse
+       err := r.UnmarshalKRPCMsg(d)
+       if err != nil {
+               // log.Print(err)
+       } else {
+               s.mu.Lock()
+               for _, cni := range r.Nodes {
+                       n := s.getNode(cni.Addr)
+                       n.id = string(cni.ID[:])
+               }
+               s.mu.Unlock()
+               // log.Printf("lifted %d nodes", len(r.Nodes))
+       }
+}
+
 // Sends a find_node query to addr. targetID is the node we're looking for.
 func (s *Server) FindNode(addr *net.UDPAddr, targetID string) (t *transaction, err error) {
        t, err = s.query(addr, "find_node", map[string]string{"target": targetID})
@@ -351,33 +387,125 @@ func (s *Server) FindNode(addr *net.UDPAddr, targetID string) (t *transaction, e
        }
        // Scrape peers from the response to put in the server's table before
        // handing the response back to the caller.
-       ch := make(chan Msg)
-       t.response = ch
-       go func() {
-               d, ok := <-t.response
+       t.onResponse(func(d Msg) {
+               s.liftNodes(d)
+       })
+       return
+}
+
+type getPeersResponse struct {
+       Values []tracker.CompactPeer `bencode:"values"`
+       Nodes  util.CompactPeers     `bencode:"nodes"`
+}
+
+type peerStream struct {
+       mu     sync.Mutex
+       Values chan []tracker.CompactPeer
+       stop   chan struct{}
+}
+
+func (ps *peerStream) Close() {
+       ps.mu.Lock()
+       select {
+       case <-ps.stop:
+       default:
+               close(ps.stop)
+               close(ps.Values)
+       }
+       ps.mu.Unlock()
+}
+
+func extractValues(m Msg) (vs []tracker.CompactPeer) {
+       r, ok := m["r"]
+       if !ok {
+               return
+       }
+       rd, ok := r.(map[string]interface{})
+       if !ok {
+               return
+       }
+       v, ok := rd["values"]
+       if !ok {
+               return
+       }
+       // log.Fatal(m)
+       vl, ok := v.([]interface{})
+       if !ok {
+               panic(v)
+       }
+       vs = make([]tracker.CompactPeer, 0, len(vl))
+       for _, i := range vl {
+               // log.Printf("%T", i)
+               s, ok := i.(string)
                if !ok {
-                       close(t.Response)
-                       return
+                       panic(i)
                }
-               if d["y"] == "r" {
-                       var r findNodeResponse
-                       err = r.UnmarshalKRPCMsg(d)
-                       if err != nil {
-                               log.Print(err)
-                       } else {
-                               s.mu.Lock()
-                               for _, cni := range r.Nodes {
-                                       n := s.getNode(cni.Addr)
-                                       n.id = string(cni.ID[:])
+               var cp tracker.CompactPeer
+               err := cp.UnmarshalBinary([]byte(s))
+               if err != nil {
+                       log.Printf("error decoding values list element: %s", err)
+                       continue
+               }
+               vs = append(vs, cp)
+       }
+       return
+}
+
+func (s *Server) GetPeers(infoHash string) (ps *peerStream, err error) {
+       ps = &peerStream{
+               Values: make(chan []tracker.CompactPeer),
+               stop:   make(chan struct{}),
+       }
+       done := make(chan struct{})
+       pending := 0
+       s.mu.Lock()
+       for _, n := range s.nodes {
+               var t *transaction
+               t, err = s.getPeers(n.addr, infoHash)
+               if err != nil {
+                       ps.Close()
+                       break
+               }
+               go func() {
+                       select {
+                       case m := <-t.Response:
+                               vs := extractValues(m)
+                               if vs != nil {
+                                       ps.Values <- vs
+                                       // } else {
+                                       // log.Print("get_peers response had no values")
                                }
-                               s.mu.Unlock()
+                       case <-ps.stop:
                        }
+                       done <- struct{}{}
+               }()
+               pending++
+       }
+       s.mu.Unlock()
+       go func() {
+               for ; pending > 0; pending-- {
+                       <-done
                }
-               t.Response <- d
+               ps.Close()
        }()
        return
 }
 
+func (s *Server) getPeers(addr *net.UDPAddr, infoHash string) (t *transaction, err error) {
+       if len(infoHash) != 20 {
+               err = fmt.Errorf("infohash has bad length")
+               return
+       }
+       t, err = s.query(addr, "get_peers", map[string]string{"info_hash": infoHash})
+       if err != nil {
+               return
+       }
+       t.onResponse(func(m Msg) {
+               s.liftNodes(m)
+       })
+       return
+}
+
 func (s *Server) addRootNode() error {
        addr, err := net.ResolveUDPAddr("udp4", "router.bittorrent.com:6881")
        if err != nil {
@@ -399,17 +527,24 @@ func (s *Server) Bootstrap() (err error) {
                        return
                }
        }
-       for _, node := range s.nodes {
-               var t *transaction
-               s.mu.Unlock()
-               t, err = s.FindNode(node.addr, s.ID)
-               s.mu.Lock()
-               if err != nil {
-                       return
+       for {
+               for _, node := range s.nodes {
+                       var t *transaction
+                       s.mu.Unlock()
+                       t, err = s.FindNode(node.addr, s.ID)
+                       s.mu.Lock()
+                       if err != nil {
+                               return
+                       }
+                       go func() {
+                               <-t.Response
+                       }()
+               }
+               time.Sleep(5 * time.Second)
+               log.Printf("now have %d nodes", len(s.nodes))
+               if len(s.nodes) >= 8*160 {
+                       break
                }
-               go func() {
-                       <-t.Response
-               }()
        }
        return
 }
@@ -424,7 +559,7 @@ func (s *Server) Nodes() (nis []NodeInfo) {
                ni := NodeInfo{
                        Addr: node.addr,
                }
-               if n := copy(ni.ID[:], node.id); n != 20 {
+               if n := copy(ni.ID[:], node.id); n != 20 && n != 0 {
                        panic(n)
                }
                nis = append(nis, ni)