diff --git a/conn.go b/conn.go index f170ccf5..39870b14 100644 --- a/conn.go +++ b/conn.go @@ -72,6 +72,9 @@ func newConn(filename string, flags OpenFlag) (conn *Conn, err error) { c.arena = c.newArena(1024) c.ctx = context.WithValue(c.ctx, connKey{}, c) c.handle, err = c.openDB(filename, flags) + if err == nil { + err = initExtensions(c) + } if err != nil { return nil, err } diff --git a/ext/array/array.go b/ext/array/array.go index 57901f8f..619e975a 100644 --- a/ext/array/array.go +++ b/ext/array/array.go @@ -15,8 +15,8 @@ import ( // The argument must be bound to a Go slice or array of // ints, floats, bools, strings or byte slices, // using [sqlite3.BindPointer] or [sqlite3.Pointer]. -func Register(db *sqlite3.Conn) { - sqlite3.CreateModule(db, "array", nil, +func Register(db *sqlite3.Conn) error { + return sqlite3.CreateModule(db, "array", nil, func(db *sqlite3.Conn, _, _, _ string, _ ...string) (array, error) { err := db.DeclareVTab(`CREATE TABLE x(value, array HIDDEN)`) return array{}, err diff --git a/ext/array/array_test.go b/ext/array/array_test.go index 0467b034..be97d7c4 100644 --- a/ext/array/array_test.go +++ b/ext/array/array_test.go @@ -15,10 +15,7 @@ import ( ) func Example_driver() { - db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error { - array.Register(c) - return nil - }) + db, err := driver.Open(":memory:", array.Register) if err != nil { log.Fatal(err) } @@ -53,14 +50,14 @@ func Example_driver() { } func Example() { + sqlite3.AutoExtension(array.Register) + db, err := sqlite3.Open(":memory:") if err != nil { log.Fatal(err) } defer db.Close() - array.Register(db) - stmt, _, err := db.Prepare(` SELECT name FROM pragma_function_list @@ -91,10 +88,7 @@ func Example() { func Test_cursor_Column(t *testing.T) { t.Parallel() - db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error { - array.Register(c) - return nil - }) + db, err := driver.Open(":memory:", array.Register) if err != nil { t.Fatal(err) } @@ -139,7 +133,10 @@ func Test_array_errors(t *testing.T) { } defer db.Close() - array.Register(db) + err = array.Register(db) + if err != nil { + t.Fatal(err) + } err = db.Exec(`SELECT * FROM array()`) if err == nil { diff --git a/ext/blobio/blob.go b/ext/blobio/blob.go index e5e15128..6202e0bc 100644 --- a/ext/blobio/blob.go +++ b/ext/blobio/blob.go @@ -29,10 +29,11 @@ import ( // along with the [sqlite3.Blob] handle. // // https://sqlite.org/c3ref/blob.html -func Register(db *sqlite3.Conn) { - db.CreateFunction("readblob", 6, 0, readblob) - db.CreateFunction("writeblob", 6, 0, writeblob) - db.CreateFunction("openblob", -1, 0, openblob) +func Register(db *sqlite3.Conn) error { + return errors.Join( + db.CreateFunction("readblob", 6, 0, readblob), + db.CreateFunction("writeblob", 6, 0, writeblob), + db.CreateFunction("openblob", -1, 0, openblob)) } // OpenCallback is the type for the openblob callback. diff --git a/ext/blobio/blob_test.go b/ext/blobio/blob_test.go index b6f9899f..3caac219 100644 --- a/ext/blobio/blob_test.go +++ b/ext/blobio/blob_test.go @@ -18,10 +18,7 @@ import ( func Example() { // Open the database, registering the extension. - db, err := driver.Open("file:/test.db?vfs=memdb", func(conn *sqlite3.Conn) error { - blobio.Register(conn) - return nil - }) + db, err := driver.Open("file:/test.db?vfs=memdb", blobio.Register) if err != nil { log.Fatal(err) @@ -60,6 +57,11 @@ func Example() { // Hello BLOB! } +func init() { + sqlite3.AutoExtension(blobio.Register) + sqlite3.AutoExtension(array.Register) +} + func Test_readblob(t *testing.T) { t.Parallel() @@ -69,9 +71,6 @@ func Test_readblob(t *testing.T) { } defer db.Close() - blobio.Register(db) - array.Register(db) - err = db.Exec(`SELECT readblob()`) if err == nil { t.Fatal("want error") @@ -129,9 +128,6 @@ func Test_openblob(t *testing.T) { } defer db.Close() - blobio.Register(db) - array.Register(db) - err = db.Exec(`SELECT openblob()`) if err == nil { t.Fatal("want error") diff --git a/ext/bloom/bloom.go b/ext/bloom/bloom.go index 6bc4c378..ed360fcf 100644 --- a/ext/bloom/bloom.go +++ b/ext/bloom/bloom.go @@ -20,8 +20,8 @@ import ( // Register registers the bloom_filter virtual table: // // CREATE VIRTUAL TABLE foo USING bloom_filter(nElements, falseProb, kHashes) -func Register(db *sqlite3.Conn) { - sqlite3.CreateModule(db, "bloom_filter", create, connect) +func Register(db *sqlite3.Conn) error { + return sqlite3.CreateModule(db, "bloom_filter", create, connect) } type bloom struct { diff --git a/ext/bloom/bloom_test.go b/ext/bloom/bloom_test.go index d9df376e..02c08fcb 100644 --- a/ext/bloom/bloom_test.go +++ b/ext/bloom/bloom_test.go @@ -12,6 +12,10 @@ import ( _ "github.com/ncruces/go-sqlite3/internal/testcfg" ) +func init() { + sqlite3.AutoExtension(bloom.Register) +} + func TestRegister(t *testing.T) { t.Parallel() @@ -21,8 +25,6 @@ func TestRegister(t *testing.T) { } defer db.Close() - bloom.Register(db) - err = db.Exec(` CREATE VIRTUAL TABLE sports_cars USING bloom_filter(20); INSERT INTO sports_cars VALUES ('ferrari'), ('lamborghini'), ('alfa romeo') @@ -90,8 +92,6 @@ func Test_compatible(t *testing.T) { } defer db.Close() - bloom.Register(db) - query, _, err := db.Prepare(`SELECT COUNT(*) FROM plants(?)`) if err != nil { t.Fatal(err) diff --git a/ext/csv/csv.go b/ext/csv/csv.go index 9f05d8f0..bb1924e5 100644 --- a/ext/csv/csv.go +++ b/ext/csv/csv.go @@ -23,13 +23,13 @@ import ( // Register registers the CSV virtual table. // If a filename is specified, [os.Open] is used to open the file. -func Register(db *sqlite3.Conn) { - RegisterFS(db, osutil.FS{}) +func Register(db *sqlite3.Conn) error { + return RegisterFS(db, osutil.FS{}) } // RegisterFS registers the CSV virtual table. // If a filename is specified, fsys is used to open the file. -func RegisterFS(db *sqlite3.Conn, fsys fs.FS) { +func RegisterFS(db *sqlite3.Conn, fsys fs.FS) error { declare := func(db *sqlite3.Conn, _, _, _ string, arg ...string) (_ *table, err error) { var ( filename string @@ -118,7 +118,7 @@ func RegisterFS(db *sqlite3.Conn, fsys fs.FS) { return table, nil } - sqlite3.CreateModule(db, "csv", declare, declare) + return sqlite3.CreateModule(db, "csv", declare, declare) } type table struct { diff --git a/ext/csv/csv_test.go b/ext/csv/csv_test.go index a21a7960..386f855f 100644 --- a/ext/csv/csv_test.go +++ b/ext/csv/csv_test.go @@ -18,7 +18,10 @@ func Example() { } defer db.Close() - csv.Register(db) + err = csv.Register(db) + if err != nil { + log.Fatal(err) + } err = db.Exec(` CREATE VIRTUAL TABLE eurofxref USING csv( @@ -51,6 +54,10 @@ func Example() { // On Twosday, 1€ = $1.1342 } +func init() { + sqlite3.AutoExtension(csv.Register) +} + func TestRegister(t *testing.T) { t.Parallel() @@ -60,8 +67,6 @@ func TestRegister(t *testing.T) { } defer db.Close() - csv.Register(db) - const data = ` # Comment "Rob" "Pike" rob @@ -124,8 +129,6 @@ func TestAffinity(t *testing.T) { } defer db.Close() - csv.Register(db) - const data = "01\n0.10\ne" err = db.Exec(` CREATE VIRTUAL TABLE temp.nums USING csv( @@ -168,8 +171,6 @@ func TestRegister_errors(t *testing.T) { } defer db.Close() - csv.Register(db) - err = db.Exec(`CREATE VIRTUAL TABLE temp.users USING csv()`) if err == nil { t.Fatal("want error") diff --git a/ext/fileio/fileio.go b/ext/fileio/fileio.go index dcb375ad..4c1bca7d 100644 --- a/ext/fileio/fileio.go +++ b/ext/fileio/fileio.go @@ -14,24 +14,26 @@ import ( // Register registers SQL functions readfile, writefile, lsmode, // and the table-valued function fsdir. -func Register(db *sqlite3.Conn) { - RegisterFS(db, nil) +func Register(db *sqlite3.Conn) error { + return RegisterFS(db, nil) } // Register registers SQL functions readfile, lsmode, // and the table-valued function fsdir; // fsys will be used to read files and list directories. -func RegisterFS(db *sqlite3.Conn, fsys fs.FS) { - db.CreateFunction("lsmode", 1, sqlite3.DETERMINISTIC, lsmode) - db.CreateFunction("readfile", 1, sqlite3.DIRECTONLY, readfile(fsys)) +func RegisterFS(db *sqlite3.Conn, fsys fs.FS) error { + var err error if fsys == nil { - db.CreateFunction("writefile", -1, sqlite3.DIRECTONLY, writefile) + err = db.CreateFunction("writefile", -1, sqlite3.DIRECTONLY, writefile) } - sqlite3.CreateModule(db, "fsdir", nil, func(db *sqlite3.Conn, _, _, _ string, _ ...string) (fsdir, error) { - err := db.DeclareVTab(`CREATE TABLE x(name,mode,mtime TIMESTAMP,data,path HIDDEN,dir HIDDEN)`) - db.VTabConfig(sqlite3.VTAB_DIRECTONLY) - return fsdir{fsys}, err - }) + return errors.Join(err, + db.CreateFunction("readfile", 1, sqlite3.DIRECTONLY, readfile(fsys)), + db.CreateFunction("lsmode", 1, sqlite3.DETERMINISTIC, lsmode), + sqlite3.CreateModule(db, "fsdir", nil, func(db *sqlite3.Conn, _, _, _ string, _ ...string) (fsdir, error) { + err := db.DeclareVTab(`CREATE TABLE x(name,mode,mtime TIMESTAMP,data,path HIDDEN,dir HIDDEN)`) + db.VTabConfig(sqlite3.VTAB_DIRECTONLY) + return fsdir{fsys}, err + })) } func lsmode(ctx sqlite3.Context, arg ...sqlite3.Value) { diff --git a/ext/fileio/fileio_test.go b/ext/fileio/fileio_test.go index d08a9a7e..99971375 100644 --- a/ext/fileio/fileio_test.go +++ b/ext/fileio/fileio_test.go @@ -17,10 +17,7 @@ import ( func Test_lsmode(t *testing.T) { t.Parallel() - db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error { - fileio.Register(c) - return nil - }) + db, err := driver.Open(":memory:", fileio.Register) if err != nil { t.Fatal(err) } diff --git a/ext/fileio/fsdir_test.go b/ext/fileio/fsdir_test.go index 95595c4c..b9bedf35 100644 --- a/ext/fileio/fsdir_test.go +++ b/ext/fileio/fsdir_test.go @@ -68,7 +68,10 @@ func Test_fsdir_errors(t *testing.T) { } defer db.Close() - fileio.Register(db) + err = fileio.Register(db) + if err != nil { + t.Fatal(err) + } err = db.Exec(`SELECT name FROM fsdir()`) if err == nil { diff --git a/ext/fileio/write_test.go b/ext/fileio/write_test.go index fc211fb0..d9ab80d8 100644 --- a/ext/fileio/write_test.go +++ b/ext/fileio/write_test.go @@ -7,7 +7,6 @@ import ( "testing" "time" - "github.com/ncruces/go-sqlite3" "github.com/ncruces/go-sqlite3/driver" _ "github.com/ncruces/go-sqlite3/embed" _ "github.com/ncruces/go-sqlite3/internal/testcfg" @@ -16,10 +15,7 @@ import ( func Test_writefile(t *testing.T) { t.Parallel() - db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error { - Register(c) - return nil - }) + db, err := driver.Open(":memory:", Register) if err != nil { t.Fatal(err) } diff --git a/ext/hash/hash.go b/ext/hash/hash.go index c649c8f6..4ee9cc16 100644 --- a/ext/hash/hash.go +++ b/ext/hash/hash.go @@ -21,47 +21,60 @@ package hash import ( "crypto" + "errors" "github.com/ncruces/go-sqlite3" + "github.com/ncruces/go-sqlite3/internal/util" ) // Register registers cryptographic hash functions for a database connection. -func Register(db *sqlite3.Conn) { +func Register(db *sqlite3.Conn) error { flags := sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS + var errs util.ErrorJoiner if crypto.MD4.Available() { - db.CreateFunction("md4", 1, flags, md4Func) + errs.Join( + db.CreateFunction("md4", 1, flags, md4Func)) } if crypto.MD5.Available() { - db.CreateFunction("md5", 1, flags, md5Func) + errs.Join( + db.CreateFunction("md5", 1, flags, md5Func)) } if crypto.SHA1.Available() { - db.CreateFunction("sha1", 1, flags, sha1Func) + errs.Join( + db.CreateFunction("sha1", 1, flags, sha1Func)) } if crypto.SHA3_512.Available() { - db.CreateFunction("sha3", 1, flags, sha3Func) - db.CreateFunction("sha3", 2, flags, sha3Func) + errs.Join( + db.CreateFunction("sha3", 1, flags, sha3Func), + db.CreateFunction("sha3", 2, flags, sha3Func)) } if crypto.SHA256.Available() { - db.CreateFunction("sha224", 1, flags, sha224Func) - db.CreateFunction("sha256", 1, flags, sha256Func) - db.CreateFunction("sha256", 2, flags, sha256Func) + errs.Join( + db.CreateFunction("sha224", 1, flags, sha224Func), + db.CreateFunction("sha256", 1, flags, sha256Func), + db.CreateFunction("sha256", 2, flags, sha256Func)) } if crypto.SHA512.Available() { - db.CreateFunction("sha384", 1, flags, sha384Func) - db.CreateFunction("sha512", 1, flags, sha512Func) - db.CreateFunction("sha512", 2, flags, sha512Func) + errs.Join( + db.CreateFunction("sha384", 1, flags, sha384Func), + db.CreateFunction("sha512", 1, flags, sha512Func), + db.CreateFunction("sha512", 2, flags, sha512Func)) } if crypto.BLAKE2s_256.Available() { - db.CreateFunction("blake2s", 1, flags, blake2sFunc) + errs.Join( + db.CreateFunction("blake2s", 1, flags, blake2sFunc)) } if crypto.BLAKE2b_512.Available() { - db.CreateFunction("blake2b", 1, flags, blake2bFunc) - db.CreateFunction("blake2b", 2, flags, blake2bFunc) + errs.Join( + db.CreateFunction("blake2b", 1, flags, blake2bFunc), + db.CreateFunction("blake2b", 2, flags, blake2bFunc)) } if crypto.RIPEMD160.Available() { - db.CreateFunction("ripemd160", 1, flags, ripemd160Func) + errs.Join( + db.CreateFunction("ripemd160", 1, flags, ripemd160Func)) } + return errors.Join(errs...) } func md4Func(ctx sqlite3.Context, arg ...sqlite3.Value) { diff --git a/ext/hash/hash_test.go b/ext/hash/hash_test.go index 9bf67bed..91aa7564 100644 --- a/ext/hash/hash_test.go +++ b/ext/hash/hash_test.go @@ -7,7 +7,6 @@ import ( _ "crypto/sha512" "testing" - "github.com/ncruces/go-sqlite3" "github.com/ncruces/go-sqlite3/driver" _ "github.com/ncruces/go-sqlite3/embed" _ "github.com/ncruces/go-sqlite3/internal/testcfg" @@ -53,10 +52,7 @@ func TestRegister(t *testing.T) { {"blake2b('', 256)", "0E5751C026E543B2E8AB2EB06099DAA1D1E5DF47778F7787FAAB45CDF12FE3A8"}, } - db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error { - Register(c) - return nil - }) + db, err := driver.Open(":memory:", Register) if err != nil { t.Fatal(err) } diff --git a/ext/lines/lines.go b/ext/lines/lines.go index 90ec369e..7c30f86e 100644 --- a/ext/lines/lines.go +++ b/ext/lines/lines.go @@ -13,6 +13,7 @@ package lines import ( "bufio" "bytes" + "errors" "fmt" "io" "io/fs" @@ -25,27 +26,28 @@ import ( // The lines function reads from a database blob or text. // The lines_read function reads from a file or an [io.Reader]. // If a filename is specified, [os.Open] is used to open the file. -func Register(db *sqlite3.Conn) { - RegisterFS(db, osutil.FS{}) +func Register(db *sqlite3.Conn) error { + return RegisterFS(db, osutil.FS{}) } // RegisterFS registers the lines and lines_read table-valued functions. // The lines function reads from a database blob or text. // The lines_read function reads from a file or an [io.Reader]. // If a filename is specified, fsys is used to open the file. -func RegisterFS(db *sqlite3.Conn, fsys fs.FS) { - sqlite3.CreateModule(db, "lines", nil, - func(db *sqlite3.Conn, _, _, _ string, _ ...string) (lines, error) { - err := db.DeclareVTab(`CREATE TABLE x(line TEXT, data HIDDEN)`) - db.VTabConfig(sqlite3.VTAB_INNOCUOUS) - return lines{}, err - }) - sqlite3.CreateModule(db, "lines_read", nil, - func(db *sqlite3.Conn, _, _, _ string, _ ...string) (lines, error) { - err := db.DeclareVTab(`CREATE TABLE x(line TEXT, data HIDDEN)`) - db.VTabConfig(sqlite3.VTAB_DIRECTONLY) - return lines{fsys}, err - }) +func RegisterFS(db *sqlite3.Conn, fsys fs.FS) error { + return errors.Join( + sqlite3.CreateModule(db, "lines", nil, + func(db *sqlite3.Conn, _, _, _ string, _ ...string) (lines, error) { + err := db.DeclareVTab(`CREATE TABLE x(line TEXT, data HIDDEN)`) + db.VTabConfig(sqlite3.VTAB_INNOCUOUS) + return lines{}, err + }), + sqlite3.CreateModule(db, "lines_read", nil, + func(db *sqlite3.Conn, _, _, _ string, _ ...string) (lines, error) { + err := db.DeclareVTab(`CREATE TABLE x(line TEXT, data HIDDEN)`) + db.VTabConfig(sqlite3.VTAB_DIRECTONLY) + return lines{fsys}, err + })) } type lines struct { diff --git a/ext/lines/lines_test.go b/ext/lines/lines_test.go index 580617aa..e99dc117 100644 --- a/ext/lines/lines_test.go +++ b/ext/lines/lines_test.go @@ -18,10 +18,7 @@ import ( ) func Example() { - db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error { - lines.Register(c) - return nil - }) + db, err := driver.Open(":memory:", lines.Register) if err != nil { log.Fatal(err) } @@ -70,10 +67,7 @@ func Example() { func Test_lines(t *testing.T) { t.Parallel() - db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error { - lines.Register(c) - return nil - }) + db, err := driver.Open(":memory:", lines.Register) if err != nil { log.Fatal(err) } @@ -103,10 +97,7 @@ func Test_lines(t *testing.T) { func Test_lines_error(t *testing.T) { t.Parallel() - db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error { - lines.Register(c) - return nil - }) + db, err := driver.Open(":memory:", lines.Register) if err != nil { log.Fatal(err) } @@ -130,10 +121,7 @@ func Test_lines_error(t *testing.T) { func Test_lines_read(t *testing.T) { t.Parallel() - db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error { - lines.Register(c) - return nil - }) + db, err := driver.Open(":memory:", lines.Register) if err != nil { log.Fatal(err) } @@ -164,10 +152,7 @@ func Test_lines_read(t *testing.T) { func Test_lines_test(t *testing.T) { t.Parallel() - db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error { - lines.Register(c) - return nil - }) + db, err := driver.Open(":memory:", lines.Register) if err != nil { log.Fatal(err) } diff --git a/ext/pivot/pivot.go b/ext/pivot/pivot.go index e05fbe4d..d410f929 100644 --- a/ext/pivot/pivot.go +++ b/ext/pivot/pivot.go @@ -13,8 +13,8 @@ import ( ) // Register registers the pivot virtual table. -func Register(db *sqlite3.Conn) { - sqlite3.CreateModule(db, "pivot", declare, declare) +func Register(db *sqlite3.Conn) error { + return sqlite3.CreateModule(db, "pivot", declare, declare) } type table struct { diff --git a/ext/pivot/pivot_test.go b/ext/pivot/pivot_test.go index aa64001d..c8ccd6a2 100644 --- a/ext/pivot/pivot_test.go +++ b/ext/pivot/pivot_test.go @@ -14,14 +14,14 @@ import ( // https://antonz.org/sqlite-pivot-table/ func Example() { + sqlite3.AutoExtension(pivot.Register) + db, err := sqlite3.Open(":memory:") if err != nil { log.Fatal(err) } defer db.Close() - pivot.Register(db) - err = db.Exec(` CREATE TABLE sales(product TEXT, year INT, income DECIMAL); INSERT INTO sales(product, year, income) VALUES @@ -83,6 +83,10 @@ func Example() { // gamma 80 75 78 80 } +func init() { + sqlite3.AutoExtension(pivot.Register) +} + func TestRegister(t *testing.T) { t.Parallel() @@ -92,8 +96,6 @@ func TestRegister(t *testing.T) { } defer db.Close() - pivot.Register(db) - err = db.Exec(` CREATE TABLE r AS SELECT 1 id UNION SELECT 2 UNION SELECT 3; @@ -153,8 +155,6 @@ func TestRegister_errors(t *testing.T) { } defer db.Close() - pivot.Register(db) - err = db.Exec(`CREATE VIRTUAL TABLE pivot USING pivot()`) if err == nil { t.Fatal("want error") diff --git a/ext/regexp/regexp.go b/ext/regexp/regexp.go index b8315e08..4e0e5081 100644 --- a/ext/regexp/regexp.go +++ b/ext/regexp/regexp.go @@ -12,19 +12,20 @@ package regexp import ( + "errors" "regexp" "github.com/ncruces/go-sqlite3" ) // Register registers Unicode aware functions for a database connection. -func Register(db *sqlite3.Conn) { +func Register(db *sqlite3.Conn) error { flags := sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS - - db.CreateFunction("regexp", 2, flags, regex) - db.CreateFunction("regexp_like", 2, flags, regexLike) - db.CreateFunction("regexp_substr", 2, flags, regexSubstr) - db.CreateFunction("regexp_replace", 3, flags, regexReplace) + return errors.Join( + db.CreateFunction("regexp", 2, flags, regex), + db.CreateFunction("regexp_like", 2, flags, regexLike), + db.CreateFunction("regexp_substr", 2, flags, regexSubstr), + db.CreateFunction("regexp_replace", 3, flags, regexReplace)) } func load(ctx sqlite3.Context, i int, expr string) (*regexp.Regexp, error) { diff --git a/ext/regexp/regexp_test.go b/ext/regexp/regexp_test.go index 3cd94568..3901458d 100644 --- a/ext/regexp/regexp_test.go +++ b/ext/regexp/regexp_test.go @@ -3,7 +3,6 @@ package regexp import ( "testing" - "github.com/ncruces/go-sqlite3" "github.com/ncruces/go-sqlite3/driver" _ "github.com/ncruces/go-sqlite3/embed" _ "github.com/ncruces/go-sqlite3/internal/testcfg" @@ -12,10 +11,7 @@ import ( func TestRegister(t *testing.T) { t.Parallel() - db, err := driver.Open(":memory:", func(conn *sqlite3.Conn) error { - Register(conn) - return nil - }) + db, err := driver.Open(":memory:", Register) if err != nil { t.Fatal(err) } @@ -50,10 +46,7 @@ func TestRegister(t *testing.T) { func TestRegister_errors(t *testing.T) { t.Parallel() - db, err := driver.Open(":memory:", func(conn *sqlite3.Conn) error { - Register(conn) - return nil - }) + db, err := driver.Open(":memory:", Register) if err != nil { t.Fatal(err) } diff --git a/ext/statement/stmt.go b/ext/statement/stmt.go index 1fc92fb4..25c7469d 100644 --- a/ext/statement/stmt.go +++ b/ext/statement/stmt.go @@ -17,8 +17,8 @@ import ( ) // Register registers the statement virtual table. -func Register(db *sqlite3.Conn) { - sqlite3.CreateModule(db, "statement", declare, declare) +func Register(db *sqlite3.Conn) error { + return sqlite3.CreateModule(db, "statement", declare, declare) } type table struct { diff --git a/ext/statement/stmt_test.go b/ext/statement/stmt_test.go index 42b571ad..4b407e23 100644 --- a/ext/statement/stmt_test.go +++ b/ext/statement/stmt_test.go @@ -12,14 +12,14 @@ import ( ) func Example() { + sqlite3.AutoExtension(statement.Register) + db, err := sqlite3.Open(":memory:") if err != nil { log.Fatal(err) } defer db.Close() - statement.Register(db) - err = db.Exec(` CREATE VIRTUAL TABLE split_date USING statement(( SELECT @@ -48,6 +48,10 @@ func Example() { // Twosday was 2022-2-22 } +func init() { + sqlite3.AutoExtension(statement.Register) +} + func TestRegister(t *testing.T) { t.Parallel() @@ -57,8 +61,6 @@ func TestRegister(t *testing.T) { } defer db.Close() - statement.Register(db) - err = db.Exec(` CREATE VIRTUAL TABLE arguments USING statement((SELECT ? AS a, ? AS b, ? AS c)) `) @@ -107,8 +109,6 @@ func TestRegister_errors(t *testing.T) { } defer db.Close() - statement.Register(db) - err = db.Exec(`CREATE VIRTUAL TABLE split_date USING statement()`) if err == nil { t.Fatal("want error") diff --git a/ext/stats/boolean_test.go b/ext/stats/boolean_test.go index 31959a1d..7ea2536d 100644 --- a/ext/stats/boolean_test.go +++ b/ext/stats/boolean_test.go @@ -5,7 +5,6 @@ import ( "github.com/ncruces/go-sqlite3" _ "github.com/ncruces/go-sqlite3/embed" - "github.com/ncruces/go-sqlite3/ext/stats" _ "github.com/ncruces/go-sqlite3/internal/testcfg" ) @@ -18,8 +17,6 @@ func TestRegister_boolean(t *testing.T) { } defer db.Close() - stats.Register(db) - err = db.Exec(`CREATE TABLE data (x)`) if err != nil { t.Fatal(err) diff --git a/ext/stats/percentile_test.go b/ext/stats/percentile_test.go index 9251a675..7abf40c6 100644 --- a/ext/stats/percentile_test.go +++ b/ext/stats/percentile_test.go @@ -6,7 +6,6 @@ import ( "github.com/ncruces/go-sqlite3" _ "github.com/ncruces/go-sqlite3/embed" - "github.com/ncruces/go-sqlite3/ext/stats" _ "github.com/ncruces/go-sqlite3/internal/testcfg" ) @@ -19,8 +18,6 @@ func TestRegister_percentile(t *testing.T) { } defer db.Close() - stats.Register(db) - err = db.Exec(`CREATE TABLE data (x)`) if err != nil { t.Fatal(err) diff --git a/ext/stats/stats.go b/ext/stats/stats.go index b1bed239..452ebc8d 100644 --- a/ext/stats/stats.go +++ b/ext/stats/stats.go @@ -44,33 +44,38 @@ // [ANSI SQL Aggregate Functions]: https://www.oreilly.com/library/view/sql-in-a/9780596155322/ch04s02.html package stats -import "github.com/ncruces/go-sqlite3" +import ( + "errors" + + "github.com/ncruces/go-sqlite3" +) // Register registers statistics functions. -func Register(db *sqlite3.Conn) { +func Register(db *sqlite3.Conn) error { flags := sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS - db.CreateWindowFunction("var_pop", 1, flags, newVariance(var_pop)) - db.CreateWindowFunction("var_samp", 1, flags, newVariance(var_samp)) - db.CreateWindowFunction("stddev_pop", 1, flags, newVariance(stddev_pop)) - db.CreateWindowFunction("stddev_samp", 1, flags, newVariance(stddev_samp)) - db.CreateWindowFunction("covar_pop", 2, flags, newCovariance(var_pop)) - db.CreateWindowFunction("covar_samp", 2, flags, newCovariance(var_samp)) - db.CreateWindowFunction("corr", 2, flags, newCovariance(corr)) - db.CreateWindowFunction("regr_r2", 2, flags, newCovariance(regr_r2)) - db.CreateWindowFunction("regr_sxx", 2, flags, newCovariance(regr_sxx)) - db.CreateWindowFunction("regr_syy", 2, flags, newCovariance(regr_syy)) - db.CreateWindowFunction("regr_sxy", 2, flags, newCovariance(regr_sxy)) - db.CreateWindowFunction("regr_avgx", 2, flags, newCovariance(regr_avgx)) - db.CreateWindowFunction("regr_avgy", 2, flags, newCovariance(regr_avgy)) - db.CreateWindowFunction("regr_slope", 2, flags, newCovariance(regr_slope)) - db.CreateWindowFunction("regr_intercept", 2, flags, newCovariance(regr_intercept)) - db.CreateWindowFunction("regr_count", 2, flags, newCovariance(regr_count)) - db.CreateWindowFunction("regr_json", 2, flags, newCovariance(regr_json)) - db.CreateWindowFunction("median", 1, flags, newPercentile(median)) - db.CreateWindowFunction("percentile_cont", 2, flags, newPercentile(percentile_cont)) - db.CreateWindowFunction("percentile_disc", 2, flags, newPercentile(percentile_disc)) - db.CreateWindowFunction("every", 1, flags, newBoolean(every)) - db.CreateWindowFunction("some", 1, flags, newBoolean(some)) + return errors.Join( + db.CreateWindowFunction("var_pop", 1, flags, newVariance(var_pop)), + db.CreateWindowFunction("var_samp", 1, flags, newVariance(var_samp)), + db.CreateWindowFunction("stddev_pop", 1, flags, newVariance(stddev_pop)), + db.CreateWindowFunction("stddev_samp", 1, flags, newVariance(stddev_samp)), + db.CreateWindowFunction("covar_pop", 2, flags, newCovariance(var_pop)), + db.CreateWindowFunction("covar_samp", 2, flags, newCovariance(var_samp)), + db.CreateWindowFunction("corr", 2, flags, newCovariance(corr)), + db.CreateWindowFunction("regr_r2", 2, flags, newCovariance(regr_r2)), + db.CreateWindowFunction("regr_sxx", 2, flags, newCovariance(regr_sxx)), + db.CreateWindowFunction("regr_syy", 2, flags, newCovariance(regr_syy)), + db.CreateWindowFunction("regr_sxy", 2, flags, newCovariance(regr_sxy)), + db.CreateWindowFunction("regr_avgx", 2, flags, newCovariance(regr_avgx)), + db.CreateWindowFunction("regr_avgy", 2, flags, newCovariance(regr_avgy)), + db.CreateWindowFunction("regr_slope", 2, flags, newCovariance(regr_slope)), + db.CreateWindowFunction("regr_intercept", 2, flags, newCovariance(regr_intercept)), + db.CreateWindowFunction("regr_count", 2, flags, newCovariance(regr_count)), + db.CreateWindowFunction("regr_json", 2, flags, newCovariance(regr_json)), + db.CreateWindowFunction("median", 1, flags, newPercentile(median)), + db.CreateWindowFunction("percentile_cont", 2, flags, newPercentile(percentile_cont)), + db.CreateWindowFunction("percentile_disc", 2, flags, newPercentile(percentile_disc)), + db.CreateWindowFunction("every", 1, flags, newBoolean(every)), + db.CreateWindowFunction("some", 1, flags, newBoolean(some))) } const ( diff --git a/ext/stats/stats_test.go b/ext/stats/stats_test.go index 33bdd416..c3d6b5f6 100644 --- a/ext/stats/stats_test.go +++ b/ext/stats/stats_test.go @@ -10,6 +10,10 @@ import ( _ "github.com/ncruces/go-sqlite3/internal/testcfg" ) +func init() { + sqlite3.AutoExtension(stats.Register) +} + func TestRegister_variance(t *testing.T) { t.Parallel() @@ -19,8 +23,6 @@ func TestRegister_variance(t *testing.T) { } defer db.Close() - stats.Register(db) - err = db.Exec(`CREATE TABLE data (x)`) if err != nil { t.Fatal(err) @@ -88,8 +90,6 @@ func TestRegister_covariance(t *testing.T) { } defer db.Close() - stats.Register(db) - err = db.Exec(`CREATE TABLE data (y, x)`) if err != nil { t.Fatal(err) @@ -217,8 +217,6 @@ func Benchmark_variance(b *testing.B) { } defer db.Close() - stats.Register(db) - stmt, _, err := db.Prepare(`SELECT var_pop(value) FROM generate_series(0, ?)`) if err != nil { b.Fatal(err) diff --git a/ext/unicode/unicode.go b/ext/unicode/unicode.go index 2c8caeec..04c3e303 100644 --- a/ext/unicode/unicode.go +++ b/ext/unicode/unicode.go @@ -18,6 +18,7 @@ package unicode import ( "bytes" + "errors" "regexp" "strings" "unicode/utf8" @@ -30,29 +31,29 @@ import ( ) // Register registers Unicode aware functions for a database connection. -func Register(db *sqlite3.Conn) { +func Register(db *sqlite3.Conn) error { flags := sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS + return errors.Join( + db.CreateFunction("like", 2, flags, like), + db.CreateFunction("like", 3, flags, like), + db.CreateFunction("upper", 1, flags, upper), + db.CreateFunction("upper", 2, flags, upper), + db.CreateFunction("lower", 1, flags, lower), + db.CreateFunction("lower", 2, flags, lower), + db.CreateFunction("regexp", 2, flags, regex), + db.CreateFunction("icu_load_collation", 2, sqlite3.DIRECTONLY, + func(ctx sqlite3.Context, arg ...sqlite3.Value) { + name := arg[1].Text() + if name == "" { + return + } - db.CreateFunction("like", 2, flags, like) - db.CreateFunction("like", 3, flags, like) - db.CreateFunction("upper", 1, flags, upper) - db.CreateFunction("upper", 2, flags, upper) - db.CreateFunction("lower", 1, flags, lower) - db.CreateFunction("lower", 2, flags, lower) - db.CreateFunction("regexp", 2, flags, regex) - db.CreateFunction("icu_load_collation", 2, sqlite3.DIRECTONLY, - func(ctx sqlite3.Context, arg ...sqlite3.Value) { - name := arg[1].Text() - if name == "" { - return - } - - err := RegisterCollation(db, arg[0].Text(), name) - if err != nil { - ctx.ResultError(err) - return - } - }) + err := RegisterCollation(db, arg[0].Text(), name) + if err != nil { + ctx.ResultError(err) + return + } + })) } // RegisterCollation registers a Unicode collation sequence for a database connection. diff --git a/ext/uuid/uuid.go b/ext/uuid/uuid.go index 4fd5fcee..ae6c213e 100644 --- a/ext/uuid/uuid.go +++ b/ext/uuid/uuid.go @@ -5,6 +5,7 @@ package uuid import ( "bytes" + "errors" "fmt" "github.com/google/uuid" @@ -25,14 +26,15 @@ import ( // uuid_blob(u) // // Converts a UUID into a 16-byte blob. -func Register(db *sqlite3.Conn) { +func Register(db *sqlite3.Conn) error { flags := sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS - db.CreateFunction("uuid", 0, sqlite3.INNOCUOUS, generate) - db.CreateFunction("uuid", 1, sqlite3.INNOCUOUS, generate) - db.CreateFunction("uuid", 2, sqlite3.INNOCUOUS, generate) - db.CreateFunction("uuid", 3, sqlite3.INNOCUOUS, generate) - db.CreateFunction("uuid_str", 1, flags, toString) - db.CreateFunction("uuid_blob", 1, flags, toBLOB) + return errors.Join( + db.CreateFunction("uuid", 0, sqlite3.INNOCUOUS, generate), + db.CreateFunction("uuid", 1, sqlite3.INNOCUOUS, generate), + db.CreateFunction("uuid", 2, sqlite3.INNOCUOUS, generate), + db.CreateFunction("uuid", 3, sqlite3.INNOCUOUS, generate), + db.CreateFunction("uuid_str", 1, flags, toString), + db.CreateFunction("uuid_blob", 1, flags, toBlob)) } func generate(ctx sqlite3.Context, arg ...sqlite3.Value) { @@ -147,7 +149,7 @@ func fromValue(arg sqlite3.Value) (u uuid.UUID, err error) { return u, err } -func toBLOB(ctx sqlite3.Context, arg ...sqlite3.Value) { +func toBlob(ctx sqlite3.Context, arg ...sqlite3.Value) { u, err := fromValue(arg[0]) if err != nil { ctx.ResultError(err) diff --git a/ext/uuid/uuid_test.go b/ext/uuid/uuid_test.go index 3345a07e..4adcaa98 100644 --- a/ext/uuid/uuid_test.go +++ b/ext/uuid/uuid_test.go @@ -4,7 +4,6 @@ import ( "testing" "github.com/google/uuid" - "github.com/ncruces/go-sqlite3" "github.com/ncruces/go-sqlite3/driver" _ "github.com/ncruces/go-sqlite3/embed" _ "github.com/ncruces/go-sqlite3/internal/testcfg" @@ -13,10 +12,7 @@ import ( func Test_generate(t *testing.T) { t.Parallel() - db, err := driver.Open(":memory:", func(conn *sqlite3.Conn) error { - Register(conn) - return nil - }) + db, err := driver.Open(":memory:", Register) if err != nil { t.Fatal(err) } @@ -135,10 +131,7 @@ func Test_generate(t *testing.T) { func Test_convert(t *testing.T) { t.Parallel() - db, err := driver.Open(":memory:", func(conn *sqlite3.Conn) error { - Register(conn) - return nil - }) + db, err := driver.Open(":memory:", Register) if err != nil { t.Fatal(err) } diff --git a/ext/zorder/zorder.go b/ext/zorder/zorder.go index 8ab93724..9ba6e471 100644 --- a/ext/zorder/zorder.go +++ b/ext/zorder/zorder.go @@ -4,15 +4,18 @@ package zorder import ( + "errors" + "github.com/ncruces/go-sqlite3" "github.com/ncruces/go-sqlite3/internal/util" ) // Register registers the zorder and unzorder SQL functions. -func Register(db *sqlite3.Conn) { +func Register(db *sqlite3.Conn) error { flags := sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS - db.CreateFunction("zorder", -1, flags, zorder) - db.CreateFunction("unzorder", 3, flags, unzorder) + return errors.Join( + db.CreateFunction("zorder", -1, flags, zorder), + db.CreateFunction("unzorder", 3, flags, unzorder)) } func zorder(ctx sqlite3.Context, arg ...sqlite3.Value) { diff --git a/ext/zorder/zorder_test.go b/ext/zorder/zorder_test.go index c8fbbe3e..fd2632ae 100644 --- a/ext/zorder/zorder_test.go +++ b/ext/zorder/zorder_test.go @@ -3,7 +3,6 @@ package zorder_test import ( "testing" - "github.com/ncruces/go-sqlite3" "github.com/ncruces/go-sqlite3/driver" _ "github.com/ncruces/go-sqlite3/embed" "github.com/ncruces/go-sqlite3/ext/zorder" @@ -13,10 +12,7 @@ import ( func TestRegister_zorder(t *testing.T) { t.Parallel() - db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error { - zorder.Register(c) - return nil - }) + db, err := driver.Open(":memory:", zorder.Register) if err != nil { t.Fatal(err) } @@ -60,10 +56,7 @@ func TestRegister_zorder(t *testing.T) { func TestRegister_unzorder(t *testing.T) { t.Parallel() - db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error { - zorder.Register(c) - return nil - }) + db, err := driver.Open(":memory:", zorder.Register) if err != nil { t.Fatal(err) } @@ -90,10 +83,7 @@ func TestRegister_unzorder(t *testing.T) { func TestRegister_error(t *testing.T) { t.Parallel() - db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error { - zorder.Register(c) - return nil - }) + db, err := driver.Open(":memory:", zorder.Register) if err != nil { t.Fatal(err) } diff --git a/func.go b/func.go index 255584a4..ab486e79 100644 --- a/func.go +++ b/func.go @@ -31,8 +31,9 @@ func (c *Conn) CollationNeeded(cb func(db *Conn, name string)) error { // // This can be used to load schemas that contain // one or more unknown collating sequences. -func (c *Conn) AnyCollationNeeded() { - c.call("sqlite3_anycollseq_init", uint64(c.handle), 0, 0) +func (c Conn) AnyCollationNeeded() error { + r := c.call("sqlite3_anycollseq_init", uint64(c.handle), 0, 0) + return c.error(r) } // CreateCollation defines a new collating sequence. diff --git a/internal/util/error.go b/internal/util/error.go index 1f5555fd..2aecac96 100644 --- a/internal/util/error.go +++ b/internal/util/error.go @@ -104,3 +104,13 @@ func ErrorCodeString(rc uint32) string { } return "sqlite3: unknown error" } + +type ErrorJoiner []error + +func (j *ErrorJoiner) Join(errs ...error) { + for _, err := range errs { + if err != nil { + *j = append(*j, err) + } + } +} diff --git a/registry.go b/registry.go new file mode 100644 index 00000000..043d69ee --- /dev/null +++ b/registry.go @@ -0,0 +1,30 @@ +package sqlite3 + +import "sync" + +var ( + // +checklocks:extRegistryMtx + extRegistry []func(*Conn) error + extRegistryMtx sync.RWMutex +) + +// AutoExtension causes the entryPoint function to be invoked +// for each new database connection that is created. +// +// https://sqlite.org/c3ref/auto_extension.html +func AutoExtension(entryPoint func(*Conn) error) { + extRegistryMtx.Lock() + defer extRegistryMtx.Unlock() + extRegistry = append(extRegistry, entryPoint) +} + +func initExtensions(c *Conn) error { + extRegistryMtx.RLock() + defer extRegistryMtx.RUnlock() + for _, f := range extRegistry { + if err := f(c); err != nil { + return err + } + } + return nil +} diff --git a/tests/func_test.go b/tests/func_test.go index 136a87b4..08ee7ac4 100644 --- a/tests/func_test.go +++ b/tests/func_test.go @@ -207,7 +207,10 @@ func TestAnyCollationNeeded(t *testing.T) { t.Fatal(err) } - db.AnyCollationNeeded() + err = db.AnyCollationNeeded() + if err != nil { + t.Fatal(err) + } stmt, _, err := db.Prepare(`SELECT id, name FROM users ORDER BY name COLLATE silly`) if err != nil {