package peer_protocol
import (
+ "bufio"
"bytes"
+ "encoding"
"encoding/binary"
"fmt"
)
+// This is a lazy union representing all the possible fields for messages. Go doesn't have ADTs, and
+// I didn't choose to use type-assertions.
type Message struct {
Keepalive bool
Type MessageType
Port uint16
}
+var _ interface {
+ encoding.BinaryUnmarshaler
+ encoding.BinaryMarshaler
+} = (*Message)(nil)
+
func MakeCancelMessage(piece, offset, length Integer) Message {
return Message{
Type: Cancel,
}
}
-func (msg Message) RequestSpec() RequestSpec {
- return RequestSpec{msg.Index, msg.Begin, msg.Length}
+func (msg Message) RequestSpec() (ret RequestSpec) {
+ return RequestSpec{
+ msg.Index,
+ msg.Begin,
+ func() Integer {
+ if msg.Type == Piece {
+ return Integer(len(msg.Piece))
+ } else {
+ return msg.Length
+ }
+ }(),
+ }
}
func (msg Message) MustMarshalBinary() []byte {
}
func (msg Message) MarshalBinary() (data []byte, err error) {
- buf := &bytes.Buffer{}
+ var buf bytes.Buffer
if !msg.Keepalive {
err = buf.WriteByte(byte(msg.Type))
if err != nil {
switch msg.Type {
case Choke, Unchoke, Interested, NotInterested, HaveAll, HaveNone:
case Have:
- err = binary.Write(buf, binary.BigEndian, msg.Index)
+ err = binary.Write(&buf, binary.BigEndian, msg.Index)
case Request, Cancel, Reject:
for _, i := range []Integer{msg.Index, msg.Begin, msg.Length} {
- err = binary.Write(buf, binary.BigEndian, i)
+ err = binary.Write(&buf, binary.BigEndian, i)
if err != nil {
break
}
_, err = buf.Write(marshalBitfield(msg.Bitfield))
case Piece:
for _, i := range []Integer{msg.Index, msg.Begin} {
- err = binary.Write(buf, binary.BigEndian, i)
+ err = binary.Write(&buf, binary.BigEndian, i)
if err != nil {
return
}
}
_, err = buf.Write(msg.ExtendedPayload)
case Port:
- err = binary.Write(buf, binary.BigEndian, msg.Port)
+ err = binary.Write(&buf, binary.BigEndian, msg.Port)
default:
err = fmt.Errorf("unknown message type: %v", msg.Type)
}
}
return
}
+
+func (me *Message) UnmarshalBinary(b []byte) error {
+ d := Decoder{
+ R: bufio.NewReader(bytes.NewReader(b)),
+ }
+ err := d.Decode(me)
+ if err != nil {
+ return err
+ }
+ if d.R.Buffered() != 0 {
+ return fmt.Errorf("%d trailing bytes", d.R.Buffered())
+ }
+ return nil
+}