13 "github.com/anacrolix/dht/v2/krpc"
14 "github.com/anacrolix/log"
15 "github.com/anacrolix/torrent/tracker/udp"
18 type ConnectionTrackerAddr = string
20 type ConnectionTracker interface {
21 Add(ctx context.Context, addr ConnectionTrackerAddr, id udp.ConnectionId) error
22 Check(ctx context.Context, addr ConnectionTrackerAddr, id udp.ConnectionId) (bool, error)
25 type InfoHash = [20]byte
27 // This is reserved for stuff like filtering by IP version, avoiding an announcer's IP or key,
28 // limiting return count, etc.
29 type GetPeersOpts struct{}
31 type PeerInfo struct {
35 type AnnounceTracker interface {
36 TrackAnnounce(ctx context.Context, req udp.AnnounceRequest, addr RequestSourceAddr) error
37 Scrape(ctx context.Context, infoHashes []InfoHash) ([]udp.ScrapeInfohashResult, error)
38 GetPeers(ctx context.Context, infoHash InfoHash, opts GetPeersOpts) ([]PeerInfo, error)
42 ConnTracker ConnectionTracker
43 SendResponse func(data []byte, addr net.Addr) (int, error)
44 AnnounceTracker AnnounceTracker
47 type RequestSourceAddr = net.Addr
49 func (me *Server) HandleRequest(ctx context.Context, family udp.AddrFamily, source RequestSourceAddr, body []byte) error {
50 var h udp.RequestHeader
53 err := udp.Read(&r, &h)
55 err = fmt.Errorf("reading request header: %w", err)
59 case udp.ActionConnect:
60 err = me.handleConnect(ctx, source, h.TransactionId)
61 case udp.ActionAnnounce:
62 err = me.handleAnnounce(ctx, family, source, h.ConnectionId, h.TransactionId, &r)
64 err = fmt.Errorf("unimplemented")
67 err = fmt.Errorf("handling action %v: %w", h.Action, err)
72 func (me *Server) handleAnnounce(
74 addrFamily udp.AddrFamily,
75 source RequestSourceAddr,
76 connId udp.ConnectionId,
77 tid udp.TransactionId,
80 ok, err := me.ConnTracker.Check(ctx, source.String(), connId)
82 err = fmt.Errorf("checking conn id: %w", err)
86 return fmt.Errorf("invalid connection id: %v", connId)
88 var req udp.AnnounceRequest
89 err = udp.Read(r, &req)
93 // TODO: This should be done asynchronously to responding to the announce.
94 err = me.AnnounceTracker.TrackAnnounce(ctx, req, source)
98 peers, err := me.AnnounceTracker.GetPeers(ctx, req.InfoHash, GetPeersOpts{})
102 nodeAddrs := make([]krpc.NodeAddr, 0, len(peers))
103 for _, p := range peers {
108 case udp.AddrFamilyIpv4:
109 if !p.Addr().Unmap().Is4() {
112 ipBuf := p.Addr().As4()
114 case udp.AddrFamilyIpv6:
115 ipBuf := p.Addr().As16()
118 nodeAddrs = append(nodeAddrs, krpc.NodeAddr{
124 err = udp.Write(&buf, udp.ResponseHeader{
125 Action: udp.ActionAnnounce,
131 err = udp.Write(&buf, udp.AnnounceResponseHeader{})
135 b, err := udp.GetNodeAddrsCompactMarshaler(nodeAddrs, addrFamily).MarshalBinary()
137 err = fmt.Errorf("marshalling compact node addrs: %w", err)
142 n, err := me.SendResponse(buf.Bytes(), source)
147 err = io.ErrShortWrite
152 func (me *Server) handleConnect(ctx context.Context, source RequestSourceAddr, tid udp.TransactionId) error {
153 connId := randomConnectionId()
154 err := me.ConnTracker.Add(ctx, source.String(), connId)
156 err = fmt.Errorf("recording conn id: %w", err)
160 udp.Write(&buf, udp.ResponseHeader{
161 Action: udp.ActionConnect,
164 udp.Write(&buf, udp.ConnectionResponse{connId})
165 n, err := me.SendResponse(buf.Bytes(), source)
170 err = io.ErrShortWrite
175 func randomConnectionId() udp.ConnectionId {
177 _, err := rand.Read(b[:])
181 return int64(binary.BigEndian.Uint64(b[:]))
184 func RunServer(ctx context.Context, s *Server, pc net.PacketConn, family udp.AddrFamily) error {
185 ctx, cancel := context.WithCancel(ctx)
189 n, addr, err := pc.ReadFrom(b[:])
194 err := s.HandleRequest(ctx, family, addr, b[:n])
196 log.Printf("error handling %v byte request from %v: %v", n, addr, err)