diff --git a/.changes/unreleased/Features-20240905-180956.yaml b/.changes/unreleased/Features-20240905-180956.yaml new file mode 100644 index 000000000..6ca843c43 --- /dev/null +++ b/.changes/unreleased/Features-20240905-180956.yaml @@ -0,0 +1,7 @@ +kind: Features +body: Add EventTimeFilter to BaseRelation, which renders a filtered relation when + start or end is set +time: 2024-09-05T18:09:56.159385-04:00 +custom: + Author: 'michelleark QMalcolm' + Issue: "294" diff --git a/dbt/adapters/base/relation.py b/dbt/adapters/base/relation.py index 1aab7b2fe..0053265f0 100644 --- a/dbt/adapters/base/relation.py +++ b/dbt/adapters/base/relation.py @@ -1,5 +1,6 @@ from collections.abc import Hashable from dataclasses import dataclass, field +from datetime import datetime from typing import ( Any, Dict, @@ -36,6 +37,13 @@ SerializableIterable = Union[Tuple, FrozenSet] +@dataclass(frozen=True, eq=False, repr=False) +class EventTimeFilter(FakeAPIObject, Hashable): + field_name: str + start: Optional[datetime] = None + end: Optional[datetime] = None + + @dataclass(frozen=True, eq=False, repr=False) class BaseRelation(FakeAPIObject, Hashable): path: Path @@ -47,6 +55,7 @@ class BaseRelation(FakeAPIObject, Hashable): quote_policy: Policy = field(default_factory=lambda: Policy()) dbt_created: bool = False limit: Optional[int] = None + event_time_filter: Optional[EventTimeFilter] = None require_alias: bool = ( True # used to govern whether to add an alias when render_limited is called ) @@ -208,14 +217,19 @@ def render(self) -> str: # if there is nothing set, this will return the empty string. return ".".join(part for _, part in self._render_iterator() if part is not None) - def _render_limited_alias(self) -> str: + def _render_subquery_alias(self, namespace: str) -> str: """Some databases require an alias for subqueries (postgres, mysql) for all others we want to avoid adding an alias as it has the potential to introduce issues with the query if the user also defines an alias. """ if self.require_alias: - return f" _dbt_limit_subq_{self.table}" + return f" _dbt_{namespace}_subq_{self.table}" return "" + def _render_limited_alias( + self, + ) -> str: + return self._render_subquery_alias(namespace="limit") + def render_limited(self) -> str: rendered = self.render() if self.limit is None: @@ -225,6 +239,31 @@ def render_limited(self) -> str: else: return f"(select * from {rendered} limit {self.limit}){self._render_limited_alias()}" + def render_event_time_filtered(self, rendered: Optional[str] = None) -> str: + rendered = rendered or self.render() + if self.event_time_filter is None: + return rendered + + filter = self._render_event_time_filtered(self.event_time_filter) + if not filter: + return rendered + + return f"(select * from {rendered} where {filter}){self._render_subquery_alias(namespace='et_filter')}" + + def _render_event_time_filtered(self, event_time_filter: EventTimeFilter) -> str: + """ + Returns "" if start and end are both None + """ + filter = "" + if event_time_filter.start and event_time_filter.end: + filter = f"{event_time_filter.field_name} >= '{event_time_filter.start}' and {event_time_filter.field_name} < '{event_time_filter.end}'" + elif event_time_filter.start: + filter = f"{event_time_filter.field_name} >= '{event_time_filter.start}'" + elif event_time_filter.end: + filter = f"{event_time_filter.field_name} < '{event_time_filter.end}'" + + return filter + def quoted(self, identifier): return "{quote_char}{identifier}{quote_char}".format( quote_char=self.quote_character, @@ -240,6 +279,7 @@ def create_ephemeral_from( cls: Type[Self], relation_config: RelationConfig, limit: Optional[int] = None, + event_time_filter: Optional[EventTimeFilter] = None, ) -> Self: # Note that ephemeral models are based on the identifier, which will # point to the model's alias if one exists and otherwise fall back to @@ -250,6 +290,7 @@ def create_ephemeral_from( type=cls.CTE, identifier=identifier, limit=limit, + event_time_filter=event_time_filter, ).quote(identifier=False) @classmethod @@ -315,7 +356,14 @@ def __hash__(self) -> int: return hash(self.render()) def __str__(self) -> str: - return self.render() if self.limit is None else self.render_limited() + rendered = self.render() if self.limit is None else self.render_limited() + + # Limited subquery is wrapped by the event time filter subquery, and not the other way around. + # This is because in the context of resolving limited refs, we care more about performance than reliably producing a sample of a certain size. + if self.event_time_filter: + rendered = self.render_event_time_filtered(rendered) + + return rendered @property def database(self) -> Optional[str]: diff --git a/tests/unit/test_relation.py b/tests/unit/test_relation.py index 97d564192..6d835e0d2 100644 --- a/tests/unit/test_relation.py +++ b/tests/unit/test_relation.py @@ -1,8 +1,9 @@ from dataclasses import dataclass, replace - +from datetime import datetime import pytest from dbt.adapters.base import BaseRelation +from dbt.adapters.base.relation import EventTimeFilter from dbt.adapters.contracts.relation import RelationType @@ -81,6 +82,80 @@ def test_render_limited(limit, require_alias, expected_result): assert str(my_relation) == expected_result +@pytest.mark.parametrize( + "event_time_filter,require_alias,expected_result", + [ + (None, False, '"test_database"."test_schema"."test_identifier"'), + ( + EventTimeFilter(field_name="column"), + False, + '"test_database"."test_schema"."test_identifier"', + ), + (None, True, '"test_database"."test_schema"."test_identifier"'), + ( + EventTimeFilter(field_name="column"), + True, + '"test_database"."test_schema"."test_identifier"', + ), + ( + EventTimeFilter(field_name="column", start=datetime(year=2020, month=1, day=1)), + False, + """(select * from "test_database"."test_schema"."test_identifier" where column >= '2020-01-01 00:00:00')""", + ), + ( + EventTimeFilter(field_name="column", start=datetime(year=2020, month=1, day=1)), + True, + """(select * from "test_database"."test_schema"."test_identifier" where column >= '2020-01-01 00:00:00') _dbt_et_filter_subq_test_identifier""", + ), + ( + EventTimeFilter(field_name="column", end=datetime(year=2020, month=1, day=1)), + False, + """(select * from "test_database"."test_schema"."test_identifier" where column < '2020-01-01 00:00:00')""", + ), + ( + EventTimeFilter( + field_name="column", + start=datetime(year=2020, month=1, day=1), + end=datetime(year=2020, month=1, day=2), + ), + False, + """(select * from "test_database"."test_schema"."test_identifier" where column >= '2020-01-01 00:00:00' and column < '2020-01-02 00:00:00')""", + ), + ], +) +def test_render_event_time_filtered(event_time_filter, require_alias, expected_result): + my_relation = BaseRelation.create( + database="test_database", + schema="test_schema", + identifier="test_identifier", + event_time_filter=event_time_filter, + require_alias=require_alias, + ) + actual_result = my_relation.render_event_time_filtered() + assert actual_result == expected_result + assert str(my_relation) == expected_result + + +def test_render_event_time_filtered_and_limited(): + my_relation = BaseRelation.create( + database="test_database", + schema="test_schema", + identifier="test_identifier", + event_time_filter=EventTimeFilter( + field_name="column", + start=datetime(year=2020, month=1, day=1), + end=datetime(year=2020, month=1, day=2), + ), + limit=0, + require_alias=False, + ) + expected_result = """(select * from (select * from "test_database"."test_schema"."test_identifier" where false limit 0) where column >= '2020-01-01 00:00:00' and column < '2020-01-02 00:00:00')""" + + actual_result = my_relation.render_event_time_filtered(my_relation.render_limited()) + assert actual_result == expected_result + assert str(my_relation) == expected_result + + def test_create_ephemeral_from_uses_identifier(): @dataclass class Node: