This is an automated email from the ASF dual-hosted git repository.

vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new ab1e40c029a Check teams defined in auth managers exist in DB when 
spinning up API server (#62527)
ab1e40c029a is described below

commit ab1e40c029a592fa91c34ea549ebbcc947f43bb6
Author: Vincent <[email protected]>
AuthorDate: Thu Feb 26 14:20:34 2026 -0500

    Check teams defined in auth managers exist in DB when spinning up API 
server (#62527)
---
 .../api_fastapi/auth/managers/base_auth_manager.py | 22 ++++++++++++----
 .../auth/managers/simple/simple_auth_manager.py    |  5 ++++
 .../auth/managers/simple/services/test_login.py    |  5 +++-
 .../managers/simple/test_simple_auth_manager.py    | 13 ++++++++++
 .../auth/managers/test_base_auth_manager.py        | 30 ++++++++++++++++++++++
 .../amazon/aws/auth_manager/aws_auth_manager.py    |  1 +
 6 files changed, 70 insertions(+), 6 deletions(-)

diff --git 
a/airflow-core/src/airflow/api_fastapi/auth/managers/base_auth_manager.py 
b/airflow-core/src/airflow/api_fastapi/auth/managers/base_auth_manager.py
index 122111a85e5..a26a4413aad 100644
--- a/airflow-core/src/airflow/api_fastapi/auth/managers/base_auth_manager.py
+++ b/airflow-core/src/airflow/api_fastapi/auth/managers/base_auth_manager.py
@@ -119,11 +119,15 @@ class BaseAuthManager(Generic[T], LoggingMixin, 
metaclass=ABCMeta):
     """
 
     def init(self) -> None:
-        """
-        Run operations when Airflow is initializing.
-
-        By default, do nothing.
-        """
+        """Run operations when Airflow is initializing."""
+        if conf.getboolean("core", "multi_team"):
+            am_teams = self._get_teams()
+            db_teams = Team.get_all_team_names()
+
+            if not db_teams.issuperset(am_teams):
+                raise ValueError(
+                    f"Teams defined in the auth manager ({am_teams}) are not 
present in the database ({db_teams})."
+                )
 
     @abstractmethod
     def deserialize_user(self, token: dict[str, Any]) -> T:
@@ -798,6 +802,14 @@ class BaseAuthManager(Generic[T], LoggingMixin, 
metaclass=ABCMeta):
         """
         return []
 
+    def _get_teams(self) -> set[str]:
+        """
+        Return the set of teams defined in the auth manager.
+
+        This method is used only when the Airflow environment is configured in 
multi-team mode.
+        """
+        raise NotImplementedError()
+
     @staticmethod
     def get_db_manager() -> str | None:
         """
diff --git 
a/airflow-core/src/airflow/api_fastapi/auth/managers/simple/simple_auth_manager.py
 
b/airflow-core/src/airflow/api_fastapi/auth/managers/simple/simple_auth_manager.py
index 1c546798616..4d7c9850351 100644
--- 
a/airflow-core/src/airflow/api_fastapi/auth/managers/simple/simple_auth_manager.py
+++ 
b/airflow-core/src/airflow/api_fastapi/auth/managers/simple/simple_auth_manager.py
@@ -118,6 +118,7 @@ class 
SimpleAuthManager(BaseAuthManager[SimpleAuthManagerUser]):
             return SimpleAuthManager._get_passwords(file)
 
     def init(self) -> None:
+        super().init()
         is_simple_auth_manager_all_admins = conf.getboolean("core", 
"simple_auth_manager_all_admins")
         if is_simple_auth_manager_all_admins:
             return
@@ -360,6 +361,10 @@ class 
SimpleAuthManager(BaseAuthManager[SimpleAuthManagerUser]):
 
         return app
 
+    def _get_teams(self) -> set[str]:
+        users = self.get_users()
+        return {team for user in users for team in user.teams}
+
     @staticmethod
     def _is_admin(user: SimpleAuthManagerUser) -> bool:
         """Return whether the user has the Admin role."""
diff --git 
a/airflow-core/tests/unit/api_fastapi/auth/managers/simple/services/test_login.py
 
b/airflow-core/tests/unit/api_fastapi/auth/managers/simple/services/test_login.py
index 7c1a3f95677..3318f803cc7 100644
--- 
a/airflow-core/tests/unit/api_fastapi/auth/managers/simple/services/test_login.py
+++ 
b/airflow-core/tests/unit/api_fastapi/auth/managers/simple/services/test_login.py
@@ -25,6 +25,7 @@ from fastapi import HTTPException
 
 from airflow.api_fastapi.auth.managers.simple.datamodels.login import LoginBody
 from airflow.api_fastapi.auth.managers.simple.services.login import 
SimpleAuthManagerLogin
+from airflow.models.team import Team
 
 from tests_common.test_utils.config import conf_vars
 
@@ -48,10 +49,12 @@ class TestLogin:
             (TEST_USER_3, TEST_ROLE_3, ["test", "marketing"]),
         ],
     )
+    @patch.object(Team, "get_all_team_names")
     
@patch("airflow.api_fastapi.auth.managers.simple.services.login.get_auth_manager")
-    def test_create_token(self, get_auth_manager, auth_manager, user, role, 
teams):
+    def test_create_token(self, get_auth_manager, mock_get_all_team_names, 
auth_manager, user, role, teams):
         mock_am = Mock(wraps=auth_manager)
         get_auth_manager.return_value = mock_am
+        mock_get_all_team_names.return_value = {"test", "marketing"}
 
         with conf_vars(
             {
diff --git 
a/airflow-core/tests/unit/api_fastapi/auth/managers/simple/test_simple_auth_manager.py
 
b/airflow-core/tests/unit/api_fastapi/auth/managers/simple/test_simple_auth_manager.py
index 3ce33c3efc4..1e9f195c8b7 100644
--- 
a/airflow-core/tests/unit/api_fastapi/auth/managers/simple/test_simple_auth_manager.py
+++ 
b/airflow-core/tests/unit/api_fastapi/auth/managers/simple/test_simple_auth_manager.py
@@ -354,3 +354,16 @@ class TestSimpleAuthManager:
             user = SimpleAuthManagerUser(username=user_id, role="user")
             result = 
auth_manager.is_authorized_hitl_task(assigned_users=assigned_users, user=user)
             assert result == expected
+
+    @conf_vars(
+        {
+            ("core", "multi_team"): "true",
+            (
+                "core",
+                "simple_auth_manager_users",
+            ): "test1:viewer,test2:viewer:test,test3:viewer:test|marketing",
+        }
+    )
+    def test_get_teams(self, auth_manager):
+        teams = auth_manager._get_teams()
+        assert teams == {"test", "marketing"}
diff --git 
a/airflow-core/tests/unit/api_fastapi/auth/managers/test_base_auth_manager.py 
b/airflow-core/tests/unit/api_fastapi/auth/managers/test_base_auth_manager.py
index 22efbc426ae..d82b9009c49 100644
--- 
a/airflow-core/tests/unit/api_fastapi/auth/managers/test_base_auth_manager.py
+++ 
b/airflow-core/tests/unit/api_fastapi/auth/managers/test_base_auth_manager.py
@@ -33,6 +33,9 @@ from 
airflow.api_fastapi.auth.managers.models.resource_details import (
 )
 from airflow.api_fastapi.auth.tokens import JWTGenerator, JWTValidator
 from airflow.api_fastapi.common.types import MenuItem
+from airflow.models.team import Team
+
+from tests_common.test_utils.config import conf_vars
 
 if TYPE_CHECKING:
     from airflow.api_fastapi.auth.managers.base_auth_manager import 
ResourceMethod
@@ -150,6 +153,33 @@ def auth_manager():
 
 
 class TestBaseAuthManager:
+    def test_init_non_multi_team_mode(self, auth_manager):
+        assert auth_manager.init() is None
+
+    @conf_vars({("core", "multi_team"): "True"})
+    @pytest.mark.parametrize(
+        ("auth_manager_teams", "db_teams", "expected"),
+        [
+            ({"teamA", "teamB"}, {"teamA", "teamB"}, True),
+            ({"teamA", "teamB"}, {"teamA", "teamB", "teamC"}, True),
+            (set(), set(), True),
+            ({"teamA", "teamB"}, {"teamA", "teamC"}, False),
+        ],
+    )
+    @patch.object(Team, "get_all_team_names")
+    @patch.object(EmptyAuthManager, "_get_teams")
+    def test_init_multi_team_mode(
+        self, mock_get_teams, mock_get_all_team_names, auth_manager_teams, 
db_teams, expected, auth_manager
+    ):
+        mock_get_teams.return_value = auth_manager_teams
+        mock_get_all_team_names.return_value = db_teams
+
+        if expected:
+            assert auth_manager.init() is None
+        else:
+            with pytest.raises(ValueError, match="Teams defined in the auth 
manager"):
+                auth_manager.init()
+
     def test_get_cli_commands_return_empty_list(self, auth_manager):
         assert auth_manager.get_cli_commands() == []
 
diff --git 
a/providers/amazon/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py
 
b/providers/amazon/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py
index 9d3f2fabe9e..f20f4f0a075 100644
--- 
a/providers/amazon/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py
+++ 
b/providers/amazon/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py
@@ -71,6 +71,7 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
     """
 
     def init(self) -> None:
+        super().init()
         if not AIRFLOW_V_3_0_PLUS:
             raise AirflowOptionalProviderFeatureException(
                 "AWS auth manager is only compatible with Airflow versions >= 
3.0.0"

Reply via email to