From 3f5ef0be3e4540b9072af82340bb460233f92a09 Mon Sep 17 00:00:00 2001 From: Luo Zhengjie <48181069+Zerolzj@users.noreply.github.com> Date: Thu, 25 Apr 2024 13:19:54 +0800 Subject: [PATCH] Optimize memory usage by avoiding intermediate buffer in message serialization (#928) * Optimize memory usage by avoiding intermediate buffer in message serialization This commit replaces the use of an intermediate buffer in the message serialization process with a direct write-to-buffer approach. The original implementation used MustMarshalBinary() which involved an extra memory copy to an intermediate buffer before writing to the final writeBuffer, leading to high memory consumption for large messages. The new WriteTo function writes message data directly to the writeBuffer, significantly reducing memory overhead and CPU time spent on garbage collection. * add benchmark for write * benchmark for 1M/4M/8M * Tidy up new benchmarks * Maintain older payload write implementation --------- Co-authored-by: luozhengjie.lzj Co-authored-by: Matt Joiner --- peer-conn-msg-writer.go | 13 ++++++- peer-conn-msg-writer_test.go | 68 +++++++++++++++++++++++++++++++++++ peer_protocol/msg.go | 69 ++++++++++++++++++++++++++++-------- requesting.go | 3 +- 4 files changed, 136 insertions(+), 17 deletions(-) create mode 100644 peer-conn-msg-writer_test.go diff --git a/peer-conn-msg-writer.go b/peer-conn-msg-writer.go index 1bacc59d..4f17e5ed 100644 --- a/peer-conn-msg-writer.go +++ b/peer-conn-msg-writer.go @@ -117,10 +117,21 @@ func (cn *peerConnMsgWriter) run(keepAliveTimeout time.Duration) { } } +func (cn *peerConnMsgWriter) writeToBuffer(msg pp.Message) (err error) { + originalLen := cn.writeBuffer.Len() + defer func() { + if err != nil { + // Since an error occurred during buffer write, revert buffer to its original state before the write. + cn.writeBuffer.Truncate(originalLen) + } + }() + return msg.WriteTo(cn.writeBuffer) +} + func (cn *peerConnMsgWriter) write(msg pp.Message) bool { cn.mu.Lock() defer cn.mu.Unlock() - cn.writeBuffer.Write(msg.MustMarshalBinary()) + cn.writeToBuffer(msg) cn.writeCond.Broadcast() return !cn.writeBufferFull() } diff --git a/peer-conn-msg-writer_test.go b/peer-conn-msg-writer_test.go new file mode 100644 index 00000000..308d18e5 --- /dev/null +++ b/peer-conn-msg-writer_test.go @@ -0,0 +1,68 @@ +package torrent + +import ( + "bytes" + "testing" + + "github.com/dustin/go-humanize" + + pp "github.com/anacrolix/torrent/peer_protocol" +) + +func PieceMsg(length int64) pp.Message { + return pp.Message{ + Type: pp.Piece, + Index: pp.Integer(0), + Begin: pp.Integer(0), + Piece: make([]byte, length), + } +} + +var benchmarkPieceLengths = []int{defaultChunkSize, 1 << 20, 4 << 20, 8 << 20} + +func runBenchmarkWriteToBuffer(b *testing.B, length int64) { + writer := &peerConnMsgWriter{ + writeBuffer: &bytes.Buffer{}, + } + msg := PieceMsg(length) + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + //b.StopTimer() + writer.writeBuffer.Reset() + //b.StartTimer() + writer.writeToBuffer(msg) + } +} + +func BenchmarkWritePieceMsg(b *testing.B) { + for _, length := range benchmarkPieceLengths { + b.Run(humanize.IBytes(uint64(length)), func(b *testing.B) { + b.Run("ToBuffer", func(b *testing.B) { + b.SetBytes(int64(length)) + runBenchmarkWriteToBuffer(b, int64(length)) + }) + b.Run("MarshalBinary", func(b *testing.B) { + b.SetBytes(int64(length)) + runBenchmarkMarshalBinaryWrite(b, int64(length)) + }) + }) + } +} + +func runBenchmarkMarshalBinaryWrite(b *testing.B, length int64) { + writer := &peerConnMsgWriter{ + writeBuffer: &bytes.Buffer{}, + } + msg := PieceMsg(length) + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + //b.StopTimer() + writer.writeBuffer.Reset() + //b.StartTimer() + writer.writeBuffer.Write(msg.MustMarshalBinary()) + } +} diff --git a/peer_protocol/msg.go b/peer_protocol/msg.go index b08bb538..2ce2b956 100644 --- a/peer_protocol/msg.go +++ b/peer_protocol/msg.go @@ -6,6 +6,7 @@ import ( "encoding" "encoding/binary" "fmt" + "io" ) // This is a lazy union representing all the possible fields for messages. Go doesn't have ADTs, and @@ -61,13 +62,14 @@ func (msg Message) MustMarshalBinary() []byte { return b } -func (msg Message) MarshalBinary() (data []byte, err error) { - // It might look like you could have a pool of buffers and preallocate the message length - // prefix, but because we have to return []byte, it becomes non-trivial to make this fast. You - // will need a benchmark. - var buf bytes.Buffer +type MessageWriter interface { + io.ByteWriter + io.Writer +} + +func (msg *Message) writePayloadTo(buf MessageWriter) (err error) { mustWrite := func(data any) { - err := binary.Write(&buf, binary.BigEndian, data) + err := binary.Write(buf, binary.BigEndian, data) if err != nil { panic(err) } @@ -85,10 +87,10 @@ func (msg Message) MarshalBinary() (data []byte, err error) { switch msg.Type { case Choke, Unchoke, Interested, NotInterested, HaveAll, HaveNone: case Have, AllowedFast, Suggest: - err = binary.Write(&buf, binary.BigEndian, msg.Index) + err = binary.Write(buf, binary.BigEndian, msg.Index) case Request, Cancel, Reject: for _, i := range []Integer{msg.Index, msg.Begin, msg.Length} { - err = binary.Write(&buf, binary.BigEndian, i) + err = binary.Write(buf, binary.BigEndian, i) if err != nil { break } @@ -97,7 +99,7 @@ func (msg Message) MarshalBinary() (data []byte, err error) { _, err = buf.Write(marshalBitfield(msg.Bitfield)) case Piece: for _, i := range []Integer{msg.Index, msg.Begin} { - err = binary.Write(&buf, binary.BigEndian, i) + err = binary.Write(buf, binary.BigEndian, i) if err != nil { return } @@ -116,7 +118,7 @@ func (msg Message) MarshalBinary() (data []byte, err error) { } _, err = buf.Write(msg.ExtendedPayload) case Port: - err = binary.Write(&buf, binary.BigEndian, msg.Port) + err = binary.Write(buf, binary.BigEndian, msg.Port) case HashRequest: buf.Write(msg.PiecesRoot[:]) writeConsecutive(msg.BaseLayer, msg.Index, msg.Length, msg.ProofLayers) @@ -124,11 +126,35 @@ func (msg Message) MarshalBinary() (data []byte, err error) { err = fmt.Errorf("unknown message type: %v", msg.Type) } } - data = make([]byte, 4+buf.Len()) - binary.BigEndian.PutUint32(data, uint32(buf.Len())) - if buf.Len() != copy(data[4:], buf.Bytes()) { - panic("bad copy") + return +} + +func (msg *Message) WriteTo(w MessageWriter) (err error) { + length, err := msg.getPayloadLength() + if err != nil { + return + } + err = binary.Write(w, binary.BigEndian, length) + if err != nil { + return } + return msg.writePayloadTo(w) +} + +func (msg *Message) getPayloadLength() (length Integer, err error) { + var lw lengthWriter + err = msg.writePayloadTo(&lw) + length = lw.n + return +} + +func (msg Message) MarshalBinary() (data []byte, err error) { + // It might look like you could have a pool of buffers and preallocate the message length + // prefix, but because we have to return []byte, it becomes non-trivial to make this fast. You + // will need a benchmark. + var buf bytes.Buffer + err = msg.WriteTo(&buf) + data = buf.Bytes() return } @@ -158,3 +184,18 @@ func (me *Message) UnmarshalBinary(b []byte) error { } return nil } + +type lengthWriter struct { + n Integer +} + +func (l *lengthWriter) WriteByte(c byte) error { + l.n++ + return nil +} + +func (l *lengthWriter) Write(p []byte) (n int, err error) { + n = len(p) + l.n += Integer(n) + return +} diff --git a/requesting.go b/requesting.go index 51419a35..a5925037 100644 --- a/requesting.go +++ b/requesting.go @@ -9,9 +9,8 @@ import ( "time" "unsafe" - g "github.com/anacrolix/generics" - "github.com/RoaringBitmap/roaring" + g "github.com/anacrolix/generics" "github.com/anacrolix/generics/heap" "github.com/anacrolix/log" "github.com/anacrolix/multiless" -- 2.44.0