]> Sergey Matveev's repositories - btrtrc.git/blob - tracker/udp/server/server.go
Support upstream trackers
[btrtrc.git] / tracker / udp / server / server.go
1 package udpTrackerServer
2
3 import (
4         "bytes"
5         "context"
6         "crypto/rand"
7         "encoding/binary"
8         "fmt"
9         "io"
10         "net"
11         "net/netip"
12
13         "github.com/anacrolix/dht/v2/krpc"
14         "github.com/anacrolix/log"
15
16         "github.com/anacrolix/torrent/tracker"
17         "github.com/anacrolix/torrent/tracker/udp"
18 )
19
20 type ConnectionTrackerAddr = string
21
22 type ConnectionTracker interface {
23         Add(ctx context.Context, addr ConnectionTrackerAddr, id udp.ConnectionId) error
24         Check(ctx context.Context, addr ConnectionTrackerAddr, id udp.ConnectionId) (bool, error)
25 }
26
27 type InfoHash = [20]byte
28
29 type AnnounceTracker = tracker.AnnounceTracker
30
31 type Server struct {
32         ConnTracker  ConnectionTracker
33         SendResponse func(data []byte, addr net.Addr) (int, error)
34         Announce     tracker.AnnounceHandler
35 }
36
37 type RequestSourceAddr = net.Addr
38
39 func (me *Server) HandleRequest(
40         ctx context.Context,
41         family udp.AddrFamily,
42         source RequestSourceAddr,
43         body []byte,
44 ) error {
45         var h udp.RequestHeader
46         var r bytes.Reader
47         r.Reset(body)
48         err := udp.Read(&r, &h)
49         if err != nil {
50                 err = fmt.Errorf("reading request header: %w", err)
51                 return err
52         }
53         switch h.Action {
54         case udp.ActionConnect:
55                 err = me.handleConnect(ctx, source, h.TransactionId)
56         case udp.ActionAnnounce:
57                 err = me.handleAnnounce(ctx, family, source, h.ConnectionId, h.TransactionId, &r)
58         default:
59                 err = fmt.Errorf("unimplemented")
60         }
61         if err != nil {
62                 err = fmt.Errorf("handling action %v: %w", h.Action, err)
63         }
64         return err
65 }
66
67 func (me *Server) handleAnnounce(
68         ctx context.Context,
69         addrFamily udp.AddrFamily,
70         source RequestSourceAddr,
71         connId udp.ConnectionId,
72         tid udp.TransactionId,
73         r *bytes.Reader,
74 ) error {
75         // Should we set a timeout of 10s or something for the entire response, so that we give up if a
76         // retry is imminent?
77
78         ok, err := me.ConnTracker.Check(ctx, source.String(), connId)
79         if err != nil {
80                 err = fmt.Errorf("checking conn id: %w", err)
81                 return err
82         }
83         if !ok {
84                 return fmt.Errorf("invalid connection id: %v", connId)
85         }
86         var req udp.AnnounceRequest
87         err = udp.Read(r, &req)
88         if err != nil {
89                 return err
90         }
91         // TODO: This should be done asynchronously to responding to the announce.
92         announceAddr, err := netip.ParseAddrPort(source.String())
93         if err != nil {
94                 err = fmt.Errorf("converting source net.Addr to AnnounceAddr: %w", err)
95                 return err
96         }
97         peers, err := me.Announce.Serve(ctx, req, announceAddr)
98         if err != nil {
99                 return err
100         }
101         nodeAddrs := make([]krpc.NodeAddr, 0, len(peers))
102         for _, p := range peers {
103                 var ip net.IP
104                 switch addrFamily {
105                 default:
106                         continue
107                 case udp.AddrFamilyIpv4:
108                         if !p.Addr().Unmap().Is4() {
109                                 continue
110                         }
111                         ipBuf := p.Addr().As4()
112                         ip = ipBuf[:]
113                 case udp.AddrFamilyIpv6:
114                         ipBuf := p.Addr().As16()
115                         ip = ipBuf[:]
116                 }
117                 nodeAddrs = append(nodeAddrs, krpc.NodeAddr{
118                         IP:   ip[:],
119                         Port: int(p.Port()),
120                 })
121         }
122         var buf bytes.Buffer
123         err = udp.Write(&buf, udp.ResponseHeader{
124                 Action:        udp.ActionAnnounce,
125                 TransactionId: tid,
126         })
127         if err != nil {
128                 return err
129         }
130         err = udp.Write(&buf, udp.AnnounceResponseHeader{})
131         if err != nil {
132                 return err
133         }
134         b, err := udp.GetNodeAddrsCompactMarshaler(nodeAddrs, addrFamily).MarshalBinary()
135         if err != nil {
136                 err = fmt.Errorf("marshalling compact node addrs: %w", err)
137                 return err
138         }
139         buf.Write(b)
140         n, err := me.SendResponse(buf.Bytes(), source)
141         if err != nil {
142                 return err
143         }
144         if n < buf.Len() {
145                 err = io.ErrShortWrite
146         }
147         return err
148 }
149
150 func (me *Server) handleConnect(ctx context.Context, source RequestSourceAddr, tid udp.TransactionId) error {
151         connId := randomConnectionId()
152         err := me.ConnTracker.Add(ctx, source.String(), connId)
153         if err != nil {
154                 err = fmt.Errorf("recording conn id: %w", err)
155                 return err
156         }
157         var buf bytes.Buffer
158         udp.Write(&buf, udp.ResponseHeader{
159                 Action:        udp.ActionConnect,
160                 TransactionId: tid,
161         })
162         udp.Write(&buf, udp.ConnectionResponse{connId})
163         n, err := me.SendResponse(buf.Bytes(), source)
164         if err != nil {
165                 return err
166         }
167         if n < buf.Len() {
168                 err = io.ErrShortWrite
169         }
170         return err
171 }
172
173 func randomConnectionId() udp.ConnectionId {
174         var b [8]byte
175         _, err := rand.Read(b[:])
176         if err != nil {
177                 panic(err)
178         }
179         return int64(binary.BigEndian.Uint64(b[:]))
180 }
181
182 func RunSimple(ctx context.Context, s *Server, pc net.PacketConn, family udp.AddrFamily) error {
183         ctx, cancel := context.WithCancel(ctx)
184         defer cancel()
185         for {
186                 var b [1500]byte
187                 n, addr, err := pc.ReadFrom(b[:])
188                 if err != nil {
189                         return err
190                 }
191                 go func() {
192                         err := s.HandleRequest(ctx, family, addr, b[:n])
193                         if err != nil {
194                                 log.Printf("error handling %v byte request from %v: %v", n, addr, err)
195                         }
196                 }()
197         }
198 }