This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 5f2628d36c Count mapped upstreams only if all are finished (#30641)
5f2628d36c is described below
commit 5f2628d36cb8481ee21bd79ac184fd8fdce3e47d
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Sun Apr 23 03:00:34 2023 +0800
Count mapped upstreams only if all are finished (#30641)
* Fix Pydantic TI handling in XComArg.resolve()
* Count mapped upstreams only if all are finished
An XComArg's get_task_map_length() should only return an integer when
the *entire* task has finished. However, before this patch, it may
attempt to count a mapped upstream even when some (or all!) of its
expanded tis are still unfinished, resulting its downstream to be
expanded prematurely.
This patch adds an additional check before we count upstream results to
ensure all the upstreams are actually finished.
* Use SQL IN to find unfinished TI instead
This needs a special workaround for a NULL quirk in SQL.
---
airflow/models/xcom_arg.py | 22 ++++++++++++++++++++-
tests/models/test_taskinstance.py | 41 +++++++++++++++++++++++++++++++++++++++
2 files changed, 62 insertions(+), 1 deletion(-)
diff --git a/airflow/models/xcom_arg.py b/airflow/models/xcom_arg.py
index ce4fb58ffa..d8b42ba819 100644
--- a/airflow/models/xcom_arg.py
+++ b/airflow/models/xcom_arg.py
@@ -21,7 +21,7 @@ import contextlib
import inspect
from typing import TYPE_CHECKING, Any, Callable, Iterator, Mapping, Sequence,
Union, overload
-from sqlalchemy import func
+from sqlalchemy import func, or_
from sqlalchemy.orm import Session
from airflow.exceptions import AirflowException, XComNotFound
@@ -33,6 +33,7 @@ from airflow.utils.edgemodifier import EdgeModifier
from airflow.utils.mixins import ResolveMixin
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.setup_teardown import SetupTeardownContext
+from airflow.utils.state import State
from airflow.utils.types import NOTSET, ArgNotSet
from airflow.utils.xcom import XCOM_RETURN_KEY
@@ -309,11 +310,26 @@ class PlainXComArg(XComArg):
return super().zip(*others, fillvalue=fillvalue)
def get_task_map_length(self, run_id: str, *, session: Session) -> int |
None:
+ from airflow.models.taskinstance import TaskInstance
from airflow.models.taskmap import TaskMap
from airflow.models.xcom import XCom
task = self.operator
if isinstance(task, MappedOperator):
+ unfinished_ti_count_query =
session.query(func.count(TaskInstance.map_index)).filter(
+ TaskInstance.dag_id == task.dag_id,
+ TaskInstance.run_id == run_id,
+ TaskInstance.task_id == task.task_id,
+ # Special NULL treatment is needed because 'state' can be NULL.
+ # The "IN" part would produce "NULL NOT IN ..." and eventually
+ # "NULl = NULL", which is a big no-no in SQL.
+ or_(
+ TaskInstance.state.is_(None),
+ TaskInstance.state.in_(s.value for s in State.unfinished
if s is not None),
+ ),
+ )
+ if unfinished_ti_count_query.scalar():
+ return None # Not all of the expanded tis are done yet.
query = session.query(func.count(XCom.map_index)).filter(
XCom.dag_id == task.dag_id,
XCom.run_id == run_id,
@@ -332,7 +348,11 @@ class PlainXComArg(XComArg):
@provide_session
def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any:
+ from airflow.models.taskinstance import TaskInstance
+
ti = context["ti"]
+ assert isinstance(ti, TaskInstance), "Wait for AIP-44 implementation
to complete"
+
task_id = self.operator.task_id
map_indexes = ti.get_relevant_upstream_map_indexes(
self.operator,
diff --git a/tests/models/test_taskinstance.py
b/tests/models/test_taskinstance.py
index 098a0acd04..f5ad9d30ee 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -3968,3 +3968,44 @@ def
test_mapped_task_expands_in_mini_scheduler_if_upstreams_are_done(dag_maker,
middle_ti = dr.get_task_instance(task_id="middle_task", map_index=i)
assert middle_ti.state == State.SCHEDULED
assert "3 downstream tasks scheduled from follow-on schedule" in
caplog.text
+
+
+def
test_mini_scheduler_not_skip_mapped_downstream_until_all_upstreams_finish(dag_maker,
session):
+ with dag_maker(session=session):
+
+ @task
+ def generate() -> list[list[int]]:
+ return []
+
+ @task
+ def a_sum(numbers: list[int]) -> int:
+ return sum(numbers)
+
+ @task
+ def b_double(summed: int) -> int:
+ return summed * 2
+
+ @task
+ def c_gather(result) -> None:
+ pass
+
+ static = EmptyOperator(task_id="static")
+
+ summed = a_sum.expand(numbers=generate())
+ doubled = b_double.expand(summed=summed)
+ static >> c_gather(doubled)
+
+ dr: DagRun = dag_maker.create_dagrun()
+ tis = {(ti.task_id, ti.map_index): ti for ti in dr.task_instances}
+
+ static_ti = tis[("static", -1)]
+ static_ti.run(session=session)
+ static_ti.schedule_downstream_tasks(session=session)
+ # No tasks should be skipped yet!
+ assert not dr.get_task_instances([TaskInstanceState.SKIPPED],
session=session)
+
+ generate_ti = tis[("generate", -1)]
+ generate_ti.run(session=session)
+ generate_ti.schedule_downstream_tasks(session=session)
+ # Now downstreams can be skipped.
+ assert dr.get_task_instances([TaskInstanceState.SKIPPED], session=session)