]> Sergey Matveev's repositories - btrtrc.git/commitdiff
Rewrite piece data decoding and relax test
authorMatt Joiner <anacrolix@gmail.com>
Sat, 14 Jul 2018 01:50:43 +0000 (11:50 +1000)
committerMatt Joiner <anacrolix@gmail.com>
Sat, 14 Jul 2018 01:50:43 +0000 (11:50 +1000)
peer_protocol/decoder.go
peer_protocol/decoder_test.go

index 47a8b65ebae7d7e045471af99fb0e3928911ffc6..b7ccab14c4ac500d9b2eaa333fda35ca0e533f2b 100644 (file)
@@ -3,11 +3,12 @@ package peer_protocol
 import (
        "bufio"
        "encoding/binary"
-       "errors"
        "fmt"
        "io"
        "io/ioutil"
        "sync"
+
+       "github.com/pkg/errors"
 )
 
 type Decoder struct {
@@ -69,24 +70,21 @@ func (d *Decoder) Decode(msg *Message) (err error) {
                msg.Bitfield = unmarshalBitfield(b)
        case Piece:
                for _, pi := range []*Integer{&msg.Index, &msg.Begin} {
-                       err = pi.Read(r)
+                       err := pi.Read(r)
                        if err != nil {
-                               break
+                               return err
                        }
                }
-               if err != nil {
-                       break
+               dataLen := r.N
+               msg.Piece = (*d.Pool.Get().(*[]byte))
+               if int64(cap(msg.Piece)) < dataLen {
+                       return errors.New("piece data longer than expected")
                }
-               //msg.Piece, err = ioutil.ReadAll(r)
-               b := *d.Pool.Get().(*[]byte)
-               n, err := io.ReadFull(r, b)
+               msg.Piece = msg.Piece[:dataLen]
+               _, err := io.ReadFull(r, msg.Piece)
                if err != nil {
-                       if err != io.ErrUnexpectedEOF || n != int(length-9) {
-                               return err
-                       }
-                       b = b[0:n]
+                       return errors.Wrap(err, "reading piece data")
                }
-               msg.Piece = b
        case Extended:
                b, err := readByte(r)
                if err != nil {
index bbd8194d03e686965430555c6b990cacce67c05f..33909cdc321d96f0e6e978891223566d780627c6 100644 (file)
@@ -89,5 +89,5 @@ func TestDecodeOverlongPiece(t *testing.T) {
                }},
        }
        var m Message
-       require.EqualError(t, d.Decode(&m), "piece data longer than expected")
+       require.Error(t, d.Decode(&m))
 }