From 21cc1de574a6c38ffa906b6259a5b19401fdc81a Mon Sep 17 00:00:00 2001 From: Paul Blankley Date: Thu, 19 Sep 2024 08:13:31 +0100 Subject: [PATCH] fix issue with merged results filter application + include test for pure organics (remove before merge) --- metrics_layer/core/model/base.py | 9 +++ .../core/sql/merged_query_resolve.py | 2 +- .../sql/query_arbitrary_merged_queries.py | 2 +- metrics_layer/core/sql/query_base.py | 25 +++----- metrics_layer/core/sql/query_filter.py | 54 ++++++++++++++-- tests/test_arbitrary_merged_results.py | 6 +- tests/test_join_query.py | 64 +++++++++++++++++++ 7 files changed, 137 insertions(+), 25 deletions(-) diff --git a/metrics_layer/core/model/base.py b/metrics_layer/core/model/base.py index 4197272..9da1ec1 100644 --- a/metrics_layer/core/model/base.py +++ b/metrics_layer/core/model/base.py @@ -2,6 +2,8 @@ import re from typing import List +from metrics_layer.core.exceptions import QueryError + NAME_REGEX = re.compile(r"([A-Za-z0-9\_]+)") @@ -32,6 +34,13 @@ def name_error(entity_name: str, name: str): "the naming conventions (only letters, numbers, or underscores)" ) + @staticmethod + def _raise_query_error_from_cte(field_name: str): + raise QueryError( + f"Field {field_name} is not present in either source query, so it" + " cannot be applied as a filter. Please add it to one of the source queries." + ) + @staticmethod def line_col(element): line = getattr(getattr(element, "lc", None), "line", None) diff --git a/metrics_layer/core/sql/merged_query_resolve.py b/metrics_layer/core/sql/merged_query_resolve.py index c724dbe..e1abd6d 100644 --- a/metrics_layer/core/sql/merged_query_resolve.py +++ b/metrics_layer/core/sql/merged_query_resolve.py @@ -125,7 +125,7 @@ def derive_sub_queries(self): secondary_metric_ids = [m.id() for m in self.secondary_metrics] merged_metric_ids = [m.id() for m in self.merged_metrics] - for h in self.having: + for h in self.flatten_filters(self.having): field = self.project.get_field(h["field"]) if field.id() not in secondary_metric_ids and not field.is_merged_result: self.secondary_metrics.append(field) diff --git a/metrics_layer/core/sql/query_arbitrary_merged_queries.py b/metrics_layer/core/sql/query_arbitrary_merged_queries.py index 2c90ad8..982cbbf 100644 --- a/metrics_layer/core/sql/query_arbitrary_merged_queries.py +++ b/metrics_layer/core/sql/query_arbitrary_merged_queries.py @@ -48,7 +48,7 @@ def get_query(self, semicolon: bool = True): if order_by_alias in self.cte_alias_lookup: order_by_alias = f"{self.cte_alias_lookup[order_by_alias]}.{order_by_alias}" else: - self._raise_query_error_from_cte(field.id(capitalize_alias=True)) + self._raise_query_error_from_cte(field.id()) order = Order.desc if order_clause.get("sort", "asc").lower() == "desc" else Order.asc complete_query = complete_query.orderby( diff --git a/metrics_layer/core/sql/query_base.py b/metrics_layer/core/sql/query_base.py index 5b24c37..5c9a9fd 100644 --- a/metrics_layer/core/sql/query_base.py +++ b/metrics_layer/core/sql/query_base.py @@ -22,23 +22,18 @@ def get_where_with_aliases( where = [] for filter_clause in filters: filter_clause["query_type"] = self.query_type - f = MetricsLayerFilter(definition=filter_clause, design=None, filter_type="where") - field = project.get_field(filter_clause["field"]) - field_alias = field.alias(with_view=True) - if field_alias in cte_alias_lookup: - field_alias = f"{cte_alias_lookup[field_alias]}.{field_alias}" - elif raise_if_not_in_lookup: - self._raise_query_error_from_cte(field.id(capitalize_alias=True)) - where.append(f.criterion(field_alias)) + f = MetricsLayerFilter( + definition=filter_clause, design=None, filter_type="where", project=project + ) + where.append( + f.sql_query( + alias_query=True, + cte_alias_lookup=cte_alias_lookup, + raise_if_not_in_lookup=raise_if_not_in_lookup, + ) + ) return where - @staticmethod - def _raise_query_error_from_cte(field_name: str): - raise QueryError( - f"Field {field_name} is not present in either source query, so it" - " cannot be applied as a filter. Please add it to one of the source queries." - ) - @staticmethod def parse_identifiers_from_clause(clause: str): if clause is None: diff --git a/metrics_layer/core/sql/query_filter.py b/metrics_layer/core/sql/query_filter.py index f901d52..c858be1 100644 --- a/metrics_layer/core/sql/query_filter.py +++ b/metrics_layer/core/sql/query_filter.py @@ -39,12 +39,13 @@ class MetricsLayerFilter(MetricsLayerBase): """ def __init__( - self, definition: Dict = {}, design: MetricsLayerDesign = None, filter_type: str = None + self, definition: Dict = {}, design: MetricsLayerDesign = None, filter_type: str = None, project=None ) -> None: # The design is used for filters in queries against specific designs # to validate that all the tables and attributes (columns/aggregates) # are properly defined in the design self.design = design + self.project = project self.is_literal_filter = "literal" in definition # This is a filter with parenthesis like (XYZ or ABC) self.is_filter_group = "conditions" in definition @@ -84,6 +85,7 @@ def validate(self, definition: Dict) -> None: if filter_group_conditions: for f in filter_group_conditions: + f["query_type"] = self.query_type MetricsLayerFilter(f, self.design, self.filter_type) if ( @@ -148,12 +150,32 @@ def validate(self, definition: Dict) -> None: if self.field.type == "yesno" and "True" in str(definition["value"]): definition["expression"] = "boolean_true" - def group_sql_query(self, functional_pk: str): + def group_sql_query( + self, + functional_pk: str, + alias_query: bool = False, + cte_alias_lookup: dict = {}, + raise_if_not_in_lookup: bool = False, + ): pypika_conditions = [] for condition in self.conditions: - condition_object = MetricsLayerFilter(condition, self.design, self.filter_type) + condition_object = MetricsLayerFilter(condition, self.design, self.filter_type, self.project) if condition_object.is_filter_group: - pypika_conditions.append(condition_object.group_sql_query(functional_pk)) + pypika_conditions.append( + condition_object.group_sql_query( + functional_pk, + alias_query, + cte_alias_lookup=cte_alias_lookup, + raise_if_not_in_lookup=raise_if_not_in_lookup, + ) + ) + elif alias_query: + if self.project is None: + raise ValueError("Project is not set, but it is required for an alias_query") + field_alias = self._handle_cte_alias_replacement( + condition_object.field, cte_alias_lookup, raise_if_not_in_lookup + ) + pypika_conditions.append(condition_object.criterion(field_alias)) else: pypika_conditions.append( condition_object.criterion( @@ -169,14 +191,36 @@ def group_sql_query(self, functional_pk: str): return Criterion.all(pypika_conditions) raise ParseError(f"Invalid logical operator: {self.logical_operator}") - def sql_query(self): + def sql_query( + self, alias_query: bool = False, cte_alias_lookup: dict = {}, raise_if_not_in_lookup: bool = False + ): if self.is_literal_filter: return LiteralValueCriterion(self.replace_fields_literal_filter()) + + if alias_query and self.is_filter_group: + return self.group_sql_query("NA", alias_query, cte_alias_lookup, raise_if_not_in_lookup) + elif alias_query: + field_alias = self._handle_cte_alias_replacement( + self.field, cte_alias_lookup, raise_if_not_in_lookup + ) + return self.criterion(field_alias) + functional_pk = self.design.functional_pk() if self.is_filter_group: return self.group_sql_query(functional_pk) return self.criterion(self.field.sql_query(self.query_type, functional_pk)) + def _handle_cte_alias_replacement( + self, field_id: str, cte_alias_lookup: dict, raise_if_not_in_lookup: bool + ): + field = self.project.get_field(field_id) + field_alias = field.alias(with_view=True) + if field_alias in cte_alias_lookup: + field_alias = f"{cte_alias_lookup[field_alias]}.{field_alias}" + elif raise_if_not_in_lookup: + self._raise_query_error_from_cte(field.id()) + return field_alias + def isin_sql_query(self, cte_alias, field_name, query_generator): group_by_field = self.design.get_field(field_name) base = query_generator._base_query() diff --git a/tests/test_arbitrary_merged_results.py b/tests/test_arbitrary_merged_results.py index 61b0bf4..32ecaf2 100644 --- a/tests/test_arbitrary_merged_results.py +++ b/tests/test_arbitrary_merged_results.py @@ -648,7 +648,7 @@ def test_query_merged_queries_invalid_where_post_merge(connection): assert isinstance(exc_info.value, QueryError) assert ( str(exc_info.value) - == "Field orders.NEW_VS_REPEAT is not present in either source query, so it cannot be applied as a" + == "Field orders.new_vs_repeat is not present in either source query, so it cannot be applied as a" " filter. Please add it to one of the source queries." ) @@ -670,7 +670,7 @@ def test_query_merged_queries_invalid_having_post_merge(connection): assert isinstance(exc_info.value, QueryError) assert ( str(exc_info.value) - == "Field order_lines.TOTAL_ITEM_COSTS is not present in either source query, so it cannot be applied" + == "Field order_lines.total_item_costs is not present in either source query, so it cannot be applied" " as a filter. Please add it to one of the source queries." ) @@ -691,7 +691,7 @@ def test_query_merged_queries_invalid_order_by_post_merge(connection): assert isinstance(exc_info.value, QueryError) assert ( str(exc_info.value) - == "Field order_lines.TOTAL_ITEM_COSTS is not present in either source query, so it cannot be applied" + == "Field order_lines.total_item_costs is not present in either source query, so it cannot be applied" " as a filter. Please add it to one of the source queries." ) diff --git a/tests/test_join_query.py b/tests/test_join_query.py index 97fcb13..b07ef00 100644 --- a/tests/test_join_query.py +++ b/tests/test_join_query.py @@ -2,6 +2,7 @@ import pytest +from metrics_layer import MetricsLayerConnection from metrics_layer.core.exceptions import JoinError, QueryError from metrics_layer.core.model import Definitions from metrics_layer.core.sql.query_errors import ParseError @@ -1271,3 +1272,66 @@ def test_query_with_or_filters_alternate_syntax(connection): " order_lines_total_item_revenue DESC NULLS LAST;" ) assert query == correct + + +# TODO DELETE BEFORE MERGE +@pytest.mark.queryy +def test_query_with_or_filters_alternate_syntaxx(connection): + connection = MetricsLayerConnection("/Users/pb/src/data_models/demo-data-model") + connection.load() + + query = connection.get_sql_query( + query_type="SNOWFLAKE", + metrics=["number_of_orders"], + dimensions=[], + where=[ + {"field": "date", "expression": "greater_or_equal_than", "value": datetime(2024, 1, 1, 0, 0)}, + { + "field": "date", + "expression": "less_or_equal_than", + "value": datetime(2024, 12, 31, 23, 59, 59), + }, + ], + having=[ + { + "conditional_filter_logic": { + "conditions": [ + { + "field": "order_lines.total_net_revenue", + "expression": "less_than", + "value": 5, + }, + { + "field": "order_lines.total_gross_revenue", + "expression": "greater_than", + "value": 6, + }, + { + "conditions": [ + { + "field": "roas", + "expression": "greater_than", + "value": 1, + }, + ], + "logical_operator": "AND", + }, + ], + "logical_operator": "OR", + } + }, + ], + ) + + correct = ( + "SELECT order_lines.sales_channel as order_lines_channel,SUM(order_lines.revenue) as" + " order_lines_total_item_revenue FROM analytics.order_line_items order_lines LEFT JOIN" + " analytics.customers customers ON order_lines.customer_id=customers.customer_id WHERE" + " customers.gender IN ('M') AND DATE_TRUNC('DAY', order_lines.order_date)>='2024-01-01T00:00:00' AND" + " DATE_TRUNC('DAY', order_lines.order_date)<='2024-12-31T23:59:59' GROUP BY order_lines.sales_channel" + " HAVING SUM(order_lines.revenue)>=100.0 AND SUM(order_lines.revenue)<=200.0 AND" + " (SUM(order_lines.revenue)>100.0 OR SUM(order_lines.revenue)<200.0 OR" + " (SUM(order_lines.revenue)>100.0 AND SUM(order_lines.revenue)<200.0)) ORDER BY" + " order_lines_total_item_revenue DESC NULLS LAST;" + ) + assert query == correct