This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 73d9352225 Make sure we can get out of a faulty scheduler state 
(#27834)
73d9352225 is described below

commit 73d9352225bcc1f086b63f1c767d25b2d7c4c221
Author: Stijn De Haes <[email protected]>
AuthorDate: Mon Dec 5 02:12:36 2022 +0100

    Make sure we can get out of a faulty scheduler state (#27834)
    
    * Make sure we can get out of a faulty scheduler state
    
    This PR fixed the case where we have a faulty state in the database.
    The state that is fixed is that both the unmapped task instance and mapped 
task instances exist at the same time.
    
    So we have instances with map_index [-1, 0, 1].
    The -1 task instances should be removed in this case.
---
 airflow/models/abstractoperator.py  | 36 ++++++++++++++++++--------
 airflow/models/dagrun.py            |  9 ++++---
 tests/models/test_dagrun.py         | 43 ++++++++++++++++++++++++++++++-
 tests/models/test_mappedoperator.py | 51 +++++++++++++++++++++++++++++++++++++
 4 files changed, 123 insertions(+), 16 deletions(-)

diff --git a/airflow/models/abstractoperator.py 
b/airflow/models/abstractoperator.py
index d5d6ad082f..ba0a8954ae 100644
--- a/airflow/models/abstractoperator.py
+++ b/airflow/models/abstractoperator.py
@@ -484,7 +484,6 @@ class AbstractOperator(LoggingMixin, DAGNode):
                 # are not done yet, so the task can't fail yet.
                 if not self.dag or not self.dag.partial:
                     unmapped_ti.state = TaskInstanceState.UPSTREAM_FAILED
-                indexes_to_map: Iterable[int] = ()
             elif total_length < 1:
                 # If the upstream maps this to a zero-length value, simply mark
                 # the unmapped task instance as SKIPPED (if needed).
@@ -494,18 +493,33 @@ class AbstractOperator(LoggingMixin, DAGNode):
                     total_length,
                 )
                 unmapped_ti.state = TaskInstanceState.SKIPPED
-                indexes_to_map = ()
             else:
-                # Otherwise convert this into the first mapped index, and 
create
-                # TaskInstance for other indexes.
-                unmapped_ti.map_index = 0
-                self.log.debug("Updated in place to become %s", unmapped_ti)
-                all_expanded_tis.append(unmapped_ti)
-                indexes_to_map = range(1, total_length)
-            state = unmapped_ti.state
-        elif not total_length:
+                zero_index_ti_exists = (
+                    session.query(TaskInstance)
+                    .filter(
+                        TaskInstance.dag_id == self.dag_id,
+                        TaskInstance.task_id == self.task_id,
+                        TaskInstance.run_id == run_id,
+                        TaskInstance.map_index == 0,
+                    )
+                    .count()
+                    > 0
+                )
+                if not zero_index_ti_exists:
+                    # Otherwise convert this into the first mapped index, and 
create
+                    # TaskInstance for other indexes.
+                    unmapped_ti.map_index = 0
+                    self.log.debug("Updated in place to become %s", 
unmapped_ti)
+                    all_expanded_tis.append(unmapped_ti)
+                    session.flush()
+                else:
+                    self.log.debug("Deleting the original task instance: %s", 
unmapped_ti)
+                    session.delete(unmapped_ti)
+                state = unmapped_ti.state
+
+        if total_length is None or total_length < 1:
             # Nothing to fixup.
-            indexes_to_map = ()
+            indexes_to_map: Iterable[int] = ()
         else:
             # Only create "missing" ones.
             current_max_mapping = (
diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py
index f3fc068b04..c601193b27 100644
--- a/airflow/models/dagrun.py
+++ b/airflow/models/dagrun.py
@@ -769,7 +769,8 @@ class DagRun(Base, LoggingMixin):
             """Try to expand the ti, if needed.
 
             If the ti needs expansion, newly created task instances are
-            returned. The original ti is modified in-place and assigned the
+            returned as well as the original ti.
+            The original ti is also modified in-place and assigned the
             ``map_index`` of 0.
 
             If the ti does not need expansion, either because the task is not
@@ -782,8 +783,7 @@ class DagRun(Base, LoggingMixin):
             except NotMapped:  # Not a mapped task, nothing needed.
                 return None
             if expanded_tis:
-                assert expanded_tis[0] is ti
-                return expanded_tis[1:]
+                return expanded_tis
             return ()
 
         # Check dependencies.
@@ -799,12 +799,13 @@ class DagRun(Base, LoggingMixin):
             # in the scheduler to ensure that the mapped task is correctly
             # expanded before executed. Also see _revise_map_indexes_if_mapped
             # docstring for additional information.
+            new_tis = None
             if schedulable.map_index < 0:
                 new_tis = _expand_mapped_task_if_needed(schedulable)
                 if new_tis is not None:
                     additional_tis.extend(new_tis)
                     expansion_happened = True
-            if schedulable.state in SCHEDULEABLE_STATES:
+            if new_tis is None and schedulable.state in SCHEDULEABLE_STATES:
                 
ready_tis.extend(self._revise_map_indexes_if_mapped(schedulable.task, 
session=session))
                 ready_tis.append(schedulable)
 
diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py
index a3e2a50652..c08b548870 100644
--- a/tests/models/test_dagrun.py
+++ b/tests/models/test_dagrun.py
@@ -29,7 +29,15 @@ from sqlalchemy.orm.session import Session
 from airflow import settings
 from airflow.callbacks.callback_requests import DagCallbackRequest
 from airflow.decorators import task, task_group
-from airflow.models import DAG, DagBag, DagModel, DagRun, TaskInstance as TI, 
clear_task_instances
+from airflow.models import (
+    DAG,
+    DagBag,
+    DagModel,
+    DagRun,
+    TaskInstance,
+    TaskInstance as TI,
+    clear_task_instances,
+)
 from airflow.models.baseoperator import BaseOperator
 from airflow.models.taskmap import TaskMap
 from airflow.operators.empty import EmptyOperator
@@ -1285,6 +1293,39 @@ def 
test_mapped_literal_length_reduction_at_runtime_adds_removed_state(dag_maker
     ]
 
 
+def test_mapped_literal_faulty_state_in_db(dag_maker, session):
+    """
+    This test tries to recreate a faulty state in the database and checks if 
we can recover from it.
+    The state that happens is that there exists mapped task instances and the 
unmapped task instance.
+    So we have instances with map_index [-1, 0, 1]. The -1 task instances 
should be removed in this case.
+    """
+
+    with dag_maker(session=session) as dag:
+
+        @task
+        def task_1():
+            return [1, 2]
+
+        @task
+        def task_2(arg2):
+            ...
+
+        task_2.expand(arg2=task_1())
+
+    dr = dag_maker.create_dagrun()
+    ti = dr.get_task_instance(task_id="task_1")
+    ti.run()
+    decision = dr.task_instance_scheduling_decisions()
+    assert len(decision.schedulable_tis) == 2
+
+    # We insert a faulty record
+    session.add(TaskInstance(dag.get_task("task_2"), dr.execution_date, 
dr.run_id))
+    session.flush()
+
+    decision = dr.task_instance_scheduling_decisions()
+    assert len(decision.schedulable_tis) == 2
+
+
 def 
test_mapped_literal_length_with_no_change_at_runtime_doesnt_call_verify_integrity(dag_maker,
 session):
     """
     Test that when there's no change to mapped task indexes at runtime, the 
dagrun.verify_integrity
diff --git a/tests/models/test_mappedoperator.py 
b/tests/models/test_mappedoperator.py
index 1998563d70..036a12fac4 100644
--- a/tests/models/test_mappedoperator.py
+++ b/tests/models/test_mappedoperator.py
@@ -228,6 +228,57 @@ def test_expand_mapped_task_instance(dag_maker, session, 
num_existing_tis, expec
     assert indices == expected
 
 
+def test_expand_mapped_task_failed_state_in_db(dag_maker, session):
+    """
+    This test tries to recreate a faulty state in the database and checks if 
we can recover from it.
+    The state that happens is that there exists mapped task instances and the 
unmapped task instance.
+    So we have instances with map_index [-1, 0, 1]. The -1 task instances 
should be removed in this case.
+    """
+    literal = [1, 2]
+    with dag_maker(session=session):
+        task1 = BaseOperator(task_id="op1")
+        mapped = 
MockOperator.partial(task_id="task_2").expand(arg2=task1.output)
+
+    dr = dag_maker.create_dagrun()
+
+    session.add(
+        TaskMap(
+            dag_id=dr.dag_id,
+            task_id=task1.task_id,
+            run_id=dr.run_id,
+            map_index=-1,
+            length=len(literal),
+            keys=None,
+        )
+    )
+
+    for index in range(2):
+        # Give the existing TIs a state to make sure we don't change them
+        ti = TaskInstance(mapped, run_id=dr.run_id, map_index=index, 
state=TaskInstanceState.SUCCESS)
+        session.add(ti)
+    session.flush()
+
+    indices = (
+        session.query(TaskInstance.map_index, TaskInstance.state)
+        .filter_by(task_id=mapped.task_id, dag_id=mapped.dag_id, 
run_id=dr.run_id)
+        .order_by(TaskInstance.map_index)
+        .all()
+    )
+    # Make sure we have the faulty state in the database
+    assert indices == [(-1, None), (0, "success"), (1, "success")]
+
+    mapped.expand_mapped_task(dr.run_id, session=session)
+
+    indices = (
+        session.query(TaskInstance.map_index, TaskInstance.state)
+        .filter_by(task_id=mapped.task_id, dag_id=mapped.dag_id, 
run_id=dr.run_id)
+        .order_by(TaskInstance.map_index)
+        .all()
+    )
+    # The -1 index should be cleaned up
+    assert indices == [(0, "success"), (1, "success")]
+
+
 def test_expand_mapped_task_instance_skipped_on_zero(dag_maker, session):
     with dag_maker(session=session):
         task1 = BaseOperator(task_id="op1")

Reply via email to