]> Sergey Matveev's repositories - btrtrc.git/commitdiff
Optimize memory usage by avoiding intermediate buffer in message serialization (...
authorLuo Zhengjie <48181069+Zerolzj@users.noreply.github.com>
Thu, 25 Apr 2024 05:19:54 +0000 (13:19 +0800)
committerGitHub <noreply@github.com>
Thu, 25 Apr 2024 05:19:54 +0000 (15:19 +1000)
* Optimize memory usage by avoiding intermediate buffer in message serialization

This commit replaces the use of an intermediate buffer in the message serialization process with a direct write-to-buffer approach. The original implementation used MustMarshalBinary() which involved an extra memory copy to an intermediate buffer before writing to the final writeBuffer, leading to high memory consumption for large messages. The new WriteTo function writes message data directly to the writeBuffer, significantly reducing memory overhead and CPU time spent on garbage collection.

* add benchmark for write

* benchmark for 1M/4M/8M

* Tidy up new benchmarks

* Maintain older payload write implementation

---------

Co-authored-by: luozhengjie.lzj <luozhengjie.lzj@alibaba-inc.com>
Co-authored-by: Matt Joiner <anacrolix@gmail.com>
peer-conn-msg-writer.go
peer-conn-msg-writer_test.go [new file with mode: 0644]
peer_protocol/msg.go
requesting.go

index 1bacc59d188c59ceb726abf4911939092ac9f574..4f17e5edec4886560cf56e7b2974723f6cb82ee0 100644 (file)
@@ -117,10 +117,21 @@ func (cn *peerConnMsgWriter) run(keepAliveTimeout time.Duration) {
        }
 }
 
+func (cn *peerConnMsgWriter) writeToBuffer(msg pp.Message) (err error) {
+       originalLen := cn.writeBuffer.Len()
+       defer func() {
+               if err != nil {
+                       // Since an error occurred during buffer write, revert buffer to its original state before the write.
+                       cn.writeBuffer.Truncate(originalLen)
+               }
+       }()
+       return msg.WriteTo(cn.writeBuffer)
+}
+
 func (cn *peerConnMsgWriter) write(msg pp.Message) bool {
        cn.mu.Lock()
        defer cn.mu.Unlock()
-       cn.writeBuffer.Write(msg.MustMarshalBinary())
+       cn.writeToBuffer(msg)
        cn.writeCond.Broadcast()
        return !cn.writeBufferFull()
 }
diff --git a/peer-conn-msg-writer_test.go b/peer-conn-msg-writer_test.go
new file mode 100644 (file)
index 0000000..308d18e
--- /dev/null
@@ -0,0 +1,68 @@
+package torrent
+
+import (
+       "bytes"
+       "testing"
+
+       "github.com/dustin/go-humanize"
+
+       pp "github.com/anacrolix/torrent/peer_protocol"
+)
+
+func PieceMsg(length int64) pp.Message {
+       return pp.Message{
+               Type:  pp.Piece,
+               Index: pp.Integer(0),
+               Begin: pp.Integer(0),
+               Piece: make([]byte, length),
+       }
+}
+
+var benchmarkPieceLengths = []int{defaultChunkSize, 1 << 20, 4 << 20, 8 << 20}
+
+func runBenchmarkWriteToBuffer(b *testing.B, length int64) {
+       writer := &peerConnMsgWriter{
+               writeBuffer: &bytes.Buffer{},
+       }
+       msg := PieceMsg(length)
+
+       b.ReportAllocs()
+       b.ResetTimer()
+       for i := 0; i < b.N; i++ {
+               //b.StopTimer()
+               writer.writeBuffer.Reset()
+               //b.StartTimer()
+               writer.writeToBuffer(msg)
+       }
+}
+
+func BenchmarkWritePieceMsg(b *testing.B) {
+       for _, length := range benchmarkPieceLengths {
+               b.Run(humanize.IBytes(uint64(length)), func(b *testing.B) {
+                       b.Run("ToBuffer", func(b *testing.B) {
+                               b.SetBytes(int64(length))
+                               runBenchmarkWriteToBuffer(b, int64(length))
+                       })
+                       b.Run("MarshalBinary", func(b *testing.B) {
+                               b.SetBytes(int64(length))
+                               runBenchmarkMarshalBinaryWrite(b, int64(length))
+                       })
+               })
+       }
+}
+
+func runBenchmarkMarshalBinaryWrite(b *testing.B, length int64) {
+       writer := &peerConnMsgWriter{
+               writeBuffer: &bytes.Buffer{},
+       }
+       msg := PieceMsg(length)
+
+       b.ReportAllocs()
+       b.ResetTimer()
+       for i := 0; i < b.N; i++ {
+               //b.StopTimer()
+               writer.writeBuffer.Reset()
+               //b.StartTimer()
+               writer.writeBuffer.Write(msg.MustMarshalBinary())
+       }
+}
index b08bb5380e009d514a0e30aa05853b845d6fe825..2ce2b9565c9607f4ff03556df3f385c6f8cdb0bd 100644 (file)
@@ -6,6 +6,7 @@ import (
        "encoding"
        "encoding/binary"
        "fmt"
+       "io"
 )
 
 // This is a lazy union representing all the possible fields for messages. Go doesn't have ADTs, and
@@ -61,13 +62,14 @@ func (msg Message) MustMarshalBinary() []byte {
        return b
 }
 
-func (msg Message) MarshalBinary() (data []byte, err error) {
-       // It might look like you could have a pool of buffers and preallocate the message length
-       // prefix, but because we have to return []byte, it becomes non-trivial to make this fast. You
-       // will need a benchmark.
-       var buf bytes.Buffer
+type MessageWriter interface {
+       io.ByteWriter
+       io.Writer
+}
+
+func (msg *Message) writePayloadTo(buf MessageWriter) (err error) {
        mustWrite := func(data any) {
-               err := binary.Write(&buf, binary.BigEndian, data)
+               err := binary.Write(buf, binary.BigEndian, data)
                if err != nil {
                        panic(err)
                }
@@ -85,10 +87,10 @@ func (msg Message) MarshalBinary() (data []byte, err error) {
                switch msg.Type {
                case Choke, Unchoke, Interested, NotInterested, HaveAll, HaveNone:
                case Have, AllowedFast, Suggest:
-                       err = binary.Write(&buf, binary.BigEndian, msg.Index)
+                       err = binary.Write(buf, binary.BigEndian, msg.Index)
                case Request, Cancel, Reject:
                        for _, i := range []Integer{msg.Index, msg.Begin, msg.Length} {
-                               err = binary.Write(&buf, binary.BigEndian, i)
+                               err = binary.Write(buf, binary.BigEndian, i)
                                if err != nil {
                                        break
                                }
@@ -97,7 +99,7 @@ func (msg Message) MarshalBinary() (data []byte, err error) {
                        _, err = buf.Write(marshalBitfield(msg.Bitfield))
                case Piece:
                        for _, i := range []Integer{msg.Index, msg.Begin} {
-                               err = binary.Write(&buf, binary.BigEndian, i)
+                               err = binary.Write(buf, binary.BigEndian, i)
                                if err != nil {
                                        return
                                }
@@ -116,7 +118,7 @@ func (msg Message) MarshalBinary() (data []byte, err error) {
                        }
                        _, err = buf.Write(msg.ExtendedPayload)
                case Port:
-                       err = binary.Write(&buf, binary.BigEndian, msg.Port)
+                       err = binary.Write(buf, binary.BigEndian, msg.Port)
                case HashRequest:
                        buf.Write(msg.PiecesRoot[:])
                        writeConsecutive(msg.BaseLayer, msg.Index, msg.Length, msg.ProofLayers)
@@ -124,11 +126,35 @@ func (msg Message) MarshalBinary() (data []byte, err error) {
                        err = fmt.Errorf("unknown message type: %v", msg.Type)
                }
        }
-       data = make([]byte, 4+buf.Len())
-       binary.BigEndian.PutUint32(data, uint32(buf.Len()))
-       if buf.Len() != copy(data[4:], buf.Bytes()) {
-               panic("bad copy")
+       return
+}
+
+func (msg *Message) WriteTo(w MessageWriter) (err error) {
+       length, err := msg.getPayloadLength()
+       if err != nil {
+               return
+       }
+       err = binary.Write(w, binary.BigEndian, length)
+       if err != nil {
+               return
        }
+       return msg.writePayloadTo(w)
+}
+
+func (msg *Message) getPayloadLength() (length Integer, err error) {
+       var lw lengthWriter
+       err = msg.writePayloadTo(&lw)
+       length = lw.n
+       return
+}
+
+func (msg Message) MarshalBinary() (data []byte, err error) {
+       // It might look like you could have a pool of buffers and preallocate the message length
+       // prefix, but because we have to return []byte, it becomes non-trivial to make this fast. You
+       // will need a benchmark.
+       var buf bytes.Buffer
+       err = msg.WriteTo(&buf)
+       data = buf.Bytes()
        return
 }
 
@@ -158,3 +184,18 @@ func (me *Message) UnmarshalBinary(b []byte) error {
        }
        return nil
 }
+
+type lengthWriter struct {
+       n Integer
+}
+
+func (l *lengthWriter) WriteByte(c byte) error {
+       l.n++
+       return nil
+}
+
+func (l *lengthWriter) Write(p []byte) (n int, err error) {
+       n = len(p)
+       l.n += Integer(n)
+       return
+}
index 51419a3599a75e2a932c7bd1964afed152374136..a59250375ada02b3dd3f8672d75a63bca9493bd5 100644 (file)
@@ -9,9 +9,8 @@ import (
        "time"
        "unsafe"
 
-       g "github.com/anacrolix/generics"
-
        "github.com/RoaringBitmap/roaring"
+       g "github.com/anacrolix/generics"
        "github.com/anacrolix/generics/heap"
        "github.com/anacrolix/log"
        "github.com/anacrolix/multiless"