This is an automated email from the ASF dual-hosted git repository.

Lee-W pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new def50a90211 feat(dag_command.py): change to use bulk clear (#68280)
def50a90211 is described below

commit def50a9021171afe9f65d021fdebb24f6309f300
Author: PoAn Yang <[email protected]>
AuthorDate: Fri Jun 12 19:50:32 2026 +0900

    feat(dag_command.py): change to use bulk clear (#68280)
---
 .../src/airflow/cli/commands/dag_command.py        |  53 +++++++--
 .../tests/unit/cli/commands/test_dag_command.py    | 122 ++++++++++++++++++++-
 2 files changed, 165 insertions(+), 10 deletions(-)

diff --git a/airflow-core/src/airflow/cli/commands/dag_command.py 
b/airflow-core/src/airflow/cli/commands/dag_command.py
index 1d5e4237576..580a61a1b4f 100644
--- a/airflow-core/src/airflow/cli/commands/dag_command.py
+++ b/airflow-core/src/airflow/cli/commands/dag_command.py
@@ -46,6 +46,7 @@ from airflow.jobs.job import Job
 from airflow.models import DagModel, DagRun, TaskInstance
 from airflow.models.errors import ParseImportError
 from airflow.models.serialized_dag import SerializedDagModel
+from airflow.models.taskinstance import clear_task_instances
 from airflow.timetables.base import TimeRestriction
 from airflow.utils import cli as cli_utils
 from airflow.utils.cli import (
@@ -55,10 +56,10 @@ from airflow.utils.cli import (
     validate_dag_bundle_arg,
 )
 from airflow.utils.dot_renderer import render_dag, render_dag_dependencies
-from airflow.utils.helpers import ask_yesno
+from airflow.utils.helpers import ask_yesno, chunks
 from airflow.utils.providers_configuration_loader import 
providers_configuration_loaded
 from airflow.utils.session import NEW_SESSION, create_session, provide_session
-from airflow.utils.state import DagRunState
+from airflow.utils.state import DagRunState, TaskInstanceState
 from airflow.utils.types import DagRunType
 
 if TYPE_CHECKING:
@@ -75,6 +76,9 @@ DAG_DETAIL_FIELDS = {*DAGResponse.model_fields, 
*DAGResponse.model_computed_fiel
 
 log = logging.getLogger(__name__)
 
+# Chunk size for bulk delete.
+_RUN_CHUNK_SIZE = 500
+
 
 @cli_utils.action_cli
 @deprecated_for_airflowctl("airflowctl dags trigger")
@@ -189,15 +193,46 @@ def dag_clear(args, *, session: Session = NEW_SESSION) -> 
None:
             print("Cancelled, nothing was cleared.")
             return
 
+    cleared = _bulk_clear_runs(
+        args.dag_id,
+        run_ids,
+        only_failed=args.only_failed,
+        only_running=args.only_running,
+        session=session,
+    )
+    print(f"Cleared {cleared} task instance(s) across {len(run_ids)} Dag 
run(s).")
+
+
+def _bulk_clear_runs(
+    dag_id: str,
+    run_ids: list[str],
+    only_failed: bool,
+    only_running: bool,
+    session: Session,
+) -> int:
+    """Clear task instances for the given run_ids in chunks instead of one 
transaction per run."""
+    state_filter: list[TaskInstanceState] = []
+    if only_failed:
+        state_filter += [TaskInstanceState.FAILED, 
TaskInstanceState.UPSTREAM_FAILED]
+    if only_running:
+        state_filter += [TaskInstanceState.RUNNING]
+
     cleared = 0
-    for run_id in run_ids:
-        cleared += dag.clear(
-            run_id=run_id,
-            only_failed=args.only_failed,
-            only_running=args.only_running,
-            session=session,
+    for chunk_run_ids in chunks(run_ids, _RUN_CHUNK_SIZE):
+        ti_query = select(TaskInstance).where(
+            TaskInstance.dag_id == dag_id,
+            TaskInstance.run_id.in_(chunk_run_ids),
         )
-    print(f"Cleared {cleared} task instance(s) across {len(run_ids)} Dag 
run(s).")
+        if state_filter:
+            ti_query = ti_query.where(TaskInstance.state.in_(state_filter))
+        tis = session.scalars(ti_query).all()
+        if not tis:
+            continue
+        clear_task_instances(list(tis), session=session)
+        session.flush()
+        cleared += len(tis)
+
+    return cleared
 
 
 @cli_utils.action_cli
diff --git a/airflow-core/tests/unit/cli/commands/test_dag_command.py 
b/airflow-core/tests/unit/cli/commands/test_dag_command.py
index 2dafed8f15e..0c9f2647f0f 100644
--- a/airflow-core/tests/unit/cli/commands/test_dag_command.py
+++ b/airflow-core/tests/unit/cli/commands/test_dag_command.py
@@ -42,7 +42,7 @@ from airflow.exceptions import AirflowException
 from airflow.models import DagModel, DagRun
 from airflow.models.dagbag import DBDagBag
 from airflow.models.serialized_dag import SerializedDagModel
-from airflow.models.taskinstance import TaskInstance
+from airflow.models.taskinstance import TaskInstance, clear_task_instances
 from airflow.providers.standard.operators.empty import EmptyOperator
 from airflow.providers.standard.triggers.temporal import DateTimeTrigger, 
TimeDeltaTrigger
 from airflow.sdk import DAG, Asset, BaseOperator, CronPartitionTimetable, 
PartitionedAssetTimetable, task
@@ -1066,6 +1066,13 @@ class TestCliDagsClear:
                 for row in session.scalars(select(DagRun).where(DagRun.dag_id 
== self.DAG_ID)).all()
             }
 
+    def _get_run_clear_numbers(self):
+        with create_session() as session:
+            return {
+                row.run_id: row.clear_number
+                for row in session.scalars(select(DagRun).where(DagRun.dag_id 
== self.DAG_ID)).all()
+            }
+
     def test_requires_a_selector(self, parser):
         args = parser.parse_args(["dags", "clear", self.DAG_ID, "--yes"])
         with pytest.raises(SystemExit, match="One of --run-id, 
--partition-key"):
@@ -1758,6 +1765,119 @@ class TestCliDagsClear:
         assert states["asset_2026_04_15"] == DagRunState.SUCCESS
         assert states["asset_non_part"] == DagRunState.SUCCESS
 
+    @pytest.mark.usefixtures("seeded_partitioned_runs")
+    @pytest.mark.parametrize(
+        ("chunk_size", "expected_calls"),
+        [
+            pytest.param(500, 1, id="single-chunk"),
+            pytest.param(2, 2, id="multiple-chunks"),
+        ],
+    )
+    def test_clears_each_matching_run_once_across_chunks(self, parser, 
chunk_size, expected_calls):
+        """Every matching run is cleared exactly once, however run_ids split 
into chunks.
+
+        clear_task_instances is called once per chunk (not once per run), 
every matching
+        run is re-queued, and each run's clear_number advances by exactly 1 — 
proving a
+        run's TIs are never split across chunks.
+        """
+        call_count = 0
+
+        def counting_clear(tis, session, **kwargs):
+            nonlocal call_count
+            call_count += 1
+            return clear_task_instances(tis, session, **kwargs)
+
+        args = parser.parse_args(
+            [
+                "dags",
+                "clear",
+                self.DAG_ID,
+                "--partition-date-start",
+                "2026-03-08T00:00:00",
+                "--partition-date-end",
+                "2026-03-14T00:00:00",
+                "--yes",
+            ]
+        )
+        with (
+            mock.patch.object(dag_command, "_RUN_CHUNK_SIZE", chunk_size),
+            mock.patch(
+                "airflow.cli.commands.dag_command.clear_task_instances",
+                side_effect=counting_clear,
+            ),
+        ):
+            dag_command.dag_clear(args)
+
+        assert call_count == expected_calls
+
+        states = self._get_run_states()
+        assert states["part_2026_03_08"] == DagRunState.QUEUED
+        assert states["part_2026_03_10"] == DagRunState.QUEUED
+        assert states["part_2026_03_14"] == DagRunState.QUEUED
+        assert states["non_partitioned"] == DagRunState.SUCCESS
+
+        clear_numbers = self._get_run_clear_numbers()
+        assert clear_numbers["part_2026_03_08"] == 1
+        assert clear_numbers["part_2026_03_10"] == 1
+        assert clear_numbers["part_2026_03_14"] == 1
+        assert clear_numbers["non_partitioned"] == 0
+
+    @pytest.mark.usefixtures("seeded_partitioned_runs")
+    def test_does_not_clear_runs_of_other_dags(self, parser, dag_maker):
+        """A run_id collision across DAGs must not clear the other DAG's task 
instances."""
+        other_dag_id = "test_dags_clear_other_dag"
+        with dag_maker(
+            other_dag_id,
+            schedule=CronPartitionTimetable("0 0 * * *", 
timezone=pendulum.UTC),
+            start_date=datetime(2026, 3, 1, tzinfo=pendulum.UTC),
+            catchup=True,
+            serialized=True,
+        ):
+            EmptyOperator(task_id="t1")
+        # Same run_id and partition_date as a run cleared below, but a 
different DAG.
+        dag_maker.create_dagrun(
+            run_id="part_2026_03_08",
+            state=DagRunState.SUCCESS,
+            logical_date=None,
+            partition_date=datetime(2026, 3, 8, tzinfo=pendulum.UTC),
+            partition_key="2026-03-08T00:00:00",
+        )
+        dag_maker.sync_dagbag_to_db()
+        # If dag_id is not filtered, clearing the other DAG would reset this 
TI to None.
+        with create_session() as session:
+            session.execute(
+                TaskInstance.__table__.update()
+                .where(TaskInstance.dag_id == other_dag_id)
+                .values(state=TaskInstanceState.SUCCESS)
+            )
+
+        args = parser.parse_args(
+            [
+                "dags",
+                "clear",
+                self.DAG_ID,
+                "--partition-date-start",
+                "2026-03-08T00:00:00",
+                "--partition-date-end",
+                "2026-03-14T00:00:00",
+                "--yes",
+            ]
+        )
+        dag_command.dag_clear(args)
+
+        # The target DAG's same-named run must be cleared.
+        assert self._get_run_states()["part_2026_03_08"] == DagRunState.QUEUED
+
+        # The other DAG's same-named run must be left untouched.
+        with create_session() as session:
+            other_run = session.scalars(
+                select(DagRun).where(DagRun.dag_id == other_dag_id, 
DagRun.run_id == "part_2026_03_08")
+            ).one()
+            assert other_run.state == DagRunState.SUCCESS
+            assert other_run.clear_number == 0
+            other_ti = 
session.scalars(select(TaskInstance).where(TaskInstance.dag_id == 
other_dag_id)).one()
+            assert other_ti.state == TaskInstanceState.SUCCESS
+
 
 class TestDagDetailsIsBackfillable:
     """Tests for the is_backfillable computation in _get_dagbag_dag_details."""

Reply via email to