diff --git a/data_diff/__init__.py b/data_diff/__init__.py index 0a3a9d71..44e38a3e 100644 --- a/data_diff/__init__.py +++ b/data_diff/__init__.py @@ -79,6 +79,8 @@ def diff_tables( materialize_all_rows: bool = False, # Maximum number of rows to write when materializing, per thread. (joindiff only) table_write_limit: int = TABLE_WRITE_LIMIT, + # Skips diffing any rows with null keys. (joindiff only) + skip_null_keys: bool = False, ) -> Iterator: """Finds the diff between table1 and table2. @@ -107,6 +109,7 @@ def diff_tables( materialize_to_table (Union[str, DbPath], optional): Path of new table to write diff results to. Disabled if not provided. Used for `JOINDIFF`. materialize_all_rows (bool): Materialize every row, not just those that are different. (used for `JOINDIFF`. default: False) table_write_limit (int): Maximum number of rows to write when materializing, per thread. + skip_null_keys (bool): Skips diffing any rows with null PKs (displays a warning if any are null) (used for `JOINDIFF`. default: False) Note: The following parameters are used to override the corresponding attributes of the given :class:`TableSegment` instances: @@ -168,6 +171,7 @@ def diff_tables( materialize_to_table=materialize_to_table, materialize_all_rows=materialize_all_rows, table_write_limit=table_write_limit, + skip_null_keys=skip_null_keys, ) else: raise ValueError(f"Unknown algorithm: {algorithm}") diff --git a/data_diff/dbt.py b/data_diff/dbt.py index 4255fa9b..fe1396d3 100644 --- a/data_diff/dbt.py +++ b/data_diff/dbt.py @@ -238,6 +238,7 @@ def _local_diff(diff_vars: TDiffVars) -> None: algorithm=Algorithm.JOINDIFF, extra_columns=extra_columns, where=diff_vars.where_filter, + skip_null_keys=True, ) if list(diff): diff --git a/data_diff/joindiff_tables.py b/data_diff/joindiff_tables.py index 93d806df..bfe77204 100644 --- a/data_diff/joindiff_tables.py +++ b/data_diff/joindiff_tables.py @@ -131,6 +131,7 @@ class JoinDiffer(TableDiffer): materialize_to_table (DbPath, optional): Path of new table to write diff results to. Disabled if not provided. materialize_all_rows (bool): Materialize every row, not just those that are different. (default: False) table_write_limit (int): Maximum number of rows to write when materializing, per thread. + skip_null_keys (bool): Skips diffing any rows with null PKs (displays a warning if any are null) (default: False) """ validate_unique_key: bool = True @@ -138,6 +139,7 @@ class JoinDiffer(TableDiffer): materialize_to_table: DbPath = None materialize_all_rows: bool = False table_write_limit: int = TABLE_WRITE_LIMIT + skip_null_keys: bool = False stats: dict = {} @@ -209,7 +211,11 @@ def _diff_segments( if is_xa and is_xb: # Can't both be exclusive, meaning a pk is NULL # This can happen if the explicit null test didn't finish running yet - raise ValueError("NULL values in one or more primary keys") + if self.skip_null_keys: + # warning is thrown in explicit null test + continue + else: + raise ValueError("NULL values in one or more primary keys") # _is_diff, a_row, b_row = _slice_tuple(x, len(is_diff_cols), len(a_cols), len(b_cols)) _is_diff, ab_row = _slice_tuple(x, len(is_diff_cols), len(a_cols) + len(b_cols)) a_row, b_row = ab_row[::2], ab_row[1::2] @@ -252,7 +258,12 @@ def _test_null_keys(self, table1, table2): q = t.select(*this[key_columns]).where(or_(this[k] == None for k in key_columns)) nulls = ts.database.query(q, list) if nulls: - raise ValueError(f"NULL values in one or more primary keys of {ts.table_path}") + if self.skip_null_keys: + logger.warning( + f"NULL values in one or more primary keys of {ts.table_path}. Skipping rows with NULL keys." + ) + else: + raise ValueError(f"NULL values in one or more primary keys of {ts.table_path}") def _collect_stats(self, i, table_seg: TableSegment, info_tree: InfoTree): logger.debug(f"Collecting stats for table #{i}") diff --git a/tests/test_dbt.py b/tests/test_dbt.py index 78eaa506..fa11a835 100644 --- a/tests/test_dbt.py +++ b/tests/test_dbt.py @@ -505,6 +505,7 @@ def test_local_diff(self, mock_diff_tables): algorithm=Algorithm.JOINDIFF, extra_columns=ANY, where=where, + skip_null_keys=True, ) self.assertEqual(len(mock_diff_tables.call_args[1]["extra_columns"]), 2) self.assertEqual(mock_connect.call_count, 2) @@ -549,6 +550,7 @@ def test_local_diff_types_differ(self, mock_diff_tables): algorithm=Algorithm.JOINDIFF, extra_columns=ANY, where=where, + skip_null_keys=True, ) self.assertEqual(len(mock_diff_tables.call_args[1]["extra_columns"]), 1) self.assertEqual(mock_connect.call_count, 2) @@ -584,7 +586,13 @@ def test_local_diff_no_diffs(self, mock_diff_tables): _local_diff(diff_vars) mock_diff_tables.assert_called_once_with( - mock_table1, mock_table2, threaded=True, algorithm=Algorithm.JOINDIFF, extra_columns=ANY, where=where + mock_table1, + mock_table2, + threaded=True, + algorithm=Algorithm.JOINDIFF, + extra_columns=ANY, + where=where, + skip_null_keys=True, ) self.assertEqual(len(mock_diff_tables.call_args[1]["extra_columns"]), 2) self.assertEqual(mock_connect.call_count, 2)