Skip to content

Commit

Permalink
refactor(ingest/dbt): move dbt tests logic to dedicated file (#8984)
Browse files Browse the repository at this point in the history
  • Loading branch information
hsheth2 authored Oct 11, 2023
1 parent 4b6b941 commit 932fbcd
Show file tree
Hide file tree
Showing 6 changed files with 288 additions and 274 deletions.
9 changes: 9 additions & 0 deletions metadata-ingestion/src/datahub/ingestion/api/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, Generic, Iterable, Optional, Tuple, TypeVar

from datahub.configuration.common import ConfigurationError
from datahub.emitter.mce_builder import set_dataset_urn_to_lower
from datahub.ingestion.api.committable import Committable
from datahub.ingestion.graph.client import DataHubGraph
Expand Down Expand Up @@ -75,3 +76,11 @@ def register_checkpointer(self, committable: Committable) -> None:

def get_committables(self) -> Iterable[Tuple[str, Committable]]:
yield from self.checkpointers.items()

def require_graph(self, operation: Optional[str] = None) -> DataHubGraph:
if not self.graph:
raise ConfigurationError(
f"{operation or 'This operation'} requires a graph, but none was provided. "
"To provide one, either use the datahub-rest sink or set the top-level datahub_api config in the recipe."
)
return self.graph
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,9 @@ def __init__(self, config: CSVEnricherConfig, ctx: PipelineContext):
# Map from entity urn to a list of SubResourceRow.
self.editable_schema_metadata_map: Dict[str, List[SubResourceRow]] = {}
self.should_overwrite: bool = self.config.write_semantics == "OVERRIDE"
if not self.should_overwrite and not self.ctx.graph:
raise ConfigurationError(
"With PATCH semantics, the csv-enricher source requires a datahub_api to connect to. "
"Consider using the datahub-rest sink or provide a datahub_api: configuration on your ingestion recipe."
)

if not self.should_overwrite:
self.ctx.require_graph(operation="The csv-enricher's PATCH semantics flag")

def get_resource_glossary_terms_work_unit(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@
DBTCommonConfig,
DBTNode,
DBTSourceBase,
DBTTest,
DBTTestResult,
)
from datahub.ingestion.source.dbt.dbt_tests import DBTTest, DBTTestResult

logger = logging.getLogger(__name__)

Expand Down
278 changes: 13 additions & 265 deletions metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_common.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import json
import logging
import re
from abc import abstractmethod
from dataclasses import dataclass, field
from datetime import datetime
from enum import auto
from typing import Any, Callable, ClassVar, Dict, Iterable, List, Optional, Tuple, Union
from typing import Any, Dict, Iterable, List, Optional, Tuple

import pydantic
from pydantic import root_validator, validator
Expand Down Expand Up @@ -34,6 +33,12 @@
from datahub.ingestion.api.source import MetadataWorkUnitProcessor
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.source.common.subtypes import DatasetSubTypes
from datahub.ingestion.source.dbt.dbt_tests import (
DBTTest,
DBTTestResult,
make_assertion_from_test,
make_assertion_result_from_test,
)
from datahub.ingestion.source.sql.sql_types import (
ATHENA_SQL_TYPES_MAP,
BIGQUERY_TYPES_MAP,
Expand Down Expand Up @@ -81,20 +86,7 @@
TimeTypeClass,
)
from datahub.metadata.schema_classes import (
AssertionInfoClass,
AssertionResultClass,
AssertionResultTypeClass,
AssertionRunEventClass,
AssertionRunStatusClass,
AssertionStdAggregationClass,
AssertionStdOperatorClass,
AssertionStdParameterClass,
AssertionStdParametersClass,
AssertionStdParameterTypeClass,
AssertionTypeClass,
DataPlatformInstanceClass,
DatasetAssertionInfoClass,
DatasetAssertionScopeClass,
DatasetPropertiesClass,
GlobalTagsClass,
GlossaryTermsClass,
Expand Down Expand Up @@ -551,134 +543,6 @@ def get_column_type(
return SchemaFieldDataType(type=TypeClass())


@dataclass
class AssertionParams:
scope: Union[DatasetAssertionScopeClass, str]
operator: Union[AssertionStdOperatorClass, str]
aggregation: Union[AssertionStdAggregationClass, str]
parameters: Optional[Callable[[Dict[str, str]], AssertionStdParametersClass]] = None
logic_fn: Optional[Callable[[Dict[str, str]], Optional[str]]] = None


def _get_name_for_relationship_test(kw_args: Dict[str, str]) -> Optional[str]:
"""
Try to produce a useful string for the name of a relationship constraint.
Return None if we fail to
"""
destination_ref = kw_args.get("to")
source_ref = kw_args.get("model")
column_name = kw_args.get("column_name")
dest_field_name = kw_args.get("field")
if not destination_ref or not source_ref or not column_name or not dest_field_name:
# base assertions are violated, bail early
return None
m = re.match(r"^ref\(\'(.*)\'\)$", destination_ref)
if m:
destination_table = m.group(1)
else:
destination_table = destination_ref
m = re.search(r"ref\(\'(.*)\'\)", source_ref)
if m:
source_table = m.group(1)
else:
source_table = source_ref
return f"{source_table}.{column_name} referential integrity to {destination_table}.{dest_field_name}"


@dataclass
class DBTTest:
qualified_test_name: str
column_name: Optional[str]
kw_args: dict

TEST_NAME_TO_ASSERTION_MAP: ClassVar[Dict[str, AssertionParams]] = {
"not_null": AssertionParams(
scope=DatasetAssertionScopeClass.DATASET_COLUMN,
operator=AssertionStdOperatorClass.NOT_NULL,
aggregation=AssertionStdAggregationClass.IDENTITY,
),
"unique": AssertionParams(
scope=DatasetAssertionScopeClass.DATASET_COLUMN,
operator=AssertionStdOperatorClass.EQUAL_TO,
aggregation=AssertionStdAggregationClass.UNIQUE_PROPOTION,
parameters=lambda _: AssertionStdParametersClass(
value=AssertionStdParameterClass(
value="1.0",
type=AssertionStdParameterTypeClass.NUMBER,
)
),
),
"accepted_values": AssertionParams(
scope=DatasetAssertionScopeClass.DATASET_COLUMN,
operator=AssertionStdOperatorClass.IN,
aggregation=AssertionStdAggregationClass.IDENTITY,
parameters=lambda kw_args: AssertionStdParametersClass(
value=AssertionStdParameterClass(
value=json.dumps(kw_args.get("values")),
type=AssertionStdParameterTypeClass.SET,
),
),
),
"relationships": AssertionParams(
scope=DatasetAssertionScopeClass.DATASET_COLUMN,
operator=AssertionStdOperatorClass._NATIVE_,
aggregation=AssertionStdAggregationClass.IDENTITY,
parameters=lambda kw_args: AssertionStdParametersClass(
value=AssertionStdParameterClass(
value=json.dumps(kw_args.get("values")),
type=AssertionStdParameterTypeClass.SET,
),
),
logic_fn=_get_name_for_relationship_test,
),
"dbt_expectations.expect_column_values_to_not_be_null": AssertionParams(
scope=DatasetAssertionScopeClass.DATASET_COLUMN,
operator=AssertionStdOperatorClass.NOT_NULL,
aggregation=AssertionStdAggregationClass.IDENTITY,
),
"dbt_expectations.expect_column_values_to_be_between": AssertionParams(
scope=DatasetAssertionScopeClass.DATASET_COLUMN,
operator=AssertionStdOperatorClass.BETWEEN,
aggregation=AssertionStdAggregationClass.IDENTITY,
parameters=lambda x: AssertionStdParametersClass(
minValue=AssertionStdParameterClass(
value=str(x.get("min_value", "unknown")),
type=AssertionStdParameterTypeClass.NUMBER,
),
maxValue=AssertionStdParameterClass(
value=str(x.get("max_value", "unknown")),
type=AssertionStdParameterTypeClass.NUMBER,
),
),
),
"dbt_expectations.expect_column_values_to_be_in_set": AssertionParams(
scope=DatasetAssertionScopeClass.DATASET_COLUMN,
operator=AssertionStdOperatorClass.IN,
aggregation=AssertionStdAggregationClass.IDENTITY,
parameters=lambda kw_args: AssertionStdParametersClass(
value=AssertionStdParameterClass(
value=json.dumps(kw_args.get("value_set")),
type=AssertionStdParameterTypeClass.SET,
),
),
),
}


@dataclass
class DBTTestResult:
invocation_id: str

status: str
execution_time: datetime

native_results: Dict[str, str]


def string_map(input_map: Dict[str, Any]) -> Dict[str, str]:
return {k: str(v) for k, v in input_map.items()}


@platform_name("dbt")
@config_class(DBTCommonConfig)
@support_status(SupportStatus.CERTIFIED)
Expand Down Expand Up @@ -750,7 +614,7 @@ def create_test_entity_mcps(

for upstream_urn in sorted(upstream_urns):
if self.config.entities_enabled.can_emit_node_type("test"):
yield self._make_assertion_from_test(
yield make_assertion_from_test(
custom_props,
node,
assertion_urn,
Expand All @@ -759,133 +623,17 @@ def create_test_entity_mcps(

if node.test_result:
if self.config.entities_enabled.can_emit_test_results:
yield self._make_assertion_result_from_test(
node, assertion_urn, upstream_urn
yield make_assertion_result_from_test(
node,
assertion_urn,
upstream_urn,
test_warnings_are_errors=self.config.test_warnings_are_errors,
)
else:
logger.debug(
f"Skipping test result {node.name} emission since it is turned off."
)

def _make_assertion_from_test(
self,
extra_custom_props: Dict[str, str],
node: DBTNode,
assertion_urn: str,
upstream_urn: str,
) -> MetadataWorkUnit:
assert node.test_info
qualified_test_name = node.test_info.qualified_test_name
column_name = node.test_info.column_name
kw_args = node.test_info.kw_args

if qualified_test_name in DBTTest.TEST_NAME_TO_ASSERTION_MAP:
assertion_params = DBTTest.TEST_NAME_TO_ASSERTION_MAP[qualified_test_name]
assertion_info = AssertionInfoClass(
type=AssertionTypeClass.DATASET,
customProperties=extra_custom_props,
datasetAssertion=DatasetAssertionInfoClass(
dataset=upstream_urn,
scope=assertion_params.scope,
operator=assertion_params.operator,
fields=[
mce_builder.make_schema_field_urn(upstream_urn, column_name)
]
if (
assertion_params.scope
== DatasetAssertionScopeClass.DATASET_COLUMN
and column_name
)
else [],
nativeType=node.name,
aggregation=assertion_params.aggregation,
parameters=assertion_params.parameters(kw_args)
if assertion_params.parameters
else None,
logic=assertion_params.logic_fn(kw_args)
if assertion_params.logic_fn
else None,
nativeParameters=string_map(kw_args),
),
)
elif column_name:
# no match with known test types, column-level test
assertion_info = AssertionInfoClass(
type=AssertionTypeClass.DATASET,
customProperties=extra_custom_props,
datasetAssertion=DatasetAssertionInfoClass(
dataset=upstream_urn,
scope=DatasetAssertionScopeClass.DATASET_COLUMN,
operator=AssertionStdOperatorClass._NATIVE_,
fields=[
mce_builder.make_schema_field_urn(upstream_urn, column_name)
],
nativeType=node.name,
logic=node.compiled_code or node.raw_code,
aggregation=AssertionStdAggregationClass._NATIVE_,
nativeParameters=string_map(kw_args),
),
)
else:
# no match with known test types, default to row-level test
assertion_info = AssertionInfoClass(
type=AssertionTypeClass.DATASET,
customProperties=extra_custom_props,
datasetAssertion=DatasetAssertionInfoClass(
dataset=upstream_urn,
scope=DatasetAssertionScopeClass.DATASET_ROWS,
operator=AssertionStdOperatorClass._NATIVE_,
logic=node.compiled_code or node.raw_code,
nativeType=node.name,
aggregation=AssertionStdAggregationClass._NATIVE_,
nativeParameters=string_map(kw_args),
),
)

wu = MetadataChangeProposalWrapper(
entityUrn=assertion_urn,
aspect=assertion_info,
).as_workunit()

return wu

def _make_assertion_result_from_test(
self,
node: DBTNode,
assertion_urn: str,
upstream_urn: str,
) -> MetadataWorkUnit:
assert node.test_result
test_result = node.test_result

assertionResult = AssertionRunEventClass(
timestampMillis=int(test_result.execution_time.timestamp() * 1000.0),
assertionUrn=assertion_urn,
asserteeUrn=upstream_urn,
runId=test_result.invocation_id,
result=AssertionResultClass(
type=AssertionResultTypeClass.SUCCESS
if test_result.status == "pass"
or (
not self.config.test_warnings_are_errors
and test_result.status == "warn"
)
else AssertionResultTypeClass.FAILURE,
nativeResults=test_result.native_results,
),
status=AssertionRunStatusClass.COMPLETE,
)

event = MetadataChangeProposalWrapper(
entityUrn=assertion_urn,
aspect=assertionResult,
)
wu = MetadataWorkUnit(
id=f"{assertion_urn}-assertionRunEvent-{upstream_urn}",
mcp=event,
)
return wu

@abstractmethod
def load_nodes(self) -> Tuple[List[DBTNode], Dict[str, Optional[str]]]:
# return dbt nodes + global custom properties
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,8 @@
DBTNode,
DBTSourceBase,
DBTSourceReport,
DBTTest,
DBTTestResult,
)
from datahub.ingestion.source.dbt.dbt_tests import DBTTest, DBTTestResult

logger = logging.getLogger(__name__)

Expand Down
Loading

0 comments on commit 932fbcd

Please sign in to comment.