This is an automated email from the ASF dual-hosted git repository.
weilee pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new e05849de0b4 fix(task_instance): Ignore NotFullyPopulated if the task
might be triggered due to trigger rule setup (#57474)
e05849de0b4 is described below
commit e05849de0b452a3c92efee3e389c2bed1cd4fc95
Author: Wei Lee <[email protected]>
AuthorDate: Tue Nov 4 14:19:45 2025 +0800
fix(task_instance): Ignore NotFullyPopulated if the task might be triggered
due to trigger rule setup (#57474)
---
.../execution_api/routes/task_instances.py | 45 ++++---
.../versions/head/test_task_instances.py | 135 ++++++++++++++++++++-
2 files changed, 162 insertions(+), 18 deletions(-)
diff --git
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
index c0d40b3a77c..dab8ac18429 100644
---
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
+++
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
@@ -66,14 +66,14 @@ from airflow.models.trigger import Trigger
from airflow.models.xcom import XComModel
from airflow.sdk.definitions._internal.expandinput import NotFullyPopulated
from airflow.sdk.definitions.asset import Asset, AssetUniqueKey
+from airflow.serialization.serialized_objects import SerializedDAG
+from airflow.task.trigger_rule import TriggerRule
from airflow.utils.state import DagRunState, TaskInstanceState
if TYPE_CHECKING:
from sqlalchemy.sql.dml import Update
from airflow.models.expandinput import SchedulerExpandInput
- from airflow.models.mappedoperator import MappedOperator
- from airflow.serialization.serialized_objects import SerializedBaseOperator
router = VersionedAPIRouter()
@@ -254,7 +254,11 @@ def ti_run(
if dag := dag_bag.get_dag_for_run(dag_run=dr, session=session):
upstream_map_indexes = dict(
- _get_upstream_map_indexes(dag.get_task(ti.task_id),
ti.map_index, ti.run_id, session=session)
+ _get_upstream_map_indexes(
+ serialized_dag=dag,
+ ti=ti,
+ session=session,
+ )
)
else:
upstream_map_indexes = None
@@ -285,30 +289,41 @@ def ti_run(
def _get_upstream_map_indexes(
- task: MappedOperator | SerializedBaseOperator, ti_map_index: int, run_id:
str, session: SessionDep
+ *,
+ serialized_dag: SerializedDAG,
+ ti: TI,
+ session: SessionDep,
) -> Iterator[tuple[str, int | list[int] | None]]:
- task_mapped_group = task.get_closest_mapped_task_group()
+ task = serialized_dag.get_task(ti.task_id)
for upstream_task in task.upstream_list:
- upstream_mapped_group = upstream_task.get_closest_mapped_task_group()
map_indexes: int | list[int] | None
- if upstream_mapped_group is None:
+ if (upstream_mapped_group :=
upstream_task.get_closest_mapped_task_group()) is None:
# regular tasks or non-mapped task groups
map_indexes = None
- elif task_mapped_group == upstream_mapped_group:
+ elif task.get_closest_mapped_task_group() == upstream_mapped_group:
# tasks in the same mapped task group hierarchy
- map_indexes = ti_map_index
+ map_indexes = ti.map_index
else:
# tasks not in the same mapped task group
# the upstream mapped task group should combine the return xcom as
a list and return it
- mapped_ti_count: int
+ mapped_ti_count: int | None = None
+
try:
- # for cases that does not need to resolve xcom
+ # First try: without resolving XCom
mapped_ti_count =
upstream_mapped_group.get_parse_time_mapped_ti_count()
except NotFullyPopulated:
- # for cases that needs to resolve xcom to get the correct count
- mapped_ti_count = cast(
- "SchedulerExpandInput", upstream_mapped_group._expand_input
- ).get_total_map_length(run_id, session=session)
+ # Second try: resolve XCom for correct count
+ try:
+ expand_input = cast("SchedulerExpandInput",
upstream_mapped_group._expand_input)
+ mapped_ti_count =
expand_input.get_total_map_length(ti.run_id, session=session)
+ except NotFullyPopulated:
+ # For these trigger rules, unresolved map indexes are
acceptable.
+ # The success of the upstream task is not the main reason
for triggering the current task.
+ # Therefore, whether the upstream task is fully populated
can be ignored.
+ if task.trigger_rule != TriggerRule.ALL_SUCCESS:
+ mapped_ti_count = None
+
+ # Compute map indexes if we have a valid count
map_indexes = list(range(mapped_ti_count)) if mapped_ti_count is
not None else None
yield upstream_task.task_id, map_indexes
diff --git
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
index fe8dedba200..3efedecdbe9 100644
---
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
+++
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
@@ -18,6 +18,7 @@
from __future__ import annotations
from datetime import datetime
+from typing import TYPE_CHECKING
from unittest import mock
from uuid import uuid4
@@ -25,16 +26,18 @@ import pytest
import uuid6
from sqlalchemy import select, update
from sqlalchemy.exc import SQLAlchemyError
+from sqlalchemy.orm import Session
from airflow._shared.timezones import timezone
from airflow.api_fastapi.auth.tokens import JWTValidator
from airflow.api_fastapi.execution_api.app import lifespan
+from airflow.exceptions import AirflowSkipException
from airflow.models import RenderedTaskInstanceFields, TaskReschedule, Trigger
from airflow.models.asset import AssetActive, AssetAliasModel, AssetEvent,
AssetModel
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskinstancehistory import TaskInstanceHistory
from airflow.providers.standard.operators.empty import EmptyOperator
-from airflow.sdk import Asset, TaskGroup, task, task_group
+from airflow.sdk import Asset, TaskGroup, TriggerRule, task, task_group
from airflow.utils.state import DagRunState, State, TaskInstanceState,
TerminalTIState
from tests_common.test_utils.db import (
@@ -45,6 +48,11 @@ from tests_common.test_utils.db import (
clear_rendered_ti_fields,
)
+if TYPE_CHECKING:
+ from airflow.sdk.api.client import Client
+
+ from tests_common.pytest_plugin import DagMaker
+
pytestmark = pytest.mark.db_test
@@ -377,7 +385,7 @@ class TestTIRunState:
f"but got {upstream_map_indexes}"
)
- def test_dynamic_task_mapping_with_xcom(self, client, dag_maker,
create_task_instance, session, run_task):
+ def test_dynamic_task_mapping_with_xcom(self, client: Client, dag_maker:
DagMaker, session: Session):
"""
Test that the Task Instance upstream_map_indexes is correctly fetched
when to running the Task Instances with xcom
"""
@@ -409,7 +417,6 @@ class TestTIRunState:
# Simulate task_1 execution to produce TaskMap.
(ti_1,) = decision.schedulable_tis
- # ti_1 = dr.get_task_instance(task_id="task_1")
ti_1.state = TaskInstanceState.SUCCESS
session.add(TaskMap.from_task_instance_xcom(ti_1, [0, 1]))
session.flush()
@@ -436,6 +443,128 @@ class TestTIRunState:
)
assert response.json()["upstream_map_indexes"] == {"tg.task_2": [0, 1,
2, 3, 4, 5]}
+ def test_dynamic_task_mapping_with_all_success_trigger_rule(self,
dag_maker: DagMaker, session: Session):
+ """
+ Test that the Task Instance upstream_map_indexes is not populuated but
+ the downstream task should not be run.
+ """
+
+ with dag_maker(session=session, serialized=True):
+
+ @task
+ def task_1():
+ raise AirflowSkipException()
+
+ @task_group
+ def tg(x):
+ @task
+ def task_2():
+ raise AirflowSkipException()
+
+ task_2()
+
+ @task(trigger_rule=TriggerRule.ALL_SUCCESS)
+ def task_3():
+ pass
+
+ @task
+ def task_4():
+ pass
+
+ tg.expand(x=task_1()) >> [task_3(), task_4()]
+
+ dr = dag_maker.create_dagrun()
+
+ decision = dr.task_instance_scheduling_decisions(session=session)
+
+ # Simulate task_1 skipped
+ (ti_1,) = decision.schedulable_tis
+ ti_1.state = TaskInstanceState.SKIPPED
+ session.flush()
+
+ # Now task_2 in mapped task group is not expanded and also skipped.
+ decision = dr.task_instance_scheduling_decisions(session=session)
+ for ti in decision.schedulable_tis:
+ ti.state = TaskInstanceState.SKIPPED
+ session.flush()
+
+ decision = dr.task_instance_scheduling_decisions(session=session)
+ assert decision.schedulable_tis == []
+
+ @pytest.mark.parametrize(
+ "trigger_rule",
+ [
+ TriggerRule.ALL_DONE,
+ TriggerRule.ALL_DONE_SETUP_SUCCESS,
+ TriggerRule.NONE_FAILED,
+ TriggerRule.ALL_SKIPPED,
+ ],
+ )
+ def test_dynamic_task_mapping_with_non_all_success_trigger_rule(
+ self, client: Client, dag_maker: DagMaker, session: Session,
trigger_rule: TriggerRule
+ ):
+ """
+ Test that the Task Instance upstream_map_indexes is not populuated but
+ the downstream task should still be run due to trigger rule.
+ """
+
+ with dag_maker(session=session, serialized=True):
+
+ @task
+ def task_1():
+ raise AirflowSkipException()
+
+ @task_group
+ def tg(x):
+ @task
+ def task_2():
+ raise AirflowSkipException()
+
+ task_2()
+
+ @task(trigger_rule=trigger_rule)
+ def task_3():
+ pass
+
+ @task
+ def task_4():
+ pass
+
+ tg.expand(x=task_1()) >> [task_3(), task_4()]
+
+ dr = dag_maker.create_dagrun()
+
+ decision = dr.task_instance_scheduling_decisions(session=session)
+
+ # Simulate task_1 skipped
+ (ti_1,) = decision.schedulable_tis
+ ti_1.state = TaskInstanceState.SKIPPED
+ session.flush()
+
+ # Now task_2 in mapped tagk group is not expanded and also skipped..
+ decision = dr.task_instance_scheduling_decisions(session=session)
+ for ti in decision.schedulable_tis:
+ ti.state = TaskInstanceState.SKIPPED
+ session.flush()
+
+ decision = dr.task_instance_scheduling_decisions(session=session)
+ # only task_3 is schedulable
+ (task_3_ti,) = decision.schedulable_tis
+ assert task_3_ti.task_id == "task_3"
+ task_3_ti.set_state(State.QUEUED)
+
+ response = client.patch(
+ f"/execution/task-instances/{task_3_ti.id}/run",
+ json={
+ "state": "running",
+ "hostname": "random-hostname",
+ "unixname": "random-unixname",
+ "pid": 100,
+ "start_date": "2024-09-30T12:00:00Z",
+ },
+ )
+ assert response.json()["upstream_map_indexes"] == {"tg.task_2": None}
+
def test_next_kwargs_still_encoded(self, client, session,
create_task_instance, time_machine):
instant_str = "2024-09-30T12:00:00Z"
instant = timezone.parse(instant_str)