This is an automated email from the ASF dual-hosted git repository.
ephraimanierobi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new d3028ada36 Fix reducing mapped length of a mapped task at runtime
after a clear (#25531)
d3028ada36 is described below
commit d3028ada36a43a0d549d22c280fb16d868b90b6d
Author: Ephraim Anierobi <[email protected]>
AuthorDate: Fri Aug 5 10:30:51 2022 +0100
Fix reducing mapped length of a mapped task at runtime after a clear
(#25531)
The previous fix on task immutability after a run did not fix a case where
the task was removed at runtime when the literal is dynamic. This PR addreses it
---
airflow/models/dagrun.py | 12 +++++++--
tests/models/test_dagrun.py | 64 +++++++++++++++++++++++++++++++++++++++++++++
2 files changed, 74 insertions(+), 2 deletions(-)
diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py
index fde2aa7685..e902f687a0 100644
--- a/airflow/models/dagrun.py
+++ b/airflow/models/dagrun.py
@@ -736,7 +736,7 @@ class DagRun(Base, LoggingMixin):
yield ti
tis = list(_filter_tis_and_exclude_removed(self.get_dag(), tis))
- missing_indexes = self._find_missing_task_indexes(tis, session=session)
+ missing_indexes = self._revise_mapped_task_indexes(tis,
session=session)
if missing_indexes:
self.verify_integrity(missing_indexes=missing_indexes,
session=session)
@@ -1158,7 +1158,7 @@ class DagRun(Base, LoggingMixin):
# TODO[HA]: We probably need to savepoint this so we can keep the
transaction alive.
session.rollback()
- def _find_missing_task_indexes(
+ def _revise_mapped_task_indexes(
self,
tis: Iterable[TI],
*,
@@ -1183,6 +1183,14 @@ class DagRun(Base, LoggingMixin):
existing_indexes[task].append(ti.map_index)
task.run_time_mapped_ti_count.cache_clear() # type:
ignore[attr-defined]
new_length = task.run_time_mapped_ti_count(self.run_id,
session=session) or 0
+
+ if ti.map_index >= new_length:
+ self.log.debug(
+ "Removing task '%s' as the map_index is longer than the
resolved mapping list (%d)",
+ ti,
+ new_length,
+ )
+ ti.state = State.REMOVED
new_indexes[task] = range(new_length)
missing_indexes: Dict[MappedOperator, Sequence[int]] =
defaultdict(list)
for k, v in existing_indexes.items():
diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py
index 8110240d4c..7f1bf8dfaa 100644
--- a/tests/models/test_dagrun.py
+++ b/tests/models/test_dagrun.py
@@ -1227,6 +1227,70 @@ def
test_mapped_length_increase_at_runtime_adds_additional_tis(dag_maker, sessio
]
+def
test_mapped_literal_length_reduction_at_runtime_adds_removed_state(dag_maker,
session):
+ """
+ Test that when the length of mapped literal reduces at runtime, the
missing task instances
+ are marked as removed
+ """
+ from airflow.models import Variable
+
+ Variable.set(key='arg1', value=[1, 2, 3])
+
+ @task
+ def task_1():
+ return Variable.get('arg1', deserialize_json=True)
+
+ with dag_maker(session=session) as dag:
+
+ @task
+ def task_2(arg2):
+ ...
+
+ task_2.expand(arg2=task_1())
+
+ dr = dag_maker.create_dagrun()
+ ti = dr.get_task_instance(task_id='task_1')
+ ti.run()
+ dr.task_instance_scheduling_decisions()
+ tis = dr.get_task_instances()
+ indices = [(ti.map_index, ti.state) for ti in tis if ti.map_index >= 0]
+ assert sorted(indices) == [
+ (0, State.NONE),
+ (1, State.NONE),
+ (2, State.NONE),
+ ]
+
+ # Now "clear" and "reduce" the length of literal
+ dag.clear()
+ Variable.set(key='arg1', value=[1, 2])
+
+ with dag:
+ task_2.expand(arg2=task_1()).operator
+
+ # At this point, we need to test that the change works on the serialized
+ # DAG (which is what the scheduler operates on)
+ serialized_dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
+
+ dr.dag = serialized_dag
+
+ # Run the first task again to get the new lengths
+ ti = dr.get_task_instance(task_id='task_1')
+ task1 = dag.get_task('task_1')
+ ti.refresh_from_task(task1)
+ ti.run()
+
+ # this would be called by the localtask job
+ dr.task_instance_scheduling_decisions()
+ tis = dr.get_task_instances()
+
+ indices = [(ti.map_index, ti.state) for ti in tis if ti.map_index >= 0]
+ assert sorted(indices) == [
+ (0, State.NONE),
+ (1, State.NONE),
+ (2, TaskInstanceState.REMOVED),
+ ]
+
+
@pytest.mark.need_serialized_dag
def test_mapped_mixed__literal_not_expanded_at_create(dag_maker, session):
literal = [1, 2, 3, 4]