]> Sergey Matveev's repositories - btrtrc.git/commitdiff
Implementing bitfields and connection message handling
authorMatt Joiner <anacrolix@gmail.com>
Mon, 30 Sep 2013 11:51:08 +0000 (21:51 +1000)
committerMatt Joiner <anacrolix@gmail.com>
Mon, 30 Sep 2013 11:51:08 +0000 (21:51 +1000)
client.go
cmd/torrent/main.go
peer_protocol/protocol.go
peer_protocol/protocol_test.go

index 284c2c1cb4e7891919da4c18cf5943f6595c28d9..55bf8d1a6e52ea816a917f261991f2dbbfe5adf4 100644 (file)
--- a/client.go
+++ b/client.go
@@ -2,9 +2,11 @@ package torrent
 
 import (
        "bitbucket.org/anacrolix/go.torrent/peer_protocol"
+       "bufio"
        "container/list"
        "crypto"
        "crypto/rand"
+       "encoding"
        "errors"
        metainfo "github.com/nsf/libtorgo/torrent"
        "io"
@@ -16,7 +18,9 @@ import (
 )
 
 const (
-       PieceHash = crypto.SHA1
+       PieceHash   = crypto.SHA1
+       maxRequests = 10
+       chunkSize   = 0x4000 // 16KiB
 )
 
 type InfoHash [20]byte
@@ -45,24 +49,76 @@ const (
 )
 
 type piece struct {
-       State pieceState
-       Hash  pieceSum
+       State         pieceState
+       Hash          pieceSum
+       PendingChunks map[chunk]struct{}
+}
+
+type chunk struct {
+       Begin, Length peer_protocol.Integer
+}
+
+type request struct {
+       Index, Begin, Length peer_protocol.Integer
 }
 
 type connection struct {
        Socket net.Conn
-       post   chan peer_protocol.Message
+       post   chan encoding.BinaryMarshaler
        write  chan []byte
 
        Interested bool
        Choked     bool
-       Requests   []peer_protocol.Request
+       Requests   map[request]struct{}
 
        PeerId         [20]byte
        PeerInterested bool
        PeerChoked     bool
-       PeerRequests   []peer_protocol.Request
+       PeerRequests   map[request]struct{}
        PeerExtensions [8]byte
+       PeerPieces     []bool
+}
+
+func (c *connection) PeerHasPiece(index int) bool {
+       if c.PeerPieces == nil {
+               return false
+       }
+       return c.PeerPieces[index]
+}
+
+func (c *connection) Post(msg encoding.BinaryMarshaler) {
+       c.post <- msg
+}
+
+func (c *connection) Request(chunk request) bool {
+       if len(c.Requests) >= maxRequests {
+               return false
+       }
+       if _, ok := c.Requests[chunk]; !ok {
+               c.Post(peer_protocol.Message{
+                       Type:   peer_protocol.Request,
+                       Index:  chunk.Index,
+                       Begin:  chunk.Begin,
+                       Length: chunk.Length,
+               })
+       }
+       c.Requests[chunk] = struct{}{}
+       return true
+}
+
+func (c *connection) SetInterested(interested bool) {
+       if c.Interested == interested {
+               return
+       }
+       c.Post(peer_protocol.Message{
+               Type: func() peer_protocol.MessageType {
+                       if interested {
+                               return peer_protocol.Interested
+                       }
+                       return peer_protocol.NotInterested
+               }(),
+       })
+       c.Interested = interested
 }
 
 func (conn *connection) writer() {
@@ -90,7 +146,11 @@ func (conn *connection) writeOptimizer() {
                        write = nil
                        nextWrite = nil
                } else {
-                       nextWrite = pending.Front().Value.(peer_protocol.Message).Encode()
+                       var err error
+                       nextWrite, err = pending.Front().Value.(encoding.BinaryMarshaler).MarshalBinary()
+                       if err != nil {
+                               panic(err)
+                       }
                }
                select {
                case msg := <-conn.post:
@@ -106,10 +166,43 @@ type torrent struct {
        Pieces   []piece
        Data     MMapSpan
        MetaInfo *metainfo.MetaInfo
-       Conns    []connection
+       Conns    []*connection
        Peers    []Peer
 }
 
+func (t *torrent) bitfield() (bf []bool) {
+       for _, p := range t.Pieces {
+               bf = append(bf, p.State == pieceStateComplete)
+       }
+       return
+}
+
+func (t *torrent) pieceChunks(index int) (cs map[chunk]struct{}) {
+       cs = make(map[chunk]struct{}, (t.MetaInfo.PieceLength+chunkSize-1)/chunkSize)
+       c := chunk{
+               Begin: 0,
+       }
+       for left := peer_protocol.Integer(t.PieceSize(index)); left > 0; left -= c.Length {
+               c.Length = left
+               if c.Length > chunkSize {
+                       c.Length = chunkSize
+               }
+               cs[c] = struct{}{}
+               c.Begin += c.Length
+       }
+       return
+}
+
+func (t *torrent) chunkHeat() (ret map[request]int) {
+       ret = make(map[request]int)
+       for _, conn := range t.Conns {
+               for req, _ := range conn.Requests {
+                       ret[req]++
+               }
+       }
+       return
+}
+
 type Peer struct {
        Id   [20]byte
        IP   net.IP
@@ -143,6 +236,8 @@ func (t *torrent) HashPiece(piece int) (ps pieceSum) {
        return
 }
 
+// func (t *torrent) bitfield
+
 type client struct {
        DataDir       string
        HalfOpenLimit int
@@ -252,7 +347,7 @@ func (me *client) initiateConn(peer Peer, torrent *torrent) {
                        log.Printf("error connecting to peer: %s", err)
                        return
                }
-               log.Printf("connected to %s", sock.RemoteAddr())
+               log.Printf("connected to %s", conn.RemoteAddr())
                me.handshake(conn, torrent, peer.Id)
        }()
 }
@@ -263,7 +358,7 @@ func (me *client) handshake(sock net.Conn, torrent *torrent, peerId [20]byte) {
                Choked:     true,
                PeerChoked: true,
                write:      make(chan []byte),
-               post:       make(chan peer_protocol.Message),
+               post:       make(chan encoding.BinaryMarshaler),
        }
        go conn.writer()
        go conn.writeOptimizer()
@@ -303,6 +398,135 @@ func (me *client) handshake(sock net.Conn, torrent *torrent, peerId [20]byte) {
                conn.post <- peer_protocol.Bytes(torrent.InfoHash[:])
                conn.post <- peer_protocol.Bytes(me.PeerId[:])
        }
+       me.withContext(func() {
+               me.addConnection(torrent, conn)
+               conn.Post(peer_protocol.Message{
+                       Type:     peer_protocol.Bitfield,
+                       Bitfield: torrent.bitfield(),
+               })
+               go func() {
+                       defer me.withContext(func() {
+                               me.dropConnection(torrent, conn)
+                       })
+                       err := me.runConnection(torrent, conn)
+                       if err != nil {
+                               log.Print(err)
+                       }
+               }()
+       })
+}
+
+func (me *client) peerGotPiece(torrent *torrent, conn *connection, piece int) {
+       if torrent.Pieces[piece].State != pieceStateIncomplete {
+               return
+       }
+       conn.SetInterested(true)
+}
+
+func (me *client) peerUnchoked(torrent *torrent, conn *connection) {
+       chunkHeatMap := torrent.chunkHeat()
+       for index, has := range conn.PeerPieces {
+               if !has {
+                       continue
+               }
+               for chunk, _ := range torrent.Pieces[index].PendingChunks {
+                       if _, ok := chunkHeatMap[chunk]; ok {
+                               continue
+                       }
+                       conn.SetInterested(true)
+                       if !conn.Request(chunk) {
+                               return
+                       }
+               }
+       }
+       conn.SetInterested(false)
+}
+
+func (me *client) runConnection(torrent *torrent, conn *connection) error {
+       decoder := peer_protocol.Decoder{
+               R:         bufio.NewReader(conn.Socket),
+               MaxLength: 256 * 1024,
+       }
+       for {
+               msg := new(peer_protocol.Message)
+               err := decoder.Decode(msg)
+               if err != nil {
+                       return err
+               }
+               if msg.Keepalive {
+                       continue
+               }
+               go me.withContext(func() {
+                       var err error
+                       switch msg.Type {
+                       case peer_protocol.Choke:
+                               conn.PeerChoked = true
+                       case peer_protocol.Unchoke:
+                               conn.PeerChoked = false
+                               me.peerUnchoked(torrent, conn)
+                       case peer_protocol.Interested:
+                               conn.PeerInterested = true
+                       case peer_protocol.NotInterested:
+                               conn.PeerInterested = false
+                       case peer_protocol.Have:
+                               me.peerGotPiece(torrent, conn, int(msg.Index))
+                               conn.PeerPieces[msg.Index] = true
+                       case peer_protocol.Request:
+                               conn.PeerRequests[request{
+                                       Index:  msg.Index,
+                                       Begin:  msg.Begin,
+                                       Length: msg.Length,
+                               }] = struct{}{}
+                       case peer_protocol.Bitfield:
+                               if len(msg.Bitfield) < len(torrent.Pieces) {
+                                       err = errors.New("received invalid bitfield")
+                                       break
+                               }
+                               if conn.PeerPieces != nil {
+                                       err = errors.New("received unexpected bitfield")
+                                       break
+                               }
+                               conn.PeerPieces = msg.Bitfield[:len(torrent.Pieces)]
+                               for index, has := range conn.PeerPieces {
+                                       if has {
+                                               me.peerGotPiece(torrent, conn, index)
+                                       }
+                               }
+                       default:
+                               log.Printf("received unknown message type: %#v", msg.Type)
+                       }
+                       if err != nil {
+                               log.Print(err)
+                               me.dropConnection(torrent, conn)
+                       }
+               })
+       }
+}
+
+func (me *client) dropConnection(torrent *torrent, conn *connection) {
+       conn.Socket.Close()
+       for i0, c := range torrent.Conns {
+               if c != conn {
+                       continue
+               }
+               i1 := len(torrent.Conns) - 1
+               if i0 != i1 {
+                       torrent.Conns[i0] = torrent.Conns[i1]
+               }
+               torrent.Conns = torrent.Conns[:i1]
+               return
+       }
+       panic("no such connection")
+}
+
+func (me *client) addConnection(t *torrent, c *connection) bool {
+       for _, c := range t.Conns {
+               if c.PeerId == c.PeerId {
+                       return false
+               }
+       }
+       t.Conns = append(t.Conns, c)
+       return true
 }
 
 func (me *client) openNewConns() {
@@ -367,21 +591,34 @@ func (me *client) withContext(f func()) {
 
 func (me *client) pieceHashed(ih InfoHash, piece int, correct bool) {
        torrent := me.torrents[ih]
-       torrent.Pieces[piece].State = func() pieceState {
+       newState := func() pieceState {
                if correct {
                        return pieceStateComplete
                } else {
                        return pieceStateIncomplete
                }
        }()
-       for _, piece := range torrent.Pieces {
-               if piece.State == pieceStateUnknown {
-                       return
+       oldState := torrent.Pieces[piece].State
+       if newState == oldState {
+               return
+       }
+       torrent.Pieces[piece].State = newState
+       if newState == pieceStateIncomplete {
+               torrent.Pieces[piece].PendingChunks = torrent.pieceChunks(piece)
+       }
+       for _, conn := range torrent.Conns {
+               if correct {
+                       conn.Post(peer_protocol.Message{
+                               Type:  peer_protocol.Have,
+                               Index: peer_protocol.Integer(piece),
+                       })
+               } else {
+                       if conn.PeerHasPiece(piece) {
+                               conn.SetInterested(true)
+                       }
                }
        }
-       go func() {
-               me.torrentFinished <- ih
-       }()
+
 }
 
 func (me *client) run() {
index 476e81d7951262f53493195015e2a182f5dc1cde..74380e9d8012d3b85d6228bff2c4586a6560d862 100644 (file)
@@ -13,6 +13,7 @@ var (
 )
 
 func init() {
+       log.SetFlags(log.LstdFlags | log.Lshortfile)
        flag.Parse()
 }
 
@@ -29,7 +30,7 @@ func main() {
                }
                err = client.AddPeers(torrent.BytesInfoHash(metaInfo.InfoHash), []torrent.Peer{{
                        IP:   net.IPv4(127, 0, 0, 1),
-                       Port: 63983,
+                       Port: 53219,
                }})
                if err != nil {
                        log.Fatal(err)
index bc34d47b2872f8f175691681f1c6dea12e604a0c..dfef720e3a148ada07030e194b49e7c8f8d126dd 100644 (file)
@@ -1,10 +1,23 @@
 package peer_protocol
 
+import (
+       "bufio"
+       "bytes"
+       "encoding/binary"
+       "errors"
+       "fmt"
+       "io"
+)
+
 type (
        MessageType byte
        Integer     uint32
 )
 
+func (i *Integer) Read(r io.Reader) error {
+       return binary.Read(r, binary.BigEndian, i)
+}
+
 const (
        Protocol = "\x13BitTorrent protocol"
 )
@@ -16,21 +29,124 @@ const (
        NotInterested
        Have
        Bitfield
-       RequestType
+       Request
        Piece
        Cancel
 )
 
-type Request struct {
+type Message struct {
+       Keepalive            bool
+       Type                 MessageType
        Index, Begin, Length Integer
+       Piece                []byte
+       Bitfield             []bool
+}
+
+func (msg Message) MarshalBinary() (data []byte, err error) {
+       buf := &bytes.Buffer{}
+       if msg.Keepalive {
+               data = buf.Bytes()
+               return
+       }
+       err = buf.WriteByte(byte(msg.Type))
+       if err != nil {
+               return
+       }
+       switch msg.Type {
+       case Choke, Unchoke, Interested, NotInterested:
+       case Have:
+               err = binary.Write(buf, binary.BigEndian, msg.Index)
+       case Request, Cancel:
+               for _, i := range []Integer{msg.Index, msg.Begin, msg.Length} {
+                       err = binary.Write(buf, binary.BigEndian, i)
+                       if err != nil {
+                               break
+                       }
+               }
+       case Bitfield:
+               _, err = buf.Write(marshalBitfield(msg.Bitfield))
+       default:
+               err = errors.New("unknown message type")
+       }
+       data = buf.Bytes()
+       return
+}
+
+type Decoder struct {
+       R         *bufio.Reader
+       MaxLength Integer
+}
+
+func (d *Decoder) Decode(msg *Message) (err error) {
+       var length Integer
+       err = binary.Read(d.R, binary.BigEndian, &length)
+       if err != nil {
+               return
+       }
+       if length > d.MaxLength {
+               return errors.New("message too long")
+       }
+       if length == 0 {
+               msg.Keepalive = true
+               return
+       }
+       msg.Keepalive = false
+       c, err := d.R.ReadByte()
+       if err != nil {
+               return
+       }
+       msg.Type = MessageType(c)
+       switch msg.Type {
+       case Choke, Unchoke, Interested, NotInterested:
+               return
+       case Have:
+               err = msg.Index.Read(d.R)
+       case Request, Cancel:
+               err = binary.Read(d.R, binary.BigEndian, []*Integer{&msg.Index, &msg.Begin, &msg.Length})
+       case Bitfield:
+               b := make([]byte, length-1)
+               _, err = io.ReadFull(d.R, b)
+               msg.Bitfield = unmarshalBitfield(b)
+       default:
+               err = fmt.Errorf("unknown message type %#v", c)
+       }
+       return
 }
 
-type Message interface {
-       Encode() []byte
+func encodeMessage(type_ MessageType, data interface{}) []byte {
+       w := &bytes.Buffer{}
+       w.WriteByte(byte(type_))
+       err := binary.Write(w, binary.BigEndian, data)
+       if err != nil {
+               panic(err)
+       }
+       return w.Bytes()
 }
 
 type Bytes []byte
 
-func (b Bytes) Encode() []byte {
-       return b
+func (b Bytes) MarshalBinary() ([]byte, error) {
+       return b, nil
+}
+
+func unmarshalBitfield(b []byte) (bf []bool) {
+       for _, c := range b {
+               for i := 7; i >= 0; i-- {
+                       bf = append(bf, (c>>uint(i))&1 == 1)
+               }
+       }
+       return
+}
+
+func marshalBitfield(bf []bool) (b []byte) {
+       b = make([]byte, (len(bf)+7)/8)
+       for i, have := range bf {
+               if !have {
+                       continue
+               }
+               c := b[i/8]
+               c |= 1 << uint(7-i%8)
+               b[i/8] = c
+       }
+       return
 }
index 6306d52c042df8f1f8e0f2a29f9c61444283e71c..7a7b8dd1ac1d109caafaa8811c6c5a86e8a68fb4 100644 (file)
@@ -10,3 +10,22 @@ func TestConstants(t *testing.T) {
                t.FailNow()
        }
 }
+func TestBitfieldEncode(t *testing.T) {
+       bm := make(Bitfield, 37)
+       bm[2] = true
+       bm[7] = true
+       bm[32] = true
+       s := string(bm.Encode())
+       const expected = "\x21\x00\x00\x00\x80"
+       if s != expected {
+               t.Fatalf("got %#v, expected %#v", s, expected)
+       }
+}
+
+func TestHaveEncode(t *testing.T) {
+       actual := string(Have(42).Encode())
+       expected := "\x04\x00\x00\x00\x2a"
+       if actual != expected {
+               t.Fatalf("expected %#v, got %#v", expected, actual)
+       }
+}