This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 5f2628d36c Count mapped upstreams only if all are finished (#30641)
5f2628d36c is described below

commit 5f2628d36cb8481ee21bd79ac184fd8fdce3e47d
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Sun Apr 23 03:00:34 2023 +0800

    Count mapped upstreams only if all are finished (#30641)
    
    * Fix Pydantic TI handling in XComArg.resolve()
    
    * Count mapped upstreams only if all are finished
    
    An XComArg's get_task_map_length() should only return an integer when
    the *entire* task has finished. However, before this patch, it may
    attempt to count a mapped upstream even when some (or all!) of its
    expanded tis are still unfinished, resulting its downstream to be
    expanded prematurely.
    
    This patch adds an additional check before we count upstream results to
    ensure all the upstreams are actually finished.
    
    * Use SQL IN to find unfinished TI instead
    
    This needs a special workaround for a NULL quirk in SQL.
---
 airflow/models/xcom_arg.py        | 22 ++++++++++++++++++++-
 tests/models/test_taskinstance.py | 41 +++++++++++++++++++++++++++++++++++++++
 2 files changed, 62 insertions(+), 1 deletion(-)

diff --git a/airflow/models/xcom_arg.py b/airflow/models/xcom_arg.py
index ce4fb58ffa..d8b42ba819 100644
--- a/airflow/models/xcom_arg.py
+++ b/airflow/models/xcom_arg.py
@@ -21,7 +21,7 @@ import contextlib
 import inspect
 from typing import TYPE_CHECKING, Any, Callable, Iterator, Mapping, Sequence, 
Union, overload
 
-from sqlalchemy import func
+from sqlalchemy import func, or_
 from sqlalchemy.orm import Session
 
 from airflow.exceptions import AirflowException, XComNotFound
@@ -33,6 +33,7 @@ from airflow.utils.edgemodifier import EdgeModifier
 from airflow.utils.mixins import ResolveMixin
 from airflow.utils.session import NEW_SESSION, provide_session
 from airflow.utils.setup_teardown import SetupTeardownContext
+from airflow.utils.state import State
 from airflow.utils.types import NOTSET, ArgNotSet
 from airflow.utils.xcom import XCOM_RETURN_KEY
 
@@ -309,11 +310,26 @@ class PlainXComArg(XComArg):
         return super().zip(*others, fillvalue=fillvalue)
 
     def get_task_map_length(self, run_id: str, *, session: Session) -> int | 
None:
+        from airflow.models.taskinstance import TaskInstance
         from airflow.models.taskmap import TaskMap
         from airflow.models.xcom import XCom
 
         task = self.operator
         if isinstance(task, MappedOperator):
+            unfinished_ti_count_query = 
session.query(func.count(TaskInstance.map_index)).filter(
+                TaskInstance.dag_id == task.dag_id,
+                TaskInstance.run_id == run_id,
+                TaskInstance.task_id == task.task_id,
+                # Special NULL treatment is needed because 'state' can be NULL.
+                # The "IN" part would produce "NULL NOT IN ..." and eventually
+                # "NULl = NULL", which is a big no-no in SQL.
+                or_(
+                    TaskInstance.state.is_(None),
+                    TaskInstance.state.in_(s.value for s in State.unfinished 
if s is not None),
+                ),
+            )
+            if unfinished_ti_count_query.scalar():
+                return None  # Not all of the expanded tis are done yet.
             query = session.query(func.count(XCom.map_index)).filter(
                 XCom.dag_id == task.dag_id,
                 XCom.run_id == run_id,
@@ -332,7 +348,11 @@ class PlainXComArg(XComArg):
 
     @provide_session
     def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any:
+        from airflow.models.taskinstance import TaskInstance
+
         ti = context["ti"]
+        assert isinstance(ti, TaskInstance), "Wait for AIP-44 implementation 
to complete"
+
         task_id = self.operator.task_id
         map_indexes = ti.get_relevant_upstream_map_indexes(
             self.operator,
diff --git a/tests/models/test_taskinstance.py 
b/tests/models/test_taskinstance.py
index 098a0acd04..f5ad9d30ee 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -3968,3 +3968,44 @@ def 
test_mapped_task_expands_in_mini_scheduler_if_upstreams_are_done(dag_maker,
         middle_ti = dr.get_task_instance(task_id="middle_task", map_index=i)
         assert middle_ti.state == State.SCHEDULED
     assert "3 downstream tasks scheduled from follow-on schedule" in 
caplog.text
+
+
+def 
test_mini_scheduler_not_skip_mapped_downstream_until_all_upstreams_finish(dag_maker,
 session):
+    with dag_maker(session=session):
+
+        @task
+        def generate() -> list[list[int]]:
+            return []
+
+        @task
+        def a_sum(numbers: list[int]) -> int:
+            return sum(numbers)
+
+        @task
+        def b_double(summed: int) -> int:
+            return summed * 2
+
+        @task
+        def c_gather(result) -> None:
+            pass
+
+        static = EmptyOperator(task_id="static")
+
+        summed = a_sum.expand(numbers=generate())
+        doubled = b_double.expand(summed=summed)
+        static >> c_gather(doubled)
+
+    dr: DagRun = dag_maker.create_dagrun()
+    tis = {(ti.task_id, ti.map_index): ti for ti in dr.task_instances}
+
+    static_ti = tis[("static", -1)]
+    static_ti.run(session=session)
+    static_ti.schedule_downstream_tasks(session=session)
+    # No tasks should be skipped yet!
+    assert not dr.get_task_instances([TaskInstanceState.SKIPPED], 
session=session)
+
+    generate_ti = tis[("generate", -1)]
+    generate_ti.run(session=session)
+    generate_ti.schedule_downstream_tasks(session=session)
+    # Now downstreams can be skipped.
+    assert dr.get_task_instances([TaskInstanceState.SKIPPED], session=session)

Reply via email to