Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix non additive or filters issue #235

Merged
merged 2 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions metrics_layer/cli/seeding.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,8 @@ def make_fields(self, column_data, schema_name: str, table_name: str, auto_tag_s
Definitions.duck_db,
Definitions.postgres,
Definitions.redshift,
Definitions.sql_server,
Definitions.azure_synapse,
}:
column_name = '"' + row["COLUMN_NAME"] + '"'
else:
Expand Down
5 changes: 3 additions & 2 deletions metrics_layer/core/sql/query_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from metrics_layer.core.sql.query_dialect import NullSorting, query_lookup
from metrics_layer.core.sql.query_errors import ArgumentError
from metrics_layer.core.sql.query_filter import MetricsLayerFilter
from metrics_layer.core.utils import flatten_filters


class MetricsLayerQuery(MetricsLayerQueryBase):
Expand Down Expand Up @@ -80,7 +81,7 @@ def parse_definition(self, definition: dict):
# them as CTE's for the appropriate filters
self.non_additive_ctes = []
metrics_in_select = definition.get("metrics", [])
metrics_in_having = [h.field.id() for h in self.having_filters if h.field]
metrics_in_having = [h["field"] for h in flatten_filters(having)]
for metric in metrics_in_select + metrics_in_having:
metric_field = self.design.get_field(metric)
for ref_field in [metric_field] + metric_field.referenced_fields(metric_field.sql):
Expand Down Expand Up @@ -385,7 +386,7 @@ def _non_additive_cte(self, definition: dict, group_by_dimensions: list):
field_lookup[non_additive_dimension.id()] = non_additive_dimension

# We also need to make all fields in the where clause available to the query
for f in self.where:
for f in flatten_filters(self.where):
field = self.design.get_field(f["field"])
field_lookup[field.id()] = field

Expand Down
17 changes: 2 additions & 15 deletions metrics_layer/core/sql/single_query_resolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from metrics_layer.core.sql.query_design import MetricsLayerDesign
from metrics_layer.core.sql.query_funnel import FunnelQuery
from metrics_layer.core.sql.query_generator import MetricsLayerQuery
from metrics_layer.core.utils import flatten_filters


class SingleSQLQueryResolver:
Expand Down Expand Up @@ -232,21 +233,7 @@ def parse_identifiers_from_dicts(conditions: list):

@staticmethod
def flatten_filters(filters: list):
flat_list = []

def recurse(filter_obj):
if isinstance(filter_obj, dict):
if "conditions" in filter_obj:
for f in filter_obj["conditions"]:
recurse(f)
else:
flat_list.append(filter_obj)
elif isinstance(filter_obj, list):
for item in filter_obj:
recurse(item)

recurse(filters)
return flat_list
return flatten_filters(filters)

@staticmethod
def _check_for_dict(conditions: list):
Expand Down
18 changes: 18 additions & 0 deletions metrics_layer/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,21 @@ def generate_random_password(length):
letters = string.ascii_letters
result_str = "".join(random.choice(letters) for i in range(length))
return result_str


def flatten_filters(filters: list):
flat_list = []

def recurse(filter_obj):
if isinstance(filter_obj, dict):
if "conditions" in filter_obj:
for f in filter_obj["conditions"]:
recurse(f)
else:
flat_list.append(filter_obj)
elif isinstance(filter_obj, list):
for item in filter_obj:
recurse(item)

recurse(filters)
return flat_list
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "metrics_layer"
version = "0.12.38"
version = "0.12.39"
description = "The open source metrics layer."
authors = ["Paul Blankley <[email protected]>"]
keywords = ["Metrics Layer", "Business Intelligence", "Analytics"]
Expand Down
18 changes: 18 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@ def yaml_dump_assert(slf, data, file):
Definitions.postgres,
Definitions.trino,
Definitions.redshift,
Definitions.sql_server,
Definitions.azure_synapse,
}:
assert social["sql"] == '${TABLE}."ON_SOCIAL_NETWORK"'
else:
Expand All @@ -195,6 +197,8 @@ def yaml_dump_assert(slf, data, file):
Definitions.postgres,
Definitions.trino,
Definitions.redshift,
Definitions.sql_server,
Definitions.azure_synapse,
}:
assert acq_date["sql"] == '${TABLE}."ACQUISITION_DATE"'
else:
Expand Down Expand Up @@ -231,6 +235,8 @@ def yaml_dump_assert(slf, data, file):
Definitions.postgres,
Definitions.trino,
Definitions.redshift,
Definitions.sql_server,
Definitions.azure_synapse,
}:
assert date["sql"] == '${TABLE}."ORDER_CREATED_AT"'
else:
Expand All @@ -244,6 +250,8 @@ def yaml_dump_assert(slf, data, file):
Definitions.postgres,
Definitions.trino,
Definitions.redshift,
Definitions.sql_server,
Definitions.azure_synapse,
}:
assert new["sql"] == '${TABLE}."NEW_VS_REPEAT"'
else:
Expand All @@ -257,6 +265,8 @@ def yaml_dump_assert(slf, data, file):
Definitions.postgres,
Definitions.trino,
Definitions.redshift,
Definitions.sql_server,
Definitions.azure_synapse,
}:
assert num["sql"] == '${TABLE}."REVENUE"'
else:
Expand Down Expand Up @@ -307,6 +317,8 @@ def yaml_dump_assert(slf, data, file):
Definitions.postgres,
Definitions.trino,
Definitions.redshift,
Definitions.sql_server,
Definitions.azure_synapse,
}:
assert cross_sell["sql"] == '${TABLE}."@CRoSSell P-roduct:"'
else:
Expand Down Expand Up @@ -344,6 +356,8 @@ def yaml_dump_assert(slf, data, file):
Definitions.postgres,
Definitions.trino,
Definitions.redshift,
Definitions.sql_server,
Definitions.azure_synapse,
}:
assert date["sql"] == '${TABLE}."SESSION_DATE"'
else:
Expand All @@ -357,6 +371,8 @@ def yaml_dump_assert(slf, data, file):
Definitions.postgres,
Definitions.trino,
Definitions.redshift,
Definitions.sql_server,
Definitions.azure_synapse,
}:
assert pk["sql"] == '${TABLE}."SESSION_ID"'
else:
Expand All @@ -370,6 +386,8 @@ def yaml_dump_assert(slf, data, file):
Definitions.postgres,
Definitions.trino,
Definitions.redshift,
Definitions.sql_server,
Definitions.azure_synapse,
}:
assert num["sql"] == '${TABLE}."CONVERSION"'
else:
Expand Down
77 changes: 77 additions & 0 deletions tests/test_non_additive_dimensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,3 +672,80 @@ def test_mrr_non_additive_dimension_merged_result_sub_join_where(connection):
" mrr_record__cte_subquery_0.mrr_record_date=z_customer_accounts_created__cte_subquery_1.z_customer_accounts_created_date;" # noqa
)
assert query == correct


@pytest.mark.query
def test_mrr_non_additive_dimension_or_filters_with_select(connection):
query = connection.get_sql_query(
metrics=["mrr.mrr_end_of_month"],
dimensions=[],
where=[
{
"conditional_filter_logic": {
"conditions": [
{"field": "mrr.plan_name", "expression": "equal_to", "value": "Enterprise"}
],
"logical_operator": "AND",
}
},
],
having=[
{
"conditional_filter_logic": {
"conditions": [
{"field": "number_of_billed_accounts", "expression": "greater_than", "value": 1100}
],
"logical_operator": "AND",
}
}
],
)

correct = (
"WITH cte_mrr_end_of_month_record_raw AS (SELECT MAX(mrr.record_date) as mrr_max_record_raw FROM"
" analytics.mrr_by_customer mrr WHERE mrr.plan_name='Enterprise' ORDER BY mrr_max_record_raw DESC"
" NULLS LAST) SELECT SUM(case when mrr.record_date=cte_mrr_end_of_month_record_raw.mrr_max_record_raw"
" then mrr.mrr else 0 end) as mrr_mrr_end_of_month FROM analytics.mrr_by_customer mrr LEFT JOIN"
" cte_mrr_end_of_month_record_raw ON 1=1 WHERE mrr.plan_name='Enterprise' HAVING"
" COUNT(mrr.parent_account_id)>1100 ORDER BY mrr_mrr_end_of_month DESC NULLS LAST;"
)
assert query == correct


@pytest.mark.query
def test_mrr_non_additive_dimension_or_filters(connection):
query = connection.get_sql_query(
metrics=["number_of_billed_accounts"],
dimensions=[],
where=[
{
"conditional_filter_logic": {
"conditions": [
{"field": "mrr.plan_name", "expression": "equal_to", "value": "Enterprise"}
],
"logical_operator": "AND",
}
},
],
having=[
{
"conditional_filter_logic": {
"conditions": [
{"field": "mrr.mrr_end_of_month", "expression": "greater_than", "value": 1100}
],
"logical_operator": "AND",
}
}
],
)

correct = (
"WITH cte_mrr_end_of_month_record_raw AS (SELECT MAX(mrr.record_date) as mrr_max_record_raw FROM"
" analytics.mrr_by_customer mrr WHERE mrr.plan_name='Enterprise' ORDER BY mrr_max_record_raw DESC"
" NULLS LAST) SELECT COUNT(mrr.parent_account_id) as mrr_number_of_billed_accounts FROM"
" analytics.mrr_by_customer mrr LEFT JOIN cte_mrr_end_of_month_record_raw ON 1=1 WHERE"
" mrr.plan_name='Enterprise' HAVING SUM(case when"
" mrr.record_date=cte_mrr_end_of_month_record_raw.mrr_max_record_raw then mrr.mrr else 0 end)>1100"
" ORDER BY mrr_number_of_billed_accounts DESC NULLS LAST;"
)
assert query == correct
Loading