]> Sergey Matveev's repositories - btrtrc.git/commitdiff
Split Client dialers and listeners
authorMatt Joiner <anacrolix@gmail.com>
Wed, 19 Feb 2020 23:57:02 +0000 (10:57 +1100)
committerMatt Joiner <anacrolix@gmail.com>
Thu, 20 Feb 2020 00:10:05 +0000 (11:10 +1100)
client.go
client_test.go
socket.go

index 8c1eb87f2a57413c3299d49c746fd4f0cc57d8ca..60bd27e88b535189147ec5997a281c7b051d89ed 100644 (file)
--- a/client.go
+++ b/client.go
@@ -58,7 +58,8 @@ type Client struct {
        peerID         PeerID
        defaultStorage *storage.Client
        onClose        []func()
-       conns          []socket
+       dialers        []dialer
+       listeners      []listener
        dhtServers     []*dht.Server
        ipBlockList    iplist.Ranger
        // Our BitTorrent protocol extension bytes, sent in our BT handshakes.
@@ -92,7 +93,7 @@ func (cl *Client) PeerID() PeerID {
 }
 
 func (cl *Client) LocalPort() (port int) {
-       cl.eachListener(func(l socket) bool {
+       cl.eachListener(func(l listener) bool {
                _port := missinggo.AddrPort(l.Addr())
                if _port == 0 {
                        panic(l)
@@ -227,28 +228,34 @@ func NewClient(cfg *ClientConfig) (cl *Client, err error) {
                }
        }
 
-       cl.conns, err = listenAll(cl.listenNetworks(), cl.config.ListenHost, cl.config.ListenPort, cl.config.ProxyURL, cl.firewallCallback)
+       sockets, err := listenAll(cl.listenNetworks(), cl.config.ListenHost, cl.config.ListenPort, cl.firewallCallback)
        if err != nil {
                return
        }
+
        // Check for panics.
        cl.LocalPort()
 
-       for _, s := range cl.conns {
+       for _, _s := range sockets {
+               s := _s // Go is fucking retarded.
+               cl.onClose = append(cl.onClose, func() { s.Close() })
                if peerNetworkEnabled(parseNetworkString(s.Addr().Network()), cl.config) {
+                       cl.dialers = append(cl.dialers, s)
+                       cl.listeners = append(cl.listeners, s)
                        go cl.acceptConnections(s)
                }
        }
 
        go cl.forwardPort()
        if !cfg.NoDHT {
-               for _, s := range cl.conns {
+               for _, s := range sockets {
                        if pc, ok := s.(net.PacketConn); ok {
                                ds, err := cl.newDhtServer(pc)
                                if err != nil {
                                        panic(err)
                                }
                                cl.dhtServers = append(cl.dhtServers, ds)
+                               cl.onClose = append(cl.onClose, func() { ds.Close() })
                        }
                }
        }
@@ -334,27 +341,17 @@ func (cl *Client) eachDhtServer(f func(*dht.Server)) {
        }
 }
 
-func (cl *Client) closeSockets() {
-       cl.eachListener(func(l socket) bool {
-               l.Close()
-               return true
-       })
-       cl.conns = nil
-}
-
 // Stops the client. All connections to peers are closed and all activity will
 // come to a halt.
 func (cl *Client) Close() {
        cl.lock()
        defer cl.unlock()
        cl.closed.Set()
-       cl.eachDhtServer(func(s *dht.Server) { s.Close() })
-       cl.closeSockets()
        for _, t := range cl.torrents {
                t.close()
        }
-       for _, f := range cl.onClose {
-               f()
+       for i := range cl.onClose {
+               cl.onClose[len(cl.onClose)-1-i]()
        }
        cl.event.Broadcast()
 }
@@ -521,18 +518,14 @@ func (cl *Client) dialFirst(ctx context.Context, addr string) (res dialResult) {
        func() {
                cl.lock()
                defer cl.unlock()
-               cl.eachListener(func(s socket) bool {
+               cl.eachDialer(func(s dialer) bool {
                        func() {
-                               network := s.Addr().Network()
-                               if !peerNetworkEnabled(parseNetworkString(network), cl.config) {
-                                       return
-                               }
                                left++
                                //cl.logger.Printf("dialing %s on %s/%s", addr, s.Addr().Network(), s.Addr())
                                go func() {
                                        resCh <- dialResult{
                                                cl.dialFromSocket(ctx, s, addr),
-                                               network,
+                                               s.LocalAddr().Network(),
                                        }
                                }()
                        }()
@@ -566,11 +559,11 @@ func (cl *Client) dialFirst(ctx context.Context, addr string) (res dialResult) {
        return res
 }
 
-func (cl *Client) dialFromSocket(ctx context.Context, s socket, addr string) net.Conn {
-       network := s.Addr().Network()
+func (cl *Client) dialFromSocket(ctx context.Context, s dialer, addr string) net.Conn {
+       network := s.LocalAddr().Network()
        cte := cl.config.ConnTracker.Wait(
                ctx,
-               conntrack.Entry{network, s.Addr().String(), addr},
+               conntrack.Entry{network, s.LocalAddr().String(), addr},
                "dial torrent client",
                0,
        )
@@ -1264,8 +1257,16 @@ func firstNotNil(ips ...net.IP) net.IP {
        return nil
 }
 
-func (cl *Client) eachListener(f func(socket) bool) {
-       for _, s := range cl.conns {
+func (cl *Client) eachDialer(f func(dialer) bool) {
+       for _, s := range cl.dialers {
+               if !f(s) {
+                       break
+               }
+       }
+}
+
+func (cl *Client) eachListener(f func(listener) bool) {
+       for _, s := range cl.listeners {
                if !f(s) {
                        break
                }
@@ -1273,7 +1274,7 @@ func (cl *Client) eachListener(f func(socket) bool) {
 }
 
 func (cl *Client) findListener(f func(net.Listener) bool) (ret net.Listener) {
-       cl.eachListener(func(l socket) bool {
+       cl.eachListener(func(l listener) bool {
                ret = l
                return !f(l)
        })
@@ -1310,7 +1311,7 @@ func (cl *Client) publicAddr(peer net.IP) IpPort {
 func (cl *Client) ListenAddrs() (ret []net.Addr) {
        cl.lock()
        defer cl.unlock()
-       cl.eachListener(func(l socket) bool {
+       cl.eachListener(func(l listener) bool {
                ret = append(ret, l.Addr())
                return true
        })
index 0278d42c41268051a6d7235ff691fcbec067f208..9e070eabd4aa27e2a6d99bc8d0b1bdff501ea5d0 100644 (file)
@@ -910,7 +910,7 @@ func TestClientDynamicListenPortAllProtocols(t *testing.T) {
        defer cl.Close()
        port := cl.LocalPort()
        assert.NotEqual(t, 0, port)
-       cl.eachListener(func(s socket) bool {
+       cl.eachListener(func(s listener) bool {
                assert.Equal(t, port, missinggo.AddrPort(s.Addr()))
                return true
        })
index d61e3d32819763aa42aa7a4008be99f264d031bf..c5f7dbcd96209a124f38a473a0c3e7f5168b38a8 100644 (file)
--- a/socket.go
+++ b/socket.go
@@ -2,95 +2,72 @@ package torrent
 
 import (
        "context"
-       "fmt"
        "net"
-       "net/url"
        "strconv"
 
        "github.com/anacrolix/missinggo"
        "github.com/anacrolix/missinggo/perf"
        "github.com/pkg/errors"
-       "golang.org/x/net/proxy"
 )
 
 type dialer interface {
        dial(_ context.Context, addr string) (net.Conn, error)
+       LocalAddr() net.Addr
 }
 
-type socket interface {
+type listener interface {
        net.Listener
-       dialer
 }
 
-func getProxyDialer(proxyURL string) (proxy.Dialer, error) {
-       fixedURL, err := url.Parse(proxyURL)
-       if err != nil {
-               return nil, err
-       }
-
-       return proxy.FromURL(fixedURL, proxy.Direct)
+type socket interface {
+       listener
+       dialer
 }
 
-func listen(n network, addr, proxyURL string, f firewallCallback) (socket, error) {
+func listen(n network, addr string, f firewallCallback) (socket, error) {
        switch {
        case n.Tcp:
-               return listenTcp(n.String(), addr, proxyURL)
+               return listenTcp(n.String(), addr)
        case n.Udp:
-               return listenUtp(n.String(), addr, proxyURL, f)
+               return listenUtp(n.String(), addr, f)
        default:
                panic(n)
        }
 }
 
-func listenTcp(network, address, proxyURL string) (s socket, err error) {
+func listenTcp(network, address string) (s socket, err error) {
        l, err := net.Listen(network, address)
-       if err != nil {
-               return
-       }
-       defer func() {
-               if err != nil {
-                       l.Close()
-               }
-       }()
-
-       // If we don't need the proxy - then we should return default net.Dialer,
-       // otherwise, let's try to parse the proxyURL and return proxy.Dialer
-       if len(proxyURL) != 0 {
-               dl := disabledListener{l}
-               dialer, err := getProxyDialer(proxyURL)
-               if err != nil {
-                       return nil, err
-               }
-               return tcpSocket{dl, func(ctx context.Context, addr string) (conn net.Conn, err error) {
-                       defer perf.ScopeTimerErr(&err)()
-                       return dialer.Dial(network, addr)
-               }}, nil
-       }
-       dialer := net.Dialer{}
-       return tcpSocket{l, func(ctx context.Context, addr string) (conn net.Conn, err error) {
-               defer perf.ScopeTimerErr(&err)()
-               return dialer.DialContext(ctx, network, addr)
-       }}, nil
+       return tcpSocket{
+               Listener: l,
+               network:  network,
+       }, err
 }
 
-type disabledListener struct {
+type tcpSocket struct {
        net.Listener
+       network string
+       dialer  net.Dialer
 }
 
-func (dl disabledListener) Accept() (net.Conn, error) {
-       return nil, fmt.Errorf("tcp listener disabled due to proxy")
+func (me tcpSocket) dial(ctx context.Context, addr string) (_ net.Conn, err error) {
+       defer perf.ScopeTimerErr(&err)()
+       return me.dialer.DialContext(ctx, me.network, addr)
 }
 
-type tcpSocket struct {
-       net.Listener
-       d func(ctx context.Context, addr string) (net.Conn, error)
+func (me tcpSocket) LocalAddr() net.Addr {
+       return tcpSocketLocalAddr{me.network, me.Listener.Addr().String()}
 }
 
-func (me tcpSocket) dial(ctx context.Context, addr string) (net.Conn, error) {
-       return me.d(ctx, addr)
+type tcpSocketLocalAddr struct {
+       network string
+       s       string
 }
 
-func listenAll(networks []network, getHost func(string) string, port int, proxyURL string, f firewallCallback) ([]socket, error) {
+func (me tcpSocketLocalAddr) Network() string { return me.network }
+
+func (me tcpSocketLocalAddr) String() string { return "" }
+
+func listenAll(networks []network, getHost func(string) string, port int, f firewallCallback) ([]socket, error) {
        if len(networks) == 0 {
                return nil, nil
        }
@@ -99,7 +76,7 @@ func listenAll(networks []network, getHost func(string) string, port int, proxyU
                nahs = append(nahs, networkAndHost{n, getHost(n.String())})
        }
        for {
-               ss, retry, err := listenAllRetry(nahs, port, proxyURL, f)
+               ss, retry, err := listenAllRetry(nahs, port, f)
                if !retry {
                        return ss, err
                }
@@ -111,10 +88,10 @@ type networkAndHost struct {
        Host    string
 }
 
-func listenAllRetry(nahs []networkAndHost, port int, proxyURL string, f firewallCallback) (ss []socket, retry bool, err error) {
+func listenAllRetry(nahs []networkAndHost, port int, f firewallCallback) (ss []socket, retry bool, err error) {
        ss = make([]socket, 1, len(nahs))
        portStr := strconv.FormatInt(int64(port), 10)
-       ss[0], err = listen(nahs[0].Network, net.JoinHostPort(nahs[0].Host, portStr), proxyURL, f)
+       ss[0], err = listen(nahs[0].Network, net.JoinHostPort(nahs[0].Host, portStr), f)
        if err != nil {
                return nil, false, errors.Wrap(err, "first listen")
        }
@@ -128,7 +105,7 @@ func listenAllRetry(nahs []networkAndHost, port int, proxyURL string, f firewall
        }()
        portStr = strconv.FormatInt(int64(missinggo.AddrPort(ss[0].Addr())), 10)
        for _, nah := range nahs[1:] {
-               s, err := listen(nah.Network, net.JoinHostPort(nah.Host, portStr), proxyURL, f)
+               s, err := listen(nah.Network, net.JoinHostPort(nah.Host, portStr), f)
                if err != nil {
                        return ss,
                                missinggo.IsAddrInUse(err) && port == 0,
@@ -141,45 +118,17 @@ func listenAllRetry(nahs []networkAndHost, port int, proxyURL string, f firewall
 
 type firewallCallback func(net.Addr) bool
 
-func listenUtp(network, addr, proxyURL string, fc firewallCallback) (s socket, err error) {
+func listenUtp(network, addr string, fc firewallCallback) (socket, error) {
        us, err := NewUtpSocket(network, addr, fc)
-       if err != nil {
-               return
-       }
-
-       // If we don't need the proxy - then we should return default net.Dialer,
-       // otherwise, let's try to parse the proxyURL and return proxy.Dialer
-       if len(proxyURL) != 0 {
-               ds := disabledUtpSocket{us}
-               dialer, err := getProxyDialer(proxyURL)
-               if err != nil {
-                       return nil, err
-               }
-               return utpSocketSocket{ds, network, dialer}, nil
-       }
-
-       return utpSocketSocket{us, network, nil}, nil
-}
-
-type disabledUtpSocket struct {
-       utpSocket
-}
-
-func (ds disabledUtpSocket) Accept() (net.Conn, error) {
-       return nil, fmt.Errorf("utp listener disabled due to proxy")
+       return utpSocketSocket{us, network}, err
 }
 
 type utpSocketSocket struct {
        utpSocket
        network string
-       d       proxy.Dialer
 }
 
 func (me utpSocketSocket) dial(ctx context.Context, addr string) (conn net.Conn, err error) {
        defer perf.ScopeTimerErr(&err)()
-       if me.d != nil {
-               return me.d.Dial(me.network, addr)
-       }
-
        return me.utpSocket.DialContext(ctx, me.network, addr)
 }