Skip to content
This repository has been archived by the owner on May 17, 2024. It is now read-only.

continue --dbt diff when null PKs exist #585

Merged
merged 4 commits into from
May 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions data_diff/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ def diff_tables(
materialize_all_rows: bool = False,
# Maximum number of rows to write when materializing, per thread. (joindiff only)
table_write_limit: int = TABLE_WRITE_LIMIT,
# Skips diffing any rows with null keys. (joindiff only)
skip_null_keys: bool = False,
) -> Iterator:
"""Finds the diff between table1 and table2.

Expand Down Expand Up @@ -107,6 +109,7 @@ def diff_tables(
materialize_to_table (Union[str, DbPath], optional): Path of new table to write diff results to. Disabled if not provided. Used for `JOINDIFF`.
materialize_all_rows (bool): Materialize every row, not just those that are different. (used for `JOINDIFF`. default: False)
table_write_limit (int): Maximum number of rows to write when materializing, per thread.
skip_null_keys (bool): Skips diffing any rows with null PKs (displays a warning if any are null) (used for `JOINDIFF`. default: False)

Note:
The following parameters are used to override the corresponding attributes of the given :class:`TableSegment` instances:
Expand Down Expand Up @@ -168,6 +171,7 @@ def diff_tables(
materialize_to_table=materialize_to_table,
materialize_all_rows=materialize_all_rows,
table_write_limit=table_write_limit,
skip_null_keys=skip_null_keys,
)
else:
raise ValueError(f"Unknown algorithm: {algorithm}")
Expand Down
1 change: 1 addition & 0 deletions data_diff/dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ def _local_diff(diff_vars: TDiffVars) -> None:
algorithm=Algorithm.JOINDIFF,
extra_columns=extra_columns,
where=diff_vars.where_filter,
skip_null_keys=True,
)

if list(diff):
Expand Down
15 changes: 13 additions & 2 deletions data_diff/joindiff_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,15 @@ class JoinDiffer(TableDiffer):
materialize_to_table (DbPath, optional): Path of new table to write diff results to. Disabled if not provided.
materialize_all_rows (bool): Materialize every row, not just those that are different. (default: False)
table_write_limit (int): Maximum number of rows to write when materializing, per thread.
skip_null_keys (bool): Skips diffing any rows with null PKs (displays a warning if any are null) (default: False)
"""

validate_unique_key: bool = True
sample_exclusive_rows: bool = False
materialize_to_table: DbPath = None
materialize_all_rows: bool = False
table_write_limit: int = TABLE_WRITE_LIMIT
skip_null_keys: bool = False

stats: dict = {}

Expand Down Expand Up @@ -209,7 +211,11 @@ def _diff_segments(
if is_xa and is_xb:
# Can't both be exclusive, meaning a pk is NULL
# This can happen if the explicit null test didn't finish running yet
raise ValueError("NULL values in one or more primary keys")
if self.skip_null_keys:
# warning is thrown in explicit null test
continue
else:
raise ValueError("NULL values in one or more primary keys")
# _is_diff, a_row, b_row = _slice_tuple(x, len(is_diff_cols), len(a_cols), len(b_cols))
_is_diff, ab_row = _slice_tuple(x, len(is_diff_cols), len(a_cols) + len(b_cols))
a_row, b_row = ab_row[::2], ab_row[1::2]
Expand Down Expand Up @@ -252,7 +258,12 @@ def _test_null_keys(self, table1, table2):
q = t.select(*this[key_columns]).where(or_(this[k] == None for k in key_columns))
nulls = ts.database.query(q, list)
if nulls:
raise ValueError(f"NULL values in one or more primary keys of {ts.table_path}")
if self.skip_null_keys:
logger.warning(
f"NULL values in one or more primary keys of {ts.table_path}. Skipping rows with NULL keys."
)
else:
raise ValueError(f"NULL values in one or more primary keys of {ts.table_path}")

def _collect_stats(self, i, table_seg: TableSegment, info_tree: InfoTree):
logger.debug(f"Collecting stats for table #{i}")
Expand Down
10 changes: 9 additions & 1 deletion tests/test_dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,7 @@ def test_local_diff(self, mock_diff_tables):
algorithm=Algorithm.JOINDIFF,
extra_columns=ANY,
where=where,
skip_null_keys=True,
)
self.assertEqual(len(mock_diff_tables.call_args[1]["extra_columns"]), 2)
self.assertEqual(mock_connect.call_count, 2)
Expand Down Expand Up @@ -549,6 +550,7 @@ def test_local_diff_types_differ(self, mock_diff_tables):
algorithm=Algorithm.JOINDIFF,
extra_columns=ANY,
where=where,
skip_null_keys=True,
)
self.assertEqual(len(mock_diff_tables.call_args[1]["extra_columns"]), 1)
self.assertEqual(mock_connect.call_count, 2)
Expand Down Expand Up @@ -584,7 +586,13 @@ def test_local_diff_no_diffs(self, mock_diff_tables):
_local_diff(diff_vars)

mock_diff_tables.assert_called_once_with(
mock_table1, mock_table2, threaded=True, algorithm=Algorithm.JOINDIFF, extra_columns=ANY, where=where
mock_table1,
mock_table2,
threaded=True,
algorithm=Algorithm.JOINDIFF,
extra_columns=ANY,
where=where,
skip_null_keys=True,
)
self.assertEqual(len(mock_diff_tables.call_args[1]["extra_columns"]), 2)
self.assertEqual(mock_connect.call_count, 2)
Expand Down