This is an automated email from the ASF dual-hosted git repository.
dstandish pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 10da1a02f30 Implement stale dag bundle cleanup (#46503)
10da1a02f30 is described below
commit 10da1a02f30713e1a29a34f881334b25f498f017
Author: Daniel Standish <[email protected]>
AuthorDate: Fri Feb 28 18:25:54 2025 -0800
Implement stale dag bundle cleanup (#46503)
On a shared worker, dag bundles will accumulate until the worker dies,
unless we do something about it.
Here we delete old dag bundle copies that exist locally on the worker.
There are three controls that avoid deletion of dag bundles that may still be
needed.
stale_bundle_cleanup_min_versions is a setting in dag processor. cleanup
process will always retain min versions number of bundle versions.
stale_bundle_cleanup_age_threshold is another setting in dag processor.
cleanup process will always retain versions that were last checked out ore
recently than the threshold.
while task in use, file is locked which prevents bundle version from being
deleted.
---
airflow/config_templates/config.yml | 31 +++
airflow/dag_processing/bundles/base.py | 297 ++++++++++++++++++++-
airflow/dag_processing/bundles/git.py | 49 +++-
airflow/executors/workloads.py | 9 +-
airflow/jobs/scheduler_job_runner.py | 13 +
airflow/models/taskinstance.py | 4 +-
.../airflow/providers/celery/cli/celery_command.py | 43 ++-
.../tests/unit/celery/cli/test_celery_command.py | 19 ++
.../unit/google/cloud/operators/test_dataproc.py | 2 +-
.../tests/unit/openlineage/utils/test_spark.py | 2 +-
.../src/airflow/sdk/execution_time/task_runner.py | 12 +-
task_sdk/tests/execution_time/test_supervisor.py | 12 +-
tests/dag_processing/bundles/test_base.py | 172 +++++++++++-
tests/dag_processing/bundles/test_git.py | 7 +-
.../cli/commands/test_celery_command.py | 17 +-
15 files changed, 638 insertions(+), 51 deletions(-)
diff --git a/airflow/config_templates/config.yml
b/airflow/config_templates/config.yml
index 48f142a2875..6a57fc70987 100644
--- a/airflow/config_templates/config.yml
+++ b/airflow/config_templates/config.yml
@@ -2697,3 +2697,34 @@ dag_processor:
type: integer
example: ~
default: "5"
+ stale_bundle_cleanup_interval:
+ description: |
+ On shared workers, bundle copies accumulate in local storage as tasks
run
+ and version of the bundle changes.
+ This setting represents the delta in seconds between checks for these
stale bundles.
+ Bundles which are older than `stale_bundle_cleanup_age_threshold` may
be removed. But
+ we always keep `stale_bundle_cleanup_min_versions` versions locally.
+ Set to 0 or negative to disable.
+ version_added: ~
+ type: integer
+ example: ~
+ default: "1800"
+ stale_bundle_cleanup_age_threshold:
+ description: |
+ Bundle versions used more recently than this threshold will not be
removed.
+ Recency of use is determined by when the task began running on the
worker,
+ that age is compared with this setting, given as time delta in seconds.
+ version_added: ~
+ type: integer
+ example: ~
+ default: "21600"
+ stale_bundle_cleanup_min_versions:
+ description: |
+ Minimum number of local bundle versions to retain on disk.
+ Local bundle versions older than `stale_bundle_cleanup_age_threshold`
will
+ only be deleted we have more than `stale_bundle_cleanup_min_versions`
versions
+ accumulated on the worker.
+ version_added: ~
+ type: integer
+ example: ~
+ default: "10"
diff --git a/airflow/dag_processing/bundles/base.py
b/airflow/dag_processing/bundles/base.py
index 3d86f398c73..8a4aeb18f6d 100644
--- a/airflow/dag_processing/bundles/base.py
+++ b/airflow/dag_processing/bundles/base.py
@@ -18,12 +18,210 @@
from __future__ import annotations
import fcntl
+import logging
+import os
+import shutil
import tempfile
from abc import ABC, abstractmethod
from contextlib import contextmanager
+from dataclasses import dataclass, field
+from datetime import timedelta
+from fcntl import LOCK_SH, LOCK_UN, flock
+from operator import attrgetter
from pathlib import Path
+from typing import TYPE_CHECKING
+
+from pendulum.parsing import ParserError
+from sqlalchemy_utils.types.enriched_datetime.pendulum_datetime import pendulum
from airflow.configuration import conf
+from airflow.dag_processing.bundles.manager import DagBundlesManager
+
+if TYPE_CHECKING:
+ from pendulum import DateTime
+
+log = logging.getLogger(__name__)
+
+
+def get_bundle_storage_root_path():
+ if configured_location := conf.get("dag_processor",
"dag_bundle_storage_path", fallback=None):
+ return Path(configured_location)
+ else:
+ return Path(tempfile.gettempdir(), "airflow", "dag_bundles")
+
+
+STALE_BUNDLE_TRACKING_FOLDER = get_bundle_storage_root_path() / "_tracking"
+
+
+def get_bundle_tracking_dir(bundle_name: str) -> Path:
+ return STALE_BUNDLE_TRACKING_FOLDER / bundle_name
+
+
+def get_bundle_tracking_file(bundle_name: str, version: str) -> Path:
+ tracking_dir = get_bundle_tracking_dir(bundle_name=bundle_name)
+ return Path(tracking_dir, version)
+
+
+def get_bundle_base_folder(bundle_name: str) -> Path:
+ return get_bundle_storage_root_path() / bundle_name
+
+
+def get_bundle_versions_base_folder(bundle_name: str) -> Path:
+ return get_bundle_base_folder(bundle_name=bundle_name) / "versions"
+
+
+def get_bundle_version_path(bundle_name: str, version: str) -> Path:
+ base_folder = get_bundle_versions_base_folder(bundle_name=bundle_name)
+ return base_folder / version
+
+
+@dataclass(frozen=True)
+class TrackedBundleVersionInfo:
+ """
+ Internal info class for stale bundle cleanup.
+
+ :meta private:
+ """
+
+ lock_file_path: Path
+ version: str = field(compare=False)
+ dt: DateTime = field(compare=False)
+
+
+class BundleUsageTrackingManager:
+ """
+ Utility helper for removing stale bundles.
+
+ :meta private:
+ """
+
+ def _parse_dt(self, val) -> DateTime | None:
+ try:
+ return pendulum.parse(val)
+ except ParserError:
+ return None
+
+ @staticmethod
+ def _filter_for_min_versions(val: list[TrackedBundleVersionInfo]) ->
list[TrackedBundleVersionInfo]:
+ min_versions_to_keep = conf.getint(
+ section="dag_processor",
+ key="stale_bundle_cleanup_min_versions",
+ )
+ return sorted(val, key=attrgetter("dt"),
reverse=True)[min_versions_to_keep:]
+
+ @staticmethod
+ def _filter_for_recency(val: list[TrackedBundleVersionInfo]) ->
list[TrackedBundleVersionInfo]:
+ age_threshold = conf.getint(
+ section="dag_processor",
+ key="stale_bundle_cleanup_age_threshold",
+ )
+ ret = []
+ now = pendulum.now(tz=pendulum.UTC)
+ cutoff = now - timedelta(seconds=age_threshold)
+ for item in val:
+ if item.dt < cutoff:
+ ret.append(item)
+ return ret
+
+ def _find_all_tracking_files(self, bundle_name) ->
list[TrackedBundleVersionInfo] | None:
+ tracking_dir = get_bundle_tracking_dir(bundle_name=bundle_name)
+ found: list[TrackedBundleVersionInfo] = []
+ if not tracking_dir.exists():
+ log.debug("bundle usage tracking directory does not exist.
tracking_dir=%s", tracking_dir)
+ return None
+ for file in tracking_dir.iterdir():
+ log.debug("found bundle tracking file, path=%s", file)
+ version = file.name
+ dt_str = file.read_text()
+ dt = self._parse_dt(val=dt_str)
+ if not dt:
+ log.error(
+ "could not parse val as datetime bundle_name=%s val=%s
version=%s",
+ bundle_name,
+ dt_str,
+ version,
+ )
+ continue
+ found.append(TrackedBundleVersionInfo(lock_file_path=file,
version=version, dt=dt))
+ return found
+
+ @staticmethod
+ def _remove_stale_bundle(bundle_name: str, info: TrackedBundleVersionInfo)
-> None:
+ bundle_version_path = get_bundle_version_path(
+ bundle_name=bundle_name,
+ version=info.version,
+ )
+
+ def log_info(msg):
+ log.info(
+ "%s bundle_name=%s bundle_version=%s bundle_path=%s
lock_file=%s",
+ msg,
+ bundle_name,
+ info.version,
+ bundle_version_path,
+ info.lock_file_path,
+ )
+
+ try:
+ log_info("removing stale bundle.")
+ with open(info.lock_file_path, "a") as f:
+ flock(f, fcntl.LOCK_EX | fcntl.LOCK_NB) # exclusive lock, do
not wait
+ # remove the actual bundle copy
+ shutil.rmtree(bundle_version_path)
+ # remove the lock file
+ os.remove(info.lock_file_path)
+ except BlockingIOError:
+ log_info("could not obtain lock. stale bundle will not be
removed.")
+ return
+
+ def _find_candidates(self, found):
+ """Remove the recently used bundles."""
+ candidates = self._filter_for_min_versions(found)
+ candidates = self._filter_for_recency(candidates)
+ if log.isEnabledFor(level=logging.DEBUG):
+ self._debug_candidates(candidates, found)
+ return candidates
+
+ @staticmethod
+ def _debug_candidates(candidates, found):
+ recently_used = list(set(found).difference(candidates))
+ if candidates:
+ log.debug(
+ "found removal candidates. candidates=%s recently_used=%s",
+ candidates,
+ recently_used,
+ )
+ else:
+ log.debug(
+ "no removal candidates found. candidates=%s recently_used=%s",
+ candidates,
+ recently_used,
+ )
+
+ def _remove_stale_bundle_versions_for_bundle(self, bundle_name: str):
+ log.info("checking bundle for stale versions. bundle_name=%s",
bundle_name)
+ found = self._find_all_tracking_files(bundle_name=bundle_name)
+ if not found:
+ return
+ candidates = self._find_candidates(found)
+ for info in candidates:
+ self._remove_stale_bundle(bundle_name=bundle_name, info=info)
+
+ def remove_stale_bundle_versions(self):
+ """
+ Remove bundles that are not in use and have not been used for some
time.
+
+ We will keep last N used bundles, and bundles last used with in X time.
+
+ This isn't really necessary on worker types that don't share storage
+ with other processes.
+ """
+ log.info("checking for stale bundle versions locally")
+ bundles = list(DagBundlesManager().get_all_dag_bundles())
+ for bundle in bundles:
+ if not bundle.supports_versioning:
+ continue
+
self._remove_stale_bundle_versions_for_bundle(bundle_name=bundle.name)
class BaseDagBundle(ABC):
@@ -48,6 +246,7 @@ class BaseDagBundle(ABC):
"""
supports_versioning: bool = False
+
_locked: bool = False
def __init__(
@@ -62,6 +261,12 @@ class BaseDagBundle(ABC):
self.refresh_interval = refresh_interval
self.is_initialized: bool = False
+ self.base_dir = get_bundle_base_folder(bundle_name=self.name)
+ """Base directory for all bundle files for this bundle."""
+
+ self.versions_dir =
get_bundle_versions_base_folder(bundle_name=self.name)
+ """Where bundle versions are stored locally for this bundle."""
+
def initialize(self) -> None:
"""
Initialize the bundle.
@@ -77,17 +282,6 @@ class BaseDagBundle(ABC):
"""
self.is_initialized = True
- @property
- def _dag_bundle_root_storage_path(self) -> Path:
- """
- Where bundles can store DAGs on disk (if local disk is required).
-
- This is the root bundle storage path, common to all bundles. Each
bundle should use a subdirectory of this path.
- """
- if configured_location := conf.get("dag_processor",
"dag_bundle_storage_path", fallback=None):
- return Path(configured_location)
- return Path(tempfile.gettempdir(), "airflow", "dag_bundles")
-
@property
@abstractmethod
def path(self) -> Path:
@@ -137,9 +331,10 @@ class BaseDagBundle(ABC):
yield
return
- lock_dir_path = self._dag_bundle_root_storage_path / "_locks"
+ lock_dir_path = get_bundle_storage_root_path() / "_locks"
lock_dir_path.mkdir(parents=True, exist_ok=True)
lock_file_path = lock_dir_path / f"{self.name}.lock"
+
with open(lock_file_path, "w") as lock_file:
# Exclusive lock - blocks until it is available
fcntl.flock(lock_file, fcntl.LOCK_EX)
@@ -147,5 +342,81 @@ class BaseDagBundle(ABC):
self._locked = True
yield
finally:
- fcntl.flock(lock_file, fcntl.LOCK_UN)
+ fcntl.flock(lock_file, LOCK_UN)
self._locked = False
+
+ def __repr__(self):
+ return f"{self.__class__.__name__}(name={self.name})"
+
+
+class BundleVersionLock:
+ """
+ Lock version of bundle when in use to prevent deletion.
+
+ :meta private:
+ """
+
+ def __init__(self, bundle_name, bundle_version, **kwargs):
+ super().__init__(**kwargs)
+ self.lock_file = None
+ self.bundle_name = bundle_name
+ self.version = bundle_version
+ self.lock_file_path: Path | None = None
+ if self.version:
+ self.lock_file_path = get_bundle_tracking_file(
+ bundle_name=self.bundle_name,
+ version=self.version,
+ )
+
+ def _log_exc(self, msg):
+ log.exception(
+ "% name=%s version=%s lock_file=%s",
+ msg,
+ self.bundle_name,
+ self.version,
+ self.lock_file_path,
+ )
+
+ def _update_version_file(self):
+ """Create a version file containing last-used timestamp."""
+ if TYPE_CHECKING:
+ assert self.lock_file_path
+ self.lock_file_path.parent.mkdir(parents=True, exist_ok=True)
+
+ with tempfile.TemporaryDirectory() as td:
+ temp_file = Path(td, self.lock_file_path)
+ now = pendulum.now(tz=pendulum.UTC)
+ temp_file.write_text(now.isoformat())
+ os.replace(temp_file, self.lock_file_path)
+
+ def acquire(self):
+ if not self.version:
+ return
+ if self.lock_file:
+ return
+ self._update_version_file()
+ if TYPE_CHECKING:
+ assert self.lock_file_path
+ self.lock_file = open(self.lock_file_path)
+ flock(self.lock_file, LOCK_SH)
+
+ def release(self):
+ if self.lock_file:
+ flock(self.lock_file, LOCK_UN)
+ self.lock_file.close()
+ self.lock_file = None
+
+ def __enter__(self):
+ # wrapping in try except here is just extra cautious since this is in
task execution path
+ try:
+ self.acquire()
+ except Exception:
+ self._log_exc("error when attempting to acquire lock")
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ # wrapping in try except here is just extra cautious since this is in
task execution path
+ try:
+ self.release()
+ except Exception:
+ self._log_exc("error when attempting to release lock")
diff --git a/airflow/dag_processing/bundles/git.py
b/airflow/dag_processing/bundles/git.py
index 6fd7d815191..8daa96c21c9 100644
--- a/airflow/dag_processing/bundles/git.py
+++ b/airflow/dag_processing/bundles/git.py
@@ -19,21 +19,24 @@ from __future__ import annotations
import contextlib
import json
+import logging
import os
import tempfile
-from typing import TYPE_CHECKING, Any
+from pathlib import Path
+from typing import Any
from urllib.parse import urlparse
from git import Repo
from git.exc import BadName, GitCommandError, NoSuchPathError
-from airflow.dag_processing.bundles.base import BaseDagBundle
+from airflow.dag_processing.bundles.base import (
+ BaseDagBundle,
+)
from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
from airflow.utils.log.logging_mixin import LoggingMixin
-if TYPE_CHECKING:
- from pathlib import Path
+log = logging.getLogger(__name__)
class GitHook(BaseHook):
@@ -143,15 +146,39 @@ class GitDagBundle(BaseDagBundle, LoggingMixin):
super().__init__(**kwargs)
self.tracking_ref = tracking_ref
self.subdir = subdir
- self.bare_repo_path = self._dag_bundle_root_storage_path / "git" /
self.name
- self.repo_path = (
- self._dag_bundle_root_storage_path / "git" / (self.name +
f"+{self.version or self.tracking_ref}")
- )
+ self.bare_repo_path = self.base_dir / "bare"
+ if self.version:
+ self.repo_path = self.versions_dir / self.version
+ else:
+ self.repo_path = self.base_dir / "tracking_repo"
self.git_conn_id = git_conn_id
self.repo_url = repo_url
+
+ def log_debug(msg, **kwargs):
+ if not log.isEnabledFor(logging.DEBUG):
+ return
+ # ugly; replace when structlog implemented
+ context = dict(
+ bundle_name=self.name,
+ version=self.version,
+ bare_repo_path=self.bare_repo_path,
+ repo_path=self.repo_path,
+ versions_path=self.versions_dir,
+ git_conn_id=self.git_conn_id,
+ repo_url=self.repo_url,
+ **kwargs,
+ )
+
+ for k, v in context.items():
+ msg += f" {k}='{v}'"
+ log.debug(msg)
+
+ self._log_debug = log_debug
+ log_debug("bundle configured")
try:
self.hook = GitHook(git_conn_id=self.git_conn_id,
repo_url=self.repo_url)
self.repo_url = self.hook.repo_url
+ log_debug("repo_url updated from hook", repo_url=self.repo_url)
except AirflowException as e:
self.log.warning("Could not create GitHook for connection %s :
%s", self.git_conn_id, e)
@@ -163,10 +190,11 @@ class GitDagBundle(BaseDagBundle, LoggingMixin):
self._clone_repo_if_required()
self.repo.git.checkout(self.tracking_ref)
+ self._log_debug("bundle initialize", version=self.version)
if self.version:
if not self._has_version(self.repo, self.version):
self.repo.remotes.origin.fetch()
- self.repo.head.set_reference(self.repo.commit(self.version))
+
self.repo.head.set_reference(str(self.repo.commit(self.version)))
self.repo.head.reset(index=True, working_tree=True)
else:
self.refresh()
@@ -188,7 +216,8 @@ class GitDagBundle(BaseDagBundle, LoggingMixin):
except NoSuchPathError as e:
# Protection should the bare repo be removed manually
raise AirflowException("Repository path: %s not found",
self.bare_repo_path) from e
-
+ else:
+ self._log_debug("repo exists", repo_path=self.repo_path)
self.repo = Repo(self.repo_path)
def _clone_bare_repo_if_required(self) -> None:
diff --git a/airflow/executors/workloads.py b/airflow/executors/workloads.py
index 3be2ad22d23..00ca97405a0 100644
--- a/airflow/executors/workloads.py
+++ b/airflow/executors/workloads.py
@@ -103,9 +103,14 @@ class ExecuteTask(BaseWorkload):
name=ti.dag_model.bundle_name,
version=ti.dag_run.bundle_version,
)
- path = dag_rel_path or Path(ti.dag_run.dag_model.relative_fileloc)
fname = log_filename_template_renderer()(ti=ti)
- return cls(ti=ser_ti, dag_rel_path=path, token="", log_path=fname,
bundle_info=bundle_info)
+ return cls(
+ ti=ser_ti,
+ dag_rel_path=dag_rel_path or Path(ti.dag_model.relative_fileloc),
+ token="",
+ log_path=fname,
+ bundle_info=bundle_info,
+ )
class RunTrigger(BaseModel):
diff --git a/airflow/jobs/scheduler_job_runner.py
b/airflow/jobs/scheduler_job_runner.py
index 1ad779fdca7..0afd8904d30 100644
--- a/airflow/jobs/scheduler_job_runner.py
+++ b/airflow/jobs/scheduler_job_runner.py
@@ -41,6 +41,7 @@ from sqlalchemy.sql import expression
from airflow import settings
from airflow.callbacks.callback_requests import DagCallbackRequest,
TaskCallbackRequest
from airflow.configuration import conf
+from airflow.dag_processing.bundles.base import BundleUsageTrackingManager
from airflow.exceptions import RemovedInAirflow3Warning
from airflow.executors import workloads
from airflow.executors.base_executor import BaseExecutor
@@ -1028,6 +1029,18 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
self._cleanup_stale_dags,
)
+ if any(x.is_local for x in self.job.executors):
+ bundle_cleanup_mgr = BundleUsageTrackingManager()
+ check_interval = conf.getint(
+ section="dag_processor",
+ key="stale_bundle_cleanup_interval",
+ )
+ if check_interval > 0:
+ timers.call_regular_interval(
+ delay=check_interval,
+ action=bundle_cleanup_mgr.remove_stale_bundle_versions,
+ )
+
for loop_count in itertools.count(start=1):
with (
Trace.start_span(span_name="scheduler_job_loop",
component="SchedulerJobRunner") as span,
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 1d29bdd8c7a..227157f3864 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -2856,7 +2856,9 @@ class TaskInstance(Base, LoggingMixin):
self.log.error("Received SIGTERM. Terminating subprocesses.")
self.log.error("Stacktrace: \n%s",
"".join(traceback.format_stack()))
self.task.on_kill()
- raise AirflowTaskTerminated("Task received SIGTERM signal")
+ raise AirflowTaskTerminated(
+ f"Task received SIGTERM signal {self.task_id=} {self.dag_id=}
{self.run_id=} {self.map_index=}"
+ )
signal.signal(signal.SIGTERM, signal_handler)
diff --git
a/providers/celery/src/airflow/providers/celery/cli/celery_command.py
b/providers/celery/src/airflow/providers/celery/cli/celery_command.py
index b935523c89d..464cd830184 100644
--- a/providers/celery/src/airflow/providers/celery/cli/celery_command.py
+++ b/providers/celery/src/airflow/providers/celery/cli/celery_command.py
@@ -21,7 +21,8 @@ from __future__ import annotations
import logging
import sys
-from contextlib import contextmanager
+import time
+from contextlib import contextmanager, suppress
from multiprocessing import Process
import psutil
@@ -33,6 +34,7 @@ from lockfile.pidlockfile import read_pid_from_pidfile,
remove_existing_pidfile
from airflow import settings
from airflow.configuration import conf
+from airflow.exceptions import AirflowConfigException
from airflow.providers.celery.version_compat import AIRFLOW_V_3_0_PLUS
from airflow.utils import cli as cli_utils
from airflow.utils.cli import setup_locations
@@ -40,6 +42,8 @@ from airflow.utils.serve_logs import serve_logs
WORKER_PROCESS_NAME = "worker"
+log = logging.getLogger(__name__)
+
def _run_command_with_daemon_option(*args, **kwargs):
try:
@@ -117,6 +121,41 @@ def _serve_logs(skip_serve_logs: bool = False):
sub_proc.terminate()
+@contextmanager
+def _run_stale_bundle_cleanup():
+ """Start stale bundle cleanup sub-process."""
+ check_interval = None
+ with suppress(AirflowConfigException): # remove when min airflow version
>= 3.0
+ check_interval = conf.getint(
+ section="dag_processor",
+ key="stale_bundle_cleanup_interval",
+ )
+ if not check_interval or check_interval <= 0 or not AIRFLOW_V_3_0_PLUS:
+ # do not start bundle cleanup process
+ try:
+ yield
+ finally:
+ return
+ from airflow.dag_processing.bundles.base import BundleUsageTrackingManager
+
+ log.info("starting stale bundle cleanup process")
+ sub_proc = None
+
+ def bundle_cleanup_main():
+ mgr = BundleUsageTrackingManager()
+ while True:
+ time.sleep(check_interval)
+ mgr.remove_stale_bundle_versions()
+
+ try:
+ sub_proc = Process(target=bundle_cleanup_main)
+ sub_proc.start()
+ yield
+ finally:
+ if sub_proc:
+ sub_proc.terminate()
+
+
@after_setup_logger.connect()
@_providers_configuration_loaded
def logger_setup_handler(logger, **kwargs):
@@ -231,7 +270,7 @@ def worker(args):
)
def run_celery_worker():
- with _serve_logs(skip_serve_logs):
+ with _serve_logs(skip_serve_logs), _run_stale_bundle_cleanup():
celery_app.worker_main(options)
if args.umask:
diff --git a/providers/celery/tests/unit/celery/cli/test_celery_command.py
b/providers/celery/tests/unit/celery/cli/test_celery_command.py
index bed88668386..93b8606aa54 100644
--- a/providers/celery/tests/unit/celery/cli/test_celery_command.py
+++ b/providers/celery/tests/unit/celery/cli/test_celery_command.py
@@ -21,6 +21,7 @@ import importlib
import os
from argparse import Namespace
from unittest import mock
+from unittest.mock import patch
import pytest
import sqlalchemy
@@ -30,6 +31,7 @@ from airflow.cli import cli_parser
from airflow.configuration import conf
from airflow.executors import executor_loader
from airflow.providers.celery.cli import celery_command
+from airflow.providers.celery.cli.celery_command import
_run_stale_bundle_cleanup
from tests_common.test_utils.config import conf_vars
from tests_common.test_utils.version_compat import AIRFLOW_V_2_10_PLUS,
AIRFLOW_V_3_0_PLUS
@@ -37,6 +39,7 @@ from tests_common.test_utils.version_compat import
AIRFLOW_V_2_10_PLUS, AIRFLOW_
pytestmark = pytest.mark.db_test
+@conf_vars({("dag_processor", "stale_bundle_cleanup_interval"): 0})
class TestWorkerPrecheck:
@mock.patch("airflow.settings.validate_session")
def test_error(self, mock_validate_session):
@@ -68,6 +71,7 @@ class TestWorkerPrecheck:
@pytest.mark.backend("mysql", "postgres")
+@conf_vars({("dag_processor", "stale_bundle_cleanup_interval"): 0})
class TestCeleryStopCommand:
@classmethod
def setup_class(cls):
@@ -149,6 +153,7 @@ class TestCeleryStopCommand:
@pytest.mark.backend("mysql", "postgres")
+@conf_vars({("dag_processor", "stale_bundle_cleanup_interval"): 0})
class TestWorkerStart:
@classmethod
def setup_class(cls):
@@ -209,6 +214,7 @@ class TestWorkerStart:
@pytest.mark.backend("mysql", "postgres")
+@conf_vars({("dag_processor", "stale_bundle_cleanup_interval"): 0})
class TestWorkerFailure:
@classmethod
def setup_class(cls):
@@ -228,6 +234,7 @@ class TestWorkerFailure:
@pytest.mark.backend("mysql", "postgres")
+@conf_vars({("dag_processor", "stale_bundle_cleanup_interval"): 0})
class TestFlowerCommand:
@classmethod
def setup_class(cls):
@@ -379,3 +386,15 @@ class TestFlowerCommand:
self, mock_celery_app, mock_daemon, mock_setup_locations, mock_pid_file
):
self._test_run_command_daemon(mock_celery_app, mock_daemon,
mock_setup_locations, mock_pid_file)
+
+
+@patch("airflow.providers.celery.cli.celery_command.Process")
[email protected](not AIRFLOW_V_3_0_PLUS, reason="Doesn't apply to pre-3.0")
+def test_stale_bundle_cleanup(mock_process):
+ mock_process.__bool__.return_value = True
+ with _run_stale_bundle_cleanup():
+ ...
+ calls = mock_process.call_args_list
+ assert len(calls) == 1
+ actual = [x.kwargs["target"] for x in calls]
+ assert actual[0].__name__ == "bundle_cleanup_main"
diff --git
a/providers/google/tests/unit/google/cloud/operators/test_dataproc.py
b/providers/google/tests/unit/google/cloud/operators/test_dataproc.py
index 22d22bd8848..f15de603053 100644
--- a/providers/google/tests/unit/google/cloud/operators/test_dataproc.py
+++ b/providers/google/tests/unit/google/cloud/operators/test_dataproc.py
@@ -1707,7 +1707,7 @@ class TestDataprocSubmitJobOperator(DataprocJobTestBase):
kafka_config = KafkaConfig(
topic="my_topic",
config={
- "bootstrap.servers": "localhost:9092,another.host:9092",
+ "bootstrap.servers":
"test-kafka-xfgup:10009,another.host-ge7h0:100010",
"acks": "all",
"retries": "3",
},
diff --git a/providers/openlineage/tests/unit/openlineage/utils/test_spark.py
b/providers/openlineage/tests/unit/openlineage/utils/test_spark.py
index 49ab0d13528..c4073da4699 100644
--- a/providers/openlineage/tests/unit/openlineage/utils/test_spark.py
+++ b/providers/openlineage/tests/unit/openlineage/utils/test_spark.py
@@ -92,7 +92,7 @@ def
test_get_transport_information_as_spark_properties_unsupported_transport_typ
kafka_config = KafkaConfig(
topic="my_topic",
config={
- "bootstrap.servers": "localhost:9092,another.host:9092",
+ "bootstrap.servers":
"test-kafka-hm0fo:10011,another.host-uuj0l:10012",
"acks": "all",
"retries": "3",
},
diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py
b/task_sdk/src/airflow/sdk/execution_time/task_runner.py
index 0e5ce160dc5..b561ee1e395 100644
--- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -33,6 +33,7 @@ import lazy_object_proxy
import structlog
from pydantic import BaseModel, ConfigDict, Field, JsonValue, TypeAdapter
+from airflow.dag_processing.bundles.base import BaseDagBundle,
BundleVersionLock
from airflow.dag_processing.bundles.manager import DagBundlesManager
from airflow.listeners.listener import get_listener_manager
from airflow.sdk.api.datamodels._generated import (
@@ -93,6 +94,7 @@ class RuntimeTaskInstance(TaskInstance):
model_config = ConfigDict(arbitrary_types_allowed=True)
task: BaseOperator
+ bundle_instance: BaseDagBundle
_ti_context_from_server: Annotated[TIRunContext | None, Field(repr=False)]
= None
"""The Task Instance context from the API server, if any."""
@@ -414,6 +416,7 @@ def parse(what: StartupDetails) -> RuntimeTaskInstance:
return RuntimeTaskInstance.model_construct(
**what.ti.model_dump(exclude_unset=True),
task=task,
+ bundle_instance=bundle_instance,
_ti_context_from_server=what.ti_context,
max_tries=what.ti_context.max_tries,
start_date=what.start_date,
@@ -802,10 +805,15 @@ def main():
# TODO: add an exception here, it causes an oof of a stack trace!
global SUPERVISOR_COMMS
SUPERVISOR_COMMS = CommsDecoder[ToTask, ToSupervisor](input=sys.stdin)
+
try:
ti, log = startup()
- state, msg, error = run(ti, log)
- finalize(ti, state, log, error)
+ with BundleVersionLock(
+ bundle_name=ti.bundle_instance.name,
+ bundle_version=ti.bundle_instance.version,
+ ):
+ state, msg, error = run(ti, log)
+ finalize(ti, state, log, error)
except KeyboardInterrupt:
log = structlog.get_logger(logger_name="task")
log.exception("Ctrl-c hit")
diff --git a/task_sdk/tests/execution_time/test_supervisor.py
b/task_sdk/tests/execution_time/test_supervisor.py
index f09fe826161..49745c73fa7 100644
--- a/task_sdk/tests/execution_time/test_supervisor.py
+++ b/task_sdk/tests/execution_time/test_supervisor.py
@@ -26,6 +26,7 @@ import signal
import sys
from io import BytesIO
from operator import attrgetter
+from pathlib import Path
from time import sleep
from typing import TYPE_CHECKING
from unittest.mock import MagicMock, patch
@@ -33,6 +34,7 @@ from unittest.mock import MagicMock, patch
import httpx
import psutil
import pytest
+import tenacity
from pytest_unordered import unordered
from uuid6 import uuid7
@@ -74,6 +76,7 @@ from task_sdk.tests.execution_time.test_task_runner import
FAKE_BUNDLE
if TYPE_CHECKING:
import kgb
+log = logging.getLogger(__name__)
TI_ID = uuid7()
@@ -666,6 +669,11 @@ class TestListenerOvertime:
),
],
)
+ @tenacity.retry(
+ stop=tenacity.stop_after_attempt(5),
+ retry=tenacity.retry_if_exception_type(AssertionError),
+ before=tenacity.before_log(log, logging.INFO),
+ )
def test_overtime_slow_listener_instance(
self,
dag_id,
@@ -681,9 +689,7 @@ class TestListenerOvertime:
monkeypatch.setattr(ActivitySubprocess, "TASK_OVERTIME_THRESHOLD",
overtime_threshold)
"""Test running a simple DAG in a subprocess and capturing the
output."""
-
get_listener_manager().add_listener(listener)
- dagfile_path = test_dags_dir
ti = TaskInstance(
id=uuid7(),
task_id=task_id,
@@ -695,7 +701,7 @@ class TestListenerOvertime:
with patch.dict(os.environ, local_dag_bundle_cfg(test_dags_dir,
bundle_info.name)):
exit_code = supervise(
ti=ti,
- dag_rel_path=dagfile_path,
+ dag_rel_path=Path("super_basic_run.py"),
token="",
server="",
dry_run=True,
diff --git a/tests/dag_processing/bundles/test_base.py
b/tests/dag_processing/bundles/test_base.py
index 6eaf07cccbd..167c0e18f91 100644
--- a/tests/dag_processing/bundles/test_base.py
+++ b/tests/dag_processing/bundles/test_base.py
@@ -18,18 +18,31 @@
from __future__ import annotations
import fcntl
+import logging
import tempfile
+import threading
+import time
+from datetime import timedelta
from pathlib import Path
+from unittest.mock import patch
import pytest
+import time_machine
-from airflow.dag_processing.bundles.base import BaseDagBundle
-from airflow.dag_processing.bundles.local import LocalDagBundle
+from airflow.dag_processing.bundles.base import (
+ BaseDagBundle,
+ BundleUsageTrackingManager,
+ BundleVersionLock,
+ get_bundle_storage_root_path,
+)
+from airflow.utils import timezone as tz
from tests_common.test_utils.config import conf_vars
pytestmark = pytest.mark.db_test
+log = logging.getLogger(__name__)
+
@pytest.fixture(autouse=True)
def bundle_temp_dir(tmp_path):
@@ -37,10 +50,16 @@ def bundle_temp_dir(tmp_path):
yield tmp_path
-def test_default_dag_storage_path():
- with conf_vars({("dag_processor", "dag_bundle_storage_path"): ""}):
- bundle = LocalDagBundle(name="test", path="/hello")
- assert bundle._dag_bundle_root_storage_path ==
Path(tempfile.gettempdir(), "airflow", "dag_bundles")
[email protected](
+ "val, expected",
+ [
+ ("/blah", Path("/blah")),
+ ("", Path(tempfile.gettempdir(), "airflow", "dag_bundles")),
+ ],
+)
+def test_default_dag_storage_path(val, expected):
+ with conf_vars({("dag_processor", "dag_bundle_storage_path"): val}):
+ assert get_bundle_storage_root_path() == expected
class BasicBundle(BaseDagBundle):
@@ -56,14 +75,13 @@ class BasicBundle(BaseDagBundle):
def test_dag_bundle_root_storage_path():
with conf_vars({("dag_processor", "dag_bundle_storage_path"): None}):
- bundle = BasicBundle(name="test")
- assert bundle._dag_bundle_root_storage_path ==
Path(tempfile.gettempdir(), "airflow", "dag_bundles")
+ assert get_bundle_storage_root_path() == Path(tempfile.gettempdir(),
"airflow", "dag_bundles")
def test_lock_acquisition():
"""Test that the lock context manager sets _locked and locks a lock
file."""
bundle = BasicBundle(name="locktest")
- lock_dir = bundle._dag_bundle_root_storage_path / "_locks"
+ lock_dir = get_bundle_storage_root_path() / "_locks"
lock_file = lock_dir / f"{bundle.name}.lock"
assert not bundle._locked
@@ -97,7 +115,7 @@ def test_lock_acquisition():
def test_lock_exception_handling():
"""Test that exceptions within the lock context manager still release the
lock."""
bundle = BasicBundle(name="locktest")
- lock_dir = bundle._dag_bundle_root_storage_path / "_locks"
+ lock_dir = get_bundle_storage_root_path() / "_locks"
lock_file = lock_dir / f"{bundle.name}.lock"
try:
@@ -117,3 +135,137 @@ def test_lock_exception_handling():
except OSError:
acquired = False
assert acquired
+
+
+class LockTestHelper:
+ def __init__(self, num, **kwargs):
+ super().__init__(**kwargs)
+ self.num = num
+ self.stop = None
+ self.did_lock = None
+ self.locker: BundleVersionLock
+
+ def lock_the_file(self):
+ self.locker = BundleVersionLock(
+ bundle_name="abc",
+ bundle_version="this",
+ )
+ with self.locker:
+ self.did_lock = True
+ idx = 0
+ while not self.stop:
+ idx += 1
+ time.sleep(0.2)
+ log.info("sleeping: idx=%s num=%s", idx, self.num)
+ log.info("exit")
+
+
+class TestBundleVersionLock:
+ def test_that_shared_lock_doesnt_block_shared_lock(self):
+ """Verify that two things can lock file at same time."""
+
+ lth1 = LockTestHelper(1)
+ t1 = threading.Thread(target=lth1.lock_the_file)
+ lth2 = LockTestHelper(2)
+ t2 = threading.Thread(target=lth2.lock_the_file)
+ t1.start()
+ time.sleep(0.1)
+ assert lth1.did_lock is True
+ t2.start()
+ time.sleep(0.1)
+ assert lth2.did_lock is True
+ lth1.stop = True
+ lth2.stop = True
+ t1.join()
+ t2.join()
+
+ def test_that_shared_lock_blocks_ex_lock(self):
+ """Test that exclusive lock is impossible when in bundle lock
context."""
+ lth1 = LockTestHelper(1)
+ t1 = threading.Thread(target=lth1.lock_the_file)
+ t1.start()
+ time.sleep(0.1)
+ assert lth1.did_lock is True
+ with open(lth1.locker.lock_file_path, "a") as f:
+ fcntl.flock(f, fcntl.LOCK_SH)
+ fcntl.flock(f, fcntl.LOCK_UN)
+ with pytest.raises(BlockingIOError): # <-- this is the important
part
+ fcntl.flock(f, fcntl.LOCK_EX | fcntl.LOCK_NB)
+ lth1.stop = True
+ t1.join()
+
+ def test_that_no_version_is_noop(self):
+ with BundleVersionLock(
+ bundle_name="Yer face",
+ bundle_version=None,
+ ) as b:
+ log.info("this is fine")
+ assert b.lock_file_path is None
+ assert b.lock_file is None
+
+
+class FakeBundle(BaseDagBundle):
+ @property
+ def path(self) -> Path:
+ assert self.version
+ return self.versions_dir / self.version
+
+ def get_current_version(self) -> str | None: ...
+ def refresh(self) -> None: ...
+
+
+class TestBundleUsageTrackingManager:
+ @pytest.mark.parametrize(
+ "threshold_hours, min_versions, when_hours, expected_remaining",
+ [
+ (3, 0, 3, 5),
+ (3, 0, 6, 2),
+ (10, 0, 3, 5),
+ (10, 0, 6, 5),
+ (0, 0, 3, 2), # two of them are in future
+ (0, 0, 6, 0), # all of them are in past
+ (0, 5, 3, 5), # keep all no matter what
+ (0, 5, 6, 5), # keep all no matter what
+ (0, 4, 3, 4), # keep 4 no matter what
+ (0, 4, 6, 4), # keep 4 no matter what
+ ],
+ )
+ @patch("airflow.dag_processing.bundles.base.get_bundle_tracking_dir")
+ def test_that_stale_bundles_are_removed(
+ self, mock_get_dir, threshold_hours, min_versions, when_hours,
expected_remaining
+ ):
+ age_threshold = threshold_hours * 60 * 60
+ with (
+ conf_vars(
+ {
+ ("dag_processor", "stale_bundle_cleanup_age_threshold"):
str(age_threshold),
+ ("dag_processor", "stale_bundle_cleanup_min_versions"):
str(min_versions),
+ }
+ ),
+ tempfile.TemporaryDirectory() as td,
+ ):
+ bundle_tracking_dir = Path(td)
+ mock_get_dir.return_value = bundle_tracking_dir
+ h0 = tz.datetime(2025, 1, 1, 0)
+ bundle_name = "abc"
+ for num in range(5):
+ with time_machine.travel(h0 + timedelta(hours=num),
tick=False):
+ version = f"hour-{num}"
+ b = FakeBundle(version=version, name=bundle_name)
+ b.path.mkdir(exist_ok=True, parents=True)
+ with BundleVersionLock(
+ bundle_name=bundle_name,
+ bundle_version=version,
+ ):
+ print(version)
+ lock_files = list(bundle_tracking_dir.iterdir())
+ assert len(lock_files) == 5
+ bundle_folders = list(b.versions_dir.iterdir())
+ assert len(bundle_folders) == 5
+ num += 1
+ with time_machine.travel(h0 + timedelta(hours=when_hours),
tick=False):
+
BundleUsageTrackingManager()._remove_stale_bundle_versions_for_bundle(bundle_name=bundle_name)
+ lock_files = list(bundle_tracking_dir.iterdir())
+ assert len(lock_files) == expected_remaining
+ bundle_folders = list(b.versions_dir.iterdir())
+ assert len(bundle_folders) == expected_remaining
diff --git a/tests/dag_processing/bundles/test_git.py
b/tests/dag_processing/bundles/test_git.py
index 42a96a24f95..98e7eccf0a4 100644
--- a/tests/dag_processing/bundles/test_git.py
+++ b/tests/dag_processing/bundles/test_git.py
@@ -25,6 +25,7 @@ import pytest
from git import Repo
from git.exc import GitCommandError, NoSuchPathError
+from airflow.dag_processing.bundles.base import get_bundle_storage_root_path
from airflow.dag_processing.bundles.git import GitDagBundle, GitHook
from airflow.exceptions import AirflowException
from airflow.models import Connection
@@ -230,10 +231,10 @@ class TestGitDagBundle:
def test_supports_versioning(self):
assert GitDagBundle.supports_versioning is True
- def test_uses_dag_bundle_root_storage_path(self, git_repo):
- repo_path, repo = git_repo
+ def test_uses_dag_bundle_root_storage_path(self):
bundle = GitDagBundle(name="test", tracking_ref=GIT_DEFAULT_BRANCH)
- assert str(bundle._dag_bundle_root_storage_path) in str(bundle.path)
+ base = get_bundle_storage_root_path()
+ assert bundle.path.is_relative_to(base)
def test_repo_url_overrides_connection_host_when_provided(self):
bundle = GitDagBundle(name="test", tracking_ref=GIT_DEFAULT_BRANCH,
repo_url="/some/other/repo")
diff --git a/tests/integration/cli/commands/test_celery_command.py
b/tests/integration/cli/commands/test_celery_command.py
index db798ba8b71..f9436e72ded 100644
--- a/tests/integration/cli/commands/test_celery_command.py
+++ b/tests/integration/cli/commands/test_celery_command.py
@@ -55,14 +55,25 @@ class TestWorkerServeLogs:
mock_process.assert_called()
@conf_vars({("core", "executor"): "CeleryExecutor"})
- def test_skip_serve_logs_on_worker_start(self):
+ @pytest.mark.parametrize(
+ "skip, expected",
+ [
+ (True, ["bundle_cleanup_main"]),
+ (False, ["serve_logs", "bundle_cleanup_main"]),
+ ],
+ )
+ def test_skip_serve_logs_on_worker_start(self, skip, expected):
with (
mock.patch("airflow.providers.celery.cli.celery_command.Process")
as mock_popen,
mock.patch("airflow.providers.celery.executors.celery_executor.app"),
):
- args = self.parser.parse_args(["celery", "worker",
"--concurrency", "1", "--skip-serve-logs"])
+ args = ["celery", "worker", "--concurrency", "1"]
+ if skip:
+ args.append("--skip-serve-logs")
+ args = self.parser.parse_args(args)
with mock.patch("celery.platforms.check_privileges") as
mock_privil:
mock_privil.return_value = 0
celery_command.worker(args)
- mock_popen.assert_not_called()
+ targets = [x.kwargs["target"].__name__ for x in
mock_popen.call_args_list]
+ assert targets == expected