]> Sergey Matveev's repositories - btrtrc.git/blob - peer_protocol/msg.go
go1.19 compat
[btrtrc.git] / peer_protocol / msg.go
1 package peer_protocol
2
3 import (
4         "bufio"
5         "bytes"
6         "encoding"
7         "encoding/binary"
8         "fmt"
9 )
10
11 // This is a lazy union representing all the possible fields for messages. Go doesn't have ADTs, and
12 // I didn't choose to use type-assertions.
13 type Message struct {
14         Keepalive            bool
15         Type                 MessageType
16         Index, Begin, Length Integer
17         Piece                []byte
18         Bitfield             []bool
19         ExtendedID           ExtensionNumber
20         ExtendedPayload      []byte
21         Port                 uint16
22 }
23
24 var _ interface {
25         encoding.BinaryUnmarshaler
26         encoding.BinaryMarshaler
27 } = (*Message)(nil)
28
29 func MakeCancelMessage(piece, offset, length Integer) Message {
30         return Message{
31                 Type:   Cancel,
32                 Index:  piece,
33                 Begin:  offset,
34                 Length: length,
35         }
36 }
37
38 func (msg Message) RequestSpec() (ret RequestSpec) {
39         return RequestSpec{
40                 msg.Index,
41                 msg.Begin,
42                 func() Integer {
43                         if msg.Type == Piece {
44                                 return Integer(len(msg.Piece))
45                         } else {
46                                 return msg.Length
47                         }
48                 }(),
49         }
50 }
51
52 func (msg Message) MustMarshalBinary() []byte {
53         b, err := msg.MarshalBinary()
54         if err != nil {
55                 panic(err)
56         }
57         return b
58 }
59
60 func (msg Message) MarshalBinary() (data []byte, err error) {
61         var buf bytes.Buffer
62         if !msg.Keepalive {
63                 err = buf.WriteByte(byte(msg.Type))
64                 if err != nil {
65                         return
66                 }
67                 switch msg.Type {
68                 case Choke, Unchoke, Interested, NotInterested, HaveAll, HaveNone:
69                 case Have:
70                         err = binary.Write(&buf, binary.BigEndian, msg.Index)
71                 case Request, Cancel, Reject:
72                         for _, i := range []Integer{msg.Index, msg.Begin, msg.Length} {
73                                 err = binary.Write(&buf, binary.BigEndian, i)
74                                 if err != nil {
75                                         break
76                                 }
77                         }
78                 case Bitfield:
79                         _, err = buf.Write(marshalBitfield(msg.Bitfield))
80                 case Piece:
81                         for _, i := range []Integer{msg.Index, msg.Begin} {
82                                 err = binary.Write(&buf, binary.BigEndian, i)
83                                 if err != nil {
84                                         return
85                                 }
86                         }
87                         n, err := buf.Write(msg.Piece)
88                         if err != nil {
89                                 break
90                         }
91                         if n != len(msg.Piece) {
92                                 panic(n)
93                         }
94                 case Extended:
95                         err = buf.WriteByte(byte(msg.ExtendedID))
96                         if err != nil {
97                                 return
98                         }
99                         _, err = buf.Write(msg.ExtendedPayload)
100                 case Port:
101                         err = binary.Write(&buf, binary.BigEndian, msg.Port)
102                 default:
103                         err = fmt.Errorf("unknown message type: %v", msg.Type)
104                 }
105         }
106         data = make([]byte, 4+buf.Len())
107         binary.BigEndian.PutUint32(data, uint32(buf.Len()))
108         if buf.Len() != copy(data[4:], buf.Bytes()) {
109                 panic("bad copy")
110         }
111         return
112 }
113
114 func marshalBitfield(bf []bool) (b []byte) {
115         b = make([]byte, (len(bf)+7)/8)
116         for i, have := range bf {
117                 if !have {
118                         continue
119                 }
120                 c := b[i/8]
121                 c |= 1 << uint(7-i%8)
122                 b[i/8] = c
123         }
124         return
125 }
126
127 func (me *Message) UnmarshalBinary(b []byte) error {
128         d := Decoder{
129                 R: bufio.NewReader(bytes.NewReader(b)),
130         }
131         err := d.Decode(me)
132         if err != nil {
133                 return err
134         }
135         if d.R.Buffered() != 0 {
136                 return fmt.Errorf("%d trailing bytes", d.R.Buffered())
137         }
138         return nil
139 }