This is an automated email from the ASF dual-hosted git repository.

vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 2d9ab54eeff Make FAB auth manager login process compatible with 
Airflow 3 UI (#45765)
2d9ab54eeff is described below

commit 2d9ab54eeff14b8839a4c82e1713ede9f37f02e2
Author: Vincent <[email protected]>
AuthorDate: Tue Jan 21 10:52:00 2025 -0500

    Make FAB auth manager login process compatible with Airflow 3 UI (#45765)
---
 airflow/api_fastapi/core_api/app.py                |  2 +-
 airflow/auth/managers/base_auth_manager.py         | 11 ++++
 airflow/auth/managers/simple/views/auth.py         |  8 +--
 .../fab/auth_manager/cli_commands/utils.py         |  2 +-
 .../providers/fab/auth_manager/fab_auth_manager.py |  2 +-
 providers/src/airflow/providers/fab/www/app.py     |  9 +--
 .../fab/www/extensions/init_appbuilder.py          | 23 ++++++--
 .../fab/www/extensions/init_jinja_globals.py       |  4 +-
 .../providers/fab/www/templates/airflow/main.html  | 20 +++----
 .../fab/www/templates/appbuilder/navbar.html       |  7 +++
 .../fab/www/templates/appbuilder/navbar_right.html | 64 ++++++++++++++++++++
 providers/src/airflow/providers/fab/www/views.py   | 24 ++++++++
 tests/auth/managers/simple/views/test_auth.py      | 17 ++++--
 tests/auth/managers/test_base_auth_manager.py      | 68 +++++++++++++++++-----
 14 files changed, 213 insertions(+), 48 deletions(-)

diff --git a/airflow/api_fastapi/core_api/app.py 
b/airflow/api_fastapi/core_api/app.py
index 6099c5b654a..08f37812c3c 100644
--- a/airflow/api_fastapi/core_api/app.py
+++ b/airflow/api_fastapi/core_api/app.py
@@ -132,7 +132,7 @@ def init_flask_plugins(app: FastAPI) -> None:
         stacklevel=2,
     )
 
-    flask_app = create_app()
+    flask_app = create_app(enable_plugins=True)
     app.mount("/pluginsv2", WSGIMiddleware(flask_app))
 
 
diff --git a/airflow/auth/managers/base_auth_manager.py 
b/airflow/auth/managers/base_auth_manager.py
index 6a9ef11e3d7..fe86bc8f05a 100644
--- a/airflow/auth/managers/base_auth_manager.py
+++ b/airflow/auth/managers/base_auth_manager.py
@@ -24,9 +24,11 @@ from sqlalchemy import select
 
 from airflow.auth.managers.models.base_user import BaseUser
 from airflow.auth.managers.models.resource_details import DagDetails
+from airflow.configuration import conf
 from airflow.exceptions import AirflowException
 from airflow.models import DagModel
 from airflow.typing_compat import Literal
+from airflow.utils.jwt_signer import JWTSigner
 from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.utils.session import NEW_SESSION, provide_session
 
@@ -100,6 +102,15 @@ class BaseAuthManager(Generic[T], LoggingMixin):
     def serialize_user(self, user: T) -> dict[str, Any]:
         """Create a dict from a user object."""
 
+    def get_jwt_token(self, user: T) -> str:
+        """Return the JWT token from a user object."""
+        signer = JWTSigner(
+            secret_key=conf.get("api", "auth_jwt_secret"),
+            expiration_time_in_seconds=conf.getint("api", 
"auth_jwt_expiration_time"),
+            audience="front-apis",
+        )
+        return signer.generate_signed_token(self.serialize_user(user))
+
     def get_user_id(self) -> str | None:
         """Return the user ID associated to the user in session."""
         user = self.get_user()
diff --git a/airflow/auth/managers/simple/views/auth.py 
b/airflow/auth/managers/simple/views/auth.py
index b292fc05541..64c697ecbcc 100644
--- a/airflow/auth/managers/simple/views/auth.py
+++ b/airflow/auth/managers/simple/views/auth.py
@@ -25,7 +25,6 @@ from flask_appbuilder import expose
 from airflow.api_fastapi.app import get_auth_manager
 from airflow.auth.managers.simple.user import SimpleAuthManagerUser
 from airflow.configuration import conf
-from airflow.utils.jwt_signer import JWTSigner
 from airflow.utils.state import State
 from airflow.www.app import csrf
 from airflow.www.views import AirflowBaseView
@@ -92,12 +91,7 @@ class SimpleAuthManagerAuthenticationViews(AirflowBaseView):
         # Will be removed once Airflow uses the new UI
         session["user"] = user
 
-        signer = JWTSigner(
-            secret_key=conf.get("api", "auth_jwt_secret"),
-            expiration_time_in_seconds=conf.getint("api", 
"auth_jwt_expiration_time"),
-            audience="front-apis",
-        )
-        token = 
signer.generate_signed_token(get_auth_manager().serialize_user(user))
+        token = get_auth_manager().get_jwt_token(user)
 
         if next_url:
             return redirect(self._get_redirect_url(next_url, token))
diff --git 
a/providers/src/airflow/providers/fab/auth_manager/cli_commands/utils.py 
b/providers/src/airflow/providers/fab/auth_manager/cli_commands/utils.py
index ee7c6f8202a..badd7fd08ae 100644
--- a/providers/src/airflow/providers/fab/auth_manager/cli_commands/utils.py
+++ b/providers/src/airflow/providers/fab/auth_manager/cli_commands/utils.py
@@ -41,7 +41,7 @@ if TYPE_CHECKING:
 @cache
 def _return_appbuilder(app: Flask) -> AirflowAppBuilder:
     """Return an appbuilder instance for the given app."""
-    init_appbuilder(app)
+    init_appbuilder(app, enable_plugins=False)
     init_plugins(app)
     init_airflow_session_interface(app)
     return app.appbuilder  # type: ignore[attr-defined]
diff --git 
a/providers/src/airflow/providers/fab/auth_manager/fab_auth_manager.py 
b/providers/src/airflow/providers/fab/auth_manager/fab_auth_manager.py
index 4c889a9c14e..2d58d79e41b 100644
--- a/providers/src/airflow/providers/fab/auth_manager/fab_auth_manager.py
+++ b/providers/src/airflow/providers/fab/auth_manager/fab_auth_manager.py
@@ -181,7 +181,7 @@ class FabAuthManager(BaseAuthManager[User]):
         if not flask_blueprint:
             return None
 
-        flask_app = create_app()
+        flask_app = create_app(enable_plugins=False)
         flask_app.register_blueprint(flask_blueprint)
 
         app = FastAPI(
diff --git a/providers/src/airflow/providers/fab/www/app.py 
b/providers/src/airflow/providers/fab/www/app.py
index 0414fc5e408..6890dc96abb 100644
--- a/providers/src/airflow/providers/fab/www/app.py
+++ b/providers/src/airflow/providers/fab/www/app.py
@@ -41,7 +41,7 @@ app: Flask | None = None
 csrf = CSRFProtect()
 
 
-def create_app():
+def create_app(enable_plugins: bool):
     """Create a new instance of Airflow WWW app."""
     flask_app = Flask(__name__)
     flask_app.secret_key = conf.get("webserver", "SECRET_KEY")
@@ -66,10 +66,11 @@ def create_app():
     init_api_auth(flask_app)
 
     with flask_app.app_context():
-        init_appbuilder(flask_app)
-        init_plugins(flask_app)
+        init_appbuilder(flask_app, enable_plugins=enable_plugins)
+        if enable_plugins:
+            init_plugins(flask_app)
         init_error_handlers(flask_app)
-        init_jinja_globals(flask_app)
+        init_jinja_globals(flask_app, enable_plugins=enable_plugins)
         init_xframe_protection(flask_app)
     return flask_app
 
diff --git 
a/providers/src/airflow/providers/fab/www/extensions/init_appbuilder.py 
b/providers/src/airflow/providers/fab/www/extensions/init_appbuilder.py
index b3f5551aeee..555f0501a6a 100644
--- a/providers/src/airflow/providers/fab/www/extensions/init_appbuilder.py
+++ b/providers/src/airflow/providers/fab/www/extensions/init_appbuilder.py
@@ -39,9 +39,10 @@ from flask_appbuilder.menu import Menu
 from flask_appbuilder.views import IndexView
 
 from airflow import settings
-from airflow.api_fastapi.app import create_auth_manager
+from airflow.api_fastapi.app import create_auth_manager, get_auth_manager
 from airflow.configuration import conf
 from airflow.providers.fab.www.security_manager import AirflowSecurityManagerV2
+from airflow.providers.fab.www.views import FabIndexView
 
 if TYPE_CHECKING:
     from flask import Flask
@@ -109,6 +110,7 @@ class AirflowAppBuilder:
         base_template="airflow/main.html",
         static_folder="static/appbuilder",
         static_url_path="/appbuilder",
+        enable_plugins: bool = False,
     ):
         """
         App-builder constructor.
@@ -125,6 +127,15 @@ class AirflowAppBuilder:
             optional, your override for the global static folder
         :param static_url_path:
             optional, your override for the global static url path
+        :param enable_plugins:
+            optional, whether plugins are enabled for this app. 
AirflowAppBuilder from FAB provider can be
+            instantiated in two modes:
+             - Plugins enabled. The Flask application is responsible to 
execute Airflow 2 plugins.
+               This application is only running if there are Airflow 2 plugins 
defined as part of the Airflow
+               environment
+             - Plugins disabled. The Flask application is responsible to 
execute the FAB auth manager login
+               process. This application is only running if FAB auth manager 
is the auth manager configured
+               in the Airflow environment
         """
         from airflow.providers_manager import ProvidersManager
 
@@ -139,6 +150,7 @@ class AirflowAppBuilder:
         self.static_folder = static_folder
         self.static_url_path = static_url_path
         self.app = app
+        self.enable_plugins = enable_plugins
         self.update_perms = conf.getboolean("fab", "UPDATE_FAB_PERMS")
         self.auth_rate_limited = conf.getboolean("fab", "AUTH_RATE_LIMITED")
         self.auth_rate_limit = conf.get("fab", "AUTH_RATE_LIMIT")
@@ -172,8 +184,10 @@ class AirflowAppBuilder:
         _index_view = app.config.get("FAB_INDEX_VIEW", None)
         if _index_view is not None:
             self.indexview = dynamic_class_import(_index_view)
+        elif not self.enable_plugins:
+            self.indexview = FabIndexView
         else:
-            self.indexview = self.indexview or IndexView
+            self.indexview = IndexView
         _menu = app.config.get("FAB_MENU", None)
         if _menu is not None:
             self.menu = dynamic_class_import(_menu)
@@ -282,6 +296,7 @@ class AirflowAppBuilder:
         """Register indexview, utilview (back function), babel views and 
Security views."""
         self.indexview = self._check_and_init(self.indexview)
         self.add_view_no_menu(self.indexview)
+        get_auth_manager().register_views()
 
     def _add_addon_views(self):
         """Register declared addons."""
@@ -500,7 +515,6 @@ class AirflowAppBuilder:
 
     @property
     def get_url_for_index(self):
-        # TODO: Return the fast api application homepage
         return 
url_for(f"{self.indexview.endpoint}.{self.indexview.default_view}")
 
     def get_url_for_locale(self, lang):
@@ -560,10 +574,11 @@ class AirflowAppBuilder:
                         view.get_init_inner_views().append(v)
 
 
-def init_appbuilder(app: Flask) -> AirflowAppBuilder:
+def init_appbuilder(app: Flask, enable_plugins: bool) -> AirflowAppBuilder:
     """Init `Flask App Builder 
<https://flask-appbuilder.readthedocs.io/en/latest/>`__."""
     return AirflowAppBuilder(
         app=app,
         session=settings.Session,
         base_template="airflow/main.html",
+        enable_plugins=enable_plugins,
     )
diff --git 
a/providers/src/airflow/providers/fab/www/extensions/init_jinja_globals.py 
b/providers/src/airflow/providers/fab/www/extensions/init_jinja_globals.py
index f7abe34154d..177ed158b95 100644
--- a/providers/src/airflow/providers/fab/www/extensions/init_jinja_globals.py
+++ b/providers/src/airflow/providers/fab/www/extensions/init_jinja_globals.py
@@ -30,7 +30,7 @@ from airflow.utils.platform import get_airflow_git_version
 logger = logging.getLogger(__name__)
 
 
-def init_jinja_globals(app):
+def init_jinja_globals(app, enable_plugins: bool):
     """Add extra globals variable to Jinja context."""
     server_timezone = conf.get("core", "default_timezone")
     if server_timezone == "system":
@@ -70,6 +70,8 @@ def init_jinja_globals(app):
             "state_color_mapping": STATE_COLORS,
             "airflow_version": airflow_version,
             "git_version": git_version,
+            "show_plugin_message": enable_plugins,
+            "disable_nav_bar": not enable_plugins,
         }
 
         # Extra global specific to auth manager
diff --git 
a/providers/src/airflow/providers/fab/www/templates/airflow/main.html 
b/providers/src/airflow/providers/fab/www/templates/airflow/main.html
index e6c00bd0666..25ce3c0439a 100644
--- a/providers/src/airflow/providers/fab/www/templates/airflow/main.html
+++ b/providers/src/airflow/providers/fab/www/templates/airflow/main.html
@@ -21,7 +21,7 @@
 {% from 'airflow/_messages.html' import show_message %}
 
 {% block page_title -%}
-  Airflow - Airflow 2 plugins compatibility view
+  Airflow
 {% endblock %}
 
 {% block head_css %}
@@ -53,12 +53,14 @@
 {% endblock %}
 
 {% block messages %}
-  {% call show_message(category='warning', dismissible=false) %}
-    <p>
-      You have a plugin that is using a FAB view or Flask Blueprint, which was 
used for the Airflow 2 UI, and is now
-      deprecated. Please update your plugin to be compatible with the Airflow 
3 UI.
-    </p>
-  {% endcall %}
+  {% if show_plugin_message %}
+    {% call show_message(category='warning', dismissible=false) %}
+      <p>
+        You have a plugin that is using a FAB view or Flask Blueprint, which 
was used for the Airflow 2 UI, and is now
+        deprecated. Please update your plugin to be compatible with the 
Airflow 3 UI.
+      </p>
+    {% endcall %}
+  {% endif %}
 {% endblock %}
 
 {% block tail_js %}
@@ -66,10 +68,6 @@
   <script>
     // below variables are used in main.js
     // keep as var, changing to const or let breaks other code
-    var Airflow = {
-      serverTimezone: '{{ server_timezone }}',
-      defaultUITimezone: '{{ default_ui_timezone }}',
-    };
     var hostName = '{{ hostname }}';
     $('time[title]').tooltip();
   </script>
diff --git 
a/providers/src/airflow/providers/fab/www/templates/appbuilder/navbar.html 
b/providers/src/airflow/providers/fab/www/templates/appbuilder/navbar.html
index dba354fb131..76cbcd8e2dd 100644
--- a/providers/src/airflow/providers/fab/www/templates/appbuilder/navbar.html
+++ b/providers/src/airflow/providers/fab/www/templates/appbuilder/navbar.html
@@ -18,6 +18,7 @@
 #}
 
 {% set menu = appbuilder.menu %}
+{% set languages = appbuilder.languages %}
 
 <div class="navbar navbar-fixed-top" role="navigation" 
style="background-color: {{ navbar_color }};">
   <div class="container">
@@ -46,7 +47,13 @@
     </div>
     <div class="navbar-collapse collapse">
       <ul class="nav navbar-nav">
+        {%- if disable_nav_bar is not defined or not disable_nav_bar -%}
         {% include 'appbuilder/navbar_menu.html' %}
+        {%- endif -%}
+      </ul>
+      <ul class="nav navbar-nav navbar-right">
+        <li class="active">
+        {% include 'appbuilder/navbar_right.html' %}
       </ul>
     </div>
   </div>
diff --git 
a/providers/src/airflow/providers/fab/www/templates/appbuilder/navbar_right.html
 
b/providers/src/airflow/providers/fab/www/templates/appbuilder/navbar_right.html
new file mode 100644
index 00000000000..54254f6d426
--- /dev/null
+++ 
b/providers/src/airflow/providers/fab/www/templates/appbuilder/navbar_right.html
@@ -0,0 +1,64 @@
+{#
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements.  See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership.  The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License.  You may obtain a copy of the License at
+
+   http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied.  See the License for the
+ specific language governing permissions and limitations
+ under the License.
+#}
+
+{% macro locale_menu(languages) %}
+  {% set locale = session['locale'] %}
+  {% if not locale %}
+      {% set locale = 'en' %}
+  {% endif %}
+  <li class="dropdown">
+    <a class="dropdown-toggle" href="javascript:void(0)">
+       <div class="f16"><i class="flag 
{{languages[locale].get('flag')}}"></i><b class="caret"></b></div>
+    </a>
+    {% if languages.keys()|length > 1 %}
+      <ul class="dropdown-menu">
+        <li class="dropdown">
+          {% for lang in languages %}
+            {% if lang != locale %}
+              <a href="{{appbuilder.get_url_for_locale(lang)}}">
+                <div class="f16"><i class="flag 
{{languages[lang].get('flag')}}"></i> - {{languages[lang].get('name')}}
+              </div></a>
+            {% endif %}
+          {% endfor %}
+        </li>
+      </ul>
+    {% endif %}
+  </li>
+{% endmacro %}
+
+{# clock and timezone menu #}
+<li class="dropdown" id="timezone-dropdown">
+  <a class="dropdown-toggle" style="display:none" href="#">
+    <time id="clock" class="js-tooltip"></time>
+    <b class="caret"></b>
+  </a>
+  <ul class="dropdown-menu" id="timezone-menu">
+    <li id="timezone-utc"><a data-timezone="UTC" href="#">UTC</a></li>
+    <li id="timezone-server" style="display: none;"><a data-timezone="{{ 
server_timezone }}" href="#">{{ server_timezone }}</a></li>
+    <li id="timezone-local"><a href="#">Local</a></li>
+    <li id="timezone-manual" style="display: none"><a data-timezone="" 
href="#"></a></li>
+    <li role="separator" class="divider"></li>
+    <li>
+      <form>
+        <label for="timezone-other">Other</label>
+        <input id="timezone-other" placeholder="Select Timezone name" 
autocomplete="off" tabindex="-1">
+      </form>
+    </li>
+  </ul>
+</li>
diff --git a/providers/src/airflow/providers/fab/www/views.py 
b/providers/src/airflow/providers/fab/www/views.py
index 43ac276897e..925a777c26d 100644
--- a/providers/src/airflow/providers/fab/www/views.py
+++ b/providers/src/airflow/providers/fab/www/views.py
@@ -21,8 +21,11 @@ import sys
 import traceback
 
 from flask import (
+    g,
+    redirect,
     render_template,
 )
+from flask_appbuilder import IndexView, expose
 
 from airflow.api_fastapi.app import get_auth_manager
 from airflow.configuration import conf
@@ -30,6 +33,27 @@ from airflow.utils.net import get_hostname
 from airflow.version import version
 
 
+class FabIndexView(IndexView):
+    """
+    A simple view that inherits from FAB index view.
+
+    The only goal of this view is to redirect the user to the Airflow 3 UI 
index page if the user is
+    authenticated. It is impossible to redirect the user directly to the 
Airflow 3 UI index page before
+    redirecting them to this page because FAB itself defines the logic 
redirection and does not allow external
+    redirect.
+
+    It is impossible to redirect the user before
+    """
+
+    @expose("/")
+    def index(self):
+        if g.user is not None and g.user.is_authenticated:
+            token = get_auth_manager().get_jwt_token(g.user)
+            return redirect(f"/webapp?token={token}", code=302)
+        else:
+            super().index(self)
+
+
 def show_traceback(error):
     """Show Traceback for a given error."""
     is_logged_in = get_auth_manager().is_logged_in()
diff --git a/tests/auth/managers/simple/views/test_auth.py 
b/tests/auth/managers/simple/views/test_auth.py
index 0eccf0dc9ec..633dcbd0f21 100644
--- a/tests/auth/managers/simple/views/test_auth.py
+++ b/tests/auth/managers/simple/views/test_auth.py
@@ -64,13 +64,20 @@ class TestSimpleAuthManagerAuthenticationViews:
             ("test", "test", True, {"next": "next_url"}, 
"next_url?token=token"),
         ],
     )
-    @patch("airflow.auth.managers.simple.views.auth.JWTSigner")
+    @patch("airflow.auth.managers.simple.views.auth.get_auth_manager")
     def test_login_submit(
-        self, mock_jwt_signer, simple_app, username, password, is_successful, 
query_params, expected_redirect
+        self,
+        mock_get_auth_manager,
+        simple_app,
+        username,
+        password,
+        is_successful,
+        query_params,
+        expected_redirect,
     ):
-        signer = Mock()
-        signer.generate_signed_token.return_value = "token"
-        mock_jwt_signer.return_value = signer
+        auth_manager = Mock()
+        auth_manager.get_jwt_token.return_value = "token"
+        mock_get_auth_manager.return_value = auth_manager
         with simple_app.test_client() as client:
             response = client.post(
                 "/login_submit", query_string=query_params, data={"username": 
username, "password": password}
diff --git a/tests/auth/managers/test_base_auth_manager.py 
b/tests/auth/managers/test_base_auth_manager.py
index 4406ae9d436..370e401da06 100644
--- a/tests/auth/managers/test_base_auth_manager.py
+++ b/tests/auth/managers/test_base_auth_manager.py
@@ -44,16 +44,27 @@ if TYPE_CHECKING:
     from airflow.www.extensions.init_appbuilder import AirflowAppBuilder
 
 
-class EmptyAuthManager(BaseAuthManager[BaseUser]):
+class BaseAuthManagerUserTest(BaseUser):
+    def __init__(self, *, name: str) -> None:
+        self.name = name
+
+    def get_id(self) -> str:
+        return self.name
+
+    def get_name(self) -> str:
+        return self.name
+
+
+class EmptyAuthManager(BaseAuthManager[BaseAuthManagerUserTest]):
     appbuilder: AirflowAppBuilder | None = None
 
-    def get_user(self) -> BaseUser:
+    def get_user(self) -> BaseAuthManagerUserTest:
         raise NotImplementedError()
 
-    def deserialize_user(self, token: dict[str, Any]) -> BaseUser:
+    def deserialize_user(self, token: dict[str, Any]) -> 
BaseAuthManagerUserTest:
         raise NotImplementedError()
 
-    def serialize_user(self, user: BaseUser) -> dict[str, Any]:
+    def serialize_user(self, user: BaseAuthManagerUserTest) -> dict[str, Any]:
         raise NotImplementedError()
 
     def is_authorized_configuration(
@@ -61,7 +72,7 @@ class EmptyAuthManager(BaseAuthManager[BaseUser]):
         *,
         method: ResourceMethod,
         details: ConfigurationDetails | None = None,
-        user: BaseUser | None = None,
+        user: BaseAuthManagerUserTest | None = None,
     ) -> bool:
         raise NotImplementedError()
 
@@ -70,7 +81,7 @@ class EmptyAuthManager(BaseAuthManager[BaseUser]):
         *,
         method: ResourceMethod,
         details: ConnectionDetails | None = None,
-        user: BaseUser | None = None,
+        user: BaseAuthManagerUserTest | None = None,
     ) -> bool:
         raise NotImplementedError()
 
@@ -80,30 +91,44 @@ class EmptyAuthManager(BaseAuthManager[BaseUser]):
         method: ResourceMethod,
         access_entity: DagAccessEntity | None = None,
         details: DagDetails | None = None,
-        user: BaseUser | None = None,
+        user: BaseAuthManagerUserTest | None = None,
     ) -> bool:
         raise NotImplementedError()
 
     def is_authorized_asset(
-        self, *, method: ResourceMethod, details: AssetDetails | None = None, 
user: BaseUser | None = None
+        self,
+        *,
+        method: ResourceMethod,
+        details: AssetDetails | None = None,
+        user: BaseAuthManagerUserTest | None = None,
     ) -> bool:
         raise NotImplementedError()
 
     def is_authorized_pool(
-        self, *, method: ResourceMethod, details: PoolDetails | None = None, 
user: BaseUser | None = None
+        self,
+        *,
+        method: ResourceMethod,
+        details: PoolDetails | None = None,
+        user: BaseAuthManagerUserTest | None = None,
     ) -> bool:
         raise NotImplementedError()
 
     def is_authorized_variable(
-        self, *, method: ResourceMethod, details: VariableDetails | None = 
None, user: BaseUser | None = None
+        self,
+        *,
+        method: ResourceMethod,
+        details: VariableDetails | None = None,
+        user: BaseAuthManagerUserTest | None = None,
     ) -> bool:
         raise NotImplementedError()
 
-    def is_authorized_view(self, *, access_view: AccessView, user: BaseUser | 
None = None) -> bool:
+    def is_authorized_view(
+        self, *, access_view: AccessView, user: BaseAuthManagerUserTest | None 
= None
+    ) -> bool:
         raise NotImplementedError()
 
     def is_authorized_custom_view(
-        self, *, method: ResourceMethod | str, resource_name: str, user: 
BaseUser | None = None
+        self, *, method: ResourceMethod | str, resource_name: str, user: 
BaseAuthManagerUserTest | None = None
     ):
         raise NotImplementedError()
 
@@ -165,6 +190,23 @@ class TestBaseAuthManager:
     def test_get_url_user_profile_return_none(self, auth_manager):
         assert auth_manager.get_url_user_profile() is None
 
+    @patch("airflow.auth.managers.base_auth_manager.JWTSigner")
+    @patch.object(EmptyAuthManager, "serialize_user")
+    def test_get_jwt_token(self, mock_serialize_user, mock_jwt_signer, 
auth_manager):
+        token = "token"
+        serialized_user = "serialized_user"
+        signer = Mock()
+        signer.generate_signed_token.return_value = token
+        mock_jwt_signer.return_value = signer
+        mock_serialize_user.return_value = serialized_user
+        user = BaseAuthManagerUserTest(name="test")
+
+        result = auth_manager.get_jwt_token(user)
+
+        mock_serialize_user.assert_called_once_with(user)
+        signer.generate_signed_token.assert_called_once_with(serialized_user)
+        assert result == token
+
     @pytest.mark.parametrize(
         "return_values, expected",
         [
@@ -279,7 +321,7 @@ class TestBaseAuthManager:
             method: ResourceMethod,
             access_entity: DagAccessEntity | None = None,
             details: DagDetails | None = None,
-            user: BaseUser | None = None,
+            user: BaseAuthManagerUserTest | None = None,
         ):
             if not details:
                 return access_all

Reply via email to