Skip to content

Commit

Permalink
Add validation rule for SCD query with no time
Browse files Browse the repository at this point in the history
This commit adds a new step to `MetricTimeQueryValidationRule` which
does the following:
1. Selects all the existing SCDs for the queried metric
2. Match them against the spec pattern of the group by input
3. Raise a `SCDRequiresMetricTimeIssue` if no `metric_time` was provided
  and there were matches

To accomplish step 1, I had to create a new `LinkableElementProperty`
called `SCD_HOP`. This new property indicates that the join path to the
linkable element goes through an SCD at some point.

I changed the `ValidLinkableSpecResolver` to add `SCD_HOP` to the
properties of all the elements it finds whenever that element belongs to
an SCD or if the path to it contains an SCD.
  • Loading branch information
serramatutu committed Oct 11, 2024
1 parent d4ecffe commit 3ccc016
Show file tree
Hide file tree
Showing 11 changed files with 281 additions and 69 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class LinkableElementProperty(Enum):
METRIC = "metric"
# A time dimension with a DatePart.
DATE_PART = "date_part"
# A linkable element that is itself part of an SCD model, or a linkable element that gets joined through another SCD model.
SCD_HOP = "scd_hop"

@staticmethod
def all_properties() -> FrozenSet[LinkableElementProperty]: # noqa: D102
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,10 @@ def __init__(

logger.debug(LazyFormat(lambda: f"Building valid group-by-item indexes took: {time.time() - start_time:.2f}s"))

def _is_semantic_model_scd(self, semantic_model: SemanticModel) -> bool:
"""Whether the semantic model's underlying table is an SCD."""
return any(dim.validity_params is not None for dim in semantic_model.dimensions)

def _generate_linkable_time_dimensions(
self,
semantic_model_origin: SemanticModelReference,
Expand Down Expand Up @@ -294,6 +298,8 @@ def get_joinable_metrics_for_semantic_model(
necessary.
"""
properties = frozenset({LinkableElementProperty.METRIC, LinkableElementProperty.JOINED})
if self._is_semantic_model_scd(semantic_model):
properties = properties.union({LinkableElementProperty.SCD_HOP})

join_path_has_path_links = len(using_join_path.path_elements) > 0
if join_path_has_path_links:
Expand Down Expand Up @@ -331,8 +337,15 @@ def _get_elements_in_semantic_model(self, semantic_model: SemanticModel) -> Link
Elements related to metric_time are handled separately in _get_metric_time_elements().
Linkable metrics are not considered local to the semantic model since they always require a join.
"""
semantic_model_is_scd = self._is_semantic_model_scd(semantic_model)

linkable_dimensions = []
linkable_entities = []

entity_properties = frozenset({LinkableElementProperty.LOCAL, LinkableElementProperty.ENTITY})
if semantic_model_is_scd:
entity_properties = entity_properties.union({LinkableElementProperty.SCD_HOP})

for entity in semantic_model.entities:
linkable_entities.append(
LinkableEntity.create(
Expand All @@ -342,7 +355,7 @@ def _get_elements_in_semantic_model(self, semantic_model: SemanticModel) -> Link
join_path=SemanticModelJoinPath(
left_semantic_model_reference=semantic_model.reference,
),
properties=frozenset({LinkableElementProperty.LOCAL, LinkableElementProperty.ENTITY}),
properties=entity_properties,
)
)
for entity_link in self._semantic_model_lookup.entity_links_for_local_elements(semantic_model):
Expand All @@ -357,12 +370,15 @@ def _get_elements_in_semantic_model(self, semantic_model: SemanticModel) -> Link
join_path=SemanticModelJoinPath(
left_semantic_model_reference=semantic_model.reference,
),
properties=frozenset({LinkableElementProperty.LOCAL, LinkableElementProperty.ENTITY}),
properties=entity_properties,
)
)

dimension_properties = frozenset({LinkableElementProperty.LOCAL})
if semantic_model_is_scd:
dimension_properties = dimension_properties.union({LinkableElementProperty.SCD_HOP})

for entity_link in self._semantic_model_lookup.entity_links_for_local_elements(semantic_model):
dimension_properties = frozenset({LinkableElementProperty.LOCAL})
for dimension in semantic_model.dimensions:
dimension_type = dimension.type
if dimension_type is DimensionType.CATEGORICAL:
Expand Down Expand Up @@ -464,6 +480,7 @@ def _get_metric_time_elements(self, measure_reference: Optional[MeasureReference
defined_granularity: Optional[ExpandedTimeGranularity] = None
if measure_reference:
measure_semantic_model = self._get_semantic_model_for_measure(measure_reference)
semantic_model_is_scd = self._is_semantic_model_scd(measure_semantic_model)
measure_agg_time_dimension_reference = measure_semantic_model.checked_agg_time_dimension_for_measure(
measure_reference=measure_reference
)
Expand All @@ -476,6 +493,7 @@ def _get_metric_time_elements(self, measure_reference: Optional[MeasureReference
# If querying metric_time without metrics, will query from time spines.
# Defaults to DAY granularity if available in time spines, else smallest available granularity.
min_granularity = min(self._time_spine_sources.keys())
semantic_model_is_scd = False
possible_metric_time_granularities = tuple(
ExpandedTimeGranularity.from_time_granularity(time_granularity)
for time_granularity in TimeGranularity
Expand Down Expand Up @@ -506,6 +524,8 @@ def _get_metric_time_elements(self, measure_reference: Optional[MeasureReference
properties.add(LinkableElementProperty.DERIVED_TIME_GRANULARITY)
if date_part:
properties.add(LinkableElementProperty.DATE_PART)
if semantic_model_is_scd:
properties.add(LinkableElementProperty.SCD_HOP)
linkable_dimension = LinkableDimension.create(
defined_in_semantic_model=measure_semantic_model.reference if measure_semantic_model else None,
element_name=MetricFlowReservedKeywords.METRIC_TIME.value,
Expand Down Expand Up @@ -717,12 +737,25 @@ def create_linkable_element_set_from_join_path(
join_path: SemanticModelJoinPath,
) -> LinkableElementSet:
"""Given the current path, generate the respective linkable elements from the last semantic model in the path."""
semantic_model = self._semantic_model_lookup.get_by_reference(join_path.last_semantic_model_reference)
assert (
semantic_model
), f"Semantic model {join_path.last_semantic_model_reference.semantic_model_name} is in join path but does not exist in SemanticModelLookup"

properties = frozenset({LinkableElementProperty.JOINED})
if len(join_path.path_elements) > 1:
properties = properties.union({LinkableElementProperty.MULTI_HOP})

semantic_model = self._semantic_model_lookup.get_by_reference(join_path.last_semantic_model_reference)
assert semantic_model
# If any of the semantic models in the join path is an SCD, add SCD_HOP
for reference_to_derived_model in join_path.derived_from_semantic_models:
derived_model = self._semantic_model_lookup.get_by_reference(reference_to_derived_model)
assert (
derived_model
), f"Semantic model {reference_to_derived_model.semantic_model_name} is in join path but does not exist in SemanticModelLookup"

if self._is_semantic_model_scd(derived_model):
properties = properties.union({LinkableElementProperty.SCD_HOP})
break

linkable_dimensions: List[LinkableDimension] = []
linkable_entities: List[LinkableEntity] = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
)
from metricflow_semantics.model.semantics.semantic_model_join_evaluator import MAX_JOIN_HOPS
from metricflow_semantics.model.semantics.semantic_model_lookup import SemanticModelLookup
from metricflow_semantics.specs.instance_spec import LinkableInstanceSpec
from metricflow_semantics.specs.time_dimension_spec import TimeDimensionSpec
from metricflow_semantics.time.granularity import ExpandedTimeGranularity

Expand Down Expand Up @@ -256,3 +257,15 @@ def _get_min_queryable_time_granularity(self, metric_reference: MetricReference)
minimum_queryable_granularity = defined_time_granularity

return minimum_queryable_granularity

def get_joinable_scd_specs_for_metric(self, metric_reference: MetricReference) -> Sequence[LinkableInstanceSpec]:
"""Get the SCDs that can be joined to a metric."""
filter = LinkableElementFilter(
with_any_of=frozenset([LinkableElementProperty.SCD_HOP]),
)
scd_elems = self.linkable_elements_for_metrics(
metric_references=(metric_reference,),
element_set_filter=filter,
)

return scd_elems.specs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from __future__ import annotations

from collections.abc import Sequence
from dataclasses import dataclass

from typing_extensions import override

from metricflow_semantics.query.group_by_item.resolution_path import MetricFlowQueryResolutionPath
from metricflow_semantics.query.issues.issues_base import (
MetricFlowQueryIssueType,
MetricFlowQueryResolutionIssue,
)
from metricflow_semantics.query.resolver_inputs.base_resolver_inputs import MetricFlowQueryResolverInput
from metricflow_semantics.specs.instance_spec import InstanceSpec


@dataclass(frozen=True)
class ScdRequiresMetricTimeIssue(MetricFlowQueryResolutionIssue):
"""Describes an issue with a query that includes a SCD group by but does not include metric_time."""

scds_in_query: Sequence[InstanceSpec]

@override
def ui_description(self, associated_input: MetricFlowQueryResolverInput) -> str:
dim_str = ", ".join(scd.qualified_name for scd in self.scds_in_query)
return (
"Your query contains the group bys which are SCDs or contain SCDs in the "
f"join path: [{dim_str}].\n\nA query containing SCDs must also contain the "
"metric_time dimension in order to join the SCD table to the valid time "
"range. Please add metric_time to the query and try again. If you're "
"using agg_time_dimension, use metric_time instead."
)

@override
def with_path_prefix(self, path_prefix: MetricFlowQueryResolutionPath) -> ScdRequiresMetricTimeIssue:
return ScdRequiresMetricTimeIssue(
issue_type=self.issue_type,
parent_issues=self.parent_issues,
query_resolution_path=self.query_resolution_path.with_path_prefix(path_prefix),
scds_in_query=self.scds_in_query,
)

@staticmethod
def from_parameters( # noqa: D102
scds_in_query: Sequence[InstanceSpec], query_resolution_path: MetricFlowQueryResolutionPath
) -> ScdRequiresMetricTimeIssue:
return ScdRequiresMetricTimeIssue(
issue_type=MetricFlowQueryIssueType.ERROR,
parent_issues=(),
query_resolution_path=query_resolution_path,
scds_in_query=scds_in_query,
)
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import Sequence

from dbt_semantic_interfaces.protocols import WhereFilterIntersection
from dbt_semantic_interfaces.references import MetricReference
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@

from dataclasses import dataclass
from functools import lru_cache
from typing import Sequence
from typing import List, Sequence

from dbt_semantic_interfaces.enum_extension import assert_values_exhausted
from dbt_semantic_interfaces.naming.keywords import METRIC_TIME_ELEMENT_NAME
from dbt_semantic_interfaces.protocols import WhereFilterIntersection
from dbt_semantic_interfaces.references import MetricReference, TimeDimensionReference
from dbt_semantic_interfaces.references import (
MetricReference,
TimeDimensionReference,
)
from dbt_semantic_interfaces.type_enums import MetricType
from typing_extensions import override

Expand All @@ -20,15 +23,22 @@
from metricflow_semantics.query.issues.parsing.offset_metric_requires_metric_time import (
OffsetMetricRequiresMetricTimeIssue,
)
from metricflow_semantics.query.resolver_inputs.query_resolver_inputs import ResolverInputForQuery
from metricflow_semantics.query.issues.parsing.scd_requires_metric_time import (
ScdRequiresMetricTimeIssue,
)
from metricflow_semantics.query.resolver_inputs.query_resolver_inputs import (
ResolverInputForQuery,
)
from metricflow_semantics.query.validation_rules.base_validation_rule import PostResolutionQueryValidationRule
from metricflow_semantics.specs.instance_spec import InstanceSpec
from metricflow_semantics.specs.time_dimension_spec import TimeDimensionSpec


@dataclass(frozen=True)
class QueryItemsAnalysis:
"""Contains data about which items a query contains."""

scds: Sequence[InstanceSpec]
has_metric_time: bool
has_agg_time_dimension: bool

Expand All @@ -39,7 +49,8 @@ class MetricTimeQueryValidationRule(PostResolutionQueryValidationRule):
Currently, known cases are:
* Cumulative metrics.
* Derived metrics with an offset time.g
* Derived metrics with an offset time.
* Slowly changing dimensions
"""

def __init__( # noqa: D107
Expand All @@ -61,18 +72,26 @@ def _get_query_items_analysis(
) -> QueryItemsAnalysis:
has_agg_time_dimension = False
has_metric_time = False
scds: List[InstanceSpec] = []

valid_agg_time_dimension_specs = self._manifest_lookup.metric_lookup.get_valid_agg_time_dimensions_for_metric(
metric_reference
)

scd_specs = self._manifest_lookup.metric_lookup.get_joinable_scd_specs_for_metric(metric_reference)

for group_by_item_input in query_resolver_input.group_by_item_inputs:
if group_by_item_input.spec_pattern.matches_any(self._metric_time_specs):
has_metric_time = True

if group_by_item_input.spec_pattern.matches_any(valid_agg_time_dimension_specs):
has_agg_time_dimension = True

scd_matches = group_by_item_input.spec_pattern.match(scd_specs)
scds.extend(scd_matches)

return QueryItemsAnalysis(
scds=scds,
has_metric_time=has_metric_time,
has_agg_time_dimension=has_agg_time_dimension,
)
Expand All @@ -89,6 +108,16 @@ def validate_metric_in_resolution_dag(

issues = MetricFlowQueryResolutionIssueSet.empty_instance()

# Queries that join to an SCD don't support direct references to agg_time_dimension, so we
# only check for metric_time. If we decide to support agg_time_dimension, we should add a check
if len(query_items_analysis.scds) > 0 and not query_items_analysis.has_metric_time:
issues = issues.add_issue(
ScdRequiresMetricTimeIssue.from_parameters(
scds_in_query=query_items_analysis.scds,
query_resolution_path=resolution_path,
)
)

if metric.type is MetricType.CUMULATIVE:
if (
metric.type_params is not None
Expand Down Expand Up @@ -134,6 +163,3 @@ def validate_query_in_resolution_dag(
resolution_path: MetricFlowQueryResolutionPath,
) -> MetricFlowQueryResolutionIssueSet:
return MetricFlowQueryResolutionIssueSet.empty_instance()


__all__ = ["MetricTimeQueryValidationRule"]
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def test_create_linkable_element_set_from_join_path_multi_hop( # noqa: D103
left_semantic_model_reference=SemanticModelReference("views_source"),
path_elements=(
SemanticModelJoinPathElement(
semantic_model_reference=SemanticModelReference("bookings"),
semantic_model_reference=SemanticModelReference("bookings_source"),
join_on_entity=EntityReference("guest"),
),
SemanticModelJoinPathElement(
Expand Down
Loading

0 comments on commit 3ccc016

Please sign in to comment.