]> Sergey Matveev's repositories - vors.git/blob - cmd/server/main.go
Verify MACs on server side
[vors.git] / cmd / server / main.go
1 // VoRS -- Vo(IP) Really Simple
2 // Copyright (C) 2024 Sergey Matveev <stargrave@stargrave.org>
3 //
4 // This program is free software: you can redistribute it and/or modify
5 // it under the terms of the GNU Affero General Public License as
6 // published by the Free Software Foundation, version 3 of the License.
7 //
8 // This program is distributed in the hope that it will be useful,
9 // but WITHOUT ANY WARRANTY; without even the implied warranty of
10 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11 // GNU General Public License for more details.
12 //
13 // You should have received a copy of the GNU General Public License
14 // along with this program.  If not, see <http://www.gnu.org/licenses/>.
15
16 package main
17
18 import (
19         "bufio"
20         "crypto/rand"
21         "crypto/subtle"
22         "crypto/tls"
23         "encoding/hex"
24         "flag"
25         "fmt"
26         "io"
27         "log"
28         "log/slog"
29         "net"
30         "net/netip"
31         "os"
32         "strconv"
33         "strings"
34         "sync"
35         "time"
36
37         "github.com/dustin/go-humanize"
38         "github.com/jroimartin/gocui"
39         vors "go.stargrave.org/vors/internal"
40         "golang.org/x/crypto/blake2s"
41         "golang.org/x/crypto/chacha20"
42         "golang.org/x/crypto/chacha20poly1305"
43         "golang.org/x/crypto/poly1305"
44 )
45
46 var (
47         TLSCfg = &tls.Config{
48                 MinVersion:       tls.VersionTLS13,
49                 CurvePreferences: []tls.CurveID{tls.X25519},
50         }
51         SPKI   string
52         Passwd = flag.String("passwd", "", "Shared password")
53         Peers  = map[byte]*Peer{}
54         PeersM sync.Mutex
55 )
56
57 type Peer struct {
58         name  string
59         sid   byte
60         addr  *net.UDPAddr
61         conn  net.Conn
62         key   []byte
63         stats *Stats
64 }
65
66 func newPeer(connRaw net.Conn) {
67         logger := slog.With("remote", connRaw.RemoteAddr().String())
68         logger.Info("connected")
69         defer connRaw.Close()
70         if len(Peers) == 256 {
71                 logger.Error("too many peers")
72                 return
73         }
74         conn := tls.Server(connRaw, TLSCfg)
75         err := conn.Handshake()
76         if err != nil {
77                 logger.Error("handshake:", "err", err)
78                 return
79         }
80         defer conn.Close()
81
82         scanner := bufio.NewScanner(conn)
83         peer := Peer{conn: conn, stats: &Stats{dead: make(chan struct{})}}
84         peer.addr = net.UDPAddrFromAddrPort(
85                 netip.MustParseAddrPort(conn.RemoteAddr().String()))
86         if err != nil {
87                 log.Fatal(err)
88         }
89         {
90                 chlng := make([]byte, 16)
91                 if _, err = io.ReadFull(rand.Reader, chlng); err != nil {
92                         log.Fatal(err)
93                         return
94                 }
95                 chlngHex := hex.EncodeToString(chlng)
96                 if _, err = io.Copy(conn, strings.NewReader(chlngHex+"\n")); err != nil {
97                         logger.Error("write challenge:", "err", err)
98                         return
99                 }
100                 h, err := blake2s.New256([]byte(*Passwd))
101                 if err != nil {
102                         log.Fatal(err)
103                 }
104                 h.Write([]byte(chlngHex))
105                 if !scanner.Scan() {
106                         logger.Error("read password:", "err", scanner.Err())
107                         return
108                 }
109                 cols := strings.Fields(scanner.Text())
110                 if len(cols) == 1 {
111                         logger.Error("no name")
112                         io.Copy(conn, strings.NewReader("no name\n"))
113                         return
114                 }
115                 peer.name = cols[1]
116                 if peer.name == "myself" {
117                         logger.Error("reserved name")
118                         io.Copy(conn, strings.NewReader("reserved name\n"))
119                         return
120                 }
121                 logger = logger.With("name", cols[1])
122                 if hex.EncodeToString(h.Sum(nil)) != cols[0] {
123                         logger.Error("wrong password")
124                         io.Copy(conn, strings.NewReader("wrong password\n"))
125                         return
126                 }
127                 for _, p := range Peers {
128                         if p.name == peer.name {
129                                 logger.Error("name already taken")
130                                 io.Copy(conn, strings.NewReader("name already taken\n"))
131                                 return
132                         }
133                 }
134                 var i byte
135                 var ok bool
136                 PeersM.Lock()
137                 for i = 0; i <= 255; i++ {
138                         if _, ok = Peers[i]; !ok {
139                                 peer.sid = i
140                                 break
141                         }
142                 }
143                 Peers[peer.sid] = &peer
144                 PeersM.Unlock()
145                 logger = logger.With("sid", peer.sid)
146                 logger.Info("authenticated")
147                 defer func() {
148                         logger.Info("removing")
149                         PeersM.Lock()
150                         delete(Peers, peer.sid)
151                         close(peer.stats.dead)
152                         s := fmt.Sprintf("%s %d\n", vors.CmdDel, peer.sid)
153                         for _, p := range Peers {
154                                 go io.Copy(p.conn, strings.NewReader(s))
155                         }
156                         PeersM.Unlock()
157                 }()
158                 if _, err = io.Copy(conn, strings.NewReader(
159                         fmt.Sprintf("OK %d\n", peer.sid))); err != nil {
160                         logger.Error("write ok:", "err", err)
161                         return
162                 }
163                 for _, p := range Peers {
164                         if p.sid == peer.sid {
165                                 continue
166                         }
167                         if _, err = io.Copy(conn, strings.NewReader(fmt.Sprintf(
168                                 "%s %d %s %s\n", vors.CmdAdd, p.sid, p.name, hex.EncodeToString(p.key),
169                         ))); err != nil {
170                                 logger.Error("write ADD:", "err", err)
171                                 return
172                         }
173                 }
174                 tlsState := conn.ConnectionState()
175                 peer.key, err = tlsState.ExportKeyingMaterial(
176                         strconv.Itoa(int(peer.sid)), nil, chacha20poly1305.KeySize)
177                 if err != nil {
178                         log.Fatal(err)
179                 }
180                 {
181                         // assume atomic write
182                         s := fmt.Sprintf("%s %d %s %s\n",
183                                 vors.CmdAdd, peer.sid, peer.name, hex.EncodeToString(peer.key))
184                         for _, p := range Peers {
185                                 if p.sid == peer.sid {
186                                         continue
187                                 }
188                                 go io.Copy(p.conn, strings.NewReader(s))
189                         }
190                 }
191                 seen := time.Now()
192                 go func(seen *time.Time) {
193                         for now := range time.Tick(vors.PingTime) {
194                                 if seen.Add(2 * vors.PingTime).Before(now) {
195                                         logger.Error("timeout:", "seen", seen)
196                                         conn.Close()
197                                         break
198                                 }
199                         }
200                 }(&seen)
201                 go func(stats *Stats) {
202                         if *NoGUI {
203                                 return
204                         }
205                         tick := time.Tick(vors.ScreenRefresh)
206                         var now time.Time
207                         var v *gocui.View
208                         for {
209                                 select {
210                                 case <-stats.dead:
211                                         GUI.DeleteView(peer.name)
212                                         return
213                                 case now = <-tick:
214                                         s := fmt.Sprintf(
215                                                 "Rx/Tx: %s / %s  |  %s / %s",
216                                                 humanize.Comma(stats.pktsRx),
217                                                 humanize.Comma(stats.pktsTx),
218                                                 humanize.IBytes(stats.bytesRx),
219                                                 humanize.IBytes(stats.bytesTx),
220                                         )
221                                         if stats.last.Add(time.Second).After(now) {
222                                                 s += "  |  " + vors.CGreen + "TALK" + vors.CReset
223                                         }
224                                         v, err = GUI.View(peer.name)
225                                         if err == nil {
226                                                 v.Clear()
227                                                 v.Write([]byte(s))
228                                         }
229                                 }
230                         }
231                 }(peer.stats)
232                 for scanner.Scan() {
233                         if scanner.Text() == vors.CmdPing {
234                                 if _, err = io.Copy(conn,
235                                         strings.NewReader(vors.CmdPong+"\n")); err != nil {
236                                         logger.Error("write ok:", "err", err)
237                                         return
238                                 }
239                                 seen = time.Now()
240                         }
241                 }
242                 if scanner.Err() != nil {
243                         logger.Error(scanner.Err().Error())
244                 }
245         }
246 }
247
248 func main() {
249         bind := flag.String("bind", "[::1]:12345", "TCP/UDP port to listen on")
250         pemFile := flag.String("pem", "keypair.pem", "PEM with keypair")
251         flag.Parse()
252         log.SetFlags(log.Lmicroseconds | log.Lshortfile)
253         if *Passwd == "" {
254                 log.Fatal("no -passwd specified")
255         }
256         if err := parsePEM(*pemFile); err != nil {
257                 log.Fatal(err)
258         }
259
260         addrTCP, err := net.ResolveTCPAddr("tcp", *bind)
261         if err != nil {
262                 log.Fatal(err)
263         }
264         addrUDP, err := net.ResolveUDPAddr("udp", *bind)
265         if err != nil {
266                 log.Fatal(err)
267         }
268         lnTCP, err := net.ListenTCP("tcp", addrTCP)
269         if err != nil {
270                 log.Fatal(err)
271         }
272         lnUDP, err := net.ListenUDP("udp", addrUDP)
273         if err != nil {
274                 log.Fatal(err)
275         }
276
277         LoggerReady := make(chan struct{})
278         if *NoGUI {
279                 close(GUIReadyC)
280                 slog.SetDefault(slog.New(slog.NewTextHandler(os.Stderr, nil)))
281                 close(LoggerReady)
282         } else {
283                 GUI, err = gocui.NewGui(gocui.OutputNormal)
284                 if err != nil {
285                         log.Fatal(err)
286                 }
287                 defer GUI.Close()
288                 GUI.SetManagerFunc(guiLayout)
289                 if err := GUI.SetKeybinding("", gocui.KeyF10, gocui.ModNone, guiQuit); err != nil {
290                         log.Fatal(err)
291                 }
292
293                 go func() {
294                         <-GUIReadyC
295                         v, err := GUI.View("logs")
296                         if err != nil {
297                                 log.Fatal(err)
298                         }
299                         slog.SetDefault(slog.New(slog.NewTextHandler(v, nil)))
300                         close(LoggerReady)
301                         for {
302                                 time.Sleep(vors.ScreenRefresh)
303                                 GUI.Update(func(gui *gocui.Gui) error {
304                                         return nil
305                                 })
306                         }
307                 }()
308         }
309
310         go func() {
311                 <-LoggerReady
312                 buf := make([]byte, 2*vors.FrameLen)
313                 var n int
314                 var from *net.UDPAddr
315                 var err error
316                 var sid byte
317                 var peer *Peer
318                 var ciph *chacha20.Cipher
319                 var macKey [32]byte
320                 var mac *poly1305.MAC
321                 tag := make([]byte, poly1305.TagSize)
322                 nonce := make([]byte, 12)
323                 for {
324                         n, from, err = lnUDP.ReadFromUDP(buf)
325                         if err != nil {
326                                 log.Fatalln("recvfrom:", err)
327                         }
328                         sid = buf[0]
329                         peer = Peers[sid]
330                         if peer == nil {
331                                 slog.Info("unknown:", "sid", sid, "from", from)
332                                 continue
333                         }
334                         if from.Port != peer.addr.Port || !from.IP.Equal(peer.addr.IP) {
335                                 slog.Info("wrong addr:",
336                                         "peer", peer.name,
337                                         "our", peer.addr,
338                                         "got", from)
339                                 continue
340                         }
341                         peer.stats.pktsRx++
342                         peer.stats.bytesRx += uint64(n)
343                         if n == 1 {
344                                 continue
345                         }
346                         if n <= 4+vors.TagLen {
347                                 slog.Info("too small:", "peer", peer.name, "len", n)
348                                 continue
349                         }
350
351                         copy(nonce[len(nonce)-4:], buf)
352                         ciph, err = chacha20.NewUnauthenticatedCipher(peer.key, nonce)
353                         if err != nil {
354                                 log.Fatal(err)
355                         }
356                         clear(macKey[:])
357                         ciph.XORKeyStream(macKey[:], macKey[:])
358                         ciph.SetCounter(1)
359                         mac = poly1305.New(&macKey)
360                         if _, err = mac.Write(buf[4 : n-vors.TagLen]); err != nil {
361                                 log.Fatal(err)
362                         }
363                         mac.Sum(tag[:0])
364                         if subtle.ConstantTimeCompare(
365                                 tag[:vors.TagLen],
366                                 buf[n-vors.TagLen:n],
367                         ) != 1 {
368                                 log.Println("decrypt:", peer.name, "tag differs")
369                                 slog.Info("MAC failed:", "peer", peer.name, "len", n)
370                                 continue
371                         }
372
373                         peer.stats.last = time.Now()
374                         for _, p := range Peers {
375                                 if p.sid == sid {
376                                         continue
377                                 }
378                                 p.stats.pktsTx++
379                                 p.stats.bytesTx += uint64(n)
380                                 if _, err = lnUDP.WriteToUDP(buf[:n], p.addr); err != nil {
381                                         slog.Warn("sendto:", "peer", peer.name, "err", err)
382                                 }
383                         }
384                 }
385         }()
386
387         go func() {
388                 <-LoggerReady
389                 slog.Info("listening", "bind", *bind, "spki", SPKI)
390                 for {
391                         conn, err := lnTCP.Accept()
392                         if err != nil {
393                                 log.Fatalln("accept:", err)
394                         }
395                         go newPeer(conn)
396                 }
397         }()
398
399         if *NoGUI {
400                 dummy := make(chan struct{})
401                 <-dummy
402         } else {
403                 err = GUI.MainLoop()
404                 if err != nil && err != gocui.ErrQuit {
405                         log.Fatal(err)
406                 }
407         }
408 }