Copilot commented on code in PR #60804:
URL: https://github.com/apache/airflow/pull/60804#discussion_r3067262350
##########
airflow-core/src/airflow/config_templates/config.yml:
##########
@@ -1606,6 +1606,30 @@ api:
type: string
example: ~
default: "False"
+ dag_cache_size:
+ description: |
+ Size of the LRU cache for SerializedDAG objects in the API server.
+ Set to 0 to disable caching. The cache is keyed by DAG version ID,
+ so lookups by DAG ID (e.g., viewing a DAG's details) always query
+ the database for the latest version, but the deserialized result is
+ cached for subsequent version-specific lookups.
Review Comment:
The config text says “Set to 0 to disable caching”, but `DBDagBag` treats
`cache_size=0` as “use unbounded dict (no eviction)”, which still caches and
can cause the memory growth this feature aims to mitigate. Please align this
description with the actual behavior (e.g., “0 = unbounded dict, pre-3.2
behavior” or implement true “no caching” semantics for 0).
```suggestion
Set to 0 to use an unbounded cache (no LRU eviction), which preserves
the pre-3.2 behavior. The cache is keyed by DAG version ID, so
lookups
by DAG ID (e.g., viewing a DAG's details) always query the database
for
the latest version, but the deserialized result is cached for
subsequent
version-specific lookups.
```
##########
airflow-core/src/airflow/models/dagbag.py:
##########
@@ -39,50 +44,122 @@
class DBDagBag:
"""
- Internal class for retrieving and caching dags in the scheduler.
+ Internal class for retrieving dags from the database.
+
+ Optionally supports LRU+TTL caching when cache_size is provided.
+ The scheduler uses this without caching, while the API server can
+ enable caching via configuration.
:meta private:
"""
- def __init__(self, load_op_links: bool = True) -> None:
- self._dags: dict[UUID, SerializedDagModel] = {} # dag_version_id to
dag
- self.load_op_links = load_op_links
+ def __init__(
+ self,
+ load_op_links: bool = True,
+ cache_size: int | None = None,
+ cache_ttl: int | None = None,
+ ) -> None:
+ """
+ Initialize DBDagBag.
- def _read_dag(self, serialized_dag_model: SerializedDagModel) ->
SerializedDAG | None:
- serialized_dag_model.load_op_links = self.load_op_links
- if dag := serialized_dag_model.dag:
- self._dags[serialized_dag_model.dag_version_id] =
serialized_dag_model
+ :param load_op_links: Should the extra operator link be loaded when
de-serializing the DAG?
+ :param cache_size: Size of LRU cache. If None or 0, uses unbounded
dict (no eviction).
+ :param cache_ttl: Time-to-live for cache entries in seconds. If None
or 0, no TTL (LRU only).
+ """
+ self.load_op_links = load_op_links
+ self._dags: MutableMapping[UUID | str, SerializedDAG] = {}
+ self._dag_models: dict[UUID | str, SerializedDagModel] = {}
+ self._use_cache = False
+
+ # Initialize bounded cache if cache_size is provided and > 0
+ if cache_size and cache_size > 0:
+ if cache_ttl and cache_ttl > 0:
+ self._dags = TTLCache(maxsize=cache_size, ttl=cache_ttl)
+ else:
+ self._dags = LRUCache(maxsize=cache_size)
+ self._use_cache = True
+
+ # Lock required for bounded caches: cachetools caches are NOT
thread-safe
+ # (LRU reordering and TTL cleanup mutate internal linked lists).
+ # nullcontext for unbounded dict avoids lock overhead in the scheduler
path.
+ self._lock: RLock | nullcontext = RLock() if self._use_cache else
nullcontext()
+
+ def _read_dag(self, serdag: SerializedDagModel) -> SerializedDAG | None:
+ """Read and optionally cache a SerializedDAG from a
SerializedDagModel."""
+ serdag.load_op_links = self.load_op_links
+ dag = serdag.dag
+ if not dag:
+ return None
+ with self._lock:
+ self._dags[serdag.dag_version_id] = dag
+ cache_size = len(self._dags)
+ if self._use_cache:
+ Stats.gauge("api_server.dag_bag.cache_size", cache_size, rate=0.1)
return dag
- def get_serialized_dag_model(self, version_id: UUID, session: Session) ->
SerializedDagModel | None:
+ def _get_dag(self, version_id: UUID | str, session: Session) ->
SerializedDAG | None:
+ # Check cache first
+ with self._lock:
+ dag = self._dags.get(version_id)
+
+ if dag:
+ if self._use_cache:
+ Stats.incr("api_server.dag_bag.cache_hit")
+ return dag
+
+ dag_version = session.get(DagVersion, version_id,
options=[joinedload(DagVersion.serialized_dag)])
+ if not dag_version:
+ return None
+ if not (serdag := dag_version.serialized_dag):
+ return None
+
+ # Double-checked locking: another thread may have cached it while we
queried DB.
+ # Only emit the miss metric after confirming no other thread cached
it, to avoid
+ # counting a single lookup as both a miss and a hit.
+ if self._use_cache:
+ with self._lock:
+ if dag := self._dags.get(version_id):
+ Stats.incr("api_server.dag_bag.cache_hit")
+ return dag
+ Stats.incr("api_server.dag_bag.cache_miss")
+ return self._read_dag(serdag)
Review Comment:
The “double-checked locking” here does not prevent redundant DB queries on
concurrent cache misses: multiple threads can miss, release the lock, and each
call `session.get(...)` before any one thread stores the result. If the goal is
to ensure only one DB query per `version_id` under concurrency, you’ll need to
serialize the DB fetch (global lock or per-key lock / in-flight tracking).
Otherwise, please adjust the comment/expectations to clarify that this pattern
mainly avoids double-counting hit/miss metrics and prevents cache corruption,
but doesn’t eliminate duplicate queries.
##########
airflow-core/tests/unit/models/test_dagbag.py:
##########
@@ -63,45 +67,239 @@ def test__read_dag_returns_none_when_no_dag(self):
assert result is None
assert "v1" not in self.db_dag_bag._dags
- def test_get_serialized_dag_model(self):
- """It should return the cached SerializedDagModel if already loaded."""
+ def test_get_dag_fetches_from_db_on_miss(self):
+ """It should query the DB and cache the result when not in cache."""
+ mock_dag = MagicMock(spec=SerializedDAG)
mock_serdag = MagicMock(spec=SerializedDagModel)
+ mock_serdag.dag = mock_dag
mock_serdag.dag_version_id = "v1"
mock_dag_version = MagicMock()
mock_dag_version.serialized_dag = mock_serdag
self.session.get.return_value = mock_dag_version
- self.db_dag_bag.get_serialized_dag_model("v1", session=self.session)
- result = self.db_dag_bag.get_serialized_dag_model("v1",
session=self.session)
+ result = self.db_dag_bag.get_dag("v1", session=self.session)
- assert result == mock_serdag
self.session.get.assert_called_once()
+ assert result == mock_dag
- def test_get_serialized_dag_model_returns_none_when_not_found(self):
+ def test_get_dag_returns_cached_on_hit(self):
+ """It should return cached DAG without querying DB."""
+ mock_dag = MagicMock(spec=SerializedDAG)
+ self.db_dag_bag._dags["v1"] = mock_dag
+
+ result = self.db_dag_bag.get_dag("v1", session=self.session)
+
+ assert result == mock_dag
+ self.session.get.assert_not_called()
+
+ def test_get_dag_returns_none_when_not_found(self):
"""It should return None if version_id not found in DB."""
self.session.get.return_value = None
- result = self.db_dag_bag.get_serialized_dag_model("v1",
session=self.session)
+ result = self.db_dag_bag.get_dag("v1", session=self.session)
assert result is None
- def test_get_dag_calls_get_dag_model_and__read_dag(self):
- """It should call get_dag_model and then _read_dag."""
+
+class TestDBDagBagCache:
+ """Tests for DBDagBag optional caching behavior."""
+
+ def test_no_caching_by_default(self):
+ """Test that DBDagBag uses a simple dict without caching by default."""
+ dag_bag = DBDagBag()
+ assert dag_bag._use_cache is False
+ assert isinstance(dag_bag._dags, dict)
+
+ def test_lru_cache_enabled_with_cache_size(self):
+ """Test that LRU cache is enabled when cache_size is provided."""
+ dag_bag = DBDagBag(cache_size=10)
+ assert dag_bag._use_cache is True
+ assert isinstance(dag_bag._dags, LRUCache)
+
+ def test_ttl_cache_enabled_with_cache_size_and_ttl(self):
+ """Test that TTL cache is enabled when both cache_size and cache_ttl
are provided."""
+ dag_bag = DBDagBag(cache_size=10, cache_ttl=60)
+ assert dag_bag._use_cache is True
+ assert isinstance(dag_bag._dags, TTLCache)
+
+ def test_zero_cache_size_uses_unbounded_dict(self):
+ """Test that cache_size=0 uses unbounded dict (same as no caching)."""
+ dag_bag = DBDagBag(cache_size=0, cache_ttl=60)
+ assert dag_bag._use_cache is False
+ assert isinstance(dag_bag._dags, dict)
+
+ def test_clear_cache_with_caching(self):
+ """Test clear_cache() with caching enabled."""
+ dag_bag = DBDagBag(cache_size=10, cache_ttl=60)
+
+ mock_dag = MagicMock()
+ dag_bag._dags["version_1"] = mock_dag
+ dag_bag._dags["version_2"] = mock_dag
+ assert len(dag_bag._dags) == 2
+
+ count = dag_bag.clear_cache()
+ assert count == 2
+ assert len(dag_bag._dags) == 0
+
+ def test_clear_cache_without_caching(self):
+ """Test clear_cache() without caching enabled."""
+ dag_bag = DBDagBag()
+
+ mock_dag = MagicMock()
+ dag_bag._dags["version_1"] = mock_dag
+ assert len(dag_bag._dags) == 1
+
+ count = dag_bag.clear_cache()
+ assert count == 1
+ assert len(dag_bag._dags) == 0
+
+ def test_ttl_cache_expiry(self):
+ """Test that cached DAGs expire after TTL."""
+ # TTLCache defaults to time.monotonic which time_machine cannot
control.
+ # Use time.time as the timer so time_machine can advance it.
+ dag_bag = DBDagBag(cache_size=10, cache_ttl=1)
+ dag_bag._dags = TTLCache(maxsize=10, ttl=1, timer=time.time)
+
+ with time_machine.travel("2025-01-01 00:00:00", tick=False):
+ dag_bag._dags["test_version_id"] = MagicMock()
+ assert "test_version_id" in dag_bag._dags
Review Comment:
This TTL expiry test uses `MagicMock()` without a `spec`, which can hide
interface mistakes. Prefer a `MagicMock(spec=SerializedDAG)` (or a minimal real
`SerializedDAG`) for the cached value to keep the test meaningful.
##########
airflow-core/tests/unit/models/test_dagbag.py:
##########
@@ -63,45 +67,239 @@ def test__read_dag_returns_none_when_no_dag(self):
assert result is None
assert "v1" not in self.db_dag_bag._dags
- def test_get_serialized_dag_model(self):
- """It should return the cached SerializedDagModel if already loaded."""
+ def test_get_dag_fetches_from_db_on_miss(self):
+ """It should query the DB and cache the result when not in cache."""
+ mock_dag = MagicMock(spec=SerializedDAG)
mock_serdag = MagicMock(spec=SerializedDagModel)
+ mock_serdag.dag = mock_dag
mock_serdag.dag_version_id = "v1"
mock_dag_version = MagicMock()
mock_dag_version.serialized_dag = mock_serdag
self.session.get.return_value = mock_dag_version
- self.db_dag_bag.get_serialized_dag_model("v1", session=self.session)
- result = self.db_dag_bag.get_serialized_dag_model("v1",
session=self.session)
+ result = self.db_dag_bag.get_dag("v1", session=self.session)
- assert result == mock_serdag
self.session.get.assert_called_once()
+ assert result == mock_dag
- def test_get_serialized_dag_model_returns_none_when_not_found(self):
+ def test_get_dag_returns_cached_on_hit(self):
+ """It should return cached DAG without querying DB."""
+ mock_dag = MagicMock(spec=SerializedDAG)
+ self.db_dag_bag._dags["v1"] = mock_dag
+
+ result = self.db_dag_bag.get_dag("v1", session=self.session)
+
+ assert result == mock_dag
+ self.session.get.assert_not_called()
+
+ def test_get_dag_returns_none_when_not_found(self):
"""It should return None if version_id not found in DB."""
self.session.get.return_value = None
- result = self.db_dag_bag.get_serialized_dag_model("v1",
session=self.session)
+ result = self.db_dag_bag.get_dag("v1", session=self.session)
assert result is None
- def test_get_dag_calls_get_dag_model_and__read_dag(self):
- """It should call get_dag_model and then _read_dag."""
+
+class TestDBDagBagCache:
+ """Tests for DBDagBag optional caching behavior."""
+
+ def test_no_caching_by_default(self):
+ """Test that DBDagBag uses a simple dict without caching by default."""
+ dag_bag = DBDagBag()
+ assert dag_bag._use_cache is False
+ assert isinstance(dag_bag._dags, dict)
+
+ def test_lru_cache_enabled_with_cache_size(self):
+ """Test that LRU cache is enabled when cache_size is provided."""
+ dag_bag = DBDagBag(cache_size=10)
+ assert dag_bag._use_cache is True
+ assert isinstance(dag_bag._dags, LRUCache)
+
+ def test_ttl_cache_enabled_with_cache_size_and_ttl(self):
+ """Test that TTL cache is enabled when both cache_size and cache_ttl
are provided."""
+ dag_bag = DBDagBag(cache_size=10, cache_ttl=60)
+ assert dag_bag._use_cache is True
+ assert isinstance(dag_bag._dags, TTLCache)
+
+ def test_zero_cache_size_uses_unbounded_dict(self):
+ """Test that cache_size=0 uses unbounded dict (same as no caching)."""
+ dag_bag = DBDagBag(cache_size=0, cache_ttl=60)
+ assert dag_bag._use_cache is False
+ assert isinstance(dag_bag._dags, dict)
+
+ def test_clear_cache_with_caching(self):
+ """Test clear_cache() with caching enabled."""
+ dag_bag = DBDagBag(cache_size=10, cache_ttl=60)
+
+ mock_dag = MagicMock()
+ dag_bag._dags["version_1"] = mock_dag
+ dag_bag._dags["version_2"] = mock_dag
+ assert len(dag_bag._dags) == 2
+
+ count = dag_bag.clear_cache()
+ assert count == 2
+ assert len(dag_bag._dags) == 0
+
+ def test_clear_cache_without_caching(self):
+ """Test clear_cache() without caching enabled."""
+ dag_bag = DBDagBag()
+
+ mock_dag = MagicMock()
+ dag_bag._dags["version_1"] = mock_dag
+ assert len(dag_bag._dags) == 1
+
+ count = dag_bag.clear_cache()
+ assert count == 1
+ assert len(dag_bag._dags) == 0
+
+ def test_ttl_cache_expiry(self):
+ """Test that cached DAGs expire after TTL."""
+ # TTLCache defaults to time.monotonic which time_machine cannot
control.
+ # Use time.time as the timer so time_machine can advance it.
+ dag_bag = DBDagBag(cache_size=10, cache_ttl=1)
+ dag_bag._dags = TTLCache(maxsize=10, ttl=1, timer=time.time)
+
+ with time_machine.travel("2025-01-01 00:00:00", tick=False):
+ dag_bag._dags["test_version_id"] = MagicMock()
+ assert "test_version_id" in dag_bag._dags
+
+ # Jump ahead beyond TTL
+ with time_machine.travel("2025-01-01 00:00:02", tick=False):
+ assert dag_bag._dags.get("test_version_id") is None
+
+ def test_lru_eviction(self):
+ """Test that LRU eviction works when cache is full."""
+ dag_bag = DBDagBag(cache_size=2)
+
+ dag_bag._dags["version_1"] = MagicMock()
+ dag_bag._dags["version_2"] = MagicMock()
+ dag_bag._dags["version_3"] = MagicMock()
+
+ # version_1 should be evicted (LRU)
+ assert dag_bag._dags.get("version_1") is None
+ assert dag_bag._dags.get("version_2") is not None
+ assert dag_bag._dags.get("version_3") is not None
+
+ def test_thread_safety_with_caching(self):
+ """Test concurrent access doesn't cause race conditions with caching
enabled."""
+ dag_bag = DBDagBag(cache_size=100, cache_ttl=60)
+ errors = []
+ mock_session = MagicMock()
+
+ def make_dag_version(version_id):
+ serdag = MagicMock()
+ serdag.dag = MagicMock()
+ serdag.dag_version_id = version_id
+ return MagicMock(serialized_dag=serdag)
+
+ def get_dag_version(model, version_id, options=None):
+ return make_dag_version(version_id)
+
+ mock_session.get.side_effect = get_dag_version
+
+ def access_cache(i):
+ try:
+ dag_bag._get_dag(f"version_{i % 5}", mock_session)
+ except Exception as e:
+ errors.append(e)
+
+ with ThreadPoolExecutor(max_workers=10) as executor:
+ futures = [executor.submit(access_cache, i) for i in range(100)]
+ for f in futures:
+ f.result()
+
+ assert not errors
+
+ def test_read_dag_stores_in_bounded_cache(self):
+ """Test that _read_dag stores DAG in bounded cache when cache_size >
0."""
+ dag_bag = DBDagBag(cache_size=10, cache_ttl=60)
+
+ mock_sdm = MagicMock()
+ mock_sdm.dag = MagicMock()
+ mock_sdm.dag_version_id = "test_version"
+
+ result = dag_bag._read_dag(mock_sdm)
+
+ assert result == mock_sdm.dag
+ assert "test_version" in dag_bag._dags
+
+ def test_read_dag_stores_in_unbounded_dict(self):
+ """Test that _read_dag stores DAG in unbounded dict when no
cache_size."""
+ dag_bag = DBDagBag()
+
+ mock_sdm = MagicMock()
+ mock_sdm.dag = MagicMock()
+ mock_sdm.dag_version_id = "test_version"
+
+ result = dag_bag._read_dag(mock_sdm)
+
+ assert result == mock_sdm.dag
+ assert "test_version" in dag_bag._dags
+
+ def test_iter_all_latest_version_dags_does_not_cache(self):
+ """Test that iter_all_latest_version_dags does not cache to prevent
thrashing."""
+ dag_bag = DBDagBag(cache_size=10, cache_ttl=60)
+
+ mock_session = MagicMock()
+ mock_sdm = MagicMock()
+ mock_sdm.dag = MagicMock()
+ mock_sdm.dag_version_id = "test_version"
+ mock_session.scalars.return_value = [mock_sdm]
+
+ list(dag_bag.iter_all_latest_version_dags(session=mock_session))
+
+ # Cache should be empty -- iter doesn't cache to prevent thrashing
+ assert len(dag_bag._dags) == 0
+
+ @patch("airflow.models.dagbag.Stats")
+ def test_cache_hit_metric_emitted(self, mock_stats):
+ """Test that cache hit metric is emitted when caching is enabled."""
+ dag_bag = DBDagBag(cache_size=10, cache_ttl=60)
+ mock_session = MagicMock()
+ dag_bag._dags["test_version"] = MagicMock()
+
Review Comment:
These metrics tests populate the cache with `MagicMock()` without
`spec`/`autospec`. Prefer specced mocks for cached DAG objects so the test
continues to validate realistic behavior.
##########
airflow-core/src/airflow/api_fastapi/common/dagbag.py:
##########
@@ -16,21 +16,42 @@
# under the License.
from __future__ import annotations
+import logging
from typing import TYPE_CHECKING, Annotated
from fastapi import Depends, HTTPException, Request, status
from sqlalchemy.orm import Session
+from airflow.configuration import conf
from airflow.models.dagbag import DBDagBag
if TYPE_CHECKING:
from airflow.models.dagrun import DagRun
from airflow.serialization.definitions.dag import SerializedDAG
+log = logging.getLogger(__name__)
+
def create_dag_bag() -> DBDagBag:
- """Create DagBag to retrieve DAGs from the database."""
- return DBDagBag()
+ """Create DagBag with configurable LRU+TTL caching for API server usage."""
+ cache_size = conf.getint("api", "dag_cache_size", fallback=64)
+ cache_ttl_config = conf.getint("api", "dag_cache_ttl", fallback=3600)
+
+ if cache_size < 0:
+ log.warning("dag_cache_size must be >= 0, disabling cache")
+ cache_size = 0
+ if cache_ttl_config < 0:
+ log.warning("dag_cache_ttl must be >= 0, disabling TTL")
+ cache_ttl_config = 0
+
+ # Disable caching if cache_size is 0
+ if cache_size <= 0:
+ return DBDagBag(cache_size=0)
+
+ # Disable TTL if cache_ttl is 0
Review Comment:
`dag_cache_size=0` is treated here as “disabling cache”, but
`DBDagBag(cache_size=0)` still uses the unbounded dict cache (i.e., it will
continue to cache DAGs and can still grow without bound). This is inconsistent
with the PR description/FAQ (“0 = unbounded, pre-3.2 behavior”) and can mislead
users trying to turn caching off.
Consider either (a) updating the wording/comments/log messages to say
“disable bounded LRU/TTL cache” (and document that 0 restores the unbounded
dict behavior), or (b) if the intent is truly “no caching”, change `DBDagBag`
so it skips storing entries when `cache_size=0` and adjust docs/tests
accordingly.
```suggestion
"""Create DagBag with configurable bounded LRU caching and optional TTL
for API server usage."""
cache_size = conf.getint("api", "dag_cache_size", fallback=64)
cache_ttl_config = conf.getint("api", "dag_cache_ttl", fallback=3600)
if cache_size < 0:
log.warning("dag_cache_size must be >= 0, disabling bounded LRU
cache and using unbounded cache")
cache_size = 0
if cache_ttl_config < 0:
log.warning("dag_cache_ttl must be >= 0, disabling TTL expiration")
cache_ttl_config = 0
# cache_size == 0 restores DBDagBag's unbounded cache behavior rather
than disabling caching.
if cache_size <= 0:
return DBDagBag(cache_size=0)
# cache_ttl == 0 disables TTL expiration.
```
##########
airflow-core/src/airflow/models/dagbag.py:
##########
@@ -39,50 +44,122 @@
class DBDagBag:
"""
- Internal class for retrieving and caching dags in the scheduler.
+ Internal class for retrieving dags from the database.
+
+ Optionally supports LRU+TTL caching when cache_size is provided.
+ The scheduler uses this without caching, while the API server can
+ enable caching via configuration.
:meta private:
"""
- def __init__(self, load_op_links: bool = True) -> None:
- self._dags: dict[UUID, SerializedDagModel] = {} # dag_version_id to
dag
- self.load_op_links = load_op_links
+ def __init__(
+ self,
+ load_op_links: bool = True,
+ cache_size: int | None = None,
+ cache_ttl: int | None = None,
+ ) -> None:
+ """
+ Initialize DBDagBag.
- def _read_dag(self, serialized_dag_model: SerializedDagModel) ->
SerializedDAG | None:
- serialized_dag_model.load_op_links = self.load_op_links
- if dag := serialized_dag_model.dag:
- self._dags[serialized_dag_model.dag_version_id] =
serialized_dag_model
+ :param load_op_links: Should the extra operator link be loaded when
de-serializing the DAG?
+ :param cache_size: Size of LRU cache. If None or 0, uses unbounded
dict (no eviction).
+ :param cache_ttl: Time-to-live for cache entries in seconds. If None
or 0, no TTL (LRU only).
+ """
+ self.load_op_links = load_op_links
+ self._dags: MutableMapping[UUID | str, SerializedDAG] = {}
+ self._dag_models: dict[UUID | str, SerializedDagModel] = {}
+ self._use_cache = False
+
+ # Initialize bounded cache if cache_size is provided and > 0
+ if cache_size and cache_size > 0:
+ if cache_ttl and cache_ttl > 0:
+ self._dags = TTLCache(maxsize=cache_size, ttl=cache_ttl)
+ else:
+ self._dags = LRUCache(maxsize=cache_size)
+ self._use_cache = True
+
+ # Lock required for bounded caches: cachetools caches are NOT
thread-safe
+ # (LRU reordering and TTL cleanup mutate internal linked lists).
+ # nullcontext for unbounded dict avoids lock overhead in the scheduler
path.
+ self._lock: RLock | nullcontext = RLock() if self._use_cache else
nullcontext()
+
Review Comment:
The `_lock` type annotation is incorrect: `nullcontext()` returns a
context-manager instance, not the `nullcontext` type itself. Consider
annotating this as an `AbstractContextManager[None]`/`ContextManager[None]` (or
similar) instead of `RLock | nullcontext` so type checkers accurately
understand `with self._lock:`.
##########
airflow-core/tests/unit/models/test_dagbag.py:
##########
@@ -63,45 +67,239 @@ def test__read_dag_returns_none_when_no_dag(self):
assert result is None
assert "v1" not in self.db_dag_bag._dags
- def test_get_serialized_dag_model(self):
- """It should return the cached SerializedDagModel if already loaded."""
+ def test_get_dag_fetches_from_db_on_miss(self):
+ """It should query the DB and cache the result when not in cache."""
+ mock_dag = MagicMock(spec=SerializedDAG)
mock_serdag = MagicMock(spec=SerializedDagModel)
+ mock_serdag.dag = mock_dag
mock_serdag.dag_version_id = "v1"
mock_dag_version = MagicMock()
mock_dag_version.serialized_dag = mock_serdag
self.session.get.return_value = mock_dag_version
- self.db_dag_bag.get_serialized_dag_model("v1", session=self.session)
- result = self.db_dag_bag.get_serialized_dag_model("v1",
session=self.session)
+ result = self.db_dag_bag.get_dag("v1", session=self.session)
- assert result == mock_serdag
self.session.get.assert_called_once()
+ assert result == mock_dag
- def test_get_serialized_dag_model_returns_none_when_not_found(self):
+ def test_get_dag_returns_cached_on_hit(self):
+ """It should return cached DAG without querying DB."""
+ mock_dag = MagicMock(spec=SerializedDAG)
+ self.db_dag_bag._dags["v1"] = mock_dag
+
+ result = self.db_dag_bag.get_dag("v1", session=self.session)
+
+ assert result == mock_dag
+ self.session.get.assert_not_called()
+
+ def test_get_dag_returns_none_when_not_found(self):
"""It should return None if version_id not found in DB."""
self.session.get.return_value = None
- result = self.db_dag_bag.get_serialized_dag_model("v1",
session=self.session)
+ result = self.db_dag_bag.get_dag("v1", session=self.session)
assert result is None
- def test_get_dag_calls_get_dag_model_and__read_dag(self):
- """It should call get_dag_model and then _read_dag."""
+
+class TestDBDagBagCache:
+ """Tests for DBDagBag optional caching behavior."""
+
+ def test_no_caching_by_default(self):
+ """Test that DBDagBag uses a simple dict without caching by default."""
+ dag_bag = DBDagBag()
+ assert dag_bag._use_cache is False
+ assert isinstance(dag_bag._dags, dict)
+
+ def test_lru_cache_enabled_with_cache_size(self):
+ """Test that LRU cache is enabled when cache_size is provided."""
+ dag_bag = DBDagBag(cache_size=10)
+ assert dag_bag._use_cache is True
+ assert isinstance(dag_bag._dags, LRUCache)
+
+ def test_ttl_cache_enabled_with_cache_size_and_ttl(self):
+ """Test that TTL cache is enabled when both cache_size and cache_ttl
are provided."""
+ dag_bag = DBDagBag(cache_size=10, cache_ttl=60)
+ assert dag_bag._use_cache is True
+ assert isinstance(dag_bag._dags, TTLCache)
+
+ def test_zero_cache_size_uses_unbounded_dict(self):
+ """Test that cache_size=0 uses unbounded dict (same as no caching)."""
+ dag_bag = DBDagBag(cache_size=0, cache_ttl=60)
+ assert dag_bag._use_cache is False
+ assert isinstance(dag_bag._dags, dict)
+
+ def test_clear_cache_with_caching(self):
+ """Test clear_cache() with caching enabled."""
+ dag_bag = DBDagBag(cache_size=10, cache_ttl=60)
+
+ mock_dag = MagicMock()
+ dag_bag._dags["version_1"] = mock_dag
+ dag_bag._dags["version_2"] = mock_dag
+ assert len(dag_bag._dags) == 2
+
+ count = dag_bag.clear_cache()
+ assert count == 2
+ assert len(dag_bag._dags) == 0
+
+ def test_clear_cache_without_caching(self):
+ """Test clear_cache() without caching enabled."""
+ dag_bag = DBDagBag()
+
+ mock_dag = MagicMock()
+ dag_bag._dags["version_1"] = mock_dag
+ assert len(dag_bag._dags) == 1
+
+ count = dag_bag.clear_cache()
+ assert count == 1
+ assert len(dag_bag._dags) == 0
+
+ def test_ttl_cache_expiry(self):
+ """Test that cached DAGs expire after TTL."""
+ # TTLCache defaults to time.monotonic which time_machine cannot
control.
+ # Use time.time as the timer so time_machine can advance it.
+ dag_bag = DBDagBag(cache_size=10, cache_ttl=1)
+ dag_bag._dags = TTLCache(maxsize=10, ttl=1, timer=time.time)
+
+ with time_machine.travel("2025-01-01 00:00:00", tick=False):
+ dag_bag._dags["test_version_id"] = MagicMock()
+ assert "test_version_id" in dag_bag._dags
+
+ # Jump ahead beyond TTL
+ with time_machine.travel("2025-01-01 00:00:02", tick=False):
+ assert dag_bag._dags.get("test_version_id") is None
+
+ def test_lru_eviction(self):
+ """Test that LRU eviction works when cache is full."""
+ dag_bag = DBDagBag(cache_size=2)
+
+ dag_bag._dags["version_1"] = MagicMock()
+ dag_bag._dags["version_2"] = MagicMock()
+ dag_bag._dags["version_3"] = MagicMock()
Review Comment:
The LRU eviction test uses multiple `MagicMock()` instances without
`spec`/`autospec`. Using `MagicMock(spec=SerializedDAG)` (or similar) here
would better reflect real cached objects and avoid brittle tests.
##########
airflow-core/tests/unit/models/test_dagbag.py:
##########
@@ -63,45 +67,239 @@ def test__read_dag_returns_none_when_no_dag(self):
assert result is None
assert "v1" not in self.db_dag_bag._dags
- def test_get_serialized_dag_model(self):
- """It should return the cached SerializedDagModel if already loaded."""
+ def test_get_dag_fetches_from_db_on_miss(self):
+ """It should query the DB and cache the result when not in cache."""
+ mock_dag = MagicMock(spec=SerializedDAG)
mock_serdag = MagicMock(spec=SerializedDagModel)
+ mock_serdag.dag = mock_dag
mock_serdag.dag_version_id = "v1"
mock_dag_version = MagicMock()
mock_dag_version.serialized_dag = mock_serdag
self.session.get.return_value = mock_dag_version
- self.db_dag_bag.get_serialized_dag_model("v1", session=self.session)
- result = self.db_dag_bag.get_serialized_dag_model("v1",
session=self.session)
+ result = self.db_dag_bag.get_dag("v1", session=self.session)
- assert result == mock_serdag
self.session.get.assert_called_once()
+ assert result == mock_dag
- def test_get_serialized_dag_model_returns_none_when_not_found(self):
+ def test_get_dag_returns_cached_on_hit(self):
+ """It should return cached DAG without querying DB."""
+ mock_dag = MagicMock(spec=SerializedDAG)
+ self.db_dag_bag._dags["v1"] = mock_dag
+
+ result = self.db_dag_bag.get_dag("v1", session=self.session)
+
+ assert result == mock_dag
+ self.session.get.assert_not_called()
+
+ def test_get_dag_returns_none_when_not_found(self):
"""It should return None if version_id not found in DB."""
self.session.get.return_value = None
- result = self.db_dag_bag.get_serialized_dag_model("v1",
session=self.session)
+ result = self.db_dag_bag.get_dag("v1", session=self.session)
assert result is None
- def test_get_dag_calls_get_dag_model_and__read_dag(self):
- """It should call get_dag_model and then _read_dag."""
+
+class TestDBDagBagCache:
+ """Tests for DBDagBag optional caching behavior."""
+
+ def test_no_caching_by_default(self):
+ """Test that DBDagBag uses a simple dict without caching by default."""
+ dag_bag = DBDagBag()
+ assert dag_bag._use_cache is False
+ assert isinstance(dag_bag._dags, dict)
+
Review Comment:
The test name/docstring implies “no caching”, but the default behavior is
still an unbounded in-memory dict cache; only the bounded LRU/TTL cachetools
mode is disabled. Renaming this test (and related docstrings) to reflect “no
bounded LRU/TTL cache by default” would avoid confusion, especially since the
user-facing `dag_cache_size=0` semantics are also being documented.
##########
airflow-core/tests/unit/models/test_dagbag.py:
##########
@@ -63,45 +67,239 @@ def test__read_dag_returns_none_when_no_dag(self):
assert result is None
assert "v1" not in self.db_dag_bag._dags
- def test_get_serialized_dag_model(self):
- """It should return the cached SerializedDagModel if already loaded."""
+ def test_get_dag_fetches_from_db_on_miss(self):
+ """It should query the DB and cache the result when not in cache."""
+ mock_dag = MagicMock(spec=SerializedDAG)
mock_serdag = MagicMock(spec=SerializedDagModel)
+ mock_serdag.dag = mock_dag
mock_serdag.dag_version_id = "v1"
mock_dag_version = MagicMock()
mock_dag_version.serialized_dag = mock_serdag
self.session.get.return_value = mock_dag_version
- self.db_dag_bag.get_serialized_dag_model("v1", session=self.session)
- result = self.db_dag_bag.get_serialized_dag_model("v1",
session=self.session)
+ result = self.db_dag_bag.get_dag("v1", session=self.session)
- assert result == mock_serdag
self.session.get.assert_called_once()
+ assert result == mock_dag
- def test_get_serialized_dag_model_returns_none_when_not_found(self):
+ def test_get_dag_returns_cached_on_hit(self):
+ """It should return cached DAG without querying DB."""
+ mock_dag = MagicMock(spec=SerializedDAG)
+ self.db_dag_bag._dags["v1"] = mock_dag
+
+ result = self.db_dag_bag.get_dag("v1", session=self.session)
+
+ assert result == mock_dag
+ self.session.get.assert_not_called()
+
+ def test_get_dag_returns_none_when_not_found(self):
"""It should return None if version_id not found in DB."""
self.session.get.return_value = None
- result = self.db_dag_bag.get_serialized_dag_model("v1",
session=self.session)
+ result = self.db_dag_bag.get_dag("v1", session=self.session)
assert result is None
- def test_get_dag_calls_get_dag_model_and__read_dag(self):
- """It should call get_dag_model and then _read_dag."""
+
+class TestDBDagBagCache:
+ """Tests for DBDagBag optional caching behavior."""
+
+ def test_no_caching_by_default(self):
+ """Test that DBDagBag uses a simple dict without caching by default."""
+ dag_bag = DBDagBag()
+ assert dag_bag._use_cache is False
+ assert isinstance(dag_bag._dags, dict)
+
+ def test_lru_cache_enabled_with_cache_size(self):
+ """Test that LRU cache is enabled when cache_size is provided."""
+ dag_bag = DBDagBag(cache_size=10)
+ assert dag_bag._use_cache is True
+ assert isinstance(dag_bag._dags, LRUCache)
+
+ def test_ttl_cache_enabled_with_cache_size_and_ttl(self):
+ """Test that TTL cache is enabled when both cache_size and cache_ttl
are provided."""
+ dag_bag = DBDagBag(cache_size=10, cache_ttl=60)
+ assert dag_bag._use_cache is True
+ assert isinstance(dag_bag._dags, TTLCache)
+
+ def test_zero_cache_size_uses_unbounded_dict(self):
+ """Test that cache_size=0 uses unbounded dict (same as no caching)."""
+ dag_bag = DBDagBag(cache_size=0, cache_ttl=60)
+ assert dag_bag._use_cache is False
+ assert isinstance(dag_bag._dags, dict)
+
+ def test_clear_cache_with_caching(self):
+ """Test clear_cache() with caching enabled."""
+ dag_bag = DBDagBag(cache_size=10, cache_ttl=60)
+
+ mock_dag = MagicMock()
+ dag_bag._dags["version_1"] = mock_dag
+ dag_bag._dags["version_2"] = mock_dag
+ assert len(dag_bag._dags) == 2
+
+ count = dag_bag.clear_cache()
+ assert count == 2
+ assert len(dag_bag._dags) == 0
+
+ def test_clear_cache_without_caching(self):
+ """Test clear_cache() without caching enabled."""
+ dag_bag = DBDagBag()
+
+ mock_dag = MagicMock()
+ dag_bag._dags["version_1"] = mock_dag
+ assert len(dag_bag._dags) == 1
+
+ count = dag_bag.clear_cache()
+ assert count == 1
+ assert len(dag_bag._dags) == 0
+
+ def test_ttl_cache_expiry(self):
+ """Test that cached DAGs expire after TTL."""
+ # TTLCache defaults to time.monotonic which time_machine cannot
control.
+ # Use time.time as the timer so time_machine can advance it.
+ dag_bag = DBDagBag(cache_size=10, cache_ttl=1)
+ dag_bag._dags = TTLCache(maxsize=10, ttl=1, timer=time.time)
+
+ with time_machine.travel("2025-01-01 00:00:00", tick=False):
+ dag_bag._dags["test_version_id"] = MagicMock()
+ assert "test_version_id" in dag_bag._dags
+
+ # Jump ahead beyond TTL
+ with time_machine.travel("2025-01-01 00:00:02", tick=False):
+ assert dag_bag._dags.get("test_version_id") is None
+
+ def test_lru_eviction(self):
+ """Test that LRU eviction works when cache is full."""
+ dag_bag = DBDagBag(cache_size=2)
+
+ dag_bag._dags["version_1"] = MagicMock()
+ dag_bag._dags["version_2"] = MagicMock()
+ dag_bag._dags["version_3"] = MagicMock()
+
+ # version_1 should be evicted (LRU)
+ assert dag_bag._dags.get("version_1") is None
+ assert dag_bag._dags.get("version_2") is not None
+ assert dag_bag._dags.get("version_3") is not None
+
+ def test_thread_safety_with_caching(self):
+ """Test concurrent access doesn't cause race conditions with caching
enabled."""
+ dag_bag = DBDagBag(cache_size=100, cache_ttl=60)
+ errors = []
+ mock_session = MagicMock()
+
+ def make_dag_version(version_id):
+ serdag = MagicMock()
+ serdag.dag = MagicMock()
+ serdag.dag_version_id = version_id
+ return MagicMock(serialized_dag=serdag)
+
Review Comment:
The thread-safety test constructs several unspecced `MagicMock()` objects
(including the session and models). Using `create_autospec(Session)` /
`MagicMock(spec=SerializedDagModel)` / `MagicMock(spec=SerializedDAG)` would
help ensure the test fails if the production interfaces change.
##########
airflow-core/tests/unit/models/test_dagbag.py:
##########
@@ -63,45 +67,239 @@ def test__read_dag_returns_none_when_no_dag(self):
assert result is None
assert "v1" not in self.db_dag_bag._dags
- def test_get_serialized_dag_model(self):
- """It should return the cached SerializedDagModel if already loaded."""
+ def test_get_dag_fetches_from_db_on_miss(self):
+ """It should query the DB and cache the result when not in cache."""
+ mock_dag = MagicMock(spec=SerializedDAG)
mock_serdag = MagicMock(spec=SerializedDagModel)
+ mock_serdag.dag = mock_dag
mock_serdag.dag_version_id = "v1"
mock_dag_version = MagicMock()
mock_dag_version.serialized_dag = mock_serdag
self.session.get.return_value = mock_dag_version
- self.db_dag_bag.get_serialized_dag_model("v1", session=self.session)
- result = self.db_dag_bag.get_serialized_dag_model("v1",
session=self.session)
+ result = self.db_dag_bag.get_dag("v1", session=self.session)
- assert result == mock_serdag
self.session.get.assert_called_once()
+ assert result == mock_dag
- def test_get_serialized_dag_model_returns_none_when_not_found(self):
+ def test_get_dag_returns_cached_on_hit(self):
+ """It should return cached DAG without querying DB."""
+ mock_dag = MagicMock(spec=SerializedDAG)
+ self.db_dag_bag._dags["v1"] = mock_dag
+
+ result = self.db_dag_bag.get_dag("v1", session=self.session)
+
+ assert result == mock_dag
+ self.session.get.assert_not_called()
+
+ def test_get_dag_returns_none_when_not_found(self):
"""It should return None if version_id not found in DB."""
self.session.get.return_value = None
- result = self.db_dag_bag.get_serialized_dag_model("v1",
session=self.session)
+ result = self.db_dag_bag.get_dag("v1", session=self.session)
assert result is None
- def test_get_dag_calls_get_dag_model_and__read_dag(self):
- """It should call get_dag_model and then _read_dag."""
+
+class TestDBDagBagCache:
+ """Tests for DBDagBag optional caching behavior."""
+
+ def test_no_caching_by_default(self):
+ """Test that DBDagBag uses a simple dict without caching by default."""
+ dag_bag = DBDagBag()
+ assert dag_bag._use_cache is False
+ assert isinstance(dag_bag._dags, dict)
+
+ def test_lru_cache_enabled_with_cache_size(self):
+ """Test that LRU cache is enabled when cache_size is provided."""
+ dag_bag = DBDagBag(cache_size=10)
+ assert dag_bag._use_cache is True
+ assert isinstance(dag_bag._dags, LRUCache)
+
+ def test_ttl_cache_enabled_with_cache_size_and_ttl(self):
+ """Test that TTL cache is enabled when both cache_size and cache_ttl
are provided."""
+ dag_bag = DBDagBag(cache_size=10, cache_ttl=60)
+ assert dag_bag._use_cache is True
+ assert isinstance(dag_bag._dags, TTLCache)
+
+ def test_zero_cache_size_uses_unbounded_dict(self):
+ """Test that cache_size=0 uses unbounded dict (same as no caching)."""
+ dag_bag = DBDagBag(cache_size=0, cache_ttl=60)
+ assert dag_bag._use_cache is False
+ assert isinstance(dag_bag._dags, dict)
+
+ def test_clear_cache_with_caching(self):
+ """Test clear_cache() with caching enabled."""
+ dag_bag = DBDagBag(cache_size=10, cache_ttl=60)
+
+ mock_dag = MagicMock()
+ dag_bag._dags["version_1"] = mock_dag
+ dag_bag._dags["version_2"] = mock_dag
Review Comment:
These tests introduce several `MagicMock()` instances without
`spec`/`autospec`, which makes the tests less effective (they’ll accept any
attribute and can mask real interface changes). Prefer `MagicMock(spec=...)`
(or `create_autospec`) for the DAG/model objects being exercised here.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]