From 6cb65f4ecf424613722d4670b4d72727f40dbc71 Mon Sep 17 00:00:00 2001
From: Matt Joiner <anacrolix@gmail.com>
Date: Mon, 22 Nov 2021 18:05:50 +1100
Subject: [PATCH] Don't dial in UDP tracking

This could fix an issue where tracker addresses change, but we're already bound to a particular address and so fail to receive any more responses.
It should also make it easier to share UDP sockets between UDP tracker clients, although that's not currently implemented.
---
 tracker/udp/client.go      | 28 ++++++++++++-----
 tracker/udp/conn-client.go | 61 ++++++++++++++++++++++++--------------
 tracker/udp/dispatcher.go  |  9 ++++--
 tracker/udp_test.go        |  7 +++--
 4 files changed, 70 insertions(+), 35 deletions(-)

diff --git a/tracker/udp/client.go b/tracker/udp/client.go
index 8b072e62..dc67a900 100644
--- a/tracker/udp/client.go
+++ b/tracker/udp/client.go
@@ -7,8 +7,11 @@ import (
 	"errors"
 	"fmt"
 	"io"
+	"net"
 	"sync"
 	"time"
+
+	"github.com/anacrolix/dht/v2/krpc"
 )
 
 // Client interacts with UDP trackers via its Writer and Dispatcher. It has no knowledge of
@@ -22,11 +25,16 @@ type Client struct {
 }
 
 func (cl *Client) Announce(
-	ctx context.Context, req AnnounceRequest, peers AnnounceResponsePeers, opts Options,
+	ctx context.Context, req AnnounceRequest, opts Options,
+	// Decides whether the response body is IPv6 or IPv4, see BEP 15.
+	ipv6 func(net.Addr) bool,
 ) (
-	respHdr AnnounceResponseHeader, err error,
+	respHdr AnnounceResponseHeader,
+	// A slice of krpc.NodeAddr, likely wrapped in an appropriate unmarshalling wrapper.
+	peers AnnounceResponsePeers,
+	err error,
 ) {
-	respBody, err := cl.request(ctx, ActionAnnounce, append(mustMarshal(req), opts.Encode()...))
+	respBody, addr, err := cl.request(ctx, ActionAnnounce, append(mustMarshal(req), opts.Encode()...))
 	if err != nil {
 		return
 	}
@@ -36,6 +44,11 @@ func (cl *Client) Announce(
 		err = fmt.Errorf("reading response header: %w", err)
 		return
 	}
+	if ipv6(addr) {
+		peers = &krpc.CompactIPv6NodeAddrs{}
+	} else {
+		peers = &krpc.CompactIPv4NodeAddrs{}
+	}
 	err = peers.UnmarshalBinary(r.Bytes())
 	if err != nil {
 		err = fmt.Errorf("reading response peers: %w", err)
@@ -43,13 +56,13 @@ func (cl *Client) Announce(
 	return
 }
 
+// There's no way to pass options in a scrape, since we don't when the request body ends.
 func (cl *Client) Scrape(
 	ctx context.Context, ihs []InfoHash,
 ) (
 	out ScrapeResponse, err error,
 ) {
-	// There's no way to pass options in a scrape, since we don't when the request body ends.
-	respBody, err := cl.request(ctx, ActionScrape, mustMarshal(ScrapeRequest(ihs)))
+	respBody, _, err := cl.request(ctx, ActionScrape, mustMarshal(ScrapeRequest(ihs)))
 	if err != nil {
 		return
 	}
@@ -77,7 +90,7 @@ func (cl *Client) connect(ctx context.Context) (err error) {
 	if !cl.connIdIssued.IsZero() && time.Since(cl.connIdIssued) < time.Minute {
 		return nil
 	}
-	respBody, err := cl.request(ctx, ActionConnect, nil)
+	respBody, _, err := cl.request(ctx, ActionConnect, nil)
 	if err != nil {
 		return err
 	}
@@ -134,7 +147,7 @@ func (cl *Client) requestWriter(ctx context.Context, action Action, body []byte,
 	}
 }
 
-func (cl *Client) request(ctx context.Context, action Action, body []byte) (respBody []byte, err error) {
+func (cl *Client) request(ctx context.Context, action Action, body []byte) (respBody []byte, addr net.Addr, err error) {
 	respChan := make(chan DispatchedResponse, 1)
 	t := cl.Dispatcher.NewTransaction(func(dr DispatchedResponse) {
 		respChan <- dr
@@ -150,6 +163,7 @@ func (cl *Client) request(ctx context.Context, action Action, body []byte) (resp
 	case dr := <-respChan:
 		if dr.Header.Action == action {
 			respBody = dr.Body
+			addr = dr.Addr
 		} else if dr.Header.Action == ActionError {
 			err = errors.New(string(dr.Body))
 		} else {
diff --git a/tracker/udp/conn-client.go b/tracker/udp/conn-client.go
index 81e139c9..511c24fe 100644
--- a/tracker/udp/conn-client.go
+++ b/tracker/udp/conn-client.go
@@ -2,9 +2,9 @@ package udp
 
 import (
 	"context"
+	"log"
 	"net"
 
-	"github.com/anacrolix/dht/v2/krpc"
 	"github.com/anacrolix/missinggo/v2"
 )
 
@@ -20,30 +20,31 @@ type NewConnClientOpts struct {
 // Manages a Client with a specific connection.
 type ConnClient struct {
 	Client  Client
-	conn    net.Conn
+	conn    net.PacketConn
 	d       Dispatcher
 	readErr error
-	ipv6    bool
+	closed  bool
+	newOpts NewConnClientOpts
 }
 
 func (cc *ConnClient) reader() {
 	b := make([]byte, 0x800)
 	for {
-		n, err := cc.conn.Read(b)
+		n, addr, err := cc.conn.ReadFrom(b)
 		if err != nil {
 			// TODO: Do bad things to the dispatcher, and incoming calls to the client if we have a
 			// read error.
 			cc.readErr = err
 			break
 		}
-		_ = cc.d.Dispatch(b[:n])
-		// if err != nil {
-		// 	log.Printf("dispatching packet received on %v (%q): %v", cc.conn, string(b[:n]), err)
-		// }
+		err = cc.d.Dispatch(b[:n], addr)
+		if err != nil {
+			log.Printf("dispatching packet received on %v (%q): %v", cc.conn, string(b[:n]), err)
+		}
 	}
 }
 
-func ipv6(opt *bool, network string, conn net.Conn) bool {
+func ipv6(opt *bool, network string, remoteAddr net.Addr) bool {
 	if opt != nil {
 		return *opt
 	}
@@ -53,21 +54,40 @@ func ipv6(opt *bool, network string, conn net.Conn) bool {
 	case "udp6":
 		return true
 	}
-	rip := missinggo.AddrIP(conn.RemoteAddr())
+	rip := missinggo.AddrIP(remoteAddr)
 	return rip.To16() != nil && rip.To4() == nil
 }
 
+// Allows a UDP Client to write packets to an endpoint without knowing about the network specifics.
+type clientWriter struct {
+	pc      net.PacketConn
+	network string
+	address string
+}
+
+func (me clientWriter) Write(p []byte) (n int, err error) {
+	addr, err := net.ResolveUDPAddr(me.network, me.address)
+	if err != nil {
+		return
+	}
+	return me.pc.WriteTo(p, addr)
+}
+
 func NewConnClient(opts NewConnClientOpts) (cc *ConnClient, err error) {
-	conn, err := net.Dial(opts.Network, opts.Host)
+	conn, err := net.ListenPacket(opts.Network, ":0")
 	if err != nil {
 		return
 	}
 	cc = &ConnClient{
 		Client: Client{
-			Writer: conn,
+			Writer: clientWriter{
+				pc:      conn,
+				network: opts.Network,
+				address: opts.Host,
+			},
 		},
-		conn: conn,
-		ipv6: ipv6(opts.Ipv6, opts.Network, conn),
+		conn:    conn,
+		newOpts: opts,
 	}
 	cc.Client.Dispatcher = &cc.d
 	go cc.reader()
@@ -75,6 +95,7 @@ func NewConnClient(opts NewConnClientOpts) (cc *ConnClient, err error) {
 }
 
 func (c *ConnClient) Close() error {
+	c.closed = true
 	return c.conn.Close()
 }
 
@@ -83,13 +104,7 @@ func (c *ConnClient) Announce(
 ) (
 	h AnnounceResponseHeader, nas AnnounceResponsePeers, err error,
 ) {
-	nas = func() AnnounceResponsePeers {
-		if c.ipv6 {
-			return &krpc.CompactIPv6NodeAddrs{}
-		} else {
-			return &krpc.CompactIPv4NodeAddrs{}
-		}
-	}()
-	h, err = c.Client.Announce(ctx, req, nas, opts)
-	return
+	return c.Client.Announce(ctx, req, opts, func(addr net.Addr) bool {
+		return ipv6(c.newOpts.Ipv6, c.newOpts.Network, addr)
+	})
 }
diff --git a/tracker/udp/dispatcher.go b/tracker/udp/dispatcher.go
index 7fc3e1b3..5709bd55 100644
--- a/tracker/udp/dispatcher.go
+++ b/tracker/udp/dispatcher.go
@@ -3,6 +3,7 @@ package udp
 import (
 	"bytes"
 	"fmt"
+	"net"
 	"sync"
 )
 
@@ -13,7 +14,7 @@ type Dispatcher struct {
 }
 
 // The caller owns b.
-func (me *Dispatcher) Dispatch(b []byte) error {
+func (me *Dispatcher) Dispatch(b []byte, addr net.Addr) error {
 	buf := bytes.NewBuffer(b)
 	var rh ResponseHeader
 	err := Read(buf, &rh)
@@ -26,6 +27,7 @@ func (me *Dispatcher) Dispatch(b []byte) error {
 		t.h(DispatchedResponse{
 			Header: rh,
 			Body:   append([]byte(nil), buf.Bytes()...),
+			Addr:   addr,
 		})
 		return nil
 	} else {
@@ -62,5 +64,8 @@ func (me *Dispatcher) NewTransaction(h TransactionResponseHandler) Transaction {
 
 type DispatchedResponse struct {
 	Header ResponseHeader
-	Body   []byte
+	// Response payload, after the header.
+	Body []byte
+	// Response source address
+	Addr net.Addr
 }
diff --git a/tracker/udp_test.go b/tracker/udp_test.go
index 32304730..7354063b 100644
--- a/tracker/udp_test.go
+++ b/tracker/udp_test.go
@@ -40,7 +40,7 @@ func TestAnnounceLocalhost(t *testing.T) {
 		},
 	}
 	var err error
-	srv.pc, err = net.ListenPacket("udp", ":0")
+	srv.pc, err = net.ListenPacket("udp", "localhost:0")
 	require.NoError(t, err)
 	defer srv.pc.Close()
 	go func() {
@@ -92,7 +92,7 @@ func TestUDPTracker(t *testing.T) {
 		t.Skip(err)
 	}
 	require.NoError(t, err)
-	t.Log(ar)
+	t.Logf("%+v", ar)
 }
 
 func TestAnnounceRandomInfoHashThirdParty(t *testing.T) {
@@ -143,7 +143,7 @@ func TestAnnounceRandomInfoHashThirdParty(t *testing.T) {
 
 // Check that URLPath option is done correctly.
 func TestURLPathOption(t *testing.T) {
-	conn, err := net.ListenUDP("udp", nil)
+	conn, err := net.ListenPacket("udp", "localhost:0")
 	if err != nil {
 		panic(err)
 	}
@@ -161,6 +161,7 @@ func TestURLPathOption(t *testing.T) {
 		announceErr <- err
 	}()
 	var b [512]byte
+	// conn.SetReadDeadline(time.Now().Add(time.Second))
 	_, addr, _ := conn.ReadFrom(b[:])
 	r := bytes.NewReader(b[:])
 	var h udp.RequestHeader
-- 
2.51.0