]> Sergey Matveev's repositories - btrtrc.git/blob - tracker/udp/server/server.go
815002147223177265625ac29639ae4280a4d7af
[btrtrc.git] / tracker / udp / server / server.go
1 package server
2
3 import (
4         "bytes"
5         "context"
6         "crypto/rand"
7         "encoding/binary"
8         "fmt"
9         "io"
10         "net"
11         "net/netip"
12
13         "github.com/anacrolix/dht/v2/krpc"
14         "github.com/anacrolix/log"
15         "github.com/anacrolix/torrent/tracker/udp"
16 )
17
18 type ConnectionTrackerAddr = string
19
20 type ConnectionTracker interface {
21         Add(ctx context.Context, addr ConnectionTrackerAddr, id udp.ConnectionId) error
22         Check(ctx context.Context, addr ConnectionTrackerAddr, id udp.ConnectionId) (bool, error)
23 }
24
25 type InfoHash = [20]byte
26
27 // This is reserved for stuff like filtering by IP version, avoiding an announcer's IP or key,
28 // limiting return count, etc.
29 type GetPeersOpts struct{}
30
31 type PeerInfo struct {
32         netip.AddrPort
33 }
34
35 type AnnounceTracker interface {
36         TrackAnnounce(ctx context.Context, req udp.AnnounceRequest, addr RequestSourceAddr) error
37         Scrape(ctx context.Context, infoHashes []InfoHash) ([]udp.ScrapeInfohashResult, error)
38         GetPeers(ctx context.Context, infoHash InfoHash, opts GetPeersOpts) ([]PeerInfo, error)
39 }
40
41 type Server struct {
42         ConnTracker     ConnectionTracker
43         SendResponse    func(data []byte, addr net.Addr) (int, error)
44         AnnounceTracker AnnounceTracker
45 }
46
47 type RequestSourceAddr = net.Addr
48
49 func (me *Server) HandleRequest(ctx context.Context, family udp.AddrFamily, source RequestSourceAddr, body []byte) error {
50         var h udp.RequestHeader
51         var r bytes.Reader
52         r.Reset(body)
53         err := udp.Read(&r, &h)
54         if err != nil {
55                 err = fmt.Errorf("reading request header: %w", err)
56                 return err
57         }
58         switch h.Action {
59         case udp.ActionConnect:
60                 err = me.handleConnect(ctx, source, h.TransactionId)
61         case udp.ActionAnnounce:
62                 err = me.handleAnnounce(ctx, family, source, h.ConnectionId, h.TransactionId, &r)
63         default:
64                 err = fmt.Errorf("unimplemented")
65         }
66         if err != nil {
67                 err = fmt.Errorf("handling action %v: %w", h.Action, err)
68         }
69         return err
70 }
71
72 func (me *Server) handleAnnounce(
73         ctx context.Context,
74         addrFamily udp.AddrFamily,
75         source RequestSourceAddr,
76         connId udp.ConnectionId,
77         tid udp.TransactionId,
78         r *bytes.Reader,
79 ) error {
80         ok, err := me.ConnTracker.Check(ctx, source.String(), connId)
81         if err != nil {
82                 err = fmt.Errorf("checking conn id: %w", err)
83                 return err
84         }
85         if !ok {
86                 return fmt.Errorf("invalid connection id: %v", connId)
87         }
88         var req udp.AnnounceRequest
89         err = udp.Read(r, &req)
90         if err != nil {
91                 return err
92         }
93         // TODO: This should be done asynchronously to responding to the announce.
94         err = me.AnnounceTracker.TrackAnnounce(ctx, req, source)
95         if err != nil {
96                 return err
97         }
98         peers, err := me.AnnounceTracker.GetPeers(ctx, req.InfoHash, GetPeersOpts{})
99         if err != nil {
100                 return err
101         }
102         nodeAddrs := make([]krpc.NodeAddr, 0, len(peers))
103         for _, p := range peers {
104                 var ip net.IP
105                 switch addrFamily {
106                 default:
107                         continue
108                 case udp.AddrFamilyIpv4:
109                         if !p.Addr().Unmap().Is4() {
110                                 continue
111                         }
112                         ipBuf := p.Addr().As4()
113                         ip = ipBuf[:]
114                 case udp.AddrFamilyIpv6:
115                         ipBuf := p.Addr().As16()
116                         ip = ipBuf[:]
117                 }
118                 nodeAddrs = append(nodeAddrs, krpc.NodeAddr{
119                         IP:   ip[:],
120                         Port: int(p.Port()),
121                 })
122         }
123         var buf bytes.Buffer
124         err = udp.Write(&buf, udp.ResponseHeader{
125                 Action:        udp.ActionAnnounce,
126                 TransactionId: tid,
127         })
128         if err != nil {
129                 return err
130         }
131         err = udp.Write(&buf, udp.AnnounceResponseHeader{})
132         if err != nil {
133                 return err
134         }
135         b, err := udp.GetNodeAddrsCompactMarshaler(nodeAddrs, addrFamily).MarshalBinary()
136         if err != nil {
137                 err = fmt.Errorf("marshalling compact node addrs: %w", err)
138                 return err
139         }
140         log.Print(nodeAddrs)
141         buf.Write(b)
142         n, err := me.SendResponse(buf.Bytes(), source)
143         if err != nil {
144                 return err
145         }
146         if n < buf.Len() {
147                 err = io.ErrShortWrite
148         }
149         return err
150 }
151
152 func (me *Server) handleConnect(ctx context.Context, source RequestSourceAddr, tid udp.TransactionId) error {
153         connId := randomConnectionId()
154         err := me.ConnTracker.Add(ctx, source.String(), connId)
155         if err != nil {
156                 err = fmt.Errorf("recording conn id: %w", err)
157                 return err
158         }
159         var buf bytes.Buffer
160         udp.Write(&buf, udp.ResponseHeader{
161                 Action:        udp.ActionConnect,
162                 TransactionId: tid,
163         })
164         udp.Write(&buf, udp.ConnectionResponse{connId})
165         n, err := me.SendResponse(buf.Bytes(), source)
166         if err != nil {
167                 return err
168         }
169         if n < buf.Len() {
170                 err = io.ErrShortWrite
171         }
172         return err
173 }
174
175 func randomConnectionId() udp.ConnectionId {
176         var b [8]byte
177         _, err := rand.Read(b[:])
178         if err != nil {
179                 panic(err)
180         }
181         return int64(binary.BigEndian.Uint64(b[:]))
182 }
183
184 func RunServer(ctx context.Context, s *Server, pc net.PacketConn, family udp.AddrFamily) error {
185         ctx, cancel := context.WithCancel(ctx)
186         defer cancel()
187         for {
188                 var b [1500]byte
189                 n, addr, err := pc.ReadFrom(b[:])
190                 if err != nil {
191                         return err
192                 }
193                 go func() {
194                         err := s.HandleRequest(ctx, family, addr, b[:n])
195                         if err != nil {
196                                 log.Printf("error handling %v byte request from %v: %v", n, addr, err)
197                         }
198                 }()
199         }
200 }