]> Sergey Matveev's repositories - btrtrc.git/blob - tracker/udp/server/server.go
Don't use AnnounceHandler by value
[btrtrc.git] / tracker / udp / server / server.go
1 package udpTrackerServer
2
3 import (
4         "bytes"
5         "context"
6         "crypto/rand"
7         "encoding/binary"
8         "fmt"
9         "io"
10         "net"
11         "net/netip"
12
13         "github.com/anacrolix/dht/v2/krpc"
14         "github.com/anacrolix/generics"
15         "github.com/anacrolix/log"
16
17         "github.com/anacrolix/torrent/tracker"
18         "github.com/anacrolix/torrent/tracker/udp"
19 )
20
21 type ConnectionTrackerAddr = string
22
23 type ConnectionTracker interface {
24         Add(ctx context.Context, addr ConnectionTrackerAddr, id udp.ConnectionId) error
25         Check(ctx context.Context, addr ConnectionTrackerAddr, id udp.ConnectionId) (bool, error)
26 }
27
28 type InfoHash = [20]byte
29
30 type AnnounceTracker = tracker.AnnounceTracker
31
32 type Server struct {
33         ConnTracker  ConnectionTracker
34         SendResponse func(data []byte, addr net.Addr) (int, error)
35         Announce     *tracker.AnnounceHandler
36 }
37
38 type RequestSourceAddr = net.Addr
39
40 func (me *Server) HandleRequest(
41         ctx context.Context,
42         family udp.AddrFamily,
43         source RequestSourceAddr,
44         body []byte,
45 ) error {
46         var h udp.RequestHeader
47         var r bytes.Reader
48         r.Reset(body)
49         err := udp.Read(&r, &h)
50         if err != nil {
51                 err = fmt.Errorf("reading request header: %w", err)
52                 return err
53         }
54         switch h.Action {
55         case udp.ActionConnect:
56                 err = me.handleConnect(ctx, source, h.TransactionId)
57         case udp.ActionAnnounce:
58                 err = me.handleAnnounce(ctx, family, source, h.ConnectionId, h.TransactionId, &r)
59         default:
60                 err = fmt.Errorf("unimplemented")
61         }
62         if err != nil {
63                 err = fmt.Errorf("handling action %v: %w", h.Action, err)
64         }
65         return err
66 }
67
68 func (me *Server) handleAnnounce(
69         ctx context.Context,
70         addrFamily udp.AddrFamily,
71         source RequestSourceAddr,
72         connId udp.ConnectionId,
73         tid udp.TransactionId,
74         r *bytes.Reader,
75 ) error {
76         // Should we set a timeout of 10s or something for the entire response, so that we give up if a
77         // retry is imminent?
78
79         ok, err := me.ConnTracker.Check(ctx, source.String(), connId)
80         if err != nil {
81                 err = fmt.Errorf("checking conn id: %w", err)
82                 return err
83         }
84         if !ok {
85                 return fmt.Errorf("incorrect connection id: %x", connId)
86         }
87         var req udp.AnnounceRequest
88         err = udp.Read(r, &req)
89         if err != nil {
90                 return err
91         }
92         // TODO: This should be done asynchronously to responding to the announce.
93         announceAddr, err := netip.ParseAddrPort(source.String())
94         if err != nil {
95                 err = fmt.Errorf("converting source net.Addr to AnnounceAddr: %w", err)
96                 return err
97         }
98         opts := tracker.GetPeersOpts{MaxCount: generics.Some[uint](50)}
99         if addrFamily == udp.AddrFamilyIpv4 {
100                 opts.MaxCount = generics.Some[uint](150)
101         }
102         peers, err := me.Announce.Serve(ctx, req, announceAddr, opts)
103         if err != nil {
104                 return err
105         }
106         nodeAddrs := make([]krpc.NodeAddr, 0, len(peers))
107         for _, p := range peers {
108                 var ip net.IP
109                 switch addrFamily {
110                 default:
111                         continue
112                 case udp.AddrFamilyIpv4:
113                         if !p.Addr().Unmap().Is4() {
114                                 continue
115                         }
116                         ipBuf := p.Addr().As4()
117                         ip = ipBuf[:]
118                 case udp.AddrFamilyIpv6:
119                         ipBuf := p.Addr().As16()
120                         ip = ipBuf[:]
121                 }
122                 nodeAddrs = append(nodeAddrs, krpc.NodeAddr{
123                         IP:   ip[:],
124                         Port: int(p.Port()),
125                 })
126         }
127         var buf bytes.Buffer
128         err = udp.Write(&buf, udp.ResponseHeader{
129                 Action:        udp.ActionAnnounce,
130                 TransactionId: tid,
131         })
132         if err != nil {
133                 return err
134         }
135         err = udp.Write(&buf, udp.AnnounceResponseHeader{})
136         if err != nil {
137                 return err
138         }
139         b, err := udp.GetNodeAddrsCompactMarshaler(nodeAddrs, addrFamily).MarshalBinary()
140         if err != nil {
141                 err = fmt.Errorf("marshalling compact node addrs: %w", err)
142                 return err
143         }
144         buf.Write(b)
145         n, err := me.SendResponse(buf.Bytes(), source)
146         if err != nil {
147                 return err
148         }
149         if n < buf.Len() {
150                 err = io.ErrShortWrite
151         }
152         return err
153 }
154
155 func (me *Server) handleConnect(ctx context.Context, source RequestSourceAddr, tid udp.TransactionId) error {
156         connId := randomConnectionId()
157         err := me.ConnTracker.Add(ctx, source.String(), connId)
158         if err != nil {
159                 err = fmt.Errorf("recording conn id: %w", err)
160                 return err
161         }
162         var buf bytes.Buffer
163         udp.Write(&buf, udp.ResponseHeader{
164                 Action:        udp.ActionConnect,
165                 TransactionId: tid,
166         })
167         udp.Write(&buf, udp.ConnectionResponse{connId})
168         n, err := me.SendResponse(buf.Bytes(), source)
169         if err != nil {
170                 return err
171         }
172         if n < buf.Len() {
173                 err = io.ErrShortWrite
174         }
175         return err
176 }
177
178 func randomConnectionId() udp.ConnectionId {
179         var b [8]byte
180         _, err := rand.Read(b[:])
181         if err != nil {
182                 panic(err)
183         }
184         return binary.BigEndian.Uint64(b[:])
185 }
186
187 func RunSimple(ctx context.Context, s *Server, pc net.PacketConn, family udp.AddrFamily) error {
188         ctx, cancel := context.WithCancel(ctx)
189         defer cancel()
190         for {
191                 var b [1500]byte
192                 n, addr, err := pc.ReadFrom(b[:])
193                 if err != nil {
194                         return err
195                 }
196                 go func() {
197                         err := s.HandleRequest(ctx, family, addr, b[:n])
198                         if err != nil {
199                                 log.Printf("error handling %v byte request from %v: %v", n, addr, err)
200                         }
201                 }()
202         }
203 }