Skip to content

Commit

Permalink
Pointer-passing interfaces.
Browse files Browse the repository at this point in the history
  • Loading branch information
ncruces committed Nov 7, 2023
1 parent 24b965a commit 24c9b57
Show file tree
Hide file tree
Showing 11 changed files with 191 additions and 2 deletions.
10 changes: 10 additions & 0 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,16 @@ func (ctx Context) resultRFC3339Nano(value time.Time) {
uint64(ctx.c.api.destructor), _UTF8)
}

// ResultPointer sets the result of the function to NULL, just like [Context.ResultNull],
// except that it also associates ptr with that NULL value such that it can be retrieved
// within an application-defined SQL function using [Value.Pointer].
//
// https://www.sqlite.org/c3ref/result_blob.html
func (ctx Context) ResultPointer(ptr any) {
valPtr := util.AddHandle(ctx.c.ctx, ptr)
ctx.c.call(ctx.c.api.resultPointer, uint64(valPtr))
}

// ResultJSON sets the result of the function to the JSON encoding of value.
//
// https://www.sqlite.org/c3ref/result_blob.html
Expand Down
5 changes: 4 additions & 1 deletion driver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,8 @@ func (s *stmt) setupBindings(args []driver.NamedValue) error {
err = s.Stmt.BindBlob(id, a)
case sqlite3.ZeroBlob:
err = s.Stmt.BindZeroBlob(id, int64(a))
case interface{ Value() any }:
err = s.Stmt.BindPointer(id, a.Value())
case time.Time:
err = s.Stmt.BindTime(id, a, sqlite3.TimeFormatDefault)
case json.Marshaler:
Expand All @@ -400,7 +402,8 @@ func (s *stmt) setupBindings(args []driver.NamedValue) error {
func (s *stmt) CheckNamedValue(arg *driver.NamedValue) error {
switch arg.Value.(type) {
case bool, int, int64, float64, string, []byte,
sqlite3.ZeroBlob, time.Time, json.Marshaler, nil:
sqlite3.ZeroBlob, interface{ Value() any },
time.Time, json.Marshaler, nil:
return nil
default:
return driver.ErrSkip
Expand Down
3 changes: 3 additions & 0 deletions embed/exports.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ sqlite3_bind_double
sqlite3_bind_text64
sqlite3_bind_blob64
sqlite3_bind_zeroblob64
sqlite3_bind_pointer_go
sqlite3_column_count
sqlite3_column_name
sqlite3_column_type
Expand Down Expand Up @@ -64,12 +65,14 @@ sqlite3_value_double
sqlite3_value_text
sqlite3_value_blob
sqlite3_value_bytes
sqlite3_value_pointer_go
sqlite3_result_null
sqlite3_result_int64
sqlite3_result_double
sqlite3_result_text64
sqlite3_result_blob64
sqlite3_result_zeroblob64
sqlite3_result_pointer_go
sqlite3_result_value
sqlite3_result_error
sqlite3_result_error_code
Expand Down
Binary file modified embed/sqlite3.wasm
Binary file not shown.
59 changes: 59 additions & 0 deletions ext/blob/blob.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// Package blob provides an alternative interface to incremental BLOB I/O.
package blob

import (
"errors"

"github.com/ncruces/go-sqlite3"
)

// Register registers the blob_open SQL function.
func Register(db *sqlite3.Conn) {
db.CreateFunction("blob_open", -1,
sqlite3.DETERMINISTIC|sqlite3.DIRECTONLY, openBlob)
}

func openBlob(ctx sqlite3.Context, arg ...sqlite3.Value) {
if len(arg) < 6 {
ctx.ResultError(errors.New("wrong number of arguments to function blob_open()"))
return
}

row := arg[3].Int64()

var err error
blob, ok := ctx.GetAuxData(0).(*sqlite3.Blob)
if ok {
err = blob.Reopen(row)
if errors.Is(err, sqlite3.MISUSE) {
// Blob was closed (db, table or column changed).
ok = false
}
}

if !ok {
db := arg[0].Text()
table := arg[1].Text()
column := arg[2].Text()
write := arg[4].Bool()
blob, err = ctx.Conn().OpenBlob(db, table, column, row, write)
}
if err != nil {
ctx.ResultError(err)
return
}

fn := arg[5].Pointer().(OpenCallback)
err = fn(blob, arg[6:]...)
if err != nil {
ctx.ResultError(err)
return
}

// This ensures the blob is closed if db, table or column change.
ctx.SetAuxData(0, blob)
ctx.SetAuxData(1, blob)
ctx.SetAuxData(2, blob)
}

type OpenCallback func(*sqlite3.Blob, ...sqlite3.Value) error
61 changes: 61 additions & 0 deletions ext/blob/blob_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package blob_test

import (
"io"
"log"
"os"

"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/driver"
_ "github.com/ncruces/go-sqlite3/embed"
"github.com/ncruces/go-sqlite3/ext/blob"
)

func Example() {
// Open the database, registering the extension.
db, err := driver.Open(":memory:", func(conn *sqlite3.Conn) error {
blob.Register(conn)
return nil
})

if err != nil {
log.Fatal(err)
}
defer os.Remove("demo.db")
defer db.Close()

_, err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
log.Fatal(err)
}

const message = "Hello BLOB!"

// Create the BLOB.
_, err = db.Exec(`INSERT INTO test VALUES (?)`, sqlite3.ZeroBlob(len(message)))
if err != nil {
log.Fatal(err)
}

// Write the BLOB.
_, err = db.Exec(`SELECT blob_open('main', 'test', 'col', last_insert_rowid(), true, ?)`,
sqlite3.Pointer[blob.OpenCallback](func(blob *sqlite3.Blob, _ ...sqlite3.Value) error {
_, err = io.WriteString(blob, message)
return err
}))
if err != nil {
log.Fatal(err)
}

// Read the BLOB.
_, err = db.Exec(`SELECT blob_open('main', 'test', 'col', rowid, false, ?) FROM test`,
sqlite3.Pointer[blob.OpenCallback](func(blob *sqlite3.Blob, _ ...sqlite3.Value) error {
_, err = io.Copy(os.Stdout, blob)
return err
}))
if err != nil {
log.Fatal(err)
}
// Output:
// Hello BLOB!
}
14 changes: 14 additions & 0 deletions pointer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package sqlite3

// Pointer returns a pointer to a value
// that can be used as an argument to
// [database/sql.DB.Exec] and similar methods.
//
// https://www.sqlite.org/bindptr.html
func Pointer[T any](val T) any {
return pointer[T]{val}
}

type pointer[T any] struct{ val T }

func (p pointer[T]) Value() any { return p.val }
6 changes: 6 additions & 0 deletions sqlite.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ func instantiateSQLite() (sqlt *sqlite, err error) {
bindText: getFun("sqlite3_bind_text64"),
bindBlob: getFun("sqlite3_bind_blob64"),
bindZeroBlob: getFun("sqlite3_bind_zeroblob64"),
bindPointer: getFun("sqlite3_bind_pointer_go"),
columnCount: getFun("sqlite3_column_count"),
columnName: getFun("sqlite3_column_name"),
columnType: getFun("sqlite3_column_type"),
Expand Down Expand Up @@ -169,12 +170,14 @@ func instantiateSQLite() (sqlt *sqlite, err error) {
valueText: getFun("sqlite3_value_text"),
valueBlob: getFun("sqlite3_value_blob"),
valueBytes: getFun("sqlite3_value_bytes"),
valuePointer: getFun("sqlite3_value_pointer_go"),
resultNull: getFun("sqlite3_result_null"),
resultInteger: getFun("sqlite3_result_int64"),
resultFloat: getFun("sqlite3_result_double"),
resultText: getFun("sqlite3_result_text64"),
resultBlob: getFun("sqlite3_result_blob64"),
resultZeroBlob: getFun("sqlite3_result_zeroblob64"),
resultPointer: getFun("sqlite3_result_pointer_go"),
resultValue: getFun("sqlite3_result_value"),
resultError: getFun("sqlite3_result_error"),
resultErrorCode: getFun("sqlite3_result_error_code"),
Expand Down Expand Up @@ -353,6 +356,7 @@ type sqliteAPI struct {
bindText api.Function
bindBlob api.Function
bindZeroBlob api.Function
bindPointer api.Function
columnCount api.Function
columnName api.Function
columnType api.Function
Expand Down Expand Up @@ -390,12 +394,14 @@ type sqliteAPI struct {
valueText api.Function
valueBlob api.Function
valueBytes api.Function
valuePointer api.Function
resultNull api.Function
resultInteger api.Function
resultFloat api.Function
resultText api.Function
resultBlob api.Function
resultZeroBlob api.Function
resultPointer api.Function
resultValue api.Function
resultError api.Function
resultErrorCode api.Function
Expand Down
16 changes: 15 additions & 1 deletion sqlite3/func.c
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,18 @@ int sqlite3_create_window_function_go(sqlite3 *db, const char *zName, int nArg,

void sqlite3_set_auxdata_go(sqlite3_context *ctx, int iArg, void *pAux) {
sqlite3_set_auxdata(ctx, iArg, pAux, go_destroy);
}
}

#define GO_POINTER_TYPE "github.com/ncruces/go-sqlite3.Pointer"

int sqlite3_bind_pointer_go(sqlite3_stmt *stmt, int i, void *pApp) {
return sqlite3_bind_pointer(stmt, i, pApp, GO_POINTER_TYPE, go_destroy);
}

void sqlite3_result_pointer_go(sqlite3_context *ctx, void *pApp) {
sqlite3_result_pointer(ctx, pApp, GO_POINTER_TYPE, go_destroy);
}

void *sqlite3_value_pointer_go(sqlite3_value *val) {
return sqlite3_value_pointer(val, GO_POINTER_TYPE);
}
12 changes: 12 additions & 0 deletions stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,18 @@ func (s *Stmt) bindRFC3339Nano(param int, value time.Time) error {
return s.c.error(r)
}

// BindPointer binds a NULL to the prepared statement, just like [Stmt.BindNull],
// but it also associates ptr with that NULL value such that it can be retrieved
// within an application-defined SQL function using [Value.Pointer].
//
// https://www.sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindPointer(param int, ptr any) error {
valPtr := util.AddHandle(s.c.ctx, ptr)
r := s.c.call(s.c.api.bindPointer,
uint64(s.handle), uint64(param), uint64(valPtr))
return s.c.error(r)
}

// BindJSON binds the JSON encoding of value to the prepared statement.
// The leftmost SQL parameter has an index of 1.
//
Expand Down
7 changes: 7 additions & 0 deletions value.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,13 @@ func (v Value) rawBytes(ptr uint32) []byte {
return util.View(v.mod, ptr, r)
}

// Pointer gets the pointer associated with this value,
// or nil if it has no associated pointer.
func (v Value) Pointer() any {
r := v.call(v.api.valuePointer, uint64(v.handle))
return util.GetHandle(v.ctx, uint32(r))
}

// JSON parses a JSON-encoded value
// and stores the result in the value pointed to by ptr.
func (v Value) JSON(ptr any) error {
Expand Down

0 comments on commit 24c9b57

Please sign in to comment.