]> Sergey Matveev's repositories - vors.git/blob - cmd/client/main.go
6b4fe68911b6d3a8cd92ab588d51b5dca1aa84466a1d9fd47235c6172bbdbd48
[vors.git] / cmd / client / main.go
1 // VoRS -- Vo(IP) Really Simple
2 // Copyright (C) 2024 Sergey Matveev <stargrave@stargrave.org>
3 //
4 // This program is free software: you can redistribute it and/or modify
5 // it under the terms of the GNU Affero General Public License as
6 // published by the Free Software Foundation, version 3 of the License.
7 //
8 // This program is distributed in the hope that it will be useful,
9 // but WITHOUT ANY WARRANTY; without even the implied warranty of
10 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11 // GNU General Public License for more details.
12 //
13 // You should have received a copy of the GNU General Public License
14 // along with this program.  If not, see <http://www.gnu.org/licenses/>.
15
16 package main
17
18 import (
19         "bufio"
20         "bytes"
21         "crypto/tls"
22         "crypto/x509"
23         "encoding/binary"
24         "encoding/hex"
25         "errors"
26         "flag"
27         "fmt"
28         "io"
29         "log"
30         "net"
31         "net/netip"
32         "os"
33         "os/exec"
34         "strconv"
35         "strings"
36         "time"
37
38         "github.com/jroimartin/gocui"
39         vors "go.stargrave.org/vors/internal"
40         "golang.org/x/crypto/blake2s"
41         "golang.org/x/crypto/chacha20"
42         "golang.org/x/crypto/chacha20poly1305"
43         "golang.org/x/crypto/poly1305"
44         "gopkg.in/hraban/opus.v2"
45 )
46
47 type Stream struct {
48         name  string
49         ctr   uint32
50         in    chan []byte
51         stats *Stats
52 }
53
54 var (
55         Streams  = map[byte]*Stream{}
56         Finish   = make(chan struct{})
57         OurStats = &Stats{dead: make(chan struct{})}
58         Name     = flag.String("name", "test", "Username")
59         Muted    bool
60 )
61
62 func parseSID(s string) byte {
63         n, err := strconv.Atoi(s)
64         if err != nil {
65                 log.Fatal(err)
66         }
67         if n > 255 {
68                 log.Fatal("too big stream num")
69         }
70         return byte(n)
71 }
72
73 func makeCmd(cmd string) *exec.Cmd {
74         args := strings.Fields(cmd)
75         if len(args) == 1 {
76                 return exec.Command(args[0])
77         }
78         return exec.Command(args[0], args[1:]...)
79 }
80
81 func incr(data []byte) {
82         for i := len(data) - 1; i >= 0; i-- {
83                 data[i]++
84                 if data[i] != 0 {
85                         return
86                 }
87         }
88         panic("overflow")
89 }
90
91 func main() {
92         vadRaw := flag.Uint("vad", 0, "VAD threshold")
93         hostport := flag.String("srv", "[::1]:12345", "TCP/UDP port to connect to")
94         spkiHash := flag.String("spki", "FILL-ME", "SHA256 hash of server's certificate SPKI")
95         passwd := flag.String("passwd", "", "Password")
96         recCmd := flag.String("rec", "rec --no-show-progress --buffer 1920 --channels 1 --endian little --encoding signed --rate 48000 --bits 16 --type raw -", "rec command")
97         playCmd := flag.String("play", "play --no-show-progress --buffer 1920 --channels 1 --endian little --encoding signed --rate 48000 --bits 16 --type raw -", "play command")
98         flag.Parse()
99         log.SetFlags(log.Lmicroseconds | log.Lshortfile)
100
101         vad := uint64(*vadRaw)
102         opusEnc := newOpusEnc()
103         var mic io.ReadCloser
104         var err error
105         if *recCmd != "" {
106                 cmd := makeCmd(*recCmd)
107                 mic, err = cmd.StdoutPipe()
108                 if err != nil {
109                         log.Fatal(err)
110                 }
111                 err = cmd.Start()
112                 if err != nil {
113                         log.Fatal(err)
114                 }
115         }
116
117         addrTCP, err := net.ResolveTCPAddr("tcp", *hostport)
118         if err != nil {
119                 log.Fatal(err)
120         }
121         addrUDP, err := net.ResolveUDPAddr("udp", *hostport)
122         if err != nil {
123                 log.Fatal(err)
124         }
125         ctrlRaw, err := net.DialTCP("tcp", nil, addrTCP)
126         if err != nil {
127                 log.Fatalln("dial server:", err)
128         }
129         defer ctrlRaw.Close()
130         ourAddr := net.UDPAddrFromAddrPort(
131                 netip.MustParseAddrPort(ctrlRaw.LocalAddr().String()))
132         ln, err := net.ListenUDP("udp", ourAddr)
133         if err != nil {
134                 log.Fatal(err)
135         }
136         ctrl := tls.Client(ctrlRaw, &tls.Config{
137                 MinVersion:         tls.VersionTLS13,
138                 CurvePreferences:   []tls.CurveID{tls.X25519},
139                 ServerName:         vors.CN,
140                 InsecureSkipVerify: true,
141                 VerifyPeerCertificate: func(
142                         rawCerts [][]byte, verifiedChains [][]*x509.Certificate,
143                 ) error {
144                         cer, err := x509.ParseCertificate(rawCerts[0])
145                         if err != nil {
146                                 return err
147                         }
148                         if *spkiHash != vors.SPKIHash(cer) {
149                                 return errors.New("server certificate's SPKI hash mismatch")
150                         }
151                         return nil
152                 },
153         })
154         err = ctrl.Handshake()
155         if err != nil {
156                 log.Println("TLS handshake:", err)
157                 return
158         }
159         defer ctrl.Close()
160
161         scanner := bufio.NewScanner(ctrl)
162         if !scanner.Scan() {
163                 log.Println("read challenge:", scanner.Err())
164                 return
165         }
166         {
167                 h, err := blake2s.New256([]byte(*passwd))
168                 if err != nil {
169                         log.Fatal(err)
170                 }
171                 h.Write(scanner.Bytes())
172                 if _, err = io.Copy(ctrl, strings.NewReader(fmt.Sprintf(
173                         "%s %s\n", hex.EncodeToString(h.Sum(nil)), *Name))); err != nil {
174                         log.Println("write password:", err)
175                         return
176                 }
177         }
178         if !scanner.Scan() {
179                 log.Println("auth", scanner.Err())
180                 return
181         }
182         cols := strings.Fields(scanner.Text())
183         if cols[0] != "OK" {
184                 log.Println("auth failed:", scanner.Text())
185                 return
186         }
187         sid := parseSID(cols[1])
188         Streams[sid] = &Stream{name: *Name, stats: OurStats}
189
190         tlsState := ctrl.ConnectionState()
191         key, err := tlsState.ExportKeyingMaterial(
192                 cols[1], nil, chacha20poly1305.KeySize)
193         if err != nil {
194                 log.Fatal(err)
195         }
196         ciph, err := chacha20poly1305.New(key)
197         if err != nil {
198                 log.Fatal(err)
199         }
200         seen := time.Now()
201
202         LoggerReady := make(chan struct{})
203         GUI, err = gocui.NewGui(gocui.OutputNormal)
204         if err != nil {
205                 log.Fatal(err)
206         }
207         defer GUI.Close()
208         GUI.SelFgColor = gocui.ColorCyan
209         GUI.Highlight = true
210         GUI.SetManagerFunc(guiLayout)
211         if err := GUI.SetKeybinding("", gocui.KeyF10, gocui.ModNone, guiQuit); err != nil {
212                 log.Fatal(err)
213         }
214         if err := GUI.SetKeybinding("", gocui.KeyEnter, gocui.ModNone, mute); err != nil {
215                 log.Fatal(err)
216         }
217
218         go func() {
219                 <-GUIReadyC
220                 v, err := GUI.View("logs")
221                 if err != nil {
222                         log.Fatal(err)
223                 }
224                 log.SetOutput(v)
225                 log.Println("connected")
226                 close(LoggerReady)
227                 for {
228                         time.Sleep(vors.ScreenRefresh)
229                         GUI.Update(func(gui *gocui.Gui) error {
230                                 return nil
231                         })
232                 }
233         }()
234
235         go func() {
236                 <-Finish
237                 go GUI.Close()
238                 time.Sleep(100 * time.Millisecond)
239                 os.Exit(0)
240         }()
241
242         go func() {
243                 var err error
244                 for {
245                         time.Sleep(vors.PingTime)
246                         if _, err = ctrl.Write([]byte(vors.CmdPing + "\n")); err != nil {
247                                 log.Println("ping:", err)
248                                 Finish <- struct{}{}
249                                 break
250                         }
251                 }
252         }()
253
254         go func(seen *time.Time) {
255                 var t string
256                 var now time.Time
257                 for scanner.Scan() {
258                         t = scanner.Text()
259                         if t == vors.CmdPong {
260                                 now = time.Now()
261                                 *seen = now
262                                 continue
263                         }
264                         cols := strings.Fields(t)
265                         switch cols[0] {
266                         case vors.CmdAdd:
267                                 sidRaw, name, keyHex := cols[1], cols[2], cols[3]
268                                 log.Println("add", name)
269                                 sid := parseSID(sidRaw)
270                                 key, err := hex.DecodeString(keyHex)
271                                 if err != nil {
272                                         log.Fatal(err)
273                                 }
274                                 stream := &Stream{
275                                         name:  name,
276                                         in:    make(chan []byte, 1<<10),
277                                         stats: &Stats{dead: make(chan struct{})},
278                                 }
279                                 go func() {
280                                         ciph, err := chacha20poly1305.New(key)
281                                         if err != nil {
282                                                 log.Fatal(err)
283                                         }
284                                         dec, err := opus.NewDecoder(vors.Rate, 1)
285                                         if err != nil {
286                                                 log.Fatal(err)
287                                         }
288
289                                         var player io.WriteCloser
290                                         var cmd *exec.Cmd
291                                         if *playCmd != "" {
292                                                 cmd = makeCmd(*playCmd)
293                                                 player, err = cmd.StdinPipe()
294                                                 if err != nil {
295                                                         log.Fatal(err)
296                                                 }
297                                                 err = cmd.Start()
298                                                 if err != nil {
299                                                         log.Fatal(err)
300                                                 }
301                                         }
302
303                                         ctr := uint32(sid) << 24
304                                         pcm := make([]int16, vors.FrameLen)
305                                         pcmbuf := make([]byte, 2*vors.FrameLen)
306                                         decbuf := make([]byte, 2*vors.FrameLen)
307                                         nonce := make([]byte, chacha20.NonceSize)
308                                         ctrBuf := nonce[len(nonce)-4:]
309                                         var pkt []byte
310                                         lost := -1
311                                         var lastDur int
312                                         for buf := range stream.in {
313                                                 ctr = binary.BigEndian.Uint32(buf)
314                                                 copy(ctrBuf, buf)
315                                                 pkt, err = ciph.Open(
316                                                         decbuf[:0], nonce, buf[4:], []byte{buf[0]})
317                                                 if err != nil {
318                                                         log.Println("decrypt:", stream.name, err)
319                                                         continue
320                                                 }
321                                                 if lost == -1 {
322                                                         // ignore the very first packet in the stream
323                                                         lost = 0
324                                                 } else {
325                                                         lost = int(ctr - (stream.ctr + 1))
326                                                 }
327                                                 stream.ctr = ctr
328                                                 stream.stats.lost += int64(lost)
329                                                 if lost > vors.MaxLost {
330                                                         lost = 0
331                                                 }
332                                                 for ; lost > 0; lost-- {
333                                                         lastDur, err = dec.LastPacketDuration()
334                                                         if err != nil {
335                                                                 log.Println("PLC:", err)
336                                                                 continue
337                                                         }
338                                                         err = dec.DecodePLC(pcm[:lastDur])
339                                                         if err != nil {
340                                                                 log.Println("PLC:", err)
341                                                                 continue
342                                                         }
343                                                         stream.stats.AddRMS(pcm)
344                                                         if cmd == nil {
345                                                                 continue
346                                                         }
347                                                         pcmConv(pcmbuf, pcm[:lastDur])
348                                                         if _, err = io.Copy(player, bytes.NewReader(
349                                                                 pcmbuf[:2*lastDur])); err != nil {
350                                                                 log.Println("play:", err)
351                                                         }
352                                                 }
353                                                 _, err = dec.Decode(pkt, pcm)
354                                                 if err != nil {
355                                                         log.Println("decode:", err)
356                                                         continue
357                                                 }
358                                                 stream.stats.AddRMS(pcm)
359                                                 stream.stats.last = time.Now()
360                                                 if cmd == nil {
361                                                         continue
362                                                 }
363                                                 pcmConv(pcmbuf, pcm)
364                                                 if _, err = io.Copy(player,
365                                                         bytes.NewReader(pcmbuf)); err != nil {
366                                                         log.Println("play:", err)
367                                                 }
368                                         }
369                                         if cmd != nil {
370                                                 cmd.Process.Kill()
371                                         }
372                                 }()
373                                 go statsDrawer(stream.stats, stream.name)
374                                 Streams[sid] = stream
375                         case vors.CmdDel:
376                                 sid := parseSID(cols[1])
377                                 s := Streams[sid]
378                                 if s == nil {
379                                         log.Println("unknown sid:", sid)
380                                         continue
381                                 }
382                                 delete(Streams, sid)
383                                 close(s.in)
384                                 close(s.stats.dead)
385                                 log.Println("del", s.name)
386                         default:
387                                 log.Fatal("unknown cmd:", cols[0])
388                         }
389                 }
390                 if scanner.Err() != nil {
391                         log.Print("scanner:", err)
392                         Finish <- struct{}{}
393                 }
394         }(&seen)
395
396         go func(seen *time.Time) {
397                 for now := range time.Tick(vors.PingTime) {
398                         if seen.Add(2 * vors.PingTime).Before(now) {
399                                 log.Println("timeout:", seen)
400                                 Finish <- struct{}{}
401                                 break
402                         }
403                 }
404         }(&seen)
405
406         go func() {
407                 <-LoggerReady
408                 var n int
409                 var from *net.UDPAddr
410                 var err error
411                 var stream *Stream
412                 var ctr uint32
413                 for {
414                         buf := make([]byte, 2*vors.FrameLen)
415                         n, from, err = ln.ReadFromUDP(buf)
416                         if err != nil {
417                                 log.Println("recvfrom:", err)
418                                 Finish <- struct{}{}
419                                 break
420                         }
421                         if from.Port != addrUDP.Port || !from.IP.Equal(addrUDP.IP) {
422                                 log.Println("wrong addr:", from)
423                                 continue
424                         }
425                         if n <= 1+4+poly1305.TagSize {
426                                 log.Println("too small:", n)
427                                 continue
428                         }
429                         stream = Streams[buf[0]]
430                         if stream == nil {
431                                 log.Println("unknown stream:", buf[0])
432                                 continue
433                         }
434                         stream.stats.pkts++
435                         stream.stats.bytes += uint64(n)
436                         ctr = binary.BigEndian.Uint32(buf)
437                         if ctr <= stream.ctr {
438                                 stream.stats.reorder++
439                                 continue
440                         }
441                         stream.in <- buf[:n]
442                 }
443         }()
444
445         go statsDrawer(OurStats, *Name)
446         go func() {
447                 <-LoggerReady
448                 for {
449                         OurStats.pkts++
450                         OurStats.bytes += 1
451                         if _, err = ln.WriteTo([]byte{sid}, addrUDP); err != nil {
452                                 log.Println("send:", err)
453                                 Finish <- struct{}{}
454                         }
455                         time.Sleep(time.Second)
456                 }
457         }()
458         go func() {
459                 if *recCmd == "" {
460                         return
461                 }
462                 <-LoggerReady
463                 buf := make([]byte, 2*vors.FrameLen)
464                 pcm := make([]int16, vors.FrameLen)
465                 nonce := make([]byte, ciph.NonceSize())
466                 nonce[len(nonce)-4] = sid
467                 ctr := nonce[len(nonce)-3:]
468                 sidAndCtr := nonce[len(nonce)-4:]
469                 var pkt []byte
470                 var n, i int
471                 for {
472                         _, err = io.ReadFull(mic, buf)
473                         if err != nil {
474                                 log.Println("mic:", err)
475                                 break
476                         }
477                         if Muted {
478                                 continue
479                         }
480                         for i = 0; i < vors.FrameLen; i++ {
481                                 pcm[i] = int16(uint16(buf[i*2+0]) | (uint16(buf[i*2+1]) << 8))
482                         }
483                         if vad != 0 && vors.RMS(pcm) < vad {
484                                 continue
485                         }
486                         incr(ctr)
487                         copy(buf, sidAndCtr)
488                         n, err = opusEnc.Encode(pcm, buf[4:])
489                         if err != nil {
490                                 log.Fatal(err)
491                         }
492                         pkt = ciph.Seal(buf[:4], nonce, buf[4:4+n], []byte{sid})
493                         OurStats.pkts++
494                         OurStats.bytes += uint64(len(pkt))
495                         OurStats.last = time.Now()
496                         OurStats.AddRMS(pcm)
497                         if _, err = ln.WriteTo(pkt, addrUDP); err != nil {
498                                 log.Println("send:", err)
499                                 break
500                         }
501                 }
502         }()
503
504         err = GUI.MainLoop()
505         if err != nil && err != gocui.ErrQuit {
506                 log.Fatal(err)
507         }
508 }