]> Sergey Matveev's repositories - vors.git/blob - cmd/client/main.go
Rooms support
[vors.git] / cmd / client / main.go
1 // VoRS -- Vo(IP) Really Simple
2 // Copyright (C) 2024 Sergey Matveev <stargrave@stargrave.org>
3 //
4 // This program is free software: you can redistribute it and/or modify
5 // it under the terms of the GNU Affero General Public License as
6 // published by the Free Software Foundation, version 3 of the License.
7 //
8 // This program is distributed in the hope that it will be useful,
9 // but WITHOUT ANY WARRANTY; without even the implied warranty of
10 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11 // GNU General Public License for more details.
12 //
13 // You should have received a copy of the GNU General Public License
14 // along with this program.  If not, see <http://www.gnu.org/licenses/>.
15
16 package main
17
18 import (
19         "bytes"
20         "crypto/subtle"
21         "encoding/binary"
22         "encoding/hex"
23         "flag"
24         "io"
25         "log"
26         "net"
27         "os"
28         "os/exec"
29         "strconv"
30         "strings"
31         "time"
32
33         "github.com/flynn/noise"
34         "github.com/jroimartin/gocui"
35         vors "go.stargrave.org/vors/internal"
36         "golang.org/x/crypto/blake2s"
37         "golang.org/x/crypto/chacha20"
38         "golang.org/x/crypto/poly1305"
39         "gopkg.in/hraban/opus.v2"
40 )
41
42 type Stream struct {
43         name  string
44         ctr   uint32
45         in    chan []byte
46         stats *Stats
47 }
48
49 var (
50         Streams  = map[byte]*Stream{}
51         Finish   = make(chan struct{})
52         OurStats = &Stats{dead: make(chan struct{})}
53         Name     = flag.String("name", "test", "Username")
54         Room     = flag.String("room", "/", "Room name")
55         Muted    bool
56 )
57
58 func parseSID(s string) byte {
59         n, err := strconv.Atoi(s)
60         if err != nil {
61                 log.Fatal(err)
62         }
63         if n > 255 {
64                 log.Fatal("too big stream num")
65         }
66         return byte(n)
67 }
68
69 func makeCmd(cmd string) *exec.Cmd {
70         args := strings.Fields(cmd)
71         if len(args) == 1 {
72                 return exec.Command(args[0])
73         }
74         return exec.Command(args[0], args[1:]...)
75 }
76
77 func incr(data []byte) {
78         for i := len(data) - 1; i >= 0; i-- {
79                 data[i]++
80                 if data[i] != 0 {
81                         return
82                 }
83         }
84         panic("overflow")
85 }
86
87 const soxParams = "--no-show-progress --buffer 1920 --channels 1 --endian little --encoding signed --rate 48000 --bits 16 --type raw -"
88
89 func main() {
90         srvAddr := flag.String("srv", "vors.home.arpa:"+strconv.Itoa(vors.DefaultPort),
91                 "Host:TCP/UDP port to connect to")
92         srvPubHex := flag.String("pub", "", "Server's public key, hex")
93         recCmd := flag.String("rec", "rec "+soxParams, "rec command")
94         playCmd := flag.String("play", "play "+soxParams, "play command")
95         vadRaw := flag.Uint("vad", 0, "VAD threshold")
96         passwd := flag.String("passwd", "", "Protected room's password")
97         flag.Parse()
98         log.SetFlags(log.Lmicroseconds | log.Lshortfile)
99
100         srvPub, err := hex.DecodeString(*srvPubHex)
101         if err != nil {
102                 log.Fatal(err)
103         }
104         *Name = strings.ReplaceAll(*Name, " ", "-")
105
106         vad := uint64(*vadRaw)
107         opusEnc := newOpusEnc()
108         var mic io.ReadCloser
109         if *recCmd != "" {
110                 cmd := makeCmd(*recCmd)
111                 mic, err = cmd.StdoutPipe()
112                 if err != nil {
113                         log.Fatal(err)
114                 }
115                 err = cmd.Start()
116                 if err != nil {
117                         log.Fatal(err)
118                 }
119         }
120
121         ctrl, err := net.DialTCP("tcp", nil, vors.MustResolveTCP(*srvAddr))
122         if err != nil {
123                 log.Fatalln("dial server:", err)
124         }
125         defer ctrl.Close()
126         if err = ctrl.SetNoDelay(true); err != nil {
127                 log.Fatalln("nodelay:", err)
128         }
129         if _, err = io.Copy(ctrl, strings.NewReader(vors.NoisePrologue)); err != nil {
130                 log.Fatalln("handshake: write prologue", err)
131                 return
132         }
133
134         hs, err := noise.NewHandshakeState(noise.Config{
135                 CipherSuite: vors.NoiseCipherSuite,
136                 Pattern:     noise.HandshakeNK,
137                 Initiator:   true,
138                 PeerStatic:  srvPub,
139                 Prologue:    []byte(vors.NoisePrologue),
140         })
141         if err != nil {
142                 log.Fatalln("noise.NewHandshakeState:", err)
143         }
144         buf, _, _, err := hs.WriteMessage(nil, []byte(*Name+" "+*Room+" "+*passwd))
145         if err != nil {
146                 log.Fatalln("handshake encrypt:", err)
147         }
148         if err = vors.PktWrite(ctrl, buf); err != nil {
149                 log.Fatalln("write handshake:", err)
150                 return
151         }
152         buf, err = vors.PktRead(ctrl)
153         if err != nil {
154                 log.Fatalln("read handshake:", err)
155         }
156         buf, txCS, rxCS, err := hs.ReadMessage(nil, buf)
157         if err != nil {
158                 log.Fatalln("handshake decrypt:", err)
159         }
160
161         rx := make(chan []byte)
162         go func() {
163                 for {
164                         buf, err := vors.PktRead(ctrl)
165                         if err != nil {
166                                 log.Println("rx", err)
167                                 break
168                         }
169                         buf, err = rxCS.Decrypt(buf[:0], nil, buf)
170                         if err != nil {
171                                 log.Println("rx decrypt", err)
172                                 break
173                         }
174                         rx <- buf
175                 }
176                 Finish <- struct{}{}
177         }()
178
179         srvAddrUDP := vors.MustResolveUDP(*srvAddr)
180         conn, err := net.DialUDP("udp", nil, srvAddrUDP)
181         if err != nil {
182                 log.Fatalln("connect:", err)
183         }
184         var sid byte
185         {
186                 cols := strings.Fields(string(buf))
187                 if cols[0] != "OK" || len(cols) != 2 {
188                         log.Fatalln("handshake failed:", cols)
189                 }
190                 var cookie vors.Cookie
191                 cookieRaw, err := hex.DecodeString(cols[1])
192                 if err != nil {
193                         log.Fatal(err)
194                 }
195                 copy(cookie[:], cookieRaw)
196                 timeout := time.NewTimer(vors.PingTime)
197                 defer func() {
198                         if !timeout.Stop() {
199                                 <-timeout.C
200                         }
201                 }()
202                 ticker := time.NewTicker(time.Second)
203                 if _, err = conn.Write(cookie[:]); err != nil {
204                         log.Fatalln("write:", err)
205                 }
206         WaitForCookieAcceptance:
207                 for {
208                         select {
209                         case <-timeout.C:
210                                 log.Fatalln("cookie acceptance timeout")
211                         case <-ticker.C:
212                                 if _, err = conn.Write(cookie[:]); err != nil {
213                                         log.Fatalln("write:", err)
214                                 }
215                         case buf = <-rx:
216                                 cols = strings.Fields(string(buf))
217                                 if cols[0] != "SID" || len(cols) != 2 {
218                                         log.Fatalln("cookie acceptance failed:", string(buf))
219                                 }
220                                 sid = parseSID(cols[1])
221                                 Streams[sid] = &Stream{name: *Name, stats: OurStats}
222                                 break WaitForCookieAcceptance
223                         }
224                 }
225                 if !timeout.Stop() {
226                         <-timeout.C
227                 }
228         }
229
230         var keyOur []byte
231         {
232                 h, err := blake2s.New256(hs.ChannelBinding())
233                 if err != nil {
234                         log.Fatalln(err)
235                 }
236                 h.Write([]byte(vors.NoisePrologue))
237                 keyOur = h.Sum(nil)
238         }
239
240         seen := time.Now()
241
242         LoggerReady := make(chan struct{})
243         GUI, err = gocui.NewGui(gocui.OutputNormal)
244         if err != nil {
245                 log.Fatal(err)
246         }
247         defer GUI.Close()
248         GUI.SelFgColor = gocui.ColorCyan
249         GUI.Highlight = true
250         GUI.SetManagerFunc(guiLayout)
251         if err := GUI.SetKeybinding("", gocui.KeyF10, gocui.ModNone, guiQuit); err != nil {
252                 log.Fatal(err)
253         }
254         if err := GUI.SetKeybinding("", gocui.KeyEnter, gocui.ModNone, mute); err != nil {
255                 log.Fatal(err)
256         }
257
258         go func() {
259                 <-GUIReadyC
260                 v, err := GUI.View("logs")
261                 if err != nil {
262                         log.Fatal(err)
263                 }
264                 log.SetOutput(v)
265                 log.Println("connected", "sid:", sid,
266                         "addr:", conn.LocalAddr().String())
267                 close(LoggerReady)
268                 for {
269                         time.Sleep(vors.ScreenRefresh)
270                         GUI.Update(func(gui *gocui.Gui) error {
271                                 return nil
272                         })
273                 }
274         }()
275
276         go func() {
277                 <-Finish
278                 go GUI.Close()
279                 time.Sleep(100 * time.Millisecond)
280                 os.Exit(0)
281         }()
282
283         go func() {
284                 for {
285                         time.Sleep(vors.PingTime)
286                         buf, err := txCS.Encrypt(nil, nil, []byte(vors.CmdPing))
287                         if err != nil {
288                                 log.Fatalln("tx encrypt:", err)
289                         }
290                         if err = vors.PktWrite(ctrl, buf); err != nil {
291                                 log.Fatalln("tx:", err)
292                         }
293                 }
294         }()
295
296         go func(seen *time.Time) {
297                 var now time.Time
298                 for buf := range rx {
299                         if string(buf) == vors.CmdPong {
300                                 now = time.Now()
301                                 *seen = now
302                                 continue
303                         }
304                         cols := strings.Fields(string(buf))
305                         switch cols[0] {
306                         case vors.CmdAdd:
307                                 sidRaw, name, keyHex := cols[1], cols[2], cols[3]
308                                 log.Println("add", name, "sid:", sidRaw)
309                                 sid := parseSID(sidRaw)
310                                 key, err := hex.DecodeString(keyHex)
311                                 if err != nil {
312                                         log.Fatal(err)
313                                 }
314                                 stream := &Stream{
315                                         name:  name,
316                                         in:    make(chan []byte, 1<<10),
317                                         stats: &Stats{dead: make(chan struct{})},
318                                 }
319                                 go func() {
320                                         dec, err := opus.NewDecoder(vors.Rate, 1)
321                                         if err != nil {
322                                                 log.Fatal(err)
323                                         }
324                                         if err = DecoderSetComplexity(dec, 10); err != nil {
325                                                 log.Fatal(err)
326                                         }
327
328                                         var player io.WriteCloser
329                                         playerTx := make(chan []byte, 5)
330                                         var cmd *exec.Cmd
331                                         if *playCmd != "" {
332                                                 cmd = makeCmd(*playCmd)
333                                                 player, err = cmd.StdinPipe()
334                                                 if err != nil {
335                                                         log.Fatal(err)
336                                                 }
337                                                 err = cmd.Start()
338                                                 if err != nil {
339                                                         log.Fatal(err)
340                                                 }
341                                                 go func() {
342                                                         var pcmbuf []byte
343                                                         var ok bool
344                                                         var err error
345                                                         for {
346                                                                 for len(playerTx) > 1 {
347                                                                         <-playerTx
348                                                                         stream.stats.reorder++
349                                                                 }
350                                                                 pcmbuf, ok = <-playerTx
351                                                                 if !ok {
352                                                                         break
353                                                                 }
354                                                                 if _, err = io.Copy(player,
355                                                                         bytes.NewReader(pcmbuf)); err != nil {
356                                                                         log.Println("play:", err)
357                                                                 }
358                                                         }
359                                                         cmd.Process.Kill()
360                                                 }()
361                                         }
362
363                                         var ciph *chacha20.Cipher
364                                         var macKey [32]byte
365                                         var mac *poly1305.MAC
366                                         tag := make([]byte, poly1305.TagSize)
367                                         var ctr uint32
368                                         pcm := make([]int16, vors.FrameLen)
369                                         nonce := make([]byte, 12)
370                                         var pkt []byte
371                                         lost := -1
372                                         var lastDur int
373                                         for buf := range stream.in {
374                                                 copy(nonce[len(nonce)-4:], buf)
375                                                 ciph, err = chacha20.NewUnauthenticatedCipher(key, nonce)
376                                                 if err != nil {
377                                                         log.Fatal(err)
378                                                 }
379                                                 clear(macKey[:])
380                                                 ciph.XORKeyStream(macKey[:], macKey[:])
381                                                 ciph.SetCounter(1)
382                                                 mac = poly1305.New(&macKey)
383                                                 if _, err = mac.Write(buf[4 : len(buf)-vors.TagLen]); err != nil {
384                                                         log.Fatal(err)
385                                                 }
386                                                 mac.Sum(tag[:0])
387                                                 if subtle.ConstantTimeCompare(
388                                                         tag[:vors.TagLen],
389                                                         buf[len(buf)-vors.TagLen:],
390                                                 ) != 1 {
391                                                         stream.stats.bads++
392                                                         continue
393                                                 }
394                                                 pkt = buf[4 : len(buf)-vors.TagLen]
395                                                 ciph.XORKeyStream(pkt, pkt)
396
397                                                 ctr = binary.BigEndian.Uint32(nonce[len(nonce)-4:])
398                                                 if lost == -1 {
399                                                         // ignore the very first packet in the stream
400                                                         lost = 0
401                                                 } else {
402                                                         lost = int(ctr - (stream.ctr + 1))
403                                                 }
404                                                 stream.ctr = ctr
405                                                 stream.stats.lost += int64(lost)
406                                                 if lost > vors.MaxLost {
407                                                         lost = 0
408                                                 }
409                                                 for ; lost > 0; lost-- {
410                                                         lastDur, err = dec.LastPacketDuration()
411                                                         if err != nil {
412                                                                 log.Println("PLC:", err)
413                                                                 continue
414                                                         }
415                                                         err = dec.DecodePLC(pcm[:lastDur])
416                                                         if err != nil {
417                                                                 log.Println("PLC:", err)
418                                                                 continue
419                                                         }
420                                                         stream.stats.AddRMS(pcm)
421                                                         if cmd == nil {
422                                                                 continue
423                                                         }
424                                                         pcmbuf := make([]byte, 2*lastDur)
425                                                         pcmConv(pcmbuf, pcm[:lastDur])
426                                                         playerTx <- pcmbuf
427                                                 }
428                                                 _, err = dec.Decode(pkt, pcm)
429                                                 if err != nil {
430                                                         log.Println("decode:", err)
431                                                         continue
432                                                 }
433                                                 stream.stats.AddRMS(pcm)
434                                                 stream.stats.last = time.Now()
435                                                 if cmd == nil {
436                                                         continue
437                                                 }
438                                                 pcmbuf := make([]byte, 2*len(pcm))
439                                                 pcmConv(pcmbuf, pcm)
440                                                 playerTx <- pcmbuf
441                                         }
442                                         if cmd != nil {
443                                                 close(playerTx)
444                                         }
445                                 }()
446                                 go statsDrawer(stream.stats, stream.name)
447                                 Streams[sid] = stream
448                         case vors.CmdDel:
449                                 sid := parseSID(cols[1])
450                                 s := Streams[sid]
451                                 if s == nil {
452                                         log.Println("unknown sid:", sid)
453                                         continue
454                                 }
455                                 log.Println("del", s.name, "sid:", cols[1])
456                                 delete(Streams, sid)
457                                 close(s.in)
458                                 close(s.stats.dead)
459                         default:
460                                 log.Fatal("unknown cmd:", cols[0])
461                         }
462                 }
463         }(&seen)
464
465         go func(seen *time.Time) {
466                 for now := range time.Tick(vors.PingTime) {
467                         if seen.Add(2 * vors.PingTime).Before(now) {
468                                 log.Println("timeout:", seen)
469                                 Finish <- struct{}{}
470                                 break
471                         }
472                 }
473         }(&seen)
474
475         go func() {
476                 <-LoggerReady
477                 var n int
478                 var from *net.UDPAddr
479                 var err error
480                 var stream *Stream
481                 var ctr uint32
482                 for {
483                         buf := make([]byte, 2*vors.FrameLen)
484                         n, from, err = conn.ReadFromUDP(buf)
485                         if err != nil {
486                                 log.Println("recvfrom:", err)
487                                 Finish <- struct{}{}
488                                 break
489                         }
490                         if from.Port != srvAddrUDP.Port || !from.IP.Equal(srvAddrUDP.IP) {
491                                 log.Println("wrong addr:", from)
492                                 continue
493                         }
494                         if n <= 4+vors.TagLen {
495                                 log.Println("too small:", n)
496                                 continue
497                         }
498                         stream = Streams[buf[0]]
499                         if stream == nil {
500                                 // log.Println("unknown stream:", buf[0])
501                                 continue
502                         }
503                         stream.stats.pkts++
504                         stream.stats.bytes += uint64(n)
505                         ctr = binary.BigEndian.Uint32(buf)
506                         if ctr <= stream.ctr {
507                                 stream.stats.reorder++
508                                 continue
509                         }
510                         stream.in <- buf[:n]
511                 }
512         }()
513
514         go statsDrawer(OurStats, *Name)
515         go func() {
516                 <-LoggerReady
517                 for {
518                         OurStats.pkts++
519                         OurStats.bytes += 1
520                         if _, err = conn.Write([]byte{sid}); err != nil {
521                                 log.Println("send:", err)
522                                 Finish <- struct{}{}
523                         }
524                         time.Sleep(time.Second)
525                 }
526         }()
527         go func() {
528                 if *recCmd == "" {
529                         return
530                 }
531                 <-LoggerReady
532                 var ciph *chacha20.Cipher
533                 var macKey [32]byte
534                 var mac *poly1305.MAC
535                 tag := make([]byte, poly1305.TagSize)
536                 buf := make([]byte, 2*vors.FrameLen)
537                 pcm := make([]int16, vors.FrameLen)
538                 nonce := make([]byte, 12)
539                 nonce[len(nonce)-4] = sid
540                 var pkt []byte
541                 var n, i int
542                 for {
543                         _, err = io.ReadFull(mic, buf)
544                         if err != nil {
545                                 log.Println("mic:", err)
546                                 break
547                         }
548                         if Muted {
549                                 continue
550                         }
551                         for i = 0; i < vors.FrameLen; i++ {
552                                 pcm[i] = int16(uint16(buf[i*2+0]) | (uint16(buf[i*2+1]) << 8))
553                         }
554                         if vad != 0 && vors.RMS(pcm) < vad {
555                                 continue
556                         }
557                         n, err = opusEnc.Encode(pcm, buf[4:])
558                         if err != nil {
559                                 log.Fatal(err)
560                         }
561                         if n <= 2 {
562                                 // DTX
563                                 continue
564                         }
565
566                         incr(nonce[len(nonce)-3:])
567                         copy(buf, nonce[len(nonce)-4:])
568                         ciph, err = chacha20.NewUnauthenticatedCipher(keyOur, nonce)
569                         if err != nil {
570                                 log.Fatal(err)
571                         }
572                         clear(macKey[:])
573                         ciph.XORKeyStream(macKey[:], macKey[:])
574                         ciph.SetCounter(1)
575                         ciph.XORKeyStream(buf[4:4+n], buf[4:4+n])
576                         mac = poly1305.New(&macKey)
577                         if _, err = mac.Write(buf[4 : 4+n]); err != nil {
578                                 log.Fatal(err)
579                         }
580                         mac.Sum(tag[:0])
581                         copy(buf[4+n:], tag[:vors.TagLen])
582                         pkt = buf[:4+n+vors.TagLen]
583
584                         OurStats.pkts++
585                         OurStats.bytes += uint64(len(pkt))
586                         OurStats.last = time.Now()
587                         OurStats.AddRMS(pcm)
588                         if _, err = conn.Write(pkt); err != nil {
589                                 log.Println("send:", err)
590                                 break
591                         }
592                 }
593         }()
594
595         err = GUI.MainLoop()
596         if err != nil && err != gocui.ErrQuit {
597                 log.Fatal(err)
598         }
599 }