client.go | 31 +++++++++++++++++++++++++------ client_test.go | 2 +- network_test.go | 2 +- socket.go | 20 +++++++++++--------- utp_go.go | 2 +- utp_libutp.go | 9 +++++++-- utp_test.go | 2 +- diff --git a/client.go b/client.go index 09226f5452e611ecdee716b96a0e362b2844b053..105d4419b9471adbbaeb413d2bf35f2fe2a46ca4 100644 --- a/client.go +++ b/client.go @@ -220,7 +220,7 @@ panic("error generating peer id") } } - cl.conns, err = listenAll(cl.enabledPeerNetworks(), cl.config.ListenHost, cl.config.ListenPort, cl.config.ProxyURL) + cl.conns, err = listenAll(cl.enabledPeerNetworks(), cl.config.ListenHost, cl.config.ListenPort, cl.config.ProxyURL, cl.firewallCallback) if err != nil { return } @@ -247,6 +247,18 @@ } } return +} + +func (cl *Client) firewallCallback(net.Addr) bool { + cl.rLock() + block := !cl.wantConns() + cl.rUnlock() + if block { + torrent.Add("connections firewalled", 1) + } else { + torrent.Add("connections not firewalled", 1) + } + return block } func (cl *Client) enabledPeerNetworks() (ns []string) { @@ -340,14 +352,21 @@ _, blocked := cl.ipBlockRange(ip) return blocked } +func (cl *Client) wantConns() bool { + for _, t := range cl.torrents { + if t.wantConns() { + return true + } + } + return false +} + func (cl *Client) waitAccept() { for { - for _, t := range cl.torrents { - if t.wantConns() { - return - } - } if cl.closed.IsSet() { + return + } + if cl.wantConns() { return } cl.event.Wait() diff --git a/client_test.go b/client_test.go index f36db55254b4e3416d5b278674cd395855dea5a6..ca4041cf2ea25c03608d7ed27d5c7510c75907e9 100644 --- a/client_test.go +++ b/client_test.go @@ -1012,7 +1012,7 @@ client.WaitAll() } func TestClientAddressInUse(t *testing.T) { - s, _ := NewUtpSocket("udp", ":50007") + s, _ := NewUtpSocket("udp", ":50007", nil) if s != nil { defer s.Close() } diff --git a/network_test.go b/network_test.go index b7dd1695e82dcd4e67603c622b0528900bfdaf81..c38f7783828fce2d994cca6d3ac624157571b968 100644 --- a/network_test.go +++ b/network_test.go @@ -23,7 +23,7 @@ assert.Equal(t, validIp4, ip.To4() != nil, ip) } func listenUtpListener(net, addr string) (l net.Listener, err error) { - l, err = NewUtpSocket(net, addr) + l, err = NewUtpSocket(net, addr, nil) return } diff --git a/socket.go b/socket.go index adeef594f049031fb5e25126c5bea7707eb69d10..d4c423d1f9c8a243c80c2df0ee0a88e4105f2fab 100644 --- a/socket.go +++ b/socket.go @@ -31,11 +31,11 @@ return proxy.FromURL(fixedURL, proxy.Direct) } -func listen(network, addr, proxyURL string) (socket, error) { +func listen(network, addr, proxyURL string, f firewallCallback) (socket, error) { if isTcpNetwork(network) { return listenTcp(network, addr, proxyURL) } else if isUtpNetwork(network) { - return listenUtp(network, addr, proxyURL) + return listenUtp(network, addr, proxyURL, f) } else { panic(fmt.Sprintf("unknown network %q", network)) } @@ -97,7 +97,7 @@ } return net.JoinHostPort(host, strconv.FormatInt(int64(port), 10)) } -func listenAll(networks []string, getHost func(string) string, port int, proxyURL string) ([]socket, error) { +func listenAll(networks []string, getHost func(string) string, port int, proxyURL string, f firewallCallback) ([]socket, error) { if len(networks) == 0 { return nil, nil } @@ -106,7 +106,7 @@ for _, n := range networks { nahs = append(nahs, networkAndHost{n, getHost(n)}) } for { - ss, retry, err := listenAllRetry(nahs, port, proxyURL) + ss, retry, err := listenAllRetry(nahs, port, proxyURL, f) if !retry { return ss, err } @@ -118,10 +118,10 @@ Network string Host string } -func listenAllRetry(nahs []networkAndHost, port int, proxyURL string) (ss []socket, retry bool, err error) { +func listenAllRetry(nahs []networkAndHost, port int, proxyURL string, f firewallCallback) (ss []socket, retry bool, err error) { ss = make([]socket, 1, len(nahs)) portStr := strconv.FormatInt(int64(port), 10) - ss[0], err = listen(nahs[0].Network, net.JoinHostPort(nahs[0].Host, portStr), proxyURL) + ss[0], err = listen(nahs[0].Network, net.JoinHostPort(nahs[0].Host, portStr), proxyURL, f) if err != nil { return nil, false, fmt.Errorf("first listen: %s", err) } @@ -135,7 +135,7 @@ } }() portStr = strconv.FormatInt(int64(missinggo.AddrPort(ss[0].Addr())), 10) for _, nah := range nahs[1:] { - s, err := listen(nah.Network, net.JoinHostPort(nah.Host, portStr), proxyURL) + s, err := listen(nah.Network, net.JoinHostPort(nah.Host, portStr), proxyURL, f) if err != nil { return ss, missinggo.IsAddrInUse(err) && port == 0, @@ -146,8 +146,10 @@ } return } -func listenUtp(network, addr, proxyURL string) (s socket, err error) { - us, err := NewUtpSocket(network, addr) +type firewallCallback func(net.Addr) bool + +func listenUtp(network, addr, proxyURL string, fc firewallCallback) (s socket, err error) { + us, err := NewUtpSocket(network, addr, fc) if err != nil { return } diff --git a/utp_go.go b/utp_go.go index c04bed3b8ac9e0aecbfb697d138b005e8fa70e3a..073c34ab997c21fad24bd7881bc15145ac4a1ee7 100644 --- a/utp_go.go +++ b/utp_go.go @@ -6,7 +6,7 @@ import ( "github.com/anacrolix/utp" ) -func NewUtpSocket(network, addr string) (utpSocket, error) { +func NewUtpSocket(network, addr string, _ firewallCallback) (utpSocket, error) { s, err := utp.NewSocket(network, addr) if s == nil { return nil, err diff --git a/utp_libutp.go b/utp_libutp.go index 94eba7cbc425d6821c7a2f7f767960783c3d192c..46ba3afb92fd55217f0a9c31e570f3a23e23c2c0 100644 --- a/utp_libutp.go +++ b/utp_libutp.go @@ -6,11 +6,16 @@ import ( "github.com/anacrolix/go-libutp" ) -func NewUtpSocket(network, addr string) (utpSocket, error) { +func NewUtpSocket(network, addr string, fc firewallCallback) (utpSocket, error) { s, err := utp.NewSocket(network, addr) if s == nil { return nil, err - } else { + } + if err != nil { return s, err } + if fc != nil { + s.SetFirewallCallback(utp.FirewallCallback(fc)) + } + return s, err } diff --git a/utp_test.go b/utp_test.go index b0917ecb65d631614e68e1d08a39a45a886ae6b5..bacb50134547a6f66be942665d99fa9731153d1f 100644 --- a/utp_test.go +++ b/utp_test.go @@ -7,7 +7,7 @@ "github.com/stretchr/testify/assert" ) func TestNewUtpSocketErrorNilInterface(t *testing.T) { - s, err := NewUtpSocket("fix", "your:language") + s, err := NewUtpSocket("fix", "your:language", nil) assert.Error(t, err) if s != nil { t.Fatalf("expected nil, got %#v", s)