]> Sergey Matveev's repositories - btrtrc.git/blob - dht/dht.go
dc3fe0cd1721a130aacc7e796bcec699b5b09e48
[btrtrc.git] / dht / dht.go
1 // Package DHT implements a DHT for use with the BitTorrent protocol,
2 // described in BEP 5: http://www.bittorrent.org/beps/bep_0005.html.
3 //
4 // Standard use involves creating a NewServer, and calling Announce on it with
5 // the details of your local torrent client and infohash of interest.
6 package dht
7
8 import (
9         "crypto"
10         _ "crypto/sha1"
11         "encoding/binary"
12         "errors"
13         "fmt"
14         "hash/crc32"
15         "io"
16         "log"
17         "math/big"
18         "math/rand"
19         "net"
20         "os"
21         "time"
22
23         "github.com/anacrolix/missinggo"
24         "github.com/anacrolix/sync"
25         "github.com/tylertreat/BoomFilters"
26
27         "github.com/anacrolix/torrent/bencode"
28         "github.com/anacrolix/torrent/iplist"
29         "github.com/anacrolix/torrent/logonce"
30         "github.com/anacrolix/torrent/util"
31 )
32
33 const (
34         maxNodes         = 320
35         queryResendEvery = 5 * time.Second
36 )
37
38 // Uniquely identifies a transaction to us.
39 type transactionKey struct {
40         RemoteAddr string // host:port
41         T          string // The KRPC transaction ID.
42 }
43
44 type Server struct {
45         id               string
46         socket           net.PacketConn
47         transactions     map[transactionKey]*Transaction
48         transactionIDInt uint64
49         nodes            map[string]*node // Keyed by dHTAddr.String().
50         mu               sync.Mutex
51         closed           chan struct{}
52         passive          bool // Don't respond to queries.
53         ipBlockList      *iplist.IPList
54         badNodes         *boom.BloomFilter
55
56         numConfirmedAnnounces int
57         bootstrapNodes        []string
58         config                ServerConfig
59 }
60
61 type ServerConfig struct {
62         Addr string // Listen address. Used if Conn is nil.
63         Conn net.PacketConn
64         // Don't respond to queries from other nodes.
65         Passive bool
66         // DHT Bootstrap nodes
67         BootstrapNodes []string
68         // Disable the DHT security extension:
69         // http://www.libtorrent.org/dht_sec.html.
70         NoSecurity bool
71         // Initial IP blocklist to use. Applied before serving and bootstrapping
72         // begins.
73         IPBlocklist *iplist.IPList
74         // Used to secure the server's ID. Defaults to the Conn's LocalAddr().
75         PublicIP net.IP
76 }
77
78 type ServerStats struct {
79         // Count of nodes in the node table that responded to our last query or
80         // haven't yet been queried.
81         GoodNodes int
82         // Count of nodes in the node table.
83         Nodes int
84         // Transactions awaiting a response.
85         OutstandingTransactions int
86         // Individual announce_peer requests that got a success response.
87         ConfirmedAnnounces int
88         // Nodes that have been blocked.
89         BadNodes uint
90 }
91
92 // Returns statistics for the server.
93 func (s *Server) Stats() (ss ServerStats) {
94         s.mu.Lock()
95         defer s.mu.Unlock()
96         for _, n := range s.nodes {
97                 if n.DefinitelyGood() {
98                         ss.GoodNodes++
99                 }
100         }
101         ss.Nodes = len(s.nodes)
102         ss.OutstandingTransactions = len(s.transactions)
103         ss.ConfirmedAnnounces = s.numConfirmedAnnounces
104         ss.BadNodes = s.badNodes.Count()
105         return
106 }
107
108 // Returns the listen address for the server. Packets arriving to this address
109 // are processed by the server (unless aliens are involved).
110 func (s *Server) Addr() net.Addr {
111         return s.socket.LocalAddr()
112 }
113
114 func makeSocket(addr string) (socket *net.UDPConn, err error) {
115         addr_, err := net.ResolveUDPAddr("", addr)
116         if err != nil {
117                 return
118         }
119         socket, err = net.ListenUDP("udp", addr_)
120         return
121 }
122
123 // Create a new DHT server.
124 func NewServer(c *ServerConfig) (s *Server, err error) {
125         if c == nil {
126                 c = &ServerConfig{}
127         }
128         s = &Server{
129                 config:      *c,
130                 ipBlockList: c.IPBlocklist,
131                 badNodes:    boom.NewBloomFilter(1000, 0.1),
132         }
133         if c.Conn != nil {
134                 s.socket = c.Conn
135         } else {
136                 s.socket, err = makeSocket(c.Addr)
137                 if err != nil {
138                         return
139                 }
140         }
141         s.passive = c.Passive
142         s.bootstrapNodes = c.BootstrapNodes
143         err = s.init()
144         if err != nil {
145                 return
146         }
147         go func() {
148                 err := s.serve()
149                 select {
150                 case <-s.closed:
151                         return
152                 default:
153                 }
154                 if err != nil {
155                         panic(err)
156                 }
157         }()
158         go func() {
159                 err := s.bootstrap()
160                 if err != nil {
161                         select {
162                         case <-s.closed:
163                         default:
164                                 log.Printf("error bootstrapping DHT: %s", err)
165                         }
166                 }
167         }()
168         return
169 }
170
171 // Returns a description of the Server. Python repr-style.
172 func (s *Server) String() string {
173         return fmt.Sprintf("dht server on %s", s.socket.LocalAddr())
174 }
175
176 type nodeID struct {
177         i   big.Int
178         set bool
179 }
180
181 func (nid *nodeID) IsUnset() bool {
182         return !nid.set
183 }
184
185 func nodeIDFromString(s string) (ret nodeID) {
186         if s == "" {
187                 return
188         }
189         ret.i.SetBytes([]byte(s))
190         ret.set = true
191         return
192 }
193
194 func (nid0 *nodeID) Distance(nid1 *nodeID) (ret big.Int) {
195         if nid0.IsUnset() != nid1.IsUnset() {
196                 ret = maxDistance
197                 return
198         }
199         ret.Xor(&nid0.i, &nid1.i)
200         return
201 }
202
203 func (nid *nodeID) ByteString() string {
204         var buf [20]byte
205         b := nid.i.Bytes()
206         copy(buf[20-len(b):], b)
207         return string(buf[:])
208 }
209
210 type node struct {
211         addr          dHTAddr
212         id            nodeID
213         announceToken string
214
215         lastGotQuery    time.Time
216         lastGotResponse time.Time
217         lastSentQuery   time.Time
218 }
219
220 func (n *node) IsSecure() bool {
221         if n.id.IsUnset() {
222                 return false
223         }
224         return NodeIdSecure(n.id.ByteString(), n.addr.IP())
225 }
226
227 func (n *node) idString() string {
228         return n.id.ByteString()
229 }
230
231 func (n *node) SetIDFromBytes(b []byte) {
232         if len(b) != 20 {
233                 panic(b)
234         }
235         n.id.i.SetBytes(b)
236         n.id.set = true
237 }
238
239 func (n *node) SetIDFromString(s string) {
240         n.SetIDFromBytes([]byte(s))
241 }
242
243 func (n *node) IDNotSet() bool {
244         return n.id.i.Int64() == 0
245 }
246
247 func (n *node) NodeInfo() (ret NodeInfo) {
248         ret.Addr = n.addr
249         if n := copy(ret.ID[:], n.idString()); n != 20 {
250                 panic(n)
251         }
252         return
253 }
254
255 func (n *node) DefinitelyGood() bool {
256         if len(n.idString()) != 20 {
257                 return false
258         }
259         // No reason to think ill of them if they've never been queried.
260         if n.lastSentQuery.IsZero() {
261                 return true
262         }
263         // They answered our last query.
264         if n.lastSentQuery.Before(n.lastGotResponse) {
265                 return true
266         }
267         return true
268 }
269
270 // A wrapper around the unmarshalled KRPC dict that constitutes messages in
271 // the DHT. There are various helpers for extracting common data from the
272 // message. In normal use, Msg is abstracted away for you, but it can be of
273 // interest.
274 type Msg map[string]interface{}
275
276 var _ fmt.Stringer = Msg{}
277
278 func (m Msg) String() string {
279         return fmt.Sprintf("%#v", m)
280 }
281
282 func (m Msg) T() (t string) {
283         tif, ok := m["t"]
284         if !ok {
285                 return
286         }
287         t, _ = tif.(string)
288         return
289 }
290
291 func (m Msg) ID() string {
292         defer func() {
293                 recover()
294         }()
295         return m[m["y"].(string)].(map[string]interface{})["id"].(string)
296 }
297
298 // Suggested nodes in a response.
299 func (m Msg) Nodes() (nodes []NodeInfo) {
300         b := func() string {
301                 defer func() {
302                         recover()
303                 }()
304                 return m["r"].(map[string]interface{})["nodes"].(string)
305         }()
306         if len(b)%26 != 0 {
307                 return
308         }
309         for i := 0; i < len(b); i += 26 {
310                 var n NodeInfo
311                 err := n.UnmarshalCompact([]byte(b[i : i+26]))
312                 if err != nil {
313                         continue
314                 }
315                 nodes = append(nodes, n)
316         }
317         return
318 }
319
320 type KRPCError struct {
321         Code int
322         Msg  string
323 }
324
325 func (me KRPCError) Error() string {
326         return fmt.Sprintf("KRPC error %d: %s", me.Code, me.Msg)
327 }
328
329 var _ error = KRPCError{}
330
331 func (m Msg) Error() (ret *KRPCError) {
332         if m["y"] != "e" {
333                 return
334         }
335         ret = &KRPCError{}
336         switch e := m["e"].(type) {
337         case []interface{}:
338                 ret.Code = int(e[0].(int64))
339                 ret.Msg = e[1].(string)
340         case string:
341                 ret.Msg = e
342         default:
343                 logonce.Stderr.Printf(`KRPC error "e" value has unexpected type: %T`, e)
344         }
345         return
346 }
347
348 // Returns the token given in response to a get_peers request for future
349 // announce_peer requests to that node.
350 func (m Msg) AnnounceToken() (token string, ok bool) {
351         defer func() { recover() }()
352         token, ok = m["r"].(map[string]interface{})["token"].(string)
353         return
354 }
355
356 type Transaction struct {
357         mu             sync.Mutex
358         remoteAddr     dHTAddr
359         t              string
360         response       chan Msg
361         onResponse     func(Msg) // Called with the server locked.
362         done           chan struct{}
363         queryPacket    []byte
364         timer          *time.Timer
365         s              *Server
366         retries        int
367         lastSend       time.Time
368         userOnResponse func(Msg)
369 }
370
371 // Set a function to be called with the response.
372 func (t *Transaction) SetResponseHandler(f func(Msg)) {
373         t.mu.Lock()
374         defer t.mu.Unlock()
375         t.userOnResponse = f
376         t.tryHandleResponse()
377 }
378
379 func (t *Transaction) tryHandleResponse() {
380         if t.userOnResponse == nil {
381                 return
382         }
383         select {
384         case r := <-t.response:
385                 t.userOnResponse(r)
386                 // Shouldn't be called more than once.
387                 t.userOnResponse = nil
388         default:
389         }
390 }
391
392 func (t *Transaction) key() transactionKey {
393         return transactionKey{
394                 t.remoteAddr.String(),
395                 t.t,
396         }
397 }
398
399 func jitterDuration(average time.Duration, plusMinus time.Duration) time.Duration {
400         return average - plusMinus/2 + time.Duration(rand.Int63n(int64(plusMinus)))
401 }
402
403 func (t *Transaction) startTimer() {
404         t.timer = time.AfterFunc(jitterDuration(queryResendEvery, time.Second), t.timerCallback)
405 }
406
407 func (t *Transaction) timerCallback() {
408         t.mu.Lock()
409         defer t.mu.Unlock()
410         select {
411         case <-t.done:
412                 return
413         default:
414         }
415         if t.retries == 2 {
416                 t.timeout()
417                 return
418         }
419         t.retries++
420         t.sendQuery()
421         if t.timer.Reset(jitterDuration(queryResendEvery, time.Second)) {
422                 panic("timer should have fired to get here")
423         }
424 }
425
426 func (t *Transaction) sendQuery() error {
427         err := t.s.writeToNode(t.queryPacket, t.remoteAddr)
428         if err != nil {
429                 return err
430         }
431         t.lastSend = time.Now()
432         return nil
433 }
434
435 func (t *Transaction) timeout() {
436         go func() {
437                 t.s.mu.Lock()
438                 defer t.s.mu.Unlock()
439                 t.s.nodeTimedOut(t.remoteAddr)
440         }()
441         t.close()
442 }
443
444 func (t *Transaction) close() {
445         if t.closing() {
446                 return
447         }
448         t.queryPacket = nil
449         close(t.response)
450         t.tryHandleResponse()
451         close(t.done)
452         t.timer.Stop()
453         go func() {
454                 t.s.mu.Lock()
455                 defer t.s.mu.Unlock()
456                 t.s.deleteTransaction(t)
457         }()
458 }
459
460 func (t *Transaction) closing() bool {
461         select {
462         case <-t.done:
463                 return true
464         default:
465                 return false
466         }
467 }
468
469 // Abandon the transaction.
470 func (t *Transaction) Close() {
471         t.mu.Lock()
472         defer t.mu.Unlock()
473         t.close()
474 }
475
476 func (t *Transaction) handleResponse(m Msg) {
477         t.mu.Lock()
478         if t.closing() {
479                 t.mu.Unlock()
480                 return
481         }
482         close(t.done)
483         t.mu.Unlock()
484         if t.onResponse != nil {
485                 t.s.mu.Lock()
486                 t.onResponse(m)
487                 t.s.mu.Unlock()
488         }
489         t.queryPacket = nil
490         select {
491         case t.response <- m:
492         default:
493                 panic("blocked handling response")
494         }
495         close(t.response)
496         t.tryHandleResponse()
497 }
498
499 func maskForIP(ip net.IP) []byte {
500         switch {
501         case ip.To4() != nil:
502                 return []byte{0x03, 0x0f, 0x3f, 0xff}
503         default:
504                 return []byte{0x01, 0x03, 0x07, 0x0f, 0x1f, 0x3f, 0x7f, 0xff}
505         }
506 }
507
508 // Generate the CRC used to make or validate secure node ID.
509 func crcIP(ip net.IP, rand uint8) uint32 {
510         if ip4 := ip.To4(); ip4 != nil {
511                 ip = ip4
512         }
513         // Copy IP so we can make changes. Go sux at this.
514         ip = append(make(net.IP, 0, len(ip)), ip...)
515         mask := maskForIP(ip)
516         for i := range mask {
517                 ip[i] &= mask[i]
518         }
519         r := rand & 7
520         ip[0] |= r << 5
521         return crc32.Checksum(ip[:len(mask)], crc32.MakeTable(crc32.Castagnoli))
522 }
523
524 // Makes a node ID secure, in-place. The ID is 20 raw bytes.
525 // http://www.libtorrent.org/dht_sec.html
526 func SecureNodeId(id []byte, ip net.IP) {
527         crc := crcIP(ip, id[19])
528         id[0] = byte(crc >> 24 & 0xff)
529         id[1] = byte(crc >> 16 & 0xff)
530         id[2] = byte(crc>>8&0xf8) | id[2]&7
531 }
532
533 // Returns whether the node ID is considered secure. The id is the 20 raw
534 // bytes. http://www.libtorrent.org/dht_sec.html
535 func NodeIdSecure(id string, ip net.IP) bool {
536         if len(id) != 20 {
537                 panic(fmt.Sprintf("%q", id))
538         }
539         if ip4 := ip.To4(); ip4 != nil {
540                 ip = ip4
541         }
542         crc := crcIP(ip, id[19])
543         if id[0] != byte(crc>>24&0xff) {
544                 return false
545         }
546         if id[1] != byte(crc>>16&0xff) {
547                 return false
548         }
549         if id[2]&0xf8 != byte(crc>>8&0xf8) {
550                 return false
551         }
552         return true
553 }
554
555 func (s *Server) setDefaults() (err error) {
556         if s.id == "" {
557                 var id [20]byte
558                 h := crypto.SHA1.New()
559                 ss, err := os.Hostname()
560                 if err != nil {
561                         log.Print(err)
562                 }
563                 ss += s.socket.LocalAddr().String()
564                 h.Write([]byte(ss))
565                 if b := h.Sum(id[:0:20]); len(b) != 20 {
566                         panic(len(b))
567                 }
568                 if len(id) != 20 {
569                         panic(len(id))
570                 }
571                 publicIP := func() net.IP {
572                         if s.config.PublicIP != nil {
573                                 return s.config.PublicIP
574                         } else {
575                                 return missinggo.AddrIP(s.socket.LocalAddr())
576                         }
577                 }()
578                 SecureNodeId(id[:], publicIP)
579                 s.id = string(id[:])
580         }
581         s.nodes = make(map[string]*node, maxNodes)
582         return
583 }
584
585 // Packets to and from any address matching a range in the list are dropped.
586 func (s *Server) SetIPBlockList(list *iplist.IPList) {
587         s.mu.Lock()
588         defer s.mu.Unlock()
589         s.ipBlockList = list
590 }
591
592 func (s *Server) IPBlocklist() *iplist.IPList {
593         return s.ipBlockList
594 }
595
596 func (s *Server) init() (err error) {
597         err = s.setDefaults()
598         if err != nil {
599                 return
600         }
601         s.closed = make(chan struct{})
602         s.transactions = make(map[transactionKey]*Transaction)
603         return
604 }
605
606 func (s *Server) processPacket(b []byte, addr dHTAddr) {
607         if len(b) < 2 || b[0] != 'd' || b[len(b)-1] != 'e' {
608                 // KRPC messages are bencoded dicts.
609                 readNotKRPCDict.Add(1)
610                 return
611         }
612         var d Msg
613         err := bencode.Unmarshal(b, &d)
614         if err != nil {
615                 readUnmarshalError.Add(1)
616                 func() {
617                         if se, ok := err.(*bencode.SyntaxError); ok {
618                                 // The message was truncated.
619                                 if int(se.Offset) == len(b) {
620                                         return
621                                 }
622                                 // Some messages seem to drop to nul chars abrubtly.
623                                 if int(se.Offset) < len(b) && b[se.Offset] == 0 {
624                                         return
625                                 }
626                                 // The message isn't bencode from the first.
627                                 if se.Offset == 0 {
628                                         return
629                                 }
630                         }
631                         if missinggo.CryHeard() {
632                                 log.Printf("%s: received bad krpc message from %s: %s: %+q", s, addr, err, b)
633                         }
634                 }()
635                 return
636         }
637         s.mu.Lock()
638         defer s.mu.Unlock()
639         if d["y"] == "q" {
640                 readQuery.Add(1)
641                 s.handleQuery(addr, d)
642                 return
643         }
644         t := s.findResponseTransaction(d.T(), addr)
645         if t == nil {
646                 //log.Printf("unexpected message: %#v", d)
647                 return
648         }
649         node := s.getNode(addr, d.ID())
650         node.lastGotResponse = time.Now()
651         // TODO: Update node ID as this is an authoritative packet.
652         go t.handleResponse(d)
653         s.deleteTransaction(t)
654 }
655
656 func (s *Server) serve() error {
657         var b [0x10000]byte
658         for {
659                 n, addr, err := s.socket.ReadFrom(b[:])
660                 if err != nil {
661                         return err
662                 }
663                 read.Add(1)
664                 if n == len(b) {
665                         logonce.Stderr.Printf("received dht packet exceeds buffer size")
666                         continue
667                 }
668                 s.mu.Lock()
669                 blocked := s.ipBlocked(missinggo.AddrIP(addr))
670                 s.mu.Unlock()
671                 if blocked {
672                         readBlocked.Add(1)
673                         continue
674                 }
675                 s.processPacket(b[:n], newDHTAddr(addr))
676         }
677 }
678
679 func (s *Server) ipBlocked(ip net.IP) bool {
680         if s.ipBlockList == nil {
681                 return false
682         }
683         return s.ipBlockList.Lookup(ip) != nil
684 }
685
686 // Adds directly to the node table.
687 func (s *Server) AddNode(ni NodeInfo) {
688         s.mu.Lock()
689         defer s.mu.Unlock()
690         if s.nodes == nil {
691                 s.nodes = make(map[string]*node)
692         }
693         s.getNode(ni.Addr, string(ni.ID[:]))
694 }
695
696 func (s *Server) nodeByID(id string) *node {
697         for _, node := range s.nodes {
698                 if node.idString() == id {
699                         return node
700                 }
701         }
702         return nil
703 }
704
705 func (s *Server) handleQuery(source dHTAddr, m Msg) {
706         args := m["a"].(map[string]interface{})
707         node := s.getNode(source, m.ID())
708         node.SetIDFromString(args["id"].(string))
709         node.lastGotQuery = time.Now()
710         // Don't respond.
711         if s.passive {
712                 return
713         }
714         switch m["q"] {
715         case "ping":
716                 s.reply(source, m["t"].(string), nil)
717         case "get_peers": // TODO: Extract common behaviour with find_node.
718                 targetID := args["info_hash"].(string)
719                 if len(targetID) != 20 {
720                         break
721                 }
722                 var rNodes []NodeInfo
723                 // TODO: Reply with "values" list if we have peers instead.
724                 for _, node := range s.closestGoodNodes(8, targetID) {
725                         rNodes = append(rNodes, node.NodeInfo())
726                 }
727                 nodesBytes := make([]byte, CompactNodeInfoLen*len(rNodes))
728                 for i, ni := range rNodes {
729                         err := ni.PutCompact(nodesBytes[i*CompactNodeInfoLen : (i+1)*CompactNodeInfoLen])
730                         if err != nil {
731                                 panic(err)
732                         }
733                 }
734                 s.reply(source, m["t"].(string), map[string]interface{}{
735                         "nodes": string(nodesBytes),
736                         "token": "hi",
737                 })
738         case "find_node": // TODO: Extract common behaviour with get_peers.
739                 targetID := args["target"].(string)
740                 if len(targetID) != 20 {
741                         log.Printf("bad DHT query: %v", m)
742                         return
743                 }
744                 var rNodes []NodeInfo
745                 if node := s.nodeByID(targetID); node != nil {
746                         rNodes = append(rNodes, node.NodeInfo())
747                 } else {
748                         for _, node := range s.closestGoodNodes(8, targetID) {
749                                 rNodes = append(rNodes, node.NodeInfo())
750                         }
751                 }
752                 nodesBytes := make([]byte, CompactNodeInfoLen*len(rNodes))
753                 for i, ni := range rNodes {
754                         // TODO: Put IPv6 nodes into the correct dict element.
755                         if ni.Addr.UDPAddr().IP.To4() == nil {
756                                 continue
757                         }
758                         err := ni.PutCompact(nodesBytes[i*CompactNodeInfoLen : (i+1)*CompactNodeInfoLen])
759                         if err != nil {
760                                 log.Printf("error compacting %#v: %s", ni, err)
761                                 continue
762                         }
763                 }
764                 s.reply(source, m["t"].(string), map[string]interface{}{
765                         "nodes": string(nodesBytes),
766                 })
767         case "announce_peer":
768                 // TODO(anacrolix): Implement this lolz.
769                 // log.Print(m)
770         case "vote":
771                 // TODO(anacrolix): Or reject, I don't think I want this.
772         default:
773                 log.Printf("%s: not handling received query: q=%s", s, m["q"])
774                 return
775         }
776 }
777
778 func (s *Server) reply(addr dHTAddr, t string, r map[string]interface{}) {
779         if r == nil {
780                 r = make(map[string]interface{}, 1)
781         }
782         r["id"] = s.ID()
783         m := map[string]interface{}{
784                 "t": t,
785                 "y": "r",
786                 "r": r,
787         }
788         b, err := bencode.Marshal(m)
789         if err != nil {
790                 panic(err)
791         }
792         err = s.writeToNode(b, addr)
793         if err != nil {
794                 log.Printf("error replying to %s: %s", addr, err)
795         }
796 }
797
798 // Returns a node struct for the addr. It is taken from the table or created
799 // and possibly added if required and meets validity constraints.
800 func (s *Server) getNode(addr dHTAddr, id string) (n *node) {
801         addrStr := addr.String()
802         n = s.nodes[addrStr]
803         if n != nil {
804                 if id != "" {
805                         n.SetIDFromString(id)
806                 }
807                 return
808         }
809         n = &node{
810                 addr: addr,
811         }
812         if len(id) == 20 {
813                 n.SetIDFromString(id)
814         }
815         if len(s.nodes) >= maxNodes {
816                 return
817         }
818         if !s.config.NoSecurity && !n.IsSecure() {
819                 return
820         }
821         if s.badNodes.Test([]byte(addrStr)) {
822                 return
823         }
824         s.nodes[addrStr] = n
825         return
826 }
827
828 func (s *Server) nodeTimedOut(addr dHTAddr) {
829         node, ok := s.nodes[addr.String()]
830         if !ok {
831                 return
832         }
833         if node.DefinitelyGood() {
834                 return
835         }
836         if len(s.nodes) < maxNodes {
837                 return
838         }
839         delete(s.nodes, addr.String())
840 }
841
842 func (s *Server) writeToNode(b []byte, node dHTAddr) (err error) {
843         if list := s.ipBlockList; list != nil {
844                 if r := list.Lookup(missinggo.AddrIP(node.UDPAddr())); r != nil {
845                         err = fmt.Errorf("write to %s blocked: %s", node, r.Description)
846                         return
847                 }
848         }
849         n, err := s.socket.WriteTo(b, node.UDPAddr())
850         if err != nil {
851                 err = fmt.Errorf("error writing %d bytes to %s: %#v", len(b), node, err)
852                 return
853         }
854         if n != len(b) {
855                 err = io.ErrShortWrite
856                 return
857         }
858         return
859 }
860
861 func (s *Server) findResponseTransaction(transactionID string, sourceNode dHTAddr) *Transaction {
862         return s.transactions[transactionKey{
863                 sourceNode.String(),
864                 transactionID}]
865 }
866
867 func (s *Server) nextTransactionID() string {
868         var b [binary.MaxVarintLen64]byte
869         n := binary.PutUvarint(b[:], s.transactionIDInt)
870         s.transactionIDInt++
871         return string(b[:n])
872 }
873
874 func (s *Server) deleteTransaction(t *Transaction) {
875         delete(s.transactions, t.key())
876 }
877
878 func (s *Server) addTransaction(t *Transaction) {
879         if _, ok := s.transactions[t.key()]; ok {
880                 panic("transaction not unique")
881         }
882         s.transactions[t.key()] = t
883 }
884
885 // Returns the 20-byte server ID. This is the ID used to communicate with the
886 // DHT network.
887 func (s *Server) ID() string {
888         if len(s.id) != 20 {
889                 panic("bad node id")
890         }
891         return s.id
892 }
893
894 func (s *Server) query(node dHTAddr, q string, a map[string]interface{}, onResponse func(Msg)) (t *Transaction, err error) {
895         tid := s.nextTransactionID()
896         if a == nil {
897                 a = make(map[string]interface{}, 1)
898         }
899         a["id"] = s.ID()
900         d := map[string]interface{}{
901                 "t": tid,
902                 "y": "q",
903                 "q": q,
904                 "a": a,
905         }
906         b, err := bencode.Marshal(d)
907         if err != nil {
908                 return
909         }
910         t = &Transaction{
911                 remoteAddr:  node,
912                 t:           tid,
913                 response:    make(chan Msg, 1),
914                 done:        make(chan struct{}),
915                 queryPacket: b,
916                 s:           s,
917                 onResponse:  onResponse,
918         }
919         err = t.sendQuery()
920         if err != nil {
921                 return
922         }
923         s.getNode(node, "").lastSentQuery = time.Now()
924         t.startTimer()
925         s.addTransaction(t)
926         return
927 }
928
929 // The size in bytes of a NodeInfo in its compact binary representation.
930 const CompactNodeInfoLen = 26
931
932 type NodeInfo struct {
933         ID   [20]byte
934         Addr dHTAddr
935 }
936
937 // Writes the node info to its compact binary representation in b. See
938 // CompactNodeInfoLen.
939 func (ni *NodeInfo) PutCompact(b []byte) error {
940         if n := copy(b[:], ni.ID[:]); n != 20 {
941                 panic(n)
942         }
943         ip := missinggo.AddrIP(ni.Addr).To4()
944         if len(ip) != 4 {
945                 return errors.New("expected ipv4 address")
946         }
947         if n := copy(b[20:], ip); n != 4 {
948                 panic(n)
949         }
950         binary.BigEndian.PutUint16(b[24:], uint16(missinggo.AddrPort(ni.Addr)))
951         return nil
952 }
953
954 func (cni *NodeInfo) UnmarshalCompact(b []byte) error {
955         if len(b) != 26 {
956                 return errors.New("expected 26 bytes")
957         }
958         missinggo.CopyExact(cni.ID[:], b[:20])
959         cni.Addr = newDHTAddr(&net.UDPAddr{
960                 IP:   net.IPv4(b[20], b[21], b[22], b[23]),
961                 Port: int(binary.BigEndian.Uint16(b[24:26])),
962         })
963         return nil
964 }
965
966 // Sends a ping query to the address given.
967 func (s *Server) Ping(node *net.UDPAddr) (*Transaction, error) {
968         s.mu.Lock()
969         defer s.mu.Unlock()
970         return s.query(newDHTAddr(node), "ping", nil, nil)
971 }
972
973 func (s *Server) announcePeer(node dHTAddr, infoHash string, port int, token string, impliedPort bool) (err error) {
974         if port == 0 && !impliedPort {
975                 return errors.New("nothing to announce")
976         }
977         _, err = s.query(node, "announce_peer", map[string]interface{}{
978                 "implied_port": func() int {
979                         if impliedPort {
980                                 return 1
981                         } else {
982                                 return 0
983                         }
984                 }(),
985                 "info_hash": infoHash,
986                 "port":      port,
987                 "token":     token,
988         }, func(m Msg) {
989                 if err := m.Error(); err != nil {
990                         announceErrors.Add(1)
991                         // log.Print(token)
992                         // logonce.Stderr.Printf("announce_peer response: %s", err)
993                         return
994                 }
995                 s.numConfirmedAnnounces++
996         })
997         return
998 }
999
1000 // Add response nodes to node table.
1001 func (s *Server) liftNodes(d Msg) {
1002         if d["y"] != "r" {
1003                 return
1004         }
1005         for _, cni := range d.Nodes() {
1006                 if missinggo.AddrPort(cni.Addr) == 0 {
1007                         // TODO: Why would people even do this?
1008                         continue
1009                 }
1010                 if s.ipBlocked(missinggo.AddrIP(cni.Addr)) {
1011                         continue
1012                 }
1013                 n := s.getNode(cni.Addr, string(cni.ID[:]))
1014                 n.SetIDFromBytes(cni.ID[:])
1015         }
1016 }
1017
1018 // Sends a find_node query to addr. targetID is the node we're looking for.
1019 func (s *Server) findNode(addr dHTAddr, targetID string) (t *Transaction, err error) {
1020         t, err = s.query(addr, "find_node", map[string]interface{}{"target": targetID}, func(d Msg) {
1021                 // Scrape peers from the response to put in the server's table before
1022                 // handing the response back to the caller.
1023                 s.liftNodes(d)
1024         })
1025         if err != nil {
1026                 return
1027         }
1028         return
1029 }
1030
1031 // In a get_peers response, the addresses of torrent clients involved with the
1032 // queried info-hash.
1033 func (m Msg) Values() (vs []util.CompactPeer) {
1034         r, ok := m["r"]
1035         if !ok {
1036                 return
1037         }
1038         rd, ok := r.(map[string]interface{})
1039         if !ok {
1040                 return
1041         }
1042         v, ok := rd["values"]
1043         if !ok {
1044                 return
1045         }
1046         vl, ok := v.([]interface{})
1047         if !ok {
1048                 if missinggo.CryHeard() {
1049                         log.Printf(`unexpected krpc "values" field: %#v`, v)
1050                 }
1051                 return
1052         }
1053         vs = make([]util.CompactPeer, 0, len(vl))
1054         for _, i := range vl {
1055                 s, ok := i.(string)
1056                 if !ok {
1057                         panic(i)
1058                 }
1059                 var cp util.CompactPeer
1060                 err := cp.UnmarshalBinary([]byte(s))
1061                 if err != nil {
1062                         log.Printf("error decoding values list element: %s", err)
1063                         continue
1064                 }
1065                 vs = append(vs, cp)
1066         }
1067         return
1068 }
1069
1070 func (s *Server) getPeers(addr dHTAddr, infoHash string) (t *Transaction, err error) {
1071         if len(infoHash) != 20 {
1072                 err = fmt.Errorf("infohash has bad length")
1073                 return
1074         }
1075         t, err = s.query(addr, "get_peers", map[string]interface{}{"info_hash": infoHash}, func(m Msg) {
1076                 s.liftNodes(m)
1077                 at, ok := m.AnnounceToken()
1078                 if ok {
1079                         s.getNode(addr, m.ID()).announceToken = at
1080                 }
1081         })
1082         return
1083 }
1084
1085 func bootstrapAddrs(nodeAddrs []string) (addrs []*net.UDPAddr, err error) {
1086         bootstrapNodes := nodeAddrs
1087         if len(bootstrapNodes) == 0 {
1088                 bootstrapNodes = []string{
1089                         "router.utorrent.com:6881",
1090                         "router.bittorrent.com:6881",
1091                 }
1092         }
1093         for _, addrStr := range bootstrapNodes {
1094                 udpAddr, err := net.ResolveUDPAddr("udp4", addrStr)
1095                 if err != nil {
1096                         continue
1097                 }
1098                 addrs = append(addrs, udpAddr)
1099         }
1100         if len(addrs) == 0 {
1101                 err = errors.New("nothing resolved")
1102         }
1103         return
1104 }
1105
1106 // Adds bootstrap nodes directly to table, if there's room. Node ID security
1107 // is bypassed, but the IP blocklist is not.
1108 func (s *Server) addRootNodes() error {
1109         addrs, err := bootstrapAddrs(s.bootstrapNodes)
1110         if err != nil {
1111                 return err
1112         }
1113         for _, addr := range addrs {
1114                 if len(s.nodes) >= maxNodes {
1115                         break
1116                 }
1117                 if s.nodes[addr.String()] != nil {
1118                         continue
1119                 }
1120                 if s.ipBlocked(addr.IP) {
1121                         log.Printf("dht root node is in the blocklist: %s", addr.IP)
1122                         continue
1123                 }
1124                 s.nodes[addr.String()] = &node{
1125                         addr: newDHTAddr(addr),
1126                 }
1127         }
1128         return nil
1129 }
1130
1131 // Populates the node table.
1132 func (s *Server) bootstrap() (err error) {
1133         s.mu.Lock()
1134         defer s.mu.Unlock()
1135         if len(s.nodes) == 0 {
1136                 err = s.addRootNodes()
1137         }
1138         if err != nil {
1139                 return
1140         }
1141         for {
1142                 var outstanding sync.WaitGroup
1143                 for _, node := range s.nodes {
1144                         var t *Transaction
1145                         t, err = s.findNode(node.addr, s.id)
1146                         if err != nil {
1147                                 err = fmt.Errorf("error sending find_node: %s", err)
1148                                 return
1149                         }
1150                         outstanding.Add(1)
1151                         t.SetResponseHandler(func(Msg) {
1152                                 outstanding.Done()
1153                         })
1154                 }
1155                 noOutstanding := make(chan struct{})
1156                 go func() {
1157                         outstanding.Wait()
1158                         close(noOutstanding)
1159                 }()
1160                 s.mu.Unlock()
1161                 select {
1162                 case <-s.closed:
1163                         s.mu.Lock()
1164                         return
1165                 case <-time.After(15 * time.Second):
1166                 case <-noOutstanding:
1167                 }
1168                 s.mu.Lock()
1169                 // log.Printf("now have %d nodes", len(s.nodes))
1170                 if s.numGoodNodes() >= 160 {
1171                         break
1172                 }
1173         }
1174         return
1175 }
1176
1177 func (s *Server) numGoodNodes() (num int) {
1178         for _, n := range s.nodes {
1179                 if n.DefinitelyGood() {
1180                         num++
1181                 }
1182         }
1183         return
1184 }
1185
1186 // Returns how many nodes are in the node table.
1187 func (s *Server) NumNodes() int {
1188         s.mu.Lock()
1189         defer s.mu.Unlock()
1190         return len(s.nodes)
1191 }
1192
1193 // Exports the current node table.
1194 func (s *Server) Nodes() (nis []NodeInfo) {
1195         s.mu.Lock()
1196         defer s.mu.Unlock()
1197         for _, node := range s.nodes {
1198                 // if !node.Good() {
1199                 //      continue
1200                 // }
1201                 ni := NodeInfo{
1202                         Addr: node.addr,
1203                 }
1204                 if n := copy(ni.ID[:], node.idString()); n != 20 && n != 0 {
1205                         panic(n)
1206                 }
1207                 nis = append(nis, ni)
1208         }
1209         return
1210 }
1211
1212 // Stops the server network activity. This is all that's required to clean-up a Server.
1213 func (s *Server) Close() {
1214         s.mu.Lock()
1215         select {
1216         case <-s.closed:
1217         default:
1218                 close(s.closed)
1219                 s.socket.Close()
1220         }
1221         s.mu.Unlock()
1222 }
1223
1224 var maxDistance big.Int
1225
1226 func init() {
1227         var zero big.Int
1228         maxDistance.SetBit(&zero, 160, 1)
1229 }
1230
1231 func (s *Server) closestGoodNodes(k int, targetID string) []*node {
1232         return s.closestNodes(k, nodeIDFromString(targetID), func(n *node) bool { return n.DefinitelyGood() })
1233 }
1234
1235 func (s *Server) closestNodes(k int, target nodeID, filter func(*node) bool) []*node {
1236         sel := newKClosestNodesSelector(k, target)
1237         idNodes := make(map[string]*node, len(s.nodes))
1238         for _, node := range s.nodes {
1239                 if !filter(node) {
1240                         continue
1241                 }
1242                 sel.Push(node.id)
1243                 idNodes[node.idString()] = node
1244         }
1245         ids := sel.IDs()
1246         ret := make([]*node, 0, len(ids))
1247         for _, id := range ids {
1248                 ret = append(ret, idNodes[id.ByteString()])
1249         }
1250         return ret
1251 }
1252
1253 func (me *Server) badNode(addr dHTAddr) {
1254         me.badNodes.Add([]byte(addr.String()))
1255         delete(me.nodes, addr.String())
1256 }