From 7af76b88a0932ea3bbd8527c66f72cbbfc116198 Mon Sep 17 00:00:00 2001 From: Pierre Jeambrun Date: Mon, 30 Sep 2024 23:04:25 +0800 Subject: [PATCH] AIP-84 Migrate patch dags to FastAPI API (#42545) * AIP-84 Migrate patch dags to FastAPI API * Fix CI --- .../api_connexion/endpoints/dag_endpoint.py | 1 + airflow/api_fastapi/db/__init__.py | 16 +++ airflow/api_fastapi/db/common.py | 83 ++++++++++++ airflow/api_fastapi/{db.py => db/dags.py} | 54 +++----- airflow/api_fastapi/openapi/v1-generated.yaml | 123 +++++++++++++++++- airflow/api_fastapi/parameters.py | 50 +++++-- airflow/api_fastapi/views/public/dags.py | 93 ++++++++----- airflow/api_fastapi/views/ui/assets.py | 2 +- airflow/ui/openapi-gen/queries/common.ts | 3 + airflow/ui/openapi-gen/queries/queries.ts | 88 ++++++++++++- .../ui/openapi-gen/requests/services.gen.ts | 50 ++++++- airflow/ui/openapi-gen/requests/types.gen.ts | 44 +++++++ tests/api_fastapi/views/public/test_dags.py | 94 ++++++++++--- 13 files changed, 596 insertions(+), 105 deletions(-) create mode 100644 airflow/api_fastapi/db/__init__.py create mode 100644 airflow/api_fastapi/db/common.py rename airflow/api_fastapi/{db.py => db/dags.py} (55%) diff --git a/airflow/api_connexion/endpoints/dag_endpoint.py b/airflow/api_connexion/endpoints/dag_endpoint.py index 6fca5ae7c93d..5d10a97dedce 100644 --- a/airflow/api_connexion/endpoints/dag_endpoint.py +++ b/airflow/api_connexion/endpoints/dag_endpoint.py @@ -165,6 +165,7 @@ def patch_dag(*, dag_id: str, update_mask: UpdateMask = None, session: Session = return dag_schema.dump(dag) +@mark_fastapi_migration_done @security.requires_access_dag("PUT") @format_parameters({"limit": check_limit}) @action_logging diff --git a/airflow/api_fastapi/db/__init__.py b/airflow/api_fastapi/db/__init__.py new file mode 100644 index 000000000000..13a83393a912 --- /dev/null +++ b/airflow/api_fastapi/db/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/api_fastapi/db/common.py b/airflow/api_fastapi/db/common.py new file mode 100644 index 000000000000..f611eaa64f07 --- /dev/null +++ b/airflow/api_fastapi/db/common.py @@ -0,0 +1,83 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Sequence + +from airflow.utils.db import get_query_count +from airflow.utils.session import NEW_SESSION, create_session, provide_session + +if TYPE_CHECKING: + from sqlalchemy.orm import Session + from sqlalchemy.sql import Select + + from airflow.api_fastapi.parameters import BaseParam + + +async def get_session() -> Session: + """ + Dependency for providing a session. + + For non route function please use the :class:`airflow.utils.session.provide_session` decorator. + + Example usage: + + .. code:: python + + @router.get("/your_path") + def your_route(session: Annotated[Session, Depends(get_session)]): + pass + """ + with create_session() as session: + yield session + + +def apply_filters_to_select(base_select: Select, filters: Sequence[BaseParam | None]) -> Select: + base_select = base_select + for filter in filters: + if filter is None: + continue + base_select = filter.to_orm(base_select) + + return base_select + + +@provide_session +def paginated_select( + base_select: Select, + filters: Sequence[BaseParam], + order_by: BaseParam | None = None, + offset: BaseParam | None = None, + limit: BaseParam | None = None, + session: Session = NEW_SESSION, +) -> Select: + base_select = apply_filters_to_select( + base_select, + filters, + ) + + total_entries = get_query_count(base_select, session=session) + + # TODO: Re-enable when permissions are handled. Readable / writable entities, + # for instance: + # readable_dags = get_auth_manager().get_permitted_dag_ids(user=g.user) + # dags_select = dags_select.where(DagModel.dag_id.in_(readable_dags)) + + base_select = apply_filters_to_select(base_select, [order_by, offset, limit]) + + return base_select, total_entries diff --git a/airflow/api_fastapi/db.py b/airflow/api_fastapi/db/dags.py similarity index 55% rename from airflow/api_fastapi/db.py rename to airflow/api_fastapi/db/dags.py index c3ed01a0aefe..7cd7cc9cd955 100644 --- a/airflow/api_fastapi/db.py +++ b/airflow/api_fastapi/db/dags.py @@ -17,45 +17,10 @@ from __future__ import annotations -from typing import TYPE_CHECKING - from sqlalchemy import func, select +from airflow.models.dag import DagModel from airflow.models.dagrun import DagRun -from airflow.utils.session import create_session - -if TYPE_CHECKING: - from sqlalchemy.orm import Session - from sqlalchemy.sql import Select - - from airflow.api_fastapi.parameters import BaseParam - - -async def get_session() -> Session: - """ - Dependency for providing a session. - - For non route function please use the :class:`airflow.utils.session.provide_session` decorator. - - Example usage: - - .. code:: python - - @router.get("/your_path") - def your_route(session: Annotated[Session, Depends(get_session)]): - pass - """ - with create_session() as session: - yield session - - -def apply_filters_to_select(base_select: Select, filters: list[BaseParam]) -> Select: - select = base_select - for filter in filters: - select = filter.to_orm(select) - - return select - latest_dag_run_per_dag_id_cte = ( select(DagRun.dag_id, func.max(DagRun.start_date).label("start_date")) @@ -63,3 +28,20 @@ def apply_filters_to_select(base_select: Select, filters: list[BaseParam]) -> Se .group_by(DagRun.dag_id) .cte() ) + + +dags_select_with_latest_dag_run = ( + select(DagModel) + .join( + latest_dag_run_per_dag_id_cte, + DagModel.dag_id == latest_dag_run_per_dag_id_cte.c.dag_id, + isouter=True, + ) + .join( + DagRun, + DagRun.start_date == latest_dag_run_per_dag_id_cte.c.start_date + and DagRun.dag_id == latest_dag_run_per_dag_id_cte.c.dag_id, + isouter=True, + ) + .order_by(DagModel.dag_id) +) diff --git a/airflow/api_fastapi/openapi/v1-generated.yaml b/airflow/api_fastapi/openapi/v1-generated.yaml index c130f3162c6e..a38a1021890d 100644 --- a/airflow/api_fastapi/openapi/v1-generated.yaml +++ b/airflow/api_fastapi/openapi/v1-generated.yaml @@ -131,12 +131,133 @@ paths: application/json: schema: $ref: '#/components/schemas/HTTPValidationError' + patch: + tags: + - DAG + summary: Patch Dags + description: Patch multiple DAGs. + operationId: patch_dags_public_dags_patch + parameters: + - name: update_mask + in: query + required: false + schema: + anyOf: + - type: array + items: + type: string + - type: 'null' + title: Update Mask + - name: limit + in: query + required: false + schema: + type: integer + default: 100 + title: Limit + - name: offset + in: query + required: false + schema: + type: integer + default: 0 + title: Offset + - name: tags + in: query + required: false + schema: + type: array + items: + type: string + title: Tags + - name: owners + in: query + required: false + schema: + type: array + items: + type: string + title: Owners + - name: dag_id_pattern + in: query + required: false + schema: + anyOf: + - type: string + - type: 'null' + title: Dag Id Pattern + - name: only_active + in: query + required: false + schema: + type: boolean + default: true + title: Only Active + - name: paused + in: query + required: false + schema: + anyOf: + - type: boolean + - type: 'null' + title: Paused + - name: last_dag_run_state + in: query + required: false + schema: + anyOf: + - $ref: '#/components/schemas/DagRunState' + - type: 'null' + title: Last Dag Run State + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/DAGPatchBody' + responses: + '200': + description: Successful Response + content: + application/json: + schema: + $ref: '#/components/schemas/DAGCollectionResponse' + '400': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Bad Request + '401': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Unauthorized + '403': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Forbidden + '404': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Not Found + '422': + description: Validation Error + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPValidationError' /public/dags/{dag_id}: patch: tags: - DAG summary: Patch Dag - description: Update the specific DAG. + description: Patch the specific DAG. operationId: patch_dag_public_dags__dag_id__patch parameters: - name: dag_id diff --git a/airflow/api_fastapi/parameters.py b/airflow/api_fastapi/parameters.py index 09eea5f6e055..504014602f3b 100644 --- a/airflow/api_fastapi/parameters.py +++ b/airflow/api_fastapi/parameters.py @@ -37,9 +37,10 @@ class BaseParam(Generic[T], ABC): """Base class for filters.""" - def __init__(self) -> None: + def __init__(self, skip_none: bool = True) -> None: self.value: T | None = None self.attribute: ColumnElement | None = None + self.skip_none = skip_none @abstractmethod def to_orm(self, select: Select) -> Select: @@ -58,7 +59,7 @@ class _LimitFilter(BaseParam[int]): """Filter on the limit.""" def to_orm(self, select: Select) -> Select: - if self.value is None: + if self.value is None and self.skip_none: return select return select.limit(self.value) @@ -71,7 +72,7 @@ class _OffsetFilter(BaseParam[int]): """Filter on offset.""" def to_orm(self, select: Select) -> Select: - if self.value is None: + if self.value is None and self.skip_none: return select return select.offset(self.value) @@ -83,7 +84,7 @@ class _PausedFilter(BaseParam[bool]): """Filter on is_paused.""" def to_orm(self, select: Select) -> Select: - if self.value is None: + if self.value is None and self.skip_none: return select return select.where(DagModel.is_paused == self.value) @@ -95,7 +96,7 @@ class _OnlyActiveFilter(BaseParam[bool]): """Filter on is_active.""" def to_orm(self, select: Select) -> Select: - if self.value: + if self.value and self.skip_none: return select.where(DagModel.is_active == self.value) return select @@ -106,33 +107,40 @@ def depends(self, only_active: bool = True) -> _OnlyActiveFilter: class _SearchParam(BaseParam[str]): """Search on attribute.""" - def __init__(self, attribute: ColumnElement) -> None: - super().__init__() + def __init__(self, attribute: ColumnElement, skip_none: bool = True) -> None: + super().__init__(skip_none) self.attribute: ColumnElement = attribute def to_orm(self, select: Select) -> Select: - if self.value is None: + if self.value is None and self.skip_none: return select return select.where(self.attribute.ilike(f"%{self.value}")) + def transform_aliases(self, value: str | None) -> str | None: + if value == "~": + value = "%" + return value + class _DagIdPatternSearch(_SearchParam): """Search on dag_id.""" - def __init__(self) -> None: - super().__init__(DagModel.dag_id) + def __init__(self, skip_none: bool = True) -> None: + super().__init__(DagModel.dag_id, skip_none) def depends(self, dag_id_pattern: str | None = None) -> _DagIdPatternSearch: + dag_id_pattern = super().transform_aliases(dag_id_pattern) return self.set_value(dag_id_pattern) class _DagDisplayNamePatternSearch(_SearchParam): """Search on dag_display_name.""" - def __init__(self) -> None: - super().__init__(DagModel.dag_display_name) + def __init__(self, skip_none: bool = True) -> None: + super().__init__(DagModel.dag_display_name, skip_none) def depends(self, dag_display_name_pattern: str | None = None) -> _DagDisplayNamePatternSearch: + dag_display_name_pattern = super().transform_aliases(dag_display_name_pattern) return self.set_value(dag_display_name_pattern) @@ -149,6 +157,9 @@ def __init__(self, allowed_attrs: list[str]) -> None: self.allowed_attrs = allowed_attrs def to_orm(self, select: Select) -> Select: + if self.skip_none is False: + raise ValueError(f"Cannot set 'skip_none' to False on a {type(self)}") + if self.value is None: return select @@ -165,6 +176,10 @@ def to_orm(self, select: Select) -> Select: # MySQL does not support `nullslast`, and True/False ordering depends on the # database implementation. nullscheck = case((column.isnot(None), 0), else_=1) + + # Reset default sorting + select = select.order_by(None) + if self.value[0] == "-": return select.order_by(nullscheck, column.desc(), DagModel.dag_id.desc()) else: @@ -178,6 +193,9 @@ class _TagsFilter(BaseParam[List[str]]): """Filter on tags.""" def to_orm(self, select: Select) -> Select: + if self.skip_none is False: + raise ValueError(f"Cannot set 'skip_none' to False on a {type(self)}") + if not self.value: return select @@ -192,6 +210,9 @@ class _OwnersFilter(BaseParam[List[str]]): """Filter on owners.""" def to_orm(self, select: Select) -> Select: + if self.skip_none is False: + raise ValueError(f"Cannot set 'skip_none' to False on a {type(self)}") + if not self.value: return select @@ -206,7 +227,7 @@ class _LastDagRunStateFilter(BaseParam[DagRunState]): """Filter on the state of the latest DagRun.""" def to_orm(self, select: Select) -> Select: - if self.value is None: + if self.value is None and self.skip_none: return select return select.where(DagRun.state == self.value) @@ -223,6 +244,9 @@ def depends(self, last_dag_run_state: DagRunState | None = None) -> _LastDagRunS QueryDagDisplayNamePatternSearch = Annotated[ _DagDisplayNamePatternSearch, Depends(_DagDisplayNamePatternSearch().depends) ] +QueryDagIdPatternSearchWithNone = Annotated[ + _DagIdPatternSearch, Depends(_DagIdPatternSearch(skip_none=False).depends) +] QueryTagsFilter = Annotated[_TagsFilter, Depends(_TagsFilter().depends)] QueryOwnersFilter = Annotated[_OwnersFilter, Depends(_OwnersFilter().depends)] # DagRun diff --git a/airflow/api_fastapi/views/public/dags.py b/airflow/api_fastapi/views/public/dags.py index a9fe87eef095..a6c25d6568c1 100644 --- a/airflow/api_fastapi/views/public/dags.py +++ b/airflow/api_fastapi/views/public/dags.py @@ -18,15 +18,20 @@ from __future__ import annotations from fastapi import APIRouter, Depends, HTTPException, Query -from sqlalchemy import select +from sqlalchemy import update from sqlalchemy.orm import Session from typing_extensions import Annotated -from airflow.api_fastapi.db import apply_filters_to_select, get_session, latest_dag_run_per_dag_id_cte +from airflow.api_fastapi.db.common import ( + get_session, + paginated_select, +) +from airflow.api_fastapi.db.dags import dags_select_with_latest_dag_run from airflow.api_fastapi.openapi.exceptions import create_openapi_http_exception_doc from airflow.api_fastapi.parameters import ( QueryDagDisplayNamePatternSearch, QueryDagIdPatternSearch, + QueryDagIdPatternSearchWithNone, QueryLastDagRunStateFilter, QueryLimit, QueryOffset, @@ -38,8 +43,6 @@ ) from airflow.api_fastapi.serializers.dags import DAGCollectionResponse, DAGPatchBody, DAGResponse from airflow.models import DagModel -from airflow.models.dagrun import DagRun -from airflow.utils.db import get_query_count dags_router = APIRouter(tags=["DAG"]) @@ -66,35 +69,16 @@ async def get_dags( session: Annotated[Session, Depends(get_session)], ) -> DAGCollectionResponse: """Get all DAGs.""" - dags_query = ( - select(DagModel) - .join( - latest_dag_run_per_dag_id_cte, - DagModel.dag_id == latest_dag_run_per_dag_id_cte.c.dag_id, - isouter=True, - ) - .join( - DagRun, - DagRun.start_date == latest_dag_run_per_dag_id_cte.c.start_date - and DagRun.dag_id == latest_dag_run_per_dag_id_cte.c.dag_id, - isouter=True, - ) - ) - - dags_query = apply_filters_to_select( - dags_query, + dags_select, total_entries = paginated_select( + dags_select_with_latest_dag_run, [only_active, paused, dag_id_pattern, dag_display_name_pattern, tags, owners, last_dag_run_state], + order_by, + offset, + limit, + session, ) - # TODO: Re-enable when permissions are handled. - # readable_dags = get_auth_manager().get_permitted_dag_ids(user=g.user) - # dags_query = dags_query.where(DagModel.dag_id.in_(readable_dags)) - - total_entries = get_query_count(dags_query, session=session) - - dags_query = apply_filters_to_select(dags_query, [order_by, offset, limit]) - - dags = session.scalars(dags_query).all() + dags = session.scalars(dags_select).all() return DAGCollectionResponse( dags=[DAGResponse.model_validate(dag, from_attributes=True) for dag in dags], @@ -109,7 +93,7 @@ async def patch_dag( session: Annotated[Session, Depends(get_session)], update_mask: list[str] | None = Query(None), ) -> DAGResponse: - """Update the specific DAG.""" + """Patch the specific DAG.""" dag = session.get(DagModel, dag_id) if dag is None: @@ -127,3 +111,50 @@ async def patch_dag( setattr(dag, attr_name, attr_value) return DAGResponse.model_validate(dag, from_attributes=True) + + +@dags_router.patch("/dags", responses=create_openapi_http_exception_doc([400, 401, 403, 404])) +async def patch_dags( + patch_body: DAGPatchBody, + limit: QueryLimit, + offset: QueryOffset, + tags: QueryTagsFilter, + owners: QueryOwnersFilter, + dag_id_pattern: QueryDagIdPatternSearchWithNone, + only_active: QueryOnlyActiveFilter, + paused: QueryPausedFilter, + last_dag_run_state: QueryLastDagRunStateFilter, + session: Annotated[Session, Depends(get_session)], + update_mask: list[str] | None = Query(None), +) -> DAGCollectionResponse: + """Patch multiple DAGs.""" + if update_mask: + if update_mask != ["is_paused"]: + raise HTTPException(400, "Only `is_paused` field can be updated through the REST API") + else: + update_mask = ["is_paused"] + + dags_select, total_entries = paginated_select( + dags_select_with_latest_dag_run, + [only_active, paused, dag_id_pattern, tags, owners, last_dag_run_state], + None, + offset, + limit, + session, + ) + + dags = session.scalars(dags_select).all() + + dags_to_update = {dag.dag_id for dag in dags} + + session.execute( + update(DagModel) + .where(DagModel.dag_id.in_(dags_to_update)) + .values(is_paused=patch_body.is_paused) + .execution_options(synchronize_session="fetch") + ) + + return DAGCollectionResponse( + dags=[DAGResponse.model_validate(dag, from_attributes=True) for dag in dags], + total_entries=total_entries, + ) diff --git a/airflow/api_fastapi/views/ui/assets.py b/airflow/api_fastapi/views/ui/assets.py index 458d531facf6..739c7d64af43 100644 --- a/airflow/api_fastapi/views/ui/assets.py +++ b/airflow/api_fastapi/views/ui/assets.py @@ -22,7 +22,7 @@ from sqlalchemy.orm import Session from typing_extensions import Annotated -from airflow.api_fastapi.db import get_session +from airflow.api_fastapi.db.common import get_session from airflow.models import DagModel from airflow.models.asset import AssetDagRunQueue, AssetEvent, AssetModel, DagScheduleAssetReference diff --git a/airflow/ui/openapi-gen/queries/common.ts b/airflow/ui/openapi-gen/queries/common.ts index 46694939ed74..b1508c86c0c4 100644 --- a/airflow/ui/openapi-gen/queries/common.ts +++ b/airflow/ui/openapi-gen/queries/common.ts @@ -76,6 +76,9 @@ export const UseDagServiceGetDagsPublicDagsGetKeyFn = ( }, ]), ]; +export type DagServicePatchDagsPublicDagsPatchMutationResult = Awaited< + ReturnType +>; export type DagServicePatchDagPublicDagsDagIdPatchMutationResult = Awaited< ReturnType >; diff --git a/airflow/ui/openapi-gen/queries/queries.ts b/airflow/ui/openapi-gen/queries/queries.ts index 7cbaac5b2c77..5eda2a3d0e4d 100644 --- a/airflow/ui/openapi-gen/queries/queries.ts +++ b/airflow/ui/openapi-gen/queries/queries.ts @@ -118,9 +118,95 @@ export const useDagServiceGetDagsPublicDagsGet = < }) as TData, ...options, }); +/** + * Patch Dags + * Patch multiple DAGs. + * @param data The data for the request. + * @param data.requestBody + * @param data.updateMask + * @param data.limit + * @param data.offset + * @param data.tags + * @param data.owners + * @param data.dagIdPattern + * @param data.onlyActive + * @param data.paused + * @param data.lastDagRunState + * @returns DAGCollectionResponse Successful Response + * @throws ApiError + */ +export const useDagServicePatchDagsPublicDagsPatch = < + TData = Common.DagServicePatchDagsPublicDagsPatchMutationResult, + TError = unknown, + TContext = unknown, +>( + options?: Omit< + UseMutationOptions< + TData, + TError, + { + dagIdPattern?: string; + lastDagRunState?: DagRunState; + limit?: number; + offset?: number; + onlyActive?: boolean; + owners?: string[]; + paused?: boolean; + requestBody: DAGPatchBody; + tags?: string[]; + updateMask?: string[]; + }, + TContext + >, + "mutationFn" + >, +) => + useMutation< + TData, + TError, + { + dagIdPattern?: string; + lastDagRunState?: DagRunState; + limit?: number; + offset?: number; + onlyActive?: boolean; + owners?: string[]; + paused?: boolean; + requestBody: DAGPatchBody; + tags?: string[]; + updateMask?: string[]; + }, + TContext + >({ + mutationFn: ({ + dagIdPattern, + lastDagRunState, + limit, + offset, + onlyActive, + owners, + paused, + requestBody, + tags, + updateMask, + }) => + DagService.patchDagsPublicDagsPatch({ + dagIdPattern, + lastDagRunState, + limit, + offset, + onlyActive, + owners, + paused, + requestBody, + tags, + updateMask, + }) as unknown as Promise, + ...options, + }); /** * Patch Dag - * Update the specific DAG. + * Patch the specific DAG. * @param data The data for the request. * @param data.dagId * @param data.requestBody diff --git a/airflow/ui/openapi-gen/requests/services.gen.ts b/airflow/ui/openapi-gen/requests/services.gen.ts index 5aa5876d112a..7fb6306afbc6 100644 --- a/airflow/ui/openapi-gen/requests/services.gen.ts +++ b/airflow/ui/openapi-gen/requests/services.gen.ts @@ -7,6 +7,8 @@ import type { NextRunAssetsUiNextRunDatasetsDagIdGetResponse, GetDagsPublicDagsGetData, GetDagsPublicDagsGetResponse, + PatchDagsPublicDagsPatchData, + PatchDagsPublicDagsPatchResponse, PatchDagPublicDagsDagIdPatchData, PatchDagPublicDagsDagIdPatchResponse, } from "./types.gen"; @@ -77,9 +79,55 @@ export class DagService { }); } + /** + * Patch Dags + * Patch multiple DAGs. + * @param data The data for the request. + * @param data.requestBody + * @param data.updateMask + * @param data.limit + * @param data.offset + * @param data.tags + * @param data.owners + * @param data.dagIdPattern + * @param data.onlyActive + * @param data.paused + * @param data.lastDagRunState + * @returns DAGCollectionResponse Successful Response + * @throws ApiError + */ + public static patchDagsPublicDagsPatch( + data: PatchDagsPublicDagsPatchData, + ): CancelablePromise { + return __request(OpenAPI, { + method: "PATCH", + url: "/public/dags", + query: { + update_mask: data.updateMask, + limit: data.limit, + offset: data.offset, + tags: data.tags, + owners: data.owners, + dag_id_pattern: data.dagIdPattern, + only_active: data.onlyActive, + paused: data.paused, + last_dag_run_state: data.lastDagRunState, + }, + body: data.requestBody, + mediaType: "application/json", + errors: { + 400: "Bad Request", + 401: "Unauthorized", + 403: "Forbidden", + 404: "Not Found", + 422: "Validation Error", + }, + }); + } + /** * Patch Dag - * Update the specific DAG. + * Patch the specific DAG. * @param data The data for the request. * @param data.dagId * @param data.requestBody diff --git a/airflow/ui/openapi-gen/requests/types.gen.ts b/airflow/ui/openapi-gen/requests/types.gen.ts index bc455f63b644..0fe7134ba8c3 100644 --- a/airflow/ui/openapi-gen/requests/types.gen.ts +++ b/airflow/ui/openapi-gen/requests/types.gen.ts @@ -111,6 +111,21 @@ export type GetDagsPublicDagsGetData = { export type GetDagsPublicDagsGetResponse = DAGCollectionResponse; +export type PatchDagsPublicDagsPatchData = { + dagIdPattern?: string | null; + lastDagRunState?: DagRunState | null; + limit?: number; + offset?: number; + onlyActive?: boolean; + owners?: Array; + paused?: boolean | null; + requestBody: DAGPatchBody; + tags?: Array; + updateMask?: Array | null; +}; + +export type PatchDagsPublicDagsPatchResponse = DAGCollectionResponse; + export type PatchDagPublicDagsDagIdPatchData = { dagId: string; requestBody: DAGPatchBody; @@ -151,6 +166,35 @@ export type $OpenApiTs = { 422: HTTPValidationError; }; }; + patch: { + req: PatchDagsPublicDagsPatchData; + res: { + /** + * Successful Response + */ + 200: DAGCollectionResponse; + /** + * Bad Request + */ + 400: HTTPExceptionResponse; + /** + * Unauthorized + */ + 401: HTTPExceptionResponse; + /** + * Forbidden + */ + 403: HTTPExceptionResponse; + /** + * Not Found + */ + 404: HTTPExceptionResponse; + /** + * Validation Error + */ + 422: HTTPValidationError; + }; + }; }; "/public/dags/{dag_id}": { patch: { diff --git a/tests/api_fastapi/views/public/test_dags.py b/tests/api_fastapi/views/public/test_dags.py index 6e400f11cc0d..7b68ebe512a2 100644 --- a/tests/api_fastapi/views/public/test_dags.py +++ b/tests/api_fastapi/views/public/test_dags.py @@ -112,37 +112,37 @@ def setup(dag_maker) -> None: "query_params, expected_total_entries, expected_ids", [ # Filters - ({}, 2, ["test_dag1", "test_dag2"]), - ({"limit": 1}, 2, ["test_dag1"]), - ({"offset": 1}, 2, ["test_dag2"]), - ({"tags": ["example"]}, 1, ["test_dag1"]), - ({"only_active": False}, 3, ["test_dag1", "test_dag2", "test_dag3"]), - ({"paused": True, "only_active": False}, 1, ["test_dag3"]), - ({"paused": False}, 2, ["test_dag1", "test_dag2"]), - ({"owners": ["airflow"]}, 2, ["test_dag1", "test_dag2"]), - ({"owners": ["test_owner"], "only_active": False}, 1, ["test_dag3"]), - ({"last_dag_run_state": "success", "only_active": False}, 1, ["test_dag3"]), - ({"last_dag_run_state": "failed", "only_active": False}, 1, ["test_dag1"]), + ({}, 2, [DAG1_ID, DAG2_ID]), + ({"limit": 1}, 2, [DAG1_ID]), + ({"offset": 1}, 2, [DAG2_ID]), + ({"tags": ["example"]}, 1, [DAG1_ID]), + ({"only_active": False}, 3, [DAG1_ID, DAG2_ID, DAG3_ID]), + ({"paused": True, "only_active": False}, 1, [DAG3_ID]), + ({"paused": False}, 2, [DAG1_ID, DAG2_ID]), + ({"owners": ["airflow"]}, 2, [DAG1_ID, DAG2_ID]), + ({"owners": ["test_owner"], "only_active": False}, 1, [DAG3_ID]), + ({"last_dag_run_state": "success", "only_active": False}, 1, [DAG3_ID]), + ({"last_dag_run_state": "failed", "only_active": False}, 1, [DAG1_ID]), # # Sort - ({"order_by": "-dag_id"}, 2, ["test_dag2", "test_dag1"]), - ({"order_by": "-dag_display_name"}, 2, ["test_dag2", "test_dag1"]), - ({"order_by": "dag_display_name"}, 2, ["test_dag1", "test_dag2"]), - ({"order_by": "next_dagrun", "only_active": False}, 3, ["test_dag3", "test_dag1", "test_dag2"]), - ({"order_by": "last_run_state", "only_active": False}, 3, ["test_dag1", "test_dag3", "test_dag2"]), - ({"order_by": "-last_run_state", "only_active": False}, 3, ["test_dag3", "test_dag1", "test_dag2"]), + ({"order_by": "-dag_id"}, 2, [DAG2_ID, DAG1_ID]), + ({"order_by": "-dag_display_name"}, 2, [DAG2_ID, DAG1_ID]), + ({"order_by": "dag_display_name"}, 2, [DAG1_ID, DAG2_ID]), + ({"order_by": "next_dagrun", "only_active": False}, 3, [DAG3_ID, DAG1_ID, DAG2_ID]), + ({"order_by": "last_run_state", "only_active": False}, 3, [DAG1_ID, DAG3_ID, DAG2_ID]), + ({"order_by": "-last_run_state", "only_active": False}, 3, [DAG3_ID, DAG1_ID, DAG2_ID]), ( {"order_by": "last_run_start_date", "only_active": False}, 3, - ["test_dag1", "test_dag3", "test_dag2"], + [DAG1_ID, DAG3_ID, DAG2_ID], ), ( {"order_by": "-last_run_start_date", "only_active": False}, 3, - ["test_dag3", "test_dag1", "test_dag2"], + [DAG3_ID, DAG1_ID, DAG2_ID], ), # Search - ({"dag_id_pattern": "1"}, 1, ["test_dag1"]), - ({"dag_display_name_pattern": "display2"}, 1, ["test_dag2"]), + ({"dag_id_pattern": "1"}, 1, [DAG1_ID]), + ({"dag_display_name_pattern": "display2"}, 1, [DAG2_ID]), ], ) def test_get_dags(test_client, query_params, expected_total_entries, expected_ids): @@ -173,3 +173,55 @@ def test_patch_dag(test_client, query_params, dag_id, body, expected_status_code if expected_status_code == 200: body = response.json() assert body["is_paused"] == expected_is_paused + + +@pytest.mark.parametrize( + "query_params, body, expected_status_code, expected_ids, expected_paused_ids", + [ + ({"update_mask": ["field_1", "is_paused"]}, {"is_paused": True}, 400, None, None), + ( + {"only_active": False}, + {"is_paused": True}, + 200, + [], + [], + ), # no-op because the dag_id_pattern is not provided + ( + {"only_active": False, "dag_id_pattern": "~"}, + {"is_paused": True}, + 200, + [DAG1_ID, DAG2_ID, DAG3_ID], + [DAG1_ID, DAG2_ID, DAG3_ID], + ), + ( + {"only_active": False, "dag_id_pattern": "~"}, + {"is_paused": False}, + 200, + [DAG1_ID, DAG2_ID, DAG3_ID], + [], + ), + ( + {"dag_id_pattern": "~"}, + {"is_paused": True}, + 200, + [DAG1_ID, DAG2_ID], + [DAG1_ID, DAG2_ID], + ), + ( + {"dag_id_pattern": "dag1"}, + {"is_paused": True}, + 200, + [DAG1_ID], + [DAG1_ID], + ), + ], +) +def test_patch_dags(test_client, query_params, body, expected_status_code, expected_ids, expected_paused_ids): + response = test_client.patch("/public/dags", json=body, params=query_params) + + assert response.status_code == expected_status_code + if expected_status_code == 200: + body = response.json() + assert [dag["dag_id"] for dag in body["dags"]] == expected_ids + paused_dag_ids = [dag["dag_id"] for dag in body["dags"] if dag["is_paused"]] + assert paused_dag_ids == expected_paused_ids