Skip to content

Commit

Permalink
AIP-84 Migrate get connections to FastAPI API apache#42571 (apache#42782
Browse files Browse the repository at this point in the history
)

* Make SortParam parent for Model Specific SortParams, Include get connections endpoint to fastapi

* Change depends() method regular method in SortParam due to parent class already have abstract

* Remove subclass, get default order_by from primary key, change alias strategy for backcompat

* pre-commit hooks

* Dynamic return value of SortParam generated within openapi specs and removed unnecessary attribute mapping keys

* Include connection_id to attr_mapping again

* Dynamic depends with correct documentation

* Add more tests

---------

Co-authored-by: pierrejeambrun <[email protected]>
  • Loading branch information
2 people authored and PaulKobow7536 committed Oct 24, 2024
1 parent 0346615 commit 8426509
Show file tree
Hide file tree
Showing 14 changed files with 466 additions and 24 deletions.
1 change: 1 addition & 0 deletions airflow/api_connexion/endpoints/connection_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def get_connection(*, connection_id: str, session: Session = NEW_SESSION) -> API
@security.requires_access_connection("GET")
@format_parameters({"limit": check_limit})
@provide_session
@mark_fastapi_migration_done
def get_connections(
*,
limit: int,
Expand Down
82 changes: 79 additions & 3 deletions airflow/api_fastapi/openapi/v1-generated.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,66 @@ paths:
application/json:
schema:
$ref: '#/components/schemas/HTTPValidationError'
/public/connections/:
get:
tags:
- Connection
summary: Get Connections
description: Get all connection entries.
operationId: get_connections
parameters:
- name: limit
in: query
required: false
schema:
type: integer
default: 100
title: Limit
- name: offset
in: query
required: false
schema:
type: integer
default: 0
title: Offset
- name: order_by
in: query
required: false
schema:
type: string
default: id
title: Order By
responses:
'200':
description: Successful Response
content:
application/json:
schema:
$ref: '#/components/schemas/ConnectionCollectionResponse'
'401':
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
description: Unauthorized
'403':
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
description: Forbidden
'404':
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
description: Not Found
'422':
description: Validation Error
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPValidationError'
/public/variables/{variable_key}:
delete:
tags:
Expand Down Expand Up @@ -886,11 +946,27 @@ paths:
$ref: '#/components/schemas/HTTPValidationError'
components:
schemas:
ConnectionCollectionResponse:
properties:
connections:
items:
$ref: '#/components/schemas/ConnectionResponse'
type: array
title: Connections
total_entries:
type: integer
title: Total Entries
type: object
required:
- connections
- total_entries
title: ConnectionCollectionResponse
description: DAG Collection serializer for responses.
ConnectionResponse:
properties:
conn_id:
connection_id:
type: string
title: Conn Id
title: Connection Id
conn_type:
type: string
title: Conn Type
Expand Down Expand Up @@ -926,7 +1002,7 @@ components:
title: Extra
type: object
required:
- conn_id
- connection_id
- conn_type
- description
- host
Expand Down
48 changes: 40 additions & 8 deletions airflow/api_fastapi/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,19 @@

from __future__ import annotations

import importlib
from abc import ABC, abstractmethod
from datetime import datetime
from typing import TYPE_CHECKING, Any, Generic, List, TypeVar
from typing import TYPE_CHECKING, Any, Callable, Generic, List, TypeVar

from fastapi import Depends, HTTPException, Query
from pendulum.parsing.exceptions import ParserError
from pydantic import AfterValidator
from sqlalchemy import case, or_
from sqlalchemy import Column, case, or_
from sqlalchemy.inspection import inspect
from typing_extensions import Annotated, Self

from airflow.models import Base, Connection
from airflow.models.dag import DagModel, DagTag
from airflow.models.dagrun import DagRun
from airflow.utils import timezone
Expand Down Expand Up @@ -154,11 +157,17 @@ class SortParam(BaseParam[str]):
attr_mapping = {
"last_run_state": DagRun.state,
"last_run_start_date": DagRun.start_date,
"connection_id": Connection.conn_id,
}

def __init__(self, allowed_attrs: list[str]) -> None:
def __init__(
self,
allowed_attrs: list[str],
model: Base,
) -> None:
super().__init__()
self.allowed_attrs = allowed_attrs
self.model = model

def to_orm(self, select: Select) -> Select:
if self.skip_none is False:
Expand All @@ -175,7 +184,9 @@ def to_orm(self, select: Select) -> Select:
f"the attribute does not exist on the model",
)

column = self.attr_mapping.get(lstriped_orderby, None) or getattr(DagModel, lstriped_orderby)
column: Column = self.attr_mapping.get(lstriped_orderby, None) or getattr(
self.model, lstriped_orderby
)

# MySQL does not support `nullslast`, and True/False ordering depends on the
# database implementation.
Expand All @@ -185,12 +196,33 @@ def to_orm(self, select: Select) -> Select:
select = select.order_by(None)

if self.value[0] == "-":
return select.order_by(nullscheck, column.desc(), DagModel.dag_id.desc())
return select.order_by(nullscheck, column.desc(), column.desc())
else:
return select.order_by(nullscheck, column.asc(), DagModel.dag_id.asc())
return select.order_by(nullscheck, column.asc(), column.asc())

def get_primary_key(self) -> str:
"""Get the primary key of the model of SortParam object."""
return inspect(self.model).primary_key[0].name

@staticmethod
def get_primary_key_of_given_model_string(model_string: str) -> str:
"""
Get the primary key of given 'airflow.models' class as a string. The class should have driven be from 'airflow.models.base'.
:param model_string: The string representation of the model class.
:return: The primary key of the model class.
"""
dynamic_return_model = getattr(importlib.import_module("airflow.models"), model_string)
return inspect(dynamic_return_model).primary_key[0].name

def depends(self, *args: Any, **kwargs: Any) -> Self:
raise NotImplementedError("Use dynamic_depends, depends not implemented.")

def dynamic_depends(self) -> Callable:
def inner(order_by: str = self.get_primary_key()) -> SortParam:
return self.set_value(self.get_primary_key() if order_by == "" else order_by)

def depends(self, order_by: str = "dag_id") -> SortParam:
return self.set_value(order_by)
return inner


class _TagsFilter(BaseParam[List[str]]):
Expand Down
9 changes: 8 additions & 1 deletion airflow/api_fastapi/serializers/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
class ConnectionResponse(BaseModel):
"""Connection serializer for responses."""

connection_id: str = Field(alias="conn_id")
connection_id: str = Field(serialization_alias="connection_id", validation_alias="conn_id")
conn_type: str
description: str | None
host: str | None
Expand All @@ -48,3 +48,10 @@ def redact_extra(cls, v: str | None) -> str | None:
except json.JSONDecodeError:
# we can't redact fields in an unstructured `extra`
return v


class ConnectionCollectionResponse(BaseModel):
"""DAG Collection serializer for responses."""

connections: list[ConnectionResponse]
total_entries: int
42 changes: 40 additions & 2 deletions airflow/api_fastapi/views/public/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@
from sqlalchemy.orm import Session
from typing_extensions import Annotated

from airflow.api_fastapi.db.common import get_session
from airflow.api_fastapi.db.common import get_session, paginated_select
from airflow.api_fastapi.openapi.exceptions import create_openapi_http_exception_doc
from airflow.api_fastapi.serializers.connections import ConnectionResponse
from airflow.api_fastapi.parameters import QueryLimit, QueryOffset, SortParam
from airflow.api_fastapi.serializers.connections import ConnectionCollectionResponse, ConnectionResponse
from airflow.api_fastapi.views.router import AirflowRouter
from airflow.models import Connection

Expand Down Expand Up @@ -63,3 +64,40 @@ async def get_connection(
raise HTTPException(404, f"The Connection with connection_id: `{connection_id}` was not found")

return ConnectionResponse.model_validate(connection, from_attributes=True)


@connections_router.get(
"/",
responses=create_openapi_http_exception_doc([401, 403, 404]),
)
async def get_connections(
limit: QueryLimit,
offset: QueryOffset,
order_by: Annotated[
SortParam,
Depends(
SortParam(
["connection_id", "conn_type", "description", "host", "port", "id"], Connection
).dynamic_depends()
),
],
session: Annotated[Session, Depends(get_session)],
) -> ConnectionCollectionResponse:
"""Get all connection entries."""
connection_select, total_entries = paginated_select(
select(Connection),
[],
order_by=order_by,
offset=offset,
limit=limit,
session=session,
)

connections = session.scalars(connection_select).all()

return ConnectionCollectionResponse(
connections=[
ConnectionResponse.model_validate(connection, from_attributes=True) for connection in connections
],
total_entries=total_entries,
)
5 changes: 3 additions & 2 deletions airflow/api_fastapi/views/public/dags.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,9 @@ async def get_dags(
SortParam,
Depends(
SortParam(
["dag_id", "dag_display_name", "next_dagrun", "last_run_state", "last_run_start_date"]
).depends
["dag_id", "dag_display_name", "next_dagrun", "last_run_state", "last_run_start_date"],
DagModel,
).dynamic_depends()
),
],
session: Annotated[Session, Depends(get_session)],
Expand Down
24 changes: 24 additions & 0 deletions airflow/ui/openapi-gen/queries/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,30 @@ export const UseConnectionServiceGetConnectionKeyFn = (
useConnectionServiceGetConnectionKey,
...(queryKey ?? [{ connectionId }]),
];
export type ConnectionServiceGetConnectionsDefaultResponse = Awaited<
ReturnType<typeof ConnectionService.getConnections>
>;
export type ConnectionServiceGetConnectionsQueryResult<
TData = ConnectionServiceGetConnectionsDefaultResponse,
TError = unknown,
> = UseQueryResult<TData, TError>;
export const useConnectionServiceGetConnectionsKey =
"ConnectionServiceGetConnections";
export const UseConnectionServiceGetConnectionsKeyFn = (
{
limit,
offset,
orderBy,
}: {
limit?: number;
offset?: number;
orderBy?: string;
} = {},
queryKey?: Array<unknown>,
) => [
useConnectionServiceGetConnectionsKey,
...(queryKey ?? [{ limit, offset, orderBy }]),
];
export type VariableServiceGetVariableDefaultResponse = Awaited<
ReturnType<typeof VariableService.getVariable>
>;
Expand Down
30 changes: 30 additions & 0 deletions airflow/ui/openapi-gen/queries/prefetch.ts
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,36 @@ export const prefetchUseConnectionServiceGetConnection = (
queryKey: Common.UseConnectionServiceGetConnectionKeyFn({ connectionId }),
queryFn: () => ConnectionService.getConnection({ connectionId }),
});
/**
* Get Connections
* Get all connection entries.
* @param data The data for the request.
* @param data.limit
* @param data.offset
* @param data.orderBy
* @returns ConnectionCollectionResponse Successful Response
* @throws ApiError
*/
export const prefetchUseConnectionServiceGetConnections = (
queryClient: QueryClient,
{
limit,
offset,
orderBy,
}: {
limit?: number;
offset?: number;
orderBy?: string;
} = {},
) =>
queryClient.prefetchQuery({
queryKey: Common.UseConnectionServiceGetConnectionsKeyFn({
limit,
offset,
orderBy,
}),
queryFn: () => ConnectionService.getConnections({ limit, offset, orderBy }),
});
/**
* Get Variable
* Get a variable entry.
Expand Down
Loading

0 comments on commit 8426509

Please sign in to comment.