This is an automated email from the ASF dual-hosted git repository.
vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 2d9ab54eeff Make FAB auth manager login process compatible with
Airflow 3 UI (#45765)
2d9ab54eeff is described below
commit 2d9ab54eeff14b8839a4c82e1713ede9f37f02e2
Author: Vincent <[email protected]>
AuthorDate: Tue Jan 21 10:52:00 2025 -0500
Make FAB auth manager login process compatible with Airflow 3 UI (#45765)
---
airflow/api_fastapi/core_api/app.py | 2 +-
airflow/auth/managers/base_auth_manager.py | 11 ++++
airflow/auth/managers/simple/views/auth.py | 8 +--
.../fab/auth_manager/cli_commands/utils.py | 2 +-
.../providers/fab/auth_manager/fab_auth_manager.py | 2 +-
providers/src/airflow/providers/fab/www/app.py | 9 +--
.../fab/www/extensions/init_appbuilder.py | 23 ++++++--
.../fab/www/extensions/init_jinja_globals.py | 4 +-
.../providers/fab/www/templates/airflow/main.html | 20 +++----
.../fab/www/templates/appbuilder/navbar.html | 7 +++
.../fab/www/templates/appbuilder/navbar_right.html | 64 ++++++++++++++++++++
providers/src/airflow/providers/fab/www/views.py | 24 ++++++++
tests/auth/managers/simple/views/test_auth.py | 17 ++++--
tests/auth/managers/test_base_auth_manager.py | 68 +++++++++++++++++-----
14 files changed, 213 insertions(+), 48 deletions(-)
diff --git a/airflow/api_fastapi/core_api/app.py
b/airflow/api_fastapi/core_api/app.py
index 6099c5b654a..08f37812c3c 100644
--- a/airflow/api_fastapi/core_api/app.py
+++ b/airflow/api_fastapi/core_api/app.py
@@ -132,7 +132,7 @@ def init_flask_plugins(app: FastAPI) -> None:
stacklevel=2,
)
- flask_app = create_app()
+ flask_app = create_app(enable_plugins=True)
app.mount("/pluginsv2", WSGIMiddleware(flask_app))
diff --git a/airflow/auth/managers/base_auth_manager.py
b/airflow/auth/managers/base_auth_manager.py
index 6a9ef11e3d7..fe86bc8f05a 100644
--- a/airflow/auth/managers/base_auth_manager.py
+++ b/airflow/auth/managers/base_auth_manager.py
@@ -24,9 +24,11 @@ from sqlalchemy import select
from airflow.auth.managers.models.base_user import BaseUser
from airflow.auth.managers.models.resource_details import DagDetails
+from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.models import DagModel
from airflow.typing_compat import Literal
+from airflow.utils.jwt_signer import JWTSigner
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import NEW_SESSION, provide_session
@@ -100,6 +102,15 @@ class BaseAuthManager(Generic[T], LoggingMixin):
def serialize_user(self, user: T) -> dict[str, Any]:
"""Create a dict from a user object."""
+ def get_jwt_token(self, user: T) -> str:
+ """Return the JWT token from a user object."""
+ signer = JWTSigner(
+ secret_key=conf.get("api", "auth_jwt_secret"),
+ expiration_time_in_seconds=conf.getint("api",
"auth_jwt_expiration_time"),
+ audience="front-apis",
+ )
+ return signer.generate_signed_token(self.serialize_user(user))
+
def get_user_id(self) -> str | None:
"""Return the user ID associated to the user in session."""
user = self.get_user()
diff --git a/airflow/auth/managers/simple/views/auth.py
b/airflow/auth/managers/simple/views/auth.py
index b292fc05541..64c697ecbcc 100644
--- a/airflow/auth/managers/simple/views/auth.py
+++ b/airflow/auth/managers/simple/views/auth.py
@@ -25,7 +25,6 @@ from flask_appbuilder import expose
from airflow.api_fastapi.app import get_auth_manager
from airflow.auth.managers.simple.user import SimpleAuthManagerUser
from airflow.configuration import conf
-from airflow.utils.jwt_signer import JWTSigner
from airflow.utils.state import State
from airflow.www.app import csrf
from airflow.www.views import AirflowBaseView
@@ -92,12 +91,7 @@ class SimpleAuthManagerAuthenticationViews(AirflowBaseView):
# Will be removed once Airflow uses the new UI
session["user"] = user
- signer = JWTSigner(
- secret_key=conf.get("api", "auth_jwt_secret"),
- expiration_time_in_seconds=conf.getint("api",
"auth_jwt_expiration_time"),
- audience="front-apis",
- )
- token =
signer.generate_signed_token(get_auth_manager().serialize_user(user))
+ token = get_auth_manager().get_jwt_token(user)
if next_url:
return redirect(self._get_redirect_url(next_url, token))
diff --git
a/providers/src/airflow/providers/fab/auth_manager/cli_commands/utils.py
b/providers/src/airflow/providers/fab/auth_manager/cli_commands/utils.py
index ee7c6f8202a..badd7fd08ae 100644
--- a/providers/src/airflow/providers/fab/auth_manager/cli_commands/utils.py
+++ b/providers/src/airflow/providers/fab/auth_manager/cli_commands/utils.py
@@ -41,7 +41,7 @@ if TYPE_CHECKING:
@cache
def _return_appbuilder(app: Flask) -> AirflowAppBuilder:
"""Return an appbuilder instance for the given app."""
- init_appbuilder(app)
+ init_appbuilder(app, enable_plugins=False)
init_plugins(app)
init_airflow_session_interface(app)
return app.appbuilder # type: ignore[attr-defined]
diff --git
a/providers/src/airflow/providers/fab/auth_manager/fab_auth_manager.py
b/providers/src/airflow/providers/fab/auth_manager/fab_auth_manager.py
index 4c889a9c14e..2d58d79e41b 100644
--- a/providers/src/airflow/providers/fab/auth_manager/fab_auth_manager.py
+++ b/providers/src/airflow/providers/fab/auth_manager/fab_auth_manager.py
@@ -181,7 +181,7 @@ class FabAuthManager(BaseAuthManager[User]):
if not flask_blueprint:
return None
- flask_app = create_app()
+ flask_app = create_app(enable_plugins=False)
flask_app.register_blueprint(flask_blueprint)
app = FastAPI(
diff --git a/providers/src/airflow/providers/fab/www/app.py
b/providers/src/airflow/providers/fab/www/app.py
index 0414fc5e408..6890dc96abb 100644
--- a/providers/src/airflow/providers/fab/www/app.py
+++ b/providers/src/airflow/providers/fab/www/app.py
@@ -41,7 +41,7 @@ app: Flask | None = None
csrf = CSRFProtect()
-def create_app():
+def create_app(enable_plugins: bool):
"""Create a new instance of Airflow WWW app."""
flask_app = Flask(__name__)
flask_app.secret_key = conf.get("webserver", "SECRET_KEY")
@@ -66,10 +66,11 @@ def create_app():
init_api_auth(flask_app)
with flask_app.app_context():
- init_appbuilder(flask_app)
- init_plugins(flask_app)
+ init_appbuilder(flask_app, enable_plugins=enable_plugins)
+ if enable_plugins:
+ init_plugins(flask_app)
init_error_handlers(flask_app)
- init_jinja_globals(flask_app)
+ init_jinja_globals(flask_app, enable_plugins=enable_plugins)
init_xframe_protection(flask_app)
return flask_app
diff --git
a/providers/src/airflow/providers/fab/www/extensions/init_appbuilder.py
b/providers/src/airflow/providers/fab/www/extensions/init_appbuilder.py
index b3f5551aeee..555f0501a6a 100644
--- a/providers/src/airflow/providers/fab/www/extensions/init_appbuilder.py
+++ b/providers/src/airflow/providers/fab/www/extensions/init_appbuilder.py
@@ -39,9 +39,10 @@ from flask_appbuilder.menu import Menu
from flask_appbuilder.views import IndexView
from airflow import settings
-from airflow.api_fastapi.app import create_auth_manager
+from airflow.api_fastapi.app import create_auth_manager, get_auth_manager
from airflow.configuration import conf
from airflow.providers.fab.www.security_manager import AirflowSecurityManagerV2
+from airflow.providers.fab.www.views import FabIndexView
if TYPE_CHECKING:
from flask import Flask
@@ -109,6 +110,7 @@ class AirflowAppBuilder:
base_template="airflow/main.html",
static_folder="static/appbuilder",
static_url_path="/appbuilder",
+ enable_plugins: bool = False,
):
"""
App-builder constructor.
@@ -125,6 +127,15 @@ class AirflowAppBuilder:
optional, your override for the global static folder
:param static_url_path:
optional, your override for the global static url path
+ :param enable_plugins:
+ optional, whether plugins are enabled for this app.
AirflowAppBuilder from FAB provider can be
+ instantiated in two modes:
+ - Plugins enabled. The Flask application is responsible to
execute Airflow 2 plugins.
+ This application is only running if there are Airflow 2 plugins
defined as part of the Airflow
+ environment
+ - Plugins disabled. The Flask application is responsible to
execute the FAB auth manager login
+ process. This application is only running if FAB auth manager
is the auth manager configured
+ in the Airflow environment
"""
from airflow.providers_manager import ProvidersManager
@@ -139,6 +150,7 @@ class AirflowAppBuilder:
self.static_folder = static_folder
self.static_url_path = static_url_path
self.app = app
+ self.enable_plugins = enable_plugins
self.update_perms = conf.getboolean("fab", "UPDATE_FAB_PERMS")
self.auth_rate_limited = conf.getboolean("fab", "AUTH_RATE_LIMITED")
self.auth_rate_limit = conf.get("fab", "AUTH_RATE_LIMIT")
@@ -172,8 +184,10 @@ class AirflowAppBuilder:
_index_view = app.config.get("FAB_INDEX_VIEW", None)
if _index_view is not None:
self.indexview = dynamic_class_import(_index_view)
+ elif not self.enable_plugins:
+ self.indexview = FabIndexView
else:
- self.indexview = self.indexview or IndexView
+ self.indexview = IndexView
_menu = app.config.get("FAB_MENU", None)
if _menu is not None:
self.menu = dynamic_class_import(_menu)
@@ -282,6 +296,7 @@ class AirflowAppBuilder:
"""Register indexview, utilview (back function), babel views and
Security views."""
self.indexview = self._check_and_init(self.indexview)
self.add_view_no_menu(self.indexview)
+ get_auth_manager().register_views()
def _add_addon_views(self):
"""Register declared addons."""
@@ -500,7 +515,6 @@ class AirflowAppBuilder:
@property
def get_url_for_index(self):
- # TODO: Return the fast api application homepage
return
url_for(f"{self.indexview.endpoint}.{self.indexview.default_view}")
def get_url_for_locale(self, lang):
@@ -560,10 +574,11 @@ class AirflowAppBuilder:
view.get_init_inner_views().append(v)
-def init_appbuilder(app: Flask) -> AirflowAppBuilder:
+def init_appbuilder(app: Flask, enable_plugins: bool) -> AirflowAppBuilder:
"""Init `Flask App Builder
<https://flask-appbuilder.readthedocs.io/en/latest/>`__."""
return AirflowAppBuilder(
app=app,
session=settings.Session,
base_template="airflow/main.html",
+ enable_plugins=enable_plugins,
)
diff --git
a/providers/src/airflow/providers/fab/www/extensions/init_jinja_globals.py
b/providers/src/airflow/providers/fab/www/extensions/init_jinja_globals.py
index f7abe34154d..177ed158b95 100644
--- a/providers/src/airflow/providers/fab/www/extensions/init_jinja_globals.py
+++ b/providers/src/airflow/providers/fab/www/extensions/init_jinja_globals.py
@@ -30,7 +30,7 @@ from airflow.utils.platform import get_airflow_git_version
logger = logging.getLogger(__name__)
-def init_jinja_globals(app):
+def init_jinja_globals(app, enable_plugins: bool):
"""Add extra globals variable to Jinja context."""
server_timezone = conf.get("core", "default_timezone")
if server_timezone == "system":
@@ -70,6 +70,8 @@ def init_jinja_globals(app):
"state_color_mapping": STATE_COLORS,
"airflow_version": airflow_version,
"git_version": git_version,
+ "show_plugin_message": enable_plugins,
+ "disable_nav_bar": not enable_plugins,
}
# Extra global specific to auth manager
diff --git
a/providers/src/airflow/providers/fab/www/templates/airflow/main.html
b/providers/src/airflow/providers/fab/www/templates/airflow/main.html
index e6c00bd0666..25ce3c0439a 100644
--- a/providers/src/airflow/providers/fab/www/templates/airflow/main.html
+++ b/providers/src/airflow/providers/fab/www/templates/airflow/main.html
@@ -21,7 +21,7 @@
{% from 'airflow/_messages.html' import show_message %}
{% block page_title -%}
- Airflow - Airflow 2 plugins compatibility view
+ Airflow
{% endblock %}
{% block head_css %}
@@ -53,12 +53,14 @@
{% endblock %}
{% block messages %}
- {% call show_message(category='warning', dismissible=false) %}
- <p>
- You have a plugin that is using a FAB view or Flask Blueprint, which was
used for the Airflow 2 UI, and is now
- deprecated. Please update your plugin to be compatible with the Airflow
3 UI.
- </p>
- {% endcall %}
+ {% if show_plugin_message %}
+ {% call show_message(category='warning', dismissible=false) %}
+ <p>
+ You have a plugin that is using a FAB view or Flask Blueprint, which
was used for the Airflow 2 UI, and is now
+ deprecated. Please update your plugin to be compatible with the
Airflow 3 UI.
+ </p>
+ {% endcall %}
+ {% endif %}
{% endblock %}
{% block tail_js %}
@@ -66,10 +68,6 @@
<script>
// below variables are used in main.js
// keep as var, changing to const or let breaks other code
- var Airflow = {
- serverTimezone: '{{ server_timezone }}',
- defaultUITimezone: '{{ default_ui_timezone }}',
- };
var hostName = '{{ hostname }}';
$('time[title]').tooltip();
</script>
diff --git
a/providers/src/airflow/providers/fab/www/templates/appbuilder/navbar.html
b/providers/src/airflow/providers/fab/www/templates/appbuilder/navbar.html
index dba354fb131..76cbcd8e2dd 100644
--- a/providers/src/airflow/providers/fab/www/templates/appbuilder/navbar.html
+++ b/providers/src/airflow/providers/fab/www/templates/appbuilder/navbar.html
@@ -18,6 +18,7 @@
#}
{% set menu = appbuilder.menu %}
+{% set languages = appbuilder.languages %}
<div class="navbar navbar-fixed-top" role="navigation"
style="background-color: {{ navbar_color }};">
<div class="container">
@@ -46,7 +47,13 @@
</div>
<div class="navbar-collapse collapse">
<ul class="nav navbar-nav">
+ {%- if disable_nav_bar is not defined or not disable_nav_bar -%}
{% include 'appbuilder/navbar_menu.html' %}
+ {%- endif -%}
+ </ul>
+ <ul class="nav navbar-nav navbar-right">
+ <li class="active">
+ {% include 'appbuilder/navbar_right.html' %}
</ul>
</div>
</div>
diff --git
a/providers/src/airflow/providers/fab/www/templates/appbuilder/navbar_right.html
b/providers/src/airflow/providers/fab/www/templates/appbuilder/navbar_right.html
new file mode 100644
index 00000000000..54254f6d426
--- /dev/null
+++
b/providers/src/airflow/providers/fab/www/templates/appbuilder/navbar_right.html
@@ -0,0 +1,64 @@
+{#
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+#}
+
+{% macro locale_menu(languages) %}
+ {% set locale = session['locale'] %}
+ {% if not locale %}
+ {% set locale = 'en' %}
+ {% endif %}
+ <li class="dropdown">
+ <a class="dropdown-toggle" href="javascript:void(0)">
+ <div class="f16"><i class="flag
{{languages[locale].get('flag')}}"></i><b class="caret"></b></div>
+ </a>
+ {% if languages.keys()|length > 1 %}
+ <ul class="dropdown-menu">
+ <li class="dropdown">
+ {% for lang in languages %}
+ {% if lang != locale %}
+ <a href="{{appbuilder.get_url_for_locale(lang)}}">
+ <div class="f16"><i class="flag
{{languages[lang].get('flag')}}"></i> - {{languages[lang].get('name')}}
+ </div></a>
+ {% endif %}
+ {% endfor %}
+ </li>
+ </ul>
+ {% endif %}
+ </li>
+{% endmacro %}
+
+{# clock and timezone menu #}
+<li class="dropdown" id="timezone-dropdown">
+ <a class="dropdown-toggle" style="display:none" href="#">
+ <time id="clock" class="js-tooltip"></time>
+ <b class="caret"></b>
+ </a>
+ <ul class="dropdown-menu" id="timezone-menu">
+ <li id="timezone-utc"><a data-timezone="UTC" href="#">UTC</a></li>
+ <li id="timezone-server" style="display: none;"><a data-timezone="{{
server_timezone }}" href="#">{{ server_timezone }}</a></li>
+ <li id="timezone-local"><a href="#">Local</a></li>
+ <li id="timezone-manual" style="display: none"><a data-timezone=""
href="#"></a></li>
+ <li role="separator" class="divider"></li>
+ <li>
+ <form>
+ <label for="timezone-other">Other</label>
+ <input id="timezone-other" placeholder="Select Timezone name"
autocomplete="off" tabindex="-1">
+ </form>
+ </li>
+ </ul>
+</li>
diff --git a/providers/src/airflow/providers/fab/www/views.py
b/providers/src/airflow/providers/fab/www/views.py
index 43ac276897e..925a777c26d 100644
--- a/providers/src/airflow/providers/fab/www/views.py
+++ b/providers/src/airflow/providers/fab/www/views.py
@@ -21,8 +21,11 @@ import sys
import traceback
from flask import (
+ g,
+ redirect,
render_template,
)
+from flask_appbuilder import IndexView, expose
from airflow.api_fastapi.app import get_auth_manager
from airflow.configuration import conf
@@ -30,6 +33,27 @@ from airflow.utils.net import get_hostname
from airflow.version import version
+class FabIndexView(IndexView):
+ """
+ A simple view that inherits from FAB index view.
+
+ The only goal of this view is to redirect the user to the Airflow 3 UI
index page if the user is
+ authenticated. It is impossible to redirect the user directly to the
Airflow 3 UI index page before
+ redirecting them to this page because FAB itself defines the logic
redirection and does not allow external
+ redirect.
+
+ It is impossible to redirect the user before
+ """
+
+ @expose("/")
+ def index(self):
+ if g.user is not None and g.user.is_authenticated:
+ token = get_auth_manager().get_jwt_token(g.user)
+ return redirect(f"/webapp?token={token}", code=302)
+ else:
+ super().index(self)
+
+
def show_traceback(error):
"""Show Traceback for a given error."""
is_logged_in = get_auth_manager().is_logged_in()
diff --git a/tests/auth/managers/simple/views/test_auth.py
b/tests/auth/managers/simple/views/test_auth.py
index 0eccf0dc9ec..633dcbd0f21 100644
--- a/tests/auth/managers/simple/views/test_auth.py
+++ b/tests/auth/managers/simple/views/test_auth.py
@@ -64,13 +64,20 @@ class TestSimpleAuthManagerAuthenticationViews:
("test", "test", True, {"next": "next_url"},
"next_url?token=token"),
],
)
- @patch("airflow.auth.managers.simple.views.auth.JWTSigner")
+ @patch("airflow.auth.managers.simple.views.auth.get_auth_manager")
def test_login_submit(
- self, mock_jwt_signer, simple_app, username, password, is_successful,
query_params, expected_redirect
+ self,
+ mock_get_auth_manager,
+ simple_app,
+ username,
+ password,
+ is_successful,
+ query_params,
+ expected_redirect,
):
- signer = Mock()
- signer.generate_signed_token.return_value = "token"
- mock_jwt_signer.return_value = signer
+ auth_manager = Mock()
+ auth_manager.get_jwt_token.return_value = "token"
+ mock_get_auth_manager.return_value = auth_manager
with simple_app.test_client() as client:
response = client.post(
"/login_submit", query_string=query_params, data={"username":
username, "password": password}
diff --git a/tests/auth/managers/test_base_auth_manager.py
b/tests/auth/managers/test_base_auth_manager.py
index 4406ae9d436..370e401da06 100644
--- a/tests/auth/managers/test_base_auth_manager.py
+++ b/tests/auth/managers/test_base_auth_manager.py
@@ -44,16 +44,27 @@ if TYPE_CHECKING:
from airflow.www.extensions.init_appbuilder import AirflowAppBuilder
-class EmptyAuthManager(BaseAuthManager[BaseUser]):
+class BaseAuthManagerUserTest(BaseUser):
+ def __init__(self, *, name: str) -> None:
+ self.name = name
+
+ def get_id(self) -> str:
+ return self.name
+
+ def get_name(self) -> str:
+ return self.name
+
+
+class EmptyAuthManager(BaseAuthManager[BaseAuthManagerUserTest]):
appbuilder: AirflowAppBuilder | None = None
- def get_user(self) -> BaseUser:
+ def get_user(self) -> BaseAuthManagerUserTest:
raise NotImplementedError()
- def deserialize_user(self, token: dict[str, Any]) -> BaseUser:
+ def deserialize_user(self, token: dict[str, Any]) ->
BaseAuthManagerUserTest:
raise NotImplementedError()
- def serialize_user(self, user: BaseUser) -> dict[str, Any]:
+ def serialize_user(self, user: BaseAuthManagerUserTest) -> dict[str, Any]:
raise NotImplementedError()
def is_authorized_configuration(
@@ -61,7 +72,7 @@ class EmptyAuthManager(BaseAuthManager[BaseUser]):
*,
method: ResourceMethod,
details: ConfigurationDetails | None = None,
- user: BaseUser | None = None,
+ user: BaseAuthManagerUserTest | None = None,
) -> bool:
raise NotImplementedError()
@@ -70,7 +81,7 @@ class EmptyAuthManager(BaseAuthManager[BaseUser]):
*,
method: ResourceMethod,
details: ConnectionDetails | None = None,
- user: BaseUser | None = None,
+ user: BaseAuthManagerUserTest | None = None,
) -> bool:
raise NotImplementedError()
@@ -80,30 +91,44 @@ class EmptyAuthManager(BaseAuthManager[BaseUser]):
method: ResourceMethod,
access_entity: DagAccessEntity | None = None,
details: DagDetails | None = None,
- user: BaseUser | None = None,
+ user: BaseAuthManagerUserTest | None = None,
) -> bool:
raise NotImplementedError()
def is_authorized_asset(
- self, *, method: ResourceMethod, details: AssetDetails | None = None,
user: BaseUser | None = None
+ self,
+ *,
+ method: ResourceMethod,
+ details: AssetDetails | None = None,
+ user: BaseAuthManagerUserTest | None = None,
) -> bool:
raise NotImplementedError()
def is_authorized_pool(
- self, *, method: ResourceMethod, details: PoolDetails | None = None,
user: BaseUser | None = None
+ self,
+ *,
+ method: ResourceMethod,
+ details: PoolDetails | None = None,
+ user: BaseAuthManagerUserTest | None = None,
) -> bool:
raise NotImplementedError()
def is_authorized_variable(
- self, *, method: ResourceMethod, details: VariableDetails | None =
None, user: BaseUser | None = None
+ self,
+ *,
+ method: ResourceMethod,
+ details: VariableDetails | None = None,
+ user: BaseAuthManagerUserTest | None = None,
) -> bool:
raise NotImplementedError()
- def is_authorized_view(self, *, access_view: AccessView, user: BaseUser |
None = None) -> bool:
+ def is_authorized_view(
+ self, *, access_view: AccessView, user: BaseAuthManagerUserTest | None
= None
+ ) -> bool:
raise NotImplementedError()
def is_authorized_custom_view(
- self, *, method: ResourceMethod | str, resource_name: str, user:
BaseUser | None = None
+ self, *, method: ResourceMethod | str, resource_name: str, user:
BaseAuthManagerUserTest | None = None
):
raise NotImplementedError()
@@ -165,6 +190,23 @@ class TestBaseAuthManager:
def test_get_url_user_profile_return_none(self, auth_manager):
assert auth_manager.get_url_user_profile() is None
+ @patch("airflow.auth.managers.base_auth_manager.JWTSigner")
+ @patch.object(EmptyAuthManager, "serialize_user")
+ def test_get_jwt_token(self, mock_serialize_user, mock_jwt_signer,
auth_manager):
+ token = "token"
+ serialized_user = "serialized_user"
+ signer = Mock()
+ signer.generate_signed_token.return_value = token
+ mock_jwt_signer.return_value = signer
+ mock_serialize_user.return_value = serialized_user
+ user = BaseAuthManagerUserTest(name="test")
+
+ result = auth_manager.get_jwt_token(user)
+
+ mock_serialize_user.assert_called_once_with(user)
+ signer.generate_signed_token.assert_called_once_with(serialized_user)
+ assert result == token
+
@pytest.mark.parametrize(
"return_values, expected",
[
@@ -279,7 +321,7 @@ class TestBaseAuthManager:
method: ResourceMethod,
access_entity: DagAccessEntity | None = None,
details: DagDetails | None = None,
- user: BaseUser | None = None,
+ user: BaseAuthManagerUserTest | None = None,
):
if not details:
return access_all