From: Matt Joiner Date: Thu, 3 Jul 2014 15:43:04 +0000 (+1000) Subject: Implement dht-get-peers command and GetPeers method in dht package X-Git-Tag: v1.0.0~1686 X-Git-Url: http://www.git.stargrave.org/?a=commitdiff_plain;h=78fe1b11ae7021393d66106b6e6e7c4791d106ec;p=btrtrc.git Implement dht-get-peers command and GetPeers method in dht package --- diff --git a/cmd/dht-get-peers/main.go b/cmd/dht-get-peers/main.go new file mode 100644 index 00000000..9e8b8951 --- /dev/null +++ b/cmd/dht-get-peers/main.go @@ -0,0 +1,151 @@ +package main + +import ( + "bitbucket.org/anacrolix/go.torrent/dht" + "flag" + "fmt" + "io" + "log" + "net" + "os" + "os/signal" +) + +type pingResponse struct { + addr string + krpc dht.Msg +} + +var ( + tableFileName = flag.String("tableFile", "", "name of file for storing node info") + serveAddr = flag.String("serveAddr", ":0", "local UDP address") + infoHash = flag.String("infoHash", "", "torrent infohash") + + s dht.Server +) + +func loadTable() error { + if *tableFileName == "" { + return nil + } + f, err := os.Open(*tableFileName) + if os.IsNotExist(err) { + return nil + } + if err != nil { + return fmt.Errorf("error opening table file: %s", err) + } + defer f.Close() + added := 0 + for { + b := make([]byte, dht.CompactNodeInfoLen) + _, err := io.ReadFull(f, b) + if err == io.EOF { + break + } + if err != nil { + return fmt.Errorf("error reading table file: %s", err) + } + var ni dht.NodeInfo + err = ni.UnmarshalCompact(b) + if err != nil { + return fmt.Errorf("error unmarshaling compact node info: %s", err) + } + s.AddNode(ni) + added++ + } + log.Printf("loaded %d nodes from table file", added) + return nil +} + +func init() { + log.SetFlags(log.LstdFlags | log.Lshortfile) + flag.Parse() + switch len(*infoHash) { + case 20: + case 40: + if _, err := fmt.Sscanf(*infoHash, "%x", infoHash); err != nil { + log.Fatal(err) + } + default: + log.Fatal("require 20 byte infohash") + } + var err error + s.Socket, err = net.ListenUDP("udp4", func() *net.UDPAddr { + addr, err := net.ResolveUDPAddr("udp4", *serveAddr) + if err != nil { + log.Fatalf("error resolving serve addr: %s", err) + } + return addr + }()) + if err != nil { + log.Fatal(err) + } + s.Init() + err = loadTable() + if err != nil { + log.Fatalf("error loading table: %s", err) + } + log.Printf("dht server on %s, ID is %q", s.Socket.LocalAddr(), s.IDString()) + setupSignals() +} + +func saveTable() error { + goodNodes := s.Nodes() + if *tableFileName == "" { + if len(goodNodes) != 0 { + log.Printf("discarding %d good nodes!", len(goodNodes)) + } + return nil + } + f, err := os.OpenFile(*tableFileName, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0666) + if err != nil { + return fmt.Errorf("error opening table file: %s", err) + } + defer f.Close() + for _, nodeInfo := range goodNodes { + var b [dht.CompactNodeInfoLen]byte + err := nodeInfo.PutCompact(b[:]) + if err != nil { + return fmt.Errorf("error compacting node info: %s", err) + } + _, err = f.Write(b[:]) + if err != nil { + return fmt.Errorf("error writing compact node info: %s", err) + } + } + log.Printf("saved %d nodes to table file", len(goodNodes)) + return nil +} + +func setupSignals() { + ch := make(chan os.Signal) + signal.Notify(ch) + go func() { + <-ch + s.StopServing() + }() +} + +func main() { + // go s.Bootstrap() + go func() { + ps, err := s.GetPeers(*infoHash) + if err != nil { + log.Fatal(err) + } + for sl := range ps.Values { + for _, p := range sl { + fmt.Println(p) + } + } + s.StopServing() + }() + err := s.Serve() + if err := saveTable(); err != nil { + log.Printf("error saving node table: %s", err) + } + if err != nil { + log.Fatalf("error serving dht: %s", err) + } +} diff --git a/dht/dht.go b/dht/dht.go index 0bc6c4f4..2b132b6e 100644 --- a/dht/dht.go +++ b/dht/dht.go @@ -1,6 +1,8 @@ package dht import ( + "bitbucket.org/anacrolix/go.torrent/tracker" + "bitbucket.org/anacrolix/go.torrent/util" "crypto" _ "crypto/sha1" "encoding/binary" @@ -85,6 +87,7 @@ func (s *Server) setDefaults() (err error) { } s.ID = string(id[:]) } + s.nodes = make(map[string]*Node, 10000) return } @@ -95,7 +98,7 @@ func (s *Server) Init() error { func (s *Server) Serve() error { for { - var b [1500]byte + var b [0x10000]byte n, addr, err := s.Socket.ReadFromUDP(b[:]) if err != nil { return err @@ -103,7 +106,7 @@ func (s *Server) Serve() error { var d map[string]interface{} err = bencode.Unmarshal(b[:n], &d) if err != nil { - log.Printf("bad krpc message: %s", err) + log.Printf("bad krpc message: %s: %q", err, b[:n]) continue } s.mu.Lock() @@ -343,6 +346,39 @@ func (me *findNodeResponse) UnmarshalKRPCMsg(m Msg) error { return nil } +func (t *transaction) onResponse(f func(m Msg)) { + ch := make(chan Msg) + t.response = ch + go func() { + d, ok := <-t.response + if !ok { + close(t.Response) + return + } + f(d) + t.Response <- d + }() +} + +func (s *Server) liftNodes(d Msg) { + if d["y"] != "r" { + return + } + var r findNodeResponse + err := r.UnmarshalKRPCMsg(d) + if err != nil { + // log.Print(err) + } else { + s.mu.Lock() + for _, cni := range r.Nodes { + n := s.getNode(cni.Addr) + n.id = string(cni.ID[:]) + } + s.mu.Unlock() + // log.Printf("lifted %d nodes", len(r.Nodes)) + } +} + // Sends a find_node query to addr. targetID is the node we're looking for. func (s *Server) FindNode(addr *net.UDPAddr, targetID string) (t *transaction, err error) { t, err = s.query(addr, "find_node", map[string]string{"target": targetID}) @@ -351,33 +387,125 @@ func (s *Server) FindNode(addr *net.UDPAddr, targetID string) (t *transaction, e } // Scrape peers from the response to put in the server's table before // handing the response back to the caller. - ch := make(chan Msg) - t.response = ch - go func() { - d, ok := <-t.response + t.onResponse(func(d Msg) { + s.liftNodes(d) + }) + return +} + +type getPeersResponse struct { + Values []tracker.CompactPeer `bencode:"values"` + Nodes util.CompactPeers `bencode:"nodes"` +} + +type peerStream struct { + mu sync.Mutex + Values chan []tracker.CompactPeer + stop chan struct{} +} + +func (ps *peerStream) Close() { + ps.mu.Lock() + select { + case <-ps.stop: + default: + close(ps.stop) + close(ps.Values) + } + ps.mu.Unlock() +} + +func extractValues(m Msg) (vs []tracker.CompactPeer) { + r, ok := m["r"] + if !ok { + return + } + rd, ok := r.(map[string]interface{}) + if !ok { + return + } + v, ok := rd["values"] + if !ok { + return + } + // log.Fatal(m) + vl, ok := v.([]interface{}) + if !ok { + panic(v) + } + vs = make([]tracker.CompactPeer, 0, len(vl)) + for _, i := range vl { + // log.Printf("%T", i) + s, ok := i.(string) if !ok { - close(t.Response) - return + panic(i) } - if d["y"] == "r" { - var r findNodeResponse - err = r.UnmarshalKRPCMsg(d) - if err != nil { - log.Print(err) - } else { - s.mu.Lock() - for _, cni := range r.Nodes { - n := s.getNode(cni.Addr) - n.id = string(cni.ID[:]) + var cp tracker.CompactPeer + err := cp.UnmarshalBinary([]byte(s)) + if err != nil { + log.Printf("error decoding values list element: %s", err) + continue + } + vs = append(vs, cp) + } + return +} + +func (s *Server) GetPeers(infoHash string) (ps *peerStream, err error) { + ps = &peerStream{ + Values: make(chan []tracker.CompactPeer), + stop: make(chan struct{}), + } + done := make(chan struct{}) + pending := 0 + s.mu.Lock() + for _, n := range s.nodes { + var t *transaction + t, err = s.getPeers(n.addr, infoHash) + if err != nil { + ps.Close() + break + } + go func() { + select { + case m := <-t.Response: + vs := extractValues(m) + if vs != nil { + ps.Values <- vs + // } else { + // log.Print("get_peers response had no values") } - s.mu.Unlock() + case <-ps.stop: } + done <- struct{}{} + }() + pending++ + } + s.mu.Unlock() + go func() { + for ; pending > 0; pending-- { + <-done } - t.Response <- d + ps.Close() }() return } +func (s *Server) getPeers(addr *net.UDPAddr, infoHash string) (t *transaction, err error) { + if len(infoHash) != 20 { + err = fmt.Errorf("infohash has bad length") + return + } + t, err = s.query(addr, "get_peers", map[string]string{"info_hash": infoHash}) + if err != nil { + return + } + t.onResponse(func(m Msg) { + s.liftNodes(m) + }) + return +} + func (s *Server) addRootNode() error { addr, err := net.ResolveUDPAddr("udp4", "router.bittorrent.com:6881") if err != nil { @@ -399,17 +527,24 @@ func (s *Server) Bootstrap() (err error) { return } } - for _, node := range s.nodes { - var t *transaction - s.mu.Unlock() - t, err = s.FindNode(node.addr, s.ID) - s.mu.Lock() - if err != nil { - return + for { + for _, node := range s.nodes { + var t *transaction + s.mu.Unlock() + t, err = s.FindNode(node.addr, s.ID) + s.mu.Lock() + if err != nil { + return + } + go func() { + <-t.Response + }() + } + time.Sleep(5 * time.Second) + log.Printf("now have %d nodes", len(s.nodes)) + if len(s.nodes) >= 8*160 { + break } - go func() { - <-t.Response - }() } return } @@ -424,7 +559,7 @@ func (s *Server) Nodes() (nis []NodeInfo) { ni := NodeInfo{ Addr: node.addr, } - if n := copy(ni.ID[:], node.id); n != 20 { + if n := copy(ni.ID[:], node.id); n != 20 && n != 0 { panic(n) } nis = append(nis, ni)