]> Sergey Matveev's repositories - btrtrc.git/blobdiff - client.go
Improve network handling and only listen networks we will use
[btrtrc.git] / client.go
index a896a97a35c89d98f85faf50f6b9c0dd47e0e62e..1fa9c2bc2950f77840adc708c62c4fd3f901b7de 100644 (file)
--- a/client.go
+++ b/client.go
@@ -225,7 +225,7 @@ func NewClient(cfg *ClientConfig) (cl *Client, err error) {
                }
        }
 
-       cl.conns, err = listenAll(allPeerNetworks, cl.config.ListenHost, cl.config.ListenPort, cl.config.ProxyURL, cl.firewallCallback)
+       cl.conns, err = listenAll(cl.listenNetworks(), cl.config.ListenHost, cl.config.ListenPort, cl.config.ProxyURL, cl.firewallCallback)
        if err != nil {
                return
        }
@@ -233,7 +233,7 @@ func NewClient(cfg *ClientConfig) (cl *Client, err error) {
        cl.LocalPort()
 
        for _, s := range cl.conns {
-               if peerNetworkEnabled(s.Addr().Network(), cl.config) {
+               if peerNetworkEnabled(parseNetworkString(s.Addr().Network()), cl.config) {
                        go cl.acceptConnections(s)
                }
        }
@@ -266,7 +266,7 @@ func (cl *Client) firewallCallback(net.Addr) bool {
        return block
 }
 
-func (cl *Client) enabledPeerNetworks() (ns []string) {
+func (cl *Client) enabledPeerNetworks() (ns []network) {
        for _, n := range allPeerNetworks {
                if peerNetworkEnabled(n, cl.config) {
                        ns = append(ns, n)
@@ -275,6 +275,31 @@ func (cl *Client) enabledPeerNetworks() (ns []string) {
        return
 }
 
+func (cl *Client) listenOnNetwork(n network) bool {
+       if n.Ipv4 && cl.config.DisableIPv4 {
+               return false
+       }
+       if n.Ipv6 && cl.config.DisableIPv6 {
+               return false
+       }
+       if n.Tcp && cl.config.DisableTCP {
+               return false
+       }
+       if n.Udp && cl.config.DisableUTP && cl.config.NoDHT {
+               return false
+       }
+       return true
+}
+
+func (cl *Client) listenNetworks() (ns []network) {
+       for _, n := range allPeerNetworks {
+               if cl.listenOnNetwork(n) {
+                       ns = append(ns, n)
+               }
+       }
+       return
+}
+
 func (cl *Client) newDhtServer(conn net.PacketConn) (s *dht.Server, err error) {
        cfg := dht.ServerConfig{
                IPBlocklist:    cl.ipBlockList,
@@ -475,26 +500,6 @@ func (cl *Client) dopplegangerAddr(addr string) bool {
        return ok
 }
 
-var allPeerNetworks = []string{"tcp4", "tcp6", "udp4", "udp6"}
-
-func peerNetworkEnabled(network string, cfg *ClientConfig) bool {
-       c := func(s string) bool {
-               return strings.Contains(network, s)
-       }
-       if cfg.DisableUTP {
-               if c("udp") || c("utp") {
-                       return false
-               }
-       }
-       if cfg.DisableTCP && c("tcp") {
-               return false
-       }
-       if cfg.DisableIPv6 && c("6") {
-               return false
-       }
-       return true
-}
-
 // Returns a connection over UTP or TCP, whichever is first to connect.
 func (cl *Client) dialFirst(ctx context.Context, addr string) dialResult {
        ctx, cancel := context.WithCancel(ctx)
@@ -507,7 +512,7 @@ func (cl *Client) dialFirst(ctx context.Context, addr string) dialResult {
                defer cl.unlock()
                cl.eachListener(func(s socket) bool {
                        network := s.Addr().Network()
-                       if peerNetworkEnabled(network, cl.config) {
+                       if peerNetworkEnabled(parseNetworkString(network), cl.config) {
                                left++
                                go func() {
                                        cte := cl.config.ConnTracker.Wait(