]> Sergey Matveev's repositories - btrtrc.git/blobdiff - peer_protocol/msg.go
Optimize memory usage by avoiding intermediate buffer in message serialization (...
[btrtrc.git] / peer_protocol / msg.go
index b08bb5380e009d514a0e30aa05853b845d6fe825..2ce2b9565c9607f4ff03556df3f385c6f8cdb0bd 100644 (file)
@@ -6,6 +6,7 @@ import (
        "encoding"
        "encoding/binary"
        "fmt"
+       "io"
 )
 
 // This is a lazy union representing all the possible fields for messages. Go doesn't have ADTs, and
@@ -61,13 +62,14 @@ func (msg Message) MustMarshalBinary() []byte {
        return b
 }
 
-func (msg Message) MarshalBinary() (data []byte, err error) {
-       // It might look like you could have a pool of buffers and preallocate the message length
-       // prefix, but because we have to return []byte, it becomes non-trivial to make this fast. You
-       // will need a benchmark.
-       var buf bytes.Buffer
+type MessageWriter interface {
+       io.ByteWriter
+       io.Writer
+}
+
+func (msg *Message) writePayloadTo(buf MessageWriter) (err error) {
        mustWrite := func(data any) {
-               err := binary.Write(&buf, binary.BigEndian, data)
+               err := binary.Write(buf, binary.BigEndian, data)
                if err != nil {
                        panic(err)
                }
@@ -85,10 +87,10 @@ func (msg Message) MarshalBinary() (data []byte, err error) {
                switch msg.Type {
                case Choke, Unchoke, Interested, NotInterested, HaveAll, HaveNone:
                case Have, AllowedFast, Suggest:
-                       err = binary.Write(&buf, binary.BigEndian, msg.Index)
+                       err = binary.Write(buf, binary.BigEndian, msg.Index)
                case Request, Cancel, Reject:
                        for _, i := range []Integer{msg.Index, msg.Begin, msg.Length} {
-                               err = binary.Write(&buf, binary.BigEndian, i)
+                               err = binary.Write(buf, binary.BigEndian, i)
                                if err != nil {
                                        break
                                }
@@ -97,7 +99,7 @@ func (msg Message) MarshalBinary() (data []byte, err error) {
                        _, err = buf.Write(marshalBitfield(msg.Bitfield))
                case Piece:
                        for _, i := range []Integer{msg.Index, msg.Begin} {
-                               err = binary.Write(&buf, binary.BigEndian, i)
+                               err = binary.Write(buf, binary.BigEndian, i)
                                if err != nil {
                                        return
                                }
@@ -116,7 +118,7 @@ func (msg Message) MarshalBinary() (data []byte, err error) {
                        }
                        _, err = buf.Write(msg.ExtendedPayload)
                case Port:
-                       err = binary.Write(&buf, binary.BigEndian, msg.Port)
+                       err = binary.Write(buf, binary.BigEndian, msg.Port)
                case HashRequest:
                        buf.Write(msg.PiecesRoot[:])
                        writeConsecutive(msg.BaseLayer, msg.Index, msg.Length, msg.ProofLayers)
@@ -124,11 +126,35 @@ func (msg Message) MarshalBinary() (data []byte, err error) {
                        err = fmt.Errorf("unknown message type: %v", msg.Type)
                }
        }
-       data = make([]byte, 4+buf.Len())
-       binary.BigEndian.PutUint32(data, uint32(buf.Len()))
-       if buf.Len() != copy(data[4:], buf.Bytes()) {
-               panic("bad copy")
+       return
+}
+
+func (msg *Message) WriteTo(w MessageWriter) (err error) {
+       length, err := msg.getPayloadLength()
+       if err != nil {
+               return
+       }
+       err = binary.Write(w, binary.BigEndian, length)
+       if err != nil {
+               return
        }
+       return msg.writePayloadTo(w)
+}
+
+func (msg *Message) getPayloadLength() (length Integer, err error) {
+       var lw lengthWriter
+       err = msg.writePayloadTo(&lw)
+       length = lw.n
+       return
+}
+
+func (msg Message) MarshalBinary() (data []byte, err error) {
+       // It might look like you could have a pool of buffers and preallocate the message length
+       // prefix, but because we have to return []byte, it becomes non-trivial to make this fast. You
+       // will need a benchmark.
+       var buf bytes.Buffer
+       err = msg.WriteTo(&buf)
+       data = buf.Bytes()
        return
 }
 
@@ -158,3 +184,18 @@ func (me *Message) UnmarshalBinary(b []byte) error {
        }
        return nil
 }
+
+type lengthWriter struct {
+       n Integer
+}
+
+func (l *lengthWriter) WriteByte(c byte) error {
+       l.n++
+       return nil
+}
+
+func (l *lengthWriter) Write(p []byte) (n int, err error) {
+       n = len(p)
+       l.n += Integer(n)
+       return
+}