diff --git a/examples/experimental/dagster-airlift/dagster_airlift/in_airflow/base_proxy_operator.py b/examples/experimental/dagster-airlift/dagster_airlift/in_airflow/base_proxy_operator.py index 2c8f058bb3bfc..572752fc2f8c7 100644 --- a/examples/experimental/dagster-airlift/dagster_airlift/in_airflow/base_proxy_operator.py +++ b/examples/experimental/dagster-airlift/dagster_airlift/in_airflow/base_proxy_operator.py @@ -119,6 +119,27 @@ def get_dagster_run_status( ) return self.get_valid_graphql_response(response, "runOrError")["status"] + def get_attribute_from_airflow_context(self, context: Context, attribute: str) -> Any: + if attribute not in context or context[attribute] is None: + raise Exception(f"Attribute {attribute} not found in context.") + return context[attribute] + + def get_airflow_dag_run_id(self, context: Context) -> str: + return self.get_attribute_from_airflow_context(context, "dag_run").run_id + + def get_airflow_dag_id(self, context: Context) -> str: + return self.get_attribute_from_airflow_context(context, "dag_run").dag_id + + def get_airflow_task_id(self, context: Context) -> str: + return self.get_attribute_from_airflow_context(context, "task").task_id + + def default_dagster_run_tags(self, context: Context) -> Dict[str, str]: + return { + DAG_ID_TAG_KEY: self.get_airflow_dag_id(context), + DAG_RUN_ID_TAG_KEY: self.get_airflow_dag_run_id(context), + TASK_ID_TAG_KEY: self.get_airflow_task_id(context), + } + def launch_runs_for_task(self, context: Context, dag_id: str, task_id: str) -> None: """Launches runs for the given task in Dagster.""" session = self._get_validated_session(context) @@ -140,36 +161,23 @@ def launch_runs_for_task(self, context: Context, dag_id: str, task_id: str) -> N ) logger.debug(f"Found assets to trigger: {assets_to_trigger}") - dag_run = context.get("dag_run") - assert dag_run, "dag_run not found in context" - # Get the dag_run_id - dag_run_id = dag_run.run_id - triggered_runs = [] - tags = { - DAG_ID_TAG_KEY: dag_id, - DAG_RUN_ID_TAG_KEY: dag_run_id, - TASK_ID_TAG_KEY: task_id, - } - for (repo_location, repo_name, job_name), asset_keys in assets_to_trigger.items(): - execution_params = { - "mode": "default", - "executionMetadata": { - "tags": [{"key": key, "value": value} for key, value in tags.items()] - }, - "runConfigData": "{}", - "selector": { - "repositoryLocationName": repo_location, - "repositoryName": repo_name, - "pipelineName": job_name, - "assetSelection": [{"path": asset_key} for asset_key in asset_keys], - "assetCheckSelection": [], - }, - } + for (repo_location, repo_name, job_name), asset_key_paths in assets_to_trigger.items(): logger.debug( - f"Triggering run for {repo_location}/{repo_name}/{job_name} with assets {asset_keys}" + f"Triggering run for {repo_location}/{repo_name}/{job_name} with assets {asset_key_paths}" + ) + run_id = self.launch_dagster_run( + context, + session, + dagster_url, + _build_dagster_run_execution_params( + self.default_dagster_run_tags(context), + repo_location, + repo_name, + job_name, + asset_key_paths, + ), ) - run_id = self.launch_dagster_run(context, session, dagster_url, execution_params) triggered_runs.append(run_id) completed_runs = {} # key is run_id, value is status while len(completed_runs) < len(triggered_runs): @@ -196,6 +204,29 @@ def execute(self, context: Context) -> Any: return self.launch_runs_for_task(context, dag_id, task_id) +def _build_dagster_run_execution_params( + tags: Mapping[str, Any], + location_name: str, + repository_name: str, + job_name: str, + asset_key_paths: Sequence[Sequence[str]], +) -> Dict[str, Any]: + return { + "mode": "default", + "executionMetadata": { + "tags": [{"key": key, "value": value} for key, value in tags.items()] + }, + "runConfigData": "{}", + "selector": { + "repositoryLocationName": location_name, + "repositoryName": repository_name, + "pipelineName": job_name, + "assetSelection": [{"path": asset_key} for asset_key in asset_key_paths], + "assetCheckSelection": [], + }, + } + + class DefaultProxyToDagsterOperator(BaseProxyToDagsterOperator): def get_dagster_session(self, context: Context) -> requests.Session: return requests.Session()