This is an automated email from the ASF dual-hosted git repository.
vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new ab1e40c029a Check teams defined in auth managers exist in DB when
spinning up API server (#62527)
ab1e40c029a is described below
commit ab1e40c029a592fa91c34ea549ebbcc947f43bb6
Author: Vincent <[email protected]>
AuthorDate: Thu Feb 26 14:20:34 2026 -0500
Check teams defined in auth managers exist in DB when spinning up API
server (#62527)
---
.../api_fastapi/auth/managers/base_auth_manager.py | 22 ++++++++++++----
.../auth/managers/simple/simple_auth_manager.py | 5 ++++
.../auth/managers/simple/services/test_login.py | 5 +++-
.../managers/simple/test_simple_auth_manager.py | 13 ++++++++++
.../auth/managers/test_base_auth_manager.py | 30 ++++++++++++++++++++++
.../amazon/aws/auth_manager/aws_auth_manager.py | 1 +
6 files changed, 70 insertions(+), 6 deletions(-)
diff --git
a/airflow-core/src/airflow/api_fastapi/auth/managers/base_auth_manager.py
b/airflow-core/src/airflow/api_fastapi/auth/managers/base_auth_manager.py
index 122111a85e5..a26a4413aad 100644
--- a/airflow-core/src/airflow/api_fastapi/auth/managers/base_auth_manager.py
+++ b/airflow-core/src/airflow/api_fastapi/auth/managers/base_auth_manager.py
@@ -119,11 +119,15 @@ class BaseAuthManager(Generic[T], LoggingMixin,
metaclass=ABCMeta):
"""
def init(self) -> None:
- """
- Run operations when Airflow is initializing.
-
- By default, do nothing.
- """
+ """Run operations when Airflow is initializing."""
+ if conf.getboolean("core", "multi_team"):
+ am_teams = self._get_teams()
+ db_teams = Team.get_all_team_names()
+
+ if not db_teams.issuperset(am_teams):
+ raise ValueError(
+ f"Teams defined in the auth manager ({am_teams}) are not
present in the database ({db_teams})."
+ )
@abstractmethod
def deserialize_user(self, token: dict[str, Any]) -> T:
@@ -798,6 +802,14 @@ class BaseAuthManager(Generic[T], LoggingMixin,
metaclass=ABCMeta):
"""
return []
+ def _get_teams(self) -> set[str]:
+ """
+ Return the set of teams defined in the auth manager.
+
+ This method is used only when the Airflow environment is configured in
multi-team mode.
+ """
+ raise NotImplementedError()
+
@staticmethod
def get_db_manager() -> str | None:
"""
diff --git
a/airflow-core/src/airflow/api_fastapi/auth/managers/simple/simple_auth_manager.py
b/airflow-core/src/airflow/api_fastapi/auth/managers/simple/simple_auth_manager.py
index 1c546798616..4d7c9850351 100644
---
a/airflow-core/src/airflow/api_fastapi/auth/managers/simple/simple_auth_manager.py
+++
b/airflow-core/src/airflow/api_fastapi/auth/managers/simple/simple_auth_manager.py
@@ -118,6 +118,7 @@ class
SimpleAuthManager(BaseAuthManager[SimpleAuthManagerUser]):
return SimpleAuthManager._get_passwords(file)
def init(self) -> None:
+ super().init()
is_simple_auth_manager_all_admins = conf.getboolean("core",
"simple_auth_manager_all_admins")
if is_simple_auth_manager_all_admins:
return
@@ -360,6 +361,10 @@ class
SimpleAuthManager(BaseAuthManager[SimpleAuthManagerUser]):
return app
+ def _get_teams(self) -> set[str]:
+ users = self.get_users()
+ return {team for user in users for team in user.teams}
+
@staticmethod
def _is_admin(user: SimpleAuthManagerUser) -> bool:
"""Return whether the user has the Admin role."""
diff --git
a/airflow-core/tests/unit/api_fastapi/auth/managers/simple/services/test_login.py
b/airflow-core/tests/unit/api_fastapi/auth/managers/simple/services/test_login.py
index 7c1a3f95677..3318f803cc7 100644
---
a/airflow-core/tests/unit/api_fastapi/auth/managers/simple/services/test_login.py
+++
b/airflow-core/tests/unit/api_fastapi/auth/managers/simple/services/test_login.py
@@ -25,6 +25,7 @@ from fastapi import HTTPException
from airflow.api_fastapi.auth.managers.simple.datamodels.login import LoginBody
from airflow.api_fastapi.auth.managers.simple.services.login import
SimpleAuthManagerLogin
+from airflow.models.team import Team
from tests_common.test_utils.config import conf_vars
@@ -48,10 +49,12 @@ class TestLogin:
(TEST_USER_3, TEST_ROLE_3, ["test", "marketing"]),
],
)
+ @patch.object(Team, "get_all_team_names")
@patch("airflow.api_fastapi.auth.managers.simple.services.login.get_auth_manager")
- def test_create_token(self, get_auth_manager, auth_manager, user, role,
teams):
+ def test_create_token(self, get_auth_manager, mock_get_all_team_names,
auth_manager, user, role, teams):
mock_am = Mock(wraps=auth_manager)
get_auth_manager.return_value = mock_am
+ mock_get_all_team_names.return_value = {"test", "marketing"}
with conf_vars(
{
diff --git
a/airflow-core/tests/unit/api_fastapi/auth/managers/simple/test_simple_auth_manager.py
b/airflow-core/tests/unit/api_fastapi/auth/managers/simple/test_simple_auth_manager.py
index 3ce33c3efc4..1e9f195c8b7 100644
---
a/airflow-core/tests/unit/api_fastapi/auth/managers/simple/test_simple_auth_manager.py
+++
b/airflow-core/tests/unit/api_fastapi/auth/managers/simple/test_simple_auth_manager.py
@@ -354,3 +354,16 @@ class TestSimpleAuthManager:
user = SimpleAuthManagerUser(username=user_id, role="user")
result =
auth_manager.is_authorized_hitl_task(assigned_users=assigned_users, user=user)
assert result == expected
+
+ @conf_vars(
+ {
+ ("core", "multi_team"): "true",
+ (
+ "core",
+ "simple_auth_manager_users",
+ ): "test1:viewer,test2:viewer:test,test3:viewer:test|marketing",
+ }
+ )
+ def test_get_teams(self, auth_manager):
+ teams = auth_manager._get_teams()
+ assert teams == {"test", "marketing"}
diff --git
a/airflow-core/tests/unit/api_fastapi/auth/managers/test_base_auth_manager.py
b/airflow-core/tests/unit/api_fastapi/auth/managers/test_base_auth_manager.py
index 22efbc426ae..d82b9009c49 100644
---
a/airflow-core/tests/unit/api_fastapi/auth/managers/test_base_auth_manager.py
+++
b/airflow-core/tests/unit/api_fastapi/auth/managers/test_base_auth_manager.py
@@ -33,6 +33,9 @@ from
airflow.api_fastapi.auth.managers.models.resource_details import (
)
from airflow.api_fastapi.auth.tokens import JWTGenerator, JWTValidator
from airflow.api_fastapi.common.types import MenuItem
+from airflow.models.team import Team
+
+from tests_common.test_utils.config import conf_vars
if TYPE_CHECKING:
from airflow.api_fastapi.auth.managers.base_auth_manager import
ResourceMethod
@@ -150,6 +153,33 @@ def auth_manager():
class TestBaseAuthManager:
+ def test_init_non_multi_team_mode(self, auth_manager):
+ assert auth_manager.init() is None
+
+ @conf_vars({("core", "multi_team"): "True"})
+ @pytest.mark.parametrize(
+ ("auth_manager_teams", "db_teams", "expected"),
+ [
+ ({"teamA", "teamB"}, {"teamA", "teamB"}, True),
+ ({"teamA", "teamB"}, {"teamA", "teamB", "teamC"}, True),
+ (set(), set(), True),
+ ({"teamA", "teamB"}, {"teamA", "teamC"}, False),
+ ],
+ )
+ @patch.object(Team, "get_all_team_names")
+ @patch.object(EmptyAuthManager, "_get_teams")
+ def test_init_multi_team_mode(
+ self, mock_get_teams, mock_get_all_team_names, auth_manager_teams,
db_teams, expected, auth_manager
+ ):
+ mock_get_teams.return_value = auth_manager_teams
+ mock_get_all_team_names.return_value = db_teams
+
+ if expected:
+ assert auth_manager.init() is None
+ else:
+ with pytest.raises(ValueError, match="Teams defined in the auth
manager"):
+ auth_manager.init()
+
def test_get_cli_commands_return_empty_list(self, auth_manager):
assert auth_manager.get_cli_commands() == []
diff --git
a/providers/amazon/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py
b/providers/amazon/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py
index 9d3f2fabe9e..f20f4f0a075 100644
---
a/providers/amazon/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py
+++
b/providers/amazon/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py
@@ -71,6 +71,7 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
"""
def init(self) -> None:
+ super().init()
if not AIRFLOW_V_3_0_PLUS:
raise AirflowOptionalProviderFeatureException(
"AWS auth manager is only compatible with Airflow versions >=
3.0.0"