]> Sergey Matveev's repositories - btrtrc.git/blob - dht/dht.go
Add packed IP list
[btrtrc.git] / dht / dht.go
1 // Package DHT implements a DHT for use with the BitTorrent protocol,
2 // described in BEP 5: http://www.bittorrent.org/beps/bep_0005.html.
3 //
4 // Standard use involves creating a NewServer, and calling Announce on it with
5 // the details of your local torrent client and infohash of interest.
6 package dht
7
8 import (
9         "crypto"
10         _ "crypto/sha1"
11         "encoding/binary"
12         "errors"
13         "fmt"
14         "hash/crc32"
15         "io"
16         "log"
17         "math/big"
18         "math/rand"
19         "net"
20         "os"
21         "strconv"
22         "time"
23
24         "github.com/anacrolix/missinggo"
25         "github.com/anacrolix/sync"
26         "github.com/tylertreat/BoomFilters"
27
28         "github.com/anacrolix/torrent/bencode"
29         "github.com/anacrolix/torrent/iplist"
30         "github.com/anacrolix/torrent/logonce"
31         "github.com/anacrolix/torrent/util"
32 )
33
34 const (
35         maxNodes         = 320
36         queryResendEvery = 5 * time.Second
37 )
38
39 // Uniquely identifies a transaction to us.
40 type transactionKey struct {
41         RemoteAddr string // host:port
42         T          string // The KRPC transaction ID.
43 }
44
45 type Server struct {
46         id               string
47         socket           net.PacketConn
48         transactions     map[transactionKey]*Transaction
49         transactionIDInt uint64
50         nodes            map[string]*node // Keyed by dHTAddr.String().
51         mu               sync.Mutex
52         closed           chan struct{}
53         ipBlockList      iplist.Ranger
54         badNodes         *boom.BloomFilter
55
56         numConfirmedAnnounces int
57         bootstrapNodes        []string
58         config                ServerConfig
59 }
60
61 type ServerConfig struct {
62         Addr string // Listen address. Used if Conn is nil.
63         Conn net.PacketConn
64         // Don't respond to queries from other nodes.
65         Passive bool
66         // DHT Bootstrap nodes
67         BootstrapNodes []string
68         // Disable the DHT security extension:
69         // http://www.libtorrent.org/dht_sec.html.
70         NoSecurity bool
71         // Initial IP blocklist to use. Applied before serving and bootstrapping
72         // begins.
73         IPBlocklist iplist.Ranger
74         // Used to secure the server's ID. Defaults to the Conn's LocalAddr().
75         PublicIP net.IP
76 }
77
78 type ServerStats struct {
79         // Count of nodes in the node table that responded to our last query or
80         // haven't yet been queried.
81         GoodNodes int
82         // Count of nodes in the node table.
83         Nodes int
84         // Transactions awaiting a response.
85         OutstandingTransactions int
86         // Individual announce_peer requests that got a success response.
87         ConfirmedAnnounces int
88         // Nodes that have been blocked.
89         BadNodes uint
90 }
91
92 // Returns statistics for the server.
93 func (s *Server) Stats() (ss ServerStats) {
94         s.mu.Lock()
95         defer s.mu.Unlock()
96         for _, n := range s.nodes {
97                 if n.DefinitelyGood() {
98                         ss.GoodNodes++
99                 }
100         }
101         ss.Nodes = len(s.nodes)
102         ss.OutstandingTransactions = len(s.transactions)
103         ss.ConfirmedAnnounces = s.numConfirmedAnnounces
104         ss.BadNodes = s.badNodes.Count()
105         return
106 }
107
108 // Returns the listen address for the server. Packets arriving to this address
109 // are processed by the server (unless aliens are involved).
110 func (s *Server) Addr() net.Addr {
111         return s.socket.LocalAddr()
112 }
113
114 func makeSocket(addr string) (socket *net.UDPConn, err error) {
115         addr_, err := net.ResolveUDPAddr("", addr)
116         if err != nil {
117                 return
118         }
119         socket, err = net.ListenUDP("udp", addr_)
120         return
121 }
122
123 // Create a new DHT server.
124 func NewServer(c *ServerConfig) (s *Server, err error) {
125         if c == nil {
126                 c = &ServerConfig{}
127         }
128         s = &Server{
129                 config:      *c,
130                 ipBlockList: c.IPBlocklist,
131                 badNodes:    boom.NewBloomFilter(1000, 0.1),
132         }
133         if c.Conn != nil {
134                 s.socket = c.Conn
135         } else {
136                 s.socket, err = makeSocket(c.Addr)
137                 if err != nil {
138                         return
139                 }
140         }
141         s.bootstrapNodes = c.BootstrapNodes
142         err = s.init()
143         if err != nil {
144                 return
145         }
146         go func() {
147                 err := s.serve()
148                 select {
149                 case <-s.closed:
150                         return
151                 default:
152                 }
153                 if err != nil {
154                         panic(err)
155                 }
156         }()
157         go func() {
158                 err := s.bootstrap()
159                 if err != nil {
160                         select {
161                         case <-s.closed:
162                         default:
163                                 log.Printf("error bootstrapping DHT: %s", err)
164                         }
165                 }
166         }()
167         return
168 }
169
170 // Returns a description of the Server. Python repr-style.
171 func (s *Server) String() string {
172         return fmt.Sprintf("dht server on %s", s.socket.LocalAddr())
173 }
174
175 type nodeID struct {
176         i   big.Int
177         set bool
178 }
179
180 func (nid *nodeID) IsUnset() bool {
181         return !nid.set
182 }
183
184 func nodeIDFromString(s string) (ret nodeID) {
185         if s == "" {
186                 return
187         }
188         ret.i.SetBytes([]byte(s))
189         ret.set = true
190         return
191 }
192
193 func (nid0 *nodeID) Distance(nid1 *nodeID) (ret big.Int) {
194         if nid0.IsUnset() != nid1.IsUnset() {
195                 ret = maxDistance
196                 return
197         }
198         ret.Xor(&nid0.i, &nid1.i)
199         return
200 }
201
202 func (nid *nodeID) ByteString() string {
203         var buf [20]byte
204         b := nid.i.Bytes()
205         copy(buf[20-len(b):], b)
206         return string(buf[:])
207 }
208
209 type node struct {
210         addr          dHTAddr
211         id            nodeID
212         announceToken string
213
214         lastGotQuery    time.Time
215         lastGotResponse time.Time
216         lastSentQuery   time.Time
217 }
218
219 func (n *node) IsSecure() bool {
220         if n.id.IsUnset() {
221                 return false
222         }
223         return NodeIdSecure(n.id.ByteString(), n.addr.IP())
224 }
225
226 func (n *node) idString() string {
227         return n.id.ByteString()
228 }
229
230 func (n *node) SetIDFromBytes(b []byte) {
231         if len(b) != 20 {
232                 panic(b)
233         }
234         n.id.i.SetBytes(b)
235         n.id.set = true
236 }
237
238 func (n *node) SetIDFromString(s string) {
239         n.SetIDFromBytes([]byte(s))
240 }
241
242 func (n *node) IDNotSet() bool {
243         return n.id.i.Int64() == 0
244 }
245
246 func (n *node) NodeInfo() (ret NodeInfo) {
247         ret.Addr = n.addr
248         if n := copy(ret.ID[:], n.idString()); n != 20 {
249                 panic(n)
250         }
251         return
252 }
253
254 func (n *node) DefinitelyGood() bool {
255         if len(n.idString()) != 20 {
256                 return false
257         }
258         // No reason to think ill of them if they've never been queried.
259         if n.lastSentQuery.IsZero() {
260                 return true
261         }
262         // They answered our last query.
263         if n.lastSentQuery.Before(n.lastGotResponse) {
264                 return true
265         }
266         return true
267 }
268
269 // A wrapper around the unmarshalled KRPC dict that constitutes messages in
270 // the DHT. There are various helpers for extracting common data from the
271 // message. In normal use, Msg is abstracted away for you, but it can be of
272 // interest.
273 type Msg map[string]interface{}
274
275 var _ fmt.Stringer = Msg{}
276
277 func (m Msg) String() string {
278         return fmt.Sprintf("%#v", m)
279 }
280
281 func (m Msg) T() (t string) {
282         tif, ok := m["t"]
283         if !ok {
284                 return
285         }
286         t, _ = tif.(string)
287         return
288 }
289
290 func (m Msg) Args() map[string]interface{} {
291         defer func() {
292                 recover()
293         }()
294         return m["a"].(map[string]interface{})
295 }
296
297 func (m Msg) SenderID() string {
298         defer func() {
299                 recover()
300         }()
301         switch m["y"].(string) {
302         case "q":
303                 return m.Args()["id"].(string)
304         case "r":
305                 return m["r"].(map[string]interface{})["id"].(string)
306         }
307         return ""
308 }
309
310 // Suggested nodes in a response.
311 func (m Msg) Nodes() (nodes []NodeInfo) {
312         b := func() string {
313                 defer func() {
314                         recover()
315                 }()
316                 return m["r"].(map[string]interface{})["nodes"].(string)
317         }()
318         if len(b)%26 != 0 {
319                 return
320         }
321         for i := 0; i < len(b); i += 26 {
322                 var n NodeInfo
323                 err := n.UnmarshalCompact([]byte(b[i : i+26]))
324                 if err != nil {
325                         continue
326                 }
327                 nodes = append(nodes, n)
328         }
329         return
330 }
331
332 type KRPCError struct {
333         Code int
334         Msg  string
335 }
336
337 func (me KRPCError) Error() string {
338         return fmt.Sprintf("KRPC error %d: %s", me.Code, me.Msg)
339 }
340
341 var _ error = KRPCError{}
342
343 func (m Msg) Error() (ret *KRPCError) {
344         if m["y"] != "e" {
345                 return
346         }
347         ret = &KRPCError{}
348         switch e := m["e"].(type) {
349         case []interface{}:
350                 ret.Code = int(e[0].(int64))
351                 ret.Msg = e[1].(string)
352         case string:
353                 ret.Msg = e
354         default:
355                 logonce.Stderr.Printf(`KRPC error "e" value has unexpected type: %T`, e)
356         }
357         return
358 }
359
360 // Returns the token given in response to a get_peers request for future
361 // announce_peer requests to that node.
362 func (m Msg) AnnounceToken() (token string, ok bool) {
363         defer func() { recover() }()
364         token, ok = m["r"].(map[string]interface{})["token"].(string)
365         return
366 }
367
368 type Transaction struct {
369         mu             sync.Mutex
370         remoteAddr     dHTAddr
371         t              string
372         response       chan Msg
373         onResponse     func(Msg) // Called with the server locked.
374         done           chan struct{}
375         queryPacket    []byte
376         timer          *time.Timer
377         s              *Server
378         retries        int
379         lastSend       time.Time
380         userOnResponse func(Msg)
381 }
382
383 // Set a function to be called with the response.
384 func (t *Transaction) SetResponseHandler(f func(Msg)) {
385         t.mu.Lock()
386         defer t.mu.Unlock()
387         t.userOnResponse = f
388         t.tryHandleResponse()
389 }
390
391 func (t *Transaction) tryHandleResponse() {
392         if t.userOnResponse == nil {
393                 return
394         }
395         select {
396         case r := <-t.response:
397                 t.userOnResponse(r)
398                 // Shouldn't be called more than once.
399                 t.userOnResponse = nil
400         default:
401         }
402 }
403
404 func (t *Transaction) key() transactionKey {
405         return transactionKey{
406                 t.remoteAddr.String(),
407                 t.t,
408         }
409 }
410
411 func jitterDuration(average time.Duration, plusMinus time.Duration) time.Duration {
412         return average - plusMinus/2 + time.Duration(rand.Int63n(int64(plusMinus)))
413 }
414
415 func (t *Transaction) startTimer() {
416         t.timer = time.AfterFunc(jitterDuration(queryResendEvery, time.Second), t.timerCallback)
417 }
418
419 func (t *Transaction) timerCallback() {
420         t.mu.Lock()
421         defer t.mu.Unlock()
422         select {
423         case <-t.done:
424                 return
425         default:
426         }
427         if t.retries == 2 {
428                 t.timeout()
429                 return
430         }
431         t.retries++
432         t.sendQuery()
433         if t.timer.Reset(jitterDuration(queryResendEvery, time.Second)) {
434                 panic("timer should have fired to get here")
435         }
436 }
437
438 func (t *Transaction) sendQuery() error {
439         err := t.s.writeToNode(t.queryPacket, t.remoteAddr)
440         if err != nil {
441                 return err
442         }
443         t.lastSend = time.Now()
444         return nil
445 }
446
447 func (t *Transaction) timeout() {
448         go func() {
449                 t.s.mu.Lock()
450                 defer t.s.mu.Unlock()
451                 t.s.nodeTimedOut(t.remoteAddr)
452         }()
453         t.close()
454 }
455
456 func (t *Transaction) close() {
457         if t.closing() {
458                 return
459         }
460         t.queryPacket = nil
461         close(t.response)
462         t.tryHandleResponse()
463         close(t.done)
464         t.timer.Stop()
465         go func() {
466                 t.s.mu.Lock()
467                 defer t.s.mu.Unlock()
468                 t.s.deleteTransaction(t)
469         }()
470 }
471
472 func (t *Transaction) closing() bool {
473         select {
474         case <-t.done:
475                 return true
476         default:
477                 return false
478         }
479 }
480
481 // Abandon the transaction.
482 func (t *Transaction) Close() {
483         t.mu.Lock()
484         defer t.mu.Unlock()
485         t.close()
486 }
487
488 func (t *Transaction) handleResponse(m Msg) {
489         t.mu.Lock()
490         if t.closing() {
491                 t.mu.Unlock()
492                 return
493         }
494         close(t.done)
495         t.mu.Unlock()
496         if t.onResponse != nil {
497                 t.s.mu.Lock()
498                 t.onResponse(m)
499                 t.s.mu.Unlock()
500         }
501         t.queryPacket = nil
502         select {
503         case t.response <- m:
504         default:
505                 panic("blocked handling response")
506         }
507         close(t.response)
508         t.tryHandleResponse()
509 }
510
511 func maskForIP(ip net.IP) []byte {
512         switch {
513         case ip.To4() != nil:
514                 return []byte{0x03, 0x0f, 0x3f, 0xff}
515         default:
516                 return []byte{0x01, 0x03, 0x07, 0x0f, 0x1f, 0x3f, 0x7f, 0xff}
517         }
518 }
519
520 // Generate the CRC used to make or validate secure node ID.
521 func crcIP(ip net.IP, rand uint8) uint32 {
522         if ip4 := ip.To4(); ip4 != nil {
523                 ip = ip4
524         }
525         // Copy IP so we can make changes. Go sux at this.
526         ip = append(make(net.IP, 0, len(ip)), ip...)
527         mask := maskForIP(ip)
528         for i := range mask {
529                 ip[i] &= mask[i]
530         }
531         r := rand & 7
532         ip[0] |= r << 5
533         return crc32.Checksum(ip[:len(mask)], crc32.MakeTable(crc32.Castagnoli))
534 }
535
536 // Makes a node ID secure, in-place. The ID is 20 raw bytes.
537 // http://www.libtorrent.org/dht_sec.html
538 func SecureNodeId(id []byte, ip net.IP) {
539         crc := crcIP(ip, id[19])
540         id[0] = byte(crc >> 24 & 0xff)
541         id[1] = byte(crc >> 16 & 0xff)
542         id[2] = byte(crc>>8&0xf8) | id[2]&7
543 }
544
545 // Returns whether the node ID is considered secure. The id is the 20 raw
546 // bytes. http://www.libtorrent.org/dht_sec.html
547 func NodeIdSecure(id string, ip net.IP) bool {
548         if len(id) != 20 {
549                 panic(fmt.Sprintf("%q", id))
550         }
551         if ip4 := ip.To4(); ip4 != nil {
552                 ip = ip4
553         }
554         crc := crcIP(ip, id[19])
555         if id[0] != byte(crc>>24&0xff) {
556                 return false
557         }
558         if id[1] != byte(crc>>16&0xff) {
559                 return false
560         }
561         if id[2]&0xf8 != byte(crc>>8&0xf8) {
562                 return false
563         }
564         return true
565 }
566
567 func (s *Server) setDefaults() (err error) {
568         if s.id == "" {
569                 var id [20]byte
570                 h := crypto.SHA1.New()
571                 ss, err := os.Hostname()
572                 if err != nil {
573                         log.Print(err)
574                 }
575                 ss += s.socket.LocalAddr().String()
576                 h.Write([]byte(ss))
577                 if b := h.Sum(id[:0:20]); len(b) != 20 {
578                         panic(len(b))
579                 }
580                 if len(id) != 20 {
581                         panic(len(id))
582                 }
583                 publicIP := func() net.IP {
584                         if s.config.PublicIP != nil {
585                                 return s.config.PublicIP
586                         } else {
587                                 return missinggo.AddrIP(s.socket.LocalAddr())
588                         }
589                 }()
590                 SecureNodeId(id[:], publicIP)
591                 s.id = string(id[:])
592         }
593         s.nodes = make(map[string]*node, maxNodes)
594         return
595 }
596
597 // Packets to and from any address matching a range in the list are dropped.
598 func (s *Server) SetIPBlockList(list iplist.Ranger) {
599         s.mu.Lock()
600         defer s.mu.Unlock()
601         s.ipBlockList = list
602 }
603
604 func (s *Server) IPBlocklist() iplist.Ranger {
605         return s.ipBlockList
606 }
607
608 func (s *Server) init() (err error) {
609         err = s.setDefaults()
610         if err != nil {
611                 return
612         }
613         s.closed = make(chan struct{})
614         s.transactions = make(map[transactionKey]*Transaction)
615         return
616 }
617
618 func (s *Server) processPacket(b []byte, addr dHTAddr) {
619         if len(b) < 2 || b[0] != 'd' || b[len(b)-1] != 'e' {
620                 // KRPC messages are bencoded dicts.
621                 readNotKRPCDict.Add(1)
622                 return
623         }
624         var d Msg
625         err := bencode.Unmarshal(b, &d)
626         if err != nil {
627                 readUnmarshalError.Add(1)
628                 func() {
629                         if se, ok := err.(*bencode.SyntaxError); ok {
630                                 // The message was truncated.
631                                 if int(se.Offset) == len(b) {
632                                         return
633                                 }
634                                 // Some messages seem to drop to nul chars abrubtly.
635                                 if int(se.Offset) < len(b) && b[se.Offset] == 0 {
636                                         return
637                                 }
638                                 // The message isn't bencode from the first.
639                                 if se.Offset == 0 {
640                                         return
641                                 }
642                         }
643                         if missinggo.CryHeard() {
644                                 log.Printf("%s: received bad krpc message from %s: %s: %+q", s, addr, err, b)
645                         }
646                 }()
647                 return
648         }
649         s.mu.Lock()
650         defer s.mu.Unlock()
651         if d["y"] == "q" {
652                 readQuery.Add(1)
653                 s.handleQuery(addr, d)
654                 return
655         }
656         t := s.findResponseTransaction(d.T(), addr)
657         if t == nil {
658                 //log.Printf("unexpected message: %#v", d)
659                 return
660         }
661         node := s.getNode(addr, d.SenderID())
662         node.lastGotResponse = time.Now()
663         // TODO: Update node ID as this is an authoritative packet.
664         go t.handleResponse(d)
665         s.deleteTransaction(t)
666 }
667
668 func (s *Server) serve() error {
669         var b [0x10000]byte
670         for {
671                 n, addr, err := s.socket.ReadFrom(b[:])
672                 if err != nil {
673                         return err
674                 }
675                 read.Add(1)
676                 if n == len(b) {
677                         logonce.Stderr.Printf("received dht packet exceeds buffer size")
678                         continue
679                 }
680                 s.mu.Lock()
681                 blocked := s.ipBlocked(missinggo.AddrIP(addr))
682                 s.mu.Unlock()
683                 if blocked {
684                         readBlocked.Add(1)
685                         continue
686                 }
687                 s.processPacket(b[:n], newDHTAddr(addr))
688         }
689 }
690
691 func (s *Server) ipBlocked(ip net.IP) bool {
692         if s.ipBlockList == nil {
693                 return false
694         }
695         return s.ipBlockList.Lookup(ip) != nil
696 }
697
698 // Adds directly to the node table.
699 func (s *Server) AddNode(ni NodeInfo) {
700         s.mu.Lock()
701         defer s.mu.Unlock()
702         if s.nodes == nil {
703                 s.nodes = make(map[string]*node)
704         }
705         s.getNode(ni.Addr, string(ni.ID[:]))
706 }
707
708 func (s *Server) nodeByID(id string) *node {
709         for _, node := range s.nodes {
710                 if node.idString() == id {
711                         return node
712                 }
713         }
714         return nil
715 }
716
717 func (s *Server) handleQuery(source dHTAddr, m Msg) {
718         node := s.getNode(source, m.SenderID())
719         node.lastGotQuery = time.Now()
720         // Don't respond.
721         if s.config.Passive {
722                 return
723         }
724         args := m.Args()
725         if args == nil {
726                 return
727         }
728         switch m["q"] {
729         case "ping":
730                 s.reply(source, m["t"].(string), nil)
731         case "get_peers": // TODO: Extract common behaviour with find_node.
732                 targetID := args["info_hash"].(string)
733                 if len(targetID) != 20 {
734                         break
735                 }
736                 var rNodes []NodeInfo
737                 // TODO: Reply with "values" list if we have peers instead.
738                 for _, node := range s.closestGoodNodes(8, targetID) {
739                         rNodes = append(rNodes, node.NodeInfo())
740                 }
741                 nodesBytes := make([]byte, CompactNodeInfoLen*len(rNodes))
742                 for i, ni := range rNodes {
743                         err := ni.PutCompact(nodesBytes[i*CompactNodeInfoLen : (i+1)*CompactNodeInfoLen])
744                         if err != nil {
745                                 panic(err)
746                         }
747                 }
748                 s.reply(source, m["t"].(string), map[string]interface{}{
749                         "nodes": string(nodesBytes),
750                         "token": "hi",
751                 })
752         case "find_node": // TODO: Extract common behaviour with get_peers.
753                 targetID := args["target"].(string)
754                 if len(targetID) != 20 {
755                         log.Printf("bad DHT query: %v", m)
756                         return
757                 }
758                 var rNodes []NodeInfo
759                 if node := s.nodeByID(targetID); node != nil {
760                         rNodes = append(rNodes, node.NodeInfo())
761                 } else {
762                         for _, node := range s.closestGoodNodes(8, targetID) {
763                                 rNodes = append(rNodes, node.NodeInfo())
764                         }
765                 }
766                 nodesBytes := make([]byte, CompactNodeInfoLen*len(rNodes))
767                 for i, ni := range rNodes {
768                         // TODO: Put IPv6 nodes into the correct dict element.
769                         if ni.Addr.UDPAddr().IP.To4() == nil {
770                                 continue
771                         }
772                         err := ni.PutCompact(nodesBytes[i*CompactNodeInfoLen : (i+1)*CompactNodeInfoLen])
773                         if err != nil {
774                                 log.Printf("error compacting %#v: %s", ni, err)
775                                 continue
776                         }
777                 }
778                 s.reply(source, m["t"].(string), map[string]interface{}{
779                         "nodes": string(nodesBytes),
780                 })
781         case "announce_peer":
782                 // TODO(anacrolix): Implement this lolz.
783                 // log.Print(m)
784         case "vote":
785                 // TODO(anacrolix): Or reject, I don't think I want this.
786         default:
787                 log.Printf("%s: not handling received query: q=%s", s, m["q"])
788                 return
789         }
790 }
791
792 func (s *Server) reply(addr dHTAddr, t string, r map[string]interface{}) {
793         if r == nil {
794                 r = make(map[string]interface{}, 1)
795         }
796         r["id"] = s.ID()
797         m := map[string]interface{}{
798                 "t": t,
799                 "y": "r",
800                 "r": r,
801         }
802         b, err := bencode.Marshal(m)
803         if err != nil {
804                 panic(err)
805         }
806         err = s.writeToNode(b, addr)
807         if err != nil {
808                 log.Printf("error replying to %s: %s", addr, err)
809         }
810 }
811
812 // Returns a node struct for the addr. It is taken from the table or created
813 // and possibly added if required and meets validity constraints.
814 func (s *Server) getNode(addr dHTAddr, id string) (n *node) {
815         addrStr := addr.String()
816         n = s.nodes[addrStr]
817         if n != nil {
818                 if id != "" {
819                         n.SetIDFromString(id)
820                 }
821                 return
822         }
823         n = &node{
824                 addr: addr,
825         }
826         if len(id) == 20 {
827                 n.SetIDFromString(id)
828         }
829         if len(s.nodes) >= maxNodes {
830                 return
831         }
832         if !s.config.NoSecurity && !n.IsSecure() {
833                 return
834         }
835         if s.badNodes.Test([]byte(addrStr)) {
836                 return
837         }
838         s.nodes[addrStr] = n
839         return
840 }
841
842 func (s *Server) nodeTimedOut(addr dHTAddr) {
843         node, ok := s.nodes[addr.String()]
844         if !ok {
845                 return
846         }
847         if node.DefinitelyGood() {
848                 return
849         }
850         if len(s.nodes) < maxNodes {
851                 return
852         }
853         delete(s.nodes, addr.String())
854 }
855
856 func (s *Server) writeToNode(b []byte, node dHTAddr) (err error) {
857         if list := s.ipBlockList; list != nil {
858                 if r := list.Lookup(missinggo.AddrIP(node.UDPAddr())); r != nil {
859                         err = fmt.Errorf("write to %s blocked: %s", node, r.Description)
860                         return
861                 }
862         }
863         n, err := s.socket.WriteTo(b, node.UDPAddr())
864         if err != nil {
865                 err = fmt.Errorf("error writing %d bytes to %s: %#v", len(b), node, err)
866                 return
867         }
868         if n != len(b) {
869                 err = io.ErrShortWrite
870                 return
871         }
872         return
873 }
874
875 func (s *Server) findResponseTransaction(transactionID string, sourceNode dHTAddr) *Transaction {
876         return s.transactions[transactionKey{
877                 sourceNode.String(),
878                 transactionID}]
879 }
880
881 func (s *Server) nextTransactionID() string {
882         var b [binary.MaxVarintLen64]byte
883         n := binary.PutUvarint(b[:], s.transactionIDInt)
884         s.transactionIDInt++
885         return string(b[:n])
886 }
887
888 func (s *Server) deleteTransaction(t *Transaction) {
889         delete(s.transactions, t.key())
890 }
891
892 func (s *Server) addTransaction(t *Transaction) {
893         if _, ok := s.transactions[t.key()]; ok {
894                 panic("transaction not unique")
895         }
896         s.transactions[t.key()] = t
897 }
898
899 // Returns the 20-byte server ID. This is the ID used to communicate with the
900 // DHT network.
901 func (s *Server) ID() string {
902         if len(s.id) != 20 {
903                 panic("bad node id")
904         }
905         return s.id
906 }
907
908 func (s *Server) query(node dHTAddr, q string, a map[string]interface{}, onResponse func(Msg)) (t *Transaction, err error) {
909         tid := s.nextTransactionID()
910         if a == nil {
911                 a = make(map[string]interface{}, 1)
912         }
913         a["id"] = s.ID()
914         d := map[string]interface{}{
915                 "t": tid,
916                 "y": "q",
917                 "q": q,
918                 "a": a,
919         }
920         // BEP 43. Outgoing queries from uncontactiable nodes should contain
921         // "ro":1 in the top level dictionary.
922         if s.config.Passive {
923                 d["ro"] = 1
924         }
925         b, err := bencode.Marshal(d)
926         if err != nil {
927                 return
928         }
929         t = &Transaction{
930                 remoteAddr:  node,
931                 t:           tid,
932                 response:    make(chan Msg, 1),
933                 done:        make(chan struct{}),
934                 queryPacket: b,
935                 s:           s,
936                 onResponse:  onResponse,
937         }
938         err = t.sendQuery()
939         if err != nil {
940                 return
941         }
942         s.getNode(node, "").lastSentQuery = time.Now()
943         t.startTimer()
944         s.addTransaction(t)
945         return
946 }
947
948 // The size in bytes of a NodeInfo in its compact binary representation.
949 const CompactNodeInfoLen = 26
950
951 type NodeInfo struct {
952         ID   [20]byte
953         Addr dHTAddr
954 }
955
956 // Writes the node info to its compact binary representation in b. See
957 // CompactNodeInfoLen.
958 func (ni *NodeInfo) PutCompact(b []byte) error {
959         if n := copy(b[:], ni.ID[:]); n != 20 {
960                 panic(n)
961         }
962         ip := missinggo.AddrIP(ni.Addr).To4()
963         if len(ip) != 4 {
964                 return errors.New("expected ipv4 address")
965         }
966         if n := copy(b[20:], ip); n != 4 {
967                 panic(n)
968         }
969         binary.BigEndian.PutUint16(b[24:], uint16(missinggo.AddrPort(ni.Addr)))
970         return nil
971 }
972
973 func (cni *NodeInfo) UnmarshalCompact(b []byte) error {
974         if len(b) != 26 {
975                 return errors.New("expected 26 bytes")
976         }
977         missinggo.CopyExact(cni.ID[:], b[:20])
978         cni.Addr = newDHTAddr(&net.UDPAddr{
979                 IP:   net.IPv4(b[20], b[21], b[22], b[23]),
980                 Port: int(binary.BigEndian.Uint16(b[24:26])),
981         })
982         return nil
983 }
984
985 // Sends a ping query to the address given.
986 func (s *Server) Ping(node *net.UDPAddr) (*Transaction, error) {
987         s.mu.Lock()
988         defer s.mu.Unlock()
989         return s.query(newDHTAddr(node), "ping", nil, nil)
990 }
991
992 func (s *Server) announcePeer(node dHTAddr, infoHash string, port int, token string, impliedPort bool) (err error) {
993         if port == 0 && !impliedPort {
994                 return errors.New("nothing to announce")
995         }
996         _, err = s.query(node, "announce_peer", map[string]interface{}{
997                 "implied_port": func() int {
998                         if impliedPort {
999                                 return 1
1000                         } else {
1001                                 return 0
1002                         }
1003                 }(),
1004                 "info_hash": infoHash,
1005                 "port":      port,
1006                 "token":     token,
1007         }, func(m Msg) {
1008                 if err := m.Error(); err != nil {
1009                         announceErrors.Add(1)
1010                         // log.Print(token)
1011                         // logonce.Stderr.Printf("announce_peer response: %s", err)
1012                         return
1013                 }
1014                 s.numConfirmedAnnounces++
1015         })
1016         return
1017 }
1018
1019 // Add response nodes to node table.
1020 func (s *Server) liftNodes(d Msg) {
1021         if d["y"] != "r" {
1022                 return
1023         }
1024         for _, cni := range d.Nodes() {
1025                 if missinggo.AddrPort(cni.Addr) == 0 {
1026                         // TODO: Why would people even do this?
1027                         continue
1028                 }
1029                 if s.ipBlocked(missinggo.AddrIP(cni.Addr)) {
1030                         continue
1031                 }
1032                 n := s.getNode(cni.Addr, string(cni.ID[:]))
1033                 n.SetIDFromBytes(cni.ID[:])
1034         }
1035 }
1036
1037 // Sends a find_node query to addr. targetID is the node we're looking for.
1038 func (s *Server) findNode(addr dHTAddr, targetID string) (t *Transaction, err error) {
1039         t, err = s.query(addr, "find_node", map[string]interface{}{"target": targetID}, func(d Msg) {
1040                 // Scrape peers from the response to put in the server's table before
1041                 // handing the response back to the caller.
1042                 s.liftNodes(d)
1043         })
1044         if err != nil {
1045                 return
1046         }
1047         return
1048 }
1049
1050 type Peer struct {
1051         IP   net.IP
1052         Port int
1053 }
1054
1055 func (me *Peer) String() string {
1056         return net.JoinHostPort(me.IP.String(), strconv.FormatInt(int64(me.Port), 10))
1057 }
1058
1059 // In a get_peers response, the addresses of torrent clients involved with the
1060 // queried info-hash.
1061 func (m Msg) Values() (vs []Peer) {
1062         v := func() interface{} {
1063                 defer func() {
1064                         recover()
1065                 }()
1066                 return m["r"].(map[string]interface{})["values"]
1067         }()
1068         if v == nil {
1069                 return
1070         }
1071         vl, ok := v.([]interface{})
1072         if !ok {
1073                 if missinggo.CryHeard() {
1074                         log.Printf(`unexpected krpc "values" field: %#v`, v)
1075                 }
1076                 return
1077         }
1078         vs = make([]Peer, 0, len(vl))
1079         for _, i := range vl {
1080                 s, ok := i.(string)
1081                 if !ok {
1082                         panic(i)
1083                 }
1084                 // Because it's a list of strings, we can let the length of the string
1085                 // determine the IP version of the compact peer.
1086                 var cp util.CompactPeer
1087                 err := cp.UnmarshalBinary([]byte(s))
1088                 if err != nil {
1089                         log.Printf("error decoding values list element: %s", err)
1090                         continue
1091                 }
1092                 vs = append(vs, Peer{cp.IP[:], int(cp.Port)})
1093         }
1094         return
1095 }
1096
1097 func (s *Server) getPeers(addr dHTAddr, infoHash string) (t *Transaction, err error) {
1098         if len(infoHash) != 20 {
1099                 err = fmt.Errorf("infohash has bad length")
1100                 return
1101         }
1102         t, err = s.query(addr, "get_peers", map[string]interface{}{"info_hash": infoHash}, func(m Msg) {
1103                 s.liftNodes(m)
1104                 at, ok := m.AnnounceToken()
1105                 if ok {
1106                         s.getNode(addr, m.SenderID()).announceToken = at
1107                 }
1108         })
1109         return
1110 }
1111
1112 func bootstrapAddrs(nodeAddrs []string) (addrs []*net.UDPAddr, err error) {
1113         bootstrapNodes := nodeAddrs
1114         if len(bootstrapNodes) == 0 {
1115                 bootstrapNodes = []string{
1116                         "router.utorrent.com:6881",
1117                         "router.bittorrent.com:6881",
1118                 }
1119         }
1120         for _, addrStr := range bootstrapNodes {
1121                 udpAddr, err := net.ResolveUDPAddr("udp4", addrStr)
1122                 if err != nil {
1123                         continue
1124                 }
1125                 addrs = append(addrs, udpAddr)
1126         }
1127         if len(addrs) == 0 {
1128                 err = errors.New("nothing resolved")
1129         }
1130         return
1131 }
1132
1133 // Adds bootstrap nodes directly to table, if there's room. Node ID security
1134 // is bypassed, but the IP blocklist is not.
1135 func (s *Server) addRootNodes() error {
1136         addrs, err := bootstrapAddrs(s.bootstrapNodes)
1137         if err != nil {
1138                 return err
1139         }
1140         for _, addr := range addrs {
1141                 if len(s.nodes) >= maxNodes {
1142                         break
1143                 }
1144                 if s.nodes[addr.String()] != nil {
1145                         continue
1146                 }
1147                 if s.ipBlocked(addr.IP) {
1148                         log.Printf("dht root node is in the blocklist: %s", addr.IP)
1149                         continue
1150                 }
1151                 s.nodes[addr.String()] = &node{
1152                         addr: newDHTAddr(addr),
1153                 }
1154         }
1155         return nil
1156 }
1157
1158 // Populates the node table.
1159 func (s *Server) bootstrap() (err error) {
1160         s.mu.Lock()
1161         defer s.mu.Unlock()
1162         if len(s.nodes) == 0 {
1163                 err = s.addRootNodes()
1164         }
1165         if err != nil {
1166                 return
1167         }
1168         for {
1169                 var outstanding sync.WaitGroup
1170                 for _, node := range s.nodes {
1171                         var t *Transaction
1172                         t, err = s.findNode(node.addr, s.id)
1173                         if err != nil {
1174                                 err = fmt.Errorf("error sending find_node: %s", err)
1175                                 return
1176                         }
1177                         outstanding.Add(1)
1178                         t.SetResponseHandler(func(Msg) {
1179                                 outstanding.Done()
1180                         })
1181                 }
1182                 noOutstanding := make(chan struct{})
1183                 go func() {
1184                         outstanding.Wait()
1185                         close(noOutstanding)
1186                 }()
1187                 s.mu.Unlock()
1188                 select {
1189                 case <-s.closed:
1190                         s.mu.Lock()
1191                         return
1192                 case <-time.After(15 * time.Second):
1193                 case <-noOutstanding:
1194                 }
1195                 s.mu.Lock()
1196                 // log.Printf("now have %d nodes", len(s.nodes))
1197                 if s.numGoodNodes() >= 160 {
1198                         break
1199                 }
1200         }
1201         return
1202 }
1203
1204 func (s *Server) numGoodNodes() (num int) {
1205         for _, n := range s.nodes {
1206                 if n.DefinitelyGood() {
1207                         num++
1208                 }
1209         }
1210         return
1211 }
1212
1213 // Returns how many nodes are in the node table.
1214 func (s *Server) NumNodes() int {
1215         s.mu.Lock()
1216         defer s.mu.Unlock()
1217         return len(s.nodes)
1218 }
1219
1220 // Exports the current node table.
1221 func (s *Server) Nodes() (nis []NodeInfo) {
1222         s.mu.Lock()
1223         defer s.mu.Unlock()
1224         for _, node := range s.nodes {
1225                 // if !node.Good() {
1226                 //      continue
1227                 // }
1228                 ni := NodeInfo{
1229                         Addr: node.addr,
1230                 }
1231                 if n := copy(ni.ID[:], node.idString()); n != 20 && n != 0 {
1232                         panic(n)
1233                 }
1234                 nis = append(nis, ni)
1235         }
1236         return
1237 }
1238
1239 // Stops the server network activity. This is all that's required to clean-up a Server.
1240 func (s *Server) Close() {
1241         s.mu.Lock()
1242         select {
1243         case <-s.closed:
1244         default:
1245                 close(s.closed)
1246                 s.socket.Close()
1247         }
1248         s.mu.Unlock()
1249 }
1250
1251 var maxDistance big.Int
1252
1253 func init() {
1254         var zero big.Int
1255         maxDistance.SetBit(&zero, 160, 1)
1256 }
1257
1258 func (s *Server) closestGoodNodes(k int, targetID string) []*node {
1259         return s.closestNodes(k, nodeIDFromString(targetID), func(n *node) bool { return n.DefinitelyGood() })
1260 }
1261
1262 func (s *Server) closestNodes(k int, target nodeID, filter func(*node) bool) []*node {
1263         sel := newKClosestNodesSelector(k, target)
1264         idNodes := make(map[string]*node, len(s.nodes))
1265         for _, node := range s.nodes {
1266                 if !filter(node) {
1267                         continue
1268                 }
1269                 sel.Push(node.id)
1270                 idNodes[node.idString()] = node
1271         }
1272         ids := sel.IDs()
1273         ret := make([]*node, 0, len(ids))
1274         for _, id := range ids {
1275                 ret = append(ret, idNodes[id.ByteString()])
1276         }
1277         return ret
1278 }
1279
1280 func (me *Server) badNode(addr dHTAddr) {
1281         me.badNodes.Add([]byte(addr.String()))
1282         delete(me.nodes, addr.String())
1283 }