Skip to content

Commit

Permalink
Migrate public endpoint Get Task to FastAPI, with main resynced
Browse files Browse the repository at this point in the history
  • Loading branch information
omkar-foss committed Nov 8, 2024
1 parent 2b79d18 commit f0f1850
Show file tree
Hide file tree
Showing 14 changed files with 1,316 additions and 1 deletion.
2 changes: 2 additions & 0 deletions airflow/api_connexion/endpoints/task_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,14 @@
from airflow.auth.managers.models.resource_details import DagAccessEntity
from airflow.exceptions import TaskNotFound
from airflow.utils.airflow_flask_app import get_airflow_app
from airflow.utils.api_migration import mark_fastapi_migration_done

if TYPE_CHECKING:
from airflow import DAG
from airflow.api_connexion.types import APIResponse


@mark_fastapi_migration_done
@security.requires_access_dag("GET", DagAccessEntity.TASK)
def get_task(*, dag_id: str, task_id: str) -> APIResponse:
"""Get simplified representation of a task."""
Expand Down
64 changes: 63 additions & 1 deletion airflow/api_fastapi/common/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,72 @@
# under the License.
from __future__ import annotations

from pydantic import AfterValidator, AwareDatetime
import inspect
from datetime import timedelta

from pydantic import AfterValidator, AliasGenerator, AwareDatetime, BaseModel, BeforeValidator, ConfigDict
from typing_extensions import Annotated

from airflow.models.mappedoperator import MappedOperator
from airflow.models.operator import Operator
from airflow.serialization.serialized_objects import SerializedBaseOperator
from airflow.utils import timezone

UtcDateTime = Annotated[AwareDatetime, AfterValidator(lambda d: d.astimezone(timezone.utc))]
"""UTCDateTime is a datetime with timezone information"""


def _validate_timedelta_field(td: timedelta | None) -> TimeDelta | None:
"""Validate the execution_timeout property."""
if td is None:
return None
return TimeDelta(
days=td.days,
seconds=td.seconds,
microseconds=td.microseconds,
)


class TimeDelta(BaseModel):
"""TimeDelta can be used to interact with datetime.timedelta objects."""

object_type: str = "TimeDelta"
days: int
seconds: int
microseconds: int

model_config = ConfigDict(
alias_generator=AliasGenerator(
serialization_alias=lambda field_name: {
"object_type": "__type",
}.get(field_name, field_name),
)
)


TimeDeltaWithValidation = Annotated[TimeDelta, BeforeValidator(_validate_timedelta_field)]


def get_class_ref(obj: Operator) -> dict[str, str | None]:
"""Return the class_ref dict for obj."""
is_mapped_or_serialized = isinstance(obj, (MappedOperator, SerializedBaseOperator))

module_path = None
if is_mapped_or_serialized:
module_path = obj._task_module
else:
module_type = inspect.getmodule(obj)
module_path = module_type.__name__ if module_type else None

class_name = None
if is_mapped_or_serialized:
class_name = obj._task_type
elif obj.__class__ is type:
class_name = obj.__name__
else:
class_name = type(obj).__name__

return {
"module_path": module_path,
"class_name": class_name,
}
83 changes: 83 additions & 0 deletions airflow/api_fastapi/core_api/datamodels/tasks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

from collections import abc
from datetime import datetime

from pydantic import BaseModel, computed_field, field_validator

from airflow.api_fastapi.common.types import TimeDeltaWithValidation
from airflow.serialization.serialized_objects import encode_priority_weight_strategy
from airflow.task.priority_strategy import PriorityWeightStrategy


class TaskResponse(BaseModel):
"""Task serializer for responses."""

task_id: str | None
task_display_name: str | None
owner: str | None
start_date: datetime | None
end_date: datetime | None
trigger_rule: str | None
depends_on_past: bool
wait_for_downstream: bool
retries: float | None
queue: str | None
pool: str | None
pool_slots: float | None
execution_timeout: TimeDeltaWithValidation | None
retry_delay: TimeDeltaWithValidation | None
retry_exponential_backoff: bool
priority_weight: float | None
weight_rule: str | None
ui_color: str | None
ui_fgcolor: str | None
template_fields: list[str] | None
downstream_task_ids: list[str] | None
doc_md: str | None
operator_name: str | None
params: abc.MutableMapping | None
class_ref: dict | None
is_mapped: bool | None

@field_validator("weight_rule", mode="before")
@classmethod
def validate_weight_rule(cls, wr: str | PriorityWeightStrategy | None) -> str | None:
"""Validate the weight_rule property."""
if wr is None:
return None
if isinstance(wr, str):
return wr
return encode_priority_weight_strategy(wr)

@field_validator("params", mode="before")
@classmethod
def get_params(cls, params: abc.MutableMapping | None) -> dict | None:
"""Convert params attribute to dict representation."""
if params is None:
return None
return {param_name: param_val.dump() for param_name, param_val in params.items()}

# Mypy issue https://github.com/python/mypy/issues/1362
@computed_field # type: ignore[misc]
@property
def extra_links(self) -> list[str]:
"""Extract and return extra_links."""
return getattr(self, "operator_extra_links", [])
Loading

0 comments on commit f0f1850

Please sign in to comment.