This is an automated email from the ASF dual-hosted git repository.

ephraimbuddy pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 92f1aa3edf4 Reload serialized Dag when a version is updated in place 
(#68558)
92f1aa3edf4 is described below

commit 92f1aa3edf4c542455e1b9db473fd786bef9cfce
Author: Ephraim Anierobi <[email protected]>
AuthorDate: Mon Jun 15 13:31:50 2026 +0100

    Reload serialized Dag when a version is updated in place (#68558)
    
    * Reload serialized Dag when a version is updated in place
    
    DBDagBag cached deserialized DAGs keyed only by dag_version_id, so after an 
in-place version update (same id, new content for a version with no task 
instances) the scheduler served stale code until restart.
    
    Cache the dag_hash alongside each DAG and revalidate it against the 
database on cache hits, reloading when it no longer matches. Revalidation is 
throttled by [core] min_serialized_dag_update_interval: since a serialized DAG 
cannot be rewritten more often than that interval, an entry validated within 
the window is served without a DB query, and staleness stays bounded instead of 
lasting until restart. The check is a single-row lookup on the uniquely-indexed 
serialized_dag.dag_version_i [...]
    
    closes: #65696
    
    * Fix scheduler test reading DBDagBag cache internals after _CacheEntry 
change
---
 airflow-core/src/airflow/models/dagbag.py          |  83 +++++++++--
 airflow-core/tests/unit/jobs/test_scheduler_job.py |   2 +-
 airflow-core/tests/unit/models/test_dagbag.py      | 162 +++++++++++++++++++--
 3 files changed, 221 insertions(+), 26 deletions(-)

diff --git a/airflow-core/src/airflow/models/dagbag.py 
b/airflow-core/src/airflow/models/dagbag.py
index f3d71addaee..c4bd8eceea1 100644
--- a/airflow-core/src/airflow/models/dagbag.py
+++ b/airflow-core/src/airflow/models/dagbag.py
@@ -18,10 +18,11 @@
 from __future__ import annotations
 
 import hashlib
+import time
 from collections.abc import MutableMapping
 from contextlib import nullcontext
 from threading import RLock
-from typing import TYPE_CHECKING, Any
+from typing import TYPE_CHECKING, Any, NamedTuple
 from uuid import UUID
 
 from cachetools import LRUCache, TTLCache
@@ -29,6 +30,7 @@ from sqlalchemy import String, select
 from sqlalchemy.orm import Mapped, joinedload, mapped_column
 
 from airflow._shared.observability.metrics import stats
+from airflow.configuration import conf
 from airflow.models.base import Base, StringID
 from airflow.models.dag_version import DagVersion
 
@@ -42,6 +44,20 @@ if TYPE_CHECKING:
     from airflow.serialization.definitions.dag import SerializedDAG
 
 
+class _CacheEntry(NamedTuple):
+    """A cached deserialized DAG plus the metadata needed to detect staleness 
on lookup."""
+
+    dag: SerializedDAG
+    dag_hash: str
+    # Monotonic timestamp of the last time this entry's dag_hash was confirmed 
current against the
+    # DB. Used to throttle revalidation: a serialized DAG cannot be rewritten 
more often than
+    # [core] min_serialized_dag_update_interval, so within that window the 
cached copy is served
+    # without a DB round-trip. Because the window restarts on each confirmed 
hit and is on a
+    # different clock than the dag processor's write throttle, worst-case 
staleness is bounded to
+    # roughly one-to-two update intervals -- still bounded, vs. the previous 
unbounded-until-restart.
+    last_validated: float
+
+
 class DBDagBag:
     """
     Internal class for retrieving dags from the database.
@@ -67,9 +83,11 @@ class DBDagBag:
         :param cache_ttl: Time-to-live for cache entries in seconds. If None 
or 0, no TTL (LRU only).
         """
         self.load_op_links = load_op_links
-        self._dags: MutableMapping[UUID | str, SerializedDAG] = {}
+        self._dags: MutableMapping[UUID | str, _CacheEntry] = {}
         self._use_cache = False
 
+        self._revalidation_interval = conf.getint("core", 
"min_serialized_dag_update_interval")
+
         # Initialize bounded cache if cache_size is provided and > 0
         if cache_size and cache_size > 0:
             if cache_ttl and cache_ttl > 0:
@@ -84,27 +102,58 @@ class DBDagBag:
         self._lock: RLock | nullcontext = RLock() if self._use_cache else 
nullcontext()
 
     def _read_dag(self, serdag: SerializedDagModel) -> SerializedDAG | None:
-        """Read and optionally cache a SerializedDAG from a 
SerializedDagModel."""
+        """Read and cache a SerializedDAG (with its ``dag_hash`` for staleness 
detection)."""
         serdag.load_op_links = self.load_op_links
         dag = serdag.dag
         if not dag:
             return None
         with self._lock:
-            self._dags[serdag.dag_version_id] = dag
+            self._dags[serdag.dag_version_id] = _CacheEntry(dag, 
serdag.dag_hash, time.monotonic())
             cache_size = len(self._dags)
         if self._use_cache:
             stats.gauge("api_server.dag_bag.cache_size", cache_size, rate=0.1)
         return dag
 
+    @staticmethod
+    def _current_dag_hash(version_id: UUID | str, session: Session) -> str | 
None:
+        """Return the current ``dag_hash`` of the serialized DAG for 
``version_id``, or None."""
+        from airflow.models.serialized_dag import SerializedDagModel
+
+        return session.scalar(
+            
select(SerializedDagModel.dag_hash).where(SerializedDagModel.dag_version_id == 
version_id)
+        )
+
     def _get_dag(self, version_id: UUID | str, session: Session) -> 
SerializedDAG | None:
-        # Check cache first
         with self._lock:
-            dag = self._dags.get(version_id)
-
-        if dag:
-            if self._use_cache:
-                stats.incr("api_server.dag_bag.cache_hit")
-            return dag
+            cached = self._dags.get(version_id)
+
+        if cached is not None:
+            now = time.monotonic()
+            # A serialized DAG cannot be rewritten more often than
+            # [core] min_serialized_dag_update_interval, so an entry validated 
within that window
+            # cannot have gone stale yet -- serve it without touching the DB.
+            if now - cached.last_validated < self._revalidation_interval:
+                if self._use_cache:
+                    stats.incr("api_server.dag_bag.cache_hit")
+                return cached.dag
+            # Past the window: a version may have been updated in place (same 
dag_version_id, new
+            # content + new dag_hash) by SerializedDagModel.write_dag, so 
confirm the cached copy
+            # against the current dag_hash. That validation is a single-row 
lookup on the
+            # uniquely-indexed serialized_dag.dag_version_id column.
+            if self._current_dag_hash(version_id, session) == cached.dag_hash:
+                # Still current: restart the revalidation window so the next 
hits skip the query.
+                # (For a TTLCache this write-back also refreshes the entry's 
TTL/LRU recency, which
+                # is fine -- the entry was just re-confirmed against the DB.)
+                with self._lock:
+                    current = self._dags.get(version_id)
+                    if current is not None and current.dag_hash == 
cached.dag_hash:
+                        self._dags[version_id] = 
current._replace(last_validated=now)
+                if self._use_cache:
+                    stats.incr("api_server.dag_bag.cache_hit")
+                return cached.dag
+            # Stale (updated in place) or the version no longer exists: drop 
and reload below.
+            with self._lock:
+                self._dags.pop(version_id, None)
 
         dag_version = session.get(DagVersion, version_id, 
options=[joinedload(DagVersion.serialized_dag)])
         if not dag_version:
@@ -112,14 +161,16 @@ class DBDagBag:
         if not (serdag := dag_version.serialized_dag):
             return None
 
-        # Double-checked locking: another thread may have cached it while we 
queried DB.
-        # Only emit the miss metric after confirming no other thread cached 
it, to avoid
-        # counting a single lookup as both a miss and a hit.
+        # Double-checked locking: another thread may have cached it while we 
queried DB. Such an
+        # entry was just loaded from the DB, so it is well within its 
revalidation window and is
+        # served without an extra hash check, consistent with the policy 
above. Only emit the miss
+        # metric after confirming no other thread cached it, to avoid counting 
a single lookup as
+        # both a miss and a hit.
         if self._use_cache:
             with self._lock:
-                if dag := self._dags.get(version_id):
+                if (cached := self._dags.get(version_id)) is not None:
                     stats.incr("api_server.dag_bag.cache_hit")
-                    return dag
+                    return cached.dag
             stats.incr("api_server.dag_bag.cache_miss")
         return self._read_dag(serdag)
 
diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py 
b/airflow-core/tests/unit/jobs/test_scheduler_job.py
index 67986c0297e..3ee51bb917d 100644
--- a/airflow-core/tests/unit/jobs/test_scheduler_job.py
+++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py
@@ -7791,7 +7791,7 @@ class TestSchedulerJob:
         with create_session() as session:
             tis = session.scalars(select(TaskInstance)).all()
 
-        dags = self.job_runner.scheduler_dag_bag._dags.values()
+        dags = [entry.dag for entry in 
self.job_runner.scheduler_dag_bag._dags.values()]
         assert [dag.dag_id for dag in dags] == ["test_only_empty_tasks"]
         assert len(tis) == 6
         assert {
diff --git a/airflow-core/tests/unit/models/test_dagbag.py 
b/airflow-core/tests/unit/models/test_dagbag.py
index 4dfbd4d8a6b..79668d4fe54 100644
--- a/airflow-core/tests/unit/models/test_dagbag.py
+++ b/airflow-core/tests/unit/models/test_dagbag.py
@@ -24,9 +24,17 @@ import pytest
 import time_machine
 from cachetools import LRUCache, TTLCache
 
-from airflow.models.dagbag import DBDagBag
+from airflow.models.dag import DagModel
+from airflow.models.dag_version import DagVersion
+from airflow.models.dagbag import DBDagBag, _CacheEntry
+from airflow.models.dagbundle import DagBundleModel
 from airflow.models.serialized_dag import SerializedDagModel
-from airflow.serialization.serialized_objects import SerializedDAG
+from airflow.providers.standard.operators.empty import EmptyOperator
+from airflow.sdk import DAG
+from airflow.serialization.serialized_objects import LazyDeserializedDAG, 
SerializedDAG
+from airflow.utils.session import create_session
+
+from tests_common.test_utils import db
 
 pytestmark = pytest.mark.db_test
 
@@ -44,16 +52,18 @@ class TestDBDagBag:
         self.session = MagicMock()
 
     def test__read_dag_stores_and_returns_dag(self):
-        """It should store the SerializedDAG in _dags and return it."""
+        """It should store the SerializedDAG with its hash, and return it."""
         mock_dag = MagicMock(spec=SerializedDAG)
         mock_serdag = MagicMock(spec=SerializedDagModel)
         mock_serdag.dag = mock_dag
         mock_serdag.dag_version_id = "v1"
+        mock_serdag.dag_hash = "hash1"
 
         result = self.db_dag_bag._read_dag(mock_serdag)
 
         assert result == mock_dag
-        assert self.db_dag_bag._dags["v1"] == mock_dag
+        entry = self.db_dag_bag._dags["v1"]
+        assert (entry.dag, entry.dag_hash) == (mock_dag, "hash1")
         assert mock_serdag.load_op_links is True
 
     def test__read_dag_returns_none_when_no_dag(self):
@@ -68,11 +78,12 @@ class TestDBDagBag:
         assert "v1" not in self.db_dag_bag._dags
 
     def test_get_dag_fetches_from_db_on_miss(self):
-        """It should query the DB and cache the result when not in cache."""
+        """It should query the DB and cache the result (with its hash) when 
not in cache."""
         mock_dag = MagicMock(spec=SerializedDAG)
         mock_serdag = MagicMock(spec=SerializedDagModel)
         mock_serdag.dag = mock_dag
         mock_serdag.dag_version_id = "v1"
+        mock_serdag.dag_hash = "hash1"
         mock_dag_version = MagicMock()
         mock_dag_version.serialized_dag = mock_serdag
         self.session.get.return_value = mock_dag_version
@@ -81,16 +92,87 @@ class TestDBDagBag:
 
         self.session.get.assert_called_once()
         assert result == mock_dag
+        entry = self.db_dag_bag._dags["v1"]
+        assert (entry.dag, entry.dag_hash) == (mock_dag, "hash1")
+
+    def test_get_dag_serves_within_revalidation_window_without_query(self):
+        """A recently validated entry is served straight from cache with no DB 
query at all."""
+        mock_dag = MagicMock(spec=SerializedDAG)
+        # Just-validated entry, well within the (default 30s) revalidation 
window.
+        self.db_dag_bag._dags["v1"] = _CacheEntry(mock_dag, "hash1", 
time.monotonic())
+
+        result = self.db_dag_bag.get_dag("v1", session=self.session)
+
+        assert result == mock_dag
+        self.session.scalar.assert_not_called()  # no revalidation query 
inside the window
+        self.session.get.assert_not_called()
 
-    def test_get_dag_returns_cached_on_hit(self):
-        """It should return cached DAG without querying DB."""
+    def test_get_dag_serves_stale_within_window_even_if_db_changed(self):
+        """Inside the window the cached copy is served even if the DB has 
since changed.
+
+        This is the intended throttle tradeoff: staleness is bounded by the 
window, not zero.
+        """
+        stale_dag = MagicMock(spec=SerializedDAG)
+        self.db_dag_bag._dags["v1"] = _CacheEntry(stale_dag, "old_hash", 
time.monotonic())
+        # The DB has a different hash now, but we are inside the window so it 
is never consulted.
+        self.session.scalar.return_value = "new_hash"
+
+        result = self.db_dag_bag.get_dag("v1", session=self.session)
+
+        assert result == stale_dag
+        self.session.scalar.assert_not_called()
+        self.session.get.assert_not_called()
+
+    def 
test_get_dag_revalidates_after_window_and_serves_when_hash_matches(self):
+        """Past the window, a hit is revalidated by hash and served (window 
restarted)."""
         mock_dag = MagicMock(spec=SerializedDAG)
-        self.db_dag_bag._dags["v1"] = mock_dag
+        # last_validated=0.0 is far in the past, so the entry is revalidated.
+        self.db_dag_bag._dags["v1"] = _CacheEntry(mock_dag, "hash1", 0.0)
+        self.session.scalar.return_value = "hash1"
 
         result = self.db_dag_bag.get_dag("v1", session=self.session)
 
         assert result == mock_dag
+        # Validated via a cheap scalar() lookup, not a full DagVersion load.
+        self.session.scalar.assert_called_once()
         self.session.get.assert_not_called()
+        # The window is restarted so the next hit can skip the query.
+        assert self.db_dag_bag._dags["v1"].last_validated > 0.0
+
+    def test_get_dag_reloads_when_version_updated_in_place(self):
+        """A version updated in place (same id, new hash) must be reloaded, 
not served stale."""
+        stale_dag = MagicMock(spec=SerializedDAG)
+        fresh_dag = MagicMock(spec=SerializedDAG)
+        # last_validated=0.0 forces revalidation regardless of the window.
+        self.db_dag_bag._dags["v1"] = _CacheEntry(stale_dag, "old_hash", 0.0)
+        mock_serdag = MagicMock(spec=SerializedDagModel)
+        mock_serdag.dag = fresh_dag
+        mock_serdag.dag_version_id = "v1"
+        mock_serdag.dag_hash = "new_hash"
+        mock_dag_version = MagicMock()
+        mock_dag_version.serialized_dag = mock_serdag
+        self.session.get.return_value = mock_dag_version
+        # The dag_hash validation lookup returns the new hash (mismatch -> 
reload).
+        self.session.scalar.return_value = "new_hash"
+
+        result = self.db_dag_bag.get_dag("v1", session=self.session)
+
+        assert result == fresh_dag
+        self.session.get.assert_called_once()
+        entry = self.db_dag_bag._dags["v1"]
+        assert (entry.dag, entry.dag_hash) == (fresh_dag, "new_hash")
+
+    def test_get_dag_reloads_when_cached_version_deleted(self):
+        """A cached entry whose serialized row no longer exists must not be 
served."""
+        stale_dag = MagicMock(spec=SerializedDAG)
+        self.db_dag_bag._dags["v1"] = _CacheEntry(stale_dag, "old_hash", 0.0)
+        self.session.scalar.return_value = None  # validation finds no row
+        self.session.get.return_value = None  # version is gone
+
+        result = self.db_dag_bag.get_dag("v1", session=self.session)
+
+        assert result is None
+        assert "v1" not in self.db_dag_bag._dags
 
     def test_get_dag_returns_none_when_not_found(self):
         """It should return None if version_id not found in DB."""
@@ -100,6 +182,66 @@ class TestDBDagBag:
 
         assert result is None
 
+    def test_get_dag_reflects_in_place_version_update_end_to_end(self):
+        """End-to-end regression: an in-place version update must be re-read, 
not served stale.
+
+        When a DagVersion has no task instances, 
``SerializedDagModel.write_dag`` updates the
+        serialized DAG in place (same ``dag_version_id``, new content). A 
long-lived DagBag (e.g.
+        the scheduler's) must reflect the new content instead of serving the 
cached old code.
+
+        Each step uses its own session, matching the real deployment where the 
dag processor
+        writes and the scheduler reads in separate processes/sessions.
+        """
+        dag_id = "stale_cache_dag"
+        bundle_name = "testing"
+        db.clear_db_dags()
+        db.clear_db_serialized_dags()
+        db.clear_db_dag_bundles()
+
+        def make_lazy(task_ids):
+            with DAG(dag_id, schedule=None) as dag:
+                for task_id in task_ids:
+                    EmptyOperator(task_id=task_id)
+            return LazyDeserializedDAG.from_dag(dag)
+
+        # Long-lived bag, like the scheduler's process-lived 
scheduler_dag_bag. A 0s revalidation
+        # interval makes every hit revalidate, exercising the post-window 
reload path
+        # deterministically without manipulating the clock.
+        dag_bag = DBDagBag()
+        dag_bag._revalidation_interval = 0
+
+        with create_session() as session:
+            session.add(DagBundleModel(name=bundle_name))
+            session.flush()
+            session.add(DagModel(dag_id=dag_id, bundle_name=bundle_name))
+            session.flush()
+            # Version 1: a single task, no task instances yet.
+            SerializedDagModel.write_dag(make_lazy(["a"]), 
bundle_name=bundle_name, session=session)
+            session.commit()
+            version_id = DagVersion.get_latest_version(dag_id, 
session=session).id
+
+        # The scheduler loads and caches the DAG.
+        with create_session() as session:
+            assert set(dag_bag.get_dag(version_id, session=session).task_ids) 
== {"a"}
+
+        # The dag processor adds a task and re-writes. With no task instances 
on the version,
+        # write_dag updates it in place (same dag_version_id, new content + 
hash).
+        with create_session() as session:
+            did_write = SerializedDagModel.write_dag(
+                make_lazy(["a", "b"]), bundle_name=bundle_name, session=session
+            )
+            session.commit()
+            assert did_write is True
+            assert DagVersion.get_latest_version(dag_id, session=session).id 
== version_id
+
+        # The scheduler reads again: it must serve the updated DAG, not the 
stale cached one.
+        with create_session() as session:
+            assert set(dag_bag.get_dag(version_id, session=session).task_ids) 
== {"a", "b"}
+
+        db.clear_db_dags()
+        db.clear_db_serialized_dags()
+        db.clear_db_dag_bundles()
+
 
 class TestDBDagBagCache:
     """Tests for DBDagBag optional caching behavior."""
@@ -257,7 +399,9 @@ class TestDBDagBagCache:
         """Test that cache hit metric is emitted when caching is enabled."""
         dag_bag = DBDagBag(cache_size=10, cache_ttl=60)
         mock_session = MagicMock()
-        dag_bag._dags["test_version"] = MagicMock()
+        # last_validated=0.0 forces revalidation; the hash matches, so it 
counts as a hit.
+        dag_bag._dags["test_version"] = _CacheEntry(MagicMock(), "hash1", 0.0)
+        mock_session.scalar.return_value = "hash1"
 
         dag_bag._get_dag("test_version", mock_session)
 

Reply via email to