]> Sergey Matveev's repositories - vors.git/blob - cmd/client/main.go
fbc355c47d79e506cb95a9df76d3771ee94d07eded53bad9286f8363ab4bf35e
[vors.git] / cmd / client / main.go
1 // VoRS -- Vo(IP) Really Simple
2 // Copyright (C) 2024 Sergey Matveev <stargrave@stargrave.org>
3 //
4 // This program is free software: you can redistribute it and/or modify
5 // it under the terms of the GNU Affero General Public License as
6 // published by the Free Software Foundation, version 3 of the License.
7 //
8 // This program is distributed in the hope that it will be useful,
9 // but WITHOUT ANY WARRANTY; without even the implied warranty of
10 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11 // GNU Affero General Public License for more details.
12 //
13 // You should have received a copy of the GNU Affero General Public License
14 // along with this program.  If not, see <http://www.gnu.org/licenses/>.
15
16 package main
17
18 import (
19         "bytes"
20         "crypto/subtle"
21         "encoding/base64"
22         "encoding/binary"
23         "encoding/hex"
24         "flag"
25         "fmt"
26         "io"
27         "log"
28         "net"
29         "os"
30         "os/exec"
31         "strconv"
32         "strings"
33         "time"
34
35         "github.com/dchest/siphash"
36         "github.com/flynn/noise"
37         "github.com/jroimartin/gocui"
38         "go.stargrave.org/opus/v2"
39         vors "go.stargrave.org/vors/v3/internal"
40         "golang.org/x/crypto/blake2s"
41         "golang.org/x/crypto/chacha20"
42 )
43
44 type Stream struct {
45         name  string
46         ctr   uint32
47         in    chan []byte
48         stats *Stats
49 }
50
51 var (
52         Streams  = map[byte]*Stream{}
53         Finish   = make(chan struct{})
54         OurStats = &Stats{dead: make(chan struct{})}
55         Name     = flag.String("name", "test", "username")
56         Room     = flag.String("room", "/", "room name")
57         Muted    bool
58 )
59
60 func parseSID(s string) byte {
61         n, err := strconv.Atoi(s)
62         if err != nil {
63                 log.Fatal(err)
64         }
65         if n > 255 {
66                 log.Fatal("too big stream num")
67         }
68         return byte(n)
69 }
70
71 func incr(data []byte) {
72         for i := len(data) - 1; i >= 0; i-- {
73                 data[i]++
74                 if data[i] != 0 {
75                         return
76                 }
77         }
78         panic("overflow")
79 }
80
81 func main() {
82         srvAddr := flag.String("srv", "vors.home.arpa:"+strconv.Itoa(vors.DefaultPort),
83                 "host:TCP/UDP port to connect to")
84         srvPubB64 := flag.String("pub", "", "server's public key, Base64")
85         recCmd := flag.String("rec", "rec "+vors.SoxParams, "rec command")
86         playCmd := flag.String("play", "play "+vors.SoxParams, "play command")
87         vadRaw := flag.Uint("vad", 0, "VAD threshold")
88         passwd := flag.String("passwd", "", "protected room's password")
89         muteToggle := flag.String("mute-toggle", "", "path to FIFO to toggle mute")
90         prefer4 := flag.Bool("4", false,
91                 "Prefer obsolete legacy IPv4 address during name resolution")
92         version := flag.Bool("version", false, "print version")
93         warranty := flag.Bool("warranty", false, "print warranty information")
94         flag.Parse()
95         log.SetFlags(log.Lmicroseconds | log.Lshortfile)
96
97         if *warranty {
98                 fmt.Println(vors.Warranty)
99                 return
100         }
101         if *version {
102                 fmt.Println(vors.GetVersion())
103                 return
104         }
105
106         srvPub, err := base64.RawURLEncoding.DecodeString(*srvPubB64)
107         if err != nil {
108                 log.Fatal(err)
109         }
110         *Name = strings.ReplaceAll(*Name, " ", "-")
111
112         go func() {
113                 if *muteToggle == "" {
114                         return
115                 }
116                 for {
117                         fd, err := os.OpenFile(*muteToggle, os.O_WRONLY, os.FileMode(0666))
118                         if err != nil {
119                                 log.Fatalln(err)
120                         }
121                         Muted = !Muted
122                         var reply string
123                         if Muted {
124                                 reply = "muted"
125                         } else {
126                                 reply = "unmuted"
127                         }
128                         fd.WriteString(reply + "\n")
129                         fd.Close()
130                         time.Sleep(time.Second)
131                 }
132         }()
133
134         vad := uint64(*vadRaw)
135         opusEnc := newOpusEnc()
136         var mic io.ReadCloser
137         if *recCmd != "" {
138                 cmd := vors.MakeCmd(*recCmd)
139                 mic, err = cmd.StdoutPipe()
140                 if err != nil {
141                         log.Fatal(err)
142                 }
143                 err = cmd.Start()
144                 if err != nil {
145                         log.Fatal(err)
146                 }
147         }
148
149         vors.PreferIPv4 = *prefer4
150         ctrl, err := net.DialTCP("tcp", nil, vors.MustResolveTCP(*srvAddr))
151         if err != nil {
152                 log.Fatalln("dial server:", err)
153         }
154         defer ctrl.Close()
155         if err = ctrl.SetNoDelay(true); err != nil {
156                 log.Fatalln("nodelay:", err)
157         }
158
159         hs, err := noise.NewHandshakeState(noise.Config{
160                 CipherSuite: vors.NoiseCipherSuite,
161                 Pattern:     noise.HandshakeNK,
162                 Initiator:   true,
163                 PeerStatic:  srvPub,
164                 Prologue:    []byte(vors.NoisePrologue),
165         })
166         if err != nil {
167                 log.Fatalln("noise.NewHandshakeState:", err)
168         }
169         buf, _, _, err := hs.WriteMessage(nil, []byte(*Name+" "+*Room+" "+*passwd))
170         if err != nil {
171                 log.Fatalln("handshake encrypt:", err)
172         }
173         buf = append(
174                 append(
175                         []byte(vors.NoisePrologue),
176                         byte((len(buf)&0xFF00)>>8),
177                         byte((len(buf)&0x00FF)>>0),
178                 ),
179                 buf...,
180         )
181         _, err = io.Copy(ctrl, bytes.NewReader(buf))
182         if err != nil {
183                 log.Fatalln("write handshake:", err)
184                 return
185         }
186         buf, err = vors.PktRead(ctrl)
187         if err != nil {
188                 log.Fatalln("read handshake:", err)
189         }
190         buf, txCS, rxCS, err := hs.ReadMessage(nil, buf)
191         if err != nil {
192                 log.Fatalln("handshake decrypt:", err)
193         }
194
195         rx := make(chan []byte)
196         go func() {
197                 for {
198                         buf, err := vors.PktRead(ctrl)
199                         if err != nil {
200                                 log.Println("rx", err)
201                                 break
202                         }
203                         buf, err = rxCS.Decrypt(buf[:0], nil, buf)
204                         if err != nil {
205                                 log.Println("rx decrypt", err)
206                                 break
207                         }
208                         rx <- buf
209                 }
210                 Finish <- struct{}{}
211         }()
212
213         srvAddrUDP := vors.MustResolveUDP(*srvAddr)
214         conn, err := net.DialUDP("udp", nil, srvAddrUDP)
215         if err != nil {
216                 log.Fatalln("connect:", err)
217         }
218         var sid byte
219         {
220                 cols := strings.Fields(string(buf))
221                 if cols[0] != "OK" || len(cols) != 2 {
222                         log.Fatalln("handshake failed:", cols)
223                 }
224                 var cookie vors.Cookie
225                 cookieRaw, err := hex.DecodeString(cols[1])
226                 if err != nil {
227                         log.Fatal(err)
228                 }
229                 copy(cookie[:], cookieRaw)
230                 timeout := time.NewTimer(vors.PingTime)
231                 defer func() {
232                         if !timeout.Stop() {
233                                 <-timeout.C
234                         }
235                 }()
236                 ticker := time.NewTicker(time.Second)
237                 if _, err = conn.Write(cookie[:]); err != nil {
238                         log.Fatalln("write:", err)
239                 }
240         WaitForCookieAcceptance:
241                 for {
242                         select {
243                         case <-timeout.C:
244                                 log.Fatalln("cookie acceptance timeout")
245                         case <-ticker.C:
246                                 if _, err = conn.Write(cookie[:]); err != nil {
247                                         log.Fatalln("write:", err)
248                                 }
249                         case buf = <-rx:
250                                 cols = strings.Fields(string(buf))
251                                 if cols[0] != "SID" || len(cols) != 2 {
252                                         log.Fatalln("cookie acceptance failed:", string(buf))
253                                 }
254                                 sid = parseSID(cols[1])
255                                 Streams[sid] = &Stream{name: *Name, stats: OurStats}
256                                 break WaitForCookieAcceptance
257                         }
258                 }
259                 if !timeout.Stop() {
260                         <-timeout.C
261                 }
262         }
263
264         var keyCiphOur []byte
265         var keyMACOur []byte
266         {
267                 xof, err := blake2s.NewXOF(chacha20.KeySize+16, nil)
268                 if err != nil {
269                         log.Fatalln(err)
270                 }
271                 xof.Write([]byte(vors.NoisePrologue))
272                 xof.Write(hs.ChannelBinding())
273                 buf := make([]byte, chacha20.KeySize+16)
274                 if _, err = io.ReadFull(xof, buf); err != nil {
275                         log.Fatalln(err)
276                 }
277                 keyCiphOur, keyMACOur = buf[:chacha20.KeySize], buf[chacha20.KeySize:]
278         }
279
280         seen := time.Now()
281
282         LoggerReady := make(chan struct{})
283         GUI, err = gocui.NewGui(gocui.OutputNormal)
284         if err != nil {
285                 log.Fatal(err)
286         }
287         defer GUI.Close()
288         GUI.SelFgColor = gocui.ColorCyan
289         GUI.Highlight = true
290         GUI.SetManagerFunc(guiLayout)
291         if err := GUI.SetKeybinding("", 'q', gocui.ModNone, guiQuit); err != nil {
292                 log.Fatal(err)
293         }
294         if err := GUI.SetKeybinding("", gocui.KeyEnter, gocui.ModNone, mute); err != nil {
295                 log.Fatal(err)
296         }
297
298         go func() {
299                 <-GUIReadyC
300                 v, err := GUI.View("logs")
301                 if err != nil {
302                         log.Fatal(err)
303                 }
304                 log.SetOutput(v)
305                 log.Println("connected", "sid:", sid,
306                         "addr:", conn.LocalAddr().String())
307                 close(LoggerReady)
308                 for {
309                         time.Sleep(vors.ScreenRefresh)
310                         GUI.Update(func(gui *gocui.Gui) error {
311                                 return nil
312                         })
313                 }
314         }()
315
316         go func() {
317                 <-Finish
318                 go GUI.Close()
319                 time.Sleep(100 * time.Millisecond)
320                 os.Exit(0)
321         }()
322
323         go func() {
324                 for {
325                         time.Sleep(vors.PingTime)
326                         buf, err := txCS.Encrypt(nil, nil, []byte(vors.CmdPing))
327                         if err != nil {
328                                 log.Fatalln("tx encrypt:", err)
329                         }
330                         if err = vors.PktWrite(ctrl, buf); err != nil {
331                                 log.Fatalln("tx:", err)
332                         }
333                 }
334         }()
335
336         go func(seen *time.Time) {
337                 var now time.Time
338                 for buf := range rx {
339                         if string(buf) == vors.CmdPong {
340                                 now = time.Now()
341                                 *seen = now
342                                 continue
343                         }
344                         cols := strings.Fields(string(buf))
345                         switch cols[0] {
346                         case vors.CmdAdd:
347                                 sidRaw, name, keyHex := cols[1], cols[2], cols[3]
348                                 log.Println("add", name, "sid:", sidRaw)
349                                 sid := parseSID(sidRaw)
350                                 key, err := hex.DecodeString(keyHex)
351                                 if err != nil {
352                                         log.Fatal(err)
353                                 }
354                                 keyCiph, keyMAC := key[:chacha20.KeySize], key[chacha20.KeySize:]
355                                 stream := &Stream{
356                                         name:  name,
357                                         in:    make(chan []byte, 1<<10),
358                                         stats: &Stats{dead: make(chan struct{})},
359                                 }
360                                 go func() {
361                                         dec, err := opus.NewDecoder(vors.Rate, 1)
362                                         if err != nil {
363                                                 log.Fatal(err)
364                                         }
365                                         if err = dec.SetComplexity(10); err != nil {
366                                                 log.Fatal(err)
367                                         }
368
369                                         var player io.WriteCloser
370                                         playerTx := make(chan []byte, 5)
371                                         var cmd *exec.Cmd
372                                         if *playCmd != "" {
373                                                 cmd = vors.MakeCmd(*playCmd)
374                                                 player, err = cmd.StdinPipe()
375                                                 if err != nil {
376                                                         log.Fatal(err)
377                                                 }
378                                                 err = cmd.Start()
379                                                 if err != nil {
380                                                         log.Fatal(err)
381                                                 }
382                                                 go func() {
383                                                         var pcmbuf []byte
384                                                         var ok bool
385                                                         var err error
386                                                         for {
387                                                                 for len(playerTx) > vors.MaxLost {
388                                                                         <-playerTx
389                                                                         stream.stats.reorder++
390                                                                 }
391                                                                 pcmbuf, ok = <-playerTx
392                                                                 if !ok {
393                                                                         break
394                                                                 }
395                                                                 if _, err = io.Copy(player,
396                                                                         bytes.NewReader(pcmbuf)); err != nil {
397                                                                         log.Println("play:", err)
398                                                                 }
399                                                         }
400                                                         cmd.Process.Kill()
401                                                 }()
402                                         }
403
404                                         var ciph *chacha20.Cipher
405                                         mac := siphash.New(keyMAC)
406                                         tag := make([]byte, siphash.Size)
407                                         var ctr uint32
408                                         pcm := make([]int16, vors.FrameLen)
409                                         nonce := make([]byte, 12)
410                                         var pkt []byte
411                                         lost := -1
412                                         var lastDur int
413                                         for buf := range stream.in {
414                                                 copy(nonce[len(nonce)-4:], buf)
415                                                 mac.Reset()
416                                                 if _, err = mac.Write(buf[:len(buf)-siphash.Size]); err != nil {
417                                                         log.Fatal(err)
418                                                 }
419                                                 mac.Sum(tag[:0])
420                                                 if subtle.ConstantTimeCompare(
421                                                         tag[:siphash.Size],
422                                                         buf[len(buf)-siphash.Size:],
423                                                 ) != 1 {
424                                                         stream.stats.bads++
425                                                         continue
426                                                 }
427                                                 ciph, err = chacha20.NewUnauthenticatedCipher(keyCiph, nonce)
428                                                 if err != nil {
429                                                         log.Fatal(err)
430                                                 }
431                                                 pkt = buf[4 : len(buf)-siphash.Size]
432                                                 ciph.XORKeyStream(pkt, pkt)
433
434                                                 ctr = binary.BigEndian.Uint32(nonce[len(nonce)-4:])
435                                                 if lost == -1 {
436                                                         // ignore the very first packet in the stream
437                                                         lost = 0
438                                                 } else {
439                                                         lost = int(ctr - (stream.ctr + 1))
440                                                 }
441                                                 stream.ctr = ctr
442                                                 stream.stats.lost += int64(lost)
443                                                 if lost > vors.MaxLost {
444                                                         lost = 0
445                                                 }
446                                                 for ; lost > 0; lost-- {
447                                                         lastDur, err = dec.LastPacketDuration()
448                                                         if err != nil {
449                                                                 log.Println("PLC:", err)
450                                                                 continue
451                                                         }
452                                                         err = dec.DecodePLC(pcm[:lastDur])
453                                                         if err != nil {
454                                                                 log.Println("PLC:", err)
455                                                                 continue
456                                                         }
457                                                         stream.stats.AddRMS(pcm)
458                                                         if cmd == nil {
459                                                                 continue
460                                                         }
461                                                         pcmbuf := make([]byte, 2*lastDur)
462                                                         pcmConv(pcmbuf, pcm[:lastDur])
463                                                         playerTx <- pcmbuf
464                                                 }
465                                                 _, err = dec.Decode(pkt, pcm)
466                                                 if err != nil {
467                                                         log.Println("decode:", err)
468                                                         continue
469                                                 }
470                                                 stream.stats.AddRMS(pcm)
471                                                 stream.stats.last = time.Now()
472                                                 if cmd == nil {
473                                                         continue
474                                                 }
475                                                 pcmbuf := make([]byte, 2*len(pcm))
476                                                 pcmConv(pcmbuf, pcm)
477                                                 playerTx <- pcmbuf
478                                         }
479                                         if cmd != nil {
480                                                 close(playerTx)
481                                         }
482                                 }()
483                                 go statsDrawer(stream.stats, stream.name)
484                                 Streams[sid] = stream
485                         case vors.CmdDel:
486                                 sid := parseSID(cols[1])
487                                 s := Streams[sid]
488                                 if s == nil {
489                                         log.Println("unknown sid:", sid)
490                                         continue
491                                 }
492                                 log.Println("del", s.name, "sid:", cols[1])
493                                 delete(Streams, sid)
494                                 close(s.in)
495                                 close(s.stats.dead)
496                         default:
497                                 log.Fatal("unknown cmd:", cols[0])
498                         }
499                 }
500         }(&seen)
501
502         go func(seen *time.Time) {
503                 for now := range time.Tick(vors.PingTime) {
504                         if seen.Add(2 * vors.PingTime).Before(now) {
505                                 log.Println("timeout:", seen)
506                                 Finish <- struct{}{}
507                                 break
508                         }
509                 }
510         }(&seen)
511
512         go func() {
513                 <-LoggerReady
514                 var n int
515                 var from *net.UDPAddr
516                 var err error
517                 var stream *Stream
518                 var ctr uint32
519                 for {
520                         buf := make([]byte, 2*vors.FrameLen)
521                         n, from, err = conn.ReadFromUDP(buf)
522                         if err != nil {
523                                 log.Println("recvfrom:", err)
524                                 Finish <- struct{}{}
525                                 break
526                         }
527                         if from.Port != srvAddrUDP.Port || !from.IP.Equal(srvAddrUDP.IP) {
528                                 log.Println("wrong addr:", from)
529                                 continue
530                         }
531                         if n <= 4+siphash.Size {
532                                 log.Println("too small:", n)
533                                 continue
534                         }
535                         stream = Streams[buf[0]]
536                         if stream == nil {
537                                 // log.Println("unknown stream:", buf[0])
538                                 continue
539                         }
540                         stream.stats.pkts++
541                         stream.stats.bytes += vors.IPHdrLen(from.IP) + 8 + uint64(n)
542                         ctr = binary.BigEndian.Uint32(buf)
543                         if ctr <= stream.ctr {
544                                 stream.stats.reorder++
545                                 continue
546                         }
547                         stream.in <- buf[:n]
548                 }
549         }()
550
551         go statsDrawer(OurStats, *Name)
552         go func() {
553                 <-LoggerReady
554                 for now := range time.NewTicker(time.Second).C {
555                         if !OurStats.last.Add(time.Second).Before(now) {
556                                 continue
557                         }
558                         OurStats.pkts++
559                         OurStats.bytes += vors.IPHdrLen(srvAddrUDP.IP) + 8 + 1
560                         if _, err = conn.Write([]byte{sid}); err != nil {
561                                 log.Println("send:", err)
562                         }
563                 }
564         }()
565         go func() {
566                 if *recCmd == "" {
567                         return
568                 }
569                 <-LoggerReady
570                 var ciph *chacha20.Cipher
571                 mac := siphash.New(keyMACOur)
572                 tag := make([]byte, siphash.Size)
573                 buf := make([]byte, 2*vors.FrameLen)
574                 pcm := make([]int16, vors.FrameLen)
575                 nonce := make([]byte, 12)
576                 nonce[len(nonce)-4] = sid
577                 var pkt []byte
578                 var n, i int
579                 for {
580                         _, err = io.ReadFull(mic, buf)
581                         if err != nil {
582                                 log.Println("mic:", err)
583                                 break
584                         }
585                         if Muted {
586                                 continue
587                         }
588                         for i = 0; i < vors.FrameLen; i++ {
589                                 pcm[i] = int16(uint16(buf[i*2+0]) | (uint16(buf[i*2+1]) << 8))
590                         }
591                         if vad != 0 && vors.RMS(pcm) < vad {
592                                 continue
593                         }
594                         n, err = opusEnc.Encode(pcm, buf[4:])
595                         if err != nil {
596                                 log.Fatal(err)
597                         }
598                         if n <= 2 {
599                                 // DTX
600                                 continue
601                         }
602
603                         incr(nonce[len(nonce)-3:])
604                         copy(buf, nonce[len(nonce)-4:])
605                         ciph, err = chacha20.NewUnauthenticatedCipher(keyCiphOur, nonce)
606                         if err != nil {
607                                 log.Fatal(err)
608                         }
609                         ciph.XORKeyStream(buf[4:4+n], buf[4:4+n])
610                         mac.Reset()
611                         if _, err = mac.Write(buf[:4+n]); err != nil {
612                                 log.Fatal(err)
613                         }
614                         mac.Sum(tag[:0])
615                         copy(buf[4+n:], tag)
616                         pkt = buf[:4+n+siphash.Size]
617
618                         OurStats.pkts++
619                         OurStats.bytes += vors.IPHdrLen(srvAddrUDP.IP) + 8 + uint64(len(pkt))
620                         OurStats.last = time.Now()
621                         OurStats.AddRMS(pcm)
622                         if _, err = conn.Write(pkt); err != nil {
623                                 log.Println("send:", err)
624                         }
625                 }
626         }()
627
628         err = GUI.MainLoop()
629         if err != nil && err != gocui.ErrQuit {
630                 log.Fatal(err)
631         }
632 }