Skip to content

Commit

Permalink
Automatically load extensions. (#115)
Browse files Browse the repository at this point in the history
  • Loading branch information
ncruces authored Jul 8, 2024
1 parent fff8b1c commit b5f746a
Show file tree
Hide file tree
Showing 36 changed files with 261 additions and 245 deletions.
3 changes: 3 additions & 0 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ func newConn(filename string, flags OpenFlag) (conn *Conn, err error) {
c.arena = c.newArena(1024)
c.ctx = context.WithValue(c.ctx, connKey{}, c)
c.handle, err = c.openDB(filename, flags)
if err == nil {
err = initExtensions(c)
}
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions ext/array/array.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ import (
// The argument must be bound to a Go slice or array of
// ints, floats, bools, strings or byte slices,
// using [sqlite3.BindPointer] or [sqlite3.Pointer].
func Register(db *sqlite3.Conn) {
sqlite3.CreateModule(db, "array", nil,
func Register(db *sqlite3.Conn) error {
return sqlite3.CreateModule(db, "array", nil,
func(db *sqlite3.Conn, _, _, _ string, _ ...string) (array, error) {
err := db.DeclareVTab(`CREATE TABLE x(value, array HIDDEN)`)
return array{}, err
Expand Down
19 changes: 8 additions & 11 deletions ext/array/array_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,7 @@ import (
)

func Example_driver() {
db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error {
array.Register(c)
return nil
})
db, err := driver.Open(":memory:", array.Register)
if err != nil {
log.Fatal(err)
}
Expand Down Expand Up @@ -53,14 +50,14 @@ func Example_driver() {
}

func Example() {
sqlite3.AutoExtension(array.Register)

db, err := sqlite3.Open(":memory:")
if err != nil {
log.Fatal(err)
}
defer db.Close()

array.Register(db)

stmt, _, err := db.Prepare(`
SELECT name
FROM pragma_function_list
Expand Down Expand Up @@ -91,10 +88,7 @@ func Example() {
func Test_cursor_Column(t *testing.T) {
t.Parallel()

db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error {
array.Register(c)
return nil
})
db, err := driver.Open(":memory:", array.Register)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -139,7 +133,10 @@ func Test_array_errors(t *testing.T) {
}
defer db.Close()

array.Register(db)
err = array.Register(db)
if err != nil {
t.Fatal(err)
}

err = db.Exec(`SELECT * FROM array()`)
if err == nil {
Expand Down
9 changes: 5 additions & 4 deletions ext/blobio/blob.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,11 @@ import (
// along with the [sqlite3.Blob] handle.
//
// https://sqlite.org/c3ref/blob.html
func Register(db *sqlite3.Conn) {
db.CreateFunction("readblob", 6, 0, readblob)
db.CreateFunction("writeblob", 6, 0, writeblob)
db.CreateFunction("openblob", -1, 0, openblob)
func Register(db *sqlite3.Conn) error {
return errors.Join(
db.CreateFunction("readblob", 6, 0, readblob),
db.CreateFunction("writeblob", 6, 0, writeblob),
db.CreateFunction("openblob", -1, 0, openblob))
}

// OpenCallback is the type for the openblob callback.
Expand Down
16 changes: 6 additions & 10 deletions ext/blobio/blob_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,7 @@ import (

func Example() {
// Open the database, registering the extension.
db, err := driver.Open("file:/test.db?vfs=memdb", func(conn *sqlite3.Conn) error {
blobio.Register(conn)
return nil
})
db, err := driver.Open("file:/test.db?vfs=memdb", blobio.Register)

if err != nil {
log.Fatal(err)
Expand Down Expand Up @@ -60,6 +57,11 @@ func Example() {
// Hello BLOB!
}

func init() {
sqlite3.AutoExtension(blobio.Register)
sqlite3.AutoExtension(array.Register)
}

func Test_readblob(t *testing.T) {
t.Parallel()

Expand All @@ -69,9 +71,6 @@ func Test_readblob(t *testing.T) {
}
defer db.Close()

blobio.Register(db)
array.Register(db)

err = db.Exec(`SELECT readblob()`)
if err == nil {
t.Fatal("want error")
Expand Down Expand Up @@ -129,9 +128,6 @@ func Test_openblob(t *testing.T) {
}
defer db.Close()

blobio.Register(db)
array.Register(db)

err = db.Exec(`SELECT openblob()`)
if err == nil {
t.Fatal("want error")
Expand Down
4 changes: 2 additions & 2 deletions ext/bloom/bloom.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ import (
// Register registers the bloom_filter virtual table:
//
// CREATE VIRTUAL TABLE foo USING bloom_filter(nElements, falseProb, kHashes)
func Register(db *sqlite3.Conn) {
sqlite3.CreateModule(db, "bloom_filter", create, connect)
func Register(db *sqlite3.Conn) error {
return sqlite3.CreateModule(db, "bloom_filter", create, connect)
}

type bloom struct {
Expand Down
8 changes: 4 additions & 4 deletions ext/bloom/bloom_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ import (
_ "github.com/ncruces/go-sqlite3/internal/testcfg"
)

func init() {
sqlite3.AutoExtension(bloom.Register)
}

func TestRegister(t *testing.T) {
t.Parallel()

Expand All @@ -21,8 +25,6 @@ func TestRegister(t *testing.T) {
}
defer db.Close()

bloom.Register(db)

err = db.Exec(`
CREATE VIRTUAL TABLE sports_cars USING bloom_filter(20);
INSERT INTO sports_cars VALUES ('ferrari'), ('lamborghini'), ('alfa romeo')
Expand Down Expand Up @@ -90,8 +92,6 @@ func Test_compatible(t *testing.T) {
}
defer db.Close()

bloom.Register(db)

query, _, err := db.Prepare(`SELECT COUNT(*) FROM plants(?)`)
if err != nil {
t.Fatal(err)
Expand Down
8 changes: 4 additions & 4 deletions ext/csv/csv.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@ import (

// Register registers the CSV virtual table.
// If a filename is specified, [os.Open] is used to open the file.
func Register(db *sqlite3.Conn) {
RegisterFS(db, osutil.FS{})
func Register(db *sqlite3.Conn) error {
return RegisterFS(db, osutil.FS{})
}

// RegisterFS registers the CSV virtual table.
// If a filename is specified, fsys is used to open the file.
func RegisterFS(db *sqlite3.Conn, fsys fs.FS) {
func RegisterFS(db *sqlite3.Conn, fsys fs.FS) error {
declare := func(db *sqlite3.Conn, _, _, _ string, arg ...string) (_ *table, err error) {
var (
filename string
Expand Down Expand Up @@ -118,7 +118,7 @@ func RegisterFS(db *sqlite3.Conn, fsys fs.FS) {
return table, nil
}

sqlite3.CreateModule(db, "csv", declare, declare)
return sqlite3.CreateModule(db, "csv", declare, declare)
}

type table struct {
Expand Down
15 changes: 8 additions & 7 deletions ext/csv/csv_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ func Example() {
}
defer db.Close()

csv.Register(db)
err = csv.Register(db)
if err != nil {
log.Fatal(err)
}

err = db.Exec(`
CREATE VIRTUAL TABLE eurofxref USING csv(
Expand Down Expand Up @@ -51,6 +54,10 @@ func Example() {
// On Twosday, 1€ = $1.1342
}

func init() {
sqlite3.AutoExtension(csv.Register)
}

func TestRegister(t *testing.T) {
t.Parallel()

Expand All @@ -60,8 +67,6 @@ func TestRegister(t *testing.T) {
}
defer db.Close()

csv.Register(db)

const data = `
# Comment
"Rob" "Pike" rob
Expand Down Expand Up @@ -124,8 +129,6 @@ func TestAffinity(t *testing.T) {
}
defer db.Close()

csv.Register(db)

const data = "01\n0.10\ne"
err = db.Exec(`
CREATE VIRTUAL TABLE temp.nums USING csv(
Expand Down Expand Up @@ -168,8 +171,6 @@ func TestRegister_errors(t *testing.T) {
}
defer db.Close()

csv.Register(db)

err = db.Exec(`CREATE VIRTUAL TABLE temp.users USING csv()`)
if err == nil {
t.Fatal("want error")
Expand Down
24 changes: 13 additions & 11 deletions ext/fileio/fileio.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,26 @@ import (

// Register registers SQL functions readfile, writefile, lsmode,
// and the table-valued function fsdir.
func Register(db *sqlite3.Conn) {
RegisterFS(db, nil)
func Register(db *sqlite3.Conn) error {
return RegisterFS(db, nil)
}

// Register registers SQL functions readfile, lsmode,
// and the table-valued function fsdir;
// fsys will be used to read files and list directories.
func RegisterFS(db *sqlite3.Conn, fsys fs.FS) {
db.CreateFunction("lsmode", 1, sqlite3.DETERMINISTIC, lsmode)
db.CreateFunction("readfile", 1, sqlite3.DIRECTONLY, readfile(fsys))
func RegisterFS(db *sqlite3.Conn, fsys fs.FS) error {
var err error
if fsys == nil {
db.CreateFunction("writefile", -1, sqlite3.DIRECTONLY, writefile)
err = db.CreateFunction("writefile", -1, sqlite3.DIRECTONLY, writefile)
}
sqlite3.CreateModule(db, "fsdir", nil, func(db *sqlite3.Conn, _, _, _ string, _ ...string) (fsdir, error) {
err := db.DeclareVTab(`CREATE TABLE x(name,mode,mtime TIMESTAMP,data,path HIDDEN,dir HIDDEN)`)
db.VTabConfig(sqlite3.VTAB_DIRECTONLY)
return fsdir{fsys}, err
})
return errors.Join(err,
db.CreateFunction("readfile", 1, sqlite3.DIRECTONLY, readfile(fsys)),
db.CreateFunction("lsmode", 1, sqlite3.DETERMINISTIC, lsmode),
sqlite3.CreateModule(db, "fsdir", nil, func(db *sqlite3.Conn, _, _, _ string, _ ...string) (fsdir, error) {
err := db.DeclareVTab(`CREATE TABLE x(name,mode,mtime TIMESTAMP,data,path HIDDEN,dir HIDDEN)`)
db.VTabConfig(sqlite3.VTAB_DIRECTONLY)
return fsdir{fsys}, err
}))
}

func lsmode(ctx sqlite3.Context, arg ...sqlite3.Value) {
Expand Down
5 changes: 1 addition & 4 deletions ext/fileio/fileio_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,7 @@ import (
func Test_lsmode(t *testing.T) {
t.Parallel()

db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error {
fileio.Register(c)
return nil
})
db, err := driver.Open(":memory:", fileio.Register)
if err != nil {
t.Fatal(err)
}
Expand Down
5 changes: 4 additions & 1 deletion ext/fileio/fsdir_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,10 @@ func Test_fsdir_errors(t *testing.T) {
}
defer db.Close()

fileio.Register(db)
err = fileio.Register(db)
if err != nil {
t.Fatal(err)
}

err = db.Exec(`SELECT name FROM fsdir()`)
if err == nil {
Expand Down
6 changes: 1 addition & 5 deletions ext/fileio/write_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"testing"
"time"

"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/driver"
_ "github.com/ncruces/go-sqlite3/embed"
_ "github.com/ncruces/go-sqlite3/internal/testcfg"
Expand All @@ -16,10 +15,7 @@ import (
func Test_writefile(t *testing.T) {
t.Parallel()

db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error {
Register(c)
return nil
})
db, err := driver.Open(":memory:", Register)
if err != nil {
t.Fatal(err)
}
Expand Down
Loading

0 comments on commit b5f746a

Please sign in to comment.