This is an automated email from the ASF dual-hosted git repository.
ephraimbuddy pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 92f1aa3edf4 Reload serialized Dag when a version is updated in place
(#68558)
92f1aa3edf4 is described below
commit 92f1aa3edf4c542455e1b9db473fd786bef9cfce
Author: Ephraim Anierobi <[email protected]>
AuthorDate: Mon Jun 15 13:31:50 2026 +0100
Reload serialized Dag when a version is updated in place (#68558)
* Reload serialized Dag when a version is updated in place
DBDagBag cached deserialized DAGs keyed only by dag_version_id, so after an
in-place version update (same id, new content for a version with no task
instances) the scheduler served stale code until restart.
Cache the dag_hash alongside each DAG and revalidate it against the
database on cache hits, reloading when it no longer matches. Revalidation is
throttled by [core] min_serialized_dag_update_interval: since a serialized DAG
cannot be rewritten more often than that interval, an entry validated within
the window is served without a DB query, and staleness stays bounded instead of
lasting until restart. The check is a single-row lookup on the uniquely-indexed
serialized_dag.dag_version_i [...]
closes: #65696
* Fix scheduler test reading DBDagBag cache internals after _CacheEntry
change
---
airflow-core/src/airflow/models/dagbag.py | 83 +++++++++--
airflow-core/tests/unit/jobs/test_scheduler_job.py | 2 +-
airflow-core/tests/unit/models/test_dagbag.py | 162 +++++++++++++++++++--
3 files changed, 221 insertions(+), 26 deletions(-)
diff --git a/airflow-core/src/airflow/models/dagbag.py
b/airflow-core/src/airflow/models/dagbag.py
index f3d71addaee..c4bd8eceea1 100644
--- a/airflow-core/src/airflow/models/dagbag.py
+++ b/airflow-core/src/airflow/models/dagbag.py
@@ -18,10 +18,11 @@
from __future__ import annotations
import hashlib
+import time
from collections.abc import MutableMapping
from contextlib import nullcontext
from threading import RLock
-from typing import TYPE_CHECKING, Any
+from typing import TYPE_CHECKING, Any, NamedTuple
from uuid import UUID
from cachetools import LRUCache, TTLCache
@@ -29,6 +30,7 @@ from sqlalchemy import String, select
from sqlalchemy.orm import Mapped, joinedload, mapped_column
from airflow._shared.observability.metrics import stats
+from airflow.configuration import conf
from airflow.models.base import Base, StringID
from airflow.models.dag_version import DagVersion
@@ -42,6 +44,20 @@ if TYPE_CHECKING:
from airflow.serialization.definitions.dag import SerializedDAG
+class _CacheEntry(NamedTuple):
+ """A cached deserialized DAG plus the metadata needed to detect staleness
on lookup."""
+
+ dag: SerializedDAG
+ dag_hash: str
+ # Monotonic timestamp of the last time this entry's dag_hash was confirmed
current against the
+ # DB. Used to throttle revalidation: a serialized DAG cannot be rewritten
more often than
+ # [core] min_serialized_dag_update_interval, so within that window the
cached copy is served
+ # without a DB round-trip. Because the window restarts on each confirmed
hit and is on a
+ # different clock than the dag processor's write throttle, worst-case
staleness is bounded to
+ # roughly one-to-two update intervals -- still bounded, vs. the previous
unbounded-until-restart.
+ last_validated: float
+
+
class DBDagBag:
"""
Internal class for retrieving dags from the database.
@@ -67,9 +83,11 @@ class DBDagBag:
:param cache_ttl: Time-to-live for cache entries in seconds. If None
or 0, no TTL (LRU only).
"""
self.load_op_links = load_op_links
- self._dags: MutableMapping[UUID | str, SerializedDAG] = {}
+ self._dags: MutableMapping[UUID | str, _CacheEntry] = {}
self._use_cache = False
+ self._revalidation_interval = conf.getint("core",
"min_serialized_dag_update_interval")
+
# Initialize bounded cache if cache_size is provided and > 0
if cache_size and cache_size > 0:
if cache_ttl and cache_ttl > 0:
@@ -84,27 +102,58 @@ class DBDagBag:
self._lock: RLock | nullcontext = RLock() if self._use_cache else
nullcontext()
def _read_dag(self, serdag: SerializedDagModel) -> SerializedDAG | None:
- """Read and optionally cache a SerializedDAG from a
SerializedDagModel."""
+ """Read and cache a SerializedDAG (with its ``dag_hash`` for staleness
detection)."""
serdag.load_op_links = self.load_op_links
dag = serdag.dag
if not dag:
return None
with self._lock:
- self._dags[serdag.dag_version_id] = dag
+ self._dags[serdag.dag_version_id] = _CacheEntry(dag,
serdag.dag_hash, time.monotonic())
cache_size = len(self._dags)
if self._use_cache:
stats.gauge("api_server.dag_bag.cache_size", cache_size, rate=0.1)
return dag
+ @staticmethod
+ def _current_dag_hash(version_id: UUID | str, session: Session) -> str |
None:
+ """Return the current ``dag_hash`` of the serialized DAG for
``version_id``, or None."""
+ from airflow.models.serialized_dag import SerializedDagModel
+
+ return session.scalar(
+
select(SerializedDagModel.dag_hash).where(SerializedDagModel.dag_version_id ==
version_id)
+ )
+
def _get_dag(self, version_id: UUID | str, session: Session) ->
SerializedDAG | None:
- # Check cache first
with self._lock:
- dag = self._dags.get(version_id)
-
- if dag:
- if self._use_cache:
- stats.incr("api_server.dag_bag.cache_hit")
- return dag
+ cached = self._dags.get(version_id)
+
+ if cached is not None:
+ now = time.monotonic()
+ # A serialized DAG cannot be rewritten more often than
+ # [core] min_serialized_dag_update_interval, so an entry validated
within that window
+ # cannot have gone stale yet -- serve it without touching the DB.
+ if now - cached.last_validated < self._revalidation_interval:
+ if self._use_cache:
+ stats.incr("api_server.dag_bag.cache_hit")
+ return cached.dag
+ # Past the window: a version may have been updated in place (same
dag_version_id, new
+ # content + new dag_hash) by SerializedDagModel.write_dag, so
confirm the cached copy
+ # against the current dag_hash. That validation is a single-row
lookup on the
+ # uniquely-indexed serialized_dag.dag_version_id column.
+ if self._current_dag_hash(version_id, session) == cached.dag_hash:
+ # Still current: restart the revalidation window so the next
hits skip the query.
+ # (For a TTLCache this write-back also refreshes the entry's
TTL/LRU recency, which
+ # is fine -- the entry was just re-confirmed against the DB.)
+ with self._lock:
+ current = self._dags.get(version_id)
+ if current is not None and current.dag_hash ==
cached.dag_hash:
+ self._dags[version_id] =
current._replace(last_validated=now)
+ if self._use_cache:
+ stats.incr("api_server.dag_bag.cache_hit")
+ return cached.dag
+ # Stale (updated in place) or the version no longer exists: drop
and reload below.
+ with self._lock:
+ self._dags.pop(version_id, None)
dag_version = session.get(DagVersion, version_id,
options=[joinedload(DagVersion.serialized_dag)])
if not dag_version:
@@ -112,14 +161,16 @@ class DBDagBag:
if not (serdag := dag_version.serialized_dag):
return None
- # Double-checked locking: another thread may have cached it while we
queried DB.
- # Only emit the miss metric after confirming no other thread cached
it, to avoid
- # counting a single lookup as both a miss and a hit.
+ # Double-checked locking: another thread may have cached it while we
queried DB. Such an
+ # entry was just loaded from the DB, so it is well within its
revalidation window and is
+ # served without an extra hash check, consistent with the policy
above. Only emit the miss
+ # metric after confirming no other thread cached it, to avoid counting
a single lookup as
+ # both a miss and a hit.
if self._use_cache:
with self._lock:
- if dag := self._dags.get(version_id):
+ if (cached := self._dags.get(version_id)) is not None:
stats.incr("api_server.dag_bag.cache_hit")
- return dag
+ return cached.dag
stats.incr("api_server.dag_bag.cache_miss")
return self._read_dag(serdag)
diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py
b/airflow-core/tests/unit/jobs/test_scheduler_job.py
index 67986c0297e..3ee51bb917d 100644
--- a/airflow-core/tests/unit/jobs/test_scheduler_job.py
+++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py
@@ -7791,7 +7791,7 @@ class TestSchedulerJob:
with create_session() as session:
tis = session.scalars(select(TaskInstance)).all()
- dags = self.job_runner.scheduler_dag_bag._dags.values()
+ dags = [entry.dag for entry in
self.job_runner.scheduler_dag_bag._dags.values()]
assert [dag.dag_id for dag in dags] == ["test_only_empty_tasks"]
assert len(tis) == 6
assert {
diff --git a/airflow-core/tests/unit/models/test_dagbag.py
b/airflow-core/tests/unit/models/test_dagbag.py
index 4dfbd4d8a6b..79668d4fe54 100644
--- a/airflow-core/tests/unit/models/test_dagbag.py
+++ b/airflow-core/tests/unit/models/test_dagbag.py
@@ -24,9 +24,17 @@ import pytest
import time_machine
from cachetools import LRUCache, TTLCache
-from airflow.models.dagbag import DBDagBag
+from airflow.models.dag import DagModel
+from airflow.models.dag_version import DagVersion
+from airflow.models.dagbag import DBDagBag, _CacheEntry
+from airflow.models.dagbundle import DagBundleModel
from airflow.models.serialized_dag import SerializedDagModel
-from airflow.serialization.serialized_objects import SerializedDAG
+from airflow.providers.standard.operators.empty import EmptyOperator
+from airflow.sdk import DAG
+from airflow.serialization.serialized_objects import LazyDeserializedDAG,
SerializedDAG
+from airflow.utils.session import create_session
+
+from tests_common.test_utils import db
pytestmark = pytest.mark.db_test
@@ -44,16 +52,18 @@ class TestDBDagBag:
self.session = MagicMock()
def test__read_dag_stores_and_returns_dag(self):
- """It should store the SerializedDAG in _dags and return it."""
+ """It should store the SerializedDAG with its hash, and return it."""
mock_dag = MagicMock(spec=SerializedDAG)
mock_serdag = MagicMock(spec=SerializedDagModel)
mock_serdag.dag = mock_dag
mock_serdag.dag_version_id = "v1"
+ mock_serdag.dag_hash = "hash1"
result = self.db_dag_bag._read_dag(mock_serdag)
assert result == mock_dag
- assert self.db_dag_bag._dags["v1"] == mock_dag
+ entry = self.db_dag_bag._dags["v1"]
+ assert (entry.dag, entry.dag_hash) == (mock_dag, "hash1")
assert mock_serdag.load_op_links is True
def test__read_dag_returns_none_when_no_dag(self):
@@ -68,11 +78,12 @@ class TestDBDagBag:
assert "v1" not in self.db_dag_bag._dags
def test_get_dag_fetches_from_db_on_miss(self):
- """It should query the DB and cache the result when not in cache."""
+ """It should query the DB and cache the result (with its hash) when
not in cache."""
mock_dag = MagicMock(spec=SerializedDAG)
mock_serdag = MagicMock(spec=SerializedDagModel)
mock_serdag.dag = mock_dag
mock_serdag.dag_version_id = "v1"
+ mock_serdag.dag_hash = "hash1"
mock_dag_version = MagicMock()
mock_dag_version.serialized_dag = mock_serdag
self.session.get.return_value = mock_dag_version
@@ -81,16 +92,87 @@ class TestDBDagBag:
self.session.get.assert_called_once()
assert result == mock_dag
+ entry = self.db_dag_bag._dags["v1"]
+ assert (entry.dag, entry.dag_hash) == (mock_dag, "hash1")
+
+ def test_get_dag_serves_within_revalidation_window_without_query(self):
+ """A recently validated entry is served straight from cache with no DB
query at all."""
+ mock_dag = MagicMock(spec=SerializedDAG)
+ # Just-validated entry, well within the (default 30s) revalidation
window.
+ self.db_dag_bag._dags["v1"] = _CacheEntry(mock_dag, "hash1",
time.monotonic())
+
+ result = self.db_dag_bag.get_dag("v1", session=self.session)
+
+ assert result == mock_dag
+ self.session.scalar.assert_not_called() # no revalidation query
inside the window
+ self.session.get.assert_not_called()
- def test_get_dag_returns_cached_on_hit(self):
- """It should return cached DAG without querying DB."""
+ def test_get_dag_serves_stale_within_window_even_if_db_changed(self):
+ """Inside the window the cached copy is served even if the DB has
since changed.
+
+ This is the intended throttle tradeoff: staleness is bounded by the
window, not zero.
+ """
+ stale_dag = MagicMock(spec=SerializedDAG)
+ self.db_dag_bag._dags["v1"] = _CacheEntry(stale_dag, "old_hash",
time.monotonic())
+ # The DB has a different hash now, but we are inside the window so it
is never consulted.
+ self.session.scalar.return_value = "new_hash"
+
+ result = self.db_dag_bag.get_dag("v1", session=self.session)
+
+ assert result == stale_dag
+ self.session.scalar.assert_not_called()
+ self.session.get.assert_not_called()
+
+ def
test_get_dag_revalidates_after_window_and_serves_when_hash_matches(self):
+ """Past the window, a hit is revalidated by hash and served (window
restarted)."""
mock_dag = MagicMock(spec=SerializedDAG)
- self.db_dag_bag._dags["v1"] = mock_dag
+ # last_validated=0.0 is far in the past, so the entry is revalidated.
+ self.db_dag_bag._dags["v1"] = _CacheEntry(mock_dag, "hash1", 0.0)
+ self.session.scalar.return_value = "hash1"
result = self.db_dag_bag.get_dag("v1", session=self.session)
assert result == mock_dag
+ # Validated via a cheap scalar() lookup, not a full DagVersion load.
+ self.session.scalar.assert_called_once()
self.session.get.assert_not_called()
+ # The window is restarted so the next hit can skip the query.
+ assert self.db_dag_bag._dags["v1"].last_validated > 0.0
+
+ def test_get_dag_reloads_when_version_updated_in_place(self):
+ """A version updated in place (same id, new hash) must be reloaded,
not served stale."""
+ stale_dag = MagicMock(spec=SerializedDAG)
+ fresh_dag = MagicMock(spec=SerializedDAG)
+ # last_validated=0.0 forces revalidation regardless of the window.
+ self.db_dag_bag._dags["v1"] = _CacheEntry(stale_dag, "old_hash", 0.0)
+ mock_serdag = MagicMock(spec=SerializedDagModel)
+ mock_serdag.dag = fresh_dag
+ mock_serdag.dag_version_id = "v1"
+ mock_serdag.dag_hash = "new_hash"
+ mock_dag_version = MagicMock()
+ mock_dag_version.serialized_dag = mock_serdag
+ self.session.get.return_value = mock_dag_version
+ # The dag_hash validation lookup returns the new hash (mismatch ->
reload).
+ self.session.scalar.return_value = "new_hash"
+
+ result = self.db_dag_bag.get_dag("v1", session=self.session)
+
+ assert result == fresh_dag
+ self.session.get.assert_called_once()
+ entry = self.db_dag_bag._dags["v1"]
+ assert (entry.dag, entry.dag_hash) == (fresh_dag, "new_hash")
+
+ def test_get_dag_reloads_when_cached_version_deleted(self):
+ """A cached entry whose serialized row no longer exists must not be
served."""
+ stale_dag = MagicMock(spec=SerializedDAG)
+ self.db_dag_bag._dags["v1"] = _CacheEntry(stale_dag, "old_hash", 0.0)
+ self.session.scalar.return_value = None # validation finds no row
+ self.session.get.return_value = None # version is gone
+
+ result = self.db_dag_bag.get_dag("v1", session=self.session)
+
+ assert result is None
+ assert "v1" not in self.db_dag_bag._dags
def test_get_dag_returns_none_when_not_found(self):
"""It should return None if version_id not found in DB."""
@@ -100,6 +182,66 @@ class TestDBDagBag:
assert result is None
+ def test_get_dag_reflects_in_place_version_update_end_to_end(self):
+ """End-to-end regression: an in-place version update must be re-read,
not served stale.
+
+ When a DagVersion has no task instances,
``SerializedDagModel.write_dag`` updates the
+ serialized DAG in place (same ``dag_version_id``, new content). A
long-lived DagBag (e.g.
+ the scheduler's) must reflect the new content instead of serving the
cached old code.
+
+ Each step uses its own session, matching the real deployment where the
dag processor
+ writes and the scheduler reads in separate processes/sessions.
+ """
+ dag_id = "stale_cache_dag"
+ bundle_name = "testing"
+ db.clear_db_dags()
+ db.clear_db_serialized_dags()
+ db.clear_db_dag_bundles()
+
+ def make_lazy(task_ids):
+ with DAG(dag_id, schedule=None) as dag:
+ for task_id in task_ids:
+ EmptyOperator(task_id=task_id)
+ return LazyDeserializedDAG.from_dag(dag)
+
+ # Long-lived bag, like the scheduler's process-lived
scheduler_dag_bag. A 0s revalidation
+ # interval makes every hit revalidate, exercising the post-window
reload path
+ # deterministically without manipulating the clock.
+ dag_bag = DBDagBag()
+ dag_bag._revalidation_interval = 0
+
+ with create_session() as session:
+ session.add(DagBundleModel(name=bundle_name))
+ session.flush()
+ session.add(DagModel(dag_id=dag_id, bundle_name=bundle_name))
+ session.flush()
+ # Version 1: a single task, no task instances yet.
+ SerializedDagModel.write_dag(make_lazy(["a"]),
bundle_name=bundle_name, session=session)
+ session.commit()
+ version_id = DagVersion.get_latest_version(dag_id,
session=session).id
+
+ # The scheduler loads and caches the DAG.
+ with create_session() as session:
+ assert set(dag_bag.get_dag(version_id, session=session).task_ids)
== {"a"}
+
+ # The dag processor adds a task and re-writes. With no task instances
on the version,
+ # write_dag updates it in place (same dag_version_id, new content +
hash).
+ with create_session() as session:
+ did_write = SerializedDagModel.write_dag(
+ make_lazy(["a", "b"]), bundle_name=bundle_name, session=session
+ )
+ session.commit()
+ assert did_write is True
+ assert DagVersion.get_latest_version(dag_id, session=session).id
== version_id
+
+ # The scheduler reads again: it must serve the updated DAG, not the
stale cached one.
+ with create_session() as session:
+ assert set(dag_bag.get_dag(version_id, session=session).task_ids)
== {"a", "b"}
+
+ db.clear_db_dags()
+ db.clear_db_serialized_dags()
+ db.clear_db_dag_bundles()
+
class TestDBDagBagCache:
"""Tests for DBDagBag optional caching behavior."""
@@ -257,7 +399,9 @@ class TestDBDagBagCache:
"""Test that cache hit metric is emitted when caching is enabled."""
dag_bag = DBDagBag(cache_size=10, cache_ttl=60)
mock_session = MagicMock()
- dag_bag._dags["test_version"] = MagicMock()
+ # last_validated=0.0 forces revalidation; the hash matches, so it
counts as a hit.
+ dag_bag._dags["test_version"] = _CacheEntry(MagicMock(), "hash1", 0.0)
+ mock_session.scalar.return_value = "hash1"
dag_bag._get_dag("test_version", mock_session)