uranusjr commented on code in PR #25788:
URL: https://github.com/apache/airflow/pull/25788#discussion_r954852616


##########
airflow/models/dagrun.py:
##########
@@ -1082,44 +1058,45 @@ def _create_task_instances(
             # TODO[HA]: We probably need to savepoint this so we can keep the 
transaction alive.
             session.rollback()
 
-    def _revise_mapped_task_indexes(
-        self,
-        tis: Iterable[TI],
-        *,
-        session: Session,
-    ) -> Dict["MappedOperator", Sequence[int]]:
-        """Check if the length of the mapped task instances changed at runtime 
and find the missing indexes.
+    def _revise_mapped_task_indexes(self, task, session: Session):
+        """Check if task increased or reduced in length and handle 
appropriately"""
+        from airflow.models.taskinstance import TaskInstance
+        from airflow.settings import task_instance_mutation_hook
 
-        :param tis: Task instances to check
-        :param session: The session to use
-        """
-        from airflow.models.mappedoperator import MappedOperator
+        task.run_time_mapped_ti_count.cache_clear()
+        total_length = (
+            task.parse_time_mapped_ti_count
+            or task.run_time_mapped_ti_count(self.run_id, session=session)
+            or 0
+        )
+        query = session.query(TaskInstance.map_index).filter(
+            TaskInstance.dag_id == self.dag_id,
+            TaskInstance.task_id == task.task_id,
+            TaskInstance.run_id == self.run_id,
+        )
+        existing_indexes = {i for (i,) in query}
+        missing_indexes = 
set(range(total_length)).difference(set(existing_indexes))
+        removed_indexes = set(existing_indexes).difference(range(total_length))

Review Comment:
   ```suggestion
           removed_indexes = existing_indexes.difference(range(total_length))
   ```



##########
airflow/models/dagrun.py:
##########
@@ -1082,44 +1058,45 @@ def _create_task_instances(
             # TODO[HA]: We probably need to savepoint this so we can keep the 
transaction alive.
             session.rollback()
 
-    def _revise_mapped_task_indexes(
-        self,
-        tis: Iterable[TI],
-        *,
-        session: Session,
-    ) -> Dict["MappedOperator", Sequence[int]]:
-        """Check if the length of the mapped task instances changed at runtime 
and find the missing indexes.
+    def _revise_mapped_task_indexes(self, task, session: Session):
+        """Check if task increased or reduced in length and handle 
appropriately"""
+        from airflow.models.taskinstance import TaskInstance
+        from airflow.settings import task_instance_mutation_hook
 
-        :param tis: Task instances to check
-        :param session: The session to use
-        """
-        from airflow.models.mappedoperator import MappedOperator
+        task.run_time_mapped_ti_count.cache_clear()
+        total_length = (
+            task.parse_time_mapped_ti_count
+            or task.run_time_mapped_ti_count(self.run_id, session=session)
+            or 0
+        )
+        query = session.query(TaskInstance.map_index).filter(
+            TaskInstance.dag_id == self.dag_id,
+            TaskInstance.task_id == task.task_id,
+            TaskInstance.run_id == self.run_id,
+        )
+        existing_indexes = {i for (i,) in query}
+        missing_indexes = 
set(range(total_length)).difference(set(existing_indexes))
+        removed_indexes = set(existing_indexes).difference(range(total_length))
+        created_indexes = []

Review Comment:
   This one should still be `created_tis`β€”it holds task instances πŸ˜„ 



##########
airflow/models/dagrun.py:
##########
@@ -1082,44 +1058,45 @@ def _create_task_instances(
             # TODO[HA]: We probably need to savepoint this so we can keep the 
transaction alive.
             session.rollback()
 
-    def _revise_mapped_task_indexes(
-        self,
-        tis: Iterable[TI],
-        *,
-        session: Session,
-    ) -> Dict["MappedOperator", Sequence[int]]:
-        """Check if the length of the mapped task instances changed at runtime 
and find the missing indexes.
+    def _revise_mapped_task_indexes(self, task, session: Session):
+        """Check if task increased or reduced in length and handle 
appropriately"""
+        from airflow.models.taskinstance import TaskInstance
+        from airflow.settings import task_instance_mutation_hook
 
-        :param tis: Task instances to check
-        :param session: The session to use
-        """
-        from airflow.models.mappedoperator import MappedOperator
+        task.run_time_mapped_ti_count.cache_clear()
+        total_length = (
+            task.parse_time_mapped_ti_count
+            or task.run_time_mapped_ti_count(self.run_id, session=session)
+            or 0
+        )
+        query = session.query(TaskInstance.map_index).filter(
+            TaskInstance.dag_id == self.dag_id,
+            TaskInstance.task_id == task.task_id,
+            TaskInstance.run_id == self.run_id,
+        )
+        existing_indexes = {i for (i,) in query}
+        missing_indexes = 
set(range(total_length)).difference(set(existing_indexes))

Review Comment:
   ```suggestion
           missing_indexes = 
set(range(total_length)).difference(existing_indexes)
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to