7 g "github.com/anacrolix/generics"
11 "github.com/pkg/errors"
16 // This must return *[]byte where the slices can fit data for piece messages. I think we store
17 // *[]byte in the pool to avoid an extra allocation every time we put the slice back into the
18 // pool. The chunk size should not change for the life of the decoder.
20 MaxLength Integer // TODO: Should this include the length header or not?
23 // This limits reads to the length of a message, returning io.EOF when the end of the message bytes
24 // are reached. If you aren't expecting io.EOF, you should probably wrap it with expectReader.
25 type decodeReader struct {
30 func (dr *decodeReader) Init(r *bufio.Reader, length int64) {
36 func (dr *decodeReader) ReadByte() (c byte, err error) {
41 c, err = dr.br.ReadByte()
48 func (dr *decodeReader) Read(p []byte) (n int, err error) {
49 n, err = dr.lr.Read(p)
50 if dr.lr.N != 0 && err == io.EOF {
51 err = io.ErrUnexpectedEOF
56 func (dr *decodeReader) UnreadLength() int64 {
60 // This expects reads to have enough bytes. io.EOF is mapped to io.ErrUnexpectedEOF. It's probably
61 // not a good idea to pass this to functions that expect to read until the end of something, because
62 // they will probably expect io.EOF.
63 type expectReader struct {
67 func (er expectReader) ReadByte() (c byte, err error) {
68 c, err = er.dr.ReadByte()
70 err = io.ErrUnexpectedEOF
75 func (er expectReader) Read(p []byte) (n int, err error) {
76 n, err = er.dr.Read(p)
78 err = io.ErrUnexpectedEOF
83 func (er expectReader) UnreadLength() int64 {
84 return er.dr.UnreadLength()
87 // io.EOF is returned if the source terminates cleanly on a message boundary.
88 func (d *Decoder) Decode(msg *Message) (err error) {
92 err = length.Read(d.R)
94 return fmt.Errorf("reading message length: %w", err)
96 if length > d.MaxLength {
97 return errors.New("message too long")
103 dr.Init(d.R, int64(length))
105 r := expectReader{&dr}
106 c, err := r.ReadByte()
110 msg.Type = MessageType(c)
111 err = readMessageAfterType(msg, &r, d.Pool)
113 err = fmt.Errorf("reading fields for message type %v: %w", msg.Type, err)
116 if r.UnreadLength() != 0 {
117 err = fmt.Errorf("%v unused bytes in message type %v", r.UnreadLength(), msg.Type)
122 func readMessageAfterType(msg *Message, r *expectReader, piecePool *sync.Pool) (err error) {
124 case Choke, Unchoke, Interested, NotInterested, HaveAll, HaveNone:
125 case Have, AllowedFast, Suggest:
126 err = msg.Index.Read(r)
127 case Request, Cancel, Reject:
128 for _, data := range []*Integer{&msg.Index, &msg.Begin, &msg.Length} {
135 b := make([]byte, r.UnreadLength())
136 _, err = io.ReadFull(r, b)
137 msg.Bitfield = unmarshalBitfield(b)
139 for _, pi := range []*Integer{&msg.Index, &msg.Begin} {
145 dataLen := r.UnreadLength()
146 if piecePool == nil {
147 msg.Piece = make([]byte, dataLen)
149 msg.Piece = *piecePool.Get().(*[]byte)
150 if int64(cap(msg.Piece)) < dataLen {
151 return errors.New("piece data longer than expected")
153 msg.Piece = msg.Piece[:dataLen]
155 _, err = io.ReadFull(r, msg.Piece)
158 b, err = r.ReadByte()
162 msg.ExtendedID = ExtensionNumber(b)
163 msg.ExtendedPayload = make([]byte, r.UnreadLength())
164 _, err = io.ReadFull(r, msg.ExtendedPayload)
166 err = binary.Read(r, binary.BigEndian, &msg.Port)
167 case HashRequest, HashReject:
168 err = readHashRequest(r, msg)
170 err = readHashRequest(r, msg)
171 numHashes := (r.UnreadLength() + 31) / 32
172 g.MakeSliceWithCap(&msg.Hashes, numHashes)
173 for range numHashes {
175 _, err = io.ReadFull(r, oneHash[:])
177 err = fmt.Errorf("error while reading hashes: %w", err)
180 msg.Hashes = append(msg.Hashes, oneHash)
183 err = errors.New("unhandled message type")
188 func readHashRequest(r io.Reader, msg *Message) (err error) {
189 _, err = io.ReadFull(r, msg.PiecesRoot[:])
193 return readSeq(r, &msg.BaseLayer, &msg.Index, &msg.Length, &msg.ProofLayers)
196 func readSeq(r io.Reader, data ...any) (err error) {
197 for _, d := range data {
198 err = binary.Read(r, binary.BigEndian, d)
206 func unmarshalBitfield(b []byte) (bf []bool) {
207 for _, c := range b {
208 for i := 7; i >= 0; i-- {
209 bf = append(bf, (c>>uint(i))&1 == 1)