]> Sergey Matveev's repositories - btrtrc.git/blob - dht/dht.go
Avoid allocation in iplist.Ranger.Lookup
[btrtrc.git] / dht / dht.go
1 // Package DHT implements a DHT for use with the BitTorrent protocol,
2 // described in BEP 5: http://www.bittorrent.org/beps/bep_0005.html.
3 //
4 // Standard use involves creating a NewServer, and calling Announce on it with
5 // the details of your local torrent client and infohash of interest.
6 package dht
7
8 import (
9         "crypto"
10         _ "crypto/sha1"
11         "encoding/binary"
12         "errors"
13         "fmt"
14         "hash/crc32"
15         "io"
16         "log"
17         "math/big"
18         "math/rand"
19         "net"
20         "os"
21         "strconv"
22         "time"
23
24         "github.com/anacrolix/missinggo"
25         "github.com/anacrolix/sync"
26         "github.com/tylertreat/BoomFilters"
27
28         "github.com/anacrolix/torrent/bencode"
29         "github.com/anacrolix/torrent/iplist"
30         "github.com/anacrolix/torrent/logonce"
31         "github.com/anacrolix/torrent/util"
32 )
33
34 const (
35         maxNodes         = 320
36         queryResendEvery = 5 * time.Second
37 )
38
39 // Uniquely identifies a transaction to us.
40 type transactionKey struct {
41         RemoteAddr string // host:port
42         T          string // The KRPC transaction ID.
43 }
44
45 type Server struct {
46         id               string
47         socket           net.PacketConn
48         transactions     map[transactionKey]*Transaction
49         transactionIDInt uint64
50         nodes            map[string]*node // Keyed by dHTAddr.String().
51         mu               sync.Mutex
52         closed           chan struct{}
53         ipBlockList      iplist.Ranger
54         badNodes         *boom.BloomFilter
55
56         numConfirmedAnnounces int
57         bootstrapNodes        []string
58         config                ServerConfig
59 }
60
61 type ServerConfig struct {
62         Addr string // Listen address. Used if Conn is nil.
63         Conn net.PacketConn
64         // Don't respond to queries from other nodes.
65         Passive bool
66         // DHT Bootstrap nodes
67         BootstrapNodes []string
68         // Disable the DHT security extension:
69         // http://www.libtorrent.org/dht_sec.html.
70         NoSecurity bool
71         // Initial IP blocklist to use. Applied before serving and bootstrapping
72         // begins.
73         IPBlocklist iplist.Ranger
74         // Used to secure the server's ID. Defaults to the Conn's LocalAddr().
75         PublicIP net.IP
76 }
77
78 type ServerStats struct {
79         // Count of nodes in the node table that responded to our last query or
80         // haven't yet been queried.
81         GoodNodes int
82         // Count of nodes in the node table.
83         Nodes int
84         // Transactions awaiting a response.
85         OutstandingTransactions int
86         // Individual announce_peer requests that got a success response.
87         ConfirmedAnnounces int
88         // Nodes that have been blocked.
89         BadNodes uint
90 }
91
92 // Returns statistics for the server.
93 func (s *Server) Stats() (ss ServerStats) {
94         s.mu.Lock()
95         defer s.mu.Unlock()
96         for _, n := range s.nodes {
97                 if n.DefinitelyGood() {
98                         ss.GoodNodes++
99                 }
100         }
101         ss.Nodes = len(s.nodes)
102         ss.OutstandingTransactions = len(s.transactions)
103         ss.ConfirmedAnnounces = s.numConfirmedAnnounces
104         ss.BadNodes = s.badNodes.Count()
105         return
106 }
107
108 // Returns the listen address for the server. Packets arriving to this address
109 // are processed by the server (unless aliens are involved).
110 func (s *Server) Addr() net.Addr {
111         return s.socket.LocalAddr()
112 }
113
114 func makeSocket(addr string) (socket *net.UDPConn, err error) {
115         addr_, err := net.ResolveUDPAddr("", addr)
116         if err != nil {
117                 return
118         }
119         socket, err = net.ListenUDP("udp", addr_)
120         return
121 }
122
123 // Create a new DHT server.
124 func NewServer(c *ServerConfig) (s *Server, err error) {
125         if c == nil {
126                 c = &ServerConfig{}
127         }
128         s = &Server{
129                 config:      *c,
130                 ipBlockList: c.IPBlocklist,
131                 badNodes:    boom.NewBloomFilter(1000, 0.1),
132         }
133         if c.Conn != nil {
134                 s.socket = c.Conn
135         } else {
136                 s.socket, err = makeSocket(c.Addr)
137                 if err != nil {
138                         return
139                 }
140         }
141         s.bootstrapNodes = c.BootstrapNodes
142         err = s.init()
143         if err != nil {
144                 return
145         }
146         go func() {
147                 err := s.serve()
148                 select {
149                 case <-s.closed:
150                         return
151                 default:
152                 }
153                 if err != nil {
154                         panic(err)
155                 }
156         }()
157         go func() {
158                 err := s.bootstrap()
159                 if err != nil {
160                         select {
161                         case <-s.closed:
162                         default:
163                                 log.Printf("error bootstrapping DHT: %s", err)
164                         }
165                 }
166         }()
167         return
168 }
169
170 // Returns a description of the Server. Python repr-style.
171 func (s *Server) String() string {
172         return fmt.Sprintf("dht server on %s", s.socket.LocalAddr())
173 }
174
175 type nodeID struct {
176         i   big.Int
177         set bool
178 }
179
180 func (nid *nodeID) IsUnset() bool {
181         return !nid.set
182 }
183
184 func nodeIDFromString(s string) (ret nodeID) {
185         if s == "" {
186                 return
187         }
188         ret.i.SetBytes([]byte(s))
189         ret.set = true
190         return
191 }
192
193 func (nid0 *nodeID) Distance(nid1 *nodeID) (ret big.Int) {
194         if nid0.IsUnset() != nid1.IsUnset() {
195                 ret = maxDistance
196                 return
197         }
198         ret.Xor(&nid0.i, &nid1.i)
199         return
200 }
201
202 func (nid *nodeID) ByteString() string {
203         var buf [20]byte
204         b := nid.i.Bytes()
205         copy(buf[20-len(b):], b)
206         return string(buf[:])
207 }
208
209 type node struct {
210         addr          dHTAddr
211         id            nodeID
212         announceToken string
213
214         lastGotQuery    time.Time
215         lastGotResponse time.Time
216         lastSentQuery   time.Time
217 }
218
219 func (n *node) IsSecure() bool {
220         if n.id.IsUnset() {
221                 return false
222         }
223         return NodeIdSecure(n.id.ByteString(), n.addr.IP())
224 }
225
226 func (n *node) idString() string {
227         return n.id.ByteString()
228 }
229
230 func (n *node) SetIDFromBytes(b []byte) {
231         if len(b) != 20 {
232                 panic(b)
233         }
234         n.id.i.SetBytes(b)
235         n.id.set = true
236 }
237
238 func (n *node) SetIDFromString(s string) {
239         n.SetIDFromBytes([]byte(s))
240 }
241
242 func (n *node) IDNotSet() bool {
243         return n.id.i.Int64() == 0
244 }
245
246 func (n *node) NodeInfo() (ret NodeInfo) {
247         ret.Addr = n.addr
248         if n := copy(ret.ID[:], n.idString()); n != 20 {
249                 panic(n)
250         }
251         return
252 }
253
254 func (n *node) DefinitelyGood() bool {
255         if len(n.idString()) != 20 {
256                 return false
257         }
258         // No reason to think ill of them if they've never been queried.
259         if n.lastSentQuery.IsZero() {
260                 return true
261         }
262         // They answered our last query.
263         if n.lastSentQuery.Before(n.lastGotResponse) {
264                 return true
265         }
266         return true
267 }
268
269 // A wrapper around the unmarshalled KRPC dict that constitutes messages in
270 // the DHT. There are various helpers for extracting common data from the
271 // message. In normal use, Msg is abstracted away for you, but it can be of
272 // interest.
273 type Msg map[string]interface{}
274
275 var _ fmt.Stringer = Msg{}
276
277 func (m Msg) String() string {
278         return fmt.Sprintf("%#v", m)
279 }
280
281 func (m Msg) T() (t string) {
282         tif, ok := m["t"]
283         if !ok {
284                 return
285         }
286         t, _ = tif.(string)
287         return
288 }
289
290 func (m Msg) Args() map[string]interface{} {
291         defer func() {
292                 recover()
293         }()
294         return m["a"].(map[string]interface{})
295 }
296
297 func (m Msg) SenderID() string {
298         defer func() {
299                 recover()
300         }()
301         switch m["y"].(string) {
302         case "q":
303                 return m.Args()["id"].(string)
304         case "r":
305                 return m["r"].(map[string]interface{})["id"].(string)
306         }
307         return ""
308 }
309
310 // Suggested nodes in a response.
311 func (m Msg) Nodes() (nodes []NodeInfo) {
312         b := func() string {
313                 defer func() {
314                         recover()
315                 }()
316                 return m["r"].(map[string]interface{})["nodes"].(string)
317         }()
318         if len(b)%26 != 0 {
319                 return
320         }
321         for i := 0; i < len(b); i += 26 {
322                 var n NodeInfo
323                 err := n.UnmarshalCompact([]byte(b[i : i+26]))
324                 if err != nil {
325                         continue
326                 }
327                 nodes = append(nodes, n)
328         }
329         return
330 }
331
332 type KRPCError struct {
333         Code int
334         Msg  string
335 }
336
337 func (me KRPCError) Error() string {
338         return fmt.Sprintf("KRPC error %d: %s", me.Code, me.Msg)
339 }
340
341 var _ error = KRPCError{}
342
343 func (m Msg) Error() (ret *KRPCError) {
344         if m["y"] != "e" {
345                 return
346         }
347         ret = &KRPCError{}
348         switch e := m["e"].(type) {
349         case []interface{}:
350                 ret.Code = int(e[0].(int64))
351                 ret.Msg = e[1].(string)
352         case string:
353                 ret.Msg = e
354         default:
355                 logonce.Stderr.Printf(`KRPC error "e" value has unexpected type: %T`, e)
356         }
357         return
358 }
359
360 // Returns the token given in response to a get_peers request for future
361 // announce_peer requests to that node.
362 func (m Msg) AnnounceToken() (token string, ok bool) {
363         defer func() { recover() }()
364         token, ok = m["r"].(map[string]interface{})["token"].(string)
365         return
366 }
367
368 type Transaction struct {
369         mu             sync.Mutex
370         remoteAddr     dHTAddr
371         t              string
372         response       chan Msg
373         onResponse     func(Msg) // Called with the server locked.
374         done           chan struct{}
375         queryPacket    []byte
376         timer          *time.Timer
377         s              *Server
378         retries        int
379         lastSend       time.Time
380         userOnResponse func(Msg)
381 }
382
383 // Set a function to be called with the response.
384 func (t *Transaction) SetResponseHandler(f func(Msg)) {
385         t.mu.Lock()
386         defer t.mu.Unlock()
387         t.userOnResponse = f
388         t.tryHandleResponse()
389 }
390
391 func (t *Transaction) tryHandleResponse() {
392         if t.userOnResponse == nil {
393                 return
394         }
395         select {
396         case r := <-t.response:
397                 t.userOnResponse(r)
398                 // Shouldn't be called more than once.
399                 t.userOnResponse = nil
400         default:
401         }
402 }
403
404 func (t *Transaction) key() transactionKey {
405         return transactionKey{
406                 t.remoteAddr.String(),
407                 t.t,
408         }
409 }
410
411 func jitterDuration(average time.Duration, plusMinus time.Duration) time.Duration {
412         return average - plusMinus/2 + time.Duration(rand.Int63n(int64(plusMinus)))
413 }
414
415 func (t *Transaction) startTimer() {
416         t.timer = time.AfterFunc(jitterDuration(queryResendEvery, time.Second), t.timerCallback)
417 }
418
419 func (t *Transaction) timerCallback() {
420         t.mu.Lock()
421         defer t.mu.Unlock()
422         select {
423         case <-t.done:
424                 return
425         default:
426         }
427         if t.retries == 2 {
428                 t.timeout()
429                 return
430         }
431         t.retries++
432         t.sendQuery()
433         if t.timer.Reset(jitterDuration(queryResendEvery, time.Second)) {
434                 panic("timer should have fired to get here")
435         }
436 }
437
438 func (t *Transaction) sendQuery() error {
439         err := t.s.writeToNode(t.queryPacket, t.remoteAddr)
440         if err != nil {
441                 return err
442         }
443         t.lastSend = time.Now()
444         return nil
445 }
446
447 func (t *Transaction) timeout() {
448         go func() {
449                 t.s.mu.Lock()
450                 defer t.s.mu.Unlock()
451                 t.s.nodeTimedOut(t.remoteAddr)
452         }()
453         t.close()
454 }
455
456 func (t *Transaction) close() {
457         if t.closing() {
458                 return
459         }
460         t.queryPacket = nil
461         close(t.response)
462         t.tryHandleResponse()
463         close(t.done)
464         t.timer.Stop()
465         go func() {
466                 t.s.mu.Lock()
467                 defer t.s.mu.Unlock()
468                 t.s.deleteTransaction(t)
469         }()
470 }
471
472 func (t *Transaction) closing() bool {
473         select {
474         case <-t.done:
475                 return true
476         default:
477                 return false
478         }
479 }
480
481 // Abandon the transaction.
482 func (t *Transaction) Close() {
483         t.mu.Lock()
484         defer t.mu.Unlock()
485         t.close()
486 }
487
488 func (t *Transaction) handleResponse(m Msg) {
489         t.mu.Lock()
490         if t.closing() {
491                 t.mu.Unlock()
492                 return
493         }
494         close(t.done)
495         t.mu.Unlock()
496         if t.onResponse != nil {
497                 t.s.mu.Lock()
498                 t.onResponse(m)
499                 t.s.mu.Unlock()
500         }
501         t.queryPacket = nil
502         select {
503         case t.response <- m:
504         default:
505                 panic("blocked handling response")
506         }
507         close(t.response)
508         t.tryHandleResponse()
509 }
510
511 func maskForIP(ip net.IP) []byte {
512         switch {
513         case ip.To4() != nil:
514                 return []byte{0x03, 0x0f, 0x3f, 0xff}
515         default:
516                 return []byte{0x01, 0x03, 0x07, 0x0f, 0x1f, 0x3f, 0x7f, 0xff}
517         }
518 }
519
520 // Generate the CRC used to make or validate secure node ID.
521 func crcIP(ip net.IP, rand uint8) uint32 {
522         if ip4 := ip.To4(); ip4 != nil {
523                 ip = ip4
524         }
525         // Copy IP so we can make changes. Go sux at this.
526         ip = append(make(net.IP, 0, len(ip)), ip...)
527         mask := maskForIP(ip)
528         for i := range mask {
529                 ip[i] &= mask[i]
530         }
531         r := rand & 7
532         ip[0] |= r << 5
533         return crc32.Checksum(ip[:len(mask)], crc32.MakeTable(crc32.Castagnoli))
534 }
535
536 // Makes a node ID secure, in-place. The ID is 20 raw bytes.
537 // http://www.libtorrent.org/dht_sec.html
538 func SecureNodeId(id []byte, ip net.IP) {
539         crc := crcIP(ip, id[19])
540         id[0] = byte(crc >> 24 & 0xff)
541         id[1] = byte(crc >> 16 & 0xff)
542         id[2] = byte(crc>>8&0xf8) | id[2]&7
543 }
544
545 // Returns whether the node ID is considered secure. The id is the 20 raw
546 // bytes. http://www.libtorrent.org/dht_sec.html
547 func NodeIdSecure(id string, ip net.IP) bool {
548         if len(id) != 20 {
549                 panic(fmt.Sprintf("%q", id))
550         }
551         if ip4 := ip.To4(); ip4 != nil {
552                 ip = ip4
553         }
554         crc := crcIP(ip, id[19])
555         if id[0] != byte(crc>>24&0xff) {
556                 return false
557         }
558         if id[1] != byte(crc>>16&0xff) {
559                 return false
560         }
561         if id[2]&0xf8 != byte(crc>>8&0xf8) {
562                 return false
563         }
564         return true
565 }
566
567 func (s *Server) setDefaults() (err error) {
568         if s.id == "" {
569                 var id [20]byte
570                 h := crypto.SHA1.New()
571                 ss, err := os.Hostname()
572                 if err != nil {
573                         log.Print(err)
574                 }
575                 ss += s.socket.LocalAddr().String()
576                 h.Write([]byte(ss))
577                 if b := h.Sum(id[:0:20]); len(b) != 20 {
578                         panic(len(b))
579                 }
580                 if len(id) != 20 {
581                         panic(len(id))
582                 }
583                 publicIP := func() net.IP {
584                         if s.config.PublicIP != nil {
585                                 return s.config.PublicIP
586                         } else {
587                                 return missinggo.AddrIP(s.socket.LocalAddr())
588                         }
589                 }()
590                 SecureNodeId(id[:], publicIP)
591                 s.id = string(id[:])
592         }
593         s.nodes = make(map[string]*node, maxNodes)
594         return
595 }
596
597 // Packets to and from any address matching a range in the list are dropped.
598 func (s *Server) SetIPBlockList(list iplist.Ranger) {
599         s.mu.Lock()
600         defer s.mu.Unlock()
601         s.ipBlockList = list
602 }
603
604 func (s *Server) IPBlocklist() iplist.Ranger {
605         return s.ipBlockList
606 }
607
608 func (s *Server) init() (err error) {
609         err = s.setDefaults()
610         if err != nil {
611                 return
612         }
613         s.closed = make(chan struct{})
614         s.transactions = make(map[transactionKey]*Transaction)
615         return
616 }
617
618 func (s *Server) processPacket(b []byte, addr dHTAddr) {
619         if len(b) < 2 || b[0] != 'd' || b[len(b)-1] != 'e' {
620                 // KRPC messages are bencoded dicts.
621                 readNotKRPCDict.Add(1)
622                 return
623         }
624         var d Msg
625         err := bencode.Unmarshal(b, &d)
626         if err != nil {
627                 readUnmarshalError.Add(1)
628                 func() {
629                         if se, ok := err.(*bencode.SyntaxError); ok {
630                                 // The message was truncated.
631                                 if int(se.Offset) == len(b) {
632                                         return
633                                 }
634                                 // Some messages seem to drop to nul chars abrubtly.
635                                 if int(se.Offset) < len(b) && b[se.Offset] == 0 {
636                                         return
637                                 }
638                                 // The message isn't bencode from the first.
639                                 if se.Offset == 0 {
640                                         return
641                                 }
642                         }
643                         if missinggo.CryHeard() {
644                                 log.Printf("%s: received bad krpc message from %s: %s: %+q", s, addr, err, b)
645                         }
646                 }()
647                 return
648         }
649         s.mu.Lock()
650         defer s.mu.Unlock()
651         if d["y"] == "q" {
652                 readQuery.Add(1)
653                 s.handleQuery(addr, d)
654                 return
655         }
656         t := s.findResponseTransaction(d.T(), addr)
657         if t == nil {
658                 //log.Printf("unexpected message: %#v", d)
659                 return
660         }
661         node := s.getNode(addr, d.SenderID())
662         node.lastGotResponse = time.Now()
663         // TODO: Update node ID as this is an authoritative packet.
664         go t.handleResponse(d)
665         s.deleteTransaction(t)
666 }
667
668 func (s *Server) serve() error {
669         var b [0x10000]byte
670         for {
671                 n, addr, err := s.socket.ReadFrom(b[:])
672                 if err != nil {
673                         return err
674                 }
675                 read.Add(1)
676                 if n == len(b) {
677                         logonce.Stderr.Printf("received dht packet exceeds buffer size")
678                         continue
679                 }
680                 s.mu.Lock()
681                 blocked := s.ipBlocked(missinggo.AddrIP(addr))
682                 s.mu.Unlock()
683                 if blocked {
684                         readBlocked.Add(1)
685                         continue
686                 }
687                 s.processPacket(b[:n], newDHTAddr(addr))
688         }
689 }
690
691 func (s *Server) ipBlocked(ip net.IP) (blocked bool) {
692         if s.ipBlockList == nil {
693                 return
694         }
695         _, blocked = s.ipBlockList.Lookup(ip)
696         return
697 }
698
699 // Adds directly to the node table.
700 func (s *Server) AddNode(ni NodeInfo) {
701         s.mu.Lock()
702         defer s.mu.Unlock()
703         if s.nodes == nil {
704                 s.nodes = make(map[string]*node)
705         }
706         s.getNode(ni.Addr, string(ni.ID[:]))
707 }
708
709 func (s *Server) nodeByID(id string) *node {
710         for _, node := range s.nodes {
711                 if node.idString() == id {
712                         return node
713                 }
714         }
715         return nil
716 }
717
718 func (s *Server) handleQuery(source dHTAddr, m Msg) {
719         node := s.getNode(source, m.SenderID())
720         node.lastGotQuery = time.Now()
721         // Don't respond.
722         if s.config.Passive {
723                 return
724         }
725         args := m.Args()
726         if args == nil {
727                 return
728         }
729         switch m["q"] {
730         case "ping":
731                 s.reply(source, m["t"].(string), nil)
732         case "get_peers": // TODO: Extract common behaviour with find_node.
733                 targetID := args["info_hash"].(string)
734                 if len(targetID) != 20 {
735                         break
736                 }
737                 var rNodes []NodeInfo
738                 // TODO: Reply with "values" list if we have peers instead.
739                 for _, node := range s.closestGoodNodes(8, targetID) {
740                         rNodes = append(rNodes, node.NodeInfo())
741                 }
742                 nodesBytes := make([]byte, CompactNodeInfoLen*len(rNodes))
743                 for i, ni := range rNodes {
744                         err := ni.PutCompact(nodesBytes[i*CompactNodeInfoLen : (i+1)*CompactNodeInfoLen])
745                         if err != nil {
746                                 panic(err)
747                         }
748                 }
749                 s.reply(source, m["t"].(string), map[string]interface{}{
750                         "nodes": string(nodesBytes),
751                         "token": "hi",
752                 })
753         case "find_node": // TODO: Extract common behaviour with get_peers.
754                 targetID := args["target"].(string)
755                 if len(targetID) != 20 {
756                         log.Printf("bad DHT query: %v", m)
757                         return
758                 }
759                 var rNodes []NodeInfo
760                 if node := s.nodeByID(targetID); node != nil {
761                         rNodes = append(rNodes, node.NodeInfo())
762                 } else {
763                         for _, node := range s.closestGoodNodes(8, targetID) {
764                                 rNodes = append(rNodes, node.NodeInfo())
765                         }
766                 }
767                 nodesBytes := make([]byte, CompactNodeInfoLen*len(rNodes))
768                 for i, ni := range rNodes {
769                         // TODO: Put IPv6 nodes into the correct dict element.
770                         if ni.Addr.UDPAddr().IP.To4() == nil {
771                                 continue
772                         }
773                         err := ni.PutCompact(nodesBytes[i*CompactNodeInfoLen : (i+1)*CompactNodeInfoLen])
774                         if err != nil {
775                                 log.Printf("error compacting %#v: %s", ni, err)
776                                 continue
777                         }
778                 }
779                 s.reply(source, m["t"].(string), map[string]interface{}{
780                         "nodes": string(nodesBytes),
781                 })
782         case "announce_peer":
783                 // TODO(anacrolix): Implement this lolz.
784                 // log.Print(m)
785         case "vote":
786                 // TODO(anacrolix): Or reject, I don't think I want this.
787         default:
788                 log.Printf("%s: not handling received query: q=%s", s, m["q"])
789                 return
790         }
791 }
792
793 func (s *Server) reply(addr dHTAddr, t string, r map[string]interface{}) {
794         if r == nil {
795                 r = make(map[string]interface{}, 1)
796         }
797         r["id"] = s.ID()
798         m := map[string]interface{}{
799                 "t": t,
800                 "y": "r",
801                 "r": r,
802         }
803         b, err := bencode.Marshal(m)
804         if err != nil {
805                 panic(err)
806         }
807         err = s.writeToNode(b, addr)
808         if err != nil {
809                 log.Printf("error replying to %s: %s", addr, err)
810         }
811 }
812
813 // Returns a node struct for the addr. It is taken from the table or created
814 // and possibly added if required and meets validity constraints.
815 func (s *Server) getNode(addr dHTAddr, id string) (n *node) {
816         addrStr := addr.String()
817         n = s.nodes[addrStr]
818         if n != nil {
819                 if id != "" {
820                         n.SetIDFromString(id)
821                 }
822                 return
823         }
824         n = &node{
825                 addr: addr,
826         }
827         if len(id) == 20 {
828                 n.SetIDFromString(id)
829         }
830         if len(s.nodes) >= maxNodes {
831                 return
832         }
833         if !s.config.NoSecurity && !n.IsSecure() {
834                 return
835         }
836         if s.badNodes.Test([]byte(addrStr)) {
837                 return
838         }
839         s.nodes[addrStr] = n
840         return
841 }
842
843 func (s *Server) nodeTimedOut(addr dHTAddr) {
844         node, ok := s.nodes[addr.String()]
845         if !ok {
846                 return
847         }
848         if node.DefinitelyGood() {
849                 return
850         }
851         if len(s.nodes) < maxNodes {
852                 return
853         }
854         delete(s.nodes, addr.String())
855 }
856
857 func (s *Server) writeToNode(b []byte, node dHTAddr) (err error) {
858         if list := s.ipBlockList; list != nil {
859                 if r := list.Lookup(missinggo.AddrIP(node.UDPAddr())); r != nil {
860                         err = fmt.Errorf("write to %s blocked: %s", node, r.Description)
861                         return
862                 }
863         }
864         n, err := s.socket.WriteTo(b, node.UDPAddr())
865         if err != nil {
866                 err = fmt.Errorf("error writing %d bytes to %s: %#v", len(b), node, err)
867                 return
868         }
869         if n != len(b) {
870                 err = io.ErrShortWrite
871                 return
872         }
873         return
874 }
875
876 func (s *Server) findResponseTransaction(transactionID string, sourceNode dHTAddr) *Transaction {
877         return s.transactions[transactionKey{
878                 sourceNode.String(),
879                 transactionID}]
880 }
881
882 func (s *Server) nextTransactionID() string {
883         var b [binary.MaxVarintLen64]byte
884         n := binary.PutUvarint(b[:], s.transactionIDInt)
885         s.transactionIDInt++
886         return string(b[:n])
887 }
888
889 func (s *Server) deleteTransaction(t *Transaction) {
890         delete(s.transactions, t.key())
891 }
892
893 func (s *Server) addTransaction(t *Transaction) {
894         if _, ok := s.transactions[t.key()]; ok {
895                 panic("transaction not unique")
896         }
897         s.transactions[t.key()] = t
898 }
899
900 // Returns the 20-byte server ID. This is the ID used to communicate with the
901 // DHT network.
902 func (s *Server) ID() string {
903         if len(s.id) != 20 {
904                 panic("bad node id")
905         }
906         return s.id
907 }
908
909 func (s *Server) query(node dHTAddr, q string, a map[string]interface{}, onResponse func(Msg)) (t *Transaction, err error) {
910         tid := s.nextTransactionID()
911         if a == nil {
912                 a = make(map[string]interface{}, 1)
913         }
914         a["id"] = s.ID()
915         d := map[string]interface{}{
916                 "t": tid,
917                 "y": "q",
918                 "q": q,
919                 "a": a,
920         }
921         // BEP 43. Outgoing queries from uncontactiable nodes should contain
922         // "ro":1 in the top level dictionary.
923         if s.config.Passive {
924                 d["ro"] = 1
925         }
926         b, err := bencode.Marshal(d)
927         if err != nil {
928                 return
929         }
930         t = &Transaction{
931                 remoteAddr:  node,
932                 t:           tid,
933                 response:    make(chan Msg, 1),
934                 done:        make(chan struct{}),
935                 queryPacket: b,
936                 s:           s,
937                 onResponse:  onResponse,
938         }
939         err = t.sendQuery()
940         if err != nil {
941                 return
942         }
943         s.getNode(node, "").lastSentQuery = time.Now()
944         t.startTimer()
945         s.addTransaction(t)
946         return
947 }
948
949 // The size in bytes of a NodeInfo in its compact binary representation.
950 const CompactNodeInfoLen = 26
951
952 type NodeInfo struct {
953         ID   [20]byte
954         Addr dHTAddr
955 }
956
957 // Writes the node info to its compact binary representation in b. See
958 // CompactNodeInfoLen.
959 func (ni *NodeInfo) PutCompact(b []byte) error {
960         if n := copy(b[:], ni.ID[:]); n != 20 {
961                 panic(n)
962         }
963         ip := missinggo.AddrIP(ni.Addr).To4()
964         if len(ip) != 4 {
965                 return errors.New("expected ipv4 address")
966         }
967         if n := copy(b[20:], ip); n != 4 {
968                 panic(n)
969         }
970         binary.BigEndian.PutUint16(b[24:], uint16(missinggo.AddrPort(ni.Addr)))
971         return nil
972 }
973
974 func (cni *NodeInfo) UnmarshalCompact(b []byte) error {
975         if len(b) != 26 {
976                 return errors.New("expected 26 bytes")
977         }
978         missinggo.CopyExact(cni.ID[:], b[:20])
979         cni.Addr = newDHTAddr(&net.UDPAddr{
980                 IP:   net.IPv4(b[20], b[21], b[22], b[23]),
981                 Port: int(binary.BigEndian.Uint16(b[24:26])),
982         })
983         return nil
984 }
985
986 // Sends a ping query to the address given.
987 func (s *Server) Ping(node *net.UDPAddr) (*Transaction, error) {
988         s.mu.Lock()
989         defer s.mu.Unlock()
990         return s.query(newDHTAddr(node), "ping", nil, nil)
991 }
992
993 func (s *Server) announcePeer(node dHTAddr, infoHash string, port int, token string, impliedPort bool) (err error) {
994         if port == 0 && !impliedPort {
995                 return errors.New("nothing to announce")
996         }
997         _, err = s.query(node, "announce_peer", map[string]interface{}{
998                 "implied_port": func() int {
999                         if impliedPort {
1000                                 return 1
1001                         } else {
1002                                 return 0
1003                         }
1004                 }(),
1005                 "info_hash": infoHash,
1006                 "port":      port,
1007                 "token":     token,
1008         }, func(m Msg) {
1009                 if err := m.Error(); err != nil {
1010                         announceErrors.Add(1)
1011                         // log.Print(token)
1012                         // logonce.Stderr.Printf("announce_peer response: %s", err)
1013                         return
1014                 }
1015                 s.numConfirmedAnnounces++
1016         })
1017         return
1018 }
1019
1020 // Add response nodes to node table.
1021 func (s *Server) liftNodes(d Msg) {
1022         if d["y"] != "r" {
1023                 return
1024         }
1025         for _, cni := range d.Nodes() {
1026                 if missinggo.AddrPort(cni.Addr) == 0 {
1027                         // TODO: Why would people even do this?
1028                         continue
1029                 }
1030                 if s.ipBlocked(missinggo.AddrIP(cni.Addr)) {
1031                         continue
1032                 }
1033                 n := s.getNode(cni.Addr, string(cni.ID[:]))
1034                 n.SetIDFromBytes(cni.ID[:])
1035         }
1036 }
1037
1038 // Sends a find_node query to addr. targetID is the node we're looking for.
1039 func (s *Server) findNode(addr dHTAddr, targetID string) (t *Transaction, err error) {
1040         t, err = s.query(addr, "find_node", map[string]interface{}{"target": targetID}, func(d Msg) {
1041                 // Scrape peers from the response to put in the server's table before
1042                 // handing the response back to the caller.
1043                 s.liftNodes(d)
1044         })
1045         if err != nil {
1046                 return
1047         }
1048         return
1049 }
1050
1051 type Peer struct {
1052         IP   net.IP
1053         Port int
1054 }
1055
1056 func (me *Peer) String() string {
1057         return net.JoinHostPort(me.IP.String(), strconv.FormatInt(int64(me.Port), 10))
1058 }
1059
1060 // In a get_peers response, the addresses of torrent clients involved with the
1061 // queried info-hash.
1062 func (m Msg) Values() (vs []Peer) {
1063         v := func() interface{} {
1064                 defer func() {
1065                         recover()
1066                 }()
1067                 return m["r"].(map[string]interface{})["values"]
1068         }()
1069         if v == nil {
1070                 return
1071         }
1072         vl, ok := v.([]interface{})
1073         if !ok {
1074                 if missinggo.CryHeard() {
1075                         log.Printf(`unexpected krpc "values" field: %#v`, v)
1076                 }
1077                 return
1078         }
1079         vs = make([]Peer, 0, len(vl))
1080         for _, i := range vl {
1081                 s, ok := i.(string)
1082                 if !ok {
1083                         panic(i)
1084                 }
1085                 // Because it's a list of strings, we can let the length of the string
1086                 // determine the IP version of the compact peer.
1087                 var cp util.CompactPeer
1088                 err := cp.UnmarshalBinary([]byte(s))
1089                 if err != nil {
1090                         log.Printf("error decoding values list element: %s", err)
1091                         continue
1092                 }
1093                 vs = append(vs, Peer{cp.IP[:], int(cp.Port)})
1094         }
1095         return
1096 }
1097
1098 func (s *Server) getPeers(addr dHTAddr, infoHash string) (t *Transaction, err error) {
1099         if len(infoHash) != 20 {
1100                 err = fmt.Errorf("infohash has bad length")
1101                 return
1102         }
1103         t, err = s.query(addr, "get_peers", map[string]interface{}{"info_hash": infoHash}, func(m Msg) {
1104                 s.liftNodes(m)
1105                 at, ok := m.AnnounceToken()
1106                 if ok {
1107                         s.getNode(addr, m.SenderID()).announceToken = at
1108                 }
1109         })
1110         return
1111 }
1112
1113 func bootstrapAddrs(nodeAddrs []string) (addrs []*net.UDPAddr, err error) {
1114         bootstrapNodes := nodeAddrs
1115         if len(bootstrapNodes) == 0 {
1116                 bootstrapNodes = []string{
1117                         "router.utorrent.com:6881",
1118                         "router.bittorrent.com:6881",
1119                 }
1120         }
1121         for _, addrStr := range bootstrapNodes {
1122                 udpAddr, err := net.ResolveUDPAddr("udp4", addrStr)
1123                 if err != nil {
1124                         continue
1125                 }
1126                 addrs = append(addrs, udpAddr)
1127         }
1128         if len(addrs) == 0 {
1129                 err = errors.New("nothing resolved")
1130         }
1131         return
1132 }
1133
1134 // Adds bootstrap nodes directly to table, if there's room. Node ID security
1135 // is bypassed, but the IP blocklist is not.
1136 func (s *Server) addRootNodes() error {
1137         addrs, err := bootstrapAddrs(s.bootstrapNodes)
1138         if err != nil {
1139                 return err
1140         }
1141         for _, addr := range addrs {
1142                 if len(s.nodes) >= maxNodes {
1143                         break
1144                 }
1145                 if s.nodes[addr.String()] != nil {
1146                         continue
1147                 }
1148                 if s.ipBlocked(addr.IP) {
1149                         log.Printf("dht root node is in the blocklist: %s", addr.IP)
1150                         continue
1151                 }
1152                 s.nodes[addr.String()] = &node{
1153                         addr: newDHTAddr(addr),
1154                 }
1155         }
1156         return nil
1157 }
1158
1159 // Populates the node table.
1160 func (s *Server) bootstrap() (err error) {
1161         s.mu.Lock()
1162         defer s.mu.Unlock()
1163         if len(s.nodes) == 0 {
1164                 err = s.addRootNodes()
1165         }
1166         if err != nil {
1167                 return
1168         }
1169         for {
1170                 var outstanding sync.WaitGroup
1171                 for _, node := range s.nodes {
1172                         var t *Transaction
1173                         t, err = s.findNode(node.addr, s.id)
1174                         if err != nil {
1175                                 err = fmt.Errorf("error sending find_node: %s", err)
1176                                 return
1177                         }
1178                         outstanding.Add(1)
1179                         t.SetResponseHandler(func(Msg) {
1180                                 outstanding.Done()
1181                         })
1182                 }
1183                 noOutstanding := make(chan struct{})
1184                 go func() {
1185                         outstanding.Wait()
1186                         close(noOutstanding)
1187                 }()
1188                 s.mu.Unlock()
1189                 select {
1190                 case <-s.closed:
1191                         s.mu.Lock()
1192                         return
1193                 case <-time.After(15 * time.Second):
1194                 case <-noOutstanding:
1195                 }
1196                 s.mu.Lock()
1197                 // log.Printf("now have %d nodes", len(s.nodes))
1198                 if s.numGoodNodes() >= 160 {
1199                         break
1200                 }
1201         }
1202         return
1203 }
1204
1205 func (s *Server) numGoodNodes() (num int) {
1206         for _, n := range s.nodes {
1207                 if n.DefinitelyGood() {
1208                         num++
1209                 }
1210         }
1211         return
1212 }
1213
1214 // Returns how many nodes are in the node table.
1215 func (s *Server) NumNodes() int {
1216         s.mu.Lock()
1217         defer s.mu.Unlock()
1218         return len(s.nodes)
1219 }
1220
1221 // Exports the current node table.
1222 func (s *Server) Nodes() (nis []NodeInfo) {
1223         s.mu.Lock()
1224         defer s.mu.Unlock()
1225         for _, node := range s.nodes {
1226                 // if !node.Good() {
1227                 //      continue
1228                 // }
1229                 ni := NodeInfo{
1230                         Addr: node.addr,
1231                 }
1232                 if n := copy(ni.ID[:], node.idString()); n != 20 && n != 0 {
1233                         panic(n)
1234                 }
1235                 nis = append(nis, ni)
1236         }
1237         return
1238 }
1239
1240 // Stops the server network activity. This is all that's required to clean-up a Server.
1241 func (s *Server) Close() {
1242         s.mu.Lock()
1243         select {
1244         case <-s.closed:
1245         default:
1246                 close(s.closed)
1247                 s.socket.Close()
1248         }
1249         s.mu.Unlock()
1250 }
1251
1252 var maxDistance big.Int
1253
1254 func init() {
1255         var zero big.Int
1256         maxDistance.SetBit(&zero, 160, 1)
1257 }
1258
1259 func (s *Server) closestGoodNodes(k int, targetID string) []*node {
1260         return s.closestNodes(k, nodeIDFromString(targetID), func(n *node) bool { return n.DefinitelyGood() })
1261 }
1262
1263 func (s *Server) closestNodes(k int, target nodeID, filter func(*node) bool) []*node {
1264         sel := newKClosestNodesSelector(k, target)
1265         idNodes := make(map[string]*node, len(s.nodes))
1266         for _, node := range s.nodes {
1267                 if !filter(node) {
1268                         continue
1269                 }
1270                 sel.Push(node.id)
1271                 idNodes[node.idString()] = node
1272         }
1273         ids := sel.IDs()
1274         ret := make([]*node, 0, len(ids))
1275         for _, id := range ids {
1276                 ret = append(ret, idNodes[id.ByteString()])
1277         }
1278         return ret
1279 }
1280
1281 func (me *Server) badNode(addr dHTAddr) {
1282         me.badNodes.Add([]byte(addr.String()))
1283         delete(me.nodes, addr.String())
1284 }