Skip to content

Commit

Permalink
Improve get|set auxdata. credit: asg017#22
Browse files Browse the repository at this point in the history
  • Loading branch information
PThorpe92 committed Jan 6, 2025
1 parent 2c5c049 commit 8e34664
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 11 deletions.
25 changes: 14 additions & 11 deletions src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -372,23 +372,26 @@ pub fn result_pointer<T>(context: *mut sqlite3_context, name: &[u8], object: T)
};
}

// TODO maybe take in a Box<T>?
/// [`sqlite3_set_auxdata`](https://www.sqlite.org/c3ref/get_auxdata.html)
pub fn auxdata_set(
context: *mut sqlite3_context,
col: i32,
p: *mut c_void,
d: Option<unsafe extern "C" fn(*mut c_void)>,
) {
pub fn auxdata_set<T>(context: *mut sqlite3_context, col: i32, p: Box<T>) {
unsafe extern "C" fn cleanup<U>(p: *mut c_void) {
drop(Box::from_raw(p.cast::<U>()));
}

let raw = Box::into_raw(p).cast::<c_void>();
unsafe {
sqlite3ext_set_auxdata(context, col, p, d);
sqlite3ext_set_auxdata(context, col, raw, Some(cleanup::<T>));
}
}

// TODO maybe return a Box<T>?
/// [`sqlite3_get_auxdata`](https://www.sqlite.org/c3ref/get_auxdata.html)
pub fn auxdata_get(context: *mut sqlite3_context, col: i32) -> *mut c_void {
unsafe { sqlite3ext_get_auxdata(context, col) }
pub fn auxdata_get<'a, T>(context: *mut sqlite3_context, col: i32) -> Option<&'a mut T> {
let ptr = unsafe { sqlite3ext_get_auxdata(context, col).cast::<T>() };
if ptr.is_null() {
None
} else {
Some(unsafe { &mut *ptr })
}
}

pub fn context_db_handle(context: *mut sqlite3_context) -> *mut sqlite3 {
Expand Down
59 changes: 59 additions & 0 deletions tests/test_auxdata.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
use sqlite_loadable::prelude::*;
use sqlite_loadable::{api, define_scalar_function, Result};

pub fn check_auxdata(context: *mut sqlite3_context, values: &[*mut sqlite3_value]) -> Result<()> {
let label = api::value_text(values.first().unwrap()).unwrap();
let value = api::value_text(values.get(1).unwrap()).unwrap();

assert!(api::auxdata_get::<String>(context, 1).is_none());

let b = Box::new(String::from(value));
api::auxdata_set::<String>(context, 1, b);

let entry = api::auxdata_get::<String>(context, 1).unwrap();
assert!(entry == value);

api::result_text(context, format!("{label}={value}")).unwrap();

Ok(())
}

#[sqlite_entrypoint]
pub fn sqlite3_test_auxdata_init(db: *mut sqlite3) -> Result<()> {
define_scalar_function(db, "check_auxdata", 2, check_auxdata, FunctionFlags::UTF8)?;
Ok(())
}

#[cfg(test)]
mod tests {
use super::*;

use rusqlite::{ffi::sqlite3_auto_extension, Connection};

#[test]
fn test_rusqlite_auto_extension() {
unsafe {
sqlite3_auto_extension(Some(std::mem::transmute(
sqlite3_test_auxdata_init as *const (),
)));
}

let conn = Connection::open_in_memory().unwrap();

// NOTE: even nested expressions are evaluated in different contexts leading to an
// auxdata_get miss. auxdata_get/set is not suitable for naive caching across function
// evaluations.
let result: String = conn
.query_row(
"SELECT (check_auxdata(?1, check_auxdata(?2, ?3)))",
("outer_label", "inner_label", "value"),
|row| {
println!("ROW {row:?}");
row.get(0)
},
)
.unwrap();

assert_eq!(result, "outer_label=inner_label=value");
}
}

0 comments on commit 8e34664

Please sign in to comment.