]> Sergey Matveev's repositories - btrtrc.git/blob - socket.go
Split Client dialers and listeners
[btrtrc.git] / socket.go
1 package torrent
2
3 import (
4         "context"
5         "net"
6         "strconv"
7
8         "github.com/anacrolix/missinggo"
9         "github.com/anacrolix/missinggo/perf"
10         "github.com/pkg/errors"
11 )
12
13 type dialer interface {
14         dial(_ context.Context, addr string) (net.Conn, error)
15         LocalAddr() net.Addr
16 }
17
18 type listener interface {
19         net.Listener
20 }
21
22 type socket interface {
23         listener
24         dialer
25 }
26
27 func listen(n network, addr string, f firewallCallback) (socket, error) {
28         switch {
29         case n.Tcp:
30                 return listenTcp(n.String(), addr)
31         case n.Udp:
32                 return listenUtp(n.String(), addr, f)
33         default:
34                 panic(n)
35         }
36 }
37
38 func listenTcp(network, address string) (s socket, err error) {
39         l, err := net.Listen(network, address)
40         return tcpSocket{
41                 Listener: l,
42                 network:  network,
43         }, err
44 }
45
46 type tcpSocket struct {
47         net.Listener
48         network string
49         dialer  net.Dialer
50 }
51
52 func (me tcpSocket) dial(ctx context.Context, addr string) (_ net.Conn, err error) {
53         defer perf.ScopeTimerErr(&err)()
54         return me.dialer.DialContext(ctx, me.network, addr)
55 }
56
57 func (me tcpSocket) LocalAddr() net.Addr {
58         return tcpSocketLocalAddr{me.network, me.Listener.Addr().String()}
59 }
60
61 type tcpSocketLocalAddr struct {
62         network string
63         s       string
64 }
65
66 func (me tcpSocketLocalAddr) Network() string { return me.network }
67
68 func (me tcpSocketLocalAddr) String() string { return "" }
69
70 func listenAll(networks []network, getHost func(string) string, port int, f firewallCallback) ([]socket, error) {
71         if len(networks) == 0 {
72                 return nil, nil
73         }
74         var nahs []networkAndHost
75         for _, n := range networks {
76                 nahs = append(nahs, networkAndHost{n, getHost(n.String())})
77         }
78         for {
79                 ss, retry, err := listenAllRetry(nahs, port, f)
80                 if !retry {
81                         return ss, err
82                 }
83         }
84 }
85
86 type networkAndHost struct {
87         Network network
88         Host    string
89 }
90
91 func listenAllRetry(nahs []networkAndHost, port int, f firewallCallback) (ss []socket, retry bool, err error) {
92         ss = make([]socket, 1, len(nahs))
93         portStr := strconv.FormatInt(int64(port), 10)
94         ss[0], err = listen(nahs[0].Network, net.JoinHostPort(nahs[0].Host, portStr), f)
95         if err != nil {
96                 return nil, false, errors.Wrap(err, "first listen")
97         }
98         defer func() {
99                 if err != nil || retry {
100                         for _, s := range ss {
101                                 s.Close()
102                         }
103                         ss = nil
104                 }
105         }()
106         portStr = strconv.FormatInt(int64(missinggo.AddrPort(ss[0].Addr())), 10)
107         for _, nah := range nahs[1:] {
108                 s, err := listen(nah.Network, net.JoinHostPort(nah.Host, portStr), f)
109                 if err != nil {
110                         return ss,
111                                 missinggo.IsAddrInUse(err) && port == 0,
112                                 errors.Wrap(err, "subsequent listen")
113                 }
114                 ss = append(ss, s)
115         }
116         return
117 }
118
119 type firewallCallback func(net.Addr) bool
120
121 func listenUtp(network, addr string, fc firewallCallback) (socket, error) {
122         us, err := NewUtpSocket(network, addr, fc)
123         return utpSocketSocket{us, network}, err
124 }
125
126 type utpSocketSocket struct {
127         utpSocket
128         network string
129 }
130
131 func (me utpSocketSocket) dial(ctx context.Context, addr string) (conn net.Conn, err error) {
132         defer perf.ScopeTimerErr(&err)()
133         return me.utpSocket.DialContext(ctx, me.network, addr)
134 }