]> Sergey Matveev's repositories - btrtrc.git/commitdiff
Decoding of Piece messages, and checking entire message is consumed
authorMatt Joiner <anacrolix@gmail.com>
Wed, 2 Oct 2013 07:57:19 +0000 (17:57 +1000)
committerMatt Joiner <anacrolix@gmail.com>
Wed, 2 Oct 2013 07:57:19 +0000 (17:57 +1000)
peer_protocol/protocol.go

index 4020c5240d3f8066b1799f87993f2c973ea4db1f..2f9547ef29dcf439686da9e610cbfe5b3280c52b 100644 (file)
@@ -7,6 +7,7 @@ import (
        "errors"
        "fmt"
        "io"
+       "io/ioutil"
 )
 
 type (
@@ -90,30 +91,51 @@ func (d *Decoder) Decode(msg *Message) (err error) {
        if length > d.MaxLength {
                return errors.New("message too long")
        }
+       r := bufio.NewReader(io.LimitReader(d.R, int64(length)))
        if length == 0 {
                msg.Keepalive = true
                return
        }
        msg.Keepalive = false
-       c, err := d.R.ReadByte()
+       c, err := r.ReadByte()
        if err != nil {
                return
        }
        msg.Type = MessageType(c)
+       defer func() {
+               written, _ := io.Copy(ioutil.Discard, r)
+               if written != 0 && err != nil {
+                       err = fmt.Errorf("short read on message type %d, left %d bytes", msg.Type, written)
+               }
+       }()
        switch msg.Type {
        case Choke, Unchoke, Interested, NotInterested:
                return
        case Have:
-               err = msg.Index.Read(d.R)
+               err = msg.Index.Read(r)
        case Request, Cancel:
-               err = binary.Read(d.R, binary.BigEndian, []*Integer{&msg.Index, &msg.Begin, &msg.Length})
+               err = binary.Read(r, binary.BigEndian, []*Integer{&msg.Index, &msg.Begin, &msg.Length})
        case Bitfield:
                b := make([]byte, length-1)
-               _, err = io.ReadFull(d.R, b)
+               _, err = io.ReadFull(r, b)
                msg.Bitfield = unmarshalBitfield(b)
+       case Piece:
+               for _, pi := range []*Integer{&msg.Index, &msg.Begin} {
+                       err = pi.Read(r)
+                       if err != nil {
+                               break
+                       }
+               }
+               if err != nil {
+                       break
+               }
+               msg.Piece, err = ioutil.ReadAll(r)
        default:
                err = fmt.Errorf("unknown message type %#v", c)
        }
+       if err != nil {
+               err = fmt.Errorf("decoding type %d: %s", msg.Type, err)
+       }
        return
 }