From baca8be01cf7e405e00966bae3cec23337e3314bd9480bb02f8902caf8f404c6 Mon Sep 17 00:00:00 2001 From: Sergey Matveev Date: Sun, 28 Apr 2024 22:15:50 +0300 Subject: [PATCH] Use mutexes over iterable maps for reliability --- cmd/client/gui.go | 4 ++++ cmd/client/main.go | 8 ++++++++ cmd/server/gui.go | 2 ++ cmd/server/main.go | 44 ++++++++++++++++++++++++++++++++------------ cmd/server/peer.go | 5 ++++- cmd/server/room.go | 13 ++++++++----- 6 files changed, 58 insertions(+), 18 deletions(-) diff --git a/cmd/client/gui.go b/cmd/client/gui.go index 34f1370..4e9eeab 100644 --- a/cmd/client/gui.go +++ b/cmd/client/gui.go @@ -35,9 +35,11 @@ var ( func tabHandle(gui *gocui.Gui, v *gocui.View) error { sids := make([]int, 0, len(Streams)+1) sids = append(sids, -1) + StreamsM.RLock() for sid := range Streams { sids = append(sids, int(sid)) } + StreamsM.RUnlock() sort.Ints(sids) if CurrentView+1 >= len(sids) { CurrentView = 0 @@ -92,9 +94,11 @@ func guiLayout(gui *gocui.Gui) error { } prevY += 3 sids := make([]int, 0, len(Streams)) + StreamsM.RLock() for sid := range Streams { sids = append(sids, int(sid)) } + StreamsM.RUnlock() sort.Ints(sids) for _, sid := range sids { stream := Streams[byte(sid)] diff --git a/cmd/client/main.go b/cmd/client/main.go index f485d51..d41b8d9 100644 --- a/cmd/client/main.go +++ b/cmd/client/main.go @@ -29,6 +29,7 @@ import ( "os/exec" "strconv" "strings" + "sync" "time" "github.com/dchest/siphash" @@ -53,6 +54,7 @@ type Stream struct { var ( Streams = map[byte]*Stream{} + StreamsM sync.RWMutex Finish = make(chan struct{}) OurStats = &Stats{dead: make(chan struct{})} Name = flag.String("name", "test", "username") @@ -274,7 +276,9 @@ func main() { log.Fatalln("cookie acceptance failed:", string(args[1])) case vors.CmdSID: sid = args[1][0] + StreamsM.Lock() Streams[sid] = &Stream{name: *Name, stats: OurStats} + StreamsM.Unlock() default: log.Fatalln("unexpected post-cookie cmd:", cmd) } @@ -534,7 +538,9 @@ func main() { } }() go statsDrawer(stream) + StreamsM.Lock() Streams[sid] = stream + StreamsM.Unlock() case vors.CmdDel: sid := args[1][0] s := Streams[sid] @@ -543,7 +549,9 @@ func main() { continue } log.Println("del", s.name, "sid:", sid) + StreamsM.Lock() delete(Streams, sid) + StreamsM.Unlock() close(s.in) close(s.stats.dead) case vors.CmdMuted: diff --git a/cmd/server/gui.go b/cmd/server/gui.go index f2161c1..c9f6ddd 100644 --- a/cmd/server/gui.go +++ b/cmd/server/gui.go @@ -53,10 +53,12 @@ func guiLayout(gui *gocui.Gui) error { v.Autoscroll = true v.Wrap = true } + RoomsM.RLock() roomNames := make([]string, 0, len(Rooms)) for n := range Rooms { roomNames = append(roomNames, n) } + RoomsM.RUnlock() sort.Strings(roomNames) var now time.Time for _, name := range roomNames { diff --git a/cmd/server/main.go b/cmd/server/main.go index 75d0b07..e8fa3b3 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -113,6 +113,7 @@ func newPeer(conn *net.TCPConn) { alive: make(chan struct{}), } Rooms[roomName] = room + RoomsM.Unlock() go func() { if *NoGUI { return @@ -134,8 +135,9 @@ func newPeer(conn *net.TCPConn) { } } }() + } else { + RoomsM.Unlock() } - RoomsM.Unlock() if room.key != key { logger.Error("wrong password") buf, _, _, err = hs.WriteMessage(nil, vors.ArgsEncode( @@ -150,6 +152,7 @@ func newPeer(conn *net.TCPConn) { } peer.room = room + room.peersM.RLock() for _, p := range room.peers { if p.name != peer.name { continue @@ -161,9 +164,11 @@ func newPeer(conn *net.TCPConn) { if err != nil { log.Fatal(err) } + room.peersM.RUnlock() nsConn.Tx(buf) return } + room.peersM.RUnlock() { var i byte @@ -194,19 +199,25 @@ func newPeer(conn *net.TCPConn) { } } logger = logger.With("sid", peer.sid) + room.peersM.Lock() room.peers[peer.sid] = peer + room.peersM.Unlock() logger.Info("logged in") defer func() { logger.Info("removing") PeersM.Lock() delete(Peers, peer.sid) + room.peersM.Lock() delete(room.peers, peer.sid) + room.peersM.Unlock() PeersM.Unlock() s := vors.ArgsEncode([]byte(vors.CmdDel), []byte{peer.sid}) + room.peersM.RLock() for _, p := range room.peers { - go func(tx chan []byte) { tx <- s }(p.tx) + p.tx <- s } + room.peersM.RUnlock() }() { @@ -246,6 +257,7 @@ func newPeer(conn *net.TCPConn) { go peer.Rx() peer.tx <- vors.ArgsEncode([]byte(vors.CmdSID), []byte{peer.sid}) + room.peersM.RLock() for _, p := range room.peers { if p.sid == peer.sid { continue @@ -253,6 +265,7 @@ func newPeer(conn *net.TCPConn) { peer.tx <- vors.ArgsEncode( []byte(vors.CmdAdd), []byte{p.sid}, []byte(p.name), p.key) } + room.peersM.RUnlock() { xof, err := blake2s.NewXOF(chacha20.KeySize+16, nil) @@ -271,11 +284,13 @@ func newPeer(conn *net.TCPConn) { { s := vors.ArgsEncode( []byte(vors.CmdAdd), []byte{peer.sid}, []byte(peer.name), peer.key) + room.peersM.RLock() for _, p := range room.peers { if p.sid != peer.sid { p.tx <- s } } + room.peersM.RUnlock() } seen := time.Now() @@ -313,33 +328,36 @@ func newPeer(conn *net.TCPConn) { case vors.CmdMuted: peer.muted = true s := vors.ArgsEncode([]byte(vors.CmdMuted), []byte{peer.sid}) + room.peersM.RLock() for _, p := range room.peers { - if p.sid == peer.sid { - continue + if p.sid != peer.sid { + p.tx <- s } - go func(tx chan []byte) { tx <- s }(p.tx) } + room.peersM.RUnlock() case vors.CmdUnmuted: peer.muted = false s := vors.ArgsEncode([]byte(vors.CmdUnmuted), []byte{peer.sid}) + room.peersM.RLock() for _, p := range room.peers { - if p.sid == peer.sid { - continue + if p.sid != peer.sid { + p.tx <- s } - go func(tx chan []byte) { tx <- s }(p.tx) } + room.peersM.RUnlock() case vors.CmdChat: if len(args) != 2 { logger.Error("wrong len(args)") continue } - msg := vors.ArgsEncode([]byte(vors.CmdChat), []byte{peer.sid}, args[1]) + s := vors.ArgsEncode([]byte(vors.CmdChat), []byte{peer.sid}, args[1]) + room.peersM.RLock() for _, p := range room.peers { - if p.sid == peer.sid { - continue + if p.sid != peer.sid { + p.tx <- s } - go func(tx chan []byte) { tx <- msg }(p.tx) } + room.peersM.RUnlock() default: logger.Error("unknown", "cmd", cmd) } @@ -485,6 +503,7 @@ func main() { } peer.stats.last = time.Now() + peer.room.peersM.RLock() for _, p := range peer.room.peers { if p.sid == sid || p.addr == nil { continue @@ -495,6 +514,7 @@ func main() { slog.Warn("sendto", "peer", peer.name, "err", err) } } + peer.room.peersM.RUnlock() } }() diff --git a/cmd/server/peer.go b/cmd/server/peer.go index c572ae6..2472882 100644 --- a/cmd/server/peer.go +++ b/cmd/server/peer.go @@ -76,7 +76,10 @@ func (peer *Peer) Tx() { peer.logger.Error("tx encrypt", "err", err) break } - peer.conn.Tx(buf) + if err = peer.conn.Tx(buf); err != nil { + peer.logger.Error("tx write", "err", err) + break + } } peer.Close() } diff --git a/cmd/server/room.go b/cmd/server/room.go index 0be24b2..1d924fc 100644 --- a/cmd/server/room.go +++ b/cmd/server/room.go @@ -12,21 +12,24 @@ import ( var ( Rooms = map[string]*Room{} - RoomsM sync.Mutex + RoomsM sync.RWMutex ) type Room struct { - name string - key string - peers map[byte]*Peer - alive chan struct{} + name string + key string + peers map[byte]*Peer + peersM sync.RWMutex + alive chan struct{} } func (room *Room) Stats(now time.Time) []string { sids := make([]int, 0, len(room.peers)) + room.peersM.RLock() for sid := range room.peers { sids = append(sids, int(sid)) } + room.peersM.RUnlock() sort.Ints(sids) lines := make([]string, 0, len(sids)) for _, sid := range sids { -- 2.48.1