From: Matt Joiner Date: Mon, 30 Sep 2013 11:51:08 +0000 (+1000) Subject: Implementing bitfields and connection message handling X-Git-Tag: v1.0.0~1810 X-Git-Url: http://www.git.stargrave.org/?a=commitdiff_plain;h=081a6805c56309185d170d19518bca5e60999d40;p=btrtrc.git Implementing bitfields and connection message handling --- diff --git a/client.go b/client.go index 284c2c1c..55bf8d1a 100644 --- a/client.go +++ b/client.go @@ -2,9 +2,11 @@ package torrent import ( "bitbucket.org/anacrolix/go.torrent/peer_protocol" + "bufio" "container/list" "crypto" "crypto/rand" + "encoding" "errors" metainfo "github.com/nsf/libtorgo/torrent" "io" @@ -16,7 +18,9 @@ import ( ) const ( - PieceHash = crypto.SHA1 + PieceHash = crypto.SHA1 + maxRequests = 10 + chunkSize = 0x4000 // 16KiB ) type InfoHash [20]byte @@ -45,24 +49,76 @@ const ( ) type piece struct { - State pieceState - Hash pieceSum + State pieceState + Hash pieceSum + PendingChunks map[chunk]struct{} +} + +type chunk struct { + Begin, Length peer_protocol.Integer +} + +type request struct { + Index, Begin, Length peer_protocol.Integer } type connection struct { Socket net.Conn - post chan peer_protocol.Message + post chan encoding.BinaryMarshaler write chan []byte Interested bool Choked bool - Requests []peer_protocol.Request + Requests map[request]struct{} PeerId [20]byte PeerInterested bool PeerChoked bool - PeerRequests []peer_protocol.Request + PeerRequests map[request]struct{} PeerExtensions [8]byte + PeerPieces []bool +} + +func (c *connection) PeerHasPiece(index int) bool { + if c.PeerPieces == nil { + return false + } + return c.PeerPieces[index] +} + +func (c *connection) Post(msg encoding.BinaryMarshaler) { + c.post <- msg +} + +func (c *connection) Request(chunk request) bool { + if len(c.Requests) >= maxRequests { + return false + } + if _, ok := c.Requests[chunk]; !ok { + c.Post(peer_protocol.Message{ + Type: peer_protocol.Request, + Index: chunk.Index, + Begin: chunk.Begin, + Length: chunk.Length, + }) + } + c.Requests[chunk] = struct{}{} + return true +} + +func (c *connection) SetInterested(interested bool) { + if c.Interested == interested { + return + } + c.Post(peer_protocol.Message{ + Type: func() peer_protocol.MessageType { + if interested { + return peer_protocol.Interested + } + return peer_protocol.NotInterested + }(), + }) + c.Interested = interested } func (conn *connection) writer() { @@ -90,7 +146,11 @@ func (conn *connection) writeOptimizer() { write = nil nextWrite = nil } else { - nextWrite = pending.Front().Value.(peer_protocol.Message).Encode() + var err error + nextWrite, err = pending.Front().Value.(encoding.BinaryMarshaler).MarshalBinary() + if err != nil { + panic(err) + } } select { case msg := <-conn.post: @@ -106,10 +166,43 @@ type torrent struct { Pieces []piece Data MMapSpan MetaInfo *metainfo.MetaInfo - Conns []connection + Conns []*connection Peers []Peer } +func (t *torrent) bitfield() (bf []bool) { + for _, p := range t.Pieces { + bf = append(bf, p.State == pieceStateComplete) + } + return +} + +func (t *torrent) pieceChunks(index int) (cs map[chunk]struct{}) { + cs = make(map[chunk]struct{}, (t.MetaInfo.PieceLength+chunkSize-1)/chunkSize) + c := chunk{ + Begin: 0, + } + for left := peer_protocol.Integer(t.PieceSize(index)); left > 0; left -= c.Length { + c.Length = left + if c.Length > chunkSize { + c.Length = chunkSize + } + cs[c] = struct{}{} + c.Begin += c.Length + } + return +} + +func (t *torrent) chunkHeat() (ret map[request]int) { + ret = make(map[request]int) + for _, conn := range t.Conns { + for req, _ := range conn.Requests { + ret[req]++ + } + } + return +} + type Peer struct { Id [20]byte IP net.IP @@ -143,6 +236,8 @@ func (t *torrent) HashPiece(piece int) (ps pieceSum) { return } +// func (t *torrent) bitfield + type client struct { DataDir string HalfOpenLimit int @@ -252,7 +347,7 @@ func (me *client) initiateConn(peer Peer, torrent *torrent) { log.Printf("error connecting to peer: %s", err) return } - log.Printf("connected to %s", sock.RemoteAddr()) + log.Printf("connected to %s", conn.RemoteAddr()) me.handshake(conn, torrent, peer.Id) }() } @@ -263,7 +358,7 @@ func (me *client) handshake(sock net.Conn, torrent *torrent, peerId [20]byte) { Choked: true, PeerChoked: true, write: make(chan []byte), - post: make(chan peer_protocol.Message), + post: make(chan encoding.BinaryMarshaler), } go conn.writer() go conn.writeOptimizer() @@ -303,6 +398,135 @@ func (me *client) handshake(sock net.Conn, torrent *torrent, peerId [20]byte) { conn.post <- peer_protocol.Bytes(torrent.InfoHash[:]) conn.post <- peer_protocol.Bytes(me.PeerId[:]) } + me.withContext(func() { + me.addConnection(torrent, conn) + conn.Post(peer_protocol.Message{ + Type: peer_protocol.Bitfield, + Bitfield: torrent.bitfield(), + }) + go func() { + defer me.withContext(func() { + me.dropConnection(torrent, conn) + }) + err := me.runConnection(torrent, conn) + if err != nil { + log.Print(err) + } + }() + }) +} + +func (me *client) peerGotPiece(torrent *torrent, conn *connection, piece int) { + if torrent.Pieces[piece].State != pieceStateIncomplete { + return + } + conn.SetInterested(true) +} + +func (me *client) peerUnchoked(torrent *torrent, conn *connection) { + chunkHeatMap := torrent.chunkHeat() + for index, has := range conn.PeerPieces { + if !has { + continue + } + for chunk, _ := range torrent.Pieces[index].PendingChunks { + if _, ok := chunkHeatMap[chunk]; ok { + continue + } + conn.SetInterested(true) + if !conn.Request(chunk) { + return + } + } + } + conn.SetInterested(false) +} + +func (me *client) runConnection(torrent *torrent, conn *connection) error { + decoder := peer_protocol.Decoder{ + R: bufio.NewReader(conn.Socket), + MaxLength: 256 * 1024, + } + for { + msg := new(peer_protocol.Message) + err := decoder.Decode(msg) + if err != nil { + return err + } + if msg.Keepalive { + continue + } + go me.withContext(func() { + var err error + switch msg.Type { + case peer_protocol.Choke: + conn.PeerChoked = true + case peer_protocol.Unchoke: + conn.PeerChoked = false + me.peerUnchoked(torrent, conn) + case peer_protocol.Interested: + conn.PeerInterested = true + case peer_protocol.NotInterested: + conn.PeerInterested = false + case peer_protocol.Have: + me.peerGotPiece(torrent, conn, int(msg.Index)) + conn.PeerPieces[msg.Index] = true + case peer_protocol.Request: + conn.PeerRequests[request{ + Index: msg.Index, + Begin: msg.Begin, + Length: msg.Length, + }] = struct{}{} + case peer_protocol.Bitfield: + if len(msg.Bitfield) < len(torrent.Pieces) { + err = errors.New("received invalid bitfield") + break + } + if conn.PeerPieces != nil { + err = errors.New("received unexpected bitfield") + break + } + conn.PeerPieces = msg.Bitfield[:len(torrent.Pieces)] + for index, has := range conn.PeerPieces { + if has { + me.peerGotPiece(torrent, conn, index) + } + } + default: + log.Printf("received unknown message type: %#v", msg.Type) + } + if err != nil { + log.Print(err) + me.dropConnection(torrent, conn) + } + }) + } +} + +func (me *client) dropConnection(torrent *torrent, conn *connection) { + conn.Socket.Close() + for i0, c := range torrent.Conns { + if c != conn { + continue + } + i1 := len(torrent.Conns) - 1 + if i0 != i1 { + torrent.Conns[i0] = torrent.Conns[i1] + } + torrent.Conns = torrent.Conns[:i1] + return + } + panic("no such connection") +} + +func (me *client) addConnection(t *torrent, c *connection) bool { + for _, c := range t.Conns { + if c.PeerId == c.PeerId { + return false + } + } + t.Conns = append(t.Conns, c) + return true } func (me *client) openNewConns() { @@ -367,21 +591,34 @@ func (me *client) withContext(f func()) { func (me *client) pieceHashed(ih InfoHash, piece int, correct bool) { torrent := me.torrents[ih] - torrent.Pieces[piece].State = func() pieceState { + newState := func() pieceState { if correct { return pieceStateComplete } else { return pieceStateIncomplete } }() - for _, piece := range torrent.Pieces { - if piece.State == pieceStateUnknown { - return + oldState := torrent.Pieces[piece].State + if newState == oldState { + return + } + torrent.Pieces[piece].State = newState + if newState == pieceStateIncomplete { + torrent.Pieces[piece].PendingChunks = torrent.pieceChunks(piece) + } + for _, conn := range torrent.Conns { + if correct { + conn.Post(peer_protocol.Message{ + Type: peer_protocol.Have, + Index: peer_protocol.Integer(piece), + }) + } else { + if conn.PeerHasPiece(piece) { + conn.SetInterested(true) + } } } - go func() { - me.torrentFinished <- ih - }() + } func (me *client) run() { diff --git a/cmd/torrent/main.go b/cmd/torrent/main.go index 476e81d7..74380e9d 100644 --- a/cmd/torrent/main.go +++ b/cmd/torrent/main.go @@ -13,6 +13,7 @@ var ( ) func init() { + log.SetFlags(log.LstdFlags | log.Lshortfile) flag.Parse() } @@ -29,7 +30,7 @@ func main() { } err = client.AddPeers(torrent.BytesInfoHash(metaInfo.InfoHash), []torrent.Peer{{ IP: net.IPv4(127, 0, 0, 1), - Port: 63983, + Port: 53219, }}) if err != nil { log.Fatal(err) diff --git a/peer_protocol/protocol.go b/peer_protocol/protocol.go index bc34d47b..dfef720e 100644 --- a/peer_protocol/protocol.go +++ b/peer_protocol/protocol.go @@ -1,10 +1,23 @@ package peer_protocol +import ( + "bufio" + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" +) + type ( MessageType byte Integer uint32 ) +func (i *Integer) Read(r io.Reader) error { + return binary.Read(r, binary.BigEndian, i) +} + const ( Protocol = "\x13BitTorrent protocol" ) @@ -16,21 +29,124 @@ const ( NotInterested Have Bitfield - RequestType + Request Piece Cancel ) -type Request struct { +type Message struct { + Keepalive bool + Type MessageType Index, Begin, Length Integer + Piece []byte + Bitfield []bool +} + +func (msg Message) MarshalBinary() (data []byte, err error) { + buf := &bytes.Buffer{} + if msg.Keepalive { + data = buf.Bytes() + return + } + err = buf.WriteByte(byte(msg.Type)) + if err != nil { + return + } + switch msg.Type { + case Choke, Unchoke, Interested, NotInterested: + case Have: + err = binary.Write(buf, binary.BigEndian, msg.Index) + case Request, Cancel: + for _, i := range []Integer{msg.Index, msg.Begin, msg.Length} { + err = binary.Write(buf, binary.BigEndian, i) + if err != nil { + break + } + } + case Bitfield: + _, err = buf.Write(marshalBitfield(msg.Bitfield)) + default: + err = errors.New("unknown message type") + } + data = buf.Bytes() + return +} + +type Decoder struct { + R *bufio.Reader + MaxLength Integer +} + +func (d *Decoder) Decode(msg *Message) (err error) { + var length Integer + err = binary.Read(d.R, binary.BigEndian, &length) + if err != nil { + return + } + if length > d.MaxLength { + return errors.New("message too long") + } + if length == 0 { + msg.Keepalive = true + return + } + msg.Keepalive = false + c, err := d.R.ReadByte() + if err != nil { + return + } + msg.Type = MessageType(c) + switch msg.Type { + case Choke, Unchoke, Interested, NotInterested: + return + case Have: + err = msg.Index.Read(d.R) + case Request, Cancel: + err = binary.Read(d.R, binary.BigEndian, []*Integer{&msg.Index, &msg.Begin, &msg.Length}) + case Bitfield: + b := make([]byte, length-1) + _, err = io.ReadFull(d.R, b) + msg.Bitfield = unmarshalBitfield(b) + default: + err = fmt.Errorf("unknown message type %#v", c) + } + return } -type Message interface { - Encode() []byte +func encodeMessage(type_ MessageType, data interface{}) []byte { + w := &bytes.Buffer{} + w.WriteByte(byte(type_)) + err := binary.Write(w, binary.BigEndian, data) + if err != nil { + panic(err) + } + return w.Bytes() } type Bytes []byte -func (b Bytes) Encode() []byte { - return b +func (b Bytes) MarshalBinary() ([]byte, error) { + return b, nil +} + +func unmarshalBitfield(b []byte) (bf []bool) { + for _, c := range b { + for i := 7; i >= 0; i-- { + bf = append(bf, (c>>uint(i))&1 == 1) + } + } + return +} + +func marshalBitfield(bf []bool) (b []byte) { + b = make([]byte, (len(bf)+7)/8) + for i, have := range bf { + if !have { + continue + } + c := b[i/8] + c |= 1 << uint(7-i%8) + b[i/8] = c + } + return } diff --git a/peer_protocol/protocol_test.go b/peer_protocol/protocol_test.go index 6306d52c..7a7b8dd1 100644 --- a/peer_protocol/protocol_test.go +++ b/peer_protocol/protocol_test.go @@ -10,3 +10,22 @@ func TestConstants(t *testing.T) { t.FailNow() } } +func TestBitfieldEncode(t *testing.T) { + bm := make(Bitfield, 37) + bm[2] = true + bm[7] = true + bm[32] = true + s := string(bm.Encode()) + const expected = "\x21\x00\x00\x00\x80" + if s != expected { + t.Fatalf("got %#v, expected %#v", s, expected) + } +} + +func TestHaveEncode(t *testing.T) { + actual := string(Have(42).Encode()) + expected := "\x04\x00\x00\x00\x2a" + if actual != expected { + t.Fatalf("expected %#v, got %#v", expected, actual) + } +}