diff --git a/CHANGELOG.md b/CHANGELOG.md index c323080..0409307 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [Unreleased]: https://github.com/zombiezen/go-sqlite/compare/v1.1.2...main +Version 1.2.0 introduces `sqlitex.SingleRow` and `sqlitex.SingleRowFS` to +support the common use case of queries that are expected to return only +one result row. + +[1.2.0]: https://github.com/zombiezen/go-sqlite/releases/tag/v1.2.0 + +### Changed + +- Add `sqlitex.SingleRow` and `sqlite.SingleRowFS` + (follow-on from [#85](https://github.com/zombiezen/go-sqlite/issues/85)). + ## [1.1.2][] - 2024-02-14 Version 1.1.2 updates the `modernc.org/sqlite` version to 1.29.1 diff --git a/sqlitex/exec.go b/sqlitex/exec.go index 91da538..7e32659 100644 --- a/sqlitex/exec.go +++ b/sqlitex/exec.go @@ -243,6 +243,57 @@ func ExecuteTransientFS(conn *sqlite.Conn, fsys fs.FS, filename string, opts *Ex return nil } +// SingleRow is [Execute], but it returns an error if there is not exactly one result returned. +func SingleRow(conn *sqlite.Conn, query string, opts *ExecOptions) error { + opts, ranOnce := onceWrap(opts) + err := Execute(conn, query, opts) + if err != nil { + return err + } + if !ranOnce() { + return errNoResults + } + return nil +} + +// SingleRowFS is [ExecuteFS], but but it returns an error if there is not exactly one result returned. +func SingleRowFS(conn *sqlite.Conn, fsys fs.FS, filename string, opts *ExecOptions) error { + opts, ranOnce := onceWrap(opts) + err := ExecuteFS(conn, fsys, filename, opts) + if err != nil { + return err + } + if !ranOnce() { + return errNoResults + } + return nil +} + +// onceWrap wraps the ResultFunc of an [*ExecOptions] in a closure that returns +// errMultipleResults if it is run more than once. +// If no ResultFunc is set, a no-op handler is used. +// It returns the modified options along with a closure that can be called to +// check if ResultFunc was run, allowing the caller to return errNoResults if it +// was never called. +func onceWrap(opts *ExecOptions) (*ExecOptions, func() bool) { + if opts == nil { + opts = &ExecOptions{} + } + if opts.ResultFunc == nil { + opts.ResultFunc = func(*sqlite.Stmt) error { return nil } + } + called := false + rf := opts.ResultFunc + opts.ResultFunc = func(stmt *sqlite.Stmt) error { + if called { + return errMultipleResults + } + called = true + return rf(stmt) + } + return opts, func() bool { return called } +} + // PrepareTransientFS prepares an SQL statement from a file // that is not cached by the Conn. // Subsequent calls with the same query will create new Stmts. diff --git a/sqlitex/exec_test.go b/sqlitex/exec_test.go index 97741ce..0b8f8ba 100644 --- a/sqlitex/exec_test.go +++ b/sqlitex/exec_test.go @@ -18,6 +18,7 @@ package sqlitex import ( + "errors" "fmt" "reflect" "testing" @@ -296,6 +297,76 @@ INSERT INTO t (a, b) VALUES ('a2', :a2); }) } +func TestSingleRow(t *testing.T) { + conn, err := sqlite.OpenConn(":memory:", 0) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + script := ` +CREATE TABLE t (a TEXT, b INTEGER); +INSERT INTO t (a, b) VALUES ('a1', 1); +INSERT INTO t (a, b) VALUES ('a2', 1); +` + err = ExecuteScript(conn, script, &ExecOptions{}) + if err != nil { + t.Fatal(err) + } + + t.Run("NoResults", func(t *testing.T) { + aVal := "" + got := SingleRow(conn, `SELECT a FROM t WHERE b==?`, &ExecOptions{ + Args: []any{0}, + ResultFunc: func(stmt *sqlite.Stmt) error { + aVal = stmt.ColumnText(0) + return nil + }, + }) + if !errors.Is(got, errNoResults) { + t.Errorf("err = %v; want %v", got, errNoResults) + } + if aVal != "" { + t.Errorf(`aVal = %q; want ""`, aVal) + } + }) + + t.Run("MultipleResults", func(t *testing.T) { + aVal := "" + got := SingleRow(conn, `SELECT a FROM t WHERE b==?`, &ExecOptions{ + Args: []any{1}, + ResultFunc: func(stmt *sqlite.Stmt) error { + t.Logf("setting aval to %s", stmt.ColumnText(0)) + aVal = stmt.ColumnText(0) + return nil + }, + }) + if !errors.Is(got, errMultipleResults) { + t.Errorf("err = %v; want %v", got, errMultipleResults) + } + if aVal != "a1" { + t.Errorf(`aVal = %q; want "a1"`, aVal) + } + }) + + t.Run("SingleResult", func(t *testing.T) { + bVal := 0 + got := SingleRow(conn, `SELECT b FROM t WHERE a==?`, &ExecOptions{ + Args: []any{"a1"}, + ResultFunc: func(stmt *sqlite.Stmt) error { + bVal = stmt.ColumnInt(0) + return nil + }, + }) + if got != nil { + t.Errorf("err = %v; want nil", got) + } + if bVal != 1 { + t.Errorf(`bVal = %d; want 1`, bVal) + } + }) +} + func TestBitsetHasAll(t *testing.T) { tests := []struct { bs bitset