]> Sergey Matveev's repositories - btrtrc.git/blob - wstracker.go
Include holepunch message protocol family in metrics
[btrtrc.git] / wstracker.go
1 package torrent
2
3 import (
4         "context"
5         "fmt"
6         "net"
7         netHttp "net/http"
8         "net/url"
9         "sync"
10
11         "github.com/anacrolix/log"
12         "github.com/gorilla/websocket"
13         "github.com/pion/datachannel"
14
15         "github.com/anacrolix/torrent/tracker"
16         httpTracker "github.com/anacrolix/torrent/tracker/http"
17         "github.com/anacrolix/torrent/webtorrent"
18 )
19
20 type websocketTrackerStatus struct {
21         url url.URL
22         tc  *webtorrent.TrackerClient
23 }
24
25 func (me websocketTrackerStatus) statusLine() string {
26         return fmt.Sprintf("%+v", me.tc.Stats())
27 }
28
29 func (me websocketTrackerStatus) URL() *url.URL {
30         return &me.url
31 }
32
33 type refCountedWebtorrentTrackerClient struct {
34         webtorrent.TrackerClient
35         refCount int
36 }
37
38 type websocketTrackers struct {
39         PeerId                     [20]byte
40         Logger                     log.Logger
41         GetAnnounceRequest         func(event tracker.AnnounceEvent, infoHash [20]byte) (tracker.AnnounceRequest, error)
42         OnConn                     func(datachannel.ReadWriteCloser, webtorrent.DataChannelContext)
43         mu                         sync.Mutex
44         clients                    map[string]*refCountedWebtorrentTrackerClient
45         Proxy                      httpTracker.ProxyFunc
46         DialContext                func(ctx context.Context, network, addr string) (net.Conn, error)
47         WebsocketTrackerHttpHeader func() netHttp.Header
48 }
49
50 func (me *websocketTrackers) Get(url string, infoHash [20]byte) (*webtorrent.TrackerClient, func()) {
51         me.mu.Lock()
52         defer me.mu.Unlock()
53         value, ok := me.clients[url]
54         if !ok {
55                 dialer := &websocket.Dialer{Proxy: me.Proxy, NetDialContext: me.DialContext, HandshakeTimeout: websocket.DefaultDialer.HandshakeTimeout}
56                 value = &refCountedWebtorrentTrackerClient{
57                         TrackerClient: webtorrent.TrackerClient{
58                                 Dialer:             dialer,
59                                 Url:                url,
60                                 GetAnnounceRequest: me.GetAnnounceRequest,
61                                 PeerId:             me.PeerId,
62                                 OnConn:             me.OnConn,
63                                 Logger: me.Logger.WithText(func(m log.Msg) string {
64                                         return fmt.Sprintf("tracker client for %q: %v", url, m)
65                                 }),
66                                 WebsocketTrackerHttpHeader: me.WebsocketTrackerHttpHeader,
67                         },
68                 }
69                 value.TrackerClient.Start(func(err error) {
70                         if err != nil {
71                                 me.Logger.Printf("error running tracker client for %q: %v", url, err)
72                         }
73                 })
74                 if me.clients == nil {
75                         me.clients = make(map[string]*refCountedWebtorrentTrackerClient)
76                 }
77                 me.clients[url] = value
78         }
79         value.refCount++
80         return &value.TrackerClient, func() {
81                 me.mu.Lock()
82                 defer me.mu.Unlock()
83                 value.TrackerClient.CloseOffersForInfohash(infoHash)
84                 value.refCount--
85                 if value.refCount == 0 {
86                         value.TrackerClient.Close()
87                         delete(me.clients, url)
88                 }
89         }
90 }