Skip to content

Commit

Permalink
refactor dbt test logic into dedicated file
Browse files Browse the repository at this point in the history
  • Loading branch information
hsheth2 committed Oct 11, 2023
1 parent 5686090 commit 8801376
Show file tree
Hide file tree
Showing 4 changed files with 276 additions and 269 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@
DBTCommonConfig,
DBTNode,
DBTSourceBase,
DBTTest,
DBTTestResult,
)
from datahub.ingestion.source.dbt.dbt_tests import DBTTest, DBTTestResult

logger = logging.getLogger(__name__)

Expand Down
278 changes: 13 additions & 265 deletions metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_common.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import json
import logging
import re
from abc import abstractmethod
from dataclasses import dataclass, field
from datetime import datetime
from enum import auto
from typing import Any, Callable, ClassVar, Dict, Iterable, List, Optional, Tuple, Union
from typing import Any, Dict, Iterable, List, Optional, Tuple

import pydantic
from pydantic import root_validator, validator
Expand Down Expand Up @@ -34,6 +33,12 @@
from datahub.ingestion.api.source import MetadataWorkUnitProcessor
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.source.common.subtypes import DatasetSubTypes
from datahub.ingestion.source.dbt.dbt_tests import (
DBTTest,
DBTTestResult,
make_assertion_from_test,
make_assertion_result_from_test,
)
from datahub.ingestion.source.sql.sql_types import (
ATHENA_SQL_TYPES_MAP,
BIGQUERY_TYPES_MAP,
Expand Down Expand Up @@ -81,20 +86,7 @@
TimeTypeClass,
)
from datahub.metadata.schema_classes import (
AssertionInfoClass,
AssertionResultClass,
AssertionResultTypeClass,
AssertionRunEventClass,
AssertionRunStatusClass,
AssertionStdAggregationClass,
AssertionStdOperatorClass,
AssertionStdParameterClass,
AssertionStdParametersClass,
AssertionStdParameterTypeClass,
AssertionTypeClass,
DataPlatformInstanceClass,
DatasetAssertionInfoClass,
DatasetAssertionScopeClass,
DatasetPropertiesClass,
GlobalTagsClass,
GlossaryTermsClass,
Expand Down Expand Up @@ -551,134 +543,6 @@ def get_column_type(
return SchemaFieldDataType(type=TypeClass())


@dataclass
class AssertionParams:
scope: Union[DatasetAssertionScopeClass, str]
operator: Union[AssertionStdOperatorClass, str]
aggregation: Union[AssertionStdAggregationClass, str]
parameters: Optional[Callable[[Dict[str, str]], AssertionStdParametersClass]] = None
logic_fn: Optional[Callable[[Dict[str, str]], Optional[str]]] = None


def _get_name_for_relationship_test(kw_args: Dict[str, str]) -> Optional[str]:
"""
Try to produce a useful string for the name of a relationship constraint.
Return None if we fail to
"""
destination_ref = kw_args.get("to")
source_ref = kw_args.get("model")
column_name = kw_args.get("column_name")
dest_field_name = kw_args.get("field")
if not destination_ref or not source_ref or not column_name or not dest_field_name:
# base assertions are violated, bail early
return None
m = re.match(r"^ref\(\'(.*)\'\)$", destination_ref)
if m:
destination_table = m.group(1)
else:
destination_table = destination_ref
m = re.search(r"ref\(\'(.*)\'\)", source_ref)
if m:
source_table = m.group(1)
else:
source_table = source_ref
return f"{source_table}.{column_name} referential integrity to {destination_table}.{dest_field_name}"


@dataclass
class DBTTest:
qualified_test_name: str
column_name: Optional[str]
kw_args: dict

TEST_NAME_TO_ASSERTION_MAP: ClassVar[Dict[str, AssertionParams]] = {
"not_null": AssertionParams(
scope=DatasetAssertionScopeClass.DATASET_COLUMN,
operator=AssertionStdOperatorClass.NOT_NULL,
aggregation=AssertionStdAggregationClass.IDENTITY,
),
"unique": AssertionParams(
scope=DatasetAssertionScopeClass.DATASET_COLUMN,
operator=AssertionStdOperatorClass.EQUAL_TO,
aggregation=AssertionStdAggregationClass.UNIQUE_PROPOTION,
parameters=lambda _: AssertionStdParametersClass(
value=AssertionStdParameterClass(
value="1.0",
type=AssertionStdParameterTypeClass.NUMBER,
)
),
),
"accepted_values": AssertionParams(
scope=DatasetAssertionScopeClass.DATASET_COLUMN,
operator=AssertionStdOperatorClass.IN,
aggregation=AssertionStdAggregationClass.IDENTITY,
parameters=lambda kw_args: AssertionStdParametersClass(
value=AssertionStdParameterClass(
value=json.dumps(kw_args.get("values")),
type=AssertionStdParameterTypeClass.SET,
),
),
),
"relationships": AssertionParams(
scope=DatasetAssertionScopeClass.DATASET_COLUMN,
operator=AssertionStdOperatorClass._NATIVE_,
aggregation=AssertionStdAggregationClass.IDENTITY,
parameters=lambda kw_args: AssertionStdParametersClass(
value=AssertionStdParameterClass(
value=json.dumps(kw_args.get("values")),
type=AssertionStdParameterTypeClass.SET,
),
),
logic_fn=_get_name_for_relationship_test,
),
"dbt_expectations.expect_column_values_to_not_be_null": AssertionParams(
scope=DatasetAssertionScopeClass.DATASET_COLUMN,
operator=AssertionStdOperatorClass.NOT_NULL,
aggregation=AssertionStdAggregationClass.IDENTITY,
),
"dbt_expectations.expect_column_values_to_be_between": AssertionParams(
scope=DatasetAssertionScopeClass.DATASET_COLUMN,
operator=AssertionStdOperatorClass.BETWEEN,
aggregation=AssertionStdAggregationClass.IDENTITY,
parameters=lambda x: AssertionStdParametersClass(
minValue=AssertionStdParameterClass(
value=str(x.get("min_value", "unknown")),
type=AssertionStdParameterTypeClass.NUMBER,
),
maxValue=AssertionStdParameterClass(
value=str(x.get("max_value", "unknown")),
type=AssertionStdParameterTypeClass.NUMBER,
),
),
),
"dbt_expectations.expect_column_values_to_be_in_set": AssertionParams(
scope=DatasetAssertionScopeClass.DATASET_COLUMN,
operator=AssertionStdOperatorClass.IN,
aggregation=AssertionStdAggregationClass.IDENTITY,
parameters=lambda kw_args: AssertionStdParametersClass(
value=AssertionStdParameterClass(
value=json.dumps(kw_args.get("value_set")),
type=AssertionStdParameterTypeClass.SET,
),
),
),
}


@dataclass
class DBTTestResult:
invocation_id: str

status: str
execution_time: datetime

native_results: Dict[str, str]


def string_map(input_map: Dict[str, Any]) -> Dict[str, str]:
return {k: str(v) for k, v in input_map.items()}


@platform_name("dbt")
@config_class(DBTCommonConfig)
@support_status(SupportStatus.CERTIFIED)
Expand Down Expand Up @@ -750,7 +614,7 @@ def create_test_entity_mcps(

for upstream_urn in sorted(upstream_urns):
if self.config.entities_enabled.can_emit_node_type("test"):
yield self._make_assertion_from_test(
yield make_assertion_from_test(
custom_props,
node,
assertion_urn,
Expand All @@ -759,133 +623,17 @@ def create_test_entity_mcps(

if node.test_result:
if self.config.entities_enabled.can_emit_test_results:
yield self._make_assertion_result_from_test(
node, assertion_urn, upstream_urn
yield make_assertion_result_from_test(
node,
assertion_urn,
upstream_urn,
test_warnings_are_errors=self.config.test_warnings_are_errors,
)
else:
logger.debug(
f"Skipping test result {node.name} emission since it is turned off."
)

def _make_assertion_from_test(
self,
extra_custom_props: Dict[str, str],
node: DBTNode,
assertion_urn: str,
upstream_urn: str,
) -> MetadataWorkUnit:
assert node.test_info
qualified_test_name = node.test_info.qualified_test_name
column_name = node.test_info.column_name
kw_args = node.test_info.kw_args

if qualified_test_name in DBTTest.TEST_NAME_TO_ASSERTION_MAP:
assertion_params = DBTTest.TEST_NAME_TO_ASSERTION_MAP[qualified_test_name]
assertion_info = AssertionInfoClass(
type=AssertionTypeClass.DATASET,
customProperties=extra_custom_props,
datasetAssertion=DatasetAssertionInfoClass(
dataset=upstream_urn,
scope=assertion_params.scope,
operator=assertion_params.operator,
fields=[
mce_builder.make_schema_field_urn(upstream_urn, column_name)
]
if (
assertion_params.scope
== DatasetAssertionScopeClass.DATASET_COLUMN
and column_name
)
else [],
nativeType=node.name,
aggregation=assertion_params.aggregation,
parameters=assertion_params.parameters(kw_args)
if assertion_params.parameters
else None,
logic=assertion_params.logic_fn(kw_args)
if assertion_params.logic_fn
else None,
nativeParameters=string_map(kw_args),
),
)
elif column_name:
# no match with known test types, column-level test
assertion_info = AssertionInfoClass(
type=AssertionTypeClass.DATASET,
customProperties=extra_custom_props,
datasetAssertion=DatasetAssertionInfoClass(
dataset=upstream_urn,
scope=DatasetAssertionScopeClass.DATASET_COLUMN,
operator=AssertionStdOperatorClass._NATIVE_,
fields=[
mce_builder.make_schema_field_urn(upstream_urn, column_name)
],
nativeType=node.name,
logic=node.compiled_code or node.raw_code,
aggregation=AssertionStdAggregationClass._NATIVE_,
nativeParameters=string_map(kw_args),
),
)
else:
# no match with known test types, default to row-level test
assertion_info = AssertionInfoClass(
type=AssertionTypeClass.DATASET,
customProperties=extra_custom_props,
datasetAssertion=DatasetAssertionInfoClass(
dataset=upstream_urn,
scope=DatasetAssertionScopeClass.DATASET_ROWS,
operator=AssertionStdOperatorClass._NATIVE_,
logic=node.compiled_code or node.raw_code,
nativeType=node.name,
aggregation=AssertionStdAggregationClass._NATIVE_,
nativeParameters=string_map(kw_args),
),
)

wu = MetadataChangeProposalWrapper(
entityUrn=assertion_urn,
aspect=assertion_info,
).as_workunit()

return wu

def _make_assertion_result_from_test(
self,
node: DBTNode,
assertion_urn: str,
upstream_urn: str,
) -> MetadataWorkUnit:
assert node.test_result
test_result = node.test_result

assertionResult = AssertionRunEventClass(
timestampMillis=int(test_result.execution_time.timestamp() * 1000.0),
assertionUrn=assertion_urn,
asserteeUrn=upstream_urn,
runId=test_result.invocation_id,
result=AssertionResultClass(
type=AssertionResultTypeClass.SUCCESS
if test_result.status == "pass"
or (
not self.config.test_warnings_are_errors
and test_result.status == "warn"
)
else AssertionResultTypeClass.FAILURE,
nativeResults=test_result.native_results,
),
status=AssertionRunStatusClass.COMPLETE,
)

event = MetadataChangeProposalWrapper(
entityUrn=assertion_urn,
aspect=assertionResult,
)
wu = MetadataWorkUnit(
id=f"{assertion_urn}-assertionRunEvent-{upstream_urn}",
mcp=event,
)
return wu

@abstractmethod
def load_nodes(self) -> Tuple[List[DBTNode], Dict[str, Optional[str]]]:
# return dbt nodes + global custom properties
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,8 @@
DBTNode,
DBTSourceBase,
DBTSourceReport,
DBTTest,
DBTTestResult,
)
from datahub.ingestion.source.dbt.dbt_tests import DBTTest, DBTTestResult

logger = logging.getLogger(__name__)

Expand Down
Loading

0 comments on commit 8801376

Please sign in to comment.