kaxil commented on code in PR #62922: URL: https://github.com/apache/airflow/pull/62922#discussion_r2890440469
########## task-sdk/src/airflow/sdk/bases/iterableoperator.py: ########## @@ -0,0 +1,432 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import asyncio +import logging +import os +import time +from collections import deque +from collections.abc import Iterable, Sequence +from concurrent.futures import Future +from math import ceil +from time import sleep +from typing import TYPE_CHECKING, Any + +from more_itertools import ichunked + +from airflow.exceptions import ( Review Comment: `more-itertools` is not a dependency of `task-sdk` — it's only in `providers/common/sql/pyproject.toml`. This will be an `ImportError` at runtime. Either add it to task-sdk's dependencies or replace `ichunked` with `itertools.batched` (Python 3.12+) or a simple chunking generator. ########## task-sdk/src/airflow/sdk/definitions/mappedoperator.py: ########## @@ -336,19 +364,20 @@ def __repr__(self): return f"<Mapped({self.task_type}): {self.task_id}>" def __attrs_post_init__(self): - from airflow.sdk.definitions.xcom_arg import XComArg - - if self.get_closest_mapped_task_group() is not None: - raise NotImplementedError("operator expansion in an expanded task group is not yet supported") - - if self.task_group: - self.task_group.add(self) - if self.dag: - self.dag.add_task(self) - XComArg.apply_upstream_relationship(self, self._get_specified_expand_input().value) - for k, v in self.partial_kwargs.items(): - if k in self.template_fields: - XComArg.apply_upstream_relationship(self, v) + if self._apply_upstream_relationship: + from airflow.sdk.definitions.xcom_arg import XComArg + + if self.get_closest_mapped_task_group() is not None: + raise NotImplementedError("operator expansion in an expanded task group is not yet supported") + + if self.task_group: + self.task_group.add(self) + if self.dag: Review Comment: When `apply_upstream_relationship=False`, this skips *everything* — not just upstream relationships, but also task group registration (`self.task_group.add(self)`) and DAG registration (`self.dag.add_task(self)`). That means `IterableOperator` creates `MappedOperator` instances that are invisible to the DAG, which breaks serialization, the UI, and dependency tracking. Should we separate these concerns? The flag name suggests it only controls upstream relationships, but it actually controls all post-init behavior. ########## task-sdk/src/airflow/sdk/definitions/_internal/expandinput.py: ########## @@ -184,7 +228,22 @@ def iter_references(self) -> Iterable[tuple[Operator, str]]: if isinstance(x, XComArg): yield from x.iter_references() - def resolve(self, context: Mapping[str, Any]) -> tuple[Mapping[str, Any], set[int]]: + def iter_values(self, context: Context) -> Iterable[Any]: + def resolve(value: Any) -> Any: + if isinstance(value, XComArg): + return value.iter_values(context=context) + return value + Review Comment: This `iter_values` doesn't produce the right output for `expand()` semantics. For `expand(a=[1, 2], b=[3, 4])`, this yields `{"a": 1}, {"a": 2}, {"b": 3}, {"b": 4}` — 4 separate single-key dicts. But dynamic task mapping produces the cartesian product: `{"a": 1, "b": 3}, {"a": 1, "b": 4}, {"a": 2, "b": 3}, {"a": 2, "b": 4}` — 4 dicts with all keys combined. Tasks receiving only `{"a": 1}` without `b` would fail with missing arguments. ########## task-sdk/src/airflow/sdk/definitions/_internal/expandinput.py: ########## @@ -79,8 +79,45 @@ def _needs_run_time_resolution(v: OperatorExpandArgument) -> TypeGuard[MappedArg return isinstance(v, (MappedArgument, XComArg)) +class ExpandInput(ABC, ResolveMixin): Review Comment: Changing `ExpandInput` from a Union type alias to an ABC is a significant refactor. A couple issues: 1. `DecoratedExpandInput` and `MappedArgument` inherit from `ExpandInput` but don't set `EXPAND_INPUT_TYPE`. The serializer in `airflow-core/src/airflow/serialization/encoders.py` (`encode_expand_input`) accesses `var.EXPAND_INPUT_TYPE` — this will crash with `AttributeError`. 2. `MappedArgument` previously inherited from `ResolveMixin`, now from `ExpandInput`. Is it really an expand input? It's a stand-in for task-group-mapping arguments — making it an `ExpandInput` subclass seems like it conflates two different concepts. ########## task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py: ########## @@ -166,6 +166,60 @@ def __getitem__(self, key: int | slice) -> T | Sequence[T]: return XCom.deserialize_value(_XComWrapper(msg.root)) [email protected] +class XComIterable(Sequence): + """An iterable that lazily fetches XCom values one by one instead of loading all at once.""" Review Comment: `XComIterable` implements both `Sequence` (via inheritance) and `Iterator` (via `__iter__` returning `self` + `__next__`). A `Sequence.__iter__` should return a fresh iterator on each call, not reuse the object itself — otherwise you can't iterate twice concurrently, and multiple consumers sharing the same instance will interfere with each other. Consider having `__iter__` return a separate iterator object, similar to how `LazyXComSequence` uses `LazyXComIterator`. ########## task-sdk/src/airflow/sdk/bases/iterableoperator.py: ########## @@ -0,0 +1,432 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import asyncio +import logging +import os +import time +from collections import deque +from collections.abc import Iterable, Sequence +from concurrent.futures import Future +from math import ceil +from time import sleep +from typing import TYPE_CHECKING, Any + +from more_itertools import ichunked + +from airflow.exceptions import ( + AirflowTaskTimeout, + TaskDeferred, +) +from airflow.sdk import timezone +from airflow.sdk.bases.operator import BaseOperator, DecoratedDeferredAsyncOperator, event_loop +from airflow.sdk.definitions.xcom_arg import MapXComArg, XComArg # noqa: F401 +from airflow.sdk.exceptions import AirflowRescheduleTaskInstanceException +from airflow.sdk.execution_time.executor import HybridExecutor, _execute_async_task, collect_futures +from airflow.sdk.execution_time.lazy_sequence import XComIterable +from airflow.sdk.execution_time.task_runner import MappedTaskInstance, RuntimeTaskInstance, _execute_task +from airflow.utils.log.logging_mixin import LoggingMixin +from airflow.utils.state import TaskInstanceState + +if TYPE_CHECKING: + from airflow.sdk.definitions._internal.expandinput import ExpandInput + from airflow.sdk.definitions.context import Context + from airflow.sdk.definitions.mappedoperator import MappedOperator + + +class TaskExecutor(LoggingMixin): + """Base class to run an operator or trigger with given task context and task instance.""" + + def __init__( + self, + task_instance: RuntimeTaskInstance, + ): + super().__init__() + self._task_instance = task_instance + self._result: Any | None = None + self._start_time: float | None = None + + @property + def task_instance(self) -> RuntimeTaskInstance: + return self._task_instance + + @property + def dag_id(self) -> str: + return self.task_instance.dag_id + + @property + def task_id(self) -> str: + return self.task_instance.task_id + + @property + def task_index(self) -> int: + return self.task_instance.map_index + + @property + def key(self): + return self.task_instance.xcom_key + + @property + def operator(self) -> BaseOperator: + return self.task_instance.task + + @property + def is_async(self) -> bool: + return self.task_instance.is_async + + def run(self, context: Context): + return _execute_task(context, self.task_instance, self.log) + + async def arun(self, context: Context): + return await _execute_async_task(context, self.task_instance, self.log) + + def __enter__(self): + self._start_time = time.monotonic() + + if self.log.isEnabledFor(logging.INFO): + self.log.info( + "Attempting running task %s of %s for %s with map_index %s in %s mode.", + self.task_instance.try_number, + self.operator.retries, + self.task_instance.task_id, + self.task_index, + "async" if self.is_async else "sync", + ) + return self + + async def __aenter__(self): + return self.__enter__() + + def __exit__(self, exc_type, exc_value, traceback): Review Comment: `self._start_time` is initialized to `None` and only set in `__enter__`. If `__exit__` gets called without a matching `__enter__` (edge case, but possible), `time.monotonic() - self._start_time` will be `TypeError`. ########## task-sdk/src/airflow/sdk/bases/operator.py: ########## @@ -1657,7 +1662,14 @@ def resume_execution(self, next_method: str, next_kwargs: dict[str, Any] | None, raise TaskDeferralTimeout(error) raise TaskDeferralError(error) # Grab the callable off the Operator/Task and add in any kwargs - execute_callable = getattr(self, next_method) + return getattr(self, next_method) + Review Comment: `next_callable` takes 3 positional args after `self`: `next_method`, `next_kwargs`, `context`. But `DecoratedDeferredAsyncOperator.aexecute` calls it with only 2: ```python next_method = self._operator.next_callable( self._task_deferred.method_name, self._task_deferred.kwargs, ) ``` That will raise `TypeError` for the missing `context` argument. Also — the original `resume_execution` handled `next_kwargs is None` defaulting before the `__fail__` check. Now `next_callable` receives the raw kwargs without that defaulting, so if `DecoratedDeferredAsyncOperator` passes `None` kwargs, `next_kwargs.get("traceback")` will `AttributeError`. ########## task-sdk/src/airflow/sdk/bases/iterableoperator.py: ########## @@ -0,0 +1,432 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import asyncio +import logging +import os +import time +from collections import deque +from collections.abc import Iterable, Sequence +from concurrent.futures import Future +from math import ceil +from time import sleep +from typing import TYPE_CHECKING, Any + +from more_itertools import ichunked + +from airflow.exceptions import ( + AirflowTaskTimeout, + TaskDeferred, +) +from airflow.sdk import timezone +from airflow.sdk.bases.operator import BaseOperator, DecoratedDeferredAsyncOperator, event_loop +from airflow.sdk.definitions.xcom_arg import MapXComArg, XComArg # noqa: F401 +from airflow.sdk.exceptions import AirflowRescheduleTaskInstanceException +from airflow.sdk.execution_time.executor import HybridExecutor, _execute_async_task, collect_futures +from airflow.sdk.execution_time.lazy_sequence import XComIterable +from airflow.sdk.execution_time.task_runner import MappedTaskInstance, RuntimeTaskInstance, _execute_task +from airflow.utils.log.logging_mixin import LoggingMixin +from airflow.utils.state import TaskInstanceState + +if TYPE_CHECKING: + from airflow.sdk.definitions._internal.expandinput import ExpandInput + from airflow.sdk.definitions.context import Context + from airflow.sdk.definitions.mappedoperator import MappedOperator + + +class TaskExecutor(LoggingMixin): + """Base class to run an operator or trigger with given task context and task instance.""" + + def __init__( + self, + task_instance: RuntimeTaskInstance, + ): + super().__init__() + self._task_instance = task_instance + self._result: Any | None = None + self._start_time: float | None = None + + @property + def task_instance(self) -> RuntimeTaskInstance: + return self._task_instance + + @property + def dag_id(self) -> str: + return self.task_instance.dag_id + + @property + def task_id(self) -> str: + return self.task_instance.task_id + + @property + def task_index(self) -> int: + return self.task_instance.map_index + + @property + def key(self): + return self.task_instance.xcom_key + + @property + def operator(self) -> BaseOperator: + return self.task_instance.task + + @property + def is_async(self) -> bool: + return self.task_instance.is_async + + def run(self, context: Context): + return _execute_task(context, self.task_instance, self.log) + + async def arun(self, context: Context): + return await _execute_async_task(context, self.task_instance, self.log) + + def __enter__(self): + self._start_time = time.monotonic() + + if self.log.isEnabledFor(logging.INFO): + self.log.info( + "Attempting running task %s of %s for %s with map_index %s in %s mode.", + self.task_instance.try_number, + self.operator.retries, + self.task_instance.task_id, + self.task_index, + "async" if self.is_async else "sync", + ) + return self + + async def __aenter__(self): + return self.__enter__() + + def __exit__(self, exc_type, exc_value, traceback): + elapsed = time.monotonic() - self._start_time + + if exc_value: + if not isinstance(exc_value, TaskDeferred): + if self.task_instance.next_try_number > self.task_instance.max_tries: + self.log.error( + "Task instance %s for %s failed after %s attempts in %.2f seconds due to: %s", + self.task_index, + self.task_instance.task_id, + self.task_instance.max_tries, + elapsed, + exc_value, + ) + self.task_instance.state = TaskInstanceState.FAILED + raise exc_value + self.task_instance.try_number += 1 + self.task_instance.end_date = timezone.utcnow() + self.task_instance.state = TaskInstanceState.UP_FOR_RESCHEDULE + raise AirflowRescheduleTaskInstanceException(task=self.task_instance) + raise exc_value + + self.task_instance.state = TaskInstanceState.SUCCESS + if self.log.isEnabledFor(logging.INFO): + self.log.info( + "Task instance %s for %s finished successfully in %s attempts in %.2f seconds", + self.task_index, + self.task_instance.task_id, + self.task_instance.next_try_number, + elapsed, + ) + + async def __aexit__(self, exc_type, exc_value, traceback): + self.__exit__(exc_type, exc_value, traceback) + + +class IterableOperator(BaseOperator): + """Object representing an iterable operator in a DAG.""" + + _operator: MappedOperator + expand_input: ExpandInput + partial_kwargs: dict[str, Any] + # each operator should override this class attr for shallow copy attrs. + shallow_copy_attrs: Sequence[str] = ( + "_operator", + "expand_input", + "partial_kwargs", + "_log", + ) + + def __init__( + self, + *, + operator: MappedOperator, + expand_input: ExpandInput, + **kwargs, + ): + super().__init__( + **{ + **kwargs, + "task_id": operator.task_id, + "owner": operator.owner, + "email": operator.email, + "email_on_retry": operator.email_on_retry, + "email_on_failure": operator.email_on_failure, + "retries": 0, # We should not retry the IterableOperator, only the mapped ti's should be retried + "retry_delay": operator.retry_delay, + "retry_exponential_backoff": operator.retry_exponential_backoff, + "max_retry_delay": operator.max_retry_delay, + "start_date": operator.start_date, + "end_date": operator.end_date, + "depends_on_past": operator.depends_on_past, + "ignore_first_depends_on_past": operator.ignore_first_depends_on_past, + "wait_for_past_depends_before_skipping": operator.wait_for_past_depends_before_skipping, + "wait_for_downstream": operator.wait_for_downstream, + "dag": operator.dag, + "priority_weight": operator.priority_weight, + "queue": operator.queue, + "pool": operator.pool, + "pool_slots": operator.pool_slots, + "execution_timeout": None, + "trigger_rule": operator.trigger_rule, + "resources": operator.resources, + "run_as_user": operator.run_as_user, + "map_index_template": operator.map_index_template, + "max_active_tis_per_dag": operator.max_active_tis_per_dag, + "max_active_tis_per_dagrun": operator.max_active_tis_per_dagrun, + "executor": operator.executor, + "executor_config": operator.executor_config, + "inlets": operator.inlets, + "outlets": operator.outlets, + "task_group": operator.task_group, + "doc": operator.doc, + "doc_md": operator.doc_md, + "doc_json": operator.doc_json, + "doc_yaml": operator.doc_yaml, + "doc_rst": operator.doc_rst, + "task_display_name": operator.task_display_name, + "allow_nested_operators": operator.allow_nested_operators, + } + ) + self._operator = operator + self.expand_input = expand_input + self.partial_kwargs = operator.partial_kwargs or {} + self._number_of_tasks: int = 0 + XComArg.apply_upstream_relationship(self, self.expand_input.value) + + @property + def task_type(self) -> str: + """@property: type of the task.""" + return self._operator.__class__.__name__ + + @property + def max_workers(self): + return self.max_active_tis_per_dag or os.cpu_count() or 1 + + @property + def timeout(self) -> float | None: + if self.execution_timeout: + return self.execution_timeout.total_seconds() + return None + + def _get_specified_expand_input(self) -> ExpandInput: + return self.expand_input + + def _unmap_operator(self, mapped_kwargs: dict): + self._number_of_tasks += 1 + return self._operator.unmap(mapped_kwargs) + + def _xcom_push(self, context: Context, task: RuntimeTaskInstance, value: Any) -> None: + self.log.debug("Pushing XCom %s", task.map_index) + + context["ti"].xcom_push(key=task.xcom_key, value=value) + + def _run_tasks( + self, + context: Context, + tasks: Iterable[RuntimeTaskInstance], + ) -> None: + exception: BaseException | None = None + reschedule_date = timezone.utcnow() + prev_futures_count = 0 + futures: dict[Future, RuntimeTaskInstance] = {} + failed_tasks: deque[RuntimeTaskInstance] = deque() + chunked_tasks = ichunked(tasks, self.max_workers) + + self.log.info("Running tasks with %d workers", self.max_workers) + + with event_loop() as loop: + with HybridExecutor(loop=loop, max_workers=self.max_workers) as executor: + for task in next(chunked_tasks, []): + if task.is_async: + future = executor.submit(self._run_async_operator, context, task) + else: + future = executor.submit(self._run_operator, context, task) + futures[future] = task + + while futures: + futures_count = len(futures) + + if futures_count != prev_futures_count: + self.log.info("Number of remaining futures: %s", futures_count) + prev_futures_count = futures_count + + ready_futures = False + + for future in collect_futures(loop, futures.keys()): + task = futures.pop(future) + ready_futures = True + + try: + if isinstance(future, asyncio.futures.Future): + result = future.result() + else: + result = future.result(timeout=self.timeout) + + self.log.debug("result: %s", result) + + if result and task.task.do_xcom_push: + self._xcom_push( + context=context, + task=task, + value=result, + ) + except TaskDeferred as task_deferred: + operator = DecoratedDeferredAsyncOperator( + operator=task.task, task_deferred=task_deferred + ) + failed_tasks.append( + self._create_mapped_task(task.run_id, task.map_index, operator) + ) + except asyncio.TimeoutError as e: + self.log.warning("A timeout occurred for task_id %s", task.task_id) + if task.next_try_number > self.retries: + exception = AirflowTaskTimeout(e) + else: + reschedule_date = min(reschedule_date, task.next_retry_datetime()) + failed_tasks.append(task) + except AirflowRescheduleTaskInstanceException as e: + reschedule_date = min(reschedule_date, e.reschedule_date) + self.log.warning( + "An exception occurred for task_id %s with map_index %s, it has been rescheduled at %s", + task.task_id, + task.map_index, + reschedule_date, + ) + failed_tasks.append(e.task) + except Exception as e: + self.log.error( + "An exception occurred for task_id %s with map_index %s", + task.task_id, + task.map_index, + ) + exception = e + + if len(futures) < self.max_workers: + for task in next(chunked_tasks, []): + if task.is_async: + future = executor.submit(self._run_async_operator, context, task) + else: + future = executor.submit(self._run_operator, context, task) + futures[future] = task + elif not ready_futures and futures: + sleep(len(futures) * 0.1) + + if not failed_tasks: + if exception: + raise exception + if self.do_xcom_push: + return XComIterable( + task_id=self.task_id, + dag_id=self.dag_id, + run_id=context["run_id"], + length=self._number_of_tasks, + ) + + now = timezone.utcnow() + + # Calculate delay before the next retry + if reschedule_date > now: + delay_seconds = ceil((reschedule_date - now).total_seconds()) + + self.log.info( + "Attempting to run %s failed tasks within %s seconds...", + len(failed_tasks), + delay_seconds, + ) + + sleep(delay_seconds) + + return self._run_tasks(context=context, tasks=list(failed_tasks)) + + def _run_operator(self, context: Context, task_instance: RuntimeTaskInstance): + with TaskExecutor(task_instance=task_instance) as executor: + return executor.run( + context={ Review Comment: `_run_operator` creates a new dict merging `context` with task-specific overrides, but the base `context` dict is shared across all threads. If any task modifies values in `context` during execution (which tasks commonly do), that's a race condition. Same issue with `_run_async_operator` below. Probably need to deep-copy `context` per task. ########## task-sdk/src/airflow/sdk/bases/iterableoperator.py: ########## @@ -0,0 +1,432 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import asyncio +import logging +import os +import time +from collections import deque +from collections.abc import Iterable, Sequence +from concurrent.futures import Future +from math import ceil +from time import sleep +from typing import TYPE_CHECKING, Any + +from more_itertools import ichunked + +from airflow.exceptions import ( + AirflowTaskTimeout, + TaskDeferred, +) +from airflow.sdk import timezone +from airflow.sdk.bases.operator import BaseOperator, DecoratedDeferredAsyncOperator, event_loop +from airflow.sdk.definitions.xcom_arg import MapXComArg, XComArg # noqa: F401 +from airflow.sdk.exceptions import AirflowRescheduleTaskInstanceException +from airflow.sdk.execution_time.executor import HybridExecutor, _execute_async_task, collect_futures +from airflow.sdk.execution_time.lazy_sequence import XComIterable +from airflow.sdk.execution_time.task_runner import MappedTaskInstance, RuntimeTaskInstance, _execute_task +from airflow.utils.log.logging_mixin import LoggingMixin +from airflow.utils.state import TaskInstanceState + +if TYPE_CHECKING: + from airflow.sdk.definitions._internal.expandinput import ExpandInput + from airflow.sdk.definitions.context import Context + from airflow.sdk.definitions.mappedoperator import MappedOperator + + +class TaskExecutor(LoggingMixin): + """Base class to run an operator or trigger with given task context and task instance.""" + + def __init__( + self, + task_instance: RuntimeTaskInstance, + ): + super().__init__() + self._task_instance = task_instance + self._result: Any | None = None + self._start_time: float | None = None + + @property + def task_instance(self) -> RuntimeTaskInstance: + return self._task_instance + + @property + def dag_id(self) -> str: + return self.task_instance.dag_id + + @property + def task_id(self) -> str: + return self.task_instance.task_id + + @property + def task_index(self) -> int: + return self.task_instance.map_index + + @property + def key(self): + return self.task_instance.xcom_key + + @property + def operator(self) -> BaseOperator: + return self.task_instance.task + + @property + def is_async(self) -> bool: + return self.task_instance.is_async + + def run(self, context: Context): + return _execute_task(context, self.task_instance, self.log) + + async def arun(self, context: Context): + return await _execute_async_task(context, self.task_instance, self.log) + + def __enter__(self): + self._start_time = time.monotonic() + + if self.log.isEnabledFor(logging.INFO): + self.log.info( + "Attempting running task %s of %s for %s with map_index %s in %s mode.", + self.task_instance.try_number, + self.operator.retries, + self.task_instance.task_id, + self.task_index, + "async" if self.is_async else "sync", + ) + return self + + async def __aenter__(self): + return self.__enter__() + + def __exit__(self, exc_type, exc_value, traceback): + elapsed = time.monotonic() - self._start_time + + if exc_value: + if not isinstance(exc_value, TaskDeferred): + if self.task_instance.next_try_number > self.task_instance.max_tries: + self.log.error( + "Task instance %s for %s failed after %s attempts in %.2f seconds due to: %s", + self.task_index, + self.task_instance.task_id, + self.task_instance.max_tries, + elapsed, + exc_value, + ) + self.task_instance.state = TaskInstanceState.FAILED + raise exc_value + self.task_instance.try_number += 1 + self.task_instance.end_date = timezone.utcnow() + self.task_instance.state = TaskInstanceState.UP_FOR_RESCHEDULE + raise AirflowRescheduleTaskInstanceException(task=self.task_instance) + raise exc_value + + self.task_instance.state = TaskInstanceState.SUCCESS + if self.log.isEnabledFor(logging.INFO): + self.log.info( + "Task instance %s for %s finished successfully in %s attempts in %.2f seconds", + self.task_index, + self.task_instance.task_id, + self.task_instance.next_try_number, + elapsed, + ) + + async def __aexit__(self, exc_type, exc_value, traceback): + self.__exit__(exc_type, exc_value, traceback) + + +class IterableOperator(BaseOperator): + """Object representing an iterable operator in a DAG.""" + + _operator: MappedOperator + expand_input: ExpandInput + partial_kwargs: dict[str, Any] + # each operator should override this class attr for shallow copy attrs. + shallow_copy_attrs: Sequence[str] = ( + "_operator", + "expand_input", + "partial_kwargs", + "_log", + ) + + def __init__( + self, + *, + operator: MappedOperator, + expand_input: ExpandInput, + **kwargs, + ): + super().__init__( + **{ + **kwargs, + "task_id": operator.task_id, + "owner": operator.owner, + "email": operator.email, + "email_on_retry": operator.email_on_retry, + "email_on_failure": operator.email_on_failure, + "retries": 0, # We should not retry the IterableOperator, only the mapped ti's should be retried + "retry_delay": operator.retry_delay, + "retry_exponential_backoff": operator.retry_exponential_backoff, + "max_retry_delay": operator.max_retry_delay, + "start_date": operator.start_date, + "end_date": operator.end_date, + "depends_on_past": operator.depends_on_past, + "ignore_first_depends_on_past": operator.ignore_first_depends_on_past, + "wait_for_past_depends_before_skipping": operator.wait_for_past_depends_before_skipping, + "wait_for_downstream": operator.wait_for_downstream, + "dag": operator.dag, + "priority_weight": operator.priority_weight, + "queue": operator.queue, + "pool": operator.pool, + "pool_slots": operator.pool_slots, + "execution_timeout": None, + "trigger_rule": operator.trigger_rule, + "resources": operator.resources, + "run_as_user": operator.run_as_user, + "map_index_template": operator.map_index_template, + "max_active_tis_per_dag": operator.max_active_tis_per_dag, + "max_active_tis_per_dagrun": operator.max_active_tis_per_dagrun, + "executor": operator.executor, + "executor_config": operator.executor_config, + "inlets": operator.inlets, + "outlets": operator.outlets, + "task_group": operator.task_group, + "doc": operator.doc, + "doc_md": operator.doc_md, + "doc_json": operator.doc_json, + "doc_yaml": operator.doc_yaml, + "doc_rst": operator.doc_rst, + "task_display_name": operator.task_display_name, + "allow_nested_operators": operator.allow_nested_operators, + } + ) + self._operator = operator + self.expand_input = expand_input + self.partial_kwargs = operator.partial_kwargs or {} + self._number_of_tasks: int = 0 + XComArg.apply_upstream_relationship(self, self.expand_input.value) + + @property + def task_type(self) -> str: + """@property: type of the task.""" + return self._operator.__class__.__name__ + + @property + def max_workers(self): + return self.max_active_tis_per_dag or os.cpu_count() or 1 + + @property + def timeout(self) -> float | None: + if self.execution_timeout: + return self.execution_timeout.total_seconds() + return None + + def _get_specified_expand_input(self) -> ExpandInput: + return self.expand_input + + def _unmap_operator(self, mapped_kwargs: dict): + self._number_of_tasks += 1 + return self._operator.unmap(mapped_kwargs) + + def _xcom_push(self, context: Context, task: RuntimeTaskInstance, value: Any) -> None: + self.log.debug("Pushing XCom %s", task.map_index) + + context["ti"].xcom_push(key=task.xcom_key, value=value) + + def _run_tasks( + self, + context: Context, + tasks: Iterable[RuntimeTaskInstance], + ) -> None: + exception: BaseException | None = None + reschedule_date = timezone.utcnow() + prev_futures_count = 0 + futures: dict[Future, RuntimeTaskInstance] = {} + failed_tasks: deque[RuntimeTaskInstance] = deque() + chunked_tasks = ichunked(tasks, self.max_workers) + + self.log.info("Running tasks with %d workers", self.max_workers) + + with event_loop() as loop: + with HybridExecutor(loop=loop, max_workers=self.max_workers) as executor: + for task in next(chunked_tasks, []): + if task.is_async: + future = executor.submit(self._run_async_operator, context, task) + else: + future = executor.submit(self._run_operator, context, task) + futures[future] = task + + while futures: + futures_count = len(futures) + + if futures_count != prev_futures_count: + self.log.info("Number of remaining futures: %s", futures_count) + prev_futures_count = futures_count + + ready_futures = False + + for future in collect_futures(loop, futures.keys()): + task = futures.pop(future) + ready_futures = True + + try: + if isinstance(future, asyncio.futures.Future): + result = future.result() + else: + result = future.result(timeout=self.timeout) + + self.log.debug("result: %s", result) + + if result and task.task.do_xcom_push: + self._xcom_push( + context=context, + task=task, + value=result, + ) + except TaskDeferred as task_deferred: + operator = DecoratedDeferredAsyncOperator( + operator=task.task, task_deferred=task_deferred + ) + failed_tasks.append( + self._create_mapped_task(task.run_id, task.map_index, operator) + ) + except asyncio.TimeoutError as e: + self.log.warning("A timeout occurred for task_id %s", task.task_id) + if task.next_try_number > self.retries: + exception = AirflowTaskTimeout(e) + else: + reschedule_date = min(reschedule_date, task.next_retry_datetime()) + failed_tasks.append(task) + except AirflowRescheduleTaskInstanceException as e: + reschedule_date = min(reschedule_date, e.reschedule_date) + self.log.warning( + "An exception occurred for task_id %s with map_index %s, it has been rescheduled at %s", + task.task_id, + task.map_index, + reschedule_date, + ) + failed_tasks.append(e.task) + except Exception as e: + self.log.error( + "An exception occurred for task_id %s with map_index %s", + task.task_id, + task.map_index, + ) + exception = e + + if len(futures) < self.max_workers: + for task in next(chunked_tasks, []): + if task.is_async: + future = executor.submit(self._run_async_operator, context, task) + else: + future = executor.submit(self._run_operator, context, task) + futures[future] = task + elif not ready_futures and futures: + sleep(len(futures) * 0.1) + + if not failed_tasks: + if exception: + raise exception + if self.do_xcom_push: + return XComIterable( + task_id=self.task_id, + dag_id=self.dag_id, + run_id=context["run_id"], + length=self._number_of_tasks, + ) + + now = timezone.utcnow() + + # Calculate delay before the next retry + if reschedule_date > now: + delay_seconds = ceil((reschedule_date - now).total_seconds()) + + self.log.info( + "Attempting to run %s failed tasks within %s seconds...", + len(failed_tasks), + delay_seconds, + ) + + sleep(delay_seconds) + Review Comment: `_run_tasks` calls itself recursively for retries: ```python return self._run_tasks(context=context, tasks=list(failed_tasks)) ``` If tasks keep failing (e.g., a sensor-like pattern with many retries), this will hit Python's recursion limit. Should be a loop instead. ########## task-sdk/src/airflow/sdk/execution_time/executor.py: ########## @@ -0,0 +1,141 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import asyncio +import contextvars +import inspect +import os +from asyncio import AbstractEventLoop, Semaphore +from collections.abc import Callable +from concurrent.futures import Future, ThreadPoolExecutor, as_completed +from logging import Logger +from typing import TYPE_CHECKING, Any + +from airflow.sdk import BaseAsyncOperator, BaseOperator +from airflow.sdk.bases.operator import ExecutorSafeguard +from airflow.sdk.execution_time.callback_runner import create_executable_runner +from airflow.sdk.execution_time.context import ( + context_get_outlet_events, + context_to_airflow_vars, +) +from airflow.sdk.execution_time.task_runner import ( + RuntimeTaskInstance, + _run_task_state_change_callbacks, +) + +if TYPE_CHECKING: + from airflow.sdk import Context + + +def collect_futures(loop: AbstractEventLoop, futures: list[Any]): + """Yield futures as they complete (sync or async).""" + yield from as_completed(f for f in futures if isinstance(f, Future)) + + async_tasks = [f for f in futures if isinstance(f, asyncio.Task)] + + if async_tasks: + for task, _ in zip( + async_tasks, + loop.run_until_complete(asyncio.gather(*async_tasks, return_exceptions=True)), + ): + yield task + + return [] + Review Comment: `os.environ.update(airflow_context_vars)` — `os.environ` is process-global. Multiple tasks running concurrently in threads via `HybridExecutor` will clobber each other's environment variables. ########## task-sdk/src/airflow/sdk/bases/iterableoperator.py: ########## @@ -0,0 +1,432 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import asyncio +import logging +import os +import time +from collections import deque +from collections.abc import Iterable, Sequence +from concurrent.futures import Future +from math import ceil +from time import sleep +from typing import TYPE_CHECKING, Any + +from more_itertools import ichunked + +from airflow.exceptions import ( + AirflowTaskTimeout, + TaskDeferred, +) +from airflow.sdk import timezone +from airflow.sdk.bases.operator import BaseOperator, DecoratedDeferredAsyncOperator, event_loop +from airflow.sdk.definitions.xcom_arg import MapXComArg, XComArg # noqa: F401 +from airflow.sdk.exceptions import AirflowRescheduleTaskInstanceException +from airflow.sdk.execution_time.executor import HybridExecutor, _execute_async_task, collect_futures +from airflow.sdk.execution_time.lazy_sequence import XComIterable +from airflow.sdk.execution_time.task_runner import MappedTaskInstance, RuntimeTaskInstance, _execute_task +from airflow.utils.log.logging_mixin import LoggingMixin +from airflow.utils.state import TaskInstanceState + +if TYPE_CHECKING: + from airflow.sdk.definitions._internal.expandinput import ExpandInput + from airflow.sdk.definitions.context import Context + from airflow.sdk.definitions.mappedoperator import MappedOperator + + +class TaskExecutor(LoggingMixin): + """Base class to run an operator or trigger with given task context and task instance.""" + + def __init__( + self, + task_instance: RuntimeTaskInstance, + ): + super().__init__() + self._task_instance = task_instance + self._result: Any | None = None + self._start_time: float | None = None + + @property + def task_instance(self) -> RuntimeTaskInstance: + return self._task_instance + + @property + def dag_id(self) -> str: + return self.task_instance.dag_id + + @property + def task_id(self) -> str: + return self.task_instance.task_id + + @property + def task_index(self) -> int: + return self.task_instance.map_index + + @property + def key(self): + return self.task_instance.xcom_key + + @property + def operator(self) -> BaseOperator: + return self.task_instance.task + + @property + def is_async(self) -> bool: + return self.task_instance.is_async + + def run(self, context: Context): + return _execute_task(context, self.task_instance, self.log) + + async def arun(self, context: Context): + return await _execute_async_task(context, self.task_instance, self.log) + + def __enter__(self): + self._start_time = time.monotonic() + + if self.log.isEnabledFor(logging.INFO): + self.log.info( + "Attempting running task %s of %s for %s with map_index %s in %s mode.", + self.task_instance.try_number, + self.operator.retries, + self.task_instance.task_id, + self.task_index, + "async" if self.is_async else "sync", + ) + return self + + async def __aenter__(self): + return self.__enter__() + + def __exit__(self, exc_type, exc_value, traceback): + elapsed = time.monotonic() - self._start_time + + if exc_value: + if not isinstance(exc_value, TaskDeferred): + if self.task_instance.next_try_number > self.task_instance.max_tries: + self.log.error( + "Task instance %s for %s failed after %s attempts in %.2f seconds due to: %s", + self.task_index, + self.task_instance.task_id, + self.task_instance.max_tries, + elapsed, + exc_value, + ) + self.task_instance.state = TaskInstanceState.FAILED + raise exc_value + self.task_instance.try_number += 1 + self.task_instance.end_date = timezone.utcnow() + self.task_instance.state = TaskInstanceState.UP_FOR_RESCHEDULE + raise AirflowRescheduleTaskInstanceException(task=self.task_instance) + raise exc_value + + self.task_instance.state = TaskInstanceState.SUCCESS + if self.log.isEnabledFor(logging.INFO): + self.log.info( + "Task instance %s for %s finished successfully in %s attempts in %.2f seconds", + self.task_index, + self.task_instance.task_id, + self.task_instance.next_try_number, + elapsed, + ) + + async def __aexit__(self, exc_type, exc_value, traceback): + self.__exit__(exc_type, exc_value, traceback) + + +class IterableOperator(BaseOperator): + """Object representing an iterable operator in a DAG.""" + + _operator: MappedOperator + expand_input: ExpandInput + partial_kwargs: dict[str, Any] + # each operator should override this class attr for shallow copy attrs. + shallow_copy_attrs: Sequence[str] = ( + "_operator", + "expand_input", + "partial_kwargs", + "_log", + ) + + def __init__( + self, + *, + operator: MappedOperator, + expand_input: ExpandInput, + **kwargs, + ): + super().__init__( + **{ + **kwargs, + "task_id": operator.task_id, + "owner": operator.owner, + "email": operator.email, + "email_on_retry": operator.email_on_retry, + "email_on_failure": operator.email_on_failure, + "retries": 0, # We should not retry the IterableOperator, only the mapped ti's should be retried + "retry_delay": operator.retry_delay, + "retry_exponential_backoff": operator.retry_exponential_backoff, + "max_retry_delay": operator.max_retry_delay, + "start_date": operator.start_date, + "end_date": operator.end_date, + "depends_on_past": operator.depends_on_past, + "ignore_first_depends_on_past": operator.ignore_first_depends_on_past, + "wait_for_past_depends_before_skipping": operator.wait_for_past_depends_before_skipping, + "wait_for_downstream": operator.wait_for_downstream, + "dag": operator.dag, + "priority_weight": operator.priority_weight, + "queue": operator.queue, + "pool": operator.pool, + "pool_slots": operator.pool_slots, + "execution_timeout": None, + "trigger_rule": operator.trigger_rule, + "resources": operator.resources, + "run_as_user": operator.run_as_user, + "map_index_template": operator.map_index_template, + "max_active_tis_per_dag": operator.max_active_tis_per_dag, + "max_active_tis_per_dagrun": operator.max_active_tis_per_dagrun, + "executor": operator.executor, + "executor_config": operator.executor_config, + "inlets": operator.inlets, + "outlets": operator.outlets, + "task_group": operator.task_group, + "doc": operator.doc, + "doc_md": operator.doc_md, + "doc_json": operator.doc_json, + "doc_yaml": operator.doc_yaml, + "doc_rst": operator.doc_rst, + "task_display_name": operator.task_display_name, + "allow_nested_operators": operator.allow_nested_operators, + } + ) + self._operator = operator + self.expand_input = expand_input + self.partial_kwargs = operator.partial_kwargs or {} + self._number_of_tasks: int = 0 + XComArg.apply_upstream_relationship(self, self.expand_input.value) + + @property + def task_type(self) -> str: + """@property: type of the task.""" + return self._operator.__class__.__name__ + + @property + def max_workers(self): + return self.max_active_tis_per_dag or os.cpu_count() or 1 + + @property + def timeout(self) -> float | None: + if self.execution_timeout: + return self.execution_timeout.total_seconds() + return None + + def _get_specified_expand_input(self) -> ExpandInput: + return self.expand_input + + def _unmap_operator(self, mapped_kwargs: dict): + self._number_of_tasks += 1 + return self._operator.unmap(mapped_kwargs) + + def _xcom_push(self, context: Context, task: RuntimeTaskInstance, value: Any) -> None: + self.log.debug("Pushing XCom %s", task.map_index) + + context["ti"].xcom_push(key=task.xcom_key, value=value) + + def _run_tasks( + self, + context: Context, + tasks: Iterable[RuntimeTaskInstance], + ) -> None: + exception: BaseException | None = None + reschedule_date = timezone.utcnow() + prev_futures_count = 0 + futures: dict[Future, RuntimeTaskInstance] = {} + failed_tasks: deque[RuntimeTaskInstance] = deque() + chunked_tasks = ichunked(tasks, self.max_workers) + + self.log.info("Running tasks with %d workers", self.max_workers) + + with event_loop() as loop: + with HybridExecutor(loop=loop, max_workers=self.max_workers) as executor: + for task in next(chunked_tasks, []): + if task.is_async: + future = executor.submit(self._run_async_operator, context, task) + else: + future = executor.submit(self._run_operator, context, task) + futures[future] = task + + while futures: + futures_count = len(futures) + + if futures_count != prev_futures_count: + self.log.info("Number of remaining futures: %s", futures_count) + prev_futures_count = futures_count + + ready_futures = False + + for future in collect_futures(loop, futures.keys()): + task = futures.pop(future) + ready_futures = True + + try: + if isinstance(future, asyncio.futures.Future): + result = future.result() + else: + result = future.result(timeout=self.timeout) + + self.log.debug("result: %s", result) + + if result and task.task.do_xcom_push: + self._xcom_push( + context=context, + task=task, + value=result, + ) + except TaskDeferred as task_deferred: + operator = DecoratedDeferredAsyncOperator( + operator=task.task, task_deferred=task_deferred + ) + failed_tasks.append( + self._create_mapped_task(task.run_id, task.map_index, operator) + ) + except asyncio.TimeoutError as e: + self.log.warning("A timeout occurred for task_id %s", task.task_id) + if task.next_try_number > self.retries: + exception = AirflowTaskTimeout(e) + else: + reschedule_date = min(reschedule_date, task.next_retry_datetime()) + failed_tasks.append(task) + except AirflowRescheduleTaskInstanceException as e: + reschedule_date = min(reschedule_date, e.reschedule_date) + self.log.warning( + "An exception occurred for task_id %s with map_index %s, it has been rescheduled at %s", + task.task_id, + task.map_index, + reschedule_date, + ) + failed_tasks.append(e.task) + except Exception as e: + self.log.error( + "An exception occurred for task_id %s with map_index %s", + task.task_id, + task.map_index, + ) + exception = e + + if len(futures) < self.max_workers: + for task in next(chunked_tasks, []): + if task.is_async: + future = executor.submit(self._run_async_operator, context, task) + else: + future = executor.submit(self._run_operator, context, task) + futures[future] = task + elif not ready_futures and futures: Review Comment: A couple concerns with `sleep()` here: 1. `sleep(len(futures) * 0.1)` blocks the worker, preventing heartbeats. With many futures this could be a significant pause. 2. The retry delay `sleep(delay_seconds)` below blocks the entire worker process. The scheduler might consider it dead and kill it. What's the reasoning behind the `0.1` multiplier per future? ########## task-sdk/src/airflow/sdk/bases/iterableoperator.py: ########## @@ -0,0 +1,432 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import asyncio +import logging +import os +import time +from collections import deque +from collections.abc import Iterable, Sequence +from concurrent.futures import Future +from math import ceil +from time import sleep +from typing import TYPE_CHECKING, Any + +from more_itertools import ichunked + +from airflow.exceptions import ( + AirflowTaskTimeout, + TaskDeferred, +) +from airflow.sdk import timezone +from airflow.sdk.bases.operator import BaseOperator, DecoratedDeferredAsyncOperator, event_loop +from airflow.sdk.definitions.xcom_arg import MapXComArg, XComArg # noqa: F401 +from airflow.sdk.exceptions import AirflowRescheduleTaskInstanceException +from airflow.sdk.execution_time.executor import HybridExecutor, _execute_async_task, collect_futures +from airflow.sdk.execution_time.lazy_sequence import XComIterable +from airflow.sdk.execution_time.task_runner import MappedTaskInstance, RuntimeTaskInstance, _execute_task +from airflow.utils.log.logging_mixin import LoggingMixin +from airflow.utils.state import TaskInstanceState + +if TYPE_CHECKING: + from airflow.sdk.definitions._internal.expandinput import ExpandInput + from airflow.sdk.definitions.context import Context + from airflow.sdk.definitions.mappedoperator import MappedOperator + + +class TaskExecutor(LoggingMixin): + """Base class to run an operator or trigger with given task context and task instance.""" + + def __init__( + self, + task_instance: RuntimeTaskInstance, + ): + super().__init__() + self._task_instance = task_instance + self._result: Any | None = None + self._start_time: float | None = None + + @property + def task_instance(self) -> RuntimeTaskInstance: + return self._task_instance + + @property + def dag_id(self) -> str: + return self.task_instance.dag_id + + @property + def task_id(self) -> str: + return self.task_instance.task_id + + @property + def task_index(self) -> int: + return self.task_instance.map_index + + @property + def key(self): + return self.task_instance.xcom_key + + @property + def operator(self) -> BaseOperator: + return self.task_instance.task + + @property + def is_async(self) -> bool: + return self.task_instance.is_async + + def run(self, context: Context): + return _execute_task(context, self.task_instance, self.log) + + async def arun(self, context: Context): + return await _execute_async_task(context, self.task_instance, self.log) + + def __enter__(self): + self._start_time = time.monotonic() + + if self.log.isEnabledFor(logging.INFO): + self.log.info( + "Attempting running task %s of %s for %s with map_index %s in %s mode.", + self.task_instance.try_number, + self.operator.retries, + self.task_instance.task_id, + self.task_index, + "async" if self.is_async else "sync", + ) + return self + + async def __aenter__(self): + return self.__enter__() + + def __exit__(self, exc_type, exc_value, traceback): + elapsed = time.monotonic() - self._start_time + + if exc_value: + if not isinstance(exc_value, TaskDeferred): + if self.task_instance.next_try_number > self.task_instance.max_tries: + self.log.error( + "Task instance %s for %s failed after %s attempts in %.2f seconds due to: %s", + self.task_index, + self.task_instance.task_id, + self.task_instance.max_tries, + elapsed, + exc_value, + ) + self.task_instance.state = TaskInstanceState.FAILED + raise exc_value + self.task_instance.try_number += 1 + self.task_instance.end_date = timezone.utcnow() + self.task_instance.state = TaskInstanceState.UP_FOR_RESCHEDULE + raise AirflowRescheduleTaskInstanceException(task=self.task_instance) + raise exc_value + + self.task_instance.state = TaskInstanceState.SUCCESS + if self.log.isEnabledFor(logging.INFO): + self.log.info( + "Task instance %s for %s finished successfully in %s attempts in %.2f seconds", + self.task_index, + self.task_instance.task_id, + self.task_instance.next_try_number, + elapsed, + ) + + async def __aexit__(self, exc_type, exc_value, traceback): + self.__exit__(exc_type, exc_value, traceback) + + +class IterableOperator(BaseOperator): + """Object representing an iterable operator in a DAG.""" + + _operator: MappedOperator + expand_input: ExpandInput + partial_kwargs: dict[str, Any] + # each operator should override this class attr for shallow copy attrs. + shallow_copy_attrs: Sequence[str] = ( + "_operator", + "expand_input", + "partial_kwargs", + "_log", + ) + + def __init__( + self, + *, + operator: MappedOperator, + expand_input: ExpandInput, + **kwargs, + ): + super().__init__( + **{ + **kwargs, + "task_id": operator.task_id, + "owner": operator.owner, + "email": operator.email, + "email_on_retry": operator.email_on_retry, + "email_on_failure": operator.email_on_failure, + "retries": 0, # We should not retry the IterableOperator, only the mapped ti's should be retried + "retry_delay": operator.retry_delay, + "retry_exponential_backoff": operator.retry_exponential_backoff, + "max_retry_delay": operator.max_retry_delay, + "start_date": operator.start_date, + "end_date": operator.end_date, + "depends_on_past": operator.depends_on_past, + "ignore_first_depends_on_past": operator.ignore_first_depends_on_past, + "wait_for_past_depends_before_skipping": operator.wait_for_past_depends_before_skipping, + "wait_for_downstream": operator.wait_for_downstream, + "dag": operator.dag, + "priority_weight": operator.priority_weight, + "queue": operator.queue, + "pool": operator.pool, + "pool_slots": operator.pool_slots, + "execution_timeout": None, + "trigger_rule": operator.trigger_rule, + "resources": operator.resources, + "run_as_user": operator.run_as_user, + "map_index_template": operator.map_index_template, + "max_active_tis_per_dag": operator.max_active_tis_per_dag, + "max_active_tis_per_dagrun": operator.max_active_tis_per_dagrun, + "executor": operator.executor, + "executor_config": operator.executor_config, + "inlets": operator.inlets, + "outlets": operator.outlets, + "task_group": operator.task_group, + "doc": operator.doc, + "doc_md": operator.doc_md, + "doc_json": operator.doc_json, + "doc_yaml": operator.doc_yaml, + "doc_rst": operator.doc_rst, + "task_display_name": operator.task_display_name, + "allow_nested_operators": operator.allow_nested_operators, + } + ) + self._operator = operator + self.expand_input = expand_input + self.partial_kwargs = operator.partial_kwargs or {} + self._number_of_tasks: int = 0 + XComArg.apply_upstream_relationship(self, self.expand_input.value) + + @property + def task_type(self) -> str: + """@property: type of the task.""" + return self._operator.__class__.__name__ + + @property + def max_workers(self): + return self.max_active_tis_per_dag or os.cpu_count() or 1 + + @property + def timeout(self) -> float | None: + if self.execution_timeout: + return self.execution_timeout.total_seconds() + return None + + def _get_specified_expand_input(self) -> ExpandInput: + return self.expand_input + + def _unmap_operator(self, mapped_kwargs: dict): + self._number_of_tasks += 1 + return self._operator.unmap(mapped_kwargs) + + def _xcom_push(self, context: Context, task: RuntimeTaskInstance, value: Any) -> None: + self.log.debug("Pushing XCom %s", task.map_index) + + context["ti"].xcom_push(key=task.xcom_key, value=value) + + def _run_tasks( + self, + context: Context, + tasks: Iterable[RuntimeTaskInstance], + ) -> None: + exception: BaseException | None = None + reschedule_date = timezone.utcnow() + prev_futures_count = 0 + futures: dict[Future, RuntimeTaskInstance] = {} + failed_tasks: deque[RuntimeTaskInstance] = deque() + chunked_tasks = ichunked(tasks, self.max_workers) + + self.log.info("Running tasks with %d workers", self.max_workers) + + with event_loop() as loop: + with HybridExecutor(loop=loop, max_workers=self.max_workers) as executor: + for task in next(chunked_tasks, []): + if task.is_async: + future = executor.submit(self._run_async_operator, context, task) + else: + future = executor.submit(self._run_operator, context, task) + futures[future] = task + + while futures: + futures_count = len(futures) + + if futures_count != prev_futures_count: + self.log.info("Number of remaining futures: %s", futures_count) + prev_futures_count = futures_count + + ready_futures = False + + for future in collect_futures(loop, futures.keys()): + task = futures.pop(future) + ready_futures = True + + try: + if isinstance(future, asyncio.futures.Future): + result = future.result() + else: + result = future.result(timeout=self.timeout) + + self.log.debug("result: %s", result) + + if result and task.task.do_xcom_push: + self._xcom_push( + context=context, + task=task, + value=result, + ) + except TaskDeferred as task_deferred: + operator = DecoratedDeferredAsyncOperator( + operator=task.task, task_deferred=task_deferred + ) + failed_tasks.append( + self._create_mapped_task(task.run_id, task.map_index, operator) + ) + except asyncio.TimeoutError as e: + self.log.warning("A timeout occurred for task_id %s", task.task_id) + if task.next_try_number > self.retries: + exception = AirflowTaskTimeout(e) + else: + reschedule_date = min(reschedule_date, task.next_retry_datetime()) + failed_tasks.append(task) + except AirflowRescheduleTaskInstanceException as e: + reschedule_date = min(reschedule_date, e.reschedule_date) + self.log.warning( + "An exception occurred for task_id %s with map_index %s, it has been rescheduled at %s", + task.task_id, + task.map_index, + reschedule_date, + ) + failed_tasks.append(e.task) + except Exception as e: + self.log.error( + "An exception occurred for task_id %s with map_index %s", + task.task_id, + task.map_index, + ) Review Comment: When multiple tasks fail, `exception = e` overwrites the previous exception each time. Only the last failure gets raised — all prior failures are silently lost. Makes debugging really hard when multiple sub-tasks fail. Consider collecting all exceptions and raising an `ExceptionGroup` or at least logging each one. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
