From d2f162972d1c4a309b01b6c628a1bbf319070e41 Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Tue, 21 Jan 2025 01:42:57 +0000 Subject: [PATCH] More type safe. (#216) --- backup.go | 38 +++--- blob.go | 64 ++++----- config.go | 158 +++++++++++----------- conn.go | 147 ++++++++++---------- const.go | 16 ++- context.go | 40 +++--- driver/driver.go | 4 +- error.go | 16 +-- error_test.go | 36 ++--- ext/README.md | 2 + ext/csv/csv.go | 2 +- ext/pivot/pivot.go | 4 +- func.go | 72 +++++----- internal/util/func.go | 3 - internal/util/handle.go | 12 +- internal/util/mem.go | 85 +++++++----- internal/util/mem_test.go | 26 ++-- internal/util/mmap_unix.go | 18 +-- internal/util/mmap_windows.go | 6 +- sqlite.go | 77 +++++------ sqlite_test.go | 20 +-- stmt.go | 247 +++++++++++++++++----------------- txn.go | 29 ++-- util/sql3util/parse.go | 18 +-- value.go | 63 +++++---- vfs/api.go | 4 +- vfs/cksm.go | 8 +- vfs/const.go | 7 +- vfs/file.go | 8 +- vfs/filename.go | 34 ++--- vfs/lock_test.go | 36 ++--- vfs/memdb/memdb.go | 4 +- vfs/shm_bsd.go | 2 +- vfs/shm_dotlk.go | 14 +- vfs/shm_ofd.go | 2 +- vfs/shm_windows.go | 14 +- vfs/vfs.go | 157 +++++++++++---------- vfs/vfs_test.go | 47 ++++--- vtab.go | 180 ++++++++++++------------- 39 files changed, 865 insertions(+), 855 deletions(-) diff --git a/backup.go b/backup.go index b16c7511..58b6229a 100644 --- a/backup.go +++ b/backup.go @@ -5,8 +5,8 @@ package sqlite3 // https://sqlite.org/c3ref/backup.html type Backup struct { c *Conn - handle uint32 - otherc uint32 + handle ptr_t + otherc ptr_t } // Backup backs up srcDB on the src connection to the "main" database in dstURI. @@ -61,7 +61,7 @@ func (src *Conn) BackupInit(srcDB, dstURI string) (*Backup, error) { return src.backupInit(dst, "main", src.handle, srcDB) } -func (c *Conn) backupInit(dst uint32, dstName string, src uint32, srcName string) (*Backup, error) { +func (c *Conn) backupInit(dst ptr_t, dstName string, src ptr_t, srcName string) (*Backup, error) { defer c.arena.mark()() dstPtr := c.arena.string(dstName) srcPtr := c.arena.string(srcName) @@ -71,19 +71,19 @@ func (c *Conn) backupInit(dst uint32, dstName string, src uint32, srcName string other = src } - r := c.call("sqlite3_backup_init", - uint64(dst), uint64(dstPtr), - uint64(src), uint64(srcPtr)) - if r == 0 { + ptr := ptr_t(c.call("sqlite3_backup_init", + stk_t(dst), stk_t(dstPtr), + stk_t(src), stk_t(srcPtr))) + if ptr == 0 { defer c.closeDB(other) - r = c.call("sqlite3_errcode", uint64(dst)) - return nil, c.sqlite.error(r, dst) + rc := res_t(c.call("sqlite3_errcode", stk_t(dst))) + return nil, c.sqlite.error(rc, dst) } return &Backup{ c: c, otherc: other, - handle: uint32(r), + handle: ptr, }, nil } @@ -97,10 +97,10 @@ func (b *Backup) Close() error { return nil } - r := b.c.call("sqlite3_backup_finish", uint64(b.handle)) + rc := res_t(b.c.call("sqlite3_backup_finish", stk_t(b.handle))) b.c.closeDB(b.otherc) b.handle = 0 - return b.c.error(r) + return b.c.error(rc) } // Step copies up to nPage pages between the source and destination databases. @@ -108,11 +108,11 @@ func (b *Backup) Close() error { // // https://sqlite.org/c3ref/backup_finish.html#sqlite3backupstep func (b *Backup) Step(nPage int) (done bool, err error) { - r := b.c.call("sqlite3_backup_step", uint64(b.handle), uint64(nPage)) - if r == _DONE { + rc := res_t(b.c.call("sqlite3_backup_step", stk_t(b.handle), stk_t(nPage))) + if rc == _DONE { return true, nil } - return false, b.c.error(r) + return false, b.c.error(rc) } // Remaining returns the number of pages still to be backed up @@ -120,8 +120,8 @@ func (b *Backup) Step(nPage int) (done bool, err error) { // // https://sqlite.org/c3ref/backup_finish.html#sqlite3backupremaining func (b *Backup) Remaining() int { - r := b.c.call("sqlite3_backup_remaining", uint64(b.handle)) - return int(int32(r)) + n := int32(b.c.call("sqlite3_backup_remaining", stk_t(b.handle))) + return int(n) } // PageCount returns the total number of pages in the source database @@ -129,6 +129,6 @@ func (b *Backup) Remaining() int { // // https://sqlite.org/c3ref/backup_finish.html#sqlite3backuppagecount func (b *Backup) PageCount() int { - r := b.c.call("sqlite3_backup_pagecount", uint64(b.handle)) - return int(int32(r)) + n := int32(b.c.call("sqlite3_backup_pagecount", stk_t(b.handle))) + return int(n) } diff --git a/blob.go b/blob.go index a0969eb6..2fac7204 100644 --- a/blob.go +++ b/blob.go @@ -20,8 +20,8 @@ type Blob struct { c *Conn bytes int64 offset int64 - handle uint32 - bufptr uint32 + handle ptr_t + bufptr ptr_t buflen int64 } @@ -37,23 +37,23 @@ func (c *Conn) OpenBlob(db, table, column string, row int64, write bool) (*Blob, tablePtr := c.arena.string(table) columnPtr := c.arena.string(column) - var flags uint64 + var flags int32 if write { flags = 1 } c.checkInterrupt(c.handle) - r := c.call("sqlite3_blob_open", uint64(c.handle), - uint64(dbPtr), uint64(tablePtr), uint64(columnPtr), - uint64(row), flags, uint64(blobPtr)) + rc := res_t(c.call("sqlite3_blob_open", stk_t(c.handle), + stk_t(dbPtr), stk_t(tablePtr), stk_t(columnPtr), + stk_t(row), stk_t(flags), stk_t(blobPtr))) - if err := c.error(r); err != nil { + if err := c.error(rc); err != nil { return nil, err } blob := Blob{c: c} - blob.handle = util.ReadUint32(c.mod, blobPtr) - blob.bytes = int64(c.call("sqlite3_blob_bytes", uint64(blob.handle))) + blob.handle = util.Read32[ptr_t](c.mod, blobPtr) + blob.bytes = int64(int32(c.call("sqlite3_blob_bytes", stk_t(blob.handle)))) return &blob, nil } @@ -67,10 +67,10 @@ func (b *Blob) Close() error { return nil } - r := b.c.call("sqlite3_blob_close", uint64(b.handle)) + rc := res_t(b.c.call("sqlite3_blob_close", stk_t(b.handle))) b.c.free(b.bufptr) b.handle = 0 - return b.c.error(r) + return b.c.error(rc) } // Size returns the size of the BLOB in bytes. @@ -94,13 +94,13 @@ func (b *Blob) Read(p []byte) (n int, err error) { want = avail } if want > b.buflen { - b.bufptr = b.c.realloc(b.bufptr, uint64(want)) + b.bufptr = b.c.realloc(b.bufptr, want) b.buflen = want } - r := b.c.call("sqlite3_blob_read", uint64(b.handle), - uint64(b.bufptr), uint64(want), uint64(b.offset)) - err = b.c.error(r) + rc := res_t(b.c.call("sqlite3_blob_read", stk_t(b.handle), + stk_t(b.bufptr), stk_t(want), stk_t(b.offset))) + err = b.c.error(rc) if err != nil { return 0, err } @@ -109,7 +109,7 @@ func (b *Blob) Read(p []byte) (n int, err error) { err = io.EOF } - copy(p, util.View(b.c.mod, b.bufptr, uint64(want))) + copy(p, util.View(b.c.mod, b.bufptr, want)) return int(want), err } @@ -127,19 +127,19 @@ func (b *Blob) WriteTo(w io.Writer) (n int64, err error) { want = avail } if want > b.buflen { - b.bufptr = b.c.realloc(b.bufptr, uint64(want)) + b.bufptr = b.c.realloc(b.bufptr, want) b.buflen = want } for want > 0 { - r := b.c.call("sqlite3_blob_read", uint64(b.handle), - uint64(b.bufptr), uint64(want), uint64(b.offset)) - err = b.c.error(r) + rc := res_t(b.c.call("sqlite3_blob_read", stk_t(b.handle), + stk_t(b.bufptr), stk_t(want), stk_t(b.offset))) + err = b.c.error(rc) if err != nil { return n, err } - mem := util.View(b.c.mod, b.bufptr, uint64(want)) + mem := util.View(b.c.mod, b.bufptr, want) m, err := w.Write(mem[:want]) b.offset += int64(m) n += int64(m) @@ -165,14 +165,14 @@ func (b *Blob) WriteTo(w io.Writer) (n int64, err error) { func (b *Blob) Write(p []byte) (n int, err error) { want := int64(len(p)) if want > b.buflen { - b.bufptr = b.c.realloc(b.bufptr, uint64(want)) + b.bufptr = b.c.realloc(b.bufptr, want) b.buflen = want } util.WriteBytes(b.c.mod, b.bufptr, p) - r := b.c.call("sqlite3_blob_write", uint64(b.handle), - uint64(b.bufptr), uint64(want), uint64(b.offset)) - err = b.c.error(r) + rc := res_t(b.c.call("sqlite3_blob_write", stk_t(b.handle), + stk_t(b.bufptr), stk_t(want), stk_t(b.offset))) + err = b.c.error(rc) if err != nil { return 0, err } @@ -196,17 +196,17 @@ func (b *Blob) ReadFrom(r io.Reader) (n int64, err error) { want = 1 } if want > b.buflen { - b.bufptr = b.c.realloc(b.bufptr, uint64(want)) + b.bufptr = b.c.realloc(b.bufptr, want) b.buflen = want } for { - mem := util.View(b.c.mod, b.bufptr, uint64(want)) + mem := util.View(b.c.mod, b.bufptr, want) m, err := r.Read(mem[:want]) if m > 0 { - r := b.c.call("sqlite3_blob_write", uint64(b.handle), - uint64(b.bufptr), uint64(m), uint64(b.offset)) - err := b.c.error(r) + rc := res_t(b.c.call("sqlite3_blob_write", stk_t(b.handle), + stk_t(b.bufptr), stk_t(m), stk_t(b.offset))) + err := b.c.error(rc) if err != nil { return n, err } @@ -254,8 +254,8 @@ func (b *Blob) Seek(offset int64, whence int) (int64, error) { // https://sqlite.org/c3ref/blob_reopen.html func (b *Blob) Reopen(row int64) error { b.c.checkInterrupt(b.c.handle) - err := b.c.error(b.c.call("sqlite3_blob_reopen", uint64(b.handle), uint64(row))) - b.bytes = int64(b.c.call("sqlite3_blob_bytes", uint64(b.handle))) + err := b.c.error(res_t(b.c.call("sqlite3_blob_reopen", stk_t(b.handle), stk_t(row)))) + b.bytes = int64(int32(b.c.call("sqlite3_blob_bytes", stk_t(b.handle)))) b.offset = 0 return err } diff --git a/config.go b/config.go index 474f960a..7fff6ead 100644 --- a/config.go +++ b/config.go @@ -32,7 +32,7 @@ func (c *Conn) Config(op DBConfig, arg ...bool) (bool, error) { defer c.arena.mark()() argsPtr := c.arena.new(intlen + ptrlen) - var flag int + var flag int32 switch { case len(arg) == 0: flag = -1 @@ -40,31 +40,31 @@ func (c *Conn) Config(op DBConfig, arg ...bool) (bool, error) { flag = 1 } - util.WriteUint32(c.mod, argsPtr+0*ptrlen, uint32(flag)) - util.WriteUint32(c.mod, argsPtr+1*ptrlen, argsPtr) + util.Write32(c.mod, argsPtr+0*ptrlen, flag) + util.Write32(c.mod, argsPtr+1*ptrlen, argsPtr) - r := c.call("sqlite3_db_config", uint64(c.handle), - uint64(op), uint64(argsPtr)) - return util.ReadUint32(c.mod, argsPtr) != 0, c.error(r) + rc := res_t(c.call("sqlite3_db_config", stk_t(c.handle), + stk_t(op), stk_t(argsPtr))) + return util.Read32[uint32](c.mod, argsPtr) != 0, c.error(rc) } // ConfigLog sets up the error logging callback for the connection. // // https://sqlite.org/errlog.html func (c *Conn) ConfigLog(cb func(code ExtendedErrorCode, msg string)) error { - var enable uint64 + var enable int32 if cb != nil { enable = 1 } - r := c.call("sqlite3_config_log_go", enable) - if err := c.error(r); err != nil { + rc := res_t(c.call("sqlite3_config_log_go", stk_t(enable))) + if err := c.error(rc); err != nil { return err } c.log = cb return nil } -func logCallback(ctx context.Context, mod api.Module, _, iCode, zMsg uint32) { +func logCallback(ctx context.Context, mod api.Module, _ ptr_t, iCode res_t, zMsg ptr_t) { if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.log != nil { msg := util.ReadString(mod, zMsg, _MAX_LENGTH) c.log(xErrorCode(iCode), msg) @@ -88,93 +88,93 @@ func (c *Conn) FileControl(schema string, op FcntlOpcode, arg ...any) (any, erro defer c.arena.mark()() ptr := c.arena.new(max(ptrlen, intlen)) - var schemaPtr uint32 + var schemaPtr ptr_t if schema != "" { schemaPtr = c.arena.string(schema) } - var rc uint64 - var res any + var rc res_t + var ret any switch op { default: return nil, MISUSE case FCNTL_RESET_CACHE: - rc = c.call("sqlite3_file_control", - uint64(c.handle), uint64(schemaPtr), - uint64(op), 0) + rc = res_t(c.call("sqlite3_file_control", + stk_t(c.handle), stk_t(schemaPtr), + stk_t(op), 0)) case FCNTL_PERSIST_WAL, FCNTL_POWERSAFE_OVERWRITE: - var flag int + var flag int32 switch { case len(arg) == 0: flag = -1 case arg[0]: flag = 1 } - util.WriteUint32(c.mod, ptr, uint32(flag)) - rc = c.call("sqlite3_file_control", - uint64(c.handle), uint64(schemaPtr), - uint64(op), uint64(ptr)) - res = util.ReadUint32(c.mod, ptr) != 0 + util.Write32(c.mod, ptr, flag) + rc = res_t(c.call("sqlite3_file_control", + stk_t(c.handle), stk_t(schemaPtr), + stk_t(op), stk_t(ptr))) + ret = util.Read32[uint32](c.mod, ptr) != 0 case FCNTL_CHUNK_SIZE: - util.WriteUint32(c.mod, ptr, uint32(arg[0].(int))) - rc = c.call("sqlite3_file_control", - uint64(c.handle), uint64(schemaPtr), - uint64(op), uint64(ptr)) + util.Write32(c.mod, ptr, int32(arg[0].(int))) + rc = res_t(c.call("sqlite3_file_control", + stk_t(c.handle), stk_t(schemaPtr), + stk_t(op), stk_t(ptr))) case FCNTL_RESERVE_BYTES: bytes := -1 if len(arg) > 0 { bytes = arg[0].(int) } - util.WriteUint32(c.mod, ptr, uint32(bytes)) - rc = c.call("sqlite3_file_control", - uint64(c.handle), uint64(schemaPtr), - uint64(op), uint64(ptr)) - res = int(util.ReadUint32(c.mod, ptr)) + util.Write32(c.mod, ptr, int32(bytes)) + rc = res_t(c.call("sqlite3_file_control", + stk_t(c.handle), stk_t(schemaPtr), + stk_t(op), stk_t(ptr))) + ret = int(util.Read32[int32](c.mod, ptr)) case FCNTL_DATA_VERSION: - rc = c.call("sqlite3_file_control", - uint64(c.handle), uint64(schemaPtr), - uint64(op), uint64(ptr)) - res = util.ReadUint32(c.mod, ptr) + rc = res_t(c.call("sqlite3_file_control", + stk_t(c.handle), stk_t(schemaPtr), + stk_t(op), stk_t(ptr))) + ret = util.Read32[uint32](c.mod, ptr) case FCNTL_LOCKSTATE: - rc = c.call("sqlite3_file_control", - uint64(c.handle), uint64(schemaPtr), - uint64(op), uint64(ptr)) - res = vfs.LockLevel(util.ReadUint32(c.mod, ptr)) + rc = res_t(c.call("sqlite3_file_control", + stk_t(c.handle), stk_t(schemaPtr), + stk_t(op), stk_t(ptr))) + ret = util.Read32[vfs.LockLevel](c.mod, ptr) case FCNTL_VFS_POINTER: - rc = c.call("sqlite3_file_control", - uint64(c.handle), uint64(schemaPtr), - uint64(op), uint64(ptr)) + rc = res_t(c.call("sqlite3_file_control", + stk_t(c.handle), stk_t(schemaPtr), + stk_t(op), stk_t(ptr))) if rc == _OK { const zNameOffset = 16 - ptr = util.ReadUint32(c.mod, ptr) - ptr = util.ReadUint32(c.mod, ptr+zNameOffset) + ptr = util.Read32[ptr_t](c.mod, ptr) + ptr = util.Read32[ptr_t](c.mod, ptr+zNameOffset) name := util.ReadString(c.mod, ptr, _MAX_NAME) - res = vfs.Find(name) + ret = vfs.Find(name) } case FCNTL_FILE_POINTER, FCNTL_JOURNAL_POINTER: - rc = c.call("sqlite3_file_control", - uint64(c.handle), uint64(schemaPtr), - uint64(op), uint64(ptr)) + rc = res_t(c.call("sqlite3_file_control", + stk_t(c.handle), stk_t(schemaPtr), + stk_t(op), stk_t(ptr))) if rc == _OK { const fileHandleOffset = 4 - ptr = util.ReadUint32(c.mod, ptr) - ptr = util.ReadUint32(c.mod, ptr+fileHandleOffset) - res = util.GetHandle(c.ctx, ptr) + ptr = util.Read32[ptr_t](c.mod, ptr) + ptr = util.Read32[ptr_t](c.mod, ptr+fileHandleOffset) + ret = util.GetHandle(c.ctx, ptr) } } if err := c.error(rc); err != nil { return nil, err } - return res, nil + return ret, nil } // Limit allows the size of various constructs to be @@ -182,20 +182,20 @@ func (c *Conn) FileControl(schema string, op FcntlOpcode, arg ...any) (any, erro // // https://sqlite.org/c3ref/limit.html func (c *Conn) Limit(id LimitCategory, value int) int { - r := c.call("sqlite3_limit", uint64(c.handle), uint64(id), uint64(value)) - return int(int32(r)) + v := int32(c.call("sqlite3_limit", stk_t(c.handle), stk_t(id), stk_t(value))) + return int(v) } // SetAuthorizer registers an authorizer callback with the database connection. // // https://sqlite.org/c3ref/set_authorizer.html func (c *Conn) SetAuthorizer(cb func(action AuthorizerActionCode, name3rd, name4th, schema, inner string) AuthorizerReturnCode) error { - var enable uint64 + var enable int32 if cb != nil { enable = 1 } - r := c.call("sqlite3_set_authorizer_go", uint64(c.handle), enable) - if err := c.error(r); err != nil { + rc := res_t(c.call("sqlite3_set_authorizer_go", stk_t(c.handle), stk_t(enable))) + if err := c.error(rc); err != nil { return err } c.authorizer = cb @@ -203,7 +203,7 @@ func (c *Conn) SetAuthorizer(cb func(action AuthorizerActionCode, name3rd, name4 } -func authorizerCallback(ctx context.Context, mod api.Module, pDB uint32, action AuthorizerActionCode, zName3rd, zName4th, zSchema, zInner uint32) (rc AuthorizerReturnCode) { +func authorizerCallback(ctx context.Context, mod api.Module, pDB ptr_t, action AuthorizerActionCode, zName3rd, zName4th, zSchema, zInner ptr_t) (rc AuthorizerReturnCode) { if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.authorizer != nil { var name3rd, name4th, schema, inner string if zName3rd != 0 { @@ -227,15 +227,15 @@ func authorizerCallback(ctx context.Context, mod api.Module, pDB uint32, action // // https://sqlite.org/c3ref/trace_v2.html func (c *Conn) Trace(mask TraceEvent, cb func(evt TraceEvent, arg1 any, arg2 any) error) error { - r := c.call("sqlite3_trace_go", uint64(c.handle), uint64(mask)) - if err := c.error(r); err != nil { + rc := res_t(c.call("sqlite3_trace_go", stk_t(c.handle), stk_t(mask))) + if err := c.error(rc); err != nil { return err } c.trace = cb return nil } -func traceCallback(ctx context.Context, mod api.Module, evt TraceEvent, pDB, pArg1, pArg2 uint32) (rc uint32) { +func traceCallback(ctx context.Context, mod api.Module, evt TraceEvent, pDB, pArg1, pArg2 ptr_t) (rc res_t) { if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.trace != nil { var arg1, arg2 any if evt == TRACE_CLOSE { @@ -248,7 +248,7 @@ func traceCallback(ctx context.Context, mod api.Module, evt TraceEvent, pDB, pAr case TRACE_STMT: arg2 = s.SQL() case TRACE_PROFILE: - arg2 = int64(util.ReadUint64(mod, pArg2)) + arg2 = util.Read64[int64](mod, pArg2) } break } @@ -269,20 +269,20 @@ func (c *Conn) WALCheckpoint(schema string, mode CheckpointMode) (nLog, nCkpt in nLogPtr := c.arena.new(ptrlen) nCkptPtr := c.arena.new(ptrlen) schemaPtr := c.arena.string(schema) - r := c.call("sqlite3_wal_checkpoint_v2", - uint64(c.handle), uint64(schemaPtr), uint64(mode), - uint64(nLogPtr), uint64(nCkptPtr)) - nLog = int(int32(util.ReadUint32(c.mod, nLogPtr))) - nCkpt = int(int32(util.ReadUint32(c.mod, nCkptPtr))) - return nLog, nCkpt, c.error(r) + rc := res_t(c.call("sqlite3_wal_checkpoint_v2", + stk_t(c.handle), stk_t(schemaPtr), stk_t(mode), + stk_t(nLogPtr), stk_t(nCkptPtr))) + nLog = int(util.Read32[int32](c.mod, nLogPtr)) + nCkpt = int(util.Read32[int32](c.mod, nCkptPtr)) + return nLog, nCkpt, c.error(rc) } // WALAutoCheckpoint configures WAL auto-checkpoints. // // https://sqlite.org/c3ref/wal_autocheckpoint.html func (c *Conn) WALAutoCheckpoint(pages int) error { - r := c.call("sqlite3_wal_autocheckpoint", uint64(c.handle), uint64(pages)) - return c.error(r) + rc := res_t(c.call("sqlite3_wal_autocheckpoint", stk_t(c.handle), stk_t(pages))) + return c.error(rc) } // WALHook registers a callback function to be invoked @@ -290,15 +290,15 @@ func (c *Conn) WALAutoCheckpoint(pages int) error { // // https://sqlite.org/c3ref/wal_hook.html func (c *Conn) WALHook(cb func(db *Conn, schema string, pages int) error) { - var enable uint64 + var enable int32 if cb != nil { enable = 1 } - c.call("sqlite3_wal_hook_go", uint64(c.handle), enable) + c.call("sqlite3_wal_hook_go", stk_t(c.handle), stk_t(enable)) c.wal = cb } -func walCallback(ctx context.Context, mod api.Module, _, pDB, zSchema uint32, pages int32) (rc uint32) { +func walCallback(ctx context.Context, mod api.Module, _, pDB, zSchema ptr_t, pages int32) (rc res_t) { if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.wal != nil { schema := util.ReadString(mod, zSchema, _MAX_NAME) err := c.wal(c, schema, int(pages)) @@ -311,15 +311,15 @@ func walCallback(ctx context.Context, mod api.Module, _, pDB, zSchema uint32, pa // // https://sqlite.org/c3ref/autovacuum_pages.html func (c *Conn) AutoVacuumPages(cb func(schema string, dbPages, freePages, bytesPerPage uint) uint) error { - var funcPtr uint32 + var funcPtr ptr_t if cb != nil { funcPtr = util.AddHandle(c.ctx, cb) } - r := c.call("sqlite3_autovacuum_pages_go", uint64(c.handle), uint64(funcPtr)) - return c.error(r) + rc := res_t(c.call("sqlite3_autovacuum_pages_go", stk_t(c.handle), stk_t(funcPtr))) + return c.error(rc) } -func autoVacuumCallback(ctx context.Context, mod api.Module, pApp, zSchema, nDbPage, nFreePage, nBytePerPage uint32) uint32 { +func autoVacuumCallback(ctx context.Context, mod api.Module, pApp, zSchema ptr_t, nDbPage, nFreePage, nBytePerPage uint32) uint32 { fn := util.GetHandle(ctx, pApp).(func(schema string, dbPages, freePages, bytesPerPage uint) uint) schema := util.ReadString(mod, zSchema, _MAX_NAME) return uint32(fn(schema, uint(nDbPage), uint(nFreePage), uint(nBytePerPage))) @@ -329,14 +329,14 @@ func autoVacuumCallback(ctx context.Context, mod api.Module, pApp, zSchema, nDbP // // https://sqlite.org/c3ref/hard_heap_limit64.html func (c *Conn) SoftHeapLimit(n int64) int64 { - return int64(c.call("sqlite3_soft_heap_limit64", uint64(n))) + return int64(c.call("sqlite3_soft_heap_limit64", stk_t(n))) } // HardHeapLimit imposes a hard limit on heap size. // // https://sqlite.org/c3ref/hard_heap_limit64.html func (c *Conn) HardHeapLimit(n int64) int64 { - return int64(c.call("sqlite3_hard_heap_limit64", uint64(n))) + return int64(c.call("sqlite3_hard_heap_limit64", stk_t(n))) } // EnableChecksums enables checksums on a database. diff --git a/conn.go b/conn.go index 862d4306..fffc7416 100644 --- a/conn.go +++ b/conn.go @@ -35,11 +35,11 @@ type Conn struct { update func(AuthorizerActionCode, string, string, int64) commit func() bool rollback func() - arena arena busy1st time.Time busylst time.Time - handle uint32 + arena arena + handle ptr_t } // Open calls [OpenFlags] with [OPEN_READWRITE], [OPEN_CREATE] and [OPEN_URI]. @@ -70,7 +70,7 @@ func OpenFlags(filename string, flags OpenFlag) (*Conn, error) { type connKey = util.ConnKey -func newConn(ctx context.Context, filename string, flags OpenFlag) (res *Conn, _ error) { +func newConn(ctx context.Context, filename string, flags OpenFlag) (ret *Conn, _ error) { err := ctx.Err() if err != nil { return nil, err @@ -82,7 +82,7 @@ func newConn(ctx context.Context, filename string, flags OpenFlag) (res *Conn, _ return nil, err } defer func() { - if res == nil { + if ret == nil { c.Close() c.sqlite.close() } else { @@ -91,7 +91,7 @@ func newConn(ctx context.Context, filename string, flags OpenFlag) (res *Conn, _ }() c.ctx = context.WithValue(c.ctx, connKey{}, c) - c.arena = c.newArena(1024) + c.arena = c.newArena() c.handle, err = c.openDB(filename, flags) if err == nil { err = initExtensions(c) @@ -102,21 +102,21 @@ func newConn(ctx context.Context, filename string, flags OpenFlag) (res *Conn, _ return c, nil } -func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) { +func (c *Conn) openDB(filename string, flags OpenFlag) (ptr_t, error) { defer c.arena.mark()() connPtr := c.arena.new(ptrlen) namePtr := c.arena.string(filename) flags |= OPEN_EXRESCODE - r := c.call("sqlite3_open_v2", uint64(namePtr), uint64(connPtr), uint64(flags), 0) + rc := res_t(c.call("sqlite3_open_v2", stk_t(namePtr), stk_t(connPtr), stk_t(flags), 0)) - handle := util.ReadUint32(c.mod, connPtr) - if err := c.sqlite.error(r, handle); err != nil { + handle := util.Read32[ptr_t](c.mod, connPtr) + if err := c.sqlite.error(rc, handle); err != nil { c.closeDB(handle) return 0, err } - c.call("sqlite3_progress_handler_go", uint64(handle), 100) + c.call("sqlite3_progress_handler_go", stk_t(handle), 100) if flags|OPEN_URI != 0 && strings.HasPrefix(filename, "file:") { var pragmas strings.Builder if _, after, ok := strings.Cut(filename, "?"); ok { @@ -130,8 +130,8 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) { if pragmas.Len() != 0 { c.checkInterrupt(handle) pragmaPtr := c.arena.string(pragmas.String()) - r := c.call("sqlite3_exec", uint64(handle), uint64(pragmaPtr), 0, 0, 0) - if err := c.sqlite.error(r, handle, pragmas.String()); err != nil { + rc := res_t(c.call("sqlite3_exec", stk_t(handle), stk_t(pragmaPtr), 0, 0, 0)) + if err := c.sqlite.error(rc, handle, pragmas.String()); err != nil { err = fmt.Errorf("sqlite3: invalid _pragma: %w", err) c.closeDB(handle) return 0, err @@ -141,9 +141,9 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) { return handle, nil } -func (c *Conn) closeDB(handle uint32) { - r := c.call("sqlite3_close_v2", uint64(handle)) - if err := c.sqlite.error(r, handle); err != nil { +func (c *Conn) closeDB(handle ptr_t) { + rc := res_t(c.call("sqlite3_close_v2", stk_t(handle))) + if err := c.sqlite.error(rc, handle); err != nil { panic(err) } } @@ -165,8 +165,8 @@ func (c *Conn) Close() error { c.pending.Close() c.pending = nil - r := c.call("sqlite3_close", uint64(c.handle)) - if err := c.error(r); err != nil { + rc := res_t(c.call("sqlite3_close", stk_t(c.handle))) + if err := c.error(rc); err != nil { return err } @@ -183,8 +183,8 @@ func (c *Conn) Exec(sql string) error { sqlPtr := c.arena.string(sql) c.checkInterrupt(c.handle) - r := c.call("sqlite3_exec", uint64(c.handle), uint64(sqlPtr), 0, 0, 0) - return c.error(r, sql) + rc := res_t(c.call("sqlite3_exec", stk_t(c.handle), stk_t(sqlPtr), 0, 0, 0)) + return c.error(rc, sql) } // Prepare calls [Conn.PrepareFlags] with no flags. @@ -209,17 +209,17 @@ func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail str sqlPtr := c.arena.string(sql) c.checkInterrupt(c.handle) - r := c.call("sqlite3_prepare_v3", uint64(c.handle), - uint64(sqlPtr), uint64(len(sql)+1), uint64(flags), - uint64(stmtPtr), uint64(tailPtr)) + rc := res_t(c.call("sqlite3_prepare_v3", stk_t(c.handle), + stk_t(sqlPtr), stk_t(len(sql)+1), stk_t(flags), + stk_t(stmtPtr), stk_t(tailPtr))) stmt = &Stmt{c: c} - stmt.handle = util.ReadUint32(c.mod, stmtPtr) - if sql := sql[util.ReadUint32(c.mod, tailPtr)-sqlPtr:]; sql != "" { + stmt.handle = util.Read32[ptr_t](c.mod, stmtPtr) + if sql := sql[util.Read32[ptr_t](c.mod, tailPtr)-sqlPtr:]; sql != "" { tail = sql } - if err := c.error(r, sql); err != nil { + if err := c.error(rc, sql); err != nil { return nil, "", err } if stmt.handle == 0 { @@ -233,9 +233,7 @@ func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail str // // https://sqlite.org/c3ref/db_name.html func (c *Conn) DBName(n int) string { - r := c.call("sqlite3_db_name", uint64(c.handle), uint64(n)) - - ptr := uint32(r) + ptr := ptr_t(c.call("sqlite3_db_name", stk_t(c.handle), stk_t(n))) if ptr == 0 { return "" } @@ -246,34 +244,34 @@ func (c *Conn) DBName(n int) string { // // https://sqlite.org/c3ref/db_filename.html func (c *Conn) Filename(schema string) *vfs.Filename { - var ptr uint32 + var ptr ptr_t if schema != "" { defer c.arena.mark()() ptr = c.arena.string(schema) } - r := c.call("sqlite3_db_filename", uint64(c.handle), uint64(ptr)) - return vfs.GetFilename(c.ctx, c.mod, uint32(r), vfs.OPEN_MAIN_DB) + ptr = ptr_t(c.call("sqlite3_db_filename", stk_t(c.handle), stk_t(ptr))) + return vfs.GetFilename(c.ctx, c.mod, ptr, vfs.OPEN_MAIN_DB) } // ReadOnly determines if a database is read-only. // // https://sqlite.org/c3ref/db_readonly.html func (c *Conn) ReadOnly(schema string) (ro bool, ok bool) { - var ptr uint32 + var ptr ptr_t if schema != "" { defer c.arena.mark()() ptr = c.arena.string(schema) } - r := c.call("sqlite3_db_readonly", uint64(c.handle), uint64(ptr)) - return int32(r) > 0, int32(r) < 0 + b := int32(c.call("sqlite3_db_readonly", stk_t(c.handle), stk_t(ptr))) + return b > 0, b < 0 } // GetAutocommit tests the connection for auto-commit mode. // // https://sqlite.org/c3ref/get_autocommit.html func (c *Conn) GetAutocommit() bool { - r := c.call("sqlite3_get_autocommit", uint64(c.handle)) - return r != 0 + b := int32(c.call("sqlite3_get_autocommit", stk_t(c.handle))) + return b != 0 } // LastInsertRowID returns the rowid of the most recent successful INSERT @@ -281,8 +279,7 @@ func (c *Conn) GetAutocommit() bool { // // https://sqlite.org/c3ref/last_insert_rowid.html func (c *Conn) LastInsertRowID() int64 { - r := c.call("sqlite3_last_insert_rowid", uint64(c.handle)) - return int64(r) + return int64(c.call("sqlite3_last_insert_rowid", stk_t(c.handle))) } // SetLastInsertRowID allows the application to set the value returned by @@ -290,7 +287,7 @@ func (c *Conn) LastInsertRowID() int64 { // // https://sqlite.org/c3ref/set_last_insert_rowid.html func (c *Conn) SetLastInsertRowID(id int64) { - c.call("sqlite3_set_last_insert_rowid", uint64(c.handle), uint64(id)) + c.call("sqlite3_set_last_insert_rowid", stk_t(c.handle), stk_t(id)) } // Changes returns the number of rows modified, inserted or deleted @@ -299,8 +296,7 @@ func (c *Conn) SetLastInsertRowID(id int64) { // // https://sqlite.org/c3ref/changes.html func (c *Conn) Changes() int64 { - r := c.call("sqlite3_changes64", uint64(c.handle)) - return int64(r) + return int64(c.call("sqlite3_changes64", stk_t(c.handle))) } // TotalChanges returns the number of rows modified, inserted or deleted @@ -309,16 +305,15 @@ func (c *Conn) Changes() int64 { // // https://sqlite.org/c3ref/total_changes.html func (c *Conn) TotalChanges() int64 { - r := c.call("sqlite3_total_changes64", uint64(c.handle)) - return int64(r) + return int64(c.call("sqlite3_total_changes64", stk_t(c.handle))) } // ReleaseMemory frees memory used by a database connection. // // https://sqlite.org/c3ref/db_release_memory.html func (c *Conn) ReleaseMemory() error { - r := c.call("sqlite3_db_release_memory", uint64(c.handle)) - return c.error(r) + rc := res_t(c.call("sqlite3_db_release_memory", stk_t(c.handle))) + return c.error(rc) } // GetInterrupt gets the context set with [Conn.SetInterrupt]. @@ -354,10 +349,10 @@ func (c *Conn) SetInterrupt(ctx context.Context) (old context.Context) { defer c.arena.mark()() stmtPtr := c.arena.new(ptrlen) loopPtr := c.arena.string(`WITH RECURSIVE c(x) AS (VALUES(0) UNION ALL SELECT x FROM c) SELECT x FROM c`) - c.call("sqlite3_prepare_v3", uint64(c.handle), uint64(loopPtr), math.MaxUint64, - uint64(PREPARE_PERSISTENT), uint64(stmtPtr), 0) + c.call("sqlite3_prepare_v3", stk_t(c.handle), stk_t(loopPtr), math.MaxUint64, + stk_t(PREPARE_PERSISTENT), stk_t(stmtPtr), 0) c.pending = &Stmt{c: c} - c.pending.handle = util.ReadUint32(c.mod, stmtPtr) + c.pending.handle = util.Read32[ptr_t](c.mod, stmtPtr) } if old.Done() != nil && ctx.Err() == nil { @@ -369,13 +364,13 @@ func (c *Conn) SetInterrupt(ctx context.Context) (old context.Context) { return old } -func (c *Conn) checkInterrupt(handle uint32) { +func (c *Conn) checkInterrupt(handle ptr_t) { if c.interrupt.Err() != nil { - c.call("sqlite3_interrupt", uint64(handle)) + c.call("sqlite3_interrupt", stk_t(handle)) } } -func progressCallback(ctx context.Context, mod api.Module, _ uint32) (interrupt uint32) { +func progressCallback(ctx context.Context, mod api.Module, _ ptr_t) (interrupt int32) { if c, ok := ctx.Value(connKey{}).(*Conn); ok { if c.interrupt.Done() != nil { runtime.Gosched() @@ -392,11 +387,11 @@ func progressCallback(ctx context.Context, mod api.Module, _ uint32) (interrupt // https://sqlite.org/c3ref/busy_timeout.html func (c *Conn) BusyTimeout(timeout time.Duration) error { ms := min((timeout+time.Millisecond-1)/time.Millisecond, math.MaxInt32) - r := c.call("sqlite3_busy_timeout", uint64(c.handle), uint64(ms)) - return c.error(r) + rc := res_t(c.call("sqlite3_busy_timeout", stk_t(c.handle), stk_t(ms))) + return c.error(rc) } -func timeoutCallback(ctx context.Context, mod api.Module, count, tmout int32) (retry uint32) { +func timeoutCallback(ctx context.Context, mod api.Module, count, tmout int32) (retry int32) { // https://fractaledmind.github.io/2024/04/15/sqlite-on-rails-the-how-and-why-of-optimal-performance/ if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.interrupt.Err() == nil { switch { @@ -419,19 +414,19 @@ func timeoutCallback(ctx context.Context, mod api.Module, count, tmout int32) (r // // https://sqlite.org/c3ref/busy_handler.html func (c *Conn) BusyHandler(cb func(ctx context.Context, count int) (retry bool)) error { - var enable uint64 + var enable int32 if cb != nil { enable = 1 } - r := c.call("sqlite3_busy_handler_go", uint64(c.handle), enable) - if err := c.error(r); err != nil { + rc := res_t(c.call("sqlite3_busy_handler_go", stk_t(c.handle), stk_t(enable))) + if err := c.error(rc); err != nil { return err } c.busy = cb return nil } -func busyCallback(ctx context.Context, mod api.Module, pDB uint32, count int32) (retry uint32) { +func busyCallback(ctx context.Context, mod api.Module, pDB ptr_t, count int32) (retry int32) { if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.busy != nil { interrupt := c.interrupt if interrupt == nil { @@ -452,16 +447,16 @@ func (c *Conn) Status(op DBStatus, reset bool) (current, highwater int, err erro hiPtr := c.arena.new(intlen) curPtr := c.arena.new(intlen) - var i uint64 + var i int32 if reset { i = 1 } - r := c.call("sqlite3_db_status", uint64(c.handle), - uint64(op), uint64(curPtr), uint64(hiPtr), i) - if err = c.error(r); err == nil { - current = int(util.ReadUint32(c.mod, curPtr)) - highwater = int(util.ReadUint32(c.mod, hiPtr)) + rc := res_t(c.call("sqlite3_db_status", stk_t(c.handle), + stk_t(op), stk_t(curPtr), stk_t(hiPtr), stk_t(i))) + if err = c.error(rc); err == nil { + current = int(util.Read32[int32](c.mod, curPtr)) + highwater = int(util.Read32[int32](c.mod, hiPtr)) } return } @@ -472,7 +467,7 @@ func (c *Conn) Status(op DBStatus, reset bool) (current, highwater int, err erro func (c *Conn) TableColumnMetadata(schema, table, column string) (declType, collSeq string, notNull, primaryKey, autoInc bool, err error) { defer c.arena.mark()() - var schemaPtr, columnPtr uint32 + var schemaPtr, columnPtr ptr_t declTypePtr := c.arena.new(ptrlen) collSeqPtr := c.arena.new(ptrlen) notNullPtr := c.arena.new(ptrlen) @@ -486,25 +481,25 @@ func (c *Conn) TableColumnMetadata(schema, table, column string) (declType, coll columnPtr = c.arena.string(column) } - r := c.call("sqlite3_table_column_metadata", uint64(c.handle), - uint64(schemaPtr), uint64(tablePtr), uint64(columnPtr), - uint64(declTypePtr), uint64(collSeqPtr), - uint64(notNullPtr), uint64(primaryKeyPtr), uint64(autoIncPtr)) - if err = c.error(r); err == nil && column != "" { - if ptr := util.ReadUint32(c.mod, declTypePtr); ptr != 0 { + rc := res_t(c.call("sqlite3_table_column_metadata", stk_t(c.handle), + stk_t(schemaPtr), stk_t(tablePtr), stk_t(columnPtr), + stk_t(declTypePtr), stk_t(collSeqPtr), + stk_t(notNullPtr), stk_t(primaryKeyPtr), stk_t(autoIncPtr))) + if err = c.error(rc); err == nil && column != "" { + if ptr := util.Read32[ptr_t](c.mod, declTypePtr); ptr != 0 { declType = util.ReadString(c.mod, ptr, _MAX_NAME) } - if ptr := util.ReadUint32(c.mod, collSeqPtr); ptr != 0 { + if ptr := util.Read32[ptr_t](c.mod, collSeqPtr); ptr != 0 { collSeq = util.ReadString(c.mod, ptr, _MAX_NAME) } - notNull = util.ReadUint32(c.mod, notNullPtr) != 0 - autoInc = util.ReadUint32(c.mod, autoIncPtr) != 0 - primaryKey = util.ReadUint32(c.mod, primaryKeyPtr) != 0 + notNull = util.Read32[uint32](c.mod, notNullPtr) != 0 + autoInc = util.Read32[uint32](c.mod, autoIncPtr) != 0 + primaryKey = util.Read32[uint32](c.mod, primaryKeyPtr) != 0 } return } -func (c *Conn) error(rc uint64, sql ...string) error { +func (c *Conn) error(rc res_t, sql ...string) error { return c.sqlite.error(rc, c.handle, sql...) } diff --git a/const.go b/const.go index d4908de0..086902a6 100644 --- a/const.go +++ b/const.go @@ -1,6 +1,10 @@ package sqlite3 -import "strconv" +import ( + "strconv" + + "github.com/ncruces/go-sqlite3/internal/util" +) const ( _OK = 0 /* Successful result */ @@ -12,8 +16,14 @@ const ( _MAX_SQL_LENGTH = 1e9 _MAX_FUNCTION_ARG = 100 - ptrlen = 4 - intlen = 4 + ptrlen = util.PtrLen + intlen = util.IntLen +) + +type ( + stk_t = util.Stk_t + ptr_t = util.Ptr_t + res_t = util.Res_t ) // ErrorCode is a result code that [Error.Code] might return. diff --git a/context.go b/context.go index 86be214e..637ddc28 100644 --- a/context.go +++ b/context.go @@ -15,7 +15,7 @@ import ( // https://sqlite.org/c3ref/context.html type Context struct { c *Conn - handle uint32 + handle ptr_t } // Conn returns the database connection of the @@ -32,14 +32,14 @@ func (ctx Context) Conn() *Conn { // https://sqlite.org/c3ref/get_auxdata.html func (ctx Context) SetAuxData(n int, data any) { ptr := util.AddHandle(ctx.c.ctx, data) - ctx.c.call("sqlite3_set_auxdata_go", uint64(ctx.handle), uint64(n), uint64(ptr)) + ctx.c.call("sqlite3_set_auxdata_go", stk_t(ctx.handle), stk_t(n), stk_t(ptr)) } // GetAuxData returns metadata for argument n of the function. // // https://sqlite.org/c3ref/get_auxdata.html func (ctx Context) GetAuxData(n int) any { - ptr := uint32(ctx.c.call("sqlite3_get_auxdata", uint64(ctx.handle), uint64(n))) + ptr := ptr_t(ctx.c.call("sqlite3_get_auxdata", stk_t(ctx.handle), stk_t(n))) return util.GetHandle(ctx.c.ctx, ptr) } @@ -68,7 +68,7 @@ func (ctx Context) ResultInt(value int) { // https://sqlite.org/c3ref/result_blob.html func (ctx Context) ResultInt64(value int64) { ctx.c.call("sqlite3_result_int64", - uint64(ctx.handle), uint64(value)) + stk_t(ctx.handle), stk_t(value)) } // ResultFloat sets the result of the function to a float64. @@ -76,7 +76,7 @@ func (ctx Context) ResultInt64(value int64) { // https://sqlite.org/c3ref/result_blob.html func (ctx Context) ResultFloat(value float64) { ctx.c.call("sqlite3_result_double", - uint64(ctx.handle), math.Float64bits(value)) + stk_t(ctx.handle), stk_t(math.Float64bits(value))) } // ResultText sets the result of the function to a string. @@ -85,7 +85,7 @@ func (ctx Context) ResultFloat(value float64) { func (ctx Context) ResultText(value string) { ptr := ctx.c.newString(value) ctx.c.call("sqlite3_result_text_go", - uint64(ctx.handle), uint64(ptr), uint64(len(value))) + stk_t(ctx.handle), stk_t(ptr), stk_t(len(value))) } // ResultRawText sets the text result of the function to a []byte. @@ -95,7 +95,7 @@ func (ctx Context) ResultText(value string) { func (ctx Context) ResultRawText(value []byte) { ptr := ctx.c.newBytes(value) ctx.c.call("sqlite3_result_text_go", - uint64(ctx.handle), uint64(ptr), uint64(len(value))) + stk_t(ctx.handle), stk_t(ptr), stk_t(len(value))) } // ResultBlob sets the result of the function to a []byte. @@ -105,7 +105,7 @@ func (ctx Context) ResultRawText(value []byte) { func (ctx Context) ResultBlob(value []byte) { ptr := ctx.c.newBytes(value) ctx.c.call("sqlite3_result_blob_go", - uint64(ctx.handle), uint64(ptr), uint64(len(value))) + stk_t(ctx.handle), stk_t(ptr), stk_t(len(value))) } // ResultZeroBlob sets the result of the function to a zero-filled, length n BLOB. @@ -113,7 +113,7 @@ func (ctx Context) ResultBlob(value []byte) { // https://sqlite.org/c3ref/result_blob.html func (ctx Context) ResultZeroBlob(n int64) { ctx.c.call("sqlite3_result_zeroblob64", - uint64(ctx.handle), uint64(n)) + stk_t(ctx.handle), stk_t(n)) } // ResultNull sets the result of the function to NULL. @@ -121,7 +121,7 @@ func (ctx Context) ResultZeroBlob(n int64) { // https://sqlite.org/c3ref/result_blob.html func (ctx Context) ResultNull() { ctx.c.call("sqlite3_result_null", - uint64(ctx.handle)) + stk_t(ctx.handle)) } // ResultTime sets the result of the function to a [time.Time]. @@ -146,14 +146,14 @@ func (ctx Context) ResultTime(value time.Time, format TimeFormat) { } func (ctx Context) resultRFC3339Nano(value time.Time) { - const maxlen = uint64(len(time.RFC3339Nano)) + 5 + const maxlen = int64(len(time.RFC3339Nano)) + 5 ptr := ctx.c.new(maxlen) buf := util.View(ctx.c.mod, ptr, maxlen) buf = value.AppendFormat(buf[:0], time.RFC3339Nano) ctx.c.call("sqlite3_result_text_go", - uint64(ctx.handle), uint64(ptr), uint64(len(buf))) + stk_t(ctx.handle), stk_t(ptr), stk_t(len(buf))) } // ResultPointer sets the result of the function to NULL, just like [Context.ResultNull], @@ -164,7 +164,7 @@ func (ctx Context) resultRFC3339Nano(value time.Time) { func (ctx Context) ResultPointer(ptr any) { valPtr := util.AddHandle(ctx.c.ctx, ptr) ctx.c.call("sqlite3_result_pointer_go", - uint64(ctx.handle), uint64(valPtr)) + stk_t(ctx.handle), stk_t(valPtr)) } // ResultJSON sets the result of the function to the JSON encoding of value. @@ -188,7 +188,7 @@ func (ctx Context) ResultValue(value Value) { return } ctx.c.call("sqlite3_result_value", - uint64(ctx.handle), uint64(value.handle)) + stk_t(ctx.handle), stk_t(value.handle)) } // ResultError sets the result of the function an error. @@ -196,12 +196,12 @@ func (ctx Context) ResultValue(value Value) { // https://sqlite.org/c3ref/result_blob.html func (ctx Context) ResultError(err error) { if errors.Is(err, NOMEM) { - ctx.c.call("sqlite3_result_error_nomem", uint64(ctx.handle)) + ctx.c.call("sqlite3_result_error_nomem", stk_t(ctx.handle)) return } if errors.Is(err, TOOBIG) { - ctx.c.call("sqlite3_result_error_toobig", uint64(ctx.handle)) + ctx.c.call("sqlite3_result_error_toobig", stk_t(ctx.handle)) return } @@ -210,11 +210,11 @@ func (ctx Context) ResultError(err error) { defer ctx.c.arena.mark()() ptr := ctx.c.arena.string(msg) ctx.c.call("sqlite3_result_error", - uint64(ctx.handle), uint64(ptr), uint64(len(msg))) + stk_t(ctx.handle), stk_t(ptr), stk_t(len(msg))) } if code != _OK { ctx.c.call("sqlite3_result_error_code", - uint64(ctx.handle), uint64(code)) + stk_t(ctx.handle), stk_t(code)) } } @@ -223,6 +223,6 @@ func (ctx Context) ResultError(err error) { // // https://sqlite.org/c3ref/vtab_nochange.html func (ctx Context) VTabNoChange() bool { - r := ctx.c.call("sqlite3_vtab_nochange", uint64(ctx.handle)) - return r != 0 + b := int32(ctx.c.call("sqlite3_vtab_nochange", stk_t(ctx.handle))) + return b != 0 } diff --git a/driver/driver.go b/driver/driver.go index 66b209eb..cc132afe 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -201,7 +201,7 @@ func (n *connector) Driver() driver.Driver { return &SQLite{} } -func (n *connector) Connect(ctx context.Context) (res driver.Conn, err error) { +func (n *connector) Connect(ctx context.Context) (ret driver.Conn, err error) { c := &conn{ txLock: n.txLock, tmRead: n.tmRead, @@ -213,7 +213,7 @@ func (n *connector) Connect(ctx context.Context) (res driver.Conn, err error) { return nil, err } defer func() { - if res == nil { + if ret == nil { c.Close() } }() diff --git a/error.go b/error.go index 870aa3ab..6d4bd63f 100644 --- a/error.go +++ b/error.go @@ -15,7 +15,7 @@ type Error struct { str string msg string sql string - code uint64 + code res_t } // Code returns the primary error code for this error. @@ -146,27 +146,27 @@ func (e ExtendedErrorCode) Code() ErrorCode { return ErrorCode(e) } -func errorCode(err error, def ErrorCode) (msg string, code uint32) { +func errorCode(err error, def ErrorCode) (msg string, code res_t) { switch code := err.(type) { case nil: return "", _OK case ErrorCode: - return "", uint32(code) + return "", res_t(code) case xErrorCode: - return "", uint32(code) + return "", res_t(code) case *Error: - return code.msg, uint32(code.code) + return code.msg, res_t(code.code) } var ecode ErrorCode var xcode xErrorCode switch { case errors.As(err, &xcode): - code = uint32(xcode) + code = res_t(xcode) case errors.As(err, &ecode): - code = uint32(ecode) + code = res_t(ecode) default: - code = uint32(def) + code = res_t(def) } return err.Error(), code } diff --git a/error_test.go b/error_test.go index 2204fa8b..2ec3f49a 100644 --- a/error_test.go +++ b/error_test.go @@ -59,14 +59,14 @@ func TestError_Temporary(t *testing.T) { tests := []struct { name string - code uint64 + code res_t want bool }{ - {"ERROR", uint64(ERROR), false}, - {"BUSY", uint64(BUSY), true}, - {"BUSY_RECOVERY", uint64(BUSY_RECOVERY), true}, - {"BUSY_SNAPSHOT", uint64(BUSY_SNAPSHOT), true}, - {"BUSY_TIMEOUT", uint64(BUSY_TIMEOUT), true}, + {"ERROR", res_t(ERROR), false}, + {"BUSY", res_t(BUSY), true}, + {"BUSY_RECOVERY", res_t(BUSY_RECOVERY), true}, + {"BUSY_SNAPSHOT", res_t(BUSY_SNAPSHOT), true}, + {"BUSY_TIMEOUT", res_t(BUSY_TIMEOUT), true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -97,14 +97,14 @@ func TestError_Timeout(t *testing.T) { tests := []struct { name string - code uint64 + code res_t want bool }{ - {"ERROR", uint64(ERROR), false}, - {"BUSY", uint64(BUSY), false}, - {"BUSY_RECOVERY", uint64(BUSY_RECOVERY), false}, - {"BUSY_SNAPSHOT", uint64(BUSY_SNAPSHOT), false}, - {"BUSY_TIMEOUT", uint64(BUSY_TIMEOUT), true}, + {"ERROR", res_t(ERROR), false}, + {"BUSY", res_t(BUSY), false}, + {"BUSY_RECOVERY", res_t(BUSY_RECOVERY), false}, + {"BUSY_SNAPSHOT", res_t(BUSY_SNAPSHOT), false}, + {"BUSY_TIMEOUT", res_t(BUSY_TIMEOUT), true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -136,8 +136,8 @@ func Test_ErrorCode_Error(t *testing.T) { // Test all error codes. for i := 0; i == int(ErrorCode(i)); i++ { want := "sqlite3: " - r := db.call("sqlite3_errstr", uint64(i)) - want += util.ReadString(db.mod, uint32(r), _MAX_NAME) + ptr := ptr_t(db.call("sqlite3_errstr", stk_t(i))) + want += util.ReadString(db.mod, ptr, _MAX_NAME) got := ErrorCode(i).Error() if got != want { @@ -158,8 +158,8 @@ func Test_ExtendedErrorCode_Error(t *testing.T) { // Test all extended error codes. for i := 0; i == int(ExtendedErrorCode(i)); i++ { want := "sqlite3: " - r := db.call("sqlite3_errstr", uint64(i)) - want += util.ReadString(db.mod, uint32(r), _MAX_NAME) + ptr := ptr_t(db.call("sqlite3_errstr", stk_t(i))) + want += util.ReadString(db.mod, ptr, _MAX_NAME) got := ExtendedErrorCode(i).Error() if got != want { @@ -172,7 +172,7 @@ func Test_errorCode(t *testing.T) { tests := []struct { arg error wantMsg string - wantCode uint32 + wantCode res_t }{ {nil, "", _OK}, {ERROR, "", util.ERROR}, @@ -190,7 +190,7 @@ func Test_errorCode(t *testing.T) { if gotMsg != tt.wantMsg { t.Errorf("errorCode() gotMsg = %q, want %q", gotMsg, tt.wantMsg) } - if gotCode != uint32(tt.wantCode) { + if gotCode != tt.wantCode { t.Errorf("errorCode() gotCode = %d, want %d", gotCode, tt.wantCode) } }) diff --git a/ext/README.md b/ext/README.md index 35028dc7..5760cb78 100644 --- a/ext/README.md +++ b/ext/README.md @@ -25,6 +25,8 @@ you can load into your database connections. creates [pivot tables](https://github.com/jakethaw/pivot_vtab). - [`github.com/ncruces/go-sqlite3/ext/regexp`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/regexp) provides regular expression functions. +- [`github.com/ncruces/go-sqlite3/ext/serdes`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/serdes) + (de)serializes databases. - [`github.com/ncruces/go-sqlite3/ext/statement`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/statement) creates [parameterized views](https://github.com/0x09/sqlite-statement-vtab). - [`github.com/ncruces/go-sqlite3/ext/stats`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/stats) diff --git a/ext/csv/csv.go b/ext/csv/csv.go index 7169d534..2ea48719 100644 --- a/ext/csv/csv.go +++ b/ext/csv/csv.go @@ -30,7 +30,7 @@ func Register(db *sqlite3.Conn) error { // RegisterFS registers the CSV virtual table. // If a filename is specified, fsys is used to open the file. func RegisterFS(db *sqlite3.Conn, fsys fs.FS) error { - declare := func(db *sqlite3.Conn, _, _, _ string, arg ...string) (res *table, err error) { + declare := func(db *sqlite3.Conn, _, _, _ string, arg ...string) (_ *table, err error) { var ( filename string data string diff --git a/ext/pivot/pivot.go b/ext/pivot/pivot.go index eafb615c..b669325d 100644 --- a/ext/pivot/pivot.go +++ b/ext/pivot/pivot.go @@ -25,14 +25,14 @@ type table struct { cols []*sqlite3.Value } -func declare(db *sqlite3.Conn, _, _, _ string, arg ...string) (res *table, err error) { +func declare(db *sqlite3.Conn, _, _, _ string, arg ...string) (ret *table, err error) { if len(arg) != 3 { return nil, fmt.Errorf("pivot: wrong number of arguments") } t := &table{db: db} defer func() { - if res == nil { + if ret == nil { t.Close() } }() diff --git a/func.go b/func.go index c416e695..6b69368b 100644 --- a/func.go +++ b/func.go @@ -14,12 +14,12 @@ import ( // // https://sqlite.org/c3ref/collation_needed.html func (c *Conn) CollationNeeded(cb func(db *Conn, name string)) error { - var enable uint64 + var enable int32 if cb != nil { enable = 1 } - r := c.call("sqlite3_collation_needed_go", uint64(c.handle), enable) - if err := c.error(r); err != nil { + rc := res_t(c.call("sqlite3_collation_needed_go", stk_t(c.handle), stk_t(enable))) + if err := c.error(rc); err != nil { return err } c.collation = cb @@ -33,8 +33,8 @@ func (c *Conn) CollationNeeded(cb func(db *Conn, name string)) error { // This can be used to load schemas that contain // one or more unknown collating sequences. func (c Conn) AnyCollationNeeded() error { - r := c.call("sqlite3_anycollseq_init", uint64(c.handle), 0, 0) - if err := c.error(r); err != nil { + rc := res_t(c.call("sqlite3_anycollseq_init", stk_t(c.handle), 0, 0)) + if err := c.error(rc); err != nil { return err } c.collation = nil @@ -45,31 +45,31 @@ func (c Conn) AnyCollationNeeded() error { // // https://sqlite.org/c3ref/create_collation.html func (c *Conn) CreateCollation(name string, fn func(a, b []byte) int) error { - var funcPtr uint32 + var funcPtr ptr_t defer c.arena.mark()() namePtr := c.arena.string(name) if fn != nil { funcPtr = util.AddHandle(c.ctx, fn) } - r := c.call("sqlite3_create_collation_go", - uint64(c.handle), uint64(namePtr), uint64(funcPtr)) - return c.error(r) + rc := res_t(c.call("sqlite3_create_collation_go", + stk_t(c.handle), stk_t(namePtr), stk_t(funcPtr))) + return c.error(rc) } // CreateFunction defines a new scalar SQL function. // // https://sqlite.org/c3ref/create_function.html func (c *Conn) CreateFunction(name string, nArg int, flag FunctionFlag, fn ScalarFunction) error { - var funcPtr uint32 + var funcPtr ptr_t defer c.arena.mark()() namePtr := c.arena.string(name) if fn != nil { funcPtr = util.AddHandle(c.ctx, fn) } - r := c.call("sqlite3_create_function_go", - uint64(c.handle), uint64(namePtr), uint64(nArg), - uint64(flag), uint64(funcPtr)) - return c.error(r) + rc := res_t(c.call("sqlite3_create_function_go", + stk_t(c.handle), stk_t(namePtr), stk_t(nArg), + stk_t(flag), stk_t(funcPtr))) + return c.error(rc) } // ScalarFunction is the type of a scalar SQL function. @@ -82,7 +82,7 @@ type ScalarFunction func(ctx Context, arg ...Value) // // https://sqlite.org/c3ref/create_function.html func (c *Conn) CreateWindowFunction(name string, nArg int, flag FunctionFlag, fn func() AggregateFunction) error { - var funcPtr uint32 + var funcPtr ptr_t defer c.arena.mark()() namePtr := c.arena.string(name) if fn != nil { @@ -92,10 +92,10 @@ func (c *Conn) CreateWindowFunction(name string, nArg int, flag FunctionFlag, fn if _, ok := fn().(WindowFunction); ok { call = "sqlite3_create_window_function_go" } - r := c.call(call, - uint64(c.handle), uint64(namePtr), uint64(nArg), - uint64(flag), uint64(funcPtr)) - return c.error(r) + rc := res_t(c.call(call, + stk_t(c.handle), stk_t(namePtr), stk_t(nArg), + stk_t(flag), stk_t(funcPtr))) + return c.error(rc) } // AggregateFunction is the interface an aggregate function should implement. @@ -129,28 +129,28 @@ type WindowFunction interface { func (c *Conn) OverloadFunction(name string, nArg int) error { defer c.arena.mark()() namePtr := c.arena.string(name) - r := c.call("sqlite3_overload_function", - uint64(c.handle), uint64(namePtr), uint64(nArg)) - return c.error(r) + rc := res_t(c.call("sqlite3_overload_function", + stk_t(c.handle), stk_t(namePtr), stk_t(nArg))) + return c.error(rc) } -func destroyCallback(ctx context.Context, mod api.Module, pApp uint32) { +func destroyCallback(ctx context.Context, mod api.Module, pApp ptr_t) { util.DelHandle(ctx, pApp) } -func collationCallback(ctx context.Context, mod api.Module, pArg, pDB, eTextRep, zName uint32) { +func collationCallback(ctx context.Context, mod api.Module, pArg, pDB ptr_t, eTextRep uint32, zName ptr_t) { if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.collation != nil { name := util.ReadString(mod, zName, _MAX_NAME) c.collation(c, name) } } -func compareCallback(ctx context.Context, mod api.Module, pApp, nKey1, pKey1, nKey2, pKey2 uint32) uint32 { +func compareCallback(ctx context.Context, mod api.Module, pApp ptr_t, nKey1 int32, pKey1 ptr_t, nKey2 int32, pKey2 ptr_t) uint32 { fn := util.GetHandle(ctx, pApp).(func(a, b []byte) int) - return uint32(fn(util.View(mod, pKey1, uint64(nKey1)), util.View(mod, pKey2, uint64(nKey2)))) + return uint32(fn(util.View(mod, pKey1, int64(nKey1)), util.View(mod, pKey2, int64(nKey2)))) } -func funcCallback(ctx context.Context, mod api.Module, pCtx, pApp, nArg, pArg uint32) { +func funcCallback(ctx context.Context, mod api.Module, pCtx, pApp ptr_t, nArg int32, pArg ptr_t) { args := getFuncArgs() defer putFuncArgs(args) db := ctx.Value(connKey{}).(*Conn) @@ -159,7 +159,7 @@ func funcCallback(ctx context.Context, mod api.Module, pCtx, pApp, nArg, pArg ui fn(Context{db, pCtx}, args[:nArg]...) } -func stepCallback(ctx context.Context, mod api.Module, pCtx, pAgg, pApp, nArg, pArg uint32) { +func stepCallback(ctx context.Context, mod api.Module, pCtx, pAgg, pApp ptr_t, nArg int32, pArg ptr_t) { args := getFuncArgs() defer putFuncArgs(args) db := ctx.Value(connKey{}).(*Conn) @@ -168,7 +168,7 @@ func stepCallback(ctx context.Context, mod api.Module, pCtx, pAgg, pApp, nArg, p fn.Step(Context{db, pCtx}, args[:nArg]...) } -func finalCallback(ctx context.Context, mod api.Module, pCtx, pAgg, pApp uint32) { +func finalCallback(ctx context.Context, mod api.Module, pCtx, pAgg, pApp ptr_t) { db := ctx.Value(connKey{}).(*Conn) fn, handle := callbackAggregate(db, pAgg, pApp) fn.Value(Context{db, pCtx}) @@ -178,13 +178,13 @@ func finalCallback(ctx context.Context, mod api.Module, pCtx, pAgg, pApp uint32) } } -func valueCallback(ctx context.Context, mod api.Module, pCtx, pAgg uint32) { +func valueCallback(ctx context.Context, mod api.Module, pCtx, pAgg ptr_t) { db := ctx.Value(connKey{}).(*Conn) fn := util.GetHandle(db.ctx, pAgg).(AggregateFunction) fn.Value(Context{db, pCtx}) } -func inverseCallback(ctx context.Context, mod api.Module, pCtx, pAgg, nArg, pArg uint32) { +func inverseCallback(ctx context.Context, mod api.Module, pCtx, pAgg ptr_t, nArg int32, pArg ptr_t) { args := getFuncArgs() defer putFuncArgs(args) db := ctx.Value(connKey{}).(*Conn) @@ -193,9 +193,9 @@ func inverseCallback(ctx context.Context, mod api.Module, pCtx, pAgg, nArg, pArg fn.Inverse(Context{db, pCtx}, args[:nArg]...) } -func callbackAggregate(db *Conn, pAgg, pApp uint32) (AggregateFunction, uint32) { +func callbackAggregate(db *Conn, pAgg, pApp ptr_t) (AggregateFunction, ptr_t) { if pApp == 0 { - handle := util.ReadUint32(db.mod, pAgg) + handle := util.Read32[ptr_t](db.mod, pAgg) return util.GetHandle(db.ctx, handle).(AggregateFunction), handle } @@ -203,17 +203,17 @@ func callbackAggregate(db *Conn, pAgg, pApp uint32) (AggregateFunction, uint32) fn := util.GetHandle(db.ctx, pApp).(func() AggregateFunction)() if pAgg != 0 { handle := util.AddHandle(db.ctx, fn) - util.WriteUint32(db.mod, pAgg, handle) + util.Write32(db.mod, pAgg, handle) return fn, handle } return fn, 0 } -func callbackArgs(db *Conn, arg []Value, pArg uint32) { +func callbackArgs(db *Conn, arg []Value, pArg ptr_t) { for i := range arg { arg[i] = Value{ c: db, - handle: util.ReadUint32(db.mod, pArg+ptrlen*uint32(i)), + handle: util.Read32[ptr_t](db.mod, pArg+ptr_t(i)*ptrlen), } } } diff --git a/internal/util/func.go b/internal/util/func.go index 468ff741..e705f318 100644 --- a/internal/util/func.go +++ b/internal/util/func.go @@ -7,9 +7,6 @@ import ( "github.com/tetratelabs/wazero/api" ) -type i32 interface{ ~int32 | ~uint32 } -type i64 interface{ ~int64 | ~uint64 } - type funcVI[T0 i32] func(context.Context, api.Module, T0) func (fn funcVI[T0]) Call(ctx context.Context, mod api.Module, stack []uint64) { diff --git a/internal/util/handle.go b/internal/util/handle.go index e4e33854..f9f39b44 100644 --- a/internal/util/handle.go +++ b/internal/util/handle.go @@ -20,7 +20,7 @@ func (s *handleState) CloseNotify(ctx context.Context, exitCode uint32) { s.holes = 0 } -func GetHandle(ctx context.Context, id uint32) any { +func GetHandle(ctx context.Context, id Ptr_t) any { if id == 0 { return nil } @@ -28,14 +28,14 @@ func GetHandle(ctx context.Context, id uint32) any { return s.handles[^id] } -func DelHandle(ctx context.Context, id uint32) error { +func DelHandle(ctx context.Context, id Ptr_t) error { if id == 0 { return nil } s := ctx.Value(moduleKey{}).(*moduleState) a := s.handles[^id] s.handles[^id] = nil - if l := uint32(len(s.handles)); l == ^id { + if l := Ptr_t(len(s.handles)); l == ^id { s.handles = s.handles[:l-1] } else { s.holes++ @@ -46,7 +46,7 @@ func DelHandle(ctx context.Context, id uint32) error { return nil } -func AddHandle(ctx context.Context, a any) uint32 { +func AddHandle(ctx context.Context, a any) Ptr_t { if a == nil { panic(NilErr) } @@ -59,12 +59,12 @@ func AddHandle(ctx context.Context, a any) uint32 { if h == nil { s.holes-- s.handles[id] = a - return ^uint32(id) + return ^Ptr_t(id) } } } // Add a new slot. s.handles = append(s.handles, a) - return -uint32(len(s.handles)) + return -Ptr_t(len(s.handles)) } diff --git a/internal/util/mem.go b/internal/util/mem.go index a09523fd..bfb1a644 100644 --- a/internal/util/mem.go +++ b/internal/util/mem.go @@ -7,110 +7,121 @@ import ( "github.com/tetratelabs/wazero/api" ) -func View(mod api.Module, ptr uint32, size uint64) []byte { +const ( + PtrLen = 4 + IntLen = 4 +) + +type ( + i8 interface{ ~int8 | ~uint8 } + i32 interface{ ~int32 | ~uint32 } + i64 interface{ ~int64 | ~uint64 } + + Stk_t = uint64 + Ptr_t uint32 + Res_t int32 +) + +func View(mod api.Module, ptr Ptr_t, size int64) []byte { if ptr == 0 { panic(NilErr) } - if size > math.MaxUint32 { - panic(RangeErr) - } if size == 0 { return nil } - buf, ok := mod.Memory().Read(ptr, uint32(size)) + if uint64(size) > math.MaxUint32 { + panic(RangeErr) + } + buf, ok := mod.Memory().Read(uint32(ptr), uint32(size)) if !ok { panic(RangeErr) } return buf } -func ReadUint8(mod api.Module, ptr uint32) uint8 { +func Read[T i8](mod api.Module, ptr Ptr_t) T { if ptr == 0 { panic(NilErr) } - v, ok := mod.Memory().ReadByte(ptr) + v, ok := mod.Memory().ReadByte(uint32(ptr)) if !ok { panic(RangeErr) } - return v + return T(v) } -func ReadUint32(mod api.Module, ptr uint32) uint32 { +func Write[T i8](mod api.Module, ptr Ptr_t, v T) { if ptr == 0 { panic(NilErr) } - v, ok := mod.Memory().ReadUint32Le(ptr) + ok := mod.Memory().WriteByte(uint32(ptr), uint8(v)) if !ok { panic(RangeErr) } - return v } -func WriteUint8(mod api.Module, ptr uint32, v uint8) { +func Read32[T i32](mod api.Module, ptr Ptr_t) T { if ptr == 0 { panic(NilErr) } - ok := mod.Memory().WriteByte(ptr, v) + v, ok := mod.Memory().ReadUint32Le(uint32(ptr)) if !ok { panic(RangeErr) } + return T(v) } -func WriteUint32(mod api.Module, ptr uint32, v uint32) { +func Write32[T i32](mod api.Module, ptr Ptr_t, v T) { if ptr == 0 { panic(NilErr) } - ok := mod.Memory().WriteUint32Le(ptr, v) + ok := mod.Memory().WriteUint32Le(uint32(ptr), uint32(v)) if !ok { panic(RangeErr) } } -func ReadUint64(mod api.Module, ptr uint32) uint64 { +func Read64[T i64](mod api.Module, ptr Ptr_t) T { if ptr == 0 { panic(NilErr) } - v, ok := mod.Memory().ReadUint64Le(ptr) + v, ok := mod.Memory().ReadUint64Le(uint32(ptr)) if !ok { panic(RangeErr) } - return v + return T(v) } -func WriteUint64(mod api.Module, ptr uint32, v uint64) { +func Write64[T i64](mod api.Module, ptr Ptr_t, v T) { if ptr == 0 { panic(NilErr) } - ok := mod.Memory().WriteUint64Le(ptr, v) + ok := mod.Memory().WriteUint64Le(uint32(ptr), uint64(v)) if !ok { panic(RangeErr) } } -func ReadFloat64(mod api.Module, ptr uint32) float64 { - return math.Float64frombits(ReadUint64(mod, ptr)) +func ReadFloat64(mod api.Module, ptr Ptr_t) float64 { + return math.Float64frombits(Read64[uint64](mod, ptr)) } -func WriteFloat64(mod api.Module, ptr uint32, v float64) { - WriteUint64(mod, ptr, math.Float64bits(v)) +func WriteFloat64(mod api.Module, ptr Ptr_t, v float64) { + Write64(mod, ptr, math.Float64bits(v)) } -func ReadString(mod api.Module, ptr, maxlen uint32) string { +func ReadString(mod api.Module, ptr Ptr_t, maxlen int64) string { if ptr == 0 { panic(NilErr) } - switch maxlen { - case 0: + if maxlen <= 0 { return "" - case math.MaxUint32: - // avoid overflow - default: - maxlen = maxlen + 1 } mem := mod.Memory() - buf, ok := mem.Read(ptr, maxlen) + maxlen = min(maxlen, math.MaxInt32-1) + 1 + buf, ok := mem.Read(uint32(ptr), uint32(maxlen)) if !ok { - buf, ok = mem.Read(ptr, mem.Size()-ptr) + buf, ok = mem.Read(uint32(ptr), mem.Size()-uint32(ptr)) if !ok { panic(RangeErr) } @@ -122,13 +133,13 @@ func ReadString(mod api.Module, ptr, maxlen uint32) string { } } -func WriteBytes(mod api.Module, ptr uint32, b []byte) { - buf := View(mod, ptr, uint64(len(b))) +func WriteBytes(mod api.Module, ptr Ptr_t, b []byte) { + buf := View(mod, ptr, int64(len(b))) copy(buf, b) } -func WriteString(mod api.Module, ptr uint32, s string) { - buf := View(mod, ptr, uint64(len(s)+1)) +func WriteString(mod api.Module, ptr Ptr_t, s string) { + buf := View(mod, ptr, int64(len(s))+1) buf[len(s)] = 0 copy(buf, s) } diff --git a/internal/util/mem_test.go b/internal/util/mem_test.go index 733ab344..28d28555 100644 --- a/internal/util/mem_test.go +++ b/internal/util/mem_test.go @@ -31,90 +31,90 @@ func TestView_overflow(t *testing.T) { func TestReadUint8_nil(t *testing.T) { defer func() { _ = recover() }() mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize)) - ReadUint8(mock, 0) + Read[byte](mock, 0) t.Error("want panic") } func TestReadUint8_range(t *testing.T) { defer func() { _ = recover() }() mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize)) - ReadUint8(mock, wazerotest.PageSize) + Read[byte](mock, wazerotest.PageSize) t.Error("want panic") } func TestReadUint32_nil(t *testing.T) { defer func() { _ = recover() }() mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize)) - ReadUint32(mock, 0) + Read32[uint32](mock, 0) t.Error("want panic") } func TestReadUint32_range(t *testing.T) { defer func() { _ = recover() }() mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize)) - ReadUint32(mock, wazerotest.PageSize-2) + Read32[uint32](mock, wazerotest.PageSize-2) t.Error("want panic") } func TestReadUint64_nil(t *testing.T) { defer func() { _ = recover() }() mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize)) - ReadUint64(mock, 0) + Read64[uint64](mock, 0) t.Error("want panic") } func TestReadUint64_range(t *testing.T) { defer func() { _ = recover() }() mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize)) - ReadUint64(mock, wazerotest.PageSize-2) + Read64[uint64](mock, wazerotest.PageSize-2) t.Error("want panic") } func TestWriteUint8_nil(t *testing.T) { defer func() { _ = recover() }() mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize)) - WriteUint8(mock, 0, 1) + Write[byte](mock, 0, 1) t.Error("want panic") } func TestWriteUint8_range(t *testing.T) { defer func() { _ = recover() }() mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize)) - WriteUint8(mock, wazerotest.PageSize, 1) + Write[byte](mock, wazerotest.PageSize, 1) t.Error("want panic") } func TestWriteUint32_nil(t *testing.T) { defer func() { _ = recover() }() mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize)) - WriteUint32(mock, 0, 1) + Write32[uint32](mock, 0, 1) t.Error("want panic") } func TestWriteUint32_range(t *testing.T) { defer func() { _ = recover() }() mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize)) - WriteUint32(mock, wazerotest.PageSize-2, 1) + Write32[uint32](mock, wazerotest.PageSize-2, 1) t.Error("want panic") } func TestWriteUint64_nil(t *testing.T) { defer func() { _ = recover() }() mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize)) - WriteUint64(mock, 0, 1) + Write64[uint64](mock, 0, 1) t.Error("want panic") } func TestWriteUint64_range(t *testing.T) { defer func() { _ = recover() }() mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize)) - WriteUint64(mock, wazerotest.PageSize-2, 1) + Write64[uint64](mock, wazerotest.PageSize-2, 1) t.Error("want panic") } func TestReadString_range(t *testing.T) { defer func() { _ = recover() }() mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize)) - ReadString(mock, wazerotest.PageSize+2, math.MaxUint32) + ReadString(mock, wazerotest.PageSize+2, math.MaxInt) t.Error("want panic") } diff --git a/internal/util/mmap_unix.go b/internal/util/mmap_unix.go index 4ff05666..42a24752 100644 --- a/internal/util/mmap_unix.go +++ b/internal/util/mmap_unix.go @@ -25,9 +25,9 @@ func (s *mmapState) new(ctx context.Context, mod api.Module, size int32) *Mapped // Allocate page aligned memmory. alloc := mod.ExportedFunction("aligned_alloc") - stack := [...]uint64{ - uint64(unix.Getpagesize()), - uint64(size), + stack := [...]Stk_t{ + Stk_t(unix.Getpagesize()), + Stk_t(size), } if err := alloc.CallWithStack(ctx, stack[:]); err != nil { panic(err) @@ -37,20 +37,20 @@ func (s *mmapState) new(ctx context.Context, mod api.Module, size int32) *Mapped } // Save the newly allocated region. - ptr := uint32(stack[0]) - buf := View(mod, ptr, uint64(size)) - res := &MappedRegion{ + ptr := Ptr_t(stack[0]) + buf := View(mod, ptr, int64(size)) + ret := &MappedRegion{ Ptr: ptr, size: size, addr: unsafe.Pointer(&buf[0]), } - s.regions = append(s.regions, res) - return res + s.regions = append(s.regions, ret) + return ret } type MappedRegion struct { addr unsafe.Pointer - Ptr uint32 + Ptr Ptr_t size int32 used bool } diff --git a/internal/util/mmap_windows.go b/internal/util/mmap_windows.go index efff1e73..913a5f73 100644 --- a/internal/util/mmap_windows.go +++ b/internal/util/mmap_windows.go @@ -29,13 +29,13 @@ func MapRegion(ctx context.Context, mod api.Module, f *os.File, offset int64, si return nil, err } - res := &MappedRegion{Handle: h, addr: a} + ret := &MappedRegion{Handle: h, addr: a} // SliceHeader, although deprecated, avoids a go vet warning. - sh := (*reflect.SliceHeader)(unsafe.Pointer(&res.Data)) + sh := (*reflect.SliceHeader)(unsafe.Pointer(&ret.Data)) sh.Len = int(size) sh.Cap = int(size) sh.Data = a - return res, nil + return ret, nil } func (r *MappedRegion) Unmap() error { diff --git a/sqlite.go b/sqlite.go index 90352d07..8203603e 100644 --- a/sqlite.go +++ b/sqlite.go @@ -3,7 +3,6 @@ package sqlite3 import ( "context" - "math" "math/bits" "os" "sync" @@ -94,7 +93,7 @@ type sqlite struct { id [32]*byte mask uint32 } - stack [9]uint64 + stack [9]stk_t } func instantiateSQLite() (sqlt *sqlite, err error) { @@ -120,7 +119,7 @@ func (sqlt *sqlite) close() error { return sqlt.mod.Close(sqlt.ctx) } -func (sqlt *sqlite) error(rc uint64, handle uint32, sql ...string) error { +func (sqlt *sqlite) error(rc res_t, handle ptr_t, sql ...string) error { if rc == _OK { return nil } @@ -131,18 +130,18 @@ func (sqlt *sqlite) error(rc uint64, handle uint32, sql ...string) error { panic(util.OOMErr) } - if r := sqlt.call("sqlite3_errstr", rc); r != 0 { - err.str = util.ReadString(sqlt.mod, uint32(r), _MAX_NAME) + if ptr := ptr_t(sqlt.call("sqlite3_errstr", stk_t(rc))); ptr != 0 { + err.str = util.ReadString(sqlt.mod, ptr, _MAX_NAME) } if handle != 0 { - if r := sqlt.call("sqlite3_errmsg", uint64(handle)); r != 0 { - err.msg = util.ReadString(sqlt.mod, uint32(r), _MAX_LENGTH) + if ptr := ptr_t(sqlt.call("sqlite3_errmsg", stk_t(handle))); ptr != 0 { + err.msg = util.ReadString(sqlt.mod, ptr, _MAX_LENGTH) } if len(sql) != 0 { - if r := sqlt.call("sqlite3_error_offset", uint64(handle)); r != math.MaxUint32 { - err.sql = sql[0][r:] + if i := int32(sqlt.call("sqlite3_error_offset", stk_t(handle))); i != -1 { + err.sql = sql[0][i:] } } } @@ -182,7 +181,7 @@ func (sqlt *sqlite) putfn(name string, fn api.Function) { } } -func (sqlt *sqlite) call(name string, params ...uint64) uint64 { +func (sqlt *sqlite) call(name string, params ...stk_t) stk_t { copy(sqlt.stack[:], params) fn := sqlt.getfn(name) err := fn.CallWithStack(sqlt.ctx, sqlt.stack[:]) @@ -190,33 +189,33 @@ func (sqlt *sqlite) call(name string, params ...uint64) uint64 { panic(err) } sqlt.putfn(name, fn) - return sqlt.stack[0] + return stk_t(sqlt.stack[0]) } -func (sqlt *sqlite) free(ptr uint32) { +func (sqlt *sqlite) free(ptr ptr_t) { if ptr == 0 { return } - sqlt.call("sqlite3_free", uint64(ptr)) + sqlt.call("sqlite3_free", stk_t(ptr)) } -func (sqlt *sqlite) new(size uint64) uint32 { - ptr := uint32(sqlt.call("sqlite3_malloc64", size)) +func (sqlt *sqlite) new(size int64) ptr_t { + ptr := ptr_t(sqlt.call("sqlite3_malloc64", stk_t(size))) if ptr == 0 && size != 0 { panic(util.OOMErr) } return ptr } -func (sqlt *sqlite) realloc(ptr uint32, size uint64) uint32 { - ptr = uint32(sqlt.call("sqlite3_realloc64", uint64(ptr), size)) +func (sqlt *sqlite) realloc(ptr ptr_t, size int64) ptr_t { + ptr = ptr_t(sqlt.call("sqlite3_realloc64", stk_t(ptr), stk_t(size))) if ptr == 0 && size != 0 { panic(util.OOMErr) } return ptr } -func (sqlt *sqlite) newBytes(b []byte) uint32 { +func (sqlt *sqlite) newBytes(b []byte) ptr_t { if (*[0]byte)(b) == nil { return 0 } @@ -224,33 +223,31 @@ func (sqlt *sqlite) newBytes(b []byte) uint32 { if size == 0 { size = 1 } - ptr := sqlt.new(uint64(size)) + ptr := sqlt.new(int64(size)) util.WriteBytes(sqlt.mod, ptr, b) return ptr } -func (sqlt *sqlite) newString(s string) uint32 { - ptr := sqlt.new(uint64(len(s) + 1)) +func (sqlt *sqlite) newString(s string) ptr_t { + ptr := sqlt.new(int64(len(s)) + 1) util.WriteString(sqlt.mod, ptr, s) return ptr } -func (sqlt *sqlite) newArena(size uint64) arena { - // Ensure the arena's size is a multiple of 8. - size = (size + 7) &^ 7 +const arenaSize = 4096 + +func (sqlt *sqlite) newArena() arena { return arena{ sqlt: sqlt, - size: uint32(size), - base: sqlt.new(size), + base: sqlt.new(arenaSize), } } type arena struct { sqlt *sqlite - ptrs []uint32 - base uint32 - next uint32 - size uint32 + ptrs []ptr_t + base ptr_t + next int32 } func (a *arena) free() { @@ -277,34 +274,34 @@ func (a *arena) mark() (reset func()) { } } -func (a *arena) new(size uint64) uint32 { +func (a *arena) new(size int64) ptr_t { // Align the next address, to 4 or 8 bytes. if size&7 != 0 { a.next = (a.next + 3) &^ 3 } else { a.next = (a.next + 7) &^ 7 } - if size <= uint64(a.size-a.next) { - ptr := a.base + a.next - a.next += uint32(size) - return ptr + if size <= arenaSize-int64(a.next) { + ptr := a.base + ptr_t(a.next) + a.next += int32(size) + return ptr_t(ptr) } ptr := a.sqlt.new(size) a.ptrs = append(a.ptrs, ptr) - return ptr + return ptr_t(ptr) } -func (a *arena) bytes(b []byte) uint32 { +func (a *arena) bytes(b []byte) ptr_t { if (*[0]byte)(b) == nil { return 0 } - ptr := a.new(uint64(len(b))) + ptr := a.new(int64(len(b))) util.WriteBytes(a.sqlt.mod, ptr, b) return ptr } -func (a *arena) string(s string) uint32 { - ptr := a.new(uint64(len(s) + 1)) +func (a *arena) string(s string) ptr_t { + ptr := a.new(int64(len(s)) + 1) util.WriteString(a.sqlt.mod, ptr, s) return ptr } diff --git a/sqlite_test.go b/sqlite_test.go index fbcd069b..5b969dea 100644 --- a/sqlite_test.go +++ b/sqlite_test.go @@ -22,7 +22,7 @@ func Test_sqlite_error_OOM(t *testing.T) { defer sqlite.close() defer func() { _ = recover() }() - sqlite.error(uint64(NOMEM), 0) + sqlite.error(res_t(NOMEM), 0) t.Error("want panic") } @@ -65,7 +65,7 @@ func Test_sqlite_newArena(t *testing.T) { } defer sqlite.close() - arena := sqlite.newArena(16) + arena := sqlite.newArena() defer arena.free() const title = "Lorem ipsum" @@ -73,7 +73,7 @@ func Test_sqlite_newArena(t *testing.T) { if ptr == 0 { t.Fatalf("got nullptr") } - if got := util.ReadString(sqlite.mod, ptr, math.MaxUint32); got != title { + if got := util.ReadString(sqlite.mod, ptr, math.MaxInt); got != title { t.Errorf("got %q, want %q", got, title) } @@ -82,7 +82,7 @@ func Test_sqlite_newArena(t *testing.T) { if ptr == 0 { t.Fatalf("got nullptr") } - if got := util.ReadString(sqlite.mod, ptr, math.MaxUint32); got != body { + if got := util.ReadString(sqlite.mod, ptr, math.MaxInt); got != body { t.Errorf("got %q, want %q", got, body) } @@ -94,7 +94,7 @@ func Test_sqlite_newArena(t *testing.T) { if ptr == 0 { t.Fatalf("got nullptr") } - if got := util.View(sqlite.mod, ptr, uint64(len(title))); string(got) != title { + if got := util.View(sqlite.mod, ptr, int64(len(title))); string(got) != title { t.Errorf("got %q, want %q", got, title) } @@ -122,7 +122,7 @@ func Test_sqlite_newBytes(t *testing.T) { } want := buf - if got := util.View(sqlite.mod, ptr, uint64(len(want))); !bytes.Equal(got, want) { + if got := util.View(sqlite.mod, ptr, int64(len(want))); !bytes.Equal(got, want) { t.Errorf("got %q, want %q", got, want) } @@ -157,7 +157,7 @@ func Test_sqlite_newString(t *testing.T) { } want := str + "\000" - if got := util.View(sqlite.mod, ptr, uint64(len(want))); string(got) != want { + if got := util.View(sqlite.mod, ptr, int64(len(want))); string(got) != want { t.Errorf("got %q, want %q", got, want) } } @@ -183,7 +183,7 @@ func Test_sqlite_getString(t *testing.T) { } want := "sqlite3" - if got := util.ReadString(sqlite.mod, ptr, math.MaxUint32); got != want { + if got := util.ReadString(sqlite.mod, ptr, math.MaxInt); got != want { t.Errorf("got %q, want %q", got, want) } if got := util.ReadString(sqlite.mod, ptr, 0); got != "" { @@ -192,13 +192,13 @@ func Test_sqlite_getString(t *testing.T) { func() { defer func() { _ = recover() }() - util.ReadString(sqlite.mod, ptr, uint32(len(want)/2)) + util.ReadString(sqlite.mod, ptr, int64(len(want))/2) t.Error("want panic") }() func() { defer func() { _ = recover() }() - util.ReadString(sqlite.mod, 0, math.MaxUint32) + util.ReadString(sqlite.mod, 0, math.MaxInt) t.Error("want panic") }() } diff --git a/stmt.go b/stmt.go index fdb13dcf..4e17d103 100644 --- a/stmt.go +++ b/stmt.go @@ -16,7 +16,7 @@ type Stmt struct { c *Conn err error sql string - handle uint32 + handle ptr_t } // Close destroys the prepared statement object. @@ -29,7 +29,7 @@ func (s *Stmt) Close() error { return nil } - r := s.c.call("sqlite3_finalize", uint64(s.handle)) + rc := res_t(s.c.call("sqlite3_finalize", stk_t(s.handle))) stmts := s.c.stmts for i := range stmts { if s == stmts[i] { @@ -42,7 +42,7 @@ func (s *Stmt) Close() error { } s.handle = 0 - return s.c.error(r) + return s.c.error(rc) } // Conn returns the database connection to which the prepared statement belongs. @@ -64,9 +64,9 @@ func (s *Stmt) SQL() string { // // https://sqlite.org/c3ref/expanded_sql.html func (s *Stmt) ExpandedSQL() string { - r := s.c.call("sqlite3_expanded_sql", uint64(s.handle)) - sql := util.ReadString(s.c.mod, uint32(r), _MAX_SQL_LENGTH) - s.c.free(uint32(r)) + ptr := ptr_t(s.c.call("sqlite3_expanded_sql", stk_t(s.handle))) + sql := util.ReadString(s.c.mod, ptr, _MAX_SQL_LENGTH) + s.c.free(ptr) return sql } @@ -75,25 +75,25 @@ func (s *Stmt) ExpandedSQL() string { // // https://sqlite.org/c3ref/stmt_readonly.html func (s *Stmt) ReadOnly() bool { - r := s.c.call("sqlite3_stmt_readonly", uint64(s.handle)) - return r != 0 + b := int32(s.c.call("sqlite3_stmt_readonly", stk_t(s.handle))) + return b != 0 } // Reset resets the prepared statement object. // // https://sqlite.org/c3ref/reset.html func (s *Stmt) Reset() error { - r := s.c.call("sqlite3_reset", uint64(s.handle)) + rc := res_t(s.c.call("sqlite3_reset", stk_t(s.handle))) s.err = nil - return s.c.error(r) + return s.c.error(rc) } // Busy determines if a prepared statement has been reset. // // https://sqlite.org/c3ref/stmt_busy.html func (s *Stmt) Busy() bool { - r := s.c.call("sqlite3_stmt_busy", uint64(s.handle)) - return r != 0 + rc := res_t(s.c.call("sqlite3_stmt_busy", stk_t(s.handle))) + return rc != 0 } // Step evaluates the SQL statement. @@ -107,15 +107,15 @@ func (s *Stmt) Busy() bool { // https://sqlite.org/c3ref/step.html func (s *Stmt) Step() bool { s.c.checkInterrupt(s.c.handle) - r := s.c.call("sqlite3_step", uint64(s.handle)) - switch r { + rc := res_t(s.c.call("sqlite3_step", stk_t(s.handle))) + switch rc { case _ROW: s.err = nil return true case _DONE: s.err = nil default: - s.err = s.c.error(r) + s.err = s.c.error(rc) } return false } @@ -143,30 +143,30 @@ func (s *Stmt) Status(op StmtStatus, reset bool) int { if op > STMTSTATUS_FILTER_HIT && op != STMTSTATUS_MEMUSED { return 0 } - var i uint64 + var i int32 if reset { i = 1 } - r := s.c.call("sqlite3_stmt_status", uint64(s.handle), - uint64(op), i) - return int(int32(r)) + n := int32(s.c.call("sqlite3_stmt_status", stk_t(s.handle), + stk_t(op), stk_t(i))) + return int(n) } // ClearBindings resets all bindings on the prepared statement. // // https://sqlite.org/c3ref/clear_bindings.html func (s *Stmt) ClearBindings() error { - r := s.c.call("sqlite3_clear_bindings", uint64(s.handle)) - return s.c.error(r) + rc := res_t(s.c.call("sqlite3_clear_bindings", stk_t(s.handle))) + return s.c.error(rc) } // BindCount returns the number of SQL parameters in the prepared statement. // // https://sqlite.org/c3ref/bind_parameter_count.html func (s *Stmt) BindCount() int { - r := s.c.call("sqlite3_bind_parameter_count", - uint64(s.handle)) - return int(int32(r)) + n := int32(s.c.call("sqlite3_bind_parameter_count", + stk_t(s.handle))) + return int(n) } // BindIndex returns the index of a parameter in the prepared statement @@ -176,9 +176,9 @@ func (s *Stmt) BindCount() int { func (s *Stmt) BindIndex(name string) int { defer s.c.arena.mark()() namePtr := s.c.arena.string(name) - r := s.c.call("sqlite3_bind_parameter_index", - uint64(s.handle), uint64(namePtr)) - return int(int32(r)) + i := int32(s.c.call("sqlite3_bind_parameter_index", + stk_t(s.handle), stk_t(namePtr))) + return int(i) } // BindName returns the name of a parameter in the prepared statement. @@ -186,10 +186,8 @@ func (s *Stmt) BindIndex(name string) int { // // https://sqlite.org/c3ref/bind_parameter_name.html func (s *Stmt) BindName(param int) string { - r := s.c.call("sqlite3_bind_parameter_name", - uint64(s.handle), uint64(param)) - - ptr := uint32(r) + ptr := ptr_t(s.c.call("sqlite3_bind_parameter_name", + stk_t(s.handle), stk_t(param))) if ptr == 0 { return "" } @@ -223,9 +221,9 @@ func (s *Stmt) BindInt(param int, value int) error { // // https://sqlite.org/c3ref/bind_blob.html func (s *Stmt) BindInt64(param int, value int64) error { - r := s.c.call("sqlite3_bind_int64", - uint64(s.handle), uint64(param), uint64(value)) - return s.c.error(r) + rc := res_t(s.c.call("sqlite3_bind_int64", + stk_t(s.handle), stk_t(param), stk_t(value))) + return s.c.error(rc) } // BindFloat binds a float64 to the prepared statement. @@ -233,9 +231,10 @@ func (s *Stmt) BindInt64(param int, value int64) error { // // https://sqlite.org/c3ref/bind_blob.html func (s *Stmt) BindFloat(param int, value float64) error { - r := s.c.call("sqlite3_bind_double", - uint64(s.handle), uint64(param), math.Float64bits(value)) - return s.c.error(r) + rc := res_t(s.c.call("sqlite3_bind_double", + stk_t(s.handle), stk_t(param), + stk_t(math.Float64bits(value)))) + return s.c.error(rc) } // BindText binds a string to the prepared statement. @@ -247,10 +246,10 @@ func (s *Stmt) BindText(param int, value string) error { return TOOBIG } ptr := s.c.newString(value) - r := s.c.call("sqlite3_bind_text_go", - uint64(s.handle), uint64(param), - uint64(ptr), uint64(len(value))) - return s.c.error(r) + rc := res_t(s.c.call("sqlite3_bind_text_go", + stk_t(s.handle), stk_t(param), + stk_t(ptr), stk_t(len(value)))) + return s.c.error(rc) } // BindRawText binds a []byte to the prepared statement as text. @@ -263,10 +262,10 @@ func (s *Stmt) BindRawText(param int, value []byte) error { return TOOBIG } ptr := s.c.newBytes(value) - r := s.c.call("sqlite3_bind_text_go", - uint64(s.handle), uint64(param), - uint64(ptr), uint64(len(value))) - return s.c.error(r) + rc := res_t(s.c.call("sqlite3_bind_text_go", + stk_t(s.handle), stk_t(param), + stk_t(ptr), stk_t(len(value)))) + return s.c.error(rc) } // BindBlob binds a []byte to the prepared statement. @@ -279,10 +278,10 @@ func (s *Stmt) BindBlob(param int, value []byte) error { return TOOBIG } ptr := s.c.newBytes(value) - r := s.c.call("sqlite3_bind_blob_go", - uint64(s.handle), uint64(param), - uint64(ptr), uint64(len(value))) - return s.c.error(r) + rc := res_t(s.c.call("sqlite3_bind_blob_go", + stk_t(s.handle), stk_t(param), + stk_t(ptr), stk_t(len(value)))) + return s.c.error(rc) } // BindZeroBlob binds a zero-filled, length n BLOB to the prepared statement. @@ -290,9 +289,9 @@ func (s *Stmt) BindBlob(param int, value []byte) error { // // https://sqlite.org/c3ref/bind_blob.html func (s *Stmt) BindZeroBlob(param int, n int64) error { - r := s.c.call("sqlite3_bind_zeroblob64", - uint64(s.handle), uint64(param), uint64(n)) - return s.c.error(r) + rc := res_t(s.c.call("sqlite3_bind_zeroblob64", + stk_t(s.handle), stk_t(param), stk_t(n))) + return s.c.error(rc) } // BindNull binds a NULL to the prepared statement. @@ -300,9 +299,9 @@ func (s *Stmt) BindZeroBlob(param int, n int64) error { // // https://sqlite.org/c3ref/bind_blob.html func (s *Stmt) BindNull(param int) error { - r := s.c.call("sqlite3_bind_null", - uint64(s.handle), uint64(param)) - return s.c.error(r) + rc := res_t(s.c.call("sqlite3_bind_null", + stk_t(s.handle), stk_t(param))) + return s.c.error(rc) } // BindTime binds a [time.Time] to the prepared statement. @@ -327,16 +326,16 @@ func (s *Stmt) BindTime(param int, value time.Time, format TimeFormat) error { } func (s *Stmt) bindRFC3339Nano(param int, value time.Time) error { - const maxlen = uint64(len(time.RFC3339Nano)) + 5 + const maxlen = int64(len(time.RFC3339Nano)) + 5 ptr := s.c.new(maxlen) buf := util.View(s.c.mod, ptr, maxlen) buf = value.AppendFormat(buf[:0], time.RFC3339Nano) - r := s.c.call("sqlite3_bind_text_go", - uint64(s.handle), uint64(param), - uint64(ptr), uint64(len(buf))) - return s.c.error(r) + rc := res_t(s.c.call("sqlite3_bind_text_go", + stk_t(s.handle), stk_t(param), + stk_t(ptr), stk_t(len(buf)))) + return s.c.error(rc) } // BindPointer binds a NULL to the prepared statement, just like [Stmt.BindNull], @@ -347,9 +346,9 @@ func (s *Stmt) bindRFC3339Nano(param int, value time.Time) error { // https://sqlite.org/c3ref/bind_blob.html func (s *Stmt) BindPointer(param int, ptr any) error { valPtr := util.AddHandle(s.c.ctx, ptr) - r := s.c.call("sqlite3_bind_pointer_go", - uint64(s.handle), uint64(param), uint64(valPtr)) - return s.c.error(r) + rc := res_t(s.c.call("sqlite3_bind_pointer_go", + stk_t(s.handle), stk_t(param), stk_t(valPtr))) + return s.c.error(rc) } // BindJSON binds the JSON encoding of value to the prepared statement. @@ -372,27 +371,27 @@ func (s *Stmt) BindValue(param int, value Value) error { if value.c != s.c { return MISUSE } - r := s.c.call("sqlite3_bind_value", - uint64(s.handle), uint64(param), uint64(value.handle)) - return s.c.error(r) + rc := res_t(s.c.call("sqlite3_bind_value", + stk_t(s.handle), stk_t(param), stk_t(value.handle))) + return s.c.error(rc) } // DataCount resets the number of columns in a result set. // // https://sqlite.org/c3ref/data_count.html func (s *Stmt) DataCount() int { - r := s.c.call("sqlite3_data_count", - uint64(s.handle)) - return int(int32(r)) + n := int32(s.c.call("sqlite3_data_count", + stk_t(s.handle))) + return int(n) } // ColumnCount returns the number of columns in a result set. // // https://sqlite.org/c3ref/column_count.html func (s *Stmt) ColumnCount() int { - r := s.c.call("sqlite3_column_count", - uint64(s.handle)) - return int(int32(r)) + n := int32(s.c.call("sqlite3_column_count", + stk_t(s.handle))) + return int(n) } // ColumnName returns the name of the result column. @@ -400,12 +399,12 @@ func (s *Stmt) ColumnCount() int { // // https://sqlite.org/c3ref/column_name.html func (s *Stmt) ColumnName(col int) string { - r := s.c.call("sqlite3_column_name", - uint64(s.handle), uint64(col)) - if r == 0 { + ptr := ptr_t(s.c.call("sqlite3_column_name", + stk_t(s.handle), stk_t(col))) + if ptr == 0 { panic(util.OOMErr) } - return util.ReadString(s.c.mod, uint32(r), _MAX_NAME) + return util.ReadString(s.c.mod, ptr, _MAX_NAME) } // ColumnType returns the initial [Datatype] of the result column. @@ -413,9 +412,8 @@ func (s *Stmt) ColumnName(col int) string { // // https://sqlite.org/c3ref/column_blob.html func (s *Stmt) ColumnType(col int) Datatype { - r := s.c.call("sqlite3_column_type", - uint64(s.handle), uint64(col)) - return Datatype(r) + return Datatype(s.c.call("sqlite3_column_type", + stk_t(s.handle), stk_t(col))) } // ColumnDeclType returns the declared datatype of the result column. @@ -423,12 +421,12 @@ func (s *Stmt) ColumnType(col int) Datatype { // // https://sqlite.org/c3ref/column_decltype.html func (s *Stmt) ColumnDeclType(col int) string { - r := s.c.call("sqlite3_column_decltype", - uint64(s.handle), uint64(col)) - if r == 0 { + ptr := ptr_t(s.c.call("sqlite3_column_decltype", + stk_t(s.handle), stk_t(col))) + if ptr == 0 { return "" } - return util.ReadString(s.c.mod, uint32(r), _MAX_NAME) + return util.ReadString(s.c.mod, ptr, _MAX_NAME) } // ColumnDatabaseName returns the name of the database @@ -437,12 +435,12 @@ func (s *Stmt) ColumnDeclType(col int) string { // // https://sqlite.org/c3ref/column_database_name.html func (s *Stmt) ColumnDatabaseName(col int) string { - r := s.c.call("sqlite3_column_database_name", - uint64(s.handle), uint64(col)) - if r == 0 { + ptr := ptr_t(s.c.call("sqlite3_column_database_name", + stk_t(s.handle), stk_t(col))) + if ptr == 0 { return "" } - return util.ReadString(s.c.mod, uint32(r), _MAX_NAME) + return util.ReadString(s.c.mod, ptr, _MAX_NAME) } // ColumnTableName returns the name of the table @@ -451,12 +449,12 @@ func (s *Stmt) ColumnDatabaseName(col int) string { // // https://sqlite.org/c3ref/column_database_name.html func (s *Stmt) ColumnTableName(col int) string { - r := s.c.call("sqlite3_column_table_name", - uint64(s.handle), uint64(col)) - if r == 0 { + ptr := ptr_t(s.c.call("sqlite3_column_table_name", + stk_t(s.handle), stk_t(col))) + if ptr == 0 { return "" } - return util.ReadString(s.c.mod, uint32(r), _MAX_NAME) + return util.ReadString(s.c.mod, ptr, _MAX_NAME) } // ColumnOriginName returns the name of the table column @@ -465,12 +463,12 @@ func (s *Stmt) ColumnTableName(col int) string { // // https://sqlite.org/c3ref/column_database_name.html func (s *Stmt) ColumnOriginName(col int) string { - r := s.c.call("sqlite3_column_origin_name", - uint64(s.handle), uint64(col)) - if r == 0 { + ptr := ptr_t(s.c.call("sqlite3_column_origin_name", + stk_t(s.handle), stk_t(col))) + if ptr == 0 { return "" } - return util.ReadString(s.c.mod, uint32(r), _MAX_NAME) + return util.ReadString(s.c.mod, ptr, _MAX_NAME) } // ColumnBool returns the value of the result column as a bool. @@ -497,9 +495,8 @@ func (s *Stmt) ColumnInt(col int) int { // // https://sqlite.org/c3ref/column_blob.html func (s *Stmt) ColumnInt64(col int) int64 { - r := s.c.call("sqlite3_column_int64", - uint64(s.handle), uint64(col)) - return int64(r) + return int64(s.c.call("sqlite3_column_int64", + stk_t(s.handle), stk_t(col))) } // ColumnFloat returns the value of the result column as a float64. @@ -507,9 +504,9 @@ func (s *Stmt) ColumnInt64(col int) int64 { // // https://sqlite.org/c3ref/column_blob.html func (s *Stmt) ColumnFloat(col int) float64 { - r := s.c.call("sqlite3_column_double", - uint64(s.handle), uint64(col)) - return math.Float64frombits(r) + f := uint64(s.c.call("sqlite3_column_double", + stk_t(s.handle), stk_t(col))) + return math.Float64frombits(f) } // ColumnTime returns the value of the result column as a [time.Time]. @@ -561,9 +558,9 @@ func (s *Stmt) ColumnBlob(col int, buf []byte) []byte { // // https://sqlite.org/c3ref/column_blob.html func (s *Stmt) ColumnRawText(col int) []byte { - r := s.c.call("sqlite3_column_text", - uint64(s.handle), uint64(col)) - return s.columnRawBytes(col, uint32(r)) + ptr := ptr_t(s.c.call("sqlite3_column_text", + stk_t(s.handle), stk_t(col))) + return s.columnRawBytes(col, ptr) } // ColumnRawBlob returns the value of the result column as a []byte. @@ -573,23 +570,23 @@ func (s *Stmt) ColumnRawText(col int) []byte { // // https://sqlite.org/c3ref/column_blob.html func (s *Stmt) ColumnRawBlob(col int) []byte { - r := s.c.call("sqlite3_column_blob", - uint64(s.handle), uint64(col)) - return s.columnRawBytes(col, uint32(r)) + ptr := ptr_t(s.c.call("sqlite3_column_blob", + stk_t(s.handle), stk_t(col))) + return s.columnRawBytes(col, ptr) } -func (s *Stmt) columnRawBytes(col int, ptr uint32) []byte { +func (s *Stmt) columnRawBytes(col int, ptr ptr_t) []byte { if ptr == 0 { - r := s.c.call("sqlite3_errcode", uint64(s.c.handle)) - if r != _ROW && r != _DONE { - s.err = s.c.error(r) + rc := res_t(s.c.call("sqlite3_errcode", stk_t(s.c.handle))) + if rc != _ROW && rc != _DONE { + s.err = s.c.error(rc) } return nil } - r := s.c.call("sqlite3_column_bytes", - uint64(s.handle), uint64(col)) - return util.View(s.c.mod, ptr, r) + n := int32(s.c.call("sqlite3_column_bytes", + stk_t(s.handle), stk_t(col))) + return util.View(s.c.mod, ptr, int64(n)) } // ColumnJSON parses the JSON-encoded value of the result column @@ -621,12 +618,12 @@ func (s *Stmt) ColumnJSON(col int, ptr any) error { // // https://sqlite.org/c3ref/column_blob.html func (s *Stmt) ColumnValue(col int) Value { - r := s.c.call("sqlite3_column_value", - uint64(s.handle), uint64(col)) + ptr := ptr_t(s.c.call("sqlite3_column_value", + stk_t(s.handle), stk_t(col))) return Value{ c: s.c, unprot: true, - handle: uint32(r), + handle: ptr, } } @@ -640,13 +637,13 @@ func (s *Stmt) ColumnValue(col int) Value { // subsequent calls to [Stmt] methods. func (s *Stmt) Columns(dest ...any) error { defer s.c.arena.mark()() - count := uint64(len(dest)) + count := int64(len(dest)) typePtr := s.c.arena.new(count) dataPtr := s.c.arena.new(count * 8) - r := s.c.call("sqlite3_columns_go", - uint64(s.handle), count, uint64(typePtr), uint64(dataPtr)) - if err := s.c.error(r); err != nil { + rc := res_t(s.c.call("sqlite3_columns_go", + stk_t(s.handle), stk_t(count), stk_t(typePtr), stk_t(dataPtr))) + if err := s.c.error(rc); err != nil { return err } @@ -660,19 +657,19 @@ func (s *Stmt) Columns(dest ...any) error { for i := range dest { switch types[i] { case byte(INTEGER): - dest[i] = int64(util.ReadUint64(s.c.mod, dataPtr)) + dest[i] = util.Read64[int64](s.c.mod, dataPtr) case byte(FLOAT): dest[i] = util.ReadFloat64(s.c.mod, dataPtr) case byte(NULL): dest[i] = nil default: - ptr := util.ReadUint32(s.c.mod, dataPtr+0) + ptr := util.Read32[ptr_t](s.c.mod, dataPtr+0) if ptr == 0 { dest[i] = []byte{} continue } - len := util.ReadUint32(s.c.mod, dataPtr+4) - buf := util.View(s.c.mod, ptr, uint64(len)) + len := util.Read32[int32](s.c.mod, dataPtr+4) + buf := util.View(s.c.mod, ptr, int64(len)) if types[i] == byte(TEXT) { dest[i] = string(buf) } else { diff --git a/txn.go b/txn.go index 57ba979a..b24789f8 100644 --- a/txn.go +++ b/txn.go @@ -229,13 +229,12 @@ func (c *Conn) txnExecInterrupted(sql string) error { // // https://sqlite.org/c3ref/txn_state.html func (c *Conn) TxnState(schema string) TxnState { - var ptr uint32 + var ptr ptr_t if schema != "" { defer c.arena.mark()() ptr = c.arena.string(schema) } - r := c.call("sqlite3_txn_state", uint64(c.handle), uint64(ptr)) - return TxnState(r) + return TxnState(c.call("sqlite3_txn_state", stk_t(c.handle), stk_t(ptr))) } // CommitHook registers a callback function to be invoked @@ -244,11 +243,11 @@ func (c *Conn) TxnState(schema string) TxnState { // // https://sqlite.org/c3ref/commit_hook.html func (c *Conn) CommitHook(cb func() (ok bool)) { - var enable uint64 + var enable int32 if cb != nil { enable = 1 } - c.call("sqlite3_commit_hook_go", uint64(c.handle), enable) + c.call("sqlite3_commit_hook_go", stk_t(c.handle), stk_t(enable)) c.commit = cb } @@ -257,11 +256,11 @@ func (c *Conn) CommitHook(cb func() (ok bool)) { // // https://sqlite.org/c3ref/commit_hook.html func (c *Conn) RollbackHook(cb func()) { - var enable uint64 + var enable int32 if cb != nil { enable = 1 } - c.call("sqlite3_rollback_hook_go", uint64(c.handle), enable) + c.call("sqlite3_rollback_hook_go", stk_t(c.handle), stk_t(enable)) c.rollback = cb } @@ -270,15 +269,15 @@ func (c *Conn) RollbackHook(cb func()) { // // https://sqlite.org/c3ref/update_hook.html func (c *Conn) UpdateHook(cb func(action AuthorizerActionCode, schema, table string, rowid int64)) { - var enable uint64 + var enable int32 if cb != nil { enable = 1 } - c.call("sqlite3_update_hook_go", uint64(c.handle), enable) + c.call("sqlite3_update_hook_go", stk_t(c.handle), stk_t(enable)) c.update = cb } -func commitCallback(ctx context.Context, mod api.Module, pDB uint32) (rollback uint32) { +func commitCallback(ctx context.Context, mod api.Module, pDB ptr_t) (rollback int32) { if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.commit != nil { if !c.commit() { rollback = 1 @@ -287,17 +286,17 @@ func commitCallback(ctx context.Context, mod api.Module, pDB uint32) (rollback u return rollback } -func rollbackCallback(ctx context.Context, mod api.Module, pDB uint32) { +func rollbackCallback(ctx context.Context, mod api.Module, pDB ptr_t) { if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.rollback != nil { c.rollback() } } -func updateCallback(ctx context.Context, mod api.Module, pDB uint32, action AuthorizerActionCode, zSchema, zTabName uint32, rowid uint64) { +func updateCallback(ctx context.Context, mod api.Module, pDB ptr_t, action AuthorizerActionCode, zSchema, zTabName ptr_t, rowid int64) { if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.update != nil { schema := util.ReadString(mod, zSchema, _MAX_NAME) table := util.ReadString(mod, zTabName, _MAX_NAME) - c.update(action, schema, table, int64(rowid)) + c.update(action, schema, table, rowid) } } @@ -305,6 +304,6 @@ func updateCallback(ctx context.Context, mod api.Module, pDB uint32, action Auth // // https://sqlite.org/c3ref/db_cacheflush.html func (c *Conn) CacheFlush() error { - r := c.call("sqlite3_db_cacheflush", uint64(c.handle)) - return c.error(r) + rc := res_t(c.call("sqlite3_db_cacheflush", stk_t(c.handle))) + return c.error(rc) } diff --git a/util/sql3util/parse.go b/util/sql3util/parse.go index f84fc4dd..7dd76ceb 100644 --- a/util/sql3util/parse.go +++ b/util/sql3util/parse.go @@ -50,7 +50,7 @@ func ParseTable(sql string) (_ *Table, err error) { copy(buf, sql) } - stack := [...]uint64{sqlp, uint64(len(sql)), errp} + stack := [...]util.Stk_t{sqlp, util.Stk_t(len(sql)), errp} err = mod.ExportedFunction("sql3parse_table").CallWithStack(ctx, stack[:]) if err != nil { return nil, err @@ -96,9 +96,9 @@ func (t *Table) load(mod api.Module, ptr uint32, sql string) { t.IsWithoutRowID = loadBool(mod, ptr+26) t.IsStrict = loadBool(mod, ptr+27) - t.Columns = loadSlice(mod, ptr+28, func(ptr uint32, res *Column) { + t.Columns = loadSlice(mod, ptr+28, func(ptr uint32, ret *Column) { p, _ := mod.Memory().ReadUint32Le(ptr) - res.load(mod, p, sql) + ret.load(mod, p, sql) }) t.Type = loadEnum[StatementType](mod, ptr+44) @@ -166,8 +166,8 @@ type ForeignKey struct { func (f *ForeignKey) load(mod api.Module, ptr uint32, sql string) { f.Table = loadString(mod, ptr+0, sql) - f.Columns = loadSlice(mod, ptr+8, func(ptr uint32, res *string) { - *res = loadString(mod, ptr, sql) + f.Columns = loadSlice(mod, ptr+8, func(ptr uint32, ret *string) { + *ret = loadString(mod, ptr, sql) }) f.OnDelete = loadEnum[FKAction](mod, ptr+16) @@ -191,12 +191,12 @@ func loadSlice[T any](mod api.Module, ptr uint32, fn func(uint32, *T)) []T { return nil } len, _ := mod.Memory().ReadUint32Le(ptr + 0) - res := make([]T, len) - for i := range res { - fn(ref, &res[i]) + ret := make([]T, len) + for i := range ret { + fn(ref, &ret[i]) ref += 4 } - return res + return ret } func loadEnum[T ~uint32](mod api.Module, ptr uint32) T { diff --git a/value.go b/value.go index 43b1a0f1..a2399fba 100644 --- a/value.go +++ b/value.go @@ -14,27 +14,27 @@ import ( // https://sqlite.org/c3ref/value.html type Value struct { c *Conn - handle uint32 + handle ptr_t unprot bool copied bool } -func (v Value) protected() uint64 { +func (v Value) protected() stk_t { if v.unprot { panic(util.ValueErr) } - return uint64(v.handle) + return stk_t(v.handle) } // Dup makes a copy of the SQL value and returns a pointer to that copy. // // https://sqlite.org/c3ref/value_dup.html func (v Value) Dup() *Value { - r := v.c.call("sqlite3_value_dup", uint64(v.handle)) + ptr := ptr_t(v.c.call("sqlite3_value_dup", stk_t(v.handle))) return &Value{ c: v.c, copied: true, - handle: uint32(r), + handle: ptr, } } @@ -45,7 +45,7 @@ func (dup *Value) Close() error { if !dup.copied { panic(util.ValueErr) } - dup.c.call("sqlite3_value_free", uint64(dup.handle)) + dup.c.call("sqlite3_value_free", stk_t(dup.handle)) dup.handle = 0 return nil } @@ -54,16 +54,14 @@ func (dup *Value) Close() error { // // https://sqlite.org/c3ref/value_blob.html func (v Value) Type() Datatype { - r := v.c.call("sqlite3_value_type", v.protected()) - return Datatype(r) + return Datatype(v.c.call("sqlite3_value_type", v.protected())) } // Type returns the numeric datatype of the value. // // https://sqlite.org/c3ref/value_blob.html func (v Value) NumericType() Datatype { - r := v.c.call("sqlite3_value_numeric_type", v.protected()) - return Datatype(r) + return Datatype(v.c.call("sqlite3_value_numeric_type", v.protected())) } // Bool returns the value as a bool. @@ -87,16 +85,15 @@ func (v Value) Int() int { // // https://sqlite.org/c3ref/value_blob.html func (v Value) Int64() int64 { - r := v.c.call("sqlite3_value_int64", v.protected()) - return int64(r) + return int64(v.c.call("sqlite3_value_int64", v.protected())) } // Float returns the value as a float64. // // https://sqlite.org/c3ref/value_blob.html func (v Value) Float() float64 { - r := v.c.call("sqlite3_value_double", v.protected()) - return math.Float64frombits(r) + f := uint64(v.c.call("sqlite3_value_double", v.protected())) + return math.Float64frombits(f) } // Time returns the value as a [time.Time]. @@ -141,8 +138,8 @@ func (v Value) Blob(buf []byte) []byte { // // https://sqlite.org/c3ref/value_blob.html func (v Value) RawText() []byte { - r := v.c.call("sqlite3_value_text", v.protected()) - return v.rawBytes(uint32(r)) + ptr := ptr_t(v.c.call("sqlite3_value_text", v.protected())) + return v.rawBytes(ptr) } // RawBlob returns the value as a []byte. @@ -151,24 +148,24 @@ func (v Value) RawText() []byte { // // https://sqlite.org/c3ref/value_blob.html func (v Value) RawBlob() []byte { - r := v.c.call("sqlite3_value_blob", v.protected()) - return v.rawBytes(uint32(r)) + ptr := ptr_t(v.c.call("sqlite3_value_blob", v.protected())) + return v.rawBytes(ptr) } -func (v Value) rawBytes(ptr uint32) []byte { +func (v Value) rawBytes(ptr ptr_t) []byte { if ptr == 0 { return nil } - r := v.c.call("sqlite3_value_bytes", v.protected()) - return util.View(v.c.mod, ptr, r) + n := int32(v.c.call("sqlite3_value_bytes", v.protected())) + return util.View(v.c.mod, ptr, int64(n)) } // Pointer gets the pointer associated with this value, // or nil if it has no associated pointer. func (v Value) Pointer() any { - r := v.c.call("sqlite3_value_pointer_go", v.protected()) - return util.GetHandle(v.c.ctx, uint32(r)) + ptr := ptr_t(v.c.call("sqlite3_value_pointer_go", v.protected())) + return util.GetHandle(v.c.ctx, ptr) } // JSON parses a JSON-encoded value @@ -197,16 +194,16 @@ func (v Value) JSON(ptr any) error { // // https://sqlite.org/c3ref/value_blob.html func (v Value) NoChange() bool { - r := v.c.call("sqlite3_value_nochange", v.protected()) - return r != 0 + b := int32(v.c.call("sqlite3_value_nochange", v.protected())) + return b != 0 } // FromBind returns true if value originated from a bound parameter. // // https://sqlite.org/c3ref/value_blob.html func (v Value) FromBind() bool { - r := v.c.call("sqlite3_value_frombind", v.protected()) - return r != 0 + b := int32(v.c.call("sqlite3_value_frombind", v.protected())) + return b != 0 } // InFirst returns the first element @@ -216,13 +213,13 @@ func (v Value) FromBind() bool { func (v Value) InFirst() (Value, error) { defer v.c.arena.mark()() valPtr := v.c.arena.new(ptrlen) - r := v.c.call("sqlite3_vtab_in_first", uint64(v.handle), uint64(valPtr)) - if err := v.c.error(r); err != nil { + rc := res_t(v.c.call("sqlite3_vtab_in_first", stk_t(v.handle), stk_t(valPtr))) + if err := v.c.error(rc); err != nil { return Value{}, err } return Value{ c: v.c, - handle: util.ReadUint32(v.c.mod, valPtr), + handle: util.Read32[ptr_t](v.c.mod, valPtr), }, nil } @@ -233,12 +230,12 @@ func (v Value) InFirst() (Value, error) { func (v Value) InNext() (Value, error) { defer v.c.arena.mark()() valPtr := v.c.arena.new(ptrlen) - r := v.c.call("sqlite3_vtab_in_next", uint64(v.handle), uint64(valPtr)) - if err := v.c.error(r); err != nil { + rc := res_t(v.c.call("sqlite3_vtab_in_next", stk_t(v.handle), stk_t(valPtr))) + if err := v.c.error(rc); err != nil { return Value{}, err } return Value{ c: v.c, - handle: util.ReadUint32(v.c.mod, valPtr), + handle: util.Read32[ptr_t](v.c.mod, valPtr), }, nil } diff --git a/vfs/api.go b/vfs/api.go index f2531f22..d5bb3a7a 100644 --- a/vfs/api.go +++ b/vfs/api.go @@ -193,7 +193,7 @@ type FileSharedMemory interface { // SharedMemory is a shared-memory WAL-index implementation. // Use [NewSharedMemory] to create a shared-memory. type SharedMemory interface { - shmMap(context.Context, api.Module, int32, int32, bool) (uint32, _ErrorCode) + shmMap(context.Context, api.Module, int32, int32, bool) (ptr_t, _ErrorCode) shmLock(int32, int32, _ShmFlag) _ErrorCode shmUnmap(bool) shmBarrier() @@ -207,7 +207,7 @@ type blockingSharedMemory interface { type fileControl interface { File - fileControl(ctx context.Context, mod api.Module, op _FcntlOpcode, pArg uint32) _ErrorCode + fileControl(ctx context.Context, mod api.Module, op _FcntlOpcode, pArg ptr_t) _ErrorCode } type filePDB interface { diff --git a/vfs/cksm.go b/vfs/cksm.go index 42d7468f..39493df9 100644 --- a/vfs/cksm.go +++ b/vfs/cksm.go @@ -102,14 +102,14 @@ func (c cksmFile) Pragma(name string, value string) (string, error) { } func (c cksmFile) DeviceCharacteristics() DeviceCharacteristic { - res := c.File.DeviceCharacteristics() + ret := c.File.DeviceCharacteristics() if c.verifyCksm { - res &^= IOCAP_SUBPAGE_READ + ret &^= IOCAP_SUBPAGE_READ } - return res + return ret } -func (c cksmFile) fileControl(ctx context.Context, mod api.Module, op _FcntlOpcode, pArg uint32) _ErrorCode { +func (c cksmFile) fileControl(ctx context.Context, mod api.Module, op _FcntlOpcode, pArg ptr_t) _ErrorCode { switch op { case _FCNTL_CKPT_START: c.inCkpt = true diff --git a/vfs/const.go b/vfs/const.go index 1c9b77a7..dc3b0db8 100644 --- a/vfs/const.go +++ b/vfs/const.go @@ -8,7 +8,12 @@ const ( _MAX_PATHNAME = 1024 _DEFAULT_SECTOR_SIZE = 4096 - ptrlen = 4 + ptrlen = util.PtrLen +) + +type ( + stk_t = util.Stk_t + ptr_t = util.Ptr_t ) // https://sqlite.org/rescode.html diff --git a/vfs/file.go b/vfs/file.go index e028a2a5..bc90555e 100644 --- a/vfs/file.go +++ b/vfs/file.go @@ -186,14 +186,14 @@ func (f *vfsFile) SectorSize() int { } func (f *vfsFile) DeviceCharacteristics() DeviceCharacteristic { - res := IOCAP_SUBPAGE_READ + ret := IOCAP_SUBPAGE_READ if osBatchAtomic(f.File) { - res |= IOCAP_BATCH_ATOMIC + ret |= IOCAP_BATCH_ATOMIC } if f.psow { - res |= IOCAP_POWERSAFE_OVERWRITE + ret |= IOCAP_POWERSAFE_OVERWRITE } - return res + return ret } func (f *vfsFile) SizeHint(size int64) error { diff --git a/vfs/filename.go b/vfs/filename.go index d9a29cd4..965c3b1a 100644 --- a/vfs/filename.go +++ b/vfs/filename.go @@ -16,13 +16,13 @@ import ( type Filename struct { ctx context.Context mod api.Module - zPath uint32 + zPath ptr_t flags OpenFlag - stack [2]uint64 + stack [2]stk_t } // GetFilename is an internal API users should not call directly. -func GetFilename(ctx context.Context, mod api.Module, id uint32, flags OpenFlag) *Filename { +func GetFilename(ctx context.Context, mod api.Module, id ptr_t, flags OpenFlag) *Filename { if id == 0 { return nil } @@ -71,12 +71,12 @@ func (n *Filename) path(method string) string { return "" } - n.stack[0] = uint64(n.zPath) + n.stack[0] = stk_t(n.zPath) fn := n.mod.ExportedFunction(method) if err := fn.CallWithStack(n.ctx, n.stack[:]); err != nil { panic(err) } - return util.ReadString(n.mod, uint32(n.stack[0]), _MAX_PATHNAME) + return util.ReadString(n.mod, ptr_t(n.stack[0]), _MAX_PATHNAME) } // DatabaseFile returns the main database [File] corresponding to a journal. @@ -90,12 +90,12 @@ func (n *Filename) DatabaseFile() File { return nil } - n.stack[0] = uint64(n.zPath) + n.stack[0] = stk_t(n.zPath) fn := n.mod.ExportedFunction("sqlite3_database_file_object") if err := fn.CallWithStack(n.ctx, n.stack[:]); err != nil { panic(err) } - file, _ := vfsFileGet(n.ctx, n.mod, uint32(n.stack[0])).(File) + file, _ := vfsFileGet(n.ctx, n.mod, ptr_t(n.stack[0])).(File) return file } @@ -108,13 +108,13 @@ func (n *Filename) URIParameter(key string) string { } uriKey := n.mod.ExportedFunction("sqlite3_uri_key") - n.stack[0] = uint64(n.zPath) - n.stack[1] = uint64(0) + n.stack[0] = stk_t(n.zPath) + n.stack[1] = stk_t(0) if err := uriKey.CallWithStack(n.ctx, n.stack[:]); err != nil { panic(err) } - ptr := uint32(n.stack[0]) + ptr := ptr_t(n.stack[0]) if ptr == 0 { return "" } @@ -127,13 +127,13 @@ func (n *Filename) URIParameter(key string) string { if k == "" { return "" } - ptr += uint32(len(k)) + 1 + ptr += ptr_t(len(k)) + 1 v := util.ReadString(n.mod, ptr, _MAX_NAME) if k == key { return v } - ptr += uint32(len(v)) + 1 + ptr += ptr_t(len(v)) + 1 } } @@ -146,13 +146,13 @@ func (n *Filename) URIParameters() url.Values { } uriKey := n.mod.ExportedFunction("sqlite3_uri_key") - n.stack[0] = uint64(n.zPath) - n.stack[1] = uint64(0) + n.stack[0] = stk_t(n.zPath) + n.stack[1] = stk_t(0) if err := uriKey.CallWithStack(n.ctx, n.stack[:]); err != nil { panic(err) } - ptr := uint32(n.stack[0]) + ptr := ptr_t(n.stack[0]) if ptr == 0 { return nil } @@ -167,13 +167,13 @@ func (n *Filename) URIParameters() url.Values { if k == "" { return params } - ptr += uint32(len(k)) + 1 + ptr += ptr_t(len(k)) + 1 v := util.ReadString(n.mod, ptr, _MAX_NAME) if params == nil { params = url.Values{} } params.Add(k, v) - ptr += uint32(len(v)) + 1 + ptr += ptr_t(len(v)) + 1 } } diff --git a/vfs/lock_test.go b/vfs/lock_test.go index 7c19bfd7..e8dfe845 100644 --- a/vfs/lock_test.go +++ b/vfs/lock_test.go @@ -47,21 +47,21 @@ func Test_vfsLock(t *testing.T) { if rc != _OK { t.Fatal("returned", rc) } - if got := util.ReadUint32(mod, pOutput); got != 0 { + if got := util.Read32[LockLevel](mod, pOutput); got != LOCK_NONE { t.Error("file was locked") } rc = vfsCheckReservedLock(ctx, mod, pFile2, pOutput) if rc != _OK { t.Fatal("returned", rc) } - if got := util.ReadUint32(mod, pOutput); got != 0 { + if got := util.Read32[LockLevel](mod, pOutput); got != LOCK_NONE { t.Error("file was locked") } rc = vfsFileControl(ctx, mod, pFile2, _FCNTL_LOCKSTATE, pOutput) if rc != _OK { t.Fatal("returned", rc) } - if got := util.ReadUint32(mod, pOutput); got != uint32(LOCK_NONE) { + if got := util.Read32[LockLevel](mod, pOutput); got != LOCK_NONE { t.Error("invalid lock state", got) } @@ -74,21 +74,21 @@ func Test_vfsLock(t *testing.T) { if rc != _OK { t.Fatal("returned", rc) } - if got := util.ReadUint32(mod, pOutput); got != 0 { + if got := util.Read32[LockLevel](mod, pOutput); got != LOCK_NONE { t.Error("file was locked") } rc = vfsCheckReservedLock(ctx, mod, pFile2, pOutput) if rc != _OK { t.Fatal("returned", rc) } - if got := util.ReadUint32(mod, pOutput); got != 0 { + if got := util.Read32[LockLevel](mod, pOutput); got != LOCK_NONE { t.Error("file was locked") } rc = vfsFileControl(ctx, mod, pFile2, _FCNTL_LOCKSTATE, pOutput) if rc != _OK { t.Fatal("returned", rc) } - if got := util.ReadUint32(mod, pOutput); got != uint32(LOCK_SHARED) { + if got := util.Read32[LockLevel](mod, pOutput); got != LOCK_SHARED { t.Error("invalid lock state", got) } @@ -105,21 +105,21 @@ func Test_vfsLock(t *testing.T) { if rc != _OK { t.Fatal("returned", rc) } - if got := util.ReadUint32(mod, pOutput); got == 0 { + if got := util.Read32[LockLevel](mod, pOutput); got == LOCK_NONE { t.Log("file wasn't locked, locking is incompatible with SQLite") } rc = vfsCheckReservedLock(ctx, mod, pFile2, pOutput) if rc != _OK { t.Fatal("returned", rc) } - if got := util.ReadUint32(mod, pOutput); got == 0 { + if got := util.Read32[LockLevel](mod, pOutput); got == LOCK_NONE { t.Error("file wasn't locked") } rc = vfsFileControl(ctx, mod, pFile2, _FCNTL_LOCKSTATE, pOutput) if rc != _OK { t.Fatal("returned", rc) } - if got := util.ReadUint32(mod, pOutput); got != uint32(LOCK_RESERVED) { + if got := util.Read32[LockLevel](mod, pOutput); got != LOCK_RESERVED { t.Error("invalid lock state", got) } @@ -132,21 +132,21 @@ func Test_vfsLock(t *testing.T) { if rc != _OK { t.Fatal("returned", rc) } - if got := util.ReadUint32(mod, pOutput); got == 0 { + if got := util.Read32[LockLevel](mod, pOutput); got == LOCK_NONE { t.Log("file wasn't locked, locking is incompatible with SQLite") } rc = vfsCheckReservedLock(ctx, mod, pFile2, pOutput) if rc != _OK { t.Fatal("returned", rc) } - if got := util.ReadUint32(mod, pOutput); got == 0 { + if got := util.Read32[LockLevel](mod, pOutput); got == LOCK_NONE { t.Error("file wasn't locked") } rc = vfsFileControl(ctx, mod, pFile2, _FCNTL_LOCKSTATE, pOutput) if rc != _OK { t.Fatal("returned", rc) } - if got := util.ReadUint32(mod, pOutput); got != uint32(LOCK_EXCLUSIVE) { + if got := util.Read32[LockLevel](mod, pOutput); got != LOCK_EXCLUSIVE { t.Error("invalid lock state", got) } @@ -159,21 +159,21 @@ func Test_vfsLock(t *testing.T) { if rc != _OK { t.Fatal("returned", rc) } - if got := util.ReadUint32(mod, pOutput); got == 0 { + if got := util.Read32[LockLevel](mod, pOutput); got == LOCK_NONE { t.Log("file wasn't locked, locking is incompatible with SQLite") } rc = vfsCheckReservedLock(ctx, mod, pFile2, pOutput) if rc != _OK { t.Fatal("returned", rc) } - if got := util.ReadUint32(mod, pOutput); got == 0 { + if got := util.Read32[LockLevel](mod, pOutput); got == LOCK_NONE { t.Error("file wasn't locked") } rc = vfsFileControl(ctx, mod, pFile1, _FCNTL_LOCKSTATE, pOutput) if rc != _OK { t.Fatal("returned", rc) } - if got := util.ReadUint32(mod, pOutput); got != uint32(LOCK_NONE) { + if got := util.Read32[LockLevel](mod, pOutput); got != LOCK_NONE { t.Error("invalid lock state", got) } @@ -186,14 +186,14 @@ func Test_vfsLock(t *testing.T) { if rc != _OK { t.Fatal("returned", rc) } - if got := util.ReadUint32(mod, pOutput); got != 0 { + if got := util.Read32[LockLevel](mod, pOutput); got != LOCK_NONE { t.Error("file was locked") } rc = vfsCheckReservedLock(ctx, mod, pFile2, pOutput) if rc != _OK { t.Fatal("returned", rc) } - if got := util.ReadUint32(mod, pOutput); got != 0 { + if got := util.Read32[LockLevel](mod, pOutput); got != LOCK_NONE { t.Error("file was locked") } @@ -205,7 +205,7 @@ func Test_vfsLock(t *testing.T) { if rc != _OK { t.Fatal("returned", rc) } - if got := util.ReadUint32(mod, pOutput); got != uint32(LOCK_SHARED) { + if got := util.Read32[LockLevel](mod, pOutput); got != LOCK_SHARED { t.Error("invalid lock state", got) } } diff --git a/vfs/memdb/memdb.go b/vfs/memdb/memdb.go index 4adb2dde..419fd1c6 100644 --- a/vfs/memdb/memdb.go +++ b/vfs/memdb/memdb.go @@ -10,9 +10,11 @@ import ( "github.com/ncruces/go-sqlite3/vfs" ) -// Must be a multiple of 64K (the largest page size). const sectorSize = 65536 +// Ensure sectorSize is a multiple of 64K (the largest page size). +var _ [0]struct{} = [sectorSize & 65535]struct{}{} + type memVFS struct{} func (memVFS) Open(name string, flags vfs.OpenFlag) (vfs.File, vfs.OpenFlag, error) { diff --git a/vfs/shm_bsd.go b/vfs/shm_bsd.go index 76e6888e..11e7bb2f 100644 --- a/vfs/shm_bsd.go +++ b/vfs/shm_bsd.go @@ -142,7 +142,7 @@ func (s *vfsShm) shmOpen() _ErrorCode { return _OK } -func (s *vfsShm) shmMap(ctx context.Context, mod api.Module, id, size int32, extend bool) (uint32, _ErrorCode) { +func (s *vfsShm) shmMap(ctx context.Context, mod api.Module, id, size int32, extend bool) (ptr_t, _ErrorCode) { // Ensure size is a multiple of the OS page size. if int(size)&(unix.Getpagesize()-1) != 0 { return 0, _IOERR_SHMMAP diff --git a/vfs/shm_dotlk.go b/vfs/shm_dotlk.go index 842bea8f..cb697a9c 100644 --- a/vfs/shm_dotlk.go +++ b/vfs/shm_dotlk.go @@ -35,8 +35,8 @@ type vfsShm struct { free api.Function path string shadow [][_WALINDEX_PGSZ]byte - ptrs []uint32 - stack [1]uint64 + ptrs []ptr_t + stack [1]stk_t lock [_SHM_NLOCK]bool } @@ -96,7 +96,7 @@ func (s *vfsShm) shmOpen() _ErrorCode { return _OK } -func (s *vfsShm) shmMap(ctx context.Context, mod api.Module, id, size int32, extend bool) (uint32, _ErrorCode) { +func (s *vfsShm) shmMap(ctx context.Context, mod api.Module, id, size int32, extend bool) (ptr_t, _ErrorCode) { if size != _WALINDEX_PGSZ { return 0, _IOERR_SHMMAP } @@ -128,15 +128,15 @@ func (s *vfsShm) shmMap(ctx context.Context, mod api.Module, id, size int32, ext // Allocate local memory. for int(id) >= len(s.ptrs) { - s.stack[0] = uint64(size) + s.stack[0] = stk_t(size) if err := s.alloc.CallWithStack(ctx, s.stack[:]); err != nil { panic(err) } if s.stack[0] == 0 { panic(util.OOMErr) } - clear(util.View(s.mod, uint32(s.stack[0]), _WALINDEX_PGSZ)) - s.ptrs = append(s.ptrs, uint32(s.stack[0])) + clear(util.View(s.mod, ptr_t(s.stack[0]), _WALINDEX_PGSZ)) + s.ptrs = append(s.ptrs, ptr_t(s.stack[0])) } s.shadow[0][4] = 1 @@ -168,7 +168,7 @@ func (s *vfsShm) shmUnmap(delete bool) { defer s.Unlock() for _, p := range s.ptrs { - s.stack[0] = uint64(p) + s.stack[0] = stk_t(p) if err := s.free.CallWithStack(context.Background(), s.stack[:]); err != nil { panic(err) } diff --git a/vfs/shm_ofd.go b/vfs/shm_ofd.go index dd361119..b0f50fcb 100644 --- a/vfs/shm_ofd.go +++ b/vfs/shm_ofd.go @@ -73,7 +73,7 @@ func (s *vfsShm) shmOpen() _ErrorCode { return rc } -func (s *vfsShm) shmMap(ctx context.Context, mod api.Module, id, size int32, extend bool) (uint32, _ErrorCode) { +func (s *vfsShm) shmMap(ctx context.Context, mod api.Module, id, size int32, extend bool) (ptr_t, _ErrorCode) { // Ensure size is a multiple of the OS page size. if int(size)&(unix.Getpagesize()-1) != 0 { return 0, _IOERR_SHMMAP diff --git a/vfs/shm_windows.go b/vfs/shm_windows.go index 1de57640..ed2e93f8 100644 --- a/vfs/shm_windows.go +++ b/vfs/shm_windows.go @@ -26,8 +26,8 @@ type vfsShm struct { regions []*util.MappedRegion shared [][]byte shadow [][_WALINDEX_PGSZ]byte - ptrs []uint32 - stack [1]uint64 + ptrs []ptr_t + stack [1]stk_t fileLock bool blocking bool sync.Mutex @@ -72,7 +72,7 @@ func (s *vfsShm) shmOpen() _ErrorCode { return rc } -func (s *vfsShm) shmMap(ctx context.Context, mod api.Module, id, size int32, extend bool) (_ uint32, rc _ErrorCode) { +func (s *vfsShm) shmMap(ctx context.Context, mod api.Module, id, size int32, extend bool) (_ ptr_t, rc _ErrorCode) { // Ensure size is a multiple of the OS page size. if size != _WALINDEX_PGSZ || (windows.Getpagesize()-1)&_WALINDEX_PGSZ != 0 { return 0, _IOERR_SHMMAP @@ -119,15 +119,15 @@ func (s *vfsShm) shmMap(ctx context.Context, mod api.Module, id, size int32, ext // Allocate local memory. for int(id) >= len(s.ptrs) { - s.stack[0] = uint64(size) + s.stack[0] = stk_t(size) if err := s.alloc.CallWithStack(ctx, s.stack[:]); err != nil { panic(err) } if s.stack[0] == 0 { panic(util.OOMErr) } - clear(util.View(s.mod, uint32(s.stack[0]), _WALINDEX_PGSZ)) - s.ptrs = append(s.ptrs, uint32(s.stack[0])) + clear(util.View(s.mod, ptr_t(s.stack[0]), _WALINDEX_PGSZ)) + s.ptrs = append(s.ptrs, ptr_t(s.stack[0])) } s.shadow[0][4] = 1 @@ -168,7 +168,7 @@ func (s *vfsShm) shmUnmap(delete bool) { // Free local memory. for _, p := range s.ptrs { - s.stack[0] = uint64(p) + s.stack[0] = stk_t(p) if err := s.free.CallWithStack(context.Background(), s.stack[:]); err != nil { panic(err) } diff --git a/vfs/vfs.go b/vfs/vfs.go index d8816e40..ca105fff 100644 --- a/vfs/vfs.go +++ b/vfs/vfs.go @@ -49,7 +49,7 @@ func ExportHostFunctions(env wazero.HostModuleBuilder) wazero.HostModuleBuilder return env } -func vfsFind(ctx context.Context, mod api.Module, zVfsName uint32) uint32 { +func vfsFind(ctx context.Context, mod api.Module, zVfsName ptr_t) uint32 { name := util.ReadString(mod, zVfsName, _MAX_NAME) if vfs := Find(name); vfs != nil && vfs != (vfsOS{}) { return 1 @@ -57,46 +57,46 @@ func vfsFind(ctx context.Context, mod api.Module, zVfsName uint32) uint32 { return 0 } -func vfsLocaltime(ctx context.Context, mod api.Module, pTm uint32, t int64) _ErrorCode { +func vfsLocaltime(ctx context.Context, mod api.Module, pTm ptr_t, t int64) _ErrorCode { tm := time.Unix(t, 0) - var isdst int + var isdst int32 if tm.IsDST() { isdst = 1 } const size = 32 / 8 // https://pubs.opengroup.org/onlinepubs/7908799/xsh/time.h.html - util.WriteUint32(mod, pTm+0*size, uint32(tm.Second())) - util.WriteUint32(mod, pTm+1*size, uint32(tm.Minute())) - util.WriteUint32(mod, pTm+2*size, uint32(tm.Hour())) - util.WriteUint32(mod, pTm+3*size, uint32(tm.Day())) - util.WriteUint32(mod, pTm+4*size, uint32(tm.Month()-time.January)) - util.WriteUint32(mod, pTm+5*size, uint32(tm.Year()-1900)) - util.WriteUint32(mod, pTm+6*size, uint32(tm.Weekday()-time.Sunday)) - util.WriteUint32(mod, pTm+7*size, uint32(tm.YearDay()-1)) - util.WriteUint32(mod, pTm+8*size, uint32(isdst)) + util.Write32(mod, pTm+0*size, int32(tm.Second())) + util.Write32(mod, pTm+1*size, int32(tm.Minute())) + util.Write32(mod, pTm+2*size, int32(tm.Hour())) + util.Write32(mod, pTm+3*size, int32(tm.Day())) + util.Write32(mod, pTm+4*size, int32(tm.Month()-time.January)) + util.Write32(mod, pTm+5*size, int32(tm.Year()-1900)) + util.Write32(mod, pTm+6*size, int32(tm.Weekday()-time.Sunday)) + util.Write32(mod, pTm+7*size, int32(tm.YearDay()-1)) + util.Write32(mod, pTm+8*size, isdst) return _OK } -func vfsRandomness(ctx context.Context, mod api.Module, pVfs uint32, nByte int32, zByte uint32) uint32 { - mem := util.View(mod, zByte, uint64(nByte)) +func vfsRandomness(ctx context.Context, mod api.Module, pVfs ptr_t, nByte int32, zByte ptr_t) uint32 { + mem := util.View(mod, zByte, int64(nByte)) n, _ := rand.Reader.Read(mem) return uint32(n) } -func vfsSleep(ctx context.Context, mod api.Module, pVfs uint32, nMicro int32) _ErrorCode { +func vfsSleep(ctx context.Context, mod api.Module, pVfs ptr_t, nMicro int32) _ErrorCode { time.Sleep(time.Duration(nMicro) * time.Microsecond) return _OK } -func vfsCurrentTime64(ctx context.Context, mod api.Module, pVfs, piNow uint32) _ErrorCode { +func vfsCurrentTime64(ctx context.Context, mod api.Module, pVfs, piNow ptr_t) _ErrorCode { day, nsec := julianday.Date(time.Now()) msec := day*86_400_000 + nsec/1_000_000 - util.WriteUint64(mod, piNow, uint64(msec)) + util.Write64(mod, piNow, msec) return _OK } -func vfsFullPathname(ctx context.Context, mod api.Module, pVfs, zRelative uint32, nFull int32, zFull uint32) _ErrorCode { +func vfsFullPathname(ctx context.Context, mod api.Module, pVfs, zRelative ptr_t, nFull int32, zFull ptr_t) _ErrorCode { vfs := vfsGet(mod, pVfs) path := util.ReadString(mod, zRelative, _MAX_PATHNAME) @@ -110,7 +110,7 @@ func vfsFullPathname(ctx context.Context, mod api.Module, pVfs, zRelative uint32 return vfsErrorCode(err, _CANTOPEN_FULLPATH) } -func vfsDelete(ctx context.Context, mod api.Module, pVfs, zPath, syncDir uint32) _ErrorCode { +func vfsDelete(ctx context.Context, mod api.Module, pVfs, zPath ptr_t, syncDir int32) _ErrorCode { vfs := vfsGet(mod, pVfs) path := util.ReadString(mod, zPath, _MAX_PATHNAME) @@ -118,21 +118,20 @@ func vfsDelete(ctx context.Context, mod api.Module, pVfs, zPath, syncDir uint32) return vfsErrorCode(err, _IOERR_DELETE) } -func vfsAccess(ctx context.Context, mod api.Module, pVfs, zPath uint32, flags AccessFlag, pResOut uint32) _ErrorCode { +func vfsAccess(ctx context.Context, mod api.Module, pVfs, zPath ptr_t, flags AccessFlag, pResOut ptr_t) _ErrorCode { vfs := vfsGet(mod, pVfs) path := util.ReadString(mod, zPath, _MAX_PATHNAME) ok, err := vfs.Access(path, flags) - var res uint32 + var res int32 if ok { res = 1 } - - util.WriteUint32(mod, pResOut, res) + util.Write32(mod, pResOut, res) return vfsErrorCode(err, _IOERR_ACCESS) } -func vfsOpen(ctx context.Context, mod api.Module, pVfs, zPath, pFile uint32, flags OpenFlag, pOutFlags, pOutVFS uint32) _ErrorCode { +func vfsOpen(ctx context.Context, mod api.Module, pVfs, zPath, pFile ptr_t, flags OpenFlag, pOutFlags, pOutVFS ptr_t) _ErrorCode { vfs := vfsGet(mod, pVfs) name := GetFilename(ctx, mod, zPath, flags) @@ -154,24 +153,24 @@ func vfsOpen(ctx context.Context, mod api.Module, pVfs, zPath, pFile uint32, fla } if file, ok := file.(FileSharedMemory); ok && pOutVFS != 0 && file.SharedMemory() != nil { - util.WriteUint32(mod, pOutVFS, 1) + util.Write32(mod, pOutVFS, int32(1)) } if pOutFlags != 0 { - util.WriteUint32(mod, pOutFlags, uint32(flags)) + util.Write32(mod, pOutFlags, flags) } file = cksmWrapFile(name, flags, file) vfsFileRegister(ctx, mod, pFile, file) return _OK } -func vfsClose(ctx context.Context, mod api.Module, pFile uint32) _ErrorCode { +func vfsClose(ctx context.Context, mod api.Module, pFile ptr_t) _ErrorCode { err := vfsFileClose(ctx, mod, pFile) return vfsErrorCode(err, _IOERR_CLOSE) } -func vfsRead(ctx context.Context, mod api.Module, pFile, zBuf uint32, iAmt int32, iOfst int64) _ErrorCode { +func vfsRead(ctx context.Context, mod api.Module, pFile, zBuf ptr_t, iAmt int32, iOfst int64) _ErrorCode { file := vfsFileGet(ctx, mod, pFile).(File) - buf := util.View(mod, zBuf, uint64(iAmt)) + buf := util.View(mod, zBuf, int64(iAmt)) n, err := file.ReadAt(buf, iOfst) if n == int(iAmt) { @@ -184,59 +183,58 @@ func vfsRead(ctx context.Context, mod api.Module, pFile, zBuf uint32, iAmt int32 return _IOERR_SHORT_READ } -func vfsWrite(ctx context.Context, mod api.Module, pFile, zBuf uint32, iAmt int32, iOfst int64) _ErrorCode { +func vfsWrite(ctx context.Context, mod api.Module, pFile, zBuf ptr_t, iAmt int32, iOfst int64) _ErrorCode { file := vfsFileGet(ctx, mod, pFile).(File) - buf := util.View(mod, zBuf, uint64(iAmt)) + buf := util.View(mod, zBuf, int64(iAmt)) _, err := file.WriteAt(buf, iOfst) return vfsErrorCode(err, _IOERR_WRITE) } -func vfsTruncate(ctx context.Context, mod api.Module, pFile uint32, nByte int64) _ErrorCode { +func vfsTruncate(ctx context.Context, mod api.Module, pFile ptr_t, nByte int64) _ErrorCode { file := vfsFileGet(ctx, mod, pFile).(File) err := file.Truncate(nByte) return vfsErrorCode(err, _IOERR_TRUNCATE) } -func vfsSync(ctx context.Context, mod api.Module, pFile uint32, flags SyncFlag) _ErrorCode { +func vfsSync(ctx context.Context, mod api.Module, pFile ptr_t, flags SyncFlag) _ErrorCode { file := vfsFileGet(ctx, mod, pFile).(File) err := file.Sync(flags) return vfsErrorCode(err, _IOERR_FSYNC) } -func vfsFileSize(ctx context.Context, mod api.Module, pFile, pSize uint32) _ErrorCode { +func vfsFileSize(ctx context.Context, mod api.Module, pFile, pSize ptr_t) _ErrorCode { file := vfsFileGet(ctx, mod, pFile).(File) size, err := file.Size() - util.WriteUint64(mod, pSize, uint64(size)) + util.Write64(mod, pSize, size) return vfsErrorCode(err, _IOERR_SEEK) } -func vfsLock(ctx context.Context, mod api.Module, pFile uint32, eLock LockLevel) _ErrorCode { +func vfsLock(ctx context.Context, mod api.Module, pFile ptr_t, eLock LockLevel) _ErrorCode { file := vfsFileGet(ctx, mod, pFile).(File) err := file.Lock(eLock) return vfsErrorCode(err, _IOERR_LOCK) } -func vfsUnlock(ctx context.Context, mod api.Module, pFile uint32, eLock LockLevel) _ErrorCode { +func vfsUnlock(ctx context.Context, mod api.Module, pFile ptr_t, eLock LockLevel) _ErrorCode { file := vfsFileGet(ctx, mod, pFile).(File) err := file.Unlock(eLock) return vfsErrorCode(err, _IOERR_UNLOCK) } -func vfsCheckReservedLock(ctx context.Context, mod api.Module, pFile, pResOut uint32) _ErrorCode { +func vfsCheckReservedLock(ctx context.Context, mod api.Module, pFile, pResOut ptr_t) _ErrorCode { file := vfsFileGet(ctx, mod, pFile).(File) locked, err := file.CheckReservedLock() - var res uint32 + var res int32 if locked { res = 1 } - - util.WriteUint32(mod, pResOut, res) + util.Write32(mod, pResOut, res) return vfsErrorCode(err, _IOERR_CHECKRESERVEDLOCK) } -func vfsFileControl(ctx context.Context, mod api.Module, pFile uint32, op _FcntlOpcode, pArg uint32) _ErrorCode { +func vfsFileControl(ctx context.Context, mod api.Module, pFile ptr_t, op _FcntlOpcode, pArg ptr_t) _ErrorCode { file := vfsFileGet(ctx, mod, pFile).(File) if file, ok := file.(fileControl); ok { return file.fileControl(ctx, mod, op, pArg) @@ -244,62 +242,62 @@ func vfsFileControl(ctx context.Context, mod api.Module, pFile uint32, op _Fcntl return vfsFileControlImpl(ctx, mod, file, op, pArg) } -func vfsFileControlImpl(ctx context.Context, mod api.Module, file File, op _FcntlOpcode, pArg uint32) _ErrorCode { +func vfsFileControlImpl(ctx context.Context, mod api.Module, file File, op _FcntlOpcode, pArg ptr_t) _ErrorCode { switch op { case _FCNTL_LOCKSTATE: if file, ok := file.(FileLockState); ok { if lk := file.LockState(); lk <= LOCK_EXCLUSIVE { - util.WriteUint32(mod, pArg, uint32(lk)) + util.Write32(mod, pArg, lk) return _OK } } case _FCNTL_PERSIST_WAL: if file, ok := file.(FilePersistWAL); ok { - if i := util.ReadUint32(mod, pArg); int32(i) >= 0 { + if i := util.Read32[int32](mod, pArg); i >= 0 { file.SetPersistWAL(i != 0) } else if file.PersistWAL() { - util.WriteUint32(mod, pArg, 1) + util.Write32(mod, pArg, int32(1)) } else { - util.WriteUint32(mod, pArg, 0) + util.Write32(mod, pArg, int32(0)) } return _OK } case _FCNTL_POWERSAFE_OVERWRITE: if file, ok := file.(FilePowersafeOverwrite); ok { - if i := util.ReadUint32(mod, pArg); int32(i) >= 0 { + if i := util.Read32[int32](mod, pArg); i >= 0 { file.SetPowersafeOverwrite(i != 0) } else if file.PowersafeOverwrite() { - util.WriteUint32(mod, pArg, 1) + util.Write32(mod, pArg, int32(1)) } else { - util.WriteUint32(mod, pArg, 0) + util.Write32(mod, pArg, int32(0)) } return _OK } case _FCNTL_CHUNK_SIZE: if file, ok := file.(FileChunkSize); ok { - size := util.ReadUint32(mod, pArg) + size := util.Read32[int32](mod, pArg) file.ChunkSize(int(size)) return _OK } case _FCNTL_SIZE_HINT: if file, ok := file.(FileSizeHint); ok { - size := util.ReadUint64(mod, pArg) - err := file.SizeHint(int64(size)) + size := util.Read64[int64](mod, pArg) + err := file.SizeHint(size) return vfsErrorCode(err, _IOERR_TRUNCATE) } case _FCNTL_HAS_MOVED: if file, ok := file.(FileHasMoved); ok { moved, err := file.HasMoved() - var res uint32 + var val uint32 if moved { - res = 1 + val = 1 } - util.WriteUint32(mod, pArg, res) + util.Write32(mod, pArg, val) return vfsErrorCode(err, _IOERR_FSTAT) } @@ -354,10 +352,10 @@ func vfsFileControlImpl(ctx context.Context, mod api.Module, file File, op _Fcnt case _FCNTL_PRAGMA: if file, ok := file.(FilePragma); ok { - ptr := util.ReadUint32(mod, pArg+1*ptrlen) + ptr := util.Read32[ptr_t](mod, pArg+1*ptrlen) name := util.ReadString(mod, ptr, _MAX_SQL_LENGTH) var value string - if ptr := util.ReadUint32(mod, pArg+2*ptrlen); ptr != 0 { + if ptr := util.Read32[ptr_t](mod, pArg+2*ptrlen); ptr != 0 { value = util.ReadString(mod, ptr, _MAX_SQL_LENGTH) } @@ -369,22 +367,22 @@ func vfsFileControlImpl(ctx context.Context, mod api.Module, file File, op _Fcnt } if out != "" { fn := mod.ExportedFunction("sqlite3_malloc64") - stack := [...]uint64{uint64(len(out) + 1)} + stack := [...]stk_t{stk_t(len(out) + 1)} if err := fn.CallWithStack(ctx, stack[:]); err != nil { panic(err) } - util.WriteUint32(mod, pArg, uint32(stack[0])) - util.WriteString(mod, uint32(stack[0]), out) + util.Write32(mod, pArg, ptr_t(stack[0])) + util.WriteString(mod, ptr_t(stack[0]), out) } return ret } case _FCNTL_BUSYHANDLER: if file, ok := file.(FileBusyHandler); ok { - arg := util.ReadUint64(mod, pArg) + arg := util.Read64[stk_t](mod, pArg) fn := mod.ExportedFunction("sqlite3_invoke_busy_handler_go") file.BusyHandler(func() bool { - stack := [...]uint64{arg} + stack := [...]stk_t{arg} if err := fn.CallWithStack(ctx, stack[:]); err != nil { panic(err) } @@ -396,7 +394,7 @@ func vfsFileControlImpl(ctx context.Context, mod api.Module, file File, op _Fcnt case _FCNTL_LOCK_TIMEOUT: if file, ok := file.(FileSharedMemory); ok { if shm, ok := file.SharedMemory().(blockingSharedMemory); ok { - shm.shmEnableBlocking(util.ReadUint32(mod, pArg) != 0) + shm.shmEnableBlocking(util.Read32[uint32](mod, pArg) != 0) return _OK } } @@ -411,44 +409,45 @@ func vfsFileControlImpl(ctx context.Context, mod api.Module, file File, op _Fcnt return _NOTFOUND } -func vfsSectorSize(ctx context.Context, mod api.Module, pFile uint32) uint32 { +func vfsSectorSize(ctx context.Context, mod api.Module, pFile ptr_t) uint32 { file := vfsFileGet(ctx, mod, pFile).(File) return uint32(file.SectorSize()) } -func vfsDeviceCharacteristics(ctx context.Context, mod api.Module, pFile uint32) DeviceCharacteristic { +func vfsDeviceCharacteristics(ctx context.Context, mod api.Module, pFile ptr_t) DeviceCharacteristic { file := vfsFileGet(ctx, mod, pFile).(File) return file.DeviceCharacteristics() } -func vfsShmBarrier(ctx context.Context, mod api.Module, pFile uint32) { +func vfsShmBarrier(ctx context.Context, mod api.Module, pFile ptr_t) { shm := vfsFileGet(ctx, mod, pFile).(FileSharedMemory).SharedMemory() shm.shmBarrier() } -func vfsShmMap(ctx context.Context, mod api.Module, pFile uint32, iRegion, szRegion int32, bExtend, pp uint32) _ErrorCode { +func vfsShmMap(ctx context.Context, mod api.Module, pFile ptr_t, iRegion, szRegion, bExtend int32, pp ptr_t) _ErrorCode { shm := vfsFileGet(ctx, mod, pFile).(FileSharedMemory).SharedMemory() p, rc := shm.shmMap(ctx, mod, iRegion, szRegion, bExtend != 0) - util.WriteUint32(mod, pp, p) + util.Write32(mod, pp, p) return rc } -func vfsShmLock(ctx context.Context, mod api.Module, pFile uint32, offset, n int32, flags _ShmFlag) _ErrorCode { +func vfsShmLock(ctx context.Context, mod api.Module, pFile ptr_t, offset, n int32, flags _ShmFlag) _ErrorCode { shm := vfsFileGet(ctx, mod, pFile).(FileSharedMemory).SharedMemory() return shm.shmLock(offset, n, flags) } -func vfsShmUnmap(ctx context.Context, mod api.Module, pFile, bDelete uint32) _ErrorCode { +func vfsShmUnmap(ctx context.Context, mod api.Module, pFile ptr_t, bDelete int32) _ErrorCode { shm := vfsFileGet(ctx, mod, pFile).(FileSharedMemory).SharedMemory() shm.shmUnmap(bDelete != 0) return _OK } -func vfsGet(mod api.Module, pVfs uint32) VFS { +func vfsGet(mod api.Module, pVfs ptr_t) VFS { var name string if pVfs != 0 { const zNameOffset = 16 - name = util.ReadString(mod, util.ReadUint32(mod, pVfs+zNameOffset), _MAX_NAME) + ptr := util.Read32[ptr_t](mod, pVfs+zNameOffset) + name = util.ReadString(mod, ptr, _MAX_NAME) } if vfs := Find(name); vfs != nil { return vfs @@ -456,21 +455,21 @@ func vfsGet(mod api.Module, pVfs uint32) VFS { panic(util.NoVFSErr + util.ErrorString(name)) } -func vfsFileRegister(ctx context.Context, mod api.Module, pFile uint32, file File) { +func vfsFileRegister(ctx context.Context, mod api.Module, pFile ptr_t, file File) { const fileHandleOffset = 4 id := util.AddHandle(ctx, file) - util.WriteUint32(mod, pFile+fileHandleOffset, id) + util.Write32(mod, pFile+fileHandleOffset, id) } -func vfsFileGet(ctx context.Context, mod api.Module, pFile uint32) any { +func vfsFileGet(ctx context.Context, mod api.Module, pFile ptr_t) any { const fileHandleOffset = 4 - id := util.ReadUint32(mod, pFile+fileHandleOffset) + id := util.Read32[ptr_t](mod, pFile+fileHandleOffset) return util.GetHandle(ctx, id) } -func vfsFileClose(ctx context.Context, mod api.Module, pFile uint32) error { +func vfsFileClose(ctx context.Context, mod api.Module, pFile ptr_t) error { const fileHandleOffset = 4 - id := util.ReadUint32(mod, pFile+fileHandleOffset) + id := util.Read32[ptr_t](mod, pFile+fileHandleOffset) return util.DelHandle(ctx, id) } diff --git a/vfs/vfs_test.go b/vfs/vfs_test.go index 79c8c3e9..64d9a21d 100644 --- a/vfs/vfs_test.go +++ b/vfs/vfs_test.go @@ -5,7 +5,6 @@ import ( "context" "errors" "io/fs" - "math" "os" "os/user" "path/filepath" @@ -29,28 +28,28 @@ func Test_vfsLocaltime(t *testing.T) { t.Fatal("returned", rc) } - if s := util.ReadUint32(mod, 4+0*4); int(s) != tm.Second() { + if s := util.Read32[int32](mod, 4+0*4); int(s) != tm.Second() { t.Error("wrong second") } - if m := util.ReadUint32(mod, 4+1*4); int(m) != tm.Minute() { + if m := util.Read32[int32](mod, 4+1*4); int(m) != tm.Minute() { t.Error("wrong minute") } - if h := util.ReadUint32(mod, 4+2*4); int(h) != tm.Hour() { + if h := util.Read32[int32](mod, 4+2*4); int(h) != tm.Hour() { t.Error("wrong hour") } - if d := util.ReadUint32(mod, 4+3*4); int(d) != tm.Day() { + if d := util.Read32[int32](mod, 4+3*4); int(d) != tm.Day() { t.Error("wrong day") } - if m := util.ReadUint32(mod, 4+4*4); time.Month(1+m) != tm.Month() { + if m := util.Read32[int32](mod, 4+4*4); time.Month(1+m) != tm.Month() { t.Error("wrong month") } - if y := util.ReadUint32(mod, 4+5*4); 1900+int(y) != tm.Year() { + if y := util.Read32[int32](mod, 4+5*4); 1900+int(y) != tm.Year() { t.Error("wrong year") } - if w := util.ReadUint32(mod, 4+6*4); time.Weekday(w) != tm.Weekday() { + if w := util.Read32[int32](mod, 4+6*4); time.Weekday(w) != tm.Weekday() { t.Error("wrong weekday") } - if d := util.ReadUint32(mod, 4+7*4); int(d) != tm.YearDay()-1 { + if d := util.Read32[int32](mod, 4+7*4); int(d) != tm.YearDay()-1 { t.Error("wrong yearday") } } @@ -99,7 +98,7 @@ func Test_vfsCurrentTime64(t *testing.T) { day, nsec := julianday.Date(now) want := day*86_400_000 + nsec/1_000_000 - if got := util.ReadUint64(mod, 4); float32(got) != float32(want) { + if got := util.Read64[int64](mod, 4); float32(got) != float32(want) { t.Errorf("got %v, want %v", got, want) } } @@ -173,7 +172,7 @@ func Test_vfsAccess(t *testing.T) { if rc != _OK { t.Fatal("returned", rc) } - if got := util.ReadUint32(mod, 4); got != 1 { + if got := util.Read32[int32](mod, 4); got != 1 { t.Error("directory did not exist") } @@ -181,7 +180,7 @@ func Test_vfsAccess(t *testing.T) { if rc != _OK { t.Fatal("returned", rc) } - if got := util.ReadUint32(mod, 4); got != 1 { + if got := util.Read32[int32](mod, 4); got != 1 { t.Error("can't access directory") } @@ -190,7 +189,7 @@ func Test_vfsAccess(t *testing.T) { if rc != _OK { t.Fatal("returned", rc) } - if got := util.ReadUint32(mod, 4); got != 1 { + if got := util.Read32[int32](mod, 4); got != 1 { t.Error("can't access file") } @@ -208,7 +207,7 @@ func Test_vfsAccess(t *testing.T) { if rc != _OK { t.Fatal("returned", rc) } - if got := util.ReadUint32(mod, 4); got != 0 { + if got := util.Read32[int32](mod, 4); got != 0 { t.Error("can access file") } } @@ -241,7 +240,7 @@ func Test_vfsFile(t *testing.T) { if rc != _OK { t.Fatal("returned", rc) } - if got := util.ReadUint32(mod, 16); got != uint32(len(text)) { + if got := util.Read64[int64](mod, 16); got != int64(len(text)) { t.Errorf("got %d", got) } @@ -265,7 +264,7 @@ func Test_vfsFile(t *testing.T) { if rc != _OK { t.Fatal("returned", rc) } - if got := util.ReadUint32(mod, 16); got != 4 { + if got := util.Read64[int64](mod, 16); got != 4 { t.Errorf("got %d", got) } @@ -296,46 +295,46 @@ func Test_vfsFile_psow(t *testing.T) { } // Read powersafe overwrite. - util.WriteUint32(mod, 16, math.MaxUint32) + util.Write32(mod, 16, int32(-1)) rc = vfsFileControl(ctx, mod, 4, _FCNTL_POWERSAFE_OVERWRITE, 16) if rc != _OK { t.Fatal("returned", rc) } - if got := util.ReadUint32(mod, 16); got == 0 { + if got := util.Read32[int32](mod, 16); got == 0 { t.Error("psow disabled") } // Unset powersafe overwrite. - util.WriteUint32(mod, 16, 0) + util.Write32(mod, 16, int32(0)) rc = vfsFileControl(ctx, mod, 4, _FCNTL_POWERSAFE_OVERWRITE, 16) if rc != _OK { t.Fatal("returned", rc) } // Read powersafe overwrite. - util.WriteUint32(mod, 16, math.MaxUint32) + util.Write32(mod, 16, int32(-1)) rc = vfsFileControl(ctx, mod, 4, _FCNTL_POWERSAFE_OVERWRITE, 16) if rc != _OK { t.Fatal("returned", rc) } - if got := util.ReadUint32(mod, 16); got != 0 { + if got := util.Read32[int32](mod, 16); got != 0 { t.Error("psow enabled") } // Set powersafe overwrite. - util.WriteUint32(mod, 16, 1) + util.Write32(mod, 16, int32(1)) rc = vfsFileControl(ctx, mod, 4, _FCNTL_POWERSAFE_OVERWRITE, 16) if rc != _OK { t.Fatal("returned", rc) } // Read powersafe overwrite. - util.WriteUint32(mod, 16, math.MaxUint32) + util.Write32(mod, 16, int32(-1)) rc = vfsFileControl(ctx, mod, 4, _FCNTL_POWERSAFE_OVERWRITE, 16) if rc != _OK { t.Fatal("returned", rc) } - if got := util.ReadUint32(mod, 16); got == 0 { + if got := util.Read32[int32](mod, 16); got == 0 { t.Error("psow disabled") } diff --git a/vtab.go b/vtab.go index 1998a528..7b29a5e4 100644 --- a/vtab.go +++ b/vtab.go @@ -58,15 +58,15 @@ func CreateModule[T VTab](db *Conn, name string, create, connect VTabConstructor flags |= VTAB_SHADOWTABS } - var modulePtr uint32 + var modulePtr ptr_t defer db.arena.mark()() namePtr := db.arena.string(name) if connect != nil { modulePtr = util.AddHandle(db.ctx, module[T]{create, connect}) } - r := db.call("sqlite3_create_module_go", uint64(db.handle), - uint64(namePtr), uint64(flags), uint64(modulePtr)) - return db.error(r) + rc := res_t(db.call("sqlite3_create_module_go", stk_t(db.handle), + stk_t(namePtr), stk_t(flags), stk_t(modulePtr))) + return db.error(rc) } func implements[T any](typ reflect.Type) bool { @@ -80,8 +80,8 @@ func implements[T any](typ reflect.Type) bool { func (c *Conn) DeclareVTab(sql string) error { defer c.arena.mark()() sqlPtr := c.arena.string(sql) - r := c.call("sqlite3_declare_vtab", uint64(c.handle), uint64(sqlPtr)) - return c.error(r) + rc := res_t(c.call("sqlite3_declare_vtab", stk_t(c.handle), stk_t(sqlPtr))) + return c.error(rc) } // VTabConflictMode is a virtual table conflict resolution mode. @@ -101,8 +101,7 @@ const ( // // https://sqlite.org/c3ref/vtab_on_conflict.html func (c *Conn) VTabOnConflict() VTabConflictMode { - r := c.call("sqlite3_vtab_on_conflict", uint64(c.handle)) - return VTabConflictMode(r) + return VTabConflictMode(c.call("sqlite3_vtab_on_conflict", stk_t(c.handle))) } // VTabConfigOption is a virtual table configuration option. @@ -121,14 +120,14 @@ const ( // // https://sqlite.org/c3ref/vtab_config.html func (c *Conn) VTabConfig(op VTabConfigOption, args ...any) error { - var i uint64 + var i int32 if op == VTAB_CONSTRAINT_SUPPORT && len(args) > 0 { if b, ok := args[0].(bool); ok && b { i = 1 } } - r := c.call("sqlite3_vtab_config_go", uint64(c.handle), uint64(op), i) - return c.error(r) + rc := res_t(c.call("sqlite3_vtab_config_go", stk_t(c.handle), stk_t(op), stk_t(i))) + return c.error(rc) } // VTabConstructor is a virtual table constructor function. @@ -263,7 +262,7 @@ type IndexInfo struct { // Inputs Constraint []IndexConstraint OrderBy []IndexOrderBy - ColumnsUsed int64 + ColumnsUsed uint64 // Outputs ConstraintUsage []IndexConstraintUsage IdxNum int @@ -274,7 +273,7 @@ type IndexInfo struct { EstimatedRows int64 // Internal c *Conn - handle uint32 + handle ptr_t } // An IndexConstraint describes virtual table indexing constraint information. @@ -309,14 +308,14 @@ type IndexConstraintUsage struct { func (idx *IndexInfo) RHSValue(column int) (Value, error) { defer idx.c.arena.mark()() valPtr := idx.c.arena.new(ptrlen) - r := idx.c.call("sqlite3_vtab_rhs_value", uint64(idx.handle), - uint64(column), uint64(valPtr)) - if err := idx.c.error(r); err != nil { + rc := res_t(idx.c.call("sqlite3_vtab_rhs_value", stk_t(idx.handle), + stk_t(column), stk_t(valPtr))) + if err := idx.c.error(rc); err != nil { return Value{}, err } return Value{ c: idx.c, - handle: util.ReadUint32(idx.c.mod, valPtr), + handle: util.Read32[ptr_t](idx.c.mod, valPtr), }, nil } @@ -324,26 +323,26 @@ func (idx *IndexInfo) RHSValue(column int) (Value, error) { // // https://sqlite.org/c3ref/vtab_collation.html func (idx *IndexInfo) Collation(column int) string { - r := idx.c.call("sqlite3_vtab_collation", uint64(idx.handle), - uint64(column)) - return util.ReadString(idx.c.mod, uint32(r), _MAX_NAME) + ptr := ptr_t(idx.c.call("sqlite3_vtab_collation", stk_t(idx.handle), + stk_t(column))) + return util.ReadString(idx.c.mod, ptr, _MAX_NAME) } // Distinct determines if a virtual table query is DISTINCT. // // https://sqlite.org/c3ref/vtab_distinct.html func (idx *IndexInfo) Distinct() int { - r := idx.c.call("sqlite3_vtab_distinct", uint64(idx.handle)) - return int(r) + i := int32(idx.c.call("sqlite3_vtab_distinct", stk_t(idx.handle))) + return int(i) } // In identifies and handles IN constraints. // // https://sqlite.org/c3ref/vtab_in.html func (idx *IndexInfo) In(column, handle int) bool { - r := idx.c.call("sqlite3_vtab_in", uint64(idx.handle), - uint64(column), uint64(handle)) - return r != 0 + b := int32(idx.c.call("sqlite3_vtab_in", stk_t(idx.handle), + stk_t(column), stk_t(handle))) + return b != 0 } func (idx *IndexInfo) load() { @@ -351,34 +350,35 @@ func (idx *IndexInfo) load() { mod := idx.c.mod ptr := idx.handle - idx.Constraint = make([]IndexConstraint, util.ReadUint32(mod, ptr+0)) - idx.ConstraintUsage = make([]IndexConstraintUsage, util.ReadUint32(mod, ptr+0)) - idx.OrderBy = make([]IndexOrderBy, util.ReadUint32(mod, ptr+8)) + nConstraint := util.Read32[int32](mod, ptr+0) + idx.Constraint = make([]IndexConstraint, nConstraint) + idx.ConstraintUsage = make([]IndexConstraintUsage, nConstraint) + idx.OrderBy = make([]IndexOrderBy, util.Read32[int32](mod, ptr+8)) - constraintPtr := util.ReadUint32(mod, ptr+4) + constraintPtr := util.Read32[ptr_t](mod, ptr+4) constraint := idx.Constraint for i := range idx.Constraint { constraint[i] = IndexConstraint{ - Column: int(int32(util.ReadUint32(mod, constraintPtr+0))), - Op: IndexConstraintOp(util.ReadUint8(mod, constraintPtr+4)), - Usable: util.ReadUint8(mod, constraintPtr+5) != 0, + Column: int(util.Read32[int32](mod, constraintPtr+0)), + Op: util.Read[IndexConstraintOp](mod, constraintPtr+4), + Usable: util.Read[byte](mod, constraintPtr+5) != 0, } constraintPtr += 12 } - orderByPtr := util.ReadUint32(mod, ptr+12) + orderByPtr := util.Read32[ptr_t](mod, ptr+12) orderBy := idx.OrderBy for i := range orderBy { orderBy[i] = IndexOrderBy{ - Column: int(int32(util.ReadUint32(mod, orderByPtr+0))), - Desc: util.ReadUint8(mod, orderByPtr+4) != 0, + Column: int(util.Read32[int32](mod, orderByPtr+0)), + Desc: util.Read[byte](mod, orderByPtr+4) != 0, } orderByPtr += 8 } idx.EstimatedCost = util.ReadFloat64(mod, ptr+40) - idx.EstimatedRows = int64(util.ReadUint64(mod, ptr+48)) - idx.ColumnsUsed = int64(util.ReadUint64(mod, ptr+64)) + idx.EstimatedRows = util.Read64[int64](mod, ptr+48) + idx.ColumnsUsed = util.Read64[uint64](mod, ptr+64) } func (idx *IndexInfo) save() { @@ -386,26 +386,26 @@ func (idx *IndexInfo) save() { mod := idx.c.mod ptr := idx.handle - usagePtr := util.ReadUint32(mod, ptr+16) + usagePtr := util.Read32[ptr_t](mod, ptr+16) for _, usage := range idx.ConstraintUsage { - util.WriteUint32(mod, usagePtr+0, uint32(usage.ArgvIndex)) + util.Write32(mod, usagePtr+0, int32(usage.ArgvIndex)) if usage.Omit { - util.WriteUint8(mod, usagePtr+4, 1) + util.Write(mod, usagePtr+4, int8(1)) } usagePtr += 8 } - util.WriteUint32(mod, ptr+20, uint32(idx.IdxNum)) + util.Write32(mod, ptr+20, int32(idx.IdxNum)) if idx.IdxStr != "" { - util.WriteUint32(mod, ptr+24, idx.c.newString(idx.IdxStr)) - util.WriteUint32(mod, ptr+28, 1) // needToFreeIdxStr + util.Write32(mod, ptr+24, idx.c.newString(idx.IdxStr)) + util.Write32(mod, ptr+28, int32(1)) // needToFreeIdxStr } if idx.OrderByConsumed { - util.WriteUint32(mod, ptr+32, 1) + util.Write32(mod, ptr+32, int32(1)) } util.WriteFloat64(mod, ptr+40, idx.EstimatedCost) - util.WriteUint64(mod, ptr+48, uint64(idx.EstimatedRows)) - util.WriteUint32(mod, ptr+56, uint32(idx.IdxFlags)) + util.Write64(mod, ptr+48, idx.EstimatedRows) + util.Write32(mod, ptr+56, idx.IdxFlags) } // IndexConstraintOp is a virtual table constraint operator code. @@ -442,33 +442,33 @@ const ( INDEX_SCAN_UNIQUE IndexScanFlag = 1 ) -func vtabModuleCallback(i vtabConstructor) func(_ context.Context, _ api.Module, _, _, _, _, _ uint32) uint32 { - return func(ctx context.Context, mod api.Module, pMod, nArg, pArg, ppVTab, pzErr uint32) uint32 { +func vtabModuleCallback(i vtabConstructor) func(_ context.Context, _ api.Module, _ ptr_t, _ int32, _, _, _ ptr_t) res_t { + return func(ctx context.Context, mod api.Module, pMod ptr_t, nArg int32, pArg, ppVTab, pzErr ptr_t) res_t { arg := make([]reflect.Value, 1+nArg) arg[0] = reflect.ValueOf(ctx.Value(connKey{})) - for i := uint32(0); i < nArg; i++ { - ptr := util.ReadUint32(mod, pArg+i*ptrlen) + for i := int32(0); i < nArg; i++ { + ptr := util.Read32[ptr_t](mod, pArg+ptr_t(i)*ptrlen) arg[i+1] = reflect.ValueOf(util.ReadString(mod, ptr, _MAX_SQL_LENGTH)) } module := vtabGetHandle(ctx, mod, pMod) - res := reflect.ValueOf(module).Index(int(i)).Call(arg) - err, _ := res[1].Interface().(error) + val := reflect.ValueOf(module).Index(int(i)).Call(arg) + err, _ := val[1].Interface().(error) if err == nil { - vtabPutHandle(ctx, mod, ppVTab, res[0].Interface()) + vtabPutHandle(ctx, mod, ppVTab, val[0].Interface()) } return vtabError(ctx, mod, pzErr, _PTR_ERROR, err) } } -func vtabDisconnectCallback(ctx context.Context, mod api.Module, pVTab uint32) uint32 { +func vtabDisconnectCallback(ctx context.Context, mod api.Module, pVTab ptr_t) res_t { err := vtabDelHandle(ctx, mod, pVTab) return vtabError(ctx, mod, 0, _PTR_ERROR, err) } -func vtabDestroyCallback(ctx context.Context, mod api.Module, pVTab uint32) uint32 { +func vtabDestroyCallback(ctx context.Context, mod api.Module, pVTab ptr_t) res_t { vtab := vtabGetHandle(ctx, mod, pVTab).(VTabDestroyer) err := vtab.Destroy() if cerr := vtabDelHandle(ctx, mod, pVTab); err == nil { @@ -477,7 +477,7 @@ func vtabDestroyCallback(ctx context.Context, mod api.Module, pVTab uint32) uint return vtabError(ctx, mod, 0, _PTR_ERROR, err) } -func vtabBestIndexCallback(ctx context.Context, mod api.Module, pVTab, pIdxInfo uint32) uint32 { +func vtabBestIndexCallback(ctx context.Context, mod api.Module, pVTab, pIdxInfo ptr_t) res_t { var info IndexInfo info.handle = pIdxInfo info.c = ctx.Value(connKey{}).(*Conn) @@ -490,7 +490,7 @@ func vtabBestIndexCallback(ctx context.Context, mod api.Module, pVTab, pIdxInfo return vtabError(ctx, mod, pVTab, _VTAB_ERROR, err) } -func vtabUpdateCallback(ctx context.Context, mod api.Module, pVTab, nArg, pArg, pRowID uint32) uint32 { +func vtabUpdateCallback(ctx context.Context, mod api.Module, pVTab ptr_t, nArg int32, pArg, pRowID ptr_t) res_t { vtab := vtabGetHandle(ctx, mod, pVTab).(VTabUpdater) db := ctx.Value(connKey{}).(*Conn) @@ -498,33 +498,33 @@ func vtabUpdateCallback(ctx context.Context, mod api.Module, pVTab, nArg, pArg, callbackArgs(db, args, pArg) rowID, err := vtab.Update(args...) if err == nil { - util.WriteUint64(mod, pRowID, uint64(rowID)) + util.Write64(mod, pRowID, rowID) } return vtabError(ctx, mod, pVTab, _VTAB_ERROR, err) } -func vtabRenameCallback(ctx context.Context, mod api.Module, pVTab, zNew uint32) uint32 { +func vtabRenameCallback(ctx context.Context, mod api.Module, pVTab, zNew ptr_t) res_t { vtab := vtabGetHandle(ctx, mod, pVTab).(VTabRenamer) err := vtab.Rename(util.ReadString(mod, zNew, _MAX_NAME)) return vtabError(ctx, mod, pVTab, _VTAB_ERROR, err) } -func vtabFindFuncCallback(ctx context.Context, mod api.Module, pVTab uint32, nArg int32, zName, pxFunc uint32) uint32 { +func vtabFindFuncCallback(ctx context.Context, mod api.Module, pVTab ptr_t, nArg int32, zName, pxFunc ptr_t) int32 { vtab := vtabGetHandle(ctx, mod, pVTab).(VTabOverloader) f, op := vtab.FindFunction(int(nArg), util.ReadString(mod, zName, _MAX_NAME)) if op != 0 { - var wrapper uint32 + var wrapper ptr_t wrapper = util.AddHandle(ctx, func(c Context, arg ...Value) { defer util.DelHandle(ctx, wrapper) f(c, arg...) }) - util.WriteUint32(mod, pxFunc, wrapper) + util.Write32(mod, pxFunc, wrapper) } - return uint32(op) + return int32(op) } -func vtabIntegrityCallback(ctx context.Context, mod api.Module, pVTab, zSchema, zTabName, mFlags, pzErr uint32) uint32 { +func vtabIntegrityCallback(ctx context.Context, mod api.Module, pVTab, zSchema, zTabName ptr_t, mFlags uint32, pzErr ptr_t) res_t { vtab := vtabGetHandle(ctx, mod, pVTab).(VTabChecker) schema := util.ReadString(mod, zSchema, _MAX_NAME) table := util.ReadString(mod, zTabName, _MAX_NAME) @@ -536,49 +536,49 @@ func vtabIntegrityCallback(ctx context.Context, mod api.Module, pVTab, zSchema, return code } -func vtabBeginCallback(ctx context.Context, mod api.Module, pVTab uint32) uint32 { +func vtabBeginCallback(ctx context.Context, mod api.Module, pVTab ptr_t) res_t { vtab := vtabGetHandle(ctx, mod, pVTab).(VTabTxn) err := vtab.Begin() return vtabError(ctx, mod, pVTab, _VTAB_ERROR, err) } -func vtabSyncCallback(ctx context.Context, mod api.Module, pVTab uint32) uint32 { +func vtabSyncCallback(ctx context.Context, mod api.Module, pVTab ptr_t) res_t { vtab := vtabGetHandle(ctx, mod, pVTab).(VTabTxn) err := vtab.Sync() return vtabError(ctx, mod, pVTab, _VTAB_ERROR, err) } -func vtabCommitCallback(ctx context.Context, mod api.Module, pVTab uint32) uint32 { +func vtabCommitCallback(ctx context.Context, mod api.Module, pVTab ptr_t) res_t { vtab := vtabGetHandle(ctx, mod, pVTab).(VTabTxn) err := vtab.Commit() return vtabError(ctx, mod, pVTab, _VTAB_ERROR, err) } -func vtabRollbackCallback(ctx context.Context, mod api.Module, pVTab uint32) uint32 { +func vtabRollbackCallback(ctx context.Context, mod api.Module, pVTab ptr_t) res_t { vtab := vtabGetHandle(ctx, mod, pVTab).(VTabTxn) err := vtab.Rollback() return vtabError(ctx, mod, pVTab, _VTAB_ERROR, err) } -func vtabSavepointCallback(ctx context.Context, mod api.Module, pVTab uint32, id int32) uint32 { +func vtabSavepointCallback(ctx context.Context, mod api.Module, pVTab ptr_t, id int32) res_t { vtab := vtabGetHandle(ctx, mod, pVTab).(VTabSavepointer) err := vtab.Savepoint(int(id)) return vtabError(ctx, mod, pVTab, _VTAB_ERROR, err) } -func vtabReleaseCallback(ctx context.Context, mod api.Module, pVTab uint32, id int32) uint32 { +func vtabReleaseCallback(ctx context.Context, mod api.Module, pVTab ptr_t, id int32) res_t { vtab := vtabGetHandle(ctx, mod, pVTab).(VTabSavepointer) err := vtab.Release(int(id)) return vtabError(ctx, mod, pVTab, _VTAB_ERROR, err) } -func vtabRollbackToCallback(ctx context.Context, mod api.Module, pVTab uint32, id int32) uint32 { +func vtabRollbackToCallback(ctx context.Context, mod api.Module, pVTab ptr_t, id int32) res_t { vtab := vtabGetHandle(ctx, mod, pVTab).(VTabSavepointer) err := vtab.RollbackTo(int(id)) return vtabError(ctx, mod, pVTab, _VTAB_ERROR, err) } -func cursorOpenCallback(ctx context.Context, mod api.Module, pVTab, ppCur uint32) uint32 { +func cursorOpenCallback(ctx context.Context, mod api.Module, pVTab, ppCur ptr_t) res_t { vtab := vtabGetHandle(ctx, mod, pVTab).(VTab) cursor, err := vtab.Open() @@ -589,12 +589,12 @@ func cursorOpenCallback(ctx context.Context, mod api.Module, pVTab, ppCur uint32 return vtabError(ctx, mod, pVTab, _VTAB_ERROR, err) } -func cursorCloseCallback(ctx context.Context, mod api.Module, pCur uint32) uint32 { +func cursorCloseCallback(ctx context.Context, mod api.Module, pCur ptr_t) res_t { err := vtabDelHandle(ctx, mod, pCur) return vtabError(ctx, mod, 0, _VTAB_ERROR, err) } -func cursorFilterCallback(ctx context.Context, mod api.Module, pCur uint32, idxNum int32, idxStr, nArg, pArg uint32) uint32 { +func cursorFilterCallback(ctx context.Context, mod api.Module, pCur ptr_t, idxNum int32, idxStr ptr_t, nArg int32, pArg ptr_t) res_t { cursor := vtabGetHandle(ctx, mod, pCur).(VTabCursor) db := ctx.Value(connKey{}).(*Conn) args := make([]Value, nArg) @@ -607,7 +607,7 @@ func cursorFilterCallback(ctx context.Context, mod api.Module, pCur uint32, idxN return vtabError(ctx, mod, pCur, _CURSOR_ERROR, err) } -func cursorEOFCallback(ctx context.Context, mod api.Module, pCur uint32) uint32 { +func cursorEOFCallback(ctx context.Context, mod api.Module, pCur ptr_t) int32 { cursor := vtabGetHandle(ctx, mod, pCur).(VTabCursor) if cursor.EOF() { return 1 @@ -615,25 +615,25 @@ func cursorEOFCallback(ctx context.Context, mod api.Module, pCur uint32) uint32 return 0 } -func cursorNextCallback(ctx context.Context, mod api.Module, pCur uint32) uint32 { +func cursorNextCallback(ctx context.Context, mod api.Module, pCur ptr_t) res_t { cursor := vtabGetHandle(ctx, mod, pCur).(VTabCursor) err := cursor.Next() return vtabError(ctx, mod, pCur, _CURSOR_ERROR, err) } -func cursorColumnCallback(ctx context.Context, mod api.Module, pCur, pCtx uint32, n int32) uint32 { +func cursorColumnCallback(ctx context.Context, mod api.Module, pCur, pCtx ptr_t, n int32) res_t { cursor := vtabGetHandle(ctx, mod, pCur).(VTabCursor) db := ctx.Value(connKey{}).(*Conn) err := cursor.Column(Context{db, pCtx}, int(n)) return vtabError(ctx, mod, pCur, _CURSOR_ERROR, err) } -func cursorRowIDCallback(ctx context.Context, mod api.Module, pCur, pRowID uint32) uint32 { +func cursorRowIDCallback(ctx context.Context, mod api.Module, pCur, pRowID ptr_t) res_t { cursor := vtabGetHandle(ctx, mod, pCur).(VTabCursor) rowID, err := cursor.RowID() if err == nil { - util.WriteUint64(mod, pRowID, uint64(rowID)) + util.Write64(mod, pRowID, rowID) } return vtabError(ctx, mod, pCur, _CURSOR_ERROR, err) @@ -645,7 +645,7 @@ const ( _CURSOR_ERROR ) -func vtabError(ctx context.Context, mod api.Module, ptr, kind uint32, err error) uint32 { +func vtabError(ctx context.Context, mod api.Module, ptr ptr_t, kind uint32, err error) res_t { const zErrMsgOffset = 8 msg, code := errorCode(err, ERROR) if msg != "" && ptr != 0 { @@ -653,32 +653,32 @@ func vtabError(ctx context.Context, mod api.Module, ptr, kind uint32, err error) case _VTAB_ERROR: ptr = ptr + zErrMsgOffset // zErrMsg case _CURSOR_ERROR: - ptr = util.ReadUint32(mod, ptr) + zErrMsgOffset // pVTab->zErrMsg + ptr = util.Read32[ptr_t](mod, ptr) + zErrMsgOffset // pVTab->zErrMsg } db := ctx.Value(connKey{}).(*Conn) - if ptr := util.ReadUint32(mod, ptr); ptr != 0 { + if ptr := util.Read32[ptr_t](mod, ptr); ptr != 0 { db.free(ptr) } - util.WriteUint32(mod, ptr, db.newString(msg)) + util.Write32(mod, ptr, db.newString(msg)) } return code } -func vtabGetHandle(ctx context.Context, mod api.Module, ptr uint32) any { +func vtabGetHandle(ctx context.Context, mod api.Module, ptr ptr_t) any { const handleOffset = 4 - handle := util.ReadUint32(mod, ptr-handleOffset) + handle := util.Read32[ptr_t](mod, ptr-handleOffset) return util.GetHandle(ctx, handle) } -func vtabDelHandle(ctx context.Context, mod api.Module, ptr uint32) error { +func vtabDelHandle(ctx context.Context, mod api.Module, ptr ptr_t) error { const handleOffset = 4 - handle := util.ReadUint32(mod, ptr-handleOffset) + handle := util.Read32[ptr_t](mod, ptr-handleOffset) return util.DelHandle(ctx, handle) } -func vtabPutHandle(ctx context.Context, mod api.Module, pptr uint32, val any) { +func vtabPutHandle(ctx context.Context, mod api.Module, pptr ptr_t, val any) { const handleOffset = 4 handle := util.AddHandle(ctx, val) - ptr := util.ReadUint32(mod, pptr) - util.WriteUint32(mod, ptr-handleOffset, handle) + ptr := util.Read32[ptr_t](mod, pptr) + util.Write32(mod, ptr-handleOffset, handle) }