1 package udpTrackerServer
13 "github.com/anacrolix/dht/v2/krpc"
14 "github.com/anacrolix/generics"
15 "github.com/anacrolix/log"
16 trackerServer "github.com/anacrolix/torrent/tracker/server"
17 "go.opentelemetry.io/otel"
18 "go.opentelemetry.io/otel/attribute"
19 "go.opentelemetry.io/otel/codes"
20 "go.opentelemetry.io/otel/trace"
22 "github.com/anacrolix/torrent/tracker/udp"
25 type ConnectionTrackerAddr = string
27 type ConnectionTracker interface {
28 Add(ctx context.Context, addr ConnectionTrackerAddr, id udp.ConnectionId) error
29 Check(ctx context.Context, addr ConnectionTrackerAddr, id udp.ConnectionId) (bool, error)
32 type InfoHash = [20]byte
34 type AnnounceTracker = trackerServer.AnnounceTracker
37 ConnTracker ConnectionTracker
38 SendResponse func(ctx context.Context, data []byte, addr net.Addr) (int, error)
39 Announce *trackerServer.AnnounceHandler
42 type RequestSourceAddr = net.Addr
44 var tracer = otel.Tracer("torrent.tracker.udp")
46 func (me *Server) HandleRequest(
48 family udp.AddrFamily,
49 source RequestSourceAddr,
52 ctx, span := tracer.Start(ctx, "Server.HandleRequest",
53 trace.WithAttributes(attribute.Int("payload.len", len(body))))
57 span.SetStatus(codes.Error, err.Error())
60 var h udp.RequestHeader
63 err = udp.Read(&r, &h)
65 err = fmt.Errorf("reading request header: %w", err)
69 case udp.ActionConnect:
70 err = me.handleConnect(ctx, source, h.TransactionId)
71 case udp.ActionAnnounce:
72 err = me.handleAnnounce(ctx, family, source, h.ConnectionId, h.TransactionId, &r)
74 err = fmt.Errorf("unimplemented")
77 err = fmt.Errorf("handling action %v: %w", h.Action, err)
82 func (me *Server) handleAnnounce(
84 addrFamily udp.AddrFamily,
85 source RequestSourceAddr,
86 connId udp.ConnectionId,
87 tid udp.TransactionId,
90 // Should we set a timeout of 10s or something for the entire response, so that we give up if a
93 ok, err := me.ConnTracker.Check(ctx, source.String(), connId)
95 err = fmt.Errorf("checking conn id: %w", err)
99 return fmt.Errorf("incorrect connection id: %x", connId)
101 var req udp.AnnounceRequest
102 err = udp.Read(r, &req)
106 // TODO: This should be done asynchronously to responding to the announce.
107 announceAddr, err := netip.ParseAddrPort(source.String())
109 err = fmt.Errorf("converting source net.Addr to AnnounceAddr: %w", err)
112 opts := trackerServer.GetPeersOpts{MaxCount: generics.Some[uint](50)}
113 if addrFamily == udp.AddrFamilyIpv4 {
114 opts.MaxCount = generics.Some[uint](150)
116 res := me.Announce.Serve(ctx, req, announceAddr, opts)
120 nodeAddrs := make([]krpc.NodeAddr, 0, len(res.Peers))
121 for _, p := range res.Peers {
126 case udp.AddrFamilyIpv4:
127 if !p.Addr().Unmap().Is4() {
130 ipBuf := p.Addr().As4()
132 case udp.AddrFamilyIpv6:
133 ipBuf := p.Addr().As16()
136 nodeAddrs = append(nodeAddrs, krpc.NodeAddr{
142 err = udp.Write(&buf, udp.ResponseHeader{
143 Action: udp.ActionAnnounce,
149 err = udp.Write(&buf, udp.AnnounceResponseHeader{
150 Interval: res.Interval.UnwrapOr(5 * 60),
151 Seeders: res.Seeders.Value,
152 Leechers: res.Leechers.Value,
157 b, err := udp.GetNodeAddrsCompactMarshaler(nodeAddrs, addrFamily).MarshalBinary()
159 err = fmt.Errorf("marshalling compact node addrs: %w", err)
163 n, err := me.SendResponse(ctx, buf.Bytes(), source)
168 err = io.ErrShortWrite
173 func (me *Server) handleConnect(ctx context.Context, source RequestSourceAddr, tid udp.TransactionId) error {
174 connId := randomConnectionId()
175 err := me.ConnTracker.Add(ctx, source.String(), connId)
177 err = fmt.Errorf("recording conn id: %w", err)
181 udp.Write(&buf, udp.ResponseHeader{
182 Action: udp.ActionConnect,
185 udp.Write(&buf, udp.ConnectionResponse{connId})
186 n, err := me.SendResponse(ctx, buf.Bytes(), source)
191 err = io.ErrShortWrite
196 func randomConnectionId() udp.ConnectionId {
198 _, err := rand.Read(b[:])
202 return binary.BigEndian.Uint64(b[:])
205 func RunSimple(ctx context.Context, s *Server, pc net.PacketConn, family udp.AddrFamily) error {
206 ctx, cancel := context.WithCancel(ctx)
209 // Limit concurrent handled requests.
210 sem := make(chan struct{}, 1000)
212 n, addr, err := pc.ReadFrom(b[:])
220 log.Printf("dropping request from %v: concurrency limit reached", addr)
222 case sem <- struct{}{}:
224 b := append([]byte(nil), b[:n]...)
226 defer func() { <-sem }()
227 err := s.HandleRequest(ctx, family, addr, b)
229 log.Printf("error handling %v byte request from %v: %v", n, addr, err)