This is an automated email from the ASF dual-hosted git repository.
pierrejeambrun pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 1242e07c0fe Refactor AuthManager from app.state to FastAPI dependency
(#57665)
1242e07c0fe is described below
commit 1242e07c0fe4577059c841c97ba09063ec391247
Author: LIU ZHE YOU <[email protected]>
AuthorDate: Tue Nov 4 23:54:29 2025 +0800
Refactor AuthManager from app.state to FastAPI dependency (#57665)
- Create new auth_manager.py module in common with AuthManagerDep
- Update security.py to use AuthManagerDep instead of
request.app.state.auth_manager
- Update auth.py routes to use AuthManagerDep
- Remove Request parameter where no longer needed
- Follow the same dependency injection pattern as DagBag
Add unit tests for auth_manager dependency injection
- Create test_auth_manager.py to test the new dependency
- Verify auth_manager_from_app correctly retrieves from app.state
- Test integration with existing test client fixture
Fix linting issues with ruff
- Move BaseAuthManager import out of TYPE_CHECKING block
- Fix import ordering in security.py
- Remove unused pytest import from test
- Remove trailing whitespace
- Format code with ruff format
Move auth_manager dependency to security.py module
- Move auth_manager_from_app and AuthManagerDep from common/auth_manager.py
to core_api/security.py
- Update import in routes/public/auth.py to use security module
- Move tests from common/test_auth_manager.py to core_api/test_security.py
- Delete now-unused common/auth_manager.py and common/test_auth_manager.py
- Import BaseAuthManager directly in security.py (not in TYPE_CHECKING)
Co-authored-by: copilot-swe-agent[bot]
<[email protected]>
---
.../api_fastapi/core_api/routes/public/auth.py | 12 +++---
.../src/airflow/api_fastapi/core_api/security.py | 45 ++++++++++++++--------
.../unit/api_fastapi/core_api/test_security.py | 37 ++++++++++++++++++
3 files changed, 71 insertions(+), 23 deletions(-)
diff --git
a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/auth.py
b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/auth.py
index b8f6d204d2e..59610e0fc94 100644
--- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/auth.py
+++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/auth.py
@@ -21,7 +21,7 @@ from fastapi.responses import RedirectResponse
from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.core_api.openapi.exceptions import
create_openapi_http_exception_doc
-from airflow.api_fastapi.core_api.security import is_safe_url
+from airflow.api_fastapi.core_api.security import AuthManagerDep, is_safe_url
auth_router = AirflowRouter(tags=["Login"], prefix="/auth")
@@ -30,9 +30,9 @@ auth_router = AirflowRouter(tags=["Login"], prefix="/auth")
"/login",
responses=create_openapi_http_exception_doc([status.HTTP_307_TEMPORARY_REDIRECT]),
)
-def login(request: Request, next: None | str = None) -> RedirectResponse:
+def login(request: Request, auth_manager: AuthManagerDep, next: None | str =
None) -> RedirectResponse:
"""Redirect to the login URL depending on the AuthManager configured."""
- login_url = request.app.state.auth_manager.get_url_login()
+ login_url = auth_manager.get_url_login()
if next and not is_safe_url(next, request=request):
raise HTTPException(status_code=400, detail="Invalid or unsafe next
URL")
@@ -47,11 +47,11 @@ def login(request: Request, next: None | str = None) ->
RedirectResponse:
"/logout",
responses=create_openapi_http_exception_doc([status.HTTP_307_TEMPORARY_REDIRECT]),
)
-def logout(request: Request, next: None | str = None) -> RedirectResponse:
+def logout(auth_manager: AuthManagerDep, next: None | str = None) ->
RedirectResponse:
"""Logout the user."""
- logout_url = request.app.state.auth_manager.get_url_logout()
+ logout_url = auth_manager.get_url_logout()
if not logout_url:
- logout_url = request.app.state.auth_manager.get_url_login()
+ logout_url = auth_manager.get_url_login()
return RedirectResponse(logout_url)
diff --git a/airflow-core/src/airflow/api_fastapi/core_api/security.py
b/airflow-core/src/airflow/api_fastapi/core_api/security.py
index 5e7d676bf9d..05ad624353b 100644
--- a/airflow-core/src/airflow/api_fastapi/core_api/security.py
+++ b/airflow-core/src/airflow/api_fastapi/core_api/security.py
@@ -27,7 +27,10 @@ from jwt import ExpiredSignatureError, InvalidTokenError
from pydantic import NonNegativeInt
from airflow.api_fastapi.app import get_auth_manager
-from airflow.api_fastapi.auth.managers.base_auth_manager import
COOKIE_NAME_JWT_TOKEN
+from airflow.api_fastapi.auth.managers.base_auth_manager import (
+ COOKIE_NAME_JWT_TOKEN,
+ BaseAuthManager,
+)
from airflow.api_fastapi.auth.managers.models.base_user import BaseUser
from airflow.api_fastapi.auth.managers.models.batch_apis import (
IsAuthorizedConnectionRequest,
@@ -70,7 +73,20 @@ if TYPE_CHECKING:
from fastapi.security import HTTPAuthorizationCredentials
from sqlalchemy.sql import Select
- from airflow.api_fastapi.auth.managers.base_auth_manager import
BaseAuthManager, ResourceMethod
+ from airflow.api_fastapi.auth.managers.base_auth_manager import
ResourceMethod
+
+
+def auth_manager_from_app(request: Request) -> BaseAuthManager:
+ """
+ FastAPI dependency resolver that returns the shared AuthManager instance
from app.state.
+
+ This ensures that all API routes using AuthManager via dependency
injection receive the same
+ singleton instance that was initialized at app startup.
+ """
+ return request.app.state.auth_manager
+
+
+AuthManagerDep = Annotated[BaseAuthManager, Depends(auth_manager_from_app)]
auth_description = (
"To authenticate Airflow API requests, clients must include a JWT (JSON
Web Token) in "
@@ -196,7 +212,7 @@ class PermittedTagFilter(PermittedDagFilter):
def permitted_dag_filter_factory(
method: ResourceMethod, filter_class=PermittedDagFilter
-) -> Callable[[Request, BaseUser], PermittedDagFilter]:
+) -> Callable[[BaseUser, BaseAuthManager], PermittedDagFilter]:
"""
Create a callable for Depends in FastAPI that returns a filter of the
permitted dags for the user.
@@ -205,10 +221,9 @@ def permitted_dag_filter_factory(
"""
def depends_permitted_dags_filter(
- request: Request,
user: GetUserDep,
+ auth_manager: AuthManagerDep,
) -> PermittedDagFilter:
- auth_manager: BaseAuthManager = request.app.state.auth_manager
authorized_dags: set[str] =
auth_manager.get_authorized_dag_ids(user=user, method=method)
return filter_class(authorized_dags)
@@ -260,7 +275,7 @@ class PermittedPoolFilter(OrmClause[set[str]]):
def permitted_pool_filter_factory(
method: ResourceMethod,
-) -> Callable[[Request, BaseUser], PermittedPoolFilter]:
+) -> Callable[[BaseUser, BaseAuthManager], PermittedPoolFilter]:
"""
Create a callable for Depends in FastAPI that returns a filter of the
permitted pools for the user.
@@ -268,10 +283,9 @@ def permitted_pool_filter_factory(
"""
def depends_permitted_pools_filter(
- request: Request,
user: GetUserDep,
+ auth_manager: AuthManagerDep,
) -> PermittedPoolFilter:
- auth_manager: BaseAuthManager = request.app.state.auth_manager
authorized_pools: set[str] =
auth_manager.get_authorized_pools(user=user, method=method)
return PermittedPoolFilter(authorized_pools)
@@ -353,7 +367,7 @@ class PermittedConnectionFilter(OrmClause[set[str]]):
def permitted_connection_filter_factory(
method: ResourceMethod,
-) -> Callable[[Request, BaseUser], PermittedConnectionFilter]:
+) -> Callable[[BaseUser, BaseAuthManager], PermittedConnectionFilter]:
"""
Create a callable for Depends in FastAPI that returns a filter of the
permitted connections for the user.
@@ -361,10 +375,9 @@ def permitted_connection_filter_factory(
"""
def depends_permitted_connections_filter(
- request: Request,
user: GetUserDep,
+ auth_manager: AuthManagerDep,
) -> PermittedConnectionFilter:
- auth_manager: BaseAuthManager = request.app.state.auth_manager
authorized_connections: set[str] =
auth_manager.get_authorized_connections(user=user, method=method)
return PermittedConnectionFilter(authorized_connections)
@@ -470,14 +483,13 @@ class PermittedTeamFilter(OrmClause[set[str]]):
return select.where(Team.name.in_(self.value))
-def permitted_team_filter_factory() -> Callable[[Request, BaseUser],
PermittedTeamFilter]:
+def permitted_team_filter_factory() -> Callable[[BaseUser, BaseAuthManager],
PermittedTeamFilter]:
"""Create a callable for Depends in FastAPI that returns a filter of the
permitted teams for the user."""
def depends_permitted_teams_filter(
- request: Request,
user: GetUserDep,
+ auth_manager: AuthManagerDep,
) -> PermittedTeamFilter:
- auth_manager: BaseAuthManager = request.app.state.auth_manager
authorized_teams: set[str] =
auth_manager.get_authorized_teams(user=user, method="GET")
return PermittedTeamFilter(authorized_teams)
@@ -496,7 +508,7 @@ class PermittedVariableFilter(OrmClause[set[str]]):
def permitted_variable_filter_factory(
method: ResourceMethod,
-) -> Callable[[Request, BaseUser], PermittedVariableFilter]:
+) -> Callable[[BaseUser, BaseAuthManager], PermittedVariableFilter]:
"""
Create a callable for Depends in FastAPI that returns a filter of the
permitted variables for the user.
@@ -504,10 +516,9 @@ def permitted_variable_filter_factory(
"""
def depends_permitted_variables_filter(
- request: Request,
user: GetUserDep,
+ auth_manager: AuthManagerDep,
) -> PermittedVariableFilter:
- auth_manager: BaseAuthManager = request.app.state.auth_manager
authorized_variables: set[str] =
auth_manager.get_authorized_variables(user=user, method=method)
return PermittedVariableFilter(authorized_variables)
diff --git a/airflow-core/tests/unit/api_fastapi/core_api/test_security.py
b/airflow-core/tests/unit/api_fastapi/core_api/test_security.py
index 026dd10fa31..6aea3b23528 100644
--- a/airflow-core/tests/unit/api_fastapi/core_api/test_security.py
+++ b/airflow-core/tests/unit/api_fastapi/core_api/test_security.py
@@ -484,3 +484,40 @@ class TestFastApiSecurity:
],
user=user,
)
+
+
+class TestAuthManagerDependency:
+ """Test the auth_manager_from_app dependency function."""
+
+ def test_auth_manager_from_app_returns_instance_from_state(self):
+ """Test that auth_manager_from_app correctly retrieves auth_manager
from app.state."""
+ from airflow.api_fastapi.core_api.security import auth_manager_from_app
+
+ # Create a mock auth manager
+ mock_auth_manager = Mock()
+
+ # Create a mock request with app.state.auth_manager
+ mock_request = Mock()
+ mock_request.app.state.auth_manager = mock_auth_manager
+
+ # Call the dependency function
+ result = auth_manager_from_app(mock_request)
+
+ # Assert it returns the correct auth manager
+ assert result is mock_auth_manager
+
+ def test_auth_manager_from_app_integration_with_test_client(self,
test_client):
+ """Test that auth_manager_from_app works with the test client setup."""
+ from airflow.api_fastapi.core_api.security import auth_manager_from_app
+
+ # Create a mock request using the test client's app
+ mock_request = Mock()
+ mock_request.app = test_client.app
+
+ # Get the auth manager
+ auth_manager = auth_manager_from_app(mock_request)
+
+ # Verify it's not None (should be SimpleAuthManager from test fixture)
+ assert auth_manager is not None
+ assert hasattr(auth_manager, "get_url_login")
+ assert hasattr(auth_manager, "get_url_logout")