]> Sergey Matveev's repositories - vors.git/blob - cmd/client/main.go
c0fc01550aa9a39223a50b3cdff480c22e826f3b2ab8bbcc0b930ac7d5736ca6
[vors.git] / cmd / client / main.go
1 // VoRS -- Vo(IP) Really Simple
2 // Copyright (C) 2024 Sergey Matveev <stargrave@stargrave.org>
3 //
4 // This program is free software: you can redistribute it and/or modify
5 // it under the terms of the GNU Affero General Public License as
6 // published by the Free Software Foundation, version 3 of the License.
7 //
8 // This program is distributed in the hope that it will be useful,
9 // but WITHOUT ANY WARRANTY; without even the implied warranty of
10 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11 // GNU General Public License for more details.
12 //
13 // You should have received a copy of the GNU General Public License
14 // along with this program.  If not, see <http://www.gnu.org/licenses/>.
15
16 package main
17
18 import (
19         "bytes"
20         "crypto/subtle"
21         "encoding/binary"
22         "encoding/hex"
23         "flag"
24         "io"
25         "log"
26         "net"
27         "os"
28         "os/exec"
29         "strconv"
30         "strings"
31         "time"
32
33         "github.com/flynn/noise"
34         "github.com/jroimartin/gocui"
35         vors "go.stargrave.org/vors/internal"
36         "golang.org/x/crypto/blake2s"
37         "golang.org/x/crypto/chacha20"
38         "golang.org/x/crypto/poly1305"
39         "gopkg.in/hraban/opus.v2"
40 )
41
42 type Stream struct {
43         name  string
44         ctr   uint32
45         in    chan []byte
46         stats *Stats
47 }
48
49 var (
50         Streams  = map[byte]*Stream{}
51         Finish   = make(chan struct{})
52         OurStats = &Stats{dead: make(chan struct{})}
53         Name     = flag.String("name", "test", "Username")
54         Muted    bool
55 )
56
57 func parseSID(s string) byte {
58         n, err := strconv.Atoi(s)
59         if err != nil {
60                 log.Fatal(err)
61         }
62         if n > 255 {
63                 log.Fatal("too big stream num")
64         }
65         return byte(n)
66 }
67
68 func makeCmd(cmd string) *exec.Cmd {
69         args := strings.Fields(cmd)
70         if len(args) == 1 {
71                 return exec.Command(args[0])
72         }
73         return exec.Command(args[0], args[1:]...)
74 }
75
76 func incr(data []byte) {
77         for i := len(data) - 1; i >= 0; i-- {
78                 data[i]++
79                 if data[i] != 0 {
80                         return
81                 }
82         }
83         panic("overflow")
84 }
85
86 const soxParams = "--no-show-progress --buffer 1920 --channels 1 --endian little --encoding signed --rate 48000 --bits 16 --type raw -"
87
88 func main() {
89         srvAddr := flag.String("srv", "vors.home.arpa:"+strconv.Itoa(vors.DefaultPort),
90                 "Host:TCP/UDP port to connect to")
91         srvPubHex := flag.String("pub", "", "Server's public key, hex")
92         recCmd := flag.String("rec", "rec "+soxParams, "rec command")
93         playCmd := flag.String("play", "play "+soxParams, "play command")
94         vadRaw := flag.Uint("vad", 0, "VAD threshold")
95         flag.Parse()
96         log.SetFlags(log.Lmicroseconds | log.Lshortfile)
97
98         srvPub, err := hex.DecodeString(*srvPubHex)
99         if err != nil {
100                 log.Fatal(err)
101         }
102
103         vad := uint64(*vadRaw)
104         opusEnc := newOpusEnc()
105         var mic io.ReadCloser
106         if *recCmd != "" {
107                 cmd := makeCmd(*recCmd)
108                 mic, err = cmd.StdoutPipe()
109                 if err != nil {
110                         log.Fatal(err)
111                 }
112                 err = cmd.Start()
113                 if err != nil {
114                         log.Fatal(err)
115                 }
116         }
117
118         ctrl, err := net.DialTCP("tcp", nil, vors.MustResolveTCP(*srvAddr))
119         if err != nil {
120                 log.Fatalln("dial server:", err)
121         }
122         defer ctrl.Close()
123         if err = ctrl.SetNoDelay(true); err != nil {
124                 log.Fatalln("nodelay:", err)
125         }
126         if _, err = io.Copy(ctrl, strings.NewReader(vors.NoisePrologue)); err != nil {
127                 log.Fatalln("handshake: write prologue", err)
128                 return
129         }
130
131         hs, err := noise.NewHandshakeState(noise.Config{
132                 CipherSuite: vors.NoiseCipherSuite,
133                 Pattern:     noise.HandshakeNK,
134                 Initiator:   true,
135                 PeerStatic:  srvPub,
136                 Prologue:    []byte(vors.NoisePrologue),
137         })
138         if err != nil {
139                 log.Fatalln("noise.NewHandshakeState:", err)
140         }
141         buf, _, _, err := hs.WriteMessage(nil, []byte(*Name))
142         if err != nil {
143                 log.Fatalln("handshake encrypt:", err)
144         }
145         if err = vors.PktWrite(ctrl, buf); err != nil {
146                 log.Fatalln("write handshake:", err)
147                 return
148         }
149         buf, err = vors.PktRead(ctrl)
150         if err != nil {
151                 log.Fatalln("read handshake:", err)
152         }
153         buf, txCS, rxCS, err := hs.ReadMessage(nil, buf)
154         if err != nil {
155                 log.Fatalln("handshake decrypt:", err)
156         }
157
158         rx := make(chan []byte)
159         go func() {
160                 for {
161                         buf, err := vors.PktRead(ctrl)
162                         if err != nil {
163                                 log.Println("rx", err)
164                                 break
165                         }
166                         buf, err = rxCS.Decrypt(buf[:0], nil, buf)
167                         if err != nil {
168                                 log.Println("rx decrypt", err)
169                                 break
170                         }
171                         rx <- buf
172                 }
173                 Finish <- struct{}{}
174         }()
175
176         srvAddrUDP := vors.MustResolveUDP(*srvAddr)
177         conn, err := net.DialUDP("udp", nil, srvAddrUDP)
178         if err != nil {
179                 log.Fatalln("connect:", err)
180         }
181         var sid byte
182         {
183                 cols := strings.Fields(string(buf))
184                 if cols[0] != "OK" || len(cols) != 2 {
185                         log.Fatalln("handshake failed:", cols)
186                 }
187                 var cookie vors.Cookie
188                 cookieRaw, err := hex.DecodeString(cols[1])
189                 if err != nil {
190                         log.Fatal(err)
191                 }
192                 copy(cookie[:], cookieRaw)
193                 timeout := time.NewTimer(vors.PingTime)
194                 defer func() {
195                         if !timeout.Stop() {
196                                 <-timeout.C
197                         }
198                 }()
199                 ticker := time.NewTicker(time.Second)
200                 if _, err = conn.Write(cookie[:]); err != nil {
201                         log.Fatalln("write:", err)
202                 }
203         WaitForCookieAcceptance:
204                 for {
205                         select {
206                         case <-timeout.C:
207                                 log.Fatalln("cookie acceptance timeout")
208                         case <-ticker.C:
209                                 if _, err = conn.Write(cookie[:]); err != nil {
210                                         log.Fatalln("write:", err)
211                                 }
212                         case buf = <-rx:
213                                 cols = strings.Fields(string(buf))
214                                 if cols[0] != "SID" || len(cols) != 2 {
215                                         log.Fatalln("cookie acceptance failed:", string(buf))
216                                 }
217                                 sid = parseSID(cols[1])
218                                 Streams[sid] = &Stream{name: *Name, stats: OurStats}
219                                 break WaitForCookieAcceptance
220                         }
221                 }
222                 if !timeout.Stop() {
223                         <-timeout.C
224                 }
225         }
226
227         var keyOur []byte
228         {
229                 h, err := blake2s.New256(hs.ChannelBinding())
230                 if err != nil {
231                         log.Fatalln(err)
232                 }
233                 h.Write([]byte(vors.NoisePrologue))
234                 keyOur = h.Sum(nil)
235         }
236
237         seen := time.Now()
238
239         LoggerReady := make(chan struct{})
240         GUI, err = gocui.NewGui(gocui.OutputNormal)
241         if err != nil {
242                 log.Fatal(err)
243         }
244         defer GUI.Close()
245         GUI.SelFgColor = gocui.ColorCyan
246         GUI.Highlight = true
247         GUI.SetManagerFunc(guiLayout)
248         if err := GUI.SetKeybinding("", gocui.KeyF10, gocui.ModNone, guiQuit); err != nil {
249                 log.Fatal(err)
250         }
251         if err := GUI.SetKeybinding("", gocui.KeyEnter, gocui.ModNone, mute); err != nil {
252                 log.Fatal(err)
253         }
254
255         go func() {
256                 <-GUIReadyC
257                 v, err := GUI.View("logs")
258                 if err != nil {
259                         log.Fatal(err)
260                 }
261                 log.SetOutput(v)
262                 log.Println("connected", "sid:", sid,
263                         "addr:", conn.LocalAddr().String())
264                 close(LoggerReady)
265                 for {
266                         time.Sleep(vors.ScreenRefresh)
267                         GUI.Update(func(gui *gocui.Gui) error {
268                                 return nil
269                         })
270                 }
271         }()
272
273         go func() {
274                 <-Finish
275                 go GUI.Close()
276                 time.Sleep(100 * time.Millisecond)
277                 os.Exit(0)
278         }()
279
280         go func() {
281                 for {
282                         time.Sleep(vors.PingTime)
283                         buf, err := txCS.Encrypt(nil, nil, []byte(vors.CmdPing))
284                         if err != nil {
285                                 log.Fatalln("tx encrypt:", err)
286                         }
287                         if err = vors.PktWrite(ctrl, buf); err != nil {
288                                 log.Fatalln("tx:", err)
289                         }
290                 }
291         }()
292
293         go func(seen *time.Time) {
294                 var now time.Time
295                 for buf := range rx {
296                         if string(buf) == vors.CmdPong {
297                                 now = time.Now()
298                                 *seen = now
299                                 continue
300                         }
301                         cols := strings.Fields(string(buf))
302                         switch cols[0] {
303                         case vors.CmdAdd:
304                                 sidRaw, name, keyHex := cols[1], cols[2], cols[3]
305                                 log.Println("add", name, "sid:", sidRaw)
306                                 sid := parseSID(sidRaw)
307                                 key, err := hex.DecodeString(keyHex)
308                                 if err != nil {
309                                         log.Fatal(err)
310                                 }
311                                 stream := &Stream{
312                                         name:  name,
313                                         in:    make(chan []byte, 1<<10),
314                                         stats: &Stats{dead: make(chan struct{})},
315                                 }
316                                 go func() {
317                                         dec, err := opus.NewDecoder(vors.Rate, 1)
318                                         if err != nil {
319                                                 log.Fatal(err)
320                                         }
321                                         if err = DecoderSetComplexity(dec, 10); err != nil {
322                                                 log.Fatal(err)
323                                         }
324
325                                         var player io.WriteCloser
326                                         playerTx := make(chan []byte, 5)
327                                         var cmd *exec.Cmd
328                                         if *playCmd != "" {
329                                                 cmd = makeCmd(*playCmd)
330                                                 player, err = cmd.StdinPipe()
331                                                 if err != nil {
332                                                         log.Fatal(err)
333                                                 }
334                                                 err = cmd.Start()
335                                                 if err != nil {
336                                                         log.Fatal(err)
337                                                 }
338                                                 go func() {
339                                                         var pcmbuf []byte
340                                                         var ok bool
341                                                         var err error
342                                                         for {
343                                                                 for len(playerTx) > 1 {
344                                                                         <-playerTx
345                                                                         stream.stats.reorder++
346                                                                 }
347                                                                 pcmbuf, ok = <-playerTx
348                                                                 if !ok {
349                                                                         break
350                                                                 }
351                                                                 if _, err = io.Copy(player,
352                                                                         bytes.NewReader(pcmbuf)); err != nil {
353                                                                         log.Println("play:", err)
354                                                                 }
355                                                         }
356                                                         cmd.Process.Kill()
357                                                 }()
358                                         }
359
360                                         var ciph *chacha20.Cipher
361                                         var macKey [32]byte
362                                         var mac *poly1305.MAC
363                                         tag := make([]byte, poly1305.TagSize)
364                                         var ctr uint32
365                                         pcm := make([]int16, vors.FrameLen)
366                                         nonce := make([]byte, 12)
367                                         var pkt []byte
368                                         lost := -1
369                                         var lastDur int
370                                         for buf := range stream.in {
371                                                 copy(nonce[len(nonce)-4:], buf)
372                                                 ciph, err = chacha20.NewUnauthenticatedCipher(key, nonce)
373                                                 if err != nil {
374                                                         log.Fatal(err)
375                                                 }
376                                                 clear(macKey[:])
377                                                 ciph.XORKeyStream(macKey[:], macKey[:])
378                                                 ciph.SetCounter(1)
379                                                 mac = poly1305.New(&macKey)
380                                                 if _, err = mac.Write(buf[4 : len(buf)-vors.TagLen]); err != nil {
381                                                         log.Fatal(err)
382                                                 }
383                                                 mac.Sum(tag[:0])
384                                                 if subtle.ConstantTimeCompare(
385                                                         tag[:vors.TagLen],
386                                                         buf[len(buf)-vors.TagLen:],
387                                                 ) != 1 {
388                                                         stream.stats.bads++
389                                                         continue
390                                                 }
391                                                 pkt = buf[4 : len(buf)-vors.TagLen]
392                                                 ciph.XORKeyStream(pkt, pkt)
393
394                                                 ctr = binary.BigEndian.Uint32(nonce[len(nonce)-4:])
395                                                 if lost == -1 {
396                                                         // ignore the very first packet in the stream
397                                                         lost = 0
398                                                 } else {
399                                                         lost = int(ctr - (stream.ctr + 1))
400                                                 }
401                                                 stream.ctr = ctr
402                                                 stream.stats.lost += int64(lost)
403                                                 if lost > vors.MaxLost {
404                                                         lost = 0
405                                                 }
406                                                 for ; lost > 0; lost-- {
407                                                         lastDur, err = dec.LastPacketDuration()
408                                                         if err != nil {
409                                                                 log.Println("PLC:", err)
410                                                                 continue
411                                                         }
412                                                         err = dec.DecodePLC(pcm[:lastDur])
413                                                         if err != nil {
414                                                                 log.Println("PLC:", err)
415                                                                 continue
416                                                         }
417                                                         stream.stats.AddRMS(pcm)
418                                                         if cmd == nil {
419                                                                 continue
420                                                         }
421                                                         pcmbuf := make([]byte, 2*lastDur)
422                                                         pcmConv(pcmbuf, pcm[:lastDur])
423                                                         playerTx <- pcmbuf
424                                                 }
425                                                 _, err = dec.Decode(pkt, pcm)
426                                                 if err != nil {
427                                                         log.Println("decode:", err)
428                                                         continue
429                                                 }
430                                                 stream.stats.AddRMS(pcm)
431                                                 stream.stats.last = time.Now()
432                                                 if cmd == nil {
433                                                         continue
434                                                 }
435                                                 pcmbuf := make([]byte, 2*len(pcm))
436                                                 pcmConv(pcmbuf, pcm)
437                                                 playerTx <- pcmbuf
438                                         }
439                                         if cmd != nil {
440                                                 close(playerTx)
441                                         }
442                                 }()
443                                 go statsDrawer(stream.stats, stream.name)
444                                 Streams[sid] = stream
445                         case vors.CmdDel:
446                                 sid := parseSID(cols[1])
447                                 s := Streams[sid]
448                                 if s == nil {
449                                         log.Println("unknown sid:", sid)
450                                         continue
451                                 }
452                                 log.Println("del", s.name, "sid:", cols[1])
453                                 delete(Streams, sid)
454                                 close(s.in)
455                                 close(s.stats.dead)
456                         default:
457                                 log.Fatal("unknown cmd:", cols[0])
458                         }
459                 }
460         }(&seen)
461
462         go func(seen *time.Time) {
463                 for now := range time.Tick(vors.PingTime) {
464                         if seen.Add(2 * vors.PingTime).Before(now) {
465                                 log.Println("timeout:", seen)
466                                 Finish <- struct{}{}
467                                 break
468                         }
469                 }
470         }(&seen)
471
472         go func() {
473                 <-LoggerReady
474                 var n int
475                 var from *net.UDPAddr
476                 var err error
477                 var stream *Stream
478                 var ctr uint32
479                 for {
480                         buf := make([]byte, 2*vors.FrameLen)
481                         n, from, err = conn.ReadFromUDP(buf)
482                         if err != nil {
483                                 log.Println("recvfrom:", err)
484                                 Finish <- struct{}{}
485                                 break
486                         }
487                         if from.Port != srvAddrUDP.Port || !from.IP.Equal(srvAddrUDP.IP) {
488                                 log.Println("wrong addr:", from)
489                                 continue
490                         }
491                         if n <= 4+vors.TagLen {
492                                 log.Println("too small:", n)
493                                 continue
494                         }
495                         stream = Streams[buf[0]]
496                         if stream == nil {
497                                 // log.Println("unknown stream:", buf[0])
498                                 continue
499                         }
500                         stream.stats.pkts++
501                         stream.stats.bytes += uint64(n)
502                         ctr = binary.BigEndian.Uint32(buf)
503                         if ctr <= stream.ctr {
504                                 stream.stats.reorder++
505                                 continue
506                         }
507                         stream.in <- buf[:n]
508                 }
509         }()
510
511         go statsDrawer(OurStats, *Name)
512         go func() {
513                 <-LoggerReady
514                 for {
515                         OurStats.pkts++
516                         OurStats.bytes += 1
517                         if _, err = conn.Write([]byte{sid}); err != nil {
518                                 log.Println("send:", err)
519                                 Finish <- struct{}{}
520                         }
521                         time.Sleep(time.Second)
522                 }
523         }()
524         go func() {
525                 if *recCmd == "" {
526                         return
527                 }
528                 <-LoggerReady
529                 var ciph *chacha20.Cipher
530                 var macKey [32]byte
531                 var mac *poly1305.MAC
532                 tag := make([]byte, poly1305.TagSize)
533                 buf := make([]byte, 2*vors.FrameLen)
534                 pcm := make([]int16, vors.FrameLen)
535                 nonce := make([]byte, 12)
536                 nonce[len(nonce)-4] = sid
537                 var pkt []byte
538                 var n, i int
539                 for {
540                         _, err = io.ReadFull(mic, buf)
541                         if err != nil {
542                                 log.Println("mic:", err)
543                                 break
544                         }
545                         if Muted {
546                                 continue
547                         }
548                         for i = 0; i < vors.FrameLen; i++ {
549                                 pcm[i] = int16(uint16(buf[i*2+0]) | (uint16(buf[i*2+1]) << 8))
550                         }
551                         if vad != 0 && vors.RMS(pcm) < vad {
552                                 continue
553                         }
554                         n, err = opusEnc.Encode(pcm, buf[4:])
555                         if err != nil {
556                                 log.Fatal(err)
557                         }
558                         if n <= 2 {
559                                 // DTX
560                                 continue
561                         }
562
563                         incr(nonce[len(nonce)-3:])
564                         copy(buf, nonce[len(nonce)-4:])
565                         ciph, err = chacha20.NewUnauthenticatedCipher(keyOur, nonce)
566                         if err != nil {
567                                 log.Fatal(err)
568                         }
569                         clear(macKey[:])
570                         ciph.XORKeyStream(macKey[:], macKey[:])
571                         ciph.SetCounter(1)
572                         ciph.XORKeyStream(buf[4:4+n], buf[4:4+n])
573                         mac = poly1305.New(&macKey)
574                         if _, err = mac.Write(buf[4 : 4+n]); err != nil {
575                                 log.Fatal(err)
576                         }
577                         mac.Sum(tag[:0])
578                         copy(buf[4+n:], tag[:vors.TagLen])
579                         pkt = buf[:4+n+vors.TagLen]
580
581                         OurStats.pkts++
582                         OurStats.bytes += uint64(len(pkt))
583                         OurStats.last = time.Now()
584                         OurStats.AddRMS(pcm)
585                         if _, err = conn.Write(pkt); err != nil {
586                                 log.Println("send:", err)
587                                 break
588                         }
589                 }
590         }()
591
592         err = GUI.MainLoop()
593         if err != nil && err != gocui.ErrQuit {
594                 log.Fatal(err)
595         }
596 }