]> Sergey Matveev's repositories - btrtrc.git/blob - socket.go
propagate proxy url parsing err
[btrtrc.git] / socket.go
1 package torrent
2
3 import (
4         "context"
5         "fmt"
6         "net"
7         "net/url"
8         "strconv"
9
10         "github.com/anacrolix/missinggo"
11         "github.com/anacrolix/missinggo/perf"
12         "github.com/pkg/errors"
13         "golang.org/x/net/proxy"
14 )
15
16 type dialer interface {
17         dial(_ context.Context, addr string) (net.Conn, error)
18 }
19
20 type socket interface {
21         net.Listener
22         dialer
23 }
24
25 func getProxyDialer(proxyURL string) (proxy.Dialer, error) {
26         fixedURL, err := url.Parse(proxyURL)
27         if err != nil {
28                 return nil, err
29         }
30
31         return proxy.FromURL(fixedURL, proxy.Direct)
32 }
33
34 func listen(n network, addr, proxyURL string, f firewallCallback) (socket, error) {
35         switch {
36         case n.Tcp:
37                 return listenTcp(n.String(), addr, proxyURL)
38         case n.Udp:
39                 return listenUtp(n.String(), addr, proxyURL, f)
40         default:
41                 panic(n)
42         }
43 }
44
45 func listenTcp(network, address, proxyURL string) (s socket, err error) {
46         l, err := net.Listen(network, address)
47         if err != nil {
48                 return
49         }
50         defer func() {
51                 if err != nil {
52                         l.Close()
53                 }
54         }()
55
56         // If we don't need the proxy - then we should return default net.Dialer,
57         // otherwise, let's try to parse the proxyURL and return proxy.Dialer
58         if len(proxyURL) != 0 {
59                 dl := disabledListener{l}
60                 dialer, err := getProxyDialer(proxyURL)
61                 if err != nil {
62                         return nil, err
63                 }
64                 return tcpSocket{dl, func(ctx context.Context, addr string) (conn net.Conn, err error) {
65                         defer perf.ScopeTimerErr(&err)()
66                         return dialer.Dial(network, addr)
67                 }}, nil
68         }
69         dialer := net.Dialer{}
70         return tcpSocket{l, func(ctx context.Context, addr string) (conn net.Conn, err error) {
71                 defer perf.ScopeTimerErr(&err)()
72                 return dialer.DialContext(ctx, network, addr)
73         }}, nil
74 }
75
76 type disabledListener struct {
77         net.Listener
78 }
79
80 func (dl disabledListener) Accept() (net.Conn, error) {
81         return nil, fmt.Errorf("tcp listener disabled due to proxy")
82 }
83
84 type tcpSocket struct {
85         net.Listener
86         d func(ctx context.Context, addr string) (net.Conn, error)
87 }
88
89 func (me tcpSocket) dial(ctx context.Context, addr string) (net.Conn, error) {
90         return me.d(ctx, addr)
91 }
92
93 func listenAll(networks []network, getHost func(string) string, port int, proxyURL string, f firewallCallback) ([]socket, error) {
94         if len(networks) == 0 {
95                 return nil, nil
96         }
97         var nahs []networkAndHost
98         for _, n := range networks {
99                 nahs = append(nahs, networkAndHost{n, getHost(n.String())})
100         }
101         for {
102                 ss, retry, err := listenAllRetry(nahs, port, proxyURL, f)
103                 if !retry {
104                         return ss, err
105                 }
106         }
107 }
108
109 type networkAndHost struct {
110         Network network
111         Host    string
112 }
113
114 func listenAllRetry(nahs []networkAndHost, port int, proxyURL string, f firewallCallback) (ss []socket, retry bool, err error) {
115         ss = make([]socket, 1, len(nahs))
116         portStr := strconv.FormatInt(int64(port), 10)
117         ss[0], err = listen(nahs[0].Network, net.JoinHostPort(nahs[0].Host, portStr), proxyURL, f)
118         if err != nil {
119                 return nil, false, errors.Wrap(err, "first listen")
120         }
121         defer func() {
122                 if err != nil || retry {
123                         for _, s := range ss {
124                                 s.Close()
125                         }
126                         ss = nil
127                 }
128         }()
129         portStr = strconv.FormatInt(int64(missinggo.AddrPort(ss[0].Addr())), 10)
130         for _, nah := range nahs[1:] {
131                 s, err := listen(nah.Network, net.JoinHostPort(nah.Host, portStr), proxyURL, f)
132                 if err != nil {
133                         return ss,
134                                 missinggo.IsAddrInUse(err) && port == 0,
135                                 errors.Wrap(err, "subsequent listen")
136                 }
137                 ss = append(ss, s)
138         }
139         return
140 }
141
142 type firewallCallback func(net.Addr) bool
143
144 func listenUtp(network, addr, proxyURL string, fc firewallCallback) (s socket, err error) {
145         us, err := NewUtpSocket(network, addr, fc)
146         if err != nil {
147                 return
148         }
149
150         // If we don't need the proxy - then we should return default net.Dialer,
151         // otherwise, let's try to parse the proxyURL and return proxy.Dialer
152         if len(proxyURL) != 0 {
153                 ds := disabledUtpSocket{us}
154                 dialer, err := getProxyDialer(proxyURL)
155                 if err != nil {
156                         return nil, err
157                 }
158                 return utpSocketSocket{ds, network, dialer}, nil
159         }
160
161         return utpSocketSocket{us, network, nil}, nil
162 }
163
164 type disabledUtpSocket struct {
165         utpSocket
166 }
167
168 func (ds disabledUtpSocket) Accept() (net.Conn, error) {
169         return nil, fmt.Errorf("utp listener disabled due to proxy")
170 }
171
172 type utpSocketSocket struct {
173         utpSocket
174         network string
175         d       proxy.Dialer
176 }
177
178 func (me utpSocketSocket) dial(ctx context.Context, addr string) (conn net.Conn, err error) {
179         defer perf.ScopeTimerErr(&err)()
180         if me.d != nil {
181                 return me.d.Dial(me.network, addr)
182         }
183
184         return me.utpSocket.DialContext(ctx, me.network, addr)
185 }