From cf011a53ad06926c2786a78fd302083d31327480 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Wed, 30 Oct 2024 18:20:55 +0000 Subject: [PATCH] Start porting DAG definition code to the Task SDK (#43076) closes #43011 By "definition code" we mean anything needed at definition/parse time, leaving anything to do with scheduling time decisions in Airflow's core. Also in this PR I have _attempted_ to keep it to only porting defintiion code for simple DAGs, leaving anything to do with mapped tasks or execution time in core for now, but a few things "leaked" across. And as the goal of this PR is to go from working state to working state some of the code in Task SDK still imports from "core" (various types, enums or helpers) that will need to be resolved before 3.0 release, but it is fine for now. I'm also aware that the class hierarchy with airflow.models.baseoperator.BaseOperator (and to a lesser extend with DAG) in particular is very messy right now, and we will need to think how we want to add on the scheduling-time functions etc, as I'm not yet sold that having Core Airflow depend upon the Task-SDK classes/import the code is the right structure, but we can address that later We will also need to addresses the rendered docs for the Task SDK in a future PR -- the goal is that "anything" exposed on `airflow.sdk` directly is part of the public API, but right now the renedered docs show DAG as `airflow.sdk.definitions.dag.DAG` which is certainly not what we want users to see. Co-authored-by: Kaxil Naik --- .github/workflows/prod-image-build.yml | 5 + .pre-commit-config.yaml | 2 + Dockerfile | 33 +- Dockerfile.ci | 2 +- airflow/api_connexion/schemas/dag_schema.py | 1 - .../core_api/openapi/v1-generated.yaml | 4 - .../api_fastapi/core_api/serializers/dags.py | 1 - airflow/cli/commands/dag_command.py | 1 - airflow/dag_processing/collection.py | 5 +- airflow/decorators/base.py | 44 +- airflow/decorators/bash.py | 4 +- airflow/decorators/sensor.py | 4 +- airflow/decorators/task_group.py | 2 +- airflow/exceptions.py | 6 +- airflow/models/abstractoperator.py | 355 ++--- airflow/models/baseoperator.py | 1190 +++------------- airflow/models/dag.py | 1053 +------------- airflow/models/dagbag.py | 13 +- airflow/models/mappedoperator.py | 7 +- airflow/models/param.py | 2 +- airflow/models/skipmixin.py | 2 +- airflow/models/taskinstance.py | 12 +- airflow/models/taskmixin.py | 270 +--- airflow/models/xcom_arg.py | 21 +- airflow/operators/python.py | 5 +- airflow/sensors/external_task.py | 4 +- airflow/serialization/schema.json | 20 +- airflow/serialization/serialized_objects.py | 61 +- airflow/task/priority_strategy.py | 4 +- airflow/template/templater.py | 17 +- airflow/typing_compat.py | 5 +- .../ui/openapi-gen/requests/schemas.gen.ts | 5 - airflow/ui/openapi-gen/requests/types.gen.ts | 1 - airflow/utils/decorators.py | 9 +- airflow/utils/edgemodifier.py | 154 +-- airflow/utils/log/logging_mixin.py | 1 + airflow/utils/task_group.py | 624 +-------- airflow/utils/types.py | 32 +- dev/mypy/plugin/outputs.py | 1 + hatch_build.py | 3 +- .../amazon/aws/operators/comprehend.py | 6 +- .../providers/amazon/aws/operators/dms.py | 6 +- .../amazon/aws/operators/kinesis_analytics.py | 6 +- .../amazon/aws/operators/sagemaker.py | 4 +- .../providers/apache/drill/operators/drill.py | 4 +- .../cncf/kubernetes/operators/pod.py | 4 +- .../kubernetes/operators/spark_kubernetes.py | 3 +- .../providers/common/sql/operators/sql.py | 18 +- .../providers/common/sql/operators/sql.pyi | 18 +- .../databricks/operators/databricks_sql.py | 4 +- .../providers/exasol/operators/exasol.py | 4 +- .../auth_manager/security_manager/override.py | 9 +- .../airflow/providers/jdbc/operators/jdbc.py | 4 +- .../microsoft/mssql/operators/mssql.py | 4 +- .../providers/mysql/operators/mysql.py | 4 +- .../providers/oracle/operators/oracle.py | 4 +- .../providers/postgres/operators/postgres.py | 7 +- .../snowflake/operators/snowflake.py | 4 +- .../providers/sqlite/operators/sqlite.py | 4 +- .../providers/teradata/operators/teradata.py | 4 +- .../providers/trino/operators/trino.py | 4 +- .../providers/vertica/operators/vertica.py | 4 +- .../tests/amazon/aws/operators/test_batch.py | 2 +- .../google/cloud/operators/test_bigquery.py | 4 +- .../cloud/operators/test_cloud_build.py | 4 +- .../google/cloud/operators/test_compute.py | 10 +- .../google/cloud/operators/test_dataflow.py | 2 +- .../google/cloud/operators/test_dataproc.py | 6 +- .../cloud/operators/test_kubernetes_engine.py | 16 +- .../cloud/operators/test_speech_to_text.py | 10 +- .../google/cloud/sensors/test_dataproc.py | 2 +- .../cloud/transfers/test_gcs_to_bigquery.py | 7 +- .../tests/salesforce/operators/test_bulk.py | 4 +- .../tests/standard/operators/test_weekday.py | 2 +- pyproject.toml | 1 + .../base_operator_partial_arguments.py | 80 +- scripts/ci/pre_commit/sync_init_decorator.py | 201 ++- scripts/docker/entrypoint_ci.sh | 2 +- .../install_from_docker_context_files.sh | 33 +- task_sdk/pyproject.toml | 35 +- task_sdk/src/airflow/sdk/__init__.py | 39 +- .../airflow/sdk/definitions/__init__.py} | 7 - .../sdk/definitions/abstractoperator.py | 261 ++++ .../airflow/sdk/definitions/baseoperator.py | 1226 +++++++++++++++++ .../airflow/sdk/definitions/contextmanager.py | 125 ++ task_sdk/src/airflow/sdk/definitions/dag.py | 1119 +++++++++++++++ .../src/airflow/sdk/definitions/decorators.py | 42 + task_sdk/src/airflow/sdk/definitions/edges.py | 189 +++ .../src/airflow/sdk/definitions/mixins.py | 121 ++ task_sdk/src/airflow/sdk/definitions/node.py | 222 +++ .../src/airflow/sdk/definitions/taskgroup.py | 683 +++++++++ task_sdk/src/airflow/sdk/exceptions.py | 16 + task_sdk/src/airflow/sdk/types.py | 75 + .../tests/defintions/test_baseoperator.py | 343 +++++ task_sdk/tests/defintions/test_dag.py | 419 ++++++ .../endpoints/test_dag_endpoint.py | 32 - .../api_connexion/schemas/test_dag_schema.py | 5 - .../core_api/routes/public/test_dags.py | 1 - tests/models/test_baseoperator.py | 301 +--- tests/models/test_dag.py | 435 +----- tests/models/test_dagbag.py | 2 +- tests/models/test_taskinstance.py | 10 +- tests/serialization/test_dag_serialization.py | 95 +- tests/serialization/test_pydantic_models.py | 16 +- tests/utils/test_task_group.py | 17 +- tests_common/test_utils/mock_operators.py | 15 +- 106 files changed, 5850 insertions(+), 4501 deletions(-) rename task_sdk/{tests/test_hello.py => src/airflow/sdk/definitions/__init__.py} (85%) create mode 100644 task_sdk/src/airflow/sdk/definitions/abstractoperator.py create mode 100644 task_sdk/src/airflow/sdk/definitions/baseoperator.py create mode 100644 task_sdk/src/airflow/sdk/definitions/contextmanager.py create mode 100644 task_sdk/src/airflow/sdk/definitions/dag.py create mode 100644 task_sdk/src/airflow/sdk/definitions/decorators.py create mode 100644 task_sdk/src/airflow/sdk/definitions/edges.py create mode 100644 task_sdk/src/airflow/sdk/definitions/mixins.py create mode 100644 task_sdk/src/airflow/sdk/definitions/node.py create mode 100644 task_sdk/src/airflow/sdk/definitions/taskgroup.py create mode 100644 task_sdk/src/airflow/sdk/exceptions.py create mode 100644 task_sdk/src/airflow/sdk/types.py create mode 100644 task_sdk/tests/defintions/test_baseoperator.py create mode 100644 task_sdk/tests/defintions/test_dag.py diff --git a/.github/workflows/prod-image-build.yml b/.github/workflows/prod-image-build.yml index db80a6ec247e..df4f24981ff3 100644 --- a/.github/workflows/prod-image-build.yml +++ b/.github/workflows/prod-image-build.yml @@ -181,6 +181,11 @@ jobs: run: > breeze release-management prepare-airflow-package --package-format wheel if: inputs.do-build == 'true' && inputs.upload-package-artifact == 'true' + - name: "Prepare task-sdk package" + shell: bash + run: > + breeze release-management prepare-task-sdk-package --package-format wheel + if: inputs.do-build == 'true' && inputs.upload-package-artifact == 'true' - name: "Upload prepared packages as artifacts" uses: actions/upload-artifact@v4 with: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7e94e4191e8a..ad4b2529b860 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1189,6 +1189,8 @@ repos: ^airflow/utils/helpers.py$ | ^providers/src/airflow/providers/ | ^(providers/)?tests/ | + task_sdk/src/airflow/sdk/definitions/dag.py$ | + task_sdk/src/airflow/sdk/definitions/node.py$ | ^dev/.*\.py$ | ^scripts/.*\.py$ | ^docker_tests/.*$ | diff --git a/Dockerfile b/Dockerfile index 10ceb8996f9d..f874bfe9e579 100644 --- a/Dockerfile +++ b/Dockerfile @@ -718,6 +718,7 @@ COPY <<"EOF" /install_from_docker_context_files.sh function install_airflow_and_providers_from_docker_context_files(){ + local flags=() if [[ ${INSTALL_MYSQL_CLIENT} != "true" ]]; then AIRFLOW_EXTRAS=${AIRFLOW_EXTRAS/mysql,} fi @@ -756,10 +757,10 @@ function install_airflow_and_providers_from_docker_context_files(){ install_airflow_package=("apache-airflow[${AIRFLOW_EXTRAS}]==${AIRFLOW_VERSION}") fi - # Find Provider packages in docker-context files - readarray -t installing_providers_packages< <(python /scripts/docker/get_package_specs.py /docker-context-files/apache?airflow?providers*.{whl,tar.gz} 2>/dev/null || true) + # Find Provider/TaskSDK packages in docker-context files + readarray -t airflow_packages< <(python /scripts/docker/get_package_specs.py /docker-context-files/apache?airflow?{providers,task?sdk}*.{whl,tar.gz} 2>/dev/null || true) echo - echo "${COLOR_BLUE}Found provider packages in docker-context-files folder: ${installing_providers_packages[*]}${COLOR_RESET}" + echo "${COLOR_BLUE}Found provider packages in docker-context-files folder: ${airflow_packages[*]}${COLOR_RESET}" echo if [[ ${USE_CONSTRAINTS_FOR_CONTEXT_PACKAGES=} == "true" ]]; then @@ -772,11 +773,7 @@ function install_airflow_and_providers_from_docker_context_files(){ echo "${COLOR_BLUE}Installing docker-context-files packages with constraints found in ${local_constraints_file}${COLOR_RESET}" echo # force reinstall all airflow + provider packages with constraints found in - set -x - ${PACKAGING_TOOL_CMD} install ${EXTRA_INSTALL_FLAGS} --upgrade \ - ${ADDITIONAL_PIP_INSTALL_FLAGS} --constraint "${local_constraints_file}" \ - "${install_airflow_package[@]}" "${installing_providers_packages[@]}" - set +x + flags=(--upgrade --constraint "${local_constraints_file}") echo echo "${COLOR_BLUE}Copying ${local_constraints_file} to ${HOME}/constraints.txt${COLOR_RESET}" echo @@ -785,23 +782,21 @@ function install_airflow_and_providers_from_docker_context_files(){ echo echo "${COLOR_BLUE}Installing docker-context-files packages with constraints from GitHub${COLOR_RESET}" echo - set -x - ${PACKAGING_TOOL_CMD} install ${EXTRA_INSTALL_FLAGS} \ - ${ADDITIONAL_PIP_INSTALL_FLAGS} \ - --constraint "${HOME}/constraints.txt" \ - "${install_airflow_package[@]}" "${installing_providers_packages[@]}" - set +x + flags=(--constraint "${HOME}/constraints.txt") fi else echo echo "${COLOR_BLUE}Installing docker-context-files packages without constraints${COLOR_RESET}" echo - set -x - ${PACKAGING_TOOL_CMD} install ${EXTRA_INSTALL_FLAGS} \ - ${ADDITIONAL_PIP_INSTALL_FLAGS} \ - "${install_airflow_package[@]}" "${installing_providers_packages[@]}" - set +x + flags=() fi + + set -x + ${PACKAGING_TOOL_CMD} install ${EXTRA_INSTALL_FLAGS} \ + ${ADDITIONAL_PIP_INSTALL_FLAGS} \ + "${flags[@]}" \ + "${install_airflow_package[@]}" "${airflow_packages[@]}" + set +x common::install_packaging_tools pip check } diff --git a/Dockerfile.ci b/Dockerfile.ci index 114f12739014..94d61507ed41 100644 --- a/Dockerfile.ci +++ b/Dockerfile.ci @@ -1154,7 +1154,7 @@ function check_force_lowest_dependencies() { echo fi set -x - uv pip install --python "$(which python)" --resolution lowest-direct --upgrade --editable ".${EXTRA}" + uv pip install --python "$(which python)" --resolution lowest-direct --upgrade --editable ".${EXTRA}" --editable "./task_sdk" set +x } diff --git a/airflow/api_connexion/schemas/dag_schema.py b/airflow/api_connexion/schemas/dag_schema.py index 4fb1dd6ae4b3..f22812abd111 100644 --- a/airflow/api_connexion/schemas/dag_schema.py +++ b/airflow/api_connexion/schemas/dag_schema.py @@ -56,7 +56,6 @@ class Meta: last_parsed_time = auto_field(dump_only=True) last_pickled = auto_field(dump_only=True) last_expired = auto_field(dump_only=True) - pickle_id = auto_field(dump_only=True) default_view = auto_field(dump_only=True) fileloc = auto_field(dump_only=True) file_token = fields.Method("get_token", dump_only=True) diff --git a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml index 42e371a2c0b3..99f24ffdb728 100644 --- a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml +++ b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml @@ -1981,9 +1981,6 @@ components: - type: boolean - type: 'null' title: Is Paused Upon Creation - orientation: - type: string - title: Orientation params: anyOf: - type: object @@ -2053,7 +2050,6 @@ components: - start_date - end_date - is_paused_upon_creation - - orientation - params - render_template_as_native_obj - template_search_path diff --git a/airflow/api_fastapi/core_api/serializers/dags.py b/airflow/api_fastapi/core_api/serializers/dags.py index c6294324c7e6..6e2c3933e176 100644 --- a/airflow/api_fastapi/core_api/serializers/dags.py +++ b/airflow/api_fastapi/core_api/serializers/dags.py @@ -116,7 +116,6 @@ class DAGDetailsResponse(DAGResponse): start_date: datetime | None end_date: datetime | None is_paused_upon_creation: bool | None - orientation: str params: abc.MutableMapping | None render_template_as_native_obj: bool template_search_path: Iterable[str] | None diff --git a/airflow/cli/commands/dag_command.py b/airflow/cli/commands/dag_command.py index 89c5fd477f05..92d1825dc627 100644 --- a/airflow/cli/commands/dag_command.py +++ b/airflow/cli/commands/dag_command.py @@ -227,7 +227,6 @@ def _get_dagbag_dag_details(dag: DAG) -> dict: "last_parsed_time": None, "last_pickled": None, "last_expired": None, - "pickle_id": dag.pickle_id, "default_view": dag.default_view, "fileloc": dag.fileloc, "file_token": None, diff --git a/airflow/dag_processing/collection.py b/airflow/dag_processing/collection.py index f27f45dda82e..f608900ee76e 100644 --- a/airflow/dag_processing/collection.py +++ b/airflow/dag_processing/collection.py @@ -211,7 +211,10 @@ def update_dags( dm.has_import_errors = False dm.last_parsed_time = utcnow() dm.default_view = dag.default_view - dm._dag_display_property_value = dag._dag_display_property_value + if hasattr(dag, "_dag_display_property_value"): + dm._dag_display_property_value = dag._dag_display_property_value + elif dag.dag_display_name != dag.dag_id: + dm._dag_display_property_value = dag.dag_display_name dm.description = dag.description dm.max_active_tasks = dag.max_active_tasks dm.max_active_runs = dag.max_active_runs diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py index bb9602d50c1c..1c9e441190a0 100644 --- a/airflow/decorators/base.py +++ b/airflow/decorators/base.py @@ -41,7 +41,6 @@ import typing_extensions from airflow.assets import Asset -from airflow.models.abstractoperator import DEFAULT_RETRIES, DEFAULT_RETRY_DELAY from airflow.models.baseoperator import ( BaseOperator, coerce_resources, @@ -49,7 +48,6 @@ get_merged_defaults, parse_retries, ) -from airflow.models.dag import DagContext from airflow.models.expandinput import ( EXPAND_INPUT_EMPTY, DictOfListsExpandInput, @@ -57,27 +55,27 @@ is_mappable, ) from airflow.models.mappedoperator import MappedOperator, ensure_xcomarg_return_value -from airflow.models.pool import Pool from airflow.models.xcom_arg import XComArg +from airflow.sdk.definitions.baseoperator import BaseOperator as TaskSDKBaseOperator +from airflow.sdk.definitions.contextmanager import DagContext, TaskGroupContext from airflow.typing_compat import ParamSpec, Protocol from airflow.utils import timezone from airflow.utils.context import KNOWN_CONTEXT_KEYS from airflow.utils.decorators import remove_task_decorator from airflow.utils.helpers import prevent_duplicates -from airflow.utils.task_group import TaskGroupContext from airflow.utils.trigger_rule import TriggerRule from airflow.utils.types import NOTSET if TYPE_CHECKING: from sqlalchemy.orm import Session - from airflow.models.dag import DAG from airflow.models.expandinput import ( ExpandInput, OperatorExpandArgument, OperatorExpandKwargsArgument, ) from airflow.models.mappedoperator import ValidationSource + from airflow.sdk import DAG from airflow.utils.context import Context from airflow.utils.task_group import TaskGroup @@ -141,13 +139,13 @@ def get_unique_task_id( ... task_id__20 """ - dag = dag or DagContext.get_current_dag() + dag = dag or DagContext.get_current() if not dag: return task_id # We need to check if we are in the context of TaskGroup as the task_id may # already be altered - task_group = task_group or TaskGroupContext.get_current_task_group(dag) + task_group = task_group or TaskGroupContext.get_current(dag) tg_task_id = task_group.child_id(task_id) if task_group else task_id if tg_task_id not in dag.task_ids: @@ -428,8 +426,8 @@ def _expand(self, expand_input: ExpandInput, *, strict: bool) -> XComArg: ensure_xcomarg_return_value(expand_input.value) task_kwargs = self.kwargs.copy() - dag = task_kwargs.pop("dag", None) or DagContext.get_current_dag() - task_group = task_kwargs.pop("task_group", None) or TaskGroupContext.get_current_task_group(dag) + dag = task_kwargs.pop("dag", None) or DagContext.get_current() + task_group = task_kwargs.pop("task_group", None) or TaskGroupContext.get_current(dag) default_args, partial_params = get_merged_defaults( dag=dag, @@ -442,7 +440,7 @@ def _expand(self, expand_input: ExpandInput, *, strict: bool) -> XComArg: "is_teardown": self.is_teardown, "on_failure_fail_dagrun": self.on_failure_fail_dagrun, } - base_signature = inspect.signature(BaseOperator) + base_signature = inspect.signature(TaskSDKBaseOperator) ignore = { "default_args", # This is target we are working on now. "kwargs", # A common name for a keyword argument. @@ -460,32 +458,26 @@ def _expand(self, expand_input: ExpandInput, *, strict: bool) -> XComArg: task_id = task_group.child_id(task_id) # Logic here should be kept in sync with BaseOperatorMeta.partial(). - if "task_concurrency" in partial_kwargs: - raise TypeError("unexpected argument: task_concurrency") if partial_kwargs.get("wait_for_downstream"): partial_kwargs["depends_on_past"] = True start_date = timezone.convert_to_utc(partial_kwargs.pop("start_date", None)) end_date = timezone.convert_to_utc(partial_kwargs.pop("end_date", None)) - if partial_kwargs.get("pool") is None: - partial_kwargs["pool"] = Pool.DEFAULT_POOL_NAME if "pool_slots" in partial_kwargs: if partial_kwargs["pool_slots"] < 1: dag_str = "" if dag: dag_str = f" in dag {dag.dag_id}" raise ValueError(f"pool slots for {task_id}{dag_str} cannot be less than 1") - partial_kwargs["retries"] = parse_retries(partial_kwargs.get("retries", DEFAULT_RETRIES)) - partial_kwargs["retry_delay"] = coerce_timedelta( - partial_kwargs.get("retry_delay", DEFAULT_RETRY_DELAY), - key="retry_delay", - ) - max_retry_delay = partial_kwargs.get("max_retry_delay") - partial_kwargs["max_retry_delay"] = ( - max_retry_delay - if max_retry_delay is None - else coerce_timedelta(max_retry_delay, key="max_retry_delay") - ) - partial_kwargs["resources"] = coerce_resources(partial_kwargs.get("resources")) + + for fld, convert in ( + ("retries", parse_retries), + ("retry_delay", coerce_timedelta), + ("max_retry_delay", coerce_timedelta), + ("resources", coerce_resources), + ): + if (v := partial_kwargs.get(fld, NOTSET)) is not NOTSET: + partial_kwargs[fld] = convert(v) # type: ignore[operator] + partial_kwargs.setdefault("executor_config", {}) partial_kwargs.setdefault("op_args", []) partial_kwargs.setdefault("op_kwargs", {}) diff --git a/airflow/decorators/bash.py b/airflow/decorators/bash.py index 44738492da09..e4dc19745e0a 100644 --- a/airflow/decorators/bash.py +++ b/airflow/decorators/bash.py @@ -18,7 +18,7 @@ from __future__ import annotations import warnings -from typing import Any, Callable, Collection, Mapping, Sequence +from typing import Any, Callable, ClassVar, Collection, Mapping, Sequence from airflow.decorators.base import DecoratedOperator, TaskDecorator, task_decorator_factory from airflow.providers.standard.operators.bash import BashOperator @@ -39,7 +39,7 @@ class _BashDecoratedOperator(DecoratedOperator, BashOperator): """ template_fields: Sequence[str] = (*DecoratedOperator.template_fields, *BashOperator.template_fields) - template_fields_renderers: dict[str, str] = { + template_fields_renderers: ClassVar[dict[str, str]] = { **DecoratedOperator.template_fields_renderers, **BashOperator.template_fields_renderers, } diff --git a/airflow/decorators/sensor.py b/airflow/decorators/sensor.py index c332a78f95c7..9ee4eeb2a79c 100644 --- a/airflow/decorators/sensor.py +++ b/airflow/decorators/sensor.py @@ -17,7 +17,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Callable, Sequence +from typing import TYPE_CHECKING, Callable, ClassVar, Sequence from airflow.decorators.base import get_unique_task_id, task_decorator_factory from airflow.sensors.python import PythonSensor @@ -42,7 +42,7 @@ class DecoratedSensorOperator(PythonSensor): """ template_fields: Sequence[str] = ("op_args", "op_kwargs") - template_fields_renderers: dict[str, str] = {"op_args": "py", "op_kwargs": "py"} + template_fields_renderers: ClassVar[dict[str, str]] = {"op_args": "py", "op_kwargs": "py"} custom_operator_name = "@task.sensor" diff --git a/airflow/decorators/task_group.py b/airflow/decorators/task_group.py index 6eee426e936a..daaa81e1ce62 100644 --- a/airflow/decorators/task_group.py +++ b/airflow/decorators/task_group.py @@ -38,8 +38,8 @@ ListOfDictsExpandInput, MappedArgument, ) -from airflow.models.taskmixin import DAGNode from airflow.models.xcom_arg import XComArg +from airflow.sdk.definitions.node import DAGNode from airflow.typing_compat import ParamSpec from airflow.utils.helpers import prevent_duplicates from airflow.utils.task_group import MappedTaskGroup, TaskGroup diff --git a/airflow/exceptions.py b/airflow/exceptions.py index ccf62ca5e817..316fe880b66b 100644 --- a/airflow/exceptions.py +++ b/airflow/exceptions.py @@ -31,7 +31,7 @@ import datetime from collections.abc import Sized - from airflow.models import DAG, DagRun + from airflow.models import DagRun class AirflowException(Exception): @@ -273,13 +273,13 @@ class FailStopDagInvalidTriggerRule(AirflowException): _allowed_rules = (TriggerRule.ALL_SUCCESS, TriggerRule.ALL_DONE_SETUP_SUCCESS) @classmethod - def check(cls, *, dag: DAG | None, trigger_rule: TriggerRule): + def check(cls, *, fail_stop: bool, trigger_rule: TriggerRule): """ Check that fail_stop dag tasks have allowable trigger rules. :meta private: """ - if dag is not None and dag.fail_stop and trigger_rule not in cls._allowed_rules: + if fail_stop and trigger_rule not in cls._allowed_rules: raise cls() def __str__(self) -> str: diff --git a/airflow/models/abstractoperator.py b/airflow/models/abstractoperator.py index 45eb3c5fff18..feafb0b6b637 100644 --- a/airflow/models/abstractoperator.py +++ b/airflow/models/abstractoperator.py @@ -19,9 +19,8 @@ import datetime import inspect -from abc import abstractproperty from functools import cached_property -from typing import TYPE_CHECKING, Any, Callable, ClassVar, Collection, Iterable, Iterator, Sequence +from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, Sequence import methodtools from sqlalchemy import select @@ -29,7 +28,7 @@ from airflow.configuration import conf from airflow.exceptions import AirflowException from airflow.models.expandinput import NotFullyPopulated -from airflow.models.taskmixin import DAGNode, DependencyMixin +from airflow.sdk.definitions.abstractoperator import AbstractOperator as TaskSDKAbstractOperator from airflow.template.templater import Templater from airflow.utils.context import Context from airflow.utils.db import exists_query @@ -39,25 +38,26 @@ from airflow.utils.state import State, TaskInstanceState from airflow.utils.task_group import MappedTaskGroup from airflow.utils.trigger_rule import TriggerRule -from airflow.utils.types import NOTSET, ArgNotSet from airflow.utils.weight_rule import WeightRule -TaskStateChangeCallback = Callable[[Context], None] - if TYPE_CHECKING: + from collections.abc import Mapping + import jinja2 # Slow import. from sqlalchemy.orm import Session - from airflow.models.baseoperator import BaseOperator from airflow.models.baseoperatorlink import BaseOperatorLink - from airflow.models.dag import DAG + from airflow.models.dag import DAG as SchedulerDAG from airflow.models.mappedoperator import MappedOperator - from airflow.models.operator import Operator from airflow.models.taskinstance import TaskInstance + from airflow.sdk import DAG, BaseOperator + from airflow.sdk.definitions.node import DAGNode from airflow.task.priority_strategy import PriorityWeightStrategy from airflow.triggers.base import StartTriggerArgs from airflow.utils.task_group import TaskGroup +TaskStateChangeCallback = Callable[[Context], None] + DEFAULT_OWNER: str = conf.get_mandatory_value("operators", "default_owner") DEFAULT_POOL_SLOTS: int = 1 DEFAULT_PRIORITY_WEIGHT: int = 1 @@ -86,7 +86,7 @@ class NotMapped(Exception): """Raise if a task is neither mapped nor has any parent mapped groups.""" -class AbstractOperator(Templater, DAGNode): +class AbstractOperator(Templater, TaskSDKAbstractOperator): """ Common implementation for operators, including unmapped and mapped. @@ -100,101 +100,8 @@ class AbstractOperator(Templater, DAGNode): :meta private: """ - operator_class: type[BaseOperator] | dict[str, Any] - - weight_rule: PriorityWeightStrategy - priority_weight: int - - # Defines the operator level extra links. - operator_extra_links: Collection[BaseOperatorLink] - - owner: str - task_id: str - - outlets: list - inlets: list trigger_rule: TriggerRule - _needs_expansion: bool | None = None - _on_failure_fail_dagrun = False - - HIDE_ATTRS_FROM_UI: ClassVar[frozenset[str]] = frozenset( - ( - "log", - "dag", # We show dag_id, don't need to show this too - "node_id", # Duplicates task_id - "task_group", # Doesn't have a useful repr, no point showing in UI - "inherits_from_empty_operator", # impl detail - # Decide whether to start task execution from triggerer - "start_trigger_args", - "start_from_trigger", - # For compatibility with TG, for operators these are just the current task, no point showing - "roots", - "leaves", - # These lists are already shown via *_task_ids - "upstream_list", - "downstream_list", - # Not useful, implementation detail, already shown elsewhere - "global_operator_extra_link_dict", - "operator_extra_link_dict", - ) - ) - - def get_dag(self) -> DAG | None: - raise NotImplementedError() - - @property - def task_type(self) -> str: - raise NotImplementedError() - - @property - def operator_name(self) -> str: - raise NotImplementedError() - - @property - def inherits_from_empty_operator(self) -> bool: - raise NotImplementedError() - - @property - def dag_id(self) -> str: - """Returns dag id if it has one or an adhoc + owner.""" - dag = self.get_dag() - if dag: - return dag.dag_id - return f"adhoc_{self.owner}" - - @property - def node_id(self) -> str: - return self.task_id - - @abstractproperty - def task_display_name(self) -> str: ... - - @property - def label(self) -> str | None: - if self.task_display_name and self.task_display_name != self.task_id: - return self.task_display_name - # Prefix handling if no display is given is cloned from taskmixin for compatibility - tg = self.task_group - if tg and tg.node_id and tg.prefix_group_id: - # "task_group_id.task_id" -> "task_id" - return self.task_id[len(tg.node_id) + 1 :] - return self.task_id - - @property - def is_setup(self) -> bool: - raise NotImplementedError() - - @is_setup.setter - def is_setup(self, value: bool) -> None: - raise NotImplementedError() - - @property - def is_teardown(self) -> bool: - raise NotImplementedError() - - @is_teardown.setter - def is_teardown(self, value: bool) -> None: - raise NotImplementedError() + weight_rule: PriorityWeightStrategy @property def on_failure_fail_dagrun(self): @@ -219,113 +126,71 @@ def on_failure_fail_dagrun(self, value): ) self._on_failure_fail_dagrun = value - def as_setup(self): - self.is_setup = True - return self + def get_template_env(self, dag: DAG | None = None) -> jinja2.Environment: + """Get the template environment for rendering templates.""" + if dag is None: + dag = self.get_dag() + return super().get_template_env(dag=dag) - def as_teardown( + def _render(self, template, context, dag: DAG | None = None): + if dag is None: + dag = self.get_dag() + return super()._render(template, context, dag=dag) + + def _do_render_template_fields( self, - *, - setups: BaseOperator | Iterable[BaseOperator] | ArgNotSet = NOTSET, - on_failure_fail_dagrun=NOTSET, - ): - self.is_teardown = True - self.trigger_rule = TriggerRule.ALL_DONE_SETUP_SUCCESS - if on_failure_fail_dagrun is not NOTSET: - self.on_failure_fail_dagrun = on_failure_fail_dagrun - if not isinstance(setups, ArgNotSet): - setups = [setups] if isinstance(setups, DependencyMixin) else setups - for s in setups: - s.is_setup = True - s >> self - return self - - def get_direct_relative_ids(self, upstream: bool = False) -> set[str]: - """Get direct relative IDs to the current task, upstream or downstream.""" - if upstream: - return self.upstream_task_ids - return self.downstream_task_ids - - def get_flat_relative_ids(self, *, upstream: bool = False) -> set[str]: - """ - Get a flat set of relative IDs, upstream or downstream. - - Will recurse each relative found in the direction specified. - - :param upstream: Whether to look for upstream or downstream relatives. - """ - dag = self.get_dag() - if not dag: - return set() - - relatives: set[str] = set() - - # This is intentionally implemented as a loop, instead of calling - # get_direct_relative_ids() recursively, since Python has significant - # limitation on stack level, and a recursive implementation can blow up - # if a DAG contains very long routes. - task_ids_to_trace = self.get_direct_relative_ids(upstream) - while task_ids_to_trace: - task_ids_to_trace_next: set[str] = set() - for task_id in task_ids_to_trace: - if task_id in relatives: + parent: Any, + template_fields: Iterable[str], + context: Mapping[str, Any], + jinja_env: jinja2.Environment, + seen_oids: set[int], + ) -> None: + """Override the base to use custom error logging.""" + for attr_name in template_fields: + try: + value = getattr(parent, attr_name) + except AttributeError: + raise AttributeError( + f"{attr_name!r} is configured as a template field " + f"but {parent.task_type} does not have this attribute." + ) + try: + if not value: continue - task_ids_to_trace_next.update(dag.task_dict[task_id].get_direct_relative_ids(upstream)) - relatives.add(task_id) - task_ids_to_trace = task_ids_to_trace_next - - return relatives - - def get_flat_relatives(self, upstream: bool = False) -> Collection[Operator]: - """Get a flat list of relatives, either upstream or downstream.""" - dag = self.get_dag() - if not dag: - return set() - return [dag.task_dict[task_id] for task_id in self.get_flat_relative_ids(upstream=upstream)] - - def get_upstreams_follow_setups(self) -> Iterable[Operator]: - """All upstreams and, for each upstream setup, its respective teardowns.""" - for task in self.get_flat_relatives(upstream=True): - yield task - if task.is_setup: - for t in task.downstream_list: - if t.is_teardown and t != self: - yield t - - def get_upstreams_only_setups_and_teardowns(self) -> Iterable[Operator]: - """ - Only *relevant* upstream setups and their teardowns. - - This method is meant to be used when we are clearing the task (non-upstream) and we need - to add in the *relevant* setups and their teardowns. - - Relevant in this case means, the setup has a teardown that is downstream of ``self``, - or the setup has no teardowns. - """ - downstream_teardown_ids = { - x.task_id for x in self.get_flat_relatives(upstream=False) if x.is_teardown - } - for task in self.get_flat_relatives(upstream=True): - if not task.is_setup: - continue - has_no_teardowns = not any(True for x in task.downstream_list if x.is_teardown) - # if task has no teardowns or has teardowns downstream of self - if has_no_teardowns or task.downstream_task_ids.intersection(downstream_teardown_ids): - yield task - for t in task.downstream_list: - if t.is_teardown and t != self: - yield t - - def get_upstreams_only_setups(self) -> Iterable[Operator]: - """ - Return relevant upstream setups. + except Exception: + # This may happen if the templated field points to a class which does not support `__bool__`, + # such as Pandas DataFrames: + # https://github.com/pandas-dev/pandas/blob/9135c3aaf12d26f857fcc787a5b64d521c51e379/pandas/core/generic.py#L1465 + self.log.info( + "Unable to check if the value of type '%s' is False for task '%s', field '%s'.", + type(value).__name__, + self.task_id, + attr_name, + ) + # We may still want to render custom classes which do not support __bool__ + pass - This method is meant to be used when we are checking task dependencies where we need - to wait for all the upstream setups to complete before we can run the task. - """ - for task in self.get_upstreams_only_setups_and_teardowns(): - if task.is_setup: - yield task + try: + if callable(value): + rendered_content = value(context=context, jinja_env=jinja_env) + else: + rendered_content = self.render_template( + value, + context, + jinja_env, + seen_oids, + ) + except Exception: + value_masked = redact(name=attr_name, value=value) + self.log.exception( + "Exception rendering Jinja template for task '%s', field '%s'. Template: %r", + self.task_id, + attr_name, + value_masked, + ) + raise + else: + setattr(parent, attr_name, rendered_content) def _iter_all_mapped_downstreams(self) -> Iterator[MappedOperator | MappedTaskGroup]: """ @@ -394,7 +259,9 @@ def iter_mapped_task_groups(self) -> Iterator[MappedTaskGroup]: """ if (group := self.task_group) is None: return - yield from group.iter_mapped_task_groups() + # TODO: Task-SDK: this type ignore shouldn't be necessary, revisit once mapping support is fully in the + # SDK + yield from group.iter_mapped_task_groups() # type: ignore[misc] def get_closest_mapped_task_group(self) -> MappedTaskGroup | None: """ @@ -460,6 +327,7 @@ def priority_weight_total(self) -> int: - WeightRule.DOWNSTREAM - adds priority weight of all downstream tasks - WeightRule.UPSTREAM - adds priority weight of all upstream tasks """ + # TODO: This should live in the WeightStragies themselves, not in here from airflow.task.priority_strategy import ( _AbsolutePriorityWeightStrategy, _DownstreamPriorityWeightStrategy, @@ -587,9 +455,9 @@ def expand_mapped_task(self, run_id: str, *, session: Session) -> tuple[Sequence """ from sqlalchemy import func, or_ - from airflow.models.baseoperator import BaseOperator from airflow.models.mappedoperator import MappedOperator from airflow.models.taskinstance import TaskInstance + from airflow.sdk import BaseOperator from airflow.settings import task_instance_mutation_hook if not isinstance(self, (BaseOperator, MappedOperator)): @@ -624,6 +492,9 @@ def expand_mapped_task(self, run_id: str, *, session: Session) -> tuple[Sequence all_expanded_tis: list[TaskInstance] = [] if unmapped_ti: + if TYPE_CHECKING: + assert self.dag is None or isinstance(self.dag, SchedulerDAG) + # The unmapped task instance still exists and is unfinished, i.e. we # haven't tried to run it before. if total_length is None: @@ -721,72 +592,6 @@ def render_template_fields( """ raise NotImplementedError() - def _render(self, template, context, dag: DAG | None = None): - if dag is None: - dag = self.get_dag() - return super()._render(template, context, dag=dag) - - def get_template_env(self, dag: DAG | None = None) -> jinja2.Environment: - """Get the template environment for rendering templates.""" - if dag is None: - dag = self.get_dag() - return super().get_template_env(dag=dag) - - def _do_render_template_fields( - self, - parent: Any, - template_fields: Iterable[str], - context: Context, - jinja_env: jinja2.Environment, - seen_oids: set[int], - ) -> None: - """Override the base to use custom error logging.""" - for attr_name in template_fields: - try: - value = getattr(parent, attr_name) - except AttributeError: - raise AttributeError( - f"{attr_name!r} is configured as a template field " - f"but {parent.task_type} does not have this attribute." - ) - try: - if not value: - continue - except Exception: - # This may happen if the templated field points to a class which does not support `__bool__`, - # such as Pandas DataFrames: - # https://github.com/pandas-dev/pandas/blob/9135c3aaf12d26f857fcc787a5b64d521c51e379/pandas/core/generic.py#L1465 - self.log.info( - "Unable to check if the value of type '%s' is False for task '%s', field '%s'.", - type(value).__name__, - self.task_id, - attr_name, - ) - # We may still want to render custom classes which do not support __bool__ - pass - - try: - if callable(value): - rendered_content = value(context=context, jinja_env=jinja_env) - else: - rendered_content = self.render_template( - value, - context, - jinja_env, - seen_oids, - ) - except Exception: - value_masked = redact(name=attr_name, value=value) - self.log.exception( - "Exception rendering Jinja template for task '%s', field '%s'. Template: %r", - self.task_id, - attr_name, - value_masked, - ) - raise - else: - setattr(parent, attr_name, rendered_content) - def __enter__(self): if not self.is_setup and not self.is_teardown: raise AirflowException("Only setup/teardown tasks can be used as context managers.") diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 514553e05a2d..c1448ef9cc55 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -23,17 +23,13 @@ from __future__ import annotations -import abc import collections.abc import contextlib import copy import functools -import inspect import logging -import sys -import warnings from datetime import datetime, timedelta -from functools import total_ordering, wraps +from functools import wraps from threading import local from types import FunctionType from typing import ( @@ -45,10 +41,9 @@ NoReturn, Sequence, TypeVar, - cast, ) -import attr +import methodtools import pendulum from sqlalchemy import select from sqlalchemy.orm.exc import NoResultFound @@ -56,7 +51,6 @@ from airflow.configuration import conf from airflow.exceptions import ( AirflowException, - FailStopDagInvalidTriggerRule, TaskDeferralError, TaskDeferred, ) @@ -78,12 +72,17 @@ ) from airflow.models.base import _sentinel from airflow.models.mappedoperator import OperatorPartial, validate_mapping_kwargs -from airflow.models.param import ParamsDict -from airflow.models.pool import Pool from airflow.models.taskinstance import TaskInstance, clear_task_instances from airflow.models.taskmixin import DependencyMixin + +# Keeping this file at all is a temp thing as we migrate the repo to the task sdk as the base, but to keep +# main working and useful for others to develop against we use the TaskSDK here but keep this file around +from airflow.sdk import DAG, BaseOperator as TaskSDKBaseOperator, EdgeModifier as TaskSDKEdgeModifier +from airflow.sdk.definitions.baseoperator import ( + BaseOperatorMeta as TaskSDKBaseOperatorMeta, + get_merged_defaults, +) from airflow.serialization.enums import DagAttributeTypes -from airflow.task.priority_strategy import PriorityWeightStrategy, validate_and_load_priority_weight_strategy from airflow.ti_deps.deps.mapped_task_upstream_dep import MappedTaskUpstreamDep from airflow.ti_deps.deps.not_in_retry_period_dep import NotInRetryPeriodDep from airflow.ti_deps.deps.not_previously_skipped_dep import NotPreviouslySkippedDep @@ -91,15 +90,11 @@ from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep from airflow.utils import timezone from airflow.utils.context import Context, context_get_outlet_events -from airflow.utils.decorators import fixup_decorator_warning_stack from airflow.utils.edgemodifier import EdgeModifier -from airflow.utils.helpers import validate_instance_args, validate_key from airflow.utils.operator_helpers import ExecutionCallableRunner from airflow.utils.operator_resources import Resources from airflow.utils.session import NEW_SESSION, provide_session -from airflow.utils.setup_teardown import SetupTeardownContext -from airflow.utils.trigger_rule import TriggerRule -from airflow.utils.types import NOTSET, AttributeRemoved, DagRunTriggeredByType +from airflow.utils.types import NOTSET, DagRunTriggeredByType from airflow.utils.xcom import XCOM_RETURN_KEY if TYPE_CHECKING: @@ -110,14 +105,18 @@ from airflow.models.abstractoperator import TaskStateChangeCallback from airflow.models.baseoperatorlink import BaseOperatorLink - from airflow.models.dag import DAG + from airflow.models.dag import DAG as SchedulerDAG from airflow.models.operator import Operator - from airflow.models.xcom_arg import XComArg + from airflow.task.priority_strategy import PriorityWeightStrategy from airflow.ti_deps.deps.base_ti_dep import BaseTIDep from airflow.triggers.base import BaseTrigger, StartTriggerArgs - from airflow.utils.task_group import TaskGroup from airflow.utils.types import ArgNotSet + +# Todo: AIP-44: Once we get rid of AIP-44 we can remove this. But without this here pydantic fails to resolve +# types for serialization +from airflow.utils.task_group import TaskGroup # noqa: TCH001 + TaskPreExecuteHook = Callable[[Context], None] TaskPostExecuteHook = Callable[[Context, Any], None] @@ -139,10 +138,12 @@ def parse_retries(retries: Any) -> int | None: return parsed_retries -def coerce_timedelta(value: float | timedelta, *, key: str) -> timedelta: +def coerce_timedelta(value: float | timedelta, *, key: str | None = None) -> timedelta: if isinstance(value, timedelta): return value - logger.debug("%s isn't a timedelta object, assuming secs", key) + # TODO: remove this log here + if key: + logger.debug("%s isn't a timedelta object, assuming secs", key) return timedelta(seconds=value) @@ -152,38 +153,6 @@ def coerce_resources(resources: dict[str, Any] | None) -> Resources | None: return Resources(**resources) -def _get_parent_defaults(dag: DAG | None, task_group: TaskGroup | None) -> tuple[dict, ParamsDict]: - if not dag: - return {}, ParamsDict() - dag_args = copy.copy(dag.default_args) - dag_params = copy.deepcopy(dag.params) - if task_group: - if task_group.default_args and not isinstance(task_group.default_args, collections.abc.Mapping): - raise TypeError("default_args must be a mapping") - dag_args.update(task_group.default_args) - return dag_args, dag_params - - -def get_merged_defaults( - dag: DAG | None, - task_group: TaskGroup | None, - task_params: collections.abc.MutableMapping | None, - task_default_args: dict | None, -) -> tuple[dict, ParamsDict]: - args, params = _get_parent_defaults(dag, task_group) - if task_params: - if not isinstance(task_params, collections.abc.Mapping): - raise TypeError("params must be a mapping") - params.update(task_params) - if task_default_args: - if not isinstance(task_default_args, collections.abc.Mapping): - raise TypeError("default_args must be a mapping") - args.update(task_default_args) - with contextlib.suppress(KeyError): - params.update(task_default_args["params"] or {}) - return args, params - - class _PartialDescriptor: """A descriptor that guards against ``.partial`` being called on Task objects.""" @@ -225,161 +194,150 @@ def partial(**kwargs): # This is what handles the actual mapping. -def partial( - operator_class: type[BaseOperator], - *, - task_id: str, - dag: DAG | None = None, - task_group: TaskGroup | None = None, - start_date: datetime | ArgNotSet = NOTSET, - end_date: datetime | ArgNotSet = NOTSET, - owner: str | ArgNotSet = NOTSET, - email: None | str | Iterable[str] | ArgNotSet = NOTSET, - params: collections.abc.MutableMapping | None = None, - resources: dict[str, Any] | None | ArgNotSet = NOTSET, - trigger_rule: str | ArgNotSet = NOTSET, - depends_on_past: bool | ArgNotSet = NOTSET, - ignore_first_depends_on_past: bool | ArgNotSet = NOTSET, - wait_for_past_depends_before_skipping: bool | ArgNotSet = NOTSET, - wait_for_downstream: bool | ArgNotSet = NOTSET, - retries: int | None | ArgNotSet = NOTSET, - queue: str | ArgNotSet = NOTSET, - pool: str | ArgNotSet = NOTSET, - pool_slots: int | ArgNotSet = NOTSET, - execution_timeout: timedelta | None | ArgNotSet = NOTSET, - max_retry_delay: None | timedelta | float | ArgNotSet = NOTSET, - retry_delay: timedelta | float | ArgNotSet = NOTSET, - retry_exponential_backoff: bool | ArgNotSet = NOTSET, - priority_weight: int | ArgNotSet = NOTSET, - weight_rule: str | PriorityWeightStrategy | ArgNotSet = NOTSET, - sla: timedelta | None | ArgNotSet = NOTSET, - map_index_template: str | None | ArgNotSet = NOTSET, - max_active_tis_per_dag: int | None | ArgNotSet = NOTSET, - max_active_tis_per_dagrun: int | None | ArgNotSet = NOTSET, - on_execute_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] | ArgNotSet = NOTSET, - on_failure_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] | ArgNotSet = NOTSET, - on_success_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] | ArgNotSet = NOTSET, - on_retry_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] | ArgNotSet = NOTSET, - on_skipped_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] | ArgNotSet = NOTSET, - run_as_user: str | None | ArgNotSet = NOTSET, - executor: str | None | ArgNotSet = NOTSET, - executor_config: dict | None | ArgNotSet = NOTSET, - inlets: Any | None | ArgNotSet = NOTSET, - outlets: Any | None | ArgNotSet = NOTSET, - doc: str | None | ArgNotSet = NOTSET, - doc_md: str | None | ArgNotSet = NOTSET, - doc_json: str | None | ArgNotSet = NOTSET, - doc_yaml: str | None | ArgNotSet = NOTSET, - doc_rst: str | None | ArgNotSet = NOTSET, - task_display_name: str | None | ArgNotSet = NOTSET, - logger_name: str | None | ArgNotSet = NOTSET, - allow_nested_operators: bool = True, - **kwargs, -) -> OperatorPartial: - from airflow.models.dag import DagContext - from airflow.utils.task_group import TaskGroupContext - - validate_mapping_kwargs(operator_class, "partial", kwargs) - - dag = dag or DagContext.get_current_dag() - if dag: - task_group = task_group or TaskGroupContext.get_current_task_group(dag) - if task_group: - task_id = task_group.child_id(task_id) - - # Merge DAG and task group level defaults into user-supplied values. - dag_default_args, partial_params = get_merged_defaults( - dag=dag, - task_group=task_group, - task_params=params, - task_default_args=kwargs.pop("default_args", None), - ) - # Create partial_kwargs from args and kwargs - partial_kwargs: dict[str, Any] = { +if TYPE_CHECKING: + + def partial( + operator_class: type[BaseOperator], + *, + task_id: str, + dag: DAG | None = None, + task_group: TaskGroup | None = None, + start_date: datetime | ArgNotSet = NOTSET, + end_date: datetime | ArgNotSet = NOTSET, + owner: str | ArgNotSet = NOTSET, + email: None | str | Iterable[str] | ArgNotSet = NOTSET, + params: collections.abc.MutableMapping | None = None, + resources: dict[str, Any] | None | ArgNotSet = NOTSET, + trigger_rule: str | ArgNotSet = NOTSET, + depends_on_past: bool | ArgNotSet = NOTSET, + ignore_first_depends_on_past: bool | ArgNotSet = NOTSET, + wait_for_past_depends_before_skipping: bool | ArgNotSet = NOTSET, + wait_for_downstream: bool | ArgNotSet = NOTSET, + retries: int | None | ArgNotSet = NOTSET, + queue: str | ArgNotSet = NOTSET, + pool: str | ArgNotSet = NOTSET, + pool_slots: int | ArgNotSet = NOTSET, + execution_timeout: timedelta | None | ArgNotSet = NOTSET, + max_retry_delay: None | timedelta | float | ArgNotSet = NOTSET, + retry_delay: timedelta | float | ArgNotSet = NOTSET, + retry_exponential_backoff: bool | ArgNotSet = NOTSET, + priority_weight: int | ArgNotSet = NOTSET, + weight_rule: str | PriorityWeightStrategy | ArgNotSet = NOTSET, + sla: timedelta | None | ArgNotSet = NOTSET, + map_index_template: str | None | ArgNotSet = NOTSET, + max_active_tis_per_dag: int | None | ArgNotSet = NOTSET, + max_active_tis_per_dagrun: int | None | ArgNotSet = NOTSET, + on_execute_callback: None + | TaskStateChangeCallback + | list[TaskStateChangeCallback] + | ArgNotSet = NOTSET, + on_failure_callback: None + | TaskStateChangeCallback + | list[TaskStateChangeCallback] + | ArgNotSet = NOTSET, + on_success_callback: None + | TaskStateChangeCallback + | list[TaskStateChangeCallback] + | ArgNotSet = NOTSET, + on_retry_callback: None + | TaskStateChangeCallback + | list[TaskStateChangeCallback] + | ArgNotSet = NOTSET, + on_skipped_callback: None + | TaskStateChangeCallback + | list[TaskStateChangeCallback] + | ArgNotSet = NOTSET, + run_as_user: str | None | ArgNotSet = NOTSET, + executor: str | None | ArgNotSet = NOTSET, + executor_config: dict | None | ArgNotSet = NOTSET, + inlets: Any | None | ArgNotSet = NOTSET, + outlets: Any | None | ArgNotSet = NOTSET, + doc: str | None | ArgNotSet = NOTSET, + doc_md: str | None | ArgNotSet = NOTSET, + doc_json: str | None | ArgNotSet = NOTSET, + doc_yaml: str | None | ArgNotSet = NOTSET, + doc_rst: str | None | ArgNotSet = NOTSET, + task_display_name: str | None | ArgNotSet = NOTSET, + logger_name: str | None | ArgNotSet = NOTSET, + allow_nested_operators: bool = True, + **kwargs, + ) -> OperatorPartial: ... +else: + + def partial( + operator_class: type[BaseOperator], + *, + task_id: str, + dag: DAG | None = None, + task_group: TaskGroup | None = None, + params: collections.abc.MutableMapping | None = None, **kwargs, - "dag": dag, - "task_group": task_group, - "task_id": task_id, - "map_index_template": map_index_template, - "start_date": start_date, - "end_date": end_date, - "owner": owner, - "email": email, - "trigger_rule": trigger_rule, - "depends_on_past": depends_on_past, - "ignore_first_depends_on_past": ignore_first_depends_on_past, - "wait_for_past_depends_before_skipping": wait_for_past_depends_before_skipping, - "wait_for_downstream": wait_for_downstream, - "retries": retries, - "queue": queue, - "pool": pool, - "pool_slots": pool_slots, - "execution_timeout": execution_timeout, - "max_retry_delay": max_retry_delay, - "retry_delay": retry_delay, - "retry_exponential_backoff": retry_exponential_backoff, - "priority_weight": priority_weight, - "weight_rule": weight_rule, - "sla": sla, - "max_active_tis_per_dag": max_active_tis_per_dag, - "max_active_tis_per_dagrun": max_active_tis_per_dagrun, - "on_execute_callback": on_execute_callback, - "on_failure_callback": on_failure_callback, - "on_retry_callback": on_retry_callback, - "on_success_callback": on_success_callback, - "on_skipped_callback": on_skipped_callback, - "run_as_user": run_as_user, - "executor": executor, - "executor_config": executor_config, - "inlets": inlets, - "outlets": outlets, - "resources": resources, - "doc": doc, - "doc_json": doc_json, - "doc_md": doc_md, - "doc_rst": doc_rst, - "doc_yaml": doc_yaml, - "task_display_name": task_display_name, - "logger_name": logger_name, - "allow_nested_operators": allow_nested_operators, - } - - # Inject DAG-level default args into args provided to this function. - partial_kwargs.update((k, v) for k, v in dag_default_args.items() if partial_kwargs.get(k) is NOTSET) - - # Fill fields not provided by the user with default values. - partial_kwargs = {k: _PARTIAL_DEFAULTS.get(k) if v is NOTSET else v for k, v in partial_kwargs.items()} - - # Post-process arguments. Should be kept in sync with _TaskDecorator.expand(). - if "task_concurrency" in kwargs: # Reject deprecated option. - raise TypeError("unexpected argument: task_concurrency") - if partial_kwargs["wait_for_downstream"]: - partial_kwargs["depends_on_past"] = True - partial_kwargs["start_date"] = timezone.convert_to_utc(partial_kwargs["start_date"]) - partial_kwargs["end_date"] = timezone.convert_to_utc(partial_kwargs["end_date"]) - if partial_kwargs["pool"] is None: - partial_kwargs["pool"] = Pool.DEFAULT_POOL_NAME - if partial_kwargs["pool_slots"] < 1: - dag_str = "" + ): + from airflow.sdk.definitions.contextmanager import DagContext, TaskGroupContext + + validate_mapping_kwargs(operator_class, "partial", kwargs) + + dag = dag or DagContext.get_current() if dag: - dag_str = f" in dag {dag.dag_id}" - raise ValueError(f"pool slots for {task_id}{dag_str} cannot be less than 1") - partial_kwargs["retries"] = parse_retries(partial_kwargs["retries"]) - partial_kwargs["retry_delay"] = coerce_timedelta(partial_kwargs["retry_delay"], key="retry_delay") - if partial_kwargs["max_retry_delay"] is not None: - partial_kwargs["max_retry_delay"] = coerce_timedelta( - partial_kwargs["max_retry_delay"], - key="max_retry_delay", + task_group = task_group or TaskGroupContext.get_current(dag) + if task_group: + task_id = task_group.child_id(task_id) + + # Merge DAG and task group level defaults into user-supplied values. + dag_default_args, partial_params = get_merged_defaults( + dag=dag, + task_group=task_group, + task_params=params, + task_default_args=kwargs.pop("default_args", None), ) - partial_kwargs["executor_config"] = partial_kwargs["executor_config"] or {} - partial_kwargs["resources"] = coerce_resources(partial_kwargs["resources"]) - return OperatorPartial( - operator_class=operator_class, - kwargs=partial_kwargs, - params=partial_params, - ) + # Create partial_kwargs from args and kwargs + partial_kwargs: dict[str, Any] = { + "task_id": task_id, + "dag": dag, + "task_group": task_group, + **kwargs, + } + + # Inject DAG-level default args into args provided to this function. + partial_kwargs.update( + (k, v) for k, v in dag_default_args.items() if partial_kwargs.get(k, NOTSET) is NOTSET + ) + + # Fill fields not provided by the user with default values. + for k, v in _PARTIAL_DEFAULTS.items(): + partial_kwargs.setdefault(k, v) + + # Post-process arguments. Should be kept in sync with _TaskDecorator.expand(). + if "task_concurrency" in kwargs: # Reject deprecated option. + raise TypeError("unexpected argument: task_concurrency") + if wait := partial_kwargs.get("wait_for_downstream", False): + partial_kwargs["depends_on_past"] = wait + if start_date := partial_kwargs.get("start_date", None): + partial_kwargs["start_date"] = timezone.convert_to_utc(start_date) + if end_date := partial_kwargs.get("end_date", None): + partial_kwargs["end_date"] = timezone.convert_to_utc(end_date) + if partial_kwargs["pool_slots"] < 1: + dag_str = "" + if dag: + dag_str = f" in dag {dag.dag_id}" + raise ValueError(f"pool slots for {task_id}{dag_str} cannot be less than 1") + if retries := partial_kwargs.get("retries"): + partial_kwargs["retries"] = parse_retries(retries) + partial_kwargs["retry_delay"] = coerce_timedelta(partial_kwargs["retry_delay"], key="retry_delay") + if partial_kwargs.get("max_retry_delay", None) is not None: + partial_kwargs["max_retry_delay"] = coerce_timedelta( + partial_kwargs["max_retry_delay"], + key="max_retry_delay", + ) + partial_kwargs.setdefault("executor_config", {}) + + return OperatorPartial( + operator_class=operator_class, + kwargs=partial_kwargs, + params=partial_params, + ) class ExecutorSafeguard: @@ -419,103 +377,9 @@ def wrapper(self, *args, **kwargs): return wrapper -class BaseOperatorMeta(abc.ABCMeta): - """Metaclass of BaseOperator.""" - - @classmethod - def _apply_defaults(cls, func: T) -> T: - """ - Look for an argument named "default_args", and fill the unspecified arguments from it. - - Since python2.* isn't clear about which arguments are missing when - calling a function, and that this can be quite confusing with multi-level - inheritance and argument defaults, this decorator also alerts with - specific information about the missing arguments. - """ - # Cache inspect.signature for the wrapper closure to avoid calling it - # at every decorated invocation. This is separate sig_cache created - # per decoration, i.e. each function decorated using apply_defaults will - # have a different sig_cache. - sig_cache = inspect.signature(func) - non_variadic_params = { - name: param - for (name, param) in sig_cache.parameters.items() - if param.name != "self" and param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD) - } - non_optional_args = { - name - for name, param in non_variadic_params.items() - if param.default == param.empty and name != "task_id" - } - - fixup_decorator_warning_stack(func) - - @wraps(func) - def apply_defaults(self: BaseOperator, *args: Any, **kwargs: Any) -> Any: - from airflow.models.dag import DagContext - from airflow.utils.task_group import TaskGroupContext - - if args: - raise AirflowException("Use keyword arguments when initializing operators") - - instantiated_from_mapped = kwargs.pop( - "_airflow_from_mapped", - getattr(self, "_BaseOperator__from_mapped", False), - ) - - dag: DAG | None = kwargs.get("dag") or DagContext.get_current_dag() - task_group: TaskGroup | None = kwargs.get("task_group") - if dag and not task_group: - task_group = TaskGroupContext.get_current_task_group(dag) - - default_args, merged_params = get_merged_defaults( - dag=dag, - task_group=task_group, - task_params=kwargs.pop("params", None), - task_default_args=kwargs.pop("default_args", None), - ) - - for arg in sig_cache.parameters: - if arg not in kwargs and arg in default_args: - kwargs[arg] = default_args[arg] - - missing_args = non_optional_args.difference(kwargs) - if len(missing_args) == 1: - raise AirflowException(f"missing keyword argument {missing_args.pop()!r}") - elif missing_args: - display = ", ".join(repr(a) for a in sorted(missing_args)) - raise AirflowException(f"missing keyword arguments {display}") - - if merged_params: - kwargs["params"] = merged_params - - hook = getattr(self, "_hook_apply_defaults", None) - if hook: - args, kwargs = hook(**kwargs, default_args=default_args) - default_args = kwargs.pop("default_args", {}) - - if not hasattr(self, "_BaseOperator__init_kwargs"): - self._BaseOperator__init_kwargs = {} - self._BaseOperator__from_mapped = instantiated_from_mapped - - result = func(self, **kwargs, default_args=default_args) - - # Store the args passed to init -- we need them to support task.map serialization! - self._BaseOperator__init_kwargs.update(kwargs) # type: ignore - - # Set upstream task defined by XComArgs passed to template fields of the operator. - # BUT: only do this _ONCE_, not once for each class in the hierarchy - if not instantiated_from_mapped and func == self.__init__.__wrapped__: # type: ignore[misc] - self.set_xcomargs_dependencies() - # Mark instance as instantiated. - self._BaseOperator__instantiated = True - - return result - - apply_defaults.__non_optional_args = non_optional_args # type: ignore - apply_defaults.__param_names = set(non_variadic_params) # type: ignore - - return cast(T, apply_defaults) +# TODO: Task-SDK - temporarily extend the metaclass to add in the ExecutorSafeguard. +class BaseOperatorMeta(TaskSDKBaseOperatorMeta): + """:meta private:""" # noqa: D400 def __new__(cls, name, bases, namespace, **kwargs): execute_method = namespace.get("execute") @@ -528,57 +392,10 @@ def __new__(cls, name, bases, namespace, **kwargs): partial_desc = vars(new_cls)["partial"] if isinstance(partial_desc, _PartialDescriptor): partial_desc.class_method = classmethod(partial) - - # We patch `__init__` only if the class defines it. - if inspect.getmro(new_cls)[1].__init__ is not new_cls.__init__: - new_cls.__init__ = cls._apply_defaults(new_cls.__init__) - return new_cls -# TODO: The following mapping is used to validate that the arguments passed to the BaseOperator are of the -# correct type. This is a temporary solution until we find a more sophisticated method for argument -# validation. One potential method is to use `get_type_hints` from the typing module. However, this is not -# fully compatible with future annotations for Python versions below 3.10. Once we require a minimum Python -# version that supports `get_type_hints` effectively or find a better approach, we can replace this -# manual type-checking method. -BASEOPERATOR_ARGS_EXPECTED_TYPES = { - "task_id": str, - "email": (str, Iterable), - "email_on_retry": bool, - "email_on_failure": bool, - "retries": int, - "retry_exponential_backoff": bool, - "depends_on_past": bool, - "ignore_first_depends_on_past": bool, - "wait_for_past_depends_before_skipping": bool, - "wait_for_downstream": bool, - "priority_weight": int, - "queue": str, - "pool": str, - "pool_slots": int, - "trigger_rule": str, - "run_as_user": str, - "task_concurrency": int, - "map_index_template": str, - "max_active_tis_per_dag": int, - "max_active_tis_per_dagrun": int, - "executor": str, - "do_xcom_push": bool, - "multiple_outputs": bool, - "doc": str, - "doc_md": str, - "doc_json": str, - "doc_yaml": str, - "doc_rst": str, - "task_display_name": str, - "logger_name": str, - "allow_nested_operators": bool, -} - - -@total_ordering -class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta): +class BaseOperator(TaskSDKBaseOperator, AbstractOperator, metaclass=BaseOperatorMeta): r""" Abstract base class for all operators. @@ -783,401 +600,71 @@ def say_hello_world(**context): hello_world_task.execute(context) """ - # Implementing Operator. - template_fields: Sequence[str] = () - template_ext: Sequence[str] = () - - template_fields_renderers: dict[str, str] = {} - - # Defines the color in the UI - ui_color: str = "#fff" - ui_fgcolor: str = "#000" - - pool: str = "" - - # base list which includes all the attrs that don't need deep copy. - _base_operator_shallow_copy_attrs: tuple[str, ...] = ( - "user_defined_macros", - "user_defined_filters", - "params", - ) - - # each operator should override this class attr for shallow copy attrs. - shallow_copy_attrs: Sequence[str] = () - - # Defines the operator level extra links - operator_extra_links: Collection[BaseOperatorLink] = () - - # The _serialized_fields are lazily loaded when get_serialized_fields() method is called - __serialized_fields: frozenset[str] | None = None - - partial: Callable[..., OperatorPartial] = _PartialDescriptor() # type: ignore - - _comps = { - "task_id", - "dag_id", - "owner", - "email", - "email_on_retry", - "retry_delay", - "retry_exponential_backoff", - "max_retry_delay", - "start_date", - "end_date", - "depends_on_past", - "wait_for_downstream", - "priority_weight", - "sla", - "execution_timeout", - "on_execute_callback", - "on_failure_callback", - "on_success_callback", - "on_retry_callback", - "on_skipped_callback", - "do_xcom_push", - "multiple_outputs", - "allow_nested_operators", - "executor", - } - - # Defines if the operator supports lineage without manual definitions - supports_lineage = False - - # If True then the class constructor was called - __instantiated = False - # List of args as passed to `init()`, after apply_defaults() has been updated. Used to "recreate" the task - # when mapping - __init_kwargs: dict[str, Any] - - # Set to True before calling execute method - _lock_for_execution = False - - _dag: DAG | None = None - task_group: TaskGroup | None = None - - start_date: pendulum.DateTime | None = None - end_date: pendulum.DateTime | None = None - - # Set to True for an operator instantiated by a mapped operator. - __from_mapped = False - start_trigger_args: StartTriggerArgs | None = None start_from_trigger: bool = False + on_execute_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None + on_failure_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None + on_success_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None + on_retry_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None + on_skipped_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None + def __init__( self, - task_id: str, - owner: str = DEFAULT_OWNER, - email: str | Iterable[str] | None = None, - email_on_retry: bool = conf.getboolean("email", "default_email_on_retry", fallback=True), - email_on_failure: bool = conf.getboolean("email", "default_email_on_failure", fallback=True), - retries: int | None = DEFAULT_RETRIES, - retry_delay: timedelta | float = DEFAULT_RETRY_DELAY, - retry_exponential_backoff: bool = False, - max_retry_delay: timedelta | float | None = None, - start_date: datetime | None = None, - end_date: datetime | None = None, - depends_on_past: bool = False, - ignore_first_depends_on_past: bool = DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST, - wait_for_past_depends_before_skipping: bool = DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING, - wait_for_downstream: bool = False, - dag: DAG | None = None, - params: collections.abc.MutableMapping | None = None, - default_args: dict | None = None, - priority_weight: int = DEFAULT_PRIORITY_WEIGHT, - weight_rule: str | PriorityWeightStrategy = DEFAULT_WEIGHT_RULE, - queue: str = DEFAULT_QUEUE, - pool: str | None = None, - pool_slots: int = DEFAULT_POOL_SLOTS, - sla: timedelta | None = None, - execution_timeout: timedelta | None = DEFAULT_TASK_EXECUTION_TIMEOUT, + pre_execute=None, + post_execute=None, on_execute_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None, on_failure_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None, on_success_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None, on_retry_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None, on_skipped_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None, - pre_execute: TaskPreExecuteHook | None = None, - post_execute: TaskPostExecuteHook | None = None, - trigger_rule: str = DEFAULT_TRIGGER_RULE, - resources: dict[str, Any] | None = None, - run_as_user: str | None = None, - map_index_template: str | None = None, - max_active_tis_per_dag: int | None = None, - max_active_tis_per_dagrun: int | None = None, - executor: str | None = None, - executor_config: dict | None = None, - do_xcom_push: bool = True, - multiple_outputs: bool = False, - inlets: Any | None = None, - outlets: Any | None = None, - task_group: TaskGroup | None = None, - doc: str | None = None, - doc_md: str | None = None, - doc_json: str | None = None, - doc_yaml: str | None = None, - doc_rst: str | None = None, - task_display_name: str | None = None, - logger_name: str | None = None, - allow_nested_operators: bool = True, **kwargs, ): - from airflow.models.dag import DagContext - from airflow.utils.task_group import TaskGroupContext - - self.__init_kwargs = {} - - super().__init__() - - kwargs.pop("_airflow_mapped_validation_only", None) - if kwargs: - raise AirflowException( - f"Invalid arguments were passed to {self.__class__.__name__} (task_id: {task_id}). " - f"Invalid arguments were:\n**kwargs: {kwargs}", - ) - validate_key(task_id) - - dag = dag or DagContext.get_current_dag() - task_group = task_group or TaskGroupContext.get_current_task_group(dag) - - self.task_id = task_group.child_id(task_id) if task_group else task_id - if not self.__from_mapped and task_group: - task_group.add(self) - - self.owner = owner - self.email = email - self.email_on_retry = email_on_retry - self.email_on_failure = email_on_failure - - if execution_timeout is not None and not isinstance(execution_timeout, timedelta): - raise ValueError( - f"execution_timeout must be timedelta object but passed as type: {type(execution_timeout)}" - ) - self.execution_timeout = execution_timeout + if start_date := kwargs.get("start_date", None): + kwargs["start_date"] = timezone.convert_to_utc(start_date) + if end_date := kwargs.get("end_date", None): + kwargs["end_date"] = timezone.convert_to_utc(end_date) + super().__init__(**kwargs) + self._pre_execute_hook = pre_execute + self._post_execute_hook = post_execute self.on_execute_callback = on_execute_callback self.on_failure_callback = on_failure_callback self.on_success_callback = on_success_callback - self.on_retry_callback = on_retry_callback self.on_skipped_callback = on_skipped_callback - self._pre_execute_hook = pre_execute - self._post_execute_hook = post_execute - - if start_date and not isinstance(start_date, datetime): - self.log.warning("start_date for %s isn't datetime.datetime", self) - elif start_date: - self.start_date = timezone.convert_to_utc(start_date) - - if end_date: - self.end_date = timezone.convert_to_utc(end_date) - - self.executor = executor - self.executor_config = executor_config or {} - self.run_as_user = run_as_user - self.retries = parse_retries(retries) - self.queue = queue - self.pool = Pool.DEFAULT_POOL_NAME if pool is None else pool - self.pool_slots = pool_slots - if self.pool_slots < 1: - dag_str = f" in dag {dag.dag_id}" if dag else "" - raise ValueError(f"pool slots for {self.task_id}{dag_str} cannot be less than 1") - - if sla: - self.log.warning( - "The SLA feature is removed in Airflow 3.0, to be replaced with a new implementation in 3.1" - ) - - if not TriggerRule.is_valid(trigger_rule): - raise AirflowException( - f"The trigger_rule must be one of {TriggerRule.all_triggers()}," - f"'{dag.dag_id if dag else ''}.{task_id}'; received '{trigger_rule}'." - ) - - self.trigger_rule: TriggerRule = TriggerRule(trigger_rule) - FailStopDagInvalidTriggerRule.check(dag=dag, trigger_rule=self.trigger_rule) - - self.depends_on_past: bool = depends_on_past - self.ignore_first_depends_on_past: bool = ignore_first_depends_on_past - self.wait_for_past_depends_before_skipping: bool = wait_for_past_depends_before_skipping - self.wait_for_downstream: bool = wait_for_downstream - if wait_for_downstream: - self.depends_on_past = True - - self.retry_delay = coerce_timedelta(retry_delay, key="retry_delay") - self.retry_exponential_backoff = retry_exponential_backoff - self.max_retry_delay = ( - max_retry_delay - if max_retry_delay is None - else coerce_timedelta(max_retry_delay, key="max_retry_delay") - ) - - # At execution_time this becomes a normal dict - self.params: ParamsDict | dict = ParamsDict(params) - if priority_weight is not None and not isinstance(priority_weight, int): - raise AirflowException( - f"`priority_weight` for task '{self.task_id}' only accepts integers, " - f"received '{type(priority_weight)}'." - ) - self.priority_weight = priority_weight - self.weight_rule = validate_and_load_priority_weight_strategy(weight_rule) - self.resources = coerce_resources(resources) - self.max_active_tis_per_dag: int | None = max_active_tis_per_dag - self.max_active_tis_per_dagrun: int | None = max_active_tis_per_dagrun - self.do_xcom_push: bool = do_xcom_push - self.map_index_template: str | None = map_index_template - self.multiple_outputs: bool = multiple_outputs - - self.doc_md = doc_md - self.doc_json = doc_json - self.doc_yaml = doc_yaml - self.doc_rst = doc_rst - self.doc = doc - # Populate the display field only if provided and different from task id - self._task_display_property_value = ( - task_display_name if task_display_name and task_display_name != task_id else None - ) - - self.upstream_task_ids: set[str] = set() - self.downstream_task_ids: set[str] = set() - - if dag: - self.dag = dag - - self._log_config_logger_name = "airflow.task.operators" - self._logger_name = logger_name - self.allow_nested_operators: bool = allow_nested_operators - - # Lineage - self.inlets: list = [] - self.outlets: list = [] - - if inlets: - self.inlets = ( - inlets - if isinstance(inlets, list) - else [ - inlets, - ] - ) - - if outlets: - self.outlets = ( - outlets - if isinstance(outlets, list) - else [ - outlets, - ] - ) - - if isinstance(self.template_fields, str): - warnings.warn( - f"The `template_fields` value for {self.task_type} is a string " - "but should be a list or tuple of string. Wrapping it in a list for execution. " - f"Please update {self.task_type} accordingly.", - UserWarning, - stacklevel=2, - ) - self.template_fields = [self.template_fields] - - self._is_setup = False - self._is_teardown = False - if SetupTeardownContext.active: - SetupTeardownContext.update_context_map(self) - - validate_instance_args(self, BASEOPERATOR_ARGS_EXPECTED_TYPES) - - def __eq__(self, other): - if type(self) is type(other): - # Use getattr() instead of __dict__ as __dict__ doesn't return - # correct values for properties. - return all(getattr(self, c, None) == getattr(other, c, None) for c in self._comps) - return False - - def __ne__(self, other): - return not self == other - - def __hash__(self): - hash_components = [type(self)] - for component in self._comps: - val = getattr(self, component, None) - try: - hash(val) - hash_components.append(val) - except TypeError: - hash_components.append(repr(val)) - return hash(tuple(hash_components)) - - # including lineage information - def __or__(self, other): - """ - Return [This Operator] | [Operator]. - - The inlets of other will be set to pick up the outlets from this operator. - Other will be set as a downstream task of this operator. - """ - if isinstance(other, BaseOperator): - if not self.outlets and not self.supports_lineage: - raise ValueError("No outlets defined for this operator") - other.add_inlets([self.task_id]) - self.set_downstream(other) - else: - raise TypeError(f"Right hand side ({other}) is not an Operator") - - return self - - # /Composing Operators --------------------------------------------- - - def __gt__(self, other): - """ - Return [Operator] > [Outlet]. - - If other is an attr annotated object it is set as an outlet of this Operator. - """ - if not isinstance(other, Iterable): - other = [other] - - for obj in other: - if not attr.has(obj): - raise TypeError(f"Left hand side ({obj}) is not an outlet") - self.add_outlets(other) + self.on_retry_callback = on_retry_callback - return self + # Defines the operator level extra links + operator_extra_links: Collection[BaseOperatorLink] = () - def __lt__(self, other): - """ - Return [Inlet] > [Operator] or [Operator] < [Inlet]. + if TYPE_CHECKING: - If other is an attr annotated object it is set as an inlet to this operator. - """ - if not isinstance(other, Iterable): - other = [other] + @property # type: ignore[override] + def dag(self) -> SchedulerDAG: # type: ignore[override] + return super().dag # type: ignore[return-value] - for obj in other: - if not attr.has(obj): - raise TypeError(f"{obj} cannot be an inlet") - self.add_inlets(other) + @dag.setter + def dag(self, val: SchedulerDAG): + # For type checking only + ... - return self + partial: Callable[..., OperatorPartial] = _PartialDescriptor() # type: ignore - def __setattr__(self, key, value): - super().__setattr__(key, value) - if self.__from_mapped or self._lock_for_execution: - return # Skip any custom behavior for validation and during execute. - if key in self.__init_kwargs: - self.__init_kwargs[key] = value - if self.__instantiated and key in self.template_fields: - # Resolve upstreams set by assigning an XComArg after initializing - # an operator, example: - # op = BashOperator() - # op.bash_command = "sleep 1" - self.set_xcomargs_dependencies() - - def add_inlets(self, inlets: Iterable[Any]): - """Set inlets to this operator.""" - self.inlets.extend(inlets) - - def add_outlets(self, outlets: Iterable[Any]): - """Define the outlets of this operator.""" - self.outlets.extend(outlets) + @classmethod + @methodtools.lru_cache(maxsize=None) + def get_serialized_fields(cls): + """Stringified DAGs and operators contain exactly these fields.""" + # TODO: this ends up caching it once per-subclass, which isn't what we want, but this class is only + # kept around during the development of AIP-72/TaskSDK code. + return TaskSDKBaseOperator.get_serialized_fields() | { + "start_trigger_args", + "start_from_trigger", + "on_execute_callback", + "on_failure_callback", + "on_success_callback", + "on_retry_callback", + "on_skipped_callback", + } def get_inlet_defs(self): """ @@ -1195,55 +682,6 @@ def get_outlet_defs(self): """ return self.outlets - def get_dag(self) -> DAG | None: - return self._dag - - @property # type: ignore[override] - def dag(self) -> DAG: # type: ignore[override] - """Returns the Operator's DAG if set, otherwise raises an error.""" - if self._dag: - return self._dag - else: - raise AirflowException(f"Operator {self} has not been assigned to a DAG yet") - - @dag.setter - def dag(self, dag: DAG | None): - """Operators can be assigned to one DAG, one time. Repeat assignments to that same DAG are ok.""" - if dag is None: - self._dag = None - return - - # if set to removed, then just set and exit - if self._dag.__class__ is AttributeRemoved: - self._dag = dag - return - # if setting to removed, then just set and exit - if dag.__class__ is AttributeRemoved: - self._dag = AttributeRemoved("_dag") # type: ignore[assignment] - return - - from airflow.models.dag import DAG - - if not isinstance(dag, DAG): - raise TypeError(f"Expected DAG; received {dag.__class__.__name__}") - elif self.has_dag() and self.dag is not dag: - raise AirflowException(f"The DAG assigned to {self} can not be changed.") - - if self.__from_mapped: - pass # Don't add to DAG -- the mapped task takes the place. - elif dag.task_dict.get(self.task_id) is not self: - dag.add_task(self) - - self._dag = dag - - @property - def task_display_name(self) -> str: - return self._task_display_property_value or self.task_id - - def has_dag(self): - """Return True if the Operator has been assigned to a DAG.""" - return self._dag is not None - deps: frozenset[BaseTIDep] = frozenset( { NotInRetryPeriodDep(), @@ -1265,33 +703,6 @@ def prepare_for_execution(self) -> BaseOperator: other._lock_for_execution = True return other - def set_xcomargs_dependencies(self) -> None: - """ - Resolve upstream dependencies of a task. - - In this way passing an ``XComArg`` as value for a template field - will result in creating upstream relation between two tasks. - - **Example**: :: - - with DAG(...): - generate_content = GenerateContentOperator(task_id="generate_content") - send_email = EmailOperator(..., html_content=generate_content.output) - - # This is equivalent to - with DAG(...): - generate_content = GenerateContentOperator(task_id="generate_content") - send_email = EmailOperator(..., html_content="{{ task_instance.xcom_pull('generate_content') }}") - generate_content >> send_email - - """ - from airflow.models.xcom_arg import XComArg - - for field in self.template_fields: - if hasattr(self, field): - arg = getattr(self, field) - XComArg.apply_upstream_relationship(self, arg) - @prepare_lineage def pre_execute(self, context: Any): """Execute right before self.execute() is called.""" @@ -1328,46 +739,6 @@ def post_execute(self, context: Any, result: Any = None): logger=self.log, ).run(context, result) - def on_kill(self) -> None: - """ - Override this method to clean up subprocesses when a task instance gets killed. - - Any use of the threading, subprocess or multiprocessing module within an - operator needs to be cleaned up, or it will leave ghost processes behind. - """ - - def __deepcopy__(self, memo): - # Hack sorting double chained task lists by task_id to avoid hitting - # max_depth on deepcopy operations. - sys.setrecursionlimit(5000) # TODO fix this in a better way - - cls = self.__class__ - result = cls.__new__(cls) - memo[id(self)] = result - - shallow_copy = cls.shallow_copy_attrs + cls._base_operator_shallow_copy_attrs - - for k, v in self.__dict__.items(): - if k == "_BaseOperator__instantiated": - # Don't set this until the _end_, as it changes behaviour of __setattr__ - continue - if k not in shallow_copy: - setattr(result, k, copy.deepcopy(v, memo)) - else: - setattr(result, k, copy.copy(v)) - result.__instantiated = self.__instantiated - return result - - def __getstate__(self): - state = dict(self.__dict__) - if self._log: - del state["_log"] - - return state - - def __setstate__(self, state): - self.__dict__ = state - def render_template_fields( self, context: Context, @@ -1413,6 +784,12 @@ def clear( qry = qry.where(TaskInstance.task_id.in_(tasks)) results = session.scalars(qry).all() count = len(results) + + if TYPE_CHECKING: + # TODO: Task-SDK: We need to set this to the scheduler DAG until we fully separate scheduling and + # definition code + assert isinstance(self.dag, SchedulerDAG) + clear_task_instances(results, session, dag=self.dag) session.commit() return count @@ -1460,6 +837,10 @@ def run( if TYPE_CHECKING: assert self.start_date + # TODO: Task-SDK: We need to set this to the scheduler DAG until we fully separate scheduling and + # definition code + assert isinstance(self.dag, SchedulerDAG) + start_date = pendulum.instance(start_date or self.start_date) end_date = pendulum.instance(end_date or self.end_date or timezone.utcnow()) @@ -1520,83 +901,6 @@ def get_direct_relatives(self, upstream: bool = False) -> Iterable[Operator]: else: return self.downstream_list - def __repr__(self): - return f"" - - @property - def operator_class(self) -> type[BaseOperator]: # type: ignore[override] - return self.__class__ - - @property - def task_type(self) -> str: - """@property: type of the task.""" - return self.__class__.__name__ - - @property - def operator_name(self) -> str: - """@property: use a more friendly display name for the operator, if set.""" - try: - return self.custom_operator_name # type: ignore - except AttributeError: - return self.task_type - - @property - def roots(self) -> list[BaseOperator]: - """Required by DAGNode.""" - return [self] - - @property - def leaves(self) -> list[BaseOperator]: - """Required by DAGNode.""" - return [self] - - @property - def output(self) -> XComArg: - """Returns reference to XCom pushed by current operator.""" - from airflow.models.xcom_arg import XComArg - - return XComArg(operator=self) - - @property - def is_setup(self) -> bool: - """ - Whether the operator is a setup task. - - :meta private: - """ - return self._is_setup - - @is_setup.setter - def is_setup(self, value: bool) -> None: - """ - Setter for is_setup property. - - :meta private: - """ - if self.is_teardown and value: - raise ValueError(f"Cannot mark task '{self.task_id}' as setup; task is already a teardown.") - self._is_setup = value - - @property - def is_teardown(self) -> bool: - """ - Whether the operator is a teardown task. - - :meta private: - """ - return self._is_teardown - - @is_teardown.setter - def is_teardown(self, value: bool) -> None: - """ - Setter for is_teardown property. - - :meta private: - """ - if self.is_setup and value: - raise ValueError(f"Cannot mark task '{self.task_id}' as teardown; task is already a setup.") - self._is_teardown = value - @staticmethod def xcom_push( context: Any, @@ -1657,68 +961,10 @@ def xcom_pull( session=session, ) - @classmethod - def get_serialized_fields(cls): - """Stringified DAGs and operators contain exactly these fields.""" - if not cls.__serialized_fields: - from airflow.models.dag import DagContext - - # make sure the following dummy task is not added to current active - # dag in context, otherwise, it will result in - # `RuntimeError: dictionary changed size during iteration` - # Exception in SerializedDAG.serialize_dag() call. - DagContext.push_context_managed_dag(None) - cls.__serialized_fields = frozenset( - vars(BaseOperator(task_id="test")).keys() - - { - "upstream_task_ids", - "default_args", - "dag", - "_dag", - "label", - "_BaseOperator__instantiated", - "_BaseOperator__init_kwargs", - "_BaseOperator__from_mapped", - "_is_setup", - "_is_teardown", - "_on_failure_fail_dagrun", - } - | { # Class level defaults need to be added to this list - "start_date", - "end_date", - "_task_type", - "_operator_name", - "ui_color", - "ui_fgcolor", - "template_ext", - "template_fields", - "template_fields_renderers", - "params", - "is_setup", - "is_teardown", - "on_failure_fail_dagrun", - "map_index_template", - "start_trigger_args", - "_needs_expansion", - "start_from_trigger", - } - ) - DagContext.pop_context_managed_dag() - - return cls.__serialized_fields - def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]: """Serialize; required by DAGNode.""" return DagAttributeTypes.OP, self.task_id - @property - def inherits_from_empty_operator(self): - """Used to determine if an Operator is inherited from EmptyOperator.""" - # This looks like `isinstance(self, EmptyOperator) would work, but this also - # needs to cope when `self` is a Serialized instance of a EmptyOperator or one - # of its subclasses (which don't inherit from anything but BaseOperator). - return getattr(self, "_is_empty", False) - def defer( self, *, @@ -2038,7 +1284,7 @@ def chain_linear(*elements: DependencyMixin | Sequence[DependencyMixin]): prev_elem = None deps_set = False for curr_elem in elements: - if isinstance(curr_elem, EdgeModifier): + if isinstance(curr_elem, (EdgeModifier, TaskSDKEdgeModifier)): raise ValueError("Labels are not supported by chain_linear") if prev_elem is not None: for task in prev_elem: diff --git a/airflow/models/dag.py b/airflow/models/dag.py index d72eb0928cb4..851d2a512934 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -20,19 +20,15 @@ import asyncio import copy import functools -import itertools import logging -import os import pathlib import pickle import sys import time import traceback -import weakref -from collections import abc, defaultdict, deque +from collections import defaultdict from contextlib import ExitStack from datetime import datetime, timedelta -from inspect import signature from typing import ( TYPE_CHECKING, Any, @@ -40,17 +36,15 @@ Collection, Container, Iterable, - Iterator, - MutableSet, Pattern, Sequence, Union, cast, overload, ) -from urllib.parse import urlsplit -import jinja2 +import attrs +import methodtools import pendulum import re2 import sqlalchemy_jsonfield @@ -77,22 +71,16 @@ from sqlalchemy.orm import backref, relationship from sqlalchemy.sql import Select, expression -import airflow.templates from airflow import settings, utils from airflow.api_internal.internal_api_call import internal_api_call -from airflow.assets import Asset, AssetAlias, AssetAll, BaseAsset +from airflow.assets import Asset, AssetAlias, BaseAsset from airflow.configuration import conf as airflow_conf, secrets_backend_list from airflow.exceptions import ( AirflowException, - DuplicateTaskIdFound, - FailStopDagInvalidTriggerRule, - ParamValidationError, TaskDeferred, - TaskNotFound, UnknownExecutorException, ) from airflow.executors.executor_loader import ExecutorLoader -from airflow.models.abstractoperator import AbstractOperator, TaskStateChangeCallback from airflow.models.asset import ( AssetDagRunQueue, AssetModel, @@ -102,7 +90,6 @@ from airflow.models.dagcode import DagCode from airflow.models.dagpickle import DagPickle from airflow.models.dagrun import RUN_ID_REGEX, DagRun -from airflow.models.param import DagParam, ParamsDict from airflow.models.taskinstance import ( Context, TaskInstance, @@ -110,6 +97,7 @@ clear_task_instances, ) from airflow.models.tasklog import LogTemplate +from airflow.sdk import DAG as TaskSDKDag, dag as task_sdk_dag_decorator from airflow.secrets.local_filesystem import LocalFilesystemBackend from airflow.security import permissions from airflow.settings import json @@ -118,36 +106,28 @@ from airflow.timetables.interval import CronDataIntervalTimetable, DeltaDataIntervalTimetable from airflow.timetables.simple import ( AssetTriggeredTimetable, - ContinuousTimetable, NullTimetable, OnceTimetable, ) -from airflow.timetables.trigger import CronTriggerTimetable from airflow.utils import timezone from airflow.utils.dag_cycle_tester import check_cycle -from airflow.utils.decorators import fixup_decorator_warning_stack -from airflow.utils.helpers import exactly_one, validate_instance_args, validate_key +from airflow.utils.helpers import exactly_one from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.sqlalchemy import UtcDateTime, lock_rows, tuple_in_condition, with_row_locks from airflow.utils.state import DagRunState, State, TaskInstanceState -from airflow.utils.trigger_rule import TriggerRule -from airflow.utils.types import NOTSET, DagRunTriggeredByType, DagRunType, EdgeInfoType +from airflow.utils.types import DagRunTriggeredByType, DagRunType if TYPE_CHECKING: - from types import ModuleType - - from pendulum.tz.timezone import FixedTimezone, Timezone from sqlalchemy.orm.query import Query from sqlalchemy.orm.session import Session - from airflow.decorators import TaskDecoratorCollection + from airflow.models.abstractoperator import TaskStateChangeCallback from airflow.models.dagbag import DagBag from airflow.models.operator import Operator from airflow.serialization.pydantic.dag import DagModelPydantic from airflow.serialization.pydantic.dag_run import DagRunPydantic from airflow.typing_compat import Literal - from airflow.utils.task_group import TaskGroup log = logging.getLogger(__name__) @@ -206,24 +186,6 @@ def _get_model_data_interval( return DataInterval(start, end) -def create_timetable(interval: ScheduleInterval, timezone: Timezone | FixedTimezone) -> Timetable: - """Create a Timetable instance from a plain ``schedule`` value.""" - if interval is None: - return NullTimetable() - if interval == "@once": - return OnceTimetable() - if interval == "@continuous": - return ContinuousTimetable() - if isinstance(interval, (timedelta, relativedelta)): - return DeltaDataIntervalTimetable(interval) - if isinstance(interval, str): - if airflow_conf.getboolean("scheduler", "create_cron_data_intervals"): - return CronDataIntervalTimetable(interval, timezone) - else: - return CronTriggerTimetable(interval, timezone=timezone) - raise ValueError(f"{interval!r} is not a valid schedule.") - - def get_last_dagrun(dag_id, session, include_externally_triggered=False): """ Return the last dag run for a dag, None if there was none. @@ -332,34 +294,28 @@ def _create_orm_dagrun( return run -# TODO: The following mapping is used to validate that the arguments passed to the DAG are of the correct -# type. This is a temporary solution until we find a more sophisticated method for argument validation. -# One potential method is to use `get_type_hints` from the typing module. However, this is not fully -# compatible with future annotations for Python versions below 3.10. Once we require a minimum Python -# version that supports `get_type_hints` effectively or find a better approach, we can replace this -# manual type-checking method. -DAG_ARGS_EXPECTED_TYPES = { - "dag_id": str, - "description": str, - "max_active_tasks": int, - "max_active_runs": int, - "max_consecutive_failed_dag_runs": int, - "dagrun_timeout": timedelta, - "default_view": str, - "orientation": str, - "catchup": bool, - "doc_md": str, - "is_paused_upon_creation": bool, - "render_template_as_native_obj": bool, - "tags": Collection, - "auto_register": bool, - "fail_stop": bool, - "dag_display_name": str, -} +if TYPE_CHECKING: + dag = task_sdk_dag_decorator +else: + + def dag(dag_id: str = "", **kwargs): + return task_sdk_dag_decorator(dag_id, __DAG_class=DAG, __warnings_stacklevel_delta=3, **kwargs) + + +def _convert_max_consecutive_failed_dag_runs(val: int) -> int: + if val == 0: + val = airflow_conf.getint("core", "max_consecutive_failed_dag_runs_per_dag") + if val < 0: + raise ValueError( + f"Invalid max_consecutive_failed_dag_runs: {val}." + f"Requires max_consecutive_failed_dag_runs >= 0" + ) + return val @functools.total_ordering -class DAG(LoggingMixin): +@attrs.define(hash=False, repr=False, eq=False, slots=False) +class DAG(TaskSDKDag, LoggingMixin): """ A dag (directed acyclic graph) is a collection of tasks with directional dependencies. @@ -473,250 +429,30 @@ class DAG(LoggingMixin): :param dag_display_name: The display name of the DAG which appears on the UI. """ - _comps = { - "dag_id", - "task_ids", - "start_date", - "end_date", - "fileloc", - "template_searchpath", - "last_loaded", - } - - __serialized_fields: frozenset[str] | None = None - - fileloc: str - """ - File path that needs to be imported to load this DAG. - - This may not be an actual file on disk in the case when this DAG is loaded - from a ZIP file or other DAG distribution format. - """ - - # NOTE: When updating arguments here, please also keep arguments in @dag() - # below in sync. (Search for 'def dag(' in this file.) - def __init__( - self, - dag_id: str, - description: str | None = None, - schedule: ScheduleArg = None, - start_date: datetime | None = None, - end_date: datetime | None = None, - template_searchpath: str | Iterable[str] | None = None, - template_undefined: type[jinja2.StrictUndefined] = jinja2.StrictUndefined, - user_defined_macros: dict | None = None, - user_defined_filters: dict | None = None, - default_args: dict | None = None, - max_active_tasks: int = airflow_conf.getint("core", "max_active_tasks_per_dag"), - max_active_runs: int = airflow_conf.getint("core", "max_active_runs_per_dag"), - max_consecutive_failed_dag_runs: int = airflow_conf.getint( - "core", "max_consecutive_failed_dag_runs_per_dag" - ), - dagrun_timeout: timedelta | None = None, - sla_miss_callback: Any = None, - default_view: str = airflow_conf.get_mandatory_value("webserver", "dag_default_view").lower(), - orientation: str = airflow_conf.get_mandatory_value("webserver", "dag_orientation"), - catchup: bool = airflow_conf.getboolean("scheduler", "catchup_by_default"), - on_success_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None, - on_failure_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None, - doc_md: str | None = None, - params: abc.MutableMapping | None = None, - access_control: dict[str, dict[str, Collection[str]]] | dict[str, Collection[str]] | None = None, - is_paused_upon_creation: bool | None = None, - jinja_environment_kwargs: dict | None = None, - render_template_as_native_obj: bool = False, - tags: Collection[str] | None = None, - owner_links: dict[str, str] | None = None, - auto_register: bool = True, - fail_stop: bool = False, - dag_display_name: str | None = None, - ): - from airflow.utils.task_group import TaskGroup - - if tags and any(len(tag) > TAG_MAX_LEN for tag in tags): - raise AirflowException(f"tag cannot be longer than {TAG_MAX_LEN} characters") - - self.owner_links = owner_links or {} - self.user_defined_macros = user_defined_macros - self.user_defined_filters = user_defined_filters - if default_args and not isinstance(default_args, dict): - raise TypeError("default_args must be a dict") - self.default_args = copy.deepcopy(default_args or {}) - params = params or {} - - # merging potentially conflicting default_args['params'] into params - if "params" in self.default_args: - params.update(self.default_args["params"]) - del self.default_args["params"] - - # check self.params and convert them into ParamsDict - self.params = ParamsDict(params) - - validate_key(dag_id) - - self._dag_id = dag_id - self._dag_display_property_value = dag_display_name - - self._max_active_tasks = max_active_tasks - self._pickle_id: int | None = None - - self._description = description - # set file location to caller source path - back = sys._getframe().f_back - self.fileloc = back.f_code.co_filename if back else "" - self.task_dict: dict[str, Operator] = {} - - # set timezone from start_date - tz = None - if start_date and start_date.tzinfo: - tzinfo = None if start_date.tzinfo else settings.TIMEZONE - tz = pendulum.instance(start_date, tz=tzinfo).timezone - elif date := self.default_args.get("start_date"): - if not isinstance(date, datetime): - date = timezone.parse(date) - self.default_args["start_date"] = date - start_date = date - - tzinfo = None if date.tzinfo else settings.TIMEZONE - tz = pendulum.instance(date, tz=tzinfo).timezone - self.timezone: Timezone | FixedTimezone = tz or settings.TIMEZONE - - # Apply the timezone we settled on to end_date if it wasn't supplied - if isinstance(_end_date := self.default_args.get("end_date"), str): - self.default_args["end_date"] = timezone.parse(_end_date, timezone=self.timezone) - - self.start_date = timezone.convert_to_utc(start_date) - self.end_date = timezone.convert_to_utc(end_date) - - # also convert tasks - if "start_date" in self.default_args: - self.default_args["start_date"] = timezone.convert_to_utc(self.default_args["start_date"]) - if "end_date" in self.default_args: - self.default_args["end_date"] = timezone.convert_to_utc(self.default_args["end_date"]) - - if isinstance(schedule, Timetable): - self.timetable = schedule - elif isinstance(schedule, BaseAsset): - self.timetable = AssetTriggeredTimetable(schedule) - elif isinstance(schedule, Collection) and not isinstance(schedule, str): - if not all(isinstance(x, (Asset, AssetAlias)) for x in schedule): - raise ValueError("All elements in 'schedule' should be assets or asset aliases") - self.timetable = AssetTriggeredTimetable(AssetAll(*schedule)) - else: - self.timetable = create_timetable(schedule, self.timezone) - - requires_automatic_backfilling = self.timetable.can_be_scheduled and catchup - if requires_automatic_backfilling and not ("start_date" in self.default_args or self.start_date): - raise ValueError("start_date is required when catchup=True") - - if isinstance(template_searchpath, str): - template_searchpath = [template_searchpath] - self.template_searchpath = template_searchpath - self.template_undefined = template_undefined - self.last_loaded: datetime = timezone.utcnow() - self.safe_dag_id = dag_id.replace(".", "__dot__") - self.max_active_runs = max_active_runs - self.max_consecutive_failed_dag_runs = max_consecutive_failed_dag_runs - if self.max_consecutive_failed_dag_runs == 0: - self.max_consecutive_failed_dag_runs = airflow_conf.getint( - "core", "max_consecutive_failed_dag_runs_per_dag" - ) - if self.max_consecutive_failed_dag_runs < 0: - raise AirflowException( - f"Invalid max_consecutive_failed_dag_runs: {self.max_consecutive_failed_dag_runs}." - f"Requires max_consecutive_failed_dag_runs >= 0" - ) - if self.timetable.active_runs_limit is not None: - if self.timetable.active_runs_limit < self.max_active_runs: - raise AirflowException( - f"Invalid max_active_runs: {type(self.timetable)} " - f"requires max_active_runs <= {self.timetable.active_runs_limit}" - ) - self.dagrun_timeout = dagrun_timeout - if sla_miss_callback: - log.warning( - "The SLA feature is removed in Airflow 3.0, to be replaced with a new implementation in 3.1" - ) - if default_view in DEFAULT_VIEW_PRESETS: - self._default_view: str = default_view - else: - raise AirflowException( - f"Invalid values of dag.default_view: only support " - f"{DEFAULT_VIEW_PRESETS}, but get {default_view}" - ) - if orientation in ORIENTATION_PRESETS: - self.orientation = orientation - else: - raise AirflowException( - f"Invalid values of dag.orientation: only support " - f"{ORIENTATION_PRESETS}, but get {orientation}" - ) - self.catchup: bool = catchup - - self.partial: bool = False - self.on_success_callback = on_success_callback - self.on_failure_callback = on_failure_callback - - # Keeps track of any extra edge metadata (sparse; will not contain all - # edges, so do not iterate over it for that). Outer key is upstream - # task ID, inner key is downstream task ID. - self.edge_info: dict[str, dict[str, EdgeInfoType]] = {} - - # To keep it in parity with Serialized DAGs - # and identify if DAG has on_*_callback without actually storing them in Serialized JSON - self.has_on_success_callback: bool = self.on_success_callback is not None - self.has_on_failure_callback: bool = self.on_failure_callback is not None - - self._access_control = DAG._upgrade_outdated_dag_access_control(access_control) - self.is_paused_upon_creation = is_paused_upon_creation - self.auto_register = auto_register - - self.fail_stop: bool = fail_stop - - self.jinja_environment_kwargs = jinja_environment_kwargs - self.render_template_as_native_obj = render_template_as_native_obj + partial: bool = False + last_loaded: datetime | None = attrs.field(factory=timezone.utcnow) - self.doc_md = self.get_doc_md(doc_md) + default_view: str = airflow_conf.get_mandatory_value("webserver", "dag_default_view").lower() + orientation: str = airflow_conf.get_mandatory_value("webserver", "dag_orientation") - self.tags: MutableSet[str] = set(tags or []) - self._task_group = TaskGroup.create_root(self) - self.validate_schedule_and_params() - wrong_links = dict(self.iter_invalid_owner_links()) - if wrong_links: - raise AirflowException( - "Wrong link format was used for the owner. Use a valid link \n" - f"Bad formatted links are: {wrong_links}" - ) - - # this will only be set at serialization time - # it's only use is for determining the relative - # fileloc based only on the serialize dag - self._processor_dags_folder = None - - validate_instance_args(self, DAG_ARGS_EXPECTED_TYPES) + # this will only be set at serialization time + # it's only use is for determining the relative fileloc based only on the serialize dag + _processor_dags_folder: str | None = attrs.field(init=False, default=None) - def get_doc_md(self, doc_md: str | None) -> str | None: - if doc_md is None: - return doc_md - - if doc_md.endswith(".md"): - try: - return open(doc_md).read() - except FileNotFoundError: - return doc_md + # Override the default from parent class to use config + max_consecutive_failed_dag_runs: int = attrs.field( + default=0, + converter=_convert_max_consecutive_failed_dag_runs, + validator=attrs.validators.instance_of(int), + ) - return doc_md + @property + def safe_dag_id(self): + return self.dag_id.replace(".", "__dot__") def validate(self): - """ - Validate the DAG has a coherent setup. - - This is called by the DAG bag before bagging the DAG. - """ + super().validate() self.validate_executor_field() - self.validate_schedule_and_params() - self.timetable.validate() - self.validate_setup_teardown() def validate_executor_field(self): for task in self.tasks: @@ -730,63 +466,6 @@ def validate_executor_field(self): "update the executor configuration for this task." ) - def validate_setup_teardown(self): - """ - Validate that setup and teardown tasks are configured properly. - - :meta private: - """ - for task in self.tasks: - if task.is_setup: - for down_task in task.downstream_list: - if not down_task.is_teardown and down_task.trigger_rule != TriggerRule.ALL_SUCCESS: - # todo: we can relax this to allow out-of-scope tasks to have other trigger rules - # this is required to ensure consistent behavior of dag - # when clearing an indirect setup - raise ValueError("Setup tasks must be followed with trigger rule ALL_SUCCESS.") - FailStopDagInvalidTriggerRule.check(dag=self, trigger_rule=task.trigger_rule) - - def __repr__(self): - return f"" - - def __eq__(self, other): - if type(self) is type(other): - # Use getattr() instead of __dict__ as __dict__ doesn't return - # correct values for properties. - return all(getattr(self, c, None) == getattr(other, c, None) for c in self._comps) - return False - - def __ne__(self, other): - return not self == other - - def __lt__(self, other): - return self.dag_id < other.dag_id - - def __hash__(self): - hash_components = [type(self)] - for c in self._comps: - # task_ids returns a list and lists can't be hashed - if c == "task_ids": - val = tuple(self.task_dict) - else: - val = getattr(self, c, None) - try: - hash(val) - hash_components.append(val) - except TypeError: - hash_components.append(repr(val)) - return hash(tuple(hash_components)) - - # Context Manager ----------------------------------------------- - def __enter__(self): - DagContext.push_context_managed_dag(self) - return self - - def __exit__(self, _type, _value, _tb): - DagContext.pop_context_managed_dag() - - # /Context Manager ---------------------------------------------- - @staticmethod def _upgrade_outdated_dag_access_control(access_control=None): """Look for outdated dag level actions in DAG access_controls and replace them with updated actions.""" @@ -951,7 +630,7 @@ def _time_restriction(self) -> TimeRestriction: earliest = None if start_dates: earliest = timezone.coerce_datetime(min(start_dates)) - latest = self.end_date + latest = timezone.coerce_datetime(self.end_date) end_dates = [t.end_date for t in self.tasks if t.end_date] if len(end_dates) == len(self.tasks): # not exists null end_date if self.end_date is not None: @@ -962,8 +641,8 @@ def _time_restriction(self) -> TimeRestriction: def iter_dagrun_infos_between( self, - earliest: pendulum.DateTime | None, - latest: pendulum.DateTime, + earliest: pendulum.DateTime | datetime | None, + latest: pendulum.DateTime | datetime, *, align: bool = True, ) -> Iterable[DagRunInfo]: @@ -1060,34 +739,6 @@ def dag_id(self, value: str) -> None: def timetable_summary(self) -> str: return self.timetable.summary - @property - def max_active_tasks(self) -> int: - return self._max_active_tasks - - @max_active_tasks.setter - def max_active_tasks(self, value: int): - self._max_active_tasks = value - - @property - def access_control(self): - return self._access_control - - @access_control.setter - def access_control(self, value): - self._access_control = DAG._upgrade_outdated_dag_access_control(value) - - @property - def dag_display_name(self) -> str: - return self._dag_display_property_value or self._dag_id - - @property - def description(self) -> str | None: - return self._description - - @property - def default_view(self) -> str: - return self._default_view - @property def pickle_id(self) -> int | None: return self._pickle_id @@ -1096,41 +747,6 @@ def pickle_id(self) -> int | None: def pickle_id(self, value: int) -> None: self._pickle_id = value - def param(self, name: str, default: Any = NOTSET) -> DagParam: - """ - Return a DagParam object for current dag. - - :param name: dag parameter name. - :param default: fallback value for dag parameter. - :return: DagParam instance for specified name and current dag. - """ - return DagParam(current_dag=self, name=name, default=default) - - @property - def tasks(self) -> list[Operator]: - return list(self.task_dict.values()) - - @tasks.setter - def tasks(self, val): - raise AttributeError("DAG.tasks can not be modified. Use dag.add_task() instead.") - - @property - def task_ids(self) -> list[str]: - return list(self.task_dict) - - @property - def teardowns(self) -> list[Operator]: - return [task for task in self.tasks if getattr(task, "is_teardown", None)] - - @property - def tasks_upstream_of_teardowns(self) -> list[Operator]: - upstream_tasks = [t.upstream_list for t in self.teardowns] - return [val for sublist in upstream_tasks for val in sublist if not getattr(val, "is_teardown", None)] - - @property - def task_group(self) -> TaskGroup: - return self._task_group - @property def relative_fileloc(self) -> pathlib.Path: """File location of the importable dag 'file' relative to the configured DAGs folder.""" @@ -1145,24 +761,6 @@ def relative_fileloc(self) -> pathlib.Path: # Not relative to DAGS_FOLDER. return path - @property - def folder(self) -> str: - """Folder location of where the DAG object is instantiated.""" - return os.path.dirname(self.fileloc) - - @property - def owner(self) -> str: - """ - Return list of all owners found in DAG tasks. - - :return: Comma separated list of owners in DAG tasks - """ - return ", ".join({t.owner for t in self.tasks}) - - @property - def allow_future_exec_dates(self) -> bool: - return settings.ALLOW_FUTURE_EXEC_DATES and not self.timetable.can_be_scheduled - @provide_session def get_concurrency_reached(self, session=NEW_SESSION) -> bool: """Return a boolean indicating whether the max_active_tasks limit for this DAG has been reached.""" @@ -1185,6 +783,14 @@ def get_is_paused(self, session=NEW_SESSION) -> None: """Return a boolean indicating whether this DAG is paused.""" return session.scalar(select(DagModel.is_paused).where(DagModel.dag_id == self.dag_id)) + @methodtools.lru_cache(maxsize=None) + @classmethod + def get_serialized_fields(cls): + """Stringified DAGs and operators contain exactly these fields.""" + return TaskSDKDag.get_serialized_fields() | { + "_processor_dags_folder", + } + @staticmethod @internal_api_call @provide_session @@ -1346,45 +952,6 @@ def get_latest_execution_date(self, session: Session = NEW_SESSION) -> pendulum. """Return the latest date for which at least one dag run exists.""" return session.scalar(select(func.max(DagRun.execution_date)).where(DagRun.dag_id == self.dag_id)) - def resolve_template_files(self): - for t in self.tasks: - t.resolve_template_files() - - def get_template_env(self, *, force_sandboxed: bool = False) -> jinja2.Environment: - """Build a Jinja2 environment.""" - # Collect directories to search for template files - searchpath = [self.folder] - if self.template_searchpath: - searchpath += self.template_searchpath - - # Default values (for backward compatibility) - jinja_env_options = { - "loader": jinja2.FileSystemLoader(searchpath), - "undefined": self.template_undefined, - "extensions": ["jinja2.ext.do"], - "cache_size": 0, - } - if self.jinja_environment_kwargs: - jinja_env_options.update(self.jinja_environment_kwargs) - env: jinja2.Environment - if self.render_template_as_native_obj and not force_sandboxed: - env = airflow.templates.NativeEnvironment(**jinja_env_options) - else: - env = airflow.templates.SandboxedEnvironment(**jinja_env_options) - - # Add any user defined items. Safe to edit globals as long as no templates are rendered yet. - # http://jinja.pocoo.org/docs/2.10/api/#jinja2.Environment.globals - if self.user_defined_macros: - env.globals.update(self.user_defined_macros) - if self.user_defined_filters: - env.filters.update(self.user_defined_filters) - - return env - - def set_dependency(self, upstream_task_id, downstream_task_id): - """Set dependency between two tasks that already have been added to the DAG using add_task().""" - self.get_task(upstream_task_id).set_downstream(self.get_task(downstream_task_id)) - @provide_session def get_task_instances_before( self, @@ -1849,33 +1416,6 @@ def set_task_group_state( return altered - @property - def roots(self) -> list[Operator]: - """Return nodes with no parents. These are first to execute and are called roots or root nodes.""" - return [task for task in self.tasks if not task.upstream_list] - - @property - def leaves(self) -> list[Operator]: - """Return nodes with no children. These are last to execute and are called leaves or leaf nodes.""" - return [task for task in self.tasks if not task.downstream_list] - - def topological_sort(self): - """ - Sorts tasks in topographical order, such that a task comes after any of its upstream dependencies. - - Deprecated in place of ``task_group.topological_sort`` - """ - from airflow.utils.task_group import TaskGroup - - def nested_topo(group): - for node in group.topological_sort(): - if isinstance(node, TaskGroup): - yield from nested_topo(node) - else: - yield node - - return tuple(nested_topo(self.task_group)) - @provide_session def clear( self, @@ -2009,169 +1549,6 @@ def clear_dags( print("Cancelled, nothing was cleared.") return count - def __deepcopy__(self, memo): - # Switcharoo to go around deepcopying objects coming through the - # backdoor - cls = self.__class__ - result = cls.__new__(cls) - memo[id(self)] = result - for k, v in self.__dict__.items(): - if k not in ("user_defined_macros", "user_defined_filters", "_log"): - setattr(result, k, copy.deepcopy(v, memo)) - - result.user_defined_macros = self.user_defined_macros - result.user_defined_filters = self.user_defined_filters - if hasattr(self, "_log"): - result._log = self._log - return result - - def partial_subset( - self, - task_ids_or_regex: str | Pattern | Iterable[str], - include_downstream=False, - include_upstream=True, - include_direct_upstream=False, - ): - """ - Return a subset of the current dag based on regex matching one or more tasks. - - Returns a subset of the current dag as a deep copy of the current dag - based on a regex that should match one or many tasks, and includes - upstream and downstream neighbours based on the flag passed. - - :param task_ids_or_regex: Either a list of task_ids, or a regex to - match against task ids (as a string, or compiled regex pattern). - :param include_downstream: Include all downstream tasks of matched - tasks, in addition to matched tasks. - :param include_upstream: Include all upstream tasks of matched tasks, - in addition to matched tasks. - :param include_direct_upstream: Include all tasks directly upstream of matched - and downstream (if include_downstream = True) tasks - """ - from airflow.models.baseoperator import BaseOperator - from airflow.models.mappedoperator import MappedOperator - - # deep-copying self.task_dict and self._task_group takes a long time, and we don't want all - # the tasks anyway, so we copy the tasks manually later - memo = {id(self.task_dict): None, id(self._task_group): None} - dag = copy.deepcopy(self, memo) # type: ignore - - if isinstance(task_ids_or_regex, (str, Pattern)): - matched_tasks = [t for t in self.tasks if re2.findall(task_ids_or_regex, t.task_id)] - else: - matched_tasks = [t for t in self.tasks if t.task_id in task_ids_or_regex] - - also_include_ids: set[str] = set() - for t in matched_tasks: - if include_downstream: - for rel in t.get_flat_relatives(upstream=False): - also_include_ids.add(rel.task_id) - if rel not in matched_tasks: # if it's in there, we're already processing it - # need to include setups and teardowns for tasks that are in multiple - # non-collinear setup/teardown paths - if not rel.is_setup and not rel.is_teardown: - also_include_ids.update( - x.task_id for x in rel.get_upstreams_only_setups_and_teardowns() - ) - if include_upstream: - also_include_ids.update(x.task_id for x in t.get_upstreams_follow_setups()) - else: - if not t.is_setup and not t.is_teardown: - also_include_ids.update(x.task_id for x in t.get_upstreams_only_setups_and_teardowns()) - if t.is_setup and not include_downstream: - also_include_ids.update(x.task_id for x in t.downstream_list if x.is_teardown) - - also_include: list[Operator] = [self.task_dict[x] for x in also_include_ids] - direct_upstreams: list[Operator] = [] - if include_direct_upstream: - for t in itertools.chain(matched_tasks, also_include): - upstream = (u for u in t.upstream_list if isinstance(u, (BaseOperator, MappedOperator))) - direct_upstreams.extend(upstream) - - # Compiling the unique list of tasks that made the cut - # Make sure to not recursively deepcopy the dag or task_group while copying the task. - # task_group is reset later - def _deepcopy_task(t) -> Operator: - memo.setdefault(id(t.task_group), None) - return copy.deepcopy(t, memo) - - dag.task_dict = { - t.task_id: _deepcopy_task(t) - for t in itertools.chain(matched_tasks, also_include, direct_upstreams) - } - - def filter_task_group(group, parent_group): - """Exclude tasks not included in the subdag from the given TaskGroup.""" - # We want to deepcopy _most but not all_ attributes of the task group, so we create a shallow copy - # and then manually deep copy the instances. (memo argument to deepcopy only works for instances - # of classes, not "native" properties of an instance) - copied = copy.copy(group) - - memo[id(group.children)] = {} - if parent_group: - memo[id(group.parent_group)] = parent_group - for attr, value in copied.__dict__.items(): - if id(value) in memo: - value = memo[id(value)] - else: - value = copy.deepcopy(value, memo) - copied.__dict__[attr] = value - - proxy = weakref.proxy(copied) - - for child in group.children.values(): - if isinstance(child, AbstractOperator): - if child.task_id in dag.task_dict: - task = copied.children[child.task_id] = dag.task_dict[child.task_id] - task.task_group = proxy - else: - copied.used_group_ids.discard(child.task_id) - else: - filtered_child = filter_task_group(child, proxy) - - # Only include this child TaskGroup if it is non-empty. - if filtered_child.children: - copied.children[child.group_id] = filtered_child - - return copied - - dag._task_group = filter_task_group(self.task_group, None) - - # Removing upstream/downstream references to tasks and TaskGroups that did not make - # the cut. - subdag_task_groups = dag.task_group.get_task_group_dict() - for group in subdag_task_groups.values(): - group.upstream_group_ids.intersection_update(subdag_task_groups) - group.downstream_group_ids.intersection_update(subdag_task_groups) - group.upstream_task_ids.intersection_update(dag.task_dict) - group.downstream_task_ids.intersection_update(dag.task_dict) - - for t in dag.tasks: - # Removing upstream/downstream references to tasks that did not - # make the cut - t.upstream_task_ids.intersection_update(dag.task_dict) - t.downstream_task_ids.intersection_update(dag.task_dict) - - if len(dag.tasks) < len(self.tasks): - dag.partial = True - - return dag - - def has_task(self, task_id: str): - return task_id in self.task_dict - - def has_task_group(self, task_group_id: str) -> bool: - return task_group_id in self.task_group_dict - - @functools.cached_property - def task_group_dict(self): - return {k: v for k, v in self._task_group.get_task_group_dict().items() if k is not None} - - def get_task(self, task_id: str) -> Operator: - if task_id in self.task_dict: - return self.task_dict[task_id] - raise TaskNotFound(f"Task {task_id} not found") - def pickle_info(self): d = {} d["is_picklable"] = True @@ -2201,76 +1578,6 @@ def pickle(self, session=NEW_SESSION) -> DagPickle: return dp - @property - def task(self) -> TaskDecoratorCollection: - from airflow.decorators import task - - return cast("TaskDecoratorCollection", functools.partial(task, dag=self)) - - def add_task(self, task: Operator) -> None: - """ - Add a task to the DAG. - - :param task: the task you want to add - """ - FailStopDagInvalidTriggerRule.check(dag=self, trigger_rule=task.trigger_rule) - - from airflow.utils.task_group import TaskGroupContext - - # if the task has no start date, assign it the same as the DAG - if not task.start_date: - task.start_date = self.start_date - # otherwise, the task will start on the later of its own start date and - # the DAG's start date - elif self.start_date: - task.start_date = max(task.start_date, self.start_date) - - # if the task has no end date, assign it the same as the dag - if not task.end_date: - task.end_date = self.end_date - # otherwise, the task will end on the earlier of its own end date and - # the DAG's end date - elif task.end_date and self.end_date: - task.end_date = min(task.end_date, self.end_date) - - task_id = task.task_id - if not task.task_group: - task_group = TaskGroupContext.get_current_task_group(self) - if task_group: - task_id = task_group.child_id(task_id) - task_group.add(task) - - if ( - task_id in self.task_dict and self.task_dict[task_id] is not task - ) or task_id in self._task_group.used_group_ids: - raise DuplicateTaskIdFound(f"Task id '{task_id}' has already been added to the DAG") - else: - self.task_dict[task_id] = task - task.dag = self - # Add task_id to used_group_ids to prevent group_id and task_id collisions. - self._task_group.used_group_ids.add(task_id) - - self.task_count = len(self.task_dict) - - def add_tasks(self, tasks: Iterable[Operator]) -> None: - """ - Add a list of tasks to the DAG. - - :param tasks: a lit of tasks you want to add - """ - for task in tasks: - self.add_task(task) - - def _remove_task(self, task_id: str) -> None: - # This is "private" as removing could leave a hole in dependencies if done incorrectly, and this - # doesn't guard against that - task = self.task_dict.pop(task_id) - tg = getattr(task, "task_group", None) - if tg: - tg._remove(task) - - self.task_count = len(self.task_dict) - def cli(self): """Exposes a CLI specific to this DAG.""" check_cycle(self) @@ -2515,6 +1822,9 @@ def create_dagrun( # todo: AIP-78 add verification that if run type is backfill then we have a backfill id + if TYPE_CHECKING: + # TODO: Task-SDK: remove this assert + assert self.params # create a copy of params before validating copied_params = copy.deepcopy(self.params) copied_params.update(conf or {}) @@ -2669,88 +1979,6 @@ def get_num_task_instances(dag_id, run_id=None, task_ids=None, states=None, sess qry = qry.where(TaskInstance.state.in_(states)) return session.scalar(qry) - @classmethod - def get_serialized_fields(cls): - """Stringified DAGs and operators contain exactly these fields.""" - if not cls.__serialized_fields: - exclusion_list = { - "schedule_asset_references", - "schedule_asset_alias_references", - "task_outlet_asset_references", - "_old_context_manager_dags", - "safe_dag_id", - "last_loaded", - "user_defined_filters", - "user_defined_macros", - "partial", - "params", - "_pickle_id", - "_log", - "task_dict", - "template_searchpath", - "sla_miss_callback", - "on_success_callback", - "on_failure_callback", - "template_undefined", - "jinja_environment_kwargs", - # has_on_*_callback are only stored if the value is True, as the default is False - "has_on_success_callback", - "has_on_failure_callback", - "auto_register", - "fail_stop", - } - cls.__serialized_fields = frozenset(vars(DAG(dag_id="test", schedule=None))) - exclusion_list - return cls.__serialized_fields - - def get_edge_info(self, upstream_task_id: str, downstream_task_id: str) -> EdgeInfoType: - """Return edge information for the given pair of tasks or an empty edge if there is no information.""" - # Note - older serialized DAGs may not have edge_info being a dict at all - empty = cast(EdgeInfoType, {}) - if self.edge_info: - return self.edge_info.get(upstream_task_id, {}).get(downstream_task_id, empty) - else: - return empty - - def set_edge_info(self, upstream_task_id: str, downstream_task_id: str, info: EdgeInfoType): - """ - Set the given edge information on the DAG. - - Note that this will overwrite, rather than merge with, existing info. - """ - self.edge_info.setdefault(upstream_task_id, {})[downstream_task_id] = info - - def validate_schedule_and_params(self): - """ - Validate Param values when the DAG has schedule defined. - - Raise exception if there are any Params which can not be resolved by their schema definition. - """ - if not self.timetable.can_be_scheduled: - return - - try: - self.params.validate() - except ParamValidationError as pverr: - raise AirflowException( - "DAG is not allowed to define a Schedule, " - "if there are any required params without default values or default values are not valid." - ) from pverr - - def iter_invalid_owner_links(self) -> Iterator[tuple[str, str]]: - """ - Parse a given link, and verifies if it's a valid URL, or a 'mailto' link. - - Returns an iterator of invalid (owner, link) pairs. - """ - for owner, link in self.owner_links.items(): - result = urlsplit(link) - if result.scheme == "mailto": - # netloc is not existing for 'mailto' link, so we are checking that the path is parsed - if not result.path: - yield result.path, link - elif not result.scheme or not result.netloc: - yield owner, link - class DagTag(Base): """A tag name per dag, to allow quick filtering in the DAG view.""" @@ -3191,123 +2419,6 @@ def get_asset_triggered_next_run_info(self, *, session=NEW_SESSION) -> dict[str, return get_asset_triggered_next_run_info([self.dag_id], session=session).get(self.dag_id, None) -# NOTE: Please keep the list of arguments in sync with DAG.__init__. -# Only exception: dag_id here should have a default value, but not in DAG. -def dag( - dag_id: str = "", - description: str | None = None, - schedule: ScheduleArg = None, - start_date: datetime | None = None, - end_date: datetime | None = None, - template_searchpath: str | Iterable[str] | None = None, - template_undefined: type[jinja2.StrictUndefined] = jinja2.StrictUndefined, - user_defined_macros: dict | None = None, - user_defined_filters: dict | None = None, - default_args: dict | None = None, - max_active_tasks: int = airflow_conf.getint("core", "max_active_tasks_per_dag"), - max_active_runs: int = airflow_conf.getint("core", "max_active_runs_per_dag"), - max_consecutive_failed_dag_runs: int = airflow_conf.getint( - "core", "max_consecutive_failed_dag_runs_per_dag" - ), - dagrun_timeout: timedelta | None = None, - sla_miss_callback: Any = None, - default_view: str = airflow_conf.get_mandatory_value("webserver", "dag_default_view").lower(), - orientation: str = airflow_conf.get_mandatory_value("webserver", "dag_orientation"), - catchup: bool = airflow_conf.getboolean("scheduler", "catchup_by_default"), - on_success_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None, - on_failure_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None, - doc_md: str | None = None, - params: abc.MutableMapping | None = None, - access_control: dict[str, dict[str, Collection[str]]] | dict[str, Collection[str]] | None = None, - is_paused_upon_creation: bool | None = None, - jinja_environment_kwargs: dict | None = None, - render_template_as_native_obj: bool = False, - tags: Collection[str] | None = None, - owner_links: dict[str, str] | None = None, - auto_register: bool = True, - fail_stop: bool = False, - dag_display_name: str | None = None, -) -> Callable[[Callable], Callable[..., DAG]]: - """ - Python dag decorator which wraps a function into an Airflow DAG. - - Accepts kwargs for operator kwarg. Can be used to parameterize DAGs. - - :param dag_args: Arguments for DAG object - :param dag_kwargs: Kwargs for DAG object. - """ - - def wrapper(f: Callable) -> Callable[..., DAG]: - @functools.wraps(f) - def factory(*args, **kwargs): - # Generate signature for decorated function and bind the arguments when called - # we do this to extract parameters, so we can annotate them on the DAG object. - # In addition, this fails if we are missing any args/kwargs with TypeError as expected. - f_sig = signature(f).bind(*args, **kwargs) - # Apply defaults to capture default values if set. - f_sig.apply_defaults() - - # Initialize DAG with bound arguments - with DAG( - dag_id or f.__name__, - description=description, - start_date=start_date, - end_date=end_date, - template_searchpath=template_searchpath, - template_undefined=template_undefined, - user_defined_macros=user_defined_macros, - user_defined_filters=user_defined_filters, - default_args=default_args, - max_active_tasks=max_active_tasks, - max_active_runs=max_active_runs, - max_consecutive_failed_dag_runs=max_consecutive_failed_dag_runs, - dagrun_timeout=dagrun_timeout, - sla_miss_callback=sla_miss_callback, - default_view=default_view, - orientation=orientation, - catchup=catchup, - on_success_callback=on_success_callback, - on_failure_callback=on_failure_callback, - doc_md=doc_md, - params=params, - access_control=access_control, - is_paused_upon_creation=is_paused_upon_creation, - jinja_environment_kwargs=jinja_environment_kwargs, - render_template_as_native_obj=render_template_as_native_obj, - tags=tags, - schedule=schedule, - owner_links=owner_links, - auto_register=auto_register, - fail_stop=fail_stop, - dag_display_name=dag_display_name, - ) as dag_obj: - # Set DAG documentation from function documentation if it exists and doc_md is not set. - if f.__doc__ and not dag_obj.doc_md: - dag_obj.doc_md = f.__doc__ - - # Generate DAGParam for each function arg/kwarg and replace it for calling the function. - # All args/kwargs for function will be DAGParam object and replaced on execution time. - f_kwargs = {} - for name, value in f_sig.arguments.items(): - f_kwargs[name] = dag_obj.param(name, value) - - # set file location to caller source path - back = sys._getframe().f_back - dag_obj.fileloc = back.f_code.co_filename if back else "" - - # Invoke function to create operators in the DAG scope. - f(**f_kwargs) - - # Return dag object such that it's accessible in Globals. - return dag_obj - - # Ensure that warnings from inside DAG() are emitted from the caller, not here - fixup_decorator_warning_stack(factory) - return factory - - return wrapper - - STATICA_HACK = True globals()["kcah_acitats"[::-1].upper()] = False if STATICA_HACK: # pragma: no cover @@ -3317,54 +2428,6 @@ def factory(*args, **kwargs): """:sphinx-autoapi-skip:""" -class DagContext: - """ - DAG context is used to keep the current DAG when DAG is used as ContextManager. - - You can use DAG as context: - - .. code-block:: python - - with DAG( - dag_id="example_dag", - default_args=default_args, - schedule="0 0 * * *", - dagrun_timeout=timedelta(minutes=60), - ) as dag: - ... - - If you do this the context stores the DAG and whenever new task is created, it will use - such stored DAG as the parent DAG. - - """ - - _context_managed_dags: deque[DAG] = deque() - autoregistered_dags: set[tuple[DAG, ModuleType]] = set() - current_autoregister_module_name: str | None = None - - @classmethod - def push_context_managed_dag(cls, dag: DAG): - cls._context_managed_dags.appendleft(dag) - - @classmethod - def pop_context_managed_dag(cls) -> DAG | None: - dag = cls._context_managed_dags.popleft() - - # In a few cases around serialization we explicitly push None in to the stack - if cls.current_autoregister_module_name is not None and dag and dag.auto_register: - mod = sys.modules[cls.current_autoregister_module_name] - cls.autoregistered_dags.add((dag, mod)) - - return dag - - @classmethod - def get_current_dag(cls) -> DAG | None: - try: - return cls._context_managed_dags[0] - except IndexError: - return None - - def _run_inline_trigger(trigger): async def _run_inline_trigger_main(): async for event in trigger.run(): diff --git a/airflow/models/dagbag.py b/airflow/models/dagbag.py index b2d45a133187..c9ad8edaa401 100644 --- a/airflow/models/dagbag.py +++ b/airflow/models/dagbag.py @@ -245,7 +245,9 @@ def get_dag(self, dag_id, session: Session = None): # If the dag corresponding to root_dag_id is absent or expired is_missing = root_dag_id not in self.dags - is_expired = orm_dag.last_expired and dag and dag.last_loaded < orm_dag.last_expired + is_expired = ( + orm_dag.last_expired and dag and dag.last_loaded and dag.last_loaded < orm_dag.last_expired + ) if is_expired: # Remove associated dags so we can re-add them. self.dags = {key: dag for key, dag in self.dags.items()} @@ -278,7 +280,7 @@ def _add_dag_from_db(self, dag_id: str, session: Session): def process_file(self, filepath, only_if_updated=True, safe_mode=True): """Given a path to a python module or zip file, import the module and look for dag objects within.""" - from airflow.models.dag import DagContext + from airflow.sdk.definitions.contextmanager import DagContext # if the source file no longer exists in the DB or in the filesystem, # return an empty list @@ -326,7 +328,7 @@ def process_file(self, filepath, only_if_updated=True, safe_mode=True): return found_dags def _load_modules_from_file(self, filepath, safe_mode): - from airflow.models.dag import DagContext + from airflow.sdk.definitions.contextmanager import DagContext if not might_contain_dag(filepath, safe_mode): # Don't want to spam user with skip messages @@ -382,7 +384,7 @@ def parse(mod_name, filepath): return parse(mod_name, filepath) def _load_modules_from_zip(self, filepath, safe_mode): - from airflow.models.dag import DagContext + from airflow.sdk.definitions.contextmanager import DagContext mods = [] with zipfile.ZipFile(filepath) as current_zip_file: @@ -431,7 +433,8 @@ def _load_modules_from_zip(self, filepath, safe_mode): return mods def _process_modules(self, filepath, mods, file_last_changed_on_disk): - from airflow.models.dag import DAG, DagContext # Avoid circular import + from airflow.models.dag import DAG # Avoid circular import + from airflow.sdk.definitions.contextmanager import DagContext top_level_dags = {(o, m) for m in mods for o in m.__dict__.values() if isinstance(o, DAG)} diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py index 8a9e790ea7fc..925acfc16f0c 100644 --- a/airflow/models/mappedoperator.py +++ b/airflow/models/mappedoperator.py @@ -201,8 +201,8 @@ def _expand(self, expand_input: ExpandInput, *, strict: bool) -> MappedOperator: task_id = partial_kwargs.pop("task_id") dag = partial_kwargs.pop("dag") task_group = partial_kwargs.pop("task_group") - start_date = partial_kwargs.pop("start_date") - end_date = partial_kwargs.pop("end_date") + start_date = partial_kwargs.pop("start_date", None) + end_date = partial_kwargs.pop("end_date", None) try: operator_name = self.operator_class.custom_operator_name # type: ignore @@ -333,7 +333,8 @@ def __attrs_post_init__(self): @classmethod def get_serialized_fields(cls): # Not using 'cls' here since we only want to serialize base fields. - return frozenset(attr.fields_dict(MappedOperator)) - { + return (frozenset(attr.fields_dict(MappedOperator)) | {"task_type"}) - { + "_task_type", "dag", "deps", "expand_input", # This is needed to be able to accept XComArg. diff --git a/airflow/models/param.py b/airflow/models/param.py index 895cd2af8bb4..28253a86f5ca 100644 --- a/airflow/models/param.py +++ b/airflow/models/param.py @@ -27,9 +27,9 @@ from airflow.utils.types import NOTSET, ArgNotSet if TYPE_CHECKING: - from airflow.models.dag import DAG from airflow.models.dagrun import DagRun from airflow.models.operator import Operator + from airflow.sdk import DAG from airflow.serialization.pydantic.dag_run import DagRunPydantic from airflow.utils.context import Context diff --git a/airflow/models/skipmixin.py b/airflow/models/skipmixin.py index 12cbdb380b92..a67c7cf310ba 100644 --- a/airflow/models/skipmixin.py +++ b/airflow/models/skipmixin.py @@ -36,7 +36,7 @@ from airflow.models.dagrun import DagRun from airflow.models.operator import Operator - from airflow.models.taskmixin import DAGNode + from airflow.sdk.definitions.node import DAGNode from airflow.serialization.pydantic.dag_run import DagRunPydantic from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 95098181c07e..bb07ba6d848a 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -157,9 +157,10 @@ from airflow.models.abstractoperator import TaskStateChangeCallback from airflow.models.baseoperator import BaseOperator - from airflow.models.dag import DAG, DagModel + from airflow.models.dag import DAG as SchedulerDAG, DagModel from airflow.models.dagrun import DagRun from airflow.models.operator import Operator + from airflow.sdk import DAG from airflow.serialization.pydantic.asset import AssetEventPydantic from airflow.serialization.pydantic.dag import DagModelPydantic from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic @@ -931,7 +932,7 @@ def _clear_next_method_args(*, task_instance: TaskInstance | TaskInstancePydanti def _get_template_context( *, task_instance: TaskInstance | TaskInstancePydantic, - dag: DAG, + dag: SchedulerDAG, session: Session | None = None, ignore_param_exceptions: bool = True, ) -> Context: @@ -961,7 +962,8 @@ def _get_template_context( assert task.dag if task.dag.__class__ is AttributeRemoved: - task.dag = dag # required after deserialization + # TODO: Task-SDK: Remove this after AIP-44 code is removed + task.dag = dag # type: ignore[assignment] # required after deserialization dag_run = task_instance.get_dagrun(session) data_interval = dag.get_run_data_interval(dag_run) @@ -1319,8 +1321,10 @@ def _record_task_map_for_downstreams( """ from airflow.models.mappedoperator import MappedOperator + # TODO: Task-SDK: Remove this after AIP-44 code is removed if task.dag.__class__ is AttributeRemoved: - task.dag = dag # required after deserialization + # required after deserialization + task.dag = dag # type: ignore[assignment] if next(task.iter_mapped_dependants(), None) is None: # No mapped dependants, no need to validate. return diff --git a/airflow/models/taskmixin.py b/airflow/models/taskmixin.py index 05768ff36fe1..fa76a3815cb8 100644 --- a/airflow/models/taskmixin.py +++ b/airflow/models/taskmixin.py @@ -16,271 +16,13 @@ # under the License. from __future__ import annotations -from abc import ABCMeta, abstractmethod -from typing import TYPE_CHECKING, Any, Iterable, Sequence - -from airflow.exceptions import AirflowException -from airflow.utils.types import NOTSET +from typing import TYPE_CHECKING if TYPE_CHECKING: - from logging import Logger - - import pendulum - - from airflow.models.baseoperator import BaseOperator - from airflow.models.dag import DAG - from airflow.models.operator import Operator - from airflow.serialization.enums import DagAttributeTypes - from airflow.utils.edgemodifier import EdgeModifier - from airflow.utils.task_group import TaskGroup - from airflow.utils.types import ArgNotSet - - -class DependencyMixin: - """Mixing implementing common dependency setting methods like >> and <<.""" - - @property - def roots(self) -> Sequence[DependencyMixin]: - """ - List of root nodes -- ones with no upstream dependencies. - - a.k.a. the "start" of this sub-graph - """ - raise NotImplementedError() - - @property - def leaves(self) -> Sequence[DependencyMixin]: - """ - List of leaf nodes -- ones with only upstream dependencies. - - a.k.a. the "end" of this sub-graph - """ - raise NotImplementedError() - - @abstractmethod - def set_upstream( - self, other: DependencyMixin | Sequence[DependencyMixin], edge_modifier: EdgeModifier | None = None - ): - """Set a task or a task list to be directly upstream from the current task.""" - raise NotImplementedError() - - @abstractmethod - def set_downstream( - self, other: DependencyMixin | Sequence[DependencyMixin], edge_modifier: EdgeModifier | None = None - ): - """Set a task or a task list to be directly downstream from the current task.""" - raise NotImplementedError() - - def as_setup(self) -> DependencyMixin: - """Mark a task as setup task.""" - raise NotImplementedError() - - def as_teardown( - self, - *, - setups: BaseOperator | Iterable[BaseOperator] | ArgNotSet = NOTSET, - on_failure_fail_dagrun=NOTSET, - ) -> DependencyMixin: - """Mark a task as teardown and set its setups as direct relatives.""" - raise NotImplementedError() - - def update_relative( - self, other: DependencyMixin, upstream: bool = True, edge_modifier: EdgeModifier | None = None - ) -> None: - """ - Update relationship information about another DependencyMixin. Default is no-op. - - Override if necessary. - """ - - def __lshift__(self, other: DependencyMixin | Sequence[DependencyMixin]): - """Implement Task << Task.""" - self.set_upstream(other) - return other - - def __rshift__(self, other: DependencyMixin | Sequence[DependencyMixin]): - """Implement Task >> Task.""" - self.set_downstream(other) - return other - - def __rrshift__(self, other: DependencyMixin | Sequence[DependencyMixin]): - """Implement Task >> [Task] because list don't have __rshift__ operators.""" - self.__lshift__(other) - return self - - def __rlshift__(self, other: DependencyMixin | Sequence[DependencyMixin]): - """Implement Task << [Task] because list don't have __lshift__ operators.""" - self.__rshift__(other) - return self - - @classmethod - def _iter_references(cls, obj: Any) -> Iterable[tuple[DependencyMixin, str]]: - from airflow.models.baseoperator import AbstractOperator - from airflow.utils.mixins import ResolveMixin - - if isinstance(obj, AbstractOperator): - yield obj, "operator" - elif isinstance(obj, ResolveMixin): - yield from obj.iter_references() - elif isinstance(obj, Sequence): - for o in obj: - yield from cls._iter_references(o) - - -class DAGNode(DependencyMixin, metaclass=ABCMeta): - """ - A base class for a node in the graph of a workflow. - - A node may be an Operator or a Task Group, either mapped or unmapped. - """ - - dag: DAG | None = None - task_group: TaskGroup | None = None - """The task_group that contains this node""" - - @property - @abstractmethod - def node_id(self) -> str: - raise NotImplementedError() - - @property - def label(self) -> str | None: - tg = self.task_group - if tg and tg.node_id and tg.prefix_group_id: - # "task_group_id.task_id" -> "task_id" - return self.node_id[len(tg.node_id) + 1 :] - return self.node_id - - start_date: pendulum.DateTime | None - end_date: pendulum.DateTime | None - upstream_task_ids: set[str] - downstream_task_ids: set[str] - - def has_dag(self) -> bool: - return self.dag is not None - - @property - def dag_id(self) -> str: - """Returns dag id if it has one or an adhoc/meaningless ID.""" - if self.dag: - return self.dag.dag_id - return "_in_memory_dag_" - - @property - def log(self) -> Logger: - raise NotImplementedError() - - @property - @abstractmethod - def roots(self) -> Sequence[DAGNode]: - raise NotImplementedError() - - @property - @abstractmethod - def leaves(self) -> Sequence[DAGNode]: - raise NotImplementedError() - - def _set_relatives( - self, - task_or_task_list: DependencyMixin | Sequence[DependencyMixin], - upstream: bool = False, - edge_modifier: EdgeModifier | None = None, - ) -> None: - """Set relatives for the task or task list.""" - from airflow.models.baseoperator import BaseOperator - from airflow.models.mappedoperator import MappedOperator - - if not isinstance(task_or_task_list, Sequence): - task_or_task_list = [task_or_task_list] - - task_list: list[Operator] = [] - for task_object in task_or_task_list: - task_object.update_relative(self, not upstream, edge_modifier=edge_modifier) - relatives = task_object.leaves if upstream else task_object.roots - for task in relatives: - if not isinstance(task, (BaseOperator, MappedOperator)): - raise AirflowException( - f"Relationships can only be set between Operators; received {task.__class__.__name__}" - ) - task_list.append(task) - - # relationships can only be set if the tasks share a single DAG. Tasks - # without a DAG are assigned to that DAG. - dags: set[DAG] = {task.dag for task in [*self.roots, *task_list] if task.has_dag() and task.dag} - - if len(dags) > 1: - raise AirflowException(f"Tried to set relationships between tasks in more than one DAG: {dags}") - elif len(dags) == 1: - dag = dags.pop() - else: - raise AirflowException( - f"Tried to create relationships between tasks that don't have DAGs yet. " - f"Set the DAG for at least one task and try again: {[self, *task_list]}" - ) - - if not self.has_dag(): - # If this task does not yet have a dag, add it to the same dag as the other task. - self.dag = dag - - for task in task_list: - if dag and not task.has_dag(): - # If the other task does not yet have a dag, add it to the same dag as this task and - dag.add_task(task) - if upstream: - task.downstream_task_ids.add(self.node_id) - self.upstream_task_ids.add(task.node_id) - if edge_modifier: - edge_modifier.add_edge_info(self.dag, task.node_id, self.node_id) - else: - self.downstream_task_ids.add(task.node_id) - task.upstream_task_ids.add(self.node_id) - if edge_modifier: - edge_modifier.add_edge_info(self.dag, self.node_id, task.node_id) - - def set_downstream( - self, - task_or_task_list: DependencyMixin | Sequence[DependencyMixin], - edge_modifier: EdgeModifier | None = None, - ) -> None: - """Set a node (or nodes) to be directly downstream from the current node.""" - self._set_relatives(task_or_task_list, upstream=False, edge_modifier=edge_modifier) - - def set_upstream( - self, - task_or_task_list: DependencyMixin | Sequence[DependencyMixin], - edge_modifier: EdgeModifier | None = None, - ) -> None: - """Set a node (or nodes) to be directly upstream from the current node.""" - self._set_relatives(task_or_task_list, upstream=True, edge_modifier=edge_modifier) - - @property - def downstream_list(self) -> Iterable[Operator]: - """List of nodes directly downstream.""" - if not self.dag: - raise AirflowException(f"Operator {self} has not been assigned to a DAG yet") - return [self.dag.get_task(tid) for tid in self.downstream_task_ids] - - @property - def upstream_list(self) -> Iterable[Operator]: - """List of nodes directly upstream.""" - if not self.dag: - raise AirflowException(f"Operator {self} has not been assigned to a DAG yet") - return [self.dag.get_task(tid) for tid in self.upstream_task_ids] - - def get_direct_relative_ids(self, upstream: bool = False) -> set[str]: - """Get set of the direct relative ids to the current task, upstream or downstream.""" - if upstream: - return self.upstream_task_ids - else: - return self.downstream_task_ids + from airflow.typing_compat import TypeAlias - def get_direct_relatives(self, upstream: bool = False) -> Iterable[DAGNode]: - """Get list of the direct relatives to the current task, upstream or downstream.""" - if upstream: - return self.upstream_list - else: - return self.downstream_list +import airflow.sdk.definitions.mixins +import airflow.sdk.definitions.node - def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]: - """Serialize a task group's content; used by TaskGroupSerialization.""" - raise NotImplementedError() +DependencyMixin: TypeAlias = airflow.sdk.definitions.mixins.DependencyMixin +DAGNode: TypeAlias = airflow.sdk.definitions.node.DAGNode diff --git a/airflow/models/xcom_arg.py b/airflow/models/xcom_arg.py index 83ff4f25c637..c28af6acbe5a 100644 --- a/airflow/models/xcom_arg.py +++ b/airflow/models/xcom_arg.py @@ -29,22 +29,21 @@ from airflow.models import MappedOperator, TaskInstance from airflow.models.abstractoperator import AbstractOperator from airflow.models.taskmixin import DependencyMixin +from airflow.sdk.types import NOTSET, ArgNotSet from airflow.utils.db import exists_query from airflow.utils.mixins import ResolveMixin from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.setup_teardown import SetupTeardownContext from airflow.utils.state import State from airflow.utils.trigger_rule import TriggerRule -from airflow.utils.types import NOTSET, ArgNotSet from airflow.utils.xcom import XCOM_RETURN_KEY if TYPE_CHECKING: from sqlalchemy.orm import Session - from airflow.models.baseoperator import BaseOperator - from airflow.models.dag import DAG + # from airflow.models.dag import DAG from airflow.models.operator import Operator - from airflow.models.taskmixin import DAGNode + from airflow.sdk import DAG, BaseOperator from airflow.utils.context import Context from airflow.utils.edgemodifier import EdgeModifier @@ -122,7 +121,7 @@ def iter_xcom_references(arg: Any) -> Iterator[tuple[Operator, str]]: yield from XComArg.iter_xcom_references(getattr(arg, attr)) @staticmethod - def apply_upstream_relationship(op: Operator, arg: Any): + def apply_upstream_relationship(op: DependencyMixin, arg: Any): """ Set dependency for XComArgs. @@ -134,12 +133,12 @@ def apply_upstream_relationship(op: Operator, arg: Any): op.set_upstream(operator) @property - def roots(self) -> list[DAGNode]: + def roots(self) -> list[Operator]: """Required by DependencyMixin.""" return [op for op, _ in self.iter_references()] @property - def leaves(self) -> list[DAGNode]: + def leaves(self) -> list[Operator]: """Required by DependencyMixin.""" return [op for op, _ in self.iter_references()] @@ -394,15 +393,15 @@ def as_setup(self) -> DependencyMixin: def as_teardown( self, *, - setups: BaseOperator | Iterable[BaseOperator] | ArgNotSet = NOTSET, - on_failure_fail_dagrun=NOTSET, + setups: BaseOperator | Iterable[BaseOperator] | None = None, + on_failure_fail_dagrun: bool | None = None, ): for operator, _ in self.iter_references(): operator.is_teardown = True operator.trigger_rule = TriggerRule.ALL_DONE_SETUP_SUCCESS - if on_failure_fail_dagrun is not NOTSET: + if on_failure_fail_dagrun is not None: operator.on_failure_fail_dagrun = on_failure_fail_dagrun - if not isinstance(setups, ArgNotSet): + if setups is not None: setups = [setups] if isinstance(setups, DependencyMixin) else setups for s in setups: s.is_setup = True diff --git a/airflow/operators/python.py b/airflow/operators/python.py index b032b45ed3e6..3d40ad2c8450 100644 --- a/airflow/operators/python.py +++ b/airflow/operators/python.py @@ -197,10 +197,7 @@ def my_python_callable(**kwargs): # since we won't mutate the arguments, we should just do the shallow copy # there are some cases we can't deepcopy the objects(e.g protobuf). - shallow_copy_attrs: Sequence[str] = ( - "python_callable", - "op_kwargs", - ) + shallow_copy_attrs: Sequence[str] = ("python_callable", "op_kwargs") def __init__( self, diff --git a/airflow/sensors/external_task.py b/airflow/sensors/external_task.py index 8eb501e281dd..331e17168bab 100644 --- a/airflow/sensors/external_task.py +++ b/airflow/sensors/external_task.py @@ -20,7 +20,7 @@ import datetime import os import warnings -from typing import TYPE_CHECKING, Any, Callable, Collection, Iterable +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Collection, Iterable from airflow.configuration import conf from airflow.exceptions import AirflowException, AirflowSkipException @@ -476,7 +476,7 @@ class ExternalTaskMarker(EmptyOperator): operator_extra_links = [ExternalDagLink()] # The _serialized_fields are lazily loaded when get_serialized_fields() method is called - __serialized_fields: frozenset[str] | None = None + __serialized_fields: ClassVar[frozenset[str] | None] = None def __init__( self, diff --git a/airflow/serialization/schema.json b/airflow/serialization/schema.json index fe1e63c49035..32ccd3dfff9c 100644 --- a/airflow/serialization/schema.json +++ b/airflow/serialization/schema.json @@ -137,7 +137,7 @@ "type": "object", "properties": { "params": { "$ref": "#/definitions/params" }, - "_dag_id": { "type": "string" }, + "dag_id": { "type": "string" }, "tasks": { "$ref": "#/definitions/tasks" }, "timezone": { "$ref": "#/definitions/timezone" }, "owner_links": { "type": "object" }, @@ -156,11 +156,10 @@ {"type": "string"} ] }, - "orientation": { "type" : "string"}, - "_dag_display_property_value": { "type" : "string"}, - "_description": { "type" : "string"}, + "dag_display_name": { "type" : "string"}, + "description": { "type" : "string"}, "_concurrency": { "type" : "number"}, - "_max_active_tasks": { "type" : "number"}, + "max_active_tasks": { "type" : "number"}, "max_active_runs": { "type" : "number"}, "max_consecutive_failed_dag_runs": { "type" : "number"}, "default_args": { "$ref": "#/definitions/dict" }, @@ -168,14 +167,13 @@ "end_date": { "$ref": "#/definitions/datetime" }, "dagrun_timeout": { "$ref": "#/definitions/timedelta" }, "doc_md": { "type" : "string"}, - "_default_view": { "type" : "string"}, - "_access_control": {"$ref": "#/definitions/dict" }, + "access_control": {"$ref": "#/definitions/dict" }, "is_paused_upon_creation": { "type": "boolean" }, "has_on_success_callback": { "type": "boolean" }, "has_on_failure_callback": { "type": "boolean" }, "render_template_as_native_obj": { "type": "boolean" }, "tags": { "type": "array" }, - "_task_group": {"anyOf": [ + "task_group": {"anyOf": [ { "type": "null" }, { "$ref": "#/definitions/task_group" } ]}, @@ -183,7 +181,7 @@ "dag_dependencies": { "$ref": "#/definitions/dag_dependencies" } }, "required": [ - "_dag_id", + "dag_id", "fileloc", "tasks" ], @@ -219,7 +217,7 @@ "$comment": "A task/operator in a DAG", "type": "object", "required": [ - "_task_type", + "task_type", "_task_module", "task_id", "ui_color", @@ -227,7 +225,7 @@ "template_fields" ], "properties": { - "_task_type": { "type": "string" }, + "task_type": { "type": "string" }, "_task_module": { "type": "string" }, "_operator_extra_links": { "$ref": "#/definitions/extra_links" }, "task_id": { "type": "string" }, diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 14528f62bddf..79403860f5fa 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -21,7 +21,7 @@ import collections.abc import datetime import enum -import inspect +import itertools import logging import weakref from functools import cache @@ -59,6 +59,7 @@ from airflow.models.tasklog import LogTemplate from airflow.models.xcom_arg import XComArg, deserialize_xcom_arg, serialize_xcom_arg from airflow.providers_manager import ProvidersManager +from airflow.sdk import BaseOperator as TaskSDKBaseOperator from airflow.serialization.dag_dependency import DagDependency from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding from airflow.serialization.helpers import serialize_template_field @@ -101,7 +102,7 @@ from airflow.models.baseoperatorlink import BaseOperatorLink from airflow.models.expandinput import ExpandInput from airflow.models.operator import Operator - from airflow.models.taskmixin import DAGNode + from airflow.sdk.definitions.node import DAGNode from airflow.serialization.json_schema import Validator from airflow.ti_deps.deps.base_ti_dep import BaseTIDep from airflow.timetables.base import Timetable @@ -597,7 +598,7 @@ def serialize_to_json( if key == "_operator_name": # when operator_name matches task_type, we can remove # it to reduce the JSON payload - task_type = getattr(object_to_serialize, "_task_type", None) + task_type = getattr(object_to_serialize, "task_type", None) if value != task_type: serialized_object[key] = cls.serialize(value) elif key in decorated_fields: @@ -920,10 +921,11 @@ def _value_is_hardcoded_default(cls, attrname: str, value: Any, instance: Any) - to account for the case where the default value of the field is None but has the ``field = field or {}`` set. """ - if attrname in cls._CONSTRUCTOR_PARAMS and ( - cls._CONSTRUCTOR_PARAMS[attrname] is value or (value in [{}, []]) - ): - return True + if attrname in cls._CONSTRUCTOR_PARAMS: + if cls._CONSTRUCTOR_PARAMS[attrname] is value or (value in [{}, []]): + return True + if cls._CONSTRUCTOR_PARAMS[attrname] is attrs.NOTHING and value is None: + return True return False @classmethod @@ -1079,7 +1081,10 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization): _CONSTRUCTOR_PARAMS = { k: v.default - for k, v in signature(BaseOperator.__init__).parameters.items() + for k, v in itertools.chain( + signature(BaseOperator.__init__).parameters.items(), + signature(TaskSDKBaseOperator.__init__).parameters.items(), + ) if v.default is not v.empty } @@ -1151,9 +1156,9 @@ def _serialize_node(cls, op: BaseOperator | MappedOperator, include_deps: bool) """Serialize operator into a JSON object.""" serialize_op = cls.serialize_to_json(op, cls._decorated_fields) - serialize_op["_task_type"] = getattr(op, "_task_type", type(op).__name__) + serialize_op["task_type"] = getattr(op, "task_type", type(op).__name__) serialize_op["_task_module"] = getattr(op, "_task_module", type(op).__module__) - if op.operator_name != serialize_op["_task_type"]: + if op.operator_name != serialize_op["task_type"]: serialize_op["_operator_name"] = op.operator_name # Used to determine if an Operator is inherited from EmptyOperator @@ -1177,7 +1182,7 @@ def _serialize_node(cls, op: BaseOperator | MappedOperator, include_deps: bool) # Store all template_fields as they are if there are JSON Serializable # If not, store them as strings # And raise an exception if the field is not templateable - forbidden_fields = set(inspect.signature(BaseOperator.__init__).parameters.keys()) + forbidden_fields = set(SerializedBaseOperator._CONSTRUCTOR_PARAMS.keys()) # Though allow some of the BaseOperator fields to be templated anyway forbidden_fields.difference_update({"email"}) if op.template_fields: @@ -1242,7 +1247,7 @@ def populate_operator(cls, op: Operator, encoded_op: dict[str, Any]) -> None: op_extra_links_from_plugin = {} if "_operator_name" not in encoded_op: - encoded_op["_operator_name"] = encoded_op["_task_type"] + encoded_op["_operator_name"] = encoded_op["task_type"] # We don't want to load Extra Operator links in Scheduler if cls._load_operator_extra_links: @@ -1256,7 +1261,7 @@ def populate_operator(cls, op: Operator, encoded_op: dict[str, Any]) -> None: for ope in plugins_manager.operator_extra_links: for operator in ope.operators: if ( - operator.__name__ == encoded_op["_task_type"] + operator.__name__ == encoded_op["task_type"] and operator.__module__ == encoded_op["_task_module"] ): op_extra_links_from_plugin.update({ope.name: ope}) @@ -1272,6 +1277,8 @@ def populate_operator(cls, op: Operator, encoded_op: dict[str, Any]) -> None: if k in ("_outlets", "_inlets"): # `_outlets` -> `outlets` k = k[1:] + elif k == "task_type": + k = "_task_type" if k == "_downstream_task_ids": # Upgrade from old format/name k = "downstream_task_ids" @@ -1383,7 +1390,7 @@ def deserialize_operator(cls, encoded_op: dict[str, Any]) -> Operator: try: operator_name = encoded_op["_operator_name"] except KeyError: - operator_name = encoded_op["_task_type"] + operator_name = encoded_op["task_type"] op = MappedOperator( operator_class=op_data, @@ -1400,7 +1407,7 @@ def deserialize_operator(cls, encoded_op: dict[str, Any]) -> Operator: ui_fgcolor=BaseOperator.ui_fgcolor, is_empty=False, task_module=encoded_op["_task_module"], - task_type=encoded_op["_task_type"], + task_type=encoded_op["task_type"], operator_name=operator_name, dag=None, task_group=None, @@ -1576,16 +1583,13 @@ class SerializedDAG(DAG, BaseSerialization): not pickle-able. SerializedDAG works for all DAGs. """ - _decorated_fields = {"default_args", "_access_control"} + _decorated_fields = {"default_args", "access_control"} @staticmethod def __get_constructor_defaults(): param_to_attr = { - "max_active_tasks": "_max_active_tasks", - "dag_display_name": "_dag_display_property_value", "description": "_description", "default_view": "_default_view", - "access_control": "_access_control", } return { param_to_attr.get(k, k): v.default @@ -1613,7 +1617,7 @@ def serialize_dag(cls, dag: DAG) -> dict: ] dag_deps.extend(DependencyDetector.detect_dag_dependencies(dag)) serialized_dag["dag_dependencies"] = [x.__dict__ for x in sorted(dag_deps)] - serialized_dag["_task_group"] = TaskGroupSerialization.serialize_task_group(dag.task_group) + serialized_dag["task_group"] = TaskGroupSerialization.serialize_task_group(dag.task_group) # Edge info in the JSON exactly matches our internal structure serialized_dag["edge_info"] = dag.edge_info @@ -1633,7 +1637,7 @@ def serialize_dag(cls, dag: DAG) -> dict: @classmethod def deserialize_dag(cls, encoded_dag: dict[str, Any]) -> SerializedDAG: """Deserializes a DAG from a JSON object.""" - dag = SerializedDAG(dag_id=encoded_dag["_dag_id"], schedule=None) + dag = SerializedDAG(dag_id=encoded_dag["dag_id"], schedule=None) for k, v in encoded_dag.items(): if k == "_downstream_task_ids": @@ -1668,20 +1672,21 @@ def deserialize_dag(cls, encoded_dag: dict[str, Any]) -> SerializedDAG: v = set(v) # else use v as it is - setattr(dag, k, v) + object.__setattr__(dag, k, v) # Set _task_group - if "_task_group" in encoded_dag: - dag._task_group = TaskGroupSerialization.deserialize_task_group( - encoded_dag["_task_group"], + if "task_group" in encoded_dag: + tg = TaskGroupSerialization.deserialize_task_group( + encoded_dag["task_group"], None, dag.task_dict, dag, ) + object.__setattr__(dag, "task_group", tg) else: # This must be old data that had no task_group. Create a root TaskGroup and add # all tasks to it. - dag._task_group = TaskGroup.create_root(dag) + object.__setattr__(dag, "task_group", TaskGroup.create_root(dag)) for task in dag.tasks: dag.task_group.add(task) @@ -1704,8 +1709,10 @@ def deserialize_dag(cls, encoded_dag: dict[str, Any]) -> SerializedDAG: def _is_excluded(cls, var: Any, attrname: str, op: DAGNode): # {} is explicitly different from None in the case of DAG-level access control # and as a result we need to preserve empty dicts through serialization for this field - if attrname == "_access_control" and var is not None: + if attrname == "access_control" and var is not None: return False + if attrname == "dag_display_name" and var == op.dag_id: + return True return super()._is_excluded(var, attrname, op) @classmethod diff --git a/airflow/task/priority_strategy.py b/airflow/task/priority_strategy.py index c22bdfa9940c..dcef1c865b6e 100644 --- a/airflow/task/priority_strategy.py +++ b/airflow/task/priority_strategy.py @@ -22,8 +22,6 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any -from airflow.exceptions import AirflowException - if TYPE_CHECKING: from airflow.models.taskinstance import TaskInstance @@ -150,5 +148,5 @@ def validate_and_load_priority_weight_strategy( priority_weight_strategy_class = qualname(priority_weight_strategy) loaded_priority_weight_strategy = _get_registered_priority_weight_strategy(priority_weight_strategy_class) if loaded_priority_weight_strategy is None: - raise AirflowException(f"Unknown priority strategy {priority_weight_strategy_class}") + raise ValueError(f"Unknown priority strategy {priority_weight_strategy_class}") return loaded_priority_weight_strategy() diff --git a/airflow/template/templater.py b/airflow/template/templater.py index fc37e18e0cdc..70be10136495 100644 --- a/airflow/template/templater.py +++ b/airflow/template/templater.py @@ -26,10 +26,12 @@ from airflow.utils.mixins import ResolveMixin if TYPE_CHECKING: + from collections.abc import Mapping + import jinja2 - from airflow import DAG from airflow.models.operator import Operator + from airflow.sdk import DAG from airflow.utils.context import Context @@ -106,7 +108,7 @@ def _do_render_template_fields( self, parent: Any, template_fields: Iterable[str], - context: Context, + context: Mapping[str, Any], jinja_env: jinja2.Environment, seen_oids: set[int], ) -> None: @@ -121,7 +123,7 @@ def _do_render_template_fields( if rendered_content: setattr(parent, attr_name, rendered_content) - def _render(self, template, context, dag: DAG | None = None) -> Any: + def _render(self, template, context, dag=None) -> Any: if dag and dag.render_template_as_native_obj: return render_template_as_native(template, context) return render_template_to_string(template, context) @@ -129,7 +131,7 @@ def _render(self, template, context, dag: DAG | None = None) -> Any: def render_template( self, content: Any, - context: Context, + context: Mapping[str, Any], jinja_env: jinja2.Environment | None = None, seen_oids: set[int] | None = None, ) -> Any: @@ -172,7 +174,8 @@ def render_template( if isinstance(value, ObjectStoragePath): return self._render_object_storage_path(value, context, jinja_env) if isinstance(value, ResolveMixin): - return value.resolve(context, include_xcom=True) + # TODO: Task-SDK: Tidy up the typing on template context + return value.resolve(context, include_xcom=True) # type: ignore[arg-type] # Fast path for common built-in collections. if value.__class__ is tuple: @@ -191,7 +194,7 @@ def render_template( return value def _render_object_storage_path( - self, value: ObjectStoragePath, context: Context, jinja_env: jinja2.Environment + self, value: ObjectStoragePath, context: Mapping[str, Any], jinja_env: jinja2.Environment ) -> ObjectStoragePath: serialized_path = value.serialize() path_version = value.__version__ @@ -201,7 +204,7 @@ def _render_object_storage_path( def _render_nested_template_fields( self, value: Any, - context: Context, + context: Mapping[str, Any], jinja_env: jinja2.Environment, seen_oids: set[int], ) -> None: diff --git a/airflow/typing_compat.py b/airflow/typing_compat.py index ba96c92d77f0..946fdce5f26c 100644 --- a/airflow/typing_compat.py +++ b/airflow/typing_compat.py @@ -24,6 +24,7 @@ "ParamSpec", "Protocol", "Self", + "TypeAlias", "TypedDict", "TypeGuard", "runtime_checkable", @@ -43,9 +44,9 @@ from typing_extensions import Literal # type: ignore[assignment] if sys.version_info >= (3, 10): - from typing import ParamSpec, TypeGuard + from typing import ParamSpec, TypeAlias, TypeGuard else: - from typing_extensions import ParamSpec, TypeGuard + from typing_extensions import ParamSpec, TypeAlias, TypeGuard if sys.version_info >= (3, 11): from typing import Self diff --git a/airflow/ui/openapi-gen/requests/schemas.gen.ts b/airflow/ui/openapi-gen/requests/schemas.gen.ts index a991cce370bd..f6b1efeb1f1d 100644 --- a/airflow/ui/openapi-gen/requests/schemas.gen.ts +++ b/airflow/ui/openapi-gen/requests/schemas.gen.ts @@ -517,10 +517,6 @@ export const $DAGDetailsResponse = { ], title: "Is Paused Upon Creation", }, - orientation: { - type: "string", - title: "Orientation", - }, params: { anyOf: [ { @@ -619,7 +615,6 @@ export const $DAGDetailsResponse = { "start_date", "end_date", "is_paused_upon_creation", - "orientation", "params", "render_template_as_native_obj", "template_search_path", diff --git a/airflow/ui/openapi-gen/requests/types.gen.ts b/airflow/ui/openapi-gen/requests/types.gen.ts index d375978b91f1..52862ae12998 100644 --- a/airflow/ui/openapi-gen/requests/types.gen.ts +++ b/airflow/ui/openapi-gen/requests/types.gen.ts @@ -95,7 +95,6 @@ export type DAGDetailsResponse = { start_date: string | null; end_date: string | null; is_paused_upon_creation: boolean | null; - orientation: string; params: { [key: string]: unknown; } | null; diff --git a/airflow/utils/decorators.py b/airflow/utils/decorators.py index e299999423e5..78044e4e3576 100644 --- a/airflow/utils/decorators.py +++ b/airflow/utils/decorators.py @@ -69,8 +69,9 @@ def _balance_parens(after_decorator): class _autostacklevel_warn: - def __init__(self): + def __init__(self, delta): self.warnings = __import__("warnings") + self.delta = delta def __getattr__(self, name): return getattr(self.warnings, name) @@ -79,11 +80,11 @@ def __dir__(self): return dir(self.warnings) def warn(self, message, category=None, stacklevel=1, source=None): - self.warnings.warn(message, category, stacklevel + 2, source) + self.warnings.warn(message, category, stacklevel + self.delta, source) -def fixup_decorator_warning_stack(func): +def fixup_decorator_warning_stack(func, delta: int = 2): if func.__globals__.get("warnings") is sys.modules["warnings"]: # Yes, this is more than slightly hacky, but it _automatically_ sets the right stacklevel parameter to # `warnings.warn` to ignore the decorator. - func.__globals__["warnings"] = _autostacklevel_warn() + func.__globals__["warnings"] = _autostacklevel_warn(delta) diff --git a/airflow/utils/edgemodifier.py b/airflow/utils/edgemodifier.py index a78e6c649992..b4a7a3d09462 100644 --- a/airflow/utils/edgemodifier.py +++ b/airflow/utils/edgemodifier.py @@ -16,158 +16,14 @@ # under the License. from __future__ import annotations -from typing import Sequence +from typing import TYPE_CHECKING -from airflow.models.taskmixin import DAGNode, DependencyMixin -from airflow.utils.task_group import TaskGroup +import airflow.sdk +if TYPE_CHECKING: + from airflow.typing_compat import TypeAlias -class EdgeModifier(DependencyMixin): - """ - Class that represents edge information to be added between two tasks/operators. - - Has shorthand factory functions, like Label("hooray"). - - Current implementation supports - t1 >> Label("Success route") >> t2 - t2 << Label("Success route") << t2 - - Note that due to the potential for use in either direction, this waits - to make the actual connection between both sides until both are declared, - and will do so progressively if multiple ups/downs are added. - - This and EdgeInfo are related - an EdgeModifier is the Python object you - use to add information to (potentially multiple) edges, and EdgeInfo - is the representation of the information for one specific edge. - """ - - def __init__(self, label: str | None = None): - self.label = label - self._upstream: list[DependencyMixin] = [] - self._downstream: list[DependencyMixin] = [] - - @property - def roots(self): - return self._downstream - - @property - def leaves(self): - return self._upstream - - @staticmethod - def _make_list(item_or_list: DependencyMixin | Sequence[DependencyMixin]) -> Sequence[DependencyMixin]: - if not isinstance(item_or_list, Sequence): - return [item_or_list] - return item_or_list - - def _save_nodes( - self, - nodes: DependencyMixin | Sequence[DependencyMixin], - stream: list[DependencyMixin], - ): - from airflow.models.xcom_arg import XComArg - - for node in self._make_list(nodes): - if isinstance(node, (TaskGroup, XComArg, DAGNode)): - stream.append(node) - else: - raise TypeError( - f"Cannot use edge labels with {type(node).__name__}, " - f"only tasks, XComArg or TaskGroups" - ) - - def _convert_streams_to_task_groups(self): - """ - Convert a node to a TaskGroup or leave it as a DAGNode. - - Requires both self._upstream and self._downstream. - - To do this, we keep a set of group_ids seen among the streams. If we find that - the nodes are from the same TaskGroup, we will leave them as DAGNodes and not - convert them to TaskGroups - """ - from airflow.models.xcom_arg import XComArg - - group_ids = set() - for node in [*self._upstream, *self._downstream]: - if isinstance(node, DAGNode) and node.task_group: - if node.task_group.is_root: - group_ids.add("root") - else: - group_ids.add(node.task_group.group_id) - elif isinstance(node, TaskGroup): - group_ids.add(node.group_id) - elif isinstance(node, XComArg): - if isinstance(node.operator, DAGNode) and node.operator.task_group: - if node.operator.task_group.is_root: - group_ids.add("root") - else: - group_ids.add(node.operator.task_group.group_id) - - # If all nodes originate from the same TaskGroup, we will not convert them - if len(group_ids) != 1: - self._upstream = self._convert_stream_to_task_groups(self._upstream) - self._downstream = self._convert_stream_to_task_groups(self._downstream) - - def _convert_stream_to_task_groups(self, stream: Sequence[DependencyMixin]) -> Sequence[DependencyMixin]: - return [ - node.task_group - if isinstance(node, DAGNode) and node.task_group and not node.task_group.is_root - else node - for node in stream - ] - - def set_upstream( - self, - other: DependencyMixin | Sequence[DependencyMixin], - edge_modifier: EdgeModifier | None = None, - ): - """ - Set the given task/list onto the upstream attribute, then attempt to resolve the relationship. - - Providing this also provides << via DependencyMixin. - """ - self._save_nodes(other, self._upstream) - if self._upstream and self._downstream: - # Convert _upstream and _downstream to task_groups only after both are set - self._convert_streams_to_task_groups() - for node in self._downstream: - node.set_upstream(other, edge_modifier=self) - - def set_downstream( - self, - other: DependencyMixin | Sequence[DependencyMixin], - edge_modifier: EdgeModifier | None = None, - ): - """ - Set the given task/list onto the downstream attribute, then attempt to resolve the relationship. - - Providing this also provides >> via DependencyMixin. - """ - self._save_nodes(other, self._downstream) - if self._upstream and self._downstream: - # Convert _upstream and _downstream to task_groups only after both are set - self._convert_streams_to_task_groups() - for node in self._upstream: - node.set_downstream(other, edge_modifier=self) - - def update_relative( - self, other: DependencyMixin, upstream: bool = True, edge_modifier: EdgeModifier | None = None - ) -> None: - """Update relative if we're not the "main" side of a relationship; still run the same logic.""" - if upstream: - self.set_upstream(other) - else: - self.set_downstream(other) - - def add_edge_info(self, dag, upstream_id: str, downstream_id: str): - """ - Add or update task info on the DAG for this specific pair of tasks. - - Called either from our relationship trigger methods above, or directly - by set_upstream/set_downstream in operators. - """ - dag.set_edge_info(upstream_id, downstream_id, {"label": self.label}) +EdgeModifier: TypeAlias = airflow.sdk.EdgeModifier # Factory functions diff --git a/airflow/utils/log/logging_mixin.py b/airflow/utils/log/logging_mixin.py index 270d3fb6cc00..cc9f3b2357f5 100644 --- a/airflow/utils/log/logging_mixin.py +++ b/airflow/utils/log/logging_mixin.py @@ -79,6 +79,7 @@ class LoggingMixin: def __init__(self, context=None): self._set_context(context) + super().__init__() @staticmethod def _create_logger_name( diff --git a/airflow/utils/task_group.py b/airflow/utils/task_group.py index 69a5d015bd42..1f94880902c9 100644 --- a/airflow/utils/task_group.py +++ b/airflow/utils/task_group.py @@ -19,566 +19,22 @@ from __future__ import annotations -import copy import functools import operator -import weakref -from typing import TYPE_CHECKING, Any, Generator, Iterator, Sequence +from typing import TYPE_CHECKING, Iterator -import methodtools -import re2 - -from airflow.exceptions import ( - AirflowDagCycleException, - AirflowException, - DuplicateTaskIdFound, - TaskAlreadyInTaskGroup, -) -from airflow.models.taskmixin import DAGNode -from airflow.serialization.enums import DagAttributeTypes -from airflow.utils.helpers import validate_group_key, validate_instance_args +import airflow.sdk.definitions.taskgroup if TYPE_CHECKING: from sqlalchemy.orm import Session - from airflow.models.abstractoperator import AbstractOperator - from airflow.models.baseoperator import BaseOperator - from airflow.models.dag import DAG - from airflow.models.expandinput import ExpandInput from airflow.models.operator import Operator - from airflow.models.taskmixin import DependencyMixin - from airflow.utils.edgemodifier import EdgeModifier - -# TODO: The following mapping is used to validate that the arguments passed to the TaskGroup are of the -# correct type. This is a temporary solution until we find a more sophisticated method for argument -# validation. One potential method is to use get_type_hints from the typing module. However, this is not -# fully compatible with future annotations for Python versions below 3.10. Once we require a minimum Python -# version that supports `get_type_hints` effectively or find a better approach, we can replace this -# manual type-checking method. -TASKGROUP_ARGS_EXPECTED_TYPES = { - "group_id": str, - "prefix_group_id": bool, - "tooltip": str, - "ui_color": str, - "ui_fgcolor": str, - "add_suffix_on_collision": bool, -} + from airflow.typing_compat import TypeAlias +TaskGroup: TypeAlias = airflow.sdk.definitions.taskgroup.TaskGroup -class TaskGroup(DAGNode): - """ - A collection of tasks. - - When set_downstream() or set_upstream() are called on the TaskGroup, it is applied across - all tasks within the group if necessary. - :param group_id: a unique, meaningful id for the TaskGroup. group_id must not conflict - with group_id of TaskGroup or task_id of tasks in the DAG. Root TaskGroup has group_id - set to None. - :param prefix_group_id: If set to True, child task_id and group_id will be prefixed with - this TaskGroup's group_id. If set to False, child task_id and group_id are not prefixed. - Default is True. - :param parent_group: The parent TaskGroup of this TaskGroup. parent_group is set to None - for the root TaskGroup. - :param dag: The DAG that this TaskGroup belongs to. - :param default_args: A dictionary of default parameters to be used - as constructor keyword parameters when initialising operators, - will override default_args defined in the DAG level. - Note that operators have the same hook, and precede those defined - here, meaning that if your dict contains `'depends_on_past': True` - here and `'depends_on_past': False` in the operator's call - `default_args`, the actual value will be `False`. - :param tooltip: The tooltip of the TaskGroup node when displayed in the UI - :param ui_color: The fill color of the TaskGroup node when displayed in the UI - :param ui_fgcolor: The label color of the TaskGroup node when displayed in the UI - :param add_suffix_on_collision: If this task group name already exists, - automatically add `__1` etc suffixes - """ - - used_group_ids: set[str | None] - - def __init__( - self, - group_id: str | None, - prefix_group_id: bool = True, - parent_group: TaskGroup | None = None, - dag: DAG | None = None, - default_args: dict[str, Any] | None = None, - tooltip: str = "", - ui_color: str = "CornflowerBlue", - ui_fgcolor: str = "#000", - add_suffix_on_collision: bool = False, - ): - from airflow.models.dag import DagContext - - self.prefix_group_id = prefix_group_id - self.default_args = copy.deepcopy(default_args or {}) - - dag = dag or DagContext.get_current_dag() - - if group_id is None: - # This creates a root TaskGroup. - if parent_group: - raise AirflowException("Root TaskGroup cannot have parent_group") - # used_group_ids is shared across all TaskGroups in the same DAG to keep track - # of used group_id to avoid duplication. - self.used_group_ids = set() - self.dag = dag - else: - if prefix_group_id: - # If group id is used as prefix, it should not contain spaces nor dots - # because it is used as prefix in the task_id - validate_group_key(group_id) - else: - if not isinstance(group_id, str): - raise ValueError("group_id must be str") - if not group_id: - raise ValueError("group_id must not be empty") - - if not parent_group and not dag: - raise AirflowException("TaskGroup can only be used inside a dag") - - parent_group = parent_group or TaskGroupContext.get_current_task_group(dag) - if not parent_group: - raise AirflowException("TaskGroup must have a parent_group except for the root TaskGroup") - if dag is not parent_group.dag: - raise RuntimeError( - "Cannot mix TaskGroups from different DAGs: %s and %s", dag, parent_group.dag - ) - - self.used_group_ids = parent_group.used_group_ids - - # if given group_id already used assign suffix by incrementing largest used suffix integer - # Example : task_group ==> task_group__1 -> task_group__2 -> task_group__3 - self._group_id = group_id - self._check_for_group_id_collisions(add_suffix_on_collision) - - self.children: dict[str, DAGNode] = {} - - if parent_group: - parent_group.add(self) - self._update_default_args(parent_group) - - self.used_group_ids.add(self.group_id) - if self.group_id: - self.used_group_ids.add(self.downstream_join_id) - self.used_group_ids.add(self.upstream_join_id) - - self.tooltip = tooltip - self.ui_color = ui_color - self.ui_fgcolor = ui_fgcolor - - # Keep track of TaskGroups or tasks that depend on this entire TaskGroup separately - # so that we can optimize the number of edges when entire TaskGroups depend on each other. - self.upstream_group_ids: set[str | None] = set() - self.downstream_group_ids: set[str | None] = set() - self.upstream_task_ids = set() - self.downstream_task_ids = set() - - validate_instance_args(self, TASKGROUP_ARGS_EXPECTED_TYPES) - - def _check_for_group_id_collisions(self, add_suffix_on_collision: bool): - if self._group_id is None: - return - # if given group_id already used assign suffix by incrementing largest used suffix integer - # Example : task_group ==> task_group__1 -> task_group__2 -> task_group__3 - if self._group_id in self.used_group_ids: - if not add_suffix_on_collision: - raise DuplicateTaskIdFound(f"group_id '{self._group_id}' has already been added to the DAG") - base = re2.split(r"__\d+$", self._group_id)[0] - suffixes = sorted( - int(re2.split(r"^.+__", used_group_id)[1]) - for used_group_id in self.used_group_ids - if used_group_id is not None and re2.match(rf"^{base}__\d+$", used_group_id) - ) - if not suffixes: - self._group_id += "__1" - else: - self._group_id = f"{base}__{suffixes[-1] + 1}" - - def _update_default_args(self, parent_group: TaskGroup): - if parent_group.default_args: - self.default_args = {**parent_group.default_args, **self.default_args} - - @classmethod - def create_root(cls, dag: DAG) -> TaskGroup: - """Create a root TaskGroup with no group_id or parent.""" - return cls(group_id=None, dag=dag) - - @property - def node_id(self): - return self.group_id - - @property - def is_root(self) -> bool: - """Returns True if this TaskGroup is the root TaskGroup. Otherwise False.""" - return not self.group_id - - @property - def parent_group(self) -> TaskGroup | None: - return self.task_group - - def __iter__(self): - for child in self.children.values(): - if isinstance(child, TaskGroup): - yield from child - else: - yield child - - def add(self, task: DAGNode) -> DAGNode: - """ - Add a task to this TaskGroup. - - :meta private: - """ - from airflow.models.abstractoperator import AbstractOperator - - if TaskGroupContext.active: - if task.task_group and task.task_group != self: - task.task_group.children.pop(task.node_id, None) - task.task_group = self - existing_tg = task.task_group - if isinstance(task, AbstractOperator) and existing_tg is not None and existing_tg != self: - raise TaskAlreadyInTaskGroup(task.node_id, existing_tg.node_id, self.node_id) - - # Set the TG first, as setting it might change the return value of node_id! - task.task_group = weakref.proxy(self) - key = task.node_id - - if key in self.children: - node_type = "Task" if hasattr(task, "task_id") else "Task Group" - raise DuplicateTaskIdFound(f"{node_type} id '{key}' has already been added to the DAG") - - if isinstance(task, TaskGroup): - if self.dag: - if task.dag is not None and self.dag is not task.dag: - raise RuntimeError( - "Cannot mix TaskGroups from different DAGs: %s and %s", self.dag, task.dag - ) - task.dag = self.dag - if task.children: - raise AirflowException("Cannot add a non-empty TaskGroup") - - self.children[key] = task - return task - - def _remove(self, task: DAGNode) -> None: - key = task.node_id - - if key not in self.children: - raise KeyError(f"Node id {key!r} not part of this task group") - - self.used_group_ids.remove(key) - del self.children[key] - - @property - def group_id(self) -> str | None: - """group_id of this TaskGroup.""" - if self.task_group and self.task_group.prefix_group_id and self.task_group._group_id: - # defer to parent whether it adds a prefix - return self.task_group.child_id(self._group_id) - - return self._group_id - - @property - def label(self) -> str | None: - """group_id excluding parent's group_id used as the node label in UI.""" - return self._group_id - - def update_relative( - self, other: DependencyMixin, upstream: bool = True, edge_modifier: EdgeModifier | None = None - ) -> None: - """ - Override TaskMixin.update_relative. - - Update upstream_group_ids/downstream_group_ids/upstream_task_ids/downstream_task_ids - accordingly so that we can reduce the number of edges when displaying Graph view. - """ - if isinstance(other, TaskGroup): - # Handles setting relationship between a TaskGroup and another TaskGroup - if upstream: - parent, child = (self, other) - if edge_modifier: - edge_modifier.add_edge_info(self.dag, other.downstream_join_id, self.upstream_join_id) - else: - parent, child = (other, self) - if edge_modifier: - edge_modifier.add_edge_info(self.dag, self.downstream_join_id, other.upstream_join_id) - - parent.upstream_group_ids.add(child.group_id) - child.downstream_group_ids.add(parent.group_id) - else: - # Handles setting relationship between a TaskGroup and a task - for task in other.roots: - if not isinstance(task, DAGNode): - raise AirflowException( - "Relationships can only be set between TaskGroup " - f"or operators; received {task.__class__.__name__}" - ) - - # Do not set a relationship between a TaskGroup and a Label's roots - if self == task: - continue - - if upstream: - self.upstream_task_ids.add(task.node_id) - if edge_modifier: - edge_modifier.add_edge_info(self.dag, task.node_id, self.upstream_join_id) - else: - self.downstream_task_ids.add(task.node_id) - if edge_modifier: - edge_modifier.add_edge_info(self.dag, self.downstream_join_id, task.node_id) - - def _set_relatives( - self, - task_or_task_list: DependencyMixin | Sequence[DependencyMixin], - upstream: bool = False, - edge_modifier: EdgeModifier | None = None, - ) -> None: - """ - Call set_upstream/set_downstream for all root/leaf tasks within this TaskGroup. - - Update upstream_group_ids/downstream_group_ids/upstream_task_ids/downstream_task_ids. - """ - if not isinstance(task_or_task_list, Sequence): - task_or_task_list = [task_or_task_list] - - for task_like in task_or_task_list: - self.update_relative(task_like, upstream, edge_modifier=edge_modifier) - - if upstream: - for task in self.get_roots(): - task.set_upstream(task_or_task_list) - else: - for task in self.get_leaves(): - task.set_downstream(task_or_task_list) - - def __enter__(self) -> TaskGroup: - TaskGroupContext.push_context_managed_task_group(self) - return self - - def __exit__(self, _type, _value, _tb): - TaskGroupContext.pop_context_managed_task_group() - - def has_task(self, task: BaseOperator) -> bool: - """Return True if this TaskGroup or its children TaskGroups contains the given task.""" - if task.task_id in self.children: - return True - - return any(child.has_task(task) for child in self.children.values() if isinstance(child, TaskGroup)) - - @property - def roots(self) -> list[BaseOperator]: - """Required by DependencyMixin.""" - return list(self.get_roots()) - - @property - def leaves(self) -> list[BaseOperator]: - """Required by DependencyMixin.""" - return list(self.get_leaves()) - - def get_roots(self) -> Generator[BaseOperator, None, None]: - """Return a generator of tasks with no upstream dependencies within the TaskGroup.""" - tasks = list(self) - ids = {x.task_id for x in tasks} - for task in tasks: - if task.upstream_task_ids.isdisjoint(ids): - yield task - - def get_leaves(self) -> Generator[BaseOperator, None, None]: - """Return a generator of tasks with no downstream dependencies within the TaskGroup.""" - tasks = list(self) - ids = {x.task_id for x in tasks} - - def has_non_teardown_downstream(task, exclude: str): - for down_task in task.downstream_list: - if down_task.task_id == exclude: - continue - elif down_task.task_id not in ids: - continue - elif not down_task.is_teardown: - return True - return False - - def recurse_for_first_non_teardown(task): - for upstream_task in task.upstream_list: - if upstream_task.task_id not in ids: - # upstream task is not in task group - continue - elif upstream_task.is_teardown: - yield from recurse_for_first_non_teardown(upstream_task) - elif task.is_teardown and upstream_task.is_setup: - # don't go through the teardown-to-setup path - continue - # return unless upstream task already has non-teardown downstream in group - elif not has_non_teardown_downstream(upstream_task, exclude=task.task_id): - yield upstream_task - - for task in tasks: - if task.downstream_task_ids.isdisjoint(ids): - if not task.is_teardown: - yield task - else: - yield from recurse_for_first_non_teardown(task) - - def child_id(self, label): - """Prefix label with group_id if prefix_group_id is True. Otherwise return the label as-is.""" - if self.prefix_group_id: - group_id = self.group_id - if group_id: - return f"{group_id}.{label}" - - return label - - @property - def upstream_join_id(self) -> str: - """ - Creates a unique ID for upstream dependencies of this TaskGroup. - - If this TaskGroup has immediate upstream TaskGroups or tasks, a proxy node called - upstream_join_id will be created in Graph view to join the outgoing edges from this - TaskGroup to reduce the total number of edges needed to be displayed. - """ - return f"{self.group_id}.upstream_join_id" - - @property - def downstream_join_id(self) -> str: - """ - Creates a unique ID for downstream dependencies of this TaskGroup. - - If this TaskGroup has immediate downstream TaskGroups or tasks, a proxy node called - downstream_join_id will be created in Graph view to join the outgoing edges from this - TaskGroup to reduce the total number of edges needed to be displayed. - """ - return f"{self.group_id}.downstream_join_id" - - def get_task_group_dict(self) -> dict[str, TaskGroup]: - """Return a flat dictionary of group_id: TaskGroup.""" - task_group_map = {} - - def build_map(task_group): - if not isinstance(task_group, TaskGroup): - return - - task_group_map[task_group.group_id] = task_group - - for child in task_group.children.values(): - build_map(child) - - build_map(self) - return task_group_map - - def get_child_by_label(self, label: str) -> DAGNode: - """Get a child task/TaskGroup by its label (i.e. task_id/group_id without the group_id prefix).""" - return self.children[self.child_id(label)] - - def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]: - """Serialize task group; required by DAGNode.""" - from airflow.serialization.serialized_objects import TaskGroupSerialization - - return DagAttributeTypes.TASK_GROUP, TaskGroupSerialization.serialize_task_group(self) - - def hierarchical_alphabetical_sort(self): - """ - Sort children in hierarchical alphabetical order. - - - groups in alphabetical order first - - tasks in alphabetical order after them. - - :return: list of tasks in hierarchical alphabetical order - """ - return sorted( - self.children.values(), key=lambda node: (not isinstance(node, TaskGroup), node.node_id) - ) - - def topological_sort(self): - """ - Sorts children in topographical order, such that a task comes after any of its upstream dependencies. - - :return: list of tasks in topological order - """ - # This uses a modified version of Kahn's Topological Sort algorithm to - # not have to pre-compute the "in-degree" of the nodes. - graph_unsorted = copy.copy(self.children) - - graph_sorted: list[DAGNode] = [] - - # special case - if not self.children: - return graph_sorted - - # Run until the unsorted graph is empty. - while graph_unsorted: - # Go through each of the node/edges pairs in the unsorted graph. If a set of edges doesn't contain - # any nodes that haven't been resolved, that is, that are still in the unsorted graph, remove the - # pair from the unsorted graph, and append it to the sorted graph. Note here that by using - # the values() method for iterating, a copy of the unsorted graph is used, allowing us to modify - # the unsorted graph as we move through it. - # - # We also keep a flag for checking that graph is acyclic, which is true if any nodes are resolved - # during each pass through the graph. If not, we need to exit as the graph therefore can't be - # sorted. - acyclic = False - for node in list(graph_unsorted.values()): - for edge in node.upstream_list: - if edge.node_id in graph_unsorted: - break - # Check for task's group is a child (or grand child) of this TG, - tg = edge.task_group - while tg: - if tg.node_id in graph_unsorted: - break - tg = tg.task_group - - if tg: - # We are already going to visit that TG - break - else: - acyclic = True - del graph_unsorted[node.node_id] - graph_sorted.append(node) - - if not acyclic: - raise AirflowDagCycleException(f"A cyclic dependency occurred in dag: {self.dag_id}") - - return graph_sorted - - def iter_mapped_task_groups(self) -> Iterator[MappedTaskGroup]: - """ - Return mapped task groups in the hierarchy. - - Groups are returned from the closest to the outmost. If *self* is a - mapped task group, it is returned first. - - :meta private: - """ - group: TaskGroup | None = self - while group is not None: - if isinstance(group, MappedTaskGroup): - yield group - group = group.task_group - - def iter_tasks(self) -> Iterator[AbstractOperator]: - """Return an iterator of the child tasks.""" - from airflow.models.abstractoperator import AbstractOperator - - groups_to_visit = [self] - - while groups_to_visit: - visiting = groups_to_visit.pop(0) - - for child in visiting.children.values(): - if isinstance(child, AbstractOperator): - yield child - elif isinstance(child, TaskGroup): - groups_to_visit.append(child) - else: - raise ValueError( - f"Encountered a DAGNode that is not a TaskGroup or an AbstractOperator: {type(child)}" - ) - - -class MappedTaskGroup(TaskGroup): +class MappedTaskGroup(airflow.sdk.definitions.taskgroup.MappedTaskGroup): """ A mapped task group. @@ -589,10 +45,6 @@ class MappedTaskGroup(TaskGroup): a ``@task_group`` function instead. """ - def __init__(self, *, expand_input: ExpandInput, **kwargs: Any) -> None: - super().__init__(**kwargs) - self._expand_input = expand_input - def iter_mapped_dependencies(self) -> Iterator[Operator]: """Upstream dependencies that provide XComs used by this mapped task group.""" from airflow.models.xcom_arg import XComArg @@ -600,27 +52,6 @@ def iter_mapped_dependencies(self) -> Iterator[Operator]: for op, _ in XComArg.iter_xcom_references(self._expand_input): yield op - @methodtools.lru_cache(maxsize=None) - def get_parse_time_mapped_ti_count(self) -> int: - """ - Return the Number of instances a task in this group should be mapped to, when a DAG run is created. - - This only considers literal mapped arguments, and would return *None* - when any non-literal values are used for mapping. - - If this group is inside mapped task groups, all the nested counts are - multiplied and accounted. - - :meta private: - - :raise NotFullyPopulated: If any non-literal mapped arguments are encountered. - :return: The total number of mapped instances each task should have. - """ - return functools.reduce( - operator.mul, - (g._expand_input.get_parse_time_mapped_ti_count() for g in self.iter_mapped_task_groups()), - ) - def get_mapped_ti_count(self, run_id: str, *, session: Session) -> int: """ Return the number of instances a task in this group should be mapped to at run time. @@ -644,51 +75,6 @@ def get_mapped_ti_count(self, run_id: str, *, session: Session) -> int: (g._expand_input.get_total_map_length(run_id, session=session) for g in groups), ) - def __exit__(self, exc_type, exc_val, exc_tb): - for op, _ in self._expand_input.iter_references(): - self.set_upstream(op) - super().__exit__(exc_type, exc_val, exc_tb) - - -class TaskGroupContext: - """TaskGroup context is used to keep the current TaskGroup when TaskGroup is used as ContextManager.""" - - active: bool = False - _context_managed_task_group: TaskGroup | None = None - _previous_context_managed_task_groups: list[TaskGroup] = [] - - @classmethod - def push_context_managed_task_group(cls, task_group: TaskGroup): - """Push a TaskGroup into the list of managed TaskGroups.""" - if cls._context_managed_task_group: - cls._previous_context_managed_task_groups.append(cls._context_managed_task_group) - cls._context_managed_task_group = task_group - cls.active = True - - @classmethod - def pop_context_managed_task_group(cls) -> TaskGroup | None: - """Pops the last TaskGroup from the list of managed TaskGroups and update the current TaskGroup.""" - old_task_group = cls._context_managed_task_group - if cls._previous_context_managed_task_groups: - cls._context_managed_task_group = cls._previous_context_managed_task_groups.pop() - else: - cls._context_managed_task_group = None - cls.active = False - return old_task_group - - @classmethod - def get_current_task_group(cls, dag: DAG | None) -> TaskGroup | None: - """Get the current TaskGroup.""" - from airflow.models.dag import DagContext - - if not cls._context_managed_task_group: - dag = dag or DagContext.get_current_dag() - if dag: - # If there's currently a DAG but no TaskGroup, return the root TaskGroup of the dag. - return dag.task_group - - return cls._context_managed_task_group - def task_group_to_dict(task_item_or_group): """Create a nested dict representation of this TaskGroup and its children used to construct the Graph.""" diff --git a/airflow/utils/types.py b/airflow/utils/types.py index ab4024907af1..7dd1ce02b609 100644 --- a/airflow/utils/types.py +++ b/airflow/utils/types.py @@ -19,39 +19,15 @@ import enum from typing import TYPE_CHECKING -from airflow.typing_compat import TypedDict +import airflow.sdk.types +from airflow.typing_compat import TypeAlias, TypedDict if TYPE_CHECKING: from datetime import datetime +ArgNotSet: TypeAlias = airflow.sdk.types.ArgNotSet -class ArgNotSet: - """ - Sentinel type for annotations, useful when None is not viable. - - Use like this:: - - def is_arg_passed(arg: Union[ArgNotSet, None] = NOTSET) -> bool: - if arg is NOTSET: - return False - return True - - - is_arg_passed() # False. - is_arg_passed(None) # True. - """ - - @staticmethod - def serialize(): - return "NOTSET" - - @classmethod - def deserialize(cls): - return cls - - -NOTSET = ArgNotSet() -"""Sentinel value for argument default. See ``ArgNotSet``.""" +NOTSET = airflow.sdk.types.NOTSET class AttributeRemoved: diff --git a/dev/mypy/plugin/outputs.py b/dev/mypy/plugin/outputs.py index fe1ccd5e7cf2..a3ba7351f556 100644 --- a/dev/mypy/plugin/outputs.py +++ b/dev/mypy/plugin/outputs.py @@ -25,6 +25,7 @@ OUTPUT_PROPERTIES = { "airflow.models.baseoperator.BaseOperator.output", "airflow.models.mappedoperator.MappedOperator.output", + "airflow.sdk.definitions.baseoperator.BaseOperator.output", } TASK_CALL_FUNCTIONS = { diff --git a/hatch_build.py b/hatch_build.py index 7dbf4f5d3c37..4c0e7f377112 100644 --- a/hatch_build.py +++ b/hatch_build.py @@ -254,6 +254,7 @@ # Coverage 7.4.0 added experimental support for Python 3.12 PEP669 which we use in Airflow "coverage>=7.4.0", "jmespath>=0.7.0", + "kgb>=7.0.0", "pytest-asyncio>=0.23.6", "pytest-cov>=4.1.0", "pytest-custom-exit-code>=0.3.0", @@ -411,7 +412,7 @@ 'pendulum>=3.0.0,<4.0;python_version>="3.12"', "pluggy>=1.5.0", "psutil>=5.8.0", - "pydantic>=2.6.4", + "pydantic>=2.7.0", "pygments>=2.0.1", "pyjwt>=2.0.0", "python-daemon>=3.0.0", diff --git a/providers/src/airflow/providers/amazon/aws/operators/comprehend.py b/providers/src/airflow/providers/amazon/aws/operators/comprehend.py index 880440726c60..88bbbf9bf46d 100644 --- a/providers/src/airflow/providers/amazon/aws/operators/comprehend.py +++ b/providers/src/airflow/providers/amazon/aws/operators/comprehend.py @@ -17,7 +17,7 @@ from __future__ import annotations from functools import cached_property -from typing import TYPE_CHECKING, Any, Sequence +from typing import TYPE_CHECKING, Any, ClassVar, Sequence from airflow.configuration import conf from airflow.exceptions import AirflowException @@ -55,7 +55,7 @@ class ComprehendBaseOperator(AwsBaseOperator[ComprehendHook]): "input_data_config", "output_data_config", "data_access_role_arn", "language_code" ) - template_fields_renderers: dict = {"input_data_config": "json", "output_data_config": "json"} + template_fields_renderers: ClassVar[dict] = {"input_data_config": "json", "output_data_config": "json"} def __init__( self, @@ -248,7 +248,7 @@ class ComprehendCreateDocumentClassifierOperator(AwsBaseOperator[ComprehendHook] "document_classifier_kwargs", ) - template_fields_renderers: dict = { + template_fields_renderers: ClassVar[dict] = { "input_data_config": "json", "output_data_config": "json", "document_classifier_kwargs": "json", diff --git a/providers/src/airflow/providers/amazon/aws/operators/dms.py b/providers/src/airflow/providers/amazon/aws/operators/dms.py index 9fb33173884f..915b85cd6c37 100644 --- a/providers/src/airflow/providers/amazon/aws/operators/dms.py +++ b/providers/src/airflow/providers/amazon/aws/operators/dms.py @@ -18,7 +18,7 @@ from __future__ import annotations from datetime import datetime -from typing import TYPE_CHECKING, Any, Sequence +from typing import TYPE_CHECKING, Any, ClassVar, Sequence from airflow.configuration import conf from airflow.exceptions import AirflowException @@ -75,7 +75,7 @@ class DmsCreateTaskOperator(AwsBaseOperator[DmsHook]): "migration_type", "create_task_kwargs", ) - template_fields_renderers = { + template_fields_renderers: ClassVar[dict] = { "table_mappings": "json", "create_task_kwargs": "json", } @@ -184,7 +184,7 @@ class DmsDescribeTasksOperator(AwsBaseOperator[DmsHook]): aws_hook_class = DmsHook template_fields: Sequence[str] = aws_template_fields("describe_tasks_kwargs") - template_fields_renderers: dict[str, str] = {"describe_tasks_kwargs": "json"} + template_fields_renderers: ClassVar[dict[str, str]] = {"describe_tasks_kwargs": "json"} def __init__(self, *, describe_tasks_kwargs: dict | None = None, **kwargs): super().__init__(**kwargs) diff --git a/providers/src/airflow/providers/amazon/aws/operators/kinesis_analytics.py b/providers/src/airflow/providers/amazon/aws/operators/kinesis_analytics.py index 727aa714c614..93f8bc6b805e 100644 --- a/providers/src/airflow/providers/amazon/aws/operators/kinesis_analytics.py +++ b/providers/src/airflow/providers/amazon/aws/operators/kinesis_analytics.py @@ -16,7 +16,7 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Any, Sequence +from typing import TYPE_CHECKING, Any, ClassVar, Sequence from botocore.exceptions import ClientError @@ -70,7 +70,7 @@ class KinesisAnalyticsV2CreateApplicationOperator(AwsBaseOperator[KinesisAnalyti "create_application_kwargs", "application_description", ) - template_fields_renderers: dict = { + template_fields_renderers: ClassVar[dict] = { "create_application_kwargs": "json", } @@ -149,7 +149,7 @@ class KinesisAnalyticsV2StartApplicationOperator(AwsBaseOperator[KinesisAnalytic "application_name", "run_configuration", ) - template_fields_renderers: dict = { + template_fields_renderers: ClassVar[dict] = { "run_configuration": "json", } diff --git a/providers/src/airflow/providers/amazon/aws/operators/sagemaker.py b/providers/src/airflow/providers/amazon/aws/operators/sagemaker.py index 57a919452622..e2fad21f09df 100644 --- a/providers/src/airflow/providers/amazon/aws/operators/sagemaker.py +++ b/providers/src/airflow/providers/amazon/aws/operators/sagemaker.py @@ -20,7 +20,7 @@ import json import time from functools import cached_property -from typing import TYPE_CHECKING, Any, Callable, Sequence +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Sequence from botocore.exceptions import ClientError @@ -65,7 +65,7 @@ class SageMakerBaseOperator(BaseOperator): template_fields: Sequence[str] = ("config",) template_ext: Sequence[str] = () - template_fields_renderers: dict = {"config": "json"} + template_fields_renderers: ClassVar[dict] = {"config": "json"} ui_color: str = "#ededed" integer_fields: list[list[Any]] = [] diff --git a/providers/src/airflow/providers/apache/drill/operators/drill.py b/providers/src/airflow/providers/apache/drill/operators/drill.py index 5aa0f061bafe..edf9d0f73590 100644 --- a/providers/src/airflow/providers/apache/drill/operators/drill.py +++ b/providers/src/airflow/providers/apache/drill/operators/drill.py @@ -17,7 +17,7 @@ # under the License. from __future__ import annotations -from typing import Sequence +from typing import ClassVar, Sequence from deprecated import deprecated @@ -46,7 +46,7 @@ class DrillOperator(SQLExecuteQueryOperator): """ template_fields: Sequence[str] = ("sql",) - template_fields_renderers = {"sql": "sql"} + template_fields_renderers: ClassVar[dict] = {"sql": "sql"} template_ext: Sequence[str] = (".sql",) ui_color = "#ededed" diff --git a/providers/src/airflow/providers/cncf/kubernetes/operators/pod.py b/providers/src/airflow/providers/cncf/kubernetes/operators/pod.py index e51397447c3c..62f08439d416 100644 --- a/providers/src/airflow/providers/cncf/kubernetes/operators/pod.py +++ b/providers/src/airflow/providers/cncf/kubernetes/operators/pod.py @@ -27,7 +27,7 @@ import shlex import string import warnings -from collections.abc import Container +from collections.abc import Container, Mapping from contextlib import AbstractContextManager from enum import Enum from functools import cached_property @@ -436,7 +436,7 @@ def _incluster_namespace(self): def _render_nested_template_fields( self, content: Any, - context: Context, + context: Mapping[str, Any], jinja_env: jinja2.Environment, seen_oids: set, ) -> None: diff --git a/providers/src/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py b/providers/src/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py index c3dd4755b983..c1f5b36d6d35 100644 --- a/providers/src/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py +++ b/providers/src/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py @@ -17,6 +17,7 @@ # under the License. from __future__ import annotations +from collections.abc import Mapping from functools import cached_property from pathlib import Path from typing import TYPE_CHECKING, Any @@ -127,7 +128,7 @@ def __init__( def _render_nested_template_fields( self, content: Any, - context: Context, + context: Mapping[str, Any], jinja_env: jinja2.Environment, seen_oids: set, ) -> None: diff --git a/providers/src/airflow/providers/common/sql/operators/sql.py b/providers/src/airflow/providers/common/sql/operators/sql.py index dae389be0289..44982b8e1efa 100644 --- a/providers/src/airflow/providers/common/sql/operators/sql.py +++ b/providers/src/airflow/providers/common/sql/operators/sql.py @@ -20,7 +20,7 @@ import ast import re from functools import cached_property -from typing import TYPE_CHECKING, Any, Callable, Iterable, Mapping, NoReturn, Sequence, SupportsAbs +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Mapping, NoReturn, Sequence, SupportsAbs from airflow.exceptions import AirflowException, AirflowFailException from airflow.hooks.base import BaseHook @@ -224,7 +224,7 @@ class SQLExecuteQueryOperator(BaseSQLOperator): template_fields: Sequence[str] = ("sql", "parameters", *BaseSQLOperator.template_fields) template_ext: Sequence[str] = (".sql", ".json") - template_fields_renderers = {"sql": "sql", "parameters": "json"} + template_fields_renderers: ClassVar[dict] = {"sql": "sql", "parameters": "json"} ui_color = "#cdaaed" def __init__( @@ -428,7 +428,7 @@ class SQLColumnCheckOperator(BaseSQLOperator): """ template_fields: Sequence[str] = ("table", "partition_clause", "sql", *BaseSQLOperator.template_fields) - template_fields_renderers = {"sql": "sql"} + template_fields_renderers: ClassVar[dict] = {"sql": "sql"} sql_check_template = """ SELECT '{column}' AS col_name, '{check}' AS check_type, {column}_{check} AS check_result @@ -657,7 +657,7 @@ class SQLTableCheckOperator(BaseSQLOperator): template_fields: Sequence[str] = ("table", "partition_clause", "sql", *BaseSQLOperator.template_fields) - template_fields_renderers = {"sql": "sql"} + template_fields_renderers: ClassVar[dict] = {"sql": "sql"} sql_check_template = """ SELECT '{check_name}' AS check_name, MIN({check_name}) AS check_result @@ -776,7 +776,7 @@ class SQLCheckOperator(BaseSQLOperator): ".hql", ".sql", ) - template_fields_renderers = {"sql": "sql"} + template_fields_renderers: ClassVar[dict] = {"sql": "sql"} ui_color = "#fff7e6" def __init__( @@ -822,7 +822,7 @@ class SQLValueCheckOperator(BaseSQLOperator): ".hql", ".sql", ) - template_fields_renderers = {"sql": "sql"} + template_fields_renderers: ClassVar[dict] = {"sql": "sql"} ui_color = "#fff7e6" def __init__( @@ -919,7 +919,7 @@ class SQLIntervalCheckOperator(BaseSQLOperator): ".hql", ".sql", ) - template_fields_renderers = {"sql1": "sql", "sql2": "sql"} + template_fields_renderers: ClassVar[dict] = {"sql1": "sql", "sql2": "sql"} ui_color = "#fff7e6" ratio_formulas = { @@ -1052,7 +1052,7 @@ class SQLThresholdCheckOperator(BaseSQLOperator): ".hql", ".sql", ) - template_fields_renderers = {"sql": "sql"} + template_fields_renderers: ClassVar[dict] = {"sql": "sql"} def __init__( self, @@ -1147,7 +1147,7 @@ class BranchSQLOperator(BaseSQLOperator, SkipMixin): template_fields: Sequence[str] = ("sql", *BaseSQLOperator.template_fields) template_ext: Sequence[str] = (".sql",) - template_fields_renderers = {"sql": "sql"} + template_fields_renderers: ClassVar[dict] = {"sql": "sql"} ui_color = "#a22034" ui_fgcolor = "#F7F7F7" diff --git a/providers/src/airflow/providers/common/sql/operators/sql.pyi b/providers/src/airflow/providers/common/sql/operators/sql.pyi index 1b97cec5023c..6921e3411ea0 100644 --- a/providers/src/airflow/providers/common/sql/operators/sql.pyi +++ b/providers/src/airflow/providers/common/sql/operators/sql.pyi @@ -36,7 +36,7 @@ from airflow.models import BaseOperator as BaseOperator, SkipMixin as SkipMixin from airflow.providers.common.sql.hooks.sql import DbApiHook as DbApiHook from airflow.providers.openlineage.extractors import OperatorLineage as OperatorLineage from airflow.utils.context import Context as Context -from typing import Any, Callable, Iterable, Mapping, Sequence, SupportsAbs +from typing import Any, Callable, ClassVar, Iterable, Mapping, Sequence, SupportsAbs def parse_boolean(val: str) -> str | bool: ... @@ -62,7 +62,7 @@ class SQLExecuteQueryOperator(BaseSQLOperator): def _raise_exception(self, exception_string: str) -> Incomplete: ... template_fields: Sequence[str] template_ext: Sequence[str] - template_fields_renderers: Incomplete + template_fields_renderers: ClassVar[dict] ui_color: str sql: Incomplete autocommit: Incomplete @@ -92,7 +92,7 @@ class SQLExecuteQueryOperator(BaseSQLOperator): class SQLColumnCheckOperator(BaseSQLOperator): template_fields: Sequence[str] - template_fields_renderers: Incomplete + template_fields_renderers: ClassVar[dict] sql_check_template: str column_checks: Incomplete table: Incomplete @@ -115,7 +115,7 @@ class SQLColumnCheckOperator(BaseSQLOperator): class SQLTableCheckOperator(BaseSQLOperator): template_fields: Sequence[str] - template_fields_renderers: Incomplete + template_fields_renderers: ClassVar[dict] sql_check_template: str table: Incomplete checks: Incomplete @@ -136,7 +136,7 @@ class SQLTableCheckOperator(BaseSQLOperator): class SQLCheckOperator(BaseSQLOperator): template_fields: Sequence[str] template_ext: Sequence[str] - template_fields_renderers: Incomplete + template_fields_renderers: ClassVar[dict] ui_color: str sql: Incomplete parameters: Incomplete @@ -155,7 +155,7 @@ class SQLValueCheckOperator(BaseSQLOperator): __mapper_args__: Incomplete template_fields: Sequence[str] template_ext: Sequence[str] - template_fields_renderers: Incomplete + template_fields_renderers: ClassVar[dict] ui_color: str sql: Incomplete pass_value: Incomplete @@ -178,7 +178,7 @@ class SQLIntervalCheckOperator(BaseSQLOperator): __mapper_args__: Incomplete template_fields: Sequence[str] template_ext: Sequence[str] - template_fields_renderers: Incomplete + template_fields_renderers: ClassVar[dict] ui_color: str ratio_formulas: Incomplete ratio_formula: Incomplete @@ -208,7 +208,7 @@ class SQLIntervalCheckOperator(BaseSQLOperator): class SQLThresholdCheckOperator(BaseSQLOperator): template_fields: Sequence[str] template_ext: Sequence[str] - template_fields_renderers: Incomplete + template_fields_renderers: ClassVar[dict] sql: Incomplete min_threshold: Incomplete max_threshold: Incomplete @@ -228,7 +228,7 @@ class SQLThresholdCheckOperator(BaseSQLOperator): class BranchSQLOperator(BaseSQLOperator, SkipMixin): template_fields: Sequence[str] template_ext: Sequence[str] - template_fields_renderers: Incomplete + template_fields_renderers: ClassVar[dict] ui_color: str ui_fgcolor: str sql: Incomplete diff --git a/providers/src/airflow/providers/databricks/operators/databricks_sql.py b/providers/src/airflow/providers/databricks/operators/databricks_sql.py index 1975b4de2149..7e59fc2a9d59 100644 --- a/providers/src/airflow/providers/databricks/operators/databricks_sql.py +++ b/providers/src/airflow/providers/databricks/operators/databricks_sql.py @@ -21,7 +21,7 @@ import csv import json -from typing import TYPE_CHECKING, Any, Sequence +from typing import TYPE_CHECKING, Any, ClassVar, Sequence from databricks.sql.utils import ParamEscaper @@ -72,7 +72,7 @@ class DatabricksSqlOperator(SQLExecuteQueryOperator): ) template_ext: Sequence[str] = (".sql",) - template_fields_renderers = {"sql": "sql"} + template_fields_renderers: ClassVar[dict] = {"sql": "sql"} conn_id_field = "databricks_conn_id" def __init__( diff --git a/providers/src/airflow/providers/exasol/operators/exasol.py b/providers/src/airflow/providers/exasol/operators/exasol.py index 407fdf659166..51c0131fa5b5 100644 --- a/providers/src/airflow/providers/exasol/operators/exasol.py +++ b/providers/src/airflow/providers/exasol/operators/exasol.py @@ -17,7 +17,7 @@ # under the License. from __future__ import annotations -from typing import Sequence +from typing import ClassVar, Sequence from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator from airflow.providers.exasol.hooks.exasol import exasol_fetch_all_handler @@ -40,7 +40,7 @@ class ExasolOperator(SQLExecuteQueryOperator): template_fields: Sequence[str] = ("sql", "exasol_conn_id") template_ext: Sequence[str] = (".sql",) - template_fields_renderers = {"sql": "sql"} + template_fields_renderers: ClassVar[dict] = {"sql": "sql"} ui_color = "#ededed" conn_id_field = "exasol_conn_id" diff --git a/providers/src/airflow/providers/fab/auth_manager/security_manager/override.py b/providers/src/airflow/providers/fab/auth_manager/security_manager/override.py index 97e13398ffb6..9c49845b4206 100644 --- a/providers/src/airflow/providers/fab/auth_manager/security_manager/override.py +++ b/providers/src/airflow/providers/fab/auth_manager/security_manager/override.py @@ -17,6 +17,7 @@ # under the License. from __future__ import annotations +import copy import datetime import itertools import logging @@ -24,7 +25,7 @@ import random import uuid import warnings -from typing import TYPE_CHECKING, Any, Callable, Collection, Container, Iterable, Sequence +from typing import TYPE_CHECKING, Any, Callable, Collection, Container, Iterable, Mapping, Sequence import jwt import packaging.version @@ -1107,7 +1108,7 @@ def is_dag_resource(self, resource_name: str) -> bool: def sync_perm_for_dag( self, dag_id: str, - access_control: dict[str, dict[str, Collection[str]] | Collection[str]] | None = None, + access_control: Mapping[str, Mapping[str, Collection[str]] | Collection[str]] | None = None, ) -> None: """ Sync permissions for given dag id. @@ -1128,7 +1129,7 @@ def sync_perm_for_dag( if access_control is not None: self.log.debug("Syncing DAG-level permissions for DAG '%s'", dag_id) - self._sync_dag_view_permissions(dag_id, access_control.copy()) + self._sync_dag_view_permissions(dag_id, copy.copy(access_control)) else: self.log.debug( "Not syncing DAG-level permissions for DAG '%s' as access control is unset.", @@ -1149,7 +1150,7 @@ def _resource_name(self, dag_id: str, resource_name: str) -> str: def _sync_dag_view_permissions( self, dag_id: str, - access_control: dict[str, dict[str, Collection[str]] | Collection[str]], + access_control: Mapping[str, Mapping[str, Collection[str]] | Collection[str]], ) -> None: """ Set the access policy on the given DAG's ViewModel. diff --git a/providers/src/airflow/providers/jdbc/operators/jdbc.py b/providers/src/airflow/providers/jdbc/operators/jdbc.py index b889eb645182..3357b569c8fb 100644 --- a/providers/src/airflow/providers/jdbc/operators/jdbc.py +++ b/providers/src/airflow/providers/jdbc/operators/jdbc.py @@ -17,7 +17,7 @@ # under the License. from __future__ import annotations -from typing import Sequence +from typing import ClassVar, Sequence from deprecated import deprecated @@ -54,7 +54,7 @@ class JdbcOperator(SQLExecuteQueryOperator): template_fields: Sequence[str] = ("sql",) template_ext: Sequence[str] = (".sql",) - template_fields_renderers = {"sql": "sql"} + template_fields_renderers: ClassVar[dict] = {"sql": "sql"} ui_color = "#ededed" def __init__(self, *, jdbc_conn_id: str = "jdbc_default", **kwargs) -> None: diff --git a/providers/src/airflow/providers/microsoft/mssql/operators/mssql.py b/providers/src/airflow/providers/microsoft/mssql/operators/mssql.py index 5c24831ef1d2..e21c78c3bf6a 100644 --- a/providers/src/airflow/providers/microsoft/mssql/operators/mssql.py +++ b/providers/src/airflow/providers/microsoft/mssql/operators/mssql.py @@ -17,7 +17,7 @@ # under the License. from __future__ import annotations -from typing import Sequence +from typing import ClassVar, Sequence from deprecated import deprecated @@ -59,7 +59,7 @@ class MsSqlOperator(SQLExecuteQueryOperator): template_fields: Sequence[str] = ("sql",) template_ext: Sequence[str] = (".sql",) - template_fields_renderers = {"sql": "tsql"} + template_fields_renderers: ClassVar[dict] = {"sql": "tsql"} ui_color = "#ededed" def __init__( diff --git a/providers/src/airflow/providers/mysql/operators/mysql.py b/providers/src/airflow/providers/mysql/operators/mysql.py index 2c2436b4d9da..7a47dd0a68cf 100644 --- a/providers/src/airflow/providers/mysql/operators/mysql.py +++ b/providers/src/airflow/providers/mysql/operators/mysql.py @@ -17,7 +17,7 @@ # under the License. from __future__ import annotations -from typing import Sequence +from typing import ClassVar, Sequence from deprecated import deprecated @@ -58,7 +58,7 @@ class MySqlOperator(SQLExecuteQueryOperator): """ template_fields: Sequence[str] = ("sql", "parameters") - template_fields_renderers = { + template_fields_renderers: ClassVar[dict] = { "sql": "mysql", "parameters": "json", } diff --git a/providers/src/airflow/providers/oracle/operators/oracle.py b/providers/src/airflow/providers/oracle/operators/oracle.py index 0debfed2c6b5..5770271d6360 100644 --- a/providers/src/airflow/providers/oracle/operators/oracle.py +++ b/providers/src/airflow/providers/oracle/operators/oracle.py @@ -18,7 +18,7 @@ from __future__ import annotations import re -from typing import TYPE_CHECKING, Sequence +from typing import TYPE_CHECKING, ClassVar, Sequence import oracledb from deprecated import deprecated @@ -60,7 +60,7 @@ class OracleOperator(SQLExecuteQueryOperator): "sql", ) template_ext: Sequence[str] = (".sql",) - template_fields_renderers = {"sql": "sql"} + template_fields_renderers: ClassVar[dict] = {"sql": "sql"} ui_color = "#ededed" def __init__(self, *, oracle_conn_id: str = "oracle_default", **kwargs) -> None: diff --git a/providers/src/airflow/providers/postgres/operators/postgres.py b/providers/src/airflow/providers/postgres/operators/postgres.py index 424a86b66690..f936d4626978 100644 --- a/providers/src/airflow/providers/postgres/operators/postgres.py +++ b/providers/src/airflow/providers/postgres/operators/postgres.py @@ -18,7 +18,7 @@ from __future__ import annotations import warnings -from typing import Mapping +from typing import ClassVar, Mapping from deprecated import deprecated @@ -55,7 +55,10 @@ class PostgresOperator(SQLExecuteQueryOperator): Deprecated - use `hook_params={'options': '-c '}` instead. """ - template_fields_renderers = {**SQLExecuteQueryOperator.template_fields_renderers, "sql": "postgresql"} + template_fields_renderers: ClassVar[dict] = { + **SQLExecuteQueryOperator.template_fields_renderers, + "sql": "postgresql", + } ui_color = "#ededed" def __init__( diff --git a/providers/src/airflow/providers/snowflake/operators/snowflake.py b/providers/src/airflow/providers/snowflake/operators/snowflake.py index 89af8fbb6fdc..c7c19d740b1f 100644 --- a/providers/src/airflow/providers/snowflake/operators/snowflake.py +++ b/providers/src/airflow/providers/snowflake/operators/snowflake.py @@ -19,7 +19,7 @@ import time from datetime import timedelta -from typing import TYPE_CHECKING, Any, Iterable, List, Mapping, Sequence, SupportsAbs, cast +from typing import TYPE_CHECKING, Any, ClassVar, Iterable, List, Mapping, Sequence, SupportsAbs, cast from deprecated import deprecated @@ -88,7 +88,7 @@ class SnowflakeOperator(SQLExecuteQueryOperator): template_fields: Sequence[str] = ("sql",) template_ext: Sequence[str] = (".sql",) - template_fields_renderers = {"sql": "sql"} + template_fields_renderers: ClassVar[dict] = {"sql": "sql"} ui_color = "#ededed" def __init__( diff --git a/providers/src/airflow/providers/sqlite/operators/sqlite.py b/providers/src/airflow/providers/sqlite/operators/sqlite.py index 38c085178f2b..1e38696263f9 100644 --- a/providers/src/airflow/providers/sqlite/operators/sqlite.py +++ b/providers/src/airflow/providers/sqlite/operators/sqlite.py @@ -17,7 +17,7 @@ # under the License. from __future__ import annotations -from typing import Sequence +from typing import ClassVar, Sequence from deprecated import deprecated @@ -51,7 +51,7 @@ class SqliteOperator(SQLExecuteQueryOperator): template_fields: Sequence[str] = ("sql",) template_ext: Sequence[str] = (".sql",) - template_fields_renderers = {"sql": "sql"} + template_fields_renderers: ClassVar[dict] = {"sql": "sql"} ui_color = "#cdaaed" def __init__(self, *, sqlite_conn_id: str = "sqlite_default", **kwargs) -> None: diff --git a/providers/src/airflow/providers/teradata/operators/teradata.py b/providers/src/airflow/providers/teradata/operators/teradata.py index c15fc290385d..edb1331c6122 100644 --- a/providers/src/airflow/providers/teradata/operators/teradata.py +++ b/providers/src/airflow/providers/teradata/operators/teradata.py @@ -17,7 +17,7 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Sequence +from typing import TYPE_CHECKING, ClassVar, Sequence from airflow.models import BaseOperator from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator @@ -49,7 +49,7 @@ class TeradataOperator(SQLExecuteQueryOperator): "parameters", ) template_ext: Sequence[str] = (".sql",) - template_fields_renderers = {"sql": "sql"} + template_fields_renderers: ClassVar[dict] = {"sql": "sql"} ui_color = "#e07c24" def __init__( diff --git a/providers/src/airflow/providers/trino/operators/trino.py b/providers/src/airflow/providers/trino/operators/trino.py index 76856728a483..9ff9768d745d 100644 --- a/providers/src/airflow/providers/trino/operators/trino.py +++ b/providers/src/airflow/providers/trino/operators/trino.py @@ -19,7 +19,7 @@ from __future__ import annotations -from typing import Any, Sequence +from typing import Any, ClassVar, Sequence from deprecated import deprecated from trino.exceptions import TrinoQueryError @@ -56,7 +56,7 @@ class TrinoOperator(SQLExecuteQueryOperator): """ template_fields: Sequence[str] = ("sql",) - template_fields_renderers = {"sql": "sql"} + template_fields_renderers: ClassVar[dict] = {"sql": "sql"} template_ext: Sequence[str] = (".sql",) ui_color = "#ededed" diff --git a/providers/src/airflow/providers/vertica/operators/vertica.py b/providers/src/airflow/providers/vertica/operators/vertica.py index 6373dfdf8b49..03cf14f000e3 100644 --- a/providers/src/airflow/providers/vertica/operators/vertica.py +++ b/providers/src/airflow/providers/vertica/operators/vertica.py @@ -17,7 +17,7 @@ # under the License. from __future__ import annotations -from typing import Any, Sequence +from typing import Any, ClassVar, Sequence from deprecated import deprecated @@ -45,7 +45,7 @@ class VerticaOperator(SQLExecuteQueryOperator): template_fields: Sequence[str] = ("sql",) template_ext: Sequence[str] = (".sql",) - template_fields_renderers = {"sql": "sql"} + template_fields_renderers: ClassVar[dict] = {"sql": "sql"} ui_color = "#b4e0ff" def __init__(self, *, vertica_conn_id: str = "vertica_default", **kwargs: Any) -> None: diff --git a/providers/tests/amazon/aws/operators/test_batch.py b/providers/tests/amazon/aws/operators/test_batch.py index 1389099e4445..0c14c256edba 100644 --- a/providers/tests/amazon/aws/operators/test_batch.py +++ b/providers/tests/amazon/aws/operators/test_batch.py @@ -441,7 +441,7 @@ def test_override_not_sent_if_not_set(self, client_mock, override): client_mock().submit_job.assert_called_once_with(**expected_args) def test_cant_set_old_and_new_override_param(self): - with pytest.raises(AirflowException): + with pytest.raises((TypeError, AirflowException), match="Invalid arguments were passed"): _ = BatchOperator( task_id="task", job_name=JOB_NAME, diff --git a/providers/tests/google/cloud/operators/test_bigquery.py b/providers/tests/google/cloud/operators/test_bigquery.py index a3beddb9d873..4269d02377b2 100644 --- a/providers/tests/google/cloud/operators/test_bigquery.py +++ b/providers/tests/google/cloud/operators/test_bigquery.py @@ -2578,7 +2578,7 @@ def test_bigquery_value_check_missing_param(self, kwargs, expected): """ Assert the exception if require param not pass to BigQueryValueCheckOperator with deferrable=True """ - with pytest.raises(AirflowException) as missing_param: + with pytest.raises((TypeError, AirflowException)) as missing_param: BigQueryValueCheckOperator(deferrable=True, **kwargs) assert missing_param.value.args[0] == expected @@ -2590,7 +2590,7 @@ def test_bigquery_value_check_empty(self): "missing keyword arguments 'sql', 'pass_value'", "missing keyword arguments 'pass_value', 'sql'", ) - with pytest.raises(AirflowException) as missing_param: + with pytest.raises((TypeError, AirflowException)) as missing_param: BigQueryValueCheckOperator(deferrable=True, kwargs={}) assert missing_param.value.args[0] in (expected, expected1) diff --git a/providers/tests/google/cloud/operators/test_cloud_build.py b/providers/tests/google/cloud/operators/test_cloud_build.py index 3bcc8ac66aa1..958190c942ae 100644 --- a/providers/tests/google/cloud/operators/test_cloud_build.py +++ b/providers/tests/google/cloud/operators/test_cloud_build.py @@ -134,7 +134,7 @@ def test_create_build(self, mock_hook): @mock.patch(CLOUD_BUILD_HOOK_PATH) def test_create_build_with_missing_build(self, mock_hook): mock_hook.return_value.create_build_without_waiting_for_result.return_value = Build() - with pytest.raises(AirflowException, match="missing keyword argument 'build'"): + with pytest.raises((TypeError, AirflowException), match="missing keyword argument 'build'"): CloudBuildCreateBuildOperator(task_id="id") @pytest.mark.parametrize( @@ -479,7 +479,7 @@ def test_async_create_build_error_event_should_throw_exception(): @mock.patch(CLOUD_BUILD_HOOK_PATH) def test_async_create_build_with_missing_build_should_throw_exception(mock_hook): mock_hook.return_value.create_build.return_value = Build() - with pytest.raises(AirflowException, match="missing keyword argument 'build'"): + with pytest.raises((TypeError, AirflowException), match="missing keyword argument 'build'"): CloudBuildCreateBuildOperator(task_id="id") diff --git a/providers/tests/google/cloud/operators/test_compute.py b/providers/tests/google/cloud/operators/test_compute.py index fac74bae4837..7913143618cb 100644 --- a/providers/tests/google/cloud/operators/test_compute.py +++ b/providers/tests/google/cloud/operators/test_compute.py @@ -349,7 +349,9 @@ def test_insert_instance_from_template_should_throw_ex_when_missing_zone(self): ) def test_insert_instance_from_template_should_throw_ex_when_missing_source_instance_template(self): - with pytest.raises(AirflowException, match=r"missing keyword argument 'source_instance_template'"): + with pytest.raises( + (TypeError, AirflowException), match=r"missing keyword argument 'source_instance_template'" + ): ComputeEngineInsertInstanceFromTemplateOperator( project_id=GCP_PROJECT_ID, body=GCP_INSTANCE_BODY_FROM_TEMPLATE, @@ -360,7 +362,7 @@ def test_insert_instance_from_template_should_throw_ex_when_missing_source_insta ) def test_insert_instance_from_template_should_throw_ex_when_missing_body(self): - with pytest.raises(AirflowException, match=r"missing keyword argument 'body'"): + with pytest.raises((TypeError, AirflowException), match=r"missing keyword argument 'body'"): ComputeEngineInsertInstanceFromTemplateOperator( project_id=GCP_PROJECT_ID, source_instance_template=SOURCE_INSTANCE_TEMPLATE, @@ -910,7 +912,7 @@ def test_insert_template_should_not_throw_ex_when_project_id_none(self, mock_hoo ) def test_insert_template_should_throw_ex_when_missing_body(self): - with pytest.raises(AirflowException, match=r"missing keyword argument 'body'"): + with pytest.raises((TypeError, AirflowException), match=r"missing keyword argument 'body'"): ComputeEngineInsertInstanceTemplateOperator( task_id=TASK_ID, project_id=GCP_PROJECT_ID, @@ -1552,7 +1554,7 @@ def test_insert_igm_should_not_throw_ex_when_project_id_none(self, mock_hook): ) def test_insert_igm_should_throw_ex_when_missing_body(self): - with pytest.raises(AirflowException, match=r"missing keyword argument 'body'"): + with pytest.raises((TypeError, AirflowException), match=r"missing keyword argument 'body'"): ComputeEngineInsertInstanceGroupManagerOperator( zone=GCE_ZONE, task_id=TASK_ID, diff --git a/providers/tests/google/cloud/operators/test_dataflow.py b/providers/tests/google/cloud/operators/test_dataflow.py index 4263d3300f11..bc23d84d2285 100644 --- a/providers/tests/google/cloud/operators/test_dataflow.py +++ b/providers/tests/google/cloud/operators/test_dataflow.py @@ -1077,7 +1077,7 @@ def test_invalid_response(self): "location": TEST_LOCATION, "gcp_conn_id": GCP_CONN_ID, } - with pytest.raises(AirflowException): + with pytest.raises((TypeError, AirflowException), match="missing keyword argument"): DataflowRunPipelineOperator(**init_kwargs).execute(mock.MagicMock()).return_value = { "error": {"message": "example error"} } diff --git a/providers/tests/google/cloud/operators/test_dataproc.py b/providers/tests/google/cloud/operators/test_dataproc.py index babe432a2fe1..2ec6ceb9babb 100644 --- a/providers/tests/google/cloud/operators/test_dataproc.py +++ b/providers/tests/google/cloud/operators/test_dataproc.py @@ -1577,7 +1577,7 @@ def test_on_kill_after_execution_timeout(self, mock_hook): ) def test_missing_region_parameter(self): - with pytest.raises(AirflowException): + with pytest.raises((TypeError, AirflowException), match="missing keyword argument 'region'"): DataprocSubmitJobOperator( task_id=TASK_ID, project_id=GCP_PROJECT, @@ -1692,7 +1692,7 @@ def test_execute(self, mock_hook): ) def test_missing_region_parameter(self): - with pytest.raises(AirflowException): + with pytest.raises((TypeError, AirflowException), match="missing keyword argument 'region'"): DataprocUpdateClusterOperator( task_id=TASK_ID, cluster_name=CLUSTER_NAME, @@ -2678,7 +2678,7 @@ def test_execute(self, mock_hook): ) def test_missing_region_parameter(self): - with pytest.raises(AirflowException): + with pytest.raises((TypeError, AirflowException), match="missing keyword argument 'region'"): DataprocCreateWorkflowTemplateOperator( task_id=TASK_ID, gcp_conn_id=GCP_CONN_ID, diff --git a/providers/tests/google/cloud/operators/test_kubernetes_engine.py b/providers/tests/google/cloud/operators/test_kubernetes_engine.py index f0f42745c1c4..3127b5d89ca9 100644 --- a/providers/tests/google/cloud/operators/test_kubernetes_engine.py +++ b/providers/tests/google/cloud/operators/test_kubernetes_engine.py @@ -227,7 +227,7 @@ def test_create_execute_error_project_id(self, mock_hook): @mock.patch(GKE_HOOK_PATH) def test_create_execute_error_location(self, mock_hook): - with pytest.raises(AirflowException): + with pytest.raises((TypeError, AirflowException), match="missing keyword argument 'location'"): GKECreateClusterOperator( project_id=TEST_GCP_PROJECT_ID, body=PROJECT_BODY, task_id=PROJECT_TASK_ID ) @@ -270,14 +270,14 @@ def test_delete_execute_error_project_id(self, mock_hook): @mock.patch(GKE_HOOK_PATH) def test_delete_execute_error_cluster_name(self, mock_hook): - with pytest.raises(AirflowException): + with pytest.raises((TypeError, AirflowException), match="missing keyword argument 'name'"): GKEDeleteClusterOperator( project_id=TEST_GCP_PROJECT_ID, location=PROJECT_LOCATION, task_id=PROJECT_TASK_ID ) @mock.patch(GKE_HOOK_PATH) def test_delete_execute_error_location(self, mock_hook): - with pytest.raises(AirflowException): + with pytest.raises((TypeError, AirflowException), match="missing keyword argument 'location'"): GKEDeleteClusterOperator( project_id=TEST_GCP_PROJECT_ID, name=CLUSTER_NAME, task_id=PROJECT_TASK_ID ) @@ -1270,7 +1270,7 @@ def test_execute(self, mock_hook, fetch_cluster_info_mock, file_mock, exec_mock) fetch_cluster_info_mock.assert_called_once() def test_config_file_throws_error(self): - with pytest.raises(AirflowException): + with pytest.raises((TypeError, AirflowException), match="missing keyword argument 'queue_name'"): GKEStartKueueJobOperator( project_id=TEST_GCP_PROJECT_ID, location=PROJECT_LOCATION, @@ -1478,7 +1478,9 @@ def setup_method(self): ) def test_config_file_throws_error(self): - with pytest.raises(AirflowException): + with pytest.raises( + (TypeError, AirflowException), match="Invalid arguments were passed to .*\n.*'config_file'" + ): GKESuspendJobOperator( project_id=TEST_GCP_PROJECT_ID, location=PROJECT_LOCATION, @@ -1586,7 +1588,9 @@ def setup_method(self): ) def test_config_file_throws_error(self): - with pytest.raises(AirflowException): + with pytest.raises( + (TypeError, AirflowException), match="Invalid arguments were passed to .*\n.*'config_file'" + ): GKEResumeJobOperator( project_id=TEST_GCP_PROJECT_ID, location=PROJECT_LOCATION, diff --git a/providers/tests/google/cloud/operators/test_speech_to_text.py b/providers/tests/google/cloud/operators/test_speech_to_text.py index 1d7fa9ca37fe..155658976f3f 100644 --- a/providers/tests/google/cloud/operators/test_speech_to_text.py +++ b/providers/tests/google/cloud/operators/test_speech_to_text.py @@ -59,26 +59,20 @@ def test_recognize_speech_green_path(self, mock_hook): def test_missing_config(self, mock_hook): mock_hook.return_value.recognize_speech.return_value = True - with pytest.raises(AirflowException) as ctx: + with pytest.raises((TypeError, AirflowException), match="missing keyword argument 'config'"): CloudSpeechToTextRecognizeSpeechOperator( project_id=PROJECT_ID, gcp_conn_id=GCP_CONN_ID, audio=AUDIO, task_id="id" ).execute(context={"task_instance": Mock()}) - - err = ctx.value - assert "config" in str(err) mock_hook.assert_not_called() @patch("airflow.providers.google.cloud.operators.speech_to_text.CloudSpeechToTextHook") def test_missing_audio(self, mock_hook): mock_hook.return_value.recognize_speech.return_value = True - with pytest.raises(AirflowException) as ctx: + with pytest.raises((TypeError, AirflowException), match="missing keyword argument 'audio'"): CloudSpeechToTextRecognizeSpeechOperator( project_id=PROJECT_ID, gcp_conn_id=GCP_CONN_ID, config=CONFIG, task_id="id" ).execute(context={"task_instance": Mock()}) - - err = ctx.value - assert "audio" in str(err) mock_hook.assert_not_called() @patch("airflow.providers.google.cloud.operators.speech_to_text.FileDetailsLink.persist") diff --git a/providers/tests/google/cloud/sensors/test_dataproc.py b/providers/tests/google/cloud/sensors/test_dataproc.py index 669a9a09f2a9..5fea3e8d3bad 100644 --- a/providers/tests/google/cloud/sensors/test_dataproc.py +++ b/providers/tests/google/cloud/sensors/test_dataproc.py @@ -132,7 +132,7 @@ def test_cancelled(self, mock_hook): @mock.patch(DATAPROC_PATH.format("DataprocHook")) def test_missing_region(self, mock_hook): - with pytest.raises(AirflowException): + with pytest.raises((TypeError, AirflowException), match="missing keyword argument 'region'"): DataprocJobSensor( task_id=TASK_ID, project_id=GCP_PROJECT, diff --git a/providers/tests/google/cloud/transfers/test_gcs_to_bigquery.py b/providers/tests/google/cloud/transfers/test_gcs_to_bigquery.py index 05ef254cb43d..d8a6c550c51e 100644 --- a/providers/tests/google/cloud/transfers/test_gcs_to_bigquery.py +++ b/providers/tests/google/cloud/transfers/test_gcs_to_bigquery.py @@ -906,7 +906,7 @@ def test_execute_should_throw_ex_when_no_bucket_specified(self, hook): ] hook.return_value.generate_job_id.return_value = REAL_JOB_ID hook.return_value.split_tablename.return_value = (PROJECT_ID, DATASET, TABLE) - with pytest.raises(AirflowException, match=r"missing keyword argument 'bucket'"): + with pytest.raises((TypeError, AirflowException), match=r"missing keyword argument 'bucket'"): GCSToBigQueryOperator( task_id=TASK_ID, source_objects=TEST_SOURCE_OBJECTS, @@ -926,7 +926,7 @@ def test_execute_should_throw_ex_when_no_source_objects_specified(self, hook): ] hook.return_value.generate_job_id.return_value = REAL_JOB_ID hook.return_value.split_tablename.return_value = (PROJECT_ID, DATASET, TABLE) - with pytest.raises(AirflowException, match=r"missing keyword argument 'source_objects'"): + with pytest.raises((TypeError, AirflowException), match=r"missing keyword argument 'source_objects'"): GCSToBigQueryOperator( task_id=TASK_ID, destination_project_dataset_table=TEST_EXPLICIT_DEST, @@ -947,7 +947,8 @@ def test_execute_should_throw_ex_when_no_destination_project_dataset_table_speci hook.return_value.generate_job_id.return_value = REAL_JOB_ID hook.return_value.split_tablename.return_value = (PROJECT_ID, DATASET, TABLE) with pytest.raises( - AirflowException, match=r"missing keyword argument 'destination_project_dataset_table'" + (TypeError, AirflowException), + match=r"missing keyword argument 'destination_project_dataset_table'", ): GCSToBigQueryOperator( task_id=TASK_ID, diff --git a/providers/tests/salesforce/operators/test_bulk.py b/providers/tests/salesforce/operators/test_bulk.py index a28cafc73a13..3ec701ecf41d 100644 --- a/providers/tests/salesforce/operators/test_bulk.py +++ b/providers/tests/salesforce/operators/test_bulk.py @@ -33,7 +33,7 @@ def test_execute_missing_operation(self): """ Test execute missing operation """ - with pytest.raises(AirflowException): + with pytest.raises((TypeError, AirflowException), match="missing keyword argument 'operation'"): SalesforceBulkOperator( task_id="no_missing_operation_arg", object_name="Account", @@ -52,7 +52,7 @@ def test_execute_missing_object_name(self): """ Test execute missing object_name """ - with pytest.raises(AirflowException): + with pytest.raises((TypeError, AirflowException), match="missing keyword argument 'object_name'"): SalesforceBulkOperator( task_id="no_object_name_arg", operation="insert", diff --git a/providers/tests/standard/operators/test_weekday.py b/providers/tests/standard/operators/test_weekday.py index 87ceed00ef62..127d735c3da0 100644 --- a/providers/tests/standard/operators/test_weekday.py +++ b/providers/tests/standard/operators/test_weekday.py @@ -208,7 +208,7 @@ def test_branch_follow_false(self, dag_maker): def test_branch_with_no_weekday(self, dag_maker): """Check if BranchDayOfWeekOperator raises exception on missing weekday""" - with pytest.raises(AirflowException): + with pytest.raises((TypeError, AirflowException), match="missing keyword argument 'week_day'"): with dag_maker( "branch_day_of_week_operator_test", start_date=DEFAULT_DATE, diff --git a/pyproject.toml b/pyproject.toml index b97d86540b8b..1ed9136da3f9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -470,6 +470,7 @@ fixture-parentheses = false ## pytest settings ## [tool.pytest.ini_options] addopts = [ + "--tb=short", "-rasl", "--verbosity=2", # Disable `flaky` plugin for pytest. This plugin conflicts with `rerunfailures` because provide the same marker. diff --git a/scripts/ci/pre_commit/base_operator_partial_arguments.py b/scripts/ci/pre_commit/base_operator_partial_arguments.py index 14999e034edb..b50705331700 100755 --- a/scripts/ci/pre_commit/base_operator_partial_arguments.py +++ b/scripts/ci/pre_commit/base_operator_partial_arguments.py @@ -27,6 +27,7 @@ ROOT_DIR = pathlib.Path(__file__).resolve().parents[3] BASEOPERATOR_PY = ROOT_DIR.joinpath("airflow", "models", "baseoperator.py") +SDK_BASEOPERATOR_PY = ROOT_DIR.joinpath("task_sdk", "src", "airflow", "sdk", "definitions", "baseoperator.py") MAPPEDOPERATOR_PY = ROOT_DIR.joinpath("airflow", "models", "mappedoperator.py") IGNORED = { @@ -51,12 +52,31 @@ # Only on MappedOperator. "expand_input", "partial_kwargs", + "operator_class", + # Task-SDK migration ones. + "deps", + "downstream_task_ids", + "on_execute_callback", + "on_failure_callback", + "on_retry_callback", + "on_skipped_callback", + "on_success_callback", + "operator_extra_links", + "start_from_trigger", + "start_trigger_args", + "upstream_task_ids", + "logger_name", + "sla", } BO_MOD = ast.parse(BASEOPERATOR_PY.read_text("utf-8"), str(BASEOPERATOR_PY)) +SDK_BO_MOD = ast.parse(SDK_BASEOPERATOR_PY.read_text("utf-8"), str(SDK_BASEOPERATOR_PY)) MO_MOD = ast.parse(MAPPEDOPERATOR_PY.read_text("utf-8"), str(MAPPEDOPERATOR_PY)) +# TODO: Task-SDK: Look at the BaseOperator init functions in both airflow.models.baseoperator and combine +# them, until we fully remove BaseOperator class from core. + BO_CLS = next( node for node in ast.iter_child_nodes(BO_MOD) @@ -67,9 +87,27 @@ for node in ast.iter_child_nodes(BO_CLS) if isinstance(node, ast.FunctionDef) and node.name == "__init__" ) -BO_PARTIAL = next( + +SDK_BO_CLS = next( + node + for node in ast.iter_child_nodes(SDK_BO_MOD) + if isinstance(node, ast.ClassDef) and node.name == "BaseOperator" +) +SDK_BO_INIT = next( + node + for node in ast.iter_child_nodes(SDK_BO_CLS) + if isinstance(node, ast.FunctionDef) and node.name == "__init__" +) + +# We now define the signature in a type checking block, the runtime impl uses **kwargs +BO_TYPE_CHECKING_BLOCKS = ( node for node in ast.iter_child_nodes(BO_MOD) + if isinstance(node, ast.If) and node.test.id == "TYPE_CHECKING" # type: ignore[attr-defined] +) +BO_PARTIAL = next( + node + for node in itertools.chain.from_iterable(map(ast.iter_child_nodes, BO_TYPE_CHECKING_BLOCKS)) if isinstance(node, ast.FunctionDef) and node.name == "partial" ) MO_CLS = next( @@ -79,23 +117,27 @@ ) -def _compare(a: set[str], b: set[str], *, excludes: set[str]) -> tuple[set[str], set[str]]: - only_in_a = {n for n in a if n not in b and n not in excludes and n[0] != "_"} - only_in_b = {n for n in b if n not in a and n not in excludes and n[0] != "_"} +def _compare(a: set[str], b: set[str]) -> tuple[set[str], set[str]]: + only_in_a = a - b - IGNORED + only_in_b = b - a - IGNORED return only_in_a, only_in_b -def _iter_arg_names(func: ast.FunctionDef) -> typing.Iterator[str]: - func_args = func.args - for arg in itertools.chain(func_args.args, getattr(func_args, "posonlyargs", ()), func_args.kwonlyargs): - yield arg.arg +def _iter_arg_names(*funcs: ast.FunctionDef) -> typing.Iterator[str]: + for func in funcs: + func_args = func.args + for arg in itertools.chain( + func_args.args, getattr(func_args, "posonlyargs", ()), func_args.kwonlyargs + ): + if arg.arg == "self" or arg.arg.startswith("_"): + continue + yield arg.arg def check_baseoperator_partial_arguments() -> bool: only_in_init, only_in_partial = _compare( - set(itertools.islice(_iter_arg_names(BO_INIT), 1, None)), - set(itertools.islice(_iter_arg_names(BO_PARTIAL), 1, None)), - excludes=IGNORED, + set(_iter_arg_names(SDK_BO_INIT, BO_INIT)), + set(_iter_arg_names(BO_PARTIAL)), ) if only_in_init: print("Arguments in BaseOperator missing from partial():", ", ".join(sorted(only_in_init))) @@ -109,6 +151,8 @@ def check_baseoperator_partial_arguments() -> bool: def _iter_assignment_to_self_attributes(targets: typing.Iterable[ast.expr]) -> typing.Iterator[str]: for t in targets: if isinstance(t, ast.Attribute) and isinstance(t.value, ast.Name) and t.value.id == "self": + if t.attr.startswith("_"): + continue yield t.attr # Something like "self.foo = ...". else: # Recursively visit nodes in unpacking assignments like "a, b = ...". @@ -132,20 +176,24 @@ def _is_property(f: ast.FunctionDef) -> bool: def _iter_member_names(klass: ast.ClassDef) -> typing.Iterator[str]: for node in ast.iter_child_nodes(klass): + name = "" if isinstance(node, ast.AnnAssign) and isinstance(node.target, ast.Name): - yield node.target.id + name = node.target.id elif isinstance(node, ast.FunctionDef) and _is_property(node): - yield node.name + name = node.name elif isinstance(node, ast.Assign): if len(node.targets) == 1 and isinstance(target := node.targets[0], ast.Name): - yield target.id + name = target.id + else: + continue + if not name.startswith("_"): + yield name def check_operator_member_parity() -> bool: only_in_base, only_in_mapped = _compare( - set(itertools.chain(_iter_assignment_targets(BO_INIT), _iter_member_names(BO_CLS))), + set(itertools.chain(_iter_assignment_targets(SDK_BO_INIT), _iter_member_names(SDK_BO_CLS))), set(_iter_member_names(MO_CLS)), - excludes=IGNORED, ) if only_in_base: print("Members on BaseOperator missing from MappedOperator:", ", ".join(sorted(only_in_base))) diff --git a/scripts/ci/pre_commit/sync_init_decorator.py b/scripts/ci/pre_commit/sync_init_decorator.py index fc99fc48d7b3..7b02136ead31 100755 --- a/scripts/ci/pre_commit/sync_init_decorator.py +++ b/scripts/ci/pre_commit/sync_init_decorator.py @@ -26,38 +26,71 @@ import sys from typing import TYPE_CHECKING -PACKAGE_ROOT = pathlib.Path(__file__).resolve().parents[3].joinpath("airflow") -DAG_PY = PACKAGE_ROOT.joinpath("models", "dag.py") -UTILS_TG_PY = PACKAGE_ROOT.joinpath("utils", "task_group.py") +ROOT = pathlib.Path(__file__).resolve().parents[3] +PACKAGE_ROOT = ROOT.joinpath("airflow") +SDK_DEFINITIONS_PKG = ROOT.joinpath("task_sdk", "src", "airflow", "sdk", "definitions") +DAG_PY = SDK_DEFINITIONS_PKG.joinpath("dag.py") +TG_PY = SDK_DEFINITIONS_PKG.joinpath("taskgroup.py") DECOS_TG_PY = PACKAGE_ROOT.joinpath("decorators", "task_group.py") -def _find_dag_init(mod: ast.Module) -> ast.FunctionDef: - """Find definition of the ``DAG`` class's ``__init__``.""" - dag_class = next(n for n in ast.iter_child_nodes(mod) if isinstance(n, ast.ClassDef) and n.name == "DAG") - return next( - node - for node in ast.iter_child_nodes(dag_class) - if isinstance(node, ast.FunctionDef) and node.name == "__init__" +def _name(node: ast.expr) -> str: + if not isinstance(node, ast.Name): + raise TypeError("node was not an ast.Name node") + return node.id + + +def _find_cls_attrs( + mod: ast.Module, class_name: str, ignore: list[str] | None = None +) -> collections.abc.Iterable[ast.AnnAssign]: + """Find the type-annotated/attrs properties in the body of the specified class.""" + dag_class = next( + n for n in ast.iter_child_nodes(mod) if isinstance(n, ast.ClassDef) and n.name == class_name ) + ignore = ignore or [] + + for node in ast.iter_child_nodes(dag_class): + if not isinstance(node, ast.AnnAssign) or not node.annotation: + continue + + # ClassVar[Any] + if isinstance(node.annotation, ast.Subscript) and _name(node.annotation.value) == "ClassVar": + continue + + # Skip private attrs fields, ones with `attrs.field(init=False)` kwargs + if isinstance(node.value, ast.Call): + # Lazy coding: since init=True is the default, we're just looking for the presence of the init + # arg name + if TYPE_CHECKING: + assert isinstance(node.value.func, ast.Attribute) + if ( + node.value.func.attr == "field" + and _name(node.value.func.value) == "attrs" + and any(arg.arg == "init" for arg in node.value.keywords) + ): + continue + if _name(node.target) in ignore: + continue + + # Attrs treats `_group_id: str` as `group_id` arg to __init__ + if _name(node.target).startswith("_"): + node.target.id = node.target.id[1:] # type: ignore[union-attr] + yield node + def _find_dag_deco(mod: ast.Module) -> ast.FunctionDef: """Find definition of the ``@dag`` decorator.""" - return next(n for n in ast.iter_child_nodes(mod) if isinstance(n, ast.FunctionDef) and n.name == "dag") - - -def _find_tg_init(mod: ast.Module) -> ast.FunctionDef: - """Find definition of the ``TaskGroup`` class's ``__init__``.""" - task_group_class = next( + # We now define the signature in a type checking block, the runtime impl uses **kwargs + type_checking_blocks = ( node for node in ast.iter_child_nodes(mod) - if isinstance(node, ast.ClassDef) and node.name == "TaskGroup" + if isinstance(node, ast.If) and node.test.id == "TYPE_CHECKING" # type: ignore[attr-defined] ) return next( - node - for node in ast.iter_child_nodes(task_group_class) - if isinstance(node, ast.FunctionDef) and node.name == "__init__" + n + for n in itertools.chain.from_iterable(map(ast.iter_child_nodes, type_checking_blocks)) + if isinstance(n, ast.FunctionDef) and n.name == "dag" ) @@ -74,43 +107,78 @@ def _find_tg_deco(mod: ast.Module) -> ast.FunctionDef: ) +# Hard-code some specific examples of allowable decorate type annotation -> class type annotation mappings +# where they don't match exactly + + +def _expr_to_ast_dump(expr: str) -> str: + return ast.dump(ast.parse(expr).body[0].value) # type: ignore[attr-defined] + + +ALLOWABLE_TYPE_ANNOTATIONS = { + # Mapping of allowble Decorator type -> Class attribute type + _expr_to_ast_dump("Collection[str] | None"): _expr_to_ast_dump("MutableSet[str]"), + _expr_to_ast_dump("ParamsDict | dict[str, Any] | None"): _expr_to_ast_dump("ParamsDict"), + # TODO: This one is legacy access control. Remove it in 3.0. RemovedInAirflow3Warning + _expr_to_ast_dump( + "dict[str, dict[str, Collection[str]]] | dict[str, Collection[str]] | None" + ): _expr_to_ast_dump("dict[str, dict[str, Collection[str]]] | None"), +} + + def _match_arguments( - init_def: tuple[str, list[ast.arg]], + init_def: tuple[str, list[ast.AnnAssign]], deco_def: tuple[str, list[ast.arg]], ) -> collections.abc.Iterator[str]: init_name, init_args = init_def deco_name, deco_args = deco_def + init_args.sort(key=lambda a: _name(a.target)) + deco_args.sort(key=lambda a: a.arg) for i, (ini, dec) in enumerate(itertools.zip_longest(init_args, deco_args, fillvalue=None)): if ini is None and dec is not None: yield f"Argument present in @{deco_name} but missing from {init_name}: {dec.arg}" return if dec is None and ini is not None: - yield f"Argument present in {init_name} but missing from @{deco_name}: {ini.arg}" + yield f"Argument present in {init_name} but missing from @{deco_name}: {_name(ini.target)}" return + if TYPE_CHECKING: + # Mypy can't work out that zip_longest means one of ini or dec must be non None + assert ini is not None + + if not isinstance(ini.target, ast.Name): + raise RuntimeError(f"Don't know how to examine {ast.unparse(ini)!r}") + attr_name = _name(ini.target) if TYPE_CHECKING: assert ini is not None and dec is not None # Because None is only possible as fillvalue. - if ini.arg != dec.arg: - yield f"Argument {i + 1} mismatch: {init_name} has {ini.arg} but @{deco_name} has {dec.arg}" + if attr_name != dec.arg: + yield f"Argument {i + 1} mismatch: {init_name} has {attr_name} but @{deco_name} has {dec.arg}" return if getattr(ini, "type_comment", None): # 3.8+ - yield f"Do not use type comments on {init_name} argument: {ini.arg}" + yield f"Do not use type comments on {init_name} argument: {ini}" if getattr(dec, "type_comment", None): # 3.8+ yield f"Do not use type comments on @{deco_name} argument: {dec.arg}" # Poorly implemented node equality check. - if ini.annotation and dec.annotation and ast.dump(ini.annotation) != ast.dump(dec.annotation): - yield ( - f"Type annotations differ on argument {ini.arg} between {init_name} and @{deco_name}: " - f"{ast.unparse(ini.annotation)} != {ast.unparse(dec.annotation)}" - ) - else: - if not ini.annotation: - yield f"Type annotation missing on {init_name} argument: {ini.arg}" - if not dec.annotation: - yield f"Type annotation missing on @{deco_name} argument: {ini.arg}" + if ini.annotation and dec.annotation: + ini_anno = ast.dump(ini.annotation) + dec_anno = ast.dump(dec.annotation) + if ( + ini_anno != dec_anno + # The decorator can have `| None` type in addaition to the base attribute + and dec_anno != f"BinOp(left={ini_anno}, op=BitOr(), right=Constant(value=None))" + and ALLOWABLE_TYPE_ANNOTATIONS.get(dec_anno) != ini_anno + ): + yield ( + f"Type annotations differ on argument {attr_name!r} between {init_name} and @{deco_name}: " + f"{ast.unparse(ini.annotation)} != {ast.unparse(dec.annotation)}" + ) + elif not ini.annotation: + yield f"Type annotation missing on {init_name} argument: {attr_name}" + elif not dec.annotation: + yield f"Type annotation missing on @{deco_name} argument: {attr_name}" def _match_defaults( @@ -130,47 +198,38 @@ def _match_defaults( def check_dag_init_decorator_arguments() -> int: dag_mod = ast.parse(DAG_PY.read_text("utf-8"), str(DAG_PY)) - - utils_tg = ast.parse(UTILS_TG_PY.read_text("utf-8"), str(UTILS_TG_PY)) + tg_mod = ast.parse(TG_PY.read_text("utf-8"), str(TG_PY)) decos_tg = ast.parse(DECOS_TG_PY.read_text("utf-8"), str(DECOS_TG_PY)) items_to_check = [ - ("DAG", _find_dag_init(dag_mod), "dag", _find_dag_deco(dag_mod), "dag_id", ""), - ("TaskGroup", _find_tg_init(utils_tg), "task_group", _find_tg_deco(decos_tg), "group_id", None), + ( + "DAG", + list(_find_cls_attrs(dag_mod, "DAG", ignore=["full_filepath", "task_group"])), + "dag", + _find_dag_deco(dag_mod), + "dag_id", + "", + ), + ( + "TaskGroup", + list(_find_cls_attrs(tg_mod, "TaskGroup")), + "_task_group", + _find_tg_deco(decos_tg), + "group_id", + None, + ), ] - for init_name, init, deco_name, deco, id_arg, id_default in items_to_check: - if getattr(init.args, "posonlyargs", None) or getattr(deco.args, "posonlyargs", None): - print(f"{init_name} and @{deco_name} should not declare positional-only arguments") - return -1 - if init.args.vararg or init.args.kwarg or deco.args.vararg or deco.args.kwarg: - print(f"{init_name} and @{deco_name} should not declare *args and **kwargs") - return -1 - - # Feel free to change this and make some of the arguments keyword-only! - if init.args.kwonlyargs or deco.args.kwonlyargs: - print(f"{init_name}() and @{deco_name}() should not declare keyword-only arguments") - return -2 - if init.args.kw_defaults or deco.args.kw_defaults: - print(f"{init_name}() and @{deco_name}() should not declare keyword-only arguments") - return -2 - - init_arg_names = [a.arg for a in init.args.args] + for init_name, cls_attrs, deco_name, deco, id_arg, id_default in items_to_check: deco_arg_names = [a.arg for a in deco.args.args] - if init_arg_names[0] != "self": - print(f"First argument in {init_name} must be 'self'") - return -3 - if init_arg_names[1] != id_arg: - print(f"Second argument in {init_name} must be {id_arg!r}") + if _name(cls_attrs[0].target) != id_arg: + print(f"First attribute in {init_name} must be {id_arg!r} (got {cls_attrs[0]!r})") return -3 if deco_arg_names[0] != id_arg: - print(f"First argument in @{deco_name} must be {id_arg!r}") + print(f"First argument in @{deco_name} must be {id_arg!r} (got {deco_arg_names[0]!r})") return -3 - if len(init.args.defaults) != len(init_arg_names) - 2: - print(f"All arguments on {init_name} except self and {id_arg} must have defaults") - return -4 if len(deco.args.defaults) != len(deco_arg_names): print(f"All arguments on @{deco_name} must have defaults") return -4 @@ -178,13 +237,11 @@ def check_dag_init_decorator_arguments() -> int: print(f"Default {id_arg} on @{deco_name} must be {id_default!r}") return -4 - for init_name, init, deco_name, deco, _, _ in items_to_check: - errors = list(_match_arguments((init_name, init.args.args[1:]), (deco_name, deco.args.args))) - if errors: - break - init_defaults_def = (init_name, init.args.defaults) - deco_defaults_def = (deco_name, deco.args.defaults[1:]) - errors = list(_match_defaults(deco_arg_names, init_defaults_def, deco_defaults_def)) + errors = [] + for init_name, cls_attrs, deco_name, deco, _, _ in items_to_check: + errors = list( + _match_arguments((init_name, cls_attrs), (deco_name, deco.args.args + deco.args.kwonlyargs)) + ) if errors: break diff --git a/scripts/docker/entrypoint_ci.sh b/scripts/docker/entrypoint_ci.sh index a96e58b9c21a..cbd7bdce141e 100755 --- a/scripts/docker/entrypoint_ci.sh +++ b/scripts/docker/entrypoint_ci.sh @@ -377,7 +377,7 @@ function check_force_lowest_dependencies() { echo fi set -x - uv pip install --python "$(which python)" --resolution lowest-direct --upgrade --editable ".${EXTRA}" + uv pip install --python "$(which python)" --resolution lowest-direct --upgrade --editable ".${EXTRA}" --editable "./task_sdk" set +x } diff --git a/scripts/docker/install_from_docker_context_files.sh b/scripts/docker/install_from_docker_context_files.sh index edcb50c82e05..4ce7cbffb740 100644 --- a/scripts/docker/install_from_docker_context_files.sh +++ b/scripts/docker/install_from_docker_context_files.sh @@ -27,6 +27,7 @@ # TODO: rewrite it all in Python (and all other scripts in scripts/docker) function install_airflow_and_providers_from_docker_context_files(){ + local flags=() if [[ ${INSTALL_MYSQL_CLIENT} != "true" ]]; then AIRFLOW_EXTRAS=${AIRFLOW_EXTRAS/mysql,} fi @@ -65,10 +66,10 @@ function install_airflow_and_providers_from_docker_context_files(){ install_airflow_package=("apache-airflow[${AIRFLOW_EXTRAS}]==${AIRFLOW_VERSION}") fi - # Find Provider packages in docker-context files - readarray -t installing_providers_packages< <(python /scripts/docker/get_package_specs.py /docker-context-files/apache?airflow?providers*.{whl,tar.gz} 2>/dev/null || true) + # Find Provider/TaskSDK packages in docker-context files + readarray -t airflow_packages< <(python /scripts/docker/get_package_specs.py /docker-context-files/apache?airflow?{providers,task?sdk}*.{whl,tar.gz} 2>/dev/null || true) echo - echo "${COLOR_BLUE}Found provider packages in docker-context-files folder: ${installing_providers_packages[*]}${COLOR_RESET}" + echo "${COLOR_BLUE}Found provider packages in docker-context-files folder: ${airflow_packages[*]}${COLOR_RESET}" echo if [[ ${USE_CONSTRAINTS_FOR_CONTEXT_PACKAGES=} == "true" ]]; then @@ -81,11 +82,7 @@ function install_airflow_and_providers_from_docker_context_files(){ echo "${COLOR_BLUE}Installing docker-context-files packages with constraints found in ${local_constraints_file}${COLOR_RESET}" echo # force reinstall all airflow + provider packages with constraints found in - set -x - ${PACKAGING_TOOL_CMD} install ${EXTRA_INSTALL_FLAGS} --upgrade \ - ${ADDITIONAL_PIP_INSTALL_FLAGS} --constraint "${local_constraints_file}" \ - "${install_airflow_package[@]}" "${installing_providers_packages[@]}" - set +x + flags=(--upgrade --constraint "${local_constraints_file}") echo echo "${COLOR_BLUE}Copying ${local_constraints_file} to ${HOME}/constraints.txt${COLOR_RESET}" echo @@ -94,23 +91,21 @@ function install_airflow_and_providers_from_docker_context_files(){ echo echo "${COLOR_BLUE}Installing docker-context-files packages with constraints from GitHub${COLOR_RESET}" echo - set -x - ${PACKAGING_TOOL_CMD} install ${EXTRA_INSTALL_FLAGS} \ - ${ADDITIONAL_PIP_INSTALL_FLAGS} \ - --constraint "${HOME}/constraints.txt" \ - "${install_airflow_package[@]}" "${installing_providers_packages[@]}" - set +x + flags=(--constraint "${HOME}/constraints.txt") fi else echo echo "${COLOR_BLUE}Installing docker-context-files packages without constraints${COLOR_RESET}" echo - set -x - ${PACKAGING_TOOL_CMD} install ${EXTRA_INSTALL_FLAGS} \ - ${ADDITIONAL_PIP_INSTALL_FLAGS} \ - "${install_airflow_package[@]}" "${installing_providers_packages[@]}" - set +x + flags=() fi + + set -x + ${PACKAGING_TOOL_CMD} install ${EXTRA_INSTALL_FLAGS} \ + ${ADDITIONAL_PIP_INSTALL_FLAGS} \ + "${flags[@]}" \ + "${install_airflow_package[@]}" "${airflow_packages[@]}" + set +x common::install_packaging_tools pip check } diff --git a/task_sdk/pyproject.toml b/task_sdk/pyproject.toml index be2be98baf86..37ea2d300ab4 100644 --- a/task_sdk/pyproject.toml +++ b/task_sdk/pyproject.toml @@ -21,7 +21,11 @@ version = "0.1.0.dev0" description = "Python Task SDK for Apache Airflow DAG Authors" readme = { file = "README.md", content-type = "text/markdown" } requires-python = ">=3.9, <3.13" -dependencies = [] +dependencies = [ + "attrs>=24.2.0", + "google-re2>=1.1.20240702", + "methodtools>=0.4.7", +] [build-system] requires = ["hatchling"] @@ -33,8 +37,37 @@ packages = ["src/airflow"] [tool.ruff] extend = "../pyproject.toml" src = ["src"] +namespace-packages = ["src/airflow"] [tool.ruff.lint.per-file-ignores] # Ignore Doc rules et al for anything outside of tests "!src/*" = ["D", "TID253", "S101", "TRY002"] + +"src/airflow/sdk/__init__.py" = ["TCH004"] + +[tool.uv] +dev-dependencies = [ + "kgb>=7.1.1", + "pytest-asyncio>=0.24.0", + "pytest-mock>=3.14.0", + "pytest>=8.3.3", +] + +[tool.coverage.run] +branch = true +relative_files = true +source = ["src/airflow"] +include_namespace_packages = true + +[tool.coverage.report] +skip_empty = true +exclude_also = [ + "def __repr__", + "raise AssertionError", + "raise NotImplementedError", + "if __name__ == .__main__.:", + "@(abc\\.)?abstractmethod", + "@(typing(_extensions)?\\.)?overload", + "if (typing(_extensions)?\\.)?TYPE_CHECKING:", +] diff --git a/task_sdk/src/airflow/sdk/__init__.py b/task_sdk/src/airflow/sdk/__init__.py index 2a3e01b64bc4..f538baedff01 100644 --- a/task_sdk/src/airflow/sdk/__init__.py +++ b/task_sdk/src/airflow/sdk/__init__.py @@ -16,6 +16,41 @@ # under the License. from __future__ import annotations +from typing import TYPE_CHECKING -def hello() -> str: - return "Hello from task-sdk!" +__all__ = [ + "BaseOperator", + "DAG", + "EdgeModifier", + "Label", + "TaskGroup", + "dag", +] + +if TYPE_CHECKING: + from airflow.sdk.definitions.baseoperator import BaseOperator + from airflow.sdk.definitions.dag import DAG, dag + from airflow.sdk.definitions.edges import EdgeModifier, Label + from airflow.sdk.definitions.taskgroup import TaskGroup + +__lazy_imports: dict[str, str] = { + "DAG": ".definitions.dag", + "dag": ".definitions.dag", + "BaseOperator": ".definitions.baseoperator", + "TaskGroup": ".definitions.taskgroup", + "EdgeModifier": ".definitions.edges", + "Label": ".definitions.edges", +} + + +def __getattr__(name: str): + if module_path := __lazy_imports.get(name): + import importlib + + mod = importlib.import_module(module_path, __name__) + val = getattr(mod, name) + + # Store for next time + globals()[name] = val + return val + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/task_sdk/tests/test_hello.py b/task_sdk/src/airflow/sdk/definitions/__init__.py similarity index 85% rename from task_sdk/tests/test_hello.py rename to task_sdk/src/airflow/sdk/definitions/__init__.py index 62cfdc069ca0..13a83393a912 100644 --- a/task_sdk/tests/test_hello.py +++ b/task_sdk/src/airflow/sdk/definitions/__init__.py @@ -14,10 +14,3 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from __future__ import annotations - -from airflow.sdk import hello - - -def test_hello(): - assert hello() == "Hello from task-sdk!" diff --git a/task_sdk/src/airflow/sdk/definitions/abstractoperator.py b/task_sdk/src/airflow/sdk/definitions/abstractoperator.py new file mode 100644 index 000000000000..5285bd97ef43 --- /dev/null +++ b/task_sdk/src/airflow/sdk/definitions/abstractoperator.py @@ -0,0 +1,261 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import datetime +from abc import abstractmethod +from collections.abc import ( + Collection, + Iterable, +) +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, +) + +from airflow.sdk.definitions.mixins import DependencyMixin +from airflow.sdk.definitions.node import DAGNode +from airflow.utils.trigger_rule import TriggerRule +from airflow.utils.weight_rule import WeightRule + +if TYPE_CHECKING: + from airflow.models.baseoperatorlink import BaseOperatorLink + from airflow.models.operator import Operator + from airflow.sdk.definitions.baseoperator import BaseOperator + from airflow.sdk.definitions.dag import DAG + +DEFAULT_OWNER: str = "airflow" +DEFAULT_POOL_SLOTS: int = 1 +DEFAULT_PRIORITY_WEIGHT: int = 1 +DEFAULT_EXECUTOR: str | None = None +DEFAULT_QUEUE: str = "default" +DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST: bool = False +DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING: bool = False +DEFAULT_RETRIES: int = 0 +DEFAULT_RETRY_DELAY: datetime.timedelta = datetime.timedelta(seconds=300) +MAX_RETRY_DELAY: int = 24 * 60 * 60 + +# TODO: Task-SDK -- these defaults should be overridable from the Airflow config +DEFAULT_TRIGGER_RULE: TriggerRule = TriggerRule.ALL_SUCCESS +DEFAULT_WEIGHT_RULE: WeightRule = WeightRule.DOWNSTREAM +DEFAULT_TASK_EXECUTION_TIMEOUT: datetime.timedelta | None = None + + +class NotMapped(Exception): + """Raise if a task is neither mapped nor has any parent mapped groups.""" + + +class AbstractOperator(DAGNode): + """ + Common implementation for operators, including unmapped and mapped. + + This base class is more about sharing implementations, not defining a common + interface. Unfortunately it's difficult to use this as the common base class + for typing due to BaseOperator carrying too much historical baggage. + + The union type ``from airflow.models.operator import Operator`` is easier + to use for typing purposes. + + :meta private: + """ + + operator_class: type[BaseOperator] | dict[str, Any] + + priority_weight: int + + # Defines the operator level extra links. + operator_extra_links: Collection[BaseOperatorLink] + + owner: str + task_id: str + + outlets: list + inlets: list + # TODO: + trigger_rule: TriggerRule + _needs_expansion: bool | None = None + _on_failure_fail_dagrun = False + is_setup: bool = False + is_teardown: bool = False + + HIDE_ATTRS_FROM_UI: ClassVar[frozenset[str]] = frozenset( + ( + "log", + "dag", # We show dag_id, don't need to show this too + "node_id", # Duplicates task_id + "task_group", # Doesn't have a useful repr, no point showing in UI + "inherits_from_empty_operator", # impl detail + # Decide whether to start task execution from triggerer + "start_trigger_args", + "start_from_trigger", + # For compatibility with TG, for operators these are just the current task, no point showing + "roots", + "leaves", + # These lists are already shown via *_task_ids + "upstream_list", + "downstream_list", + # Not useful, implementation detail, already shown elsewhere + "global_operator_extra_link_dict", + "operator_extra_link_dict", + ) + ) + + def get_dag(self) -> DAG | None: + raise NotImplementedError() + + @property + def task_type(self) -> str: + raise NotImplementedError() + + @property + def operator_name(self) -> str: + raise NotImplementedError() + + @property + def inherits_from_empty_operator(self) -> bool: + raise NotImplementedError() + + @property + def dag_id(self) -> str: + """Returns dag id if it has one or an adhoc + owner.""" + dag = self.get_dag() + if dag: + return dag.dag_id + return f"adhoc_{self.owner}" + + @property + def node_id(self) -> str: + return self.task_id + + @property + @abstractmethod + def task_display_name(self) -> str: ... + + @property + def label(self) -> str | None: + if self.task_display_name and self.task_display_name != self.task_id: + return self.task_display_name + # Prefix handling if no display is given is cloned from taskmixin for compatibility + tg = self.task_group + if tg and tg.node_id and tg.prefix_group_id: + # "task_group_id.task_id" -> "task_id" + return self.task_id[len(tg.node_id) + 1 :] + return self.task_id + + def as_setup(self): + self.is_setup = True + return self + + def as_teardown( + self, + *, + setups: BaseOperator | Iterable[BaseOperator] | None = None, + on_failure_fail_dagrun: bool | None = None, + ): + self.is_teardown = True + self.trigger_rule = TriggerRule.ALL_DONE_SETUP_SUCCESS + if on_failure_fail_dagrun is not None: + self.on_failure_fail_dagrun = on_failure_fail_dagrun + if setups is not None: + setups = [setups] if isinstance(setups, DependencyMixin) else setups + for s in setups: + s.is_setup = True + s >> self + return self + + def get_flat_relative_ids(self, *, upstream: bool = False) -> set[str]: + """ + Get a flat set of relative IDs, upstream or downstream. + + Will recurse each relative found in the direction specified. + + :param upstream: Whether to look for upstream or downstream relatives. + """ + dag = self.get_dag() + if not dag: + return set() + + relatives: set[str] = set() + + # This is intentionally implemented as a loop, instead of calling + # get_direct_relative_ids() recursively, since Python has significant + # limitation on stack level, and a recursive implementation can blow up + # if a DAG contains very long routes. + task_ids_to_trace = self.get_direct_relative_ids(upstream) + while task_ids_to_trace: + task_ids_to_trace_next: set[str] = set() + for task_id in task_ids_to_trace: + if task_id in relatives: + continue + task_ids_to_trace_next.update(dag.task_dict[task_id].get_direct_relative_ids(upstream)) + relatives.add(task_id) + task_ids_to_trace = task_ids_to_trace_next + + return relatives + + def get_flat_relatives(self, upstream: bool = False) -> Collection[Operator]: + """Get a flat list of relatives, either upstream or downstream.""" + dag = self.get_dag() + if not dag: + return set() + return [dag.task_dict[task_id] for task_id in self.get_flat_relative_ids(upstream=upstream)] + + def get_upstreams_follow_setups(self) -> Iterable[Operator]: + """All upstreams and, for each upstream setup, its respective teardowns.""" + for task in self.get_flat_relatives(upstream=True): + yield task + if task.is_setup: + for t in task.downstream_list: + if t.is_teardown and t != self: + yield t + + def get_upstreams_only_setups_and_teardowns(self) -> Iterable[Operator]: + """ + Only *relevant* upstream setups and their teardowns. + + This method is meant to be used when we are clearing the task (non-upstream) and we need + to add in the *relevant* setups and their teardowns. + + Relevant in this case means, the setup has a teardown that is downstream of ``self``, + or the setup has no teardowns. + """ + downstream_teardown_ids = { + x.task_id for x in self.get_flat_relatives(upstream=False) if x.is_teardown + } + for task in self.get_flat_relatives(upstream=True): + if not task.is_setup: + continue + has_no_teardowns = not any(True for x in task.downstream_list if x.is_teardown) + # if task has no teardowns or has teardowns downstream of self + if has_no_teardowns or task.downstream_task_ids.intersection(downstream_teardown_ids): + yield task + for t in task.downstream_list: + if t.is_teardown and t != self: + yield t + + def get_upstreams_only_setups(self) -> Iterable[Operator]: + """ + Return relevant upstream setups. + + This method is meant to be used when we are checking task dependencies where we need + to wait for all the upstream setups to complete before we can run the task. + """ + for task in self.get_upstreams_only_setups_and_teardowns(): + if task.is_setup: + yield task diff --git a/task_sdk/src/airflow/sdk/definitions/baseoperator.py b/task_sdk/src/airflow/sdk/definitions/baseoperator.py new file mode 100644 index 000000000000..fc16682a63cd --- /dev/null +++ b/task_sdk/src/airflow/sdk/definitions/baseoperator.py @@ -0,0 +1,1226 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import abc +import collections.abc +import contextlib +import copy +import inspect +import sys +import warnings +from collections.abc import Collection, Iterable, Sequence +from dataclasses import dataclass, field +from datetime import datetime, timedelta +from functools import total_ordering, wraps +from types import FunctionType +from typing import TYPE_CHECKING, Any, ClassVar, Final, TypeVar, cast + +import attrs + +from airflow.models.param import ParamsDict +from airflow.sdk.definitions.abstractoperator import ( + DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST, + DEFAULT_OWNER, + DEFAULT_POOL_SLOTS, + DEFAULT_PRIORITY_WEIGHT, + DEFAULT_QUEUE, + DEFAULT_RETRIES, + DEFAULT_RETRY_DELAY, + DEFAULT_TASK_EXECUTION_TIMEOUT, + DEFAULT_TRIGGER_RULE, + DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING, + DEFAULT_WEIGHT_RULE, + AbstractOperator, +) +from airflow.sdk.definitions.decorators import fixup_decorator_warning_stack +from airflow.sdk.definitions.node import validate_key +from airflow.sdk.types import NOTSET, validate_instance_args +from airflow.task.priority_strategy import ( + PriorityWeightStrategy, + airflow_priority_weight_strategies, + validate_and_load_priority_weight_strategy, +) +from airflow.utils import timezone +from airflow.utils.setup_teardown import SetupTeardownContext +from airflow.utils.trigger_rule import TriggerRule +from airflow.utils.types import AttributeRemoved + +T = TypeVar("T", bound=FunctionType) + +if TYPE_CHECKING: + from airflow.models.xcom_arg import XComArg + from airflow.sdk.definitions.dag import DAG + from airflow.sdk.definitions.taskgroup import TaskGroup + from airflow.serialization.enums import DagAttributeTypes + from airflow.utils.operator_resources import Resources + +# TODO: Task-SDK +AirflowException = RuntimeError + + +def _get_parent_defaults(dag: DAG | None, task_group: TaskGroup | None) -> tuple[dict, ParamsDict]: + if not dag: + return {}, ParamsDict() + dag_args = copy.copy(dag.default_args) + dag_params = copy.deepcopy(dag.params) + if task_group: + if task_group.default_args and not isinstance(task_group.default_args, collections.abc.Mapping): + raise TypeError("default_args must be a mapping") + dag_args.update(task_group.default_args) + return dag_args, dag_params + + +def get_merged_defaults( + dag: DAG | None, + task_group: TaskGroup | None, + task_params: collections.abc.MutableMapping | None, + task_default_args: dict | None, +) -> tuple[dict, ParamsDict]: + args, params = _get_parent_defaults(dag, task_group) + if task_params: + if not isinstance(task_params, collections.abc.Mapping): + raise TypeError(f"params must be a mapping, got {type(task_params)}") + params.update(task_params) + if task_default_args: + if not isinstance(task_default_args, collections.abc.Mapping): + raise TypeError(f"default_args must be a mapping, got {type(task_params)}") + args.update(task_default_args) + with contextlib.suppress(KeyError): + params.update(task_default_args["params"] or {}) + return args, params + + +class BaseOperatorMeta(abc.ABCMeta): + """Metaclass of BaseOperator.""" + + @classmethod + def _apply_defaults(cls, func: T) -> T: + """ + Look for an argument named "default_args", and fill the unspecified arguments from it. + + Since python2.* isn't clear about which arguments are missing when + calling a function, and that this can be quite confusing with multi-level + inheritance and argument defaults, this decorator also alerts with + specific information about the missing arguments. + """ + # Cache inspect.signature for the wrapper closure to avoid calling it + # at every decorated invocation. This is separate sig_cache created + # per decoration, i.e. each function decorated using apply_defaults will + # have a different sig_cache. + sig_cache = inspect.signature(func) + non_variadic_params = { + name: param + for (name, param) in sig_cache.parameters.items() + if param.name != "self" and param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD) + } + non_optional_args = { + name + for name, param in non_variadic_params.items() + if param.default == param.empty and name != "task_id" + } + + fixup_decorator_warning_stack(func) + + @wraps(func) + def apply_defaults(self: BaseOperator, *args: Any, **kwargs: Any) -> Any: + from airflow.sdk.definitions.contextmanager import DagContext, TaskGroupContext + + if args: + raise TypeError("Use keyword arguments when initializing operators") + + instantiated_from_mapped = kwargs.pop( + "_airflow_from_mapped", + getattr(self, "_BaseOperator__from_mapped", False), + ) + + dag: DAG | None = kwargs.get("dag") + if dag is None: + dag = DagContext.get_current() + if dag is not None: + kwargs["dag"] = dag + + task_group: TaskGroup | None = kwargs.get("task_group") + if dag and not task_group: + task_group = TaskGroupContext.get_current(dag) + if task_group is not None: + kwargs["task_group"] = task_group + + default_args, merged_params = get_merged_defaults( + dag=dag, + task_group=task_group, + task_params=kwargs.pop("params", None), + task_default_args=kwargs.pop("default_args", None), + ) + + for arg in sig_cache.parameters: + if arg not in kwargs and arg in default_args: + kwargs[arg] = default_args[arg] + + missing_args = non_optional_args.difference(kwargs) + if len(missing_args) == 1: + raise TypeError(f"missing keyword argument {missing_args.pop()!r}") + elif missing_args: + display = ", ".join(repr(a) for a in sorted(missing_args)) + raise TypeError(f"missing keyword arguments {display}") + + if merged_params: + kwargs["params"] = merged_params + + hook = getattr(self, "_hook_apply_defaults", None) + if hook: + args, kwargs = hook(**kwargs, default_args=default_args) + default_args = kwargs.pop("default_args", {}) + + if not hasattr(self, "_BaseOperator__init_kwargs"): + object.__setattr__(self, "_BaseOperator__init_kwargs", {}) + object.__setattr__(self, "_BaseOperator__from_mapped", instantiated_from_mapped) + + result = func(self, **kwargs, default_args=default_args) + + # Store the args passed to init -- we need them to support task.map serialization! + self._BaseOperator__init_kwargs.update(kwargs) # type: ignore + + # Set upstream task defined by XComArgs passed to template fields of the operator. + # BUT: only do this _ONCE_, not once for each class in the hierarchy + if not instantiated_from_mapped and func == self.__init__.__wrapped__: # type: ignore[misc] + self._set_xcomargs_dependencies() + # Mark instance as instantiated so that futre attr setting updates xcomarg-based deps. + object.__setattr__(self, "_BaseOperator__instantiated", True) + + return result + + apply_defaults.__non_optional_args = non_optional_args # type: ignore + apply_defaults.__param_names = set(non_variadic_params) # type: ignore + + return cast(T, apply_defaults) + + def __new__(cls, name, bases, namespace, **kwargs): + # TODO: Task-SDK + # execute_method = namespace.get("execute") + # if callable(execute_method) and not getattr(execute_method, "__isabstractmethod__", False): + # namespace["execute"] = ExecutorSafeguard().decorator(execute_method) + new_cls = super().__new__(cls, name, bases, namespace, **kwargs) + with contextlib.suppress(KeyError): + # Update the partial descriptor with the class method, so it calls the actual function + # (but let subclasses override it if they need to) + # TODO: Task-SDK + # partial_desc = vars(new_cls)["partial"] + # if isinstance(partial_desc, _PartialDescriptor): + # partial_desc.class_method = classmethod(partial) + ... + + # We patch `__init__` only if the class defines it. + if inspect.getmro(new_cls)[1].__init__ is not new_cls.__init__: + new_cls.__init__ = cls._apply_defaults(new_cls.__init__) + + return new_cls + + +# TODO: The following mapping is used to validate that the arguments passed to the BaseOperator are of the +# correct type. This is a temporary solution until we find a more sophisticated method for argument +# validation. One potential method is to use `get_type_hints` from the typing module. However, this is not +# fully compatible with future annotations for Python versions below 3.10. Once we require a minimum Python +# version that supports `get_type_hints` effectively or find a better approach, we can replace this +# manual type-checking method. +BASEOPERATOR_ARGS_EXPECTED_TYPES = { + "task_id": str, + "email": (str, Sequence), + "email_on_retry": bool, + "email_on_failure": bool, + "retries": int, + "retry_exponential_backoff": bool, + "depends_on_past": bool, + "ignore_first_depends_on_past": bool, + "wait_for_past_depends_before_skipping": bool, + "wait_for_downstream": bool, + "priority_weight": int, + "queue": str, + "pool": str, + "pool_slots": int, + "trigger_rule": str, + "run_as_user": str, + "task_concurrency": int, + "map_index_template": str, + "max_active_tis_per_dag": int, + "max_active_tis_per_dagrun": int, + "executor": str, + "do_xcom_push": bool, + "multiple_outputs": bool, + "doc": str, + "doc_md": str, + "doc_json": str, + "doc_yaml": str, + "doc_rst": str, + "task_display_name": str, + "logger_name": str, + "allow_nested_operators": bool, + "start_date": datetime, + "end_date": datetime, +} + + +# Note: BaseOperator is defined as a dataclass, and not an attrs class as we do too much metaprogramming in +# here (metaclass, custom `__setattr__` behaviour) and this fights with attrs too much to make it worth it. +# +# To future reader: if you want to try and make this a "normal" attrs class, go ahead and attempt it. If you +# get no where leave your record here for the next poor soul and what problems you ran in to. +# +# @ashb, 2024/10/14 +# - "Can't combine custom __setattr__ with on_setattr hooks" +# - Setting class-wide `define(on_setarrs=...)` isn't called for non-attrs subclasses +@total_ordering +@dataclass(repr=False) +class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta): + r""" + Abstract base class for all operators. + + Since operators create objects that become nodes in the DAG, BaseOperator + contains many recursive methods for DAG crawling behavior. To derive from + this class, you are expected to override the constructor and the 'execute' + method. + + Operators derived from this class should perform or trigger certain tasks + synchronously (wait for completion). Example of operators could be an + operator that runs a Pig job (PigOperator), a sensor operator that + waits for a partition to land in Hive (HiveSensorOperator), or one that + moves data from Hive to MySQL (Hive2MySqlOperator). Instances of these + operators (tasks) target specific operations, running specific scripts, + functions or data transfers. + + This class is abstract and shouldn't be instantiated. Instantiating a + class derived from this one results in the creation of a task object, + which ultimately becomes a node in DAG objects. Task dependencies should + be set by using the set_upstream and/or set_downstream methods. + + :param task_id: a unique, meaningful id for the task + :param owner: the owner of the task. Using a meaningful description + (e.g. user/person/team/role name) to clarify ownership is recommended. + :param email: the 'to' email address(es) used in email alerts. This can be a + single email or multiple ones. Multiple addresses can be specified as a + comma or semicolon separated string or by passing a list of strings. + :param email_on_retry: Indicates whether email alerts should be sent when a + task is retried + :param email_on_failure: Indicates whether email alerts should be sent when + a task failed + :param retries: the number of retries that should be performed before + failing the task + :param retry_delay: delay between retries, can be set as ``timedelta`` or + ``float`` seconds, which will be converted into ``timedelta``, + the default is ``timedelta(seconds=300)``. + :param retry_exponential_backoff: allow progressively longer waits between + retries by using exponential backoff algorithm on retry delay (delay + will be converted into seconds) + :param max_retry_delay: maximum delay interval between retries, can be set as + ``timedelta`` or ``float`` seconds, which will be converted into ``timedelta``. + :param start_date: The ``start_date`` for the task, determines + the ``execution_date`` for the first task instance. The best practice + is to have the start_date rounded + to your DAG's ``schedule_interval``. Daily jobs have their start_date + some day at 00:00:00, hourly jobs have their start_date at 00:00 + of a specific hour. Note that Airflow simply looks at the latest + ``execution_date`` and adds the ``schedule_interval`` to determine + the next ``execution_date``. It is also very important + to note that different tasks' dependencies + need to line up in time. If task A depends on task B and their + start_date are offset in a way that their execution_date don't line + up, A's dependencies will never be met. If you are looking to delay + a task, for example running a daily task at 2AM, look into the + ``TimeSensor`` and ``TimeDeltaSensor``. We advise against using + dynamic ``start_date`` and recommend using fixed ones. Read the + FAQ entry about start_date for more information. + :param end_date: if specified, the scheduler won't go beyond this date + :param depends_on_past: when set to true, task instances will run + sequentially and only if the previous instance has succeeded or has been skipped. + The task instance for the start_date is allowed to run. + :param wait_for_past_depends_before_skipping: when set to true, if the task instance + should be marked as skipped, and depends_on_past is true, the ti will stay on None state + waiting the task of the previous run + :param wait_for_downstream: when set to true, an instance of task + X will wait for tasks immediately downstream of the previous instance + of task X to finish successfully or be skipped before it runs. This is useful if the + different instances of a task X alter the same asset, and this asset + is used by tasks downstream of task X. Note that depends_on_past + is forced to True wherever wait_for_downstream is used. Also note that + only tasks *immediately* downstream of the previous task instance are waited + for; the statuses of any tasks further downstream are ignored. + :param dag: a reference to the dag the task is attached to (if any) + :param priority_weight: priority weight of this task against other task. + This allows the executor to trigger higher priority tasks before + others when things get backed up. Set priority_weight as a higher + number for more important tasks. + :param weight_rule: weighting method used for the effective total + priority weight of the task. Options are: + ``{ downstream | upstream | absolute }`` default is ``downstream`` + When set to ``downstream`` the effective weight of the task is the + aggregate sum of all downstream descendants. As a result, upstream + tasks will have higher weight and will be scheduled more aggressively + when using positive weight values. This is useful when you have + multiple dag run instances and desire to have all upstream tasks to + complete for all runs before each dag can continue processing + downstream tasks. When set to ``upstream`` the effective weight is the + aggregate sum of all upstream ancestors. This is the opposite where + downstream tasks have higher weight and will be scheduled more + aggressively when using positive weight values. This is useful when you + have multiple dag run instances and prefer to have each dag complete + before starting upstream tasks of other dags. When set to + ``absolute``, the effective weight is the exact ``priority_weight`` + specified without additional weighting. You may want to do this when + you know exactly what priority weight each task should have. + Additionally, when set to ``absolute``, there is bonus effect of + significantly speeding up the task creation process as for very large + DAGs. Options can be set as string or using the constants defined in + the static class ``airflow.utils.WeightRule`` + |experimental| + Since 2.9.0, Airflow allows to define custom priority weight strategy, + by creating a subclass of + ``airflow.task.priority_strategy.PriorityWeightStrategy`` and registering + in a plugin, then providing the class path or the class instance via + ``weight_rule`` parameter. The custom priority weight strategy will be + used to calculate the effective total priority weight of the task instance. + :param queue: which queue to target when running this job. Not + all executors implement queue management, the CeleryExecutor + does support targeting specific queues. + :param pool: the slot pool this task should run in, slot pools are a + way to limit concurrency for certain tasks + :param pool_slots: the number of pool slots this task should use (>= 1) + Values less than 1 are not allowed. + :param sla: time by which the job is expected to succeed. Note that + this represents the ``timedelta`` after the period is closed. For + example if you set an SLA of 1 hour, the scheduler would send an email + soon after 1:00AM on the ``2016-01-02`` if the ``2016-01-01`` instance + has not succeeded yet. + The scheduler pays special attention for jobs with an SLA and + sends alert + emails for SLA misses. SLA misses are also recorded in the database + for future reference. All tasks that share the same SLA time + get bundled in a single email, sent soon after that time. SLA + notification are sent once and only once for each task instance. + :param execution_timeout: max time allowed for the execution of + this task instance, if it goes beyond it will raise and fail. + :param on_failure_callback: a function or list of functions to be called when a task instance + of this task fails. a context dictionary is passed as a single + parameter to this function. Context contains references to related + objects to the task instance and is documented under the macros + section of the API. + :param on_execute_callback: much like the ``on_failure_callback`` except + that it is executed right before the task is executed. + :param on_retry_callback: much like the ``on_failure_callback`` except + that it is executed when retries occur. + :param on_success_callback: much like the ``on_failure_callback`` except + that it is executed when the task succeeds. + :param on_skipped_callback: much like the ``on_failure_callback`` except + that it is executed when skipped occur; this callback will be called only if AirflowSkipException get raised. + Explicitly it is NOT called if a task is not started to be executed because of a preceding branching + decision in the DAG or a trigger rule which causes execution to skip so that the task execution + is never scheduled. + :param pre_execute: a function to be called immediately before task + execution, receiving a context dictionary; raising an exception will + prevent the task from being executed. + + |experimental| + :param post_execute: a function to be called immediately after task + execution, receiving a context dictionary and task result; raising an + exception will prevent the task from succeeding. + + |experimental| + :param trigger_rule: defines the rule by which dependencies are applied + for the task to get triggered. Options are: + ``{ all_success | all_failed | all_done | all_skipped | one_success | one_done | + one_failed | none_failed | none_failed_min_one_success | none_skipped | always}`` + default is ``all_success``. Options can be set as string or + using the constants defined in the static class + ``airflow.utils.TriggerRule`` + :param resources: A map of resource parameter names (the argument names of the + Resources constructor) to their values. + :param run_as_user: unix username to impersonate while running the task + :param max_active_tis_per_dag: When set, a task will be able to limit the concurrent + runs across execution_dates. + :param max_active_tis_per_dagrun: When set, a task will be able to limit the concurrent + task instances per DAG run. + :param executor: Which executor to target when running this task. NOT YET SUPPORTED + :param executor_config: Additional task-level configuration parameters that are + interpreted by a specific executor. Parameters are namespaced by the name of + executor. + + **Example**: to run this task in a specific docker container through + the KubernetesExecutor :: + + MyOperator(..., executor_config={"KubernetesExecutor": {"image": "myCustomDockerImage"}}) + + :param do_xcom_push: if True, an XCom is pushed containing the Operator's + result + :param multiple_outputs: if True and do_xcom_push is True, pushes multiple XComs, one for each + key in the returned dictionary result. If False and do_xcom_push is True, pushes a single XCom. + :param task_group: The TaskGroup to which the task should belong. This is typically provided when not + using a TaskGroup as a context manager. + :param doc: Add documentation or notes to your Task objects that is visible in + Task Instance details View in the Webserver + :param doc_md: Add documentation (in Markdown format) or notes to your Task objects + that is visible in Task Instance details View in the Webserver + :param doc_rst: Add documentation (in RST format) or notes to your Task objects + that is visible in Task Instance details View in the Webserver + :param doc_json: Add documentation (in JSON format) or notes to your Task objects + that is visible in Task Instance details View in the Webserver + :param doc_yaml: Add documentation (in YAML format) or notes to your Task objects + that is visible in Task Instance details View in the Webserver + :param task_display_name: The display name of the task which appears on the UI. + :param logger_name: Name of the logger used by the Operator to emit logs. + If set to `None` (default), the logger name will fall back to + `airflow.task.operators.{class.__module__}.{class.__name__}` (e.g. SimpleHttpOperator will have + *airflow.task.operators.airflow.providers.http.operators.http.SimpleHttpOperator* as logger). + :param allow_nested_operators: if True, when an operator is executed within another one a warning message + will be logged. If False, then an exception will be raised if the operator is badly used (e.g. nested + within another one). In future releases of Airflow this parameter will be removed and an exception + will always be thrown when operators are nested within each other (default is True). + + **Example**: example of a bad operator mixin usage:: + + @task(provide_context=True) + def say_hello_world(**context): + hello_world_task = BashOperator( + task_id="hello_world_task", + bash_command="python -c \"print('Hello, world!')\"", + dag=dag, + ) + hello_world_task.execute(context) + """ + + task_id: str + owner: str = DEFAULT_OWNER + email: str | Sequence[str] | None = None + email_on_retry: bool = True + email_on_failure: bool = True + retries: int | None = DEFAULT_RETRIES + retry_delay: timedelta = DEFAULT_RETRY_DELAY + retry_exponential_backoff: bool = False + max_retry_delay: timedelta | float | None = None + start_date: datetime | None = None + end_date: datetime | None = None + depends_on_past: bool = False + ignore_first_depends_on_past: bool = DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST + wait_for_past_depends_before_skipping: bool = DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING + wait_for_downstream: bool = False + + # At execution_time this becomes a normal dict + params: ParamsDict | dict = field(default_factory=ParamsDict) + default_args: dict | None = None + priority_weight: int = DEFAULT_PRIORITY_WEIGHT + weight_rule: PriorityWeightStrategy = field( + default_factory=airflow_priority_weight_strategies[DEFAULT_WEIGHT_RULE] + ) + queue: str = DEFAULT_QUEUE + pool: str = "default" + pool_slots: int = DEFAULT_POOL_SLOTS + execution_timeout: timedelta | None = DEFAULT_TASK_EXECUTION_TIMEOUT + # on_execute_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None + # on_failure_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None + # on_success_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None + # on_retry_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None + # on_skipped_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None + # pre_execute: TaskPreExecuteHook | None = None + # post_execute: TaskPostExecuteHook | None = None + trigger_rule: TriggerRule = DEFAULT_TRIGGER_RULE + resources: dict[str, Any] | None = None + run_as_user: str | None = None + task_concurrency: int | None = None + map_index_template: str | None = None + max_active_tis_per_dag: int | None = None + max_active_tis_per_dagrun: int | None = None + executor: str | None = None + executor_config: dict | None = None + do_xcom_push: bool = True + multiple_outputs: bool = False + inlets: list[Any] = field(default_factory=list) + outlets: list[Any] = field(default_factory=list) + task_group: TaskGroup | None = None + doc: str | None = None + doc_md: str | None = None + doc_json: str | None = None + doc_yaml: str | None = None + doc_rst: str | None = None + _task_display_name: str | None = None + logger_name: str | None = None + allow_nested_operators: bool = True + + is_setup: bool = False + is_teardown: bool = False + + # TODO: Task-SDK: Make these ClassVar[]? + template_fields: Collection[str] = () + template_ext: Sequence[str] = () + + template_fields_renderers: ClassVar[dict[str, str]] = {} + + # Defines the color in the UI + ui_color: str = "#fff" + ui_fgcolor: str = "#000" + + # TODO: Task-SDK Mapping + # partial: Callable[..., OperatorPartial] = _PartialDescriptor() # type: ignore + + _dag: DAG | None = field(init=False, default=None) + + # Make this optional so the type matches the one define in LoggingMixin + _log_config_logger_name: str | None = field(default="airflow.task.operators", init=False) + _logger_name: str | None = None + + # The _serialized_fields are lazily loaded when get_serialized_fields() method is called + __serialized_fields: ClassVar[frozenset[str] | None] = None + + _comps: ClassVar[set[str]] = { + "task_id", + "dag_id", + "owner", + "email", + "email_on_retry", + "retry_delay", + "retry_exponential_backoff", + "max_retry_delay", + "start_date", + "end_date", + "depends_on_past", + "wait_for_downstream", + "priority_weight", + "sla", + "execution_timeout", + "on_execute_callback", + "on_failure_callback", + "on_success_callback", + "on_retry_callback", + "on_skipped_callback", + "do_xcom_push", + "multiple_outputs", + "allow_nested_operators", + "executor", + } + + # Defines if the operator supports lineage without manual definitions + supports_lineage: bool = False + + # If True then the class constructor was called + __instantiated: bool = False + # List of args as passed to `init()`, after apply_defaults() has been updated. Used to "recreate" the task + # when mapping + # Set via the metaclass + __init_kwargs: dict[str, Any] = field(init=False) + + # Set to True before calling execute method + _lock_for_execution: bool = False + + # Set to True for an operator instantiated by a mapped operator. + __from_mapped: bool = False + + # TODO: + # start_trigger_args: StartTriggerArgs | None = None + # start_from_trigger: bool = False + + # base list which includes all the attrs that don't need deep copy. + _base_operator_shallow_copy_attrs: Final[tuple[str, ...]] = ( + "user_defined_macros", + "user_defined_filters", + "params", + ) + + # each operator should override this class attr for shallow copy attrs. + shallow_copy_attrs: Sequence[str] = () + + def __setattr__(self: BaseOperator, key: str, value: Any): + if converter := getattr(self, f"_convert_{key}", None): + value = converter(value) + super().__setattr__(key, value) + if self.__from_mapped or self._lock_for_execution: + return # Skip any custom behavior for validation and during execute. + if key in self.__init_kwargs: + self.__init_kwargs[key] = value + if self.__instantiated and key in self.template_fields: + # Resolve upstreams set by assigning an XComArg after initializing + # an operator, example: + # op = BashOperator() + # op.bash_command = "sleep 1" + self._set_xcomargs_dependency(key, value) + + def __init__( + self, + *, + task_id: str, + owner: str = DEFAULT_OWNER, + email: str | Sequence[str] | None = None, + email_on_retry: bool = True, + email_on_failure: bool = True, + retries: int | None = DEFAULT_RETRIES, + retry_delay: timedelta | float = DEFAULT_RETRY_DELAY, + retry_exponential_backoff: bool = False, + max_retry_delay: timedelta | float | None = None, + start_date: datetime | None = None, + end_date: datetime | None = None, + depends_on_past: bool = False, + ignore_first_depends_on_past: bool = DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST, + wait_for_past_depends_before_skipping: bool = DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING, + wait_for_downstream: bool = False, + dag: DAG | None = None, + params: collections.abc.MutableMapping[str, Any] | None = None, + default_args: dict | None = None, + priority_weight: int = DEFAULT_PRIORITY_WEIGHT, + weight_rule: str | PriorityWeightStrategy = DEFAULT_WEIGHT_RULE, + queue: str = DEFAULT_QUEUE, + pool: str | None = None, + pool_slots: int = DEFAULT_POOL_SLOTS, + sla: timedelta | None = None, + execution_timeout: timedelta | None = DEFAULT_TASK_EXECUTION_TIMEOUT, + # on_execute_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None, + # on_failure_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None, + # on_success_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None, + # on_retry_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None, + # on_skipped_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None, + # pre_execute: TaskPreExecuteHook | None = None, + # post_execute: TaskPostExecuteHook | None = None, + trigger_rule: str = DEFAULT_TRIGGER_RULE, + resources: dict[str, Any] | None = None, + run_as_user: str | None = None, + map_index_template: str | None = None, + max_active_tis_per_dag: int | None = None, + max_active_tis_per_dagrun: int | None = None, + executor: str | None = None, + executor_config: dict | None = None, + do_xcom_push: bool = True, + multiple_outputs: bool = False, + inlets: Any | None = None, + outlets: Any | None = None, + task_group: TaskGroup | None = None, + doc: str | None = None, + doc_md: str | None = None, + doc_json: str | None = None, + doc_yaml: str | None = None, + doc_rst: str | None = None, + task_display_name: str | None = None, + logger_name: str | None = None, + allow_nested_operators: bool = True, + **kwargs: Any, + ): + # Note: Metaclass handles passing in the DAG/TaskGroup from active context manager, if any + + self.task_id = task_group.child_id(task_id) if task_group else task_id + if not self.__from_mapped and task_group: + task_group.add(self) + + super().__init__() + self.task_group = task_group + + kwargs.pop("_airflow_mapped_validation_only", None) + if kwargs: + raise TypeError( + f"Invalid arguments were passed to {self.__class__.__name__} (task_id: {task_id}). " + f"Invalid arguments were:\n**kwargs: {kwargs}", + ) + validate_key(task_id) + + self.owner = owner + self.email = email + self.email_on_retry = email_on_retry + self.email_on_failure = email_on_failure + + if execution_timeout is not None and not isinstance(execution_timeout, timedelta): + raise ValueError( + f"execution_timeout must be timedelta object but passed as type: {type(execution_timeout)}" + ) + self.execution_timeout = execution_timeout + + # TODO: + # self.on_execute_callback = on_execute_callback + # self.on_failure_callback = on_failure_callback + # self.on_success_callback = on_success_callback + # self.on_retry_callback = on_retry_callback + # self.on_skipped_callback = on_skipped_callback + # self._pre_execute_hook = pre_execute + # self._post_execute_hook = post_execute + + if start_date: + self.start_date = timezone.convert_to_utc(start_date) + + if end_date: + self.end_date = timezone.convert_to_utc(end_date) + + if executor: + warnings.warn( + "Specifying executors for operators is not yet supported, the value {executor!r} will have no effect", + category=UserWarning, + stacklevel=2, + ) + self.executor = executor + self.executor_config = executor_config or {} + self.run_as_user = run_as_user + # TODO: + # self.retries = parse_retries(retries) + self.retries = retries + self.queue = queue + # TODO: Task-SDK: pull this default name from Pool constant? + self.pool = "default_pool" if pool is None else pool + self.pool_slots = pool_slots + if self.pool_slots < 1: + dag_str = f" in dag {dag.dag_id}" if dag else "" + raise ValueError(f"pool slots for {self.task_id}{dag_str} cannot be less than 1") + self.sla = sla + + if not TriggerRule.is_valid(trigger_rule): + raise ValueError( + f"The trigger_rule must be one of {TriggerRule.all_triggers()}," + f"'{dag.dag_id if dag else ''}.{task_id}'; received '{trigger_rule}'." + ) + + self.trigger_rule: TriggerRule = TriggerRule(trigger_rule) + + self.depends_on_past: bool = depends_on_past + self.ignore_first_depends_on_past: bool = ignore_first_depends_on_past + self.wait_for_past_depends_before_skipping: bool = wait_for_past_depends_before_skipping + self.wait_for_downstream: bool = wait_for_downstream + if wait_for_downstream: + self.depends_on_past = True + + # Converted by setattr + self.retry_delay = retry_delay # type: ignore[assignment] + self.retry_exponential_backoff = retry_exponential_backoff + if max_retry_delay is not None: + self.max_retry_delay = max_retry_delay + + self.resources = resources + + self.params = ParamsDict(params) + + self.priority_weight = priority_weight + self.weight_rule = validate_and_load_priority_weight_strategy(weight_rule) + + self.max_active_tis_per_dag: int | None = max_active_tis_per_dag + self.max_active_tis_per_dagrun: int | None = max_active_tis_per_dagrun + self.do_xcom_push: bool = do_xcom_push + self.map_index_template: str | None = map_index_template + self.multiple_outputs: bool = multiple_outputs + + self.doc_md = doc_md + self.doc_json = doc_json + self.doc_yaml = doc_yaml + self.doc_rst = doc_rst + self.doc = doc + + self._task_display_name = task_display_name + + self.allow_nested_operators = allow_nested_operators + + self._logger_name = logger_name + + # Lineage + if inlets: + self.inlets = ( + inlets + if isinstance(inlets, list) + else [ + inlets, + ] + ) + else: + self.inlets = [] + + if outlets: + self.outlets = ( + outlets + if isinstance(outlets, list) + else [ + outlets, + ] + ) + else: + self.outlets = [] + + if isinstance(self.template_fields, str): + warnings.warn( + f"The `template_fields` value for {self.task_type} is a string " + "but should be a list or tuple of string. Wrapping it in a list for execution. " + f"Please update {self.task_type} accordingly.", + UserWarning, + stacklevel=2, + ) + self.template_fields = [self.template_fields] + + self.is_setup = False + self.is_teardown = False + + if SetupTeardownContext.active: + SetupTeardownContext.update_context_map(self) + + # We set self.dag right at the end as `_convert_dag` calls `dag.add_task` for us, and we need all the + # other properties to be set at that point + if dag is not None: + self.dag = dag + + validate_instance_args(self, BASEOPERATOR_ARGS_EXPECTED_TYPES) + + def __eq__(self, other): + if type(self) is type(other): + # Use getattr() instead of __dict__ as __dict__ doesn't return + # correct values for properties. + return all(getattr(self, c, None) == getattr(other, c, None) for c in self._comps) + return False + + def __ne__(self, other): + return not self == other + + def __hash__(self): + hash_components = [type(self)] + for component in self._comps: + val = getattr(self, component, None) + try: + hash(val) + hash_components.append(val) + except TypeError: + hash_components.append(repr(val)) + return hash(tuple(hash_components)) + + # including lineage information + def __or__(self, other): + """ + Return [This Operator] | [Operator]. + + The inlets of other will be set to pick up the outlets from this operator. + Other will be set as a downstream task of this operator. + """ + if isinstance(other, BaseOperator): + if not self.outlets and not self.supports_lineage: + raise ValueError("No outlets defined for this operator") + other.add_inlets([self.task_id]) + self.set_downstream(other) + else: + raise TypeError(f"Right hand side ({other}) is not an Operator") + + return self + + # /Composing Operators --------------------------------------------- + + def __gt__(self, other): + """ + Return [Operator] > [Outlet]. + + If other is an attr annotated object it is set as an outlet of this Operator. + """ + if not isinstance(other, Iterable): + other = [other] + + for obj in other: + if not attrs.has(obj): + raise TypeError(f"Left hand side ({obj}) is not an outlet") + self.add_outlets(other) + + return self + + def __lt__(self, other): + """ + Return [Inlet] > [Operator] or [Operator] < [Inlet]. + + If other is an attr annotated object it is set as an inlet to this operator. + """ + if not isinstance(other, Iterable): + other = [other] + + for obj in other: + if not attrs.has(obj): + raise TypeError(f"{obj} cannot be an inlet") + self.add_inlets(other) + + return self + + def __deepcopy__(self, memo: dict[int, Any]): + # Hack sorting double chained task lists by task_id to avoid hitting + # max_depth on deepcopy operations. + sys.setrecursionlimit(5000) # TODO fix this in a better way + + cls = self.__class__ + result = cls.__new__(cls) + memo[id(self)] = result + + shallow_copy = tuple(cls.shallow_copy_attrs) + cls._base_operator_shallow_copy_attrs + + for k, v in self.__dict__.items(): + if k not in shallow_copy: + v = copy.deepcopy(v, memo) + else: + v = copy.copy(v) + + # Bypass any setters, and set it on the object directly. This works since we are cloning ourself so + # we know the type is already fine + object.__setattr__(result, k, v) + return result + + def __getstate__(self): + state = dict(self.__dict__) + if self._log: + del state["_log"] + + return state + + def __setstate__(self, state): + self.__dict__ = state + + def add_inlets(self, inlets: Iterable[Any]): + """Set inlets to this operator.""" + self.inlets.extend(inlets) + + def add_outlets(self, outlets: Iterable[Any]): + """Define the outlets of this operator.""" + self.outlets.extend(outlets) + + def get_dag(self) -> DAG | None: + return self._dag + + @property # type: ignore[override] + def dag(self) -> DAG: + """Returns the Operator's DAG if set, otherwise raises an error.""" + if dag := self._dag: + return dag + else: + raise RuntimeError(f"Operator {self} has not been assigned to a DAG yet") + + @dag.setter + def dag(self, dag: DAG | None | AttributeRemoved) -> None: + """Operators can be assigned to one DAG, one time. Repeat assignments to that same DAG are ok.""" + # TODO: Task-SDK: Remove the AttributeRemoved and this type ignore once we remove AIP-44 code + self._dag = dag # type: ignore[assignment] + + def _convert__dag(self, dag: DAG | None | AttributeRemoved) -> DAG | None | AttributeRemoved: + # Called automatically by __setattr__ method + from airflow.sdk.definitions.dag import DAG + + if dag is None: + return dag + + # if set to removed, then just set and exit + if type(self._dag) is AttributeRemoved: + return dag + # if setting to removed, then just set and exit + if type(dag) is AttributeRemoved: + return AttributeRemoved("_dag") # type: ignore[assignment] + + if not isinstance(dag, DAG): + raise TypeError(f"Expected DAG; received {dag.__class__.__name__}") + elif self._dag is not None and self._dag is not dag: + raise ValueError(f"The DAG assigned to {self} can not be changed.") + + if self.__from_mapped: + pass # Don't add to DAG -- the mapped task takes the place. + elif dag.task_dict.get(self.task_id) is not self: + # TODO: Task-SDK: Remove this type ignore + dag.add_task(self) # type: ignore[arg-type] + return dag + + @staticmethod + def _convert_retries(retries: Any) -> int | None: + if retries is None: + return 0 + elif type(retries) == int: # noqa: E721 + return retries + try: + parsed_retries = int(retries) + except (TypeError, ValueError): + raise TypeError(f"'retries' type must be int, not {type(retries).__name__}") + return parsed_retries + + @staticmethod + def _convert_timedelta(value: float | timedelta | None) -> timedelta | None: + if value is None or isinstance(value, timedelta): + return value + return timedelta(seconds=value) + + _convert_retry_delay = _convert_timedelta + _convert_max_retry_delay = _convert_timedelta + + @staticmethod + def _convert_resources(resources: dict[str, Any] | None) -> Resources | None: + if resources is None: + return None + + from airflow.utils.operator_resources import Resources + + if isinstance(resources, Resources): + return resources + + return Resources(**resources) + + def _convert_is_setup(self, value: bool) -> bool: + """ + Setter for is_setup property. + + :meta private: + """ + if self.is_teardown and value: + raise ValueError(f"Cannot mark task '{self.task_id}' as setup; task is already a teardown.") + return value + + def _convert_is_teardown(self, value: bool) -> bool: + if self.is_setup and value: + raise ValueError(f"Cannot mark task '{self.task_id}' as teardown; task is already a setup.") + return value + + @property + def task_display_name(self) -> str: + return self._task_display_name or self.task_id + + def has_dag(self): + """Return True if the Operator has been assigned to a DAG.""" + return self._dag is not None + + def _set_xcomargs_dependencies(self) -> None: + from airflow.models.xcom_arg import XComArg + + for f in self.template_fields: + arg = getattr(self, f, NOTSET) + if arg is not NOTSET: + XComArg.apply_upstream_relationship(self, arg) + + def _set_xcomargs_dependency(self, field: str, newvalue: Any) -> None: + """ + Resolve upstream dependencies of a task. + + In this way passing an ``XComArg`` as value for a template field + will result in creating upstream relation between two tasks. + + **Example**: :: + + with DAG(...): + generate_content = GenerateContentOperator(task_id="generate_content") + send_email = EmailOperator(..., html_content=generate_content.output) + + # This is equivalent to + with DAG(...): + generate_content = GenerateContentOperator(task_id="generate_content") + send_email = EmailOperator(..., html_content="{{ task_instance.xcom_pull('generate_content') }}") + generate_content >> send_email + + """ + from airflow.models.xcom_arg import XComArg + + if field not in self.template_fields: + return + XComArg.apply_upstream_relationship(self, newvalue) + + def on_kill(self) -> None: + """ + Override this method to clean up subprocesses when a task instance gets killed. + + Any use of the threading, subprocess or multiprocessing module within an + operator needs to be cleaned up, or it will leave ghost processes behind. + """ + + def __repr__(self): + return f"" + + @property + def operator_class(self) -> type[BaseOperator]: # type: ignore[override] + return self.__class__ + + @property + def task_type(self) -> str: + """@property: type of the task.""" + return self.__class__.__name__ + + @property + def operator_name(self) -> str: + """@property: use a more friendly display name for the operator, if set.""" + try: + return self.custom_operator_name # type: ignore + except AttributeError: + return self.task_type + + @property + def roots(self) -> list[BaseOperator]: + """Required by DAGNode.""" + return [self] + + @property + def leaves(self) -> list[BaseOperator]: + """Required by DAGNode.""" + return [self] + + @property + def output(self) -> XComArg: + """Returns reference to XCom pushed by current operator.""" + from airflow.models.xcom_arg import XComArg + + # TODO: Task-SDK: remove this type ignore once XComArg is ported over + return XComArg(operator=self) # type: ignore[call-overload] + + @classmethod + def get_serialized_fields(cls): + """Stringified DAGs and operators contain exactly these fields.""" + if not cls.__serialized_fields: + from airflow.sdk.definitions.contextmanager import DagContext + + # make sure the following "fake" task is not added to current active + # dag in context, otherwise, it will result in + # `RuntimeError: dictionary changed size during iteration` + # Exception in SerializedDAG.serialize_dag() call. + DagContext.push(None) + cls.__serialized_fields = frozenset( + vars(BaseOperator(task_id="test")).keys() + - { + "upstream_task_ids", + "default_args", + "dag", + "_dag", + "label", + "_BaseOperator__instantiated", + "_BaseOperator__init_kwargs", + "_BaseOperator__from_mapped", + "on_failure_fail_dagrun", + "task_group", + "_task_type", + } + | { # Class level defaults, or `@property` need to be added to this list + "start_date", + "end_date", + "task_type", + "ui_color", + "ui_fgcolor", + "template_ext", + "template_fields", + "template_fields_renderers", + "params", + "is_setup", + "is_teardown", + "on_failure_fail_dagrun", + "map_index_template", + "start_trigger_args", + "_needs_expansion", + "start_from_trigger", + "max_retry_delay", + } + ) + DagContext.pop() + + return cls.__serialized_fields + + def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]: + """Serialize; required by DAGNode.""" + from airflow.serialization.enums import DagAttributeTypes + + return DagAttributeTypes.OP, self.task_id + + @property + def inherits_from_empty_operator(self): + """Used to determine if an Operator is inherited from EmptyOperator.""" + # This looks like `isinstance(self, EmptyOperator) would work, but this also + # needs to cope when `self` is a Serialized instance of a EmptyOperator or one + # of its subclasses (which don't inherit from anything but BaseOperator). + return getattr(self, "_is_empty", False) diff --git a/task_sdk/src/airflow/sdk/definitions/contextmanager.py b/task_sdk/src/airflow/sdk/definitions/contextmanager.py new file mode 100644 index 000000000000..ac50dcadbfc7 --- /dev/null +++ b/task_sdk/src/airflow/sdk/definitions/contextmanager.py @@ -0,0 +1,125 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import sys +from collections import deque +from types import ModuleType +from typing import Any, Generic, TypeVar + +from airflow.sdk.definitions.dag import DAG +from airflow.sdk.definitions.taskgroup import TaskGroup + +T = TypeVar("T") + +__all__ = [ + "DagContext", + "TaskGroupContext", +] + + +# In order to add a `@classproperty`-like thing we need to define a property on a metaclass. +class ContextStackMeta(type): + _context: deque + + # TODO: Task-SDK: + # share_parent_context can go away once the DAG and TaskContext manager in airflow.models are removed and + # everything uses sdk fully for definition/parsing + def __new__(cls, name, bases, namespace, share_parent_context: bool = False, **kwargs: Any): + if not share_parent_context: + namespace["_context"] = deque() + + new_cls = super().__new__(cls, name, bases, namespace, **kwargs) + + return new_cls + + @property + def active(self) -> bool: + """The active property says if any object is currently in scope.""" + return bool(self._context) + + +class ContextStack(Generic[T], metaclass=ContextStackMeta): + _context: deque[T] + + @classmethod + def push(cls, obj: T): + cls._context.appendleft(obj) + + @classmethod + def pop(cls) -> T | None: + return cls._context.popleft() + + @classmethod + def get_current(cls) -> T | None: + try: + return cls._context[0] + except IndexError: + return None + + +class DagContext(ContextStack[DAG]): + """ + DAG context is used to keep the current DAG when DAG is used as ContextManager. + + You can use DAG as context: + + .. code-block:: python + + with DAG( + dag_id="example_dag", + default_args=default_args, + schedule="0 0 * * *", + dagrun_timeout=timedelta(minutes=60), + ) as dag: + ... + + If you do this the context stores the DAG and whenever new task is created, it will use + such stored DAG as the parent DAG. + + """ + + # TODO: Task-SDK, should module type be optional? Will that break more? + autoregistered_dags: set[tuple[DAG, ModuleType | None]] = set() + current_autoregister_module_name: str | None = None + + @classmethod + def pop(cls) -> DAG | None: + dag = super().pop() + # In a few cases around serialization we explicitly push None in to the stack + if cls.current_autoregister_module_name is not None and dag and getattr(dag, "auto_register", True): + mod = sys.modules[cls.current_autoregister_module_name] + cls.autoregistered_dags.add((dag, mod)) + return dag + + @classmethod + def get_current_dag(cls) -> DAG | None: + return cls.get_current() + + +class TaskGroupContext(ContextStack[TaskGroup]): + """TaskGroup context is used to keep the current TaskGroup when TaskGroup is used as ContextManager.""" + + @classmethod + def get_current(cls, dag: DAG | None = None) -> TaskGroup | None: + if current := super().get_current(): + return current + if dag := dag or DagContext.get_current(): + # If there's currently a DAG but no TaskGroup, return the root TaskGroup of the dag. + return dag.task_group + return None diff --git a/task_sdk/src/airflow/sdk/definitions/dag.py b/task_sdk/src/airflow/sdk/definitions/dag.py new file mode 100644 index 000000000000..479c1ea09b80 --- /dev/null +++ b/task_sdk/src/airflow/sdk/definitions/dag.py @@ -0,0 +1,1119 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import copy +import functools +import itertools +import logging +import os +import sys +import weakref +from collections import abc +from collections.abc import Collection, Iterable, MutableSet +from datetime import datetime, timedelta +from inspect import signature +from re import Pattern +from typing import ( + TYPE_CHECKING, + Any, + Callable, + ClassVar, + Union, + cast, +) +from urllib.parse import urlsplit + +import attrs +import jinja2 +import re2 +from dateutil.relativedelta import relativedelta + +from airflow import settings +from airflow.assets import Asset, AssetAlias, BaseAsset +from airflow.exceptions import ( + DuplicateTaskIdFound, + FailStopDagInvalidTriggerRule, + ParamValidationError, + TaskNotFound, +) +from airflow.models.param import DagParam, ParamsDict +from airflow.sdk.definitions.abstractoperator import AbstractOperator +from airflow.sdk.definitions.baseoperator import BaseOperator +from airflow.sdk.types import NOTSET +from airflow.timetables.base import Timetable +from airflow.timetables.simple import ( + AssetTriggeredTimetable, + ContinuousTimetable, + NullTimetable, + OnceTimetable, +) +from airflow.utils.context import Context +from airflow.utils.dag_cycle_tester import check_cycle +from airflow.utils.decorators import fixup_decorator_warning_stack +from airflow.utils.trigger_rule import TriggerRule +from airflow.utils.types import EdgeInfoType + +if TYPE_CHECKING: + # TODO: Task-SDK: Remove pendulum core dep + from pendulum.tz.timezone import FixedTimezone, Timezone + + from airflow.decorators import TaskDecoratorCollection + from airflow.models.operator import Operator + from airflow.sdk.definitions.taskgroup import TaskGroup + from airflow.typing_compat import Self + + +log = logging.getLogger(__name__) + +TAG_MAX_LEN = 100 + +__all__ = [ + "DAG", + "dag", +] + + +DagStateChangeCallback = Callable[[Context], None] +ScheduleInterval = Union[None, str, timedelta, relativedelta] + +ScheduleArg = Union[ + ScheduleInterval, + Timetable, + BaseAsset, + Collection[Union["Asset", "AssetAlias"]], +] + + +_DAG_HASH_ATTRS = frozenset( + { + "dag_id", + "task_ids", + "start_date", + "end_date", + "fileloc", + "template_searchpath", + "last_loaded", + "schedule", + # TODO: Task-SDK: we should be hashing on timetable now, not scheulde! + # "timetable", + } +) + + +def _create_timetable(interval: ScheduleInterval, timezone: Timezone | FixedTimezone) -> Timetable: + """Create a Timetable instance from a plain ``schedule`` value.""" + from airflow.configuration import conf as airflow_conf + from airflow.timetables.interval import CronDataIntervalTimetable, DeltaDataIntervalTimetable + from airflow.timetables.trigger import CronTriggerTimetable + + if interval is None: + return NullTimetable() + if interval == "@once": + return OnceTimetable() + if interval == "@continuous": + return ContinuousTimetable() + if isinstance(interval, (timedelta, relativedelta)): + return DeltaDataIntervalTimetable(interval) + if isinstance(interval, str): + if airflow_conf.getboolean("scheduler", "create_cron_data_intervals"): + return CronDataIntervalTimetable(interval, timezone) + else: + return CronTriggerTimetable(interval, timezone=timezone) + raise ValueError(f"{interval!r} is not a valid schedule.") + + +def _convert_params(val: abc.MutableMapping | None, self_: DAG) -> ParamsDict: + """ + Convert the plain dict into a ParamsDict. + + This will also merge in params from default_args + """ + val = val or {} + + # merging potentially conflicting default_args['params'] into params + if "params" in self_.default_args: + val.update(self_.default_args["params"]) + del self_.default_args["params"] + + params = ParamsDict(val) + object.__setattr__(self_, "params", params) + + return params + + +def _convert_str_to_tuple(val: str | Iterable[str] | None) -> Iterable[str] | None: + if isinstance(val, str): + return (val,) + return val + + +def _convert_tags(tags: Collection[str] | None) -> MutableSet[str]: + return set(tags or []) + + +def _convert_access_control(value, self_: DAG): + if hasattr(self_, "_upgrade_outdated_dag_access_control"): + return self_._upgrade_outdated_dag_access_control(value) + else: + return value + + +def _convert_doc_md(doc_md: str | None) -> str | None: + if doc_md is None: + return doc_md + + if doc_md.endswith(".md"): + try: + with open(doc_md) as fh: + return fh.read() + except FileNotFoundError: + return doc_md + + return doc_md + + +def _all_after_dag_id_to_kw_only(cls, fields: list[attrs.Attribute]): + i = iter(fields) + f = next(i) + if f.name != "dag_id": + raise RuntimeError("dag_id was not the first field") + yield f + + for f in i: + yield f.evolve(kw_only=True) + + +if TYPE_CHECKING: + # Given this attrs field: + # + # default_args: dict[str, Any] = attrs.field(factory=dict, converter=copy.copy) + # + # mypy ignores the type of the attrs and works out the type as the converter function. However it doesn't + # cope with generics properly and errors with 'incompatible type "dict[str, object]"; expected "_T"' + # + # https://github.com/python/mypy/issues/8625 + def dict_copy(_: dict[str, Any]) -> dict[str, Any]: ... +else: + dict_copy = copy.copy + + +def _default_start_date(instance: DAG): + # Find start date inside default_args for compat with Airflow 2. + from airflow.utils import timezone + + if date := instance.default_args.get("start_date"): + if not isinstance(date, datetime): + date = timezone.parse(date) + instance.default_args["start_date"] = date + return date + return None + + +def _default_dag_display_name(instance: DAG) -> str: + return instance.dag_id + + +def _default_fileloc() -> str: + # Skip over this frame, and the 'attrs generated init' + back = sys._getframe().f_back + if not back or not (back := back.f_back): + # We expect two frames back, if not we don't know where we are + return "" + return back.f_code.co_filename if back else "" + + +def _default_task_group(instance: DAG) -> TaskGroup: + from airflow.sdk.definitions.taskgroup import TaskGroup + + return TaskGroup.create_root(dag=instance) + + +# TODO: Task-SDK: look at re-enabling slots after we remove pickling +@attrs.define(repr=False, field_transformer=_all_after_dag_id_to_kw_only, slots=False) +class DAG: + """ + A dag (directed acyclic graph) is a collection of tasks with directional dependencies. + + A dag also has a schedule, a start date and an end date (optional). For each schedule, + (say daily or hourly), the DAG needs to run each individual tasks as their dependencies + are met. Certain tasks have the property of depending on their own past, meaning that + they can't run until their previous schedule (and upstream tasks) are completed. + + DAGs essentially act as namespaces for tasks. A task_id can only be + added once to a DAG. + + Note that if you plan to use time zones all the dates provided should be pendulum + dates. See :ref:`timezone_aware_dags`. + + .. versionadded:: 2.4 + The *schedule* argument to specify either time-based scheduling logic + (timetable), or dataset-driven triggers. + + .. versionchanged:: 3.0 + The default value of *schedule* has been changed to *None* (no schedule). + The previous default was ``timedelta(days=1)``. + + :param dag_id: The id of the DAG; must consist exclusively of alphanumeric + characters, dashes, dots and underscores (all ASCII) + :param description: The description for the DAG to e.g. be shown on the webserver + :param schedule: If provided, this defines the rules according to which DAG + runs are scheduled. Possible values include a cron expression string, + timedelta object, Timetable, or list of Asset objects. + See also :doc:`/howto/timetable`. + :param start_date: The timestamp from which the scheduler will + attempt to backfill. If this is not provided, backfilling must be done + manually with an explicit time range. + :param end_date: A date beyond which your DAG won't run, leave to None + for open-ended scheduling. + :param template_searchpath: This list of folders (non-relative) + defines where jinja will look for your templates. Order matters. + Note that jinja/airflow includes the path of your DAG file by + default + :param template_undefined: Template undefined type. + :param user_defined_macros: a dictionary of macros that will be exposed + in your jinja templates. For example, passing ``dict(foo='bar')`` + to this argument allows you to ``{{ foo }}`` in all jinja + templates related to this DAG. Note that you can pass any + type of object here. + :param user_defined_filters: a dictionary of filters that will be exposed + in your jinja templates. For example, passing + ``dict(hello=lambda name: 'Hello %s' % name)`` to this argument allows + you to ``{{ 'world' | hello }}`` in all jinja templates related to + this DAG. + :param default_args: A dictionary of default parameters to be used + as constructor keyword parameters when initialising operators. + Note that operators have the same hook, and precede those defined + here, meaning that if your dict contains `'depends_on_past': True` + here and `'depends_on_past': False` in the operator's call + `default_args`, the actual value will be `False`. + :param params: a dictionary of DAG level parameters that are made + accessible in templates, namespaced under `params`. These + params can be overridden at the task level. + :param max_active_tasks: the number of task instances allowed to run + concurrently + :param max_active_runs: maximum number of active DAG runs, beyond this + number of DAG runs in a running state, the scheduler won't create + new active DAG runs + :param max_consecutive_failed_dag_runs: (experimental) maximum number of consecutive failed DAG runs, + beyond this the scheduler will disable the DAG + :param dagrun_timeout: Specify the duration a DagRun should be allowed to run before it times out or + fails. Task instances that are running when a DagRun is timed out will be marked as skipped. + :param sla_miss_callback: DEPRECATED - The SLA feature is removed in Airflow 3.0, to be replaced with a new implementation in 3.1 + :param catchup: Perform scheduler catchup (or only run latest)? Defaults to True + :param on_failure_callback: A function or list of functions to be called when a DagRun of this dag fails. + A context dictionary is passed as a single parameter to this function. + :param on_success_callback: Much like the ``on_failure_callback`` except + that it is executed when the dag succeeds. + :param access_control: Specify optional DAG-level actions, e.g., + "{'role1': {'can_read'}, 'role2': {'can_read', 'can_edit', 'can_delete'}}" + or it can specify the resource name if there is a DAGs Run resource, e.g., + "{'role1': {'DAG Runs': {'can_create'}}, 'role2': {'DAGs': {'can_read', 'can_edit', 'can_delete'}}" + :param is_paused_upon_creation: Specifies if the dag is paused when created for the first time. + If the dag exists already, this flag will be ignored. If this optional parameter + is not specified, the global config setting will be used. + :param jinja_environment_kwargs: additional configuration options to be passed to Jinja + ``Environment`` for template rendering + + **Example**: to avoid Jinja from removing a trailing newline from template strings :: + + DAG( + dag_id="my-dag", + jinja_environment_kwargs={ + "keep_trailing_newline": True, + # some other jinja2 Environment options here + }, + ) + + **See**: `Jinja Environment documentation + `_ + + :param render_template_as_native_obj: If True, uses a Jinja ``NativeEnvironment`` + to render templates as native Python types. If False, a Jinja + ``Environment`` is used to render templates as string values. + :param tags: List of tags to help filtering DAGs in the UI. + :param owner_links: Dict of owners and their links, that will be clickable on the DAGs view UI. + Can be used as an HTTP link (for example the link to your Slack channel), or a mailto link. + e.g: {"dag_owner": "https://airflow.apache.org/"} + :param auto_register: Automatically register this DAG when it is used in a ``with`` block + :param fail_stop: Fails currently running tasks when task in DAG fails. + **Warning**: A fail stop dag can only have tasks with the default trigger rule ("all_success"). + An exception will be thrown if any task in a fail stop dag has a non default trigger rule. + :param dag_display_name: The display name of the DAG which appears on the UI. + """ + + __serialized_fields: ClassVar[frozenset[str] | None] = None + + # Note: mypy gets very confused about the use of `@${attr}.default` for attrs without init=False -- and it + # doesn't correctly track/notice that they have default values (it gives errors about `Missing positional + # argument "description" in call to "DAG"`` etc), so for init=True args we use the `default=Factory()` + # style + + # NOTE: When updating arguments here, please also keep arguments in @dag() + # below in sync. (Search for 'def dag(' in this file.) + dag_id: str = attrs.field(kw_only=False, validator=attrs.validators.instance_of(str)) + description: str | None = attrs.field( + default=None, + validator=attrs.validators.optional(attrs.validators.instance_of(str)), + ) + default_args: dict[str, Any] = attrs.field( + factory=dict, validator=attrs.validators.instance_of(dict), converter=dict_copy + ) + start_date: datetime | None = attrs.field( + default=attrs.Factory(_default_start_date, takes_self=True), + ) + + end_date: datetime | None = None + timezone: FixedTimezone | Timezone = attrs.field(init=False) + schedule: ScheduleArg = attrs.field(default=None, on_setattr=attrs.setters.frozen) + timetable: Timetable = attrs.field(init=False) + template_searchpath: str | Iterable[str] | None = attrs.field( + default=None, converter=_convert_str_to_tuple + ) + # TODO: Task-SDK: Work out how to not import jinj2 until we need it! It's expensive + template_undefined: type[jinja2.StrictUndefined] = jinja2.StrictUndefined + user_defined_macros: dict | None = None + user_defined_filters: dict | None = None + max_active_tasks: int = attrs.field(default=16, validator=attrs.validators.instance_of(int)) + max_active_runs: int = attrs.field(default=16, validator=attrs.validators.instance_of(int)) + max_consecutive_failed_dag_runs: int = attrs.field( + default=-1, validator=attrs.validators.instance_of(int) + ) + dagrun_timeout: timedelta | None = attrs.field( + default=None, + validator=attrs.validators.optional(attrs.validators.instance_of(timedelta)), + ) + # sla_miss_callback: None | SLAMissCallback | list[SLAMissCallback] = None + catchup: bool = attrs.field(default=True, converter=bool) + on_success_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None + on_failure_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None + doc_md: str | None = attrs.field(default=None, converter=_convert_doc_md) + params: ParamsDict = attrs.field( + # mypy doesn't really like passing the Converter object + default=None, + converter=attrs.Converter(_convert_params, takes_self=True), # type: ignore[misc, call-overload] + ) + access_control: dict[str, dict[str, Collection[str]]] | None = attrs.field( + default=None, + converter=attrs.Converter(_convert_access_control, takes_self=True), # type: ignore[misc, call-overload] + ) + is_paused_upon_creation: bool | None = None + jinja_environment_kwargs: dict | None = None + render_template_as_native_obj: bool = attrs.field(default=False, converter=bool) + tags: MutableSet[str] = attrs.field(factory=set, converter=_convert_tags) + owner_links: dict[str, str] = attrs.field(factory=dict) + auto_register: bool = attrs.field(default=True, converter=bool) + fail_stop: bool = attrs.field(default=False, converter=bool) + dag_display_name: str = attrs.field( + default=attrs.Factory(_default_dag_display_name, takes_self=True), + validator=attrs.validators.instance_of(str), + ) + + task_dict: dict[str, Operator] = attrs.field(factory=dict, init=False) + + task_group: TaskGroup = attrs.field( + on_setattr=attrs.setters.frozen, default=attrs.Factory(_default_task_group, takes_self=True) + ) + + fileloc: str = attrs.field(init=False, factory=_default_fileloc) + partial: bool = attrs.field(init=False, default=False) + + edge_info: dict[str, dict[str, EdgeInfoType]] = attrs.field(init=False, factory=dict) + + has_on_success_callback: bool = attrs.field(init=False) + has_on_failure_callback: bool = attrs.field(init=False) + + def __attrs_post_init__(self): + from airflow.utils import timezone + + # Apply the timezone we settled on to end_date if it wasn't supplied + if isinstance(_end_date := self.default_args.get("end_date"), str): + self.default_args["end_date"] = timezone.parse(_end_date, timezone=self.timezone) + + self.start_date = timezone.convert_to_utc(self.start_date) + self.end_date = timezone.convert_to_utc(self.end_date) + if start_date := self.default_args.get("start_date", None): + self.default_args["start_date"] = timezone.convert_to_utc(start_date) + if end_date := self.default_args.get("end_date", None): + self.default_args["end_date"] = timezone.convert_to_utc(end_date) + + @params.validator + def _validate_params(self, _, params: ParamsDict): + """ + Validate Param values when the DAG has schedule defined. + + Raise exception if there are any Params which can not be resolved by their schema definition. + """ + if not self.timetable or not self.timetable.can_be_scheduled: + return + + try: + params.validate() + except ParamValidationError as pverr: + raise ValueError( + f"DAG {self.dag_id!r} is not allowed to define a Schedule, " + "as there are required params without default values, or the default values are not valid." + ) from pverr + + @catchup.validator + def _validate_catchup(self, _, catchup: bool): + requires_automatic_backfilling = self.timetable.can_be_scheduled and catchup + if requires_automatic_backfilling and not ("start_date" in self.default_args or self.start_date): + raise ValueError("start_date is required when catchup=True") + + @tags.validator + def _validate_tags(self, _, tags: Collection[str]): + if tags and any(len(tag) > TAG_MAX_LEN for tag in tags): + raise ValueError(f"tag cannot be longer than {TAG_MAX_LEN} characters") + + @max_active_runs.validator + def _validate_max_active_runs(self, _, max_active_runs): + if self.timetable.active_runs_limit is not None: + if self.timetable.active_runs_limit < self.max_active_runs: + raise ValueError( + f"Invalid max_active_runs: {type(self.timetable).__name__} " + f"requires max_active_runs <= {self.timetable.active_runs_limit}" + ) + + @timetable.default + def _default_timetable(instance: DAG): + from airflow.assets import AssetAll + + schedule = instance.schedule + # TODO: Once + # delattr(self, "schedule") + if isinstance(schedule, Timetable): + return schedule + elif isinstance(schedule, BaseAsset): + return AssetTriggeredTimetable(schedule) + elif isinstance(schedule, Collection) and not isinstance(schedule, str): + if not all(isinstance(x, (Asset, AssetAlias)) for x in schedule): + raise ValueError("All elements in 'schedule' should be assets or asset aliases") + return AssetTriggeredTimetable(AssetAll(*schedule)) + else: + return _create_timetable(schedule, instance.timezone) + + @timezone.default + def _extract_tz(instance): + import pendulum + + from airflow.utils import timezone + + start_date = instance.start_date or instance.default_args.get("start_date") + + if start_date: + if not isinstance(start_date, datetime): + start_date = timezone.parse(start_date) + tzinfo = start_date.tzinfo or settings.TIMEZONE + tz = pendulum.instance(start_date, tz=tzinfo).timezone + else: + tz = settings.TIMEZONE + + return tz + + @has_on_success_callback.default + def _has_on_success_callback(self) -> bool: + return self.on_success_callback is not None + + @has_on_failure_callback.default + def _has_on_failure_callback(self) -> bool: + return self.on_failure_callback is not None + + def __repr__(self): + return f"" + + def __eq__(self, other: Self | Any): + # TODO: This subclassing behaviour seems wrong, but it's what Airflow has done for ~ever. + if type(self) is not type(other): + return False + return all(getattr(self, c, None) == getattr(other, c, None) for c in _DAG_HASH_ATTRS) + + def __ne__(self, other: Any): + return not self == other + + def __lt__(self, other): + return self.dag_id < other.dag_id + + def __hash__(self): + hash_components: list[Any] = [type(self)] + for c in _DAG_HASH_ATTRS: + # task_ids returns a list and lists can't be hashed + if c == "task_ids": + val = tuple(self.task_dict) + else: + val = getattr(self, c, None) + try: + hash(val) + hash_components.append(val) + except TypeError: + hash_components.append(repr(val)) + return hash(tuple(hash_components)) + + def __enter__(self) -> Self: + from airflow.sdk.definitions.contextmanager import DagContext + + DagContext.push(self) + return self + + def __exit__(self, _type, _value, _tb): + from airflow.sdk.definitions.contextmanager import DagContext + + _ = DagContext.pop() + + def validate(self): + """ + Validate the DAG has a coherent setup. + + This is called by the DAG bag before bagging the DAG. + """ + self.timetable.validate() + self.validate_setup_teardown() + + # We validate owner links on set, but since it's a dict it could be mutated without calling the + # setter. Validate again here + self._validate_owner_links(None, self.owner_links) + + def validate_setup_teardown(self): + """ + Validate that setup and teardown tasks are configured properly. + + :meta private: + """ + for task in self.tasks: + if task.is_setup: + for down_task in task.downstream_list: + if not down_task.is_teardown and down_task.trigger_rule != TriggerRule.ALL_SUCCESS: + # todo: we can relax this to allow out-of-scope tasks to have other trigger rules + # this is required to ensure consistent behavior of dag + # when clearing an indirect setup + raise ValueError("Setup tasks must be followed with trigger rule ALL_SUCCESS.") + + def param(self, name: str, default: Any = NOTSET) -> DagParam: + """ + Return a DagParam object for current dag. + + :param name: dag parameter name. + :param default: fallback value for dag parameter. + :return: DagParam instance for specified name and current dag. + """ + return DagParam(current_dag=self, name=name, default=default) + + @property + def tasks(self) -> list[Operator]: + return list(self.task_dict.values()) + + @tasks.setter + def tasks(self, val): + raise AttributeError("DAG.tasks can not be modified. Use dag.add_task() instead.") + + @property + def task_ids(self) -> list[str]: + return list(self.task_dict) + + @property + def teardowns(self) -> list[Operator]: + return [task for task in self.tasks if getattr(task, "is_teardown", None)] + + @property + def tasks_upstream_of_teardowns(self) -> list[Operator]: + upstream_tasks = [t.upstream_list for t in self.teardowns] + return [val for sublist in upstream_tasks for val in sublist if not getattr(val, "is_teardown", None)] + + @property + def folder(self) -> str: + """Folder location of where the DAG object is instantiated.""" + return os.path.dirname(self.fileloc) + + @property + def owner(self) -> str: + """ + Return list of all owners found in DAG tasks. + + :return: Comma separated list of owners in DAG tasks + """ + return ", ".join({t.owner for t in self.tasks}) + + @property + def allow_future_exec_dates(self) -> bool: + return settings.ALLOW_FUTURE_EXEC_DATES and not self.timetable.can_be_scheduled + + def resolve_template_files(self): + for t in self.tasks: + t.resolve_template_files() + + def get_template_env(self, *, force_sandboxed: bool = False) -> jinja2.Environment: + """Build a Jinja2 environment.""" + import airflow.templates + + # Collect directories to search for template files + searchpath = [self.folder] + if self.template_searchpath: + searchpath += self.template_searchpath + + # Default values (for backward compatibility) + jinja_env_options = { + "loader": jinja2.FileSystemLoader(searchpath), + "undefined": self.template_undefined, + "extensions": ["jinja2.ext.do"], + "cache_size": 0, + } + if self.jinja_environment_kwargs: + jinja_env_options.update(self.jinja_environment_kwargs) + env: jinja2.Environment + if self.render_template_as_native_obj and not force_sandboxed: + env = airflow.templates.NativeEnvironment(**jinja_env_options) + else: + env = airflow.templates.SandboxedEnvironment(**jinja_env_options) + + # Add any user defined items. Safe to edit globals as long as no templates are rendered yet. + # http://jinja.pocoo.org/docs/2.10/api/#jinja2.Environment.globals + if self.user_defined_macros: + env.globals.update(self.user_defined_macros) + if self.user_defined_filters: + env.filters.update(self.user_defined_filters) + + return env + + def set_dependency(self, upstream_task_id, downstream_task_id): + """Set dependency between two tasks that already have been added to the DAG using add_task().""" + self.get_task(upstream_task_id).set_downstream(self.get_task(downstream_task_id)) + + @property + def roots(self) -> list[Operator]: + """Return nodes with no parents. These are first to execute and are called roots or root nodes.""" + return [task for task in self.tasks if not task.upstream_list] + + @property + def leaves(self) -> list[Operator]: + """Return nodes with no children. These are last to execute and are called leaves or leaf nodes.""" + return [task for task in self.tasks if not task.downstream_list] + + def topological_sort(self): + """ + Sorts tasks in topographical order, such that a task comes after any of its upstream dependencies. + + Deprecated in place of ``task_group.topological_sort`` + """ + from airflow.sdk.definitions.taskgroup import TaskGroup + + # TODO: Remove in RemovedInAirflow3Warning + def nested_topo(group): + for node in group.topological_sort(): + if isinstance(node, TaskGroup): + yield from nested_topo(node) + else: + yield node + + return tuple(nested_topo(self.task_group)) + + def __deepcopy__(self, memo: dict[int, Any]): + # Switcharoo to go around deepcopying objects coming through the + # backdoor + cls = self.__class__ + result = cls.__new__(cls) + memo[id(self)] = result + for k, v in self.__dict__.items(): + if k not in ("user_defined_macros", "user_defined_filters", "_log"): + object.__setattr__(result, k, copy.deepcopy(v, memo)) + + result.user_defined_macros = self.user_defined_macros + result.user_defined_filters = self.user_defined_filters + if hasattr(self, "_log"): + result._log = self._log # type: ignore[attr-defined] + return result + + def partial_subset( + self, + task_ids_or_regex: str | Pattern | Iterable[str], + include_downstream=False, + include_upstream=True, + include_direct_upstream=False, + ): + """ + Return a subset of the current dag based on regex matching one or more tasks. + + Returns a subset of the current dag as a deep copy of the current dag + based on a regex that should match one or many tasks, and includes + upstream and downstream neighbours based on the flag passed. + + :param task_ids_or_regex: Either a list of task_ids, or a regex to + match against task ids (as a string, or compiled regex pattern). + :param include_downstream: Include all downstream tasks of matched + tasks, in addition to matched tasks. + :param include_upstream: Include all upstream tasks of matched tasks, + in addition to matched tasks. + :param include_direct_upstream: Include all tasks directly upstream of matched + and downstream (if include_downstream = True) tasks + """ + from airflow.models.mappedoperator import MappedOperator + + # deep-copying self.task_dict and self.task_group takes a long time, and we don't want all + # the tasks anyway, so we copy the tasks manually later + memo = {id(self.task_dict): None, id(self.task_group): None} + dag = copy.deepcopy(self, memo) # type: ignore + + if isinstance(task_ids_or_regex, (str, Pattern)): + matched_tasks = [t for t in self.tasks if re2.findall(task_ids_or_regex, t.task_id)] + else: + matched_tasks = [t for t in self.tasks if t.task_id in task_ids_or_regex] + + also_include_ids: set[str] = set() + for t in matched_tasks: + if include_downstream: + for rel in t.get_flat_relatives(upstream=False): + also_include_ids.add(rel.task_id) + if rel not in matched_tasks: # if it's in there, we're already processing it + # need to include setups and teardowns for tasks that are in multiple + # non-collinear setup/teardown paths + if not rel.is_setup and not rel.is_teardown: + also_include_ids.update( + x.task_id for x in rel.get_upstreams_only_setups_and_teardowns() + ) + if include_upstream: + also_include_ids.update(x.task_id for x in t.get_upstreams_follow_setups()) + else: + if not t.is_setup and not t.is_teardown: + also_include_ids.update(x.task_id for x in t.get_upstreams_only_setups_and_teardowns()) + if t.is_setup and not include_downstream: + also_include_ids.update(x.task_id for x in t.downstream_list if x.is_teardown) + + also_include: list[Operator] = [self.task_dict[x] for x in also_include_ids] + direct_upstreams: list[Operator] = [] + if include_direct_upstream: + for t in itertools.chain(matched_tasks, also_include): + upstream = (u for u in t.upstream_list if isinstance(u, (BaseOperator, MappedOperator))) + direct_upstreams.extend(upstream) + + # Make sure to not recursively deepcopy the dag or task_group while copying the task. + # task_group is reset later + def _deepcopy_task(t) -> Operator: + memo.setdefault(id(t.task_group), None) + return copy.deepcopy(t, memo) + + # Compiling the unique list of tasks that made the cut + dag.task_dict = { + t.task_id: _deepcopy_task(t) + for t in itertools.chain(matched_tasks, also_include, direct_upstreams) + } + + def filter_task_group(group, parent_group): + """Exclude tasks not included in the subdag from the given TaskGroup.""" + # We want to deepcopy _most but not all_ attributes of the task group, so we create a shallow copy + # and then manually deep copy the instances. (memo argument to deepcopy only works for instances + # of classes, not "native" properties of an instance) + copied = copy.copy(group) + + memo[id(group.children)] = {} + if parent_group: + memo[id(group.parent_group)] = parent_group + for attr in type(group).__slots__: + value = getattr(group, attr) + value = copy.deepcopy(value, memo) + object.__setattr__(copied, attr, value) + + proxy = weakref.proxy(copied) + + for child in group.children.values(): + if isinstance(child, AbstractOperator): + if child.task_id in dag.task_dict: + task = copied.children[child.task_id] = dag.task_dict[child.task_id] + task.task_group = proxy + else: + copied.used_group_ids.discard(child.task_id) + else: + filtered_child = filter_task_group(child, proxy) + + # Only include this child TaskGroup if it is non-empty. + if filtered_child.children: + copied.children[child.group_id] = filtered_child + + return copied + + object.__setattr__(dag, "task_group", filter_task_group(self.task_group, None)) + + # Removing upstream/downstream references to tasks and TaskGroups that did not make + # the cut. + subdag_task_groups = dag.task_group.get_task_group_dict() + for group in subdag_task_groups.values(): + group.upstream_group_ids.intersection_update(subdag_task_groups) + group.downstream_group_ids.intersection_update(subdag_task_groups) + group.upstream_task_ids.intersection_update(dag.task_dict) + group.downstream_task_ids.intersection_update(dag.task_dict) + + for t in dag.tasks: + # Removing upstream/downstream references to tasks that did not + # make the cut + t.upstream_task_ids.intersection_update(dag.task_dict) + t.downstream_task_ids.intersection_update(dag.task_dict) + + dag.partial = len(dag.tasks) < len(self.tasks) + + return dag + + def has_task(self, task_id: str): + return task_id in self.task_dict + + def has_task_group(self, task_group_id: str) -> bool: + return task_group_id in self.task_group_dict + + @functools.cached_property + def task_group_dict(self): + return {k: v for k, v in self.task_group.get_task_group_dict().items() if k is not None} + + def get_task(self, task_id: str) -> Operator: + if task_id in self.task_dict: + return self.task_dict[task_id] + raise TaskNotFound(f"Task {task_id} not found") + + @property + def task(self) -> TaskDecoratorCollection: + from airflow.decorators import task + + return cast("TaskDecoratorCollection", functools.partial(task, dag=self)) + + def add_task(self, task: Operator) -> None: + """ + Add a task to the DAG. + + :param task: the task you want to add + """ + # FailStopDagInvalidTriggerRule.check(dag=self, trigger_rule=task.trigger_rule) + + from airflow.sdk.definitions.contextmanager import TaskGroupContext + + # if the task has no start date, assign it the same as the DAG + if not task.start_date: + task.start_date = self.start_date + # otherwise, the task will start on the later of its own start date and + # the DAG's start date + elif self.start_date: + task.start_date = max(task.start_date, self.start_date) + + # if the task has no end date, assign it the same as the dag + if not task.end_date: + task.end_date = self.end_date + # otherwise, the task will end on the earlier of its own end date and + # the DAG's end date + elif task.end_date and self.end_date: + task.end_date = min(task.end_date, self.end_date) + + task_id = task.node_id + if not task.task_group: + task_group = TaskGroupContext.get_current(self) + if task_group: + task_id = task_group.child_id(task_id) + task_group.add(task) + + if ( + task_id in self.task_dict and self.task_dict[task_id] is not task + ) or task_id in self.task_group.used_group_ids: + raise DuplicateTaskIdFound(f"Task id '{task_id}' has already been added to the DAG") + else: + self.task_dict[task_id] = task + # TODO: Task-SDK: this type ignore shouldn't be needed! + task.dag = self # type: ignore[assignment] + # Add task_id to used_group_ids to prevent group_id and task_id collisions. + self.task_group.used_group_ids.add(task_id) + + FailStopDagInvalidTriggerRule.check(fail_stop=self.fail_stop, trigger_rule=task.trigger_rule) + + def add_tasks(self, tasks: Iterable[Operator]) -> None: + """ + Add a list of tasks to the DAG. + + :param tasks: a lit of tasks you want to add + """ + for task in tasks: + self.add_task(task) + + def _remove_task(self, task_id: str) -> None: + # This is "private" as removing could leave a hole in dependencies if done incorrectly, and this + # doesn't guard against that + task = self.task_dict.pop(task_id) + tg = getattr(task, "task_group", None) + if tg: + tg._remove(task) + + def cli(self): + """Exposes a CLI specific to this DAG.""" + check_cycle(self) + + from airflow.cli import cli_parser + + parser = cli_parser.get_parser(dag_parser=True) + args = parser.parse_args() + args.func(args, self) + + @classmethod + def get_serialized_fields(cls): + """Stringified DAGs and operators contain exactly these fields.""" + if not cls.__serialized_fields: + exclusion_list = { + "schedule_asset_references", + "schedule_asset_alias_references", + "task_outlet_asset_references", + "_old_context_manager_dags", + "safe_dag_id", + "last_loaded", + "user_defined_filters", + "user_defined_macros", + "partial", + "params", + "_pickle_id", + "_log", + "task_dict", + "template_searchpath", + # "sla_miss_callback", + "on_success_callback", + "on_failure_callback", + "template_undefined", + "jinja_environment_kwargs", + # has_on_*_callback are only stored if the value is True, as the default is False + "has_on_success_callback", + "has_on_failure_callback", + "auto_register", + "fail_stop", + "schedule", + } + cls.__serialized_fields = frozenset(vars(DAG(dag_id="test", schedule=None))) - exclusion_list + return cls.__serialized_fields + + def get_edge_info(self, upstream_task_id: str, downstream_task_id: str) -> EdgeInfoType: + """Return edge information for the given pair of tasks or an empty edge if there is no information.""" + # Note - older serialized DAGs may not have edge_info being a dict at all + empty = cast(EdgeInfoType, {}) + if self.edge_info: + return self.edge_info.get(upstream_task_id, {}).get(downstream_task_id, empty) + else: + return empty + + def set_edge_info(self, upstream_task_id: str, downstream_task_id: str, info: EdgeInfoType): + """ + Set the given edge information on the DAG. + + Note that this will overwrite, rather than merge with, existing info. + """ + self.edge_info.setdefault(upstream_task_id, {})[downstream_task_id] = info + + @owner_links.validator + def _validate_owner_links(self, _, owner_links): + wrong_links = {} + + for owner, link in owner_links.items(): + result = urlsplit(link) + if result.scheme == "mailto": + # netloc is not existing for 'mailto' link, so we are checking that the path is parsed + if not result.path: + wrong_links[result.path] = link + elif not result.scheme or not result.netloc: + wrong_links[owner] = link + if wrong_links: + raise ValueError( + "Wrong link format was used for the owner. Use a valid link \n" + f"Bad formatted links are: {wrong_links}" + ) + + +if TYPE_CHECKING: + # NOTE: Please keep the list of arguments in sync with DAG.__init__. + # Only exception: dag_id here should have a default value, but not in DAG. + def dag( + dag_id: str = "", + *, + description: str | None = None, + schedule: ScheduleArg = None, + start_date: datetime | None = None, + end_date: datetime | None = None, + template_searchpath: str | Iterable[str] | None = None, + template_undefined: type[jinja2.StrictUndefined] = jinja2.StrictUndefined, + user_defined_macros: dict | None = None, + user_defined_filters: dict | None = None, + default_args: dict[str, Any] | None = None, + max_active_tasks: int = ..., + max_active_runs: int = ..., + max_consecutive_failed_dag_runs: int = ..., + dagrun_timeout: timedelta | None = None, + # sla_miss_callback: Any = None, + catchup: bool = ..., + on_success_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None, + on_failure_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None, + doc_md: str | None = None, + params: ParamsDict | dict[str, Any] | None = None, + access_control: dict[str, dict[str, Collection[str]]] | dict[str, Collection[str]] | None = None, + is_paused_upon_creation: bool | None = None, + jinja_environment_kwargs: dict | None = None, + render_template_as_native_obj: bool = False, + tags: Collection[str] | None = None, + owner_links: dict[str, str] | None = None, + auto_register: bool = True, + fail_stop: bool = False, + dag_display_name: str | None = None, + ) -> Callable[[Callable], Callable[..., DAG]]: + """ + Python dag decorator which wraps a function into an Airflow DAG. + + Accepts kwargs for operator kwarg. Can be used to parameterize DAGs. + + :param dag_args: Arguments for DAG object + :param dag_kwargs: Kwargs for DAG object. + """ +else: + + def dag(dag_id="", __DAG_class=DAG, __warnings_stacklevel_delta=2, **decorator_kwargs): + # TODO: Task-SDK: remove __DAG_class + # __DAG_class is a temporary hack to allow the dag decorator in airflow.models.dag to continue to + # return SchedulerDag objects + DAG = __DAG_class + + def wrapper(f: Callable) -> Callable[..., DAG]: + @functools.wraps(f) + def factory(*args, **kwargs): + # Generate signature for decorated function and bind the arguments when called + # we do this to extract parameters, so we can annotate them on the DAG object. + # In addition, this fails if we are missing any args/kwargs with TypeError as expected. + f_sig = signature(f).bind(*args, **kwargs) + # Apply defaults to capture default values if set. + f_sig.apply_defaults() + + # Initialize DAG with bound arguments + with DAG(dag_id or f.__name__, **decorator_kwargs) as dag_obj: + # Set DAG documentation from function documentation if it exists and doc_md is not set. + if f.__doc__ and not dag_obj.doc_md: + dag_obj.doc_md = f.__doc__ + + # Generate DAGParam for each function arg/kwarg and replace it for calling the function. + # All args/kwargs for function will be DAGParam object and replaced on execution time. + f_kwargs = {} + for name, value in f_sig.arguments.items(): + f_kwargs[name] = dag_obj.param(name, value) + + # set file location to caller source path + back = sys._getframe().f_back + dag_obj.fileloc = back.f_code.co_filename if back else "" + + # Invoke function to create operators in the DAG scope. + f(**f_kwargs) + + # Return dag object such that it's accessible in Globals. + return dag_obj + + # Ensure that warnings from inside DAG() are emitted from the caller, not here + fixup_decorator_warning_stack(factory) + return factory + + return wrapper diff --git a/task_sdk/src/airflow/sdk/definitions/decorators.py b/task_sdk/src/airflow/sdk/definitions/decorators.py new file mode 100644 index 000000000000..ab73ba0c9242 --- /dev/null +++ b/task_sdk/src/airflow/sdk/definitions/decorators.py @@ -0,0 +1,42 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import sys +from types import FunctionType + + +class _autostacklevel_warn: + def __init__(self): + self.warnings = __import__("warnings") + + def __getattr__(self, name: str): + return getattr(self.warnings, name) + + def __dir__(self): + return dir(self.warnings) + + def warn(self, message, category=None, stacklevel=1, source=None): + self.warnings.warn(message, category, stacklevel + 2, source) + + +def fixup_decorator_warning_stack(func: FunctionType): + if func.__globals__.get("warnings") is sys.modules["warnings"]: + # Yes, this is more than slightly hacky, but it _automatically_ sets the right stacklevel parameter to + # `warnings.warn` to ignore the decorator. + func.__globals__["warnings"] = _autostacklevel_warn() diff --git a/task_sdk/src/airflow/sdk/definitions/edges.py b/task_sdk/src/airflow/sdk/definitions/edges.py new file mode 100644 index 000000000000..7e50431b497e --- /dev/null +++ b/task_sdk/src/airflow/sdk/definitions/edges.py @@ -0,0 +1,189 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from collections.abc import Sequence +from typing import TYPE_CHECKING + +from airflow.sdk.definitions.mixins import DependencyMixin + +if TYPE_CHECKING: + from airflow.sdk.definitions.dag import DAG + + +class EdgeModifier(DependencyMixin): + """ + Class that represents edge information to be added between two tasks/operators. + + Has shorthand factory functions, like Label("hooray"). + + Current implementation supports + t1 >> Label("Success route") >> t2 + t2 << Label("Success route") << t2 + + Note that due to the potential for use in either direction, this waits + to make the actual connection between both sides until both are declared, + and will do so progressively if multiple ups/downs are added. + + This and EdgeInfo are related - an EdgeModifier is the Python object you + use to add information to (potentially multiple) edges, and EdgeInfo + is the representation of the information for one specific edge. + """ + + def __init__(self, label: str | None = None): + self.label = label + self._upstream: list[DependencyMixin] = [] + self._downstream: list[DependencyMixin] = [] + + @property + def roots(self): + return self._downstream + + @property + def leaves(self): + return self._upstream + + @staticmethod + def _make_list( + item_or_list: DependencyMixin | Sequence[DependencyMixin], + ) -> Sequence[DependencyMixin]: + if not isinstance(item_or_list, Sequence): + return [item_or_list] + return item_or_list + + def _save_nodes( + self, + nodes: DependencyMixin | Sequence[DependencyMixin], + stream: list[DependencyMixin], + ): + from airflow.models.xcom_arg import XComArg + from airflow.sdk.definitions.node import DAGNode + from airflow.sdk.definitions.taskgroup import TaskGroup + + for node in self._make_list(nodes): + if isinstance(node, (TaskGroup, XComArg, DAGNode)): + stream.append(node) + else: + raise TypeError( + f"Cannot use edge labels with {type(node).__name__}, only tasks, XComArg or TaskGroups" + ) + + def _convert_streams_to_task_groups(self): + """ + Convert a node to a TaskGroup or leave it as a DAGNode. + + Requires both self._upstream and self._downstream. + + To do this, we keep a set of group_ids seen among the streams. If we find that + the nodes are from the same TaskGroup, we will leave them as DAGNodes and not + convert them to TaskGroups + """ + from airflow.models.xcom_arg import XComArg + from airflow.sdk.definitions.node import DAGNode + from airflow.sdk.definitions.taskgroup import TaskGroup + + group_ids = set() + for node in [*self._upstream, *self._downstream]: + if isinstance(node, DAGNode) and node.task_group: + if node.task_group.is_root: + group_ids.add("root") + else: + group_ids.add(node.task_group.group_id) + elif isinstance(node, TaskGroup): + group_ids.add(node.group_id) + elif isinstance(node, XComArg): + if isinstance(node.operator, DAGNode) and node.operator.task_group: + if node.operator.task_group.is_root: + group_ids.add("root") + else: + group_ids.add(node.operator.task_group.group_id) + + # If all nodes originate from the same TaskGroup, we will not convert them + if len(group_ids) != 1: + self._upstream = self._convert_stream_to_task_groups(self._upstream) + self._downstream = self._convert_stream_to_task_groups(self._downstream) + + def _convert_stream_to_task_groups(self, stream: Sequence[DependencyMixin]) -> Sequence[DependencyMixin]: + from airflow.sdk.definitions.node import DAGNode + + return [ + node.task_group + if isinstance(node, DAGNode) and node.task_group and not node.task_group.is_root + else node + for node in stream + ] + + def set_upstream( + self, + other: DependencyMixin | Sequence[DependencyMixin], + edge_modifier: EdgeModifier | None = None, + ): + """ + Set the given task/list onto the upstream attribute, then attempt to resolve the relationship. + + Providing this also provides << via DependencyMixin. + """ + self._save_nodes(other, self._upstream) + if self._upstream and self._downstream: + # Convert _upstream and _downstream to task_groups only after both are set + self._convert_streams_to_task_groups() + for node in self._downstream: + node.set_upstream(other, edge_modifier=self) + + def set_downstream( + self, + other: DependencyMixin | Sequence[DependencyMixin], + edge_modifier: EdgeModifier | None = None, + ): + """ + Set the given task/list onto the downstream attribute, then attempt to resolve the relationship. + + Providing this also provides >> via DependencyMixin. + """ + self._save_nodes(other, self._downstream) + if self._upstream and self._downstream: + # Convert _upstream and _downstream to task_groups only after both are set + self._convert_streams_to_task_groups() + for node in self._upstream: + node.set_downstream(other, edge_modifier=self) + + def update_relative( + self, + other: DependencyMixin, + upstream: bool = True, + edge_modifier: EdgeModifier | None = None, + ) -> None: + """Update relative if we're not the "main" side of a relationship; still run the same logic.""" + if upstream: + self.set_upstream(other) + else: + self.set_downstream(other) + + def add_edge_info(self, dag: DAG, upstream_id: str, downstream_id: str): + """ + Add or update task info on the DAG for this specific pair of tasks. + + Called either from our relationship trigger methods above, or directly + by set_upstream/set_downstream in operators. + """ + dag.set_edge_info(upstream_id, downstream_id, {"label": self.label}) + + +# Factory functions +def Label(label: str): + """Create an EdgeModifier that sets a human-readable label on the edge.""" + return EdgeModifier(label=label) diff --git a/task_sdk/src/airflow/sdk/definitions/mixins.py b/task_sdk/src/airflow/sdk/definitions/mixins.py new file mode 100644 index 000000000000..de63772615de --- /dev/null +++ b/task_sdk/src/airflow/sdk/definitions/mixins.py @@ -0,0 +1,121 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from abc import abstractmethod +from collections.abc import Iterable, Sequence +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from airflow.sdk.definitions.baseoperator import BaseOperator + from airflow.sdk.definitions.edges import EdgeModifier + +# TODO: Should this all just live on DAGNode? + + +class DependencyMixin: + """Mixing implementing common dependency setting methods like >> and <<.""" + + @property + def roots(self) -> Iterable[DependencyMixin]: + """ + List of root nodes -- ones with no upstream dependencies. + + a.k.a. the "start" of this sub-graph + """ + raise NotImplementedError() + + @property + def leaves(self) -> Iterable[DependencyMixin]: + """ + List of leaf nodes -- ones with only upstream dependencies. + + a.k.a. the "end" of this sub-graph + """ + raise NotImplementedError() + + @abstractmethod + def set_upstream( + self, other: DependencyMixin | Sequence[DependencyMixin], edge_modifier: EdgeModifier | None = None + ): + """Set a task or a task list to be directly upstream from the current task.""" + raise NotImplementedError() + + @abstractmethod + def set_downstream( + self, other: DependencyMixin | Sequence[DependencyMixin], edge_modifier: EdgeModifier | None = None + ): + """Set a task or a task list to be directly downstream from the current task.""" + raise NotImplementedError() + + def as_setup(self) -> DependencyMixin: + """Mark a task as setup task.""" + raise NotImplementedError() + + def as_teardown( + self, + *, + setups: BaseOperator | Iterable[BaseOperator] | None = None, + on_failure_fail_dagrun: bool | None = None, + ) -> DependencyMixin: + """Mark a task as teardown and set its setups as direct relatives.""" + raise NotImplementedError() + + def update_relative( + self, other: DependencyMixin, upstream: bool = True, edge_modifier: EdgeModifier | None = None + ) -> None: + """ + Update relationship information about another TaskMixin. Default is no-op. + + Override if necessary. + """ + + def __lshift__(self, other: DependencyMixin | Sequence[DependencyMixin]): + """Implement Task << Task.""" + self.set_upstream(other) + return other + + def __rshift__(self, other: DependencyMixin | Sequence[DependencyMixin]): + """Implement Task >> Task.""" + self.set_downstream(other) + return other + + def __rrshift__(self, other: DependencyMixin | Sequence[DependencyMixin]): + """Implement Task >> [Task] because list don't have __rshift__ operators.""" + self.__lshift__(other) + return self + + def __rlshift__(self, other: DependencyMixin | Sequence[DependencyMixin]): + """Implement Task << [Task] because list don't have __lshift__ operators.""" + self.__rshift__(other) + return self + + @classmethod + def _iter_references(cls, obj: Any) -> Iterable[tuple[DependencyMixin, str]]: + from airflow.sdk.definitions.abstractoperator import AbstractOperator + + # TODO:Task-SDK + from airflow.utils.mixins import ResolveMixin + + if isinstance(obj, AbstractOperator): + yield obj, "operator" + elif isinstance(obj, ResolveMixin): + yield from obj.iter_references() + elif isinstance(obj, Sequence): + for o in obj: + yield from cls._iter_references(o) diff --git a/task_sdk/src/airflow/sdk/definitions/node.py b/task_sdk/src/airflow/sdk/definitions/node.py new file mode 100644 index 000000000000..a29a36a7b7ad --- /dev/null +++ b/task_sdk/src/airflow/sdk/definitions/node.py @@ -0,0 +1,222 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import logging +import re +from abc import ABCMeta, abstractmethod +from collections.abc import Iterable, Sequence +from datetime import datetime +from typing import TYPE_CHECKING, Any + +import methodtools +import re2 + +from airflow.sdk.definitions.mixins import DependencyMixin + +if TYPE_CHECKING: + from airflow.models.operator import Operator + from airflow.sdk.definitions.dag import DAG + from airflow.sdk.definitions.edges import EdgeModifier + from airflow.sdk.definitions.taskgroup import TaskGroup + from airflow.sdk.types import Logger + from airflow.serialization.enums import DagAttributeTypes + + +KEY_REGEX = re2.compile(r"^[\w.-]+$") +GROUP_KEY_REGEX = re2.compile(r"^[\w-]+$") +CAMELCASE_TO_SNAKE_CASE_REGEX = re.compile(r"(?!^)([A-Z]+)") + + +def validate_key(k: str, max_length: int = 250): + """Validate value used as a key.""" + if not isinstance(k, str): + raise TypeError(f"The key has to be a string and is {type(k)}:{k}") + if len(k) > max_length: + raise ValueError(f"The key has to be less than {max_length} characters") + if not KEY_REGEX.match(k): + raise ValueError( + f"The key {k!r} has to be made of alphanumeric characters, dashes, " + "dots and underscores exclusively" + ) + + +class DAGNode(DependencyMixin, metaclass=ABCMeta): + """ + A base class for a node in the graph of a workflow. + + A node may be an Operator or a Task Group, either mapped or unmapped. + """ + + dag: DAG | None + task_group: TaskGroup | None + """The task_group that contains this node""" + start_date: datetime | None + end_date: datetime | None + upstream_task_ids: set[str] + downstream_task_ids: set[str] + + def __init__(self): + self.upstream_task_ids = set() + self.downstream_task_ids = set() + super().__init__() + + @property + @abstractmethod + def node_id(self) -> str: + raise NotImplementedError() + + @property + def label(self) -> str | None: + tg = self.task_group + if tg and tg.node_id and tg.prefix_group_id: + # "task_group_id.task_id" -> "task_id" + return self.node_id[len(tg.node_id) + 1 :] + return self.node_id + + def has_dag(self) -> bool: + return self.dag is not None + + @property + def dag_id(self) -> str: + """Returns dag id if it has one or an adhoc/meaningless ID.""" + if self.dag: + return self.dag.dag_id + return "_in_memory_dag_" + + @property + @methodtools.lru_cache() + def log(self) -> Logger: + typ = type(self) + name = f"{typ.__module__}.{typ.__qualname__}" + return logging.getLogger(name) + + @property + @abstractmethod + def roots(self) -> Sequence[DAGNode]: + raise NotImplementedError() + + @property + @abstractmethod + def leaves(self) -> Sequence[DAGNode]: + raise NotImplementedError() + + def _set_relatives( + self, + task_or_task_list: DependencyMixin | Sequence[DependencyMixin], + upstream: bool = False, + edge_modifier: EdgeModifier | None = None, + ) -> None: + """Set relatives for the task or task list.""" + from airflow.models.mappedoperator import MappedOperator + from airflow.sdk.definitions.baseoperator import BaseOperator + + if not isinstance(task_or_task_list, Sequence): + task_or_task_list = [task_or_task_list] + + task_list: list[BaseOperator | MappedOperator] = [] + for task_object in task_or_task_list: + task_object.update_relative(self, not upstream, edge_modifier=edge_modifier) + relatives = task_object.leaves if upstream else task_object.roots + for task in relatives: + if not isinstance(task, (BaseOperator, MappedOperator)): + raise TypeError( + f"Relationships can only be set between Operators; received {task.__class__.__name__}" + ) + task_list.append(task) + + # relationships can only be set if the tasks share a single DAG. Tasks + # without a DAG are assigned to that DAG. + dags: set[DAG] = {task.dag for task in [*self.roots, *task_list] if task.has_dag() and task.dag} + + if len(dags) > 1: + raise RuntimeError(f"Tried to set relationships between tasks in more than one DAG: {dags}") + elif len(dags) == 1: + dag = dags.pop() + else: + raise ValueError( + "Tried to create relationships between tasks that don't have DAGs yet. " + f"Set the DAG for at least one task and try again: {[self, *task_list]}" + ) + + if not self.has_dag(): + # If this task does not yet have a dag, add it to the same dag as the other task. + self.dag = dag + + for task in task_list: + if dag and not task.has_dag(): + # If the other task does not yet have a dag, add it to the same dag as this task and + dag.add_task(task) # type: ignore[arg-type] + if upstream: + task.downstream_task_ids.add(self.node_id) + self.upstream_task_ids.add(task.node_id) + if edge_modifier: + edge_modifier.add_edge_info(dag, task.node_id, self.node_id) + else: + self.downstream_task_ids.add(task.node_id) + task.upstream_task_ids.add(self.node_id) + if edge_modifier: + edge_modifier.add_edge_info(dag, self.node_id, task.node_id) + + def set_downstream( + self, + task_or_task_list: DependencyMixin | Sequence[DependencyMixin], + edge_modifier: EdgeModifier | None = None, + ) -> None: + """Set a node (or nodes) to be directly downstream from the current node.""" + self._set_relatives(task_or_task_list, upstream=False, edge_modifier=edge_modifier) + + def set_upstream( + self, + task_or_task_list: DependencyMixin | Sequence[DependencyMixin], + edge_modifier: EdgeModifier | None = None, + ) -> None: + """Set a node (or nodes) to be directly upstream from the current node.""" + self._set_relatives(task_or_task_list, upstream=True, edge_modifier=edge_modifier) + + @property + def downstream_list(self) -> Iterable[Operator]: + """List of nodes directly downstream.""" + if not self.dag: + raise RuntimeError(f"Operator {self} has not been assigned to a DAG yet") + return [self.dag.get_task(tid) for tid in self.downstream_task_ids] + + @property + def upstream_list(self) -> Iterable[Operator]: + """List of nodes directly upstream.""" + if not self.dag: + raise RuntimeError(f"Operator {self} has not been assigned to a DAG yet") + return [self.dag.get_task(tid) for tid in self.upstream_task_ids] + + def get_direct_relative_ids(self, upstream: bool = False) -> set[str]: + """Get set of the direct relative ids to the current task, upstream or downstream.""" + if upstream: + return self.upstream_task_ids + else: + return self.downstream_task_ids + + def get_direct_relatives(self, upstream: bool = False) -> Iterable[Operator]: + """Get list of the direct relatives to the current task, upstream or downstream.""" + if upstream: + return self.upstream_list + else: + return self.downstream_list + + def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]: + """Serialize a task group's content; used by TaskGroupSerialization.""" + raise NotImplementedError() diff --git a/task_sdk/src/airflow/sdk/definitions/taskgroup.py b/task_sdk/src/airflow/sdk/definitions/taskgroup.py new file mode 100644 index 000000000000..de4bd0c771ad --- /dev/null +++ b/task_sdk/src/airflow/sdk/definitions/taskgroup.py @@ -0,0 +1,683 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""A collection of closely related tasks on the same DAG that should be grouped together visually.""" + +from __future__ import annotations + +import copy +import functools +import operator +import weakref +from collections.abc import Generator, Iterator, Sequence +from typing import TYPE_CHECKING, Any + +import attrs +import methodtools +import re2 + +from airflow.exceptions import ( + AirflowDagCycleException, + AirflowException, + DuplicateTaskIdFound, + TaskAlreadyInTaskGroup, +) +from airflow.sdk.definitions.node import DAGNode + +if TYPE_CHECKING: + from airflow.models.expandinput import ExpandInput + from airflow.sdk.definitions.abstractoperator import AbstractOperator + from airflow.sdk.definitions.baseoperator import BaseOperator + from airflow.sdk.definitions.dag import DAG + from airflow.sdk.definitions.edges import EdgeModifier + from airflow.sdk.definitions.mixins import DependencyMixin + from airflow.serialization.enums import DagAttributeTypes + + +def _default_parent_group() -> TaskGroup | None: + from airflow.sdk.definitions.contextmanager import TaskGroupContext + + return TaskGroupContext.get_current() + + +def _parent_used_group_ids(tg: TaskGroup) -> set: + if tg.parent_group: + return tg.parent_group.used_group_ids + return set() + + +# This could be achieved with `@dag.default` and make this a method, but for some unknown reason when we do +# that it makes Mypy (1.9.0 and 1.13.0 tested) seem to entirely loose track that this is an Attrs class. So +# we've gone with this and moved on with our lives, mypy is to much of a dark beast to battle over this. +def _default_dag(instance: TaskGroup): + from airflow.sdk.definitions.contextmanager import DagContext + + if (pg := instance.parent_group) is not None: + return pg.dag + return DagContext.get_current() + + +@attrs.define(repr=False) +class TaskGroup(DAGNode): + """ + A collection of tasks. + + When set_downstream() or set_upstream() are called on the TaskGroup, it is applied across + all tasks within the group if necessary. + + :param group_id: a unique, meaningful id for the TaskGroup. group_id must not conflict + with group_id of TaskGroup or task_id of tasks in the DAG. Root TaskGroup has group_id + set to None. + :param prefix_group_id: If set to True, child task_id and group_id will be prefixed with + this TaskGroup's group_id. If set to False, child task_id and group_id are not prefixed. + Default is True. + :param parent_group: The parent TaskGroup of this TaskGroup. parent_group is set to None + for the root TaskGroup. + :param dag: The DAG that this TaskGroup belongs to. + :param default_args: A dictionary of default parameters to be used + as constructor keyword parameters when initialising operators, + will override default_args defined in the DAG level. + Note that operators have the same hook, and precede those defined + here, meaning that if your dict contains `'depends_on_past': True` + here and `'depends_on_past': False` in the operator's call + `default_args`, the actual value will be `False`. + :param tooltip: The tooltip of the TaskGroup node when displayed in the UI + :param ui_color: The fill color of the TaskGroup node when displayed in the UI + :param ui_fgcolor: The label color of the TaskGroup node when displayed in the UI + :param add_suffix_on_collision: If this task group name already exists, + automatically add `__1` etc suffixes + """ + + _group_id: str | None = attrs.field( + validator=attrs.validators.optional(attrs.validators.instance_of(str)) + ) + prefix_group_id: bool = attrs.field(default=True) + parent_group: TaskGroup | None = attrs.field(factory=_default_parent_group) + dag: DAG = attrs.field(default=attrs.Factory(_default_dag, takes_self=True)) + default_args: dict[str, Any] = attrs.field(factory=dict, converter=copy.deepcopy) + tooltip: str = attrs.field(default="", validator=attrs.validators.instance_of(str)) + children: dict[str, DAGNode] = attrs.field(factory=dict, init=False) + + upstream_group_ids: set[str | None] = attrs.field(factory=set, init=False) + downstream_group_ids: set[str | None] = attrs.field(factory=set, init=False) + upstream_task_ids: set[str] = attrs.field(factory=set, init=False) + downstream_task_ids: set[str] = attrs.field(factory=set, init=False) + + used_group_ids: set[str] = attrs.field( + default=attrs.Factory(_parent_used_group_ids, takes_self=True), + init=False, + on_setattr=attrs.setters.frozen, + ) + + ui_color: str = attrs.field(default="CornflowerBlue", validator=attrs.validators.instance_of(str)) + ui_fgcolor: str = attrs.field(default="#000", validator=attrs.validators.instance_of(str)) + + add_suffix_on_collision: bool = False + + @dag.validator + def _validate_dag(self, _attr, dag): + if not dag: + raise RuntimeError("TaskGroup can only be used inside a dag") + + def __attrs_post_init__(self): + # TODO: If attrs supported init only args we could use that here + # https://github.com/python-attrs/attrs/issues/342 + self._check_for_group_id_collisions(self.add_suffix_on_collision) + + if self._group_id and not self.parent_group and self.dag: + # Support `tg = TaskGroup(x, dag=dag)` + self.parent_group = self.dag.task_group + + if self.parent_group: + self.parent_group.add(self) + if self.parent_group.default_args: + self.default_args = {**self.parent_group.default_args, **self.default_args} + + if self._group_id: + self.used_group_ids.add(self.group_id) + self.used_group_ids.add(self.downstream_join_id) + self.used_group_ids.add(self.upstream_join_id) + + def _check_for_group_id_collisions(self, add_suffix_on_collision: bool): + if self._group_id is None: + return + # if given group_id already used assign suffix by incrementing largest used suffix integer + # Example : task_group ==> task_group__1 -> task_group__2 -> task_group__3 + if self.group_id in self.used_group_ids: + if not add_suffix_on_collision: + raise DuplicateTaskIdFound(f"group_id '{self._group_id}' has already been added to the DAG") + base = re2.split(r"__\d+$", self._group_id)[0] + suffixes = sorted( + int(re2.split(r"^.+__", used_group_id)[1]) + for used_group_id in self.used_group_ids + if used_group_id is not None and re2.match(rf"^{base}__\d+$", used_group_id) + ) + if not suffixes: + self._group_id += "__1" + else: + self._group_id = f"{base}__{suffixes[-1] + 1}" + + @classmethod + def create_root(cls, dag: DAG) -> TaskGroup: + """Create a root TaskGroup with no group_id or parent.""" + return cls(group_id=None, dag=dag) + + @property + def node_id(self): + return self.group_id + + @property + def is_root(self) -> bool: + """Returns True if this TaskGroup is the root TaskGroup. Otherwise False.""" + return not self._group_id + + @property + def task_group(self) -> TaskGroup | None: + return self.parent_group + + @task_group.setter + def task_group(self, value: TaskGroup | None): + self.parent_group = value + + def __iter__(self): + for child in self.children.values(): + if isinstance(child, TaskGroup): + yield from child + else: + yield child + + def add(self, task: DAGNode) -> DAGNode: + """ + Add a task or TaskGroup to this TaskGroup. + + :meta private: + """ + from airflow.sdk.definitions.abstractoperator import AbstractOperator + from airflow.sdk.definitions.contextmanager import TaskGroupContext + + if TaskGroupContext.active: + if task.task_group and task.task_group != self: + task.task_group.children.pop(task.node_id, None) + task.task_group = self + existing_tg = task.task_group + if isinstance(task, AbstractOperator) and existing_tg is not None and existing_tg != self: + raise TaskAlreadyInTaskGroup(task.node_id, existing_tg.node_id, self.node_id) + + # Set the TG first, as setting it might change the return value of node_id! + task.task_group = weakref.proxy(self) + key = task.node_id + + if key in self.children: + node_type = "Task" if hasattr(task, "task_id") else "Task Group" + raise DuplicateTaskIdFound(f"{node_type} id '{key}' has already been added to the DAG") + + if isinstance(task, TaskGroup): + if self.dag: + if task.dag is not None and self.dag is not task.dag: + raise RuntimeError( + "Cannot mix TaskGroups from different DAGs: %s and %s", self.dag, task.dag + ) + task.dag = self.dag + if task.children: + raise AirflowException("Cannot add a non-empty TaskGroup") + + self.children[key] = task + return task + + def _remove(self, task: DAGNode) -> None: + key = task.node_id + + if key not in self.children: + raise KeyError(f"Node id {key!r} not part of this task group") + + self.used_group_ids.remove(key) + del self.children[key] + + @property + def group_id(self) -> str | None: + """group_id of this TaskGroup.""" + if self.parent_group and self.parent_group.prefix_group_id and self.parent_group._group_id: + # defer to parent whether it adds a prefix + return self.parent_group.child_id(self._group_id) + + return self._group_id + + @property + def label(self) -> str | None: + """group_id excluding parent's group_id used as the node label in UI.""" + return self._group_id + + def update_relative( + self, other: DependencyMixin, upstream: bool = True, edge_modifier: EdgeModifier | None = None + ) -> None: + """ + Override TaskMixin.update_relative. + + Update upstream_group_ids/downstream_group_ids/upstream_task_ids/downstream_task_ids + accordingly so that we can reduce the number of edges when displaying Graph view. + """ + if isinstance(other, TaskGroup): + # Handles setting relationship between a TaskGroup and another TaskGroup + if upstream: + parent, child = (self, other) + if edge_modifier: + edge_modifier.add_edge_info(self.dag, other.downstream_join_id, self.upstream_join_id) + else: + parent, child = (other, self) + if edge_modifier: + edge_modifier.add_edge_info(self.dag, self.downstream_join_id, other.upstream_join_id) + + parent.upstream_group_ids.add(child.group_id) + child.downstream_group_ids.add(parent.group_id) + else: + # Handles setting relationship between a TaskGroup and a task + for task in other.roots: + if not isinstance(task, DAGNode): + raise AirflowException( + "Relationships can only be set between TaskGroup " + f"or operators; received {task.__class__.__name__}" + ) + + # Do not set a relationship between a TaskGroup and a Label's roots + if self == task: + continue + + if upstream: + self.upstream_task_ids.add(task.node_id) + if edge_modifier: + edge_modifier.add_edge_info(self.dag, task.node_id, self.upstream_join_id) + else: + self.downstream_task_ids.add(task.node_id) + if edge_modifier: + edge_modifier.add_edge_info(self.dag, self.downstream_join_id, task.node_id) + + def _set_relatives( + self, + task_or_task_list: DependencyMixin | Sequence[DependencyMixin], + upstream: bool = False, + edge_modifier: EdgeModifier | None = None, + ) -> None: + """ + Call set_upstream/set_downstream for all root/leaf tasks within this TaskGroup. + + Update upstream_group_ids/downstream_group_ids/upstream_task_ids/downstream_task_ids. + """ + if not isinstance(task_or_task_list, Sequence): + task_or_task_list = [task_or_task_list] + + for task_like in task_or_task_list: + self.update_relative(task_like, upstream, edge_modifier=edge_modifier) + + if upstream: + for task in self.get_roots(): + task.set_upstream(task_or_task_list) + else: + for task in self.get_leaves(): + task.set_downstream(task_or_task_list) + + def __enter__(self) -> TaskGroup: + from airflow.sdk.definitions.contextmanager import TaskGroupContext + + TaskGroupContext.push(self) + return self + + def __exit__(self, _type, _value, _tb): + from airflow.sdk.definitions.contextmanager import TaskGroupContext + + TaskGroupContext.pop() + + def has_task(self, task: BaseOperator) -> bool: + """Return True if this TaskGroup or its children TaskGroups contains the given task.""" + if task.task_id in self.children: + return True + + return any(child.has_task(task) for child in self.children.values() if isinstance(child, TaskGroup)) + + @property + def roots(self) -> list[BaseOperator]: + """Required by DependencyMixin.""" + return list(self.get_roots()) + + @property + def leaves(self) -> list[BaseOperator]: + """Required by DependencyMixin.""" + return list(self.get_leaves()) + + def get_roots(self) -> Generator[BaseOperator, None, None]: + """Return a generator of tasks with no upstream dependencies within the TaskGroup.""" + tasks = list(self) + ids = {x.task_id for x in tasks} + for task in tasks: + if task.upstream_task_ids.isdisjoint(ids): + yield task + + def get_leaves(self) -> Generator[BaseOperator, None, None]: + """Return a generator of tasks with no downstream dependencies within the TaskGroup.""" + tasks = list(self) + ids = {x.task_id for x in tasks} + + def has_non_teardown_downstream(task, exclude: str): + for down_task in task.downstream_list: + if down_task.task_id == exclude: + continue + elif down_task.task_id not in ids: + continue + elif not down_task.is_teardown: + return True + return False + + def recurse_for_first_non_teardown(task): + for upstream_task in task.upstream_list: + if upstream_task.task_id not in ids: + # upstream task is not in task group + continue + elif upstream_task.is_teardown: + yield from recurse_for_first_non_teardown(upstream_task) + elif task.is_teardown and upstream_task.is_setup: + # don't go through the teardown-to-setup path + continue + # return unless upstream task already has non-teardown downstream in group + elif not has_non_teardown_downstream(upstream_task, exclude=task.task_id): + yield upstream_task + + for task in tasks: + if task.downstream_task_ids.isdisjoint(ids): + if not task.is_teardown: + yield task + else: + yield from recurse_for_first_non_teardown(task) + + def child_id(self, label): + """Prefix label with group_id if prefix_group_id is True. Otherwise return the label as-is.""" + if self.prefix_group_id: + group_id = self.group_id + if group_id: + return f"{group_id}.{label}" + + return label + + @property + def upstream_join_id(self) -> str: + """ + Creates a unique ID for upstream dependencies of this TaskGroup. + + If this TaskGroup has immediate upstream TaskGroups or tasks, a proxy node called + upstream_join_id will be created in Graph view to join the outgoing edges from this + TaskGroup to reduce the total number of edges needed to be displayed. + """ + return f"{self.group_id}.upstream_join_id" + + @property + def downstream_join_id(self) -> str: + """ + Creates a unique ID for downstream dependencies of this TaskGroup. + + If this TaskGroup has immediate downstream TaskGroups or tasks, a proxy node called + downstream_join_id will be created in Graph view to join the outgoing edges from this + TaskGroup to reduce the total number of edges needed to be displayed. + """ + return f"{self.group_id}.downstream_join_id" + + def get_task_group_dict(self) -> dict[str, TaskGroup]: + """Return a flat dictionary of group_id: TaskGroup.""" + task_group_map = {} + + def build_map(task_group): + if not isinstance(task_group, TaskGroup): + return + + task_group_map[task_group.group_id] = task_group + + for child in task_group.children.values(): + build_map(child) + + build_map(self) + return task_group_map + + def get_child_by_label(self, label: str) -> DAGNode: + """Get a child task/TaskGroup by its label (i.e. task_id/group_id without the group_id prefix).""" + return self.children[self.child_id(label)] + + def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]: + """Serialize task group; required by DAGNode.""" + from airflow.serialization.enums import DagAttributeTypes + from airflow.serialization.serialized_objects import TaskGroupSerialization + + return DagAttributeTypes.TASK_GROUP, TaskGroupSerialization.serialize_task_group(self) + + def hierarchical_alphabetical_sort(self): + """ + Sort children in hierarchical alphabetical order. + + - groups in alphabetical order first + - tasks in alphabetical order after them. + + :return: list of tasks in hierarchical alphabetical order + """ + return sorted( + self.children.values(), key=lambda node: (not isinstance(node, TaskGroup), node.node_id) + ) + + def topological_sort(self): + """ + Sorts children in topographical order, such that a task comes after any of its upstream dependencies. + + :return: list of tasks in topological order + """ + # This uses a modified version of Kahn's Topological Sort algorithm to + # not have to pre-compute the "in-degree" of the nodes. + graph_unsorted = copy.copy(self.children) + + graph_sorted: list[DAGNode] = [] + + # special case + if not self.children: + return graph_sorted + + # Run until the unsorted graph is empty. + while graph_unsorted: + # Go through each of the node/edges pairs in the unsorted graph. If a set of edges doesn't contain + # any nodes that haven't been resolved, that is, that are still in the unsorted graph, remove the + # pair from the unsorted graph, and append it to the sorted graph. Note here that by using + # the values() method for iterating, a copy of the unsorted graph is used, allowing us to modify + # the unsorted graph as we move through it. + # + # We also keep a flag for checking that graph is acyclic, which is true if any nodes are resolved + # during each pass through the graph. If not, we need to exit as the graph therefore can't be + # sorted. + acyclic = False + for node in list(graph_unsorted.values()): + for edge in node.upstream_list: + if edge.node_id in graph_unsorted: + break + # Check for task's group is a child (or grand child) of this TG, + tg = edge.task_group + while tg: + if tg.node_id in graph_unsorted: + break + tg = tg.parent_group + + if tg: + # We are already going to visit that TG + break + else: + acyclic = True + del graph_unsorted[node.node_id] + graph_sorted.append(node) + + if not acyclic: + raise AirflowDagCycleException(f"A cyclic dependency occurred in dag: {self.dag_id}") + + return graph_sorted + + def iter_mapped_task_groups(self) -> Iterator[MappedTaskGroup]: + """ + Return mapped task groups in the hierarchy. + + Groups are returned from the closest to the outmost. If *self* is a + mapped task group, it is returned first. + + :meta private: + """ + group: TaskGroup | None = self + while group is not None: + if isinstance(group, MappedTaskGroup): + yield group + group = group.parent_group + + def iter_tasks(self) -> Iterator[AbstractOperator]: + """Return an iterator of the child tasks.""" + from airflow.models.abstractoperator import AbstractOperator + + groups_to_visit = [self] + + while groups_to_visit: + visiting = groups_to_visit.pop(0) + + for child in visiting.children.values(): + if isinstance(child, AbstractOperator): + yield child + elif isinstance(child, TaskGroup): + groups_to_visit.append(child) + else: + raise ValueError( + f"Encountered a DAGNode that is not a TaskGroup or an AbstractOperator: {type(child)}" + ) + + +class MappedTaskGroup(TaskGroup): + """ + A mapped task group. + + This doesn't really do anything special, just holds some additional metadata + for expansion later. + + Don't instantiate this class directly; call *expand* or *expand_kwargs* on + a ``@task_group`` function instead. + """ + + def __init__(self, *, expand_input: ExpandInput, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._expand_input = expand_input + + def iter_mapped_dependencies(self) -> Iterator[DAGNode]: + """Upstream dependencies that provide XComs used by this mapped task group.""" + from airflow.models.xcom_arg import XComArg + + for op, _ in XComArg.iter_xcom_references(self._expand_input): + yield op + + @methodtools.lru_cache(maxsize=None) + def get_parse_time_mapped_ti_count(self) -> int: + """ + Return the Number of instances a task in this group should be mapped to, when a DAG run is created. + + This only considers literal mapped arguments, and would return *None* + when any non-literal values are used for mapping. + + If this group is inside mapped task groups, all the nested counts are + multiplied and accounted. + + :meta private: + + :raise NotFullyPopulated: If any non-literal mapped arguments are encountered. + :return: The total number of mapped instances each task should have. + """ + return functools.reduce( + operator.mul, + (g._expand_input.get_parse_time_mapped_ti_count() for g in self.iter_mapped_task_groups()), + ) + + def __exit__(self, exc_type, exc_val, exc_tb): + for op, _ in self._expand_input.iter_references(): + self.set_upstream(op) + super().__exit__(exc_type, exc_val, exc_tb) + + +def task_group_to_dict(task_item_or_group): + """Create a nested dict representation of this TaskGroup and its children used to construct the Graph.""" + from airflow.models.abstractoperator import AbstractOperator + from airflow.models.mappedoperator import MappedOperator + + if isinstance(task := task_item_or_group, AbstractOperator): + setup_teardown_type = {} + is_mapped = {} + if task.is_setup is True: + setup_teardown_type["setupTeardownType"] = "setup" + elif task.is_teardown is True: + setup_teardown_type["setupTeardownType"] = "teardown" + if isinstance(task, MappedOperator): + is_mapped["isMapped"] = True + return { + "id": task.task_id, + "value": { + "label": task.label, + "labelStyle": f"fill:{task.ui_fgcolor};", + "style": f"fill:{task.ui_color};", + "rx": 5, + "ry": 5, + **is_mapped, + **setup_teardown_type, + }, + } + task_group = task_item_or_group + is_mapped = isinstance(task_group, MappedTaskGroup) + children = [ + task_group_to_dict(child) for child in sorted(task_group.children.values(), key=lambda t: t.label) + ] + + if task_group.upstream_group_ids or task_group.upstream_task_ids: + children.append( + { + "id": task_group.upstream_join_id, + "value": { + "label": "", + "labelStyle": f"fill:{task_group.ui_fgcolor};", + "style": f"fill:{task_group.ui_color};", + "shape": "circle", + }, + } + ) + + if task_group.downstream_group_ids or task_group.downstream_task_ids: + # This is the join node used to reduce the number of edges between two TaskGroup. + children.append( + { + "id": task_group.downstream_join_id, + "value": { + "label": "", + "labelStyle": f"fill:{task_group.ui_fgcolor};", + "style": f"fill:{task_group.ui_color};", + "shape": "circle", + }, + } + ) + + return { + "id": task_group.group_id, + "value": { + "label": task_group.label, + "labelStyle": f"fill:{task_group.ui_fgcolor};", + "style": f"fill:{task_group.ui_color}", + "rx": 5, + "ry": 5, + "clusterLabelPos": "top", + "tooltip": task_group.tooltip, + "isMapped": is_mapped, + }, + "children": children, + } diff --git a/task_sdk/src/airflow/sdk/exceptions.py b/task_sdk/src/airflow/sdk/exceptions.py new file mode 100644 index 000000000000..13a83393a912 --- /dev/null +++ b/task_sdk/src/airflow/sdk/exceptions.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/task_sdk/src/airflow/sdk/types.py b/task_sdk/src/airflow/sdk/types.py new file mode 100644 index 000000000000..232d08e27f90 --- /dev/null +++ b/task_sdk/src/airflow/sdk/types.py @@ -0,0 +1,75 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + + +class ArgNotSet: + """ + Sentinel type for annotations, useful when None is not viable. + + Use like this:: + + def is_arg_passed(arg: Union[ArgNotSet, None] = NOTSET) -> bool: + if arg is NOTSET: + return False + return True + + + is_arg_passed() # False. + is_arg_passed(None) # True. + """ + + @staticmethod + def serialize(): + return "NOTSET" + + @classmethod + def deserialize(cls): + return cls + + +NOTSET = ArgNotSet() +"""Sentinel value for argument default. See ``ArgNotSet``.""" + + +if TYPE_CHECKING: + import logging + + from airflow.sdk.definitions.node import DAGNode + + Logger = logging.Logger +else: + + class Logger: ... # noqa: D101 + + +def validate_instance_args(instance: DAGNode, expected_arg_types: dict[str, Any]) -> None: + """Validate that the instance has the expected types for the arguments.""" + from airflow.sdk.definitions.taskgroup import TaskGroup + + typ = "task group" if isinstance(instance, TaskGroup) else "task" + + for arg_name, expected_arg_type in expected_arg_types.items(): + instance_arg_value = getattr(instance, arg_name, None) + if instance_arg_value is not None and not isinstance(instance_arg_value, expected_arg_type): + raise TypeError( + f"{arg_name!r} for {typ} {instance.node_id!r} expects {expected_arg_type}, got {type(instance_arg_value)} with value " + f"{instance_arg_value!r}" + ) diff --git a/task_sdk/tests/defintions/test_baseoperator.py b/task_sdk/tests/defintions/test_baseoperator.py new file mode 100644 index 000000000000..427d1ee0e3ef --- /dev/null +++ b/task_sdk/tests/defintions/test_baseoperator.py @@ -0,0 +1,343 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import warnings +from datetime import datetime, timedelta, timezone + +import pytest + +from airflow.sdk.definitions.baseoperator import BaseOperator, BaseOperatorMeta +from airflow.sdk.definitions.dag import DAG +from airflow.task.priority_strategy import _DownstreamPriorityWeightStrategy, _UpstreamPriorityWeightStrategy + +DEFAULT_DATE = datetime(2016, 1, 1, tzinfo=timezone.utc) + + +# Essentially similar to airflow.models.baseoperator.BaseOperator +class FakeOperator(metaclass=BaseOperatorMeta): + def __init__(self, test_param, params=None, default_args=None): + self.test_param = test_param + + def _set_xcomargs_dependencies(self): ... + + +class FakeSubClass(FakeOperator): + def __init__(self, test_sub_param, test_param, **kwargs): + super().__init__(test_param=test_param, **kwargs) + self.test_sub_param = test_sub_param + + +class DeprecatedOperator(BaseOperator): + def __init__(self, **kwargs): + warnings.warn("This operator is deprecated.", DeprecationWarning, stacklevel=2) + super().__init__(**kwargs) + + +class MockOperator(BaseOperator): + """Operator for testing purposes.""" + + template_fields = ("arg1", "arg2") + + def __init__(self, arg1: str = "", arg2: str = "", **kwargs): + super().__init__(**kwargs) + self.arg1 = arg1 + self.arg2 = arg2 + + +class TestBaseOperator: + # Since we have a custom metaclass, lets double check the behaviour of passing args in the wrong way (args + # etc) + def test_kwargs_only(self): + with pytest.raises(TypeError, match="keyword arguments"): + BaseOperator("task_id") + + def test_missing_kwarg(self): + with pytest.raises(TypeError, match="missing keyword argument"): + FakeOperator(task_id="task_id") + + def test_missing_kwargs(self): + with pytest.raises(TypeError, match="missing keyword arguments"): + FakeSubClass(task_id="task_id") + + def test_hash(self): + """Two operators created equally should hash equaylly""" + # Include a "non-hashable" type too + assert hash(MockOperator(task_id="one", retries=1024 * 1024, arg1="abcef", params={"a": 1})) == hash( + MockOperator(task_id="one", retries=1024 * 1024, arg1="abcef", params={"a": 2}) + ) + + def test_expand(self): + op = FakeOperator(test_param=True) + assert op.test_param + + with pytest.raises(TypeError, match="missing keyword argument 'test_param'"): + FakeSubClass(test_sub_param=True) + + def test_default_args(self): + default_args = {"test_param": True} + op = FakeOperator(default_args=default_args) + assert op.test_param + + default_args = {"test_param": True, "test_sub_param": True} + op = FakeSubClass(default_args=default_args) + assert op.test_param + assert op.test_sub_param + + default_args = {"test_param": True} + op = FakeSubClass(default_args=default_args, test_sub_param=True) + assert op.test_param + assert op.test_sub_param + + with pytest.raises(TypeError, match="missing keyword argument 'test_sub_param'"): + FakeSubClass(default_args=default_args) + + def test_execution_timeout_type(self): + with pytest.raises( + ValueError, match="execution_timeout must be timedelta object but passed as type: " + ): + BaseOperator(task_id="test", execution_timeout="1") + + with pytest.raises( + ValueError, match="execution_timeout must be timedelta object but passed as type: " + ): + BaseOperator(task_id="test", execution_timeout=1) + + def test_default_resources(self): + task = BaseOperator(task_id="default-resources") + assert task.resources is None + + def test_custom_resources(self): + task = BaseOperator(task_id="custom-resources", resources={"cpus": 1, "ram": 1024}) + assert task.resources.cpus.qty == 1 + assert task.resources.ram.qty == 1024 + + def test_default_email_on_actions(self): + test_task = BaseOperator(task_id="test_default_email_on_actions") + assert test_task.email_on_retry is True + assert test_task.email_on_failure is True + + def test_email_on_actions(self): + test_task = BaseOperator( + task_id="test_default_email_on_actions", email_on_retry=False, email_on_failure=True + ) + assert test_task.email_on_retry is False + assert test_task.email_on_failure is True + + def test_incorrect_default_args(self): + default_args = {"test_param": True, "extra_param": True} + op = FakeOperator(default_args=default_args) + assert op.test_param + + default_args = {"random_params": True} + with pytest.raises(TypeError, match="missing keyword argument 'test_param'"): + FakeOperator(default_args=default_args) + + def test_incorrect_priority_weight(self): + error_msg = "'priority_weight' for task 'test_op' expects , got " + with pytest.raises(TypeError, match=error_msg): + BaseOperator(task_id="test_op", priority_weight="2") + + def test_illegal_args_forbidden(self): + """ + Tests that operators raise exceptions on illegal arguments when + illegal arguments are not allowed. + """ + msg = r"Invalid arguments were passed to BaseOperator \(task_id: test_illegal_args\)" + with pytest.raises(TypeError, match=msg): + BaseOperator( + task_id="test_illegal_args", + illegal_argument_1234="hello?", + ) + + def test_invalid_type_for_default_arg(self): + error_msg = "'max_active_tis_per_dag' for task 'test' expects , got with value 'not_an_int'" + with pytest.raises(TypeError, match=error_msg): + BaseOperator(task_id="test", default_args={"max_active_tis_per_dag": "not_an_int"}) + + def test_invalid_type_for_operator_arg(self): + error_msg = "'max_active_tis_per_dag' for task 'test' expects , got with value 'not_an_int'" + with pytest.raises(TypeError, match=error_msg): + BaseOperator(task_id="test", max_active_tis_per_dag="not_an_int") + + def test_weight_rule_default(self): + op = BaseOperator(task_id="test_task") + assert _DownstreamPriorityWeightStrategy() == op.weight_rule + + def test_weight_rule_override(self): + op = BaseOperator(task_id="test_task", weight_rule="upstream") + assert _UpstreamPriorityWeightStrategy() == op.weight_rule + + def test_dag_task_invalid_weight_rule(self): + # Test if we enter an invalid weight rule + with pytest.raises(ValueError): + BaseOperator(task_id="should_fail", weight_rule="no rule") + + def test_dag_task_not_registered_weight_strategy(self): + from airflow.task.priority_strategy import PriorityWeightStrategy + + class NotRegisteredPriorityWeightStrategy(PriorityWeightStrategy): + def get_weight(self, ti): + return 99 + + with pytest.raises(ValueError, match="Unknown priority strategy"): + BaseOperator( + task_id="empty_task", + weight_rule=NotRegisteredPriorityWeightStrategy(), + ) + + def test_warnings_are_properly_propagated(self): + with pytest.warns(DeprecationWarning) as warnings: + DeprecatedOperator(task_id="test") + assert len(warnings) == 1 + warning = warnings[0] + # Here we check that the trace points to the place + # where the deprecated class was used + assert warning.filename == __file__ + + def test_setattr_performs_no_custom_action_at_execute_time(self, spy_agency): + from airflow.models.xcom_arg import XComArg + + op = MockOperator(task_id="test_task") + + op._lock_for_execution = True + # TODO: Task-SDK + # op_copy = op.prepare_for_execution() + op_copy = op + + spy_agency.spy_on(XComArg.apply_upstream_relationship, call_original=False) + op_copy.arg1 = "b" + assert XComArg.apply_upstream_relationship.called is False + + def test_upstream_is_set_when_template_field_is_xcomarg(self): + with DAG("xcomargs_test", schedule=None): + op1 = BaseOperator(task_id="op1") + op2 = MockOperator(task_id="op2", arg1=op1.output) + + assert op1.task_id in op2.upstream_task_ids + assert op2.task_id in op1.downstream_task_ids + + def test_set_xcomargs_dependencies_works_recursively(self): + with DAG("xcomargs_test", schedule=None): + op1 = BaseOperator(task_id="op1") + op2 = BaseOperator(task_id="op2") + op3 = MockOperator(task_id="op3", arg1=[op1.output, op2.output]) + op4 = MockOperator(task_id="op4", arg1={"op1": op1.output, "op2": op2.output}) + + assert op1.task_id in op3.upstream_task_ids + assert op2.task_id in op3.upstream_task_ids + assert op1.task_id in op4.upstream_task_ids + assert op2.task_id in op4.upstream_task_ids + + def test_set_xcomargs_dependencies_works_when_set_after_init(self): + with DAG(dag_id="xcomargs_test", schedule=None): + op1 = BaseOperator(task_id="op1") + op2 = MockOperator(task_id="op2") + op2.arg1 = op1.output # value is set after init + + assert op1.task_id in op2.upstream_task_ids + + def test_set_xcomargs_dependencies_error_when_outside_dag(self): + op1 = BaseOperator(task_id="op1") + with pytest.raises(ValueError): + MockOperator(task_id="op2", arg1=op1.output) + + def test_cannot_change_dag(self): + with DAG(dag_id="dag1", schedule=None): + op1 = BaseOperator(task_id="op1") + with pytest.raises(ValueError, match="can not be changed"): + op1.dag = DAG(dag_id="dag2") + + def test_invalid_trigger_rule(self): + with pytest.raises( + ValueError, + match=(r"The trigger_rule must be one of .*,'\.op1'; received 'some_rule'\."), + ): + BaseOperator(task_id="op1", trigger_rule="some_rule") + + +def test_init_subclass_args(): + class InitSubclassOp(BaseOperator): + class_arg = None + + def __init_subclass__(cls, class_arg=None, **kwargs) -> None: + cls.class_arg = class_arg + super().__init_subclass__() + + class_arg = "foo" + + class ConcreteSubclassOp(InitSubclassOp, class_arg=class_arg): + pass + + task = ConcreteSubclassOp(task_id="op1") + + assert task.class_arg == class_arg + + +class CustomInt(int): + def __int__(self): + raise ValueError("Cannot cast to int") + + +@pytest.mark.parametrize( + ("retries", "expected"), + [ + pytest.param("foo", "'retries' type must be int, not str", id="string"), + pytest.param(CustomInt(10), "'retries' type must be int, not CustomInt", id="custom int"), + ], +) +def test_operator_retries_invalid(dag_maker, retries, expected): + with pytest.raises(TypeError) as ctx: + BaseOperator(task_id="test_illegal_args", retries=retries) + assert str(ctx.value) == expected + + +@pytest.mark.parametrize( + ("retries", "expected"), + [ + pytest.param(None, 0, id="None"), + pytest.param("5", 5, id="str"), + pytest.param(1, 1, id="int"), + ], +) +def test_operator_retries_conversion(retries, expected): + op = BaseOperator( + task_id="test_illegal_args", + retries=retries, + ) + assert op.retries == expected + + +def test_default_retry_delay(): + task1 = BaseOperator(task_id="test_no_explicit_retry_delay") + + assert task1.retry_delay == timedelta(seconds=300) + + +def test_dag_level_retry_delay(): + with DAG(dag_id="test_dag_level_retry_delay", default_args={"retry_delay": timedelta(seconds=100)}): + task1 = BaseOperator(task_id="test_no_explicit_retry_delay") + + assert task1.retry_delay == timedelta(seconds=100) + + +def test_task_level_retry_delay(): + with DAG(dag_id="test_task_level_retry_delay", default_args={"retry_delay": timedelta(seconds=100)}): + task1 = BaseOperator(task_id="test_no_explicit_retry_delay", retry_delay=200) + + assert task1.retry_delay == timedelta(seconds=200) diff --git a/task_sdk/tests/defintions/test_dag.py b/task_sdk/tests/defintions/test_dag.py new file mode 100644 index 000000000000..b2481a49b6a1 --- /dev/null +++ b/task_sdk/tests/defintions/test_dag.py @@ -0,0 +1,419 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import weakref +from datetime import datetime, timedelta, timezone +from typing import Any + +import pytest + +from airflow.exceptions import DuplicateTaskIdFound +from airflow.models.param import Param, ParamsDict +from airflow.sdk.definitions.baseoperator import BaseOperator +from airflow.sdk.definitions.dag import DAG, dag as dag_decorator + +DEFAULT_DATE = datetime(2016, 1, 1, tzinfo=timezone.utc) + + +class TestDag: + def test_dag_topological_sort_dag_without_tasks(self): + dag = DAG("dag", schedule=None, start_date=DEFAULT_DATE, default_args={"owner": "owner1"}) + + assert () == dag.topological_sort() + + def test_dag_naive_start_date_string(self): + DAG("DAG", schedule=None, default_args={"start_date": "2019-06-01"}) + + def test_dag_naive_start_end_dates_strings(self): + DAG("DAG", schedule=None, default_args={"start_date": "2019-06-01", "end_date": "2019-06-05"}) + + def test_dag_start_date_propagates_to_end_date(self): + """ + Tests that a start_date string with a timezone and an end_date string without a timezone + are accepted and that the timezone from the start carries over the end + + This test is a little indirect, it works by setting start and end equal except for the + timezone and then testing for equality after the DAG construction. They'll be equal + only if the same timezone was applied to both. + + An explicit check the `tzinfo` attributes for both are the same is an extra check. + """ + dag = DAG( + "DAG", + schedule=None, + default_args={"start_date": "2019-06-05T00:00:00+05:00", "end_date": "2019-06-05T00:00:00"}, + ) + assert dag.default_args["start_date"] == dag.default_args["end_date"] + assert dag.default_args["start_date"].tzinfo == dag.default_args["end_date"].tzinfo + + def test_dag_as_context_manager(self): + """ + Test DAG as a context manager. + When used as a context manager, Operators are automatically added to + the DAG (unless they specify a different DAG) + """ + dag = DAG("dag", schedule=None, start_date=DEFAULT_DATE, default_args={"owner": "owner1"}) + dag2 = DAG("dag2", schedule=None, start_date=DEFAULT_DATE, default_args={"owner": "owner2"}) + + with dag: + op1 = BaseOperator(task_id="op1") + op2 = BaseOperator(task_id="op2", dag=dag2) + + assert op1.dag is dag + assert op1.owner == "owner1" + assert op2.dag is dag2 + assert op2.owner == "owner2" + + with dag2: + op3 = BaseOperator(task_id="op3") + + assert op3.dag is dag2 + assert op3.owner == "owner2" + + with dag: + with dag2: + op4 = BaseOperator(task_id="op4") + op5 = BaseOperator(task_id="op5") + + assert op4.dag is dag2 + assert op5.dag is dag + assert op4.owner == "owner2" + assert op5.owner == "owner1" + + with DAG("creating_dag_in_cm", schedule=None, start_date=DEFAULT_DATE) as dag: + BaseOperator(task_id="op6") + + assert dag.dag_id == "creating_dag_in_cm" + assert dag.tasks[0].task_id == "op6" + + with dag: + with dag: + op7 = BaseOperator(task_id="op7") + op8 = BaseOperator(task_id="op8") + op9 = BaseOperator(task_id="op8") + op9.dag = dag2 + + assert op7.dag == dag + assert op8.dag == dag + assert op9.dag == dag2 + + def test_params_not_passed_is_empty_dict(self): + """ + Test that when 'params' is _not_ passed to a new Dag, that the params + attribute is set to an empty dictionary. + """ + dag = DAG("test-dag", schedule=None) + + assert isinstance(dag.params, ParamsDict) + assert 0 == len(dag.params) + + def test_params_passed_and_params_in_default_args_no_override(self): + """ + Test that when 'params' exists as a key passed to the default_args dict + in addition to params being passed explicitly as an argument to the + dag, that the 'params' key of the default_args dict is merged with the + dict of the params argument. + """ + params1 = {"parameter1": 1} + params2 = {"parameter2": 2} + + dag = DAG("test-dag", schedule=None, default_args={"params": params1}, params=params2) + + assert params1["parameter1"] == dag.params["parameter1"] + assert params2["parameter2"] == dag.params["parameter2"] + + def test_not_none_schedule_with_non_default_params(self): + """ + Test if there is a DAG with a schedule and have some params that don't have a default value raise a + error while DAG parsing. (Because we can't schedule them if there we don't know what value to use) + """ + params = {"param1": Param(type="string")} + + with pytest.raises(ValueError): + DAG("my-dag", schedule=timedelta(days=1), start_date=DEFAULT_DATE, params=params) + + def test_roots(self): + """Verify if dag.roots returns the root tasks of a DAG.""" + with DAG("test_dag", schedule=None, start_date=DEFAULT_DATE) as dag: + op1 = BaseOperator(task_id="t1") + op2 = BaseOperator(task_id="t2") + op3 = BaseOperator(task_id="t3") + op4 = BaseOperator(task_id="t4") + op5 = BaseOperator(task_id="t5") + [op1, op2] >> op3 >> [op4, op5] + + assert set(dag.roots) == {op1, op2} + + def test_leaves(self): + """Verify if dag.leaves returns the leaf tasks of a DAG.""" + with DAG("test_dag", schedule=None, start_date=DEFAULT_DATE) as dag: + op1 = BaseOperator(task_id="t1") + op2 = BaseOperator(task_id="t2") + op3 = BaseOperator(task_id="t3") + op4 = BaseOperator(task_id="t4") + op5 = BaseOperator(task_id="t5") + [op1, op2] >> op3 >> [op4, op5] + + assert set(dag.leaves) == {op4, op5} + + def test_duplicate_task_ids_not_allowed_with_dag_context_manager(self): + """Verify tasks with Duplicate task_id raises error""" + with DAG("test_dag", schedule=None, start_date=DEFAULT_DATE) as dag: + op1 = BaseOperator(task_id="t1") + with pytest.raises(DuplicateTaskIdFound, match="Task id 't1' has already been added to the DAG"): + BaseOperator(task_id="t1") + + assert dag.task_dict == {op1.task_id: op1} + + def test_duplicate_task_ids_not_allowed_without_dag_context_manager(self): + """Verify tasks with Duplicate task_id raises error""" + dag = DAG("test_dag", schedule=None, start_date=DEFAULT_DATE) + op1 = BaseOperator(task_id="t1", dag=dag) + with pytest.raises(DuplicateTaskIdFound, match="Task id 't1' has already been added to the DAG"): + BaseOperator(task_id="t1", dag=dag) + + assert dag.task_dict == {op1.task_id: op1} + + def test_duplicate_task_ids_for_same_task_is_allowed(self): + """Verify that same tasks with Duplicate task_id do not raise error""" + with DAG("test_dag", schedule=None, start_date=DEFAULT_DATE) as dag: + op1 = op2 = BaseOperator(task_id="t1") + op3 = BaseOperator(task_id="t3") + op1 >> op3 + op2 >> op3 + + assert op1 == op2 + assert dag.task_dict == {op1.task_id: op1, op3.task_id: op3} + assert dag.task_dict == {op2.task_id: op2, op3.task_id: op3} + + def test_fail_dag_when_schedule_is_non_none_and_empty_start_date(self): + # Check that we get a ValueError 'start_date' for self.start_date when schedule is non-none + with pytest.raises(ValueError, match="start_date is required when catchup=True"): + DAG(dag_id="dag_with_non_none_schedule_and_empty_start_date", schedule="@hourly", catchup=True) + + def test_partial_subset_updates_all_references_while_deepcopy(self): + with DAG("test_dag", schedule=None, start_date=DEFAULT_DATE) as dag: + op1 = BaseOperator(task_id="t1") + op2 = BaseOperator(task_id="t2") + op3 = BaseOperator(task_id="t3") + op1 >> op2 + op2 >> op3 + + partial = dag.partial_subset("t2", include_upstream=True, include_downstream=False) + assert id(partial.task_dict["t1"].downstream_list[0].dag) == id(partial) + + # Copied DAG should not include unused task IDs in used_group_ids + assert "t3" not in partial.task_group.used_group_ids + + def test_partial_subset_taskgroup_join_ids(self): + from airflow.sdk import TaskGroup + + with DAG("test_dag", schedule=None, start_date=DEFAULT_DATE) as dag: + start = BaseOperator(task_id="start") + with TaskGroup(group_id="outer", prefix_group_id=False) as outer_group: + with TaskGroup(group_id="tg1", prefix_group_id=False) as tg1: + BaseOperator(task_id="t1") + with TaskGroup(group_id="tg2", prefix_group_id=False) as tg2: + BaseOperator(task_id="t2") + + start >> tg1 >> tg2 + + # Pre-condition checks + task = dag.get_task("t2") + assert task.task_group.upstream_group_ids == {"tg1"} + assert isinstance(task.task_group.parent_group, weakref.ProxyType) + assert task.task_group.parent_group == outer_group + + partial = dag.partial_subset(["t2"], include_upstream=True, include_downstream=False) + copied_task = partial.get_task("t2") + assert copied_task.task_group.upstream_group_ids == {"tg1"} + assert isinstance(copied_task.task_group.parent_group, weakref.ProxyType) + assert copied_task.task_group.parent_group + + # Make sure we don't affect the original! + assert task.task_group.upstream_group_ids is not copied_task.task_group.upstream_group_ids + + def test_dag_owner_links(self): + dag = DAG( + "dag", + schedule=None, + start_date=DEFAULT_DATE, + owner_links={"owner1": "https://mylink.com", "owner2": "mailto:someone@yoursite.com"}, + ) + + assert dag.owner_links == {"owner1": "https://mylink.com", "owner2": "mailto:someone@yoursite.com"} + + # Check wrong formatted owner link + with pytest.raises(ValueError, match="Wrong link format"): + DAG("dag", schedule=None, start_date=DEFAULT_DATE, owner_links={"owner1": "my-bad-link"}) + + dag = DAG("dag", schedule=None, start_date=DEFAULT_DATE) + dag.owner_links["owner1"] = "my-bad-link" + with pytest.raises(ValueError, match="Wrong link format"): + dag.validate() + + def test_continuous_schedule_linmits_max_active_runs(self): + from airflow.timetables.simple import ContinuousTimetable + + dag = DAG("continuous", start_date=DEFAULT_DATE, schedule="@continuous", max_active_runs=1) + assert isinstance(dag.timetable, ContinuousTimetable) + assert dag.max_active_runs == 1 + + dag = DAG("continuous", start_date=DEFAULT_DATE, schedule="@continuous", max_active_runs=0) + assert isinstance(dag.timetable, ContinuousTimetable) + assert dag.max_active_runs == 0 + + with pytest.raises(ValueError, match="ContinuousTimetable requires max_active_runs <= 1"): + dag = DAG("continuous", start_date=DEFAULT_DATE, schedule="@continuous", max_active_runs=25) + + +# Test some of the arg valiadtion. This is not all the validations we perform, just some of them. +@pytest.mark.parametrize( + ["attr", "value"], + [ + pytest.param("max_consecutive_failed_dag_runs", "not_an_int", id="max_consecutive_failed_dag_runs"), + pytest.param("dagrun_timeout", "not_an_int", id="dagrun_timeout"), + pytest.param("max_active_runs", "not_an_int", id="max_active_runs"), + ], +) +def test_invalid_type_for_args(attr: str, value: Any): + with pytest.raises(TypeError): + DAG("invalid-default-args", **{attr: value}) + + +@pytest.mark.parametrize( + "tags, should_pass", + [ + pytest.param([], True, id="empty tags"), + pytest.param(["a normal tag"], True, id="one tag"), + pytest.param(["a normal tag", "another normal tag"], True, id="two tags"), + pytest.param(["a" * 100], True, id="a tag that's of just length 100"), + pytest.param(["a normal tag", "a" * 101], False, id="two tags and one of them is of length > 100"), + ], +) +def test__tags_length(tags: list[str], should_pass: bool): + if should_pass: + DAG("test-dag", schedule=None, tags=tags) + else: + with pytest.raises(ValueError): + DAG("test-dag", schedule=None, tags=tags) + + +@pytest.mark.parametrize( + "input_tags, expected_result", + [ + pytest.param([], set(), id="empty tags"), + pytest.param( + ["a normal tag"], + {"a normal tag"}, + id="one tag", + ), + pytest.param( + ["a normal tag", "another normal tag"], + {"a normal tag", "another normal tag"}, + id="two different tags", + ), + pytest.param( + ["a", "a"], + {"a"}, + id="two same tags", + ), + ], +) +def test__tags_duplicates(input_tags: list[str], expected_result: set[str]): + result = DAG("test-dag", tags=input_tags) + assert result.tags == expected_result + + +def test__tags_mutable(): + expected_tags = {"6", "7"} + test_dag = DAG("test-dag") + test_dag.tags.add("6") + test_dag.tags.add("7") + test_dag.tags.add("8") + test_dag.tags.remove("8") + assert test_dag.tags == expected_tags + + +class TestDagDecorator: + DEFAULT_ARGS = { + "owner": "test", + "depends_on_past": True, + "start_date": datetime.now(tz=timezone.utc), + "retries": 1, + "retry_delay": timedelta(minutes=1), + } + VALUE = 42 + + def test_fileloc(self): + @dag_decorator(schedule=None, default_args=self.DEFAULT_ARGS) + def noop_pipeline(): ... + + dag = noop_pipeline() + assert isinstance(dag, DAG) + assert dag.dag_id == "noop_pipeline" + assert dag.fileloc == __file__ + + def test_set_dag_id(self): + """Test that checks you can set dag_id from decorator.""" + + @dag_decorator("test", schedule=None, default_args=self.DEFAULT_ARGS) + def noop_pipeline(): ... + + dag = noop_pipeline() + assert isinstance(dag, DAG) + assert dag.dag_id == "test" + + def test_default_dag_id(self): + """Test that @dag uses function name as default dag id.""" + + @dag_decorator(schedule=None, default_args=self.DEFAULT_ARGS) + def noop_pipeline(): ... + + dag = noop_pipeline() + assert isinstance(dag, DAG) + assert dag.dag_id == "noop_pipeline" + + @pytest.mark.parametrize( + argnames=["dag_doc_md", "expected_doc_md"], + argvalues=[ + pytest.param("dag docs.", "dag docs.", id="use_dag_doc_md"), + pytest.param(None, "Regular DAG documentation", id="use_dag_docstring"), + ], + ) + def test_documentation_added(self, dag_doc_md, expected_doc_md): + """Test that @dag uses function docs as doc_md for DAG object if doc_md is not explicitly set.""" + + @dag_decorator(schedule=None, default_args=self.DEFAULT_ARGS, doc_md=dag_doc_md) + def noop_pipeline(): + """Regular DAG documentation""" + + dag = noop_pipeline() + assert isinstance(dag, DAG) + assert dag.dag_id == "noop_pipeline" + assert dag.doc_md == expected_doc_md + + def test_fails_if_arg_not_set(self): + """Test that @dag decorated function fails if positional argument is not set""" + + @dag_decorator(schedule=None, default_args=self.DEFAULT_ARGS) + def noop_pipeline(value): ... + + # Test that if arg is not passed it raises a type error as expected. + with pytest.raises(TypeError): + noop_pipeline() diff --git a/tests/api_connexion/endpoints/test_dag_endpoint.py b/tests/api_connexion/endpoints/test_dag_endpoint.py index ecb0052a8a03..5249944ea113 100644 --- a/tests/api_connexion/endpoints/test_dag_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_endpoint.py @@ -190,7 +190,6 @@ def test_should_respond_200(self): "last_parsed_time": None, "timetable_description": None, "has_import_errors": False, - "pickle_id": None, } == response.json @conf_vars({("webserver", "secret_key"): "mysecret"}) @@ -230,7 +229,6 @@ def test_should_respond_200_with_schedule_none(self, session): "last_parsed_time": None, "timetable_description": None, "has_import_errors": False, - "pickle_id": None, } == response.json def test_should_respond_404(self): @@ -331,7 +329,6 @@ def test_should_respond_200(self, url_safe_serializer): "value": 1, } }, - "pickle_id": None, "render_template_as_native_obj": False, "timetable_summary": "2 2 * * *", "start_date": "2020-06-15T00:00:00+00:00", @@ -393,7 +390,6 @@ def test_should_respond_200_with_asset_expression(self, url_safe_serializer): "value": 1, } }, - "pickle_id": None, "render_template_as_native_obj": False, "timetable_summary": "2 2 * * *", "start_date": "2020-06-15T00:00:00+00:00", @@ -443,7 +439,6 @@ def test_should_response_200_with_doc_md_none(self, url_safe_serializer): "orientation": "LR", "owners": [], "params": {}, - "pickle_id": None, "render_template_as_native_obj": False, "timetable_summary": "2 2 * * *", "start_date": "2020-06-15T00:00:00+00:00", @@ -493,7 +488,6 @@ def test_should_response_200_for_null_start_date(self, url_safe_serializer): "orientation": "LR", "owners": [], "params": {}, - "pickle_id": None, "render_template_as_native_obj": False, "timetable_summary": "2 2 * * *", "start_date": None, @@ -552,7 +546,6 @@ def test_should_respond_200_serialized(self, url_safe_serializer): "value": 1, } }, - "pickle_id": None, "render_template_as_native_obj": False, "timetable_summary": "2 2 * * *", "start_date": "2020-06-15T00:00:00+00:00", @@ -612,7 +605,6 @@ def test_should_respond_200_serialized(self, url_safe_serializer): "value": 1, } }, - "pickle_id": None, "render_template_as_native_obj": False, "timetable_summary": "2 2 * * *", "start_date": "2020-06-15T00:00:00+00:00", @@ -712,7 +704,6 @@ def test_should_respond_200(self, session, url_safe_serializer): "last_parsed_time": None, "timetable_description": None, "has_import_errors": False, - "pickle_id": None, }, { "dag_id": "TEST_DAG_2", @@ -739,7 +730,6 @@ def test_should_respond_200(self, session, url_safe_serializer): "last_parsed_time": None, "timetable_description": None, "has_import_errors": False, - "pickle_id": None, }, ], "total_entries": 2, @@ -778,7 +768,6 @@ def test_only_active_true_returns_active_dags(self, url_safe_serializer): "last_parsed_time": None, "timetable_description": None, "has_import_errors": False, - "pickle_id": None, } ], "total_entries": 1, @@ -818,7 +807,6 @@ def test_only_active_false_returns_all_dags(self, url_safe_serializer): "last_parsed_time": None, "timetable_description": None, "has_import_errors": False, - "pickle_id": None, }, { "dag_id": "TEST_DAG_DELETED_1", @@ -845,7 +833,6 @@ def test_only_active_false_returns_all_dags(self, url_safe_serializer): "last_parsed_time": None, "timetable_description": None, "has_import_errors": False, - "pickle_id": None, }, ], "total_entries": 2, @@ -1001,7 +988,6 @@ def test_paused_true_returns_paused_dags(self, url_safe_serializer): "last_parsed_time": None, "timetable_description": None, "has_import_errors": False, - "pickle_id": None, } ], "total_entries": 1, @@ -1040,7 +1026,6 @@ def test_paused_false_returns_unpaused_dags(self, url_safe_serializer): "last_parsed_time": None, "timetable_description": None, "has_import_errors": False, - "pickle_id": None, } ], "total_entries": 1, @@ -1079,7 +1064,6 @@ def test_paused_none_returns_all_dags(self, url_safe_serializer): "last_parsed_time": None, "timetable_description": None, "has_import_errors": False, - "pickle_id": None, }, { "dag_id": "TEST_DAG_UNPAUSED_1", @@ -1106,7 +1090,6 @@ def test_paused_none_returns_all_dags(self, url_safe_serializer): "last_parsed_time": None, "timetable_description": None, "has_import_errors": False, - "pickle_id": None, }, ], "total_entries": 2, @@ -1195,7 +1178,6 @@ def test_should_respond_200_on_patch_is_paused(self, url_safe_serializer, sessio "last_parsed_time": None, "timetable_description": None, "has_import_errors": False, - "pickle_id": None, } assert response.json == expected_response _check_last_log( @@ -1293,7 +1275,6 @@ def test_should_respond_200_with_update_mask(self, url_safe_serializer): "last_parsed_time": None, "timetable_description": None, "has_import_errors": False, - "pickle_id": None, } assert response.json == expected_response @@ -1387,7 +1368,6 @@ def test_should_respond_200_on_patch_is_paused(self, session, url_safe_serialize "last_parsed_time": None, "timetable_description": None, "has_import_errors": False, - "pickle_id": None, }, { "dag_id": "TEST_DAG_2", @@ -1414,7 +1394,6 @@ def test_should_respond_200_on_patch_is_paused(self, session, url_safe_serialize "last_parsed_time": None, "timetable_description": None, "has_import_errors": False, - "pickle_id": None, }, ], "total_entries": 2, @@ -1466,7 +1445,6 @@ def test_should_respond_200_on_patch_is_paused_using_update_mask(self, session, "last_parsed_time": None, "timetable_description": None, "has_import_errors": False, - "pickle_id": None, }, { "dag_id": "TEST_DAG_2", @@ -1493,7 +1471,6 @@ def test_should_respond_200_on_patch_is_paused_using_update_mask(self, session, "last_parsed_time": None, "timetable_description": None, "has_import_errors": False, - "pickle_id": None, }, ], "total_entries": 2, @@ -1585,7 +1562,6 @@ def test_only_active_true_returns_active_dags(self, url_safe_serializer, session "last_parsed_time": None, "timetable_description": None, "has_import_errors": False, - "pickle_id": None, } ], "total_entries": 1, @@ -1633,7 +1609,6 @@ def test_only_active_false_returns_all_dags(self, url_safe_serializer, session): "last_parsed_time": None, "timetable_description": None, "has_import_errors": False, - "pickle_id": None, }, { "dag_id": "TEST_DAG_DELETED_1", @@ -1660,7 +1635,6 @@ def test_only_active_false_returns_all_dags(self, url_safe_serializer, session): "last_parsed_time": None, "timetable_description": None, "has_import_errors": False, - "pickle_id": None, }, ], "total_entries": 2, @@ -1858,7 +1832,6 @@ def test_should_respond_200_and_pause_dags(self, url_safe_serializer): "last_parsed_time": None, "timetable_description": None, "has_import_errors": False, - "pickle_id": None, }, { "dag_id": "TEST_DAG_2", @@ -1885,7 +1858,6 @@ def test_should_respond_200_and_pause_dags(self, url_safe_serializer): "last_parsed_time": None, "timetable_description": None, "has_import_errors": False, - "pickle_id": None, }, ], "total_entries": 2, @@ -1933,7 +1905,6 @@ def test_should_respond_200_and_pause_dag_pattern(self, session, url_safe_serial "last_parsed_time": None, "timetable_description": None, "has_import_errors": False, - "pickle_id": None, }, { "dag_id": "TEST_DAG_10", @@ -1960,7 +1931,6 @@ def test_should_respond_200_and_pause_dag_pattern(self, session, url_safe_serial "last_parsed_time": None, "timetable_description": None, "has_import_errors": False, - "pickle_id": None, }, ], "total_entries": 2, @@ -2010,7 +1980,6 @@ def test_should_respond_200_and_reverse_ordering(self, session, url_safe_seriali "last_parsed_time": None, "timetable_description": None, "has_import_errors": False, - "pickle_id": None, }, { "dag_id": "TEST_DAG_1", @@ -2037,7 +2006,6 @@ def test_should_respond_200_and_reverse_ordering(self, session, url_safe_seriali "last_parsed_time": None, "timetable_description": None, "has_import_errors": False, - "pickle_id": None, }, ], "total_entries": 2, diff --git a/tests/api_connexion/schemas/test_dag_schema.py b/tests/api_connexion/schemas/test_dag_schema.py index 1576f9cf3d6e..4a6829a5c831 100644 --- a/tests/api_connexion/schemas/test_dag_schema.py +++ b/tests/api_connexion/schemas/test_dag_schema.py @@ -72,7 +72,6 @@ def test_serialize_test_dag_schema(url_safe_serializer): "last_parsed_time": None, "timetable_description": None, "has_import_errors": None, - "pickle_id": None, } == serialized_dag @@ -108,7 +107,6 @@ def test_serialize_test_dag_collection_schema(url_safe_serializer): "last_parsed_time": None, "timetable_description": None, "has_import_errors": None, - "pickle_id": None, }, { "dag_id": "test_dag_id_b", @@ -135,7 +133,6 @@ def test_serialize_test_dag_collection_schema(url_safe_serializer): "last_parsed_time": None, "timetable_description": None, "has_import_errors": None, - "pickle_id": None, }, ], "total_entries": 2, @@ -190,7 +187,6 @@ def test_serialize_test_dag_detail_schema(url_safe_serializer): "timezone": UTC_JSON_REPR, "max_active_runs": 16, "max_consecutive_failed_dag_runs": 0, - "pickle_id": None, "end_date": None, "is_paused_upon_creation": None, "render_template_as_native_obj": False, @@ -254,7 +250,6 @@ def test_serialize_test_dag_with_asset_schedule_detail_schema(url_safe_serialize "timezone": UTC_JSON_REPR, "max_active_runs": 16, "max_consecutive_failed_dag_runs": 0, - "pickle_id": None, "end_date": None, "is_paused_upon_creation": None, "render_template_as_native_obj": False, diff --git a/tests/api_fastapi/core_api/routes/public/test_dags.py b/tests/api_fastapi/core_api/routes/public/test_dags.py index 72f4f70179f3..0e9b7a408583 100644 --- a/tests/api_fastapi/core_api/routes/public/test_dags.py +++ b/tests/api_fastapi/core_api/routes/public/test_dags.py @@ -320,7 +320,6 @@ def test_dag_details( "next_dagrun_create_after": None, "next_dagrun_data_interval_end": None, "next_dagrun_data_interval_start": None, - "orientation": "LR", "owners": ["airflow"], "params": { "foo": { diff --git a/tests/models/test_baseoperator.py b/tests/models/test_baseoperator.py index 31bb0c9dd66a..1d8e4f4d32a8 100644 --- a/tests/models/test_baseoperator.py +++ b/tests/models/test_baseoperator.py @@ -21,20 +21,18 @@ import logging import uuid from collections import defaultdict -from datetime import date, datetime, timedelta -from typing import TYPE_CHECKING, Any, NamedTuple +from datetime import date, datetime +from typing import NamedTuple from unittest import mock import jinja2 import pytest from airflow.decorators import task as task_decorator -from airflow.exceptions import AirflowException, FailStopDagInvalidTriggerRule +from airflow.exceptions import AirflowException from airflow.lineage.entities import File from airflow.models.baseoperator import ( - BASEOPERATOR_ARGS_EXPECTED_TYPES, BaseOperator, - BaseOperatorMeta, chain, chain_linear, cross_downstream, @@ -43,7 +41,6 @@ from airflow.models.dagrun import DagRun from airflow.models.taskinstance import TaskInstance from airflow.providers.common.sql.operators import sql -from airflow.task.priority_strategy import _DownstreamPriorityWeightStrategy, _UpstreamPriorityWeightStrategy from airflow.utils.edgemodifier import Label from airflow.utils.task_group import TaskGroup from airflow.utils.template import literal @@ -51,10 +48,7 @@ from airflow.utils.types import DagRunType from tests.models import DEFAULT_DATE -from tests_common.test_utils.mock_operators import DeprecatedOperator, MockOperator - -if TYPE_CHECKING: - from airflow.utils.context import Context +from tests_common.test_utils.mock_operators import MockOperator class ClassWithCustomAttributes: @@ -83,93 +77,12 @@ def __ne__(self, other): setattr(object1, "ref", object2) -# Essentially similar to airflow.models.baseoperator.BaseOperator -class DummyClass(metaclass=BaseOperatorMeta): - def __init__(self, test_param, params=None, default_args=None): - self.test_param = test_param - - def set_xcomargs_dependencies(self): ... - - -class DummySubClass(DummyClass): - def __init__(self, test_sub_param, **kwargs): - super().__init__(**kwargs) - self.test_sub_param = test_sub_param - - class MockNamedTuple(NamedTuple): var1: str var2: str -class CustomInt(int): - def __int__(self): - raise ValueError("Cannot cast to int") - - class TestBaseOperator: - def test_expand(self): - dummy = DummyClass(test_param=True) - assert dummy.test_param - - with pytest.raises(AirflowException, match="missing keyword argument 'test_param'"): - DummySubClass(test_sub_param=True) - - def test_default_args(self): - default_args = {"test_param": True} - dummy_class = DummyClass(default_args=default_args) - assert dummy_class.test_param - - default_args = {"test_param": True, "test_sub_param": True} - dummy_subclass = DummySubClass(default_args=default_args) - assert dummy_class.test_param - assert dummy_subclass.test_sub_param - - default_args = {"test_param": True} - dummy_subclass = DummySubClass(default_args=default_args, test_sub_param=True) - assert dummy_class.test_param - assert dummy_subclass.test_sub_param - - with pytest.raises(AirflowException, match="missing keyword argument 'test_sub_param'"): - DummySubClass(default_args=default_args) - - def test_execution_timeout_type(self): - with pytest.raises( - ValueError, match="execution_timeout must be timedelta object but passed as type: " - ): - BaseOperator(task_id="test", execution_timeout="1") - - with pytest.raises( - ValueError, match="execution_timeout must be timedelta object but passed as type: " - ): - BaseOperator(task_id="test", execution_timeout=1) - - def test_incorrect_default_args(self): - default_args = {"test_param": True, "extra_param": True} - dummy_class = DummyClass(default_args=default_args) - assert dummy_class.test_param - - default_args = {"random_params": True} - with pytest.raises(AirflowException, match="missing keyword argument 'test_param'"): - DummyClass(default_args=default_args) - - def test_incorrect_priority_weight(self): - error_msg = "`priority_weight` for task 'test_op' only accepts integers, received ''." - with pytest.raises(AirflowException, match=error_msg): - BaseOperator(task_id="test_op", priority_weight="2") - - def test_illegal_args_forbidden(self): - """ - Tests that operators raise exceptions on illegal arguments when - illegal arguments are not allowed. - """ - msg = r"Invalid arguments were passed to BaseOperator \(task_id: test_illegal_args\)" - with pytest.raises(AirflowException, match=msg): - BaseOperator( - task_id="test_illegal_args", - illegal_argument_1234="hello?", - ) - def test_trigger_rule_validation(self): from airflow.models.abstractoperator import DEFAULT_TRIGGER_RULE @@ -192,11 +105,6 @@ def test_trigger_rule_validation(self): BaseOperator( task_id="test_valid_trigger_rule", dag=non_fail_stop_dag, trigger_rule=TriggerRule.ALWAYS ) - # An operator with non default trigger rule and a fail stop dag should not be allowed - with pytest.raises(FailStopDagInvalidTriggerRule): - BaseOperator( - task_id="test_invalid_trigger_rule", dag=fail_stop_dag, trigger_rule=TriggerRule.ALWAYS - ) @pytest.mark.db_test @pytest.mark.parametrize( @@ -403,27 +311,6 @@ def test_jinja_env_creation(self, mock_jinja_env): task.render_template_fields(context={"foo": "whatever", "bar": "whatever"}) assert mock_jinja_env.call_count == 1 - def test_default_resources(self): - task = BaseOperator(task_id="default-resources") - assert task.resources is None - - def test_custom_resources(self): - task = BaseOperator(task_id="custom-resources", resources={"cpus": 1, "ram": 1024}) - assert task.resources.cpus.qty == 1 - assert task.resources.ram.qty == 1024 - - def test_default_email_on_actions(self): - test_task = BaseOperator(task_id="test_default_email_on_actions") - assert test_task.email_on_retry is True - assert test_task.email_on_failure is True - - def test_email_on_actions(self): - test_task = BaseOperator( - task_id="test_default_email_on_actions", email_on_retry=False, email_on_failure=True - ) - assert test_task.email_on_retry is False - assert test_task.email_on_failure is True - def test_cross_downstream(self): """Test if all dependencies between tasks are all set correctly.""" dag = DAG(dag_id="test_dag", schedule=None, start_date=datetime.now()) @@ -659,15 +546,6 @@ def test_lineage_composition(self): task4 > [inlet, outlet, extra] assert task4.get_outlet_defs() == [inlet, outlet, extra] - def test_warnings_are_properly_propagated(self): - with pytest.warns(DeprecationWarning) as warnings: - DeprecatedOperator(task_id="test") - assert len(warnings) == 1 - warning = warnings[0] - # Here we check that the trace points to the place - # where the deprecated class was used - assert warning.filename == __file__ - def test_pre_execute_hook(self): hook = mock.MagicMock() @@ -694,65 +572,6 @@ def test_task_naive_datetime(self): assert op_no_dag.start_date.tzinfo assert op_no_dag.end_date.tzinfo - def test_setattr_performs_no_custom_action_at_execute_time(self): - op = MockOperator(task_id="test_task") - op_copy = op.prepare_for_execution() - - with mock.patch("airflow.models.baseoperator.BaseOperator.set_xcomargs_dependencies") as method_mock: - op_copy.execute({}) - assert method_mock.call_count == 0 - - def test_upstream_is_set_when_template_field_is_xcomarg(self): - with DAG("xcomargs_test", schedule=None, default_args={"start_date": datetime.today()}): - op1 = BaseOperator(task_id="op1") - op2 = MockOperator(task_id="op2", arg1=op1.output) - - assert op1 in op2.upstream_list - assert op2 in op1.downstream_list - - def test_set_xcomargs_dependencies_works_recursively(self): - with DAG("xcomargs_test", schedule=None, default_args={"start_date": datetime.today()}): - op1 = BaseOperator(task_id="op1") - op2 = BaseOperator(task_id="op2") - op3 = MockOperator(task_id="op3", arg1=[op1.output, op2.output]) - op4 = MockOperator(task_id="op4", arg1={"op1": op1.output, "op2": op2.output}) - - assert op1 in op3.upstream_list - assert op2 in op3.upstream_list - assert op1 in op4.upstream_list - assert op2 in op4.upstream_list - - def test_set_xcomargs_dependencies_works_when_set_after_init(self): - with DAG(dag_id="xcomargs_test", schedule=None, default_args={"start_date": datetime.today()}): - op1 = BaseOperator(task_id="op1") - op2 = MockOperator(task_id="op2") - op2.arg1 = op1.output # value is set after init - - assert op1 in op2.upstream_list - - def test_set_xcomargs_dependencies_error_when_outside_dag(self): - op1 = BaseOperator(task_id="op1") - with pytest.raises(AirflowException): - MockOperator(task_id="op2", arg1=op1.output) - - def test_invalid_trigger_rule(self): - with pytest.raises( - AirflowException, - match=( - f"The trigger_rule must be one of {TriggerRule.all_triggers()}," - "'.op1'; received 'some_rule'." - ), - ): - BaseOperator(task_id="op1", trigger_rule="some_rule") - - def test_weight_rule_default(self): - op = BaseOperator(task_id="test_task") - assert _DownstreamPriorityWeightStrategy() == op.weight_rule - - def test_weight_rule_override(self): - op = BaseOperator(task_id="test_task", weight_rule="upstream") - assert _UpstreamPriorityWeightStrategy() == op.weight_rule - # ensure the default logging config is used for this test, no matter what ran before @pytest.mark.usefixtures("reset_logging_config") def test_logging_propogated_by_default(self, caplog): @@ -763,118 +582,6 @@ def test_logging_propogated_by_default(self, caplog): # leaking a lot of state) assert caplog.messages == ["test"] - def test_invalid_type_for_default_arg(self): - error_msg = "'max_active_tis_per_dag' has an invalid type with value not_an_int, expected type is " - with pytest.raises(TypeError, match=error_msg): - BaseOperator(task_id="test", default_args={"max_active_tis_per_dag": "not_an_int"}) - - def test_invalid_type_for_operator_arg(self): - error_msg = "'max_active_tis_per_dag' has an invalid type with value not_an_int, expected type is " - with pytest.raises(TypeError, match=error_msg): - BaseOperator(task_id="test", max_active_tis_per_dag="not_an_int") - - @mock.patch("airflow.models.baseoperator.validate_instance_args") - def test_baseoperator_init_validates_arg_types(self, mock_validate_instance_args): - operator = BaseOperator(task_id="test") - - mock_validate_instance_args.assert_called_once_with(operator, BASEOPERATOR_ARGS_EXPECTED_TYPES) - - -def test_init_subclass_args(): - class InitSubclassOp(BaseOperator): - _class_arg: Any - - def __init_subclass__(cls, class_arg=None, **kwargs) -> None: - cls._class_arg = class_arg - super().__init_subclass__() - - def execute(self, context: Context): - self.context_arg = context - - class_arg = "foo" - context = {"key": "value"} - - class ConcreteSubclassOp(InitSubclassOp, class_arg=class_arg): - pass - - task = ConcreteSubclassOp(task_id="op1") - task_copy = task.prepare_for_execution() - - task_copy.execute(context) - - assert task_copy._class_arg == class_arg - assert task_copy.context_arg == context - - -@pytest.mark.db_test -@pytest.mark.parametrize( - ("retries", "expected"), - [ - pytest.param("foo", "'retries' type must be int, not str", id="string"), - pytest.param(CustomInt(10), "'retries' type must be int, not CustomInt", id="custom int"), - ], -) -def test_operator_retries_invalid(dag_maker, retries, expected): - with pytest.raises(AirflowException) as ctx: - with dag_maker(): - BaseOperator(task_id="test_illegal_args", retries=retries) - assert str(ctx.value) == expected - - -@pytest.mark.db_test -@pytest.mark.parametrize( - ("retries", "expected"), - [ - pytest.param(None, [], id="None"), - pytest.param(5, [], id="5"), - pytest.param( - "1", - [ - ( - "airflow.models.baseoperator.BaseOperator", - logging.WARNING, - "Implicitly converting 'retries' from '1' to int", - ), - ], - id="str", - ), - ], -) -def test_operator_retries(caplog, dag_maker, retries, expected): - with caplog.at_level(logging.WARNING): - with dag_maker(): - BaseOperator( - task_id="test_illegal_args", - retries=retries, - ) - assert caplog.record_tuples == expected - - -@pytest.mark.db_test -def test_default_retry_delay(dag_maker): - with dag_maker(dag_id="test_default_retry_delay"): - task1 = BaseOperator(task_id="test_no_explicit_retry_delay") - - assert task1.retry_delay == timedelta(seconds=300) - - -@pytest.mark.db_test -def test_dag_level_retry_delay(dag_maker): - with dag_maker(dag_id="test_dag_level_retry_delay", default_args={"retry_delay": timedelta(seconds=100)}): - task1 = BaseOperator(task_id="test_no_explicit_retry_delay") - - assert task1.retry_delay == timedelta(seconds=100) - - -@pytest.mark.db_test -def test_task_level_retry_delay(dag_maker): - with dag_maker( - dag_id="test_task_level_retry_delay", default_args={"retry_delay": timedelta(seconds=100)} - ): - task1 = BaseOperator(task_id="test_no_explicit_retry_delay", retry_delay=timedelta(seconds=200)) - - assert task1.retry_delay == timedelta(seconds=200) - def test_deepcopy(): # Test bug when copying an operator attached to a DAG diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index 2bca1c24acd4..4c1d8a67960c 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -23,7 +23,6 @@ import os import pickle import re -import weakref from datetime import timedelta from pathlib import Path from typing import TYPE_CHECKING @@ -42,7 +41,6 @@ from airflow.decorators import setup, task as task_decorator, teardown from airflow.exceptions import ( AirflowException, - DuplicateTaskIdFound, ParamValidationError, UnknownExecutorException, ) @@ -57,7 +55,6 @@ from airflow.models.baseoperator import BaseOperator from airflow.models.dag import ( DAG, - DAG_ARGS_EXPECTED_TYPES, DagModel, DagOwnerAttributes, DagTag, @@ -66,18 +63,19 @@ get_asset_triggered_next_run_info, ) from airflow.models.dagrun import DagRun -from airflow.models.param import DagParam, Param, ParamsDict +from airflow.models.param import DagParam, Param from airflow.models.serialized_dag import SerializedDagModel from airflow.models.taskinstance import TaskInstance as TI from airflow.operators.empty import EmptyOperator from airflow.operators.python import PythonOperator from airflow.providers.standard.operators.bash import BashOperator +from airflow.sdk import TaskGroup +from airflow.sdk.definitions.contextmanager import TaskGroupContext from airflow.security import permissions from airflow.templates import NativeEnvironment, SandboxedEnvironment from airflow.timetables.base import DagRunInfo, DataInterval, TimeRestriction, Timetable from airflow.timetables.simple import ( AssetTriggeredTimetable, - ContinuousTimetable, NullTimetable, OnceTimetable, ) @@ -85,7 +83,6 @@ from airflow.utils.file import list_py_file_paths from airflow.utils.session import create_session from airflow.utils.state import DagRunState, State, TaskInstanceState -from airflow.utils.task_group import TaskGroup, TaskGroupContext from airflow.utils.timezone import datetime as datetime_tz from airflow.utils.trigger_rule import TriggerRule from airflow.utils.types import DagRunType @@ -94,7 +91,6 @@ from tests.models import DEFAULT_DATE from tests.plugins.priority_weight_strategy import ( FactorPriorityWeightStrategy, - NotRegisteredPriorityWeightStrategy, StaticTestPriorityWeightStrategy, TestPriorityWeightStrategyPlugin, ) @@ -174,150 +170,6 @@ def _occur_before(a, b, list_): b_index = i return 0 <= a_index < b_index - def test_params_not_passed_is_empty_dict(self): - """ - Test that when 'params' is _not_ passed to a new Dag, that the params - attribute is set to an empty dictionary. - """ - dag = DAG("test-dag", schedule=None) - - assert isinstance(dag.params, ParamsDict) - assert 0 == len(dag.params) - - def test_params_passed_and_params_in_default_args_no_override(self): - """ - Test that when 'params' exists as a key passed to the default_args dict - in addition to params being passed explicitly as an argument to the - dag, that the 'params' key of the default_args dict is merged with the - dict of the params argument. - """ - params1 = {"parameter1": 1} - params2 = {"parameter2": 2} - - dag = DAG("test-dag", schedule=None, default_args={"params": params1}, params=params2) - - assert params1["parameter1"] == dag.params["parameter1"] - assert params2["parameter2"] == dag.params["parameter2"] - - def test_not_none_schedule_with_non_default_params(self): - """ - Test if there is a DAG with not None schedule and have some params that - don't have a default value raise a error while DAG parsing - """ - params = {"param1": Param(type="string")} - - with pytest.raises(AirflowException): - DAG("dummy-dag", schedule=timedelta(days=1), start_date=DEFAULT_DATE, params=params) - - def test_dag_invalid_default_view(self): - """ - Test invalid `default_view` of DAG initialization - """ - with pytest.raises(AirflowException, match="Invalid values of dag.default_view: only support"): - DAG(dag_id="test-invalid-default_view", schedule=None, default_view="airflow") - - def test_dag_default_view_default_value(self): - """ - Test `default_view` default value of DAG initialization - """ - dag = DAG(dag_id="test-default_default_view", schedule=None) - assert conf.get("webserver", "dag_default_view").lower() == dag.default_view - - def test_dag_invalid_orientation(self): - """ - Test invalid `orientation` of DAG initialization - """ - with pytest.raises(AirflowException, match="Invalid values of dag.orientation: only support"): - DAG(dag_id="test-invalid-orientation", schedule=None, orientation="airflow") - - def test_dag_orientation_default_value(self): - """ - Test `orientation` default value of DAG initialization - """ - dag = DAG(dag_id="test-default_orientation", schedule=None) - assert conf.get("webserver", "dag_orientation") == dag.orientation - - def test_dag_as_context_manager(self): - """ - Test DAG as a context manager. - When used as a context manager, Operators are automatically added to - the DAG (unless they specify a different DAG) - """ - dag = DAG("dag", schedule=None, start_date=DEFAULT_DATE, default_args={"owner": "owner1"}) - dag2 = DAG("dag2", schedule=None, start_date=DEFAULT_DATE, default_args={"owner": "owner2"}) - - with dag: - op1 = EmptyOperator(task_id="op1") - op2 = EmptyOperator(task_id="op2", dag=dag2) - - assert op1.dag is dag - assert op1.owner == "owner1" - assert op2.dag is dag2 - assert op2.owner == "owner2" - - with dag2: - op3 = EmptyOperator(task_id="op3") - - assert op3.dag is dag2 - assert op3.owner == "owner2" - - with dag: - with dag2: - op4 = EmptyOperator(task_id="op4") - op5 = EmptyOperator(task_id="op5") - - assert op4.dag is dag2 - assert op5.dag is dag - assert op4.owner == "owner2" - assert op5.owner == "owner1" - - with DAG("creating_dag_in_cm", schedule=None, start_date=DEFAULT_DATE) as dag: - EmptyOperator(task_id="op6") - - assert dag.dag_id == "creating_dag_in_cm" - assert dag.tasks[0].task_id == "op6" - - with dag: - with dag: - op7 = EmptyOperator(task_id="op7") - op8 = EmptyOperator(task_id="op8") - op9 = EmptyOperator(task_id="op8") - op9.dag = dag2 - - assert op7.dag == dag - assert op8.dag == dag - assert op9.dag == dag2 - - def test_dag_topological_sort_dag_without_tasks(self): - dag = DAG("dag", schedule=None, start_date=DEFAULT_DATE, default_args={"owner": "owner1"}) - - assert () == dag.topological_sort() - - def test_dag_naive_start_date_string(self): - DAG("DAG", schedule=None, default_args={"start_date": "2019-06-01"}) - - def test_dag_naive_start_end_dates_strings(self): - DAG("DAG", schedule=None, default_args={"start_date": "2019-06-01", "end_date": "2019-06-05"}) - - def test_dag_start_date_propagates_to_end_date(self): - """ - Tests that a start_date string with a timezone and an end_date string without a timezone - are accepted and that the timezone from the start carries over the end - - This test is a little indirect, it works by setting start and end equal except for the - timezone and then testing for equality after the DAG construction. They'll be equal - only if the same timezone was applied to both. - - An explicit check the `tzinfo` attributes for both are the same is an extra check. - """ - dag = DAG( - "DAG", - schedule=None, - default_args={"start_date": "2019-06-05T00:00:00+05:00", "end_date": "2019-06-05T00:00:00"}, - ) - assert dag.default_args["start_date"] == dag.default_args["end_date"] - assert dag.default_args["start_date"].tzinfo == dag.default_args["end_date"].tzinfo - def test_dag_naive_default_args_start_date(self): dag = DAG("DAG", schedule=None, default_args={"start_date": datetime.datetime(2018, 1, 1)}) assert dag.timezone == settings.TIMEZONE @@ -416,12 +268,6 @@ def test_dag_task_priority_weight_total_using_absolute(self): calculated_weight = task.priority_weight_total assert calculated_weight == correct_weight - def test_dag_task_invalid_weight_rule(self): - # Test if we enter an invalid weight rule - with DAG("dag", schedule=None, start_date=DEFAULT_DATE, default_args={"owner": "owner1"}): - with pytest.raises(AirflowException): - EmptyOperator(task_id="should_fail", weight_rule="no rule") - @pytest.mark.parametrize( "cls, expected", [ @@ -448,16 +294,6 @@ def test_dag_task_custom_weight_strategy(self, cls, expected): ti = dr.get_task_instance(task.task_id) assert ti.priority_weight == expected - def test_dag_task_not_registered_weight_strategy(self): - with mock_plugin_manager(plugins=[TestPriorityWeightStrategyPlugin]), DAG( - "dag", schedule=None, start_date=DEFAULT_DATE, default_args={"owner": "owner1"} - ): - with pytest.raises(AirflowException, match="Unknown priority strategy"): - EmptyOperator( - task_id="empty_task", - weight_rule=NotRegisteredPriorityWeightStrategy(), - ) - def test_get_num_task_instances(self): test_dag_id = "test_get_num_task_instances_dag" test_task_id = "task_1" @@ -766,11 +602,6 @@ def test_create_dagrun_when_schedule_is_none_and_empty_start_date(self): ) assert dagrun is not None - def test_fail_dag_when_schedule_is_non_none_and_empty_start_date(self): - # Check that we get a ValueError 'start_date' for self.start_date when schedule is non-none - with pytest.raises(ValueError, match="start_date is required when catchup=True"): - DAG(dag_id="dag_with_non_none_schedule_and_empty_start_date", schedule="@hourly", catchup=True) - def test_dagtag_repr(self): clear_db_dags() dag = DAG("dag-test-dagtag", schedule=None, start_date=DEFAULT_DATE, tags=["tag-1", "tag-2"]) @@ -1283,100 +1114,6 @@ def test_dag_naive_default_args_start_date_with_timezone(self): dag = DAG("DAG", schedule=None, default_args=default_args) assert dag.timezone.name == local_tz.name - def test_roots(self): - """Verify if dag.roots returns the root tasks of a DAG.""" - with DAG("test_dag", schedule=None, start_date=DEFAULT_DATE) as dag: - op1 = EmptyOperator(task_id="t1") - op2 = EmptyOperator(task_id="t2") - op3 = EmptyOperator(task_id="t3") - op4 = EmptyOperator(task_id="t4") - op5 = EmptyOperator(task_id="t5") - [op1, op2] >> op3 >> [op4, op5] - - assert set(dag.roots) == {op1, op2} - - def test_leaves(self): - """Verify if dag.leaves returns the leaf tasks of a DAG.""" - with DAG("test_dag", schedule=None, start_date=DEFAULT_DATE) as dag: - op1 = EmptyOperator(task_id="t1") - op2 = EmptyOperator(task_id="t2") - op3 = EmptyOperator(task_id="t3") - op4 = EmptyOperator(task_id="t4") - op5 = EmptyOperator(task_id="t5") - [op1, op2] >> op3 >> [op4, op5] - - assert set(dag.leaves) == {op4, op5} - - def test_duplicate_task_ids_not_allowed_with_dag_context_manager(self): - """Verify tasks with Duplicate task_id raises error""" - with DAG("test_dag", schedule=None, start_date=DEFAULT_DATE) as dag: - op1 = EmptyOperator(task_id="t1") - with pytest.raises(DuplicateTaskIdFound, match="Task id 't1' has already been added to the DAG"): - BashOperator(task_id="t1", bash_command="sleep 1") - - assert dag.task_dict == {op1.task_id: op1} - - def test_duplicate_task_ids_not_allowed_without_dag_context_manager(self): - """Verify tasks with Duplicate task_id raises error""" - dag = DAG("test_dag", schedule=None, start_date=DEFAULT_DATE) - op1 = EmptyOperator(task_id="t1", dag=dag) - with pytest.raises(DuplicateTaskIdFound, match="Task id 't1' has already been added to the DAG"): - EmptyOperator(task_id="t1", dag=dag) - - assert dag.task_dict == {op1.task_id: op1} - - def test_duplicate_task_ids_for_same_task_is_allowed(self): - """Verify that same tasks with Duplicate task_id do not raise error""" - with DAG("test_dag", schedule=None, start_date=DEFAULT_DATE) as dag: - op1 = op2 = EmptyOperator(task_id="t1") - op3 = EmptyOperator(task_id="t3") - op1 >> op3 - op2 >> op3 - - assert op1 == op2 - assert dag.task_dict == {op1.task_id: op1, op3.task_id: op3} - assert dag.task_dict == {op2.task_id: op2, op3.task_id: op3} - - def test_partial_subset_updates_all_references_while_deepcopy(self): - with DAG("test_dag", schedule=None, start_date=DEFAULT_DATE) as dag: - op1 = EmptyOperator(task_id="t1") - op2 = EmptyOperator(task_id="t2") - op3 = EmptyOperator(task_id="t3") - op1 >> op2 - op2 >> op3 - - partial = dag.partial_subset("t2", include_upstream=True, include_downstream=False) - assert id(partial.task_dict["t1"].downstream_list[0].dag) == id(partial) - - # Copied DAG should not include unused task IDs in used_group_ids - assert "t3" not in partial.task_group.used_group_ids - - def test_partial_subset_taskgroup_join_ids(self): - with DAG("test_dag", schedule=None, start_date=DEFAULT_DATE) as dag: - start = EmptyOperator(task_id="start") - with TaskGroup(group_id="outer", prefix_group_id=False) as outer_group: - with TaskGroup(group_id="tg1", prefix_group_id=False) as tg1: - EmptyOperator(task_id="t1") - with TaskGroup(group_id="tg2", prefix_group_id=False) as tg2: - EmptyOperator(task_id="t2") - - start >> tg1 >> tg2 - - # Pre-condition checks - task = dag.get_task("t2") - assert task.task_group.upstream_group_ids == {"tg1"} - assert isinstance(task.task_group.parent_group, weakref.ProxyType) - assert task.task_group.parent_group == outer_group - - partial = dag.partial_subset(["t2"], include_upstream=True, include_downstream=False) - copied_task = partial.get_task("t2") - assert copied_task.task_group.upstream_group_ids == {"tg1"} - assert isinstance(copied_task.task_group.parent_group, weakref.ProxyType) - assert copied_task.task_group.parent_group - - # Make sure we don't affect the original! - assert task.task_group.upstream_group_ids is not copied_task.task_group.upstream_group_ids - def test_schedule_dag_no_previous_runs(self): """ Tests scheduling a dag with no previous runs @@ -1759,13 +1496,16 @@ def test_dag_add_task_checks_trigger_rule(self): fail_stop_dag.add_task(task_with_default_trigger_rule) # a fail stop dag should not allow a non-default trigger rule + task_with_non_default_trigger_rule = EmptyOperator( + task_id="task_with_non_default_trigger_rule", trigger_rule=TriggerRule.ALWAYS + ) with pytest.raises(FailStopDagInvalidTriggerRule): fail_stop_dag.add_task(task_with_non_default_trigger_rule) def test_dag_add_task_sets_default_task_group(self): dag = DAG(dag_id="test_dag_add_task_sets_default_task_group", schedule=None, start_date=DEFAULT_DATE) task_without_task_group = EmptyOperator(task_id="task_without_group_id") - default_task_group = TaskGroupContext.get_current_task_group(dag) + default_task_group = TaskGroupContext.get_current(dag) dag.add_task(task_without_task_group) assert default_task_group.get_child_by_label("task_without_group_id") == task_without_task_group @@ -1942,7 +1682,15 @@ def check_task_2(my_input): with dag: check_task_2(check_task()) - dag.test() + dr = dag.test() + + ti1 = dr.get_task_instance("check_task") + ti2 = dr.get_task_instance("check_task_2") + + assert ti1 + assert ti2 + assert ti1.state == TaskInstanceState.FAILED + assert ti2.state == TaskInstanceState.UPSTREAM_FAILED mock_handle_object_1.assert_called_with("task check_task failed...") mock_handle_object_2.assert_called_with("dag test_local_testing_conn_file run failed...") @@ -2380,22 +2128,6 @@ def test_dag_owner_links(self): orm_dag_owners = session.query(DagOwnerAttributes).all() assert not orm_dag_owners - # Check wrong formatted owner link - with pytest.raises(AirflowException): - DAG("dag", schedule=None, start_date=DEFAULT_DATE, owner_links={"owner1": "my-bad-link"}) - - def test_continuous_schedule_linmits_max_active_runs(self): - dag = DAG("continuous", start_date=DEFAULT_DATE, schedule="@continuous", max_active_runs=1) - assert isinstance(dag.timetable, ContinuousTimetable) - assert dag.max_active_runs == 1 - - dag = DAG("continuous", start_date=DEFAULT_DATE, schedule="@continuous", max_active_runs=0) - assert isinstance(dag.timetable, ContinuousTimetable) - assert dag.max_active_runs == 0 - - with pytest.raises(AirflowException): - dag = DAG("continuous", start_date=DEFAULT_DATE, schedule="@continuous", max_active_runs=25) - class TestDagModel: def _clean(self): @@ -2775,54 +2507,6 @@ def setup_method(self): def teardown_method(self): clear_db_runs() - def test_fileloc(self): - @dag_decorator(schedule=None, default_args=self.DEFAULT_ARGS) - def noop_pipeline(): ... - - dag = noop_pipeline() - assert isinstance(dag, DAG) - assert dag.dag_id == "noop_pipeline" - assert dag.fileloc == __file__ - - def test_set_dag_id(self): - """Test that checks you can set dag_id from decorator.""" - - @dag_decorator("test", schedule=None, default_args=self.DEFAULT_ARGS) - def noop_pipeline(): ... - - dag = noop_pipeline() - assert isinstance(dag, DAG) - assert dag.dag_id == "test" - - def test_default_dag_id(self): - """Test that @dag uses function name as default dag id.""" - - @dag_decorator(schedule=None, default_args=self.DEFAULT_ARGS) - def noop_pipeline(): ... - - dag = noop_pipeline() - assert isinstance(dag, DAG) - assert dag.dag_id == "noop_pipeline" - - @pytest.mark.parametrize( - argnames=["dag_doc_md", "expected_doc_md"], - argvalues=[ - pytest.param("dag docs.", "dag docs.", id="use_dag_doc_md"), - pytest.param(None, "Regular DAG documentation", id="use_dag_docstring"), - ], - ) - def test_documentation_added(self, dag_doc_md, expected_doc_md): - """Test that @dag uses function docs as doc_md for DAG object if doc_md is not explicitly set.""" - - @dag_decorator(schedule=None, default_args=self.DEFAULT_ARGS, doc_md=dag_doc_md) - def noop_pipeline(): - """Regular DAG documentation""" - - dag = noop_pipeline() - assert isinstance(dag, DAG) - assert dag.dag_id == "noop_pipeline" - assert dag.doc_md == expected_doc_md - def test_documentation_template_rendered(self): """Test that @dag uses function docs as doc_md for DAG object""" @@ -2835,7 +2519,6 @@ def noop_pipeline(): """ dag = noop_pipeline() - assert isinstance(dag, DAG) assert dag.dag_id == "noop_pipeline" assert "Regular DAG documentation" in dag.doc_md @@ -2855,25 +2538,9 @@ def test_resolve_documentation_template_file_not_rendered(self, tmp_path): def markdown_docs(): ... dag = markdown_docs() - assert isinstance(dag, DAG) assert dag.dag_id == "test-dag" assert dag.doc_md == raw_content - def test_fails_if_arg_not_set(self): - """Test that @dag decorated function fails if positional argument is not set""" - - @dag_decorator(schedule=None, default_args=self.DEFAULT_ARGS) - def noop_pipeline(value): - @task_decorator - def return_num(num): - return num - - return_num(value) - - # Test that if arg is not passed it raises a type error as expected. - with pytest.raises(TypeError): - noop_pipeline() - def test_dag_param_resolves(self): """Test that dag param is correctly resolved by operator""" @@ -3178,7 +2845,7 @@ def get_ti_from_db(task): } -def test_dag_teardowns_property_lists_all_teardown_tasks(dag_maker): +def test_dag_teardowns_property_lists_all_teardown_tasks(): @setup def setup_task(): return 1 @@ -3199,7 +2866,7 @@ def teardown_task3(): def mytask(): return 1 - with dag_maker() as dag: + with DAG("dag") as dag: t1 = setup_task() t2 = teardown_task() t3 = teardown_task2() @@ -3358,60 +3025,6 @@ def test__time_restriction(dag_maker, dag_date, tasks_date, restrict): assert dag._time_restriction == restrict -@pytest.mark.parametrize( - "tags, should_pass", - [ - pytest.param([], True, id="empty tags"), - pytest.param(["a normal tag"], True, id="one tag"), - pytest.param(["a normal tag", "another normal tag"], True, id="two tags"), - pytest.param(["a" * 100], True, id="a tag that's of just length 100"), - pytest.param(["a normal tag", "a" * 101], False, id="two tags and one of them is of length > 100"), - ], -) -def test__tags_length(tags: list[str], should_pass: bool): - if should_pass: - DAG("test-dag", schedule=None, tags=tags) - else: - with pytest.raises(AirflowException): - DAG("test-dag", schedule=None, tags=tags) - - -@pytest.mark.parametrize( - "input_tags, expected_result", - [ - pytest.param([], set(), id="empty tags"), - pytest.param( - ["a normal tag"], - {"a normal tag"}, - id="one tag", - ), - pytest.param( - ["a normal tag", "another normal tag"], - {"a normal tag", "another normal tag"}, - id="two different tags", - ), - pytest.param( - ["a", "a"], - {"a"}, - id="two same tags", - ), - ], -) -def test__tags_duplicates(input_tags: list[str], expected_result: set[str]): - result = DAG("test-dag", tags=input_tags) - assert result.tags == expected_result - - -def test__tags_mutable(): - expected_tags = {"6", "7"} - test_dag = DAG("test-dag") - test_dag.tags.add("6") - test_dag.tags.add("7") - test_dag.tags.add("8") - test_dag.tags.remove("8") - assert test_dag.tags == expected_tags - - @pytest.mark.need_serialized_dag def test_get_asset_triggered_next_run_info(dag_maker, clear_assets): asset1 = Asset(uri="ds1") @@ -3520,18 +3133,6 @@ def test_create_dagrun_disallow_manual_to_use_automated_run_id(run_id_type: DagR ) -def test_invalid_type_for_args(): - with pytest.raises(TypeError): - DAG("invalid-default-args", schedule=None, max_consecutive_failed_dag_runs="not_an_int") - - -@mock.patch("airflow.models.dag.validate_instance_args") -def test_dag_init_validates_arg_types(mock_validate_instance_args): - dag = DAG("dag_with_expected_args", schedule=None) - - mock_validate_instance_args.assert_called_once_with(dag, DAG_ARGS_EXPECTED_TYPES) - - class TestTaskClearingSetupTeardownBehavior: """ Task clearing behavior is mainly controlled by dag.partial_subset. diff --git a/tests/models/test_dagbag.py b/tests/models/test_dagbag.py index 89e899543eb0..f563c72f5451 100644 --- a/tests/models/test_dagbag.py +++ b/tests/models/test_dagbag.py @@ -667,7 +667,7 @@ def test_sync_to_db_is_retried(self, mock_bulk_write_to_db, mock_s10n_write_dag, """Test that dagbag.sync_to_db is retried on OperationalError""" dagbag = DagBag("/dev/null") - mock_dag = mock.MagicMock(spec=DAG) + mock_dag = mock.MagicMock() dagbag.dags["mock_dag"] = mock_dag op_error = OperationalError(statement=mock.ANY, params=mock.ANY, orig=mock.ANY) diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index de90889cffd7..ef213f438377 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -229,7 +229,7 @@ def test_set_dag(self, dag_maker): # no dag assigned assert not op.has_dag() - with pytest.raises(AirflowException): + with pytest.raises(RuntimeError): getattr(op, "dag") # no improper assignment @@ -239,7 +239,7 @@ def test_set_dag(self, dag_maker): op.dag = dag # no reassignment - with pytest.raises(AirflowException): + with pytest.raises(ValueError): op.dag = dag2 # but assigning the same dag is ok @@ -261,7 +261,7 @@ def test_infer_dag(self, create_dummy_dag): assert [i.has_dag() for i in [op1, op2, op3, op4]] == [False, False, True, True] # can't combine operators with no dags - with pytest.raises(AirflowException): + with pytest.raises(ValueError): op1.set_downstream(op2) # op2 should infer dag from op1 @@ -270,9 +270,9 @@ def test_infer_dag(self, create_dummy_dag): assert op2.dag is dag # can't assign across multiple DAGs - with pytest.raises(AirflowException): + with pytest.raises(RuntimeError): op1.set_downstream(op4) - with pytest.raises(AirflowException): + with pytest.raises(RuntimeError): op1.set_downstream([op3, op4]) def test_bitshift_compose_operators(self, dag_maker): diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index b54ded3a55c6..0111cc669cd0 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -20,6 +20,7 @@ from __future__ import annotations import copy +import dataclasses import importlib import importlib.util import json @@ -124,7 +125,7 @@ "delta": 86400.0, }, }, - "_task_group": { + "task_group": { "_group_id": None, "prefix_group_id": True, "children": {"bash_task": ("operator", "bash_task"), "custom_task": ("operator", "custom_task")}, @@ -137,7 +138,7 @@ "downstream_task_ids": [], }, "is_paused_upon_creation": False, - "_dag_id": "simple_dag", + "dag_id": "simple_dag", "doc_md": "### DAG Tutorial Documentation", "fileloc": None, "_processor_dags_folder": f"{repo_root}/tests/dags", @@ -146,7 +147,6 @@ "__type": "operator", "__var": { "task_id": "bash_task", - "owner": "airflow", "retries": 1, "retry_delay": 300.0, "max_retry_delay": 600.0, @@ -158,7 +158,7 @@ "template_fields": ["bash_command", "env", "cwd"], "template_fields_renderers": {"bash_command": "bash", "env": "json"}, "bash_command": "echo {{ task.task_id }}", - "_task_type": "BashOperator", + "task_type": "BashOperator", "_task_module": "airflow.providers.standard.operators.bash", "pool": "default_pool", "is_setup": False, @@ -174,7 +174,6 @@ }, }, "doc_md": "### Task Tutorial Documentation", - "_log_config_logger_name": "airflow.task.operators", "_needs_expansion": False, "weight_rule": "downstream", "start_trigger_args": None, @@ -196,14 +195,13 @@ "template_ext": [], "template_fields": ["bash_command"], "template_fields_renderers": {}, - "_task_type": "CustomOperator", + "task_type": "CustomOperator", "_operator_name": "@custom", "_task_module": "tests_common.test_utils.mock_operators", "pool": "default_pool", "is_setup": False, "is_teardown": False, "on_failure_fail_dagrun": False, - "_log_config_logger_name": "airflow.task.operators", "_needs_expansion": False, "weight_rule": "downstream", "start_trigger_args": None, @@ -212,7 +210,7 @@ }, ], "timezone": "UTC", - "_access_control": { + "access_control": { "__type": "dict", "__var": { "test_role": { @@ -456,7 +454,7 @@ def test_dag_serialization_preserves_empty_access_roles(self): serialized_dag = SerializedDAG.to_dict(dag) SerializedDAG.validate_schema(serialized_dag) - assert serialized_dag["dag"]["_access_control"] == {"__type": "dict", "__var": {}} + assert serialized_dag["dag"]["access_control"] == {"__type": "dict", "__var": {}} @pytest.mark.db_test def test_dag_serialization_unregistered_custom_timetable(self): @@ -491,8 +489,8 @@ def sorted_serialized_dag(dag_dict: dict): task["__var"] = dict(sorted(task["__var"].items(), key=lambda x: x[0])) tasks.append(task) dag_dict["dag"]["tasks"] = tasks - dag_dict["dag"]["_access_control"]["__var"]["test_role"]["__var"] = sorted( - dag_dict["dag"]["_access_control"]["__var"]["test_role"]["__var"] + dag_dict["dag"]["access_control"]["__var"]["test_role"]["__var"] = sorted( + dag_dict["dag"]["access_control"]["__var"]["test_role"]["__var"] ) return dag_dict @@ -567,7 +565,7 @@ def validate_deserialized_dag(self, serialized_dag: DAG, dag: DAG): "timezone", # Need to check fields in it, to exclude functions. "default_args", - "_task_group", + "task_group", "params", "_processor_dags_folder", } @@ -592,7 +590,7 @@ def validate_deserialized_dag(self, serialized_dag: DAG, dag: DAG): assert serialized_dag.timetable.summary == dag.timetable.summary assert serialized_dag.timetable.serialize() == dag.timetable.serialize() - assert serialized_dag.timezone.name == dag.timezone.name + assert serialized_dag.timezone == dag.timezone for task_id in dag.task_ids: self.validate_deserialized_task(serialized_dag.get_task(task_id), dag.get_task(task_id)) @@ -613,7 +611,7 @@ def validate_deserialized_task( assert isinstance(serialized_task, SerializedBaseOperator) fields_to_check = task.get_serialized_fields() - { # Checked separately - "_task_type", + "task_type", "_operator_name", # Type is excluded, so don't check it "_log", @@ -675,7 +673,7 @@ def validate_deserialized_task( # MappedOperator.operator_class holds a backup of the serialized # data; checking its entirety basically duplicates this validation # function, so we just do some sanity checks. - serialized_task.operator_class["_task_type"] == type(task).__name__ + serialized_task.operator_class["task_type"] == type(task).__name__ if isinstance(serialized_task.operator_class, DecoratedOperator): serialized_task.operator_class["_operator_name"] == task._operator_name @@ -804,7 +802,7 @@ def test_deserialization_timetable( "__version": 1, "dag": { "default_args": {"__type": "dict", "__var": {}}, - "_dag_id": "simple_dag", + "dag_id": "simple_dag", "fileloc": __file__, "tasks": [], "timezone": "UTC", @@ -820,7 +818,7 @@ def test_deserialization_timetable_unregistered(self): "__version": 1, "dag": { "default_args": {"__type": "dict", "__var": {}}, - "_dag_id": "simple_dag", + "dag_id": "simple_dag", "fileloc": __file__, "tasks": [], "timezone": "UTC", @@ -1060,6 +1058,7 @@ def test_extra_serialized_field_and_operator_links( link = simple_task.get_extra_links(ti, GoogleLink.name) assert "https://www.google.com" == link + @pytest.mark.usefixtures("clear_all_logger_handlers") def test_extra_operator_links_logs_error_for_non_registered_extra_links(self, caplog): """ Assert OperatorLinks not registered via Plugins and if it is not an inbuilt Operator Link, @@ -1221,14 +1220,20 @@ def test_no_new_fields_added_to_base_operator(self): This test verifies that there are no new fields added to BaseOperator. And reminds that tests should be added for it. """ + from airflow.utils.trigger_rule import TriggerRule + base_operator = BaseOperator(task_id="10") - fields = {k: v for (k, v) in vars(base_operator).items() if k in BaseOperator.get_serialized_fields()} + # Return the name of any annotated class property, or anything explicitly listed in serialized fields + field_names = { + fld.name + for fld in dataclasses.fields(BaseOperator) + if fld.name in BaseOperator.get_serialized_fields() + } | BaseOperator.get_serialized_fields() + fields = {k: getattr(base_operator, k) for k in field_names} assert fields == { "_logger_name": None, - "_log_config_logger_name": "airflow.task.operators", - "_post_execute_hook": None, - "_pre_execute_hook": None, - "_task_display_property_value": None, + "_needs_expansion": None, + "_task_display_name": None, "allow_nested_operators": True, "depends_on_past": False, "do_xcom_push": True, @@ -1238,19 +1243,23 @@ def test_no_new_fields_added_to_base_operator(self): "doc_rst": None, "doc_yaml": None, "downstream_task_ids": set(), + "end_date": None, "email": None, "email_on_failure": True, "email_on_retry": True, "execution_timeout": None, "executor": None, "executor_config": {}, - "ignore_first_depends_on_past": True, + "ignore_first_depends_on_past": False, + "is_setup": False, + "is_teardown": False, "inlets": [], "map_index_template": None, "max_active_tis_per_dag": None, "max_active_tis_per_dagrun": None, "max_retry_delay": None, "on_execute_callback": None, + "on_failure_fail_dagrun": False, "on_failure_callback": None, "on_retry_callback": None, "on_skipped_callback": None, @@ -1267,8 +1276,18 @@ def test_no_new_fields_added_to_base_operator(self): "retry_delay": timedelta(0, 300), "retry_exponential_backoff": False, "run_as_user": None, + "sla": None, + "start_date": None, + "start_from_trigger": False, + "start_trigger_args": None, "task_id": "10", - "trigger_rule": "all_success", + "task_type": "BaseOperator", + "template_ext": (), + "template_fields": (), + "template_fields_renderers": {}, + "trigger_rule": TriggerRule.ALL_SUCCESS, + "ui_color": "#fff", + "ui_fgcolor": "#000", "wait_for_downstream": False, "wait_for_past_depends_before_skipping": False, "weight_rule": _DownstreamPriorityWeightStrategy(), @@ -1294,7 +1313,7 @@ def test_operator_deserialize_old_names(self): "template_ext": [], "template_fields": ["bash_command"], "template_fields_renderers": {}, - "_task_type": "CustomOperator", + "task_type": "CustomOperator", "_task_module": "tests_common.test_utils.mock_operators", "pool": "default_pool", "ui_color": "#fff", @@ -2052,7 +2071,7 @@ def test_params_upgrade(self): serialized = { "__version": 1, "dag": { - "_dag_id": "simple_dag", + "dag_id": "simple_dag", "fileloc": "/path/to/file.py", "tasks": [], "timezone": "UTC", @@ -2071,7 +2090,7 @@ def test_params_serialization_from_dict_upgrade(self): serialized = { "__version": 1, "dag": { - "_dag_id": "simple_dag", + "dag_id": "simple_dag", "fileloc": "/path/to/file.py", "tasks": [], "timezone": "UTC", @@ -2091,7 +2110,7 @@ def test_params_serialize_default_2_2_0(self): serialized = { "__version": 1, "dag": { - "_dag_id": "simple_dag", + "dag_id": "simple_dag", "fileloc": "/path/to/file.py", "tasks": [], "timezone": "UTC", @@ -2108,7 +2127,7 @@ def test_params_serialize_default(self): serialized = { "__version": 1, "dag": { - "_dag_id": "simple_dag", + "dag_id": "simple_dag", "fileloc": "/path/to/file.py", "tasks": [], "timezone": "UTC", @@ -2290,7 +2309,7 @@ def test_operator_expand_serde(): "_is_mapped": True, "_needs_expansion": True, "_task_module": "airflow.providers.standard.operators.bash", - "_task_type": "BashOperator", + "task_type": "BashOperator", "start_trigger_args": None, "start_from_trigger": False, "downstream_task_ids": [], @@ -2323,7 +2342,7 @@ def test_operator_expand_serde(): assert op.deps is MappedOperator.deps_for(BaseOperator) assert op.operator_class == { - "_task_type": "BashOperator", + "task_type": "BashOperator", "_needs_expansion": True, "start_trigger_args": None, "start_from_trigger": False, @@ -2353,7 +2372,7 @@ def test_operator_expand_xcomarg_serde(): "_is_mapped": True, "_needs_expansion": True, "_task_module": "tests_common.test_utils.mock_operators", - "_task_type": "MockOperator", + "task_type": "MockOperator", "downstream_task_ids": [], "expand_input": { "type": "dict-of-lists", @@ -2408,7 +2427,7 @@ def test_operator_expand_kwargs_literal_serde(strict): "_is_mapped": True, "_needs_expansion": True, "_task_module": "tests_common.test_utils.mock_operators", - "_task_type": "MockOperator", + "task_type": "MockOperator", "downstream_task_ids": [], "expand_input": { "type": "list-of-dicts", @@ -2463,7 +2482,7 @@ def test_operator_expand_kwargs_xcomarg_serde(strict): "_is_mapped": True, "_needs_expansion": True, "_task_module": "tests_common.test_utils.mock_operators", - "_task_type": "MockOperator", + "task_type": "MockOperator", "downstream_task_ids": [], "expand_input": { "type": "list-of-dicts", @@ -2575,7 +2594,7 @@ def x(arg1, arg2, arg3): "_is_mapped": True, "_needs_expansion": True, "_task_module": "airflow.decorators.python", - "_task_type": "_PythonDecoratedOperator", + "task_type": "_PythonDecoratedOperator", "_operator_name": "@task", "downstream_task_ids": [], "partial_kwargs": { @@ -2677,7 +2696,7 @@ def x(arg1, arg2, arg3): "_is_mapped": True, "_needs_expansion": True, "_task_module": "airflow.decorators.python", - "_task_type": "_PythonDecoratedOperator", + "task_type": "_PythonDecoratedOperator", "_operator_name": "@task", "start_trigger_args": None, "start_from_trigger": False, @@ -2771,7 +2790,7 @@ def tg(a: str) -> None: tg.expand(a=[".", ".."]) ser_dag = SerializedBaseOperator.serialize(dag) - assert ser_dag[Encoding.VAR]["_task_group"]["children"]["tg"] == ( + assert ser_dag[Encoding.VAR]["task_group"]["children"]["tg"] == ( "taskgroup", { "_group_id": "tg", @@ -2831,7 +2850,7 @@ def operator_extra_links(self): "template_ext": [], "template_fields": [], "template_fields_renderers": {}, - "_task_type": "_DummyOperator", + "task_type": "_DummyOperator", "_task_module": "tests.serialization.test_dag_serialization", "_is_empty": False, "_is_mapped": True, diff --git a/tests/serialization/test_pydantic_models.py b/tests/serialization/test_pydantic_models.py index 83d677df91b7..7bfc7eda32af 100644 --- a/tests/serialization/test_pydantic_models.py +++ b/tests/serialization/test_pydantic_models.py @@ -33,7 +33,7 @@ DagScheduleAssetReference, TaskOutletAssetReference, ) -from airflow.models.dag import DAG, DagModel, create_timetable +from airflow.models.dag import DAG, DagModel from airflow.serialization.pydantic.asset import AssetEventPydantic from airflow.serialization.pydantic.dag import DagModelPydantic from airflow.serialization.pydantic.dag_run import DagRunPydantic @@ -81,11 +81,10 @@ def test_serializing_pydantic_task_instance(session, create_task_instance): def test_deserialize_ti_mapped_op_reserialized_with_refresh_from_task(session, dag_maker): op_class_dict_expected = { "_needs_expansion": True, - "_task_type": "_PythonDecoratedOperator", + "task_type": "_PythonDecoratedOperator", "downstream_task_ids": [], "start_from_trigger": False, "start_trigger_args": None, - "_operator_name": "@task", "ui_fgcolor": "#000", "ui_color": "#ffefeb", "template_fields": ["templates_dict", "op_args", "op_kwargs"], @@ -128,6 +127,8 @@ def target(val=None): assert desered.task.__class__ == MappedOperator assert desered.task.operator_class == op_class_dict_expected + assert desered.task.task_type == "_PythonDecoratedOperator" + assert desered.task.operator_name == "@task" desered.refresh_from_task(deser_task) @@ -180,12 +181,11 @@ def test_serializing_pydantic_dagrun(session, create_task_instance): ], ) def test_serializing_pydantic_dagmodel(schedule): - timetable = create_timetable(schedule, timezone.utc) dag_model = DagModel( dag_id="test-dag", fileloc="/tmp/dag_1.py", - timetable_summary=timetable.summary, - timetable_description=timetable.description, + timetable_summary="summary", + timetable_description="desc", is_active=True, is_paused=False, ) @@ -196,8 +196,8 @@ def test_serializing_pydantic_dagmodel(schedule): deserialized_model = DagModelPydantic.model_validate_json(json_string) assert deserialized_model.dag_id == "test-dag" assert deserialized_model.fileloc == "/tmp/dag_1.py" - assert deserialized_model.timetable_summary == timetable.summary - assert deserialized_model.timetable_description == timetable.description + assert deserialized_model.timetable_summary == "summary" + assert deserialized_model.timetable_description == "desc" assert deserialized_model.is_active is True assert deserialized_model.is_paused is False diff --git a/tests/utils/test_task_group.py b/tests/utils/test_task_group.py index 5abedfe55386..b973a1f615e0 100644 --- a/tests/utils/test_task_group.py +++ b/tests/utils/test_task_group.py @@ -18,7 +18,6 @@ from __future__ import annotations from datetime import timedelta -from unittest import mock import pendulum import pytest @@ -37,7 +36,7 @@ from airflow.operators.empty import EmptyOperator from airflow.operators.python import PythonOperator from airflow.utils.dag_edges import dag_edges -from airflow.utils.task_group import TASKGROUP_ARGS_EXPECTED_TYPES, TaskGroup, task_group_to_dict +from airflow.utils.task_group import TaskGroup, task_group_to_dict from tests.models import DEFAULT_DATE from tests_common.test_utils.compat import BashOperator @@ -1660,17 +1659,7 @@ def work(): ... def test_task_group_with_invalid_arg_type_raises_error(): - error_msg = "'ui_color' has an invalid type with value 123, expected type is " + error_msg = r"'ui_color' must be \(got 123 that is a \)\." with DAG(dag_id="dag_with_tg_invalid_arg_type", schedule=None): with pytest.raises(TypeError, match=error_msg): - with TaskGroup("group_1", ui_color=123): - EmptyOperator(task_id="task1") - - -@mock.patch("airflow.utils.task_group.validate_instance_args") -def test_task_group_init_validates_arg_types(mock_validate_instance_args): - with DAG(dag_id="dag_with_tg_valid_arg_types", schedule=None): - with TaskGroup("group_1", ui_color="red") as tg: - EmptyOperator(task_id="task1") - - mock_validate_instance_args.assert_called_with(tg, TASKGROUP_ARGS_EXPECTED_TYPES) + _ = TaskGroup("group_1", ui_color=123) diff --git a/tests_common/test_utils/mock_operators.py b/tests_common/test_utils/mock_operators.py index 4bf53551164c..6a88f41c22ad 100644 --- a/tests_common/test_utils/mock_operators.py +++ b/tests_common/test_utils/mock_operators.py @@ -16,7 +16,7 @@ # under the License. from __future__ import annotations -import warnings +from collections.abc import Mapping from typing import TYPE_CHECKING, Any, Sequence import attr @@ -67,7 +67,7 @@ def __init__(self, arg1: str = "", arg2: NestedFields | None = None, **kwargs): def _render_nested_template_fields( self, content: Any, - context: Context, + context: Mapping[str, Any], jinja_env: jinja2.Environment, seen_oids: set, ) -> None: @@ -204,14 +204,3 @@ class GithubLink(BaseOperatorLink): def get_link(self, operator, *, ti_key): return "https://github.com/apache/airflow" - - -class DeprecatedOperator(BaseOperator): - """Deprecated Operator for testing purposes.""" - - def __init__(self, **kwargs): - warnings.warn("This operator is deprecated.", DeprecationWarning, stacklevel=2) - super().__init__(**kwargs) - - def execute(self, context: Context): - pass