]> Sergey Matveev's repositories - btrtrc.git/blob - peer_protocol/decoder.go
Merge branch 'webtorrent'
[btrtrc.git] / peer_protocol / decoder.go
1 package peer_protocol
2
3 import (
4         "bufio"
5         "encoding/binary"
6         "fmt"
7         "io"
8         "io/ioutil"
9         "sync"
10
11         "github.com/pkg/errors"
12 )
13
14 type Decoder struct {
15         R         *bufio.Reader
16         Pool      *sync.Pool
17         MaxLength Integer // TODO: Should this include the length header or not?
18 }
19
20 // io.EOF is returned if the source terminates cleanly on a message boundary.
21 // TODO: Is that before or after the message?
22 func (d *Decoder) Decode(msg *Message) (err error) {
23         var length Integer
24         err = binary.Read(d.R, binary.BigEndian, &length)
25         if err != nil {
26                 if err != io.EOF {
27                         err = fmt.Errorf("error reading message length: %s", err)
28                 }
29                 return
30         }
31         if length > d.MaxLength {
32                 return errors.New("message too long")
33         }
34         if length == 0 {
35                 msg.Keepalive = true
36                 return
37         }
38         msg.Keepalive = false
39         r := &io.LimitedReader{R: d.R, N: int64(length)}
40         // Check that all of r was utilized.
41         defer func() {
42                 if err != nil {
43                         return
44                 }
45                 if r.N != 0 {
46                         err = fmt.Errorf("%d bytes unused in message type %d", r.N, msg.Type)
47                 }
48         }()
49         msg.Keepalive = false
50         c, err := readByte(r)
51         if err != nil {
52                 return
53         }
54         msg.Type = MessageType(c)
55         switch msg.Type {
56         case Choke, Unchoke, Interested, NotInterested, HaveAll, HaveNone:
57                 return
58         case Have, AllowedFast, Suggest:
59                 err = msg.Index.Read(r)
60         case Request, Cancel, Reject:
61                 for _, data := range []*Integer{&msg.Index, &msg.Begin, &msg.Length} {
62                         err = data.Read(r)
63                         if err != nil {
64                                 break
65                         }
66                 }
67         case Bitfield:
68                 b := make([]byte, length-1)
69                 _, err = io.ReadFull(r, b)
70                 msg.Bitfield = unmarshalBitfield(b)
71         case Piece:
72                 for _, pi := range []*Integer{&msg.Index, &msg.Begin} {
73                         err := pi.Read(r)
74                         if err != nil {
75                                 return err
76                         }
77                 }
78                 dataLen := r.N
79                 msg.Piece = (*d.Pool.Get().(*[]byte))
80                 if int64(cap(msg.Piece)) < dataLen {
81                         return errors.New("piece data longer than expected")
82                 }
83                 msg.Piece = msg.Piece[:dataLen]
84                 _, err := io.ReadFull(r, msg.Piece)
85                 if err != nil {
86                         return errors.Wrap(err, "reading piece data")
87                 }
88         case Extended:
89                 var b byte
90                 b, err = readByte(r)
91                 if err != nil {
92                         break
93                 }
94                 msg.ExtendedID = ExtensionNumber(b)
95                 msg.ExtendedPayload, err = ioutil.ReadAll(r)
96         case Port:
97                 err = binary.Read(r, binary.BigEndian, &msg.Port)
98         default:
99                 err = fmt.Errorf("unknown message type %#v", c)
100         }
101         return
102 }
103
104 func readByte(r io.Reader) (b byte, err error) {
105         var arr [1]byte
106         n, err := r.Read(arr[:])
107         b = arr[0]
108         if n == 1 {
109                 err = nil
110                 return
111         }
112         if err == nil {
113                 panic(err)
114         }
115         return
116 }
117
118 func unmarshalBitfield(b []byte) (bf []bool) {
119         for _, c := range b {
120                 for i := 7; i >= 0; i-- {
121                         bf = append(bf, (c>>uint(i))&1 == 1)
122                 }
123         }
124         return
125 }