This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch v2-10-test
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/v2-10-test by this push:
     new 1e498d8f41c fix schedule_downstream_tasks bug (#42582) (#43299)
1e498d8f41c is described below

commit 1e498d8f41c21ee218b9f090cf44269e703cf3c2
Author: Jarek Potiuk <[email protected]>
AuthorDate: Wed Oct 23 16:44:45 2024 +0200

    fix schedule_downstream_tasks bug (#42582) (#43299)
    
    * fix schedule_downstream_tasks bug
    
    * remove partial_subset
    
    * Update comment
    
    ---------
    
    Co-authored-by: 维湘 <[email protected]>
    (cherry picked from commit 3fceaa69260be80bc2123cd4664db79d96142b9f)
    
    Co-authored-by: luoyuliuyin <[email protected]>
---
 airflow/models/taskinstance.py    | 18 ++++------
 tests/models/test_taskinstance.py | 76 ++++++++++++++++++++++++++++++++++++++-
 2 files changed, 81 insertions(+), 13 deletions(-)

diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 9fe438c650f..5a6de8e6ecf 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -3865,21 +3865,15 @@ class TaskInstance(Base, LoggingMixin):
                 assert task
                 assert task.dag
 
-            # Get a partial DAG with just the specific tasks we want to 
examine.
-            # In order for dep checks to work correctly, we include ourself (so
-            # TriggerRuleDep can check the state of the task we just executed).
-            partial_dag = task.dag.partial_subset(
-                task.downstream_task_ids,
-                include_downstream=True,
-                include_upstream=False,
-                include_direct_upstream=True,
-            )
-
-            dag_run.dag = partial_dag
+            # Previously, this section used task.dag.partial_subset to 
retrieve a partial DAG.
+            # However, this approach is unsafe as it can result in incomplete 
or incorrect task execution,
+            # leading to potential bad cases. As a result, the operation has 
been removed.
+            # For more details, refer to the discussion in PR 
#[https://github.com/apache/airflow/pull/42582].
+            dag_run.dag = task.dag
             info = dag_run.task_instance_scheduling_decisions(session)
 
             skippable_task_ids = {
-                task_id for task_id in partial_dag.task_ids if task_id not in 
task.downstream_task_ids
+                task_id for task_id in task.dag.task_ids if task_id not in 
task.downstream_task_ids
             }
 
             schedulable_tis = [
diff --git a/tests/models/test_taskinstance.py 
b/tests/models/test_taskinstance.py
index cbd38e9b390..468dc2c9300 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -79,7 +79,7 @@ from airflow.models.xcom import LazyXComSelectSequence, XCom
 from airflow.notifications.basenotifier import BaseNotifier
 from airflow.operators.bash import BashOperator
 from airflow.operators.empty import EmptyOperator
-from airflow.operators.python import PythonOperator
+from airflow.operators.python import BranchPythonOperator, PythonOperator
 from airflow.sensors.base import BaseSensorOperator
 from airflow.sensors.python import PythonSensor
 from airflow.serialization.serialized_objects import SerializedBaseOperator, 
SerializedDAG
@@ -5224,6 +5224,80 @@ def 
test_mapped_task_expands_in_mini_scheduler_if_upstreams_are_done(dag_maker,
     assert "3 downstream tasks scheduled from follow-on schedule" in 
caplog.text
 
 
[email protected]_if_database_isolation_mode
+def test_one_success_task_in_mini_scheduler_if_upstreams_are_done(dag_maker, 
caplog, session):
+    """Test that mini scheduler with one_success task"""
+    with dag_maker() as dag:
+        branch = BranchPythonOperator(task_id="branch", 
python_callable=lambda: "task_run")
+        task_run = BashOperator(task_id="task_run", bash_command="echo 0")
+        task_skip = BashOperator(task_id="task_skip", bash_command="echo 0")
+        task_1 = BashOperator(task_id="task_1", bash_command="echo 0")
+        task_one_success = BashOperator(
+            task_id="task_one_success", bash_command="echo 0", 
trigger_rule="one_success"
+        )
+        task_2 = BashOperator(task_id="task_2", bash_command="echo 0")
+
+        task_1 >> task_2
+        branch >> task_skip
+        branch >> task_run
+        task_run >> task_one_success
+        task_skip >> task_one_success
+        task_one_success >> task_2
+        task_skip >> task_2
+
+    dr = dag_maker.create_dagrun()
+
+    branch = dr.get_task_instance(task_id="branch")
+    task_1 = dr.get_task_instance(task_id="task_1")
+    task_skip = dr.get_task_instance(task_id="task_skip")
+    branch.state = State.SUCCESS
+    task_1.state = State.SUCCESS
+    task_skip.state = State.SKIPPED
+    session.merge(branch)
+    session.merge(task_1)
+    session.merge(task_skip)
+    session.commit()
+    task_1.refresh_from_task(dag.get_task("task_1"))
+    task_1.schedule_downstream_tasks(session=session)
+
+    branch = dr.get_task_instance(task_id="branch")
+    task_run = dr.get_task_instance(task_id="task_run")
+    task_skip = dr.get_task_instance(task_id="task_skip")
+    task_1 = dr.get_task_instance(task_id="task_1")
+    task_one_success = dr.get_task_instance(task_id="task_one_success")
+    task_2 = dr.get_task_instance(task_id="task_2")
+    assert branch.state == State.SUCCESS
+    assert task_run.state == State.NONE
+    assert task_skip.state == State.SKIPPED
+    assert task_1.state == State.SUCCESS
+    # task_one_success should not be scheduled
+    assert task_one_success.state == State.NONE
+    assert task_2.state == State.SKIPPED
+    assert "0 downstream tasks scheduled from follow-on schedule" in 
caplog.text
+
+    task_run = dr.get_task_instance(task_id="task_run")
+    task_run.state = State.SUCCESS
+    session.merge(task_run)
+    session.commit()
+    task_run.refresh_from_task(dag.get_task("task_run"))
+    task_run.schedule_downstream_tasks(session=session)
+
+    branch = dr.get_task_instance(task_id="branch")
+    task_run = dr.get_task_instance(task_id="task_run")
+    task_skip = dr.get_task_instance(task_id="task_skip")
+    task_1 = dr.get_task_instance(task_id="task_1")
+    task_one_success = dr.get_task_instance(task_id="task_one_success")
+    task_2 = dr.get_task_instance(task_id="task_2")
+    assert branch.state == State.SUCCESS
+    assert task_run.state == State.SUCCESS
+    assert task_skip.state == State.SKIPPED
+    assert task_1.state == State.SUCCESS
+    # task_one_success should not be scheduled
+    assert task_one_success.state == State.SCHEDULED
+    assert task_2.state == State.SKIPPED
+    assert "1 downstream tasks scheduled from follow-on schedule" in 
caplog.text
+
+
 @pytest.mark.skip_if_database_isolation_mode  # Does not work in db isolation 
mode
 def 
test_mini_scheduler_not_skip_mapped_downstream_until_all_upstreams_finish(dag_maker,
 session):
     with dag_maker(session=session):

Reply via email to