diff --git a/null/uint64.go b/null/uint64.go index 2fe83a8..096f830 100644 --- a/null/uint64.go +++ b/null/uint64.go @@ -11,8 +11,6 @@ var ErrScan = errors.New("converting driver.Value type") // Uint64 represents an uint64 that may be null. // Its behavior is based on the implementation of database/sql.NullInt64. -// Because of limitations in the database/sql/driver package, this type -// is limited to the maximum int64 value. type Uint64 struct { Uint64 uint64 Valid bool @@ -31,6 +29,9 @@ func (n *Uint64) Scan(value any) (err error) { // So, even though we weren't able to convert the database value to uint64, the value wasn't NULL. switch v := value.(type) { + case uint64: + n.Uint64 = v + return nil case []byte: if n.Uint64, err = strconv.ParseUint(string(v), 10, 64); err == nil { return nil @@ -57,15 +58,5 @@ func (n Uint64) Value() (driver.Value, error) { if !n.Valid { return nil, nil } - if n.Uint64 > maxInt64 { - return 0, fmt.Errorf("%w: %d", outOfBounds, n.Uint64) - } - return int64(n.Uint64), nil + return n.Uint64, nil } - -const ( - maxUint64 = ^uint64(0) - maxInt64 = maxUint64 >> 1 -) - -var outOfBounds = errors.New("out-of-bounds") diff --git a/null/uint64_test.go b/null/uint64_test.go index 0afb35e..91f6e17 100644 --- a/null/uint64_test.go +++ b/null/uint64_test.go @@ -7,6 +7,8 @@ import ( "strings" "testing" "time" + + "github.com/smarty/db-connector/mysql" ) func TestUint64Scan(t *testing.T) { @@ -19,6 +21,8 @@ func TestUint64Scan(t *testing.T) { {Name: "NULL ", Source: nil /***********/, Want: Uint64{Valid: false, Uint64: 0}, Err: nil}, {Name: "valid byte slice ", Source: []byte("42") /**/, Want: Uint64{Valid: true, Uint64: 42}, Err: nil}, {Name: "invalid byte slice", Source: []byte("invalid"), Want: Uint64{Valid: true, Uint64: 0}, Err: ErrScan}, + {Name: "actual uint64 ", Source: uint64(42) /****/, Want: Uint64{Valid: true, Uint64: 42}, Err: nil}, + {Name: "large uint64 ", Source: largeUint64 /***/, Want: Uint64{Valid: true, Uint64: largeUint64}, Err: nil}, {Name: "positive int64 ", Source: int64(42) /*****/, Want: Uint64{Valid: true, Uint64: 42}, Err: nil}, {Name: "zero int64 ", Source: int64(0) /******/, Want: Uint64{Valid: true, Uint64: 0}, Err: nil}, {Name: "negative int64 ", Source: int64(-1) /*****/, Want: Uint64{Valid: true, Uint64: 0}, Err: ErrScan}, @@ -80,24 +84,97 @@ func TestUint64Value_ValidInt64Value(t *testing.T) { value, err := v.Value() - if value != int64(42) { - t.Errorf("database/sql requires int64(42), but was: %s(%d)", reflect.TypeOf(value), value) + if value != uint64(42) { + t.Errorf("database/sql requires uint64(42), but was: %s(%d)", reflect.TypeOf(value), value) } if err != nil { t.Error("err should have been nil, but was:", err) } } -func TestUint64Value_OutOfBounds(t *testing.T) { - var v Uint64 - v.Valid = true - v.Uint64 = maxInt64 + 1 - value, err := v.Value() +func TestIntegration(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test") + } + db, err := mysql.New() + if err != nil { + t.Fatal(err) + } + defer func() { _ = db.Close() }() + + err = db.Ping() + if err != nil { + t.Fatal(err) + } + + exec(t, db, "DROP SCHEMA IF EXISTS dbConnector;") + defer exec(t, db, "DROP SCHEMA IF EXISTS dbConnector;") + exec(t, db, "CREATE SCHEMA IF NOT EXISTS dbConnector;") + exec(t, db, "DROP TABLE IF EXISTS dbConnector.null_uint64_integration_test;") + exec(t, db, `CREATE TABLE IF NOT EXISTS dbConnector.null_uint64_integration_test ( + id INT NOT NULL AUTO_INCREMENT PRIMARY KEY, + special BIGINT unsigned NULL + );`, + ) + + var ( + null Uint64 + small = Uint64{Uint64: 42, Valid: true} + large = Uint64{Uint64: largeUint64, Valid: true} + ) + numbers := []Uint64{null, small, large} + exec(t, db, + "INSERT INTO dbConnector.null_uint64_integration_test (special) VALUES (?), (?), (?);", + null, small, large, + ) - if value != 0 { - t.Errorf("out-of-bounds value should not be returned") + rows, err := db.Query("SELECT special FROM dbConnector.null_uint64_integration_test;") + if err != nil { + t.Fatal(err) + } + defer func() { _ = rows.Close() }() + + for x := 0; rows.Next(); x++ { + var toScan Uint64 + err = rows.Scan(&toScan) + if err != nil { + t.Fatal(err) + } + if numbers[x].Uint64 != toScan.Uint64 { + t.Errorf("\n"+ + "got: %v\n"+ + "want: %v", + toScan.Uint64, + numbers[x].Uint64, + ) + } + } + + row := db.QueryRow("SELECT COUNT(*) FROM dbConnector.null_uint64_integration_test WHERE special IS NULL;") + var nullCount int + err = row.Scan(&nullCount) + if err != nil { + t.Fatal(err) + } + if nullCount != 1 { + t.Errorf("\n"+ + "got: %v\n"+ + "want: %v", + nullCount, + 1, + ) } - if !errors.Is(err, outOfBounds) { - t.Error("expected out-of-bounds error but got:", err) +} +func exec(t *testing.T, db *sql.DB, statement string, args ...any) { + t.Helper() + _, err := db.Exec(statement, args...) + if err != nil { + t.Fatal(err) } } + +const ( + maxUint64 = ^uint64(0) + maxInt64 = maxUint64 >> 1 + largeUint64 = maxInt64 + 1 +)