"bufio"
"encoding/binary"
"fmt"
+ g "github.com/anacrolix/generics"
"io"
"sync"
MaxLength Integer // TODO: Should this include the length header or not?
}
-// io.EOF is returned if the source terminates cleanly on a message boundary.
-func (d *Decoder) Decode(msg *Message) (err error) {
- var length Integer
- err = length.Read(d.R)
- if err != nil {
- return fmt.Errorf("reading message length: %w", err)
+// This limits reads to the length of a message, returning io.EOF when the end of the message bytes
+// are reached. If you aren't expecting io.EOF, you should probably wrap it with expectReader.
+type decodeReader struct {
+ lr io.LimitedReader
+ br *bufio.Reader
+}
+
+func (dr *decodeReader) Init(r *bufio.Reader, length int64) {
+ dr.lr.R = r
+ dr.lr.N = length
+ dr.br = r
+}
+
+func (dr *decodeReader) ReadByte() (c byte, err error) {
+ if dr.lr.N <= 0 {
+ err = io.EOF
+ return
}
- if length > d.MaxLength {
- return errors.New("message too long")
+ c, err = dr.br.ReadByte()
+ if err == nil {
+ dr.lr.N--
}
- if length == 0 {
- msg.Keepalive = true
- return
+ return
+}
+
+func (dr *decodeReader) Read(p []byte) (n int, err error) {
+ n, err = dr.lr.Read(p)
+ if dr.lr.N != 0 && err == io.EOF {
+ err = io.ErrUnexpectedEOF
+ }
+ return
+}
+
+func (dr *decodeReader) UnreadLength() int64 {
+ return dr.lr.N
+}
+
+// This expects reads to have enough bytes. io.EOF is mapped to io.ErrUnexpectedEOF. It's probably
+// not a good idea to pass this to functions that expect to read until the end of something, because
+// they will probably expect io.EOF.
+type expectReader struct {
+ dr *decodeReader
+}
+
+func (er expectReader) ReadByte() (c byte, err error) {
+ c, err = er.dr.ReadByte()
+ if err == io.EOF {
+ err = io.ErrUnexpectedEOF
}
- r := d.R
- readByte := func() (byte, error) {
- length--
- return d.R.ReadByte()
+ return
+}
+
+func (er expectReader) Read(p []byte) (n int, err error) {
+ n, err = er.dr.Read(p)
+ if err == io.EOF {
+ err = io.ErrUnexpectedEOF
}
- // From this point onwards, EOF is unexpected
- defer func() {
- if err == io.EOF {
- err = io.ErrUnexpectedEOF
+ return
+}
+
+func (er expectReader) UnreadLength() int64 {
+ return er.dr.UnreadLength()
+}
+
+// io.EOF is returned if the source terminates cleanly on a message boundary.
+func (d *Decoder) Decode(msg *Message) (err error) {
+ var dr decodeReader
+ {
+ var length Integer
+ err = length.Read(d.R)
+ if err != nil {
+ return fmt.Errorf("reading message length: %w", err)
}
- }()
- c, err := readByte()
+ if length > d.MaxLength {
+ return errors.New("message too long")
+ }
+ if length == 0 {
+ msg.Keepalive = true
+ return
+ }
+ dr.Init(d.R, int64(length))
+ }
+ r := expectReader{&dr}
+ c, err := r.ReadByte()
if err != nil {
return
}
msg.Type = MessageType(c)
- // Can return directly in cases when err is not nil, or length is known to be zero.
+ err = readMessageAfterType(msg, &r, d.Pool)
+ if err != nil {
+ err = fmt.Errorf("reading fields for message type %v: %w", msg.Type, err)
+ return
+ }
+ if r.UnreadLength() != 0 {
+ err = fmt.Errorf("%v unused bytes in message type %v", r.UnreadLength(), msg.Type)
+ }
+ return
+}
+
+func readMessageAfterType(msg *Message, r *expectReader, piecePool *sync.Pool) (err error) {
switch msg.Type {
case Choke, Unchoke, Interested, NotInterested, HaveAll, HaveNone:
case Have, AllowedFast, Suggest:
- length -= 4
err = msg.Index.Read(r)
case Request, Cancel, Reject:
for _, data := range []*Integer{&msg.Index, &msg.Begin, &msg.Length} {
break
}
}
- length -= 12
case Bitfield:
- b := make([]byte, length)
+ b := make([]byte, r.UnreadLength())
_, err = io.ReadFull(r, b)
- length = 0
msg.Bitfield = unmarshalBitfield(b)
- return
case Piece:
for _, pi := range []*Integer{&msg.Index, &msg.Begin} {
- err := pi.Read(r)
+ err = pi.Read(r)
if err != nil {
- return err
+ return
}
}
- length -= 8
- dataLen := int64(length)
- if d.Pool == nil {
+ dataLen := r.UnreadLength()
+ if piecePool == nil {
msg.Piece = make([]byte, dataLen)
} else {
- msg.Piece = *d.Pool.Get().(*[]byte)
+ msg.Piece = *piecePool.Get().(*[]byte)
if int64(cap(msg.Piece)) < dataLen {
return errors.New("piece data longer than expected")
}
msg.Piece = msg.Piece[:dataLen]
}
_, err = io.ReadFull(r, msg.Piece)
- length = 0
- return
case Extended:
var b byte
- b, err = readByte()
+ b, err = r.ReadByte()
if err != nil {
break
}
msg.ExtendedID = ExtensionNumber(b)
- msg.ExtendedPayload = make([]byte, length)
+ msg.ExtendedPayload = make([]byte, r.UnreadLength())
_, err = io.ReadFull(r, msg.ExtendedPayload)
- length = 0
- return
case Port:
err = binary.Read(r, binary.BigEndian, &msg.Port)
- length -= 2
+ case HashRequest, HashReject:
+ err = readHashRequest(r, msg)
+ case Hashes:
+ err = readHashRequest(r, msg)
+ numHashes := (r.UnreadLength() + 31) / 32
+ g.MakeSliceWithCap(&msg.Hashes, numHashes)
+ for range numHashes {
+ var oneHash [32]byte
+ _, err = io.ReadFull(r, oneHash[:])
+ if err != nil {
+ err = fmt.Errorf("error while reading hashes: %w", err)
+ return
+ }
+ msg.Hashes = append(msg.Hashes, oneHash)
+ }
default:
- err = fmt.Errorf("unknown message type %#v", c)
- }
- if err == nil && length != 0 {
- err = fmt.Errorf("%v unused bytes in message type %v", length, msg.Type)
+ err = errors.New("unhandled message type")
}
return
}
-func readByte(r io.Reader) (b byte, err error) {
- var arr [1]byte
- n, err := r.Read(arr[:])
- b = arr[0]
- if n == 1 {
- err = nil
+func readHashRequest(r io.Reader, msg *Message) (err error) {
+ _, err = io.ReadFull(r, msg.PiecesRoot[:])
+ if err != nil {
return
}
- if err == nil {
- panic(err)
+ return readSeq(r, &msg.BaseLayer, &msg.Index, &msg.Length, &msg.ProofLayers)
+}
+
+func readSeq(r io.Reader, data ...any) (err error) {
+ for _, d := range data {
+ err = binary.Read(r, binary.BigEndian, d)
+ if err != nil {
+ return
+ }
}
return
}
import "strconv"
+func _() {
+ // An "invalid array index" compiler error signifies that the constant values have changed.
+ // Re-run the stringer command to generate them again.
+ var x [1]struct{}
+ _ = x[Choke-0]
+ _ = x[Unchoke-1]
+ _ = x[Interested-2]
+ _ = x[NotInterested-3]
+ _ = x[Have-4]
+ _ = x[Bitfield-5]
+ _ = x[Request-6]
+ _ = x[Piece-7]
+ _ = x[Cancel-8]
+ _ = x[Port-9]
+ _ = x[Suggest-13]
+ _ = x[HaveAll-14]
+ _ = x[HaveNone-15]
+ _ = x[Reject-16]
+ _ = x[AllowedFast-17]
+ _ = x[Extended-20]
+ _ = x[HashRequest-21]
+ _ = x[Hashes-22]
+ _ = x[HashReject-23]
+}
+
const (
_MessageType_name_0 = "ChokeUnchokeInterestedNotInterestedHaveBitfieldRequestPieceCancelPort"
_MessageType_name_1 = "SuggestHaveAllHaveNoneRejectAllowedFast"
- _MessageType_name_2 = "Extended"
+ _MessageType_name_2 = "ExtendedHashRequestHashesHashReject"
)
var (
_MessageType_index_0 = [...]uint8{0, 5, 12, 22, 35, 39, 47, 54, 59, 65, 69}
_MessageType_index_1 = [...]uint8{0, 7, 14, 22, 28, 39}
+ _MessageType_index_2 = [...]uint8{0, 8, 19, 25, 35}
)
func (i MessageType) String() string {
case 13 <= i && i <= 17:
i -= 13
return _MessageType_name_1[_MessageType_index_1[i]:_MessageType_index_1[i+1]]
- case i == 20:
- return _MessageType_name_2
+ case 20 <= i && i <= 23:
+ i -= 20
+ return _MessageType_name_2[_MessageType_index_2[i]:_MessageType_index_2[i+1]]
default:
return "MessageType(" + strconv.FormatInt(int64(i), 10) + ")"
}