Skip to content

Commit

Permalink
Add reference link detection to dimensions dag
Browse files Browse the repository at this point in the history
  • Loading branch information
shangyian committed Dec 13, 2024
1 parent 43a1b98 commit 9a04642
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 10 deletions.
6 changes: 6 additions & 0 deletions datajunction-server/datajunction_server/api/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
build_metric_nodes,
get_default_criteria,
rename_columns,
validate_shared_dimensions,
)
from datajunction_server.construction.dj_query import build_dj_query
from datajunction_server.database.attributetype import AttributeType
Expand Down Expand Up @@ -497,6 +498,11 @@ async def validate_cube( # pylint: disable=too-many-locals
message=("Metrics and dimensions must be part of a common catalog"),
)

await validate_shared_dimensions(
session,
metric_nodes,
dimension_names,
)
return metrics, metric_nodes, list(dimension_nodes.values()), dimensions, catalog


Expand Down
27 changes: 21 additions & 6 deletions datajunction-server/datajunction_server/construction/build_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1279,12 +1279,27 @@ async def dimension_join_path(
# Check the reference links on this dimension node
await refresh_if_needed(session, current_link.dimension.current, ["columns"])
for col in current_link.dimension.current.columns:
if (
col.dimension
and f"{col.dimension.name}.{col.dimension_column}" == dimension
):
return join_path

if col.dimension:
if f"{col.dimension.name}.{col.dimension_column}" == dimension:
return join_path
await refresh_if_needed(session, col.dimension, ["current"])
await refresh_if_needed(
session,
col.dimension.current,
["dimension_links"],
)
for link in col.dimension.current.dimension_links:
if (
f"{col.dimension.name}.{col.dimension_column}"
in link.foreign_keys
):
if (
link.foreign_keys[
f"{col.dimension.name}.{col.dimension_column}"
]
== dimension
):
return join_path
await refresh_if_needed(
session,
current_link.dimension.current,
Expand Down
25 changes: 24 additions & 1 deletion datajunction-server/datajunction_server/database/queryrequest.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@
from datajunction_server.database.node import Node, NodeRevision
from datajunction_server.enum import StrEnum
from datajunction_server.errors import DJInvalidInputException
from datajunction_server.sql.dag import get_upstream_nodes
from datajunction_server.sql.dag import (
get_dimensions,
get_shared_dimensions,
get_upstream_nodes,
)
from datajunction_server.sql.parsing import ast
from datajunction_server.sql.parsing.backends.antlr4 import parse
from datajunction_server.typing import UTCDatetime
Expand Down Expand Up @@ -304,6 +308,25 @@ async def to_versioned_query_request( # pylint: disable=too-many-locals
message="At least one metric is required",
http_status_code=HTTPStatus.UNPROCESSABLE_ENTITY,
)
node_columns = []
if len(nodes_objs) == 1:
node_columns = [col.name for col in nodes_objs[0].current.columns] # type: ignore
available_dimensions = {
dim.name
for dim in (
await get_dimensions(session, nodes_objs[0]) # type: ignore
if len(nodes_objs) == 1
else await get_shared_dimensions(session, nodes_objs) # type: ignore
)
}.union(set(node_columns))
invalid_dimensions = sorted(
list(set(dimensions).difference(available_dimensions)),
)
if dimensions and invalid_dimensions:
raise DJInvalidInputException(
f"{', '.join(invalid_dimensions)} are not available "
f"dimensions on {', '.join(nodes)}",
)

dimension_nodes = [
await Node.get_by_name(session, ".".join(dim.split(".")[:-1]), options=[])
Expand Down
6 changes: 3 additions & 3 deletions datajunction-server/datajunction_server/sql/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,13 +352,13 @@ async def get_dimensions_dag( # pylint: disable=too-many-locals
)
.join(
graph_branches,
(current_rev.id == graph_branches.c.node_revision_id)
& (is_(graph_branches.c.dimension_column, None)),
(current_rev.id == graph_branches.c.node_revision_id),
# & (is_(graph_branches.c.dimension_column, None)),
)
.join(
next_node,
(next_node.id == graph_branches.c.dimension_id)
& (is_(graph_branches.c.dimension_column, None))
# & (is_(graph_branches.c.dimension_column, None))
& (is_(next_node.deactivated_at, None)),
)
.join(
Expand Down

0 comments on commit 9a04642

Please sign in to comment.