From 3909c6c125321d31952d0eb142820c50198ed7df Mon Sep 17 00:00:00 2001
From: Marco Vidonis <31407403+marcovidonis@users.noreply.github.com>
Date: Wed, 7 Dec 2022 22:17:33 +0000
Subject: [PATCH] Add customer headers when dialling WS connection to tracker
 (#789)

* expose WebtorrentTrackerHttpHeader field
---
 client.go                    |  5 +++--
 config.go                    |  3 +++
 webtorrent/tracker-client.go | 11 ++++++++++-
 wstracker.go                 | 21 ++++++++++++---------
 4 files changed, 28 insertions(+), 12 deletions(-)

diff --git a/client.go b/client.go
index e68e80b1..4adf28b7 100644
--- a/client.go
+++ b/client.go
@@ -297,8 +297,9 @@ func NewClient(cfg *ClientConfig) (cl *Client, err error) {
 			}
 			return t.announceRequest(event), nil
 		},
-		Proxy:       cl.config.HTTPProxy,
-		DialContext: cl.config.TrackerDialContext,
+		Proxy:                      cl.config.HTTPProxy,
+		WebsocketTrackerHttpHeader: cl.config.WebsocketTrackerHttpHeader,
+		DialContext:                cl.config.TrackerDialContext,
 		OnConn: func(dc datachannel.ReadWriteCloser, dcc webtorrent.DataChannelContext) {
 			cl.lock()
 			defer cl.unlock()
diff --git a/config.go b/config.go
index 09f9bc1e..e1e6452a 100644
--- a/config.go
+++ b/config.go
@@ -117,6 +117,9 @@ type ClientConfig struct {
 	// HttpRequestDirector modifies the request before it's sent.
 	// Useful for adding authentication headers, for example
 	HttpRequestDirector func(*http.Request) error
+	// WebsocketTrackerHttpHeader returns a custom header to be used when dialing a websocket connection
+	// to the tracker. Useful for adding authentication headers
+	WebsocketTrackerHttpHeader func() http.Header
 	// Updated occasionally to when there's been some changes to client
 	// behaviour in case other clients are assuming anything of us. See also
 	// `bep20`.
diff --git a/webtorrent/tracker-client.go b/webtorrent/tracker-client.go
index 64885bf4..60cd8527 100644
--- a/webtorrent/tracker-client.go
+++ b/webtorrent/tracker-client.go
@@ -5,6 +5,7 @@ import (
 	"crypto/rand"
 	"encoding/json"
 	"fmt"
+	"net/http"
 	"sync"
 	"time"
 
@@ -40,6 +41,8 @@ type TrackerClient struct {
 	closed         bool
 	stats          TrackerClientStats
 	pingTicker     *time.Ticker
+
+	WebsocketTrackerHttpHeader func() http.Header
 }
 
 func (me *TrackerClient) Stats() TrackerClientStats {
@@ -86,7 +89,13 @@ func (tc *TrackerClient) doWebsocket() error {
 	tc.mu.Lock()
 	tc.stats.Dials++
 	tc.mu.Unlock()
-	c, _, err := tc.Dialer.Dial(tc.Url, nil)
+
+	var header http.Header
+	if tc.WebsocketTrackerHttpHeader != nil {
+		header = tc.WebsocketTrackerHttpHeader()
+	}
+
+	c, _, err := tc.Dialer.Dial(tc.Url, header)
 	if err != nil {
 		return fmt.Errorf("dialing tracker: %w", err)
 	}
diff --git a/wstracker.go b/wstracker.go
index 9b1a9201..c379dc31 100644
--- a/wstracker.go
+++ b/wstracker.go
@@ -4,6 +4,7 @@ import (
 	"context"
 	"fmt"
 	"net"
+	netHttp "net/http"
 	"net/url"
 	"sync"
 
@@ -12,7 +13,7 @@ import (
 	"github.com/pion/datachannel"
 
 	"github.com/anacrolix/torrent/tracker"
-	"github.com/anacrolix/torrent/tracker/http"
+	httpTracker "github.com/anacrolix/torrent/tracker/http"
 	"github.com/anacrolix/torrent/webtorrent"
 )
 
@@ -35,14 +36,15 @@ type refCountedWebtorrentTrackerClient struct {
 }
 
 type websocketTrackers struct {
-	PeerId             [20]byte
-	Logger             log.Logger
-	GetAnnounceRequest func(event tracker.AnnounceEvent, infoHash [20]byte) (tracker.AnnounceRequest, error)
-	OnConn             func(datachannel.ReadWriteCloser, webtorrent.DataChannelContext)
-	mu                 sync.Mutex
-	clients            map[string]*refCountedWebtorrentTrackerClient
-	Proxy              httpTracker.ProxyFunc
-	DialContext        func(ctx context.Context, network, addr string) (net.Conn, error)
+	PeerId                     [20]byte
+	Logger                     log.Logger
+	GetAnnounceRequest         func(event tracker.AnnounceEvent, infoHash [20]byte) (tracker.AnnounceRequest, error)
+	OnConn                     func(datachannel.ReadWriteCloser, webtorrent.DataChannelContext)
+	mu                         sync.Mutex
+	clients                    map[string]*refCountedWebtorrentTrackerClient
+	Proxy                      httpTracker.ProxyFunc
+	DialContext                func(ctx context.Context, network, addr string) (net.Conn, error)
+	WebsocketTrackerHttpHeader func() netHttp.Header
 }
 
 func (me *websocketTrackers) Get(url string, infoHash [20]byte) (*webtorrent.TrackerClient, func()) {
@@ -61,6 +63,7 @@ func (me *websocketTrackers) Get(url string, infoHash [20]byte) (*webtorrent.Tra
 				Logger: me.Logger.WithText(func(m log.Msg) string {
 					return fmt.Sprintf("tracker client for %q: %v", url, m)
 				}),
+				WebsocketTrackerHttpHeader: me.WebsocketTrackerHttpHeader,
 			},
 		}
 		value.TrackerClient.Start(func(err error) {
-- 
2.51.0