]> Sergey Matveev's repositories - vors.git/blob - cmd/server/main.go
17c9ce3daecaf66a17ffd27dc99ed892e04e60262e1f76e2e71f89ed98af2980
[vors.git] / cmd / server / main.go
1 // VoRS -- Vo(IP) Really Simple
2 // Copyright (C) 2024 Sergey Matveev <stargrave@stargrave.org>
3 //
4 // This program is free software: you can redistribute it and/or modify
5 // it under the terms of the GNU Affero General Public License as
6 // published by the Free Software Foundation, version 3 of the License.
7 //
8 // This program is distributed in the hope that it will be useful,
9 // but WITHOUT ANY WARRANTY; without even the implied warranty of
10 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11 // GNU General Public License for more details.
12 //
13 // You should have received a copy of the GNU General Public License
14 // along with this program.  If not, see <http://www.gnu.org/licenses/>.
15
16 package main
17
18 import (
19         "crypto/rand"
20         "crypto/subtle"
21         "crypto/tls"
22         "encoding/hex"
23         "flag"
24         "fmt"
25         "io"
26         "log"
27         "log/slog"
28         "net"
29         "net/netip"
30         "os"
31         "strconv"
32         "sync"
33         "time"
34
35         "github.com/dustin/go-humanize"
36         "github.com/flynn/noise"
37         "github.com/jroimartin/gocui"
38         vors "go.stargrave.org/vors/internal"
39         "golang.org/x/crypto/blake2s"
40         "golang.org/x/crypto/chacha20"
41         "golang.org/x/crypto/poly1305"
42 )
43
44 var (
45         TLSCfg = &tls.Config{
46                 MinVersion:       tls.VersionTLS13,
47                 CurvePreferences: []tls.CurveID{tls.X25519},
48         }
49         Peers    = map[byte]*Peer{}
50         PeersM   sync.Mutex
51         Prv, Pub []byte
52         Cookies  = map[vors.Cookie]chan *net.UDPAddr{}
53 )
54
55 func newPeer(conn *net.TCPConn) {
56         logger := slog.With("remote", conn.RemoteAddr().String())
57         logger.Info("connected")
58         defer conn.Close()
59         if len(Peers) == 1<<8 {
60                 logger.Error("too many peers")
61                 return
62         }
63         err := conn.SetNoDelay(true)
64         if err != nil {
65                 log.Fatalln("nodelay:", err)
66         }
67         buf := make([]byte, len(vors.NoisePrologue))
68
69         if _, err = io.ReadFull(conn, buf); err != nil {
70                 logger.Error("handshake: read prologue", "err", err)
71                 return
72         }
73         if string(buf) != vors.NoisePrologue {
74                 logger.Error("handshake: wrong prologue", "err", err)
75                 return
76         }
77
78         hs, err := noise.NewHandshakeState(noise.Config{
79                 CipherSuite:   vors.NoiseCipherSuite,
80                 Pattern:       noise.HandshakeNK,
81                 Initiator:     false,
82                 StaticKeypair: noise.DHKey{Private: Prv, Public: Pub},
83                 Prologue:      []byte(vors.NoisePrologue),
84         })
85         if err != nil {
86                 log.Fatalln("noise.NewHandshakeState:", err)
87         }
88         buf, err = vors.PktRead(conn)
89         if err != nil {
90                 logger.Error("read handshake", "err", err)
91                 return
92         }
93         peer := Peer{
94                 logger: logger,
95                 conn:   conn,
96                 stats:  &Stats{alive: make(chan struct{})},
97                 rx:     make(chan []byte),
98                 tx:     make(chan []byte, 10),
99                 alive:  make(chan struct{}),
100         }
101         {
102                 name, _, _, err := hs.ReadMessage(nil, buf)
103                 if err != nil {
104                         logger.Error("handshake: decrypt", "err", err)
105                         return
106                 }
107                 peer.name = string(name)
108         }
109         logger = logger.With("name", peer.name)
110
111         for _, p := range Peers {
112                 if p.name != peer.name {
113                         continue
114                 }
115                 logger.Error("name already taken")
116                 buf, _, _, err = hs.WriteMessage(nil, []byte("name already taken"))
117                 if err != nil {
118                         log.Fatal(err)
119                 }
120                 vors.PktWrite(conn, buf)
121                 return
122         }
123
124         {
125                 var i byte
126                 var ok bool
127                 var found bool
128                 PeersM.Lock()
129                 for i = 0; i <= (1<<8)-1; i++ {
130                         if _, ok = Peers[i]; !ok {
131                                 peer.sid = i
132                                 found = true
133                                 break
134                         }
135                 }
136                 if found {
137                         Peers[peer.sid] = &peer
138                         go peer.Tx()
139                 }
140                 PeersM.Unlock()
141                 if !found {
142                         buf, _, _, err = hs.WriteMessage(nil, []byte("too many users"))
143                         if err != nil {
144                                 log.Fatal(err)
145                         }
146                         vors.PktWrite(conn, buf)
147                         return
148                 }
149         }
150         logger = logger.With("sid", peer.sid)
151         logger.Info("logged in")
152
153         defer func() {
154                 logger.Info("removing")
155                 PeersM.Lock()
156                 delete(Peers, peer.sid)
157                 PeersM.Unlock()
158                 close(peer.stats.alive)
159                 s := []byte(fmt.Sprintf("%s %d", vors.CmdDel, peer.sid))
160                 for _, p := range Peers {
161                         go func(tx chan []byte) { tx <- s }(p.tx)
162                 }
163         }()
164
165         {
166                 var cookie vors.Cookie
167                 if _, err = io.ReadFull(rand.Reader, cookie[:]); err != nil {
168                         log.Fatalln("cookie:", err)
169                 }
170                 gotCookie := make(chan *net.UDPAddr)
171                 Cookies[cookie] = gotCookie
172
173                 var txCS, rxCS *noise.CipherState
174                 buf, txCS, rxCS, err := hs.WriteMessage(nil,
175                         []byte(fmt.Sprintf("OK %s", hex.EncodeToString(cookie[:]))))
176                 if err = vors.PktWrite(conn, buf); err != nil {
177                         logger.Error("handshake write", "err", err)
178                         delete(Cookies, cookie)
179                         return
180                 }
181                 peer.rxCS, peer.txCS = txCS, rxCS
182
183                 timeout := time.NewTimer(vors.PingTime)
184                 select {
185                 case peer.addr = <-gotCookie:
186                 case <-timeout.C:
187                         logger.Error("cookie timeout")
188                         delete(Cookies, cookie)
189                         return
190                 }
191                 delete(Cookies, cookie)
192                 logger.Info("got cookie", "addr", peer.addr)
193                 if !timeout.Stop() {
194                         <-timeout.C
195                 }
196         }
197         go peer.Rx()
198         peer.tx <- []byte(fmt.Sprintf("SID %d", peer.sid))
199
200         for _, p := range Peers {
201                 if p.sid == peer.sid {
202                         continue
203                 }
204                 peer.tx <- []byte(fmt.Sprintf("%s %d %s %s",
205                         vors.CmdAdd, p.sid, p.name, hex.EncodeToString(p.key)))
206         }
207
208         {
209                 h, err := blake2s.New256(hs.ChannelBinding())
210                 if err != nil {
211                         log.Fatalln(err)
212                 }
213                 h.Write([]byte(vors.NoisePrologue))
214                 peer.key = h.Sum(nil)
215         }
216
217         {
218                 s := []byte(fmt.Sprintf("%s %d %s %s",
219                         vors.CmdAdd, peer.sid, peer.name, hex.EncodeToString(peer.key)))
220                 for _, p := range Peers {
221                         if p.sid != peer.sid {
222                                 p.tx <- s
223                         }
224                 }
225         }
226
227         seen := time.Now()
228         go func(seen *time.Time) {
229                 ticker := time.Tick(vors.PingTime)
230                 var now time.Time
231                 for {
232                         select {
233                         case now = <-ticker:
234                                 if seen.Add(2 * vors.PingTime).Before(now) {
235                                         logger.Error("timeout:", "seen", seen)
236                                         peer.Close()
237                                         return
238                                 }
239                         case <-peer.alive:
240                                 return
241                         }
242                 }
243         }(&seen)
244
245         go func(stats *Stats) {
246                 if *NoGUI {
247                         return
248                 }
249                 tick := time.Tick(vors.ScreenRefresh)
250                 var now time.Time
251                 var v *gocui.View
252                 for {
253                         select {
254                         case <-stats.alive:
255                                 GUI.DeleteView(peer.name)
256                                 return
257                         case now = <-tick:
258                                 s := fmt.Sprintf(
259                                         "%s | Rx/Tx: %s / %s  |  %s / %s",
260                                         peer.addr,
261                                         humanize.Comma(stats.pktsRx),
262                                         humanize.Comma(stats.pktsTx),
263                                         humanize.IBytes(stats.bytesRx),
264                                         humanize.IBytes(stats.bytesTx),
265                                 )
266                                 if stats.last.Add(vors.ScreenRefresh).After(now) {
267                                         s += "  |  " + vors.CGreen + "TALK" + vors.CReset
268                                 }
269                                 v, err = GUI.View(peer.name)
270                                 if err == nil {
271                                         v.Clear()
272                                         v.Write([]byte(s))
273                                 }
274                         }
275                 }
276         }(peer.stats)
277
278         for buf := range peer.rx {
279                 if string(buf) == vors.CmdPing {
280                         seen = time.Now()
281                         peer.tx <- []byte(vors.CmdPong)
282                 }
283         }
284 }
285
286 func main() {
287         bind := flag.String("bind", "[::1]:"+strconv.Itoa(vors.DefaultPort),
288                 "Host:TCP/UDP port to listen on")
289         kpFile := flag.String("key", "key", "Path to keypair file")
290         flag.Parse()
291         log.SetFlags(log.Lmicroseconds | log.Lshortfile)
292
293         {
294                 data, err := os.ReadFile(*kpFile)
295                 if err != nil {
296                         log.Fatal(err)
297                 }
298                 Prv, Pub = data[:len(data)/2], data[len(data)/2:]
299         }
300
301         lnTCP, err := net.ListenTCP("tcp",
302                 net.TCPAddrFromAddrPort(netip.MustParseAddrPort(*bind)))
303         if err != nil {
304                 log.Fatal(err)
305         }
306         lnUDP, err := net.ListenUDP("udp",
307                 net.UDPAddrFromAddrPort(netip.MustParseAddrPort(*bind)))
308         if err != nil {
309                 log.Fatal(err)
310         }
311
312         LoggerReady := make(chan struct{})
313         if *NoGUI {
314                 close(GUIReadyC)
315                 slog.SetDefault(slog.New(slog.NewTextHandler(os.Stderr, nil)))
316                 close(LoggerReady)
317         } else {
318                 GUI, err = gocui.NewGui(gocui.OutputNormal)
319                 if err != nil {
320                         log.Fatal(err)
321                 }
322                 defer GUI.Close()
323                 GUI.SetManagerFunc(guiLayout)
324                 if err := GUI.SetKeybinding("", gocui.KeyF10, gocui.ModNone, guiQuit); err != nil {
325                         log.Fatal(err)
326                 }
327
328                 go func() {
329                         <-GUIReadyC
330                         v, err := GUI.View("logs")
331                         if err != nil {
332                                 log.Fatal(err)
333                         }
334                         slog.SetDefault(slog.New(slog.NewTextHandler(v, nil)))
335                         close(LoggerReady)
336                         for {
337                                 time.Sleep(vors.ScreenRefresh)
338                                 GUI.Update(func(gui *gocui.Gui) error {
339                                         return nil
340                                 })
341                         }
342                 }()
343         }
344
345         go func() {
346                 <-LoggerReady
347                 buf := make([]byte, 2*vors.FrameLen)
348                 var n int
349                 var from *net.UDPAddr
350                 var err error
351                 var sid byte
352                 var peer *Peer
353                 var ciph *chacha20.Cipher
354                 var macKey [32]byte
355                 var mac *poly1305.MAC
356                 tag := make([]byte, poly1305.TagSize)
357                 nonce := make([]byte, 12)
358                 for {
359                         n, from, err = lnUDP.ReadFromUDP(buf)
360                         if err != nil {
361                                 log.Fatalln("recvfrom:", err)
362                         }
363
364                         if n == vors.CookieLen {
365                                 var cookie vors.Cookie
366                                 copy(cookie[:], buf)
367                                 if c, ok := Cookies[cookie]; ok {
368                                         c <- from
369                                         close(c)
370                                 } else {
371                                         slog.Info("unknown cookie", "cookie", cookie)
372                                 }
373                                 continue
374                         }
375
376                         sid = buf[0]
377                         peer = Peers[sid]
378                         if peer == nil {
379                                 slog.Info("unknown", "sid", sid, "from", from)
380                                 continue
381                         }
382
383                         if from.Port != peer.addr.Port || !from.IP.Equal(peer.addr.IP) {
384                                 slog.Info("wrong addr",
385                                         "peer", peer.name,
386                                         "our", peer.addr,
387                                         "got", from)
388                                 continue
389                         }
390
391                         peer.stats.pktsRx++
392                         peer.stats.bytesRx += uint64(n)
393                         if n == 1 {
394                                 continue
395                         }
396                         if n <= 4+vors.TagLen {
397                                 slog.Info("too small", "peer", peer.name, "len", n)
398                                 continue
399                         }
400
401                         copy(nonce[len(nonce)-4:], buf)
402                         ciph, err = chacha20.NewUnauthenticatedCipher(peer.key, nonce)
403                         if err != nil {
404                                 log.Fatal(err)
405                         }
406                         clear(macKey[:])
407                         ciph.XORKeyStream(macKey[:], macKey[:])
408                         ciph.SetCounter(1)
409                         mac = poly1305.New(&macKey)
410                         if _, err = mac.Write(buf[4 : n-vors.TagLen]); err != nil {
411                                 log.Fatal(err)
412                         }
413                         mac.Sum(tag[:0])
414                         if subtle.ConstantTimeCompare(
415                                 tag[:vors.TagLen],
416                                 buf[n-vors.TagLen:n],
417                         ) != 1 {
418                                 log.Println("decrypt:", peer.name, "tag differs")
419                                 slog.Info("MAC failed", "peer", peer.name, "len", n)
420                                 continue
421                         }
422
423                         peer.stats.last = time.Now()
424                         for _, p := range Peers {
425                                 if p.sid == sid {
426                                         continue
427                                 }
428                                 p.stats.pktsTx++
429                                 p.stats.bytesTx += uint64(n)
430                                 if _, err = lnUDP.WriteToUDP(buf[:n], p.addr); err != nil {
431                                         slog.Warn("sendto", "peer", peer.name, "err", err)
432                                 }
433                         }
434                 }
435         }()
436
437         go func() {
438                 <-LoggerReady
439                 slog.Info("listening", "bind", *bind, "pub", hex.EncodeToString(Pub))
440                 for {
441                         conn, err := lnTCP.AcceptTCP()
442                         if err != nil {
443                                 log.Fatalln("accept:", err)
444                         }
445                         go newPeer(conn)
446                 }
447         }()
448
449         if *NoGUI {
450                 <-make(chan struct{})
451         }
452         err = GUI.MainLoop()
453         if err != nil && err != gocui.ErrQuit {
454                 log.Fatal(err)
455         }
456 }