From f0f18505b357b02535686ee89de7ed74b544fe1f Mon Sep 17 00:00:00 2001 From: Omkar P <45419097+omkar-foss@users.noreply.github.com> Date: Fri, 8 Nov 2024 19:20:19 +0530 Subject: [PATCH] Migrate public endpoint Get Task to FastAPI, with main resynced --- .../api_connexion/endpoints/task_endpoint.py | 2 + airflow/api_fastapi/common/types.py | 64 +++- .../api_fastapi/core_api/datamodels/tasks.py | 83 +++++ .../core_api/openapi/v1-generated.yaml | 246 +++++++++++++ .../core_api/routes/public/__init__.py | 3 + .../core_api/routes/public/tasks.py | 56 +++ airflow/ui/openapi-gen/queries/common.ts | 19 + airflow/ui/openapi-gen/queries/prefetch.ts | 24 ++ airflow/ui/openapi-gen/queries/queries.ts | 30 ++ airflow/ui/openapi-gen/queries/suspense.ts | 30 ++ .../ui/openapi-gen/requests/schemas.gen.ts | 344 ++++++++++++++++++ .../ui/openapi-gen/requests/services.gen.ts | 31 ++ airflow/ui/openapi-gen/requests/types.gen.ts | 88 +++++ .../core_api/routes/public/test_tasks.py | 297 +++++++++++++++ 14 files changed, 1316 insertions(+), 1 deletion(-) create mode 100644 airflow/api_fastapi/core_api/datamodels/tasks.py create mode 100644 airflow/api_fastapi/core_api/routes/public/tasks.py create mode 100644 tests/api_fastapi/core_api/routes/public/test_tasks.py diff --git a/airflow/api_connexion/endpoints/task_endpoint.py b/airflow/api_connexion/endpoints/task_endpoint.py index 4c5954d2ac5f0..abc28cfee6fbb 100644 --- a/airflow/api_connexion/endpoints/task_endpoint.py +++ b/airflow/api_connexion/endpoints/task_endpoint.py @@ -25,12 +25,14 @@ from airflow.auth.managers.models.resource_details import DagAccessEntity from airflow.exceptions import TaskNotFound from airflow.utils.airflow_flask_app import get_airflow_app +from airflow.utils.api_migration import mark_fastapi_migration_done if TYPE_CHECKING: from airflow import DAG from airflow.api_connexion.types import APIResponse +@mark_fastapi_migration_done @security.requires_access_dag("GET", DagAccessEntity.TASK) def get_task(*, dag_id: str, task_id: str) -> APIResponse: """Get simplified representation of a task.""" diff --git a/airflow/api_fastapi/common/types.py b/airflow/api_fastapi/common/types.py index d9664c072213a..ec619aaea6a94 100644 --- a/airflow/api_fastapi/common/types.py +++ b/airflow/api_fastapi/common/types.py @@ -16,10 +16,72 @@ # under the License. from __future__ import annotations -from pydantic import AfterValidator, AwareDatetime +import inspect +from datetime import timedelta + +from pydantic import AfterValidator, AliasGenerator, AwareDatetime, BaseModel, BeforeValidator, ConfigDict from typing_extensions import Annotated +from airflow.models.mappedoperator import MappedOperator +from airflow.models.operator import Operator +from airflow.serialization.serialized_objects import SerializedBaseOperator from airflow.utils import timezone UtcDateTime = Annotated[AwareDatetime, AfterValidator(lambda d: d.astimezone(timezone.utc))] """UTCDateTime is a datetime with timezone information""" + + +def _validate_timedelta_field(td: timedelta | None) -> TimeDelta | None: + """Validate the execution_timeout property.""" + if td is None: + return None + return TimeDelta( + days=td.days, + seconds=td.seconds, + microseconds=td.microseconds, + ) + + +class TimeDelta(BaseModel): + """TimeDelta can be used to interact with datetime.timedelta objects.""" + + object_type: str = "TimeDelta" + days: int + seconds: int + microseconds: int + + model_config = ConfigDict( + alias_generator=AliasGenerator( + serialization_alias=lambda field_name: { + "object_type": "__type", + }.get(field_name, field_name), + ) + ) + + +TimeDeltaWithValidation = Annotated[TimeDelta, BeforeValidator(_validate_timedelta_field)] + + +def get_class_ref(obj: Operator) -> dict[str, str | None]: + """Return the class_ref dict for obj.""" + is_mapped_or_serialized = isinstance(obj, (MappedOperator, SerializedBaseOperator)) + + module_path = None + if is_mapped_or_serialized: + module_path = obj._task_module + else: + module_type = inspect.getmodule(obj) + module_path = module_type.__name__ if module_type else None + + class_name = None + if is_mapped_or_serialized: + class_name = obj._task_type + elif obj.__class__ is type: + class_name = obj.__name__ + else: + class_name = type(obj).__name__ + + return { + "module_path": module_path, + "class_name": class_name, + } diff --git a/airflow/api_fastapi/core_api/datamodels/tasks.py b/airflow/api_fastapi/core_api/datamodels/tasks.py new file mode 100644 index 0000000000000..7caaf9c02f473 --- /dev/null +++ b/airflow/api_fastapi/core_api/datamodels/tasks.py @@ -0,0 +1,83 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from collections import abc +from datetime import datetime + +from pydantic import BaseModel, computed_field, field_validator + +from airflow.api_fastapi.common.types import TimeDeltaWithValidation +from airflow.serialization.serialized_objects import encode_priority_weight_strategy +from airflow.task.priority_strategy import PriorityWeightStrategy + + +class TaskResponse(BaseModel): + """Task serializer for responses.""" + + task_id: str | None + task_display_name: str | None + owner: str | None + start_date: datetime | None + end_date: datetime | None + trigger_rule: str | None + depends_on_past: bool + wait_for_downstream: bool + retries: float | None + queue: str | None + pool: str | None + pool_slots: float | None + execution_timeout: TimeDeltaWithValidation | None + retry_delay: TimeDeltaWithValidation | None + retry_exponential_backoff: bool + priority_weight: float | None + weight_rule: str | None + ui_color: str | None + ui_fgcolor: str | None + template_fields: list[str] | None + downstream_task_ids: list[str] | None + doc_md: str | None + operator_name: str | None + params: abc.MutableMapping | None + class_ref: dict | None + is_mapped: bool | None + + @field_validator("weight_rule", mode="before") + @classmethod + def validate_weight_rule(cls, wr: str | PriorityWeightStrategy | None) -> str | None: + """Validate the weight_rule property.""" + if wr is None: + return None + if isinstance(wr, str): + return wr + return encode_priority_weight_strategy(wr) + + @field_validator("params", mode="before") + @classmethod + def get_params(cls, params: abc.MutableMapping | None) -> dict | None: + """Convert params attribute to dict representation.""" + if params is None: + return None + return {param_name: param_val.dump() for param_name, param_val in params.items()} + + # Mypy issue https://github.com/python/mypy/issues/1362 + @computed_field # type: ignore[misc] + @property + def extra_links(self) -> list[str]: + """Extract and return extra_links.""" + return getattr(self, "operator_extra_links", []) diff --git a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml index a4e270ed7e6e3..315cfaadf0123 100644 --- a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml +++ b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml @@ -3104,6 +3104,62 @@ paths: application/json: schema: $ref: '#/components/schemas/HTTPValidationError' + /public/dags/{dag_id}/tasks/{task_id}: + get: + tags: + - Task + summary: Get Task + description: Get simplified representation of a task. + operationId: get_task + parameters: + - name: dag_id + in: path + required: true + schema: + type: string + title: Dag Id + - name: task_id + in: path + required: true + schema: + title: Task Id + responses: + '200': + description: Successful Response + content: + application/json: + schema: + $ref: '#/components/schemas/TaskResponse' + '400': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Bad Request + '401': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Unauthorized + '403': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Forbidden + '404': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Not Found + '422': + description: Validation Error + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPValidationError' components: schemas: AppBuilderMenuItemResponse: @@ -4913,6 +4969,196 @@ components: - triggerer_job title: TaskInstanceResponse description: TaskInstance serializer for responses. + TaskResponse: + properties: + task_id: + anyOf: + - type: string + - type: 'null' + title: Task Id + task_display_name: + anyOf: + - type: string + - type: 'null' + title: Task Display Name + owner: + anyOf: + - type: string + - type: 'null' + title: Owner + start_date: + anyOf: + - type: string + format: date-time + - type: 'null' + title: Start Date + end_date: + anyOf: + - type: string + format: date-time + - type: 'null' + title: End Date + trigger_rule: + anyOf: + - type: string + - type: 'null' + title: Trigger Rule + depends_on_past: + type: boolean + title: Depends On Past + wait_for_downstream: + type: boolean + title: Wait For Downstream + retries: + anyOf: + - type: number + - type: 'null' + title: Retries + queue: + anyOf: + - type: string + - type: 'null' + title: Queue + pool: + anyOf: + - type: string + - type: 'null' + title: Pool + pool_slots: + anyOf: + - type: number + - type: 'null' + title: Pool Slots + execution_timeout: + anyOf: + - $ref: '#/components/schemas/TimeDelta' + - type: 'null' + retry_delay: + anyOf: + - $ref: '#/components/schemas/TimeDelta' + - type: 'null' + retry_exponential_backoff: + type: boolean + title: Retry Exponential Backoff + priority_weight: + anyOf: + - type: number + - type: 'null' + title: Priority Weight + weight_rule: + anyOf: + - type: string + - type: 'null' + title: Weight Rule + ui_color: + anyOf: + - type: string + - type: 'null' + title: Ui Color + ui_fgcolor: + anyOf: + - type: string + - type: 'null' + title: Ui Fgcolor + template_fields: + anyOf: + - items: + type: string + type: array + - type: 'null' + title: Template Fields + downstream_task_ids: + anyOf: + - items: + type: string + type: array + - type: 'null' + title: Downstream Task Ids + doc_md: + anyOf: + - type: string + - type: 'null' + title: Doc Md + operator_name: + anyOf: + - type: string + - type: 'null' + title: Operator Name + params: + anyOf: + - type: object + - type: 'null' + title: Params + class_ref: + anyOf: + - type: object + - type: 'null' + title: Class Ref + is_mapped: + anyOf: + - type: boolean + - type: 'null' + title: Is Mapped + extra_links: + items: + type: string + type: array + title: Extra Links + description: Extract and return extra_links. + readOnly: true + type: object + required: + - task_id + - task_display_name + - owner + - start_date + - end_date + - trigger_rule + - depends_on_past + - wait_for_downstream + - retries + - queue + - pool + - pool_slots + - execution_timeout + - retry_delay + - retry_exponential_backoff + - priority_weight + - weight_rule + - ui_color + - ui_fgcolor + - template_fields + - downstream_task_ids + - doc_md + - operator_name + - params + - class_ref + - is_mapped + - extra_links + title: TaskResponse + description: Task serializer for responses. + TimeDelta: + properties: + __type: + type: string + title: ' Type' + default: TimeDelta + days: + type: integer + title: Days + seconds: + type: integer + title: Seconds + microseconds: + type: integer + title: Microseconds + type: object + required: + - days + - seconds + - microseconds + title: TimeDelta + description: TimeDelta can be used to interact with datetime.timedelta objects. TriggerResponse: properties: id: diff --git a/airflow/api_fastapi/core_api/routes/public/__init__.py b/airflow/api_fastapi/core_api/routes/public/__init__.py index b7c8affe4a9cb..07f7b163ecde9 100644 --- a/airflow/api_fastapi/core_api/routes/public/__init__.py +++ b/airflow/api_fastapi/core_api/routes/public/__init__.py @@ -32,6 +32,7 @@ from airflow.api_fastapi.core_api.routes.public.pools import pools_router from airflow.api_fastapi.core_api.routes.public.providers import providers_router from airflow.api_fastapi.core_api.routes.public.task_instances import task_instances_router +from airflow.api_fastapi.core_api.routes.public.tasks import tasks_router from airflow.api_fastapi.core_api.routes.public.variables import variables_router from airflow.api_fastapi.core_api.routes.public.version import version_router @@ -56,3 +57,5 @@ public_router.include_router(variables_router) public_router.include_router(version_router) public_router.include_router(dag_stats_router) +public_router.include_router(plugins_router) +public_router.include_router(tasks_router) diff --git a/airflow/api_fastapi/core_api/routes/public/tasks.py b/airflow/api_fastapi/core_api/routes/public/tasks.py new file mode 100644 index 0000000000000..0d39266419ce9 --- /dev/null +++ b/airflow/api_fastapi/core_api/routes/public/tasks.py @@ -0,0 +1,56 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from fastapi import HTTPException, Request, status + +from airflow.api_fastapi.common.router import AirflowRouter +from airflow.api_fastapi.common.types import get_class_ref +from airflow.api_fastapi.core_api.datamodels.tasks import TaskResponse +from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc +from airflow.exceptions import TaskNotFound +from airflow.models import DAG +from airflow.models.mappedoperator import MappedOperator +from airflow.models.operator import Operator + +tasks_router = AirflowRouter(tags=["Task"], prefix="/dags/{dag_id}/tasks") + + +@tasks_router.get( + "/{task_id}", + responses=create_openapi_http_exception_doc( + [ + status.HTTP_400_BAD_REQUEST, + status.HTTP_401_UNAUTHORIZED, + status.HTTP_403_FORBIDDEN, + status.HTTP_404_NOT_FOUND, + ] + ), +) +def get_task(dag_id: str, task_id, request: Request) -> TaskResponse: + """Get simplified representation of a task.""" + dag: DAG = request.app.state.dag_bag.get_dag(dag_id) + if not dag: + raise HTTPException(status.HTTP_404_NOT_FOUND, f"Dag with id {dag_id} was not found") + try: + task: Operator = dag.get_task(task_id=task_id) + task.class_ref = get_class_ref(task) + task.is_mapped = isinstance(task, MappedOperator) + except TaskNotFound: + raise HTTPException(status.HTTP_404_NOT_FOUND, f"Task with id {task_id} was not found") + return TaskResponse.model_validate(task, from_attributes=True) diff --git a/airflow/ui/openapi-gen/queries/common.ts b/airflow/ui/openapi-gen/queries/common.ts index cfbb945b4e57f..5827f00f0ca10 100644 --- a/airflow/ui/openapi-gen/queries/common.ts +++ b/airflow/ui/openapi-gen/queries/common.ts @@ -19,6 +19,7 @@ import { PoolService, ProviderService, TaskInstanceService, + TaskService, VariableService, VersionService, } from "../requests/services.gen"; @@ -905,6 +906,24 @@ export const UseDagStatsServiceGetDagStatsKeyFn = ( } = {}, queryKey?: Array, ) => [useDagStatsServiceGetDagStatsKey, ...(queryKey ?? [{ dagIds }])]; +export type TaskServiceGetTaskDefaultResponse = Awaited< + ReturnType +>; +export type TaskServiceGetTaskQueryResult< + TData = TaskServiceGetTaskDefaultResponse, + TError = unknown, +> = UseQueryResult; +export const useTaskServiceGetTaskKey = "TaskServiceGetTask"; +export const UseTaskServiceGetTaskKeyFn = ( + { + dagId, + taskId, + }: { + dagId: string; + taskId: unknown; + }, + queryKey?: Array, +) => [useTaskServiceGetTaskKey, ...(queryKey ?? [{ dagId, taskId }])]; export type BackfillServiceCreateBackfillMutationResult = Awaited< ReturnType >; diff --git a/airflow/ui/openapi-gen/queries/prefetch.ts b/airflow/ui/openapi-gen/queries/prefetch.ts index b8dc71dd821cd..d5efbf73b3a45 100644 --- a/airflow/ui/openapi-gen/queries/prefetch.ts +++ b/airflow/ui/openapi-gen/queries/prefetch.ts @@ -19,6 +19,7 @@ import { PoolService, ProviderService, TaskInstanceService, + TaskService, VariableService, VersionService, } from "../requests/services.gen"; @@ -1214,3 +1215,26 @@ export const prefetchUseDagStatsServiceGetDagStats = ( queryKey: Common.UseDagStatsServiceGetDagStatsKeyFn({ dagIds }), queryFn: () => DagStatsService.getDagStats({ dagIds }), }); +/** + * Get Task + * Get simplified representation of a task. + * @param data The data for the request. + * @param data.dagId + * @param data.taskId + * @returns TaskResponse Successful Response + * @throws ApiError + */ +export const prefetchUseTaskServiceGetTask = ( + queryClient: QueryClient, + { + dagId, + taskId, + }: { + dagId: string; + taskId: unknown; + }, +) => + queryClient.prefetchQuery({ + queryKey: Common.UseTaskServiceGetTaskKeyFn({ dagId, taskId }), + queryFn: () => TaskService.getTask({ dagId, taskId }), + }); diff --git a/airflow/ui/openapi-gen/queries/queries.ts b/airflow/ui/openapi-gen/queries/queries.ts index 3a8d508a8c426..b66fac68206ca 100644 --- a/airflow/ui/openapi-gen/queries/queries.ts +++ b/airflow/ui/openapi-gen/queries/queries.ts @@ -24,6 +24,7 @@ import { PoolService, ProviderService, TaskInstanceService, + TaskService, VariableService, VersionService, } from "../requests/services.gen"; @@ -1464,6 +1465,35 @@ export const useDagStatsServiceGetDagStats = < queryFn: () => DagStatsService.getDagStats({ dagIds }) as TData, ...options, }); +/** + * Get Task + * Get simplified representation of a task. + * @param data The data for the request. + * @param data.dagId + * @param data.taskId + * @returns TaskResponse Successful Response + * @throws ApiError + */ +export const useTaskServiceGetTask = < + TData = Common.TaskServiceGetTaskDefaultResponse, + TError = unknown, + TQueryKey extends Array = unknown[], +>( + { + dagId, + taskId, + }: { + dagId: string; + taskId: unknown; + }, + queryKey?: TQueryKey, + options?: Omit, "queryKey" | "queryFn">, +) => + useQuery({ + queryKey: Common.UseTaskServiceGetTaskKeyFn({ dagId, taskId }, queryKey), + queryFn: () => TaskService.getTask({ dagId, taskId }) as TData, + ...options, + }); /** * Create Backfill * @param data The data for the request. diff --git a/airflow/ui/openapi-gen/queries/suspense.ts b/airflow/ui/openapi-gen/queries/suspense.ts index 3672219a23f2f..f984a1cc03843 100644 --- a/airflow/ui/openapi-gen/queries/suspense.ts +++ b/airflow/ui/openapi-gen/queries/suspense.ts @@ -19,6 +19,7 @@ import { PoolService, ProviderService, TaskInstanceService, + TaskService, VariableService, VersionService, } from "../requests/services.gen"; @@ -1449,3 +1450,32 @@ export const useDagStatsServiceGetDagStatsSuspense = < queryFn: () => DagStatsService.getDagStats({ dagIds }) as TData, ...options, }); +/** + * Get Task + * Get simplified representation of a task. + * @param data The data for the request. + * @param data.dagId + * @param data.taskId + * @returns TaskResponse Successful Response + * @throws ApiError + */ +export const useTaskServiceGetTaskSuspense = < + TData = Common.TaskServiceGetTaskDefaultResponse, + TError = unknown, + TQueryKey extends Array = unknown[], +>( + { + dagId, + taskId, + }: { + dagId: string; + taskId: unknown; + }, + queryKey?: TQueryKey, + options?: Omit, "queryKey" | "queryFn">, +) => + useSuspenseQuery({ + queryKey: Common.UseTaskServiceGetTaskKeyFn({ dagId, taskId }, queryKey), + queryFn: () => TaskService.getTask({ dagId, taskId }) as TData, + ...options, + }); diff --git a/airflow/ui/openapi-gen/requests/schemas.gen.ts b/airflow/ui/openapi-gen/requests/schemas.gen.ts index 0e583e014f0d3..80de709844510 100644 --- a/airflow/ui/openapi-gen/requests/schemas.gen.ts +++ b/airflow/ui/openapi-gen/requests/schemas.gen.ts @@ -2788,6 +2788,350 @@ export const $TaskInstanceResponse = { description: "TaskInstance serializer for responses.", } as const; +export const $TaskResponse = { + properties: { + task_id: { + anyOf: [ + { + type: "string", + }, + { + type: "null", + }, + ], + title: "Task Id", + }, + task_display_name: { + anyOf: [ + { + type: "string", + }, + { + type: "null", + }, + ], + title: "Task Display Name", + }, + owner: { + anyOf: [ + { + type: "string", + }, + { + type: "null", + }, + ], + title: "Owner", + }, + start_date: { + anyOf: [ + { + type: "string", + format: "date-time", + }, + { + type: "null", + }, + ], + title: "Start Date", + }, + end_date: { + anyOf: [ + { + type: "string", + format: "date-time", + }, + { + type: "null", + }, + ], + title: "End Date", + }, + trigger_rule: { + anyOf: [ + { + type: "string", + }, + { + type: "null", + }, + ], + title: "Trigger Rule", + }, + depends_on_past: { + type: "boolean", + title: "Depends On Past", + }, + wait_for_downstream: { + type: "boolean", + title: "Wait For Downstream", + }, + retries: { + anyOf: [ + { + type: "number", + }, + { + type: "null", + }, + ], + title: "Retries", + }, + queue: { + anyOf: [ + { + type: "string", + }, + { + type: "null", + }, + ], + title: "Queue", + }, + pool: { + anyOf: [ + { + type: "string", + }, + { + type: "null", + }, + ], + title: "Pool", + }, + pool_slots: { + anyOf: [ + { + type: "number", + }, + { + type: "null", + }, + ], + title: "Pool Slots", + }, + execution_timeout: { + anyOf: [ + { + $ref: "#/components/schemas/TimeDelta", + }, + { + type: "null", + }, + ], + }, + retry_delay: { + anyOf: [ + { + $ref: "#/components/schemas/TimeDelta", + }, + { + type: "null", + }, + ], + }, + retry_exponential_backoff: { + type: "boolean", + title: "Retry Exponential Backoff", + }, + priority_weight: { + anyOf: [ + { + type: "number", + }, + { + type: "null", + }, + ], + title: "Priority Weight", + }, + weight_rule: { + anyOf: [ + { + type: "string", + }, + { + type: "null", + }, + ], + title: "Weight Rule", + }, + ui_color: { + anyOf: [ + { + type: "string", + }, + { + type: "null", + }, + ], + title: "Ui Color", + }, + ui_fgcolor: { + anyOf: [ + { + type: "string", + }, + { + type: "null", + }, + ], + title: "Ui Fgcolor", + }, + template_fields: { + anyOf: [ + { + items: { + type: "string", + }, + type: "array", + }, + { + type: "null", + }, + ], + title: "Template Fields", + }, + downstream_task_ids: { + anyOf: [ + { + items: { + type: "string", + }, + type: "array", + }, + { + type: "null", + }, + ], + title: "Downstream Task Ids", + }, + doc_md: { + anyOf: [ + { + type: "string", + }, + { + type: "null", + }, + ], + title: "Doc Md", + }, + operator_name: { + anyOf: [ + { + type: "string", + }, + { + type: "null", + }, + ], + title: "Operator Name", + }, + params: { + anyOf: [ + { + type: "object", + }, + { + type: "null", + }, + ], + title: "Params", + }, + class_ref: { + anyOf: [ + { + type: "object", + }, + { + type: "null", + }, + ], + title: "Class Ref", + }, + is_mapped: { + anyOf: [ + { + type: "boolean", + }, + { + type: "null", + }, + ], + title: "Is Mapped", + }, + extra_links: { + items: { + type: "string", + }, + type: "array", + title: "Extra Links", + description: "Extract and return extra_links.", + readOnly: true, + }, + }, + type: "object", + required: [ + "task_id", + "task_display_name", + "owner", + "start_date", + "end_date", + "trigger_rule", + "depends_on_past", + "wait_for_downstream", + "retries", + "queue", + "pool", + "pool_slots", + "execution_timeout", + "retry_delay", + "retry_exponential_backoff", + "priority_weight", + "weight_rule", + "ui_color", + "ui_fgcolor", + "template_fields", + "downstream_task_ids", + "doc_md", + "operator_name", + "params", + "class_ref", + "is_mapped", + "extra_links", + ], + title: "TaskResponse", + description: "Task serializer for responses.", +} as const; + +export const $TimeDelta = { + properties: { + __type: { + type: "string", + title: " Type", + default: "TimeDelta", + }, + days: { + type: "integer", + title: "Days", + }, + seconds: { + type: "integer", + title: "Seconds", + }, + microseconds: { + type: "integer", + title: "Microseconds", + }, + }, + type: "object", + required: ["days", "seconds", "microseconds"], + title: "TimeDelta", + description: + "TimeDelta can be used to interact with datetime.timedelta objects.", +} as const; + export const $TriggerResponse = { properties: { id: { diff --git a/airflow/ui/openapi-gen/requests/services.gen.ts b/airflow/ui/openapi-gen/requests/services.gen.ts index 8a6cd3e4f702a..cd9caaefe3dd7 100644 --- a/airflow/ui/openapi-gen/requests/services.gen.ts +++ b/airflow/ui/openapi-gen/requests/services.gen.ts @@ -103,6 +103,8 @@ import type { GetVersionResponse, GetDagStatsData, GetDagStatsResponse, + GetTaskData, + GetTaskResponse, } from "./types.gen"; export class AssetService { @@ -1689,3 +1691,32 @@ export class DagStatsService { }); } } + +export class TaskService { + /** + * Get Task + * Get simplified representation of a task. + * @param data The data for the request. + * @param data.dagId + * @param data.taskId + * @returns TaskResponse Successful Response + * @throws ApiError + */ + public static getTask(data: GetTaskData): CancelablePromise { + return __request(OpenAPI, { + method: "GET", + url: "/public/dags/{dag_id}/tasks/{task_id}", + path: { + dag_id: data.dagId, + task_id: data.taskId, + }, + errors: { + 400: "Bad Request", + 401: "Unauthorized", + 403: "Forbidden", + 404: "Not Found", + 422: "Validation Error", + }, + }); + } +} diff --git a/airflow/ui/openapi-gen/requests/types.gen.ts b/airflow/ui/openapi-gen/requests/types.gen.ts index 2b96ab7140a62..efe56f4cb5752 100644 --- a/airflow/ui/openapi-gen/requests/types.gen.ts +++ b/airflow/ui/openapi-gen/requests/types.gen.ts @@ -653,6 +653,56 @@ export type TaskInstanceResponse = { triggerer_job: JobResponse | null; }; +/** + * Task serializer for responses. + */ +export type TaskResponse = { + task_id: string | null; + task_display_name: string | null; + owner: string | null; + start_date: string | null; + end_date: string | null; + trigger_rule: string | null; + depends_on_past: boolean; + wait_for_downstream: boolean; + retries: number | null; + queue: string | null; + pool: string | null; + pool_slots: number | null; + execution_timeout: TimeDelta | null; + retry_delay: TimeDelta | null; + retry_exponential_backoff: boolean; + priority_weight: number | null; + weight_rule: string | null; + ui_color: string | null; + ui_fgcolor: string | null; + template_fields: Array | null; + downstream_task_ids: Array | null; + doc_md: string | null; + operator_name: string | null; + params: { + [key: string]: unknown; + } | null; + class_ref: { + [key: string]: unknown; + } | null; + is_mapped: boolean | null; + /** + * Extract and return extra_links. + */ + readonly extra_links: Array; +}; + +/** + * TimeDelta can be used to interact with datetime.timedelta objects. + */ +export type TimeDelta = { + __type?: string; + days: number; + seconds: number; + microseconds: number; +}; + /** * Trigger serializer for responses. */ @@ -1176,6 +1226,13 @@ export type GetDagStatsData = { export type GetDagStatsResponse = DagStatsCollectionResponse; +export type GetTaskData = { + dagId: string; + taskId: unknown; +}; + +export type GetTaskResponse = TaskResponse; + export type $OpenApiTs = { "/ui/next_run_assets/{dag_id}": { get: { @@ -2464,4 +2521,35 @@ export type $OpenApiTs = { }; }; }; + "/public/dags/{dag_id}/tasks/{task_id}": { + get: { + req: GetTaskData; + res: { + /** + * Successful Response + */ + 200: TaskResponse; + /** + * Bad Request + */ + 400: HTTPExceptionResponse; + /** + * Unauthorized + */ + 401: HTTPExceptionResponse; + /** + * Forbidden + */ + 403: HTTPExceptionResponse; + /** + * Not Found + */ + 404: HTTPExceptionResponse; + /** + * Validation Error + */ + 422: HTTPValidationError; + }; + }; + }; }; diff --git a/tests/api_fastapi/core_api/routes/public/test_tasks.py b/tests/api_fastapi/core_api/routes/public/test_tasks.py new file mode 100644 index 0000000000000..e73498d3a6610 --- /dev/null +++ b/tests/api_fastapi/core_api/routes/public/test_tasks.py @@ -0,0 +1,297 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import os +import unittest +from datetime import datetime + +import pytest + +from airflow.models.dag import DAG +from airflow.models.dagbag import DagBag +from airflow.models.expandinput import EXPAND_INPUT_EMPTY +from airflow.models.serialized_dag import SerializedDagModel +from airflow.operators.empty import EmptyOperator + +from tests_common.test_utils.db import clear_db_dags, clear_db_runs, clear_db_serialized_dags + +pytestmark = pytest.mark.db_test + + +class TestTaskEndpoint: + dag_id = "test_dag" + mapped_dag_id = "test_mapped_task" + unscheduled_dag_id = "test_unscheduled_dag" + task_id = "op1" + task_id2 = "op2" + task_id3 = "op3" + mapped_task_id = "mapped_task" + unscheduled_task_id1 = "unscheduled_task_1" + unscheduled_task_id2 = "unscheduled_task_2" + task1_start_date = datetime(2020, 6, 15) + task2_start_date = datetime(2020, 6, 16) + + def create_dags(self, test_client): + with DAG(self.dag_id, schedule=None, start_date=self.task1_start_date, doc_md="details") as dag: + task1 = EmptyOperator(task_id=self.task_id, params={"foo": "bar"}) + task2 = EmptyOperator(task_id=self.task_id2, start_date=self.task2_start_date) + + with DAG(self.mapped_dag_id, schedule=None, start_date=self.task1_start_date) as mapped_dag: + EmptyOperator(task_id=self.task_id3) + # Use the private _expand() method to avoid the empty kwargs check. + # We don't care about how the operator runs here, only its presence. + EmptyOperator.partial(task_id=self.mapped_task_id)._expand(EXPAND_INPUT_EMPTY, strict=False) + + with DAG(self.unscheduled_dag_id, start_date=None, schedule=None) as unscheduled_dag: + task4 = EmptyOperator(task_id=self.unscheduled_task_id1, params={"is_unscheduled": True}) + task5 = EmptyOperator(task_id=self.unscheduled_task_id2, params={"is_unscheduled": True}) + + task1 >> task2 + task4 >> task5 + dag_bag = DagBag(os.devnull, include_examples=False) + dag_bag.dags = { + dag.dag_id: dag, + mapped_dag.dag_id: mapped_dag, + unscheduled_dag.dag_id: unscheduled_dag, + } + test_client.app.state.dag_bag = dag_bag + + @staticmethod + def clear_db(): + clear_db_runs() + clear_db_dags() + clear_db_serialized_dags() + + @pytest.fixture(autouse=True) + def setup(self, test_client) -> None: + self.clear_db() + self.create_dags(test_client) + + def teardown_method(self) -> None: + self.clear_db() + + +class TestGetTask(TestTaskEndpoint): + def test_should_respond_200(self, test_client): + expected = { + "class_ref": { + "class_name": "EmptyOperator", + "module_path": "airflow.operators.empty", + }, + "depends_on_past": False, + "downstream_task_ids": [self.task_id2], + "end_date": None, + "execution_timeout": None, + "extra_links": [], + "operator_name": "EmptyOperator", + "owner": "airflow", + "params": { + "foo": { + "__class": "airflow.models.param.Param", + "value": "bar", + "description": None, + "schema": {}, + } + }, + "pool": "default_pool", + "pool_slots": 1.0, + "priority_weight": 1.0, + "queue": "default", + "retries": 0.0, + "retry_delay": {"__type": "TimeDelta", "days": 0, "seconds": 300, "microseconds": 0}, + "retry_exponential_backoff": False, + "start_date": self.task1_start_date.replace(tzinfo=None).isoformat() + + "Z", # pydantic datetime format + "task_id": "op1", + "task_display_name": "op1", + "template_fields": [], + "trigger_rule": "all_success", + "ui_color": "#e8f7e4", + "ui_fgcolor": "#000", + "wait_for_downstream": False, + "weight_rule": "downstream", + "is_mapped": False, + "doc_md": None, + } + response = test_client.get( + f"/public/dags/{self.dag_id}/tasks/{self.task_id}", + ) + assert response.status_code == 200 + assert response.json() == expected + + def test_mapped_task(self, test_client): + expected = { + "class_ref": {"class_name": "EmptyOperator", "module_path": "airflow.operators.empty"}, + "depends_on_past": False, + "downstream_task_ids": [], + "end_date": None, + "execution_timeout": None, + "extra_links": [], + "is_mapped": True, + "operator_name": "EmptyOperator", + "owner": "airflow", + "params": {}, + "pool": "default_pool", + "pool_slots": 1.0, + "priority_weight": 1.0, + "queue": "default", + "retries": 0.0, + "retry_delay": {"__type": "TimeDelta", "days": 0, "microseconds": 0, "seconds": 300}, + "retry_exponential_backoff": False, + "start_date": self.task1_start_date.replace(tzinfo=None).isoformat() + + "Z", # pydantic datetime format + "task_id": "mapped_task", + "task_display_name": "mapped_task", + "template_fields": [], + "trigger_rule": "all_success", + "ui_color": "#e8f7e4", + "ui_fgcolor": "#000", + "wait_for_downstream": False, + "weight_rule": "downstream", + "doc_md": None, + } + response = test_client.get( + f"/public/dags/{self.mapped_dag_id}/tasks/{self.mapped_task_id}", + ) + assert response.status_code == 200 + assert response.json() == expected + + def test_unscheduled_task(self, test_client): + expected = { + "class_ref": { + "class_name": "EmptyOperator", + "module_path": "airflow.operators.empty", + }, + "depends_on_past": False, + "downstream_task_ids": [], + "end_date": None, + "execution_timeout": None, + "extra_links": [], + "operator_name": "EmptyOperator", + "owner": "airflow", + "params": { + "is_unscheduled": { + "__class": "airflow.models.param.Param", + "value": True, + "description": None, + "schema": {}, + } + }, + "pool": "default_pool", + "pool_slots": 1.0, + "priority_weight": 1.0, + "queue": "default", + "retries": 0.0, + "retry_delay": {"__type": "TimeDelta", "days": 0, "seconds": 300, "microseconds": 0}, + "retry_exponential_backoff": False, + "start_date": None, + "task_id": None, + "task_display_name": None, + "template_fields": [], + "trigger_rule": "all_success", + "ui_color": "#e8f7e4", + "ui_fgcolor": "#000", + "wait_for_downstream": False, + "weight_rule": "downstream", + "is_mapped": False, + "doc_md": None, + } + downstream_dict = { + self.unscheduled_task_id1: self.unscheduled_task_id2, + self.unscheduled_task_id2: None, + } + for task_id, downstream_task_id in downstream_dict.items(): + response = test_client.get( + f"/public/dags/{self.unscheduled_dag_id}/tasks/{task_id}", + ) + assert response.status_code == 200 + expected["downstream_task_ids"] = [downstream_task_id] if downstream_task_id else [] + expected["task_id"] = task_id + expected["task_display_name"] = task_id + assert response.json() == expected + + def test_should_respond_200_serialized(self, test_client): + # Get the dag out of the dagbag before we patch it to an empty one + dag = test_client.app.state.dag_bag.get_dag(self.dag_id) + dag.sync_to_db() + SerializedDagModel.write_dag(dag) + + dag_bag = DagBag(os.devnull, include_examples=False, read_dags_from_db=True) + patcher = unittest.mock.patch.object(test_client.app.state, "dag_bag", dag_bag) + patcher.start() + + expected = { + "class_ref": { + "class_name": "EmptyOperator", + "module_path": "airflow.operators.empty", + }, + "depends_on_past": False, + "downstream_task_ids": [self.task_id2], + "end_date": None, + "execution_timeout": None, + "extra_links": [], + "operator_name": "EmptyOperator", + "owner": "airflow", + "params": { + "foo": { + "__class": "airflow.models.param.Param", + "value": "bar", + "description": None, + "schema": {}, + } + }, + "pool": "default_pool", + "pool_slots": 1.0, + "priority_weight": 1.0, + "queue": "default", + "retries": 0.0, + "retry_delay": {"__type": "TimeDelta", "days": 0, "seconds": 300, "microseconds": 0}, + "retry_exponential_backoff": False, + "start_date": self.task1_start_date.replace(tzinfo=None).isoformat() + + "Z", # pydantic datetime format + "task_id": "op1", + "task_display_name": "op1", + "template_fields": [], + "trigger_rule": "all_success", + "ui_color": "#e8f7e4", + "ui_fgcolor": "#000", + "wait_for_downstream": False, + "weight_rule": "downstream", + "is_mapped": False, + "doc_md": None, + } + response = test_client.get( + f"/public/dags/{self.dag_id}/tasks/{self.task_id}", + ) + assert response.status_code == 200 + assert response.json() == expected + patcher.stop() + + def test_should_respond_404(self, test_client): + task_id = "xxxx_not_existing" + response = test_client.get( + f"/public/dags/{self.dag_id}/tasks/{task_id}", + ) + assert response.status_code == 404 + + def test_should_respond_404_when_dag_not_found(self, test_client): + dag_id = "xxxx_not_existing" + response = test_client.get( + f"/public/dags/{dag_id}/tasks/{self.task_id}", + ) + assert response.status_code == 404