diff --git a/metrics_layer/cli/seeding.py b/metrics_layer/cli/seeding.py index 381f011..df8dc68 100644 --- a/metrics_layer/cli/seeding.py +++ b/metrics_layer/cli/seeding.py @@ -369,6 +369,8 @@ def make_fields(self, column_data, schema_name: str, table_name: str, auto_tag_s Definitions.duck_db, Definitions.postgres, Definitions.redshift, + Definitions.sql_server, + Definitions.azure_synapse, }: column_name = '"' + row["COLUMN_NAME"] + '"' else: diff --git a/metrics_layer/core/sql/query_generator.py b/metrics_layer/core/sql/query_generator.py index ecee606..5fea256 100644 --- a/metrics_layer/core/sql/query_generator.py +++ b/metrics_layer/core/sql/query_generator.py @@ -14,6 +14,7 @@ from metrics_layer.core.sql.query_dialect import NullSorting, query_lookup from metrics_layer.core.sql.query_errors import ArgumentError from metrics_layer.core.sql.query_filter import MetricsLayerFilter +from metrics_layer.core.utils import flatten_filters class MetricsLayerQuery(MetricsLayerQueryBase): @@ -80,7 +81,7 @@ def parse_definition(self, definition: dict): # them as CTE's for the appropriate filters self.non_additive_ctes = [] metrics_in_select = definition.get("metrics", []) - metrics_in_having = [h.field.id() for h in self.having_filters if h.field] + metrics_in_having = [h["field"] for h in flatten_filters(having)] for metric in metrics_in_select + metrics_in_having: metric_field = self.design.get_field(metric) for ref_field in [metric_field] + metric_field.referenced_fields(metric_field.sql): @@ -385,7 +386,7 @@ def _non_additive_cte(self, definition: dict, group_by_dimensions: list): field_lookup[non_additive_dimension.id()] = non_additive_dimension # We also need to make all fields in the where clause available to the query - for f in self.where: + for f in flatten_filters(self.where): field = self.design.get_field(f["field"]) field_lookup[field.id()] = field diff --git a/metrics_layer/core/sql/single_query_resolve.py b/metrics_layer/core/sql/single_query_resolve.py index 8b17f04..8438af3 100644 --- a/metrics_layer/core/sql/single_query_resolve.py +++ b/metrics_layer/core/sql/single_query_resolve.py @@ -4,6 +4,7 @@ from metrics_layer.core.sql.query_design import MetricsLayerDesign from metrics_layer.core.sql.query_funnel import FunnelQuery from metrics_layer.core.sql.query_generator import MetricsLayerQuery +from metrics_layer.core.utils import flatten_filters class SingleSQLQueryResolver: @@ -232,21 +233,7 @@ def parse_identifiers_from_dicts(conditions: list): @staticmethod def flatten_filters(filters: list): - flat_list = [] - - def recurse(filter_obj): - if isinstance(filter_obj, dict): - if "conditions" in filter_obj: - for f in filter_obj["conditions"]: - recurse(f) - else: - flat_list.append(filter_obj) - elif isinstance(filter_obj, list): - for item in filter_obj: - recurse(item) - - recurse(filters) - return flat_list + return flatten_filters(filters) @staticmethod def _check_for_dict(conditions: list): diff --git a/metrics_layer/core/utils.py b/metrics_layer/core/utils.py index d053722..89d1dd4 100644 --- a/metrics_layer/core/utils.py +++ b/metrics_layer/core/utils.py @@ -14,3 +14,21 @@ def generate_random_password(length): letters = string.ascii_letters result_str = "".join(random.choice(letters) for i in range(length)) return result_str + + +def flatten_filters(filters: list): + flat_list = [] + + def recurse(filter_obj): + if isinstance(filter_obj, dict): + if "conditions" in filter_obj: + for f in filter_obj["conditions"]: + recurse(f) + else: + flat_list.append(filter_obj) + elif isinstance(filter_obj, list): + for item in filter_obj: + recurse(item) + + recurse(filters) + return flat_list diff --git a/tests/test_cli.py b/tests/test_cli.py index c14a031..0f7972d 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -176,6 +176,8 @@ def yaml_dump_assert(slf, data, file): Definitions.postgres, Definitions.trino, Definitions.redshift, + Definitions.sql_server, + Definitions.azure_synapse, }: assert social["sql"] == '${TABLE}."ON_SOCIAL_NETWORK"' else: @@ -195,6 +197,8 @@ def yaml_dump_assert(slf, data, file): Definitions.postgres, Definitions.trino, Definitions.redshift, + Definitions.sql_server, + Definitions.azure_synapse, }: assert acq_date["sql"] == '${TABLE}."ACQUISITION_DATE"' else: @@ -231,6 +235,8 @@ def yaml_dump_assert(slf, data, file): Definitions.postgres, Definitions.trino, Definitions.redshift, + Definitions.sql_server, + Definitions.azure_synapse, }: assert date["sql"] == '${TABLE}."ORDER_CREATED_AT"' else: @@ -244,6 +250,8 @@ def yaml_dump_assert(slf, data, file): Definitions.postgres, Definitions.trino, Definitions.redshift, + Definitions.sql_server, + Definitions.azure_synapse, }: assert new["sql"] == '${TABLE}."NEW_VS_REPEAT"' else: @@ -257,6 +265,8 @@ def yaml_dump_assert(slf, data, file): Definitions.postgres, Definitions.trino, Definitions.redshift, + Definitions.sql_server, + Definitions.azure_synapse, }: assert num["sql"] == '${TABLE}."REVENUE"' else: @@ -307,6 +317,8 @@ def yaml_dump_assert(slf, data, file): Definitions.postgres, Definitions.trino, Definitions.redshift, + Definitions.sql_server, + Definitions.azure_synapse, }: assert cross_sell["sql"] == '${TABLE}."@CRoSSell P-roduct:"' else: @@ -344,6 +356,8 @@ def yaml_dump_assert(slf, data, file): Definitions.postgres, Definitions.trino, Definitions.redshift, + Definitions.sql_server, + Definitions.azure_synapse, }: assert date["sql"] == '${TABLE}."SESSION_DATE"' else: @@ -357,6 +371,8 @@ def yaml_dump_assert(slf, data, file): Definitions.postgres, Definitions.trino, Definitions.redshift, + Definitions.sql_server, + Definitions.azure_synapse, }: assert pk["sql"] == '${TABLE}."SESSION_ID"' else: @@ -370,6 +386,8 @@ def yaml_dump_assert(slf, data, file): Definitions.postgres, Definitions.trino, Definitions.redshift, + Definitions.sql_server, + Definitions.azure_synapse, }: assert num["sql"] == '${TABLE}."CONVERSION"' else: diff --git a/tests/test_non_additive_dimensions.py b/tests/test_non_additive_dimensions.py index 4d66c68..5ec78f9 100644 --- a/tests/test_non_additive_dimensions.py +++ b/tests/test_non_additive_dimensions.py @@ -672,3 +672,80 @@ def test_mrr_non_additive_dimension_merged_result_sub_join_where(connection): " mrr_record__cte_subquery_0.mrr_record_date=z_customer_accounts_created__cte_subquery_1.z_customer_accounts_created_date;" # noqa ) assert query == correct + + +@pytest.mark.query +def test_mrr_non_additive_dimension_or_filters_with_select(connection): + query = connection.get_sql_query( + metrics=["mrr.mrr_end_of_month"], + dimensions=[], + where=[ + { + "conditional_filter_logic": { + "conditions": [ + {"field": "mrr.plan_name", "expression": "equal_to", "value": "Enterprise"} + ], + "logical_operator": "AND", + } + }, + ], + having=[ + { + "conditional_filter_logic": { + "conditions": [ + {"field": "number_of_billed_accounts", "expression": "greater_than", "value": 1100} + ], + "logical_operator": "AND", + } + } + ], + ) + + correct = ( + "WITH cte_mrr_end_of_month_record_raw AS (SELECT MAX(mrr.record_date) as mrr_max_record_raw FROM" + " analytics.mrr_by_customer mrr WHERE mrr.plan_name='Enterprise' ORDER BY mrr_max_record_raw DESC" + " NULLS LAST) SELECT SUM(case when mrr.record_date=cte_mrr_end_of_month_record_raw.mrr_max_record_raw" + " then mrr.mrr else 0 end) as mrr_mrr_end_of_month FROM analytics.mrr_by_customer mrr LEFT JOIN" + " cte_mrr_end_of_month_record_raw ON 1=1 WHERE mrr.plan_name='Enterprise' HAVING" + " COUNT(mrr.parent_account_id)>1100 ORDER BY mrr_mrr_end_of_month DESC NULLS LAST;" + ) + assert query == correct + + +@pytest.mark.query +def test_mrr_non_additive_dimension_or_filters(connection): + query = connection.get_sql_query( + metrics=["number_of_billed_accounts"], + dimensions=[], + where=[ + { + "conditional_filter_logic": { + "conditions": [ + {"field": "mrr.plan_name", "expression": "equal_to", "value": "Enterprise"} + ], + "logical_operator": "AND", + } + }, + ], + having=[ + { + "conditional_filter_logic": { + "conditions": [ + {"field": "mrr.mrr_end_of_month", "expression": "greater_than", "value": 1100} + ], + "logical_operator": "AND", + } + } + ], + ) + + correct = ( + "WITH cte_mrr_end_of_month_record_raw AS (SELECT MAX(mrr.record_date) as mrr_max_record_raw FROM" + " analytics.mrr_by_customer mrr WHERE mrr.plan_name='Enterprise' ORDER BY mrr_max_record_raw DESC" + " NULLS LAST) SELECT COUNT(mrr.parent_account_id) as mrr_number_of_billed_accounts FROM" + " analytics.mrr_by_customer mrr LEFT JOIN cte_mrr_end_of_month_record_raw ON 1=1 WHERE" + " mrr.plan_name='Enterprise' HAVING SUM(case when" + " mrr.record_date=cte_mrr_end_of_month_record_raw.mrr_max_record_raw then mrr.mrr else 0 end)>1100" + " ORDER BY mrr_number_of_billed_accounts DESC NULLS LAST;" + ) + assert query == correct