Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix site_id tie-breaker #422

Merged
merged 1 commit into from
Jan 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 39 additions & 4 deletions core/rs/core/src/changes_vtab_write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,45 @@ fn did_cid_win(
reset_cached_stmt(col_val_stmt.stmt)?;
if ret == 0 && unsafe { (*ext_data).mergeEqualValues == 1 } {
// values are the same (ret == 0) and the option to tie break on site_id is true
ret = unsafe {
let my_site_id = core::slice::from_raw_parts((*ext_data).siteId, 16);
insert_site_id.cmp(my_site_id) as c_int
};
let col_site_id_stmt_ref = tbl_info.get_col_site_id_stmt(db)?;
let col_site_id_stmt = col_site_id_stmt_ref.as_ref().ok_or(ResultCode::ERROR)?;

let bind_result = col_site_id_stmt.bind_int64(1, key);
if let Err(rc) = bind_result {
reset_cached_stmt(col_site_id_stmt.stmt)?;
return Err(rc);
}
if let Err(rc) = col_site_id_stmt.bind_text(2, col_name, sqlite::Destructor::STATIC)
{
reset_cached_stmt(col_site_id_stmt.stmt)?;
return Err(rc);
}

match col_site_id_stmt.step() {
Ok(ResultCode::ROW) => {
let local_site_id = col_site_id_stmt.column_blob(0)?;
ret = insert_site_id.cmp(local_site_id) as c_int;

// reset the stmt after, we're accessing a slice in-memory
reset_cached_stmt(col_site_id_stmt.stmt)?;
}
Ok(ResultCode::DONE) => {
reset_cached_stmt(col_site_id_stmt.stmt)?;
let err = CString::new(format!(
"could not find site_id for previous change, cr-sqlite clock table might be corrupt for tbl {}",
insert_tbl
))?;
unsafe { *errmsg = err.into_raw() };
return Err(ResultCode::ERROR);
}
Ok(rc) | Err(rc) => {
reset_cached_stmt(col_site_id_stmt.stmt)?;
let err =
CString::new("Bad return code when selecting local column site_id")?;
unsafe { *errmsg = err.into_raw() };
return Err(rc);
}
}
}
return Ok(ret > 0);
}
Expand Down
17 changes: 17 additions & 0 deletions core/rs/core/src/tableinfo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ pub struct TableInfo {
set_winner_clock_stmt: RefCell<Option<ManagedStmt>>,
local_cl_stmt: RefCell<Option<ManagedStmt>>,
col_version_stmt: RefCell<Option<ManagedStmt>>,
col_site_id_stmt: RefCell<Option<ManagedStmt>>,
merge_pk_only_insert_stmt: RefCell<Option<ManagedStmt>>,
merge_delete_stmt: RefCell<Option<ManagedStmt>>,
merge_delete_drop_clocks_stmt: RefCell<Option<ManagedStmt>>,
Expand Down Expand Up @@ -337,6 +338,21 @@ impl TableInfo {
Ok(self.col_version_stmt.try_borrow()?)
}

pub fn get_col_site_id_stmt(
&self,
db: *mut sqlite3,
) -> Result<Ref<Option<ManagedStmt>>, ResultCode> {
if self.col_site_id_stmt.try_borrow()?.is_none() {
let sql = format!(
"SELECT site_id FROM crsql_site_id WHERE ordinal = (SELECT site_id FROM \"{table_name}__crsql_clock\" WHERE key = ? AND col_name = ?)",
table_name = crate::util::escape_ident(&self.tbl_name),
);
let ret = db.prepare_v3(&sql, sqlite::PREPARE_PERSISTENT)?;
*self.col_site_id_stmt.try_borrow_mut()? = Some(ret);
}
Ok(self.col_site_id_stmt.try_borrow()?)
}

pub fn get_merge_pk_only_insert_stmt(
&self,
db: *mut sqlite3,
Expand Down Expand Up @@ -871,6 +887,7 @@ pub fn pull_table_info(
set_winner_clock_stmt: RefCell::new(None),
local_cl_stmt: RefCell::new(None),
col_version_stmt: RefCell::new(None),
col_site_id_stmt: RefCell::new(None),

select_key_stmt: RefCell::new(None),
insert_key_stmt: RefCell::new(None),
Expand Down
41 changes: 38 additions & 3 deletions py/correctness/tests/test_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,7 @@ def make_dbs():
def test_merge_same_w_tie_breaker():
db1 = create_basic_db()
db2 = create_basic_db()
db3 = create_basic_db()

db1.execute("INSERT INTO foo (a,b) VALUES (1,2);")
db1.execute("SELECT crsql_config_set('merge-equal-values', 1);")
Expand All @@ -392,13 +393,47 @@ def test_merge_same_w_tie_breaker():
db2.execute("SELECT crsql_config_set('merge-equal-values', 1);")
db2.commit()

db3.execute("INSERT INTO foo (a,b) VALUES (1,2);")
db3.execute("SELECT crsql_config_set('merge-equal-values', 1);")
db3.commit()

sync_left_to_right(db1, db2, 0)
changes12 = db2.execute("SELECT \"table\", pk, cid, val, col_version, site_id FROM crsql_changes").fetchall()
changes2 = db2.execute("SELECT \"table\", pk, cid, val, col_version, site_id, db_version FROM crsql_changes").fetchall()

sync_left_to_right(db2, db1, 0)
changes21 = db1.execute("SELECT \"table\", pk, cid, val, col_version, site_id FROM crsql_changes").fetchall()
changes1 = db1.execute("SELECT \"table\", pk, cid, val, col_version, site_id, db_version FROM crsql_changes").fetchall()

sync_left_to_right(db2, db3, 0)
changes3 = db3.execute("SELECT \"table\", pk, cid, val, col_version, site_id, db_version FROM crsql_changes").fetchall()

# check that everything by db_version is the same
assert (changes2[:-6] == changes1[:-6] == changes3[:-6])

assert (changes12 == changes21)
# Test that we're stable / do not loop when we tie break equal values

sync_left_to_right(db2, db1, 0)
changes1_2 = db1.execute("SELECT \"table\", pk, cid, val, col_version, site_id, db_version FROM crsql_changes").fetchall()
sync_left_to_right(db3, db2, 0)
changes2_2 = db2.execute("SELECT \"table\", pk, cid, val, col_version, site_id, db_version FROM crsql_changes").fetchall()
sync_left_to_right(db1, db3, 0)
changes3_2 = db3.execute("SELECT \"table\", pk, cid, val, col_version, site_id, db_version FROM crsql_changes").fetchall()

# everything should stay the same, including db_version
assert (changes1 == changes1_2)
assert (changes2 == changes2_2)
assert (changes3 == changes3_2)

sync_left_to_right(db3, db1, 0)
changes1_2 = db1.execute("SELECT \"table\", pk, cid, val, col_version, site_id, db_version FROM crsql_changes").fetchall()
sync_left_to_right(db1, db2, 0)
changes2_2 = db2.execute("SELECT \"table\", pk, cid, val, col_version, site_id, db_version FROM crsql_changes").fetchall()
sync_left_to_right(db2, db3, 0)
changes3_2 = db3.execute("SELECT \"table\", pk, cid, val, col_version, site_id, db_version FROM crsql_changes").fetchall()

# everything should stay the same, including db_version
assert (changes1 == changes1_2)
assert (changes2 == changes2_2)
assert (changes3 == changes3_2)


def test_merge_matching_clocks_lesser_value():
Expand Down