From d6470b7800c5c41c761992585b8fe1f51ac82950 Mon Sep 17 00:00:00 2001 From: Imri Paran Date: Mon, 25 Nov 2024 15:53:50 +0100 Subject: [PATCH] MINOR: fix(data-diff): get added columns (#18694) * fix(data-diff): get added columns - use both columns to calculate schema diff * fix tests --- .../table_diff_params_setter.py | 32 +++++++++++-------- .../validations/table/sqlalchemy/tableDiff.py | 17 +++++----- .../data_quality/test_data_diff.py | 2 +- 3 files changed, 29 insertions(+), 22 deletions(-) diff --git a/ingestion/src/metadata/data_quality/validations/runtime_param_setter/table_diff_params_setter.py b/ingestion/src/metadata/data_quality/validations/runtime_param_setter/table_diff_params_setter.py index b721e22910ba..1bd8c0a7732b 100644 --- a/ingestion/src/metadata/data_quality/validations/runtime_param_setter/table_diff_params_setter.py +++ b/ingestion/src/metadata/data_quality/validations/runtime_param_setter/table_diff_params_setter.py @@ -10,7 +10,7 @@ # limitations under the License. """Module that defines the TableDiffParamsSetter class.""" from ast import literal_eval -from typing import List, Optional +from typing import List, Optional, Set from urllib.parse import urlparse from metadata.data_quality.validations import utils @@ -75,7 +75,9 @@ def get_parameters(self, test_case) -> TableDiffRuntimeParameters: DatabaseService, table2.service.id, nullable=False ) key_columns = self.get_key_columns(test_case) - extra_columns = self.get_extra_columns(key_columns, test_case) + extra_columns = self.get_extra_columns( + key_columns, test_case, self.table_entity.columns, table2.columns + ) return TableDiffRuntimeParameters( table_profile_config=self.table_entity.tableProfilerConfig, table1=TableParameter( @@ -111,8 +113,8 @@ def get_parameters(self, test_case) -> TableDiffRuntimeParameters: case_sensitive=case_sensitive_columns, ), ), - keyColumns=key_columns, - extraColumns=extra_columns, + keyColumns=list(key_columns), + extraColumns=list(extra_columns), whereClause=self.build_where_clause(test_case), ) @@ -134,21 +136,25 @@ def build_where_clause(self, test_case) -> Optional[str]: return " AND ".join(where_clauses) def get_extra_columns( - self, key_columns: List[str], test_case - ) -> Optional[List[str]]: + self, + key_columns: Set[str], + test_case, + left_columns: List[Column], + right_columns: List[Column], + ) -> Optional[Set[str]]: extra_columns_param = self.get_parameter(test_case, "useColumns", None) if extra_columns_param is not None: extra_columns: List[str] = literal_eval(extra_columns_param) self.validate_columns(extra_columns) - return extra_columns + return set(extra_columns) if extra_columns_param is None: extra_columns_param = [] - for column in self.table_entity.columns: + for column in left_columns + right_columns: if column.name.root not in key_columns: extra_columns_param.insert(0, column.name.root) - return extra_columns_param + return set(extra_columns_param) - def get_key_columns(self, test_case) -> List[str]: + def get_key_columns(self, test_case) -> Set[str]: key_columns_param = self.get_parameter(test_case, "keyColumns", "[]") key_columns: List[str] = literal_eval(key_columns_param) if key_columns: @@ -167,13 +173,13 @@ def get_key_columns(self, test_case) -> List[str]: "Could not find primary key or unique constraint columns.\n", "Specify 'keyColumns' to explicitly set the columns to use as keys.", ) - return key_columns + return set(key_columns) @staticmethod def filter_relevant_columns( columns: List[Column], - key_columns: List[str], - extra_columns: List[str], + key_columns: Set[str], + extra_columns: Set[str], case_sensitive: bool, ) -> List[Column]: validated_columns = ( diff --git a/ingestion/src/metadata/data_quality/validations/table/sqlalchemy/tableDiff.py b/ingestion/src/metadata/data_quality/validations/table/sqlalchemy/tableDiff.py index 08f31daabf41..1221f66f8dff 100644 --- a/ingestion/src/metadata/data_quality/validations/table/sqlalchemy/tableDiff.py +++ b/ingestion/src/metadata/data_quality/validations/table/sqlalchemy/tableDiff.py @@ -273,15 +273,16 @@ def get_incomparable_columns(self) -> List[str]: ).with_schema() result = [] for column in table1.key_columns + table1.extra_columns: - col1_type = self._get_column_python_type( - table1._schema[column] # pylint: disable=protected-access - ) - # Skip columns that are not in the second table. We cover this case in get_changed_added_columns. - if table2._schema.get(column) is None: # pylint: disable=protected-access + col1 = table1._schema.get(column) # pylint: disable=protected-access + if col1 is None: + # Skip columns that are not in the first table. We cover this case in get_changed_added_columns. continue - col2_type = self._get_column_python_type( - table2._schema[column] # pylint: disable=protected-access - ) + col2 = table2._schema.get(column) # pylint: disable=protected-access + if col2 is None: + # Skip columns that are not in the second table. We cover this case in get_changed_added_columns. + continue + col1_type = self._get_column_python_type(col1) + col2_type = self._get_column_python_type(col2) if is_numeric(col1_type) and is_numeric(col2_type): continue if col1_type != col2_type: diff --git a/ingestion/tests/integration/data_quality/test_data_diff.py b/ingestion/tests/integration/data_quality/test_data_diff.py index 4f5c93281c94..a515928eb128 100644 --- a/ingestion/tests/integration/data_quality/test_data_diff.py +++ b/ingestion/tests/integration/data_quality/test_data_diff.py @@ -305,7 +305,7 @@ def __init__(self, *args, **kwargs): testCaseStatus=TestCaseStatus.Failed, testResultValue=[ TestResultValue(name="removedColumns", value="1"), - TestResultValue(name="addedColumns", value="0"), + TestResultValue(name="addedColumns", value="1"), TestResultValue(name="changedColumns", value="0"), ], ),