This is an automated email from the ASF dual-hosted git repository.
Lee-W pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new def50a90211 feat(dag_command.py): change to use bulk clear (#68280)
def50a90211 is described below
commit def50a9021171afe9f65d021fdebb24f6309f300
Author: PoAn Yang <[email protected]>
AuthorDate: Fri Jun 12 19:50:32 2026 +0900
feat(dag_command.py): change to use bulk clear (#68280)
---
.../src/airflow/cli/commands/dag_command.py | 53 +++++++--
.../tests/unit/cli/commands/test_dag_command.py | 122 ++++++++++++++++++++-
2 files changed, 165 insertions(+), 10 deletions(-)
diff --git a/airflow-core/src/airflow/cli/commands/dag_command.py
b/airflow-core/src/airflow/cli/commands/dag_command.py
index 1d5e4237576..580a61a1b4f 100644
--- a/airflow-core/src/airflow/cli/commands/dag_command.py
+++ b/airflow-core/src/airflow/cli/commands/dag_command.py
@@ -46,6 +46,7 @@ from airflow.jobs.job import Job
from airflow.models import DagModel, DagRun, TaskInstance
from airflow.models.errors import ParseImportError
from airflow.models.serialized_dag import SerializedDagModel
+from airflow.models.taskinstance import clear_task_instances
from airflow.timetables.base import TimeRestriction
from airflow.utils import cli as cli_utils
from airflow.utils.cli import (
@@ -55,10 +56,10 @@ from airflow.utils.cli import (
validate_dag_bundle_arg,
)
from airflow.utils.dot_renderer import render_dag, render_dag_dependencies
-from airflow.utils.helpers import ask_yesno
+from airflow.utils.helpers import ask_yesno, chunks
from airflow.utils.providers_configuration_loader import
providers_configuration_loaded
from airflow.utils.session import NEW_SESSION, create_session, provide_session
-from airflow.utils.state import DagRunState
+from airflow.utils.state import DagRunState, TaskInstanceState
from airflow.utils.types import DagRunType
if TYPE_CHECKING:
@@ -75,6 +76,9 @@ DAG_DETAIL_FIELDS = {*DAGResponse.model_fields,
*DAGResponse.model_computed_fiel
log = logging.getLogger(__name__)
+# Chunk size for bulk delete.
+_RUN_CHUNK_SIZE = 500
+
@cli_utils.action_cli
@deprecated_for_airflowctl("airflowctl dags trigger")
@@ -189,15 +193,46 @@ def dag_clear(args, *, session: Session = NEW_SESSION) ->
None:
print("Cancelled, nothing was cleared.")
return
+ cleared = _bulk_clear_runs(
+ args.dag_id,
+ run_ids,
+ only_failed=args.only_failed,
+ only_running=args.only_running,
+ session=session,
+ )
+ print(f"Cleared {cleared} task instance(s) across {len(run_ids)} Dag
run(s).")
+
+
+def _bulk_clear_runs(
+ dag_id: str,
+ run_ids: list[str],
+ only_failed: bool,
+ only_running: bool,
+ session: Session,
+) -> int:
+ """Clear task instances for the given run_ids in chunks instead of one
transaction per run."""
+ state_filter: list[TaskInstanceState] = []
+ if only_failed:
+ state_filter += [TaskInstanceState.FAILED,
TaskInstanceState.UPSTREAM_FAILED]
+ if only_running:
+ state_filter += [TaskInstanceState.RUNNING]
+
cleared = 0
- for run_id in run_ids:
- cleared += dag.clear(
- run_id=run_id,
- only_failed=args.only_failed,
- only_running=args.only_running,
- session=session,
+ for chunk_run_ids in chunks(run_ids, _RUN_CHUNK_SIZE):
+ ti_query = select(TaskInstance).where(
+ TaskInstance.dag_id == dag_id,
+ TaskInstance.run_id.in_(chunk_run_ids),
)
- print(f"Cleared {cleared} task instance(s) across {len(run_ids)} Dag
run(s).")
+ if state_filter:
+ ti_query = ti_query.where(TaskInstance.state.in_(state_filter))
+ tis = session.scalars(ti_query).all()
+ if not tis:
+ continue
+ clear_task_instances(list(tis), session=session)
+ session.flush()
+ cleared += len(tis)
+
+ return cleared
@cli_utils.action_cli
diff --git a/airflow-core/tests/unit/cli/commands/test_dag_command.py
b/airflow-core/tests/unit/cli/commands/test_dag_command.py
index 2dafed8f15e..0c9f2647f0f 100644
--- a/airflow-core/tests/unit/cli/commands/test_dag_command.py
+++ b/airflow-core/tests/unit/cli/commands/test_dag_command.py
@@ -42,7 +42,7 @@ from airflow.exceptions import AirflowException
from airflow.models import DagModel, DagRun
from airflow.models.dagbag import DBDagBag
from airflow.models.serialized_dag import SerializedDagModel
-from airflow.models.taskinstance import TaskInstance
+from airflow.models.taskinstance import TaskInstance, clear_task_instances
from airflow.providers.standard.operators.empty import EmptyOperator
from airflow.providers.standard.triggers.temporal import DateTimeTrigger,
TimeDeltaTrigger
from airflow.sdk import DAG, Asset, BaseOperator, CronPartitionTimetable,
PartitionedAssetTimetable, task
@@ -1066,6 +1066,13 @@ class TestCliDagsClear:
for row in session.scalars(select(DagRun).where(DagRun.dag_id
== self.DAG_ID)).all()
}
+ def _get_run_clear_numbers(self):
+ with create_session() as session:
+ return {
+ row.run_id: row.clear_number
+ for row in session.scalars(select(DagRun).where(DagRun.dag_id
== self.DAG_ID)).all()
+ }
+
def test_requires_a_selector(self, parser):
args = parser.parse_args(["dags", "clear", self.DAG_ID, "--yes"])
with pytest.raises(SystemExit, match="One of --run-id,
--partition-key"):
@@ -1758,6 +1765,119 @@ class TestCliDagsClear:
assert states["asset_2026_04_15"] == DagRunState.SUCCESS
assert states["asset_non_part"] == DagRunState.SUCCESS
+ @pytest.mark.usefixtures("seeded_partitioned_runs")
+ @pytest.mark.parametrize(
+ ("chunk_size", "expected_calls"),
+ [
+ pytest.param(500, 1, id="single-chunk"),
+ pytest.param(2, 2, id="multiple-chunks"),
+ ],
+ )
+ def test_clears_each_matching_run_once_across_chunks(self, parser,
chunk_size, expected_calls):
+ """Every matching run is cleared exactly once, however run_ids split
into chunks.
+
+ clear_task_instances is called once per chunk (not once per run),
every matching
+ run is re-queued, and each run's clear_number advances by exactly 1 —
proving a
+ run's TIs are never split across chunks.
+ """
+ call_count = 0
+
+ def counting_clear(tis, session, **kwargs):
+ nonlocal call_count
+ call_count += 1
+ return clear_task_instances(tis, session, **kwargs)
+
+ args = parser.parse_args(
+ [
+ "dags",
+ "clear",
+ self.DAG_ID,
+ "--partition-date-start",
+ "2026-03-08T00:00:00",
+ "--partition-date-end",
+ "2026-03-14T00:00:00",
+ "--yes",
+ ]
+ )
+ with (
+ mock.patch.object(dag_command, "_RUN_CHUNK_SIZE", chunk_size),
+ mock.patch(
+ "airflow.cli.commands.dag_command.clear_task_instances",
+ side_effect=counting_clear,
+ ),
+ ):
+ dag_command.dag_clear(args)
+
+ assert call_count == expected_calls
+
+ states = self._get_run_states()
+ assert states["part_2026_03_08"] == DagRunState.QUEUED
+ assert states["part_2026_03_10"] == DagRunState.QUEUED
+ assert states["part_2026_03_14"] == DagRunState.QUEUED
+ assert states["non_partitioned"] == DagRunState.SUCCESS
+
+ clear_numbers = self._get_run_clear_numbers()
+ assert clear_numbers["part_2026_03_08"] == 1
+ assert clear_numbers["part_2026_03_10"] == 1
+ assert clear_numbers["part_2026_03_14"] == 1
+ assert clear_numbers["non_partitioned"] == 0
+
+ @pytest.mark.usefixtures("seeded_partitioned_runs")
+ def test_does_not_clear_runs_of_other_dags(self, parser, dag_maker):
+ """A run_id collision across DAGs must not clear the other DAG's task
instances."""
+ other_dag_id = "test_dags_clear_other_dag"
+ with dag_maker(
+ other_dag_id,
+ schedule=CronPartitionTimetable("0 0 * * *",
timezone=pendulum.UTC),
+ start_date=datetime(2026, 3, 1, tzinfo=pendulum.UTC),
+ catchup=True,
+ serialized=True,
+ ):
+ EmptyOperator(task_id="t1")
+ # Same run_id and partition_date as a run cleared below, but a
different DAG.
+ dag_maker.create_dagrun(
+ run_id="part_2026_03_08",
+ state=DagRunState.SUCCESS,
+ logical_date=None,
+ partition_date=datetime(2026, 3, 8, tzinfo=pendulum.UTC),
+ partition_key="2026-03-08T00:00:00",
+ )
+ dag_maker.sync_dagbag_to_db()
+ # If dag_id is not filtered, clearing the other DAG would reset this
TI to None.
+ with create_session() as session:
+ session.execute(
+ TaskInstance.__table__.update()
+ .where(TaskInstance.dag_id == other_dag_id)
+ .values(state=TaskInstanceState.SUCCESS)
+ )
+
+ args = parser.parse_args(
+ [
+ "dags",
+ "clear",
+ self.DAG_ID,
+ "--partition-date-start",
+ "2026-03-08T00:00:00",
+ "--partition-date-end",
+ "2026-03-14T00:00:00",
+ "--yes",
+ ]
+ )
+ dag_command.dag_clear(args)
+
+ # The target DAG's same-named run must be cleared.
+ assert self._get_run_states()["part_2026_03_08"] == DagRunState.QUEUED
+
+ # The other DAG's same-named run must be left untouched.
+ with create_session() as session:
+ other_run = session.scalars(
+ select(DagRun).where(DagRun.dag_id == other_dag_id,
DagRun.run_id == "part_2026_03_08")
+ ).one()
+ assert other_run.state == DagRunState.SUCCESS
+ assert other_run.clear_number == 0
+ other_ti =
session.scalars(select(TaskInstance).where(TaskInstance.dag_id ==
other_dag_id)).one()
+ assert other_ti.state == TaskInstanceState.SUCCESS
+
class TestDagDetailsIsBackfillable:
"""Tests for the is_backfillable computation in _get_dagbag_dag_details."""