]> Sergey Matveev's repositories - btrtrc.git/blob - peer_protocol/decoder.go
f4432f64d6d1750afe4c90b184557cf535b933cd
[btrtrc.git] / peer_protocol / decoder.go
1 package peer_protocol
2
3 import (
4         "bufio"
5         "encoding/binary"
6         "fmt"
7         "io"
8         "sync"
9
10         "github.com/pkg/errors"
11 )
12
13 type Decoder struct {
14         R         *bufio.Reader
15         Pool      *sync.Pool
16         MaxLength Integer // TODO: Should this include the length header or not?
17 }
18
19 // io.EOF is returned if the source terminates cleanly on a message boundary.
20 func (d *Decoder) Decode(msg *Message) (err error) {
21         var length Integer
22         err = length.Read(d.R)
23         if err != nil {
24                 return fmt.Errorf("reading message length: %w", err)
25         }
26         if length > d.MaxLength {
27                 return errors.New("message too long")
28         }
29         if length == 0 {
30                 msg.Keepalive = true
31                 return
32         }
33         r := d.R
34         readByte := func() (byte, error) {
35                 length--
36                 return d.R.ReadByte()
37         }
38         c, err := readByte()
39         if err != nil {
40                 return
41         }
42         msg.Type = MessageType(c)
43         switch msg.Type {
44         case Choke, Unchoke, Interested, NotInterested, HaveAll, HaveNone:
45         case Have, AllowedFast, Suggest:
46                 length -= 4
47                 err = msg.Index.Read(r)
48         case Request, Cancel, Reject:
49                 for _, data := range []*Integer{&msg.Index, &msg.Begin, &msg.Length} {
50                         err = data.Read(r)
51                         if err != nil {
52                                 break
53                         }
54                 }
55                 length -= 12
56         case Bitfield:
57                 b := make([]byte, length)
58                 _, err = io.ReadFull(r, b)
59                 length = 0
60                 msg.Bitfield = unmarshalBitfield(b)
61         case Piece:
62                 for _, pi := range []*Integer{&msg.Index, &msg.Begin} {
63                         err := pi.Read(r)
64                         if err != nil {
65                                 return err
66                         }
67                 }
68                 length -= 8
69                 dataLen := int64(length)
70                 msg.Piece = *d.Pool.Get().(*[]byte)
71                 if int64(cap(msg.Piece)) < dataLen {
72                         return errors.New("piece data longer than expected")
73                 }
74                 msg.Piece = msg.Piece[:dataLen]
75                 _, err := io.ReadFull(r, msg.Piece)
76                 if err != nil {
77                         return fmt.Errorf("reading piece data: %w", err)
78                 }
79                 length = 0
80         case Extended:
81                 var b byte
82                 b, err = readByte()
83                 if err != nil {
84                         break
85                 }
86                 msg.ExtendedID = ExtensionNumber(b)
87                 msg.ExtendedPayload = make([]byte, length)
88                 _, err = io.ReadFull(r, msg.ExtendedPayload)
89                 if err == io.EOF {
90                         err = io.ErrUnexpectedEOF
91                 }
92                 length = 0
93         case Port:
94                 err = binary.Read(r, binary.BigEndian, &msg.Port)
95                 length -= 2
96         default:
97                 err = fmt.Errorf("unknown message type %#v", c)
98         }
99         if err == nil && length != 0 {
100                 err = fmt.Errorf("%v unused bytes in message type %v", length, msg.Type)
101         }
102         return
103 }
104
105 func readByte(r io.Reader) (b byte, err error) {
106         var arr [1]byte
107         n, err := r.Read(arr[:])
108         b = arr[0]
109         if n == 1 {
110                 err = nil
111                 return
112         }
113         if err == nil {
114                 panic(err)
115         }
116         return
117 }
118
119 func unmarshalBitfield(b []byte) (bf []bool) {
120         for _, c := range b {
121                 for i := 7; i >= 0; i-- {
122                         bf = append(bf, (c>>uint(i))&1 == 1)
123                 }
124         }
125         return
126 }