tracker/http/peer.go | 8 ++++++++ tracker/http/server/server.go | 125 +++++++++++++++++++++++++++++++++++++++++++++++++++++ tracker/server/server.go | 324 +++++++++++++++++++++++++++++++++++++++++++++++++++++ tracker/server/upstream-announcing.go | 18 ++++++++++++++++++ tracker/server/use.go | 9 +++++++++ tracker/udp-server_test.go | 8 ++++---- tracker/udp/addr-family.go | 25 +++++++++++++++++++++++++ tracker/udp/announce.go | 7 ++++++- tracker/udp/protocol.go | 2 +- tracker/udp/server/server.go | 242 ++++++++++++++++++++++++++++++++++++++++++++++++++++- tracker/udp_test.go | 1 + diff --git a/tracker/http/peer.go b/tracker/http/peer.go index 363ba6d3c7b9f7d317b0f2785a3e5338c16df49e..b0deee0b3e540772669d74ceed32e43c497ed601 100644 --- a/tracker/http/peer.go +++ b/tracker/http/peer.go @@ -3,14 +3,22 @@ import ( "fmt" "net" + "net/netip" "github.com/anacrolix/dht/v2/krpc" ) +// TODO: Use netip.Addr and Option[[20]byte]. type Peer struct { IP net.IP `bencode:"ip"` Port int `bencode:"port"` ID []byte `bencode:"peer id"` +} + +func (p Peer) ToNetipAddrPort() (addrPort netip.AddrPort, ok bool) { + addr, ok := netip.AddrFromSlice(p.IP) + addrPort = netip.AddrPortFrom(addr, uint16(p.Port)) + return } func (p Peer) String() string { diff --git a/tracker/http/server/server.go b/tracker/http/server/server.go new file mode 100644 index 0000000000000000000000000000000000000000..30be15c65c8c4719bbdcf1b173919294b09d286b --- /dev/null +++ b/tracker/http/server/server.go @@ -0,0 +1,125 @@ +package httpTrackerServer + +import ( + "fmt" + "net" + "net/http" + "net/netip" + "net/url" + "strconv" + + "github.com/anacrolix/dht/v2/krpc" + "github.com/anacrolix/generics" + "github.com/anacrolix/log" + trackerServer "github.com/anacrolix/torrent/tracker/server" + + "github.com/anacrolix/torrent/bencode" + "github.com/anacrolix/torrent/tracker" + httpTracker "github.com/anacrolix/torrent/tracker/http" +) + +type Handler struct { + Announce *trackerServer.AnnounceHandler + // Called to derive an announcer's IP if non-nil. If not specified, the Request.RemoteAddr is + // used. Necessary for instances running behind reverse proxies for example. + RequestHost func(r *http.Request) (netip.Addr, error) +} + +func unmarshalQueryKeyToArray(w http.ResponseWriter, key string, query url.Values) (ret [20]byte, ok bool) { + str := query.Get(key) + if len(str) != len(ret) { + http.Error(w, fmt.Sprintf("%v has wrong length", key), http.StatusBadRequest) + return + } + copy(ret[:], str) + ok = true + return +} + +// Returns false if there was an error and it was served. +func (me Handler) requestHostAddr(r *http.Request) (_ netip.Addr, err error) { + if me.RequestHost != nil { + return me.RequestHost(r) + } + host, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + return + } + return netip.ParseAddr(host) +} + +var requestHeadersLogger = log.Default.WithNames("request", "headers") + +func (me Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + vs := r.URL.Query() + var event tracker.AnnounceEvent + err := event.UnmarshalText([]byte(vs.Get("event"))) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + infoHash, ok := unmarshalQueryKeyToArray(w, "info_hash", vs) + if !ok { + return + } + peerId, ok := unmarshalQueryKeyToArray(w, "peer_id", vs) + if !ok { + return + } + requestHeadersLogger.Levelf(log.Debug, "request RemoteAddr=%q, header=%q", r.RemoteAddr, r.Header) + addr, err := me.requestHostAddr(r) + if err != nil { + log.Printf("error getting requester IP: %v", err) + http.Error(w, "error determining your IP", http.StatusBadGateway) + return + } + portU64, _ := strconv.ParseUint(vs.Get("port"), 0, 16) + addrPort := netip.AddrPortFrom(addr, uint16(portU64)) + left, err := strconv.ParseInt(vs.Get("left"), 0, 64) + if err != nil { + left = -1 + } + res := me.Announce.Serve( + r.Context(), + tracker.AnnounceRequest{ + InfoHash: infoHash, + PeerId: peerId, + Event: event, + Port: addrPort.Port(), + NumWant: -1, + Left: left, + }, + addrPort, + trackerServer.GetPeersOpts{ + MaxCount: generics.Some[uint](200), + }, + ) + err = res.Err + if err != nil { + log.Printf("error serving announce: %v", err) + http.Error(w, "error handling announce", http.StatusInternalServerError) + return + } + var resp httpTracker.HttpResponse + resp.Incomplete = res.Leechers.Value + resp.Complete = res.Seeders.Value + resp.Interval = res.Interval.UnwrapOr(5 * 60) + resp.Peers.Compact = true + for _, peer := range res.Peers { + if peer.Addr().Is4() { + resp.Peers.List = append(resp.Peers.List, tracker.Peer{ + IP: peer.Addr().AsSlice(), + Port: int(peer.Port()), + }) + } else if peer.Addr().Is6() { + resp.Peers6 = append(resp.Peers6, krpc.NodeAddr{ + IP: peer.Addr().AsSlice(), + Port: int(peer.Port()), + }) + } + } + err = bencode.NewEncoder(w).Encode(resp) + if err != nil { + log.Printf("error encoding and writing response body: %v", err) + } +} diff --git a/tracker/server/server.go b/tracker/server/server.go new file mode 100644 index 0000000000000000000000000000000000000000..823816d58f6db300f0971113605ccd46f5d0f3e5 --- /dev/null +++ b/tracker/server/server.go @@ -0,0 +1,324 @@ +package trackerServer + +import ( + "context" + "encoding/hex" + "fmt" + "net/netip" + "sync" + "time" + + "github.com/anacrolix/generics" + "github.com/anacrolix/log" + "github.com/anacrolix/torrent/tracker" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/trace" + + "github.com/anacrolix/torrent/tracker/udp" +) + +// This is reserved for stuff like filtering by IP version, avoiding an announcer's IP or key, +// limiting return count, etc. +type GetPeersOpts struct { + // Negative numbers are not allowed. + MaxCount generics.Option[uint] +} + +type InfoHash = [20]byte + +type PeerInfo struct { + AnnounceAddr +} + +type AnnounceAddr = netip.AddrPort + +type AnnounceTracker interface { + TrackAnnounce(ctx context.Context, req udp.AnnounceRequest, addr AnnounceAddr) error + Scrape(ctx context.Context, infoHashes []InfoHash) ([]udp.ScrapeInfohashResult, error) + GetPeers( + ctx context.Context, + infoHash InfoHash, + opts GetPeersOpts, + remote AnnounceAddr, + ) ServerAnnounceResult +} + +type ServerAnnounceResult struct { + Err error + Peers []PeerInfo + Interval generics.Option[int32] + Leechers generics.Option[int32] + Seeders generics.Option[int32] +} + +type AnnounceHandler struct { + AnnounceTracker AnnounceTracker + + UpstreamTrackers []Client + UpstreamTrackerUrls []string + UpstreamAnnouncePeerId [20]byte + UpstreamAnnounceGate UpstreamAnnounceGater + + mu sync.Mutex + // Operations are only removed when all the upstream peers have been tracked. + ongoingUpstreamAugmentations map[InfoHash]augmentationOperation +} + +type peerSet = map[PeerInfo]struct{} + +type augmentationOperation struct { + // Closed when no more announce responses are pending. finalPeers will contain all the peers + // seen. + doneAnnouncing chan struct{} + // This receives the latest peerSet until doneAnnouncing is closed. + curPeers chan peerSet + // This contains the final peerSet after doneAnnouncing is closed. + finalPeers peerSet +} + +func (me augmentationOperation) getCurPeers() (ret peerSet) { + ret, _ = me.getCurPeersAndDone() + return +} + +func (me augmentationOperation) getCurPeersAndDone() (ret peerSet, done bool) { + select { + case ret = <-me.curPeers: + case <-me.doneAnnouncing: + ret = copyPeerSet(me.finalPeers) + done = true + } + return +} + +// Adds peers from new that aren't in orig. Modifies both arguments. +func addMissing(orig []PeerInfo, new peerSet) { + for _, peer := range orig { + delete(new, peer) + } + for peer := range new { + orig = append(orig, peer) + } +} + +var tracer = otel.Tracer("torrent.tracker.udp") + +func (me *AnnounceHandler) Serve( + ctx context.Context, req AnnounceRequest, addr AnnounceAddr, opts GetPeersOpts, +) (ret ServerAnnounceResult) { + ctx, span := tracer.Start( + ctx, + "AnnounceHandler.Serve", + trace.WithAttributes( + attribute.Int64("announce.request.num_want", int64(req.NumWant)), + attribute.Int("announce.request.port", int(req.Port)), + attribute.String("announce.request.info_hash", hex.EncodeToString(req.InfoHash[:])), + attribute.String("announce.request.event", req.Event.String()), + attribute.Int64("announce.get_peers.opts.max_count_value", int64(opts.MaxCount.Value)), + attribute.Bool("announce.get_peers.opts.max_count_ok", opts.MaxCount.Ok), + attribute.String("announce.source.addr.ip", addr.Addr().String()), + attribute.Int("announce.source.addr.port", int(addr.Port())), + ), + ) + defer span.End() + defer func() { + span.SetAttributes(attribute.Int("announce.get_peers.len", len(ret.Peers))) + if ret.Err != nil { + span.SetStatus(codes.Error, ret.Err.Error()) + } + }() + + if req.Port != 0 { + addr = netip.AddrPortFrom(addr.Addr(), req.Port) + } + ret.Err = me.AnnounceTracker.TrackAnnounce(ctx, req, addr) + if ret.Err != nil { + ret.Err = fmt.Errorf("tracking announce: %w", ret.Err) + return + } + infoHash := req.InfoHash + var op generics.Option[augmentationOperation] + // Grab a handle to any augmentations that are already running. + me.mu.Lock() + op.Value, op.Ok = me.ongoingUpstreamAugmentations[infoHash] + me.mu.Unlock() + // Apply num_want limit to max count. I really can't tell if this is the right place to do it, + // but it seems the most flexible. + if req.NumWant != -1 { + newCount := uint(req.NumWant) + if opts.MaxCount.Ok { + if newCount < opts.MaxCount.Value { + opts.MaxCount.Value = newCount + } + } else { + opts.MaxCount = generics.Some(newCount) + } + } + ret = me.AnnounceTracker.GetPeers(ctx, infoHash, opts, addr) + if ret.Err != nil { + return + } + // Take whatever peers it has ready. If it's finished, it doesn't matter if we do this inside + // the mutex or not. + if op.Ok { + curPeers, done := op.Value.getCurPeersAndDone() + addMissing(ret.Peers, curPeers) + if done { + // It doesn't get any better with this operation. Forget it. + op.Ok = false + } + } + me.mu.Lock() + // If we didn't have an operation, and don't have enough peers, start one. Allowing 1 is + // assuming the announcing peer might be that one. Really we should record a value to prevent + // duplicate announces. Also don't announce upstream if we got no peers because the caller asked + // for none. + if !op.Ok && len(ret.Peers) <= 1 && opts.MaxCount.UnwrapOr(1) > 0 { + op.Value, op.Ok = me.ongoingUpstreamAugmentations[infoHash] + if !op.Ok { + op.Set(me.augmentPeersFromUpstream(req.InfoHash)) + generics.MakeMapIfNilAndSet(&me.ongoingUpstreamAugmentations, infoHash, op.Value) + } + } + me.mu.Unlock() + // Wait a while for the current operation. + if op.Ok { + // Force the augmentation to return with whatever it has if it hasn't completed in a + // reasonable time. + ctx, cancel := context.WithTimeout(ctx, 10*time.Second) + select { + case <-ctx.Done(): + case <-op.Value.doneAnnouncing: + } + cancel() + addMissing(ret.Peers, op.Value.getCurPeers()) + } + return +} + +func (me *AnnounceHandler) augmentPeersFromUpstream(infoHash [20]byte) augmentationOperation { + const announceTimeout = time.Minute + announceCtx, cancel := context.WithTimeout(context.Background(), announceTimeout) + subReq := AnnounceRequest{ + InfoHash: infoHash, + PeerId: me.UpstreamAnnouncePeerId, + Event: tracker.None, + Key: 0, + NumWant: -1, + Port: 0, + } + peersChan := make(chan []Peer) + var pendingUpstreams sync.WaitGroup + for i := range me.UpstreamTrackers { + client := me.UpstreamTrackers[i] + url := me.UpstreamTrackerUrls[i] + pendingUpstreams.Add(1) + go func() { + started, err := me.UpstreamAnnounceGate.Start(announceCtx, url, infoHash, announceTimeout) + if err != nil { + log.Printf("error reserving announce for %x to %v: %v", infoHash, url, err) + } + if err != nil || !started { + peersChan <- nil + return + } + log.Printf("announcing %x upstream to %v", infoHash, url) + resp, err := client.Announce(announceCtx, subReq, tracker.AnnounceOpt{ + UserAgent: "aragorn", + }) + interval := resp.Interval + go func() { + if interval < 5*60 { + // This is as much to reduce load on upstream trackers in the event of errors, + // as it is to reduce load on our peer store. + interval = 5 * 60 + } + err := me.UpstreamAnnounceGate.Completed(context.Background(), url, infoHash, interval) + if err != nil { + log.Printf("error recording completed announce for %x to %v: %v", infoHash, url, err) + } + }() + peersChan <- resp.Peers + if err != nil { + log.Levelf(log.Warning, "error announcing to upstream %q: %v", url, err) + } + }() + } + peersToTrack := make(map[string]Peer) + go func() { + pendingUpstreams.Wait() + cancel() + close(peersChan) + log.Levelf(log.Debug, "adding %v distinct peers from upstream trackers", len(peersToTrack)) + for _, peer := range peersToTrack { + addrPort, ok := peer.ToNetipAddrPort() + if !ok { + continue + } + trackReq := AnnounceRequest{ + InfoHash: infoHash, + Event: tracker.Started, + Port: uint16(peer.Port), + // Let's assume upstream peers are leechers without knowing better. + Left: -1, + } + copy(trackReq.PeerId[:], peer.ID) + // TODO: How do we know if these peers are leechers or seeders? + err := me.AnnounceTracker.TrackAnnounce(context.TODO(), trackReq, addrPort) + if err != nil { + log.Levelf(log.Error, "error tracking upstream peer: %v", err) + } + } + me.mu.Lock() + delete(me.ongoingUpstreamAugmentations, infoHash) + me.mu.Unlock() + }() + curPeersChan := make(chan map[PeerInfo]struct{}) + doneChan := make(chan struct{}) + retPeers := make(map[PeerInfo]struct{}) + go func() { + defer close(doneChan) + for { + select { + case peers, ok := <-peersChan: + if !ok { + return + } + voldemort(peers, peersToTrack, retPeers) + pendingUpstreams.Done() + case curPeersChan <- copyPeerSet(retPeers): + } + } + }() + // Take return references. + return augmentationOperation{ + curPeers: curPeersChan, + finalPeers: retPeers, + doneAnnouncing: doneChan, + } +} + +func copyPeerSet(orig peerSet) (ret peerSet) { + ret = make(peerSet, len(orig)) + for k, v := range orig { + ret[k] = v + } + return +} + +// Adds peers to trailing containers. +func voldemort(peers []Peer, toTrack map[string]Peer, sets ...map[PeerInfo]struct{}) { + for _, protoPeer := range peers { + toTrack[protoPeer.String()] = protoPeer + addr, ok := netip.AddrFromSlice(protoPeer.IP) + if !ok { + continue + } + handlerPeer := PeerInfo{netip.AddrPortFrom(addr, uint16(protoPeer.Port))} + for _, set := range sets { + set[handlerPeer] = struct{}{} + } + } +} diff --git a/tracker/server/upstream-announcing.go b/tracker/server/upstream-announcing.go new file mode 100644 index 0000000000000000000000000000000000000000..cfbf61c85a502219264f5e58ac30ad1c3568575a --- /dev/null +++ b/tracker/server/upstream-announcing.go @@ -0,0 +1,18 @@ +package trackerServer + +import ( + "context" + "time" +) + +type UpstreamAnnounceGater interface { + Start(ctx context.Context, tracker string, infoHash InfoHash, + // How long the announce block remains before discarding it. + timeout time.Duration, + ) (bool, error) + Completed( + ctx context.Context, tracker string, infoHash InfoHash, + // Num of seconds reported by tracker, or some suitable value the caller has chosen. + interval int32, + ) error +} diff --git a/tracker/server/use.go b/tracker/server/use.go new file mode 100644 index 0000000000000000000000000000000000000000..942321c554b9c737812c3fb33a8347a3fe9ba83f --- /dev/null +++ b/tracker/server/use.go @@ -0,0 +1,9 @@ +package trackerServer + +import "github.com/anacrolix/torrent/tracker" + +type ( + AnnounceRequest = tracker.AnnounceRequest + Client = tracker.Client + Peer = tracker.Peer +) diff --git a/tracker/udp-server_test.go b/tracker/udp-server_test.go index 824038ea276870c166e79fc052c909258890a6fd..7308ed0d8ef322514cfff25b1331976ffc564fa1 100644 --- a/tracker/udp-server_test.go +++ b/tracker/udp-server_test.go @@ -21,7 +21,7 @@ } type server struct { pc net.PacketConn - conns map[int64]struct{} + conns map[udp.ConnectionId]struct{} t map[[20]byte]torrent } @@ -46,10 +46,10 @@ _, err = s.pc.WriteTo(b, addr) return } -func (s *server) newConn() (ret int64) { - ret = rand.Int63() +func (s *server) newConn() (ret udp.ConnectionId) { + ret = rand.Uint64() if s.conns == nil { - s.conns = make(map[int64]struct{}) + s.conns = make(map[udp.ConnectionId]struct{}) } s.conns[ret] = struct{}{} return diff --git a/tracker/udp/addr-family.go b/tracker/udp/addr-family.go index 0213f41f0215799c942685d46b38544290b3680b..ddecb4c9fafb3e4ed0225a081fa36055cf1832b0 100644 --- a/tracker/udp/addr-family.go +++ b/tracker/udp/addr-family.go @@ -1 +1,26 @@ package udp + +import ( + "encoding" + + "github.com/anacrolix/dht/v2/krpc" +) + +// Discriminates behaviours based on address family in use. +type AddrFamily int + +const ( + AddrFamilyIpv4 = iota + 1 + AddrFamilyIpv6 +) + +// Returns a marshaler for the given node addrs for the specified family. +func GetNodeAddrsCompactMarshaler(nas []krpc.NodeAddr, family AddrFamily) encoding.BinaryMarshaler { + switch family { + case AddrFamilyIpv4: + return krpc.CompactIPv4NodeAddrs(nas) + case AddrFamilyIpv6: + return krpc.CompactIPv6NodeAddrs(nas) + } + return nil +} diff --git a/tracker/udp/announce.go b/tracker/udp/announce.go index 59b6c6cfa728e8b7c694de89074d24b200da0341..b5c9f8ffbaa317651c0d8a73b4272d96c84e38e6 100644 --- a/tracker/udp/announce.go +++ b/tracker/udp/announce.go @@ -38,7 +38,12 @@ var announceEventStrings = []string{"", "completed", "started", "stopped"} func (e AnnounceEvent) String() string { - // See BEP 3, "event", and https://github.com/anacrolix/torrent/issues/416#issuecomment-751427001. + // See BEP 3, "event", and + // https://github.com/anacrolix/torrent/issues/416#issuecomment-751427001. Return a safe default + // in case event values are not sanitized. + if e < 0 || int(e) >= len(announceEventStrings) { + return "" + } return announceEventStrings[e] } diff --git a/tracker/udp/protocol.go b/tracker/udp/protocol.go index f6beb4c6d55c3e495730fcac1caebfc23c717d4e..653d013efcd5367888ea6bde3af13d7d6ae7d749 100644 --- a/tracker/udp/protocol.go +++ b/tracker/udp/protocol.go @@ -26,7 +26,7 @@ ) type TransactionId = int32 -type ConnectionId = int64 +type ConnectionId = uint64 type ConnectionRequest struct { ConnectionId ConnectionId diff --git a/tracker/udp/server/server.go b/tracker/udp/server/server.go index abb4e431abd516750a5a1e5e2b77073c236b8f9e..5666e8052be0e0d8b37962dafd094205211fdd20 100644 --- a/tracker/udp/server/server.go +++ b/tracker/udp/server/server.go @@ -1 +1,241 @@ -package server +package udpTrackerServer + +import ( + "bytes" + "context" + "crypto/rand" + "encoding/binary" + "fmt" + "io" + "net" + "net/netip" + + "github.com/anacrolix/dht/v2/krpc" + "github.com/anacrolix/generics" + "github.com/anacrolix/log" + trackerServer "github.com/anacrolix/torrent/tracker/server" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/trace" + + "github.com/anacrolix/torrent/tracker/udp" +) + +type ConnectionTrackerAddr = string + +type ConnectionTracker interface { + Add(ctx context.Context, addr ConnectionTrackerAddr, id udp.ConnectionId) error + Check(ctx context.Context, addr ConnectionTrackerAddr, id udp.ConnectionId) (bool, error) +} + +type InfoHash = [20]byte + +type AnnounceTracker = trackerServer.AnnounceTracker + +type Server struct { + ConnTracker ConnectionTracker + SendResponse func(ctx context.Context, data []byte, addr net.Addr) (int, error) + Announce *trackerServer.AnnounceHandler +} + +type RequestSourceAddr = net.Addr + +var tracer = otel.Tracer("torrent.tracker.udp") + +func (me *Server) HandleRequest( + ctx context.Context, + family udp.AddrFamily, + source RequestSourceAddr, + body []byte, +) (err error) { + ctx, span := tracer.Start(ctx, "Server.HandleRequest", + trace.WithAttributes(attribute.Int("payload.len", len(body)))) + defer span.End() + defer func() { + if err != nil { + span.SetStatus(codes.Error, err.Error()) + } + }() + var h udp.RequestHeader + var r bytes.Reader + r.Reset(body) + err = udp.Read(&r, &h) + if err != nil { + err = fmt.Errorf("reading request header: %w", err) + return err + } + switch h.Action { + case udp.ActionConnect: + err = me.handleConnect(ctx, source, h.TransactionId) + case udp.ActionAnnounce: + err = me.handleAnnounce(ctx, family, source, h.ConnectionId, h.TransactionId, &r) + default: + err = fmt.Errorf("unimplemented") + } + if err != nil { + err = fmt.Errorf("handling action %v: %w", h.Action, err) + } + return err +} + +func (me *Server) handleAnnounce( + ctx context.Context, + addrFamily udp.AddrFamily, + source RequestSourceAddr, + connId udp.ConnectionId, + tid udp.TransactionId, + r *bytes.Reader, +) error { + // Should we set a timeout of 10s or something for the entire response, so that we give up if a + // retry is imminent? + + ok, err := me.ConnTracker.Check(ctx, source.String(), connId) + if err != nil { + err = fmt.Errorf("checking conn id: %w", err) + return err + } + if !ok { + return fmt.Errorf("incorrect connection id: %x", connId) + } + var req udp.AnnounceRequest + err = udp.Read(r, &req) + if err != nil { + return err + } + // TODO: This should be done asynchronously to responding to the announce. + announceAddr, err := netip.ParseAddrPort(source.String()) + if err != nil { + err = fmt.Errorf("converting source net.Addr to AnnounceAddr: %w", err) + return err + } + opts := trackerServer.GetPeersOpts{MaxCount: generics.Some[uint](50)} + if addrFamily == udp.AddrFamilyIpv4 { + opts.MaxCount = generics.Some[uint](150) + } + res := me.Announce.Serve(ctx, req, announceAddr, opts) + if res.Err != nil { + return res.Err + } + nodeAddrs := make([]krpc.NodeAddr, 0, len(res.Peers)) + for _, p := range res.Peers { + var ip net.IP + switch addrFamily { + default: + continue + case udp.AddrFamilyIpv4: + if !p.Addr().Unmap().Is4() { + continue + } + ipBuf := p.Addr().As4() + ip = ipBuf[:] + case udp.AddrFamilyIpv6: + ipBuf := p.Addr().As16() + ip = ipBuf[:] + } + nodeAddrs = append(nodeAddrs, krpc.NodeAddr{ + IP: ip[:], + Port: int(p.Port()), + }) + } + var buf bytes.Buffer + err = udp.Write(&buf, udp.ResponseHeader{ + Action: udp.ActionAnnounce, + TransactionId: tid, + }) + if err != nil { + return err + } + err = udp.Write(&buf, udp.AnnounceResponseHeader{ + Interval: res.Interval.UnwrapOr(5 * 60), + Seeders: res.Seeders.Value, + Leechers: res.Leechers.Value, + }) + if err != nil { + return err + } + b, err := udp.GetNodeAddrsCompactMarshaler(nodeAddrs, addrFamily).MarshalBinary() + if err != nil { + err = fmt.Errorf("marshalling compact node addrs: %w", err) + return err + } + buf.Write(b) + n, err := me.SendResponse(ctx, buf.Bytes(), source) + if err != nil { + return err + } + if n < buf.Len() { + err = io.ErrShortWrite + } + return err +} + +func (me *Server) handleConnect(ctx context.Context, source RequestSourceAddr, tid udp.TransactionId) error { + connId := randomConnectionId() + err := me.ConnTracker.Add(ctx, source.String(), connId) + if err != nil { + err = fmt.Errorf("recording conn id: %w", err) + return err + } + var buf bytes.Buffer + udp.Write(&buf, udp.ResponseHeader{ + Action: udp.ActionConnect, + TransactionId: tid, + }) + udp.Write(&buf, udp.ConnectionResponse{connId}) + n, err := me.SendResponse(ctx, buf.Bytes(), source) + if err != nil { + return err + } + if n < buf.Len() { + err = io.ErrShortWrite + } + return err +} + +func randomConnectionId() udp.ConnectionId { + var b [8]byte + _, err := rand.Read(b[:]) + if err != nil { + panic(err) + } + return binary.BigEndian.Uint64(b[:]) +} + +func RunSimple(ctx context.Context, s *Server, pc net.PacketConn, family udp.AddrFamily) error { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + var b [1500]byte + // Limit concurrent handled requests. + sem := make(chan struct{}, 1000) + for { + n, addr, err := pc.ReadFrom(b[:]) + ctx, span := tracer.Start(ctx, "handle udp packet") + if err != nil { + span.SetStatus(codes.Error, err.Error()) + span.End() + return err + } + select { + case <-ctx.Done(): + span.SetStatus(codes.Error, err.Error()) + span.End() + return ctx.Err() + default: + span.SetStatus(codes.Error, "concurrency limit reached") + span.End() + log.Levelf(log.Debug, "dropping request from %v: concurrency limit reached", addr) + continue + case sem <- struct{}{}: + } + b := append([]byte(nil), b[:n]...) + go func() { + defer span.End() + defer func() { <-sem }() + err := s.HandleRequest(ctx, family, addr, b) + if err != nil { + log.Printf("error handling %v byte request from %v: %v", n, addr, err) + } + }() + } +} diff --git a/tracker/udp_test.go b/tracker/udp_test.go index 7354063b69dd664569341b3c3c1131135daa5d52..751e41b942b2c2d45f544087f7f6d6a622fec905 100644 --- a/tracker/udp_test.go +++ b/tracker/udp_test.go @@ -23,6 +23,7 @@ var trackers = []string{ "udp://tracker.opentrackr.org:1337/announce", "udp://tracker.openbittorrent.com:6969/announce", + "udp://localhost:42069", } func TestAnnounceLocalhost(t *testing.T) {