]> Sergey Matveev's repositories - btrtrc.git/commitdiff
dht-server: Save and load node table between invocations
authorMatt Joiner <anacrolix@gmail.com>
Sun, 25 May 2014 13:04:55 +0000 (23:04 +1000)
committerMatt Joiner <anacrolix@gmail.com>
Sun, 25 May 2014 13:04:55 +0000 (23:04 +1000)
cmd/dht-server/main.go
dht/dht.go

index 1e6abe7245be74322e068544bf39802ddea1c982..c0ba99b32b838459fbf0e0a3099986d9a79cea7a 100644 (file)
@@ -4,6 +4,7 @@ import (
        "bitbucket.org/anacrolix/go.torrent/dht"
        "log"
        "net"
+       "os"
 )
 
 type pingResponse struct {
@@ -20,6 +21,20 @@ func main() {
                log.Fatal(err)
        }
        s.Init()
+       func() {
+               f, err := os.Open("nodes")
+               if os.IsNotExist(err) {
+                       return
+               }
+               if err != nil {
+                       log.Fatal(err)
+               }
+               defer f.Close()
+               err = s.ReadNodes(f)
+               if err != nil {
+                       log.Fatal(err)
+               }
+       }()
        log.Printf("dht server on %s", s.Socket.LocalAddr())
        go func() {
                err := s.Serve()
@@ -28,8 +43,16 @@ func main() {
                }
        }()
        err = s.Bootstrap()
+       func() {
+               f, err := os.OpenFile("nodes", os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0666)
+               if err != nil {
+                       log.Print(err)
+                       return
+               }
+               defer f.Close()
+               s.WriteNodes(f)
+       }()
        if err != nil {
                log.Fatal(err)
        }
-       select {}
 }
index 27c1a79e54f3ff9460b454688783b920381e2b41..cfc123ba584698f70a1e3fbf65c09c6d873a7c9c 100644 (file)
@@ -42,6 +42,26 @@ type transaction struct {
        response   chan Msg
 }
 
+func (s *Server) ReadNodes(r io.Reader) error {
+       for {
+               var b [compactNodeInfoLen]byte
+               _, err := io.ReadFull(r, b[:])
+               if err == io.EOF {
+                       return nil
+               }
+               if err != nil {
+                       return err
+               }
+               var cni compactNodeInfo
+               err = cni.UnmarshalBinary(b[:])
+               if err != nil {
+                       return err
+               }
+               n := s.getNode(cni.Addr)
+               n.id = string(cni.ID[:])
+       }
+}
+
 func (s *Server) WriteNodes(w io.Writer) (n int, err error) {
        for _, node := range s.nodes {
                cni := compactNodeInfo{
@@ -91,7 +111,15 @@ func (s *Server) Serve() error {
                        log.Printf("bad krpc message: %s", err)
                        continue
                }
+               if d["y"] == "q" {
+                       s.handleQuery(addr, d)
+                       continue
+               }
                t := s.findResponseTransaction(d["t"].(string), addr)
+               if t == nil {
+                       log.Printf("unexpected message: %#v", d)
+                       continue
+               }
                t.response <- d
                s.removeTransaction(t)
                id := ""
@@ -102,6 +130,32 @@ func (s *Server) Serve() error {
        }
 }
 
+func (s *Server) handleQuery(source *net.UDPAddr, m Msg) {
+       if m["q"] != "ping" {
+               return
+       }
+       s.heardFromNode(source, m["a"].(map[string]string)["id"])
+       s.reply(source, m["t"].(string))
+}
+
+func (s *Server) reply(addr *net.UDPAddr, t string) {
+       m := map[string]interface{}{
+               "t": t,
+               "y": "r",
+               "r": map[string]string{
+                       "id": s.IDString(),
+               },
+       }
+       b, err := bencode.Marshal(m)
+       if err != nil {
+               panic(err)
+       }
+       _, err = s.Socket.WriteTo(b, addr)
+       if err != nil {
+               panic(err)
+       }
+}
+
 func (s *Server) heardFromNode(addr *net.UDPAddr, id string) {
        n := s.getNode(addr)
        n.id = id
@@ -200,7 +254,7 @@ func (s *Server) query(node *net.UDPAddr, q string, a map[string]string) (t *tra
        return
 }
 
-const compactAddrInfoLen = 26
+const compactNodeInfoLen = 26
 
 type compactAddrInfo *net.UDPAddr
 
@@ -246,9 +300,23 @@ type findNodeResponse struct {
        Nodes []compactNodeInfo
 }
 
+func getResponseNodes(m Msg) (s string, err error) {
+       defer func() {
+               r := recover()
+               if r == nil {
+                       return
+               }
+               err = fmt.Errorf("couldn't get response nodes: %s: %#v", r, m)
+       }()
+       s = m["r"].(map[string]interface{})["nodes"].(string)
+       return
+}
+
 func (me *findNodeResponse) UnmarshalKRPCMsg(m Msg) error {
-       b := m["r"].(map[string]interface{})["nodes"].(string)
-       log.Printf("%q", b)
+       b, err := getResponseNodes(m)
+       if err != nil {
+               return err
+       }
        for i := 0; i < len(b); i += 26 {
                var n compactNodeInfo
                err := n.UnmarshalBinary([]byte(b[i : i+26]))
@@ -261,7 +329,7 @@ func (me *findNodeResponse) UnmarshalKRPCMsg(m Msg) error {
 }
 
 func (s *Server) FindNode(addr *net.UDPAddr, targetID string) (t *transaction, err error) {
-       log.Print(addr)
+       // log.Print(addr)
        t, err = s.query(addr, "find_node", map[string]string{"target": targetID})
        if err != nil {
                return
@@ -302,10 +370,10 @@ func (s *Server) Bootstrap() error {
                }
        }
        queriedNodes := make(map[string]bool, 1000)
-       for {
+       for i := 0; i < 3; i++ {
+               log.Printf("node table length: %d", len(s.nodes))
                for _, node := range s.nodes {
                        if queriedNodes[node.addr.String()] {
-                               log.Printf("skipping already queried: %s", node.addr)
                                continue
                        }
                        t, err := s.FindNode(node.addr, s.ID)
@@ -314,7 +382,7 @@ func (s *Server) Bootstrap() error {
                        }
                        queriedNodes[node.addr.String()] = true
                        go func() {
-                               log.Print(<-t.Response)
+                               <-t.Response
                        }()
                }
                time.Sleep(3 * time.Second)