Skip to content

Commit

Permalink
add dimension support + smart filter consolidation to group by filters
Browse files Browse the repository at this point in the history
  • Loading branch information
pblankley committed Jul 31, 2024
1 parent 158d762 commit 8f309dd
Show file tree
Hide file tree
Showing 3 changed files with 260 additions and 8 deletions.
44 changes: 38 additions & 6 deletions metrics_layer/core/sql/query_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(
else:
self.query_type = definition["query_type"]
self.filter_type = filter_type
self._extra_group_by_filter_conditions = []

self.validate(definition)

Expand Down Expand Up @@ -176,16 +177,47 @@ def criterion(self, field_sql: str) -> Criterion:
return Criterion.all(criteria)
return Filter.sql_query(field_sql, self.expression_type, self.value)

def consolidate_group_by_filter(self, filter_class_to_consolidate: "MetricsLayerFilter") -> None:
"""
Consolidate a group_by filter with another filter
"""
if not self.is_group_by:
raise QueryError("A group_by filter is invalid for a filter with no group_by property")

if self.group_by != filter_class_to_consolidate.group_by:
raise QueryError("The group_by field must be the same for both filters")

joinable_graphs = [jg for jg in self.field.join_graphs() if "merged_result" not in jg]
consolidate_joinable_graphs = [
jg for jg in filter_class_to_consolidate.field.join_graphs() if "merged_result" not in jg
]
join_overlap = set.intersection(*map(set, [joinable_graphs, consolidate_joinable_graphs]))
if len(join_overlap) == 0:
raise QueryError("The filters must have a join path in common to be consolidated")

self._extra_group_by_filter_conditions.append(filter_class_to_consolidate)

def cte(self, query_class, design_class):
if not self.is_group_by:
raise QueryError("A CTE is invalid for a filter with no group_by property")

having_filter = {k: v for k, v in self._definition.items() if k != "group_by"}
field_names = [self.group_by, having_filter["field"]]
group_by_filters = [{k: v for k, v in self._definition.items() if k != "group_by"}]
for f in self._extra_group_by_filter_conditions:
group_by_filters.append({k: v for k, v in f._definition.items() if k != "group_by"})

field_lookup = {}
for n in field_names:
field = self.design.get_field(n)
field_lookup[field.id()] = field
group_by_field = self.design.get_field(self.group_by)
field_lookup[group_by_field.id()] = group_by_field

filter_dict_args = {"where": [], "having": []}
for group_by_filter in group_by_filters:
filter_field = self.design.get_field(group_by_filter["field"])
field_lookup[filter_field.id()] = filter_field

if filter_field.field_type == "measure":
filter_dict_args["having"].append(group_by_filter)
else:
filter_dict_args["where"].append(group_by_filter)

design = design_class(
no_group_by=False,
Expand All @@ -198,7 +230,7 @@ def cte(self, query_class, design_class):
config = {
"metrics": [],
"dimensions": [self.group_by],
"having": [having_filter],
**filter_dict_args,
"return_pypika_query": True,
}
generator = query_class(config, design=design)
Expand Down
26 changes: 25 additions & 1 deletion metrics_layer/core/sql/query_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,32 @@ def get_query(self, semicolon: bool = True):
base_query = base_query.having(Criterion.all(having))

if self.having_group_by_filters:
# First consolidate them if they share the same group by field
seen = {}
for f in self.having_group_by_filters:
matches = [
consolidated_filter for k, consolidated_filter in seen.items() if f.group_by == k[1]
]
if matches:
consolidated = False
for match in matches:
try:
# The consolidation step will raise a QueryError if the filters can't be
# consolidated, which means the field cannot be joined.
# These filters must remain separate.
match.consolidate_group_by_filter(f)
consolidated = True
break
except QueryError:
pass

if not consolidated:
seen[(len(matches), f.group_by)] = f
else:
seen[(0, f.group_by)] = f

group_by_where = []
for i, f in enumerate(sorted(self.having_group_by_filters)):
for i, (_, f) in enumerate(sorted(seen.items(), key=lambda x: x[0])):
cte_alias = f"filter_subquery_{i}"
cte_query = f.cte(query_class=MetricsLayerQuery, design_class=MetricsLayerDesign)
base_query = base_query.with_(Table(cte_query), cte_alias)
Expand Down
198 changes: 197 additions & 1 deletion tests/test_join_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,7 +688,7 @@ def test_query_bool_and_date_filter(connection, bool_value):

@pytest.mark.query
@pytest.mark.parametrize("filter_type", ["having", "where"])
def test_query_sub_group_by_filter(connection, filter_type):
def test_query_sub_group_by_filter_measure(connection, filter_type):
query = connection.get_sql_query(
metrics=["number_of_orders"],
dimensions=["region"],
Expand Down Expand Up @@ -719,6 +719,202 @@ def test_query_sub_group_by_filter(connection, filter_type):
assert query == correct


@pytest.mark.query
@pytest.mark.parametrize("filter_type", ["having", "where"])
def test_query_sub_group_by_filter_dimension(connection, filter_type):
query = connection.get_sql_query(
metrics=["number_of_orders"],
dimensions=["region"],
**{
filter_type: [
{
"field": "channel",
"group_by": "customers.customer_id",
"expression": "contains_case_insensitive",
"value": "social",
}
]
},
)

correct = (
"WITH filter_subquery_0 AS (SELECT customers.customer_id as customers_customer_id FROM"
" analytics.order_line_items order_lines LEFT JOIN analytics.customers customers ON"
" order_lines.customer_id=customers.customer_id WHERE LOWER(order_lines.sales_channel) LIKE"
" LOWER('%social%') GROUP BY customers.customer_id ORDER BY customers_customer_id ASC) SELECT"
" customers.region as customers_region,COUNT(orders.id) as orders_number_of_orders FROM"
" analytics.orders orders LEFT JOIN analytics.customers customers ON"
" orders.customer_id=customers.customer_id WHERE customers.customer_id IN (SELECT DISTINCT"
" customers_customer_id FROM filter_subquery_0) GROUP BY customers.region ORDER BY"
" orders_number_of_orders DESC;"
)
assert query == correct


@pytest.mark.query
@pytest.mark.parametrize("filter_type", ["having", "where"])
def test_query_sub_group_by_filter_dimension_group(connection, filter_type):
query = connection.get_sql_query(
metrics=["number_of_orders"],
dimensions=["region"],
**{
filter_type: [
{
"field": "orders.order_date",
"group_by": "customers.customer_id",
"expression": "less_than",
"value": "2024-02-03",
}
]
},
)

correct = (
"WITH filter_subquery_0 AS (SELECT customers.customer_id as customers_customer_id FROM"
" analytics.orders orders LEFT JOIN analytics.customers customers ON"
" orders.customer_id=customers.customer_id WHERE DATE_TRUNC('DAY', orders.order_date)<'2024-02-03'"
" GROUP BY customers.customer_id ORDER BY customers_customer_id ASC) SELECT customers.region as"
" customers_region,COUNT(orders.id) as orders_number_of_orders FROM analytics.orders orders LEFT JOIN"
" analytics.customers customers ON orders.customer_id=customers.customer_id WHERE"
" customers.customer_id IN (SELECT DISTINCT customers_customer_id FROM filter_subquery_0) GROUP BY"
" customers.region ORDER BY orders_number_of_orders DESC;"
)
assert query == correct


@pytest.mark.query
def test_query_sub_group_by_filter_consolidated(connection):
query = connection.get_sql_query(
metrics=["number_of_orders"],
dimensions=["region"],
where=[
{
"field": "orders.order_date",
"group_by": "customers.customer_id",
"expression": "less_than",
"value": "2024-02-03",
},
{
"field": "channel",
"group_by": "customers.customer_id",
"expression": "contains_case_insensitive",
"value": "social",
},
{
"field": "total_item_revenue",
"group_by": "customers.customer_id",
"expression": "greater_than",
"value": 1000,
},
{"field": "product_name", "expression": "not_equal_to", "value": "Shipping Protection"},
],
having=[{"field": "total_item_revenue", "expression": "less_than", "value": 300_000}],
)

correct = (
"WITH filter_subquery_0 AS (SELECT customers.customer_id as customers_customer_id FROM"
" analytics.order_line_items order_lines LEFT JOIN analytics.orders orders ON"
" order_lines.order_unique_id=orders.id LEFT JOIN analytics.customers customers ON"
" order_lines.customer_id=customers.customer_id WHERE DATE_TRUNC('DAY',"
" orders.order_date)<'2024-02-03' AND LOWER(order_lines.sales_channel) LIKE LOWER('%social%') GROUP"
" BY customers.customer_id HAVING SUM(order_lines.revenue)>1000 ORDER BY customers_customer_id ASC)"
" SELECT customers.region as customers_region,NULLIF(COUNT(DISTINCT CASE WHEN (orders.id) IS NOT"
" NULL THEN orders.id ELSE NULL END), 0) as orders_number_of_orders FROM analytics.order_line_items"
" order_lines LEFT JOIN analytics.orders orders ON order_lines.order_unique_id=orders.id LEFT JOIN"
" analytics.customers customers ON order_lines.customer_id=customers.customer_id WHERE"
" order_lines.product_name<>'Shipping Protection' AND customers.customer_id IN (SELECT DISTINCT"
" customers_customer_id FROM filter_subquery_0) GROUP BY customers.region HAVING"
" SUM(order_lines.revenue)<300000 ORDER BY orders_number_of_orders DESC;"
)
assert query == correct


@pytest.mark.query
def test_query_sub_group_by_filter_consolidated_no_join(connection):
query = connection.get_sql_query(
metrics=["number_of_orders"],
dimensions=["region"],
where=[
{
"field": "number_of_sessions",
"group_by": "customers.customer_id",
"expression": "less_than",
"value": 10_000,
},
{
"field": "total_item_revenue",
"group_by": "customers.customer_id",
"expression": "greater_than",
"value": 1000,
},
{"field": "product_name", "expression": "not_equal_to", "value": "Shipping Protection"},
],
having=[{"field": "total_item_revenue", "expression": "less_than", "value": 300_000}],
)

correct = (
"WITH filter_subquery_0 AS (SELECT customers.customer_id as customers_customer_id FROM"
" analytics.sessions sessions LEFT JOIN analytics.customers customers ON"
" sessions.customer_id=customers.customer_id GROUP BY customers.customer_id HAVING"
" COUNT(sessions.id)<10000 ORDER BY customers_customer_id ASC) ,filter_subquery_1 AS (SELECT"
" customers.customer_id as customers_customer_id FROM analytics.order_line_items order_lines LEFT"
" JOIN analytics.customers customers ON order_lines.customer_id=customers.customer_id GROUP BY"
" customers.customer_id HAVING SUM(order_lines.revenue)>1000 ORDER BY customers_customer_id ASC)"
" SELECT customers.region as customers_region,NULLIF(COUNT(DISTINCT CASE WHEN (orders.id) IS NOT"
" NULL THEN orders.id ELSE NULL END), 0) as orders_number_of_orders FROM analytics.order_line_items"
" order_lines LEFT JOIN analytics.orders orders ON order_lines.order_unique_id=orders.id LEFT JOIN"
" analytics.customers customers ON order_lines.customer_id=customers.customer_id WHERE"
" order_lines.product_name<>'Shipping Protection' AND customers.customer_id IN (SELECT DISTINCT"
" customers_customer_id FROM filter_subquery_0) AND customers.customer_id IN (SELECT DISTINCT"
" customers_customer_id FROM filter_subquery_1) GROUP BY customers.region HAVING"
" SUM(order_lines.revenue)<300000 ORDER BY orders_number_of_orders DESC;"
)
assert query == correct


@pytest.mark.query
def test_query_sub_group_by_filter_not_consolidated(connection):
query = connection.get_sql_query(
metrics=["number_of_orders"],
dimensions=["region"],
where=[
{
"field": "channel",
"group_by": "orders.order_id",
"expression": "contains_case_insensitive",
"value": "social",
},
{
"field": "total_item_revenue",
"group_by": "customers.customer_id",
"expression": "greater_than",
"value": 1000,
},
{"field": "product_name", "expression": "not_equal_to", "value": "Shipping Protection"},
],
having=[{"field": "total_item_revenue", "expression": "less_than", "value": 300_000}],
)

correct = (
"WITH filter_subquery_0 AS (SELECT customers.customer_id as customers_customer_id FROM"
" analytics.order_line_items order_lines LEFT JOIN analytics.customers customers ON"
" order_lines.customer_id=customers.customer_id GROUP BY customers.customer_id HAVING"
" SUM(order_lines.revenue)>1000 ORDER BY customers_customer_id ASC) ,filter_subquery_1 AS (SELECT"
" orders.id as orders_order_id FROM analytics.order_line_items order_lines LEFT JOIN analytics.orders"
" orders ON order_lines.order_unique_id=orders.id WHERE LOWER(order_lines.sales_channel) LIKE"
" LOWER('%social%') GROUP BY orders.id ORDER BY orders_order_id ASC) SELECT customers.region as"
" customers_region,NULLIF(COUNT(DISTINCT CASE WHEN (orders.id) IS NOT NULL THEN orders.id ELSE"
" NULL END), 0) as orders_number_of_orders FROM analytics.order_line_items order_lines LEFT JOIN"
" analytics.orders orders ON order_lines.order_unique_id=orders.id LEFT JOIN analytics.customers"
" customers ON order_lines.customer_id=customers.customer_id WHERE"
" order_lines.product_name<>'Shipping Protection' AND customers.customer_id IN (SELECT DISTINCT"
" customers_customer_id FROM filter_subquery_0) AND orders.id IN (SELECT DISTINCT orders_order_id"
" FROM filter_subquery_1) GROUP BY customers.region HAVING SUM(order_lines.revenue)<300000 ORDER BY"
" orders_number_of_orders DESC;"
)
assert query == correct


@pytest.mark.query
def test_query_sum_when_should_be_number(connection):
with pytest.raises(QueryError) as exc_info:
Expand Down

0 comments on commit 8f309dd

Please sign in to comment.