]> Sergey Matveev's repositories - vors.git/blob - cmd/server/main.go
Show bad packets counter
[vors.git] / cmd / server / main.go
1 // VoRS -- Vo(IP) Really Simple
2 // Copyright (C) 2024 Sergey Matveev <stargrave@stargrave.org>
3 //
4 // This program is free software: you can redistribute it and/or modify
5 // it under the terms of the GNU Affero General Public License as
6 // published by the Free Software Foundation, version 3 of the License.
7 //
8 // This program is distributed in the hope that it will be useful,
9 // but WITHOUT ANY WARRANTY; without even the implied warranty of
10 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11 // GNU General Public License for more details.
12 //
13 // You should have received a copy of the GNU General Public License
14 // along with this program.  If not, see <http://www.gnu.org/licenses/>.
15
16 package main
17
18 import (
19         "crypto/rand"
20         "crypto/subtle"
21         "crypto/tls"
22         "encoding/hex"
23         "flag"
24         "fmt"
25         "io"
26         "log"
27         "log/slog"
28         "net"
29         "net/netip"
30         "os"
31         "strconv"
32         "sync"
33         "time"
34
35         "github.com/dustin/go-humanize"
36         "github.com/flynn/noise"
37         "github.com/jroimartin/gocui"
38         vors "go.stargrave.org/vors/internal"
39         "golang.org/x/crypto/blake2s"
40         "golang.org/x/crypto/chacha20"
41         "golang.org/x/crypto/poly1305"
42 )
43
44 var (
45         TLSCfg = &tls.Config{
46                 MinVersion:       tls.VersionTLS13,
47                 CurvePreferences: []tls.CurveID{tls.X25519},
48         }
49         Peers    = map[byte]*Peer{}
50         PeersM   sync.Mutex
51         Prv, Pub []byte
52         Cookies  = map[vors.Cookie]chan *net.UDPAddr{}
53 )
54
55 func newPeer(conn *net.TCPConn) {
56         logger := slog.With("remote", conn.RemoteAddr().String())
57         logger.Info("connected")
58         defer conn.Close()
59         if len(Peers) == 1<<8 {
60                 logger.Error("too many peers")
61                 return
62         }
63         err := conn.SetNoDelay(true)
64         if err != nil {
65                 log.Fatalln("nodelay:", err)
66         }
67         buf := make([]byte, len(vors.NoisePrologue))
68
69         if _, err = io.ReadFull(conn, buf); err != nil {
70                 logger.Error("handshake: read prologue", "err", err)
71                 return
72         }
73         if string(buf) != vors.NoisePrologue {
74                 logger.Error("handshake: wrong prologue", "err", err)
75                 return
76         }
77
78         hs, err := noise.NewHandshakeState(noise.Config{
79                 CipherSuite:   vors.NoiseCipherSuite,
80                 Pattern:       noise.HandshakeNK,
81                 Initiator:     false,
82                 StaticKeypair: noise.DHKey{Private: Prv, Public: Pub},
83                 Prologue:      []byte(vors.NoisePrologue),
84         })
85         if err != nil {
86                 log.Fatalln("noise.NewHandshakeState:", err)
87         }
88         buf, err = vors.PktRead(conn)
89         if err != nil {
90                 logger.Error("read handshake", "err", err)
91                 return
92         }
93         peer := Peer{
94                 logger: logger,
95                 conn:   conn,
96                 stats:  &Stats{alive: make(chan struct{})},
97                 rx:     make(chan []byte),
98                 tx:     make(chan []byte, 10),
99                 alive:  make(chan struct{}),
100         }
101         {
102                 name, _, _, err := hs.ReadMessage(nil, buf)
103                 if err != nil {
104                         logger.Error("handshake: decrypt", "err", err)
105                         return
106                 }
107                 peer.name = string(name)
108         }
109         logger = logger.With("name", peer.name)
110
111         for _, p := range Peers {
112                 if p.name != peer.name {
113                         continue
114                 }
115                 logger.Error("name already taken")
116                 buf, _, _, err = hs.WriteMessage(nil, []byte("name already taken"))
117                 if err != nil {
118                         log.Fatal(err)
119                 }
120                 vors.PktWrite(conn, buf)
121                 return
122         }
123
124         {
125                 var i byte
126                 var ok bool
127                 var found bool
128                 PeersM.Lock()
129                 for i = 0; i <= (1<<8)-1; i++ {
130                         if _, ok = Peers[i]; !ok {
131                                 peer.sid = i
132                                 found = true
133                                 break
134                         }
135                 }
136                 if found {
137                         Peers[peer.sid] = &peer
138                         go peer.Tx()
139                 }
140                 PeersM.Unlock()
141                 if !found {
142                         buf, _, _, err = hs.WriteMessage(nil, []byte("too many users"))
143                         if err != nil {
144                                 log.Fatal(err)
145                         }
146                         vors.PktWrite(conn, buf)
147                         return
148                 }
149         }
150         logger = logger.With("sid", peer.sid)
151         logger.Info("logged in")
152
153         defer func() {
154                 logger.Info("removing")
155                 PeersM.Lock()
156                 delete(Peers, peer.sid)
157                 PeersM.Unlock()
158                 close(peer.stats.alive)
159                 s := []byte(fmt.Sprintf("%s %d", vors.CmdDel, peer.sid))
160                 for _, p := range Peers {
161                         go func(tx chan []byte) { tx <- s }(p.tx)
162                 }
163         }()
164
165         {
166                 var cookie vors.Cookie
167                 if _, err = io.ReadFull(rand.Reader, cookie[:]); err != nil {
168                         log.Fatalln("cookie:", err)
169                 }
170                 gotCookie := make(chan *net.UDPAddr)
171                 Cookies[cookie] = gotCookie
172
173                 var txCS, rxCS *noise.CipherState
174                 buf, txCS, rxCS, err := hs.WriteMessage(nil,
175                         []byte(fmt.Sprintf("OK %s", hex.EncodeToString(cookie[:]))))
176                 if err = vors.PktWrite(conn, buf); err != nil {
177                         logger.Error("handshake write", "err", err)
178                         delete(Cookies, cookie)
179                         return
180                 }
181                 peer.rxCS, peer.txCS = txCS, rxCS
182
183                 timeout := time.NewTimer(vors.PingTime)
184                 select {
185                 case peer.addr = <-gotCookie:
186                 case <-timeout.C:
187                         logger.Error("cookie timeout")
188                         delete(Cookies, cookie)
189                         return
190                 }
191                 delete(Cookies, cookie)
192                 logger.Info("got cookie", "addr", peer.addr)
193                 if !timeout.Stop() {
194                         <-timeout.C
195                 }
196         }
197         go peer.Rx()
198         peer.tx <- []byte(fmt.Sprintf("SID %d", peer.sid))
199
200         for _, p := range Peers {
201                 if p.sid == peer.sid {
202                         continue
203                 }
204                 peer.tx <- []byte(fmt.Sprintf("%s %d %s %s",
205                         vors.CmdAdd, p.sid, p.name, hex.EncodeToString(p.key)))
206         }
207
208         {
209                 h, err := blake2s.New256(hs.ChannelBinding())
210                 if err != nil {
211                         log.Fatalln(err)
212                 }
213                 h.Write([]byte(vors.NoisePrologue))
214                 peer.key = h.Sum(nil)
215         }
216
217         {
218                 s := []byte(fmt.Sprintf("%s %d %s %s",
219                         vors.CmdAdd, peer.sid, peer.name, hex.EncodeToString(peer.key)))
220                 for _, p := range Peers {
221                         if p.sid != peer.sid {
222                                 p.tx <- s
223                         }
224                 }
225         }
226
227         seen := time.Now()
228         go func(seen *time.Time) {
229                 ticker := time.Tick(vors.PingTime)
230                 var now time.Time
231                 for {
232                         select {
233                         case now = <-ticker:
234                                 if seen.Add(2 * vors.PingTime).Before(now) {
235                                         logger.Error("timeout:", "seen", seen)
236                                         peer.Close()
237                                         return
238                                 }
239                         case <-peer.alive:
240                                 return
241                         }
242                 }
243         }(&seen)
244
245         go func(stats *Stats) {
246                 if *NoGUI {
247                         return
248                 }
249                 tick := time.Tick(vors.ScreenRefresh)
250                 var now time.Time
251                 var v *gocui.View
252                 for {
253                         select {
254                         case <-stats.alive:
255                                 GUI.DeleteView(peer.name)
256                                 return
257                         case now = <-tick:
258                                 s := fmt.Sprintf(
259                                         "%s | Rx/Tx/Bad: %s / %s / %s |  %s / %s",
260                                         peer.addr,
261                                         humanize.Comma(stats.pktsRx),
262                                         humanize.Comma(stats.pktsTx),
263                                         humanize.Comma(stats.bads),
264                                         humanize.IBytes(stats.bytesRx),
265                                         humanize.IBytes(stats.bytesTx),
266                                 )
267                                 if stats.last.Add(vors.ScreenRefresh).After(now) {
268                                         s += "  |  " + vors.CGreen + "TALK" + vors.CReset
269                                 }
270                                 v, err = GUI.View(peer.name)
271                                 if err == nil {
272                                         v.Clear()
273                                         v.Write([]byte(s))
274                                 }
275                         }
276                 }
277         }(peer.stats)
278
279         for buf := range peer.rx {
280                 if string(buf) == vors.CmdPing {
281                         seen = time.Now()
282                         peer.tx <- []byte(vors.CmdPong)
283                 }
284         }
285 }
286
287 func main() {
288         bind := flag.String("bind", "[::1]:"+strconv.Itoa(vors.DefaultPort),
289                 "Host:TCP/UDP port to listen on")
290         kpFile := flag.String("key", "key", "Path to keypair file")
291         flag.Parse()
292         log.SetFlags(log.Lmicroseconds | log.Lshortfile)
293
294         {
295                 data, err := os.ReadFile(*kpFile)
296                 if err != nil {
297                         log.Fatal(err)
298                 }
299                 Prv, Pub = data[:len(data)/2], data[len(data)/2:]
300         }
301
302         lnTCP, err := net.ListenTCP("tcp",
303                 net.TCPAddrFromAddrPort(netip.MustParseAddrPort(*bind)))
304         if err != nil {
305                 log.Fatal(err)
306         }
307         lnUDP, err := net.ListenUDP("udp",
308                 net.UDPAddrFromAddrPort(netip.MustParseAddrPort(*bind)))
309         if err != nil {
310                 log.Fatal(err)
311         }
312
313         LoggerReady := make(chan struct{})
314         if *NoGUI {
315                 close(GUIReadyC)
316                 slog.SetDefault(slog.New(slog.NewTextHandler(os.Stderr, nil)))
317                 close(LoggerReady)
318         } else {
319                 GUI, err = gocui.NewGui(gocui.OutputNormal)
320                 if err != nil {
321                         log.Fatal(err)
322                 }
323                 defer GUI.Close()
324                 GUI.SetManagerFunc(guiLayout)
325                 if err := GUI.SetKeybinding("", gocui.KeyF10, gocui.ModNone, guiQuit); err != nil {
326                         log.Fatal(err)
327                 }
328
329                 go func() {
330                         <-GUIReadyC
331                         v, err := GUI.View("logs")
332                         if err != nil {
333                                 log.Fatal(err)
334                         }
335                         slog.SetDefault(slog.New(slog.NewTextHandler(v, nil)))
336                         close(LoggerReady)
337                         for {
338                                 time.Sleep(vors.ScreenRefresh)
339                                 GUI.Update(func(gui *gocui.Gui) error {
340                                         return nil
341                                 })
342                         }
343                 }()
344         }
345
346         go func() {
347                 <-LoggerReady
348                 buf := make([]byte, 2*vors.FrameLen)
349                 var n int
350                 var from *net.UDPAddr
351                 var err error
352                 var sid byte
353                 var peer *Peer
354                 var ciph *chacha20.Cipher
355                 var macKey [32]byte
356                 var mac *poly1305.MAC
357                 tag := make([]byte, poly1305.TagSize)
358                 nonce := make([]byte, 12)
359                 for {
360                         n, from, err = lnUDP.ReadFromUDP(buf)
361                         if err != nil {
362                                 log.Fatalln("recvfrom:", err)
363                         }
364
365                         if n == vors.CookieLen {
366                                 var cookie vors.Cookie
367                                 copy(cookie[:], buf)
368                                 if c, ok := Cookies[cookie]; ok {
369                                         c <- from
370                                         close(c)
371                                 } else {
372                                         slog.Info("unknown cookie", "cookie", cookie)
373                                 }
374                                 continue
375                         }
376
377                         sid = buf[0]
378                         peer = Peers[sid]
379                         if peer == nil {
380                                 slog.Info("unknown", "sid", sid, "from", from)
381                                 continue
382                         }
383
384                         if from.Port != peer.addr.Port || !from.IP.Equal(peer.addr.IP) {
385                                 slog.Info("wrong addr",
386                                         "peer", peer.name,
387                                         "our", peer.addr,
388                                         "got", from)
389                                 continue
390                         }
391
392                         peer.stats.pktsRx++
393                         peer.stats.bytesRx += uint64(n)
394                         if n == 1 {
395                                 continue
396                         }
397                         if n <= 4+vors.TagLen {
398                                 peer.stats.bads++
399                                 continue
400                         }
401
402                         copy(nonce[len(nonce)-4:], buf)
403                         ciph, err = chacha20.NewUnauthenticatedCipher(peer.key, nonce)
404                         if err != nil {
405                                 log.Fatal(err)
406                         }
407                         clear(macKey[:])
408                         ciph.XORKeyStream(macKey[:], macKey[:])
409                         ciph.SetCounter(1)
410                         mac = poly1305.New(&macKey)
411                         if _, err = mac.Write(buf[4 : n-vors.TagLen]); err != nil {
412                                 log.Fatal(err)
413                         }
414                         mac.Sum(tag[:0])
415                         if subtle.ConstantTimeCompare(
416                                 tag[:vors.TagLen],
417                                 buf[n-vors.TagLen:n],
418                         ) != 1 {
419                                 peer.stats.bads++
420                                 continue
421                         }
422
423                         peer.stats.last = time.Now()
424                         for _, p := range Peers {
425                                 if p.sid == sid {
426                                         continue
427                                 }
428                                 p.stats.pktsTx++
429                                 p.stats.bytesTx += uint64(n)
430                                 if _, err = lnUDP.WriteToUDP(buf[:n], p.addr); err != nil {
431                                         slog.Warn("sendto", "peer", peer.name, "err", err)
432                                 }
433                         }
434                 }
435         }()
436
437         go func() {
438                 <-LoggerReady
439                 slog.Info("listening", "bind", *bind, "pub", hex.EncodeToString(Pub))
440                 for {
441                         conn, err := lnTCP.AcceptTCP()
442                         if err != nil {
443                                 log.Fatalln("accept:", err)
444                         }
445                         go newPeer(conn)
446                 }
447         }()
448
449         if *NoGUI {
450                 <-make(chan struct{})
451         }
452         err = GUI.MainLoop()
453         if err != nil && err != gocui.ErrQuit {
454                 log.Fatal(err)
455         }
456 }