]> Sergey Matveev's repositories - btrtrc.git/blob - dht/dht.go
d2f148799f944262600c5792db103b11a6ceaa16
[btrtrc.git] / dht / dht.go
1 package dht
2
3 import (
4         "crypto"
5         _ "crypto/sha1"
6         "encoding/binary"
7         "errors"
8         "fmt"
9         "io"
10         "log"
11         "math/big"
12         "math/rand"
13         "net"
14         "os"
15         "time"
16
17         "github.com/anacrolix/torrent/iplist"
18         "github.com/anacrolix/torrent/logonce"
19         "github.com/anacrolix/torrent/util"
20         "github.com/anacrolix/sync"
21         "github.com/anacrolix/libtorgo/bencode"
22 )
23
24 const (
25         maxNodes         = 320
26         queryResendEvery = 5 * time.Second
27 )
28
29 // Uniquely identifies a transaction to us.
30 type transactionKey struct {
31         RemoteAddr string // host:port
32         T          string // The KRPC transaction ID.
33 }
34
35 type Server struct {
36         id               string
37         socket           net.PacketConn
38         transactions     map[transactionKey]*transaction
39         transactionIDInt uint64
40         nodes            map[string]*Node // Keyed by dHTAddr.String().
41         mu               sync.Mutex
42         closed           chan struct{}
43         passive          bool // Don't respond to queries.
44         ipBlockList      *iplist.IPList
45
46         NumConfirmedAnnounces int
47 }
48
49 type dHTAddr interface {
50         net.Addr
51         UDPAddr() *net.UDPAddr
52 }
53
54 type cachedAddr struct {
55         a net.Addr
56         s string
57 }
58
59 func (ca cachedAddr) Network() string {
60         return ca.a.Network()
61 }
62
63 func (ca cachedAddr) String() string {
64         return ca.s
65 }
66
67 func (ca cachedAddr) UDPAddr() *net.UDPAddr {
68         return ca.a.(*net.UDPAddr)
69 }
70
71 func newDHTAddr(addr net.Addr) dHTAddr {
72         return cachedAddr{addr, addr.String()}
73 }
74
75 type ServerConfig struct {
76         Addr    string
77         Conn    net.PacketConn
78         Passive bool // Don't respond to queries.
79 }
80
81 type serverStats struct {
82         NumGoodNodes               int
83         NumNodes                   int
84         NumOutstandingTransactions int
85 }
86
87 func (s *Server) Stats() (ss serverStats) {
88         s.mu.Lock()
89         defer s.mu.Unlock()
90         for _, n := range s.nodes {
91                 if n.DefinitelyGood() {
92                         ss.NumGoodNodes++
93                 }
94         }
95         ss.NumNodes = len(s.nodes)
96         ss.NumOutstandingTransactions = len(s.transactions)
97         return
98 }
99
100 func (s *Server) LocalAddr() net.Addr {
101         return s.socket.LocalAddr()
102 }
103
104 func makeSocket(addr string) (socket *net.UDPConn, err error) {
105         addr_, err := net.ResolveUDPAddr("", addr)
106         if err != nil {
107                 return
108         }
109         socket, err = net.ListenUDP("udp", addr_)
110         return
111 }
112
113 func NewServer(c *ServerConfig) (s *Server, err error) {
114         if c == nil {
115                 c = &ServerConfig{}
116         }
117         s = &Server{}
118         if c.Conn != nil {
119                 s.socket = c.Conn
120         } else {
121                 s.socket, err = makeSocket(c.Addr)
122                 if err != nil {
123                         return
124                 }
125         }
126         s.passive = c.Passive
127         err = s.init()
128         if err != nil {
129                 return
130         }
131         go func() {
132                 err := s.serve()
133                 select {
134                 case <-s.closed:
135                         return
136                 default:
137                 }
138                 if err != nil {
139                         panic(err)
140                 }
141         }()
142         go func() {
143                 err := s.bootstrap()
144                 if err != nil {
145                         log.Printf("error bootstrapping DHT: %s", err)
146                 }
147         }()
148         return
149 }
150
151 func (s *Server) String() string {
152         return fmt.Sprintf("dht server on %s", s.socket.LocalAddr())
153 }
154
155 type nodeID struct {
156         i   big.Int
157         set bool
158 }
159
160 func (nid *nodeID) IsUnset() bool {
161         return !nid.set
162 }
163
164 func nodeIDFromString(s string) (ret nodeID) {
165         if s == "" {
166                 return
167         }
168         ret.i.SetBytes([]byte(s))
169         ret.set = true
170         return
171 }
172
173 func (nid0 *nodeID) Distance(nid1 *nodeID) (ret big.Int) {
174         if nid0.IsUnset() != nid1.IsUnset() {
175                 ret = maxDistance
176                 return
177         }
178         ret.Xor(&nid0.i, &nid1.i)
179         return
180 }
181
182 func (nid *nodeID) String() string {
183         return string(nid.i.Bytes())
184 }
185
186 type Node struct {
187         addr          dHTAddr
188         id            nodeID
189         announceToken string
190
191         lastGotQuery    time.Time
192         lastGotResponse time.Time
193         lastSentQuery   time.Time
194 }
195
196 func (n *Node) idString() string {
197         return n.id.String()
198 }
199
200 func (n *Node) SetIDFromBytes(b []byte) {
201         n.id.i.SetBytes(b)
202         n.id.set = true
203 }
204
205 func (n *Node) SetIDFromString(s string) {
206         n.id.i.SetBytes([]byte(s))
207 }
208
209 func (n *Node) IDNotSet() bool {
210         return n.id.i.Int64() == 0
211 }
212
213 func (n *Node) NodeInfo() (ret NodeInfo) {
214         ret.Addr = n.addr
215         if n := copy(ret.ID[:], n.idString()); n != 20 {
216                 panic(n)
217         }
218         return
219 }
220
221 func (n *Node) DefinitelyGood() bool {
222         if len(n.idString()) != 20 {
223                 return false
224         }
225         // No reason to think ill of them if they've never been queried.
226         if n.lastSentQuery.IsZero() {
227                 return true
228         }
229         // They answered our last query.
230         if n.lastSentQuery.Before(n.lastGotResponse) {
231                 return true
232         }
233         return true
234 }
235
236 type Msg map[string]interface{}
237
238 var _ fmt.Stringer = Msg{}
239
240 func (m Msg) String() string {
241         return fmt.Sprintf("%#v", m)
242 }
243
244 func (m Msg) T() (t string) {
245         tif, ok := m["t"]
246         if !ok {
247                 return
248         }
249         t, _ = tif.(string)
250         return
251 }
252
253 func (m Msg) ID() string {
254         defer func() {
255                 recover()
256         }()
257         return m[m["y"].(string)].(map[string]interface{})["id"].(string)
258 }
259
260 // Suggested nodes in a response.
261 func (m Msg) Nodes() (nodes []NodeInfo) {
262         b := func() string {
263                 defer func() {
264                         recover()
265                 }()
266                 return m["r"].(map[string]interface{})["nodes"].(string)
267         }()
268         if len(b)%26 != 0 {
269                 return
270         }
271         for i := 0; i < len(b); i += 26 {
272                 var n NodeInfo
273                 err := n.UnmarshalCompact([]byte(b[i : i+26]))
274                 if err != nil {
275                         continue
276                 }
277                 nodes = append(nodes, n)
278         }
279         return
280 }
281
282 type KRPCError struct {
283         Code int
284         Msg  string
285 }
286
287 func (me KRPCError) Error() string {
288         return fmt.Sprintf("KRPC error %d: %s", me.Code, me.Msg)
289 }
290
291 var _ error = KRPCError{}
292
293 func (m Msg) Error() (ret *KRPCError) {
294         if m["y"] != "e" {
295                 return
296         }
297         ret = &KRPCError{}
298         switch e := m["e"].(type) {
299         case []interface{}:
300                 ret.Code = int(e[0].(int64))
301                 ret.Msg = e[1].(string)
302         case string:
303                 ret.Msg = e
304         default:
305                 logonce.Stderr.Printf(`KRPC error "e" value has unexpected type: %T`, e)
306         }
307         return
308 }
309
310 // Returns the token given in response to a get_peers request for future
311 // announce_peer requests to that node.
312 func (m Msg) AnnounceToken() (token string, ok bool) {
313         defer func() { recover() }()
314         token, ok = m["r"].(map[string]interface{})["token"].(string)
315         return
316 }
317
318 type transaction struct {
319         mu             sync.Mutex
320         remoteAddr     dHTAddr
321         t              string
322         response       chan Msg
323         onResponse     func(Msg) // Called with the server locked.
324         done           chan struct{}
325         queryPacket    []byte
326         timer          *time.Timer
327         s              *Server
328         retries        int
329         lastSend       time.Time
330         userOnResponse func(Msg)
331 }
332
333 func (t *transaction) SetResponseHandler(f func(Msg)) {
334         t.mu.Lock()
335         defer t.mu.Unlock()
336         t.userOnResponse = f
337         t.tryHandleResponse()
338 }
339
340 func (t *transaction) tryHandleResponse() {
341         if t.userOnResponse == nil {
342                 return
343         }
344         select {
345         case r := <-t.response:
346                 t.userOnResponse(r)
347                 // Shouldn't be called more than once.
348                 t.userOnResponse = nil
349         default:
350         }
351 }
352
353 func (t *transaction) Key() transactionKey {
354         return transactionKey{
355                 t.remoteAddr.String(),
356                 t.t,
357         }
358 }
359
360 func jitterDuration(average time.Duration, plusMinus time.Duration) time.Duration {
361         return average - plusMinus/2 + time.Duration(rand.Int63n(int64(plusMinus)))
362 }
363
364 func (t *transaction) startTimer() {
365         t.timer = time.AfterFunc(jitterDuration(queryResendEvery, time.Second), t.timerCallback)
366 }
367
368 func (t *transaction) timerCallback() {
369         t.mu.Lock()
370         defer t.mu.Unlock()
371         select {
372         case <-t.done:
373                 return
374         default:
375         }
376         if t.retries == 2 {
377                 t.timeout()
378                 return
379         }
380         t.retries++
381         t.sendQuery()
382         if t.timer.Reset(jitterDuration(queryResendEvery, time.Second)) {
383                 panic("timer should have fired to get here")
384         }
385 }
386
387 func (t *transaction) sendQuery() error {
388         err := t.s.writeToNode(t.queryPacket, t.remoteAddr)
389         if err != nil {
390                 return err
391         }
392         t.lastSend = time.Now()
393         return nil
394 }
395
396 func (t *transaction) timeout() {
397         go func() {
398                 t.s.mu.Lock()
399                 defer t.s.mu.Unlock()
400                 t.s.nodeTimedOut(t.remoteAddr)
401         }()
402         t.close()
403 }
404
405 func (t *transaction) close() {
406         if t.closing() {
407                 return
408         }
409         t.queryPacket = nil
410         close(t.response)
411         t.tryHandleResponse()
412         close(t.done)
413         t.timer.Stop()
414         go func() {
415                 t.s.mu.Lock()
416                 defer t.s.mu.Unlock()
417                 t.s.deleteTransaction(t)
418         }()
419 }
420
421 func (t *transaction) closing() bool {
422         select {
423         case <-t.done:
424                 return true
425         default:
426                 return false
427         }
428 }
429
430 func (t *transaction) Close() {
431         t.mu.Lock()
432         defer t.mu.Unlock()
433         t.close()
434 }
435
436 func (t *transaction) handleResponse(m Msg) {
437         t.mu.Lock()
438         if t.closing() {
439                 t.mu.Unlock()
440                 return
441         }
442         close(t.done)
443         t.mu.Unlock()
444         if t.onResponse != nil {
445                 t.s.mu.Lock()
446                 t.onResponse(m)
447                 t.s.mu.Unlock()
448         }
449         t.queryPacket = nil
450         select {
451         case t.response <- m:
452         default:
453                 panic("blocked handling response")
454         }
455         close(t.response)
456         t.tryHandleResponse()
457 }
458
459 func (s *Server) setDefaults() (err error) {
460         if s.id == "" {
461                 var id [20]byte
462                 h := crypto.SHA1.New()
463                 ss, err := os.Hostname()
464                 if err != nil {
465                         log.Print(err)
466                 }
467                 ss += s.socket.LocalAddr().String()
468                 h.Write([]byte(ss))
469                 if b := h.Sum(id[:0:20]); len(b) != 20 {
470                         panic(len(b))
471                 }
472                 if len(id) != 20 {
473                         panic(len(id))
474                 }
475                 s.id = string(id[:])
476         }
477         s.nodes = make(map[string]*Node, 10000)
478         return
479 }
480
481 func (s *Server) SetIPBlockList(list *iplist.IPList) {
482         s.mu.Lock()
483         defer s.mu.Unlock()
484         s.ipBlockList = list
485 }
486
487 func (s *Server) init() (err error) {
488         err = s.setDefaults()
489         if err != nil {
490                 return
491         }
492         s.closed = make(chan struct{})
493         s.transactions = make(map[transactionKey]*transaction)
494         return
495 }
496
497 func (s *Server) processPacket(b []byte, addr dHTAddr) {
498         var d Msg
499         err := bencode.Unmarshal(b, &d)
500         if err != nil {
501                 func() {
502                         if se, ok := err.(*bencode.SyntaxError); ok {
503                                 // The message was truncated.
504                                 if int(se.Offset) == len(b) {
505                                         return
506                                 }
507                                 // Some messages seem to drop to nul chars abrubtly.
508                                 if int(se.Offset) < len(b) && b[se.Offset] == 0 {
509                                         return
510                                 }
511                                 // The message isn't bencode from the first.
512                                 if se.Offset == 0 {
513                                         return
514                                 }
515                         }
516                         log.Printf("%s: received bad krpc message from %s: %s: %q", s, addr, err, b)
517                 }()
518                 return
519         }
520         s.mu.Lock()
521         defer s.mu.Unlock()
522         if d["y"] == "q" {
523                 s.handleQuery(addr, d)
524                 return
525         }
526         t := s.findResponseTransaction(d.T(), addr)
527         if t == nil {
528                 //log.Printf("unexpected message: %#v", d)
529                 return
530         }
531         node := s.getNode(addr)
532         node.lastGotResponse = time.Now()
533         // TODO: Update node ID as this is an authoritative packet.
534         go t.handleResponse(d)
535         s.deleteTransaction(t)
536 }
537
538 func (s *Server) serve() error {
539         var b [0x10000]byte
540         for {
541                 n, addr, err := s.socket.ReadFrom(b[:])
542                 if err != nil {
543                         return err
544                 }
545                 if n == len(b) {
546                         logonce.Stderr.Printf("received dht packet exceeds buffer size")
547                         continue
548                 }
549                 s.processPacket(b[:n], newDHTAddr(addr))
550         }
551 }
552
553 func (s *Server) ipBlocked(ip net.IP) bool {
554         if s.ipBlockList == nil {
555                 return false
556         }
557         return s.ipBlockList.Lookup(ip) != nil
558 }
559
560 func (s *Server) AddNode(ni NodeInfo) {
561         s.mu.Lock()
562         defer s.mu.Unlock()
563         if s.nodes == nil {
564                 s.nodes = make(map[string]*Node)
565         }
566         n := s.getNode(ni.Addr)
567         if n.IDNotSet() {
568                 n.SetIDFromBytes(ni.ID[:])
569         }
570 }
571
572 func (s *Server) nodeByID(id string) *Node {
573         for _, node := range s.nodes {
574                 if node.idString() == id {
575                         return node
576                 }
577         }
578         return nil
579 }
580
581 func (s *Server) handleQuery(source dHTAddr, m Msg) {
582         args := m["a"].(map[string]interface{})
583         node := s.getNode(source)
584         node.SetIDFromString(args["id"].(string))
585         node.lastGotQuery = time.Now()
586         // Don't respond.
587         if s.passive {
588                 return
589         }
590         switch m["q"] {
591         case "ping":
592                 s.reply(source, m["t"].(string), nil)
593         case "get_peers": // TODO: Extract common behaviour with find_node.
594                 targetID := args["info_hash"].(string)
595                 if len(targetID) != 20 {
596                         break
597                 }
598                 var rNodes []NodeInfo
599                 // TODO: Reply with "values" list if we have peers instead.
600                 for _, node := range s.closestGoodNodes(8, targetID) {
601                         rNodes = append(rNodes, node.NodeInfo())
602                 }
603                 nodesBytes := make([]byte, CompactNodeInfoLen*len(rNodes))
604                 for i, ni := range rNodes {
605                         err := ni.PutCompact(nodesBytes[i*CompactNodeInfoLen : (i+1)*CompactNodeInfoLen])
606                         if err != nil {
607                                 panic(err)
608                         }
609                 }
610                 s.reply(source, m["t"].(string), map[string]interface{}{
611                         "nodes": string(nodesBytes),
612                         "token": "hi",
613                 })
614         case "find_node": // TODO: Extract common behaviour with get_peers.
615                 targetID := args["target"].(string)
616                 if len(targetID) != 20 {
617                         log.Printf("bad DHT query: %v", m)
618                         return
619                 }
620                 var rNodes []NodeInfo
621                 if node := s.nodeByID(targetID); node != nil {
622                         rNodes = append(rNodes, node.NodeInfo())
623                 } else {
624                         for _, node := range s.closestGoodNodes(8, targetID) {
625                                 rNodes = append(rNodes, node.NodeInfo())
626                         }
627                 }
628                 nodesBytes := make([]byte, CompactNodeInfoLen*len(rNodes))
629                 for i, ni := range rNodes {
630                         // TODO: Put IPv6 nodes into the correct dict element.
631                         if ni.Addr.UDPAddr().IP.To4() == nil {
632                                 continue
633                         }
634                         err := ni.PutCompact(nodesBytes[i*CompactNodeInfoLen : (i+1)*CompactNodeInfoLen])
635                         if err != nil {
636                                 log.Printf("error compacting %#v: %s", ni, err)
637                                 continue
638                         }
639                 }
640                 s.reply(source, m["t"].(string), map[string]interface{}{
641                         "nodes": string(nodesBytes),
642                 })
643         case "announce_peer":
644                 // TODO(anacrolix): Implement this lolz.
645                 // log.Print(m)
646         case "vote":
647                 // TODO(anacrolix): Or reject, I don't think I want this.
648         default:
649                 log.Printf("%s: not handling received query: q=%s", s, m["q"])
650                 return
651         }
652 }
653
654 func (s *Server) reply(addr dHTAddr, t string, r map[string]interface{}) {
655         if r == nil {
656                 r = make(map[string]interface{}, 1)
657         }
658         r["id"] = s.IDString()
659         m := map[string]interface{}{
660                 "t": t,
661                 "y": "r",
662                 "r": r,
663         }
664         b, err := bencode.Marshal(m)
665         if err != nil {
666                 panic(err)
667         }
668         err = s.writeToNode(b, addr)
669         if err != nil {
670                 log.Printf("error replying to %s: %s", addr, err)
671         }
672 }
673
674 func (s *Server) getNode(addr dHTAddr) (n *Node) {
675         addrStr := addr.String()
676         n = s.nodes[addrStr]
677         if n == nil {
678                 n = &Node{
679                         addr: addr,
680                 }
681                 if len(s.nodes) < maxNodes {
682                         s.nodes[addrStr] = n
683                 }
684         }
685         return
686 }
687 func (s *Server) nodeTimedOut(addr dHTAddr) {
688         node, ok := s.nodes[addr.String()]
689         if !ok {
690                 return
691         }
692         if node.DefinitelyGood() {
693                 return
694         }
695         if len(s.nodes) < maxNodes {
696                 return
697         }
698         delete(s.nodes, addr.String())
699 }
700
701 func (s *Server) writeToNode(b []byte, node dHTAddr) (err error) {
702         if list := s.ipBlockList; list != nil {
703                 if r := list.Lookup(util.AddrIP(node.UDPAddr())); r != nil {
704                         err = fmt.Errorf("write to %s blocked: %s", node, r.Description)
705                         return
706                 }
707         }
708         n, err := s.socket.WriteTo(b, node.UDPAddr())
709         if err != nil {
710                 err = fmt.Errorf("error writing %d bytes to %s: %#v", len(b), node, err)
711                 return
712         }
713         if n != len(b) {
714                 err = io.ErrShortWrite
715                 return
716         }
717         return
718 }
719
720 func (s *Server) findResponseTransaction(transactionID string, sourceNode dHTAddr) *transaction {
721         return s.transactions[transactionKey{
722                 sourceNode.String(),
723                 transactionID}]
724 }
725
726 func (s *Server) nextTransactionID() string {
727         var b [binary.MaxVarintLen64]byte
728         n := binary.PutUvarint(b[:], s.transactionIDInt)
729         s.transactionIDInt++
730         return string(b[:n])
731 }
732
733 func (s *Server) deleteTransaction(t *transaction) {
734         delete(s.transactions, t.Key())
735 }
736
737 func (s *Server) addTransaction(t *transaction) {
738         if _, ok := s.transactions[t.Key()]; ok {
739                 panic("transaction not unique")
740         }
741         s.transactions[t.Key()] = t
742 }
743
744 func (s *Server) IDString() string {
745         if len(s.id) != 20 {
746                 panic("bad node id")
747         }
748         return s.id
749 }
750
751 func (s *Server) query(node dHTAddr, q string, a map[string]interface{}, onResponse func(Msg)) (t *transaction, err error) {
752         tid := s.nextTransactionID()
753         if a == nil {
754                 a = make(map[string]interface{}, 1)
755         }
756         a["id"] = s.IDString()
757         d := map[string]interface{}{
758                 "t": tid,
759                 "y": "q",
760                 "q": q,
761                 "a": a,
762         }
763         b, err := bencode.Marshal(d)
764         if err != nil {
765                 return
766         }
767         t = &transaction{
768                 remoteAddr:  node,
769                 t:           tid,
770                 response:    make(chan Msg, 1),
771                 done:        make(chan struct{}),
772                 queryPacket: b,
773                 s:           s,
774                 onResponse:  onResponse,
775         }
776         err = t.sendQuery()
777         if err != nil {
778                 return
779         }
780         s.getNode(node).lastSentQuery = time.Now()
781         t.startTimer()
782         s.addTransaction(t)
783         return
784 }
785
786 const CompactNodeInfoLen = 26
787
788 type NodeInfo struct {
789         ID   [20]byte
790         Addr dHTAddr
791 }
792
793 func (ni *NodeInfo) PutCompact(b []byte) error {
794         if n := copy(b[:], ni.ID[:]); n != 20 {
795                 panic(n)
796         }
797         ip := util.AddrIP(ni.Addr).To4()
798         if len(ip) != 4 {
799                 return errors.New("expected ipv4 address")
800         }
801         if n := copy(b[20:], ip); n != 4 {
802                 panic(n)
803         }
804         binary.BigEndian.PutUint16(b[24:], uint16(util.AddrPort(ni.Addr)))
805         return nil
806 }
807
808 func (cni *NodeInfo) UnmarshalCompact(b []byte) error {
809         if len(b) != 26 {
810                 return errors.New("expected 26 bytes")
811         }
812         util.CopyExact(cni.ID[:], b[:20])
813         cni.Addr = newDHTAddr(&net.UDPAddr{
814                 IP:   net.IPv4(b[20], b[21], b[22], b[23]),
815                 Port: int(binary.BigEndian.Uint16(b[24:26])),
816         })
817         return nil
818 }
819
820 func (s *Server) Ping(node *net.UDPAddr) (*transaction, error) {
821         s.mu.Lock()
822         defer s.mu.Unlock()
823         return s.query(newDHTAddr(node), "ping", nil, nil)
824 }
825
826 // Announce a local peer. This can only be done to nodes that gave us an
827 // announce token, which is received in responses during GetPeers. It's
828 // recommended then that GetPeers is called before this method.
829 func (s *Server) AnnouncePeer(port int, impliedPort bool, infoHash string) (err error) {
830         s.mu.Lock()
831         defer s.mu.Unlock()
832         for _, node := range s.closestNodes(160, nodeIDFromString(infoHash), func(n *Node) bool {
833                 return n.announceToken != ""
834         }) {
835                 err = s.announcePeer(node.addr, infoHash, port, node.announceToken, impliedPort)
836                 if err != nil {
837                         break
838                 }
839         }
840         return
841 }
842
843 func (s *Server) announcePeer(node dHTAddr, infoHash string, port int, token string, impliedPort bool) (err error) {
844         if port == 0 && !impliedPort {
845                 return errors.New("nothing to announce")
846         }
847         _, err = s.query(node, "announce_peer", map[string]interface{}{
848                 "implied_port": func() int {
849                         if impliedPort {
850                                 return 1
851                         } else {
852                                 return 0
853                         }
854                 }(),
855                 "info_hash": infoHash,
856                 "port":      port,
857                 "token":     token,
858         }, func(m Msg) {
859                 if err := m.Error(); err != nil {
860                         logonce.Stderr.Printf("announce_peer response: %s", err)
861                         return
862                 }
863                 s.NumConfirmedAnnounces++
864         })
865         return
866 }
867
868 // Add response nodes to node table.
869 func (s *Server) liftNodes(d Msg) {
870         if d["y"] != "r" {
871                 return
872         }
873         for _, cni := range d.Nodes() {
874                 if util.AddrPort(cni.Addr) == 0 {
875                         // TODO: Why would people even do this?
876                         continue
877                 }
878                 if s.ipBlocked(util.AddrIP(cni.Addr)) {
879                         continue
880                 }
881                 n := s.getNode(cni.Addr)
882                 n.SetIDFromBytes(cni.ID[:])
883         }
884 }
885
886 // Sends a find_node query to addr. targetID is the node we're looking for.
887 func (s *Server) findNode(addr dHTAddr, targetID string) (t *transaction, err error) {
888         t, err = s.query(addr, "find_node", map[string]interface{}{"target": targetID}, func(d Msg) {
889                 // Scrape peers from the response to put in the server's table before
890                 // handing the response back to the caller.
891                 s.liftNodes(d)
892         })
893         if err != nil {
894                 return
895         }
896         return
897 }
898
899 // In a get_peers response, the addresses of torrent clients involved with the
900 // queried info-hash.
901 func (m Msg) Values() (vs []util.CompactPeer) {
902         r, ok := m["r"]
903         if !ok {
904                 return
905         }
906         rd, ok := r.(map[string]interface{})
907         if !ok {
908                 return
909         }
910         v, ok := rd["values"]
911         if !ok {
912                 return
913         }
914         vl, ok := v.([]interface{})
915         if !ok {
916                 log.Printf("unexpected krpc values type: %T", v)
917                 return
918         }
919         vs = make([]util.CompactPeer, 0, len(vl))
920         for _, i := range vl {
921                 s, ok := i.(string)
922                 if !ok {
923                         panic(i)
924                 }
925                 var cp util.CompactPeer
926                 err := cp.UnmarshalBinary([]byte(s))
927                 if err != nil {
928                         log.Printf("error decoding values list element: %s", err)
929                         continue
930                 }
931                 vs = append(vs, cp)
932         }
933         return
934 }
935
936 func (s *Server) getPeers(addr dHTAddr, infoHash string) (t *transaction, err error) {
937         if len(infoHash) != 20 {
938                 err = fmt.Errorf("infohash has bad length")
939                 return
940         }
941         t, err = s.query(addr, "get_peers", map[string]interface{}{"info_hash": infoHash}, func(m Msg) {
942                 s.liftNodes(m)
943                 at, ok := m.AnnounceToken()
944                 if ok {
945                         s.getNode(addr).announceToken = at
946                 }
947         })
948         return
949 }
950
951 func bootstrapAddrs() (addrs []*net.UDPAddr, err error) {
952         for _, addrStr := range []string{
953                 "router.utorrent.com:6881",
954                 "router.bittorrent.com:6881",
955         } {
956                 udpAddr, err := net.ResolveUDPAddr("udp4", addrStr)
957                 if err != nil {
958                         continue
959                 }
960                 addrs = append(addrs, udpAddr)
961         }
962         if len(addrs) == 0 {
963                 err = errors.New("nothing resolved")
964         }
965         return
966 }
967
968 func (s *Server) addRootNodes() error {
969         addrs, err := bootstrapAddrs()
970         if err != nil {
971                 return err
972         }
973         for _, addr := range addrs {
974                 s.nodes[addr.String()] = &Node{
975                         addr: newDHTAddr(addr),
976                 }
977         }
978         return nil
979 }
980
981 // Populates the node table.
982 func (s *Server) bootstrap() (err error) {
983         s.mu.Lock()
984         defer s.mu.Unlock()
985         if len(s.nodes) == 0 {
986                 err = s.addRootNodes()
987         }
988         if err != nil {
989                 return
990         }
991         for {
992                 var outstanding sync.WaitGroup
993                 for _, node := range s.nodes {
994                         var t *transaction
995                         t, err = s.findNode(node.addr, s.id)
996                         if err != nil {
997                                 err = fmt.Errorf("error sending find_node: %s", err)
998                                 return
999                         }
1000                         outstanding.Add(1)
1001                         t.SetResponseHandler(func(Msg) {
1002                                 outstanding.Done()
1003                         })
1004                 }
1005                 noOutstanding := make(chan struct{})
1006                 go func() {
1007                         outstanding.Wait()
1008                         close(noOutstanding)
1009                 }()
1010                 s.mu.Unlock()
1011                 select {
1012                 case <-s.closed:
1013                         s.mu.Lock()
1014                         return
1015                 case <-time.After(15 * time.Second):
1016                 case <-noOutstanding:
1017                 }
1018                 s.mu.Lock()
1019                 // log.Printf("now have %d nodes", len(s.nodes))
1020                 if s.numGoodNodes() >= 160 {
1021                         break
1022                 }
1023         }
1024         return
1025 }
1026
1027 func (s *Server) numGoodNodes() (num int) {
1028         for _, n := range s.nodes {
1029                 if n.DefinitelyGood() {
1030                         num++
1031                 }
1032         }
1033         return
1034 }
1035
1036 func (s *Server) NumNodes() int {
1037         s.mu.Lock()
1038         defer s.mu.Unlock()
1039         return len(s.nodes)
1040 }
1041
1042 func (s *Server) Nodes() (nis []NodeInfo) {
1043         s.mu.Lock()
1044         defer s.mu.Unlock()
1045         for _, node := range s.nodes {
1046                 // if !node.Good() {
1047                 //      continue
1048                 // }
1049                 ni := NodeInfo{
1050                         Addr: node.addr,
1051                 }
1052                 if n := copy(ni.ID[:], node.idString()); n != 20 && n != 0 {
1053                         panic(n)
1054                 }
1055                 nis = append(nis, ni)
1056         }
1057         return
1058 }
1059
1060 func (s *Server) Close() {
1061         s.mu.Lock()
1062         select {
1063         case <-s.closed:
1064         default:
1065                 close(s.closed)
1066                 s.socket.Close()
1067         }
1068         s.mu.Unlock()
1069 }
1070
1071 var maxDistance big.Int
1072
1073 func init() {
1074         var zero big.Int
1075         maxDistance.SetBit(&zero, 160, 1)
1076 }
1077
1078 func (s *Server) closestGoodNodes(k int, targetID string) []*Node {
1079         return s.closestNodes(k, nodeIDFromString(targetID), func(n *Node) bool { return n.DefinitelyGood() })
1080 }
1081
1082 func (s *Server) closestNodes(k int, target nodeID, filter func(*Node) bool) []*Node {
1083         sel := newKClosestNodesSelector(k, target)
1084         idNodes := make(map[string]*Node, len(s.nodes))
1085         for _, node := range s.nodes {
1086                 if !filter(node) {
1087                         continue
1088                 }
1089                 sel.Push(node.id)
1090                 idNodes[node.idString()] = node
1091         }
1092         ids := sel.IDs()
1093         ret := make([]*Node, 0, len(ids))
1094         for _, id := range ids {
1095                 ret = append(ret, idNodes[id.String()])
1096         }
1097         return ret
1098 }