Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ISSUE #19175: Handle pk for snowflake in data diff #19734

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions ingestion/src/metadata/data_quality/validations/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,16 @@
from metadata.generated.schema.entity.services.databaseService import (
DatabaseServiceType,
)
from metadata.ingestion.models.custom_pydantic import CustomSecretStr


class TableParameter(BaseModel):
serviceUrl: str
path: str
columns: List[Column]
database_service_type: DatabaseServiceType
privateKey: Optional[CustomSecretStr]
passPhrase: Optional[CustomSecretStr]


class TableDiffRuntimeParameters(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
"""Base class for param setter logic for table data diff"""

from typing import List, Optional, Set
from urllib.parse import urlparse

from metadata.data_quality.validations.models import Column, TableParameter
from metadata.generated.schema.entity.data.table import Table
from metadata.generated.schema.entity.services.databaseService import DatabaseService
from metadata.ingestion.source.connections import get_connection
from metadata.profiler.orm.registry import Dialects
from metadata.utils import fqn
from metadata.utils.collections import CaseInsensitiveList


class BaseTableParameter:
"""Base table parameter setter for the table diff test."""

def get(
self,
service: DatabaseService,
entity: Table,
key_columns,
extra_columns,
case_sensitive_columns,
service_url: Optional[str],
) -> TableParameter:
"""Getter table parameter for the table diff test.

Returns:
TableParameter
"""
return TableParameter(
database_service_type=service.serviceType,
path=self.get_data_diff_table_path(entity.fullyQualifiedName.root),
serviceUrl=self.get_data_diff_url(
service,
entity.fullyQualifiedName.root,
override_url=service_url,
),
columns=self.filter_relevant_columns(
entity.columns,
key_columns,
extra_columns,
case_sensitive=case_sensitive_columns,
),
privateKey=None,
passPhrase=None,
)

@staticmethod
def get_data_diff_table_path(table_fqn: str) -> str:
"""Get the data diff table path.

Args:
table_fqn (str): The fully qualified name of the table

Returns:
str
"""
_, _, schema, table = fqn.split(table_fqn)
return fqn._build( # pylint: disable=protected-access
"___SERVICE___", "__DATABASE__", schema, table
).replace("___SERVICE___.__DATABASE__.", "")

@staticmethod
def get_data_diff_url(
db_service: DatabaseService, table_fqn, override_url: Optional[str] = None
) -> str:
"""Get the url for the data diff service.

Args:
db_service (DatabaseService): The database service entity
table_fqn (str): The fully qualified name of the table
override_url (Optional[str], optional): Override the url. Defaults to None.

Returns:
str: The url for the data diff service
"""
source_url = (
str(get_connection(db_service.connection.config).url)
if not override_url
else override_url
)
url = urlparse(source_url)
# remove the driver name from the url because table-diff doesn't support it
kwargs = {"scheme": url.scheme.split("+")[0]}
service, database, schema, table = fqn.split( # pylint: disable=unused-variable
table_fqn
)
# path needs to include the database AND schema in some of the connectors
if hasattr(db_service.connection.config, "supportsDatabase"):
kwargs["path"] = f"/{database}"
if kwargs["scheme"] in {Dialects.MSSQL, Dialects.Snowflake}:
kwargs["path"] = f"/{database}/{schema}"
return url._replace(**kwargs).geturl()

@staticmethod
def filter_relevant_columns(
columns: List[Column],
key_columns: Set[str],
extra_columns: Set[str],
case_sensitive: bool,
) -> List[Column]:
"""Filter relevant columns.

Args:
columns (List[Column]): list of columns
key_columns (Set[str]): set of key columns
extra_columns (Set[str]): set of extra columns
case_sensitive (bool): case sensitive flag

Returns:
List[Column]
"""
validated_columns = (
[*key_columns, *extra_columns]
if case_sensitive
else CaseInsensitiveList([*key_columns, *extra_columns])
)
return [c for c in columns if c.name.root in validated_columns]
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,32 @@
from urllib.parse import urlparse

from metadata.data_quality.validations import utils
from metadata.data_quality.validations.models import (
Column,
TableDiffRuntimeParameters,
TableParameter,
)
from metadata.data_quality.validations.models import Column, TableDiffRuntimeParameters
from metadata.data_quality.validations.runtime_param_setter.param_setter import (
RuntimeParameterSetter,
)
from metadata.generated.schema.entity.data.table import Constraint, Table
from metadata.generated.schema.entity.services.databaseService import DatabaseService
from metadata.generated.schema.entity.services.serviceType import ServiceType
from metadata.generated.schema.tests.testCase import TestCase
from metadata.ingestion.source.connections import get_connection
from metadata.profiler.orm.registry import Dialects
from metadata.utils import fqn
from metadata.utils.collections import CaseInsensitiveList
from metadata.utils.importer import get_module_dir, import_from_module


def get_for_source(
service_type: ServiceType, source_type: str, from_: str = "ingestion"
):
return import_from_module(
"metadata.{}.source.{}.{}.{}.ServiceSpec".format( # pylint: disable=C0209
from_,
service_type.name.lower(),
get_module_dir(source_type),
"service_spec",
)
)


class TableDiffParamsSetter(RuntimeParameterSetter):
Expand All @@ -51,67 +62,70 @@ def __init__(self, *args, **kwargs):
}

def get_parameters(self, test_case) -> TableDiffRuntimeParameters:
# Using the specs class method causes circular import as TestSuiteInterface
# imports RuntimeParameterSetter
cls_path = get_for_source(
ServiceType.Database,
source_type=self.service_connection_config.type.value.lower(),
).data_diff
cls = import_from_module(cls_path)()

service1: DatabaseService = self.ometa_client.get_by_id(
DatabaseService, self.table_entity.service.id, nullable=False
)

service1_url = (
str(get_connection(self.service_connection_config).url)
if self.service_connection_config
else None
)
service1: DatabaseService = self.ometa_client.get_by_id(
DatabaseService, self.table_entity.service.id, nullable=False
)

table2_fqn = self.get_parameter(test_case, "table2")
case_sensitive_columns: bool = utils.get_bool_test_case_param(
test_case.parameterValues, "caseSensitiveColumns"
)
if table2_fqn is None:
raise ValueError("table2 not set")
table2: Table = self.ometa_client.get_by_name(
Table, fqn=table2_fqn, nullable=False
)
service2_url = (
service1_url if table2.service == self.table_entity.service else None
)
service2: DatabaseService = self.ometa_client.get_by_id(
DatabaseService, table2.service.id, nullable=False
)
service2_url = (
self.get_parameter(test_case, "service2Url") or service1_url
if table2.service == self.table_entity.service
else None
)

key_columns = self.get_key_columns(test_case)
extra_columns = self.get_extra_columns(
key_columns, test_case, self.table_entity.columns, table2.columns
extra_columns = (
self.get_extra_columns(
key_columns, test_case, self.table_entity.columns, table2.columns
)
or set()
)
case_sensitive_columns: bool = (
utils.get_bool_test_case_param(
test_case.parameterValues, "caseSensitiveColumns"
)
or False
)

return TableDiffRuntimeParameters(
table_profile_config=self.table_entity.tableProfilerConfig,
table1=TableParameter(
database_service_type=service1.serviceType,
path=self.get_data_diff_table_path(
self.table_entity.fullyQualifiedName.root
),
serviceUrl=self.get_data_diff_url(
service1,
self.table_entity.fullyQualifiedName.root,
override_url=service1_url,
),
columns=self.filter_relevant_columns(
self.table_entity.columns,
key_columns,
extra_columns,
case_sensitive=case_sensitive_columns,
),
table1=cls.get(
service1,
self.table_entity,
key_columns,
extra_columns,
case_sensitive_columns,
service1_url,
),
table2=TableParameter(
database_service_type=service2.serviceType,
path=self.get_data_diff_table_path(table2_fqn),
serviceUrl=self.get_data_diff_url(
service2,
table2_fqn,
override_url=self.get_parameter(test_case, "service2Url")
or service2_url,
),
columns=self.filter_relevant_columns(
table2.columns,
key_columns,
extra_columns,
case_sensitive=case_sensitive_columns,
),
table2=cls.get(
service2,
table2,
key_columns,
extra_columns,
case_sensitive_columns,
service2_url,
),
keyColumns=list(key_columns),
extraColumns=list(extra_columns),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,13 +265,25 @@ def get_incomparable_columns(self) -> List[str]:
self.runtime_params.keyColumns,
extra_columns=self.runtime_params.extraColumns,
case_sensitive=self.get_case_sensitive(),
key_content=self.runtime_params.table1.privateKey.get_secret_value()
if self.runtime_params.table1.privateKey
else None,
private_key_passphrase=self.runtime_params.table1.passPhrase.get_secret_value()
if self.runtime_params.table1.passPhrase
else None,
).with_schema()
table2 = data_diff.connect_to_table(
self.runtime_params.table2.serviceUrl,
self.runtime_params.table2.path,
self.runtime_params.keyColumns,
extra_columns=self.runtime_params.extraColumns,
case_sensitive=self.get_case_sensitive(),
key_content=self.runtime_params.table2.privateKey.get_secret_value()
if self.runtime_params.table2.privateKey
else None,
private_key_passphrase=self.runtime_params.table2.passPhrase.get_secret_value()
if self.runtime_params.table2.passPhrase
else None,
).with_schema()
result = []
for column in table1.key_columns + table1.extra_columns:
Expand Down Expand Up @@ -332,13 +344,25 @@ def get_table_diff(self) -> DiffResultWrapper:
self.runtime_params.keyColumns, # type: ignore
case_sensitive=self.get_case_sensitive(),
where=left_where,
key_content=self.runtime_params.table1.privateKey.get_secret_value()
if self.runtime_params.table1.privateKey
else None,
private_key_passphrase=self.runtime_params.table1.passPhrase.get_secret_value()
if self.runtime_params.table1.passPhrase
else None,
)
table2 = data_diff.connect_to_table(
self.runtime_params.table2.serviceUrl,
self.runtime_params.table2.path,
self.runtime_params.keyColumns, # type: ignore
case_sensitive=self.get_case_sensitive(),
where=right_where,
key_content=self.runtime_params.table1.privateKey.get_secret_value()
if self.runtime_params.table1.privateKey
else None,
private_key_passphrase=self.runtime_params.table1.passPhrase.get_secret_value()
if self.runtime_params.table1.passPhrase
else None,
)
data_diff_kwargs = {
"key_columns": self.runtime_params.keyColumns,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""Snowflake spec for data diff"""

from typing import Optional, cast

from metadata.data_quality.validations.models import TableParameter
from metadata.data_quality.validations.runtime_param_setter.base_diff_params_setter import (
BaseTableParameter,
)
from metadata.generated.schema.entity.data.table import Table
from metadata.generated.schema.entity.services.databaseService import DatabaseService
from metadata.ingestion.source.database.snowflake.metadata import SnowflakeConnection


class SnowflakeTableParameter(BaseTableParameter):
"""SnowflakeTableParameter class for setting runtime parameters for data diff"""

def get(
self,
service: DatabaseService,
entity: Table,
key_columns,
extra_columns,
case_sensitive_columns,
service_url: Optional[str],
) -> TableParameter:
table_param: TableParameter = super().get(
service,
entity,
key_columns,
extra_columns,
case_sensitive_columns,
service_url,
)
connection_config = cast(SnowflakeConnection, service.connection.config)
table_param.privateKey = connection_config.privateKey
table_param.passPhrase = connection_config.snowflakePrivatekeyPassphrase
return table_param
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from metadata.data_quality.interface.sqlalchemy.snowflake.test_suite_interface import (
SnowflakeTestSuiteInterface,
)
from metadata.ingestion.source.database.snowflake.data_diff.data_diff import (
SnowflakeTableParameter,
)
from metadata.ingestion.source.database.snowflake.lineage import SnowflakeLineageSource
from metadata.ingestion.source.database.snowflake.metadata import SnowflakeSource
from metadata.ingestion.source.database.snowflake.profiler.profiler import (
Expand All @@ -17,4 +20,5 @@
profiler_class=SnowflakeProfiler,
test_suite_class=SnowflakeTestSuiteInterface,
sampler_class=SnowflakeSampler,
data_diff=SnowflakeTableParameter,
)
Loading
Loading