]> Sergey Matveev's repositories - btrtrc.git/blob - socket.go
Improve network handling and only listen networks we will use
[btrtrc.git] / socket.go
1 package torrent
2
3 import (
4         "context"
5         "net"
6         "net/url"
7         "strconv"
8
9         "github.com/anacrolix/missinggo"
10         "github.com/anacrolix/missinggo/perf"
11         "github.com/pkg/errors"
12         "golang.org/x/net/proxy"
13 )
14
15 type dialer interface {
16         dial(_ context.Context, addr string) (net.Conn, error)
17 }
18
19 type socket interface {
20         net.Listener
21         dialer
22 }
23
24 func getProxyDialer(proxyURL string) (proxy.Dialer, error) {
25         fixedURL, err := url.Parse(proxyURL)
26         if err != nil {
27                 return nil, err
28         }
29
30         return proxy.FromURL(fixedURL, proxy.Direct)
31 }
32
33 func listen(n network, addr, proxyURL string, f firewallCallback) (socket, error) {
34         switch {
35         case n.Tcp:
36                 return listenTcp(n.String(), addr, proxyURL)
37         case n.Udp:
38                 return listenUtp(n.String(), addr, proxyURL, f)
39         default:
40                 panic(n)
41         }
42 }
43
44 func listenTcp(network, address, proxyURL string) (s socket, err error) {
45         l, err := net.Listen(network, address)
46         if err != nil {
47                 return
48         }
49         defer func() {
50                 if err != nil {
51                         l.Close()
52                 }
53         }()
54
55         // If we don't need the proxy - then we should return default net.Dialer,
56         // otherwise, let's try to parse the proxyURL and return proxy.Dialer
57         if len(proxyURL) != 0 {
58                 // TODO: The error should be propagated, as proxy may be in use for
59                 // security or privacy reasons. Also just pass proxy.Dialer in from
60                 // the Config.
61                 if dialer, err := getProxyDialer(proxyURL); err == nil {
62                         return tcpSocket{l, func(ctx context.Context, addr string) (conn net.Conn, err error) {
63                                 defer perf.ScopeTimerErr(&err)()
64                                 return dialer.Dial(network, addr)
65                         }}, nil
66                 }
67         }
68         dialer := net.Dialer{}
69         return tcpSocket{l, func(ctx context.Context, addr string) (conn net.Conn, err error) {
70                 defer perf.ScopeTimerErr(&err)()
71                 return dialer.DialContext(ctx, network, addr)
72         }}, nil
73 }
74
75 type tcpSocket struct {
76         net.Listener
77         d func(ctx context.Context, addr string) (net.Conn, error)
78 }
79
80 func (me tcpSocket) dial(ctx context.Context, addr string) (net.Conn, error) {
81         return me.d(ctx, addr)
82 }
83
84 func listenAll(networks []network, getHost func(string) string, port int, proxyURL string, f firewallCallback) ([]socket, error) {
85         if len(networks) == 0 {
86                 return nil, nil
87         }
88         var nahs []networkAndHost
89         for _, n := range networks {
90                 nahs = append(nahs, networkAndHost{n, getHost(n.String())})
91         }
92         for {
93                 ss, retry, err := listenAllRetry(nahs, port, proxyURL, f)
94                 if !retry {
95                         return ss, err
96                 }
97         }
98 }
99
100 type networkAndHost struct {
101         Network network
102         Host    string
103 }
104
105 func listenAllRetry(nahs []networkAndHost, port int, proxyURL string, f firewallCallback) (ss []socket, retry bool, err error) {
106         ss = make([]socket, 1, len(nahs))
107         portStr := strconv.FormatInt(int64(port), 10)
108         ss[0], err = listen(nahs[0].Network, net.JoinHostPort(nahs[0].Host, portStr), proxyURL, f)
109         if err != nil {
110                 return nil, false, errors.Wrap(err, "first listen")
111         }
112         defer func() {
113                 if err != nil || retry {
114                         for _, s := range ss {
115                                 s.Close()
116                         }
117                         ss = nil
118                 }
119         }()
120         portStr = strconv.FormatInt(int64(missinggo.AddrPort(ss[0].Addr())), 10)
121         for _, nah := range nahs[1:] {
122                 s, err := listen(nah.Network, net.JoinHostPort(nah.Host, portStr), proxyURL, f)
123                 if err != nil {
124                         return ss,
125                                 missinggo.IsAddrInUse(err) && port == 0,
126                                 errors.Wrap(err, "subsequent listen")
127                 }
128                 ss = append(ss, s)
129         }
130         return
131 }
132
133 type firewallCallback func(net.Addr) bool
134
135 func listenUtp(network, addr, proxyURL string, fc firewallCallback) (s socket, err error) {
136         us, err := NewUtpSocket(network, addr, fc)
137         if err != nil {
138                 return
139         }
140
141         // If we don't need the proxy - then we should return default net.Dialer,
142         // otherwise, let's try to parse the proxyURL and return proxy.Dialer
143         if len(proxyURL) != 0 {
144                 if dialer, err := getProxyDialer(proxyURL); err == nil {
145                         return utpSocketSocket{us, network, dialer}, nil
146                 }
147         }
148
149         return utpSocketSocket{us, network, nil}, nil
150 }
151
152 type utpSocketSocket struct {
153         utpSocket
154         network string
155         d       proxy.Dialer
156 }
157
158 func (me utpSocketSocket) dial(ctx context.Context, addr string) (conn net.Conn, err error) {
159         defer perf.ScopeTimerErr(&err)()
160         if me.d != nil {
161                 return me.d.Dial(me.network, addr)
162         }
163
164         return me.utpSocket.DialContext(ctx, me.network, addr)
165 }