]> Sergey Matveev's repositories - btrtrc.git/blob - wstracker.go
Pool webtorrent tracker websockets at the Client level
[btrtrc.git] / wstracker.go
1 package torrent
2
3 import (
4         "fmt"
5         "net/url"
6         "sync"
7
8         "github.com/anacrolix/log"
9
10         "github.com/anacrolix/torrent/tracker"
11         "github.com/anacrolix/torrent/webtorrent"
12         "github.com/pion/datachannel"
13 )
14
15 type websocketTracker struct {
16         url url.URL
17         *webtorrent.TrackerClient
18 }
19
20 func (me websocketTracker) statusLine() string {
21         return fmt.Sprintf("%q", me.url.String())
22 }
23
24 func (me websocketTracker) URL() url.URL {
25         return me.url
26 }
27
28 type refCountedWebtorrentTrackerClient struct {
29         webtorrent.TrackerClient
30         refCount int
31 }
32
33 type websocketTrackers struct {
34         PeerId             [20]byte
35         Logger             log.Logger
36         GetAnnounceRequest func(event tracker.AnnounceEvent, infoHash [20]byte) tracker.AnnounceRequest
37         OnConn             func(datachannel.ReadWriteCloser, webtorrent.DataChannelContext)
38         mu                 sync.Mutex
39         clients            map[string]*refCountedWebtorrentTrackerClient
40 }
41
42 func (me *websocketTrackers) Get(url string) (*webtorrent.TrackerClient, func()) {
43         me.mu.Lock()
44         defer me.mu.Unlock()
45         value, ok := me.clients[url]
46         if !ok {
47                 value = &refCountedWebtorrentTrackerClient{
48                         TrackerClient: webtorrent.TrackerClient{
49                                 Url:                url,
50                                 GetAnnounceRequest: me.GetAnnounceRequest,
51                                 PeerId:             me.PeerId,
52                                 OnConn:             me.OnConn,
53                                 Logger: me.Logger.WithText(func(m log.Msg) string {
54                                         return fmt.Sprintf("tracker client for %q: %v", url, m)
55                                 }),
56                         },
57                 }
58                 go func() {
59                         err := value.TrackerClient.Run()
60                         if err != nil {
61                                 me.Logger.Printf("error running tracker client for %q: %v", url, err)
62                         }
63                 }()
64                 if me.clients == nil {
65                         me.clients = make(map[string]*refCountedWebtorrentTrackerClient)
66                 }
67                 me.clients[url] = value
68         }
69         value.refCount++
70         return &value.TrackerClient, func() {
71                 me.mu.Lock()
72                 defer me.mu.Unlock()
73                 value.refCount--
74                 if value.refCount == 0 {
75                         value.TrackerClient.Close()
76                         delete(me.clients, url)
77                 }
78         }
79 }