Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Start porting DAG definition code to the Task SDK #43076

Merged
merged 32 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
578f21d
Start porting over all the DAG defintion code to the Task SDK
ashb Oct 12, 2024
bac7847
Get more tests passing
ashb Oct 18, 2024
ff638f8
[skip ci]
ashb Oct 19, 2024
4c98cf9
[skip ci]
ashb Oct 21, 2024
f32e424
Fix some tests in tests/models/test_dagbag.py [ci skip]
kaxil Oct 22, 2024
b64d680
[skip ci]
ashb Oct 22, 2024
4479c86
More fixes to test_dagbag.py [skip ci]
kaxil Oct 22, 2024
f504e20
Use DAG Context from Task SDK [skip ci]
kaxil Oct 22, 2024
549c18f
Replace DagContext from core to Task SDK [skip ci]
kaxil Oct 23, 2024
b34f7ef
[skip-ci]
ashb Oct 23, 2024
5c14f90
Use attrs converters for access_control [skip ci]
kaxil Oct 23, 2024
c78c0b8
Update pre-commit scripts.
ashb Oct 24, 2024
e6ce661
fix more stest [skip-ci]
ashb Oct 24, 2024
a78fc3d
make mpypy happy [skip ci]
ashb Oct 24, 2024
8d8fa7f
Fix default pool
ashb Oct 25, 2024
deff3a0
Fix mypy typing
ashb Oct 25, 2024
e6d360c
Fix serialization
ashb Oct 28, 2024
9b26fb4
fix some tests [skip ci]
ashb Oct 29, 2024
40f31e7
Fix AirflowExecption error in tests/models/test_taskinstance.py
kaxil Oct 29, 2024
9d38a5d
[skip-ci]
ashb Oct 29, 2024
c4b4e4b
[skip-ci]
ashb Oct 29, 2024
cf1de29
[skip-ci]
ashb Oct 29, 2024
7b1e4a7
Fix tests [skip ci]
ashb Oct 29, 2024
8d8ec63
[skip-ci]
ashb Oct 29, 2024
03b3d99
fix-non db tests
ashb Oct 29, 2024
623fcff
Fix timezone from default_args test
kaxil Oct 29, 2024
3f82122
fix more tests [skip ci]
ashb Oct 29, 2024
75b6848
Install task-sdk when downgrading deps
ashb Oct 29, 2024
4108b2c
[skip ci]
ashb Oct 30, 2024
7de2a64
Upgrade minimum pydantic version to 2.7 to deal with delayed annotati…
ashb Oct 30, 2024
969af3c
Install and build task-sdk in prod images
ashb Oct 30, 2024
1ef4b33
fixup breeze hash
ashb Oct 30, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .github/workflows/prod-image-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,11 @@ jobs:
run: >
breeze release-management prepare-airflow-package --package-format wheel
if: inputs.do-build == 'true' && inputs.upload-package-artifact == 'true'
- name: "Prepare task-sdk package"
shell: bash
run: >
breeze release-management prepare-task-sdk-package --package-format wheel
if: inputs.do-build == 'true' && inputs.upload-package-artifact == 'true'
- name: "Upload prepared packages as artifacts"
uses: actions/upload-artifact@v4
with:
Expand Down
2 changes: 2 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1189,6 +1189,8 @@ repos:
^airflow/utils/helpers.py$ |
^providers/src/airflow/providers/ |
^(providers/)?tests/ |
task_sdk/src/airflow/sdk/definitions/dag.py$ |
task_sdk/src/airflow/sdk/definitions/node.py$ |
^dev/.*\.py$ |
^scripts/.*\.py$ |
^docker_tests/.*$ |
Expand Down
33 changes: 14 additions & 19 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,7 @@ COPY <<"EOF" /install_from_docker_context_files.sh


function install_airflow_and_providers_from_docker_context_files(){
local flags=()
if [[ ${INSTALL_MYSQL_CLIENT} != "true" ]]; then
AIRFLOW_EXTRAS=${AIRFLOW_EXTRAS/mysql,}
fi
Expand Down Expand Up @@ -756,10 +757,10 @@ function install_airflow_and_providers_from_docker_context_files(){
install_airflow_package=("apache-airflow[${AIRFLOW_EXTRAS}]==${AIRFLOW_VERSION}")
fi

# Find Provider packages in docker-context files
readarray -t installing_providers_packages< <(python /scripts/docker/get_package_specs.py /docker-context-files/apache?airflow?providers*.{whl,tar.gz} 2>/dev/null || true)
# Find Provider/TaskSDK packages in docker-context files
readarray -t airflow_packages< <(python /scripts/docker/get_package_specs.py /docker-context-files/apache?airflow?{providers,task?sdk}*.{whl,tar.gz} 2>/dev/null || true)
echo
echo "${COLOR_BLUE}Found provider packages in docker-context-files folder: ${installing_providers_packages[*]}${COLOR_RESET}"
echo "${COLOR_BLUE}Found provider packages in docker-context-files folder: ${airflow_packages[*]}${COLOR_RESET}"
echo

if [[ ${USE_CONSTRAINTS_FOR_CONTEXT_PACKAGES=} == "true" ]]; then
Expand All @@ -772,11 +773,7 @@ function install_airflow_and_providers_from_docker_context_files(){
echo "${COLOR_BLUE}Installing docker-context-files packages with constraints found in ${local_constraints_file}${COLOR_RESET}"
echo
# force reinstall all airflow + provider packages with constraints found in
set -x
${PACKAGING_TOOL_CMD} install ${EXTRA_INSTALL_FLAGS} --upgrade \
${ADDITIONAL_PIP_INSTALL_FLAGS} --constraint "${local_constraints_file}" \
"${install_airflow_package[@]}" "${installing_providers_packages[@]}"
set +x
flags=(--upgrade --constraint "${local_constraints_file}")
echo
echo "${COLOR_BLUE}Copying ${local_constraints_file} to ${HOME}/constraints.txt${COLOR_RESET}"
echo
Expand All @@ -785,23 +782,21 @@ function install_airflow_and_providers_from_docker_context_files(){
echo
echo "${COLOR_BLUE}Installing docker-context-files packages with constraints from GitHub${COLOR_RESET}"
echo
set -x
${PACKAGING_TOOL_CMD} install ${EXTRA_INSTALL_FLAGS} \
${ADDITIONAL_PIP_INSTALL_FLAGS} \
--constraint "${HOME}/constraints.txt" \
"${install_airflow_package[@]}" "${installing_providers_packages[@]}"
set +x
flags=(--constraint "${HOME}/constraints.txt")
fi
else
echo
echo "${COLOR_BLUE}Installing docker-context-files packages without constraints${COLOR_RESET}"
echo
set -x
${PACKAGING_TOOL_CMD} install ${EXTRA_INSTALL_FLAGS} \
${ADDITIONAL_PIP_INSTALL_FLAGS} \
"${install_airflow_package[@]}" "${installing_providers_packages[@]}"
set +x
flags=()
fi

set -x
${PACKAGING_TOOL_CMD} install ${EXTRA_INSTALL_FLAGS} \
${ADDITIONAL_PIP_INSTALL_FLAGS} \
"${flags[@]}" \
"${install_airflow_package[@]}" "${airflow_packages[@]}"
set +x
common::install_packaging_tools
pip check
}
Expand Down
2 changes: 1 addition & 1 deletion Dockerfile.ci
Original file line number Diff line number Diff line change
Expand Up @@ -1154,7 +1154,7 @@ function check_force_lowest_dependencies() {
echo
fi
set -x
uv pip install --python "$(which python)" --resolution lowest-direct --upgrade --editable ".${EXTRA}"
uv pip install --python "$(which python)" --resolution lowest-direct --upgrade --editable ".${EXTRA}" --editable "./task_sdk"
set +x
}

Expand Down
1 change: 0 additions & 1 deletion airflow/api_connexion/schemas/dag_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ class Meta:
last_parsed_time = auto_field(dump_only=True)
last_pickled = auto_field(dump_only=True)
last_expired = auto_field(dump_only=True)
pickle_id = auto_field(dump_only=True)
default_view = auto_field(dump_only=True)
fileloc = auto_field(dump_only=True)
file_token = fields.Method("get_token", dump_only=True)
Expand Down
4 changes: 0 additions & 4 deletions airflow/api_fastapi/core_api/openapi/v1-generated.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1981,9 +1981,6 @@ components:
- type: boolean
- type: 'null'
title: Is Paused Upon Creation
orientation:
type: string
title: Orientation
params:
anyOf:
- type: object
Expand Down Expand Up @@ -2053,7 +2050,6 @@ components:
- start_date
- end_date
- is_paused_upon_creation
- orientation
- params
- render_template_as_native_obj
- template_search_path
Expand Down
1 change: 0 additions & 1 deletion airflow/api_fastapi/core_api/serializers/dags.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ class DAGDetailsResponse(DAGResponse):
start_date: datetime | None
end_date: datetime | None
is_paused_upon_creation: bool | None
orientation: str
params: abc.MutableMapping | None
render_template_as_native_obj: bool
template_search_path: Iterable[str] | None
Expand Down
1 change: 0 additions & 1 deletion airflow/cli/commands/dag_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,6 @@ def _get_dagbag_dag_details(dag: DAG) -> dict:
"last_parsed_time": None,
"last_pickled": None,
"last_expired": None,
"pickle_id": dag.pickle_id,
"default_view": dag.default_view,
"fileloc": dag.fileloc,
"file_token": None,
Expand Down
5 changes: 4 additions & 1 deletion airflow/dag_processing/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,10 @@ def update_dags(
dm.has_import_errors = False
dm.last_parsed_time = utcnow()
dm.default_view = dag.default_view
dm._dag_display_property_value = dag._dag_display_property_value
if hasattr(dag, "_dag_display_property_value"):
dm._dag_display_property_value = dag._dag_display_property_value
elif dag.dag_display_name != dag.dag_id:
dm._dag_display_property_value = dag.dag_display_name
dm.description = dag.description
dm.max_active_tasks = dag.max_active_tasks
dm.max_active_runs = dag.max_active_runs
Expand Down
44 changes: 18 additions & 26 deletions airflow/decorators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,43 +41,41 @@
import typing_extensions

from airflow.assets import Asset
from airflow.models.abstractoperator import DEFAULT_RETRIES, DEFAULT_RETRY_DELAY
from airflow.models.baseoperator import (
BaseOperator,
coerce_resources,
coerce_timedelta,
get_merged_defaults,
parse_retries,
)
from airflow.models.dag import DagContext
from airflow.models.expandinput import (
EXPAND_INPUT_EMPTY,
DictOfListsExpandInput,
ListOfDictsExpandInput,
is_mappable,
)
from airflow.models.mappedoperator import MappedOperator, ensure_xcomarg_return_value
from airflow.models.pool import Pool
from airflow.models.xcom_arg import XComArg
from airflow.sdk.definitions.baseoperator import BaseOperator as TaskSDKBaseOperator
from airflow.sdk.definitions.contextmanager import DagContext, TaskGroupContext
from airflow.typing_compat import ParamSpec, Protocol
from airflow.utils import timezone
from airflow.utils.context import KNOWN_CONTEXT_KEYS
from airflow.utils.decorators import remove_task_decorator
from airflow.utils.helpers import prevent_duplicates
from airflow.utils.task_group import TaskGroupContext
from airflow.utils.trigger_rule import TriggerRule
from airflow.utils.types import NOTSET

if TYPE_CHECKING:
from sqlalchemy.orm import Session

from airflow.models.dag import DAG
from airflow.models.expandinput import (
ExpandInput,
OperatorExpandArgument,
OperatorExpandKwargsArgument,
)
from airflow.models.mappedoperator import ValidationSource
from airflow.sdk import DAG
from airflow.utils.context import Context
from airflow.utils.task_group import TaskGroup

Expand Down Expand Up @@ -141,13 +139,13 @@ def get_unique_task_id(
...
task_id__20
"""
dag = dag or DagContext.get_current_dag()
dag = dag or DagContext.get_current()
if not dag:
return task_id

# We need to check if we are in the context of TaskGroup as the task_id may
# already be altered
task_group = task_group or TaskGroupContext.get_current_task_group(dag)
task_group = task_group or TaskGroupContext.get_current(dag)
tg_task_id = task_group.child_id(task_id) if task_group else task_id

if tg_task_id not in dag.task_ids:
Expand Down Expand Up @@ -428,8 +426,8 @@ def _expand(self, expand_input: ExpandInput, *, strict: bool) -> XComArg:
ensure_xcomarg_return_value(expand_input.value)

task_kwargs = self.kwargs.copy()
dag = task_kwargs.pop("dag", None) or DagContext.get_current_dag()
task_group = task_kwargs.pop("task_group", None) or TaskGroupContext.get_current_task_group(dag)
dag = task_kwargs.pop("dag", None) or DagContext.get_current()
task_group = task_kwargs.pop("task_group", None) or TaskGroupContext.get_current(dag)

default_args, partial_params = get_merged_defaults(
dag=dag,
Expand All @@ -442,7 +440,7 @@ def _expand(self, expand_input: ExpandInput, *, strict: bool) -> XComArg:
"is_teardown": self.is_teardown,
"on_failure_fail_dagrun": self.on_failure_fail_dagrun,
}
base_signature = inspect.signature(BaseOperator)
base_signature = inspect.signature(TaskSDKBaseOperator)
ignore = {
"default_args", # This is target we are working on now.
"kwargs", # A common name for a keyword argument.
Expand All @@ -460,32 +458,26 @@ def _expand(self, expand_input: ExpandInput, *, strict: bool) -> XComArg:
task_id = task_group.child_id(task_id)

# Logic here should be kept in sync with BaseOperatorMeta.partial().
if "task_concurrency" in partial_kwargs:
raise TypeError("unexpected argument: task_concurrency")
if partial_kwargs.get("wait_for_downstream"):
partial_kwargs["depends_on_past"] = True
start_date = timezone.convert_to_utc(partial_kwargs.pop("start_date", None))
end_date = timezone.convert_to_utc(partial_kwargs.pop("end_date", None))
if partial_kwargs.get("pool") is None:
partial_kwargs["pool"] = Pool.DEFAULT_POOL_NAME
if "pool_slots" in partial_kwargs:
if partial_kwargs["pool_slots"] < 1:
dag_str = ""
if dag:
dag_str = f" in dag {dag.dag_id}"
raise ValueError(f"pool slots for {task_id}{dag_str} cannot be less than 1")
partial_kwargs["retries"] = parse_retries(partial_kwargs.get("retries", DEFAULT_RETRIES))
partial_kwargs["retry_delay"] = coerce_timedelta(
partial_kwargs.get("retry_delay", DEFAULT_RETRY_DELAY),
key="retry_delay",
)
max_retry_delay = partial_kwargs.get("max_retry_delay")
partial_kwargs["max_retry_delay"] = (
max_retry_delay
if max_retry_delay is None
else coerce_timedelta(max_retry_delay, key="max_retry_delay")
)
partial_kwargs["resources"] = coerce_resources(partial_kwargs.get("resources"))

for fld, convert in (
("retries", parse_retries),
("retry_delay", coerce_timedelta),
("max_retry_delay", coerce_timedelta),
("resources", coerce_resources),
):
if (v := partial_kwargs.get(fld, NOTSET)) is not NOTSET:
partial_kwargs[fld] = convert(v) # type: ignore[operator]

partial_kwargs.setdefault("executor_config", {})
partial_kwargs.setdefault("op_args", [])
partial_kwargs.setdefault("op_kwargs", {})
Expand Down
4 changes: 2 additions & 2 deletions airflow/decorators/bash.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from __future__ import annotations

import warnings
from typing import Any, Callable, Collection, Mapping, Sequence
from typing import Any, Callable, ClassVar, Collection, Mapping, Sequence

from airflow.decorators.base import DecoratedOperator, TaskDecorator, task_decorator_factory
from airflow.providers.standard.operators.bash import BashOperator
Expand All @@ -39,7 +39,7 @@ class _BashDecoratedOperator(DecoratedOperator, BashOperator):
"""

template_fields: Sequence[str] = (*DecoratedOperator.template_fields, *BashOperator.template_fields)
template_fields_renderers: dict[str, str] = {
template_fields_renderers: ClassVar[dict[str, str]] = {
**DecoratedOperator.template_fields_renderers,
**BashOperator.template_fields_renderers,
}
Expand Down
4 changes: 2 additions & 2 deletions airflow/decorators/sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Callable, Sequence
from typing import TYPE_CHECKING, Callable, ClassVar, Sequence

from airflow.decorators.base import get_unique_task_id, task_decorator_factory
from airflow.sensors.python import PythonSensor
Expand All @@ -42,7 +42,7 @@ class DecoratedSensorOperator(PythonSensor):
"""

template_fields: Sequence[str] = ("op_args", "op_kwargs")
template_fields_renderers: dict[str, str] = {"op_args": "py", "op_kwargs": "py"}
template_fields_renderers: ClassVar[dict[str, str]] = {"op_args": "py", "op_kwargs": "py"}

custom_operator_name = "@task.sensor"

Expand Down
2 changes: 1 addition & 1 deletion airflow/decorators/task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@
ListOfDictsExpandInput,
MappedArgument,
)
from airflow.models.taskmixin import DAGNode
from airflow.models.xcom_arg import XComArg
from airflow.sdk.definitions.node import DAGNode
from airflow.typing_compat import ParamSpec
from airflow.utils.helpers import prevent_duplicates
from airflow.utils.task_group import MappedTaskGroup, TaskGroup
Expand Down
6 changes: 3 additions & 3 deletions airflow/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import datetime
from collections.abc import Sized

from airflow.models import DAG, DagRun
from airflow.models import DagRun


class AirflowException(Exception):
Expand Down Expand Up @@ -273,13 +273,13 @@ class FailStopDagInvalidTriggerRule(AirflowException):
_allowed_rules = (TriggerRule.ALL_SUCCESS, TriggerRule.ALL_DONE_SETUP_SUCCESS)

@classmethod
def check(cls, *, dag: DAG | None, trigger_rule: TriggerRule):
def check(cls, *, fail_stop: bool, trigger_rule: TriggerRule):
"""
Check that fail_stop dag tasks have allowable trigger rules.

:meta private:
"""
if dag is not None and dag.fail_stop and trigger_rule not in cls._allowed_rules:
if fail_stop and trigger_rule not in cls._allowed_rules:
raise cls()

def __str__(self) -> str:
Expand Down
Loading
Loading