From 1f16c57faf0f10edfaeb51b2d8c0046a8843023d Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Mon, 30 Sep 2024 12:58:31 -0700 Subject: [PATCH 01/10] Speed up boring cyborg consistency pre-commit check (#42589) This is typically the slowest pre-commit besides mypy, and it runs every time. Previously it loaded all filenames into memory and ran glob filter on that. It seems faster to apply glob against the file system directly. This makes pre-commit much faster. Previously took around 4 seconds, now about a half a second. --- scripts/ci/pre_commit/boring_cyborg.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/scripts/ci/pre_commit/boring_cyborg.py b/scripts/ci/pre_commit/boring_cyborg.py index cf852b12bb6d..ec674485b545 100755 --- a/scripts/ci/pre_commit/boring_cyborg.py +++ b/scripts/ci/pre_commit/boring_cyborg.py @@ -17,13 +17,11 @@ # under the License. from __future__ import annotations -import subprocess import sys from pathlib import Path import yaml from termcolor import colored -from wcmatch import glob if __name__ not in ("__main__", "__mp_main__"): raise SystemExit( @@ -33,9 +31,8 @@ CONFIG_KEY = "labelPRBasedOnFilePath" -current_files = subprocess.check_output(["git", "ls-files"]).decode().splitlines() -git_root = Path(subprocess.check_output(["git", "rev-parse", "--show-toplevel"]).decode().strip()) -cyborg_config_path = git_root / ".github" / "boring-cyborg.yml" +repo_root = Path(__file__).parent.parent.parent.parent +cyborg_config_path = repo_root / ".github" / "boring-cyborg.yml" cyborg_config = yaml.safe_load(cyborg_config_path.read_text()) if CONFIG_KEY not in cyborg_config: raise SystemExit(f"Missing section {CONFIG_KEY}") @@ -43,12 +40,14 @@ errors = [] for label, patterns in cyborg_config[CONFIG_KEY].items(): for pattern in patterns: - if glob.globfilter(current_files, pattern, flags=glob.G | glob.E): + try: + next(Path(repo_root).glob(pattern)) continue - yaml_path = f"{CONFIG_KEY}.{label}" - errors.append( - f"Unused pattern [{colored(pattern, 'cyan')}] in [{colored(yaml_path, 'cyan')}] section." - ) + except StopIteration: + yaml_path = f"{CONFIG_KEY}.{label}" + errors.append( + f"Unused pattern [{colored(pattern, 'cyan')}] in [{colored(yaml_path, 'cyan')}] section." + ) if errors: print(f"Found {colored(str(len(errors)), 'red')} problems:") From f4f38f15fe4e6c32614686ceb3b5991dd91ef7ee Mon Sep 17 00:00:00 2001 From: JISHAN GARGACHARYA <34843832+jishangarg@users.noreply.github.com> Date: Tue, 1 Oct 2024 12:02:59 +0530 Subject: [PATCH 02/10] Doc update - Airflow local settings no longer importable from dags folder (#42231) --------- Co-authored-by: Jishan Garg --- docs/apache-airflow/howto/set-config.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/apache-airflow/howto/set-config.rst b/docs/apache-airflow/howto/set-config.rst index 4f19159a810d..2a03b2bbf5ee 100644 --- a/docs/apache-airflow/howto/set-config.rst +++ b/docs/apache-airflow/howto/set-config.rst @@ -179,6 +179,8 @@ where you can configure such local settings - This is usually done in the ``airf You should create a ``airflow_local_settings.py`` file and put it in a directory in ``sys.path`` or in the ``$AIRFLOW_HOME/config`` folder. (Airflow adds ``$AIRFLOW_HOME/config`` to ``sys.path`` when Airflow is initialized) +Starting from Airflow 2.10.1, the $AIRFLOW_HOME/dags folder is no longer included in sys.path at initialization, so any local settings in that folder will not be imported. Ensure that airflow_local_settings.py is located in a path that is part of sys.path during initialization, like $AIRFLOW_HOME/config. +For more context about this change, see the `mailing list announcement `_. You can see the example of such local settings here: From 9b29394a309895b63f7b60ba11789d52c62e288e Mon Sep 17 00:00:00 2001 From: Dewen Kong Date: Tue, 1 Oct 2024 02:40:33 -0400 Subject: [PATCH 03/10] add flexibility for redis service (#41811) * add service type options for redis * additional value * update based on testing * fix syntax * update description * Update chart/templates/redis/redis-service.yaml Co-authored-by: rom sharon <33751805+romsharon98@users.noreply.github.com> * Update chart/templates/redis/redis-service.yaml Co-authored-by: rom sharon <33751805+romsharon98@users.noreply.github.com> --------- Co-authored-by: rom sharon <33751805+romsharon98@users.noreply.github.com> --- chart/templates/redis/redis-service.yaml | 10 ++++++ chart/values.schema.json | 33 +++++++++++++++++++ chart/values.yaml | 8 +++++ helm_tests/other/test_redis.py | 41 ++++++++++++++++++++++++ 4 files changed, 92 insertions(+) diff --git a/chart/templates/redis/redis-service.yaml b/chart/templates/redis/redis-service.yaml index 17d4c8d5e483..ee010901ef84 100644 --- a/chart/templates/redis/redis-service.yaml +++ b/chart/templates/redis/redis-service.yaml @@ -35,7 +35,14 @@ metadata: {{- toYaml . | nindent 4 }} {{- end }} spec: +{{- if eq .Values.redis.service.type "ClusterIP" }} type: ClusterIP + {{- if .Values.redis.service.clusterIP }} + clusterIP: {{ .Values.redis.service.clusterIP }} + {{- end }} +{{- else }} + type: {{ .Values.redis.service.type }} +{{- end }} selector: tier: airflow component: redis @@ -45,4 +52,7 @@ spec: protocol: TCP port: {{ .Values.ports.redisDB }} targetPort: {{ .Values.ports.redisDB }} + {{- if (and (eq .Values.redis.service.type "NodePort") (not (empty .Values.redis.service.nodePort))) }} + nodePort: {{ .Values.redis.service.nodePort }} + {{- end }} {{- end }} diff --git a/chart/values.schema.json b/chart/values.schema.json index 948f09f3b9a4..d8b5de41c8eb 100644 --- a/chart/values.schema.json +++ b/chart/values.schema.json @@ -7670,6 +7670,39 @@ "type": "integer", "default": 600 }, + "service": { + "description": "service configuration.", + "type": "object", + "additionalProperties": false, + "properties": { + "type": { + "description": "Service type.", + "enum": [ + "ClusterIP", + "NodePort", + "LoadBalancer" + ], + "type": "string", + "default": "ClusterIP" + }, + "clusterIP": { + "description": "If using `ClusterIP` service type, custom IP address can be specified.", + "type": [ + "string", + "null" + ], + "default": null + }, + "nodePort": { + "description": "If using `NodePort` service type, custom node port can be specified.", + "type": [ + "integer", + "null" + ], + "default": null + } + } + }, "persistence": { "description": "Persistence configuration.", "type": "object", diff --git a/chart/values.yaml b/chart/values.yaml index 7bfa733a905b..0edb9f2bd7cd 100644 --- a/chart/values.yaml +++ b/chart/values.yaml @@ -2378,6 +2378,14 @@ redis: # Annotations to add to worker kubernetes service account. annotations: {} + service: + # service type, default: ClusterIP + type: "ClusterIP" + # If using ClusterIP service type, custom IP address can be specified + clusterIP: + # If using NodePort service type, custom node port can be specified + nodePort: + persistence: # Enable persistent volumes enabled: true diff --git a/helm_tests/other/test_redis.py b/helm_tests/other/test_redis.py index a5a6f2099e4a..8c4456742031 100644 --- a/helm_tests/other/test_redis.py +++ b/helm_tests/other/test_redis.py @@ -452,3 +452,44 @@ def test_overridden_automount_service_account_token(self): show_only=["templates/redis/redis-serviceaccount.yaml"], ) assert jmespath.search("automountServiceAccountToken", docs[0]) is False + + +class TestRedisService: + """Tests redis service.""" + + @pytest.mark.parametrize( + "redis_values, expected", + [ + ({"redis": {"service": {"type": "ClusterIP"}}}, "ClusterIP"), + ({"redis": {"service": {"type": "NodePort"}}}, "NodePort"), + ({"redis": {"service": {"type": "LoadBalancer"}}}, "LoadBalancer"), + ], + ) + def test_redis_service_type(self, redis_values, expected): + docs = render_chart( + values=redis_values, + show_only=["templates/redis/redis-service.yaml"], + ) + assert expected == jmespath.search("spec.type", docs[0]) + + def test_redis_service_nodeport(self): + docs = render_chart( + values={ + "redis": { + "service": {"type": "NodePort", "nodePort": 11111}, + }, + }, + show_only=["templates/redis/redis-service.yaml"], + ) + assert 11111 == jmespath.search("spec.ports[0].nodePort", docs[0]) + + def test_redis_service_clusterIP(self): + docs = render_chart( + values={ + "redis": { + "service": {"type": "ClusterIP", "clusterIP": "127.0.0.1"}, + }, + }, + show_only=["templates/redis/redis-service.yaml"], + ) + assert "127.0.0.1" == jmespath.search("spec.clusterIP", docs[0]) From db06cb8e893addeef4d49479beb1f2387fa63993 Mon Sep 17 00:00:00 2001 From: Howard Yoo <32691630+howardyoo@users.noreply.github.com> Date: Tue, 1 Oct 2024 02:04:06 -0500 Subject: [PATCH 04/10] Support of host.name in OTEL metrics and usage of OTEL_RESOURCE_ATTRIBUTES in metrics (#42428) * fixes: 42425, and 42424 * fixed static type check failure --- airflow/metrics/otel_logger.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/airflow/metrics/otel_logger.py b/airflow/metrics/otel_logger.py index 14080eb2d831..6d7d6e8fffa1 100644 --- a/airflow/metrics/otel_logger.py +++ b/airflow/metrics/otel_logger.py @@ -28,7 +28,7 @@ from opentelemetry.metrics import Observation from opentelemetry.sdk.metrics import MeterProvider from opentelemetry.sdk.metrics._internal.export import ConsoleMetricExporter, PeriodicExportingMetricReader -from opentelemetry.sdk.resources import SERVICE_NAME, Resource +from opentelemetry.sdk.resources import HOST_NAME, SERVICE_NAME, Resource from airflow.configuration import conf from airflow.exceptions import AirflowProviderDeprecationWarning @@ -40,6 +40,7 @@ get_validator, stat_name_otel_handler, ) +from airflow.utils.net import get_hostname if TYPE_CHECKING: from opentelemetry.metrics import Instrument @@ -410,7 +411,7 @@ def get_otel_logger(cls) -> SafeOtelLogger: debug = conf.getboolean("metrics", "otel_debugging_on") service_name = conf.get("metrics", "otel_service") - resource = Resource(attributes={SERVICE_NAME: service_name}) + resource = Resource.create(attributes={HOST_NAME: get_hostname(), SERVICE_NAME: service_name}) protocol = "https" if ssl_active else "http" endpoint = f"{protocol}://{host}:{port}/v1/metrics" From 14b63920a3e05142452e3170112a3098295cf437 Mon Sep 17 00:00:00 2001 From: Pierre Jeambrun Date: Tue, 1 Oct 2024 15:42:37 +0800 Subject: [PATCH 05/10] Update fastapi operation ids (#42588) * Update operation id automatically * Cherry pick Brent change --------- Co-authored-by: Brent Bovenzi --- airflow/api_fastapi/openapi/v1-generated.yaml | 10 +- airflow/api_fastapi/views/public/__init__.py | 5 +- airflow/api_fastapi/views/public/dags.py | 5 +- airflow/api_fastapi/views/router.py | 93 +++++++++++++++++++ airflow/api_fastapi/views/ui/__init__.py | 5 +- airflow/api_fastapi/views/ui/assets.py | 5 +- airflow/ui/openapi-gen/queries/common.ts | 44 ++++----- airflow/ui/openapi-gen/queries/prefetch.ts | 15 ++- airflow/ui/openapi-gen/queries/queries.ts | 32 +++---- airflow/ui/openapi-gen/queries/suspense.ts | 20 ++-- .../ui/openapi-gen/requests/services.gen.ts | 40 ++++---- airflow/ui/openapi-gen/requests/types.gen.ts | 24 ++--- airflow/ui/package.json | 2 +- airflow/ui/src/App.test.tsx | 7 +- airflow/ui/src/pages/DagsList.tsx | 4 +- 15 files changed, 193 insertions(+), 118 deletions(-) create mode 100644 airflow/api_fastapi/views/router.py diff --git a/airflow/api_fastapi/openapi/v1-generated.yaml b/airflow/api_fastapi/openapi/v1-generated.yaml index a38a1021890d..b08ef42c16df 100644 --- a/airflow/api_fastapi/openapi/v1-generated.yaml +++ b/airflow/api_fastapi/openapi/v1-generated.yaml @@ -12,7 +12,7 @@ paths: tags: - Asset summary: Next Run Assets - operationId: next_run_assets_ui_next_run_datasets__dag_id__get + operationId: next_run_assets parameters: - name: dag_id in: path @@ -27,7 +27,7 @@ paths: application/json: schema: type: object - title: Response Next Run Assets Ui Next Run Datasets Dag Id Get + title: Response Next Run Assets '422': description: Validation Error content: @@ -40,7 +40,7 @@ paths: - DAG summary: Get Dags description: Get all DAGs. - operationId: get_dags_public_dags_get + operationId: get_dags parameters: - name: limit in: query @@ -136,7 +136,7 @@ paths: - DAG summary: Patch Dags description: Patch multiple DAGs. - operationId: patch_dags_public_dags_patch + operationId: patch_dags parameters: - name: update_mask in: query @@ -258,7 +258,7 @@ paths: - DAG summary: Patch Dag description: Patch the specific DAG. - operationId: patch_dag_public_dags__dag_id__patch + operationId: patch_dag parameters: - name: dag_id in: path diff --git a/airflow/api_fastapi/views/public/__init__.py b/airflow/api_fastapi/views/public/__init__.py index b6466536c335..1c2511fc82ac 100644 --- a/airflow/api_fastapi/views/public/__init__.py +++ b/airflow/api_fastapi/views/public/__init__.py @@ -17,11 +17,10 @@ from __future__ import annotations -from fastapi import APIRouter - from airflow.api_fastapi.views.public.dags import dags_router +from airflow.api_fastapi.views.router import AirflowRouter -public_router = APIRouter(prefix="/public") +public_router = AirflowRouter(prefix="/public") public_router.include_router(dags_router) diff --git a/airflow/api_fastapi/views/public/dags.py b/airflow/api_fastapi/views/public/dags.py index a6c25d6568c1..3761d593d2fd 100644 --- a/airflow/api_fastapi/views/public/dags.py +++ b/airflow/api_fastapi/views/public/dags.py @@ -17,7 +17,7 @@ from __future__ import annotations -from fastapi import APIRouter, Depends, HTTPException, Query +from fastapi import Depends, HTTPException, Query from sqlalchemy import update from sqlalchemy.orm import Session from typing_extensions import Annotated @@ -42,9 +42,10 @@ SortParam, ) from airflow.api_fastapi.serializers.dags import DAGCollectionResponse, DAGPatchBody, DAGResponse +from airflow.api_fastapi.views.router import AirflowRouter from airflow.models import DagModel -dags_router = APIRouter(tags=["DAG"]) +dags_router = AirflowRouter(tags=["DAG"]) @dags_router.get("/dags") diff --git a/airflow/api_fastapi/views/router.py b/airflow/api_fastapi/views/router.py new file mode 100644 index 000000000000..5bf07e0fe834 --- /dev/null +++ b/airflow/api_fastapi/views/router.py @@ -0,0 +1,93 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from enum import Enum +from typing import Any, Callable, Sequence + +from fastapi import APIRouter, params +from fastapi.datastructures import Default +from fastapi.routing import APIRoute +from fastapi.types import DecoratedCallable, IncEx +from fastapi.utils import generate_unique_id +from starlette.responses import JSONResponse, Response +from starlette.routing import BaseRoute + + +class AirflowRouter(APIRouter): + """Extends the FastAPI default router.""" + + def api_route( + self, + path: str, + *, + response_model: Any = Default(None), + status_code: int | None = None, + tags: list[str | Enum] | None = None, + dependencies: Sequence[params.Depends] | None = None, + summary: str | None = None, + description: str | None = None, + response_description: str = "Successful Response", + responses: dict[int | str, dict[str, Any]] | None = None, + deprecated: bool | None = None, + methods: list[str] | None = None, + operation_id: str | None = None, + response_model_include: IncEx | None = None, + response_model_exclude: IncEx | None = None, + response_model_by_alias: bool = True, + response_model_exclude_unset: bool = False, + response_model_exclude_defaults: bool = False, + response_model_exclude_none: bool = False, + include_in_schema: bool = True, + response_class: type[Response] = Default(JSONResponse), + name: str | None = None, + callbacks: list[BaseRoute] | None = None, + openapi_extra: dict[str, Any] | None = None, + generate_unique_id_function: Callable[[APIRoute], str] = Default(generate_unique_id), + ) -> Callable[[DecoratedCallable], DecoratedCallable]: + def decorator(func: DecoratedCallable) -> DecoratedCallable: + self.add_api_route( + path, + func, + response_model=response_model, + status_code=status_code, + tags=tags, + dependencies=dependencies, + summary=summary, + description=description, + response_description=response_description, + responses=responses, + deprecated=deprecated, + methods=methods, + operation_id=operation_id or func.__name__, + response_model_include=response_model_include, + response_model_exclude=response_model_exclude, + response_model_by_alias=response_model_by_alias, + response_model_exclude_unset=response_model_exclude_unset, + response_model_exclude_defaults=response_model_exclude_defaults, + response_model_exclude_none=response_model_exclude_none, + include_in_schema=include_in_schema, + response_class=response_class, + name=name, + callbacks=callbacks, + openapi_extra=openapi_extra, + generate_unique_id_function=generate_unique_id_function, + ) + return func + + return decorator diff --git a/airflow/api_fastapi/views/ui/__init__.py b/airflow/api_fastapi/views/ui/__init__.py index edba930c3d1d..8495ac5e5e6a 100644 --- a/airflow/api_fastapi/views/ui/__init__.py +++ b/airflow/api_fastapi/views/ui/__init__.py @@ -16,10 +16,9 @@ # under the License. from __future__ import annotations -from fastapi import APIRouter - +from airflow.api_fastapi.views.router import AirflowRouter from airflow.api_fastapi.views.ui.assets import assets_router -ui_router = APIRouter(prefix="/ui") +ui_router = AirflowRouter(prefix="/ui") ui_router.include_router(assets_router) diff --git a/airflow/api_fastapi/views/ui/assets.py b/airflow/api_fastapi/views/ui/assets.py index 739c7d64af43..01cc9fd1cfbf 100644 --- a/airflow/api_fastapi/views/ui/assets.py +++ b/airflow/api_fastapi/views/ui/assets.py @@ -17,16 +17,17 @@ from __future__ import annotations -from fastapi import APIRouter, Depends, HTTPException, Request +from fastapi import Depends, HTTPException, Request from sqlalchemy import and_, func, select from sqlalchemy.orm import Session from typing_extensions import Annotated from airflow.api_fastapi.db.common import get_session +from airflow.api_fastapi.views.router import AirflowRouter from airflow.models import DagModel from airflow.models.asset import AssetDagRunQueue, AssetEvent, AssetModel, DagScheduleAssetReference -assets_router = APIRouter(tags=["Asset"]) +assets_router = AirflowRouter(tags=["Asset"]) @assets_router.get("/next_run_datasets/{dag_id}", include_in_schema=False) diff --git a/airflow/ui/openapi-gen/queries/common.ts b/airflow/ui/openapi-gen/queries/common.ts index b1508c86c0c4..96e49cc6d767 100644 --- a/airflow/ui/openapi-gen/queries/common.ts +++ b/airflow/ui/openapi-gen/queries/common.ts @@ -4,37 +4,31 @@ import { UseQueryResult } from "@tanstack/react-query"; import { AssetService, DagService } from "../requests/services.gen"; import { DagRunState } from "../requests/types.gen"; -export type AssetServiceNextRunAssetsUiNextRunDatasetsDagIdGetDefaultResponse = - Awaited< - ReturnType - >; -export type AssetServiceNextRunAssetsUiNextRunDatasetsDagIdGetQueryResult< - TData = AssetServiceNextRunAssetsUiNextRunDatasetsDagIdGetDefaultResponse, +export type AssetServiceNextRunAssetsDefaultResponse = Awaited< + ReturnType +>; +export type AssetServiceNextRunAssetsQueryResult< + TData = AssetServiceNextRunAssetsDefaultResponse, TError = unknown, > = UseQueryResult; -export const useAssetServiceNextRunAssetsUiNextRunDatasetsDagIdGetKey = - "AssetServiceNextRunAssetsUiNextRunDatasetsDagIdGet"; -export const UseAssetServiceNextRunAssetsUiNextRunDatasetsDagIdGetKeyFn = ( +export const useAssetServiceNextRunAssetsKey = "AssetServiceNextRunAssets"; +export const UseAssetServiceNextRunAssetsKeyFn = ( { dagId, }: { dagId: string; }, queryKey?: Array, -) => [ - useAssetServiceNextRunAssetsUiNextRunDatasetsDagIdGetKey, - ...(queryKey ?? [{ dagId }]), -]; -export type DagServiceGetDagsPublicDagsGetDefaultResponse = Awaited< - ReturnType +) => [useAssetServiceNextRunAssetsKey, ...(queryKey ?? [{ dagId }])]; +export type DagServiceGetDagsDefaultResponse = Awaited< + ReturnType >; -export type DagServiceGetDagsPublicDagsGetQueryResult< - TData = DagServiceGetDagsPublicDagsGetDefaultResponse, +export type DagServiceGetDagsQueryResult< + TData = DagServiceGetDagsDefaultResponse, TError = unknown, > = UseQueryResult; -export const useDagServiceGetDagsPublicDagsGetKey = - "DagServiceGetDagsPublicDagsGet"; -export const UseDagServiceGetDagsPublicDagsGetKeyFn = ( +export const useDagServiceGetDagsKey = "DagServiceGetDags"; +export const UseDagServiceGetDagsKeyFn = ( { dagDisplayNamePattern, dagIdPattern, @@ -60,7 +54,7 @@ export const UseDagServiceGetDagsPublicDagsGetKeyFn = ( } = {}, queryKey?: Array, ) => [ - useDagServiceGetDagsPublicDagsGetKey, + useDagServiceGetDagsKey, ...(queryKey ?? [ { dagDisplayNamePattern, @@ -76,9 +70,9 @@ export const UseDagServiceGetDagsPublicDagsGetKeyFn = ( }, ]), ]; -export type DagServicePatchDagsPublicDagsPatchMutationResult = Awaited< - ReturnType +export type DagServicePatchDagsMutationResult = Awaited< + ReturnType >; -export type DagServicePatchDagPublicDagsDagIdPatchMutationResult = Awaited< - ReturnType +export type DagServicePatchDagMutationResult = Awaited< + ReturnType >; diff --git a/airflow/ui/openapi-gen/queries/prefetch.ts b/airflow/ui/openapi-gen/queries/prefetch.ts index 7de7282a9bd0..95c2c7b73734 100644 --- a/airflow/ui/openapi-gen/queries/prefetch.ts +++ b/airflow/ui/openapi-gen/queries/prefetch.ts @@ -12,7 +12,7 @@ import * as Common from "./common"; * @returns unknown Successful Response * @throws ApiError */ -export const prefetchUseAssetServiceNextRunAssetsUiNextRunDatasetsDagIdGet = ( +export const prefetchUseAssetServiceNextRunAssets = ( queryClient: QueryClient, { dagId, @@ -21,11 +21,8 @@ export const prefetchUseAssetServiceNextRunAssetsUiNextRunDatasetsDagIdGet = ( }, ) => queryClient.prefetchQuery({ - queryKey: Common.UseAssetServiceNextRunAssetsUiNextRunDatasetsDagIdGetKeyFn( - { dagId }, - ), - queryFn: () => - AssetService.nextRunAssetsUiNextRunDatasetsDagIdGet({ dagId }), + queryKey: Common.UseAssetServiceNextRunAssetsKeyFn({ dagId }), + queryFn: () => AssetService.nextRunAssets({ dagId }), }); /** * Get Dags @@ -44,7 +41,7 @@ export const prefetchUseAssetServiceNextRunAssetsUiNextRunDatasetsDagIdGet = ( * @returns DAGCollectionResponse Successful Response * @throws ApiError */ -export const prefetchUseDagServiceGetDagsPublicDagsGet = ( +export const prefetchUseDagServiceGetDags = ( queryClient: QueryClient, { dagDisplayNamePattern, @@ -71,7 +68,7 @@ export const prefetchUseDagServiceGetDagsPublicDagsGet = ( } = {}, ) => queryClient.prefetchQuery({ - queryKey: Common.UseDagServiceGetDagsPublicDagsGetKeyFn({ + queryKey: Common.UseDagServiceGetDagsKeyFn({ dagDisplayNamePattern, dagIdPattern, lastDagRunState, @@ -84,7 +81,7 @@ export const prefetchUseDagServiceGetDagsPublicDagsGet = ( tags, }), queryFn: () => - DagService.getDagsPublicDagsGet({ + DagService.getDags({ dagDisplayNamePattern, dagIdPattern, lastDagRunState, diff --git a/airflow/ui/openapi-gen/queries/queries.ts b/airflow/ui/openapi-gen/queries/queries.ts index 5eda2a3d0e4d..985bf952e3eb 100644 --- a/airflow/ui/openapi-gen/queries/queries.ts +++ b/airflow/ui/openapi-gen/queries/queries.ts @@ -17,8 +17,8 @@ import * as Common from "./common"; * @returns unknown Successful Response * @throws ApiError */ -export const useAssetServiceNextRunAssetsUiNextRunDatasetsDagIdGet = < - TData = Common.AssetServiceNextRunAssetsUiNextRunDatasetsDagIdGetDefaultResponse, +export const useAssetServiceNextRunAssets = < + TData = Common.AssetServiceNextRunAssetsDefaultResponse, TError = unknown, TQueryKey extends Array = unknown[], >( @@ -31,12 +31,8 @@ export const useAssetServiceNextRunAssetsUiNextRunDatasetsDagIdGet = < options?: Omit, "queryKey" | "queryFn">, ) => useQuery({ - queryKey: Common.UseAssetServiceNextRunAssetsUiNextRunDatasetsDagIdGetKeyFn( - { dagId }, - queryKey, - ), - queryFn: () => - AssetService.nextRunAssetsUiNextRunDatasetsDagIdGet({ dagId }) as TData, + queryKey: Common.UseAssetServiceNextRunAssetsKeyFn({ dagId }, queryKey), + queryFn: () => AssetService.nextRunAssets({ dagId }) as TData, ...options, }); /** @@ -56,8 +52,8 @@ export const useAssetServiceNextRunAssetsUiNextRunDatasetsDagIdGet = < * @returns DAGCollectionResponse Successful Response * @throws ApiError */ -export const useDagServiceGetDagsPublicDagsGet = < - TData = Common.DagServiceGetDagsPublicDagsGetDefaultResponse, +export const useDagServiceGetDags = < + TData = Common.DagServiceGetDagsDefaultResponse, TError = unknown, TQueryKey extends Array = unknown[], >( @@ -88,7 +84,7 @@ export const useDagServiceGetDagsPublicDagsGet = < options?: Omit, "queryKey" | "queryFn">, ) => useQuery({ - queryKey: Common.UseDagServiceGetDagsPublicDagsGetKeyFn( + queryKey: Common.UseDagServiceGetDagsKeyFn( { dagDisplayNamePattern, dagIdPattern, @@ -104,7 +100,7 @@ export const useDagServiceGetDagsPublicDagsGet = < queryKey, ), queryFn: () => - DagService.getDagsPublicDagsGet({ + DagService.getDags({ dagDisplayNamePattern, dagIdPattern, lastDagRunState, @@ -135,8 +131,8 @@ export const useDagServiceGetDagsPublicDagsGet = < * @returns DAGCollectionResponse Successful Response * @throws ApiError */ -export const useDagServicePatchDagsPublicDagsPatch = < - TData = Common.DagServicePatchDagsPublicDagsPatchMutationResult, +export const useDagServicePatchDags = < + TData = Common.DagServicePatchDagsMutationResult, TError = unknown, TContext = unknown, >( @@ -190,7 +186,7 @@ export const useDagServicePatchDagsPublicDagsPatch = < tags, updateMask, }) => - DagService.patchDagsPublicDagsPatch({ + DagService.patchDags({ dagIdPattern, lastDagRunState, limit, @@ -214,8 +210,8 @@ export const useDagServicePatchDagsPublicDagsPatch = < * @returns DAGResponse Successful Response * @throws ApiError */ -export const useDagServicePatchDagPublicDagsDagIdPatch = < - TData = Common.DagServicePatchDagPublicDagsDagIdPatchMutationResult, +export const useDagServicePatchDag = < + TData = Common.DagServicePatchDagMutationResult, TError = unknown, TContext = unknown, >( @@ -244,7 +240,7 @@ export const useDagServicePatchDagPublicDagsDagIdPatch = < TContext >({ mutationFn: ({ dagId, requestBody, updateMask }) => - DagService.patchDagPublicDagsDagIdPatch({ + DagService.patchDag({ dagId, requestBody, updateMask, diff --git a/airflow/ui/openapi-gen/queries/suspense.ts b/airflow/ui/openapi-gen/queries/suspense.ts index 18dba7acb4b5..dc8b99dfb218 100644 --- a/airflow/ui/openapi-gen/queries/suspense.ts +++ b/airflow/ui/openapi-gen/queries/suspense.ts @@ -12,8 +12,8 @@ import * as Common from "./common"; * @returns unknown Successful Response * @throws ApiError */ -export const useAssetServiceNextRunAssetsUiNextRunDatasetsDagIdGetSuspense = < - TData = Common.AssetServiceNextRunAssetsUiNextRunDatasetsDagIdGetDefaultResponse, +export const useAssetServiceNextRunAssetsSuspense = < + TData = Common.AssetServiceNextRunAssetsDefaultResponse, TError = unknown, TQueryKey extends Array = unknown[], >( @@ -26,12 +26,8 @@ export const useAssetServiceNextRunAssetsUiNextRunDatasetsDagIdGetSuspense = < options?: Omit, "queryKey" | "queryFn">, ) => useSuspenseQuery({ - queryKey: Common.UseAssetServiceNextRunAssetsUiNextRunDatasetsDagIdGetKeyFn( - { dagId }, - queryKey, - ), - queryFn: () => - AssetService.nextRunAssetsUiNextRunDatasetsDagIdGet({ dagId }) as TData, + queryKey: Common.UseAssetServiceNextRunAssetsKeyFn({ dagId }, queryKey), + queryFn: () => AssetService.nextRunAssets({ dagId }) as TData, ...options, }); /** @@ -51,8 +47,8 @@ export const useAssetServiceNextRunAssetsUiNextRunDatasetsDagIdGetSuspense = < * @returns DAGCollectionResponse Successful Response * @throws ApiError */ -export const useDagServiceGetDagsPublicDagsGetSuspense = < - TData = Common.DagServiceGetDagsPublicDagsGetDefaultResponse, +export const useDagServiceGetDagsSuspense = < + TData = Common.DagServiceGetDagsDefaultResponse, TError = unknown, TQueryKey extends Array = unknown[], >( @@ -83,7 +79,7 @@ export const useDagServiceGetDagsPublicDagsGetSuspense = < options?: Omit, "queryKey" | "queryFn">, ) => useSuspenseQuery({ - queryKey: Common.UseDagServiceGetDagsPublicDagsGetKeyFn( + queryKey: Common.UseDagServiceGetDagsKeyFn( { dagDisplayNamePattern, dagIdPattern, @@ -99,7 +95,7 @@ export const useDagServiceGetDagsPublicDagsGetSuspense = < queryKey, ), queryFn: () => - DagService.getDagsPublicDagsGet({ + DagService.getDags({ dagDisplayNamePattern, dagIdPattern, lastDagRunState, diff --git a/airflow/ui/openapi-gen/requests/services.gen.ts b/airflow/ui/openapi-gen/requests/services.gen.ts index 7fb6306afbc6..be216bd534c6 100644 --- a/airflow/ui/openapi-gen/requests/services.gen.ts +++ b/airflow/ui/openapi-gen/requests/services.gen.ts @@ -3,14 +3,14 @@ import type { CancelablePromise } from "./core/CancelablePromise"; import { OpenAPI } from "./core/OpenAPI"; import { request as __request } from "./core/request"; import type { - NextRunAssetsUiNextRunDatasetsDagIdGetData, - NextRunAssetsUiNextRunDatasetsDagIdGetResponse, - GetDagsPublicDagsGetData, - GetDagsPublicDagsGetResponse, - PatchDagsPublicDagsPatchData, - PatchDagsPublicDagsPatchResponse, - PatchDagPublicDagsDagIdPatchData, - PatchDagPublicDagsDagIdPatchResponse, + NextRunAssetsData, + NextRunAssetsResponse, + GetDagsData, + GetDagsResponse, + PatchDagsData, + PatchDagsResponse, + PatchDagData, + PatchDagResponse, } from "./types.gen"; export class AssetService { @@ -21,9 +21,9 @@ export class AssetService { * @returns unknown Successful Response * @throws ApiError */ - public static nextRunAssetsUiNextRunDatasetsDagIdGet( - data: NextRunAssetsUiNextRunDatasetsDagIdGetData, - ): CancelablePromise { + public static nextRunAssets( + data: NextRunAssetsData, + ): CancelablePromise { return __request(OpenAPI, { method: "GET", url: "/ui/next_run_datasets/{dag_id}", @@ -55,9 +55,9 @@ export class DagService { * @returns DAGCollectionResponse Successful Response * @throws ApiError */ - public static getDagsPublicDagsGet( - data: GetDagsPublicDagsGetData = {}, - ): CancelablePromise { + public static getDags( + data: GetDagsData = {}, + ): CancelablePromise { return __request(OpenAPI, { method: "GET", url: "/public/dags", @@ -96,9 +96,9 @@ export class DagService { * @returns DAGCollectionResponse Successful Response * @throws ApiError */ - public static patchDagsPublicDagsPatch( - data: PatchDagsPublicDagsPatchData, - ): CancelablePromise { + public static patchDags( + data: PatchDagsData, + ): CancelablePromise { return __request(OpenAPI, { method: "PATCH", url: "/public/dags", @@ -135,9 +135,9 @@ export class DagService { * @returns DAGResponse Successful Response * @throws ApiError */ - public static patchDagPublicDagsDagIdPatch( - data: PatchDagPublicDagsDagIdPatchData, - ): CancelablePromise { + public static patchDag( + data: PatchDagData, + ): CancelablePromise { return __request(OpenAPI, { method: "PATCH", url: "/public/dags/{dag_id}", diff --git a/airflow/ui/openapi-gen/requests/types.gen.ts b/airflow/ui/openapi-gen/requests/types.gen.ts index 0fe7134ba8c3..e1db8310a1dc 100644 --- a/airflow/ui/openapi-gen/requests/types.gen.ts +++ b/airflow/ui/openapi-gen/requests/types.gen.ts @@ -88,15 +88,15 @@ export type ValidationError = { type: string; }; -export type NextRunAssetsUiNextRunDatasetsDagIdGetData = { +export type NextRunAssetsData = { dagId: string; }; -export type NextRunAssetsUiNextRunDatasetsDagIdGetResponse = { +export type NextRunAssetsResponse = { [key: string]: unknown; }; -export type GetDagsPublicDagsGetData = { +export type GetDagsData = { dagDisplayNamePattern?: string | null; dagIdPattern?: string | null; lastDagRunState?: DagRunState | null; @@ -109,9 +109,9 @@ export type GetDagsPublicDagsGetData = { tags?: Array; }; -export type GetDagsPublicDagsGetResponse = DAGCollectionResponse; +export type GetDagsResponse = DAGCollectionResponse; -export type PatchDagsPublicDagsPatchData = { +export type PatchDagsData = { dagIdPattern?: string | null; lastDagRunState?: DagRunState | null; limit?: number; @@ -124,20 +124,20 @@ export type PatchDagsPublicDagsPatchData = { updateMask?: Array | null; }; -export type PatchDagsPublicDagsPatchResponse = DAGCollectionResponse; +export type PatchDagsResponse = DAGCollectionResponse; -export type PatchDagPublicDagsDagIdPatchData = { +export type PatchDagData = { dagId: string; requestBody: DAGPatchBody; updateMask?: Array | null; }; -export type PatchDagPublicDagsDagIdPatchResponse = DAGResponse; +export type PatchDagResponse = DAGResponse; export type $OpenApiTs = { "/ui/next_run_datasets/{dag_id}": { get: { - req: NextRunAssetsUiNextRunDatasetsDagIdGetData; + req: NextRunAssetsData; res: { /** * Successful Response @@ -154,7 +154,7 @@ export type $OpenApiTs = { }; "/public/dags": { get: { - req: GetDagsPublicDagsGetData; + req: GetDagsData; res: { /** * Successful Response @@ -167,7 +167,7 @@ export type $OpenApiTs = { }; }; patch: { - req: PatchDagsPublicDagsPatchData; + req: PatchDagsData; res: { /** * Successful Response @@ -198,7 +198,7 @@ export type $OpenApiTs = { }; "/public/dags/{dag_id}": { patch: { - req: PatchDagPublicDagsDagIdPatchData; + req: PatchDagData; res: { /** * Successful Response diff --git a/airflow/ui/package.json b/airflow/ui/package.json index c7d79f792a59..1f77334074f0 100644 --- a/airflow/ui/package.json +++ b/airflow/ui/package.json @@ -11,7 +11,7 @@ "lint:fix": "eslint --fix && tsc --p tsconfig.app.json", "format": "pnpm prettier --write .", "preview": "vite preview", - "codegen": "openapi-rq -i \"../api_fastapi/openapi/v1-generated.yaml\" -c axios --format prettier -o openapi-gen", + "codegen": "openapi-rq -i \"../api_fastapi/openapi/v1-generated.yaml\" -c axios --format prettier -o openapi-gen --operationId", "test": "vitest run", "coverage": "vitest run --coverage" }, diff --git a/airflow/ui/src/App.test.tsx b/airflow/ui/src/App.test.tsx index d34cf016befd..5efcf90f1a05 100644 --- a/airflow/ui/src/App.test.tsx +++ b/airflow/ui/src/App.test.tsx @@ -105,10 +105,9 @@ beforeEach(() => { isLoading: false, } as QueryObserverSuccessResult; - vi.spyOn( - openapiQueriesModule, - "useDagServiceGetDagsPublicDagsGet", - ).mockImplementation(() => returnValue); + vi.spyOn(openapiQueriesModule, "useDagServiceGetDags").mockImplementation( + () => returnValue, + ); }); afterEach(() => { diff --git a/airflow/ui/src/pages/DagsList.tsx b/airflow/ui/src/pages/DagsList.tsx index fe764f117e45..ab480d2cbabd 100644 --- a/airflow/ui/src/pages/DagsList.tsx +++ b/airflow/ui/src/pages/DagsList.tsx @@ -30,7 +30,7 @@ import { Select as ReactSelect } from "chakra-react-select"; import { type ChangeEventHandler, useCallback } from "react"; import { useSearchParams } from "react-router-dom"; -import { useDagServiceGetDagsPublicDagsGet } from "openapi/queries"; +import { useDagServiceGetDags } from "openapi/queries"; import type { DAGResponse } from "openapi/requests/types.gen"; import { DataTable } from "../components/DataTable"; @@ -93,7 +93,7 @@ export const DagsList = ({ cardView = false }) => { const [sort] = sorting; const orderBy = sort ? `${sort.desc ? "-" : ""}${sort.id}` : undefined; - const { data, isLoading } = useDagServiceGetDagsPublicDagsGet({ + const { data, isLoading } = useDagServiceGetDags({ limit: pagination.pageSize, offset: pagination.pageIndex * pagination.pageSize, onlyActive: true, From 8c9d251efbd07d4c30a90264364847eb529c90d1 Mon Sep 17 00:00:00 2001 From: Jarek Potiuk Date: Tue, 1 Oct 2024 01:02:38 -0700 Subject: [PATCH 06/10] Limit build-images workflow to main and v2-10 branches (#42601) There is no need to run image builds for PRs to old branches. --- .github/workflows/build-images.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/build-images.yml b/.github/workflows/build-images.yml index 1256fd2f0da6..abf966faede0 100644 --- a/.github/workflows/build-images.yml +++ b/.github/workflows/build-images.yml @@ -21,6 +21,10 @@ run-name: > Build images for ${{ github.event.pull_request.title }} ${{ github.event.pull_request._links.html.href }} on: # yamllint disable-line rule:truthy pull_request_target: + branches: + - main + - v2-10-stable + - v2-10-test permissions: # all other permissions are set to none contents: read From e46365d96ca2bd39a8148e98d38551d2184a78c7 Mon Sep 17 00:00:00 2001 From: Jakub Dardzinski Date: Tue, 1 Oct 2024 14:16:03 +0200 Subject: [PATCH 07/10] openlineage: add unit test for listener hooks on dag run state changes. (#42554) openlineage: cover task instance failure in unit tests. Signed-off-by: Jakub Dardzinski --- tests/dags/test_openlineage_execution.py | 12 +++++- .../openlineage/plugins/test_execution.py | 11 +++++ .../openlineage/plugins/test_listener.py | 43 +++++++++++++++++++ 3 files changed, 65 insertions(+), 1 deletion(-) diff --git a/tests/dags/test_openlineage_execution.py b/tests/dags/test_openlineage_execution.py index 475e43ef6ac2..f8db91611e84 100644 --- a/tests/dags/test_openlineage_execution.py +++ b/tests/dags/test_openlineage_execution.py @@ -27,13 +27,16 @@ class OpenLineageExecutionOperator(BaseOperator): - def __init__(self, *, stall_amount=0, **kwargs) -> None: + def __init__(self, *, stall_amount=0, fail=False, **kwargs) -> None: super().__init__(**kwargs) self.stall_amount = stall_amount + self.fail = fail def execute(self, context): self.log.error("STALL AMOUNT %s", self.stall_amount) time.sleep(1) + if self.fail: + raise Exception("Failed") def get_openlineage_facets_on_start(self): return OperatorLineage(inputs=[Dataset(namespace="test", name="on-start")]) @@ -43,6 +46,11 @@ def get_openlineage_facets_on_complete(self, task_instance): time.sleep(self.stall_amount) return OperatorLineage(inputs=[Dataset(namespace="test", name="on-complete")]) + def get_openlineage_facets_on_failure(self, task_instance): + self.log.error("STALL AMOUNT %s", self.stall_amount) + time.sleep(self.stall_amount) + return OperatorLineage(inputs=[Dataset(namespace="test", name="on-failure")]) + with DAG( dag_id="test_openlineage_execution", @@ -57,3 +65,5 @@ def get_openlineage_facets_on_complete(self, task_instance): mid_stall = OpenLineageExecutionOperator(task_id="execute_mid_stall", stall_amount=15) long_stall = OpenLineageExecutionOperator(task_id="execute_long_stall", stall_amount=30) + + fail = OpenLineageExecutionOperator(task_id="execute_fail", fail=True) diff --git a/tests/providers/openlineage/plugins/test_execution.py b/tests/providers/openlineage/plugins/test_execution.py index 3adaaac582dd..8c0bdd55a1f9 100644 --- a/tests/providers/openlineage/plugins/test_execution.py +++ b/tests/providers/openlineage/plugins/test_execution.py @@ -124,6 +124,17 @@ def test_not_stalled_task_emits_proper_lineage(self): assert has_value_in_events(events, ["inputs", "name"], "on-start") assert has_value_in_events(events, ["inputs", "name"], "on-complete") + @pytest.mark.db_test + @conf_vars({("openlineage", "transport"): f'{{"type": "file", "log_file_path": "{listener_path}"}}'}) + def test_not_stalled_failing_task_emits_proper_lineage(self): + task_name = "execute_fail" + run_id = "test_failure" + self.setup_job(task_name, run_id) + + events = get_sorted_events(tmp_dir) + assert has_value_in_events(events, ["inputs", "name"], "on-start") + assert has_value_in_events(events, ["inputs", "name"], "on-failure") + @conf_vars( { ("openlineage", "transport"): f'{{"type": "file", "log_file_path": "{listener_path}"}}', diff --git a/tests/providers/openlineage/plugins/test_listener.py b/tests/providers/openlineage/plugins/test_listener.py index 92467a58af8c..57c0134f79d8 100644 --- a/tests/providers/openlineage/plugins/test_listener.py +++ b/tests/providers/openlineage/plugins/test_listener.py @@ -606,6 +606,49 @@ def test_listener_on_dag_run_state_changes_configure_process_pool_size(mock_exec mock_executor.return_value.submit.assert_called_once() +class MockExecutor: + def __init__(self, *args, **kwargs): + self.submitted = False + self.succeeded = False + self.result = None + + def submit(self, fn, /, *args, **kwargs): + self.submitted = True + try: + fn(*args, **kwargs) + self.succeeded = True + except Exception: + pass + return MagicMock() + + def shutdown(self, *args, **kwargs): + print("Shutting down") + + +@pytest.mark.parametrize( + ("method", "dag_run_state"), + [ + ("on_dag_run_running", DagRunState.RUNNING), + ("on_dag_run_success", DagRunState.SUCCESS), + ("on_dag_run_failed", DagRunState.FAILED), + ], +) +@patch("airflow.providers.openlineage.plugins.adapter.OpenLineageAdapter.emit") +def test_listener_on_dag_run_state_changes(mock_emit, method, dag_run_state, create_task_instance): + mock_executor = MockExecutor() + ti = create_task_instance(dag_id="dag", task_id="op") + # Change the state explicitly to set end_date following the logic in the method + ti.dag_run.set_state(dag_run_state) + with mock.patch( + "airflow.providers.openlineage.plugins.listener.ProcessPoolExecutor", return_value=mock_executor + ): + listener = OpenLineageListener() + getattr(listener, method)(ti.dag_run, None) + assert mock_executor.submitted is True + assert mock_executor.succeeded is True + mock_emit.assert_called_once() + + def test_listener_logs_failed_serialization(): listener = OpenLineageListener() callback_future = Future() From 9536c98a439fc028542bb9b8eb9b76c24e2ee02b Mon Sep 17 00:00:00 2001 From: Vincent <97131062+vincbeck@users.noreply.github.com> Date: Tue, 1 Oct 2024 09:59:11 -0400 Subject: [PATCH 08/10] Update Rest API tests to no longer rely on FAB auth manager. Move tests specific to FAB permissions to FAB provider (#42523) --- .../managers/simple/simple_auth_manager.py | 7 +- airflow/auth/managers/simple/user.py | 6 +- .../0034_3_0_0_update_user_id_type.py | 52 +++ ..._3_0_0_add_name_field_to_dataset_model.py} | 4 +- airflow/models/dagrun.py | 2 +- airflow/models/taskinstance.py | 2 +- .../api/auth/backend/basic_auth.py | 4 +- .../api/auth/backend/kerberos_auth.py | 2 +- .../fab/auth_manager/models/anonymous_user.py | 2 +- docs/apache-airflow/img/airflow_erd.sha256 | 2 +- docs/apache-airflow/img/airflow_erd.svg | 4 +- docs/apache-airflow/migrations-ref.rst | 5 +- tests/api_connexion/conftest.py | 11 +- .../endpoints/test_backfill_endpoint.py | 31 +- .../endpoints/test_config_endpoint.py | 12 +- .../endpoints/test_connection_endpoint.py | 17 +- .../endpoints/test_dag_endpoint.py | 100 +--- .../endpoints/test_dag_parsing.py | 16 +- .../endpoints/test_dag_run_endpoint.py | 154 +------ .../endpoints/test_dag_source_endpoint.py | 71 +-- .../endpoints/test_dag_stats_endpoint.py | 15 +- .../endpoints/test_dag_warning_endpoint.py | 33 +- .../endpoints/test_dataset_endpoint.py | 234 +--------- .../endpoints/test_event_log_endpoint.py | 80 +--- .../endpoints/test_extra_link_endpoint.py | 20 +- .../endpoints/test_import_error_endpoint.py | 170 +------ .../endpoints/test_log_endpoint.py | 9 +- .../test_mapped_task_instance_endpoint.py | 25 +- .../endpoints/test_plugin_endpoint.py | 12 +- .../endpoints/test_pool_endpoint.py | 17 +- .../endpoints/test_provider_endpoint.py | 12 +- .../endpoints/test_task_endpoint.py | 16 +- .../endpoints/test_task_instance_endpoint.py | 219 +-------- .../endpoints/test_variable_endpoint.py | 37 +- .../endpoints/test_xcom_endpoint.py | 74 +-- tests/api_connexion/test_auth.py | 188 ++------ tests/api_connexion/test_security.py | 8 +- .../api_endpoints/api_connexion_utils.py | 116 +++++ .../remote_user_api_auth_backend.py | 81 ++++ .../auth_manager/api_endpoints/test_auth.py | 176 ++++++++ .../api_endpoints/test_backfill_endpoint.py | 264 +++++++++++ .../auth_manager/api_endpoints}/test_cors.py | 35 +- .../api_endpoints/test_dag_endpoint.py | 252 +++++++++++ .../api_endpoints/test_dag_run_endpoint.py | 273 +++++++++++ .../api_endpoints/test_dag_source_endpoint.py | 144 ++++++ .../test_dag_warning_endpoint.py | 84 ++++ .../api_endpoints/test_dataset_endpoint.py | 327 ++++++++++++++ .../api_endpoints/test_event_log_endpoint.py | 151 +++++++ .../test_import_error_endpoint.py | 221 +++++++++ .../test_role_and_permission_endpoint.py | 22 +- .../test_role_and_permission_schema.py | 22 +- .../test_task_instance_endpoint.py | 427 ++++++++++++++++++ .../api_endpoints/test_user_endpoint.py | 15 +- .../api_endpoints/test_user_schema.py | 3 +- .../api_endpoints/test_variable_endpoint.py | 88 ++++ .../api_endpoints/test_xcom_endpoint.py | 230 ++++++++++ tests/providers/fab/auth_manager/conftest.py | 17 +- .../fab/auth_manager/test_security.py | 2 +- .../auth_manager/views/test_permissions.py | 2 +- .../fab/auth_manager/views/test_roles_list.py | 2 +- .../fab/auth_manager/views/test_user.py | 2 +- .../fab/auth_manager/views/test_user_edit.py | 2 +- .../fab/auth_manager/views/test_user_stats.py | 2 +- tests/test_utils/api_connexion_utils.py | 64 +-- .../remote_user_api_auth_backend.py | 32 +- .../www/views/test_views_custom_user_views.py | 5 +- tests/www/views/test_views_dagrun.py | 6 +- tests/www/views/test_views_home.py | 2 +- tests/www/views/test_views_tasks.py | 6 +- tests/www/views/test_views_variable.py | 2 +- 70 files changed, 3208 insertions(+), 1542 deletions(-) create mode 100644 airflow/migrations/versions/0034_3_0_0_update_user_id_type.py rename airflow/migrations/versions/{0034_3_0_0_add_name_field_to_dataset_model.py => 0035_3_0_0_add_name_field_to_dataset_model.py} (98%) create mode 100644 tests/providers/fab/auth_manager/api_endpoints/api_connexion_utils.py create mode 100644 tests/providers/fab/auth_manager/api_endpoints/remote_user_api_auth_backend.py create mode 100644 tests/providers/fab/auth_manager/api_endpoints/test_auth.py create mode 100644 tests/providers/fab/auth_manager/api_endpoints/test_backfill_endpoint.py rename tests/{api_connexion => providers/fab/auth_manager/api_endpoints}/test_cors.py (81%) create mode 100644 tests/providers/fab/auth_manager/api_endpoints/test_dag_endpoint.py create mode 100644 tests/providers/fab/auth_manager/api_endpoints/test_dag_run_endpoint.py create mode 100644 tests/providers/fab/auth_manager/api_endpoints/test_dag_source_endpoint.py create mode 100644 tests/providers/fab/auth_manager/api_endpoints/test_dag_warning_endpoint.py create mode 100644 tests/providers/fab/auth_manager/api_endpoints/test_dataset_endpoint.py create mode 100644 tests/providers/fab/auth_manager/api_endpoints/test_event_log_endpoint.py create mode 100644 tests/providers/fab/auth_manager/api_endpoints/test_import_error_endpoint.py rename tests/{api_connexion/schemas => providers/fab/auth_manager/api_endpoints}/test_role_and_permission_schema.py (85%) create mode 100644 tests/providers/fab/auth_manager/api_endpoints/test_task_instance_endpoint.py create mode 100644 tests/providers/fab/auth_manager/api_endpoints/test_variable_endpoint.py create mode 100644 tests/providers/fab/auth_manager/api_endpoints/test_xcom_endpoint.py diff --git a/airflow/auth/managers/simple/simple_auth_manager.py b/airflow/auth/managers/simple/simple_auth_manager.py index 451068733667..4a9639a998c4 100644 --- a/airflow/auth/managers/simple/simple_auth_manager.py +++ b/airflow/auth/managers/simple/simple_auth_manager.py @@ -221,7 +221,12 @@ def _is_authorized( user = self.get_user() if not user: return False - role_str = user.get_role().upper() + + user_role = user.get_role() + if not user_role: + return False + + role_str = user_role.upper() role = SimpleAuthManagerRole[role_str] if role == SimpleAuthManagerRole.ADMIN: return True diff --git a/airflow/auth/managers/simple/user.py b/airflow/auth/managers/simple/user.py index fa032f596ee4..f4591b0b1c75 100644 --- a/airflow/auth/managers/simple/user.py +++ b/airflow/auth/managers/simple/user.py @@ -24,10 +24,10 @@ class SimpleAuthManagerUser(BaseUser): User model for users managed by the simple auth manager. :param username: The username - :param role: The role associated to the user + :param role: The role associated to the user. If not provided, the user has no permission """ - def __init__(self, *, username: str, role: str) -> None: + def __init__(self, *, username: str, role: str | None) -> None: self.username = username self.role = role @@ -37,5 +37,5 @@ def get_id(self) -> str: def get_name(self) -> str: return self.username - def get_role(self): + def get_role(self) -> str | None: return self.role diff --git a/airflow/migrations/versions/0034_3_0_0_update_user_id_type.py b/airflow/migrations/versions/0034_3_0_0_update_user_id_type.py new file mode 100644 index 000000000000..321a1e2bbafa --- /dev/null +++ b/airflow/migrations/versions/0034_3_0_0_update_user_id_type.py @@ -0,0 +1,52 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Update dag_run_note.user_id and task_instance_note.user_id columns to String. + +Revision ID: 44eabb1904b4 +Revises: 16cbcb1c8c36 +Create Date: 2024-09-27 09:57:29.830521 + +""" + +from __future__ import annotations + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "44eabb1904b4" +down_revision = "16cbcb1c8c36" +branch_labels = None +depends_on = None +airflow_version = "3.0.0" + + +def upgrade(): + with op.batch_alter_table("dag_run_note") as batch_op: + batch_op.alter_column("user_id", type_=sa.String(length=128)) + with op.batch_alter_table("task_instance_note") as batch_op: + batch_op.alter_column("user_id", type_=sa.String(length=128)) + + +def downgrade(): + with op.batch_alter_table("dag_run_note") as batch_op: + batch_op.alter_column("user_id", type_=sa.Integer(), postgresql_using="user_id::integer") + with op.batch_alter_table("task_instance_note") as batch_op: + batch_op.alter_column("user_id", type_=sa.Integer(), postgresql_using="user_id::integer") diff --git a/airflow/migrations/versions/0034_3_0_0_add_name_field_to_dataset_model.py b/airflow/migrations/versions/0035_3_0_0_add_name_field_to_dataset_model.py similarity index 98% rename from airflow/migrations/versions/0034_3_0_0_add_name_field_to_dataset_model.py rename to airflow/migrations/versions/0035_3_0_0_add_name_field_to_dataset_model.py index 5c8aec69e9be..6016dd965890 100644 --- a/airflow/migrations/versions/0034_3_0_0_add_name_field_to_dataset_model.py +++ b/airflow/migrations/versions/0035_3_0_0_add_name_field_to_dataset_model.py @@ -30,7 +30,7 @@ also rename the one on DatasetAliasModel here for consistency. Revision ID: 0d9e73a75ee4 -Revises: 16cbcb1c8c36 +Revises: 44eabb1904b4 Create Date: 2024-08-13 09:45:32.213222 """ @@ -42,7 +42,7 @@ # revision identifiers, used by Alembic. revision = "0d9e73a75ee4" -down_revision = "16cbcb1c8c36" +down_revision = "44eabb1904b4" branch_labels = None depends_on = None airflow_version = "3.0.0" diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index 5d53e51763df..4928c7fcbd8f 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -1687,7 +1687,7 @@ class DagRunNote(Base): __tablename__ = "dag_run_note" - user_id = Column(Integer, nullable=True) + user_id = Column(String(128), nullable=True) dag_run_id = Column(Integer, primary_key=True, nullable=False) content = Column(String(1000).with_variant(Text(1000), "mysql")) created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False) diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index b19e65486307..333a4cad91cb 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -4002,7 +4002,7 @@ class TaskInstanceNote(TaskInstanceDependencies): __tablename__ = "task_instance_note" - user_id = Column(Integer, nullable=True) + user_id = Column(String(128), nullable=True) task_id = Column(StringID(), primary_key=True, nullable=False) dag_id = Column(StringID(), primary_key=True, nullable=False) run_id = Column(StringID(), primary_key=True, nullable=False) diff --git a/airflow/providers/fab/auth_manager/api/auth/backend/basic_auth.py b/airflow/providers/fab/auth_manager/api/auth/backend/basic_auth.py index 3a0328fe9962..7b5033873345 100644 --- a/airflow/providers/fab/auth_manager/api/auth/backend/basic_auth.py +++ b/airflow/providers/fab/auth_manager/api/auth/backend/basic_auth.py @@ -62,9 +62,7 @@ def requires_authentication(function: T): @wraps(function) def decorated(*args, **kwargs): - if auth_current_user() is not None or current_app.appbuilder.get_app.config.get( - "AUTH_ROLE_PUBLIC", None - ): + if auth_current_user() is not None or current_app.config.get("AUTH_ROLE_PUBLIC", None): return function(*args, **kwargs) else: return Response("Unauthorized", 401, {"WWW-Authenticate": "Basic"}) diff --git a/airflow/providers/fab/auth_manager/api/auth/backend/kerberos_auth.py b/airflow/providers/fab/auth_manager/api/auth/backend/kerberos_auth.py index d8d5a95ee676..f2038b27597c 100644 --- a/airflow/providers/fab/auth_manager/api/auth/backend/kerberos_auth.py +++ b/airflow/providers/fab/auth_manager/api/auth/backend/kerberos_auth.py @@ -124,7 +124,7 @@ def requires_authentication(function: T, find_user: Callable[[str], BaseUser] | @wraps(function) def decorated(*args, **kwargs): - if current_app.appbuilder.get_app.config.get("AUTH_ROLE_PUBLIC", None): + if current_app.config.get("AUTH_ROLE_PUBLIC", None): response = function(*args, **kwargs) return make_response(response) diff --git a/airflow/providers/fab/auth_manager/models/anonymous_user.py b/airflow/providers/fab/auth_manager/models/anonymous_user.py index 2f294fd9e5d0..9afb2cdff635 100644 --- a/airflow/providers/fab/auth_manager/models/anonymous_user.py +++ b/airflow/providers/fab/auth_manager/models/anonymous_user.py @@ -35,7 +35,7 @@ class AnonymousUser(AnonymousUserMixin, BaseUser): @property def roles(self): if not self._roles: - public_role = current_app.appbuilder.get_app.config.get("AUTH_ROLE_PUBLIC", None) + public_role = current_app.config.get("AUTH_ROLE_PUBLIC", None) self._roles = {current_app.appbuilder.sm.find_role(public_role)} if public_role else set() return self._roles diff --git a/docs/apache-airflow/img/airflow_erd.sha256 b/docs/apache-airflow/img/airflow_erd.sha256 index e4a952da1b9f..bca068fde674 100644 --- a/docs/apache-airflow/img/airflow_erd.sha256 +++ b/docs/apache-airflow/img/airflow_erd.sha256 @@ -1 +1 @@ -c33e9a583a5b29eb748ebd50e117643e11bcb2a9b61ec017efd690621e22769b \ No newline at end of file +64dfad12dfd49f033c4723c2f3bb3bac58dd956136fb24a87a2e5a6ae176ec1a \ No newline at end of file diff --git a/docs/apache-airflow/img/airflow_erd.svg b/docs/apache-airflow/img/airflow_erd.svg index 76fbd8f841f2..4eb6c2ee7091 100644 --- a/docs/apache-airflow/img/airflow_erd.svg +++ b/docs/apache-airflow/img/airflow_erd.svg @@ -1394,7 +1394,7 @@ user_id - [INTEGER] + [VARCHAR(100)] @@ -1813,7 +1813,7 @@ user_id - [INTEGER] + [VARCHAR(100)] diff --git a/docs/apache-airflow/migrations-ref.rst b/docs/apache-airflow/migrations-ref.rst index a547d03d75be..e4fb2dfa332e 100644 --- a/docs/apache-airflow/migrations-ref.rst +++ b/docs/apache-airflow/migrations-ref.rst @@ -39,7 +39,10 @@ Here's the list of all the Database Migrations that are executed via when you ru +-------------------------+------------------+-------------------+--------------------------------------------------------------+ | Revision ID | Revises ID | Airflow Version | Description | +=========================+==================+===================+==============================================================+ -| ``0d9e73a75ee4`` (head) | ``16cbcb1c8c36`` | ``3.0.0`` | Add name and group fields to DatasetModel. | +| ``0d9e73a75ee4`` (head) | ``44eabb1904b4`` | ``3.0.0`` | Add name and group fields to DatasetModel. | ++-------------------------+------------------+-------------------+--------------------------------------------------------------+ +| ``44eabb1904b4`` | ``16cbcb1c8c36`` | ``3.0.0`` | Update dag_run_note.user_id and task_instance_note.user_id | +| | | | columns to String. | +-------------------------+------------------+-------------------+--------------------------------------------------------------+ | ``16cbcb1c8c36`` | ``522625f6d606`` | ``3.0.0`` | Remove redundant index. | +-------------------------+------------------+-------------------+--------------------------------------------------------------+ diff --git a/tests/api_connexion/conftest.py b/tests/api_connexion/conftest.py index 38e7b58cb598..6a23b2cf11d9 100644 --- a/tests/api_connexion/conftest.py +++ b/tests/api_connexion/conftest.py @@ -36,9 +36,16 @@ def minimal_app_for_api(): ] ) def factory(): - with conf_vars({("api", "auth_backends"): "tests.test_utils.remote_user_api_auth_backend"}): + with conf_vars( + { + ("api", "auth_backends"): "tests.test_utils.remote_user_api_auth_backend", + ( + "core", + "auth_manager", + ): "airflow.auth.managers.simple.simple_auth_manager.SimpleAuthManager", + } + ): _app = app.create_app(testing=True, config={"WTF_CSRF_ENABLED": False}) # type:ignore - _app.config["AUTH_ROLE_PUBLIC"] = None return _app return factory() diff --git a/tests/api_connexion/endpoints/test_backfill_endpoint.py b/tests/api_connexion/endpoints/test_backfill_endpoint.py index 51a4faf40055..07b2a3fd56c2 100644 --- a/tests/api_connexion/endpoints/test_backfill_endpoint.py +++ b/tests/api_connexion/endpoints/test_backfill_endpoint.py @@ -29,7 +29,6 @@ from airflow.models.dag import DAG from airflow.models.serialized_dag import SerializedDagModel from airflow.operators.empty import EmptyOperator -from airflow.security import permissions from airflow.utils import timezone from airflow.utils.session import provide_session from tests.test_utils.api_connexion_utils import create_user, delete_user @@ -50,25 +49,11 @@ def configured_app(minimal_app_for_api): app = minimal_app_for_api create_user( - app, # type: ignore + app, username="test", - role_name="Test", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_DAG), - ], - ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore - create_user(app, username="test_granular_permissions", role_name="TestGranularDag") # type: ignore - app.appbuilder.sm.sync_perm_for_dag( # type: ignore - "TEST_DAG_1", - access_control={ - "TestGranularDag": { - permissions.RESOURCE_DAG: {permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ} - }, - }, + role_name="admin", ) + create_user(app, username="test_no_permissions", role_name=None) with DAG( DAG_ID, @@ -93,9 +78,8 @@ def configured_app(minimal_app_for_api): yield app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore - delete_user(app, username="test_granular_permissions") # type: ignore + delete_user(app, username="test") + delete_user(app, username="test_no_permissions") class TestBackfillEndpoint: @@ -178,7 +162,6 @@ def test_should_respond_200(self, session): @pytest.mark.parametrize( "user, expected", [ - ("test_granular_permissions", 200), ("test_no_permissions", 403), ("test", 200), (None, 401), @@ -240,7 +223,6 @@ def test_no_exist(self, session): @pytest.mark.parametrize( "user, expected", [ - ("test_granular_permissions", 200), ("test_no_permissions", 403), ("test", 200), (None, 401), @@ -268,7 +250,6 @@ class TestCreateBackfill(TestBackfillEndpoint): @pytest.mark.parametrize( "user, expected", [ - ("test_granular_permissions", 200), ("test_no_permissions", 403), ("test", 200), (None, 401), @@ -347,7 +328,6 @@ def test_should_respond_200(self, session): @pytest.mark.parametrize( "user, expected", [ - ("test_granular_permissions", 200), ("test_no_permissions", 403), ("test", 200), (None, 401), @@ -409,7 +389,6 @@ def test_should_respond_200(self, session): @pytest.mark.parametrize( "user, expected", [ - ("test_granular_permissions", 200), ("test_no_permissions", 403), ("test", 200), (None, 401), diff --git a/tests/api_connexion/endpoints/test_config_endpoint.py b/tests/api_connexion/endpoints/test_config_endpoint.py index 475753a4a902..bd88c491c952 100644 --- a/tests/api_connexion/endpoints/test_config_endpoint.py +++ b/tests/api_connexion/endpoints/test_config_endpoint.py @@ -21,7 +21,6 @@ import pytest -from airflow.security import permissions from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user from tests.test_utils.config import conf_vars @@ -54,18 +53,17 @@ def configured_app(minimal_app_for_api): app = minimal_app_for_api create_user( - app, # type:ignore + app, username="test", - role_name="Test", - permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_CONFIG)], # type: ignore + role_name="admin", ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(app, username="test_no_permissions", role_name=None) with conf_vars({("webserver", "expose_config"): "True"}): yield minimal_app_for_api - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(app, username="test") + delete_user(app, username="test_no_permissions") class TestGetConfig: diff --git a/tests/api_connexion/endpoints/test_connection_endpoint.py b/tests/api_connexion/endpoints/test_connection_endpoint.py index a19b046aa274..a140046656e3 100644 --- a/tests/api_connexion/endpoints/test_connection_endpoint.py +++ b/tests/api_connexion/endpoints/test_connection_endpoint.py @@ -24,7 +24,6 @@ from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP from airflow.models import Connection from airflow.secrets.environment_variables import CONN_ENV_PREFIX -from airflow.security import permissions from airflow.utils.session import provide_session from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user from tests.test_utils.config import conf_vars @@ -38,22 +37,16 @@ def configured_app(minimal_app_for_api): app = minimal_app_for_api create_user( - app, # type: ignore + app, username="test", - role_name="Test", - permissions=[ - (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_CONNECTION), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_CONNECTION), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_CONNECTION), - (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_CONNECTION), - ], + role_name="admin", ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(app, username="test_no_permissions", role_name=None) yield app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(app, username="test") + delete_user(app, username="test_no_permissions") class TestConnectionEndpoint: diff --git a/tests/api_connexion/endpoints/test_dag_endpoint.py b/tests/api_connexion/endpoints/test_dag_endpoint.py index 9905b4e27ab2..6d4ffc2d06d2 100644 --- a/tests/api_connexion/endpoints/test_dag_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_endpoint.py @@ -28,7 +28,6 @@ from airflow.models.dag import DAG from airflow.models.serialized_dag import SerializedDagModel from airflow.operators.empty import EmptyOperator -from airflow.security import permissions from airflow.utils.session import provide_session from airflow.utils.state import TaskInstanceState from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user @@ -56,33 +55,11 @@ def configured_app(minimal_app_for_api): app = minimal_app_for_api create_user( - app, # type: ignore + app, username="test", - role_name="Test", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_DAG), - ], - ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore - create_user(app, username="test_granular_permissions", role_name="TestGranularDag") # type: ignore - app.appbuilder.sm.sync_perm_for_dag( # type: ignore - "TEST_DAG_1", - access_control={ - "TestGranularDag": { - permissions.RESOURCE_DAG: {permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ} - }, - }, - ) - app.appbuilder.sm.sync_perm_for_dag( # type: ignore - "TEST_DAG_1", - access_control={ - "TestGranularDag": { - permissions.RESOURCE_DAG: {permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ} - }, - }, + role_name="admin", ) + create_user(app, username="test_no_permissions", role_name=None) with DAG( DAG_ID, @@ -107,9 +84,8 @@ def configured_app(minimal_app_for_api): yield app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore - delete_user(app, username="test_granular_permissions") # type: ignore + delete_user(app, username="test") + delete_user(app, username="test_no_permissions") class TestDagEndpoint: @@ -258,13 +234,6 @@ def test_should_respond_200_with_schedule_none(self, session): "pickle_id": None, } == response.json - def test_should_respond_200_with_granular_dag_access(self): - self._create_dag_models(1) - response = self.client.get( - "/api/v1/dags/TEST_DAG_1", environ_overrides={"REMOTE_USER": "test_granular_permissions"} - ) - assert response.status_code == 200 - def test_should_respond_404(self): response = self.client.get("/api/v1/dags/INVALID_DAG", environ_overrides={"REMOTE_USER": "test"}) assert response.status_code == 404 @@ -282,13 +251,6 @@ def test_should_raise_403_forbidden(self): ) assert response.status_code == 403 - def test_should_respond_403_with_granular_access_for_different_dag(self): - self._create_dag_models(3) - response = self.client.get( - "/api/v1/dags/TEST_DAG_2", environ_overrides={"REMOTE_USER": "test_granular_permissions"} - ) - assert response.status_code == 403 - @pytest.mark.parametrize( "fields", [ @@ -961,15 +923,6 @@ def test_filter_dags_by_dag_id_works(self, url, expected_dag_ids): assert expected_dag_ids == dag_ids - def test_should_respond_200_with_granular_dag_access(self): - self._create_dag_models(3) - response = self.client.get( - "/api/v1/dags", environ_overrides={"REMOTE_USER": "test_granular_permissions"} - ) - assert response.status_code == 200 - assert len(response.json["dags"]) == 1 - assert response.json["dags"][0]["dag_id"] == "TEST_DAG_1" - @pytest.mark.parametrize( "url, expected_dag_ids", [ @@ -1252,18 +1205,6 @@ def test_should_respond_200_on_patch_is_paused(self, url_safe_serializer, sessio session, dag_id="TEST_DAG_1", event="api.patch_dag", execution_date=None, expected_extra=payload ) - def test_should_respond_200_on_patch_with_granular_dag_access(self, session): - self._create_dag_models(1) - response = self.client.patch( - "/api/v1/dags/TEST_DAG_1", - json={ - "is_paused": False, - }, - environ_overrides={"REMOTE_USER": "test_granular_permissions"}, - ) - assert response.status_code == 200 - _check_last_log(session, dag_id="TEST_DAG_1", event="api.patch_dag", execution_date=None) - def test_should_respond_400_on_invalid_request(self): patch_body = { "is_paused": True, @@ -1279,24 +1220,6 @@ def test_should_respond_400_on_invalid_request(self): "type": EXCEPTIONS_LINK_MAP[400], } - def test_validation_error_raises_400(self): - patch_body = { - "ispaused": True, - } - dag_model = self._create_dag_model() - response = self.client.patch( - f"/api/v1/dags/{dag_model.dag_id}", - json=patch_body, - environ_overrides={"REMOTE_USER": "test_granular_permissions"}, - ) - assert response.status_code == 400 - assert response.json == { - "detail": "{'ispaused': ['Unknown field.']}", - "status": 400, - "title": "Bad Request", - "type": EXCEPTIONS_LINK_MAP[400], - } - def test_non_existing_dag_raises_not_found(self): patch_body = { "is_paused": True, @@ -1820,19 +1743,6 @@ def test_filter_dags_by_dag_id_works(self, url, expected_dag_ids): assert expected_dag_ids == dag_ids - def test_should_respond_200_with_granular_dag_access(self): - self._create_dag_models(3) - response = self.client.patch( - "api/v1/dags?dag_id_pattern=~", - json={ - "is_paused": False, - }, - environ_overrides={"REMOTE_USER": "test_granular_permissions"}, - ) - assert response.status_code == 200 - assert len(response.json["dags"]) == 1 - assert response.json["dags"][0]["dag_id"] == "TEST_DAG_1" - @pytest.mark.parametrize( "url, expected_dag_ids", [ diff --git a/tests/api_connexion/endpoints/test_dag_parsing.py b/tests/api_connexion/endpoints/test_dag_parsing.py index 521d8d9e8cd9..ae42a565dd05 100644 --- a/tests/api_connexion/endpoints/test_dag_parsing.py +++ b/tests/api_connexion/endpoints/test_dag_parsing.py @@ -24,7 +24,6 @@ from airflow.models import DagBag from airflow.models.dagbag import DagPriorityParsingRequest -from airflow.security import permissions from tests.test_utils.api_connexion_utils import create_user, delete_user from tests.test_utils.db import clear_db_dag_parsing_requests @@ -45,21 +44,16 @@ def configured_app(minimal_app_for_api): app = minimal_app_for_api create_user( - app, # type:ignore + app, username="test", - role_name="Test", - permissions=[(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG)], # type: ignore + role_name="admin", ) - app.appbuilder.sm.sync_perm_for_dag( # type: ignore - TEST_DAG_ID, - access_control={"Test": [permissions.ACTION_CAN_EDIT]}, - ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(app, username="test_no_permissions", role_name=None) yield app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(app, username="test") + delete_user(app, username="test_no_permissions") class TestDagParsingRequest: diff --git a/tests/api_connexion/endpoints/test_dag_run_endpoint.py b/tests/api_connexion/endpoints/test_dag_run_endpoint.py index f3921da7b9c2..73c75b98a43b 100644 --- a/tests/api_connexion/endpoints/test_dag_run_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_run_endpoint.py @@ -30,12 +30,11 @@ from airflow.models.dagrun import DagRun from airflow.models.param import Param from airflow.operators.empty import EmptyOperator -from airflow.security import permissions from airflow.utils import timezone from airflow.utils.session import create_session, provide_session from airflow.utils.state import DagRunState, State from airflow.utils.types import DagRunType -from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_roles, delete_user +from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS from tests.test_utils.config import conf_vars from tests.test_utils.db import clear_db_dags, clear_db_runs, clear_db_serialized_dags @@ -52,79 +51,16 @@ def configured_app(minimal_app_for_api): app = minimal_app_for_api create_user( - app, # type: ignore + app, username="test", - role_name="Test", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_ASSET), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_CLUSTER_ACTIVITY), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_DAG_RUN), - ], - ) - create_user( - app, # type: ignore - username="test_no_dag_run_create_permission", - role_name="TestNoDagRunCreatePermission", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_ASSET), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_CLUSTER_ACTIVITY), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_DAG_RUN), - ], - ) - create_user( - app, # type: ignore - username="test_dag_view_only", - role_name="TestViewDags", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_DAG_RUN), - ], - ) - create_user( - app, # type: ignore - username="test_view_dags", - role_name="TestViewDags", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_DAG_RUN), - ], + role_name="admin", ) - create_user( - app, # type: ignore - username="test_granular_permissions", - role_name="TestGranularDag", - permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN)], - ) - app.appbuilder.sm.sync_perm_for_dag( # type: ignore - "TEST_DAG_ID", - access_control={ - "TestGranularDag": {permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ}, - "TestNoDagRunCreatePermission": {permissions.RESOURCE_DAG_RUN: {permissions.ACTION_CAN_CREATE}}, - }, - ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(app, username="test_no_permissions", role_name=None) yield app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_dag_view_only") # type: ignore - delete_user(app, username="test_view_dags") # type: ignore - delete_user(app, username="test_granular_permissions") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore - delete_user(app, username="test_no_dag_run_create_permission") # type: ignore - delete_roles(app) + delete_user(app, username="test") + delete_user(app, username="test_no_permissions") class TestDagRunEndpoint: @@ -499,16 +435,6 @@ def test_should_return_all_with_tilde_as_dag_id_and_all_dag_permissions(self): dag_run_ids = [dag_run["dag_id"] for dag_run in response.json["dag_runs"]] assert dag_run_ids == expected_dag_run_ids - def test_should_return_accessible_with_tilde_as_dag_id_and_dag_level_permissions(self): - self._create_test_dag_run(extra_dag=True) - expected_dag_run_ids = ["TEST_DAG_ID", "TEST_DAG_ID"] - response = self.client.get( - "api/v1/dags/~/dagRuns", environ_overrides={"REMOTE_USER": "test_granular_permissions"} - ) - assert response.status_code == 200 - dag_run_ids = [dag_run["dag_id"] for dag_run in response.json["dag_runs"]] - assert dag_run_ids == expected_dag_run_ids - def test_should_raises_401_unauthenticated(self): self._create_test_dag_run() @@ -907,57 +833,6 @@ def test_order_by_raises_for_invalid_attr(self): msg = "Ordering with 'dag_ru' is disallowed or the attribute does not exist on the model" assert response.json["detail"] == msg - def test_should_return_accessible_with_tilde_as_dag_id_and_dag_level_permissions(self): - self._create_test_dag_run(extra_dag=True) - expected_response_json_1 = { - "dag_id": "TEST_DAG_ID", - "dag_run_id": "TEST_DAG_RUN_ID_1", - "end_date": None, - "state": "running", - "execution_date": self.default_time, - "logical_date": self.default_time, - "external_trigger": True, - "start_date": self.default_time, - "conf": {}, - "data_interval_end": None, - "data_interval_start": None, - "last_scheduling_decision": None, - "run_type": "manual", - "note": None, - } - expected_response_json_1.update({"triggered_by": "test"} if AIRFLOW_V_3_0_PLUS else {}) - expected_response_json_2 = { - "dag_id": "TEST_DAG_ID", - "dag_run_id": "TEST_DAG_RUN_ID_2", - "end_date": None, - "state": "running", - "execution_date": self.default_time_2, - "logical_date": self.default_time_2, - "external_trigger": True, - "start_date": self.default_time, - "conf": {}, - "data_interval_end": None, - "data_interval_start": None, - "last_scheduling_decision": None, - "run_type": "manual", - "note": None, - } - expected_response_json_2.update({"triggered_by": "test"} if AIRFLOW_V_3_0_PLUS else {}) - - response = self.client.post( - "api/v1/dags/~/dagRuns/list", - json={"dag_ids": []}, - environ_overrides={"REMOTE_USER": "test_granular_permissions"}, - ) - assert response.status_code == 200 - assert response.json == { - "dag_runs": [ - expected_response_json_1, - expected_response_json_2, - ], - "total_entries": 2, - } - @pytest.mark.parametrize( "payload, error", [ @@ -1328,15 +1203,6 @@ def test_raises_validation_error_for_invalid_params(self): assert response.status_code == 400 assert "Invalid input for param" in response.json["detail"] - def test_dagrun_trigger_with_dag_level_permissions(self): - self._create_dag("TEST_DAG_ID") - response = self.client.post( - "api/v1/dags/TEST_DAG_ID/dagRuns", - json={"conf": {"validated_number": 1}}, - environ_overrides={"REMOTE_USER": "test_no_dag_run_create_permission"}, - ) - assert response.status_code == 200 - @mock.patch("airflow.api_connexion.endpoints.dag_run_endpoint.get_airflow_app") def test_dagrun_creation_exception_is_handled(self, mock_get_app, session): self._create_dag("TEST_DAG_ID") @@ -1627,11 +1493,7 @@ def test_should_raises_401_unauthenticated(self): assert_401(response) - @pytest.mark.parametrize( - "username", - ["test_dag_view_only", "test_view_dags", "test_granular_permissions", "test_no_permissions"], - ) - def test_should_raises_403_unauthorized(self, username): + def test_should_raises_403_unauthorized(self): self._create_dag("TEST_DAG_ID") response = self.client.post( "api/v1/dags/TEST_DAG_ID/dagRuns", @@ -1639,7 +1501,7 @@ def test_should_raises_403_unauthorized(self, username): "dag_run_id": "TEST_DAG_RUN_ID_1", "execution_date": self.default_time, }, - environ_overrides={"REMOTE_USER": username}, + environ_overrides={"REMOTE_USER": "test_no_permissions"}, ) assert response.status_code == 403 diff --git a/tests/api_connexion/endpoints/test_dag_source_endpoint.py b/tests/api_connexion/endpoints/test_dag_source_endpoint.py index a8d1224e034c..f4df56ba629a 100644 --- a/tests/api_connexion/endpoints/test_dag_source_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_source_endpoint.py @@ -23,7 +23,6 @@ import pytest from airflow.models import DagBag -from airflow.security import permissions from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user from tests.test_utils.db import clear_db_dag_code, clear_db_dags, clear_db_serialized_dags @@ -44,29 +43,16 @@ def configured_app(minimal_app_for_api): app = minimal_app_for_api create_user( - app, # type:ignore + app, username="test", - role_name="Test", - permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_CODE)], # type: ignore + role_name="admin", ) - app.appbuilder.sm.sync_perm_for_dag( # type: ignore - TEST_DAG_ID, - access_control={"Test": [permissions.ACTION_CAN_READ]}, - ) - app.appbuilder.sm.sync_perm_for_dag( # type: ignore - EXAMPLE_DAG_ID, - access_control={"Test": [permissions.ACTION_CAN_READ]}, - ) - app.appbuilder.sm.sync_perm_for_dag( # type: ignore - TEST_MULTIPLE_DAGS_ID, - access_control={"Test": [permissions.ACTION_CAN_READ]}, - ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(app, username="test_no_permissions", role_name=None) yield app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(app, username="test") + delete_user(app, username="test_no_permissions") class TestGetSource: @@ -123,18 +109,6 @@ def test_should_respond_200_json(self, url_safe_serializer): assert dag_docstring in response.json["content"] assert "application/json" == response.headers["Content-Type"] - def test_should_respond_406(self, url_safe_serializer): - dagbag = DagBag(dag_folder=EXAMPLE_DAG_FILE) - dagbag.sync_to_db() - test_dag: DAG = dagbag.dags[TEST_DAG_ID] - - url = f"/api/v1/dagSources/{url_safe_serializer.dumps(test_dag.fileloc)}" - response = self.client.get( - url, headers={"Accept": "image/webp"}, environ_overrides={"REMOTE_USER": "test"} - ) - - assert 406 == response.status_code - def test_should_respond_404(self): wrong_fileloc = "abcd1234" url = f"/api/v1/dagSources/{wrong_fileloc}" @@ -167,38 +141,3 @@ def test_should_raise_403_forbidden(self, url_safe_serializer): environ_overrides={"REMOTE_USER": "test_no_permissions"}, ) assert response.status_code == 403 - - def test_should_respond_403_not_readable(self, url_safe_serializer): - dagbag = DagBag(dag_folder=EXAMPLE_DAG_FILE) - dagbag.sync_to_db() - dag: DAG = dagbag.dags[NOT_READABLE_DAG_ID] - - response = self.client.get( - f"/api/v1/dagSources/{url_safe_serializer.dumps(dag.fileloc)}", - headers={"Accept": "text/plain"}, - environ_overrides={"REMOTE_USER": "test"}, - ) - read_dag = self.client.get( - f"/api/v1/dags/{NOT_READABLE_DAG_ID}", - environ_overrides={"REMOTE_USER": "test"}, - ) - assert response.status_code == 403 - assert read_dag.status_code == 403 - - def test_should_respond_403_some_dags_not_readable_in_the_file(self, url_safe_serializer): - dagbag = DagBag(dag_folder=EXAMPLE_DAG_FILE) - dagbag.sync_to_db() - dag: DAG = dagbag.dags[TEST_MULTIPLE_DAGS_ID] - - response = self.client.get( - f"/api/v1/dagSources/{url_safe_serializer.dumps(dag.fileloc)}", - headers={"Accept": "text/plain"}, - environ_overrides={"REMOTE_USER": "test"}, - ) - - read_dag = self.client.get( - f"/api/v1/dags/{TEST_MULTIPLE_DAGS_ID}", - environ_overrides={"REMOTE_USER": "test"}, - ) - assert response.status_code == 403 - assert read_dag.status_code == 200 diff --git a/tests/api_connexion/endpoints/test_dag_stats_endpoint.py b/tests/api_connexion/endpoints/test_dag_stats_endpoint.py index 36fc54d3a5b1..9ab5b4976593 100644 --- a/tests/api_connexion/endpoints/test_dag_stats_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_stats_endpoint.py @@ -22,7 +22,6 @@ from airflow.models.dag import DAG, DagModel from airflow.models.dagrun import DagRun -from airflow.security import permissions from airflow.utils import timezone from airflow.utils.session import create_session from airflow.utils.state import DagRunState @@ -38,21 +37,17 @@ def configured_app(minimal_app_for_api): app = minimal_app_for_api create_user( - app, # type: ignore + app, username="test", - role_name="Test", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - ], + role_name="admin", ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(app, username="test_no_permissions", role_name=None) yield app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(app, username="test") + delete_user(app, username="test_no_permissions") class TestDagStatsEndpoint: diff --git a/tests/api_connexion/endpoints/test_dag_warning_endpoint.py b/tests/api_connexion/endpoints/test_dag_warning_endpoint.py index 3e7c805173b3..f156d8921c0e 100644 --- a/tests/api_connexion/endpoints/test_dag_warning_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_warning_endpoint.py @@ -22,7 +22,6 @@ from airflow.models.dag import DagModel from airflow.models.dagwarning import DagWarning -from airflow.security import permissions from airflow.utils.session import create_session from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user from tests.test_utils.db import clear_db_dag_warnings, clear_db_dags @@ -34,30 +33,16 @@ def configured_app(minimal_app_for_api): app = minimal_app_for_api create_user( - app, # type:ignore + app, username="test", - role_name="Test", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_WARNING), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - ], # type: ignore - ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore - create_user( - app, # type:ignore - username="test_with_dag2_read", - role_name="TestWithDag2Read", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_WARNING), - (permissions.ACTION_CAN_READ, f"{permissions.RESOURCE_DAG_PREFIX}dag2"), - ], # type: ignore + role_name="admin", ) + create_user(app, username="test_no_permissions", role_name=None) yield minimal_app_for_api - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore - delete_user(app, username="test_with_dag2_read") # type: ignore + delete_user(app, username="test") + delete_user(app, username="test_no_permissions") class TestBaseDagWarning: @@ -162,11 +147,3 @@ def test_should_raise_403_forbidden(self): "/api/v1/dagWarnings", environ_overrides={"REMOTE_USER": "test_no_permissions"} ) assert response.status_code == 403 - - def test_should_raise_403_forbidden_when_user_has_no_dag_read_permission(self): - response = self.client.get( - "/api/v1/dagWarnings", - environ_overrides={"REMOTE_USER": "test_with_dag2_read"}, - query_string={"dag_id": "dag1"}, - ) - assert response.status_code == 403 diff --git a/tests/api_connexion/endpoints/test_dataset_endpoint.py b/tests/api_connexion/endpoints/test_dataset_endpoint.py index 5caec0ac2a13..76c164654c9d 100644 --- a/tests/api_connexion/endpoints/test_dataset_endpoint.py +++ b/tests/api_connexion/endpoints/test_dataset_endpoint.py @@ -33,7 +33,6 @@ TaskOutletAssetReference, ) from airflow.models.dagrun import DagRun -from airflow.security import permissions from airflow.utils import timezone from airflow.utils.session import provide_session from airflow.utils.types import DagRunType @@ -50,31 +49,16 @@ def configured_app(minimal_app_for_api): app = minimal_app_for_api create_user( - app, # type: ignore + app, username="test", - role_name="Test", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_ASSET), - (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_ASSET), - ], - ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore - create_user( - app, # type: ignore - username="test_queued_event", - role_name="TestQueuedEvent", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_ASSET), - (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_ASSET), - ], + role_name="admin", ) + create_user(app, username="test_no_permissions", role_name=None) yield app - delete_user(app, username="test_queued_event") # type: ignore - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(app, username="test") + delete_user(app, username="test_no_permissions") class TestDatasetEndpoint: @@ -768,43 +752,6 @@ def _create_dataset_dag_run_queues(self, dag_id, dataset_id, session): class TestGetDagDatasetQueuedEvent(TestQueuedEventEndpoint): - @pytest.mark.usefixtures("time_freezer") - def test_should_respond_200(self, session, create_dummy_dag): - dag, _ = create_dummy_dag() - dag_id = dag.dag_id - dataset_id = self._create_dataset(session).id - self._create_dataset_dag_run_queues(dag_id, dataset_id, session) - dataset_uri = "s3://bucket/key" - - response = self.client.get( - f"/api/v1/dags/{dag_id}/datasets/queuedEvent/{dataset_uri}", - environ_overrides={"REMOTE_USER": "test_queued_event"}, - ) - - assert response.status_code == 200 - assert response.json == { - "created_at": self.default_time, - "uri": "s3://bucket/key", - "dag_id": "dag", - } - - def test_should_respond_404(self): - dag_id = "not_exists" - dataset_uri = "not_exists" - - response = self.client.get( - f"/api/v1/dags/{dag_id}/datasets/queuedEvent/{dataset_uri}", - environ_overrides={"REMOTE_USER": "test_queued_event"}, - ) - - assert response.status_code == 404 - assert { - "detail": "Queue event with dag_id: `not_exists` and asset uri: `not_exists` was not found", - "status": 404, - "title": "Queue event not found", - "type": EXCEPTIONS_LINK_MAP[404], - } == response.json - def test_should_raises_401_unauthenticated(self, session): dag_id = "dummy" dataset_uri = "dummy" @@ -826,47 +773,6 @@ def test_should_raise_403_forbidden(self, session): class TestDeleteDagDatasetQueuedEvent(TestDatasetEndpoint): - def test_delete_should_respond_204(self, session, create_dummy_dag): - dag, _ = create_dummy_dag() - dag_id = dag.dag_id - dataset_uri = "s3://bucket/key" - dataset_id = self._create_dataset(session).id - - adrq = AssetDagRunQueue(target_dag_id=dag_id, dataset_id=dataset_id) - session.add(adrq) - session.commit() - conn = session.query(AssetDagRunQueue).all() - assert len(conn) == 1 - - response = self.client.delete( - f"/api/v1/dags/{dag_id}/datasets/queuedEvent/{dataset_uri}", - environ_overrides={"REMOTE_USER": "test_queued_event"}, - ) - - assert response.status_code == 204 - conn = session.query(AssetDagRunQueue).all() - assert len(conn) == 0 - _check_last_log( - session, dag_id=dag_id, event="api.delete_dag_dataset_queued_event", execution_date=None - ) - - def test_should_respond_404(self): - dag_id = "not_exists" - dataset_uri = "not_exists" - - response = self.client.delete( - f"/api/v1/dags/{dag_id}/datasets/queuedEvent/{dataset_uri}", - environ_overrides={"REMOTE_USER": "test_queued_event"}, - ) - - assert response.status_code == 404 - assert { - "detail": "Queue event with dag_id: `not_exists` and asset uri: `not_exists` was not found", - "status": 404, - "title": "Queue event not found", - "type": EXCEPTIONS_LINK_MAP[404], - } == response.json - def test_should_raises_401_unauthenticated(self, session): dag_id = "dummy" dataset_uri = "dummy" @@ -884,46 +790,6 @@ def test_should_raise_403_forbidden(self, session): class TestGetDagDatasetQueuedEvents(TestQueuedEventEndpoint): - @pytest.mark.usefixtures("time_freezer") - def test_should_respond_200(self, session, create_dummy_dag): - dag, _ = create_dummy_dag() - dag_id = dag.dag_id - dataset_id = self._create_dataset(session).id - self._create_dataset_dag_run_queues(dag_id, dataset_id, session) - - response = self.client.get( - f"/api/v1/dags/{dag_id}/datasets/queuedEvent", - environ_overrides={"REMOTE_USER": "test_queued_event"}, - ) - - assert response.status_code == 200 - assert response.json == { - "queued_events": [ - { - "created_at": self.default_time, - "uri": "s3://bucket/key", - "dag_id": "dag", - } - ], - "total_entries": 1, - } - - def test_should_respond_404(self): - dag_id = "not_exists" - - response = self.client.get( - f"/api/v1/dags/{dag_id}/datasets/queuedEvent", - environ_overrides={"REMOTE_USER": "test_queued_event"}, - ) - - assert response.status_code == 404 - assert { - "detail": "Queue event with dag_id: `not_exists` was not found", - "status": 404, - "title": "Queue event not found", - "type": EXCEPTIONS_LINK_MAP[404], - } == response.json - def test_should_raises_401_unauthenticated(self): dag_id = "dummy" @@ -943,22 +809,6 @@ def test_should_raise_403_forbidden(self): class TestDeleteDagDatasetQueuedEvents(TestDatasetEndpoint): - def test_should_respond_404(self): - dag_id = "not_exists" - - response = self.client.delete( - f"/api/v1/dags/{dag_id}/datasets/queuedEvent", - environ_overrides={"REMOTE_USER": "test_queued_event"}, - ) - - assert response.status_code == 404 - assert { - "detail": "Queue event with dag_id: `not_exists` was not found", - "status": 404, - "title": "Queue event not found", - "type": EXCEPTIONS_LINK_MAP[404], - } == response.json - def test_should_raises_401_unauthenticated(self): dag_id = "dummy" @@ -978,47 +828,6 @@ def test_should_raise_403_forbidden(self): class TestGetDatasetQueuedEvents(TestQueuedEventEndpoint): - @pytest.mark.usefixtures("time_freezer") - def test_should_respond_200(self, session, create_dummy_dag): - dag, _ = create_dummy_dag() - dag_id = dag.dag_id - dataset_id = self._create_dataset(session).id - self._create_dataset_dag_run_queues(dag_id, dataset_id, session) - dataset_uri = "s3://bucket/key" - - response = self.client.get( - f"/api/v1/datasets/queuedEvent/{dataset_uri}", - environ_overrides={"REMOTE_USER": "test_queued_event"}, - ) - - assert response.status_code == 200 - assert response.json == { - "queued_events": [ - { - "created_at": self.default_time, - "uri": "s3://bucket/key", - "dag_id": "dag", - } - ], - "total_entries": 1, - } - - def test_should_respond_404(self): - dataset_uri = "not_exists" - - response = self.client.get( - f"/api/v1/datasets/queuedEvent/{dataset_uri}", - environ_overrides={"REMOTE_USER": "test_queued_event"}, - ) - - assert response.status_code == 404 - assert { - "detail": "Queue event with asset uri: `not_exists` was not found", - "status": 404, - "title": "Queue event not found", - "type": EXCEPTIONS_LINK_MAP[404], - } == response.json - def test_should_raises_401_unauthenticated(self): dataset_uri = "not_exists" @@ -1038,39 +847,6 @@ def test_should_raise_403_forbidden(self): class TestDeleteDatasetQueuedEvents(TestQueuedEventEndpoint): - def test_delete_should_respond_204(self, session, create_dummy_dag): - dag, _ = create_dummy_dag() - dag_id = dag.dag_id - dataset_id = self._create_dataset(session).id - self._create_dataset_dag_run_queues(dag_id, dataset_id, session) - dataset_uri = "s3://bucket/key" - - response = self.client.delete( - f"/api/v1/datasets/queuedEvent/{dataset_uri}", - environ_overrides={"REMOTE_USER": "test_queued_event"}, - ) - - assert response.status_code == 204 - conn = session.query(AssetDagRunQueue).all() - assert len(conn) == 0 - _check_last_log(session, dag_id=None, event="api.delete_dataset_queued_events", execution_date=None) - - def test_should_respond_404(self): - dataset_uri = "not_exists" - - response = self.client.delete( - f"/api/v1/datasets/queuedEvent/{dataset_uri}", - environ_overrides={"REMOTE_USER": "test_queued_event"}, - ) - - assert response.status_code == 404 - assert { - "detail": "Queue event with asset uri: `not_exists` was not found", - "status": 404, - "title": "Queue event not found", - "type": EXCEPTIONS_LINK_MAP[404], - } == response.json - def test_should_raises_401_unauthenticated(self): dataset_uri = "not_exists" diff --git a/tests/api_connexion/endpoints/test_event_log_endpoint.py b/tests/api_connexion/endpoints/test_event_log_endpoint.py index 0fdef1a3af2b..e5ca3d301765 100644 --- a/tests/api_connexion/endpoints/test_event_log_endpoint.py +++ b/tests/api_connexion/endpoints/test_event_log_endpoint.py @@ -20,7 +20,6 @@ from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP from airflow.models import Log -from airflow.security import permissions from airflow.utils import timezone from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user from tests.test_utils.config import conf_vars @@ -33,32 +32,16 @@ def configured_app(minimal_app_for_api): app = minimal_app_for_api create_user( - app, # type:ignore + app, username="test", - role_name="Test", - permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_AUDIT_LOG)], # type: ignore + role_name="admin", ) - create_user( - app, # type:ignore - username="test_granular", - role_name="TestGranular", - permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_AUDIT_LOG)], # type: ignore - ) - app.appbuilder.sm.sync_perm_for_dag( # type: ignore - "TEST_DAG_ID_1", - access_control={"TestGranular": [permissions.ACTION_CAN_READ]}, - ) - app.appbuilder.sm.sync_perm_for_dag( # type: ignore - "TEST_DAG_ID_2", - access_control={"TestGranular": [permissions.ACTION_CAN_READ]}, - ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(app, username="test_no_permissions", role_name=None) yield app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_granular") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(app, username="test") + delete_user(app, username="test_no_permissions") @pytest.fixture @@ -274,33 +257,6 @@ def test_should_raises_401_unauthenticated(self, log_model): assert_401(response) - def test_should_filter_eventlogs_by_allowed_attributes(self, create_log_model, session): - eventlog1 = create_log_model( - event="TEST_EVENT_1", - dag_id="TEST_DAG_ID_1", - task_id="TEST_TASK_ID_1", - owner="TEST_OWNER_1", - when=self.default_time, - ) - eventlog2 = create_log_model( - event="TEST_EVENT_2", - dag_id="TEST_DAG_ID_2", - task_id="TEST_TASK_ID_2", - owner="TEST_OWNER_2", - when=self.default_time_2, - ) - session.add_all([eventlog1, eventlog2]) - session.commit() - for attr in ["dag_id", "task_id", "owner", "event"]: - attr_value = f"TEST_{attr}_1".upper() - response = self.client.get( - f"/api/v1/eventLogs?{attr}={attr_value}", environ_overrides={"REMOTE_USER": "test_granular"} - ) - assert response.status_code == 200 - assert response.json["total_entries"] == 1 - assert len(response.json["event_logs"]) == 1 - assert response.json["event_logs"][0][attr] == attr_value - def test_should_filter_eventlogs_by_when(self, create_log_model, session): eventlog1 = create_log_model(event="TEST_EVENT_1", when=self.default_time) eventlog2 = create_log_model(event="TEST_EVENT_2", when=self.default_time_2) @@ -339,32 +295,6 @@ def test_should_filter_eventlogs_by_run_id(self, create_log_model, session): assert {eventlog["event"] for eventlog in response.json["event_logs"]} == expected_eventlogs assert all({eventlog["run_id"] == run_id for eventlog in response.json["event_logs"]}) - def test_should_filter_eventlogs_by_included_events(self, create_log_model): - for event in ["TEST_EVENT_1", "TEST_EVENT_2", "cli_scheduler"]: - create_log_model(event=event, when=self.default_time) - response = self.client.get( - "/api/v1/eventLogs?included_events=TEST_EVENT_1,TEST_EVENT_2", - environ_overrides={"REMOTE_USER": "test_granular"}, - ) - assert response.status_code == 200 - response_data = response.json - assert len(response_data["event_logs"]) == 2 - assert response_data["total_entries"] == 2 - assert {"TEST_EVENT_1", "TEST_EVENT_2"} == {x["event"] for x in response_data["event_logs"]} - - def test_should_filter_eventlogs_by_excluded_events(self, create_log_model): - for event in ["TEST_EVENT_1", "TEST_EVENT_2", "cli_scheduler"]: - create_log_model(event=event, when=self.default_time) - response = self.client.get( - "/api/v1/eventLogs?excluded_events=TEST_EVENT_1,TEST_EVENT_2", - environ_overrides={"REMOTE_USER": "test_granular"}, - ) - assert response.status_code == 200 - response_data = response.json - assert len(response_data["event_logs"]) == 1 - assert response_data["total_entries"] == 1 - assert {"cli_scheduler"} == {x["event"] for x in response_data["event_logs"]} - class TestGetEventLogPagination(TestEventLogEndpoint): @pytest.mark.parametrize( diff --git a/tests/api_connexion/endpoints/test_extra_link_endpoint.py b/tests/api_connexion/endpoints/test_extra_link_endpoint.py index 1e9226ede984..2c3eacdc91dc 100644 --- a/tests/api_connexion/endpoints/test_extra_link_endpoint.py +++ b/tests/api_connexion/endpoints/test_extra_link_endpoint.py @@ -26,7 +26,6 @@ from airflow.models.dagbag import DagBag from airflow.models.xcom import XCom from airflow.plugins_manager import AirflowPlugin -from airflow.security import permissions from airflow.timetables.base import DataInterval from airflow.utils import timezone from airflow.utils.state import DagRunState @@ -48,21 +47,16 @@ def configured_app(minimal_app_for_api): app = minimal_app_for_api create_user( - app, # type: ignore + app, username="test", - role_name="Test", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - ], + role_name="admin", ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(app, username="test_no_permissions", role_name=None) yield app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(app, username="test") + delete_user(app, username="test_no_permissions") class TestGetExtraLinks: @@ -78,8 +72,8 @@ def setup_attrs(self, configured_app, session) -> None: self.dag = self._create_dag() self.app.dag_bag = DagBag(os.devnull, include_examples=False) - self.app.dag_bag.dags = {self.dag.dag_id: self.dag} # type: ignore - self.app.dag_bag.sync_to_db() # type: ignore + self.app.dag_bag.dags = {self.dag.dag_id: self.dag} + self.app.dag_bag.sync_to_db() triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {} self.dag.create_dagrun( diff --git a/tests/api_connexion/endpoints/test_import_error_endpoint.py b/tests/api_connexion/endpoints/test_import_error_endpoint.py index 635e159bb292..af2b83ebb1ee 100644 --- a/tests/api_connexion/endpoints/test_import_error_endpoint.py +++ b/tests/api_connexion/endpoints/test_import_error_endpoint.py @@ -21,15 +21,12 @@ import pytest from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP -from airflow.models.dag import DagModel -from airflow.security import permissions from airflow.utils import timezone from airflow.utils.session import provide_session from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user from tests.test_utils.compat import ParseImportError from tests.test_utils.config import conf_vars from tests.test_utils.db import clear_db_dags, clear_db_import_errors -from tests.test_utils.permissions import _resource_name pytestmark = [pytest.mark.db_test, pytest.mark.skip_if_database_isolation_mode] @@ -40,42 +37,16 @@ def configured_app(minimal_app_for_api): app = minimal_app_for_api create_user( - app, # type:ignore + app, username="test", - role_name="Test", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_IMPORT_ERROR), - ], # type: ignore - ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore - create_user( - app, # type:ignore - username="test_single_dag", - role_name="TestSingleDAG", - permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_IMPORT_ERROR)], # type: ignore - ) - # For some reason, DAG level permissions are not synced when in the above list of perms, - # so do it manually here: - app.appbuilder.sm.bulk_sync_roles( - [ - { - "role": "TestSingleDAG", - "perms": [ - ( - permissions.ACTION_CAN_READ, - _resource_name(TEST_DAG_IDS[0], permissions.RESOURCE_DAG), - ) - ], - } - ] + role_name="admin", ) + create_user(app, username="test_no_permissions", role_name=None) yield app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore - delete_user(app, username="test_single_dag") # type: ignore + delete_user(app, username="test") + delete_user(app, username="test_no_permissions") class TestBaseImportError: @@ -152,72 +123,6 @@ def test_should_raise_403_forbidden(self): ) assert response.status_code == 403 - def test_should_raise_403_forbidden_without_dag_read(self, session): - import_error = ParseImportError( - filename="Lorem_ipsum.py", - stacktrace="Lorem ipsum", - timestamp=timezone.parse(self.timestamp, timezone="UTC"), - ) - session.add(import_error) - session.commit() - - response = self.client.get( - f"/api/v1/importErrors/{import_error.id}", environ_overrides={"REMOTE_USER": "test_single_dag"} - ) - - assert response.status_code == 403 - - def test_should_return_200_with_single_dag_read(self, session): - dag_model = DagModel(dag_id=TEST_DAG_IDS[0], fileloc="Lorem_ipsum.py") - session.add(dag_model) - import_error = ParseImportError( - filename="Lorem_ipsum.py", - stacktrace="Lorem ipsum", - timestamp=timezone.parse(self.timestamp, timezone="UTC"), - ) - session.add(import_error) - session.commit() - - response = self.client.get( - f"/api/v1/importErrors/{import_error.id}", environ_overrides={"REMOTE_USER": "test_single_dag"} - ) - - assert response.status_code == 200 - response_data = response.json - response_data["import_error_id"] = 1 - assert { - "filename": "Lorem_ipsum.py", - "import_error_id": 1, - "stack_trace": "Lorem ipsum", - "timestamp": "2020-06-10T12:00:00+00:00", - } == response_data - - def test_should_return_200_redacted_with_single_dag_read_in_dagfile(self, session): - for dag_id in TEST_DAG_IDS: - dag_model = DagModel(dag_id=dag_id, fileloc="Lorem_ipsum.py") - session.add(dag_model) - import_error = ParseImportError( - filename="Lorem_ipsum.py", - stacktrace="Lorem ipsum", - timestamp=timezone.parse(self.timestamp, timezone="UTC"), - ) - session.add(import_error) - session.commit() - - response = self.client.get( - f"/api/v1/importErrors/{import_error.id}", environ_overrides={"REMOTE_USER": "test_single_dag"} - ) - - assert response.status_code == 200 - response_data = response.json - response_data["import_error_id"] = 1 - assert { - "filename": "Lorem_ipsum.py", - "import_error_id": 1, - "stack_trace": "REDACTED - you do not have read permission on all DAGs in the file", - "timestamp": "2020-06-10T12:00:00+00:00", - } == response_data - class TestGetImportErrorsEndpoint(TestBaseImportError): def test_get_import_errors(self, session): @@ -328,71 +233,6 @@ def test_should_raises_401_unauthenticated(self, session): assert_401(response) - def test_get_import_errors_single_dag(self, session): - for dag_id in TEST_DAG_IDS: - fake_filename = f"/tmp/{dag_id}.py" - dag_model = DagModel(dag_id=dag_id, fileloc=fake_filename) - session.add(dag_model) - importerror = ParseImportError( - filename=fake_filename, - stacktrace="Lorem ipsum", - timestamp=timezone.parse(self.timestamp, timezone="UTC"), - ) - session.add(importerror) - session.commit() - - response = self.client.get( - "/api/v1/importErrors", environ_overrides={"REMOTE_USER": "test_single_dag"} - ) - - assert response.status_code == 200 - response_data = response.json - self._normalize_import_errors(response_data["import_errors"]) - assert { - "import_errors": [ - { - "filename": "/tmp/test_dag.py", - "import_error_id": 1, - "stack_trace": "Lorem ipsum", - "timestamp": "2020-06-10T12:00:00+00:00", - }, - ], - "total_entries": 1, - } == response_data - - def test_get_import_errors_single_dag_in_dagfile(self, session): - for dag_id in TEST_DAG_IDS: - fake_filename = "/tmp/all_in_one.py" - dag_model = DagModel(dag_id=dag_id, fileloc=fake_filename) - session.add(dag_model) - - importerror = ParseImportError( - filename="/tmp/all_in_one.py", - stacktrace="Lorem ipsum", - timestamp=timezone.parse(self.timestamp, timezone="UTC"), - ) - session.add(importerror) - session.commit() - - response = self.client.get( - "/api/v1/importErrors", environ_overrides={"REMOTE_USER": "test_single_dag"} - ) - - assert response.status_code == 200 - response_data = response.json - self._normalize_import_errors(response_data["import_errors"]) - assert { - "import_errors": [ - { - "filename": "/tmp/all_in_one.py", - "import_error_id": 1, - "stack_trace": "REDACTED - you do not have read permission on all DAGs in the file", - "timestamp": "2020-06-10T12:00:00+00:00", - }, - ], - "total_entries": 1, - } == response_data - class TestGetImportErrorsEndpointPagination(TestBaseImportError): @pytest.mark.parametrize( diff --git a/tests/api_connexion/endpoints/test_log_endpoint.py b/tests/api_connexion/endpoints/test_log_endpoint.py index 420d2dd65f89..2b112e322184 100644 --- a/tests/api_connexion/endpoints/test_log_endpoint.py +++ b/tests/api_connexion/endpoints/test_log_endpoint.py @@ -30,7 +30,6 @@ from airflow.decorators import task from airflow.models.dag import DAG from airflow.operators.empty import EmptyOperator -from airflow.security import permissions from airflow.utils import timezone from airflow.utils.types import DagRunType from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user @@ -46,13 +45,9 @@ def configured_app(minimal_app_for_api): create_user( app, username="test", - role_name="Test", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_LOG), - ], + role_name="admin", ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") + create_user(app, username="test_no_permissions", role_name=None) yield app diff --git a/tests/api_connexion/endpoints/test_mapped_task_instance_endpoint.py b/tests/api_connexion/endpoints/test_mapped_task_instance_endpoint.py index 72cdccdee68d..fc53b8952f4a 100644 --- a/tests/api_connexion/endpoints/test_mapped_task_instance_endpoint.py +++ b/tests/api_connexion/endpoints/test_mapped_task_instance_endpoint.py @@ -28,12 +28,11 @@ from airflow.models.baseoperator import BaseOperator from airflow.models.dagbag import DagBag from airflow.models.taskmap import TaskMap -from airflow.security import permissions from airflow.utils.platform import getuser from airflow.utils.session import provide_session from airflow.utils.state import State, TaskInstanceState from airflow.utils.timezone import datetime -from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_roles, delete_user +from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user from tests.test_utils.db import clear_db_runs, clear_db_sla_miss, clear_rendered_ti_fields from tests.test_utils.mock_operators import MockOperator @@ -50,24 +49,16 @@ def configured_app(minimal_app_for_api): app = minimal_app_for_api create_user( - app, # type: ignore + app, username="test", - role_name="Test", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_TASK_INSTANCE), - ], + role_name="admin", ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(app, username="test_no_permissions", role_name=None) yield app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore - delete_roles(app) + delete_user(app, username="test") + delete_user(app, username="test_no_permissions") class TestMappedTaskInstanceEndpoint: @@ -133,8 +124,8 @@ def create_dag_runs_with_mapped_tasks(self, dag_maker, session, dags=None): session.add(ti) self.app.dag_bag = DagBag(os.devnull, include_examples=False) - self.app.dag_bag.dags = {dag_id: dag_maker.dag} # type: ignore - self.app.dag_bag.sync_to_db() # type: ignore + self.app.dag_bag.dags = {dag_id: dag_maker.dag} + self.app.dag_bag.sync_to_db() session.flush() mapped.expand_mapped_task(dr.run_id, session=session) diff --git a/tests/api_connexion/endpoints/test_plugin_endpoint.py b/tests/api_connexion/endpoints/test_plugin_endpoint.py index edf925cf0fa7..0cd630375a28 100644 --- a/tests/api_connexion/endpoints/test_plugin_endpoint.py +++ b/tests/api_connexion/endpoints/test_plugin_endpoint.py @@ -24,7 +24,6 @@ from airflow.hooks.base import BaseHook from airflow.plugins_manager import AirflowPlugin -from airflow.security import permissions from airflow.ti_deps.deps.base_ti_dep import BaseTIDep from airflow.timetables.base import Timetable from airflow.utils.module_loading import qualname @@ -105,17 +104,16 @@ class MockPlugin(AirflowPlugin): def configured_app(minimal_app_for_api): app = minimal_app_for_api create_user( - app, # type: ignore + app, username="test", - role_name="Test", - permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_PLUGIN)], + role_name="admin", ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(app, username="test_no_permissions", role_name=None) yield app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(app, username="test") + delete_user(app, username="test_no_permissions") class TestPluginsEndpoint: diff --git a/tests/api_connexion/endpoints/test_pool_endpoint.py b/tests/api_connexion/endpoints/test_pool_endpoint.py index 87439a581194..2cc095d077aa 100644 --- a/tests/api_connexion/endpoints/test_pool_endpoint.py +++ b/tests/api_connexion/endpoints/test_pool_endpoint.py @@ -20,7 +20,6 @@ from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP from airflow.models.pool import Pool -from airflow.security import permissions from airflow.utils.session import provide_session from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user from tests.test_utils.config import conf_vars @@ -35,22 +34,16 @@ def configured_app(minimal_app_for_api): app = minimal_app_for_api create_user( - app, # type: ignore + app, username="test", - role_name="Test", - permissions=[ - (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_POOL), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_POOL), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_POOL), - (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_POOL), - ], + role_name="admin", ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(app, username="test_no_permissions", role_name=None) yield app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(app, username="test") + delete_user(app, username="test_no_permissions") class TestBasePoolEndpoints: diff --git a/tests/api_connexion/endpoints/test_provider_endpoint.py b/tests/api_connexion/endpoints/test_provider_endpoint.py index 16e5989cc56d..b4cf8f10a92a 100644 --- a/tests/api_connexion/endpoints/test_provider_endpoint.py +++ b/tests/api_connexion/endpoints/test_provider_endpoint.py @@ -21,7 +21,6 @@ import pytest from airflow.providers_manager import ProviderInfo -from airflow.security import permissions from tests.test_utils.api_connexion_utils import create_user, delete_user pytestmark = [pytest.mark.db_test, pytest.mark.skip_if_database_isolation_mode] @@ -54,17 +53,16 @@ def configured_app(minimal_app_for_api): app = minimal_app_for_api create_user( - app, # type: ignore + app, username="test", - role_name="Test", - permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_PROVIDER)], + role_name="admin", ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(app, username="test_no_permissions", role_name=None) yield app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(app, username="test") + delete_user(app, username="test_no_permissions") class TestBaseProviderEndpoint: diff --git a/tests/api_connexion/endpoints/test_task_endpoint.py b/tests/api_connexion/endpoints/test_task_endpoint.py index d0a4fb903c8b..b2e068bd507f 100644 --- a/tests/api_connexion/endpoints/test_task_endpoint.py +++ b/tests/api_connexion/endpoints/test_task_endpoint.py @@ -27,7 +27,6 @@ from airflow.models.expandinput import EXPAND_INPUT_EMPTY from airflow.models.serialized_dag import SerializedDagModel from airflow.operators.empty import EmptyOperator -from airflow.security import permissions from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user from tests.test_utils.db import clear_db_dags, clear_db_runs, clear_db_serialized_dags @@ -38,21 +37,16 @@ def configured_app(minimal_app_for_api): app = minimal_app_for_api create_user( - app, # type: ignore + app, username="test", - role_name="Test", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - ], + role_name="admin", ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(app, username="test_no_permissions", role_name=None) yield app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(app, username="test") + delete_user(app, username="test_no_permissions") class TestTaskEndpoint: diff --git a/tests/api_connexion/endpoints/test_task_instance_endpoint.py b/tests/api_connexion/endpoints/test_task_instance_endpoint.py index 25ded6c814b7..b5b3163e988d 100644 --- a/tests/api_connexion/endpoints/test_task_instance_endpoint.py +++ b/tests/api_connexion/endpoints/test_task_instance_endpoint.py @@ -25,19 +25,17 @@ from sqlalchemy import select from sqlalchemy.orm import contains_eager -from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP from airflow.jobs.job import Job from airflow.jobs.triggerer_job_runner import TriggererJobRunner from airflow.models import DagRun, SlaMiss, TaskInstance, Trigger from airflow.models.renderedtifields import RenderedTaskInstanceFields as RTIF from airflow.models.taskinstancehistory import TaskInstanceHistory -from airflow.security import permissions from airflow.utils.platform import getuser from airflow.utils.session import provide_session from airflow.utils.state import State from airflow.utils.timezone import datetime from airflow.utils.types import DagRunType -from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_roles, delete_user +from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user from tests.test_utils.db import clear_db_runs, clear_db_sla_miss, clear_rendered_ti_fields from tests.test_utils.www import _check_last_log @@ -55,69 +53,16 @@ def configured_app(minimal_app_for_api): app = minimal_app_for_api create_user( - app, # type: ignore + app, username="test", - role_name="Test", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_TASK_INSTANCE), - ], - ) - create_user( - app, # type: ignore - username="test_dag_read_only", - role_name="TestDagReadOnly", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_TASK_INSTANCE), - ], - ) - create_user( - app, # type: ignore - username="test_task_read_only", - role_name="TestTaskReadOnly", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - ], - ) - create_user( - app, # type: ignore - username="test_read_only_one_dag", - role_name="TestReadOnlyOneDag", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - ], - ) - # For some reason, "DAG:example_python_operator" is not synced when in the above list of perms, - # so do it manually here: - app.appbuilder.sm.bulk_sync_roles( - [ - { - "role": "TestReadOnlyOneDag", - "perms": [(permissions.ACTION_CAN_READ, "DAG:example_python_operator")], - } - ] + role_name="admin", ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(app, username="test_no_permissions", role_name=None) yield app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_dag_read_only") # type: ignore - delete_user(app, username="test_task_read_only") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore - delete_user(app, username="test_read_only_one_dag") # type: ignore - delete_roles(app) + delete_user(app, username="test") + delete_user(app, username="test_no_permissions") class TestTaskInstanceEndpoint: @@ -219,9 +164,8 @@ def setup_method(self): def teardown_method(self): clear_db_runs() - @pytest.mark.parametrize("username", ["test", "test_dag_read_only", "test_task_read_only"]) @provide_session - def test_should_respond_200(self, username, session): + def test_should_respond_200(self, session): self.create_task_instances(session) # Update ti and set operator to None to # test that operator field is nullable. @@ -232,7 +176,7 @@ def test_should_respond_200(self, username, session): session.commit() response = self.client.get( "/api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context", - environ_overrides={"REMOTE_USER": username}, + environ_overrides={"REMOTE_USER": "test"}, ) assert response.status_code == 200 assert response.json == { @@ -723,36 +667,11 @@ def test_should_respond_200(self, task_instances, update_extras, url, expected_t assert response.json["total_entries"] == expected_ti assert len(response.json["task_instances"]) == expected_ti - @pytest.mark.parametrize( - "task_instances, user, expected_ti", - [ - pytest.param( - { - "example_python_operator": 2, - "example_skip_dag": 1, - }, - "test_read_only_one_dag", - 2, - ), - pytest.param( - { - "example_python_operator": 1, - "example_skip_dag": 2, - }, - "test_read_only_one_dag", - 1, - ), - pytest.param( - { - "example_python_operator": 1, - "example_skip_dag": 2, - }, - "test", - 3, - ), - ], - ) - def test_return_TI_only_from_readable_dags(self, task_instances, user, expected_ti, session): + def test_return_TI_only_from_readable_dags(self, session): + task_instances = { + "example_python_operator": 1, + "example_skip_dag": 2, + } for dag_id in task_instances: self.create_task_instances( session, @@ -763,11 +682,11 @@ def test_return_TI_only_from_readable_dags(self, task_instances, user, expected_ dag_id=dag_id, ) response = self.client.get( - "/api/v1/dags/~/dagRuns/~/taskInstances", environ_overrides={"REMOTE_USER": user} + "/api/v1/dags/~/dagRuns/~/taskInstances", environ_overrides={"REMOTE_USER": "test"} ) assert response.status_code == 200 - assert response.json["total_entries"] == expected_ti - assert len(response.json["task_instances"]) == expected_ti + assert response.json["total_entries"] == 3 + assert len(response.json["task_instances"]) == 3 def test_should_respond_200_for_dag_id_filter(self, session): self.create_task_instances(session) @@ -898,44 +817,6 @@ class TestGetTaskInstancesBatch(TestTaskInstanceEndpoint): "test", id="test executor filter", ), - pytest.param( - [ - {"pool": "test_pool_1"}, - {"pool": "test_pool_2"}, - {"pool": "test_pool_3"}, - ], - True, - {"pool": ["test_pool_1", "test_pool_2"]}, - 2, - "test_dag_read_only", - id="test pool filter", - ), - pytest.param( - [ - {"state": State.RUNNING}, - {"state": State.QUEUED}, - {"state": State.SUCCESS}, - {"state": State.NONE}, - ], - False, - {"state": ["running", "queued", "none"]}, - 3, - "test_task_read_only", - id="test state filter", - ), - pytest.param( - [ - {"state": State.NONE}, - {"state": State.NONE}, - {"state": State.NONE}, - {"state": State.NONE}, - ], - False, - {}, - 4, - "test_task_read_only", - id="test dag with null states", - ), pytest.param( [ {"duration": 100}, @@ -948,36 +829,6 @@ class TestGetTaskInstancesBatch(TestTaskInstanceEndpoint): "test", id="test duration filter", ), - pytest.param( - [ - {"end_date": DEFAULT_DATETIME_1}, - {"end_date": DEFAULT_DATETIME_1 + dt.timedelta(days=1)}, - {"end_date": DEFAULT_DATETIME_1 + dt.timedelta(days=2)}, - ], - True, - { - "end_date_gte": DEFAULT_DATETIME_STR_1, - "end_date_lte": DEFAULT_DATETIME_STR_2, - }, - 2, - "test_task_read_only", - id="test end date filter", - ), - pytest.param( - [ - {"start_date": DEFAULT_DATETIME_1}, - {"start_date": DEFAULT_DATETIME_1 + dt.timedelta(days=1)}, - {"start_date": DEFAULT_DATETIME_1 + dt.timedelta(days=2)}, - ], - True, - { - "start_date_gte": DEFAULT_DATETIME_STR_1, - "start_date_lte": DEFAULT_DATETIME_STR_2, - }, - 2, - "test_dag_read_only", - id="test start date filter", - ), pytest.param( [ {"execution_date": DEFAULT_DATETIME_1}, @@ -1162,24 +1013,6 @@ def test_should_raise_403_forbidden(self): ) assert response.status_code == 403 - def test_returns_403_forbidden_when_user_has_access_to_only_some_dags(self, session): - self.create_task_instances(session=session) - self.create_task_instances(session=session, dag_id="example_skip_dag") - payload = {"dag_ids": ["example_python_operator", "example_skip_dag"]} - - response = self.client.post( - "/api/v1/dags/~/dagRuns/~/taskInstances/list", - environ_overrides={"REMOTE_USER": "test_read_only_one_dag"}, - json=payload, - ) - assert response.status_code == 403 - assert response.json == { - "detail": "User not allowed to access some of these DAGs: ['example_python_operator', 'example_skip_dag']", - "status": 403, - "title": "Forbidden", - "type": EXCEPTIONS_LINK_MAP[403], - } - def test_should_raise_400_for_no_json(self): response = self.client.post( "/api/v1/dags/~/dagRuns/~/taskInstances/list", @@ -1794,11 +1627,10 @@ def test_should_raises_401_unauthenticated(self): ) assert_401(response) - @pytest.mark.parametrize("username", ["test_no_permissions", "test_dag_read_only", "test_task_read_only"]) - def test_should_raise_403_forbidden(self, username: str): + def test_should_raise_403_forbidden(self): response = self.client.post( "/api/v1/dags/example_python_operator/clearTaskInstances", - environ_overrides={"REMOTE_USER": username}, + environ_overrides={"REMOTE_USER": "test_no_permissions"}, json={ "dry_run": False, "reset_dag_runs": True, @@ -2043,11 +1875,10 @@ def test_should_raises_401_unauthenticated(self): ) assert_401(response) - @pytest.mark.parametrize("username", ["test_no_permissions", "test_dag_read_only", "test_task_read_only"]) - def test_should_raise_403_forbidden(self, username): + def test_should_raise_403_forbidden(self): response = self.client.post( "/api/v1/dags/example_python_operator/updateTaskInstancesState", - environ_overrides={"REMOTE_USER": username}, + environ_overrides={"REMOTE_USER": "test_no_permissions"}, json={ "dry_run": True, "task_id": "print_the_context", @@ -2386,11 +2217,10 @@ def test_should_raises_401_unauthenticated(self): ) assert_401(response) - @pytest.mark.parametrize("username", ["test_no_permissions", "test_dag_read_only", "test_task_read_only"]) - def test_should_raise_403_forbidden(self, username): + def test_should_raise_403_forbidden(self): response = self.client.patch( self.ENDPOINT_URL, - environ_overrides={"REMOTE_USER": username}, + environ_overrides={"REMOTE_USER": "test_no_permissions"}, json={ "dry_run": True, "new_state": "failed", @@ -2748,14 +2578,13 @@ def setup_method(self): def teardown_method(self): clear_db_runs() - @pytest.mark.parametrize("username", ["test", "test_dag_read_only", "test_task_read_only"]) @provide_session - def test_should_respond_200(self, username, session): + def test_should_respond_200(self, session): self.create_task_instances(session, task_instances=[{"state": State.SUCCESS}], with_ti_history=True) response = self.client.get( "/api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context/tries/1", - environ_overrides={"REMOTE_USER": username}, + environ_overrides={"REMOTE_USER": "test"}, ) assert response.status_code == 200 assert response.json == { diff --git a/tests/api_connexion/endpoints/test_variable_endpoint.py b/tests/api_connexion/endpoints/test_variable_endpoint.py index 81405df08b04..aa5f7c99674f 100644 --- a/tests/api_connexion/endpoints/test_variable_endpoint.py +++ b/tests/api_connexion/endpoints/test_variable_endpoint.py @@ -22,7 +22,6 @@ from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP from airflow.models import Variable -from airflow.security import permissions from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user from tests.test_utils.config import conf_vars from tests.test_utils.db import clear_db_variables @@ -36,40 +35,16 @@ def configured_app(minimal_app_for_api): app = minimal_app_for_api create_user( - app, # type: ignore + app, username="test", - role_name="Test", - permissions=[ - (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_VARIABLE), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_VARIABLE), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_VARIABLE), - (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_VARIABLE), - ], - ) - create_user( - app, # type: ignore - username="test_read_only", - role_name="TestReadOnly", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_VARIABLE), - ], - ) - create_user( - app, # type: ignore - username="test_delete_only", - role_name="TestDeleteOnly", - permissions=[ - (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_VARIABLE), - ], + role_name="admin", ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(app, username="test_no_permissions", role_name=None) yield app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_read_only") # type: ignore - delete_user(app, username="test_delete_only") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(app, username="test") + delete_user(app, username="test_no_permissions") class TestVariableEndpoint: @@ -131,8 +106,6 @@ class TestGetVariable(TestVariableEndpoint): "user, expected_status_code", [ ("test", 200), - ("test_read_only", 200), - ("test_delete_only", 403), ("test_no_permissions", 403), ], ) diff --git a/tests/api_connexion/endpoints/test_xcom_endpoint.py b/tests/api_connexion/endpoints/test_xcom_endpoint.py index 7a51714c5b29..809e537f9f88 100644 --- a/tests/api_connexion/endpoints/test_xcom_endpoint.py +++ b/tests/api_connexion/endpoints/test_xcom_endpoint.py @@ -26,7 +26,6 @@ from airflow.models.taskinstance import TaskInstance from airflow.models.xcom import BaseXCom, XCom, resolve_xcom_backend from airflow.operators.empty import EmptyOperator -from airflow.security import permissions from airflow.utils.dates import parse_execution_date from airflow.utils.session import create_session from airflow.utils.timezone import utcnow @@ -52,32 +51,16 @@ def configured_app(minimal_app_for_api): app = minimal_app_for_api create_user( - app, # type: ignore + app, username="test", - role_name="Test", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_XCOM), - ], - ) - create_user( - app, # type: ignore - username="test_granular_permissions", - role_name="TestGranularDag", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_XCOM), - ], - ) - app.appbuilder.sm.sync_perm_for_dag( # type: ignore - "test-dag-id-1", - access_control={"TestGranularDag": [permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ]}, + role_name="admin", ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(app, username="test_no_permissions", role_name=None) yield app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(app, username="test") + delete_user(app, username="test_no_permissions") def _compare_xcom_collections(collection1: dict, collection_2: dict): @@ -435,53 +418,6 @@ def test_should_respond_200_with_tilde_and_access_to_all_dags(self): }, ) - def test_should_respond_200_with_tilde_and_granular_dag_access(self): - dag_id_1 = "test-dag-id-1" - task_id_1 = "test-task-id-1" - execution_date = "2005-04-02T00:00:00+00:00" - execution_date_parsed = parse_execution_date(execution_date) - dag_run_id_1 = DagRun.generate_run_id(DagRunType.MANUAL, execution_date_parsed) - self._create_xcom_entries(dag_id_1, dag_run_id_1, execution_date_parsed, task_id_1) - - dag_id_2 = "test-dag-id-2" - task_id_2 = "test-task-id-2" - run_id_2 = DagRun.generate_run_id(DagRunType.MANUAL, execution_date_parsed) - self._create_xcom_entries(dag_id_2, run_id_2, execution_date_parsed, task_id_2) - self._create_invalid_xcom_entries(execution_date_parsed) - response = self.client.get( - "/api/v1/dags/~/dagRuns/~/taskInstances/~/xcomEntries", - environ_overrides={"REMOTE_USER": "test_granular_permissions"}, - ) - - assert 200 == response.status_code - response_data = response.json - for xcom_entry in response_data["xcom_entries"]: - xcom_entry["timestamp"] = "TIMESTAMP" - _compare_xcom_collections( - response_data, - { - "xcom_entries": [ - { - "dag_id": dag_id_1, - "execution_date": execution_date, - "key": "test-xcom-key-1", - "task_id": task_id_1, - "timestamp": "TIMESTAMP", - "map_index": -1, - }, - { - "dag_id": dag_id_1, - "execution_date": execution_date, - "key": "test-xcom-key-2", - "task_id": task_id_1, - "timestamp": "TIMESTAMP", - "map_index": -1, - }, - ], - "total_entries": 2, - }, - ) - def test_should_respond_200_with_map_index(self): dag_id = "test-dag-id" task_id = "test-task-id" diff --git a/tests/api_connexion/test_auth.py b/tests/api_connexion/test_auth.py index 7d1dcc088273..54e5632ad84d 100644 --- a/tests/api_connexion/test_auth.py +++ b/tests/api_connexion/test_auth.py @@ -16,15 +16,15 @@ # under the License. from __future__ import annotations -from base64 import b64encode +from unittest.mock import patch import pytest -from flask_login import current_user +from airflow.auth.managers.simple.simple_auth_manager import SimpleAuthManager +from airflow.auth.managers.simple.user import SimpleAuthManagerUser from tests.test_utils.api_connexion_utils import assert_401 from tests.test_utils.config import conf_vars from tests.test_utils.db import clear_db_pools -from tests.test_utils.www import client_with_login pytestmark = [pytest.mark.db_test, pytest.mark.skip_if_database_isolation_mode] @@ -34,101 +34,6 @@ class BaseTestAuth: def set_attrs(self, minimal_app_for_api): self.app = minimal_app_for_api - sm = self.app.appbuilder.sm - tester = sm.find_user(username="test") - if not tester: - role_admin = sm.find_role("Admin") - sm.add_user( - username="test", - first_name="test", - last_name="test", - email="test@fab.org", - role=role_admin, - password="test", - ) - - -class TestBasicAuth(BaseTestAuth): - @pytest.fixture(autouse=True, scope="class") - def with_basic_auth_backend(self, minimal_app_for_api): - from airflow.www.extensions.init_security import init_api_auth - - old_auth = getattr(minimal_app_for_api, "api_auth") - - try: - with conf_vars( - {("api", "auth_backends"): "airflow.providers.fab.auth_manager.api.auth.backend.basic_auth"} - ): - init_api_auth(minimal_app_for_api) - yield - finally: - setattr(minimal_app_for_api, "api_auth", old_auth) - - def test_success(self): - token = "Basic " + b64encode(b"test:test").decode() - clear_db_pools() - - with self.app.test_client() as test_client: - response = test_client.get("/api/v1/pools", headers={"Authorization": token}) - assert current_user.email == "test@fab.org" - - assert response.status_code == 200 - assert response.json == { - "pools": [ - { - "name": "default_pool", - "slots": 128, - "occupied_slots": 0, - "running_slots": 0, - "queued_slots": 0, - "scheduled_slots": 0, - "deferred_slots": 0, - "open_slots": 128, - "description": "Default pool", - "include_deferred": False, - }, - ], - "total_entries": 1, - } - - @pytest.mark.parametrize( - "token", - [ - "basic", - "basic ", - "bearer", - "test:test", - b64encode(b"test:test").decode(), - "bearer ", - "basic: ", - "basic 123", - ], - ) - def test_malformed_headers(self, token): - with self.app.test_client() as test_client: - response = test_client.get("/api/v1/pools", headers={"Authorization": token}) - assert response.status_code == 401 - assert response.headers["Content-Type"] == "application/problem+json" - assert response.headers["WWW-Authenticate"] == "Basic" - assert_401(response) - - @pytest.mark.parametrize( - "token", - [ - "basic " + b64encode(b"test").decode(), - "basic " + b64encode(b"test:").decode(), - "basic " + b64encode(b"test:123").decode(), - "basic " + b64encode(b"test test").decode(), - ], - ) - def test_invalid_auth_header(self, token): - with self.app.test_client() as test_client: - response = test_client.get("/api/v1/pools", headers={"Authorization": token}) - assert response.status_code == 401 - assert response.headers["Content-Type"] == "application/problem+json" - assert response.headers["WWW-Authenticate"] == "Basic" - assert_401(response) - class TestSessionAuth(BaseTestAuth): @pytest.fixture(autouse=True, scope="class") @@ -144,74 +49,37 @@ def with_session_backend(self, minimal_app_for_api): finally: setattr(minimal_app_for_api, "api_auth", old_auth) - def test_success(self): + @patch.object(SimpleAuthManager, "is_logged_in", return_value=True) + @patch.object( + SimpleAuthManager, "get_user", return_value=SimpleAuthManagerUser(username="test", role="admin") + ) + def test_success(self, *args): clear_db_pools() - admin_user = client_with_login(self.app, username="test", password="test") - response = admin_user.get("/api/v1/pools") - assert response.status_code == 200 - assert response.json == { - "pools": [ - { - "name": "default_pool", - "slots": 128, - "occupied_slots": 0, - "running_slots": 0, - "queued_slots": 0, - "scheduled_slots": 0, - "deferred_slots": 0, - "open_slots": 128, - "description": "Default pool", - "include_deferred": False, - }, - ], - "total_entries": 1, - } - - def test_failure(self): with self.app.test_client() as test_client: response = test_client.get("/api/v1/pools") - assert response.status_code == 401 - assert response.headers["Content-Type"] == "application/problem+json" - assert_401(response) - - -class TestSessionWithBasicAuthFallback(BaseTestAuth): - @pytest.fixture(autouse=True, scope="class") - def with_basic_auth_backend(self, minimal_app_for_api): - from airflow.www.extensions.init_security import init_api_auth - - old_auth = getattr(minimal_app_for_api, "api_auth") - - try: - with conf_vars( - { - ( - "api", - "auth_backends", - ): "airflow.api.auth.backend.session,airflow.providers.fab.auth_manager.api.auth.backend.basic_auth" - } - ): - init_api_auth(minimal_app_for_api) - yield - finally: - setattr(minimal_app_for_api, "api_auth", old_auth) - - def test_basic_auth_fallback(self): - token = "Basic " + b64encode(b"test:test").decode() - clear_db_pools() - - # request uses session - admin_user = client_with_login(self.app, username="test", password="test") - response = admin_user.get("/api/v1/pools") - assert response.status_code == 200 - - # request uses basic auth - with self.app.test_client() as test_client: - response = test_client.get("/api/v1/pools", headers={"Authorization": token}) assert response.status_code == 200 + assert response.json == { + "pools": [ + { + "name": "default_pool", + "slots": 128, + "occupied_slots": 0, + "running_slots": 0, + "queued_slots": 0, + "scheduled_slots": 0, + "deferred_slots": 0, + "open_slots": 128, + "description": "Default pool", + "include_deferred": False, + }, + ], + "total_entries": 1, + } - # request without session or basic auth header + def test_failure(self): with self.app.test_client() as test_client: response = test_client.get("/api/v1/pools") assert response.status_code == 401 + assert response.headers["Content-Type"] == "application/problem+json" + assert_401(response) diff --git a/tests/api_connexion/test_security.py b/tests/api_connexion/test_security.py index 13a5dd4e25af..c6a112b1a1bb 100644 --- a/tests/api_connexion/test_security.py +++ b/tests/api_connexion/test_security.py @@ -18,7 +18,6 @@ import pytest -from airflow.security import permissions from tests.test_utils.api_connexion_utils import create_user, delete_user pytestmark = [pytest.mark.db_test, pytest.mark.skip_if_database_isolation_mode] @@ -28,15 +27,14 @@ def configured_app(minimal_app_for_api): app = minimal_app_for_api create_user( - app, # type:ignore + app, username="test", - role_name="Test", - permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_CONFIG)], # type: ignore + role_name="admin", ) yield minimal_app_for_api - delete_user(app, username="test") # type: ignore + delete_user(app, username="test") class TestSession: diff --git a/tests/providers/fab/auth_manager/api_endpoints/api_connexion_utils.py b/tests/providers/fab/auth_manager/api_endpoints/api_connexion_utils.py new file mode 100644 index 000000000000..61d923d5ff12 --- /dev/null +++ b/tests/providers/fab/auth_manager/api_endpoints/api_connexion_utils.py @@ -0,0 +1,116 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from contextlib import contextmanager + +from tests.test_utils.compat import ignore_provider_compatibility_error + +with ignore_provider_compatibility_error("2.9.0+", __file__): + from airflow.providers.fab.auth_manager.security_manager.override import EXISTING_ROLES + + +@contextmanager +def create_test_client(app, user_name, role_name, permissions): + """ + Helper function to create a client with a temporary user which will be deleted once done + """ + client = app.test_client() + with create_user_scope(app, username=user_name, role_name=role_name, permissions=permissions) as _: + resp = client.post("/login/", data={"username": user_name, "password": user_name}) + assert resp.status_code == 302 + yield client + + +@contextmanager +def create_user_scope(app, username, **kwargs): + """ + Helper function designed to be used with pytest fixture mainly. + It will create a user and provide it for the fixture via YIELD (generator) + then will tidy up once test is complete + """ + test_user = create_user(app, username, **kwargs) + + try: + yield test_user + finally: + delete_user(app, username) + + +def create_user(app, username, role_name=None, email=None, permissions=None): + appbuilder = app.appbuilder + + # Removes user and role so each test has isolated test data. + delete_user(app, username) + role = None + if role_name: + delete_role(app, role_name) + role = create_role(app, role_name, permissions) + else: + role = [] + + return appbuilder.sm.add_user( + username=username, + first_name=username, + last_name=username, + email=email or f"{username}@example.org", + role=role, + password=username, + ) + + +def create_role(app, name, permissions=None): + appbuilder = app.appbuilder + role = appbuilder.sm.find_role(name) + if not role: + role = appbuilder.sm.add_role(name) + if not permissions: + permissions = [] + for permission in permissions: + perm_object = appbuilder.sm.get_permission(*permission) + appbuilder.sm.add_permission_to_role(role, perm_object) + return role + + +def set_user_single_role(app, user, role_name): + role = create_role(app, role_name) + if role not in user.roles: + user.roles = [role] + app.appbuilder.sm.update_user(user) + user._perms = None + + +def delete_role(app, name): + if name not in EXISTING_ROLES: + if app.appbuilder.sm.find_role(name): + app.appbuilder.sm.delete_role(name) + + +def delete_roles(app): + for role in app.appbuilder.sm.get_all_roles(): + delete_role(app, role.name) + + +def delete_user(app, username): + appbuilder = app.appbuilder + for user in appbuilder.sm.get_all_users(): + if user.username == username: + _ = [ + delete_role(app, role.name) for role in user.roles if role and role.name not in EXISTING_ROLES + ] + appbuilder.sm.del_register_user(user) + break diff --git a/tests/providers/fab/auth_manager/api_endpoints/remote_user_api_auth_backend.py b/tests/providers/fab/auth_manager/api_endpoints/remote_user_api_auth_backend.py new file mode 100644 index 000000000000..b7714e5192e6 --- /dev/null +++ b/tests/providers/fab/auth_manager/api_endpoints/remote_user_api_auth_backend.py @@ -0,0 +1,81 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Default authentication backend - everything is allowed""" + +from __future__ import annotations + +import logging +from functools import wraps +from typing import TYPE_CHECKING, Callable, TypeVar, cast + +from flask import Response, request +from flask_login import login_user + +from airflow.utils.airflow_flask_app import get_airflow_app + +if TYPE_CHECKING: + from requests.auth import AuthBase + +log = logging.getLogger(__name__) + +CLIENT_AUTH: tuple[str, str] | AuthBase | None = None + + +def init_app(_): + """Initializes authentication backend""" + + +T = TypeVar("T", bound=Callable) + + +def _lookup_user(user_email_or_username: str): + security_manager = get_airflow_app().appbuilder.sm + user = security_manager.find_user(email=user_email_or_username) or security_manager.find_user( + username=user_email_or_username + ) + if not user: + return None + + if not user.is_active: + return None + + return user + + +def requires_authentication(function: T): + """Decorator for functions that require authentication""" + + @wraps(function) + def decorated(*args, **kwargs): + user_id = request.remote_user + if not user_id: + log.debug("Missing REMOTE_USER.") + return Response("Forbidden", 403) + + log.debug("Looking for user: %s", user_id) + + user = _lookup_user(user_id) + if not user: + return Response("Forbidden", 403) + + log.debug("Found user: %s", user) + + login_user(user, remember=False) + return function(*args, **kwargs) + + return cast(T, decorated) diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_auth.py b/tests/providers/fab/auth_manager/api_endpoints/test_auth.py new file mode 100644 index 000000000000..d3012e2f1b43 --- /dev/null +++ b/tests/providers/fab/auth_manager/api_endpoints/test_auth.py @@ -0,0 +1,176 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from base64 import b64encode + +import pytest +from flask_login import current_user + +from tests.test_utils.api_connexion_utils import assert_401 +from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS +from tests.test_utils.config import conf_vars +from tests.test_utils.db import clear_db_pools +from tests.test_utils.www import client_with_login + +pytestmark = [ + pytest.mark.db_test, + pytest.mark.skip_if_database_isolation_mode, + pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), +] + + +class BaseTestAuth: + @pytest.fixture(autouse=True) + def set_attrs(self, minimal_app_for_auth_api): + self.app = minimal_app_for_auth_api + + sm = self.app.appbuilder.sm + tester = sm.find_user(username="test") + if not tester: + role_admin = sm.find_role("Admin") + sm.add_user( + username="test", + first_name="test", + last_name="test", + email="test@fab.org", + role=role_admin, + password="test", + ) + + +class TestBasicAuth(BaseTestAuth): + @pytest.fixture(autouse=True, scope="class") + def with_basic_auth_backend(self, minimal_app_for_auth_api): + from airflow.www.extensions.init_security import init_api_auth + + old_auth = getattr(minimal_app_for_auth_api, "api_auth") + + try: + with conf_vars( + {("api", "auth_backends"): "airflow.providers.fab.auth_manager.api.auth.backend.basic_auth"} + ): + init_api_auth(minimal_app_for_auth_api) + yield + finally: + setattr(minimal_app_for_auth_api, "api_auth", old_auth) + + def test_success(self): + token = "Basic " + b64encode(b"test:test").decode() + clear_db_pools() + + with self.app.test_client() as test_client: + response = test_client.get("/api/v1/pools", headers={"Authorization": token}) + assert current_user.email == "test@fab.org" + + assert response.status_code == 200 + assert response.json == { + "pools": [ + { + "name": "default_pool", + "slots": 128, + "occupied_slots": 0, + "running_slots": 0, + "queued_slots": 0, + "scheduled_slots": 0, + "deferred_slots": 0, + "open_slots": 128, + "description": "Default pool", + "include_deferred": False, + }, + ], + "total_entries": 1, + } + + @pytest.mark.parametrize( + "token", + [ + "basic", + "basic ", + "bearer", + "test:test", + b64encode(b"test:test").decode(), + "bearer ", + "basic: ", + "basic 123", + ], + ) + def test_malformed_headers(self, token): + with self.app.test_client() as test_client: + response = test_client.get("/api/v1/pools", headers={"Authorization": token}) + assert response.status_code == 401 + assert response.headers["Content-Type"] == "application/problem+json" + assert response.headers["WWW-Authenticate"] == "Basic" + assert_401(response) + + @pytest.mark.parametrize( + "token", + [ + "basic " + b64encode(b"test").decode(), + "basic " + b64encode(b"test:").decode(), + "basic " + b64encode(b"test:123").decode(), + "basic " + b64encode(b"test test").decode(), + ], + ) + def test_invalid_auth_header(self, token): + with self.app.test_client() as test_client: + response = test_client.get("/api/v1/pools", headers={"Authorization": token}) + assert response.status_code == 401 + assert response.headers["Content-Type"] == "application/problem+json" + assert response.headers["WWW-Authenticate"] == "Basic" + assert_401(response) + + +class TestSessionWithBasicAuthFallback(BaseTestAuth): + @pytest.fixture(autouse=True, scope="class") + def with_basic_auth_backend(self, minimal_app_for_auth_api): + from airflow.www.extensions.init_security import init_api_auth + + old_auth = getattr(minimal_app_for_auth_api, "api_auth") + + try: + with conf_vars( + { + ( + "api", + "auth_backends", + ): "airflow.api.auth.backend.session,airflow.providers.fab.auth_manager.api.auth.backend.basic_auth" + } + ): + init_api_auth(minimal_app_for_auth_api) + yield + finally: + setattr(minimal_app_for_auth_api, "api_auth", old_auth) + + def test_basic_auth_fallback(self): + token = "Basic " + b64encode(b"test:test").decode() + clear_db_pools() + + # request uses session + admin_user = client_with_login(self.app, username="test", password="test") + response = admin_user.get("/api/v1/pools") + assert response.status_code == 200 + + # request uses basic auth + with self.app.test_client() as test_client: + response = test_client.get("/api/v1/pools", headers={"Authorization": token}) + assert response.status_code == 200 + + # request without session or basic auth header + with self.app.test_client() as test_client: + response = test_client.get("/api/v1/pools") + assert response.status_code == 401 diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_backfill_endpoint.py b/tests/providers/fab/auth_manager/api_endpoints/test_backfill_endpoint.py new file mode 100644 index 000000000000..56f135d457e9 --- /dev/null +++ b/tests/providers/fab/auth_manager/api_endpoints/test_backfill_endpoint.py @@ -0,0 +1,264 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import os +from datetime import datetime +from unittest import mock +from urllib.parse import urlencode + +import pendulum +import pytest + +from airflow.models import DagBag, DagModel +from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS + +try: + from airflow.models.backfill import Backfill +except ImportError: + if AIRFLOW_V_3_0_PLUS: + raise + else: + pass +from airflow.models.dag import DAG +from airflow.models.serialized_dag import SerializedDagModel +from airflow.operators.empty import EmptyOperator +from airflow.security import permissions +from airflow.utils import timezone +from airflow.utils.session import provide_session +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user +from tests.test_utils.db import clear_db_backfills, clear_db_dags, clear_db_runs, clear_db_serialized_dags + +pytestmark = [ + pytest.mark.db_test, + pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), +] + + +DAG_ID = "test_dag" +TASK_ID = "op1" +DAG2_ID = "test_dag2" +DAG3_ID = "test_dag3" +UTC_JSON_REPR = "UTC" if pendulum.__version__.startswith("3") else "Timezone('UTC')" + + +@pytest.fixture(scope="module") +def configured_app(minimal_app_for_auth_api): + app = minimal_app_for_auth_api + + create_user(app, username="test_granular_permissions", role_name="TestGranularDag") + app.appbuilder.sm.sync_perm_for_dag( + "TEST_DAG_1", + access_control={ + "TestGranularDag": { + permissions.RESOURCE_DAG: {permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ} + }, + }, + ) + + with DAG( + DAG_ID, + schedule=None, + start_date=datetime(2020, 6, 15), + doc_md="details", + params={"foo": 1}, + tags=["example"], + ) as dag: + EmptyOperator(task_id=TASK_ID) + + with DAG(DAG2_ID, schedule=None, start_date=datetime(2020, 6, 15)) as dag2: # no doc_md + EmptyOperator(task_id=TASK_ID) + + with DAG(DAG3_ID, schedule=None) as dag3: # DAG start_date set to None + EmptyOperator(task_id=TASK_ID, start_date=datetime(2019, 6, 12)) + + dag_bag = DagBag(os.devnull, include_examples=False) + dag_bag.dags = {dag.dag_id: dag, dag2.dag_id: dag2, dag3.dag_id: dag3} + + app.dag_bag = dag_bag + + yield app + + delete_user(app, username="test_granular_permissions") + + +class TestBackfillEndpoint: + @staticmethod + def clean_db(): + clear_db_backfills() + clear_db_runs() + clear_db_dags() + clear_db_serialized_dags() + + @pytest.fixture(autouse=True) + def setup_attrs(self, configured_app) -> None: + self.clean_db() + self.app = configured_app + self.client = self.app.test_client() # type:ignore + self.dag_id = DAG_ID + self.dag2_id = DAG2_ID + self.dag3_id = DAG3_ID + + def teardown_method(self) -> None: + self.clean_db() + + @provide_session + def _create_dag_models(self, *, count=1, dag_id_prefix="TEST_DAG", is_paused=False, session=None): + dags = [] + for num in range(1, count + 1): + dag_model = DagModel( + dag_id=f"{dag_id_prefix}_{num}", + fileloc=f"/tmp/dag_{num}.py", + is_active=True, + timetable_summary="0 0 * * *", + is_paused=is_paused, + ) + session.add(dag_model) + dags.append(dag_model) + return dags + + @provide_session + def _create_deactivated_dag(self, session=None): + dag_model = DagModel( + dag_id="TEST_DAG_DELETED_1", + fileloc="/tmp/dag_del_1.py", + schedule_interval="2 2 * * *", + is_active=False, + ) + session.add(dag_model) + + +class TestListBackfills(TestBackfillEndpoint): + def test_should_respond_200_with_granular_dag_access(self, session): + (dag,) = self._create_dag_models() + from_date = timezone.utcnow() + to_date = timezone.utcnow() + b = Backfill( + dag_id=dag.dag_id, + from_date=from_date, + to_date=to_date, + ) + + session.add(b) + session.commit() + kwargs = {} + kwargs.update(environ_overrides={"REMOTE_USER": "test_granular_permissions"}) + response = self.client.get("/api/v1/backfills?dag_id=TEST_DAG_1", **kwargs) + assert response.status_code == 200 + + +class TestGetBackfill(TestBackfillEndpoint): + def test_should_respond_200_with_granular_dag_access(self, session): + (dag,) = self._create_dag_models() + from_date = timezone.utcnow() + to_date = timezone.utcnow() + backfill = Backfill( + dag_id=dag.dag_id, + from_date=from_date, + to_date=to_date, + ) + session.add(backfill) + session.commit() + kwargs = {} + kwargs.update(environ_overrides={"REMOTE_USER": "test_granular_permissions"}) + response = self.client.get(f"/api/v1/backfills/{backfill.id}", **kwargs) + assert response.status_code == 200 + + +class TestCreateBackfill(TestBackfillEndpoint): + def test_create_backfill(self, session, dag_maker): + with dag_maker(session=session, dag_id="TEST_DAG_1", schedule="0 * * * *") as dag: + EmptyOperator(task_id="mytask") + session.add(SerializedDagModel(dag)) + session.commit() + session.query(DagModel).all() + from_date = pendulum.parse("2024-01-01") + from_date_iso = from_date.isoformat() + to_date = pendulum.parse("2024-02-01") + to_date_iso = to_date.isoformat() + max_active_runs = 5 + query = urlencode( + query={ + "dag_id": dag.dag_id, + "from_date": f"{from_date_iso}", + "to_date": f"{to_date_iso}", + "max_active_runs": max_active_runs, + "reverse": False, + } + ) + kwargs = {} + kwargs.update(environ_overrides={"REMOTE_USER": "test_granular_permissions"}) + + response = self.client.post( + f"/api/v1/backfills?{query}", + **kwargs, + ) + assert response.status_code == 200 + assert response.json == { + "completed_at": mock.ANY, + "created_at": mock.ANY, + "dag_id": "TEST_DAG_1", + "dag_run_conf": None, + "from_date": from_date_iso, + "id": mock.ANY, + "is_paused": False, + "max_active_runs": 5, + "to_date": to_date_iso, + "updated_at": mock.ANY, + } + + +class TestPauseBackfill(TestBackfillEndpoint): + def test_should_respond_200_with_granular_dag_access(self, session): + (dag,) = self._create_dag_models() + from_date = timezone.utcnow() + to_date = timezone.utcnow() + backfill = Backfill( + dag_id=dag.dag_id, + from_date=from_date, + to_date=to_date, + ) + session.add(backfill) + session.commit() + kwargs = {} + kwargs.update(environ_overrides={"REMOTE_USER": "test_granular_permissions"}) + response = self.client.post(f"/api/v1/backfills/{backfill.id}/pause", **kwargs) + assert response.status_code == 200 + + +class TestCancelBackfill(TestBackfillEndpoint): + def test_should_respond_200_with_granular_dag_access(self, session): + (dag,) = self._create_dag_models() + from_date = timezone.utcnow() + to_date = timezone.utcnow() + backfill = Backfill( + dag_id=dag.dag_id, + from_date=from_date, + to_date=to_date, + ) + session.add(backfill) + session.commit() + kwargs = {} + kwargs.update(environ_overrides={"REMOTE_USER": "test_granular_permissions"}) + response = self.client.post(f"/api/v1/backfills/{backfill.id}/cancel", **kwargs) + assert response.status_code == 200 + # now it is marked as completed + assert pendulum.parse(response.json["completed_at"]) + + # get conflict when canceling already-canceled backfill + response = self.client.post(f"/api/v1/backfills/{backfill.id}/cancel", **kwargs) + assert response.status_code == 409 diff --git a/tests/api_connexion/test_cors.py b/tests/providers/fab/auth_manager/api_endpoints/test_cors.py similarity index 81% rename from tests/api_connexion/test_cors.py rename to tests/providers/fab/auth_manager/api_endpoints/test_cors.py index a2b7f0ebca74..b44eab8820ec 100644 --- a/tests/api_connexion/test_cors.py +++ b/tests/providers/fab/auth_manager/api_endpoints/test_cors.py @@ -20,16 +20,21 @@ import pytest +from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS from tests.test_utils.config import conf_vars from tests.test_utils.db import clear_db_pools -pytestmark = [pytest.mark.db_test, pytest.mark.skip_if_database_isolation_mode] +pytestmark = [ + pytest.mark.db_test, + pytest.mark.skip_if_database_isolation_mode, + pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), +] class BaseTestAuth: @pytest.fixture(autouse=True) - def set_attrs(self, minimal_app_for_api): - self.app = minimal_app_for_api + def set_attrs(self, minimal_app_for_auth_api): + self.app = minimal_app_for_auth_api sm = self.app.appbuilder.sm tester = sm.find_user(username="test") @@ -47,19 +52,19 @@ def set_attrs(self, minimal_app_for_api): class TestEmptyCors(BaseTestAuth): @pytest.fixture(autouse=True, scope="class") - def with_basic_auth_backend(self, minimal_app_for_api): + def with_basic_auth_backend(self, minimal_app_for_auth_api): from airflow.www.extensions.init_security import init_api_auth - old_auth = getattr(minimal_app_for_api, "api_auth") + old_auth = getattr(minimal_app_for_auth_api, "api_auth") try: with conf_vars( {("api", "auth_backends"): "airflow.providers.fab.auth_manager.api.auth.backend.basic_auth"} ): - init_api_auth(minimal_app_for_api) + init_api_auth(minimal_app_for_auth_api) yield finally: - setattr(minimal_app_for_api, "api_auth", old_auth) + setattr(minimal_app_for_auth_api, "api_auth", old_auth) def test_empty_cors_headers(self): token = "Basic " + b64encode(b"test:test").decode() @@ -75,10 +80,10 @@ def test_empty_cors_headers(self): class TestCorsOrigin(BaseTestAuth): @pytest.fixture(autouse=True, scope="class") - def with_basic_auth_backend(self, minimal_app_for_api): + def with_basic_auth_backend(self, minimal_app_for_auth_api): from airflow.www.extensions.init_security import init_api_auth - old_auth = getattr(minimal_app_for_api, "api_auth") + old_auth = getattr(minimal_app_for_auth_api, "api_auth") try: with conf_vars( @@ -90,10 +95,10 @@ def with_basic_auth_backend(self, minimal_app_for_api): ("api", "access_control_allow_origins"): "http://apache.org http://example.com", } ): - init_api_auth(minimal_app_for_api) + init_api_auth(minimal_app_for_auth_api) yield finally: - setattr(minimal_app_for_api, "api_auth", old_auth) + setattr(minimal_app_for_auth_api, "api_auth", old_auth) def test_cors_origin_reflection(self): token = "Basic " + b64encode(b"test:test").decode() @@ -119,10 +124,10 @@ def test_cors_origin_reflection(self): class TestCorsWildcard(BaseTestAuth): @pytest.fixture(autouse=True, scope="class") - def with_basic_auth_backend(self, minimal_app_for_api): + def with_basic_auth_backend(self, minimal_app_for_auth_api): from airflow.www.extensions.init_security import init_api_auth - old_auth = getattr(minimal_app_for_api, "api_auth") + old_auth = getattr(minimal_app_for_auth_api, "api_auth") try: with conf_vars( @@ -134,10 +139,10 @@ def with_basic_auth_backend(self, minimal_app_for_api): ("api", "access_control_allow_origins"): "*", } ): - init_api_auth(minimal_app_for_api) + init_api_auth(minimal_app_for_auth_api) yield finally: - setattr(minimal_app_for_api, "api_auth", old_auth) + setattr(minimal_app_for_auth_api, "api_auth", old_auth) def test_cors_origin_reflection(self): token = "Basic " + b64encode(b"test:test").decode() diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_dag_endpoint.py b/tests/providers/fab/auth_manager/api_endpoints/test_dag_endpoint.py new file mode 100644 index 000000000000..b78ac58e442e --- /dev/null +++ b/tests/providers/fab/auth_manager/api_endpoints/test_dag_endpoint.py @@ -0,0 +1,252 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import os +from datetime import datetime + +import pendulum +import pytest + +from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP +from airflow.models import DagBag, DagModel +from airflow.models.dag import DAG +from airflow.operators.empty import EmptyOperator +from airflow.security import permissions +from airflow.utils.session import provide_session +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user +from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS +from tests.test_utils.db import clear_db_dags, clear_db_runs, clear_db_serialized_dags +from tests.test_utils.www import _check_last_log + +pytestmark = [ + pytest.mark.db_test, + pytest.mark.skip_if_database_isolation_mode, + pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), +] + + +@pytest.fixture +def current_file_token(url_safe_serializer) -> str: + return url_safe_serializer.dumps(__file__) + + +DAG_ID = "test_dag" +TASK_ID = "op1" +DAG2_ID = "test_dag2" +DAG3_ID = "test_dag3" +UTC_JSON_REPR = "UTC" if pendulum.__version__.startswith("3") else "Timezone('UTC')" + + +@pytest.fixture(scope="module") +def configured_app(minimal_app_for_auth_api): + app = minimal_app_for_auth_api + + create_user(app, username="test_granular_permissions", role_name="TestGranularDag") + app.appbuilder.sm.sync_perm_for_dag( + "TEST_DAG_1", + access_control={ + "TestGranularDag": { + permissions.RESOURCE_DAG: {permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ} + }, + }, + ) + app.appbuilder.sm.sync_perm_for_dag( + "TEST_DAG_1", + access_control={ + "TestGranularDag": { + permissions.RESOURCE_DAG: {permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ} + }, + }, + ) + + with DAG( + DAG_ID, + schedule=None, + start_date=datetime(2020, 6, 15), + doc_md="details", + params={"foo": 1}, + tags=["example"], + ) as dag: + EmptyOperator(task_id=TASK_ID) + + with DAG(DAG2_ID, schedule=None, start_date=datetime(2020, 6, 15)) as dag2: # no doc_md + EmptyOperator(task_id=TASK_ID) + + with DAG(DAG3_ID, schedule=None) as dag3: # DAG start_date set to None + EmptyOperator(task_id=TASK_ID, start_date=datetime(2019, 6, 12)) + + dag_bag = DagBag(os.devnull, include_examples=False) + dag_bag.dags = {dag.dag_id: dag, dag2.dag_id: dag2, dag3.dag_id: dag3} + + app.dag_bag = dag_bag + + yield app + + delete_user(app, username="test_granular_permissions") + + +class TestDagEndpoint: + @staticmethod + def clean_db(): + clear_db_runs() + clear_db_dags() + clear_db_serialized_dags() + + @pytest.fixture(autouse=True) + def setup_attrs(self, configured_app) -> None: + self.clean_db() + self.app = configured_app + self.client = self.app.test_client() # type:ignore + self.dag_id = DAG_ID + self.dag2_id = DAG2_ID + self.dag3_id = DAG3_ID + + def teardown_method(self) -> None: + self.clean_db() + + @provide_session + def _create_dag_models(self, count, dag_id_prefix="TEST_DAG", is_paused=False, session=None): + for num in range(1, count + 1): + dag_model = DagModel( + dag_id=f"{dag_id_prefix}_{num}", + fileloc=f"/tmp/dag_{num}.py", + timetable_summary="2 2 * * *", + is_active=True, + is_paused=is_paused, + ) + session.add(dag_model) + + @provide_session + def _create_dag_model_for_details_endpoint(self, dag_id, session=None): + dag_model = DagModel( + dag_id=dag_id, + fileloc="/tmp/dag.py", + timetable_summary="2 2 * * *", + is_active=True, + is_paused=False, + ) + session.add(dag_model) + + @provide_session + def _create_dag_model_for_details_endpoint_with_dataset_expression(self, dag_id, session=None): + dag_model = DagModel( + dag_id=dag_id, + fileloc="/tmp/dag.py", + timetable_summary="2 2 * * *", + is_active=True, + is_paused=False, + dataset_expression={ + "any": [ + "s3://dag1/output_1.txt", + {"all": ["s3://dag2/output_1.txt", "s3://dag3/output_3.txt"]}, + ] + }, + ) + session.add(dag_model) + + @provide_session + def _create_deactivated_dag(self, session=None): + dag_model = DagModel( + dag_id="TEST_DAG_DELETED_1", + fileloc="/tmp/dag_del_1.py", + timetable_summary="2 2 * * *", + is_active=False, + ) + session.add(dag_model) + + +class TestGetDag(TestDagEndpoint): + def test_should_respond_200_with_granular_dag_access(self): + self._create_dag_models(1) + response = self.client.get( + "/api/v1/dags/TEST_DAG_1", environ_overrides={"REMOTE_USER": "test_granular_permissions"} + ) + assert response.status_code == 200 + + def test_should_respond_403_with_granular_access_for_different_dag(self): + self._create_dag_models(3) + response = self.client.get( + "/api/v1/dags/TEST_DAG_2", environ_overrides={"REMOTE_USER": "test_granular_permissions"} + ) + assert response.status_code == 403 + + +class TestGetDags(TestDagEndpoint): + def test_should_respond_200_with_granular_dag_access(self): + self._create_dag_models(3) + response = self.client.get( + "/api/v1/dags", environ_overrides={"REMOTE_USER": "test_granular_permissions"} + ) + assert response.status_code == 200 + assert len(response.json["dags"]) == 1 + assert response.json["dags"][0]["dag_id"] == "TEST_DAG_1" + + +class TestPatchDag(TestDagEndpoint): + @provide_session + def _create_dag_model(self, session=None): + dag_model = DagModel( + dag_id="TEST_DAG_1", fileloc="/tmp/dag_1.py", timetable_summary="2 2 * * *", is_paused=True + ) + session.add(dag_model) + return dag_model + + def test_should_respond_200_on_patch_with_granular_dag_access(self, session): + self._create_dag_models(1) + response = self.client.patch( + "/api/v1/dags/TEST_DAG_1", + json={ + "is_paused": False, + }, + environ_overrides={"REMOTE_USER": "test_granular_permissions"}, + ) + assert response.status_code == 200 + _check_last_log(session, dag_id="TEST_DAG_1", event="api.patch_dag", execution_date=None) + + def test_validation_error_raises_400(self): + patch_body = { + "ispaused": True, + } + dag_model = self._create_dag_model() + response = self.client.patch( + f"/api/v1/dags/{dag_model.dag_id}", + json=patch_body, + environ_overrides={"REMOTE_USER": "test_granular_permissions"}, + ) + assert response.status_code == 400 + assert response.json == { + "detail": "{'ispaused': ['Unknown field.']}", + "status": 400, + "title": "Bad Request", + "type": EXCEPTIONS_LINK_MAP[400], + } + + +class TestPatchDags(TestDagEndpoint): + def test_should_respond_200_with_granular_dag_access(self): + self._create_dag_models(3) + response = self.client.patch( + "api/v1/dags?dag_id_pattern=~", + json={ + "is_paused": False, + }, + environ_overrides={"REMOTE_USER": "test_granular_permissions"}, + ) + assert response.status_code == 200 + assert len(response.json["dags"]) == 1 + assert response.json["dags"][0]["dag_id"] == "TEST_DAG_1" diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_dag_run_endpoint.py b/tests/providers/fab/auth_manager/api_endpoints/test_dag_run_endpoint.py new file mode 100644 index 000000000000..a58ea08ff31c --- /dev/null +++ b/tests/providers/fab/auth_manager/api_endpoints/test_dag_run_endpoint.py @@ -0,0 +1,273 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from datetime import timedelta + +import pytest + +from airflow.models.dag import DAG, DagModel +from airflow.models.dagrun import DagRun +from airflow.models.param import Param +from airflow.security import permissions +from airflow.utils import timezone +from airflow.utils.session import create_session +from airflow.utils.state import DagRunState +from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS + +try: + from airflow.utils.types import DagRunTriggeredByType, DagRunType +except ImportError: + if AIRFLOW_V_3_0_PLUS: + raise + else: + pass +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import ( + create_user, + delete_roles, + delete_user, +) +from tests.test_utils.db import clear_db_dags, clear_db_runs, clear_db_serialized_dags + +pytestmark = [ + pytest.mark.db_test, + pytest.mark.skip_if_database_isolation_mode, + pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), +] + + +@pytest.fixture(scope="module") +def configured_app(minimal_app_for_auth_api): + app = minimal_app_for_auth_api + + create_user( + app, + username="test_no_dag_run_create_permission", + role_name="TestNoDagRunCreatePermission", + permissions=[ + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_ASSET), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_CLUSTER_ACTIVITY), + (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), + (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN), + (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_DAG_RUN), + ], + ) + create_user( + app, + username="test_dag_view_only", + role_name="TestViewDags", + permissions=[ + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), + (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_DAG_RUN), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), + (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN), + (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_DAG_RUN), + ], + ) + create_user( + app, + username="test_view_dags", + role_name="TestViewDags", + permissions=[ + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), + (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_DAG_RUN), + ], + ) + create_user( + app, + username="test_granular_permissions", + role_name="TestGranularDag", + permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN)], + ) + app.appbuilder.sm.sync_perm_for_dag( + "TEST_DAG_ID", + access_control={ + "TestGranularDag": {permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ}, + "TestNoDagRunCreatePermission": {permissions.RESOURCE_DAG_RUN: {permissions.ACTION_CAN_CREATE}}, + }, + ) + + yield app + + delete_user(app, username="test_dag_view_only") + delete_user(app, username="test_view_dags") + delete_user(app, username="test_granular_permissions") + delete_user(app, username="test_no_dag_run_create_permission") + delete_roles(app) + + +class TestDagRunEndpoint: + default_time = "2020-06-11T18:00:00+00:00" + default_time_2 = "2020-06-12T18:00:00+00:00" + default_time_3 = "2020-06-13T18:00:00+00:00" + + @pytest.fixture(autouse=True) + def setup_attrs(self, configured_app) -> None: + self.app = configured_app + self.client = self.app.test_client() # type:ignore + clear_db_runs() + clear_db_serialized_dags() + clear_db_dags() + + def teardown_method(self) -> None: + clear_db_runs() + clear_db_dags() + clear_db_serialized_dags() + + def _create_dag(self, dag_id): + dag_instance = DagModel(dag_id=dag_id) + dag_instance.is_active = True + with create_session() as session: + session.add(dag_instance) + dag = DAG(dag_id=dag_id, schedule=None, params={"validated_number": Param(1, minimum=1, maximum=10)}) + self.app.dag_bag.bag_dag(dag) + return dag_instance + + def _create_test_dag_run(self, state=DagRunState.RUNNING, extra_dag=False, commit=True, idx_start=1): + dag_runs = [] + dags = [] + triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {} + + for i in range(idx_start, idx_start + 2): + if i == 1: + dags.append(DagModel(dag_id="TEST_DAG_ID", is_active=True)) + dagrun_model = DagRun( + dag_id="TEST_DAG_ID", + run_id=f"TEST_DAG_RUN_ID_{i}", + run_type=DagRunType.MANUAL, + execution_date=timezone.parse(self.default_time) + timedelta(days=i - 1), + start_date=timezone.parse(self.default_time), + external_trigger=True, + state=state, + **triggered_by_kwargs, + ) + dagrun_model.updated_at = timezone.parse(self.default_time) + dag_runs.append(dagrun_model) + + if extra_dag: + for i in range(idx_start + 2, idx_start + 4): + dags.append(DagModel(dag_id=f"TEST_DAG_ID_{i}")) + dag_runs.append( + DagRun( + dag_id=f"TEST_DAG_ID_{i}", + run_id=f"TEST_DAG_RUN_ID_{i}", + run_type=DagRunType.MANUAL, + execution_date=timezone.parse(self.default_time_2), + start_date=timezone.parse(self.default_time), + external_trigger=True, + state=state, + ) + ) + if commit: + with create_session() as session: + session.add_all(dag_runs) + session.add_all(dags) + return dag_runs + + +class TestGetDagRuns(TestDagRunEndpoint): + def test_should_return_accessible_with_tilde_as_dag_id_and_dag_level_permissions(self): + self._create_test_dag_run(extra_dag=True) + expected_dag_run_ids = ["TEST_DAG_ID", "TEST_DAG_ID"] + response = self.client.get( + "api/v1/dags/~/dagRuns", environ_overrides={"REMOTE_USER": "test_granular_permissions"} + ) + assert response.status_code == 200 + dag_run_ids = [dag_run["dag_id"] for dag_run in response.json["dag_runs"]] + assert dag_run_ids == expected_dag_run_ids + + +class TestGetDagRunBatch(TestDagRunEndpoint): + def test_should_return_accessible_with_tilde_as_dag_id_and_dag_level_permissions(self): + self._create_test_dag_run(extra_dag=True) + expected_response_json_1 = { + "dag_id": "TEST_DAG_ID", + "dag_run_id": "TEST_DAG_RUN_ID_1", + "end_date": None, + "state": "running", + "execution_date": self.default_time, + "logical_date": self.default_time, + "external_trigger": True, + "start_date": self.default_time, + "conf": {}, + "data_interval_end": None, + "data_interval_start": None, + "last_scheduling_decision": None, + "run_type": "manual", + "note": None, + } + expected_response_json_1.update({"triggered_by": "test"} if AIRFLOW_V_3_0_PLUS else {}) + expected_response_json_2 = { + "dag_id": "TEST_DAG_ID", + "dag_run_id": "TEST_DAG_RUN_ID_2", + "end_date": None, + "state": "running", + "execution_date": self.default_time_2, + "logical_date": self.default_time_2, + "external_trigger": True, + "start_date": self.default_time, + "conf": {}, + "data_interval_end": None, + "data_interval_start": None, + "last_scheduling_decision": None, + "run_type": "manual", + "note": None, + } + expected_response_json_2.update({"triggered_by": "test"} if AIRFLOW_V_3_0_PLUS else {}) + + response = self.client.post( + "api/v1/dags/~/dagRuns/list", + json={"dag_ids": []}, + environ_overrides={"REMOTE_USER": "test_granular_permissions"}, + ) + assert response.status_code == 200 + assert response.json == { + "dag_runs": [ + expected_response_json_1, + expected_response_json_2, + ], + "total_entries": 2, + } + + +class TestPostDagRun(TestDagRunEndpoint): + def test_dagrun_trigger_with_dag_level_permissions(self): + self._create_dag("TEST_DAG_ID") + response = self.client.post( + "api/v1/dags/TEST_DAG_ID/dagRuns", + json={"conf": {"validated_number": 1}}, + environ_overrides={"REMOTE_USER": "test_no_dag_run_create_permission"}, + ) + assert response.status_code == 200 + + @pytest.mark.parametrize( + "username", + ["test_dag_view_only", "test_view_dags", "test_granular_permissions"], + ) + def test_should_raises_403_unauthorized(self, username): + self._create_dag("TEST_DAG_ID") + response = self.client.post( + "api/v1/dags/TEST_DAG_ID/dagRuns", + json={ + "dag_run_id": "TEST_DAG_RUN_ID_1", + "execution_date": self.default_time, + }, + environ_overrides={"REMOTE_USER": username}, + ) + assert response.status_code == 403 diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_dag_source_endpoint.py b/tests/providers/fab/auth_manager/api_endpoints/test_dag_source_endpoint.py new file mode 100644 index 000000000000..f0d9b0da298c --- /dev/null +++ b/tests/providers/fab/auth_manager/api_endpoints/test_dag_source_endpoint.py @@ -0,0 +1,144 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import ast +import os +from typing import TYPE_CHECKING + +import pytest + +from airflow.models import DagBag +from airflow.security import permissions +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user +from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS +from tests.test_utils.db import clear_db_dag_code, clear_db_dags, clear_db_serialized_dags + +pytestmark = [ + pytest.mark.db_test, + pytest.mark.skip_if_database_isolation_mode, + pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), +] + +if TYPE_CHECKING: + from airflow.models.dag import DAG + +ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir)) +EXAMPLE_DAG_FILE = os.path.join("airflow", "example_dags", "example_bash_operator.py") +EXAMPLE_DAG_ID = "example_bash_operator" +TEST_DAG_ID = "latest_only" +NOT_READABLE_DAG_ID = "latest_only_with_trigger" +TEST_MULTIPLE_DAGS_ID = "asset_produces_1" + + +@pytest.fixture(scope="module") +def configured_app(minimal_app_for_auth_api): + app = minimal_app_for_auth_api + create_user( + app, + username="test", + role_name="Test", + permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_CODE)], + ) + app.appbuilder.sm.sync_perm_for_dag( + TEST_DAG_ID, + access_control={"Test": [permissions.ACTION_CAN_READ]}, + ) + app.appbuilder.sm.sync_perm_for_dag( + EXAMPLE_DAG_ID, + access_control={"Test": [permissions.ACTION_CAN_READ]}, + ) + app.appbuilder.sm.sync_perm_for_dag( + TEST_MULTIPLE_DAGS_ID, + access_control={"Test": [permissions.ACTION_CAN_READ]}, + ) + + yield app + + delete_user(app, username="test") + + +class TestGetSource: + @pytest.fixture(autouse=True) + def setup_attrs(self, configured_app) -> None: + self.app = configured_app + self.client = self.app.test_client() # type:ignore + self.clear_db() + + def teardown_method(self) -> None: + self.clear_db() + + @staticmethod + def clear_db(): + clear_db_dags() + clear_db_serialized_dags() + clear_db_dag_code() + + @staticmethod + def _get_dag_file_docstring(fileloc: str) -> str | None: + with open(fileloc) as f: + file_contents = f.read() + module = ast.parse(file_contents) + docstring = ast.get_docstring(module) + return docstring + + def test_should_respond_406(self, url_safe_serializer): + dagbag = DagBag(dag_folder=EXAMPLE_DAG_FILE) + dagbag.sync_to_db() + test_dag: DAG = dagbag.dags[TEST_DAG_ID] + + url = f"/api/v1/dagSources/{url_safe_serializer.dumps(test_dag.fileloc)}" + response = self.client.get( + url, headers={"Accept": "image/webp"}, environ_overrides={"REMOTE_USER": "test"} + ) + + assert 406 == response.status_code + + def test_should_respond_403_not_readable(self, url_safe_serializer): + dagbag = DagBag(dag_folder=EXAMPLE_DAG_FILE) + dagbag.sync_to_db() + dag: DAG = dagbag.dags[NOT_READABLE_DAG_ID] + + response = self.client.get( + f"/api/v1/dagSources/{url_safe_serializer.dumps(dag.fileloc)}", + headers={"Accept": "text/plain"}, + environ_overrides={"REMOTE_USER": "test"}, + ) + read_dag = self.client.get( + f"/api/v1/dags/{NOT_READABLE_DAG_ID}", + environ_overrides={"REMOTE_USER": "test"}, + ) + assert response.status_code == 403 + assert read_dag.status_code == 403 + + def test_should_respond_403_some_dags_not_readable_in_the_file(self, url_safe_serializer): + dagbag = DagBag(dag_folder=EXAMPLE_DAG_FILE) + dagbag.sync_to_db() + dag: DAG = dagbag.dags[TEST_MULTIPLE_DAGS_ID] + + response = self.client.get( + f"/api/v1/dagSources/{url_safe_serializer.dumps(dag.fileloc)}", + headers={"Accept": "text/plain"}, + environ_overrides={"REMOTE_USER": "test"}, + ) + + read_dag = self.client.get( + f"/api/v1/dags/{TEST_MULTIPLE_DAGS_ID}", + environ_overrides={"REMOTE_USER": "test"}, + ) + assert response.status_code == 403 + assert read_dag.status_code == 200 diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_dag_warning_endpoint.py b/tests/providers/fab/auth_manager/api_endpoints/test_dag_warning_endpoint.py new file mode 100644 index 000000000000..adfde1cc5b3e --- /dev/null +++ b/tests/providers/fab/auth_manager/api_endpoints/test_dag_warning_endpoint.py @@ -0,0 +1,84 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pytest + +from airflow.models.dag import DagModel +from airflow.models.dagwarning import DagWarning +from airflow.security import permissions +from airflow.utils.session import create_session +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user +from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS +from tests.test_utils.db import clear_db_dag_warnings, clear_db_dags + +pytestmark = [ + pytest.mark.db_test, + pytest.mark.skip_if_database_isolation_mode, + pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), +] + + +@pytest.fixture(scope="module") +def configured_app(minimal_app_for_auth_api): + app = minimal_app_for_auth_api + create_user( + app, # type:ignore + username="test_with_dag2_read", + role_name="TestWithDag2Read", + permissions=[ + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_WARNING), + (permissions.ACTION_CAN_READ, f"{permissions.RESOURCE_DAG_PREFIX}dag2"), + ], + ) + + yield app + + delete_user(app, username="test_with_dag2_read") + + +class TestBaseDagWarning: + timestamp = "2020-06-10T12:00" + + @pytest.fixture(autouse=True) + def setup_attrs(self, configured_app) -> None: + self.app = configured_app + self.client = self.app.test_client() # type:ignore + + def teardown_method(self) -> None: + clear_db_dag_warnings() + clear_db_dags() + + +class TestGetDagWarningEndpoint(TestBaseDagWarning): + def setup_class(self): + clear_db_dag_warnings() + clear_db_dags() + + def setup_method(self): + with create_session() as session: + session.add(DagModel(dag_id="dag1")) + session.add(DagWarning("dag1", "non-existent pool", "test message")) + session.commit() + + def test_should_raise_403_forbidden_when_user_has_no_dag_read_permission(self): + response = self.client.get( + "/api/v1/dagWarnings", + environ_overrides={"REMOTE_USER": "test_with_dag2_read"}, + query_string={"dag_id": "dag1"}, + ) + assert response.status_code == 403 diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_dataset_endpoint.py b/tests/providers/fab/auth_manager/api_endpoints/test_dataset_endpoint.py new file mode 100644 index 000000000000..4d302722223d --- /dev/null +++ b/tests/providers/fab/auth_manager/api_endpoints/test_dataset_endpoint.py @@ -0,0 +1,327 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import Generator + +import pytest +import time_machine + +from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP +from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS + +try: + from airflow.models.asset import AssetDagRunQueue, AssetModel +except ImportError: + if AIRFLOW_V_3_0_PLUS: + raise + else: + pass +from airflow.security import permissions +from airflow.utils import timezone +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user +from tests.test_utils.db import clear_db_assets, clear_db_runs +from tests.test_utils.www import _check_last_log + +pytestmark = [ + pytest.mark.db_test, + pytest.mark.skip_if_database_isolation_mode, + pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), +] + + +@pytest.fixture(scope="module") +def configured_app(minimal_app_for_auth_api): + app = minimal_app_for_auth_api + create_user( + app, + username="test_queued_event", + role_name="TestQueuedEvent", + permissions=[ + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_ASSET), + (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_ASSET), + ], + ) + + yield app + + delete_user(app, username="test_queued_event") + + +class TestAssetEndpoint: + default_time = "2020-06-11T18:00:00+00:00" + + @pytest.fixture(autouse=True) + def setup_attrs(self, configured_app) -> None: + self.app = configured_app + self.client = self.app.test_client() + clear_db_assets() + clear_db_runs() + + def teardown_method(self) -> None: + clear_db_assets() + clear_db_runs() + + def _create_asset(self, session): + asset_model = AssetModel( + id=1, + uri="s3://bucket/key", + extra={"foo": "bar"}, + created_at=timezone.parse(self.default_time), + updated_at=timezone.parse(self.default_time), + ) + session.add(asset_model) + session.commit() + return asset_model + + +class TestQueuedEventEndpoint(TestAssetEndpoint): + @pytest.fixture + def time_freezer(self) -> Generator: + freezer = time_machine.travel(self.default_time, tick=False) + freezer.start() + + yield + + freezer.stop() + + def _create_asset_dag_run_queues(self, dag_id, dataset_id, session): + ddrq = AssetDagRunQueue(target_dag_id=dag_id, dataset_id=dataset_id) + session.add(ddrq) + session.commit() + return ddrq + + +class TestGetDagDatasetQueuedEvent(TestQueuedEventEndpoint): + @pytest.mark.usefixtures("time_freezer") + def test_should_respond_200(self, session, create_dummy_dag): + dag, _ = create_dummy_dag() + dag_id = dag.dag_id + dataset_id = self._create_asset(session).id + self._create_asset_dag_run_queues(dag_id, dataset_id, session) + dataset_uri = "s3://bucket/key" + + response = self.client.get( + f"/api/v1/dags/{dag_id}/datasets/queuedEvent/{dataset_uri}", + environ_overrides={"REMOTE_USER": "test_queued_event"}, + ) + + assert response.status_code == 200 + assert response.json == { + "created_at": self.default_time, + "uri": "s3://bucket/key", + "dag_id": "dag", + } + + def test_should_respond_404(self): + dag_id = "not_exists" + dataset_uri = "not_exists" + + response = self.client.get( + f"/api/v1/dags/{dag_id}/datasets/queuedEvent/{dataset_uri}", + environ_overrides={"REMOTE_USER": "test_queued_event"}, + ) + + assert response.status_code == 404 + assert { + "detail": "Queue event with dag_id: `not_exists` and asset uri: `not_exists` was not found", + "status": 404, + "title": "Queue event not found", + "type": EXCEPTIONS_LINK_MAP[404], + } == response.json + + +class TestDeleteDagDatasetQueuedEvent(TestAssetEndpoint): + def test_delete_should_respond_204(self, session, create_dummy_dag): + dag, _ = create_dummy_dag() + dag_id = dag.dag_id + dataset_uri = "s3://bucket/key" + dataset_id = self._create_asset(session).id + + ddrq = AssetDagRunQueue(target_dag_id=dag_id, dataset_id=dataset_id) + session.add(ddrq) + session.commit() + conn = session.query(AssetDagRunQueue).all() + assert len(conn) == 1 + + response = self.client.delete( + f"/api/v1/dags/{dag_id}/datasets/queuedEvent/{dataset_uri}", + environ_overrides={"REMOTE_USER": "test_queued_event"}, + ) + + assert response.status_code == 204 + conn = session.query(AssetDagRunQueue).all() + assert len(conn) == 0 + _check_last_log( + session, dag_id=dag_id, event="api.delete_dag_dataset_queued_event", execution_date=None + ) + + def test_should_respond_404(self): + dag_id = "not_exists" + dataset_uri = "not_exists" + + response = self.client.delete( + f"/api/v1/dags/{dag_id}/datasets/queuedEvent/{dataset_uri}", + environ_overrides={"REMOTE_USER": "test_queued_event"}, + ) + + assert response.status_code == 404 + assert { + "detail": "Queue event with dag_id: `not_exists` and asset uri: `not_exists` was not found", + "status": 404, + "title": "Queue event not found", + "type": EXCEPTIONS_LINK_MAP[404], + } == response.json + + +class TestGetDagDatasetQueuedEvents(TestQueuedEventEndpoint): + @pytest.mark.usefixtures("time_freezer") + def test_should_respond_200(self, session, create_dummy_dag): + dag, _ = create_dummy_dag() + dag_id = dag.dag_id + dataset_id = self._create_asset(session).id + self._create_asset_dag_run_queues(dag_id, dataset_id, session) + + response = self.client.get( + f"/api/v1/dags/{dag_id}/datasets/queuedEvent", + environ_overrides={"REMOTE_USER": "test_queued_event"}, + ) + + assert response.status_code == 200 + assert response.json == { + "queued_events": [ + { + "created_at": self.default_time, + "uri": "s3://bucket/key", + "dag_id": "dag", + } + ], + "total_entries": 1, + } + + def test_should_respond_404(self): + dag_id = "not_exists" + + response = self.client.get( + f"/api/v1/dags/{dag_id}/datasets/queuedEvent", + environ_overrides={"REMOTE_USER": "test_queued_event"}, + ) + + assert response.status_code == 404 + assert { + "detail": "Queue event with dag_id: `not_exists` was not found", + "status": 404, + "title": "Queue event not found", + "type": EXCEPTIONS_LINK_MAP[404], + } == response.json + + +class TestDeleteDagDatasetQueuedEvents(TestAssetEndpoint): + def test_should_respond_404(self): + dag_id = "not_exists" + + response = self.client.delete( + f"/api/v1/dags/{dag_id}/datasets/queuedEvent", + environ_overrides={"REMOTE_USER": "test_queued_event"}, + ) + + assert response.status_code == 404 + assert { + "detail": "Queue event with dag_id: `not_exists` was not found", + "status": 404, + "title": "Queue event not found", + "type": EXCEPTIONS_LINK_MAP[404], + } == response.json + + +class TestGetDatasetQueuedEvents(TestQueuedEventEndpoint): + @pytest.mark.usefixtures("time_freezer") + def test_should_respond_200(self, session, create_dummy_dag): + dag, _ = create_dummy_dag() + dag_id = dag.dag_id + dataset_id = self._create_asset(session).id + self._create_asset_dag_run_queues(dag_id, dataset_id, session) + dataset_uri = "s3://bucket/key" + + response = self.client.get( + f"/api/v1/datasets/queuedEvent/{dataset_uri}", + environ_overrides={"REMOTE_USER": "test_queued_event"}, + ) + + assert response.status_code == 200 + assert response.json == { + "queued_events": [ + { + "created_at": self.default_time, + "uri": "s3://bucket/key", + "dag_id": "dag", + } + ], + "total_entries": 1, + } + + def test_should_respond_404(self): + dataset_uri = "not_exists" + + response = self.client.get( + f"/api/v1/datasets/queuedEvent/{dataset_uri}", + environ_overrides={"REMOTE_USER": "test_queued_event"}, + ) + + assert response.status_code == 404 + assert { + "detail": "Queue event with asset uri: `not_exists` was not found", + "status": 404, + "title": "Queue event not found", + "type": EXCEPTIONS_LINK_MAP[404], + } == response.json + + +class TestDeleteDatasetQueuedEvents(TestQueuedEventEndpoint): + def test_delete_should_respond_204(self, session, create_dummy_dag): + dag, _ = create_dummy_dag() + dag_id = dag.dag_id + dataset_id = self._create_asset(session).id + self._create_asset_dag_run_queues(dag_id, dataset_id, session) + dataset_uri = "s3://bucket/key" + + response = self.client.delete( + f"/api/v1/datasets/queuedEvent/{dataset_uri}", + environ_overrides={"REMOTE_USER": "test_queued_event"}, + ) + + assert response.status_code == 204 + conn = session.query(AssetDagRunQueue).all() + assert len(conn) == 0 + _check_last_log(session, dag_id=None, event="api.delete_dataset_queued_events", execution_date=None) + + def test_should_respond_404(self): + dataset_uri = "not_exists" + + response = self.client.delete( + f"/api/v1/datasets/queuedEvent/{dataset_uri}", + environ_overrides={"REMOTE_USER": "test_queued_event"}, + ) + + assert response.status_code == 404 + assert { + "detail": "Queue event with asset uri: `not_exists` was not found", + "status": 404, + "title": "Queue event not found", + "type": EXCEPTIONS_LINK_MAP[404], + } == response.json diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_event_log_endpoint.py b/tests/providers/fab/auth_manager/api_endpoints/test_event_log_endpoint.py new file mode 100644 index 000000000000..acf3ca62684a --- /dev/null +++ b/tests/providers/fab/auth_manager/api_endpoints/test_event_log_endpoint.py @@ -0,0 +1,151 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pytest + +from airflow.models import Log +from airflow.security import permissions +from airflow.utils import timezone +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user +from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS +from tests.test_utils.db import clear_db_logs + +pytestmark = [ + pytest.mark.db_test, + pytest.mark.skip_if_database_isolation_mode, + pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), +] + + +@pytest.fixture(scope="module") +def configured_app(minimal_app_for_auth_api): + app = minimal_app_for_auth_api + create_user( + app, + username="test_granular", + role_name="TestGranular", + permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_AUDIT_LOG)], + ) + app.appbuilder.sm.sync_perm_for_dag( + "TEST_DAG_ID_1", + access_control={"TestGranular": [permissions.ACTION_CAN_READ]}, + ) + app.appbuilder.sm.sync_perm_for_dag( + "TEST_DAG_ID_2", + access_control={"TestGranular": [permissions.ACTION_CAN_READ]}, + ) + + yield app + + delete_user(app, username="test_granular") + + +@pytest.fixture +def task_instance(session, create_task_instance, request): + return create_task_instance( + session=session, + dag_id="TEST_DAG_ID", + task_id="TEST_TASK_ID", + run_id="TEST_RUN_ID", + execution_date=request.instance.default_time, + ) + + +@pytest.fixture +def create_log_model(create_task_instance, task_instance, session, request): + def maker(event, when, **kwargs): + log_model = Log( + event=event, + task_instance=task_instance, + **kwargs, + ) + log_model.dttm = when + + session.add(log_model) + session.flush() + return log_model + + return maker + + +class TestEventLogEndpoint: + @pytest.fixture(autouse=True) + def setup_attrs(self, configured_app) -> None: + self.app = configured_app + self.client = self.app.test_client() # type:ignore + clear_db_logs() + self.default_time = timezone.parse("2020-06-10T20:00:00+00:00") + self.default_time_2 = timezone.parse("2020-06-11T07:00:00+00:00") + + def teardown_method(self) -> None: + clear_db_logs() + + +class TestGetEventLogs(TestEventLogEndpoint): + def test_should_filter_eventlogs_by_allowed_attributes(self, create_log_model, session): + eventlog1 = create_log_model( + event="TEST_EVENT_1", + dag_id="TEST_DAG_ID_1", + task_id="TEST_TASK_ID_1", + owner="TEST_OWNER_1", + when=self.default_time, + ) + eventlog2 = create_log_model( + event="TEST_EVENT_2", + dag_id="TEST_DAG_ID_2", + task_id="TEST_TASK_ID_2", + owner="TEST_OWNER_2", + when=self.default_time_2, + ) + session.add_all([eventlog1, eventlog2]) + session.commit() + for attr in ["dag_id", "task_id", "owner", "event"]: + attr_value = f"TEST_{attr}_1".upper() + response = self.client.get( + f"/api/v1/eventLogs?{attr}={attr_value}", environ_overrides={"REMOTE_USER": "test_granular"} + ) + assert response.status_code == 200 + assert response.json["total_entries"] == 1 + assert len(response.json["event_logs"]) == 1 + assert response.json["event_logs"][0][attr] == attr_value + + def test_should_filter_eventlogs_by_included_events(self, create_log_model): + for event in ["TEST_EVENT_1", "TEST_EVENT_2", "cli_scheduler"]: + create_log_model(event=event, when=self.default_time) + response = self.client.get( + "/api/v1/eventLogs?included_events=TEST_EVENT_1,TEST_EVENT_2", + environ_overrides={"REMOTE_USER": "test_granular"}, + ) + assert response.status_code == 200 + response_data = response.json + assert len(response_data["event_logs"]) == 2 + assert response_data["total_entries"] == 2 + assert {"TEST_EVENT_1", "TEST_EVENT_2"} == {x["event"] for x in response_data["event_logs"]} + + def test_should_filter_eventlogs_by_excluded_events(self, create_log_model): + for event in ["TEST_EVENT_1", "TEST_EVENT_2", "cli_scheduler"]: + create_log_model(event=event, when=self.default_time) + response = self.client.get( + "/api/v1/eventLogs?excluded_events=TEST_EVENT_1,TEST_EVENT_2", + environ_overrides={"REMOTE_USER": "test_granular"}, + ) + assert response.status_code == 200 + response_data = response.json + assert len(response_data["event_logs"]) == 1 + assert response_data["total_entries"] == 1 + assert {"cli_scheduler"} == {x["event"] for x in response_data["event_logs"]} diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_import_error_endpoint.py b/tests/providers/fab/auth_manager/api_endpoints/test_import_error_endpoint.py new file mode 100644 index 000000000000..a2fa1d028a3f --- /dev/null +++ b/tests/providers/fab/auth_manager/api_endpoints/test_import_error_endpoint.py @@ -0,0 +1,221 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pytest + +from airflow.models.dag import DagModel +from airflow.security import permissions +from airflow.utils import timezone +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user +from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS, ParseImportError +from tests.test_utils.db import clear_db_dags, clear_db_import_errors +from tests.test_utils.permissions import _resource_name + +pytestmark = [ + pytest.mark.db_test, + pytest.mark.skip_if_database_isolation_mode, + pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), +] + +TEST_DAG_IDS = ["test_dag", "test_dag2"] + + +@pytest.fixture(scope="module") +def configured_app(minimal_app_for_auth_api): + app = minimal_app_for_auth_api + create_user( + app, + username="test_single_dag", + role_name="TestSingleDAG", + permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_IMPORT_ERROR)], + ) + # For some reason, DAG level permissions are not synced when in the above list of perms, + # so do it manually here: + app.appbuilder.sm.bulk_sync_roles( + [ + { + "role": "TestSingleDAG", + "perms": [ + ( + permissions.ACTION_CAN_READ, + _resource_name(TEST_DAG_IDS[0], permissions.RESOURCE_DAG), + ) + ], + } + ] + ) + + yield app + + delete_user(app, username="test_single_dag") + + +class TestBaseImportError: + timestamp = "2020-06-10T12:00" + + @pytest.fixture(autouse=True) + def setup_attrs(self, configured_app) -> None: + self.app = configured_app + self.client = self.app.test_client() # type:ignore + + clear_db_import_errors() + clear_db_dags() + + def teardown_method(self) -> None: + clear_db_import_errors() + clear_db_dags() + + @staticmethod + def _normalize_import_errors(import_errors): + for i, import_error in enumerate(import_errors, 1): + import_error["import_error_id"] = i + + +class TestGetImportErrorEndpoint(TestBaseImportError): + def test_should_raise_403_forbidden_without_dag_read(self, session): + import_error = ParseImportError( + filename="Lorem_ipsum.py", + stacktrace="Lorem ipsum", + timestamp=timezone.parse(self.timestamp, timezone="UTC"), + ) + session.add(import_error) + session.commit() + + response = self.client.get( + f"/api/v1/importErrors/{import_error.id}", environ_overrides={"REMOTE_USER": "test_single_dag"} + ) + + assert response.status_code == 403 + + def test_should_return_200_with_single_dag_read(self, session): + dag_model = DagModel(dag_id=TEST_DAG_IDS[0], fileloc="Lorem_ipsum.py") + session.add(dag_model) + import_error = ParseImportError( + filename="Lorem_ipsum.py", + stacktrace="Lorem ipsum", + timestamp=timezone.parse(self.timestamp, timezone="UTC"), + ) + session.add(import_error) + session.commit() + + response = self.client.get( + f"/api/v1/importErrors/{import_error.id}", environ_overrides={"REMOTE_USER": "test_single_dag"} + ) + + assert response.status_code == 200 + response_data = response.json + response_data["import_error_id"] = 1 + assert { + "filename": "Lorem_ipsum.py", + "import_error_id": 1, + "stack_trace": "Lorem ipsum", + "timestamp": "2020-06-10T12:00:00+00:00", + } == response_data + + def test_should_return_200_redacted_with_single_dag_read_in_dagfile(self, session): + for dag_id in TEST_DAG_IDS: + dag_model = DagModel(dag_id=dag_id, fileloc="Lorem_ipsum.py") + session.add(dag_model) + import_error = ParseImportError( + filename="Lorem_ipsum.py", + stacktrace="Lorem ipsum", + timestamp=timezone.parse(self.timestamp, timezone="UTC"), + ) + session.add(import_error) + session.commit() + + response = self.client.get( + f"/api/v1/importErrors/{import_error.id}", environ_overrides={"REMOTE_USER": "test_single_dag"} + ) + + assert response.status_code == 200 + response_data = response.json + response_data["import_error_id"] = 1 + assert { + "filename": "Lorem_ipsum.py", + "import_error_id": 1, + "stack_trace": "REDACTED - you do not have read permission on all DAGs in the file", + "timestamp": "2020-06-10T12:00:00+00:00", + } == response_data + + +class TestGetImportErrorsEndpoint(TestBaseImportError): + def test_get_import_errors_single_dag(self, session): + for dag_id in TEST_DAG_IDS: + fake_filename = f"/tmp/{dag_id}.py" + dag_model = DagModel(dag_id=dag_id, fileloc=fake_filename) + session.add(dag_model) + importerror = ParseImportError( + filename=fake_filename, + stacktrace="Lorem ipsum", + timestamp=timezone.parse(self.timestamp, timezone="UTC"), + ) + session.add(importerror) + session.commit() + + response = self.client.get( + "/api/v1/importErrors", environ_overrides={"REMOTE_USER": "test_single_dag"} + ) + + assert response.status_code == 200 + response_data = response.json + self._normalize_import_errors(response_data["import_errors"]) + assert { + "import_errors": [ + { + "filename": "/tmp/test_dag.py", + "import_error_id": 1, + "stack_trace": "Lorem ipsum", + "timestamp": "2020-06-10T12:00:00+00:00", + }, + ], + "total_entries": 1, + } == response_data + + def test_get_import_errors_single_dag_in_dagfile(self, session): + for dag_id in TEST_DAG_IDS: + fake_filename = "/tmp/all_in_one.py" + dag_model = DagModel(dag_id=dag_id, fileloc=fake_filename) + session.add(dag_model) + + importerror = ParseImportError( + filename="/tmp/all_in_one.py", + stacktrace="Lorem ipsum", + timestamp=timezone.parse(self.timestamp, timezone="UTC"), + ) + session.add(importerror) + session.commit() + + response = self.client.get( + "/api/v1/importErrors", environ_overrides={"REMOTE_USER": "test_single_dag"} + ) + + assert response.status_code == 200 + response_data = response.json + self._normalize_import_errors(response_data["import_errors"]) + assert { + "import_errors": [ + { + "filename": "/tmp/all_in_one.py", + "import_error_id": 1, + "stack_trace": "REDACTED - you do not have read permission on all DAGs in the file", + "timestamp": "2020-06-10T12:00:00+00:00", + }, + ], + "total_entries": 1, + } == response_data diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_role_and_permission_endpoint.py b/tests/providers/fab/auth_manager/api_endpoints/test_role_and_permission_endpoint.py index 30cfaeb22790..413a49a9d86a 100644 --- a/tests/providers/fab/auth_manager/api_endpoints/test_role_and_permission_endpoint.py +++ b/tests/providers/fab/auth_manager/api_endpoints/test_role_and_permission_endpoint.py @@ -19,6 +19,13 @@ import pytest from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import ( + create_role, + create_user, + delete_role, + delete_user, +) +from tests.test_utils.api_connexion_utils import assert_401 from tests.test_utils.compat import ignore_provider_compatibility_error with ignore_provider_compatibility_error("2.9.0+", __file__): @@ -27,13 +34,6 @@ from airflow.security import permissions -from tests.test_utils.api_connexion_utils import ( - assert_401, - create_role, - create_user, - delete_role, - delete_user, -) pytestmark = pytest.mark.db_test @@ -42,7 +42,7 @@ def configured_app(minimal_app_for_auth_api): app = minimal_app_for_auth_api create_user( - app, # type: ignore + app, username="test", role_name="Test", permissions=[ @@ -53,11 +53,11 @@ def configured_app(minimal_app_for_auth_api): (permissions.ACTION_CAN_READ, permissions.RESOURCE_ACTION), ], ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(app, username="test_no_permissions", role_name="TestNoPermissions") yield app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(app, username="test") + delete_user(app, username="test_no_permissions") class TestRoleEndpoint: diff --git a/tests/api_connexion/schemas/test_role_and_permission_schema.py b/tests/providers/fab/auth_manager/api_endpoints/test_role_and_permission_schema.py similarity index 85% rename from tests/api_connexion/schemas/test_role_and_permission_schema.py rename to tests/providers/fab/auth_manager/api_endpoints/test_role_and_permission_schema.py index f2967d519794..4a2f0068e5e4 100644 --- a/tests/api_connexion/schemas/test_role_and_permission_schema.py +++ b/tests/providers/fab/auth_manager/api_endpoints/test_role_and_permission_schema.py @@ -31,19 +31,19 @@ class TestRoleCollectionItemSchema: @pytest.fixture(scope="class") - def role(self, minimal_app_for_api): + def role(self, minimal_app_for_auth_api): yield create_role( - minimal_app_for_api, # type: ignore + minimal_app_for_auth_api, # type: ignore name="Test", permissions=[ (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_CONNECTION), ], ) - delete_role(minimal_app_for_api, "Test") + delete_role(minimal_app_for_auth_api, "Test") @pytest.fixture(autouse=True) - def _set_attrs(self, minimal_app_for_api, role): - self.app = minimal_app_for_api + def _set_attrs(self, minimal_app_for_auth_api, role): + self.app = minimal_app_for_auth_api self.role = role def test_serialize(self): @@ -67,26 +67,26 @@ def test_deserialize(self): class TestRoleCollectionSchema: @pytest.fixture(scope="class") - def role1(self, minimal_app_for_api): + def role1(self, minimal_app_for_auth_api): yield create_role( - minimal_app_for_api, # type: ignore + minimal_app_for_auth_api, # type: ignore name="Test1", permissions=[ (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_CONNECTION), ], ) - delete_role(minimal_app_for_api, "Test1") + delete_role(minimal_app_for_auth_api, "Test1") @pytest.fixture(scope="class") - def role2(self, minimal_app_for_api): + def role2(self, minimal_app_for_auth_api): yield create_role( - minimal_app_for_api, # type: ignore + minimal_app_for_auth_api, # type: ignore name="Test2", permissions=[ (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), ], ) - delete_role(minimal_app_for_api, "Test2") + delete_role(minimal_app_for_auth_api, "Test2") def test_serialize(self, role1, role2): instance = RoleCollection([role1, role2], total_entries=2) diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_task_instance_endpoint.py b/tests/providers/fab/auth_manager/api_endpoints/test_task_instance_endpoint.py new file mode 100644 index 000000000000..69b3c221eae9 --- /dev/null +++ b/tests/providers/fab/auth_manager/api_endpoints/test_task_instance_endpoint.py @@ -0,0 +1,427 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import datetime as dt +import urllib + +import pytest + +from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP +from airflow.models import DagRun, TaskInstance +from airflow.security import permissions +from airflow.utils.session import provide_session +from airflow.utils.state import State +from airflow.utils.timezone import datetime +from airflow.utils.types import DagRunType +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import ( + create_user, + delete_roles, + delete_user, +) +from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS +from tests.test_utils.db import clear_db_runs, clear_db_sla_miss, clear_rendered_ti_fields + +pytestmark = [ + pytest.mark.db_test, + pytest.mark.skip_if_database_isolation_mode, + pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), +] + +DEFAULT_DATETIME_1 = datetime(2020, 1, 1) +DEFAULT_DATETIME_STR_1 = "2020-01-01T00:00:00+00:00" +DEFAULT_DATETIME_STR_2 = "2020-01-02T00:00:00+00:00" + +QUOTED_DEFAULT_DATETIME_STR_1 = urllib.parse.quote(DEFAULT_DATETIME_STR_1) +QUOTED_DEFAULT_DATETIME_STR_2 = urllib.parse.quote(DEFAULT_DATETIME_STR_2) + + +@pytest.fixture(scope="module") +def configured_app(minimal_app_for_auth_api): + app = minimal_app_for_auth_api + create_user( + app, + username="test_dag_read_only", + role_name="TestDagReadOnly", + permissions=[ + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), + (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_TASK_INSTANCE), + ], + ) + create_user( + app, + username="test_task_read_only", + role_name="TestTaskReadOnly", + permissions=[ + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), + (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), + ], + ) + create_user( + app, + username="test_read_only_one_dag", + role_name="TestReadOnlyOneDag", + permissions=[ + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), + ], + ) + # For some reason, "DAG:example_python_operator" is not synced when in the above list of perms, + # so do it manually here: + app.appbuilder.sm.bulk_sync_roles( + [ + { + "role": "TestReadOnlyOneDag", + "perms": [(permissions.ACTION_CAN_READ, "DAG:example_python_operator")], + } + ] + ) + + yield app + + delete_user(app, username="test_dag_read_only") + delete_user(app, username="test_task_read_only") + delete_user(app, username="test_read_only_one_dag") + delete_roles(app) + + +class TestTaskInstanceEndpoint: + @pytest.fixture(autouse=True) + def setup_attrs(self, configured_app, dagbag) -> None: + self.default_time = DEFAULT_DATETIME_1 + self.ti_init = { + "execution_date": self.default_time, + "state": State.RUNNING, + } + self.ti_extras = { + "start_date": self.default_time + dt.timedelta(days=1), + "end_date": self.default_time + dt.timedelta(days=2), + "pid": 100, + "duration": 10000, + "pool": "default_pool", + "queue": "default_queue", + "job_id": 0, + } + self.app = configured_app + self.client = self.app.test_client() # type:ignore + clear_db_runs() + clear_db_sla_miss() + clear_rendered_ti_fields() + self.dagbag = dagbag + + def create_task_instances( + self, + session, + dag_id: str = "example_python_operator", + update_extras: bool = True, + task_instances=None, + dag_run_state=State.RUNNING, + with_ti_history=False, + ): + """Method to create task instances using kwargs and default arguments""" + + dag = self.dagbag.get_dag(dag_id) + tasks = dag.tasks + counter = len(tasks) + if task_instances is not None: + counter = min(len(task_instances), counter) + + run_id = "TEST_DAG_RUN_ID" + execution_date = self.ti_init.pop("execution_date", self.default_time) + dr = None + + tis = [] + for i in range(counter): + if task_instances is None: + pass + elif update_extras: + self.ti_extras.update(task_instances[i]) + else: + self.ti_init.update(task_instances[i]) + + if "execution_date" in self.ti_init: + run_id = f"TEST_DAG_RUN_ID_{i}" + execution_date = self.ti_init.pop("execution_date") + dr = None + + if not dr: + dr = DagRun( + run_id=run_id, + dag_id=dag_id, + execution_date=execution_date, + run_type=DagRunType.MANUAL, + state=dag_run_state, + ) + session.add(dr) + ti = TaskInstance(task=tasks[i], **self.ti_init) + session.add(ti) + ti.dag_run = dr + ti.note = "placeholder-note" + + for key, value in self.ti_extras.items(): + setattr(ti, key, value) + tis.append(ti) + + session.commit() + if with_ti_history: + for ti in tis: + ti.try_number = 1 + session.merge(ti) + session.commit() + dag.clear() + for ti in tis: + ti.try_number = 2 + ti.queue = "default_queue" + session.merge(ti) + session.commit() + return tis + + +class TestGetTaskInstance(TestTaskInstanceEndpoint): + def setup_method(self): + clear_db_runs() + + def teardown_method(self): + clear_db_runs() + + @pytest.mark.parametrize("username", ["test_dag_read_only", "test_task_read_only"]) + @provide_session + def test_should_respond_200(self, username, session): + self.create_task_instances(session) + # Update ti and set operator to None to + # test that operator field is nullable. + # This prevents issue when users upgrade to 2.0+ + # from 1.10.x + # https://github.com/apache/airflow/issues/14421 + session.query(TaskInstance).update({TaskInstance.operator: None}, synchronize_session="fetch") + session.commit() + response = self.client.get( + "/api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context", + environ_overrides={"REMOTE_USER": username}, + ) + assert response.status_code == 200 + + +class TestGetTaskInstances(TestTaskInstanceEndpoint): + @pytest.mark.parametrize( + "task_instances, user, expected_ti", + [ + pytest.param( + { + "example_python_operator": 2, + "example_skip_dag": 1, + }, + "test_read_only_one_dag", + 2, + ), + pytest.param( + { + "example_python_operator": 1, + "example_skip_dag": 2, + }, + "test_read_only_one_dag", + 1, + ), + ], + ) + def test_return_TI_only_from_readable_dags(self, task_instances, user, expected_ti, session): + for dag_id in task_instances: + self.create_task_instances( + session, + task_instances=[ + {"execution_date": DEFAULT_DATETIME_1 + dt.timedelta(days=i)} + for i in range(task_instances[dag_id]) + ], + dag_id=dag_id, + ) + response = self.client.get( + "/api/v1/dags/~/dagRuns/~/taskInstances", environ_overrides={"REMOTE_USER": user} + ) + assert response.status_code == 200 + assert response.json["total_entries"] == expected_ti + assert len(response.json["task_instances"]) == expected_ti + + +class TestGetTaskInstancesBatch(TestTaskInstanceEndpoint): + @pytest.mark.parametrize( + "task_instances, update_extras, payload, expected_ti_count, username", + [ + pytest.param( + [ + {"pool": "test_pool_1"}, + {"pool": "test_pool_2"}, + {"pool": "test_pool_3"}, + ], + True, + {"pool": ["test_pool_1", "test_pool_2"]}, + 2, + "test_dag_read_only", + id="test pool filter", + ), + pytest.param( + [ + {"state": State.RUNNING}, + {"state": State.QUEUED}, + {"state": State.SUCCESS}, + {"state": State.NONE}, + ], + False, + {"state": ["running", "queued", "none"]}, + 3, + "test_task_read_only", + id="test state filter", + ), + pytest.param( + [ + {"state": State.NONE}, + {"state": State.NONE}, + {"state": State.NONE}, + {"state": State.NONE}, + ], + False, + {}, + 4, + "test_task_read_only", + id="test dag with null states", + ), + pytest.param( + [ + {"end_date": DEFAULT_DATETIME_1}, + {"end_date": DEFAULT_DATETIME_1 + dt.timedelta(days=1)}, + {"end_date": DEFAULT_DATETIME_1 + dt.timedelta(days=2)}, + ], + True, + { + "end_date_gte": DEFAULT_DATETIME_STR_1, + "end_date_lte": DEFAULT_DATETIME_STR_2, + }, + 2, + "test_task_read_only", + id="test end date filter", + ), + pytest.param( + [ + {"start_date": DEFAULT_DATETIME_1}, + {"start_date": DEFAULT_DATETIME_1 + dt.timedelta(days=1)}, + {"start_date": DEFAULT_DATETIME_1 + dt.timedelta(days=2)}, + ], + True, + { + "start_date_gte": DEFAULT_DATETIME_STR_1, + "start_date_lte": DEFAULT_DATETIME_STR_2, + }, + 2, + "test_dag_read_only", + id="test start date filter", + ), + ], + ) + def test_should_respond_200( + self, task_instances, update_extras, payload, expected_ti_count, username, session + ): + self.create_task_instances( + session, + update_extras=update_extras, + task_instances=task_instances, + ) + response = self.client.post( + "/api/v1/dags/~/dagRuns/~/taskInstances/list", + environ_overrides={"REMOTE_USER": username}, + json=payload, + ) + assert response.status_code == 200, response.json + assert expected_ti_count == response.json["total_entries"] + assert expected_ti_count == len(response.json["task_instances"]) + + def test_returns_403_forbidden_when_user_has_access_to_only_some_dags(self, session): + self.create_task_instances(session=session) + self.create_task_instances(session=session, dag_id="example_skip_dag") + payload = {"dag_ids": ["example_python_operator", "example_skip_dag"]} + + response = self.client.post( + "/api/v1/dags/~/dagRuns/~/taskInstances/list", + environ_overrides={"REMOTE_USER": "test_read_only_one_dag"}, + json=payload, + ) + assert response.status_code == 403 + assert response.json == { + "detail": "User not allowed to access some of these DAGs: ['example_python_operator', 'example_skip_dag']", + "status": 403, + "title": "Forbidden", + "type": EXCEPTIONS_LINK_MAP[403], + } + + +class TestPostSetTaskInstanceState(TestTaskInstanceEndpoint): + @pytest.mark.parametrize("username", ["test_dag_read_only", "test_task_read_only"]) + def test_should_raise_403_forbidden(self, username): + response = self.client.post( + "/api/v1/dags/example_python_operator/updateTaskInstancesState", + environ_overrides={"REMOTE_USER": username}, + json={ + "dry_run": True, + "task_id": "print_the_context", + "execution_date": DEFAULT_DATETIME_1.isoformat(), + "include_upstream": True, + "include_downstream": True, + "include_future": True, + "include_past": True, + "new_state": "failed", + }, + ) + assert response.status_code == 403 + + +class TestPatchTaskInstance(TestTaskInstanceEndpoint): + ENDPOINT_URL = ( + "/api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context" + ) + + @pytest.mark.parametrize("username", ["test_dag_read_only", "test_task_read_only"]) + def test_should_raise_403_forbidden(self, username): + response = self.client.patch( + self.ENDPOINT_URL, + environ_overrides={"REMOTE_USER": username}, + json={ + "dry_run": True, + "new_state": "failed", + }, + ) + assert response.status_code == 403 + + +class TestGetTaskInstanceTry(TestTaskInstanceEndpoint): + def setup_method(self): + clear_db_runs() + + def teardown_method(self): + clear_db_runs() + + @pytest.mark.parametrize("username", ["test_dag_read_only", "test_task_read_only"]) + @provide_session + def test_should_respond_200(self, username, session): + self.create_task_instances(session, task_instances=[{"state": State.SUCCESS}], with_ti_history=True) + + response = self.client.get( + "/api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context/tries/1", + environ_overrides={"REMOTE_USER": username}, + ) + assert response.status_code == 200 diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_user_endpoint.py b/tests/providers/fab/auth_manager/api_endpoints/test_user_endpoint.py index bc400c8a43fa..7f2c885bab52 100644 --- a/tests/providers/fab/auth_manager/api_endpoints/test_user_endpoint.py +++ b/tests/providers/fab/auth_manager/api_endpoints/test_user_endpoint.py @@ -30,7 +30,12 @@ with ignore_provider_compatibility_error("2.9.0+", __file__): from airflow.providers.fab.auth_manager.models import User -from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_role, delete_user +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import ( + create_user, + delete_role, + delete_user, +) +from tests.test_utils.api_connexion_utils import assert_401 from tests.test_utils.config import conf_vars pytestmark = pytest.mark.db_test @@ -43,7 +48,7 @@ def configured_app(minimal_app_for_auth_api): app = minimal_app_for_auth_api create_user( - app, # type: ignore + app, username="test", role_name="Test", permissions=[ @@ -53,12 +58,12 @@ def configured_app(minimal_app_for_auth_api): (permissions.ACTION_CAN_READ, permissions.RESOURCE_USER), ], ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(app, username="test_no_permissions", role_name="TestNoPermissions") yield app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(app, username="test") + delete_user(app, username="test_no_permissions") delete_role(app, name="TestNoPermissions") diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_user_schema.py b/tests/providers/fab/auth_manager/api_endpoints/test_user_schema.py index 265407622e26..f3399de6a977 100644 --- a/tests/providers/fab/auth_manager/api_endpoints/test_user_schema.py +++ b/tests/providers/fab/auth_manager/api_endpoints/test_user_schema.py @@ -18,6 +18,7 @@ import pytest +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import create_role, delete_role from tests.test_utils.compat import ignore_provider_compatibility_error with ignore_provider_compatibility_error("2.9.0+", __file__): @@ -30,8 +31,6 @@ DEFAULT_TIME = "2021-01-09T13:59:56.336000+00:00" -from tests.test_utils.api_connexion_utils import create_role, delete_role # noqa: E402 - pytestmark = pytest.mark.db_test diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_variable_endpoint.py b/tests/providers/fab/auth_manager/api_endpoints/test_variable_endpoint.py new file mode 100644 index 000000000000..a8e71e1a8246 --- /dev/null +++ b/tests/providers/fab/auth_manager/api_endpoints/test_variable_endpoint.py @@ -0,0 +1,88 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pytest + +from airflow.models import Variable +from airflow.security import permissions +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user +from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS +from tests.test_utils.db import clear_db_variables + +pytestmark = [ + pytest.mark.db_test, + pytest.mark.skip_if_database_isolation_mode, + pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), +] + + +@pytest.fixture(scope="module") +def configured_app(minimal_app_for_auth_api): + app = minimal_app_for_auth_api + + create_user( + app, + username="test_read_only", + role_name="TestReadOnly", + permissions=[ + (permissions.ACTION_CAN_READ, permissions.RESOURCE_VARIABLE), + ], + ) + create_user( + app, + username="test_delete_only", + role_name="TestDeleteOnly", + permissions=[ + (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_VARIABLE), + ], + ) + + yield app + + delete_user(app, username="test_read_only") + delete_user(app, username="test_delete_only") + + +class TestVariableEndpoint: + @pytest.fixture(autouse=True) + def setup_method(self, configured_app) -> None: + self.app = configured_app + self.client = self.app.test_client() # type:ignore + clear_db_variables() + + def teardown_method(self) -> None: + clear_db_variables() + + +class TestGetVariable(TestVariableEndpoint): + @pytest.mark.parametrize( + "user, expected_status_code", + [ + ("test_read_only", 200), + ("test_delete_only", 403), + ], + ) + def test_read_variable(self, user, expected_status_code): + expected_value = '{"foo": 1}' + Variable.set("TEST_VARIABLE_KEY", expected_value) + response = self.client.get( + "/api/v1/variables/TEST_VARIABLE_KEY", environ_overrides={"REMOTE_USER": user} + ) + assert response.status_code == expected_status_code + if expected_status_code == 200: + assert response.json == {"key": "TEST_VARIABLE_KEY", "value": expected_value, "description": None} diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_xcom_endpoint.py b/tests/providers/fab/auth_manager/api_endpoints/test_xcom_endpoint.py new file mode 100644 index 000000000000..01336f9957c6 --- /dev/null +++ b/tests/providers/fab/auth_manager/api_endpoints/test_xcom_endpoint.py @@ -0,0 +1,230 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from datetime import timedelta + +import pytest + +from airflow.models.dag import DagModel +from airflow.models.dagrun import DagRun +from airflow.models.taskinstance import TaskInstance +from airflow.models.xcom import BaseXCom, XCom +from airflow.operators.empty import EmptyOperator +from airflow.security import permissions +from airflow.utils.dates import parse_execution_date +from airflow.utils.session import create_session +from airflow.utils.types import DagRunType +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user +from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS +from tests.test_utils.db import clear_db_dags, clear_db_runs, clear_db_xcom + +pytestmark = [ + pytest.mark.db_test, + pytest.mark.skip_if_database_isolation_mode, + pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), +] + + +class CustomXCom(BaseXCom): + @classmethod + def deserialize_value(cls, xcom: XCom): + return f"real deserialized {super().deserialize_value(xcom)}" + + def orm_deserialize_value(self): + return f"orm deserialized {super().orm_deserialize_value()}" + + +@pytest.fixture(scope="module") +def configured_app(minimal_app_for_auth_api): + app = minimal_app_for_auth_api + + create_user( + app, + username="test_granular_permissions", + role_name="TestGranularDag", + permissions=[ + (permissions.ACTION_CAN_READ, permissions.RESOURCE_XCOM), + ], + ) + app.appbuilder.sm.sync_perm_for_dag( + "test-dag-id-1", + access_control={"TestGranularDag": [permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ]}, + ) + + yield app + + delete_user(app, username="test_granular_permissions") + + +def _compare_xcom_collections(collection1: dict, collection_2: dict): + assert collection1.get("total_entries") == collection_2.get("total_entries") + + def sort_key(record): + return ( + record.get("dag_id"), + record.get("task_id"), + record.get("execution_date"), + record.get("map_index"), + record.get("key"), + ) + + assert sorted(collection1.get("xcom_entries", []), key=sort_key) == sorted( + collection_2.get("xcom_entries", []), key=sort_key + ) + + +class TestXComEndpoint: + @staticmethod + def clean_db(): + clear_db_dags() + clear_db_runs() + clear_db_xcom() + + @pytest.fixture(autouse=True) + def setup_attrs(self, configured_app) -> None: + """ + Setup For XCom endpoint TC + """ + self.app = configured_app + self.client = self.app.test_client() # type:ignore + # clear existing xcoms + self.clean_db() + + def teardown_method(self) -> None: + """ + Clear Hanging XComs + """ + self.clean_db() + + +class TestGetXComEntries(TestXComEndpoint): + def test_should_respond_200_with_tilde_and_granular_dag_access(self): + dag_id_1 = "test-dag-id-1" + task_id_1 = "test-task-id-1" + execution_date = "2005-04-02T00:00:00+00:00" + execution_date_parsed = parse_execution_date(execution_date) + dag_run_id_1 = DagRun.generate_run_id(DagRunType.MANUAL, execution_date_parsed) + self._create_xcom_entries(dag_id_1, dag_run_id_1, execution_date_parsed, task_id_1) + + dag_id_2 = "test-dag-id-2" + task_id_2 = "test-task-id-2" + run_id_2 = DagRun.generate_run_id(DagRunType.MANUAL, execution_date_parsed) + self._create_xcom_entries(dag_id_2, run_id_2, execution_date_parsed, task_id_2) + self._create_invalid_xcom_entries(execution_date_parsed) + response = self.client.get( + "/api/v1/dags/~/dagRuns/~/taskInstances/~/xcomEntries", + environ_overrides={"REMOTE_USER": "test_granular_permissions"}, + ) + + assert 200 == response.status_code + response_data = response.json + for xcom_entry in response_data["xcom_entries"]: + xcom_entry["timestamp"] = "TIMESTAMP" + _compare_xcom_collections( + response_data, + { + "xcom_entries": [ + { + "dag_id": dag_id_1, + "execution_date": execution_date, + "key": "test-xcom-key-1", + "task_id": task_id_1, + "timestamp": "TIMESTAMP", + "map_index": -1, + }, + { + "dag_id": dag_id_1, + "execution_date": execution_date, + "key": "test-xcom-key-2", + "task_id": task_id_1, + "timestamp": "TIMESTAMP", + "map_index": -1, + }, + ], + "total_entries": 2, + }, + ) + + def _create_xcom_entries(self, dag_id, run_id, execution_date, task_id, mapped_ti=False): + with create_session() as session: + dag = DagModel(dag_id=dag_id) + session.add(dag) + dagrun = DagRun( + dag_id=dag_id, + run_id=run_id, + execution_date=execution_date, + start_date=execution_date, + run_type=DagRunType.MANUAL, + ) + session.add(dagrun) + if mapped_ti: + for i in [0, 1]: + ti = TaskInstance(EmptyOperator(task_id=task_id), run_id=run_id, map_index=i) + ti.dag_id = dag_id + session.add(ti) + else: + ti = TaskInstance(EmptyOperator(task_id=task_id), run_id=run_id) + ti.dag_id = dag_id + session.add(ti) + + for i in [1, 2]: + if mapped_ti: + key = "test-xcom-key" + map_index = i - 1 + else: + key = f"test-xcom-key-{i}" + map_index = -1 + + XCom.set( + key=key, value="TEST", run_id=run_id, task_id=task_id, dag_id=dag_id, map_index=map_index + ) + + def _create_invalid_xcom_entries(self, execution_date): + """ + Invalid XCom entries to test join query + """ + with create_session() as session: + dag = DagModel(dag_id="invalid_dag") + session.add(dag) + dagrun = DagRun( + dag_id="invalid_dag", + run_id="invalid_run_id", + execution_date=execution_date + timedelta(days=1), + start_date=execution_date, + run_type=DagRunType.MANUAL, + ) + session.add(dagrun) + dagrun1 = DagRun( + dag_id="invalid_dag", + run_id="not_this_run_id", + execution_date=execution_date, + start_date=execution_date, + run_type=DagRunType.MANUAL, + ) + session.add(dagrun1) + ti = TaskInstance(EmptyOperator(task_id="invalid_task"), run_id="not_this_run_id") + ti.dag_id = "invalid_dag" + session.add(ti) + for i in [1, 2]: + XCom.set( + key=f"invalid-xcom-key-{i}", + value="TEST", + run_id="not_this_run_id", + task_id="invalid_task", + dag_id="invalid_dag", + ) diff --git a/tests/providers/fab/auth_manager/conftest.py b/tests/providers/fab/auth_manager/conftest.py index 22c29dd229fa..a8fbe5fbdaaa 100644 --- a/tests/providers/fab/auth_manager/conftest.py +++ b/tests/providers/fab/auth_manager/conftest.py @@ -30,7 +30,10 @@ def minimal_app_for_auth_api(): "init_appbuilder", "init_api_auth", "init_api_auth_provider", + "init_api_connexion", "init_api_error_handlers", + "init_airflow_session_interface", + "init_appbuilder_views", ] ) def factory(): @@ -39,7 +42,11 @@ def factory(): ( "api", "auth_backends", - ): "tests.test_utils.remote_user_api_auth_backend,airflow.api.auth.backend.session" + ): "tests.providers.fab.auth_manager.api_endpoints.remote_user_api_auth_backend,airflow.api.auth.backend.session", + ( + "core", + "auth_manager", + ): "airflow.providers.fab.auth_manager.fab_auth_manager.FabAuthManager", } ): _app = app.create_app(testing=True, config={"WTF_CSRF_ENABLED": False}) # type:ignore @@ -58,3 +65,11 @@ def set_auth_role_public(request): yield app.config["AUTH_ROLE_PUBLIC"] = auto_role_public + + +@pytest.fixture(scope="module") +def dagbag(): + from airflow.models import DagBag + + DagBag(include_examples=True, read_dags_from_db=False).sync_to_db() + return DagBag(include_examples=True, read_dags_from_db=True) diff --git a/tests/providers/fab/auth_manager/test_security.py b/tests/providers/fab/auth_manager/test_security.py index 156b5cf62627..bebb52c256fc 100644 --- a/tests/providers/fab/auth_manager/test_security.py +++ b/tests/providers/fab/auth_manager/test_security.py @@ -49,7 +49,7 @@ from airflow.www.auth import get_access_denied_message from airflow.www.extensions.init_auth_manager import get_auth_manager from airflow.www.utils import CustomSQLAInterface -from tests.test_utils.api_connexion_utils import ( +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import ( create_user, create_user_scope, delete_role, diff --git a/tests/providers/fab/auth_manager/views/test_permissions.py b/tests/providers/fab/auth_manager/views/test_permissions.py index 0b1073df287f..f24d9b738343 100644 --- a/tests/providers/fab/auth_manager/views/test_permissions.py +++ b/tests/providers/fab/auth_manager/views/test_permissions.py @@ -21,7 +21,7 @@ from airflow.security import permissions from airflow.www import app as application -from tests.test_utils.api_connexion_utils import create_user, delete_user +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user from tests.test_utils.compat import AIRFLOW_V_2_9_PLUS from tests.test_utils.www import client_with_login diff --git a/tests/providers/fab/auth_manager/views/test_roles_list.py b/tests/providers/fab/auth_manager/views/test_roles_list.py index 156f07df4120..8de63ad5ba88 100644 --- a/tests/providers/fab/auth_manager/views/test_roles_list.py +++ b/tests/providers/fab/auth_manager/views/test_roles_list.py @@ -21,7 +21,7 @@ from airflow.security import permissions from airflow.www import app as application -from tests.test_utils.api_connexion_utils import create_user, delete_user +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user from tests.test_utils.compat import AIRFLOW_V_2_9_PLUS from tests.test_utils.www import client_with_login diff --git a/tests/providers/fab/auth_manager/views/test_user.py b/tests/providers/fab/auth_manager/views/test_user.py index 6660ab926d88..62b03a99e7c2 100644 --- a/tests/providers/fab/auth_manager/views/test_user.py +++ b/tests/providers/fab/auth_manager/views/test_user.py @@ -21,7 +21,7 @@ from airflow.security import permissions from airflow.www import app as application -from tests.test_utils.api_connexion_utils import create_user, delete_user +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user from tests.test_utils.compat import AIRFLOW_V_2_9_PLUS from tests.test_utils.www import client_with_login diff --git a/tests/providers/fab/auth_manager/views/test_user_edit.py b/tests/providers/fab/auth_manager/views/test_user_edit.py index 65937b6f83d3..8099f6794818 100644 --- a/tests/providers/fab/auth_manager/views/test_user_edit.py +++ b/tests/providers/fab/auth_manager/views/test_user_edit.py @@ -21,7 +21,7 @@ from airflow.security import permissions from airflow.www import app as application -from tests.test_utils.api_connexion_utils import create_user, delete_user +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user from tests.test_utils.compat import AIRFLOW_V_2_9_PLUS from tests.test_utils.www import client_with_login diff --git a/tests/providers/fab/auth_manager/views/test_user_stats.py b/tests/providers/fab/auth_manager/views/test_user_stats.py index 8cb260fcf1ec..ae09cf92252c 100644 --- a/tests/providers/fab/auth_manager/views/test_user_stats.py +++ b/tests/providers/fab/auth_manager/views/test_user_stats.py @@ -21,7 +21,7 @@ from airflow.security import permissions from airflow.www import app as application -from tests.test_utils.api_connexion_utils import create_user, delete_user +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user from tests.test_utils.compat import AIRFLOW_V_2_9_PLUS from tests.test_utils.www import client_with_login diff --git a/tests/test_utils/api_connexion_utils.py b/tests/test_utils/api_connexion_utils.py index af746b2d5546..48869ee48078 100644 --- a/tests/test_utils/api_connexion_utils.py +++ b/tests/test_utils/api_connexion_utils.py @@ -17,6 +17,7 @@ from __future__ import annotations from contextlib import contextmanager +from typing import TYPE_CHECKING from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP from tests.test_utils.compat import ignore_provider_compatibility_error @@ -24,6 +25,9 @@ with ignore_provider_compatibility_error("2.9.0+", __file__): from airflow.providers.fab.auth_manager.security_manager.override import EXISTING_ROLES +if TYPE_CHECKING: + from flask import Flask + @contextmanager def create_test_client(app, user_name, role_name, permissions): @@ -44,7 +48,11 @@ def create_user_scope(app, username, **kwargs): It will create a user and provide it for the fixture via YIELD (generator) then will tidy up once test is complete """ - test_user = create_user(app, username, **kwargs) + from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import ( + create_user as create_user_fab, + ) + + test_user = create_user_fab(app, username, **kwargs) try: yield test_user @@ -52,27 +60,20 @@ def create_user_scope(app, username, **kwargs): delete_user(app, username) -def create_user(app, username, role_name=None, email=None, permissions=None): - appbuilder = app.appbuilder - +def create_user(app: Flask, username: str, role_name: str | None): # Removes user and role so each test has isolated test data. delete_user(app, username) - role = None - if role_name: - delete_role(app, role_name) - role = create_role(app, role_name, permissions) - else: - role = [] - - return appbuilder.sm.add_user( - username=username, - first_name=username, - last_name=username, - email=email or f"{username}@example.org", - role=role, - password=username, + + users = app.config.get("SIMPLE_AUTH_MANAGER_USERS", []) + users.append( + { + "username": username, + "role": role_name, + } ) + app.config["SIMPLE_AUTH_MANAGER_USERS"] = users + def create_role(app, name, permissions=None): appbuilder = app.appbuilder @@ -87,14 +88,6 @@ def create_role(app, name, permissions=None): return role -def set_user_single_role(app, user, role_name): - role = create_role(app, role_name) - if role not in user.roles: - user.roles = [role] - app.appbuilder.sm.update_user(user) - user._perms = None - - def delete_role(app, name): if name not in EXISTING_ROLES: if app.appbuilder.sm.find_role(name): @@ -106,20 +99,11 @@ def delete_roles(app): delete_role(app, role.name) -def delete_user(app, username): - appbuilder = app.appbuilder - for user in appbuilder.sm.get_all_users(): - if user.username == username: - _ = [ - delete_role(app, role.name) for role in user.roles if role and role.name not in EXISTING_ROLES - ] - appbuilder.sm.del_register_user(user) - break - - -def delete_users(app): - for user in app.appbuilder.sm.get_all_users(): - delete_user(app, user.username) +def delete_user(app: Flask, username): + users = app.config.get("SIMPLE_AUTH_MANAGER_USERS", []) + users = [user for user in users if user["username"] != username] + + app.config["SIMPLE_AUTH_MANAGER_USERS"] = users def assert_401(response): diff --git a/tests/test_utils/remote_user_api_auth_backend.py b/tests/test_utils/remote_user_api_auth_backend.py index b7714e5192e6..59df201e530e 100644 --- a/tests/test_utils/remote_user_api_auth_backend.py +++ b/tests/test_utils/remote_user_api_auth_backend.py @@ -15,17 +15,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Default authentication backend - everything is allowed""" - from __future__ import annotations import logging from functools import wraps from typing import TYPE_CHECKING, Callable, TypeVar, cast -from flask import Response, request -from flask_login import login_user +from flask import Response, request, session +from airflow.auth.managers.simple.user import SimpleAuthManagerUser from airflow.utils.airflow_flask_app import get_airflow_app if TYPE_CHECKING: @@ -36,25 +34,15 @@ CLIENT_AUTH: tuple[str, str] | AuthBase | None = None -def init_app(_): - """Initializes authentication backend""" +def init_app(_): ... T = TypeVar("T", bound=Callable) -def _lookup_user(user_email_or_username: str): - security_manager = get_airflow_app().appbuilder.sm - user = security_manager.find_user(email=user_email_or_username) or security_manager.find_user( - username=user_email_or_username - ) - if not user: - return None - - if not user.is_active: - return None - - return user +def _lookup_user(username: str): + users = get_airflow_app().config.get("SIMPLE_AUTH_MANAGER_USERS", []) + return next((user for user in users if user["username"] == username), None) def requires_authentication(function: T): @@ -69,13 +57,13 @@ def decorated(*args, **kwargs): log.debug("Looking for user: %s", user_id) - user = _lookup_user(user_id) - if not user: + user_dict = _lookup_user(user_id) + if not user_dict: return Response("Forbidden", 403) - log.debug("Found user: %s", user) + log.debug("Found user: %s", user_dict) + session["user"] = SimpleAuthManagerUser(username=user_dict["username"], role=user_dict["role"]) - login_user(user, remember=False) return function(*args, **kwargs) return cast(T, decorated) diff --git a/tests/www/views/test_views_custom_user_views.py b/tests/www/views/test_views_custom_user_views.py index ae6d0132827c..84947a8e5f36 100644 --- a/tests/www/views/test_views_custom_user_views.py +++ b/tests/www/views/test_views_custom_user_views.py @@ -27,7 +27,10 @@ from airflow import settings from airflow.security import permissions from airflow.www import app as application -from tests.test_utils.api_connexion_utils import create_user, delete_role +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import ( + create_user as create_user, + delete_role, +) from tests.test_utils.www import check_content_in_response, check_content_not_in_response, client_with_login pytestmark = pytest.mark.db_test diff --git a/tests/www/views/test_views_dagrun.py b/tests/www/views/test_views_dagrun.py index 39c17d086f37..d95955246ac7 100644 --- a/tests/www/views/test_views_dagrun.py +++ b/tests/www/views/test_views_dagrun.py @@ -24,7 +24,11 @@ from airflow.utils import timezone from airflow.utils.session import create_session from airflow.www.views import DagRunModelView -from tests.test_utils.api_connexion_utils import create_user, delete_roles, delete_user +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import ( + create_user, + delete_roles, + delete_user, +) from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS from tests.test_utils.www import check_content_in_response, check_content_not_in_response, client_with_login from tests.www.views.test_views_tasks import _get_appbuilder_pk_string diff --git a/tests/www/views/test_views_home.py b/tests/www/views/test_views_home.py index 539311504139..ddec0c0bcfed 100644 --- a/tests/www/views/test_views_home.py +++ b/tests/www/views/test_views_home.py @@ -27,7 +27,7 @@ from airflow.utils.state import State from airflow.www.utils import UIAlert from airflow.www.views import FILTER_LASTRUN_COOKIE, FILTER_STATUS_COOKIE, FILTER_TAGS_COOKIE -from tests.test_utils.api_connexion_utils import create_user +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import create_user from tests.test_utils.db import clear_db_dags, clear_db_import_errors, clear_db_serialized_dags from tests.test_utils.permissions import _resource_name from tests.test_utils.www import check_content_in_response, check_content_not_in_response, client_with_login diff --git a/tests/www/views/test_views_tasks.py b/tests/www/views/test_views_tasks.py index f5cc011fb6f0..7b65051724c2 100644 --- a/tests/www/views/test_views_tasks.py +++ b/tests/www/views/test_views_tasks.py @@ -44,7 +44,11 @@ from airflow.utils.state import DagRunState, State from airflow.utils.types import DagRunType from airflow.www.views import TaskInstanceModelView, _safe_parse_datetime -from tests.test_utils.api_connexion_utils import create_user, delete_roles, delete_user +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import ( + create_user, + delete_roles, + delete_user, +) from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS from tests.test_utils.config import conf_vars from tests.test_utils.db import clear_db_runs, clear_db_xcom diff --git a/tests/www/views/test_views_variable.py b/tests/www/views/test_views_variable.py index a91a12ddc470..b7fa8b37c52c 100644 --- a/tests/www/views/test_views_variable.py +++ b/tests/www/views/test_views_variable.py @@ -25,7 +25,7 @@ from airflow.models import Variable from airflow.security import permissions from airflow.utils.session import create_session -from tests.test_utils.api_connexion_utils import create_user +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import create_user from tests.test_utils.www import ( _check_last_log, check_content_in_response, From 7219d3049ae8f793e7bf7c80505e27afcd00e7e1 Mon Sep 17 00:00:00 2001 From: Brent Bovenzi Date: Tue, 1 Oct 2024 17:25:18 +0200 Subject: [PATCH 09/10] Add Docs button to Nav (#42586) * Add Docs button to new UI nav * Add Docs menu button to Nav * Use src alias * Address PR feedback, update documentation * Delete airflow/ui/.env.local --- .gitignore | 1 + airflow/ui/.env.example | 23 +++++++ airflow/ui/src/layouts/Nav/DocsButton.tsx | 67 +++++++++++++++++++ airflow/ui/src/layouts/{ => Nav}/Nav.tsx | 9 ++- .../ui/src/layouts/{ => Nav}/NavButton.tsx | 17 ++--- airflow/ui/src/layouts/Nav/index.tsx | 20 ++++++ airflow/ui/src/layouts/Nav/navButtonProps.ts | 30 +++++++++ airflow/ui/src/main.tsx | 2 +- airflow/ui/src/vite-env.d.ts | 9 +++ .../14_node_environment_setup.rst | 17 +++++ 10 files changed, 178 insertions(+), 17 deletions(-) create mode 100644 airflow/ui/.env.example create mode 100644 airflow/ui/src/layouts/Nav/DocsButton.tsx rename airflow/ui/src/layouts/{ => Nav}/Nav.tsx (94%) rename airflow/ui/src/layouts/{ => Nav}/NavButton.tsx (83%) create mode 100644 airflow/ui/src/layouts/Nav/index.tsx create mode 100644 airflow/ui/src/layouts/Nav/navButtonProps.ts diff --git a/.gitignore b/.gitignore index 257331cb4e90..a9c055041d98 100644 --- a/.gitignore +++ b/.gitignore @@ -111,6 +111,7 @@ celerybeat-schedule # dotenv .env +.env.local .autoenv*.zsh # virtualenv diff --git a/airflow/ui/.env.example b/airflow/ui/.env.example new file mode 100644 index 000000000000..9374d93de6bc --- /dev/null +++ b/airflow/ui/.env.example @@ -0,0 +1,23 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +#/ + + +# This is an example. You should make your own `.env.local` file for development + +VITE_FASTAPI_URL="http://localhost:29091" diff --git a/airflow/ui/src/layouts/Nav/DocsButton.tsx b/airflow/ui/src/layouts/Nav/DocsButton.tsx new file mode 100644 index 000000000000..07a4b93dfaed --- /dev/null +++ b/airflow/ui/src/layouts/Nav/DocsButton.tsx @@ -0,0 +1,67 @@ +/*! + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +import { + IconButton, + Link, + Menu, + MenuButton, + MenuItem, + MenuList, +} from "@chakra-ui/react"; +import { FiBookOpen } from "react-icons/fi"; + +import { navButtonProps } from "./navButtonProps"; + +const links = [ + { + href: "https://airflow.apache.org/docs/", + title: "Documentation", + }, + { + href: "https://github.com/apache/airflow", + title: "GitHub Repo", + }, + { + href: `${import.meta.env.VITE_FASTAPI_URL}/docs`, + title: "REST API Reference", + }, +]; + +export const DocsButton = () => ( + + } + {...navButtonProps} + /> + + {links.map((link) => ( + + {link.title} + + ))} + + +); diff --git a/airflow/ui/src/layouts/Nav.tsx b/airflow/ui/src/layouts/Nav/Nav.tsx similarity index 94% rename from airflow/ui/src/layouts/Nav.tsx rename to airflow/ui/src/layouts/Nav/Nav.tsx index 4900540cd96d..55bfd4480e0f 100644 --- a/airflow/ui/src/layouts/Nav.tsx +++ b/airflow/ui/src/layouts/Nav/Nav.tsx @@ -37,8 +37,10 @@ import { FiSun, } from "react-icons/fi"; -import { AirflowPin } from "../assets/AirflowPin"; -import { DagIcon } from "../assets/DagIcon"; +import { AirflowPin } from "src/assets/AirflowPin"; +import { DagIcon } from "src/assets/DagIcon"; + +import { DocsButton } from "./DocsButton"; import { NavButton } from "./NavButton"; export const Nav = () => { @@ -78,7 +80,7 @@ export const Nav = () => { } isDisabled - title="Datasets" + title="Assets" /> } @@ -103,6 +105,7 @@ export const Nav = () => { icon={} title="Return to legacy UI" /> + ( - diff --git a/airflow/ui/src/layouts/Nav/index.tsx b/airflow/ui/src/layouts/Nav/index.tsx new file mode 100644 index 000000000000..403e140919b0 --- /dev/null +++ b/airflow/ui/src/layouts/Nav/index.tsx @@ -0,0 +1,20 @@ +/*! + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +export { Nav } from "./Nav"; diff --git a/airflow/ui/src/layouts/Nav/navButtonProps.ts b/airflow/ui/src/layouts/Nav/navButtonProps.ts new file mode 100644 index 000000000000..740348bc9676 --- /dev/null +++ b/airflow/ui/src/layouts/Nav/navButtonProps.ts @@ -0,0 +1,30 @@ +/*! + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +import type { ButtonProps } from "@chakra-ui/react"; + +export const navButtonProps: ButtonProps = { + alignItems: "center", + borderRadius: "none", + flexDir: "column", + height: 16, + transition: "0.2s background-color ease-in-out", + variant: "ghost", + whiteSpace: "wrap", + width: 24, +}; diff --git a/airflow/ui/src/main.tsx b/airflow/ui/src/main.tsx index ca5fbed04b6c..7b762508ea7b 100644 --- a/airflow/ui/src/main.tsx +++ b/airflow/ui/src/main.tsx @@ -43,7 +43,7 @@ const queryClient = new QueryClient({ }, }); -axios.defaults.baseURL = "http://localhost:29091"; +axios.defaults.baseURL = import.meta.env.VITE_FASTAPI_URL; // redirect to login page if the API responds with unauthorized or forbidden errors axios.interceptors.response.use( diff --git a/airflow/ui/src/vite-env.d.ts b/airflow/ui/src/vite-env.d.ts index a1fdcdd1e6fc..193866687bff 100644 --- a/airflow/ui/src/vite-env.d.ts +++ b/airflow/ui/src/vite-env.d.ts @@ -1,3 +1,4 @@ +/* eslint-disable @typescript-eslint/consistent-type-definitions */ /*! * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file @@ -18,3 +19,11 @@ */ /// + +interface ImportMetaEnv { + readonly VITE_FASTAPI_URL: string; +} + +interface ImportMeta { + readonly env: ImportMetaEnv; +} diff --git a/contributing-docs/14_node_environment_setup.rst b/contributing-docs/14_node_environment_setup.rst index 8d98f0860fc8..7b10f0b0d5ed 100644 --- a/contributing-docs/14_node_environment_setup.rst +++ b/contributing-docs/14_node_environment_setup.rst @@ -84,6 +84,23 @@ Project Structure - ``/src/components`` shared components across the UI - ``/dist`` build files +Local Environment Variables +--------------------------- + +Copy the example environment + +.. code-block:: bash + + cp .env.example .env.local + +If you run into CORS issues, you may need to add some variables to your Breeze config, ``files/airflow-breeze-config/variables.env``: + +.. code-block:: bash + + export AIRFLOW__API__ACCESS_CONTROL_ALLOW_HEADERS="Origin, Access-Control-Request-Method" + export AIRFLOW__API__ACCESS_CONTROL_ALLOW_METHODS="*" + export AIRFLOW__API__ACCESS_CONTROL_ALLOW_ORIGINS="http://localhost:28080,http://localhost:8080" + DEPRECATED Airflow WWW From 05c43eeacc537cfa6b1affa6fdd8c3202f70c14b Mon Sep 17 00:00:00 2001 From: Elad Kalif <45845474+eladkal@users.noreply.github.com> Date: Tue, 1 Oct 2024 22:31:17 +0700 Subject: [PATCH 10/10] Update providers metadata 2024-10-01 (#42611) --- generated/provider_metadata.json | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/generated/provider_metadata.json b/generated/provider_metadata.json index a73e3da9f6fc..56199e2f82c3 100644 --- a/generated/provider_metadata.json +++ b/generated/provider_metadata.json @@ -2763,6 +2763,10 @@ "1.17.0": { "associated_airflow_version": "2.10.1", "date_released": "2024-09-24T13:49:56Z" + }, + "1.17.1": { + "associated_airflow_version": "2.10.1", + "date_released": "2024-10-01T09:05:14Z" } }, "databricks": { @@ -6225,6 +6229,10 @@ "1.12.0": { "associated_airflow_version": "2.10.1", "date_released": "2024-09-24T13:49:56Z" + }, + "1.12.1": { + "associated_airflow_version": "2.10.1", + "date_released": "2024-10-01T09:05:14Z" } }, "opensearch": {