Skip to content

Commit

Permalink
The latest MySQL driver (v1.8.1) supports full range of uint64 values.
Browse files Browse the repository at this point in the history
  • Loading branch information
mdwhatcott committed Jun 3, 2024
1 parent 0af0dd1 commit 366bdd7
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 24 deletions.
17 changes: 4 additions & 13 deletions null/uint64.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ var ErrScan = errors.New("converting driver.Value type")

// Uint64 represents an uint64 that may be null.
// Its behavior is based on the implementation of database/sql.NullInt64.
// Because of limitations in the database/sql/driver package, this type
// is limited to the maximum int64 value.
type Uint64 struct {
Uint64 uint64
Valid bool
Expand All @@ -31,6 +29,9 @@ func (n *Uint64) Scan(value any) (err error) {
// So, even though we weren't able to convert the database value to uint64, the value wasn't NULL.

switch v := value.(type) {
case uint64:
n.Uint64 = v
return nil
case []byte:
if n.Uint64, err = strconv.ParseUint(string(v), 10, 64); err == nil {
return nil
Expand All @@ -57,15 +58,5 @@ func (n Uint64) Value() (driver.Value, error) {
if !n.Valid {
return nil, nil
}
if n.Uint64 > maxInt64 {
return 0, fmt.Errorf("%w: %d", outOfBounds, n.Uint64)
}
return int64(n.Uint64), nil
return n.Uint64, nil
}

const (
maxUint64 = ^uint64(0)
maxInt64 = maxUint64 >> 1
)

var outOfBounds = errors.New("out-of-bounds")
99 changes: 88 additions & 11 deletions null/uint64_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"strings"
"testing"
"time"

"github.com/smarty/db-connector/mysql"
)

func TestUint64Scan(t *testing.T) {
Expand All @@ -19,6 +21,8 @@ func TestUint64Scan(t *testing.T) {
{Name: "NULL ", Source: nil /***********/, Want: Uint64{Valid: false, Uint64: 0}, Err: nil},
{Name: "valid byte slice ", Source: []byte("42") /**/, Want: Uint64{Valid: true, Uint64: 42}, Err: nil},
{Name: "invalid byte slice", Source: []byte("invalid"), Want: Uint64{Valid: true, Uint64: 0}, Err: ErrScan},
{Name: "actual uint64 ", Source: uint64(42) /****/, Want: Uint64{Valid: true, Uint64: 42}, Err: nil},
{Name: "large uint64 ", Source: largeUint64 /***/, Want: Uint64{Valid: true, Uint64: largeUint64}, Err: nil},
{Name: "positive int64 ", Source: int64(42) /*****/, Want: Uint64{Valid: true, Uint64: 42}, Err: nil},
{Name: "zero int64 ", Source: int64(0) /******/, Want: Uint64{Valid: true, Uint64: 0}, Err: nil},
{Name: "negative int64 ", Source: int64(-1) /*****/, Want: Uint64{Valid: true, Uint64: 0}, Err: ErrScan},
Expand Down Expand Up @@ -80,24 +84,97 @@ func TestUint64Value_ValidInt64Value(t *testing.T) {

value, err := v.Value()

if value != int64(42) {
t.Errorf("database/sql requires int64(42), but was: %s(%d)", reflect.TypeOf(value), value)
if value != uint64(42) {
t.Errorf("database/sql requires uint64(42), but was: %s(%d)", reflect.TypeOf(value), value)
}
if err != nil {
t.Error("err should have been nil, but was:", err)
}
}
func TestUint64Value_OutOfBounds(t *testing.T) {
var v Uint64
v.Valid = true
v.Uint64 = maxInt64 + 1

value, err := v.Value()
func TestIntegration(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test")
}
db, err := mysql.New()
if err != nil {
t.Fatal(err)
}
defer func() { _ = db.Close() }()

err = db.Ping()
if err != nil {
t.Fatal(err)
}

exec(t, db, "DROP SCHEMA IF EXISTS dbConnector;")
defer exec(t, db, "DROP SCHEMA IF EXISTS dbConnector;")
exec(t, db, "CREATE SCHEMA IF NOT EXISTS dbConnector;")
exec(t, db, "DROP TABLE IF EXISTS dbConnector.null_uint64_integration_test;")
exec(t, db, `CREATE TABLE IF NOT EXISTS dbConnector.null_uint64_integration_test (
id INT NOT NULL AUTO_INCREMENT PRIMARY KEY,
special BIGINT unsigned NULL
);`,
)

var (
null Uint64
small = Uint64{Uint64: 42, Valid: true}
large = Uint64{Uint64: largeUint64, Valid: true}
)
numbers := []Uint64{null, small, large}
exec(t, db,
"INSERT INTO dbConnector.null_uint64_integration_test (special) VALUES (?), (?), (?);",
null, small, large,
)

if value != 0 {
t.Errorf("out-of-bounds value should not be returned")
rows, err := db.Query("SELECT special FROM dbConnector.null_uint64_integration_test;")
if err != nil {
t.Fatal(err)
}
defer func() { _ = rows.Close() }()

for x := 0; rows.Next(); x++ {
var toScan Uint64
err = rows.Scan(&toScan)
if err != nil {
t.Fatal(err)
}
if numbers[x].Uint64 != toScan.Uint64 {
t.Errorf("\n"+
"got: %v\n"+
"want: %v",
toScan.Uint64,
numbers[x].Uint64,
)
}
}

row := db.QueryRow("SELECT COUNT(*) FROM dbConnector.null_uint64_integration_test WHERE special IS NULL;")
var nullCount int
err = row.Scan(&nullCount)
if err != nil {
t.Fatal(err)
}
if nullCount != 1 {
t.Errorf("\n"+
"got: %v\n"+
"want: %v",
nullCount,
1,
)
}
if !errors.Is(err, outOfBounds) {
t.Error("expected out-of-bounds error but got:", err)
}
func exec(t *testing.T, db *sql.DB, statement string, args ...any) {
t.Helper()
_, err := db.Exec(statement, args...)
if err != nil {
t.Fatal(err)
}
}

const (
maxUint64 = ^uint64(0)
maxInt64 = maxUint64 >> 1
largeUint64 = maxInt64 + 1
)

0 comments on commit 366bdd7

Please sign in to comment.