kaxil commented on code in PR #62922:
URL: https://github.com/apache/airflow/pull/62922#discussion_r3293416741


##########
task-sdk/src/airflow/sdk/execution_time/executor.py:
##########
@@ -0,0 +1,241 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+import contextvars
+import inspect
+import logging
+import time
+from asyncio import AbstractEventLoop, Semaphore
+from collections.abc import Callable, Generator
+from concurrent.futures import Future, ThreadPoolExecutor, as_completed
+from typing import TYPE_CHECKING, Any, cast
+
+from airflow.sdk import BaseAsyncOperator, BaseOperator, TaskInstanceState, 
timezone
+from airflow.sdk.bases.operator import ExecutorSafeguard
+from airflow.sdk.definitions._internal.logging_mixin import LoggingMixin
+from airflow.sdk.exceptions import AirflowRescheduleTaskInstanceException, 
TaskDeferred
+from airflow.sdk.execution_time.callback_runner import create_executable_runner
+from airflow.sdk.execution_time.context import context_get_outlet_events
+from airflow.sdk.execution_time.task_runner import (
+    RuntimeTaskInstance,
+    _execute_task,
+    _run_task_state_change_callbacks,
+)
+
+if TYPE_CHECKING:
+    from structlog.typing import FilteringBoundLogger as Logger
+
+    from airflow.sdk import Context
+    from airflow.sdk.execution_time.task_runner import IndexedTaskInstance
+
+
+def collect_futures(
+    loop: AbstractEventLoop, futures: list[Any]
+) -> Generator[Future | asyncio.futures.Future, None, None]:
+    """
+    Yield futures as they complete (sync or async).
+
+    :param loop: The asyncio event loop to use for async tasks
+    :param futures: List of Future or asyncio.futures.Future objects to collect
+    :return: Generator yielding Future or asyncio.futures.Future objects as 
they complete
+    """
+    yield from as_completed(f for f in futures if isinstance(f, Future))
+
+    async_tasks = [f for f in futures if isinstance(f, asyncio.futures.Future)]
+
+    if async_tasks:
+        for task, _ in zip(
+            async_tasks,
+            loop.run_until_complete(asyncio.gather(*async_tasks, 
return_exceptions=True)),
+        ):
+            yield task
+
+
+class ConcurrentExecutor:
+    """
+    Executes both sync and async functions concurrently.
+
+    Sync functions run in a ThreadPoolExecutor.
+    Async coroutines run on an asyncio event loop with a semaphore limit.
+    """
+
+    def __init__(self, loop: AbstractEventLoop, max_workers: int = 4):
+        self._loop = loop
+        self._semaphore = Semaphore(max_workers)
+        self._thread_pool = ThreadPoolExecutor(max_workers=max_workers)
+
+    def __enter__(self):
+        return self
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        if self._thread_pool:
+            self._thread_pool.shutdown(wait=True)
+
+    def submit(self, func: Callable, *args, **kwargs):
+        if inspect.iscoroutine(func):
+            coro = func
+        elif inspect.iscoroutinefunction(func):
+            coro = func(*args, **kwargs)
+        else:
+            return self._thread_pool.submit(func, *args, **kwargs)
+
+        async def guarded():
+            async with self._semaphore:
+                return await coro
+
+        return self._loop.create_task(guarded())
+
+
+class TaskExecutor(LoggingMixin):
+    """Base class to run an operator or trigger with given task context and 
task instance."""
+
+    def __init__(
+        self,
+        task_instance: IndexedTaskInstance,
+    ):
+        super().__init__()
+        self.task_instance = task_instance
+        self._result: Any | None = None
+        self._start_time: float | None = None
+
+    @property
+    def dag_id(self) -> str:
+        return self.task_instance.dag_id
+
+    @property
+    def task_id(self) -> str:
+        return self.task_instance.task_id
+
+    @property
+    def task_index(self) -> int:
+        return self.task_instance.index
+
+    @property
+    def xcom_key(self):
+        return self.task_instance.xcom_key
+
+    @property
+    def operator(self) -> BaseOperator:
+        return self.task_instance.task
+
+    @property
+    def is_async(self) -> bool:
+        return self.task_instance.is_async
+
+    def run(self, context: Context):
+        return _execute_task(context, self.task_instance, self.log)

Review Comment:
   Calling `_execute_task` from a `ConcurrentExecutor`-submitted thread 
introduces a `os.environ` race that the existing single-process supervisor flow 
never had.
   
   Inside `_execute_task` (task_runner.py around line 2009), there's a 
`os.environ.update(airflow_context_vars)` that sets `AIRFLOW_CTX_TASK_ID`, 
`AIRFLOW_CTX_MAP_INDEX`, etc. globally. Two `TaskExecutor.run(...)` calls 
running in parallel threads will overwrite each other's env vars, so any user 
task that reads them (BashOperator's templated env, subprocess hooks, 
third-party libs picking up `AIRFLOW_CTX_*`) sees a value belonging to a 
sibling mapped task.
   
   Options: thread-local-ify the env update in `_execute_task`, pass the 
context vars through `context` and have operators read from there, or take a 
per-thread snapshot/restore around the call site here.



##########
task-sdk/src/airflow/sdk/definitions/iterableoperator.py:
##########
@@ -0,0 +1,453 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+import copy
+import os
+from collections import deque
+from collections.abc import Iterable, Mapping, Sequence
+from concurrent.futures import Future
+from itertools import chain
+from typing import TYPE_CHECKING, Any
+
+try:
+    # Python 3.12+
+    from itertools import batched  # type: ignore[attr-defined]
+except ImportError:
+    from more_itertools import batched  # type: ignore[no-redef]
+
+try:
+    # Python 3.11+
+    BaseExceptionGroup
+except NameError:
+    from exceptiongroup import BaseExceptionGroup
+
+from airflow.sdk import BaseXCom, TaskInstanceState, timezone
+from airflow.sdk.bases.operator import BaseOperator, 
DecoratedDeferredAsyncOperator, event_loop
+from airflow.sdk.definitions._internal.expandinput import 
PartitionedExpandInput
+from airflow.sdk.definitions.mappedoperator import MappedOperator
+from airflow.sdk.definitions.xcom_arg import MapXComArg, XComArg  # noqa: F401
+from airflow.sdk.exceptions import (
+    AirflowRescheduleTaskInstanceException,
+    AirflowTaskTimeout,
+    TaskDeferred,
+)
+from airflow.sdk.execution_time.executor import ConcurrentExecutor, 
TaskExecutor, collect_futures
+from airflow.sdk.execution_time.task_runner import IndexedTaskInstance
+
+if TYPE_CHECKING:
+    import jinja2
+
+    from airflow.sdk.definitions._internal.expandinput import ExpandInput
+    from airflow.sdk.definitions.context import Context
+    from airflow.sdk.execution_time.lazy_sequence import XComIterable
+
+
+class IterableOperator(BaseOperator):
+    """Object representing an iterable operator in a DAG."""
+
+    _operator: MappedOperator
+    expand_input: ExpandInput
+    partial_kwargs: dict[str, Any]
+    shallow_copy_attrs: Sequence[str] = (
+        "_operator",
+        "expand_input",
+        "partial_kwargs",
+        "_log",
+    )
+
+    def __init__(
+        self,
+        *,
+        operator: MappedOperator,
+        expand_input: ExpandInput,
+        **kwargs,
+    ):
+        super().__init__(
+            **{
+                **kwargs,
+                "task_id": operator.task_id,
+                "owner": operator.owner,
+                "email": operator.email,
+                "email_on_retry": operator.email_on_retry,
+                "email_on_failure": operator.email_on_failure,
+                "retries": 0,  # We should not retry the IterableOperator, 
only the mapped ti's should be retried
+                "retry_delay": operator.retry_delay,
+                "retry_exponential_backoff": 
operator.retry_exponential_backoff,
+                "max_retry_delay": operator.max_retry_delay,
+                "start_date": operator.start_date,
+                "end_date": operator.end_date,
+                "depends_on_past": operator.depends_on_past,
+                "ignore_first_depends_on_past": 
operator.ignore_first_depends_on_past,
+                "wait_for_past_depends_before_skipping": 
operator.wait_for_past_depends_before_skipping,
+                "wait_for_downstream": operator.wait_for_downstream,
+                "dag": operator.dag,
+                "priority_weight": operator.priority_weight,
+                "queue": operator.queue,
+                "pool": operator.pool,
+                "pool_slots": operator.pool_slots,
+                "execution_timeout": None,
+                "trigger_rule": operator.trigger_rule,
+                "resources": operator.resources,
+                "run_as_user": operator.run_as_user,
+                "map_index_template": operator.map_index_template,
+                "max_active_tis_per_dag": operator.max_active_tis_per_dag,
+                "max_active_tis_per_dagrun": 
operator.max_active_tis_per_dagrun,
+                "executor": operator.executor,
+                "executor_config": operator.executor_config,
+                "inlets": operator.inlets,
+                "outlets": operator.outlets,
+                "task_group": operator.task_group,
+                "doc": operator.doc,
+                "doc_md": operator.doc_md,
+                "doc_json": operator.doc_json,
+                "doc_yaml": operator.doc_yaml,
+                "doc_rst": operator.doc_rst,
+                "task_display_name": operator.task_display_name,
+                "allow_nested_operators": operator.allow_nested_operators,
+            }
+        )
+        self._operator = operator
+        self.expand_input = expand_input
+        self.partial_kwargs = operator.partial_kwargs or {}
+        self.partial_kwargs.pop("partition_size", None)
+        self.max_workers = self.partial_kwargs.pop("task_concurrency", None) 
or os.cpu_count() or 1
+        self._number_of_tasks: int = 0
+        XComArg.apply_upstream_relationship(self, self.expand_input.value)
+
+    @property
+    def task_type(self) -> str:
+        """@property: type of the task."""
+        return self._operator.__class__.__name__
+
+    @property
+    def timeout(self) -> float | None:
+        if self.execution_timeout:
+            return self.execution_timeout.total_seconds()
+        return None
+
+    def _do_render_template_fields(
+        self,
+        parent: Any,
+        template_fields: Iterable[str],
+        context: Context,
+        jinja_env: jinja2.Environment,
+        seen_oids: set[int],
+    ) -> None:
+        # IterableOperator doesn't need to render template fields as the 
actual operator's template fields
+        # will be rendered in the TaskExecutor when running each mapped task 
instance.
+        pass
+
+    def _get_specified_expand_input(self) -> ExpandInput:
+        return self.expand_input
+
+    def _unmap_operator(
+        self, context: Context, mapped_kwargs: Context, jinja_env: 
jinja2.Environment
+    ) -> BaseOperator:
+        from airflow.sdk.execution_time.context import 
context_update_for_unmapped
+
+        self._number_of_tasks += 1
+        unmapped_task = self._operator.unmap(mapped_kwargs)
+        # Make sure deferred operators will always raise a DeferredTask 
exception when executed
+        unmapped_task.start_from_trigger = False
+        context_update_for_unmapped(context, unmapped_task)
+
+        unmapped_task._do_render_template_fields(
+            parent=unmapped_task,
+            template_fields=self._operator.template_fields,
+            context=context,
+            jinja_env=jinja_env,
+            seen_oids=set(),
+        )
+        return unmapped_task
+
+    def _xcom_push(self, task: IndexedTaskInstance, value: Any) -> None:
+        if task.xcom_pushed:
+            self.log.debug(
+                "XCom already pushed for task_id %s with index %s",
+                task.task_id,
+                task.index,
+            )
+        else:
+            self.log.debug(
+                "Pushing XCom for task_id %s with index %s",
+                task.task_id,
+                task.index,
+            )
+
+            task.xcom_push(key=BaseXCom.XCOM_RETURN_KEY, value=value)
+
+    def _run_tasks(
+        self,
+        context: Context,
+        tasks: Iterable[IndexedTaskInstance],
+    ) -> XComIterable | None:
+        exceptions: list[BaseException] = []
+        reschedule_date = timezone.utcnow()
+        prev_futures_count = 0
+        futures: dict[Future | asyncio.futures.Future, IndexedTaskInstance] = 
{}
+        deferred_tasks: deque[IndexedTaskInstance] = deque()
+        failed_tasks: deque[IndexedTaskInstance] = deque()
+        chunked_tasks = batched(tasks, self.max_workers)
+        do_xcom_push = True
+
+        self.log.info("Running tasks with %d workers", self.max_workers)
+
+        with event_loop() as loop:
+            with ConcurrentExecutor(loop=loop, max_workers=self.max_workers) 
as executor:
+                for task in next(chunked_tasks, []):
+                    do_xcom_push = task.do_xcom_push
+                    if task.is_async:
+                        future = executor.submit(self._run_async_operator, 
context, task)
+                    else:
+                        future = executor.submit(self._run_operator, context, 
task)
+                    futures[future] = task
+
+                while futures:
+                    futures_count = len(futures)
+
+                    if futures_count != prev_futures_count:
+                        self.log.info("Number of remaining futures: %s", 
futures_count)
+                        prev_futures_count = futures_count
+
+                    for future in collect_futures(loop, list(futures.keys())):
+                        task = futures.pop(future)
+
+                        try:
+                            if isinstance(future, asyncio.futures.Future):
+                                result = future.result()
+                            else:
+                                result = future.result(timeout=self.timeout)
+
+                            self.log.debug("result: %s", result)
+
+                            if result is not None and task.do_xcom_push:
+                                self._xcom_push(
+                                    task=task,
+                                    value=result,
+                                )
+                        except TaskDeferred as task_deferred:
+                            operator = DecoratedDeferredAsyncOperator(
+                                operator=task.task, task_deferred=task_deferred
+                            )
+                            deferred_tasks.append(
+                                self._create_mapped_task(
+                                    run_id=task.run_id,
+                                    index=task.index,
+                                    map_index=task.map_index,  # type: 
ignore[arg-type]
+                                    try_number=task.try_number,
+                                    operator=operator,
+                                )
+                            )
+                        except asyncio.TimeoutError as e:
+                            self.log.warning("A timeout occurred for task_id 
%s", task.task_id)
+                            if task.next_try_number > (self.retries or 0):
+                                exceptions.append(AirflowTaskTimeout(e))
+                            else:
+                                reschedule_date = min(reschedule_date, 
task.next_retry_datetime())
+                                failed_tasks.append(task)
+                        except AirflowRescheduleTaskInstanceException as e:
+                            reschedule_date = min(reschedule_date, 
e.reschedule_date)
+                            self.log.exception(
+                                "An exception occurred for task_id %s with 
index %s, it has been rescheduled at %s",
+                                task.task_id,
+                                task.index,
+                                reschedule_date,
+                            )
+                            failed_tasks.append(e.task)
+                        except Exception as e:
+                            self.log.exception(
+                                "An exception occurred for task_id %s with 
index %s",
+                                task.task_id,
+                                task.index,
+                            )
+                            exceptions.append(e)
+
+                    if len(futures) < self.max_workers:
+                        chunked_tasks = chain(list(deferred_tasks), 
chunked_tasks)
+                        deferred_tasks.clear()
+
+                        for task in next(chunked_tasks, []):
+                            if task.is_async:
+                                future = 
executor.submit(self._run_async_operator, context, task)
+                            else:
+                                future = executor.submit(self._run_operator, 
context, task)
+                            futures[future] = task
+
+        if not failed_tasks:
+            if exceptions:
+                raise BaseExceptionGroup("Multiple sub-task failures", 
exceptions)
+            if do_xcom_push:
+                from airflow.sdk.execution_time.lazy_sequence import 
XComIterable
+
+                return XComIterable(
+                    task_id=self.task_id,
+                    dag_id=self.dag_id,
+                    run_id=context["run_id"],
+                    length=self._number_of_tasks,
+                    map_index=context["ti"].map_index,
+                )
+            return None
+
+        # If the retry time is still in the future we defer the operator so 
the worker
+        # slot is released. If the retry time has already passed we 
immediately re-run
+        # the failed tasks without deferring.
+        if reschedule_date > timezone.utcnow():
+            # TODO: This is tricky as that import doesn't exist in Task SDK
+            from airflow.providers.standard.triggers.temporal import 
DateTimeTrigger
+
+            self.defer(
+                trigger=DateTimeTrigger(reschedule_date),
+                method_name=self.execute_failed_tasks.__name__,
+                kwargs={
+                    "failed_tasks": {failed_task.index for failed_task in 
failed_tasks},
+                    "try_number": next(iter(failed_tasks)).try_number,
+                },
+            )
+
+        return self._run_tasks(context=context, tasks=list(failed_tasks))
+
+    def _run_operator(self, context: Context, task_instance: 
IndexedTaskInstance):
+        with TaskExecutor(task_instance=task_instance) as executor:
+            return executor.run(
+                context={
+                    **dict(context),

Review Comment:
   Partially addressed from the previous review. The shallow merge 
`{**dict(context), **{"ti": ..., "task_instance": ...}}` creates a fresh 
top-level dict per task, which correctly prevents `ti` / `task_instance` from 
being clobbered across threads.
   
   But the context's nested mutable values (`params`, `dag_run.conf`, 
`outlet_events`, `dag_run`, `inlet_events`, ...) are still shared by reference. 
If user code mutates any context-attached object (`context['params']['x'] = 
...`, `context['outlet_events'].add(...)`, etc.), sibling threads see it.
   
   `copy.deepcopy(context)` is the safe choice; if that's too heavyweight for 
the hot path, audit the context shape and explicitly `copy()` the mutable 
fields that user code is allowed to touch.



##########
task-sdk/src/airflow/sdk/definitions/_internal/expandinput.py:
##########
@@ -184,6 +243,19 @@ def iter_references(self) -> Iterable[tuple[Operator, 
str]]:
             if isinstance(x, XComArg):
                 yield from x.iter_references()
 
+    def iter_values(self, context: Mapping[str, Any]) -> Iterable[Any]:
+        from airflow.sdk.definitions.xcom_arg import XComArg
+
+        resolved = {k: v.resolve(context) if isinstance(v, XComArg) else v for 
k, v in self.value.items()}
+        keys = list(resolved)
+        for items in zip(

Review Comment:
   Semantic divergence from classic dynamic task mapping is still unresolved.
   
   `DictOfListsExpandInput.iter_values` uses `zip(*values)`, so 
`expand(a=[1,2], b=[3,4])` yields 2 dicts (`{a:1,b:3}`, `{a:2,b:4}`). But the 
scheduler-side `SchedulerDictOfListsExpandInput.get_total_map_length` uses 
`functools.reduce(operator.mul, lengths)` -- cartesian -- so the same call 
mapped via `.expand()` produces 4 task instances.
   
   Two issues:
   1. `.iterate()` on a 2-key expand silently drops items the user expects to 
be processed.
   2. Users switching between `.expand()` and `.iterate()` get different 
cardinality with no warning or doc note.
   
   `test_dict_of_lists_expand_input_iter_values` currently locks in the zip 
behavior, which confirms the divergence is intentional but it isn't documented 
anywhere user-visible. Options: (a) make this cartesian to match DTM, (b) 
document the zip choice prominently in the DTI docs, or (c) raise on multi-key 
expand and force the user to use `iterate_kwargs` explicitly.



##########
task-sdk/src/airflow/sdk/definitions/iterableoperator.py:
##########
@@ -0,0 +1,453 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+import copy
+import os
+from collections import deque
+from collections.abc import Iterable, Mapping, Sequence
+from concurrent.futures import Future
+from itertools import chain
+from typing import TYPE_CHECKING, Any
+
+try:
+    # Python 3.12+
+    from itertools import batched  # type: ignore[attr-defined]
+except ImportError:
+    from more_itertools import batched  # type: ignore[no-redef]
+
+try:
+    # Python 3.11+
+    BaseExceptionGroup
+except NameError:
+    from exceptiongroup import BaseExceptionGroup
+
+from airflow.sdk import BaseXCom, TaskInstanceState, timezone
+from airflow.sdk.bases.operator import BaseOperator, 
DecoratedDeferredAsyncOperator, event_loop
+from airflow.sdk.definitions._internal.expandinput import 
PartitionedExpandInput
+from airflow.sdk.definitions.mappedoperator import MappedOperator
+from airflow.sdk.definitions.xcom_arg import MapXComArg, XComArg  # noqa: F401
+from airflow.sdk.exceptions import (
+    AirflowRescheduleTaskInstanceException,
+    AirflowTaskTimeout,
+    TaskDeferred,
+)
+from airflow.sdk.execution_time.executor import ConcurrentExecutor, 
TaskExecutor, collect_futures
+from airflow.sdk.execution_time.task_runner import IndexedTaskInstance
+
+if TYPE_CHECKING:
+    import jinja2
+
+    from airflow.sdk.definitions._internal.expandinput import ExpandInput
+    from airflow.sdk.definitions.context import Context
+    from airflow.sdk.execution_time.lazy_sequence import XComIterable
+
+
+class IterableOperator(BaseOperator):
+    """Object representing an iterable operator in a DAG."""
+
+    _operator: MappedOperator
+    expand_input: ExpandInput
+    partial_kwargs: dict[str, Any]
+    shallow_copy_attrs: Sequence[str] = (
+        "_operator",
+        "expand_input",
+        "partial_kwargs",
+        "_log",
+    )
+
+    def __init__(
+        self,
+        *,
+        operator: MappedOperator,
+        expand_input: ExpandInput,
+        **kwargs,
+    ):
+        super().__init__(
+            **{
+                **kwargs,
+                "task_id": operator.task_id,
+                "owner": operator.owner,
+                "email": operator.email,
+                "email_on_retry": operator.email_on_retry,
+                "email_on_failure": operator.email_on_failure,
+                "retries": 0,  # We should not retry the IterableOperator, 
only the mapped ti's should be retried
+                "retry_delay": operator.retry_delay,
+                "retry_exponential_backoff": 
operator.retry_exponential_backoff,
+                "max_retry_delay": operator.max_retry_delay,
+                "start_date": operator.start_date,
+                "end_date": operator.end_date,
+                "depends_on_past": operator.depends_on_past,
+                "ignore_first_depends_on_past": 
operator.ignore_first_depends_on_past,
+                "wait_for_past_depends_before_skipping": 
operator.wait_for_past_depends_before_skipping,
+                "wait_for_downstream": operator.wait_for_downstream,
+                "dag": operator.dag,
+                "priority_weight": operator.priority_weight,
+                "queue": operator.queue,
+                "pool": operator.pool,
+                "pool_slots": operator.pool_slots,
+                "execution_timeout": None,
+                "trigger_rule": operator.trigger_rule,
+                "resources": operator.resources,
+                "run_as_user": operator.run_as_user,
+                "map_index_template": operator.map_index_template,
+                "max_active_tis_per_dag": operator.max_active_tis_per_dag,
+                "max_active_tis_per_dagrun": 
operator.max_active_tis_per_dagrun,
+                "executor": operator.executor,
+                "executor_config": operator.executor_config,
+                "inlets": operator.inlets,
+                "outlets": operator.outlets,
+                "task_group": operator.task_group,
+                "doc": operator.doc,
+                "doc_md": operator.doc_md,
+                "doc_json": operator.doc_json,
+                "doc_yaml": operator.doc_yaml,
+                "doc_rst": operator.doc_rst,
+                "task_display_name": operator.task_display_name,
+                "allow_nested_operators": operator.allow_nested_operators,
+            }
+        )
+        self._operator = operator
+        self.expand_input = expand_input
+        self.partial_kwargs = operator.partial_kwargs or {}
+        self.partial_kwargs.pop("partition_size", None)
+        self.max_workers = self.partial_kwargs.pop("task_concurrency", None) 
or os.cpu_count() or 1
+        self._number_of_tasks: int = 0
+        XComArg.apply_upstream_relationship(self, self.expand_input.value)
+
+    @property
+    def task_type(self) -> str:
+        """@property: type of the task."""
+        return self._operator.__class__.__name__
+
+    @property
+    def timeout(self) -> float | None:
+        if self.execution_timeout:
+            return self.execution_timeout.total_seconds()
+        return None
+
+    def _do_render_template_fields(
+        self,
+        parent: Any,
+        template_fields: Iterable[str],
+        context: Context,
+        jinja_env: jinja2.Environment,
+        seen_oids: set[int],
+    ) -> None:
+        # IterableOperator doesn't need to render template fields as the 
actual operator's template fields
+        # will be rendered in the TaskExecutor when running each mapped task 
instance.
+        pass
+
+    def _get_specified_expand_input(self) -> ExpandInput:
+        return self.expand_input
+
+    def _unmap_operator(
+        self, context: Context, mapped_kwargs: Context, jinja_env: 
jinja2.Environment
+    ) -> BaseOperator:
+        from airflow.sdk.execution_time.context import 
context_update_for_unmapped
+
+        self._number_of_tasks += 1
+        unmapped_task = self._operator.unmap(mapped_kwargs)
+        # Make sure deferred operators will always raise a DeferredTask 
exception when executed
+        unmapped_task.start_from_trigger = False
+        context_update_for_unmapped(context, unmapped_task)
+
+        unmapped_task._do_render_template_fields(
+            parent=unmapped_task,
+            template_fields=self._operator.template_fields,
+            context=context,
+            jinja_env=jinja_env,
+            seen_oids=set(),
+        )
+        return unmapped_task
+
+    def _xcom_push(self, task: IndexedTaskInstance, value: Any) -> None:
+        if task.xcom_pushed:
+            self.log.debug(
+                "XCom already pushed for task_id %s with index %s",
+                task.task_id,
+                task.index,
+            )
+        else:
+            self.log.debug(
+                "Pushing XCom for task_id %s with index %s",
+                task.task_id,
+                task.index,
+            )
+
+            task.xcom_push(key=BaseXCom.XCOM_RETURN_KEY, value=value)
+
+    def _run_tasks(
+        self,
+        context: Context,
+        tasks: Iterable[IndexedTaskInstance],
+    ) -> XComIterable | None:
+        exceptions: list[BaseException] = []
+        reschedule_date = timezone.utcnow()
+        prev_futures_count = 0
+        futures: dict[Future | asyncio.futures.Future, IndexedTaskInstance] = 
{}
+        deferred_tasks: deque[IndexedTaskInstance] = deque()
+        failed_tasks: deque[IndexedTaskInstance] = deque()
+        chunked_tasks = batched(tasks, self.max_workers)
+        do_xcom_push = True
+
+        self.log.info("Running tasks with %d workers", self.max_workers)
+
+        with event_loop() as loop:
+            with ConcurrentExecutor(loop=loop, max_workers=self.max_workers) 
as executor:
+                for task in next(chunked_tasks, []):
+                    do_xcom_push = task.do_xcom_push
+                    if task.is_async:
+                        future = executor.submit(self._run_async_operator, 
context, task)
+                    else:
+                        future = executor.submit(self._run_operator, context, 
task)
+                    futures[future] = task
+
+                while futures:
+                    futures_count = len(futures)
+
+                    if futures_count != prev_futures_count:
+                        self.log.info("Number of remaining futures: %s", 
futures_count)
+                        prev_futures_count = futures_count
+
+                    for future in collect_futures(loop, list(futures.keys())):
+                        task = futures.pop(future)
+
+                        try:
+                            if isinstance(future, asyncio.futures.Future):
+                                result = future.result()
+                            else:
+                                result = future.result(timeout=self.timeout)
+
+                            self.log.debug("result: %s", result)
+
+                            if result is not None and task.do_xcom_push:
+                                self._xcom_push(
+                                    task=task,
+                                    value=result,
+                                )
+                        except TaskDeferred as task_deferred:
+                            operator = DecoratedDeferredAsyncOperator(
+                                operator=task.task, task_deferred=task_deferred
+                            )
+                            deferred_tasks.append(
+                                self._create_mapped_task(
+                                    run_id=task.run_id,
+                                    index=task.index,
+                                    map_index=task.map_index,  # type: 
ignore[arg-type]
+                                    try_number=task.try_number,
+                                    operator=operator,
+                                )
+                            )
+                        except asyncio.TimeoutError as e:
+                            self.log.warning("A timeout occurred for task_id 
%s", task.task_id)
+                            if task.next_try_number > (self.retries or 0):
+                                exceptions.append(AirflowTaskTimeout(e))
+                            else:
+                                reschedule_date = min(reschedule_date, 
task.next_retry_datetime())
+                                failed_tasks.append(task)
+                        except AirflowRescheduleTaskInstanceException as e:
+                            reschedule_date = min(reschedule_date, 
e.reschedule_date)
+                            self.log.exception(
+                                "An exception occurred for task_id %s with 
index %s, it has been rescheduled at %s",
+                                task.task_id,
+                                task.index,
+                                reschedule_date,
+                            )
+                            failed_tasks.append(e.task)
+                        except Exception as e:
+                            self.log.exception(
+                                "An exception occurred for task_id %s with 
index %s",
+                                task.task_id,
+                                task.index,
+                            )
+                            exceptions.append(e)
+
+                    if len(futures) < self.max_workers:
+                        chunked_tasks = chain(list(deferred_tasks), 
chunked_tasks)
+                        deferred_tasks.clear()
+
+                        for task in next(chunked_tasks, []):
+                            if task.is_async:
+                                future = 
executor.submit(self._run_async_operator, context, task)
+                            else:
+                                future = executor.submit(self._run_operator, 
context, task)
+                            futures[future] = task
+
+        if not failed_tasks:
+            if exceptions:
+                raise BaseExceptionGroup("Multiple sub-task failures", 
exceptions)
+            if do_xcom_push:
+                from airflow.sdk.execution_time.lazy_sequence import 
XComIterable
+
+                return XComIterable(
+                    task_id=self.task_id,
+                    dag_id=self.dag_id,
+                    run_id=context["run_id"],
+                    length=self._number_of_tasks,
+                    map_index=context["ti"].map_index,
+                )
+            return None
+
+        # If the retry time is still in the future we defer the operator so 
the worker
+        # slot is released. If the retry time has already passed we 
immediately re-run
+        # the failed tasks without deferring.
+        if reschedule_date > timezone.utcnow():
+            # TODO: This is tricky as that import doesn't exist in Task SDK
+            from airflow.providers.standard.triggers.temporal import 
DateTimeTrigger
+
+            self.defer(
+                trigger=DateTimeTrigger(reschedule_date),
+                method_name=self.execute_failed_tasks.__name__,
+                kwargs={
+                    "failed_tasks": {failed_task.index for failed_task in 
failed_tasks},
+                    "try_number": next(iter(failed_tasks)).try_number,
+                },
+            )
+
+        return self._run_tasks(context=context, tasks=list(failed_tasks))

Review Comment:
   Still unaddressed from the previous review: `return 
self._run_tasks(context=context, tasks=list(failed_tasks))` recurses on every 
retry batch. Sensor-like patterns or any workload that repeatedly hits 
`AirflowRescheduleTaskInstanceException` will blow the Python recursion limit 
(default 1000). Convert to a `while failed_tasks:` loop that reassigns 
`failed_tasks` each iteration -- the tail call here has no special meaning, 
it's just a tight retry loop.



##########
task-sdk/src/airflow/sdk/execution_time/executor.py:
##########
@@ -0,0 +1,241 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+import contextvars
+import inspect
+import logging
+import time
+from asyncio import AbstractEventLoop, Semaphore
+from collections.abc import Callable, Generator
+from concurrent.futures import Future, ThreadPoolExecutor, as_completed
+from typing import TYPE_CHECKING, Any, cast
+
+from airflow.sdk import BaseAsyncOperator, BaseOperator, TaskInstanceState, 
timezone
+from airflow.sdk.bases.operator import ExecutorSafeguard
+from airflow.sdk.definitions._internal.logging_mixin import LoggingMixin
+from airflow.sdk.exceptions import AirflowRescheduleTaskInstanceException, 
TaskDeferred
+from airflow.sdk.execution_time.callback_runner import create_executable_runner
+from airflow.sdk.execution_time.context import context_get_outlet_events
+from airflow.sdk.execution_time.task_runner import (
+    RuntimeTaskInstance,
+    _execute_task,
+    _run_task_state_change_callbacks,
+)
+
+if TYPE_CHECKING:
+    from structlog.typing import FilteringBoundLogger as Logger
+
+    from airflow.sdk import Context
+    from airflow.sdk.execution_time.task_runner import IndexedTaskInstance
+
+
+def collect_futures(
+    loop: AbstractEventLoop, futures: list[Any]
+) -> Generator[Future | asyncio.futures.Future, None, None]:
+    """
+    Yield futures as they complete (sync or async).
+
+    :param loop: The asyncio event loop to use for async tasks
+    :param futures: List of Future or asyncio.futures.Future objects to collect
+    :return: Generator yielding Future or asyncio.futures.Future objects as 
they complete
+    """
+    yield from as_completed(f for f in futures if isinstance(f, Future))
+
+    async_tasks = [f for f in futures if isinstance(f, asyncio.futures.Future)]
+
+    if async_tasks:
+        for task, _ in zip(
+            async_tasks,
+            loop.run_until_complete(asyncio.gather(*async_tasks, 
return_exceptions=True)),
+        ):
+            yield task
+
+
+class ConcurrentExecutor:
+    """
+    Executes both sync and async functions concurrently.
+
+    Sync functions run in a ThreadPoolExecutor.
+    Async coroutines run on an asyncio event loop with a semaphore limit.
+    """
+
+    def __init__(self, loop: AbstractEventLoop, max_workers: int = 4):
+        self._loop = loop
+        self._semaphore = Semaphore(max_workers)
+        self._thread_pool = ThreadPoolExecutor(max_workers=max_workers)
+
+    def __enter__(self):
+        return self
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        if self._thread_pool:
+            self._thread_pool.shutdown(wait=True)

Review Comment:
   `ConcurrentExecutor.__exit__` does `self._thread_pool.shutdown(wait=True)`. 
If any worker thread is hung in user code (slow HTTP, blocking C extension, a 
deadlocked DB driver), `__exit__` waits forever -- and the outer 
`future.result(timeout=...)` in `_run_tasks` can't break out because the 
context manager owns the join.
   
   Same anti-pattern as "`ThreadPoolExecutor` as context manager negates the 
per-future timeout." Use `shutdown(wait=False, cancel_futures=True)` (and 
document that in-flight tasks are abandoned, not killed -- Python can't kill 
threads), or take a configurable shutdown deadline.



##########
task-sdk/src/airflow/sdk/definitions/mappedoperator.py:
##########
@@ -322,6 +289,7 @@ class MappedOperator(AbstractOperator):
     end_date: pendulum.DateTime | None
     upstream_task_ids: set[str] = attrs.field(factory=set, init=False)
     downstream_task_ids: set[str] = attrs.field(factory=set, init=False)
+    _apply_upstream_relationship: bool = 
attrs.field(alias="apply_upstream_relationship", default=True)

Review Comment:
   Flag-name vs behavior mismatch. The alias is `apply_upstream_relationship`, 
which reads like it only gates `XComArg.apply_upstream_relationship(...)`. But 
the `__attrs_post_init__` body (now correctly documented in the comment block) 
also gates `task_group.add(self)` and `dag.add_task(self)`.
   
   Call sites like 
`PartitionedOperator._iterate(apply_upstream_relationship=False)` read as 
"don't wire up XComArg relationships," when the actual effect is "don't 
register the operator with the DAG at all." The docstring/comment covers the 
why, but the parameter name still misleads.
   
   Consider renaming to `_register_with_dag` (or `_is_throwaway_instance`) and 
inverting the default, so the call sites read truer to what they do.



##########
task-sdk/src/airflow/sdk/definitions/iterableoperator.py:
##########
@@ -0,0 +1,453 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+import copy
+import os
+from collections import deque
+from collections.abc import Iterable, Mapping, Sequence
+from concurrent.futures import Future
+from itertools import chain
+from typing import TYPE_CHECKING, Any
+
+try:
+    # Python 3.12+
+    from itertools import batched  # type: ignore[attr-defined]
+except ImportError:
+    from more_itertools import batched  # type: ignore[no-redef]
+
+try:
+    # Python 3.11+
+    BaseExceptionGroup
+except NameError:
+    from exceptiongroup import BaseExceptionGroup
+
+from airflow.sdk import BaseXCom, TaskInstanceState, timezone
+from airflow.sdk.bases.operator import BaseOperator, 
DecoratedDeferredAsyncOperator, event_loop
+from airflow.sdk.definitions._internal.expandinput import 
PartitionedExpandInput
+from airflow.sdk.definitions.mappedoperator import MappedOperator
+from airflow.sdk.definitions.xcom_arg import MapXComArg, XComArg  # noqa: F401
+from airflow.sdk.exceptions import (
+    AirflowRescheduleTaskInstanceException,
+    AirflowTaskTimeout,
+    TaskDeferred,
+)
+from airflow.sdk.execution_time.executor import ConcurrentExecutor, 
TaskExecutor, collect_futures
+from airflow.sdk.execution_time.task_runner import IndexedTaskInstance
+
+if TYPE_CHECKING:
+    import jinja2
+
+    from airflow.sdk.definitions._internal.expandinput import ExpandInput
+    from airflow.sdk.definitions.context import Context
+    from airflow.sdk.execution_time.lazy_sequence import XComIterable
+
+
+class IterableOperator(BaseOperator):
+    """Object representing an iterable operator in a DAG."""
+
+    _operator: MappedOperator
+    expand_input: ExpandInput
+    partial_kwargs: dict[str, Any]
+    shallow_copy_attrs: Sequence[str] = (
+        "_operator",
+        "expand_input",
+        "partial_kwargs",
+        "_log",
+    )
+
+    def __init__(
+        self,
+        *,
+        operator: MappedOperator,
+        expand_input: ExpandInput,
+        **kwargs,
+    ):
+        super().__init__(
+            **{
+                **kwargs,
+                "task_id": operator.task_id,
+                "owner": operator.owner,
+                "email": operator.email,
+                "email_on_retry": operator.email_on_retry,
+                "email_on_failure": operator.email_on_failure,
+                "retries": 0,  # We should not retry the IterableOperator, 
only the mapped ti's should be retried
+                "retry_delay": operator.retry_delay,
+                "retry_exponential_backoff": 
operator.retry_exponential_backoff,
+                "max_retry_delay": operator.max_retry_delay,
+                "start_date": operator.start_date,
+                "end_date": operator.end_date,
+                "depends_on_past": operator.depends_on_past,
+                "ignore_first_depends_on_past": 
operator.ignore_first_depends_on_past,
+                "wait_for_past_depends_before_skipping": 
operator.wait_for_past_depends_before_skipping,
+                "wait_for_downstream": operator.wait_for_downstream,
+                "dag": operator.dag,
+                "priority_weight": operator.priority_weight,
+                "queue": operator.queue,
+                "pool": operator.pool,
+                "pool_slots": operator.pool_slots,
+                "execution_timeout": None,

Review Comment:
   `execution_timeout` is hardcoded to `None` in the super-init kwargs, but the 
`timeout` property below reads `self.execution_timeout`:
   
   ```python
   @property
   def timeout(self) -> float | None:
       if self.execution_timeout:
           return self.execution_timeout.total_seconds()
       return None
   ```
   
   and `future.result(timeout=self.timeout)` in `_run_tasks` relies on it. Net 
effect: any `execution_timeout` set on the operator (via `partial` or via the 
wrapped `MappedOperator`) is silently dropped -- `future.result(timeout=None)` 
waits forever.
   
   Either propagate `operator.execution_timeout` to the super-init, or document 
why timeout enforcement is intentionally disabled at the `IterableOperator` 
level (and what the user should do instead).



##########
task-sdk/src/airflow/sdk/definitions/iterableoperator.py:
##########
@@ -0,0 +1,453 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+import copy
+import os
+from collections import deque
+from collections.abc import Iterable, Mapping, Sequence
+from concurrent.futures import Future
+from itertools import chain
+from typing import TYPE_CHECKING, Any
+
+try:
+    # Python 3.12+
+    from itertools import batched  # type: ignore[attr-defined]
+except ImportError:
+    from more_itertools import batched  # type: ignore[no-redef]
+
+try:
+    # Python 3.11+
+    BaseExceptionGroup
+except NameError:
+    from exceptiongroup import BaseExceptionGroup
+
+from airflow.sdk import BaseXCom, TaskInstanceState, timezone
+from airflow.sdk.bases.operator import BaseOperator, 
DecoratedDeferredAsyncOperator, event_loop
+from airflow.sdk.definitions._internal.expandinput import 
PartitionedExpandInput
+from airflow.sdk.definitions.mappedoperator import MappedOperator
+from airflow.sdk.definitions.xcom_arg import MapXComArg, XComArg  # noqa: F401
+from airflow.sdk.exceptions import (
+    AirflowRescheduleTaskInstanceException,
+    AirflowTaskTimeout,
+    TaskDeferred,
+)
+from airflow.sdk.execution_time.executor import ConcurrentExecutor, 
TaskExecutor, collect_futures
+from airflow.sdk.execution_time.task_runner import IndexedTaskInstance
+
+if TYPE_CHECKING:
+    import jinja2
+
+    from airflow.sdk.definitions._internal.expandinput import ExpandInput
+    from airflow.sdk.definitions.context import Context
+    from airflow.sdk.execution_time.lazy_sequence import XComIterable
+
+
+class IterableOperator(BaseOperator):
+    """Object representing an iterable operator in a DAG."""
+
+    _operator: MappedOperator
+    expand_input: ExpandInput
+    partial_kwargs: dict[str, Any]
+    shallow_copy_attrs: Sequence[str] = (
+        "_operator",
+        "expand_input",
+        "partial_kwargs",
+        "_log",
+    )
+
+    def __init__(
+        self,
+        *,
+        operator: MappedOperator,
+        expand_input: ExpandInput,
+        **kwargs,
+    ):
+        super().__init__(
+            **{
+                **kwargs,
+                "task_id": operator.task_id,
+                "owner": operator.owner,
+                "email": operator.email,
+                "email_on_retry": operator.email_on_retry,
+                "email_on_failure": operator.email_on_failure,
+                "retries": 0,  # We should not retry the IterableOperator, 
only the mapped ti's should be retried
+                "retry_delay": operator.retry_delay,
+                "retry_exponential_backoff": 
operator.retry_exponential_backoff,
+                "max_retry_delay": operator.max_retry_delay,
+                "start_date": operator.start_date,
+                "end_date": operator.end_date,
+                "depends_on_past": operator.depends_on_past,
+                "ignore_first_depends_on_past": 
operator.ignore_first_depends_on_past,
+                "wait_for_past_depends_before_skipping": 
operator.wait_for_past_depends_before_skipping,
+                "wait_for_downstream": operator.wait_for_downstream,
+                "dag": operator.dag,
+                "priority_weight": operator.priority_weight,
+                "queue": operator.queue,
+                "pool": operator.pool,
+                "pool_slots": operator.pool_slots,
+                "execution_timeout": None,
+                "trigger_rule": operator.trigger_rule,
+                "resources": operator.resources,
+                "run_as_user": operator.run_as_user,
+                "map_index_template": operator.map_index_template,
+                "max_active_tis_per_dag": operator.max_active_tis_per_dag,
+                "max_active_tis_per_dagrun": 
operator.max_active_tis_per_dagrun,
+                "executor": operator.executor,
+                "executor_config": operator.executor_config,
+                "inlets": operator.inlets,
+                "outlets": operator.outlets,
+                "task_group": operator.task_group,
+                "doc": operator.doc,
+                "doc_md": operator.doc_md,
+                "doc_json": operator.doc_json,
+                "doc_yaml": operator.doc_yaml,
+                "doc_rst": operator.doc_rst,
+                "task_display_name": operator.task_display_name,
+                "allow_nested_operators": operator.allow_nested_operators,
+            }
+        )
+        self._operator = operator
+        self.expand_input = expand_input
+        self.partial_kwargs = operator.partial_kwargs or {}
+        self.partial_kwargs.pop("partition_size", None)
+        self.max_workers = self.partial_kwargs.pop("task_concurrency", None) 
or os.cpu_count() or 1
+        self._number_of_tasks: int = 0
+        XComArg.apply_upstream_relationship(self, self.expand_input.value)
+
+    @property
+    def task_type(self) -> str:
+        """@property: type of the task."""
+        return self._operator.__class__.__name__
+
+    @property
+    def timeout(self) -> float | None:
+        if self.execution_timeout:
+            return self.execution_timeout.total_seconds()
+        return None
+
+    def _do_render_template_fields(
+        self,
+        parent: Any,
+        template_fields: Iterable[str],
+        context: Context,
+        jinja_env: jinja2.Environment,
+        seen_oids: set[int],
+    ) -> None:
+        # IterableOperator doesn't need to render template fields as the 
actual operator's template fields
+        # will be rendered in the TaskExecutor when running each mapped task 
instance.
+        pass
+
+    def _get_specified_expand_input(self) -> ExpandInput:
+        return self.expand_input
+
+    def _unmap_operator(
+        self, context: Context, mapped_kwargs: Context, jinja_env: 
jinja2.Environment
+    ) -> BaseOperator:
+        from airflow.sdk.execution_time.context import 
context_update_for_unmapped
+
+        self._number_of_tasks += 1
+        unmapped_task = self._operator.unmap(mapped_kwargs)
+        # Make sure deferred operators will always raise a DeferredTask 
exception when executed
+        unmapped_task.start_from_trigger = False
+        context_update_for_unmapped(context, unmapped_task)
+
+        unmapped_task._do_render_template_fields(
+            parent=unmapped_task,
+            template_fields=self._operator.template_fields,
+            context=context,
+            jinja_env=jinja_env,
+            seen_oids=set(),
+        )
+        return unmapped_task
+
+    def _xcom_push(self, task: IndexedTaskInstance, value: Any) -> None:
+        if task.xcom_pushed:
+            self.log.debug(
+                "XCom already pushed for task_id %s with index %s",
+                task.task_id,
+                task.index,
+            )
+        else:
+            self.log.debug(
+                "Pushing XCom for task_id %s with index %s",
+                task.task_id,
+                task.index,
+            )
+
+            task.xcom_push(key=BaseXCom.XCOM_RETURN_KEY, value=value)
+
+    def _run_tasks(
+        self,
+        context: Context,
+        tasks: Iterable[IndexedTaskInstance],
+    ) -> XComIterable | None:
+        exceptions: list[BaseException] = []
+        reschedule_date = timezone.utcnow()
+        prev_futures_count = 0
+        futures: dict[Future | asyncio.futures.Future, IndexedTaskInstance] = 
{}
+        deferred_tasks: deque[IndexedTaskInstance] = deque()
+        failed_tasks: deque[IndexedTaskInstance] = deque()
+        chunked_tasks = batched(tasks, self.max_workers)
+        do_xcom_push = True
+
+        self.log.info("Running tasks with %d workers", self.max_workers)
+
+        with event_loop() as loop:
+            with ConcurrentExecutor(loop=loop, max_workers=self.max_workers) 
as executor:
+                for task in next(chunked_tasks, []):
+                    do_xcom_push = task.do_xcom_push
+                    if task.is_async:
+                        future = executor.submit(self._run_async_operator, 
context, task)
+                    else:
+                        future = executor.submit(self._run_operator, context, 
task)
+                    futures[future] = task
+
+                while futures:
+                    futures_count = len(futures)
+
+                    if futures_count != prev_futures_count:
+                        self.log.info("Number of remaining futures: %s", 
futures_count)
+                        prev_futures_count = futures_count
+
+                    for future in collect_futures(loop, list(futures.keys())):
+                        task = futures.pop(future)
+
+                        try:
+                            if isinstance(future, asyncio.futures.Future):
+                                result = future.result()
+                            else:
+                                result = future.result(timeout=self.timeout)
+
+                            self.log.debug("result: %s", result)
+
+                            if result is not None and task.do_xcom_push:
+                                self._xcom_push(
+                                    task=task,
+                                    value=result,
+                                )
+                        except TaskDeferred as task_deferred:
+                            operator = DecoratedDeferredAsyncOperator(
+                                operator=task.task, task_deferred=task_deferred
+                            )
+                            deferred_tasks.append(
+                                self._create_mapped_task(
+                                    run_id=task.run_id,
+                                    index=task.index,
+                                    map_index=task.map_index,  # type: 
ignore[arg-type]
+                                    try_number=task.try_number,
+                                    operator=operator,
+                                )
+                            )
+                        except asyncio.TimeoutError as e:
+                            self.log.warning("A timeout occurred for task_id 
%s", task.task_id)
+                            if task.next_try_number > (self.retries or 0):
+                                exceptions.append(AirflowTaskTimeout(e))
+                            else:
+                                reschedule_date = min(reschedule_date, 
task.next_retry_datetime())
+                                failed_tasks.append(task)
+                        except AirflowRescheduleTaskInstanceException as e:
+                            reschedule_date = min(reschedule_date, 
e.reschedule_date)
+                            self.log.exception(
+                                "An exception occurred for task_id %s with 
index %s, it has been rescheduled at %s",
+                                task.task_id,
+                                task.index,
+                                reschedule_date,
+                            )
+                            failed_tasks.append(e.task)
+                        except Exception as e:
+                            self.log.exception(
+                                "An exception occurred for task_id %s with 
index %s",
+                                task.task_id,
+                                task.index,
+                            )
+                            exceptions.append(e)
+
+                    if len(futures) < self.max_workers:
+                        chunked_tasks = chain(list(deferred_tasks), 
chunked_tasks)
+                        deferred_tasks.clear()
+
+                        for task in next(chunked_tasks, []):
+                            if task.is_async:
+                                future = 
executor.submit(self._run_async_operator, context, task)
+                            else:
+                                future = executor.submit(self._run_operator, 
context, task)
+                            futures[future] = task
+
+        if not failed_tasks:
+            if exceptions:
+                raise BaseExceptionGroup("Multiple sub-task failures", 
exceptions)
+            if do_xcom_push:
+                from airflow.sdk.execution_time.lazy_sequence import 
XComIterable
+
+                return XComIterable(
+                    task_id=self.task_id,
+                    dag_id=self.dag_id,
+                    run_id=context["run_id"],
+                    length=self._number_of_tasks,
+                    map_index=context["ti"].map_index,
+                )
+            return None
+
+        # If the retry time is still in the future we defer the operator so 
the worker
+        # slot is released. If the retry time has already passed we 
immediately re-run
+        # the failed tasks without deferring.
+        if reschedule_date > timezone.utcnow():
+            # TODO: This is tricky as that import doesn't exist in Task SDK
+            from airflow.providers.standard.triggers.temporal import 
DateTimeTrigger

Review Comment:
   Runtime `from airflow.providers.standard.triggers.temporal import 
DateTimeTrigger` inside `_run_tasks` breaks the task-sdk INV-3 contract (the 
SDK is supposed to ship without provider packages). The TODO on the line above 
acknowledges this, and the pre-commit exclusion added to 
`task-sdk/.pre-commit-config.yaml` for `check-core-imports` silently formalizes 
the violation.
   
   Options: (a) move `DateTimeTrigger` (or a thin SDK equivalent) into the SDK, 
(b) declare `apache-airflow-providers-standard` as a hard runtime dependency of 
task-sdk in `pyproject.toml`, or (c) expose a sleep/reschedule primitive in the 
SDK that the standard provider extends. Silently bypassing the import check is 
the worst of the three.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to