This is an automated email from the ASF dual-hosted git repository. kaxilnaik pushed a commit to branch v3-1-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 8a3cd09aefac99d0edc790e7aa7bbd1c392095fe Author: Zhen-Lun (Kevin) Hong <[email protected]> AuthorDate: Fri Oct 3 15:56:37 2025 +0800 fix: allow mapped tasks to accept zero-length inputs on rerun (#56162) * fix: allow mapped tasks to accept zero-length inputs on rerun * fix: add test for rerun args of different length * chore: revise comments to align with the changes * chore: add comments before the task state check * fix: replace legacy query syntax (cherry picked from commit 6d3e841d0ceb2bfd7d5d65d6ec1a455d59436b91) --- airflow-core/src/airflow/models/dagrun.py | 7 +++- airflow-core/tests/unit/models/test_dagrun.py | 55 ++++++++++++++++++++++++++- 2 files changed, 60 insertions(+), 2 deletions(-) diff --git a/airflow-core/src/airflow/models/dagrun.py b/airflow-core/src/airflow/models/dagrun.py index ec7dff2ba71..562fdc5e9fc 100644 --- a/airflow-core/src/airflow/models/dagrun.py +++ b/airflow-core/src/airflow/models/dagrun.py @@ -1541,7 +1541,12 @@ class DagRun(Base, LoggingMixin): ) ) revised_map_index_task_ids.add(schedulable.task.task_id) - ready_tis.append(schedulable) + + # _revise_map_indexes_if_mapped might mark the current task as REMOVED + # after calculating mapped task length, so we need to re-check + # the task state to ensure it's still schedulable + if schedulable.state in SCHEDULEABLE_STATES: + ready_tis.append(schedulable) # Check if any ti changed state tis_filter = TI.filter_for_tis(old_states) diff --git a/airflow-core/tests/unit/models/test_dagrun.py b/airflow-core/tests/unit/models/test_dagrun.py index 73284f67356..47bf91d4687 100644 --- a/airflow-core/tests/unit/models/test_dagrun.py +++ b/airflow-core/tests/unit/models/test_dagrun.py @@ -43,7 +43,7 @@ from airflow.models.taskreschedule import TaskReschedule from airflow.providers.standard.operators.bash import BashOperator from airflow.providers.standard.operators.empty import EmptyOperator from airflow.providers.standard.operators.python import PythonOperator, ShortCircuitOperator -from airflow.sdk import DAG, BaseOperator, setup, task, task_group, teardown +from airflow.sdk import DAG, BaseOperator, get_current_context, setup, task, task_group, teardown from airflow.sdk.definitions.deadline import AsyncCallback, DeadlineAlert, DeadlineReference from airflow.serialization.serialized_objects import LazyDeserializedDAG, SerializedDAG from airflow.stats import Stats @@ -2247,6 +2247,59 @@ def test_mapped_task_group_expands(dag_maker, session): } [email protected]("rerun_length", [0, 1, 2, 3]) +def test_mapped_task_rerun_with_different_length_of_args(session, dag_maker, rerun_length): + @task + def generate_mapping_args(): + context = get_current_context() + if context["ti"].try_number == 0: + args = [i for i in range(2)] + else: + args = [i for i in range(rerun_length)] + return args + + @task + def mapped_print_value(arg): + return arg + + with dag_maker(session=session): + args = generate_mapping_args() + mapped_print_value.expand(arg=args) + + # First Run + dr = dag_maker.create_dagrun() + dag_maker.run_ti("generate_mapping_args", dr) + + decision = dr.task_instance_scheduling_decisions(session=session) + for ti in decision.schedulable_tis: + dag_maker.run_ti(ti.task_id, dr, map_index=ti.map_index) + + clear_task_instances(dr.get_task_instances(), session=session) + + # Second Run + ti = dr.get_task_instance(task_id="generate_mapping_args", session=session) + ti.try_number += 1 + session.merge(ti) + dag_maker.run_ti("generate_mapping_args", dr) + + # Check if the new mapped task instances are correctly scheduled + decision = dr.task_instance_scheduling_decisions(session=session) + assert len(decision.schedulable_tis) == rerun_length + assert all([ti.task_id == "mapped_print_value" for ti in decision.schedulable_tis]) + + # Check if mapped task rerun successfully + for ti in decision.schedulable_tis: + dag_maker.run_ti(ti.task_id, dr, map_index=ti.map_index) + query = select(TI).where( + TI.dag_id == dr.dag_id, + TI.run_id == dr.run_id, + TI.task_id == "mapped_print_value", + TI.state == TaskInstanceState.SUCCESS, + ) + success_tis = session.execute(query).all() + assert len(success_tis) == rerun_length + + def test_operator_mapped_task_group_receives_value(dag_maker, session): with dag_maker(session=session):
