]> Sergey Matveev's repositories - vors.git/blob - cmd/client/main.go
go fmt
[vors.git] / cmd / client / main.go
1 // VoRS -- Vo(IP) Really Simple
2 // Copyright (C) 2024 Sergey Matveev <stargrave@stargrave.org>
3 //
4 // This program is free software: you can redistribute it and/or modify
5 // it under the terms of the GNU Affero General Public License as
6 // published by the Free Software Foundation, version 3 of the License.
7 //
8 // This program is distributed in the hope that it will be useful,
9 // but WITHOUT ANY WARRANTY; without even the implied warranty of
10 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11 // GNU Affero General Public License for more details.
12 //
13 // You should have received a copy of the GNU Affero General Public License
14 // along with this program.  If not, see <http://www.gnu.org/licenses/>.
15
16 package main
17
18 import (
19         "bytes"
20         "crypto/subtle"
21         "encoding/base64"
22         "encoding/binary"
23         "encoding/hex"
24         "flag"
25         "fmt"
26         "io"
27         "log"
28         "net"
29         "os"
30         "os/exec"
31         "strconv"
32         "strings"
33         "time"
34
35         "github.com/dchest/siphash"
36         "github.com/flynn/noise"
37         "github.com/jroimartin/gocui"
38         "go.stargrave.org/opus/v2"
39         vors "go.stargrave.org/vors/internal"
40         "golang.org/x/crypto/blake2s"
41         "golang.org/x/crypto/chacha20"
42 )
43
44 type Stream struct {
45         name  string
46         ctr   uint32
47         in    chan []byte
48         stats *Stats
49 }
50
51 var (
52         Streams  = map[byte]*Stream{}
53         Finish   = make(chan struct{})
54         OurStats = &Stats{dead: make(chan struct{})}
55         Name     = flag.String("name", "test", "username")
56         Room     = flag.String("room", "/", "room name")
57         Muted    bool
58 )
59
60 func parseSID(s string) byte {
61         n, err := strconv.Atoi(s)
62         if err != nil {
63                 log.Fatal(err)
64         }
65         if n > 255 {
66                 log.Fatal("too big stream num")
67         }
68         return byte(n)
69 }
70
71 func makeCmd(cmd string) *exec.Cmd {
72         args := strings.Fields(cmd)
73         if len(args) == 1 {
74                 return exec.Command(args[0])
75         }
76         return exec.Command(args[0], args[1:]...)
77 }
78
79 func incr(data []byte) {
80         for i := len(data) - 1; i >= 0; i-- {
81                 data[i]++
82                 if data[i] != 0 {
83                         return
84                 }
85         }
86         panic("overflow")
87 }
88
89 const soxParams = "--no-show-progress --buffer 1920 --channels 1 --endian little --encoding signed --rate 48000 --bits 16 --type raw -"
90
91 func main() {
92         srvAddr := flag.String("srv", "vors.home.arpa:"+strconv.Itoa(vors.DefaultPort),
93                 "host:TCP/UDP port to connect to")
94         srvPubB64 := flag.String("pub", "", "server's public key, Base64")
95         recCmd := flag.String("rec", "rec "+soxParams, "rec command")
96         playCmd := flag.String("play", "play "+soxParams, "play command")
97         vadRaw := flag.Uint("vad", 0, "VAD threshold")
98         passwd := flag.String("passwd", "", "protected room's password")
99         muteToggle := flag.String("mute-toggle", "", "path to FIFO to toggle mute")
100         version := flag.Bool("version", false, "print version")
101         warranty := flag.Bool("warranty", false, "print warranty information")
102         flag.Parse()
103         log.SetFlags(log.Lmicroseconds | log.Lshortfile)
104
105         if *warranty {
106                 fmt.Println(vors.Warranty)
107                 return
108         }
109         if *version {
110                 fmt.Println(vors.GetVersion())
111                 return
112         }
113
114         srvPub, err := base64.RawURLEncoding.DecodeString(*srvPubB64)
115         if err != nil {
116                 log.Fatal(err)
117         }
118         *Name = strings.ReplaceAll(*Name, " ", "-")
119
120         go func() {
121                 if *muteToggle == "" {
122                         return
123                 }
124                 for {
125                         fd, err := os.OpenFile(*muteToggle, os.O_WRONLY, os.FileMode(0666))
126                         if err != nil {
127                                 log.Fatalln(err)
128                         }
129                         Muted = !Muted
130                         var reply string
131                         if Muted {
132                                 reply = "muted"
133                         } else {
134                                 reply = "unmuted"
135                         }
136                         fd.WriteString(reply + "\n")
137                         fd.Close()
138                         time.Sleep(time.Second)
139                 }
140         }()
141
142         vad := uint64(*vadRaw)
143         opusEnc := newOpusEnc()
144         var mic io.ReadCloser
145         if *recCmd != "" {
146                 cmd := makeCmd(*recCmd)
147                 mic, err = cmd.StdoutPipe()
148                 if err != nil {
149                         log.Fatal(err)
150                 }
151                 err = cmd.Start()
152                 if err != nil {
153                         log.Fatal(err)
154                 }
155         }
156
157         ctrl, err := net.DialTCP("tcp", nil, vors.MustResolveTCP(*srvAddr))
158         if err != nil {
159                 log.Fatalln("dial server:", err)
160         }
161         defer ctrl.Close()
162         if err = ctrl.SetNoDelay(true); err != nil {
163                 log.Fatalln("nodelay:", err)
164         }
165
166         hs, err := noise.NewHandshakeState(noise.Config{
167                 CipherSuite: vors.NoiseCipherSuite,
168                 Pattern:     noise.HandshakeNK,
169                 Initiator:   true,
170                 PeerStatic:  srvPub,
171                 Prologue:    []byte(vors.NoisePrologue),
172         })
173         if err != nil {
174                 log.Fatalln("noise.NewHandshakeState:", err)
175         }
176         buf, _, _, err := hs.WriteMessage(nil, []byte(*Name+" "+*Room+" "+*passwd))
177         if err != nil {
178                 log.Fatalln("handshake encrypt:", err)
179         }
180         buf = append(
181                 append(
182                         []byte(vors.NoisePrologue),
183                         byte((len(buf)&0xFF00)>>8),
184                         byte((len(buf)&0x00FF)>>0),
185                 ),
186                 buf...,
187         )
188         _, err = io.Copy(ctrl, bytes.NewReader(buf))
189         if err != nil {
190                 log.Fatalln("write handshake:", err)
191                 return
192         }
193         buf, err = vors.PktRead(ctrl)
194         if err != nil {
195                 log.Fatalln("read handshake:", err)
196         }
197         buf, txCS, rxCS, err := hs.ReadMessage(nil, buf)
198         if err != nil {
199                 log.Fatalln("handshake decrypt:", err)
200         }
201
202         rx := make(chan []byte)
203         go func() {
204                 for {
205                         buf, err := vors.PktRead(ctrl)
206                         if err != nil {
207                                 log.Println("rx", err)
208                                 break
209                         }
210                         buf, err = rxCS.Decrypt(buf[:0], nil, buf)
211                         if err != nil {
212                                 log.Println("rx decrypt", err)
213                                 break
214                         }
215                         rx <- buf
216                 }
217                 Finish <- struct{}{}
218         }()
219
220         srvAddrUDP := vors.MustResolveUDP(*srvAddr)
221         conn, err := net.DialUDP("udp", nil, srvAddrUDP)
222         if err != nil {
223                 log.Fatalln("connect:", err)
224         }
225         var sid byte
226         {
227                 cols := strings.Fields(string(buf))
228                 if cols[0] != "OK" || len(cols) != 2 {
229                         log.Fatalln("handshake failed:", cols)
230                 }
231                 var cookie vors.Cookie
232                 cookieRaw, err := hex.DecodeString(cols[1])
233                 if err != nil {
234                         log.Fatal(err)
235                 }
236                 copy(cookie[:], cookieRaw)
237                 timeout := time.NewTimer(vors.PingTime)
238                 defer func() {
239                         if !timeout.Stop() {
240                                 <-timeout.C
241                         }
242                 }()
243                 ticker := time.NewTicker(time.Second)
244                 if _, err = conn.Write(cookie[:]); err != nil {
245                         log.Fatalln("write:", err)
246                 }
247         WaitForCookieAcceptance:
248                 for {
249                         select {
250                         case <-timeout.C:
251                                 log.Fatalln("cookie acceptance timeout")
252                         case <-ticker.C:
253                                 if _, err = conn.Write(cookie[:]); err != nil {
254                                         log.Fatalln("write:", err)
255                                 }
256                         case buf = <-rx:
257                                 cols = strings.Fields(string(buf))
258                                 if cols[0] != "SID" || len(cols) != 2 {
259                                         log.Fatalln("cookie acceptance failed:", string(buf))
260                                 }
261                                 sid = parseSID(cols[1])
262                                 Streams[sid] = &Stream{name: *Name, stats: OurStats}
263                                 break WaitForCookieAcceptance
264                         }
265                 }
266                 if !timeout.Stop() {
267                         <-timeout.C
268                 }
269         }
270
271         var keyCiphOur []byte
272         var keyMACOur []byte
273         {
274                 xof, err := blake2s.NewXOF(32+16, nil)
275                 if err != nil {
276                         log.Fatalln(err)
277                 }
278                 xof.Write([]byte(vors.NoisePrologue))
279                 xof.Write(hs.ChannelBinding())
280                 buf := make([]byte, 32+16)
281                 if _, err = io.ReadFull(xof, buf); err != nil {
282                         log.Fatalln(err)
283                 }
284                 keyCiphOur, keyMACOur = buf[:32], buf[32:]
285         }
286
287         seen := time.Now()
288
289         LoggerReady := make(chan struct{})
290         GUI, err = gocui.NewGui(gocui.OutputNormal)
291         if err != nil {
292                 log.Fatal(err)
293         }
294         defer GUI.Close()
295         GUI.SelFgColor = gocui.ColorCyan
296         GUI.Highlight = true
297         GUI.SetManagerFunc(guiLayout)
298         if err := GUI.SetKeybinding("", gocui.KeyF10, gocui.ModNone, guiQuit); err != nil {
299                 log.Fatal(err)
300         }
301         if err := GUI.SetKeybinding("", gocui.KeyEnter, gocui.ModNone, mute); err != nil {
302                 log.Fatal(err)
303         }
304
305         go func() {
306                 <-GUIReadyC
307                 v, err := GUI.View("logs")
308                 if err != nil {
309                         log.Fatal(err)
310                 }
311                 log.SetOutput(v)
312                 log.Println("connected", "sid:", sid,
313                         "addr:", conn.LocalAddr().String())
314                 close(LoggerReady)
315                 for {
316                         time.Sleep(vors.ScreenRefresh)
317                         GUI.Update(func(gui *gocui.Gui) error {
318                                 return nil
319                         })
320                 }
321         }()
322
323         go func() {
324                 <-Finish
325                 go GUI.Close()
326                 time.Sleep(100 * time.Millisecond)
327                 os.Exit(0)
328         }()
329
330         go func() {
331                 for {
332                         time.Sleep(vors.PingTime)
333                         buf, err := txCS.Encrypt(nil, nil, []byte(vors.CmdPing))
334                         if err != nil {
335                                 log.Fatalln("tx encrypt:", err)
336                         }
337                         if err = vors.PktWrite(ctrl, buf); err != nil {
338                                 log.Fatalln("tx:", err)
339                         }
340                 }
341         }()
342
343         go func(seen *time.Time) {
344                 var now time.Time
345                 for buf := range rx {
346                         if string(buf) == vors.CmdPong {
347                                 now = time.Now()
348                                 *seen = now
349                                 continue
350                         }
351                         cols := strings.Fields(string(buf))
352                         switch cols[0] {
353                         case vors.CmdAdd:
354                                 sidRaw, name, keyHex := cols[1], cols[2], cols[3]
355                                 log.Println("add", name, "sid:", sidRaw)
356                                 sid := parseSID(sidRaw)
357                                 key, err := hex.DecodeString(keyHex)
358                                 if err != nil {
359                                         log.Fatal(err)
360                                 }
361                                 keyCiph, keyMAC := key[:32], key[32:]
362                                 stream := &Stream{
363                                         name:  name,
364                                         in:    make(chan []byte, 1<<10),
365                                         stats: &Stats{dead: make(chan struct{})},
366                                 }
367                                 go func() {
368                                         dec, err := opus.NewDecoder(vors.Rate, 1)
369                                         if err != nil {
370                                                 log.Fatal(err)
371                                         }
372                                         if err = dec.SetComplexity(10); err != nil {
373                                                 log.Fatal(err)
374                                         }
375
376                                         var player io.WriteCloser
377                                         playerTx := make(chan []byte, 5)
378                                         var cmd *exec.Cmd
379                                         if *playCmd != "" {
380                                                 cmd = makeCmd(*playCmd)
381                                                 player, err = cmd.StdinPipe()
382                                                 if err != nil {
383                                                         log.Fatal(err)
384                                                 }
385                                                 err = cmd.Start()
386                                                 if err != nil {
387                                                         log.Fatal(err)
388                                                 }
389                                                 go func() {
390                                                         var pcmbuf []byte
391                                                         var ok bool
392                                                         var err error
393                                                         for {
394                                                                 for len(playerTx) > vors.MaxLost {
395                                                                         <-playerTx
396                                                                         stream.stats.reorder++
397                                                                 }
398                                                                 pcmbuf, ok = <-playerTx
399                                                                 if !ok {
400                                                                         break
401                                                                 }
402                                                                 if _, err = io.Copy(player,
403                                                                         bytes.NewReader(pcmbuf)); err != nil {
404                                                                         log.Println("play:", err)
405                                                                 }
406                                                         }
407                                                         cmd.Process.Kill()
408                                                 }()
409                                         }
410
411                                         var ciph *chacha20.Cipher
412                                         mac := siphash.New(keyMAC)
413                                         tag := make([]byte, siphash.Size)
414                                         var ctr uint32
415                                         pcm := make([]int16, vors.FrameLen)
416                                         nonce := make([]byte, 12)
417                                         var pkt []byte
418                                         lost := -1
419                                         var lastDur int
420                                         for buf := range stream.in {
421                                                 copy(nonce[len(nonce)-4:], buf)
422                                                 mac.Reset()
423                                                 if _, err = mac.Write(buf[:len(buf)-siphash.Size]); err != nil {
424                                                         log.Fatal(err)
425                                                 }
426                                                 mac.Sum(tag[:0])
427                                                 if subtle.ConstantTimeCompare(
428                                                         tag[:siphash.Size],
429                                                         buf[len(buf)-siphash.Size:],
430                                                 ) != 1 {
431                                                         stream.stats.bads++
432                                                         continue
433                                                 }
434                                                 ciph, err = chacha20.NewUnauthenticatedCipher(keyCiph, nonce)
435                                                 if err != nil {
436                                                         log.Fatal(err)
437                                                 }
438                                                 pkt = buf[4 : len(buf)-siphash.Size]
439                                                 ciph.XORKeyStream(pkt, pkt)
440
441                                                 ctr = binary.BigEndian.Uint32(nonce[len(nonce)-4:])
442                                                 if lost == -1 {
443                                                         // ignore the very first packet in the stream
444                                                         lost = 0
445                                                 } else {
446                                                         lost = int(ctr - (stream.ctr + 1))
447                                                 }
448                                                 stream.ctr = ctr
449                                                 stream.stats.lost += int64(lost)
450                                                 if lost > vors.MaxLost {
451                                                         lost = 0
452                                                 }
453                                                 for ; lost > 0; lost-- {
454                                                         lastDur, err = dec.LastPacketDuration()
455                                                         if err != nil {
456                                                                 log.Println("PLC:", err)
457                                                                 continue
458                                                         }
459                                                         err = dec.DecodePLC(pcm[:lastDur])
460                                                         if err != nil {
461                                                                 log.Println("PLC:", err)
462                                                                 continue
463                                                         }
464                                                         stream.stats.AddRMS(pcm)
465                                                         if cmd == nil {
466                                                                 continue
467                                                         }
468                                                         pcmbuf := make([]byte, 2*lastDur)
469                                                         pcmConv(pcmbuf, pcm[:lastDur])
470                                                         playerTx <- pcmbuf
471                                                 }
472                                                 _, err = dec.Decode(pkt, pcm)
473                                                 if err != nil {
474                                                         log.Println("decode:", err)
475                                                         continue
476                                                 }
477                                                 stream.stats.AddRMS(pcm)
478                                                 stream.stats.last = time.Now()
479                                                 if cmd == nil {
480                                                         continue
481                                                 }
482                                                 pcmbuf := make([]byte, 2*len(pcm))
483                                                 pcmConv(pcmbuf, pcm)
484                                                 playerTx <- pcmbuf
485                                         }
486                                         if cmd != nil {
487                                                 close(playerTx)
488                                         }
489                                 }()
490                                 go statsDrawer(stream.stats, stream.name)
491                                 Streams[sid] = stream
492                         case vors.CmdDel:
493                                 sid := parseSID(cols[1])
494                                 s := Streams[sid]
495                                 if s == nil {
496                                         log.Println("unknown sid:", sid)
497                                         continue
498                                 }
499                                 log.Println("del", s.name, "sid:", cols[1])
500                                 delete(Streams, sid)
501                                 close(s.in)
502                                 close(s.stats.dead)
503                         default:
504                                 log.Fatal("unknown cmd:", cols[0])
505                         }
506                 }
507         }(&seen)
508
509         go func(seen *time.Time) {
510                 for now := range time.Tick(vors.PingTime) {
511                         if seen.Add(2 * vors.PingTime).Before(now) {
512                                 log.Println("timeout:", seen)
513                                 Finish <- struct{}{}
514                                 break
515                         }
516                 }
517         }(&seen)
518
519         go func() {
520                 <-LoggerReady
521                 var n int
522                 var from *net.UDPAddr
523                 var err error
524                 var stream *Stream
525                 var ctr uint32
526                 for {
527                         buf := make([]byte, 2*vors.FrameLen)
528                         n, from, err = conn.ReadFromUDP(buf)
529                         if err != nil {
530                                 log.Println("recvfrom:", err)
531                                 Finish <- struct{}{}
532                                 break
533                         }
534                         if from.Port != srvAddrUDP.Port || !from.IP.Equal(srvAddrUDP.IP) {
535                                 log.Println("wrong addr:", from)
536                                 continue
537                         }
538                         if n <= 4+siphash.Size {
539                                 log.Println("too small:", n)
540                                 continue
541                         }
542                         stream = Streams[buf[0]]
543                         if stream == nil {
544                                 // log.Println("unknown stream:", buf[0])
545                                 continue
546                         }
547                         stream.stats.pkts++
548                         stream.stats.bytes += vors.IPHdrLen(from.IP) + 8 + uint64(n)
549                         ctr = binary.BigEndian.Uint32(buf)
550                         if ctr <= stream.ctr {
551                                 stream.stats.reorder++
552                                 continue
553                         }
554                         stream.in <- buf[:n]
555                 }
556         }()
557
558         go statsDrawer(OurStats, *Name)
559         go func() {
560                 <-LoggerReady
561                 for now := range time.NewTicker(time.Second).C {
562                         if !OurStats.last.Add(time.Second).Before(now) {
563                                 continue
564                         }
565                         OurStats.pkts++
566                         OurStats.bytes += vors.IPHdrLen(srvAddrUDP.IP) + 8 + 1
567                         if _, err = conn.Write([]byte{sid}); err != nil {
568                                 log.Println("send:", err)
569                                 Finish <- struct{}{}
570                         }
571                 }
572         }()
573         go func() {
574                 if *recCmd == "" {
575                         return
576                 }
577                 <-LoggerReady
578                 var ciph *chacha20.Cipher
579                 mac := siphash.New(keyMACOur)
580                 tag := make([]byte, siphash.Size)
581                 buf := make([]byte, 2*vors.FrameLen)
582                 pcm := make([]int16, vors.FrameLen)
583                 nonce := make([]byte, 12)
584                 nonce[len(nonce)-4] = sid
585                 var pkt []byte
586                 var n, i int
587                 for {
588                         _, err = io.ReadFull(mic, buf)
589                         if err != nil {
590                                 log.Println("mic:", err)
591                                 break
592                         }
593                         if Muted {
594                                 continue
595                         }
596                         for i = 0; i < vors.FrameLen; i++ {
597                                 pcm[i] = int16(uint16(buf[i*2+0]) | (uint16(buf[i*2+1]) << 8))
598                         }
599                         if vad != 0 && vors.RMS(pcm) < vad {
600                                 continue
601                         }
602                         n, err = opusEnc.Encode(pcm, buf[4:])
603                         if err != nil {
604                                 log.Fatal(err)
605                         }
606                         if n <= 2 {
607                                 // DTX
608                                 continue
609                         }
610
611                         incr(nonce[len(nonce)-3:])
612                         copy(buf, nonce[len(nonce)-4:])
613                         ciph, err = chacha20.NewUnauthenticatedCipher(keyCiphOur, nonce)
614                         if err != nil {
615                                 log.Fatal(err)
616                         }
617                         ciph.XORKeyStream(buf[4:4+n], buf[4:4+n])
618                         mac.Reset()
619                         if _, err = mac.Write(buf[:4+n]); err != nil {
620                                 log.Fatal(err)
621                         }
622                         mac.Sum(tag[:0])
623                         copy(buf[4+n:], tag)
624                         pkt = buf[:4+n+siphash.Size]
625
626                         OurStats.pkts++
627                         OurStats.bytes += vors.IPHdrLen(srvAddrUDP.IP) + 8 + uint64(len(pkt))
628                         OurStats.last = time.Now()
629                         OurStats.AddRMS(pcm)
630                         if _, err = conn.Write(pkt); err != nil {
631                                 log.Println("send:", err)
632                                 break
633                         }
634                 }
635         }()
636
637         err = GUI.MainLoop()
638         if err != nil && err != gocui.ErrQuit {
639                 log.Fatal(err)
640         }
641 }