Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: SingleRow #89

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

[Unreleased]: https://github.com/zombiezen/go-sqlite/compare/v1.1.2...main

Version 1.2.0 introduces `sqlitex.SingleRow` and `sqlitex.SingleRowFS` to
support the common use case of queries that are expected to return only
one result row.

[1.2.0]: https://github.com/zombiezen/go-sqlite/releases/tag/v1.2.0

### Changed

- Add `sqlitex.SingleRow` and `sqlite.SingleRowFS`
(follow-on from [#85](https://github.com/zombiezen/go-sqlite/issues/85)).

## [1.1.2][] - 2024-02-14

Version 1.1.2 updates the `modernc.org/sqlite` version to 1.29.1
Expand Down
51 changes: 51 additions & 0 deletions sqlitex/exec.go
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,57 @@ func ExecuteTransientFS(conn *sqlite.Conn, fsys fs.FS, filename string, opts *Ex
return nil
}

// SingleRow is [Execute], but it returns an error if there is not exactly one result returned.
func SingleRow(conn *sqlite.Conn, query string, opts *ExecOptions) error {
opts, ranOnce := onceWrap(opts)
err := Execute(conn, query, opts)
if err != nil {
return err
}
if !ranOnce() {
return errNoResults
}
return nil
}

// SingleRowFS is [ExecuteFS], but but it returns an error if there is not exactly one result returned.
func SingleRowFS(conn *sqlite.Conn, fsys fs.FS, filename string, opts *ExecOptions) error {
opts, ranOnce := onceWrap(opts)
err := ExecuteFS(conn, fsys, filename, opts)
if err != nil {
return err
}
if !ranOnce() {
return errNoResults
}
return nil
}

// onceWrap wraps the ResultFunc of an [*ExecOptions] in a closure that returns
// errMultipleResults if it is run more than once.
// If no ResultFunc is set, a no-op handler is used.
// It returns the modified options along with a closure that can be called to
// check if ResultFunc was run, allowing the caller to return errNoResults if it
// was never called.
func onceWrap(opts *ExecOptions) (*ExecOptions, func() bool) {
if opts == nil {
opts = &ExecOptions{}
}
if opts.ResultFunc == nil {
opts.ResultFunc = func(*sqlite.Stmt) error { return nil }
}
called := false
rf := opts.ResultFunc
opts.ResultFunc = func(stmt *sqlite.Stmt) error {
if called {
return errMultipleResults
}
called = true
return rf(stmt)
}
return opts, func() bool { return called }
}

// PrepareTransientFS prepares an SQL statement from a file
// that is not cached by the Conn.
// Subsequent calls with the same query will create new Stmts.
Expand Down
71 changes: 71 additions & 0 deletions sqlitex/exec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package sqlitex

import (
"errors"
"fmt"
"reflect"
"testing"
Expand Down Expand Up @@ -296,6 +297,76 @@ INSERT INTO t (a, b) VALUES ('a2', :a2);
})
}

func TestSingleRow(t *testing.T) {
conn, err := sqlite.OpenConn(":memory:", 0)
if err != nil {
t.Fatal(err)
}
defer conn.Close()

script := `
CREATE TABLE t (a TEXT, b INTEGER);
INSERT INTO t (a, b) VALUES ('a1', 1);
INSERT INTO t (a, b) VALUES ('a2', 1);
`
err = ExecuteScript(conn, script, &ExecOptions{})
if err != nil {
t.Fatal(err)
}

t.Run("NoResults", func(t *testing.T) {
aVal := ""
got := SingleRow(conn, `SELECT a FROM t WHERE b==?`, &ExecOptions{
Args: []any{0},
ResultFunc: func(stmt *sqlite.Stmt) error {
aVal = stmt.ColumnText(0)
return nil
},
})
if !errors.Is(got, errNoResults) {
t.Errorf("err = %v; want %v", got, errNoResults)
}
if aVal != "" {
t.Errorf(`aVal = %q; want ""`, aVal)
}
})

t.Run("MultipleResults", func(t *testing.T) {
aVal := ""
got := SingleRow(conn, `SELECT a FROM t WHERE b==?`, &ExecOptions{
Args: []any{1},
ResultFunc: func(stmt *sqlite.Stmt) error {
t.Logf("setting aval to %s", stmt.ColumnText(0))
aVal = stmt.ColumnText(0)
return nil
},
})
if !errors.Is(got, errMultipleResults) {
t.Errorf("err = %v; want %v", got, errMultipleResults)
}
if aVal != "a1" {
t.Errorf(`aVal = %q; want "a1"`, aVal)
}
})

t.Run("SingleResult", func(t *testing.T) {
bVal := 0
got := SingleRow(conn, `SELECT b FROM t WHERE a==?`, &ExecOptions{
Args: []any{"a1"},
ResultFunc: func(stmt *sqlite.Stmt) error {
bVal = stmt.ColumnInt(0)
return nil
},
})
if got != nil {
t.Errorf("err = %v; want nil", got)
}
if bVal != 1 {
t.Errorf(`bVal = %d; want 1`, bVal)
}
})
}

func TestBitsetHasAll(t *testing.T) {
tests := []struct {
bs bitset
Expand Down