peer-conn-msg-writer.go | 13 ++++++++++++- peer-conn-msg-writer_test.go | 68 +++++++++++++++++++++++++++++++++++++++++++++++++++++ peer_protocol/msg.go | 69 ++++++++++++++++++++++++++++++++++++++++++----------- requesting.go | 3 +-- diff --git a/peer-conn-msg-writer.go b/peer-conn-msg-writer.go index 1bacc59d188c59ceb726abf4911939092ac9f574..4f17e5edec4886560cf56e7b2974723f6cb82ee0 100644 --- a/peer-conn-msg-writer.go +++ b/peer-conn-msg-writer.go @@ -117,10 +117,21 @@ keepAliveTimer.Reset(keepAliveTimeout) } } +func (cn *peerConnMsgWriter) writeToBuffer(msg pp.Message) (err error) { + originalLen := cn.writeBuffer.Len() + defer func() { + if err != nil { + // Since an error occurred during buffer write, revert buffer to its original state before the write. + cn.writeBuffer.Truncate(originalLen) + } + }() + return msg.WriteTo(cn.writeBuffer) +} + func (cn *peerConnMsgWriter) write(msg pp.Message) bool { cn.mu.Lock() defer cn.mu.Unlock() - cn.writeBuffer.Write(msg.MustMarshalBinary()) + cn.writeToBuffer(msg) cn.writeCond.Broadcast() return !cn.writeBufferFull() } diff --git a/peer-conn-msg-writer_test.go b/peer-conn-msg-writer_test.go new file mode 100644 index 0000000000000000000000000000000000000000..308d18e5af200530c780a6c750d000e0cb20abbb --- /dev/null +++ b/peer-conn-msg-writer_test.go @@ -0,0 +1,68 @@ +package torrent + +import ( + "bytes" + "testing" + + "github.com/dustin/go-humanize" + + pp "github.com/anacrolix/torrent/peer_protocol" +) + +func PieceMsg(length int64) pp.Message { + return pp.Message{ + Type: pp.Piece, + Index: pp.Integer(0), + Begin: pp.Integer(0), + Piece: make([]byte, length), + } +} + +var benchmarkPieceLengths = []int{defaultChunkSize, 1 << 20, 4 << 20, 8 << 20} + +func runBenchmarkWriteToBuffer(b *testing.B, length int64) { + writer := &peerConnMsgWriter{ + writeBuffer: &bytes.Buffer{}, + } + msg := PieceMsg(length) + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + //b.StopTimer() + writer.writeBuffer.Reset() + //b.StartTimer() + writer.writeToBuffer(msg) + } +} + +func BenchmarkWritePieceMsg(b *testing.B) { + for _, length := range benchmarkPieceLengths { + b.Run(humanize.IBytes(uint64(length)), func(b *testing.B) { + b.Run("ToBuffer", func(b *testing.B) { + b.SetBytes(int64(length)) + runBenchmarkWriteToBuffer(b, int64(length)) + }) + b.Run("MarshalBinary", func(b *testing.B) { + b.SetBytes(int64(length)) + runBenchmarkMarshalBinaryWrite(b, int64(length)) + }) + }) + } +} + +func runBenchmarkMarshalBinaryWrite(b *testing.B, length int64) { + writer := &peerConnMsgWriter{ + writeBuffer: &bytes.Buffer{}, + } + msg := PieceMsg(length) + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + //b.StopTimer() + writer.writeBuffer.Reset() + //b.StartTimer() + writer.writeBuffer.Write(msg.MustMarshalBinary()) + } +} diff --git a/peer_protocol/msg.go b/peer_protocol/msg.go index b08bb5380e009d514a0e30aa05853b845d6fe825..2ce2b9565c9607f4ff03556df3f385c6f8cdb0bd 100644 --- a/peer_protocol/msg.go +++ b/peer_protocol/msg.go @@ -6,6 +6,7 @@ "bytes" "encoding" "encoding/binary" "fmt" + "io" ) // This is a lazy union representing all the possible fields for messages. Go doesn't have ADTs, and @@ -61,13 +62,14 @@ } return b } -func (msg Message) MarshalBinary() (data []byte, err error) { - // It might look like you could have a pool of buffers and preallocate the message length - // prefix, but because we have to return []byte, it becomes non-trivial to make this fast. You - // will need a benchmark. - var buf bytes.Buffer +type MessageWriter interface { + io.ByteWriter + io.Writer +} + +func (msg *Message) writePayloadTo(buf MessageWriter) (err error) { mustWrite := func(data any) { - err := binary.Write(&buf, binary.BigEndian, data) + err := binary.Write(buf, binary.BigEndian, data) if err != nil { panic(err) } @@ -85,10 +87,10 @@ } switch msg.Type { case Choke, Unchoke, Interested, NotInterested, HaveAll, HaveNone: case Have, AllowedFast, Suggest: - err = binary.Write(&buf, binary.BigEndian, msg.Index) + err = binary.Write(buf, binary.BigEndian, msg.Index) case Request, Cancel, Reject: for _, i := range []Integer{msg.Index, msg.Begin, msg.Length} { - err = binary.Write(&buf, binary.BigEndian, i) + err = binary.Write(buf, binary.BigEndian, i) if err != nil { break } @@ -97,7 +99,7 @@ case Bitfield: _, err = buf.Write(marshalBitfield(msg.Bitfield)) case Piece: for _, i := range []Integer{msg.Index, msg.Begin} { - err = binary.Write(&buf, binary.BigEndian, i) + err = binary.Write(buf, binary.BigEndian, i) if err != nil { return } @@ -116,7 +118,7 @@ return } _, err = buf.Write(msg.ExtendedPayload) case Port: - err = binary.Write(&buf, binary.BigEndian, msg.Port) + err = binary.Write(buf, binary.BigEndian, msg.Port) case HashRequest: buf.Write(msg.PiecesRoot[:]) writeConsecutive(msg.BaseLayer, msg.Index, msg.Length, msg.ProofLayers) @@ -124,11 +126,35 @@ default: err = fmt.Errorf("unknown message type: %v", msg.Type) } } - data = make([]byte, 4+buf.Len()) - binary.BigEndian.PutUint32(data, uint32(buf.Len())) - if buf.Len() != copy(data[4:], buf.Bytes()) { - panic("bad copy") + return +} + +func (msg *Message) WriteTo(w MessageWriter) (err error) { + length, err := msg.getPayloadLength() + if err != nil { + return + } + err = binary.Write(w, binary.BigEndian, length) + if err != nil { + return } + return msg.writePayloadTo(w) +} + +func (msg *Message) getPayloadLength() (length Integer, err error) { + var lw lengthWriter + err = msg.writePayloadTo(&lw) + length = lw.n + return +} + +func (msg Message) MarshalBinary() (data []byte, err error) { + // It might look like you could have a pool of buffers and preallocate the message length + // prefix, but because we have to return []byte, it becomes non-trivial to make this fast. You + // will need a benchmark. + var buf bytes.Buffer + err = msg.WriteTo(&buf) + data = buf.Bytes() return } @@ -158,3 +184,18 @@ return fmt.Errorf("%d trailing bytes", d.R.Buffered()) } return nil } + +type lengthWriter struct { + n Integer +} + +func (l *lengthWriter) WriteByte(c byte) error { + l.n++ + return nil +} + +func (l *lengthWriter) Write(p []byte) (n int, err error) { + n = len(p) + l.n += Integer(n) + return +} diff --git a/requesting.go b/requesting.go index 51419a3599a75e2a932c7bd1964afed152374136..a59250375ada02b3dd3f8672d75a63bca9493bd5 100644 --- a/requesting.go +++ b/requesting.go @@ -9,9 +9,8 @@ "runtime/pprof" "time" "unsafe" + "github.com/RoaringBitmap/roaring" g "github.com/anacrolix/generics" - - "github.com/RoaringBitmap/roaring" "github.com/anacrolix/generics/heap" "github.com/anacrolix/log" "github.com/anacrolix/multiless"