This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch v2-10-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/v2-10-test by this push:
new 1e498d8f41c fix schedule_downstream_tasks bug (#42582) (#43299)
1e498d8f41c is described below
commit 1e498d8f41c21ee218b9f090cf44269e703cf3c2
Author: Jarek Potiuk <[email protected]>
AuthorDate: Wed Oct 23 16:44:45 2024 +0200
fix schedule_downstream_tasks bug (#42582) (#43299)
* fix schedule_downstream_tasks bug
* remove partial_subset
* Update comment
---------
Co-authored-by: 维湘 <[email protected]>
(cherry picked from commit 3fceaa69260be80bc2123cd4664db79d96142b9f)
Co-authored-by: luoyuliuyin <[email protected]>
---
airflow/models/taskinstance.py | 18 ++++------
tests/models/test_taskinstance.py | 76 ++++++++++++++++++++++++++++++++++++++-
2 files changed, 81 insertions(+), 13 deletions(-)
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 9fe438c650f..5a6de8e6ecf 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -3865,21 +3865,15 @@ class TaskInstance(Base, LoggingMixin):
assert task
assert task.dag
- # Get a partial DAG with just the specific tasks we want to
examine.
- # In order for dep checks to work correctly, we include ourself (so
- # TriggerRuleDep can check the state of the task we just executed).
- partial_dag = task.dag.partial_subset(
- task.downstream_task_ids,
- include_downstream=True,
- include_upstream=False,
- include_direct_upstream=True,
- )
-
- dag_run.dag = partial_dag
+ # Previously, this section used task.dag.partial_subset to
retrieve a partial DAG.
+ # However, this approach is unsafe as it can result in incomplete
or incorrect task execution,
+ # leading to potential bad cases. As a result, the operation has
been removed.
+ # For more details, refer to the discussion in PR
#[https://github.com/apache/airflow/pull/42582].
+ dag_run.dag = task.dag
info = dag_run.task_instance_scheduling_decisions(session)
skippable_task_ids = {
- task_id for task_id in partial_dag.task_ids if task_id not in
task.downstream_task_ids
+ task_id for task_id in task.dag.task_ids if task_id not in
task.downstream_task_ids
}
schedulable_tis = [
diff --git a/tests/models/test_taskinstance.py
b/tests/models/test_taskinstance.py
index cbd38e9b390..468dc2c9300 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -79,7 +79,7 @@ from airflow.models.xcom import LazyXComSelectSequence, XCom
from airflow.notifications.basenotifier import BaseNotifier
from airflow.operators.bash import BashOperator
from airflow.operators.empty import EmptyOperator
-from airflow.operators.python import PythonOperator
+from airflow.operators.python import BranchPythonOperator, PythonOperator
from airflow.sensors.base import BaseSensorOperator
from airflow.sensors.python import PythonSensor
from airflow.serialization.serialized_objects import SerializedBaseOperator,
SerializedDAG
@@ -5224,6 +5224,80 @@ def
test_mapped_task_expands_in_mini_scheduler_if_upstreams_are_done(dag_maker,
assert "3 downstream tasks scheduled from follow-on schedule" in
caplog.text
[email protected]_if_database_isolation_mode
+def test_one_success_task_in_mini_scheduler_if_upstreams_are_done(dag_maker,
caplog, session):
+ """Test that mini scheduler with one_success task"""
+ with dag_maker() as dag:
+ branch = BranchPythonOperator(task_id="branch",
python_callable=lambda: "task_run")
+ task_run = BashOperator(task_id="task_run", bash_command="echo 0")
+ task_skip = BashOperator(task_id="task_skip", bash_command="echo 0")
+ task_1 = BashOperator(task_id="task_1", bash_command="echo 0")
+ task_one_success = BashOperator(
+ task_id="task_one_success", bash_command="echo 0",
trigger_rule="one_success"
+ )
+ task_2 = BashOperator(task_id="task_2", bash_command="echo 0")
+
+ task_1 >> task_2
+ branch >> task_skip
+ branch >> task_run
+ task_run >> task_one_success
+ task_skip >> task_one_success
+ task_one_success >> task_2
+ task_skip >> task_2
+
+ dr = dag_maker.create_dagrun()
+
+ branch = dr.get_task_instance(task_id="branch")
+ task_1 = dr.get_task_instance(task_id="task_1")
+ task_skip = dr.get_task_instance(task_id="task_skip")
+ branch.state = State.SUCCESS
+ task_1.state = State.SUCCESS
+ task_skip.state = State.SKIPPED
+ session.merge(branch)
+ session.merge(task_1)
+ session.merge(task_skip)
+ session.commit()
+ task_1.refresh_from_task(dag.get_task("task_1"))
+ task_1.schedule_downstream_tasks(session=session)
+
+ branch = dr.get_task_instance(task_id="branch")
+ task_run = dr.get_task_instance(task_id="task_run")
+ task_skip = dr.get_task_instance(task_id="task_skip")
+ task_1 = dr.get_task_instance(task_id="task_1")
+ task_one_success = dr.get_task_instance(task_id="task_one_success")
+ task_2 = dr.get_task_instance(task_id="task_2")
+ assert branch.state == State.SUCCESS
+ assert task_run.state == State.NONE
+ assert task_skip.state == State.SKIPPED
+ assert task_1.state == State.SUCCESS
+ # task_one_success should not be scheduled
+ assert task_one_success.state == State.NONE
+ assert task_2.state == State.SKIPPED
+ assert "0 downstream tasks scheduled from follow-on schedule" in
caplog.text
+
+ task_run = dr.get_task_instance(task_id="task_run")
+ task_run.state = State.SUCCESS
+ session.merge(task_run)
+ session.commit()
+ task_run.refresh_from_task(dag.get_task("task_run"))
+ task_run.schedule_downstream_tasks(session=session)
+
+ branch = dr.get_task_instance(task_id="branch")
+ task_run = dr.get_task_instance(task_id="task_run")
+ task_skip = dr.get_task_instance(task_id="task_skip")
+ task_1 = dr.get_task_instance(task_id="task_1")
+ task_one_success = dr.get_task_instance(task_id="task_one_success")
+ task_2 = dr.get_task_instance(task_id="task_2")
+ assert branch.state == State.SUCCESS
+ assert task_run.state == State.SUCCESS
+ assert task_skip.state == State.SKIPPED
+ assert task_1.state == State.SUCCESS
+ # task_one_success should not be scheduled
+ assert task_one_success.state == State.SCHEDULED
+ assert task_2.state == State.SKIPPED
+ assert "1 downstream tasks scheduled from follow-on schedule" in
caplog.text
+
+
@pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation
mode
def
test_mini_scheduler_not_skip_mapped_downstream_until_all_upstreams_finish(dag_maker,
session):
with dag_maker(session=session):