]> Sergey Matveev's repositories - btrtrc.git/blob - peer_protocol/protocol.go
2f9547ef29dcf439686da9e610cbfe5b3280c52b
[btrtrc.git] / peer_protocol / protocol.go
1 package peer_protocol
2
3 import (
4         "bufio"
5         "bytes"
6         "encoding/binary"
7         "errors"
8         "fmt"
9         "io"
10         "io/ioutil"
11 )
12
13 type (
14         MessageType byte
15         Integer     uint32
16 )
17
18 func (i *Integer) Read(r io.Reader) error {
19         return binary.Read(r, binary.BigEndian, i)
20 }
21
22 const (
23         Protocol = "\x13BitTorrent protocol"
24 )
25
26 const (
27         Choke MessageType = iota
28         Unchoke
29         Interested
30         NotInterested
31         Have
32         Bitfield
33         Request
34         Piece
35         Cancel
36 )
37
38 type Message struct {
39         Keepalive            bool
40         Type                 MessageType
41         Index, Begin, Length Integer
42         Piece                []byte
43         Bitfield             []bool
44 }
45
46 func (msg Message) MarshalBinary() (data []byte, err error) {
47         buf := &bytes.Buffer{}
48         if msg.Keepalive {
49                 data = buf.Bytes()
50                 return
51         }
52         err = buf.WriteByte(byte(msg.Type))
53         if err != nil {
54                 return
55         }
56         switch msg.Type {
57         case Choke, Unchoke, Interested, NotInterested:
58         case Have:
59                 err = binary.Write(buf, binary.BigEndian, msg.Index)
60         case Request, Cancel:
61                 for _, i := range []Integer{msg.Index, msg.Begin, msg.Length} {
62                         err = binary.Write(buf, binary.BigEndian, i)
63                         if err != nil {
64                                 break
65                         }
66                 }
67         case Bitfield:
68                 _, err = buf.Write(marshalBitfield(msg.Bitfield))
69         default:
70                 err = errors.New("unknown message type")
71         }
72         data = make([]byte, 4+buf.Len())
73         binary.BigEndian.PutUint32(data, uint32(buf.Len()))
74         if buf.Len() != copy(data[4:], buf.Bytes()) {
75                 panic("bad copy")
76         }
77         return
78 }
79
80 type Decoder struct {
81         R         *bufio.Reader
82         MaxLength Integer
83 }
84
85 func (d *Decoder) Decode(msg *Message) (err error) {
86         var length Integer
87         err = binary.Read(d.R, binary.BigEndian, &length)
88         if err != nil {
89                 return
90         }
91         if length > d.MaxLength {
92                 return errors.New("message too long")
93         }
94         r := bufio.NewReader(io.LimitReader(d.R, int64(length)))
95         if length == 0 {
96                 msg.Keepalive = true
97                 return
98         }
99         msg.Keepalive = false
100         c, err := r.ReadByte()
101         if err != nil {
102                 return
103         }
104         msg.Type = MessageType(c)
105         defer func() {
106                 written, _ := io.Copy(ioutil.Discard, r)
107                 if written != 0 && err != nil {
108                         err = fmt.Errorf("short read on message type %d, left %d bytes", msg.Type, written)
109                 }
110         }()
111         switch msg.Type {
112         case Choke, Unchoke, Interested, NotInterested:
113                 return
114         case Have:
115                 err = msg.Index.Read(r)
116         case Request, Cancel:
117                 err = binary.Read(r, binary.BigEndian, []*Integer{&msg.Index, &msg.Begin, &msg.Length})
118         case Bitfield:
119                 b := make([]byte, length-1)
120                 _, err = io.ReadFull(r, b)
121                 msg.Bitfield = unmarshalBitfield(b)
122         case Piece:
123                 for _, pi := range []*Integer{&msg.Index, &msg.Begin} {
124                         err = pi.Read(r)
125                         if err != nil {
126                                 break
127                         }
128                 }
129                 if err != nil {
130                         break
131                 }
132                 msg.Piece, err = ioutil.ReadAll(r)
133         default:
134                 err = fmt.Errorf("unknown message type %#v", c)
135         }
136         if err != nil {
137                 err = fmt.Errorf("decoding type %d: %s", msg.Type, err)
138         }
139         return
140 }
141
142 type Bytes []byte
143
144 func (b Bytes) MarshalBinary() ([]byte, error) {
145         return b, nil
146 }
147
148 func unmarshalBitfield(b []byte) (bf []bool) {
149         for _, c := range b {
150                 for i := 7; i >= 0; i-- {
151                         bf = append(bf, (c>>uint(i))&1 == 1)
152                 }
153         }
154         return
155 }
156
157 func marshalBitfield(bf []bool) (b []byte) {
158         b = make([]byte, (len(bf)+7)/8)
159         for i, have := range bf {
160                 if !have {
161                         continue
162                 }
163                 c := b[i/8]
164                 c |= 1 << uint(7-i%8)
165                 b[i/8] = c
166         }
167         return
168 }