diff --git a/redisql_lib/src/community_statement.rs b/redisql_lib/src/community_statement.rs index 339cce9..8158954 100644 --- a/redisql_lib/src/community_statement.rs +++ b/redisql_lib/src/community_statement.rs @@ -39,7 +39,7 @@ impl<'a> fmt::Display for MultiStatement { } pub struct Statement { - stmt: *mut ffi::sqlite3_stmt, + stmt: ptr::NonNull, } unsafe impl Send for Statement {} @@ -48,7 +48,7 @@ unsafe impl Sync for Statement {} impl<'a> fmt::Display for Statement { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let sql = unsafe { - CStr::from_ptr(ffi::sqlite3_sql(self.stmt)) + CStr::from_ptr(ffi::sqlite3_sql(self.as_ptr())) .to_string_lossy() .into_owned() }; @@ -59,7 +59,7 @@ impl<'a> fmt::Display for Statement { impl<'a> Drop for Statement { fn drop(&mut self) { unsafe { - ffi::sqlite3_finalize(self.stmt); + ffi::sqlite3_finalize(self.as_ptr()); }; } } @@ -87,10 +87,13 @@ pub fn generate_statements( &mut next_query, ) }; + match r { ffi::SQLITE_OK => { - let stmt = Statement { stmt }; - stmts.push(stmt); + if !stmt.is_null() { + let stmt = Statement::from_ptr(stmt); + stmts.push(stmt); + } if unsafe { *next_query } == 0 { let (num_parameters, parameters) = count_parameters(&stmts)?; @@ -100,7 +103,7 @@ pub fn generate_statements( number_parameters: num_parameters, _parameters: parameters, }); - } + }; } _ => return Err(conn.get_last_error()), } @@ -108,11 +111,16 @@ pub fn generate_statements( } impl Statement { + fn from_ptr(stmt: *mut ffi::sqlite3_stmt) -> Self { + Statement { + stmt: ptr::NonNull::new(stmt).unwrap(), + } + } fn execute( &self, db: &RawConnection, ) -> Result { - match unsafe { ffi::sqlite3_step(self.stmt) } { + match unsafe { ffi::sqlite3_step(self.as_ptr()) } { ffi::SQLITE_OK => Ok(Cursor::OKCursor {}), ffi::SQLITE_DONE => { let modified_rows = @@ -120,9 +128,9 @@ impl Statement { Ok(Cursor::DONECursor { modified_rows }) } ffi::SQLITE_ROW => { - let num_columns = - unsafe { ffi::sqlite3_column_count(self.stmt) } - as i32; + let num_columns = unsafe { + ffi::sqlite3_column_count(self.as_ptr()) + } as i32; Ok(Cursor::RowsCursor { stmt: self, num_columns, @@ -134,10 +142,13 @@ impl Statement { } } fn get_last_error(&self) -> SQLite3Error { - let db = unsafe { ffi::sqlite3_db_handle(self.stmt) }; + let db = unsafe { ffi::sqlite3_db_handle(self.as_ptr()) }; let rc = RawConnection::from_db_handler(db); rc.get_last_error() } + pub fn as_ptr(&self) -> *mut ffi::sqlite3_stmt { + self.stmt.as_ptr() + } } impl<'a> StatementTrait<'a> for Statement { @@ -165,15 +176,15 @@ impl<'a> StatementTrait<'a> for Statement { ) }; match r { - ffi::SQLITE_OK => Ok(Statement { stmt }), + ffi::SQLITE_OK => Ok(Statement::from_ptr(stmt)), _ => Err(conn.get_last_error()), } } fn reset(&self) { unsafe { - ffi::sqlite3_reset(self.stmt); - ffi::sqlite3_clear_bindings(self.stmt); + ffi::sqlite3_reset(self.as_ptr()); + ffi::sqlite3_clear_bindings(self.as_ptr()); } } @@ -202,7 +213,7 @@ impl<'a> StatementTrait<'a> for Statement { } match unsafe { ffi::sqlite3_bind_text( - self.stmt, + self.as_ptr(), index, value.as_ptr() as *const c_char, value.len() as i32, @@ -214,19 +225,15 @@ impl<'a> StatementTrait<'a> for Statement { } } - fn get_raw_stmt(&self) -> *mut ffi::sqlite3_stmt { - self.stmt - } - fn is_read_only(&self) -> bool { - let v = unsafe { ffi::sqlite3_stmt_readonly(self.stmt) }; + let v = unsafe { ffi::sqlite3_stmt_readonly(self.as_ptr()) }; v != 0 } } impl<'a> StatementTrait<'a> for MultiStatement { fn reset(&self) { - self.stmts.iter().map(|stmt| stmt.reset()).count(); + self.stmts.iter().map(StatementTrait::reset).count(); } fn execute(&self) -> Result { let db = self.db.clone(); @@ -313,9 +320,6 @@ impl<'a> StatementTrait<'a> for MultiStatement { ) -> Result { generate_statements(conn, query) } - fn get_raw_stmt(&self) -> *mut ffi::sqlite3_stmt { - self.stmts[0].stmt - } fn is_read_only(&self) -> bool { for stmt in &self.stmts { @@ -392,11 +396,12 @@ fn get_parameter_name( stmt: &Statement, index: i32, ) -> Result, SQLite3Error> { - let parameter_name_ptr = - unsafe { ffi::sqlite3_bind_parameter_name(stmt.stmt, index) }; + let parameter_name_ptr = unsafe { + ffi::sqlite3_bind_parameter_name(stmt.as_ptr(), index) + }; let index_parameter = unsafe { ffi::sqlite3_bind_parameter_index( - stmt.stmt, + stmt.as_ptr(), parameter_name_ptr, ) }; @@ -410,7 +415,7 @@ fn get_parameters( stmt: &Statement, ) -> Result, SQLite3Error> { let total_paramenters = - unsafe { ffi::sqlite3_bind_parameter_count(stmt.stmt) } + unsafe { ffi::sqlite3_bind_parameter_count(stmt.as_ptr()) } as usize; if total_paramenters == 0 { return Ok(vec![]); diff --git a/redisql_lib/src/sqlite.rs b/redisql_lib/src/sqlite.rs index 050bf35..3ee0a1e 100644 --- a/redisql_lib/src/sqlite.rs +++ b/redisql_lib/src/sqlite.rs @@ -168,7 +168,6 @@ pub trait StatementTrait<'a>: Sized { index: i32, value: &str, ) -> Result; - fn get_raw_stmt(&self) -> *mut ffi::sqlite3_stmt; fn is_read_only(&self) -> bool { false } @@ -198,23 +197,23 @@ pub enum Entity { impl Entity { fn new(stmt: &Statement, i: i32) -> Entity { - match get_entity_type(stmt.get_raw_stmt(), i) { + match get_entity_type(stmt.as_ptr(), i) { EntityType::Integer => { let int = unsafe { - ffi::sqlite3_column_int64(stmt.get_raw_stmt(), i) + ffi::sqlite3_column_int64(stmt.as_ptr(), i) }; Entity::Integer { int } } EntityType::Float => { let value = unsafe { - ffi::sqlite3_column_double(stmt.get_raw_stmt(), i) + ffi::sqlite3_column_double(stmt.as_ptr(), i) }; Entity::Float { float: value } } EntityType::Text => { let value = unsafe { CStr::from_ptr(ffi::sqlite3_column_text( - stmt.get_raw_stmt(), + stmt.as_ptr(), i, ) as *const c_char) @@ -227,7 +226,7 @@ impl Entity { EntityType::Blob => { let value = unsafe { CStr::from_ptr(ffi::sqlite3_column_blob( - stmt.get_raw_stmt(), + stmt.as_ptr(), i, ) as *const c_char) @@ -323,7 +322,7 @@ impl<'a> From> for QueryResult { for i in 0..num_columns { let name = unsafe { CStr::from_ptr(ffi::sqlite3_column_name( - stmt.get_raw_stmt(), + stmt.as_ptr(), i, )) .to_string_lossy() @@ -340,7 +339,7 @@ impl<'a> From> for QueryResult { } unsafe { *previous_status = - ffi::sqlite3_step(stmt.get_raw_stmt()); + ffi::sqlite3_step(stmt.as_ptr()); }; result.push(row); diff --git a/src/commands.rs b/src/commands.rs index 9d158de..c3d975b 100644 --- a/src/commands.rs +++ b/src/commands.rs @@ -972,12 +972,12 @@ pub extern "C" fn MakeCopy( ) .unwrap(); STATISTICS.copy_err(); - return unsafe { + unsafe { r::rm::ffi::RedisModule_ReplyWithError.unwrap()( context.as_ptr(), error.as_ptr(), ) - }; + } } } } diff --git a/test/correctness/test.py b/test/correctness/test.py index b6ed0d2..9cf307e 100755 --- a/test/correctness/test.py +++ b/test/correctness/test.py @@ -860,6 +860,32 @@ def test_null_terminated(self): one = self.exec_naked("REDISQL.EXEC", "NULL", "SELECT 1" + b'\x00') self.assertEquals(one, [[1]]) +class TestBlankAfterSemicolon(TestRediSQLWithExec): + def test_whitespace_after_semicolon(self): + with DB(self, "NULL"): + one = self.exec_naked("REDISQL.EXEC", "NULL", "SELECT 1; ") + self.assertEquals(one, [[1]]) + + def test_newline_after_semicolon(self): + with DB(self, "NULL"): + one = self.exec_naked("REDISQL.EXEC", "NULL", "SELECT 1;\n") + self.assertEquals(one, [[1]]) + + def test_mix_after_semicolon(self): + with DB(self, "NULL"): + one = self.exec_naked("REDISQL.EXEC", "NULL", "SELECT 1; \n ") + self.assertEquals(one, [[1]]) + + def test_whitespace_after_semicolon_then_query(self): + with DB(self, "NULL"): + one = self.exec_naked("REDISQL.EXEC", "NULL", "SELECT 1; SELECT 2;") + self.assertEquals(one, [[2]]) + + def test_newline_after_semicolon_then_query(self): + with DB(self, "NULL"): + one = self.exec_naked("REDISQL.EXEC", "NULL", "SELECT 1;\nSELECT 2;") + self.assertEquals(one, [[2]]) + if __name__ == '__main__': unittest.main()