1 package udpTrackerServer
13 "github.com/anacrolix/dht/v2/krpc"
14 "github.com/anacrolix/log"
16 "github.com/anacrolix/torrent/tracker"
17 "github.com/anacrolix/torrent/tracker/udp"
20 type ConnectionTrackerAddr = string
22 type ConnectionTracker interface {
23 Add(ctx context.Context, addr ConnectionTrackerAddr, id udp.ConnectionId) error
24 Check(ctx context.Context, addr ConnectionTrackerAddr, id udp.ConnectionId) (bool, error)
27 type InfoHash = [20]byte
29 type AnnounceTracker = tracker.AnnounceTracker
32 ConnTracker ConnectionTracker
33 SendResponse func(data []byte, addr net.Addr) (int, error)
34 AnnounceTracker AnnounceTracker
37 type RequestSourceAddr = net.Addr
39 func (me *Server) HandleRequest(
41 family udp.AddrFamily,
42 source RequestSourceAddr,
45 var h udp.RequestHeader
48 err := udp.Read(&r, &h)
50 err = fmt.Errorf("reading request header: %w", err)
54 case udp.ActionConnect:
55 err = me.handleConnect(ctx, source, h.TransactionId)
56 case udp.ActionAnnounce:
57 err = me.handleAnnounce(ctx, family, source, h.ConnectionId, h.TransactionId, &r)
59 err = fmt.Errorf("unimplemented")
62 err = fmt.Errorf("handling action %v: %w", h.Action, err)
67 func (me *Server) handleAnnounce(
69 addrFamily udp.AddrFamily,
70 source RequestSourceAddr,
71 connId udp.ConnectionId,
72 tid udp.TransactionId,
75 ok, err := me.ConnTracker.Check(ctx, source.String(), connId)
77 err = fmt.Errorf("checking conn id: %w", err)
81 return fmt.Errorf("invalid connection id: %v", connId)
83 var req udp.AnnounceRequest
84 err = udp.Read(r, &req)
88 // TODO: This should be done asynchronously to responding to the announce.
89 announceAddr, err := netip.ParseAddrPort(source.String())
91 err = fmt.Errorf("converting source net.Addr to AnnounceAddr: %w", err)
94 err = me.AnnounceTracker.TrackAnnounce(ctx, req, announceAddr)
98 peers, err := me.AnnounceTracker.GetPeers(ctx, req.InfoHash, tracker.GetPeersOpts{})
102 nodeAddrs := make([]krpc.NodeAddr, 0, len(peers))
103 for _, p := range peers {
108 case udp.AddrFamilyIpv4:
109 if !p.Addr().Unmap().Is4() {
112 ipBuf := p.Addr().As4()
114 case udp.AddrFamilyIpv6:
115 ipBuf := p.Addr().As16()
118 nodeAddrs = append(nodeAddrs, krpc.NodeAddr{
124 err = udp.Write(&buf, udp.ResponseHeader{
125 Action: udp.ActionAnnounce,
131 err = udp.Write(&buf, udp.AnnounceResponseHeader{})
135 b, err := udp.GetNodeAddrsCompactMarshaler(nodeAddrs, addrFamily).MarshalBinary()
137 err = fmt.Errorf("marshalling compact node addrs: %w", err)
141 n, err := me.SendResponse(buf.Bytes(), source)
146 err = io.ErrShortWrite
151 func (me *Server) handleConnect(ctx context.Context, source RequestSourceAddr, tid udp.TransactionId) error {
152 connId := randomConnectionId()
153 err := me.ConnTracker.Add(ctx, source.String(), connId)
155 err = fmt.Errorf("recording conn id: %w", err)
159 udp.Write(&buf, udp.ResponseHeader{
160 Action: udp.ActionConnect,
163 udp.Write(&buf, udp.ConnectionResponse{connId})
164 n, err := me.SendResponse(buf.Bytes(), source)
169 err = io.ErrShortWrite
174 func randomConnectionId() udp.ConnectionId {
176 _, err := rand.Read(b[:])
180 return int64(binary.BigEndian.Uint64(b[:]))
183 func RunSimple(ctx context.Context, s *Server, pc net.PacketConn, family udp.AddrFamily) error {
184 ctx, cancel := context.WithCancel(ctx)
188 n, addr, err := pc.ReadFrom(b[:])
193 err := s.HandleRequest(ctx, family, addr, b[:n])
195 log.Printf("error handling %v byte request from %v: %v", n, addr, err)