Skip to content
This repository has been archived by the owner on Jul 11, 2021. It is now read-only.

Commit

Permalink
merge
Browse files Browse the repository at this point in the history
  • Loading branch information
siscia committed May 8, 2019
2 parents 75ea042 + f86ac2e commit 823f215
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 38 deletions.
61 changes: 33 additions & 28 deletions redisql_lib/src/community_statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ impl<'a> fmt::Display for MultiStatement {
}

pub struct Statement {
stmt: *mut ffi::sqlite3_stmt,
stmt: ptr::NonNull<ffi::sqlite3_stmt>,
}

unsafe impl Send for Statement {}
Expand All @@ -48,7 +48,7 @@ unsafe impl Sync for Statement {}
impl<'a> fmt::Display for Statement {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let sql = unsafe {
CStr::from_ptr(ffi::sqlite3_sql(self.stmt))
CStr::from_ptr(ffi::sqlite3_sql(self.as_ptr()))
.to_string_lossy()
.into_owned()
};
Expand All @@ -59,7 +59,7 @@ impl<'a> fmt::Display for Statement {
impl<'a> Drop for Statement {
fn drop(&mut self) {
unsafe {
ffi::sqlite3_finalize(self.stmt);
ffi::sqlite3_finalize(self.as_ptr());
};
}
}
Expand Down Expand Up @@ -87,10 +87,13 @@ pub fn generate_statements(
&mut next_query,
)
};

match r {
ffi::SQLITE_OK => {
let stmt = Statement { stmt };
stmts.push(stmt);
if !stmt.is_null() {
let stmt = Statement::from_ptr(stmt);
stmts.push(stmt);
}
if unsafe { *next_query } == 0 {
let (num_parameters, parameters) =
count_parameters(&stmts)?;
Expand All @@ -100,29 +103,34 @@ pub fn generate_statements(
number_parameters: num_parameters,
_parameters: parameters,
});
}
};
}
_ => return Err(conn.get_last_error()),
}
}
}

impl Statement {
fn from_ptr(stmt: *mut ffi::sqlite3_stmt) -> Self {
Statement {
stmt: ptr::NonNull::new(stmt).unwrap(),
}
}
fn execute(
&self,
db: &RawConnection,
) -> Result<Cursor, SQLite3Error> {
match unsafe { ffi::sqlite3_step(self.stmt) } {
match unsafe { ffi::sqlite3_step(self.as_ptr()) } {
ffi::SQLITE_OK => Ok(Cursor::OKCursor {}),
ffi::SQLITE_DONE => {
let modified_rows =
unsafe { ffi::sqlite3_changes(db.get_db()) };
Ok(Cursor::DONECursor { modified_rows })
}
ffi::SQLITE_ROW => {
let num_columns =
unsafe { ffi::sqlite3_column_count(self.stmt) }
as i32;
let num_columns = unsafe {
ffi::sqlite3_column_count(self.as_ptr())
} as i32;
Ok(Cursor::RowsCursor {
stmt: self,
num_columns,
Expand All @@ -134,10 +142,13 @@ impl Statement {
}
}
fn get_last_error(&self) -> SQLite3Error {
let db = unsafe { ffi::sqlite3_db_handle(self.stmt) };
let db = unsafe { ffi::sqlite3_db_handle(self.as_ptr()) };
let rc = RawConnection::from_db_handler(db);
rc.get_last_error()
}
pub fn as_ptr(&self) -> *mut ffi::sqlite3_stmt {
self.stmt.as_ptr()
}
}

impl<'a> StatementTrait<'a> for Statement {
Expand Down Expand Up @@ -165,15 +176,15 @@ impl<'a> StatementTrait<'a> for Statement {
)
};
match r {
ffi::SQLITE_OK => Ok(Statement { stmt }),
ffi::SQLITE_OK => Ok(Statement::from_ptr(stmt)),
_ => Err(conn.get_last_error()),
}
}

fn reset(&self) {
unsafe {
ffi::sqlite3_reset(self.stmt);
ffi::sqlite3_clear_bindings(self.stmt);
ffi::sqlite3_reset(self.as_ptr());
ffi::sqlite3_clear_bindings(self.as_ptr());
}
}

Expand Down Expand Up @@ -202,7 +213,7 @@ impl<'a> StatementTrait<'a> for Statement {
}
match unsafe {
ffi::sqlite3_bind_text(
self.stmt,
self.as_ptr(),
index,
value.as_ptr() as *const c_char,
value.len() as i32,
Expand All @@ -214,19 +225,15 @@ impl<'a> StatementTrait<'a> for Statement {
}
}

fn get_raw_stmt(&self) -> *mut ffi::sqlite3_stmt {
self.stmt
}

fn is_read_only(&self) -> bool {
let v = unsafe { ffi::sqlite3_stmt_readonly(self.stmt) };
let v = unsafe { ffi::sqlite3_stmt_readonly(self.as_ptr()) };
v != 0
}
}

impl<'a> StatementTrait<'a> for MultiStatement {
fn reset(&self) {
self.stmts.iter().map(|stmt| stmt.reset()).count();
self.stmts.iter().map(StatementTrait::reset).count();
}
fn execute(&self) -> Result<Cursor, SQLite3Error> {
let db = self.db.clone();
Expand Down Expand Up @@ -313,9 +320,6 @@ impl<'a> StatementTrait<'a> for MultiStatement {
) -> Result<Self, SQLite3Error> {
generate_statements(conn, query)
}
fn get_raw_stmt(&self) -> *mut ffi::sqlite3_stmt {
self.stmts[0].stmt
}

fn is_read_only(&self) -> bool {
for stmt in &self.stmts {
Expand Down Expand Up @@ -392,11 +396,12 @@ fn get_parameter_name(
stmt: &Statement,
index: i32,
) -> Result<Option<Parameters>, SQLite3Error> {
let parameter_name_ptr =
unsafe { ffi::sqlite3_bind_parameter_name(stmt.stmt, index) };
let parameter_name_ptr = unsafe {
ffi::sqlite3_bind_parameter_name(stmt.as_ptr(), index)
};
let index_parameter = unsafe {
ffi::sqlite3_bind_parameter_index(
stmt.stmt,
stmt.as_ptr(),
parameter_name_ptr,
)
};
Expand All @@ -410,7 +415,7 @@ fn get_parameters(
stmt: &Statement,
) -> Result<Vec<Parameters>, SQLite3Error> {
let total_paramenters =
unsafe { ffi::sqlite3_bind_parameter_count(stmt.stmt) }
unsafe { ffi::sqlite3_bind_parameter_count(stmt.as_ptr()) }
as usize;
if total_paramenters == 0 {
return Ok(vec![]);
Expand Down
15 changes: 7 additions & 8 deletions redisql_lib/src/sqlite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,6 @@ pub trait StatementTrait<'a>: Sized {
index: i32,
value: &str,
) -> Result<SQLiteOK, SQLite3Error>;
fn get_raw_stmt(&self) -> *mut ffi::sqlite3_stmt;
fn is_read_only(&self) -> bool {
false
}
Expand Down Expand Up @@ -198,23 +197,23 @@ pub enum Entity {

impl Entity {
fn new(stmt: &Statement, i: i32) -> Entity {
match get_entity_type(stmt.get_raw_stmt(), i) {
match get_entity_type(stmt.as_ptr(), i) {
EntityType::Integer => {
let int = unsafe {
ffi::sqlite3_column_int64(stmt.get_raw_stmt(), i)
ffi::sqlite3_column_int64(stmt.as_ptr(), i)
};
Entity::Integer { int }
}
EntityType::Float => {
let value = unsafe {
ffi::sqlite3_column_double(stmt.get_raw_stmt(), i)
ffi::sqlite3_column_double(stmt.as_ptr(), i)
};
Entity::Float { float: value }
}
EntityType::Text => {
let value = unsafe {
CStr::from_ptr(ffi::sqlite3_column_text(
stmt.get_raw_stmt(),
stmt.as_ptr(),
i,
)
as *const c_char)
Expand All @@ -227,7 +226,7 @@ impl Entity {
EntityType::Blob => {
let value = unsafe {
CStr::from_ptr(ffi::sqlite3_column_blob(
stmt.get_raw_stmt(),
stmt.as_ptr(),
i,
)
as *const c_char)
Expand Down Expand Up @@ -323,7 +322,7 @@ impl<'a> From<Cursor<'a>> for QueryResult {
for i in 0..num_columns {
let name = unsafe {
CStr::from_ptr(ffi::sqlite3_column_name(
stmt.get_raw_stmt(),
stmt.as_ptr(),
i,
))
.to_string_lossy()
Expand All @@ -340,7 +339,7 @@ impl<'a> From<Cursor<'a>> for QueryResult {
}
unsafe {
*previous_status =
ffi::sqlite3_step(stmt.get_raw_stmt());
ffi::sqlite3_step(stmt.as_ptr());
};

result.push(row);
Expand Down
4 changes: 2 additions & 2 deletions src/commands.rs
Original file line number Diff line number Diff line change
Expand Up @@ -972,12 +972,12 @@ pub extern "C" fn MakeCopy(
)
.unwrap();
STATISTICS.copy_err();
return unsafe {
unsafe {
r::rm::ffi::RedisModule_ReplyWithError.unwrap()(
context.as_ptr(),
error.as_ptr(),
)
};
}
}
}
}
Expand Down
26 changes: 26 additions & 0 deletions test/correctness/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,6 +860,32 @@ def test_null_terminated(self):
one = self.exec_naked("REDISQL.EXEC", "NULL", "SELECT 1" + b'\x00')
self.assertEquals(one, [[1]])

class TestBlankAfterSemicolon(TestRediSQLWithExec):
def test_whitespace_after_semicolon(self):
with DB(self, "NULL"):
one = self.exec_naked("REDISQL.EXEC", "NULL", "SELECT 1; ")
self.assertEquals(one, [[1]])

def test_newline_after_semicolon(self):
with DB(self, "NULL"):
one = self.exec_naked("REDISQL.EXEC", "NULL", "SELECT 1;\n")
self.assertEquals(one, [[1]])

def test_mix_after_semicolon(self):
with DB(self, "NULL"):
one = self.exec_naked("REDISQL.EXEC", "NULL", "SELECT 1; \n ")
self.assertEquals(one, [[1]])

def test_whitespace_after_semicolon_then_query(self):
with DB(self, "NULL"):
one = self.exec_naked("REDISQL.EXEC", "NULL", "SELECT 1; SELECT 2;")
self.assertEquals(one, [[2]])

def test_newline_after_semicolon_then_query(self):
with DB(self, "NULL"):
one = self.exec_naked("REDISQL.EXEC", "NULL", "SELECT 1;\nSELECT 2;")
self.assertEquals(one, [[2]])

if __name__ == '__main__':
unittest.main()

0 comments on commit 823f215

Please sign in to comment.