]> Sergey Matveev's repositories - btrtrc.git/blob - socket.go
Add support for non-IP-based networks
[btrtrc.git] / socket.go
1 package torrent
2
3 import (
4         "context"
5         "net"
6         "strconv"
7
8         "github.com/anacrolix/missinggo"
9         "github.com/anacrolix/missinggo/perf"
10         "github.com/pkg/errors"
11 )
12
13 type Listener interface {
14         net.Listener
15 }
16
17 type socket interface {
18         Listener
19         Dialer
20 }
21
22 func listen(n network, addr string, f firewallCallback) (socket, error) {
23         switch {
24         case n.Tcp:
25                 return listenTcp(n.String(), addr)
26         case n.Udp:
27                 return listenUtp(n.String(), addr, f)
28         default:
29                 panic(n)
30         }
31 }
32
33 func listenTcp(network, address string) (s socket, err error) {
34         l, err := net.Listen(network, address)
35         return tcpSocket{
36                 Listener: l,
37                 NetDialer: NetDialer{
38                         Network: network,
39                 },
40         }, err
41 }
42
43 type tcpSocket struct {
44         net.Listener
45         NetDialer
46 }
47
48 func listenAll(networks []network, getHost func(string) string, port int, f firewallCallback) ([]socket, error) {
49         if len(networks) == 0 {
50                 return nil, nil
51         }
52         var nahs []networkAndHost
53         for _, n := range networks {
54                 nahs = append(nahs, networkAndHost{n, getHost(n.String())})
55         }
56         for {
57                 ss, retry, err := listenAllRetry(nahs, port, f)
58                 if !retry {
59                         return ss, err
60                 }
61         }
62 }
63
64 type networkAndHost struct {
65         Network network
66         Host    string
67 }
68
69 func listenAllRetry(nahs []networkAndHost, port int, f firewallCallback) (ss []socket, retry bool, err error) {
70         ss = make([]socket, 1, len(nahs))
71         portStr := strconv.FormatInt(int64(port), 10)
72         ss[0], err = listen(nahs[0].Network, net.JoinHostPort(nahs[0].Host, portStr), f)
73         if err != nil {
74                 return nil, false, errors.Wrap(err, "first listen")
75         }
76         defer func() {
77                 if err != nil || retry {
78                         for _, s := range ss {
79                                 s.Close()
80                         }
81                         ss = nil
82                 }
83         }()
84         portStr = strconv.FormatInt(int64(missinggo.AddrPort(ss[0].Addr())), 10)
85         for _, nah := range nahs[1:] {
86                 s, err := listen(nah.Network, net.JoinHostPort(nah.Host, portStr), f)
87                 if err != nil {
88                         return ss,
89                                 missinggo.IsAddrInUse(err) && port == 0,
90                                 errors.Wrap(err, "subsequent listen")
91                 }
92                 ss = append(ss, s)
93         }
94         return
95 }
96
97 type firewallCallback func(net.Addr) bool
98
99 func listenUtp(network, addr string, fc firewallCallback) (socket, error) {
100         us, err := NewUtpSocket(network, addr, fc)
101         return utpSocketSocket{us, network}, err
102 }
103
104 type utpSocketSocket struct {
105         utpSocket
106         network string
107 }
108
109 func (me utpSocketSocket) Dial(ctx context.Context, addr string) (conn net.Conn, err error) {
110         defer perf.ScopeTimerErr(&err)()
111         return me.utpSocket.DialContext(ctx, me.network, addr)
112 }