Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consolidate to Single Cube Materialization Option #1304

Merged
merged 23 commits into from
Feb 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion datajunction-clients/python/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
get_session,
get_settings,
)
from fastapi import Request
from httpx import AsyncClient
from pytest_mock import MockerFixture
from sqlalchemy import create_engine
Expand Down Expand Up @@ -273,12 +274,15 @@ def module__server( # pylint: disable=too-many-statements
module__session: AsyncSession,
module__settings: Settings,
module__query_service_client: QueryServiceClient,
module_mocker,
) -> Iterator[TestClient]:
"""
Create a mock server for testing APIs that contains a mock query service.
"""

def get_query_service_client_override() -> QueryServiceClient:
def get_query_service_client_override(
request: Request = None, # pylint: disable=unused-argument
) -> QueryServiceClient:
return module__query_service_client

async def get_session_override() -> AsyncSession:
Expand All @@ -287,6 +291,11 @@ async def get_session_override() -> AsyncSession:
def get_settings_override() -> Settings:
return module__settings

module_mocker.patch(
"datajunction_server.api.materializations.get_query_service_client",
get_query_service_client_override,
)

app.dependency_overrides[get_session] = get_session_override
app.dependency_overrides[get_settings] = get_settings_override
app.dependency_overrides[
Expand Down
100 changes: 100 additions & 0 deletions datajunction-server/datajunction_server/api/cubes.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,26 @@
from datajunction_server.construction.dimensions import build_dimensions_from_cube_query
from datajunction_server.database.node import Node
from datajunction_server.database.user import User
from datajunction_server.errors import DJInvalidInputException
from datajunction_server.internal.access.authentication.http import SecureAPIRouter
from datajunction_server.internal.access.authorization import validate_access
from datajunction_server.internal.materializations import build_cube_materialization
from datajunction_server.internal.nodes import get_cube_revision_metadata
from datajunction_server.models import access
from datajunction_server.models.cube import (
CubeRevisionMetadata,
DimensionValue,
DimensionValues,
)
from datajunction_server.models.cube_materialization import (
DruidCubeMaterializationInput,
UpsertCubeMaterialization,
)
from datajunction_server.models.materialization import (
Granularity,
MaterializationJobTypeEnum,
MaterializationStrategy,
)
from datajunction_server.models.metric import TranslatedSQL
from datajunction_server.models.query import QueryCreate
from datajunction_server.naming import from_amenable_name
Expand All @@ -47,6 +58,95 @@ async def get_cube(
return await get_cube_revision_metadata(session, name)


@router.get("/cubes/{name}/materialization", name="Cube Materialization Config")
async def cube_materialization_info(
name: str,
session: AsyncSession = Depends(get_session),
) -> DruidCubeMaterializationInput:
"""
The standard cube materialization config. DJ makes sensible materialization choices
where possible.

Requirements:
- The cube must have a temporal partition column specified.
- The job strategy will always be "incremental time".
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why always? I think we should provide a "full replacement" as an option.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah, I definitely agree that we should provide it as an option, but for this first cut I didn't want to account for supporting both "full" and "incremental". It was easier to just raise an error until we're ready to implement "full".


Outputs:
"measures_materializations":
We group the metrics by parent node. Then we try to pre-aggregate each parent node as
much as possible to prepare for metric queries on the cube's dimensions.
"combiners":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This a great feature but I wonder how often it will be used.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hard to say, but to be honest I came across this problem almost immediately, so I think it's more often than we think.

We combine each set of measures materializations on their shared grain. Note that we don't
support materializing cubes with measures materializations that don't share the same grain.
However, we keep `combiners` as a list in the eventual future where we support that.
"metrics":
We include a list of metrics, their required measures, and the derived expression (e.g., the
expression used by the metric that makes use of the pre-aggregated measures)

Once we create a scheduled materialization workflow, we freeze the metadata for that particular
materialized dataset. This allows us to reconstruct metrics SQL from the dataset when needed.
To request metrics from the materialized cube, use the metrics' measures metadata.
"""
node = await Node.get_cube_by_name(session, name)
temporal_partitions = node.current.temporal_partition_columns() # type: ignore
if len(temporal_partitions) != 1:
raise DJInvalidInputException(
"The cube must have a single temporal partition column set "
"in order for it to be materialized.",
)
temporal_partition = temporal_partitions[0] if temporal_partitions else None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Add a comment that if more than 2 temporal partitions are defined we pick a random one.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually changed the earlier part that checks for temporal partitions to be if len(temporal_partitions) != 1 -- then we can catch and raise for cases where there are multiple temporal partitions. The only use case I can see for multiple temporal partitions is if there's a "date int" and an "hour" partition, but we'll need more metadata to support that anyway, so better to raise I think.

granularity_lookback_defaults = {
Granularity.MINUTE: "1 MINUTE",
Granularity.HOUR: "1 HOUR",
Granularity.DAY: "1 DAY",
Granularity.WEEK: "1 WEEK",
Granularity.MONTH: "1 MONTH",
Granularity.QUARTER: "1 QUARTER",
Granularity.YEAR: "1 YEAR",
}
granularity_cron_defaults = {
Granularity.MINUTE: "* * * * *", # Runs every minute
Granularity.HOUR: "0 * * * *", # Runs at the start of every hour
Granularity.DAY: "0 0 * * *", # Runs at midnight every day
Granularity.WEEK: "0 0 * * 0", # Runs at midnight on Sundays
Granularity.MONTH: "0 0 1 * *", # Runs at midnight on the first of every month
Granularity.QUARTER: "0 0 1 */3 *", # Runs at midnight on the first day of each quarter
Granularity.YEAR: "0 0 1 1 *", # Runs at midnight on January 1st every year
}
upsert = UpsertCubeMaterialization(
job=MaterializationJobTypeEnum.DRUID_CUBE,
strategy=(
MaterializationStrategy.INCREMENTAL_TIME
if temporal_partition
else MaterializationStrategy.FULL
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may never reach this code path... unless we covert this to support full refresh.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, good point. We can do that in a follow-up PR I think.

),
lookback_window=granularity_lookback_defaults.get(
temporal_partition.partition.granularity,
granularity_lookback_defaults[Granularity.DAY],
),
schedule=granularity_cron_defaults.get(
temporal_partition.partition.granularity,
granularity_cron_defaults[Granularity.DAY],
),
)
cube_config = await build_cube_materialization(
session,
node.current, # type: ignore
upsert,
)
return DruidCubeMaterializationInput(
name="",
cube=cube_config.cube,
dimensions=cube_config.dimensions,
metrics=cube_config.metrics,
strategy=upsert.strategy,
schedule=upsert.schedule,
job=upsert.job.name,
measures_materializations=cube_config.measures_materializations,
combiners=cube_config.combiners,
)


@router.get("/cubes/{name}/dimensions/sql", name="Dimensions SQL for Cube")
async def get_cube_dimension_sql(
name: str,
Expand Down
4 changes: 4 additions & 0 deletions datajunction-server/datajunction_server/api/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
"""
Data related APIs.
"""
import logging
from typing import Callable, Dict, List, Optional

from fastapi import BackgroundTasks, Depends, Query, Request
Expand Down Expand Up @@ -42,6 +43,8 @@
get_settings,
)

_logger = logging.getLogger(__name__)

settings = get_settings()
router = SecureAPIRouter(tags=["data"])

Expand All @@ -61,6 +64,7 @@ async def add_availability_state(
"""
Add an availability state to a node.
"""
_logger.info("Storing availability for node=%s", node_name)

node = await Node.get_by_name(
session,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,14 @@

import strawberry

from datajunction_server.models.cube_materialization import (
Aggregability as Aggregability_,
)
from datajunction_server.models.cube_materialization import (
AggregationRule as AggregationRule_,
)
from datajunction_server.models.cube_materialization import Measure as Measure_
from datajunction_server.models.node import MetricDirection as MetricDirection_
from datajunction_server.sql.decompose import Aggregability as Aggregability_
from datajunction_server.sql.decompose import AggregationRule as AggregationRule_
from datajunction_server.sql.decompose import Measure as Measure_

MetricDirection = strawberry.enum(MetricDirection_)
Aggregability = strawberry.enum(Aggregability_)
Expand Down
2 changes: 2 additions & 0 deletions datajunction-server/datajunction_server/api/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,8 @@ async def validate_cube( # pylint: disable=too-many-locals

# Verify that the provided metrics are metric nodes
metrics: List[Column] = [metric.current.columns[0] for metric in metric_nodes]
for metric in metrics:
await session.refresh(metric, ["node_revisions"])
if not metrics:
raise DJInvalidInputException(
message=("At least one metric is required"),
Expand Down
43 changes: 37 additions & 6 deletions datajunction-server/datajunction_server/api/materializations.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from datajunction_server.materialization.jobs import MaterializationJob
from datajunction_server.models import access
from datajunction_server.models.base import labelize
from datajunction_server.models.cube_materialization import UpsertCubeMaterialization
from datajunction_server.models.materialization import (
MaterializationConfigInfoUnified,
MaterializationConfigOutput,
Expand Down Expand Up @@ -80,7 +81,7 @@ def materialization_jobs_info() -> JSONResponse:
)
async def upsert_materialization( # pylint: disable=too-many-locals
node_name: str,
data: UpsertMaterialization,
data: UpsertMaterialization | UpsertCubeMaterialization,
*,
session: AsyncSession = Depends(get_session),
request: Request,
Expand All @@ -95,14 +96,19 @@ async def upsert_materialization( # pylint: disable=too-many-locals
for the materialization config, it will always update that named config.
"""
request_headers = dict(request.headers)
node = await Node.get_by_name(session, node_name)
node = await Node.get_by_name(session, node_name, raise_if_not_exists=True)
if node.type == NodeType.SOURCE: # type: ignore
raise DJInvalidInputException(
http_status_code=HTTPStatus.BAD_REQUEST,
message=f"Cannot set materialization config for source node `{node_name}`!",
)
if node.type == NodeType.CUBE: # type: ignore
node = await Node.get_cube_by_name(session, node_name)
_logger.info(
"Upserting materialization for node=%s version=%s",
node.name, # type: ignore
node.current_version, # type: ignore
)

current_revision = node.current # type: ignore
old_materializations = {mat.name: mat for mat in current_revision.materializations}
Expand Down Expand Up @@ -131,6 +137,11 @@ async def upsert_materialization( # pylint: disable=too-many-locals
existing_materialization
and existing_materialization.config == new_materialization.config
):
_logger.info(
"Existing materialization found for node=%s version=%s",
node.name, # type: ignore
node.current_version, # type: ignore
)
new_materialization.node_revision = None # type: ignore
# if the materialization was deactivated before, restore it
if existing_materialization.deactivated_at is not None:
Expand All @@ -151,15 +162,20 @@ async def upsert_materialization( # pylint: disable=too-many-locals
existing_materialization_info = query_service_client.get_materialization_info(
node_name,
current_revision.version, # type: ignore
current_revision.type,
new_materialization.name, # type: ignore
request_headers=request_headers,
)
# refresh existing materialization job
_logger.info(
"Refresh materialization workflows for node=%s version=%s",
node.name, # type: ignore
node.current_version, # type: ignore
)
await schedule_materialization_jobs(
session,
node_revision_id=current_revision.id,
materialization_names=[new_materialization.name],
query_service_client=query_service_client,
query_service_client=get_query_service_client(request), # type: ignore
request_headers=request_headers,
)
return JSONResponse(
Expand All @@ -178,12 +194,22 @@ async def upsert_materialization( # pylint: disable=too-many-locals
)
# If changes are detected, update the existing or save the new materialization
if existing_materialization:
_logger.info(
"Updating existing materialization for node=%s version=%s",
node.name, # type: ignore
node.current_version, # type: ignore
)
existing_materialization.config = new_materialization.config
existing_materialization.schedule = new_materialization.schedule
new_materialization.node_revision = None # type: ignore
new_materialization = existing_materialization
new_materialization.deactivated_at = None
else:
_logger.info(
"Adding new materialization for node=%s version=%s",
node.name, # type: ignore
node.current_version, # type: ignore
)
unchanged_existing_materializations = [
config
for config in current_revision.materializations
Expand Down Expand Up @@ -215,12 +241,16 @@ async def upsert_materialization( # pylint: disable=too-many-locals
),
)
await session.commit()

_logger.info(
"Scheduling materialization workflows for node=%s version=%s",
node.name, # type: ignore
node.current_version, # type: ignore
)
materialization_response = await schedule_materialization_jobs(
session,
node_revision_id=current_revision.id,
materialization_names=[new_materialization.name],
query_service_client=query_service_client,
query_service_client=get_query_service_client(request), # type: ignore
request_headers=request_headers,
)
return JSONResponse(
Expand Down Expand Up @@ -260,6 +290,7 @@ async def list_node_materializations(
info = query_service_client.get_materialization_info(
node_name,
node.current.version, # type: ignore
node.type, # type: ignore
materialization.name, # type: ignore
request_headers=request_headers,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@
from datajunction_server.errors import DJError, DJInvalidInputException, ErrorCode
from datajunction_server.internal.engines import get_engine
from datajunction_server.models import access
from datajunction_server.models.cube_materialization import Measure
from datajunction_server.models.engine import Dialect
from datajunction_server.models.materialization import GenericCubeConfig
from datajunction_server.models.node import BuildCriteria
from datajunction_server.naming import LOOKUP_CHARS, amenable_name, from_amenable_name
from datajunction_server.sql.dag import get_shared_dimensions
from datajunction_server.sql.decompose import Measure, MeasureExtractor
from datajunction_server.sql.decompose import MeasureExtractor
from datajunction_server.sql.parsing.backends.antlr4 import ast, parse
from datajunction_server.sql.parsing.types import ColumnType
from datajunction_server.utils import SEPARATOR
Expand Down
11 changes: 10 additions & 1 deletion datajunction-server/datajunction_server/construction/build_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@
from datajunction_server.internal.engines import get_engine
from datajunction_server.models import access
from datajunction_server.models.column import SemanticType
from datajunction_server.models.cube_materialization import Aggregability, Measure
from datajunction_server.models.engine import Dialect
from datajunction_server.models.node import BuildCriteria
from datajunction_server.models.node_type import NodeType
from datajunction_server.models.sql import GeneratedSQL
from datajunction_server.naming import amenable_name, from_amenable_name
from datajunction_server.sql.decompose import Aggregability, Measure
from datajunction_server.sql.parsing.ast import CompileContext
from datajunction_server.sql.parsing.backends.antlr4 import ast, cached_parse, parse
from datajunction_server.utils import SEPARATOR, refresh_if_needed
Expand Down Expand Up @@ -262,6 +262,15 @@ async def get_measures_query( # pylint: disable=too-many-locals
else [pk_col.name for pk_col in parent_node.current.primary_key()]
),
errors=query_builder.errors,
metrics={
metric.name: (
metrics2measures[metric.name][0],
str(metrics2measures[metric.name][1]).replace("\n", "")
if preaggregate
else metric.query,
)
for metric in children
},
),
)
return measures_queries
Expand Down
Loading