]> Sergey Matveev's repositories - btrtrc.git/commitdiff
Don't dial in UDP tracking
authorMatt Joiner <anacrolix@gmail.com>
Mon, 22 Nov 2021 07:05:50 +0000 (18:05 +1100)
committerMatt Joiner <anacrolix@gmail.com>
Mon, 22 Nov 2021 07:11:09 +0000 (18:11 +1100)
This could fix an issue where tracker addresses change, but we're already bound to a particular address and so fail to receive any more responses.
It should also make it easier to share UDP sockets between UDP tracker clients, although that's not currently implemented.

tracker/udp/client.go
tracker/udp/conn-client.go
tracker/udp/dispatcher.go
tracker/udp_test.go

index 8b072e62fef4b8db7b3027814ce7da4cb8a98293..dc67a9001fd43aec5bc915f1acd1832b042d5627 100644 (file)
@@ -7,8 +7,11 @@ import (
        "errors"
        "fmt"
        "io"
+       "net"
        "sync"
        "time"
+
+       "github.com/anacrolix/dht/v2/krpc"
 )
 
 // Client interacts with UDP trackers via its Writer and Dispatcher. It has no knowledge of
@@ -22,11 +25,16 @@ type Client struct {
 }
 
 func (cl *Client) Announce(
-       ctx context.Context, req AnnounceRequest, peers AnnounceResponsePeers, opts Options,
+       ctx context.Context, req AnnounceRequest, opts Options,
+       // Decides whether the response body is IPv6 or IPv4, see BEP 15.
+       ipv6 func(net.Addr) bool,
 ) (
-       respHdr AnnounceResponseHeader, err error,
+       respHdr AnnounceResponseHeader,
+       // A slice of krpc.NodeAddr, likely wrapped in an appropriate unmarshalling wrapper.
+       peers AnnounceResponsePeers,
+       err error,
 ) {
-       respBody, err := cl.request(ctx, ActionAnnounce, append(mustMarshal(req), opts.Encode()...))
+       respBody, addr, err := cl.request(ctx, ActionAnnounce, append(mustMarshal(req), opts.Encode()...))
        if err != nil {
                return
        }
@@ -36,6 +44,11 @@ func (cl *Client) Announce(
                err = fmt.Errorf("reading response header: %w", err)
                return
        }
+       if ipv6(addr) {
+               peers = &krpc.CompactIPv6NodeAddrs{}
+       } else {
+               peers = &krpc.CompactIPv4NodeAddrs{}
+       }
        err = peers.UnmarshalBinary(r.Bytes())
        if err != nil {
                err = fmt.Errorf("reading response peers: %w", err)
@@ -43,13 +56,13 @@ func (cl *Client) Announce(
        return
 }
 
+// There's no way to pass options in a scrape, since we don't when the request body ends.
 func (cl *Client) Scrape(
        ctx context.Context, ihs []InfoHash,
 ) (
        out ScrapeResponse, err error,
 ) {
-       // There's no way to pass options in a scrape, since we don't when the request body ends.
-       respBody, err := cl.request(ctx, ActionScrape, mustMarshal(ScrapeRequest(ihs)))
+       respBody, _, err := cl.request(ctx, ActionScrape, mustMarshal(ScrapeRequest(ihs)))
        if err != nil {
                return
        }
@@ -77,7 +90,7 @@ func (cl *Client) connect(ctx context.Context) (err error) {
        if !cl.connIdIssued.IsZero() && time.Since(cl.connIdIssued) < time.Minute {
                return nil
        }
-       respBody, err := cl.request(ctx, ActionConnect, nil)
+       respBody, _, err := cl.request(ctx, ActionConnect, nil)
        if err != nil {
                return err
        }
@@ -134,7 +147,7 @@ func (cl *Client) requestWriter(ctx context.Context, action Action, body []byte,
        }
 }
 
-func (cl *Client) request(ctx context.Context, action Action, body []byte) (respBody []byte, err error) {
+func (cl *Client) request(ctx context.Context, action Action, body []byte) (respBody []byte, addr net.Addr, err error) {
        respChan := make(chan DispatchedResponse, 1)
        t := cl.Dispatcher.NewTransaction(func(dr DispatchedResponse) {
                respChan <- dr
@@ -150,6 +163,7 @@ func (cl *Client) request(ctx context.Context, action Action, body []byte) (resp
        case dr := <-respChan:
                if dr.Header.Action == action {
                        respBody = dr.Body
+                       addr = dr.Addr
                } else if dr.Header.Action == ActionError {
                        err = errors.New(string(dr.Body))
                } else {
index 81e139c96151fbd856d95b5bd528a9456193e1db..511c24fe203aa3f4e7c386eedc408cde87003ea1 100644 (file)
@@ -2,9 +2,9 @@ package udp
 
 import (
        "context"
+       "log"
        "net"
 
-       "github.com/anacrolix/dht/v2/krpc"
        "github.com/anacrolix/missinggo/v2"
 )
 
@@ -20,30 +20,31 @@ type NewConnClientOpts struct {
 // Manages a Client with a specific connection.
 type ConnClient struct {
        Client  Client
-       conn    net.Conn
+       conn    net.PacketConn
        d       Dispatcher
        readErr error
-       ipv6    bool
+       closed  bool
+       newOpts NewConnClientOpts
 }
 
 func (cc *ConnClient) reader() {
        b := make([]byte, 0x800)
        for {
-               n, err := cc.conn.Read(b)
+               n, addr, err := cc.conn.ReadFrom(b)
                if err != nil {
                        // TODO: Do bad things to the dispatcher, and incoming calls to the client if we have a
                        // read error.
                        cc.readErr = err
                        break
                }
-               _ = cc.d.Dispatch(b[:n])
-               // if err != nil {
-               //      log.Printf("dispatching packet received on %v (%q): %v", cc.conn, string(b[:n]), err)
-               // }
+               err = cc.d.Dispatch(b[:n], addr)
+               if err != nil {
+                       log.Printf("dispatching packet received on %v (%q): %v", cc.conn, string(b[:n]), err)
+               }
        }
 }
 
-func ipv6(opt *bool, network string, conn net.Conn) bool {
+func ipv6(opt *bool, network string, remoteAddr net.Addr) bool {
        if opt != nil {
                return *opt
        }
@@ -53,21 +54,40 @@ func ipv6(opt *bool, network string, conn net.Conn) bool {
        case "udp6":
                return true
        }
-       rip := missinggo.AddrIP(conn.RemoteAddr())
+       rip := missinggo.AddrIP(remoteAddr)
        return rip.To16() != nil && rip.To4() == nil
 }
 
+// Allows a UDP Client to write packets to an endpoint without knowing about the network specifics.
+type clientWriter struct {
+       pc      net.PacketConn
+       network string
+       address string
+}
+
+func (me clientWriter) Write(p []byte) (n int, err error) {
+       addr, err := net.ResolveUDPAddr(me.network, me.address)
+       if err != nil {
+               return
+       }
+       return me.pc.WriteTo(p, addr)
+}
+
 func NewConnClient(opts NewConnClientOpts) (cc *ConnClient, err error) {
-       conn, err := net.Dial(opts.Network, opts.Host)
+       conn, err := net.ListenPacket(opts.Network, ":0")
        if err != nil {
                return
        }
        cc = &ConnClient{
                Client: Client{
-                       Writer: conn,
+                       Writer: clientWriter{
+                               pc:      conn,
+                               network: opts.Network,
+                               address: opts.Host,
+                       },
                },
-               conn: conn,
-               ipv6: ipv6(opts.Ipv6, opts.Network, conn),
+               conn:    conn,
+               newOpts: opts,
        }
        cc.Client.Dispatcher = &cc.d
        go cc.reader()
@@ -75,6 +95,7 @@ func NewConnClient(opts NewConnClientOpts) (cc *ConnClient, err error) {
 }
 
 func (c *ConnClient) Close() error {
+       c.closed = true
        return c.conn.Close()
 }
 
@@ -83,13 +104,7 @@ func (c *ConnClient) Announce(
 ) (
        h AnnounceResponseHeader, nas AnnounceResponsePeers, err error,
 ) {
-       nas = func() AnnounceResponsePeers {
-               if c.ipv6 {
-                       return &krpc.CompactIPv6NodeAddrs{}
-               } else {
-                       return &krpc.CompactIPv4NodeAddrs{}
-               }
-       }()
-       h, err = c.Client.Announce(ctx, req, nas, opts)
-       return
+       return c.Client.Announce(ctx, req, opts, func(addr net.Addr) bool {
+               return ipv6(c.newOpts.Ipv6, c.newOpts.Network, addr)
+       })
 }
index 7fc3e1b346e68c88d9523ff89294242ac758a4b2..5709bd55491de442b94894a084c364dd06b52d91 100644 (file)
@@ -3,6 +3,7 @@ package udp
 import (
        "bytes"
        "fmt"
+       "net"
        "sync"
 )
 
@@ -13,7 +14,7 @@ type Dispatcher struct {
 }
 
 // The caller owns b.
-func (me *Dispatcher) Dispatch(b []byte) error {
+func (me *Dispatcher) Dispatch(b []byte, addr net.Addr) error {
        buf := bytes.NewBuffer(b)
        var rh ResponseHeader
        err := Read(buf, &rh)
@@ -26,6 +27,7 @@ func (me *Dispatcher) Dispatch(b []byte) error {
                t.h(DispatchedResponse{
                        Header: rh,
                        Body:   append([]byte(nil), buf.Bytes()...),
+                       Addr:   addr,
                })
                return nil
        } else {
@@ -62,5 +64,8 @@ func (me *Dispatcher) NewTransaction(h TransactionResponseHandler) Transaction {
 
 type DispatchedResponse struct {
        Header ResponseHeader
-       Body   []byte
+       // Response payload, after the header.
+       Body []byte
+       // Response source address
+       Addr net.Addr
 }
index 323047306f02c6f7ac60cfc1b6e698b712de5f79..7354063b69dd664569341b3c3c1131135daa5d52 100644 (file)
@@ -40,7 +40,7 @@ func TestAnnounceLocalhost(t *testing.T) {
                },
        }
        var err error
-       srv.pc, err = net.ListenPacket("udp", ":0")
+       srv.pc, err = net.ListenPacket("udp", "localhost:0")
        require.NoError(t, err)
        defer srv.pc.Close()
        go func() {
@@ -92,7 +92,7 @@ func TestUDPTracker(t *testing.T) {
                t.Skip(err)
        }
        require.NoError(t, err)
-       t.Log(ar)
+       t.Logf("%+v", ar)
 }
 
 func TestAnnounceRandomInfoHashThirdParty(t *testing.T) {
@@ -143,7 +143,7 @@ func TestAnnounceRandomInfoHashThirdParty(t *testing.T) {
 
 // Check that URLPath option is done correctly.
 func TestURLPathOption(t *testing.T) {
-       conn, err := net.ListenUDP("udp", nil)
+       conn, err := net.ListenPacket("udp", "localhost:0")
        if err != nil {
                panic(err)
        }
@@ -161,6 +161,7 @@ func TestURLPathOption(t *testing.T) {
                announceErr <- err
        }()
        var b [512]byte
+       // conn.SetReadDeadline(time.Now().Add(time.Second))
        _, addr, _ := conn.ReadFrom(b[:])
        r := bytes.NewReader(b[:])
        var h udp.RequestHeader