jason810496 commented on code in PR #60804:
URL: https://github.com/apache/airflow/pull/60804#discussion_r2711184359


##########
airflow-core/src/airflow/models/dagbag.py:
##########
@@ -37,34 +42,114 @@
     from airflow.models.serialized_dag import SerializedDagModel
     from airflow.serialization.definitions.dag import SerializedDAG
 
+log = logging.getLogger(__name__)
+
 
 class DBDagBag:
     """
-    Internal class for retrieving and caching dags in the scheduler.
+    Internal class for retrieving dags from the database.
+
+    Optionally supports LRU+TTL caching when cache_size is provided.
+    The scheduler uses this without caching, while the API server can
+    enable caching via configuration.
 
     :meta private:
     """
 
-    def __init__(self, load_op_links: bool = True) -> None:
-        self._dags: dict[str, SerializedDAG] = {}  # dag_version_id to dag
+    def __init__(
+        self,
+        load_op_links: bool = True,
+        cache_size: int | None = None,
+        cache_ttl: int | None = None,
+    ) -> None:
+        """
+        Initialize DBDagBag.
+
+        :param load_op_links: Should the extra operator link be loaded when 
de-serializing the DAG?
+        :param cache_size: Size of LRU cache. If None or 0, no caching is used.
+        :param cache_ttl: Time-to-live for cache entries in seconds. If None 
or 0, no TTL is used.
+        """
         self.load_op_links = load_op_links
+        self._cache_size = cache_size
+        self._cache_ttl = cache_ttl
+        self._disable_cache = cache_size == 0
+
+        self._lock: RLock | None = None
+        self._use_cache = False
+        self._dags: MutableMapping[str, SerializedDAG] = {}
+
+        # Initialize cache if cache_size is provided
+        if cache_size and cache_size > 0:
+            if cache_ttl and cache_ttl > 0:
+                self._dags = TTLCache(maxsize=cache_size, ttl=cache_ttl)
+            else:
+                self._dags = LRUCache(maxsize=cache_size)
+            self._lock = RLock()
+            self._use_cache = True
 
     def _read_dag(self, serdag: SerializedDagModel) -> SerializedDAG | None:
+        """Read and optionally cache a SerializedDAG from a 
SerializedDagModel."""
         serdag.load_op_links = self.load_op_links
-        if dag := serdag.dag:
+        dag = serdag.dag
+        if not dag or self._disable_cache:
+            return dag
+        if self._use_cache and self._lock:
+            try:
+                with self._lock:
+                    self._dags[serdag.dag_version_id] = dag
+                    Stats.gauge("api_server.dag_bag.cache_size", 
len(self._dags))
+            except MemoryError:
+                # Re-raise MemoryError to avoid masking OOM conditions
+                raise
+            except Exception:
+                log.warning("Failed to cache DAG %s", serdag.dag_id, 
exc_info=True)
+        else:
             self._dags[serdag.dag_version_id] = dag
         return dag
 
     def _get_dag(self, version_id: str, session: Session) -> SerializedDAG | 
None:
-        if dag := self._dags.get(version_id):
-            return dag
+        if not self._disable_cache:
+            if self._lock:
+                with self._lock:
+                    dag = self._dags.get(version_id)
+            else:
+                dag = self._dags.get(version_id)
+            if dag:
+                if self._use_cache:
+                    Stats.incr("api_server.dag_bag.cache_hit")
+                return dag
+            if self._use_cache:

Review Comment:
   It seems we could consolidate `_disable_cache` and `_use_cache` as same 
variable.



##########
airflow-core/src/airflow/models/dagbag.py:
##########
@@ -37,34 +42,114 @@
     from airflow.models.serialized_dag import SerializedDagModel
     from airflow.serialization.definitions.dag import SerializedDAG
 
+log = logging.getLogger(__name__)
+
 
 class DBDagBag:
     """
-    Internal class for retrieving and caching dags in the scheduler.
+    Internal class for retrieving dags from the database.
+
+    Optionally supports LRU+TTL caching when cache_size is provided.
+    The scheduler uses this without caching, while the API server can
+    enable caching via configuration.
 
     :meta private:
     """
 
-    def __init__(self, load_op_links: bool = True) -> None:
-        self._dags: dict[str, SerializedDAG] = {}  # dag_version_id to dag
+    def __init__(
+        self,
+        load_op_links: bool = True,
+        cache_size: int | None = None,
+        cache_ttl: int | None = None,
+    ) -> None:
+        """
+        Initialize DBDagBag.
+
+        :param load_op_links: Should the extra operator link be loaded when 
de-serializing the DAG?
+        :param cache_size: Size of LRU cache. If None or 0, no caching is used.
+        :param cache_ttl: Time-to-live for cache entries in seconds. If None 
or 0, no TTL is used.
+        """
         self.load_op_links = load_op_links
+        self._cache_size = cache_size
+        self._cache_ttl = cache_ttl
+        self._disable_cache = cache_size == 0
+
+        self._lock: RLock | None = None
+        self._use_cache = False
+        self._dags: MutableMapping[str, SerializedDAG] = {}
+
+        # Initialize cache if cache_size is provided
+        if cache_size and cache_size > 0:

Review Comment:
   Not sure would it be better to use existed `_disable_cache` as condition?
   ```suggestion
           if not self._disable_cache:
   ```



##########
airflow-core/tests/unit/models/test_dagbag.py:
##########
@@ -16,13 +16,234 @@
 # under the License.
 from __future__ import annotations
 
+from concurrent.futures import ThreadPoolExecutor
+from unittest.mock import MagicMock, patch
+
 import pytest
+import time_machine
+from cachetools import LRUCache, TTLCache
+
+from airflow.models.dagbag import DBDagBag
 
 pytestmark = pytest.mark.db_test
 
-# This file previously contained tests for DagBag functionality, but those 
tests
-# have been moved to airflow-core/tests/unit/dag_processing/test_dagbag.py to 
match
-# the source code reorganization where DagBag moved from models to 
dag_processing.
-#
-# Tests for models-specific functionality (DBDagBag, 
DagPriorityParsingRequest, etc.)
-# would remain in this file, but currently no such tests exist.
+
+class TestDBDagBagCache:
+    """Tests for DBDagBag optional caching behavior."""
+
+    def test_no_caching_by_default(self):
+        """Test that DBDagBag uses a simple dict without caching by default."""
+        dag_bag = DBDagBag()
+        assert dag_bag._use_cache is False
+        assert isinstance(dag_bag._dags, dict)
+        assert dag_bag._lock is None
+
+    def test_lru_cache_enabled_with_cache_size(self):
+        """Test that LRU cache is enabled when cache_size is provided."""
+        dag_bag = DBDagBag(cache_size=10)
+        assert dag_bag._use_cache is True
+        assert isinstance(dag_bag._dags, LRUCache)
+        assert dag_bag._lock is not None
+
+    def test_ttl_cache_enabled_with_cache_size_and_ttl(self):
+        """Test that TTL cache is enabled when both cache_size and cache_ttl 
are provided."""
+        dag_bag = DBDagBag(cache_size=10, cache_ttl=60)
+        assert dag_bag._use_cache is True
+        assert isinstance(dag_bag._dags, TTLCache)
+        assert dag_bag._lock is not None
+
+    def test_caching_disabled_with_zero_cache_size(self):
+        """Test that caching is disabled when cache_size is 0."""
+        dag_bag = DBDagBag(cache_size=0, cache_ttl=60)
+        assert dag_bag._disable_cache is True
+        assert dag_bag._use_cache is False
+        assert isinstance(dag_bag._dags, dict)
+
+    def test_clear_cache_with_caching(self):
+        """Test clear_cache() with caching enabled."""
+        dag_bag = DBDagBag(cache_size=10, cache_ttl=60)
+
+        # Add some mock DAGs to cache
+        mock_dag = MagicMock()
+        dag_bag._dags["version_1"] = mock_dag
+        dag_bag._dags["version_2"] = mock_dag
+        assert len(dag_bag._dags) == 2
+
+        # Clear cache
+        count = dag_bag.clear_cache()
+        assert count == 2
+        assert len(dag_bag._dags) == 0
+
+    def test_clear_cache_without_caching(self):
+        """Test clear_cache() without caching enabled."""
+        dag_bag = DBDagBag()
+
+        # Add some mock DAGs
+        mock_dag = MagicMock()
+        dag_bag._dags["version_1"] = mock_dag
+        assert len(dag_bag._dags) == 1
+
+        # Clear cache
+        count = dag_bag.clear_cache()
+        assert count == 1
+        assert len(dag_bag._dags) == 0
+
+    def test_ttl_cache_expiry(self):
+        """Test that cached DAGs expire after TTL."""
+        with time_machine.travel("2025-01-01 00:00:00", tick=False):
+            dag_bag = DBDagBag(cache_size=10, cache_ttl=1)  # 1 second TTL
+
+            # Add a mock DAG to cache
+            mock_dag = MagicMock()
+            dag_bag._dags["test_version_id"] = mock_dag
+            assert "test_version_id" in dag_bag._dags
+
+        # Jump ahead beyond TTL
+        with time_machine.travel("2025-01-01 00:00:02", tick=False):
+            # Cache should have expired
+            assert dag_bag._dags.get("test_version_id") is None
+
+    def test_lru_eviction(self):
+        """Test that LRU eviction works when cache is full."""
+        dag_bag = DBDagBag(cache_size=2)
+
+        # Add 3 DAGs - first one should be evicted
+        dag_bag._dags["version_1"] = MagicMock()
+        dag_bag._dags["version_2"] = MagicMock()
+        dag_bag._dags["version_3"] = MagicMock()
+
+        # version_1 should be evicted (LRU)
+        assert dag_bag._dags.get("version_1") is None
+        assert dag_bag._dags.get("version_2") is not None
+        assert dag_bag._dags.get("version_3") is not None
+
+    def test_thread_safety_with_caching(self):
+        """Test concurrent access doesn't cause race conditions with caching 
enabled."""
+        dag_bag = DBDagBag(cache_size=100, cache_ttl=60)
+        errors = []
+        mock_session = MagicMock()
+
+        def make_dag_version(version_id: str) -> MagicMock:
+            serdag = MagicMock()
+            serdag.dag = MagicMock()
+            serdag.dag_version_id = version_id
+            return MagicMock(serialized_dag=serdag)
+
+        def get_dag_version(model, version_id, options=None):
+            return make_dag_version(version_id)
+
+        mock_session.get.side_effect = get_dag_version
+
+        def access_cache(i):
+            try:
+                dag_bag._get_dag(f"version_{i % 5}", mock_session)
+            except Exception as e:
+                errors.append(e)
+
+        with ThreadPoolExecutor(max_workers=10) as executor:
+            futures = [executor.submit(access_cache, i) for i in range(100)]
+            for f in futures:
+                f.result()
+
+        assert not errors
+
+    def test_read_dag_caches_with_lock(self):
+        """Test that _read_dag uses lock when caching is enabled."""
+        dag_bag = DBDagBag(cache_size=10, cache_ttl=60)
+
+        mock_sdm = MagicMock()
+        mock_sdm.dag = MagicMock()
+        mock_sdm.dag_version_id = "test_version"
+
+        result = dag_bag._read_dag(mock_sdm)
+
+        assert result == mock_sdm.dag
+        assert "test_version" in dag_bag._dags
+
+    def test_read_dag_without_caching(self):

Review Comment:
   The behavior of this test case doesn't seem to match its naming for 
`test_read_dag_caches_with_lock` and `test_read_dag_without_caching` test 
cases. Regardless of whether the cache is enabled or not, 
`self._dags[serdag.dag_version_id] = dag` is executed in the `_read_dag` method.



##########
airflow-core/tests/unit/api_fastapi/common/test_dagbag.py:
##########
@@ -82,3 +83,50 @@ def test_dagbag_used_as_singleton_in_dependency(self, 
session, dag_maker, test_c
         assert resp2.status_code == 200
 
         assert self.dagbag_call_counter["count"] == 1
+
+
+class TestCreateDagBag:
+    """Tests for create_dag_bag() function."""
+

Review Comment:
   Although not necessary, we could consolidate these test methods using 
`pytest.mark.parameterize` with `dag_cache_size, dag_cache_ttl, expected_class`.



##########
airflow-core/src/airflow/models/dagbag.py:
##########
@@ -37,34 +42,114 @@
     from airflow.models.serialized_dag import SerializedDagModel
     from airflow.serialization.definitions.dag import SerializedDAG
 
+log = logging.getLogger(__name__)
+
 
 class DBDagBag:
     """
-    Internal class for retrieving and caching dags in the scheduler.
+    Internal class for retrieving dags from the database.
+
+    Optionally supports LRU+TTL caching when cache_size is provided.
+    The scheduler uses this without caching, while the API server can
+    enable caching via configuration.
 
     :meta private:
     """
 
-    def __init__(self, load_op_links: bool = True) -> None:
-        self._dags: dict[str, SerializedDAG] = {}  # dag_version_id to dag
+    def __init__(
+        self,
+        load_op_links: bool = True,
+        cache_size: int | None = None,
+        cache_ttl: int | None = None,
+    ) -> None:
+        """
+        Initialize DBDagBag.
+
+        :param load_op_links: Should the extra operator link be loaded when 
de-serializing the DAG?
+        :param cache_size: Size of LRU cache. If None or 0, no caching is used.
+        :param cache_ttl: Time-to-live for cache entries in seconds. If None 
or 0, no TTL is used.
+        """
         self.load_op_links = load_op_links
+        self._cache_size = cache_size
+        self._cache_ttl = cache_ttl
+        self._disable_cache = cache_size == 0
+
+        self._lock: RLock | None = None
+        self._use_cache = False
+        self._dags: MutableMapping[str, SerializedDAG] = {}
+
+        # Initialize cache if cache_size is provided
+        if cache_size and cache_size > 0:
+            if cache_ttl and cache_ttl > 0:
+                self._dags = TTLCache(maxsize=cache_size, ttl=cache_ttl)
+            else:
+                self._dags = LRUCache(maxsize=cache_size)
+            self._lock = RLock()
+            self._use_cache = True
 
     def _read_dag(self, serdag: SerializedDagModel) -> SerializedDAG | None:
+        """Read and optionally cache a SerializedDAG from a 
SerializedDagModel."""
         serdag.load_op_links = self.load_op_links
-        if dag := serdag.dag:
+        dag = serdag.dag
+        if not dag or self._disable_cache:
+            return dag
+        if self._use_cache and self._lock:
+            try:
+                with self._lock:
+                    self._dags[serdag.dag_version_id] = dag
+                    Stats.gauge("api_server.dag_bag.cache_size", 
len(self._dags))
+            except MemoryError:
+                # Re-raise MemoryError to avoid masking OOM conditions
+                raise
+            except Exception:
+                log.warning("Failed to cache DAG %s", serdag.dag_id, 
exc_info=True)
+        else:
             self._dags[serdag.dag_version_id] = dag
         return dag
 
     def _get_dag(self, version_id: str, session: Session) -> SerializedDAG | 
None:
-        if dag := self._dags.get(version_id):
-            return dag
+        if not self._disable_cache:
+            if self._lock:
+                with self._lock:
+                    dag = self._dags.get(version_id)
+            else:
+                dag = self._dags.get(version_id)
+            if dag:
+                if self._use_cache:
+                    Stats.incr("api_server.dag_bag.cache_hit")
+                return dag
+            if self._use_cache:
+                Stats.incr("api_server.dag_bag.cache_miss")
         dag_version = session.get(DagVersion, version_id, 
options=[joinedload(DagVersion.serialized_dag)])
         if not dag_version:
             return None
         if not (serdag := dag_version.serialized_dag):
             return None
+        if self._lock and not self._disable_cache:
+            with self._lock:
+                if dag := self._dags.get(version_id):
+                    return dag

Review Comment:
   ```suggestion
   ```
   
   If I understand correctly, we have already handled the case where retrieve 
from the cache before fetching ‎`dag_version.serialized_dag`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to