This is an automated email from the ASF dual-hosted git repository.
ephraimanierobi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 326680fd977 Refactor bundle refresh persistence into overridable
get/update methods (#63835)
326680fd977 is described below
commit 326680fd9775eeb47ba8dd824a100da56dc1d8b9
Author: Ephraim Anierobi <[email protected]>
AuthorDate: Wed Mar 25 15:53:37 2026 +0100
Refactor bundle refresh persistence into overridable get/update methods
(#63835)
* Refactor bundle refresh persistence into overridable get/update methods
Replace inline DagBundleModel mutations in _refresh_dag_bundles() with
get_bundle_state() and update_bundle_state() instance methods on
DagFileProcessorManager, isolating all DB access for bundle state
behind a clean override seam.
* fixup! Refactor bundle refresh persistence into overridable get/update
methods
* Address reviews
---
airflow-core/src/airflow/dag_processing/manager.py | 165 +++++++++-----
.../tests/unit/dag_processing/test_manager.py | 247 +++++++++++++++++++++
2 files changed, 354 insertions(+), 58 deletions(-)
diff --git a/airflow-core/src/airflow/dag_processing/manager.py
b/airflow-core/src/airflow/dag_processing/manager.py
index 7d4e20adfa8..d1194127599 100644
--- a/airflow-core/src/airflow/dag_processing/manager.py
+++ b/airflow-core/src/airflow/dag_processing/manager.py
@@ -94,6 +94,13 @@ class DagParsingStat(NamedTuple):
all_files_processed: bool
+class BundleState(NamedTuple):
+ """Persisted refresh state for a DAG bundle."""
+
+ last_refreshed: datetime | None
+ version: str | None
+
+
@attrs.define
class DagFileStat:
"""Information about single processing of one file."""
@@ -591,6 +598,43 @@ class DagFileProcessorManager(LoggingMixin):
self._add_files_to_queue([file_info], mode="front")
Stats.incr("dag_processing.other_callback_count")
+ @provide_session
+ def get_bundle_state(self, bundle_name: str, *, session: Session =
NEW_SESSION) -> BundleState | None:
+ """
+ Return the persisted refresh state for a bundle.
+
+ Returns ``None`` if the bundle has no database record.
+ """
+ row = session.scalar(
+ select(DagBundleModel)
+ .where(DagBundleModel.name == bundle_name)
+ .options(load_only(DagBundleModel.last_refreshed,
DagBundleModel.version))
+ )
+ if row is None:
+ return None
+ return BundleState(last_refreshed=row.last_refreshed,
version=row.version)
+
+ @provide_session
+ def update_bundle_state(
+ self,
+ bundle_name: str,
+ *,
+ last_refreshed: datetime,
+ version: str | None,
+ session: Session = NEW_SESSION,
+ ) -> None:
+ """
+ Persist the post-refresh state for a bundle.
+
+ Always updates ``last_refreshed``. Updates ``version`` only when
``version`` is not
+ ``None`` — pass ``None`` to leave the stored version unchanged (e.g.
for non-versioned
+ bundles or when the version did not change after a refresh).
+ """
+ values: dict[str, Any] = {"last_refreshed": last_refreshed}
+ if version is not None:
+ values["version"] = version
+ session.execute(update(DagBundleModel).where(DagBundleModel.name ==
bundle_name).values(**values))
+
def _refresh_dag_bundles(self, known_files: dict[str, set[DagFileInfo]]):
"""Refresh DAG bundles, if required."""
now = timezone.utcnow()
@@ -619,69 +663,74 @@ class DagFileProcessorManager(LoggingMixin):
self.log.exception("Error initializing bundle %s: %s",
bundle.name, e)
continue
# TODO: AIP-66 test to make sure we get a fresh record from the db
and it's not cached
- with create_session() as session:
- bundle_model = session.get(DagBundleModel, bundle.name)
- if bundle_model is None:
- self.log.warning("Bundle model not found for %s",
bundle.name)
- continue
- elapsed_time_since_refresh = (
- now - (bundle_model.last_refreshed or utc_epoch())
- ).total_seconds()
- if bundle.supports_versioning:
- # we will also check the version of the bundle to see if
another DAG processor has seen
- # a new version
- pre_refresh_version = (
- self._bundle_versions.get(bundle.name) or
bundle.get_current_version()
- )
- current_version_matches_db = pre_refresh_version ==
bundle_model.version
- else:
- # With no versioning, it always "matches"
- current_version_matches_db = True
-
- previously_seen = bundle.name in self._bundle_versions
- if self.should_skip_refresh(
- bundle=bundle,
- elapsed_time_since_refresh=elapsed_time_since_refresh,
- current_version_matches_db=current_version_matches_db,
- previously_seen=previously_seen,
- ):
- self.log.info("Not time to refresh bundle %s", bundle.name)
- continue
-
- self.log.info("Refreshing bundle %s", bundle.name)
-
- try:
- bundle.refresh()
- any_refreshed = True
- except Exception:
- self.log.exception("Error refreshing bundle %s",
bundle.name)
- continue
-
- bundle_model.last_refreshed = now
- self._force_refresh_bundles.discard(bundle.name)
+ try:
+ bundle_state = self.get_bundle_state(bundle.name)
+ except Exception:
+ self.log.exception("Error fetching state for bundle %s",
bundle.name)
+ continue
+ if bundle_state is None:
+ self.log.warning("Bundle model not found for %s", bundle.name)
+ continue
+ elapsed_time_since_refresh = (now - (bundle_state.last_refreshed
or utc_epoch())).total_seconds()
+ if bundle.supports_versioning:
+ # we will also check the version of the bundle to see if
another DAG processor has seen
+ # a new version
+ pre_refresh_version = self._bundle_versions.get(bundle.name)
or bundle.get_current_version()
+ current_version_matches_db = pre_refresh_version ==
bundle_state.version
+ else:
+ # With no versioning, it always "matches"
+ current_version_matches_db = True
+
+ previously_seen = bundle.name in self._bundle_versions
+ if self.should_skip_refresh(
+ bundle=bundle,
+ elapsed_time_since_refresh=elapsed_time_since_refresh,
+ current_version_matches_db=current_version_matches_db,
+ previously_seen=previously_seen,
+ ):
+ self.log.info("Not time to refresh bundle %s", bundle.name)
+ continue
- if bundle.supports_versioning:
- # We can short-circuit the rest of this if (1) bundle was
seen before by
- # this dag processor and (2) the version of the bundle did
not change
- # after refreshing it
- version_after_refresh = bundle.get_current_version()
- if previously_seen and pre_refresh_version ==
version_after_refresh:
- self.log.debug(
- "Bundle %s version not changed after refresh: %s",
- bundle.name,
- version_after_refresh,
- )
- continue
+ self.log.info("Refreshing bundle %s", bundle.name)
- bundle_model.version = version_after_refresh
+ try:
+ bundle.refresh()
+ any_refreshed = True
+ except Exception:
+ self.log.exception("Error refreshing bundle %s", bundle.name)
+ continue
- self.log.info(
- "Version changed for %s, new version: %s",
bundle.name, version_after_refresh
+ self._force_refresh_bundles.discard(bundle.name)
+
+ if bundle.supports_versioning:
+ # We can short-circuit the rest of this if (1) bundle was seen
before by
+ # this dag processor and (2) the version of the bundle did not
change
+ # after refreshing it
+ version_after_refresh = bundle.get_current_version()
+ if previously_seen and pre_refresh_version ==
version_after_refresh:
+ self.log.debug(
+ "Bundle %s version not changed after refresh: %s",
+ bundle.name,
+ version_after_refresh,
)
- else:
- version_after_refresh = None
+ try:
+ self.update_bundle_state(bundle.name,
last_refreshed=now, version=None)
+ except Exception:
+ self.log.exception("Error persisting state for bundle
%s", bundle.name)
+ continue
- self._bundle_versions[bundle.name] = version_after_refresh
+ self.log.info("Version changed for %s, new version: %s",
bundle.name, version_after_refresh)
+ else:
+ version_after_refresh = None
+
+ # Persistence failure must not skip file scanning (bundle is
already refreshed locally).
+ # _bundle_versions is only advanced on success to stay consistent
with the DB.
+ try:
+ self.update_bundle_state(bundle.name, last_refreshed=now,
version=version_after_refresh)
+ except Exception:
+ self.log.exception("Error persisting state for bundle %s",
bundle.name)
+ else:
+ self._bundle_versions[bundle.name] = version_after_refresh
found_files = {
DagFileInfo(rel_path=p, bundle_name=bundle.name,
bundle_path=bundle.path)
diff --git a/airflow-core/tests/unit/dag_processing/test_manager.py
b/airflow-core/tests/unit/dag_processing/test_manager.py
index e728ce85524..ea867ad1fec 100644
--- a/airflow-core/tests/unit/dag_processing/test_manager.py
+++ b/airflow-core/tests/unit/dag_processing/test_manager.py
@@ -42,9 +42,11 @@ from uuid6 import uuid7
from airflow._shared.timezones import timezone
from airflow.callbacks.callback_requests import DagCallbackRequest
+from airflow.dag_processing.bundles.base import BaseDagBundle
from airflow.dag_processing.bundles.manager import DagBundlesManager
from airflow.dag_processing.dagbag import DagBag
from airflow.dag_processing.manager import (
+ BundleState,
DagFileInfo,
DagFileProcessorManager,
DagFileStat,
@@ -1838,3 +1840,248 @@ class TestDagFileProcessorManager:
dag_path.touch() # make the loop run faster
gauge_values.clear()
+
+ # --- get_bundle_state / update_bundle_state ---
+
+ def test_get_bundle_state_returns_none_for_missing_bundle(self):
+ manager = DagFileProcessorManager(max_runs=1)
+ assert manager.get_bundle_state("nonexistent_bundle") is None
+
+ def test_get_bundle_state_returns_correct_state(self, session):
+ bundle_name = "test_state_bundle"
+ refreshed_at = timezone.datetime(2024, 1, 15, 12, 0, 0)
+ model = DagBundleModel(name=bundle_name, version="v1")
+ model.last_refreshed = refreshed_at
+ session.add(model)
+ session.commit()
+
+ manager = DagFileProcessorManager(max_runs=1)
+ state = manager.get_bundle_state(bundle_name)
+
+ assert state == BundleState(last_refreshed=refreshed_at, version="v1")
+
+ def test_get_bundle_state_null_fields(self, session):
+ bundle_name = "test_null_state_bundle"
+ session.add(DagBundleModel(name=bundle_name))
+ session.commit()
+
+ manager = DagFileProcessorManager(max_runs=1)
+ state = manager.get_bundle_state(bundle_name)
+
+ assert state == BundleState(last_refreshed=None, version=None)
+
+ def test_update_bundle_state_sets_last_refreshed(self, session):
+ bundle_name = "test_update_bundle"
+ session.add(DagBundleModel(name=bundle_name))
+ session.commit()
+
+ refreshed_at = timezone.datetime(2024, 6, 1, 8, 0, 0)
+ manager = DagFileProcessorManager(max_runs=1)
+ manager.update_bundle_state(bundle_name, last_refreshed=refreshed_at,
version=None)
+
+ session.expire_all()
+ model = session.get(DagBundleModel, bundle_name)
+ assert model.last_refreshed == refreshed_at
+ assert model.version is None
+
+ def test_update_bundle_state_sets_version(self, session):
+ bundle_name = "test_update_version_bundle"
+ session.add(DagBundleModel(name=bundle_name))
+ session.commit()
+
+ refreshed_at = timezone.datetime(2024, 6, 1, 8, 0, 0)
+ manager = DagFileProcessorManager(max_runs=1)
+ manager.update_bundle_state(bundle_name, last_refreshed=refreshed_at,
version="abc123")
+
+ session.expire_all()
+ model = session.get(DagBundleModel, bundle_name)
+ assert model.last_refreshed == refreshed_at
+ assert model.version == "abc123"
+
+ def test_update_bundle_state_does_not_overwrite_version_when_none(self,
session):
+ bundle_name = "test_preserve_version_bundle"
+ session.add(DagBundleModel(name=bundle_name, version="keep_me"))
+ session.commit()
+
+ refreshed_at = timezone.datetime(2024, 6, 1, 8, 0, 0)
+ manager = DagFileProcessorManager(max_runs=1)
+ manager.update_bundle_state(bundle_name, last_refreshed=refreshed_at,
version=None)
+
+ session.expire_all()
+ model = session.get(DagBundleModel, bundle_name)
+ assert model.last_refreshed == refreshed_at
+ assert model.version == "keep_me"
+
+ def _make_refresh_bundle(self, *, supports_versioning=False,
current_version=None):
+ bundle = MagicMock(spec=BaseDagBundle)
+ bundle.name = "mock_bundle"
+ bundle.refresh_interval = 0
+ bundle.supports_versioning = supports_versioning
+ bundle.is_initialized = True
+ bundle.path = Path("/dev/null")
+ bundle.get_current_version.return_value = current_version
+ return bundle
+
+ def _refresh_with_mocked_state(self, manager, bundle, initial_state):
+ """Run _refresh_dag_bundles with get/update_bundle_state mocked out.
+
+ Returns the two MagicMock objects for post-call assertions. MagicMock
retains its
+ call records after the ``with`` block exits (un-patching only restores
the original
+ attribute; it does not clear the mock's recorded calls), so callers
can assert on
+ them normally after this method returns.
+ """
+ manager._dag_bundles = [bundle]
+ manager._force_refresh_bundles = set()
+ mock_get = mock.patch.object(manager, "get_bundle_state",
return_value=initial_state)
+ mock_update = mock.patch.object(manager, "update_bundle_state")
+ with (
+ mock_get as patched_get,
+ mock_update as patched_update,
+ mock.patch.object(manager, "_find_files_in_bundle",
return_value=[]),
+ mock.patch.object(manager, "deactivate_deleted_dags"),
+ mock.patch.object(manager, "clear_orphaned_import_errors"),
+ mock.patch.object(manager, "handle_removed_files"),
+ mock.patch.object(manager, "_resort_file_queue"),
+ mock.patch.object(manager, "_add_new_files_to_queue"),
+ ):
+ manager._refresh_dag_bundles({})
+ return patched_get, patched_update
+
+ def test_refresh_dag_bundles_non_versioned_calls_update_bundle_state(self):
+ """Non-versioned bundle: update_bundle_state called with
version=None."""
+ manager = DagFileProcessorManager(max_runs=1)
+ bundle = self._make_refresh_bundle(supports_versioning=False)
+
+ mock_get, mock_update = self._refresh_with_mocked_state(
+ manager, bundle, BundleState(last_refreshed=None, version=None)
+ )
+
+ mock_get.assert_called_once_with("mock_bundle")
+ mock_update.assert_called_once_with("mock_bundle",
last_refreshed=mock.ANY, version=None)
+ assert manager._bundle_versions["mock_bundle"] is None
+
+ def
test_refresh_dag_bundles_versioned_version_changed_calls_update_bundle_state(self):
+ """Versioned bundle with new version: update_bundle_state called with
the new version."""
+ manager = DagFileProcessorManager(max_runs=1)
+ bundle = self._make_refresh_bundle(supports_versioning=True,
current_version="v2")
+ # Pre-populate _bundle_versions so previously_seen=True and current DB
version differs
+ manager._bundle_versions["mock_bundle"] = "v1"
+
+ mock_get, mock_update = self._refresh_with_mocked_state(
+ manager, bundle, BundleState(last_refreshed=None, version="v1")
+ )
+
+ mock_get.assert_called_once_with("mock_bundle")
+ mock_update.assert_called_once_with("mock_bundle",
last_refreshed=mock.ANY, version="v2")
+ assert manager._bundle_versions["mock_bundle"] == "v2"
+
+ def
test_refresh_dag_bundles_versioned_version_unchanged_calls_update_bundle_state(self):
+ """Versioned bundle with unchanged version: update_bundle_state called
with version=None."""
+ manager = DagFileProcessorManager(max_runs=1)
+ bundle = self._make_refresh_bundle(supports_versioning=True,
current_version="v1")
+ # Pre-populate _bundle_versions so previously_seen=True and version
matches
+ manager._bundle_versions["mock_bundle"] = "v1"
+
+ mock_get, mock_update = self._refresh_with_mocked_state(
+ manager, bundle, BundleState(last_refreshed=None, version="v1")
+ )
+
+ mock_get.assert_called_once_with("mock_bundle")
+ # version=None because version did not change — last_refreshed still
updated
+ mock_update.assert_called_once_with("mock_bundle",
last_refreshed=mock.ANY, version=None)
+ # _bundle_versions NOT updated for unchanged-version early-continue
path
+ assert manager._bundle_versions["mock_bundle"] == "v1"
+
+ def
test_refresh_dag_bundles_versioned_version_unchanged_persist_failure(self):
+ """Short-circuit path: if update_bundle_state raises, the bundle is
skipped without
+ populating known_files (version didn't change, so no file scanning
needed)."""
+ manager = DagFileProcessorManager(max_runs=1)
+ bundle = self._make_refresh_bundle(supports_versioning=True,
current_version="v1")
+ manager._bundle_versions["mock_bundle"] = "v1"
+ manager._dag_bundles = [bundle]
+ manager._force_refresh_bundles = set()
+
+ known_files: dict[str, set[DagFileInfo]] = {}
+ with (
+ mock.patch.object(
+ manager, "get_bundle_state",
return_value=BundleState(last_refreshed=None, version="v1")
+ ),
+ mock.patch.object(manager, "update_bundle_state",
side_effect=Exception("DB error")),
+ mock.patch.object(manager, "_find_files_in_bundle",
return_value=[]) as mock_find,
+ mock.patch.object(manager, "deactivate_deleted_dags"),
+ mock.patch.object(manager, "clear_orphaned_import_errors"),
+ mock.patch.object(manager, "handle_removed_files"),
+ mock.patch.object(manager, "_resort_file_queue"),
+ mock.patch.object(manager, "_add_new_files_to_queue"),
+ ):
+ manager._refresh_dag_bundles(known_files)
+
+ bundle.refresh.assert_called_once()
+ # Short-circuit continues to next bundle — no file scanning
+ mock_find.assert_not_called()
+ assert "mock_bundle" not in known_files
+ # _bundle_versions unchanged
+ assert manager._bundle_versions["mock_bundle"] == "v1"
+
+ def
test_refresh_dag_bundles_versioned_first_seen_skips_short_circuit(self):
+ """Versioned bundle seen for the first time: short-circuit is skipped
even if versions match.
+
+ previously_seen=False means ``previously_seen and ...`` is False, so
the bundle always
+ goes through the full update path on first encounter regardless of
version equality.
+ """
+ manager = DagFileProcessorManager(max_runs=1)
+ # current_version matches what's already in the DB state
+ bundle = self._make_refresh_bundle(supports_versioning=True,
current_version="v1")
+ # _bundle_versions is empty → previously_seen=False
+
+ mock_get, mock_update = self._refresh_with_mocked_state(
+ manager, bundle, BundleState(last_refreshed=None, version="v1")
+ )
+
+ mock_get.assert_called_once_with("mock_bundle")
+ # full update called with the actual version, not short-circuited to
version=None
+ mock_update.assert_called_once_with("mock_bundle",
last_refreshed=mock.ANY, version="v1")
+ assert manager._bundle_versions["mock_bundle"] == "v1"
+
+ def test_refresh_dag_bundles_get_bundle_state_failure_skips_bundle(self):
+ """A failure in get_bundle_state() logs and skips the bundle without
aborting the loop."""
+ manager = DagFileProcessorManager(max_runs=1)
+ bundle = self._make_refresh_bundle()
+ manager._dag_bundles = [bundle]
+
+ with mock.patch.object(manager, "get_bundle_state",
side_effect=Exception("API error")):
+ manager._refresh_dag_bundles({})
+
+ bundle.refresh.assert_not_called()
+
+ def
test_refresh_dag_bundles_update_bundle_state_failure_still_scans_files(self):
+ """A failure in update_bundle_state() logs but does not skip file
scanning.
+
+ The bundle was already refreshed, so known_files must still be
populated to prevent
+ the end-of-method cleanup from treating the bundle's files as removed.
+ """
+ manager = DagFileProcessorManager(max_runs=1)
+ bundle = self._make_refresh_bundle()
+ manager._dag_bundles = [bundle]
+
+ known_files: dict[str, set[DagFileInfo]] = {}
+ with (
+ mock.patch.object(
+ manager, "get_bundle_state",
return_value=BundleState(last_refreshed=None, version=None)
+ ),
+ mock.patch.object(manager, "update_bundle_state",
side_effect=Exception("API error")),
+ mock.patch.object(manager, "_find_files_in_bundle",
return_value=[]),
+ mock.patch.object(manager, "deactivate_deleted_dags"),
+ mock.patch.object(manager, "clear_orphaned_import_errors"),
+ mock.patch.object(manager, "handle_removed_files"),
+ mock.patch.object(manager, "_resort_file_queue"),
+ mock.patch.object(manager, "_add_new_files_to_queue"),
+ ):
+ manager._refresh_dag_bundles(known_files)
+
+ bundle.refresh.assert_called_once()
+ # known_files must be populated so cleanup doesn't purge this bundle's
files
+ assert "mock_bundle" in known_files
+ # _bundle_versions must NOT advance — DB still holds the old version,
so the next
+ # iteration will see a version mismatch and re-refresh rather than
skip incorrectly
+ assert "mock_bundle" not in manager._bundle_versions