Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dagster-airlift] refactor run execution params #25088

Open
wants to merge 1 commit into
base: dpeng817/factor_out_gql
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,27 @@ def get_dagster_run_status(
)
return self.get_valid_graphql_response(response, "runOrError")["status"]

def get_attribute_from_airflow_context(self, context: Context, attribute: str) -> Any:
if attribute not in context or context[attribute] is None:
raise Exception(f"Attribute {attribute} not found in context.")
return context[attribute]

def get_airflow_dag_run_id(self, context: Context) -> str:
return self.get_attribute_from_airflow_context(context, "dag_run").run_id

def get_airflow_dag_id(self, context: Context) -> str:
return self.get_attribute_from_airflow_context(context, "dag_run").dag_id

def get_airflow_task_id(self, context: Context) -> str:
return self.get_attribute_from_airflow_context(context, "task").task_id
Comment on lines +122 to +134
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to constrast from the previous PR, get is appropriate here since it is in-memory/cheap


def default_dagster_run_tags(self, context: Context) -> Dict[str, str]:
return {
DAG_ID_TAG_KEY: self.get_airflow_dag_id(context),
DAG_RUN_ID_TAG_KEY: self.get_airflow_dag_run_id(context),
TASK_ID_TAG_KEY: self.get_airflow_task_id(context),
}

def launch_runs_for_task(self, context: Context, dag_id: str, task_id: str) -> None:
"""Launches runs for the given task in Dagster."""
session = self._get_validated_session(context)
Expand All @@ -140,36 +161,23 @@ def launch_runs_for_task(self, context: Context, dag_id: str, task_id: str) -> N
)
logger.debug(f"Found assets to trigger: {assets_to_trigger}")

dag_run = context.get("dag_run")
assert dag_run, "dag_run not found in context"
# Get the dag_run_id
dag_run_id = dag_run.run_id

triggered_runs = []
tags = {
DAG_ID_TAG_KEY: dag_id,
DAG_RUN_ID_TAG_KEY: dag_run_id,
TASK_ID_TAG_KEY: task_id,
}
for (repo_location, repo_name, job_name), asset_keys in assets_to_trigger.items():
execution_params = {
"mode": "default",
"executionMetadata": {
"tags": [{"key": key, "value": value} for key, value in tags.items()]
},
"runConfigData": "{}",
"selector": {
"repositoryLocationName": repo_location,
"repositoryName": repo_name,
"pipelineName": job_name,
"assetSelection": [{"path": asset_key} for asset_key in asset_keys],
"assetCheckSelection": [],
},
}
for (repo_location, repo_name, job_name), asset_key_paths in assets_to_trigger.items():
logger.debug(
f"Triggering run for {repo_location}/{repo_name}/{job_name} with assets {asset_keys}"
f"Triggering run for {repo_location}/{repo_name}/{job_name} with assets {asset_key_paths}"
)
run_id = self.launch_dagster_run(
context,
session,
dagster_url,
_build_dagster_run_execution_params(
self.default_dagster_run_tags(context),
repo_location,
repo_name,
job_name,
asset_key_paths,
),
)
run_id = self.launch_dagster_run(context, session, dagster_url, execution_params)
triggered_runs.append(run_id)
completed_runs = {} # key is run_id, value is status
while len(completed_runs) < len(triggered_runs):
Expand All @@ -196,6 +204,29 @@ def execute(self, context: Context) -> Any:
return self.launch_runs_for_task(context, dag_id, task_id)


def _build_dagster_run_execution_params(
tags: Mapping[str, Any],
location_name: str,
repository_name: str,
job_name: str,
asset_key_paths: Sequence[Sequence[str]],
) -> Dict[str, Any]:
return {
"mode": "default",
"executionMetadata": {
"tags": [{"key": key, "value": value} for key, value in tags.items()]
},
"runConfigData": "{}",
"selector": {
"repositoryLocationName": location_name,
"repositoryName": repository_name,
"pipelineName": job_name,
"assetSelection": [{"path": asset_key} for asset_key in asset_key_paths],
"assetCheckSelection": [],
},
}


class DefaultProxyToDagsterOperator(BaseProxyToDagsterOperator):
def get_dagster_session(self, context: Context) -> requests.Session:
return requests.Session()
Expand Down