This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new eb60174dbda Add new arguments to db_clean to explicitly include or 
exclude DAGs (#56663)
eb60174dbda is described below

commit eb60174dbda2767f78eae346f9d546d00e75c0a5
Author: Matt Usifer <[email protected]>
AuthorDate: Sat Jan 3 20:37:26 2026 -0500

    Add new arguments to db_clean to explicitly include or exclude DAGs (#56663)
    
    * Specify DAG IDs to include or exclude from db_clean
    
    * Fix _confirm_delete
    
    * Add documentation
    
    * Fix tables without a dag_id_column, and keep last for dag_version
    
    * Formatting, add more tests
    
    * Fix deprecated sqlalchemy usage
---
 airflow-core/docs/howto/usage-cli.rst              |  2 +
 airflow-core/src/airflow/cli/cli_config.py         | 14 ++++
 .../src/airflow/cli/commands/db_command.py         |  2 +
 airflow-core/src/airflow/utils/db_cleanup.py       | 88 ++++++++++++++++++----
 .../tests/unit/cli/commands/test_db_command.py     | 78 +++++++++++++++++++
 airflow-core/tests/unit/utils/test_db_cleanup.py   | 66 ++++++++++++++++
 6 files changed, 235 insertions(+), 15 deletions(-)

diff --git a/airflow-core/docs/howto/usage-cli.rst 
b/airflow-core/docs/howto/usage-cli.rst
index e954947d502..b80faf89b3e 100644
--- a/airflow-core/docs/howto/usage-cli.rst
+++ b/airflow-core/docs/howto/usage-cli.rst
@@ -213,6 +213,8 @@ The ``db clean`` command works by deleting from each table 
the records older tha
 
 You can optionally provide a list of tables to perform deletes on. If no list 
of tables is supplied, all tables will be included.
 
+You can filter cleanup to specific DAGs using ``--dag-ids`` (comma-separated 
list), or exclude specific DAGs using ``--exclude-dag-ids`` (comma-separated 
list). These options allow you to target or avoid cleanup for particular DAGs.
+
 You can use the ``--dry-run`` option to print the row counts in the primary 
tables to be cleaned.
 
 By default, ``db clean`` will archive purged rows in tables of the form 
``_airflow_deleted__<table>__<timestamp>``.  If you don't want the data 
preserved in this way, you may supply argument ``--skip-archive``.
diff --git a/airflow-core/src/airflow/cli/cli_config.py 
b/airflow-core/src/airflow/cli/cli_config.py
index 7d8e8611b96..0a89a7b1668 100644
--- a/airflow-core/src/airflow/cli/cli_config.py
+++ b/airflow-core/src/airflow/cli/cli_config.py
@@ -520,6 +520,18 @@ ARG_DB_BATCH_SIZE = Arg(
         "Lower values reduce long-running locks but increase the number of 
batches."
     ),
 )
+ARG_DAG_IDS = Arg(
+    ("--dag-ids",),
+    default=None,
+    help="Only cleanup data related to the given dag_id",
+    type=string_list_type,
+)
+ARG_EXCLUDE_DAG_IDS = Arg(
+    ("--exclude-dag-ids",),
+    default=None,
+    help="Avoid cleaning up data related to the given dag_ids",
+    type=string_list_type,
+)
 
 # pool
 ARG_POOL_NAME = Arg(("pool",), metavar="NAME", help="Pool name")
@@ -1527,6 +1539,8 @@ DB_COMMANDS = (
             ARG_YES,
             ARG_DB_SKIP_ARCHIVE,
             ARG_DB_BATCH_SIZE,
+            ARG_DAG_IDS,
+            ARG_EXCLUDE_DAG_IDS,
         ),
     ),
     ActionCommand(
diff --git a/airflow-core/src/airflow/cli/commands/db_command.py 
b/airflow-core/src/airflow/cli/commands/db_command.py
index f9ec7365f57..638dc1c74d6 100644
--- a/airflow-core/src/airflow/cli/commands/db_command.py
+++ b/airflow-core/src/airflow/cli/commands/db_command.py
@@ -301,6 +301,8 @@ def cleanup_tables(args):
         confirm=not args.yes,
         skip_archive=args.skip_archive,
         batch_size=args.batch_size,
+        dag_ids=args.dag_ids,
+        exclude_dag_ids=args.exclude_dag_ids,
     )
 
 
diff --git a/airflow-core/src/airflow/utils/db_cleanup.py 
b/airflow-core/src/airflow/utils/db_cleanup.py
index 558146fc84d..8b6ce79a772 100644
--- a/airflow-core/src/airflow/utils/db_cleanup.py
+++ b/airflow-core/src/airflow/utils/db_cleanup.py
@@ -79,6 +79,7 @@ class _TableConfig:
     table_name: str
     recency_column_name: str
     extra_columns: list[str] | None = None
+    dag_id_column_name: str | None = None
     keep_last: bool = False
     keep_last_filters: Any | None = None
     keep_last_group_by: Any | None = None
@@ -89,9 +90,19 @@ class _TableConfig:
 
     def __post_init__(self):
         self.recency_column = column(self.recency_column_name)
-        self.orm_model: Base = table(
-            self.table_name, *[column(x) for x in self.extra_columns or []], 
self.recency_column
-        )
+        if self.dag_id_column_name is None:
+            self.dag_id_column = None
+            self.orm_model: Base = table(
+                self.table_name, *[column(x) for x in self.extra_columns or 
[]], self.recency_column
+            )
+        else:
+            self.dag_id_column = column(self.dag_id_column_name)
+            self.orm_model: Base = table(
+                self.table_name,
+                *[column(x) for x in self.extra_columns or []],
+                self.dag_id_column,
+                self.recency_column,
+            )
 
     def __lt__(self, other):
         return self.table_name < other.table_name
@@ -101,6 +112,7 @@ class _TableConfig:
         return {
             "table": self.orm_model.name,
             "recency_column": str(self.recency_column),
+            "dag_id_column": str(self.dag_id_column),
             "keep_last": self.keep_last,
             "keep_last_filters": [str(x) for x in self.keep_last_filters] if 
self.keep_last_filters else None,
             "keep_last_group_by": str(self.keep_last_group_by),
@@ -108,34 +120,39 @@ class _TableConfig:
 
 
 config_list: list[_TableConfig] = [
-    _TableConfig(table_name="job", recency_column_name="latest_heartbeat"),
+    _TableConfig(table_name="job", recency_column_name="latest_heartbeat", 
dag_id_column_name="dag_id"),
     _TableConfig(
         table_name="dag",
         recency_column_name="last_parsed_time",
         dependent_tables=["dag_version", "deadline"],
+        dag_id_column_name="dag_id",
     ),
     _TableConfig(
         table_name="dag_run",
         recency_column_name="start_date",
+        dag_id_column_name="dag_id",
         extra_columns=["dag_id", "run_type"],
         keep_last=True,
         keep_last_filters=[column("run_type") != DagRunType.MANUAL],
         keep_last_group_by=["dag_id"],
         dependent_tables=["task_instance", "deadline"],
     ),
-    _TableConfig(table_name="asset_event", recency_column_name="timestamp"),
+    _TableConfig(table_name="asset_event", recency_column_name="timestamp", 
dag_id_column_name="dag_id"),
     _TableConfig(table_name="import_error", recency_column_name="timestamp"),
-    _TableConfig(table_name="log", recency_column_name="dttm"),
-    _TableConfig(table_name="sla_miss", recency_column_name="timestamp"),
+    _TableConfig(table_name="log", recency_column_name="dttm", 
dag_id_column_name="dag_id"),
+    _TableConfig(table_name="sla_miss", recency_column_name="timestamp", 
dag_id_column_name="dag_id"),
     _TableConfig(
         table_name="task_instance",
         recency_column_name="start_date",
         dependent_tables=["task_instance_history", "xcom"],
+        dag_id_column_name="dag_id",
+    ),
+    _TableConfig(
+        table_name="task_instance_history", recency_column_name="start_date", 
dag_id_column_name="dag_id"
     ),
-    _TableConfig(table_name="task_instance_history", 
recency_column_name="start_date"),
-    _TableConfig(table_name="task_reschedule", 
recency_column_name="start_date"),
-    _TableConfig(table_name="xcom", recency_column_name="timestamp"),
-    _TableConfig(table_name="_xcom_archive", recency_column_name="timestamp"),
+    _TableConfig(table_name="task_reschedule", 
recency_column_name="start_date", dag_id_column_name="dag_id"),
+    _TableConfig(table_name="xcom", recency_column_name="timestamp", 
dag_id_column_name="dag_id"),
+    _TableConfig(table_name="_xcom_archive", recency_column_name="timestamp", 
dag_id_column_name="dag_id"),
     _TableConfig(table_name="callback_request", 
recency_column_name="created_at"),
     _TableConfig(table_name="celery_taskmeta", 
recency_column_name="date_done"),
     _TableConfig(table_name="celery_tasksetmeta", 
recency_column_name="date_done"),
@@ -148,8 +165,11 @@ config_list: list[_TableConfig] = [
         table_name="dag_version",
         recency_column_name="created_at",
         dependent_tables=["task_instance", "dag_run"],
+        dag_id_column_name="dag_id",
+        keep_last=True,
+        keep_last_group_by=["dag_id"],
     ),
-    _TableConfig(table_name="deadline", recency_column_name="deadline_time"),
+    _TableConfig(table_name="deadline", recency_column_name="deadline_time", 
dag_id_column_name="dag_id"),
 ]
 
 # We need to have `fallback="database"` because this is executed at top level 
code and provider configuration
@@ -308,6 +328,9 @@ def _build_query(
     keep_last_group_by,
     clean_before_timestamp: DateTime,
     session: Session,
+    dag_id_column=None,
+    dag_ids: list[str] | None = None,
+    exclude_dag_ids: list[str] | None = None,
     **kwargs,
 ) -> Query:
     base_table_alias = "base"
@@ -315,6 +338,15 @@ def _build_query(
     query = 
session.query(base_table).with_entities(text(f"{base_table_alias}.*"))
     base_table_recency_col = base_table.c[recency_column.name]
     conditions = [base_table_recency_col < clean_before_timestamp]
+
+    if (dag_ids or exclude_dag_ids) and dag_id_column is not None:
+        base_table_dag_id_col = base_table.c[dag_id_column.name]
+
+        if dag_ids:
+            conditions.append(base_table_dag_id_col.in_(dag_ids))
+        if exclude_dag_ids:
+            conditions.append(base_table_dag_id_col.not_in(exclude_dag_ids))
+
     if keep_last:
         max_date_col_name = "max_date_per_group"
         group_by_columns: list[Any] = [column(x) for x in keep_last_group_by]
@@ -345,6 +377,9 @@ def _cleanup_table(
     keep_last_filters,
     keep_last_group_by,
     clean_before_timestamp: DateTime,
+    dag_id_column=None,
+    dag_ids=None,
+    exclude_dag_ids=None,
     dry_run: bool = True,
     verbose: bool = False,
     skip_archive: bool = False,
@@ -358,6 +393,9 @@ def _cleanup_table(
     query = _build_query(
         orm_model=orm_model,
         recency_column=recency_column,
+        dag_id_column=dag_id_column,
+        dag_ids=dag_ids,
+        exclude_dag_ids=exclude_dag_ids,
         keep_last=keep_last,
         keep_last_filters=keep_last_filters,
         keep_last_group_by=keep_last_group_by,
@@ -380,10 +418,18 @@ def _cleanup_table(
     session.commit()
 
 
-def _confirm_delete(*, date: DateTime, tables: list[str]) -> None:
+def _confirm_delete(
+    *,
+    date: DateTime,
+    tables: list[str],
+    dag_ids: list[str] | None = None,
+    exclude_dag_ids: list[str] | None = None,
+) -> None:
     for_tables = f" for tables {tables!r}" if tables else ""
+    for_dags = f" for the following dags: {dag_ids!r}" if dag_ids else ""
+    excluding_dags = f" excluding the following dags: {exclude_dag_ids!r}" if 
exclude_dag_ids else ""
     question = (
-        f"You have requested that we purge all data prior to 
{date}{for_tables}.\n"
+        f"You have requested that we purge all data prior to 
{date}{for_tables}{for_dags}{excluding_dags}.\n"
         f"This is irreversible.  Consider backing up the tables first and / or 
doing a dry run "
         f"with option --dry-run.\n"
         f"Enter 'delete rows' (without quotes) to proceed."
@@ -493,6 +539,8 @@ def run_cleanup(
     *,
     clean_before_timestamp: DateTime,
     table_names: list[str] | None = None,
+    dag_ids: list[str] | None = None,
+    exclude_dag_ids: list[str] | None = None,
     dry_run: bool = False,
     verbose: bool = False,
     confirm: bool = True,
@@ -513,6 +561,9 @@ def run_cleanup(
     :param clean_before_timestamp: The timestamp before which data should be 
purged
     :param table_names: Optional. List of table names to perform maintenance 
on.  If list not provided,
         will perform maintenance on all tables.
+    :param dag_ids: Optional. List of dag ids to perform maintenance on.  If 
list not provided,
+        will perform maintenance on all dags.
+    :param exclude_dag_ids: Optional. List of dag ids to exclude from 
maintenance.
     :param dry_run: If true, print rows meeting deletion criteria
     :param verbose: If true, may provide more detailed output.
     :param confirm: Require user input to confirm before processing deletions.
@@ -532,7 +583,12 @@ def run_cleanup(
         )
         _print_config(configs=effective_config_dict)
     if not dry_run and confirm:
-        _confirm_delete(date=clean_before_timestamp, 
tables=sorted(effective_table_names))
+        _confirm_delete(
+            date=clean_before_timestamp,
+            tables=sorted(effective_table_names),
+            dag_ids=dag_ids,
+            exclude_dag_ids=exclude_dag_ids,
+        )
     existing_tables = reflect_tables(tables=None, session=session).tables
 
     for table_name, table_config in effective_config_dict.items():
@@ -540,6 +596,8 @@ def run_cleanup(
             with _suppress_with_logging(table_name, session):
                 _cleanup_table(
                     clean_before_timestamp=clean_before_timestamp,
+                    dag_ids=dag_ids,
+                    exclude_dag_ids=exclude_dag_ids,
                     dry_run=dry_run,
                     verbose=verbose,
                     **table_config.__dict__,
diff --git a/airflow-core/tests/unit/cli/commands/test_db_command.py 
b/airflow-core/tests/unit/cli/commands/test_db_command.py
index fcd18c308b1..013f1bbb1c8 100644
--- a/airflow-core/tests/unit/cli/commands/test_db_command.py
+++ b/airflow-core/tests/unit/cli/commands/test_db_command.py
@@ -701,6 +701,8 @@ class TestCLIDBClean:
             db_command.cleanup_tables(args)
         run_cleanup_mock.assert_called_once_with(
             table_names=None,
+            dag_ids=None,
+            exclude_dag_ids=None,
             dry_run=False,
             clean_before_timestamp=pendulum.parse(timestamp, tz=timezone),
             verbose=False,
@@ -722,6 +724,8 @@ class TestCLIDBClean:
 
         run_cleanup_mock.assert_called_once_with(
             table_names=None,
+            dag_ids=None,
+            exclude_dag_ids=None,
             dry_run=False,
             clean_before_timestamp=pendulum.parse(timestamp),
             verbose=False,
@@ -749,6 +753,8 @@ class TestCLIDBClean:
 
         run_cleanup_mock.assert_called_once_with(
             table_names=None,
+            dag_ids=None,
+            exclude_dag_ids=None,
             dry_run=False,
             clean_before_timestamp=pendulum.parse("2021-01-01 00:00:00Z"),
             verbose=False,
@@ -776,6 +782,8 @@ class TestCLIDBClean:
 
         run_cleanup_mock.assert_called_once_with(
             table_names=None,
+            dag_ids=None,
+            exclude_dag_ids=None,
             dry_run=False,
             clean_before_timestamp=pendulum.parse("2021-01-01 00:00:00Z"),
             verbose=False,
@@ -803,6 +811,8 @@ class TestCLIDBClean:
 
         run_cleanup_mock.assert_called_once_with(
             table_names=None,
+            dag_ids=None,
+            exclude_dag_ids=None,
             dry_run=expected,
             clean_before_timestamp=pendulum.parse("2021-01-01 00:00:00Z"),
             verbose=False,
@@ -832,6 +842,8 @@ class TestCLIDBClean:
 
         run_cleanup_mock.assert_called_once_with(
             table_names=expected,
+            dag_ids=None,
+            exclude_dag_ids=None,
             dry_run=False,
             clean_before_timestamp=pendulum.parse("2021-01-01 00:00:00Z"),
             verbose=False,
@@ -859,6 +871,8 @@ class TestCLIDBClean:
 
         run_cleanup_mock.assert_called_once_with(
             table_names=None,
+            dag_ids=None,
+            exclude_dag_ids=None,
             dry_run=False,
             clean_before_timestamp=pendulum.parse("2021-01-01 00:00:00Z"),
             verbose=expected,
@@ -886,6 +900,8 @@ class TestCLIDBClean:
 
         run_cleanup_mock.assert_called_once_with(
             table_names=None,
+            dag_ids=None,
+            exclude_dag_ids=None,
             dry_run=False,
             clean_before_timestamp=pendulum.parse("2021-01-01 00:00:00Z"),
             verbose=False,
@@ -894,6 +910,68 @@ class TestCLIDBClean:
             batch_size=expected,
         )
 
+    @pytest.mark.parametrize(
+        ("extra_args", "expected"), [(["--dag-ids", "dag1, dag2"], ["dag1", 
"dag2"]), ([], None)]
+    )
+    @patch("airflow.cli.commands.db_command.run_cleanup")
+    def test_dag_ids(self, run_cleanup_mock, extra_args, expected):
+        """
+        When dag_ids are included in the args then dag_ids should be passed 
in, or None otherwise
+        """
+        args = self.parser.parse_args(
+            [
+                "db",
+                "clean",
+                "--clean-before-timestamp",
+                "2021-01-01",
+                *extra_args,
+            ]
+        )
+        db_command.cleanup_tables(args)
+
+        run_cleanup_mock.assert_called_once_with(
+            table_names=None,
+            dry_run=False,
+            dag_ids=expected,
+            exclude_dag_ids=None,
+            clean_before_timestamp=pendulum.parse("2021-01-01 00:00:00Z"),
+            verbose=False,
+            confirm=True,
+            skip_archive=False,
+            batch_size=None,
+        )
+
+    @pytest.mark.parametrize(
+        ("extra_args", "expected"), [(["--exclude-dag-ids", "dag1, dag2"], 
["dag1", "dag2"]), ([], None)]
+    )
+    @patch("airflow.cli.commands.db_command.run_cleanup")
+    def test_exclude_dag_ids(self, run_cleanup_mock, extra_args, expected):
+        """
+        When exclude_dag_ids are included in the args then exclude_dag_ids 
should be passed in, or None otherwise
+        """
+        args = self.parser.parse_args(
+            [
+                "db",
+                "clean",
+                "--clean-before-timestamp",
+                "2021-01-01",
+                *extra_args,
+            ]
+        )
+        db_command.cleanup_tables(args)
+
+        run_cleanup_mock.assert_called_once_with(
+            table_names=None,
+            dry_run=False,
+            dag_ids=None,
+            exclude_dag_ids=expected,
+            clean_before_timestamp=pendulum.parse("2021-01-01 00:00:00Z"),
+            verbose=False,
+            confirm=True,
+            skip_archive=False,
+            batch_size=None,
+        )
+
     @patch("airflow.cli.commands.db_command.export_archived_records")
     @patch("airflow.cli.commands.db_command.os.path.isdir", return_value=True)
     def test_export_archived_records(self, os_mock, export_archived_mock):
diff --git a/airflow-core/tests/unit/utils/test_db_cleanup.py 
b/airflow-core/tests/unit/utils/test_db_cleanup.py
index 867453103c2..7b0d12532cc 100644
--- a/airflow-core/tests/unit/utils/test_db_cleanup.py
+++ b/airflow-core/tests/unit/utils/test_db_cleanup.py
@@ -355,6 +355,72 @@ class TestDBCleanup:
                 f"Expected archive tables not found: {expected_archived - 
actual_archived}"
             )
 
+    @pytest.mark.parametrize(
+        ("dag_ids", "exclude_dag_ids", "expected_remaining_dag_ids"),
+        [
+            pytest.param(["dag1"], None, {"dag2", "dag3"}, 
id="include_single_dag"),
+            pytest.param(["dag1", "dag2"], None, {"dag3"}, 
id="include_multiple_dags"),
+            pytest.param(None, ["dag3"], {"dag3"}, id="exclude_single_dag"),
+            pytest.param(None, ["dag2", "dag3"], {"dag2", "dag3"}, 
id="exclude_multiple_dags"),
+            pytest.param(["dag1", "dag2"], ["dag2"], {"dag2", "dag3"}, 
id="include_and_exclude"),
+            pytest.param(None, None, set(), id="no_filtering_all_deleted"),
+        ],
+    )
+    def test_cleanup_with_dag_id_filtering(self, dag_ids, exclude_dag_ids, 
expected_remaining_dag_ids):
+        """
+        Verify that dag_ids and exclude_dag_ids parameters correctly 
include/exclude
+        specific DAGs during cleanup
+        """
+        base_date = pendulum.DateTime(2022, 1, 1, 
tzinfo=pendulum.timezone("UTC"))
+
+        with create_session() as session:
+            bundle_name = "testing"
+            session.add(DagBundleModel(name=bundle_name))
+            session.flush()
+
+            for dag_id in ["dag1", "dag2", "dag3"]:
+                dag = DAG(dag_id=dag_id)
+                dm = DagModel(dag_id=dag_id, bundle_name=bundle_name)
+                session.add(dm)
+                
SerializedDagModel.write_dag(LazyDeserializedDAG.from_dag(dag), 
bundle_name=bundle_name)
+                dag_version = DagVersion.get_latest_version(dag.dag_id)
+
+                start_date = base_date
+                dag_run = DagRun(
+                    dag.dag_id,
+                    run_id=f"{dag_id}_run",
+                    run_type=DagRunType.MANUAL,
+                    start_date=start_date,
+                )
+                ti = TaskInstance(
+                    PythonOperator(task_id="dummy-task", 
python_callable=print),
+                    run_id=dag_run.run_id,
+                    dag_version_id=dag_version.id,
+                )
+                ti.dag_id = dag.dag_id
+                ti.start_date = start_date
+                session.add(dag_run)
+                session.add(ti)
+            session.commit()
+
+            clean_before_date = base_date.add(days=10)
+            run_cleanup(
+                clean_before_timestamp=clean_before_date,
+                table_names=["task_instance"],
+                dag_ids=dag_ids,
+                exclude_dag_ids=exclude_dag_ids,
+                dry_run=False,
+                confirm=False,
+                session=session,
+            )
+
+            remaining_tis = session.scalars(select(TaskInstance)).all()
+            remaining_dag_ids = {ti.dag_id for ti in remaining_tis}
+
+            assert remaining_dag_ids == expected_remaining_dag_ids, (
+                f"Expected {expected_remaining_dag_ids} to remain, but got 
{remaining_dag_ids}"
+            )
+
     @pytest.mark.parametrize(
         ("skip_archive", "expected_archives"),
         [pytest.param(True, 0, id="skip_archive"), pytest.param(False, 1, 
id="do_archive")],

Reply via email to