This is an automated email from the ASF dual-hosted git repository.

vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new d5bd1344b62 Set up JWT token authentication in Fast APIs (#42634)
d5bd1344b62 is described below

commit d5bd1344b626b0a407e651380c061c363e9cab5a
Author: Vincent <[email protected]>
AuthorDate: Tue Nov 19 14:57:14 2024 -0500

    Set up JWT token authentication in Fast APIs (#42634)
---
 airflow/api_fastapi/app.py                         | 45 ++++++++++
 airflow/api_fastapi/core_api/security.py           | 77 +++++++++++++++++
 airflow/auth/managers/base_auth_manager.py         | 36 +++++---
 .../auth/managers/simple/simple_auth_manager.py    | 69 ++++++++++-----
 airflow/auth/managers/simple/views/auth.py         | 15 +++-
 airflow/config_templates/config.yml                | 20 +++++
 airflow/configuration.py                           |  1 +
 docs/spelling_wordlist.txt                         |  1 +
 .../amazon/aws/auth_manager/aws_auth_manager.py    |  3 +-
 .../providers/fab/auth_manager/fab_auth_manager.py | 32 +++++--
 .../fab/auth_manager/test_fab_auth_manager.py      | 22 ++++-
 tests/api_fastapi/core_api/test_security.py        | 99 ++++++++++++++++++++++
 .../managers/simple/test_simple_auth_manager.py    | 15 +++-
 tests/auth/managers/simple/views/test_auth.py      |  9 +-
 tests/auth/managers/test_base_auth_manager.py      | 12 ++-
 tests/core/test_configuration.py                   |  1 +
 16 files changed, 402 insertions(+), 55 deletions(-)

diff --git a/airflow/api_fastapi/app.py b/airflow/api_fastapi/app.py
index 9ddd97b6cbe..4bf6ae9f6b7 100644
--- a/airflow/api_fastapi/app.py
+++ b/airflow/api_fastapi/app.py
@@ -24,10 +24,14 @@ from starlette.routing import Mount
 
 from airflow.api_fastapi.core_api.app import init_config, init_dag_bag, 
init_plugins, init_views
 from airflow.api_fastapi.execution_api.app import create_task_execution_api_app
+from airflow.auth.managers.base_auth_manager import BaseAuthManager
+from airflow.configuration import conf
+from airflow.exceptions import AirflowConfigException
 
 log = logging.getLogger(__name__)
 
 app: FastAPI | None = None
+auth_manager: BaseAuthManager | None = None
 
 
 @asynccontextmanager
@@ -57,6 +61,7 @@ def create_app(apps: str = "all") -> FastAPI:
         init_dag_bag(app)
         init_views(app)
         init_plugins(app)
+        init_auth_manager()
 
     if "execution" in apps_list or "all" in apps_list:
         task_exec_api_app = create_task_execution_api_app(app)
@@ -79,3 +84,43 @@ def purge_cached_app() -> None:
     """Remove the cached version of the app in global state."""
     global app
     app = None
+
+
+def get_auth_manager_cls() -> type[BaseAuthManager]:
+    """
+    Return just the auth manager class without initializing it.
+
+    Useful to save execution time if only static methods need to be called.
+    """
+    auth_manager_cls = conf.getimport(section="core", key="auth_manager")
+
+    if not auth_manager_cls:
+        raise AirflowConfigException(
+            "No auth manager defined in the config. "
+            "Please specify one using section/key [core/auth_manager]."
+        )
+
+    return auth_manager_cls
+
+
+def init_auth_manager() -> BaseAuthManager:
+    """
+    Initialize the auth manager.
+
+    Import the user manager class and instantiate it.
+    """
+    global auth_manager
+    auth_manager_cls = get_auth_manager_cls()
+    auth_manager = auth_manager_cls()
+    auth_manager.init()
+    return auth_manager
+
+
+def get_auth_manager() -> BaseAuthManager:
+    """Return the auth manager, provided it's been initialized before."""
+    if auth_manager is None:
+        raise RuntimeError(
+            "Auth Manager has not been initialized yet. "
+            "The `init_auth_manager` method needs to be called first."
+        )
+    return auth_manager
diff --git a/airflow/api_fastapi/core_api/security.py 
b/airflow/api_fastapi/core_api/security.py
new file mode 100644
index 00000000000..ede628e04aa
--- /dev/null
+++ b/airflow/api_fastapi/core_api/security.py
@@ -0,0 +1,77 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from functools import cache
+from typing import Any, Callable
+
+from fastapi import Depends, HTTPException
+from fastapi.security import OAuth2PasswordBearer
+from jwt import InvalidTokenError
+from typing_extensions import Annotated
+
+from airflow.api_fastapi.app import get_auth_manager
+from airflow.auth.managers.base_auth_manager import ResourceMethod
+from airflow.auth.managers.models.base_user import BaseUser
+from airflow.auth.managers.models.resource_details import DagAccessEntity, 
DagDetails
+from airflow.configuration import conf
+from airflow.utils.jwt_signer import JWTSigner
+
+oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
+
+
+@cache
+def get_signer() -> JWTSigner:
+    return JWTSigner(
+        secret_key=conf.get("api", "auth_jwt_secret"),
+        expiration_time_in_seconds=conf.getint("api", 
"auth_jwt_expiration_time"),
+        audience="front-apis",
+    )
+
+
+def get_user(token_str: Annotated[str, Depends(oauth2_scheme)]) -> BaseUser:
+    try:
+        signer = get_signer()
+        payload: dict[str, Any] = signer.verify_token(token_str)
+        return get_auth_manager().deserialize_user(payload)
+    except InvalidTokenError:
+        raise HTTPException(403, "Forbidden")
+
+
+def requires_access_dag(method: ResourceMethod, access_entity: DagAccessEntity 
| None = None) -> Callable:
+    def inner(
+        dag_id: str | None = None,
+        user: Annotated[BaseUser | None, Depends(get_user)] = None,
+    ) -> None:
+        def callback():
+            return get_auth_manager().is_authorized_dag(
+                method=method, access_entity=access_entity, 
details=DagDetails(id=dag_id), user=user
+            )
+
+        _requires_access(
+            is_authorized_callback=callback,
+        )
+
+    return inner
+
+
+def _requires_access(
+    *,
+    is_authorized_callback: Callable[[], bool],
+) -> None:
+    if not is_authorized_callback():
+        raise HTTPException(403, "Forbidden")
diff --git a/airflow/auth/managers/base_auth_manager.py 
b/airflow/auth/managers/base_auth_manager.py
index 69b5969c827..028d4dadb13 100644
--- a/airflow/auth/managers/base_auth_manager.py
+++ b/airflow/auth/managers/base_auth_manager.py
@@ -19,11 +19,12 @@ from __future__ import annotations
 
 from abc import abstractmethod
 from functools import cached_property
-from typing import TYPE_CHECKING, Container, Literal, Sequence
+from typing import TYPE_CHECKING, Any, Container, Generic, Literal, Sequence, 
TypeVar
 
 from flask_appbuilder.menu import MenuItem
 from sqlalchemy import select
 
+from airflow.auth.managers.models.base_user import BaseUser
 from airflow.auth.managers.models.resource_details import (
     DagDetails,
 )
@@ -37,7 +38,6 @@ if TYPE_CHECKING:
     from flask import Blueprint
     from sqlalchemy.orm import Session
 
-    from airflow.auth.managers.models.base_user import BaseUser
     from airflow.auth.managers.models.batch_apis import (
         IsAuthorizedConnectionRequest,
         IsAuthorizedDagRequest,
@@ -59,8 +59,10 @@ if TYPE_CHECKING:
 
 ResourceMethod = Literal["GET", "POST", "PUT", "DELETE", "MENU"]
 
+T = TypeVar("T", bound=BaseUser)
 
-class BaseAuthManager(LoggingMixin):
+
+class BaseAuthManager(Generic[T], LoggingMixin):
     """
     Class to derive in order to implement concrete auth managers.
 
@@ -69,7 +71,7 @@ class BaseAuthManager(LoggingMixin):
     :param appbuilder: the flask app builder
     """
 
-    def __init__(self, appbuilder: AirflowAppBuilder) -> None:
+    def __init__(self, appbuilder: AirflowAppBuilder | None = None) -> None:
         super().__init__()
         self.appbuilder = appbuilder
 
@@ -93,9 +95,17 @@ class BaseAuthManager(LoggingMixin):
         return self.get_user_name()
 
     @abstractmethod
-    def get_user(self) -> BaseUser | None:
+    def get_user(self) -> T | None:
         """Return the user associated to the user in session."""
 
+    @abstractmethod
+    def deserialize_user(self, token: dict[str, Any]) -> T:
+        """Create a user object from dict."""
+
+    @abstractmethod
+    def serialize_user(self, user: T) -> dict[str, Any]:
+        """Create a dict from a user object."""
+
     def get_user_id(self) -> str | None:
         """Return the user ID associated to the user in session."""
         user = self.get_user()
@@ -132,7 +142,7 @@ class BaseAuthManager(LoggingMixin):
         *,
         method: ResourceMethod,
         details: ConfigurationDetails | None = None,
-        user: BaseUser | None = None,
+        user: T | None = None,
     ) -> bool:
         """
         Return whether the user is authorized to perform a given action on 
configuration.
@@ -148,7 +158,7 @@ class BaseAuthManager(LoggingMixin):
         *,
         method: ResourceMethod,
         details: ConnectionDetails | None = None,
-        user: BaseUser | None = None,
+        user: T | None = None,
     ) -> bool:
         """
         Return whether the user is authorized to perform a given action on a 
connection.
@@ -165,7 +175,7 @@ class BaseAuthManager(LoggingMixin):
         method: ResourceMethod,
         access_entity: DagAccessEntity | None = None,
         details: DagDetails | None = None,
-        user: BaseUser | None = None,
+        user: T | None = None,
     ) -> bool:
         """
         Return whether the user is authorized to perform a given action on a 
DAG.
@@ -183,7 +193,7 @@ class BaseAuthManager(LoggingMixin):
         *,
         method: ResourceMethod,
         details: AssetDetails | None = None,
-        user: BaseUser | None = None,
+        user: T | None = None,
     ) -> bool:
         """
         Return whether the user is authorized to perform a given action on an 
asset.
@@ -199,7 +209,7 @@ class BaseAuthManager(LoggingMixin):
         *,
         method: ResourceMethod,
         details: PoolDetails | None = None,
-        user: BaseUser | None = None,
+        user: T | None = None,
     ) -> bool:
         """
         Return whether the user is authorized to perform a given action on a 
pool.
@@ -215,7 +225,7 @@ class BaseAuthManager(LoggingMixin):
         *,
         method: ResourceMethod,
         details: VariableDetails | None = None,
-        user: BaseUser | None = None,
+        user: T | None = None,
     ) -> bool:
         """
         Return whether the user is authorized to perform a given action on a 
variable.
@@ -230,7 +240,7 @@ class BaseAuthManager(LoggingMixin):
         self,
         *,
         access_view: AccessView,
-        user: BaseUser | None = None,
+        user: T | None = None,
     ) -> bool:
         """
         Return whether the user is authorized to access a read-only state of 
the installation.
@@ -241,7 +251,7 @@ class BaseAuthManager(LoggingMixin):
 
     @abstractmethod
     def is_authorized_custom_view(
-        self, *, method: ResourceMethod | str, resource_name: str, user: 
BaseUser | None = None
+        self, *, method: ResourceMethod | str, resource_name: str, user: T | 
None = None
     ):
         """
         Return whether the user is authorized to perform a given action on a 
custom view.
diff --git a/airflow/auth/managers/simple/simple_auth_manager.py 
b/airflow/auth/managers/simple/simple_auth_manager.py
index 78dccf7c2a9..48baa02e7c7 100644
--- a/airflow/auth/managers/simple/simple_auth_manager.py
+++ b/airflow/auth/managers/simple/simple_auth_manager.py
@@ -22,7 +22,7 @@ import os
 import random
 from collections import namedtuple
 from enum import Enum
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Any
 
 from flask import session, url_for
 from termcolor import colored
@@ -33,7 +33,6 @@ from airflow.auth.managers.simple.views.auth import 
SimpleAuthManagerAuthenticat
 from airflow.configuration import AIRFLOW_HOME
 
 if TYPE_CHECKING:
-    from airflow.auth.managers.models.base_user import BaseUser
     from airflow.auth.managers.models.resource_details import (
         AccessView,
         AssetDetails,
@@ -68,7 +67,7 @@ class 
SimpleAuthManagerRole(namedtuple("SimpleAuthManagerRole", "name order"), E
     ADMIN = "ADMIN", 3
 
 
-class SimpleAuthManager(BaseAuthManager):
+class SimpleAuthManager(BaseAuthManager[SimpleAuthManagerUser]):
     """
     Simple auth manager.
 
@@ -89,6 +88,8 @@ class SimpleAuthManager(BaseAuthManager):
         )
 
     def init(self) -> None:
+        if not self.appbuilder:
+            return
         user_passwords_from_file = {}
 
         # Read passwords from file
@@ -115,8 +116,9 @@ class SimpleAuthManager(BaseAuthManager):
             file.write(json.dumps(self.passwords))
 
     def is_logged_in(self) -> bool:
-        return "user" in session or self.appbuilder.get_app.config.get(
-            "SIMPLE_AUTH_MANAGER_ALL_ADMINS", False
+        return "user" in session or (
+            self.appbuilder is not None
+            and 
self.appbuilder.get_app.config.get("SIMPLE_AUTH_MANAGER_ALL_ADMINS", False)
         )
 
     def get_url_login(self, **kwargs) -> str:
@@ -128,28 +130,34 @@ class SimpleAuthManager(BaseAuthManager):
     def get_user(self) -> SimpleAuthManagerUser | None:
         if not self.is_logged_in():
             return None
-        if 
self.appbuilder.get_app.config.get("SIMPLE_AUTH_MANAGER_ALL_ADMINS", False):
+        if self.appbuilder and 
self.appbuilder.get_app.config.get("SIMPLE_AUTH_MANAGER_ALL_ADMINS", False):
             return SimpleAuthManagerUser(username="anonymous", role="admin")
         else:
             return session["user"]
 
+    def deserialize_user(self, token: dict[str, Any]) -> SimpleAuthManagerUser:
+        return SimpleAuthManagerUser(username=token["username"], 
role=token["role"])
+
+    def serialize_user(self, user: SimpleAuthManagerUser) -> dict[str, Any]:
+        return {"username": user.username, "role": user.role}
+
     def is_authorized_configuration(
         self,
         *,
         method: ResourceMethod,
         details: ConfigurationDetails | None = None,
-        user: BaseUser | None = None,
+        user: SimpleAuthManagerUser | None = None,
     ) -> bool:
-        return self._is_authorized(method=method, 
allow_role=SimpleAuthManagerRole.OP)
+        return self._is_authorized(method=method, 
allow_role=SimpleAuthManagerRole.OP, user=user)
 
     def is_authorized_connection(
         self,
         *,
         method: ResourceMethod,
         details: ConnectionDetails | None = None,
-        user: BaseUser | None = None,
+        user: SimpleAuthManagerUser | None = None,
     ) -> bool:
-        return self._is_authorized(method=method, 
allow_role=SimpleAuthManagerRole.OP)
+        return self._is_authorized(method=method, 
allow_role=SimpleAuthManagerRole.OP, user=user)
 
     def is_authorized_dag(
         self,
@@ -157,46 +165,65 @@ class SimpleAuthManager(BaseAuthManager):
         method: ResourceMethod,
         access_entity: DagAccessEntity | None = None,
         details: DagDetails | None = None,
-        user: BaseUser | None = None,
+        user: SimpleAuthManagerUser | None = None,
     ) -> bool:
         return self._is_authorized(
             method=method,
             allow_get_role=SimpleAuthManagerRole.VIEWER,
             allow_role=SimpleAuthManagerRole.USER,
+            user=user,
         )
 
     def is_authorized_asset(
-        self, *, method: ResourceMethod, details: AssetDetails | None = None, 
user: BaseUser | None = None
+        self,
+        *,
+        method: ResourceMethod,
+        details: AssetDetails | None = None,
+        user: SimpleAuthManagerUser | None = None,
     ) -> bool:
         return self._is_authorized(
             method=method,
             allow_get_role=SimpleAuthManagerRole.VIEWER,
             allow_role=SimpleAuthManagerRole.OP,
+            user=user,
         )
 
     def is_authorized_pool(
-        self, *, method: ResourceMethod, details: PoolDetails | None = None, 
user: BaseUser | None = None
+        self,
+        *,
+        method: ResourceMethod,
+        details: PoolDetails | None = None,
+        user: SimpleAuthManagerUser | None = None,
     ) -> bool:
         return self._is_authorized(
             method=method,
             allow_get_role=SimpleAuthManagerRole.VIEWER,
             allow_role=SimpleAuthManagerRole.OP,
+            user=user,
         )
 
     def is_authorized_variable(
-        self, *, method: ResourceMethod, details: VariableDetails | None = 
None, user: BaseUser | None = None
+        self,
+        *,
+        method: ResourceMethod,
+        details: VariableDetails | None = None,
+        user: SimpleAuthManagerUser | None = None,
     ) -> bool:
-        return self._is_authorized(method=method, 
allow_role=SimpleAuthManagerRole.OP)
+        return self._is_authorized(method=method, 
allow_role=SimpleAuthManagerRole.OP, user=user)
 
-    def is_authorized_view(self, *, access_view: AccessView, user: BaseUser | 
None = None) -> bool:
-        return self._is_authorized(method="GET", 
allow_role=SimpleAuthManagerRole.VIEWER)
+    def is_authorized_view(
+        self, *, access_view: AccessView, user: SimpleAuthManagerUser | None = 
None
+    ) -> bool:
+        return self._is_authorized(method="GET", 
allow_role=SimpleAuthManagerRole.VIEWER, user=user)
 
     def is_authorized_custom_view(
-        self, *, method: ResourceMethod | str, resource_name: str, user: 
BaseUser | None = None
+        self, *, method: ResourceMethod | str, resource_name: str, user: 
SimpleAuthManagerUser | None = None
     ):
-        return self._is_authorized(method="GET", 
allow_role=SimpleAuthManagerRole.VIEWER)
+        return self._is_authorized(method="GET", 
allow_role=SimpleAuthManagerRole.VIEWER, user=user)
 
     def register_views(self) -> None:
+        if not self.appbuilder:
+            return
         self.appbuilder.add_view_no_menu(
             SimpleAuthManagerAuthenticationViews(
                 
users=self.appbuilder.get_app.config.get("SIMPLE_AUTH_MANAGER_USERS", []),
@@ -210,6 +237,7 @@ class SimpleAuthManager(BaseAuthManager):
         method: ResourceMethod,
         allow_role: SimpleAuthManagerRole,
         allow_get_role: SimpleAuthManagerRole | None = None,
+        user: SimpleAuthManagerUser | None = None,
     ):
         """
         Return whether the user is authorized to access a given resource.
@@ -219,8 +247,9 @@ class SimpleAuthManager(BaseAuthManager):
             equal than this role, they have access
         :param allow_get_role: minimal role giving access to the resource, if 
the user's role is greater or
             equal than this role, they have access. If not provided, 
``allow_role`` is used
+        :param user: the user to check the authorization for. If not provided, 
the current user is used
         """
-        user = self.get_user()
+        user = user or self.get_user()
         if not user:
             return False
 
diff --git a/airflow/auth/managers/simple/views/auth.py 
b/airflow/auth/managers/simple/views/auth.py
index 8ab02d0a015..6e4cf0c3994 100644
--- a/airflow/auth/managers/simple/views/auth.py
+++ b/airflow/auth/managers/simple/views/auth.py
@@ -23,8 +23,10 @@ from flask_appbuilder import expose
 
 from airflow.auth.managers.simple.user import SimpleAuthManagerUser
 from airflow.configuration import conf
+from airflow.utils.jwt_signer import JWTSigner
 from airflow.utils.state import State
 from airflow.www.app import csrf
+from airflow.www.extensions.init_auth_manager import get_auth_manager
 from airflow.www.views import AirflowBaseView
 
 logger = logging.getLogger(__name__)
@@ -80,9 +82,18 @@ class SimpleAuthManagerAuthenticationViews(AirflowBaseView):
         if not username or not password or len(found_users) == 0:
             return 
redirect(url_for("SimpleAuthManagerAuthenticationViews.login", error=["1"]))
 
-        session["user"] = SimpleAuthManagerUser(
+        user = SimpleAuthManagerUser(
             username=username,
             role=found_users[0]["role"],
         )
+        # Will be removed once Airflow uses the new UI
+        session["user"] = user
 
-        return redirect(url_for("Airflow.index"))
+        signer = JWTSigner(
+            secret_key=conf.get("api", "auth_jwt_secret"),
+            expiration_time_in_seconds=conf.getint("api", 
"auth_jwt_expiration_time"),
+            audience="front-apis",
+        )
+        token = 
signer.generate_signed_token(get_auth_manager().serialize_user(user))
+
+        return redirect(url_for("Airflow.index", token=token))
diff --git a/airflow/config_templates/config.yml 
b/airflow/config_templates/config.yml
index 04ac0d802a9..4a7fe5f1897 100644
--- a/airflow/config_templates/config.yml
+++ b/airflow/config_templates/config.yml
@@ -1404,6 +1404,26 @@ api:
       version_added: 2.7.0
       example: ~
       default: "False"
+    auth_jwt_secret:
+      description: |
+        Secret key used to encode and decode JWT tokens to authenticate to 
public and private APIs.
+        It should be as random as possible. However, when running more than 1 
instances of API services,
+        make sure all of them use the same ``jwt_secret`` otherwise calls will 
fail on authentication.
+      version_added: 3.0.0
+      type: string
+      sensitive: true
+      example: ~
+      default: "{JWT_SECRET_KEY}"
+    auth_jwt_expiration_time:
+      description: |
+        Number in seconds until the JWT token used for authentication expires. 
When the token expires,
+        all API calls using this token will fail on authentication.
+        Make sure that time on ALL the machines that you run airflow 
components on is synchronized
+        (for example using ntpd) otherwise you might get "forbidden" errors.
+      version_added: 3.0.0
+      type: integer
+      example: ~
+      default: "86400"
 lineage:
   description: ~
   options:
diff --git a/airflow/configuration.py b/airflow/configuration.py
index bc808d6bfc2..521af6cbe32 100644
--- a/airflow/configuration.py
+++ b/airflow/configuration.py
@@ -2143,6 +2143,7 @@ else:
     TEST_PLUGINS_FOLDER = os.path.join(AIRFLOW_HOME, "plugins")
 
 SECRET_KEY = b64encode(os.urandom(16)).decode("utf-8")
+JWT_SECRET_KEY = b64encode(os.urandom(16)).decode("utf-8")
 FERNET_KEY = ""  # Set only if needed when generating a new file
 WEBSERVER_CONFIG = ""  # Set by initialize_config
 
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index 32787428cc5..b2df194de0d 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -898,6 +898,7 @@ Jupyter
 jupyter
 jupytercmd
 JWT
+jwt
 Kafka
 kafka
 Kalibrr
diff --git 
a/providers/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py 
b/providers/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py
index 414907961ce..8271aad3302 100644
--- 
a/providers/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py
+++ 
b/providers/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py
@@ -430,7 +430,8 @@ class AwsAuthManager(BaseAuthManager):
         ]
 
     def register_views(self) -> None:
-        self.appbuilder.add_view_no_menu(AwsAuthManagerAuthenticationViews())
+        if self.appbuilder:
+            
self.appbuilder.add_view_no_menu(AwsAuthManagerAuthenticationViews())
 
     @staticmethod
     def _get_menu_item_request(resource_name: str) -> IsAuthorizedRequest:
diff --git 
a/providers/src/airflow/providers/fab/auth_manager/fab_auth_manager.py 
b/providers/src/airflow/providers/fab/auth_manager/fab_auth_manager.py
index e93e440f5dd..adebe30dd53 100644
--- a/providers/src/airflow/providers/fab/auth_manager/fab_auth_manager.py
+++ b/providers/src/airflow/providers/fab/auth_manager/fab_auth_manager.py
@@ -21,7 +21,7 @@ import argparse
 import warnings
 from functools import cached_property
 from pathlib import Path
-from typing import TYPE_CHECKING, Container
+from typing import TYPE_CHECKING, Any, Container
 
 import packaging.version
 from connexion import FlaskApi
@@ -82,7 +82,7 @@ from airflow.security.permissions import (
     RESOURCE_WEBSITE,
     RESOURCE_XCOM,
 )
-from airflow.utils.session import NEW_SESSION, provide_session
+from airflow.utils.session import NEW_SESSION, create_session, provide_session
 from airflow.utils.yaml import safe_load
 from airflow.version import version
 from airflow.www.constants import SWAGGER_BUNDLE, SWAGGER_ENABLED
@@ -131,13 +131,18 @@ _MAP_ACCESS_VIEW_TO_FAB_RESOURCE_TYPE = {
 }
 
 
-class FabAuthManager(BaseAuthManager):
+class FabAuthManager(BaseAuthManager[User]):
     """
     Flask-AppBuilder auth manager.
 
     This auth manager is responsible for providing a backward compatible user 
management experience to users.
     """
 
+    def init(self) -> None:
+        """Run operations when Airflow is initializing."""
+        if self.appbuilder:
+            self._sync_appbuilder_roles()
+
     @staticmethod
     def get_cli_commands() -> list[CLICommand]:
         """Vends CLI commands to be included in Airflow CLI."""
@@ -199,9 +204,12 @@ class FabAuthManager(BaseAuthManager):
 
         return current_user
 
-    def init(self) -> None:
-        """Run operations when Airflow is initializing."""
-        self._sync_appbuilder_roles()
+    def deserialize_user(self, token: dict[str, Any]) -> User:
+        with create_session() as session:
+            return session.get(User, token["id"])
+
+    def serialize_user(self, user: User) -> dict[str, Any]:
+        return {"id": user.id}
 
     def is_logged_in(self) -> bool:
         """Return whether the user is logged in."""
@@ -209,8 +217,10 @@ class FabAuthManager(BaseAuthManager):
         if Version(Version(version).base_version) < Version("3.0.0"):
             return not user.is_anonymous and user.is_active
         else:
-            return self.appbuilder.get_app.config.get("AUTH_ROLE_PUBLIC", 
None) or (
-                not user.is_anonymous and user.is_active
+            return (
+                self.appbuilder
+                and self.appbuilder.get_app.config.get("AUTH_ROLE_PUBLIC", 
None)
+                or (not user.is_anonymous and user.is_active)
             )
 
     def is_authorized_configuration(
@@ -376,6 +386,9 @@ class FabAuthManager(BaseAuthManager):
             FabAirflowSecurityManagerOverride,
         )
 
+        if not self.appbuilder:
+            raise AirflowException("AppBuilder is not initialized.")
+
         sm_from_config = 
self.appbuilder.get_app.config.get("SECURITY_MANAGER_CLASS")
         if sm_from_config:
             if not issubclass(sm_from_config, 
FabAirflowSecurityManagerOverride):
@@ -547,6 +560,9 @@ class FabAuthManager(BaseAuthManager):
 
         :meta private:
         """
+        if not self.appbuilder:
+            raise AirflowException("AppBuilder is not initialized.")
+
         if "." in dag_id and hasattr(DagModel, "root_dag_id"):
             return self.appbuilder.get_session.scalar(
                 select(DagModel.dag_id, 
DagModel.root_dag_id).where(DagModel.dag_id == dag_id).limit(1)
diff --git a/providers/tests/fab/auth_manager/test_fab_auth_manager.py 
b/providers/tests/fab/auth_manager/test_fab_auth_manager.py
index d298f7667ea..1994e910f9e 100644
--- a/providers/tests/fab/auth_manager/test_fab_auth_manager.py
+++ b/providers/tests/fab/auth_manager/test_fab_auth_manager.py
@@ -27,6 +27,8 @@ from flask import Flask, g
 
 from airflow.exceptions import AirflowConfigException, AirflowException
 
+from providers.tests.fab.auth_manager.api_endpoints.api_connexion_utils import 
create_user
+
 try:
     from airflow.auth.managers.models.resource_details import AccessView, 
DagAccessEntity, DagDetails
 except ImportError:
@@ -141,6 +143,16 @@ class TestFabAuthManager:
             with user_set(minimal_app_for_auth_api, flask_g_user):
                 assert auth_manager.get_user() == flask_g_user
 
+    def test_deserialize_user(self, flask_app, auth_manager_with_appbuilder):
+        user = create_user(flask_app, "test")
+        result = auth_manager_with_appbuilder.deserialize_user({"id": user.id})
+        assert user == result
+
+    def test_serialize_user(self, flask_app, auth_manager_with_appbuilder):
+        user = create_user(flask_app, "test")
+        result = auth_manager_with_appbuilder.serialize_user(user)
+        assert result == {"id": user.id}
+
     @pytest.mark.db_test
     @mock.patch.object(FabAuthManager, "get_user")
     def test_is_logged_in(self, mock_get_user, auth_manager_with_appbuilder):
@@ -338,11 +350,17 @@ class TestFabAuthManager:
         ],
     )
     def test_is_authorized_dag(
-        self, method, dag_access_entity, dag_details, user_permissions, 
expected_result, auth_manager
+        self,
+        method,
+        dag_access_entity,
+        dag_details,
+        user_permissions,
+        expected_result,
+        auth_manager_with_appbuilder,
     ):
         user = Mock()
         user.perms = user_permissions
-        result = auth_manager.is_authorized_dag(
+        result = auth_manager_with_appbuilder.is_authorized_dag(
             method=method, access_entity=dag_access_entity, 
details=dag_details, user=user
         )
         assert result == expected_result
diff --git a/tests/api_fastapi/core_api/test_security.py 
b/tests/api_fastapi/core_api/test_security.py
new file mode 100644
index 00000000000..90bf3f647bb
--- /dev/null
+++ b/tests/api_fastapi/core_api/test_security.py
@@ -0,0 +1,99 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest.mock import Mock, patch
+
+import pytest
+from fastapi import HTTPException
+from jwt import InvalidTokenError
+
+from airflow.api_fastapi.app import create_app
+from airflow.api_fastapi.core_api.security import get_user, requires_access_dag
+from airflow.auth.managers.models.resource_details import DagAccessEntity
+from airflow.auth.managers.simple.user import SimpleAuthManagerUser
+
+from tests_common.test_utils.config import conf_vars
+
+
+class TestFastApiSecurity:
+    @classmethod
+    def setup_class(cls):
+        with conf_vars(
+            {
+                (
+                    "core",
+                    "auth_manager",
+                ): 
"airflow.auth.managers.simple.simple_auth_manager.SimpleAuthManager",
+            }
+        ):
+            create_app()
+
+    @patch("airflow.api_fastapi.core_api.security.get_signer")
+    @patch("airflow.api_fastapi.core_api.security.get_auth_manager")
+    def test_get_user(self, mock_get_auth_manager, mock_get_signer):
+        token_str = "test-token"
+        user_dict = {"user": "XXXXXXXXX"}
+        user = SimpleAuthManagerUser(username="username", role="admin")
+
+        auth_manager = Mock()
+        auth_manager.deserialize_user.return_value = user
+        mock_get_auth_manager.return_value = auth_manager
+
+        signer = Mock()
+        signer.verify_token.return_value = user_dict
+        mock_get_signer.return_value = signer
+
+        result = get_user(token_str)
+
+        signer.verify_token.assert_called_once_with(token_str)
+        auth_manager.deserialize_user.assert_called_once_with(user_dict)
+        assert result == user
+
+    @patch("airflow.api_fastapi.core_api.security.get_signer")
+    def test_get_user_unsuccessful(self, mock_get_signer):
+        token_str = "test-token"
+
+        signer = Mock()
+        signer.verify_token.side_effect = InvalidTokenError()
+        mock_get_signer.return_value = signer
+
+        with pytest.raises(HTTPException, match="Forbidden"):
+            get_user(token_str)
+
+        signer.verify_token.assert_called_once_with(token_str)
+
+    @patch("airflow.api_fastapi.core_api.security.get_auth_manager")
+    def test_requires_access_dag_authorized(self, mock_get_auth_manager):
+        auth_manager = Mock()
+        auth_manager.is_authorized_dag.return_value = True
+        mock_get_auth_manager.return_value = auth_manager
+
+        requires_access_dag("GET", DagAccessEntity.CODE)("dag-id", Mock())
+
+        auth_manager.is_authorized_dag.assert_called_once()
+
+    @patch("airflow.api_fastapi.core_api.security.get_auth_manager")
+    def test_requires_access_dag_unauthorized(self, mock_get_auth_manager):
+        auth_manager = Mock()
+        auth_manager.is_authorized_dag.return_value = False
+        mock_get_auth_manager.return_value = auth_manager
+
+        with pytest.raises(HTTPException, match="Forbidden"):
+            requires_access_dag("GET", DagAccessEntity.CODE)("dag-id", Mock())
+
+        auth_manager.is_authorized_dag.assert_called_once()
diff --git a/tests/auth/managers/simple/test_simple_auth_manager.py 
b/tests/auth/managers/simple/test_simple_auth_manager.py
index 434c0d60fcc..07289f6f002 100644
--- a/tests/auth/managers/simple/test_simple_auth_manager.py
+++ b/tests/auth/managers/simple/test_simple_auth_manager.py
@@ -132,6 +132,16 @@ class TestSimpleAuthManager:
 
         assert result is None
 
+    def test_deserialize_user(self, auth_manager):
+        result = auth_manager.deserialize_user({"username": "test", "role": 
"admin"})
+        assert result.username == "test"
+        assert result.role == "admin"
+
+    def test_serialize_user(self, auth_manager):
+        user = SimpleAuthManagerUser(username="test", role="admin")
+        result = auth_manager.serialize_user(user)
+        assert result == {"username": "test", "role": "admin"}
+
     @pytest.mark.db_test
     @patch.object(SimpleAuthManager, "is_logged_in")
     @pytest.mark.parametrize(
@@ -280,10 +290,7 @@ class TestSimpleAuthManager:
             assert getattr(auth_manager_with_appbuilder, api)(method=method) 
is result
 
     @pytest.mark.db_test
-    @patch(
-        
"airflow.providers.amazon.aws.auth_manager.views.auth.conf.get_mandatory_value",
 return_value="test"
-    )
-    def test_register_views(self, _, auth_manager_with_appbuilder):
+    def test_register_views(self, auth_manager_with_appbuilder):
         auth_manager_with_appbuilder.appbuilder.add_view_no_menu = Mock()
         auth_manager_with_appbuilder.register_views()
         
auth_manager_with_appbuilder.appbuilder.add_view_no_menu.assert_called_once()
diff --git a/tests/auth/managers/simple/views/test_auth.py 
b/tests/auth/managers/simple/views/test_auth.py
index cd4628c2d02..e3e7b29e2cc 100644
--- a/tests/auth/managers/simple/views/test_auth.py
+++ b/tests/auth/managers/simple/views/test_auth.py
@@ -17,6 +17,7 @@
 from __future__ import annotations
 
 import json
+from unittest.mock import Mock, patch
 
 import pytest
 from flask import session, url_for
@@ -67,11 +68,15 @@ class TestSimpleAuthManagerAuthenticationViews:
         "username, password, is_successful",
         [("test", "test", True), ("test", "test2", False), ("", "", False)],
     )
-    def test_login_submit(self, simple_app, username, password, is_successful):
+    @patch("airflow.auth.managers.simple.views.auth.JWTSigner")
+    def test_login_submit(self, mock_jwt_signer, simple_app, username, 
password, is_successful):
+        signer = Mock()
+        signer.generate_signed_token.return_value = "token"
+        mock_jwt_signer.return_value = signer
         with simple_app.test_client() as client:
             response = client.post("/login_submit", data={"username": 
username, "password": password})
             assert response.status_code == 302
             if is_successful:
-                assert response.location == url_for("Airflow.index")
+                assert response.location == url_for("Airflow.index", 
token="token")
             else:
                 assert response.location == 
url_for("SimpleAuthManagerAuthenticationViews.login", error=["1"])
diff --git a/tests/auth/managers/test_base_auth_manager.py 
b/tests/auth/managers/test_base_auth_manager.py
index 82efe20048b..c62076a4654 100644
--- a/tests/auth/managers/test_base_auth_manager.py
+++ b/tests/auth/managers/test_base_auth_manager.py
@@ -16,13 +16,14 @@
 # under the License.
 from __future__ import annotations
 
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Any
 from unittest.mock import MagicMock, Mock, patch
 
 import pytest
 from flask_appbuilder.menu import Menu
 
 from airflow.auth.managers.base_auth_manager import BaseAuthManager, 
ResourceMethod
+from airflow.auth.managers.models.base_user import BaseUser
 from airflow.auth.managers.models.resource_details import (
     ConnectionDetails,
     DagDetails,
@@ -32,7 +33,6 @@ from airflow.auth.managers.models.resource_details import (
 from airflow.exceptions import AirflowException
 
 if TYPE_CHECKING:
-    from airflow.auth.managers.models.base_user import BaseUser
     from airflow.auth.managers.models.resource_details import (
         AccessView,
         AssetDetails,
@@ -41,10 +41,16 @@ if TYPE_CHECKING:
     )
 
 
-class EmptyAuthManager(BaseAuthManager):
+class EmptyAuthManager(BaseAuthManager[BaseUser]):
     def get_user(self) -> BaseUser:
         raise NotImplementedError()
 
+    def deserialize_user(self, token: dict[str, Any]) -> BaseUser:
+        raise NotImplementedError()
+
+    def serialize_user(self, user: BaseUser) -> dict[str, Any]:
+        raise NotImplementedError()
+
     def is_authorized_configuration(
         self,
         *,
diff --git a/tests/core/test_configuration.py b/tests/core/test_configuration.py
index 2167d71ed5c..1fb3b270573 100644
--- a/tests/core/test_configuration.py
+++ b/tests/core/test_configuration.py
@@ -1584,6 +1584,7 @@ def test_sensitive_values():
         ("database", "sql_alchemy_conn"),
         ("core", "fernet_key"),
         ("core", "internal_api_secret_key"),
+        ("api", "auth_jwt_secret"),
         ("webserver", "secret_key"),
         ("secrets", "backend_kwargs"),
         ("sentry", "sentry_dsn"),


Reply via email to