From b00a81000b1efe7c47a4c0e8a6167a0718dc909e Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Thu, 14 Nov 2024 15:59:35 +0530 Subject: [PATCH 01/33] Remove ORM references from datamodels for assets (#44010) --- airflow/api_fastapi/core_api/datamodels/assets.py | 6 +++--- airflow/api_fastapi/core_api/openapi/v1-generated.yaml | 6 +++--- airflow/ui/openapi-gen/requests/schemas.gen.ts | 9 +++------ airflow/ui/openapi-gen/requests/types.gen.ts | 6 +++--- 4 files changed, 12 insertions(+), 15 deletions(-) diff --git a/airflow/api_fastapi/core_api/datamodels/assets.py b/airflow/api_fastapi/core_api/datamodels/assets.py index 85e41ff7b569..1295b33fbf76 100644 --- a/airflow/api_fastapi/core_api/datamodels/assets.py +++ b/airflow/api_fastapi/core_api/datamodels/assets.py @@ -23,7 +23,7 @@ class DagScheduleAssetReference(BaseModel): - """Serializable version of the DagScheduleAssetReference ORM SqlAlchemyModel.""" + """DAG schedule reference serializer for assets.""" dag_id: str created_at: datetime @@ -31,7 +31,7 @@ class DagScheduleAssetReference(BaseModel): class TaskOutletAssetReference(BaseModel): - """Serializable version of the TaskOutletAssetReference ORM SqlAlchemyModel.""" + """Task outlet reference serializer for assets.""" dag_id: str task_id: str @@ -40,7 +40,7 @@ class TaskOutletAssetReference(BaseModel): class AssetAliasSchema(BaseModel): - """Serializable version of the AssetAliasSchema ORM SqlAlchemyModel.""" + """Asset alias serializer for assets.""" id: int name: str diff --git a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml index b99b389de51f..b5dc0db0741b 100644 --- a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml +++ b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml @@ -3650,7 +3650,7 @@ components: - id - name title: AssetAliasSchema - description: Serializable version of the AssetAliasSchema ORM SqlAlchemyModel. + description: Asset alias serializer for assets. AssetCollectionResponse: properties: assets: @@ -4900,7 +4900,7 @@ components: - created_at - updated_at title: DagScheduleAssetReference - description: Serializable version of the DagScheduleAssetReference ORM SqlAlchemyModel. + description: DAG schedule reference serializer for assets. DagStatsCollectionResponse: properties: dags: @@ -5800,7 +5800,7 @@ components: - created_at - updated_at title: TaskOutletAssetReference - description: Serializable version of the TaskOutletAssetReference ORM SqlAlchemyModel. + description: Task outlet reference serializer for assets. TaskResponse: properties: task_id: diff --git a/airflow/ui/openapi-gen/requests/schemas.gen.ts b/airflow/ui/openapi-gen/requests/schemas.gen.ts index 7bf8f4b02966..8a8a50bb7437 100644 --- a/airflow/ui/openapi-gen/requests/schemas.gen.ts +++ b/airflow/ui/openapi-gen/requests/schemas.gen.ts @@ -103,8 +103,7 @@ export const $AssetAliasSchema = { type: "object", required: ["id", "name"], title: "AssetAliasSchema", - description: - "Serializable version of the AssetAliasSchema ORM SqlAlchemyModel.", + description: "Asset alias serializer for assets.", } as const; export const $AssetCollectionResponse = { @@ -2012,8 +2011,7 @@ export const $DagScheduleAssetReference = { type: "object", required: ["dag_id", "created_at", "updated_at"], title: "DagScheduleAssetReference", - description: - "Serializable version of the DagScheduleAssetReference ORM SqlAlchemyModel.", + description: "DAG schedule reference serializer for assets.", } as const; export const $DagStatsCollectionResponse = { @@ -3327,8 +3325,7 @@ export const $TaskOutletAssetReference = { type: "object", required: ["dag_id", "task_id", "created_at", "updated_at"], title: "TaskOutletAssetReference", - description: - "Serializable version of the TaskOutletAssetReference ORM SqlAlchemyModel.", + description: "Task outlet reference serializer for assets.", } as const; export const $TaskResponse = { diff --git a/airflow/ui/openapi-gen/requests/types.gen.ts b/airflow/ui/openapi-gen/requests/types.gen.ts index c4eee6b97c21..5267c193f480 100644 --- a/airflow/ui/openapi-gen/requests/types.gen.ts +++ b/airflow/ui/openapi-gen/requests/types.gen.ts @@ -22,7 +22,7 @@ export type AppBuilderViewResponse = { }; /** - * Serializable version of the AssetAliasSchema ORM SqlAlchemyModel. + * Asset alias serializer for assets. */ export type AssetAliasSchema = { id: number; @@ -457,7 +457,7 @@ export type DagRunType = | "asset_triggered"; /** - * Serializable version of the DagScheduleAssetReference ORM SqlAlchemyModel. + * DAG schedule reference serializer for assets. */ export type DagScheduleAssetReference = { dag_id: string; @@ -809,7 +809,7 @@ export type TaskInstanceStateCount = { }; /** - * Serializable version of the TaskOutletAssetReference ORM SqlAlchemyModel. + * Task outlet reference serializer for assets. */ export type TaskOutletAssetReference = { dag_id: string; From 2ef8438eecb35027601982bd00865acca737a5b3 Mon Sep 17 00:00:00 2001 From: GPK Date: Thu, 14 Nov 2024 13:10:44 +0000 Subject: [PATCH 02/33] move version imports to inside utils (#44018) --- .../airflow/providers/standard/__init__.py | 9 ------- .../providers/standard/operators/python.py | 2 +- .../providers/standard/sensors/date_time.py | 2 +- .../providers/standard/sensors/time.py | 2 +- .../providers/standard/sensors/time_delta.py | 2 +- .../standard/utils/version_references.py | 26 +++++++++++++++++++ 6 files changed, 30 insertions(+), 13 deletions(-) create mode 100644 providers/src/airflow/providers/standard/utils/version_references.py diff --git a/providers/src/airflow/providers/standard/__init__.py b/providers/src/airflow/providers/standard/__init__.py index 47fc7a1e8009..217e5db96078 100644 --- a/providers/src/airflow/providers/standard/__init__.py +++ b/providers/src/airflow/providers/standard/__init__.py @@ -15,12 +15,3 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from __future__ import annotations - -from packaging.version import Version - -from airflow import __version__ as airflow_version - -AIRFLOW_VERSION = Version(airflow_version) -AIRFLOW_V_2_10_PLUS = Version(AIRFLOW_VERSION.base_version) >= Version("2.10.0") -AIRFLOW_V_3_0_PLUS = Version(AIRFLOW_VERSION.base_version) >= Version("3.0.0") diff --git a/providers/src/airflow/providers/standard/operators/python.py b/providers/src/airflow/providers/standard/operators/python.py index 8c1f440c73a2..0e8b3843c5ce 100644 --- a/providers/src/airflow/providers/standard/operators/python.py +++ b/providers/src/airflow/providers/standard/operators/python.py @@ -48,8 +48,8 @@ from airflow.models.taskinstance import _CURRENT_CONTEXT from airflow.models.variable import Variable from airflow.operators.branch import BranchMixIn -from airflow.providers.standard import AIRFLOW_V_2_10_PLUS, AIRFLOW_V_3_0_PLUS from airflow.providers.standard.utils.python_virtualenv import prepare_virtualenv, write_python_script +from airflow.providers.standard.utils.version_references import AIRFLOW_V_2_10_PLUS, AIRFLOW_V_3_0_PLUS from airflow.settings import _ENABLE_AIP_44 from airflow.typing_compat import Literal from airflow.utils import hashlib_wrapper diff --git a/providers/src/airflow/providers/standard/sensors/date_time.py b/providers/src/airflow/providers/standard/sensors/date_time.py index 35e88df07ba7..63917e3c2239 100644 --- a/providers/src/airflow/providers/standard/sensors/date_time.py +++ b/providers/src/airflow/providers/standard/sensors/date_time.py @@ -21,7 +21,7 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any, NoReturn, Sequence -from airflow.providers.standard import AIRFLOW_V_3_0_PLUS +from airflow.providers.standard.utils.version_references import AIRFLOW_V_3_0_PLUS from airflow.sensors.base import BaseSensorOperator try: diff --git a/providers/src/airflow/providers/standard/sensors/time.py b/providers/src/airflow/providers/standard/sensors/time.py index 5c1629495297..8b727cb1cf1d 100644 --- a/providers/src/airflow/providers/standard/sensors/time.py +++ b/providers/src/airflow/providers/standard/sensors/time.py @@ -21,7 +21,7 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any, NoReturn -from airflow.providers.standard import AIRFLOW_V_2_10_PLUS +from airflow.providers.standard.utils.version_references import AIRFLOW_V_2_10_PLUS from airflow.sensors.base import BaseSensorOperator try: diff --git a/providers/src/airflow/providers/standard/sensors/time_delta.py b/providers/src/airflow/providers/standard/sensors/time_delta.py index eb8bac1c57ea..0b50c5cef863 100644 --- a/providers/src/airflow/providers/standard/sensors/time_delta.py +++ b/providers/src/airflow/providers/standard/sensors/time_delta.py @@ -23,7 +23,7 @@ from airflow.configuration import conf from airflow.exceptions import AirflowSkipException -from airflow.providers.standard import AIRFLOW_V_3_0_PLUS +from airflow.providers.standard.utils.version_references import AIRFLOW_V_3_0_PLUS from airflow.sensors.base import BaseSensorOperator from airflow.triggers.temporal import DateTimeTrigger, TimeDeltaTrigger from airflow.utils import timezone diff --git a/providers/src/airflow/providers/standard/utils/version_references.py b/providers/src/airflow/providers/standard/utils/version_references.py new file mode 100644 index 000000000000..47fc7a1e8009 --- /dev/null +++ b/providers/src/airflow/providers/standard/utils/version_references.py @@ -0,0 +1,26 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from packaging.version import Version + +from airflow import __version__ as airflow_version + +AIRFLOW_VERSION = Version(airflow_version) +AIRFLOW_V_2_10_PLUS = Version(AIRFLOW_VERSION.base_version) >= Version("2.10.0") +AIRFLOW_V_3_0_PLUS = Version(AIRFLOW_VERSION.base_version) >= Version("3.0.0") From de15523d7c709184c9042b1854dcadeb6d8e2e3e Mon Sep 17 00:00:00 2001 From: GPK Date: Thu, 14 Nov 2024 13:11:06 +0000 Subject: [PATCH 03/33] fix openlineage tests (#44025) --- providers/tests/openlineage/plugins/test_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/providers/tests/openlineage/plugins/test_utils.py b/providers/tests/openlineage/plugins/test_utils.py index cb294169bfd7..22b80120bb6a 100644 --- a/providers/tests/openlineage/plugins/test_utils.py +++ b/providers/tests/openlineage/plugins/test_utils.py @@ -346,14 +346,14 @@ def test_serialize_timetable(): "asset_condition": { "__type": DagAttributeTypes.ASSET_ANY, "objects": [ - {"__type": DagAttributeTypes.ASSET, "extra": {}, "uri": "2"}, + {"__type": DagAttributeTypes.ASSET, "extra": {}, "name": "2", "uri": "2"}, {"__type": DagAttributeTypes.ASSET_ANY, "objects": []}, - {"__type": DagAttributeTypes.ASSET, "extra": {}, "uri": "3"}, + {"__type": DagAttributeTypes.ASSET, "extra": {}, "name": "3", "uri": "3"}, { "__type": DagAttributeTypes.ASSET_ALL, "objects": [ {"__type": DagAttributeTypes.ASSET_ANY, "objects": []}, - {"__type": DagAttributeTypes.ASSET, "extra": {}, "uri": "4"}, + {"__type": DagAttributeTypes.ASSET, "extra": {}, "name": "4", "uri": "4"}, ], }, ], From 75de1d877108ee9859fd4c57054d6775daa27256 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Thu, 14 Nov 2024 14:31:50 +0000 Subject: [PATCH 04/33] Start building the replacement task runner for Task Execution SDK (#43893) The eventual goal if this "airflow.sdk.exeuction_time" package is to replace LocalTaskJob and StandardTaskRunner, but at this stage it co-exists with it's replacement. As this PR is not a complete re-implementation of all the features that exist currently (no handling of task level callbacks yet, no AirflowSkipException etc.) the current tests are skeleton at best. Once we get closer to feature parity (in future PRs) the tests will grow to match. This supervisor and task runner operates slightly differently to the current classes in these ways **Logs from the subprocess are send over a different channel to stdout/stderr** This makes the task supervisor a little bit more complex as it now has to read stdout, stderr and a logs channel. The advantage of this approach is that it makes the logs setup in the task process itself markedly simpler -- all it has to do is write logs output to the custom file handle as JSON and it will show up "natively" as logs. structlog has been chosen as the logging engine over stdlib's own logging as the ability to have structured fields in the logs is nice, and stdlib is configured to send logs to a stuctlog processor. **Direct database access is replaced with an HTTP API client** This is the crux of this feature and of AIP-72 in general -- tasks run via this runner can no longer access DB models or DB session directly. This PR doesn't yet implement the code/shims to make `Connection.get_connection_from_secrets` use this client yet - that will be future work. The reason tasks don't speak directly to the API server is primarily for two reasons: 1. The supervisor process already needs to maintain an http session in order to report the task as started, to heart beat it, and to mark it as finished; and so because of that 2. Reduce the number of active HTTP connections for tasks to 1 per task (instead of 2 per task). The other reason we have this interface is that DAG parsing code will very soon need to be updated to not have direct DB access either, and having this "in process" interface ability already means that we can support commands like `airflow dags reserialize` without having a running API server. The API client itself is not auto-generated: I tried a number of different client generates based on the OpenAPI spec and found them all lacking or buggy in different ways, and the http client side itself is very simple, the only interesting/difficult bit is the generation of the datamodels from the OpenAPI spec which I found one that --------- Co-authored-by: Kaxil Naik --- Dockerfile | 2 +- Dockerfile.ci | 2 +- airflow/utils/net.py | 4 +- scripts/docker/install_airflow.sh | 2 +- task_sdk/pyproject.toml | 59 +- task_sdk/src/airflow/sdk/__init__.py | 3 + task_sdk/src/airflow/sdk/api/__init__.py | 16 + task_sdk/src/airflow/sdk/api/client.py | 216 +++++++ .../airflow/sdk/api/datamodels/__init__.py | 16 + .../airflow/sdk/api/datamodels/_generated.py | 148 +++++ .../airflow/sdk/api/datamodels/activities.py | 31 + task_sdk/src/airflow/sdk/api/datamodels/ti.py | 32 + .../airflow/sdk/execution_time/__init__.py | 17 + .../src/airflow/sdk/execution_time/comms.py | 120 ++++ .../airflow/sdk/execution_time/supervisor.py | 599 ++++++++++++++++++ .../airflow/sdk/execution_time/task_runner.py | 191 ++++++ task_sdk/src/airflow/sdk/log.py | 372 +++++++++++ task_sdk/src/airflow/sdk/types.py | 2 +- task_sdk/tests/api/__init__.py | 16 + task_sdk/tests/api/test_client.py | 62 ++ task_sdk/tests/conftest.py | 58 ++ task_sdk/tests/defintions/__init__.py | 16 + .../tests/defintions/test_baseoperator.py | 19 + task_sdk/tests/execution_time/__init__.py | 16 + task_sdk/tests/execution_time/conftest.py | 33 + .../tests/execution_time/test_supervisor.py | 150 +++++ .../tests/execution_time/test_task_runner.py | 56 ++ tests/cli/commands/test_celery_command.py | 1 + 28 files changed, 2251 insertions(+), 8 deletions(-) create mode 100644 task_sdk/src/airflow/sdk/api/__init__.py create mode 100644 task_sdk/src/airflow/sdk/api/client.py create mode 100644 task_sdk/src/airflow/sdk/api/datamodels/__init__.py create mode 100644 task_sdk/src/airflow/sdk/api/datamodels/_generated.py create mode 100644 task_sdk/src/airflow/sdk/api/datamodels/activities.py create mode 100644 task_sdk/src/airflow/sdk/api/datamodels/ti.py create mode 100644 task_sdk/src/airflow/sdk/execution_time/__init__.py create mode 100644 task_sdk/src/airflow/sdk/execution_time/comms.py create mode 100644 task_sdk/src/airflow/sdk/execution_time/supervisor.py create mode 100644 task_sdk/src/airflow/sdk/execution_time/task_runner.py create mode 100644 task_sdk/src/airflow/sdk/log.py create mode 100644 task_sdk/tests/api/__init__.py create mode 100644 task_sdk/tests/api/test_client.py create mode 100644 task_sdk/tests/defintions/__init__.py create mode 100644 task_sdk/tests/execution_time/__init__.py create mode 100644 task_sdk/tests/execution_time/conftest.py create mode 100644 task_sdk/tests/execution_time/test_supervisor.py create mode 100644 task_sdk/tests/execution_time/test_task_runner.py diff --git a/Dockerfile b/Dockerfile index 5ca9949b0213..d9fb1878f116 100644 --- a/Dockerfile +++ b/Dockerfile @@ -890,7 +890,7 @@ function install_airflow() { # Similarly we need _a_ file for task_sdk too mkdir -p ./task_sdk/src/airflow/sdk/ - touch ./task_sdk/src/airflow/sdk/__init__.py + echo '__version__ = "0.0.0dev0"' > ./task_sdk/src/airflow/sdk/__init__.py trap 'rm -f ./providers/src/airflow/providers/__init__.py ./task_sdk/src/airflow/__init__.py 2>/dev/null' EXIT diff --git a/Dockerfile.ci b/Dockerfile.ci index 943270aec693..952993984e56 100644 --- a/Dockerfile.ci +++ b/Dockerfile.ci @@ -660,7 +660,7 @@ function install_airflow() { # Similarly we need _a_ file for task_sdk too mkdir -p ./task_sdk/src/airflow/sdk/ - touch ./task_sdk/src/airflow/sdk/__init__.py + echo '__version__ = "0.0.0dev0"' > ./task_sdk/src/airflow/sdk/__init__.py trap 'rm -f ./providers/src/airflow/providers/__init__.py ./task_sdk/src/airflow/__init__.py 2>/dev/null' EXIT diff --git a/airflow/utils/net.py b/airflow/utils/net.py index 992aee67e800..9fc79b3842c3 100644 --- a/airflow/utils/net.py +++ b/airflow/utils/net.py @@ -20,8 +20,6 @@ import socket from functools import lru_cache -from airflow.configuration import conf - # patched version of socket.getfqdn() - see https://github.com/python/cpython/issues/49254 @lru_cache(maxsize=None) @@ -53,4 +51,6 @@ def get_host_ip_address(): def get_hostname(): """Fetch the hostname using the callable from config or use `airflow.utils.net.getfqdn` as a fallback.""" + from airflow.configuration import conf + return conf.getimport("core", "hostname_callable", fallback="airflow.utils.net.getfqdn")() diff --git a/scripts/docker/install_airflow.sh b/scripts/docker/install_airflow.sh index 2975c50c2d61..27dd25ba2608 100644 --- a/scripts/docker/install_airflow.sh +++ b/scripts/docker/install_airflow.sh @@ -54,7 +54,7 @@ function install_airflow() { # Similarly we need _a_ file for task_sdk too mkdir -p ./task_sdk/src/airflow/sdk/ - touch ./task_sdk/src/airflow/sdk/__init__.py + echo '__version__ = "0.0.0dev0"' > ./task_sdk/src/airflow/sdk/__init__.py trap 'rm -f ./providers/src/airflow/providers/__init__.py ./task_sdk/src/airflow/__init__.py 2>/dev/null' EXIT diff --git a/task_sdk/pyproject.toml b/task_sdk/pyproject.toml index f290dfb17fdb..5da673a79bf0 100644 --- a/task_sdk/pyproject.toml +++ b/task_sdk/pyproject.toml @@ -17,20 +17,30 @@ [project] name = "apache-airflow-task-sdk" -version = "0.1.0.dev0" +dynamic = ["version"] description = "Python Task SDK for Apache Airflow DAG Authors" readme = { file = "README.md", content-type = "text/markdown" } requires-python = ">=3.9, <3.13" dependencies = [ "attrs>=24.2.0", "google-re2>=1.1.20240702", + "httpx>=0.27.0", "methodtools>=0.4.7", + "msgspec>=0.18.6", + "psutil>=6.1.0", + "structlog>=24.4.0", +] +classifiers = [ + "Framework :: Apache Airflow", ] [build-system] requires = ["hatchling"] build-backend = "hatchling.build" +[tool.hatch.version] +path = "src/airflow/sdk/__init__.py" + [tool.hatch.build.targets.wheel] packages = ["src/airflow"] # This file only exists to make pyright/VSCode happy, don't ship it @@ -46,11 +56,24 @@ namespace-packages = ["src/airflow"] # Ignore Doc rules et al for anything outside of tests "!src/*" = ["D", "TID253", "S101", "TRY002"] -"src/airflow/sdk/__init__.py" = ["TCH004"] +# Ignore the pytest rules outside the tests folder - https://github.com/astral-sh/ruff/issues/14205 +"!tests/*" = ["PT"] # Pycharm barfs if this "stub" file has future imports "src/airflow/__init__.py" = ["I002"] +"src/airflow/sdk/__init__.py" = ["TCH004"] + +# msgspec needs types for annotations to be defined, even with future +# annotations, so disable the "type check only import" for these files +"src/airflow/sdk/api/datamodels/*.py" = ["TCH001"] + +# Only the public API should _require_ docstrings on classes +"!src/airflow/sdk/definitions/*" = ["D101"] + +# Generated file, be less strict +"src/airflow/sdk/*/_generated.py" = ["D"] + [tool.uv] dev-dependencies = [ "kgb>=7.1.1", @@ -59,6 +82,7 @@ dev-dependencies = [ "pytest>=8.3.3", ] + [tool.coverage.run] branch = true relative_files = true @@ -76,3 +100,34 @@ exclude_also = [ "@(typing(_extensions)?\\.)?overload", "if (typing(_extensions)?\\.)?TYPE_CHECKING:", ] + +[dependency-groups] +codegen = [ + "datamodel-code-generator[http]>=0.26.3", +] + +[tool.black] +# This is needed for datamodel-codegen to treat this as the "project" file + +# To use: +# +# uv run --group codegen --project apache-airflow-task-sdk --directory task_sdk datamodel-codegen +[tool.datamodel-codegen] +capitalise-enum-members=true # `State.RUNNING` not `State.running` +disable-timestamp=true +enable-version-header=true +enum-field-as-literal='one' # When a single enum member, make it output a `Literal["..."]` +input-file-type='openapi' +output-model-type='pydantic_v2.BaseModel' +output-datetime-class='datetime' +target-python-version='3.9' +use-annotated=true +use-default=true +use-double-quotes=true +use-schema-description=true # Desc becomes class doc comment +use-standard-collections=true # list[] not List[] +use-subclass-enum=true # enum, not union of Literals +use-union-operator=true # 3.9+annotations, not `Union[]` + +url = 'http://0.0.0.0:9091/execution/openapi.json' +output = 'src/airflow/sdk/api/datamodels/_generated.py' diff --git a/task_sdk/src/airflow/sdk/__init__.py b/task_sdk/src/airflow/sdk/__init__.py index f538baedff01..bd882f43dd0b 100644 --- a/task_sdk/src/airflow/sdk/__init__.py +++ b/task_sdk/src/airflow/sdk/__init__.py @@ -25,8 +25,11 @@ "Label", "TaskGroup", "dag", + "__version__", ] +__version__ = "1.0.0.dev1" + if TYPE_CHECKING: from airflow.sdk.definitions.baseoperator import BaseOperator from airflow.sdk.definitions.dag import DAG, dag diff --git a/task_sdk/src/airflow/sdk/api/__init__.py b/task_sdk/src/airflow/sdk/api/__init__.py new file mode 100644 index 000000000000..13a83393a912 --- /dev/null +++ b/task_sdk/src/airflow/sdk/api/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/task_sdk/src/airflow/sdk/api/client.py b/task_sdk/src/airflow/sdk/api/client.py new file mode 100644 index 000000000000..ece3bc96009b --- /dev/null +++ b/task_sdk/src/airflow/sdk/api/client.py @@ -0,0 +1,216 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import sys +import uuid +from typing import TYPE_CHECKING, Any + +import httpx +import methodtools +import structlog +from pydantic import BaseModel +from uuid6 import uuid7 + +from airflow.sdk import __version__ +from airflow.sdk.api.datamodels._generated import ( + ConnectionResponse, + State1 as TerminalState, + TaskInstanceState, + TIEnterRunningPayload, + TITerminalStatePayload, + ValidationError as RemoteValidationError, +) +from airflow.utils.net import get_hostname +from airflow.utils.platform import getuser + +if TYPE_CHECKING: + from datetime import datetime + + +log = structlog.get_logger(logger_name=__name__) + +__all__ = [ + "Client", + "ConnectionOperations", + "ErrorBody", + "ServerResponseError", + "TaskInstanceOperations", +] + + +def get_json_error(response: httpx.Response): + """Raise a ServerResponseError if we can extract error info from the error.""" + err = ServerResponseError.from_response(response) + if err: + log.warning("Server error", detail=err.detail) + raise err + + +def raise_on_4xx_5xx(response: httpx.Response): + return get_json_error(response) or response.raise_for_status() + + +# Py 3.11+ version +def raise_on_4xx_5xx_with_note(response: httpx.Response): + try: + return get_json_error(response) or response.raise_for_status() + except httpx.HTTPStatusError as e: + if TYPE_CHECKING: + assert hasattr(e, "add_note") + e.add_note( + f"Correlation-id={response.headers.get('correlation-id', None) or response.request.headers.get('correlation-id', 'no-correlction-id')}" + ) + raise + + +if hasattr(BaseException, "add_note"): + # Py 3.11+ + raise_on_4xx_5xx = raise_on_4xx_5xx_with_note + + +def add_correlation_id(request: httpx.Request): + request.headers["correlation-id"] = str(uuid7()) + + +class TaskInstanceOperations: + __slots__ = ("client",) + + def __init__(self, client: Client): + self.client = client + + def start(self, id: uuid.UUID, pid: int, when: datetime): + """Tell the API server that this TI has started running.""" + body = TIEnterRunningPayload(pid=pid, hostname=get_hostname(), unixname=getuser(), start_date=when) + + self.client.patch(f"task-instance/{id}/state", content=body.model_dump_json()) + + def finish(self, id: uuid.UUID, state: TaskInstanceState, when: datetime): + """Tell the API server that this TI has reached a terminal state.""" + body = TITerminalStatePayload(end_date=when, state=TerminalState(state)) + + self.client.patch(f"task-instance/{id}/state", content=body.model_dump_json()) + + def heartbeat(self, id: uuid.UUID): + self.client.put(f"task-instance/{id}/heartbeat") + + +class ConnectionOperations: + __slots__ = ("client", "decoder") + + def __init__(self, client: Client): + self.client = client + + def get(self, id: str) -> ConnectionResponse: + """Get a connection from the API server.""" + resp = self.client.get(f"connection/{id}") + return ConnectionResponse.model_validate_json(resp.read()) + + +class BearerAuth(httpx.Auth): + def __init__(self, token: str): + self.token: str = token + + def auth_flow(self, request: httpx.Request): + if self.token: + request.headers["Authorization"] = "Bearer " + self.token + yield request + + +# This exists as a aid for debugging or local running via the `dry_run` argument to Client. It doesn't make +# sense for returning connections etc. +def noop_handler(request: httpx.Request) -> httpx.Response: + log.debug("Dry-run request", method=request.method, path=request.url.path) + return httpx.Response(200, json={"text": "Hello, world!"}) + + +class Client(httpx.Client): + def __init__(self, *, base_url: str | None, dry_run: bool = False, token: str, **kwargs: Any): + if (not base_url) ^ dry_run: + raise ValueError(f"Can only specify one of {base_url=} or {dry_run=}") + auth = BearerAuth(token) + + if dry_run: + # If dry run is requested, install a no op handler so that simple tasks can "heartbeat" using a + # real client, but just don't make any HTTP requests + kwargs["transport"] = httpx.MockTransport(noop_handler) + kwargs["base_url"] = "dry-run://server" + else: + kwargs["base_url"] = base_url + pyver = f"{'.'.join(map(str, sys.version_info[:3]))}" + super().__init__( + auth=auth, + headers={"user-agent": f"apache-airflow-task-sdk/{__version__} (Python/{pyver})"}, + event_hooks={"response": [raise_on_4xx_5xx], "request": [add_correlation_id]}, + **kwargs, + ) + + # We "group" or "namespace" operations by what they operate on, rather than a flat namespace with all + # methods on one object prefixed with the object type (`.task_instances.update` rather than + # `task_instance_update` etc.) + + @methodtools.lru_cache() # type: ignore[misc] + @property + def task_instances(self) -> TaskInstanceOperations: + """Operations related to TaskInstances.""" + return TaskInstanceOperations(self) + + @methodtools.lru_cache() # type: ignore[misc] + @property + def connections(self) -> ConnectionOperations: + """Operations related to TaskInstances.""" + return ConnectionOperations(self) + + +class ErrorBody(BaseModel): + detail: list[RemoteValidationError] | dict[str, Any] + + def __repr__(self): + return repr(self.detail) + + +class ServerResponseError(httpx.HTTPStatusError): + def __init__(self, message: str, *, request: httpx.Request, response: httpx.Response): + super().__init__(message, request=request, response=response) + + detail: ErrorBody + + @classmethod + def from_response(cls, response: httpx.Response) -> ServerResponseError | None: + if response.is_success: + return None + # 4xx or 5xx error? + if 400 < (response.status_code // 100) >= 600: + return None + + if response.headers.get("content-type") != "application/json": + return None + + try: + err = ErrorBody.model_validate_json(response.read()) + if isinstance(err.detail, list): + msg = "Remote server returned validation error" + else: + msg = err.detail.get("message", "") or "Un-parseable error" + except Exception: + err = ErrorBody.model_validate_json(response.content) + msg = "Server returned error" + + self = cls(msg, request=response.request, response=response) + self.detail = err + return self diff --git a/task_sdk/src/airflow/sdk/api/datamodels/__init__.py b/task_sdk/src/airflow/sdk/api/datamodels/__init__.py new file mode 100644 index 000000000000..13a83393a912 --- /dev/null +++ b/task_sdk/src/airflow/sdk/api/datamodels/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py new file mode 100644 index 000000000000..f41508cae2a2 --- /dev/null +++ b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -0,0 +1,148 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# generated by datamodel-codegen: +# filename: http://0.0.0.0:9091/execution/openapi.json +# version: 0.26.3 + +from __future__ import annotations + +from datetime import datetime +from enum import Enum +from typing import Annotated, Any, Literal + +from pydantic import BaseModel, Field + + +class ConnectionResponse(BaseModel): + """ + Connection schema for responses with fields that are needed for Runtime. + """ + + conn_id: Annotated[str, Field(title="Conn Id")] + conn_type: Annotated[str, Field(title="Conn Type")] + host: Annotated[str | None, Field(title="Host")] = None + schema_: Annotated[str | None, Field(alias="schema", title="Schema")] = None + login: Annotated[str | None, Field(title="Login")] = None + password: Annotated[str | None, Field(title="Password")] = None + port: Annotated[int | None, Field(title="Port")] = None + extra: Annotated[str | None, Field(title="Extra")] = None + + +class TIEnterRunningPayload(BaseModel): + """ + Schema for updating TaskInstance to 'RUNNING' state with minimal required fields. + """ + + state: Annotated[Literal["running"] | None, Field(title="State")] = "running" + hostname: Annotated[str, Field(title="Hostname")] + unixname: Annotated[str, Field(title="Unixname")] + pid: Annotated[int, Field(title="Pid")] + start_date: Annotated[datetime, Field(title="Start Date")] + + +class TIHeartbeatInfo(BaseModel): + """ + Schema for TaskInstance heartbeat endpoint. + """ + + hostname: Annotated[str, Field(title="Hostname")] + pid: Annotated[int, Field(title="Pid")] + + +class State(Enum): + REMOVED = "removed" + SCHEDULED = "scheduled" + QUEUED = "queued" + RUNNING = "running" + RESTARTING = "restarting" + UP_FOR_RETRY = "up_for_retry" + UP_FOR_RESCHEDULE = "up_for_reschedule" + UPSTREAM_FAILED = "upstream_failed" + DEFERRED = "deferred" + + +class TITargetStatePayload(BaseModel): + """ + Schema for updating TaskInstance to a target state, excluding terminal and running states. + """ + + state: State + + +class State1(Enum): + FAILED = "failed" + SUCCESS = "success" + SKIPPED = "skipped" + + +class TITerminalStatePayload(BaseModel): + """ + Schema for updating TaskInstance to a terminal state (e.g., SUCCESS or FAILED). + """ + + state: Annotated[State1, Field(title="TerminalState")] + end_date: Annotated[datetime, Field(title="End Date")] + + +class TaskInstanceState(str, Enum): + """ + All possible states that a Task Instance can be in. + + Note that None is also allowed, so always use this in a type hint with Optional. + """ + + REMOVED = "removed" + SCHEDULED = "scheduled" + QUEUED = "queued" + RUNNING = "running" + SUCCESS = "success" + RESTARTING = "restarting" + FAILED = "failed" + UP_FOR_RETRY = "up_for_retry" + UP_FOR_RESCHEDULE = "up_for_reschedule" + UPSTREAM_FAILED = "upstream_failed" + SKIPPED = "skipped" + DEFERRED = "deferred" + + +class ValidationError(BaseModel): + loc: Annotated[list[str | int], Field(title="Location")] + msg: Annotated[str, Field(title="Message")] + type: Annotated[str, Field(title="Error Type")] + + +class VariableResponse(BaseModel): + """ + Variable schema for responses with fields that are needed for Runtime. + """ + + key: Annotated[str, Field(title="Key")] + value: Annotated[str | None, Field(title="Value")] = None + + +class XComResponse(BaseModel): + """ + XCom schema for responses with fields that are needed for Runtime. + """ + + key: Annotated[str, Field(title="Key")] + value: Annotated[Any, Field(title="Value")] + + +class HTTPValidationError(BaseModel): + detail: Annotated[list[ValidationError] | None, Field(title="Detail")] = None diff --git a/task_sdk/src/airflow/sdk/api/datamodels/activities.py b/task_sdk/src/airflow/sdk/api/datamodels/activities.py new file mode 100644 index 000000000000..04f2b389d5db --- /dev/null +++ b/task_sdk/src/airflow/sdk/api/datamodels/activities.py @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import os + +from pydantic import BaseModel + +from airflow.sdk.api.datamodels.ti import TaskInstance + + +class ExecuteTaskActivity(BaseModel): + ti: TaskInstance + path: os.PathLike[str] + token: str + """The identity token for this workload""" diff --git a/task_sdk/src/airflow/sdk/api/datamodels/ti.py b/task_sdk/src/airflow/sdk/api/datamodels/ti.py new file mode 100644 index 000000000000..ce9e1e870ae2 --- /dev/null +++ b/task_sdk/src/airflow/sdk/api/datamodels/ti.py @@ -0,0 +1,32 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import uuid + +from pydantic import BaseModel + + +class TaskInstance(BaseModel): + id: uuid.UUID + + task_id: str + dag_id: str + run_id: str + try_number: int + map_index: int | None = None diff --git a/task_sdk/src/airflow/sdk/execution_time/__init__.py b/task_sdk/src/airflow/sdk/execution_time/__init__.py new file mode 100644 index 000000000000..217e5db96078 --- /dev/null +++ b/task_sdk/src/airflow/sdk/execution_time/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/task_sdk/src/airflow/sdk/execution_time/comms.py b/task_sdk/src/airflow/sdk/execution_time/comms.py new file mode 100644 index 000000000000..3128e98bf437 --- /dev/null +++ b/task_sdk/src/airflow/sdk/execution_time/comms.py @@ -0,0 +1,120 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +r""" +Communication protocol between the Supervisor and the task process +================================================================== + +* All communication is done over stdout/stdin in the form of "JSON lines" (each + message is a single JSON document terminated by `\n` character) +* Messages from the subprocess are all log messages and are sent directly to the log +* No messages are sent to task process except in response to a request. (This is because the task process will + be running user's code, so we can't read from stdin until we enter our code, such as when requesting an XCom + value etc.) + +The reason this communication protocol exists, rather than the task process speaking directly to the Task +Execution API server is because: + +1. To reduce the number of concurrent HTTP connections on the API server. + + The supervisor already has to speak to that to heartbeat the running Task, so having the task speak to its + parent process and having all API traffic go through that means that the number of HTTP connections is + "halved". (Not every task will make API calls, so it's not always halved, but it is reduced.) + +2. This means that the user Task code doesn't ever directly see the task identity JWT token. + + This is a short lived token tied to one specific task instance try, so it being leaked/exfiltrated is not a + large risk, but it's easy to not give it to the user code, so lets do that. +""" # noqa: D400, D205 + +from __future__ import annotations + +from typing import Annotated, Any, Literal, Union + +from pydantic import BaseModel, ConfigDict, Field + +from airflow.sdk.api.datamodels._generated import TaskInstanceState # noqa: TCH001 +from airflow.sdk.api.datamodels.ti import TaskInstance # noqa: TCH001 + + +class StartupDetails(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + ti: TaskInstance + file: str + requests_fd: int + """ + The channel for the task to send requests over. + + Responses will come back on stdin + """ + type: Literal["StartupDetails"] = "StartupDetails" + + +class XComResponse(BaseModel): + """Response to ReadXCom request.""" + + key: str + value: Any + + type: Literal["XComResponse"] = "XComResponse" + + +class ConnectionResponse(BaseModel): + conn: Any + + type: Literal["ConnectionResponse"] = "ConnectionResponse" + + +ToTask = Annotated[ + Union[StartupDetails, XComResponse, ConnectionResponse], + Field(discriminator="type"), +] + + +class TaskState(BaseModel): + """ + Update a task's state. + + If a process exits without sending one of these the state will be derived from the exit code: + - 0 = SUCCESS + - anything else = FAILED + """ + + state: TaskInstanceState + type: Literal["TaskState"] = "TaskState" + + +class ReadXCom(BaseModel): + key: str + type: Literal["ReadXCom"] = "ReadXCom" + + +class GetConnection(BaseModel): + id: str + type: Literal["GetConnection"] = "GetConnection" + + +class GetVariable(BaseModel): + id: str + type: Literal["GetVariable"] = "GetVariable" + + +ToSupervisor = Annotated[ + Union[TaskState, ReadXCom, GetConnection, GetVariable], + Field(discriminator="type"), +] diff --git a/task_sdk/src/airflow/sdk/execution_time/supervisor.py b/task_sdk/src/airflow/sdk/execution_time/supervisor.py new file mode 100644 index 000000000000..3c0623ba1b04 --- /dev/null +++ b/task_sdk/src/airflow/sdk/execution_time/supervisor.py @@ -0,0 +1,599 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Supervise and run Tasks in a subprocess.""" + +from __future__ import annotations + +import atexit +import io +import logging +import os +import selectors +import signal +import sys +import time +import weakref +from collections.abc import Generator +from contextlib import suppress +from datetime import datetime, timezone +from socket import socket, socketpair +from typing import TYPE_CHECKING, BinaryIO, Callable, ClassVar, Literal, NoReturn, cast, overload +from uuid import UUID + +import attrs +import httpx +import msgspec +import psutil +import structlog +from pydantic import TypeAdapter + +from airflow.sdk.api.client import Client +from airflow.sdk.api.datamodels._generated import TaskInstanceState +from airflow.sdk.execution_time.comms import ( + ConnectionResponse, + GetConnection, + StartupDetails, + ToSupervisor, +) + +if TYPE_CHECKING: + from structlog.typing import FilteringBoundLogger + + from airflow.sdk.api.datamodels.activities import ExecuteTaskActivity + from airflow.sdk.api.datamodels.ti import TaskInstance + + +__all__ = ["WatchedSubprocess", "supervise"] + +log: FilteringBoundLogger = structlog.get_logger(logger_name="supervisor") + +# TODO: Pull this from config +SLOWEST_HEARTBEAT_INTERVAL: int = 30 +# Don't heartbeat more often than this +FASTEST_HEARTBEAT_INTERVAL: int = 5 + + +@overload +def mkpipe() -> tuple[socket, socket]: ... + + +@overload +def mkpipe(remote_read: Literal[True]) -> tuple[socket, BinaryIO]: ... + + +def mkpipe( + remote_read: bool = False, +) -> tuple[socket, socket | BinaryIO]: + """ + Create a pair of connected sockets. + + The inheritable flag will be set correctly so that the end destined for the subprocess is kept open but + the end for this process is closed automatically by the OS. + """ + rsock, wsock = socketpair() + local, remote = (wsock, rsock) if remote_read else (rsock, wsock) + + remote.set_inheritable(True) + local.setblocking(False) + + io: BinaryIO | socket + if remote_read: + # If _we_ are writing, we don't want to buffer + io = cast(BinaryIO, local.makefile("wb", buffering=0)) + else: + io = local + + return remote, io + + +def _subprocess_main(): + from airflow.sdk.execution_time.task_runner import main + + main() + + +def _reset_signals(): + # Uninstall the rich etc. exception handler + sys.excepthook = sys.__excepthook__ + signal.signal(signal.SIGINT, signal.SIG_DFL) + signal.signal(signal.SIGUSR2, signal.SIG_DFL) + + +def _configure_logs_over_json_channel(log_fd: int): + # A channel that the task can send JSON-formated logs over. + # + # JSON logs sent this way will be handled nicely + from airflow.sdk.log import configure_logging + + log_io = os.fdopen(log_fd, "wb", buffering=0) + configure_logging(enable_pretty_log=False, output=log_io) + + +def _reopen_std_io_handles(child_stdin, child_stdout, child_stderr): + if "PYTEST_CURRENT_TEST" in os.environ: + # When we are running in pytest, it's output capturing messes us up. This works around it + sys.stdout = sys.__stdout__ + sys.stderr = sys.__stderr__ + + # Ensure that sys.stdout et al (and the underlying filehandles for C libraries etc) are connected to the + # pipes form the supervisor + + for handle_name, sock, mode, close in ( + ("stdin", child_stdin, "r", True), + ("stdout", child_stdout, "w", True), + ("stderr", child_stderr, "w", False), + ): + handle = getattr(sys, handle_name) + try: + fd = handle.fileno() + os.dup2(sock.fileno(), fd) + if close: + handle.close() + except io.UnsupportedOperation: + if "PYTEST_CURRENT_TEST" in os.environ: + # When we're running under pytest, the stdin is not a real filehandle with an fd, so we need + # to handle that differently + fd = sock.fileno() + else: + raise + + setattr(sys, handle_name, os.fdopen(fd, mode)) + + +def _fork_main( + child_stdin: socket, + child_stdout: socket, + child_stderr: socket, + log_fd: int, + target: Callable[[], None], +) -> NoReturn: + """ + "Entrypoint" of the child process. + + Ultimately this process will be running the user's code in the operators ``execute()`` function. + + The responsibility of this function is to: + + - Reset any signals handlers we inherited from the parent process (so they don't fire twice - once in + parent, and once in child) + - Set up the out/err handles to the streams created in the parent (to capture stdout and stderr for + logging) + - Configure the loggers in the child (both stdlib logging and Structlog) to send JSON logs back to the + supervisor for processing/output. + - Catch un-handled exceptions and attempt to show _something_ in case of error + - Finally, run the actual task runner code (``target`` argument, defaults to ``.task_runner:main`) + """ + # TODO: Make this process a session leader + + # Store original stderr for last-chance exception handling + last_chance_stderr = sys.__stderr__ or sys.stderr + + _reset_signals() + if log_fd: + _configure_logs_over_json_channel(log_fd) + _reopen_std_io_handles(child_stdin, child_stdout, child_stderr) + + def exit(n: int) -> NoReturn: + with suppress(ValueError, OSError): + sys.stdout.flush() + with suppress(ValueError, OSError): + sys.stderr.flush() + with suppress(ValueError, OSError): + last_chance_stderr.flush() + os._exit(n) + + if hasattr(atexit, "_clear"): + # Since we're in a fork we want to try and clear them. If we can't do it cleanly, then we won't try + # and run new atexit handlers. + with suppress(Exception): + atexit._clear() + base_exit = exit + + def exit(n: int) -> NoReturn: + # This will only run any atexit funcs registered after we've forked. + atexit._run_exitfuncs() + base_exit(n) + + try: + target() + exit(0) + except SystemExit as e: + code = 1 + if isinstance(e.code, int): + code = e.code + elif e.code: + print(e.code, file=sys.stderr) + exit(code) + except Exception: + # Last ditch log attempt + exc, v, tb = sys.exc_info() + + import traceback + + try: + last_chance_stderr.write("--- Last chance exception handler ---\n") + traceback.print_exception(exc, value=v, tb=tb, file=last_chance_stderr) + # Exit code 126 and 125 don't have any "special" meaning, they are only meant to serve as an + # identifier that the task process died in a really odd way. + exit(126) + except Exception as e: + with suppress(Exception): + print( + f"--- Last chance exception handler failed --- {repr(str(e))}\n", file=last_chance_stderr + ) + exit(125) + + +@attrs.define() +class WatchedSubprocess: + ti_id: UUID + pid: int + + stdin: BinaryIO + stdout: socket + stderr: socket + + client: Client + + _process: psutil.Process + _exit_code: int | None = None + _terminal_state: str | None = None + + _last_heartbeat: float = 0 + + selector: selectors.BaseSelector = attrs.field(factory=selectors.DefaultSelector) + + procs: ClassVar[weakref.WeakValueDictionary[int, WatchedSubprocess]] = weakref.WeakValueDictionary() + + def __attrs_post_init__(self): + self.procs[self.pid] = self + + @classmethod + def start( + cls, + path: str | os.PathLike[str], + ti: TaskInstance, + client: Client, + target: Callable[[], None] = _subprocess_main, + ) -> WatchedSubprocess: + """Fork and start a new subprocess to execute the given task.""" + # Create socketpairs/"pipes" to connect to the stdin and out from the subprocess + child_stdin, feed_stdin = mkpipe(remote_read=True) + child_stdout, read_stdout = mkpipe() + child_stderr, read_stderr = mkpipe() + + # Open these socketpair before forking off the child, so that it is open when we fork. + child_comms, read_msgs = mkpipe() + child_logs, read_logs = mkpipe() + + pid = os.fork() + if pid == 0: + # Parent ends of the sockets are closed by the OS as they are set as non-inheritable + + # Run the child entryoint + _fork_main(child_stdin, child_stdout, child_stderr, child_logs.fileno(), target) + + proc = cls( + ti_id=ti.id, + pid=pid, + stdin=feed_stdin, + stdout=read_stdout, + stderr=read_stderr, + process=psutil.Process(pid), + client=client, + ) + + # We've forked, but the task won't start until we send it the StartupDetails message. But before we do + # that, we need to tell the server it's started (so it has the chance to tell us "no, stop!" for any + # reason) + try: + client.task_instances.start(ti.id, pid, datetime.now(tz=timezone.utc)) + proc._last_heartbeat = time.monotonic() + except Exception: + # On any error kill that subprocess! + proc.kill(signal.SIGKILL) + raise + + # TODO: Use logging providers to handle the chunked upload for us + task_logger: FilteringBoundLogger = structlog.get_logger(logger_name="task").bind() + + # proc.selector is a way of registering a handler/callback to be called when the given IO channel has + # activity to read on (https://www.man7.org/linux/man-pages/man2/select.2.html etc, but better + # alternatives are used automatically) -- this is a way of having "event-based" code, but without + # needing full async, to read and process output from each socket as it is received. + + cb = make_buffered_socket_reader(forward_to_log(task_logger.bind(chan="stdout"), level=logging.INFO)) + proc.selector.register(read_stdout, selectors.EVENT_READ, cb) + + cb = make_buffered_socket_reader(forward_to_log(task_logger.bind(chan="stderr"), level=logging.ERROR)) + proc.selector.register(read_stderr, selectors.EVENT_READ, cb) + + proc.selector.register( + read_logs, + selectors.EVENT_READ, + make_buffered_socket_reader(process_log_messages_from_subprocess(task_logger)), + ) + proc.selector.register( + read_msgs, + selectors.EVENT_READ, + make_buffered_socket_reader(proc.handle_requests(log=log)), + ) + + # Close the remaining parent-end of the sockets we've passed to the child via fork. We still have the + # other end of the pair open + child_stdout.close() + child_stdin.close() + child_comms.close() + child_logs.close() + + # Tell the task process what it needs to do! + + msg = StartupDetails( + ti=ti, + file=str(path), + requests_fd=child_comms.fileno(), + ) + + # Send the message to tell the process what it needs to execute + log.debug("Sending", msg=msg) + feed_stdin.write(msg.model_dump_json().encode()) + feed_stdin.write(b"\n") + + return proc + + def kill(self, signal: signal.Signals = signal.SIGINT): + if self._exit_code is not None: + return + + with suppress(ProcessLookupError): + os.kill(self.pid, signal) + + def wait(self) -> int: + if self._exit_code is not None: + return self._exit_code + + # Until we have a selector for the process, don't poll for more than 10s, just in case it exists but + # doesn't produce any output + max_poll_interval = 10 + + try: + while self._exit_code is None or len(self.selector.get_map()): + last_heartbeat_ago = time.monotonic() - self._last_heartbeat + # Monitor the task to see if it's done. Wait in a syscall (`select`) for as long as possible + # so we notice the subprocess finishing as quick as we can. + max_wait_time = max( + 0, # Make sure this value is never negative, + min( + # Ensure we heartbeat _at most_ 75% through time the zombie threshold time + SLOWEST_HEARTBEAT_INTERVAL - last_heartbeat_ago * 0.75, + max_poll_interval, + ), + ) + events = self.selector.select(timeout=max_wait_time) + for key, _ in events: + socket_handler = key.data + need_more = socket_handler(key.fileobj) + + if not need_more: + self.selector.unregister(key.fileobj) + key.fileobj.close() # type: ignore[union-attr] + + if self._exit_code is None: + try: + self._exit_code = self._process.wait(timeout=0) + log.debug("Task process exited", exit_code=self._exit_code) + except psutil.TimeoutExpired: + pass + + if last_heartbeat_ago < FASTEST_HEARTBEAT_INTERVAL: + # Avoid heartbeating too frequently + continue + + try: + self.client.task_instances.heartbeat(self.ti_id) + self._last_heartbeat = time.monotonic() + except Exception: + log.warning("Couldn't heartbeat", exc_info=True) + # TODO: If we couldn't heartbeat for X times the interval, kill ourselves + pass + finally: + self.selector.close() + + self.client.task_instances.finish( + id=self.ti_id, state=self.final_state, when=datetime.now(tz=timezone.utc) + ) + return self._exit_code + + @property + def final_state(self): + """ + The final state of the TaskInstance. + + By default this will be derived from the exit code of the task + (0=success, failed otherwise) but can be changed by the subprocess + sending a TaskState message, as long as the process exits with 0 + + Not valid before the process has finished. + """ + if self._exit_code == 0: + return self._terminal_state or TaskInstanceState.SUCCESS + return TaskInstanceState.FAILED + + def __rich_repr__(self): + yield "pid", self.pid + yield "exit_code", self._exit_code, None + + __rich_repr__.angular = True # type: ignore[attr-defined] + + def __repr__(self) -> str: + rep = f"" + + def handle_requests(self, log: FilteringBoundLogger) -> Generator[None, bytes, None]: + encoder = ConnectionResponse.model_dump_json + # Use a buffer to avoid small allocations + buffer = bytearray(64) + + decoder = TypeAdapter[ToSupervisor](ToSupervisor) + + while True: + line = yield + + try: + msg = decoder.validate_json(line) + except Exception: + log.exception("Unable to decode message", line=line) + continue + + # if isinstnace(msg, TaskState): + # self._terminal_state = msg.state + # elif isinstance(msg, ReadXCom): + # resp = XComResponse(key="secret", value=True) + # encoder.encode_into(resp, buffer) + # self.stdin.write(buffer + b"\n") + if isinstance(msg, GetConnection): + conn = self.client.connections.get(msg.id) + resp = ConnectionResponse(conn=conn) + encoded_resp = encoder(resp) + buffer.extend(encoded_resp.encode()) + else: + log.error("Unhandled request", msg=msg) + continue + + buffer.extend(b"\n") + self.stdin.write(buffer) + + # Ensure the buffer doesn't grow and stay large if a large payload is used. This won't grow it + # larger than it is, but it will shrink it + if len(buffer) > 1024: + buffer = buffer[:1024] + + +# Sockets, even the `.makefile()` function don't correctly do line buffering on reading. If a chunk is read +# and it doesn't contain a new line character, `.readline()` will just return the chunk as is. +# +# This returns a callback suitable for attaching to a `selector` that reads in to a buffer, and yields lines +# to a (sync) generator +def make_buffered_socket_reader( + gen: Generator[None, bytes, None], buffer_size: int = 4096 +) -> Callable[[socket], bool]: + buffer = bytearray() # This will hold our accumulated binary data + read_buffer = bytearray(buffer_size) # Temporary buffer for each read + + # We need to start up the generator to get it to the point it's at waiting on the yield + next(gen) + + def cb(sock: socket): + nonlocal buffer, read_buffer + # Read up to `buffer_size` bytes of data from the socket + n_received = sock.recv_into(read_buffer) + + if not n_received: + # If no data is returned, the connection is closed. Return whatever is left in the buffer + if len(buffer): + gen.send(buffer) + # Tell loop to close this selector + return False + + buffer.extend(read_buffer[:n_received]) + + # We could have read multiple lines in one go, yield them all + while (newline_pos := buffer.find(b"\n")) != -1: + if TYPE_CHECKING: + # We send in a memoryvuew, but pretend it's a bytes, as Buffer is only in 3.12+ + line = buffer[: newline_pos + 1] + else: + line = memoryview(buffer)[: newline_pos + 1] # Include the newline character + gen.send(line) + buffer = buffer[newline_pos + 1 :] # Update the buffer with remaining data + + return True + + return cb + + +def process_log_messages_from_subprocess(log: FilteringBoundLogger) -> Generator[None, bytes, None]: + from structlog.stdlib import NAME_TO_LEVEL + + while True: + # Generator receive syntax, values are "sent" in by the `make_buffered_socket_reader` and returned to + # the yield. + line = yield + + try: + event = msgspec.json.decode(line) + except Exception: + log.exception("Malformed json log line", line=line) + continue + + if ts := event.get("timestamp"): + # We use msgspec to decode the timestamp as it does it orders of magnitude quicker than + # datetime.strptime cn + # + # We remove the timezone info here, as the json encoding has `+00:00`, and since the log came + # from a subprocess we know that the timezone of the log message is the same, so having some + # messages include tz (from subprocess) but others not (ones from supervisor process) is + # confusing. + event["timestamp"] = msgspec.json.decode(f'"{ts}"', type=datetime).replace(tzinfo=None) + + if exc := event.pop("exception", None): + # TODO: convert the dict back to a pretty stack trace + event["error_detail"] = exc + log.log(NAME_TO_LEVEL[event.pop("level")], event.pop("event", None), **event) + + +def forward_to_log(target_log: FilteringBoundLogger, level: int) -> Generator[None, bytes, None]: + while True: + buf = yield + line = bytes(buf) + # Strip off new line + line = line.rstrip() + try: + msg = line.decode("utf-8", errors="replace") + target_log.log(level, msg) + except UnicodeDecodeError: + msg = line.decode("ascii", errors="replace") + target_log.log(level, msg) + + +def supervise(activity: ExecuteTaskActivity, server: str | None = None, dry_run: bool = False) -> int: + """ + Run a single task execution to completion. + + Returns the exit code of the process + """ + # One or the other + if (server == "") ^ dry_run: + raise ValueError(f"Can only specify one of {server=} or {dry_run=}") + + if not activity.path: + raise ValueError("path filed of activity missing") + + limits = httpx.Limits(max_keepalive_connections=1, max_connections=10) + client = Client(base_url=server or "", limits=limits, dry_run=dry_run, token=activity.token) + + start = time.monotonic() + + process = WatchedSubprocess.start(activity.path, activity.ti, client=client) + + exit_code = process.wait() + end = time.monotonic() + log.debug("Task finished", exit_code=exit_code, duration=end - start) + return exit_code diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py b/task_sdk/src/airflow/sdk/execution_time/task_runner.py new file mode 100644 index 000000000000..382e29c59b6a --- /dev/null +++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py @@ -0,0 +1,191 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""The entrypoint for the actual task execution process.""" + +from __future__ import annotations + +import os +import sys +from io import FileIO +from typing import TYPE_CHECKING, TextIO + +import attrs +import structlog +from pydantic import ConfigDict, TypeAdapter + +from airflow.sdk import BaseOperator +from airflow.sdk.execution_time.comms import StartupDetails, TaskInstance, ToSupervisor, ToTask + +if TYPE_CHECKING: + from structlog.typing import FilteringBoundLogger as Logger + + +class RuntimeTaskInstance(TaskInstance): + model_config = ConfigDict(arbitrary_types_allowed=True) + + task: BaseOperator + + +def parse(what: StartupDetails) -> RuntimeTaskInstance: + # TODO: Task-SDK: + # Using DagBag here is aoubt 98% wrong, but it'll do for now + + from airflow.models.dagbag import DagBag + + bag = DagBag( + dag_folder=what.file, + include_examples=False, + safe_mode=False, + load_op_links=False, + ) + if TYPE_CHECKING: + assert what.ti.dag_id + + dag = bag.dags[what.ti.dag_id] + + # install_loader() + + # TODO: Handle task not found + task = dag.task_dict[what.ti.task_id] + if not isinstance(task, BaseOperator): + raise TypeError(f"task is of the wrong type, got {type(task)}, wanted {BaseOperator}") + return RuntimeTaskInstance(**what.ti.model_dump(exclude_unset=True), task=task) + + +@attrs.define() +class CommsDecoder: + """Handle communication between the task in this process and the supervisor parent process.""" + + input: TextIO = sys.stdin + + request_socket: FileIO = attrs.field(init=False, default=None) + + decoder: TypeAdapter[ToTask] = attrs.field(init=False, factory=lambda: TypeAdapter(ToTask)) + + def get_message(self) -> ToTask: + """ + Get a message from the parent. + + This will block until the message has been received. + """ + line = self.input.readline() + try: + msg = self.decoder.validate_json(line) + except Exception: + structlog.get_logger(logger_name="CommsDecoder").exception("Unable to decode message", line=line) + raise + + if isinstance(msg, StartupDetails): + # If we read a startup message, pull out the FDs we care about! + if msg.requests_fd > 0: + self.request_socket = os.fdopen(msg.requests_fd, "wb", buffering=0) + return msg + + def send_request(self, log: Logger, msg: ToSupervisor): + encoded_msg = msg.model_dump_json().encode() + b"\n" + + log.debug("Sending request", json=encoded_msg) + self.request_socket.write(encoded_msg) + + +# This global variable will be used by Connection/Variable classes etc to send requests to +SUPERVISOR_COMMS: CommsDecoder + +# State machine! +# 1. Start up (receive details from supervisor) +# 2. Execution (run task code, possibly send requests) +# 3. Shutdown and report status + + +def startup() -> tuple[RuntimeTaskInstance, Logger]: + msg = SUPERVISOR_COMMS.get_message() + + if isinstance(msg, StartupDetails): + log = structlog.get_logger(logger_name="task") + # TODO: set the "magic loop" context vars for parsing + ti = parse(msg) + log.debug("DAG file parsed", file=msg.file) + return ti, log + else: + raise RuntimeError(f"Unhandled startup message {type(msg)} {msg}") + + # TODO: Render fields here + + +def run(ti: RuntimeTaskInstance, log: Logger): + """Run the task in this process.""" + from airflow.exceptions import ( + AirflowException, + AirflowFailException, + AirflowRescheduleException, + AirflowSensorTimeout, + AirflowSkipException, + AirflowTaskTerminated, + AirflowTaskTimeout, + TaskDeferred, + ) + + if TYPE_CHECKING: + assert ti.task is not None + assert isinstance(ti.task, BaseOperator) + try: + # TODO: pre execute etc. + # TODO next_method to support resuming from deferred + # TODO: Get a real context object + ti.task.execute({"task_instance": ti}) # type: ignore[attr-defined] + except TaskDeferred: + ... + except AirflowSkipException: + ... + except AirflowRescheduleException: + ... + except (AirflowFailException, AirflowSensorTimeout): + # If AirflowFailException is raised, task should not retry. + ... + except (AirflowTaskTimeout, AirflowException, AirflowTaskTerminated): + ... + except SystemExit: + ... + except BaseException: + ... + + +def finalize(log: Logger): ... + + +def main(): + # TODO: add an exception here, it causes an oof of a stack trace! + + global SUPERVISOR_COMMS + SUPERVISOR_COMMS = CommsDecoder() + try: + ti, log = startup() + run(ti, log) + finalize(log) + except KeyboardInterrupt: + log = structlog.get_logger(logger_name="task") + log.exception("Ctrl-c hit") + exit(2) + except Exception: + log = structlog.get_logger(logger_name="task") + log.exception("Top level error") + exit(1) + + +if __name__ == "__main__": + main() diff --git a/task_sdk/src/airflow/sdk/log.py b/task_sdk/src/airflow/sdk/log.py new file mode 100644 index 000000000000..f8e06eda4a65 --- /dev/null +++ b/task_sdk/src/airflow/sdk/log.py @@ -0,0 +1,372 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import itertools +import logging.config +import os +import sys +import warnings +from functools import cache +from typing import TYPE_CHECKING, Any, BinaryIO, Callable, Generic, TextIO, TypeVar + +import msgspec +import structlog + +if TYPE_CHECKING: + from structlog.typing import EventDict, ExcInfo, Processor + + +__all__ = [ + "configure_logging", + "reset_logging", +] + + +def exception_group_tracebacks(format_exception: Callable[[ExcInfo], list[dict[str, Any]]]) -> Processor: + # Make mypy happy + if not hasattr(__builtins__, "BaseExceptionGroup"): + T = TypeVar("T") + + class BaseExceptionGroup(Generic[T]): + exceptions: list[T] + + def _exception_group_tracebacks(logger: Any, method_name: Any, event_dict: EventDict) -> EventDict: + if exc_info := event_dict.get("exc_info", None): + group: BaseExceptionGroup[Exception] | None = None + if exc_info is True: + # `log.exception('mesg")` case + exc_info = sys.exc_info() + if exc_info[0] is None: + exc_info = None + + if ( + isinstance(exc_info, tuple) + and len(exc_info) == 3 + and isinstance(exc_info[1], BaseExceptionGroup) + ): + group = exc_info[1] + elif isinstance(exc_info, BaseExceptionGroup): + group = exc_info + + if group: + # Only remove it from event_dict if we handle it + del event_dict["exc_info"] + event_dict["exception"] = list( + itertools.chain.from_iterable( + format_exception((type(exc), exc, exc.__traceback__)) # type: ignore[attr-defined,arg-type] + for exc in (*group.exceptions, group) + ) + ) + + return event_dict + + return _exception_group_tracebacks + + +def logger_name(logger: Any, method_name: Any, event_dict: EventDict) -> EventDict: + if logger_name := event_dict.pop("logger_name", None): + event_dict.setdefault("logger", logger_name) + return event_dict + + +def redact_jwt(logger: Any, method_name: str, event_dict: EventDict) -> EventDict: + for k, v in event_dict.items(): + if isinstance(v, str) and v.startswith("eyJ"): + event_dict[k] = "eyJ***" + return event_dict + + +def drop_positional_args(logger: Any, method_name: Any, event_dict: EventDict) -> EventDict: + event_dict.pop("positional_args", None) + return event_dict + + +class StdBinaryStreamHandler(logging.StreamHandler): + """A logging.StreamHandler that sends logs as binary JSON over the given stream.""" + + stream: BinaryIO + + def __init__(self, stream: BinaryIO): + super().__init__(stream) + + def emit(self, record: logging.LogRecord): + try: + msg = self.format(record) + buffer = bytearray(msg, "ascii", "backslashreplace") + + buffer += b"\n" + + stream = self.stream + stream.write(buffer) + self.flush() + except RecursionError: # See issue 36272 + raise + except Exception: + self.handleError(record) + + +@cache +def logging_processors( + enable_pretty_log: bool, +): + if enable_pretty_log: + timestamper = structlog.processors.MaybeTimeStamper(fmt="%Y-%m-%d %H:%M:%S.%f") + else: + timestamper = structlog.processors.MaybeTimeStamper(fmt="iso") + + processors: list[structlog.typing.Processor] = [ + timestamper, + structlog.contextvars.merge_contextvars, + structlog.processors.add_log_level, + structlog.stdlib.PositionalArgumentsFormatter(), + logger_name, + redact_jwt, + structlog.processors.StackInfoRenderer(), + ] + + # Imports to suppress showing code from these modules. We need the import to get the filepath for + # structlog to ignore. + import contextlib + + import click + import httpcore + import httpx + + suppress = ( + click, + contextlib, + httpx, + httpcore, + httpx, + ) + + if enable_pretty_log: + rich_exc_formatter = structlog.dev.RichTracebackFormatter( + # These values are picked somewhat arbitrarily to produce useful-but-compact tracebacks. If + # we ever need to change these then they should be configurable. + extra_lines=0, + max_frames=30, + indent_guides=False, + suppress=suppress, + ) + my_styles = structlog.dev.ConsoleRenderer.get_default_level_styles() + my_styles["debug"] = structlog.dev.CYAN + + console = structlog.dev.ConsoleRenderer( + exception_formatter=rich_exc_formatter, level_styles=my_styles + ) + processors.append(console) + return processors, { + "timestamper": timestamper, + "console": console, + } + else: + # Imports to suppress showing code from these modules + import contextlib + + import click + import httpcore + import httpx + + dict_exc_formatter = structlog.tracebacks.ExceptionDictTransformer( + use_rich=False, show_locals=False, suppress=suppress + ) + + dict_tracebacks = structlog.processors.ExceptionRenderer(dict_exc_formatter) + if hasattr(__builtins__, "BaseExceptionGroup"): + exc_group_processor = exception_group_tracebacks(dict_exc_formatter) + processors.append(exc_group_processor) + else: + exc_group_processor = None + + encoder = msgspec.json.Encoder() + + def json_dumps(msg, default): + return encoder.encode(msg) + + def json_processor(logger: Any, method_name: Any, event_dict: EventDict) -> str: + return encoder.encode(event_dict).decode("ascii") + + json = structlog.processors.JSONRenderer(serializer=json_dumps) + + processors.extend( + ( + dict_tracebacks, + structlog.processors.UnicodeDecoder(), + json, + ), + ) + + return processors, { + "timestamper": timestamper, + "exc_group_processor": exc_group_processor, + "dict_tracebacks": dict_tracebacks, + "json": json_processor, + } + + +@cache +def configure_logging( + enable_pretty_log: bool = True, + log_level: str = "DEBUG", + output: BinaryIO | None = None, + cache_logger_on_first_use: bool = True, +): + """Set up struct logging and stdlib logging config.""" + if enable_pretty_log and output is not None: + raise ValueError("output can only be set if enable_pretty_log is not") + + lvl = structlog.stdlib.NAME_TO_LEVEL[log_level.lower()] + + if enable_pretty_log: + formatter = "colored" + else: + formatter = "plain" + processors, named = logging_processors(enable_pretty_log) + timestamper = named["timestamper"] + + pre_chain: list[structlog.typing.Processor] = [ + # Add the log level and a timestamp to the event_dict if the log entry + # is not from structlog. + structlog.stdlib.add_log_level, + structlog.stdlib.add_logger_name, + timestamper, + ] + + # Don't cache the loggers during tests, it make it hard to capture them + if "PYTEST_CURRENT_TEST" in os.environ: + cache_logger_on_first_use = False + + color_formatter: list[structlog.typing.Processor] = [ + structlog.stdlib.ProcessorFormatter.remove_processors_meta, + drop_positional_args, + ] + std_lib_formatter: list[structlog.typing.Processor] = [ + structlog.stdlib.ProcessorFormatter.remove_processors_meta, + drop_positional_args, + ] + + wrapper_class = structlog.make_filtering_bound_logger(lvl) + if enable_pretty_log: + structlog.configure( + processors=processors, + cache_logger_on_first_use=cache_logger_on_first_use, + wrapper_class=wrapper_class, + ) + color_formatter.append(named["console"]) + else: + structlog.configure( + processors=processors, + cache_logger_on_first_use=cache_logger_on_first_use, + wrapper_class=wrapper_class, + logger_factory=structlog.BytesLoggerFactory(output), + ) + + if processor := named["exc_group_processor"]: + pre_chain.append(processor) + pre_chain.append(named["dict_tracebacks"]) + color_formatter.append(named["json"]) + std_lib_formatter.append(named["json"]) + + global _warnings_showwarning + _warnings_showwarning = warnings.showwarning + # Capture warnings and show them via structlog + warnings.showwarning = _showwarning + + logging.config.dictConfig( + { + "version": 1, + "disable_existing_loggers": False, + "formatters": { + "plain": { + "()": structlog.stdlib.ProcessorFormatter, + "processors": std_lib_formatter, + "foreign_pre_chain": pre_chain, + "pass_foreign_args": True, + }, + "colored": { + "()": structlog.stdlib.ProcessorFormatter, + "processors": color_formatter, + "foreign_pre_chain": pre_chain, + "pass_foreign_args": True, + }, + }, + "handlers": { + "default": { + "level": log_level.upper(), + "class": "logging.StreamHandler", + "formatter": formatter, + }, + "to_supervisor": { + "level": log_level.upper(), + "()": StdBinaryStreamHandler, + "formatter": formatter, + "stream": output, + }, + }, + "loggers": { + "": { + "handlers": ["to_supervisor" if output else "default"], + "level": log_level.upper(), + "propagate": True, + }, + # Some modules we _never_ want at debug level + "asyncio": {"level": "INFO"}, + "alembic": {"level": "INFO"}, + "httpcore": {"level": "INFO"}, + "httpx": {"level": "WARN"}, + "psycopg.pq": {"level": "INFO"}, + "sqlalchemy.engine": {"level": "WARN"}, + }, + } + ) + + +def reset_logging(): + global _warnings_showwarning + warnings.showwarning = _warnings_showwarning + configure_logging.cache_clear() + + +_warnings_showwarning = None + + +def _showwarning( + message: str | Warning, + category: type[Warning], + filename: str, + lineno: int, + file: TextIO | None = None, + line: str | None = None, +): + """ + Redirects warnings to structlog so they appear in task logs etc. + + Implementation of showwarnings which redirects to logging, which will first + check to see if the file parameter is None. If a file is specified, it will + delegate to the original warnings implementation of showwarning. Otherwise, + it will call warnings.formatwarning and will log the resulting string to a + warnings logger named "py.warnings" with level logging.WARNING. + """ + if file is not None: + if _warnings_showwarning is not None: + _warnings_showwarning(message, category, filename, lineno, file, line) + else: + log = structlog.get_logger(logger_name="py.warnings") + log.warning(str(message), category=category.__name__, filename=filename, lineno=lineno) diff --git a/task_sdk/src/airflow/sdk/types.py b/task_sdk/src/airflow/sdk/types.py index 232d08e27f90..ffde2170b17f 100644 --- a/task_sdk/src/airflow/sdk/types.py +++ b/task_sdk/src/airflow/sdk/types.py @@ -57,7 +57,7 @@ def deserialize(cls): Logger = logging.Logger else: - class Logger: ... # noqa: D101 + class Logger: ... def validate_instance_args(instance: DAGNode, expected_arg_types: dict[str, Any]) -> None: diff --git a/task_sdk/tests/api/__init__.py b/task_sdk/tests/api/__init__.py new file mode 100644 index 000000000000..13a83393a912 --- /dev/null +++ b/task_sdk/tests/api/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/task_sdk/tests/api/test_client.py b/task_sdk/tests/api/test_client.py new file mode 100644 index 000000000000..a32b321545dd --- /dev/null +++ b/task_sdk/tests/api/test_client.py @@ -0,0 +1,62 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import httpx +import pytest + +from airflow.sdk.api.client import Client, ErrorBody, RemoteValidationError, ServerResponseError + + +class TestClient: + def test_error_parsing(self): + def handle_request(request: httpx.Request) -> httpx.Response: + """ + A transport handle that always returns errors + """ + + return httpx.Response(422, json={"detail": [{"loc": ["#0"], "msg": "err", "type": "required"}]}) + + client = Client( + base_url=None, dry_run=True, token="", mounts={"'http://": httpx.MockTransport(handle_request)} + ) + + with pytest.raises(ServerResponseError) as err: + client.get("http://error") + + assert isinstance(err.value, ServerResponseError) + assert isinstance(err.value.detail, ErrorBody) + assert err.value.detail.detail == [ + RemoteValidationError(loc=["#0"], msg="err", type="required"), + ] + + def test_error_parsing_plain_text(self): + def handle_request(request: httpx.Request) -> httpx.Response: + """ + A transport handle that always returns errors + """ + + return httpx.Response(422, content=b"Internal Server Error") + + client = Client( + base_url=None, dry_run=True, token="", mounts={"'http://": httpx.MockTransport(handle_request)} + ) + + with pytest.raises(httpx.HTTPStatusError) as err: + client.get("http://error") + assert not isinstance(err.value, ServerResponseError) diff --git a/task_sdk/tests/conftest.py b/task_sdk/tests/conftest.py index ddc7c61656a0..dffd1370f4e3 100644 --- a/task_sdk/tests/conftest.py +++ b/task_sdk/tests/conftest.py @@ -17,6 +17,7 @@ from __future__ import annotations import os +from typing import TYPE_CHECKING, NoReturn import pytest @@ -25,7 +26,64 @@ # Task SDK does not need access to the Airflow database os.environ["_AIRFLOW_SKIP_DB_TESTS"] = "true" +if TYPE_CHECKING: + from structlog.typing import EventDict, WrappedLogger + + +@pytest.hookimpl() +def pytest_addhooks(pluginmanager: pytest.PytestPluginManager): + # Python 3.12 starts warning about mixing os.fork + Threads, and the pytest-rerunfailures plugin uses + # threads internally. Since this is new code, and it should be flake free, we disable the re-run failures + # plugin early (so that it doesn't run it's pytest_configure which is where the thread starts up if xdist + # is discovered). + pluginmanager.set_blocked("rerunfailures") + @pytest.hookimpl(tryfirst=True) def pytest_configure(config: pytest.Config) -> None: config.inicfg["airflow_deprecations_ignore"] = [] + + +class LogCapture: + # Like structlog.typing.LogCapture, but that doesn't add log_level in to the event dict + entries: list[EventDict] + + def __init__(self) -> None: + self.entries = [] + + def __call__(self, _: WrappedLogger, method_name: str, event_dict: EventDict) -> NoReturn: + from structlog.exceptions import DropEvent + + if "level" not in event_dict: + event_dict["_log_level"] = method_name + + self.entries.append(event_dict) + + raise DropEvent + + +@pytest.fixture +def captured_logs(): + import structlog + + from airflow.sdk.log import configure_logging, reset_logging + + # Use our real log config + reset_logging() + configure_logging(enable_pretty_log=False) + + # But we need to replace remove the last processor (the one that turns JSON into text, as we want the + # event dict for tests) + cur_processors = structlog.get_config()["processors"] + processors = cur_processors.copy() + proc = processors.pop() + assert isinstance( + proc, (structlog.dev.ConsoleRenderer, structlog.processors.JSONRenderer) + ), "Pre-condition" + try: + cap = LogCapture() + processors.append(cap) + structlog.configure(processors=processors) + yield cap.entries + finally: + structlog.configure(processors=cur_processors) diff --git a/task_sdk/tests/defintions/__init__.py b/task_sdk/tests/defintions/__init__.py new file mode 100644 index 000000000000..13a83393a912 --- /dev/null +++ b/task_sdk/tests/defintions/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/task_sdk/tests/defintions/test_baseoperator.py b/task_sdk/tests/defintions/test_baseoperator.py index 427d1ee0e3ef..19035319cdcb 100644 --- a/task_sdk/tests/defintions/test_baseoperator.py +++ b/task_sdk/tests/defintions/test_baseoperator.py @@ -29,6 +29,25 @@ DEFAULT_DATE = datetime(2016, 1, 1, tzinfo=timezone.utc) +@pytest.fixture(autouse=True, scope="module") +def _disable_ol_plugin(): + # The OpenLineage plugin imports setproctitle, and that now causes (C) level thread calls, which on Py + # 3.12+ issues a warning when os.fork happens. So for this plugin we disable it + + # And we load plugins when setting the priorty_weight field + import airflow.plugins_manager + + old = airflow.plugins_manager.plugins + + assert old is None, "Plugins already loaded, too late to stop them being loaded!" + + airflow.plugins_manager.plugins = [] + + yield + + airflow.plugins_manager.plugins = None + + # Essentially similar to airflow.models.baseoperator.BaseOperator class FakeOperator(metaclass=BaseOperatorMeta): def __init__(self, test_param, params=None, default_args=None): diff --git a/task_sdk/tests/execution_time/__init__.py b/task_sdk/tests/execution_time/__init__.py new file mode 100644 index 000000000000..13a83393a912 --- /dev/null +++ b/task_sdk/tests/execution_time/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/task_sdk/tests/execution_time/conftest.py b/task_sdk/tests/execution_time/conftest.py new file mode 100644 index 000000000000..4a537373363a --- /dev/null +++ b/task_sdk/tests/execution_time/conftest.py @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import sys + +import pytest + + +@pytest.fixture +def disable_capturing(): + old_in, old_out, old_err = sys.stdin, sys.stdout, sys.stderr + + sys.stdin = sys.__stdin__ + sys.stdout = sys.__stdout__ + sys.stderr = sys.__stderr__ + yield + sys.stdin, sys.stdout, sys.stderr = old_in, old_out, old_err diff --git a/task_sdk/tests/execution_time/test_supervisor.py b/task_sdk/tests/execution_time/test_supervisor.py new file mode 100644 index 000000000000..f1bf287cd222 --- /dev/null +++ b/task_sdk/tests/execution_time/test_supervisor.py @@ -0,0 +1,150 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import inspect +import logging +import os +import signal +import sys +from unittest.mock import MagicMock + +import pytest +import structlog +import structlog.testing + +from airflow.sdk.api import client as sdk_client +from airflow.sdk.api.datamodels.ti import TaskInstance +from airflow.sdk.execution_time.supervisor import WatchedSubprocess +from airflow.utils import timezone as tz + + +def lineno(): + """Returns the current line number in our program.""" + return inspect.currentframe().f_back.f_lineno + + +@pytest.mark.usefixtures("disable_capturing") +class TestWatchedSubprocess: + def test_reading_from_pipes(self, captured_logs, time_machine): + # Ignore anything lower than INFO for this test. Captured_logs resets things for us afterwards + structlog.configure(wrapper_class=structlog.make_filtering_bound_logger(logging.INFO)) + + line = lineno() + + def subprocess_main(): + # This is run in the subprocess! + + # Flush calls are to ensure ordering of output for predictable tests + import logging + import warnings + + print("I'm a short message") + sys.stdout.write("Message ") + sys.stdout.write("split across two writes\n") + sys.stdout.flush() + + print("stderr message", file=sys.stderr) + sys.stderr.flush() + + logging.getLogger("airflow.foobar").error("An error message") + + warnings.warn("Warning should be captured too", stacklevel=1) + + instant = tz.datetime(2024, 11, 7, 12, 34, 56, 78901) + time_machine.move_to(instant, tick=False) + + proc = WatchedSubprocess.start( + path=os.devnull, + ti=TaskInstance( + id="4d828a62-a417-4936-a7a6-2b3fabacecab", + task_id="b", + dag_id="c", + run_id="d", + try_number=1, + ), + client=MagicMock(spec=sdk_client.Client), + target=subprocess_main, + ) + + rc = proc.wait() + + assert rc == 0 + assert captured_logs == [ + { + "chan": "stdout", + "event": "I'm a short message", + "level": "info", + "logger": "task", + "timestamp": "2024-11-07T12:34:56.078901Z", + }, + { + "chan": "stdout", + "event": "Message split across two writes", + "level": "info", + "logger": "task", + "timestamp": "2024-11-07T12:34:56.078901Z", + }, + { + "chan": "stderr", + "event": "stderr message", + "level": "error", + "logger": "task", + "timestamp": "2024-11-07T12:34:56.078901Z", + }, + { + "event": "An error message", + "level": "error", + "logger": "airflow.foobar", + "timestamp": instant.replace(tzinfo=None), + }, + { + "category": "UserWarning", + "event": "Warning should be captured too", + "filename": __file__, + "level": "warning", + "lineno": line + 19, + "logger": "py.warnings", + "timestamp": instant.replace(tzinfo=None), + }, + ] + + def test_subprocess_sigkilled(self): + main_pid = os.getpid() + + def subprocess_main(): + # This is run in the subprocess! + assert os.getpid() != main_pid + os.kill(os.getpid(), signal.SIGKILL) + + proc = WatchedSubprocess.start( + path=os.devnull, + ti=TaskInstance( + id="4d828a62-a417-4936-a7a6-2b3fabacecab", + task_id="b", + dag_id="c", + run_id="d", + try_number=1, + ), + client=MagicMock(spec=sdk_client.Client), + target=subprocess_main, + ) + + rc = proc.wait() + + assert rc == -9 diff --git a/task_sdk/tests/execution_time/test_task_runner.py b/task_sdk/tests/execution_time/test_task_runner.py new file mode 100644 index 000000000000..5a90701cb2c0 --- /dev/null +++ b/task_sdk/tests/execution_time/test_task_runner.py @@ -0,0 +1,56 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import uuid +from socket import socketpair + +import pytest + +from airflow.sdk.execution_time.comms import StartupDetails +from airflow.sdk.execution_time.task_runner import CommsDecoder + + +class TestCommsDecoder: + """Test the communication between the subprocess and the "supervisor".""" + + @pytest.mark.usefixtures("disable_capturing") + def test_recv_StartupDetails(self): + r, w = socketpair() + # Create a valid FD for the decoder to open + _, w2 = socketpair() + + w.makefile("wb").write( + b'{"type":"StartupDetails", "ti": {' + b'"id": "4d828a62-a417-4936-a7a6-2b3fabacecab", "task_id": "a", "try_number": 1, "run_id": "b", "dag_id": "c" }, ' + b'"file": "/dev/null", "requests_fd": ' + str(w2.fileno()).encode("ascii") + b"}\n" + ) + + decoder = CommsDecoder(input=r.makefile("r")) + + msg = decoder.get_message() + assert isinstance(msg, StartupDetails) + assert msg.ti.id == uuid.UUID("4d828a62-a417-4936-a7a6-2b3fabacecab") + assert msg.ti.task_id == "a" + assert msg.ti.dag_id == "c" + assert msg.file == "/dev/null" + + # Since this was a StartupDetails message, the decoder should open the other socket + assert decoder.request_socket is not None + assert decoder.request_socket.writable() + assert decoder.request_socket.fileno() == w2.fileno() diff --git a/tests/cli/commands/test_celery_command.py b/tests/cli/commands/test_celery_command.py index ae2b71a59094..c417b6eb61cf 100644 --- a/tests/cli/commands/test_celery_command.py +++ b/tests/cli/commands/test_celery_command.py @@ -276,6 +276,7 @@ def test_run_command(self, mock_celery_app): @mock.patch("airflow.cli.commands.daemon_utils.setup_locations") @mock.patch("airflow.cli.commands.daemon_utils.daemon") @mock.patch("airflow.providers.celery.executors.celery_executor.app") + @pytest.mark.usefixtures("capfd") # This test needs fd capturing to work def test_run_command_daemon(self, mock_celery_app, mock_daemon, mock_setup_locations, mock_pid_file): mock_setup_locations.return_value = ( mock.MagicMock(name="pidfile"), From 339bc7748156a4d099121092addb5af1da4e81e8 Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Thu, 14 Nov 2024 14:33:00 +0000 Subject: [PATCH 05/33] Standardize timer metrics to milliseconds and remove config (#43975) - Removes the `timer_unit_consistency` configuration option, standardizing all timer and timing metrics to milliseconds by default. - Updates metric loggers (e.g., Datadog, OpenTelemetry) to ensure uniform milliseconds-based reporting. - Cleans up related warnings! This is a follow-up of https://github.com/apache/airflow/pull/43966 for main & Airflow 3 --- airflow/config_templates/config.yml | 18 ------ airflow/metrics/datadog_logger.py | 15 +---- airflow/metrics/otel_logger.py | 14 +---- airflow/metrics/protocols.py | 16 +----- airflow/models/taskinstance.py | 20 +------ newsfragments/39908.significant.rst | 11 ---- newsfragments/43975.significant.rst | 8 +++ tests/core/test_otel_logger.py | 56 ++++--------------- tests/core/test_stats.py | 25 ++------- tests_common/_internals/forbidden_warnings.py | 5 -- 10 files changed, 27 insertions(+), 161 deletions(-) delete mode 100644 newsfragments/39908.significant.rst create mode 100644 newsfragments/43975.significant.rst diff --git a/airflow/config_templates/config.yml b/airflow/config_templates/config.yml index c77f9476b0d2..eba9f7b8c70e 100644 --- a/airflow/config_templates/config.yml +++ b/airflow/config_templates/config.yml @@ -1064,24 +1064,6 @@ metrics: example: "\"scheduler,executor,dagrun,pool,triggerer,celery\" or \"^scheduler,^executor,heartbeat|timeout\"" default: "" - # TODO: Remove 'timer_unit_consistency' in Airflow 3.0 - timer_unit_consistency: - description: | - Controls the consistency of timer units across all metrics loggers - (e.g., Statsd, Datadog, OpenTelemetry) - for timing and duration-based metrics. When enabled, all timers will publish - metrics in milliseconds for consistency and alignment with Airflow's default - metrics behavior in version 3.0+. - - .. warning:: - - It will be the default behavior from Airflow 3.0. If disabled, timers may publish - in seconds for backwards compatibility, though it is recommended to enable this - setting to ensure metric uniformity and forward-compat with Airflow 3. - version_added: 2.11.0 - type: string - example: ~ - default: "False" statsd_on: description: | Enables sending metrics to StatsD. diff --git a/airflow/metrics/datadog_logger.py b/airflow/metrics/datadog_logger.py index 81926716eb25..a166c6fcb169 100644 --- a/airflow/metrics/datadog_logger.py +++ b/airflow/metrics/datadog_logger.py @@ -19,11 +19,9 @@ import datetime import logging -import warnings from typing import TYPE_CHECKING from airflow.configuration import conf -from airflow.exceptions import RemovedInAirflow3Warning from airflow.metrics.protocols import Timer from airflow.metrics.validators import ( PatternAllowListValidator, @@ -42,14 +40,6 @@ log = logging.getLogger(__name__) -timer_unit_consistency = conf.getboolean("metrics", "timer_unit_consistency") -if not timer_unit_consistency: - warnings.warn( - "Timer and timing metrics publish in seconds were deprecated. It is enabled by default from Airflow 3 onwards. Enable timer_unit_consistency to publish all the timer and timing metrics in milliseconds.", - RemovedInAirflow3Warning, - stacklevel=2, - ) - class SafeDogStatsdLogger: """DogStatsd Logger.""" @@ -144,10 +134,7 @@ def timing( tags_list = [] if self.metrics_validator.test(stat): if isinstance(dt, datetime.timedelta): - if timer_unit_consistency: - dt = dt.total_seconds() * 1000.0 - else: - dt = dt.total_seconds() + dt = dt.total_seconds() * 1000.0 return self.dogstatsd.timing(metric=stat, value=dt, tags=tags_list) return None diff --git a/airflow/metrics/otel_logger.py b/airflow/metrics/otel_logger.py index ed123608626f..c3633212cd27 100644 --- a/airflow/metrics/otel_logger.py +++ b/airflow/metrics/otel_logger.py @@ -31,7 +31,6 @@ from opentelemetry.sdk.resources import HOST_NAME, SERVICE_NAME, Resource from airflow.configuration import conf -from airflow.exceptions import RemovedInAirflow3Warning from airflow.metrics.protocols import Timer from airflow.metrics.validators import ( OTEL_NAME_MAX_LENGTH, @@ -73,14 +72,6 @@ # Delimiter is placed between the universal metric prefix and the unique metric name. DEFAULT_METRIC_NAME_DELIMITER = "." -timer_unit_consistency = conf.getboolean("metrics", "timer_unit_consistency") -if not timer_unit_consistency: - warnings.warn( - "Timer and timing metrics publish in seconds were deprecated. It is enabled by default from Airflow 3 onwards. Enable timer_unit_consistency to publish all the timer and timing metrics in milliseconds.", - RemovedInAirflow3Warning, - stacklevel=2, - ) - def full_name(name: str, *, prefix: str = DEFAULT_METRIC_NAME_PREFIX) -> str: """Assembles the prefix, delimiter, and name and returns it as a string.""" @@ -284,10 +275,7 @@ def timing( """OTel does not have a native timer, stored as a Gauge whose value is number of seconds elapsed.""" if self.metrics_validator.test(stat) and name_is_otel_safe(self.prefix, stat): if isinstance(dt, datetime.timedelta): - if timer_unit_consistency: - dt = dt.total_seconds() * 1000.0 - else: - dt = dt.total_seconds() + dt = dt.total_seconds() * 1000.0 self.metrics_map.set_gauge_value(full_name(prefix=self.prefix, name=stat), float(dt), False, tags) def timer( diff --git a/airflow/metrics/protocols.py b/airflow/metrics/protocols.py index 0d12704e87a3..8cfe4d8e7ea3 100644 --- a/airflow/metrics/protocols.py +++ b/airflow/metrics/protocols.py @@ -19,23 +19,12 @@ import datetime import time -import warnings from typing import Union -from airflow.configuration import conf -from airflow.exceptions import RemovedInAirflow3Warning from airflow.typing_compat import Protocol DeltaType = Union[int, float, datetime.timedelta] -timer_unit_consistency = conf.getboolean("metrics", "timer_unit_consistency") -if not timer_unit_consistency: - warnings.warn( - "Timer and timing metrics publish in seconds were deprecated. It is enabled by default from Airflow 3 onwards. Enable timer_unit_consistency to publish all the timer and timing metrics in milliseconds.", - RemovedInAirflow3Warning, - stacklevel=2, - ) - class TimerProtocol(Protocol): """Type protocol for StatsLogger.timer.""" @@ -127,9 +116,6 @@ def start(self) -> Timer: def stop(self, send: bool = True) -> None: """Stop the timer, and optionally send it to stats backend.""" if self._start_time is not None: - if timer_unit_consistency: - self.duration = 1000.0 * (time.perf_counter() - self._start_time) # Convert to milliseconds. - else: - self.duration = time.perf_counter() - self._start_time + self.duration = 1000.0 * (time.perf_counter() - self._start_time) # Convert to milliseconds. if send and self.real_timer: self.real_timer.stop() diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 410cdd8773d3..30ef941ceea5 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -26,7 +26,6 @@ import operator import os import signal -import warnings from collections import defaultdict from contextlib import nullcontext from datetime import timedelta @@ -85,7 +84,6 @@ AirflowSkipException, AirflowTaskTerminated, AirflowTaskTimeout, - RemovedInAirflow3Warning, TaskDeferralError, TaskDeferred, UnmappableXComLengthPushed, @@ -176,14 +174,6 @@ PAST_DEPENDS_MET = "past_depends_met" -timer_unit_consistency = conf.getboolean("metrics", "timer_unit_consistency") -if not timer_unit_consistency: - warnings.warn( - "Timer and timing metrics publish in seconds were deprecated. It is enabled by default from Airflow 3 onwards. Enable timer_unit_consistency to publish all the timer and timing metrics in milliseconds.", - RemovedInAirflow3Warning, - stacklevel=2, - ) - class TaskReturnCode(Enum): """ @@ -2831,10 +2821,7 @@ def emit_state_change_metric(self, new_state: TaskInstanceState) -> None: self.task_id, ) return - if timer_unit_consistency: - timing = timezone.utcnow() - self.queued_dttm - else: - timing = (timezone.utcnow() - self.queued_dttm).total_seconds() + timing = timezone.utcnow() - self.queued_dttm elif new_state == TaskInstanceState.QUEUED: metric_name = "scheduled_duration" if self.start_date is None: @@ -2847,10 +2834,7 @@ def emit_state_change_metric(self, new_state: TaskInstanceState) -> None: self.task_id, ) return - if timer_unit_consistency: - timing = timezone.utcnow() - self.start_date - else: - timing = (timezone.utcnow() - self.start_date).total_seconds() + timing = timezone.utcnow() - self.start_date else: raise NotImplementedError("no metric emission setup for state %s", new_state) diff --git a/newsfragments/39908.significant.rst b/newsfragments/39908.significant.rst deleted file mode 100644 index d5ba99fa9fa5..000000000000 --- a/newsfragments/39908.significant.rst +++ /dev/null @@ -1,11 +0,0 @@ -Publishing timer and timing metrics in seconds is now deprecated. - -In Airflow 3.0, the ``timer_unit_consistency`` setting in the ``metrics`` section will be -enabled by default and setting itself will be removed. This will standardize all timer and timing metrics to -milliseconds across all metric loggers. - -**Users Integrating with Datadog, OpenTelemetry, or other metric backends** should enable this setting. For users, using -``statsd``, this change will not affect you. - -If you need backward compatibility, you can leave this setting disabled temporarily, but enabling -``timer_unit_consistency`` is encouraged to future-proof your metrics setup. diff --git a/newsfragments/43975.significant.rst b/newsfragments/43975.significant.rst new file mode 100644 index 000000000000..6d116ac1eedf --- /dev/null +++ b/newsfragments/43975.significant.rst @@ -0,0 +1,8 @@ +Timer and timing metrics are now standardized to milliseconds + +In Airflow 3.0, the ``timer_unit_consistency`` setting in the ``metrics`` section is removed as it is now the default behaviour. +This is done to standardize all timer and timing metrics to milliseconds across all metric loggers. + +Airflow 2.11 introduced the ``timer_unit_consistency`` setting in the ``metrics`` section of the configuration file. The +default value was ``False`` which meant that the timer and timing metrics were logged in seconds. This was done to maintain +backwards compatibility with the previous versions of Airflow. diff --git a/tests/core/test_otel_logger.py b/tests/core/test_otel_logger.py index a4bf7c4c4156..c88dcdbab01c 100644 --- a/tests/core/test_otel_logger.py +++ b/tests/core/test_otel_logger.py @@ -25,7 +25,6 @@ from opentelemetry.metrics import MeterProvider from airflow.exceptions import InvalidStatsNameException -from airflow.metrics import otel_logger, protocols from airflow.metrics.otel_logger import ( OTEL_NAME_MAX_LENGTH, UP_DOWN_COUNTERS, @@ -235,21 +234,15 @@ def test_gauge_value_is_correct(self, name): assert self.map[full_name(name)].value == 1 - @pytest.mark.parametrize( - "timer_unit_consistency", - [True, False], - ) - def test_timing_new_metric(self, timer_unit_consistency, name): + def test_timing_new_metric(self, name): import datetime - otel_logger.timer_unit_consistency = timer_unit_consistency - self.stats.timing(name, dt=datetime.timedelta(seconds=123)) self.meter.get_meter().create_observable_gauge.assert_called_once_with( name=full_name(name), callbacks=ANY ) - expected_value = 123000.0 if timer_unit_consistency else 123 + expected_value = 123000.0 assert self.map[full_name(name)].value == expected_value def test_timing_new_metric_with_tags(self, name): @@ -276,81 +269,52 @@ def test_timing_existing_metric(self, name): # time.perf_count() is called once to get the starting timestamp and again # to get the end timestamp. timer() should return the difference as a float. - @pytest.mark.parametrize( - "timer_unit_consistency", - [True, False], - ) @mock.patch.object(time, "perf_counter", side_effect=[0.0, 3.14]) - def test_timer_with_name_returns_float_and_stores_value(self, mock_time, timer_unit_consistency, name): - protocols.timer_unit_consistency = timer_unit_consistency + def test_timer_with_name_returns_float_and_stores_value(self, mock_time, name): with self.stats.timer(name) as timer: pass assert isinstance(timer.duration, float) - expected_duration = 3140.0 if timer_unit_consistency else 3.14 + expected_duration = 3140.0 assert timer.duration == expected_duration assert mock_time.call_count == 2 self.meter.get_meter().create_observable_gauge.assert_called_once_with( name=full_name(name), callbacks=ANY ) - @pytest.mark.parametrize( - "timer_unit_consistency", - [True, False], - ) @mock.patch.object(time, "perf_counter", side_effect=[0.0, 3.14]) - def test_timer_no_name_returns_float_but_does_not_store_value( - self, mock_time, timer_unit_consistency, name - ): - protocols.timer_unit_consistency = timer_unit_consistency + def test_timer_no_name_returns_float_but_does_not_store_value(self, mock_time, name): with self.stats.timer() as timer: pass assert isinstance(timer.duration, float) - expected_duration = 3140.0 if timer_unit_consistency else 3.14 + expected_duration = 3140.0 assert timer.duration == expected_duration assert mock_time.call_count == 2 self.meter.get_meter().create_observable_gauge.assert_not_called() - @pytest.mark.parametrize( - "timer_unit_consistency", - [ - True, - False, - ], - ) @mock.patch.object(time, "perf_counter", side_effect=[0.0, 3.14]) - def test_timer_start_and_stop_manually_send_false(self, mock_time, timer_unit_consistency, name): - protocols.timer_unit_consistency = timer_unit_consistency - + def test_timer_start_and_stop_manually_send_false(self, mock_time, name): timer = self.stats.timer(name) timer.start() # Perform some task timer.stop(send=False) assert isinstance(timer.duration, float) - expected_value = 3140.0 if timer_unit_consistency else 3.14 + expected_value = 3140.0 assert timer.duration == expected_value assert mock_time.call_count == 2 self.meter.get_meter().create_observable_gauge.assert_not_called() - @pytest.mark.parametrize( - "timer_unit_consistency", - [ - True, - False, - ], - ) @mock.patch.object(time, "perf_counter", side_effect=[0.0, 3.14]) - def test_timer_start_and_stop_manually_send_true(self, mock_time, timer_unit_consistency, name): - protocols.timer_unit_consistency = timer_unit_consistency + def test_timer_start_and_stop_manually_send_true(self, mock_time, name): timer = self.stats.timer(name) timer.start() # Perform some task timer.stop(send=True) assert isinstance(timer.duration, float) - expected_value = 3140.0 if timer_unit_consistency else 3.14 + expected_value = 3140.0 assert timer.duration == expected_value assert mock_time.call_count == 2 self.meter.get_meter().create_observable_gauge.assert_called_once_with( diff --git a/tests/core/test_stats.py b/tests/core/test_stats.py index 1b30bc9990af..92b4b63f7b6a 100644 --- a/tests/core/test_stats.py +++ b/tests/core/test_stats.py @@ -29,7 +29,6 @@ import airflow from airflow.exceptions import AirflowConfigException, InvalidStatsNameException -from airflow.metrics import datadog_logger, protocols from airflow.metrics.datadog_logger import SafeDogStatsdLogger from airflow.metrics.statsd_logger import SafeStatsdLogger from airflow.metrics.validators import ( @@ -221,20 +220,12 @@ def test_does_send_stats_using_dogstatsd_when_statsd_and_dogstatsd_both_on(self) metric="empty_key", sample_rate=1, tags=[], value=1 ) - @pytest.mark.parametrize( - "timer_unit_consistency", - [True, False], - ) @mock.patch.object(time, "perf_counter", side_effect=[0.0, 100.0]) - def test_timer(self, time_mock, timer_unit_consistency): - protocols.timer_unit_consistency = timer_unit_consistency - + def test_timer(self, time_mock): with self.dogstatsd.timer("empty_timer") as timer: pass self.dogstatsd_client.timed.assert_called_once_with("empty_timer", tags=[]) - expected_duration = 100.0 - if timer_unit_consistency: - expected_duration = 1000.0 * 100.0 + expected_duration = 1000.0 * 100.0 assert expected_duration == timer.duration assert time_mock.call_count == 2 @@ -243,22 +234,14 @@ def test_empty_timer(self): pass self.dogstatsd_client.timed.assert_not_called() - @pytest.mark.parametrize( - "timer_unit_consistency", - [True, False], - ) - def test_timing(self, timer_unit_consistency): + def test_timing(self): import datetime - datadog_logger.timer_unit_consistency = timer_unit_consistency - self.dogstatsd.timing("empty_timer", 123) self.dogstatsd_client.timing.assert_called_once_with(metric="empty_timer", value=123, tags=[]) self.dogstatsd.timing("empty_timer", datetime.timedelta(seconds=123)) - self.dogstatsd_client.timing.assert_called_with( - metric="empty_timer", value=123000.0 if timer_unit_consistency else 123.0, tags=[] - ) + self.dogstatsd_client.timing.assert_called_with(metric="empty_timer", value=123000.0, tags=[]) def test_gauge(self): self.dogstatsd.gauge("empty", 123) diff --git a/tests_common/_internals/forbidden_warnings.py b/tests_common/_internals/forbidden_warnings.py index 856960935bd4..1217927f1101 100644 --- a/tests_common/_internals/forbidden_warnings.py +++ b/tests_common/_internals/forbidden_warnings.py @@ -73,11 +73,6 @@ def pytest_itemcollected(self, item: pytest.Item): # Add marker at the beginning of the markers list. In this case, it does not conflict with # filterwarnings markers, which are set explicitly in the test suite. item.add_marker(pytest.mark.filterwarnings(f"error::{fw}"), append=False) - item.add_marker( - pytest.mark.filterwarnings( - "ignore:Timer and timing metrics publish in seconds were deprecated. It is enabled by default from Airflow 3 onwards. Enable timer_unit_consistency to publish all the timer and timing metrics in milliseconds.:DeprecationWarning" - ) - ) @pytest.hookimpl(hookwrapper=True, trylast=True) def pytest_sessionfinish(self, session: pytest.Session, exitstatus: int): From 77281399de2c0fed224e16047eaadd51c657b289 Mon Sep 17 00:00:00 2001 From: Jarek Potiuk Date: Thu, 14 Nov 2024 15:38:37 +0100 Subject: [PATCH 06/33] Fix side effect of bad grpc.Chanel mocking (#44029) The grpc.Channel has been patched but not relased in the test_grpc.py and it could have caused other tests failing - when they were run later in the same interpreter. For example it failed in in #44011 in the https://github.com/apache/airflow/pull/44011#issuecomment-2476439247 Patching is now fixed via using fixtures. --- providers/tests/grpc/hooks/test_grpc.py | 74 +++++++++++++++---------- 1 file changed, 44 insertions(+), 30 deletions(-) diff --git a/providers/tests/grpc/hooks/test_grpc.py b/providers/tests/grpc/hooks/test_grpc.py index 8536188bb135..ed185b4bf282 100644 --- a/providers/tests/grpc/hooks/test_grpc.py +++ b/providers/tests/grpc/hooks/test_grpc.py @@ -59,21 +59,21 @@ def stream_call(self, data): return ["streaming", "call"] -class TestGrpcHook: - def setup_method(self): - self.channel_mock = mock.patch("grpc.Channel").start() +@pytest.fixture +def channel_mock(): + """We mock run_command to capture its call args; it returns nothing so mock training is unnecessary.""" + with patch("grpc.Channel") as grpc_channel: + yield grpc_channel - def custom_conn_func(self, _): - mocked_channel = self.channel_mock.return_value - return mocked_channel +class TestGrpcHook: @mock.patch("grpc.insecure_channel") @mock.patch("airflow.hooks.base.BaseHook.get_connection") - def test_no_auth_connection(self, mock_get_connection, mock_insecure_channel): + def test_no_auth_connection(self, mock_get_connection, mock_insecure_channel, channel_mock): conn = get_airflow_connection() mock_get_connection.return_value = conn hook = GrpcHook("grpc_default") - mocked_channel = self.channel_mock.return_value + mocked_channel = channel_mock.return_value mock_insecure_channel.return_value = mocked_channel channel = hook.get_conn() @@ -84,11 +84,11 @@ def test_no_auth_connection(self, mock_get_connection, mock_insecure_channel): @mock.patch("grpc.insecure_channel") @mock.patch("airflow.hooks.base.BaseHook.get_connection") - def test_connection_with_port(self, mock_get_connection, mock_insecure_channel): + def test_connection_with_port(self, mock_get_connection, mock_insecure_channel, channel_mock): conn = get_airflow_connection_with_port() mock_get_connection.return_value = conn hook = GrpcHook("grpc_default") - mocked_channel = self.channel_mock.return_value + mocked_channel = channel_mock.return_value mock_insecure_channel.return_value = mocked_channel channel = hook.get_conn() @@ -102,13 +102,13 @@ def test_connection_with_port(self, mock_get_connection, mock_insecure_channel): @mock.patch("grpc.ssl_channel_credentials") @mock.patch("grpc.secure_channel") def test_connection_with_ssl( - self, mock_secure_channel, mock_channel_credentials, mock_get_connection, mock_open + self, mock_secure_channel, mock_channel_credentials, mock_get_connection, mock_open, channel_mock ): conn = get_airflow_connection(auth_type="SSL", credential_pem_file="pem") mock_get_connection.return_value = conn mock_open.return_value = StringIO("credential") hook = GrpcHook("grpc_default") - mocked_channel = self.channel_mock.return_value + mocked_channel = channel_mock.return_value mock_secure_channel.return_value = mocked_channel mock_credential_object = "test_credential_object" mock_channel_credentials.return_value = mock_credential_object @@ -126,13 +126,13 @@ def test_connection_with_ssl( @mock.patch("grpc.ssl_channel_credentials") @mock.patch("grpc.secure_channel") def test_connection_with_tls( - self, mock_secure_channel, mock_channel_credentials, mock_get_connection, mock_open + self, mock_secure_channel, mock_channel_credentials, mock_get_connection, mock_open, channel_mock ): conn = get_airflow_connection(auth_type="TLS", credential_pem_file="pem") mock_get_connection.return_value = conn mock_open.return_value = StringIO("credential") hook = GrpcHook("grpc_default") - mocked_channel = self.channel_mock.return_value + mocked_channel = channel_mock.return_value mock_secure_channel.return_value = mocked_channel mock_credential_object = "test_credential_object" mock_channel_credentials.return_value = mock_credential_object @@ -150,12 +150,17 @@ def test_connection_with_tls( @mock.patch("google.auth.default") @mock.patch("google.auth.transport.grpc.secure_authorized_channel") def test_connection_with_jwt( - self, mock_secure_channel, mock_google_default_auth, mock_google_cred, mock_get_connection + self, + mock_secure_channel, + mock_google_default_auth, + mock_google_cred, + mock_get_connection, + channel_mock, ): conn = get_airflow_connection(auth_type="JWT_GOOGLE") mock_get_connection.return_value = conn hook = GrpcHook("grpc_default") - mocked_channel = self.channel_mock.return_value + mocked_channel = channel_mock.return_value mock_secure_channel.return_value = mocked_channel mock_credential_object = "test_credential_object" mock_google_default_auth.return_value = (mock_credential_object, "") @@ -173,12 +178,17 @@ def test_connection_with_jwt( @mock.patch("google.auth.default") @mock.patch("google.auth.transport.grpc.secure_authorized_channel") def test_connection_with_google_oauth( - self, mock_secure_channel, mock_google_default_auth, mock_google_auth_request, mock_get_connection + self, + mock_secure_channel, + mock_google_default_auth, + mock_google_auth_request, + mock_get_connection, + channel_mock, ): conn = get_airflow_connection(auth_type="OATH_GOOGLE", scopes="grpc,gcs") mock_get_connection.return_value = conn hook = GrpcHook("grpc_default") - mocked_channel = self.channel_mock.return_value + mocked_channel = channel_mock.return_value mock_secure_channel.return_value = mocked_channel mock_credential_object = "test_credential_object" mock_google_default_auth.return_value = (mock_credential_object, "") @@ -192,18 +202,22 @@ def test_connection_with_google_oauth( assert channel == mocked_channel @mock.patch("airflow.hooks.base.BaseHook.get_connection") - def test_custom_connection(self, mock_get_connection): + def test_custom_connection(self, mock_get_connection, channel_mock): + def custom_conn_func(_): + mocked_channel = channel_mock.return_value + return mocked_channel + conn = get_airflow_connection("CUSTOM") mock_get_connection.return_value = conn - mocked_channel = self.channel_mock.return_value - hook = GrpcHook("grpc_default", custom_connection_func=self.custom_conn_func) + mocked_channel = channel_mock.return_value + hook = GrpcHook("grpc_default", custom_connection_func=custom_conn_func) channel = hook.get_conn() assert channel == mocked_channel @mock.patch("airflow.hooks.base.BaseHook.get_connection") - def test_custom_connection_with_no_connection_func(self, mock_get_connection): + def test_custom_connection_with_no_connection_func(self, mock_get_connection, channel_mock): conn = get_airflow_connection("CUSTOM") mock_get_connection.return_value = conn hook = GrpcHook("grpc_default") @@ -212,7 +226,7 @@ def test_custom_connection_with_no_connection_func(self, mock_get_connection): hook.get_conn() @mock.patch("airflow.hooks.base.BaseHook.get_connection") - def test_connection_type_not_supported(self, mock_get_connection): + def test_connection_type_not_supported(self, mock_get_connection, channel_mock): conn = get_airflow_connection("NOT_SUPPORT") mock_get_connection.return_value = conn hook = GrpcHook("grpc_default") @@ -224,11 +238,11 @@ def test_connection_type_not_supported(self, mock_get_connection): @mock.patch("airflow.hooks.base.BaseHook.get_connection") @mock.patch("grpc.insecure_channel") def test_connection_with_interceptors( - self, mock_insecure_channel, mock_get_connection, mock_intercept_channel + self, mock_insecure_channel, mock_get_connection, mock_intercept_channel, channel_mock ): conn = get_airflow_connection() mock_get_connection.return_value = conn - mocked_channel = self.channel_mock.return_value + mocked_channel = channel_mock.return_value hook = GrpcHook("grpc_default", interceptors=["test1"]) mock_insecure_channel.return_value = mocked_channel mock_intercept_channel.return_value = mocked_channel @@ -240,7 +254,7 @@ def test_connection_with_interceptors( @mock.patch("airflow.hooks.base.BaseHook.get_connection") @mock.patch("airflow.providers.grpc.hooks.grpc.GrpcHook.get_conn") - def test_simple_run(self, mock_get_conn, mock_get_connection): + def test_simple_run(self, mock_get_conn, mock_get_connection, channel_mock): conn = get_airflow_connection() mock_get_connection.return_value = conn mocked_channel = mock.Mock() @@ -255,7 +269,7 @@ def test_simple_run(self, mock_get_conn, mock_get_connection): @mock.patch("airflow.hooks.base.BaseHook.get_connection") @mock.patch("airflow.providers.grpc.hooks.grpc.GrpcHook.get_conn") - def test_stream_run(self, mock_get_conn, mock_get_connection): + def test_stream_run(self, mock_get_conn, mock_get_connection, channel_mock): conn = get_airflow_connection() mock_get_connection.return_value = conn mocked_channel = mock.Mock() @@ -279,13 +293,13 @@ def test_stream_run(self, mock_get_conn, mock_get_connection): ], ) @patch("airflow.providers.grpc.hooks.grpc.grpc.insecure_channel") - def test_backcompat_prefix_works(self, channel_mock, uri): + def test_backcompat_prefix_works(self, insecure_channel_mock, uri): with patch.dict(os.environ, {"AIRFLOW_CONN_MY_CONN": uri}): hook = GrpcHook("my_conn") hook.get_conn() - channel_mock.assert_called_with("abc:50") + insecure_channel_mock.assert_called_with("abc:50") - def test_backcompat_prefix_both_prefers_short(self): + def test_backcompat_prefix_both_prefers_short(self, channel_mock): with patch.dict( os.environ, {"AIRFLOW_CONN_MY_CONN": "a://abc:50?extra__grpc__auth_type=non-pref&auth_type=pref"}, From f60886cf368b943120af20889b83704ccdbb8c91 Mon Sep 17 00:00:00 2001 From: Maciej Obuchowski Date: Thu, 14 Nov 2024 15:57:27 +0100 Subject: [PATCH 07/33] add ProcessingEngineRunFacet to OL DAG Start event (#43213) Signed-off-by: Maciej Obuchowski --- .../providers/openlineage/plugins/adapter.py | 14 +++----------- .../providers/openlineage/utils/utils.py | 16 ++++++++++++++-- .../tests/openlineage/plugins/test_adapter.py | 3 +++ .../tests/openlineage/plugins/test_utils.py | 17 +++++++++++++++++ 4 files changed, 37 insertions(+), 13 deletions(-) diff --git a/providers/src/airflow/providers/openlineage/plugins/adapter.py b/providers/src/airflow/providers/openlineage/plugins/adapter.py index fb58cc5dc022..199df880e79a 100644 --- a/providers/src/airflow/providers/openlineage/plugins/adapter.py +++ b/providers/src/airflow/providers/openlineage/plugins/adapter.py @@ -32,7 +32,6 @@ nominal_time_run, ownership_job, parent_run, - processing_engine_run, source_code_location_job, ) from openlineage.client.uuid import generate_static_uuid @@ -42,6 +41,7 @@ OpenLineageRedactor, get_airflow_debug_facet, get_airflow_state_run_facet, + get_processing_engine_facet, ) from airflow.stats import Stats from airflow.utils.log.logging_mixin import LoggingMixin @@ -195,18 +195,10 @@ def start_task( :param task: metadata container with information extracted from operator :param run_facets: custom run facets """ - from airflow.version import version as AIRFLOW_VERSION - - processing_engine_version_facet = processing_engine_run.ProcessingEngineRunFacet( - version=AIRFLOW_VERSION, - name="Airflow", - openlineageAdapterVersion=OPENLINEAGE_PROVIDER_VERSION, - ) - run_facets = run_facets or {} if task: run_facets = {**task.run_facets, **run_facets} - run_facets["processing_engine"] = processing_engine_version_facet # type: ignore + run_facets = {**run_facets, **get_processing_engine_facet()} # type: ignore event = RunEvent( eventType=RunState.START, eventTime=event_time, @@ -362,7 +354,7 @@ def dag_started( job_name=dag_id, nominal_start_time=nominal_start_time, nominal_end_time=nominal_end_time, - run_facets={**run_facets, **get_airflow_debug_facet()}, + run_facets={**run_facets, **get_airflow_debug_facet(), **get_processing_engine_facet()}, ), inputs=[], outputs=[], diff --git a/providers/src/airflow/providers/openlineage/utils/utils.py b/providers/src/airflow/providers/openlineage/utils/utils.py index 8c67c32f95b8..99faa3c4d5ce 100644 --- a/providers/src/airflow/providers/openlineage/utils/utils.py +++ b/providers/src/airflow/providers/openlineage/utils/utils.py @@ -38,7 +38,7 @@ # TODO: move this maybe to Airflow's logic? from airflow.models import DAG, BaseOperator, DagRun, MappedOperator from airflow.providers.common.compat.assets import Asset -from airflow.providers.openlineage import conf +from airflow.providers.openlineage import __version__ as OPENLINEAGE_PROVIDER_VERSION, conf from airflow.providers.openlineage.plugins.facets import ( AirflowDagRunFacet, AirflowDebugRunFacet, @@ -65,7 +65,7 @@ if TYPE_CHECKING: from openlineage.client.event_v2 import Dataset as OpenLineageDataset - from openlineage.client.facet_v2 import RunFacet + from openlineage.client.facet_v2 import RunFacet, processing_engine_run from airflow.models import TaskInstance from airflow.utils.state import DagRunState, TaskInstanceState @@ -428,6 +428,18 @@ def _get_all_packages_installed() -> dict[str, str]: return {dist.metadata["Name"]: dist.version for dist in metadata.distributions()} +def get_processing_engine_facet() -> dict[str, processing_engine_run.ProcessingEngineRunFacet]: + from openlineage.client.facet_v2 import processing_engine_run + + return { + "processing_engine": processing_engine_run.ProcessingEngineRunFacet( + version=AIRFLOW_VERSION, + name="Airflow", + openlineageAdapterVersion=OPENLINEAGE_PROVIDER_VERSION, + ) + } + + def get_airflow_debug_facet() -> dict[str, AirflowDebugRunFacet]: if not conf.debug_mode(): return {} diff --git a/providers/tests/openlineage/plugins/test_adapter.py b/providers/tests/openlineage/plugins/test_adapter.py index f0928dd70db0..73145f9e4b1c 100644 --- a/providers/tests/openlineage/plugins/test_adapter.py +++ b/providers/tests/openlineage/plugins/test_adapter.py @@ -606,6 +606,9 @@ def test_emit_dag_started_event(mock_stats_incr, mock_stats_timer, generate_stat nominalStartTime=event_time.isoformat(), nominalEndTime=event_time.isoformat(), ), + "processing_engine": processing_engine_run.ProcessingEngineRunFacet( + version=ANY, name="Airflow", openlineageAdapterVersion=ANY + ), "airflowDagRun": AirflowDagRunFacet( dag=expected_dag_info, dagRun={ diff --git a/providers/tests/openlineage/plugins/test_utils.py b/providers/tests/openlineage/plugins/test_utils.py index 22b80120bb6a..e84fac118657 100644 --- a/providers/tests/openlineage/plugins/test_utils.py +++ b/providers/tests/openlineage/plugins/test_utils.py @@ -40,6 +40,7 @@ get_airflow_debug_facet, get_airflow_run_facet, get_fully_qualified_class_name, + get_processing_engine_facet, is_operator_disabled, ) from airflow.serialization.enums import DagAttributeTypes @@ -438,3 +439,19 @@ def test_serialize_timetable_2_8(): ], } } + + +@pytest.mark.parametrize( + ("airflow_version", "ol_version"), + [ + ("2.9.3", "1.12.2"), + ("2.10.1", "1.13.0"), + ("3.0.0", "1.14.0"), + ], +) +def test_get_processing_engine_facet(airflow_version, ol_version): + with patch("airflow.providers.openlineage.utils.utils.AIRFLOW_VERSION", airflow_version): + with patch("airflow.providers.openlineage.utils.utils.OPENLINEAGE_PROVIDER_VERSION", ol_version): + result = get_processing_engine_facet() + assert result["processing_engine"].version == airflow_version + assert result["processing_engine"].openlineageAdapterVersion == ol_version From 61076f0ab57feda67ab4d013ed41365774be9bbd Mon Sep 17 00:00:00 2001 From: Brent Bovenzi Date: Thu, 14 Nov 2024 10:03:55 -0500 Subject: [PATCH 08/33] Improve ability to see and clear dags list filters (#43981) * Improve ability to see and clear dags list filters * Refine filter buttons again --- airflow/ui/src/components/DagRunInfo.tsx | 12 ++-- .../DataTable/ToggleTableDisplay.tsx | 12 +++- .../ui/src/components/QuickFilterButton.tsx | 12 +++- airflow/ui/src/components/SearchBar.tsx | 3 +- airflow/ui/src/components/StateCircle.tsx | 38 +++++++++++ .../TriggerDag/TriggerDAGIconButton.tsx | 12 +++- .../ui/src/components/ui/Select/Trigger.tsx | 49 +++++++------- airflow/ui/src/pages/DagsList/DagsFilters.tsx | 67 +++++++++++++++---- airflow/ui/src/utils/advancedSelectStyles.ts | 18 +++++ 9 files changed, 168 insertions(+), 55 deletions(-) create mode 100644 airflow/ui/src/components/StateCircle.tsx create mode 100644 airflow/ui/src/utils/advancedSelectStyles.ts diff --git a/airflow/ui/src/components/DagRunInfo.tsx b/airflow/ui/src/components/DagRunInfo.tsx index 4cc2f7027370..0d30e9c7667c 100644 --- a/airflow/ui/src/components/DagRunInfo.tsx +++ b/airflow/ui/src/components/DagRunInfo.tsx @@ -16,7 +16,7 @@ * specific language governing permissions and limitations * under the License. */ -import { VStack, Text, Box, HStack } from "@chakra-ui/react"; +import { VStack, Text, HStack } from "@chakra-ui/react"; import dayjs from "dayjs"; import type { DAGRunResponse } from "openapi/requests/types.gen"; @@ -24,6 +24,8 @@ import Time from "src/components/Time"; import { Tooltip } from "src/components/ui"; import { stateColor } from "src/utils/stateColor"; +import { StateCircle } from "./StateCircle"; + type Props = { readonly dataIntervalEnd?: string | null; readonly dataIntervalStart?: string | null; @@ -81,13 +83,7 @@ const DagRunInfo = ({