From 1fc63bddedfbd96a3f6fcac3444365accfec4fc7 Mon Sep 17 00:00:00 2001
From: Matt Joiner <anacrolix@gmail.com>
Date: Fri, 23 Oct 2020 09:03:44 +1100
Subject: [PATCH] sqlite storage: Add NewProviderPool

---
 storage/sqlite/sqlite-storage.go      | 60 ++++++++++++++++++++-------
 storage/sqlite/sqlite-storage_test.go | 51 +++++++++++++++++++++++
 2 files changed, 97 insertions(+), 14 deletions(-)
 create mode 100644 storage/sqlite/sqlite-storage_test.go

diff --git a/storage/sqlite/sqlite-storage.go b/storage/sqlite/sqlite-storage.go
index 0ec59070..e4eba4be 100644
--- a/storage/sqlite/sqlite-storage.go
+++ b/storage/sqlite/sqlite-storage.go
@@ -2,6 +2,7 @@ package sqliteStorage
 
 import (
 	"bytes"
+	"context"
 	"errors"
 	"fmt"
 	"io"
@@ -27,14 +28,43 @@ create table if not exists blob(
 `)
 }
 
+// Emulates a pool from a single Conn.
+type poolFromConn struct {
+	mu   sync.Mutex
+	conn conn
+}
+
+func (me *poolFromConn) Get(ctx context.Context) conn {
+	me.mu.Lock()
+	return me.conn
+}
+
+func (me *poolFromConn) Put(conn conn) {
+	if conn != me.conn {
+		panic("expected to same conn")
+	}
+	me.mu.Unlock()
+}
+
 func NewProvider(conn *sqlite.Conn) (*provider, error) {
 	err := initConn(conn)
-	return &provider{conn: conn}, err
+	return &provider{&poolFromConn{conn: conn}}, err
+}
+
+func NewProviderPool(pool *sqlitex.Pool) (*provider, error) {
+	conn := pool.Get(context.TODO())
+	defer pool.Put(conn)
+	err := initConn(conn)
+	return &provider{pool: pool}, err
+}
+
+type pool interface {
+	Get(context.Context) conn
+	Put(conn)
 }
 
 type provider struct {
-	mu   sync.Mutex
-	conn conn
+	pool pool
 }
 
 func (p *provider) NewInstance(s string) (resource.Instance, error) {
@@ -47,17 +77,17 @@ type instance struct {
 }
 
 func (i instance) withConn(with func(conn conn)) {
-	i.lockConn()
-	defer i.unlockConn()
-	with(i.p.conn)
+	conn := i.p.pool.Get(context.TODO())
+	defer i.p.pool.Put(conn)
+	with(conn)
 }
 
-func (i instance) lockConn() {
-	i.p.mu.Lock()
+func (i instance) getConn() *sqlite.Conn {
+	return i.p.pool.Get(context.TODO())
 }
 
-func (i instance) unlockConn() {
-	i.p.mu.Unlock()
+func (i instance) putConn(conn *sqlite.Conn) {
+	i.p.pool.Put(conn)
 }
 
 func (i instance) Readdirnames() (names []string, err error) {
@@ -104,15 +134,15 @@ func (me connBlob) Close() error {
 }
 
 func (i instance) Get() (ret io.ReadCloser, err error) {
-	i.lockConn()
-	blob, err := i.openBlob(i.p.conn, false, true)
+	conn := i.getConn()
+	blob, err := i.openBlob(conn, false, true)
 	if err != nil {
-		i.unlockConn()
+		i.putConn(conn)
 		return
 	}
 	var once sync.Once
 	return connBlob{blob, func() {
-		once.Do(i.unlockConn)
+		once.Do(func() { i.putConn(conn) })
 	}}, nil
 }
 
@@ -121,6 +151,8 @@ func (i instance) openBlob(conn conn, write, updateAccess bool) (*sqlite.Blob, e
 	if err != nil {
 		return nil, err
 	}
+	// This seems to cause locking issues with in-memory databases. Is it something to do with not
+	// having WAL?
 	if updateAccess {
 		err = sqlitex.Exec(conn, "update blob set last_used=datetime('now') where rowid=?", nil, rowid)
 		if err != nil {
diff --git a/storage/sqlite/sqlite-storage_test.go b/storage/sqlite/sqlite-storage_test.go
new file mode 100644
index 00000000..c18c4c3a
--- /dev/null
+++ b/storage/sqlite/sqlite-storage_test.go
@@ -0,0 +1,51 @@
+package sqliteStorage
+
+import (
+	"bytes"
+	"fmt"
+	"io"
+	"io/ioutil"
+	"path/filepath"
+	"sync"
+	"testing"
+
+	"crawshaw.io/sqlite/sqlitex"
+	_ "github.com/anacrolix/envpprof"
+	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
+)
+
+func TestSimultaneousIncrementalBlob(t *testing.T) {
+	pool, err := sqlitex.Open(
+		// We don't do this in memory, because it seems to have some locking issues with updating
+		// last_used.
+		fmt.Sprintf("file:%s", filepath.Join(t.TempDir(), "sqlite3.db")),
+		0,
+		10)
+	require.NoError(t, err)
+	defer pool.Close()
+	p, err := NewProviderPool(pool)
+	require.NoError(t, err)
+	a, err := p.NewInstance("a")
+	require.NoError(t, err)
+	const contents = "hello, world"
+	require.NoError(t, a.Put(bytes.NewReader([]byte("hello, world"))))
+	rc1, err := a.Get()
+	require.NoError(t, err)
+	rc2, err := a.Get()
+	require.NoError(t, err)
+	var b1, b2 []byte
+	var e1, e2 error
+	var wg sync.WaitGroup
+	doRead := func(b *[]byte, e *error, rc io.ReadCloser, n int) {
+		defer wg.Done()
+		defer rc.Close()
+		*b, *e = ioutil.ReadAll(rc)
+		require.NoError(t, *e, n)
+		assert.EqualValues(t, contents, *b)
+	}
+	wg.Add(2)
+	go doRead(&b2, &e2, rc2, 2)
+	go doRead(&b1, &e1, rc1, 1)
+	wg.Wait()
+}
-- 
2.51.0