This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new a31db1fb010 Fix duplicate task execution when running multiple 
schedulers (#60330)
a31db1fb010 is described below

commit a31db1fb010c7c2c5306e00c2e10cd69e662bf31
Author: Ephraim Anierobi <[email protected]>
AuthorDate: Thu Mar 19 05:29:53 2026 +0100

    Fix duplicate task execution when running multiple schedulers (#60330)
    
    In HA, two scheduler processes can race to schedule the same
    TaskInstance. Previously DagRun.schedule_tis() updated rows by ti.id
    alone, so a scheduler could increment try_number and transition
    state even after another scheduler had already advanced the TI (e.g. to
    SCHEDULED/QUEUED), resulting in duplicate attempts being queued.
    
    This change makes scheduling idempotent under HA races by:
    - Guarding schedule_tis() DB updates to only apply when the TI is still
    in schedulable states (derived from SCHEDULEABLE_STATES, handling
    NULL explicitly).
    
    - Using a single CASE (next_try_number) so reschedules
    (UP_FOR_RESCHEDULE) do not start a new try, and applying this
    consistently to both normal scheduling and the EmptyOperator fast-path.
    The CASE uses TI.id (not TI.state) to avoid MySQL SET left-to-right
    evaluation issues.
    
    Adds regression tests covering:
    - TI already queued by another scheduler.
    - EmptyOperator fast-path blocked when TI is already QUEUED/RUNNING.
    - UP_FOR_RESCHEDULE scheduling keeps try_number unchanged.
    - Only one "scheduler" update succeeds when competing.
    
    Closes: #57618
    
    Co-authored-by: Kaxil Naik <[email protected]>
---
 airflow-core/src/airflow/models/dagrun.py     |  37 +++--
 airflow-core/tests/unit/models/test_dagrun.py | 210 +++++++++++++++++++++++++-
 2 files changed, 235 insertions(+), 12 deletions(-)

diff --git a/airflow-core/src/airflow/models/dagrun.py 
b/airflow-core/src/airflow/models/dagrun.py
index 23fdfb72564..1a608b6f0c5 100644
--- a/airflow-core/src/airflow/models/dagrun.py
+++ b/airflow-core/src/airflow/models/dagrun.py
@@ -1956,11 +1956,14 @@ class DagRun(Base, LoggingMixin):
         # tasks using EmptyOperator and without on_execute_callback / 
on_success_callback
         empty_ti_ids: list[UUID] = []
         schedulable_ti_ids: list[UUID] = []
+        reschedule_ti_ids: set[UUID] = set()
         debug_try_number_check = self.log.isEnabledFor(logging.DEBUG)
         expected_try_number_by_ti_id: dict[UUID, tuple[int, int, str | None]] 
= {}
         for ti in schedulable_tis:
             if ti.is_schedulable:
                 schedulable_ti_ids.append(ti.id)
+                if ti.state == TaskInstanceState.UP_FOR_RESCHEDULE:
+                    reschedule_ti_ids.add(ti.id)
                 if debug_try_number_check:
                     expected_try_number_by_ti_id[ti.id] = (
                         ti.try_number
@@ -1990,7 +1993,25 @@ class DagRun(Base, LoggingMixin):
                 empty_ti_ids.append(ti.id)
 
         count = 0
-
+        # Don't only check if the TI.id is in id_chunk
+        # but also check if the TI.state is in the schedulable states.
+        # Plus, a scheduled empty operator should not be scheduled again.
+        non_null_schedulable_states = tuple(s for s in SCHEDULEABLE_STATES if 
s is not None)
+        schedulable_state_clause = or_(
+            TI.state.is_(None),
+            TI.state.in_(non_null_schedulable_states),
+        )
+        # Use TI.id (not TI.state) in the CASE to decide try_number. MySQL 
evaluates
+        # SET left-to-right, so referencing TI.state here would see the 
already-updated
+        # value if state is assigned first. TI.id is never modified in the SET 
clause.
+        next_try_number = (
+            case(
+                (TI.id.in_(reschedule_ti_ids), TI.try_number),
+                else_=TI.try_number + 1,
+            )
+            if reschedule_ti_ids
+            else TI.try_number + 1
+        )
         if schedulable_ti_ids:
             schedulable_ti_ids_chunks = chunks(
                 schedulable_ti_ids, max_tis_per_query or 
len(schedulable_ti_ids)
@@ -1998,17 +2019,11 @@ class DagRun(Base, LoggingMixin):
             for id_chunk in schedulable_ti_ids_chunks:
                 result = session.execute(
                     update(TI)
-                    .where(TI.id.in_(id_chunk))
+                    .where(TI.id.in_(id_chunk), schedulable_state_clause)
                     .values(
                         state=TaskInstanceState.SCHEDULED,
                         scheduled_dttm=timezone.utcnow(),
-                        try_number=case(
-                            (
-                                or_(TI.state.is_(None), TI.state != 
TaskInstanceState.UP_FOR_RESCHEDULE),
-                                TI.try_number + 1,
-                            ),
-                            else_=TI.try_number,
-                        ),
+                        try_number=next_try_number,
                     )
                     .execution_options(synchronize_session=False)
                 )
@@ -2052,13 +2067,13 @@ class DagRun(Base, LoggingMixin):
             for id_chunk in dummy_ti_ids_chunks:
                 result = session.execute(
                     update(TI)
-                    .where(TI.id.in_(id_chunk))
+                    .where(TI.id.in_(id_chunk), schedulable_state_clause)
                     .values(
                         state=TaskInstanceState.SUCCESS,
                         start_date=timezone.utcnow(),
                         end_date=timezone.utcnow(),
                         duration=0,
-                        try_number=TI.try_number + 1,
+                        try_number=next_try_number,
                     )
                     .execution_options(
                         synchronize_session=False,
diff --git a/airflow-core/tests/unit/models/test_dagrun.py 
b/airflow-core/tests/unit/models/test_dagrun.py
index 2bed00cf75c..850357b68c9 100644
--- a/airflow-core/tests/unit/models/test_dagrun.py
+++ b/airflow-core/tests/unit/models/test_dagrun.py
@@ -28,7 +28,11 @@ from unittest.mock import ANY, call
 import pendulum
 import pytest
 from opentelemetry.sdk.trace import TracerProvider
-from sqlalchemy import func, select
+from sqlalchemy import (
+    func,
+    select,
+    update,
+)
 from sqlalchemy.orm import joinedload
 
 from airflow import settings
@@ -55,6 +59,7 @@ from airflow.serialization.serialized_objects import 
LazyDeserializedDAG
 from airflow.settings import get_policy_plugin_manager
 from airflow.task.trigger_rule import TriggerRule
 from airflow.triggers.base import StartTriggerArgs
+from airflow.utils.session import create_session
 from airflow.utils.state import DagRunState, State, TaskInstanceState
 from airflow.utils.types import DagRunTriggeredByType, DagRunType
 
@@ -2046,6 +2051,209 @@ def test_schedule_tis_map_index(dag_maker, session):
     assert ti2.state == TaskInstanceState.SUCCESS
 
 
+def 
test_schedule_tis_does_not_increment_try_number_if_ti_already_queued_by_other_scheduler(
+    dag_maker, session
+):
+    with dag_maker(session=session) as dag:
+        BashOperator(task_id="task", bash_command="echo 1")
+
+    dr = dag_maker.create_dagrun(session=session)
+    ti = dr.get_task_instance("task", session=session)
+    assert ti is not None
+    ti.refresh_from_task(dag.get_task("task"))
+    assert ti.state is None
+
+    # The stale scheduler picks try 1.
+    ti.try_number = 1
+    session.flush()
+    session.commit()
+
+    # Another scheduler already queued the TI in DB (same try).
+    with create_session() as other_session:
+        filter_for_tis = TI.filter_for_tis([ti])
+        assert filter_for_tis is not None
+        other_session.execute(
+            update(TI)
+            .where(filter_for_tis)
+            .values(
+                state=TaskInstanceState.QUEUED,
+                try_number=1,
+            )
+            .execution_options(synchronize_session=False)
+        )
+
+    # This stale scheduler still has a stale TI object; schedule_tis must be a 
no-op.
+    assert dr.schedule_tis((ti,), session=session) == 0
+
+    refreshed_ti = session.scalar(
+        select(TI).where(
+            TI.dag_id == ti.dag_id,
+            TI.task_id == ti.task_id,
+            TI.run_id == ti.run_id,
+            TI.map_index == ti.map_index,
+        )
+    )
+    assert refreshed_ti.state == TaskInstanceState.QUEUED
+    assert refreshed_ti.try_number == 1
+
+
+def 
test_schedule_tis_empty_operator_does_not_short_circuit_if_ti_already_queued(dag_maker,
 session):
+    with dag_maker(session=session) as dag:
+        EmptyOperator(task_id="empty_task")
+
+    dr = dag_maker.create_dagrun(session=session)
+    ti = dr.get_task_instance("empty_task", session=session)
+    ti.refresh_from_task(dag.get_task("empty_task"))
+    assert ti.state is None
+
+    # Stale scheduler picks TI
+    ti.try_number = 1
+    session.flush()
+    session.commit()
+
+    # Another scheduler already queued it.
+    with create_session() as other_session:
+        filter_for_tis = TI.filter_for_tis([ti])
+        assert filter_for_tis is not None
+        other_session.execute(
+            update(TI)
+            .where(filter_for_tis)
+            .values(
+                state=TaskInstanceState.QUEUED,
+                try_number=1,
+            )
+            .execution_options(synchronize_session=False)
+        )
+
+    # no shortcircuit
+    assert dr.schedule_tis((ti,), session=session) == 0
+
+    refreshed_ti = session.scalar(
+        select(TI).where(
+            TI.dag_id == ti.dag_id,
+            TI.task_id == ti.task_id,
+            TI.run_id == ti.run_id,
+            TI.map_index == ti.map_index,
+        )
+    )
+    assert refreshed_ti is not None
+    assert refreshed_ti.state == TaskInstanceState.QUEUED
+    assert refreshed_ti.try_number == 1
+
+
+def 
test_schedule_tis_up_for_reschedule_does_not_increment_try_number(dag_maker, 
session):
+    with dag_maker(session=session) as dag:
+        BashOperator(task_id="task", bash_command="echo 1")
+
+    dr = dag_maker.create_dagrun(session=session)
+    ti = dr.get_task_instance("task", session=session)
+    ti.refresh_from_task(dag.get_task("task"))
+
+    ti.state = TaskInstanceState.UP_FOR_RESCHEDULE
+    ti.try_number = 3
+    session.commit()
+
+    assert dr.schedule_tis((ti,), session=session) == 1
+    session.commit()
+
+    # schedule_tis uses synchronize_session=False, so the session may still 
hold a stale instance.
+    # Expire the identity map so the SELECT reflects the DB row.
+    session.expire_all()
+    refreshed_ti = session.scalar(
+        select(TI).where(
+            TI.dag_id == ti.dag_id,
+            TI.task_id == ti.task_id,
+            TI.run_id == ti.run_id,
+            TI.map_index == ti.map_index,
+        )
+    )
+    assert refreshed_ti.state == TaskInstanceState.SCHEDULED
+    assert refreshed_ti.try_number == 3
+
+
+def test_schedule_tis_empty_operator_is_noop_if_ti_already_running(dag_maker, 
session):
+    with dag_maker(session=session) as dag:
+        EmptyOperator(task_id="empty_task")
+
+    dr = dag_maker.create_dagrun(session=session)
+    ti = dr.get_task_instance("empty_task", session=session)
+    ti.refresh_from_task(dag.get_task("empty_task"))
+
+    ti.try_number = 3
+    session.commit()
+
+    with create_session() as other_session:
+        filter_for_tis = TI.filter_for_tis([ti])
+        assert filter_for_tis is not None
+        other_session.execute(
+            update(TI)
+            .where(filter_for_tis)
+            .values(
+                state=TaskInstanceState.RUNNING,
+                try_number=3,
+            )
+            .execution_options(synchronize_session=False)
+        )
+
+    assert dr.schedule_tis((ti,), session=session) == 0
+
+    refreshed_ti = session.scalar(
+        select(TI).where(
+            TI.dag_id == ti.dag_id,
+            TI.task_id == ti.task_id,
+            TI.run_id == ti.run_id,
+            TI.map_index == ti.map_index,
+        )
+    )
+    assert refreshed_ti.state == TaskInstanceState.RUNNING
+    assert refreshed_ti.try_number == 3
+
+
+def 
test_schedule_tis_only_one_scheduler_update_succeeds_when_competing(dag_maker, 
session):
+    with dag_maker(session=session) as dag:
+        BashOperator(task_id="task", bash_command="echo 1")
+
+    dr = dag_maker.create_dagrun(session=session)
+    ti = dr.get_task_instance("task", session=session)
+    ti.refresh_from_task(dag.get_task("task"))
+    assert ti.state is None
+
+    ti.try_number = 0
+    session.commit()
+
+    # Scheduler B loads TI *before* Scheduler A commits — both see state=None.
+    with create_session() as scheduler_b_session:
+        ti_b = scheduler_b_session.scalar(
+            select(TI).where(
+                TI.dag_id == ti.dag_id,
+                TI.task_id == ti.task_id,
+                TI.run_id == ti.run_id,
+                TI.map_index == ti.map_index,
+            )
+        )
+        assert ti_b is not None
+        assert ti_b.state is None
+
+        # Scheduler A schedules first.
+        assert dr.schedule_tis((ti,), session=session) == 1
+        session.commit()
+
+        # Scheduler B tries with its stale TI object; should be a no-op.
+        assert dr.schedule_tis((ti_b,), session=scheduler_b_session) == 0
+
+    session.expire_all()
+    refreshed_ti = session.scalar(
+        select(TI).where(
+            TI.dag_id == ti.dag_id,
+            TI.task_id == ti.task_id,
+            TI.run_id == ti.run_id,
+            TI.map_index == ti.map_index,
+        )
+    )
+    assert refreshed_ti.state == TaskInstanceState.SCHEDULED
+    assert refreshed_ti.try_number == 1
+
+
 @pytest.mark.xfail(reason="We can't keep this behaviour with remote workers 
where scheduler can't reach xcom")
 @pytest.mark.need_serialized_dag
 def test_schedule_tis_start_trigger(dag_maker, session):

Reply via email to