]> Sergey Matveev's repositories - btrtrc.git/blob - tracker/udp/server/server.go
2007233f4f0c067e04ff6c755ef56631eb0ea161
[btrtrc.git] / tracker / udp / server / server.go
1 package udpTrackerServer
2
3 import (
4         "bytes"
5         "context"
6         "crypto/rand"
7         "encoding/binary"
8         "fmt"
9         "io"
10         "net"
11         "net/netip"
12
13         "github.com/anacrolix/dht/v2/krpc"
14         "github.com/anacrolix/generics"
15         "github.com/anacrolix/log"
16         trackerServer "github.com/anacrolix/torrent/tracker/server"
17         "go.opentelemetry.io/otel"
18         "go.opentelemetry.io/otel/codes"
19
20         "github.com/anacrolix/torrent/tracker/udp"
21 )
22
23 type ConnectionTrackerAddr = string
24
25 type ConnectionTracker interface {
26         Add(ctx context.Context, addr ConnectionTrackerAddr, id udp.ConnectionId) error
27         Check(ctx context.Context, addr ConnectionTrackerAddr, id udp.ConnectionId) (bool, error)
28 }
29
30 type InfoHash = [20]byte
31
32 type AnnounceTracker = trackerServer.AnnounceTracker
33
34 type Server struct {
35         ConnTracker  ConnectionTracker
36         SendResponse func(data []byte, addr net.Addr) (int, error)
37         Announce     *trackerServer.AnnounceHandler
38 }
39
40 type RequestSourceAddr = net.Addr
41
42 var tracer = otel.Tracer("torrent.tracker.udp")
43
44 func (me *Server) HandleRequest(
45         ctx context.Context,
46         family udp.AddrFamily,
47         source RequestSourceAddr,
48         body []byte,
49 ) (err error) {
50         ctx, span := tracer.Start(ctx, "Server.HandleRequest")
51         defer span.End()
52         defer func() {
53                 if err != nil {
54                         span.SetStatus(codes.Error, err.Error())
55                 }
56         }()
57         var h udp.RequestHeader
58         var r bytes.Reader
59         r.Reset(body)
60         err = udp.Read(&r, &h)
61         if err != nil {
62                 err = fmt.Errorf("reading request header: %w", err)
63                 return err
64         }
65         switch h.Action {
66         case udp.ActionConnect:
67                 err = me.handleConnect(ctx, source, h.TransactionId)
68         case udp.ActionAnnounce:
69                 err = me.handleAnnounce(ctx, family, source, h.ConnectionId, h.TransactionId, &r)
70         default:
71                 err = fmt.Errorf("unimplemented")
72         }
73         if err != nil {
74                 err = fmt.Errorf("handling action %v: %w", h.Action, err)
75         }
76         return err
77 }
78
79 func (me *Server) handleAnnounce(
80         ctx context.Context,
81         addrFamily udp.AddrFamily,
82         source RequestSourceAddr,
83         connId udp.ConnectionId,
84         tid udp.TransactionId,
85         r *bytes.Reader,
86 ) error {
87         // Should we set a timeout of 10s or something for the entire response, so that we give up if a
88         // retry is imminent?
89
90         ok, err := me.ConnTracker.Check(ctx, source.String(), connId)
91         if err != nil {
92                 err = fmt.Errorf("checking conn id: %w", err)
93                 return err
94         }
95         if !ok {
96                 return fmt.Errorf("incorrect connection id: %x", connId)
97         }
98         var req udp.AnnounceRequest
99         err = udp.Read(r, &req)
100         if err != nil {
101                 return err
102         }
103         // TODO: This should be done asynchronously to responding to the announce.
104         announceAddr, err := netip.ParseAddrPort(source.String())
105         if err != nil {
106                 err = fmt.Errorf("converting source net.Addr to AnnounceAddr: %w", err)
107                 return err
108         }
109         opts := trackerServer.GetPeersOpts{MaxCount: generics.Some[uint](50)}
110         if addrFamily == udp.AddrFamilyIpv4 {
111                 opts.MaxCount = generics.Some[uint](150)
112         }
113         res := me.Announce.Serve(ctx, req, announceAddr, opts)
114         if res.Err != nil {
115                 return res.Err
116         }
117         nodeAddrs := make([]krpc.NodeAddr, 0, len(res.Peers))
118         for _, p := range res.Peers {
119                 var ip net.IP
120                 switch addrFamily {
121                 default:
122                         continue
123                 case udp.AddrFamilyIpv4:
124                         if !p.Addr().Unmap().Is4() {
125                                 continue
126                         }
127                         ipBuf := p.Addr().As4()
128                         ip = ipBuf[:]
129                 case udp.AddrFamilyIpv6:
130                         ipBuf := p.Addr().As16()
131                         ip = ipBuf[:]
132                 }
133                 nodeAddrs = append(nodeAddrs, krpc.NodeAddr{
134                         IP:   ip[:],
135                         Port: int(p.Port()),
136                 })
137         }
138         var buf bytes.Buffer
139         err = udp.Write(&buf, udp.ResponseHeader{
140                 Action:        udp.ActionAnnounce,
141                 TransactionId: tid,
142         })
143         if err != nil {
144                 return err
145         }
146         err = udp.Write(&buf, udp.AnnounceResponseHeader{
147                 Interval: res.Interval.UnwrapOr(5 * 60),
148                 Seeders:  res.Seeders.Value,
149                 Leechers: res.Leechers.Value,
150         })
151         if err != nil {
152                 return err
153         }
154         b, err := udp.GetNodeAddrsCompactMarshaler(nodeAddrs, addrFamily).MarshalBinary()
155         if err != nil {
156                 err = fmt.Errorf("marshalling compact node addrs: %w", err)
157                 return err
158         }
159         buf.Write(b)
160         n, err := me.SendResponse(buf.Bytes(), source)
161         if err != nil {
162                 return err
163         }
164         if n < buf.Len() {
165                 err = io.ErrShortWrite
166         }
167         return err
168 }
169
170 func (me *Server) handleConnect(ctx context.Context, source RequestSourceAddr, tid udp.TransactionId) error {
171         connId := randomConnectionId()
172         err := me.ConnTracker.Add(ctx, source.String(), connId)
173         if err != nil {
174                 err = fmt.Errorf("recording conn id: %w", err)
175                 return err
176         }
177         var buf bytes.Buffer
178         udp.Write(&buf, udp.ResponseHeader{
179                 Action:        udp.ActionConnect,
180                 TransactionId: tid,
181         })
182         udp.Write(&buf, udp.ConnectionResponse{connId})
183         n, err := me.SendResponse(buf.Bytes(), source)
184         if err != nil {
185                 return err
186         }
187         if n < buf.Len() {
188                 err = io.ErrShortWrite
189         }
190         return err
191 }
192
193 func randomConnectionId() udp.ConnectionId {
194         var b [8]byte
195         _, err := rand.Read(b[:])
196         if err != nil {
197                 panic(err)
198         }
199         return binary.BigEndian.Uint64(b[:])
200 }
201
202 func RunSimple(ctx context.Context, s *Server, pc net.PacketConn, family udp.AddrFamily) error {
203         ctx, cancel := context.WithCancel(ctx)
204         defer cancel()
205         for {
206                 var b [1500]byte
207                 n, addr, err := pc.ReadFrom(b[:])
208                 if err != nil {
209                         return err
210                 }
211                 go func() {
212                         err := s.HandleRequest(ctx, family, addr, b[:n])
213                         if err != nil {
214                                 log.Printf("error handling %v byte request from %v: %v", n, addr, err)
215                         }
216                 }()
217         }
218 }