Skip to content

Commit

Permalink
fix issue with merged results filter application + include test for p…
Browse files Browse the repository at this point in the history
…ure organics (remove before merge)
  • Loading branch information
pblankley committed Sep 19, 2024
1 parent bf46d00 commit 21cc1de
Show file tree
Hide file tree
Showing 7 changed files with 137 additions and 25 deletions.
9 changes: 9 additions & 0 deletions metrics_layer/core/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import re
from typing import List

from metrics_layer.core.exceptions import QueryError

NAME_REGEX = re.compile(r"([A-Za-z0-9\_]+)")


Expand Down Expand Up @@ -32,6 +34,13 @@ def name_error(entity_name: str, name: str):
"the naming conventions (only letters, numbers, or underscores)"
)

@staticmethod
def _raise_query_error_from_cte(field_name: str):
raise QueryError(
f"Field {field_name} is not present in either source query, so it"
" cannot be applied as a filter. Please add it to one of the source queries."
)

@staticmethod
def line_col(element):
line = getattr(getattr(element, "lc", None), "line", None)
Expand Down
2 changes: 1 addition & 1 deletion metrics_layer/core/sql/merged_query_resolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def derive_sub_queries(self):

secondary_metric_ids = [m.id() for m in self.secondary_metrics]
merged_metric_ids = [m.id() for m in self.merged_metrics]
for h in self.having:
for h in self.flatten_filters(self.having):
field = self.project.get_field(h["field"])
if field.id() not in secondary_metric_ids and not field.is_merged_result:
self.secondary_metrics.append(field)
Expand Down
2 changes: 1 addition & 1 deletion metrics_layer/core/sql/query_arbitrary_merged_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def get_query(self, semicolon: bool = True):
if order_by_alias in self.cte_alias_lookup:
order_by_alias = f"{self.cte_alias_lookup[order_by_alias]}.{order_by_alias}"
else:
self._raise_query_error_from_cte(field.id(capitalize_alias=True))
self._raise_query_error_from_cte(field.id())

order = Order.desc if order_clause.get("sort", "asc").lower() == "desc" else Order.asc
complete_query = complete_query.orderby(
Expand Down
25 changes: 10 additions & 15 deletions metrics_layer/core/sql/query_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,23 +22,18 @@ def get_where_with_aliases(
where = []
for filter_clause in filters:
filter_clause["query_type"] = self.query_type
f = MetricsLayerFilter(definition=filter_clause, design=None, filter_type="where")
field = project.get_field(filter_clause["field"])
field_alias = field.alias(with_view=True)
if field_alias in cte_alias_lookup:
field_alias = f"{cte_alias_lookup[field_alias]}.{field_alias}"
elif raise_if_not_in_lookup:
self._raise_query_error_from_cte(field.id(capitalize_alias=True))
where.append(f.criterion(field_alias))
f = MetricsLayerFilter(
definition=filter_clause, design=None, filter_type="where", project=project
)
where.append(
f.sql_query(
alias_query=True,
cte_alias_lookup=cte_alias_lookup,
raise_if_not_in_lookup=raise_if_not_in_lookup,
)
)
return where

@staticmethod
def _raise_query_error_from_cte(field_name: str):
raise QueryError(
f"Field {field_name} is not present in either source query, so it"
" cannot be applied as a filter. Please add it to one of the source queries."
)

@staticmethod
def parse_identifiers_from_clause(clause: str):
if clause is None:
Expand Down
54 changes: 49 additions & 5 deletions metrics_layer/core/sql/query_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,13 @@ class MetricsLayerFilter(MetricsLayerBase):
"""

def __init__(
self, definition: Dict = {}, design: MetricsLayerDesign = None, filter_type: str = None
self, definition: Dict = {}, design: MetricsLayerDesign = None, filter_type: str = None, project=None
) -> None:
# The design is used for filters in queries against specific designs
# to validate that all the tables and attributes (columns/aggregates)
# are properly defined in the design
self.design = design
self.project = project
self.is_literal_filter = "literal" in definition
# This is a filter with parenthesis like (XYZ or ABC)
self.is_filter_group = "conditions" in definition
Expand Down Expand Up @@ -84,6 +85,7 @@ def validate(self, definition: Dict) -> None:

if filter_group_conditions:
for f in filter_group_conditions:
f["query_type"] = self.query_type
MetricsLayerFilter(f, self.design, self.filter_type)

if (
Expand Down Expand Up @@ -148,12 +150,32 @@ def validate(self, definition: Dict) -> None:
if self.field.type == "yesno" and "True" in str(definition["value"]):
definition["expression"] = "boolean_true"

def group_sql_query(self, functional_pk: str):
def group_sql_query(
self,
functional_pk: str,
alias_query: bool = False,
cte_alias_lookup: dict = {},
raise_if_not_in_lookup: bool = False,
):
pypika_conditions = []
for condition in self.conditions:
condition_object = MetricsLayerFilter(condition, self.design, self.filter_type)
condition_object = MetricsLayerFilter(condition, self.design, self.filter_type, self.project)
if condition_object.is_filter_group:
pypika_conditions.append(condition_object.group_sql_query(functional_pk))
pypika_conditions.append(
condition_object.group_sql_query(
functional_pk,
alias_query,
cte_alias_lookup=cte_alias_lookup,
raise_if_not_in_lookup=raise_if_not_in_lookup,
)
)
elif alias_query:
if self.project is None:
raise ValueError("Project is not set, but it is required for an alias_query")
field_alias = self._handle_cte_alias_replacement(
condition_object.field, cte_alias_lookup, raise_if_not_in_lookup
)
pypika_conditions.append(condition_object.criterion(field_alias))
else:
pypika_conditions.append(
condition_object.criterion(
Expand All @@ -169,14 +191,36 @@ def group_sql_query(self, functional_pk: str):
return Criterion.all(pypika_conditions)
raise ParseError(f"Invalid logical operator: {self.logical_operator}")

def sql_query(self):
def sql_query(
self, alias_query: bool = False, cte_alias_lookup: dict = {}, raise_if_not_in_lookup: bool = False
):
if self.is_literal_filter:
return LiteralValueCriterion(self.replace_fields_literal_filter())

if alias_query and self.is_filter_group:
return self.group_sql_query("NA", alias_query, cte_alias_lookup, raise_if_not_in_lookup)
elif alias_query:
field_alias = self._handle_cte_alias_replacement(
self.field, cte_alias_lookup, raise_if_not_in_lookup
)
return self.criterion(field_alias)

functional_pk = self.design.functional_pk()
if self.is_filter_group:
return self.group_sql_query(functional_pk)
return self.criterion(self.field.sql_query(self.query_type, functional_pk))

def _handle_cte_alias_replacement(
self, field_id: str, cte_alias_lookup: dict, raise_if_not_in_lookup: bool
):
field = self.project.get_field(field_id)
field_alias = field.alias(with_view=True)
if field_alias in cte_alias_lookup:
field_alias = f"{cte_alias_lookup[field_alias]}.{field_alias}"
elif raise_if_not_in_lookup:
self._raise_query_error_from_cte(field.id())
return field_alias

def isin_sql_query(self, cte_alias, field_name, query_generator):
group_by_field = self.design.get_field(field_name)
base = query_generator._base_query()
Expand Down
6 changes: 3 additions & 3 deletions tests/test_arbitrary_merged_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,7 +648,7 @@ def test_query_merged_queries_invalid_where_post_merge(connection):
assert isinstance(exc_info.value, QueryError)
assert (
str(exc_info.value)
== "Field orders.NEW_VS_REPEAT is not present in either source query, so it cannot be applied as a"
== "Field orders.new_vs_repeat is not present in either source query, so it cannot be applied as a"
" filter. Please add it to one of the source queries."
)

Expand All @@ -670,7 +670,7 @@ def test_query_merged_queries_invalid_having_post_merge(connection):
assert isinstance(exc_info.value, QueryError)
assert (
str(exc_info.value)
== "Field order_lines.TOTAL_ITEM_COSTS is not present in either source query, so it cannot be applied"
== "Field order_lines.total_item_costs is not present in either source query, so it cannot be applied"
" as a filter. Please add it to one of the source queries."
)

Expand All @@ -691,7 +691,7 @@ def test_query_merged_queries_invalid_order_by_post_merge(connection):
assert isinstance(exc_info.value, QueryError)
assert (
str(exc_info.value)
== "Field order_lines.TOTAL_ITEM_COSTS is not present in either source query, so it cannot be applied"
== "Field order_lines.total_item_costs is not present in either source query, so it cannot be applied"
" as a filter. Please add it to one of the source queries."
)

Expand Down
64 changes: 64 additions & 0 deletions tests/test_join_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pytest

from metrics_layer import MetricsLayerConnection
from metrics_layer.core.exceptions import JoinError, QueryError
from metrics_layer.core.model import Definitions
from metrics_layer.core.sql.query_errors import ParseError
Expand Down Expand Up @@ -1271,3 +1272,66 @@ def test_query_with_or_filters_alternate_syntax(connection):
" order_lines_total_item_revenue DESC NULLS LAST;"
)
assert query == correct


# TODO DELETE BEFORE MERGE
@pytest.mark.queryy
def test_query_with_or_filters_alternate_syntaxx(connection):
connection = MetricsLayerConnection("/Users/pb/src/data_models/demo-data-model")
connection.load()

query = connection.get_sql_query(
query_type="SNOWFLAKE",
metrics=["number_of_orders"],
dimensions=[],
where=[
{"field": "date", "expression": "greater_or_equal_than", "value": datetime(2024, 1, 1, 0, 0)},
{
"field": "date",
"expression": "less_or_equal_than",
"value": datetime(2024, 12, 31, 23, 59, 59),
},
],
having=[
{
"conditional_filter_logic": {
"conditions": [
{
"field": "order_lines.total_net_revenue",
"expression": "less_than",
"value": 5,
},
{
"field": "order_lines.total_gross_revenue",
"expression": "greater_than",
"value": 6,
},
{
"conditions": [
{
"field": "roas",
"expression": "greater_than",
"value": 1,
},
],
"logical_operator": "AND",
},
],
"logical_operator": "OR",
}
},
],
)

correct = (
"SELECT order_lines.sales_channel as order_lines_channel,SUM(order_lines.revenue) as"
" order_lines_total_item_revenue FROM analytics.order_line_items order_lines LEFT JOIN"
" analytics.customers customers ON order_lines.customer_id=customers.customer_id WHERE"
" customers.gender IN ('M') AND DATE_TRUNC('DAY', order_lines.order_date)>='2024-01-01T00:00:00' AND"
" DATE_TRUNC('DAY', order_lines.order_date)<='2024-12-31T23:59:59' GROUP BY order_lines.sales_channel"
" HAVING SUM(order_lines.revenue)>=100.0 AND SUM(order_lines.revenue)<=200.0 AND"
" (SUM(order_lines.revenue)>100.0 OR SUM(order_lines.revenue)<200.0 OR"
" (SUM(order_lines.revenue)>100.0 AND SUM(order_lines.revenue)<200.0)) ORDER BY"
" order_lines_total_item_revenue DESC NULLS LAST;"
)
assert query == correct

0 comments on commit 21cc1de

Please sign in to comment.