From 56faf4806bced50d9786e35697c7e9cb0f3ce7d3 Mon Sep 17 00:00:00 2001 From: Omkar P <45419097+omkar-foss@users.noreply.github.com> Date: Mon, 28 Oct 2024 22:40:45 +0530 Subject: [PATCH] Use pydantic model constructors, minor refactors --- airflow/api_fastapi/common/db/common.py | 2 +- airflow/api_fastapi/common/parameters.py | 5 ++++- .../core_api/routes/public/dag_stats.py | 20 ++++++++++--------- 3 files changed, 16 insertions(+), 11 deletions(-) diff --git a/airflow/api_fastapi/common/db/common.py b/airflow/api_fastapi/common/db/common.py index b9b90fb6f7434..01e1fe532bf60 100644 --- a/airflow/api_fastapi/common/db/common.py +++ b/airflow/api_fastapi/common/db/common.py @@ -65,7 +65,7 @@ def paginated_select( offset: BaseParam | None = None, limit: BaseParam | None = None, session: Session = NEW_SESSION, - return_total_entries: bool = True + return_total_entries: bool = True, ) -> Select: base_select = apply_filters_to_select( base_select, diff --git a/airflow/api_fastapi/common/parameters.py b/airflow/api_fastapi/common/parameters.py index a441ab37bdc4d..185182d42a5ec 100644 --- a/airflow/api_fastapi/common/parameters.py +++ b/airflow/api_fastapi/common/parameters.py @@ -306,6 +306,7 @@ def _safe_parse_datetime(date_to_check: str) -> datetime: # Common Safe DateTime DateTimeQuery = Annotated[str, AfterValidator(_safe_parse_datetime)] + # DAG QueryLimit = Annotated[_LimitFilter, Depends(_LimitFilter().depends)] QueryOffset = Annotated[_OffsetFilter, Depends(_OffsetFilter().depends)] @@ -320,8 +321,10 @@ def _safe_parse_datetime(date_to_check: str) -> datetime: ] QueryTagsFilter = Annotated[_TagsFilter, Depends(_TagsFilter().depends)] QueryOwnersFilter = Annotated[_OwnersFilter, Depends(_OwnersFilter().depends)] + # DagRun QueryLastDagRunStateFilter = Annotated[_LastDagRunStateFilter, Depends(_LastDagRunStateFilter().depends)] +QueryDagIdsFilter = Annotated[_DagIdsFilter, Depends(_DagIdsFilter().depends)] + # DAGTags QueryDagTagPatternSearch = Annotated[_DagTagNamePatternSearch, Depends(_DagTagNamePatternSearch().depends)] -QueryDagIdsFilter = Annotated[_DagIdsFilter, Depends(_DagIdsFilter().depends)] diff --git a/airflow/api_fastapi/core_api/routes/public/dag_stats.py b/airflow/api_fastapi/core_api/routes/public/dag_stats.py index 9b470a639c153..20b4e2fa8db08 100644 --- a/airflow/api_fastapi/core_api/routes/public/dag_stats.py +++ b/airflow/api_fastapi/core_api/routes/public/dag_stats.py @@ -31,6 +31,8 @@ from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc from airflow.api_fastapi.core_api.serializers.dag_stats import ( DagStatsCollectionResponse, + DagStatsResponse, + DagStatsStateResponse, ) from airflow.utils.state import DagRunState @@ -50,7 +52,7 @@ async def get_dag_stats( base_select=dagruns_select_with_state_count, filters=[dag_ids], session=session, - return_total_entries=False + return_total_entries=False, ) query_result = session.execute(dagruns_select) @@ -62,16 +64,16 @@ async def get_dag_stats( result_dag_ids.append(dag_id) dags = [ - { - "dag_id": dag_id, - "stats": [ - { - "state": state, - "count": dag_state_data.get((dag_id, state), 0), - } + DagStatsResponse( + dag_id=dag_id, + stats=[ + DagStatsStateResponse( + state=state, + count=dag_state_data.get((dag_id, state), 0), + ) for state in DagRunState ], - } + ) for dag_id in result_dag_ids ] return DagStatsCollectionResponse(dags=dags, total_entries=len(dags))