From: Matt Joiner Date: Fri, 2 Feb 2018 10:29:57 +0000 (+1100) Subject: Break up peer_protocol into several files X-Git-Tag: v1.0.0~225 X-Git-Url: http://www.git.stargrave.org/?a=commitdiff_plain;h=b610107d8d08930272b2f034c6b12ecfebda74c9;p=btrtrc.git Break up peer_protocol into several files --- diff --git a/peer_protocol/decoder.go b/peer_protocol/decoder.go new file mode 100644 index 00000000..2d82c1c6 --- /dev/null +++ b/peer_protocol/decoder.go @@ -0,0 +1,124 @@ +package peer_protocol + +import ( + "bufio" + "encoding/binary" + "errors" + "fmt" + "io" + "io/ioutil" + "sync" +) + +type Decoder struct { + R *bufio.Reader + Pool *sync.Pool + MaxLength Integer // TODO: Should this include the length header or not? +} + +// io.EOF is returned if the source terminates cleanly on a message boundary. +func (d *Decoder) Decode(msg *Message) (err error) { + var length Integer + err = binary.Read(d.R, binary.BigEndian, &length) + if err != nil { + if err != io.EOF { + err = fmt.Errorf("error reading message length: %s", err) + } + return + } + if length > d.MaxLength { + return errors.New("message too long") + } + if length == 0 { + msg.Keepalive = true + return + } + msg.Keepalive = false + r := &io.LimitedReader{d.R, int64(length)} + // Check that all of r was utilized. + defer func() { + if err != nil { + return + } + if r.N != 0 { + err = fmt.Errorf("%d bytes unused in message type %d", r.N, msg.Type) + } + }() + msg.Keepalive = false + c, err := readByte(r) + if err != nil { + return + } + msg.Type = MessageType(c) + switch msg.Type { + case Choke, Unchoke, Interested, NotInterested, HaveAll, HaveNone: + return + case Have: + err = msg.Index.Read(r) + case Request, Cancel, Reject: + for _, data := range []*Integer{&msg.Index, &msg.Begin, &msg.Length} { + err = data.Read(r) + if err != nil { + break + } + } + case Bitfield: + b := make([]byte, length-1) + _, err = io.ReadFull(r, b) + msg.Bitfield = unmarshalBitfield(b) + case Piece: + for _, pi := range []*Integer{&msg.Index, &msg.Begin} { + err = pi.Read(r) + if err != nil { + break + } + } + if err != nil { + break + } + //msg.Piece, err = ioutil.ReadAll(r) + b := *d.Pool.Get().(*[]byte) + n, err := io.ReadFull(r, b) + if err != nil { + if err != io.ErrUnexpectedEOF || n != int(length-9) { + return err + } + b = b[0:n] + } + msg.Piece = b + case Extended: + msg.ExtendedID, err = readByte(r) + if err != nil { + break + } + msg.ExtendedPayload, err = ioutil.ReadAll(r) + case Port: + err = binary.Read(r, binary.BigEndian, &msg.Port) + default: + err = fmt.Errorf("unknown message type %#v", c) + } + return +} + +func readByte(r io.Reader) (b byte, err error) { + var arr [1]byte + n, err := r.Read(arr[:]) + b = arr[0] + if n == 1 { + err = nil + return + } + if err == nil { + panic(err) + } + return +} + +func unmarshalBitfield(b []byte) (bf []bool) { + for _, c := range b { + for i := 7; i >= 0; i-- { + bf = append(bf, (c>>uint(i))&1 == 1) + } + } + return +} diff --git a/peer_protocol/int.go b/peer_protocol/int.go new file mode 100644 index 00000000..6c43da44 --- /dev/null +++ b/peer_protocol/int.go @@ -0,0 +1,21 @@ +package peer_protocol + +import ( + "encoding/binary" + "io" +) + +type Integer uint32 + +func (i *Integer) Read(r io.Reader) error { + return binary.Read(r, binary.BigEndian, i) +} + +// It's perfectly fine to cast these to an int. TODO: Or is it? +func (i Integer) Int() int { + return int(i) +} + +func (i Integer) Uint64() uint64 { + return uint64(i) +} diff --git a/peer_protocol/msg.go b/peer_protocol/msg.go new file mode 100644 index 00000000..d5c316c9 --- /dev/null +++ b/peer_protocol/msg.go @@ -0,0 +1,102 @@ +package peer_protocol + +import ( + "bytes" + "encoding/binary" + "fmt" +) + +type Message struct { + Keepalive bool + Type MessageType + Index, Begin, Length Integer + Piece []byte + Bitfield []bool + ExtendedID byte + ExtendedPayload []byte + Port uint16 +} + +func MakeCancelMessage(piece, offset, length Integer) Message { + return Message{ + Type: Cancel, + Index: piece, + Begin: offset, + Length: length, + } +} + +func (msg Message) MustMarshalBinary() []byte { + b, err := msg.MarshalBinary() + if err != nil { + panic(err) + } + return b +} + +func (msg Message) MarshalBinary() (data []byte, err error) { + buf := &bytes.Buffer{} + if !msg.Keepalive { + err = buf.WriteByte(byte(msg.Type)) + if err != nil { + return + } + switch msg.Type { + case Choke, Unchoke, Interested, NotInterested, HaveAll, HaveNone: + case Have: + err = binary.Write(buf, binary.BigEndian, msg.Index) + case Request, Cancel, Reject: + for _, i := range []Integer{msg.Index, msg.Begin, msg.Length} { + err = binary.Write(buf, binary.BigEndian, i) + if err != nil { + break + } + } + case Bitfield: + _, err = buf.Write(marshalBitfield(msg.Bitfield)) + case Piece: + for _, i := range []Integer{msg.Index, msg.Begin} { + err = binary.Write(buf, binary.BigEndian, i) + if err != nil { + return + } + } + n, err := buf.Write(msg.Piece) + if err != nil { + break + } + if n != len(msg.Piece) { + panic(n) + } + case Extended: + err = buf.WriteByte(msg.ExtendedID) + if err != nil { + return + } + _, err = buf.Write(msg.ExtendedPayload) + case Port: + err = binary.Write(buf, binary.BigEndian, msg.Port) + default: + err = fmt.Errorf("unknown message type: %v", msg.Type) + } + } + data = make([]byte, 4+buf.Len()) + binary.BigEndian.PutUint32(data, uint32(buf.Len())) + if buf.Len() != copy(data[4:], buf.Bytes()) { + panic("bad copy") + } + return +} + +func marshalBitfield(bf []bool) (b []byte) { + b = make([]byte, (len(bf)+7)/8) + for i, have := range bf { + if !have { + continue + } + c := b[i/8] + c |= 1 << uint(7-i%8) + b[i/8] = c + } + return +} diff --git a/peer_protocol/protocol.go b/peer_protocol/protocol.go index 27827eed..8f914591 100644 --- a/peer_protocol/protocol.go +++ b/peer_protocol/protocol.go @@ -1,36 +1,11 @@ package peer_protocol -import ( - "bufio" - "bytes" - "encoding/binary" - "errors" - "fmt" - "io" - "io/ioutil" - "sync" +const ( + Protocol = "\x13BitTorrent protocol" ) type ( MessageType byte - Integer uint32 -) - -func (i *Integer) Read(r io.Reader) error { - return binary.Read(r, binary.BigEndian, i) -} - -// It's perfectly fine to cast these to an int. TODO: Or is it? -func (i Integer) Int() int { - return int(i) -} - -func (i Integer) Uint64() uint64 { - return uint64(i) -} - -const ( - Protocol = "\x13BitTorrent protocol" ) const ( @@ -60,217 +35,3 @@ const ( DataMetadataExtensionMsgType = 1 RejectMetadataExtensionMsgType = 2 ) - -type Message struct { - Keepalive bool - Type MessageType - Index, Begin, Length Integer - Piece []byte - Bitfield []bool - ExtendedID byte - ExtendedPayload []byte - Port uint16 -} - -func (msg Message) MustMarshalBinary() []byte { - b, err := msg.MarshalBinary() - if err != nil { - panic(err) - } - return b -} - -func (msg Message) MarshalBinary() (data []byte, err error) { - buf := &bytes.Buffer{} - if !msg.Keepalive { - err = buf.WriteByte(byte(msg.Type)) - if err != nil { - return - } - switch msg.Type { - case Choke, Unchoke, Interested, NotInterested, HaveAll, HaveNone: - case Have: - err = binary.Write(buf, binary.BigEndian, msg.Index) - case Request, Cancel, Reject: - for _, i := range []Integer{msg.Index, msg.Begin, msg.Length} { - err = binary.Write(buf, binary.BigEndian, i) - if err != nil { - break - } - } - case Bitfield: - _, err = buf.Write(marshalBitfield(msg.Bitfield)) - case Piece: - for _, i := range []Integer{msg.Index, msg.Begin} { - err = binary.Write(buf, binary.BigEndian, i) - if err != nil { - return - } - } - n, err := buf.Write(msg.Piece) - if err != nil { - break - } - if n != len(msg.Piece) { - panic(n) - } - case Extended: - err = buf.WriteByte(msg.ExtendedID) - if err != nil { - return - } - _, err = buf.Write(msg.ExtendedPayload) - case Port: - err = binary.Write(buf, binary.BigEndian, msg.Port) - default: - err = fmt.Errorf("unknown message type: %v", msg.Type) - } - } - data = make([]byte, 4+buf.Len()) - binary.BigEndian.PutUint32(data, uint32(buf.Len())) - if buf.Len() != copy(data[4:], buf.Bytes()) { - panic("bad copy") - } - return -} - -type Decoder struct { - R *bufio.Reader - Pool *sync.Pool - MaxLength Integer // TODO: Should this include the length header or not? -} - -func readByte(r io.Reader) (b byte, err error) { - var arr [1]byte - n, err := r.Read(arr[:]) - b = arr[0] - if n == 1 { - err = nil - return - } - if err == nil { - panic(err) - } - return -} - -// io.EOF is returned if the source terminates cleanly on a message boundary. -func (d *Decoder) Decode(msg *Message) (err error) { - var length Integer - err = binary.Read(d.R, binary.BigEndian, &length) - if err != nil { - if err != io.EOF { - err = fmt.Errorf("error reading message length: %s", err) - } - return - } - if length > d.MaxLength { - return errors.New("message too long") - } - if length == 0 { - msg.Keepalive = true - return - } - msg.Keepalive = false - r := &io.LimitedReader{d.R, int64(length)} - // Check that all of r was utilized. - defer func() { - if err != nil { - return - } - if r.N != 0 { - err = fmt.Errorf("%d bytes unused in message type %d", r.N, msg.Type) - } - }() - msg.Keepalive = false - c, err := readByte(r) - if err != nil { - return - } - msg.Type = MessageType(c) - switch msg.Type { - case Choke, Unchoke, Interested, NotInterested, HaveAll, HaveNone: - return - case Have: - err = msg.Index.Read(r) - case Request, Cancel, Reject: - for _, data := range []*Integer{&msg.Index, &msg.Begin, &msg.Length} { - err = data.Read(r) - if err != nil { - break - } - } - case Bitfield: - b := make([]byte, length-1) - _, err = io.ReadFull(r, b) - msg.Bitfield = unmarshalBitfield(b) - case Piece: - for _, pi := range []*Integer{&msg.Index, &msg.Begin} { - err = pi.Read(r) - if err != nil { - break - } - } - if err != nil { - break - } - //msg.Piece, err = ioutil.ReadAll(r) - b := *d.Pool.Get().(*[]byte) - n, err := io.ReadFull(r, b) - if err != nil { - if err != io.ErrUnexpectedEOF || n != int(length-9) { - return err - } - b = b[0:n] - } - msg.Piece = b - case Extended: - msg.ExtendedID, err = readByte(r) - if err != nil { - break - } - msg.ExtendedPayload, err = ioutil.ReadAll(r) - case Port: - err = binary.Read(r, binary.BigEndian, &msg.Port) - default: - err = fmt.Errorf("unknown message type %#v", c) - } - return -} - -type Bytes []byte - -func (b Bytes) MarshalBinary() ([]byte, error) { - return b, nil -} - -func unmarshalBitfield(b []byte) (bf []bool) { - for _, c := range b { - for i := 7; i >= 0; i-- { - bf = append(bf, (c>>uint(i))&1 == 1) - } - } - return -} - -func marshalBitfield(bf []bool) (b []byte) { - b = make([]byte, (len(bf)+7)/8) - for i, have := range bf { - if !have { - continue - } - c := b[i/8] - c |= 1 << uint(7-i%8) - b[i/8] = c - } - return -} - -func MakeCancelMessage(piece, offset, length Integer) Message { - return Message{ - Type: Cancel, - Index: piece, - Begin: offset, - Length: length, - } -}