]> Sergey Matveev's repositories - btrtrc.git/blob - peer_protocol/decoder.go
Support AllowedFast and enable fast extension
[btrtrc.git] / peer_protocol / decoder.go
1 package peer_protocol
2
3 import (
4         "bufio"
5         "encoding/binary"
6         "errors"
7         "fmt"
8         "io"
9         "io/ioutil"
10         "sync"
11 )
12
13 type Decoder struct {
14         R         *bufio.Reader
15         Pool      *sync.Pool
16         MaxLength Integer // TODO: Should this include the length header or not?
17 }
18
19 // io.EOF is returned if the source terminates cleanly on a message boundary.
20 // TODO: Is that before or after the message?
21 func (d *Decoder) Decode(msg *Message) (err error) {
22         var length Integer
23         err = binary.Read(d.R, binary.BigEndian, &length)
24         if err != nil {
25                 if err != io.EOF {
26                         err = fmt.Errorf("error reading message length: %s", err)
27                 }
28                 return
29         }
30         if length > d.MaxLength {
31                 return errors.New("message too long")
32         }
33         if length == 0 {
34                 msg.Keepalive = true
35                 return
36         }
37         msg.Keepalive = false
38         r := &io.LimitedReader{d.R, int64(length)}
39         // Check that all of r was utilized.
40         defer func() {
41                 if err != nil {
42                         return
43                 }
44                 if r.N != 0 {
45                         err = fmt.Errorf("%d bytes unused in message type %d", r.N, msg.Type)
46                 }
47         }()
48         msg.Keepalive = false
49         c, err := readByte(r)
50         if err != nil {
51                 return
52         }
53         msg.Type = MessageType(c)
54         switch msg.Type {
55         case Choke, Unchoke, Interested, NotInterested, HaveAll, HaveNone:
56                 return
57         case Have, AllowedFast, Suggest:
58                 err = msg.Index.Read(r)
59         case Request, Cancel, Reject:
60                 for _, data := range []*Integer{&msg.Index, &msg.Begin, &msg.Length} {
61                         err = data.Read(r)
62                         if err != nil {
63                                 break
64                         }
65                 }
66         case Bitfield:
67                 b := make([]byte, length-1)
68                 _, err = io.ReadFull(r, b)
69                 msg.Bitfield = unmarshalBitfield(b)
70         case Piece:
71                 for _, pi := range []*Integer{&msg.Index, &msg.Begin} {
72                         err = pi.Read(r)
73                         if err != nil {
74                                 break
75                         }
76                 }
77                 if err != nil {
78                         break
79                 }
80                 //msg.Piece, err = ioutil.ReadAll(r)
81                 b := *d.Pool.Get().(*[]byte)
82                 n, err := io.ReadFull(r, b)
83                 if err != nil {
84                         if err != io.ErrUnexpectedEOF || n != int(length-9) {
85                                 return err
86                         }
87                         b = b[0:n]
88                 }
89                 msg.Piece = b
90         case Extended:
91                 msg.ExtendedID, err = readByte(r)
92                 if err != nil {
93                         break
94                 }
95                 msg.ExtendedPayload, err = ioutil.ReadAll(r)
96         case Port:
97                 err = binary.Read(r, binary.BigEndian, &msg.Port)
98         default:
99                 err = fmt.Errorf("unknown message type %#v", c)
100         }
101         return
102 }
103
104 func readByte(r io.Reader) (b byte, err error) {
105         var arr [1]byte
106         n, err := r.Read(arr[:])
107         b = arr[0]
108         if n == 1 {
109                 err = nil
110                 return
111         }
112         if err == nil {
113                 panic(err)
114         }
115         return
116 }
117
118 func unmarshalBitfield(b []byte) (bf []bool) {
119         for _, c := range b {
120                 for i := 7; i >= 0; i-- {
121                         bf = append(bf, (c>>uint(i))&1 == 1)
122                 }
123         }
124         return
125 }