]> Sergey Matveev's repositories - btrtrc.git/blobdiff - peer_protocol/msg.go
Skip test failures due to Go uTP implementation
[btrtrc.git] / peer_protocol / msg.go
index c0d94e3727bb7f1f6b1739d5f634b760f3a31505..23710e6339d4cd93b79071cdb4c17b0165fbf5cb 100644 (file)
@@ -1,7 +1,9 @@
 package peer_protocol
 
 import (
+       "bufio"
        "bytes"
+       "encoding"
        "encoding/binary"
        "fmt"
 )
@@ -19,6 +21,11 @@ type Message struct {
        Port                 uint16
 }
 
+var _ interface {
+       encoding.BinaryUnmarshaler
+       encoding.BinaryMarshaler
+} = (*Message)(nil)
+
 func MakeCancelMessage(piece, offset, length Integer) Message {
        return Message{
                Type:   Cancel,
@@ -51,7 +58,7 @@ func (msg Message) MustMarshalBinary() []byte {
 }
 
 func (msg Message) MarshalBinary() (data []byte, err error) {
-       buf := &bytes.Buffer{}
+       var buf bytes.Buffer
        if !msg.Keepalive {
                err = buf.WriteByte(byte(msg.Type))
                if err != nil {
@@ -60,10 +67,10 @@ func (msg Message) MarshalBinary() (data []byte, err error) {
                switch msg.Type {
                case Choke, Unchoke, Interested, NotInterested, HaveAll, HaveNone:
                case Have:
-                       err = binary.Write(buf, binary.BigEndian, msg.Index)
+                       err = binary.Write(&buf, binary.BigEndian, msg.Index)
                case Request, Cancel, Reject:
                        for _, i := range []Integer{msg.Index, msg.Begin, msg.Length} {
-                               err = binary.Write(buf, binary.BigEndian, i)
+                               err = binary.Write(&buf, binary.BigEndian, i)
                                if err != nil {
                                        break
                                }
@@ -72,7 +79,7 @@ func (msg Message) MarshalBinary() (data []byte, err error) {
                        _, err = buf.Write(marshalBitfield(msg.Bitfield))
                case Piece:
                        for _, i := range []Integer{msg.Index, msg.Begin} {
-                               err = binary.Write(buf, binary.BigEndian, i)
+                               err = binary.Write(&buf, binary.BigEndian, i)
                                if err != nil {
                                        return
                                }
@@ -91,7 +98,7 @@ func (msg Message) MarshalBinary() (data []byte, err error) {
                        }
                        _, err = buf.Write(msg.ExtendedPayload)
                case Port:
-                       err = binary.Write(buf, binary.BigEndian, msg.Port)
+                       err = binary.Write(&buf, binary.BigEndian, msg.Port)
                default:
                        err = fmt.Errorf("unknown message type: %v", msg.Type)
                }
@@ -116,3 +123,17 @@ func marshalBitfield(bf []bool) (b []byte) {
        }
        return
 }
+
+func (me *Message) UnmarshalBinary(b []byte) error {
+       d := Decoder{
+               R: bufio.NewReader(bytes.NewReader(b)),
+       }
+       err := d.Decode(me)
+       if err != nil {
+               return err
+       }
+       if d.R.Buffered() != 0 {
+               return fmt.Errorf("%d trailing bytes", d.R.Buffered())
+       }
+       return nil
+}