]> Sergey Matveev's repositories - btrtrc.git/blob - dht/dht.go
5f70993438ca782b2e0cf7f19e5bca3d66cb7343
[btrtrc.git] / dht / dht.go
1 package dht
2
3 import (
4         "crypto"
5         _ "crypto/sha1"
6         "encoding/binary"
7         "errors"
8         "fmt"
9         "io"
10         "log"
11         "math/big"
12         "math/rand"
13         "net"
14         "os"
15         "time"
16
17         "github.com/anacrolix/libtorgo/bencode"
18         "github.com/anacrolix/sync"
19
20         "github.com/anacrolix/torrent/iplist"
21         "github.com/anacrolix/torrent/logonce"
22         "github.com/anacrolix/torrent/util"
23 )
24
25 const (
26         maxNodes         = 320
27         queryResendEvery = 5 * time.Second
28 )
29
30 // Uniquely identifies a transaction to us.
31 type transactionKey struct {
32         RemoteAddr string // host:port
33         T          string // The KRPC transaction ID.
34 }
35
36 type Server struct {
37         id               string
38         socket           net.PacketConn
39         transactions     map[transactionKey]*transaction
40         transactionIDInt uint64
41         nodes            map[string]*Node // Keyed by dHTAddr.String().
42         mu               sync.Mutex
43         closed           chan struct{}
44         passive          bool // Don't respond to queries.
45         ipBlockList      *iplist.IPList
46
47         NumConfirmedAnnounces int
48 }
49
50 type dHTAddr interface {
51         net.Addr
52         UDPAddr() *net.UDPAddr
53 }
54
55 type cachedAddr struct {
56         a net.Addr
57         s string
58 }
59
60 func (ca cachedAddr) Network() string {
61         return ca.a.Network()
62 }
63
64 func (ca cachedAddr) String() string {
65         return ca.s
66 }
67
68 func (ca cachedAddr) UDPAddr() *net.UDPAddr {
69         return ca.a.(*net.UDPAddr)
70 }
71
72 func newDHTAddr(addr net.Addr) dHTAddr {
73         return cachedAddr{addr, addr.String()}
74 }
75
76 type ServerConfig struct {
77         Addr    string
78         Conn    net.PacketConn
79         Passive bool // Don't respond to queries.
80 }
81
82 type serverStats struct {
83         NumGoodNodes               int
84         NumNodes                   int
85         NumOutstandingTransactions int
86 }
87
88 func (s *Server) Stats() (ss serverStats) {
89         s.mu.Lock()
90         defer s.mu.Unlock()
91         for _, n := range s.nodes {
92                 if n.DefinitelyGood() {
93                         ss.NumGoodNodes++
94                 }
95         }
96         ss.NumNodes = len(s.nodes)
97         ss.NumOutstandingTransactions = len(s.transactions)
98         return
99 }
100
101 func (s *Server) LocalAddr() net.Addr {
102         return s.socket.LocalAddr()
103 }
104
105 func makeSocket(addr string) (socket *net.UDPConn, err error) {
106         addr_, err := net.ResolveUDPAddr("", addr)
107         if err != nil {
108                 return
109         }
110         socket, err = net.ListenUDP("udp", addr_)
111         return
112 }
113
114 func NewServer(c *ServerConfig) (s *Server, err error) {
115         if c == nil {
116                 c = &ServerConfig{}
117         }
118         s = &Server{}
119         if c.Conn != nil {
120                 s.socket = c.Conn
121         } else {
122                 s.socket, err = makeSocket(c.Addr)
123                 if err != nil {
124                         return
125                 }
126         }
127         s.passive = c.Passive
128         err = s.init()
129         if err != nil {
130                 return
131         }
132         go func() {
133                 err := s.serve()
134                 select {
135                 case <-s.closed:
136                         return
137                 default:
138                 }
139                 if err != nil {
140                         panic(err)
141                 }
142         }()
143         go func() {
144                 err := s.bootstrap()
145                 if err != nil {
146                         log.Printf("error bootstrapping DHT: %s", err)
147                 }
148         }()
149         return
150 }
151
152 func (s *Server) String() string {
153         return fmt.Sprintf("dht server on %s", s.socket.LocalAddr())
154 }
155
156 type nodeID struct {
157         i   big.Int
158         set bool
159 }
160
161 func (nid *nodeID) IsUnset() bool {
162         return !nid.set
163 }
164
165 func nodeIDFromString(s string) (ret nodeID) {
166         if s == "" {
167                 return
168         }
169         ret.i.SetBytes([]byte(s))
170         ret.set = true
171         return
172 }
173
174 func (nid0 *nodeID) Distance(nid1 *nodeID) (ret big.Int) {
175         if nid0.IsUnset() != nid1.IsUnset() {
176                 ret = maxDistance
177                 return
178         }
179         ret.Xor(&nid0.i, &nid1.i)
180         return
181 }
182
183 func (nid *nodeID) String() string {
184         return string(nid.i.Bytes())
185 }
186
187 type Node struct {
188         addr          dHTAddr
189         id            nodeID
190         announceToken string
191
192         lastGotQuery    time.Time
193         lastGotResponse time.Time
194         lastSentQuery   time.Time
195 }
196
197 func (n *Node) idString() string {
198         return n.id.String()
199 }
200
201 func (n *Node) SetIDFromBytes(b []byte) {
202         n.id.i.SetBytes(b)
203         n.id.set = true
204 }
205
206 func (n *Node) SetIDFromString(s string) {
207         n.id.i.SetBytes([]byte(s))
208 }
209
210 func (n *Node) IDNotSet() bool {
211         return n.id.i.Int64() == 0
212 }
213
214 func (n *Node) NodeInfo() (ret NodeInfo) {
215         ret.Addr = n.addr
216         if n := copy(ret.ID[:], n.idString()); n != 20 {
217                 panic(n)
218         }
219         return
220 }
221
222 func (n *Node) DefinitelyGood() bool {
223         if len(n.idString()) != 20 {
224                 return false
225         }
226         // No reason to think ill of them if they've never been queried.
227         if n.lastSentQuery.IsZero() {
228                 return true
229         }
230         // They answered our last query.
231         if n.lastSentQuery.Before(n.lastGotResponse) {
232                 return true
233         }
234         return true
235 }
236
237 type Msg map[string]interface{}
238
239 var _ fmt.Stringer = Msg{}
240
241 func (m Msg) String() string {
242         return fmt.Sprintf("%#v", m)
243 }
244
245 func (m Msg) T() (t string) {
246         tif, ok := m["t"]
247         if !ok {
248                 return
249         }
250         t, _ = tif.(string)
251         return
252 }
253
254 func (m Msg) ID() string {
255         defer func() {
256                 recover()
257         }()
258         return m[m["y"].(string)].(map[string]interface{})["id"].(string)
259 }
260
261 // Suggested nodes in a response.
262 func (m Msg) Nodes() (nodes []NodeInfo) {
263         b := func() string {
264                 defer func() {
265                         recover()
266                 }()
267                 return m["r"].(map[string]interface{})["nodes"].(string)
268         }()
269         if len(b)%26 != 0 {
270                 return
271         }
272         for i := 0; i < len(b); i += 26 {
273                 var n NodeInfo
274                 err := n.UnmarshalCompact([]byte(b[i : i+26]))
275                 if err != nil {
276                         continue
277                 }
278                 nodes = append(nodes, n)
279         }
280         return
281 }
282
283 type KRPCError struct {
284         Code int
285         Msg  string
286 }
287
288 func (me KRPCError) Error() string {
289         return fmt.Sprintf("KRPC error %d: %s", me.Code, me.Msg)
290 }
291
292 var _ error = KRPCError{}
293
294 func (m Msg) Error() (ret *KRPCError) {
295         if m["y"] != "e" {
296                 return
297         }
298         ret = &KRPCError{}
299         switch e := m["e"].(type) {
300         case []interface{}:
301                 ret.Code = int(e[0].(int64))
302                 ret.Msg = e[1].(string)
303         case string:
304                 ret.Msg = e
305         default:
306                 logonce.Stderr.Printf(`KRPC error "e" value has unexpected type: %T`, e)
307         }
308         return
309 }
310
311 // Returns the token given in response to a get_peers request for future
312 // announce_peer requests to that node.
313 func (m Msg) AnnounceToken() (token string, ok bool) {
314         defer func() { recover() }()
315         token, ok = m["r"].(map[string]interface{})["token"].(string)
316         return
317 }
318
319 type transaction struct {
320         mu             sync.Mutex
321         remoteAddr     dHTAddr
322         t              string
323         response       chan Msg
324         onResponse     func(Msg) // Called with the server locked.
325         done           chan struct{}
326         queryPacket    []byte
327         timer          *time.Timer
328         s              *Server
329         retries        int
330         lastSend       time.Time
331         userOnResponse func(Msg)
332 }
333
334 func (t *transaction) SetResponseHandler(f func(Msg)) {
335         t.mu.Lock()
336         defer t.mu.Unlock()
337         t.userOnResponse = f
338         t.tryHandleResponse()
339 }
340
341 func (t *transaction) tryHandleResponse() {
342         if t.userOnResponse == nil {
343                 return
344         }
345         select {
346         case r := <-t.response:
347                 t.userOnResponse(r)
348                 // Shouldn't be called more than once.
349                 t.userOnResponse = nil
350         default:
351         }
352 }
353
354 func (t *transaction) Key() transactionKey {
355         return transactionKey{
356                 t.remoteAddr.String(),
357                 t.t,
358         }
359 }
360
361 func jitterDuration(average time.Duration, plusMinus time.Duration) time.Duration {
362         return average - plusMinus/2 + time.Duration(rand.Int63n(int64(plusMinus)))
363 }
364
365 func (t *transaction) startTimer() {
366         t.timer = time.AfterFunc(jitterDuration(queryResendEvery, time.Second), t.timerCallback)
367 }
368
369 func (t *transaction) timerCallback() {
370         t.mu.Lock()
371         defer t.mu.Unlock()
372         select {
373         case <-t.done:
374                 return
375         default:
376         }
377         if t.retries == 2 {
378                 t.timeout()
379                 return
380         }
381         t.retries++
382         t.sendQuery()
383         if t.timer.Reset(jitterDuration(queryResendEvery, time.Second)) {
384                 panic("timer should have fired to get here")
385         }
386 }
387
388 func (t *transaction) sendQuery() error {
389         err := t.s.writeToNode(t.queryPacket, t.remoteAddr)
390         if err != nil {
391                 return err
392         }
393         t.lastSend = time.Now()
394         return nil
395 }
396
397 func (t *transaction) timeout() {
398         go func() {
399                 t.s.mu.Lock()
400                 defer t.s.mu.Unlock()
401                 t.s.nodeTimedOut(t.remoteAddr)
402         }()
403         t.close()
404 }
405
406 func (t *transaction) close() {
407         if t.closing() {
408                 return
409         }
410         t.queryPacket = nil
411         close(t.response)
412         t.tryHandleResponse()
413         close(t.done)
414         t.timer.Stop()
415         go func() {
416                 t.s.mu.Lock()
417                 defer t.s.mu.Unlock()
418                 t.s.deleteTransaction(t)
419         }()
420 }
421
422 func (t *transaction) closing() bool {
423         select {
424         case <-t.done:
425                 return true
426         default:
427                 return false
428         }
429 }
430
431 func (t *transaction) Close() {
432         t.mu.Lock()
433         defer t.mu.Unlock()
434         t.close()
435 }
436
437 func (t *transaction) handleResponse(m Msg) {
438         t.mu.Lock()
439         if t.closing() {
440                 t.mu.Unlock()
441                 return
442         }
443         close(t.done)
444         t.mu.Unlock()
445         if t.onResponse != nil {
446                 t.s.mu.Lock()
447                 t.onResponse(m)
448                 t.s.mu.Unlock()
449         }
450         t.queryPacket = nil
451         select {
452         case t.response <- m:
453         default:
454                 panic("blocked handling response")
455         }
456         close(t.response)
457         t.tryHandleResponse()
458 }
459
460 func (s *Server) setDefaults() (err error) {
461         if s.id == "" {
462                 var id [20]byte
463                 h := crypto.SHA1.New()
464                 ss, err := os.Hostname()
465                 if err != nil {
466                         log.Print(err)
467                 }
468                 ss += s.socket.LocalAddr().String()
469                 h.Write([]byte(ss))
470                 if b := h.Sum(id[:0:20]); len(b) != 20 {
471                         panic(len(b))
472                 }
473                 if len(id) != 20 {
474                         panic(len(id))
475                 }
476                 s.id = string(id[:])
477         }
478         s.nodes = make(map[string]*Node, 10000)
479         return
480 }
481
482 func (s *Server) SetIPBlockList(list *iplist.IPList) {
483         s.mu.Lock()
484         defer s.mu.Unlock()
485         s.ipBlockList = list
486 }
487
488 func (s *Server) init() (err error) {
489         err = s.setDefaults()
490         if err != nil {
491                 return
492         }
493         s.closed = make(chan struct{})
494         s.transactions = make(map[transactionKey]*transaction)
495         return
496 }
497
498 func (s *Server) processPacket(b []byte, addr dHTAddr) {
499         var d Msg
500         err := bencode.Unmarshal(b, &d)
501         if err != nil {
502                 func() {
503                         if se, ok := err.(*bencode.SyntaxError); ok {
504                                 // The message was truncated.
505                                 if int(se.Offset) == len(b) {
506                                         return
507                                 }
508                                 // Some messages seem to drop to nul chars abrubtly.
509                                 if int(se.Offset) < len(b) && b[se.Offset] == 0 {
510                                         return
511                                 }
512                                 // The message isn't bencode from the first.
513                                 if se.Offset == 0 {
514                                         return
515                                 }
516                         }
517                         log.Printf("%s: received bad krpc message from %s: %s: %q", s, addr, err, b)
518                 }()
519                 return
520         }
521         s.mu.Lock()
522         defer s.mu.Unlock()
523         if d["y"] == "q" {
524                 s.handleQuery(addr, d)
525                 return
526         }
527         t := s.findResponseTransaction(d.T(), addr)
528         if t == nil {
529                 //log.Printf("unexpected message: %#v", d)
530                 return
531         }
532         node := s.getNode(addr)
533         node.lastGotResponse = time.Now()
534         // TODO: Update node ID as this is an authoritative packet.
535         go t.handleResponse(d)
536         s.deleteTransaction(t)
537 }
538
539 func (s *Server) serve() error {
540         var b [0x10000]byte
541         for {
542                 n, addr, err := s.socket.ReadFrom(b[:])
543                 if err != nil {
544                         return err
545                 }
546                 if n == len(b) {
547                         logonce.Stderr.Printf("received dht packet exceeds buffer size")
548                         continue
549                 }
550                 s.processPacket(b[:n], newDHTAddr(addr))
551         }
552 }
553
554 func (s *Server) ipBlocked(ip net.IP) bool {
555         if s.ipBlockList == nil {
556                 return false
557         }
558         return s.ipBlockList.Lookup(ip) != nil
559 }
560
561 func (s *Server) AddNode(ni NodeInfo) {
562         s.mu.Lock()
563         defer s.mu.Unlock()
564         if s.nodes == nil {
565                 s.nodes = make(map[string]*Node)
566         }
567         n := s.getNode(ni.Addr)
568         if n.IDNotSet() {
569                 n.SetIDFromBytes(ni.ID[:])
570         }
571 }
572
573 func (s *Server) nodeByID(id string) *Node {
574         for _, node := range s.nodes {
575                 if node.idString() == id {
576                         return node
577                 }
578         }
579         return nil
580 }
581
582 func (s *Server) handleQuery(source dHTAddr, m Msg) {
583         args := m["a"].(map[string]interface{})
584         node := s.getNode(source)
585         node.SetIDFromString(args["id"].(string))
586         node.lastGotQuery = time.Now()
587         // Don't respond.
588         if s.passive {
589                 return
590         }
591         switch m["q"] {
592         case "ping":
593                 s.reply(source, m["t"].(string), nil)
594         case "get_peers": // TODO: Extract common behaviour with find_node.
595                 targetID := args["info_hash"].(string)
596                 if len(targetID) != 20 {
597                         break
598                 }
599                 var rNodes []NodeInfo
600                 // TODO: Reply with "values" list if we have peers instead.
601                 for _, node := range s.closestGoodNodes(8, targetID) {
602                         rNodes = append(rNodes, node.NodeInfo())
603                 }
604                 nodesBytes := make([]byte, CompactNodeInfoLen*len(rNodes))
605                 for i, ni := range rNodes {
606                         err := ni.PutCompact(nodesBytes[i*CompactNodeInfoLen : (i+1)*CompactNodeInfoLen])
607                         if err != nil {
608                                 panic(err)
609                         }
610                 }
611                 s.reply(source, m["t"].(string), map[string]interface{}{
612                         "nodes": string(nodesBytes),
613                         "token": "hi",
614                 })
615         case "find_node": // TODO: Extract common behaviour with get_peers.
616                 targetID := args["target"].(string)
617                 if len(targetID) != 20 {
618                         log.Printf("bad DHT query: %v", m)
619                         return
620                 }
621                 var rNodes []NodeInfo
622                 if node := s.nodeByID(targetID); node != nil {
623                         rNodes = append(rNodes, node.NodeInfo())
624                 } else {
625                         for _, node := range s.closestGoodNodes(8, targetID) {
626                                 rNodes = append(rNodes, node.NodeInfo())
627                         }
628                 }
629                 nodesBytes := make([]byte, CompactNodeInfoLen*len(rNodes))
630                 for i, ni := range rNodes {
631                         // TODO: Put IPv6 nodes into the correct dict element.
632                         if ni.Addr.UDPAddr().IP.To4() == nil {
633                                 continue
634                         }
635                         err := ni.PutCompact(nodesBytes[i*CompactNodeInfoLen : (i+1)*CompactNodeInfoLen])
636                         if err != nil {
637                                 log.Printf("error compacting %#v: %s", ni, err)
638                                 continue
639                         }
640                 }
641                 s.reply(source, m["t"].(string), map[string]interface{}{
642                         "nodes": string(nodesBytes),
643                 })
644         case "announce_peer":
645                 // TODO(anacrolix): Implement this lolz.
646                 // log.Print(m)
647         case "vote":
648                 // TODO(anacrolix): Or reject, I don't think I want this.
649         default:
650                 log.Printf("%s: not handling received query: q=%s", s, m["q"])
651                 return
652         }
653 }
654
655 func (s *Server) reply(addr dHTAddr, t string, r map[string]interface{}) {
656         if r == nil {
657                 r = make(map[string]interface{}, 1)
658         }
659         r["id"] = s.IDString()
660         m := map[string]interface{}{
661                 "t": t,
662                 "y": "r",
663                 "r": r,
664         }
665         b, err := bencode.Marshal(m)
666         if err != nil {
667                 panic(err)
668         }
669         err = s.writeToNode(b, addr)
670         if err != nil {
671                 log.Printf("error replying to %s: %s", addr, err)
672         }
673 }
674
675 func (s *Server) getNode(addr dHTAddr) (n *Node) {
676         addrStr := addr.String()
677         n = s.nodes[addrStr]
678         if n == nil {
679                 n = &Node{
680                         addr: addr,
681                 }
682                 if len(s.nodes) < maxNodes {
683                         s.nodes[addrStr] = n
684                 }
685         }
686         return
687 }
688 func (s *Server) nodeTimedOut(addr dHTAddr) {
689         node, ok := s.nodes[addr.String()]
690         if !ok {
691                 return
692         }
693         if node.DefinitelyGood() {
694                 return
695         }
696         if len(s.nodes) < maxNodes {
697                 return
698         }
699         delete(s.nodes, addr.String())
700 }
701
702 func (s *Server) writeToNode(b []byte, node dHTAddr) (err error) {
703         if list := s.ipBlockList; list != nil {
704                 if r := list.Lookup(util.AddrIP(node.UDPAddr())); r != nil {
705                         err = fmt.Errorf("write to %s blocked: %s", node, r.Description)
706                         return
707                 }
708         }
709         n, err := s.socket.WriteTo(b, node.UDPAddr())
710         if err != nil {
711                 err = fmt.Errorf("error writing %d bytes to %s: %#v", len(b), node, err)
712                 return
713         }
714         if n != len(b) {
715                 err = io.ErrShortWrite
716                 return
717         }
718         return
719 }
720
721 func (s *Server) findResponseTransaction(transactionID string, sourceNode dHTAddr) *transaction {
722         return s.transactions[transactionKey{
723                 sourceNode.String(),
724                 transactionID}]
725 }
726
727 func (s *Server) nextTransactionID() string {
728         var b [binary.MaxVarintLen64]byte
729         n := binary.PutUvarint(b[:], s.transactionIDInt)
730         s.transactionIDInt++
731         return string(b[:n])
732 }
733
734 func (s *Server) deleteTransaction(t *transaction) {
735         delete(s.transactions, t.Key())
736 }
737
738 func (s *Server) addTransaction(t *transaction) {
739         if _, ok := s.transactions[t.Key()]; ok {
740                 panic("transaction not unique")
741         }
742         s.transactions[t.Key()] = t
743 }
744
745 func (s *Server) IDString() string {
746         if len(s.id) != 20 {
747                 panic("bad node id")
748         }
749         return s.id
750 }
751
752 func (s *Server) query(node dHTAddr, q string, a map[string]interface{}, onResponse func(Msg)) (t *transaction, err error) {
753         tid := s.nextTransactionID()
754         if a == nil {
755                 a = make(map[string]interface{}, 1)
756         }
757         a["id"] = s.IDString()
758         d := map[string]interface{}{
759                 "t": tid,
760                 "y": "q",
761                 "q": q,
762                 "a": a,
763         }
764         b, err := bencode.Marshal(d)
765         if err != nil {
766                 return
767         }
768         t = &transaction{
769                 remoteAddr:  node,
770                 t:           tid,
771                 response:    make(chan Msg, 1),
772                 done:        make(chan struct{}),
773                 queryPacket: b,
774                 s:           s,
775                 onResponse:  onResponse,
776         }
777         err = t.sendQuery()
778         if err != nil {
779                 return
780         }
781         s.getNode(node).lastSentQuery = time.Now()
782         t.startTimer()
783         s.addTransaction(t)
784         return
785 }
786
787 const CompactNodeInfoLen = 26
788
789 type NodeInfo struct {
790         ID   [20]byte
791         Addr dHTAddr
792 }
793
794 func (ni *NodeInfo) PutCompact(b []byte) error {
795         if n := copy(b[:], ni.ID[:]); n != 20 {
796                 panic(n)
797         }
798         ip := util.AddrIP(ni.Addr).To4()
799         if len(ip) != 4 {
800                 return errors.New("expected ipv4 address")
801         }
802         if n := copy(b[20:], ip); n != 4 {
803                 panic(n)
804         }
805         binary.BigEndian.PutUint16(b[24:], uint16(util.AddrPort(ni.Addr)))
806         return nil
807 }
808
809 func (cni *NodeInfo) UnmarshalCompact(b []byte) error {
810         if len(b) != 26 {
811                 return errors.New("expected 26 bytes")
812         }
813         util.CopyExact(cni.ID[:], b[:20])
814         cni.Addr = newDHTAddr(&net.UDPAddr{
815                 IP:   net.IPv4(b[20], b[21], b[22], b[23]),
816                 Port: int(binary.BigEndian.Uint16(b[24:26])),
817         })
818         return nil
819 }
820
821 func (s *Server) Ping(node *net.UDPAddr) (*transaction, error) {
822         s.mu.Lock()
823         defer s.mu.Unlock()
824         return s.query(newDHTAddr(node), "ping", nil, nil)
825 }
826
827 // Announce a local peer. This can only be done to nodes that gave us an
828 // announce token, which is received in responses during GetPeers. It's
829 // recommended then that GetPeers is called before this method.
830 func (s *Server) AnnouncePeer(port int, impliedPort bool, infoHash string) (err error) {
831         s.mu.Lock()
832         defer s.mu.Unlock()
833         for _, node := range s.closestNodes(160, nodeIDFromString(infoHash), func(n *Node) bool {
834                 return n.announceToken != ""
835         }) {
836                 err = s.announcePeer(node.addr, infoHash, port, node.announceToken, impliedPort)
837                 if err != nil {
838                         break
839                 }
840         }
841         return
842 }
843
844 func (s *Server) announcePeer(node dHTAddr, infoHash string, port int, token string, impliedPort bool) (err error) {
845         if port == 0 && !impliedPort {
846                 return errors.New("nothing to announce")
847         }
848         _, err = s.query(node, "announce_peer", map[string]interface{}{
849                 "implied_port": func() int {
850                         if impliedPort {
851                                 return 1
852                         } else {
853                                 return 0
854                         }
855                 }(),
856                 "info_hash": infoHash,
857                 "port":      port,
858                 "token":     token,
859         }, func(m Msg) {
860                 if err := m.Error(); err != nil {
861                         logonce.Stderr.Printf("announce_peer response: %s", err)
862                         return
863                 }
864                 s.NumConfirmedAnnounces++
865         })
866         return
867 }
868
869 // Add response nodes to node table.
870 func (s *Server) liftNodes(d Msg) {
871         if d["y"] != "r" {
872                 return
873         }
874         for _, cni := range d.Nodes() {
875                 if util.AddrPort(cni.Addr) == 0 {
876                         // TODO: Why would people even do this?
877                         continue
878                 }
879                 if s.ipBlocked(util.AddrIP(cni.Addr)) {
880                         continue
881                 }
882                 n := s.getNode(cni.Addr)
883                 n.SetIDFromBytes(cni.ID[:])
884         }
885 }
886
887 // Sends a find_node query to addr. targetID is the node we're looking for.
888 func (s *Server) findNode(addr dHTAddr, targetID string) (t *transaction, err error) {
889         t, err = s.query(addr, "find_node", map[string]interface{}{"target": targetID}, func(d Msg) {
890                 // Scrape peers from the response to put in the server's table before
891                 // handing the response back to the caller.
892                 s.liftNodes(d)
893         })
894         if err != nil {
895                 return
896         }
897         return
898 }
899
900 // In a get_peers response, the addresses of torrent clients involved with the
901 // queried info-hash.
902 func (m Msg) Values() (vs []util.CompactPeer) {
903         r, ok := m["r"]
904         if !ok {
905                 return
906         }
907         rd, ok := r.(map[string]interface{})
908         if !ok {
909                 return
910         }
911         v, ok := rd["values"]
912         if !ok {
913                 return
914         }
915         vl, ok := v.([]interface{})
916         if !ok {
917                 log.Printf("unexpected krpc values type: %T", v)
918                 return
919         }
920         vs = make([]util.CompactPeer, 0, len(vl))
921         for _, i := range vl {
922                 s, ok := i.(string)
923                 if !ok {
924                         panic(i)
925                 }
926                 var cp util.CompactPeer
927                 err := cp.UnmarshalBinary([]byte(s))
928                 if err != nil {
929                         log.Printf("error decoding values list element: %s", err)
930                         continue
931                 }
932                 vs = append(vs, cp)
933         }
934         return
935 }
936
937 func (s *Server) getPeers(addr dHTAddr, infoHash string) (t *transaction, err error) {
938         if len(infoHash) != 20 {
939                 err = fmt.Errorf("infohash has bad length")
940                 return
941         }
942         t, err = s.query(addr, "get_peers", map[string]interface{}{"info_hash": infoHash}, func(m Msg) {
943                 s.liftNodes(m)
944                 at, ok := m.AnnounceToken()
945                 if ok {
946                         s.getNode(addr).announceToken = at
947                 }
948         })
949         return
950 }
951
952 func bootstrapAddrs() (addrs []*net.UDPAddr, err error) {
953         for _, addrStr := range []string{
954                 "router.utorrent.com:6881",
955                 "router.bittorrent.com:6881",
956         } {
957                 udpAddr, err := net.ResolveUDPAddr("udp4", addrStr)
958                 if err != nil {
959                         continue
960                 }
961                 addrs = append(addrs, udpAddr)
962         }
963         if len(addrs) == 0 {
964                 err = errors.New("nothing resolved")
965         }
966         return
967 }
968
969 func (s *Server) addRootNodes() error {
970         addrs, err := bootstrapAddrs()
971         if err != nil {
972                 return err
973         }
974         for _, addr := range addrs {
975                 s.nodes[addr.String()] = &Node{
976                         addr: newDHTAddr(addr),
977                 }
978         }
979         return nil
980 }
981
982 // Populates the node table.
983 func (s *Server) bootstrap() (err error) {
984         s.mu.Lock()
985         defer s.mu.Unlock()
986         if len(s.nodes) == 0 {
987                 err = s.addRootNodes()
988         }
989         if err != nil {
990                 return
991         }
992         for {
993                 var outstanding sync.WaitGroup
994                 for _, node := range s.nodes {
995                         var t *transaction
996                         t, err = s.findNode(node.addr, s.id)
997                         if err != nil {
998                                 err = fmt.Errorf("error sending find_node: %s", err)
999                                 return
1000                         }
1001                         outstanding.Add(1)
1002                         t.SetResponseHandler(func(Msg) {
1003                                 outstanding.Done()
1004                         })
1005                 }
1006                 noOutstanding := make(chan struct{})
1007                 go func() {
1008                         outstanding.Wait()
1009                         close(noOutstanding)
1010                 }()
1011                 s.mu.Unlock()
1012                 select {
1013                 case <-s.closed:
1014                         s.mu.Lock()
1015                         return
1016                 case <-time.After(15 * time.Second):
1017                 case <-noOutstanding:
1018                 }
1019                 s.mu.Lock()
1020                 // log.Printf("now have %d nodes", len(s.nodes))
1021                 if s.numGoodNodes() >= 160 {
1022                         break
1023                 }
1024         }
1025         return
1026 }
1027
1028 func (s *Server) numGoodNodes() (num int) {
1029         for _, n := range s.nodes {
1030                 if n.DefinitelyGood() {
1031                         num++
1032                 }
1033         }
1034         return
1035 }
1036
1037 func (s *Server) NumNodes() int {
1038         s.mu.Lock()
1039         defer s.mu.Unlock()
1040         return len(s.nodes)
1041 }
1042
1043 func (s *Server) Nodes() (nis []NodeInfo) {
1044         s.mu.Lock()
1045         defer s.mu.Unlock()
1046         for _, node := range s.nodes {
1047                 // if !node.Good() {
1048                 //      continue
1049                 // }
1050                 ni := NodeInfo{
1051                         Addr: node.addr,
1052                 }
1053                 if n := copy(ni.ID[:], node.idString()); n != 20 && n != 0 {
1054                         panic(n)
1055                 }
1056                 nis = append(nis, ni)
1057         }
1058         return
1059 }
1060
1061 func (s *Server) Close() {
1062         s.mu.Lock()
1063         select {
1064         case <-s.closed:
1065         default:
1066                 close(s.closed)
1067                 s.socket.Close()
1068         }
1069         s.mu.Unlock()
1070 }
1071
1072 var maxDistance big.Int
1073
1074 func init() {
1075         var zero big.Int
1076         maxDistance.SetBit(&zero, 160, 1)
1077 }
1078
1079 func (s *Server) closestGoodNodes(k int, targetID string) []*Node {
1080         return s.closestNodes(k, nodeIDFromString(targetID), func(n *Node) bool { return n.DefinitelyGood() })
1081 }
1082
1083 func (s *Server) closestNodes(k int, target nodeID, filter func(*Node) bool) []*Node {
1084         sel := newKClosestNodesSelector(k, target)
1085         idNodes := make(map[string]*Node, len(s.nodes))
1086         for _, node := range s.nodes {
1087                 if !filter(node) {
1088                         continue
1089                 }
1090                 sel.Push(node.id)
1091                 idNodes[node.idString()] = node
1092         }
1093         ids := sel.IDs()
1094         ret := make([]*Node, 0, len(ids))
1095         for _, id := range ids {
1096                 ret = append(ret, idNodes[id.String()])
1097         }
1098         return ret
1099 }