]> Sergey Matveev's repositories - btrtrc.git/blob - tracker/udp/server/server.go
68abb1061aca1b8b9a6e882bf74057b8dd8ac723
[btrtrc.git] / tracker / udp / server / server.go
1 package udpTrackerServer
2
3 import (
4         "bytes"
5         "context"
6         "crypto/rand"
7         "encoding/binary"
8         "fmt"
9         "io"
10         "net"
11         "net/netip"
12
13         "github.com/anacrolix/dht/v2/krpc"
14         "github.com/anacrolix/log"
15
16         "github.com/anacrolix/torrent/tracker"
17         "github.com/anacrolix/torrent/tracker/udp"
18 )
19
20 type ConnectionTrackerAddr = string
21
22 type ConnectionTracker interface {
23         Add(ctx context.Context, addr ConnectionTrackerAddr, id udp.ConnectionId) error
24         Check(ctx context.Context, addr ConnectionTrackerAddr, id udp.ConnectionId) (bool, error)
25 }
26
27 type InfoHash = [20]byte
28
29 type AnnounceTracker = tracker.AnnounceTracker
30
31 type Server struct {
32         ConnTracker     ConnectionTracker
33         SendResponse    func(data []byte, addr net.Addr) (int, error)
34         AnnounceTracker AnnounceTracker
35 }
36
37 type RequestSourceAddr = net.Addr
38
39 func (me *Server) HandleRequest(
40         ctx context.Context,
41         family udp.AddrFamily,
42         source RequestSourceAddr,
43         body []byte,
44 ) error {
45         var h udp.RequestHeader
46         var r bytes.Reader
47         r.Reset(body)
48         err := udp.Read(&r, &h)
49         if err != nil {
50                 err = fmt.Errorf("reading request header: %w", err)
51                 return err
52         }
53         switch h.Action {
54         case udp.ActionConnect:
55                 err = me.handleConnect(ctx, source, h.TransactionId)
56         case udp.ActionAnnounce:
57                 err = me.handleAnnounce(ctx, family, source, h.ConnectionId, h.TransactionId, &r)
58         default:
59                 err = fmt.Errorf("unimplemented")
60         }
61         if err != nil {
62                 err = fmt.Errorf("handling action %v: %w", h.Action, err)
63         }
64         return err
65 }
66
67 func (me *Server) handleAnnounce(
68         ctx context.Context,
69         addrFamily udp.AddrFamily,
70         source RequestSourceAddr,
71         connId udp.ConnectionId,
72         tid udp.TransactionId,
73         r *bytes.Reader,
74 ) error {
75         ok, err := me.ConnTracker.Check(ctx, source.String(), connId)
76         if err != nil {
77                 err = fmt.Errorf("checking conn id: %w", err)
78                 return err
79         }
80         if !ok {
81                 return fmt.Errorf("invalid connection id: %v", connId)
82         }
83         var req udp.AnnounceRequest
84         err = udp.Read(r, &req)
85         if err != nil {
86                 return err
87         }
88         // TODO: This should be done asynchronously to responding to the announce.
89         announceAddr, err := netip.ParseAddrPort(source.String())
90         if err != nil {
91                 err = fmt.Errorf("converting source net.Addr to AnnounceAddr: %w", err)
92                 return err
93         }
94         err = me.AnnounceTracker.TrackAnnounce(ctx, req, announceAddr)
95         if err != nil {
96                 return err
97         }
98         peers, err := me.AnnounceTracker.GetPeers(ctx, req.InfoHash, tracker.GetPeersOpts{})
99         if err != nil {
100                 return err
101         }
102         nodeAddrs := make([]krpc.NodeAddr, 0, len(peers))
103         for _, p := range peers {
104                 var ip net.IP
105                 switch addrFamily {
106                 default:
107                         continue
108                 case udp.AddrFamilyIpv4:
109                         if !p.Addr().Unmap().Is4() {
110                                 continue
111                         }
112                         ipBuf := p.Addr().As4()
113                         ip = ipBuf[:]
114                 case udp.AddrFamilyIpv6:
115                         ipBuf := p.Addr().As16()
116                         ip = ipBuf[:]
117                 }
118                 nodeAddrs = append(nodeAddrs, krpc.NodeAddr{
119                         IP:   ip[:],
120                         Port: int(p.Port()),
121                 })
122         }
123         var buf bytes.Buffer
124         err = udp.Write(&buf, udp.ResponseHeader{
125                 Action:        udp.ActionAnnounce,
126                 TransactionId: tid,
127         })
128         if err != nil {
129                 return err
130         }
131         err = udp.Write(&buf, udp.AnnounceResponseHeader{})
132         if err != nil {
133                 return err
134         }
135         b, err := udp.GetNodeAddrsCompactMarshaler(nodeAddrs, addrFamily).MarshalBinary()
136         if err != nil {
137                 err = fmt.Errorf("marshalling compact node addrs: %w", err)
138                 return err
139         }
140         buf.Write(b)
141         n, err := me.SendResponse(buf.Bytes(), source)
142         if err != nil {
143                 return err
144         }
145         if n < buf.Len() {
146                 err = io.ErrShortWrite
147         }
148         return err
149 }
150
151 func (me *Server) handleConnect(ctx context.Context, source RequestSourceAddr, tid udp.TransactionId) error {
152         connId := randomConnectionId()
153         err := me.ConnTracker.Add(ctx, source.String(), connId)
154         if err != nil {
155                 err = fmt.Errorf("recording conn id: %w", err)
156                 return err
157         }
158         var buf bytes.Buffer
159         udp.Write(&buf, udp.ResponseHeader{
160                 Action:        udp.ActionConnect,
161                 TransactionId: tid,
162         })
163         udp.Write(&buf, udp.ConnectionResponse{connId})
164         n, err := me.SendResponse(buf.Bytes(), source)
165         if err != nil {
166                 return err
167         }
168         if n < buf.Len() {
169                 err = io.ErrShortWrite
170         }
171         return err
172 }
173
174 func randomConnectionId() udp.ConnectionId {
175         var b [8]byte
176         _, err := rand.Read(b[:])
177         if err != nil {
178                 panic(err)
179         }
180         return int64(binary.BigEndian.Uint64(b[:]))
181 }
182
183 func RunSimple(ctx context.Context, s *Server, pc net.PacketConn, family udp.AddrFamily) error {
184         ctx, cancel := context.WithCancel(ctx)
185         defer cancel()
186         for {
187                 var b [1500]byte
188                 n, addr, err := pc.ReadFrom(b[:])
189                 if err != nil {
190                         return err
191                 }
192                 go func() {
193                         err := s.HandleRequest(ctx, family, addr, b[:n])
194                         if err != nil {
195                                 log.Printf("error handling %v byte request from %v: %v", n, addr, err)
196                         }
197                 }()
198         }
199 }