]> Sergey Matveev's repositories - vors.git/blob - cmd/server/main.go
SipHash24 for short messages is much faster and secure enough
[vors.git] / cmd / server / main.go
1 // VoRS -- Vo(IP) Really Simple
2 // Copyright (C) 2024 Sergey Matveev <stargrave@stargrave.org>
3 //
4 // This program is free software: you can redistribute it and/or modify
5 // it under the terms of the GNU Affero General Public License as
6 // published by the Free Software Foundation, version 3 of the License.
7 //
8 // This program is distributed in the hope that it will be useful,
9 // but WITHOUT ANY WARRANTY; without even the implied warranty of
10 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11 // GNU Affero General Public License for more details.
12 //
13 // You should have received a copy of the GNU Affero General Public License
14 // along with this program.  If not, see <http://www.gnu.org/licenses/>.
15
16 package main
17
18 import (
19         "crypto/rand"
20         "crypto/subtle"
21         "crypto/tls"
22         "encoding/base64"
23         "encoding/hex"
24         "flag"
25         "fmt"
26         "io"
27         "log"
28         "log/slog"
29         "net"
30         "net/netip"
31         "os"
32         "strconv"
33         "strings"
34         "time"
35
36         "github.com/dchest/siphash"
37         "github.com/flynn/noise"
38         "github.com/jroimartin/gocui"
39         vors "go.stargrave.org/vors/internal"
40         "golang.org/x/crypto/blake2s"
41 )
42
43 var (
44         TLSCfg = &tls.Config{
45                 MinVersion:       tls.VersionTLS13,
46                 CurvePreferences: []tls.CurveID{tls.X25519},
47         }
48         Prv, Pub []byte
49         Cookies  = map[vors.Cookie]chan *net.UDPAddr{}
50 )
51
52 func newPeer(conn *net.TCPConn) {
53         logger := slog.With("remote", conn.RemoteAddr().String())
54         logger.Info("connected")
55         defer conn.Close()
56         err := conn.SetNoDelay(true)
57         if err != nil {
58                 log.Fatalln("nodelay:", err)
59         }
60         buf := make([]byte, len(vors.NoisePrologue))
61
62         if _, err = io.ReadFull(conn, buf); err != nil {
63                 logger.Error("handshake: read prologue", "err", err)
64                 return
65         }
66         if string(buf) != vors.NoisePrologue {
67                 logger.Error("handshake: wrong prologue", "err", err)
68                 return
69         }
70
71         hs, err := noise.NewHandshakeState(noise.Config{
72                 CipherSuite:   vors.NoiseCipherSuite,
73                 Pattern:       noise.HandshakeNK,
74                 Initiator:     false,
75                 StaticKeypair: noise.DHKey{Private: Prv, Public: Pub},
76                 Prologue:      []byte(vors.NoisePrologue),
77         })
78         if err != nil {
79                 log.Fatalln("noise.NewHandshakeState:", err)
80         }
81         buf, err = vors.PktRead(conn)
82         if err != nil {
83                 logger.Error("read handshake", "err", err)
84                 return
85         }
86         peer := &Peer{
87                 logger: logger,
88                 conn:   conn,
89                 stats:  &Stats{},
90                 rx:     make(chan []byte),
91                 tx:     make(chan []byte, 10),
92                 alive:  make(chan struct{}),
93         }
94         var room *Room
95         {
96                 nameAndRoom, _, _, err := hs.ReadMessage(nil, buf)
97                 if err != nil {
98                         logger.Error("handshake: decrypt", "err", err)
99                         return
100                 }
101                 cols := strings.SplitN(string(nameAndRoom), " ", 3)
102                 roomName := "/"
103                 if len(cols) > 1 {
104                         roomName = cols[1]
105                 }
106                 var key string
107                 if len(cols) > 2 {
108                         key = cols[2]
109                 }
110                 peer.name = string(cols[0])
111                 logger = logger.With("name", peer.name, "room", roomName)
112                 RoomsM.Lock()
113                 room = Rooms[roomName]
114                 if room == nil {
115                         room = &Room{
116                                 name:  roomName,
117                                 key:   key,
118                                 peers: make(map[byte]*Peer),
119                                 alive: make(chan struct{}),
120                         }
121                         Rooms[roomName] = room
122                         go func() {
123                                 if *NoGUI {
124                                         return
125                                 }
126                                 tick := time.Tick(vors.ScreenRefresh)
127                                 var now time.Time
128                                 var v *gocui.View
129                                 for {
130                                         select {
131                                         case <-room.alive:
132                                                 GUI.DeleteView(room.name)
133                                                 return
134                                         case now = <-tick:
135                                                 v, err = GUI.View(room.name)
136                                                 if err == nil {
137                                                         v.Clear()
138                                                         v.Write([]byte(strings.Join(room.Stats(now), "\n")))
139                                                 }
140                                         }
141                                 }
142                         }()
143                 }
144                 RoomsM.Unlock()
145                 if room.key != key {
146                         logger.Error("wrong password")
147                         buf, _, _, err = hs.WriteMessage(nil, []byte("wrong password"))
148                         if err != nil {
149                                 log.Fatal(err)
150                         }
151                         vors.PktWrite(conn, buf)
152                         return
153                 }
154         }
155         peer.room = room
156
157         for _, p := range room.peers {
158                 if p.name != peer.name {
159                         continue
160                 }
161                 logger.Error("name already taken")
162                 buf, _, _, err = hs.WriteMessage(nil, []byte("name already taken"))
163                 if err != nil {
164                         log.Fatal(err)
165                 }
166                 vors.PktWrite(conn, buf)
167                 return
168         }
169
170         {
171                 var i byte
172                 var ok bool
173                 var found bool
174                 PeersM.Lock()
175                 for i = 0; i <= (1<<8)-1; i++ {
176                         if _, ok = Peers[i]; !ok {
177                                 peer.sid = i
178                                 found = true
179                                 break
180                         }
181                 }
182                 if found {
183                         Peers[peer.sid] = peer
184                         go peer.Tx()
185                 }
186                 PeersM.Unlock()
187                 if !found {
188                         buf, _, _, err = hs.WriteMessage(nil, []byte("too many users"))
189                         if err != nil {
190                                 log.Fatal(err)
191                         }
192                         vors.PktWrite(conn, buf)
193                         return
194                 }
195         }
196         logger = logger.With("sid", peer.sid)
197         room.peers[peer.sid] = peer
198         logger.Info("logged in")
199
200         defer func() {
201                 logger.Info("removing")
202                 PeersM.Lock()
203                 delete(Peers, peer.sid)
204                 delete(room.peers, peer.sid)
205                 PeersM.Unlock()
206                 s := []byte(fmt.Sprintf("%s %d", vors.CmdDel, peer.sid))
207                 for _, p := range room.peers {
208                         go func(tx chan []byte) { tx <- s }(p.tx)
209                 }
210         }()
211
212         {
213                 var cookie vors.Cookie
214                 if _, err = io.ReadFull(rand.Reader, cookie[:]); err != nil {
215                         log.Fatalln("cookie:", err)
216                 }
217                 gotCookie := make(chan *net.UDPAddr)
218                 Cookies[cookie] = gotCookie
219
220                 var txCS, rxCS *noise.CipherState
221                 buf, txCS, rxCS, err := hs.WriteMessage(nil,
222                         []byte(fmt.Sprintf("OK %s", hex.EncodeToString(cookie[:]))))
223                 if err != nil {
224                         log.Fatalln("hs.WriteMessage:", err)
225                 }
226                 if err = vors.PktWrite(conn, buf); err != nil {
227                         logger.Error("handshake write", "err", err)
228                         delete(Cookies, cookie)
229                         return
230                 }
231                 peer.rxCS, peer.txCS = txCS, rxCS
232
233                 timeout := time.NewTimer(vors.PingTime)
234                 select {
235                 case peer.addr = <-gotCookie:
236                 case <-timeout.C:
237                         logger.Error("cookie timeout")
238                         delete(Cookies, cookie)
239                         return
240                 }
241                 delete(Cookies, cookie)
242                 logger.Info("got cookie", "addr", peer.addr)
243                 if !timeout.Stop() {
244                         <-timeout.C
245                 }
246         }
247         go peer.Rx()
248         peer.tx <- []byte(fmt.Sprintf("SID %d", peer.sid))
249
250         for _, p := range room.peers {
251                 if p.sid == peer.sid {
252                         continue
253                 }
254                 peer.tx <- []byte(fmt.Sprintf("%s %d %s %s",
255                         vors.CmdAdd, p.sid, p.name, hex.EncodeToString(p.key)))
256         }
257
258         {
259                 xof, err := blake2s.NewXOF(32+16, nil)
260                 if err != nil {
261                         log.Fatalln(err)
262                 }
263                 xof.Write([]byte(vors.NoisePrologue))
264                 xof.Write(hs.ChannelBinding())
265                 peer.key = make([]byte, 32+16)
266                 if _, err = io.ReadFull(xof, peer.key); err != nil {
267                         log.Fatalln(err)
268                 }
269                 peer.mac = siphash.New(peer.key[32:])
270         }
271
272         {
273                 s := []byte(fmt.Sprintf("%s %d %s %s",
274                         vors.CmdAdd, peer.sid, peer.name, hex.EncodeToString(peer.key)))
275                 for _, p := range room.peers {
276                         if p.sid != peer.sid {
277                                 p.tx <- s
278                         }
279                 }
280         }
281
282         seen := time.Now()
283         go func(seen *time.Time) {
284                 ticker := time.Tick(vors.PingTime)
285                 var now time.Time
286                 for {
287                         select {
288                         case now = <-ticker:
289                                 if seen.Add(2 * vors.PingTime).Before(now) {
290                                         logger.Error("timeout:", "seen", seen)
291                                         peer.Close()
292                                         return
293                                 }
294                         case <-peer.alive:
295                                 return
296                         }
297                 }
298         }(&seen)
299
300         for buf := range peer.rx {
301                 if string(buf) == vors.CmdPing {
302                         seen = time.Now()
303                         peer.tx <- []byte(vors.CmdPong)
304                 }
305         }
306 }
307
308 func main() {
309         bind := flag.String("bind", "[::1]:"+strconv.Itoa(vors.DefaultPort),
310                 "host:TCP/UDP port to listen on")
311         kpFile := flag.String("key", "key", "path to keypair file")
312         version := flag.Bool("version", false, "print version")
313         warranty := flag.Bool("warranty", false, "print warranty information")
314         flag.Parse()
315         log.SetFlags(log.Lmicroseconds | log.Lshortfile)
316
317         if *warranty {
318                 fmt.Println(vors.Warranty)
319                 return
320         }
321         if *version {
322                 fmt.Println(vors.GetVersion())
323                 return
324         }
325
326         {
327                 data, err := os.ReadFile(*kpFile)
328                 if err != nil {
329                         log.Fatal(err)
330                 }
331                 Prv, Pub = data[:len(data)/2], data[len(data)/2:]
332         }
333
334         lnTCP, err := net.ListenTCP("tcp",
335                 net.TCPAddrFromAddrPort(netip.MustParseAddrPort(*bind)))
336         if err != nil {
337                 log.Fatal(err)
338         }
339         lnUDP, err := net.ListenUDP("udp",
340                 net.UDPAddrFromAddrPort(netip.MustParseAddrPort(*bind)))
341         if err != nil {
342                 log.Fatal(err)
343         }
344
345         LoggerReady := make(chan struct{})
346         if *NoGUI {
347                 close(GUIReadyC)
348                 slog.SetDefault(slog.New(slog.NewTextHandler(os.Stderr, nil)))
349                 close(LoggerReady)
350         } else {
351                 GUI, err = gocui.NewGui(gocui.OutputNormal)
352                 if err != nil {
353                         log.Fatal(err)
354                 }
355                 defer GUI.Close()
356                 GUI.SetManagerFunc(guiLayout)
357                 if err := GUI.SetKeybinding("", gocui.KeyF10, gocui.ModNone, guiQuit); err != nil {
358                         log.Fatal(err)
359                 }
360
361                 go func() {
362                         <-GUIReadyC
363                         v, err := GUI.View("logs")
364                         if err != nil {
365                                 log.Fatal(err)
366                         }
367                         slog.SetDefault(slog.New(slog.NewTextHandler(v, nil)))
368                         close(LoggerReady)
369                         for {
370                                 time.Sleep(vors.ScreenRefresh)
371                                 GUI.Update(func(gui *gocui.Gui) error {
372                                         return nil
373                                 })
374                         }
375                 }()
376         }
377
378         go func() {
379                 <-LoggerReady
380                 buf := make([]byte, 2*vors.FrameLen)
381                 var n int
382                 var from *net.UDPAddr
383                 var err error
384                 var sid byte
385                 var peer *Peer
386                 tag := make([]byte, siphash.Size)
387                 for {
388                         n, from, err = lnUDP.ReadFromUDP(buf)
389                         if err != nil {
390                                 log.Fatalln("recvfrom:", err)
391                         }
392
393                         if n == vors.CookieLen {
394                                 var cookie vors.Cookie
395                                 copy(cookie[:], buf)
396                                 if c, ok := Cookies[cookie]; ok {
397                                         c <- from
398                                         close(c)
399                                 } else {
400                                         slog.Info("unknown cookie", "cookie", cookie)
401                                 }
402                                 continue
403                         }
404
405                         sid = buf[0]
406                         peer = Peers[sid]
407                         if peer == nil {
408                                 slog.Info("unknown", "sid", sid, "from", from)
409                                 continue
410                         }
411
412                         if from.Port != peer.addr.Port || !from.IP.Equal(peer.addr.IP) {
413                                 slog.Info("wrong addr",
414                                         "peer", peer.name,
415                                         "our", peer.addr,
416                                         "got", from)
417                                 continue
418                         }
419
420                         peer.stats.pktsRx++
421                         peer.stats.bytesRx += vors.IPHdrLen(from.IP) + 8 + uint64(n)
422                         if n == 1 {
423                                 continue
424                         }
425                         if n <= 4+siphash.Size {
426                                 peer.stats.bads++
427                                 continue
428                         }
429
430                         peer.mac.Reset()
431                         if _, err = peer.mac.Write(buf[: n-siphash.Size]); err != nil {
432                                 log.Fatal(err)
433                         }
434                         peer.mac.Sum(tag[:0])
435                         if subtle.ConstantTimeCompare(
436                                 tag[:siphash.Size],
437                                 buf[n-siphash.Size:n],
438                         ) != 1 {
439                                 peer.stats.bads++
440                                 continue
441                         }
442
443                         peer.stats.last = time.Now()
444                         for _, p := range peer.room.peers {
445                                 if p.sid == sid {
446                                         continue
447                                 }
448                                 p.stats.pktsTx++
449                                 p.stats.bytesTx += vors.IPHdrLen(p.addr.IP) + 8 + uint64(n)
450                                 if _, err = lnUDP.WriteToUDP(buf[:n], p.addr); err != nil {
451                                         slog.Warn("sendto", "peer", peer.name, "err", err)
452                                 }
453                         }
454                 }
455         }()
456
457         go func() {
458                 <-LoggerReady
459                 slog.Info("listening",
460                         "bind", *bind,
461                         "pub", base64.RawURLEncoding.EncodeToString(Pub))
462                 for {
463                         conn, err := lnTCP.AcceptTCP()
464                         if err != nil {
465                                 log.Fatalln("accept:", err)
466                         }
467                         go newPeer(conn)
468                 }
469         }()
470
471         if *NoGUI {
472                 <-make(chan struct{})
473         }
474         err = GUI.MainLoop()
475         if err != nil && err != gocui.ErrQuit {
476                 log.Fatal(err)
477         }
478 }