Skip to content

Commit

Permalink
Add test and fix issues with handling reference dim links
Browse files Browse the repository at this point in the history
  • Loading branch information
shangyian committed Dec 10, 2024
1 parent aac97e0 commit 09f1686
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 6 deletions.
1 change: 0 additions & 1 deletion datajunction-server/datajunction_server/api/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
build_metric_nodes,
get_default_criteria,
rename_columns,
validate_shared_dimensions,
)
from datajunction_server.construction.dj_query import build_dj_query
from datajunction_server.database.attributetype import AttributeType
Expand Down
1 change: 1 addition & 0 deletions datajunction-server/datajunction_server/api/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -971,6 +971,7 @@ async def add_reference_dimension_link(
),
)
await session.commit()
await session.refresh(target_column)
return JSONResponse(
status_code=201,
content={
Expand Down
12 changes: 12 additions & 0 deletions datajunction-server/datajunction_server/construction/build_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1349,7 +1349,19 @@ def build_dimension_attribute(
if dimension_attr.name in link.foreign_keys_reversed
else None
)
reference_links = {
col.name: f"{col.dimension.name}.{col.dimension_column}"
for col in link.dimension.current.columns
if col.dimension
}
for col in node_query.select.projection:
if reference_links.get(col.alias_or_name.name) == full_column_name: # type: ignore
return ast.Column(
name=ast.Name(col.alias_or_name.name), # type: ignore
alias=ast.Name(alias) if alias else None,
_table=node_query,
_type=col.type, # type: ignore
)
if col.alias_or_name.name == dimension_attr.column_name or ( # type: ignore
foreign_key_column_name
and col.alias_or_name.identifier() == foreign_key_column_name # type: ignore
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,7 @@
from datajunction_server.database.node import Node, NodeRevision
from datajunction_server.enum import StrEnum
from datajunction_server.errors import DJInvalidInputException
from datajunction_server.sql.dag import (
get_dimensions,
get_shared_dimensions,
get_upstream_nodes,
)
from datajunction_server.sql.dag import get_upstream_nodes
from datajunction_server.sql.parsing import ast
from datajunction_server.sql.parsing.backends.antlr4 import parse
from datajunction_server.typing import UTCDatetime
Expand Down
74 changes: 74 additions & 0 deletions datajunction-server/tests/api/dimension_links_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,27 @@ async def _link_events_to_users_without_role() -> Response:
return _link_events_to_users_without_role


@pytest.fixture
def reference_link_users_date(
dimensions_link_client: AsyncClient, # pylint: disable=redefined-outer-name
):
"""
Create a reference link between users and date
"""

async def _reference_link_users_date() -> Response:
response = await dimensions_link_client.post(
"/nodes/default.users/columns/snapshot_date/link",
params={
"dimension_node": "default.date",
"dimension_column": "dateint",
},
)
return response

return _reference_link_users_date


@pytest.fixture
def link_events_to_users_with_role_direct(
dimensions_link_client: AsyncClient, # pylint: disable=redefined-outer-name
Expand Down Expand Up @@ -964,6 +985,59 @@ async def test_measures_sql_with_reference_dimension_links(
assert response_data[0]["errors"] == []


@pytest.mark.asyncio
async def test_measures_sql_with_ref_link_on_dim_node(
dimensions_link_client: AsyncClient, # pylint: disable=redefined-outer-name
link_events_to_users_without_role, # pylint: disable=redefined-outer-name
reference_link_users_date, # pylint: disable=redefined-outer-name
):
"""
Verify that measures SQL can be retrieved for dimension attributes that come from a
reference dimension link from one dim node to another dim node.
"""
await link_events_to_users_without_role()
await reference_link_users_date()

response = await dimensions_link_client.get(
"/sql/measures/v2",
params={
"metrics": ["default.elapsed_secs"],
"dimensions": [
"default.date.dateint",
],
},
)
response_data = response.json()
expected_sql = """
WITH default_DOT_events AS (
SELECT
default_DOT_events_table.user_id,
default_DOT_events_table.event_start_date,
default_DOT_events_table.event_end_date,
default_DOT_events_table.elapsed_secs,
default_DOT_events_table.user_registration_country
FROM examples.events AS default_DOT_events_table
),
default_DOT_users AS (
SELECT
default_DOT_users_table.user_id,
default_DOT_users_table.snapshot_date,
default_DOT_users_table.registration_country,
default_DOT_users_table.residence_country,
default_DOT_users_table.account_type
FROM examples.users AS default_DOT_users_table
)
SELECT
default_DOT_events.elapsed_secs default_DOT_events_DOT_elapsed_secs,
default_DOT_users.snapshot_date default_DOT_date_DOT_dateint
FROM default_DOT_events
LEFT JOIN default_DOT_users
ON default_DOT_events.user_id = default_DOT_users.user_id
AND default_DOT_events.event_start_date = default_DOT_users.snapshot_date
"""
assert str(parse(response_data[0]["sql"])) == str(parse(expected_sql))


@pytest.mark.asyncio
async def test_dimension_link_cross_join(
dimensions_link_client: AsyncClient, # pylint: disable=redefined-outer-name
Expand Down
10 changes: 10 additions & 0 deletions datajunction-server/tests/examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -2258,6 +2258,16 @@
"primary_key": ["country_code"],
},
),
(
"/nodes/dimension/",
{
"description": "Date dimension",
"query": """SELECT 1 AS dateint""",
"mode": "published",
"name": "default.date",
"primary_key": ["dateint"],
},
),
(
"/nodes/metric/",
{
Expand Down

0 comments on commit 09f1686

Please sign in to comment.