Skip to content

Commit

Permalink
Start porting DAG definition code to the Task SDK (apache#43076)
Browse files Browse the repository at this point in the history
closes apache#43011

By "definition code" we mean anything needed at definition/parse time, leaving
anything to do with scheduling time decisions in Airflow's core.

Also in this PR I have _attempted_ to keep it to only porting defintiion code
for simple DAGs, leaving anything to do with mapped tasks or execution time in
core for now, but a few things "leaked" across.

And as the goal of this PR is to go from working state to working state some
of the code in Task SDK still imports from "core" (various types, enums or
helpers) that will need to be resolved before 3.0 release, but it is fine for
now.

I'm also aware that the class hierarchy with
airflow.models.baseoperator.BaseOperator (and to a lesser extend with DAG) in
particular is very messy right now, and we will need to think how we want to
add on the scheduling-time functions etc, as I'm not yet sold that having Core
Airflow depend upon the Task-SDK classes/import the code is the right
structure, but we can address that later

We will also need to addresses the rendered docs for the Task SDK in a future
PR -- the goal is that "anything" exposed on `airflow.sdk` directly is part of
the public API, but right now the renedered docs show DAG as
`airflow.sdk.definitions.dag.DAG` which is certainly not what we want users to
see.

Co-authored-by: Kaxil Naik <[email protected]>
  • Loading branch information
2 people authored and ellisms committed Nov 13, 2024
1 parent de5b1df commit cf011a5
Show file tree
Hide file tree
Showing 106 changed files with 5,850 additions and 4,501 deletions.
5 changes: 5 additions & 0 deletions .github/workflows/prod-image-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,11 @@ jobs:
run: >
breeze release-management prepare-airflow-package --package-format wheel
if: inputs.do-build == 'true' && inputs.upload-package-artifact == 'true'
- name: "Prepare task-sdk package"
shell: bash
run: >
breeze release-management prepare-task-sdk-package --package-format wheel
if: inputs.do-build == 'true' && inputs.upload-package-artifact == 'true'
- name: "Upload prepared packages as artifacts"
uses: actions/upload-artifact@v4
with:
Expand Down
2 changes: 2 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1189,6 +1189,8 @@ repos:
^airflow/utils/helpers.py$ |
^providers/src/airflow/providers/ |
^(providers/)?tests/ |
task_sdk/src/airflow/sdk/definitions/dag.py$ |
task_sdk/src/airflow/sdk/definitions/node.py$ |
^dev/.*\.py$ |
^scripts/.*\.py$ |
^docker_tests/.*$ |
Expand Down
33 changes: 14 additions & 19 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,7 @@ COPY <<"EOF" /install_from_docker_context_files.sh


function install_airflow_and_providers_from_docker_context_files(){
local flags=()
if [[ ${INSTALL_MYSQL_CLIENT} != "true" ]]; then
AIRFLOW_EXTRAS=${AIRFLOW_EXTRAS/mysql,}
fi
Expand Down Expand Up @@ -756,10 +757,10 @@ function install_airflow_and_providers_from_docker_context_files(){
install_airflow_package=("apache-airflow[${AIRFLOW_EXTRAS}]==${AIRFLOW_VERSION}")
fi

# Find Provider packages in docker-context files
readarray -t installing_providers_packages< <(python /scripts/docker/get_package_specs.py /docker-context-files/apache?airflow?providers*.{whl,tar.gz} 2>/dev/null || true)
# Find Provider/TaskSDK packages in docker-context files
readarray -t airflow_packages< <(python /scripts/docker/get_package_specs.py /docker-context-files/apache?airflow?{providers,task?sdk}*.{whl,tar.gz} 2>/dev/null || true)
echo
echo "${COLOR_BLUE}Found provider packages in docker-context-files folder: ${installing_providers_packages[*]}${COLOR_RESET}"
echo "${COLOR_BLUE}Found provider packages in docker-context-files folder: ${airflow_packages[*]}${COLOR_RESET}"
echo

if [[ ${USE_CONSTRAINTS_FOR_CONTEXT_PACKAGES=} == "true" ]]; then
Expand All @@ -772,11 +773,7 @@ function install_airflow_and_providers_from_docker_context_files(){
echo "${COLOR_BLUE}Installing docker-context-files packages with constraints found in ${local_constraints_file}${COLOR_RESET}"
echo
# force reinstall all airflow + provider packages with constraints found in
set -x
${PACKAGING_TOOL_CMD} install ${EXTRA_INSTALL_FLAGS} --upgrade \
${ADDITIONAL_PIP_INSTALL_FLAGS} --constraint "${local_constraints_file}" \
"${install_airflow_package[@]}" "${installing_providers_packages[@]}"
set +x
flags=(--upgrade --constraint "${local_constraints_file}")
echo
echo "${COLOR_BLUE}Copying ${local_constraints_file} to ${HOME}/constraints.txt${COLOR_RESET}"
echo
Expand All @@ -785,23 +782,21 @@ function install_airflow_and_providers_from_docker_context_files(){
echo
echo "${COLOR_BLUE}Installing docker-context-files packages with constraints from GitHub${COLOR_RESET}"
echo
set -x
${PACKAGING_TOOL_CMD} install ${EXTRA_INSTALL_FLAGS} \
${ADDITIONAL_PIP_INSTALL_FLAGS} \
--constraint "${HOME}/constraints.txt" \
"${install_airflow_package[@]}" "${installing_providers_packages[@]}"
set +x
flags=(--constraint "${HOME}/constraints.txt")
fi
else
echo
echo "${COLOR_BLUE}Installing docker-context-files packages without constraints${COLOR_RESET}"
echo
set -x
${PACKAGING_TOOL_CMD} install ${EXTRA_INSTALL_FLAGS} \
${ADDITIONAL_PIP_INSTALL_FLAGS} \
"${install_airflow_package[@]}" "${installing_providers_packages[@]}"
set +x
flags=()
fi

set -x
${PACKAGING_TOOL_CMD} install ${EXTRA_INSTALL_FLAGS} \
${ADDITIONAL_PIP_INSTALL_FLAGS} \
"${flags[@]}" \
"${install_airflow_package[@]}" "${airflow_packages[@]}"
set +x
common::install_packaging_tools
pip check
}
Expand Down
2 changes: 1 addition & 1 deletion Dockerfile.ci
Original file line number Diff line number Diff line change
Expand Up @@ -1154,7 +1154,7 @@ function check_force_lowest_dependencies() {
echo
fi
set -x
uv pip install --python "$(which python)" --resolution lowest-direct --upgrade --editable ".${EXTRA}"
uv pip install --python "$(which python)" --resolution lowest-direct --upgrade --editable ".${EXTRA}" --editable "./task_sdk"
set +x
}

Expand Down
1 change: 0 additions & 1 deletion airflow/api_connexion/schemas/dag_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ class Meta:
last_parsed_time = auto_field(dump_only=True)
last_pickled = auto_field(dump_only=True)
last_expired = auto_field(dump_only=True)
pickle_id = auto_field(dump_only=True)
default_view = auto_field(dump_only=True)
fileloc = auto_field(dump_only=True)
file_token = fields.Method("get_token", dump_only=True)
Expand Down
4 changes: 0 additions & 4 deletions airflow/api_fastapi/core_api/openapi/v1-generated.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1981,9 +1981,6 @@ components:
- type: boolean
- type: 'null'
title: Is Paused Upon Creation
orientation:
type: string
title: Orientation
params:
anyOf:
- type: object
Expand Down Expand Up @@ -2053,7 +2050,6 @@ components:
- start_date
- end_date
- is_paused_upon_creation
- orientation
- params
- render_template_as_native_obj
- template_search_path
Expand Down
1 change: 0 additions & 1 deletion airflow/api_fastapi/core_api/serializers/dags.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ class DAGDetailsResponse(DAGResponse):
start_date: datetime | None
end_date: datetime | None
is_paused_upon_creation: bool | None
orientation: str
params: abc.MutableMapping | None
render_template_as_native_obj: bool
template_search_path: Iterable[str] | None
Expand Down
1 change: 0 additions & 1 deletion airflow/cli/commands/dag_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,6 @@ def _get_dagbag_dag_details(dag: DAG) -> dict:
"last_parsed_time": None,
"last_pickled": None,
"last_expired": None,
"pickle_id": dag.pickle_id,
"default_view": dag.default_view,
"fileloc": dag.fileloc,
"file_token": None,
Expand Down
5 changes: 4 additions & 1 deletion airflow/dag_processing/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,10 @@ def update_dags(
dm.has_import_errors = False
dm.last_parsed_time = utcnow()
dm.default_view = dag.default_view
dm._dag_display_property_value = dag._dag_display_property_value
if hasattr(dag, "_dag_display_property_value"):
dm._dag_display_property_value = dag._dag_display_property_value
elif dag.dag_display_name != dag.dag_id:
dm._dag_display_property_value = dag.dag_display_name
dm.description = dag.description
dm.max_active_tasks = dag.max_active_tasks
dm.max_active_runs = dag.max_active_runs
Expand Down
44 changes: 18 additions & 26 deletions airflow/decorators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,43 +41,41 @@
import typing_extensions

from airflow.assets import Asset
from airflow.models.abstractoperator import DEFAULT_RETRIES, DEFAULT_RETRY_DELAY
from airflow.models.baseoperator import (
BaseOperator,
coerce_resources,
coerce_timedelta,
get_merged_defaults,
parse_retries,
)
from airflow.models.dag import DagContext
from airflow.models.expandinput import (
EXPAND_INPUT_EMPTY,
DictOfListsExpandInput,
ListOfDictsExpandInput,
is_mappable,
)
from airflow.models.mappedoperator import MappedOperator, ensure_xcomarg_return_value
from airflow.models.pool import Pool
from airflow.models.xcom_arg import XComArg
from airflow.sdk.definitions.baseoperator import BaseOperator as TaskSDKBaseOperator
from airflow.sdk.definitions.contextmanager import DagContext, TaskGroupContext
from airflow.typing_compat import ParamSpec, Protocol
from airflow.utils import timezone
from airflow.utils.context import KNOWN_CONTEXT_KEYS
from airflow.utils.decorators import remove_task_decorator
from airflow.utils.helpers import prevent_duplicates
from airflow.utils.task_group import TaskGroupContext
from airflow.utils.trigger_rule import TriggerRule
from airflow.utils.types import NOTSET

if TYPE_CHECKING:
from sqlalchemy.orm import Session

from airflow.models.dag import DAG
from airflow.models.expandinput import (
ExpandInput,
OperatorExpandArgument,
OperatorExpandKwargsArgument,
)
from airflow.models.mappedoperator import ValidationSource
from airflow.sdk import DAG
from airflow.utils.context import Context
from airflow.utils.task_group import TaskGroup

Expand Down Expand Up @@ -141,13 +139,13 @@ def get_unique_task_id(
...
task_id__20
"""
dag = dag or DagContext.get_current_dag()
dag = dag or DagContext.get_current()
if not dag:
return task_id

# We need to check if we are in the context of TaskGroup as the task_id may
# already be altered
task_group = task_group or TaskGroupContext.get_current_task_group(dag)
task_group = task_group or TaskGroupContext.get_current(dag)
tg_task_id = task_group.child_id(task_id) if task_group else task_id

if tg_task_id not in dag.task_ids:
Expand Down Expand Up @@ -428,8 +426,8 @@ def _expand(self, expand_input: ExpandInput, *, strict: bool) -> XComArg:
ensure_xcomarg_return_value(expand_input.value)

task_kwargs = self.kwargs.copy()
dag = task_kwargs.pop("dag", None) or DagContext.get_current_dag()
task_group = task_kwargs.pop("task_group", None) or TaskGroupContext.get_current_task_group(dag)
dag = task_kwargs.pop("dag", None) or DagContext.get_current()
task_group = task_kwargs.pop("task_group", None) or TaskGroupContext.get_current(dag)

default_args, partial_params = get_merged_defaults(
dag=dag,
Expand All @@ -442,7 +440,7 @@ def _expand(self, expand_input: ExpandInput, *, strict: bool) -> XComArg:
"is_teardown": self.is_teardown,
"on_failure_fail_dagrun": self.on_failure_fail_dagrun,
}
base_signature = inspect.signature(BaseOperator)
base_signature = inspect.signature(TaskSDKBaseOperator)
ignore = {
"default_args", # This is target we are working on now.
"kwargs", # A common name for a keyword argument.
Expand All @@ -460,32 +458,26 @@ def _expand(self, expand_input: ExpandInput, *, strict: bool) -> XComArg:
task_id = task_group.child_id(task_id)

# Logic here should be kept in sync with BaseOperatorMeta.partial().
if "task_concurrency" in partial_kwargs:
raise TypeError("unexpected argument: task_concurrency")
if partial_kwargs.get("wait_for_downstream"):
partial_kwargs["depends_on_past"] = True
start_date = timezone.convert_to_utc(partial_kwargs.pop("start_date", None))
end_date = timezone.convert_to_utc(partial_kwargs.pop("end_date", None))
if partial_kwargs.get("pool") is None:
partial_kwargs["pool"] = Pool.DEFAULT_POOL_NAME
if "pool_slots" in partial_kwargs:
if partial_kwargs["pool_slots"] < 1:
dag_str = ""
if dag:
dag_str = f" in dag {dag.dag_id}"
raise ValueError(f"pool slots for {task_id}{dag_str} cannot be less than 1")
partial_kwargs["retries"] = parse_retries(partial_kwargs.get("retries", DEFAULT_RETRIES))
partial_kwargs["retry_delay"] = coerce_timedelta(
partial_kwargs.get("retry_delay", DEFAULT_RETRY_DELAY),
key="retry_delay",
)
max_retry_delay = partial_kwargs.get("max_retry_delay")
partial_kwargs["max_retry_delay"] = (
max_retry_delay
if max_retry_delay is None
else coerce_timedelta(max_retry_delay, key="max_retry_delay")
)
partial_kwargs["resources"] = coerce_resources(partial_kwargs.get("resources"))

for fld, convert in (
("retries", parse_retries),
("retry_delay", coerce_timedelta),
("max_retry_delay", coerce_timedelta),
("resources", coerce_resources),
):
if (v := partial_kwargs.get(fld, NOTSET)) is not NOTSET:
partial_kwargs[fld] = convert(v) # type: ignore[operator]

partial_kwargs.setdefault("executor_config", {})
partial_kwargs.setdefault("op_args", [])
partial_kwargs.setdefault("op_kwargs", {})
Expand Down
4 changes: 2 additions & 2 deletions airflow/decorators/bash.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from __future__ import annotations

import warnings
from typing import Any, Callable, Collection, Mapping, Sequence
from typing import Any, Callable, ClassVar, Collection, Mapping, Sequence

from airflow.decorators.base import DecoratedOperator, TaskDecorator, task_decorator_factory
from airflow.providers.standard.operators.bash import BashOperator
Expand All @@ -39,7 +39,7 @@ class _BashDecoratedOperator(DecoratedOperator, BashOperator):
"""

template_fields: Sequence[str] = (*DecoratedOperator.template_fields, *BashOperator.template_fields)
template_fields_renderers: dict[str, str] = {
template_fields_renderers: ClassVar[dict[str, str]] = {
**DecoratedOperator.template_fields_renderers,
**BashOperator.template_fields_renderers,
}
Expand Down
4 changes: 2 additions & 2 deletions airflow/decorators/sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Callable, Sequence
from typing import TYPE_CHECKING, Callable, ClassVar, Sequence

from airflow.decorators.base import get_unique_task_id, task_decorator_factory
from airflow.sensors.python import PythonSensor
Expand All @@ -42,7 +42,7 @@ class DecoratedSensorOperator(PythonSensor):
"""

template_fields: Sequence[str] = ("op_args", "op_kwargs")
template_fields_renderers: dict[str, str] = {"op_args": "py", "op_kwargs": "py"}
template_fields_renderers: ClassVar[dict[str, str]] = {"op_args": "py", "op_kwargs": "py"}

custom_operator_name = "@task.sensor"

Expand Down
2 changes: 1 addition & 1 deletion airflow/decorators/task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@
ListOfDictsExpandInput,
MappedArgument,
)
from airflow.models.taskmixin import DAGNode
from airflow.models.xcom_arg import XComArg
from airflow.sdk.definitions.node import DAGNode
from airflow.typing_compat import ParamSpec
from airflow.utils.helpers import prevent_duplicates
from airflow.utils.task_group import MappedTaskGroup, TaskGroup
Expand Down
6 changes: 3 additions & 3 deletions airflow/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import datetime
from collections.abc import Sized

from airflow.models import DAG, DagRun
from airflow.models import DagRun


class AirflowException(Exception):
Expand Down Expand Up @@ -273,13 +273,13 @@ class FailStopDagInvalidTriggerRule(AirflowException):
_allowed_rules = (TriggerRule.ALL_SUCCESS, TriggerRule.ALL_DONE_SETUP_SUCCESS)

@classmethod
def check(cls, *, dag: DAG | None, trigger_rule: TriggerRule):
def check(cls, *, fail_stop: bool, trigger_rule: TriggerRule):
"""
Check that fail_stop dag tasks have allowable trigger rules.
:meta private:
"""
if dag is not None and dag.fail_stop and trigger_rule not in cls._allowed_rules:
if fail_stop and trigger_rule not in cls._allowed_rules:
raise cls()

def __str__(self) -> str:
Expand Down
Loading

0 comments on commit cf011a5

Please sign in to comment.