Skip to content

Commit

Permalink
Make FAB auth manager login process compatible with Airflow 3 UI (apa…
Browse files Browse the repository at this point in the history
  • Loading branch information
vincbeck authored and got686-yandex committed Jan 30, 2025
1 parent c92985b commit 24121c4
Show file tree
Hide file tree
Showing 14 changed files with 213 additions and 48 deletions.
2 changes: 1 addition & 1 deletion airflow/api_fastapi/core_api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def init_flask_plugins(app: FastAPI) -> None:
stacklevel=2,
)

flask_app = create_app()
flask_app = create_app(enable_plugins=True)
app.mount("/pluginsv2", WSGIMiddleware(flask_app))


Expand Down
11 changes: 11 additions & 0 deletions airflow/auth/managers/base_auth_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@

from airflow.auth.managers.models.base_user import BaseUser
from airflow.auth.managers.models.resource_details import DagDetails
from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.models import DagModel
from airflow.typing_compat import Literal
from airflow.utils.jwt_signer import JWTSigner
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import NEW_SESSION, provide_session

Expand Down Expand Up @@ -100,6 +102,15 @@ def deserialize_user(self, token: dict[str, Any]) -> T:
def serialize_user(self, user: T) -> dict[str, Any]:
"""Create a dict from a user object."""

def get_jwt_token(self, user: T) -> str:
"""Return the JWT token from a user object."""
signer = JWTSigner(
secret_key=conf.get("api", "auth_jwt_secret"),
expiration_time_in_seconds=conf.getint("api", "auth_jwt_expiration_time"),
audience="front-apis",
)
return signer.generate_signed_token(self.serialize_user(user))

def get_user_id(self) -> str | None:
"""Return the user ID associated to the user in session."""
user = self.get_user()
Expand Down
8 changes: 1 addition & 7 deletions airflow/auth/managers/simple/views/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from airflow.api_fastapi.app import get_auth_manager
from airflow.auth.managers.simple.user import SimpleAuthManagerUser
from airflow.configuration import conf
from airflow.utils.jwt_signer import JWTSigner
from airflow.utils.state import State
from airflow.www.app import csrf
from airflow.www.views import AirflowBaseView
Expand Down Expand Up @@ -92,12 +91,7 @@ def login_submit(self):
# Will be removed once Airflow uses the new UI
session["user"] = user

signer = JWTSigner(
secret_key=conf.get("api", "auth_jwt_secret"),
expiration_time_in_seconds=conf.getint("api", "auth_jwt_expiration_time"),
audience="front-apis",
)
token = signer.generate_signed_token(get_auth_manager().serialize_user(user))
token = get_auth_manager().get_jwt_token(user)

if next_url:
return redirect(self._get_redirect_url(next_url, token))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
@cache
def _return_appbuilder(app: Flask) -> AirflowAppBuilder:
"""Return an appbuilder instance for the given app."""
init_appbuilder(app)
init_appbuilder(app, enable_plugins=False)
init_plugins(app)
init_airflow_session_interface(app)
return app.appbuilder # type: ignore[attr-defined]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def get_fastapi_app(self) -> FastAPI | None:
if not flask_blueprint:
return None

flask_app = create_app()
flask_app = create_app(enable_plugins=False)
flask_app.register_blueprint(flask_blueprint)

app = FastAPI(
Expand Down
9 changes: 5 additions & 4 deletions providers/src/airflow/providers/fab/www/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
csrf = CSRFProtect()


def create_app():
def create_app(enable_plugins: bool):
"""Create a new instance of Airflow WWW app."""
flask_app = Flask(__name__)
flask_app.secret_key = conf.get("webserver", "SECRET_KEY")
Expand All @@ -66,10 +66,11 @@ def create_app():
init_api_auth(flask_app)

with flask_app.app_context():
init_appbuilder(flask_app)
init_plugins(flask_app)
init_appbuilder(flask_app, enable_plugins=enable_plugins)
if enable_plugins:
init_plugins(flask_app)
init_error_handlers(flask_app)
init_jinja_globals(flask_app)
init_jinja_globals(flask_app, enable_plugins=enable_plugins)
init_xframe_protection(flask_app)
return flask_app

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,10 @@
from flask_appbuilder.views import IndexView

from airflow import settings
from airflow.api_fastapi.app import create_auth_manager
from airflow.api_fastapi.app import create_auth_manager, get_auth_manager
from airflow.configuration import conf
from airflow.providers.fab.www.security_manager import AirflowSecurityManagerV2
from airflow.providers.fab.www.views import FabIndexView

if TYPE_CHECKING:
from flask import Flask
Expand Down Expand Up @@ -109,6 +110,7 @@ def __init__(
base_template="airflow/main.html",
static_folder="static/appbuilder",
static_url_path="/appbuilder",
enable_plugins: bool = False,
):
"""
App-builder constructor.
Expand All @@ -125,6 +127,15 @@ def __init__(
optional, your override for the global static folder
:param static_url_path:
optional, your override for the global static url path
:param enable_plugins:
optional, whether plugins are enabled for this app. AirflowAppBuilder from FAB provider can be
instantiated in two modes:
- Plugins enabled. The Flask application is responsible to execute Airflow 2 plugins.
This application is only running if there are Airflow 2 plugins defined as part of the Airflow
environment
- Plugins disabled. The Flask application is responsible to execute the FAB auth manager login
process. This application is only running if FAB auth manager is the auth manager configured
in the Airflow environment
"""
from airflow.providers_manager import ProvidersManager

Expand All @@ -139,6 +150,7 @@ def __init__(
self.static_folder = static_folder
self.static_url_path = static_url_path
self.app = app
self.enable_plugins = enable_plugins
self.update_perms = conf.getboolean("fab", "UPDATE_FAB_PERMS")
self.auth_rate_limited = conf.getboolean("fab", "AUTH_RATE_LIMITED")
self.auth_rate_limit = conf.get("fab", "AUTH_RATE_LIMIT")
Expand Down Expand Up @@ -172,8 +184,10 @@ def init_app(self, app, session):
_index_view = app.config.get("FAB_INDEX_VIEW", None)
if _index_view is not None:
self.indexview = dynamic_class_import(_index_view)
elif not self.enable_plugins:
self.indexview = FabIndexView
else:
self.indexview = self.indexview or IndexView
self.indexview = IndexView
_menu = app.config.get("FAB_MENU", None)
if _menu is not None:
self.menu = dynamic_class_import(_menu)
Expand Down Expand Up @@ -282,6 +296,7 @@ def _add_admin_views(self):
"""Register indexview, utilview (back function), babel views and Security views."""
self.indexview = self._check_and_init(self.indexview)
self.add_view_no_menu(self.indexview)
get_auth_manager().register_views()

def _add_addon_views(self):
"""Register declared addons."""
Expand Down Expand Up @@ -500,7 +515,6 @@ def add_view_no_menu(self, baseview, endpoint=None, static_folder=None):

@property
def get_url_for_index(self):
# TODO: Return the fast api application homepage
return url_for(f"{self.indexview.endpoint}.{self.indexview.default_view}")

def get_url_for_locale(self, lang):
Expand Down Expand Up @@ -560,10 +574,11 @@ def _process_inner_views(self):
view.get_init_inner_views().append(v)


def init_appbuilder(app: Flask) -> AirflowAppBuilder:
def init_appbuilder(app: Flask, enable_plugins: bool) -> AirflowAppBuilder:
"""Init `Flask App Builder <https://flask-appbuilder.readthedocs.io/en/latest/>`__."""
return AirflowAppBuilder(
app=app,
session=settings.Session,
base_template="airflow/main.html",
enable_plugins=enable_plugins,
)
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
logger = logging.getLogger(__name__)


def init_jinja_globals(app):
def init_jinja_globals(app, enable_plugins: bool):
"""Add extra globals variable to Jinja context."""
server_timezone = conf.get("core", "default_timezone")
if server_timezone == "system":
Expand Down Expand Up @@ -70,6 +70,8 @@ def prepare_jinja_globals():
"state_color_mapping": STATE_COLORS,
"airflow_version": airflow_version,
"git_version": git_version,
"show_plugin_message": enable_plugins,
"disable_nav_bar": not enable_plugins,
}

# Extra global specific to auth manager
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
{% from 'airflow/_messages.html' import show_message %}

{% block page_title -%}
Airflow - Airflow 2 plugins compatibility view
Airflow
{% endblock %}

{% block head_css %}
Expand Down Expand Up @@ -53,23 +53,21 @@
{% endblock %}

{% block messages %}
{% call show_message(category='warning', dismissible=false) %}
<p>
You have a plugin that is using a FAB view or Flask Blueprint, which was used for the Airflow 2 UI, and is now
deprecated. Please update your plugin to be compatible with the Airflow 3 UI.
</p>
{% endcall %}
{% if show_plugin_message %}
{% call show_message(category='warning', dismissible=false) %}
<p>
You have a plugin that is using a FAB view or Flask Blueprint, which was used for the Airflow 2 UI, and is now
deprecated. Please update your plugin to be compatible with the Airflow 3 UI.
</p>
{% endcall %}
{% endif %}
{% endblock %}

{% block tail_js %}
{{ super() }}
<script>
// below variables are used in main.js
// keep as var, changing to const or let breaks other code
var Airflow = {
serverTimezone: '{{ server_timezone }}',
defaultUITimezone: '{{ default_ui_timezone }}',
};
var hostName = '{{ hostname }}';
$('time[title]').tooltip();
</script>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#}

{% set menu = appbuilder.menu %}
{% set languages = appbuilder.languages %}

<div class="navbar navbar-fixed-top" role="navigation" style="background-color: {{ navbar_color }};">
<div class="container">
Expand Down Expand Up @@ -46,7 +47,13 @@
</div>
<div class="navbar-collapse collapse">
<ul class="nav navbar-nav">
{%- if disable_nav_bar is not defined or not disable_nav_bar -%}
{% include 'appbuilder/navbar_menu.html' %}
{%- endif -%}
</ul>
<ul class="nav navbar-nav navbar-right">
<li class="active">
{% include 'appbuilder/navbar_right.html' %}
</ul>
</div>
</div>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
{#
Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
#}

{% macro locale_menu(languages) %}
{% set locale = session['locale'] %}
{% if not locale %}
{% set locale = 'en' %}
{% endif %}
<li class="dropdown">
<a class="dropdown-toggle" href="javascript:void(0)">
<div class="f16"><i class="flag {{languages[locale].get('flag')}}"></i><b class="caret"></b></div>
</a>
{% if languages.keys()|length > 1 %}
<ul class="dropdown-menu">
<li class="dropdown">
{% for lang in languages %}
{% if lang != locale %}
<a href="{{appbuilder.get_url_for_locale(lang)}}">
<div class="f16"><i class="flag {{languages[lang].get('flag')}}"></i> - {{languages[lang].get('name')}}
</div></a>
{% endif %}
{% endfor %}
</li>
</ul>
{% endif %}
</li>
{% endmacro %}

{# clock and timezone menu #}
<li class="dropdown" id="timezone-dropdown">
<a class="dropdown-toggle" style="display:none" href="#">
<time id="clock" class="js-tooltip"></time>
<b class="caret"></b>
</a>
<ul class="dropdown-menu" id="timezone-menu">
<li id="timezone-utc"><a data-timezone="UTC" href="#">UTC</a></li>
<li id="timezone-server" style="display: none;"><a data-timezone="{{ server_timezone }}" href="#">{{ server_timezone }}</a></li>
<li id="timezone-local"><a href="#">Local</a></li>
<li id="timezone-manual" style="display: none"><a data-timezone="" href="#"></a></li>
<li role="separator" class="divider"></li>
<li>
<form>
<label for="timezone-other">Other</label>
<input id="timezone-other" placeholder="Select Timezone name" autocomplete="off" tabindex="-1">
</form>
</li>
</ul>
</li>
24 changes: 24 additions & 0 deletions providers/src/airflow/providers/fab/www/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,39 @@
import traceback

from flask import (
g,
redirect,
render_template,
)
from flask_appbuilder import IndexView, expose

from airflow.api_fastapi.app import get_auth_manager
from airflow.configuration import conf
from airflow.utils.net import get_hostname
from airflow.version import version


class FabIndexView(IndexView):
"""
A simple view that inherits from FAB index view.
The only goal of this view is to redirect the user to the Airflow 3 UI index page if the user is
authenticated. It is impossible to redirect the user directly to the Airflow 3 UI index page before
redirecting them to this page because FAB itself defines the logic redirection and does not allow external
redirect.
It is impossible to redirect the user before
"""

@expose("/")
def index(self):
if g.user is not None and g.user.is_authenticated:
token = get_auth_manager().get_jwt_token(g.user)
return redirect(f"/webapp?token={token}", code=302)
else:
super().index(self)


def show_traceback(error):
"""Show Traceback for a given error."""
is_logged_in = get_auth_manager().is_logged_in()
Expand Down
17 changes: 12 additions & 5 deletions tests/auth/managers/simple/views/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,20 @@ def test_logout_redirects_to_login_and_clear_user(self, simple_app):
("test", "test", True, {"next": "next_url"}, "next_url?token=token"),
],
)
@patch("airflow.auth.managers.simple.views.auth.JWTSigner")
@patch("airflow.auth.managers.simple.views.auth.get_auth_manager")
def test_login_submit(
self, mock_jwt_signer, simple_app, username, password, is_successful, query_params, expected_redirect
self,
mock_get_auth_manager,
simple_app,
username,
password,
is_successful,
query_params,
expected_redirect,
):
signer = Mock()
signer.generate_signed_token.return_value = "token"
mock_jwt_signer.return_value = signer
auth_manager = Mock()
auth_manager.get_jwt_token.return_value = "token"
mock_get_auth_manager.return_value = auth_manager
with simple_app.test_client() as client:
response = client.post(
"/login_submit", query_string=query_params, data={"username": username, "password": password}
Expand Down
Loading

0 comments on commit 24121c4

Please sign in to comment.