Skip to content

Commit

Permalink
fix non additive or filters issue
Browse files Browse the repository at this point in the history
  • Loading branch information
pblankley committed Sep 25, 2024
1 parent 93c8651 commit a42b628
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 17 deletions.
2 changes: 2 additions & 0 deletions metrics_layer/cli/seeding.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,8 @@ def make_fields(self, column_data, schema_name: str, table_name: str, auto_tag_s
Definitions.duck_db,
Definitions.postgres,
Definitions.redshift,
Definitions.sql_server,
Definitions.azure_synapse,
}:
column_name = '"' + row["COLUMN_NAME"] + '"'
else:
Expand Down
5 changes: 3 additions & 2 deletions metrics_layer/core/sql/query_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from metrics_layer.core.sql.query_dialect import NullSorting, query_lookup
from metrics_layer.core.sql.query_errors import ArgumentError
from metrics_layer.core.sql.query_filter import MetricsLayerFilter
from metrics_layer.core.utils import flatten_filters


class MetricsLayerQuery(MetricsLayerQueryBase):
Expand Down Expand Up @@ -80,7 +81,7 @@ def parse_definition(self, definition: dict):
# them as CTE's for the appropriate filters
self.non_additive_ctes = []
metrics_in_select = definition.get("metrics", [])
metrics_in_having = [h.field.id() for h in self.having_filters if h.field]
metrics_in_having = [h["field"] for h in flatten_filters(having)]
for metric in metrics_in_select + metrics_in_having:
metric_field = self.design.get_field(metric)
for ref_field in [metric_field] + metric_field.referenced_fields(metric_field.sql):
Expand Down Expand Up @@ -385,7 +386,7 @@ def _non_additive_cte(self, definition: dict, group_by_dimensions: list):
field_lookup[non_additive_dimension.id()] = non_additive_dimension

# We also need to make all fields in the where clause available to the query
for f in self.where:
for f in flatten_filters(self.where):
field = self.design.get_field(f["field"])
field_lookup[field.id()] = field

Expand Down
17 changes: 2 additions & 15 deletions metrics_layer/core/sql/single_query_resolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from metrics_layer.core.sql.query_design import MetricsLayerDesign
from metrics_layer.core.sql.query_funnel import FunnelQuery
from metrics_layer.core.sql.query_generator import MetricsLayerQuery
from metrics_layer.core.utils import flatten_filters


class SingleSQLQueryResolver:
Expand Down Expand Up @@ -232,21 +233,7 @@ def parse_identifiers_from_dicts(conditions: list):

@staticmethod
def flatten_filters(filters: list):
flat_list = []

def recurse(filter_obj):
if isinstance(filter_obj, dict):
if "conditions" in filter_obj:
for f in filter_obj["conditions"]:
recurse(f)
else:
flat_list.append(filter_obj)
elif isinstance(filter_obj, list):
for item in filter_obj:
recurse(item)

recurse(filters)
return flat_list
return flatten_filters(filters)

@staticmethod
def _check_for_dict(conditions: list):
Expand Down
18 changes: 18 additions & 0 deletions metrics_layer/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,21 @@ def generate_random_password(length):
letters = string.ascii_letters
result_str = "".join(random.choice(letters) for i in range(length))
return result_str


def flatten_filters(filters: list):
flat_list = []

def recurse(filter_obj):
if isinstance(filter_obj, dict):
if "conditions" in filter_obj:
for f in filter_obj["conditions"]:
recurse(f)
else:
flat_list.append(filter_obj)
elif isinstance(filter_obj, list):
for item in filter_obj:
recurse(item)

recurse(filters)
return flat_list
18 changes: 18 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@ def yaml_dump_assert(slf, data, file):
Definitions.postgres,
Definitions.trino,
Definitions.redshift,
Definitions.sql_server,
Definitions.azure_synapse,
}:
assert social["sql"] == '${TABLE}."ON_SOCIAL_NETWORK"'
else:
Expand All @@ -195,6 +197,8 @@ def yaml_dump_assert(slf, data, file):
Definitions.postgres,
Definitions.trino,
Definitions.redshift,
Definitions.sql_server,
Definitions.azure_synapse,
}:
assert acq_date["sql"] == '${TABLE}."ACQUISITION_DATE"'
else:
Expand Down Expand Up @@ -231,6 +235,8 @@ def yaml_dump_assert(slf, data, file):
Definitions.postgres,
Definitions.trino,
Definitions.redshift,
Definitions.sql_server,
Definitions.azure_synapse,
}:
assert date["sql"] == '${TABLE}."ORDER_CREATED_AT"'
else:
Expand All @@ -244,6 +250,8 @@ def yaml_dump_assert(slf, data, file):
Definitions.postgres,
Definitions.trino,
Definitions.redshift,
Definitions.sql_server,
Definitions.azure_synapse,
}:
assert new["sql"] == '${TABLE}."NEW_VS_REPEAT"'
else:
Expand All @@ -257,6 +265,8 @@ def yaml_dump_assert(slf, data, file):
Definitions.postgres,
Definitions.trino,
Definitions.redshift,
Definitions.sql_server,
Definitions.azure_synapse,
}:
assert num["sql"] == '${TABLE}."REVENUE"'
else:
Expand Down Expand Up @@ -307,6 +317,8 @@ def yaml_dump_assert(slf, data, file):
Definitions.postgres,
Definitions.trino,
Definitions.redshift,
Definitions.sql_server,
Definitions.azure_synapse,
}:
assert cross_sell["sql"] == '${TABLE}."@CRoSSell P-roduct:"'
else:
Expand Down Expand Up @@ -344,6 +356,8 @@ def yaml_dump_assert(slf, data, file):
Definitions.postgres,
Definitions.trino,
Definitions.redshift,
Definitions.sql_server,
Definitions.azure_synapse,
}:
assert date["sql"] == '${TABLE}."SESSION_DATE"'
else:
Expand All @@ -357,6 +371,8 @@ def yaml_dump_assert(slf, data, file):
Definitions.postgres,
Definitions.trino,
Definitions.redshift,
Definitions.sql_server,
Definitions.azure_synapse,
}:
assert pk["sql"] == '${TABLE}."SESSION_ID"'
else:
Expand All @@ -370,6 +386,8 @@ def yaml_dump_assert(slf, data, file):
Definitions.postgres,
Definitions.trino,
Definitions.redshift,
Definitions.sql_server,
Definitions.azure_synapse,
}:
assert num["sql"] == '${TABLE}."CONVERSION"'
else:
Expand Down
77 changes: 77 additions & 0 deletions tests/test_non_additive_dimensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,3 +672,80 @@ def test_mrr_non_additive_dimension_merged_result_sub_join_where(connection):
" mrr_record__cte_subquery_0.mrr_record_date=z_customer_accounts_created__cte_subquery_1.z_customer_accounts_created_date;" # noqa
)
assert query == correct


@pytest.mark.query
def test_mrr_non_additive_dimension_or_filters_with_select(connection):
query = connection.get_sql_query(
metrics=["mrr.mrr_end_of_month"],
dimensions=[],
where=[
{
"conditional_filter_logic": {
"conditions": [
{"field": "mrr.plan_name", "expression": "equal_to", "value": "Enterprise"}
],
"logical_operator": "AND",
}
},
],
having=[
{
"conditional_filter_logic": {
"conditions": [
{"field": "number_of_billed_accounts", "expression": "greater_than", "value": 1100}
],
"logical_operator": "AND",
}
}
],
)

correct = (
"WITH cte_mrr_end_of_month_record_raw AS (SELECT MAX(mrr.record_date) as mrr_max_record_raw FROM"
" analytics.mrr_by_customer mrr WHERE mrr.plan_name='Enterprise' ORDER BY mrr_max_record_raw DESC"
" NULLS LAST) SELECT SUM(case when mrr.record_date=cte_mrr_end_of_month_record_raw.mrr_max_record_raw"
" then mrr.mrr else 0 end) as mrr_mrr_end_of_month FROM analytics.mrr_by_customer mrr LEFT JOIN"
" cte_mrr_end_of_month_record_raw ON 1=1 WHERE mrr.plan_name='Enterprise' HAVING"
" COUNT(mrr.parent_account_id)>1100 ORDER BY mrr_mrr_end_of_month DESC NULLS LAST;"
)
assert query == correct


@pytest.mark.query
def test_mrr_non_additive_dimension_or_filters(connection):
query = connection.get_sql_query(
metrics=["number_of_billed_accounts"],
dimensions=[],
where=[
{
"conditional_filter_logic": {
"conditions": [
{"field": "mrr.plan_name", "expression": "equal_to", "value": "Enterprise"}
],
"logical_operator": "AND",
}
},
],
having=[
{
"conditional_filter_logic": {
"conditions": [
{"field": "mrr.mrr_end_of_month", "expression": "greater_than", "value": 1100}
],
"logical_operator": "AND",
}
}
],
)

correct = (
"WITH cte_mrr_end_of_month_record_raw AS (SELECT MAX(mrr.record_date) as mrr_max_record_raw FROM"
" analytics.mrr_by_customer mrr WHERE mrr.plan_name='Enterprise' ORDER BY mrr_max_record_raw DESC"
" NULLS LAST) SELECT COUNT(mrr.parent_account_id) as mrr_number_of_billed_accounts FROM"
" analytics.mrr_by_customer mrr LEFT JOIN cte_mrr_end_of_month_record_raw ON 1=1 WHERE"
" mrr.plan_name='Enterprise' HAVING SUM(case when"
" mrr.record_date=cte_mrr_end_of_month_record_raw.mrr_max_record_raw then mrr.mrr else 0 end)>1100"
" ORDER BY mrr_number_of_billed_accounts DESC NULLS LAST;"
)
assert query == correct

0 comments on commit a42b628

Please sign in to comment.