]> Sergey Matveev's repositories - btrtrc.git/blob - peer_protocol/decoder.go
peer_protocol: Use faster form for Integer.{UnmarshalBinary,Read}
[btrtrc.git] / peer_protocol / decoder.go
1 package peer_protocol
2
3 import (
4         "bufio"
5         "encoding/binary"
6         "fmt"
7         "io"
8         "io/ioutil"
9         "sync"
10
11         "github.com/pkg/errors"
12 )
13
14 type Decoder struct {
15         R         *bufio.Reader
16         Pool      *sync.Pool
17         MaxLength Integer // TODO: Should this include the length header or not?
18 }
19
20 // io.EOF is returned if the source terminates cleanly on a message boundary.
21 func (d *Decoder) Decode(msg *Message) (err error) {
22         var length Integer
23         err = length.Read(d.R)
24         if err != nil {
25                 return fmt.Errorf("reading message length: %w", err)
26         }
27         if length > d.MaxLength {
28                 return errors.New("message too long")
29         }
30         if length == 0 {
31                 msg.Keepalive = true
32                 return
33         }
34         msg.Keepalive = false
35         r := &io.LimitedReader{R: d.R, N: int64(length)}
36         // Check that all of r was utilized.
37         defer func() {
38                 if err != nil {
39                         return
40                 }
41                 if r.N != 0 {
42                         err = fmt.Errorf("%d bytes unused in message type %d", r.N, msg.Type)
43                 }
44         }()
45         msg.Keepalive = false
46         c, err := readByte(r)
47         if err != nil {
48                 return
49         }
50         msg.Type = MessageType(c)
51         switch msg.Type {
52         case Choke, Unchoke, Interested, NotInterested, HaveAll, HaveNone:
53                 return
54         case Have, AllowedFast, Suggest:
55                 err = msg.Index.Read(r)
56         case Request, Cancel, Reject:
57                 for _, data := range []*Integer{&msg.Index, &msg.Begin, &msg.Length} {
58                         err = data.Read(r)
59                         if err != nil {
60                                 break
61                         }
62                 }
63         case Bitfield:
64                 b := make([]byte, length-1)
65                 _, err = io.ReadFull(r, b)
66                 msg.Bitfield = unmarshalBitfield(b)
67         case Piece:
68                 for _, pi := range []*Integer{&msg.Index, &msg.Begin} {
69                         err := pi.Read(r)
70                         if err != nil {
71                                 return err
72                         }
73                 }
74                 dataLen := r.N
75                 msg.Piece = (*d.Pool.Get().(*[]byte))
76                 if int64(cap(msg.Piece)) < dataLen {
77                         return errors.New("piece data longer than expected")
78                 }
79                 msg.Piece = msg.Piece[:dataLen]
80                 _, err := io.ReadFull(r, msg.Piece)
81                 if err != nil {
82                         return errors.Wrap(err, "reading piece data")
83                 }
84         case Extended:
85                 var b byte
86                 b, err = readByte(r)
87                 if err != nil {
88                         break
89                 }
90                 msg.ExtendedID = ExtensionNumber(b)
91                 msg.ExtendedPayload, err = ioutil.ReadAll(r)
92         case Port:
93                 err = binary.Read(r, binary.BigEndian, &msg.Port)
94         default:
95                 err = fmt.Errorf("unknown message type %#v", c)
96         }
97         return
98 }
99
100 func readByte(r io.Reader) (b byte, err error) {
101         var arr [1]byte
102         n, err := r.Read(arr[:])
103         b = arr[0]
104         if n == 1 {
105                 err = nil
106                 return
107         }
108         if err == nil {
109                 panic(err)
110         }
111         return
112 }
113
114 func unmarshalBitfield(b []byte) (bf []bool) {
115         for _, c := range b {
116                 for i := 7; i >= 0; i-- {
117                         bf = append(bf, (c>>uint(i))&1 == 1)
118                 }
119         }
120         return
121 }