Skip to content

Commit

Permalink
MINOR: fix(data-diff): get added columns (#18694)
Browse files Browse the repository at this point in the history
* fix(data-diff): get added columns

- use both columns to calculate schema diff

* fix tests
  • Loading branch information
sushi30 authored Nov 25, 2024
1 parent c164aff commit d6470b7
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# limitations under the License.
"""Module that defines the TableDiffParamsSetter class."""
from ast import literal_eval
from typing import List, Optional
from typing import List, Optional, Set
from urllib.parse import urlparse

from metadata.data_quality.validations import utils
Expand Down Expand Up @@ -75,7 +75,9 @@ def get_parameters(self, test_case) -> TableDiffRuntimeParameters:
DatabaseService, table2.service.id, nullable=False
)
key_columns = self.get_key_columns(test_case)
extra_columns = self.get_extra_columns(key_columns, test_case)
extra_columns = self.get_extra_columns(
key_columns, test_case, self.table_entity.columns, table2.columns
)
return TableDiffRuntimeParameters(
table_profile_config=self.table_entity.tableProfilerConfig,
table1=TableParameter(
Expand Down Expand Up @@ -111,8 +113,8 @@ def get_parameters(self, test_case) -> TableDiffRuntimeParameters:
case_sensitive=case_sensitive_columns,
),
),
keyColumns=key_columns,
extraColumns=extra_columns,
keyColumns=list(key_columns),
extraColumns=list(extra_columns),
whereClause=self.build_where_clause(test_case),
)

Expand All @@ -134,21 +136,25 @@ def build_where_clause(self, test_case) -> Optional[str]:
return " AND ".join(where_clauses)

def get_extra_columns(
self, key_columns: List[str], test_case
) -> Optional[List[str]]:
self,
key_columns: Set[str],
test_case,
left_columns: List[Column],
right_columns: List[Column],
) -> Optional[Set[str]]:
extra_columns_param = self.get_parameter(test_case, "useColumns", None)
if extra_columns_param is not None:
extra_columns: List[str] = literal_eval(extra_columns_param)
self.validate_columns(extra_columns)
return extra_columns
return set(extra_columns)
if extra_columns_param is None:
extra_columns_param = []
for column in self.table_entity.columns:
for column in left_columns + right_columns:
if column.name.root not in key_columns:
extra_columns_param.insert(0, column.name.root)
return extra_columns_param
return set(extra_columns_param)

def get_key_columns(self, test_case) -> List[str]:
def get_key_columns(self, test_case) -> Set[str]:
key_columns_param = self.get_parameter(test_case, "keyColumns", "[]")
key_columns: List[str] = literal_eval(key_columns_param)
if key_columns:
Expand All @@ -167,13 +173,13 @@ def get_key_columns(self, test_case) -> List[str]:
"Could not find primary key or unique constraint columns.\n",
"Specify 'keyColumns' to explicitly set the columns to use as keys.",
)
return key_columns
return set(key_columns)

@staticmethod
def filter_relevant_columns(
columns: List[Column],
key_columns: List[str],
extra_columns: List[str],
key_columns: Set[str],
extra_columns: Set[str],
case_sensitive: bool,
) -> List[Column]:
validated_columns = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -273,15 +273,16 @@ def get_incomparable_columns(self) -> List[str]:
).with_schema()
result = []
for column in table1.key_columns + table1.extra_columns:
col1_type = self._get_column_python_type(
table1._schema[column] # pylint: disable=protected-access
)
# Skip columns that are not in the second table. We cover this case in get_changed_added_columns.
if table2._schema.get(column) is None: # pylint: disable=protected-access
col1 = table1._schema.get(column) # pylint: disable=protected-access
if col1 is None:
# Skip columns that are not in the first table. We cover this case in get_changed_added_columns.
continue
col2_type = self._get_column_python_type(
table2._schema[column] # pylint: disable=protected-access
)
col2 = table2._schema.get(column) # pylint: disable=protected-access
if col2 is None:
# Skip columns that are not in the second table. We cover this case in get_changed_added_columns.
continue
col1_type = self._get_column_python_type(col1)
col2_type = self._get_column_python_type(col2)
if is_numeric(col1_type) and is_numeric(col2_type):
continue
if col1_type != col2_type:
Expand Down
2 changes: 1 addition & 1 deletion ingestion/tests/integration/data_quality/test_data_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def __init__(self, *args, **kwargs):
testCaseStatus=TestCaseStatus.Failed,
testResultValue=[
TestResultValue(name="removedColumns", value="1"),
TestResultValue(name="addedColumns", value="0"),
TestResultValue(name="addedColumns", value="1"),
TestResultValue(name="changedColumns", value="0"),
],
),
Expand Down

0 comments on commit d6470b7

Please sign in to comment.