From: Matt Joiner Date: Tue, 27 May 2014 06:28:56 +0000 (+1000) Subject: Got dht-server working nicely X-Git-Tag: v1.0.0~1723 X-Git-Url: http://www.git.stargrave.org/?a=commitdiff_plain;h=83a02420a5b099a4901df67ee2e1f59534d9135f;p=btrtrc.git Got dht-server working nicely --- diff --git a/cmd/dht-ping/main.go b/cmd/dht-ping/main.go index 833ca6fd..9e43f798 100644 --- a/cmd/dht-ping/main.go +++ b/cmd/dht-ping/main.go @@ -23,7 +23,7 @@ func main() { } s := dht.Server{} var err error - s.Socket, err = net.ListenPacket("udp4", "") + s.Socket, err = net.ListenUDP("udp4", nil) if err != nil { log.Fatal(err) } diff --git a/cmd/dht-server/main.go b/cmd/dht-server/main.go index c0ba99b3..4b7cf8ab 100644 --- a/cmd/dht-server/main.go +++ b/cmd/dht-server/main.go @@ -2,9 +2,13 @@ package main import ( "bitbucket.org/anacrolix/go.torrent/dht" + "flag" + "fmt" + "io" "log" "net" "os" + "os/signal" ) type pingResponse struct { @@ -12,47 +16,119 @@ type pingResponse struct { krpc dht.Msg } -func main() { +var ( + tableFileName = flag.String("tableFile", "", "name of file for storing node info") + serveAddr = flag.String("serveAddr", ":0", "local UDP address") + + s dht.Server +) + +func loadTable() error { + if *tableFileName == "" { + return nil + } + f, err := os.Open(*tableFileName) + if os.IsNotExist(err) { + return nil + } + if err != nil { + return fmt.Errorf("error opening table file: %s", err) + } + defer f.Close() + added := 0 + for { + b := make([]byte, dht.CompactNodeInfoLen) + _, err := io.ReadFull(f, b) + if err == io.EOF { + break + } + if err != nil { + return fmt.Errorf("error reading table file: %s", err) + } + var ni dht.NodeInfo + err = ni.UnmarshalCompact(b) + if err != nil { + return fmt.Errorf("error unmarshaling compact node info: %s", err) + } + s.AddNode(ni) + added++ + } + log.Printf("loaded %d nodes from table file", added) + return nil +} + +func init() { log.SetFlags(log.LstdFlags | log.Lshortfile) - s := dht.Server{} - var err error - s.Socket, err = net.ListenUDP("udp4", nil) + flag.Parse() + err := loadTable() + if err != nil { + log.Fatalf("error loading table: %s", err) + } + s.Socket, err = net.ListenUDP("udp4", func() *net.UDPAddr { + addr, err := net.ResolveUDPAddr("udp4", *serveAddr) + if err != nil { + log.Fatalf("error resolving serve addr: %s", err) + } + return addr + }()) if err != nil { log.Fatal(err) } + log.Printf("dht server on %s", s.Socket.LocalAddr()) s.Init() - func() { - f, err := os.Open("nodes") - if os.IsNotExist(err) { - return + setupSignals() +} + +func saveTable() error { + goodNodes := s.GoodNodes() + if *tableFileName == "" { + if len(goodNodes) != 0 { + log.Printf("discarding %d good nodes!", len(goodNodes)) } + return nil + } + f, err := os.OpenFile(*tableFileName, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0666) + if err != nil { + return fmt.Errorf("error opening table file: %s", err) + } + defer f.Close() + for _, nodeInfo := range goodNodes { + var b [dht.CompactNodeInfoLen]byte + err := nodeInfo.PutCompact(b[:]) if err != nil { - log.Fatal(err) + return fmt.Errorf("error compacting node info: %s", err) } - defer f.Close() - err = s.ReadNodes(f) + _, err = f.Write(b[:]) if err != nil { - log.Fatal(err) + return fmt.Errorf("error writing compact node info: %s", err) } - }() - log.Printf("dht server on %s", s.Socket.LocalAddr()) + } + log.Printf("saved %d nodes to table file", len(goodNodes)) + return nil +} + +func setupSignals() { + ch := make(chan os.Signal) + signal.Notify(ch) go func() { - err := s.Serve() - if err != nil { - log.Fatal(err) - } + <-ch + s.StopServing() }() - err = s.Bootstrap() - func() { - f, err := os.OpenFile("nodes", os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0666) +} + +func main() { + go func() { + err := s.Bootstrap() if err != nil { - log.Print(err) - return + log.Printf("error bootstrapping: %s", err) + s.StopServing() } - defer f.Close() - s.WriteNodes(f) }() + err := s.Serve() + if err := saveTable(); err != nil { + log.Printf("error saving node table: %s", err) + } if err != nil { - log.Fatal(err) + log.Fatalf("error serving dht: %s", err) } } diff --git a/dht/dht.go b/dht/dht.go index cfc123ba..d7e8e92a 100644 --- a/dht/dht.go +++ b/dht/dht.go @@ -9,6 +9,7 @@ import ( "io" "log" "net" + "sync" "time" ) @@ -18,6 +19,7 @@ type Server struct { transactions []*transaction transactionIDInt uint64 nodes map[string]*Node + mu sync.Mutex } type Node struct { @@ -27,6 +29,16 @@ type Node struct { lastSentTo time.Time } +func (n *Node) Good() bool { + if len(n.id) != 20 { + return false + } + if time.Now().Sub(n.lastHeardFrom) >= 15*time.Minute { + return false + } + return true +} + type Msg map[string]interface{} var _ fmt.Stringer = Msg{} @@ -42,46 +54,6 @@ type transaction struct { response chan Msg } -func (s *Server) ReadNodes(r io.Reader) error { - for { - var b [compactNodeInfoLen]byte - _, err := io.ReadFull(r, b[:]) - if err == io.EOF { - return nil - } - if err != nil { - return err - } - var cni compactNodeInfo - err = cni.UnmarshalBinary(b[:]) - if err != nil { - return err - } - n := s.getNode(cni.Addr) - n.id = string(cni.ID[:]) - } -} - -func (s *Server) WriteNodes(w io.Writer) (n int, err error) { - for _, node := range s.nodes { - cni := compactNodeInfo{ - Addr: node.addr, - } - if n := copy(cni.ID[:], node.id); n != 20 { - panic(n) - } - var b [26]byte - cni.PutBinary(b[:]) - var nn int - nn, err = w.Write(b[:]) - if err != nil { - return - } - n += nn - } - return -} - func (s *Server) setDefaults() { if s.ID == "" { var id [20]byte @@ -95,7 +67,6 @@ func (s *Server) setDefaults() { func (s *Server) Init() { s.setDefaults() - s.nodes = make(map[string]*Node, 1000) } func (s *Server) Serve() error { @@ -111,13 +82,16 @@ func (s *Server) Serve() error { log.Printf("bad krpc message: %s", err) continue } + s.mu.Lock() if d["y"] == "q" { s.handleQuery(addr, d) + s.mu.Unlock() continue } t := s.findResponseTransaction(d["t"].(string), addr) if t == nil { log.Printf("unexpected message: %#v", d) + s.mu.Unlock() continue } t.response <- d @@ -127,14 +101,26 @@ func (s *Server) Serve() error { id = d["r"].(map[string]interface{})["id"].(string) } s.heardFromNode(addr, id) + s.mu.Unlock() + } +} + +func (s *Server) AddNode(ni NodeInfo) { + if s.nodes == nil { + s.nodes = make(map[string]*Node) + } + n := s.getNode(ni.Addr) + if n.id == "" { + n.id = string(ni.ID[:]) } } func (s *Server) handleQuery(source *net.UDPAddr, m Msg) { + log.Print(m["q"]) if m["q"] != "ping" { return } - s.heardFromNode(source, m["a"].(map[string]string)["id"]) + s.heardFromNode(source, m["a"].(map[string]interface{})["id"].(string)) s.reply(source, m["t"].(string)) } @@ -254,30 +240,29 @@ func (s *Server) query(node *net.UDPAddr, q string, a map[string]string) (t *tra return } -const compactNodeInfoLen = 26 - -type compactAddrInfo *net.UDPAddr +const CompactNodeInfoLen = 26 -type compactNodeInfo struct { +type NodeInfo struct { ID [20]byte - Addr compactAddrInfo + Addr *net.UDPAddr } -func (cni *compactNodeInfo) PutBinary(b []byte) { - if n := copy(b[:], cni.ID[:]); n != 20 { +func (ni *NodeInfo) PutCompact(b []byte) error { + if n := copy(b[:], ni.ID[:]); n != 20 { panic(n) } - ip := cni.Addr.IP.To4() + ip := ni.Addr.IP.To4() if len(ip) != 4 { panic(ip) } if n := copy(b[20:], ip); n != 4 { panic(n) } - binary.BigEndian.PutUint16(b[24:], uint16(cni.Addr.Port)) + binary.BigEndian.PutUint16(b[24:], uint16(ni.Addr.Port)) + return nil } -func (cni *compactNodeInfo) UnmarshalBinary(b []byte) error { +func (cni *NodeInfo) UnmarshalCompact(b []byte) error { if len(b) != 26 { return errors.New("expected 26 bytes") } @@ -297,7 +282,7 @@ func (s *Server) Ping(node *net.UDPAddr) (*transaction, error) { } type findNodeResponse struct { - Nodes []compactNodeInfo + Nodes []NodeInfo } func getResponseNodes(m Msg) (s string, err error) { @@ -318,8 +303,8 @@ func (me *findNodeResponse) UnmarshalKRPCMsg(m Msg) error { return err } for i := 0; i < len(b); i += 26 { - var n compactNodeInfo - err := n.UnmarshalBinary([]byte(b[i : i+26])) + var n NodeInfo + err := n.UnmarshalCompact([]byte(b[i : i+26])) if err != nil { return err } @@ -329,7 +314,6 @@ func (me *findNodeResponse) UnmarshalKRPCMsg(m Msg) error { } func (s *Server) FindNode(addr *net.UDPAddr, targetID string) (t *transaction, err error) { - // log.Print(addr) t, err = s.query(addr, "find_node", map[string]string{"target": targetID}) if err != nil { return @@ -348,10 +332,12 @@ func (s *Server) FindNode(addr *net.UDPAddr, targetID string) (t *transaction, e if err != nil { log.Print(err) } else { + s.mu.Lock() for _, cni := range r.Nodes { n := s.getNode(cni.Addr) n.id = string(cni.ID[:]) } + s.mu.Unlock() } } t.Response <- d @@ -359,33 +345,60 @@ func (s *Server) FindNode(addr *net.UDPAddr, targetID string) (t *transaction, e return } -func (s *Server) Bootstrap() error { +func (s *Server) addRootNode() error { + addr, err := net.ResolveUDPAddr("udp4", "router.bittorrent.com:6881") + if err != nil { + return err + } + s.nodes[addr.String()] = &Node{ + addr: addr, + } + return nil +} + +// Populates the node table. +func (s *Server) Bootstrap() (err error) { + s.mu.Lock() + defer s.mu.Unlock() if len(s.nodes) == 0 { - addr, err := net.ResolveUDPAddr("udp4", "router.bittorrent.com:6881") + err = s.addRootNode() if err != nil { - return err + return } - s.nodes[addr.String()] = &Node{ - addr: addr, + } + for _, node := range s.nodes { + var t *transaction + s.mu.Unlock() + t, err = s.FindNode(node.addr, s.ID) + s.mu.Lock() + if err != nil { + return } + go func() { + <-t.Response + }() } - queriedNodes := make(map[string]bool, 1000) - for i := 0; i < 3; i++ { - log.Printf("node table length: %d", len(s.nodes)) - for _, node := range s.nodes { - if queriedNodes[node.addr.String()] { - continue - } - t, err := s.FindNode(node.addr, s.ID) - if err != nil { - return err - } - queriedNodes[node.addr.String()] = true - go func() { - <-t.Response - }() + return +} + +func (s *Server) GoodNodes() (nis []NodeInfo) { + s.mu.Lock() + defer s.mu.Unlock() + for _, node := range s.nodes { + if !node.Good() { + continue } - time.Sleep(3 * time.Second) + ni := NodeInfo{ + Addr: node.addr, + } + if n := copy(ni.ID[:], node.id); n != 20 { + panic(n) + } + nis = append(nis, ni) } - return nil + return +} + +func (s *Server) StopServing() { + s.Socket.Close() } diff --git a/dht/dht_test.go b/dht/dht_test.go index 83bd6e8f..7a945f5f 100644 --- a/dht/dht_test.go +++ b/dht/dht_test.go @@ -6,7 +6,7 @@ import ( ) func TestMarshalCompactNodeInfo(t *testing.T) { - cni := compactNodeInfo{ + cni := NodeInfo{ ID: [20]byte{'a', 'b', 'c'}, } var err error @@ -14,8 +14,8 @@ func TestMarshalCompactNodeInfo(t *testing.T) { if err != nil { t.Fatal(err) } - var b [compactAddrInfoLen]byte - cni.PutBinary(b[:]) + var b [CompactNodeInfoLen]byte + cni.PutCompact(b[:]) if err != nil { t.Fatal(err) }