From: Matt Joiner <anacrolix@gmail.com>
Date: Tue, 27 Nov 2018 23:30:21 +0000 (+1100)
Subject: Improve network handling and only listen networks we will use
X-Git-Tag: v1.0.0~9
X-Git-Url: http://www.git.stargrave.org/?a=commitdiff_plain;h=d9e1ebde700e11bb534ba117605238c483598728;p=btrtrc.git

Improve network handling and only listen networks we will use

Fixes #290.
---

diff --git a/client.go b/client.go
index a896a97a..1fa9c2bc 100644
--- a/client.go
+++ b/client.go
@@ -225,7 +225,7 @@ func NewClient(cfg *ClientConfig) (cl *Client, err error) {
 		}
 	}
 
-	cl.conns, err = listenAll(allPeerNetworks, cl.config.ListenHost, cl.config.ListenPort, cl.config.ProxyURL, cl.firewallCallback)
+	cl.conns, err = listenAll(cl.listenNetworks(), cl.config.ListenHost, cl.config.ListenPort, cl.config.ProxyURL, cl.firewallCallback)
 	if err != nil {
 		return
 	}
@@ -233,7 +233,7 @@ func NewClient(cfg *ClientConfig) (cl *Client, err error) {
 	cl.LocalPort()
 
 	for _, s := range cl.conns {
-		if peerNetworkEnabled(s.Addr().Network(), cl.config) {
+		if peerNetworkEnabled(parseNetworkString(s.Addr().Network()), cl.config) {
 			go cl.acceptConnections(s)
 		}
 	}
@@ -266,7 +266,7 @@ func (cl *Client) firewallCallback(net.Addr) bool {
 	return block
 }
 
-func (cl *Client) enabledPeerNetworks() (ns []string) {
+func (cl *Client) enabledPeerNetworks() (ns []network) {
 	for _, n := range allPeerNetworks {
 		if peerNetworkEnabled(n, cl.config) {
 			ns = append(ns, n)
@@ -275,6 +275,31 @@ func (cl *Client) enabledPeerNetworks() (ns []string) {
 	return
 }
 
+func (cl *Client) listenOnNetwork(n network) bool {
+	if n.Ipv4 && cl.config.DisableIPv4 {
+		return false
+	}
+	if n.Ipv6 && cl.config.DisableIPv6 {
+		return false
+	}
+	if n.Tcp && cl.config.DisableTCP {
+		return false
+	}
+	if n.Udp && cl.config.DisableUTP && cl.config.NoDHT {
+		return false
+	}
+	return true
+}
+
+func (cl *Client) listenNetworks() (ns []network) {
+	for _, n := range allPeerNetworks {
+		if cl.listenOnNetwork(n) {
+			ns = append(ns, n)
+		}
+	}
+	return
+}
+
 func (cl *Client) newDhtServer(conn net.PacketConn) (s *dht.Server, err error) {
 	cfg := dht.ServerConfig{
 		IPBlocklist:    cl.ipBlockList,
@@ -475,26 +500,6 @@ func (cl *Client) dopplegangerAddr(addr string) bool {
 	return ok
 }
 
-var allPeerNetworks = []string{"tcp4", "tcp6", "udp4", "udp6"}
-
-func peerNetworkEnabled(network string, cfg *ClientConfig) bool {
-	c := func(s string) bool {
-		return strings.Contains(network, s)
-	}
-	if cfg.DisableUTP {
-		if c("udp") || c("utp") {
-			return false
-		}
-	}
-	if cfg.DisableTCP && c("tcp") {
-		return false
-	}
-	if cfg.DisableIPv6 && c("6") {
-		return false
-	}
-	return true
-}
-
 // Returns a connection over UTP or TCP, whichever is first to connect.
 func (cl *Client) dialFirst(ctx context.Context, addr string) dialResult {
 	ctx, cancel := context.WithCancel(ctx)
@@ -507,7 +512,7 @@ func (cl *Client) dialFirst(ctx context.Context, addr string) dialResult {
 		defer cl.unlock()
 		cl.eachListener(func(s socket) bool {
 			network := s.Addr().Network()
-			if peerNetworkEnabled(network, cl.config) {
+			if peerNetworkEnabled(parseNetworkString(network), cl.config) {
 				left++
 				go func() {
 					cte := cl.config.ConnTracker.Wait(
diff --git a/connection.go b/connection.go
index 69b80244..21f1fdc2 100644
--- a/connection.go
+++ b/connection.go
@@ -239,7 +239,7 @@ func (cn *connection) connectionFlags() (ret string) {
 }
 
 func (cn *connection) utp() bool {
-	return isUtpNetwork(cn.network)
+	return parseNetworkString(cn.network).Udp
 }
 
 // Inspired by https://github.com/transmission/transmission/wiki/Peer-Status-Text.
diff --git a/networks.go b/networks.go
new file mode 100644
index 00000000..068a9a58
--- /dev/null
+++ b/networks.go
@@ -0,0 +1,57 @@
+package torrent
+
+import "strings"
+
+var allPeerNetworks = func() (ret []network) {
+	for _, s := range []string{"tcp4", "tcp6", "udp4", "udp6"} {
+		ret = append(ret, parseNetworkString(s))
+	}
+	return
+}()
+
+type network struct {
+	Ipv4 bool
+	Ipv6 bool
+	Udp  bool
+	Tcp  bool
+}
+
+func (n network) String() (ret string) {
+	a := func(b bool, s string) {
+		if b {
+			ret += s
+		}
+	}
+	a(n.Udp, "udp")
+	a(n.Tcp, "tcp")
+	a(n.Ipv4, "4")
+	a(n.Ipv6, "6")
+	return
+}
+
+func parseNetworkString(network string) (ret network) {
+	c := func(s string) bool {
+		return strings.Contains(network, s)
+	}
+	ret.Ipv4 = c("4")
+	ret.Ipv6 = c("6")
+	ret.Udp = c("udp")
+	ret.Tcp = c("tcp")
+	return
+}
+
+func peerNetworkEnabled(n network, cfg *ClientConfig) bool {
+	if cfg.DisableUTP && n.Udp {
+		return false
+	}
+	if cfg.DisableTCP && n.Tcp {
+		return false
+	}
+	if cfg.DisableIPv6 && n.Ipv6 {
+		return false
+	}
+	if cfg.DisableIPv4 && n.Ipv4 {
+		return false
+	}
+	return true
+}
diff --git a/socket.go b/socket.go
index 4d343eef..6c16e588 100644
--- a/socket.go
+++ b/socket.go
@@ -2,11 +2,9 @@ package torrent
 
 import (
 	"context"
-	"fmt"
 	"net"
 	"net/url"
 	"strconv"
-	"strings"
 
 	"github.com/anacrolix/missinggo"
 	"github.com/anacrolix/missinggo/perf"
@@ -32,24 +30,17 @@ func getProxyDialer(proxyURL string) (proxy.Dialer, error) {
 	return proxy.FromURL(fixedURL, proxy.Direct)
 }
 
-func listen(network, addr, proxyURL string, f firewallCallback) (socket, error) {
-	if isTcpNetwork(network) {
-		return listenTcp(network, addr, proxyURL)
-	} else if isUtpNetwork(network) {
-		return listenUtp(network, addr, proxyURL, f)
-	} else {
-		panic(fmt.Sprintf("unknown network %q", network))
+func listen(n network, addr, proxyURL string, f firewallCallback) (socket, error) {
+	switch {
+	case n.Tcp:
+		return listenTcp(n.String(), addr, proxyURL)
+	case n.Udp:
+		return listenUtp(n.String(), addr, proxyURL, f)
+	default:
+		panic(n)
 	}
 }
 
-func isTcpNetwork(s string) bool {
-	return strings.Contains(s, "tcp")
-}
-
-func isUtpNetwork(s string) bool {
-	return strings.Contains(s, "utp") || strings.Contains(s, "udp")
-}
-
 func listenTcp(network, address, proxyURL string) (s socket, err error) {
 	l, err := net.Listen(network, address)
 	if err != nil {
@@ -90,13 +81,13 @@ func (me tcpSocket) dial(ctx context.Context, addr string) (net.Conn, error) {
 	return me.d(ctx, addr)
 }
 
-func listenAll(networks []string, getHost func(string) string, port int, proxyURL string, f firewallCallback) ([]socket, error) {
+func listenAll(networks []network, getHost func(string) string, port int, proxyURL string, f firewallCallback) ([]socket, error) {
 	if len(networks) == 0 {
 		return nil, nil
 	}
 	var nahs []networkAndHost
 	for _, n := range networks {
-		nahs = append(nahs, networkAndHost{n, getHost(n)})
+		nahs = append(nahs, networkAndHost{n, getHost(n.String())})
 	}
 	for {
 		ss, retry, err := listenAllRetry(nahs, port, proxyURL, f)
@@ -107,7 +98,7 @@ func listenAll(networks []string, getHost func(string) string, port int, proxyUR
 }
 
 type networkAndHost struct {
-	Network string
+	Network network
 	Host    string
 }