This is an automated email from the ASF dual-hosted git repository.

weilee pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new e05849de0b4 fix(task_instance): Ignore NotFullyPopulated if the task 
might be triggered due to trigger rule setup (#57474)
e05849de0b4 is described below

commit e05849de0b452a3c92efee3e389c2bed1cd4fc95
Author: Wei Lee <[email protected]>
AuthorDate: Tue Nov 4 14:19:45 2025 +0800

    fix(task_instance): Ignore NotFullyPopulated if the task might be triggered 
due to trigger rule setup (#57474)
---
 .../execution_api/routes/task_instances.py         |  45 ++++---
 .../versions/head/test_task_instances.py           | 135 ++++++++++++++++++++-
 2 files changed, 162 insertions(+), 18 deletions(-)

diff --git 
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py 
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
index c0d40b3a77c..dab8ac18429 100644
--- 
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
+++ 
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
@@ -66,14 +66,14 @@ from airflow.models.trigger import Trigger
 from airflow.models.xcom import XComModel
 from airflow.sdk.definitions._internal.expandinput import NotFullyPopulated
 from airflow.sdk.definitions.asset import Asset, AssetUniqueKey
+from airflow.serialization.serialized_objects import SerializedDAG
+from airflow.task.trigger_rule import TriggerRule
 from airflow.utils.state import DagRunState, TaskInstanceState
 
 if TYPE_CHECKING:
     from sqlalchemy.sql.dml import Update
 
     from airflow.models.expandinput import SchedulerExpandInput
-    from airflow.models.mappedoperator import MappedOperator
-    from airflow.serialization.serialized_objects import SerializedBaseOperator
 
 router = VersionedAPIRouter()
 
@@ -254,7 +254,11 @@ def ti_run(
 
         if dag := dag_bag.get_dag_for_run(dag_run=dr, session=session):
             upstream_map_indexes = dict(
-                _get_upstream_map_indexes(dag.get_task(ti.task_id), 
ti.map_index, ti.run_id, session=session)
+                _get_upstream_map_indexes(
+                    serialized_dag=dag,
+                    ti=ti,
+                    session=session,
+                )
             )
         else:
             upstream_map_indexes = None
@@ -285,30 +289,41 @@ def ti_run(
 
 
 def _get_upstream_map_indexes(
-    task: MappedOperator | SerializedBaseOperator, ti_map_index: int, run_id: 
str, session: SessionDep
+    *,
+    serialized_dag: SerializedDAG,
+    ti: TI,
+    session: SessionDep,
 ) -> Iterator[tuple[str, int | list[int] | None]]:
-    task_mapped_group = task.get_closest_mapped_task_group()
+    task = serialized_dag.get_task(ti.task_id)
     for upstream_task in task.upstream_list:
-        upstream_mapped_group = upstream_task.get_closest_mapped_task_group()
         map_indexes: int | list[int] | None
-        if upstream_mapped_group is None:
+        if (upstream_mapped_group := 
upstream_task.get_closest_mapped_task_group()) is None:
             # regular tasks or non-mapped task groups
             map_indexes = None
-        elif task_mapped_group == upstream_mapped_group:
+        elif task.get_closest_mapped_task_group() == upstream_mapped_group:
             # tasks in the same mapped task group hierarchy
-            map_indexes = ti_map_index
+            map_indexes = ti.map_index
         else:
             # tasks not in the same mapped task group
             # the upstream mapped task group should combine the return xcom as 
a list and return it
-            mapped_ti_count: int
+            mapped_ti_count: int | None = None
+
             try:
-                # for cases that does not need to resolve xcom
+                # First try: without resolving XCom
                 mapped_ti_count = 
upstream_mapped_group.get_parse_time_mapped_ti_count()
             except NotFullyPopulated:
-                # for cases that needs to resolve xcom to get the correct count
-                mapped_ti_count = cast(
-                    "SchedulerExpandInput", upstream_mapped_group._expand_input
-                ).get_total_map_length(run_id, session=session)
+                # Second try: resolve XCom for correct count
+                try:
+                    expand_input = cast("SchedulerExpandInput", 
upstream_mapped_group._expand_input)
+                    mapped_ti_count = 
expand_input.get_total_map_length(ti.run_id, session=session)
+                except NotFullyPopulated:
+                    # For these trigger rules, unresolved map indexes are 
acceptable.
+                    # The success of the upstream task is not the main reason 
for triggering the current task.
+                    # Therefore, whether the upstream task is fully populated 
can be ignored.
+                    if task.trigger_rule != TriggerRule.ALL_SUCCESS:
+                        mapped_ti_count = None
+
+            # Compute map indexes if we have a valid count
             map_indexes = list(range(mapped_ti_count)) if mapped_ti_count is 
not None else None
 
         yield upstream_task.task_id, map_indexes
diff --git 
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
 
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
index fe8dedba200..3efedecdbe9 100644
--- 
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
+++ 
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
@@ -18,6 +18,7 @@
 from __future__ import annotations
 
 from datetime import datetime
+from typing import TYPE_CHECKING
 from unittest import mock
 from uuid import uuid4
 
@@ -25,16 +26,18 @@ import pytest
 import uuid6
 from sqlalchemy import select, update
 from sqlalchemy.exc import SQLAlchemyError
+from sqlalchemy.orm import Session
 
 from airflow._shared.timezones import timezone
 from airflow.api_fastapi.auth.tokens import JWTValidator
 from airflow.api_fastapi.execution_api.app import lifespan
+from airflow.exceptions import AirflowSkipException
 from airflow.models import RenderedTaskInstanceFields, TaskReschedule, Trigger
 from airflow.models.asset import AssetActive, AssetAliasModel, AssetEvent, 
AssetModel
 from airflow.models.taskinstance import TaskInstance
 from airflow.models.taskinstancehistory import TaskInstanceHistory
 from airflow.providers.standard.operators.empty import EmptyOperator
-from airflow.sdk import Asset, TaskGroup, task, task_group
+from airflow.sdk import Asset, TaskGroup, TriggerRule, task, task_group
 from airflow.utils.state import DagRunState, State, TaskInstanceState, 
TerminalTIState
 
 from tests_common.test_utils.db import (
@@ -45,6 +48,11 @@ from tests_common.test_utils.db import (
     clear_rendered_ti_fields,
 )
 
+if TYPE_CHECKING:
+    from airflow.sdk.api.client import Client
+
+    from tests_common.pytest_plugin import DagMaker
+
 pytestmark = pytest.mark.db_test
 
 
@@ -377,7 +385,7 @@ class TestTIRunState:
                 f"but got {upstream_map_indexes}"
             )
 
-    def test_dynamic_task_mapping_with_xcom(self, client, dag_maker, 
create_task_instance, session, run_task):
+    def test_dynamic_task_mapping_with_xcom(self, client: Client, dag_maker: 
DagMaker, session: Session):
         """
         Test that the Task Instance upstream_map_indexes is correctly fetched 
when to running the Task Instances with xcom
         """
@@ -409,7 +417,6 @@ class TestTIRunState:
 
         # Simulate task_1 execution to produce TaskMap.
         (ti_1,) = decision.schedulable_tis
-        # ti_1 = dr.get_task_instance(task_id="task_1")
         ti_1.state = TaskInstanceState.SUCCESS
         session.add(TaskMap.from_task_instance_xcom(ti_1, [0, 1]))
         session.flush()
@@ -436,6 +443,128 @@ class TestTIRunState:
         )
         assert response.json()["upstream_map_indexes"] == {"tg.task_2": [0, 1, 
2, 3, 4, 5]}
 
+    def test_dynamic_task_mapping_with_all_success_trigger_rule(self, 
dag_maker: DagMaker, session: Session):
+        """
+        Test that the Task Instance upstream_map_indexes is not populuated but
+        the downstream task should not be run.
+        """
+
+        with dag_maker(session=session, serialized=True):
+
+            @task
+            def task_1():
+                raise AirflowSkipException()
+
+            @task_group
+            def tg(x):
+                @task
+                def task_2():
+                    raise AirflowSkipException()
+
+                task_2()
+
+            @task(trigger_rule=TriggerRule.ALL_SUCCESS)
+            def task_3():
+                pass
+
+            @task
+            def task_4():
+                pass
+
+            tg.expand(x=task_1()) >> [task_3(), task_4()]
+
+        dr = dag_maker.create_dagrun()
+
+        decision = dr.task_instance_scheduling_decisions(session=session)
+
+        # Simulate task_1 skipped
+        (ti_1,) = decision.schedulable_tis
+        ti_1.state = TaskInstanceState.SKIPPED
+        session.flush()
+
+        # Now task_2 in mapped task group is not expanded and also skipped.
+        decision = dr.task_instance_scheduling_decisions(session=session)
+        for ti in decision.schedulable_tis:
+            ti.state = TaskInstanceState.SKIPPED
+        session.flush()
+
+        decision = dr.task_instance_scheduling_decisions(session=session)
+        assert decision.schedulable_tis == []
+
+    @pytest.mark.parametrize(
+        "trigger_rule",
+        [
+            TriggerRule.ALL_DONE,
+            TriggerRule.ALL_DONE_SETUP_SUCCESS,
+            TriggerRule.NONE_FAILED,
+            TriggerRule.ALL_SKIPPED,
+        ],
+    )
+    def test_dynamic_task_mapping_with_non_all_success_trigger_rule(
+        self, client: Client, dag_maker: DagMaker, session: Session, 
trigger_rule: TriggerRule
+    ):
+        """
+        Test that the Task Instance upstream_map_indexes is not populuated but
+        the downstream task should still be run due to trigger rule.
+        """
+
+        with dag_maker(session=session, serialized=True):
+
+            @task
+            def task_1():
+                raise AirflowSkipException()
+
+            @task_group
+            def tg(x):
+                @task
+                def task_2():
+                    raise AirflowSkipException()
+
+                task_2()
+
+            @task(trigger_rule=trigger_rule)
+            def task_3():
+                pass
+
+            @task
+            def task_4():
+                pass
+
+            tg.expand(x=task_1()) >> [task_3(), task_4()]
+
+        dr = dag_maker.create_dagrun()
+
+        decision = dr.task_instance_scheduling_decisions(session=session)
+
+        # Simulate task_1 skipped
+        (ti_1,) = decision.schedulable_tis
+        ti_1.state = TaskInstanceState.SKIPPED
+        session.flush()
+
+        # Now task_2 in mapped tagk group is not expanded and also skipped..
+        decision = dr.task_instance_scheduling_decisions(session=session)
+        for ti in decision.schedulable_tis:
+            ti.state = TaskInstanceState.SKIPPED
+        session.flush()
+
+        decision = dr.task_instance_scheduling_decisions(session=session)
+        # only task_3 is schedulable
+        (task_3_ti,) = decision.schedulable_tis
+        assert task_3_ti.task_id == "task_3"
+        task_3_ti.set_state(State.QUEUED)
+
+        response = client.patch(
+            f"/execution/task-instances/{task_3_ti.id}/run",
+            json={
+                "state": "running",
+                "hostname": "random-hostname",
+                "unixname": "random-unixname",
+                "pid": 100,
+                "start_date": "2024-09-30T12:00:00Z",
+            },
+        )
+        assert response.json()["upstream_map_indexes"] == {"tg.task_2": None}
+
     def test_next_kwargs_still_encoded(self, client, session, 
create_task_instance, time_machine):
         instant_str = "2024-09-30T12:00:00Z"
         instant = timezone.parse(instant_str)

Reply via email to