]> Sergey Matveev's repositories - btrtrc.git/blob - peer_protocol/decoder.go
Drop support for go 1.20
[btrtrc.git] / peer_protocol / decoder.go
1 package peer_protocol
2
3 import (
4         "bufio"
5         "encoding/binary"
6         "fmt"
7         "io"
8         "sync"
9
10         "github.com/pkg/errors"
11 )
12
13 type Decoder struct {
14         R *bufio.Reader
15         // This must return *[]byte where the slices can fit data for piece messages. I think we store
16         // *[]byte in the pool to avoid an extra allocation every time we put the slice back into the
17         // pool. The chunk size should not change for the life of the decoder.
18         Pool      *sync.Pool
19         MaxLength Integer // TODO: Should this include the length header or not?
20 }
21
22 // io.EOF is returned if the source terminates cleanly on a message boundary.
23 func (d *Decoder) Decode(msg *Message) (err error) {
24         var length Integer
25         err = length.Read(d.R)
26         if err != nil {
27                 return fmt.Errorf("reading message length: %w", err)
28         }
29         if length > d.MaxLength {
30                 return errors.New("message too long")
31         }
32         if length == 0 {
33                 msg.Keepalive = true
34                 return
35         }
36         r := d.R
37         readByte := func() (byte, error) {
38                 length--
39                 return d.R.ReadByte()
40         }
41         // From this point onwards, EOF is unexpected
42         defer func() {
43                 if err == io.EOF {
44                         err = io.ErrUnexpectedEOF
45                 }
46         }()
47         c, err := readByte()
48         if err != nil {
49                 return
50         }
51         msg.Type = MessageType(c)
52         // Can return directly in cases when err is not nil, or length is known to be zero.
53         switch msg.Type {
54         case Choke, Unchoke, Interested, NotInterested, HaveAll, HaveNone:
55         case Have, AllowedFast, Suggest:
56                 length -= 4
57                 err = msg.Index.Read(r)
58         case Request, Cancel, Reject:
59                 for _, data := range []*Integer{&msg.Index, &msg.Begin, &msg.Length} {
60                         err = data.Read(r)
61                         if err != nil {
62                                 break
63                         }
64                 }
65                 length -= 12
66         case Bitfield:
67                 b := make([]byte, length)
68                 _, err = io.ReadFull(r, b)
69                 length = 0
70                 msg.Bitfield = unmarshalBitfield(b)
71                 return
72         case Piece:
73                 for _, pi := range []*Integer{&msg.Index, &msg.Begin} {
74                         err := pi.Read(r)
75                         if err != nil {
76                                 return err
77                         }
78                 }
79                 length -= 8
80                 dataLen := int64(length)
81                 if d.Pool == nil {
82                         msg.Piece = make([]byte, dataLen)
83                 } else {
84                         msg.Piece = *d.Pool.Get().(*[]byte)
85                         if int64(cap(msg.Piece)) < dataLen {
86                                 return errors.New("piece data longer than expected")
87                         }
88                         msg.Piece = msg.Piece[:dataLen]
89                 }
90                 _, err = io.ReadFull(r, msg.Piece)
91                 length = 0
92                 return
93         case Extended:
94                 var b byte
95                 b, err = readByte()
96                 if err != nil {
97                         break
98                 }
99                 msg.ExtendedID = ExtensionNumber(b)
100                 msg.ExtendedPayload = make([]byte, length)
101                 _, err = io.ReadFull(r, msg.ExtendedPayload)
102                 length = 0
103                 return
104         case Port:
105                 err = binary.Read(r, binary.BigEndian, &msg.Port)
106                 length -= 2
107         default:
108                 err = fmt.Errorf("unknown message type %#v", c)
109         }
110         if err == nil && length != 0 {
111                 err = fmt.Errorf("%v unused bytes in message type %v", length, msg.Type)
112         }
113         return
114 }
115
116 func readByte(r io.Reader) (b byte, err error) {
117         var arr [1]byte
118         n, err := r.Read(arr[:])
119         b = arr[0]
120         if n == 1 {
121                 err = nil
122                 return
123         }
124         if err == nil {
125                 panic(err)
126         }
127         return
128 }
129
130 func unmarshalBitfield(b []byte) (bf []bool) {
131         for _, c := range b {
132                 for i := 7; i >= 0; i-- {
133                         bf = append(bf, (c>>uint(i))&1 == 1)
134                 }
135         }
136         return
137 }