]> Sergey Matveev's repositories - btrtrc.git/blobdiff - peer_protocol/msg.go
Skip test failures due to Go uTP implementation
[btrtrc.git] / peer_protocol / msg.go
index 7d3764b714223cafe832cc73b7a8f708694cf968..23710e6339d4cd93b79071cdb4c17b0165fbf5cb 100644 (file)
@@ -1,11 +1,15 @@
 package peer_protocol
 
 import (
+       "bufio"
        "bytes"
+       "encoding"
        "encoding/binary"
        "fmt"
 )
 
+// This is a lazy union representing all the possible fields for messages. Go doesn't have ADTs, and
+// I didn't choose to use type-assertions.
 type Message struct {
        Keepalive            bool
        Type                 MessageType
@@ -17,6 +21,11 @@ type Message struct {
        Port                 uint16
 }
 
+var _ interface {
+       encoding.BinaryUnmarshaler
+       encoding.BinaryMarshaler
+} = (*Message)(nil)
+
 func MakeCancelMessage(piece, offset, length Integer) Message {
        return Message{
                Type:   Cancel,
@@ -26,8 +35,18 @@ func MakeCancelMessage(piece, offset, length Integer) Message {
        }
 }
 
-func (msg Message) RequestSpec() RequestSpec {
-       return RequestSpec{msg.Index, msg.Begin, msg.Length}
+func (msg Message) RequestSpec() (ret RequestSpec) {
+       return RequestSpec{
+               msg.Index,
+               msg.Begin,
+               func() Integer {
+                       if msg.Type == Piece {
+                               return Integer(len(msg.Piece))
+                       } else {
+                               return msg.Length
+                       }
+               }(),
+       }
 }
 
 func (msg Message) MustMarshalBinary() []byte {
@@ -39,7 +58,7 @@ func (msg Message) MustMarshalBinary() []byte {
 }
 
 func (msg Message) MarshalBinary() (data []byte, err error) {
-       buf := &bytes.Buffer{}
+       var buf bytes.Buffer
        if !msg.Keepalive {
                err = buf.WriteByte(byte(msg.Type))
                if err != nil {
@@ -48,10 +67,10 @@ func (msg Message) MarshalBinary() (data []byte, err error) {
                switch msg.Type {
                case Choke, Unchoke, Interested, NotInterested, HaveAll, HaveNone:
                case Have:
-                       err = binary.Write(buf, binary.BigEndian, msg.Index)
+                       err = binary.Write(&buf, binary.BigEndian, msg.Index)
                case Request, Cancel, Reject:
                        for _, i := range []Integer{msg.Index, msg.Begin, msg.Length} {
-                               err = binary.Write(buf, binary.BigEndian, i)
+                               err = binary.Write(&buf, binary.BigEndian, i)
                                if err != nil {
                                        break
                                }
@@ -60,7 +79,7 @@ func (msg Message) MarshalBinary() (data []byte, err error) {
                        _, err = buf.Write(marshalBitfield(msg.Bitfield))
                case Piece:
                        for _, i := range []Integer{msg.Index, msg.Begin} {
-                               err = binary.Write(buf, binary.BigEndian, i)
+                               err = binary.Write(&buf, binary.BigEndian, i)
                                if err != nil {
                                        return
                                }
@@ -79,7 +98,7 @@ func (msg Message) MarshalBinary() (data []byte, err error) {
                        }
                        _, err = buf.Write(msg.ExtendedPayload)
                case Port:
-                       err = binary.Write(buf, binary.BigEndian, msg.Port)
+                       err = binary.Write(&buf, binary.BigEndian, msg.Port)
                default:
                        err = fmt.Errorf("unknown message type: %v", msg.Type)
                }
@@ -104,3 +123,17 @@ func marshalBitfield(bf []bool) (b []byte) {
        }
        return
 }
+
+func (me *Message) UnmarshalBinary(b []byte) error {
+       d := Decoder{
+               R: bufio.NewReader(bytes.NewReader(b)),
+       }
+       err := d.Decode(me)
+       if err != nil {
+               return err
+       }
+       if d.R.Buffered() != 0 {
+               return fmt.Errorf("%d trailing bytes", d.R.Buffered())
+       }
+       return nil
+}