]> Sergey Matveev's repositories - btrtrc.git/blob - socket.go
sortimports
[btrtrc.git] / socket.go
1 package torrent
2
3 import (
4         "context"
5         "fmt"
6         "net"
7         "net/url"
8         "strconv"
9         "strings"
10
11         "github.com/anacrolix/missinggo"
12         "github.com/anacrolix/missinggo/perf"
13         "golang.org/x/net/proxy"
14 )
15
16 type dialer interface {
17         dial(_ context.Context, addr string) (net.Conn, error)
18 }
19
20 type socket interface {
21         net.Listener
22         dialer
23 }
24
25 func getProxyDialer(proxyURL string) (proxy.Dialer, error) {
26         fixedURL, err := url.Parse(proxyURL)
27         if err != nil {
28                 return nil, err
29         }
30
31         return proxy.FromURL(fixedURL, proxy.Direct)
32 }
33
34 func listen(network, addr, proxyURL string) (socket, error) {
35         if isTcpNetwork(network) {
36                 return listenTcp(network, addr, proxyURL)
37         } else if isUtpNetwork(network) {
38                 return listenUtp(network, addr, proxyURL)
39         } else {
40                 panic(fmt.Sprintf("unknown network %q", network))
41         }
42 }
43
44 func isTcpNetwork(s string) bool {
45         return strings.Contains(s, "tcp")
46 }
47
48 func isUtpNetwork(s string) bool {
49         return strings.Contains(s, "utp") || strings.Contains(s, "udp")
50 }
51
52 func listenTcp(network, address, proxyURL string) (s socket, err error) {
53         l, err := net.Listen(network, address)
54         if err != nil {
55                 return
56         }
57         defer func() {
58                 if err != nil {
59                         l.Close()
60                 }
61         }()
62
63         // If we don't need the proxy - then we should return default net.Dialer,
64         // otherwise, let's try to parse the proxyURL and return proxy.Dialer
65         if len(proxyURL) != 0 {
66                 // TODO: The error should be propagated, as proxy may be in use for
67                 // security or privacy reasons. Also just pass proxy.Dialer in from
68                 // the Config.
69                 if dialer, err := getProxyDialer(proxyURL); err == nil {
70                         return tcpSocket{l, func(ctx context.Context, addr string) (conn net.Conn, err error) {
71                                 defer perf.ScopeTimerErr(&err)()
72                                 return dialer.Dial(network, addr)
73                         }}, nil
74                 }
75         }
76         dialer := net.Dialer{}
77         return tcpSocket{l, func(ctx context.Context, addr string) (conn net.Conn, err error) {
78                 defer perf.ScopeTimerErr(&err)()
79                 return dialer.DialContext(ctx, network, addr)
80         }}, nil
81 }
82
83 type tcpSocket struct {
84         net.Listener
85         d func(ctx context.Context, addr string) (net.Conn, error)
86 }
87
88 func (me tcpSocket) dial(ctx context.Context, addr string) (net.Conn, error) {
89         return me.d(ctx, addr)
90 }
91
92 func setPort(addr string, port int) string {
93         host, _, err := net.SplitHostPort(addr)
94         if err != nil {
95                 panic(err)
96         }
97         return net.JoinHostPort(host, strconv.FormatInt(int64(port), 10))
98 }
99
100 func listenAll(networks []string, getHost func(string) string, port int, proxyURL string) ([]socket, error) {
101         if len(networks) == 0 {
102                 return nil, nil
103         }
104         var nahs []networkAndHost
105         for _, n := range networks {
106                 nahs = append(nahs, networkAndHost{n, getHost(n)})
107         }
108         for {
109                 ss, retry, err := listenAllRetry(nahs, port, proxyURL)
110                 if !retry {
111                         return ss, err
112                 }
113         }
114 }
115
116 type networkAndHost struct {
117         Network string
118         Host    string
119 }
120
121 func listenAllRetry(nahs []networkAndHost, port int, proxyURL string) (ss []socket, retry bool, err error) {
122         ss = make([]socket, 1, len(nahs))
123         portStr := strconv.FormatInt(int64(port), 10)
124         ss[0], err = listen(nahs[0].Network, net.JoinHostPort(nahs[0].Host, portStr), proxyURL)
125         if err != nil {
126                 return nil, false, fmt.Errorf("first listen: %s", err)
127         }
128         defer func() {
129                 if err != nil || retry {
130                         for _, s := range ss {
131                                 s.Close()
132                         }
133                         ss = nil
134                 }
135         }()
136         portStr = strconv.FormatInt(int64(missinggo.AddrPort(ss[0].Addr())), 10)
137         for _, nah := range nahs[1:] {
138                 s, err := listen(nah.Network, net.JoinHostPort(nah.Host, portStr), proxyURL)
139                 if err != nil {
140                         return ss,
141                                 missinggo.IsAddrInUse(err) && port == 0,
142                                 fmt.Errorf("subsequent listen: %s", err)
143                 }
144                 ss = append(ss, s)
145         }
146         return
147 }
148
149 func listenUtp(network, addr, proxyURL string) (s socket, err error) {
150         us, err := NewUtpSocket(network, addr)
151         if err != nil {
152                 return
153         }
154
155         // If we don't need the proxy - then we should return default net.Dialer,
156         // otherwise, let's try to parse the proxyURL and return proxy.Dialer
157         if len(proxyURL) != 0 {
158                 if dialer, err := getProxyDialer(proxyURL); err == nil {
159                         return utpSocketSocket{us, network, dialer}, nil
160                 }
161         }
162
163         return utpSocketSocket{us, network, nil}, nil
164 }
165
166 type utpSocketSocket struct {
167         utpSocket
168         network string
169         d       proxy.Dialer
170 }
171
172 func (me utpSocketSocket) dial(ctx context.Context, addr string) (conn net.Conn, err error) {
173         defer perf.ScopeTimerErr(&err)()
174         if me.d != nil {
175                 return me.d.Dial(me.network, addr)
176         }
177
178         return me.utpSocket.DialContext(ctx, me.network, addr)
179 }