Skip to content

Commit

Permalink
[dagster-airlift][sensor] make sensor implementation pluggable at top…
Browse files Browse the repository at this point in the history
… level (#25052)

## Summary & Motivation
Introduce a pluggability point for the sensor. Only bit that has changed
from offline discussion is that I added a context argument so that users
can perform their own logging.

## How I Tested These Changes
Added a new test for pluggable implementation.
Additional testing I would like to do:
- ensure that an error in pluggable fxn doesn't advance the cursor
- ensure that we handle errors in the pluggable function well.
- ensure that we gracefully error when effective timestamp metadata is
missing.

## Changelog
NOCHANGELOG
  • Loading branch information
dpeng817 authored Oct 4, 2024
1 parent f510688 commit 9bd94e9
Show file tree
Hide file tree
Showing 5 changed files with 210 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
DEFAULT_AIRFLOW_SENSOR_INTERVAL_SECONDS,
build_airflow_polling_sensor_defs,
)
from dagster_airlift.core.sensor.event_translation import DagsterEventTransformerFn
from dagster_airlift.core.serialization.compute import compute_serialized_data
from dagster_airlift.core.serialization.defs_construction import (
construct_automapped_dag_assets_defs,
Expand Down Expand Up @@ -57,6 +58,7 @@ def build_defs_from_airflow_instance(
airflow_instance: AirflowInstance,
defs: Optional[Definitions] = None,
sensor_minimum_interval_seconds: int = DEFAULT_AIRFLOW_SENSOR_INTERVAL_SECONDS,
event_transformer_fn: Optional[DagsterEventTransformerFn] = None,
) -> Definitions:
resolved_defs = AirflowInstanceDefsLoader(
airflow_instance=airflow_instance,
Expand All @@ -70,7 +72,7 @@ def build_defs_from_airflow_instance(
airflow_instance=airflow_instance, resolved_airflow_defs=resolved_defs
),
minimum_interval_seconds=sensor_minimum_interval_seconds,
event_translation_fn=None,
event_transformer_fn=event_transformer_fn,
),
)

Expand Down Expand Up @@ -117,7 +119,7 @@ def build_full_automapped_dags_from_airflow_instance(
airflow_data=AirflowDefinitionsData(
resolved_airflow_defs=resolved_defs, airflow_instance=airflow_instance
),
event_translation_fn=None,
event_transformer_fn=None,
),
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,18 @@
_check as check,
sensor,
)
from dagster._core.definitions.asset_check_evaluation import AssetCheckEvaluation
from dagster._core.definitions.asset_selection import AssetSelection
from dagster._core.definitions.definitions_class import Definitions
from dagster._core.definitions.events import AssetObservation
from dagster._core.definitions.repository_definition.repository_definition import (
RepositoryDefinition,
)
from dagster._core.errors import (
DagsterInvariantViolationError,
DagsterUserCodeExecutionError,
user_code_error_boundary,
)
from dagster._core.storage.dagster_run import RunsFilter
from dagster._grpc.client import DEFAULT_SENSOR_GRPC_TIMEOUT
from dagster._record import record
Expand All @@ -27,12 +34,14 @@
from dagster_airlift.constants import (
AUTOMAPPED_TASK_METADATA_KEY,
DAG_RUN_ID_TAG_KEY,
EFFECTIVE_TIMESTAMP_METADATA_KEY,
TASK_ID_TAG_KEY,
)
from dagster_airlift.core.airflow_defs_data import AirflowDefinitionsData
from dagster_airlift.core.airflow_instance import AirflowInstance, DagRun, TaskInstance
from dagster_airlift.core.sensor.event_translation import (
AirflowEventTranslationFn,
AssetEvent,
DagsterEventTransformerFn,
get_timestamp_from_materialization,
materializations_for_dag_run,
synthetic_mats_for_mapped_asset_keys,
Expand All @@ -54,6 +63,10 @@ class AirflowPollingSensorCursor:
dag_query_offset: Optional[int] = None


class AirliftSensorEventTransformerError(DagsterUserCodeExecutionError):
"""Error raised when an error occurs in the event transformer function."""


def check_keys_for_asset_keys(
repository_def: RepositoryDefinition, asset_keys: Set[AssetKey]
) -> Iterable[AssetCheckKey]:
Expand All @@ -65,7 +78,7 @@ def check_keys_for_asset_keys(

def build_airflow_polling_sensor_defs(
airflow_data: AirflowDefinitionsData,
event_translation_fn: Optional[AirflowEventTranslationFn],
event_transformer_fn: Optional[DagsterEventTransformerFn],
minimum_interval_seconds: int = DEFAULT_AIRFLOW_SENSOR_INTERVAL_SECONDS,
) -> Definitions:
@sensor(
Expand Down Expand Up @@ -102,7 +115,6 @@ def airflow_dag_sensor(context: SensorEvaluationContext) -> SensorResult:
end_date_lte=end_date_lte,
offset=current_dag_offset,
airflow_data=airflow_data,
event_translation_fn=event_translation_fn,
)
all_asset_events: List[AssetMaterialization] = []
all_check_keys: Set[AssetCheckKey] = set()
Expand Down Expand Up @@ -132,13 +144,20 @@ def airflow_dag_sensor(context: SensorEvaluationContext) -> SensorResult:
end_date_lte=None,
dag_query_offset=0,
)
updated_asset_events = _get_transformer_result(
event_transformer_fn=event_transformer_fn,
context=context,
airflow_data=airflow_data,
all_asset_events=all_asset_events,
)

context.update_cursor(serialize_value(new_cursor))

context.log.info(
f"************Exitting sensor for {airflow_data.airflow_instance.name}***********"
f"************Exiting sensor for {airflow_data.airflow_instance.name}***********"
)
return SensorResult(
asset_events=sorted_asset_events(all_asset_events, repository_def),
asset_events=sorted_asset_events(updated_asset_events, repository_def),
run_requests=[RunRequest(asset_check_keys=list(all_check_keys))]
if all_check_keys
else None,
Expand All @@ -148,22 +167,51 @@ def airflow_dag_sensor(context: SensorEvaluationContext) -> SensorResult:


def sorted_asset_events(
all_materializations: Sequence[AssetMaterialization],
asset_events: Sequence[AssetEvent],
repository_def: RepositoryDefinition,
) -> List[AssetMaterialization]:
) -> List[AssetEvent]:
"""Sort materializations by end date and toposort order."""
topo_aks = repository_def.asset_graph.toposorted_asset_keys
materializations_and_timestamps = [
(get_timestamp_from_materialization(mat), mat) for mat in all_materializations
(get_timestamp_from_materialization(mat), mat) for mat in asset_events
]
return [
sorted_mat[1]
for sorted_mat in sorted(
sorted_event[1]
for sorted_event in sorted(
materializations_and_timestamps, key=lambda x: (x[0], topo_aks.index(x[1].asset_key))
)
]


def _get_transformer_result(
event_transformer_fn: Optional[DagsterEventTransformerFn],
context: SensorEvaluationContext,
airflow_data: AirflowDefinitionsData,
all_asset_events: Sequence[AssetMaterialization],
) -> Sequence[AssetEvent]:
if not event_transformer_fn:
return all_asset_events

with user_code_error_boundary(
AirliftSensorEventTransformerError,
lambda: f"Error occurred during event transformation for {airflow_data.airflow_instance.name}",
):
updated_asset_events = list(event_transformer_fn(context, airflow_data, all_asset_events))

for asset_event in updated_asset_events:
if not isinstance(
asset_event, (AssetMaterialization, AssetObservation, AssetCheckEvaluation)
):
raise DagsterInvariantViolationError(
f"Event transformer function must return AssetMaterialization, AssetObservation, or AssetCheckEvaluation objects. Got {type(asset_event)}."
)
if EFFECTIVE_TIMESTAMP_METADATA_KEY not in asset_event.metadata:
raise DagsterInvariantViolationError(
f"All returned events must have an effective timestamp, but {asset_event} does not. An effective timestamp can be used by setting dagster_airlift.constants.EFFECTIVE_TIMESTAMP_METADATA_KEY with a dagster.TimestampMetadataValue."
)
return updated_asset_events


@record
class BatchResult:
idx: int
Expand All @@ -177,7 +225,6 @@ def materializations_and_requests_from_batch_iter(
end_date_lte: float,
offset: int,
airflow_data: AirflowDefinitionsData,
event_translation_fn: Optional[AirflowEventTranslationFn],
) -> Iterator[Optional[BatchResult]]:
runs = airflow_data.airflow_instance.get_dag_runs_batch(
dag_ids=list(airflow_data.all_dag_ids),
Expand All @@ -188,7 +235,7 @@ def materializations_and_requests_from_batch_iter(
context.log.info(f"Found {len(runs)} dag runs for {airflow_data.airflow_instance.name}")
context.log.info(f"All runs {runs}")
for i, dag_run in enumerate(runs):
# TODO: add pluggability here (ignoring `event_translation_fn` for now)
# TODO: add pluggability here (ignoring `event_transformer_fn` for now)

dag_mats = materializations_for_dag_run(dag_run, airflow_data)
synthetic_mats = build_synthetic_asset_materializations(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,27 +1,33 @@
from typing import AbstractSet, Any, Callable, Iterable, Mapping, Sequence
from typing import AbstractSet, Any, Callable, Iterable, Mapping, Sequence, Union

from dagster import (
AssetMaterialization,
AssetObservation,
JsonMetadataValue,
MarkdownMetadataValue,
SensorEvaluationContext,
TimestampMetadataValue,
_check as check,
)
from dagster._core.definitions.asset_check_evaluation import AssetCheckEvaluation
from dagster._core.definitions.asset_key import AssetKey
from dagster._time import get_current_timestamp

from dagster_airlift.constants import EFFECTIVE_TIMESTAMP_METADATA_KEY
from dagster_airlift.core.airflow_defs_data import AirflowDefinitionsData
from dagster_airlift.core.airflow_instance import DagRun, TaskInstance

AirflowEventTranslationFn = Callable[
[DagRun, Sequence[TaskInstance], AirflowDefinitionsData], Iterable[AssetMaterialization]
AssetEvent = Union[AssetMaterialization, AssetObservation, AssetCheckEvaluation]
DagsterEventTransformerFn = Callable[
[SensorEvaluationContext, AirflowDefinitionsData, Sequence[AssetMaterialization]],
Iterable[AssetEvent],
]


def get_timestamp_from_materialization(mat: AssetMaterialization) -> float:
def get_timestamp_from_materialization(event: AssetEvent) -> float:
return check.float_param(
mat.metadata[EFFECTIVE_TIMESTAMP_METADATA_KEY].value, "Materialization Effective Timestamp"
event.metadata[EFFECTIVE_TIMESTAMP_METADATA_KEY].value,
"Materialization Effective Timestamp",
)


Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections import defaultdict
from datetime import datetime, timedelta
from typing import Dict, Generator, List, Sequence, Tuple, Union
from typing import Dict, Generator, List, Optional, Sequence, Tuple, Union

import pytest
from dagster import (
Expand Down Expand Up @@ -28,6 +28,7 @@
from dagster_airlift.core import (
build_defs_from_airflow_instance as build_defs_from_airflow_instance,
)
from dagster_airlift.core.sensor.event_translation import DagsterEventTransformerFn
from dagster_airlift.core.utils import metadata_for_task_mapping
from dagster_airlift.test import make_dag_run, make_instance

Expand All @@ -40,9 +41,13 @@ def fully_loaded_repo_from_airflow_asset_graph(
assets_per_task: Dict[str, Dict[str, List[Tuple[str, List[str]]]]],
additional_defs: Definitions = Definitions(),
create_runs: bool = True,
event_transformer_fn: Optional[DagsterEventTransformerFn] = None,
) -> RepositoryDefinition:
defs = load_definitions_airflow_asset_graph(
assets_per_task, additional_defs=additional_defs, create_runs=create_runs
assets_per_task,
additional_defs=additional_defs,
create_runs=create_runs,
event_transformer_fn=event_transformer_fn,
)
repo_def = defs.get_repository_def()
repo_def.load_all_definitions()
Expand All @@ -54,6 +59,7 @@ def load_definitions_airflow_asset_graph(
additional_defs: Definitions = Definitions(),
create_runs: bool = True,
create_assets_defs: bool = True,
event_transformer_fn: Optional[DagsterEventTransformerFn] = None,
) -> Definitions:
assets = []
dag_and_task_structure = defaultdict(list)
Expand Down Expand Up @@ -96,17 +102,20 @@ def _asset():
additional_defs,
Definitions(assets=assets),
)
return build_defs_from_airflow_instance(airflow_instance=instance, defs=defs)
return build_defs_from_airflow_instance(
airflow_instance=instance, defs=defs, event_transformer_fn=event_transformer_fn
)


def build_and_invoke_sensor(
*,
assets_per_task: Dict[str, Dict[str, List[Tuple[str, List[str]]]]],
instance: DagsterInstance,
additional_defs: Definitions = Definitions(),
event_transformer_fn: Optional[DagsterEventTransformerFn] = None,
) -> Tuple[SensorResult, SensorEvaluationContext]:
repo_def = fully_loaded_repo_from_airflow_asset_graph(
assets_per_task, additional_defs=additional_defs
assets_per_task, additional_defs=additional_defs, event_transformer_fn=event_transformer_fn
)
sensor = next(iter(repo_def.sensor_defs))
sensor_context = build_sensor_context(repository_def=repo_def, instance=instance)
Expand Down
Loading

0 comments on commit 9bd94e9

Please sign in to comment.