]> Sergey Matveev's repositories - vors.git/blob - cmd/client/main.go
fadaf51fd28367cfbed9a8468f666009e8433415a1dfffb436a083f2b60b2243
[vors.git] / cmd / client / main.go
1 // VoRS -- Vo(IP) Really Simple
2 // Copyright (C) 2024 Sergey Matveev <stargrave@stargrave.org>
3 //
4 // This program is free software: you can redistribute it and/or modify
5 // it under the terms of the GNU Affero General Public License as
6 // published by the Free Software Foundation, version 3 of the License.
7 //
8 // This program is distributed in the hope that it will be useful,
9 // but WITHOUT ANY WARRANTY; without even the implied warranty of
10 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11 // GNU Affero General Public License for more details.
12 //
13 // You should have received a copy of the GNU Affero General Public License
14 // along with this program.  If not, see <http://www.gnu.org/licenses/>.
15
16 package main
17
18 import (
19         "bytes"
20         "crypto/subtle"
21         "encoding/base64"
22         "encoding/binary"
23         "encoding/hex"
24         "flag"
25         "fmt"
26         "io"
27         "log"
28         "net"
29         "os"
30         "os/exec"
31         "strconv"
32         "strings"
33         "time"
34
35         "github.com/flynn/noise"
36         "github.com/jroimartin/gocui"
37         "go.stargrave.org/opus/v2"
38         vors "go.stargrave.org/vors/internal"
39         "golang.org/x/crypto/blake2s"
40         "golang.org/x/crypto/chacha20"
41         "golang.org/x/crypto/poly1305"
42 )
43
44 type Stream struct {
45         name  string
46         ctr   uint32
47         in    chan []byte
48         stats *Stats
49 }
50
51 var (
52         Streams  = map[byte]*Stream{}
53         Finish   = make(chan struct{})
54         OurStats = &Stats{dead: make(chan struct{})}
55         Name     = flag.String("name", "test", "username")
56         Room     = flag.String("room", "/", "room name")
57         Muted    bool
58 )
59
60 func parseSID(s string) byte {
61         n, err := strconv.Atoi(s)
62         if err != nil {
63                 log.Fatal(err)
64         }
65         if n > 255 {
66                 log.Fatal("too big stream num")
67         }
68         return byte(n)
69 }
70
71 func makeCmd(cmd string) *exec.Cmd {
72         args := strings.Fields(cmd)
73         if len(args) == 1 {
74                 return exec.Command(args[0])
75         }
76         return exec.Command(args[0], args[1:]...)
77 }
78
79 func incr(data []byte) {
80         for i := len(data) - 1; i >= 0; i-- {
81                 data[i]++
82                 if data[i] != 0 {
83                         return
84                 }
85         }
86         panic("overflow")
87 }
88
89 const soxParams = "--no-show-progress --buffer 1920 --channels 1 --endian little --encoding signed --rate 48000 --bits 16 --type raw -"
90
91 func main() {
92         srvAddr := flag.String("srv", "vors.home.arpa:"+strconv.Itoa(vors.DefaultPort),
93                 "host:TCP/UDP port to connect to")
94         srvPubB64 := flag.String("pub", "", "server's public key, Base64")
95         recCmd := flag.String("rec", "rec "+soxParams, "rec command")
96         playCmd := flag.String("play", "play "+soxParams, "play command")
97         vadRaw := flag.Uint("vad", 0, "VAD threshold")
98         passwd := flag.String("passwd", "", "protected room's password")
99         muteToggle := flag.String("mute-toggle", "", "path to FIFO to toggle mute")
100         version := flag.Bool("version", false, "print version")
101         warranty := flag.Bool("warranty", false, "print warranty information")
102         flag.Parse()
103         log.SetFlags(log.Lmicroseconds | log.Lshortfile)
104
105         if *warranty {
106                 fmt.Println(vors.Warranty)
107                 return
108         }
109         if *version {
110                 fmt.Println(vors.GetVersion())
111                 return
112         }
113
114         srvPub, err := base64.RawURLEncoding.DecodeString(*srvPubB64)
115         if err != nil {
116                 log.Fatal(err)
117         }
118         *Name = strings.ReplaceAll(*Name, " ", "-")
119
120         go func() {
121                 if *muteToggle == "" {
122                         return
123                 }
124                 for {
125                         fd, err := os.OpenFile(*muteToggle, os.O_WRONLY, os.FileMode(0666))
126                         if err != nil {
127                                 log.Fatalln(err)
128                         }
129                         Muted = !Muted
130                         var reply string
131                         if Muted {
132                                 reply = "muted"
133                         } else {
134                                 reply = "unmuted"
135                         }
136                         fd.WriteString(reply + "\n")
137                         fd.Close()
138                         time.Sleep(time.Second)
139                 }
140         }()
141
142         vad := uint64(*vadRaw)
143         opusEnc := newOpusEnc()
144         var mic io.ReadCloser
145         if *recCmd != "" {
146                 cmd := makeCmd(*recCmd)
147                 mic, err = cmd.StdoutPipe()
148                 if err != nil {
149                         log.Fatal(err)
150                 }
151                 err = cmd.Start()
152                 if err != nil {
153                         log.Fatal(err)
154                 }
155         }
156
157         ctrl, err := net.DialTCP("tcp", nil, vors.MustResolveTCP(*srvAddr))
158         if err != nil {
159                 log.Fatalln("dial server:", err)
160         }
161         defer ctrl.Close()
162         if err = ctrl.SetNoDelay(true); err != nil {
163                 log.Fatalln("nodelay:", err)
164         }
165
166         hs, err := noise.NewHandshakeState(noise.Config{
167                 CipherSuite: vors.NoiseCipherSuite,
168                 Pattern:     noise.HandshakeNK,
169                 Initiator:   true,
170                 PeerStatic:  srvPub,
171                 Prologue:    []byte(vors.NoisePrologue),
172         })
173         if err != nil {
174                 log.Fatalln("noise.NewHandshakeState:", err)
175         }
176         buf, _, _, err := hs.WriteMessage(nil, []byte(*Name+" "+*Room+" "+*passwd))
177         if err != nil {
178                 log.Fatalln("handshake encrypt:", err)
179         }
180         buf = append(
181                 append(
182                         []byte(vors.NoisePrologue),
183                         byte((len(buf) & 0xFF00) >> 8),
184                         byte((len(buf) & 0x00FF) >> 0),
185                 ),
186                 buf...
187         )
188         _, err = io.Copy(ctrl, bytes.NewReader(buf))
189         if err != nil {
190                 log.Fatalln("write handshake:", err)
191                 return
192         }
193         buf, err = vors.PktRead(ctrl)
194         if err != nil {
195                 log.Fatalln("read handshake:", err)
196         }
197         buf, txCS, rxCS, err := hs.ReadMessage(nil, buf)
198         if err != nil {
199                 log.Fatalln("handshake decrypt:", err)
200         }
201
202         rx := make(chan []byte)
203         go func() {
204                 for {
205                         buf, err := vors.PktRead(ctrl)
206                         if err != nil {
207                                 log.Println("rx", err)
208                                 break
209                         }
210                         buf, err = rxCS.Decrypt(buf[:0], nil, buf)
211                         if err != nil {
212                                 log.Println("rx decrypt", err)
213                                 break
214                         }
215                         rx <- buf
216                 }
217                 Finish <- struct{}{}
218         }()
219
220         srvAddrUDP := vors.MustResolveUDP(*srvAddr)
221         conn, err := net.DialUDP("udp", nil, srvAddrUDP)
222         if err != nil {
223                 log.Fatalln("connect:", err)
224         }
225         var sid byte
226         {
227                 cols := strings.Fields(string(buf))
228                 if cols[0] != "OK" || len(cols) != 2 {
229                         log.Fatalln("handshake failed:", cols)
230                 }
231                 var cookie vors.Cookie
232                 cookieRaw, err := hex.DecodeString(cols[1])
233                 if err != nil {
234                         log.Fatal(err)
235                 }
236                 copy(cookie[:], cookieRaw)
237                 timeout := time.NewTimer(vors.PingTime)
238                 defer func() {
239                         if !timeout.Stop() {
240                                 <-timeout.C
241                         }
242                 }()
243                 ticker := time.NewTicker(time.Second)
244                 if _, err = conn.Write(cookie[:]); err != nil {
245                         log.Fatalln("write:", err)
246                 }
247         WaitForCookieAcceptance:
248                 for {
249                         select {
250                         case <-timeout.C:
251                                 log.Fatalln("cookie acceptance timeout")
252                         case <-ticker.C:
253                                 if _, err = conn.Write(cookie[:]); err != nil {
254                                         log.Fatalln("write:", err)
255                                 }
256                         case buf = <-rx:
257                                 cols = strings.Fields(string(buf))
258                                 if cols[0] != "SID" || len(cols) != 2 {
259                                         log.Fatalln("cookie acceptance failed:", string(buf))
260                                 }
261                                 sid = parseSID(cols[1])
262                                 Streams[sid] = &Stream{name: *Name, stats: OurStats}
263                                 break WaitForCookieAcceptance
264                         }
265                 }
266                 if !timeout.Stop() {
267                         <-timeout.C
268                 }
269         }
270
271         var keyOur []byte
272         {
273                 h, err := blake2s.New256(hs.ChannelBinding())
274                 if err != nil {
275                         log.Fatalln(err)
276                 }
277                 h.Write([]byte(vors.NoisePrologue))
278                 keyOur = h.Sum(nil)
279         }
280
281         seen := time.Now()
282
283         LoggerReady := make(chan struct{})
284         GUI, err = gocui.NewGui(gocui.OutputNormal)
285         if err != nil {
286                 log.Fatal(err)
287         }
288         defer GUI.Close()
289         GUI.SelFgColor = gocui.ColorCyan
290         GUI.Highlight = true
291         GUI.SetManagerFunc(guiLayout)
292         if err := GUI.SetKeybinding("", gocui.KeyF10, gocui.ModNone, guiQuit); err != nil {
293                 log.Fatal(err)
294         }
295         if err := GUI.SetKeybinding("", gocui.KeyEnter, gocui.ModNone, mute); err != nil {
296                 log.Fatal(err)
297         }
298
299         go func() {
300                 <-GUIReadyC
301                 v, err := GUI.View("logs")
302                 if err != nil {
303                         log.Fatal(err)
304                 }
305                 log.SetOutput(v)
306                 log.Println("connected", "sid:", sid,
307                         "addr:", conn.LocalAddr().String())
308                 close(LoggerReady)
309                 for {
310                         time.Sleep(vors.ScreenRefresh)
311                         GUI.Update(func(gui *gocui.Gui) error {
312                                 return nil
313                         })
314                 }
315         }()
316
317         go func() {
318                 <-Finish
319                 go GUI.Close()
320                 time.Sleep(100 * time.Millisecond)
321                 os.Exit(0)
322         }()
323
324         go func() {
325                 for {
326                         time.Sleep(vors.PingTime)
327                         buf, err := txCS.Encrypt(nil, nil, []byte(vors.CmdPing))
328                         if err != nil {
329                                 log.Fatalln("tx encrypt:", err)
330                         }
331                         if err = vors.PktWrite(ctrl, buf); err != nil {
332                                 log.Fatalln("tx:", err)
333                         }
334                 }
335         }()
336
337         go func(seen *time.Time) {
338                 var now time.Time
339                 for buf := range rx {
340                         if string(buf) == vors.CmdPong {
341                                 now = time.Now()
342                                 *seen = now
343                                 continue
344                         }
345                         cols := strings.Fields(string(buf))
346                         switch cols[0] {
347                         case vors.CmdAdd:
348                                 sidRaw, name, keyHex := cols[1], cols[2], cols[3]
349                                 log.Println("add", name, "sid:", sidRaw)
350                                 sid := parseSID(sidRaw)
351                                 key, err := hex.DecodeString(keyHex)
352                                 if err != nil {
353                                         log.Fatal(err)
354                                 }
355                                 stream := &Stream{
356                                         name:  name,
357                                         in:    make(chan []byte, 1<<10),
358                                         stats: &Stats{dead: make(chan struct{})},
359                                 }
360                                 go func() {
361                                         dec, err := opus.NewDecoder(vors.Rate, 1)
362                                         if err != nil {
363                                                 log.Fatal(err)
364                                         }
365                                         if err = dec.SetComplexity(10); err != nil {
366                                                 log.Fatal(err)
367                                         }
368
369                                         var player io.WriteCloser
370                                         playerTx := make(chan []byte, 5)
371                                         var cmd *exec.Cmd
372                                         if *playCmd != "" {
373                                                 cmd = makeCmd(*playCmd)
374                                                 player, err = cmd.StdinPipe()
375                                                 if err != nil {
376                                                         log.Fatal(err)
377                                                 }
378                                                 err = cmd.Start()
379                                                 if err != nil {
380                                                         log.Fatal(err)
381                                                 }
382                                                 go func() {
383                                                         var pcmbuf []byte
384                                                         var ok bool
385                                                         var err error
386                                                         for {
387                                                                 for len(playerTx) > 1 {
388                                                                         <-playerTx
389                                                                         stream.stats.reorder++
390                                                                 }
391                                                                 pcmbuf, ok = <-playerTx
392                                                                 if !ok {
393                                                                         break
394                                                                 }
395                                                                 if _, err = io.Copy(player,
396                                                                         bytes.NewReader(pcmbuf)); err != nil {
397                                                                         log.Println("play:", err)
398                                                                 }
399                                                         }
400                                                         cmd.Process.Kill()
401                                                 }()
402                                         }
403
404                                         var ciph *chacha20.Cipher
405                                         var macKey [32]byte
406                                         var mac *poly1305.MAC
407                                         tag := make([]byte, poly1305.TagSize)
408                                         var ctr uint32
409                                         pcm := make([]int16, vors.FrameLen)
410                                         nonce := make([]byte, 12)
411                                         var pkt []byte
412                                         lost := -1
413                                         var lastDur int
414                                         for buf := range stream.in {
415                                                 copy(nonce[len(nonce)-4:], buf)
416                                                 ciph, err = chacha20.NewUnauthenticatedCipher(key, nonce)
417                                                 if err != nil {
418                                                         log.Fatal(err)
419                                                 }
420                                                 clear(macKey[:])
421                                                 ciph.XORKeyStream(macKey[:], macKey[:])
422                                                 ciph.SetCounter(1)
423                                                 mac = poly1305.New(&macKey)
424                                                 if _, err = mac.Write(buf[4 : len(buf)-vors.TagLen]); err != nil {
425                                                         log.Fatal(err)
426                                                 }
427                                                 mac.Sum(tag[:0])
428                                                 if subtle.ConstantTimeCompare(
429                                                         tag[:vors.TagLen],
430                                                         buf[len(buf)-vors.TagLen:],
431                                                 ) != 1 {
432                                                         stream.stats.bads++
433                                                         continue
434                                                 }
435                                                 pkt = buf[4 : len(buf)-vors.TagLen]
436                                                 ciph.XORKeyStream(pkt, pkt)
437
438                                                 ctr = binary.BigEndian.Uint32(nonce[len(nonce)-4:])
439                                                 if lost == -1 {
440                                                         // ignore the very first packet in the stream
441                                                         lost = 0
442                                                 } else {
443                                                         lost = int(ctr - (stream.ctr + 1))
444                                                 }
445                                                 stream.ctr = ctr
446                                                 stream.stats.lost += int64(lost)
447                                                 if lost > vors.MaxLost {
448                                                         lost = 0
449                                                 }
450                                                 for ; lost > 0; lost-- {
451                                                         lastDur, err = dec.LastPacketDuration()
452                                                         if err != nil {
453                                                                 log.Println("PLC:", err)
454                                                                 continue
455                                                         }
456                                                         err = dec.DecodePLC(pcm[:lastDur])
457                                                         if err != nil {
458                                                                 log.Println("PLC:", err)
459                                                                 continue
460                                                         }
461                                                         stream.stats.AddRMS(pcm)
462                                                         if cmd == nil {
463                                                                 continue
464                                                         }
465                                                         pcmbuf := make([]byte, 2*lastDur)
466                                                         pcmConv(pcmbuf, pcm[:lastDur])
467                                                         playerTx <- pcmbuf
468                                                 }
469                                                 _, err = dec.Decode(pkt, pcm)
470                                                 if err != nil {
471                                                         log.Println("decode:", err)
472                                                         continue
473                                                 }
474                                                 stream.stats.AddRMS(pcm)
475                                                 stream.stats.last = time.Now()
476                                                 if cmd == nil {
477                                                         continue
478                                                 }
479                                                 pcmbuf := make([]byte, 2*len(pcm))
480                                                 pcmConv(pcmbuf, pcm)
481                                                 playerTx <- pcmbuf
482                                         }
483                                         if cmd != nil {
484                                                 close(playerTx)
485                                         }
486                                 }()
487                                 go statsDrawer(stream.stats, stream.name)
488                                 Streams[sid] = stream
489                         case vors.CmdDel:
490                                 sid := parseSID(cols[1])
491                                 s := Streams[sid]
492                                 if s == nil {
493                                         log.Println("unknown sid:", sid)
494                                         continue
495                                 }
496                                 log.Println("del", s.name, "sid:", cols[1])
497                                 delete(Streams, sid)
498                                 close(s.in)
499                                 close(s.stats.dead)
500                         default:
501                                 log.Fatal("unknown cmd:", cols[0])
502                         }
503                 }
504         }(&seen)
505
506         go func(seen *time.Time) {
507                 for now := range time.Tick(vors.PingTime) {
508                         if seen.Add(2 * vors.PingTime).Before(now) {
509                                 log.Println("timeout:", seen)
510                                 Finish <- struct{}{}
511                                 break
512                         }
513                 }
514         }(&seen)
515
516         go func() {
517                 <-LoggerReady
518                 var n int
519                 var from *net.UDPAddr
520                 var err error
521                 var stream *Stream
522                 var ctr uint32
523                 for {
524                         buf := make([]byte, 2*vors.FrameLen)
525                         n, from, err = conn.ReadFromUDP(buf)
526                         if err != nil {
527                                 log.Println("recvfrom:", err)
528                                 Finish <- struct{}{}
529                                 break
530                         }
531                         if from.Port != srvAddrUDP.Port || !from.IP.Equal(srvAddrUDP.IP) {
532                                 log.Println("wrong addr:", from)
533                                 continue
534                         }
535                         if n <= 4+vors.TagLen {
536                                 log.Println("too small:", n)
537                                 continue
538                         }
539                         stream = Streams[buf[0]]
540                         if stream == nil {
541                                 // log.Println("unknown stream:", buf[0])
542                                 continue
543                         }
544                         stream.stats.pkts++
545                         stream.stats.bytes += vors.IPHdrLen(from.IP) + 8 + uint64(n)
546                         ctr = binary.BigEndian.Uint32(buf)
547                         if ctr <= stream.ctr {
548                                 stream.stats.reorder++
549                                 continue
550                         }
551                         stream.in <- buf[:n]
552                 }
553         }()
554
555         go statsDrawer(OurStats, *Name)
556         go func() {
557                 <-LoggerReady
558                 for now := range time.NewTicker(time.Second).C {
559                         if !OurStats.last.Add(time.Second).Before(now) {
560                                 continue
561                         }
562                         OurStats.pkts++
563                         OurStats.bytes += vors.IPHdrLen(srvAddrUDP.IP) + 8 + 1
564                         if _, err = conn.Write([]byte{sid}); err != nil {
565                                 log.Println("send:", err)
566                                 Finish <- struct{}{}
567                         }
568                 }
569         }()
570         go func() {
571                 if *recCmd == "" {
572                         return
573                 }
574                 <-LoggerReady
575                 var ciph *chacha20.Cipher
576                 var macKey [32]byte
577                 var mac *poly1305.MAC
578                 tag := make([]byte, poly1305.TagSize)
579                 buf := make([]byte, 2*vors.FrameLen)
580                 pcm := make([]int16, vors.FrameLen)
581                 nonce := make([]byte, 12)
582                 nonce[len(nonce)-4] = sid
583                 var pkt []byte
584                 var n, i int
585                 for {
586                         _, err = io.ReadFull(mic, buf)
587                         if err != nil {
588                                 log.Println("mic:", err)
589                                 break
590                         }
591                         if Muted {
592                                 continue
593                         }
594                         for i = 0; i < vors.FrameLen; i++ {
595                                 pcm[i] = int16(uint16(buf[i*2+0]) | (uint16(buf[i*2+1]) << 8))
596                         }
597                         if vad != 0 && vors.RMS(pcm) < vad {
598                                 continue
599                         }
600                         n, err = opusEnc.Encode(pcm, buf[4:])
601                         if err != nil {
602                                 log.Fatal(err)
603                         }
604                         if n <= 2 {
605                                 // DTX
606                                 continue
607                         }
608
609                         incr(nonce[len(nonce)-3:])
610                         copy(buf, nonce[len(nonce)-4:])
611                         ciph, err = chacha20.NewUnauthenticatedCipher(keyOur, nonce)
612                         if err != nil {
613                                 log.Fatal(err)
614                         }
615                         clear(macKey[:])
616                         ciph.XORKeyStream(macKey[:], macKey[:])
617                         ciph.SetCounter(1)
618                         ciph.XORKeyStream(buf[4:4+n], buf[4:4+n])
619                         mac = poly1305.New(&macKey)
620                         if _, err = mac.Write(buf[4 : 4+n]); err != nil {
621                                 log.Fatal(err)
622                         }
623                         mac.Sum(tag[:0])
624                         copy(buf[4+n:], tag[:vors.TagLen])
625                         pkt = buf[:4+n+vors.TagLen]
626
627                         OurStats.pkts++
628                         OurStats.bytes += vors.IPHdrLen(srvAddrUDP.IP) + 8 + uint64(len(pkt))
629                         OurStats.last = time.Now()
630                         OurStats.AddRMS(pcm)
631                         if _, err = conn.Write(pkt); err != nil {
632                                 log.Println("send:", err)
633                                 break
634                         }
635                 }
636         }()
637
638         err = GUI.MainLoop()
639         if err != nil && err != gocui.ErrQuit {
640                 log.Fatal(err)
641         }
642 }