This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 326680fd977 Refactor bundle refresh persistence into overridable 
get/update methods (#63835)
326680fd977 is described below

commit 326680fd9775eeb47ba8dd824a100da56dc1d8b9
Author: Ephraim Anierobi <[email protected]>
AuthorDate: Wed Mar 25 15:53:37 2026 +0100

    Refactor bundle refresh persistence into overridable get/update methods 
(#63835)
    
    * Refactor bundle refresh persistence into overridable get/update methods
    
    Replace inline DagBundleModel mutations in _refresh_dag_bundles() with
    get_bundle_state() and update_bundle_state() instance methods on
    DagFileProcessorManager, isolating all DB access for bundle state
    behind a clean override seam.
    
    * fixup! Refactor bundle refresh persistence into overridable get/update 
methods
    
    * Address reviews
---
 airflow-core/src/airflow/dag_processing/manager.py | 165 +++++++++-----
 .../tests/unit/dag_processing/test_manager.py      | 247 +++++++++++++++++++++
 2 files changed, 354 insertions(+), 58 deletions(-)

diff --git a/airflow-core/src/airflow/dag_processing/manager.py 
b/airflow-core/src/airflow/dag_processing/manager.py
index 7d4e20adfa8..d1194127599 100644
--- a/airflow-core/src/airflow/dag_processing/manager.py
+++ b/airflow-core/src/airflow/dag_processing/manager.py
@@ -94,6 +94,13 @@ class DagParsingStat(NamedTuple):
     all_files_processed: bool
 
 
+class BundleState(NamedTuple):
+    """Persisted refresh state for a DAG bundle."""
+
+    last_refreshed: datetime | None
+    version: str | None
+
+
 @attrs.define
 class DagFileStat:
     """Information about single processing of one file."""
@@ -591,6 +598,43 @@ class DagFileProcessorManager(LoggingMixin):
         self._add_files_to_queue([file_info], mode="front")
         Stats.incr("dag_processing.other_callback_count")
 
+    @provide_session
+    def get_bundle_state(self, bundle_name: str, *, session: Session = 
NEW_SESSION) -> BundleState | None:
+        """
+        Return the persisted refresh state for a bundle.
+
+        Returns ``None`` if the bundle has no database record.
+        """
+        row = session.scalar(
+            select(DagBundleModel)
+            .where(DagBundleModel.name == bundle_name)
+            .options(load_only(DagBundleModel.last_refreshed, 
DagBundleModel.version))
+        )
+        if row is None:
+            return None
+        return BundleState(last_refreshed=row.last_refreshed, 
version=row.version)
+
+    @provide_session
+    def update_bundle_state(
+        self,
+        bundle_name: str,
+        *,
+        last_refreshed: datetime,
+        version: str | None,
+        session: Session = NEW_SESSION,
+    ) -> None:
+        """
+        Persist the post-refresh state for a bundle.
+
+        Always updates ``last_refreshed``. Updates ``version`` only when 
``version`` is not
+        ``None`` — pass ``None`` to leave the stored version unchanged (e.g. 
for non-versioned
+        bundles or when the version did not change after a refresh).
+        """
+        values: dict[str, Any] = {"last_refreshed": last_refreshed}
+        if version is not None:
+            values["version"] = version
+        session.execute(update(DagBundleModel).where(DagBundleModel.name == 
bundle_name).values(**values))
+
     def _refresh_dag_bundles(self, known_files: dict[str, set[DagFileInfo]]):
         """Refresh DAG bundles, if required."""
         now = timezone.utcnow()
@@ -619,69 +663,74 @@ class DagFileProcessorManager(LoggingMixin):
                     self.log.exception("Error initializing bundle %s: %s", 
bundle.name, e)
                     continue
             # TODO: AIP-66 test to make sure we get a fresh record from the db 
and it's not cached
-            with create_session() as session:
-                bundle_model = session.get(DagBundleModel, bundle.name)
-                if bundle_model is None:
-                    self.log.warning("Bundle model not found for %s", 
bundle.name)
-                    continue
-                elapsed_time_since_refresh = (
-                    now - (bundle_model.last_refreshed or utc_epoch())
-                ).total_seconds()
-                if bundle.supports_versioning:
-                    # we will also check the version of the bundle to see if 
another DAG processor has seen
-                    # a new version
-                    pre_refresh_version = (
-                        self._bundle_versions.get(bundle.name) or 
bundle.get_current_version()
-                    )
-                    current_version_matches_db = pre_refresh_version == 
bundle_model.version
-                else:
-                    # With no versioning, it always "matches"
-                    current_version_matches_db = True
-
-                previously_seen = bundle.name in self._bundle_versions
-                if self.should_skip_refresh(
-                    bundle=bundle,
-                    elapsed_time_since_refresh=elapsed_time_since_refresh,
-                    current_version_matches_db=current_version_matches_db,
-                    previously_seen=previously_seen,
-                ):
-                    self.log.info("Not time to refresh bundle %s", bundle.name)
-                    continue
-
-                self.log.info("Refreshing bundle %s", bundle.name)
-
-                try:
-                    bundle.refresh()
-                    any_refreshed = True
-                except Exception:
-                    self.log.exception("Error refreshing bundle %s", 
bundle.name)
-                    continue
-
-                bundle_model.last_refreshed = now
-                self._force_refresh_bundles.discard(bundle.name)
+            try:
+                bundle_state = self.get_bundle_state(bundle.name)
+            except Exception:
+                self.log.exception("Error fetching state for bundle %s", 
bundle.name)
+                continue
+            if bundle_state is None:
+                self.log.warning("Bundle model not found for %s", bundle.name)
+                continue
+            elapsed_time_since_refresh = (now - (bundle_state.last_refreshed 
or utc_epoch())).total_seconds()
+            if bundle.supports_versioning:
+                # we will also check the version of the bundle to see if 
another DAG processor has seen
+                # a new version
+                pre_refresh_version = self._bundle_versions.get(bundle.name) 
or bundle.get_current_version()
+                current_version_matches_db = pre_refresh_version == 
bundle_state.version
+            else:
+                # With no versioning, it always "matches"
+                current_version_matches_db = True
+
+            previously_seen = bundle.name in self._bundle_versions
+            if self.should_skip_refresh(
+                bundle=bundle,
+                elapsed_time_since_refresh=elapsed_time_since_refresh,
+                current_version_matches_db=current_version_matches_db,
+                previously_seen=previously_seen,
+            ):
+                self.log.info("Not time to refresh bundle %s", bundle.name)
+                continue
 
-                if bundle.supports_versioning:
-                    # We can short-circuit the rest of this if (1) bundle was 
seen before by
-                    # this dag processor and (2) the version of the bundle did 
not change
-                    # after refreshing it
-                    version_after_refresh = bundle.get_current_version()
-                    if previously_seen and pre_refresh_version == 
version_after_refresh:
-                        self.log.debug(
-                            "Bundle %s version not changed after refresh: %s",
-                            bundle.name,
-                            version_after_refresh,
-                        )
-                        continue
+            self.log.info("Refreshing bundle %s", bundle.name)
 
-                    bundle_model.version = version_after_refresh
+            try:
+                bundle.refresh()
+                any_refreshed = True
+            except Exception:
+                self.log.exception("Error refreshing bundle %s", bundle.name)
+                continue
 
-                    self.log.info(
-                        "Version changed for %s, new version: %s", 
bundle.name, version_after_refresh
+            self._force_refresh_bundles.discard(bundle.name)
+
+            if bundle.supports_versioning:
+                # We can short-circuit the rest of this if (1) bundle was seen 
before by
+                # this dag processor and (2) the version of the bundle did not 
change
+                # after refreshing it
+                version_after_refresh = bundle.get_current_version()
+                if previously_seen and pre_refresh_version == 
version_after_refresh:
+                    self.log.debug(
+                        "Bundle %s version not changed after refresh: %s",
+                        bundle.name,
+                        version_after_refresh,
                     )
-                else:
-                    version_after_refresh = None
+                    try:
+                        self.update_bundle_state(bundle.name, 
last_refreshed=now, version=None)
+                    except Exception:
+                        self.log.exception("Error persisting state for bundle 
%s", bundle.name)
+                    continue
 
-            self._bundle_versions[bundle.name] = version_after_refresh
+                self.log.info("Version changed for %s, new version: %s", 
bundle.name, version_after_refresh)
+            else:
+                version_after_refresh = None
+
+            # Persistence failure must not skip file scanning (bundle is 
already refreshed locally).
+            # _bundle_versions is only advanced on success to stay consistent 
with the DB.
+            try:
+                self.update_bundle_state(bundle.name, last_refreshed=now, 
version=version_after_refresh)
+            except Exception:
+                self.log.exception("Error persisting state for bundle %s", 
bundle.name)
+            else:
+                self._bundle_versions[bundle.name] = version_after_refresh
 
             found_files = {
                 DagFileInfo(rel_path=p, bundle_name=bundle.name, 
bundle_path=bundle.path)
diff --git a/airflow-core/tests/unit/dag_processing/test_manager.py 
b/airflow-core/tests/unit/dag_processing/test_manager.py
index e728ce85524..ea867ad1fec 100644
--- a/airflow-core/tests/unit/dag_processing/test_manager.py
+++ b/airflow-core/tests/unit/dag_processing/test_manager.py
@@ -42,9 +42,11 @@ from uuid6 import uuid7
 
 from airflow._shared.timezones import timezone
 from airflow.callbacks.callback_requests import DagCallbackRequest
+from airflow.dag_processing.bundles.base import BaseDagBundle
 from airflow.dag_processing.bundles.manager import DagBundlesManager
 from airflow.dag_processing.dagbag import DagBag
 from airflow.dag_processing.manager import (
+    BundleState,
     DagFileInfo,
     DagFileProcessorManager,
     DagFileStat,
@@ -1838,3 +1840,248 @@ class TestDagFileProcessorManager:
 
                 dag_path.touch()  # make the loop run faster
                 gauge_values.clear()
+
+    # --- get_bundle_state / update_bundle_state ---
+
+    def test_get_bundle_state_returns_none_for_missing_bundle(self):
+        manager = DagFileProcessorManager(max_runs=1)
+        assert manager.get_bundle_state("nonexistent_bundle") is None
+
+    def test_get_bundle_state_returns_correct_state(self, session):
+        bundle_name = "test_state_bundle"
+        refreshed_at = timezone.datetime(2024, 1, 15, 12, 0, 0)
+        model = DagBundleModel(name=bundle_name, version="v1")
+        model.last_refreshed = refreshed_at
+        session.add(model)
+        session.commit()
+
+        manager = DagFileProcessorManager(max_runs=1)
+        state = manager.get_bundle_state(bundle_name)
+
+        assert state == BundleState(last_refreshed=refreshed_at, version="v1")
+
+    def test_get_bundle_state_null_fields(self, session):
+        bundle_name = "test_null_state_bundle"
+        session.add(DagBundleModel(name=bundle_name))
+        session.commit()
+
+        manager = DagFileProcessorManager(max_runs=1)
+        state = manager.get_bundle_state(bundle_name)
+
+        assert state == BundleState(last_refreshed=None, version=None)
+
+    def test_update_bundle_state_sets_last_refreshed(self, session):
+        bundle_name = "test_update_bundle"
+        session.add(DagBundleModel(name=bundle_name))
+        session.commit()
+
+        refreshed_at = timezone.datetime(2024, 6, 1, 8, 0, 0)
+        manager = DagFileProcessorManager(max_runs=1)
+        manager.update_bundle_state(bundle_name, last_refreshed=refreshed_at, 
version=None)
+
+        session.expire_all()
+        model = session.get(DagBundleModel, bundle_name)
+        assert model.last_refreshed == refreshed_at
+        assert model.version is None
+
+    def test_update_bundle_state_sets_version(self, session):
+        bundle_name = "test_update_version_bundle"
+        session.add(DagBundleModel(name=bundle_name))
+        session.commit()
+
+        refreshed_at = timezone.datetime(2024, 6, 1, 8, 0, 0)
+        manager = DagFileProcessorManager(max_runs=1)
+        manager.update_bundle_state(bundle_name, last_refreshed=refreshed_at, 
version="abc123")
+
+        session.expire_all()
+        model = session.get(DagBundleModel, bundle_name)
+        assert model.last_refreshed == refreshed_at
+        assert model.version == "abc123"
+
+    def test_update_bundle_state_does_not_overwrite_version_when_none(self, 
session):
+        bundle_name = "test_preserve_version_bundle"
+        session.add(DagBundleModel(name=bundle_name, version="keep_me"))
+        session.commit()
+
+        refreshed_at = timezone.datetime(2024, 6, 1, 8, 0, 0)
+        manager = DagFileProcessorManager(max_runs=1)
+        manager.update_bundle_state(bundle_name, last_refreshed=refreshed_at, 
version=None)
+
+        session.expire_all()
+        model = session.get(DagBundleModel, bundle_name)
+        assert model.last_refreshed == refreshed_at
+        assert model.version == "keep_me"
+
+    def _make_refresh_bundle(self, *, supports_versioning=False, 
current_version=None):
+        bundle = MagicMock(spec=BaseDagBundle)
+        bundle.name = "mock_bundle"
+        bundle.refresh_interval = 0
+        bundle.supports_versioning = supports_versioning
+        bundle.is_initialized = True
+        bundle.path = Path("/dev/null")
+        bundle.get_current_version.return_value = current_version
+        return bundle
+
+    def _refresh_with_mocked_state(self, manager, bundle, initial_state):
+        """Run _refresh_dag_bundles with get/update_bundle_state mocked out.
+
+        Returns the two MagicMock objects for post-call assertions. MagicMock 
retains its
+        call records after the ``with`` block exits (un-patching only restores 
the original
+        attribute; it does not clear the mock's recorded calls), so callers 
can assert on
+        them normally after this method returns.
+        """
+        manager._dag_bundles = [bundle]
+        manager._force_refresh_bundles = set()
+        mock_get = mock.patch.object(manager, "get_bundle_state", 
return_value=initial_state)
+        mock_update = mock.patch.object(manager, "update_bundle_state")
+        with (
+            mock_get as patched_get,
+            mock_update as patched_update,
+            mock.patch.object(manager, "_find_files_in_bundle", 
return_value=[]),
+            mock.patch.object(manager, "deactivate_deleted_dags"),
+            mock.patch.object(manager, "clear_orphaned_import_errors"),
+            mock.patch.object(manager, "handle_removed_files"),
+            mock.patch.object(manager, "_resort_file_queue"),
+            mock.patch.object(manager, "_add_new_files_to_queue"),
+        ):
+            manager._refresh_dag_bundles({})
+        return patched_get, patched_update
+
+    def test_refresh_dag_bundles_non_versioned_calls_update_bundle_state(self):
+        """Non-versioned bundle: update_bundle_state called with 
version=None."""
+        manager = DagFileProcessorManager(max_runs=1)
+        bundle = self._make_refresh_bundle(supports_versioning=False)
+
+        mock_get, mock_update = self._refresh_with_mocked_state(
+            manager, bundle, BundleState(last_refreshed=None, version=None)
+        )
+
+        mock_get.assert_called_once_with("mock_bundle")
+        mock_update.assert_called_once_with("mock_bundle", 
last_refreshed=mock.ANY, version=None)
+        assert manager._bundle_versions["mock_bundle"] is None
+
+    def 
test_refresh_dag_bundles_versioned_version_changed_calls_update_bundle_state(self):
+        """Versioned bundle with new version: update_bundle_state called with 
the new version."""
+        manager = DagFileProcessorManager(max_runs=1)
+        bundle = self._make_refresh_bundle(supports_versioning=True, 
current_version="v2")
+        # Pre-populate _bundle_versions so previously_seen=True and current DB 
version differs
+        manager._bundle_versions["mock_bundle"] = "v1"
+
+        mock_get, mock_update = self._refresh_with_mocked_state(
+            manager, bundle, BundleState(last_refreshed=None, version="v1")
+        )
+
+        mock_get.assert_called_once_with("mock_bundle")
+        mock_update.assert_called_once_with("mock_bundle", 
last_refreshed=mock.ANY, version="v2")
+        assert manager._bundle_versions["mock_bundle"] == "v2"
+
+    def 
test_refresh_dag_bundles_versioned_version_unchanged_calls_update_bundle_state(self):
+        """Versioned bundle with unchanged version: update_bundle_state called 
with version=None."""
+        manager = DagFileProcessorManager(max_runs=1)
+        bundle = self._make_refresh_bundle(supports_versioning=True, 
current_version="v1")
+        # Pre-populate _bundle_versions so previously_seen=True and version 
matches
+        manager._bundle_versions["mock_bundle"] = "v1"
+
+        mock_get, mock_update = self._refresh_with_mocked_state(
+            manager, bundle, BundleState(last_refreshed=None, version="v1")
+        )
+
+        mock_get.assert_called_once_with("mock_bundle")
+        # version=None because version did not change — last_refreshed still 
updated
+        mock_update.assert_called_once_with("mock_bundle", 
last_refreshed=mock.ANY, version=None)
+        # _bundle_versions NOT updated for unchanged-version early-continue 
path
+        assert manager._bundle_versions["mock_bundle"] == "v1"
+
+    def 
test_refresh_dag_bundles_versioned_version_unchanged_persist_failure(self):
+        """Short-circuit path: if update_bundle_state raises, the bundle is 
skipped without
+        populating known_files (version didn't change, so no file scanning 
needed)."""
+        manager = DagFileProcessorManager(max_runs=1)
+        bundle = self._make_refresh_bundle(supports_versioning=True, 
current_version="v1")
+        manager._bundle_versions["mock_bundle"] = "v1"
+        manager._dag_bundles = [bundle]
+        manager._force_refresh_bundles = set()
+
+        known_files: dict[str, set[DagFileInfo]] = {}
+        with (
+            mock.patch.object(
+                manager, "get_bundle_state", 
return_value=BundleState(last_refreshed=None, version="v1")
+            ),
+            mock.patch.object(manager, "update_bundle_state", 
side_effect=Exception("DB error")),
+            mock.patch.object(manager, "_find_files_in_bundle", 
return_value=[]) as mock_find,
+            mock.patch.object(manager, "deactivate_deleted_dags"),
+            mock.patch.object(manager, "clear_orphaned_import_errors"),
+            mock.patch.object(manager, "handle_removed_files"),
+            mock.patch.object(manager, "_resort_file_queue"),
+            mock.patch.object(manager, "_add_new_files_to_queue"),
+        ):
+            manager._refresh_dag_bundles(known_files)
+
+        bundle.refresh.assert_called_once()
+        # Short-circuit continues to next bundle — no file scanning
+        mock_find.assert_not_called()
+        assert "mock_bundle" not in known_files
+        # _bundle_versions unchanged
+        assert manager._bundle_versions["mock_bundle"] == "v1"
+
+    def 
test_refresh_dag_bundles_versioned_first_seen_skips_short_circuit(self):
+        """Versioned bundle seen for the first time: short-circuit is skipped 
even if versions match.
+
+        previously_seen=False means ``previously_seen and ...`` is False, so 
the bundle always
+        goes through the full update path on first encounter regardless of 
version equality.
+        """
+        manager = DagFileProcessorManager(max_runs=1)
+        # current_version matches what's already in the DB state
+        bundle = self._make_refresh_bundle(supports_versioning=True, 
current_version="v1")
+        # _bundle_versions is empty → previously_seen=False
+
+        mock_get, mock_update = self._refresh_with_mocked_state(
+            manager, bundle, BundleState(last_refreshed=None, version="v1")
+        )
+
+        mock_get.assert_called_once_with("mock_bundle")
+        # full update called with the actual version, not short-circuited to 
version=None
+        mock_update.assert_called_once_with("mock_bundle", 
last_refreshed=mock.ANY, version="v1")
+        assert manager._bundle_versions["mock_bundle"] == "v1"
+
+    def test_refresh_dag_bundles_get_bundle_state_failure_skips_bundle(self):
+        """A failure in get_bundle_state() logs and skips the bundle without 
aborting the loop."""
+        manager = DagFileProcessorManager(max_runs=1)
+        bundle = self._make_refresh_bundle()
+        manager._dag_bundles = [bundle]
+
+        with mock.patch.object(manager, "get_bundle_state", 
side_effect=Exception("API error")):
+            manager._refresh_dag_bundles({})
+
+        bundle.refresh.assert_not_called()
+
+    def 
test_refresh_dag_bundles_update_bundle_state_failure_still_scans_files(self):
+        """A failure in update_bundle_state() logs but does not skip file 
scanning.
+
+        The bundle was already refreshed, so known_files must still be 
populated to prevent
+        the end-of-method cleanup from treating the bundle's files as removed.
+        """
+        manager = DagFileProcessorManager(max_runs=1)
+        bundle = self._make_refresh_bundle()
+        manager._dag_bundles = [bundle]
+
+        known_files: dict[str, set[DagFileInfo]] = {}
+        with (
+            mock.patch.object(
+                manager, "get_bundle_state", 
return_value=BundleState(last_refreshed=None, version=None)
+            ),
+            mock.patch.object(manager, "update_bundle_state", 
side_effect=Exception("API error")),
+            mock.patch.object(manager, "_find_files_in_bundle", 
return_value=[]),
+            mock.patch.object(manager, "deactivate_deleted_dags"),
+            mock.patch.object(manager, "clear_orphaned_import_errors"),
+            mock.patch.object(manager, "handle_removed_files"),
+            mock.patch.object(manager, "_resort_file_queue"),
+            mock.patch.object(manager, "_add_new_files_to_queue"),
+        ):
+            manager._refresh_dag_bundles(known_files)
+
+        bundle.refresh.assert_called_once()
+        # known_files must be populated so cleanup doesn't purge this bundle's 
files
+        assert "mock_bundle" in known_files
+        # _bundle_versions must NOT advance — DB still holds the old version, 
so the next
+        # iteration will see a version mismatch and re-refresh rather than 
skip incorrectly
+        assert "mock_bundle" not in manager._bundle_versions

Reply via email to