package server
+
+import (
+ "bytes"
+ "context"
+ "crypto/rand"
+ "encoding/binary"
+ "fmt"
+ "io"
+ "net"
+ "net/netip"
+
+ "github.com/anacrolix/dht/v2/krpc"
+ "github.com/anacrolix/log"
+ "github.com/anacrolix/torrent/tracker/udp"
+)
+
+type ConnectionTrackerAddr = string
+
+type ConnectionTracker interface {
+ Add(ctx context.Context, addr ConnectionTrackerAddr, id udp.ConnectionId) error
+ Check(ctx context.Context, addr ConnectionTrackerAddr, id udp.ConnectionId) (bool, error)
+}
+
+type InfoHash = [20]byte
+
+// This is reserved for stuff like filtering by IP version, avoiding an announcer's IP or key,
+// limiting return count, etc.
+type GetPeersOpts struct{}
+
+type PeerInfo struct {
+ netip.AddrPort
+}
+
+type AnnounceTracker interface {
+ TrackAnnounce(ctx context.Context, req udp.AnnounceRequest, addr RequestSourceAddr) error
+ Scrape(ctx context.Context, infoHashes []InfoHash) ([]udp.ScrapeInfohashResult, error)
+ GetPeers(ctx context.Context, infoHash InfoHash, opts GetPeersOpts) ([]PeerInfo, error)
+}
+
+type Server struct {
+ ConnTracker ConnectionTracker
+ SendResponse func(data []byte, addr net.Addr) (int, error)
+ AnnounceTracker AnnounceTracker
+}
+
+type RequestSourceAddr = net.Addr
+
+func (me *Server) HandleRequest(ctx context.Context, family udp.AddrFamily, source RequestSourceAddr, body []byte) error {
+ var h udp.RequestHeader
+ var r bytes.Reader
+ r.Reset(body)
+ err := udp.Read(&r, &h)
+ if err != nil {
+ err = fmt.Errorf("reading request header: %w", err)
+ return err
+ }
+ switch h.Action {
+ case udp.ActionConnect:
+ err = me.handleConnect(ctx, source, h.TransactionId)
+ case udp.ActionAnnounce:
+ err = me.handleAnnounce(ctx, family, source, h.ConnectionId, h.TransactionId, &r)
+ default:
+ err = fmt.Errorf("unimplemented")
+ }
+ if err != nil {
+ err = fmt.Errorf("handling action %v: %w", h.Action, err)
+ }
+ return err
+}
+
+func (me *Server) handleAnnounce(
+ ctx context.Context,
+ addrFamily udp.AddrFamily,
+ source RequestSourceAddr,
+ connId udp.ConnectionId,
+ tid udp.TransactionId,
+ r *bytes.Reader,
+) error {
+ ok, err := me.ConnTracker.Check(ctx, source.String(), connId)
+ if err != nil {
+ err = fmt.Errorf("checking conn id: %w", err)
+ return err
+ }
+ if !ok {
+ return fmt.Errorf("invalid connection id: %v", connId)
+ }
+ var req udp.AnnounceRequest
+ err = udp.Read(r, &req)
+ if err != nil {
+ return err
+ }
+ // TODO: This should be done asynchronously to responding to the announce.
+ err = me.AnnounceTracker.TrackAnnounce(ctx, req, source)
+ if err != nil {
+ return err
+ }
+ peers, err := me.AnnounceTracker.GetPeers(ctx, req.InfoHash, GetPeersOpts{})
+ if err != nil {
+ return err
+ }
+ nodeAddrs := make([]krpc.NodeAddr, 0, len(peers))
+ for _, p := range peers {
+ var ip net.IP
+ switch addrFamily {
+ default:
+ continue
+ case udp.AddrFamilyIpv4:
+ if !p.Addr().Unmap().Is4() {
+ continue
+ }
+ ipBuf := p.Addr().As4()
+ ip = ipBuf[:]
+ case udp.AddrFamilyIpv6:
+ ipBuf := p.Addr().As16()
+ ip = ipBuf[:]
+ }
+ nodeAddrs = append(nodeAddrs, krpc.NodeAddr{
+ IP: ip[:],
+ Port: int(p.Port()),
+ })
+ }
+ var buf bytes.Buffer
+ err = udp.Write(&buf, udp.ResponseHeader{
+ Action: udp.ActionAnnounce,
+ TransactionId: tid,
+ })
+ if err != nil {
+ return err
+ }
+ err = udp.Write(&buf, udp.AnnounceResponseHeader{})
+ if err != nil {
+ return err
+ }
+ b, err := udp.GetNodeAddrsCompactMarshaler(nodeAddrs, addrFamily).MarshalBinary()
+ if err != nil {
+ err = fmt.Errorf("marshalling compact node addrs: %w", err)
+ return err
+ }
+ log.Print(nodeAddrs)
+ buf.Write(b)
+ n, err := me.SendResponse(buf.Bytes(), source)
+ if err != nil {
+ return err
+ }
+ if n < buf.Len() {
+ err = io.ErrShortWrite
+ }
+ return err
+}
+
+func (me *Server) handleConnect(ctx context.Context, source RequestSourceAddr, tid udp.TransactionId) error {
+ connId := randomConnectionId()
+ err := me.ConnTracker.Add(ctx, source.String(), connId)
+ if err != nil {
+ err = fmt.Errorf("recording conn id: %w", err)
+ return err
+ }
+ var buf bytes.Buffer
+ udp.Write(&buf, udp.ResponseHeader{
+ Action: udp.ActionConnect,
+ TransactionId: tid,
+ })
+ udp.Write(&buf, udp.ConnectionResponse{connId})
+ n, err := me.SendResponse(buf.Bytes(), source)
+ if err != nil {
+ return err
+ }
+ if n < buf.Len() {
+ err = io.ErrShortWrite
+ }
+ return err
+}
+
+func randomConnectionId() udp.ConnectionId {
+ var b [8]byte
+ _, err := rand.Read(b[:])
+ if err != nil {
+ panic(err)
+ }
+ return int64(binary.BigEndian.Uint64(b[:]))
+}
+
+func RunServer(ctx context.Context, s *Server, pc net.PacketConn, family udp.AddrFamily) error {
+ ctx, cancel := context.WithCancel(ctx)
+ defer cancel()
+ for {
+ var b [1500]byte
+ n, addr, err := pc.ReadFrom(b[:])
+ if err != nil {
+ return err
+ }
+ go func() {
+ err := s.HandleRequest(ctx, family, addr, b[:n])
+ if err != nil {
+ log.Printf("error handling %v byte request from %v: %v", n, addr, err)
+ }
+ }()
+ }
+}