From 28531a4fccf0bdff0d40257a8d31bb0d4e9a0a64 Mon Sep 17 00:00:00 2001
From: Matt Joiner <anacrolix@gmail.com>
Date: Tue, 1 Oct 2013 18:43:18 +1000
Subject: [PATCH] Fix request/chunk confusion, missing outgoing message prefix,
 protocol tests; improve request triggering

---
 client.go                      | 116 +++++++++++++++++++++------------
 cmd/torrent/main.go            |   2 +-
 peer_protocol/protocol.go      |  16 ++---
 peer_protocol/protocol_test.go |  41 +++++++++---
 4 files changed, 114 insertions(+), 61 deletions(-)

diff --git a/client.go b/client.go
index 55bf8d1a..fddefa37 100644
--- a/client.go
+++ b/client.go
@@ -49,17 +49,18 @@ const (
 )
 
 type piece struct {
-	State         pieceState
-	Hash          pieceSum
-	PendingChunks map[chunk]struct{}
+	State             pieceState
+	Hash              pieceSum
+	PendingChunkSpecs map[chunkSpec]struct{}
 }
 
-type chunk struct {
+type chunkSpec struct {
 	Begin, Length peer_protocol.Integer
 }
 
 type request struct {
-	Index, Begin, Length peer_protocol.Integer
+	Index peer_protocol.Integer
+	chunkSpec
 }
 
 type connection struct {
@@ -102,6 +103,9 @@ func (c *connection) Request(chunk request) bool {
 			Length: chunk.Length,
 		})
 	}
+	if c.Requests == nil {
+		c.Requests = make(map[request]struct{}, maxRequests)
+	}
 	c.Requests[chunk] = struct{}{}
 	return true
 }
@@ -114,8 +118,9 @@ func (c *connection) SetInterested(interested bool) {
 		Type: func() peer_protocol.MessageType {
 			if interested {
 				return peer_protocol.Interested
+			} else {
+				return peer_protocol.NotInterested
 			}
-			return peer_protocol.NotInterested
 		}(),
 	})
 	c.Interested = interested
@@ -124,7 +129,6 @@ func (c *connection) SetInterested(interested bool) {
 func (conn *connection) writer() {
 	for {
 		b := <-conn.write
-		log.Printf("writing %#v", string(b))
 		n, err := conn.Socket.Write(b)
 		if err != nil {
 			log.Print(err)
@@ -134,6 +138,7 @@ func (conn *connection) writer() {
 		if n != len(b) {
 			panic("didn't write all bytes")
 		}
+		log.Printf("wrote %#v", string(b))
 	}
 }
 
@@ -144,7 +149,6 @@ func (conn *connection) writeOptimizer() {
 		write := conn.write
 		if pending.Len() == 0 {
 			write = nil
-			nextWrite = nil
 		} else {
 			var err error
 			nextWrite, err = pending.Front().Value.(encoding.BinaryMarshaler).MarshalBinary()
@@ -177,9 +181,9 @@ func (t *torrent) bitfield() (bf []bool) {
 	return
 }
 
-func (t *torrent) pieceChunks(index int) (cs map[chunk]struct{}) {
-	cs = make(map[chunk]struct{}, (t.MetaInfo.PieceLength+chunkSize-1)/chunkSize)
-	c := chunk{
+func (t *torrent) pieceChunkSpecs(index int) (cs map[chunkSpec]struct{}) {
+	cs = make(map[chunkSpec]struct{}, (t.MetaInfo.PieceLength+chunkSize-1)/chunkSize)
+	c := chunkSpec{
 		Begin: 0,
 	}
 	for left := peer_protocol.Integer(t.PieceSize(index)); left > 0; left -= c.Length {
@@ -193,7 +197,7 @@ func (t *torrent) pieceChunks(index int) (cs map[chunk]struct{}) {
 	return
 }
 
-func (t *torrent) chunkHeat() (ret map[request]int) {
+func (t *torrent) requestHeat() (ret map[request]int) {
 	ret = make(map[request]int)
 	for _, conn := range t.Conns {
 		for req, _ := range conn.Requests {
@@ -352,6 +356,15 @@ func (me *client) initiateConn(peer Peer, torrent *torrent) {
 	}()
 }
 
+func (me *torrent) haveAnyPieces() bool {
+	for _, piece := range me.Pieces {
+		if piece.State == pieceStateComplete {
+			return true
+		}
+	}
+	return false
+}
+
 func (me *client) handshake(sock net.Conn, torrent *torrent, peerId [20]byte) {
 	conn := &connection{
 		Socket:     sock,
@@ -400,10 +413,12 @@ func (me *client) handshake(sock net.Conn, torrent *torrent, peerId [20]byte) {
 	}
 	me.withContext(func() {
 		me.addConnection(torrent, conn)
-		conn.Post(peer_protocol.Message{
-			Type:     peer_protocol.Bitfield,
-			Bitfield: torrent.bitfield(),
-		})
+		if torrent.haveAnyPieces() {
+			conn.Post(peer_protocol.Message{
+				Type:     peer_protocol.Bitfield,
+				Bitfield: torrent.bitfield(),
+			})
+		}
 		go func() {
 			defer me.withContext(func() {
 				me.dropConnection(torrent, conn)
@@ -417,29 +432,22 @@ func (me *client) handshake(sock net.Conn, torrent *torrent, peerId [20]byte) {
 }
 
 func (me *client) peerGotPiece(torrent *torrent, conn *connection, piece int) {
-	if torrent.Pieces[piece].State != pieceStateIncomplete {
-		return
+	if conn.PeerPieces == nil {
+		conn.PeerPieces = make([]bool, len(torrent.Pieces))
+	}
+	conn.PeerPieces[piece] = true
+	if torrent.wantPiece(piece) {
+		conn.SetInterested(true)
+		me.replenishConnRequests(torrent, conn)
 	}
-	conn.SetInterested(true)
+}
+
+func (t *torrent) wantPiece(index int) bool {
+	return t.Pieces[index].State == pieceStateIncomplete
 }
 
 func (me *client) peerUnchoked(torrent *torrent, conn *connection) {
-	chunkHeatMap := torrent.chunkHeat()
-	for index, has := range conn.PeerPieces {
-		if !has {
-			continue
-		}
-		for chunk, _ := range torrent.Pieces[index].PendingChunks {
-			if _, ok := chunkHeatMap[chunk]; ok {
-				continue
-			}
-			conn.SetInterested(true)
-			if !conn.Request(chunk) {
-				return
-			}
-		}
-	}
-	conn.SetInterested(false)
+	me.replenishConnRequests(torrent, conn)
 }
 
 func (me *client) runConnection(torrent *torrent, conn *connection) error {
@@ -457,6 +465,7 @@ func (me *client) runConnection(torrent *torrent, conn *connection) error {
 			continue
 		}
 		go me.withContext(func() {
+			log.Print(msg)
 			var err error
 			switch msg.Type {
 			case peer_protocol.Choke:
@@ -470,12 +479,10 @@ func (me *client) runConnection(torrent *torrent, conn *connection) error {
 				conn.PeerInterested = false
 			case peer_protocol.Have:
 				me.peerGotPiece(torrent, conn, int(msg.Index))
-				conn.PeerPieces[msg.Index] = true
 			case peer_protocol.Request:
 				conn.PeerRequests[request{
-					Index:  msg.Index,
-					Begin:  msg.Begin,
-					Length: msg.Length,
+					Index:     msg.Index,
+					chunkSpec: chunkSpec{msg.Begin, msg.Length},
 				}] = struct{}{}
 			case peer_protocol.Bitfield:
 				if len(msg.Bitfield) < len(torrent.Pieces) {
@@ -589,6 +596,33 @@ func (me *client) withContext(f func()) {
 	me.actorTask <- f
 }
 
+func (me *client) replenishConnRequests(torrent *torrent, conn *connection) {
+	if len(conn.Requests) >= maxRequests {
+		return
+	}
+	if conn.PeerChoked {
+		return
+	}
+	requestHeatMap := torrent.requestHeat()
+	for index, has := range conn.PeerPieces {
+		if !has {
+			continue
+		}
+		for chunkSpec, _ := range torrent.Pieces[index].PendingChunkSpecs {
+			request := request{peer_protocol.Integer(index), chunkSpec}
+			if heat := requestHeatMap[request]; heat > 0 {
+				continue
+			}
+			conn.SetInterested(true)
+			if !conn.Request(request) {
+				return
+			}
+		}
+	}
+	//conn.SetInterested(false)
+
+}
+
 func (me *client) pieceHashed(ih InfoHash, piece int, correct bool) {
 	torrent := me.torrents[ih]
 	newState := func() pieceState {
@@ -604,7 +638,7 @@ func (me *client) pieceHashed(ih InfoHash, piece int, correct bool) {
 	}
 	torrent.Pieces[piece].State = newState
 	if newState == pieceStateIncomplete {
-		torrent.Pieces[piece].PendingChunks = torrent.pieceChunks(piece)
+		torrent.Pieces[piece].PendingChunkSpecs = torrent.pieceChunkSpecs(piece)
 	}
 	for _, conn := range torrent.Conns {
 		if correct {
@@ -614,7 +648,7 @@ func (me *client) pieceHashed(ih InfoHash, piece int, correct bool) {
 			})
 		} else {
 			if conn.PeerHasPiece(piece) {
-				conn.SetInterested(true)
+				me.replenishConnRequests(torrent, conn)
 			}
 		}
 	}
diff --git a/cmd/torrent/main.go b/cmd/torrent/main.go
index 74380e9d..76a9d943 100644
--- a/cmd/torrent/main.go
+++ b/cmd/torrent/main.go
@@ -30,7 +30,7 @@ func main() {
 		}
 		err = client.AddPeers(torrent.BytesInfoHash(metaInfo.InfoHash), []torrent.Peer{{
 			IP:   net.IPv4(127, 0, 0, 1),
-			Port: 53219,
+			Port: 50933,
 		}})
 		if err != nil {
 			log.Fatal(err)
diff --git a/peer_protocol/protocol.go b/peer_protocol/protocol.go
index dfef720e..4020c524 100644
--- a/peer_protocol/protocol.go
+++ b/peer_protocol/protocol.go
@@ -68,7 +68,11 @@ func (msg Message) MarshalBinary() (data []byte, err error) {
 	default:
 		err = errors.New("unknown message type")
 	}
-	data = buf.Bytes()
+	data = make([]byte, 4+buf.Len())
+	binary.BigEndian.PutUint32(data, uint32(buf.Len()))
+	if buf.Len() != copy(data[4:], buf.Bytes()) {
+		panic("bad copy")
+	}
 	return
 }
 
@@ -113,16 +117,6 @@ func (d *Decoder) Decode(msg *Message) (err error) {
 	return
 }
 
-func encodeMessage(type_ MessageType, data interface{}) []byte {
-	w := &bytes.Buffer{}
-	w.WriteByte(byte(type_))
-	err := binary.Write(w, binary.BigEndian, data)
-	if err != nil {
-		panic(err)
-	}
-	return w.Bytes()
-}
-
 type Bytes []byte
 
 func (b Bytes) MarshalBinary() ([]byte, error) {
diff --git a/peer_protocol/protocol_test.go b/peer_protocol/protocol_test.go
index 7a7b8dd1..3c04aff8 100644
--- a/peer_protocol/protocol_test.go
+++ b/peer_protocol/protocol_test.go
@@ -10,22 +10,47 @@ func TestConstants(t *testing.T) {
 		t.FailNow()
 	}
 }
+
 func TestBitfieldEncode(t *testing.T) {
-	bm := make(Bitfield, 37)
-	bm[2] = true
-	bm[7] = true
-	bm[32] = true
-	s := string(bm.Encode())
+	bf := make([]bool, 37)
+	bf[2] = true
+	bf[7] = true
+	bf[32] = true
+	s := string(marshalBitfield(bf))
 	const expected = "\x21\x00\x00\x00\x80"
 	if s != expected {
 		t.Fatalf("got %#v, expected %#v", s, expected)
 	}
 }
 
+func TestBitfieldUnmarshal(t *testing.T) {
+	bf := unmarshalBitfield([]byte("\x81\x06"))
+	expected := make([]bool, 16)
+	expected[0] = true
+	expected[7] = true
+	expected[13] = true
+	expected[14] = true
+	if len(bf) != len(expected) {
+		t.FailNow()
+	}
+	for i := range expected {
+		if bf[i] != expected[i] {
+			t.FailNow()
+		}
+	}
+}
+
 func TestHaveEncode(t *testing.T) {
-	actual := string(Have(42).Encode())
+	actualBytes, err := Message{
+		Type:  Have,
+		Index: 42,
+	}.MarshalBinary()
+	if err != nil {
+		t.Fatal(err)
+	}
+	actualString := string(actualBytes)
 	expected := "\x04\x00\x00\x00\x2a"
-	if actual != expected {
-		t.Fatalf("expected %#v, got %#v", expected, actual)
+	if actualString != expected {
+		t.Fatalf("expected %#v, got %#v", expected, actualString)
 	}
 }
-- 
2.51.0