]> Sergey Matveev's repositories - btrtrc.git/blob - wstracker.go
gorond ./...
[btrtrc.git] / wstracker.go
1 package torrent
2
3 import (
4         "context"
5         "fmt"
6         "net"
7         "net/url"
8         "sync"
9
10         "github.com/anacrolix/log"
11         "github.com/gorilla/websocket"
12         "github.com/pion/datachannel"
13
14         "github.com/anacrolix/torrent/tracker"
15         "github.com/anacrolix/torrent/tracker/http"
16         "github.com/anacrolix/torrent/webtorrent"
17 )
18
19 type websocketTrackerStatus struct {
20         url url.URL
21         tc  *webtorrent.TrackerClient
22 }
23
24 func (me websocketTrackerStatus) statusLine() string {
25         return fmt.Sprintf("%+v", me.tc.Stats())
26 }
27
28 func (me websocketTrackerStatus) URL() *url.URL {
29         return &me.url
30 }
31
32 type refCountedWebtorrentTrackerClient struct {
33         webtorrent.TrackerClient
34         refCount int
35 }
36
37 type websocketTrackers struct {
38         PeerId             [20]byte
39         Logger             log.Logger
40         GetAnnounceRequest func(event tracker.AnnounceEvent, infoHash [20]byte) (tracker.AnnounceRequest, error)
41         OnConn             func(datachannel.ReadWriteCloser, webtorrent.DataChannelContext)
42         mu                 sync.Mutex
43         clients            map[string]*refCountedWebtorrentTrackerClient
44         Proxy              http.ProxyFunc
45         DialContext        func(ctx context.Context, network, addr string) (net.Conn, error)
46 }
47
48 func (me *websocketTrackers) Get(url string, infoHash [20]byte) (*webtorrent.TrackerClient, func()) {
49         me.mu.Lock()
50         defer me.mu.Unlock()
51         value, ok := me.clients[url]
52         if !ok {
53                 dialer := &websocket.Dialer{Proxy: me.Proxy, NetDialContext: me.DialContext, HandshakeTimeout: websocket.DefaultDialer.HandshakeTimeout}
54                 value = &refCountedWebtorrentTrackerClient{
55                         TrackerClient: webtorrent.TrackerClient{
56                                 Dialer:             dialer,
57                                 Url:                url,
58                                 GetAnnounceRequest: me.GetAnnounceRequest,
59                                 PeerId:             me.PeerId,
60                                 OnConn:             me.OnConn,
61                                 Logger: me.Logger.WithText(func(m log.Msg) string {
62                                         return fmt.Sprintf("tracker client for %q: %v", url, m)
63                                 }),
64                         },
65                 }
66                 value.TrackerClient.Start(func(err error) {
67                         if err != nil {
68                                 me.Logger.Printf("error running tracker client for %q: %v", url, err)
69                         }
70                 })
71                 if me.clients == nil {
72                         me.clients = make(map[string]*refCountedWebtorrentTrackerClient)
73                 }
74                 me.clients[url] = value
75         }
76         value.refCount++
77         return &value.TrackerClient, func() {
78                 me.mu.Lock()
79                 defer me.mu.Unlock()
80                 value.TrackerClient.CloseOffersForInfohash(infoHash)
81                 value.refCount--
82                 if value.refCount == 0 {
83                         value.TrackerClient.Close()
84                         delete(me.clients, url)
85                 }
86         }
87 }