From eb9c032f2b5a1071052d3c0ccf6dbc97d7a907e4 Mon Sep 17 00:00:00 2001 From: Matt Joiner Date: Mon, 5 Dec 2022 12:52:19 +1100 Subject: [PATCH] Start a UDP server implementation --- tracker/udp/addr-family.go | 25 +++++ tracker/udp/server/server.go | 199 +++++++++++++++++++++++++++++++++++ tracker/udp_test.go | 1 + 3 files changed, 225 insertions(+) diff --git a/tracker/udp/addr-family.go b/tracker/udp/addr-family.go index 0213f41f..ddecb4c9 100644 --- a/tracker/udp/addr-family.go +++ b/tracker/udp/addr-family.go @@ -1 +1,26 @@ package udp + +import ( + "encoding" + + "github.com/anacrolix/dht/v2/krpc" +) + +// Discriminates behaviours based on address family in use. +type AddrFamily int + +const ( + AddrFamilyIpv4 = iota + 1 + AddrFamilyIpv6 +) + +// Returns a marshaler for the given node addrs for the specified family. +func GetNodeAddrsCompactMarshaler(nas []krpc.NodeAddr, family AddrFamily) encoding.BinaryMarshaler { + switch family { + case AddrFamilyIpv4: + return krpc.CompactIPv4NodeAddrs(nas) + case AddrFamilyIpv6: + return krpc.CompactIPv6NodeAddrs(nas) + } + return nil +} diff --git a/tracker/udp/server/server.go b/tracker/udp/server/server.go index abb4e431..81500214 100644 --- a/tracker/udp/server/server.go +++ b/tracker/udp/server/server.go @@ -1 +1,200 @@ package server + +import ( + "bytes" + "context" + "crypto/rand" + "encoding/binary" + "fmt" + "io" + "net" + "net/netip" + + "github.com/anacrolix/dht/v2/krpc" + "github.com/anacrolix/log" + "github.com/anacrolix/torrent/tracker/udp" +) + +type ConnectionTrackerAddr = string + +type ConnectionTracker interface { + Add(ctx context.Context, addr ConnectionTrackerAddr, id udp.ConnectionId) error + Check(ctx context.Context, addr ConnectionTrackerAddr, id udp.ConnectionId) (bool, error) +} + +type InfoHash = [20]byte + +// This is reserved for stuff like filtering by IP version, avoiding an announcer's IP or key, +// limiting return count, etc. +type GetPeersOpts struct{} + +type PeerInfo struct { + netip.AddrPort +} + +type AnnounceTracker interface { + TrackAnnounce(ctx context.Context, req udp.AnnounceRequest, addr RequestSourceAddr) error + Scrape(ctx context.Context, infoHashes []InfoHash) ([]udp.ScrapeInfohashResult, error) + GetPeers(ctx context.Context, infoHash InfoHash, opts GetPeersOpts) ([]PeerInfo, error) +} + +type Server struct { + ConnTracker ConnectionTracker + SendResponse func(data []byte, addr net.Addr) (int, error) + AnnounceTracker AnnounceTracker +} + +type RequestSourceAddr = net.Addr + +func (me *Server) HandleRequest(ctx context.Context, family udp.AddrFamily, source RequestSourceAddr, body []byte) error { + var h udp.RequestHeader + var r bytes.Reader + r.Reset(body) + err := udp.Read(&r, &h) + if err != nil { + err = fmt.Errorf("reading request header: %w", err) + return err + } + switch h.Action { + case udp.ActionConnect: + err = me.handleConnect(ctx, source, h.TransactionId) + case udp.ActionAnnounce: + err = me.handleAnnounce(ctx, family, source, h.ConnectionId, h.TransactionId, &r) + default: + err = fmt.Errorf("unimplemented") + } + if err != nil { + err = fmt.Errorf("handling action %v: %w", h.Action, err) + } + return err +} + +func (me *Server) handleAnnounce( + ctx context.Context, + addrFamily udp.AddrFamily, + source RequestSourceAddr, + connId udp.ConnectionId, + tid udp.TransactionId, + r *bytes.Reader, +) error { + ok, err := me.ConnTracker.Check(ctx, source.String(), connId) + if err != nil { + err = fmt.Errorf("checking conn id: %w", err) + return err + } + if !ok { + return fmt.Errorf("invalid connection id: %v", connId) + } + var req udp.AnnounceRequest + err = udp.Read(r, &req) + if err != nil { + return err + } + // TODO: This should be done asynchronously to responding to the announce. + err = me.AnnounceTracker.TrackAnnounce(ctx, req, source) + if err != nil { + return err + } + peers, err := me.AnnounceTracker.GetPeers(ctx, req.InfoHash, GetPeersOpts{}) + if err != nil { + return err + } + nodeAddrs := make([]krpc.NodeAddr, 0, len(peers)) + for _, p := range peers { + var ip net.IP + switch addrFamily { + default: + continue + case udp.AddrFamilyIpv4: + if !p.Addr().Unmap().Is4() { + continue + } + ipBuf := p.Addr().As4() + ip = ipBuf[:] + case udp.AddrFamilyIpv6: + ipBuf := p.Addr().As16() + ip = ipBuf[:] + } + nodeAddrs = append(nodeAddrs, krpc.NodeAddr{ + IP: ip[:], + Port: int(p.Port()), + }) + } + var buf bytes.Buffer + err = udp.Write(&buf, udp.ResponseHeader{ + Action: udp.ActionAnnounce, + TransactionId: tid, + }) + if err != nil { + return err + } + err = udp.Write(&buf, udp.AnnounceResponseHeader{}) + if err != nil { + return err + } + b, err := udp.GetNodeAddrsCompactMarshaler(nodeAddrs, addrFamily).MarshalBinary() + if err != nil { + err = fmt.Errorf("marshalling compact node addrs: %w", err) + return err + } + log.Print(nodeAddrs) + buf.Write(b) + n, err := me.SendResponse(buf.Bytes(), source) + if err != nil { + return err + } + if n < buf.Len() { + err = io.ErrShortWrite + } + return err +} + +func (me *Server) handleConnect(ctx context.Context, source RequestSourceAddr, tid udp.TransactionId) error { + connId := randomConnectionId() + err := me.ConnTracker.Add(ctx, source.String(), connId) + if err != nil { + err = fmt.Errorf("recording conn id: %w", err) + return err + } + var buf bytes.Buffer + udp.Write(&buf, udp.ResponseHeader{ + Action: udp.ActionConnect, + TransactionId: tid, + }) + udp.Write(&buf, udp.ConnectionResponse{connId}) + n, err := me.SendResponse(buf.Bytes(), source) + if err != nil { + return err + } + if n < buf.Len() { + err = io.ErrShortWrite + } + return err +} + +func randomConnectionId() udp.ConnectionId { + var b [8]byte + _, err := rand.Read(b[:]) + if err != nil { + panic(err) + } + return int64(binary.BigEndian.Uint64(b[:])) +} + +func RunServer(ctx context.Context, s *Server, pc net.PacketConn, family udp.AddrFamily) error { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + for { + var b [1500]byte + n, addr, err := pc.ReadFrom(b[:]) + if err != nil { + return err + } + go func() { + err := s.HandleRequest(ctx, family, addr, b[:n]) + if err != nil { + log.Printf("error handling %v byte request from %v: %v", n, addr, err) + } + }() + } +} diff --git a/tracker/udp_test.go b/tracker/udp_test.go index 7354063b..751e41b9 100644 --- a/tracker/udp_test.go +++ b/tracker/udp_test.go @@ -23,6 +23,7 @@ import ( var trackers = []string{ "udp://tracker.opentrackr.org:1337/announce", "udp://tracker.openbittorrent.com:6969/announce", + "udp://localhost:42069", } func TestAnnounceLocalhost(t *testing.T) { -- 2.44.0