]> Sergey Matveev's repositories - btrtrc.git/blob - tracker/udp/server/server.go
c8c9bec364c65a180caa30c6459e21f58d9a27f0
[btrtrc.git] / tracker / udp / server / server.go
1 package udpTrackerServer
2
3 import (
4         "bytes"
5         "context"
6         "crypto/rand"
7         "encoding/binary"
8         "fmt"
9         "io"
10         "net"
11         "net/netip"
12
13         "github.com/anacrolix/dht/v2/krpc"
14         "github.com/anacrolix/generics"
15         "github.com/anacrolix/log"
16         "go.opentelemetry.io/otel"
17
18         "github.com/anacrolix/torrent/tracker"
19         "github.com/anacrolix/torrent/tracker/udp"
20 )
21
22 type ConnectionTrackerAddr = string
23
24 type ConnectionTracker interface {
25         Add(ctx context.Context, addr ConnectionTrackerAddr, id udp.ConnectionId) error
26         Check(ctx context.Context, addr ConnectionTrackerAddr, id udp.ConnectionId) (bool, error)
27 }
28
29 type InfoHash = [20]byte
30
31 type AnnounceTracker = tracker.AnnounceTracker
32
33 type Server struct {
34         ConnTracker  ConnectionTracker
35         SendResponse func(data []byte, addr net.Addr) (int, error)
36         Announce     *tracker.AnnounceHandler
37 }
38
39 type RequestSourceAddr = net.Addr
40
41 var tracer = otel.Tracer("torrent.tracker.udp")
42
43 func (me *Server) HandleRequest(
44         ctx context.Context,
45         family udp.AddrFamily,
46         source RequestSourceAddr,
47         body []byte,
48 ) error {
49         ctx, span := tracer.Start(ctx, "Server.HandleRequest")
50         defer span.End()
51         var h udp.RequestHeader
52         var r bytes.Reader
53         r.Reset(body)
54         err := udp.Read(&r, &h)
55         if err != nil {
56                 err = fmt.Errorf("reading request header: %w", err)
57                 return err
58         }
59         switch h.Action {
60         case udp.ActionConnect:
61                 err = me.handleConnect(ctx, source, h.TransactionId)
62         case udp.ActionAnnounce:
63                 err = me.handleAnnounce(ctx, family, source, h.ConnectionId, h.TransactionId, &r)
64         default:
65                 err = fmt.Errorf("unimplemented")
66         }
67         if err != nil {
68                 err = fmt.Errorf("handling action %v: %w", h.Action, err)
69         }
70         return err
71 }
72
73 func (me *Server) handleAnnounce(
74         ctx context.Context,
75         addrFamily udp.AddrFamily,
76         source RequestSourceAddr,
77         connId udp.ConnectionId,
78         tid udp.TransactionId,
79         r *bytes.Reader,
80 ) error {
81         // Should we set a timeout of 10s or something for the entire response, so that we give up if a
82         // retry is imminent?
83
84         ok, err := me.ConnTracker.Check(ctx, source.String(), connId)
85         if err != nil {
86                 err = fmt.Errorf("checking conn id: %w", err)
87                 return err
88         }
89         if !ok {
90                 return fmt.Errorf("incorrect connection id: %x", connId)
91         }
92         var req udp.AnnounceRequest
93         err = udp.Read(r, &req)
94         if err != nil {
95                 return err
96         }
97         // TODO: This should be done asynchronously to responding to the announce.
98         announceAddr, err := netip.ParseAddrPort(source.String())
99         if err != nil {
100                 err = fmt.Errorf("converting source net.Addr to AnnounceAddr: %w", err)
101                 return err
102         }
103         opts := tracker.GetPeersOpts{MaxCount: generics.Some[uint](50)}
104         if addrFamily == udp.AddrFamilyIpv4 {
105                 opts.MaxCount = generics.Some[uint](150)
106         }
107         res := me.Announce.Serve(ctx, req, announceAddr, opts)
108         if res.Err != nil {
109                 return res.Err
110         }
111         nodeAddrs := make([]krpc.NodeAddr, 0, len(res.Peers))
112         for _, p := range res.Peers {
113                 var ip net.IP
114                 switch addrFamily {
115                 default:
116                         continue
117                 case udp.AddrFamilyIpv4:
118                         if !p.Addr().Unmap().Is4() {
119                                 continue
120                         }
121                         ipBuf := p.Addr().As4()
122                         ip = ipBuf[:]
123                 case udp.AddrFamilyIpv6:
124                         ipBuf := p.Addr().As16()
125                         ip = ipBuf[:]
126                 }
127                 nodeAddrs = append(nodeAddrs, krpc.NodeAddr{
128                         IP:   ip[:],
129                         Port: int(p.Port()),
130                 })
131         }
132         var buf bytes.Buffer
133         err = udp.Write(&buf, udp.ResponseHeader{
134                 Action:        udp.ActionAnnounce,
135                 TransactionId: tid,
136         })
137         if err != nil {
138                 return err
139         }
140         err = udp.Write(&buf, udp.AnnounceResponseHeader{
141                 Interval: res.Interval.UnwrapOr(5 * 60),
142         })
143         if err != nil {
144                 return err
145         }
146         b, err := udp.GetNodeAddrsCompactMarshaler(nodeAddrs, addrFamily).MarshalBinary()
147         if err != nil {
148                 err = fmt.Errorf("marshalling compact node addrs: %w", err)
149                 return err
150         }
151         buf.Write(b)
152         n, err := me.SendResponse(buf.Bytes(), source)
153         if err != nil {
154                 return err
155         }
156         if n < buf.Len() {
157                 err = io.ErrShortWrite
158         }
159         return err
160 }
161
162 func (me *Server) handleConnect(ctx context.Context, source RequestSourceAddr, tid udp.TransactionId) error {
163         connId := randomConnectionId()
164         err := me.ConnTracker.Add(ctx, source.String(), connId)
165         if err != nil {
166                 err = fmt.Errorf("recording conn id: %w", err)
167                 return err
168         }
169         var buf bytes.Buffer
170         udp.Write(&buf, udp.ResponseHeader{
171                 Action:        udp.ActionConnect,
172                 TransactionId: tid,
173         })
174         udp.Write(&buf, udp.ConnectionResponse{connId})
175         n, err := me.SendResponse(buf.Bytes(), source)
176         if err != nil {
177                 return err
178         }
179         if n < buf.Len() {
180                 err = io.ErrShortWrite
181         }
182         return err
183 }
184
185 func randomConnectionId() udp.ConnectionId {
186         var b [8]byte
187         _, err := rand.Read(b[:])
188         if err != nil {
189                 panic(err)
190         }
191         return binary.BigEndian.Uint64(b[:])
192 }
193
194 func RunSimple(ctx context.Context, s *Server, pc net.PacketConn, family udp.AddrFamily) error {
195         ctx, cancel := context.WithCancel(ctx)
196         defer cancel()
197         for {
198                 var b [1500]byte
199                 n, addr, err := pc.ReadFrom(b[:])
200                 if err != nil {
201                         return err
202                 }
203                 go func() {
204                         err := s.HandleRequest(ctx, family, addr, b[:n])
205                         if err != nil {
206                                 log.Printf("error handling %v byte request from %v: %v", n, addr, err)
207                         }
208                 }()
209         }
210 }