]> Sergey Matveev's repositories - btrtrc.git/commitdiff
Support alternate remote host resolution
authorMatt Joiner <anacrolix@gmail.com>
Tue, 6 Dec 2022 04:59:06 +0000 (15:59 +1100)
committerMatt Joiner <anacrolix@gmail.com>
Tue, 6 Dec 2022 23:45:12 +0000 (10:45 +1100)
tracker/http/server/server.go

index 88457826e1bbb819ab97a94243558ad6106cc788..0840fe96cf0a205e46bd78817513381f52014de7 100644 (file)
@@ -6,6 +6,7 @@ import (
        "net/http"
        "net/netip"
        "net/url"
+       "strconv"
 
        "github.com/anacrolix/dht/v2/krpc"
        "github.com/anacrolix/log"
@@ -18,6 +19,9 @@ import (
 
 type Handler struct {
        AnnounceTracker udpTrackerServer.AnnounceTracker
+       // Called to derive an announcer's IP if non-nil. If not specified, the Request.RemoteAddr is
+       // used. Necessary for instances running behind reverse proxies for example.
+       RequestHost func(r *http.Request) (netip.Addr, error)
 }
 
 func unmarshalQueryKeyToArray(w http.ResponseWriter, key string, query url.Values) (ret [20]byte, ok bool) {
@@ -31,6 +35,20 @@ func unmarshalQueryKeyToArray(w http.ResponseWriter, key string, query url.Value
        return
 }
 
+var Logger = log.NewLogger("anacrolix", "torrent", "tracker", "http", "server")
+
+// Returns false if there was an error and it was served.
+func (me Handler) requestHostAddr(r *http.Request) (_ netip.Addr, err error) {
+       if me.RequestHost != nil {
+               return me.RequestHost(r)
+       }
+       host, _, err := net.SplitHostPort(r.RemoteAddr)
+       if err != nil {
+               return
+       }
+       return netip.ParseAddr(host)
+}
+
 func (me Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
        vs := r.URL.Query()
        var event tracker.AnnounceEvent
@@ -47,13 +65,15 @@ func (me Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
        if !ok {
                return
        }
-       host, _, err := net.SplitHostPort(r.RemoteAddr)
+       Logger.WithNames("request").Levelf(log.Debug, "request RemoteAddr=%q, header=%q", r.RemoteAddr, r.Header)
+       addr, err := me.requestHostAddr(r)
        if err != nil {
-               log.Printf("error splitting remote port: %v", err)
-               http.Error(w, "error determining your IP", http.StatusInternalServerError)
+               log.Printf("error getting requester IP: %v", err)
+               http.Error(w, "error determining your IP", http.StatusBadGateway)
                return
        }
-       addrPort, err := netip.ParseAddrPort(net.JoinHostPort(host, vs.Get("port")))
+       portU64, err := strconv.ParseUint(vs.Get("port"), 0, 16)
+       addrPort := netip.AddrPortFrom(addr, uint16(portU64))
        err = me.AnnounceTracker.TrackAnnounce(r.Context(), tracker.AnnounceRequest{
                InfoHash: infoHash,
                PeerId:   peerId,