]> Sergey Matveev's repositories - btrtrc.git/commitdiff
Remove requests from the outbound message queue if cancelled before they're written
authorMatt Joiner <anacrolix@gmail.com>
Wed, 28 May 2014 15:27:48 +0000 (01:27 +1000)
committerMatt Joiner <anacrolix@gmail.com>
Wed, 28 May 2014 15:27:48 +0000 (01:27 +1000)
Only post peer protocol messages to the channel, bytes must be done directly.
This fixes a possible issue where slow responses during handshake could cause
keep alive messages to be sent prematurely.

client.go
connection.go
connection_test.go [new file with mode: 0644]

index 7d3eb7ebb17424d48daf1c1e45616b07a43e1df3..31ecc9f15682b3990d158bda5371fc6e05c7374e 100644 (file)
--- a/client.go
+++ b/client.go
@@ -19,7 +19,6 @@ import (
        "bufio"
        "container/list"
        "crypto/rand"
-       "encoding"
        "errors"
        "fmt"
        "io"
@@ -278,7 +277,7 @@ func (me *Client) runConnection(sock net.Conn, torrent *torrent) (err error) {
                Choked:          true,
                PeerChoked:      true,
                write:           make(chan []byte),
-               post:            make(chan encoding.BinaryMarshaler),
+               post:            make(chan pp.Message),
                PeerMaxRequests: 250,
        }
        defer func() {
@@ -289,12 +288,12 @@ func (me *Client) runConnection(sock net.Conn, torrent *torrent) (err error) {
                conn.Close()
        }()
        go conn.writer()
-       go conn.writeOptimizer()
-       conn.post <- pp.Bytes(pp.Protocol)
-       conn.post <- pp.Bytes("\x00\x00\x00\x00\x00\x00\x00\x00")
+       // go conn.writeOptimizer()
+       conn.write <- pp.Bytes(pp.Protocol)
+       conn.write <- pp.Bytes("\x00\x00\x00\x00\x00\x00\x00\x00")
        if torrent != nil {
-               conn.post <- pp.Bytes(torrent.InfoHash[:])
-               conn.post <- pp.Bytes(me.PeerId[:])
+               conn.write <- pp.Bytes(torrent.InfoHash[:])
+               conn.write <- pp.Bytes(me.PeerId[:])
        }
        var b [28]byte
        _, err = io.ReadFull(conn.Socket, b[:])
@@ -327,14 +326,15 @@ func (me *Client) runConnection(sock net.Conn, torrent *torrent) (err error) {
                if torrent == nil {
                        return
                }
-               conn.post <- pp.Bytes(torrent.InfoHash[:])
-               conn.post <- pp.Bytes(me.PeerId[:])
+               conn.write <- pp.Bytes(torrent.InfoHash[:])
+               conn.write <- pp.Bytes(me.PeerId[:])
        }
        me.mu.Lock()
        defer me.mu.Unlock()
        if !me.addConnection(torrent, conn) {
                return
        }
+       go conn.writeOptimizer(time.Minute)
        if torrent.haveAnyPieces() {
                conn.Post(pp.Message{
                        Type:     pp.Bitfield,
index 6bb5f5b2bbb18688d6794de693415b5a0aa4f435..df774451888f8981a5d56a51fdf97eee41d3b250 100644 (file)
@@ -16,7 +16,7 @@ type connection struct {
        Socket net.Conn
        closed bool
        mu     sync.Mutex // Only for closing.
-       post   chan encoding.BinaryMarshaler
+       post   chan peer_protocol.Message
        write  chan []byte
 
        // Stuff controlled by the local peer.
@@ -58,7 +58,7 @@ func (c *connection) PeerHasPiece(index peer_protocol.Integer) bool {
        return c.PeerPieces[index]
 }
 
-func (c *connection) Post(msg encoding.BinaryMarshaler) {
+func (c *connection) Post(msg peer_protocol.Message) {
        c.post <- msg
 }
 
@@ -166,56 +166,71 @@ var (
        keepAliveBytes [4]byte
 )
 
+// Writes buffers to the socket from the write channel.
 func (conn *connection) writer() {
-       timer := time.NewTimer(0)
-       defer timer.Stop()
-       for {
-               if !timer.Reset(time.Minute) {
-                       <-timer.C
-               }
-               var b []byte
-               select {
-               case <-timer.C:
-                       b = keepAliveBytes[:]
-               case b = <-conn.write:
-                       if b == nil {
-                               return
-                       }
-               }
+       for b := range conn.write {
                _, err := conn.Socket.Write(b)
-               if conn.getClosed() {
-                       break
-               }
                if err != nil {
-                       log.Print(err)
+                       if !conn.getClosed() {
+                               log.Print(err)
+                       }
                        break
                }
        }
 }
 
-func (conn *connection) writeOptimizer() {
-       pending := list.New()
-       var nextWrite []byte
-       defer close(conn.write)
+func (conn *connection) writeOptimizer(keepAliveDelay time.Duration) {
+       defer close(conn.write) // Responsible for notifying downstream routines.
+       pending := list.New()   // Message queue.
+       var nextWrite []byte    // Set to nil if we need to need to marshal the next message.
+       timer := time.NewTimer(keepAliveDelay)
+       defer timer.Stop()
+       lastWrite := time.Now()
        for {
-               write := conn.write
+               write := conn.write // Set to nil if there's nothing to write.
                if pending.Len() == 0 {
                        write = nil
-               } else {
+               } else if nextWrite == nil {
                        var err error
                        nextWrite, err = pending.Front().Value.(encoding.BinaryMarshaler).MarshalBinary()
                        if err != nil {
                                panic(err)
                        }
                }
+       event:
                select {
+               case <-timer.C:
+                       if pending.Len() != 0 {
+                               break
+                       }
+                       keepAliveTime := lastWrite.Add(keepAliveDelay)
+                       if time.Now().Before(keepAliveTime) {
+                               timer.Reset(keepAliveTime.Sub(time.Now()))
+                               break
+                       }
+                       pending.PushBack(peer_protocol.Message{Keepalive: true})
                case msg, ok := <-conn.post:
                        if !ok {
                                return
                        }
+                       if msg.Type == peer_protocol.Cancel {
+                               for e := pending.Back(); e != nil; e = e.Prev() {
+                                       elemMsg := e.Value.(peer_protocol.Message)
+                                       if elemMsg.Type == peer_protocol.Request && msg.Index == elemMsg.Index && msg.Begin == elemMsg.Begin && msg.Length == elemMsg.Length {
+                                               pending.Remove(e)
+                                               log.Print("optimized cancel! %q", msg)
+                                               break event
+                                       }
+                               }
+                       }
                        pending.PushBack(msg)
                case write <- nextWrite:
                        pending.Remove(pending.Front())
+                       nextWrite = nil
+                       lastWrite = time.Now()
+                       if pending.Len() == 0 {
+                               timer.Reset(keepAliveDelay)
+                       }
                }
        }
 }
diff --git a/connection_test.go b/connection_test.go
new file mode 100644 (file)
index 0000000..9374a66
--- /dev/null
@@ -0,0 +1,46 @@
+package torrent
+
+import (
+       "bitbucket.org/anacrolix/go.torrent/peer_protocol"
+       "testing"
+       "time"
+)
+
+func TestCancelRequestOptimized(t *testing.T) {
+       c := &connection{
+               PeerMaxRequests: 1,
+               PeerPieces:      []bool{false, true},
+               post:            make(chan peer_protocol.Message),
+               write:           make(chan []byte),
+       }
+       if len(c.Requests) != 0 {
+               t.FailNow()
+       }
+       // Keepalive timeout of 0 works because I'm just that good.
+       go c.writeOptimizer(0 * time.Millisecond)
+       c.Request(newRequest(1, 2, 3))
+       if len(c.Requests) != 1 {
+               t.Fatal("request was not posted")
+       }
+       // Posting this message should removing the pending Request.
+       if !c.Cancel(newRequest(1, 2, 3)) {
+               t.Fatal("request was not found")
+       }
+       // Check that the write optimization has filtered out the Request message.
+       for _, b := range []string{
+               // The initial request triggers an Interested message.
+               "\x00\x00\x00\x01\x02",
+               // Let a keep-alive through to verify there were no pending messages.
+               "\x00\x00\x00\x00",
+       } {
+               bb := string(<-c.write)
+               if b != bb {
+                       t.Fatalf("received message %q is not expected: %q", bb, b)
+               }
+       }
+       close(c.post)
+       _, ok := <-c.write
+       if ok {
+               t.Fatal("write channel didn't close")
+       }
+}