]> Sergey Matveev's repositories - btrtrc.git/blob - client.go
Support different hosts for each network
[btrtrc.git] / client.go
1 package torrent
2
3 import (
4         "bufio"
5         "bytes"
6         "context"
7         "crypto/rand"
8         "encoding/binary"
9         "errors"
10         "fmt"
11         "io"
12         "net"
13         "strconv"
14         "strings"
15         "time"
16
17         "github.com/anacrolix/dht"
18         "github.com/anacrolix/dht/krpc"
19         "github.com/anacrolix/log"
20         "github.com/anacrolix/missinggo"
21         "github.com/anacrolix/missinggo/pproffd"
22         "github.com/anacrolix/missinggo/pubsub"
23         "github.com/anacrolix/missinggo/slices"
24         "github.com/anacrolix/sync"
25         "github.com/dustin/go-humanize"
26         "github.com/google/btree"
27         "golang.org/x/time/rate"
28
29         "github.com/anacrolix/torrent/bencode"
30         "github.com/anacrolix/torrent/iplist"
31         "github.com/anacrolix/torrent/metainfo"
32         "github.com/anacrolix/torrent/mse"
33         pp "github.com/anacrolix/torrent/peer_protocol"
34         "github.com/anacrolix/torrent/storage"
35 )
36
37 // Clients contain zero or more Torrents. A Client manages a blocklist, the
38 // TCP/UDP protocol ports, and DHT as desired.
39 type Client struct {
40         mu     sync.RWMutex
41         event  sync.Cond
42         closed missinggo.Event
43
44         config Config
45         logger *log.Logger
46
47         halfOpenLimit  int
48         peerID         PeerID
49         defaultStorage *storage.Client
50         onClose        []func()
51         conns          []socket
52         dhtServers     []*dht.Server
53         ipBlockList    iplist.Ranger
54         // Our BitTorrent protocol extension bytes, sent in our BT handshakes.
55         extensionBytes peerExtensionBytes
56         uploadLimit    *rate.Limiter
57         downloadLimit  *rate.Limiter
58
59         // Set of addresses that have our client ID. This intentionally will
60         // include ourselves if we end up trying to connect to our own address
61         // through legitimate channels.
62         dopplegangerAddrs map[string]struct{}
63         badPeerIPs        map[string]struct{}
64         torrents          map[metainfo.Hash]*Torrent
65 }
66
67 func (cl *Client) BadPeerIPs() []string {
68         cl.mu.RLock()
69         defer cl.mu.RUnlock()
70         return cl.badPeerIPsLocked()
71 }
72
73 func (cl *Client) badPeerIPsLocked() []string {
74         return slices.FromMapKeys(cl.badPeerIPs).([]string)
75 }
76
77 func (cl *Client) PeerID() PeerID {
78         return cl.peerID
79 }
80
81 type torrentAddr string
82
83 func (torrentAddr) Network() string { return "" }
84
85 func (me torrentAddr) String() string { return string(me) }
86
87 func (cl *Client) LocalPort() (port int) {
88         cl.eachListener(func(l socket) bool {
89                 _port := missinggo.AddrPort(l.Addr())
90                 if _port == 0 {
91                         panic(l)
92                 }
93                 if port == 0 {
94                         port = _port
95                 } else if port != _port {
96                         panic("mismatched ports")
97                 }
98                 return true
99         })
100         return
101 }
102
103 func writeDhtServerStatus(w io.Writer, s *dht.Server) {
104         dhtStats := s.Stats()
105         fmt.Fprintf(w, "\tDHT nodes: %d (%d good, %d banned)\n", dhtStats.Nodes, dhtStats.GoodNodes, dhtStats.BadNodes)
106         fmt.Fprintf(w, "\tDHT Server ID: %x\n", s.ID())
107         fmt.Fprintf(w, "\tDHT port: %d\n", missinggo.AddrPort(s.Addr()))
108         fmt.Fprintf(w, "\tDHT announces: %d\n", dhtStats.ConfirmedAnnounces)
109         fmt.Fprintf(w, "\tOutstanding transactions: %d\n", dhtStats.OutstandingTransactions)
110 }
111
112 // Writes out a human readable status of the client, such as for writing to a
113 // HTTP status page.
114 func (cl *Client) WriteStatus(_w io.Writer) {
115         cl.mu.Lock()
116         defer cl.mu.Unlock()
117         w := bufio.NewWriter(_w)
118         defer w.Flush()
119         fmt.Fprintf(w, "Listen port: %d\n", cl.LocalPort())
120         fmt.Fprintf(w, "Peer ID: %+q\n", cl.PeerID())
121         fmt.Fprintf(w, "Announce key: %x\n", cl.announceKey())
122         fmt.Fprintf(w, "Banned IPs: %d\n", len(cl.badPeerIPsLocked()))
123         cl.eachDhtServer(func(s *dht.Server) {
124                 fmt.Fprintf(w, "%s DHT server:\n", s.Addr().Network())
125                 writeDhtServerStatus(w, s)
126         })
127         fmt.Fprintf(w, "# Torrents: %d\n", len(cl.torrentsAsSlice()))
128         fmt.Fprintln(w)
129         for _, t := range slices.Sort(cl.torrentsAsSlice(), func(l, r *Torrent) bool {
130                 return l.InfoHash().AsString() < r.InfoHash().AsString()
131         }).([]*Torrent) {
132                 if t.name() == "" {
133                         fmt.Fprint(w, "<unknown name>")
134                 } else {
135                         fmt.Fprint(w, t.name())
136                 }
137                 fmt.Fprint(w, "\n")
138                 if t.info != nil {
139                         fmt.Fprintf(w, "%f%% of %d bytes (%s)", 100*(1-float64(t.bytesMissingLocked())/float64(t.info.TotalLength())), t.length, humanize.Bytes(uint64(t.info.TotalLength())))
140                 } else {
141                         w.WriteString("<missing metainfo>")
142                 }
143                 fmt.Fprint(w, "\n")
144                 t.writeStatus(w)
145                 fmt.Fprintln(w)
146         }
147 }
148
149 const debugLogValue = "debug"
150
151 func (cl *Client) debugLogFilter(m *log.Msg) bool {
152         if !cl.config.Debug {
153                 _, ok := m.Values()[debugLogValue]
154                 return !ok
155         }
156         return true
157 }
158
159 func (cl *Client) initLogger() {
160         cl.logger = log.Default.Clone().AddValue(cl).AddFilter(log.NewFilter(cl.debugLogFilter))
161 }
162
163 func (cl *Client) announceKey() int32 {
164         return int32(binary.BigEndian.Uint32(cl.peerID[16:20]))
165 }
166
167 func NewClient(cfg *Config) (cl *Client, err error) {
168         if cfg == nil {
169                 cfg = &Config{}
170         }
171         cfg.setDefaults()
172
173         defer func() {
174                 if err != nil {
175                         cl = nil
176                 }
177         }()
178         cl = &Client{
179                 halfOpenLimit:     cfg.HalfOpenConnsPerTorrent,
180                 config:            *cfg,
181                 dopplegangerAddrs: make(map[string]struct{}),
182                 torrents:          make(map[metainfo.Hash]*Torrent),
183         }
184         cl.initLogger()
185         defer func() {
186                 if err == nil {
187                         return
188                 }
189                 cl.Close()
190         }()
191         if cfg.UploadRateLimiter == nil {
192                 cl.uploadLimit = rate.NewLimiter(rate.Inf, 0)
193         } else {
194                 cl.uploadLimit = cfg.UploadRateLimiter
195         }
196         if cfg.DownloadRateLimiter == nil {
197                 cl.downloadLimit = rate.NewLimiter(rate.Inf, 0)
198         } else {
199                 cl.downloadLimit = cfg.DownloadRateLimiter
200         }
201         cl.extensionBytes = defaultPeerExtensionBytes()
202         cl.event.L = &cl.mu
203         storageImpl := cfg.DefaultStorage
204         if storageImpl == nil {
205                 // We'd use mmap but HFS+ doesn't support sparse files.
206                 storageImpl = storage.NewFile(cfg.DataDir)
207                 cl.onClose = append(cl.onClose, func() {
208                         if err := storageImpl.Close(); err != nil {
209                                 log.Printf("error closing default storage: %s", err)
210                         }
211                 })
212         }
213         cl.defaultStorage = storage.NewClient(storageImpl)
214         if cfg.IPBlocklist != nil {
215                 cl.ipBlockList = cfg.IPBlocklist
216         }
217
218         if cfg.PeerID != "" {
219                 missinggo.CopyExact(&cl.peerID, cfg.PeerID)
220         } else {
221                 o := copy(cl.peerID[:], cfg.Bep20)
222                 _, err = rand.Read(cl.peerID[o:])
223                 if err != nil {
224                         panic("error generating peer id")
225                 }
226         }
227
228         cl.conns, err = listenAll(cl.enabledPeerNetworks(), cl.config.ListenHost, cl.config.ListenPort)
229         if err != nil {
230                 return
231         }
232         cl.LocalPort()
233
234         for _, s := range cl.conns {
235                 if peerNetworkEnabled(s.Addr().Network(), cl.config) {
236                         go cl.acceptConnections(s)
237                 }
238         }
239
240         go cl.forwardPort()
241         if !cfg.NoDHT {
242                 for _, s := range cl.conns {
243                         if pc, ok := s.(net.PacketConn); ok {
244                                 ds, err := cl.newDhtServer(pc)
245                                 if err != nil {
246                                         panic(err)
247                                 }
248                                 cl.dhtServers = append(cl.dhtServers, ds)
249                         }
250                 }
251         }
252
253         return
254 }
255
256 func (cl *Client) enabledPeerNetworks() (ns []string) {
257         for _, n := range allPeerNetworks {
258                 if peerNetworkEnabled(n, cl.config) {
259                         ns = append(ns, n)
260                 }
261         }
262         return
263 }
264
265 func (cl *Client) newDhtServer(conn net.PacketConn) (s *dht.Server, err error) {
266         cfg := dht.ServerConfig{
267                 IPBlocklist:    cl.ipBlockList,
268                 Conn:           conn,
269                 OnAnnouncePeer: cl.onDHTAnnouncePeer,
270                 PublicIP: func() net.IP {
271                         if connIsIpv6(conn) && cl.config.PublicIp6 != nil {
272                                 return cl.config.PublicIp6
273                         }
274                         return cl.config.PublicIp4
275                 }(),
276                 StartingNodes: cl.config.DhtStartingNodes,
277         }
278         s, err = dht.NewServer(&cfg)
279         if err == nil {
280                 go func() {
281                         if _, err := s.Bootstrap(); err != nil {
282                                 log.Printf("error bootstrapping dht: %s", err)
283                         }
284                 }()
285         }
286         return
287 }
288
289 func firstNonEmptyString(ss ...string) string {
290         for _, s := range ss {
291                 if s != "" {
292                         return s
293                 }
294         }
295         return ""
296 }
297
298 func (cl *Client) Closed() <-chan struct{} {
299         cl.mu.Lock()
300         defer cl.mu.Unlock()
301         return cl.closed.C()
302 }
303
304 func (cl *Client) eachDhtServer(f func(*dht.Server)) {
305         for _, ds := range cl.dhtServers {
306                 f(ds)
307         }
308 }
309
310 func (cl *Client) closeSockets() {
311         cl.eachListener(func(l socket) bool {
312                 l.Close()
313                 return true
314         })
315         cl.conns = nil
316 }
317
318 // Stops the client. All connections to peers are closed and all activity will
319 // come to a halt.
320 func (cl *Client) Close() {
321         cl.mu.Lock()
322         defer cl.mu.Unlock()
323         cl.closed.Set()
324         cl.eachDhtServer(func(s *dht.Server) { s.Close() })
325         cl.closeSockets()
326         for _, t := range cl.torrents {
327                 t.close()
328         }
329         for _, f := range cl.onClose {
330                 f()
331         }
332         cl.event.Broadcast()
333 }
334
335 func (cl *Client) ipBlockRange(ip net.IP) (r iplist.Range, blocked bool) {
336         if cl.ipBlockList == nil {
337                 return
338         }
339         return cl.ipBlockList.Lookup(ip)
340 }
341
342 func (cl *Client) ipIsBlocked(ip net.IP) bool {
343         _, blocked := cl.ipBlockRange(ip)
344         return blocked
345 }
346
347 func (cl *Client) waitAccept() {
348         for {
349                 for _, t := range cl.torrents {
350                         if t.wantConns() {
351                                 return
352                         }
353                 }
354                 if cl.closed.IsSet() {
355                         return
356                 }
357                 cl.event.Wait()
358         }
359 }
360
361 func (cl *Client) rejectAccepted(conn net.Conn) bool {
362         ra := conn.RemoteAddr()
363         rip := missinggo.AddrIP(ra)
364         if cl.config.DisableIPv4Peers && rip.To4() != nil {
365                 return true
366         }
367         if cl.config.DisableIPv4 && len(rip) == net.IPv4len {
368                 return true
369         }
370         if cl.config.DisableIPv6 && len(rip) == net.IPv6len && rip.To4() == nil {
371                 return true
372         }
373         return cl.badPeerIPPort(rip, missinggo.AddrPort(ra))
374 }
375
376 func (cl *Client) acceptConnections(l net.Listener) {
377         cl.mu.Lock()
378         defer cl.mu.Unlock()
379         for {
380                 cl.waitAccept()
381                 cl.mu.Unlock()
382                 conn, err := l.Accept()
383                 conn = pproffd.WrapNetConn(conn)
384                 cl.mu.Lock()
385                 if cl.closed.IsSet() {
386                         if conn != nil {
387                                 conn.Close()
388                         }
389                         return
390                 }
391                 if err != nil {
392                         log.Print(err)
393                         // I think something harsher should happen here? Our accept
394                         // routine just fucked off.
395                         return
396                 }
397                 log.Fmsg("accepted connection from %s", conn.RemoteAddr()).AddValue(debugLogValue).Log(cl.logger)
398                 go torrent.Add(fmt.Sprintf("accepted conn remote IP len=%d", len(missinggo.AddrIP(conn.RemoteAddr()))), 1)
399                 go torrent.Add(fmt.Sprintf("accepted conn network=%s", conn.RemoteAddr().Network()), 1)
400                 go torrent.Add(fmt.Sprintf("accepted on %s listener", l.Addr().Network()), 1)
401                 if cl.rejectAccepted(conn) {
402                         go torrent.Add("rejected accepted connections", 1)
403                         conn.Close()
404                 } else {
405                         go cl.incomingConnection(conn)
406                 }
407         }
408 }
409
410 func (cl *Client) incomingConnection(nc net.Conn) {
411         defer nc.Close()
412         if tc, ok := nc.(*net.TCPConn); ok {
413                 tc.SetLinger(0)
414         }
415         c := cl.newConnection(nc)
416         c.Discovery = peerSourceIncoming
417         cl.runReceivedConn(c)
418 }
419
420 // Returns a handle to the given torrent, if it's present in the client.
421 func (cl *Client) Torrent(ih metainfo.Hash) (t *Torrent, ok bool) {
422         cl.mu.Lock()
423         defer cl.mu.Unlock()
424         t, ok = cl.torrents[ih]
425         return
426 }
427
428 func (cl *Client) torrent(ih metainfo.Hash) *Torrent {
429         return cl.torrents[ih]
430 }
431
432 type dialResult struct {
433         Conn net.Conn
434 }
435
436 func countDialResult(err error) {
437         if err == nil {
438                 successfulDials.Add(1)
439         } else {
440                 unsuccessfulDials.Add(1)
441         }
442 }
443
444 func reducedDialTimeout(minDialTimeout, max time.Duration, halfOpenLimit int, pendingPeers int) (ret time.Duration) {
445         ret = max / time.Duration((pendingPeers+halfOpenLimit)/halfOpenLimit)
446         if ret < minDialTimeout {
447                 ret = minDialTimeout
448         }
449         return
450 }
451
452 // Returns whether an address is known to connect to a client with our own ID.
453 func (cl *Client) dopplegangerAddr(addr string) bool {
454         _, ok := cl.dopplegangerAddrs[addr]
455         return ok
456 }
457
458 func (cl *Client) dialTCP(ctx context.Context, addr string) (c net.Conn, err error) {
459         d := net.Dialer{
460         // Can't bind to the listen address, even though we intend to create an
461         // endpoint pair that is distinct. Oh well.
462
463         // LocalAddr: cl.tcpListener.Addr(),
464         }
465         c, err = d.DialContext(ctx, "tcp"+ipNetworkSuffix(!cl.config.DisableIPv4 && !cl.config.DisableIPv4Peers, !cl.config.DisableIPv6), addr)
466         countDialResult(err)
467         if err == nil {
468                 c.(*net.TCPConn).SetLinger(0)
469         }
470         c = pproffd.WrapNetConn(c)
471         return
472 }
473
474 func ipNetworkSuffix(allowIpv4, allowIpv6 bool) string {
475         switch {
476         case allowIpv4 && allowIpv6:
477                 return ""
478         case allowIpv4 && !allowIpv6:
479                 return "4"
480         case !allowIpv4 && allowIpv6:
481                 return "6"
482         default:
483                 panic("unhandled ip network combination")
484         }
485 }
486
487 func dialUTP(ctx context.Context, addr string, sock utpSocket) (c net.Conn, err error) {
488         return sock.DialContext(ctx, "", addr)
489 }
490
491 var allPeerNetworks = []string{"tcp4", "tcp6", "udp4", "udp6"}
492
493 func peerNetworkEnabled(network string, cfg Config) bool {
494         c := func(s string) bool {
495                 return strings.Contains(network, s)
496         }
497         if cfg.DisableUTP {
498                 if c("udp") || c("utp") {
499                         return false
500                 }
501         }
502         if cfg.DisableTCP && c("tcp") {
503                 return false
504         }
505         return true
506 }
507
508 // Returns a connection over UTP or TCP, whichever is first to connect.
509 func (cl *Client) dialFirst(ctx context.Context, addr string) net.Conn {
510         ctx, cancel := context.WithCancel(ctx)
511         // As soon as we return one connection, cancel the others.
512         defer cancel()
513         left := 0
514         resCh := make(chan dialResult, left)
515         dial := func(f func(_ context.Context, addr string) (net.Conn, error)) {
516                 left++
517                 go func() {
518                         c, err := f(ctx, addr)
519                         countDialResult(err)
520                         resCh <- dialResult{c}
521                 }()
522         }
523         func() {
524                 cl.mu.Lock()
525                 defer cl.mu.Unlock()
526                 cl.eachListener(func(s socket) bool {
527                         if peerNetworkEnabled(s.Addr().Network(), cl.config) {
528                                 dial(s.dial)
529                         }
530                         return true
531                 })
532         }()
533         var res dialResult
534         // Wait for a successful connection.
535         for ; left > 0 && res.Conn == nil; left-- {
536                 res = <-resCh
537         }
538         // There are still incompleted dials.
539         go func() {
540                 for ; left > 0; left-- {
541                         conn := (<-resCh).Conn
542                         if conn != nil {
543                                 conn.Close()
544                         }
545                 }
546         }()
547         if res.Conn != nil {
548                 go torrent.Add(fmt.Sprintf("network dialed first: %s", res.Conn.RemoteAddr().Network()), 1)
549         }
550         return res.Conn
551 }
552
553 func (cl *Client) noLongerHalfOpen(t *Torrent, addr string) {
554         if _, ok := t.halfOpen[addr]; !ok {
555                 panic("invariant broken")
556         }
557         delete(t.halfOpen, addr)
558         t.openNewConns()
559 }
560
561 // Performs initiator handshakes and returns a connection. Returns nil
562 // *connection if no connection for valid reasons.
563 func (cl *Client) handshakesConnection(ctx context.Context, nc net.Conn, t *Torrent, encryptHeader bool) (c *connection, err error) {
564         c = cl.newConnection(nc)
565         c.headerEncrypted = encryptHeader
566         ctx, cancel := context.WithTimeout(ctx, cl.config.HandshakesTimeout)
567         defer cancel()
568         dl, ok := ctx.Deadline()
569         if !ok {
570                 panic(ctx)
571         }
572         err = nc.SetDeadline(dl)
573         if err != nil {
574                 panic(err)
575         }
576         ok, err = cl.initiateHandshakes(c, t)
577         if !ok {
578                 c = nil
579         }
580         return
581 }
582
583 // Returns nil connection and nil error if no connection could be established
584 // for valid reasons.
585 func (cl *Client) establishOutgoingConnEx(t *Torrent, addr string, ctx context.Context, obfuscatedHeader bool) (c *connection, err error) {
586         nc := cl.dialFirst(ctx, addr)
587         if nc == nil {
588                 return
589         }
590         defer func() {
591                 if c == nil || err != nil {
592                         nc.Close()
593                 }
594         }()
595         return cl.handshakesConnection(ctx, nc, t, obfuscatedHeader)
596 }
597
598 // Returns nil connection and nil error if no connection could be established
599 // for valid reasons.
600 func (cl *Client) establishOutgoingConn(t *Torrent, addr string) (c *connection, err error) {
601         ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
602         defer cancel()
603         obfuscatedHeaderFirst := !cl.config.DisableEncryption && !cl.config.PreferNoEncryption
604         c, err = cl.establishOutgoingConnEx(t, addr, ctx, obfuscatedHeaderFirst)
605         if err != nil {
606                 return
607         }
608         if c != nil {
609                 go torrent.Add("initiated conn with preferred header obfuscation", 1)
610                 return
611         }
612         if cl.config.ForceEncryption {
613                 // We should have just tried with an obfuscated header. A plaintext
614                 // header can't result in an encrypted connection, so we're done.
615                 if !obfuscatedHeaderFirst {
616                         panic(cl.config.EncryptionPolicy)
617                 }
618                 return
619         }
620         // Try again with encryption if we didn't earlier, or without if we did.
621         c, err = cl.establishOutgoingConnEx(t, addr, ctx, !obfuscatedHeaderFirst)
622         if c != nil {
623                 go torrent.Add("initiated conn with fallback header obfuscation", 1)
624         }
625         return
626 }
627
628 // Called to dial out and run a connection. The addr we're given is already
629 // considered half-open.
630 func (cl *Client) outgoingConnection(t *Torrent, addr string, ps peerSource) {
631         c, err := cl.establishOutgoingConn(t, addr)
632         cl.mu.Lock()
633         defer cl.mu.Unlock()
634         // Don't release lock between here and addConnection, unless it's for
635         // failure.
636         cl.noLongerHalfOpen(t, addr)
637         if err != nil {
638                 if cl.config.Debug {
639                         log.Printf("error establishing outgoing connection: %s", err)
640                 }
641                 return
642         }
643         if c == nil {
644                 return
645         }
646         defer c.Close()
647         c.Discovery = ps
648         cl.runHandshookConn(c, t, true)
649 }
650
651 // The port number for incoming peer connections. 0 if the client isn't
652 // listening.
653 func (cl *Client) incomingPeerPort() int {
654         return cl.LocalPort()
655 }
656
657 func (cl *Client) initiateHandshakes(c *connection, t *Torrent) (ok bool, err error) {
658         if c.headerEncrypted {
659                 var rw io.ReadWriter
660                 rw, c.cryptoMethod, err = mse.InitiateHandshake(
661                         struct {
662                                 io.Reader
663                                 io.Writer
664                         }{c.r, c.w},
665                         t.infoHash[:],
666                         nil,
667                         func() mse.CryptoMethod {
668                                 switch {
669                                 case cl.config.ForceEncryption:
670                                         return mse.CryptoMethodRC4
671                                 case cl.config.DisableEncryption:
672                                         return mse.CryptoMethodPlaintext
673                                 default:
674                                         return mse.AllSupportedCrypto
675                                 }
676                         }(),
677                 )
678                 c.setRW(rw)
679                 if err != nil {
680                         return
681                 }
682         }
683         ih, ok, err := cl.connBTHandshake(c, &t.infoHash)
684         if ih != t.infoHash {
685                 ok = false
686         }
687         return
688 }
689
690 // Calls f with any secret keys.
691 func (cl *Client) forSkeys(f func([]byte) bool) {
692         cl.mu.Lock()
693         defer cl.mu.Unlock()
694         for ih := range cl.torrents {
695                 if !f(ih[:]) {
696                         break
697                 }
698         }
699 }
700
701 // Do encryption and bittorrent handshakes as receiver.
702 func (cl *Client) receiveHandshakes(c *connection) (t *Torrent, err error) {
703         var rw io.ReadWriter
704         rw, c.headerEncrypted, c.cryptoMethod, err = handleEncryption(c.rw(), cl.forSkeys, cl.config.EncryptionPolicy)
705         c.setRW(rw)
706         if err != nil {
707                 if err == mse.ErrNoSecretKeyMatch {
708                         err = nil
709                 }
710                 return
711         }
712         if cl.config.ForceEncryption && !c.headerEncrypted {
713                 err = errors.New("connection not encrypted")
714                 return
715         }
716         ih, ok, err := cl.connBTHandshake(c, nil)
717         if err != nil {
718                 err = fmt.Errorf("error during bt handshake: %s", err)
719                 return
720         }
721         if !ok {
722                 return
723         }
724         cl.mu.Lock()
725         t = cl.torrents[ih]
726         cl.mu.Unlock()
727         return
728 }
729
730 // Returns !ok if handshake failed for valid reasons.
731 func (cl *Client) connBTHandshake(c *connection, ih *metainfo.Hash) (ret metainfo.Hash, ok bool, err error) {
732         res, ok, err := handshake(c.rw(), ih, cl.peerID, cl.extensionBytes)
733         if err != nil || !ok {
734                 return
735         }
736         ret = res.Hash
737         c.PeerExtensionBytes = res.peerExtensionBytes
738         c.PeerID = res.PeerID
739         c.completedHandshake = time.Now()
740         return
741 }
742
743 func (cl *Client) runReceivedConn(c *connection) {
744         err := c.conn.SetDeadline(time.Now().Add(cl.config.HandshakesTimeout))
745         if err != nil {
746                 panic(err)
747         }
748         t, err := cl.receiveHandshakes(c)
749         if err != nil {
750                 if cl.config.Debug {
751                         log.Printf("error receiving handshakes: %s", err)
752                 }
753                 return
754         }
755         if t == nil {
756                 return
757         }
758         cl.mu.Lock()
759         defer cl.mu.Unlock()
760         cl.runHandshookConn(c, t, false)
761 }
762
763 func (cl *Client) runHandshookConn(c *connection, t *Torrent, outgoing bool) {
764         t.reconcileHandshakeStats(c)
765         if c.PeerID == cl.peerID {
766                 if outgoing {
767                         connsToSelf.Add(1)
768                         addr := c.conn.RemoteAddr().String()
769                         cl.dopplegangerAddrs[addr] = struct{}{}
770                 } else {
771                         // Because the remote address is not necessarily the same as its
772                         // client's torrent listen address, we won't record the remote address
773                         // as a doppleganger. Instead, the initiator can record *us* as the
774                         // doppleganger.
775                 }
776                 return
777         }
778         c.conn.SetWriteDeadline(time.Time{})
779         c.r = deadlineReader{c.conn, c.r}
780         completedHandshakeConnectionFlags.Add(c.connectionFlags(), 1)
781         if connIsIpv6(c.conn) {
782                 torrent.Add("completed handshake over ipv6", 1)
783         }
784         if !t.addConnection(c, outgoing) {
785                 return
786         }
787         defer t.dropConnection(c)
788         go c.writer(time.Minute)
789         cl.sendInitialMessages(c, t)
790         err := c.mainReadLoop()
791         if err != nil && cl.config.Debug {
792                 log.Printf("error during connection main read loop: %s", err)
793         }
794 }
795
796 func (cl *Client) sendInitialMessages(conn *connection, torrent *Torrent) {
797         func() {
798                 if conn.fastEnabled() {
799                         if torrent.haveAllPieces() {
800                                 conn.Post(pp.Message{Type: pp.HaveAll})
801                                 conn.sentHaves.AddRange(0, conn.t.NumPieces())
802                                 return
803                         } else if !torrent.haveAnyPieces() {
804                                 conn.Post(pp.Message{Type: pp.HaveNone})
805                                 conn.sentHaves.Clear()
806                                 return
807                         }
808                 }
809                 conn.PostBitfield()
810         }()
811         if conn.PeerExtensionBytes.SupportsExtended() && cl.extensionBytes.SupportsExtended() {
812                 conn.Post(pp.Message{
813                         Type:       pp.Extended,
814                         ExtendedID: pp.HandshakeExtendedID,
815                         ExtendedPayload: func() []byte {
816                                 d := map[string]interface{}{
817                                         "m": func() (ret map[string]int) {
818                                                 ret = make(map[string]int, 2)
819                                                 ret["ut_metadata"] = metadataExtendedId
820                                                 if !cl.config.DisablePEX {
821                                                         ret["ut_pex"] = pexExtendedId
822                                                 }
823                                                 return
824                                         }(),
825                                         "v": cl.config.ExtendedHandshakeClientVersion,
826                                         // No upload queue is implemented yet.
827                                         "reqq": 64,
828                                 }
829                                 if !cl.config.DisableEncryption {
830                                         d["e"] = 1
831                                 }
832                                 if torrent.metadataSizeKnown() {
833                                         d["metadata_size"] = torrent.metadataSize()
834                                 }
835                                 if p := cl.incomingPeerPort(); p != 0 {
836                                         d["p"] = p
837                                 }
838                                 yourip, err := addrCompactIP(conn.remoteAddr())
839                                 if err != nil {
840                                         log.Printf("error calculating yourip field value in extension handshake: %s", err)
841                                 } else {
842                                         d["yourip"] = yourip
843                                 }
844                                 // log.Printf("sending %v", d)
845                                 b, err := bencode.Marshal(d)
846                                 if err != nil {
847                                         panic(err)
848                                 }
849                                 return b
850                         }(),
851                 })
852         }
853         if conn.PeerExtensionBytes.SupportsDHT() && cl.extensionBytes.SupportsDHT() && cl.haveDhtServer() {
854                 conn.Post(pp.Message{
855                         Type: pp.Port,
856                         Port: cl.dhtPort(),
857                 })
858         }
859 }
860
861 func (cl *Client) dhtPort() (ret uint16) {
862         cl.eachDhtServer(func(s *dht.Server) {
863                 ret = uint16(missinggo.AddrPort(s.Addr()))
864         })
865         return
866 }
867
868 func (cl *Client) haveDhtServer() (ret bool) {
869         cl.eachDhtServer(func(_ *dht.Server) {
870                 ret = true
871         })
872         return
873 }
874
875 // Process incoming ut_metadata message.
876 func (cl *Client) gotMetadataExtensionMsg(payload []byte, t *Torrent, c *connection) error {
877         var d map[string]int
878         err := bencode.Unmarshal(payload, &d)
879         if _, ok := err.(bencode.ErrUnusedTrailingBytes); ok {
880         } else if err != nil {
881                 return fmt.Errorf("error unmarshalling bencode: %s", err)
882         }
883         msgType, ok := d["msg_type"]
884         if !ok {
885                 return errors.New("missing msg_type field")
886         }
887         piece := d["piece"]
888         switch msgType {
889         case pp.DataMetadataExtensionMsgType:
890                 if !c.requestedMetadataPiece(piece) {
891                         return fmt.Errorf("got unexpected piece %d", piece)
892                 }
893                 c.metadataRequests[piece] = false
894                 begin := len(payload) - metadataPieceSize(d["total_size"], piece)
895                 if begin < 0 || begin >= len(payload) {
896                         return fmt.Errorf("data has bad offset in payload: %d", begin)
897                 }
898                 t.saveMetadataPiece(piece, payload[begin:])
899                 c.stats.ChunksReadUseful++
900                 c.t.stats.ChunksReadUseful++
901                 c.lastUsefulChunkReceived = time.Now()
902                 return t.maybeCompleteMetadata()
903         case pp.RequestMetadataExtensionMsgType:
904                 if !t.haveMetadataPiece(piece) {
905                         c.Post(t.newMetadataExtensionMessage(c, pp.RejectMetadataExtensionMsgType, d["piece"], nil))
906                         return nil
907                 }
908                 start := (1 << 14) * piece
909                 c.Post(t.newMetadataExtensionMessage(c, pp.DataMetadataExtensionMsgType, piece, t.metadataBytes[start:start+t.metadataPieceSize(piece)]))
910                 return nil
911         case pp.RejectMetadataExtensionMsgType:
912                 return nil
913         default:
914                 return errors.New("unknown msg_type value")
915         }
916 }
917
918 func (cl *Client) badPeerIPPort(ip net.IP, port int) bool {
919         if port == 0 {
920                 return true
921         }
922         if cl.dopplegangerAddr(net.JoinHostPort(ip.String(), strconv.FormatInt(int64(port), 10))) {
923                 return true
924         }
925         if _, ok := cl.ipBlockRange(ip); ok {
926                 return true
927         }
928         if _, ok := cl.badPeerIPs[ip.String()]; ok {
929                 return true
930         }
931         return false
932 }
933
934 // Return a Torrent ready for insertion into a Client.
935 func (cl *Client) newTorrent(ih metainfo.Hash, specStorage storage.ClientImpl) (t *Torrent) {
936         // use provided storage, if provided
937         storageClient := cl.defaultStorage
938         if specStorage != nil {
939                 storageClient = storage.NewClient(specStorage)
940         }
941
942         t = &Torrent{
943                 cl:       cl,
944                 infoHash: ih,
945                 peers: prioritizedPeers{
946                         om: btree.New(2),
947                         getPrio: func(p Peer) peerPriority {
948                                 return bep40Priority(cl.publicAddr(p.IP), p.addr())
949                         },
950                 },
951                 conns: make(map[*connection]struct{}, 2*cl.config.EstablishedConnsPerTorrent),
952
953                 halfOpen:          make(map[string]Peer),
954                 pieceStateChanges: pubsub.NewPubSub(),
955
956                 storageOpener:       storageClient,
957                 maxEstablishedConns: cl.config.EstablishedConnsPerTorrent,
958
959                 networkingEnabled: true,
960                 requestStrategy:   2,
961                 metadataChanged: sync.Cond{
962                         L: &cl.mu,
963                 },
964         }
965         t.logger = cl.logger.Clone().AddValue(t)
966         t.setChunkSize(defaultChunkSize)
967         return
968 }
969
970 // A file-like handle to some torrent data resource.
971 type Handle interface {
972         io.Reader
973         io.Seeker
974         io.Closer
975         io.ReaderAt
976 }
977
978 func (cl *Client) AddTorrentInfoHash(infoHash metainfo.Hash) (t *Torrent, new bool) {
979         return cl.AddTorrentInfoHashWithStorage(infoHash, nil)
980 }
981
982 // Adds a torrent by InfoHash with a custom Storage implementation.
983 // If the torrent already exists then this Storage is ignored and the
984 // existing torrent returned with `new` set to `false`
985 func (cl *Client) AddTorrentInfoHashWithStorage(infoHash metainfo.Hash, specStorage storage.ClientImpl) (t *Torrent, new bool) {
986         cl.mu.Lock()
987         defer cl.mu.Unlock()
988         t, ok := cl.torrents[infoHash]
989         if ok {
990                 return
991         }
992         new = true
993         t = cl.newTorrent(infoHash, specStorage)
994         cl.eachDhtServer(func(s *dht.Server) {
995                 go t.dhtAnnouncer(s)
996         })
997         cl.torrents[infoHash] = t
998         t.updateWantPeersEvent()
999         // Tickle Client.waitAccept, new torrent may want conns.
1000         cl.event.Broadcast()
1001         return
1002 }
1003
1004 // Add or merge a torrent spec. If the torrent is already present, the
1005 // trackers will be merged with the existing ones. If the Info isn't yet
1006 // known, it will be set. The display name is replaced if the new spec
1007 // provides one. Returns new if the torrent wasn't already in the client.
1008 // Note that any `Storage` defined on the spec will be ignored if the
1009 // torrent is already present (i.e. `new` return value is `true`)
1010 func (cl *Client) AddTorrentSpec(spec *TorrentSpec) (t *Torrent, new bool, err error) {
1011         t, new = cl.AddTorrentInfoHashWithStorage(spec.InfoHash, spec.Storage)
1012         if spec.DisplayName != "" {
1013                 t.SetDisplayName(spec.DisplayName)
1014         }
1015         if spec.InfoBytes != nil {
1016                 err = t.SetInfoBytes(spec.InfoBytes)
1017                 if err != nil {
1018                         return
1019                 }
1020         }
1021         cl.mu.Lock()
1022         defer cl.mu.Unlock()
1023         if spec.ChunkSize != 0 {
1024                 t.setChunkSize(pp.Integer(spec.ChunkSize))
1025         }
1026         t.addTrackers(spec.Trackers)
1027         t.maybeNewConns()
1028         return
1029 }
1030
1031 func (cl *Client) dropTorrent(infoHash metainfo.Hash) (err error) {
1032         t, ok := cl.torrents[infoHash]
1033         if !ok {
1034                 err = fmt.Errorf("no such torrent")
1035                 return
1036         }
1037         err = t.close()
1038         if err != nil {
1039                 panic(err)
1040         }
1041         delete(cl.torrents, infoHash)
1042         return
1043 }
1044
1045 func (cl *Client) allTorrentsCompleted() bool {
1046         for _, t := range cl.torrents {
1047                 if !t.haveInfo() {
1048                         return false
1049                 }
1050                 if !t.haveAllPieces() {
1051                         return false
1052                 }
1053         }
1054         return true
1055 }
1056
1057 // Returns true when all torrents are completely downloaded and false if the
1058 // client is stopped before that.
1059 func (cl *Client) WaitAll() bool {
1060         cl.mu.Lock()
1061         defer cl.mu.Unlock()
1062         for !cl.allTorrentsCompleted() {
1063                 if cl.closed.IsSet() {
1064                         return false
1065                 }
1066                 cl.event.Wait()
1067         }
1068         return true
1069 }
1070
1071 // Returns handles to all the torrents loaded in the Client.
1072 func (cl *Client) Torrents() []*Torrent {
1073         cl.mu.Lock()
1074         defer cl.mu.Unlock()
1075         return cl.torrentsAsSlice()
1076 }
1077
1078 func (cl *Client) torrentsAsSlice() (ret []*Torrent) {
1079         for _, t := range cl.torrents {
1080                 ret = append(ret, t)
1081         }
1082         return
1083 }
1084
1085 func (cl *Client) AddMagnet(uri string) (T *Torrent, err error) {
1086         spec, err := TorrentSpecFromMagnetURI(uri)
1087         if err != nil {
1088                 return
1089         }
1090         T, _, err = cl.AddTorrentSpec(spec)
1091         return
1092 }
1093
1094 func (cl *Client) AddTorrent(mi *metainfo.MetaInfo) (T *Torrent, err error) {
1095         T, _, err = cl.AddTorrentSpec(TorrentSpecFromMetaInfo(mi))
1096         var ss []string
1097         slices.MakeInto(&ss, mi.Nodes)
1098         cl.AddDHTNodes(ss)
1099         return
1100 }
1101
1102 func (cl *Client) AddTorrentFromFile(filename string) (T *Torrent, err error) {
1103         mi, err := metainfo.LoadFromFile(filename)
1104         if err != nil {
1105                 return
1106         }
1107         return cl.AddTorrent(mi)
1108 }
1109
1110 func (cl *Client) DhtServers() []*dht.Server {
1111         return cl.dhtServers
1112 }
1113
1114 func (cl *Client) AddDHTNodes(nodes []string) {
1115         for _, n := range nodes {
1116                 hmp := missinggo.SplitHostMaybePort(n)
1117                 ip := net.ParseIP(hmp.Host)
1118                 if ip == nil {
1119                         log.Printf("won't add DHT node with bad IP: %q", hmp.Host)
1120                         continue
1121                 }
1122                 ni := krpc.NodeInfo{
1123                         Addr: krpc.NodeAddr{
1124                                 IP:   ip,
1125                                 Port: hmp.Port,
1126                         },
1127                 }
1128                 cl.eachDhtServer(func(s *dht.Server) {
1129                         s.AddNode(ni)
1130                 })
1131         }
1132 }
1133
1134 func (cl *Client) banPeerIP(ip net.IP) {
1135         if cl.badPeerIPs == nil {
1136                 cl.badPeerIPs = make(map[string]struct{})
1137         }
1138         cl.badPeerIPs[ip.String()] = struct{}{}
1139 }
1140
1141 func (cl *Client) newConnection(nc net.Conn) (c *connection) {
1142         c = &connection{
1143                 conn:            nc,
1144                 Choked:          true,
1145                 PeerChoked:      true,
1146                 PeerMaxRequests: 250,
1147                 writeBuffer:     new(bytes.Buffer),
1148         }
1149         c.writerCond.L = &cl.mu
1150         c.setRW(connStatsReadWriter{nc, &cl.mu, c})
1151         c.r = &rateLimitedReader{
1152                 l: cl.downloadLimit,
1153                 r: c.r,
1154         }
1155         return
1156 }
1157
1158 func (cl *Client) onDHTAnnouncePeer(ih metainfo.Hash, p dht.Peer) {
1159         cl.mu.Lock()
1160         defer cl.mu.Unlock()
1161         t := cl.torrent(ih)
1162         if t == nil {
1163                 return
1164         }
1165         t.addPeers([]Peer{{
1166                 IP:     p.IP,
1167                 Port:   p.Port,
1168                 Source: peerSourceDHTAnnouncePeer,
1169         }})
1170 }
1171
1172 func firstNotNil(ips ...net.IP) net.IP {
1173         for _, ip := range ips {
1174                 if ip != nil {
1175                         return ip
1176                 }
1177         }
1178         return nil
1179 }
1180
1181 func (cl *Client) eachListener(f func(socket) bool) {
1182         for _, s := range cl.conns {
1183                 if !f(s) {
1184                         break
1185                 }
1186         }
1187 }
1188
1189 func (cl *Client) findListener(f func(net.Listener) bool) (ret net.Listener) {
1190         cl.eachListener(func(l socket) bool {
1191                 ret = l
1192                 return !f(l)
1193         })
1194         return
1195 }
1196
1197 func (cl *Client) publicIp(peer net.IP) net.IP {
1198         // TODO: Use BEP 10 to determine how peers are seeing us.
1199         if peer.To4() != nil {
1200                 return firstNotNil(
1201                         cl.config.PublicIp4,
1202                         cl.findListenerIp(func(ip net.IP) bool { return ip.To4() != nil }),
1203                 )
1204         } else {
1205                 return firstNotNil(
1206                         cl.config.PublicIp6,
1207                         cl.findListenerIp(func(ip net.IP) bool { return ip.To4() == nil }),
1208                 )
1209         }
1210 }
1211
1212 func (cl *Client) findListenerIp(f func(net.IP) bool) net.IP {
1213         return missinggo.AddrIP(cl.findListener(func(l net.Listener) bool {
1214                 return f(missinggo.AddrIP(l.Addr()))
1215         }).Addr())
1216 }
1217
1218 // Our IP as a peer should see it.
1219 func (cl *Client) publicAddr(peer net.IP) ipPort {
1220         return ipPort{cl.publicIp(peer), uint16(cl.incomingPeerPort())}
1221 }
1222
1223 func (cl *Client) ListenAddrs() (ret []net.Addr) {
1224         cl.eachListener(func(l socket) bool {
1225                 ret = append(ret, l.Addr())
1226                 return true
1227         })
1228         return
1229 }