This is an automated email from the ASF dual-hosted git repository.
vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 429f5c6b610 Add TTL cache with single-flight dedup to Keycloak
filter_authorized_dag_ids (#63184)
429f5c6b610 is described below
commit 429f5c6b610122d14cae25e50f30063c46b33410
Author: Mathieu Monet <[email protected]>
AuthorDate: Tue Mar 10 15:19:39 2026 +0100
Add TTL cache with single-flight dedup to Keycloak
filter_authorized_dag_ids (#63184)
---
.../providers/keycloak/auth_manager/cache.py | 84 ++++++++++++
.../keycloak/auth_manager/keycloak_auth_manager.py | 19 +++
.../tests/unit/keycloak/auth_manager/test_cache.py | 137 +++++++++++++++++++
.../auth_manager/test_keycloak_auth_manager.py | 146 +++++++++++++++++++++
4 files changed, 386 insertions(+)
diff --git
a/providers/keycloak/src/airflow/providers/keycloak/auth_manager/cache.py
b/providers/keycloak/src/airflow/providers/keycloak/auth_manager/cache.py
new file mode 100644
index 00000000000..75ccbf99c51
--- /dev/null
+++ b/providers/keycloak/src/airflow/providers/keycloak/auth_manager/cache.py
@@ -0,0 +1,84 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import threading
+import time
+from collections.abc import Callable
+
+_CACHE_TTL_SECONDS = 30
+_SINGLE_FLIGHT_TIMEOUT_SECONDS = 60
+
+# Maps cache keys to (timestamp, result) pairs for TTL-based expiration.
+_cache: dict[tuple, tuple[float, frozenset[str]]] = {}
+# Tracks in-flight requests: maps cache keys to events that waiting threads
block on.
+_pending_requests: dict[tuple, threading.Event] = {}
+_cache_lock = threading.Lock()
+
+
+def _cache_get(key: tuple) -> frozenset[str] | None:
+ entry = _cache.get(key)
+ if entry and (time.monotonic() - entry[0]) < _CACHE_TTL_SECONDS:
+ return entry[1]
+ return None
+
+
+def _cache_set(key: tuple, value: frozenset[str]) -> None:
+ with _cache_lock:
+ _cache[key] = (time.monotonic(), value)
+ now = time.monotonic()
+ for k in [k for k, (ts, _) in _cache.items() if now - ts >
_CACHE_TTL_SECONDS * 2]:
+ _cache.pop(k, None)
+
+
+def single_flight(cache_key: tuple, query_keycloak: Callable[[], set[str]]) ->
set[str]:
+ """Return cached result, wait for a pending request, or run the query
ourselves."""
+ # Fast path: check cache without lock
+ cached = _cache_get(cache_key)
+ if cached is not None:
+ return set(cached)
+
+ with _cache_lock:
+ cached = _cache_get(cache_key)
+ if cached is not None:
+ return set(cached)
+
+ event = _pending_requests.get(cache_key)
+ if event is not None:
+ is_worker = False
+ else:
+ event = threading.Event()
+ _pending_requests[cache_key] = event
+ is_worker = True
+
+ if not is_worker:
+ # Wait for the other thread to finish
+ event.wait(timeout=_SINGLE_FLIGHT_TIMEOUT_SECONDS)
+ cached = _cache_get(cache_key)
+ if cached is not None:
+ return set(cached)
+ # If the other thread failed, fall through and do the work ourselves
+
+ try:
+ result = query_keycloak()
+ _cache_set(cache_key, frozenset(result))
+ return result
+ finally:
+ with _cache_lock:
+ event = _pending_requests.pop(cache_key, None)
+ if event is not None:
+ event.set()
diff --git
a/providers/keycloak/src/airflow/providers/keycloak/auth_manager/keycloak_auth_manager.py
b/providers/keycloak/src/airflow/providers/keycloak/auth_manager/keycloak_auth_manager.py
index bc398020844..9844873c031 100644
---
a/providers/keycloak/src/airflow/providers/keycloak/auth_manager/keycloak_auth_manager.py
+++
b/providers/keycloak/src/airflow/providers/keycloak/auth_manager/keycloak_auth_manager.py
@@ -49,6 +49,7 @@ try:
except ModuleNotFoundError:
from airflow.configuration import conf
from airflow.exceptions import AirflowException
+from airflow.providers.keycloak.auth_manager.cache import single_flight
from airflow.providers.keycloak.auth_manager.constants import (
CONF_CLIENT_ID_KEY,
CONF_CLIENT_SECRET_KEY,
@@ -440,6 +441,24 @@ class
KeycloakAuthManager(BaseAuthManager[KeycloakAuthManagerUser]):
)
raise AirflowException(f"Unexpected error: {resp.status_code} -
{resp.text}")
+ def filter_authorized_dag_ids(
+ self,
+ *,
+ dag_ids: set[str],
+ user: KeycloakAuthManagerUser,
+ method: ResourceMethod = "GET",
+ team_name: str | None = None,
+ ) -> set[str]:
+ cache_key = (user.get_id(), method, team_name, frozenset(dag_ids))
+
+ def query_keycloak() -> set[str]:
+ kwargs: dict = dict(dag_ids=dag_ids, user=user, method=method)
+ if team_name is not None:
+ kwargs["team_name"] = team_name
+ return super(KeycloakAuthManager,
self).filter_authorized_dag_ids(**kwargs)
+
+ return single_flight(cache_key, query_keycloak)
+
def _is_batch_authorized(
self,
*,
diff --git a/providers/keycloak/tests/unit/keycloak/auth_manager/test_cache.py
b/providers/keycloak/tests/unit/keycloak/auth_manager/test_cache.py
new file mode 100644
index 00000000000..9125d8d4b72
--- /dev/null
+++ b/providers/keycloak/tests/unit/keycloak/auth_manager/test_cache.py
@@ -0,0 +1,137 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import threading
+import time
+
+import pytest
+
+from airflow.providers.keycloak.auth_manager import cache as cache_module
+from airflow.providers.keycloak.auth_manager.cache import single_flight
+
+
[email protected](autouse=True)
+def _clear_cache():
+ cache_module._cache.clear()
+ cache_module._pending_requests.clear()
+ yield
+ cache_module._cache.clear()
+ cache_module._pending_requests.clear()
+
+
+class TestSingleFlight:
+ def test_returns_query_result(self):
+ result = single_flight(("key",), lambda: {"a", "b"})
+ assert result == {"a", "b"}
+
+ def test_cache_hit(self):
+ call_count = 0
+
+ def query():
+ nonlocal call_count
+ call_count += 1
+ return {"a"}
+
+ single_flight(("key",), query)
+ single_flight(("key",), query)
+
+ assert call_count == 1
+
+ def test_different_keys_not_cached(self):
+ call_count = 0
+
+ def query():
+ nonlocal call_count
+ call_count += 1
+ return {"a"}
+
+ single_flight(("key1",), query)
+ single_flight(("key2",), query)
+
+ assert call_count == 2
+
+ def test_cache_expires(self):
+ call_count = 0
+
+ def query():
+ nonlocal call_count
+ call_count += 1
+ return {"a"}
+
+ single_flight(("key",), query)
+
+ # Expire the cache entry by backdating its timestamp
+ for k in cache_module._cache:
+ ts, val = cache_module._cache[k]
+ cache_module._cache[k] = (ts - cache_module._CACHE_TTL_SECONDS -
1, val)
+
+ single_flight(("key",), query)
+ assert call_count == 2
+
+ def test_concurrent_dedup(self):
+ """Multiple threads with the same key coalesce into one call."""
+ gate = threading.Event()
+ call_count = 0
+
+ def slow_query():
+ nonlocal call_count
+ call_count += 1
+ gate.wait(timeout=5)
+ return {"a"}
+
+ results = [None] * 5
+ errors = []
+
+ def run(index):
+ try:
+ results[index] = single_flight(("key",), slow_query)
+ except Exception as e:
+ errors.append(e)
+
+ threads = [threading.Thread(target=run, args=(i,)) for i in range(5)]
+ for t in threads:
+ t.start()
+
+ time.sleep(0.1)
+ gate.set()
+
+ for t in threads:
+ t.join(timeout=5)
+
+ assert not errors, f"Threads raised errors: {errors}"
+ for r in results:
+ assert r == {"a"}
+ assert call_count == 1
+
+ def test_failed_query_allows_retry(self):
+ """If the worker thread fails, another thread can retry."""
+ call_count = 0
+
+ def failing_then_ok():
+ nonlocal call_count
+ call_count += 1
+ if call_count == 1:
+ raise ValueError("boom")
+ return {"a"}
+
+ with pytest.raises(ValueError, match="boom"):
+ single_flight(("key",), failing_then_ok)
+
+ result = single_flight(("key",), failing_then_ok)
+ assert result == {"a"}
+ assert call_count == 2
diff --git
a/providers/keycloak/tests/unit/keycloak/auth_manager/test_keycloak_auth_manager.py
b/providers/keycloak/tests/unit/keycloak/auth_manager/test_keycloak_auth_manager.py
index b1c28714f6b..9c8ed9dd5b6 100644
---
a/providers/keycloak/tests/unit/keycloak/auth_manager/test_keycloak_auth_manager.py
+++
b/providers/keycloak/tests/unit/keycloak/auth_manager/test_keycloak_auth_manager.py
@@ -51,6 +51,7 @@ try:
from airflow.providers.common.compat.sdk import AirflowException
except ModuleNotFoundError:
from airflow.exceptions import AirflowException
+from airflow.providers.keycloak.auth_manager import cache as cache_module
from airflow.providers.keycloak.auth_manager.constants import (
CONF_CLIENT_ID_KEY,
CONF_CLIENT_SECRET_KEY,
@@ -109,6 +110,16 @@ def user():
return user
[email protected](autouse=True)
+def _clear_filter_cache():
+ """Clear module-level single-flight cache between tests."""
+ cache_module._cache.clear()
+ cache_module._pending_requests.clear()
+ yield
+ cache_module._cache.clear()
+ cache_module._pending_requests.clear()
+
+
class TestKeycloakAuthManager:
def test_deserialize_user(self, auth_manager):
result = auth_manager.deserialize_user(
@@ -960,3 +971,138 @@ class TestKeycloakAuthManager:
headers={"Authorization": "Bearer pat-token"},
timeout=5,
)
+
+ @pytest.mark.parametrize(
+ ("status_codes", "expected"),
+ [
+ ([200, 200, 200], True),
+ ([200, 403, 200], False),
+ ([401, 200, 200], False),
+ ([200, 200, 401], False),
+ ],
+ )
+ def test_batch_is_authorized_dag(self, status_codes, expected,
auth_manager, user):
+ return_values = [code == 200 for code in status_codes]
+
+ requests = [{"method": "GET", "details": DagDetails(id=f"dag_{i}")}
for i in range(len(status_codes))]
+
+ with patch.object(KeycloakAuthManager, "_is_authorized",
side_effect=return_values):
+ result = auth_manager.batch_is_authorized_dag(requests, user=user)
+
+ assert result == expected
+
+ @pytest.mark.parametrize(
+ ("status_codes", "expected"),
+ [
+ ([200, 200], True),
+ ([200, 403], False),
+ ],
+ )
+ def test_batch_is_authorized_connection(self, status_codes, expected,
auth_manager, user):
+ return_values = [code == 200 for code in status_codes]
+
+ requests = [
+ {"method": "GET", "details":
ConnectionDetails(conn_id=f"conn_{i}")}
+ for i in range(len(status_codes))
+ ]
+
+ with patch.object(KeycloakAuthManager, "_is_authorized",
side_effect=return_values):
+ result = auth_manager.batch_is_authorized_connection(requests,
user=user)
+
+ assert result == expected
+
+ @pytest.mark.parametrize(
+ ("status_codes", "expected"),
+ [
+ ([200, 200], True),
+ ([403, 200], False),
+ ],
+ )
+ def test_batch_is_authorized_pool(self, status_codes, expected,
auth_manager, user):
+ return_values = [code == 200 for code in status_codes]
+
+ requests = [
+ {"method": "GET", "details": PoolDetails(name=f"pool_{i}")} for i
in range(len(status_codes))
+ ]
+
+ with patch.object(KeycloakAuthManager, "_is_authorized",
side_effect=return_values):
+ result = auth_manager.batch_is_authorized_pool(requests, user=user)
+
+ assert result == expected
+
+ @pytest.mark.parametrize(
+ ("status_codes", "expected"),
+ [
+ ([200, 200], True),
+ ([200, 401], False),
+ ],
+ )
+ def test_batch_is_authorized_variable(self, status_codes, expected,
auth_manager, user):
+ return_values = [code == 200 for code in status_codes]
+
+ requests = [
+ {"method": "GET", "details": VariableDetails(key=f"var_{i}")} for
i in range(len(status_codes))
+ ]
+
+ with patch.object(KeycloakAuthManager, "_is_authorized",
side_effect=return_values):
+ result = auth_manager.batch_is_authorized_variable(requests,
user=user)
+
+ assert result == expected
+
+ def test_batch_is_authorized_dag_empty_requests(self, auth_manager, user):
+ result = auth_manager.batch_is_authorized_dag([], user=user)
+ assert result is True
+
+ def test_batch_is_authorized_dag_with_access_entity(self, auth_manager,
user):
+ requests = [
+ {
+ "method": "GET",
+ "access_entity": DagAccessEntity.TASK_INSTANCE,
+ "details": DagDetails(id="dag_1"),
+ }
+ ]
+
+ with patch.object(KeycloakAuthManager, "_is_authorized",
return_value=True) as mock_is_authorized:
+ result = auth_manager.batch_is_authorized_dag(requests, user=user)
+
+ assert result is True
+ # Verify the call included the dag_entity attribute
+ call_kwargs = mock_is_authorized.call_args
+ assert call_kwargs.kwargs["attributes"] == {"dag_entity":
"TASK_INSTANCE"}
+
+ @patch.object(
+ KeycloakAuthManager,
+ "is_authorized_dag",
+ side_effect=lambda *, details, **kw: {"dag_0": True, "dag_1": False,
"dag_2": True}[details.id],
+ )
+ def test_filter_authorized_dag_ids(self, mock_is_authorized, auth_manager,
user):
+ result = auth_manager.filter_authorized_dag_ids(
+ dag_ids={"dag_0", "dag_1", "dag_2"}, user=user, method="GET"
+ )
+
+ assert result == {"dag_0", "dag_2"}
+ assert mock_is_authorized.call_count == 3
+
+ def test_filter_authorized_dag_ids_empty(self, auth_manager, user):
+ result = auth_manager.filter_authorized_dag_ids(dag_ids=set(),
user=user, method="GET")
+ assert result == set()
+
+ @patch.object(KeycloakAuthManager, "is_authorized_dag", return_value=False)
+ def test_filter_authorized_dag_ids_all_denied(self, mock_is_authorized,
auth_manager, user):
+ result = auth_manager.filter_authorized_dag_ids(dag_ids={"dag_0",
"dag_1"}, user=user, method="GET")
+
+ assert result == set()
+ assert mock_is_authorized.call_count == 2
+
+ @patch.object(KeycloakAuthManager, "is_authorized_dag", return_value=True)
+ def test_filter_authorized_dag_ids_cache_hit(self, mock_is_authorized,
auth_manager, user):
+ """Second call with same args should return cached result without
hitting Keycloak."""
+ dag_ids = {"dag_0", "dag_1"}
+
+ result1 = auth_manager.filter_authorized_dag_ids(dag_ids=dag_ids,
user=user, method="GET")
+ result2 = auth_manager.filter_authorized_dag_ids(dag_ids=dag_ids,
user=user, method="GET")
+
+ assert result1 == dag_ids
+ assert result2 == dag_ids
+ # is_authorized_dag should only be called for the first invocation (2
dag_ids × 1 call)
+ assert mock_is_authorized.call_count == 2