cmd/init/main.go | 34 +++++++++++++++++++--------------- cmd/resp/main.go | 22 +++++++++++++--------- diff --git a/cmd/init/main.go b/cmd/init/main.go index 77b82139b67ad12d1163e9419729589a037043f78c9847fbe2ff84ff3f83604e..3564dbf0ed737d1a4c718e6774742fddbf2ffce3c45b5fbf264fea4c00584b48 100644 --- a/cmd/init/main.go +++ b/cmd/init/main.go @@ -44,6 +44,7 @@ DstAddrTCP *net.TCPAddr TLSConfig *tls.Config LnUDP *net.UDPConn Peers = make(map[string]chan udpobfs.Buf) + PeersM sync.RWMutex Bufs = sync.Pool{New: func() any { return new([udpobfs.BufLen]byte) }} ) @@ -83,8 +84,17 @@ return } cryptoState := udpobfs.NewCryptoState(seed, true) txs := make(chan udpobfs.Buf) + PeersM.Lock() + Peers[localAddr.String()] = txs + PeersM.Unlock() + var rxPkts, txPkts, rxBytes, txBytes int64 + { + txPkts++ + txBytes += int64(len(dataInitial)) + tmp := make([]byte, udpobfs.SeqLen+len(dataInitial)) + connUDP.WriteTo(cryptoState.Tx(tmp, dataInitial), DstAddrUDP) + } rxFinished := make(chan struct{}) - var rxPkts, txPkts, rxBytes, txBytes int64 go func() { var n int var err error @@ -129,6 +139,7 @@ now := time.Now() lastPing := now last := now var got []byte + var ok bool for { select { case <-ticker.C: @@ -142,8 +153,8 @@ _, err = connUDP.WriteTo( cryptoState.Tx(buf[:udpobfs.SeqLen], nil), DstAddrUDP) lastPing = now } - case tx = <-txs: - if tx.Buf == nil { + case tx, ok = <-txs: + if !ok { return } got = cryptoState.Tx(buf[:udpobfs.SeqLen+tx.N], (*tx.Buf)[:tx.N]) @@ -156,13 +167,6 @@ last = lastPing } } }() - Peers[localAddr.String()] = txs - { - txPkts++ - txBytes += int64(len(dataInitial)) - tmp := make([]byte, udpobfs.SeqLen+len(dataInitial)) - connUDP.WriteTo(cryptoState.Tx(tmp, dataInitial), DstAddrUDP) - } go func() { defer connUDP.Close() ticker := time.NewTicker(udpobfs.LifetimeDuration) @@ -202,12 +206,10 @@ "rxPkts", rxPkts, "rxBytes", rxBytes, "txPkts", txPkts, "txBytes", txBytes) + PeersM.Lock() delete(Peers, localAddr.String()) - txs <- udpobfs.Buf{Buf: nil} - go func() { - for range txs { - } - }() + PeersM.Unlock() + close(txs) } func main() { @@ -277,10 +279,12 @@ n, from, _ = LnUDP.ReadFrom((*buf)[:]) if n == 0 { continue } + PeersM.RLock() txs = Peers[from.String()] if txs != nil { txs <- udpobfs.Buf{Buf: buf, N: n} } + PeersM.RUnlock() if txs == nil { neu := make([]byte, n) copy(neu, (*buf)[:n]) diff --git a/cmd/resp/main.go b/cmd/resp/main.go index cc973368855855ed6671693cb55cc6239712c7e3ceb045a0311970977eccdb16..75452990263b7392253f3389a32572e538d2b9797c00f5e6297f431e4e9bcce7 100644 --- a/cmd/resp/main.go +++ b/cmd/resp/main.go @@ -37,6 +37,7 @@ TLSConfig *tls.Config LnUDP *net.UDPConn LnTCP *net.TCPListener Peers = make(map[string]chan udpobfs.Buf) + PeersM sync.RWMutex Bufs = sync.Pool{New: func() any { return new([udpobfs.BufLen]byte) }} ) @@ -67,6 +68,9 @@ return } cryptoState := udpobfs.NewCryptoState(seed, false) txs := make(chan udpobfs.Buf) + PeersM.Lock() + Peers[remoteAddr.String()] = txs + PeersM.Unlock() txFinished := make(chan struct{}) var rxPkts, txPkts, rxBytes, txBytes int64 go func() { @@ -128,6 +132,7 @@ last := now buf := make([]byte, udpobfs.BufLen) var tx udpobfs.Buf var got []byte + var ok bool for { select { case <-txFinished: @@ -138,9 +143,9 @@ if now.Sub(last) > 2*udpobfs.LifetimeDuration { localUDP.Close() return } - case tx = <-txs: - if tx.Buf == nil { - break + case tx, ok = <-txs: + if !ok { + return } if tx.N < udpobfs.SeqLen { logger.Warn("too short") @@ -162,7 +167,6 @@ last = time.Now() } } }() - Peers[remoteAddr.String()] = txs go func() { buf := make([]byte, 8) for { @@ -181,12 +185,10 @@ "rxPkts", rxPkts, "rxBytes", rxBytes, "txPkts", txPkts, "txBytes", txBytes) + PeersM.Lock() delete(Peers, remoteAddr.String()) - txs <- udpobfs.Buf{Buf: nil} - go func() { - for range txs { - } - }() + PeersM.Unlock() + close(txs) } func main() { @@ -240,10 +242,12 @@ n, from, _ = LnUDP.ReadFrom((*buf)[:]) if n == 0 { continue } + PeersM.RLock() txs = Peers[from.String()] if txs != nil { txs <- udpobfs.Buf{Buf: buf, N: n} } + PeersM.RUnlock() } }()