]> Sergey Matveev's repositories - btrtrc.git/blob - dht/dht.go
dht: Make Msg a struct with bencode tags
[btrtrc.git] / dht / dht.go
1 // Package DHT implements a DHT for use with the BitTorrent protocol,
2 // described in BEP 5: http://www.bittorrent.org/beps/bep_0005.html.
3 //
4 // Standard use involves creating a NewServer, and calling Announce on it with
5 // the details of your local torrent client and infohash of interest.
6 package dht
7
8 import (
9         "crypto"
10         _ "crypto/sha1"
11         "encoding/binary"
12         "errors"
13         "fmt"
14         "hash/crc32"
15         "io"
16         "log"
17         "math/big"
18         "math/rand"
19         "net"
20         "os"
21         "strconv"
22         "time"
23
24         "github.com/anacrolix/missinggo"
25         "github.com/anacrolix/sync"
26         "github.com/tylertreat/BoomFilters"
27
28         "github.com/anacrolix/torrent/bencode"
29         "github.com/anacrolix/torrent/iplist"
30         "github.com/anacrolix/torrent/logonce"
31 )
32
33 const (
34         maxNodes = 320
35 )
36
37 var (
38         queryResendEvery = 5 * time.Second
39 )
40
41 // Uniquely identifies a transaction to us.
42 type transactionKey struct {
43         RemoteAddr string // host:port
44         T          string // The KRPC transaction ID.
45 }
46
47 type Server struct {
48         id               string
49         socket           net.PacketConn
50         transactions     map[transactionKey]*Transaction
51         transactionIDInt uint64
52         nodes            map[string]*node // Keyed by dHTAddr.String().
53         mu               sync.Mutex
54         closed           chan struct{}
55         ipBlockList      iplist.Ranger
56         badNodes         *boom.BloomFilter
57
58         numConfirmedAnnounces int
59         bootstrapNodes        []string
60         config                ServerConfig
61 }
62
63 type ServerConfig struct {
64         Addr string // Listen address. Used if Conn is nil.
65         Conn net.PacketConn
66         // Don't respond to queries from other nodes.
67         Passive bool
68         // DHT Bootstrap nodes
69         BootstrapNodes []string
70         // Disable the DHT security extension:
71         // http://www.libtorrent.org/dht_sec.html.
72         NoSecurity bool
73         // Initial IP blocklist to use. Applied before serving and bootstrapping
74         // begins.
75         IPBlocklist iplist.Ranger
76         // Used to secure the server's ID. Defaults to the Conn's LocalAddr().
77         PublicIP net.IP
78 }
79
80 type ServerStats struct {
81         // Count of nodes in the node table that responded to our last query or
82         // haven't yet been queried.
83         GoodNodes int
84         // Count of nodes in the node table.
85         Nodes int
86         // Transactions awaiting a response.
87         OutstandingTransactions int
88         // Individual announce_peer requests that got a success response.
89         ConfirmedAnnounces int
90         // Nodes that have been blocked.
91         BadNodes uint
92 }
93
94 // Returns statistics for the server.
95 func (s *Server) Stats() (ss ServerStats) {
96         s.mu.Lock()
97         defer s.mu.Unlock()
98         for _, n := range s.nodes {
99                 if n.DefinitelyGood() {
100                         ss.GoodNodes++
101                 }
102         }
103         ss.Nodes = len(s.nodes)
104         ss.OutstandingTransactions = len(s.transactions)
105         ss.ConfirmedAnnounces = s.numConfirmedAnnounces
106         ss.BadNodes = s.badNodes.Count()
107         return
108 }
109
110 // Returns the listen address for the server. Packets arriving to this address
111 // are processed by the server (unless aliens are involved).
112 func (s *Server) Addr() net.Addr {
113         return s.socket.LocalAddr()
114 }
115
116 func makeSocket(addr string) (socket *net.UDPConn, err error) {
117         addr_, err := net.ResolveUDPAddr("", addr)
118         if err != nil {
119                 return
120         }
121         socket, err = net.ListenUDP("udp", addr_)
122         return
123 }
124
125 // Create a new DHT server.
126 func NewServer(c *ServerConfig) (s *Server, err error) {
127         if c == nil {
128                 c = &ServerConfig{}
129         }
130         s = &Server{
131                 config:      *c,
132                 ipBlockList: c.IPBlocklist,
133                 badNodes:    boom.NewBloomFilter(1000, 0.1),
134         }
135         if c.Conn != nil {
136                 s.socket = c.Conn
137         } else {
138                 s.socket, err = makeSocket(c.Addr)
139                 if err != nil {
140                         return
141                 }
142         }
143         s.bootstrapNodes = c.BootstrapNodes
144         err = s.init()
145         if err != nil {
146                 return
147         }
148         go func() {
149                 err := s.serve()
150                 select {
151                 case <-s.closed:
152                         return
153                 default:
154                 }
155                 if err != nil {
156                         panic(err)
157                 }
158         }()
159         go func() {
160                 err := s.bootstrap()
161                 if err != nil {
162                         select {
163                         case <-s.closed:
164                         default:
165                                 log.Printf("error bootstrapping DHT: %s", err)
166                         }
167                 }
168         }()
169         return
170 }
171
172 // Returns a description of the Server. Python repr-style.
173 func (s *Server) String() string {
174         return fmt.Sprintf("dht server on %s", s.socket.LocalAddr())
175 }
176
177 type nodeID struct {
178         i   big.Int
179         set bool
180 }
181
182 func (nid *nodeID) IsUnset() bool {
183         return !nid.set
184 }
185
186 func nodeIDFromString(s string) (ret nodeID) {
187         if s == "" {
188                 return
189         }
190         ret.i.SetBytes([]byte(s))
191         ret.set = true
192         return
193 }
194
195 func (nid0 *nodeID) Distance(nid1 *nodeID) (ret big.Int) {
196         if nid0.IsUnset() != nid1.IsUnset() {
197                 ret = maxDistance
198                 return
199         }
200         ret.Xor(&nid0.i, &nid1.i)
201         return
202 }
203
204 func (nid *nodeID) ByteString() string {
205         var buf [20]byte
206         b := nid.i.Bytes()
207         copy(buf[20-len(b):], b)
208         return string(buf[:])
209 }
210
211 type node struct {
212         addr          dHTAddr
213         id            nodeID
214         announceToken string
215
216         lastGotQuery    time.Time
217         lastGotResponse time.Time
218         lastSentQuery   time.Time
219 }
220
221 func (n *node) IsSecure() bool {
222         if n.id.IsUnset() {
223                 return false
224         }
225         return NodeIdSecure(n.id.ByteString(), n.addr.IP())
226 }
227
228 func (n *node) idString() string {
229         return n.id.ByteString()
230 }
231
232 func (n *node) SetIDFromBytes(b []byte) {
233         if len(b) != 20 {
234                 panic(b)
235         }
236         n.id.i.SetBytes(b)
237         n.id.set = true
238 }
239
240 func (n *node) SetIDFromString(s string) {
241         n.SetIDFromBytes([]byte(s))
242 }
243
244 func (n *node) IDNotSet() bool {
245         return n.id.i.Int64() == 0
246 }
247
248 func (n *node) NodeInfo() (ret NodeInfo) {
249         ret.Addr = n.addr
250         if n := copy(ret.ID[:], n.idString()); n != 20 {
251                 panic(n)
252         }
253         return
254 }
255
256 func (n *node) DefinitelyGood() bool {
257         if len(n.idString()) != 20 {
258                 return false
259         }
260         // No reason to think ill of them if they've never been queried.
261         if n.lastSentQuery.IsZero() {
262                 return true
263         }
264         // They answered our last query.
265         if n.lastSentQuery.Before(n.lastGotResponse) {
266                 return true
267         }
268         return true
269 }
270
271 type Transaction struct {
272         mu             sync.Mutex
273         remoteAddr     dHTAddr
274         t              string
275         response       chan Msg
276         onResponse     func(Msg) // Called with the server locked.
277         done           chan struct{}
278         queryPacket    []byte
279         timer          *time.Timer
280         s              *Server
281         retries        int
282         lastSend       time.Time
283         userOnResponse func(Msg)
284 }
285
286 // Set a function to be called with the response.
287 func (t *Transaction) SetResponseHandler(f func(Msg)) {
288         t.mu.Lock()
289         defer t.mu.Unlock()
290         t.userOnResponse = f
291         t.tryHandleResponse()
292 }
293
294 func (t *Transaction) tryHandleResponse() {
295         if t.userOnResponse == nil {
296                 return
297         }
298         select {
299         case r := <-t.response:
300                 t.userOnResponse(r)
301                 // Shouldn't be called more than once.
302                 t.userOnResponse = nil
303         default:
304         }
305 }
306
307 func (t *Transaction) key() transactionKey {
308         return transactionKey{
309                 t.remoteAddr.String(),
310                 t.t,
311         }
312 }
313
314 func jitterDuration(average time.Duration, plusMinus time.Duration) time.Duration {
315         return average - plusMinus/2 + time.Duration(rand.Int63n(int64(plusMinus)))
316 }
317
318 func (t *Transaction) startTimer() {
319         t.timer = time.AfterFunc(jitterDuration(queryResendEvery, time.Second), t.timerCallback)
320 }
321
322 func (t *Transaction) timerCallback() {
323         t.mu.Lock()
324         defer t.mu.Unlock()
325         select {
326         case <-t.done:
327                 return
328         default:
329         }
330         if t.retries == 2 {
331                 t.timeout()
332                 return
333         }
334         t.retries++
335         t.sendQuery()
336         if t.timer.Reset(jitterDuration(queryResendEvery, time.Second)) {
337                 panic("timer should have fired to get here")
338         }
339 }
340
341 func (t *Transaction) sendQuery() error {
342         err := t.s.writeToNode(t.queryPacket, t.remoteAddr)
343         if err != nil {
344                 return err
345         }
346         t.lastSend = time.Now()
347         return nil
348 }
349
350 func (t *Transaction) timeout() {
351         go func() {
352                 t.s.mu.Lock()
353                 defer t.s.mu.Unlock()
354                 t.s.nodeTimedOut(t.remoteAddr)
355         }()
356         t.close()
357 }
358
359 func (t *Transaction) close() {
360         if t.closing() {
361                 return
362         }
363         t.queryPacket = nil
364         close(t.response)
365         t.tryHandleResponse()
366         close(t.done)
367         t.timer.Stop()
368         go func() {
369                 t.s.mu.Lock()
370                 defer t.s.mu.Unlock()
371                 t.s.deleteTransaction(t)
372         }()
373 }
374
375 func (t *Transaction) closing() bool {
376         select {
377         case <-t.done:
378                 return true
379         default:
380                 return false
381         }
382 }
383
384 // Abandon the transaction.
385 func (t *Transaction) Close() {
386         t.mu.Lock()
387         defer t.mu.Unlock()
388         t.close()
389 }
390
391 func (t *Transaction) handleResponse(m Msg) {
392         t.mu.Lock()
393         if t.closing() {
394                 t.mu.Unlock()
395                 return
396         }
397         close(t.done)
398         t.mu.Unlock()
399         if t.onResponse != nil {
400                 t.s.mu.Lock()
401                 t.onResponse(m)
402                 t.s.mu.Unlock()
403         }
404         t.queryPacket = nil
405         select {
406         case t.response <- m:
407         default:
408                 panic("blocked handling response")
409         }
410         close(t.response)
411         t.tryHandleResponse()
412 }
413
414 func maskForIP(ip net.IP) []byte {
415         switch {
416         case ip.To4() != nil:
417                 return []byte{0x03, 0x0f, 0x3f, 0xff}
418         default:
419                 return []byte{0x01, 0x03, 0x07, 0x0f, 0x1f, 0x3f, 0x7f, 0xff}
420         }
421 }
422
423 // Generate the CRC used to make or validate secure node ID.
424 func crcIP(ip net.IP, rand uint8) uint32 {
425         if ip4 := ip.To4(); ip4 != nil {
426                 ip = ip4
427         }
428         // Copy IP so we can make changes. Go sux at this.
429         ip = append(make(net.IP, 0, len(ip)), ip...)
430         mask := maskForIP(ip)
431         for i := range mask {
432                 ip[i] &= mask[i]
433         }
434         r := rand & 7
435         ip[0] |= r << 5
436         return crc32.Checksum(ip[:len(mask)], crc32.MakeTable(crc32.Castagnoli))
437 }
438
439 // Makes a node ID secure, in-place. The ID is 20 raw bytes.
440 // http://www.libtorrent.org/dht_sec.html
441 func SecureNodeId(id []byte, ip net.IP) {
442         crc := crcIP(ip, id[19])
443         id[0] = byte(crc >> 24 & 0xff)
444         id[1] = byte(crc >> 16 & 0xff)
445         id[2] = byte(crc>>8&0xf8) | id[2]&7
446 }
447
448 // Returns whether the node ID is considered secure. The id is the 20 raw
449 // bytes. http://www.libtorrent.org/dht_sec.html
450 func NodeIdSecure(id string, ip net.IP) bool {
451         if len(id) != 20 {
452                 panic(fmt.Sprintf("%q", id))
453         }
454         if ip4 := ip.To4(); ip4 != nil {
455                 ip = ip4
456         }
457         crc := crcIP(ip, id[19])
458         if id[0] != byte(crc>>24&0xff) {
459                 return false
460         }
461         if id[1] != byte(crc>>16&0xff) {
462                 return false
463         }
464         if id[2]&0xf8 != byte(crc>>8&0xf8) {
465                 return false
466         }
467         return true
468 }
469
470 func (s *Server) setDefaults() (err error) {
471         if s.id == "" {
472                 var id [20]byte
473                 h := crypto.SHA1.New()
474                 ss, err := os.Hostname()
475                 if err != nil {
476                         log.Print(err)
477                 }
478                 ss += s.socket.LocalAddr().String()
479                 h.Write([]byte(ss))
480                 if b := h.Sum(id[:0:20]); len(b) != 20 {
481                         panic(len(b))
482                 }
483                 if len(id) != 20 {
484                         panic(len(id))
485                 }
486                 publicIP := func() net.IP {
487                         if s.config.PublicIP != nil {
488                                 return s.config.PublicIP
489                         } else {
490                                 return missinggo.AddrIP(s.socket.LocalAddr())
491                         }
492                 }()
493                 SecureNodeId(id[:], publicIP)
494                 s.id = string(id[:])
495         }
496         s.nodes = make(map[string]*node, maxNodes)
497         return
498 }
499
500 // Packets to and from any address matching a range in the list are dropped.
501 func (s *Server) SetIPBlockList(list iplist.Ranger) {
502         s.mu.Lock()
503         defer s.mu.Unlock()
504         s.ipBlockList = list
505 }
506
507 func (s *Server) IPBlocklist() iplist.Ranger {
508         return s.ipBlockList
509 }
510
511 func (s *Server) init() (err error) {
512         err = s.setDefaults()
513         if err != nil {
514                 return
515         }
516         s.closed = make(chan struct{})
517         s.transactions = make(map[transactionKey]*Transaction)
518         return
519 }
520
521 func (s *Server) processPacket(b []byte, addr dHTAddr) {
522         if len(b) < 2 || b[0] != 'd' || b[len(b)-1] != 'e' {
523                 // KRPC messages are bencoded dicts.
524                 readNotKRPCDict.Add(1)
525                 return
526         }
527         var d Msg
528         err := bencode.Unmarshal(b, &d)
529         if err != nil {
530                 readUnmarshalError.Add(1)
531                 func() {
532                         if se, ok := err.(*bencode.SyntaxError); ok {
533                                 // The message was truncated.
534                                 if int(se.Offset) == len(b) {
535                                         return
536                                 }
537                                 // Some messages seem to drop to nul chars abrubtly.
538                                 if int(se.Offset) < len(b) && b[se.Offset] == 0 {
539                                         return
540                                 }
541                                 // The message isn't bencode from the first.
542                                 if se.Offset == 0 {
543                                         return
544                                 }
545                         }
546                         if missinggo.CryHeard() {
547                                 log.Printf("%s: received bad krpc message from %s: %s: %+q", s, addr, err, b)
548                         }
549                 }()
550                 return
551         }
552         s.mu.Lock()
553         defer s.mu.Unlock()
554         if d.Y == "q" {
555                 readQuery.Add(1)
556                 s.handleQuery(addr, d)
557                 return
558         }
559         t := s.findResponseTransaction(d.T, addr)
560         if t == nil {
561                 //log.Printf("unexpected message: %#v", d)
562                 return
563         }
564         node := s.getNode(addr, d.SenderID())
565         node.lastGotResponse = time.Now()
566         // TODO: Update node ID as this is an authoritative packet.
567         go t.handleResponse(d)
568         s.deleteTransaction(t)
569 }
570
571 func (s *Server) serve() error {
572         var b [0x10000]byte
573         for {
574                 n, addr, err := s.socket.ReadFrom(b[:])
575                 if err != nil {
576                         return err
577                 }
578                 read.Add(1)
579                 if n == len(b) {
580                         logonce.Stderr.Printf("received dht packet exceeds buffer size")
581                         continue
582                 }
583                 s.mu.Lock()
584                 blocked := s.ipBlocked(missinggo.AddrIP(addr))
585                 s.mu.Unlock()
586                 if blocked {
587                         readBlocked.Add(1)
588                         continue
589                 }
590                 s.processPacket(b[:n], newDHTAddr(addr))
591         }
592 }
593
594 func (s *Server) ipBlocked(ip net.IP) (blocked bool) {
595         if s.ipBlockList == nil {
596                 return
597         }
598         _, blocked = s.ipBlockList.Lookup(ip)
599         return
600 }
601
602 // Adds directly to the node table.
603 func (s *Server) AddNode(ni NodeInfo) {
604         s.mu.Lock()
605         defer s.mu.Unlock()
606         if s.nodes == nil {
607                 s.nodes = make(map[string]*node)
608         }
609         s.getNode(ni.Addr, string(ni.ID[:]))
610 }
611
612 func (s *Server) nodeByID(id string) *node {
613         for _, node := range s.nodes {
614                 if node.idString() == id {
615                         return node
616                 }
617         }
618         return nil
619 }
620
621 func (s *Server) handleQuery(source dHTAddr, m Msg) {
622         node := s.getNode(source, m.SenderID())
623         node.lastGotQuery = time.Now()
624         // Don't respond.
625         if s.config.Passive {
626                 return
627         }
628         args := m.A
629         switch m.Q {
630         case "ping":
631                 s.reply(source, m.T, Return{})
632         case "get_peers": // TODO: Extract common behaviour with find_node.
633                 targetID := args.InfoHash
634                 if len(targetID) != 20 {
635                         break
636                 }
637                 var rNodes []NodeInfo
638                 // TODO: Reply with "values" list if we have peers instead.
639                 for _, node := range s.closestGoodNodes(8, targetID) {
640                         rNodes = append(rNodes, node.NodeInfo())
641                 }
642                 s.reply(source, m.T, Return{
643                         Nodes: rNodes,
644                         // TODO: Generate this dynamically, and store it for the source.
645                         Token: "hi",
646                 })
647         case "find_node": // TODO: Extract common behaviour with get_peers.
648                 targetID := args.Target
649                 if len(targetID) != 20 {
650                         log.Printf("bad DHT query: %v", m)
651                         return
652                 }
653                 var rNodes []NodeInfo
654                 if node := s.nodeByID(targetID); node != nil {
655                         rNodes = append(rNodes, node.NodeInfo())
656                 } else {
657                         // This will probably cause a crash for IPv6, but meh.
658                         for _, node := range s.closestGoodNodes(8, targetID) {
659                                 rNodes = append(rNodes, node.NodeInfo())
660                         }
661                 }
662                 s.reply(source, m.T, Return{
663                         Nodes: rNodes,
664                 })
665         case "announce_peer":
666                 // TODO(anacrolix): Implement this lolz.
667                 // log.Print(m)
668         case "vote":
669                 // TODO(anacrolix): Or reject, I don't think I want this.
670         default:
671                 log.Printf("%s: not handling received query: q=%s", s, m.Q)
672                 return
673         }
674 }
675
676 func (s *Server) reply(addr dHTAddr, t string, r Return) {
677         r.ID = s.ID()
678         m := Msg{
679                 T: t,
680                 Y: "r",
681                 R: &r,
682         }
683         b, err := bencode.Marshal(m)
684         if err != nil {
685                 panic(err)
686         }
687         err = s.writeToNode(b, addr)
688         if err != nil {
689                 log.Printf("error replying to %s: %s", addr, err)
690         }
691 }
692
693 // Returns a node struct for the addr. It is taken from the table or created
694 // and possibly added if required and meets validity constraints.
695 func (s *Server) getNode(addr dHTAddr, id string) (n *node) {
696         addrStr := addr.String()
697         n = s.nodes[addrStr]
698         if n != nil {
699                 if id != "" {
700                         n.SetIDFromString(id)
701                 }
702                 return
703         }
704         n = &node{
705                 addr: addr,
706         }
707         if len(id) == 20 {
708                 n.SetIDFromString(id)
709         }
710         if len(s.nodes) >= maxNodes {
711                 return
712         }
713         if !s.config.NoSecurity && !n.IsSecure() {
714                 return
715         }
716         if s.badNodes.Test([]byte(addrStr)) {
717                 return
718         }
719         s.nodes[addrStr] = n
720         return
721 }
722
723 func (s *Server) nodeTimedOut(addr dHTAddr) {
724         node, ok := s.nodes[addr.String()]
725         if !ok {
726                 return
727         }
728         if node.DefinitelyGood() {
729                 return
730         }
731         if len(s.nodes) < maxNodes {
732                 return
733         }
734         delete(s.nodes, addr.String())
735 }
736
737 func (s *Server) writeToNode(b []byte, node dHTAddr) (err error) {
738         if list := s.ipBlockList; list != nil {
739                 if r, ok := list.Lookup(missinggo.AddrIP(node.UDPAddr())); ok {
740                         err = fmt.Errorf("write to %s blocked: %s", node, r.Description)
741                         return
742                 }
743         }
744         n, err := s.socket.WriteTo(b, node.UDPAddr())
745         if err != nil {
746                 err = fmt.Errorf("error writing %d bytes to %s: %#v", len(b), node, err)
747                 return
748         }
749         if n != len(b) {
750                 err = io.ErrShortWrite
751                 return
752         }
753         return
754 }
755
756 func (s *Server) findResponseTransaction(transactionID string, sourceNode dHTAddr) *Transaction {
757         return s.transactions[transactionKey{
758                 sourceNode.String(),
759                 transactionID}]
760 }
761
762 func (s *Server) nextTransactionID() string {
763         var b [binary.MaxVarintLen64]byte
764         n := binary.PutUvarint(b[:], s.transactionIDInt)
765         s.transactionIDInt++
766         return string(b[:n])
767 }
768
769 func (s *Server) deleteTransaction(t *Transaction) {
770         delete(s.transactions, t.key())
771 }
772
773 func (s *Server) addTransaction(t *Transaction) {
774         if _, ok := s.transactions[t.key()]; ok {
775                 panic("transaction not unique")
776         }
777         s.transactions[t.key()] = t
778 }
779
780 // Returns the 20-byte server ID. This is the ID used to communicate with the
781 // DHT network.
782 func (s *Server) ID() string {
783         if len(s.id) != 20 {
784                 panic("bad node id")
785         }
786         return s.id
787 }
788
789 func (s *Server) query(node dHTAddr, q string, a map[string]interface{}, onResponse func(Msg)) (t *Transaction, err error) {
790         tid := s.nextTransactionID()
791         if a == nil {
792                 a = make(map[string]interface{}, 1)
793         }
794         a["id"] = s.ID()
795         d := map[string]interface{}{
796                 "t": tid,
797                 "y": "q",
798                 "q": q,
799                 "a": a,
800         }
801         // BEP 43. Outgoing queries from uncontactiable nodes should contain
802         // "ro":1 in the top level dictionary.
803         if s.config.Passive {
804                 d["ro"] = 1
805         }
806         b, err := bencode.Marshal(d)
807         if err != nil {
808                 return
809         }
810         t = &Transaction{
811                 remoteAddr:  node,
812                 t:           tid,
813                 response:    make(chan Msg, 1),
814                 done:        make(chan struct{}),
815                 queryPacket: b,
816                 s:           s,
817                 onResponse:  onResponse,
818         }
819         err = t.sendQuery()
820         if err != nil {
821                 return
822         }
823         s.getNode(node, "").lastSentQuery = time.Now()
824         t.startTimer()
825         s.addTransaction(t)
826         return
827 }
828
829 // The size in bytes of a NodeInfo in its compact binary representation.
830 const CompactIPv4NodeInfoLen = 26
831
832 type NodeInfo struct {
833         ID   [20]byte
834         Addr dHTAddr
835 }
836
837 // Writes the node info to its compact binary representation in b. See
838 // CompactNodeInfoLen.
839 func (ni *NodeInfo) PutCompact(b []byte) error {
840         if n := copy(b[:], ni.ID[:]); n != 20 {
841                 panic(n)
842         }
843         ip := missinggo.AddrIP(ni.Addr).To4()
844         if len(ip) != 4 {
845                 return errors.New("expected ipv4 address")
846         }
847         if n := copy(b[20:], ip); n != 4 {
848                 panic(n)
849         }
850         binary.BigEndian.PutUint16(b[24:], uint16(missinggo.AddrPort(ni.Addr)))
851         return nil
852 }
853
854 func (cni *NodeInfo) UnmarshalCompactIPv4(b []byte) error {
855         if len(b) != 26 {
856                 return errors.New("expected 26 bytes")
857         }
858         missinggo.CopyExact(cni.ID[:], b[:20])
859         cni.Addr = newDHTAddr(&net.UDPAddr{
860                 IP:   net.IPv4(b[20], b[21], b[22], b[23]),
861                 Port: int(binary.BigEndian.Uint16(b[24:26])),
862         })
863         return nil
864 }
865
866 // Sends a ping query to the address given.
867 func (s *Server) Ping(node *net.UDPAddr) (*Transaction, error) {
868         s.mu.Lock()
869         defer s.mu.Unlock()
870         return s.query(newDHTAddr(node), "ping", nil, nil)
871 }
872
873 func (s *Server) announcePeer(node dHTAddr, infoHash string, port int, token string, impliedPort bool) (err error) {
874         if port == 0 && !impliedPort {
875                 return errors.New("nothing to announce")
876         }
877         _, err = s.query(node, "announce_peer", map[string]interface{}{
878                 "implied_port": func() int {
879                         if impliedPort {
880                                 return 1
881                         } else {
882                                 return 0
883                         }
884                 }(),
885                 "info_hash": infoHash,
886                 "port":      port,
887                 "token":     token,
888         }, func(m Msg) {
889                 if err := m.Error(); err != nil {
890                         announceErrors.Add(1)
891                         // log.Print(token)
892                         // logonce.Stderr.Printf("announce_peer response: %s", err)
893                         return
894                 }
895                 s.numConfirmedAnnounces++
896         })
897         return
898 }
899
900 // Add response nodes to node table.
901 func (s *Server) liftNodes(d Msg) {
902         if d.Y != "r" {
903                 return
904         }
905         for _, cni := range d.R.Nodes {
906                 if missinggo.AddrPort(cni.Addr) == 0 {
907                         // TODO: Why would people even do this?
908                         continue
909                 }
910                 if s.ipBlocked(missinggo.AddrIP(cni.Addr)) {
911                         continue
912                 }
913                 n := s.getNode(cni.Addr, string(cni.ID[:]))
914                 n.SetIDFromBytes(cni.ID[:])
915         }
916 }
917
918 // Sends a find_node query to addr. targetID is the node we're looking for.
919 func (s *Server) findNode(addr dHTAddr, targetID string) (t *Transaction, err error) {
920         t, err = s.query(addr, "find_node", map[string]interface{}{"target": targetID}, func(d Msg) {
921                 // Scrape peers from the response to put in the server's table before
922                 // handing the response back to the caller.
923                 s.liftNodes(d)
924         })
925         if err != nil {
926                 return
927         }
928         return
929 }
930
931 type Peer struct {
932         IP   net.IP
933         Port int
934 }
935
936 func (me *Peer) String() string {
937         return net.JoinHostPort(me.IP.String(), strconv.FormatInt(int64(me.Port), 10))
938 }
939
940 func (s *Server) getPeers(addr dHTAddr, infoHash string) (t *Transaction, err error) {
941         if len(infoHash) != 20 {
942                 err = fmt.Errorf("infohash has bad length")
943                 return
944         }
945         t, err = s.query(addr, "get_peers", map[string]interface{}{"info_hash": infoHash}, func(m Msg) {
946                 s.liftNodes(m)
947                 s.getNode(addr, m.SenderID()).announceToken = m.R.Token
948         })
949         return
950 }
951
952 func bootstrapAddrs(nodeAddrs []string) (addrs []*net.UDPAddr, err error) {
953         bootstrapNodes := nodeAddrs
954         if len(bootstrapNodes) == 0 {
955                 bootstrapNodes = []string{
956                         "router.utorrent.com:6881",
957                         "router.bittorrent.com:6881",
958                 }
959         }
960         for _, addrStr := range bootstrapNodes {
961                 udpAddr, err := net.ResolveUDPAddr("udp4", addrStr)
962                 if err != nil {
963                         continue
964                 }
965                 addrs = append(addrs, udpAddr)
966         }
967         if len(addrs) == 0 {
968                 err = errors.New("nothing resolved")
969         }
970         return
971 }
972
973 // Adds bootstrap nodes directly to table, if there's room. Node ID security
974 // is bypassed, but the IP blocklist is not.
975 func (s *Server) addRootNodes() error {
976         addrs, err := bootstrapAddrs(s.bootstrapNodes)
977         if err != nil {
978                 return err
979         }
980         for _, addr := range addrs {
981                 if len(s.nodes) >= maxNodes {
982                         break
983                 }
984                 if s.nodes[addr.String()] != nil {
985                         continue
986                 }
987                 if s.ipBlocked(addr.IP) {
988                         log.Printf("dht root node is in the blocklist: %s", addr.IP)
989                         continue
990                 }
991                 s.nodes[addr.String()] = &node{
992                         addr: newDHTAddr(addr),
993                 }
994         }
995         return nil
996 }
997
998 // Populates the node table.
999 func (s *Server) bootstrap() (err error) {
1000         s.mu.Lock()
1001         defer s.mu.Unlock()
1002         if len(s.nodes) == 0 {
1003                 err = s.addRootNodes()
1004         }
1005         if err != nil {
1006                 return
1007         }
1008         for {
1009                 var outstanding sync.WaitGroup
1010                 for _, node := range s.nodes {
1011                         var t *Transaction
1012                         t, err = s.findNode(node.addr, s.id)
1013                         if err != nil {
1014                                 err = fmt.Errorf("error sending find_node: %s", err)
1015                                 return
1016                         }
1017                         outstanding.Add(1)
1018                         t.SetResponseHandler(func(Msg) {
1019                                 outstanding.Done()
1020                         })
1021                 }
1022                 noOutstanding := make(chan struct{})
1023                 go func() {
1024                         outstanding.Wait()
1025                         close(noOutstanding)
1026                 }()
1027                 s.mu.Unlock()
1028                 select {
1029                 case <-s.closed:
1030                         s.mu.Lock()
1031                         return
1032                 case <-time.After(15 * time.Second):
1033                 case <-noOutstanding:
1034                 }
1035                 s.mu.Lock()
1036                 // log.Printf("now have %d nodes", len(s.nodes))
1037                 if s.numGoodNodes() >= 160 {
1038                         break
1039                 }
1040         }
1041         return
1042 }
1043
1044 func (s *Server) numGoodNodes() (num int) {
1045         for _, n := range s.nodes {
1046                 if n.DefinitelyGood() {
1047                         num++
1048                 }
1049         }
1050         return
1051 }
1052
1053 // Returns how many nodes are in the node table.
1054 func (s *Server) NumNodes() int {
1055         s.mu.Lock()
1056         defer s.mu.Unlock()
1057         return len(s.nodes)
1058 }
1059
1060 // Exports the current node table.
1061 func (s *Server) Nodes() (nis []NodeInfo) {
1062         s.mu.Lock()
1063         defer s.mu.Unlock()
1064         for _, node := range s.nodes {
1065                 // if !node.Good() {
1066                 //      continue
1067                 // }
1068                 ni := NodeInfo{
1069                         Addr: node.addr,
1070                 }
1071                 if n := copy(ni.ID[:], node.idString()); n != 20 && n != 0 {
1072                         panic(n)
1073                 }
1074                 nis = append(nis, ni)
1075         }
1076         return
1077 }
1078
1079 // Stops the server network activity. This is all that's required to clean-up a Server.
1080 func (s *Server) Close() {
1081         s.mu.Lock()
1082         select {
1083         case <-s.closed:
1084         default:
1085                 close(s.closed)
1086                 s.socket.Close()
1087         }
1088         s.mu.Unlock()
1089 }
1090
1091 var maxDistance big.Int
1092
1093 func init() {
1094         var zero big.Int
1095         maxDistance.SetBit(&zero, 160, 1)
1096 }
1097
1098 func (s *Server) closestGoodNodes(k int, targetID string) []*node {
1099         return s.closestNodes(k, nodeIDFromString(targetID), func(n *node) bool { return n.DefinitelyGood() })
1100 }
1101
1102 func (s *Server) closestNodes(k int, target nodeID, filter func(*node) bool) []*node {
1103         sel := newKClosestNodesSelector(k, target)
1104         idNodes := make(map[string]*node, len(s.nodes))
1105         for _, node := range s.nodes {
1106                 if !filter(node) {
1107                         continue
1108                 }
1109                 sel.Push(node.id)
1110                 idNodes[node.idString()] = node
1111         }
1112         ids := sel.IDs()
1113         ret := make([]*node, 0, len(ids))
1114         for _, id := range ids {
1115                 ret = append(ret, idNodes[id.ByteString()])
1116         }
1117         return ret
1118 }
1119
1120 func (me *Server) badNode(addr dHTAddr) {
1121         me.badNodes.Add([]byte(addr.String()))
1122         delete(me.nodes, addr.String())
1123 }