]> Sergey Matveev's repositories - vors.git/blob - cmd/server/main.go
7f475f0e9b6bc201b780a25dce35c32b0cc42a78fd2872d040d09a488fc2bd19
[vors.git] / cmd / server / main.go
1 // VoRS -- Vo(IP) Really Simple
2 // Copyright (C) 2024 Sergey Matveev <stargrave@stargrave.org>
3 //
4 // This program is free software: you can redistribute it and/or modify
5 // it under the terms of the GNU Affero General Public License as
6 // published by the Free Software Foundation, version 3 of the License.
7 //
8 // This program is distributed in the hope that it will be useful,
9 // but WITHOUT ANY WARRANTY; without even the implied warranty of
10 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11 // GNU Affero General Public License for more details.
12 //
13 // You should have received a copy of the GNU Affero General Public License
14 // along with this program.  If not, see <http://www.gnu.org/licenses/>.
15
16 package main
17
18 import (
19         "crypto/rand"
20         "crypto/subtle"
21         "crypto/tls"
22         "encoding/base64"
23         "encoding/hex"
24         "flag"
25         "fmt"
26         "io"
27         "log"
28         "log/slog"
29         "net"
30         "net/netip"
31         "os"
32         "strconv"
33         "strings"
34         "time"
35
36         "github.com/dchest/siphash"
37         "github.com/flynn/noise"
38         "github.com/jroimartin/gocui"
39         vors "go.stargrave.org/vors/internal"
40         "golang.org/x/crypto/blake2s"
41         "golang.org/x/crypto/chacha20"
42 )
43
44 var (
45         TLSCfg = &tls.Config{
46                 MinVersion:       tls.VersionTLS13,
47                 CurvePreferences: []tls.CurveID{tls.X25519},
48         }
49         Prv, Pub []byte
50         Cookies  = map[vors.Cookie]chan *net.UDPAddr{}
51 )
52
53 func newPeer(conn *net.TCPConn) {
54         logger := slog.With("remote", conn.RemoteAddr().String())
55         logger.Info("connected")
56         defer conn.Close()
57         err := conn.SetNoDelay(true)
58         if err != nil {
59                 log.Fatalln("nodelay:", err)
60         }
61         buf := make([]byte, len(vors.NoisePrologue))
62
63         if _, err = io.ReadFull(conn, buf); err != nil {
64                 logger.Error("handshake: read prologue", "err", err)
65                 return
66         }
67         if string(buf) != vors.NoisePrologue {
68                 logger.Error("handshake: wrong prologue", "err", err)
69                 return
70         }
71
72         hs, err := noise.NewHandshakeState(noise.Config{
73                 CipherSuite:   vors.NoiseCipherSuite,
74                 Pattern:       noise.HandshakeNK,
75                 Initiator:     false,
76                 StaticKeypair: noise.DHKey{Private: Prv, Public: Pub},
77                 Prologue:      []byte(vors.NoisePrologue),
78         })
79         if err != nil {
80                 log.Fatalln("noise.NewHandshakeState:", err)
81         }
82         buf, err = vors.PktRead(conn)
83         if err != nil {
84                 logger.Error("read handshake", "err", err)
85                 return
86         }
87         peer := &Peer{
88                 logger: logger,
89                 conn:   conn,
90                 stats:  &Stats{},
91                 rx:     make(chan []byte),
92                 tx:     make(chan []byte, 10),
93                 alive:  make(chan struct{}),
94         }
95         var room *Room
96         {
97                 nameAndRoom, _, _, err := hs.ReadMessage(nil, buf)
98                 if err != nil {
99                         logger.Error("handshake: decrypt", "err", err)
100                         return
101                 }
102                 cols := strings.SplitN(string(nameAndRoom), " ", 3)
103                 roomName := "/"
104                 if len(cols) > 1 {
105                         roomName = cols[1]
106                 }
107                 var key string
108                 if len(cols) > 2 {
109                         key = cols[2]
110                 }
111                 peer.name = string(cols[0])
112                 logger = logger.With("name", peer.name, "room", roomName)
113                 RoomsM.Lock()
114                 room = Rooms[roomName]
115                 if room == nil {
116                         room = &Room{
117                                 name:  roomName,
118                                 key:   key,
119                                 peers: make(map[byte]*Peer),
120                                 alive: make(chan struct{}),
121                         }
122                         Rooms[roomName] = room
123                         go func() {
124                                 if *NoGUI {
125                                         return
126                                 }
127                                 tick := time.Tick(vors.ScreenRefresh)
128                                 var now time.Time
129                                 var v *gocui.View
130                                 for {
131                                         select {
132                                         case <-room.alive:
133                                                 GUI.DeleteView(room.name)
134                                                 return
135                                         case now = <-tick:
136                                                 v, err = GUI.View(room.name)
137                                                 if err == nil {
138                                                         v.Clear()
139                                                         v.Write([]byte(strings.Join(room.Stats(now), "\n")))
140                                                 }
141                                         }
142                                 }
143                         }()
144                 }
145                 RoomsM.Unlock()
146                 if room.key != key {
147                         logger.Error("wrong password")
148                         buf, _, _, err = hs.WriteMessage(nil, []byte("wrong password"))
149                         if err != nil {
150                                 log.Fatal(err)
151                         }
152                         vors.PktWrite(conn, buf)
153                         return
154                 }
155         }
156         peer.room = room
157
158         for _, p := range room.peers {
159                 if p.name != peer.name {
160                         continue
161                 }
162                 logger.Error("name already taken")
163                 buf, _, _, err = hs.WriteMessage(nil, []byte("name already taken"))
164                 if err != nil {
165                         log.Fatal(err)
166                 }
167                 vors.PktWrite(conn, buf)
168                 return
169         }
170
171         {
172                 var i byte
173                 var ok bool
174                 var found bool
175                 PeersM.Lock()
176                 for i = 0; i <= (1<<8)-1; i++ {
177                         if _, ok = Peers[i]; !ok {
178                                 peer.sid = i
179                                 found = true
180                                 break
181                         }
182                 }
183                 if found {
184                         Peers[peer.sid] = peer
185                         go peer.Tx()
186                 }
187                 PeersM.Unlock()
188                 if !found {
189                         buf, _, _, err = hs.WriteMessage(nil, []byte("too many users"))
190                         if err != nil {
191                                 log.Fatal(err)
192                         }
193                         vors.PktWrite(conn, buf)
194                         return
195                 }
196         }
197         logger = logger.With("sid", peer.sid)
198         room.peers[peer.sid] = peer
199         logger.Info("logged in")
200
201         defer func() {
202                 logger.Info("removing")
203                 PeersM.Lock()
204                 delete(Peers, peer.sid)
205                 delete(room.peers, peer.sid)
206                 PeersM.Unlock()
207                 s := []byte(fmt.Sprintf("%s %d", vors.CmdDel, peer.sid))
208                 for _, p := range room.peers {
209                         go func(tx chan []byte) { tx <- s }(p.tx)
210                 }
211         }()
212
213         {
214                 var cookie vors.Cookie
215                 if _, err = io.ReadFull(rand.Reader, cookie[:]); err != nil {
216                         log.Fatalln("cookie:", err)
217                 }
218                 gotCookie := make(chan *net.UDPAddr)
219                 Cookies[cookie] = gotCookie
220
221                 var txCS, rxCS *noise.CipherState
222                 buf, txCS, rxCS, err := hs.WriteMessage(nil,
223                         []byte(fmt.Sprintf("OK %s", hex.EncodeToString(cookie[:]))))
224                 if err != nil {
225                         log.Fatalln("hs.WriteMessage:", err)
226                 }
227                 if err = vors.PktWrite(conn, buf); err != nil {
228                         logger.Error("handshake write", "err", err)
229                         delete(Cookies, cookie)
230                         return
231                 }
232                 peer.rxCS, peer.txCS = txCS, rxCS
233
234                 timeout := time.NewTimer(vors.PingTime)
235                 select {
236                 case peer.addr = <-gotCookie:
237                 case <-timeout.C:
238                         logger.Error("cookie timeout")
239                         delete(Cookies, cookie)
240                         return
241                 }
242                 delete(Cookies, cookie)
243                 logger.Info("got cookie", "addr", peer.addr)
244                 if !timeout.Stop() {
245                         <-timeout.C
246                 }
247         }
248         go peer.Rx()
249         peer.tx <- []byte(fmt.Sprintf("SID %d", peer.sid))
250
251         for _, p := range room.peers {
252                 if p.sid == peer.sid {
253                         continue
254                 }
255                 peer.tx <- []byte(fmt.Sprintf("%s %d %s %s",
256                         vors.CmdAdd, p.sid, p.name, hex.EncodeToString(p.key)))
257         }
258
259         {
260                 xof, err := blake2s.NewXOF(chacha20.KeySize+16, nil)
261                 if err != nil {
262                         log.Fatalln(err)
263                 }
264                 xof.Write([]byte(vors.NoisePrologue))
265                 xof.Write(hs.ChannelBinding())
266                 peer.key = make([]byte, chacha20.KeySize+16)
267                 if _, err = io.ReadFull(xof, peer.key); err != nil {
268                         log.Fatalln(err)
269                 }
270                 peer.mac = siphash.New(peer.key[chacha20.KeySize:])
271         }
272
273         {
274                 s := []byte(fmt.Sprintf("%s %d %s %s",
275                         vors.CmdAdd, peer.sid, peer.name, hex.EncodeToString(peer.key)))
276                 for _, p := range room.peers {
277                         if p.sid != peer.sid {
278                                 p.tx <- s
279                         }
280                 }
281         }
282
283         seen := time.Now()
284         go func(seen *time.Time) {
285                 ticker := time.Tick(vors.PingTime)
286                 var now time.Time
287                 for {
288                         select {
289                         case now = <-ticker:
290                                 if seen.Add(2 * vors.PingTime).Before(now) {
291                                         logger.Error("timeout:", "seen", seen)
292                                         peer.Close()
293                                         return
294                                 }
295                         case <-peer.alive:
296                                 return
297                         }
298                 }
299         }(&seen)
300
301         for buf := range peer.rx {
302                 if string(buf) == vors.CmdPing {
303                         seen = time.Now()
304                         peer.tx <- []byte(vors.CmdPong)
305                 }
306         }
307 }
308
309 func main() {
310         bind := flag.String("bind", "[::1]:"+strconv.Itoa(vors.DefaultPort),
311                 "host:TCP/UDP port to listen on")
312         kpFile := flag.String("key", "key", "path to keypair file")
313         prefer4 := flag.Bool("4", false,
314                 "Prefer obsolete legacy IPv4 address during name resolution")
315         version := flag.Bool("version", false, "print version")
316         warranty := flag.Bool("warranty", false, "print warranty information")
317         flag.Parse()
318         log.SetFlags(log.Lmicroseconds | log.Lshortfile)
319
320         if *warranty {
321                 fmt.Println(vors.Warranty)
322                 return
323         }
324         if *version {
325                 fmt.Println(vors.GetVersion())
326                 return
327         }
328
329         {
330                 data, err := os.ReadFile(*kpFile)
331                 if err != nil {
332                         log.Fatal(err)
333                 }
334                 Prv, Pub = data[:len(data)/2], data[len(data)/2:]
335         }
336
337         vors.PreferIPv4 = *prefer4
338         lnTCP, err := net.ListenTCP("tcp",
339                 net.TCPAddrFromAddrPort(netip.MustParseAddrPort(*bind)))
340         if err != nil {
341                 log.Fatal(err)
342         }
343         lnUDP, err := net.ListenUDP("udp",
344                 net.UDPAddrFromAddrPort(netip.MustParseAddrPort(*bind)))
345         if err != nil {
346                 log.Fatal(err)
347         }
348
349         LoggerReady := make(chan struct{})
350         if *NoGUI {
351                 close(GUIReadyC)
352                 slog.SetDefault(slog.New(slog.NewTextHandler(os.Stderr, nil)))
353                 close(LoggerReady)
354         } else {
355                 GUI, err = gocui.NewGui(gocui.OutputNormal)
356                 if err != nil {
357                         log.Fatal(err)
358                 }
359                 defer GUI.Close()
360                 GUI.SetManagerFunc(guiLayout)
361                 if err := GUI.SetKeybinding("", 'q', gocui.ModNone, guiQuit); err != nil {
362                         log.Fatal(err)
363                 }
364
365                 go func() {
366                         <-GUIReadyC
367                         v, err := GUI.View("logs")
368                         if err != nil {
369                                 log.Fatal(err)
370                         }
371                         slog.SetDefault(slog.New(slog.NewTextHandler(v, nil)))
372                         close(LoggerReady)
373                         for {
374                                 time.Sleep(vors.ScreenRefresh)
375                                 GUI.Update(func(gui *gocui.Gui) error {
376                                         return nil
377                                 })
378                         }
379                 }()
380         }
381
382         go func() {
383                 <-LoggerReady
384                 buf := make([]byte, 2*vors.FrameLen)
385                 var n int
386                 var from *net.UDPAddr
387                 var err error
388                 var sid byte
389                 var peer *Peer
390                 tag := make([]byte, siphash.Size)
391                 for {
392                         n, from, err = lnUDP.ReadFromUDP(buf)
393                         if err != nil {
394                                 log.Fatalln("recvfrom:", err)
395                         }
396
397                         if n == vors.CookieLen {
398                                 var cookie vors.Cookie
399                                 copy(cookie[:], buf)
400                                 if c, ok := Cookies[cookie]; ok {
401                                         c <- from
402                                         close(c)
403                                         continue
404                                 }
405                         }
406
407                         sid = buf[0]
408                         peer = Peers[sid]
409                         if peer == nil {
410                                 slog.Info("unknown", "sid", sid, "from", from)
411                                 continue
412                         }
413
414                         if peer.addr == nil ||
415                                 from.Port != peer.addr.Port ||
416                                 !from.IP.Equal(peer.addr.IP) {
417                                 slog.Info("wrong addr",
418                                         "peer", peer.name,
419                                         "our", peer.addr,
420                                         "got", from)
421                                 continue
422                         }
423
424                         peer.stats.pktsRx++
425                         peer.stats.bytesRx += vors.IPHdrLen(from.IP) + 8 + uint64(n)
426                         if n == 1 {
427                                 continue
428                         }
429                         if n <= 4+siphash.Size {
430                                 peer.stats.bads++
431                                 continue
432                         }
433
434                         peer.mac.Reset()
435                         if _, err = peer.mac.Write(buf[:n-siphash.Size]); err != nil {
436                                 log.Fatal(err)
437                         }
438                         peer.mac.Sum(tag[:0])
439                         if subtle.ConstantTimeCompare(
440                                 tag[:siphash.Size],
441                                 buf[n-siphash.Size:n],
442                         ) != 1 {
443                                 peer.stats.bads++
444                                 continue
445                         }
446
447                         peer.stats.last = time.Now()
448                         for _, p := range peer.room.peers {
449                                 if p.sid == sid || p.addr == nil {
450                                         continue
451                                 }
452                                 p.stats.pktsTx++
453                                 p.stats.bytesTx += vors.IPHdrLen(p.addr.IP) + 8 + uint64(n)
454                                 if _, err = lnUDP.WriteToUDP(buf[:n], p.addr); err != nil {
455                                         slog.Warn("sendto", "peer", peer.name, "err", err)
456                                 }
457                         }
458                 }
459         }()
460
461         go func() {
462                 <-LoggerReady
463                 slog.Info("listening",
464                         "bind", *bind,
465                         "pub", base64.RawURLEncoding.EncodeToString(Pub))
466                 for {
467                         conn, err := lnTCP.AcceptTCP()
468                         if err != nil {
469                                 log.Fatalln("accept:", err)
470                         }
471                         go newPeer(conn)
472                 }
473         }()
474
475         if *NoGUI {
476                 <-make(chan struct{})
477         }
478         err = GUI.MainLoop()
479         if err != nil && err != gocui.ErrQuit {
480                 log.Fatal(err)
481         }
482 }