This is an automated email from the ASF dual-hosted git repository.
vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 4ed9873db7e Keycloak auth manager: enforce team‑scoped authorization
(AIP‑67) (#61351)
4ed9873db7e is described below
commit 4ed9873db7efa26ed03e84ffd4dcc40b204d0fc9
Author: Mathieu Monet <[email protected]>
AuthorDate: Wed Feb 18 16:45:36 2026 +0100
Keycloak auth manager: enforce team‑scoped authorization (AIP‑67) (#61351)
* feat: update keycloak manager for multiteam setup
* feat: refactor authorization checks for multi-team setup in
KeycloakAuthManager
* feat: refactor team name retrieval for multi-team support in
KeycloakAuthManager
* feat: update KeycloakAuthManager to retrieve multi-team configuration
dynamically
Regenerate Keycloak OpenAPI spec
* Apply suggestion from @vincbeck
Co-authored-by: Vincent <[email protected]>
* feat: added error handling for missing team_name in multi-team mode
* generalize team-scoped auth cases and add compat guards
* test(keycloak): gate multiteam auth manager tests on Airflow 3.2+
* fix(keycloak): allow global team-scoped resources in multi-team mode
* refactor(keycloak): inline resource name resolution in auth manager
* test(keycloak): align multiteam skip reasons with 3.2.0 gate
* Update
providers/keycloak/src/airflow/providers/keycloak/auth_manager/keycloak_auth_manager.py
Co-authored-by: Vincent <[email protected]>
* Made type of _get_team_name explicit
* refactored/renamed multiteam tests
* Updated test style to enhance readability
* moved auth manager patch from a with block to a decorator
---------
Co-authored-by: Vincent <[email protected]>
---
.../keycloak/auth_manager/keycloak_auth_manager.py | 62 +++++-
.../auth_manager/test_keycloak_auth_manager.py | 246 ++++++++++++++++++++-
2 files changed, 291 insertions(+), 17 deletions(-)
diff --git
a/providers/keycloak/src/airflow/providers/keycloak/auth_manager/keycloak_auth_manager.py
b/providers/keycloak/src/airflow/providers/keycloak/auth_manager/keycloak_auth_manager.py
index 4c844987d59..d28c04da598 100644
---
a/providers/keycloak/src/airflow/providers/keycloak/auth_manager/keycloak_auth_manager.py
+++
b/providers/keycloak/src/airflow/providers/keycloak/auth_manager/keycloak_auth_manager.py
@@ -43,7 +43,12 @@ except ImportError:
from airflow.api_fastapi.common.types import MenuItem
from airflow.cli.cli_config import CLICommand
-from airflow.providers.common.compat.sdk import AirflowException, conf
+
+try:
+ from airflow.providers.common.compat.sdk import AirflowException, conf
+except ModuleNotFoundError:
+ from airflow.configuration import conf
+ from airflow.exceptions import AirflowException
from airflow.providers.keycloak.auth_manager.constants import (
CONF_CLIENT_ID_KEY,
CONF_CLIENT_SECRET_KEY,
@@ -76,6 +81,14 @@ if TYPE_CHECKING:
log = logging.getLogger(__name__)
RESOURCE_ID_ATTRIBUTE_NAME = "resource_id"
+TEAM_SCOPED_RESOURCES = frozenset(
+ {
+ KeycloakResource.DAG,
+ KeycloakResource.CONNECTION,
+ KeycloakResource.POOL,
+ KeycloakResource.VARIABLE,
+ }
+)
class KeycloakAuthManager(BaseAuthManager[KeycloakAuthManagerUser]):
@@ -184,10 +197,7 @@ class
KeycloakAuthManager(BaseAuthManager[KeycloakAuthManagerUser]):
) -> bool:
config_section = details.section if details else None
return self._is_authorized(
- method=method,
- resource_type=KeycloakResource.CONFIGURATION,
- user=user,
- resource_id=config_section,
+ method=method, resource_type=KeycloakResource.CONFIGURATION,
user=user, resource_id=config_section
)
def is_authorized_connection(
@@ -198,8 +208,13 @@ class
KeycloakAuthManager(BaseAuthManager[KeycloakAuthManagerUser]):
details: ConnectionDetails | None = None,
) -> bool:
connection_id = details.conn_id if details else None
+ team_name = self._get_team_name(details)
return self._is_authorized(
- method=method, resource_type=KeycloakResource.CONNECTION,
user=user, resource_id=connection_id
+ method=method,
+ resource_type=KeycloakResource.CONNECTION,
+ user=user,
+ resource_id=connection_id,
+ team_name=team_name,
)
def is_authorized_dag(
@@ -211,12 +226,14 @@ class
KeycloakAuthManager(BaseAuthManager[KeycloakAuthManagerUser]):
details: DagDetails | None = None,
) -> bool:
dag_id = details.id if details else None
+ team_name = self._get_team_name(details)
access_entity_str = access_entity.value if access_entity else None
return self._is_authorized(
method=method,
resource_type=KeycloakResource.DAG,
user=user,
resource_id=dag_id,
+ team_name=team_name,
attributes={"dag_entity": access_entity_str},
)
@@ -262,16 +279,26 @@ class
KeycloakAuthManager(BaseAuthManager[KeycloakAuthManagerUser]):
self, *, method: ResourceMethod, user: KeycloakAuthManagerUser,
details: VariableDetails | None = None
) -> bool:
variable_key = details.key if details else None
+ team_name = self._get_team_name(details)
return self._is_authorized(
- method=method, resource_type=KeycloakResource.VARIABLE, user=user,
resource_id=variable_key
+ method=method,
+ resource_type=KeycloakResource.VARIABLE,
+ user=user,
+ resource_id=variable_key,
+ team_name=team_name,
)
def is_authorized_pool(
self, *, method: ResourceMethod, user: KeycloakAuthManagerUser,
details: PoolDetails | None = None
) -> bool:
pool_name = details.name if details else None
+ team_name = self._get_team_name(details)
return self._is_authorized(
- method=method, resource_type=KeycloakResource.POOL, user=user,
resource_id=pool_name
+ method=method,
+ resource_type=KeycloakResource.POOL,
+ user=user,
+ resource_id=pool_name,
+ team_name=team_name,
)
def is_authorized_view(self, *, access_view: AccessView, user:
KeycloakAuthManagerUser) -> bool:
@@ -356,6 +383,7 @@ class
KeycloakAuthManager(BaseAuthManager[KeycloakAuthManagerUser]):
resource_type: KeycloakResource,
user: KeycloakAuthManagerUser,
resource_id: str | None = None,
+ team_name: str | None = None,
attributes: dict[str, str | None] | None = None,
) -> bool:
client_id = conf.get(CONF_SECTION_NAME, CONF_CLIENT_ID_KEY)
@@ -368,9 +396,19 @@ class
KeycloakAuthManager(BaseAuthManager[KeycloakAuthManagerUser]):
elif method == "GET":
method = "LIST"
+ if (
+ team_name
+ and conf.getboolean("core", "multi_team", fallback=False)
+ and resource_type in TEAM_SCOPED_RESOURCES
+ ):
+ resource_name = f"{resource_type.value}:{team_name}"
+ else:
+ resource_name = resource_type.value
+ permission = f"{resource_name}#{method}"
+
resp = self.http_session.post(
self._get_token_url(server_url, realm),
- data=self._get_payload(client_id,
f"{resource_type.value}#{method}", context_attributes),
+ data=self._get_payload(client_id, permission, context_attributes),
headers=self._get_headers(user.access_token),
timeout=5,
)
@@ -425,6 +463,12 @@ class
KeycloakAuthManager(BaseAuthManager[KeycloakAuthManagerUser]):
# Normalize server_url to avoid double slashes (required for Keycloak
26.4+ strict path validation).
return
f"{server_url.rstrip('/')}/realms/{realm}/protocol/openid-connect/token"
+ @staticmethod
+ def _get_team_name(
+ details: ConnectionDetails | DagDetails | PoolDetails |
VariableDetails | None,
+ ) -> str | None:
+ return getattr(details, "team_name", None) if details else None
+
@staticmethod
def _get_payload(client_id: str, permission: str, attributes: dict[str,
str] | None = None):
payload: dict[str, Any] = {
diff --git
a/providers/keycloak/tests/unit/keycloak/auth_manager/test_keycloak_auth_manager.py
b/providers/keycloak/tests/unit/keycloak/auth_manager/test_keycloak_auth_manager.py
index 0a234fc1f02..ef90a603d53 100644
---
a/providers/keycloak/tests/unit/keycloak/auth_manager/test_keycloak_auth_manager.py
+++
b/providers/keycloak/tests/unit/keycloak/auth_manager/test_keycloak_auth_manager.py
@@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations
+import base64
import json
from contextlib import ExitStack
from unittest.mock import Mock, patch
@@ -38,7 +39,11 @@ from
airflow.api_fastapi.auth.managers.models.resource_details import (
)
from airflow.api_fastapi.common.types import MenuItem
from airflow.exceptions import AirflowProviderDeprecationWarning
-from airflow.providers.common.compat.sdk import AirflowException
+
+try:
+ from airflow.providers.common.compat.sdk import AirflowException
+except ModuleNotFoundError:
+ from airflow.exceptions import AirflowException
from airflow.providers.keycloak.auth_manager.constants import (
CONF_CLIENT_ID_KEY,
CONF_CLIENT_SECRET_KEY,
@@ -53,7 +58,14 @@ from
airflow.providers.keycloak.auth_manager.keycloak_auth_manager import (
from airflow.providers.keycloak.auth_manager.user import
KeycloakAuthManagerUser
from tests_common.test_utils.config import conf_vars
-from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_7_PLUS
+from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_7_PLUS,
AIRFLOW_V_3_2_PLUS
+
+
+def _build_access_token(payload: dict[str, object]) -> str:
+ header = {"alg": "none", "typ": "JWT"}
+ header_b64 =
base64.urlsafe_b64encode(json.dumps(header).encode()).decode().rstrip("=")
+ payload_b64 =
base64.urlsafe_b64encode(json.dumps(payload).encode()).decode().rstrip("=")
+ return f"{header_b64}.{payload_b64}."
@pytest.fixture
@@ -69,6 +81,20 @@ def auth_manager():
yield KeycloakAuthManager()
[email protected]
+def auth_manager_multi_team():
+ with conf_vars(
+ {
+ ("core", "multi_team"): "True",
+ (CONF_SECTION_NAME, CONF_CLIENT_ID_KEY): "client_id",
+ (CONF_SECTION_NAME, CONF_CLIENT_SECRET_KEY): "client_secret",
+ (CONF_SECTION_NAME, CONF_REALM_KEY): "realm",
+ (CONF_SECTION_NAME, CONF_SERVER_URL_KEY): "server_url",
+ }
+ ):
+ yield KeycloakAuthManager()
+
+
@pytest.fixture
def user():
user = Mock()
@@ -278,7 +304,7 @@ class TestKeycloakAuthManager:
token_url = auth_manager._get_token_url("server_url", "realm")
payload = auth_manager._get_payload("client_id", permission,
attributes)
- headers = auth_manager._get_headers("access_token")
+ headers = auth_manager._get_headers(user.access_token)
auth_manager.http_session.post.assert_called_once_with(
token_url, data=payload, headers=headers, timeout=5
)
@@ -411,12 +437,216 @@ class TestKeycloakAuthManager:
token_url = auth_manager._get_token_url("server_url", "realm")
payload = auth_manager._get_payload("client_id", permission,
attributes)
- headers = auth_manager._get_headers("access_token")
+ headers = auth_manager._get_headers(user.access_token)
auth_manager.http_session.post.assert_called_once_with(
token_url, data=payload, headers=headers, timeout=5
)
assert result == expected
+ @pytest.mark.skipif(not AIRFLOW_V_3_2_PLUS, reason="team_name not
supported before Airflow 3.2.0")
+ @pytest.mark.parametrize(
+ ("function", "method", "details_cls", "details_kwargs", "permission"),
+ [
+ ("is_authorized_dag", "GET", DagDetails, {"id": "test",
"team_name": "team-a"}, "Dag#GET"),
+ (
+ "is_authorized_connection",
+ "DELETE",
+ ConnectionDetails,
+ {"conn_id": "test", "team_name": "team-a"},
+ "Connection#DELETE",
+ ),
+ (
+ "is_authorized_variable",
+ "PUT",
+ VariableDetails,
+ {"key": "test", "team_name": "team-a"},
+ "Variable#PUT",
+ ),
+ ("is_authorized_pool", "POST", PoolDetails, {"name": "test",
"team_name": "team-a"}, "Pool#POST"),
+ ],
+ )
+ def test_team_name_ignored_when_multi_team_disabled(
+ self, auth_manager, user, function, method, details_cls,
details_kwargs, permission
+ ):
+ details = details_cls(**details_kwargs)
+ mock_response = Mock()
+ mock_response.status_code = 200
+ auth_manager.http_session.post = Mock(return_value=mock_response)
+
+ getattr(auth_manager, function)(method=method, user=user,
details=details)
+
+ actual_permission =
auth_manager.http_session.post.call_args.kwargs["data"]["permission"]
+ assert actual_permission == permission
+
+ @pytest.mark.skipif(not AIRFLOW_V_3_2_PLUS, reason="team_name not
supported before Airflow 3.2.0")
+ @pytest.mark.parametrize(
+ ("function", "details_cls", "details_kwargs", "permission"),
+ [
+ ("is_authorized_dag", DagDetails, {"id": "test", "team_name":
"team-a"}, "Dag:team-a#GET"),
+ (
+ "is_authorized_connection",
+ ConnectionDetails,
+ {"conn_id": "test", "team_name": "team-a"},
+ "Connection:team-a#GET",
+ ),
+ (
+ "is_authorized_variable",
+ VariableDetails,
+ {"key": "test", "team_name": "team-a"},
+ "Variable:team-a#GET",
+ ),
+ ("is_authorized_pool", PoolDetails, {"name": "test", "team_name":
"team-a"}, "Pool:team-a#GET"),
+ ],
+ )
+ def test_with_team_name_uses_team_scoped_permission(
+ self, auth_manager_multi_team, user, function, details_cls,
details_kwargs, permission
+ ):
+ details = details_cls(**details_kwargs)
+ mock_response = Mock()
+ mock_response.status_code = 200
+ auth_manager_multi_team.http_session.post =
Mock(return_value=mock_response)
+
+ getattr(auth_manager_multi_team, function)(method="GET", user=user,
details=details)
+
+ actual_permission =
auth_manager_multi_team.http_session.post.call_args.kwargs["data"]["permission"]
+ assert actual_permission == permission
+
+ @pytest.mark.parametrize(
+ ("function", "details", "permission"),
+ [
+ ("is_authorized_dag", DagDetails(id="test"), "Dag#GET"),
+ ("is_authorized_connection", ConnectionDetails(conn_id="test"),
"Connection#GET"),
+ ("is_authorized_variable", VariableDetails(key="test"),
"Variable#GET"),
+ ("is_authorized_pool", PoolDetails(name="test"), "Pool#GET"),
+ ],
+ )
+ def test_without_team_name_uses_global_permission(
+ self, auth_manager_multi_team, user, function, details, permission
+ ):
+ mock_response = Mock()
+ mock_response.status_code = 200
+ auth_manager_multi_team.http_session.post =
Mock(return_value=mock_response)
+
+ getattr(auth_manager_multi_team, function)(method="GET", user=user,
details=details)
+
+ actual_permission =
auth_manager_multi_team.http_session.post.call_args.kwargs["data"]["permission"]
+ assert actual_permission == permission
+
+ @pytest.mark.parametrize(
+ ("function", "permission"),
+ [
+ ("is_authorized_dag", "Dag#LIST"),
+ ("is_authorized_connection", "Connection#LIST"),
+ ("is_authorized_variable", "Variable#LIST"),
+ ("is_authorized_pool", "Pool#LIST"),
+ ],
+ )
+ def test_list_without_team_name_uses_global_permission(
+ self, auth_manager_multi_team, user, function, permission
+ ):
+ mock_response = Mock()
+ mock_response.status_code = 200
+ auth_manager_multi_team.http_session.post =
Mock(return_value=mock_response)
+
+ getattr(auth_manager_multi_team, function)(method="GET", user=user)
+
+ actual_permission =
auth_manager_multi_team.http_session.post.call_args.kwargs["data"]["permission"]
+ assert actual_permission == permission
+
+ @pytest.mark.skipif(not AIRFLOW_V_3_2_PLUS, reason="team_name not
supported before Airflow 3.2.0")
+ @pytest.mark.parametrize(
+ ("function", "details_cls", "details_kwargs", "permission"),
+ [
+ ("is_authorized_dag", DagDetails, {"team_name": "team-a"},
"Dag:team-a#LIST"),
+ (
+ "is_authorized_connection",
+ ConnectionDetails,
+ {"team_name": "team-a"},
+ "Connection:team-a#LIST",
+ ),
+ ("is_authorized_variable", VariableDetails, {"team_name":
"team-a"}, "Variable:team-a#LIST"),
+ ("is_authorized_pool", PoolDetails, {"team_name": "team-a"},
"Pool:team-a#LIST"),
+ ],
+ )
+ def test_list_with_team_name_uses_team_scoped_permission(
+ self, auth_manager_multi_team, user, function, details_cls,
details_kwargs, permission
+ ):
+ details = details_cls(**details_kwargs)
+ user.access_token = _build_access_token({"groups": ["team-a"]})
+ mock_response = Mock()
+ mock_response.status_code = 200
+ auth_manager_multi_team.http_session.post =
Mock(return_value=mock_response)
+
+ getattr(auth_manager_multi_team, function)(method="GET", user=user,
details=details)
+
+ actual_permission =
auth_manager_multi_team.http_session.post.call_args.kwargs["data"]["permission"]
+ assert actual_permission == permission
+
+ @pytest.mark.skipif(not AIRFLOW_V_3_2_PLUS, reason="team_name not
supported before Airflow 3.2.0")
+ @patch.object(KeycloakAuthManager, "is_authorized_dag", return_value=False)
+ def test_filter_authorized_dag_ids_team_mismatch(self, mock_is_authorized,
auth_manager_multi_team, user):
+ result = auth_manager_multi_team.filter_authorized_dag_ids(
+ dag_ids={"dag-a"}, user=user, team_name="team-b"
+ )
+
+ mock_is_authorized.assert_called_once()
+ assert result == set()
+
+ @pytest.mark.skipif(not AIRFLOW_V_3_2_PLUS, reason="team_name not
supported before Airflow 3.2.0")
+ @patch.object(KeycloakAuthManager, "is_authorized_dag", return_value=True)
+ def test_filter_authorized_dag_ids_team_match(self, mock_is_authorized,
auth_manager_multi_team, user):
+ result = auth_manager_multi_team.filter_authorized_dag_ids(
+ dag_ids={"dag-a"}, user=user, team_name="team-a"
+ )
+
+ mock_is_authorized.assert_called_once()
+ assert result == {"dag-a"}
+
+ @pytest.mark.skipif(not AIRFLOW_V_3_2_PLUS, reason="team_name not
supported before Airflow 3.2.0")
+ @patch.object(KeycloakAuthManager, "is_authorized_pool",
return_value=False)
+ def test_filter_authorized_pools_no_team_returns_empty(
+ self, mock_is_authorized, auth_manager_multi_team, user
+ ):
+ result = auth_manager_multi_team.filter_authorized_pools(
+ pool_names={"pool-a"}, user=user, team_name=None
+ )
+
+ mock_is_authorized.assert_called_once()
+ assert result == set()
+
+ @pytest.mark.skipif(not AIRFLOW_V_3_2_PLUS, reason="team_name not
supported before Airflow 3.2.0")
+ @pytest.mark.parametrize(
+ ("function", "details_cls", "details_kwargs"),
+ [
+ ("is_authorized_dag", DagDetails, {"team_name": "team-b"}),
+ ("is_authorized_connection", ConnectionDetails, {"team_name":
"team-b"}),
+ ("is_authorized_variable", VariableDetails, {"team_name":
"team-b"}),
+ ("is_authorized_pool", PoolDetails, {"team_name": "team-b"}),
+ ],
+ )
+ def test_list_with_mismatched_team_delegates_to_keycloak(
+ self, auth_manager_multi_team, user, function, details_cls,
details_kwargs
+ ):
+ details = details_cls(**details_kwargs)
+ mock_response = Mock()
+ mock_response.status_code = 403
+ auth_manager_multi_team.http_session.post =
Mock(return_value=mock_response)
+
+ result = getattr(auth_manager_multi_team, function)(method="GET",
user=user, details=details)
+
+ auth_manager_multi_team.http_session.post.assert_called_once()
+ assert result is False
+
+ def test_filter_authorized_menu_items_with_batch_authorized(self,
auth_manager, user):
+ with patch.object(
+ KeycloakAuthManager,
+ "_is_batch_authorized",
+ return_value={("MENU", menu.value) for menu in MenuItem},
+ ):
+ result = auth_manager.filter_authorized_menu_items(list(MenuItem),
user=user)
+
+ assert set(result) == set(MenuItem)
+
@pytest.mark.parametrize(
("status_code", "expected"),
[
@@ -441,7 +671,7 @@ class TestKeycloakAuthManager:
payload = auth_manager._get_payload(
"client_id", "View#GET", {RESOURCE_ID_ATTRIBUTE_NAME:
"CLUSTER_ACTIVITY"}
)
- headers = auth_manager._get_headers("access_token")
+ headers = auth_manager._get_headers(user.access_token)
auth_manager.http_session.post.assert_called_once_with(
token_url, data=payload, headers=headers, timeout=5
)
@@ -469,7 +699,7 @@ class TestKeycloakAuthManager:
token_url = auth_manager._get_token_url("server_url", "realm")
payload = auth_manager._get_payload("client_id", "Custom#GET",
{RESOURCE_ID_ATTRIBUTE_NAME: "test"})
- headers = auth_manager._get_headers("access_token")
+ headers = auth_manager._get_headers(user.access_token)
auth_manager.http_session.post.assert_called_once_with(
token_url, data=payload, headers=headers, timeout=5
)
@@ -502,7 +732,7 @@ class TestKeycloakAuthManager:
payload = auth_manager._get_batch_payload(
"client_id", [("MENU", MenuItem.ASSETS.value), ("MENU",
MenuItem.CONNECTIONS.value)]
)
- headers = auth_manager._get_headers("access_token")
+ headers = auth_manager._get_headers(user.access_token)
auth_manager.http_session.post.assert_called_once_with(
token_url, data=payload, headers=headers, timeout=5
)
@@ -527,7 +757,7 @@ class TestKeycloakAuthManager:
payload = auth_manager._get_batch_payload(
"client_id", [("MENU", MenuItem.ASSETS.value), ("MENU",
MenuItem.CONNECTIONS.value)]
)
- headers = auth_manager._get_headers("access_token")
+ headers = auth_manager._get_headers(user.access_token)
auth_manager.http_session.post.assert_called_once_with(
token_url, data=payload, headers=headers, timeout=5
)