From: Matt Joiner Date: Mon, 21 Jun 2021 03:29:26 +0000 (+1000) Subject: Tidy up the Dialer interface X-Git-Tag: v1.29.0~23 X-Git-Url: http://www.git.stargrave.org/?a=commitdiff_plain;h=b9c7d6266b5b4612696f0f1aa3b18ee1f55343aa;p=btrtrc.git Tidy up the Dialer interface --- diff --git a/dialer.go b/dialer.go index e8126bd6..d499af30 100644 --- a/dialer.go +++ b/dialer.go @@ -3,43 +3,32 @@ package torrent import ( "context" "net" - - "github.com/anacrolix/missinggo/perf" ) +// Dialers have the network locked in. type Dialer interface { Dial(_ context.Context, addr string) (net.Conn, error) DialerNetwork() string } -type NetDialer struct { - Network string - Dialer net.Dialer +// An interface to ease wrapping dialers that explicitly include a network parameter. +type DialContexter interface { + DialContext(ctx context.Context, network, addr string) (net.Conn, error) } -func (me NetDialer) DialerNetwork() string { - return me.Network -} +// Used by wrappers of standard library network types. +var DefaultNetDialer = &net.Dialer{} -func (me NetDialer) Dial(ctx context.Context, addr string) (_ net.Conn, err error) { - defer perf.ScopeTimerErr(&err)() - return me.Dialer.DialContext(ctx, me.Network, addr) -} - -func (me NetDialer) LocalAddr() net.Addr { - return netDialerLocalAddr{me.Network, me.Dialer.LocalAddr} +// Adapts a DialContexter to the Dial interface in this package. +type NetworkDialer struct { + Network string + Dialer DialContexter } -type netDialerLocalAddr struct { - network string - addr net.Addr +func (me NetworkDialer) DialerNetwork() string { + return me.Network } -func (me netDialerLocalAddr) Network() string { return me.network } - -func (me netDialerLocalAddr) String() string { - if me.addr == nil { - return "" - } - return me.addr.String() +func (me NetworkDialer) Dial(ctx context.Context, addr string) (_ net.Conn, err error) { + return me.Dialer.DialContext(ctx, me.Network, addr) } diff --git a/socket.go b/socket.go index ba2a091b..7313f632 100644 --- a/socket.go +++ b/socket.go @@ -39,15 +39,16 @@ func listenTcp(network, address string) (s socket, err error) { l, err := net.Listen(network, address) return tcpSocket{ Listener: l, - NetDialer: NetDialer{ + NetworkDialer: NetworkDialer{ Network: network, + Dialer: DefaultNetDialer, }, }, err } type tcpSocket struct { net.Listener - NetDialer + NetworkDialer } func listenAll(networks []network, getHost func(string) string, port int, f firewallCallback) ([]socket, error) { diff --git a/test/unix_test.go b/test/unix_test.go index 1e877c0d..d8a3ff9f 100644 --- a/test/unix_test.go +++ b/test/unix_test.go @@ -24,7 +24,7 @@ func TestUnixConns(t *testing.T) { cfg.Debug = true }, Client: func(cl *torrent.Client) { - cl.AddDialer(torrent.NetDialer{Network: "unix"}) + cl.AddDialer(torrent.NetworkDialer{Network: "unix", Dialer: torrent.DefaultNetDialer}) l, err := net.Listen("unix", filepath.Join(t.TempDir(), "socket")) if err != nil { panic(err)