kaxil commented on code in PR #60804:
URL: https://github.com/apache/airflow/pull/60804#discussion_r3067301823
##########
airflow-core/src/airflow/api_fastapi/common/dagbag.py:
##########
@@ -16,21 +16,42 @@
# under the License.
from __future__ import annotations
+import logging
from typing import TYPE_CHECKING, Annotated
from fastapi import Depends, HTTPException, Request, status
from sqlalchemy.orm import Session
+from airflow.configuration import conf
from airflow.models.dagbag import DBDagBag
if TYPE_CHECKING:
from airflow.models.dagrun import DagRun
from airflow.serialization.definitions.dag import SerializedDAG
+log = logging.getLogger(__name__)
+
def create_dag_bag() -> DBDagBag:
- """Create DagBag to retrieve DAGs from the database."""
- return DBDagBag()
+ """Create DagBag with configurable LRU+TTL caching for API server usage."""
+ cache_size = conf.getint("api", "dag_cache_size", fallback=64)
+ cache_ttl_config = conf.getint("api", "dag_cache_ttl", fallback=3600)
+
+ if cache_size < 0:
+ log.warning("dag_cache_size must be >= 0, disabling cache")
+ cache_size = 0
+ if cache_ttl_config < 0:
+ log.warning("dag_cache_ttl must be >= 0, disabling TTL")
+ cache_ttl_config = 0
+
+ # Disable caching if cache_size is 0
+ if cache_size <= 0:
+ return DBDagBag(cache_size=0)
+
+ # Disable TTL if cache_ttl is 0
Review Comment:
Fixed -- config and code comments now say "unbounded dict (no eviction)"
instead of "disable caching".
##########
airflow-core/src/airflow/config_templates/config.yml:
##########
@@ -1606,6 +1606,30 @@ api:
type: string
example: ~
default: "False"
+ dag_cache_size:
+ description: |
+ Size of the LRU cache for SerializedDAG objects in the API server.
+ Set to 0 to disable caching. The cache is keyed by DAG version ID,
+ so lookups by DAG ID (e.g., viewing a DAG's details) always query
+ the database for the latest version, but the deserialized result is
+ cached for subsequent version-specific lookups.
Review Comment:
Fixed -- now reads "Set to 0 to use an unbounded dict (no eviction, pre-3.2
behavior)".
##########
airflow-core/src/airflow/models/dagbag.py:
##########
@@ -39,50 +44,122 @@
class DBDagBag:
"""
- Internal class for retrieving and caching dags in the scheduler.
+ Internal class for retrieving dags from the database.
+
+ Optionally supports LRU+TTL caching when cache_size is provided.
+ The scheduler uses this without caching, while the API server can
+ enable caching via configuration.
:meta private:
"""
- def __init__(self, load_op_links: bool = True) -> None:
- self._dags: dict[UUID, SerializedDagModel] = {} # dag_version_id to
dag
- self.load_op_links = load_op_links
+ def __init__(
+ self,
+ load_op_links: bool = True,
+ cache_size: int | None = None,
+ cache_ttl: int | None = None,
+ ) -> None:
+ """
+ Initialize DBDagBag.
- def _read_dag(self, serialized_dag_model: SerializedDagModel) ->
SerializedDAG | None:
- serialized_dag_model.load_op_links = self.load_op_links
- if dag := serialized_dag_model.dag:
- self._dags[serialized_dag_model.dag_version_id] =
serialized_dag_model
+ :param load_op_links: Should the extra operator link be loaded when
de-serializing the DAG?
+ :param cache_size: Size of LRU cache. If None or 0, uses unbounded
dict (no eviction).
+ :param cache_ttl: Time-to-live for cache entries in seconds. If None
or 0, no TTL (LRU only).
+ """
+ self.load_op_links = load_op_links
+ self._dags: MutableMapping[UUID | str, SerializedDAG] = {}
+ self._dag_models: dict[UUID | str, SerializedDagModel] = {}
+ self._use_cache = False
+
+ # Initialize bounded cache if cache_size is provided and > 0
+ if cache_size and cache_size > 0:
+ if cache_ttl and cache_ttl > 0:
+ self._dags = TTLCache(maxsize=cache_size, ttl=cache_ttl)
+ else:
+ self._dags = LRUCache(maxsize=cache_size)
+ self._use_cache = True
+
+ # Lock required for bounded caches: cachetools caches are NOT
thread-safe
+ # (LRU reordering and TTL cleanup mutate internal linked lists).
+ # nullcontext for unbounded dict avoids lock overhead in the scheduler
path.
+ self._lock: RLock | nullcontext = RLock() if self._use_cache else
nullcontext()
+
+ def _read_dag(self, serdag: SerializedDagModel) -> SerializedDAG | None:
+ """Read and optionally cache a SerializedDAG from a
SerializedDagModel."""
+ serdag.load_op_links = self.load_op_links
+ dag = serdag.dag
+ if not dag:
+ return None
+ with self._lock:
+ self._dags[serdag.dag_version_id] = dag
+ cache_size = len(self._dags)
+ if self._use_cache:
+ Stats.gauge("api_server.dag_bag.cache_size", cache_size, rate=0.1)
return dag
- def get_serialized_dag_model(self, version_id: UUID, session: Session) ->
SerializedDagModel | None:
+ def _get_dag(self, version_id: UUID | str, session: Session) ->
SerializedDAG | None:
+ # Check cache first
+ with self._lock:
+ dag = self._dags.get(version_id)
+
+ if dag:
+ if self._use_cache:
+ Stats.incr("api_server.dag_bag.cache_hit")
+ return dag
+
+ dag_version = session.get(DagVersion, version_id,
options=[joinedload(DagVersion.serialized_dag)])
+ if not dag_version:
+ return None
+ if not (serdag := dag_version.serialized_dag):
+ return None
+
+ # Double-checked locking: another thread may have cached it while we
queried DB.
+ # Only emit the miss metric after confirming no other thread cached
it, to avoid
+ # counting a single lookup as both a miss and a hit.
+ if self._use_cache:
+ with self._lock:
+ if dag := self._dags.get(version_id):
+ Stats.incr("api_server.dag_bag.cache_hit")
+ return dag
+ Stats.incr("api_server.dag_bag.cache_miss")
+ return self._read_dag(serdag)
Review Comment:
Correct -- the double-checked locking is a soft optimization, not a mutex
around the DB query. Multiple threads can still issue concurrent DB queries for
the same version_id. The lock only deduplicates the cache write so the
last-arriving thread can find it cached instead of deserializing again. A
per-key lock or "pending" sentinel would fully deduplicate DB queries but adds
complexity that isn't warranted for this use case.
##########
airflow-core/src/airflow/models/dagbag.py:
##########
@@ -39,50 +44,122 @@
class DBDagBag:
"""
- Internal class for retrieving and caching dags in the scheduler.
+ Internal class for retrieving dags from the database.
+
+ Optionally supports LRU+TTL caching when cache_size is provided.
+ The scheduler uses this without caching, while the API server can
+ enable caching via configuration.
:meta private:
"""
- def __init__(self, load_op_links: bool = True) -> None:
- self._dags: dict[UUID, SerializedDagModel] = {} # dag_version_id to
dag
- self.load_op_links = load_op_links
+ def __init__(
+ self,
+ load_op_links: bool = True,
+ cache_size: int | None = None,
+ cache_ttl: int | None = None,
+ ) -> None:
+ """
+ Initialize DBDagBag.
- def _read_dag(self, serialized_dag_model: SerializedDagModel) ->
SerializedDAG | None:
- serialized_dag_model.load_op_links = self.load_op_links
- if dag := serialized_dag_model.dag:
- self._dags[serialized_dag_model.dag_version_id] =
serialized_dag_model
+ :param load_op_links: Should the extra operator link be loaded when
de-serializing the DAG?
+ :param cache_size: Size of LRU cache. If None or 0, uses unbounded
dict (no eviction).
+ :param cache_ttl: Time-to-live for cache entries in seconds. If None
or 0, no TTL (LRU only).
+ """
+ self.load_op_links = load_op_links
+ self._dags: MutableMapping[UUID | str, SerializedDAG] = {}
+ self._dag_models: dict[UUID | str, SerializedDagModel] = {}
+ self._use_cache = False
+
+ # Initialize bounded cache if cache_size is provided and > 0
+ if cache_size and cache_size > 0:
+ if cache_ttl and cache_ttl > 0:
+ self._dags = TTLCache(maxsize=cache_size, ttl=cache_ttl)
+ else:
+ self._dags = LRUCache(maxsize=cache_size)
+ self._use_cache = True
+
+ # Lock required for bounded caches: cachetools caches are NOT
thread-safe
+ # (LRU reordering and TTL cleanup mutate internal linked lists).
+ # nullcontext for unbounded dict avoids lock overhead in the scheduler
path.
+ self._lock: RLock | nullcontext = RLock() if self._use_cache else
nullcontext()
+
Review Comment:
Fair point. Both RLock and nullcontext() are context managers so it works at
runtime. A more precise annotation would be `AbstractContextManager` but the
current form is readable and mypy doesn't flag it. Leaving as-is for now.
##########
airflow-core/tests/unit/models/test_dagbag.py:
##########
@@ -63,45 +67,239 @@ def test__read_dag_returns_none_when_no_dag(self):
assert result is None
assert "v1" not in self.db_dag_bag._dags
- def test_get_serialized_dag_model(self):
- """It should return the cached SerializedDagModel if already loaded."""
+ def test_get_dag_fetches_from_db_on_miss(self):
+ """It should query the DB and cache the result when not in cache."""
+ mock_dag = MagicMock(spec=SerializedDAG)
mock_serdag = MagicMock(spec=SerializedDagModel)
+ mock_serdag.dag = mock_dag
mock_serdag.dag_version_id = "v1"
mock_dag_version = MagicMock()
mock_dag_version.serialized_dag = mock_serdag
self.session.get.return_value = mock_dag_version
- self.db_dag_bag.get_serialized_dag_model("v1", session=self.session)
- result = self.db_dag_bag.get_serialized_dag_model("v1",
session=self.session)
+ result = self.db_dag_bag.get_dag("v1", session=self.session)
- assert result == mock_serdag
self.session.get.assert_called_once()
+ assert result == mock_dag
- def test_get_serialized_dag_model_returns_none_when_not_found(self):
+ def test_get_dag_returns_cached_on_hit(self):
+ """It should return cached DAG without querying DB."""
+ mock_dag = MagicMock(spec=SerializedDAG)
+ self.db_dag_bag._dags["v1"] = mock_dag
+
+ result = self.db_dag_bag.get_dag("v1", session=self.session)
+
+ assert result == mock_dag
+ self.session.get.assert_not_called()
+
+ def test_get_dag_returns_none_when_not_found(self):
"""It should return None if version_id not found in DB."""
self.session.get.return_value = None
- result = self.db_dag_bag.get_serialized_dag_model("v1",
session=self.session)
+ result = self.db_dag_bag.get_dag("v1", session=self.session)
assert result is None
- def test_get_dag_calls_get_dag_model_and__read_dag(self):
- """It should call get_dag_model and then _read_dag."""
+
+class TestDBDagBagCache:
+ """Tests for DBDagBag optional caching behavior."""
+
+ def test_no_caching_by_default(self):
+ """Test that DBDagBag uses a simple dict without caching by default."""
+ dag_bag = DBDagBag()
+ assert dag_bag._use_cache is False
+ assert isinstance(dag_bag._dags, dict)
+
Review Comment:
The test validates the default constructor produces an unbounded dict with
no LRU/TTL. The name "no_caching_by_default" is shorthand for "no bounded
caching" -- the unbounded dict path is the pre-existing behavior, not a caching
feature. I think the meaning is clear in context.
##########
airflow-core/tests/unit/models/test_dagbag.py:
##########
@@ -63,45 +67,239 @@ def test__read_dag_returns_none_when_no_dag(self):
assert result is None
assert "v1" not in self.db_dag_bag._dags
- def test_get_serialized_dag_model(self):
- """It should return the cached SerializedDagModel if already loaded."""
+ def test_get_dag_fetches_from_db_on_miss(self):
+ """It should query the DB and cache the result when not in cache."""
+ mock_dag = MagicMock(spec=SerializedDAG)
mock_serdag = MagicMock(spec=SerializedDagModel)
+ mock_serdag.dag = mock_dag
mock_serdag.dag_version_id = "v1"
mock_dag_version = MagicMock()
mock_dag_version.serialized_dag = mock_serdag
self.session.get.return_value = mock_dag_version
- self.db_dag_bag.get_serialized_dag_model("v1", session=self.session)
- result = self.db_dag_bag.get_serialized_dag_model("v1",
session=self.session)
+ result = self.db_dag_bag.get_dag("v1", session=self.session)
- assert result == mock_serdag
self.session.get.assert_called_once()
+ assert result == mock_dag
- def test_get_serialized_dag_model_returns_none_when_not_found(self):
+ def test_get_dag_returns_cached_on_hit(self):
+ """It should return cached DAG without querying DB."""
+ mock_dag = MagicMock(spec=SerializedDAG)
+ self.db_dag_bag._dags["v1"] = mock_dag
+
+ result = self.db_dag_bag.get_dag("v1", session=self.session)
+
+ assert result == mock_dag
+ self.session.get.assert_not_called()
+
+ def test_get_dag_returns_none_when_not_found(self):
"""It should return None if version_id not found in DB."""
self.session.get.return_value = None
- result = self.db_dag_bag.get_serialized_dag_model("v1",
session=self.session)
+ result = self.db_dag_bag.get_dag("v1", session=self.session)
assert result is None
- def test_get_dag_calls_get_dag_model_and__read_dag(self):
- """It should call get_dag_model and then _read_dag."""
+
+class TestDBDagBagCache:
+ """Tests for DBDagBag optional caching behavior."""
+
+ def test_no_caching_by_default(self):
+ """Test that DBDagBag uses a simple dict without caching by default."""
+ dag_bag = DBDagBag()
+ assert dag_bag._use_cache is False
+ assert isinstance(dag_bag._dags, dict)
+
+ def test_lru_cache_enabled_with_cache_size(self):
+ """Test that LRU cache is enabled when cache_size is provided."""
+ dag_bag = DBDagBag(cache_size=10)
+ assert dag_bag._use_cache is True
+ assert isinstance(dag_bag._dags, LRUCache)
+
+ def test_ttl_cache_enabled_with_cache_size_and_ttl(self):
+ """Test that TTL cache is enabled when both cache_size and cache_ttl
are provided."""
+ dag_bag = DBDagBag(cache_size=10, cache_ttl=60)
+ assert dag_bag._use_cache is True
+ assert isinstance(dag_bag._dags, TTLCache)
+
+ def test_zero_cache_size_uses_unbounded_dict(self):
+ """Test that cache_size=0 uses unbounded dict (same as no caching)."""
+ dag_bag = DBDagBag(cache_size=0, cache_ttl=60)
+ assert dag_bag._use_cache is False
+ assert isinstance(dag_bag._dags, dict)
+
+ def test_clear_cache_with_caching(self):
+ """Test clear_cache() with caching enabled."""
+ dag_bag = DBDagBag(cache_size=10, cache_ttl=60)
+
+ mock_dag = MagicMock()
+ dag_bag._dags["version_1"] = mock_dag
+ dag_bag._dags["version_2"] = mock_dag
Review Comment:
Good point. The thread safety test uses unspecced mocks because it's testing
lock contention, not interface correctness. The other tests in the file do use
spec. Could tighten this in a follow-up.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]