From 3909c6c125321d31952d0eb142820c50198ed7df Mon Sep 17 00:00:00 2001
From: Marco Vidonis <31407403+marcovidonis@users.noreply.github.com>
Date: Wed, 7 Dec 2022 22:17:33 +0000
Subject: [PATCH] Add customer headers when dialling WS connection to tracker
(#789)
* expose WebtorrentTrackerHttpHeader field
---
client.go | 5 +++--
config.go | 3 +++
webtorrent/tracker-client.go | 11 ++++++++++-
wstracker.go | 21 ++++++++++++---------
4 files changed, 28 insertions(+), 12 deletions(-)
diff --git a/client.go b/client.go
index e68e80b1..4adf28b7 100644
--- a/client.go
+++ b/client.go
@@ -297,8 +297,9 @@ func NewClient(cfg *ClientConfig) (cl *Client, err error) {
}
return t.announceRequest(event), nil
},
- Proxy: cl.config.HTTPProxy,
- DialContext: cl.config.TrackerDialContext,
+ Proxy: cl.config.HTTPProxy,
+ WebsocketTrackerHttpHeader: cl.config.WebsocketTrackerHttpHeader,
+ DialContext: cl.config.TrackerDialContext,
OnConn: func(dc datachannel.ReadWriteCloser, dcc webtorrent.DataChannelContext) {
cl.lock()
defer cl.unlock()
diff --git a/config.go b/config.go
index 09f9bc1e..e1e6452a 100644
--- a/config.go
+++ b/config.go
@@ -117,6 +117,9 @@ type ClientConfig struct {
// HttpRequestDirector modifies the request before it's sent.
// Useful for adding authentication headers, for example
HttpRequestDirector func(*http.Request) error
+ // WebsocketTrackerHttpHeader returns a custom header to be used when dialing a websocket connection
+ // to the tracker. Useful for adding authentication headers
+ WebsocketTrackerHttpHeader func() http.Header
// Updated occasionally to when there's been some changes to client
// behaviour in case other clients are assuming anything of us. See also
// `bep20`.
diff --git a/webtorrent/tracker-client.go b/webtorrent/tracker-client.go
index 64885bf4..60cd8527 100644
--- a/webtorrent/tracker-client.go
+++ b/webtorrent/tracker-client.go
@@ -5,6 +5,7 @@ import (
"crypto/rand"
"encoding/json"
"fmt"
+ "net/http"
"sync"
"time"
@@ -40,6 +41,8 @@ type TrackerClient struct {
closed bool
stats TrackerClientStats
pingTicker *time.Ticker
+
+ WebsocketTrackerHttpHeader func() http.Header
}
func (me *TrackerClient) Stats() TrackerClientStats {
@@ -86,7 +89,13 @@ func (tc *TrackerClient) doWebsocket() error {
tc.mu.Lock()
tc.stats.Dials++
tc.mu.Unlock()
- c, _, err := tc.Dialer.Dial(tc.Url, nil)
+
+ var header http.Header
+ if tc.WebsocketTrackerHttpHeader != nil {
+ header = tc.WebsocketTrackerHttpHeader()
+ }
+
+ c, _, err := tc.Dialer.Dial(tc.Url, header)
if err != nil {
return fmt.Errorf("dialing tracker: %w", err)
}
diff --git a/wstracker.go b/wstracker.go
index 9b1a9201..c379dc31 100644
--- a/wstracker.go
+++ b/wstracker.go
@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"net"
+ netHttp "net/http"
"net/url"
"sync"
@@ -12,7 +13,7 @@ import (
"github.com/pion/datachannel"
"github.com/anacrolix/torrent/tracker"
- "github.com/anacrolix/torrent/tracker/http"
+ httpTracker "github.com/anacrolix/torrent/tracker/http"
"github.com/anacrolix/torrent/webtorrent"
)
@@ -35,14 +36,15 @@ type refCountedWebtorrentTrackerClient struct {
}
type websocketTrackers struct {
- PeerId [20]byte
- Logger log.Logger
- GetAnnounceRequest func(event tracker.AnnounceEvent, infoHash [20]byte) (tracker.AnnounceRequest, error)
- OnConn func(datachannel.ReadWriteCloser, webtorrent.DataChannelContext)
- mu sync.Mutex
- clients map[string]*refCountedWebtorrentTrackerClient
- Proxy httpTracker.ProxyFunc
- DialContext func(ctx context.Context, network, addr string) (net.Conn, error)
+ PeerId [20]byte
+ Logger log.Logger
+ GetAnnounceRequest func(event tracker.AnnounceEvent, infoHash [20]byte) (tracker.AnnounceRequest, error)
+ OnConn func(datachannel.ReadWriteCloser, webtorrent.DataChannelContext)
+ mu sync.Mutex
+ clients map[string]*refCountedWebtorrentTrackerClient
+ Proxy httpTracker.ProxyFunc
+ DialContext func(ctx context.Context, network, addr string) (net.Conn, error)
+ WebsocketTrackerHttpHeader func() netHttp.Header
}
func (me *websocketTrackers) Get(url string, infoHash [20]byte) (*webtorrent.TrackerClient, func()) {
@@ -61,6 +63,7 @@ func (me *websocketTrackers) Get(url string, infoHash [20]byte) (*webtorrent.Tra
Logger: me.Logger.WithText(func(m log.Msg) string {
return fmt.Sprintf("tracker client for %q: %v", url, m)
}),
+ WebsocketTrackerHttpHeader: me.WebsocketTrackerHttpHeader,
},
}
value.TrackerClient.Start(func(err error) {
--
2.51.0