ephraimbuddy commented on code in PR #25788:
URL: https://github.com/apache/airflow/pull/25788#discussion_r949911201
##########
airflow/models/dagrun.py:
##########
@@ -1081,44 +1058,49 @@ def _create_task_instances(
# TODO[HA]: We probably need to savepoint this so we can keep the
transaction alive.
session.rollback()
- def _revise_mapped_task_indexes(
- self,
- tis: Iterable[TI],
- *,
- session: Session,
- ) -> Dict["MappedOperator", Sequence[int]]:
- """Check if the length of the mapped task instances changed at runtime
and find the missing indexes.
-
- :param tis: Task instances to check
- :param session: The session to use
- """
- from airflow.models.mappedoperator import MappedOperator
+ def _revise_mapped_task_indexes(self, task, session: Session):
+ """Check if task increased or reduced in length and handle
appropriately"""
+ from airflow.models.taskinstance import TaskInstance
+ from airflow.settings import task_instance_mutation_hook
- existing_indexes: Dict[MappedOperator, List[int]] = defaultdict(list)
- new_indexes: Dict[MappedOperator, Sequence[int]] = defaultdict(list)
- for ti in tis:
- task = ti.task
- if not isinstance(task, MappedOperator):
- continue
- # skip unexpanded tasks and also tasks that expands with literal
arguments
- if ti.map_index < 0 or task.parse_time_mapped_ti_count:
- continue
- existing_indexes[task].append(ti.map_index)
- task.run_time_mapped_ti_count.cache_clear() # type:
ignore[attr-defined]
- new_length = task.run_time_mapped_ti_count(self.run_id,
session=session) or 0
-
- if ti.map_index >= new_length:
- self.log.debug(
- "Removing task '%s' as the map_index is longer than the
resolved mapping list (%d)",
- ti,
- new_length,
- )
- ti.state = State.REMOVED
- new_indexes[task] = range(new_length)
- missing_indexes: Dict[MappedOperator, Sequence[int]] =
defaultdict(list)
- for k, v in existing_indexes.items():
- missing_indexes.update({k:
list(set(new_indexes[k]).difference(v))})
- return missing_indexes
+ task.run_time_mapped_ti_count.cache_clear()
+ total_length = (
+ task.parse_time_mapped_ti_count
+ or task.run_time_mapped_ti_count(self.run_id, session=session)
+ or 0
+ )
+ existing_tis = (
+ session.query(TaskInstance)
+ .filter(
+ TaskInstance.dag_id == self.dag_id,
+ TaskInstance.task_id == task.task_id,
+ TaskInstance.run_id == self.run_id,
+ )
+ .all()
+ )
+ existing_indexes = [i.map_index for i in existing_tis]
+ missing_tis =
set(range(total_length)).difference(set(existing_indexes))
+ removed_tis = set(existing_indexes).difference(range(total_length))
+ created_tis = []
+
+ if missing_tis:
+ for index in missing_tis:
+ ti = TaskInstance(task, run_id=self.run_id, map_index=index,
state=None)
+ self.log.debug("Expanding TIs upserted %s", ti)
+ task_instance_mutation_hook(ti)
+ ti = session.merge(ti)
+ ti.refresh_from_task(task)
+ session.flush()
+ created_tis.append(ti)
+ elif removed_tis:
+ session.query(TaskInstance).filter(
+ TaskInstance.dag_id == self.dag_id,
+ TaskInstance.task_id == task.task_id,
+ TaskInstance.run_id == self.run_id,
+ TaskInstance.map_index.in_(removed_tis),
+ ).update({TaskInstance.state: TaskInstanceState.REMOVED})
+ session.flush()
Review Comment:
If we have missing `tis`, then no tis was removed, so `elif` captures it
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]