diff --git a/CHANGELOG.md b/CHANGELOG.md index b93a0ab..67b8f54 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased][] +### Added + +- New function `sqlitex.ResultBytes`. + ([#86](https://github.com/zombiezen/go-sqlite/pull/86)) + ### Changed - `Conn.Close` returns an error if the connection has already been closed diff --git a/sqlitex/query.go b/sqlitex/query.go index 6ba0a35..122435f 100644 --- a/sqlitex/query.go +++ b/sqlitex/query.go @@ -9,8 +9,10 @@ import ( "zombiezen.com/go/sqlite" ) -var errNoResults = errors.New("sqlite: statement has no results") -var errMultipleResults = errors.New("sqlite: statement has multiple result rows") +var ( + errNoResults = errors.New("sqlite: statement has no results") + errMultipleResults = errors.New("sqlite: statement has multiple result rows") +) func resultSetup(stmt *sqlite.Stmt) error { hasRow, err := stmt.Step() @@ -100,3 +102,18 @@ func ResultFloat(stmt *sqlite.Stmt) (float64, error) { } return res, nil } + +// ResultBytes reads the first column of the first and only row +// produced by running stmt into buf, +// returning the number of bytes read. +// It returns an error if there is not exactly one result row. +func ResultBytes(stmt *sqlite.Stmt, buf []byte) (int, error) { + if err := resultSetup(stmt); err != nil { + return 0, err + } + read := stmt.ColumnBytes(0, buf) + if err := resultTeardown(stmt); err != nil { + return 0, err + } + return read, nil +} diff --git a/sqlitex/query_test.go b/sqlitex/query_test.go new file mode 100644 index 0000000..bca63fe --- /dev/null +++ b/sqlitex/query_test.go @@ -0,0 +1,360 @@ +// Copyright 2024 Roxy Light +// SPDX-License-Identifier: ISC + +package sqlitex + +import ( + "testing" + + "zombiezen.com/go/sqlite" +) + +func TestResultInt64(t *testing.T) { + conn, err := sqlite.OpenConn(":memory:") + if err != nil { + t.Fatal(err) + } + defer func() { + if err := conn.Close(); err != nil { + t.Error(err) + } + }() + + err = ExecuteScript(conn, ` +CREATE TABLE foo ( + id integer not null primary key +); + +INSERT INTO foo VALUES (1), (2);`, nil) + if err != nil { + t.Fatal(err) + } + + t.Run("Single", func(t *testing.T) { + stmt, _, err := conn.PrepareTransient(`SELECT 42;`) + if err != nil { + t.Fatal(err) + } + defer stmt.Finalize() + + const want = 42 + got, err := ResultInt64(stmt) + if got != want || err != nil { + t.Errorf("ResultInt64(...) = %d, %v; want %d, ", got, err, want) + } + }) + + t.Run("Multiple", func(t *testing.T) { + stmt, _, err := conn.PrepareTransient(`SELECT id FROM foo;`) + if err != nil { + t.Fatal(err) + } + defer stmt.Finalize() + + n, err := ResultInt64(stmt) + if n != 0 || err == nil { + t.Errorf("ResultInt64(...) = %d, %v; want 0, ", n, err) + } else { + t.Log("Returned (expected) error:", err) + } + }) + + t.Run("NoRows", func(t *testing.T) { + stmt, _, err := conn.PrepareTransient(`SELECT id FROM foo WHERE id > 3;`) + if err != nil { + t.Fatal(err) + } + defer stmt.Finalize() + + n, err := ResultInt64(stmt) + if n != 0 || err == nil { + t.Errorf("ResultInt64(...) = %d, %v; want 0, ", n, err) + } else { + t.Log("Returned (expected) error:", err) + } + }) +} + +func TestResultBool(t *testing.T) { + conn, err := sqlite.OpenConn(":memory:") + if err != nil { + t.Fatal(err) + } + defer func() { + if err := conn.Close(); err != nil { + t.Error(err) + } + }() + + err = ExecuteScript(conn, ` +CREATE TABLE foo ( + id integer not null primary key +); + +INSERT INTO foo VALUES (1), (2);`, nil) + if err != nil { + t.Fatal(err) + } + + t.Run("False", func(t *testing.T) { + stmt, _, err := conn.PrepareTransient(`SELECT false;`) + if err != nil { + t.Fatal(err) + } + defer stmt.Finalize() + + got, err := ResultBool(stmt) + if got || err != nil { + t.Errorf("ResultBool(...) = %t, %v; want false, ", got, err) + } + }) + + t.Run("True", func(t *testing.T) { + stmt, _, err := conn.PrepareTransient(`SELECT true;`) + if err != nil { + t.Fatal(err) + } + defer stmt.Finalize() + + got, err := ResultBool(stmt) + if !got || err != nil { + t.Errorf("ResultBool(...) = %t, %v; want true, ", got, err) + } + }) + + t.Run("Multiple", func(t *testing.T) { + stmt, _, err := conn.PrepareTransient(`SELECT id = 1 FROM foo;`) + if err != nil { + t.Fatal(err) + } + defer stmt.Finalize() + + got, err := ResultBool(stmt) + if got || err == nil { + t.Errorf("ResultBool(...) = %t, %v; want false, ", got, err) + } else { + t.Log("Returned (expected) error:", err) + } + }) + + t.Run("NoRows", func(t *testing.T) { + stmt, _, err := conn.PrepareTransient(`SELECT id = 1 FROM foo WHERE id > 3;`) + if err != nil { + t.Fatal(err) + } + defer stmt.Finalize() + + got, err := ResultBool(stmt) + if got || err == nil { + t.Errorf("ResultBool(...) = %t, %v; want false, ", got, err) + } else { + t.Log("Returned (expected) error:", err) + } + }) +} + +func TestResultText(t *testing.T) { + conn, err := sqlite.OpenConn(":memory:") + if err != nil { + t.Fatal(err) + } + defer func() { + if err := conn.Close(); err != nil { + t.Error(err) + } + }() + + err = ExecuteScript(conn, ` +CREATE TABLE foo ( + id integer not null primary key, + my_blob blob +); + +INSERT INTO foo VALUES (1, CAST('hi' AS BLOB)), (2, CAST('bye' AS BLOB));`, nil) + if err != nil { + t.Fatal(err) + } + + t.Run("Single", func(t *testing.T) { + stmt, _, err := conn.PrepareTransient(`SELECT my_blob FROM foo WHERE id = 1;`) + if err != nil { + t.Fatal(err) + } + defer stmt.Finalize() + + const want = "hi" + got, err := ResultText(stmt) + if got != want || err != nil { + t.Errorf("ResultText(...) = %q, %v; want %q, ", got, err, want) + } + }) + + t.Run("Multiple", func(t *testing.T) { + stmt, _, err := conn.PrepareTransient(`SELECT my_blob FROM foo;`) + if err != nil { + t.Fatal(err) + } + defer stmt.Finalize() + + got, err := ResultText(stmt) + if got != "" || err == nil { + t.Errorf("ResultText(...) = %q, %v; want 0, ", got, err) + } else { + t.Log("Returned (expected) error:", err) + } + }) + + t.Run("NoRows", func(t *testing.T) { + stmt, _, err := conn.PrepareTransient(`SELECT my_blob FROM foo WHERE id = 3;`) + if err != nil { + t.Fatal(err) + } + defer stmt.Finalize() + + got, err := ResultText(stmt) + if got != "" || err == nil { + t.Errorf("ResultText(...) = %q, %v; want 0, ", got, err) + } else { + t.Log("Returned (expected) error:", err) + } + }) +} + +func TestResultFloat(t *testing.T) { + conn, err := sqlite.OpenConn(":memory:") + if err != nil { + t.Fatal(err) + } + defer func() { + if err := conn.Close(); err != nil { + t.Error(err) + } + }() + + err = ExecuteScript(conn, ` +CREATE TABLE foo ( + id integer not null primary key +); + +INSERT INTO foo VALUES (1), (2);`, nil) + if err != nil { + t.Fatal(err) + } + + t.Run("Single", func(t *testing.T) { + stmt, _, err := conn.PrepareTransient(`SELECT 42;`) + if err != nil { + t.Fatal(err) + } + defer stmt.Finalize() + + const want = 42.0 + got, err := ResultFloat(stmt) + if got != want || err != nil { + t.Errorf("ResultFloat(...) = %g, %v; want %g, ", got, err, want) + } + }) + + t.Run("Multiple", func(t *testing.T) { + stmt, _, err := conn.PrepareTransient(`SELECT id FROM foo;`) + if err != nil { + t.Fatal(err) + } + defer stmt.Finalize() + + n, err := ResultFloat(stmt) + if n != 0 || err == nil { + t.Errorf("ResultFloat(...) = %g, %v; want 0, ", n, err) + } else { + t.Log("Returned (expected) error:", err) + } + }) + + t.Run("NoRows", func(t *testing.T) { + stmt, _, err := conn.PrepareTransient(`SELECT id FROM foo WHERE id > 3;`) + if err != nil { + t.Fatal(err) + } + defer stmt.Finalize() + + n, err := ResultFloat(stmt) + if n != 0 || err == nil { + t.Errorf("ResultFloat(...) = %g, %v; want 0, ", n, err) + } else { + t.Log("Returned (expected) error:", err) + } + }) +} + +func TestResultBytes(t *testing.T) { + conn, err := sqlite.OpenConn(":memory:") + if err != nil { + t.Fatal(err) + } + defer func() { + if err := conn.Close(); err != nil { + t.Error(err) + } + }() + + err = ExecuteScript(conn, ` +CREATE TABLE foo ( + id integer not null primary key, + my_blob blob +); + +INSERT INTO foo VALUES (1, CAST('hi' AS BLOB)), (2, CAST('bye' AS BLOB));`, nil) + if err != nil { + t.Fatal(err) + } + + t.Run("Single", func(t *testing.T) { + stmt, _, err := conn.PrepareTransient(`SELECT my_blob FROM foo WHERE id = 1;`) + if err != nil { + t.Fatal(err) + } + defer stmt.Finalize() + + const want = "hi" + buf := make([]byte, 4096) + n, err := ResultBytes(stmt, buf) + if n != len(want) || err != nil { + t.Errorf("ResultBytes(...) = %d, %v; want %d, ", n, err, len(want)) + } + if got := string(buf[:n]); got != want { + t.Errorf("result = %q; want %q", got, want) + } + }) + + t.Run("Multiple", func(t *testing.T) { + stmt, _, err := conn.PrepareTransient(`SELECT my_blob FROM foo;`) + if err != nil { + t.Fatal(err) + } + defer stmt.Finalize() + + buf := make([]byte, 4096) + n, err := ResultBytes(stmt, buf) + if n != 0 || err == nil { + t.Errorf("ResultBytes(...) = %d, %v; want 0, ", n, err) + } else { + t.Log("Returned (expected) error:", err) + } + }) + + t.Run("NoRows", func(t *testing.T) { + stmt, _, err := conn.PrepareTransient(`SELECT my_blob FROM foo WHERE id = 3;`) + if err != nil { + t.Fatal(err) + } + defer stmt.Finalize() + + buf := make([]byte, 4096) + n, err := ResultBytes(stmt, buf) + if n != 0 || err == nil { + t.Errorf("ResultBytes(...) = %d, %v; want 0, ", n, err) + } else { + t.Log("Returned (expected) error:", err) + } + }) +}