]> Sergey Matveev's repositories - vors.git/blob - cmd/client/main.go
Shorter MAC
[vors.git] / cmd / client / main.go
1 // VoRS -- Vo(IP) Really Simple
2 // Copyright (C) 2024 Sergey Matveev <stargrave@stargrave.org>
3 //
4 // This program is free software: you can redistribute it and/or modify
5 // it under the terms of the GNU Affero General Public License as
6 // published by the Free Software Foundation, version 3 of the License.
7 //
8 // This program is distributed in the hope that it will be useful,
9 // but WITHOUT ANY WARRANTY; without even the implied warranty of
10 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11 // GNU General Public License for more details.
12 //
13 // You should have received a copy of the GNU General Public License
14 // along with this program.  If not, see <http://www.gnu.org/licenses/>.
15
16 package main
17
18 import (
19         "bufio"
20         "bytes"
21         "crypto/subtle"
22         "crypto/tls"
23         "crypto/x509"
24         "encoding/binary"
25         "encoding/hex"
26         "errors"
27         "flag"
28         "fmt"
29         "io"
30         "log"
31         "net"
32         "net/netip"
33         "os"
34         "os/exec"
35         "strconv"
36         "strings"
37         "time"
38
39         "github.com/jroimartin/gocui"
40         vors "go.stargrave.org/vors/internal"
41         "golang.org/x/crypto/blake2s"
42         "golang.org/x/crypto/chacha20"
43         "golang.org/x/crypto/poly1305"
44         "gopkg.in/hraban/opus.v2"
45 )
46
47 const TagLen = 8
48
49 type Stream struct {
50         name  string
51         ctr   uint32
52         in    chan []byte
53         stats *Stats
54 }
55
56 var (
57         Streams  = map[byte]*Stream{}
58         Finish   = make(chan struct{})
59         OurStats = &Stats{dead: make(chan struct{})}
60         Name     = flag.String("name", "test", "Username")
61         Muted    bool
62 )
63
64 func parseSID(s string) byte {
65         n, err := strconv.Atoi(s)
66         if err != nil {
67                 log.Fatal(err)
68         }
69         if n > 255 {
70                 log.Fatal("too big stream num")
71         }
72         return byte(n)
73 }
74
75 func makeCmd(cmd string) *exec.Cmd {
76         args := strings.Fields(cmd)
77         if len(args) == 1 {
78                 return exec.Command(args[0])
79         }
80         return exec.Command(args[0], args[1:]...)
81 }
82
83 func incr(data []byte) {
84         for i := len(data) - 1; i >= 0; i-- {
85                 data[i]++
86                 if data[i] != 0 {
87                         return
88                 }
89         }
90         panic("overflow")
91 }
92
93 func main() {
94         vadRaw := flag.Uint("vad", 0, "VAD threshold")
95         hostport := flag.String("srv", "[::1]:12345", "TCP/UDP port to connect to")
96         spkiHash := flag.String("spki", "FILL-ME", "SHA256 hash of server's certificate SPKI")
97         passwd := flag.String("passwd", "", "Password")
98         recCmd := flag.String("rec", "rec --no-show-progress --buffer 1920 --channels 1 --endian little --encoding signed --rate 48000 --bits 16 --type raw -", "rec command")
99         playCmd := flag.String("play", "play --no-show-progress --buffer 1920 --channels 1 --endian little --encoding signed --rate 48000 --bits 16 --type raw -", "play command")
100         flag.Parse()
101         log.SetFlags(log.Lmicroseconds | log.Lshortfile)
102
103         vad := uint64(*vadRaw)
104         opusEnc := newOpusEnc()
105         var mic io.ReadCloser
106         var err error
107         if *recCmd != "" {
108                 cmd := makeCmd(*recCmd)
109                 mic, err = cmd.StdoutPipe()
110                 if err != nil {
111                         log.Fatal(err)
112                 }
113                 err = cmd.Start()
114                 if err != nil {
115                         log.Fatal(err)
116                 }
117         }
118
119         addrTCP, err := net.ResolveTCPAddr("tcp", *hostport)
120         if err != nil {
121                 log.Fatal(err)
122         }
123         addrUDP, err := net.ResolveUDPAddr("udp", *hostport)
124         if err != nil {
125                 log.Fatal(err)
126         }
127         ctrlRaw, err := net.DialTCP("tcp", nil, addrTCP)
128         if err != nil {
129                 log.Fatalln("dial server:", err)
130         }
131         defer ctrlRaw.Close()
132         ourAddr := net.UDPAddrFromAddrPort(
133                 netip.MustParseAddrPort(ctrlRaw.LocalAddr().String()))
134         ln, err := net.ListenUDP("udp", ourAddr)
135         if err != nil {
136                 log.Fatal(err)
137         }
138         ctrl := tls.Client(ctrlRaw, &tls.Config{
139                 MinVersion:         tls.VersionTLS13,
140                 CurvePreferences:   []tls.CurveID{tls.X25519},
141                 ServerName:         vors.CN,
142                 InsecureSkipVerify: true,
143                 VerifyPeerCertificate: func(
144                         rawCerts [][]byte, verifiedChains [][]*x509.Certificate,
145                 ) error {
146                         cer, err := x509.ParseCertificate(rawCerts[0])
147                         if err != nil {
148                                 return err
149                         }
150                         if *spkiHash != vors.SPKIHash(cer) {
151                                 return errors.New("server certificate's SPKI hash mismatch")
152                         }
153                         return nil
154                 },
155         })
156         err = ctrl.Handshake()
157         if err != nil {
158                 log.Println("TLS handshake:", err)
159                 return
160         }
161         defer ctrl.Close()
162
163         scanner := bufio.NewScanner(ctrl)
164         if !scanner.Scan() {
165                 log.Println("read challenge:", scanner.Err())
166                 return
167         }
168         {
169                 h, err := blake2s.New256([]byte(*passwd))
170                 if err != nil {
171                         log.Fatal(err)
172                 }
173                 h.Write(scanner.Bytes())
174                 if _, err = io.Copy(ctrl, strings.NewReader(fmt.Sprintf(
175                         "%s %s\n", hex.EncodeToString(h.Sum(nil)), *Name))); err != nil {
176                         log.Println("write password:", err)
177                         return
178                 }
179         }
180         if !scanner.Scan() {
181                 log.Println("auth", scanner.Err())
182                 return
183         }
184         cols := strings.Fields(scanner.Text())
185         if cols[0] != "OK" {
186                 log.Println("auth failed:", scanner.Text())
187                 return
188         }
189         sid := parseSID(cols[1])
190         Streams[sid] = &Stream{name: *Name, stats: OurStats}
191
192         tlsState := ctrl.ConnectionState()
193         keyOur, err := tlsState.ExportKeyingMaterial(cols[1], nil, chacha20.KeySize)
194         if err != nil {
195                 log.Fatal(err)
196         }
197         seen := time.Now()
198
199         LoggerReady := make(chan struct{})
200         GUI, err = gocui.NewGui(gocui.OutputNormal)
201         if err != nil {
202                 log.Fatal(err)
203         }
204         defer GUI.Close()
205         GUI.SelFgColor = gocui.ColorCyan
206         GUI.Highlight = true
207         GUI.SetManagerFunc(guiLayout)
208         if err := GUI.SetKeybinding("", gocui.KeyF10, gocui.ModNone, guiQuit); err != nil {
209                 log.Fatal(err)
210         }
211         if err := GUI.SetKeybinding("", gocui.KeyEnter, gocui.ModNone, mute); err != nil {
212                 log.Fatal(err)
213         }
214
215         go func() {
216                 <-GUIReadyC
217                 v, err := GUI.View("logs")
218                 if err != nil {
219                         log.Fatal(err)
220                 }
221                 log.SetOutput(v)
222                 log.Println("connected")
223                 close(LoggerReady)
224                 for {
225                         time.Sleep(vors.ScreenRefresh)
226                         GUI.Update(func(gui *gocui.Gui) error {
227                                 return nil
228                         })
229                 }
230         }()
231
232         go func() {
233                 <-Finish
234                 go GUI.Close()
235                 time.Sleep(100 * time.Millisecond)
236                 os.Exit(0)
237         }()
238
239         go func() {
240                 var err error
241                 for {
242                         time.Sleep(vors.PingTime)
243                         if _, err = ctrl.Write([]byte(vors.CmdPing + "\n")); err != nil {
244                                 log.Println("ping:", err)
245                                 Finish <- struct{}{}
246                                 break
247                         }
248                 }
249         }()
250
251         go func(seen *time.Time) {
252                 var t string
253                 var now time.Time
254                 for scanner.Scan() {
255                         t = scanner.Text()
256                         if t == vors.CmdPong {
257                                 now = time.Now()
258                                 *seen = now
259                                 continue
260                         }
261                         cols := strings.Fields(t)
262                         switch cols[0] {
263                         case vors.CmdAdd:
264                                 sidRaw, name, keyHex := cols[1], cols[2], cols[3]
265                                 log.Println("add", name)
266                                 sid := parseSID(sidRaw)
267                                 key, err := hex.DecodeString(keyHex)
268                                 if err != nil {
269                                         log.Fatal(err)
270                                 }
271                                 stream := &Stream{
272                                         name:  name,
273                                         in:    make(chan []byte, 1<<10),
274                                         stats: &Stats{dead: make(chan struct{})},
275                                 }
276                                 go func() {
277                                         dec, err := opus.NewDecoder(vors.Rate, 1)
278                                         if err != nil {
279                                                 log.Fatal(err)
280                                         }
281
282                                         var player io.WriteCloser
283                                         var cmd *exec.Cmd
284                                         if *playCmd != "" {
285                                                 cmd = makeCmd(*playCmd)
286                                                 player, err = cmd.StdinPipe()
287                                                 if err != nil {
288                                                         log.Fatal(err)
289                                                 }
290                                                 err = cmd.Start()
291                                                 if err != nil {
292                                                         log.Fatal(err)
293                                                 }
294                                         }
295
296                                         var ciph *chacha20.Cipher
297                                         var macKey [32]byte
298                                         var mac *poly1305.MAC
299                                         tag := make([]byte, poly1305.TagSize)
300                                         var ctr uint32
301                                         pcm := make([]int16, vors.FrameLen)
302                                         pcmbuf := make([]byte, 2*vors.FrameLen)
303                                         nonce := make([]byte, 12)
304                                         var pkt []byte
305                                         lost := -1
306                                         var lastDur int
307                                         for buf := range stream.in {
308                                                 copy(nonce[len(nonce)-4:], buf)
309                                                 ciph, err = chacha20.NewUnauthenticatedCipher(key, nonce)
310                                                 if err != nil {
311                                                         log.Fatal(err)
312                                                 }
313                                                 clear(macKey[:])
314                                                 ciph.XORKeyStream(macKey[:], macKey[:])
315                                                 ciph.SetCounter(1)
316                                                 mac = poly1305.New(&macKey)
317                                                 if _, err = mac.Write(buf[4 : len(buf)-TagLen]); err != nil {
318                                                         log.Fatal(err)
319                                                 }
320                                                 mac.Sum(tag[:0])
321                                                 if subtle.ConstantTimeCompare(tag[:TagLen], buf[len(buf)-TagLen:]) != 1 {
322                                                         log.Println("decrypt:", stream.name, "tag differs")
323                                                         continue
324                                                 }
325                                                 pkt = buf[4 : len(buf)-TagLen]
326                                                 ciph.XORKeyStream(pkt, pkt)
327
328                                                 ctr = binary.BigEndian.Uint32(nonce[len(nonce)-4:])
329                                                 if lost == -1 {
330                                                         // ignore the very first packet in the stream
331                                                         lost = 0
332                                                 } else {
333                                                         lost = int(ctr - (stream.ctr + 1))
334                                                 }
335                                                 stream.ctr = ctr
336                                                 stream.stats.lost += int64(lost)
337                                                 if lost > vors.MaxLost {
338                                                         lost = 0
339                                                 }
340                                                 for ; lost > 0; lost-- {
341                                                         lastDur, err = dec.LastPacketDuration()
342                                                         if err != nil {
343                                                                 log.Println("PLC:", err)
344                                                                 continue
345                                                         }
346                                                         err = dec.DecodePLC(pcm[:lastDur])
347                                                         if err != nil {
348                                                                 log.Println("PLC:", err)
349                                                                 continue
350                                                         }
351                                                         stream.stats.AddRMS(pcm)
352                                                         if cmd == nil {
353                                                                 continue
354                                                         }
355                                                         pcmConv(pcmbuf, pcm[:lastDur])
356                                                         if _, err = io.Copy(player, bytes.NewReader(
357                                                                 pcmbuf[:2*lastDur])); err != nil {
358                                                                 log.Println("play:", err)
359                                                         }
360                                                 }
361                                                 _, err = dec.Decode(pkt, pcm)
362                                                 if err != nil {
363                                                         log.Println("decode:", err)
364                                                         continue
365                                                 }
366                                                 stream.stats.AddRMS(pcm)
367                                                 stream.stats.last = time.Now()
368                                                 if cmd == nil {
369                                                         continue
370                                                 }
371                                                 pcmConv(pcmbuf, pcm)
372                                                 if _, err = io.Copy(player,
373                                                         bytes.NewReader(pcmbuf)); err != nil {
374                                                         log.Println("play:", err)
375                                                 }
376                                         }
377                                         if cmd != nil {
378                                                 cmd.Process.Kill()
379                                         }
380                                 }()
381                                 go statsDrawer(stream.stats, stream.name)
382                                 Streams[sid] = stream
383                         case vors.CmdDel:
384                                 sid := parseSID(cols[1])
385                                 s := Streams[sid]
386                                 if s == nil {
387                                         log.Println("unknown sid:", sid)
388                                         continue
389                                 }
390                                 delete(Streams, sid)
391                                 close(s.in)
392                                 close(s.stats.dead)
393                                 log.Println("del", s.name)
394                         default:
395                                 log.Fatal("unknown cmd:", cols[0])
396                         }
397                 }
398                 if scanner.Err() != nil {
399                         log.Print("scanner:", err)
400                         Finish <- struct{}{}
401                 }
402         }(&seen)
403
404         go func(seen *time.Time) {
405                 for now := range time.Tick(vors.PingTime) {
406                         if seen.Add(2 * vors.PingTime).Before(now) {
407                                 log.Println("timeout:", seen)
408                                 Finish <- struct{}{}
409                                 break
410                         }
411                 }
412         }(&seen)
413
414         go func() {
415                 <-LoggerReady
416                 var n int
417                 var from *net.UDPAddr
418                 var err error
419                 var stream *Stream
420                 var ctr uint32
421                 for {
422                         buf := make([]byte, 2*vors.FrameLen)
423                         n, from, err = ln.ReadFromUDP(buf)
424                         if err != nil {
425                                 log.Println("recvfrom:", err)
426                                 Finish <- struct{}{}
427                                 break
428                         }
429                         if from.Port != addrUDP.Port || !from.IP.Equal(addrUDP.IP) {
430                                 log.Println("wrong addr:", from)
431                                 continue
432                         }
433                         if n <= 1+4+poly1305.TagSize {
434                                 log.Println("too small:", n)
435                                 continue
436                         }
437                         stream = Streams[buf[0]]
438                         if stream == nil {
439                                 log.Println("unknown stream:", buf[0])
440                                 continue
441                         }
442                         stream.stats.pkts++
443                         stream.stats.bytes += uint64(n)
444                         ctr = binary.BigEndian.Uint32(buf)
445                         if ctr <= stream.ctr {
446                                 stream.stats.reorder++
447                                 continue
448                         }
449                         stream.in <- buf[:n]
450                 }
451         }()
452
453         go statsDrawer(OurStats, *Name)
454         go func() {
455                 <-LoggerReady
456                 for {
457                         OurStats.pkts++
458                         OurStats.bytes += 1
459                         if _, err = ln.WriteTo([]byte{sid}, addrUDP); err != nil {
460                                 log.Println("send:", err)
461                                 Finish <- struct{}{}
462                         }
463                         time.Sleep(time.Second)
464                 }
465         }()
466         go func() {
467                 if *recCmd == "" {
468                         return
469                 }
470                 <-LoggerReady
471                 var ciph *chacha20.Cipher
472                 var macKey [32]byte
473                 var mac *poly1305.MAC
474                 tag := make([]byte, poly1305.TagSize)
475                 buf := make([]byte, 2*vors.FrameLen)
476                 pcm := make([]int16, vors.FrameLen)
477                 nonce := make([]byte, 12)
478                 nonce[len(nonce)-4] = sid
479                 var pkt []byte
480                 var n, i int
481                 for {
482                         _, err = io.ReadFull(mic, buf)
483                         if err != nil {
484                                 log.Println("mic:", err)
485                                 break
486                         }
487                         if Muted {
488                                 continue
489                         }
490                         for i = 0; i < vors.FrameLen; i++ {
491                                 pcm[i] = int16(uint16(buf[i*2+0]) | (uint16(buf[i*2+1]) << 8))
492                         }
493                         if vad != 0 && vors.RMS(pcm) < vad {
494                                 continue
495                         }
496                         n, err = opusEnc.Encode(pcm, buf[4:])
497                         if err != nil {
498                                 log.Fatal(err)
499                         }
500
501                         incr(nonce[len(nonce)-3:])
502                         copy(buf, nonce[len(nonce)-4:])
503                         ciph, err = chacha20.NewUnauthenticatedCipher(keyOur, nonce)
504                         if err != nil {
505                                 log.Fatal(err)
506                         }
507                         clear(macKey[:])
508                         ciph.XORKeyStream(macKey[:], macKey[:])
509                         ciph.SetCounter(1)
510                         ciph.XORKeyStream(buf[4:4+n], buf[4:4+n])
511                         mac = poly1305.New(&macKey)
512                         if _, err = mac.Write(buf[4 : 4+n]); err != nil {
513                                 log.Fatal(err)
514                         }
515                         mac.Sum(tag[:0])
516                         copy(buf[4+n:], tag[:TagLen])
517                         pkt = buf[:4+n+TagLen]
518
519                         OurStats.pkts++
520                         OurStats.bytes += uint64(len(pkt))
521                         OurStats.last = time.Now()
522                         OurStats.AddRMS(pcm)
523                         if _, err = ln.WriteTo(pkt, addrUDP); err != nil {
524                                 log.Println("send:", err)
525                                 break
526                         }
527                 }
528         }()
529
530         err = GUI.MainLoop()
531         if err != nil && err != gocui.ErrQuit {
532                 log.Fatal(err)
533         }
534 }