diff --git a/core/rs/core/src/changes_vtab_write.rs b/core/rs/core/src/changes_vtab_write.rs index 3f26b07a..dee60f70 100644 --- a/core/rs/core/src/changes_vtab_write.rs +++ b/core/rs/core/src/changes_vtab_write.rs @@ -95,10 +95,45 @@ fn did_cid_win( reset_cached_stmt(col_val_stmt.stmt)?; if ret == 0 && unsafe { (*ext_data).mergeEqualValues == 1 } { // values are the same (ret == 0) and the option to tie break on site_id is true - ret = unsafe { - let my_site_id = core::slice::from_raw_parts((*ext_data).siteId, 16); - insert_site_id.cmp(my_site_id) as c_int - }; + let col_site_id_stmt_ref = tbl_info.get_col_site_id_stmt(db)?; + let col_site_id_stmt = col_site_id_stmt_ref.as_ref().ok_or(ResultCode::ERROR)?; + + let bind_result = col_site_id_stmt.bind_int64(1, key); + if let Err(rc) = bind_result { + reset_cached_stmt(col_site_id_stmt.stmt)?; + return Err(rc); + } + if let Err(rc) = col_site_id_stmt.bind_text(2, col_name, sqlite::Destructor::STATIC) + { + reset_cached_stmt(col_site_id_stmt.stmt)?; + return Err(rc); + } + + match col_site_id_stmt.step() { + Ok(ResultCode::ROW) => { + let local_site_id = col_site_id_stmt.column_blob(0)?; + ret = insert_site_id.cmp(local_site_id) as c_int; + + // reset the stmt after, we're accessing a slice in-memory + reset_cached_stmt(col_site_id_stmt.stmt)?; + } + Ok(ResultCode::DONE) => { + reset_cached_stmt(col_site_id_stmt.stmt)?; + let err = CString::new(format!( + "could not find site_id for previous change, cr-sqlite clock table might be corrupt for tbl {}", + insert_tbl + ))?; + unsafe { *errmsg = err.into_raw() }; + return Err(ResultCode::ERROR); + } + Ok(rc) | Err(rc) => { + reset_cached_stmt(col_site_id_stmt.stmt)?; + let err = + CString::new("Bad return code when selecting local column site_id")?; + unsafe { *errmsg = err.into_raw() }; + return Err(rc); + } + } } return Ok(ret > 0); } diff --git a/core/rs/core/src/tableinfo.rs b/core/rs/core/src/tableinfo.rs index 536abdaf..b36ea9a6 100644 --- a/core/rs/core/src/tableinfo.rs +++ b/core/rs/core/src/tableinfo.rs @@ -45,6 +45,7 @@ pub struct TableInfo { set_winner_clock_stmt: RefCell>, local_cl_stmt: RefCell>, col_version_stmt: RefCell>, + col_site_id_stmt: RefCell>, merge_pk_only_insert_stmt: RefCell>, merge_delete_stmt: RefCell>, merge_delete_drop_clocks_stmt: RefCell>, @@ -337,6 +338,21 @@ impl TableInfo { Ok(self.col_version_stmt.try_borrow()?) } + pub fn get_col_site_id_stmt( + &self, + db: *mut sqlite3, + ) -> Result>, ResultCode> { + if self.col_site_id_stmt.try_borrow()?.is_none() { + let sql = format!( + "SELECT site_id FROM crsql_site_id WHERE ordinal = (SELECT site_id FROM \"{table_name}__crsql_clock\" WHERE key = ? AND col_name = ?)", + table_name = crate::util::escape_ident(&self.tbl_name), + ); + let ret = db.prepare_v3(&sql, sqlite::PREPARE_PERSISTENT)?; + *self.col_site_id_stmt.try_borrow_mut()? = Some(ret); + } + Ok(self.col_site_id_stmt.try_borrow()?) + } + pub fn get_merge_pk_only_insert_stmt( &self, db: *mut sqlite3, @@ -871,6 +887,7 @@ pub fn pull_table_info( set_winner_clock_stmt: RefCell::new(None), local_cl_stmt: RefCell::new(None), col_version_stmt: RefCell::new(None), + col_site_id_stmt: RefCell::new(None), select_key_stmt: RefCell::new(None), insert_key_stmt: RefCell::new(None), diff --git a/py/correctness/tests/test_sync.py b/py/correctness/tests/test_sync.py index a1d74d91..fd30f2b1 100644 --- a/py/correctness/tests/test_sync.py +++ b/py/correctness/tests/test_sync.py @@ -383,6 +383,7 @@ def make_dbs(): def test_merge_same_w_tie_breaker(): db1 = create_basic_db() db2 = create_basic_db() + db3 = create_basic_db() db1.execute("INSERT INTO foo (a,b) VALUES (1,2);") db1.execute("SELECT crsql_config_set('merge-equal-values', 1);") @@ -392,13 +393,47 @@ def test_merge_same_w_tie_breaker(): db2.execute("SELECT crsql_config_set('merge-equal-values', 1);") db2.commit() + db3.execute("INSERT INTO foo (a,b) VALUES (1,2);") + db3.execute("SELECT crsql_config_set('merge-equal-values', 1);") + db3.commit() + sync_left_to_right(db1, db2, 0) - changes12 = db2.execute("SELECT \"table\", pk, cid, val, col_version, site_id FROM crsql_changes").fetchall() + changes2 = db2.execute("SELECT \"table\", pk, cid, val, col_version, site_id, db_version FROM crsql_changes").fetchall() sync_left_to_right(db2, db1, 0) - changes21 = db1.execute("SELECT \"table\", pk, cid, val, col_version, site_id FROM crsql_changes").fetchall() + changes1 = db1.execute("SELECT \"table\", pk, cid, val, col_version, site_id, db_version FROM crsql_changes").fetchall() + + sync_left_to_right(db2, db3, 0) + changes3 = db3.execute("SELECT \"table\", pk, cid, val, col_version, site_id, db_version FROM crsql_changes").fetchall() + + # check that everything by db_version is the same + assert (changes2[:-6] == changes1[:-6] == changes3[:-6]) - assert (changes12 == changes21) + # Test that we're stable / do not loop when we tie break equal values + + sync_left_to_right(db2, db1, 0) + changes1_2 = db1.execute("SELECT \"table\", pk, cid, val, col_version, site_id, db_version FROM crsql_changes").fetchall() + sync_left_to_right(db3, db2, 0) + changes2_2 = db2.execute("SELECT \"table\", pk, cid, val, col_version, site_id, db_version FROM crsql_changes").fetchall() + sync_left_to_right(db1, db3, 0) + changes3_2 = db3.execute("SELECT \"table\", pk, cid, val, col_version, site_id, db_version FROM crsql_changes").fetchall() + + # everything should stay the same, including db_version + assert (changes1 == changes1_2) + assert (changes2 == changes2_2) + assert (changes3 == changes3_2) + + sync_left_to_right(db3, db1, 0) + changes1_2 = db1.execute("SELECT \"table\", pk, cid, val, col_version, site_id, db_version FROM crsql_changes").fetchall() + sync_left_to_right(db1, db2, 0) + changes2_2 = db2.execute("SELECT \"table\", pk, cid, val, col_version, site_id, db_version FROM crsql_changes").fetchall() + sync_left_to_right(db2, db3, 0) + changes3_2 = db3.execute("SELECT \"table\", pk, cid, val, col_version, site_id, db_version FROM crsql_changes").fetchall() + + # everything should stay the same, including db_version + assert (changes1 == changes1_2) + assert (changes2 == changes2_2) + assert (changes3 == changes3_2) def test_merge_matching_clocks_lesser_value():