kaxil commented on code in PR #60804:
URL: https://github.com/apache/airflow/pull/60804#discussion_r3083386785


##########
airflow-core/src/airflow/models/dagbag.py:
##########
@@ -39,50 +44,122 @@
 
 class DBDagBag:
     """
-    Internal class for retrieving and caching dags in the scheduler.
+    Internal class for retrieving dags from the database.
+
+    Optionally supports LRU+TTL caching when cache_size is provided.
+    The scheduler uses this without caching, while the API server can
+    enable caching via configuration.
 
     :meta private:
     """
 
-    def __init__(self, load_op_links: bool = True) -> None:
-        self._dags: dict[UUID, SerializedDagModel] = {}  # dag_version_id to 
dag
-        self.load_op_links = load_op_links
+    def __init__(
+        self,
+        load_op_links: bool = True,
+        cache_size: int | None = None,
+        cache_ttl: int | None = None,
+    ) -> None:
+        """
+        Initialize DBDagBag.
 
-    def _read_dag(self, serialized_dag_model: SerializedDagModel) -> 
SerializedDAG | None:
-        serialized_dag_model.load_op_links = self.load_op_links
-        if dag := serialized_dag_model.dag:
-            self._dags[serialized_dag_model.dag_version_id] = 
serialized_dag_model
+        :param load_op_links: Should the extra operator link be loaded when 
de-serializing the DAG?
+        :param cache_size: Size of LRU cache. If None or 0, uses unbounded 
dict (no eviction).
+        :param cache_ttl: Time-to-live for cache entries in seconds. If None 
or 0, no TTL (LRU only).
+        """
+        self.load_op_links = load_op_links
+        self._dags: MutableMapping[UUID | str, SerializedDAG] = {}
+        self._dag_models: dict[UUID | str, SerializedDagModel] = {}
+        self._use_cache = False
+
+        # Initialize bounded cache if cache_size is provided and > 0
+        if cache_size and cache_size > 0:
+            if cache_ttl and cache_ttl > 0:
+                self._dags = TTLCache(maxsize=cache_size, ttl=cache_ttl)
+            else:
+                self._dags = LRUCache(maxsize=cache_size)
+            self._use_cache = True
+
+        # Lock required for bounded caches: cachetools caches are NOT 
thread-safe
+        # (LRU reordering and TTL cleanup mutate internal linked lists).
+        # nullcontext for unbounded dict avoids lock overhead in the scheduler 
path.
+        self._lock: RLock | nullcontext = RLock() if self._use_cache else 
nullcontext()
+
+    def _read_dag(self, serdag: SerializedDagModel) -> SerializedDAG | None:
+        """Read and optionally cache a SerializedDAG from a 
SerializedDagModel."""
+        serdag.load_op_links = self.load_op_links
+        dag = serdag.dag
+        if not dag:
+            return None
+        with self._lock:
+            self._dags[serdag.dag_version_id] = dag
+            cache_size = len(self._dags)
+        if self._use_cache:
+            Stats.gauge("api_server.dag_bag.cache_size", cache_size, rate=0.1)
         return dag
 
-    def get_serialized_dag_model(self, version_id: UUID, session: Session) -> 
SerializedDagModel | None:
+    def _get_dag(self, version_id: UUID | str, session: Session) -> 
SerializedDAG | None:
+        # Check cache first
+        with self._lock:
+            dag = self._dags.get(version_id)
+
+        if dag:
+            if self._use_cache:
+                Stats.incr("api_server.dag_bag.cache_hit")
+            return dag
+
+        dag_version = session.get(DagVersion, version_id, 
options=[joinedload(DagVersion.serialized_dag)])
+        if not dag_version:
+            return None
+        if not (serdag := dag_version.serialized_dag):
+            return None
+
+        # Double-checked locking: another thread may have cached it while we 
queried DB.
+        # Only emit the miss metric after confirming no other thread cached 
it, to avoid
+        # counting a single lookup as both a miss and a hit.
+        if self._use_cache:
+            with self._lock:
+                if dag := self._dags.get(version_id):
+                    Stats.incr("api_server.dag_bag.cache_hit")
+                    return dag
+            Stats.incr("api_server.dag_bag.cache_miss")
+        return self._read_dag(serdag)
+
+    def get_dag(self, version_id: UUID | str, session: Session) -> 
SerializedDAG | None:
+        """Get a dag by its version id, using cache if enabled."""
+        return self._get_dag(version_id=version_id, session=session)
+
+    def get_serialized_dag_model(self, version_id: UUID | str, session: 
Session) -> SerializedDagModel | None:
         """
         Return the SerializedDagModel for a given dag version id.
 
-        This will first consult the in-memory cache keyed by the dag version 
id. If the
-        model is not cached, the database is queried for a corresponding 
:class:`DagVersion`
-        and its associated :class:`SerializedDagModel`.
+        Uses a separate plain dict cache (not the LRU/TTL cache, which stores
+        deserialized SerializedDAG objects). The triggerer needs the full model
+        for ``serialized_dag_model.data``.
+        """
+        if serdag := self._dag_models.get(version_id):

Review Comment:
   Good catch. Removed `_dag_models` entirely. `get_serialized_dag_model()` now 
always queries the DB. The only production caller is the triggerer, which 
creates a fresh `DBDagBag()` per batch, so the within-batch deduplication was 
marginal. An unbounded dict in a PR fixing unbounded memory growth was a 
contradiction.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to