]> Sergey Matveev's repositories - btrtrc.git/blob - peer_protocol/decoder.go
Improvements to decoder fuzzing
[btrtrc.git] / peer_protocol / decoder.go
1 package peer_protocol
2
3 import (
4         "bufio"
5         "encoding/binary"
6         "fmt"
7         "io"
8         "io/ioutil"
9         "sync"
10
11         "github.com/pkg/errors"
12 )
13
14 type Decoder struct {
15         R         *bufio.Reader
16         Pool      *sync.Pool
17         MaxLength Integer // TODO: Should this include the length header or not?
18 }
19
20 // io.EOF is returned if the source terminates cleanly on a message boundary.
21 func (d *Decoder) Decode(msg *Message) (err error) {
22         var length Integer
23         err = binary.Read(d.R, binary.BigEndian, &length)
24         if err != nil {
25                 if err != io.EOF {
26                         err = fmt.Errorf("error reading message length: %w", err)
27                 }
28                 return
29         }
30         if length > d.MaxLength {
31                 return errors.New("message too long")
32         }
33         if length == 0 {
34                 msg.Keepalive = true
35                 return
36         }
37         msg.Keepalive = false
38         r := &io.LimitedReader{R: d.R, N: int64(length)}
39         // Check that all of r was utilized.
40         defer func() {
41                 if err != nil {
42                         return
43                 }
44                 if r.N != 0 {
45                         err = fmt.Errorf("%d bytes unused in message type %d", r.N, msg.Type)
46                 }
47         }()
48         msg.Keepalive = false
49         c, err := readByte(r)
50         if err != nil {
51                 return
52         }
53         msg.Type = MessageType(c)
54         switch msg.Type {
55         case Choke, Unchoke, Interested, NotInterested, HaveAll, HaveNone:
56                 return
57         case Have, AllowedFast, Suggest:
58                 err = msg.Index.Read(r)
59         case Request, Cancel, Reject:
60                 for _, data := range []*Integer{&msg.Index, &msg.Begin, &msg.Length} {
61                         err = data.Read(r)
62                         if err != nil {
63                                 break
64                         }
65                 }
66         case Bitfield:
67                 b := make([]byte, length-1)
68                 _, err = io.ReadFull(r, b)
69                 msg.Bitfield = unmarshalBitfield(b)
70         case Piece:
71                 for _, pi := range []*Integer{&msg.Index, &msg.Begin} {
72                         err := pi.Read(r)
73                         if err != nil {
74                                 return err
75                         }
76                 }
77                 dataLen := r.N
78                 msg.Piece = (*d.Pool.Get().(*[]byte))
79                 if int64(cap(msg.Piece)) < dataLen {
80                         return errors.New("piece data longer than expected")
81                 }
82                 msg.Piece = msg.Piece[:dataLen]
83                 _, err := io.ReadFull(r, msg.Piece)
84                 if err != nil {
85                         return errors.Wrap(err, "reading piece data")
86                 }
87         case Extended:
88                 var b byte
89                 b, err = readByte(r)
90                 if err != nil {
91                         break
92                 }
93                 msg.ExtendedID = ExtensionNumber(b)
94                 msg.ExtendedPayload, err = ioutil.ReadAll(r)
95         case Port:
96                 err = binary.Read(r, binary.BigEndian, &msg.Port)
97         default:
98                 err = fmt.Errorf("unknown message type %#v", c)
99         }
100         return
101 }
102
103 func readByte(r io.Reader) (b byte, err error) {
104         var arr [1]byte
105         n, err := r.Read(arr[:])
106         b = arr[0]
107         if n == 1 {
108                 err = nil
109                 return
110         }
111         if err == nil {
112                 panic(err)
113         }
114         return
115 }
116
117 func unmarshalBitfield(b []byte) (bf []bool) {
118         for _, c := range b {
119                 for i := 7; i >= 0; i-- {
120                         bf = append(bf, (c>>uint(i))&1 == 1)
121                 }
122         }
123         return
124 }