]> Sergey Matveev's repositories - btrtrc.git/commitdiff
Start a UDP server implementation
authorMatt Joiner <anacrolix@gmail.com>
Mon, 5 Dec 2022 01:52:19 +0000 (12:52 +1100)
committerMatt Joiner <anacrolix@gmail.com>
Tue, 6 Dec 2022 23:45:12 +0000 (10:45 +1100)
tracker/udp/addr-family.go
tracker/udp/server/server.go
tracker/udp_test.go

index 0213f41f0215799c942685d46b38544290b3680b..ddecb4c9fafb3e4ed0225a081fa36055cf1832b0 100644 (file)
@@ -1 +1,26 @@
 package udp
+
+import (
+       "encoding"
+
+       "github.com/anacrolix/dht/v2/krpc"
+)
+
+// Discriminates behaviours based on address family in use.
+type AddrFamily int
+
+const (
+       AddrFamilyIpv4 = iota + 1
+       AddrFamilyIpv6
+)
+
+// Returns a marshaler for the given node addrs for the specified family.
+func GetNodeAddrsCompactMarshaler(nas []krpc.NodeAddr, family AddrFamily) encoding.BinaryMarshaler {
+       switch family {
+       case AddrFamilyIpv4:
+               return krpc.CompactIPv4NodeAddrs(nas)
+       case AddrFamilyIpv6:
+               return krpc.CompactIPv6NodeAddrs(nas)
+       }
+       return nil
+}
index abb4e431abd516750a5a1e5e2b77073c236b8f9e..815002147223177265625ac29639ae4280a4d7af 100644 (file)
@@ -1 +1,200 @@
 package server
+
+import (
+       "bytes"
+       "context"
+       "crypto/rand"
+       "encoding/binary"
+       "fmt"
+       "io"
+       "net"
+       "net/netip"
+
+       "github.com/anacrolix/dht/v2/krpc"
+       "github.com/anacrolix/log"
+       "github.com/anacrolix/torrent/tracker/udp"
+)
+
+type ConnectionTrackerAddr = string
+
+type ConnectionTracker interface {
+       Add(ctx context.Context, addr ConnectionTrackerAddr, id udp.ConnectionId) error
+       Check(ctx context.Context, addr ConnectionTrackerAddr, id udp.ConnectionId) (bool, error)
+}
+
+type InfoHash = [20]byte
+
+// This is reserved for stuff like filtering by IP version, avoiding an announcer's IP or key,
+// limiting return count, etc.
+type GetPeersOpts struct{}
+
+type PeerInfo struct {
+       netip.AddrPort
+}
+
+type AnnounceTracker interface {
+       TrackAnnounce(ctx context.Context, req udp.AnnounceRequest, addr RequestSourceAddr) error
+       Scrape(ctx context.Context, infoHashes []InfoHash) ([]udp.ScrapeInfohashResult, error)
+       GetPeers(ctx context.Context, infoHash InfoHash, opts GetPeersOpts) ([]PeerInfo, error)
+}
+
+type Server struct {
+       ConnTracker     ConnectionTracker
+       SendResponse    func(data []byte, addr net.Addr) (int, error)
+       AnnounceTracker AnnounceTracker
+}
+
+type RequestSourceAddr = net.Addr
+
+func (me *Server) HandleRequest(ctx context.Context, family udp.AddrFamily, source RequestSourceAddr, body []byte) error {
+       var h udp.RequestHeader
+       var r bytes.Reader
+       r.Reset(body)
+       err := udp.Read(&r, &h)
+       if err != nil {
+               err = fmt.Errorf("reading request header: %w", err)
+               return err
+       }
+       switch h.Action {
+       case udp.ActionConnect:
+               err = me.handleConnect(ctx, source, h.TransactionId)
+       case udp.ActionAnnounce:
+               err = me.handleAnnounce(ctx, family, source, h.ConnectionId, h.TransactionId, &r)
+       default:
+               err = fmt.Errorf("unimplemented")
+       }
+       if err != nil {
+               err = fmt.Errorf("handling action %v: %w", h.Action, err)
+       }
+       return err
+}
+
+func (me *Server) handleAnnounce(
+       ctx context.Context,
+       addrFamily udp.AddrFamily,
+       source RequestSourceAddr,
+       connId udp.ConnectionId,
+       tid udp.TransactionId,
+       r *bytes.Reader,
+) error {
+       ok, err := me.ConnTracker.Check(ctx, source.String(), connId)
+       if err != nil {
+               err = fmt.Errorf("checking conn id: %w", err)
+               return err
+       }
+       if !ok {
+               return fmt.Errorf("invalid connection id: %v", connId)
+       }
+       var req udp.AnnounceRequest
+       err = udp.Read(r, &req)
+       if err != nil {
+               return err
+       }
+       // TODO: This should be done asynchronously to responding to the announce.
+       err = me.AnnounceTracker.TrackAnnounce(ctx, req, source)
+       if err != nil {
+               return err
+       }
+       peers, err := me.AnnounceTracker.GetPeers(ctx, req.InfoHash, GetPeersOpts{})
+       if err != nil {
+               return err
+       }
+       nodeAddrs := make([]krpc.NodeAddr, 0, len(peers))
+       for _, p := range peers {
+               var ip net.IP
+               switch addrFamily {
+               default:
+                       continue
+               case udp.AddrFamilyIpv4:
+                       if !p.Addr().Unmap().Is4() {
+                               continue
+                       }
+                       ipBuf := p.Addr().As4()
+                       ip = ipBuf[:]
+               case udp.AddrFamilyIpv6:
+                       ipBuf := p.Addr().As16()
+                       ip = ipBuf[:]
+               }
+               nodeAddrs = append(nodeAddrs, krpc.NodeAddr{
+                       IP:   ip[:],
+                       Port: int(p.Port()),
+               })
+       }
+       var buf bytes.Buffer
+       err = udp.Write(&buf, udp.ResponseHeader{
+               Action:        udp.ActionAnnounce,
+               TransactionId: tid,
+       })
+       if err != nil {
+               return err
+       }
+       err = udp.Write(&buf, udp.AnnounceResponseHeader{})
+       if err != nil {
+               return err
+       }
+       b, err := udp.GetNodeAddrsCompactMarshaler(nodeAddrs, addrFamily).MarshalBinary()
+       if err != nil {
+               err = fmt.Errorf("marshalling compact node addrs: %w", err)
+               return err
+       }
+       log.Print(nodeAddrs)
+       buf.Write(b)
+       n, err := me.SendResponse(buf.Bytes(), source)
+       if err != nil {
+               return err
+       }
+       if n < buf.Len() {
+               err = io.ErrShortWrite
+       }
+       return err
+}
+
+func (me *Server) handleConnect(ctx context.Context, source RequestSourceAddr, tid udp.TransactionId) error {
+       connId := randomConnectionId()
+       err := me.ConnTracker.Add(ctx, source.String(), connId)
+       if err != nil {
+               err = fmt.Errorf("recording conn id: %w", err)
+               return err
+       }
+       var buf bytes.Buffer
+       udp.Write(&buf, udp.ResponseHeader{
+               Action:        udp.ActionConnect,
+               TransactionId: tid,
+       })
+       udp.Write(&buf, udp.ConnectionResponse{connId})
+       n, err := me.SendResponse(buf.Bytes(), source)
+       if err != nil {
+               return err
+       }
+       if n < buf.Len() {
+               err = io.ErrShortWrite
+       }
+       return err
+}
+
+func randomConnectionId() udp.ConnectionId {
+       var b [8]byte
+       _, err := rand.Read(b[:])
+       if err != nil {
+               panic(err)
+       }
+       return int64(binary.BigEndian.Uint64(b[:]))
+}
+
+func RunServer(ctx context.Context, s *Server, pc net.PacketConn, family udp.AddrFamily) error {
+       ctx, cancel := context.WithCancel(ctx)
+       defer cancel()
+       for {
+               var b [1500]byte
+               n, addr, err := pc.ReadFrom(b[:])
+               if err != nil {
+                       return err
+               }
+               go func() {
+                       err := s.HandleRequest(ctx, family, addr, b[:n])
+                       if err != nil {
+                               log.Printf("error handling %v byte request from %v: %v", n, addr, err)
+                       }
+               }()
+       }
+}
index 7354063b69dd664569341b3c3c1131135daa5d52..751e41b942b2c2d45f544087f7f6d6a622fec905 100644 (file)
@@ -23,6 +23,7 @@ import (
 var trackers = []string{
        "udp://tracker.opentrackr.org:1337/announce",
        "udp://tracker.openbittorrent.com:6969/announce",
+       "udp://localhost:42069",
 }
 
 func TestAnnounceLocalhost(t *testing.T) {