]> Sergey Matveev's repositories - btrtrc.git/commitdiff
dht: Concurrency improvements and fixes to bootstrapping and getting peers
authorMatt Joiner <anacrolix@gmail.com>
Wed, 9 Jul 2014 14:13:54 +0000 (00:13 +1000)
committerMatt Joiner <anacrolix@gmail.com>
Wed, 9 Jul 2014 14:13:54 +0000 (00:13 +1000)
dht/dht.go

index 2b132b6ebb0bc7b2e42e6e040caa35b6ecce6ec8..372f78489b7cd7a755d79fae4d94da171748de24 100644 (file)
@@ -24,6 +24,11 @@ type Server struct {
        transactionIDInt uint64
        nodes            map[string]*Node
        mu               sync.Mutex
+       closed           chan struct{}
+}
+
+func (s *Server) String() string {
+       return fmt.Sprintf("dht server on %s", s.Socket.LocalAddr())
 }
 
 type Node struct {
@@ -55,7 +60,15 @@ type transaction struct {
        remoteAddr net.Addr
        t          string
        Response   chan Msg
-       response   chan Msg
+       onResponse func(Msg)
+}
+
+func (t *transaction) handleResponse(m Msg) {
+       if t.onResponse != nil {
+               t.onResponse(m)
+       }
+       t.Response <- m
+       close(t.Response)
 }
 
 func (s *Server) setDefaults() (err error) {
@@ -91,9 +104,13 @@ func (s *Server) setDefaults() (err error) {
        return
 }
 
-func (s *Server) Init() error {
-       return s.setDefaults()
-       //s.nodes = make(map[string]*Node)
+func (s *Server) Init() (err error) {
+       err = s.setDefaults()
+       if err != nil {
+               return
+       }
+       s.closed = make(chan struct{})
+       return
 }
 
 func (s *Server) Serve() error {
@@ -106,7 +123,7 @@ func (s *Server) Serve() error {
                var d map[string]interface{}
                err = bencode.Unmarshal(b[:n], &d)
                if err != nil {
-                       log.Printf("bad krpc message: %s: %q", err, b[:n])
+                       log.Printf("%s: received bad krpc message: %s: %q", s, err, b[:n])
                        continue
                }
                s.mu.Lock()
@@ -121,7 +138,7 @@ func (s *Server) Serve() error {
                        s.mu.Unlock()
                        continue
                }
-               t.response <- d
+               t.handleResponse(d)
                s.removeTransaction(t)
                id := ""
                if d["y"] == "r" {
@@ -143,8 +160,8 @@ func (s *Server) AddNode(ni NodeInfo) {
 }
 
 func (s *Server) handleQuery(source *net.UDPAddr, m Msg) {
-       log.Print(m["q"])
        if m["q"] != "ping" {
+               log.Printf("%s: not handling received query: q=%s", s, m["q"])
                return
        }
        s.heardFromNode(source, m["a"].(map[string]interface{})["id"].(string))
@@ -264,7 +281,6 @@ func (s *Server) query(node *net.UDPAddr, q string, a map[string]string) (t *tra
                t:          tid,
                Response:   make(chan Msg, 1),
        }
-       t.response = t.Response
        s.addTransaction(t)
        err = s.writeToNode(b, node)
        if err != nil {
@@ -346,18 +362,54 @@ func (me *findNodeResponse) UnmarshalKRPCMsg(m Msg) error {
        return nil
 }
 
-func (t *transaction) onResponse(f func(m Msg)) {
-       ch := make(chan Msg)
-       t.response = ch
-       go func() {
-               d, ok := <-t.response
-               if !ok {
-                       close(t.Response)
+func (t *transaction) setOnResponse(f func(m Msg)) {
+       if t.onResponse != nil {
+               panic(t.onResponse)
+       }
+       t.onResponse = f
+}
+
+func unmarshalNodeInfoBinary(b []byte) (ret []NodeInfo, err error) {
+       if len(b)%26 != 0 {
+               err = errors.New("bad buffer length")
+               return
+       }
+       ret = make([]NodeInfo, 0, len(b)/26)
+       for i := 0; i < len(b); i += 26 {
+               var ni NodeInfo
+               err = ni.UnmarshalCompact(b[i : i+26])
+               if err != nil {
                        return
                }
-               f(d)
-               t.Response <- d
-       }()
+               ret = append(ret, ni)
+       }
+       return
+}
+
+func extractNodes(d Msg) (nodes []NodeInfo, err error) {
+       if d["y"] != "r" {
+               return
+       }
+       r, ok := d["r"]
+       if !ok {
+               err = errors.New("missing r dict")
+               return
+       }
+       rd, ok := r.(map[string]interface{})
+       if !ok {
+               err = errors.New("bad r value type")
+               return
+       }
+       n, ok := rd["nodes"]
+       if !ok {
+               return
+       }
+       ns, ok := n.(string)
+       if !ok {
+               err = errors.New("bad nodes value type")
+               return
+       }
+       return unmarshalNodeInfoBinary([]byte(ns))
 }
 
 func (s *Server) liftNodes(d Msg) {
@@ -369,25 +421,23 @@ func (s *Server) liftNodes(d Msg) {
        if err != nil {
                // log.Print(err)
        } else {
-               s.mu.Lock()
                for _, cni := range r.Nodes {
                        n := s.getNode(cni.Addr)
                        n.id = string(cni.ID[:])
                }
-               s.mu.Unlock()
                // log.Printf("lifted %d nodes", len(r.Nodes))
        }
 }
 
 // Sends a find_node query to addr. targetID is the node we're looking for.
-func (s *Server) FindNode(addr *net.UDPAddr, targetID string) (t *transaction, err error) {
+func (s *Server) findNode(addr *net.UDPAddr, targetID string) (t *transaction, err error) {
        t, err = s.query(addr, "find_node", map[string]string{"target": targetID})
        if err != nil {
                return
        }
        // Scrape peers from the response to put in the server's table before
        // handing the response back to the caller.
-       t.onResponse(func(d Msg) {
+       t.setOnResponse(func(d Msg) {
                s.liftNodes(d)
        })
        return
@@ -471,9 +521,10 @@ func (s *Server) GetPeers(infoHash string) (ps *peerStream, err error) {
                        case m := <-t.Response:
                                vs := extractValues(m)
                                if vs != nil {
-                                       ps.Values <- vs
-                                       // } else {
-                                       // log.Print("get_peers response had no values")
+                                       select {
+                                       case ps.Values <- vs:
+                                       case <-ps.stop:
+                                       }
                                }
                        case <-ps.stop:
                        }
@@ -484,7 +535,10 @@ func (s *Server) GetPeers(infoHash string) (ps *peerStream, err error) {
        s.mu.Unlock()
        go func() {
                for ; pending > 0; pending-- {
-                       <-done
+                       select {
+                       case <-done:
+                       case <-s.closed:
+                       }
                }
                ps.Close()
        }()
@@ -500,7 +554,7 @@ func (s *Server) getPeers(addr *net.UDPAddr, infoHash string) (t *transaction, e
        if err != nil {
                return
        }
-       t.onResponse(func(m Msg) {
+       t.setOnResponse(func(m Msg) {
                s.liftNodes(m)
        })
        return
@@ -523,24 +577,38 @@ func (s *Server) Bootstrap() (err error) {
        defer s.mu.Unlock()
        if len(s.nodes) == 0 {
                err = s.addRootNode()
-               if err != nil {
-                       return
-               }
+       }
+       if err != nil {
+               return
        }
        for {
+               var outstanding sync.WaitGroup
                for _, node := range s.nodes {
                        var t *transaction
-                       s.mu.Unlock()
-                       t, err = s.FindNode(node.addr, s.ID)
-                       s.mu.Lock()
+                       t, err = s.findNode(node.addr, s.ID)
                        if err != nil {
                                return
                        }
+                       outstanding.Add(1)
                        go func() {
                                <-t.Response
+                               outstanding.Done()
                        }()
                }
-               time.Sleep(5 * time.Second)
+               noOutstanding := make(chan struct{})
+               go func() {
+                       outstanding.Wait()
+                       close(noOutstanding)
+               }()
+               s.mu.Unlock()
+               select {
+               case <-s.closed:
+                       s.mu.Lock()
+                       return
+               case <-time.After(15 * time.Second):
+               case <-noOutstanding:
+               }
+               s.mu.Lock()
                log.Printf("now have %d nodes", len(s.nodes))
                if len(s.nodes) >= 8*160 {
                        break
@@ -569,6 +637,13 @@ func (s *Server) Nodes() (nis []NodeInfo) {
 
 func (s *Server) StopServing() {
        s.Socket.Close()
+       s.mu.Lock()
+       select {
+       case <-s.closed:
+       default:
+               close(s.closed)
+       }
+       s.mu.Unlock()
 }
 
 func idDistance(a, b string) (ret int) {