]> Sergey Matveev's repositories - btrtrc.git/blobdiff - tracker/udp/server/server.go
Start a UDP server implementation
[btrtrc.git] / tracker / udp / server / server.go
index abb4e431abd516750a5a1e5e2b77073c236b8f9e..815002147223177265625ac29639ae4280a4d7af 100644 (file)
@@ -1 +1,200 @@
 package server
+
+import (
+       "bytes"
+       "context"
+       "crypto/rand"
+       "encoding/binary"
+       "fmt"
+       "io"
+       "net"
+       "net/netip"
+
+       "github.com/anacrolix/dht/v2/krpc"
+       "github.com/anacrolix/log"
+       "github.com/anacrolix/torrent/tracker/udp"
+)
+
+type ConnectionTrackerAddr = string
+
+type ConnectionTracker interface {
+       Add(ctx context.Context, addr ConnectionTrackerAddr, id udp.ConnectionId) error
+       Check(ctx context.Context, addr ConnectionTrackerAddr, id udp.ConnectionId) (bool, error)
+}
+
+type InfoHash = [20]byte
+
+// This is reserved for stuff like filtering by IP version, avoiding an announcer's IP or key,
+// limiting return count, etc.
+type GetPeersOpts struct{}
+
+type PeerInfo struct {
+       netip.AddrPort
+}
+
+type AnnounceTracker interface {
+       TrackAnnounce(ctx context.Context, req udp.AnnounceRequest, addr RequestSourceAddr) error
+       Scrape(ctx context.Context, infoHashes []InfoHash) ([]udp.ScrapeInfohashResult, error)
+       GetPeers(ctx context.Context, infoHash InfoHash, opts GetPeersOpts) ([]PeerInfo, error)
+}
+
+type Server struct {
+       ConnTracker     ConnectionTracker
+       SendResponse    func(data []byte, addr net.Addr) (int, error)
+       AnnounceTracker AnnounceTracker
+}
+
+type RequestSourceAddr = net.Addr
+
+func (me *Server) HandleRequest(ctx context.Context, family udp.AddrFamily, source RequestSourceAddr, body []byte) error {
+       var h udp.RequestHeader
+       var r bytes.Reader
+       r.Reset(body)
+       err := udp.Read(&r, &h)
+       if err != nil {
+               err = fmt.Errorf("reading request header: %w", err)
+               return err
+       }
+       switch h.Action {
+       case udp.ActionConnect:
+               err = me.handleConnect(ctx, source, h.TransactionId)
+       case udp.ActionAnnounce:
+               err = me.handleAnnounce(ctx, family, source, h.ConnectionId, h.TransactionId, &r)
+       default:
+               err = fmt.Errorf("unimplemented")
+       }
+       if err != nil {
+               err = fmt.Errorf("handling action %v: %w", h.Action, err)
+       }
+       return err
+}
+
+func (me *Server) handleAnnounce(
+       ctx context.Context,
+       addrFamily udp.AddrFamily,
+       source RequestSourceAddr,
+       connId udp.ConnectionId,
+       tid udp.TransactionId,
+       r *bytes.Reader,
+) error {
+       ok, err := me.ConnTracker.Check(ctx, source.String(), connId)
+       if err != nil {
+               err = fmt.Errorf("checking conn id: %w", err)
+               return err
+       }
+       if !ok {
+               return fmt.Errorf("invalid connection id: %v", connId)
+       }
+       var req udp.AnnounceRequest
+       err = udp.Read(r, &req)
+       if err != nil {
+               return err
+       }
+       // TODO: This should be done asynchronously to responding to the announce.
+       err = me.AnnounceTracker.TrackAnnounce(ctx, req, source)
+       if err != nil {
+               return err
+       }
+       peers, err := me.AnnounceTracker.GetPeers(ctx, req.InfoHash, GetPeersOpts{})
+       if err != nil {
+               return err
+       }
+       nodeAddrs := make([]krpc.NodeAddr, 0, len(peers))
+       for _, p := range peers {
+               var ip net.IP
+               switch addrFamily {
+               default:
+                       continue
+               case udp.AddrFamilyIpv4:
+                       if !p.Addr().Unmap().Is4() {
+                               continue
+                       }
+                       ipBuf := p.Addr().As4()
+                       ip = ipBuf[:]
+               case udp.AddrFamilyIpv6:
+                       ipBuf := p.Addr().As16()
+                       ip = ipBuf[:]
+               }
+               nodeAddrs = append(nodeAddrs, krpc.NodeAddr{
+                       IP:   ip[:],
+                       Port: int(p.Port()),
+               })
+       }
+       var buf bytes.Buffer
+       err = udp.Write(&buf, udp.ResponseHeader{
+               Action:        udp.ActionAnnounce,
+               TransactionId: tid,
+       })
+       if err != nil {
+               return err
+       }
+       err = udp.Write(&buf, udp.AnnounceResponseHeader{})
+       if err != nil {
+               return err
+       }
+       b, err := udp.GetNodeAddrsCompactMarshaler(nodeAddrs, addrFamily).MarshalBinary()
+       if err != nil {
+               err = fmt.Errorf("marshalling compact node addrs: %w", err)
+               return err
+       }
+       log.Print(nodeAddrs)
+       buf.Write(b)
+       n, err := me.SendResponse(buf.Bytes(), source)
+       if err != nil {
+               return err
+       }
+       if n < buf.Len() {
+               err = io.ErrShortWrite
+       }
+       return err
+}
+
+func (me *Server) handleConnect(ctx context.Context, source RequestSourceAddr, tid udp.TransactionId) error {
+       connId := randomConnectionId()
+       err := me.ConnTracker.Add(ctx, source.String(), connId)
+       if err != nil {
+               err = fmt.Errorf("recording conn id: %w", err)
+               return err
+       }
+       var buf bytes.Buffer
+       udp.Write(&buf, udp.ResponseHeader{
+               Action:        udp.ActionConnect,
+               TransactionId: tid,
+       })
+       udp.Write(&buf, udp.ConnectionResponse{connId})
+       n, err := me.SendResponse(buf.Bytes(), source)
+       if err != nil {
+               return err
+       }
+       if n < buf.Len() {
+               err = io.ErrShortWrite
+       }
+       return err
+}
+
+func randomConnectionId() udp.ConnectionId {
+       var b [8]byte
+       _, err := rand.Read(b[:])
+       if err != nil {
+               panic(err)
+       }
+       return int64(binary.BigEndian.Uint64(b[:]))
+}
+
+func RunServer(ctx context.Context, s *Server, pc net.PacketConn, family udp.AddrFamily) error {
+       ctx, cancel := context.WithCancel(ctx)
+       defer cancel()
+       for {
+               var b [1500]byte
+               n, addr, err := pc.ReadFrom(b[:])
+               if err != nil {
+                       return err
+               }
+               go func() {
+                       err := s.HandleRequest(ctx, family, addr, b[:n])
+                       if err != nil {
+                               log.Printf("error handling %v byte request from %v: %v", n, addr, err)
+                       }
+               }()
+       }
+}