]> Sergey Matveev's repositories - btrtrc.git/blob - peer_protocol/decoder.go
Rewrite piece data decoding and relax test
[btrtrc.git] / peer_protocol / decoder.go
1 package peer_protocol
2
3 import (
4         "bufio"
5         "encoding/binary"
6         "fmt"
7         "io"
8         "io/ioutil"
9         "sync"
10
11         "github.com/pkg/errors"
12 )
13
14 type Decoder struct {
15         R         *bufio.Reader
16         Pool      *sync.Pool
17         MaxLength Integer // TODO: Should this include the length header or not?
18 }
19
20 // io.EOF is returned if the source terminates cleanly on a message boundary.
21 // TODO: Is that before or after the message?
22 func (d *Decoder) Decode(msg *Message) (err error) {
23         var length Integer
24         err = binary.Read(d.R, binary.BigEndian, &length)
25         if err != nil {
26                 if err != io.EOF {
27                         err = fmt.Errorf("error reading message length: %s", err)
28                 }
29                 return
30         }
31         if length > d.MaxLength {
32                 return errors.New("message too long")
33         }
34         if length == 0 {
35                 msg.Keepalive = true
36                 return
37         }
38         msg.Keepalive = false
39         r := &io.LimitedReader{R: d.R, N: int64(length)}
40         // Check that all of r was utilized.
41         defer func() {
42                 if err != nil {
43                         return
44                 }
45                 if r.N != 0 {
46                         err = fmt.Errorf("%d bytes unused in message type %d", r.N, msg.Type)
47                 }
48         }()
49         msg.Keepalive = false
50         c, err := readByte(r)
51         if err != nil {
52                 return
53         }
54         msg.Type = MessageType(c)
55         switch msg.Type {
56         case Choke, Unchoke, Interested, NotInterested, HaveAll, HaveNone:
57                 return
58         case Have, AllowedFast, Suggest:
59                 err = msg.Index.Read(r)
60         case Request, Cancel, Reject:
61                 for _, data := range []*Integer{&msg.Index, &msg.Begin, &msg.Length} {
62                         err = data.Read(r)
63                         if err != nil {
64                                 break
65                         }
66                 }
67         case Bitfield:
68                 b := make([]byte, length-1)
69                 _, err = io.ReadFull(r, b)
70                 msg.Bitfield = unmarshalBitfield(b)
71         case Piece:
72                 for _, pi := range []*Integer{&msg.Index, &msg.Begin} {
73                         err := pi.Read(r)
74                         if err != nil {
75                                 return err
76                         }
77                 }
78                 dataLen := r.N
79                 msg.Piece = (*d.Pool.Get().(*[]byte))
80                 if int64(cap(msg.Piece)) < dataLen {
81                         return errors.New("piece data longer than expected")
82                 }
83                 msg.Piece = msg.Piece[:dataLen]
84                 _, err := io.ReadFull(r, msg.Piece)
85                 if err != nil {
86                         return errors.Wrap(err, "reading piece data")
87                 }
88         case Extended:
89                 b, err := readByte(r)
90                 if err != nil {
91                         break
92                 }
93                 msg.ExtendedID = ExtensionNumber(b)
94                 msg.ExtendedPayload, err = ioutil.ReadAll(r)
95         case Port:
96                 err = binary.Read(r, binary.BigEndian, &msg.Port)
97         default:
98                 err = fmt.Errorf("unknown message type %#v", c)
99         }
100         return
101 }
102
103 func readByte(r io.Reader) (b byte, err error) {
104         var arr [1]byte
105         n, err := r.Read(arr[:])
106         b = arr[0]
107         if n == 1 {
108                 err = nil
109                 return
110         }
111         if err == nil {
112                 panic(err)
113         }
114         return
115 }
116
117 func unmarshalBitfield(b []byte) (bf []bool) {
118         for _, c := range b {
119                 for i := 7; i >= 0; i-- {
120                         bf = append(bf, (c>>uint(i))&1 == 1)
121                 }
122         }
123         return
124 }