This is an automated email from the ASF dual-hosted git repository.

pierrejeambrun pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 1242e07c0fe Refactor AuthManager from app.state to FastAPI dependency 
(#57665)
1242e07c0fe is described below

commit 1242e07c0fe4577059c841c97ba09063ec391247
Author: LIU ZHE YOU <[email protected]>
AuthorDate: Tue Nov 4 23:54:29 2025 +0800

    Refactor AuthManager from app.state to FastAPI dependency (#57665)
    
    - Create new auth_manager.py module in common with AuthManagerDep
    - Update security.py to use AuthManagerDep instead of 
request.app.state.auth_manager
    - Update auth.py routes to use AuthManagerDep
    - Remove Request parameter where no longer needed
    - Follow the same dependency injection pattern as DagBag
    
    
    
    Add unit tests for auth_manager dependency injection
    
    - Create test_auth_manager.py to test the new dependency
    - Verify auth_manager_from_app correctly retrieves from app.state
    - Test integration with existing test client fixture
    
    
    
    Fix linting issues with ruff
    
    - Move BaseAuthManager import out of TYPE_CHECKING block
    - Fix import ordering in security.py
    - Remove unused pytest import from test
    - Remove trailing whitespace
    - Format code with ruff format
    
    
    
    Move auth_manager dependency to security.py module
    
    - Move auth_manager_from_app and AuthManagerDep from common/auth_manager.py 
to core_api/security.py
    - Update import in routes/public/auth.py to use security module
    - Move tests from common/test_auth_manager.py to core_api/test_security.py
    - Delete now-unused common/auth_manager.py and common/test_auth_manager.py
    - Import BaseAuthManager directly in security.py (not in TYPE_CHECKING)
    
    Co-authored-by: copilot-swe-agent[bot] 
<[email protected]>
---
 .../api_fastapi/core_api/routes/public/auth.py     | 12 +++---
 .../src/airflow/api_fastapi/core_api/security.py   | 45 ++++++++++++++--------
 .../unit/api_fastapi/core_api/test_security.py     | 37 ++++++++++++++++++
 3 files changed, 71 insertions(+), 23 deletions(-)

diff --git 
a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/auth.py 
b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/auth.py
index b8f6d204d2e..59610e0fc94 100644
--- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/auth.py
+++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/auth.py
@@ -21,7 +21,7 @@ from fastapi.responses import RedirectResponse
 
 from airflow.api_fastapi.common.router import AirflowRouter
 from airflow.api_fastapi.core_api.openapi.exceptions import 
create_openapi_http_exception_doc
-from airflow.api_fastapi.core_api.security import is_safe_url
+from airflow.api_fastapi.core_api.security import AuthManagerDep, is_safe_url
 
 auth_router = AirflowRouter(tags=["Login"], prefix="/auth")
 
@@ -30,9 +30,9 @@ auth_router = AirflowRouter(tags=["Login"], prefix="/auth")
     "/login",
     
responses=create_openapi_http_exception_doc([status.HTTP_307_TEMPORARY_REDIRECT]),
 )
-def login(request: Request, next: None | str = None) -> RedirectResponse:
+def login(request: Request, auth_manager: AuthManagerDep, next: None | str = 
None) -> RedirectResponse:
     """Redirect to the login URL depending on the AuthManager configured."""
-    login_url = request.app.state.auth_manager.get_url_login()
+    login_url = auth_manager.get_url_login()
 
     if next and not is_safe_url(next, request=request):
         raise HTTPException(status_code=400, detail="Invalid or unsafe next 
URL")
@@ -47,11 +47,11 @@ def login(request: Request, next: None | str = None) -> 
RedirectResponse:
     "/logout",
     
responses=create_openapi_http_exception_doc([status.HTTP_307_TEMPORARY_REDIRECT]),
 )
-def logout(request: Request, next: None | str = None) -> RedirectResponse:
+def logout(auth_manager: AuthManagerDep, next: None | str = None) -> 
RedirectResponse:
     """Logout the user."""
-    logout_url = request.app.state.auth_manager.get_url_logout()
+    logout_url = auth_manager.get_url_logout()
 
     if not logout_url:
-        logout_url = request.app.state.auth_manager.get_url_login()
+        logout_url = auth_manager.get_url_login()
 
     return RedirectResponse(logout_url)
diff --git a/airflow-core/src/airflow/api_fastapi/core_api/security.py 
b/airflow-core/src/airflow/api_fastapi/core_api/security.py
index 5e7d676bf9d..05ad624353b 100644
--- a/airflow-core/src/airflow/api_fastapi/core_api/security.py
+++ b/airflow-core/src/airflow/api_fastapi/core_api/security.py
@@ -27,7 +27,10 @@ from jwt import ExpiredSignatureError, InvalidTokenError
 from pydantic import NonNegativeInt
 
 from airflow.api_fastapi.app import get_auth_manager
-from airflow.api_fastapi.auth.managers.base_auth_manager import 
COOKIE_NAME_JWT_TOKEN
+from airflow.api_fastapi.auth.managers.base_auth_manager import (
+    COOKIE_NAME_JWT_TOKEN,
+    BaseAuthManager,
+)
 from airflow.api_fastapi.auth.managers.models.base_user import BaseUser
 from airflow.api_fastapi.auth.managers.models.batch_apis import (
     IsAuthorizedConnectionRequest,
@@ -70,7 +73,20 @@ if TYPE_CHECKING:
     from fastapi.security import HTTPAuthorizationCredentials
     from sqlalchemy.sql import Select
 
-    from airflow.api_fastapi.auth.managers.base_auth_manager import 
BaseAuthManager, ResourceMethod
+    from airflow.api_fastapi.auth.managers.base_auth_manager import 
ResourceMethod
+
+
+def auth_manager_from_app(request: Request) -> BaseAuthManager:
+    """
+    FastAPI dependency resolver that returns the shared AuthManager instance 
from app.state.
+
+    This ensures that all API routes using AuthManager via dependency 
injection receive the same
+    singleton instance that was initialized at app startup.
+    """
+    return request.app.state.auth_manager
+
+
+AuthManagerDep = Annotated[BaseAuthManager, Depends(auth_manager_from_app)]
 
 auth_description = (
     "To authenticate Airflow API requests, clients must include a JWT (JSON 
Web Token) in "
@@ -196,7 +212,7 @@ class PermittedTagFilter(PermittedDagFilter):
 
 def permitted_dag_filter_factory(
     method: ResourceMethod, filter_class=PermittedDagFilter
-) -> Callable[[Request, BaseUser], PermittedDagFilter]:
+) -> Callable[[BaseUser, BaseAuthManager], PermittedDagFilter]:
     """
     Create a callable for Depends in FastAPI that returns a filter of the 
permitted dags for the user.
 
@@ -205,10 +221,9 @@ def permitted_dag_filter_factory(
     """
 
     def depends_permitted_dags_filter(
-        request: Request,
         user: GetUserDep,
+        auth_manager: AuthManagerDep,
     ) -> PermittedDagFilter:
-        auth_manager: BaseAuthManager = request.app.state.auth_manager
         authorized_dags: set[str] = 
auth_manager.get_authorized_dag_ids(user=user, method=method)
         return filter_class(authorized_dags)
 
@@ -260,7 +275,7 @@ class PermittedPoolFilter(OrmClause[set[str]]):
 
 def permitted_pool_filter_factory(
     method: ResourceMethod,
-) -> Callable[[Request, BaseUser], PermittedPoolFilter]:
+) -> Callable[[BaseUser, BaseAuthManager], PermittedPoolFilter]:
     """
     Create a callable for Depends in FastAPI that returns a filter of the 
permitted pools for the user.
 
@@ -268,10 +283,9 @@ def permitted_pool_filter_factory(
     """
 
     def depends_permitted_pools_filter(
-        request: Request,
         user: GetUserDep,
+        auth_manager: AuthManagerDep,
     ) -> PermittedPoolFilter:
-        auth_manager: BaseAuthManager = request.app.state.auth_manager
         authorized_pools: set[str] = 
auth_manager.get_authorized_pools(user=user, method=method)
         return PermittedPoolFilter(authorized_pools)
 
@@ -353,7 +367,7 @@ class PermittedConnectionFilter(OrmClause[set[str]]):
 
 def permitted_connection_filter_factory(
     method: ResourceMethod,
-) -> Callable[[Request, BaseUser], PermittedConnectionFilter]:
+) -> Callable[[BaseUser, BaseAuthManager], PermittedConnectionFilter]:
     """
     Create a callable for Depends in FastAPI that returns a filter of the 
permitted connections for the user.
 
@@ -361,10 +375,9 @@ def permitted_connection_filter_factory(
     """
 
     def depends_permitted_connections_filter(
-        request: Request,
         user: GetUserDep,
+        auth_manager: AuthManagerDep,
     ) -> PermittedConnectionFilter:
-        auth_manager: BaseAuthManager = request.app.state.auth_manager
         authorized_connections: set[str] = 
auth_manager.get_authorized_connections(user=user, method=method)
         return PermittedConnectionFilter(authorized_connections)
 
@@ -470,14 +483,13 @@ class PermittedTeamFilter(OrmClause[set[str]]):
         return select.where(Team.name.in_(self.value))
 
 
-def permitted_team_filter_factory() -> Callable[[Request, BaseUser], 
PermittedTeamFilter]:
+def permitted_team_filter_factory() -> Callable[[BaseUser, BaseAuthManager], 
PermittedTeamFilter]:
     """Create a callable for Depends in FastAPI that returns a filter of the 
permitted teams for the user."""
 
     def depends_permitted_teams_filter(
-        request: Request,
         user: GetUserDep,
+        auth_manager: AuthManagerDep,
     ) -> PermittedTeamFilter:
-        auth_manager: BaseAuthManager = request.app.state.auth_manager
         authorized_teams: set[str] = 
auth_manager.get_authorized_teams(user=user, method="GET")
         return PermittedTeamFilter(authorized_teams)
 
@@ -496,7 +508,7 @@ class PermittedVariableFilter(OrmClause[set[str]]):
 
 def permitted_variable_filter_factory(
     method: ResourceMethod,
-) -> Callable[[Request, BaseUser], PermittedVariableFilter]:
+) -> Callable[[BaseUser, BaseAuthManager], PermittedVariableFilter]:
     """
     Create a callable for Depends in FastAPI that returns a filter of the 
permitted variables for the user.
 
@@ -504,10 +516,9 @@ def permitted_variable_filter_factory(
     """
 
     def depends_permitted_variables_filter(
-        request: Request,
         user: GetUserDep,
+        auth_manager: AuthManagerDep,
     ) -> PermittedVariableFilter:
-        auth_manager: BaseAuthManager = request.app.state.auth_manager
         authorized_variables: set[str] = 
auth_manager.get_authorized_variables(user=user, method=method)
         return PermittedVariableFilter(authorized_variables)
 
diff --git a/airflow-core/tests/unit/api_fastapi/core_api/test_security.py 
b/airflow-core/tests/unit/api_fastapi/core_api/test_security.py
index 026dd10fa31..6aea3b23528 100644
--- a/airflow-core/tests/unit/api_fastapi/core_api/test_security.py
+++ b/airflow-core/tests/unit/api_fastapi/core_api/test_security.py
@@ -484,3 +484,40 @@ class TestFastApiSecurity:
             ],
             user=user,
         )
+
+
+class TestAuthManagerDependency:
+    """Test the auth_manager_from_app dependency function."""
+
+    def test_auth_manager_from_app_returns_instance_from_state(self):
+        """Test that auth_manager_from_app correctly retrieves auth_manager 
from app.state."""
+        from airflow.api_fastapi.core_api.security import auth_manager_from_app
+
+        # Create a mock auth manager
+        mock_auth_manager = Mock()
+
+        # Create a mock request with app.state.auth_manager
+        mock_request = Mock()
+        mock_request.app.state.auth_manager = mock_auth_manager
+
+        # Call the dependency function
+        result = auth_manager_from_app(mock_request)
+
+        # Assert it returns the correct auth manager
+        assert result is mock_auth_manager
+
+    def test_auth_manager_from_app_integration_with_test_client(self, 
test_client):
+        """Test that auth_manager_from_app works with the test client setup."""
+        from airflow.api_fastapi.core_api.security import auth_manager_from_app
+
+        # Create a mock request using the test client's app
+        mock_request = Mock()
+        mock_request.app = test_client.app
+
+        # Get the auth manager
+        auth_manager = auth_manager_from_app(mock_request)
+
+        # Verify it's not None (should be SimpleAuthManager from test fixture)
+        assert auth_manager is not None
+        assert hasattr(auth_manager, "get_url_login")
+        assert hasattr(auth_manager, "get_url_logout")

Reply via email to