]> Sergey Matveev's repositories - btrtrc.git/blob - tracker/udp/server/server.go
9df2fc67caee6909adbd472a50f15499d1568102
[btrtrc.git] / tracker / udp / server / server.go
1 package udpTrackerServer
2
3 import (
4         "bytes"
5         "context"
6         "crypto/rand"
7         "encoding/binary"
8         "fmt"
9         "io"
10         "net"
11         "net/netip"
12
13         "github.com/anacrolix/dht/v2/krpc"
14         "github.com/anacrolix/generics"
15         "github.com/anacrolix/log"
16         trackerServer "github.com/anacrolix/torrent/tracker/server"
17         "go.opentelemetry.io/otel"
18         "go.opentelemetry.io/otel/attribute"
19         "go.opentelemetry.io/otel/codes"
20         "go.opentelemetry.io/otel/trace"
21
22         "github.com/anacrolix/torrent/tracker/udp"
23 )
24
25 type ConnectionTrackerAddr = string
26
27 type ConnectionTracker interface {
28         Add(ctx context.Context, addr ConnectionTrackerAddr, id udp.ConnectionId) error
29         Check(ctx context.Context, addr ConnectionTrackerAddr, id udp.ConnectionId) (bool, error)
30 }
31
32 type InfoHash = [20]byte
33
34 type AnnounceTracker = trackerServer.AnnounceTracker
35
36 type Server struct {
37         ConnTracker  ConnectionTracker
38         SendResponse func(data []byte, addr net.Addr) (int, error)
39         Announce     *trackerServer.AnnounceHandler
40 }
41
42 type RequestSourceAddr = net.Addr
43
44 var tracer = otel.Tracer("torrent.tracker.udp")
45
46 func (me *Server) HandleRequest(
47         ctx context.Context,
48         family udp.AddrFamily,
49         source RequestSourceAddr,
50         body []byte,
51 ) (err error) {
52         ctx, span := tracer.Start(ctx, "Server.HandleRequest",
53                 trace.WithAttributes(attribute.Int("payload.len", len(body))))
54         defer span.End()
55         defer func() {
56                 if err != nil {
57                         span.SetStatus(codes.Error, err.Error())
58                 }
59         }()
60         var h udp.RequestHeader
61         var r bytes.Reader
62         r.Reset(body)
63         err = udp.Read(&r, &h)
64         if err != nil {
65                 err = fmt.Errorf("reading request header: %w", err)
66                 return err
67         }
68         switch h.Action {
69         case udp.ActionConnect:
70                 err = me.handleConnect(ctx, source, h.TransactionId)
71         case udp.ActionAnnounce:
72                 err = me.handleAnnounce(ctx, family, source, h.ConnectionId, h.TransactionId, &r)
73         default:
74                 err = fmt.Errorf("unimplemented")
75         }
76         if err != nil {
77                 err = fmt.Errorf("handling action %v: %w", h.Action, err)
78         }
79         return err
80 }
81
82 func (me *Server) handleAnnounce(
83         ctx context.Context,
84         addrFamily udp.AddrFamily,
85         source RequestSourceAddr,
86         connId udp.ConnectionId,
87         tid udp.TransactionId,
88         r *bytes.Reader,
89 ) error {
90         // Should we set a timeout of 10s or something for the entire response, so that we give up if a
91         // retry is imminent?
92
93         ok, err := me.ConnTracker.Check(ctx, source.String(), connId)
94         if err != nil {
95                 err = fmt.Errorf("checking conn id: %w", err)
96                 return err
97         }
98         if !ok {
99                 return fmt.Errorf("incorrect connection id: %x", connId)
100         }
101         var req udp.AnnounceRequest
102         err = udp.Read(r, &req)
103         if err != nil {
104                 return err
105         }
106         // TODO: This should be done asynchronously to responding to the announce.
107         announceAddr, err := netip.ParseAddrPort(source.String())
108         if err != nil {
109                 err = fmt.Errorf("converting source net.Addr to AnnounceAddr: %w", err)
110                 return err
111         }
112         opts := trackerServer.GetPeersOpts{MaxCount: generics.Some[uint](50)}
113         if addrFamily == udp.AddrFamilyIpv4 {
114                 opts.MaxCount = generics.Some[uint](150)
115         }
116         res := me.Announce.Serve(ctx, req, announceAddr, opts)
117         if res.Err != nil {
118                 return res.Err
119         }
120         nodeAddrs := make([]krpc.NodeAddr, 0, len(res.Peers))
121         for _, p := range res.Peers {
122                 var ip net.IP
123                 switch addrFamily {
124                 default:
125                         continue
126                 case udp.AddrFamilyIpv4:
127                         if !p.Addr().Unmap().Is4() {
128                                 continue
129                         }
130                         ipBuf := p.Addr().As4()
131                         ip = ipBuf[:]
132                 case udp.AddrFamilyIpv6:
133                         ipBuf := p.Addr().As16()
134                         ip = ipBuf[:]
135                 }
136                 nodeAddrs = append(nodeAddrs, krpc.NodeAddr{
137                         IP:   ip[:],
138                         Port: int(p.Port()),
139                 })
140         }
141         var buf bytes.Buffer
142         err = udp.Write(&buf, udp.ResponseHeader{
143                 Action:        udp.ActionAnnounce,
144                 TransactionId: tid,
145         })
146         if err != nil {
147                 return err
148         }
149         err = udp.Write(&buf, udp.AnnounceResponseHeader{
150                 Interval: res.Interval.UnwrapOr(5 * 60),
151                 Seeders:  res.Seeders.Value,
152                 Leechers: res.Leechers.Value,
153         })
154         if err != nil {
155                 return err
156         }
157         b, err := udp.GetNodeAddrsCompactMarshaler(nodeAddrs, addrFamily).MarshalBinary()
158         if err != nil {
159                 err = fmt.Errorf("marshalling compact node addrs: %w", err)
160                 return err
161         }
162         buf.Write(b)
163         n, err := me.SendResponse(buf.Bytes(), source)
164         if err != nil {
165                 return err
166         }
167         if n < buf.Len() {
168                 err = io.ErrShortWrite
169         }
170         return err
171 }
172
173 func (me *Server) handleConnect(ctx context.Context, source RequestSourceAddr, tid udp.TransactionId) error {
174         connId := randomConnectionId()
175         err := me.ConnTracker.Add(ctx, source.String(), connId)
176         if err != nil {
177                 err = fmt.Errorf("recording conn id: %w", err)
178                 return err
179         }
180         var buf bytes.Buffer
181         udp.Write(&buf, udp.ResponseHeader{
182                 Action:        udp.ActionConnect,
183                 TransactionId: tid,
184         })
185         udp.Write(&buf, udp.ConnectionResponse{connId})
186         n, err := me.SendResponse(buf.Bytes(), source)
187         if err != nil {
188                 return err
189         }
190         if n < buf.Len() {
191                 err = io.ErrShortWrite
192         }
193         return err
194 }
195
196 func randomConnectionId() udp.ConnectionId {
197         var b [8]byte
198         _, err := rand.Read(b[:])
199         if err != nil {
200                 panic(err)
201         }
202         return binary.BigEndian.Uint64(b[:])
203 }
204
205 func RunSimple(ctx context.Context, s *Server, pc net.PacketConn, family udp.AddrFamily) error {
206         ctx, cancel := context.WithCancel(ctx)
207         defer cancel()
208         for {
209                 var b [1500]byte
210                 n, addr, err := pc.ReadFrom(b[:])
211                 if err != nil {
212                         return err
213                 }
214                 go func() {
215                         err := s.HandleRequest(ctx, family, addr, b[:n])
216                         if err != nil {
217                                 log.Printf("error handling %v byte request from %v: %v", n, addr, err)
218                         }
219                 }()
220         }
221 }