]> Sergey Matveev's repositories - btrtrc.git/commitdiff
Use the new firewall callback support in go-libutp
authorMatt Joiner <anacrolix@gmail.com>
Wed, 25 Jul 2018 07:11:09 +0000 (17:11 +1000)
committerMatt Joiner <anacrolix@gmail.com>
Wed, 25 Jul 2018 07:11:09 +0000 (17:11 +1000)
client.go
client_test.go
network_test.go
socket.go
utp_go.go
utp_libutp.go
utp_test.go

index 09226f5452e611ecdee716b96a0e362b2844b053..105d4419b9471adbbaeb413d2bf35f2fe2a46ca4 100644 (file)
--- a/client.go
+++ b/client.go
@@ -220,7 +220,7 @@ func NewClient(cfg *ClientConfig) (cl *Client, err error) {
                }
        }
 
-       cl.conns, err = listenAll(cl.enabledPeerNetworks(), cl.config.ListenHost, cl.config.ListenPort, cl.config.ProxyURL)
+       cl.conns, err = listenAll(cl.enabledPeerNetworks(), cl.config.ListenHost, cl.config.ListenPort, cl.config.ProxyURL, cl.firewallCallback)
        if err != nil {
                return
        }
@@ -249,6 +249,18 @@ func NewClient(cfg *ClientConfig) (cl *Client, err error) {
        return
 }
 
+func (cl *Client) firewallCallback(net.Addr) bool {
+       cl.rLock()
+       block := !cl.wantConns()
+       cl.rUnlock()
+       if block {
+               torrent.Add("connections firewalled", 1)
+       } else {
+               torrent.Add("connections not firewalled", 1)
+       }
+       return block
+}
+
 func (cl *Client) enabledPeerNetworks() (ns []string) {
        for _, n := range allPeerNetworks {
                if peerNetworkEnabled(n, cl.config) {
@@ -340,16 +352,23 @@ func (cl *Client) ipIsBlocked(ip net.IP) bool {
        return blocked
 }
 
+func (cl *Client) wantConns() bool {
+       for _, t := range cl.torrents {
+               if t.wantConns() {
+                       return true
+               }
+       }
+       return false
+}
+
 func (cl *Client) waitAccept() {
        for {
-               for _, t := range cl.torrents {
-                       if t.wantConns() {
-                               return
-                       }
-               }
                if cl.closed.IsSet() {
                        return
                }
+               if cl.wantConns() {
+                       return
+               }
                cl.event.Wait()
        }
 }
index f36db55254b4e3416d5b278674cd395855dea5a6..ca4041cf2ea25c03608d7ed27d5c7510c75907e9 100644 (file)
@@ -1012,7 +1012,7 @@ func TestMultipleTorrentsWithEncryption(t *testing.T) {
 }
 
 func TestClientAddressInUse(t *testing.T) {
-       s, _ := NewUtpSocket("udp", ":50007")
+       s, _ := NewUtpSocket("udp", ":50007", nil)
        if s != nil {
                defer s.Close()
        }
index b7dd1695e82dcd4e67603c622b0528900bfdaf81..c38f7783828fce2d994cca6d3ac624157571b968 100644 (file)
@@ -23,7 +23,7 @@ func testListenerNetwork(
 }
 
 func listenUtpListener(net, addr string) (l net.Listener, err error) {
-       l, err = NewUtpSocket(net, addr)
+       l, err = NewUtpSocket(net, addr, nil)
        return
 }
 
index adeef594f049031fb5e25126c5bea7707eb69d10..d4c423d1f9c8a243c80c2df0ee0a88e4105f2fab 100644 (file)
--- a/socket.go
+++ b/socket.go
@@ -31,11 +31,11 @@ func getProxyDialer(proxyURL string) (proxy.Dialer, error) {
        return proxy.FromURL(fixedURL, proxy.Direct)
 }
 
-func listen(network, addr, proxyURL string) (socket, error) {
+func listen(network, addr, proxyURL string, f firewallCallback) (socket, error) {
        if isTcpNetwork(network) {
                return listenTcp(network, addr, proxyURL)
        } else if isUtpNetwork(network) {
-               return listenUtp(network, addr, proxyURL)
+               return listenUtp(network, addr, proxyURL, f)
        } else {
                panic(fmt.Sprintf("unknown network %q", network))
        }
@@ -97,7 +97,7 @@ func setPort(addr string, port int) string {
        return net.JoinHostPort(host, strconv.FormatInt(int64(port), 10))
 }
 
-func listenAll(networks []string, getHost func(string) string, port int, proxyURL string) ([]socket, error) {
+func listenAll(networks []string, getHost func(string) string, port int, proxyURL string, f firewallCallback) ([]socket, error) {
        if len(networks) == 0 {
                return nil, nil
        }
@@ -106,7 +106,7 @@ func listenAll(networks []string, getHost func(string) string, port int, proxyUR
                nahs = append(nahs, networkAndHost{n, getHost(n)})
        }
        for {
-               ss, retry, err := listenAllRetry(nahs, port, proxyURL)
+               ss, retry, err := listenAllRetry(nahs, port, proxyURL, f)
                if !retry {
                        return ss, err
                }
@@ -118,10 +118,10 @@ type networkAndHost struct {
        Host    string
 }
 
-func listenAllRetry(nahs []networkAndHost, port int, proxyURL string) (ss []socket, retry bool, err error) {
+func listenAllRetry(nahs []networkAndHost, port int, proxyURL string, f firewallCallback) (ss []socket, retry bool, err error) {
        ss = make([]socket, 1, len(nahs))
        portStr := strconv.FormatInt(int64(port), 10)
-       ss[0], err = listen(nahs[0].Network, net.JoinHostPort(nahs[0].Host, portStr), proxyURL)
+       ss[0], err = listen(nahs[0].Network, net.JoinHostPort(nahs[0].Host, portStr), proxyURL, f)
        if err != nil {
                return nil, false, fmt.Errorf("first listen: %s", err)
        }
@@ -135,7 +135,7 @@ func listenAllRetry(nahs []networkAndHost, port int, proxyURL string) (ss []sock
        }()
        portStr = strconv.FormatInt(int64(missinggo.AddrPort(ss[0].Addr())), 10)
        for _, nah := range nahs[1:] {
-               s, err := listen(nah.Network, net.JoinHostPort(nah.Host, portStr), proxyURL)
+               s, err := listen(nah.Network, net.JoinHostPort(nah.Host, portStr), proxyURL, f)
                if err != nil {
                        return ss,
                                missinggo.IsAddrInUse(err) && port == 0,
@@ -146,8 +146,10 @@ func listenAllRetry(nahs []networkAndHost, port int, proxyURL string) (ss []sock
        return
 }
 
-func listenUtp(network, addr, proxyURL string) (s socket, err error) {
-       us, err := NewUtpSocket(network, addr)
+type firewallCallback func(net.Addr) bool
+
+func listenUtp(network, addr, proxyURL string, fc firewallCallback) (s socket, err error) {
+       us, err := NewUtpSocket(network, addr, fc)
        if err != nil {
                return
        }
index c04bed3b8ac9e0aecbfb697d138b005e8fa70e3a..073c34ab997c21fad24bd7881bc15145ac4a1ee7 100644 (file)
--- a/utp_go.go
+++ b/utp_go.go
@@ -6,7 +6,7 @@ import (
        "github.com/anacrolix/utp"
 )
 
-func NewUtpSocket(network, addr string) (utpSocket, error) {
+func NewUtpSocket(network, addr string, _ firewallCallback) (utpSocket, error) {
        s, err := utp.NewSocket(network, addr)
        if s == nil {
                return nil, err
index 94eba7cbc425d6821c7a2f7f767960783c3d192c..46ba3afb92fd55217f0a9c31e570f3a23e23c2c0 100644 (file)
@@ -6,11 +6,16 @@ import (
        "github.com/anacrolix/go-libutp"
 )
 
-func NewUtpSocket(network, addr string) (utpSocket, error) {
+func NewUtpSocket(network, addr string, fc firewallCallback) (utpSocket, error) {
        s, err := utp.NewSocket(network, addr)
        if s == nil {
                return nil, err
-       } else {
+       }
+       if err != nil {
                return s, err
        }
+       if fc != nil {
+               s.SetFirewallCallback(utp.FirewallCallback(fc))
+       }
+       return s, err
 }
index b0917ecb65d631614e68e1d08a39a45a886ae6b5..bacb50134547a6f66be942665d99fa9731153d1f 100644 (file)
@@ -7,7 +7,7 @@ import (
 )
 
 func TestNewUtpSocketErrorNilInterface(t *testing.T) {
-       s, err := NewUtpSocket("fix", "your:language")
+       s, err := NewUtpSocket("fix", "your:language", nil)
        assert.Error(t, err)
        if s != nil {
                t.Fatalf("expected nil, got %#v", s)