From 736151ddec10ccc3bf2f710eba8cf63ead4e3390 Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Mon, 20 Jan 2025 19:31:06 +0000 Subject: [PATCH 1/6] More type safe. --- backup.go | 36 +++--- blob.go | 46 ++++---- config.go | 120 +++++++++---------- conn.go | 115 +++++++++---------- const.go | 15 ++- context.go | 8 +- error.go | 2 +- error_test.go | 32 +++--- func.go | 64 +++++------ internal/util/func.go | 1 + internal/util/handle.go | 12 +- internal/util/mem.go | 62 +++++----- internal/util/mem_test.go | 24 ++-- internal/util/mmap_unix.go | 4 +- sqlite.go | 51 ++++---- sqlite_test.go | 2 +- stmt.go | 230 ++++++++++++++++++------------------- txn.go | 15 ++- value.go | 57 +++++---- vfs/api.go | 4 +- vfs/cksm.go | 2 +- vfs/const.go | 4 +- vfs/filename.go | 20 ++-- vfs/lock_test.go | 36 +++--- vfs/shm_ofd.go | 2 +- vfs/vfs.go | 141 ++++++++++++----------- vfs/vfs_test.go | 47 ++++---- vtab.go | 170 +++++++++++++-------------- 28 files changed, 666 insertions(+), 656 deletions(-) diff --git a/backup.go b/backup.go index b16c7511..6378aab8 100644 --- a/backup.go +++ b/backup.go @@ -5,8 +5,8 @@ package sqlite3 // https://sqlite.org/c3ref/backup.html type Backup struct { c *Conn - handle uint32 - otherc uint32 + handle ptr_t + otherc ptr_t } // Backup backs up srcDB on the src connection to the "main" database in dstURI. @@ -61,7 +61,7 @@ func (src *Conn) BackupInit(srcDB, dstURI string) (*Backup, error) { return src.backupInit(dst, "main", src.handle, srcDB) } -func (c *Conn) backupInit(dst uint32, dstName string, src uint32, srcName string) (*Backup, error) { +func (c *Conn) backupInit(dst ptr_t, dstName string, src ptr_t, srcName string) (*Backup, error) { defer c.arena.mark()() dstPtr := c.arena.string(dstName) srcPtr := c.arena.string(srcName) @@ -71,19 +71,19 @@ func (c *Conn) backupInit(dst uint32, dstName string, src uint32, srcName string other = src } - r := c.call("sqlite3_backup_init", + ptr := ptr_t(c.call("sqlite3_backup_init", uint64(dst), uint64(dstPtr), - uint64(src), uint64(srcPtr)) - if r == 0 { + uint64(src), uint64(srcPtr))) + if ptr == 0 { defer c.closeDB(other) - r = c.call("sqlite3_errcode", uint64(dst)) - return nil, c.sqlite.error(r, dst) + rc := res_t(c.call("sqlite3_errcode", uint64(dst))) + return nil, c.sqlite.error(rc, dst) } return &Backup{ c: c, otherc: other, - handle: uint32(r), + handle: ptr, }, nil } @@ -97,10 +97,10 @@ func (b *Backup) Close() error { return nil } - r := b.c.call("sqlite3_backup_finish", uint64(b.handle)) + rc := res_t(b.c.call("sqlite3_backup_finish", uint64(b.handle))) b.c.closeDB(b.otherc) b.handle = 0 - return b.c.error(r) + return b.c.error(rc) } // Step copies up to nPage pages between the source and destination databases. @@ -108,11 +108,11 @@ func (b *Backup) Close() error { // // https://sqlite.org/c3ref/backup_finish.html#sqlite3backupstep func (b *Backup) Step(nPage int) (done bool, err error) { - r := b.c.call("sqlite3_backup_step", uint64(b.handle), uint64(nPage)) - if r == _DONE { + rc := res_t(b.c.call("sqlite3_backup_step", uint64(b.handle), uint64(nPage))) + if rc == _DONE { return true, nil } - return false, b.c.error(r) + return false, b.c.error(rc) } // Remaining returns the number of pages still to be backed up @@ -120,8 +120,8 @@ func (b *Backup) Step(nPage int) (done bool, err error) { // // https://sqlite.org/c3ref/backup_finish.html#sqlite3backupremaining func (b *Backup) Remaining() int { - r := b.c.call("sqlite3_backup_remaining", uint64(b.handle)) - return int(int32(r)) + n := int32(b.c.call("sqlite3_backup_remaining", uint64(b.handle))) + return int(n) } // PageCount returns the total number of pages in the source database @@ -129,6 +129,6 @@ func (b *Backup) Remaining() int { // // https://sqlite.org/c3ref/backup_finish.html#sqlite3backuppagecount func (b *Backup) PageCount() int { - r := b.c.call("sqlite3_backup_pagecount", uint64(b.handle)) - return int(int32(r)) + n := int32(b.c.call("sqlite3_backup_pagecount", uint64(b.handle))) + return int(n) } diff --git a/blob.go b/blob.go index a0969eb6..a2e4cfee 100644 --- a/blob.go +++ b/blob.go @@ -20,8 +20,8 @@ type Blob struct { c *Conn bytes int64 offset int64 - handle uint32 - bufptr uint32 + handle ptr_t + bufptr ptr_t buflen int64 } @@ -43,17 +43,17 @@ func (c *Conn) OpenBlob(db, table, column string, row int64, write bool) (*Blob, } c.checkInterrupt(c.handle) - r := c.call("sqlite3_blob_open", uint64(c.handle), + rc := res_t(c.call("sqlite3_blob_open", uint64(c.handle), uint64(dbPtr), uint64(tablePtr), uint64(columnPtr), - uint64(row), flags, uint64(blobPtr)) + uint64(row), flags, uint64(blobPtr))) - if err := c.error(r); err != nil { + if err := c.error(rc); err != nil { return nil, err } blob := Blob{c: c} - blob.handle = util.ReadUint32(c.mod, blobPtr) - blob.bytes = int64(c.call("sqlite3_blob_bytes", uint64(blob.handle))) + blob.handle = util.Read32[ptr_t](c.mod, blobPtr) + blob.bytes = int64(int32(c.call("sqlite3_blob_bytes", uint64(blob.handle)))) return &blob, nil } @@ -67,10 +67,10 @@ func (b *Blob) Close() error { return nil } - r := b.c.call("sqlite3_blob_close", uint64(b.handle)) + rc := res_t(b.c.call("sqlite3_blob_close", uint64(b.handle))) b.c.free(b.bufptr) b.handle = 0 - return b.c.error(r) + return b.c.error(rc) } // Size returns the size of the BLOB in bytes. @@ -98,9 +98,9 @@ func (b *Blob) Read(p []byte) (n int, err error) { b.buflen = want } - r := b.c.call("sqlite3_blob_read", uint64(b.handle), - uint64(b.bufptr), uint64(want), uint64(b.offset)) - err = b.c.error(r) + rc := res_t(b.c.call("sqlite3_blob_read", uint64(b.handle), + uint64(b.bufptr), uint64(want), uint64(b.offset))) + err = b.c.error(rc) if err != nil { return 0, err } @@ -132,9 +132,9 @@ func (b *Blob) WriteTo(w io.Writer) (n int64, err error) { } for want > 0 { - r := b.c.call("sqlite3_blob_read", uint64(b.handle), - uint64(b.bufptr), uint64(want), uint64(b.offset)) - err = b.c.error(r) + rc := res_t(b.c.call("sqlite3_blob_read", uint64(b.handle), + uint64(b.bufptr), uint64(want), uint64(b.offset))) + err = b.c.error(rc) if err != nil { return n, err } @@ -170,9 +170,9 @@ func (b *Blob) Write(p []byte) (n int, err error) { } util.WriteBytes(b.c.mod, b.bufptr, p) - r := b.c.call("sqlite3_blob_write", uint64(b.handle), - uint64(b.bufptr), uint64(want), uint64(b.offset)) - err = b.c.error(r) + rc := res_t(b.c.call("sqlite3_blob_write", uint64(b.handle), + uint64(b.bufptr), uint64(want), uint64(b.offset))) + err = b.c.error(rc) if err != nil { return 0, err } @@ -204,9 +204,9 @@ func (b *Blob) ReadFrom(r io.Reader) (n int64, err error) { mem := util.View(b.c.mod, b.bufptr, uint64(want)) m, err := r.Read(mem[:want]) if m > 0 { - r := b.c.call("sqlite3_blob_write", uint64(b.handle), - uint64(b.bufptr), uint64(m), uint64(b.offset)) - err := b.c.error(r) + rc := res_t(b.c.call("sqlite3_blob_write", uint64(b.handle), + uint64(b.bufptr), uint64(m), uint64(b.offset))) + err := b.c.error(rc) if err != nil { return n, err } @@ -254,8 +254,8 @@ func (b *Blob) Seek(offset int64, whence int) (int64, error) { // https://sqlite.org/c3ref/blob_reopen.html func (b *Blob) Reopen(row int64) error { b.c.checkInterrupt(b.c.handle) - err := b.c.error(b.c.call("sqlite3_blob_reopen", uint64(b.handle), uint64(row))) - b.bytes = int64(b.c.call("sqlite3_blob_bytes", uint64(b.handle))) + err := b.c.error(res_t(b.c.call("sqlite3_blob_reopen", uint64(b.handle), uint64(row)))) + b.bytes = int64(int32(b.c.call("sqlite3_blob_bytes", uint64(b.handle)))) b.offset = 0 return err } diff --git a/config.go b/config.go index 474f960a..7391d578 100644 --- a/config.go +++ b/config.go @@ -32,7 +32,7 @@ func (c *Conn) Config(op DBConfig, arg ...bool) (bool, error) { defer c.arena.mark()() argsPtr := c.arena.new(intlen + ptrlen) - var flag int + var flag int32 switch { case len(arg) == 0: flag = -1 @@ -40,12 +40,12 @@ func (c *Conn) Config(op DBConfig, arg ...bool) (bool, error) { flag = 1 } - util.WriteUint32(c.mod, argsPtr+0*ptrlen, uint32(flag)) - util.WriteUint32(c.mod, argsPtr+1*ptrlen, argsPtr) + util.Write32(c.mod, argsPtr+0*ptrlen, flag) + util.Write32(c.mod, argsPtr+1*ptrlen, argsPtr) - r := c.call("sqlite3_db_config", uint64(c.handle), - uint64(op), uint64(argsPtr)) - return util.ReadUint32(c.mod, argsPtr) != 0, c.error(r) + rc := res_t(c.call("sqlite3_db_config", uint64(c.handle), + uint64(op), uint64(argsPtr))) + return util.Read32[uint32](c.mod, argsPtr) != 0, c.error(rc) } // ConfigLog sets up the error logging callback for the connection. @@ -56,15 +56,15 @@ func (c *Conn) ConfigLog(cb func(code ExtendedErrorCode, msg string)) error { if cb != nil { enable = 1 } - r := c.call("sqlite3_config_log_go", enable) - if err := c.error(r); err != nil { + rc := res_t(c.call("sqlite3_config_log_go", enable)) + if err := c.error(rc); err != nil { return err } c.log = cb return nil } -func logCallback(ctx context.Context, mod api.Module, _, iCode, zMsg uint32) { +func logCallback(ctx context.Context, mod api.Module, _, iCode, zMsg ptr_t) { if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.log != nil { msg := util.ReadString(mod, zMsg, _MAX_LENGTH) c.log(xErrorCode(iCode), msg) @@ -88,85 +88,85 @@ func (c *Conn) FileControl(schema string, op FcntlOpcode, arg ...any) (any, erro defer c.arena.mark()() ptr := c.arena.new(max(ptrlen, intlen)) - var schemaPtr uint32 + var schemaPtr ptr_t if schema != "" { schemaPtr = c.arena.string(schema) } - var rc uint64 + var rc res_t var res any switch op { default: return nil, MISUSE case FCNTL_RESET_CACHE: - rc = c.call("sqlite3_file_control", + rc = res_t(c.call("sqlite3_file_control", uint64(c.handle), uint64(schemaPtr), - uint64(op), 0) + uint64(op), 0)) case FCNTL_PERSIST_WAL, FCNTL_POWERSAFE_OVERWRITE: - var flag int + var flag int32 switch { case len(arg) == 0: flag = -1 case arg[0]: flag = 1 } - util.WriteUint32(c.mod, ptr, uint32(flag)) - rc = c.call("sqlite3_file_control", + util.Write32(c.mod, ptr, flag) + rc = res_t(c.call("sqlite3_file_control", uint64(c.handle), uint64(schemaPtr), - uint64(op), uint64(ptr)) - res = util.ReadUint32(c.mod, ptr) != 0 + uint64(op), uint64(ptr))) + res = util.Read32[uint32](c.mod, ptr) != 0 case FCNTL_CHUNK_SIZE: - util.WriteUint32(c.mod, ptr, uint32(arg[0].(int))) - rc = c.call("sqlite3_file_control", + util.Write32(c.mod, ptr, int32(arg[0].(int))) + rc = res_t(c.call("sqlite3_file_control", uint64(c.handle), uint64(schemaPtr), - uint64(op), uint64(ptr)) + uint64(op), uint64(ptr))) case FCNTL_RESERVE_BYTES: bytes := -1 if len(arg) > 0 { bytes = arg[0].(int) } - util.WriteUint32(c.mod, ptr, uint32(bytes)) - rc = c.call("sqlite3_file_control", + util.Write32(c.mod, ptr, int32(bytes)) + rc = res_t(c.call("sqlite3_file_control", uint64(c.handle), uint64(schemaPtr), - uint64(op), uint64(ptr)) - res = int(util.ReadUint32(c.mod, ptr)) + uint64(op), uint64(ptr))) + res = int(util.Read32[int32](c.mod, ptr)) case FCNTL_DATA_VERSION: - rc = c.call("sqlite3_file_control", + rc = res_t(c.call("sqlite3_file_control", uint64(c.handle), uint64(schemaPtr), - uint64(op), uint64(ptr)) - res = util.ReadUint32(c.mod, ptr) + uint64(op), uint64(ptr))) + res = util.Read32[uint32](c.mod, ptr) case FCNTL_LOCKSTATE: - rc = c.call("sqlite3_file_control", + rc = res_t(c.call("sqlite3_file_control", uint64(c.handle), uint64(schemaPtr), - uint64(op), uint64(ptr)) - res = vfs.LockLevel(util.ReadUint32(c.mod, ptr)) + uint64(op), uint64(ptr))) + res = util.Read32[vfs.LockLevel](c.mod, ptr) case FCNTL_VFS_POINTER: - rc = c.call("sqlite3_file_control", + rc = res_t(c.call("sqlite3_file_control", uint64(c.handle), uint64(schemaPtr), - uint64(op), uint64(ptr)) + uint64(op), uint64(ptr))) if rc == _OK { const zNameOffset = 16 - ptr = util.ReadUint32(c.mod, ptr) - ptr = util.ReadUint32(c.mod, ptr+zNameOffset) + ptr = util.Read32[ptr_t](c.mod, ptr) + ptr = util.Read32[ptr_t](c.mod, ptr+zNameOffset) name := util.ReadString(c.mod, ptr, _MAX_NAME) res = vfs.Find(name) } case FCNTL_FILE_POINTER, FCNTL_JOURNAL_POINTER: - rc = c.call("sqlite3_file_control", + rc = res_t(c.call("sqlite3_file_control", uint64(c.handle), uint64(schemaPtr), - uint64(op), uint64(ptr)) + uint64(op), uint64(ptr))) if rc == _OK { const fileHandleOffset = 4 - ptr = util.ReadUint32(c.mod, ptr) - ptr = util.ReadUint32(c.mod, ptr+fileHandleOffset) + ptr = util.Read32[ptr_t](c.mod, ptr) + ptr = util.Read32[ptr_t](c.mod, ptr+fileHandleOffset) res = util.GetHandle(c.ctx, ptr) } } @@ -182,8 +182,8 @@ func (c *Conn) FileControl(schema string, op FcntlOpcode, arg ...any) (any, erro // // https://sqlite.org/c3ref/limit.html func (c *Conn) Limit(id LimitCategory, value int) int { - r := c.call("sqlite3_limit", uint64(c.handle), uint64(id), uint64(value)) - return int(int32(r)) + v := int32(c.call("sqlite3_limit", uint64(c.handle), uint64(id), uint64(value))) + return int(v) } // SetAuthorizer registers an authorizer callback with the database connection. @@ -194,8 +194,8 @@ func (c *Conn) SetAuthorizer(cb func(action AuthorizerActionCode, name3rd, name4 if cb != nil { enable = 1 } - r := c.call("sqlite3_set_authorizer_go", uint64(c.handle), enable) - if err := c.error(r); err != nil { + rc := res_t(c.call("sqlite3_set_authorizer_go", uint64(c.handle), enable)) + if err := c.error(rc); err != nil { return err } c.authorizer = cb @@ -203,7 +203,7 @@ func (c *Conn) SetAuthorizer(cb func(action AuthorizerActionCode, name3rd, name4 } -func authorizerCallback(ctx context.Context, mod api.Module, pDB uint32, action AuthorizerActionCode, zName3rd, zName4th, zSchema, zInner uint32) (rc AuthorizerReturnCode) { +func authorizerCallback(ctx context.Context, mod api.Module, pDB ptr_t, action AuthorizerActionCode, zName3rd, zName4th, zSchema, zInner ptr_t) (rc AuthorizerReturnCode) { if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.authorizer != nil { var name3rd, name4th, schema, inner string if zName3rd != 0 { @@ -227,15 +227,15 @@ func authorizerCallback(ctx context.Context, mod api.Module, pDB uint32, action // // https://sqlite.org/c3ref/trace_v2.html func (c *Conn) Trace(mask TraceEvent, cb func(evt TraceEvent, arg1 any, arg2 any) error) error { - r := c.call("sqlite3_trace_go", uint64(c.handle), uint64(mask)) - if err := c.error(r); err != nil { + rc := res_t(c.call("sqlite3_trace_go", uint64(c.handle), uint64(mask))) + if err := c.error(rc); err != nil { return err } c.trace = cb return nil } -func traceCallback(ctx context.Context, mod api.Module, evt TraceEvent, pDB, pArg1, pArg2 uint32) (rc uint32) { +func traceCallback(ctx context.Context, mod api.Module, evt TraceEvent, pDB, pArg1, pArg2 ptr_t) (rc uint32) { if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.trace != nil { var arg1, arg2 any if evt == TRACE_CLOSE { @@ -248,7 +248,7 @@ func traceCallback(ctx context.Context, mod api.Module, evt TraceEvent, pDB, pAr case TRACE_STMT: arg2 = s.SQL() case TRACE_PROFILE: - arg2 = int64(util.ReadUint64(mod, pArg2)) + arg2 = util.Read64[int64](mod, pArg2) } break } @@ -269,20 +269,20 @@ func (c *Conn) WALCheckpoint(schema string, mode CheckpointMode) (nLog, nCkpt in nLogPtr := c.arena.new(ptrlen) nCkptPtr := c.arena.new(ptrlen) schemaPtr := c.arena.string(schema) - r := c.call("sqlite3_wal_checkpoint_v2", + rc := res_t(c.call("sqlite3_wal_checkpoint_v2", uint64(c.handle), uint64(schemaPtr), uint64(mode), - uint64(nLogPtr), uint64(nCkptPtr)) - nLog = int(int32(util.ReadUint32(c.mod, nLogPtr))) - nCkpt = int(int32(util.ReadUint32(c.mod, nCkptPtr))) - return nLog, nCkpt, c.error(r) + uint64(nLogPtr), uint64(nCkptPtr))) + nLog = int(util.Read32[int32](c.mod, nLogPtr)) + nCkpt = int(util.Read32[int32](c.mod, nCkptPtr)) + return nLog, nCkpt, c.error(rc) } // WALAutoCheckpoint configures WAL auto-checkpoints. // // https://sqlite.org/c3ref/wal_autocheckpoint.html func (c *Conn) WALAutoCheckpoint(pages int) error { - r := c.call("sqlite3_wal_autocheckpoint", uint64(c.handle), uint64(pages)) - return c.error(r) + rc := res_t(c.call("sqlite3_wal_autocheckpoint", uint64(c.handle), uint64(pages))) + return c.error(rc) } // WALHook registers a callback function to be invoked @@ -298,7 +298,7 @@ func (c *Conn) WALHook(cb func(db *Conn, schema string, pages int) error) { c.wal = cb } -func walCallback(ctx context.Context, mod api.Module, _, pDB, zSchema uint32, pages int32) (rc uint32) { +func walCallback(ctx context.Context, mod api.Module, _, pDB, zSchema ptr_t, pages int32) (rc uint32) { if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.wal != nil { schema := util.ReadString(mod, zSchema, _MAX_NAME) err := c.wal(c, schema, int(pages)) @@ -311,15 +311,15 @@ func walCallback(ctx context.Context, mod api.Module, _, pDB, zSchema uint32, pa // // https://sqlite.org/c3ref/autovacuum_pages.html func (c *Conn) AutoVacuumPages(cb func(schema string, dbPages, freePages, bytesPerPage uint) uint) error { - var funcPtr uint32 + var funcPtr ptr_t if cb != nil { funcPtr = util.AddHandle(c.ctx, cb) } - r := c.call("sqlite3_autovacuum_pages_go", uint64(c.handle), uint64(funcPtr)) - return c.error(r) + rc := res_t(c.call("sqlite3_autovacuum_pages_go", uint64(c.handle), uint64(funcPtr))) + return c.error(rc) } -func autoVacuumCallback(ctx context.Context, mod api.Module, pApp, zSchema, nDbPage, nFreePage, nBytePerPage uint32) uint32 { +func autoVacuumCallback(ctx context.Context, mod api.Module, pApp, zSchema ptr_t, nDbPage, nFreePage, nBytePerPage uint32) uint32 { fn := util.GetHandle(ctx, pApp).(func(schema string, dbPages, freePages, bytesPerPage uint) uint) schema := util.ReadString(mod, zSchema, _MAX_NAME) return uint32(fn(schema, uint(nDbPage), uint(nFreePage), uint(nBytePerPage))) diff --git a/conn.go b/conn.go index 862d4306..394dd657 100644 --- a/conn.go +++ b/conn.go @@ -39,7 +39,7 @@ type Conn struct { busy1st time.Time busylst time.Time - handle uint32 + handle ptr_t } // Open calls [OpenFlags] with [OPEN_READWRITE], [OPEN_CREATE] and [OPEN_URI]. @@ -102,16 +102,16 @@ func newConn(ctx context.Context, filename string, flags OpenFlag) (res *Conn, _ return c, nil } -func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) { +func (c *Conn) openDB(filename string, flags OpenFlag) (ptr_t, error) { defer c.arena.mark()() connPtr := c.arena.new(ptrlen) namePtr := c.arena.string(filename) flags |= OPEN_EXRESCODE - r := c.call("sqlite3_open_v2", uint64(namePtr), uint64(connPtr), uint64(flags), 0) + rc := res_t(c.call("sqlite3_open_v2", uint64(namePtr), uint64(connPtr), uint64(flags), 0)) - handle := util.ReadUint32(c.mod, connPtr) - if err := c.sqlite.error(r, handle); err != nil { + handle := util.Read32[ptr_t](c.mod, connPtr) + if err := c.sqlite.error(rc, handle); err != nil { c.closeDB(handle) return 0, err } @@ -130,8 +130,8 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) { if pragmas.Len() != 0 { c.checkInterrupt(handle) pragmaPtr := c.arena.string(pragmas.String()) - r := c.call("sqlite3_exec", uint64(handle), uint64(pragmaPtr), 0, 0, 0) - if err := c.sqlite.error(r, handle, pragmas.String()); err != nil { + rc := res_t(c.call("sqlite3_exec", uint64(handle), uint64(pragmaPtr), 0, 0, 0)) + if err := c.sqlite.error(rc, handle, pragmas.String()); err != nil { err = fmt.Errorf("sqlite3: invalid _pragma: %w", err) c.closeDB(handle) return 0, err @@ -141,9 +141,9 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) { return handle, nil } -func (c *Conn) closeDB(handle uint32) { - r := c.call("sqlite3_close_v2", uint64(handle)) - if err := c.sqlite.error(r, handle); err != nil { +func (c *Conn) closeDB(handle ptr_t) { + rc := res_t(c.call("sqlite3_close_v2", uint64(handle))) + if err := c.sqlite.error(rc, handle); err != nil { panic(err) } } @@ -165,8 +165,8 @@ func (c *Conn) Close() error { c.pending.Close() c.pending = nil - r := c.call("sqlite3_close", uint64(c.handle)) - if err := c.error(r); err != nil { + rc := res_t(c.call("sqlite3_close", uint64(c.handle))) + if err := c.error(rc); err != nil { return err } @@ -183,8 +183,8 @@ func (c *Conn) Exec(sql string) error { sqlPtr := c.arena.string(sql) c.checkInterrupt(c.handle) - r := c.call("sqlite3_exec", uint64(c.handle), uint64(sqlPtr), 0, 0, 0) - return c.error(r, sql) + rc := res_t(c.call("sqlite3_exec", uint64(c.handle), uint64(sqlPtr), 0, 0, 0)) + return c.error(rc, sql) } // Prepare calls [Conn.PrepareFlags] with no flags. @@ -209,17 +209,17 @@ func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail str sqlPtr := c.arena.string(sql) c.checkInterrupt(c.handle) - r := c.call("sqlite3_prepare_v3", uint64(c.handle), + rc := res_t(c.call("sqlite3_prepare_v3", uint64(c.handle), uint64(sqlPtr), uint64(len(sql)+1), uint64(flags), - uint64(stmtPtr), uint64(tailPtr)) + uint64(stmtPtr), uint64(tailPtr))) stmt = &Stmt{c: c} - stmt.handle = util.ReadUint32(c.mod, stmtPtr) - if sql := sql[util.ReadUint32(c.mod, tailPtr)-sqlPtr:]; sql != "" { + stmt.handle = util.Read32[ptr_t](c.mod, stmtPtr) + if sql := sql[util.Read32[ptr_t](c.mod, tailPtr)-sqlPtr:]; sql != "" { tail = sql } - if err := c.error(r, sql); err != nil { + if err := c.error(rc, sql); err != nil { return nil, "", err } if stmt.handle == 0 { @@ -233,9 +233,7 @@ func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail str // // https://sqlite.org/c3ref/db_name.html func (c *Conn) DBName(n int) string { - r := c.call("sqlite3_db_name", uint64(c.handle), uint64(n)) - - ptr := uint32(r) + ptr := ptr_t(c.call("sqlite3_db_name", uint64(c.handle), uint64(n))) if ptr == 0 { return "" } @@ -246,34 +244,34 @@ func (c *Conn) DBName(n int) string { // // https://sqlite.org/c3ref/db_filename.html func (c *Conn) Filename(schema string) *vfs.Filename { - var ptr uint32 + var ptr ptr_t if schema != "" { defer c.arena.mark()() ptr = c.arena.string(schema) } - r := c.call("sqlite3_db_filename", uint64(c.handle), uint64(ptr)) - return vfs.GetFilename(c.ctx, c.mod, uint32(r), vfs.OPEN_MAIN_DB) + ptr = ptr_t(c.call("sqlite3_db_filename", uint64(c.handle), uint64(ptr))) + return vfs.GetFilename(c.ctx, c.mod, ptr, vfs.OPEN_MAIN_DB) } // ReadOnly determines if a database is read-only. // // https://sqlite.org/c3ref/db_readonly.html func (c *Conn) ReadOnly(schema string) (ro bool, ok bool) { - var ptr uint32 + var ptr ptr_t if schema != "" { defer c.arena.mark()() ptr = c.arena.string(schema) } - r := c.call("sqlite3_db_readonly", uint64(c.handle), uint64(ptr)) - return int32(r) > 0, int32(r) < 0 + b := int32(c.call("sqlite3_db_readonly", uint64(c.handle), uint64(ptr))) + return b > 0, b < 0 } // GetAutocommit tests the connection for auto-commit mode. // // https://sqlite.org/c3ref/get_autocommit.html func (c *Conn) GetAutocommit() bool { - r := c.call("sqlite3_get_autocommit", uint64(c.handle)) - return r != 0 + b := int32(c.call("sqlite3_get_autocommit", uint64(c.handle))) + return b != 0 } // LastInsertRowID returns the rowid of the most recent successful INSERT @@ -281,8 +279,7 @@ func (c *Conn) GetAutocommit() bool { // // https://sqlite.org/c3ref/last_insert_rowid.html func (c *Conn) LastInsertRowID() int64 { - r := c.call("sqlite3_last_insert_rowid", uint64(c.handle)) - return int64(r) + return int64(c.call("sqlite3_last_insert_rowid", uint64(c.handle))) } // SetLastInsertRowID allows the application to set the value returned by @@ -299,8 +296,7 @@ func (c *Conn) SetLastInsertRowID(id int64) { // // https://sqlite.org/c3ref/changes.html func (c *Conn) Changes() int64 { - r := c.call("sqlite3_changes64", uint64(c.handle)) - return int64(r) + return int64(c.call("sqlite3_changes64", uint64(c.handle))) } // TotalChanges returns the number of rows modified, inserted or deleted @@ -309,16 +305,15 @@ func (c *Conn) Changes() int64 { // // https://sqlite.org/c3ref/total_changes.html func (c *Conn) TotalChanges() int64 { - r := c.call("sqlite3_total_changes64", uint64(c.handle)) - return int64(r) + return int64(c.call("sqlite3_total_changes64", uint64(c.handle))) } // ReleaseMemory frees memory used by a database connection. // // https://sqlite.org/c3ref/db_release_memory.html func (c *Conn) ReleaseMemory() error { - r := c.call("sqlite3_db_release_memory", uint64(c.handle)) - return c.error(r) + rc := res_t(c.call("sqlite3_db_release_memory", uint64(c.handle))) + return c.error(rc) } // GetInterrupt gets the context set with [Conn.SetInterrupt]. @@ -357,7 +352,7 @@ func (c *Conn) SetInterrupt(ctx context.Context) (old context.Context) { c.call("sqlite3_prepare_v3", uint64(c.handle), uint64(loopPtr), math.MaxUint64, uint64(PREPARE_PERSISTENT), uint64(stmtPtr), 0) c.pending = &Stmt{c: c} - c.pending.handle = util.ReadUint32(c.mod, stmtPtr) + c.pending.handle = util.Read32[ptr_t](c.mod, stmtPtr) } if old.Done() != nil && ctx.Err() == nil { @@ -369,7 +364,7 @@ func (c *Conn) SetInterrupt(ctx context.Context) (old context.Context) { return old } -func (c *Conn) checkInterrupt(handle uint32) { +func (c *Conn) checkInterrupt(handle ptr_t) { if c.interrupt.Err() != nil { c.call("sqlite3_interrupt", uint64(handle)) } @@ -392,8 +387,8 @@ func progressCallback(ctx context.Context, mod api.Module, _ uint32) (interrupt // https://sqlite.org/c3ref/busy_timeout.html func (c *Conn) BusyTimeout(timeout time.Duration) error { ms := min((timeout+time.Millisecond-1)/time.Millisecond, math.MaxInt32) - r := c.call("sqlite3_busy_timeout", uint64(c.handle), uint64(ms)) - return c.error(r) + rc := res_t(c.call("sqlite3_busy_timeout", uint64(c.handle), uint64(ms))) + return c.error(rc) } func timeoutCallback(ctx context.Context, mod api.Module, count, tmout int32) (retry uint32) { @@ -423,15 +418,15 @@ func (c *Conn) BusyHandler(cb func(ctx context.Context, count int) (retry bool)) if cb != nil { enable = 1 } - r := c.call("sqlite3_busy_handler_go", uint64(c.handle), enable) - if err := c.error(r); err != nil { + rc := res_t(c.call("sqlite3_busy_handler_go", uint64(c.handle), enable)) + if err := c.error(rc); err != nil { return err } c.busy = cb return nil } -func busyCallback(ctx context.Context, mod api.Module, pDB uint32, count int32) (retry uint32) { +func busyCallback(ctx context.Context, mod api.Module, pDB ptr_t, count int32) (retry uint32) { if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.busy != nil { interrupt := c.interrupt if interrupt == nil { @@ -457,11 +452,11 @@ func (c *Conn) Status(op DBStatus, reset bool) (current, highwater int, err erro i = 1 } - r := c.call("sqlite3_db_status", uint64(c.handle), - uint64(op), uint64(curPtr), uint64(hiPtr), i) - if err = c.error(r); err == nil { - current = int(util.ReadUint32(c.mod, curPtr)) - highwater = int(util.ReadUint32(c.mod, hiPtr)) + rc := res_t(c.call("sqlite3_db_status", uint64(c.handle), + uint64(op), uint64(curPtr), uint64(hiPtr), i)) + if err = c.error(rc); err == nil { + current = int(util.Read32[int32](c.mod, curPtr)) + highwater = int(util.Read32[int32](c.mod, hiPtr)) } return } @@ -472,7 +467,7 @@ func (c *Conn) Status(op DBStatus, reset bool) (current, highwater int, err erro func (c *Conn) TableColumnMetadata(schema, table, column string) (declType, collSeq string, notNull, primaryKey, autoInc bool, err error) { defer c.arena.mark()() - var schemaPtr, columnPtr uint32 + var schemaPtr, columnPtr ptr_t declTypePtr := c.arena.new(ptrlen) collSeqPtr := c.arena.new(ptrlen) notNullPtr := c.arena.new(ptrlen) @@ -486,25 +481,25 @@ func (c *Conn) TableColumnMetadata(schema, table, column string) (declType, coll columnPtr = c.arena.string(column) } - r := c.call("sqlite3_table_column_metadata", uint64(c.handle), + rc := res_t(c.call("sqlite3_table_column_metadata", uint64(c.handle), uint64(schemaPtr), uint64(tablePtr), uint64(columnPtr), uint64(declTypePtr), uint64(collSeqPtr), - uint64(notNullPtr), uint64(primaryKeyPtr), uint64(autoIncPtr)) - if err = c.error(r); err == nil && column != "" { - if ptr := util.ReadUint32(c.mod, declTypePtr); ptr != 0 { + uint64(notNullPtr), uint64(primaryKeyPtr), uint64(autoIncPtr))) + if err = c.error(rc); err == nil && column != "" { + if ptr := util.Read32[ptr_t](c.mod, declTypePtr); ptr != 0 { declType = util.ReadString(c.mod, ptr, _MAX_NAME) } - if ptr := util.ReadUint32(c.mod, collSeqPtr); ptr != 0 { + if ptr := util.Read32[ptr_t](c.mod, collSeqPtr); ptr != 0 { collSeq = util.ReadString(c.mod, ptr, _MAX_NAME) } - notNull = util.ReadUint32(c.mod, notNullPtr) != 0 - autoInc = util.ReadUint32(c.mod, autoIncPtr) != 0 - primaryKey = util.ReadUint32(c.mod, primaryKeyPtr) != 0 + notNull = util.Read32[uint32](c.mod, notNullPtr) != 0 + autoInc = util.Read32[uint32](c.mod, autoIncPtr) != 0 + primaryKey = util.Read32[uint32](c.mod, primaryKeyPtr) != 0 } return } -func (c *Conn) error(rc uint64, sql ...string) error { +func (c *Conn) error(rc res_t, sql ...string) error { return c.sqlite.error(rc, c.handle, sql...) } diff --git a/const.go b/const.go index d4908de0..60d2bdc6 100644 --- a/const.go +++ b/const.go @@ -1,6 +1,10 @@ package sqlite3 -import "strconv" +import ( + "strconv" + + "github.com/ncruces/go-sqlite3/internal/util" +) const ( _OK = 0 /* Successful result */ @@ -12,8 +16,13 @@ const ( _MAX_SQL_LENGTH = 1e9 _MAX_FUNCTION_ARG = 100 - ptrlen = 4 - intlen = 4 + ptrlen = util.PtrLen + intlen = util.IntLen +) + +type ( + ptr_t = util.Ptr_t + res_t = util.Res_t ) // ErrorCode is a result code that [Error.Code] might return. diff --git a/context.go b/context.go index 86be214e..34ee92f1 100644 --- a/context.go +++ b/context.go @@ -15,7 +15,7 @@ import ( // https://sqlite.org/c3ref/context.html type Context struct { c *Conn - handle uint32 + handle ptr_t } // Conn returns the database connection of the @@ -39,7 +39,7 @@ func (ctx Context) SetAuxData(n int, data any) { // // https://sqlite.org/c3ref/get_auxdata.html func (ctx Context) GetAuxData(n int) any { - ptr := uint32(ctx.c.call("sqlite3_get_auxdata", uint64(ctx.handle), uint64(n))) + ptr := ptr_t(ctx.c.call("sqlite3_get_auxdata", uint64(ctx.handle), uint64(n))) return util.GetHandle(ctx.c.ctx, ptr) } @@ -223,6 +223,6 @@ func (ctx Context) ResultError(err error) { // // https://sqlite.org/c3ref/vtab_nochange.html func (ctx Context) VTabNoChange() bool { - r := ctx.c.call("sqlite3_vtab_nochange", uint64(ctx.handle)) - return r != 0 + b := int32(ctx.c.call("sqlite3_vtab_nochange", uint64(ctx.handle))) + return b != 0 } diff --git a/error.go b/error.go index 870aa3ab..3799416e 100644 --- a/error.go +++ b/error.go @@ -15,7 +15,7 @@ type Error struct { str string msg string sql string - code uint64 + code res_t } // Code returns the primary error code for this error. diff --git a/error_test.go b/error_test.go index 2204fa8b..1cdc804a 100644 --- a/error_test.go +++ b/error_test.go @@ -59,14 +59,14 @@ func TestError_Temporary(t *testing.T) { tests := []struct { name string - code uint64 + code res_t want bool }{ - {"ERROR", uint64(ERROR), false}, - {"BUSY", uint64(BUSY), true}, - {"BUSY_RECOVERY", uint64(BUSY_RECOVERY), true}, - {"BUSY_SNAPSHOT", uint64(BUSY_SNAPSHOT), true}, - {"BUSY_TIMEOUT", uint64(BUSY_TIMEOUT), true}, + {"ERROR", res_t(ERROR), false}, + {"BUSY", res_t(BUSY), true}, + {"BUSY_RECOVERY", res_t(BUSY_RECOVERY), true}, + {"BUSY_SNAPSHOT", res_t(BUSY_SNAPSHOT), true}, + {"BUSY_TIMEOUT", res_t(BUSY_TIMEOUT), true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -97,14 +97,14 @@ func TestError_Timeout(t *testing.T) { tests := []struct { name string - code uint64 + code res_t want bool }{ - {"ERROR", uint64(ERROR), false}, - {"BUSY", uint64(BUSY), false}, - {"BUSY_RECOVERY", uint64(BUSY_RECOVERY), false}, - {"BUSY_SNAPSHOT", uint64(BUSY_SNAPSHOT), false}, - {"BUSY_TIMEOUT", uint64(BUSY_TIMEOUT), true}, + {"ERROR", res_t(ERROR), false}, + {"BUSY", res_t(BUSY), false}, + {"BUSY_RECOVERY", res_t(BUSY_RECOVERY), false}, + {"BUSY_SNAPSHOT", res_t(BUSY_SNAPSHOT), false}, + {"BUSY_TIMEOUT", res_t(BUSY_TIMEOUT), true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -136,8 +136,8 @@ func Test_ErrorCode_Error(t *testing.T) { // Test all error codes. for i := 0; i == int(ErrorCode(i)); i++ { want := "sqlite3: " - r := db.call("sqlite3_errstr", uint64(i)) - want += util.ReadString(db.mod, uint32(r), _MAX_NAME) + ptr := ptr_t(db.call("sqlite3_errstr", uint64(i))) + want += util.ReadString(db.mod, ptr, _MAX_NAME) got := ErrorCode(i).Error() if got != want { @@ -158,8 +158,8 @@ func Test_ExtendedErrorCode_Error(t *testing.T) { // Test all extended error codes. for i := 0; i == int(ExtendedErrorCode(i)); i++ { want := "sqlite3: " - r := db.call("sqlite3_errstr", uint64(i)) - want += util.ReadString(db.mod, uint32(r), _MAX_NAME) + ptr := ptr_t(db.call("sqlite3_errstr", uint64(i))) + want += util.ReadString(db.mod, ptr, _MAX_NAME) got := ExtendedErrorCode(i).Error() if got != want { diff --git a/func.go b/func.go index c416e695..cdb4e8e6 100644 --- a/func.go +++ b/func.go @@ -18,8 +18,8 @@ func (c *Conn) CollationNeeded(cb func(db *Conn, name string)) error { if cb != nil { enable = 1 } - r := c.call("sqlite3_collation_needed_go", uint64(c.handle), enable) - if err := c.error(r); err != nil { + rc := res_t(c.call("sqlite3_collation_needed_go", uint64(c.handle), enable)) + if err := c.error(rc); err != nil { return err } c.collation = cb @@ -33,8 +33,8 @@ func (c *Conn) CollationNeeded(cb func(db *Conn, name string)) error { // This can be used to load schemas that contain // one or more unknown collating sequences. func (c Conn) AnyCollationNeeded() error { - r := c.call("sqlite3_anycollseq_init", uint64(c.handle), 0, 0) - if err := c.error(r); err != nil { + rc := res_t(c.call("sqlite3_anycollseq_init", uint64(c.handle), 0, 0)) + if err := c.error(rc); err != nil { return err } c.collation = nil @@ -45,31 +45,31 @@ func (c Conn) AnyCollationNeeded() error { // // https://sqlite.org/c3ref/create_collation.html func (c *Conn) CreateCollation(name string, fn func(a, b []byte) int) error { - var funcPtr uint32 + var funcPtr ptr_t defer c.arena.mark()() namePtr := c.arena.string(name) if fn != nil { funcPtr = util.AddHandle(c.ctx, fn) } - r := c.call("sqlite3_create_collation_go", - uint64(c.handle), uint64(namePtr), uint64(funcPtr)) - return c.error(r) + rc := res_t(c.call("sqlite3_create_collation_go", + uint64(c.handle), uint64(namePtr), uint64(funcPtr))) + return c.error(rc) } // CreateFunction defines a new scalar SQL function. // // https://sqlite.org/c3ref/create_function.html func (c *Conn) CreateFunction(name string, nArg int, flag FunctionFlag, fn ScalarFunction) error { - var funcPtr uint32 + var funcPtr ptr_t defer c.arena.mark()() namePtr := c.arena.string(name) if fn != nil { funcPtr = util.AddHandle(c.ctx, fn) } - r := c.call("sqlite3_create_function_go", + rc := res_t(c.call("sqlite3_create_function_go", uint64(c.handle), uint64(namePtr), uint64(nArg), - uint64(flag), uint64(funcPtr)) - return c.error(r) + uint64(flag), uint64(funcPtr))) + return c.error(rc) } // ScalarFunction is the type of a scalar SQL function. @@ -82,7 +82,7 @@ type ScalarFunction func(ctx Context, arg ...Value) // // https://sqlite.org/c3ref/create_function.html func (c *Conn) CreateWindowFunction(name string, nArg int, flag FunctionFlag, fn func() AggregateFunction) error { - var funcPtr uint32 + var funcPtr ptr_t defer c.arena.mark()() namePtr := c.arena.string(name) if fn != nil { @@ -92,10 +92,10 @@ func (c *Conn) CreateWindowFunction(name string, nArg int, flag FunctionFlag, fn if _, ok := fn().(WindowFunction); ok { call = "sqlite3_create_window_function_go" } - r := c.call(call, + rc := res_t(c.call(call, uint64(c.handle), uint64(namePtr), uint64(nArg), - uint64(flag), uint64(funcPtr)) - return c.error(r) + uint64(flag), uint64(funcPtr))) + return c.error(rc) } // AggregateFunction is the interface an aggregate function should implement. @@ -129,28 +129,28 @@ type WindowFunction interface { func (c *Conn) OverloadFunction(name string, nArg int) error { defer c.arena.mark()() namePtr := c.arena.string(name) - r := c.call("sqlite3_overload_function", - uint64(c.handle), uint64(namePtr), uint64(nArg)) - return c.error(r) + rc := res_t(c.call("sqlite3_overload_function", + uint64(c.handle), uint64(namePtr), uint64(nArg))) + return c.error(rc) } -func destroyCallback(ctx context.Context, mod api.Module, pApp uint32) { +func destroyCallback(ctx context.Context, mod api.Module, pApp ptr_t) { util.DelHandle(ctx, pApp) } -func collationCallback(ctx context.Context, mod api.Module, pArg, pDB, eTextRep, zName uint32) { +func collationCallback(ctx context.Context, mod api.Module, pArg, pDB ptr_t, eTextRep uint32, zName ptr_t) { if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.collation != nil { name := util.ReadString(mod, zName, _MAX_NAME) c.collation(c, name) } } -func compareCallback(ctx context.Context, mod api.Module, pApp, nKey1, pKey1, nKey2, pKey2 uint32) uint32 { +func compareCallback(ctx context.Context, mod api.Module, pApp ptr_t, nKey1 uint32, pKey1 ptr_t, nKey2 uint32, pKey2 ptr_t) uint32 { fn := util.GetHandle(ctx, pApp).(func(a, b []byte) int) return uint32(fn(util.View(mod, pKey1, uint64(nKey1)), util.View(mod, pKey2, uint64(nKey2)))) } -func funcCallback(ctx context.Context, mod api.Module, pCtx, pApp, nArg, pArg uint32) { +func funcCallback(ctx context.Context, mod api.Module, pCtx, pApp ptr_t, nArg uint32, pArg ptr_t) { args := getFuncArgs() defer putFuncArgs(args) db := ctx.Value(connKey{}).(*Conn) @@ -159,7 +159,7 @@ func funcCallback(ctx context.Context, mod api.Module, pCtx, pApp, nArg, pArg ui fn(Context{db, pCtx}, args[:nArg]...) } -func stepCallback(ctx context.Context, mod api.Module, pCtx, pAgg, pApp, nArg, pArg uint32) { +func stepCallback(ctx context.Context, mod api.Module, pCtx, pAgg, pApp ptr_t, nArg uint32, pArg ptr_t) { args := getFuncArgs() defer putFuncArgs(args) db := ctx.Value(connKey{}).(*Conn) @@ -168,7 +168,7 @@ func stepCallback(ctx context.Context, mod api.Module, pCtx, pAgg, pApp, nArg, p fn.Step(Context{db, pCtx}, args[:nArg]...) } -func finalCallback(ctx context.Context, mod api.Module, pCtx, pAgg, pApp uint32) { +func finalCallback(ctx context.Context, mod api.Module, pCtx, pAgg, pApp ptr_t) { db := ctx.Value(connKey{}).(*Conn) fn, handle := callbackAggregate(db, pAgg, pApp) fn.Value(Context{db, pCtx}) @@ -178,13 +178,13 @@ func finalCallback(ctx context.Context, mod api.Module, pCtx, pAgg, pApp uint32) } } -func valueCallback(ctx context.Context, mod api.Module, pCtx, pAgg uint32) { +func valueCallback(ctx context.Context, mod api.Module, pCtx, pAgg ptr_t) { db := ctx.Value(connKey{}).(*Conn) fn := util.GetHandle(db.ctx, pAgg).(AggregateFunction) fn.Value(Context{db, pCtx}) } -func inverseCallback(ctx context.Context, mod api.Module, pCtx, pAgg, nArg, pArg uint32) { +func inverseCallback(ctx context.Context, mod api.Module, pCtx, pAgg ptr_t, nArg uint32, pArg ptr_t) { args := getFuncArgs() defer putFuncArgs(args) db := ctx.Value(connKey{}).(*Conn) @@ -193,9 +193,9 @@ func inverseCallback(ctx context.Context, mod api.Module, pCtx, pAgg, nArg, pArg fn.Inverse(Context{db, pCtx}, args[:nArg]...) } -func callbackAggregate(db *Conn, pAgg, pApp uint32) (AggregateFunction, uint32) { +func callbackAggregate(db *Conn, pAgg, pApp ptr_t) (AggregateFunction, ptr_t) { if pApp == 0 { - handle := util.ReadUint32(db.mod, pAgg) + handle := util.Read32[ptr_t](db.mod, pAgg) return util.GetHandle(db.ctx, handle).(AggregateFunction), handle } @@ -203,17 +203,17 @@ func callbackAggregate(db *Conn, pAgg, pApp uint32) (AggregateFunction, uint32) fn := util.GetHandle(db.ctx, pApp).(func() AggregateFunction)() if pAgg != 0 { handle := util.AddHandle(db.ctx, fn) - util.WriteUint32(db.mod, pAgg, handle) + util.Write32(db.mod, pAgg, handle) return fn, handle } return fn, 0 } -func callbackArgs(db *Conn, arg []Value, pArg uint32) { +func callbackArgs(db *Conn, arg []Value, pArg ptr_t) { for i := range arg { arg[i] = Value{ c: db, - handle: util.ReadUint32(db.mod, pArg+ptrlen*uint32(i)), + handle: util.Read32[ptr_t](db.mod, pArg+ptrlen*ptr_t(i)), } } } diff --git a/internal/util/func.go b/internal/util/func.go index 468ff741..d310afc2 100644 --- a/internal/util/func.go +++ b/internal/util/func.go @@ -7,6 +7,7 @@ import ( "github.com/tetratelabs/wazero/api" ) +type i8 interface{ ~int8 | ~uint8 } type i32 interface{ ~int32 | ~uint32 } type i64 interface{ ~int64 | ~uint64 } diff --git a/internal/util/handle.go b/internal/util/handle.go index e4e33854..f9f39b44 100644 --- a/internal/util/handle.go +++ b/internal/util/handle.go @@ -20,7 +20,7 @@ func (s *handleState) CloseNotify(ctx context.Context, exitCode uint32) { s.holes = 0 } -func GetHandle(ctx context.Context, id uint32) any { +func GetHandle(ctx context.Context, id Ptr_t) any { if id == 0 { return nil } @@ -28,14 +28,14 @@ func GetHandle(ctx context.Context, id uint32) any { return s.handles[^id] } -func DelHandle(ctx context.Context, id uint32) error { +func DelHandle(ctx context.Context, id Ptr_t) error { if id == 0 { return nil } s := ctx.Value(moduleKey{}).(*moduleState) a := s.handles[^id] s.handles[^id] = nil - if l := uint32(len(s.handles)); l == ^id { + if l := Ptr_t(len(s.handles)); l == ^id { s.handles = s.handles[:l-1] } else { s.holes++ @@ -46,7 +46,7 @@ func DelHandle(ctx context.Context, id uint32) error { return nil } -func AddHandle(ctx context.Context, a any) uint32 { +func AddHandle(ctx context.Context, a any) Ptr_t { if a == nil { panic(NilErr) } @@ -59,12 +59,12 @@ func AddHandle(ctx context.Context, a any) uint32 { if h == nil { s.holes-- s.handles[id] = a - return ^uint32(id) + return ^Ptr_t(id) } } } // Add a new slot. s.handles = append(s.handles, a) - return -uint32(len(s.handles)) + return -Ptr_t(len(s.handles)) } diff --git a/internal/util/mem.go b/internal/util/mem.go index a09523fd..7172ab08 100644 --- a/internal/util/mem.go +++ b/internal/util/mem.go @@ -7,7 +7,17 @@ import ( "github.com/tetratelabs/wazero/api" ) -func View(mod api.Module, ptr uint32, size uint64) []byte { +const ( + PtrLen = 4 + IntLen = 4 +) + +type ( + Ptr_t uint32 + Res_t int32 +) + +func View(mod api.Module, ptr Ptr_t, size uint64) []byte { if ptr == 0 { panic(NilErr) } @@ -17,85 +27,85 @@ func View(mod api.Module, ptr uint32, size uint64) []byte { if size == 0 { return nil } - buf, ok := mod.Memory().Read(ptr, uint32(size)) + buf, ok := mod.Memory().Read(uint32(ptr), uint32(size)) if !ok { panic(RangeErr) } return buf } -func ReadUint8(mod api.Module, ptr uint32) uint8 { +func Read[T i8](mod api.Module, ptr Ptr_t) T { if ptr == 0 { panic(NilErr) } - v, ok := mod.Memory().ReadByte(ptr) + v, ok := mod.Memory().ReadByte(uint32(ptr)) if !ok { panic(RangeErr) } - return v + return T(v) } -func ReadUint32(mod api.Module, ptr uint32) uint32 { +func Write[T i8](mod api.Module, ptr Ptr_t, v T) { if ptr == 0 { panic(NilErr) } - v, ok := mod.Memory().ReadUint32Le(ptr) + ok := mod.Memory().WriteByte(uint32(ptr), uint8(v)) if !ok { panic(RangeErr) } - return v } -func WriteUint8(mod api.Module, ptr uint32, v uint8) { +func Read32[T i32](mod api.Module, ptr Ptr_t) T { if ptr == 0 { panic(NilErr) } - ok := mod.Memory().WriteByte(ptr, v) + v, ok := mod.Memory().ReadUint32Le(uint32(ptr)) if !ok { panic(RangeErr) } + return T(v) } -func WriteUint32(mod api.Module, ptr uint32, v uint32) { +func Write32[T i32](mod api.Module, ptr Ptr_t, v T) { if ptr == 0 { panic(NilErr) } - ok := mod.Memory().WriteUint32Le(ptr, v) + ok := mod.Memory().WriteUint32Le(uint32(ptr), uint32(v)) if !ok { panic(RangeErr) } } -func ReadUint64(mod api.Module, ptr uint32) uint64 { +func Read64[T i64](mod api.Module, ptr Ptr_t) T { if ptr == 0 { panic(NilErr) } - v, ok := mod.Memory().ReadUint64Le(ptr) + v, ok := mod.Memory().ReadUint64Le(uint32(ptr)) if !ok { panic(RangeErr) } - return v + return T(v) } -func WriteUint64(mod api.Module, ptr uint32, v uint64) { +func Write64[T i64](mod api.Module, ptr Ptr_t, v T) { if ptr == 0 { panic(NilErr) } - ok := mod.Memory().WriteUint64Le(ptr, v) + ok := mod.Memory().WriteUint64Le(uint32(ptr), uint64(v)) if !ok { panic(RangeErr) } } -func ReadFloat64(mod api.Module, ptr uint32) float64 { - return math.Float64frombits(ReadUint64(mod, ptr)) +func ReadFloat64(mod api.Module, ptr Ptr_t) float64 { + return math.Float64frombits(Read64[uint64](mod, ptr)) } -func WriteFloat64(mod api.Module, ptr uint32, v float64) { - WriteUint64(mod, ptr, math.Float64bits(v)) +func WriteFloat64(mod api.Module, ptr Ptr_t, v float64) { + Write64(mod, ptr, math.Float64bits(v)) } -func ReadString(mod api.Module, ptr, maxlen uint32) string { +func ReadString(mod api.Module, ptr Ptr_t, maxlen uint32) string { if ptr == 0 { panic(NilErr) } @@ -108,9 +118,9 @@ func ReadString(mod api.Module, ptr, maxlen uint32) string { maxlen = maxlen + 1 } mem := mod.Memory() - buf, ok := mem.Read(ptr, maxlen) + buf, ok := mem.Read(uint32(ptr), maxlen) if !ok { - buf, ok = mem.Read(ptr, mem.Size()-ptr) + buf, ok = mem.Read(uint32(ptr), mem.Size()-uint32(ptr)) if !ok { panic(RangeErr) } @@ -122,12 +132,12 @@ func ReadString(mod api.Module, ptr, maxlen uint32) string { } } -func WriteBytes(mod api.Module, ptr uint32, b []byte) { +func WriteBytes(mod api.Module, ptr Ptr_t, b []byte) { buf := View(mod, ptr, uint64(len(b))) copy(buf, b) } -func WriteString(mod api.Module, ptr uint32, s string) { +func WriteString(mod api.Module, ptr Ptr_t, s string) { buf := View(mod, ptr, uint64(len(s)+1)) buf[len(s)] = 0 copy(buf, s) diff --git a/internal/util/mem_test.go b/internal/util/mem_test.go index 733ab344..0226e7b0 100644 --- a/internal/util/mem_test.go +++ b/internal/util/mem_test.go @@ -31,84 +31,84 @@ func TestView_overflow(t *testing.T) { func TestReadUint8_nil(t *testing.T) { defer func() { _ = recover() }() mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize)) - ReadUint8(mock, 0) + Read[byte](mock, 0) t.Error("want panic") } func TestReadUint8_range(t *testing.T) { defer func() { _ = recover() }() mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize)) - ReadUint8(mock, wazerotest.PageSize) + Read[byte](mock, wazerotest.PageSize) t.Error("want panic") } func TestReadUint32_nil(t *testing.T) { defer func() { _ = recover() }() mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize)) - ReadUint32(mock, 0) + Read32[uint32](mock, 0) t.Error("want panic") } func TestReadUint32_range(t *testing.T) { defer func() { _ = recover() }() mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize)) - ReadUint32(mock, wazerotest.PageSize-2) + Read32[uint32](mock, wazerotest.PageSize-2) t.Error("want panic") } func TestReadUint64_nil(t *testing.T) { defer func() { _ = recover() }() mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize)) - ReadUint64(mock, 0) + Read64[uint64](mock, 0) t.Error("want panic") } func TestReadUint64_range(t *testing.T) { defer func() { _ = recover() }() mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize)) - ReadUint64(mock, wazerotest.PageSize-2) + Read64[uint64](mock, wazerotest.PageSize-2) t.Error("want panic") } func TestWriteUint8_nil(t *testing.T) { defer func() { _ = recover() }() mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize)) - WriteUint8(mock, 0, 1) + Write[byte](mock, 0, 1) t.Error("want panic") } func TestWriteUint8_range(t *testing.T) { defer func() { _ = recover() }() mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize)) - WriteUint8(mock, wazerotest.PageSize, 1) + Write[byte](mock, wazerotest.PageSize, 1) t.Error("want panic") } func TestWriteUint32_nil(t *testing.T) { defer func() { _ = recover() }() mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize)) - WriteUint32(mock, 0, 1) + Write32[uint32](mock, 0, 1) t.Error("want panic") } func TestWriteUint32_range(t *testing.T) { defer func() { _ = recover() }() mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize)) - WriteUint32(mock, wazerotest.PageSize-2, 1) + Write32[uint32](mock, wazerotest.PageSize-2, 1) t.Error("want panic") } func TestWriteUint64_nil(t *testing.T) { defer func() { _ = recover() }() mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize)) - WriteUint64(mock, 0, 1) + Write64[uint64](mock, 0, 1) t.Error("want panic") } func TestWriteUint64_range(t *testing.T) { defer func() { _ = recover() }() mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize)) - WriteUint64(mock, wazerotest.PageSize-2, 1) + Write64[uint64](mock, wazerotest.PageSize-2, 1) t.Error("want panic") } diff --git a/internal/util/mmap_unix.go b/internal/util/mmap_unix.go index 4ff05666..0c5363a7 100644 --- a/internal/util/mmap_unix.go +++ b/internal/util/mmap_unix.go @@ -37,7 +37,7 @@ func (s *mmapState) new(ctx context.Context, mod api.Module, size int32) *Mapped } // Save the newly allocated region. - ptr := uint32(stack[0]) + ptr := Ptr_t(stack[0]) buf := View(mod, ptr, uint64(size)) res := &MappedRegion{ Ptr: ptr, @@ -50,7 +50,7 @@ func (s *mmapState) new(ctx context.Context, mod api.Module, size int32) *Mapped type MappedRegion struct { addr unsafe.Pointer - Ptr uint32 + Ptr Ptr_t size int32 used bool } diff --git a/sqlite.go b/sqlite.go index 90352d07..d478dca6 100644 --- a/sqlite.go +++ b/sqlite.go @@ -3,7 +3,6 @@ package sqlite3 import ( "context" - "math" "math/bits" "os" "sync" @@ -120,7 +119,7 @@ func (sqlt *sqlite) close() error { return sqlt.mod.Close(sqlt.ctx) } -func (sqlt *sqlite) error(rc uint64, handle uint32, sql ...string) error { +func (sqlt *sqlite) error(rc res_t, handle ptr_t, sql ...string) error { if rc == _OK { return nil } @@ -131,18 +130,18 @@ func (sqlt *sqlite) error(rc uint64, handle uint32, sql ...string) error { panic(util.OOMErr) } - if r := sqlt.call("sqlite3_errstr", rc); r != 0 { - err.str = util.ReadString(sqlt.mod, uint32(r), _MAX_NAME) + if ptr := ptr_t(sqlt.call("sqlite3_errstr", uint64(rc))); ptr != 0 { + err.str = util.ReadString(sqlt.mod, ptr, _MAX_NAME) } if handle != 0 { - if r := sqlt.call("sqlite3_errmsg", uint64(handle)); r != 0 { - err.msg = util.ReadString(sqlt.mod, uint32(r), _MAX_LENGTH) + if ptr := ptr_t(sqlt.call("sqlite3_errmsg", uint64(handle))); ptr != 0 { + err.msg = util.ReadString(sqlt.mod, ptr, _MAX_LENGTH) } if len(sql) != 0 { - if r := sqlt.call("sqlite3_error_offset", uint64(handle)); r != math.MaxUint32 { - err.sql = sql[0][r:] + if i := int32(sqlt.call("sqlite3_error_offset", uint64(handle))); i != -1 { + err.sql = sql[0][i:] } } } @@ -182,7 +181,9 @@ func (sqlt *sqlite) putfn(name string, fn api.Function) { } } -func (sqlt *sqlite) call(name string, params ...uint64) uint64 { +type stk64 uint64 + +func (sqlt *sqlite) call(name string, params ...uint64) stk64 { copy(sqlt.stack[:], params) fn := sqlt.getfn(name) err := fn.CallWithStack(sqlt.ctx, sqlt.stack[:]) @@ -190,33 +191,33 @@ func (sqlt *sqlite) call(name string, params ...uint64) uint64 { panic(err) } sqlt.putfn(name, fn) - return sqlt.stack[0] + return stk64(sqlt.stack[0]) } -func (sqlt *sqlite) free(ptr uint32) { +func (sqlt *sqlite) free(ptr ptr_t) { if ptr == 0 { return } sqlt.call("sqlite3_free", uint64(ptr)) } -func (sqlt *sqlite) new(size uint64) uint32 { - ptr := uint32(sqlt.call("sqlite3_malloc64", size)) +func (sqlt *sqlite) new(size uint64) ptr_t { + ptr := ptr_t(sqlt.call("sqlite3_malloc64", size)) if ptr == 0 && size != 0 { panic(util.OOMErr) } return ptr } -func (sqlt *sqlite) realloc(ptr uint32, size uint64) uint32 { - ptr = uint32(sqlt.call("sqlite3_realloc64", uint64(ptr), size)) +func (sqlt *sqlite) realloc(ptr ptr_t, size uint64) ptr_t { + ptr = ptr_t(sqlt.call("sqlite3_realloc64", uint64(ptr), size)) if ptr == 0 && size != 0 { panic(util.OOMErr) } return ptr } -func (sqlt *sqlite) newBytes(b []byte) uint32 { +func (sqlt *sqlite) newBytes(b []byte) ptr_t { if (*[0]byte)(b) == nil { return 0 } @@ -229,7 +230,7 @@ func (sqlt *sqlite) newBytes(b []byte) uint32 { return ptr } -func (sqlt *sqlite) newString(s string) uint32 { +func (sqlt *sqlite) newString(s string) ptr_t { ptr := sqlt.new(uint64(len(s) + 1)) util.WriteString(sqlt.mod, ptr, s) return ptr @@ -247,8 +248,8 @@ func (sqlt *sqlite) newArena(size uint64) arena { type arena struct { sqlt *sqlite - ptrs []uint32 - base uint32 + ptrs []ptr_t + base ptr_t next uint32 size uint32 } @@ -277,7 +278,7 @@ func (a *arena) mark() (reset func()) { } } -func (a *arena) new(size uint64) uint32 { +func (a *arena) new(size uint64) ptr_t { // Align the next address, to 4 or 8 bytes. if size&7 != 0 { a.next = (a.next + 3) &^ 3 @@ -285,16 +286,16 @@ func (a *arena) new(size uint64) uint32 { a.next = (a.next + 7) &^ 7 } if size <= uint64(a.size-a.next) { - ptr := a.base + a.next + ptr := a.base + ptr_t(a.next) a.next += uint32(size) - return ptr + return ptr_t(ptr) } ptr := a.sqlt.new(size) a.ptrs = append(a.ptrs, ptr) - return ptr + return ptr_t(ptr) } -func (a *arena) bytes(b []byte) uint32 { +func (a *arena) bytes(b []byte) ptr_t { if (*[0]byte)(b) == nil { return 0 } @@ -303,7 +304,7 @@ func (a *arena) bytes(b []byte) uint32 { return ptr } -func (a *arena) string(s string) uint32 { +func (a *arena) string(s string) ptr_t { ptr := a.new(uint64(len(s) + 1)) util.WriteString(a.sqlt.mod, ptr, s) return ptr diff --git a/sqlite_test.go b/sqlite_test.go index fbcd069b..0e4b06cc 100644 --- a/sqlite_test.go +++ b/sqlite_test.go @@ -22,7 +22,7 @@ func Test_sqlite_error_OOM(t *testing.T) { defer sqlite.close() defer func() { _ = recover() }() - sqlite.error(uint64(NOMEM), 0) + sqlite.error(res_t(NOMEM), 0) t.Error("want panic") } diff --git a/stmt.go b/stmt.go index fdb13dcf..10430075 100644 --- a/stmt.go +++ b/stmt.go @@ -16,7 +16,7 @@ type Stmt struct { c *Conn err error sql string - handle uint32 + handle ptr_t } // Close destroys the prepared statement object. @@ -29,7 +29,7 @@ func (s *Stmt) Close() error { return nil } - r := s.c.call("sqlite3_finalize", uint64(s.handle)) + rc := res_t(s.c.call("sqlite3_finalize", uint64(s.handle))) stmts := s.c.stmts for i := range stmts { if s == stmts[i] { @@ -42,7 +42,7 @@ func (s *Stmt) Close() error { } s.handle = 0 - return s.c.error(r) + return s.c.error(rc) } // Conn returns the database connection to which the prepared statement belongs. @@ -64,9 +64,9 @@ func (s *Stmt) SQL() string { // // https://sqlite.org/c3ref/expanded_sql.html func (s *Stmt) ExpandedSQL() string { - r := s.c.call("sqlite3_expanded_sql", uint64(s.handle)) - sql := util.ReadString(s.c.mod, uint32(r), _MAX_SQL_LENGTH) - s.c.free(uint32(r)) + ptr := ptr_t(s.c.call("sqlite3_expanded_sql", uint64(s.handle))) + sql := util.ReadString(s.c.mod, ptr, _MAX_SQL_LENGTH) + s.c.free(ptr) return sql } @@ -75,25 +75,25 @@ func (s *Stmt) ExpandedSQL() string { // // https://sqlite.org/c3ref/stmt_readonly.html func (s *Stmt) ReadOnly() bool { - r := s.c.call("sqlite3_stmt_readonly", uint64(s.handle)) - return r != 0 + b := int32(s.c.call("sqlite3_stmt_readonly", uint64(s.handle))) + return b != 0 } // Reset resets the prepared statement object. // // https://sqlite.org/c3ref/reset.html func (s *Stmt) Reset() error { - r := s.c.call("sqlite3_reset", uint64(s.handle)) + rc := res_t(s.c.call("sqlite3_reset", uint64(s.handle))) s.err = nil - return s.c.error(r) + return s.c.error(rc) } // Busy determines if a prepared statement has been reset. // // https://sqlite.org/c3ref/stmt_busy.html func (s *Stmt) Busy() bool { - r := s.c.call("sqlite3_stmt_busy", uint64(s.handle)) - return r != 0 + rc := res_t(s.c.call("sqlite3_stmt_busy", uint64(s.handle))) + return rc != 0 } // Step evaluates the SQL statement. @@ -107,15 +107,15 @@ func (s *Stmt) Busy() bool { // https://sqlite.org/c3ref/step.html func (s *Stmt) Step() bool { s.c.checkInterrupt(s.c.handle) - r := s.c.call("sqlite3_step", uint64(s.handle)) - switch r { + rc := res_t(s.c.call("sqlite3_step", uint64(s.handle))) + switch rc { case _ROW: s.err = nil return true case _DONE: s.err = nil default: - s.err = s.c.error(r) + s.err = s.c.error(rc) } return false } @@ -147,26 +147,26 @@ func (s *Stmt) Status(op StmtStatus, reset bool) int { if reset { i = 1 } - r := s.c.call("sqlite3_stmt_status", uint64(s.handle), - uint64(op), i) - return int(int32(r)) + n := int32(s.c.call("sqlite3_stmt_status", uint64(s.handle), + uint64(op), i)) + return int(n) } // ClearBindings resets all bindings on the prepared statement. // // https://sqlite.org/c3ref/clear_bindings.html func (s *Stmt) ClearBindings() error { - r := s.c.call("sqlite3_clear_bindings", uint64(s.handle)) - return s.c.error(r) + rc := res_t(s.c.call("sqlite3_clear_bindings", uint64(s.handle))) + return s.c.error(rc) } // BindCount returns the number of SQL parameters in the prepared statement. // // https://sqlite.org/c3ref/bind_parameter_count.html func (s *Stmt) BindCount() int { - r := s.c.call("sqlite3_bind_parameter_count", - uint64(s.handle)) - return int(int32(r)) + n := int32(s.c.call("sqlite3_bind_parameter_count", + uint64(s.handle))) + return int(n) } // BindIndex returns the index of a parameter in the prepared statement @@ -176,9 +176,9 @@ func (s *Stmt) BindCount() int { func (s *Stmt) BindIndex(name string) int { defer s.c.arena.mark()() namePtr := s.c.arena.string(name) - r := s.c.call("sqlite3_bind_parameter_index", - uint64(s.handle), uint64(namePtr)) - return int(int32(r)) + i := int32(s.c.call("sqlite3_bind_parameter_index", + uint64(s.handle), uint64(namePtr))) + return int(i) } // BindName returns the name of a parameter in the prepared statement. @@ -186,10 +186,8 @@ func (s *Stmt) BindIndex(name string) int { // // https://sqlite.org/c3ref/bind_parameter_name.html func (s *Stmt) BindName(param int) string { - r := s.c.call("sqlite3_bind_parameter_name", - uint64(s.handle), uint64(param)) - - ptr := uint32(r) + ptr := ptr_t(s.c.call("sqlite3_bind_parameter_name", + uint64(s.handle), uint64(param))) if ptr == 0 { return "" } @@ -223,9 +221,9 @@ func (s *Stmt) BindInt(param int, value int) error { // // https://sqlite.org/c3ref/bind_blob.html func (s *Stmt) BindInt64(param int, value int64) error { - r := s.c.call("sqlite3_bind_int64", - uint64(s.handle), uint64(param), uint64(value)) - return s.c.error(r) + rc := res_t(s.c.call("sqlite3_bind_int64", + uint64(s.handle), uint64(param), uint64(value))) + return s.c.error(rc) } // BindFloat binds a float64 to the prepared statement. @@ -233,9 +231,9 @@ func (s *Stmt) BindInt64(param int, value int64) error { // // https://sqlite.org/c3ref/bind_blob.html func (s *Stmt) BindFloat(param int, value float64) error { - r := s.c.call("sqlite3_bind_double", - uint64(s.handle), uint64(param), math.Float64bits(value)) - return s.c.error(r) + rc := res_t(s.c.call("sqlite3_bind_double", + uint64(s.handle), uint64(param), math.Float64bits(value))) + return s.c.error(rc) } // BindText binds a string to the prepared statement. @@ -247,10 +245,10 @@ func (s *Stmt) BindText(param int, value string) error { return TOOBIG } ptr := s.c.newString(value) - r := s.c.call("sqlite3_bind_text_go", + rc := res_t(s.c.call("sqlite3_bind_text_go", uint64(s.handle), uint64(param), - uint64(ptr), uint64(len(value))) - return s.c.error(r) + uint64(ptr), uint64(len(value)))) + return s.c.error(rc) } // BindRawText binds a []byte to the prepared statement as text. @@ -263,10 +261,10 @@ func (s *Stmt) BindRawText(param int, value []byte) error { return TOOBIG } ptr := s.c.newBytes(value) - r := s.c.call("sqlite3_bind_text_go", + rc := res_t(s.c.call("sqlite3_bind_text_go", uint64(s.handle), uint64(param), - uint64(ptr), uint64(len(value))) - return s.c.error(r) + uint64(ptr), uint64(len(value)))) + return s.c.error(rc) } // BindBlob binds a []byte to the prepared statement. @@ -279,10 +277,10 @@ func (s *Stmt) BindBlob(param int, value []byte) error { return TOOBIG } ptr := s.c.newBytes(value) - r := s.c.call("sqlite3_bind_blob_go", + rc := res_t(s.c.call("sqlite3_bind_blob_go", uint64(s.handle), uint64(param), - uint64(ptr), uint64(len(value))) - return s.c.error(r) + uint64(ptr), uint64(len(value)))) + return s.c.error(rc) } // BindZeroBlob binds a zero-filled, length n BLOB to the prepared statement. @@ -290,9 +288,9 @@ func (s *Stmt) BindBlob(param int, value []byte) error { // // https://sqlite.org/c3ref/bind_blob.html func (s *Stmt) BindZeroBlob(param int, n int64) error { - r := s.c.call("sqlite3_bind_zeroblob64", - uint64(s.handle), uint64(param), uint64(n)) - return s.c.error(r) + rc := res_t(s.c.call("sqlite3_bind_zeroblob64", + uint64(s.handle), uint64(param), uint64(n))) + return s.c.error(rc) } // BindNull binds a NULL to the prepared statement. @@ -300,9 +298,9 @@ func (s *Stmt) BindZeroBlob(param int, n int64) error { // // https://sqlite.org/c3ref/bind_blob.html func (s *Stmt) BindNull(param int) error { - r := s.c.call("sqlite3_bind_null", - uint64(s.handle), uint64(param)) - return s.c.error(r) + rc := res_t(s.c.call("sqlite3_bind_null", + uint64(s.handle), uint64(param))) + return s.c.error(rc) } // BindTime binds a [time.Time] to the prepared statement. @@ -333,10 +331,10 @@ func (s *Stmt) bindRFC3339Nano(param int, value time.Time) error { buf := util.View(s.c.mod, ptr, maxlen) buf = value.AppendFormat(buf[:0], time.RFC3339Nano) - r := s.c.call("sqlite3_bind_text_go", + rc := res_t(s.c.call("sqlite3_bind_text_go", uint64(s.handle), uint64(param), - uint64(ptr), uint64(len(buf))) - return s.c.error(r) + uint64(ptr), uint64(len(buf)))) + return s.c.error(rc) } // BindPointer binds a NULL to the prepared statement, just like [Stmt.BindNull], @@ -347,9 +345,9 @@ func (s *Stmt) bindRFC3339Nano(param int, value time.Time) error { // https://sqlite.org/c3ref/bind_blob.html func (s *Stmt) BindPointer(param int, ptr any) error { valPtr := util.AddHandle(s.c.ctx, ptr) - r := s.c.call("sqlite3_bind_pointer_go", - uint64(s.handle), uint64(param), uint64(valPtr)) - return s.c.error(r) + rc := res_t(s.c.call("sqlite3_bind_pointer_go", + uint64(s.handle), uint64(param), uint64(valPtr))) + return s.c.error(rc) } // BindJSON binds the JSON encoding of value to the prepared statement. @@ -372,27 +370,27 @@ func (s *Stmt) BindValue(param int, value Value) error { if value.c != s.c { return MISUSE } - r := s.c.call("sqlite3_bind_value", - uint64(s.handle), uint64(param), uint64(value.handle)) - return s.c.error(r) + rc := res_t(s.c.call("sqlite3_bind_value", + uint64(s.handle), uint64(param), uint64(value.handle))) + return s.c.error(rc) } // DataCount resets the number of columns in a result set. // // https://sqlite.org/c3ref/data_count.html func (s *Stmt) DataCount() int { - r := s.c.call("sqlite3_data_count", - uint64(s.handle)) - return int(int32(r)) + n := int32(s.c.call("sqlite3_data_count", + uint64(s.handle))) + return int(n) } // ColumnCount returns the number of columns in a result set. // // https://sqlite.org/c3ref/column_count.html func (s *Stmt) ColumnCount() int { - r := s.c.call("sqlite3_column_count", - uint64(s.handle)) - return int(int32(r)) + n := int32(s.c.call("sqlite3_column_count", + uint64(s.handle))) + return int(n) } // ColumnName returns the name of the result column. @@ -400,12 +398,12 @@ func (s *Stmt) ColumnCount() int { // // https://sqlite.org/c3ref/column_name.html func (s *Stmt) ColumnName(col int) string { - r := s.c.call("sqlite3_column_name", - uint64(s.handle), uint64(col)) - if r == 0 { + ptr := ptr_t(s.c.call("sqlite3_column_name", + uint64(s.handle), uint64(col))) + if ptr == 0 { panic(util.OOMErr) } - return util.ReadString(s.c.mod, uint32(r), _MAX_NAME) + return util.ReadString(s.c.mod, ptr, _MAX_NAME) } // ColumnType returns the initial [Datatype] of the result column. @@ -413,9 +411,8 @@ func (s *Stmt) ColumnName(col int) string { // // https://sqlite.org/c3ref/column_blob.html func (s *Stmt) ColumnType(col int) Datatype { - r := s.c.call("sqlite3_column_type", - uint64(s.handle), uint64(col)) - return Datatype(r) + return Datatype(s.c.call("sqlite3_column_type", + uint64(s.handle), uint64(col))) } // ColumnDeclType returns the declared datatype of the result column. @@ -423,12 +420,12 @@ func (s *Stmt) ColumnType(col int) Datatype { // // https://sqlite.org/c3ref/column_decltype.html func (s *Stmt) ColumnDeclType(col int) string { - r := s.c.call("sqlite3_column_decltype", - uint64(s.handle), uint64(col)) - if r == 0 { + ptr := ptr_t(s.c.call("sqlite3_column_decltype", + uint64(s.handle), uint64(col))) + if ptr == 0 { return "" } - return util.ReadString(s.c.mod, uint32(r), _MAX_NAME) + return util.ReadString(s.c.mod, ptr, _MAX_NAME) } // ColumnDatabaseName returns the name of the database @@ -437,12 +434,12 @@ func (s *Stmt) ColumnDeclType(col int) string { // // https://sqlite.org/c3ref/column_database_name.html func (s *Stmt) ColumnDatabaseName(col int) string { - r := s.c.call("sqlite3_column_database_name", - uint64(s.handle), uint64(col)) - if r == 0 { + ptr := ptr_t(s.c.call("sqlite3_column_database_name", + uint64(s.handle), uint64(col))) + if ptr == 0 { return "" } - return util.ReadString(s.c.mod, uint32(r), _MAX_NAME) + return util.ReadString(s.c.mod, ptr, _MAX_NAME) } // ColumnTableName returns the name of the table @@ -451,12 +448,12 @@ func (s *Stmt) ColumnDatabaseName(col int) string { // // https://sqlite.org/c3ref/column_database_name.html func (s *Stmt) ColumnTableName(col int) string { - r := s.c.call("sqlite3_column_table_name", - uint64(s.handle), uint64(col)) - if r == 0 { + ptr := ptr_t(s.c.call("sqlite3_column_table_name", + uint64(s.handle), uint64(col))) + if ptr == 0 { return "" } - return util.ReadString(s.c.mod, uint32(r), _MAX_NAME) + return util.ReadString(s.c.mod, ptr, _MAX_NAME) } // ColumnOriginName returns the name of the table column @@ -465,12 +462,12 @@ func (s *Stmt) ColumnTableName(col int) string { // // https://sqlite.org/c3ref/column_database_name.html func (s *Stmt) ColumnOriginName(col int) string { - r := s.c.call("sqlite3_column_origin_name", - uint64(s.handle), uint64(col)) - if r == 0 { + ptr := ptr_t(s.c.call("sqlite3_column_origin_name", + uint64(s.handle), uint64(col))) + if ptr == 0 { return "" } - return util.ReadString(s.c.mod, uint32(r), _MAX_NAME) + return util.ReadString(s.c.mod, ptr, _MAX_NAME) } // ColumnBool returns the value of the result column as a bool. @@ -497,9 +494,8 @@ func (s *Stmt) ColumnInt(col int) int { // // https://sqlite.org/c3ref/column_blob.html func (s *Stmt) ColumnInt64(col int) int64 { - r := s.c.call("sqlite3_column_int64", - uint64(s.handle), uint64(col)) - return int64(r) + return int64(s.c.call("sqlite3_column_int64", + uint64(s.handle), uint64(col))) } // ColumnFloat returns the value of the result column as a float64. @@ -507,9 +503,9 @@ func (s *Stmt) ColumnInt64(col int) int64 { // // https://sqlite.org/c3ref/column_blob.html func (s *Stmt) ColumnFloat(col int) float64 { - r := s.c.call("sqlite3_column_double", - uint64(s.handle), uint64(col)) - return math.Float64frombits(r) + f := uint64(s.c.call("sqlite3_column_double", + uint64(s.handle), uint64(col))) + return math.Float64frombits(f) } // ColumnTime returns the value of the result column as a [time.Time]. @@ -561,9 +557,9 @@ func (s *Stmt) ColumnBlob(col int, buf []byte) []byte { // // https://sqlite.org/c3ref/column_blob.html func (s *Stmt) ColumnRawText(col int) []byte { - r := s.c.call("sqlite3_column_text", - uint64(s.handle), uint64(col)) - return s.columnRawBytes(col, uint32(r)) + ptr := ptr_t(s.c.call("sqlite3_column_text", + uint64(s.handle), uint64(col))) + return s.columnRawBytes(col, ptr) } // ColumnRawBlob returns the value of the result column as a []byte. @@ -573,23 +569,23 @@ func (s *Stmt) ColumnRawText(col int) []byte { // // https://sqlite.org/c3ref/column_blob.html func (s *Stmt) ColumnRawBlob(col int) []byte { - r := s.c.call("sqlite3_column_blob", - uint64(s.handle), uint64(col)) - return s.columnRawBytes(col, uint32(r)) + ptr := ptr_t(s.c.call("sqlite3_column_blob", + uint64(s.handle), uint64(col))) + return s.columnRawBytes(col, ptr) } -func (s *Stmt) columnRawBytes(col int, ptr uint32) []byte { +func (s *Stmt) columnRawBytes(col int, ptr ptr_t) []byte { if ptr == 0 { - r := s.c.call("sqlite3_errcode", uint64(s.c.handle)) - if r != _ROW && r != _DONE { - s.err = s.c.error(r) + rc := res_t(s.c.call("sqlite3_errcode", uint64(s.c.handle))) + if rc != _ROW && rc != _DONE { + s.err = s.c.error(rc) } return nil } - r := s.c.call("sqlite3_column_bytes", - uint64(s.handle), uint64(col)) - return util.View(s.c.mod, ptr, r) + n := int32(s.c.call("sqlite3_column_bytes", + uint64(s.handle), uint64(col))) + return util.View(s.c.mod, ptr, uint64(n)) } // ColumnJSON parses the JSON-encoded value of the result column @@ -621,12 +617,12 @@ func (s *Stmt) ColumnJSON(col int, ptr any) error { // // https://sqlite.org/c3ref/column_blob.html func (s *Stmt) ColumnValue(col int) Value { - r := s.c.call("sqlite3_column_value", - uint64(s.handle), uint64(col)) + ptr := ptr_t(s.c.call("sqlite3_column_value", + uint64(s.handle), uint64(col))) return Value{ c: s.c, unprot: true, - handle: uint32(r), + handle: ptr, } } @@ -644,9 +640,9 @@ func (s *Stmt) Columns(dest ...any) error { typePtr := s.c.arena.new(count) dataPtr := s.c.arena.new(count * 8) - r := s.c.call("sqlite3_columns_go", - uint64(s.handle), count, uint64(typePtr), uint64(dataPtr)) - if err := s.c.error(r); err != nil { + rc := res_t(s.c.call("sqlite3_columns_go", + uint64(s.handle), count, uint64(typePtr), uint64(dataPtr))) + if err := s.c.error(rc); err != nil { return err } @@ -660,18 +656,18 @@ func (s *Stmt) Columns(dest ...any) error { for i := range dest { switch types[i] { case byte(INTEGER): - dest[i] = int64(util.ReadUint64(s.c.mod, dataPtr)) + dest[i] = util.Read64[int64](s.c.mod, dataPtr) case byte(FLOAT): dest[i] = util.ReadFloat64(s.c.mod, dataPtr) case byte(NULL): dest[i] = nil default: - ptr := util.ReadUint32(s.c.mod, dataPtr+0) + ptr := util.Read32[ptr_t](s.c.mod, dataPtr+0) if ptr == 0 { dest[i] = []byte{} continue } - len := util.ReadUint32(s.c.mod, dataPtr+4) + len := util.Read32[int32](s.c.mod, dataPtr+4) buf := util.View(s.c.mod, ptr, uint64(len)) if types[i] == byte(TEXT) { dest[i] = string(buf) diff --git a/txn.go b/txn.go index 57ba979a..bdee752e 100644 --- a/txn.go +++ b/txn.go @@ -229,13 +229,12 @@ func (c *Conn) txnExecInterrupted(sql string) error { // // https://sqlite.org/c3ref/txn_state.html func (c *Conn) TxnState(schema string) TxnState { - var ptr uint32 + var ptr ptr_t if schema != "" { defer c.arena.mark()() ptr = c.arena.string(schema) } - r := c.call("sqlite3_txn_state", uint64(c.handle), uint64(ptr)) - return TxnState(r) + return TxnState(c.call("sqlite3_txn_state", uint64(c.handle), uint64(ptr))) } // CommitHook registers a callback function to be invoked @@ -278,7 +277,7 @@ func (c *Conn) UpdateHook(cb func(action AuthorizerActionCode, schema, table str c.update = cb } -func commitCallback(ctx context.Context, mod api.Module, pDB uint32) (rollback uint32) { +func commitCallback(ctx context.Context, mod api.Module, pDB ptr_t) (rollback uint32) { if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.commit != nil { if !c.commit() { rollback = 1 @@ -287,13 +286,13 @@ func commitCallback(ctx context.Context, mod api.Module, pDB uint32) (rollback u return rollback } -func rollbackCallback(ctx context.Context, mod api.Module, pDB uint32) { +func rollbackCallback(ctx context.Context, mod api.Module, pDB ptr_t) { if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.rollback != nil { c.rollback() } } -func updateCallback(ctx context.Context, mod api.Module, pDB uint32, action AuthorizerActionCode, zSchema, zTabName uint32, rowid uint64) { +func updateCallback(ctx context.Context, mod api.Module, pDB ptr_t, action AuthorizerActionCode, zSchema, zTabName ptr_t, rowid uint64) { if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.update != nil { schema := util.ReadString(mod, zSchema, _MAX_NAME) table := util.ReadString(mod, zTabName, _MAX_NAME) @@ -305,6 +304,6 @@ func updateCallback(ctx context.Context, mod api.Module, pDB uint32, action Auth // // https://sqlite.org/c3ref/db_cacheflush.html func (c *Conn) CacheFlush() error { - r := c.call("sqlite3_db_cacheflush", uint64(c.handle)) - return c.error(r) + rc := res_t(c.call("sqlite3_db_cacheflush", uint64(c.handle))) + return c.error(rc) } diff --git a/value.go b/value.go index 43b1a0f1..2e5d4cf4 100644 --- a/value.go +++ b/value.go @@ -14,7 +14,7 @@ import ( // https://sqlite.org/c3ref/value.html type Value struct { c *Conn - handle uint32 + handle ptr_t unprot bool copied bool } @@ -30,11 +30,11 @@ func (v Value) protected() uint64 { // // https://sqlite.org/c3ref/value_dup.html func (v Value) Dup() *Value { - r := v.c.call("sqlite3_value_dup", uint64(v.handle)) + ptr := ptr_t(v.c.call("sqlite3_value_dup", uint64(v.handle))) return &Value{ c: v.c, copied: true, - handle: uint32(r), + handle: ptr, } } @@ -54,16 +54,14 @@ func (dup *Value) Close() error { // // https://sqlite.org/c3ref/value_blob.html func (v Value) Type() Datatype { - r := v.c.call("sqlite3_value_type", v.protected()) - return Datatype(r) + return Datatype(v.c.call("sqlite3_value_type", v.protected())) } // Type returns the numeric datatype of the value. // // https://sqlite.org/c3ref/value_blob.html func (v Value) NumericType() Datatype { - r := v.c.call("sqlite3_value_numeric_type", v.protected()) - return Datatype(r) + return Datatype(v.c.call("sqlite3_value_numeric_type", v.protected())) } // Bool returns the value as a bool. @@ -87,16 +85,15 @@ func (v Value) Int() int { // // https://sqlite.org/c3ref/value_blob.html func (v Value) Int64() int64 { - r := v.c.call("sqlite3_value_int64", v.protected()) - return int64(r) + return int64(v.c.call("sqlite3_value_int64", v.protected())) } // Float returns the value as a float64. // // https://sqlite.org/c3ref/value_blob.html func (v Value) Float() float64 { - r := v.c.call("sqlite3_value_double", v.protected()) - return math.Float64frombits(r) + f := uint64(v.c.call("sqlite3_value_double", v.protected())) + return math.Float64frombits(f) } // Time returns the value as a [time.Time]. @@ -141,8 +138,8 @@ func (v Value) Blob(buf []byte) []byte { // // https://sqlite.org/c3ref/value_blob.html func (v Value) RawText() []byte { - r := v.c.call("sqlite3_value_text", v.protected()) - return v.rawBytes(uint32(r)) + ptr := ptr_t(v.c.call("sqlite3_value_text", v.protected())) + return v.rawBytes(ptr) } // RawBlob returns the value as a []byte. @@ -151,24 +148,24 @@ func (v Value) RawText() []byte { // // https://sqlite.org/c3ref/value_blob.html func (v Value) RawBlob() []byte { - r := v.c.call("sqlite3_value_blob", v.protected()) - return v.rawBytes(uint32(r)) + ptr := ptr_t(v.c.call("sqlite3_value_blob", v.protected())) + return v.rawBytes(ptr) } -func (v Value) rawBytes(ptr uint32) []byte { +func (v Value) rawBytes(ptr ptr_t) []byte { if ptr == 0 { return nil } - r := v.c.call("sqlite3_value_bytes", v.protected()) - return util.View(v.c.mod, ptr, r) + n := int32(v.c.call("sqlite3_value_bytes", v.protected())) + return util.View(v.c.mod, ptr, uint64(n)) } // Pointer gets the pointer associated with this value, // or nil if it has no associated pointer. func (v Value) Pointer() any { - r := v.c.call("sqlite3_value_pointer_go", v.protected()) - return util.GetHandle(v.c.ctx, uint32(r)) + ptr := ptr_t(v.c.call("sqlite3_value_pointer_go", v.protected())) + return util.GetHandle(v.c.ctx, ptr) } // JSON parses a JSON-encoded value @@ -197,16 +194,16 @@ func (v Value) JSON(ptr any) error { // // https://sqlite.org/c3ref/value_blob.html func (v Value) NoChange() bool { - r := v.c.call("sqlite3_value_nochange", v.protected()) - return r != 0 + b := int32(v.c.call("sqlite3_value_nochange", v.protected())) + return b != 0 } // FromBind returns true if value originated from a bound parameter. // // https://sqlite.org/c3ref/value_blob.html func (v Value) FromBind() bool { - r := v.c.call("sqlite3_value_frombind", v.protected()) - return r != 0 + b := int32(v.c.call("sqlite3_value_frombind", v.protected())) + return b != 0 } // InFirst returns the first element @@ -216,13 +213,13 @@ func (v Value) FromBind() bool { func (v Value) InFirst() (Value, error) { defer v.c.arena.mark()() valPtr := v.c.arena.new(ptrlen) - r := v.c.call("sqlite3_vtab_in_first", uint64(v.handle), uint64(valPtr)) - if err := v.c.error(r); err != nil { + rc := res_t(v.c.call("sqlite3_vtab_in_first", uint64(v.handle), uint64(valPtr))) + if err := v.c.error(rc); err != nil { return Value{}, err } return Value{ c: v.c, - handle: util.ReadUint32(v.c.mod, valPtr), + handle: util.Read32[ptr_t](v.c.mod, valPtr), }, nil } @@ -233,12 +230,12 @@ func (v Value) InFirst() (Value, error) { func (v Value) InNext() (Value, error) { defer v.c.arena.mark()() valPtr := v.c.arena.new(ptrlen) - r := v.c.call("sqlite3_vtab_in_next", uint64(v.handle), uint64(valPtr)) - if err := v.c.error(r); err != nil { + rc := res_t(v.c.call("sqlite3_vtab_in_next", uint64(v.handle), uint64(valPtr))) + if err := v.c.error(rc); err != nil { return Value{}, err } return Value{ c: v.c, - handle: util.ReadUint32(v.c.mod, valPtr), + handle: util.Read32[ptr_t](v.c.mod, valPtr), }, nil } diff --git a/vfs/api.go b/vfs/api.go index f2531f22..d5bb3a7a 100644 --- a/vfs/api.go +++ b/vfs/api.go @@ -193,7 +193,7 @@ type FileSharedMemory interface { // SharedMemory is a shared-memory WAL-index implementation. // Use [NewSharedMemory] to create a shared-memory. type SharedMemory interface { - shmMap(context.Context, api.Module, int32, int32, bool) (uint32, _ErrorCode) + shmMap(context.Context, api.Module, int32, int32, bool) (ptr_t, _ErrorCode) shmLock(int32, int32, _ShmFlag) _ErrorCode shmUnmap(bool) shmBarrier() @@ -207,7 +207,7 @@ type blockingSharedMemory interface { type fileControl interface { File - fileControl(ctx context.Context, mod api.Module, op _FcntlOpcode, pArg uint32) _ErrorCode + fileControl(ctx context.Context, mod api.Module, op _FcntlOpcode, pArg ptr_t) _ErrorCode } type filePDB interface { diff --git a/vfs/cksm.go b/vfs/cksm.go index 42d7468f..51f5e8a4 100644 --- a/vfs/cksm.go +++ b/vfs/cksm.go @@ -109,7 +109,7 @@ func (c cksmFile) DeviceCharacteristics() DeviceCharacteristic { return res } -func (c cksmFile) fileControl(ctx context.Context, mod api.Module, op _FcntlOpcode, pArg uint32) _ErrorCode { +func (c cksmFile) fileControl(ctx context.Context, mod api.Module, op _FcntlOpcode, pArg ptr_t) _ErrorCode { switch op { case _FCNTL_CKPT_START: c.inCkpt = true diff --git a/vfs/const.go b/vfs/const.go index 1c9b77a7..b1de96f5 100644 --- a/vfs/const.go +++ b/vfs/const.go @@ -8,9 +8,11 @@ const ( _MAX_PATHNAME = 1024 _DEFAULT_SECTOR_SIZE = 4096 - ptrlen = 4 + ptrlen = util.PtrLen ) +type ptr_t = util.Ptr_t + // https://sqlite.org/rescode.html type _ErrorCode uint32 diff --git a/vfs/filename.go b/vfs/filename.go index d9a29cd4..7137abd9 100644 --- a/vfs/filename.go +++ b/vfs/filename.go @@ -16,13 +16,13 @@ import ( type Filename struct { ctx context.Context mod api.Module - zPath uint32 + zPath ptr_t flags OpenFlag stack [2]uint64 } // GetFilename is an internal API users should not call directly. -func GetFilename(ctx context.Context, mod api.Module, id uint32, flags OpenFlag) *Filename { +func GetFilename(ctx context.Context, mod api.Module, id ptr_t, flags OpenFlag) *Filename { if id == 0 { return nil } @@ -76,7 +76,7 @@ func (n *Filename) path(method string) string { if err := fn.CallWithStack(n.ctx, n.stack[:]); err != nil { panic(err) } - return util.ReadString(n.mod, uint32(n.stack[0]), _MAX_PATHNAME) + return util.ReadString(n.mod, ptr_t(n.stack[0]), _MAX_PATHNAME) } // DatabaseFile returns the main database [File] corresponding to a journal. @@ -95,7 +95,7 @@ func (n *Filename) DatabaseFile() File { if err := fn.CallWithStack(n.ctx, n.stack[:]); err != nil { panic(err) } - file, _ := vfsFileGet(n.ctx, n.mod, uint32(n.stack[0])).(File) + file, _ := vfsFileGet(n.ctx, n.mod, ptr_t(n.stack[0])).(File) return file } @@ -114,7 +114,7 @@ func (n *Filename) URIParameter(key string) string { panic(err) } - ptr := uint32(n.stack[0]) + ptr := ptr_t(n.stack[0]) if ptr == 0 { return "" } @@ -127,13 +127,13 @@ func (n *Filename) URIParameter(key string) string { if k == "" { return "" } - ptr += uint32(len(k)) + 1 + ptr += ptr_t(len(k)) + 1 v := util.ReadString(n.mod, ptr, _MAX_NAME) if k == key { return v } - ptr += uint32(len(v)) + 1 + ptr += ptr_t(len(v)) + 1 } } @@ -152,7 +152,7 @@ func (n *Filename) URIParameters() url.Values { panic(err) } - ptr := uint32(n.stack[0]) + ptr := ptr_t(n.stack[0]) if ptr == 0 { return nil } @@ -167,13 +167,13 @@ func (n *Filename) URIParameters() url.Values { if k == "" { return params } - ptr += uint32(len(k)) + 1 + ptr += ptr_t(len(k)) + 1 v := util.ReadString(n.mod, ptr, _MAX_NAME) if params == nil { params = url.Values{} } params.Add(k, v) - ptr += uint32(len(v)) + 1 + ptr += ptr_t(len(v)) + 1 } } diff --git a/vfs/lock_test.go b/vfs/lock_test.go index 7c19bfd7..e8dfe845 100644 --- a/vfs/lock_test.go +++ b/vfs/lock_test.go @@ -47,21 +47,21 @@ func Test_vfsLock(t *testing.T) { if rc != _OK { t.Fatal("returned", rc) } - if got := util.ReadUint32(mod, pOutput); got != 0 { + if got := util.Read32[LockLevel](mod, pOutput); got != LOCK_NONE { t.Error("file was locked") } rc = vfsCheckReservedLock(ctx, mod, pFile2, pOutput) if rc != _OK { t.Fatal("returned", rc) } - if got := util.ReadUint32(mod, pOutput); got != 0 { + if got := util.Read32[LockLevel](mod, pOutput); got != LOCK_NONE { t.Error("file was locked") } rc = vfsFileControl(ctx, mod, pFile2, _FCNTL_LOCKSTATE, pOutput) if rc != _OK { t.Fatal("returned", rc) } - if got := util.ReadUint32(mod, pOutput); got != uint32(LOCK_NONE) { + if got := util.Read32[LockLevel](mod, pOutput); got != LOCK_NONE { t.Error("invalid lock state", got) } @@ -74,21 +74,21 @@ func Test_vfsLock(t *testing.T) { if rc != _OK { t.Fatal("returned", rc) } - if got := util.ReadUint32(mod, pOutput); got != 0 { + if got := util.Read32[LockLevel](mod, pOutput); got != LOCK_NONE { t.Error("file was locked") } rc = vfsCheckReservedLock(ctx, mod, pFile2, pOutput) if rc != _OK { t.Fatal("returned", rc) } - if got := util.ReadUint32(mod, pOutput); got != 0 { + if got := util.Read32[LockLevel](mod, pOutput); got != LOCK_NONE { t.Error("file was locked") } rc = vfsFileControl(ctx, mod, pFile2, _FCNTL_LOCKSTATE, pOutput) if rc != _OK { t.Fatal("returned", rc) } - if got := util.ReadUint32(mod, pOutput); got != uint32(LOCK_SHARED) { + if got := util.Read32[LockLevel](mod, pOutput); got != LOCK_SHARED { t.Error("invalid lock state", got) } @@ -105,21 +105,21 @@ func Test_vfsLock(t *testing.T) { if rc != _OK { t.Fatal("returned", rc) } - if got := util.ReadUint32(mod, pOutput); got == 0 { + if got := util.Read32[LockLevel](mod, pOutput); got == LOCK_NONE { t.Log("file wasn't locked, locking is incompatible with SQLite") } rc = vfsCheckReservedLock(ctx, mod, pFile2, pOutput) if rc != _OK { t.Fatal("returned", rc) } - if got := util.ReadUint32(mod, pOutput); got == 0 { + if got := util.Read32[LockLevel](mod, pOutput); got == LOCK_NONE { t.Error("file wasn't locked") } rc = vfsFileControl(ctx, mod, pFile2, _FCNTL_LOCKSTATE, pOutput) if rc != _OK { t.Fatal("returned", rc) } - if got := util.ReadUint32(mod, pOutput); got != uint32(LOCK_RESERVED) { + if got := util.Read32[LockLevel](mod, pOutput); got != LOCK_RESERVED { t.Error("invalid lock state", got) } @@ -132,21 +132,21 @@ func Test_vfsLock(t *testing.T) { if rc != _OK { t.Fatal("returned", rc) } - if got := util.ReadUint32(mod, pOutput); got == 0 { + if got := util.Read32[LockLevel](mod, pOutput); got == LOCK_NONE { t.Log("file wasn't locked, locking is incompatible with SQLite") } rc = vfsCheckReservedLock(ctx, mod, pFile2, pOutput) if rc != _OK { t.Fatal("returned", rc) } - if got := util.ReadUint32(mod, pOutput); got == 0 { + if got := util.Read32[LockLevel](mod, pOutput); got == LOCK_NONE { t.Error("file wasn't locked") } rc = vfsFileControl(ctx, mod, pFile2, _FCNTL_LOCKSTATE, pOutput) if rc != _OK { t.Fatal("returned", rc) } - if got := util.ReadUint32(mod, pOutput); got != uint32(LOCK_EXCLUSIVE) { + if got := util.Read32[LockLevel](mod, pOutput); got != LOCK_EXCLUSIVE { t.Error("invalid lock state", got) } @@ -159,21 +159,21 @@ func Test_vfsLock(t *testing.T) { if rc != _OK { t.Fatal("returned", rc) } - if got := util.ReadUint32(mod, pOutput); got == 0 { + if got := util.Read32[LockLevel](mod, pOutput); got == LOCK_NONE { t.Log("file wasn't locked, locking is incompatible with SQLite") } rc = vfsCheckReservedLock(ctx, mod, pFile2, pOutput) if rc != _OK { t.Fatal("returned", rc) } - if got := util.ReadUint32(mod, pOutput); got == 0 { + if got := util.Read32[LockLevel](mod, pOutput); got == LOCK_NONE { t.Error("file wasn't locked") } rc = vfsFileControl(ctx, mod, pFile1, _FCNTL_LOCKSTATE, pOutput) if rc != _OK { t.Fatal("returned", rc) } - if got := util.ReadUint32(mod, pOutput); got != uint32(LOCK_NONE) { + if got := util.Read32[LockLevel](mod, pOutput); got != LOCK_NONE { t.Error("invalid lock state", got) } @@ -186,14 +186,14 @@ func Test_vfsLock(t *testing.T) { if rc != _OK { t.Fatal("returned", rc) } - if got := util.ReadUint32(mod, pOutput); got != 0 { + if got := util.Read32[LockLevel](mod, pOutput); got != LOCK_NONE { t.Error("file was locked") } rc = vfsCheckReservedLock(ctx, mod, pFile2, pOutput) if rc != _OK { t.Fatal("returned", rc) } - if got := util.ReadUint32(mod, pOutput); got != 0 { + if got := util.Read32[LockLevel](mod, pOutput); got != LOCK_NONE { t.Error("file was locked") } @@ -205,7 +205,7 @@ func Test_vfsLock(t *testing.T) { if rc != _OK { t.Fatal("returned", rc) } - if got := util.ReadUint32(mod, pOutput); got != uint32(LOCK_SHARED) { + if got := util.Read32[LockLevel](mod, pOutput); got != LOCK_SHARED { t.Error("invalid lock state", got) } } diff --git a/vfs/shm_ofd.go b/vfs/shm_ofd.go index dd361119..b0f50fcb 100644 --- a/vfs/shm_ofd.go +++ b/vfs/shm_ofd.go @@ -73,7 +73,7 @@ func (s *vfsShm) shmOpen() _ErrorCode { return rc } -func (s *vfsShm) shmMap(ctx context.Context, mod api.Module, id, size int32, extend bool) (uint32, _ErrorCode) { +func (s *vfsShm) shmMap(ctx context.Context, mod api.Module, id, size int32, extend bool) (ptr_t, _ErrorCode) { // Ensure size is a multiple of the OS page size. if int(size)&(unix.Getpagesize()-1) != 0 { return 0, _IOERR_SHMMAP diff --git a/vfs/vfs.go b/vfs/vfs.go index d8816e40..2a7cd9a3 100644 --- a/vfs/vfs.go +++ b/vfs/vfs.go @@ -49,7 +49,7 @@ func ExportHostFunctions(env wazero.HostModuleBuilder) wazero.HostModuleBuilder return env } -func vfsFind(ctx context.Context, mod api.Module, zVfsName uint32) uint32 { +func vfsFind(ctx context.Context, mod api.Module, zVfsName ptr_t) uint32 { name := util.ReadString(mod, zVfsName, _MAX_NAME) if vfs := Find(name); vfs != nil && vfs != (vfsOS{}) { return 1 @@ -57,46 +57,46 @@ func vfsFind(ctx context.Context, mod api.Module, zVfsName uint32) uint32 { return 0 } -func vfsLocaltime(ctx context.Context, mod api.Module, pTm uint32, t int64) _ErrorCode { +func vfsLocaltime(ctx context.Context, mod api.Module, pTm ptr_t, t int64) _ErrorCode { tm := time.Unix(t, 0) - var isdst int + var isdst int32 if tm.IsDST() { isdst = 1 } const size = 32 / 8 // https://pubs.opengroup.org/onlinepubs/7908799/xsh/time.h.html - util.WriteUint32(mod, pTm+0*size, uint32(tm.Second())) - util.WriteUint32(mod, pTm+1*size, uint32(tm.Minute())) - util.WriteUint32(mod, pTm+2*size, uint32(tm.Hour())) - util.WriteUint32(mod, pTm+3*size, uint32(tm.Day())) - util.WriteUint32(mod, pTm+4*size, uint32(tm.Month()-time.January)) - util.WriteUint32(mod, pTm+5*size, uint32(tm.Year()-1900)) - util.WriteUint32(mod, pTm+6*size, uint32(tm.Weekday()-time.Sunday)) - util.WriteUint32(mod, pTm+7*size, uint32(tm.YearDay()-1)) - util.WriteUint32(mod, pTm+8*size, uint32(isdst)) + util.Write32(mod, pTm+0*size, int32(tm.Second())) + util.Write32(mod, pTm+1*size, int32(tm.Minute())) + util.Write32(mod, pTm+2*size, int32(tm.Hour())) + util.Write32(mod, pTm+3*size, int32(tm.Day())) + util.Write32(mod, pTm+4*size, int32(tm.Month()-time.January)) + util.Write32(mod, pTm+5*size, int32(tm.Year()-1900)) + util.Write32(mod, pTm+6*size, int32(tm.Weekday()-time.Sunday)) + util.Write32(mod, pTm+7*size, int32(tm.YearDay()-1)) + util.Write32(mod, pTm+8*size, isdst) return _OK } -func vfsRandomness(ctx context.Context, mod api.Module, pVfs uint32, nByte int32, zByte uint32) uint32 { +func vfsRandomness(ctx context.Context, mod api.Module, pVfs ptr_t, nByte int32, zByte ptr_t) uint32 { mem := util.View(mod, zByte, uint64(nByte)) n, _ := rand.Reader.Read(mem) return uint32(n) } -func vfsSleep(ctx context.Context, mod api.Module, pVfs uint32, nMicro int32) _ErrorCode { +func vfsSleep(ctx context.Context, mod api.Module, pVfs ptr_t, nMicro int32) _ErrorCode { time.Sleep(time.Duration(nMicro) * time.Microsecond) return _OK } -func vfsCurrentTime64(ctx context.Context, mod api.Module, pVfs, piNow uint32) _ErrorCode { +func vfsCurrentTime64(ctx context.Context, mod api.Module, pVfs, piNow ptr_t) _ErrorCode { day, nsec := julianday.Date(time.Now()) msec := day*86_400_000 + nsec/1_000_000 - util.WriteUint64(mod, piNow, uint64(msec)) + util.Write64(mod, piNow, msec) return _OK } -func vfsFullPathname(ctx context.Context, mod api.Module, pVfs, zRelative uint32, nFull int32, zFull uint32) _ErrorCode { +func vfsFullPathname(ctx context.Context, mod api.Module, pVfs, zRelative ptr_t, nFull int32, zFull ptr_t) _ErrorCode { vfs := vfsGet(mod, pVfs) path := util.ReadString(mod, zRelative, _MAX_PATHNAME) @@ -110,7 +110,7 @@ func vfsFullPathname(ctx context.Context, mod api.Module, pVfs, zRelative uint32 return vfsErrorCode(err, _CANTOPEN_FULLPATH) } -func vfsDelete(ctx context.Context, mod api.Module, pVfs, zPath, syncDir uint32) _ErrorCode { +func vfsDelete(ctx context.Context, mod api.Module, pVfs, zPath ptr_t, syncDir uint32) _ErrorCode { vfs := vfsGet(mod, pVfs) path := util.ReadString(mod, zPath, _MAX_PATHNAME) @@ -118,21 +118,21 @@ func vfsDelete(ctx context.Context, mod api.Module, pVfs, zPath, syncDir uint32) return vfsErrorCode(err, _IOERR_DELETE) } -func vfsAccess(ctx context.Context, mod api.Module, pVfs, zPath uint32, flags AccessFlag, pResOut uint32) _ErrorCode { +func vfsAccess(ctx context.Context, mod api.Module, pVfs, zPath ptr_t, flags AccessFlag, pResOut ptr_t) _ErrorCode { vfs := vfsGet(mod, pVfs) path := util.ReadString(mod, zPath, _MAX_PATHNAME) ok, err := vfs.Access(path, flags) - var res uint32 + var res int32 if ok { res = 1 } - util.WriteUint32(mod, pResOut, res) + util.Write32(mod, pResOut, res) return vfsErrorCode(err, _IOERR_ACCESS) } -func vfsOpen(ctx context.Context, mod api.Module, pVfs, zPath, pFile uint32, flags OpenFlag, pOutFlags, pOutVFS uint32) _ErrorCode { +func vfsOpen(ctx context.Context, mod api.Module, pVfs, zPath, pFile ptr_t, flags OpenFlag, pOutFlags, pOutVFS ptr_t) _ErrorCode { vfs := vfsGet(mod, pVfs) name := GetFilename(ctx, mod, zPath, flags) @@ -154,22 +154,22 @@ func vfsOpen(ctx context.Context, mod api.Module, pVfs, zPath, pFile uint32, fla } if file, ok := file.(FileSharedMemory); ok && pOutVFS != 0 && file.SharedMemory() != nil { - util.WriteUint32(mod, pOutVFS, 1) + util.Write32(mod, pOutVFS, int32(1)) } if pOutFlags != 0 { - util.WriteUint32(mod, pOutFlags, uint32(flags)) + util.Write32(mod, pOutFlags, flags) } file = cksmWrapFile(name, flags, file) vfsFileRegister(ctx, mod, pFile, file) return _OK } -func vfsClose(ctx context.Context, mod api.Module, pFile uint32) _ErrorCode { +func vfsClose(ctx context.Context, mod api.Module, pFile ptr_t) _ErrorCode { err := vfsFileClose(ctx, mod, pFile) return vfsErrorCode(err, _IOERR_CLOSE) } -func vfsRead(ctx context.Context, mod api.Module, pFile, zBuf uint32, iAmt int32, iOfst int64) _ErrorCode { +func vfsRead(ctx context.Context, mod api.Module, pFile, zBuf ptr_t, iAmt int32, iOfst int64) _ErrorCode { file := vfsFileGet(ctx, mod, pFile).(File) buf := util.View(mod, zBuf, uint64(iAmt)) @@ -184,7 +184,7 @@ func vfsRead(ctx context.Context, mod api.Module, pFile, zBuf uint32, iAmt int32 return _IOERR_SHORT_READ } -func vfsWrite(ctx context.Context, mod api.Module, pFile, zBuf uint32, iAmt int32, iOfst int64) _ErrorCode { +func vfsWrite(ctx context.Context, mod api.Module, pFile, zBuf ptr_t, iAmt int32, iOfst int64) _ErrorCode { file := vfsFileGet(ctx, mod, pFile).(File) buf := util.View(mod, zBuf, uint64(iAmt)) @@ -192,51 +192,51 @@ func vfsWrite(ctx context.Context, mod api.Module, pFile, zBuf uint32, iAmt int3 return vfsErrorCode(err, _IOERR_WRITE) } -func vfsTruncate(ctx context.Context, mod api.Module, pFile uint32, nByte int64) _ErrorCode { +func vfsTruncate(ctx context.Context, mod api.Module, pFile ptr_t, nByte int64) _ErrorCode { file := vfsFileGet(ctx, mod, pFile).(File) err := file.Truncate(nByte) return vfsErrorCode(err, _IOERR_TRUNCATE) } -func vfsSync(ctx context.Context, mod api.Module, pFile uint32, flags SyncFlag) _ErrorCode { +func vfsSync(ctx context.Context, mod api.Module, pFile ptr_t, flags SyncFlag) _ErrorCode { file := vfsFileGet(ctx, mod, pFile).(File) err := file.Sync(flags) return vfsErrorCode(err, _IOERR_FSYNC) } -func vfsFileSize(ctx context.Context, mod api.Module, pFile, pSize uint32) _ErrorCode { +func vfsFileSize(ctx context.Context, mod api.Module, pFile, pSize ptr_t) _ErrorCode { file := vfsFileGet(ctx, mod, pFile).(File) size, err := file.Size() - util.WriteUint64(mod, pSize, uint64(size)) + util.Write64(mod, pSize, size) return vfsErrorCode(err, _IOERR_SEEK) } -func vfsLock(ctx context.Context, mod api.Module, pFile uint32, eLock LockLevel) _ErrorCode { +func vfsLock(ctx context.Context, mod api.Module, pFile ptr_t, eLock LockLevel) _ErrorCode { file := vfsFileGet(ctx, mod, pFile).(File) err := file.Lock(eLock) return vfsErrorCode(err, _IOERR_LOCK) } -func vfsUnlock(ctx context.Context, mod api.Module, pFile uint32, eLock LockLevel) _ErrorCode { +func vfsUnlock(ctx context.Context, mod api.Module, pFile ptr_t, eLock LockLevel) _ErrorCode { file := vfsFileGet(ctx, mod, pFile).(File) err := file.Unlock(eLock) return vfsErrorCode(err, _IOERR_UNLOCK) } -func vfsCheckReservedLock(ctx context.Context, mod api.Module, pFile, pResOut uint32) _ErrorCode { +func vfsCheckReservedLock(ctx context.Context, mod api.Module, pFile, pResOut ptr_t) _ErrorCode { file := vfsFileGet(ctx, mod, pFile).(File) locked, err := file.CheckReservedLock() - var res uint32 + var res int32 if locked { res = 1 } - util.WriteUint32(mod, pResOut, res) + util.Write32(mod, pResOut, res) return vfsErrorCode(err, _IOERR_CHECKRESERVEDLOCK) } -func vfsFileControl(ctx context.Context, mod api.Module, pFile uint32, op _FcntlOpcode, pArg uint32) _ErrorCode { +func vfsFileControl(ctx context.Context, mod api.Module, pFile ptr_t, op _FcntlOpcode, pArg ptr_t) _ErrorCode { file := vfsFileGet(ctx, mod, pFile).(File) if file, ok := file.(fileControl); ok { return file.fileControl(ctx, mod, op, pArg) @@ -244,51 +244,51 @@ func vfsFileControl(ctx context.Context, mod api.Module, pFile uint32, op _Fcntl return vfsFileControlImpl(ctx, mod, file, op, pArg) } -func vfsFileControlImpl(ctx context.Context, mod api.Module, file File, op _FcntlOpcode, pArg uint32) _ErrorCode { +func vfsFileControlImpl(ctx context.Context, mod api.Module, file File, op _FcntlOpcode, pArg ptr_t) _ErrorCode { switch op { case _FCNTL_LOCKSTATE: if file, ok := file.(FileLockState); ok { if lk := file.LockState(); lk <= LOCK_EXCLUSIVE { - util.WriteUint32(mod, pArg, uint32(lk)) + util.Write32(mod, pArg, lk) return _OK } } case _FCNTL_PERSIST_WAL: if file, ok := file.(FilePersistWAL); ok { - if i := util.ReadUint32(mod, pArg); int32(i) >= 0 { + if i := util.Read32[int32](mod, pArg); i >= 0 { file.SetPersistWAL(i != 0) } else if file.PersistWAL() { - util.WriteUint32(mod, pArg, 1) + util.Write32(mod, pArg, int32(1)) } else { - util.WriteUint32(mod, pArg, 0) + util.Write32(mod, pArg, int32(0)) } return _OK } case _FCNTL_POWERSAFE_OVERWRITE: if file, ok := file.(FilePowersafeOverwrite); ok { - if i := util.ReadUint32(mod, pArg); int32(i) >= 0 { + if i := util.Read32[int32](mod, pArg); i >= 0 { file.SetPowersafeOverwrite(i != 0) } else if file.PowersafeOverwrite() { - util.WriteUint32(mod, pArg, 1) + util.Write32(mod, pArg, int32(1)) } else { - util.WriteUint32(mod, pArg, 0) + util.Write32(mod, pArg, int32(0)) } return _OK } case _FCNTL_CHUNK_SIZE: if file, ok := file.(FileChunkSize); ok { - size := util.ReadUint32(mod, pArg) + size := util.Read32[int32](mod, pArg) file.ChunkSize(int(size)) return _OK } case _FCNTL_SIZE_HINT: if file, ok := file.(FileSizeHint); ok { - size := util.ReadUint64(mod, pArg) - err := file.SizeHint(int64(size)) + size := util.Read64[int64](mod, pArg) + err := file.SizeHint(size) return vfsErrorCode(err, _IOERR_TRUNCATE) } @@ -299,7 +299,7 @@ func vfsFileControlImpl(ctx context.Context, mod api.Module, file File, op _Fcnt if moved { res = 1 } - util.WriteUint32(mod, pArg, res) + util.Write32(mod, pArg, res) return vfsErrorCode(err, _IOERR_FSTAT) } @@ -354,10 +354,10 @@ func vfsFileControlImpl(ctx context.Context, mod api.Module, file File, op _Fcnt case _FCNTL_PRAGMA: if file, ok := file.(FilePragma); ok { - ptr := util.ReadUint32(mod, pArg+1*ptrlen) + ptr := util.Read32[ptr_t](mod, pArg+1*ptrlen) name := util.ReadString(mod, ptr, _MAX_SQL_LENGTH) var value string - if ptr := util.ReadUint32(mod, pArg+2*ptrlen); ptr != 0 { + if ptr := util.Read32[ptr_t](mod, pArg+2*ptrlen); ptr != 0 { value = util.ReadString(mod, ptr, _MAX_SQL_LENGTH) } @@ -373,15 +373,15 @@ func vfsFileControlImpl(ctx context.Context, mod api.Module, file File, op _Fcnt if err := fn.CallWithStack(ctx, stack[:]); err != nil { panic(err) } - util.WriteUint32(mod, pArg, uint32(stack[0])) - util.WriteString(mod, uint32(stack[0]), out) + util.Write32(mod, pArg, ptr_t(stack[0])) + util.WriteString(mod, ptr_t(stack[0]), out) } return ret } case _FCNTL_BUSYHANDLER: if file, ok := file.(FileBusyHandler); ok { - arg := util.ReadUint64(mod, pArg) + arg := util.Read64[uint64](mod, pArg) fn := mod.ExportedFunction("sqlite3_invoke_busy_handler_go") file.BusyHandler(func() bool { stack := [...]uint64{arg} @@ -396,7 +396,7 @@ func vfsFileControlImpl(ctx context.Context, mod api.Module, file File, op _Fcnt case _FCNTL_LOCK_TIMEOUT: if file, ok := file.(FileSharedMemory); ok { if shm, ok := file.SharedMemory().(blockingSharedMemory); ok { - shm.shmEnableBlocking(util.ReadUint32(mod, pArg) != 0) + shm.shmEnableBlocking(util.Read32[uint32](mod, pArg) != 0) return _OK } } @@ -411,44 +411,45 @@ func vfsFileControlImpl(ctx context.Context, mod api.Module, file File, op _Fcnt return _NOTFOUND } -func vfsSectorSize(ctx context.Context, mod api.Module, pFile uint32) uint32 { +func vfsSectorSize(ctx context.Context, mod api.Module, pFile ptr_t) uint32 { file := vfsFileGet(ctx, mod, pFile).(File) return uint32(file.SectorSize()) } -func vfsDeviceCharacteristics(ctx context.Context, mod api.Module, pFile uint32) DeviceCharacteristic { +func vfsDeviceCharacteristics(ctx context.Context, mod api.Module, pFile ptr_t) DeviceCharacteristic { file := vfsFileGet(ctx, mod, pFile).(File) return file.DeviceCharacteristics() } -func vfsShmBarrier(ctx context.Context, mod api.Module, pFile uint32) { +func vfsShmBarrier(ctx context.Context, mod api.Module, pFile ptr_t) { shm := vfsFileGet(ctx, mod, pFile).(FileSharedMemory).SharedMemory() shm.shmBarrier() } -func vfsShmMap(ctx context.Context, mod api.Module, pFile uint32, iRegion, szRegion int32, bExtend, pp uint32) _ErrorCode { +func vfsShmMap(ctx context.Context, mod api.Module, pFile ptr_t, iRegion, szRegion, bExtend int32, pp ptr_t) _ErrorCode { shm := vfsFileGet(ctx, mod, pFile).(FileSharedMemory).SharedMemory() p, rc := shm.shmMap(ctx, mod, iRegion, szRegion, bExtend != 0) - util.WriteUint32(mod, pp, p) + util.Write32(mod, pp, p) return rc } -func vfsShmLock(ctx context.Context, mod api.Module, pFile uint32, offset, n int32, flags _ShmFlag) _ErrorCode { +func vfsShmLock(ctx context.Context, mod api.Module, pFile ptr_t, offset, n int32, flags _ShmFlag) _ErrorCode { shm := vfsFileGet(ctx, mod, pFile).(FileSharedMemory).SharedMemory() return shm.shmLock(offset, n, flags) } -func vfsShmUnmap(ctx context.Context, mod api.Module, pFile, bDelete uint32) _ErrorCode { +func vfsShmUnmap(ctx context.Context, mod api.Module, pFile ptr_t, bDelete uint32) _ErrorCode { shm := vfsFileGet(ctx, mod, pFile).(FileSharedMemory).SharedMemory() shm.shmUnmap(bDelete != 0) return _OK } -func vfsGet(mod api.Module, pVfs uint32) VFS { +func vfsGet(mod api.Module, pVfs ptr_t) VFS { var name string if pVfs != 0 { const zNameOffset = 16 - name = util.ReadString(mod, util.ReadUint32(mod, pVfs+zNameOffset), _MAX_NAME) + ptr := util.Read32[ptr_t](mod, pVfs+zNameOffset) + name = util.ReadString(mod, ptr, _MAX_NAME) } if vfs := Find(name); vfs != nil { return vfs @@ -456,21 +457,21 @@ func vfsGet(mod api.Module, pVfs uint32) VFS { panic(util.NoVFSErr + util.ErrorString(name)) } -func vfsFileRegister(ctx context.Context, mod api.Module, pFile uint32, file File) { +func vfsFileRegister(ctx context.Context, mod api.Module, pFile ptr_t, file File) { const fileHandleOffset = 4 id := util.AddHandle(ctx, file) - util.WriteUint32(mod, pFile+fileHandleOffset, id) + util.Write32(mod, pFile+fileHandleOffset, id) } -func vfsFileGet(ctx context.Context, mod api.Module, pFile uint32) any { +func vfsFileGet(ctx context.Context, mod api.Module, pFile ptr_t) any { const fileHandleOffset = 4 - id := util.ReadUint32(mod, pFile+fileHandleOffset) + id := util.Read32[ptr_t](mod, pFile+fileHandleOffset) return util.GetHandle(ctx, id) } -func vfsFileClose(ctx context.Context, mod api.Module, pFile uint32) error { +func vfsFileClose(ctx context.Context, mod api.Module, pFile ptr_t) error { const fileHandleOffset = 4 - id := util.ReadUint32(mod, pFile+fileHandleOffset) + id := util.Read32[ptr_t](mod, pFile+fileHandleOffset) return util.DelHandle(ctx, id) } diff --git a/vfs/vfs_test.go b/vfs/vfs_test.go index 79c8c3e9..64d9a21d 100644 --- a/vfs/vfs_test.go +++ b/vfs/vfs_test.go @@ -5,7 +5,6 @@ import ( "context" "errors" "io/fs" - "math" "os" "os/user" "path/filepath" @@ -29,28 +28,28 @@ func Test_vfsLocaltime(t *testing.T) { t.Fatal("returned", rc) } - if s := util.ReadUint32(mod, 4+0*4); int(s) != tm.Second() { + if s := util.Read32[int32](mod, 4+0*4); int(s) != tm.Second() { t.Error("wrong second") } - if m := util.ReadUint32(mod, 4+1*4); int(m) != tm.Minute() { + if m := util.Read32[int32](mod, 4+1*4); int(m) != tm.Minute() { t.Error("wrong minute") } - if h := util.ReadUint32(mod, 4+2*4); int(h) != tm.Hour() { + if h := util.Read32[int32](mod, 4+2*4); int(h) != tm.Hour() { t.Error("wrong hour") } - if d := util.ReadUint32(mod, 4+3*4); int(d) != tm.Day() { + if d := util.Read32[int32](mod, 4+3*4); int(d) != tm.Day() { t.Error("wrong day") } - if m := util.ReadUint32(mod, 4+4*4); time.Month(1+m) != tm.Month() { + if m := util.Read32[int32](mod, 4+4*4); time.Month(1+m) != tm.Month() { t.Error("wrong month") } - if y := util.ReadUint32(mod, 4+5*4); 1900+int(y) != tm.Year() { + if y := util.Read32[int32](mod, 4+5*4); 1900+int(y) != tm.Year() { t.Error("wrong year") } - if w := util.ReadUint32(mod, 4+6*4); time.Weekday(w) != tm.Weekday() { + if w := util.Read32[int32](mod, 4+6*4); time.Weekday(w) != tm.Weekday() { t.Error("wrong weekday") } - if d := util.ReadUint32(mod, 4+7*4); int(d) != tm.YearDay()-1 { + if d := util.Read32[int32](mod, 4+7*4); int(d) != tm.YearDay()-1 { t.Error("wrong yearday") } } @@ -99,7 +98,7 @@ func Test_vfsCurrentTime64(t *testing.T) { day, nsec := julianday.Date(now) want := day*86_400_000 + nsec/1_000_000 - if got := util.ReadUint64(mod, 4); float32(got) != float32(want) { + if got := util.Read64[int64](mod, 4); float32(got) != float32(want) { t.Errorf("got %v, want %v", got, want) } } @@ -173,7 +172,7 @@ func Test_vfsAccess(t *testing.T) { if rc != _OK { t.Fatal("returned", rc) } - if got := util.ReadUint32(mod, 4); got != 1 { + if got := util.Read32[int32](mod, 4); got != 1 { t.Error("directory did not exist") } @@ -181,7 +180,7 @@ func Test_vfsAccess(t *testing.T) { if rc != _OK { t.Fatal("returned", rc) } - if got := util.ReadUint32(mod, 4); got != 1 { + if got := util.Read32[int32](mod, 4); got != 1 { t.Error("can't access directory") } @@ -190,7 +189,7 @@ func Test_vfsAccess(t *testing.T) { if rc != _OK { t.Fatal("returned", rc) } - if got := util.ReadUint32(mod, 4); got != 1 { + if got := util.Read32[int32](mod, 4); got != 1 { t.Error("can't access file") } @@ -208,7 +207,7 @@ func Test_vfsAccess(t *testing.T) { if rc != _OK { t.Fatal("returned", rc) } - if got := util.ReadUint32(mod, 4); got != 0 { + if got := util.Read32[int32](mod, 4); got != 0 { t.Error("can access file") } } @@ -241,7 +240,7 @@ func Test_vfsFile(t *testing.T) { if rc != _OK { t.Fatal("returned", rc) } - if got := util.ReadUint32(mod, 16); got != uint32(len(text)) { + if got := util.Read64[int64](mod, 16); got != int64(len(text)) { t.Errorf("got %d", got) } @@ -265,7 +264,7 @@ func Test_vfsFile(t *testing.T) { if rc != _OK { t.Fatal("returned", rc) } - if got := util.ReadUint32(mod, 16); got != 4 { + if got := util.Read64[int64](mod, 16); got != 4 { t.Errorf("got %d", got) } @@ -296,46 +295,46 @@ func Test_vfsFile_psow(t *testing.T) { } // Read powersafe overwrite. - util.WriteUint32(mod, 16, math.MaxUint32) + util.Write32(mod, 16, int32(-1)) rc = vfsFileControl(ctx, mod, 4, _FCNTL_POWERSAFE_OVERWRITE, 16) if rc != _OK { t.Fatal("returned", rc) } - if got := util.ReadUint32(mod, 16); got == 0 { + if got := util.Read32[int32](mod, 16); got == 0 { t.Error("psow disabled") } // Unset powersafe overwrite. - util.WriteUint32(mod, 16, 0) + util.Write32(mod, 16, int32(0)) rc = vfsFileControl(ctx, mod, 4, _FCNTL_POWERSAFE_OVERWRITE, 16) if rc != _OK { t.Fatal("returned", rc) } // Read powersafe overwrite. - util.WriteUint32(mod, 16, math.MaxUint32) + util.Write32(mod, 16, int32(-1)) rc = vfsFileControl(ctx, mod, 4, _FCNTL_POWERSAFE_OVERWRITE, 16) if rc != _OK { t.Fatal("returned", rc) } - if got := util.ReadUint32(mod, 16); got != 0 { + if got := util.Read32[int32](mod, 16); got != 0 { t.Error("psow enabled") } // Set powersafe overwrite. - util.WriteUint32(mod, 16, 1) + util.Write32(mod, 16, int32(1)) rc = vfsFileControl(ctx, mod, 4, _FCNTL_POWERSAFE_OVERWRITE, 16) if rc != _OK { t.Fatal("returned", rc) } // Read powersafe overwrite. - util.WriteUint32(mod, 16, math.MaxUint32) + util.Write32(mod, 16, int32(-1)) rc = vfsFileControl(ctx, mod, 4, _FCNTL_POWERSAFE_OVERWRITE, 16) if rc != _OK { t.Fatal("returned", rc) } - if got := util.ReadUint32(mod, 16); got == 0 { + if got := util.Read32[int32](mod, 16); got == 0 { t.Error("psow disabled") } diff --git a/vtab.go b/vtab.go index 1998a528..ef9ce94f 100644 --- a/vtab.go +++ b/vtab.go @@ -58,15 +58,15 @@ func CreateModule[T VTab](db *Conn, name string, create, connect VTabConstructor flags |= VTAB_SHADOWTABS } - var modulePtr uint32 + var modulePtr ptr_t defer db.arena.mark()() namePtr := db.arena.string(name) if connect != nil { modulePtr = util.AddHandle(db.ctx, module[T]{create, connect}) } - r := db.call("sqlite3_create_module_go", uint64(db.handle), - uint64(namePtr), uint64(flags), uint64(modulePtr)) - return db.error(r) + rc := res_t(db.call("sqlite3_create_module_go", uint64(db.handle), + uint64(namePtr), uint64(flags), uint64(modulePtr))) + return db.error(rc) } func implements[T any](typ reflect.Type) bool { @@ -80,8 +80,8 @@ func implements[T any](typ reflect.Type) bool { func (c *Conn) DeclareVTab(sql string) error { defer c.arena.mark()() sqlPtr := c.arena.string(sql) - r := c.call("sqlite3_declare_vtab", uint64(c.handle), uint64(sqlPtr)) - return c.error(r) + rc := res_t(c.call("sqlite3_declare_vtab", uint64(c.handle), uint64(sqlPtr))) + return c.error(rc) } // VTabConflictMode is a virtual table conflict resolution mode. @@ -101,8 +101,7 @@ const ( // // https://sqlite.org/c3ref/vtab_on_conflict.html func (c *Conn) VTabOnConflict() VTabConflictMode { - r := c.call("sqlite3_vtab_on_conflict", uint64(c.handle)) - return VTabConflictMode(r) + return VTabConflictMode(c.call("sqlite3_vtab_on_conflict", uint64(c.handle))) } // VTabConfigOption is a virtual table configuration option. @@ -127,8 +126,8 @@ func (c *Conn) VTabConfig(op VTabConfigOption, args ...any) error { i = 1 } } - r := c.call("sqlite3_vtab_config_go", uint64(c.handle), uint64(op), i) - return c.error(r) + rc := res_t(c.call("sqlite3_vtab_config_go", uint64(c.handle), uint64(op), i)) + return c.error(rc) } // VTabConstructor is a virtual table constructor function. @@ -263,7 +262,7 @@ type IndexInfo struct { // Inputs Constraint []IndexConstraint OrderBy []IndexOrderBy - ColumnsUsed int64 + ColumnsUsed uint64 // Outputs ConstraintUsage []IndexConstraintUsage IdxNum int @@ -274,7 +273,7 @@ type IndexInfo struct { EstimatedRows int64 // Internal c *Conn - handle uint32 + handle ptr_t } // An IndexConstraint describes virtual table indexing constraint information. @@ -309,14 +308,14 @@ type IndexConstraintUsage struct { func (idx *IndexInfo) RHSValue(column int) (Value, error) { defer idx.c.arena.mark()() valPtr := idx.c.arena.new(ptrlen) - r := idx.c.call("sqlite3_vtab_rhs_value", uint64(idx.handle), - uint64(column), uint64(valPtr)) - if err := idx.c.error(r); err != nil { + rc := res_t(idx.c.call("sqlite3_vtab_rhs_value", uint64(idx.handle), + uint64(column), uint64(valPtr))) + if err := idx.c.error(rc); err != nil { return Value{}, err } return Value{ c: idx.c, - handle: util.ReadUint32(idx.c.mod, valPtr), + handle: util.Read32[ptr_t](idx.c.mod, valPtr), }, nil } @@ -324,26 +323,26 @@ func (idx *IndexInfo) RHSValue(column int) (Value, error) { // // https://sqlite.org/c3ref/vtab_collation.html func (idx *IndexInfo) Collation(column int) string { - r := idx.c.call("sqlite3_vtab_collation", uint64(idx.handle), - uint64(column)) - return util.ReadString(idx.c.mod, uint32(r), _MAX_NAME) + ptr := ptr_t(idx.c.call("sqlite3_vtab_collation", uint64(idx.handle), + uint64(column))) + return util.ReadString(idx.c.mod, ptr, _MAX_NAME) } // Distinct determines if a virtual table query is DISTINCT. // // https://sqlite.org/c3ref/vtab_distinct.html func (idx *IndexInfo) Distinct() int { - r := idx.c.call("sqlite3_vtab_distinct", uint64(idx.handle)) - return int(r) + i := int32(idx.c.call("sqlite3_vtab_distinct", uint64(idx.handle))) + return int(i) } // In identifies and handles IN constraints. // // https://sqlite.org/c3ref/vtab_in.html func (idx *IndexInfo) In(column, handle int) bool { - r := idx.c.call("sqlite3_vtab_in", uint64(idx.handle), - uint64(column), uint64(handle)) - return r != 0 + b := int32(idx.c.call("sqlite3_vtab_in", uint64(idx.handle), + uint64(column), uint64(handle))) + return b != 0 } func (idx *IndexInfo) load() { @@ -351,34 +350,35 @@ func (idx *IndexInfo) load() { mod := idx.c.mod ptr := idx.handle - idx.Constraint = make([]IndexConstraint, util.ReadUint32(mod, ptr+0)) - idx.ConstraintUsage = make([]IndexConstraintUsage, util.ReadUint32(mod, ptr+0)) - idx.OrderBy = make([]IndexOrderBy, util.ReadUint32(mod, ptr+8)) + nConstraint := util.Read32[int32](mod, ptr+0) + idx.Constraint = make([]IndexConstraint, nConstraint) + idx.ConstraintUsage = make([]IndexConstraintUsage, nConstraint) + idx.OrderBy = make([]IndexOrderBy, util.Read32[int32](mod, ptr+8)) - constraintPtr := util.ReadUint32(mod, ptr+4) + constraintPtr := util.Read32[ptr_t](mod, ptr+4) constraint := idx.Constraint for i := range idx.Constraint { constraint[i] = IndexConstraint{ - Column: int(int32(util.ReadUint32(mod, constraintPtr+0))), - Op: IndexConstraintOp(util.ReadUint8(mod, constraintPtr+4)), - Usable: util.ReadUint8(mod, constraintPtr+5) != 0, + Column: int(util.Read32[int32](mod, constraintPtr+0)), + Op: util.Read[IndexConstraintOp](mod, constraintPtr+4), + Usable: util.Read[byte](mod, constraintPtr+5) != 0, } constraintPtr += 12 } - orderByPtr := util.ReadUint32(mod, ptr+12) + orderByPtr := util.Read32[ptr_t](mod, ptr+12) orderBy := idx.OrderBy for i := range orderBy { orderBy[i] = IndexOrderBy{ - Column: int(int32(util.ReadUint32(mod, orderByPtr+0))), - Desc: util.ReadUint8(mod, orderByPtr+4) != 0, + Column: int(util.Read32[int32](mod, orderByPtr+0)), + Desc: util.Read[byte](mod, orderByPtr+4) != 0, } orderByPtr += 8 } idx.EstimatedCost = util.ReadFloat64(mod, ptr+40) - idx.EstimatedRows = int64(util.ReadUint64(mod, ptr+48)) - idx.ColumnsUsed = int64(util.ReadUint64(mod, ptr+64)) + idx.EstimatedRows = util.Read64[int64](mod, ptr+48) + idx.ColumnsUsed = util.Read64[uint64](mod, ptr+64) } func (idx *IndexInfo) save() { @@ -386,26 +386,26 @@ func (idx *IndexInfo) save() { mod := idx.c.mod ptr := idx.handle - usagePtr := util.ReadUint32(mod, ptr+16) + usagePtr := util.Read32[ptr_t](mod, ptr+16) for _, usage := range idx.ConstraintUsage { - util.WriteUint32(mod, usagePtr+0, uint32(usage.ArgvIndex)) + util.Write32(mod, usagePtr+0, int32(usage.ArgvIndex)) if usage.Omit { - util.WriteUint8(mod, usagePtr+4, 1) + util.Write(mod, usagePtr+4, int8(1)) } usagePtr += 8 } - util.WriteUint32(mod, ptr+20, uint32(idx.IdxNum)) + util.Write32(mod, ptr+20, int32(idx.IdxNum)) if idx.IdxStr != "" { - util.WriteUint32(mod, ptr+24, idx.c.newString(idx.IdxStr)) - util.WriteUint32(mod, ptr+28, 1) // needToFreeIdxStr + util.Write32(mod, ptr+24, idx.c.newString(idx.IdxStr)) + util.Write32(mod, ptr+28, int32(1)) // needToFreeIdxStr } if idx.OrderByConsumed { - util.WriteUint32(mod, ptr+32, 1) + util.Write32(mod, ptr+32, int32(1)) } util.WriteFloat64(mod, ptr+40, idx.EstimatedCost) - util.WriteUint64(mod, ptr+48, uint64(idx.EstimatedRows)) - util.WriteUint32(mod, ptr+56, uint32(idx.IdxFlags)) + util.Write64(mod, ptr+48, idx.EstimatedRows) + util.Write32(mod, ptr+56, idx.IdxFlags) } // IndexConstraintOp is a virtual table constraint operator code. @@ -442,13 +442,13 @@ const ( INDEX_SCAN_UNIQUE IndexScanFlag = 1 ) -func vtabModuleCallback(i vtabConstructor) func(_ context.Context, _ api.Module, _, _, _, _, _ uint32) uint32 { - return func(ctx context.Context, mod api.Module, pMod, nArg, pArg, ppVTab, pzErr uint32) uint32 { +func vtabModuleCallback(i vtabConstructor) func(_ context.Context, _ api.Module, _ ptr_t, _ int32, _, _, _ ptr_t) uint32 { + return func(ctx context.Context, mod api.Module, pMod ptr_t, nArg int32, pArg, ppVTab, pzErr ptr_t) uint32 { arg := make([]reflect.Value, 1+nArg) arg[0] = reflect.ValueOf(ctx.Value(connKey{})) - for i := uint32(0); i < nArg; i++ { - ptr := util.ReadUint32(mod, pArg+i*ptrlen) + for i := int32(0); i < nArg; i++ { + ptr := util.Read32[ptr_t](mod, pArg+ptr_t(i*ptrlen)) arg[i+1] = reflect.ValueOf(util.ReadString(mod, ptr, _MAX_SQL_LENGTH)) } @@ -463,12 +463,12 @@ func vtabModuleCallback(i vtabConstructor) func(_ context.Context, _ api.Module, } } -func vtabDisconnectCallback(ctx context.Context, mod api.Module, pVTab uint32) uint32 { +func vtabDisconnectCallback(ctx context.Context, mod api.Module, pVTab ptr_t) uint32 { err := vtabDelHandle(ctx, mod, pVTab) return vtabError(ctx, mod, 0, _PTR_ERROR, err) } -func vtabDestroyCallback(ctx context.Context, mod api.Module, pVTab uint32) uint32 { +func vtabDestroyCallback(ctx context.Context, mod api.Module, pVTab ptr_t) uint32 { vtab := vtabGetHandle(ctx, mod, pVTab).(VTabDestroyer) err := vtab.Destroy() if cerr := vtabDelHandle(ctx, mod, pVTab); err == nil { @@ -477,7 +477,7 @@ func vtabDestroyCallback(ctx context.Context, mod api.Module, pVTab uint32) uint return vtabError(ctx, mod, 0, _PTR_ERROR, err) } -func vtabBestIndexCallback(ctx context.Context, mod api.Module, pVTab, pIdxInfo uint32) uint32 { +func vtabBestIndexCallback(ctx context.Context, mod api.Module, pVTab, pIdxInfo ptr_t) uint32 { var info IndexInfo info.handle = pIdxInfo info.c = ctx.Value(connKey{}).(*Conn) @@ -490,7 +490,7 @@ func vtabBestIndexCallback(ctx context.Context, mod api.Module, pVTab, pIdxInfo return vtabError(ctx, mod, pVTab, _VTAB_ERROR, err) } -func vtabUpdateCallback(ctx context.Context, mod api.Module, pVTab, nArg, pArg, pRowID uint32) uint32 { +func vtabUpdateCallback(ctx context.Context, mod api.Module, pVTab ptr_t, nArg int32, pArg, pRowID ptr_t) uint32 { vtab := vtabGetHandle(ctx, mod, pVTab).(VTabUpdater) db := ctx.Value(connKey{}).(*Conn) @@ -498,33 +498,33 @@ func vtabUpdateCallback(ctx context.Context, mod api.Module, pVTab, nArg, pArg, callbackArgs(db, args, pArg) rowID, err := vtab.Update(args...) if err == nil { - util.WriteUint64(mod, pRowID, uint64(rowID)) + util.Write64(mod, pRowID, rowID) } return vtabError(ctx, mod, pVTab, _VTAB_ERROR, err) } -func vtabRenameCallback(ctx context.Context, mod api.Module, pVTab, zNew uint32) uint32 { +func vtabRenameCallback(ctx context.Context, mod api.Module, pVTab, zNew ptr_t) uint32 { vtab := vtabGetHandle(ctx, mod, pVTab).(VTabRenamer) err := vtab.Rename(util.ReadString(mod, zNew, _MAX_NAME)) return vtabError(ctx, mod, pVTab, _VTAB_ERROR, err) } -func vtabFindFuncCallback(ctx context.Context, mod api.Module, pVTab uint32, nArg int32, zName, pxFunc uint32) uint32 { +func vtabFindFuncCallback(ctx context.Context, mod api.Module, pVTab ptr_t, nArg int32, zName, pxFunc ptr_t) uint32 { vtab := vtabGetHandle(ctx, mod, pVTab).(VTabOverloader) f, op := vtab.FindFunction(int(nArg), util.ReadString(mod, zName, _MAX_NAME)) if op != 0 { - var wrapper uint32 + var wrapper ptr_t wrapper = util.AddHandle(ctx, func(c Context, arg ...Value) { defer util.DelHandle(ctx, wrapper) f(c, arg...) }) - util.WriteUint32(mod, pxFunc, wrapper) + util.Write32(mod, pxFunc, wrapper) } return uint32(op) } -func vtabIntegrityCallback(ctx context.Context, mod api.Module, pVTab, zSchema, zTabName, mFlags, pzErr uint32) uint32 { +func vtabIntegrityCallback(ctx context.Context, mod api.Module, pVTab, zSchema, zTabName ptr_t, mFlags uint32, pzErr ptr_t) uint32 { vtab := vtabGetHandle(ctx, mod, pVTab).(VTabChecker) schema := util.ReadString(mod, zSchema, _MAX_NAME) table := util.ReadString(mod, zTabName, _MAX_NAME) @@ -536,49 +536,49 @@ func vtabIntegrityCallback(ctx context.Context, mod api.Module, pVTab, zSchema, return code } -func vtabBeginCallback(ctx context.Context, mod api.Module, pVTab uint32) uint32 { +func vtabBeginCallback(ctx context.Context, mod api.Module, pVTab ptr_t) uint32 { vtab := vtabGetHandle(ctx, mod, pVTab).(VTabTxn) err := vtab.Begin() return vtabError(ctx, mod, pVTab, _VTAB_ERROR, err) } -func vtabSyncCallback(ctx context.Context, mod api.Module, pVTab uint32) uint32 { +func vtabSyncCallback(ctx context.Context, mod api.Module, pVTab ptr_t) uint32 { vtab := vtabGetHandle(ctx, mod, pVTab).(VTabTxn) err := vtab.Sync() return vtabError(ctx, mod, pVTab, _VTAB_ERROR, err) } -func vtabCommitCallback(ctx context.Context, mod api.Module, pVTab uint32) uint32 { +func vtabCommitCallback(ctx context.Context, mod api.Module, pVTab ptr_t) uint32 { vtab := vtabGetHandle(ctx, mod, pVTab).(VTabTxn) err := vtab.Commit() return vtabError(ctx, mod, pVTab, _VTAB_ERROR, err) } -func vtabRollbackCallback(ctx context.Context, mod api.Module, pVTab uint32) uint32 { +func vtabRollbackCallback(ctx context.Context, mod api.Module, pVTab ptr_t) uint32 { vtab := vtabGetHandle(ctx, mod, pVTab).(VTabTxn) err := vtab.Rollback() return vtabError(ctx, mod, pVTab, _VTAB_ERROR, err) } -func vtabSavepointCallback(ctx context.Context, mod api.Module, pVTab uint32, id int32) uint32 { +func vtabSavepointCallback(ctx context.Context, mod api.Module, pVTab ptr_t, id int32) uint32 { vtab := vtabGetHandle(ctx, mod, pVTab).(VTabSavepointer) err := vtab.Savepoint(int(id)) return vtabError(ctx, mod, pVTab, _VTAB_ERROR, err) } -func vtabReleaseCallback(ctx context.Context, mod api.Module, pVTab uint32, id int32) uint32 { +func vtabReleaseCallback(ctx context.Context, mod api.Module, pVTab ptr_t, id int32) uint32 { vtab := vtabGetHandle(ctx, mod, pVTab).(VTabSavepointer) err := vtab.Release(int(id)) return vtabError(ctx, mod, pVTab, _VTAB_ERROR, err) } -func vtabRollbackToCallback(ctx context.Context, mod api.Module, pVTab uint32, id int32) uint32 { +func vtabRollbackToCallback(ctx context.Context, mod api.Module, pVTab ptr_t, id int32) uint32 { vtab := vtabGetHandle(ctx, mod, pVTab).(VTabSavepointer) err := vtab.RollbackTo(int(id)) return vtabError(ctx, mod, pVTab, _VTAB_ERROR, err) } -func cursorOpenCallback(ctx context.Context, mod api.Module, pVTab, ppCur uint32) uint32 { +func cursorOpenCallback(ctx context.Context, mod api.Module, pVTab, ppCur ptr_t) uint32 { vtab := vtabGetHandle(ctx, mod, pVTab).(VTab) cursor, err := vtab.Open() @@ -589,12 +589,12 @@ func cursorOpenCallback(ctx context.Context, mod api.Module, pVTab, ppCur uint32 return vtabError(ctx, mod, pVTab, _VTAB_ERROR, err) } -func cursorCloseCallback(ctx context.Context, mod api.Module, pCur uint32) uint32 { +func cursorCloseCallback(ctx context.Context, mod api.Module, pCur ptr_t) uint32 { err := vtabDelHandle(ctx, mod, pCur) return vtabError(ctx, mod, 0, _VTAB_ERROR, err) } -func cursorFilterCallback(ctx context.Context, mod api.Module, pCur uint32, idxNum int32, idxStr, nArg, pArg uint32) uint32 { +func cursorFilterCallback(ctx context.Context, mod api.Module, pCur ptr_t, idxNum int32, idxStr ptr_t, nArg int32, pArg ptr_t) uint32 { cursor := vtabGetHandle(ctx, mod, pCur).(VTabCursor) db := ctx.Value(connKey{}).(*Conn) args := make([]Value, nArg) @@ -607,7 +607,7 @@ func cursorFilterCallback(ctx context.Context, mod api.Module, pCur uint32, idxN return vtabError(ctx, mod, pCur, _CURSOR_ERROR, err) } -func cursorEOFCallback(ctx context.Context, mod api.Module, pCur uint32) uint32 { +func cursorEOFCallback(ctx context.Context, mod api.Module, pCur ptr_t) uint32 { cursor := vtabGetHandle(ctx, mod, pCur).(VTabCursor) if cursor.EOF() { return 1 @@ -615,25 +615,25 @@ func cursorEOFCallback(ctx context.Context, mod api.Module, pCur uint32) uint32 return 0 } -func cursorNextCallback(ctx context.Context, mod api.Module, pCur uint32) uint32 { +func cursorNextCallback(ctx context.Context, mod api.Module, pCur ptr_t) uint32 { cursor := vtabGetHandle(ctx, mod, pCur).(VTabCursor) err := cursor.Next() return vtabError(ctx, mod, pCur, _CURSOR_ERROR, err) } -func cursorColumnCallback(ctx context.Context, mod api.Module, pCur, pCtx uint32, n int32) uint32 { +func cursorColumnCallback(ctx context.Context, mod api.Module, pCur, pCtx ptr_t, n int32) uint32 { cursor := vtabGetHandle(ctx, mod, pCur).(VTabCursor) db := ctx.Value(connKey{}).(*Conn) err := cursor.Column(Context{db, pCtx}, int(n)) return vtabError(ctx, mod, pCur, _CURSOR_ERROR, err) } -func cursorRowIDCallback(ctx context.Context, mod api.Module, pCur, pRowID uint32) uint32 { +func cursorRowIDCallback(ctx context.Context, mod api.Module, pCur, pRowID ptr_t) uint32 { cursor := vtabGetHandle(ctx, mod, pCur).(VTabCursor) rowID, err := cursor.RowID() if err == nil { - util.WriteUint64(mod, pRowID, uint64(rowID)) + util.Write64(mod, pRowID, rowID) } return vtabError(ctx, mod, pCur, _CURSOR_ERROR, err) @@ -645,7 +645,7 @@ const ( _CURSOR_ERROR ) -func vtabError(ctx context.Context, mod api.Module, ptr, kind uint32, err error) uint32 { +func vtabError(ctx context.Context, mod api.Module, ptr ptr_t, kind uint32, err error) uint32 { const zErrMsgOffset = 8 msg, code := errorCode(err, ERROR) if msg != "" && ptr != 0 { @@ -653,32 +653,32 @@ func vtabError(ctx context.Context, mod api.Module, ptr, kind uint32, err error) case _VTAB_ERROR: ptr = ptr + zErrMsgOffset // zErrMsg case _CURSOR_ERROR: - ptr = util.ReadUint32(mod, ptr) + zErrMsgOffset // pVTab->zErrMsg + ptr = util.Read32[ptr_t](mod, ptr) + zErrMsgOffset // pVTab->zErrMsg } db := ctx.Value(connKey{}).(*Conn) - if ptr := util.ReadUint32(mod, ptr); ptr != 0 { + if ptr := util.Read32[ptr_t](mod, ptr); ptr != 0 { db.free(ptr) } - util.WriteUint32(mod, ptr, db.newString(msg)) + util.Write32(mod, ptr, db.newString(msg)) } return code } -func vtabGetHandle(ctx context.Context, mod api.Module, ptr uint32) any { +func vtabGetHandle(ctx context.Context, mod api.Module, ptr ptr_t) any { const handleOffset = 4 - handle := util.ReadUint32(mod, ptr-handleOffset) + handle := util.Read32[ptr_t](mod, ptr-handleOffset) return util.GetHandle(ctx, handle) } -func vtabDelHandle(ctx context.Context, mod api.Module, ptr uint32) error { +func vtabDelHandle(ctx context.Context, mod api.Module, ptr ptr_t) error { const handleOffset = 4 - handle := util.ReadUint32(mod, ptr-handleOffset) + handle := util.Read32[ptr_t](mod, ptr-handleOffset) return util.DelHandle(ctx, handle) } -func vtabPutHandle(ctx context.Context, mod api.Module, pptr uint32, val any) { +func vtabPutHandle(ctx context.Context, mod api.Module, pptr ptr_t, val any) { const handleOffset = 4 handle := util.AddHandle(ctx, val) - ptr := util.ReadUint32(mod, pptr) - util.WriteUint32(mod, ptr-handleOffset, handle) + ptr := util.Read32[ptr_t](mod, pptr) + util.Write32(mod, ptr-handleOffset, handle) } From 72e88d558e7b56d2bb96f2fd9300896d539f30b2 Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Mon, 20 Jan 2025 21:22:54 +0000 Subject: [PATCH 2/6] Refactor. --- config.go | 18 +++++++++--------- conn.go | 4 ++-- driver/driver.go | 4 ++-- ext/csv/csv.go | 2 +- ext/pivot/pivot.go | 4 ++-- internal/util/mmap_unix.go | 6 +++--- internal/util/mmap_windows.go | 6 +++--- util/sql3util/parse.go | 16 ++++++++-------- vfs/cksm.go | 6 +++--- vfs/file.go | 8 ++++---- vfs/vfs.go | 20 +++++++++----------- vtab.go | 6 +++--- 12 files changed, 49 insertions(+), 51 deletions(-) diff --git a/config.go b/config.go index 7391d578..4fb55681 100644 --- a/config.go +++ b/config.go @@ -64,7 +64,7 @@ func (c *Conn) ConfigLog(cb func(code ExtendedErrorCode, msg string)) error { return nil } -func logCallback(ctx context.Context, mod api.Module, _, iCode, zMsg ptr_t) { +func logCallback(ctx context.Context, mod api.Module, _, iCode res_t, zMsg ptr_t) { if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.log != nil { msg := util.ReadString(mod, zMsg, _MAX_LENGTH) c.log(xErrorCode(iCode), msg) @@ -94,7 +94,7 @@ func (c *Conn) FileControl(schema string, op FcntlOpcode, arg ...any) (any, erro } var rc res_t - var res any + var ret any switch op { default: return nil, MISUSE @@ -116,7 +116,7 @@ func (c *Conn) FileControl(schema string, op FcntlOpcode, arg ...any) (any, erro rc = res_t(c.call("sqlite3_file_control", uint64(c.handle), uint64(schemaPtr), uint64(op), uint64(ptr))) - res = util.Read32[uint32](c.mod, ptr) != 0 + ret = util.Read32[uint32](c.mod, ptr) != 0 case FCNTL_CHUNK_SIZE: util.Write32(c.mod, ptr, int32(arg[0].(int))) @@ -133,19 +133,19 @@ func (c *Conn) FileControl(schema string, op FcntlOpcode, arg ...any) (any, erro rc = res_t(c.call("sqlite3_file_control", uint64(c.handle), uint64(schemaPtr), uint64(op), uint64(ptr))) - res = int(util.Read32[int32](c.mod, ptr)) + ret = int(util.Read32[int32](c.mod, ptr)) case FCNTL_DATA_VERSION: rc = res_t(c.call("sqlite3_file_control", uint64(c.handle), uint64(schemaPtr), uint64(op), uint64(ptr))) - res = util.Read32[uint32](c.mod, ptr) + ret = util.Read32[uint32](c.mod, ptr) case FCNTL_LOCKSTATE: rc = res_t(c.call("sqlite3_file_control", uint64(c.handle), uint64(schemaPtr), uint64(op), uint64(ptr))) - res = util.Read32[vfs.LockLevel](c.mod, ptr) + ret = util.Read32[vfs.LockLevel](c.mod, ptr) case FCNTL_VFS_POINTER: rc = res_t(c.call("sqlite3_file_control", @@ -156,7 +156,7 @@ func (c *Conn) FileControl(schema string, op FcntlOpcode, arg ...any) (any, erro ptr = util.Read32[ptr_t](c.mod, ptr) ptr = util.Read32[ptr_t](c.mod, ptr+zNameOffset) name := util.ReadString(c.mod, ptr, _MAX_NAME) - res = vfs.Find(name) + ret = vfs.Find(name) } case FCNTL_FILE_POINTER, FCNTL_JOURNAL_POINTER: @@ -167,14 +167,14 @@ func (c *Conn) FileControl(schema string, op FcntlOpcode, arg ...any) (any, erro const fileHandleOffset = 4 ptr = util.Read32[ptr_t](c.mod, ptr) ptr = util.Read32[ptr_t](c.mod, ptr+fileHandleOffset) - res = util.GetHandle(c.ctx, ptr) + ret = util.GetHandle(c.ctx, ptr) } } if err := c.error(rc); err != nil { return nil, err } - return res, nil + return ret, nil } // Limit allows the size of various constructs to be diff --git a/conn.go b/conn.go index 394dd657..94e764b4 100644 --- a/conn.go +++ b/conn.go @@ -70,7 +70,7 @@ func OpenFlags(filename string, flags OpenFlag) (*Conn, error) { type connKey = util.ConnKey -func newConn(ctx context.Context, filename string, flags OpenFlag) (res *Conn, _ error) { +func newConn(ctx context.Context, filename string, flags OpenFlag) (ret *Conn, _ error) { err := ctx.Err() if err != nil { return nil, err @@ -82,7 +82,7 @@ func newConn(ctx context.Context, filename string, flags OpenFlag) (res *Conn, _ return nil, err } defer func() { - if res == nil { + if ret == nil { c.Close() c.sqlite.close() } else { diff --git a/driver/driver.go b/driver/driver.go index 66b209eb..cc132afe 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -201,7 +201,7 @@ func (n *connector) Driver() driver.Driver { return &SQLite{} } -func (n *connector) Connect(ctx context.Context) (res driver.Conn, err error) { +func (n *connector) Connect(ctx context.Context) (ret driver.Conn, err error) { c := &conn{ txLock: n.txLock, tmRead: n.tmRead, @@ -213,7 +213,7 @@ func (n *connector) Connect(ctx context.Context) (res driver.Conn, err error) { return nil, err } defer func() { - if res == nil { + if ret == nil { c.Close() } }() diff --git a/ext/csv/csv.go b/ext/csv/csv.go index 7169d534..2ea48719 100644 --- a/ext/csv/csv.go +++ b/ext/csv/csv.go @@ -30,7 +30,7 @@ func Register(db *sqlite3.Conn) error { // RegisterFS registers the CSV virtual table. // If a filename is specified, fsys is used to open the file. func RegisterFS(db *sqlite3.Conn, fsys fs.FS) error { - declare := func(db *sqlite3.Conn, _, _, _ string, arg ...string) (res *table, err error) { + declare := func(db *sqlite3.Conn, _, _, _ string, arg ...string) (_ *table, err error) { var ( filename string data string diff --git a/ext/pivot/pivot.go b/ext/pivot/pivot.go index eafb615c..b669325d 100644 --- a/ext/pivot/pivot.go +++ b/ext/pivot/pivot.go @@ -25,14 +25,14 @@ type table struct { cols []*sqlite3.Value } -func declare(db *sqlite3.Conn, _, _, _ string, arg ...string) (res *table, err error) { +func declare(db *sqlite3.Conn, _, _, _ string, arg ...string) (ret *table, err error) { if len(arg) != 3 { return nil, fmt.Errorf("pivot: wrong number of arguments") } t := &table{db: db} defer func() { - if res == nil { + if ret == nil { t.Close() } }() diff --git a/internal/util/mmap_unix.go b/internal/util/mmap_unix.go index 0c5363a7..03ea8ba0 100644 --- a/internal/util/mmap_unix.go +++ b/internal/util/mmap_unix.go @@ -39,13 +39,13 @@ func (s *mmapState) new(ctx context.Context, mod api.Module, size int32) *Mapped // Save the newly allocated region. ptr := Ptr_t(stack[0]) buf := View(mod, ptr, uint64(size)) - res := &MappedRegion{ + ret := &MappedRegion{ Ptr: ptr, size: size, addr: unsafe.Pointer(&buf[0]), } - s.regions = append(s.regions, res) - return res + s.regions = append(s.regions, ret) + return ret } type MappedRegion struct { diff --git a/internal/util/mmap_windows.go b/internal/util/mmap_windows.go index efff1e73..913a5f73 100644 --- a/internal/util/mmap_windows.go +++ b/internal/util/mmap_windows.go @@ -29,13 +29,13 @@ func MapRegion(ctx context.Context, mod api.Module, f *os.File, offset int64, si return nil, err } - res := &MappedRegion{Handle: h, addr: a} + ret := &MappedRegion{Handle: h, addr: a} // SliceHeader, although deprecated, avoids a go vet warning. - sh := (*reflect.SliceHeader)(unsafe.Pointer(&res.Data)) + sh := (*reflect.SliceHeader)(unsafe.Pointer(&ret.Data)) sh.Len = int(size) sh.Cap = int(size) sh.Data = a - return res, nil + return ret, nil } func (r *MappedRegion) Unmap() error { diff --git a/util/sql3util/parse.go b/util/sql3util/parse.go index f84fc4dd..0cc0d3df 100644 --- a/util/sql3util/parse.go +++ b/util/sql3util/parse.go @@ -96,9 +96,9 @@ func (t *Table) load(mod api.Module, ptr uint32, sql string) { t.IsWithoutRowID = loadBool(mod, ptr+26) t.IsStrict = loadBool(mod, ptr+27) - t.Columns = loadSlice(mod, ptr+28, func(ptr uint32, res *Column) { + t.Columns = loadSlice(mod, ptr+28, func(ptr uint32, ret *Column) { p, _ := mod.Memory().ReadUint32Le(ptr) - res.load(mod, p, sql) + ret.load(mod, p, sql) }) t.Type = loadEnum[StatementType](mod, ptr+44) @@ -166,8 +166,8 @@ type ForeignKey struct { func (f *ForeignKey) load(mod api.Module, ptr uint32, sql string) { f.Table = loadString(mod, ptr+0, sql) - f.Columns = loadSlice(mod, ptr+8, func(ptr uint32, res *string) { - *res = loadString(mod, ptr, sql) + f.Columns = loadSlice(mod, ptr+8, func(ptr uint32, ret *string) { + *ret = loadString(mod, ptr, sql) }) f.OnDelete = loadEnum[FKAction](mod, ptr+16) @@ -191,12 +191,12 @@ func loadSlice[T any](mod api.Module, ptr uint32, fn func(uint32, *T)) []T { return nil } len, _ := mod.Memory().ReadUint32Le(ptr + 0) - res := make([]T, len) - for i := range res { - fn(ref, &res[i]) + ret := make([]T, len) + for i := range ret { + fn(ref, &ret[i]) ref += 4 } - return res + return ret } func loadEnum[T ~uint32](mod api.Module, ptr uint32) T { diff --git a/vfs/cksm.go b/vfs/cksm.go index 51f5e8a4..39493df9 100644 --- a/vfs/cksm.go +++ b/vfs/cksm.go @@ -102,11 +102,11 @@ func (c cksmFile) Pragma(name string, value string) (string, error) { } func (c cksmFile) DeviceCharacteristics() DeviceCharacteristic { - res := c.File.DeviceCharacteristics() + ret := c.File.DeviceCharacteristics() if c.verifyCksm { - res &^= IOCAP_SUBPAGE_READ + ret &^= IOCAP_SUBPAGE_READ } - return res + return ret } func (c cksmFile) fileControl(ctx context.Context, mod api.Module, op _FcntlOpcode, pArg ptr_t) _ErrorCode { diff --git a/vfs/file.go b/vfs/file.go index e028a2a5..bc90555e 100644 --- a/vfs/file.go +++ b/vfs/file.go @@ -186,14 +186,14 @@ func (f *vfsFile) SectorSize() int { } func (f *vfsFile) DeviceCharacteristics() DeviceCharacteristic { - res := IOCAP_SUBPAGE_READ + ret := IOCAP_SUBPAGE_READ if osBatchAtomic(f.File) { - res |= IOCAP_BATCH_ATOMIC + ret |= IOCAP_BATCH_ATOMIC } if f.psow { - res |= IOCAP_POWERSAFE_OVERWRITE + ret |= IOCAP_POWERSAFE_OVERWRITE } - return res + return ret } func (f *vfsFile) SizeHint(size int64) error { diff --git a/vfs/vfs.go b/vfs/vfs.go index 2a7cd9a3..2037cd0a 100644 --- a/vfs/vfs.go +++ b/vfs/vfs.go @@ -123,12 +123,11 @@ func vfsAccess(ctx context.Context, mod api.Module, pVfs, zPath ptr_t, flags Acc path := util.ReadString(mod, zPath, _MAX_PATHNAME) ok, err := vfs.Access(path, flags) - var res int32 + var val int32 if ok { - res = 1 + val = 1 } - - util.Write32(mod, pResOut, res) + util.Write32(mod, pResOut, val) return vfsErrorCode(err, _IOERR_ACCESS) } @@ -227,12 +226,11 @@ func vfsCheckReservedLock(ctx context.Context, mod api.Module, pFile, pResOut pt file := vfsFileGet(ctx, mod, pFile).(File) locked, err := file.CheckReservedLock() - var res int32 + var val int32 if locked { - res = 1 + val = 1 } - - util.Write32(mod, pResOut, res) + util.Write32(mod, pResOut, val) return vfsErrorCode(err, _IOERR_CHECKRESERVEDLOCK) } @@ -295,11 +293,11 @@ func vfsFileControlImpl(ctx context.Context, mod api.Module, file File, op _Fcnt case _FCNTL_HAS_MOVED: if file, ok := file.(FileHasMoved); ok { moved, err := file.HasMoved() - var res uint32 + var val uint32 if moved { - res = 1 + val = 1 } - util.Write32(mod, pArg, res) + util.Write32(mod, pArg, val) return vfsErrorCode(err, _IOERR_FSTAT) } diff --git a/vtab.go b/vtab.go index ef9ce94f..9689bead 100644 --- a/vtab.go +++ b/vtab.go @@ -453,10 +453,10 @@ func vtabModuleCallback(i vtabConstructor) func(_ context.Context, _ api.Module, } module := vtabGetHandle(ctx, mod, pMod) - res := reflect.ValueOf(module).Index(int(i)).Call(arg) - err, _ := res[1].Interface().(error) + val := reflect.ValueOf(module).Index(int(i)).Call(arg) + err, _ := val[1].Interface().(error) if err == nil { - vtabPutHandle(ctx, mod, ppVTab, res[0].Interface()) + vtabPutHandle(ctx, mod, ppVTab, val[0].Interface()) } return vtabError(ctx, mod, pzErr, _PTR_ERROR, err) From 14b14535fe31357a6bc2a7dfe2be0d397a0b5ea2 Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Mon, 20 Jan 2025 21:27:41 +0000 Subject: [PATCH 3/6] Ports. --- vfs/shm_bsd.go | 2 +- vfs/shm_dotlk.go | 8 ++++---- vfs/shm_windows.go | 8 ++++---- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/vfs/shm_bsd.go b/vfs/shm_bsd.go index 76e6888e..11e7bb2f 100644 --- a/vfs/shm_bsd.go +++ b/vfs/shm_bsd.go @@ -142,7 +142,7 @@ func (s *vfsShm) shmOpen() _ErrorCode { return _OK } -func (s *vfsShm) shmMap(ctx context.Context, mod api.Module, id, size int32, extend bool) (uint32, _ErrorCode) { +func (s *vfsShm) shmMap(ctx context.Context, mod api.Module, id, size int32, extend bool) (ptr_t, _ErrorCode) { // Ensure size is a multiple of the OS page size. if int(size)&(unix.Getpagesize()-1) != 0 { return 0, _IOERR_SHMMAP diff --git a/vfs/shm_dotlk.go b/vfs/shm_dotlk.go index 842bea8f..e5062481 100644 --- a/vfs/shm_dotlk.go +++ b/vfs/shm_dotlk.go @@ -35,7 +35,7 @@ type vfsShm struct { free api.Function path string shadow [][_WALINDEX_PGSZ]byte - ptrs []uint32 + ptrs []ptr_t stack [1]uint64 lock [_SHM_NLOCK]bool } @@ -96,7 +96,7 @@ func (s *vfsShm) shmOpen() _ErrorCode { return _OK } -func (s *vfsShm) shmMap(ctx context.Context, mod api.Module, id, size int32, extend bool) (uint32, _ErrorCode) { +func (s *vfsShm) shmMap(ctx context.Context, mod api.Module, id, size int32, extend bool) (ptr_t, _ErrorCode) { if size != _WALINDEX_PGSZ { return 0, _IOERR_SHMMAP } @@ -135,8 +135,8 @@ func (s *vfsShm) shmMap(ctx context.Context, mod api.Module, id, size int32, ext if s.stack[0] == 0 { panic(util.OOMErr) } - clear(util.View(s.mod, uint32(s.stack[0]), _WALINDEX_PGSZ)) - s.ptrs = append(s.ptrs, uint32(s.stack[0])) + clear(util.View(s.mod, ptr_t(s.stack[0]), _WALINDEX_PGSZ)) + s.ptrs = append(s.ptrs, ptr_t(s.stack[0])) } s.shadow[0][4] = 1 diff --git a/vfs/shm_windows.go b/vfs/shm_windows.go index 1de57640..29d26ab5 100644 --- a/vfs/shm_windows.go +++ b/vfs/shm_windows.go @@ -26,7 +26,7 @@ type vfsShm struct { regions []*util.MappedRegion shared [][]byte shadow [][_WALINDEX_PGSZ]byte - ptrs []uint32 + ptrs []ptr_t stack [1]uint64 fileLock bool blocking bool @@ -72,7 +72,7 @@ func (s *vfsShm) shmOpen() _ErrorCode { return rc } -func (s *vfsShm) shmMap(ctx context.Context, mod api.Module, id, size int32, extend bool) (_ uint32, rc _ErrorCode) { +func (s *vfsShm) shmMap(ctx context.Context, mod api.Module, id, size int32, extend bool) (_ ptr_t, rc _ErrorCode) { // Ensure size is a multiple of the OS page size. if size != _WALINDEX_PGSZ || (windows.Getpagesize()-1)&_WALINDEX_PGSZ != 0 { return 0, _IOERR_SHMMAP @@ -126,8 +126,8 @@ func (s *vfsShm) shmMap(ctx context.Context, mod api.Module, id, size int32, ext if s.stack[0] == 0 { panic(util.OOMErr) } - clear(util.View(s.mod, uint32(s.stack[0]), _WALINDEX_PGSZ)) - s.ptrs = append(s.ptrs, uint32(s.stack[0])) + clear(util.View(s.mod, ptr_t(s.stack[0]), _WALINDEX_PGSZ)) + s.ptrs = append(s.ptrs, ptr_t(s.stack[0])) } s.shadow[0][4] = 1 From 4abb963aaf4793aa34ab4f4445421f57a653fe6a Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Tue, 21 Jan 2025 00:26:03 +0000 Subject: [PATCH 4/6] More. --- backup.go | 14 +++--- blob.go | 46 +++++++++--------- config.go | 70 ++++++++++++++-------------- conn.go | 68 +++++++++++++-------------- const.go | 1 + context.go | 36 +++++++-------- error.go | 14 +++--- error_test.go | 8 ++-- func.go | 28 +++++------ internal/util/func.go | 4 -- internal/util/mem.go | 25 +++++----- internal/util/mem_test.go | 2 +- internal/util/mmap_unix.go | 8 ++-- sqlite.go | 38 ++++++++------- sqlite_test.go | 16 +++---- stmt.go | 95 +++++++++++++++++++------------------- txn.go | 22 ++++----- util/sql3util/parse.go | 2 +- value.go | 14 +++--- vfs/const.go | 5 +- vfs/filename.go | 14 +++--- vfs/shm_dotlk.go | 6 +-- vfs/shm_windows.go | 6 +-- vfs/vfs.go | 16 +++---- vtab.go | 76 +++++++++++++++--------------- 25 files changed, 317 insertions(+), 317 deletions(-) diff --git a/backup.go b/backup.go index 6378aab8..58b6229a 100644 --- a/backup.go +++ b/backup.go @@ -72,11 +72,11 @@ func (c *Conn) backupInit(dst ptr_t, dstName string, src ptr_t, srcName string) } ptr := ptr_t(c.call("sqlite3_backup_init", - uint64(dst), uint64(dstPtr), - uint64(src), uint64(srcPtr))) + stk_t(dst), stk_t(dstPtr), + stk_t(src), stk_t(srcPtr))) if ptr == 0 { defer c.closeDB(other) - rc := res_t(c.call("sqlite3_errcode", uint64(dst))) + rc := res_t(c.call("sqlite3_errcode", stk_t(dst))) return nil, c.sqlite.error(rc, dst) } @@ -97,7 +97,7 @@ func (b *Backup) Close() error { return nil } - rc := res_t(b.c.call("sqlite3_backup_finish", uint64(b.handle))) + rc := res_t(b.c.call("sqlite3_backup_finish", stk_t(b.handle))) b.c.closeDB(b.otherc) b.handle = 0 return b.c.error(rc) @@ -108,7 +108,7 @@ func (b *Backup) Close() error { // // https://sqlite.org/c3ref/backup_finish.html#sqlite3backupstep func (b *Backup) Step(nPage int) (done bool, err error) { - rc := res_t(b.c.call("sqlite3_backup_step", uint64(b.handle), uint64(nPage))) + rc := res_t(b.c.call("sqlite3_backup_step", stk_t(b.handle), stk_t(nPage))) if rc == _DONE { return true, nil } @@ -120,7 +120,7 @@ func (b *Backup) Step(nPage int) (done bool, err error) { // // https://sqlite.org/c3ref/backup_finish.html#sqlite3backupremaining func (b *Backup) Remaining() int { - n := int32(b.c.call("sqlite3_backup_remaining", uint64(b.handle))) + n := int32(b.c.call("sqlite3_backup_remaining", stk_t(b.handle))) return int(n) } @@ -129,6 +129,6 @@ func (b *Backup) Remaining() int { // // https://sqlite.org/c3ref/backup_finish.html#sqlite3backuppagecount func (b *Backup) PageCount() int { - n := int32(b.c.call("sqlite3_backup_pagecount", uint64(b.handle))) + n := int32(b.c.call("sqlite3_backup_pagecount", stk_t(b.handle))) return int(n) } diff --git a/blob.go b/blob.go index a2e4cfee..2fac7204 100644 --- a/blob.go +++ b/blob.go @@ -37,15 +37,15 @@ func (c *Conn) OpenBlob(db, table, column string, row int64, write bool) (*Blob, tablePtr := c.arena.string(table) columnPtr := c.arena.string(column) - var flags uint64 + var flags int32 if write { flags = 1 } c.checkInterrupt(c.handle) - rc := res_t(c.call("sqlite3_blob_open", uint64(c.handle), - uint64(dbPtr), uint64(tablePtr), uint64(columnPtr), - uint64(row), flags, uint64(blobPtr))) + rc := res_t(c.call("sqlite3_blob_open", stk_t(c.handle), + stk_t(dbPtr), stk_t(tablePtr), stk_t(columnPtr), + stk_t(row), stk_t(flags), stk_t(blobPtr))) if err := c.error(rc); err != nil { return nil, err @@ -53,7 +53,7 @@ func (c *Conn) OpenBlob(db, table, column string, row int64, write bool) (*Blob, blob := Blob{c: c} blob.handle = util.Read32[ptr_t](c.mod, blobPtr) - blob.bytes = int64(int32(c.call("sqlite3_blob_bytes", uint64(blob.handle)))) + blob.bytes = int64(int32(c.call("sqlite3_blob_bytes", stk_t(blob.handle)))) return &blob, nil } @@ -67,7 +67,7 @@ func (b *Blob) Close() error { return nil } - rc := res_t(b.c.call("sqlite3_blob_close", uint64(b.handle))) + rc := res_t(b.c.call("sqlite3_blob_close", stk_t(b.handle))) b.c.free(b.bufptr) b.handle = 0 return b.c.error(rc) @@ -94,12 +94,12 @@ func (b *Blob) Read(p []byte) (n int, err error) { want = avail } if want > b.buflen { - b.bufptr = b.c.realloc(b.bufptr, uint64(want)) + b.bufptr = b.c.realloc(b.bufptr, want) b.buflen = want } - rc := res_t(b.c.call("sqlite3_blob_read", uint64(b.handle), - uint64(b.bufptr), uint64(want), uint64(b.offset))) + rc := res_t(b.c.call("sqlite3_blob_read", stk_t(b.handle), + stk_t(b.bufptr), stk_t(want), stk_t(b.offset))) err = b.c.error(rc) if err != nil { return 0, err @@ -109,7 +109,7 @@ func (b *Blob) Read(p []byte) (n int, err error) { err = io.EOF } - copy(p, util.View(b.c.mod, b.bufptr, uint64(want))) + copy(p, util.View(b.c.mod, b.bufptr, want)) return int(want), err } @@ -127,19 +127,19 @@ func (b *Blob) WriteTo(w io.Writer) (n int64, err error) { want = avail } if want > b.buflen { - b.bufptr = b.c.realloc(b.bufptr, uint64(want)) + b.bufptr = b.c.realloc(b.bufptr, want) b.buflen = want } for want > 0 { - rc := res_t(b.c.call("sqlite3_blob_read", uint64(b.handle), - uint64(b.bufptr), uint64(want), uint64(b.offset))) + rc := res_t(b.c.call("sqlite3_blob_read", stk_t(b.handle), + stk_t(b.bufptr), stk_t(want), stk_t(b.offset))) err = b.c.error(rc) if err != nil { return n, err } - mem := util.View(b.c.mod, b.bufptr, uint64(want)) + mem := util.View(b.c.mod, b.bufptr, want) m, err := w.Write(mem[:want]) b.offset += int64(m) n += int64(m) @@ -165,13 +165,13 @@ func (b *Blob) WriteTo(w io.Writer) (n int64, err error) { func (b *Blob) Write(p []byte) (n int, err error) { want := int64(len(p)) if want > b.buflen { - b.bufptr = b.c.realloc(b.bufptr, uint64(want)) + b.bufptr = b.c.realloc(b.bufptr, want) b.buflen = want } util.WriteBytes(b.c.mod, b.bufptr, p) - rc := res_t(b.c.call("sqlite3_blob_write", uint64(b.handle), - uint64(b.bufptr), uint64(want), uint64(b.offset))) + rc := res_t(b.c.call("sqlite3_blob_write", stk_t(b.handle), + stk_t(b.bufptr), stk_t(want), stk_t(b.offset))) err = b.c.error(rc) if err != nil { return 0, err @@ -196,16 +196,16 @@ func (b *Blob) ReadFrom(r io.Reader) (n int64, err error) { want = 1 } if want > b.buflen { - b.bufptr = b.c.realloc(b.bufptr, uint64(want)) + b.bufptr = b.c.realloc(b.bufptr, want) b.buflen = want } for { - mem := util.View(b.c.mod, b.bufptr, uint64(want)) + mem := util.View(b.c.mod, b.bufptr, want) m, err := r.Read(mem[:want]) if m > 0 { - rc := res_t(b.c.call("sqlite3_blob_write", uint64(b.handle), - uint64(b.bufptr), uint64(m), uint64(b.offset))) + rc := res_t(b.c.call("sqlite3_blob_write", stk_t(b.handle), + stk_t(b.bufptr), stk_t(m), stk_t(b.offset))) err := b.c.error(rc) if err != nil { return n, err @@ -254,8 +254,8 @@ func (b *Blob) Seek(offset int64, whence int) (int64, error) { // https://sqlite.org/c3ref/blob_reopen.html func (b *Blob) Reopen(row int64) error { b.c.checkInterrupt(b.c.handle) - err := b.c.error(res_t(b.c.call("sqlite3_blob_reopen", uint64(b.handle), uint64(row)))) - b.bytes = int64(int32(b.c.call("sqlite3_blob_bytes", uint64(b.handle)))) + err := b.c.error(res_t(b.c.call("sqlite3_blob_reopen", stk_t(b.handle), stk_t(row)))) + b.bytes = int64(int32(b.c.call("sqlite3_blob_bytes", stk_t(b.handle)))) b.offset = 0 return err } diff --git a/config.go b/config.go index 4fb55681..7fff6ead 100644 --- a/config.go +++ b/config.go @@ -43,8 +43,8 @@ func (c *Conn) Config(op DBConfig, arg ...bool) (bool, error) { util.Write32(c.mod, argsPtr+0*ptrlen, flag) util.Write32(c.mod, argsPtr+1*ptrlen, argsPtr) - rc := res_t(c.call("sqlite3_db_config", uint64(c.handle), - uint64(op), uint64(argsPtr))) + rc := res_t(c.call("sqlite3_db_config", stk_t(c.handle), + stk_t(op), stk_t(argsPtr))) return util.Read32[uint32](c.mod, argsPtr) != 0, c.error(rc) } @@ -52,11 +52,11 @@ func (c *Conn) Config(op DBConfig, arg ...bool) (bool, error) { // // https://sqlite.org/errlog.html func (c *Conn) ConfigLog(cb func(code ExtendedErrorCode, msg string)) error { - var enable uint64 + var enable int32 if cb != nil { enable = 1 } - rc := res_t(c.call("sqlite3_config_log_go", enable)) + rc := res_t(c.call("sqlite3_config_log_go", stk_t(enable))) if err := c.error(rc); err != nil { return err } @@ -64,7 +64,7 @@ func (c *Conn) ConfigLog(cb func(code ExtendedErrorCode, msg string)) error { return nil } -func logCallback(ctx context.Context, mod api.Module, _, iCode res_t, zMsg ptr_t) { +func logCallback(ctx context.Context, mod api.Module, _ ptr_t, iCode res_t, zMsg ptr_t) { if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.log != nil { msg := util.ReadString(mod, zMsg, _MAX_LENGTH) c.log(xErrorCode(iCode), msg) @@ -101,8 +101,8 @@ func (c *Conn) FileControl(schema string, op FcntlOpcode, arg ...any) (any, erro case FCNTL_RESET_CACHE: rc = res_t(c.call("sqlite3_file_control", - uint64(c.handle), uint64(schemaPtr), - uint64(op), 0)) + stk_t(c.handle), stk_t(schemaPtr), + stk_t(op), 0)) case FCNTL_PERSIST_WAL, FCNTL_POWERSAFE_OVERWRITE: var flag int32 @@ -114,15 +114,15 @@ func (c *Conn) FileControl(schema string, op FcntlOpcode, arg ...any) (any, erro } util.Write32(c.mod, ptr, flag) rc = res_t(c.call("sqlite3_file_control", - uint64(c.handle), uint64(schemaPtr), - uint64(op), uint64(ptr))) + stk_t(c.handle), stk_t(schemaPtr), + stk_t(op), stk_t(ptr))) ret = util.Read32[uint32](c.mod, ptr) != 0 case FCNTL_CHUNK_SIZE: util.Write32(c.mod, ptr, int32(arg[0].(int))) rc = res_t(c.call("sqlite3_file_control", - uint64(c.handle), uint64(schemaPtr), - uint64(op), uint64(ptr))) + stk_t(c.handle), stk_t(schemaPtr), + stk_t(op), stk_t(ptr))) case FCNTL_RESERVE_BYTES: bytes := -1 @@ -131,26 +131,26 @@ func (c *Conn) FileControl(schema string, op FcntlOpcode, arg ...any) (any, erro } util.Write32(c.mod, ptr, int32(bytes)) rc = res_t(c.call("sqlite3_file_control", - uint64(c.handle), uint64(schemaPtr), - uint64(op), uint64(ptr))) + stk_t(c.handle), stk_t(schemaPtr), + stk_t(op), stk_t(ptr))) ret = int(util.Read32[int32](c.mod, ptr)) case FCNTL_DATA_VERSION: rc = res_t(c.call("sqlite3_file_control", - uint64(c.handle), uint64(schemaPtr), - uint64(op), uint64(ptr))) + stk_t(c.handle), stk_t(schemaPtr), + stk_t(op), stk_t(ptr))) ret = util.Read32[uint32](c.mod, ptr) case FCNTL_LOCKSTATE: rc = res_t(c.call("sqlite3_file_control", - uint64(c.handle), uint64(schemaPtr), - uint64(op), uint64(ptr))) + stk_t(c.handle), stk_t(schemaPtr), + stk_t(op), stk_t(ptr))) ret = util.Read32[vfs.LockLevel](c.mod, ptr) case FCNTL_VFS_POINTER: rc = res_t(c.call("sqlite3_file_control", - uint64(c.handle), uint64(schemaPtr), - uint64(op), uint64(ptr))) + stk_t(c.handle), stk_t(schemaPtr), + stk_t(op), stk_t(ptr))) if rc == _OK { const zNameOffset = 16 ptr = util.Read32[ptr_t](c.mod, ptr) @@ -161,8 +161,8 @@ func (c *Conn) FileControl(schema string, op FcntlOpcode, arg ...any) (any, erro case FCNTL_FILE_POINTER, FCNTL_JOURNAL_POINTER: rc = res_t(c.call("sqlite3_file_control", - uint64(c.handle), uint64(schemaPtr), - uint64(op), uint64(ptr))) + stk_t(c.handle), stk_t(schemaPtr), + stk_t(op), stk_t(ptr))) if rc == _OK { const fileHandleOffset = 4 ptr = util.Read32[ptr_t](c.mod, ptr) @@ -182,7 +182,7 @@ func (c *Conn) FileControl(schema string, op FcntlOpcode, arg ...any) (any, erro // // https://sqlite.org/c3ref/limit.html func (c *Conn) Limit(id LimitCategory, value int) int { - v := int32(c.call("sqlite3_limit", uint64(c.handle), uint64(id), uint64(value))) + v := int32(c.call("sqlite3_limit", stk_t(c.handle), stk_t(id), stk_t(value))) return int(v) } @@ -190,11 +190,11 @@ func (c *Conn) Limit(id LimitCategory, value int) int { // // https://sqlite.org/c3ref/set_authorizer.html func (c *Conn) SetAuthorizer(cb func(action AuthorizerActionCode, name3rd, name4th, schema, inner string) AuthorizerReturnCode) error { - var enable uint64 + var enable int32 if cb != nil { enable = 1 } - rc := res_t(c.call("sqlite3_set_authorizer_go", uint64(c.handle), enable)) + rc := res_t(c.call("sqlite3_set_authorizer_go", stk_t(c.handle), stk_t(enable))) if err := c.error(rc); err != nil { return err } @@ -227,7 +227,7 @@ func authorizerCallback(ctx context.Context, mod api.Module, pDB ptr_t, action A // // https://sqlite.org/c3ref/trace_v2.html func (c *Conn) Trace(mask TraceEvent, cb func(evt TraceEvent, arg1 any, arg2 any) error) error { - rc := res_t(c.call("sqlite3_trace_go", uint64(c.handle), uint64(mask))) + rc := res_t(c.call("sqlite3_trace_go", stk_t(c.handle), stk_t(mask))) if err := c.error(rc); err != nil { return err } @@ -235,7 +235,7 @@ func (c *Conn) Trace(mask TraceEvent, cb func(evt TraceEvent, arg1 any, arg2 any return nil } -func traceCallback(ctx context.Context, mod api.Module, evt TraceEvent, pDB, pArg1, pArg2 ptr_t) (rc uint32) { +func traceCallback(ctx context.Context, mod api.Module, evt TraceEvent, pDB, pArg1, pArg2 ptr_t) (rc res_t) { if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.trace != nil { var arg1, arg2 any if evt == TRACE_CLOSE { @@ -270,8 +270,8 @@ func (c *Conn) WALCheckpoint(schema string, mode CheckpointMode) (nLog, nCkpt in nCkptPtr := c.arena.new(ptrlen) schemaPtr := c.arena.string(schema) rc := res_t(c.call("sqlite3_wal_checkpoint_v2", - uint64(c.handle), uint64(schemaPtr), uint64(mode), - uint64(nLogPtr), uint64(nCkptPtr))) + stk_t(c.handle), stk_t(schemaPtr), stk_t(mode), + stk_t(nLogPtr), stk_t(nCkptPtr))) nLog = int(util.Read32[int32](c.mod, nLogPtr)) nCkpt = int(util.Read32[int32](c.mod, nCkptPtr)) return nLog, nCkpt, c.error(rc) @@ -281,7 +281,7 @@ func (c *Conn) WALCheckpoint(schema string, mode CheckpointMode) (nLog, nCkpt in // // https://sqlite.org/c3ref/wal_autocheckpoint.html func (c *Conn) WALAutoCheckpoint(pages int) error { - rc := res_t(c.call("sqlite3_wal_autocheckpoint", uint64(c.handle), uint64(pages))) + rc := res_t(c.call("sqlite3_wal_autocheckpoint", stk_t(c.handle), stk_t(pages))) return c.error(rc) } @@ -290,15 +290,15 @@ func (c *Conn) WALAutoCheckpoint(pages int) error { // // https://sqlite.org/c3ref/wal_hook.html func (c *Conn) WALHook(cb func(db *Conn, schema string, pages int) error) { - var enable uint64 + var enable int32 if cb != nil { enable = 1 } - c.call("sqlite3_wal_hook_go", uint64(c.handle), enable) + c.call("sqlite3_wal_hook_go", stk_t(c.handle), stk_t(enable)) c.wal = cb } -func walCallback(ctx context.Context, mod api.Module, _, pDB, zSchema ptr_t, pages int32) (rc uint32) { +func walCallback(ctx context.Context, mod api.Module, _, pDB, zSchema ptr_t, pages int32) (rc res_t) { if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.wal != nil { schema := util.ReadString(mod, zSchema, _MAX_NAME) err := c.wal(c, schema, int(pages)) @@ -315,7 +315,7 @@ func (c *Conn) AutoVacuumPages(cb func(schema string, dbPages, freePages, bytesP if cb != nil { funcPtr = util.AddHandle(c.ctx, cb) } - rc := res_t(c.call("sqlite3_autovacuum_pages_go", uint64(c.handle), uint64(funcPtr))) + rc := res_t(c.call("sqlite3_autovacuum_pages_go", stk_t(c.handle), stk_t(funcPtr))) return c.error(rc) } @@ -329,14 +329,14 @@ func autoVacuumCallback(ctx context.Context, mod api.Module, pApp, zSchema ptr_t // // https://sqlite.org/c3ref/hard_heap_limit64.html func (c *Conn) SoftHeapLimit(n int64) int64 { - return int64(c.call("sqlite3_soft_heap_limit64", uint64(n))) + return int64(c.call("sqlite3_soft_heap_limit64", stk_t(n))) } // HardHeapLimit imposes a hard limit on heap size. // // https://sqlite.org/c3ref/hard_heap_limit64.html func (c *Conn) HardHeapLimit(n int64) int64 { - return int64(c.call("sqlite3_hard_heap_limit64", uint64(n))) + return int64(c.call("sqlite3_hard_heap_limit64", stk_t(n))) } // EnableChecksums enables checksums on a database. diff --git a/conn.go b/conn.go index 94e764b4..79119d7e 100644 --- a/conn.go +++ b/conn.go @@ -108,7 +108,7 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (ptr_t, error) { namePtr := c.arena.string(filename) flags |= OPEN_EXRESCODE - rc := res_t(c.call("sqlite3_open_v2", uint64(namePtr), uint64(connPtr), uint64(flags), 0)) + rc := res_t(c.call("sqlite3_open_v2", stk_t(namePtr), stk_t(connPtr), stk_t(flags), 0)) handle := util.Read32[ptr_t](c.mod, connPtr) if err := c.sqlite.error(rc, handle); err != nil { @@ -116,7 +116,7 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (ptr_t, error) { return 0, err } - c.call("sqlite3_progress_handler_go", uint64(handle), 100) + c.call("sqlite3_progress_handler_go", stk_t(handle), 100) if flags|OPEN_URI != 0 && strings.HasPrefix(filename, "file:") { var pragmas strings.Builder if _, after, ok := strings.Cut(filename, "?"); ok { @@ -130,7 +130,7 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (ptr_t, error) { if pragmas.Len() != 0 { c.checkInterrupt(handle) pragmaPtr := c.arena.string(pragmas.String()) - rc := res_t(c.call("sqlite3_exec", uint64(handle), uint64(pragmaPtr), 0, 0, 0)) + rc := res_t(c.call("sqlite3_exec", stk_t(handle), stk_t(pragmaPtr), 0, 0, 0)) if err := c.sqlite.error(rc, handle, pragmas.String()); err != nil { err = fmt.Errorf("sqlite3: invalid _pragma: %w", err) c.closeDB(handle) @@ -142,7 +142,7 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (ptr_t, error) { } func (c *Conn) closeDB(handle ptr_t) { - rc := res_t(c.call("sqlite3_close_v2", uint64(handle))) + rc := res_t(c.call("sqlite3_close_v2", stk_t(handle))) if err := c.sqlite.error(rc, handle); err != nil { panic(err) } @@ -165,7 +165,7 @@ func (c *Conn) Close() error { c.pending.Close() c.pending = nil - rc := res_t(c.call("sqlite3_close", uint64(c.handle))) + rc := res_t(c.call("sqlite3_close", stk_t(c.handle))) if err := c.error(rc); err != nil { return err } @@ -183,7 +183,7 @@ func (c *Conn) Exec(sql string) error { sqlPtr := c.arena.string(sql) c.checkInterrupt(c.handle) - rc := res_t(c.call("sqlite3_exec", uint64(c.handle), uint64(sqlPtr), 0, 0, 0)) + rc := res_t(c.call("sqlite3_exec", stk_t(c.handle), stk_t(sqlPtr), 0, 0, 0)) return c.error(rc, sql) } @@ -209,9 +209,9 @@ func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail str sqlPtr := c.arena.string(sql) c.checkInterrupt(c.handle) - rc := res_t(c.call("sqlite3_prepare_v3", uint64(c.handle), - uint64(sqlPtr), uint64(len(sql)+1), uint64(flags), - uint64(stmtPtr), uint64(tailPtr))) + rc := res_t(c.call("sqlite3_prepare_v3", stk_t(c.handle), + stk_t(sqlPtr), stk_t(len(sql)+1), stk_t(flags), + stk_t(stmtPtr), stk_t(tailPtr))) stmt = &Stmt{c: c} stmt.handle = util.Read32[ptr_t](c.mod, stmtPtr) @@ -233,7 +233,7 @@ func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail str // // https://sqlite.org/c3ref/db_name.html func (c *Conn) DBName(n int) string { - ptr := ptr_t(c.call("sqlite3_db_name", uint64(c.handle), uint64(n))) + ptr := ptr_t(c.call("sqlite3_db_name", stk_t(c.handle), stk_t(n))) if ptr == 0 { return "" } @@ -249,7 +249,7 @@ func (c *Conn) Filename(schema string) *vfs.Filename { defer c.arena.mark()() ptr = c.arena.string(schema) } - ptr = ptr_t(c.call("sqlite3_db_filename", uint64(c.handle), uint64(ptr))) + ptr = ptr_t(c.call("sqlite3_db_filename", stk_t(c.handle), stk_t(ptr))) return vfs.GetFilename(c.ctx, c.mod, ptr, vfs.OPEN_MAIN_DB) } @@ -262,7 +262,7 @@ func (c *Conn) ReadOnly(schema string) (ro bool, ok bool) { defer c.arena.mark()() ptr = c.arena.string(schema) } - b := int32(c.call("sqlite3_db_readonly", uint64(c.handle), uint64(ptr))) + b := int32(c.call("sqlite3_db_readonly", stk_t(c.handle), stk_t(ptr))) return b > 0, b < 0 } @@ -270,7 +270,7 @@ func (c *Conn) ReadOnly(schema string) (ro bool, ok bool) { // // https://sqlite.org/c3ref/get_autocommit.html func (c *Conn) GetAutocommit() bool { - b := int32(c.call("sqlite3_get_autocommit", uint64(c.handle))) + b := int32(c.call("sqlite3_get_autocommit", stk_t(c.handle))) return b != 0 } @@ -279,7 +279,7 @@ func (c *Conn) GetAutocommit() bool { // // https://sqlite.org/c3ref/last_insert_rowid.html func (c *Conn) LastInsertRowID() int64 { - return int64(c.call("sqlite3_last_insert_rowid", uint64(c.handle))) + return int64(c.call("sqlite3_last_insert_rowid", stk_t(c.handle))) } // SetLastInsertRowID allows the application to set the value returned by @@ -287,7 +287,7 @@ func (c *Conn) LastInsertRowID() int64 { // // https://sqlite.org/c3ref/set_last_insert_rowid.html func (c *Conn) SetLastInsertRowID(id int64) { - c.call("sqlite3_set_last_insert_rowid", uint64(c.handle), uint64(id)) + c.call("sqlite3_set_last_insert_rowid", stk_t(c.handle), stk_t(id)) } // Changes returns the number of rows modified, inserted or deleted @@ -296,7 +296,7 @@ func (c *Conn) SetLastInsertRowID(id int64) { // // https://sqlite.org/c3ref/changes.html func (c *Conn) Changes() int64 { - return int64(c.call("sqlite3_changes64", uint64(c.handle))) + return int64(c.call("sqlite3_changes64", stk_t(c.handle))) } // TotalChanges returns the number of rows modified, inserted or deleted @@ -305,14 +305,14 @@ func (c *Conn) Changes() int64 { // // https://sqlite.org/c3ref/total_changes.html func (c *Conn) TotalChanges() int64 { - return int64(c.call("sqlite3_total_changes64", uint64(c.handle))) + return int64(c.call("sqlite3_total_changes64", stk_t(c.handle))) } // ReleaseMemory frees memory used by a database connection. // // https://sqlite.org/c3ref/db_release_memory.html func (c *Conn) ReleaseMemory() error { - rc := res_t(c.call("sqlite3_db_release_memory", uint64(c.handle))) + rc := res_t(c.call("sqlite3_db_release_memory", stk_t(c.handle))) return c.error(rc) } @@ -349,8 +349,8 @@ func (c *Conn) SetInterrupt(ctx context.Context) (old context.Context) { defer c.arena.mark()() stmtPtr := c.arena.new(ptrlen) loopPtr := c.arena.string(`WITH RECURSIVE c(x) AS (VALUES(0) UNION ALL SELECT x FROM c) SELECT x FROM c`) - c.call("sqlite3_prepare_v3", uint64(c.handle), uint64(loopPtr), math.MaxUint64, - uint64(PREPARE_PERSISTENT), uint64(stmtPtr), 0) + c.call("sqlite3_prepare_v3", stk_t(c.handle), stk_t(loopPtr), math.MaxUint64, + stk_t(PREPARE_PERSISTENT), stk_t(stmtPtr), 0) c.pending = &Stmt{c: c} c.pending.handle = util.Read32[ptr_t](c.mod, stmtPtr) } @@ -366,11 +366,11 @@ func (c *Conn) SetInterrupt(ctx context.Context) (old context.Context) { func (c *Conn) checkInterrupt(handle ptr_t) { if c.interrupt.Err() != nil { - c.call("sqlite3_interrupt", uint64(handle)) + c.call("sqlite3_interrupt", stk_t(handle)) } } -func progressCallback(ctx context.Context, mod api.Module, _ uint32) (interrupt uint32) { +func progressCallback(ctx context.Context, mod api.Module, _ ptr_t) (interrupt int32) { if c, ok := ctx.Value(connKey{}).(*Conn); ok { if c.interrupt.Done() != nil { runtime.Gosched() @@ -387,11 +387,11 @@ func progressCallback(ctx context.Context, mod api.Module, _ uint32) (interrupt // https://sqlite.org/c3ref/busy_timeout.html func (c *Conn) BusyTimeout(timeout time.Duration) error { ms := min((timeout+time.Millisecond-1)/time.Millisecond, math.MaxInt32) - rc := res_t(c.call("sqlite3_busy_timeout", uint64(c.handle), uint64(ms))) + rc := res_t(c.call("sqlite3_busy_timeout", stk_t(c.handle), stk_t(ms))) return c.error(rc) } -func timeoutCallback(ctx context.Context, mod api.Module, count, tmout int32) (retry uint32) { +func timeoutCallback(ctx context.Context, mod api.Module, count, tmout int32) (retry int32) { // https://fractaledmind.github.io/2024/04/15/sqlite-on-rails-the-how-and-why-of-optimal-performance/ if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.interrupt.Err() == nil { switch { @@ -414,11 +414,11 @@ func timeoutCallback(ctx context.Context, mod api.Module, count, tmout int32) (r // // https://sqlite.org/c3ref/busy_handler.html func (c *Conn) BusyHandler(cb func(ctx context.Context, count int) (retry bool)) error { - var enable uint64 + var enable int32 if cb != nil { enable = 1 } - rc := res_t(c.call("sqlite3_busy_handler_go", uint64(c.handle), enable)) + rc := res_t(c.call("sqlite3_busy_handler_go", stk_t(c.handle), stk_t(enable))) if err := c.error(rc); err != nil { return err } @@ -426,7 +426,7 @@ func (c *Conn) BusyHandler(cb func(ctx context.Context, count int) (retry bool)) return nil } -func busyCallback(ctx context.Context, mod api.Module, pDB ptr_t, count int32) (retry uint32) { +func busyCallback(ctx context.Context, mod api.Module, pDB ptr_t, count int32) (retry int32) { if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.busy != nil { interrupt := c.interrupt if interrupt == nil { @@ -447,13 +447,13 @@ func (c *Conn) Status(op DBStatus, reset bool) (current, highwater int, err erro hiPtr := c.arena.new(intlen) curPtr := c.arena.new(intlen) - var i uint64 + var i int32 if reset { i = 1 } - rc := res_t(c.call("sqlite3_db_status", uint64(c.handle), - uint64(op), uint64(curPtr), uint64(hiPtr), i)) + rc := res_t(c.call("sqlite3_db_status", stk_t(c.handle), + stk_t(op), stk_t(curPtr), stk_t(hiPtr), stk_t(i))) if err = c.error(rc); err == nil { current = int(util.Read32[int32](c.mod, curPtr)) highwater = int(util.Read32[int32](c.mod, hiPtr)) @@ -481,10 +481,10 @@ func (c *Conn) TableColumnMetadata(schema, table, column string) (declType, coll columnPtr = c.arena.string(column) } - rc := res_t(c.call("sqlite3_table_column_metadata", uint64(c.handle), - uint64(schemaPtr), uint64(tablePtr), uint64(columnPtr), - uint64(declTypePtr), uint64(collSeqPtr), - uint64(notNullPtr), uint64(primaryKeyPtr), uint64(autoIncPtr))) + rc := res_t(c.call("sqlite3_table_column_metadata", stk_t(c.handle), + stk_t(schemaPtr), stk_t(tablePtr), stk_t(columnPtr), + stk_t(declTypePtr), stk_t(collSeqPtr), + stk_t(notNullPtr), stk_t(primaryKeyPtr), stk_t(autoIncPtr))) if err = c.error(rc); err == nil && column != "" { if ptr := util.Read32[ptr_t](c.mod, declTypePtr); ptr != 0 { declType = util.ReadString(c.mod, ptr, _MAX_NAME) diff --git a/const.go b/const.go index 60d2bdc6..086902a6 100644 --- a/const.go +++ b/const.go @@ -21,6 +21,7 @@ const ( ) type ( + stk_t = util.Stk_t ptr_t = util.Ptr_t res_t = util.Res_t ) diff --git a/context.go b/context.go index 34ee92f1..637ddc28 100644 --- a/context.go +++ b/context.go @@ -32,14 +32,14 @@ func (ctx Context) Conn() *Conn { // https://sqlite.org/c3ref/get_auxdata.html func (ctx Context) SetAuxData(n int, data any) { ptr := util.AddHandle(ctx.c.ctx, data) - ctx.c.call("sqlite3_set_auxdata_go", uint64(ctx.handle), uint64(n), uint64(ptr)) + ctx.c.call("sqlite3_set_auxdata_go", stk_t(ctx.handle), stk_t(n), stk_t(ptr)) } // GetAuxData returns metadata for argument n of the function. // // https://sqlite.org/c3ref/get_auxdata.html func (ctx Context) GetAuxData(n int) any { - ptr := ptr_t(ctx.c.call("sqlite3_get_auxdata", uint64(ctx.handle), uint64(n))) + ptr := ptr_t(ctx.c.call("sqlite3_get_auxdata", stk_t(ctx.handle), stk_t(n))) return util.GetHandle(ctx.c.ctx, ptr) } @@ -68,7 +68,7 @@ func (ctx Context) ResultInt(value int) { // https://sqlite.org/c3ref/result_blob.html func (ctx Context) ResultInt64(value int64) { ctx.c.call("sqlite3_result_int64", - uint64(ctx.handle), uint64(value)) + stk_t(ctx.handle), stk_t(value)) } // ResultFloat sets the result of the function to a float64. @@ -76,7 +76,7 @@ func (ctx Context) ResultInt64(value int64) { // https://sqlite.org/c3ref/result_blob.html func (ctx Context) ResultFloat(value float64) { ctx.c.call("sqlite3_result_double", - uint64(ctx.handle), math.Float64bits(value)) + stk_t(ctx.handle), stk_t(math.Float64bits(value))) } // ResultText sets the result of the function to a string. @@ -85,7 +85,7 @@ func (ctx Context) ResultFloat(value float64) { func (ctx Context) ResultText(value string) { ptr := ctx.c.newString(value) ctx.c.call("sqlite3_result_text_go", - uint64(ctx.handle), uint64(ptr), uint64(len(value))) + stk_t(ctx.handle), stk_t(ptr), stk_t(len(value))) } // ResultRawText sets the text result of the function to a []byte. @@ -95,7 +95,7 @@ func (ctx Context) ResultText(value string) { func (ctx Context) ResultRawText(value []byte) { ptr := ctx.c.newBytes(value) ctx.c.call("sqlite3_result_text_go", - uint64(ctx.handle), uint64(ptr), uint64(len(value))) + stk_t(ctx.handle), stk_t(ptr), stk_t(len(value))) } // ResultBlob sets the result of the function to a []byte. @@ -105,7 +105,7 @@ func (ctx Context) ResultRawText(value []byte) { func (ctx Context) ResultBlob(value []byte) { ptr := ctx.c.newBytes(value) ctx.c.call("sqlite3_result_blob_go", - uint64(ctx.handle), uint64(ptr), uint64(len(value))) + stk_t(ctx.handle), stk_t(ptr), stk_t(len(value))) } // ResultZeroBlob sets the result of the function to a zero-filled, length n BLOB. @@ -113,7 +113,7 @@ func (ctx Context) ResultBlob(value []byte) { // https://sqlite.org/c3ref/result_blob.html func (ctx Context) ResultZeroBlob(n int64) { ctx.c.call("sqlite3_result_zeroblob64", - uint64(ctx.handle), uint64(n)) + stk_t(ctx.handle), stk_t(n)) } // ResultNull sets the result of the function to NULL. @@ -121,7 +121,7 @@ func (ctx Context) ResultZeroBlob(n int64) { // https://sqlite.org/c3ref/result_blob.html func (ctx Context) ResultNull() { ctx.c.call("sqlite3_result_null", - uint64(ctx.handle)) + stk_t(ctx.handle)) } // ResultTime sets the result of the function to a [time.Time]. @@ -146,14 +146,14 @@ func (ctx Context) ResultTime(value time.Time, format TimeFormat) { } func (ctx Context) resultRFC3339Nano(value time.Time) { - const maxlen = uint64(len(time.RFC3339Nano)) + 5 + const maxlen = int64(len(time.RFC3339Nano)) + 5 ptr := ctx.c.new(maxlen) buf := util.View(ctx.c.mod, ptr, maxlen) buf = value.AppendFormat(buf[:0], time.RFC3339Nano) ctx.c.call("sqlite3_result_text_go", - uint64(ctx.handle), uint64(ptr), uint64(len(buf))) + stk_t(ctx.handle), stk_t(ptr), stk_t(len(buf))) } // ResultPointer sets the result of the function to NULL, just like [Context.ResultNull], @@ -164,7 +164,7 @@ func (ctx Context) resultRFC3339Nano(value time.Time) { func (ctx Context) ResultPointer(ptr any) { valPtr := util.AddHandle(ctx.c.ctx, ptr) ctx.c.call("sqlite3_result_pointer_go", - uint64(ctx.handle), uint64(valPtr)) + stk_t(ctx.handle), stk_t(valPtr)) } // ResultJSON sets the result of the function to the JSON encoding of value. @@ -188,7 +188,7 @@ func (ctx Context) ResultValue(value Value) { return } ctx.c.call("sqlite3_result_value", - uint64(ctx.handle), uint64(value.handle)) + stk_t(ctx.handle), stk_t(value.handle)) } // ResultError sets the result of the function an error. @@ -196,12 +196,12 @@ func (ctx Context) ResultValue(value Value) { // https://sqlite.org/c3ref/result_blob.html func (ctx Context) ResultError(err error) { if errors.Is(err, NOMEM) { - ctx.c.call("sqlite3_result_error_nomem", uint64(ctx.handle)) + ctx.c.call("sqlite3_result_error_nomem", stk_t(ctx.handle)) return } if errors.Is(err, TOOBIG) { - ctx.c.call("sqlite3_result_error_toobig", uint64(ctx.handle)) + ctx.c.call("sqlite3_result_error_toobig", stk_t(ctx.handle)) return } @@ -210,11 +210,11 @@ func (ctx Context) ResultError(err error) { defer ctx.c.arena.mark()() ptr := ctx.c.arena.string(msg) ctx.c.call("sqlite3_result_error", - uint64(ctx.handle), uint64(ptr), uint64(len(msg))) + stk_t(ctx.handle), stk_t(ptr), stk_t(len(msg))) } if code != _OK { ctx.c.call("sqlite3_result_error_code", - uint64(ctx.handle), uint64(code)) + stk_t(ctx.handle), stk_t(code)) } } @@ -223,6 +223,6 @@ func (ctx Context) ResultError(err error) { // // https://sqlite.org/c3ref/vtab_nochange.html func (ctx Context) VTabNoChange() bool { - b := int32(ctx.c.call("sqlite3_vtab_nochange", uint64(ctx.handle))) + b := int32(ctx.c.call("sqlite3_vtab_nochange", stk_t(ctx.handle))) return b != 0 } diff --git a/error.go b/error.go index 3799416e..6d4bd63f 100644 --- a/error.go +++ b/error.go @@ -146,27 +146,27 @@ func (e ExtendedErrorCode) Code() ErrorCode { return ErrorCode(e) } -func errorCode(err error, def ErrorCode) (msg string, code uint32) { +func errorCode(err error, def ErrorCode) (msg string, code res_t) { switch code := err.(type) { case nil: return "", _OK case ErrorCode: - return "", uint32(code) + return "", res_t(code) case xErrorCode: - return "", uint32(code) + return "", res_t(code) case *Error: - return code.msg, uint32(code.code) + return code.msg, res_t(code.code) } var ecode ErrorCode var xcode xErrorCode switch { case errors.As(err, &xcode): - code = uint32(xcode) + code = res_t(xcode) case errors.As(err, &ecode): - code = uint32(ecode) + code = res_t(ecode) default: - code = uint32(def) + code = res_t(def) } return err.Error(), code } diff --git a/error_test.go b/error_test.go index 1cdc804a..2ec3f49a 100644 --- a/error_test.go +++ b/error_test.go @@ -136,7 +136,7 @@ func Test_ErrorCode_Error(t *testing.T) { // Test all error codes. for i := 0; i == int(ErrorCode(i)); i++ { want := "sqlite3: " - ptr := ptr_t(db.call("sqlite3_errstr", uint64(i))) + ptr := ptr_t(db.call("sqlite3_errstr", stk_t(i))) want += util.ReadString(db.mod, ptr, _MAX_NAME) got := ErrorCode(i).Error() @@ -158,7 +158,7 @@ func Test_ExtendedErrorCode_Error(t *testing.T) { // Test all extended error codes. for i := 0; i == int(ExtendedErrorCode(i)); i++ { want := "sqlite3: " - ptr := ptr_t(db.call("sqlite3_errstr", uint64(i))) + ptr := ptr_t(db.call("sqlite3_errstr", stk_t(i))) want += util.ReadString(db.mod, ptr, _MAX_NAME) got := ExtendedErrorCode(i).Error() @@ -172,7 +172,7 @@ func Test_errorCode(t *testing.T) { tests := []struct { arg error wantMsg string - wantCode uint32 + wantCode res_t }{ {nil, "", _OK}, {ERROR, "", util.ERROR}, @@ -190,7 +190,7 @@ func Test_errorCode(t *testing.T) { if gotMsg != tt.wantMsg { t.Errorf("errorCode() gotMsg = %q, want %q", gotMsg, tt.wantMsg) } - if gotCode != uint32(tt.wantCode) { + if gotCode != tt.wantCode { t.Errorf("errorCode() gotCode = %d, want %d", gotCode, tt.wantCode) } }) diff --git a/func.go b/func.go index cdb4e8e6..f6c488ff 100644 --- a/func.go +++ b/func.go @@ -14,11 +14,11 @@ import ( // // https://sqlite.org/c3ref/collation_needed.html func (c *Conn) CollationNeeded(cb func(db *Conn, name string)) error { - var enable uint64 + var enable int32 if cb != nil { enable = 1 } - rc := res_t(c.call("sqlite3_collation_needed_go", uint64(c.handle), enable)) + rc := res_t(c.call("sqlite3_collation_needed_go", stk_t(c.handle), stk_t(enable))) if err := c.error(rc); err != nil { return err } @@ -33,7 +33,7 @@ func (c *Conn) CollationNeeded(cb func(db *Conn, name string)) error { // This can be used to load schemas that contain // one or more unknown collating sequences. func (c Conn) AnyCollationNeeded() error { - rc := res_t(c.call("sqlite3_anycollseq_init", uint64(c.handle), 0, 0)) + rc := res_t(c.call("sqlite3_anycollseq_init", stk_t(c.handle), 0, 0)) if err := c.error(rc); err != nil { return err } @@ -52,7 +52,7 @@ func (c *Conn) CreateCollation(name string, fn func(a, b []byte) int) error { funcPtr = util.AddHandle(c.ctx, fn) } rc := res_t(c.call("sqlite3_create_collation_go", - uint64(c.handle), uint64(namePtr), uint64(funcPtr))) + stk_t(c.handle), stk_t(namePtr), stk_t(funcPtr))) return c.error(rc) } @@ -67,8 +67,8 @@ func (c *Conn) CreateFunction(name string, nArg int, flag FunctionFlag, fn Scala funcPtr = util.AddHandle(c.ctx, fn) } rc := res_t(c.call("sqlite3_create_function_go", - uint64(c.handle), uint64(namePtr), uint64(nArg), - uint64(flag), uint64(funcPtr))) + stk_t(c.handle), stk_t(namePtr), stk_t(nArg), + stk_t(flag), stk_t(funcPtr))) return c.error(rc) } @@ -93,8 +93,8 @@ func (c *Conn) CreateWindowFunction(name string, nArg int, flag FunctionFlag, fn call = "sqlite3_create_window_function_go" } rc := res_t(c.call(call, - uint64(c.handle), uint64(namePtr), uint64(nArg), - uint64(flag), uint64(funcPtr))) + stk_t(c.handle), stk_t(namePtr), stk_t(nArg), + stk_t(flag), stk_t(funcPtr))) return c.error(rc) } @@ -130,7 +130,7 @@ func (c *Conn) OverloadFunction(name string, nArg int) error { defer c.arena.mark()() namePtr := c.arena.string(name) rc := res_t(c.call("sqlite3_overload_function", - uint64(c.handle), uint64(namePtr), uint64(nArg))) + stk_t(c.handle), stk_t(namePtr), stk_t(nArg))) return c.error(rc) } @@ -145,12 +145,12 @@ func collationCallback(ctx context.Context, mod api.Module, pArg, pDB ptr_t, eTe } } -func compareCallback(ctx context.Context, mod api.Module, pApp ptr_t, nKey1 uint32, pKey1 ptr_t, nKey2 uint32, pKey2 ptr_t) uint32 { +func compareCallback(ctx context.Context, mod api.Module, pApp ptr_t, nKey1 int32, pKey1 ptr_t, nKey2 int32, pKey2 ptr_t) uint32 { fn := util.GetHandle(ctx, pApp).(func(a, b []byte) int) - return uint32(fn(util.View(mod, pKey1, uint64(nKey1)), util.View(mod, pKey2, uint64(nKey2)))) + return uint32(fn(util.View(mod, pKey1, int64(nKey1)), util.View(mod, pKey2, int64(nKey2)))) } -func funcCallback(ctx context.Context, mod api.Module, pCtx, pApp ptr_t, nArg uint32, pArg ptr_t) { +func funcCallback(ctx context.Context, mod api.Module, pCtx, pApp ptr_t, nArg int32, pArg ptr_t) { args := getFuncArgs() defer putFuncArgs(args) db := ctx.Value(connKey{}).(*Conn) @@ -159,7 +159,7 @@ func funcCallback(ctx context.Context, mod api.Module, pCtx, pApp ptr_t, nArg ui fn(Context{db, pCtx}, args[:nArg]...) } -func stepCallback(ctx context.Context, mod api.Module, pCtx, pAgg, pApp ptr_t, nArg uint32, pArg ptr_t) { +func stepCallback(ctx context.Context, mod api.Module, pCtx, pAgg, pApp ptr_t, nArg int32, pArg ptr_t) { args := getFuncArgs() defer putFuncArgs(args) db := ctx.Value(connKey{}).(*Conn) @@ -184,7 +184,7 @@ func valueCallback(ctx context.Context, mod api.Module, pCtx, pAgg ptr_t) { fn.Value(Context{db, pCtx}) } -func inverseCallback(ctx context.Context, mod api.Module, pCtx, pAgg ptr_t, nArg uint32, pArg ptr_t) { +func inverseCallback(ctx context.Context, mod api.Module, pCtx, pAgg ptr_t, nArg int32, pArg ptr_t) { args := getFuncArgs() defer putFuncArgs(args) db := ctx.Value(connKey{}).(*Conn) diff --git a/internal/util/func.go b/internal/util/func.go index d310afc2..e705f318 100644 --- a/internal/util/func.go +++ b/internal/util/func.go @@ -7,10 +7,6 @@ import ( "github.com/tetratelabs/wazero/api" ) -type i8 interface{ ~int8 | ~uint8 } -type i32 interface{ ~int32 | ~uint32 } -type i64 interface{ ~int64 | ~uint64 } - type funcVI[T0 i32] func(context.Context, api.Module, T0) func (fn funcVI[T0]) Call(ctx context.Context, mod api.Module, stack []uint64) { diff --git a/internal/util/mem.go b/internal/util/mem.go index 7172ab08..a4d89445 100644 --- a/internal/util/mem.go +++ b/internal/util/mem.go @@ -13,15 +13,20 @@ const ( ) type ( + i8 interface{ ~int8 | ~uint8 } + i32 interface{ ~int32 | ~uint32 } + i64 interface{ ~int64 | ~uint64 } + + Stk_t = uint64 Ptr_t uint32 Res_t int32 ) -func View(mod api.Module, ptr Ptr_t, size uint64) []byte { +func View(mod api.Module, ptr Ptr_t, size int64) []byte { if ptr == 0 { panic(NilErr) } - if size > math.MaxUint32 { + if uint64(size) > math.MaxUint32 { panic(RangeErr) } if size == 0 { @@ -105,20 +110,16 @@ func WriteFloat64(mod api.Module, ptr Ptr_t, v float64) { Write64(mod, ptr, math.Float64bits(v)) } -func ReadString(mod api.Module, ptr Ptr_t, maxlen uint32) string { +func ReadString(mod api.Module, ptr Ptr_t, maxlen int64) string { if ptr == 0 { panic(NilErr) } - switch maxlen { - case 0: + if maxlen <= 0 { return "" - case math.MaxUint32: - // avoid overflow - default: - maxlen = maxlen + 1 } mem := mod.Memory() - buf, ok := mem.Read(uint32(ptr), maxlen) + maxlen = min(maxlen, math.MaxInt32-1) + 1 + buf, ok := mem.Read(uint32(ptr), uint32(maxlen)) if !ok { buf, ok = mem.Read(uint32(ptr), mem.Size()-uint32(ptr)) if !ok { @@ -133,12 +134,12 @@ func ReadString(mod api.Module, ptr Ptr_t, maxlen uint32) string { } func WriteBytes(mod api.Module, ptr Ptr_t, b []byte) { - buf := View(mod, ptr, uint64(len(b))) + buf := View(mod, ptr, int64(len(b))) copy(buf, b) } func WriteString(mod api.Module, ptr Ptr_t, s string) { - buf := View(mod, ptr, uint64(len(s)+1)) + buf := View(mod, ptr, int64(len(s)+1)) buf[len(s)] = 0 copy(buf, s) } diff --git a/internal/util/mem_test.go b/internal/util/mem_test.go index 0226e7b0..28d28555 100644 --- a/internal/util/mem_test.go +++ b/internal/util/mem_test.go @@ -115,6 +115,6 @@ func TestWriteUint64_range(t *testing.T) { func TestReadString_range(t *testing.T) { defer func() { _ = recover() }() mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize)) - ReadString(mock, wazerotest.PageSize+2, math.MaxUint32) + ReadString(mock, wazerotest.PageSize+2, math.MaxInt) t.Error("want panic") } diff --git a/internal/util/mmap_unix.go b/internal/util/mmap_unix.go index 03ea8ba0..42a24752 100644 --- a/internal/util/mmap_unix.go +++ b/internal/util/mmap_unix.go @@ -25,9 +25,9 @@ func (s *mmapState) new(ctx context.Context, mod api.Module, size int32) *Mapped // Allocate page aligned memmory. alloc := mod.ExportedFunction("aligned_alloc") - stack := [...]uint64{ - uint64(unix.Getpagesize()), - uint64(size), + stack := [...]Stk_t{ + Stk_t(unix.Getpagesize()), + Stk_t(size), } if err := alloc.CallWithStack(ctx, stack[:]); err != nil { panic(err) @@ -38,7 +38,7 @@ func (s *mmapState) new(ctx context.Context, mod api.Module, size int32) *Mapped // Save the newly allocated region. ptr := Ptr_t(stack[0]) - buf := View(mod, ptr, uint64(size)) + buf := View(mod, ptr, int64(size)) ret := &MappedRegion{ Ptr: ptr, size: size, diff --git a/sqlite.go b/sqlite.go index d478dca6..defb3f76 100644 --- a/sqlite.go +++ b/sqlite.go @@ -93,7 +93,7 @@ type sqlite struct { id [32]*byte mask uint32 } - stack [9]uint64 + stack [9]stk_t } func instantiateSQLite() (sqlt *sqlite, err error) { @@ -130,17 +130,17 @@ func (sqlt *sqlite) error(rc res_t, handle ptr_t, sql ...string) error { panic(util.OOMErr) } - if ptr := ptr_t(sqlt.call("sqlite3_errstr", uint64(rc))); ptr != 0 { + if ptr := ptr_t(sqlt.call("sqlite3_errstr", stk_t(rc))); ptr != 0 { err.str = util.ReadString(sqlt.mod, ptr, _MAX_NAME) } if handle != 0 { - if ptr := ptr_t(sqlt.call("sqlite3_errmsg", uint64(handle))); ptr != 0 { + if ptr := ptr_t(sqlt.call("sqlite3_errmsg", stk_t(handle))); ptr != 0 { err.msg = util.ReadString(sqlt.mod, ptr, _MAX_LENGTH) } if len(sql) != 0 { - if i := int32(sqlt.call("sqlite3_error_offset", uint64(handle))); i != -1 { + if i := int32(sqlt.call("sqlite3_error_offset", stk_t(handle))); i != -1 { err.sql = sql[0][i:] } } @@ -181,9 +181,7 @@ func (sqlt *sqlite) putfn(name string, fn api.Function) { } } -type stk64 uint64 - -func (sqlt *sqlite) call(name string, params ...uint64) stk64 { +func (sqlt *sqlite) call(name string, params ...stk_t) stk_t { copy(sqlt.stack[:], params) fn := sqlt.getfn(name) err := fn.CallWithStack(sqlt.ctx, sqlt.stack[:]) @@ -191,26 +189,26 @@ func (sqlt *sqlite) call(name string, params ...uint64) stk64 { panic(err) } sqlt.putfn(name, fn) - return stk64(sqlt.stack[0]) + return stk_t(sqlt.stack[0]) } func (sqlt *sqlite) free(ptr ptr_t) { if ptr == 0 { return } - sqlt.call("sqlite3_free", uint64(ptr)) + sqlt.call("sqlite3_free", stk_t(ptr)) } -func (sqlt *sqlite) new(size uint64) ptr_t { - ptr := ptr_t(sqlt.call("sqlite3_malloc64", size)) +func (sqlt *sqlite) new(size int64) ptr_t { + ptr := ptr_t(sqlt.call("sqlite3_malloc64", stk_t(size))) if ptr == 0 && size != 0 { panic(util.OOMErr) } return ptr } -func (sqlt *sqlite) realloc(ptr ptr_t, size uint64) ptr_t { - ptr = ptr_t(sqlt.call("sqlite3_realloc64", uint64(ptr), size)) +func (sqlt *sqlite) realloc(ptr ptr_t, size int64) ptr_t { + ptr = ptr_t(sqlt.call("sqlite3_realloc64", stk_t(ptr), stk_t(size))) if ptr == 0 && size != 0 { panic(util.OOMErr) } @@ -225,18 +223,18 @@ func (sqlt *sqlite) newBytes(b []byte) ptr_t { if size == 0 { size = 1 } - ptr := sqlt.new(uint64(size)) + ptr := sqlt.new(int64(size)) util.WriteBytes(sqlt.mod, ptr, b) return ptr } func (sqlt *sqlite) newString(s string) ptr_t { - ptr := sqlt.new(uint64(len(s) + 1)) + ptr := sqlt.new(int64(len(s)) + 1) util.WriteString(sqlt.mod, ptr, s) return ptr } -func (sqlt *sqlite) newArena(size uint64) arena { +func (sqlt *sqlite) newArena(size int64) arena { // Ensure the arena's size is a multiple of 8. size = (size + 7) &^ 7 return arena{ @@ -278,14 +276,14 @@ func (a *arena) mark() (reset func()) { } } -func (a *arena) new(size uint64) ptr_t { +func (a *arena) new(size int64) ptr_t { // Align the next address, to 4 or 8 bytes. if size&7 != 0 { a.next = (a.next + 3) &^ 3 } else { a.next = (a.next + 7) &^ 7 } - if size <= uint64(a.size-a.next) { + if size <= int64(a.size-a.next) { ptr := a.base + ptr_t(a.next) a.next += uint32(size) return ptr_t(ptr) @@ -299,13 +297,13 @@ func (a *arena) bytes(b []byte) ptr_t { if (*[0]byte)(b) == nil { return 0 } - ptr := a.new(uint64(len(b))) + ptr := a.new(int64(len(b))) util.WriteBytes(a.sqlt.mod, ptr, b) return ptr } func (a *arena) string(s string) ptr_t { - ptr := a.new(uint64(len(s) + 1)) + ptr := a.new(int64(len(s)) + 1) util.WriteString(a.sqlt.mod, ptr, s) return ptr } diff --git a/sqlite_test.go b/sqlite_test.go index 0e4b06cc..1b909b9f 100644 --- a/sqlite_test.go +++ b/sqlite_test.go @@ -73,7 +73,7 @@ func Test_sqlite_newArena(t *testing.T) { if ptr == 0 { t.Fatalf("got nullptr") } - if got := util.ReadString(sqlite.mod, ptr, math.MaxUint32); got != title { + if got := util.ReadString(sqlite.mod, ptr, math.MaxInt); got != title { t.Errorf("got %q, want %q", got, title) } @@ -82,7 +82,7 @@ func Test_sqlite_newArena(t *testing.T) { if ptr == 0 { t.Fatalf("got nullptr") } - if got := util.ReadString(sqlite.mod, ptr, math.MaxUint32); got != body { + if got := util.ReadString(sqlite.mod, ptr, math.MaxInt); got != body { t.Errorf("got %q, want %q", got, body) } @@ -94,7 +94,7 @@ func Test_sqlite_newArena(t *testing.T) { if ptr == 0 { t.Fatalf("got nullptr") } - if got := util.View(sqlite.mod, ptr, uint64(len(title))); string(got) != title { + if got := util.View(sqlite.mod, ptr, int64(len(title))); string(got) != title { t.Errorf("got %q, want %q", got, title) } @@ -122,7 +122,7 @@ func Test_sqlite_newBytes(t *testing.T) { } want := buf - if got := util.View(sqlite.mod, ptr, uint64(len(want))); !bytes.Equal(got, want) { + if got := util.View(sqlite.mod, ptr, int64(len(want))); !bytes.Equal(got, want) { t.Errorf("got %q, want %q", got, want) } @@ -157,7 +157,7 @@ func Test_sqlite_newString(t *testing.T) { } want := str + "\000" - if got := util.View(sqlite.mod, ptr, uint64(len(want))); string(got) != want { + if got := util.View(sqlite.mod, ptr, int64(len(want))); string(got) != want { t.Errorf("got %q, want %q", got, want) } } @@ -183,7 +183,7 @@ func Test_sqlite_getString(t *testing.T) { } want := "sqlite3" - if got := util.ReadString(sqlite.mod, ptr, math.MaxUint32); got != want { + if got := util.ReadString(sqlite.mod, ptr, math.MaxInt); got != want { t.Errorf("got %q, want %q", got, want) } if got := util.ReadString(sqlite.mod, ptr, 0); got != "" { @@ -192,13 +192,13 @@ func Test_sqlite_getString(t *testing.T) { func() { defer func() { _ = recover() }() - util.ReadString(sqlite.mod, ptr, uint32(len(want)/2)) + util.ReadString(sqlite.mod, ptr, int64(len(want)/2)) t.Error("want panic") }() func() { defer func() { _ = recover() }() - util.ReadString(sqlite.mod, 0, math.MaxUint32) + util.ReadString(sqlite.mod, 0, math.MaxInt) t.Error("want panic") }() } diff --git a/stmt.go b/stmt.go index 10430075..4e17d103 100644 --- a/stmt.go +++ b/stmt.go @@ -29,7 +29,7 @@ func (s *Stmt) Close() error { return nil } - rc := res_t(s.c.call("sqlite3_finalize", uint64(s.handle))) + rc := res_t(s.c.call("sqlite3_finalize", stk_t(s.handle))) stmts := s.c.stmts for i := range stmts { if s == stmts[i] { @@ -64,7 +64,7 @@ func (s *Stmt) SQL() string { // // https://sqlite.org/c3ref/expanded_sql.html func (s *Stmt) ExpandedSQL() string { - ptr := ptr_t(s.c.call("sqlite3_expanded_sql", uint64(s.handle))) + ptr := ptr_t(s.c.call("sqlite3_expanded_sql", stk_t(s.handle))) sql := util.ReadString(s.c.mod, ptr, _MAX_SQL_LENGTH) s.c.free(ptr) return sql @@ -75,7 +75,7 @@ func (s *Stmt) ExpandedSQL() string { // // https://sqlite.org/c3ref/stmt_readonly.html func (s *Stmt) ReadOnly() bool { - b := int32(s.c.call("sqlite3_stmt_readonly", uint64(s.handle))) + b := int32(s.c.call("sqlite3_stmt_readonly", stk_t(s.handle))) return b != 0 } @@ -83,7 +83,7 @@ func (s *Stmt) ReadOnly() bool { // // https://sqlite.org/c3ref/reset.html func (s *Stmt) Reset() error { - rc := res_t(s.c.call("sqlite3_reset", uint64(s.handle))) + rc := res_t(s.c.call("sqlite3_reset", stk_t(s.handle))) s.err = nil return s.c.error(rc) } @@ -92,7 +92,7 @@ func (s *Stmt) Reset() error { // // https://sqlite.org/c3ref/stmt_busy.html func (s *Stmt) Busy() bool { - rc := res_t(s.c.call("sqlite3_stmt_busy", uint64(s.handle))) + rc := res_t(s.c.call("sqlite3_stmt_busy", stk_t(s.handle))) return rc != 0 } @@ -107,7 +107,7 @@ func (s *Stmt) Busy() bool { // https://sqlite.org/c3ref/step.html func (s *Stmt) Step() bool { s.c.checkInterrupt(s.c.handle) - rc := res_t(s.c.call("sqlite3_step", uint64(s.handle))) + rc := res_t(s.c.call("sqlite3_step", stk_t(s.handle))) switch rc { case _ROW: s.err = nil @@ -143,12 +143,12 @@ func (s *Stmt) Status(op StmtStatus, reset bool) int { if op > STMTSTATUS_FILTER_HIT && op != STMTSTATUS_MEMUSED { return 0 } - var i uint64 + var i int32 if reset { i = 1 } - n := int32(s.c.call("sqlite3_stmt_status", uint64(s.handle), - uint64(op), i)) + n := int32(s.c.call("sqlite3_stmt_status", stk_t(s.handle), + stk_t(op), stk_t(i))) return int(n) } @@ -156,7 +156,7 @@ func (s *Stmt) Status(op StmtStatus, reset bool) int { // // https://sqlite.org/c3ref/clear_bindings.html func (s *Stmt) ClearBindings() error { - rc := res_t(s.c.call("sqlite3_clear_bindings", uint64(s.handle))) + rc := res_t(s.c.call("sqlite3_clear_bindings", stk_t(s.handle))) return s.c.error(rc) } @@ -165,7 +165,7 @@ func (s *Stmt) ClearBindings() error { // https://sqlite.org/c3ref/bind_parameter_count.html func (s *Stmt) BindCount() int { n := int32(s.c.call("sqlite3_bind_parameter_count", - uint64(s.handle))) + stk_t(s.handle))) return int(n) } @@ -177,7 +177,7 @@ func (s *Stmt) BindIndex(name string) int { defer s.c.arena.mark()() namePtr := s.c.arena.string(name) i := int32(s.c.call("sqlite3_bind_parameter_index", - uint64(s.handle), uint64(namePtr))) + stk_t(s.handle), stk_t(namePtr))) return int(i) } @@ -187,7 +187,7 @@ func (s *Stmt) BindIndex(name string) int { // https://sqlite.org/c3ref/bind_parameter_name.html func (s *Stmt) BindName(param int) string { ptr := ptr_t(s.c.call("sqlite3_bind_parameter_name", - uint64(s.handle), uint64(param))) + stk_t(s.handle), stk_t(param))) if ptr == 0 { return "" } @@ -222,7 +222,7 @@ func (s *Stmt) BindInt(param int, value int) error { // https://sqlite.org/c3ref/bind_blob.html func (s *Stmt) BindInt64(param int, value int64) error { rc := res_t(s.c.call("sqlite3_bind_int64", - uint64(s.handle), uint64(param), uint64(value))) + stk_t(s.handle), stk_t(param), stk_t(value))) return s.c.error(rc) } @@ -232,7 +232,8 @@ func (s *Stmt) BindInt64(param int, value int64) error { // https://sqlite.org/c3ref/bind_blob.html func (s *Stmt) BindFloat(param int, value float64) error { rc := res_t(s.c.call("sqlite3_bind_double", - uint64(s.handle), uint64(param), math.Float64bits(value))) + stk_t(s.handle), stk_t(param), + stk_t(math.Float64bits(value)))) return s.c.error(rc) } @@ -246,8 +247,8 @@ func (s *Stmt) BindText(param int, value string) error { } ptr := s.c.newString(value) rc := res_t(s.c.call("sqlite3_bind_text_go", - uint64(s.handle), uint64(param), - uint64(ptr), uint64(len(value)))) + stk_t(s.handle), stk_t(param), + stk_t(ptr), stk_t(len(value)))) return s.c.error(rc) } @@ -262,8 +263,8 @@ func (s *Stmt) BindRawText(param int, value []byte) error { } ptr := s.c.newBytes(value) rc := res_t(s.c.call("sqlite3_bind_text_go", - uint64(s.handle), uint64(param), - uint64(ptr), uint64(len(value)))) + stk_t(s.handle), stk_t(param), + stk_t(ptr), stk_t(len(value)))) return s.c.error(rc) } @@ -278,8 +279,8 @@ func (s *Stmt) BindBlob(param int, value []byte) error { } ptr := s.c.newBytes(value) rc := res_t(s.c.call("sqlite3_bind_blob_go", - uint64(s.handle), uint64(param), - uint64(ptr), uint64(len(value)))) + stk_t(s.handle), stk_t(param), + stk_t(ptr), stk_t(len(value)))) return s.c.error(rc) } @@ -289,7 +290,7 @@ func (s *Stmt) BindBlob(param int, value []byte) error { // https://sqlite.org/c3ref/bind_blob.html func (s *Stmt) BindZeroBlob(param int, n int64) error { rc := res_t(s.c.call("sqlite3_bind_zeroblob64", - uint64(s.handle), uint64(param), uint64(n))) + stk_t(s.handle), stk_t(param), stk_t(n))) return s.c.error(rc) } @@ -299,7 +300,7 @@ func (s *Stmt) BindZeroBlob(param int, n int64) error { // https://sqlite.org/c3ref/bind_blob.html func (s *Stmt) BindNull(param int) error { rc := res_t(s.c.call("sqlite3_bind_null", - uint64(s.handle), uint64(param))) + stk_t(s.handle), stk_t(param))) return s.c.error(rc) } @@ -325,15 +326,15 @@ func (s *Stmt) BindTime(param int, value time.Time, format TimeFormat) error { } func (s *Stmt) bindRFC3339Nano(param int, value time.Time) error { - const maxlen = uint64(len(time.RFC3339Nano)) + 5 + const maxlen = int64(len(time.RFC3339Nano)) + 5 ptr := s.c.new(maxlen) buf := util.View(s.c.mod, ptr, maxlen) buf = value.AppendFormat(buf[:0], time.RFC3339Nano) rc := res_t(s.c.call("sqlite3_bind_text_go", - uint64(s.handle), uint64(param), - uint64(ptr), uint64(len(buf)))) + stk_t(s.handle), stk_t(param), + stk_t(ptr), stk_t(len(buf)))) return s.c.error(rc) } @@ -346,7 +347,7 @@ func (s *Stmt) bindRFC3339Nano(param int, value time.Time) error { func (s *Stmt) BindPointer(param int, ptr any) error { valPtr := util.AddHandle(s.c.ctx, ptr) rc := res_t(s.c.call("sqlite3_bind_pointer_go", - uint64(s.handle), uint64(param), uint64(valPtr))) + stk_t(s.handle), stk_t(param), stk_t(valPtr))) return s.c.error(rc) } @@ -371,7 +372,7 @@ func (s *Stmt) BindValue(param int, value Value) error { return MISUSE } rc := res_t(s.c.call("sqlite3_bind_value", - uint64(s.handle), uint64(param), uint64(value.handle))) + stk_t(s.handle), stk_t(param), stk_t(value.handle))) return s.c.error(rc) } @@ -380,7 +381,7 @@ func (s *Stmt) BindValue(param int, value Value) error { // https://sqlite.org/c3ref/data_count.html func (s *Stmt) DataCount() int { n := int32(s.c.call("sqlite3_data_count", - uint64(s.handle))) + stk_t(s.handle))) return int(n) } @@ -389,7 +390,7 @@ func (s *Stmt) DataCount() int { // https://sqlite.org/c3ref/column_count.html func (s *Stmt) ColumnCount() int { n := int32(s.c.call("sqlite3_column_count", - uint64(s.handle))) + stk_t(s.handle))) return int(n) } @@ -399,7 +400,7 @@ func (s *Stmt) ColumnCount() int { // https://sqlite.org/c3ref/column_name.html func (s *Stmt) ColumnName(col int) string { ptr := ptr_t(s.c.call("sqlite3_column_name", - uint64(s.handle), uint64(col))) + stk_t(s.handle), stk_t(col))) if ptr == 0 { panic(util.OOMErr) } @@ -412,7 +413,7 @@ func (s *Stmt) ColumnName(col int) string { // https://sqlite.org/c3ref/column_blob.html func (s *Stmt) ColumnType(col int) Datatype { return Datatype(s.c.call("sqlite3_column_type", - uint64(s.handle), uint64(col))) + stk_t(s.handle), stk_t(col))) } // ColumnDeclType returns the declared datatype of the result column. @@ -421,7 +422,7 @@ func (s *Stmt) ColumnType(col int) Datatype { // https://sqlite.org/c3ref/column_decltype.html func (s *Stmt) ColumnDeclType(col int) string { ptr := ptr_t(s.c.call("sqlite3_column_decltype", - uint64(s.handle), uint64(col))) + stk_t(s.handle), stk_t(col))) if ptr == 0 { return "" } @@ -435,7 +436,7 @@ func (s *Stmt) ColumnDeclType(col int) string { // https://sqlite.org/c3ref/column_database_name.html func (s *Stmt) ColumnDatabaseName(col int) string { ptr := ptr_t(s.c.call("sqlite3_column_database_name", - uint64(s.handle), uint64(col))) + stk_t(s.handle), stk_t(col))) if ptr == 0 { return "" } @@ -449,7 +450,7 @@ func (s *Stmt) ColumnDatabaseName(col int) string { // https://sqlite.org/c3ref/column_database_name.html func (s *Stmt) ColumnTableName(col int) string { ptr := ptr_t(s.c.call("sqlite3_column_table_name", - uint64(s.handle), uint64(col))) + stk_t(s.handle), stk_t(col))) if ptr == 0 { return "" } @@ -463,7 +464,7 @@ func (s *Stmt) ColumnTableName(col int) string { // https://sqlite.org/c3ref/column_database_name.html func (s *Stmt) ColumnOriginName(col int) string { ptr := ptr_t(s.c.call("sqlite3_column_origin_name", - uint64(s.handle), uint64(col))) + stk_t(s.handle), stk_t(col))) if ptr == 0 { return "" } @@ -495,7 +496,7 @@ func (s *Stmt) ColumnInt(col int) int { // https://sqlite.org/c3ref/column_blob.html func (s *Stmt) ColumnInt64(col int) int64 { return int64(s.c.call("sqlite3_column_int64", - uint64(s.handle), uint64(col))) + stk_t(s.handle), stk_t(col))) } // ColumnFloat returns the value of the result column as a float64. @@ -504,7 +505,7 @@ func (s *Stmt) ColumnInt64(col int) int64 { // https://sqlite.org/c3ref/column_blob.html func (s *Stmt) ColumnFloat(col int) float64 { f := uint64(s.c.call("sqlite3_column_double", - uint64(s.handle), uint64(col))) + stk_t(s.handle), stk_t(col))) return math.Float64frombits(f) } @@ -558,7 +559,7 @@ func (s *Stmt) ColumnBlob(col int, buf []byte) []byte { // https://sqlite.org/c3ref/column_blob.html func (s *Stmt) ColumnRawText(col int) []byte { ptr := ptr_t(s.c.call("sqlite3_column_text", - uint64(s.handle), uint64(col))) + stk_t(s.handle), stk_t(col))) return s.columnRawBytes(col, ptr) } @@ -570,13 +571,13 @@ func (s *Stmt) ColumnRawText(col int) []byte { // https://sqlite.org/c3ref/column_blob.html func (s *Stmt) ColumnRawBlob(col int) []byte { ptr := ptr_t(s.c.call("sqlite3_column_blob", - uint64(s.handle), uint64(col))) + stk_t(s.handle), stk_t(col))) return s.columnRawBytes(col, ptr) } func (s *Stmt) columnRawBytes(col int, ptr ptr_t) []byte { if ptr == 0 { - rc := res_t(s.c.call("sqlite3_errcode", uint64(s.c.handle))) + rc := res_t(s.c.call("sqlite3_errcode", stk_t(s.c.handle))) if rc != _ROW && rc != _DONE { s.err = s.c.error(rc) } @@ -584,8 +585,8 @@ func (s *Stmt) columnRawBytes(col int, ptr ptr_t) []byte { } n := int32(s.c.call("sqlite3_column_bytes", - uint64(s.handle), uint64(col))) - return util.View(s.c.mod, ptr, uint64(n)) + stk_t(s.handle), stk_t(col))) + return util.View(s.c.mod, ptr, int64(n)) } // ColumnJSON parses the JSON-encoded value of the result column @@ -618,7 +619,7 @@ func (s *Stmt) ColumnJSON(col int, ptr any) error { // https://sqlite.org/c3ref/column_blob.html func (s *Stmt) ColumnValue(col int) Value { ptr := ptr_t(s.c.call("sqlite3_column_value", - uint64(s.handle), uint64(col))) + stk_t(s.handle), stk_t(col))) return Value{ c: s.c, unprot: true, @@ -636,12 +637,12 @@ func (s *Stmt) ColumnValue(col int) Value { // subsequent calls to [Stmt] methods. func (s *Stmt) Columns(dest ...any) error { defer s.c.arena.mark()() - count := uint64(len(dest)) + count := int64(len(dest)) typePtr := s.c.arena.new(count) dataPtr := s.c.arena.new(count * 8) rc := res_t(s.c.call("sqlite3_columns_go", - uint64(s.handle), count, uint64(typePtr), uint64(dataPtr))) + stk_t(s.handle), stk_t(count), stk_t(typePtr), stk_t(dataPtr))) if err := s.c.error(rc); err != nil { return err } @@ -668,7 +669,7 @@ func (s *Stmt) Columns(dest ...any) error { continue } len := util.Read32[int32](s.c.mod, dataPtr+4) - buf := util.View(s.c.mod, ptr, uint64(len)) + buf := util.View(s.c.mod, ptr, int64(len)) if types[i] == byte(TEXT) { dest[i] = string(buf) } else { diff --git a/txn.go b/txn.go index bdee752e..b24789f8 100644 --- a/txn.go +++ b/txn.go @@ -234,7 +234,7 @@ func (c *Conn) TxnState(schema string) TxnState { defer c.arena.mark()() ptr = c.arena.string(schema) } - return TxnState(c.call("sqlite3_txn_state", uint64(c.handle), uint64(ptr))) + return TxnState(c.call("sqlite3_txn_state", stk_t(c.handle), stk_t(ptr))) } // CommitHook registers a callback function to be invoked @@ -243,11 +243,11 @@ func (c *Conn) TxnState(schema string) TxnState { // // https://sqlite.org/c3ref/commit_hook.html func (c *Conn) CommitHook(cb func() (ok bool)) { - var enable uint64 + var enable int32 if cb != nil { enable = 1 } - c.call("sqlite3_commit_hook_go", uint64(c.handle), enable) + c.call("sqlite3_commit_hook_go", stk_t(c.handle), stk_t(enable)) c.commit = cb } @@ -256,11 +256,11 @@ func (c *Conn) CommitHook(cb func() (ok bool)) { // // https://sqlite.org/c3ref/commit_hook.html func (c *Conn) RollbackHook(cb func()) { - var enable uint64 + var enable int32 if cb != nil { enable = 1 } - c.call("sqlite3_rollback_hook_go", uint64(c.handle), enable) + c.call("sqlite3_rollback_hook_go", stk_t(c.handle), stk_t(enable)) c.rollback = cb } @@ -269,15 +269,15 @@ func (c *Conn) RollbackHook(cb func()) { // // https://sqlite.org/c3ref/update_hook.html func (c *Conn) UpdateHook(cb func(action AuthorizerActionCode, schema, table string, rowid int64)) { - var enable uint64 + var enable int32 if cb != nil { enable = 1 } - c.call("sqlite3_update_hook_go", uint64(c.handle), enable) + c.call("sqlite3_update_hook_go", stk_t(c.handle), stk_t(enable)) c.update = cb } -func commitCallback(ctx context.Context, mod api.Module, pDB ptr_t) (rollback uint32) { +func commitCallback(ctx context.Context, mod api.Module, pDB ptr_t) (rollback int32) { if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.commit != nil { if !c.commit() { rollback = 1 @@ -292,11 +292,11 @@ func rollbackCallback(ctx context.Context, mod api.Module, pDB ptr_t) { } } -func updateCallback(ctx context.Context, mod api.Module, pDB ptr_t, action AuthorizerActionCode, zSchema, zTabName ptr_t, rowid uint64) { +func updateCallback(ctx context.Context, mod api.Module, pDB ptr_t, action AuthorizerActionCode, zSchema, zTabName ptr_t, rowid int64) { if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.update != nil { schema := util.ReadString(mod, zSchema, _MAX_NAME) table := util.ReadString(mod, zTabName, _MAX_NAME) - c.update(action, schema, table, int64(rowid)) + c.update(action, schema, table, rowid) } } @@ -304,6 +304,6 @@ func updateCallback(ctx context.Context, mod api.Module, pDB ptr_t, action Autho // // https://sqlite.org/c3ref/db_cacheflush.html func (c *Conn) CacheFlush() error { - rc := res_t(c.call("sqlite3_db_cacheflush", uint64(c.handle))) + rc := res_t(c.call("sqlite3_db_cacheflush", stk_t(c.handle))) return c.error(rc) } diff --git a/util/sql3util/parse.go b/util/sql3util/parse.go index 0cc0d3df..7dd76ceb 100644 --- a/util/sql3util/parse.go +++ b/util/sql3util/parse.go @@ -50,7 +50,7 @@ func ParseTable(sql string) (_ *Table, err error) { copy(buf, sql) } - stack := [...]uint64{sqlp, uint64(len(sql)), errp} + stack := [...]util.Stk_t{sqlp, util.Stk_t(len(sql)), errp} err = mod.ExportedFunction("sql3parse_table").CallWithStack(ctx, stack[:]) if err != nil { return nil, err diff --git a/value.go b/value.go index 2e5d4cf4..a2399fba 100644 --- a/value.go +++ b/value.go @@ -19,18 +19,18 @@ type Value struct { copied bool } -func (v Value) protected() uint64 { +func (v Value) protected() stk_t { if v.unprot { panic(util.ValueErr) } - return uint64(v.handle) + return stk_t(v.handle) } // Dup makes a copy of the SQL value and returns a pointer to that copy. // // https://sqlite.org/c3ref/value_dup.html func (v Value) Dup() *Value { - ptr := ptr_t(v.c.call("sqlite3_value_dup", uint64(v.handle))) + ptr := ptr_t(v.c.call("sqlite3_value_dup", stk_t(v.handle))) return &Value{ c: v.c, copied: true, @@ -45,7 +45,7 @@ func (dup *Value) Close() error { if !dup.copied { panic(util.ValueErr) } - dup.c.call("sqlite3_value_free", uint64(dup.handle)) + dup.c.call("sqlite3_value_free", stk_t(dup.handle)) dup.handle = 0 return nil } @@ -158,7 +158,7 @@ func (v Value) rawBytes(ptr ptr_t) []byte { } n := int32(v.c.call("sqlite3_value_bytes", v.protected())) - return util.View(v.c.mod, ptr, uint64(n)) + return util.View(v.c.mod, ptr, int64(n)) } // Pointer gets the pointer associated with this value, @@ -213,7 +213,7 @@ func (v Value) FromBind() bool { func (v Value) InFirst() (Value, error) { defer v.c.arena.mark()() valPtr := v.c.arena.new(ptrlen) - rc := res_t(v.c.call("sqlite3_vtab_in_first", uint64(v.handle), uint64(valPtr))) + rc := res_t(v.c.call("sqlite3_vtab_in_first", stk_t(v.handle), stk_t(valPtr))) if err := v.c.error(rc); err != nil { return Value{}, err } @@ -230,7 +230,7 @@ func (v Value) InFirst() (Value, error) { func (v Value) InNext() (Value, error) { defer v.c.arena.mark()() valPtr := v.c.arena.new(ptrlen) - rc := res_t(v.c.call("sqlite3_vtab_in_next", uint64(v.handle), uint64(valPtr))) + rc := res_t(v.c.call("sqlite3_vtab_in_next", stk_t(v.handle), stk_t(valPtr))) if err := v.c.error(rc); err != nil { return Value{}, err } diff --git a/vfs/const.go b/vfs/const.go index b1de96f5..dc3b0db8 100644 --- a/vfs/const.go +++ b/vfs/const.go @@ -11,7 +11,10 @@ const ( ptrlen = util.PtrLen ) -type ptr_t = util.Ptr_t +type ( + stk_t = util.Stk_t + ptr_t = util.Ptr_t +) // https://sqlite.org/rescode.html type _ErrorCode uint32 diff --git a/vfs/filename.go b/vfs/filename.go index 7137abd9..965c3b1a 100644 --- a/vfs/filename.go +++ b/vfs/filename.go @@ -18,7 +18,7 @@ type Filename struct { mod api.Module zPath ptr_t flags OpenFlag - stack [2]uint64 + stack [2]stk_t } // GetFilename is an internal API users should not call directly. @@ -71,7 +71,7 @@ func (n *Filename) path(method string) string { return "" } - n.stack[0] = uint64(n.zPath) + n.stack[0] = stk_t(n.zPath) fn := n.mod.ExportedFunction(method) if err := fn.CallWithStack(n.ctx, n.stack[:]); err != nil { panic(err) @@ -90,7 +90,7 @@ func (n *Filename) DatabaseFile() File { return nil } - n.stack[0] = uint64(n.zPath) + n.stack[0] = stk_t(n.zPath) fn := n.mod.ExportedFunction("sqlite3_database_file_object") if err := fn.CallWithStack(n.ctx, n.stack[:]); err != nil { panic(err) @@ -108,8 +108,8 @@ func (n *Filename) URIParameter(key string) string { } uriKey := n.mod.ExportedFunction("sqlite3_uri_key") - n.stack[0] = uint64(n.zPath) - n.stack[1] = uint64(0) + n.stack[0] = stk_t(n.zPath) + n.stack[1] = stk_t(0) if err := uriKey.CallWithStack(n.ctx, n.stack[:]); err != nil { panic(err) } @@ -146,8 +146,8 @@ func (n *Filename) URIParameters() url.Values { } uriKey := n.mod.ExportedFunction("sqlite3_uri_key") - n.stack[0] = uint64(n.zPath) - n.stack[1] = uint64(0) + n.stack[0] = stk_t(n.zPath) + n.stack[1] = stk_t(0) if err := uriKey.CallWithStack(n.ctx, n.stack[:]); err != nil { panic(err) } diff --git a/vfs/shm_dotlk.go b/vfs/shm_dotlk.go index e5062481..cb697a9c 100644 --- a/vfs/shm_dotlk.go +++ b/vfs/shm_dotlk.go @@ -36,7 +36,7 @@ type vfsShm struct { path string shadow [][_WALINDEX_PGSZ]byte ptrs []ptr_t - stack [1]uint64 + stack [1]stk_t lock [_SHM_NLOCK]bool } @@ -128,7 +128,7 @@ func (s *vfsShm) shmMap(ctx context.Context, mod api.Module, id, size int32, ext // Allocate local memory. for int(id) >= len(s.ptrs) { - s.stack[0] = uint64(size) + s.stack[0] = stk_t(size) if err := s.alloc.CallWithStack(ctx, s.stack[:]); err != nil { panic(err) } @@ -168,7 +168,7 @@ func (s *vfsShm) shmUnmap(delete bool) { defer s.Unlock() for _, p := range s.ptrs { - s.stack[0] = uint64(p) + s.stack[0] = stk_t(p) if err := s.free.CallWithStack(context.Background(), s.stack[:]); err != nil { panic(err) } diff --git a/vfs/shm_windows.go b/vfs/shm_windows.go index 29d26ab5..ed2e93f8 100644 --- a/vfs/shm_windows.go +++ b/vfs/shm_windows.go @@ -27,7 +27,7 @@ type vfsShm struct { shared [][]byte shadow [][_WALINDEX_PGSZ]byte ptrs []ptr_t - stack [1]uint64 + stack [1]stk_t fileLock bool blocking bool sync.Mutex @@ -119,7 +119,7 @@ func (s *vfsShm) shmMap(ctx context.Context, mod api.Module, id, size int32, ext // Allocate local memory. for int(id) >= len(s.ptrs) { - s.stack[0] = uint64(size) + s.stack[0] = stk_t(size) if err := s.alloc.CallWithStack(ctx, s.stack[:]); err != nil { panic(err) } @@ -168,7 +168,7 @@ func (s *vfsShm) shmUnmap(delete bool) { // Free local memory. for _, p := range s.ptrs { - s.stack[0] = uint64(p) + s.stack[0] = stk_t(p) if err := s.free.CallWithStack(context.Background(), s.stack[:]); err != nil { panic(err) } diff --git a/vfs/vfs.go b/vfs/vfs.go index 2037cd0a..ff1f646b 100644 --- a/vfs/vfs.go +++ b/vfs/vfs.go @@ -79,7 +79,7 @@ func vfsLocaltime(ctx context.Context, mod api.Module, pTm ptr_t, t int64) _Erro } func vfsRandomness(ctx context.Context, mod api.Module, pVfs ptr_t, nByte int32, zByte ptr_t) uint32 { - mem := util.View(mod, zByte, uint64(nByte)) + mem := util.View(mod, zByte, int64(nByte)) n, _ := rand.Reader.Read(mem) return uint32(n) } @@ -110,7 +110,7 @@ func vfsFullPathname(ctx context.Context, mod api.Module, pVfs, zRelative ptr_t, return vfsErrorCode(err, _CANTOPEN_FULLPATH) } -func vfsDelete(ctx context.Context, mod api.Module, pVfs, zPath ptr_t, syncDir uint32) _ErrorCode { +func vfsDelete(ctx context.Context, mod api.Module, pVfs, zPath ptr_t, syncDir int32) _ErrorCode { vfs := vfsGet(mod, pVfs) path := util.ReadString(mod, zPath, _MAX_PATHNAME) @@ -170,7 +170,7 @@ func vfsClose(ctx context.Context, mod api.Module, pFile ptr_t) _ErrorCode { func vfsRead(ctx context.Context, mod api.Module, pFile, zBuf ptr_t, iAmt int32, iOfst int64) _ErrorCode { file := vfsFileGet(ctx, mod, pFile).(File) - buf := util.View(mod, zBuf, uint64(iAmt)) + buf := util.View(mod, zBuf, int64(iAmt)) n, err := file.ReadAt(buf, iOfst) if n == int(iAmt) { @@ -185,7 +185,7 @@ func vfsRead(ctx context.Context, mod api.Module, pFile, zBuf ptr_t, iAmt int32, func vfsWrite(ctx context.Context, mod api.Module, pFile, zBuf ptr_t, iAmt int32, iOfst int64) _ErrorCode { file := vfsFileGet(ctx, mod, pFile).(File) - buf := util.View(mod, zBuf, uint64(iAmt)) + buf := util.View(mod, zBuf, int64(iAmt)) _, err := file.WriteAt(buf, iOfst) return vfsErrorCode(err, _IOERR_WRITE) @@ -367,7 +367,7 @@ func vfsFileControlImpl(ctx context.Context, mod api.Module, file File, op _Fcnt } if out != "" { fn := mod.ExportedFunction("sqlite3_malloc64") - stack := [...]uint64{uint64(len(out) + 1)} + stack := [...]stk_t{stk_t(len(out) + 1)} if err := fn.CallWithStack(ctx, stack[:]); err != nil { panic(err) } @@ -379,10 +379,10 @@ func vfsFileControlImpl(ctx context.Context, mod api.Module, file File, op _Fcnt case _FCNTL_BUSYHANDLER: if file, ok := file.(FileBusyHandler); ok { - arg := util.Read64[uint64](mod, pArg) + arg := util.Read64[stk_t](mod, pArg) fn := mod.ExportedFunction("sqlite3_invoke_busy_handler_go") file.BusyHandler(func() bool { - stack := [...]uint64{arg} + stack := [...]stk_t{arg} if err := fn.CallWithStack(ctx, stack[:]); err != nil { panic(err) } @@ -436,7 +436,7 @@ func vfsShmLock(ctx context.Context, mod api.Module, pFile ptr_t, offset, n int3 return shm.shmLock(offset, n, flags) } -func vfsShmUnmap(ctx context.Context, mod api.Module, pFile ptr_t, bDelete uint32) _ErrorCode { +func vfsShmUnmap(ctx context.Context, mod api.Module, pFile ptr_t, bDelete int32) _ErrorCode { shm := vfsFileGet(ctx, mod, pFile).(FileSharedMemory).SharedMemory() shm.shmUnmap(bDelete != 0) return _OK diff --git a/vtab.go b/vtab.go index 9689bead..8dbef8ab 100644 --- a/vtab.go +++ b/vtab.go @@ -64,8 +64,8 @@ func CreateModule[T VTab](db *Conn, name string, create, connect VTabConstructor if connect != nil { modulePtr = util.AddHandle(db.ctx, module[T]{create, connect}) } - rc := res_t(db.call("sqlite3_create_module_go", uint64(db.handle), - uint64(namePtr), uint64(flags), uint64(modulePtr))) + rc := res_t(db.call("sqlite3_create_module_go", stk_t(db.handle), + stk_t(namePtr), stk_t(flags), stk_t(modulePtr))) return db.error(rc) } @@ -80,7 +80,7 @@ func implements[T any](typ reflect.Type) bool { func (c *Conn) DeclareVTab(sql string) error { defer c.arena.mark()() sqlPtr := c.arena.string(sql) - rc := res_t(c.call("sqlite3_declare_vtab", uint64(c.handle), uint64(sqlPtr))) + rc := res_t(c.call("sqlite3_declare_vtab", stk_t(c.handle), stk_t(sqlPtr))) return c.error(rc) } @@ -101,7 +101,7 @@ const ( // // https://sqlite.org/c3ref/vtab_on_conflict.html func (c *Conn) VTabOnConflict() VTabConflictMode { - return VTabConflictMode(c.call("sqlite3_vtab_on_conflict", uint64(c.handle))) + return VTabConflictMode(c.call("sqlite3_vtab_on_conflict", stk_t(c.handle))) } // VTabConfigOption is a virtual table configuration option. @@ -120,13 +120,13 @@ const ( // // https://sqlite.org/c3ref/vtab_config.html func (c *Conn) VTabConfig(op VTabConfigOption, args ...any) error { - var i uint64 + var i int32 if op == VTAB_CONSTRAINT_SUPPORT && len(args) > 0 { if b, ok := args[0].(bool); ok && b { i = 1 } } - rc := res_t(c.call("sqlite3_vtab_config_go", uint64(c.handle), uint64(op), i)) + rc := res_t(c.call("sqlite3_vtab_config_go", stk_t(c.handle), stk_t(op), stk_t(i))) return c.error(rc) } @@ -308,8 +308,8 @@ type IndexConstraintUsage struct { func (idx *IndexInfo) RHSValue(column int) (Value, error) { defer idx.c.arena.mark()() valPtr := idx.c.arena.new(ptrlen) - rc := res_t(idx.c.call("sqlite3_vtab_rhs_value", uint64(idx.handle), - uint64(column), uint64(valPtr))) + rc := res_t(idx.c.call("sqlite3_vtab_rhs_value", stk_t(idx.handle), + stk_t(column), stk_t(valPtr))) if err := idx.c.error(rc); err != nil { return Value{}, err } @@ -323,8 +323,8 @@ func (idx *IndexInfo) RHSValue(column int) (Value, error) { // // https://sqlite.org/c3ref/vtab_collation.html func (idx *IndexInfo) Collation(column int) string { - ptr := ptr_t(idx.c.call("sqlite3_vtab_collation", uint64(idx.handle), - uint64(column))) + ptr := ptr_t(idx.c.call("sqlite3_vtab_collation", stk_t(idx.handle), + stk_t(column))) return util.ReadString(idx.c.mod, ptr, _MAX_NAME) } @@ -332,7 +332,7 @@ func (idx *IndexInfo) Collation(column int) string { // // https://sqlite.org/c3ref/vtab_distinct.html func (idx *IndexInfo) Distinct() int { - i := int32(idx.c.call("sqlite3_vtab_distinct", uint64(idx.handle))) + i := int32(idx.c.call("sqlite3_vtab_distinct", stk_t(idx.handle))) return int(i) } @@ -340,8 +340,8 @@ func (idx *IndexInfo) Distinct() int { // // https://sqlite.org/c3ref/vtab_in.html func (idx *IndexInfo) In(column, handle int) bool { - b := int32(idx.c.call("sqlite3_vtab_in", uint64(idx.handle), - uint64(column), uint64(handle))) + b := int32(idx.c.call("sqlite3_vtab_in", stk_t(idx.handle), + stk_t(column), stk_t(handle))) return b != 0 } @@ -442,8 +442,8 @@ const ( INDEX_SCAN_UNIQUE IndexScanFlag = 1 ) -func vtabModuleCallback(i vtabConstructor) func(_ context.Context, _ api.Module, _ ptr_t, _ int32, _, _, _ ptr_t) uint32 { - return func(ctx context.Context, mod api.Module, pMod ptr_t, nArg int32, pArg, ppVTab, pzErr ptr_t) uint32 { +func vtabModuleCallback(i vtabConstructor) func(_ context.Context, _ api.Module, _ ptr_t, _ int32, _, _, _ ptr_t) res_t { + return func(ctx context.Context, mod api.Module, pMod ptr_t, nArg int32, pArg, ppVTab, pzErr ptr_t) res_t { arg := make([]reflect.Value, 1+nArg) arg[0] = reflect.ValueOf(ctx.Value(connKey{})) @@ -463,12 +463,12 @@ func vtabModuleCallback(i vtabConstructor) func(_ context.Context, _ api.Module, } } -func vtabDisconnectCallback(ctx context.Context, mod api.Module, pVTab ptr_t) uint32 { +func vtabDisconnectCallback(ctx context.Context, mod api.Module, pVTab ptr_t) res_t { err := vtabDelHandle(ctx, mod, pVTab) return vtabError(ctx, mod, 0, _PTR_ERROR, err) } -func vtabDestroyCallback(ctx context.Context, mod api.Module, pVTab ptr_t) uint32 { +func vtabDestroyCallback(ctx context.Context, mod api.Module, pVTab ptr_t) res_t { vtab := vtabGetHandle(ctx, mod, pVTab).(VTabDestroyer) err := vtab.Destroy() if cerr := vtabDelHandle(ctx, mod, pVTab); err == nil { @@ -477,7 +477,7 @@ func vtabDestroyCallback(ctx context.Context, mod api.Module, pVTab ptr_t) uint3 return vtabError(ctx, mod, 0, _PTR_ERROR, err) } -func vtabBestIndexCallback(ctx context.Context, mod api.Module, pVTab, pIdxInfo ptr_t) uint32 { +func vtabBestIndexCallback(ctx context.Context, mod api.Module, pVTab, pIdxInfo ptr_t) res_t { var info IndexInfo info.handle = pIdxInfo info.c = ctx.Value(connKey{}).(*Conn) @@ -490,7 +490,7 @@ func vtabBestIndexCallback(ctx context.Context, mod api.Module, pVTab, pIdxInfo return vtabError(ctx, mod, pVTab, _VTAB_ERROR, err) } -func vtabUpdateCallback(ctx context.Context, mod api.Module, pVTab ptr_t, nArg int32, pArg, pRowID ptr_t) uint32 { +func vtabUpdateCallback(ctx context.Context, mod api.Module, pVTab ptr_t, nArg int32, pArg, pRowID ptr_t) res_t { vtab := vtabGetHandle(ctx, mod, pVTab).(VTabUpdater) db := ctx.Value(connKey{}).(*Conn) @@ -504,13 +504,13 @@ func vtabUpdateCallback(ctx context.Context, mod api.Module, pVTab ptr_t, nArg i return vtabError(ctx, mod, pVTab, _VTAB_ERROR, err) } -func vtabRenameCallback(ctx context.Context, mod api.Module, pVTab, zNew ptr_t) uint32 { +func vtabRenameCallback(ctx context.Context, mod api.Module, pVTab, zNew ptr_t) res_t { vtab := vtabGetHandle(ctx, mod, pVTab).(VTabRenamer) err := vtab.Rename(util.ReadString(mod, zNew, _MAX_NAME)) return vtabError(ctx, mod, pVTab, _VTAB_ERROR, err) } -func vtabFindFuncCallback(ctx context.Context, mod api.Module, pVTab ptr_t, nArg int32, zName, pxFunc ptr_t) uint32 { +func vtabFindFuncCallback(ctx context.Context, mod api.Module, pVTab ptr_t, nArg int32, zName, pxFunc ptr_t) int32 { vtab := vtabGetHandle(ctx, mod, pVTab).(VTabOverloader) f, op := vtab.FindFunction(int(nArg), util.ReadString(mod, zName, _MAX_NAME)) if op != 0 { @@ -521,10 +521,10 @@ func vtabFindFuncCallback(ctx context.Context, mod api.Module, pVTab ptr_t, nArg }) util.Write32(mod, pxFunc, wrapper) } - return uint32(op) + return int32(op) } -func vtabIntegrityCallback(ctx context.Context, mod api.Module, pVTab, zSchema, zTabName ptr_t, mFlags uint32, pzErr ptr_t) uint32 { +func vtabIntegrityCallback(ctx context.Context, mod api.Module, pVTab, zSchema, zTabName ptr_t, mFlags uint32, pzErr ptr_t) res_t { vtab := vtabGetHandle(ctx, mod, pVTab).(VTabChecker) schema := util.ReadString(mod, zSchema, _MAX_NAME) table := util.ReadString(mod, zTabName, _MAX_NAME) @@ -536,49 +536,49 @@ func vtabIntegrityCallback(ctx context.Context, mod api.Module, pVTab, zSchema, return code } -func vtabBeginCallback(ctx context.Context, mod api.Module, pVTab ptr_t) uint32 { +func vtabBeginCallback(ctx context.Context, mod api.Module, pVTab ptr_t) res_t { vtab := vtabGetHandle(ctx, mod, pVTab).(VTabTxn) err := vtab.Begin() return vtabError(ctx, mod, pVTab, _VTAB_ERROR, err) } -func vtabSyncCallback(ctx context.Context, mod api.Module, pVTab ptr_t) uint32 { +func vtabSyncCallback(ctx context.Context, mod api.Module, pVTab ptr_t) res_t { vtab := vtabGetHandle(ctx, mod, pVTab).(VTabTxn) err := vtab.Sync() return vtabError(ctx, mod, pVTab, _VTAB_ERROR, err) } -func vtabCommitCallback(ctx context.Context, mod api.Module, pVTab ptr_t) uint32 { +func vtabCommitCallback(ctx context.Context, mod api.Module, pVTab ptr_t) res_t { vtab := vtabGetHandle(ctx, mod, pVTab).(VTabTxn) err := vtab.Commit() return vtabError(ctx, mod, pVTab, _VTAB_ERROR, err) } -func vtabRollbackCallback(ctx context.Context, mod api.Module, pVTab ptr_t) uint32 { +func vtabRollbackCallback(ctx context.Context, mod api.Module, pVTab ptr_t) res_t { vtab := vtabGetHandle(ctx, mod, pVTab).(VTabTxn) err := vtab.Rollback() return vtabError(ctx, mod, pVTab, _VTAB_ERROR, err) } -func vtabSavepointCallback(ctx context.Context, mod api.Module, pVTab ptr_t, id int32) uint32 { +func vtabSavepointCallback(ctx context.Context, mod api.Module, pVTab ptr_t, id int32) res_t { vtab := vtabGetHandle(ctx, mod, pVTab).(VTabSavepointer) err := vtab.Savepoint(int(id)) return vtabError(ctx, mod, pVTab, _VTAB_ERROR, err) } -func vtabReleaseCallback(ctx context.Context, mod api.Module, pVTab ptr_t, id int32) uint32 { +func vtabReleaseCallback(ctx context.Context, mod api.Module, pVTab ptr_t, id int32) res_t { vtab := vtabGetHandle(ctx, mod, pVTab).(VTabSavepointer) err := vtab.Release(int(id)) return vtabError(ctx, mod, pVTab, _VTAB_ERROR, err) } -func vtabRollbackToCallback(ctx context.Context, mod api.Module, pVTab ptr_t, id int32) uint32 { +func vtabRollbackToCallback(ctx context.Context, mod api.Module, pVTab ptr_t, id int32) res_t { vtab := vtabGetHandle(ctx, mod, pVTab).(VTabSavepointer) err := vtab.RollbackTo(int(id)) return vtabError(ctx, mod, pVTab, _VTAB_ERROR, err) } -func cursorOpenCallback(ctx context.Context, mod api.Module, pVTab, ppCur ptr_t) uint32 { +func cursorOpenCallback(ctx context.Context, mod api.Module, pVTab, ppCur ptr_t) res_t { vtab := vtabGetHandle(ctx, mod, pVTab).(VTab) cursor, err := vtab.Open() @@ -589,12 +589,12 @@ func cursorOpenCallback(ctx context.Context, mod api.Module, pVTab, ppCur ptr_t) return vtabError(ctx, mod, pVTab, _VTAB_ERROR, err) } -func cursorCloseCallback(ctx context.Context, mod api.Module, pCur ptr_t) uint32 { +func cursorCloseCallback(ctx context.Context, mod api.Module, pCur ptr_t) res_t { err := vtabDelHandle(ctx, mod, pCur) return vtabError(ctx, mod, 0, _VTAB_ERROR, err) } -func cursorFilterCallback(ctx context.Context, mod api.Module, pCur ptr_t, idxNum int32, idxStr ptr_t, nArg int32, pArg ptr_t) uint32 { +func cursorFilterCallback(ctx context.Context, mod api.Module, pCur ptr_t, idxNum int32, idxStr ptr_t, nArg int32, pArg ptr_t) res_t { cursor := vtabGetHandle(ctx, mod, pCur).(VTabCursor) db := ctx.Value(connKey{}).(*Conn) args := make([]Value, nArg) @@ -607,7 +607,7 @@ func cursorFilterCallback(ctx context.Context, mod api.Module, pCur ptr_t, idxNu return vtabError(ctx, mod, pCur, _CURSOR_ERROR, err) } -func cursorEOFCallback(ctx context.Context, mod api.Module, pCur ptr_t) uint32 { +func cursorEOFCallback(ctx context.Context, mod api.Module, pCur ptr_t) int32 { cursor := vtabGetHandle(ctx, mod, pCur).(VTabCursor) if cursor.EOF() { return 1 @@ -615,20 +615,20 @@ func cursorEOFCallback(ctx context.Context, mod api.Module, pCur ptr_t) uint32 { return 0 } -func cursorNextCallback(ctx context.Context, mod api.Module, pCur ptr_t) uint32 { +func cursorNextCallback(ctx context.Context, mod api.Module, pCur ptr_t) res_t { cursor := vtabGetHandle(ctx, mod, pCur).(VTabCursor) err := cursor.Next() return vtabError(ctx, mod, pCur, _CURSOR_ERROR, err) } -func cursorColumnCallback(ctx context.Context, mod api.Module, pCur, pCtx ptr_t, n int32) uint32 { +func cursorColumnCallback(ctx context.Context, mod api.Module, pCur, pCtx ptr_t, n int32) res_t { cursor := vtabGetHandle(ctx, mod, pCur).(VTabCursor) db := ctx.Value(connKey{}).(*Conn) err := cursor.Column(Context{db, pCtx}, int(n)) return vtabError(ctx, mod, pCur, _CURSOR_ERROR, err) } -func cursorRowIDCallback(ctx context.Context, mod api.Module, pCur, pRowID ptr_t) uint32 { +func cursorRowIDCallback(ctx context.Context, mod api.Module, pCur, pRowID ptr_t) res_t { cursor := vtabGetHandle(ctx, mod, pCur).(VTabCursor) rowID, err := cursor.RowID() @@ -645,7 +645,7 @@ const ( _CURSOR_ERROR ) -func vtabError(ctx context.Context, mod api.Module, ptr ptr_t, kind uint32, err error) uint32 { +func vtabError(ctx context.Context, mod api.Module, ptr ptr_t, kind uint32, err error) res_t { const zErrMsgOffset = 8 msg, code := errorCode(err, ERROR) if msg != "" && ptr != 0 { From 9942050016ffeed1db7742b4bd529995bf84e6f8 Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Tue, 21 Jan 2025 01:35:58 +0000 Subject: [PATCH 5/6] More. --- conn.go | 4 ++-- func.go | 2 +- internal/util/mem.go | 8 ++++---- sqlite.go | 16 +++++++--------- sqlite_test.go | 4 ++-- vfs/memdb/memdb.go | 4 +++- vfs/vfs.go | 12 ++++++------ vtab.go | 2 +- 8 files changed, 26 insertions(+), 26 deletions(-) diff --git a/conn.go b/conn.go index 79119d7e..fffc7416 100644 --- a/conn.go +++ b/conn.go @@ -35,10 +35,10 @@ type Conn struct { update func(AuthorizerActionCode, string, string, int64) commit func() bool rollback func() - arena arena busy1st time.Time busylst time.Time + arena arena handle ptr_t } @@ -91,7 +91,7 @@ func newConn(ctx context.Context, filename string, flags OpenFlag) (ret *Conn, _ }() c.ctx = context.WithValue(c.ctx, connKey{}, c) - c.arena = c.newArena(1024) + c.arena = c.newArena() c.handle, err = c.openDB(filename, flags) if err == nil { err = initExtensions(c) diff --git a/func.go b/func.go index f6c488ff..6b69368b 100644 --- a/func.go +++ b/func.go @@ -213,7 +213,7 @@ func callbackArgs(db *Conn, arg []Value, pArg ptr_t) { for i := range arg { arg[i] = Value{ c: db, - handle: util.Read32[ptr_t](db.mod, pArg+ptrlen*ptr_t(i)), + handle: util.Read32[ptr_t](db.mod, pArg+ptr_t(i)*ptrlen), } } } diff --git a/internal/util/mem.go b/internal/util/mem.go index a4d89445..bfb1a644 100644 --- a/internal/util/mem.go +++ b/internal/util/mem.go @@ -26,12 +26,12 @@ func View(mod api.Module, ptr Ptr_t, size int64) []byte { if ptr == 0 { panic(NilErr) } - if uint64(size) > math.MaxUint32 { - panic(RangeErr) - } if size == 0 { return nil } + if uint64(size) > math.MaxUint32 { + panic(RangeErr) + } buf, ok := mod.Memory().Read(uint32(ptr), uint32(size)) if !ok { panic(RangeErr) @@ -139,7 +139,7 @@ func WriteBytes(mod api.Module, ptr Ptr_t, b []byte) { } func WriteString(mod api.Module, ptr Ptr_t, s string) { - buf := View(mod, ptr, int64(len(s)+1)) + buf := View(mod, ptr, int64(len(s))+1) buf[len(s)] = 0 copy(buf, s) } diff --git a/sqlite.go b/sqlite.go index defb3f76..8203603e 100644 --- a/sqlite.go +++ b/sqlite.go @@ -234,13 +234,12 @@ func (sqlt *sqlite) newString(s string) ptr_t { return ptr } -func (sqlt *sqlite) newArena(size int64) arena { - // Ensure the arena's size is a multiple of 8. - size = (size + 7) &^ 7 +const arenaSize = 4096 + +func (sqlt *sqlite) newArena() arena { return arena{ sqlt: sqlt, - size: uint32(size), - base: sqlt.new(size), + base: sqlt.new(arenaSize), } } @@ -248,8 +247,7 @@ type arena struct { sqlt *sqlite ptrs []ptr_t base ptr_t - next uint32 - size uint32 + next int32 } func (a *arena) free() { @@ -283,9 +281,9 @@ func (a *arena) new(size int64) ptr_t { } else { a.next = (a.next + 7) &^ 7 } - if size <= int64(a.size-a.next) { + if size <= arenaSize-int64(a.next) { ptr := a.base + ptr_t(a.next) - a.next += uint32(size) + a.next += int32(size) return ptr_t(ptr) } ptr := a.sqlt.new(size) diff --git a/sqlite_test.go b/sqlite_test.go index 1b909b9f..5b969dea 100644 --- a/sqlite_test.go +++ b/sqlite_test.go @@ -65,7 +65,7 @@ func Test_sqlite_newArena(t *testing.T) { } defer sqlite.close() - arena := sqlite.newArena(16) + arena := sqlite.newArena() defer arena.free() const title = "Lorem ipsum" @@ -192,7 +192,7 @@ func Test_sqlite_getString(t *testing.T) { func() { defer func() { _ = recover() }() - util.ReadString(sqlite.mod, ptr, int64(len(want)/2)) + util.ReadString(sqlite.mod, ptr, int64(len(want))/2) t.Error("want panic") }() diff --git a/vfs/memdb/memdb.go b/vfs/memdb/memdb.go index 4adb2dde..419fd1c6 100644 --- a/vfs/memdb/memdb.go +++ b/vfs/memdb/memdb.go @@ -10,9 +10,11 @@ import ( "github.com/ncruces/go-sqlite3/vfs" ) -// Must be a multiple of 64K (the largest page size). const sectorSize = 65536 +// Ensure sectorSize is a multiple of 64K (the largest page size). +var _ [0]struct{} = [sectorSize & 65535]struct{}{} + type memVFS struct{} func (memVFS) Open(name string, flags vfs.OpenFlag) (vfs.File, vfs.OpenFlag, error) { diff --git a/vfs/vfs.go b/vfs/vfs.go index ff1f646b..ca105fff 100644 --- a/vfs/vfs.go +++ b/vfs/vfs.go @@ -123,11 +123,11 @@ func vfsAccess(ctx context.Context, mod api.Module, pVfs, zPath ptr_t, flags Acc path := util.ReadString(mod, zPath, _MAX_PATHNAME) ok, err := vfs.Access(path, flags) - var val int32 + var res int32 if ok { - val = 1 + res = 1 } - util.Write32(mod, pResOut, val) + util.Write32(mod, pResOut, res) return vfsErrorCode(err, _IOERR_ACCESS) } @@ -226,11 +226,11 @@ func vfsCheckReservedLock(ctx context.Context, mod api.Module, pFile, pResOut pt file := vfsFileGet(ctx, mod, pFile).(File) locked, err := file.CheckReservedLock() - var val int32 + var res int32 if locked { - val = 1 + res = 1 } - util.Write32(mod, pResOut, val) + util.Write32(mod, pResOut, res) return vfsErrorCode(err, _IOERR_CHECKRESERVEDLOCK) } diff --git a/vtab.go b/vtab.go index 8dbef8ab..7b29a5e4 100644 --- a/vtab.go +++ b/vtab.go @@ -448,7 +448,7 @@ func vtabModuleCallback(i vtabConstructor) func(_ context.Context, _ api.Module, arg[0] = reflect.ValueOf(ctx.Value(connKey{})) for i := int32(0); i < nArg; i++ { - ptr := util.Read32[ptr_t](mod, pArg+ptr_t(i*ptrlen)) + ptr := util.Read32[ptr_t](mod, pArg+ptr_t(i)*ptrlen) arg[i+1] = reflect.ValueOf(util.ReadString(mod, ptr, _MAX_SQL_LENGTH)) } From 5c6194810d33401c1b2d2a9a2d45d7dbdcb82743 Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Tue, 21 Jan 2025 01:41:28 +0000 Subject: [PATCH 6/6] Docs. --- ext/README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ext/README.md b/ext/README.md index 35028dc7..5760cb78 100644 --- a/ext/README.md +++ b/ext/README.md @@ -25,6 +25,8 @@ you can load into your database connections. creates [pivot tables](https://github.com/jakethaw/pivot_vtab). - [`github.com/ncruces/go-sqlite3/ext/regexp`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/regexp) provides regular expression functions. +- [`github.com/ncruces/go-sqlite3/ext/serdes`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/serdes) + (de)serializes databases. - [`github.com/ncruces/go-sqlite3/ext/statement`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/statement) creates [parameterized views](https://github.com/0x09/sqlite-statement-vtab). - [`github.com/ncruces/go-sqlite3/ext/stats`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/ext/stats)