This is an automated email from the ASF dual-hosted git repository.

vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 4ed9873db7e Keycloak auth manager: enforce team‑scoped authorization 
(AIP‑67) (#61351)
4ed9873db7e is described below

commit 4ed9873db7efa26ed03e84ffd4dcc40b204d0fc9
Author: Mathieu Monet <[email protected]>
AuthorDate: Wed Feb 18 16:45:36 2026 +0100

    Keycloak auth manager: enforce team‑scoped authorization (AIP‑67) (#61351)
    
    * feat: update keycloak manager for multiteam setup
    
    * feat: refactor authorization checks for multi-team setup in 
KeycloakAuthManager
    
    * feat: refactor team name retrieval for multi-team support in 
KeycloakAuthManager
    
    * feat: update KeycloakAuthManager to retrieve multi-team configuration 
dynamically
    Regenerate Keycloak OpenAPI spec
    
    * Apply suggestion from @vincbeck
    
    Co-authored-by: Vincent <[email protected]>
    
    * feat: added error handling for missing team_name in multi-team mode
    
    * generalize team-scoped auth cases and add compat guards
    
    * test(keycloak): gate multiteam auth manager tests on Airflow 3.2+
    
    * fix(keycloak): allow global team-scoped resources in multi-team mode
    
    * refactor(keycloak): inline resource name resolution in auth manager
    
    * test(keycloak): align multiteam skip reasons with 3.2.0 gate
    
    * Update 
providers/keycloak/src/airflow/providers/keycloak/auth_manager/keycloak_auth_manager.py
    
    Co-authored-by: Vincent <[email protected]>
    
    * Made type of _get_team_name explicit
    
    * refactored/renamed multiteam tests
    
    * Updated test style to enhance readability
    
    * moved auth manager patch from a with block to a decorator
    
    ---------
    
    Co-authored-by: Vincent <[email protected]>
---
 .../keycloak/auth_manager/keycloak_auth_manager.py |  62 +++++-
 .../auth_manager/test_keycloak_auth_manager.py     | 246 ++++++++++++++++++++-
 2 files changed, 291 insertions(+), 17 deletions(-)

diff --git 
a/providers/keycloak/src/airflow/providers/keycloak/auth_manager/keycloak_auth_manager.py
 
b/providers/keycloak/src/airflow/providers/keycloak/auth_manager/keycloak_auth_manager.py
index 4c844987d59..d28c04da598 100644
--- 
a/providers/keycloak/src/airflow/providers/keycloak/auth_manager/keycloak_auth_manager.py
+++ 
b/providers/keycloak/src/airflow/providers/keycloak/auth_manager/keycloak_auth_manager.py
@@ -43,7 +43,12 @@ except ImportError:
 
 from airflow.api_fastapi.common.types import MenuItem
 from airflow.cli.cli_config import CLICommand
-from airflow.providers.common.compat.sdk import AirflowException, conf
+
+try:
+    from airflow.providers.common.compat.sdk import AirflowException, conf
+except ModuleNotFoundError:
+    from airflow.configuration import conf
+    from airflow.exceptions import AirflowException
 from airflow.providers.keycloak.auth_manager.constants import (
     CONF_CLIENT_ID_KEY,
     CONF_CLIENT_SECRET_KEY,
@@ -76,6 +81,14 @@ if TYPE_CHECKING:
 log = logging.getLogger(__name__)
 
 RESOURCE_ID_ATTRIBUTE_NAME = "resource_id"
+TEAM_SCOPED_RESOURCES = frozenset(
+    {
+        KeycloakResource.DAG,
+        KeycloakResource.CONNECTION,
+        KeycloakResource.POOL,
+        KeycloakResource.VARIABLE,
+    }
+)
 
 
 class KeycloakAuthManager(BaseAuthManager[KeycloakAuthManagerUser]):
@@ -184,10 +197,7 @@ class 
KeycloakAuthManager(BaseAuthManager[KeycloakAuthManagerUser]):
     ) -> bool:
         config_section = details.section if details else None
         return self._is_authorized(
-            method=method,
-            resource_type=KeycloakResource.CONFIGURATION,
-            user=user,
-            resource_id=config_section,
+            method=method, resource_type=KeycloakResource.CONFIGURATION, 
user=user, resource_id=config_section
         )
 
     def is_authorized_connection(
@@ -198,8 +208,13 @@ class 
KeycloakAuthManager(BaseAuthManager[KeycloakAuthManagerUser]):
         details: ConnectionDetails | None = None,
     ) -> bool:
         connection_id = details.conn_id if details else None
+        team_name = self._get_team_name(details)
         return self._is_authorized(
-            method=method, resource_type=KeycloakResource.CONNECTION, 
user=user, resource_id=connection_id
+            method=method,
+            resource_type=KeycloakResource.CONNECTION,
+            user=user,
+            resource_id=connection_id,
+            team_name=team_name,
         )
 
     def is_authorized_dag(
@@ -211,12 +226,14 @@ class 
KeycloakAuthManager(BaseAuthManager[KeycloakAuthManagerUser]):
         details: DagDetails | None = None,
     ) -> bool:
         dag_id = details.id if details else None
+        team_name = self._get_team_name(details)
         access_entity_str = access_entity.value if access_entity else None
         return self._is_authorized(
             method=method,
             resource_type=KeycloakResource.DAG,
             user=user,
             resource_id=dag_id,
+            team_name=team_name,
             attributes={"dag_entity": access_entity_str},
         )
 
@@ -262,16 +279,26 @@ class 
KeycloakAuthManager(BaseAuthManager[KeycloakAuthManagerUser]):
         self, *, method: ResourceMethod, user: KeycloakAuthManagerUser, 
details: VariableDetails | None = None
     ) -> bool:
         variable_key = details.key if details else None
+        team_name = self._get_team_name(details)
         return self._is_authorized(
-            method=method, resource_type=KeycloakResource.VARIABLE, user=user, 
resource_id=variable_key
+            method=method,
+            resource_type=KeycloakResource.VARIABLE,
+            user=user,
+            resource_id=variable_key,
+            team_name=team_name,
         )
 
     def is_authorized_pool(
         self, *, method: ResourceMethod, user: KeycloakAuthManagerUser, 
details: PoolDetails | None = None
     ) -> bool:
         pool_name = details.name if details else None
+        team_name = self._get_team_name(details)
         return self._is_authorized(
-            method=method, resource_type=KeycloakResource.POOL, user=user, 
resource_id=pool_name
+            method=method,
+            resource_type=KeycloakResource.POOL,
+            user=user,
+            resource_id=pool_name,
+            team_name=team_name,
         )
 
     def is_authorized_view(self, *, access_view: AccessView, user: 
KeycloakAuthManagerUser) -> bool:
@@ -356,6 +383,7 @@ class 
KeycloakAuthManager(BaseAuthManager[KeycloakAuthManagerUser]):
         resource_type: KeycloakResource,
         user: KeycloakAuthManagerUser,
         resource_id: str | None = None,
+        team_name: str | None = None,
         attributes: dict[str, str | None] | None = None,
     ) -> bool:
         client_id = conf.get(CONF_SECTION_NAME, CONF_CLIENT_ID_KEY)
@@ -368,9 +396,19 @@ class 
KeycloakAuthManager(BaseAuthManager[KeycloakAuthManagerUser]):
         elif method == "GET":
             method = "LIST"
 
+        if (
+            team_name
+            and conf.getboolean("core", "multi_team", fallback=False)
+            and resource_type in TEAM_SCOPED_RESOURCES
+        ):
+            resource_name = f"{resource_type.value}:{team_name}"
+        else:
+            resource_name = resource_type.value
+        permission = f"{resource_name}#{method}"
+
         resp = self.http_session.post(
             self._get_token_url(server_url, realm),
-            data=self._get_payload(client_id, 
f"{resource_type.value}#{method}", context_attributes),
+            data=self._get_payload(client_id, permission, context_attributes),
             headers=self._get_headers(user.access_token),
             timeout=5,
         )
@@ -425,6 +463,12 @@ class 
KeycloakAuthManager(BaseAuthManager[KeycloakAuthManagerUser]):
         # Normalize server_url to avoid double slashes (required for Keycloak 
26.4+ strict path validation).
         return 
f"{server_url.rstrip('/')}/realms/{realm}/protocol/openid-connect/token"
 
+    @staticmethod
+    def _get_team_name(
+        details: ConnectionDetails | DagDetails | PoolDetails | 
VariableDetails | None,
+    ) -> str | None:
+        return getattr(details, "team_name", None) if details else None
+
     @staticmethod
     def _get_payload(client_id: str, permission: str, attributes: dict[str, 
str] | None = None):
         payload: dict[str, Any] = {
diff --git 
a/providers/keycloak/tests/unit/keycloak/auth_manager/test_keycloak_auth_manager.py
 
b/providers/keycloak/tests/unit/keycloak/auth_manager/test_keycloak_auth_manager.py
index 0a234fc1f02..ef90a603d53 100644
--- 
a/providers/keycloak/tests/unit/keycloak/auth_manager/test_keycloak_auth_manager.py
+++ 
b/providers/keycloak/tests/unit/keycloak/auth_manager/test_keycloak_auth_manager.py
@@ -16,6 +16,7 @@
 # under the License.
 from __future__ import annotations
 
+import base64
 import json
 from contextlib import ExitStack
 from unittest.mock import Mock, patch
@@ -38,7 +39,11 @@ from 
airflow.api_fastapi.auth.managers.models.resource_details import (
 )
 from airflow.api_fastapi.common.types import MenuItem
 from airflow.exceptions import AirflowProviderDeprecationWarning
-from airflow.providers.common.compat.sdk import AirflowException
+
+try:
+    from airflow.providers.common.compat.sdk import AirflowException
+except ModuleNotFoundError:
+    from airflow.exceptions import AirflowException
 from airflow.providers.keycloak.auth_manager.constants import (
     CONF_CLIENT_ID_KEY,
     CONF_CLIENT_SECRET_KEY,
@@ -53,7 +58,14 @@ from 
airflow.providers.keycloak.auth_manager.keycloak_auth_manager import (
 from airflow.providers.keycloak.auth_manager.user import 
KeycloakAuthManagerUser
 
 from tests_common.test_utils.config import conf_vars
-from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_7_PLUS
+from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_7_PLUS, 
AIRFLOW_V_3_2_PLUS
+
+
+def _build_access_token(payload: dict[str, object]) -> str:
+    header = {"alg": "none", "typ": "JWT"}
+    header_b64 = 
base64.urlsafe_b64encode(json.dumps(header).encode()).decode().rstrip("=")
+    payload_b64 = 
base64.urlsafe_b64encode(json.dumps(payload).encode()).decode().rstrip("=")
+    return f"{header_b64}.{payload_b64}."
 
 
 @pytest.fixture
@@ -69,6 +81,20 @@ def auth_manager():
         yield KeycloakAuthManager()
 
 
[email protected]
+def auth_manager_multi_team():
+    with conf_vars(
+        {
+            ("core", "multi_team"): "True",
+            (CONF_SECTION_NAME, CONF_CLIENT_ID_KEY): "client_id",
+            (CONF_SECTION_NAME, CONF_CLIENT_SECRET_KEY): "client_secret",
+            (CONF_SECTION_NAME, CONF_REALM_KEY): "realm",
+            (CONF_SECTION_NAME, CONF_SERVER_URL_KEY): "server_url",
+        }
+    ):
+        yield KeycloakAuthManager()
+
+
 @pytest.fixture
 def user():
     user = Mock()
@@ -278,7 +304,7 @@ class TestKeycloakAuthManager:
 
         token_url = auth_manager._get_token_url("server_url", "realm")
         payload = auth_manager._get_payload("client_id", permission, 
attributes)
-        headers = auth_manager._get_headers("access_token")
+        headers = auth_manager._get_headers(user.access_token)
         auth_manager.http_session.post.assert_called_once_with(
             token_url, data=payload, headers=headers, timeout=5
         )
@@ -411,12 +437,216 @@ class TestKeycloakAuthManager:
 
         token_url = auth_manager._get_token_url("server_url", "realm")
         payload = auth_manager._get_payload("client_id", permission, 
attributes)
-        headers = auth_manager._get_headers("access_token")
+        headers = auth_manager._get_headers(user.access_token)
         auth_manager.http_session.post.assert_called_once_with(
             token_url, data=payload, headers=headers, timeout=5
         )
         assert result == expected
 
+    @pytest.mark.skipif(not AIRFLOW_V_3_2_PLUS, reason="team_name not 
supported before Airflow 3.2.0")
+    @pytest.mark.parametrize(
+        ("function", "method", "details_cls", "details_kwargs", "permission"),
+        [
+            ("is_authorized_dag", "GET", DagDetails, {"id": "test", 
"team_name": "team-a"}, "Dag#GET"),
+            (
+                "is_authorized_connection",
+                "DELETE",
+                ConnectionDetails,
+                {"conn_id": "test", "team_name": "team-a"},
+                "Connection#DELETE",
+            ),
+            (
+                "is_authorized_variable",
+                "PUT",
+                VariableDetails,
+                {"key": "test", "team_name": "team-a"},
+                "Variable#PUT",
+            ),
+            ("is_authorized_pool", "POST", PoolDetails, {"name": "test", 
"team_name": "team-a"}, "Pool#POST"),
+        ],
+    )
+    def test_team_name_ignored_when_multi_team_disabled(
+        self, auth_manager, user, function, method, details_cls, 
details_kwargs, permission
+    ):
+        details = details_cls(**details_kwargs)
+        mock_response = Mock()
+        mock_response.status_code = 200
+        auth_manager.http_session.post = Mock(return_value=mock_response)
+
+        getattr(auth_manager, function)(method=method, user=user, 
details=details)
+
+        actual_permission = 
auth_manager.http_session.post.call_args.kwargs["data"]["permission"]
+        assert actual_permission == permission
+
+    @pytest.mark.skipif(not AIRFLOW_V_3_2_PLUS, reason="team_name not 
supported before Airflow 3.2.0")
+    @pytest.mark.parametrize(
+        ("function", "details_cls", "details_kwargs", "permission"),
+        [
+            ("is_authorized_dag", DagDetails, {"id": "test", "team_name": 
"team-a"}, "Dag:team-a#GET"),
+            (
+                "is_authorized_connection",
+                ConnectionDetails,
+                {"conn_id": "test", "team_name": "team-a"},
+                "Connection:team-a#GET",
+            ),
+            (
+                "is_authorized_variable",
+                VariableDetails,
+                {"key": "test", "team_name": "team-a"},
+                "Variable:team-a#GET",
+            ),
+            ("is_authorized_pool", PoolDetails, {"name": "test", "team_name": 
"team-a"}, "Pool:team-a#GET"),
+        ],
+    )
+    def test_with_team_name_uses_team_scoped_permission(
+        self, auth_manager_multi_team, user, function, details_cls, 
details_kwargs, permission
+    ):
+        details = details_cls(**details_kwargs)
+        mock_response = Mock()
+        mock_response.status_code = 200
+        auth_manager_multi_team.http_session.post = 
Mock(return_value=mock_response)
+
+        getattr(auth_manager_multi_team, function)(method="GET", user=user, 
details=details)
+
+        actual_permission = 
auth_manager_multi_team.http_session.post.call_args.kwargs["data"]["permission"]
+        assert actual_permission == permission
+
+    @pytest.mark.parametrize(
+        ("function", "details", "permission"),
+        [
+            ("is_authorized_dag", DagDetails(id="test"), "Dag#GET"),
+            ("is_authorized_connection", ConnectionDetails(conn_id="test"), 
"Connection#GET"),
+            ("is_authorized_variable", VariableDetails(key="test"), 
"Variable#GET"),
+            ("is_authorized_pool", PoolDetails(name="test"), "Pool#GET"),
+        ],
+    )
+    def test_without_team_name_uses_global_permission(
+        self, auth_manager_multi_team, user, function, details, permission
+    ):
+        mock_response = Mock()
+        mock_response.status_code = 200
+        auth_manager_multi_team.http_session.post = 
Mock(return_value=mock_response)
+
+        getattr(auth_manager_multi_team, function)(method="GET", user=user, 
details=details)
+
+        actual_permission = 
auth_manager_multi_team.http_session.post.call_args.kwargs["data"]["permission"]
+        assert actual_permission == permission
+
+    @pytest.mark.parametrize(
+        ("function", "permission"),
+        [
+            ("is_authorized_dag", "Dag#LIST"),
+            ("is_authorized_connection", "Connection#LIST"),
+            ("is_authorized_variable", "Variable#LIST"),
+            ("is_authorized_pool", "Pool#LIST"),
+        ],
+    )
+    def test_list_without_team_name_uses_global_permission(
+        self, auth_manager_multi_team, user, function, permission
+    ):
+        mock_response = Mock()
+        mock_response.status_code = 200
+        auth_manager_multi_team.http_session.post = 
Mock(return_value=mock_response)
+
+        getattr(auth_manager_multi_team, function)(method="GET", user=user)
+
+        actual_permission = 
auth_manager_multi_team.http_session.post.call_args.kwargs["data"]["permission"]
+        assert actual_permission == permission
+
+    @pytest.mark.skipif(not AIRFLOW_V_3_2_PLUS, reason="team_name not 
supported before Airflow 3.2.0")
+    @pytest.mark.parametrize(
+        ("function", "details_cls", "details_kwargs", "permission"),
+        [
+            ("is_authorized_dag", DagDetails, {"team_name": "team-a"}, 
"Dag:team-a#LIST"),
+            (
+                "is_authorized_connection",
+                ConnectionDetails,
+                {"team_name": "team-a"},
+                "Connection:team-a#LIST",
+            ),
+            ("is_authorized_variable", VariableDetails, {"team_name": 
"team-a"}, "Variable:team-a#LIST"),
+            ("is_authorized_pool", PoolDetails, {"team_name": "team-a"}, 
"Pool:team-a#LIST"),
+        ],
+    )
+    def test_list_with_team_name_uses_team_scoped_permission(
+        self, auth_manager_multi_team, user, function, details_cls, 
details_kwargs, permission
+    ):
+        details = details_cls(**details_kwargs)
+        user.access_token = _build_access_token({"groups": ["team-a"]})
+        mock_response = Mock()
+        mock_response.status_code = 200
+        auth_manager_multi_team.http_session.post = 
Mock(return_value=mock_response)
+
+        getattr(auth_manager_multi_team, function)(method="GET", user=user, 
details=details)
+
+        actual_permission = 
auth_manager_multi_team.http_session.post.call_args.kwargs["data"]["permission"]
+        assert actual_permission == permission
+
+    @pytest.mark.skipif(not AIRFLOW_V_3_2_PLUS, reason="team_name not 
supported before Airflow 3.2.0")
+    @patch.object(KeycloakAuthManager, "is_authorized_dag", return_value=False)
+    def test_filter_authorized_dag_ids_team_mismatch(self, mock_is_authorized, 
auth_manager_multi_team, user):
+        result = auth_manager_multi_team.filter_authorized_dag_ids(
+            dag_ids={"dag-a"}, user=user, team_name="team-b"
+        )
+
+        mock_is_authorized.assert_called_once()
+        assert result == set()
+
+    @pytest.mark.skipif(not AIRFLOW_V_3_2_PLUS, reason="team_name not 
supported before Airflow 3.2.0")
+    @patch.object(KeycloakAuthManager, "is_authorized_dag", return_value=True)
+    def test_filter_authorized_dag_ids_team_match(self, mock_is_authorized, 
auth_manager_multi_team, user):
+        result = auth_manager_multi_team.filter_authorized_dag_ids(
+            dag_ids={"dag-a"}, user=user, team_name="team-a"
+        )
+
+        mock_is_authorized.assert_called_once()
+        assert result == {"dag-a"}
+
+    @pytest.mark.skipif(not AIRFLOW_V_3_2_PLUS, reason="team_name not 
supported before Airflow 3.2.0")
+    @patch.object(KeycloakAuthManager, "is_authorized_pool", 
return_value=False)
+    def test_filter_authorized_pools_no_team_returns_empty(
+        self, mock_is_authorized, auth_manager_multi_team, user
+    ):
+        result = auth_manager_multi_team.filter_authorized_pools(
+            pool_names={"pool-a"}, user=user, team_name=None
+        )
+
+        mock_is_authorized.assert_called_once()
+        assert result == set()
+
+    @pytest.mark.skipif(not AIRFLOW_V_3_2_PLUS, reason="team_name not 
supported before Airflow 3.2.0")
+    @pytest.mark.parametrize(
+        ("function", "details_cls", "details_kwargs"),
+        [
+            ("is_authorized_dag", DagDetails, {"team_name": "team-b"}),
+            ("is_authorized_connection", ConnectionDetails, {"team_name": 
"team-b"}),
+            ("is_authorized_variable", VariableDetails, {"team_name": 
"team-b"}),
+            ("is_authorized_pool", PoolDetails, {"team_name": "team-b"}),
+        ],
+    )
+    def test_list_with_mismatched_team_delegates_to_keycloak(
+        self, auth_manager_multi_team, user, function, details_cls, 
details_kwargs
+    ):
+        details = details_cls(**details_kwargs)
+        mock_response = Mock()
+        mock_response.status_code = 403
+        auth_manager_multi_team.http_session.post = 
Mock(return_value=mock_response)
+
+        result = getattr(auth_manager_multi_team, function)(method="GET", 
user=user, details=details)
+
+        auth_manager_multi_team.http_session.post.assert_called_once()
+        assert result is False
+
+    def test_filter_authorized_menu_items_with_batch_authorized(self, 
auth_manager, user):
+        with patch.object(
+            KeycloakAuthManager,
+            "_is_batch_authorized",
+            return_value={("MENU", menu.value) for menu in MenuItem},
+        ):
+            result = auth_manager.filter_authorized_menu_items(list(MenuItem), 
user=user)
+
+        assert set(result) == set(MenuItem)
+
     @pytest.mark.parametrize(
         ("status_code", "expected"),
         [
@@ -441,7 +671,7 @@ class TestKeycloakAuthManager:
         payload = auth_manager._get_payload(
             "client_id", "View#GET", {RESOURCE_ID_ATTRIBUTE_NAME: 
"CLUSTER_ACTIVITY"}
         )
-        headers = auth_manager._get_headers("access_token")
+        headers = auth_manager._get_headers(user.access_token)
         auth_manager.http_session.post.assert_called_once_with(
             token_url, data=payload, headers=headers, timeout=5
         )
@@ -469,7 +699,7 @@ class TestKeycloakAuthManager:
 
         token_url = auth_manager._get_token_url("server_url", "realm")
         payload = auth_manager._get_payload("client_id", "Custom#GET", 
{RESOURCE_ID_ATTRIBUTE_NAME: "test"})
-        headers = auth_manager._get_headers("access_token")
+        headers = auth_manager._get_headers(user.access_token)
         auth_manager.http_session.post.assert_called_once_with(
             token_url, data=payload, headers=headers, timeout=5
         )
@@ -502,7 +732,7 @@ class TestKeycloakAuthManager:
         payload = auth_manager._get_batch_payload(
             "client_id", [("MENU", MenuItem.ASSETS.value), ("MENU", 
MenuItem.CONNECTIONS.value)]
         )
-        headers = auth_manager._get_headers("access_token")
+        headers = auth_manager._get_headers(user.access_token)
         auth_manager.http_session.post.assert_called_once_with(
             token_url, data=payload, headers=headers, timeout=5
         )
@@ -527,7 +757,7 @@ class TestKeycloakAuthManager:
         payload = auth_manager._get_batch_payload(
             "client_id", [("MENU", MenuItem.ASSETS.value), ("MENU", 
MenuItem.CONNECTIONS.value)]
         )
-        headers = auth_manager._get_headers("access_token")
+        headers = auth_manager._get_headers(user.access_token)
         auth_manager.http_session.post.assert_called_once_with(
             token_url, data=payload, headers=headers, timeout=5
         )

Reply via email to