]> Sergey Matveev's repositories - btrtrc.git/blobdiff - socket.go
Improve network handling and only listen networks we will use
[btrtrc.git] / socket.go
index 4d343eef25be8b5f672ccf353b95ee7f3c29950e..6c16e588a874c334979e05da0f09592b0da1b6b5 100644 (file)
--- a/socket.go
+++ b/socket.go
@@ -2,11 +2,9 @@ package torrent
 
 import (
        "context"
-       "fmt"
        "net"
        "net/url"
        "strconv"
-       "strings"
 
        "github.com/anacrolix/missinggo"
        "github.com/anacrolix/missinggo/perf"
@@ -32,24 +30,17 @@ func getProxyDialer(proxyURL string) (proxy.Dialer, error) {
        return proxy.FromURL(fixedURL, proxy.Direct)
 }
 
-func listen(network, addr, proxyURL string, f firewallCallback) (socket, error) {
-       if isTcpNetwork(network) {
-               return listenTcp(network, addr, proxyURL)
-       } else if isUtpNetwork(network) {
-               return listenUtp(network, addr, proxyURL, f)
-       } else {
-               panic(fmt.Sprintf("unknown network %q", network))
+func listen(n network, addr, proxyURL string, f firewallCallback) (socket, error) {
+       switch {
+       case n.Tcp:
+               return listenTcp(n.String(), addr, proxyURL)
+       case n.Udp:
+               return listenUtp(n.String(), addr, proxyURL, f)
+       default:
+               panic(n)
        }
 }
 
-func isTcpNetwork(s string) bool {
-       return strings.Contains(s, "tcp")
-}
-
-func isUtpNetwork(s string) bool {
-       return strings.Contains(s, "utp") || strings.Contains(s, "udp")
-}
-
 func listenTcp(network, address, proxyURL string) (s socket, err error) {
        l, err := net.Listen(network, address)
        if err != nil {
@@ -90,13 +81,13 @@ func (me tcpSocket) dial(ctx context.Context, addr string) (net.Conn, error) {
        return me.d(ctx, addr)
 }
 
-func listenAll(networks []string, getHost func(string) string, port int, proxyURL string, f firewallCallback) ([]socket, error) {
+func listenAll(networks []network, getHost func(string) string, port int, proxyURL string, f firewallCallback) ([]socket, error) {
        if len(networks) == 0 {
                return nil, nil
        }
        var nahs []networkAndHost
        for _, n := range networks {
-               nahs = append(nahs, networkAndHost{n, getHost(n)})
+               nahs = append(nahs, networkAndHost{n, getHost(n.String())})
        }
        for {
                ss, retry, err := listenAllRetry(nahs, port, proxyURL, f)
@@ -107,7 +98,7 @@ func listenAll(networks []string, getHost func(string) string, port int, proxyUR
 }
 
 type networkAndHost struct {
-       Network string
+       Network network
        Host    string
 }