]> Sergey Matveev's repositories - vors.git/commitdiff
Use mutexes over iterable maps for reliability
authorSergey Matveev <stargrave@stargrave.org>
Sun, 28 Apr 2024 19:15:50 +0000 (22:15 +0300)
committerSergey Matveev <stargrave@stargrave.org>
Mon, 29 Apr 2024 11:28:12 +0000 (14:28 +0300)
cmd/client/gui.go
cmd/client/main.go
cmd/server/gui.go
cmd/server/main.go
cmd/server/peer.go
cmd/server/room.go

index 34f1370688108dbb38d677cfda96a3f7e9383bde7ba5be472656d4e02989dd74..4e9eeab201aad0b6953dcbb4c124c7688ca332e85e4f65ae068088a15e645d11 100644 (file)
@@ -35,9 +35,11 @@ var (
 func tabHandle(gui *gocui.Gui, v *gocui.View) error {
        sids := make([]int, 0, len(Streams)+1)
        sids = append(sids, -1)
+       StreamsM.RLock()
        for sid := range Streams {
                sids = append(sids, int(sid))
        }
+       StreamsM.RUnlock()
        sort.Ints(sids)
        if CurrentView+1 >= len(sids) {
                CurrentView = 0
@@ -92,9 +94,11 @@ func guiLayout(gui *gocui.Gui) error {
        }
        prevY += 3
        sids := make([]int, 0, len(Streams))
+       StreamsM.RLock()
        for sid := range Streams {
                sids = append(sids, int(sid))
        }
+       StreamsM.RUnlock()
        sort.Ints(sids)
        for _, sid := range sids {
                stream := Streams[byte(sid)]
index f485d513ffa4fa32d2a2bec43be19b9e8d779982e3da482a72dc56459931315d..d41b8d9987dc95236845cc7e2bdf8a9065bdd2b25d3cbaf947bd294be22db5e1 100644 (file)
@@ -29,6 +29,7 @@ import (
        "os/exec"
        "strconv"
        "strings"
+       "sync"
        "time"
 
        "github.com/dchest/siphash"
@@ -53,6 +54,7 @@ type Stream struct {
 
 var (
        Streams  = map[byte]*Stream{}
+       StreamsM sync.RWMutex
        Finish   = make(chan struct{})
        OurStats = &Stats{dead: make(chan struct{})}
        Name     = flag.String("name", "test", "username")
@@ -274,7 +276,9 @@ func main() {
                                        log.Fatalln("cookie acceptance failed:", string(args[1]))
                                case vors.CmdSID:
                                        sid = args[1][0]
+                                       StreamsM.Lock()
                                        Streams[sid] = &Stream{name: *Name, stats: OurStats}
+                                       StreamsM.Unlock()
                                default:
                                        log.Fatalln("unexpected post-cookie cmd:", cmd)
                                }
@@ -534,7 +538,9 @@ func main() {
                                        }
                                }()
                                go statsDrawer(stream)
+                               StreamsM.Lock()
                                Streams[sid] = stream
+                               StreamsM.Unlock()
                        case vors.CmdDel:
                                sid := args[1][0]
                                s := Streams[sid]
@@ -543,7 +549,9 @@ func main() {
                                        continue
                                }
                                log.Println("del", s.name, "sid:", sid)
+                               StreamsM.Lock()
                                delete(Streams, sid)
+                               StreamsM.Unlock()
                                close(s.in)
                                close(s.stats.dead)
                        case vors.CmdMuted:
index f2161c15b2a7df8d8902890d72129bb853bdd597f65640c778b576a9308d535d..c9f6ddd97e91d913fc55eac6515de5104be52eb1fe308243f91d7d5a174ac657 100644 (file)
@@ -53,10 +53,12 @@ func guiLayout(gui *gocui.Gui) error {
                v.Autoscroll = true
                v.Wrap = true
        }
+       RoomsM.RLock()
        roomNames := make([]string, 0, len(Rooms))
        for n := range Rooms {
                roomNames = append(roomNames, n)
        }
+       RoomsM.RUnlock()
        sort.Strings(roomNames)
        var now time.Time
        for _, name := range roomNames {
index 75d0b078bdbdc51234e0fe2cf4c798e45dd1b62df8d0927cc02de4a30d411d85..e8fa3b3c0c761ca9b0eb89600bc56a00a53d4743f41d16857b04ba766d0f0a86 100644 (file)
@@ -113,6 +113,7 @@ func newPeer(conn *net.TCPConn) {
                                alive: make(chan struct{}),
                        }
                        Rooms[roomName] = room
+                       RoomsM.Unlock()
                        go func() {
                                if *NoGUI {
                                        return
@@ -134,8 +135,9 @@ func newPeer(conn *net.TCPConn) {
                                        }
                                }
                        }()
+               } else {
+                       RoomsM.Unlock()
                }
-               RoomsM.Unlock()
                if room.key != key {
                        logger.Error("wrong password")
                        buf, _, _, err = hs.WriteMessage(nil, vors.ArgsEncode(
@@ -150,6 +152,7 @@ func newPeer(conn *net.TCPConn) {
        }
        peer.room = room
 
+       room.peersM.RLock()
        for _, p := range room.peers {
                if p.name != peer.name {
                        continue
@@ -161,9 +164,11 @@ func newPeer(conn *net.TCPConn) {
                if err != nil {
                        log.Fatal(err)
                }
+               room.peersM.RUnlock()
                nsConn.Tx(buf)
                return
        }
+       room.peersM.RUnlock()
 
        {
                var i byte
@@ -194,19 +199,25 @@ func newPeer(conn *net.TCPConn) {
                }
        }
        logger = logger.With("sid", peer.sid)
+       room.peersM.Lock()
        room.peers[peer.sid] = peer
+       room.peersM.Unlock()
        logger.Info("logged in")
 
        defer func() {
                logger.Info("removing")
                PeersM.Lock()
                delete(Peers, peer.sid)
+               room.peersM.Lock()
                delete(room.peers, peer.sid)
+               room.peersM.Unlock()
                PeersM.Unlock()
                s := vors.ArgsEncode([]byte(vors.CmdDel), []byte{peer.sid})
+               room.peersM.RLock()
                for _, p := range room.peers {
-                       go func(tx chan []byte) { tx <- s }(p.tx)
+                       p.tx <- s
                }
+               room.peersM.RUnlock()
        }()
 
        {
@@ -246,6 +257,7 @@ func newPeer(conn *net.TCPConn) {
        go peer.Rx()
        peer.tx <- vors.ArgsEncode([]byte(vors.CmdSID), []byte{peer.sid})
 
+       room.peersM.RLock()
        for _, p := range room.peers {
                if p.sid == peer.sid {
                        continue
@@ -253,6 +265,7 @@ func newPeer(conn *net.TCPConn) {
                peer.tx <- vors.ArgsEncode(
                        []byte(vors.CmdAdd), []byte{p.sid}, []byte(p.name), p.key)
        }
+       room.peersM.RUnlock()
 
        {
                xof, err := blake2s.NewXOF(chacha20.KeySize+16, nil)
@@ -271,11 +284,13 @@ func newPeer(conn *net.TCPConn) {
        {
                s := vors.ArgsEncode(
                        []byte(vors.CmdAdd), []byte{peer.sid}, []byte(peer.name), peer.key)
+               room.peersM.RLock()
                for _, p := range room.peers {
                        if p.sid != peer.sid {
                                p.tx <- s
                        }
                }
+               room.peersM.RUnlock()
        }
 
        seen := time.Now()
@@ -313,33 +328,36 @@ func newPeer(conn *net.TCPConn) {
                case vors.CmdMuted:
                        peer.muted = true
                        s := vors.ArgsEncode([]byte(vors.CmdMuted), []byte{peer.sid})
+                       room.peersM.RLock()
                        for _, p := range room.peers {
-                               if p.sid == peer.sid {
-                                       continue
+                               if p.sid != peer.sid {
+                                       p.tx <- s
                                }
-                               go func(tx chan []byte) { tx <- s }(p.tx)
                        }
+                       room.peersM.RUnlock()
                case vors.CmdUnmuted:
                        peer.muted = false
                        s := vors.ArgsEncode([]byte(vors.CmdUnmuted), []byte{peer.sid})
+                       room.peersM.RLock()
                        for _, p := range room.peers {
-                               if p.sid == peer.sid {
-                                       continue
+                               if p.sid != peer.sid {
+                                       p.tx <- s
                                }
-                               go func(tx chan []byte) { tx <- s }(p.tx)
                        }
+                       room.peersM.RUnlock()
                case vors.CmdChat:
                        if len(args) != 2 {
                                logger.Error("wrong len(args)")
                                continue
                        }
-                       msg := vors.ArgsEncode([]byte(vors.CmdChat), []byte{peer.sid}, args[1])
+                       s := vors.ArgsEncode([]byte(vors.CmdChat), []byte{peer.sid}, args[1])
+                       room.peersM.RLock()
                        for _, p := range room.peers {
-                               if p.sid == peer.sid {
-                                       continue
+                               if p.sid != peer.sid {
+                                       p.tx <- s
                                }
-                               go func(tx chan []byte) { tx <- msg }(p.tx)
                        }
+                       room.peersM.RUnlock()
                default:
                        logger.Error("unknown", "cmd", cmd)
                }
@@ -485,6 +503,7 @@ func main() {
                        }
 
                        peer.stats.last = time.Now()
+                       peer.room.peersM.RLock()
                        for _, p := range peer.room.peers {
                                if p.sid == sid || p.addr == nil {
                                        continue
@@ -495,6 +514,7 @@ func main() {
                                        slog.Warn("sendto", "peer", peer.name, "err", err)
                                }
                        }
+                       peer.room.peersM.RUnlock()
                }
        }()
 
index c572ae6424dc35e6504c0c699e61c0249caf6005547a2627eac7ff786ce7e225..24728822f5f1cc274985d01705aae41ba1be422be02771f27a7782877447a047 100644 (file)
@@ -76,7 +76,10 @@ func (peer *Peer) Tx() {
                        peer.logger.Error("tx encrypt", "err", err)
                        break
                }
-               peer.conn.Tx(buf)
+               if err = peer.conn.Tx(buf); err != nil {
+                       peer.logger.Error("tx write", "err", err)
+                       break
+               }
        }
        peer.Close()
 }
index 0be24b25d24697e39460d00b3b604f4b11820a72f13202bf289f7552704d8024..1d924fc308ede370df69b646541c9e0e2016a14259f19eb864bf70d4bbbbeb27 100644 (file)
@@ -12,21 +12,24 @@ import (
 
 var (
        Rooms  = map[string]*Room{}
-       RoomsM sync.Mutex
+       RoomsM sync.RWMutex
 )
 
 type Room struct {
-       name  string
-       key   string
-       peers map[byte]*Peer
-       alive chan struct{}
+       name   string
+       key    string
+       peers  map[byte]*Peer
+       peersM sync.RWMutex
+       alive  chan struct{}
 }
 
 func (room *Room) Stats(now time.Time) []string {
        sids := make([]int, 0, len(room.peers))
+       room.peersM.RLock()
        for sid := range room.peers {
                sids = append(sids, int(sid))
        }
+       room.peersM.RUnlock()
        sort.Ints(sids)
        lines := make([]string, 0, len(sids))
        for _, sid := range sids {