]> Sergey Matveev's repositories - vors.git/blob - cmd/server/main.go
Working version
[vors.git] / cmd / server / main.go
1 // VoRS -- Vo(IP) Really Simple
2 // Copyright (C) 2024 Sergey Matveev <stargrave@stargrave.org>
3 //
4 // This program is free software: you can redistribute it and/or modify
5 // it under the terms of the GNU Affero General Public License as
6 // published by the Free Software Foundation, version 3 of the License.
7 //
8 // This program is distributed in the hope that it will be useful,
9 // but WITHOUT ANY WARRANTY; without even the implied warranty of
10 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11 // GNU General Public License for more details.
12 //
13 // You should have received a copy of the GNU General Public License
14 // along with this program.  If not, see <http://www.gnu.org/licenses/>.
15
16 package main
17
18 import (
19         "bufio"
20         "crypto/rand"
21         "crypto/tls"
22         "encoding/hex"
23         "flag"
24         "fmt"
25         "io"
26         "log"
27         "log/slog"
28         "net"
29         "net/netip"
30         "os"
31         "strconv"
32         "strings"
33         "sync"
34         "time"
35
36         "github.com/dustin/go-humanize"
37         "github.com/jroimartin/gocui"
38         vors "go.stargrave.org/vors/internal"
39         "golang.org/x/crypto/blake2s"
40         "golang.org/x/crypto/chacha20poly1305"
41 )
42
43 var (
44         TLSCfg = &tls.Config{
45                 MinVersion:       tls.VersionTLS13,
46                 CurvePreferences: []tls.CurveID{tls.X25519},
47         }
48         SPKI   string
49         Passwd = flag.String("passwd", "", "Shared password")
50         Peers  = map[byte]*Peer{}
51         PeersM sync.Mutex
52 )
53
54 type Peer struct {
55         name  string
56         sid   byte
57         addr  *net.UDPAddr
58         conn  net.Conn
59         key   []byte
60         stats *Stats
61 }
62
63 func newPeer(connRaw net.Conn) {
64         logger := slog.With("remote", connRaw.RemoteAddr().String())
65         logger.Info("connected")
66         defer connRaw.Close()
67         if len(Peers) == 256 {
68                 logger.Error("too many peers")
69                 return
70         }
71         conn := tls.Server(connRaw, TLSCfg)
72         err := conn.Handshake()
73         if err != nil {
74                 logger.Error("handshake:", "err", err)
75                 return
76         }
77         defer conn.Close()
78
79         scanner := bufio.NewScanner(conn)
80         peer := Peer{conn: conn, stats: &Stats{dead: make(chan struct{})}}
81         peer.addr = net.UDPAddrFromAddrPort(
82                 netip.MustParseAddrPort(conn.RemoteAddr().String()))
83         if err != nil {
84                 log.Fatal(err)
85         }
86         {
87                 chlng := make([]byte, 16)
88                 if _, err = io.ReadFull(rand.Reader, chlng); err != nil {
89                         log.Fatal(err)
90                         return
91                 }
92                 chlngHex := hex.EncodeToString(chlng)
93                 if _, err = io.Copy(conn, strings.NewReader(chlngHex+"\n")); err != nil {
94                         logger.Error("write challenge:", "err", err)
95                         return
96                 }
97                 h, err := blake2s.New256([]byte(*Passwd))
98                 if err != nil {
99                         log.Fatal(err)
100                 }
101                 h.Write([]byte(chlngHex))
102                 if !scanner.Scan() {
103                         logger.Error("read password:", "err", scanner.Err())
104                         return
105                 }
106                 cols := strings.Fields(scanner.Text())
107                 if len(cols) == 1 {
108                         logger.Error("no name")
109                         io.Copy(conn, strings.NewReader("no name\n"))
110                         return
111                 }
112                 peer.name = cols[1]
113                 if peer.name == "myself" {
114                         logger.Error("reserved name")
115                         io.Copy(conn, strings.NewReader("reserved name\n"))
116                         return
117                 }
118                 logger = logger.With("name", cols[1])
119                 if hex.EncodeToString(h.Sum(nil)) != cols[0] {
120                         logger.Error("wrong password")
121                         io.Copy(conn, strings.NewReader("wrong password\n"))
122                         return
123                 }
124                 for _, p := range Peers {
125                         if p.name == peer.name {
126                                 logger.Error("name already taken")
127                                 io.Copy(conn, strings.NewReader("name already taken\n"))
128                                 return
129                         }
130                 }
131                 var i byte
132                 var ok bool
133                 PeersM.Lock()
134                 for i = 0; i <= 255; i++ {
135                         if _, ok = Peers[i]; !ok {
136                                 peer.sid = i
137                                 break
138                         }
139                 }
140                 Peers[peer.sid] = &peer
141                 PeersM.Unlock()
142                 logger = logger.With("sid", peer.sid)
143                 logger.Info("authenticated")
144                 defer func() {
145                         logger.Info("removing")
146                         PeersM.Lock()
147                         delete(Peers, peer.sid)
148                         close(peer.stats.dead)
149                         s := fmt.Sprintf("%s %d\n", vors.CmdDel, peer.sid)
150                         for _, p := range Peers {
151                                 go io.Copy(p.conn, strings.NewReader(s))
152                         }
153                         PeersM.Unlock()
154                 }()
155                 if _, err = io.Copy(conn, strings.NewReader(
156                         fmt.Sprintf("OK %d\n", peer.sid))); err != nil {
157                         logger.Error("write ok:", "err", err)
158                         return
159                 }
160                 for _, p := range Peers {
161                         if p.sid == peer.sid {
162                                 continue
163                         }
164                         if _, err = io.Copy(conn, strings.NewReader(fmt.Sprintf(
165                                 "%s %d %s %s\n", vors.CmdAdd, p.sid, p.name, hex.EncodeToString(p.key),
166                         ))); err != nil {
167                                 logger.Error("write ADD:", "err", err)
168                                 return
169                         }
170                 }
171                 tlsState := conn.ConnectionState()
172                 peer.key, err = tlsState.ExportKeyingMaterial(
173                         strconv.Itoa(int(peer.sid)), nil, chacha20poly1305.KeySize)
174                 if err != nil {
175                         log.Fatal(err)
176                 }
177                 {
178                         // assume atomic write
179                         s := fmt.Sprintf("%s %d %s %s\n",
180                                 vors.CmdAdd, peer.sid, peer.name, hex.EncodeToString(peer.key))
181                         for _, p := range Peers {
182                                 if p.sid == peer.sid {
183                                         continue
184                                 }
185                                 go io.Copy(p.conn, strings.NewReader(s))
186                         }
187                 }
188                 seen := time.Now()
189                 go func(seen *time.Time) {
190                         for now := range time.Tick(vors.PingTime) {
191                                 if seen.Add(2 * vors.PingTime).Before(now) {
192                                         logger.Error("timeout:", "seen", seen)
193                                         conn.Close()
194                                         break
195                                 }
196                         }
197                 }(&seen)
198                 go func(stats *Stats) {
199                         if *NoGUI {
200                                 return
201                         }
202                         tick := time.Tick(vors.ScreenRefresh)
203                         var now time.Time
204                         var v *gocui.View
205                         for {
206                                 select {
207                                 case <-stats.dead:
208                                         GUI.DeleteView(peer.name)
209                                         return
210                                 case now = <-tick:
211                                         s := fmt.Sprintf(
212                                                 "Rx/Tx: %s / %s  |  %s / %s",
213                                                 humanize.Comma(stats.pktsRx),
214                                                 humanize.Comma(stats.pktsTx),
215                                                 humanize.IBytes(stats.bytesRx),
216                                                 humanize.IBytes(stats.bytesTx),
217                                         )
218                                         if stats.last.Add(time.Second).After(now) {
219                                                 s += "  |  " + vors.CGreen + "TALK" + vors.CReset
220                                         }
221                                         v, err = GUI.View(peer.name)
222                                         if err == nil {
223                                                 v.Clear()
224                                                 v.Write([]byte(s))
225                                         }
226                                 }
227                         }
228                 }(peer.stats)
229                 for scanner.Scan() {
230                         if scanner.Text() == vors.CmdPing {
231                                 if _, err = io.Copy(conn,
232                                         strings.NewReader(vors.CmdPong+"\n")); err != nil {
233                                         logger.Error("write ok:", "err", err)
234                                         return
235                                 }
236                                 seen = time.Now()
237                         }
238                 }
239                 if scanner.Err() != nil {
240                         logger.Error(scanner.Err().Error())
241                 }
242         }
243 }
244
245 func main() {
246         bind := flag.String("bind", "[::1]:12345", "TCP/UDP port to listen on")
247         pemFile := flag.String("pem", "keypair.pem", "PEM with keypair")
248         flag.Parse()
249         log.SetFlags(log.Lmicroseconds | log.Lshortfile)
250         if *Passwd == "" {
251                 log.Fatal("no -passwd specified")
252         }
253         if err := parsePEM(*pemFile); err != nil {
254                 log.Fatal(err)
255         }
256
257         addrTCP, err := net.ResolveTCPAddr("tcp", *bind)
258         if err != nil {
259                 log.Fatal(err)
260         }
261         addrUDP, err := net.ResolveUDPAddr("udp", *bind)
262         if err != nil {
263                 log.Fatal(err)
264         }
265         lnTCP, err := net.ListenTCP("tcp", addrTCP)
266         if err != nil {
267                 log.Fatal(err)
268         }
269         lnUDP, err := net.ListenUDP("udp", addrUDP)
270         if err != nil {
271                 log.Fatal(err)
272         }
273
274         LoggerReady := make(chan struct{})
275         if *NoGUI {
276                 close(GUIReadyC)
277                 slog.SetDefault(slog.New(slog.NewTextHandler(os.Stderr, nil)))
278                 close(LoggerReady)
279         } else {
280                 GUI, err = gocui.NewGui(gocui.OutputNormal)
281                 if err != nil {
282                         log.Fatal(err)
283                 }
284                 defer GUI.Close()
285                 GUI.SetManagerFunc(guiLayout)
286                 if err := GUI.SetKeybinding("", gocui.KeyF10, gocui.ModNone, guiQuit); err != nil {
287                         log.Fatal(err)
288                 }
289
290                 go func() {
291                         <-GUIReadyC
292                         v, err := GUI.View("logs")
293                         if err != nil {
294                                 log.Fatal(err)
295                         }
296                         slog.SetDefault(slog.New(slog.NewTextHandler(v, nil)))
297                         close(LoggerReady)
298                         for {
299                                 time.Sleep(vors.ScreenRefresh)
300                                 GUI.Update(func(gui *gocui.Gui) error {
301                                         return nil
302                                 })
303                         }
304                 }()
305         }
306
307         go func() {
308                 <-LoggerReady
309                 buf := make([]byte, 2*vors.FrameLen)
310                 var n int
311                 var from *net.UDPAddr
312                 var err error
313                 var sid byte
314                 var peer *Peer
315                 for {
316                         n, from, err = lnUDP.ReadFromUDP(buf)
317                         if err != nil {
318                                 log.Fatalln("recvfrom:", err)
319                         }
320                         sid = buf[0]
321                         peer = Peers[sid]
322                         if peer == nil {
323                                 slog.Info("unknown:", "sid", sid, "from", from)
324                                 continue
325                         }
326                         if from.Port != peer.addr.Port || !from.IP.Equal(peer.addr.IP) {
327                                 slog.Info("wrong addr:",
328                                         "peer", peer.name,
329                                         "our", peer.addr,
330                                         "got", from)
331                                 continue
332                         }
333                         peer.stats.pktsRx++
334                         peer.stats.bytesRx += uint64(n)
335                         if n == 1 {
336                                 continue
337                         }
338                         peer.stats.last = time.Now()
339                         for _, p := range Peers {
340                                 if p.sid == sid {
341                                         continue
342                                 }
343                                 p.stats.pktsTx++
344                                 p.stats.bytesTx += uint64(n)
345                                 if _, err = lnUDP.WriteToUDP(buf[:n], p.addr); err != nil {
346                                         slog.Warn("sendto:", "peer", peer.name, "err", err)
347                                 }
348                         }
349                 }
350         }()
351
352         go func() {
353                 <-LoggerReady
354                 slog.Info("listening", "bind", *bind, "spki", SPKI)
355                 for {
356                         conn, err := lnTCP.Accept()
357                         if err != nil {
358                                 log.Fatalln("accept:", err)
359                         }
360                         go newPeer(conn)
361                 }
362         }()
363
364         if *NoGUI {
365                 dummy := make(chan struct{})
366                 <-dummy
367         } else {
368                 err = GUI.MainLoop()
369                 if err != nil && err != gocui.ErrQuit {
370                         log.Fatal(err)
371                 }
372         }
373 }