]> Sergey Matveev's repositories - btrtrc.git/commitdiff
dht: Reply to get_peers and find_node queries
authorMatt Joiner <anacrolix@gmail.com>
Fri, 11 Jul 2014 15:24:01 +0000 (01:24 +1000)
committerMatt Joiner <anacrolix@gmail.com>
Fri, 11 Jul 2014 15:24:01 +0000 (01:24 +1000)
dht/dht.go

index 372f78489b7cd7a755d79fae4d94da171748de24..16e4a20b3b4335faf931d1aee442324a408688b7 100644 (file)
@@ -22,7 +22,7 @@ type Server struct {
        Socket           *net.UDPConn
        transactions     []*transaction
        transactionIDInt uint64
-       nodes            map[string]*Node
+       nodes            map[string]*Node // Keyed by *net.UDPAddr.String().
        mu               sync.Mutex
        closed           chan struct{}
 }
@@ -38,6 +38,14 @@ type Node struct {
        lastSentTo    time.Time
 }
 
+func (n *Node) NodeInfo() (ret NodeInfo) {
+       ret.Addr = n.addr
+       if n := copy(ret.ID[:], n.id); n != 20 {
+               panic(n)
+       }
+       return
+}
+
 func (n *Node) Good() bool {
        if len(n.id) != 20 {
                return false
@@ -134,7 +142,7 @@ func (s *Server) Serve() error {
                }
                t := s.findResponseTransaction(d["t"].(string), addr)
                if t == nil {
-                       log.Printf("unexpected message: %#v", d)
+                       //log.Printf("unexpected message: %#v", d)
                        s.mu.Unlock()
                        continue
                }
@@ -159,22 +167,76 @@ func (s *Server) AddNode(ni NodeInfo) {
        }
 }
 
+func (s *Server) nodeByID(id string) *Node {
+       for _, node := range s.nodes {
+               if node.id == id {
+                       return node
+               }
+       }
+       return nil
+}
+
 func (s *Server) handleQuery(source *net.UDPAddr, m Msg) {
-       if m["q"] != "ping" {
+       args := m["a"].(map[string]interface{})
+       s.heardFromNode(source, args["id"].(string))
+       switch m["q"] {
+       case "ping":
+               s.reply(source, m["t"].(string), nil)
+       case "get_peers":
+               targetID := args["info_hash"].(string)
+               var rNodes []NodeInfo
+               // TODO: Reply with "values" list if we have peers instead.
+               for _, node := range s.closestGoodNodes(8, targetID) {
+                       rNodes = append(rNodes, node.NodeInfo())
+               }
+               nodesBytes := make([]byte, CompactNodeInfoLen*len(rNodes))
+               for i, ni := range rNodes {
+                       err := ni.PutCompact(nodesBytes[i*CompactNodeInfoLen : (i+1)*CompactNodeInfoLen])
+                       if err != nil {
+                               panic(err)
+                       }
+               }
+               s.reply(source, m["t"].(string), map[string]interface{}{
+                       "nodes": string(nodesBytes),
+                       "token": "hi",
+               })
+       case "find_node":
+               targetID := args["target"].(string)
+               var rNodes []NodeInfo
+               if node := s.nodeByID(targetID); node != nil {
+                       rNodes = append(rNodes, node.NodeInfo())
+               } else {
+                       for _, node := range s.closestGoodNodes(8, targetID) {
+                               rNodes = append(rNodes, node.NodeInfo())
+                       }
+               }
+               nodesBytes := make([]byte, CompactNodeInfoLen*len(rNodes))
+               for i, ni := range rNodes {
+                       err := ni.PutCompact(nodesBytes[i*CompactNodeInfoLen : (i+1)*CompactNodeInfoLen])
+                       if err != nil {
+                               panic(err)
+                       }
+               }
+               s.reply(source, m["t"].(string), map[string]interface{}{
+                       "nodes": string(nodesBytes),
+               })
+       case "announce_peer":
+               log.Print(m)
+       default:
                log.Printf("%s: not handling received query: q=%s", s, m["q"])
                return
        }
-       s.heardFromNode(source, m["a"].(map[string]interface{})["id"].(string))
-       s.reply(source, m["t"].(string))
 }
 
-func (s *Server) reply(addr *net.UDPAddr, t string) {
+func (s *Server) reply(addr *net.UDPAddr, t string, r map[string]interface{}) {
+       if r == nil {
+               r = make(map[string]interface{}, 1)
+       }
+       r["id"] = s.IDString()
        m := map[string]interface{}{
                "t": t,
                "y": "r",
-               "r": map[string]string{
-                       "id": s.IDString(),
-               },
+               "r": r,
        }
        b, err := bencode.Marshal(m)
        if err != nil {
@@ -661,13 +723,20 @@ func idDistance(a, b string) (ret int) {
        return
 }
 
-// func (s *Server) closestNodes(k int) (ret *closestNodes) {
-//     heap.Init(ret)
-//     for _, node := range s.nodes {
-//             heap.Push(ret, node)
-//             if ret.Len() > k {
-//                     heap.Pop(ret)
-//             }
-//     }
-//     return
-// }
+func (s *Server) closestGoodNodes(k int, targetID string) []*Node {
+       sel := newKClosestNodesSelector(k, targetID)
+       idNodes := make(map[string]*Node, len(s.nodes))
+       for _, node := range s.nodes {
+               if !node.Good() {
+                       continue
+               }
+               sel.Push(node.id)
+               idNodes[node.id] = node
+       }
+       ids := sel.IDs()
+       ret := make([]*Node, 0, len(ids))
+       for _, id := range ids {
+               ret = append(ret, idNodes[id])
+       }
+       return ret
+}