]> Sergey Matveev's repositories - btrtrc.git/blob - peer_protocol/decoder.go
Break up peer_protocol into several files
[btrtrc.git] / peer_protocol / decoder.go
1 package peer_protocol
2
3 import (
4         "bufio"
5         "encoding/binary"
6         "errors"
7         "fmt"
8         "io"
9         "io/ioutil"
10         "sync"
11 )
12
13 type Decoder struct {
14         R         *bufio.Reader
15         Pool      *sync.Pool
16         MaxLength Integer // TODO: Should this include the length header or not?
17 }
18
19 // io.EOF is returned if the source terminates cleanly on a message boundary.
20 func (d *Decoder) Decode(msg *Message) (err error) {
21         var length Integer
22         err = binary.Read(d.R, binary.BigEndian, &length)
23         if err != nil {
24                 if err != io.EOF {
25                         err = fmt.Errorf("error reading message length: %s", err)
26                 }
27                 return
28         }
29         if length > d.MaxLength {
30                 return errors.New("message too long")
31         }
32         if length == 0 {
33                 msg.Keepalive = true
34                 return
35         }
36         msg.Keepalive = false
37         r := &io.LimitedReader{d.R, int64(length)}
38         // Check that all of r was utilized.
39         defer func() {
40                 if err != nil {
41                         return
42                 }
43                 if r.N != 0 {
44                         err = fmt.Errorf("%d bytes unused in message type %d", r.N, msg.Type)
45                 }
46         }()
47         msg.Keepalive = false
48         c, err := readByte(r)
49         if err != nil {
50                 return
51         }
52         msg.Type = MessageType(c)
53         switch msg.Type {
54         case Choke, Unchoke, Interested, NotInterested, HaveAll, HaveNone:
55                 return
56         case Have:
57                 err = msg.Index.Read(r)
58         case Request, Cancel, Reject:
59                 for _, data := range []*Integer{&msg.Index, &msg.Begin, &msg.Length} {
60                         err = data.Read(r)
61                         if err != nil {
62                                 break
63                         }
64                 }
65         case Bitfield:
66                 b := make([]byte, length-1)
67                 _, err = io.ReadFull(r, b)
68                 msg.Bitfield = unmarshalBitfield(b)
69         case Piece:
70                 for _, pi := range []*Integer{&msg.Index, &msg.Begin} {
71                         err = pi.Read(r)
72                         if err != nil {
73                                 break
74                         }
75                 }
76                 if err != nil {
77                         break
78                 }
79                 //msg.Piece, err = ioutil.ReadAll(r)
80                 b := *d.Pool.Get().(*[]byte)
81                 n, err := io.ReadFull(r, b)
82                 if err != nil {
83                         if err != io.ErrUnexpectedEOF || n != int(length-9) {
84                                 return err
85                         }
86                         b = b[0:n]
87                 }
88                 msg.Piece = b
89         case Extended:
90                 msg.ExtendedID, err = readByte(r)
91                 if err != nil {
92                         break
93                 }
94                 msg.ExtendedPayload, err = ioutil.ReadAll(r)
95         case Port:
96                 err = binary.Read(r, binary.BigEndian, &msg.Port)
97         default:
98                 err = fmt.Errorf("unknown message type %#v", c)
99         }
100         return
101 }
102
103 func readByte(r io.Reader) (b byte, err error) {
104         var arr [1]byte
105         n, err := r.Read(arr[:])
106         b = arr[0]
107         if n == 1 {
108                 err = nil
109                 return
110         }
111         if err == nil {
112                 panic(err)
113         }
114         return
115 }
116
117 func unmarshalBitfield(b []byte) (bf []bool) {
118         for _, c := range b {
119                 for i := 7; i >= 0; i-- {
120                         bf = append(bf, (c>>uint(i))&1 == 1)
121                 }
122         }
123         return
124 }