]> Sergey Matveev's repositories - btrtrc.git/blob - peer_protocol/decoder.go
Implement decoding hash request, reject and hashes
[btrtrc.git] / peer_protocol / decoder.go
1 package peer_protocol
2
3 import (
4         "bufio"
5         "encoding/binary"
6         "fmt"
7         g "github.com/anacrolix/generics"
8         "io"
9         "sync"
10
11         "github.com/pkg/errors"
12 )
13
14 type Decoder struct {
15         R *bufio.Reader
16         // This must return *[]byte where the slices can fit data for piece messages. I think we store
17         // *[]byte in the pool to avoid an extra allocation every time we put the slice back into the
18         // pool. The chunk size should not change for the life of the decoder.
19         Pool      *sync.Pool
20         MaxLength Integer // TODO: Should this include the length header or not?
21 }
22
23 // This limits reads to the length of a message, returning io.EOF when the end of the message bytes
24 // are reached. If you aren't expecting io.EOF, you should probably wrap it with expectReader.
25 type decodeReader struct {
26         lr io.LimitedReader
27         br *bufio.Reader
28 }
29
30 func (dr *decodeReader) Init(r *bufio.Reader, length int64) {
31         dr.lr.R = r
32         dr.lr.N = length
33         dr.br = r
34 }
35
36 func (dr *decodeReader) ReadByte() (c byte, err error) {
37         if dr.lr.N <= 0 {
38                 err = io.EOF
39                 return
40         }
41         c, err = dr.br.ReadByte()
42         if err == nil {
43                 dr.lr.N--
44         }
45         return
46 }
47
48 func (dr *decodeReader) Read(p []byte) (n int, err error) {
49         n, err = dr.lr.Read(p)
50         if dr.lr.N != 0 && err == io.EOF {
51                 err = io.ErrUnexpectedEOF
52         }
53         return
54 }
55
56 func (dr *decodeReader) UnreadLength() int64 {
57         return dr.lr.N
58 }
59
60 // This expects reads to have enough bytes. io.EOF is mapped to io.ErrUnexpectedEOF. It's probably
61 // not a good idea to pass this to functions that expect to read until the end of something, because
62 // they will probably expect io.EOF.
63 type expectReader struct {
64         dr *decodeReader
65 }
66
67 func (er expectReader) ReadByte() (c byte, err error) {
68         c, err = er.dr.ReadByte()
69         if err == io.EOF {
70                 err = io.ErrUnexpectedEOF
71         }
72         return
73 }
74
75 func (er expectReader) Read(p []byte) (n int, err error) {
76         n, err = er.dr.Read(p)
77         if err == io.EOF {
78                 err = io.ErrUnexpectedEOF
79         }
80         return
81 }
82
83 func (er expectReader) UnreadLength() int64 {
84         return er.dr.UnreadLength()
85 }
86
87 // io.EOF is returned if the source terminates cleanly on a message boundary.
88 func (d *Decoder) Decode(msg *Message) (err error) {
89         var dr decodeReader
90         {
91                 var length Integer
92                 err = length.Read(d.R)
93                 if err != nil {
94                         return fmt.Errorf("reading message length: %w", err)
95                 }
96                 if length > d.MaxLength {
97                         return errors.New("message too long")
98                 }
99                 if length == 0 {
100                         msg.Keepalive = true
101                         return
102                 }
103                 dr.Init(d.R, int64(length))
104         }
105         r := expectReader{&dr}
106         c, err := r.ReadByte()
107         if err != nil {
108                 return
109         }
110         msg.Type = MessageType(c)
111         err = readMessageAfterType(msg, &r, d.Pool)
112         if err != nil {
113                 err = fmt.Errorf("reading fields for message type %v: %w", msg.Type, err)
114                 return
115         }
116         if r.UnreadLength() != 0 {
117                 err = fmt.Errorf("%v unused bytes in message type %v", r.UnreadLength(), msg.Type)
118         }
119         return
120 }
121
122 func readMessageAfterType(msg *Message, r *expectReader, piecePool *sync.Pool) (err error) {
123         switch msg.Type {
124         case Choke, Unchoke, Interested, NotInterested, HaveAll, HaveNone:
125         case Have, AllowedFast, Suggest:
126                 err = msg.Index.Read(r)
127         case Request, Cancel, Reject:
128                 for _, data := range []*Integer{&msg.Index, &msg.Begin, &msg.Length} {
129                         err = data.Read(r)
130                         if err != nil {
131                                 break
132                         }
133                 }
134         case Bitfield:
135                 b := make([]byte, r.UnreadLength())
136                 _, err = io.ReadFull(r, b)
137                 msg.Bitfield = unmarshalBitfield(b)
138         case Piece:
139                 for _, pi := range []*Integer{&msg.Index, &msg.Begin} {
140                         err = pi.Read(r)
141                         if err != nil {
142                                 return
143                         }
144                 }
145                 dataLen := r.UnreadLength()
146                 if piecePool == nil {
147                         msg.Piece = make([]byte, dataLen)
148                 } else {
149                         msg.Piece = *piecePool.Get().(*[]byte)
150                         if int64(cap(msg.Piece)) < dataLen {
151                                 return errors.New("piece data longer than expected")
152                         }
153                         msg.Piece = msg.Piece[:dataLen]
154                 }
155                 _, err = io.ReadFull(r, msg.Piece)
156         case Extended:
157                 var b byte
158                 b, err = r.ReadByte()
159                 if err != nil {
160                         break
161                 }
162                 msg.ExtendedID = ExtensionNumber(b)
163                 msg.ExtendedPayload = make([]byte, r.UnreadLength())
164                 _, err = io.ReadFull(r, msg.ExtendedPayload)
165         case Port:
166                 err = binary.Read(r, binary.BigEndian, &msg.Port)
167         case HashRequest, HashReject:
168                 err = readHashRequest(r, msg)
169         case Hashes:
170                 err = readHashRequest(r, msg)
171                 numHashes := (r.UnreadLength() + 31) / 32
172                 g.MakeSliceWithCap(&msg.Hashes, numHashes)
173                 for range numHashes {
174                         var oneHash [32]byte
175                         _, err = io.ReadFull(r, oneHash[:])
176                         if err != nil {
177                                 err = fmt.Errorf("error while reading hashes: %w", err)
178                                 return
179                         }
180                         msg.Hashes = append(msg.Hashes, oneHash)
181                 }
182         default:
183                 err = errors.New("unhandled message type")
184         }
185         return
186 }
187
188 func readHashRequest(r io.Reader, msg *Message) (err error) {
189         _, err = io.ReadFull(r, msg.PiecesRoot[:])
190         if err != nil {
191                 return
192         }
193         return readSeq(r, &msg.BaseLayer, &msg.Index, &msg.Length, &msg.ProofLayers)
194 }
195
196 func readSeq(r io.Reader, data ...any) (err error) {
197         for _, d := range data {
198                 err = binary.Read(r, binary.BigEndian, d)
199                 if err != nil {
200                         return
201                 }
202         }
203         return
204 }
205
206 func unmarshalBitfield(b []byte) (bf []bool) {
207         for _, c := range b {
208                 for i := 7; i >= 0; i-- {
209                         bf = append(bf, (c>>uint(i))&1 == 1)
210                 }
211         }
212         return
213 }