This is an automated email from the ASF dual-hosted git repository.

vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 429f5c6b610 Add TTL cache with single-flight dedup to Keycloak 
filter_authorized_dag_ids (#63184)
429f5c6b610 is described below

commit 429f5c6b610122d14cae25e50f30063c46b33410
Author: Mathieu Monet <[email protected]>
AuthorDate: Tue Mar 10 15:19:39 2026 +0100

    Add TTL cache with single-flight dedup to Keycloak 
filter_authorized_dag_ids (#63184)
---
 .../providers/keycloak/auth_manager/cache.py       |  84 ++++++++++++
 .../keycloak/auth_manager/keycloak_auth_manager.py |  19 +++
 .../tests/unit/keycloak/auth_manager/test_cache.py | 137 +++++++++++++++++++
 .../auth_manager/test_keycloak_auth_manager.py     | 146 +++++++++++++++++++++
 4 files changed, 386 insertions(+)

diff --git 
a/providers/keycloak/src/airflow/providers/keycloak/auth_manager/cache.py 
b/providers/keycloak/src/airflow/providers/keycloak/auth_manager/cache.py
new file mode 100644
index 00000000000..75ccbf99c51
--- /dev/null
+++ b/providers/keycloak/src/airflow/providers/keycloak/auth_manager/cache.py
@@ -0,0 +1,84 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import threading
+import time
+from collections.abc import Callable
+
+_CACHE_TTL_SECONDS = 30
+_SINGLE_FLIGHT_TIMEOUT_SECONDS = 60
+
+# Maps cache keys to (timestamp, result) pairs for TTL-based expiration.
+_cache: dict[tuple, tuple[float, frozenset[str]]] = {}
+# Tracks in-flight requests: maps cache keys to events that waiting threads 
block on.
+_pending_requests: dict[tuple, threading.Event] = {}
+_cache_lock = threading.Lock()
+
+
+def _cache_get(key: tuple) -> frozenset[str] | None:
+    entry = _cache.get(key)
+    if entry and (time.monotonic() - entry[0]) < _CACHE_TTL_SECONDS:
+        return entry[1]
+    return None
+
+
+def _cache_set(key: tuple, value: frozenset[str]) -> None:
+    with _cache_lock:
+        _cache[key] = (time.monotonic(), value)
+        now = time.monotonic()
+        for k in [k for k, (ts, _) in _cache.items() if now - ts > 
_CACHE_TTL_SECONDS * 2]:
+            _cache.pop(k, None)
+
+
+def single_flight(cache_key: tuple, query_keycloak: Callable[[], set[str]]) -> 
set[str]:
+    """Return cached result, wait for a pending request, or run the query 
ourselves."""
+    # Fast path: check cache without lock
+    cached = _cache_get(cache_key)
+    if cached is not None:
+        return set(cached)
+
+    with _cache_lock:
+        cached = _cache_get(cache_key)
+        if cached is not None:
+            return set(cached)
+
+        event = _pending_requests.get(cache_key)
+        if event is not None:
+            is_worker = False
+        else:
+            event = threading.Event()
+            _pending_requests[cache_key] = event
+            is_worker = True
+
+    if not is_worker:
+        # Wait for the other thread to finish
+        event.wait(timeout=_SINGLE_FLIGHT_TIMEOUT_SECONDS)
+        cached = _cache_get(cache_key)
+        if cached is not None:
+            return set(cached)
+        # If the other thread failed, fall through and do the work ourselves
+
+    try:
+        result = query_keycloak()
+        _cache_set(cache_key, frozenset(result))
+        return result
+    finally:
+        with _cache_lock:
+            event = _pending_requests.pop(cache_key, None)
+        if event is not None:
+            event.set()
diff --git 
a/providers/keycloak/src/airflow/providers/keycloak/auth_manager/keycloak_auth_manager.py
 
b/providers/keycloak/src/airflow/providers/keycloak/auth_manager/keycloak_auth_manager.py
index bc398020844..9844873c031 100644
--- 
a/providers/keycloak/src/airflow/providers/keycloak/auth_manager/keycloak_auth_manager.py
+++ 
b/providers/keycloak/src/airflow/providers/keycloak/auth_manager/keycloak_auth_manager.py
@@ -49,6 +49,7 @@ try:
 except ModuleNotFoundError:
     from airflow.configuration import conf
     from airflow.exceptions import AirflowException
+from airflow.providers.keycloak.auth_manager.cache import single_flight
 from airflow.providers.keycloak.auth_manager.constants import (
     CONF_CLIENT_ID_KEY,
     CONF_CLIENT_SECRET_KEY,
@@ -440,6 +441,24 @@ class 
KeycloakAuthManager(BaseAuthManager[KeycloakAuthManagerUser]):
             )
         raise AirflowException(f"Unexpected error: {resp.status_code} - 
{resp.text}")
 
+    def filter_authorized_dag_ids(
+        self,
+        *,
+        dag_ids: set[str],
+        user: KeycloakAuthManagerUser,
+        method: ResourceMethod = "GET",
+        team_name: str | None = None,
+    ) -> set[str]:
+        cache_key = (user.get_id(), method, team_name, frozenset(dag_ids))
+
+        def query_keycloak() -> set[str]:
+            kwargs: dict = dict(dag_ids=dag_ids, user=user, method=method)
+            if team_name is not None:
+                kwargs["team_name"] = team_name
+            return super(KeycloakAuthManager, 
self).filter_authorized_dag_ids(**kwargs)
+
+        return single_flight(cache_key, query_keycloak)
+
     def _is_batch_authorized(
         self,
         *,
diff --git a/providers/keycloak/tests/unit/keycloak/auth_manager/test_cache.py 
b/providers/keycloak/tests/unit/keycloak/auth_manager/test_cache.py
new file mode 100644
index 00000000000..9125d8d4b72
--- /dev/null
+++ b/providers/keycloak/tests/unit/keycloak/auth_manager/test_cache.py
@@ -0,0 +1,137 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import threading
+import time
+
+import pytest
+
+from airflow.providers.keycloak.auth_manager import cache as cache_module
+from airflow.providers.keycloak.auth_manager.cache import single_flight
+
+
[email protected](autouse=True)
+def _clear_cache():
+    cache_module._cache.clear()
+    cache_module._pending_requests.clear()
+    yield
+    cache_module._cache.clear()
+    cache_module._pending_requests.clear()
+
+
+class TestSingleFlight:
+    def test_returns_query_result(self):
+        result = single_flight(("key",), lambda: {"a", "b"})
+        assert result == {"a", "b"}
+
+    def test_cache_hit(self):
+        call_count = 0
+
+        def query():
+            nonlocal call_count
+            call_count += 1
+            return {"a"}
+
+        single_flight(("key",), query)
+        single_flight(("key",), query)
+
+        assert call_count == 1
+
+    def test_different_keys_not_cached(self):
+        call_count = 0
+
+        def query():
+            nonlocal call_count
+            call_count += 1
+            return {"a"}
+
+        single_flight(("key1",), query)
+        single_flight(("key2",), query)
+
+        assert call_count == 2
+
+    def test_cache_expires(self):
+        call_count = 0
+
+        def query():
+            nonlocal call_count
+            call_count += 1
+            return {"a"}
+
+        single_flight(("key",), query)
+
+        # Expire the cache entry by backdating its timestamp
+        for k in cache_module._cache:
+            ts, val = cache_module._cache[k]
+            cache_module._cache[k] = (ts - cache_module._CACHE_TTL_SECONDS - 
1, val)
+
+        single_flight(("key",), query)
+        assert call_count == 2
+
+    def test_concurrent_dedup(self):
+        """Multiple threads with the same key coalesce into one call."""
+        gate = threading.Event()
+        call_count = 0
+
+        def slow_query():
+            nonlocal call_count
+            call_count += 1
+            gate.wait(timeout=5)
+            return {"a"}
+
+        results = [None] * 5
+        errors = []
+
+        def run(index):
+            try:
+                results[index] = single_flight(("key",), slow_query)
+            except Exception as e:
+                errors.append(e)
+
+        threads = [threading.Thread(target=run, args=(i,)) for i in range(5)]
+        for t in threads:
+            t.start()
+
+        time.sleep(0.1)
+        gate.set()
+
+        for t in threads:
+            t.join(timeout=5)
+
+        assert not errors, f"Threads raised errors: {errors}"
+        for r in results:
+            assert r == {"a"}
+        assert call_count == 1
+
+    def test_failed_query_allows_retry(self):
+        """If the worker thread fails, another thread can retry."""
+        call_count = 0
+
+        def failing_then_ok():
+            nonlocal call_count
+            call_count += 1
+            if call_count == 1:
+                raise ValueError("boom")
+            return {"a"}
+
+        with pytest.raises(ValueError, match="boom"):
+            single_flight(("key",), failing_then_ok)
+
+        result = single_flight(("key",), failing_then_ok)
+        assert result == {"a"}
+        assert call_count == 2
diff --git 
a/providers/keycloak/tests/unit/keycloak/auth_manager/test_keycloak_auth_manager.py
 
b/providers/keycloak/tests/unit/keycloak/auth_manager/test_keycloak_auth_manager.py
index b1c28714f6b..9c8ed9dd5b6 100644
--- 
a/providers/keycloak/tests/unit/keycloak/auth_manager/test_keycloak_auth_manager.py
+++ 
b/providers/keycloak/tests/unit/keycloak/auth_manager/test_keycloak_auth_manager.py
@@ -51,6 +51,7 @@ try:
     from airflow.providers.common.compat.sdk import AirflowException
 except ModuleNotFoundError:
     from airflow.exceptions import AirflowException
+from airflow.providers.keycloak.auth_manager import cache as cache_module
 from airflow.providers.keycloak.auth_manager.constants import (
     CONF_CLIENT_ID_KEY,
     CONF_CLIENT_SECRET_KEY,
@@ -109,6 +110,16 @@ def user():
     return user
 
 
[email protected](autouse=True)
+def _clear_filter_cache():
+    """Clear module-level single-flight cache between tests."""
+    cache_module._cache.clear()
+    cache_module._pending_requests.clear()
+    yield
+    cache_module._cache.clear()
+    cache_module._pending_requests.clear()
+
+
 class TestKeycloakAuthManager:
     def test_deserialize_user(self, auth_manager):
         result = auth_manager.deserialize_user(
@@ -960,3 +971,138 @@ class TestKeycloakAuthManager:
             headers={"Authorization": "Bearer pat-token"},
             timeout=5,
         )
+
+    @pytest.mark.parametrize(
+        ("status_codes", "expected"),
+        [
+            ([200, 200, 200], True),
+            ([200, 403, 200], False),
+            ([401, 200, 200], False),
+            ([200, 200, 401], False),
+        ],
+    )
+    def test_batch_is_authorized_dag(self, status_codes, expected, 
auth_manager, user):
+        return_values = [code == 200 for code in status_codes]
+
+        requests = [{"method": "GET", "details": DagDetails(id=f"dag_{i}")} 
for i in range(len(status_codes))]
+
+        with patch.object(KeycloakAuthManager, "_is_authorized", 
side_effect=return_values):
+            result = auth_manager.batch_is_authorized_dag(requests, user=user)
+
+        assert result == expected
+
+    @pytest.mark.parametrize(
+        ("status_codes", "expected"),
+        [
+            ([200, 200], True),
+            ([200, 403], False),
+        ],
+    )
+    def test_batch_is_authorized_connection(self, status_codes, expected, 
auth_manager, user):
+        return_values = [code == 200 for code in status_codes]
+
+        requests = [
+            {"method": "GET", "details": 
ConnectionDetails(conn_id=f"conn_{i}")}
+            for i in range(len(status_codes))
+        ]
+
+        with patch.object(KeycloakAuthManager, "_is_authorized", 
side_effect=return_values):
+            result = auth_manager.batch_is_authorized_connection(requests, 
user=user)
+
+        assert result == expected
+
+    @pytest.mark.parametrize(
+        ("status_codes", "expected"),
+        [
+            ([200, 200], True),
+            ([403, 200], False),
+        ],
+    )
+    def test_batch_is_authorized_pool(self, status_codes, expected, 
auth_manager, user):
+        return_values = [code == 200 for code in status_codes]
+
+        requests = [
+            {"method": "GET", "details": PoolDetails(name=f"pool_{i}")} for i 
in range(len(status_codes))
+        ]
+
+        with patch.object(KeycloakAuthManager, "_is_authorized", 
side_effect=return_values):
+            result = auth_manager.batch_is_authorized_pool(requests, user=user)
+
+        assert result == expected
+
+    @pytest.mark.parametrize(
+        ("status_codes", "expected"),
+        [
+            ([200, 200], True),
+            ([200, 401], False),
+        ],
+    )
+    def test_batch_is_authorized_variable(self, status_codes, expected, 
auth_manager, user):
+        return_values = [code == 200 for code in status_codes]
+
+        requests = [
+            {"method": "GET", "details": VariableDetails(key=f"var_{i}")} for 
i in range(len(status_codes))
+        ]
+
+        with patch.object(KeycloakAuthManager, "_is_authorized", 
side_effect=return_values):
+            result = auth_manager.batch_is_authorized_variable(requests, 
user=user)
+
+        assert result == expected
+
+    def test_batch_is_authorized_dag_empty_requests(self, auth_manager, user):
+        result = auth_manager.batch_is_authorized_dag([], user=user)
+        assert result is True
+
+    def test_batch_is_authorized_dag_with_access_entity(self, auth_manager, 
user):
+        requests = [
+            {
+                "method": "GET",
+                "access_entity": DagAccessEntity.TASK_INSTANCE,
+                "details": DagDetails(id="dag_1"),
+            }
+        ]
+
+        with patch.object(KeycloakAuthManager, "_is_authorized", 
return_value=True) as mock_is_authorized:
+            result = auth_manager.batch_is_authorized_dag(requests, user=user)
+
+        assert result is True
+        # Verify the call included the dag_entity attribute
+        call_kwargs = mock_is_authorized.call_args
+        assert call_kwargs.kwargs["attributes"] == {"dag_entity": 
"TASK_INSTANCE"}
+
+    @patch.object(
+        KeycloakAuthManager,
+        "is_authorized_dag",
+        side_effect=lambda *, details, **kw: {"dag_0": True, "dag_1": False, 
"dag_2": True}[details.id],
+    )
+    def test_filter_authorized_dag_ids(self, mock_is_authorized, auth_manager, 
user):
+        result = auth_manager.filter_authorized_dag_ids(
+            dag_ids={"dag_0", "dag_1", "dag_2"}, user=user, method="GET"
+        )
+
+        assert result == {"dag_0", "dag_2"}
+        assert mock_is_authorized.call_count == 3
+
+    def test_filter_authorized_dag_ids_empty(self, auth_manager, user):
+        result = auth_manager.filter_authorized_dag_ids(dag_ids=set(), 
user=user, method="GET")
+        assert result == set()
+
+    @patch.object(KeycloakAuthManager, "is_authorized_dag", return_value=False)
+    def test_filter_authorized_dag_ids_all_denied(self, mock_is_authorized, 
auth_manager, user):
+        result = auth_manager.filter_authorized_dag_ids(dag_ids={"dag_0", 
"dag_1"}, user=user, method="GET")
+
+        assert result == set()
+        assert mock_is_authorized.call_count == 2
+
+    @patch.object(KeycloakAuthManager, "is_authorized_dag", return_value=True)
+    def test_filter_authorized_dag_ids_cache_hit(self, mock_is_authorized, 
auth_manager, user):
+        """Second call with same args should return cached result without 
hitting Keycloak."""
+        dag_ids = {"dag_0", "dag_1"}
+
+        result1 = auth_manager.filter_authorized_dag_ids(dag_ids=dag_ids, 
user=user, method="GET")
+        result2 = auth_manager.filter_authorized_dag_ids(dag_ids=dag_ids, 
user=user, method="GET")
+
+        assert result1 == dag_ids
+        assert result2 == dag_ids
+        # is_authorized_dag should only be called for the first invocation (2 
dag_ids × 1 call)
+        assert mock_is_authorized.call_count == 2

Reply via email to