From 24121c46cd4f87501497b048912e753eb42c26dc Mon Sep 17 00:00:00 2001 From: Vincent <97131062+vincbeck@users.noreply.github.com> Date: Tue, 21 Jan 2025 10:52:00 -0500 Subject: [PATCH] Make FAB auth manager login process compatible with Airflow 3 UI (#45765) --- airflow/api_fastapi/core_api/app.py | 2 +- airflow/auth/managers/base_auth_manager.py | 11 +++ airflow/auth/managers/simple/views/auth.py | 8 +-- .../fab/auth_manager/cli_commands/utils.py | 2 +- .../fab/auth_manager/fab_auth_manager.py | 2 +- .../src/airflow/providers/fab/www/app.py | 9 +-- .../fab/www/extensions/init_appbuilder.py | 23 +++++-- .../fab/www/extensions/init_jinja_globals.py | 4 +- .../fab/www/templates/airflow/main.html | 20 +++--- .../fab/www/templates/appbuilder/navbar.html | 7 ++ .../templates/appbuilder/navbar_right.html | 64 +++++++++++++++++ .../src/airflow/providers/fab/www/views.py | 24 +++++++ tests/auth/managers/simple/views/test_auth.py | 17 +++-- tests/auth/managers/test_base_auth_manager.py | 68 +++++++++++++++---- 14 files changed, 213 insertions(+), 48 deletions(-) create mode 100644 providers/src/airflow/providers/fab/www/templates/appbuilder/navbar_right.html diff --git a/airflow/api_fastapi/core_api/app.py b/airflow/api_fastapi/core_api/app.py index 6099c5b654ac0..08f37812c3c50 100644 --- a/airflow/api_fastapi/core_api/app.py +++ b/airflow/api_fastapi/core_api/app.py @@ -132,7 +132,7 @@ def init_flask_plugins(app: FastAPI) -> None: stacklevel=2, ) - flask_app = create_app() + flask_app = create_app(enable_plugins=True) app.mount("/pluginsv2", WSGIMiddleware(flask_app)) diff --git a/airflow/auth/managers/base_auth_manager.py b/airflow/auth/managers/base_auth_manager.py index 6a9ef11e3d785..fe86bc8f05acf 100644 --- a/airflow/auth/managers/base_auth_manager.py +++ b/airflow/auth/managers/base_auth_manager.py @@ -24,9 +24,11 @@ from airflow.auth.managers.models.base_user import BaseUser from airflow.auth.managers.models.resource_details import DagDetails +from airflow.configuration import conf from airflow.exceptions import AirflowException from airflow.models import DagModel from airflow.typing_compat import Literal +from airflow.utils.jwt_signer import JWTSigner from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.session import NEW_SESSION, provide_session @@ -100,6 +102,15 @@ def deserialize_user(self, token: dict[str, Any]) -> T: def serialize_user(self, user: T) -> dict[str, Any]: """Create a dict from a user object.""" + def get_jwt_token(self, user: T) -> str: + """Return the JWT token from a user object.""" + signer = JWTSigner( + secret_key=conf.get("api", "auth_jwt_secret"), + expiration_time_in_seconds=conf.getint("api", "auth_jwt_expiration_time"), + audience="front-apis", + ) + return signer.generate_signed_token(self.serialize_user(user)) + def get_user_id(self) -> str | None: """Return the user ID associated to the user in session.""" user = self.get_user() diff --git a/airflow/auth/managers/simple/views/auth.py b/airflow/auth/managers/simple/views/auth.py index b292fc05541b6..64c697ecbcc3d 100644 --- a/airflow/auth/managers/simple/views/auth.py +++ b/airflow/auth/managers/simple/views/auth.py @@ -25,7 +25,6 @@ from airflow.api_fastapi.app import get_auth_manager from airflow.auth.managers.simple.user import SimpleAuthManagerUser from airflow.configuration import conf -from airflow.utils.jwt_signer import JWTSigner from airflow.utils.state import State from airflow.www.app import csrf from airflow.www.views import AirflowBaseView @@ -92,12 +91,7 @@ def login_submit(self): # Will be removed once Airflow uses the new UI session["user"] = user - signer = JWTSigner( - secret_key=conf.get("api", "auth_jwt_secret"), - expiration_time_in_seconds=conf.getint("api", "auth_jwt_expiration_time"), - audience="front-apis", - ) - token = signer.generate_signed_token(get_auth_manager().serialize_user(user)) + token = get_auth_manager().get_jwt_token(user) if next_url: return redirect(self._get_redirect_url(next_url, token)) diff --git a/providers/src/airflow/providers/fab/auth_manager/cli_commands/utils.py b/providers/src/airflow/providers/fab/auth_manager/cli_commands/utils.py index ee7c6f8202a25..badd7fd08aeaa 100644 --- a/providers/src/airflow/providers/fab/auth_manager/cli_commands/utils.py +++ b/providers/src/airflow/providers/fab/auth_manager/cli_commands/utils.py @@ -41,7 +41,7 @@ @cache def _return_appbuilder(app: Flask) -> AirflowAppBuilder: """Return an appbuilder instance for the given app.""" - init_appbuilder(app) + init_appbuilder(app, enable_plugins=False) init_plugins(app) init_airflow_session_interface(app) return app.appbuilder # type: ignore[attr-defined] diff --git a/providers/src/airflow/providers/fab/auth_manager/fab_auth_manager.py b/providers/src/airflow/providers/fab/auth_manager/fab_auth_manager.py index 4c889a9c14e3f..2d58d79e41b00 100644 --- a/providers/src/airflow/providers/fab/auth_manager/fab_auth_manager.py +++ b/providers/src/airflow/providers/fab/auth_manager/fab_auth_manager.py @@ -181,7 +181,7 @@ def get_fastapi_app(self) -> FastAPI | None: if not flask_blueprint: return None - flask_app = create_app() + flask_app = create_app(enable_plugins=False) flask_app.register_blueprint(flask_blueprint) app = FastAPI( diff --git a/providers/src/airflow/providers/fab/www/app.py b/providers/src/airflow/providers/fab/www/app.py index 0414fc5e408b5..6890dc96abbe3 100644 --- a/providers/src/airflow/providers/fab/www/app.py +++ b/providers/src/airflow/providers/fab/www/app.py @@ -41,7 +41,7 @@ csrf = CSRFProtect() -def create_app(): +def create_app(enable_plugins: bool): """Create a new instance of Airflow WWW app.""" flask_app = Flask(__name__) flask_app.secret_key = conf.get("webserver", "SECRET_KEY") @@ -66,10 +66,11 @@ def create_app(): init_api_auth(flask_app) with flask_app.app_context(): - init_appbuilder(flask_app) - init_plugins(flask_app) + init_appbuilder(flask_app, enable_plugins=enable_plugins) + if enable_plugins: + init_plugins(flask_app) init_error_handlers(flask_app) - init_jinja_globals(flask_app) + init_jinja_globals(flask_app, enable_plugins=enable_plugins) init_xframe_protection(flask_app) return flask_app diff --git a/providers/src/airflow/providers/fab/www/extensions/init_appbuilder.py b/providers/src/airflow/providers/fab/www/extensions/init_appbuilder.py index b3f5551aeee3b..555f0501a6a61 100644 --- a/providers/src/airflow/providers/fab/www/extensions/init_appbuilder.py +++ b/providers/src/airflow/providers/fab/www/extensions/init_appbuilder.py @@ -39,9 +39,10 @@ from flask_appbuilder.views import IndexView from airflow import settings -from airflow.api_fastapi.app import create_auth_manager +from airflow.api_fastapi.app import create_auth_manager, get_auth_manager from airflow.configuration import conf from airflow.providers.fab.www.security_manager import AirflowSecurityManagerV2 +from airflow.providers.fab.www.views import FabIndexView if TYPE_CHECKING: from flask import Flask @@ -109,6 +110,7 @@ def __init__( base_template="airflow/main.html", static_folder="static/appbuilder", static_url_path="/appbuilder", + enable_plugins: bool = False, ): """ App-builder constructor. @@ -125,6 +127,15 @@ def __init__( optional, your override for the global static folder :param static_url_path: optional, your override for the global static url path + :param enable_plugins: + optional, whether plugins are enabled for this app. AirflowAppBuilder from FAB provider can be + instantiated in two modes: + - Plugins enabled. The Flask application is responsible to execute Airflow 2 plugins. + This application is only running if there are Airflow 2 plugins defined as part of the Airflow + environment + - Plugins disabled. The Flask application is responsible to execute the FAB auth manager login + process. This application is only running if FAB auth manager is the auth manager configured + in the Airflow environment """ from airflow.providers_manager import ProvidersManager @@ -139,6 +150,7 @@ def __init__( self.static_folder = static_folder self.static_url_path = static_url_path self.app = app + self.enable_plugins = enable_plugins self.update_perms = conf.getboolean("fab", "UPDATE_FAB_PERMS") self.auth_rate_limited = conf.getboolean("fab", "AUTH_RATE_LIMITED") self.auth_rate_limit = conf.get("fab", "AUTH_RATE_LIMIT") @@ -172,8 +184,10 @@ def init_app(self, app, session): _index_view = app.config.get("FAB_INDEX_VIEW", None) if _index_view is not None: self.indexview = dynamic_class_import(_index_view) + elif not self.enable_plugins: + self.indexview = FabIndexView else: - self.indexview = self.indexview or IndexView + self.indexview = IndexView _menu = app.config.get("FAB_MENU", None) if _menu is not None: self.menu = dynamic_class_import(_menu) @@ -282,6 +296,7 @@ def _add_admin_views(self): """Register indexview, utilview (back function), babel views and Security views.""" self.indexview = self._check_and_init(self.indexview) self.add_view_no_menu(self.indexview) + get_auth_manager().register_views() def _add_addon_views(self): """Register declared addons.""" @@ -500,7 +515,6 @@ def add_view_no_menu(self, baseview, endpoint=None, static_folder=None): @property def get_url_for_index(self): - # TODO: Return the fast api application homepage return url_for(f"{self.indexview.endpoint}.{self.indexview.default_view}") def get_url_for_locale(self, lang): @@ -560,10 +574,11 @@ def _process_inner_views(self): view.get_init_inner_views().append(v) -def init_appbuilder(app: Flask) -> AirflowAppBuilder: +def init_appbuilder(app: Flask, enable_plugins: bool) -> AirflowAppBuilder: """Init `Flask App Builder `__.""" return AirflowAppBuilder( app=app, session=settings.Session, base_template="airflow/main.html", + enable_plugins=enable_plugins, ) diff --git a/providers/src/airflow/providers/fab/www/extensions/init_jinja_globals.py b/providers/src/airflow/providers/fab/www/extensions/init_jinja_globals.py index f7abe34154dc5..177ed158b959b 100644 --- a/providers/src/airflow/providers/fab/www/extensions/init_jinja_globals.py +++ b/providers/src/airflow/providers/fab/www/extensions/init_jinja_globals.py @@ -30,7 +30,7 @@ logger = logging.getLogger(__name__) -def init_jinja_globals(app): +def init_jinja_globals(app, enable_plugins: bool): """Add extra globals variable to Jinja context.""" server_timezone = conf.get("core", "default_timezone") if server_timezone == "system": @@ -70,6 +70,8 @@ def prepare_jinja_globals(): "state_color_mapping": STATE_COLORS, "airflow_version": airflow_version, "git_version": git_version, + "show_plugin_message": enable_plugins, + "disable_nav_bar": not enable_plugins, } # Extra global specific to auth manager diff --git a/providers/src/airflow/providers/fab/www/templates/airflow/main.html b/providers/src/airflow/providers/fab/www/templates/airflow/main.html index e6c00bd06660f..25ce3c0439a01 100644 --- a/providers/src/airflow/providers/fab/www/templates/airflow/main.html +++ b/providers/src/airflow/providers/fab/www/templates/airflow/main.html @@ -21,7 +21,7 @@ {% from 'airflow/_messages.html' import show_message %} {% block page_title -%} - Airflow - Airflow 2 plugins compatibility view + Airflow {% endblock %} {% block head_css %} @@ -53,12 +53,14 @@ {% endblock %} {% block messages %} - {% call show_message(category='warning', dismissible=false) %} -

- You have a plugin that is using a FAB view or Flask Blueprint, which was used for the Airflow 2 UI, and is now - deprecated. Please update your plugin to be compatible with the Airflow 3 UI. -

- {% endcall %} + {% if show_plugin_message %} + {% call show_message(category='warning', dismissible=false) %} +

+ You have a plugin that is using a FAB view or Flask Blueprint, which was used for the Airflow 2 UI, and is now + deprecated. Please update your plugin to be compatible with the Airflow 3 UI. +

+ {% endcall %} + {% endif %} {% endblock %} {% block tail_js %} @@ -66,10 +68,6 @@ diff --git a/providers/src/airflow/providers/fab/www/templates/appbuilder/navbar.html b/providers/src/airflow/providers/fab/www/templates/appbuilder/navbar.html index dba354fb1310a..76cbcd8e2ddb4 100644 --- a/providers/src/airflow/providers/fab/www/templates/appbuilder/navbar.html +++ b/providers/src/airflow/providers/fab/www/templates/appbuilder/navbar.html @@ -18,6 +18,7 @@ #} {% set menu = appbuilder.menu %} +{% set languages = appbuilder.languages %} diff --git a/providers/src/airflow/providers/fab/www/templates/appbuilder/navbar_right.html b/providers/src/airflow/providers/fab/www/templates/appbuilder/navbar_right.html new file mode 100644 index 0000000000000..54254f6d4266c --- /dev/null +++ b/providers/src/airflow/providers/fab/www/templates/appbuilder/navbar_right.html @@ -0,0 +1,64 @@ +{# + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +#} + +{% macro locale_menu(languages) %} + {% set locale = session['locale'] %} + {% if not locale %} + {% set locale = 'en' %} + {% endif %} + +{% endmacro %} + +{# clock and timezone menu #} + diff --git a/providers/src/airflow/providers/fab/www/views.py b/providers/src/airflow/providers/fab/www/views.py index 43ac276897e85..925a777c26d37 100644 --- a/providers/src/airflow/providers/fab/www/views.py +++ b/providers/src/airflow/providers/fab/www/views.py @@ -21,8 +21,11 @@ import traceback from flask import ( + g, + redirect, render_template, ) +from flask_appbuilder import IndexView, expose from airflow.api_fastapi.app import get_auth_manager from airflow.configuration import conf @@ -30,6 +33,27 @@ from airflow.version import version +class FabIndexView(IndexView): + """ + A simple view that inherits from FAB index view. + + The only goal of this view is to redirect the user to the Airflow 3 UI index page if the user is + authenticated. It is impossible to redirect the user directly to the Airflow 3 UI index page before + redirecting them to this page because FAB itself defines the logic redirection and does not allow external + redirect. + + It is impossible to redirect the user before + """ + + @expose("/") + def index(self): + if g.user is not None and g.user.is_authenticated: + token = get_auth_manager().get_jwt_token(g.user) + return redirect(f"/webapp?token={token}", code=302) + else: + super().index(self) + + def show_traceback(error): """Show Traceback for a given error.""" is_logged_in = get_auth_manager().is_logged_in() diff --git a/tests/auth/managers/simple/views/test_auth.py b/tests/auth/managers/simple/views/test_auth.py index 0eccf0dc9ec1d..633dcbd0f2170 100644 --- a/tests/auth/managers/simple/views/test_auth.py +++ b/tests/auth/managers/simple/views/test_auth.py @@ -64,13 +64,20 @@ def test_logout_redirects_to_login_and_clear_user(self, simple_app): ("test", "test", True, {"next": "next_url"}, "next_url?token=token"), ], ) - @patch("airflow.auth.managers.simple.views.auth.JWTSigner") + @patch("airflow.auth.managers.simple.views.auth.get_auth_manager") def test_login_submit( - self, mock_jwt_signer, simple_app, username, password, is_successful, query_params, expected_redirect + self, + mock_get_auth_manager, + simple_app, + username, + password, + is_successful, + query_params, + expected_redirect, ): - signer = Mock() - signer.generate_signed_token.return_value = "token" - mock_jwt_signer.return_value = signer + auth_manager = Mock() + auth_manager.get_jwt_token.return_value = "token" + mock_get_auth_manager.return_value = auth_manager with simple_app.test_client() as client: response = client.post( "/login_submit", query_string=query_params, data={"username": username, "password": password} diff --git a/tests/auth/managers/test_base_auth_manager.py b/tests/auth/managers/test_base_auth_manager.py index 4406ae9d43607..370e401da0609 100644 --- a/tests/auth/managers/test_base_auth_manager.py +++ b/tests/auth/managers/test_base_auth_manager.py @@ -44,16 +44,27 @@ from airflow.www.extensions.init_appbuilder import AirflowAppBuilder -class EmptyAuthManager(BaseAuthManager[BaseUser]): +class BaseAuthManagerUserTest(BaseUser): + def __init__(self, *, name: str) -> None: + self.name = name + + def get_id(self) -> str: + return self.name + + def get_name(self) -> str: + return self.name + + +class EmptyAuthManager(BaseAuthManager[BaseAuthManagerUserTest]): appbuilder: AirflowAppBuilder | None = None - def get_user(self) -> BaseUser: + def get_user(self) -> BaseAuthManagerUserTest: raise NotImplementedError() - def deserialize_user(self, token: dict[str, Any]) -> BaseUser: + def deserialize_user(self, token: dict[str, Any]) -> BaseAuthManagerUserTest: raise NotImplementedError() - def serialize_user(self, user: BaseUser) -> dict[str, Any]: + def serialize_user(self, user: BaseAuthManagerUserTest) -> dict[str, Any]: raise NotImplementedError() def is_authorized_configuration( @@ -61,7 +72,7 @@ def is_authorized_configuration( *, method: ResourceMethod, details: ConfigurationDetails | None = None, - user: BaseUser | None = None, + user: BaseAuthManagerUserTest | None = None, ) -> bool: raise NotImplementedError() @@ -70,7 +81,7 @@ def is_authorized_connection( *, method: ResourceMethod, details: ConnectionDetails | None = None, - user: BaseUser | None = None, + user: BaseAuthManagerUserTest | None = None, ) -> bool: raise NotImplementedError() @@ -80,30 +91,44 @@ def is_authorized_dag( method: ResourceMethod, access_entity: DagAccessEntity | None = None, details: DagDetails | None = None, - user: BaseUser | None = None, + user: BaseAuthManagerUserTest | None = None, ) -> bool: raise NotImplementedError() def is_authorized_asset( - self, *, method: ResourceMethod, details: AssetDetails | None = None, user: BaseUser | None = None + self, + *, + method: ResourceMethod, + details: AssetDetails | None = None, + user: BaseAuthManagerUserTest | None = None, ) -> bool: raise NotImplementedError() def is_authorized_pool( - self, *, method: ResourceMethod, details: PoolDetails | None = None, user: BaseUser | None = None + self, + *, + method: ResourceMethod, + details: PoolDetails | None = None, + user: BaseAuthManagerUserTest | None = None, ) -> bool: raise NotImplementedError() def is_authorized_variable( - self, *, method: ResourceMethod, details: VariableDetails | None = None, user: BaseUser | None = None + self, + *, + method: ResourceMethod, + details: VariableDetails | None = None, + user: BaseAuthManagerUserTest | None = None, ) -> bool: raise NotImplementedError() - def is_authorized_view(self, *, access_view: AccessView, user: BaseUser | None = None) -> bool: + def is_authorized_view( + self, *, access_view: AccessView, user: BaseAuthManagerUserTest | None = None + ) -> bool: raise NotImplementedError() def is_authorized_custom_view( - self, *, method: ResourceMethod | str, resource_name: str, user: BaseUser | None = None + self, *, method: ResourceMethod | str, resource_name: str, user: BaseAuthManagerUserTest | None = None ): raise NotImplementedError() @@ -165,6 +190,23 @@ def test_get_user_id_raise_exception_when_no_user(self, auth_manager): def test_get_url_user_profile_return_none(self, auth_manager): assert auth_manager.get_url_user_profile() is None + @patch("airflow.auth.managers.base_auth_manager.JWTSigner") + @patch.object(EmptyAuthManager, "serialize_user") + def test_get_jwt_token(self, mock_serialize_user, mock_jwt_signer, auth_manager): + token = "token" + serialized_user = "serialized_user" + signer = Mock() + signer.generate_signed_token.return_value = token + mock_jwt_signer.return_value = signer + mock_serialize_user.return_value = serialized_user + user = BaseAuthManagerUserTest(name="test") + + result = auth_manager.get_jwt_token(user) + + mock_serialize_user.assert_called_once_with(user) + signer.generate_signed_token.assert_called_once_with(serialized_user) + assert result == token + @pytest.mark.parametrize( "return_values, expected", [ @@ -279,7 +321,7 @@ def side_effect_func( method: ResourceMethod, access_entity: DagAccessEntity | None = None, details: DagDetails | None = None, - user: BaseUser | None = None, + user: BaseAuthManagerUserTest | None = None, ): if not details: return access_all