]> Sergey Matveev's repositories - vors.git/blob - cmd/client/main.go
bad619faed3490fe9f5d1de277d770194152d61c263e260e6b8b93683a36ba39
[vors.git] / cmd / client / main.go
1 // VoRS -- Vo(IP) Really Simple
2 // Copyright (C) 2024 Sergey Matveev <stargrave@stargrave.org>
3 //
4 // This program is free software: you can redistribute it and/or modify
5 // it under the terms of the GNU Affero General Public License as
6 // published by the Free Software Foundation, version 3 of the License.
7 //
8 // This program is distributed in the hope that it will be useful,
9 // but WITHOUT ANY WARRANTY; without even the implied warranty of
10 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11 // GNU General Public License for more details.
12 //
13 // You should have received a copy of the GNU General Public License
14 // along with this program.  If not, see <http://www.gnu.org/licenses/>.
15
16 package main
17
18 import (
19         "bufio"
20         "bytes"
21         "crypto/subtle"
22         "crypto/tls"
23         "crypto/x509"
24         "encoding/binary"
25         "encoding/hex"
26         "errors"
27         "flag"
28         "fmt"
29         "io"
30         "log"
31         "net"
32         "net/netip"
33         "os"
34         "os/exec"
35         "strconv"
36         "strings"
37         "time"
38
39         "github.com/jroimartin/gocui"
40         vors "go.stargrave.org/vors/internal"
41         "golang.org/x/crypto/blake2s"
42         "golang.org/x/crypto/chacha20"
43         "golang.org/x/crypto/poly1305"
44         "gopkg.in/hraban/opus.v2"
45 )
46
47 type Stream struct {
48         name  string
49         ctr   uint32
50         in    chan []byte
51         stats *Stats
52 }
53
54 var (
55         Streams  = map[byte]*Stream{}
56         Finish   = make(chan struct{})
57         OurStats = &Stats{dead: make(chan struct{})}
58         Name     = flag.String("name", "test", "Username")
59         Muted    bool
60 )
61
62 func parseSID(s string) byte {
63         n, err := strconv.Atoi(s)
64         if err != nil {
65                 log.Fatal(err)
66         }
67         if n > 255 {
68                 log.Fatal("too big stream num")
69         }
70         return byte(n)
71 }
72
73 func makeCmd(cmd string) *exec.Cmd {
74         args := strings.Fields(cmd)
75         if len(args) == 1 {
76                 return exec.Command(args[0])
77         }
78         return exec.Command(args[0], args[1:]...)
79 }
80
81 func incr(data []byte) {
82         for i := len(data) - 1; i >= 0; i-- {
83                 data[i]++
84                 if data[i] != 0 {
85                         return
86                 }
87         }
88         panic("overflow")
89 }
90
91 func main() {
92         vadRaw := flag.Uint("vad", 0, "VAD threshold")
93         hostport := flag.String("srv", "[::1]:12345", "TCP/UDP port to connect to")
94         spkiHash := flag.String("spki", "FILL-ME", "SHA256 hash of server's certificate SPKI")
95         passwd := flag.String("passwd", "", "Password")
96         recCmd := flag.String("rec", "rec --no-show-progress --buffer 1920 --channels 1 --endian little --encoding signed --rate 48000 --bits 16 --type raw -", "rec command")
97         playCmd := flag.String("play", "play --no-show-progress --buffer 1920 --channels 1 --endian little --encoding signed --rate 48000 --bits 16 --type raw -", "play command")
98         flag.Parse()
99         log.SetFlags(log.Lmicroseconds | log.Lshortfile)
100
101         vad := uint64(*vadRaw)
102         opusEnc := newOpusEnc()
103         var mic io.ReadCloser
104         var err error
105         if *recCmd != "" {
106                 cmd := makeCmd(*recCmd)
107                 mic, err = cmd.StdoutPipe()
108                 if err != nil {
109                         log.Fatal(err)
110                 }
111                 err = cmd.Start()
112                 if err != nil {
113                         log.Fatal(err)
114                 }
115         }
116
117         addrTCP, err := net.ResolveTCPAddr("tcp", *hostport)
118         if err != nil {
119                 log.Fatal(err)
120         }
121         addrUDP, err := net.ResolveUDPAddr("udp", *hostport)
122         if err != nil {
123                 log.Fatal(err)
124         }
125         ctrlRaw, err := net.DialTCP("tcp", nil, addrTCP)
126         if err != nil {
127                 log.Fatalln("dial server:", err)
128         }
129         defer ctrlRaw.Close()
130         ourAddr := net.UDPAddrFromAddrPort(
131                 netip.MustParseAddrPort(ctrlRaw.LocalAddr().String()))
132         ln, err := net.ListenUDP("udp", ourAddr)
133         if err != nil {
134                 log.Fatal(err)
135         }
136         ctrl := tls.Client(ctrlRaw, &tls.Config{
137                 MinVersion:         tls.VersionTLS13,
138                 CurvePreferences:   []tls.CurveID{tls.X25519},
139                 ServerName:         vors.CN,
140                 InsecureSkipVerify: true,
141                 VerifyPeerCertificate: func(
142                         rawCerts [][]byte, verifiedChains [][]*x509.Certificate,
143                 ) error {
144                         cer, err := x509.ParseCertificate(rawCerts[0])
145                         if err != nil {
146                                 return err
147                         }
148                         if *spkiHash != vors.SPKIHash(cer) {
149                                 return errors.New("server certificate's SPKI hash mismatch")
150                         }
151                         return nil
152                 },
153         })
154         err = ctrl.Handshake()
155         if err != nil {
156                 log.Println("TLS handshake:", err)
157                 return
158         }
159         defer ctrl.Close()
160
161         scanner := bufio.NewScanner(ctrl)
162         if !scanner.Scan() {
163                 log.Println("read challenge:", scanner.Err())
164                 return
165         }
166         {
167                 h, err := blake2s.New256([]byte(*passwd))
168                 if err != nil {
169                         log.Fatal(err)
170                 }
171                 h.Write(scanner.Bytes())
172                 if _, err = io.Copy(ctrl, strings.NewReader(fmt.Sprintf(
173                         "%s %s\n", hex.EncodeToString(h.Sum(nil)), *Name))); err != nil {
174                         log.Println("write password:", err)
175                         return
176                 }
177         }
178         if !scanner.Scan() {
179                 log.Println("auth", scanner.Err())
180                 return
181         }
182         cols := strings.Fields(scanner.Text())
183         if cols[0] != "OK" {
184                 log.Println("auth failed:", scanner.Text())
185                 return
186         }
187         sid := parseSID(cols[1])
188         Streams[sid] = &Stream{name: *Name, stats: OurStats}
189
190         tlsState := ctrl.ConnectionState()
191         keyOur, err := tlsState.ExportKeyingMaterial(cols[1], nil, chacha20.KeySize)
192         if err != nil {
193                 log.Fatal(err)
194         }
195         seen := time.Now()
196
197         LoggerReady := make(chan struct{})
198         GUI, err = gocui.NewGui(gocui.OutputNormal)
199         if err != nil {
200                 log.Fatal(err)
201         }
202         defer GUI.Close()
203         GUI.SelFgColor = gocui.ColorCyan
204         GUI.Highlight = true
205         GUI.SetManagerFunc(guiLayout)
206         if err := GUI.SetKeybinding("", gocui.KeyF10, gocui.ModNone, guiQuit); err != nil {
207                 log.Fatal(err)
208         }
209         if err := GUI.SetKeybinding("", gocui.KeyEnter, gocui.ModNone, mute); err != nil {
210                 log.Fatal(err)
211         }
212
213         go func() {
214                 <-GUIReadyC
215                 v, err := GUI.View("logs")
216                 if err != nil {
217                         log.Fatal(err)
218                 }
219                 log.SetOutput(v)
220                 log.Println("connected")
221                 close(LoggerReady)
222                 for {
223                         time.Sleep(vors.ScreenRefresh)
224                         GUI.Update(func(gui *gocui.Gui) error {
225                                 return nil
226                         })
227                 }
228         }()
229
230         go func() {
231                 <-Finish
232                 go GUI.Close()
233                 time.Sleep(100 * time.Millisecond)
234                 os.Exit(0)
235         }()
236
237         go func() {
238                 var err error
239                 for {
240                         time.Sleep(vors.PingTime)
241                         if _, err = ctrl.Write([]byte(vors.CmdPing + "\n")); err != nil {
242                                 log.Println("ping:", err)
243                                 Finish <- struct{}{}
244                                 break
245                         }
246                 }
247         }()
248
249         go func(seen *time.Time) {
250                 var t string
251                 var now time.Time
252                 for scanner.Scan() {
253                         t = scanner.Text()
254                         if t == vors.CmdPong {
255                                 now = time.Now()
256                                 *seen = now
257                                 continue
258                         }
259                         cols := strings.Fields(t)
260                         switch cols[0] {
261                         case vors.CmdAdd:
262                                 sidRaw, name, keyHex := cols[1], cols[2], cols[3]
263                                 log.Println("add", name)
264                                 sid := parseSID(sidRaw)
265                                 key, err := hex.DecodeString(keyHex)
266                                 if err != nil {
267                                         log.Fatal(err)
268                                 }
269                                 stream := &Stream{
270                                         name:  name,
271                                         in:    make(chan []byte, 1<<10),
272                                         stats: &Stats{dead: make(chan struct{})},
273                                 }
274                                 go func() {
275                                         dec, err := opus.NewDecoder(vors.Rate, 1)
276                                         if err != nil {
277                                                 log.Fatal(err)
278                                         }
279
280                                         var player io.WriteCloser
281                                         var cmd *exec.Cmd
282                                         if *playCmd != "" {
283                                                 cmd = makeCmd(*playCmd)
284                                                 player, err = cmd.StdinPipe()
285                                                 if err != nil {
286                                                         log.Fatal(err)
287                                                 }
288                                                 err = cmd.Start()
289                                                 if err != nil {
290                                                         log.Fatal(err)
291                                                 }
292                                         }
293
294                                         var ciph *chacha20.Cipher
295                                         var macKey [32]byte
296                                         var mac *poly1305.MAC
297                                         tag := make([]byte, poly1305.TagSize)
298                                         var ctr uint32
299                                         pcm := make([]int16, vors.FrameLen)
300                                         pcmbuf := make([]byte, 2*vors.FrameLen)
301                                         nonce := make([]byte, 12)
302                                         var pkt []byte
303                                         lost := -1
304                                         var lastDur int
305                                         for buf := range stream.in {
306                                                 copy(nonce[len(nonce)-4:], buf)
307                                                 ciph, err = chacha20.NewUnauthenticatedCipher(key, nonce)
308                                                 if err != nil {
309                                                         log.Fatal(err)
310                                                 }
311                                                 clear(macKey[:])
312                                                 ciph.XORKeyStream(macKey[:], macKey[:])
313                                                 ciph.SetCounter(1)
314                                                 mac = poly1305.New(&macKey)
315                                                 if _, err = mac.Write(buf[4 : len(buf)-vors.TagLen]); err != nil {
316                                                         log.Fatal(err)
317                                                 }
318                                                 mac.Sum(tag[:0])
319                                                 if subtle.ConstantTimeCompare(
320                                                         tag[:vors.TagLen],
321                                                         buf[len(buf)-vors.TagLen:],
322                                                 ) != 1 {
323                                                         log.Println("decrypt:", stream.name, "tag differs")
324                                                         continue
325                                                 }
326                                                 pkt = buf[4 : len(buf)-vors.TagLen]
327                                                 ciph.XORKeyStream(pkt, pkt)
328
329                                                 ctr = binary.BigEndian.Uint32(nonce[len(nonce)-4:])
330                                                 if lost == -1 {
331                                                         // ignore the very first packet in the stream
332                                                         lost = 0
333                                                 } else {
334                                                         lost = int(ctr - (stream.ctr + 1))
335                                                 }
336                                                 stream.ctr = ctr
337                                                 stream.stats.lost += int64(lost)
338                                                 if lost > vors.MaxLost {
339                                                         lost = 0
340                                                 }
341                                                 for ; lost > 0; lost-- {
342                                                         lastDur, err = dec.LastPacketDuration()
343                                                         if err != nil {
344                                                                 log.Println("PLC:", err)
345                                                                 continue
346                                                         }
347                                                         err = dec.DecodePLC(pcm[:lastDur])
348                                                         if err != nil {
349                                                                 log.Println("PLC:", err)
350                                                                 continue
351                                                         }
352                                                         stream.stats.AddRMS(pcm)
353                                                         if cmd == nil {
354                                                                 continue
355                                                         }
356                                                         pcmConv(pcmbuf, pcm[:lastDur])
357                                                         if _, err = io.Copy(player, bytes.NewReader(
358                                                                 pcmbuf[:2*lastDur])); err != nil {
359                                                                 log.Println("play:", err)
360                                                         }
361                                                 }
362                                                 _, err = dec.Decode(pkt, pcm)
363                                                 if err != nil {
364                                                         log.Println("decode:", err)
365                                                         continue
366                                                 }
367                                                 stream.stats.AddRMS(pcm)
368                                                 stream.stats.last = time.Now()
369                                                 if cmd == nil {
370                                                         continue
371                                                 }
372                                                 pcmConv(pcmbuf, pcm)
373                                                 if _, err = io.Copy(player,
374                                                         bytes.NewReader(pcmbuf)); err != nil {
375                                                         log.Println("play:", err)
376                                                 }
377                                         }
378                                         if cmd != nil {
379                                                 cmd.Process.Kill()
380                                         }
381                                 }()
382                                 go statsDrawer(stream.stats, stream.name)
383                                 Streams[sid] = stream
384                         case vors.CmdDel:
385                                 sid := parseSID(cols[1])
386                                 s := Streams[sid]
387                                 if s == nil {
388                                         log.Println("unknown sid:", sid)
389                                         continue
390                                 }
391                                 delete(Streams, sid)
392                                 close(s.in)
393                                 close(s.stats.dead)
394                                 log.Println("del", s.name)
395                         default:
396                                 log.Fatal("unknown cmd:", cols[0])
397                         }
398                 }
399                 if scanner.Err() != nil {
400                         log.Print("scanner:", err)
401                         Finish <- struct{}{}
402                 }
403         }(&seen)
404
405         go func(seen *time.Time) {
406                 for now := range time.Tick(vors.PingTime) {
407                         if seen.Add(2 * vors.PingTime).Before(now) {
408                                 log.Println("timeout:", seen)
409                                 Finish <- struct{}{}
410                                 break
411                         }
412                 }
413         }(&seen)
414
415         go func() {
416                 <-LoggerReady
417                 var n int
418                 var from *net.UDPAddr
419                 var err error
420                 var stream *Stream
421                 var ctr uint32
422                 for {
423                         buf := make([]byte, 2*vors.FrameLen)
424                         n, from, err = ln.ReadFromUDP(buf)
425                         if err != nil {
426                                 log.Println("recvfrom:", err)
427                                 Finish <- struct{}{}
428                                 break
429                         }
430                         if from.Port != addrUDP.Port || !from.IP.Equal(addrUDP.IP) {
431                                 log.Println("wrong addr:", from)
432                                 continue
433                         }
434                         if n <= 4+vors.TagLen {
435                                 log.Println("too small:", n)
436                                 continue
437                         }
438                         stream = Streams[buf[0]]
439                         if stream == nil {
440                                 log.Println("unknown stream:", buf[0])
441                                 continue
442                         }
443                         stream.stats.pkts++
444                         stream.stats.bytes += uint64(n)
445                         ctr = binary.BigEndian.Uint32(buf)
446                         if ctr <= stream.ctr {
447                                 stream.stats.reorder++
448                                 continue
449                         }
450                         stream.in <- buf[:n]
451                 }
452         }()
453
454         go statsDrawer(OurStats, *Name)
455         go func() {
456                 <-LoggerReady
457                 for {
458                         OurStats.pkts++
459                         OurStats.bytes += 1
460                         if _, err = ln.WriteTo([]byte{sid}, addrUDP); err != nil {
461                                 log.Println("send:", err)
462                                 Finish <- struct{}{}
463                         }
464                         time.Sleep(time.Second)
465                 }
466         }()
467         go func() {
468                 if *recCmd == "" {
469                         return
470                 }
471                 <-LoggerReady
472                 var ciph *chacha20.Cipher
473                 var macKey [32]byte
474                 var mac *poly1305.MAC
475                 tag := make([]byte, poly1305.TagSize)
476                 buf := make([]byte, 2*vors.FrameLen)
477                 pcm := make([]int16, vors.FrameLen)
478                 nonce := make([]byte, 12)
479                 nonce[len(nonce)-4] = sid
480                 var pkt []byte
481                 var n, i int
482                 for {
483                         _, err = io.ReadFull(mic, buf)
484                         if err != nil {
485                                 log.Println("mic:", err)
486                                 break
487                         }
488                         if Muted {
489                                 continue
490                         }
491                         for i = 0; i < vors.FrameLen; i++ {
492                                 pcm[i] = int16(uint16(buf[i*2+0]) | (uint16(buf[i*2+1]) << 8))
493                         }
494                         if vad != 0 && vors.RMS(pcm) < vad {
495                                 continue
496                         }
497                         n, err = opusEnc.Encode(pcm, buf[4:])
498                         if err != nil {
499                                 log.Fatal(err)
500                         }
501
502                         incr(nonce[len(nonce)-3:])
503                         copy(buf, nonce[len(nonce)-4:])
504                         ciph, err = chacha20.NewUnauthenticatedCipher(keyOur, nonce)
505                         if err != nil {
506                                 log.Fatal(err)
507                         }
508                         clear(macKey[:])
509                         ciph.XORKeyStream(macKey[:], macKey[:])
510                         ciph.SetCounter(1)
511                         ciph.XORKeyStream(buf[4:4+n], buf[4:4+n])
512                         mac = poly1305.New(&macKey)
513                         if _, err = mac.Write(buf[4 : 4+n]); err != nil {
514                                 log.Fatal(err)
515                         }
516                         mac.Sum(tag[:0])
517                         copy(buf[4+n:], tag[:vors.TagLen])
518                         pkt = buf[:4+n+vors.TagLen]
519
520                         OurStats.pkts++
521                         OurStats.bytes += uint64(len(pkt))
522                         OurStats.last = time.Now()
523                         OurStats.AddRMS(pcm)
524                         if _, err = ln.WriteTo(pkt, addrUDP); err != nil {
525                                 log.Println("send:", err)
526                                 break
527                         }
528                 }
529         }()
530
531         err = GUI.MainLoop()
532         if err != nil && err != gocui.ErrQuit {
533                 log.Fatal(err)
534         }
535 }