]> Sergey Matveev's repositories - btrtrc.git/commitdiff
Got dht-server working nicely
authorMatt Joiner <anacrolix@gmail.com>
Tue, 27 May 2014 06:28:56 +0000 (16:28 +1000)
committerMatt Joiner <anacrolix@gmail.com>
Tue, 27 May 2014 06:28:56 +0000 (16:28 +1000)
cmd/dht-ping/main.go
cmd/dht-server/main.go
dht/dht.go
dht/dht_test.go

index 833ca6fdf39827739d6e8540f9c270407223e9bc..9e43f798ff39ff2a5396ecb3490cab307b602772 100644 (file)
@@ -23,7 +23,7 @@ func main() {
        }
        s := dht.Server{}
        var err error
-       s.Socket, err = net.ListenPacket("udp4", "")
+       s.Socket, err = net.ListenUDP("udp4", nil)
        if err != nil {
                log.Fatal(err)
        }
index c0ba99b32b838459fbf0e0a3099986d9a79cea7a..4b7cf8abf49c9bcc45a98dab04bdb079ef64b005 100644 (file)
@@ -2,9 +2,13 @@ package main
 
 import (
        "bitbucket.org/anacrolix/go.torrent/dht"
+       "flag"
+       "fmt"
+       "io"
        "log"
        "net"
        "os"
+       "os/signal"
 )
 
 type pingResponse struct {
@@ -12,47 +16,119 @@ type pingResponse struct {
        krpc dht.Msg
 }
 
-func main() {
+var (
+       tableFileName = flag.String("tableFile", "", "name of file for storing node info")
+       serveAddr     = flag.String("serveAddr", ":0", "local UDP address")
+
+       s dht.Server
+)
+
+func loadTable() error {
+       if *tableFileName == "" {
+               return nil
+       }
+       f, err := os.Open(*tableFileName)
+       if os.IsNotExist(err) {
+               return nil
+       }
+       if err != nil {
+               return fmt.Errorf("error opening table file: %s", err)
+       }
+       defer f.Close()
+       added := 0
+       for {
+               b := make([]byte, dht.CompactNodeInfoLen)
+               _, err := io.ReadFull(f, b)
+               if err == io.EOF {
+                       break
+               }
+               if err != nil {
+                       return fmt.Errorf("error reading table file: %s", err)
+               }
+               var ni dht.NodeInfo
+               err = ni.UnmarshalCompact(b)
+               if err != nil {
+                       return fmt.Errorf("error unmarshaling compact node info: %s", err)
+               }
+               s.AddNode(ni)
+               added++
+       }
+       log.Printf("loaded %d nodes from table file", added)
+       return nil
+}
+
+func init() {
        log.SetFlags(log.LstdFlags | log.Lshortfile)
-       s := dht.Server{}
-       var err error
-       s.Socket, err = net.ListenUDP("udp4", nil)
+       flag.Parse()
+       err := loadTable()
+       if err != nil {
+               log.Fatalf("error loading table: %s", err)
+       }
+       s.Socket, err = net.ListenUDP("udp4", func() *net.UDPAddr {
+               addr, err := net.ResolveUDPAddr("udp4", *serveAddr)
+               if err != nil {
+                       log.Fatalf("error resolving serve addr: %s", err)
+               }
+               return addr
+       }())
        if err != nil {
                log.Fatal(err)
        }
+       log.Printf("dht server on %s", s.Socket.LocalAddr())
        s.Init()
-       func() {
-               f, err := os.Open("nodes")
-               if os.IsNotExist(err) {
-                       return
+       setupSignals()
+}
+
+func saveTable() error {
+       goodNodes := s.GoodNodes()
+       if *tableFileName == "" {
+               if len(goodNodes) != 0 {
+                       log.Printf("discarding %d good nodes!", len(goodNodes))
                }
+               return nil
+       }
+       f, err := os.OpenFile(*tableFileName, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0666)
+       if err != nil {
+               return fmt.Errorf("error opening table file: %s", err)
+       }
+       defer f.Close()
+       for _, nodeInfo := range goodNodes {
+               var b [dht.CompactNodeInfoLen]byte
+               err := nodeInfo.PutCompact(b[:])
                if err != nil {
-                       log.Fatal(err)
+                       return fmt.Errorf("error compacting node info: %s", err)
                }
-               defer f.Close()
-               err = s.ReadNodes(f)
+               _, err = f.Write(b[:])
                if err != nil {
-                       log.Fatal(err)
+                       return fmt.Errorf("error writing compact node info: %s", err)
                }
-       }()
-       log.Printf("dht server on %s", s.Socket.LocalAddr())
+       }
+       log.Printf("saved %d nodes to table file", len(goodNodes))
+       return nil
+}
+
+func setupSignals() {
+       ch := make(chan os.Signal)
+       signal.Notify(ch)
        go func() {
-               err := s.Serve()
-               if err != nil {
-                       log.Fatal(err)
-               }
+               <-ch
+               s.StopServing()
        }()
-       err = s.Bootstrap()
-       func() {
-               f, err := os.OpenFile("nodes", os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0666)
+}
+
+func main() {
+       go func() {
+               err := s.Bootstrap()
                if err != nil {
-                       log.Print(err)
-                       return
+                       log.Printf("error bootstrapping: %s", err)
+                       s.StopServing()
                }
-               defer f.Close()
-               s.WriteNodes(f)
        }()
+       err := s.Serve()
+       if err := saveTable(); err != nil {
+               log.Printf("error saving node table: %s", err)
+       }
        if err != nil {
-               log.Fatal(err)
+               log.Fatalf("error serving dht: %s", err)
        }
 }
index cfc123ba584698f70a1e3fbf65c09c6d873a7c9c..d7e8e92ac1134adcbdbd82c548ace24d470b0c4c 100644 (file)
@@ -9,6 +9,7 @@ import (
        "io"
        "log"
        "net"
+       "sync"
        "time"
 )
 
@@ -18,6 +19,7 @@ type Server struct {
        transactions     []*transaction
        transactionIDInt uint64
        nodes            map[string]*Node
+       mu               sync.Mutex
 }
 
 type Node struct {
@@ -27,6 +29,16 @@ type Node struct {
        lastSentTo    time.Time
 }
 
+func (n *Node) Good() bool {
+       if len(n.id) != 20 {
+               return false
+       }
+       if time.Now().Sub(n.lastHeardFrom) >= 15*time.Minute {
+               return false
+       }
+       return true
+}
+
 type Msg map[string]interface{}
 
 var _ fmt.Stringer = Msg{}
@@ -42,46 +54,6 @@ type transaction struct {
        response   chan Msg
 }
 
-func (s *Server) ReadNodes(r io.Reader) error {
-       for {
-               var b [compactNodeInfoLen]byte
-               _, err := io.ReadFull(r, b[:])
-               if err == io.EOF {
-                       return nil
-               }
-               if err != nil {
-                       return err
-               }
-               var cni compactNodeInfo
-               err = cni.UnmarshalBinary(b[:])
-               if err != nil {
-                       return err
-               }
-               n := s.getNode(cni.Addr)
-               n.id = string(cni.ID[:])
-       }
-}
-
-func (s *Server) WriteNodes(w io.Writer) (n int, err error) {
-       for _, node := range s.nodes {
-               cni := compactNodeInfo{
-                       Addr: node.addr,
-               }
-               if n := copy(cni.ID[:], node.id); n != 20 {
-                       panic(n)
-               }
-               var b [26]byte
-               cni.PutBinary(b[:])
-               var nn int
-               nn, err = w.Write(b[:])
-               if err != nil {
-                       return
-               }
-               n += nn
-       }
-       return
-}
-
 func (s *Server) setDefaults() {
        if s.ID == "" {
                var id [20]byte
@@ -95,7 +67,6 @@ func (s *Server) setDefaults() {
 
 func (s *Server) Init() {
        s.setDefaults()
-       s.nodes = make(map[string]*Node, 1000)
 }
 
 func (s *Server) Serve() error {
@@ -111,13 +82,16 @@ func (s *Server) Serve() error {
                        log.Printf("bad krpc message: %s", err)
                        continue
                }
+               s.mu.Lock()
                if d["y"] == "q" {
                        s.handleQuery(addr, d)
+                       s.mu.Unlock()
                        continue
                }
                t := s.findResponseTransaction(d["t"].(string), addr)
                if t == nil {
                        log.Printf("unexpected message: %#v", d)
+                       s.mu.Unlock()
                        continue
                }
                t.response <- d
@@ -127,14 +101,26 @@ func (s *Server) Serve() error {
                        id = d["r"].(map[string]interface{})["id"].(string)
                }
                s.heardFromNode(addr, id)
+               s.mu.Unlock()
+       }
+}
+
+func (s *Server) AddNode(ni NodeInfo) {
+       if s.nodes == nil {
+               s.nodes = make(map[string]*Node)
+       }
+       n := s.getNode(ni.Addr)
+       if n.id == "" {
+               n.id = string(ni.ID[:])
        }
 }
 
 func (s *Server) handleQuery(source *net.UDPAddr, m Msg) {
+       log.Print(m["q"])
        if m["q"] != "ping" {
                return
        }
-       s.heardFromNode(source, m["a"].(map[string]string)["id"])
+       s.heardFromNode(source, m["a"].(map[string]interface{})["id"].(string))
        s.reply(source, m["t"].(string))
 }
 
@@ -254,30 +240,29 @@ func (s *Server) query(node *net.UDPAddr, q string, a map[string]string) (t *tra
        return
 }
 
-const compactNodeInfoLen = 26
-
-type compactAddrInfo *net.UDPAddr
+const CompactNodeInfoLen = 26
 
-type compactNodeInfo struct {
+type NodeInfo struct {
        ID   [20]byte
-       Addr compactAddrInfo
+       Addr *net.UDPAddr
 }
 
-func (cni *compactNodeInfo) PutBinary(b []byte) {
-       if n := copy(b[:], cni.ID[:]); n != 20 {
+func (ni *NodeInfo) PutCompact(b []byte) error {
+       if n := copy(b[:], ni.ID[:]); n != 20 {
                panic(n)
        }
-       ip := cni.Addr.IP.To4()
+       ip := ni.Addr.IP.To4()
        if len(ip) != 4 {
                panic(ip)
        }
        if n := copy(b[20:], ip); n != 4 {
                panic(n)
        }
-       binary.BigEndian.PutUint16(b[24:], uint16(cni.Addr.Port))
+       binary.BigEndian.PutUint16(b[24:], uint16(ni.Addr.Port))
+       return nil
 }
 
-func (cni *compactNodeInfo) UnmarshalBinary(b []byte) error {
+func (cni *NodeInfo) UnmarshalCompact(b []byte) error {
        if len(b) != 26 {
                return errors.New("expected 26 bytes")
        }
@@ -297,7 +282,7 @@ func (s *Server) Ping(node *net.UDPAddr) (*transaction, error) {
 }
 
 type findNodeResponse struct {
-       Nodes []compactNodeInfo
+       Nodes []NodeInfo
 }
 
 func getResponseNodes(m Msg) (s string, err error) {
@@ -318,8 +303,8 @@ func (me *findNodeResponse) UnmarshalKRPCMsg(m Msg) error {
                return err
        }
        for i := 0; i < len(b); i += 26 {
-               var n compactNodeInfo
-               err := n.UnmarshalBinary([]byte(b[i : i+26]))
+               var n NodeInfo
+               err := n.UnmarshalCompact([]byte(b[i : i+26]))
                if err != nil {
                        return err
                }
@@ -329,7 +314,6 @@ func (me *findNodeResponse) UnmarshalKRPCMsg(m Msg) error {
 }
 
 func (s *Server) FindNode(addr *net.UDPAddr, targetID string) (t *transaction, err error) {
-       // log.Print(addr)
        t, err = s.query(addr, "find_node", map[string]string{"target": targetID})
        if err != nil {
                return
@@ -348,10 +332,12 @@ func (s *Server) FindNode(addr *net.UDPAddr, targetID string) (t *transaction, e
                        if err != nil {
                                log.Print(err)
                        } else {
+                               s.mu.Lock()
                                for _, cni := range r.Nodes {
                                        n := s.getNode(cni.Addr)
                                        n.id = string(cni.ID[:])
                                }
+                               s.mu.Unlock()
                        }
                }
                t.Response <- d
@@ -359,33 +345,60 @@ func (s *Server) FindNode(addr *net.UDPAddr, targetID string) (t *transaction, e
        return
 }
 
-func (s *Server) Bootstrap() error {
+func (s *Server) addRootNode() error {
+       addr, err := net.ResolveUDPAddr("udp4", "router.bittorrent.com:6881")
+       if err != nil {
+               return err
+       }
+       s.nodes[addr.String()] = &Node{
+               addr: addr,
+       }
+       return nil
+}
+
+// Populates the node table.
+func (s *Server) Bootstrap() (err error) {
+       s.mu.Lock()
+       defer s.mu.Unlock()
        if len(s.nodes) == 0 {
-               addr, err := net.ResolveUDPAddr("udp4", "router.bittorrent.com:6881")
+               err = s.addRootNode()
                if err != nil {
-                       return err
+                       return
                }
-               s.nodes[addr.String()] = &Node{
-                       addr: addr,
+       }
+       for _, node := range s.nodes {
+               var t *transaction
+               s.mu.Unlock()
+               t, err = s.FindNode(node.addr, s.ID)
+               s.mu.Lock()
+               if err != nil {
+                       return
                }
+               go func() {
+                       <-t.Response
+               }()
        }
-       queriedNodes := make(map[string]bool, 1000)
-       for i := 0; i < 3; i++ {
-               log.Printf("node table length: %d", len(s.nodes))
-               for _, node := range s.nodes {
-                       if queriedNodes[node.addr.String()] {
-                               continue
-                       }
-                       t, err := s.FindNode(node.addr, s.ID)
-                       if err != nil {
-                               return err
-                       }
-                       queriedNodes[node.addr.String()] = true
-                       go func() {
-                               <-t.Response
-                       }()
+       return
+}
+
+func (s *Server) GoodNodes() (nis []NodeInfo) {
+       s.mu.Lock()
+       defer s.mu.Unlock()
+       for _, node := range s.nodes {
+               if !node.Good() {
+                       continue
                }
-               time.Sleep(3 * time.Second)
+               ni := NodeInfo{
+                       Addr: node.addr,
+               }
+               if n := copy(ni.ID[:], node.id); n != 20 {
+                       panic(n)
+               }
+               nis = append(nis, ni)
        }
-       return nil
+       return
+}
+
+func (s *Server) StopServing() {
+       s.Socket.Close()
 }
index 83bd6e8f0d6911d30198a41a3d61f3517eb9b0fa..7a945f5f27431b875d22a5d896fe4e3e3e5ff3ff 100644 (file)
@@ -6,7 +6,7 @@ import (
 )
 
 func TestMarshalCompactNodeInfo(t *testing.T) {
-       cni := compactNodeInfo{
+       cni := NodeInfo{
                ID: [20]byte{'a', 'b', 'c'},
        }
        var err error
@@ -14,8 +14,8 @@ func TestMarshalCompactNodeInfo(t *testing.T) {
        if err != nil {
                t.Fatal(err)
        }
-       var b [compactAddrInfoLen]byte
-       cni.PutBinary(b[:])
+       var b [CompactNodeInfoLen]byte
+       cni.PutCompact(b[:])
        if err != nil {
                t.Fatal(err)
        }