]> Sergey Matveev's repositories - btrtrc.git/blobdiff - peer-conn-msg-writer.go
Drop support for go 1.20
[btrtrc.git] / peer-conn-msg-writer.go
index dff4eb9ec8f98405821ffa36ae9122dcb1e2b67d..1bacc59d188c59ceb726abf4911939092ac9f574 100644 (file)
@@ -12,7 +12,7 @@ import (
        pp "github.com/anacrolix/torrent/peer_protocol"
 )
 
-func (pc *PeerConn) startWriter() {
+func (pc *PeerConn) initMessageWriter() {
        w := &pc.messageWriter
        *w = peerConnMsgWriter{
                fillWriteBuffer: func() {
@@ -27,18 +27,24 @@ func (pc *PeerConn) startWriter() {
                logger: pc.logger,
                w:      pc.w,
                keepAlive: func() bool {
-                       pc.locker().Lock()
-                       defer pc.locker().Unlock()
+                       pc.locker().RLock()
+                       defer pc.locker().RUnlock()
                        return pc.useful()
                },
                writeBuffer: new(bytes.Buffer),
        }
-       go func() {
-               defer pc.locker().Unlock()
-               defer pc.close()
-               defer pc.locker().Lock()
-               pc.messageWriter.run(pc.t.cl.config.KeepAliveTimeout)
-       }()
+}
+
+func (pc *PeerConn) startMessageWriter() {
+       pc.initMessageWriter()
+       go pc.messageWriterRunner()
+}
+
+func (pc *PeerConn) messageWriterRunner() {
+       defer pc.locker().Unlock()
+       defer pc.close()
+       defer pc.locker().Lock()
+       pc.messageWriter.run(pc.t.cl.config.KeepAliveTimeout)
 }
 
 type peerConnMsgWriter struct {
@@ -59,35 +65,16 @@ type peerConnMsgWriter struct {
 // activity elsewhere in the Client, and some is determined locally when the
 // connection is writable.
 func (cn *peerConnMsgWriter) run(keepAliveTimeout time.Duration) {
-       var (
-               lastWrite      time.Time = time.Now()
-               keepAliveTimer *time.Timer
-       )
-       cn.mu.Lock()
-       defer cn.mu.Unlock()
-       keepAliveTimer = time.AfterFunc(keepAliveTimeout, func() {
-               cn.mu.Lock()
-               defer cn.mu.Unlock()
-               if time.Since(lastWrite) >= keepAliveTimeout {
-                       cn.writeCond.Broadcast()
-               }
-               keepAliveTimer.Reset(keepAliveTimeout)
-       })
-       defer keepAliveTimer.Stop()
+       lastWrite := time.Now()
+       keepAliveTimer := time.NewTimer(keepAliveTimeout)
        frontBuf := new(bytes.Buffer)
        for {
                if cn.closed.IsSet() {
                        return
                }
-               keepAlive := false
-               if cn.writeBuffer.Len() == 0 {
-                       func() {
-                               cn.mu.Unlock()
-                               defer cn.mu.Lock()
-                               cn.fillWriteBuffer()
-                               keepAlive = cn.keepAlive()
-                       }()
-               }
+               cn.fillWriteBuffer()
+               keepAlive := cn.keepAlive()
+               cn.mu.Lock()
                if cn.writeBuffer.Len() == 0 && time.Since(lastWrite) >= keepAliveTimeout && keepAlive {
                        cn.writeBuffer.Write(pp.Message{Keepalive: true}.MustMarshalBinary())
                        torrent.Add("written keepalives", 1)
@@ -98,27 +85,35 @@ func (cn *peerConnMsgWriter) run(keepAliveTimeout time.Duration) {
                        select {
                        case <-cn.closed.Done():
                        case <-writeCond:
+                       case <-keepAliveTimer.C:
                        }
-                       cn.mu.Lock()
                        continue
                }
                // Flip the buffers.
                frontBuf, cn.writeBuffer = cn.writeBuffer, frontBuf
                cn.mu.Unlock()
-               n, err := cn.w.Write(frontBuf.Bytes())
-               cn.mu.Lock()
-               if n != 0 {
-                       lastWrite = time.Now()
-                       keepAliveTimer.Reset(keepAliveTimeout)
+               if frontBuf.Len() == 0 {
+                       panic("expected non-empty front buffer")
+               }
+               var err error
+               for frontBuf.Len() != 0 {
+                       // Limit write size for WebRTC. See https://github.com/pion/datachannel/issues/59.
+                       next := frontBuf.Next(1<<16 - 1)
+                       var n int
+                       n, err = cn.w.Write(next)
+                       if err == nil && n != len(next) {
+                               panic("expected full write")
+                       }
+                       if err != nil {
+                               break
+                       }
                }
                if err != nil {
                        cn.logger.WithDefaultLevel(log.Debug).Printf("error writing: %v", err)
                        return
                }
-               if n != frontBuf.Len() {
-                       panic("short write")
-               }
-               frontBuf.Reset()
+               lastWrite = time.Now()
+               keepAliveTimer.Reset(keepAliveTimeout)
        }
 }