]> Sergey Matveev's repositories - vors.git/blob - cmd/server/main.go
Rooms support
[vors.git] / cmd / server / main.go
1 // VoRS -- Vo(IP) Really Simple
2 // Copyright (C) 2024 Sergey Matveev <stargrave@stargrave.org>
3 //
4 // This program is free software: you can redistribute it and/or modify
5 // it under the terms of the GNU Affero General Public License as
6 // published by the Free Software Foundation, version 3 of the License.
7 //
8 // This program is distributed in the hope that it will be useful,
9 // but WITHOUT ANY WARRANTY; without even the implied warranty of
10 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11 // GNU General Public License for more details.
12 //
13 // You should have received a copy of the GNU General Public License
14 // along with this program.  If not, see <http://www.gnu.org/licenses/>.
15
16 package main
17
18 import (
19         "crypto/rand"
20         "crypto/subtle"
21         "crypto/tls"
22         "encoding/hex"
23         "flag"
24         "fmt"
25         "io"
26         "log"
27         "log/slog"
28         "net"
29         "net/netip"
30         "os"
31         "strconv"
32         "strings"
33         "time"
34
35         "github.com/flynn/noise"
36         "github.com/jroimartin/gocui"
37         vors "go.stargrave.org/vors/internal"
38         "golang.org/x/crypto/blake2s"
39         "golang.org/x/crypto/chacha20"
40         "golang.org/x/crypto/poly1305"
41 )
42
43 var (
44         TLSCfg = &tls.Config{
45                 MinVersion:       tls.VersionTLS13,
46                 CurvePreferences: []tls.CurveID{tls.X25519},
47         }
48         Prv, Pub []byte
49         Cookies  = map[vors.Cookie]chan *net.UDPAddr{}
50 )
51
52 func newPeer(conn *net.TCPConn) {
53         logger := slog.With("remote", conn.RemoteAddr().String())
54         logger.Info("connected")
55         defer conn.Close()
56         err := conn.SetNoDelay(true)
57         if err != nil {
58                 log.Fatalln("nodelay:", err)
59         }
60         buf := make([]byte, len(vors.NoisePrologue))
61
62         if _, err = io.ReadFull(conn, buf); err != nil {
63                 logger.Error("handshake: read prologue", "err", err)
64                 return
65         }
66         if string(buf) != vors.NoisePrologue {
67                 logger.Error("handshake: wrong prologue", "err", err)
68                 return
69         }
70
71         hs, err := noise.NewHandshakeState(noise.Config{
72                 CipherSuite:   vors.NoiseCipherSuite,
73                 Pattern:       noise.HandshakeNK,
74                 Initiator:     false,
75                 StaticKeypair: noise.DHKey{Private: Prv, Public: Pub},
76                 Prologue:      []byte(vors.NoisePrologue),
77         })
78         if err != nil {
79                 log.Fatalln("noise.NewHandshakeState:", err)
80         }
81         buf, err = vors.PktRead(conn)
82         if err != nil {
83                 logger.Error("read handshake", "err", err)
84                 return
85         }
86         peer := &Peer{
87                 logger: logger,
88                 conn:   conn,
89                 stats:  &Stats{},
90                 rx:     make(chan []byte),
91                 tx:     make(chan []byte, 10),
92                 alive:  make(chan struct{}),
93         }
94         var room *Room
95         {
96                 nameAndRoom, _, _, err := hs.ReadMessage(nil, buf)
97                 if err != nil {
98                         logger.Error("handshake: decrypt", "err", err)
99                         return
100                 }
101                 cols := strings.SplitN(string(nameAndRoom), " ", 3)
102                 roomName := "/"
103                 if len(cols) > 1 {
104                         roomName = cols[1]
105                 }
106                 var key string
107                 if len(cols) > 2 {
108                         key = cols[2]
109                 }
110                 peer.name = string(cols[0])
111                 logger = logger.With("name", peer.name, "room", roomName)
112                 RoomsM.Lock()
113                 room = Rooms[roomName]
114                 if room == nil {
115                         room = &Room{
116                                 name:  roomName,
117                                 key:   key,
118                                 peers: make(map[byte]*Peer),
119                                 alive: make(chan struct{}),
120                         }
121                         Rooms[roomName] = room
122                         go func() {
123                                 if *NoGUI {
124                                         return
125                                 }
126                                 tick := time.Tick(vors.ScreenRefresh)
127                                 var now time.Time
128                                 var v *gocui.View
129                                 for {
130                                         select {
131                                         case <-room.alive:
132                                                 GUI.DeleteView(room.name)
133                                                 return
134                                         case now = <-tick:
135                                                 v, err = GUI.View(room.name)
136                                                 if err == nil {
137                                                         v.Clear()
138                                                         v.Write([]byte(strings.Join(room.Stats(now), "\n")))
139                                                 }
140                                         }
141                                 }
142                         }()
143                 }
144                 RoomsM.Unlock()
145                 if room.key != key {
146                         logger.Error("wrong password")
147                         buf, _, _, err = hs.WriteMessage(nil, []byte("wrong password"))
148                         if err != nil {
149                                 log.Fatal(err)
150                         }
151                         vors.PktWrite(conn, buf)
152                         return
153                 }
154         }
155         peer.room = room
156
157         for _, p := range room.peers {
158                 if p.name != peer.name {
159                         continue
160                 }
161                 logger.Error("name already taken")
162                 buf, _, _, err = hs.WriteMessage(nil, []byte("name already taken"))
163                 if err != nil {
164                         log.Fatal(err)
165                 }
166                 vors.PktWrite(conn, buf)
167                 return
168         }
169
170         {
171                 var i byte
172                 var ok bool
173                 var found bool
174                 PeersM.Lock()
175                 for i = 0; i <= (1<<8)-1; i++ {
176                         if _, ok = Peers[i]; !ok {
177                                 peer.sid = i
178                                 found = true
179                                 break
180                         }
181                 }
182                 if found {
183                         Peers[peer.sid] = peer
184                         go peer.Tx()
185                 }
186                 PeersM.Unlock()
187                 if !found {
188                         buf, _, _, err = hs.WriteMessage(nil, []byte("too many users"))
189                         if err != nil {
190                                 log.Fatal(err)
191                         }
192                         vors.PktWrite(conn, buf)
193                         return
194                 }
195         }
196         logger = logger.With("sid", peer.sid)
197         room.peers[peer.sid] = peer
198         logger.Info("logged in")
199
200         defer func() {
201                 logger.Info("removing")
202                 PeersM.Lock()
203                 delete(Peers, peer.sid)
204                 delete(room.peers, peer.sid)
205                 PeersM.Unlock()
206                 s := []byte(fmt.Sprintf("%s %d", vors.CmdDel, peer.sid))
207                 for _, p := range room.peers {
208                         go func(tx chan []byte) { tx <- s }(p.tx)
209                 }
210         }()
211
212         {
213                 var cookie vors.Cookie
214                 if _, err = io.ReadFull(rand.Reader, cookie[:]); err != nil {
215                         log.Fatalln("cookie:", err)
216                 }
217                 gotCookie := make(chan *net.UDPAddr)
218                 Cookies[cookie] = gotCookie
219
220                 var txCS, rxCS *noise.CipherState
221                 buf, txCS, rxCS, err := hs.WriteMessage(nil,
222                         []byte(fmt.Sprintf("OK %s", hex.EncodeToString(cookie[:]))))
223                 if err != nil {
224                         log.Fatalln("hs.WriteMessage:", err)
225                 }
226                 if err = vors.PktWrite(conn, buf); err != nil {
227                         logger.Error("handshake write", "err", err)
228                         delete(Cookies, cookie)
229                         return
230                 }
231                 peer.rxCS, peer.txCS = txCS, rxCS
232
233                 timeout := time.NewTimer(vors.PingTime)
234                 select {
235                 case peer.addr = <-gotCookie:
236                 case <-timeout.C:
237                         logger.Error("cookie timeout")
238                         delete(Cookies, cookie)
239                         return
240                 }
241                 delete(Cookies, cookie)
242                 logger.Info("got cookie", "addr", peer.addr)
243                 if !timeout.Stop() {
244                         <-timeout.C
245                 }
246         }
247         go peer.Rx()
248         peer.tx <- []byte(fmt.Sprintf("SID %d", peer.sid))
249
250         for _, p := range room.peers {
251                 if p.sid == peer.sid {
252                         continue
253                 }
254                 peer.tx <- []byte(fmt.Sprintf("%s %d %s %s",
255                         vors.CmdAdd, p.sid, p.name, hex.EncodeToString(p.key)))
256         }
257
258         {
259                 h, err := blake2s.New256(hs.ChannelBinding())
260                 if err != nil {
261                         log.Fatalln(err)
262                 }
263                 h.Write([]byte(vors.NoisePrologue))
264                 peer.key = h.Sum(nil)
265         }
266
267         {
268                 s := []byte(fmt.Sprintf("%s %d %s %s",
269                         vors.CmdAdd, peer.sid, peer.name, hex.EncodeToString(peer.key)))
270                 for _, p := range room.peers {
271                         if p.sid != peer.sid {
272                                 p.tx <- s
273                         }
274                 }
275         }
276
277         seen := time.Now()
278         go func(seen *time.Time) {
279                 ticker := time.Tick(vors.PingTime)
280                 var now time.Time
281                 for {
282                         select {
283                         case now = <-ticker:
284                                 if seen.Add(2 * vors.PingTime).Before(now) {
285                                         logger.Error("timeout:", "seen", seen)
286                                         peer.Close()
287                                         return
288                                 }
289                         case <-peer.alive:
290                                 return
291                         }
292                 }
293         }(&seen)
294
295         for buf := range peer.rx {
296                 if string(buf) == vors.CmdPing {
297                         seen = time.Now()
298                         peer.tx <- []byte(vors.CmdPong)
299                 }
300         }
301 }
302
303 func main() {
304         bind := flag.String("bind", "[::1]:"+strconv.Itoa(vors.DefaultPort),
305                 "Host:TCP/UDP port to listen on")
306         kpFile := flag.String("key", "key", "Path to keypair file")
307         flag.Parse()
308         log.SetFlags(log.Lmicroseconds | log.Lshortfile)
309
310         {
311                 data, err := os.ReadFile(*kpFile)
312                 if err != nil {
313                         log.Fatal(err)
314                 }
315                 Prv, Pub = data[:len(data)/2], data[len(data)/2:]
316         }
317
318         lnTCP, err := net.ListenTCP("tcp",
319                 net.TCPAddrFromAddrPort(netip.MustParseAddrPort(*bind)))
320         if err != nil {
321                 log.Fatal(err)
322         }
323         lnUDP, err := net.ListenUDP("udp",
324                 net.UDPAddrFromAddrPort(netip.MustParseAddrPort(*bind)))
325         if err != nil {
326                 log.Fatal(err)
327         }
328
329         LoggerReady := make(chan struct{})
330         if *NoGUI {
331                 close(GUIReadyC)
332                 slog.SetDefault(slog.New(slog.NewTextHandler(os.Stderr, nil)))
333                 close(LoggerReady)
334         } else {
335                 GUI, err = gocui.NewGui(gocui.OutputNormal)
336                 if err != nil {
337                         log.Fatal(err)
338                 }
339                 defer GUI.Close()
340                 GUI.SetManagerFunc(guiLayout)
341                 if err := GUI.SetKeybinding("", gocui.KeyF10, gocui.ModNone, guiQuit); err != nil {
342                         log.Fatal(err)
343                 }
344
345                 go func() {
346                         <-GUIReadyC
347                         v, err := GUI.View("logs")
348                         if err != nil {
349                                 log.Fatal(err)
350                         }
351                         slog.SetDefault(slog.New(slog.NewTextHandler(v, nil)))
352                         close(LoggerReady)
353                         for {
354                                 time.Sleep(vors.ScreenRefresh)
355                                 GUI.Update(func(gui *gocui.Gui) error {
356                                         return nil
357                                 })
358                         }
359                 }()
360         }
361
362         go func() {
363                 <-LoggerReady
364                 buf := make([]byte, 2*vors.FrameLen)
365                 var n int
366                 var from *net.UDPAddr
367                 var err error
368                 var sid byte
369                 var peer *Peer
370                 var ciph *chacha20.Cipher
371                 var macKey [32]byte
372                 var mac *poly1305.MAC
373                 tag := make([]byte, poly1305.TagSize)
374                 nonce := make([]byte, 12)
375                 for {
376                         n, from, err = lnUDP.ReadFromUDP(buf)
377                         if err != nil {
378                                 log.Fatalln("recvfrom:", err)
379                         }
380
381                         if n == vors.CookieLen {
382                                 var cookie vors.Cookie
383                                 copy(cookie[:], buf)
384                                 if c, ok := Cookies[cookie]; ok {
385                                         c <- from
386                                         close(c)
387                                 } else {
388                                         slog.Info("unknown cookie", "cookie", cookie)
389                                 }
390                                 continue
391                         }
392
393                         sid = buf[0]
394                         peer = Peers[sid]
395                         if peer == nil {
396                                 slog.Info("unknown", "sid", sid, "from", from)
397                                 continue
398                         }
399
400                         if from.Port != peer.addr.Port || !from.IP.Equal(peer.addr.IP) {
401                                 slog.Info("wrong addr",
402                                         "peer", peer.name,
403                                         "our", peer.addr,
404                                         "got", from)
405                                 continue
406                         }
407
408                         peer.stats.pktsRx++
409                         peer.stats.bytesRx += uint64(n)
410                         if n == 1 {
411                                 continue
412                         }
413                         if n <= 4+vors.TagLen {
414                                 peer.stats.bads++
415                                 continue
416                         }
417
418                         copy(nonce[len(nonce)-4:], buf)
419                         ciph, err = chacha20.NewUnauthenticatedCipher(peer.key, nonce)
420                         if err != nil {
421                                 log.Fatal(err)
422                         }
423                         clear(macKey[:])
424                         ciph.XORKeyStream(macKey[:], macKey[:])
425                         ciph.SetCounter(1)
426                         mac = poly1305.New(&macKey)
427                         if _, err = mac.Write(buf[4 : n-vors.TagLen]); err != nil {
428                                 log.Fatal(err)
429                         }
430                         mac.Sum(tag[:0])
431                         if subtle.ConstantTimeCompare(
432                                 tag[:vors.TagLen],
433                                 buf[n-vors.TagLen:n],
434                         ) != 1 {
435                                 peer.stats.bads++
436                                 continue
437                         }
438
439                         peer.stats.last = time.Now()
440                         for _, p := range peer.room.peers {
441                                 if p.sid == sid {
442                                         continue
443                                 }
444                                 p.stats.pktsTx++
445                                 p.stats.bytesTx += uint64(n)
446                                 if _, err = lnUDP.WriteToUDP(buf[:n], p.addr); err != nil {
447                                         slog.Warn("sendto", "peer", peer.name, "err", err)
448                                 }
449                         }
450                 }
451         }()
452
453         go func() {
454                 <-LoggerReady
455                 slog.Info("listening", "bind", *bind, "pub", hex.EncodeToString(Pub))
456                 for {
457                         conn, err := lnTCP.AcceptTCP()
458                         if err != nil {
459                                 log.Fatalln("accept:", err)
460                         }
461                         go newPeer(conn)
462                 }
463         }()
464
465         if *NoGUI {
466                 <-make(chan struct{})
467         }
468         err = GUI.MainLoop()
469         if err != nil && err != gocui.ErrQuit {
470                 log.Fatal(err)
471         }
472 }