]> Sergey Matveev's repositories - btrtrc.git/blobdiff - peer_protocol/protocol.go
Add a end-to-end test for torrentfs
[btrtrc.git] / peer_protocol / protocol.go
index 2f9547ef29dcf439686da9e610cbfe5b3280c52b..281d84869a9666929ccedc6aa91081bc256dd5ae 100644 (file)
@@ -24,15 +24,15 @@ const (
 )
 
 const (
-       Choke MessageType = iota
-       Unchoke
-       Interested
-       NotInterested
-       Have
-       Bitfield
-       Request
-       Piece
-       Cancel
+       Choke         MessageType = iota
+       Unchoke                   // 1
+       Interested                // 2
+       NotInterested             // 3
+       Have                      // 4
+       Bitfield                  // 5
+       Request                   // 6
+       Piece                     // 7
+       Cancel                    // 8
 )
 
 type Message struct {
@@ -66,8 +66,22 @@ func (msg Message) MarshalBinary() (data []byte, err error) {
                }
        case Bitfield:
                _, err = buf.Write(marshalBitfield(msg.Bitfield))
+       case Piece:
+               for _, i := range []Integer{msg.Index, msg.Begin} {
+                       err = binary.Write(buf, binary.BigEndian, i)
+                       if err != nil {
+                               return
+                       }
+               }
+               n, err := buf.Write(msg.Piece)
+               if err != nil {
+                       break
+               }
+               if n != len(msg.Piece) {
+                       panic(n)
+               }
        default:
-               err = errors.New("unknown message type")
+               err = fmt.Errorf("unknown message type: %s", msg.Type)
        }
        data = make([]byte, 4+buf.Len())
        binary.BigEndian.PutUint32(data, uint32(buf.Len()))
@@ -114,7 +128,12 @@ func (d *Decoder) Decode(msg *Message) (err error) {
        case Have:
                err = msg.Index.Read(r)
        case Request, Cancel:
-               err = binary.Read(r, binary.BigEndian, []*Integer{&msg.Index, &msg.Begin, &msg.Length})
+               for _, data := range []*Integer{&msg.Index, &msg.Begin, &msg.Length} {
+                       err = data.Read(r)
+                       if err != nil {
+                               break
+                       }
+               }
        case Bitfield:
                b := make([]byte, length-1)
                _, err = io.ReadFull(r, b)