dstandish commented on code in PR #27491:
URL: https://github.com/apache/airflow/pull/27491#discussion_r1024769307


##########
airflow/models/dagrun.py:
##########
@@ -1080,43 +1096,50 @@ def _create_task_instances(
             # TODO[HA]: We probably need to savepoint this so we can keep the 
transaction alive.
             session.rollback()
 
-    def _revise_mapped_task_indexes(self, task: MappedOperator, session: 
Session) -> Iterable[TI]:
-        """Check if task increased or reduced in length and handle 
appropriately"""
+    def _revise_map_indexes_if_mapped(self, task: Operator, *, session: 
Session) -> Iterator[TI]:
+        """Check if task increased or reduced in length and handle 
appropriately.
+
+        Currently missing tis are created and returned if possible. Expansion

Review Comment:
   ```suggestion
           Missing tis are created and returned if possible. Expansion
   ```
   
   "Currently missing" made me question whether you meant "currently, missing 
tis are..." or "currently missing tis are...". Just saying "missing tis" does 
the job I think.



##########
airflow/models/abstractoperator.py:
##########
@@ -422,6 +424,121 @@ def get_mapped_ti_count(self, run_id: str, *, session: 
Session) -> int:
         counts = (g.get_mapped_ti_count(run_id, session=session) for g in 
mapped_task_groups)
         return functools.reduce(operator.mul, counts)
 
+    def expand_mapped_task(self, run_id: str, *, session: Session) -> 
tuple[Sequence[TaskInstance], int]:
+        """Create the mapped task instances for mapped task.
+
+        :raise NotMapped: If this task does not need expansion.
+        :return: The newly created mapped task instances (if any) in ascending
+            order by map index, and the maximum map index value.
+        """
+        from sqlalchemy import func, or_
+
+        from airflow.models.baseoperator import BaseOperator
+        from airflow.models.mappedoperator import MappedOperator
+        from airflow.models.taskinstance import TaskInstance
+        from airflow.settings import task_instance_mutation_hook
+
+        if not isinstance(self, (BaseOperator, MappedOperator)):
+            raise RuntimeError(f"cannot expand unrecognized operator type 
{type(self).__name__}")
+
+        try:
+            total_length: int | None = self.get_mapped_ti_count(run_id, 
session=session)
+        except NotFullyPopulated as e:
+            # It's possible that the upstream tasks are not yet done, but we
+            # don't have upstream of upstreams in partial DAGs (possible in the
+            # mini-scheduler), so we ignore this exception.
+            if not self.dag or not self.dag.partial:
+                self.log.error(
+                    "Cannot expand %r for run %s; missing upstream values: %s",
+                    self,
+                    run_id,
+                    sorted(e.missing),
+                )
+            total_length = None
+
+        state: TaskInstanceState | None = None
+        unmapped_ti: TaskInstance | None = (
+            session.query(TaskInstance)
+            .filter(
+                TaskInstance.dag_id == self.dag_id,
+                TaskInstance.task_id == self.task_id,
+                TaskInstance.run_id == run_id,
+                TaskInstance.map_index == -1,
+                or_(TaskInstance.state.in_(State.unfinished), 
TaskInstance.state.is_(None)),
+            )
+            .one_or_none()
+        )
+
+        all_expanded_tis: list[TaskInstance] = []
+
+        if unmapped_ti:
+            # The unmapped task instance still exists and is unfinished, i.e. 
we
+            # haven't tried to run it before.
+            if total_length is None:
+                # If the DAG is partial, it's likely that the upstream tasks
+                # are not done yet, so the task can't fail yet.

Review Comment:
   ```
                   # are not done yet, so the task can't fail yet.
   ```
   
   do you mean "so we can't fail the task yet"?
   
   also... a little unclear, to me anyway, what is the subject of this 
comment...



##########
airflow/models/abstractoperator.py:
##########
@@ -422,6 +424,121 @@ def get_mapped_ti_count(self, run_id: str, *, session: 
Session) -> int:
         counts = (g.get_mapped_ti_count(run_id, session=session) for g in 
mapped_task_groups)
         return functools.reduce(operator.mul, counts)
 
+    def expand_mapped_task(self, run_id: str, *, session: Session) -> 
tuple[Sequence[TaskInstance], int]:
+        """Create the mapped task instances for mapped task.
+
+        :raise NotMapped: If this task does not need expansion.
+        :return: The newly created mapped task instances (if any) in ascending
+            order by map index, and the maximum map index value.
+        """
+        from sqlalchemy import func, or_
+
+        from airflow.models.baseoperator import BaseOperator
+        from airflow.models.mappedoperator import MappedOperator
+        from airflow.models.taskinstance import TaskInstance
+        from airflow.settings import task_instance_mutation_hook
+
+        if not isinstance(self, (BaseOperator, MappedOperator)):
+            raise RuntimeError(f"cannot expand unrecognized operator type 
{type(self).__name__}")
+
+        try:
+            total_length: int | None = self.get_mapped_ti_count(run_id, 
session=session)
+        except NotFullyPopulated as e:
+            # It's possible that the upstream tasks are not yet done, but we
+            # don't have upstream of upstreams in partial DAGs (possible in the

Review Comment:
   i don't understand what this means:
   
   ```
   # It's possible that the upstream tasks are not yet done, but we
   # don't have upstream of upstreams in partial DAGs
   ```



##########
airflow/models/xcom_arg.py:
##########
@@ -196,7 +191,10 @@ def get_task_map_length(self, run_id: str, *, session: 
Session) -> int | None:
     def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any:
         """Pull XCom value.
 
-        This should only be called during ``op.execute()`` in respectable 
context.
+        This should only be called during ``op.execute()`` in respectable
+        context. Note that although the ``ResolveMixin`` parent mixin also has 
a

Review Comment:
   `respectable context` could probably benefit from clearer wording. though i 
realize this was already here.  



##########
airflow/models/dagrun.py:
##########
@@ -1080,43 +1096,50 @@ def _create_task_instances(
             # TODO[HA]: We probably need to savepoint this so we can keep the 
transaction alive.
             session.rollback()
 
-    def _revise_mapped_task_indexes(self, task: MappedOperator, session: 
Session) -> Iterable[TI]:
-        """Check if task increased or reduced in length and handle 
appropriately"""
+    def _revise_map_indexes_if_mapped(self, task: Operator, *, session: 
Session) -> Iterator[TI]:
+        """Check if task increased or reduced in length and handle 
appropriately.
+
+        Currently missing tis are created and returned if possible. Expansion
+        only happens if depended upstreams are all ready; if not all of them

Review Comment:
   ```suggestion
           only happens if all upstreams are ready; if not all of them
   ```



##########
airflow/models/dagrun.py:
##########
@@ -730,30 +729,47 @@ def _get_ready_tis(
             finished_tis=finished_tis,
         )
 
+        def _expand_mapped_task_if_needed(ti: TI) -> Iterable[TI] | None:
+            """Try to expand the ti, if needed.
+
+            If the ti needs expansion, newly created task instances are
+            returned. The original ti is modified in-place and assigned the
+            ``map_index`` of 0.
+
+            If the ti does not need expansion, either because the task is not
+            mapped, or has already been expanded, *None* is returned.
+            """
+            if ti.map_index >= 0:  # Already expanded, we're good.
+                return None
+            try:
+                expanded_tis, _ = ti.task.expand_mapped_task(self.run_id, 
session=session)
+            except NotMapped:  # Not a mapped task, nothing needed.
+                return None
+            if expanded_tis:
+                assert expanded_tis[0] is ti

Review Comment:
   don't usually see asserts in airflow.  How come using one here? e.g. as 
opposed to raising



##########
airflow/models/abstractoperator.py:
##########
@@ -422,6 +424,121 @@ def get_mapped_ti_count(self, run_id: str, *, session: 
Session) -> int:
         counts = (g.get_mapped_ti_count(run_id, session=session) for g in 
mapped_task_groups)
         return functools.reduce(operator.mul, counts)
 
+    def expand_mapped_task(self, run_id: str, *, session: Session) -> 
tuple[Sequence[TaskInstance], int]:
+        """Create the mapped task instances for mapped task.
+
+        :raise NotMapped: If this task does not need expansion.
+        :return: The newly created mapped task instances (if any) in ascending
+            order by map index, and the maximum map index value.
+        """
+        from sqlalchemy import func, or_
+
+        from airflow.models.baseoperator import BaseOperator
+        from airflow.models.mappedoperator import MappedOperator
+        from airflow.models.taskinstance import TaskInstance
+        from airflow.settings import task_instance_mutation_hook
+
+        if not isinstance(self, (BaseOperator, MappedOperator)):
+            raise RuntimeError(f"cannot expand unrecognized operator type 
{type(self).__name__}")
+
+        try:
+            total_length: int | None = self.get_mapped_ti_count(run_id, 
session=session)
+        except NotFullyPopulated as e:
+            # It's possible that the upstream tasks are not yet done, but we
+            # don't have upstream of upstreams in partial DAGs (possible in the
+            # mini-scheduler), so we ignore this exception.
+            if not self.dag or not self.dag.partial:
+                self.log.error(
+                    "Cannot expand %r for run %s; missing upstream values: %s",
+                    self,
+                    run_id,
+                    sorted(e.missing),
+                )
+            total_length = None
+
+        state: TaskInstanceState | None = None
+        unmapped_ti: TaskInstance | None = (
+            session.query(TaskInstance)
+            .filter(
+                TaskInstance.dag_id == self.dag_id,
+                TaskInstance.task_id == self.task_id,
+                TaskInstance.run_id == run_id,
+                TaskInstance.map_index == -1,
+                or_(TaskInstance.state.in_(State.unfinished), 
TaskInstance.state.is_(None)),
+            )
+            .one_or_none()
+        )
+
+        all_expanded_tis: list[TaskInstance] = []
+
+        if unmapped_ti:
+            # The unmapped task instance still exists and is unfinished, i.e. 
we
+            # haven't tried to run it before.
+            if total_length is None:
+                # If the DAG is partial, it's likely that the upstream tasks
+                # are not done yet, so the task can't fail yet.
+                if not self.dag or not self.dag.partial:
+                    unmapped_ti.state = TaskInstanceState.UPSTREAM_FAILED
+                indexes_to_map: Iterable[int] = ()
+            elif total_length < 1:
+                # If the upstream maps this to a zero-length value, simply mark
+                # the unmapped task instance as SKIPPED (if needed).
+                self.log.info(
+                    "Marking %s as SKIPPED since the map has %d values to 
expand",
+                    unmapped_ti,
+                    total_length,
+                )
+                unmapped_ti.state = TaskInstanceState.SKIPPED
+                indexes_to_map = ()
+            else:
+                # Otherwise convert this into the first mapped index, and 
create
+                # TaskInstance for other indexes.
+                unmapped_ti.map_index = 0
+                self.log.debug("Updated in place to become %s", unmapped_ti)
+                all_expanded_tis.append(unmapped_ti)
+                indexes_to_map = range(1, total_length)
+            state = unmapped_ti.state
+        elif not total_length:
+            # Nothing to fixup.
+            indexes_to_map = ()
+        else:
+            # Only create "missing" ones.
+            current_max_mapping = (
+                session.query(func.max(TaskInstance.map_index))
+                .filter(
+                    TaskInstance.dag_id == self.dag_id,
+                    TaskInstance.task_id == self.task_id,
+                    TaskInstance.run_id == run_id,
+                )
+                .scalar()
+            )
+            indexes_to_map = range(current_max_mapping + 1, total_length)
+
+        for index in indexes_to_map:
+            # TODO: Make more efficient with 
bulk_insert_mappings/bulk_save_mappings.
+            ti = TaskInstance(self, run_id=run_id, map_index=index, 
state=state)
+            self.log.debug("Expanding TIs upserted %s", ti)
+            task_instance_mutation_hook(ti)
+            ti = session.merge(ti)
+            ti.refresh_from_task(self)  # session.merge() loses task 
information.
+            all_expanded_tis.append(ti)
+
+        # Coerce the None case to 0 -- these two are almost treated 
identically,
+        # except the unmapped ti (if exists) is marked to different states.
+        total_expanded_ti_count = total_length or 0
+
+        # Set to "REMOVED" any (old) TaskInstances with map indices greater
+        # than the current map value

Review Comment:
   ```suggestion
           # than the current number of mapped TIs
   ```
   
   or something?
   
   "current map value" read like "current map index"



##########
airflow/models/abstractoperator.py:
##########
@@ -422,6 +424,121 @@ def get_mapped_ti_count(self, run_id: str, *, session: 
Session) -> int:
         counts = (g.get_mapped_ti_count(run_id, session=session) for g in 
mapped_task_groups)
         return functools.reduce(operator.mul, counts)
 
+    def expand_mapped_task(self, run_id: str, *, session: Session) -> 
tuple[Sequence[TaskInstance], int]:
+        """Create the mapped task instances for mapped task.
+
+        :raise NotMapped: If this task does not need expansion.
+        :return: The newly created mapped task instances (if any) in ascending
+            order by map index, and the maximum map index value.
+        """
+        from sqlalchemy import func, or_
+
+        from airflow.models.baseoperator import BaseOperator
+        from airflow.models.mappedoperator import MappedOperator
+        from airflow.models.taskinstance import TaskInstance
+        from airflow.settings import task_instance_mutation_hook
+
+        if not isinstance(self, (BaseOperator, MappedOperator)):
+            raise RuntimeError(f"cannot expand unrecognized operator type 
{type(self).__name__}")
+
+        try:
+            total_length: int | None = self.get_mapped_ti_count(run_id, 
session=session)
+        except NotFullyPopulated as e:
+            # It's possible that the upstream tasks are not yet done, but we
+            # don't have upstream of upstreams in partial DAGs (possible in the
+            # mini-scheduler), so we ignore this exception.
+            if not self.dag or not self.dag.partial:
+                self.log.error(
+                    "Cannot expand %r for run %s; missing upstream values: %s",
+                    self,
+                    run_id,
+                    sorted(e.missing),
+                )
+            total_length = None
+
+        state: TaskInstanceState | None = None
+        unmapped_ti: TaskInstance | None = (
+            session.query(TaskInstance)
+            .filter(
+                TaskInstance.dag_id == self.dag_id,
+                TaskInstance.task_id == self.task_id,
+                TaskInstance.run_id == run_id,
+                TaskInstance.map_index == -1,
+                or_(TaskInstance.state.in_(State.unfinished), 
TaskInstance.state.is_(None)),
+            )
+            .one_or_none()
+        )
+
+        all_expanded_tis: list[TaskInstance] = []
+
+        if unmapped_ti:
+            # The unmapped task instance still exists and is unfinished, i.e. 
we
+            # haven't tried to run it before.

Review Comment:
   > i.e. we haven't tried to run it before.
   
   is that true? what if this is a retry?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to