This is an automated email from the ASF dual-hosted git repository.

pierrejeambrun pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 023ae241c47 Fix setup/teardown auto-inclusion when clearing or marking 
tasks (#68193)
023ae241c47 is described below

commit 023ae241c47af34d8a11c7d46713d41617ce14a6
Author: Haseeb Malik <[email protected]>
AuthorDate: Thu Jun 11 09:40:14 2026 -0400

    Fix setup/teardown auto-inclusion when clearing or marking tasks (#68193)
---
 airflow-core/src/airflow/api/common/mark_tasks.py  |  2 +
 .../src/airflow/serialization/definitions/dag.py   | 16 ++++++
 .../tests/unit/api/common/test_mark_tasks.py       | 36 ++++++++++++-
 airflow-core/tests/unit/models/test_cleartasks.py  | 63 ++++++++++++++++++++++
 4 files changed, 116 insertions(+), 1 deletion(-)

diff --git a/airflow-core/src/airflow/api/common/mark_tasks.py 
b/airflow-core/src/airflow/api/common/mark_tasks.py
index efc2a016a43..120a31a67b7 100644
--- a/airflow-core/src/airflow/api/common/mark_tasks.py
+++ b/airflow-core/src/airflow/api/common/mark_tasks.py
@@ -137,6 +137,8 @@ def find_task_relatives(
             yield task.task_id
         if downstream:
             for relative in task.get_flat_relatives(upstream=False):
+                if relative.is_teardown:
+                    continue
                 yield relative.task_id
         if upstream:
             for relative in task.get_flat_relatives(upstream=True):
diff --git a/airflow-core/src/airflow/serialization/definitions/dag.py 
b/airflow-core/src/airflow/serialization/definitions/dag.py
index 33e48aa72bd..ac22806eb55 100644
--- a/airflow-core/src/airflow/serialization/definitions/dag.py
+++ b/airflow-core/src/airflow/serialization/definitions/dag.py
@@ -1229,6 +1229,22 @@ class SerializedDAG:
             # Yes, having `+=` doesn't make sense, but this was the existing 
behaviour
             state += [TaskInstanceState.RUNNING]
 
+        if task_ids is not None:
+            plain_task_ids: set[str] = {tid[0] if isinstance(tid, tuple) else 
tid for tid in task_ids}
+            if plain_task_ids:
+                added_ids = (
+                    set(
+                        self.partial_subset(
+                            task_ids=plain_task_ids,
+                            include_downstream=False,
+                            include_upstream=False,
+                        ).task_dict
+                    )
+                    - plain_task_ids
+                )
+                if added_ids:
+                    task_ids = [*task_ids, *added_ids]
+
         tis_result = self._get_task_instances(
             task_ids=task_ids,
             start_date=start_date,
diff --git a/airflow-core/tests/unit/api/common/test_mark_tasks.py 
b/airflow-core/tests/unit/api/common/test_mark_tasks.py
index e50a4de7f15..4504ddceda0 100644
--- a/airflow-core/tests/unit/api/common/test_mark_tasks.py
+++ b/airflow-core/tests/unit/api/common/test_mark_tasks.py
@@ -21,7 +21,11 @@ from typing import TYPE_CHECKING
 import pytest
 from sqlalchemy import select
 
-from airflow.api.common.mark_tasks import set_dag_run_state_to_failed, 
set_dag_run_state_to_success
+from airflow.api.common.mark_tasks import (
+    find_task_relatives,
+    set_dag_run_state_to_failed,
+    set_dag_run_state_to_success,
+)
 from airflow.models.dagrun import DagRun
 from airflow.providers.standard.operators.empty import EmptyOperator
 from airflow.utils.state import DagRunState, State, TaskInstanceState
@@ -119,3 +123,33 @@ def 
test_set_dag_run_state_to_success_finished_teardown(dag_maker: DagMaker[Seri
     assert task_dict["failed"].state == TaskInstanceState.SUCCESS
     if finished_state != TaskInstanceState.SUCCESS:
         assert task_dict["teardown"].state == TaskInstanceState.SUCCESS
+
+
+def test_find_task_relatives_downstream_skips_teardowns(dag_maker: 
DagMaker[SerializedDAG]):
+    with dag_maker("test_find_task_relatives_downstream_skips_teardowns") as 
dag:
+        setup_t = EmptyOperator(task_id="setup_t").as_setup()
+        normal_t = EmptyOperator(task_id="normal_t")
+        teardown_t = 
EmptyOperator(task_id="teardown_t").as_teardown(setups=setup_t)
+        setup_t >> normal_t >> teardown_t
+    dag_maker.create_dagrun()
+    normal_task = dag.get_task("normal_t")
+
+    relatives = list(find_task_relatives([normal_task], downstream=True, 
upstream=False))
+
+    assert "normal_t" in relatives
+    assert "teardown_t" not in relatives
+
+
+def test_find_task_relatives_upstream_still_includes_setups(dag_maker: 
DagMaker[SerializedDAG]):
+    with dag_maker("test_find_task_relatives_upstream_still_includes_setups") 
as dag:
+        setup_t = EmptyOperator(task_id="setup_t").as_setup()
+        normal_t = EmptyOperator(task_id="normal_t")
+        teardown_t = 
EmptyOperator(task_id="teardown_t").as_teardown(setups=setup_t)
+        setup_t >> normal_t >> teardown_t
+    dag_maker.create_dagrun()
+    normal_task = dag.get_task("normal_t")
+
+    relatives = list(find_task_relatives([normal_task], downstream=False, 
upstream=True))
+
+    assert "normal_t" in relatives
+    assert "setup_t" in relatives
diff --git a/airflow-core/tests/unit/models/test_cleartasks.py 
b/airflow-core/tests/unit/models/test_cleartasks.py
index 421f4776433..9972191e1eb 100644
--- a/airflow-core/tests/unit/models/test_cleartasks.py
+++ b/airflow-core/tests/unit/models/test_cleartasks.py
@@ -950,3 +950,66 @@ class TestClearTasks:
         )
 
         assert count == 0
+
+    def test_clear_normal_task_includes_setup_and_teardown(self, dag_maker):
+        with dag_maker("test_clear_normal_task_includes_setup_and_teardown") 
as dag:
+            setup_t = EmptyOperator(task_id="setup_t").as_setup()
+            normal_t = EmptyOperator(task_id="normal_t")
+            teardown_t = 
EmptyOperator(task_id="teardown_t").as_teardown(setups=setup_t)
+            setup_t >> normal_t >> teardown_t
+        dr = dag_maker.create_dagrun()
+        for ti in dr.get_task_instances():
+            ti.set_state(TaskInstanceState.SUCCESS)
+        dag_maker.session.flush()
+
+        cleared = dag.clear(
+            dry_run=True,
+            task_ids=["normal_t"],
+            run_id=dr.run_id,
+            session=dag_maker.session,
+        )
+
+        cleared_ids = {ti.task_id for ti in cleared}
+        assert cleared_ids == {"setup_t", "normal_t", "teardown_t"}
+
+    def test_clear_setup_includes_paired_teardown(self, dag_maker):
+        with dag_maker("test_clear_setup_includes_paired_teardown") as dag:
+            setup_t = EmptyOperator(task_id="setup_t").as_setup()
+            normal_t = EmptyOperator(task_id="normal_t")
+            teardown_t = 
EmptyOperator(task_id="teardown_t").as_teardown(setups=setup_t)
+            setup_t >> normal_t >> teardown_t
+        dr = dag_maker.create_dagrun()
+        for ti in dr.get_task_instances():
+            ti.set_state(TaskInstanceState.SUCCESS)
+        dag_maker.session.flush()
+
+        cleared = dag.clear(
+            dry_run=True,
+            task_ids=["setup_t"],
+            run_id=dr.run_id,
+            session=dag_maker.session,
+        )
+
+        cleared_ids = {ti.task_id for ti in cleared}
+        assert cleared_ids == {"setup_t", "teardown_t"}
+
+    def test_clear_teardown_does_not_include_setup(self, dag_maker):
+        with dag_maker("test_clear_teardown_does_not_include_setup") as dag:
+            setup_t = EmptyOperator(task_id="setup_t").as_setup()
+            normal_t = EmptyOperator(task_id="normal_t")
+            teardown_t = 
EmptyOperator(task_id="teardown_t").as_teardown(setups=setup_t)
+            setup_t >> normal_t >> teardown_t
+        dr = dag_maker.create_dagrun()
+        for ti in dr.get_task_instances():
+            ti.set_state(TaskInstanceState.SUCCESS)
+        dag_maker.session.flush()
+
+        cleared = dag.clear(
+            dry_run=True,
+            task_ids=["teardown_t"],
+            run_id=dr.run_id,
+            session=dag_maker.session,
+        )
+
+        cleared_ids = {ti.task_id for ti in cleared}
+        assert cleared_ids == {"teardown_t"}

Reply via email to