kaxil commented on code in PR #62922: URL: https://github.com/apache/airflow/pull/62922#discussion_r3293416741
########## task-sdk/src/airflow/sdk/execution_time/executor.py: ########## @@ -0,0 +1,241 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import asyncio +import contextvars +import inspect +import logging +import time +from asyncio import AbstractEventLoop, Semaphore +from collections.abc import Callable, Generator +from concurrent.futures import Future, ThreadPoolExecutor, as_completed +from typing import TYPE_CHECKING, Any, cast + +from airflow.sdk import BaseAsyncOperator, BaseOperator, TaskInstanceState, timezone +from airflow.sdk.bases.operator import ExecutorSafeguard +from airflow.sdk.definitions._internal.logging_mixin import LoggingMixin +from airflow.sdk.exceptions import AirflowRescheduleTaskInstanceException, TaskDeferred +from airflow.sdk.execution_time.callback_runner import create_executable_runner +from airflow.sdk.execution_time.context import context_get_outlet_events +from airflow.sdk.execution_time.task_runner import ( + RuntimeTaskInstance, + _execute_task, + _run_task_state_change_callbacks, +) + +if TYPE_CHECKING: + from structlog.typing import FilteringBoundLogger as Logger + + from airflow.sdk import Context + from airflow.sdk.execution_time.task_runner import IndexedTaskInstance + + +def collect_futures( + loop: AbstractEventLoop, futures: list[Any] +) -> Generator[Future | asyncio.futures.Future, None, None]: + """ + Yield futures as they complete (sync or async). + + :param loop: The asyncio event loop to use for async tasks + :param futures: List of Future or asyncio.futures.Future objects to collect + :return: Generator yielding Future or asyncio.futures.Future objects as they complete + """ + yield from as_completed(f for f in futures if isinstance(f, Future)) + + async_tasks = [f for f in futures if isinstance(f, asyncio.futures.Future)] + + if async_tasks: + for task, _ in zip( + async_tasks, + loop.run_until_complete(asyncio.gather(*async_tasks, return_exceptions=True)), + ): + yield task + + +class ConcurrentExecutor: + """ + Executes both sync and async functions concurrently. + + Sync functions run in a ThreadPoolExecutor. + Async coroutines run on an asyncio event loop with a semaphore limit. + """ + + def __init__(self, loop: AbstractEventLoop, max_workers: int = 4): + self._loop = loop + self._semaphore = Semaphore(max_workers) + self._thread_pool = ThreadPoolExecutor(max_workers=max_workers) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self._thread_pool: + self._thread_pool.shutdown(wait=True) + + def submit(self, func: Callable, *args, **kwargs): + if inspect.iscoroutine(func): + coro = func + elif inspect.iscoroutinefunction(func): + coro = func(*args, **kwargs) + else: + return self._thread_pool.submit(func, *args, **kwargs) + + async def guarded(): + async with self._semaphore: + return await coro + + return self._loop.create_task(guarded()) + + +class TaskExecutor(LoggingMixin): + """Base class to run an operator or trigger with given task context and task instance.""" + + def __init__( + self, + task_instance: IndexedTaskInstance, + ): + super().__init__() + self.task_instance = task_instance + self._result: Any | None = None + self._start_time: float | None = None + + @property + def dag_id(self) -> str: + return self.task_instance.dag_id + + @property + def task_id(self) -> str: + return self.task_instance.task_id + + @property + def task_index(self) -> int: + return self.task_instance.index + + @property + def xcom_key(self): + return self.task_instance.xcom_key + + @property + def operator(self) -> BaseOperator: + return self.task_instance.task + + @property + def is_async(self) -> bool: + return self.task_instance.is_async + + def run(self, context: Context): + return _execute_task(context, self.task_instance, self.log) Review Comment: Calling `_execute_task` from a `ConcurrentExecutor`-submitted thread introduces a `os.environ` race that the existing single-process supervisor flow never had. Inside `_execute_task` (task_runner.py around line 2009), there's a `os.environ.update(airflow_context_vars)` that sets `AIRFLOW_CTX_TASK_ID`, `AIRFLOW_CTX_MAP_INDEX`, etc. globally. Two `TaskExecutor.run(...)` calls running in parallel threads will overwrite each other's env vars, so any user task that reads them (BashOperator's templated env, subprocess hooks, third-party libs picking up `AIRFLOW_CTX_*`) sees a value belonging to a sibling mapped task. Options: thread-local-ify the env update in `_execute_task`, pass the context vars through `context` and have operators read from there, or take a per-thread snapshot/restore around the call site here. ########## task-sdk/src/airflow/sdk/definitions/iterableoperator.py: ########## @@ -0,0 +1,453 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import asyncio +import copy +import os +from collections import deque +from collections.abc import Iterable, Mapping, Sequence +from concurrent.futures import Future +from itertools import chain +from typing import TYPE_CHECKING, Any + +try: + # Python 3.12+ + from itertools import batched # type: ignore[attr-defined] +except ImportError: + from more_itertools import batched # type: ignore[no-redef] + +try: + # Python 3.11+ + BaseExceptionGroup +except NameError: + from exceptiongroup import BaseExceptionGroup + +from airflow.sdk import BaseXCom, TaskInstanceState, timezone +from airflow.sdk.bases.operator import BaseOperator, DecoratedDeferredAsyncOperator, event_loop +from airflow.sdk.definitions._internal.expandinput import PartitionedExpandInput +from airflow.sdk.definitions.mappedoperator import MappedOperator +from airflow.sdk.definitions.xcom_arg import MapXComArg, XComArg # noqa: F401 +from airflow.sdk.exceptions import ( + AirflowRescheduleTaskInstanceException, + AirflowTaskTimeout, + TaskDeferred, +) +from airflow.sdk.execution_time.executor import ConcurrentExecutor, TaskExecutor, collect_futures +from airflow.sdk.execution_time.task_runner import IndexedTaskInstance + +if TYPE_CHECKING: + import jinja2 + + from airflow.sdk.definitions._internal.expandinput import ExpandInput + from airflow.sdk.definitions.context import Context + from airflow.sdk.execution_time.lazy_sequence import XComIterable + + +class IterableOperator(BaseOperator): + """Object representing an iterable operator in a DAG.""" + + _operator: MappedOperator + expand_input: ExpandInput + partial_kwargs: dict[str, Any] + shallow_copy_attrs: Sequence[str] = ( + "_operator", + "expand_input", + "partial_kwargs", + "_log", + ) + + def __init__( + self, + *, + operator: MappedOperator, + expand_input: ExpandInput, + **kwargs, + ): + super().__init__( + **{ + **kwargs, + "task_id": operator.task_id, + "owner": operator.owner, + "email": operator.email, + "email_on_retry": operator.email_on_retry, + "email_on_failure": operator.email_on_failure, + "retries": 0, # We should not retry the IterableOperator, only the mapped ti's should be retried + "retry_delay": operator.retry_delay, + "retry_exponential_backoff": operator.retry_exponential_backoff, + "max_retry_delay": operator.max_retry_delay, + "start_date": operator.start_date, + "end_date": operator.end_date, + "depends_on_past": operator.depends_on_past, + "ignore_first_depends_on_past": operator.ignore_first_depends_on_past, + "wait_for_past_depends_before_skipping": operator.wait_for_past_depends_before_skipping, + "wait_for_downstream": operator.wait_for_downstream, + "dag": operator.dag, + "priority_weight": operator.priority_weight, + "queue": operator.queue, + "pool": operator.pool, + "pool_slots": operator.pool_slots, + "execution_timeout": None, + "trigger_rule": operator.trigger_rule, + "resources": operator.resources, + "run_as_user": operator.run_as_user, + "map_index_template": operator.map_index_template, + "max_active_tis_per_dag": operator.max_active_tis_per_dag, + "max_active_tis_per_dagrun": operator.max_active_tis_per_dagrun, + "executor": operator.executor, + "executor_config": operator.executor_config, + "inlets": operator.inlets, + "outlets": operator.outlets, + "task_group": operator.task_group, + "doc": operator.doc, + "doc_md": operator.doc_md, + "doc_json": operator.doc_json, + "doc_yaml": operator.doc_yaml, + "doc_rst": operator.doc_rst, + "task_display_name": operator.task_display_name, + "allow_nested_operators": operator.allow_nested_operators, + } + ) + self._operator = operator + self.expand_input = expand_input + self.partial_kwargs = operator.partial_kwargs or {} + self.partial_kwargs.pop("partition_size", None) + self.max_workers = self.partial_kwargs.pop("task_concurrency", None) or os.cpu_count() or 1 + self._number_of_tasks: int = 0 + XComArg.apply_upstream_relationship(self, self.expand_input.value) + + @property + def task_type(self) -> str: + """@property: type of the task.""" + return self._operator.__class__.__name__ + + @property + def timeout(self) -> float | None: + if self.execution_timeout: + return self.execution_timeout.total_seconds() + return None + + def _do_render_template_fields( + self, + parent: Any, + template_fields: Iterable[str], + context: Context, + jinja_env: jinja2.Environment, + seen_oids: set[int], + ) -> None: + # IterableOperator doesn't need to render template fields as the actual operator's template fields + # will be rendered in the TaskExecutor when running each mapped task instance. + pass + + def _get_specified_expand_input(self) -> ExpandInput: + return self.expand_input + + def _unmap_operator( + self, context: Context, mapped_kwargs: Context, jinja_env: jinja2.Environment + ) -> BaseOperator: + from airflow.sdk.execution_time.context import context_update_for_unmapped + + self._number_of_tasks += 1 + unmapped_task = self._operator.unmap(mapped_kwargs) + # Make sure deferred operators will always raise a DeferredTask exception when executed + unmapped_task.start_from_trigger = False + context_update_for_unmapped(context, unmapped_task) + + unmapped_task._do_render_template_fields( + parent=unmapped_task, + template_fields=self._operator.template_fields, + context=context, + jinja_env=jinja_env, + seen_oids=set(), + ) + return unmapped_task + + def _xcom_push(self, task: IndexedTaskInstance, value: Any) -> None: + if task.xcom_pushed: + self.log.debug( + "XCom already pushed for task_id %s with index %s", + task.task_id, + task.index, + ) + else: + self.log.debug( + "Pushing XCom for task_id %s with index %s", + task.task_id, + task.index, + ) + + task.xcom_push(key=BaseXCom.XCOM_RETURN_KEY, value=value) + + def _run_tasks( + self, + context: Context, + tasks: Iterable[IndexedTaskInstance], + ) -> XComIterable | None: + exceptions: list[BaseException] = [] + reschedule_date = timezone.utcnow() + prev_futures_count = 0 + futures: dict[Future | asyncio.futures.Future, IndexedTaskInstance] = {} + deferred_tasks: deque[IndexedTaskInstance] = deque() + failed_tasks: deque[IndexedTaskInstance] = deque() + chunked_tasks = batched(tasks, self.max_workers) + do_xcom_push = True + + self.log.info("Running tasks with %d workers", self.max_workers) + + with event_loop() as loop: + with ConcurrentExecutor(loop=loop, max_workers=self.max_workers) as executor: + for task in next(chunked_tasks, []): + do_xcom_push = task.do_xcom_push + if task.is_async: + future = executor.submit(self._run_async_operator, context, task) + else: + future = executor.submit(self._run_operator, context, task) + futures[future] = task + + while futures: + futures_count = len(futures) + + if futures_count != prev_futures_count: + self.log.info("Number of remaining futures: %s", futures_count) + prev_futures_count = futures_count + + for future in collect_futures(loop, list(futures.keys())): + task = futures.pop(future) + + try: + if isinstance(future, asyncio.futures.Future): + result = future.result() + else: + result = future.result(timeout=self.timeout) + + self.log.debug("result: %s", result) + + if result is not None and task.do_xcom_push: + self._xcom_push( + task=task, + value=result, + ) + except TaskDeferred as task_deferred: + operator = DecoratedDeferredAsyncOperator( + operator=task.task, task_deferred=task_deferred + ) + deferred_tasks.append( + self._create_mapped_task( + run_id=task.run_id, + index=task.index, + map_index=task.map_index, # type: ignore[arg-type] + try_number=task.try_number, + operator=operator, + ) + ) + except asyncio.TimeoutError as e: + self.log.warning("A timeout occurred for task_id %s", task.task_id) + if task.next_try_number > (self.retries or 0): + exceptions.append(AirflowTaskTimeout(e)) + else: + reschedule_date = min(reschedule_date, task.next_retry_datetime()) + failed_tasks.append(task) + except AirflowRescheduleTaskInstanceException as e: + reschedule_date = min(reschedule_date, e.reschedule_date) + self.log.exception( + "An exception occurred for task_id %s with index %s, it has been rescheduled at %s", + task.task_id, + task.index, + reschedule_date, + ) + failed_tasks.append(e.task) + except Exception as e: + self.log.exception( + "An exception occurred for task_id %s with index %s", + task.task_id, + task.index, + ) + exceptions.append(e) + + if len(futures) < self.max_workers: + chunked_tasks = chain(list(deferred_tasks), chunked_tasks) + deferred_tasks.clear() + + for task in next(chunked_tasks, []): + if task.is_async: + future = executor.submit(self._run_async_operator, context, task) + else: + future = executor.submit(self._run_operator, context, task) + futures[future] = task + + if not failed_tasks: + if exceptions: + raise BaseExceptionGroup("Multiple sub-task failures", exceptions) + if do_xcom_push: + from airflow.sdk.execution_time.lazy_sequence import XComIterable + + return XComIterable( + task_id=self.task_id, + dag_id=self.dag_id, + run_id=context["run_id"], + length=self._number_of_tasks, + map_index=context["ti"].map_index, + ) + return None + + # If the retry time is still in the future we defer the operator so the worker + # slot is released. If the retry time has already passed we immediately re-run + # the failed tasks without deferring. + if reschedule_date > timezone.utcnow(): + # TODO: This is tricky as that import doesn't exist in Task SDK + from airflow.providers.standard.triggers.temporal import DateTimeTrigger + + self.defer( + trigger=DateTimeTrigger(reschedule_date), + method_name=self.execute_failed_tasks.__name__, + kwargs={ + "failed_tasks": {failed_task.index for failed_task in failed_tasks}, + "try_number": next(iter(failed_tasks)).try_number, + }, + ) + + return self._run_tasks(context=context, tasks=list(failed_tasks)) + + def _run_operator(self, context: Context, task_instance: IndexedTaskInstance): + with TaskExecutor(task_instance=task_instance) as executor: + return executor.run( + context={ + **dict(context), Review Comment: Partially addressed from the previous review. The shallow merge `{**dict(context), **{"ti": ..., "task_instance": ...}}` creates a fresh top-level dict per task, which correctly prevents `ti` / `task_instance` from being clobbered across threads. But the context's nested mutable values (`params`, `dag_run.conf`, `outlet_events`, `dag_run`, `inlet_events`, ...) are still shared by reference. If user code mutates any context-attached object (`context['params']['x'] = ...`, `context['outlet_events'].add(...)`, etc.), sibling threads see it. `copy.deepcopy(context)` is the safe choice; if that's too heavyweight for the hot path, audit the context shape and explicitly `copy()` the mutable fields that user code is allowed to touch. ########## task-sdk/src/airflow/sdk/definitions/_internal/expandinput.py: ########## @@ -184,6 +243,19 @@ def iter_references(self) -> Iterable[tuple[Operator, str]]: if isinstance(x, XComArg): yield from x.iter_references() + def iter_values(self, context: Mapping[str, Any]) -> Iterable[Any]: + from airflow.sdk.definitions.xcom_arg import XComArg + + resolved = {k: v.resolve(context) if isinstance(v, XComArg) else v for k, v in self.value.items()} + keys = list(resolved) + for items in zip( Review Comment: Semantic divergence from classic dynamic task mapping is still unresolved. `DictOfListsExpandInput.iter_values` uses `zip(*values)`, so `expand(a=[1,2], b=[3,4])` yields 2 dicts (`{a:1,b:3}`, `{a:2,b:4}`). But the scheduler-side `SchedulerDictOfListsExpandInput.get_total_map_length` uses `functools.reduce(operator.mul, lengths)` -- cartesian -- so the same call mapped via `.expand()` produces 4 task instances. Two issues: 1. `.iterate()` on a 2-key expand silently drops items the user expects to be processed. 2. Users switching between `.expand()` and `.iterate()` get different cardinality with no warning or doc note. `test_dict_of_lists_expand_input_iter_values` currently locks in the zip behavior, which confirms the divergence is intentional but it isn't documented anywhere user-visible. Options: (a) make this cartesian to match DTM, (b) document the zip choice prominently in the DTI docs, or (c) raise on multi-key expand and force the user to use `iterate_kwargs` explicitly. ########## task-sdk/src/airflow/sdk/definitions/iterableoperator.py: ########## @@ -0,0 +1,453 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import asyncio +import copy +import os +from collections import deque +from collections.abc import Iterable, Mapping, Sequence +from concurrent.futures import Future +from itertools import chain +from typing import TYPE_CHECKING, Any + +try: + # Python 3.12+ + from itertools import batched # type: ignore[attr-defined] +except ImportError: + from more_itertools import batched # type: ignore[no-redef] + +try: + # Python 3.11+ + BaseExceptionGroup +except NameError: + from exceptiongroup import BaseExceptionGroup + +from airflow.sdk import BaseXCom, TaskInstanceState, timezone +from airflow.sdk.bases.operator import BaseOperator, DecoratedDeferredAsyncOperator, event_loop +from airflow.sdk.definitions._internal.expandinput import PartitionedExpandInput +from airflow.sdk.definitions.mappedoperator import MappedOperator +from airflow.sdk.definitions.xcom_arg import MapXComArg, XComArg # noqa: F401 +from airflow.sdk.exceptions import ( + AirflowRescheduleTaskInstanceException, + AirflowTaskTimeout, + TaskDeferred, +) +from airflow.sdk.execution_time.executor import ConcurrentExecutor, TaskExecutor, collect_futures +from airflow.sdk.execution_time.task_runner import IndexedTaskInstance + +if TYPE_CHECKING: + import jinja2 + + from airflow.sdk.definitions._internal.expandinput import ExpandInput + from airflow.sdk.definitions.context import Context + from airflow.sdk.execution_time.lazy_sequence import XComIterable + + +class IterableOperator(BaseOperator): + """Object representing an iterable operator in a DAG.""" + + _operator: MappedOperator + expand_input: ExpandInput + partial_kwargs: dict[str, Any] + shallow_copy_attrs: Sequence[str] = ( + "_operator", + "expand_input", + "partial_kwargs", + "_log", + ) + + def __init__( + self, + *, + operator: MappedOperator, + expand_input: ExpandInput, + **kwargs, + ): + super().__init__( + **{ + **kwargs, + "task_id": operator.task_id, + "owner": operator.owner, + "email": operator.email, + "email_on_retry": operator.email_on_retry, + "email_on_failure": operator.email_on_failure, + "retries": 0, # We should not retry the IterableOperator, only the mapped ti's should be retried + "retry_delay": operator.retry_delay, + "retry_exponential_backoff": operator.retry_exponential_backoff, + "max_retry_delay": operator.max_retry_delay, + "start_date": operator.start_date, + "end_date": operator.end_date, + "depends_on_past": operator.depends_on_past, + "ignore_first_depends_on_past": operator.ignore_first_depends_on_past, + "wait_for_past_depends_before_skipping": operator.wait_for_past_depends_before_skipping, + "wait_for_downstream": operator.wait_for_downstream, + "dag": operator.dag, + "priority_weight": operator.priority_weight, + "queue": operator.queue, + "pool": operator.pool, + "pool_slots": operator.pool_slots, + "execution_timeout": None, + "trigger_rule": operator.trigger_rule, + "resources": operator.resources, + "run_as_user": operator.run_as_user, + "map_index_template": operator.map_index_template, + "max_active_tis_per_dag": operator.max_active_tis_per_dag, + "max_active_tis_per_dagrun": operator.max_active_tis_per_dagrun, + "executor": operator.executor, + "executor_config": operator.executor_config, + "inlets": operator.inlets, + "outlets": operator.outlets, + "task_group": operator.task_group, + "doc": operator.doc, + "doc_md": operator.doc_md, + "doc_json": operator.doc_json, + "doc_yaml": operator.doc_yaml, + "doc_rst": operator.doc_rst, + "task_display_name": operator.task_display_name, + "allow_nested_operators": operator.allow_nested_operators, + } + ) + self._operator = operator + self.expand_input = expand_input + self.partial_kwargs = operator.partial_kwargs or {} + self.partial_kwargs.pop("partition_size", None) + self.max_workers = self.partial_kwargs.pop("task_concurrency", None) or os.cpu_count() or 1 + self._number_of_tasks: int = 0 + XComArg.apply_upstream_relationship(self, self.expand_input.value) + + @property + def task_type(self) -> str: + """@property: type of the task.""" + return self._operator.__class__.__name__ + + @property + def timeout(self) -> float | None: + if self.execution_timeout: + return self.execution_timeout.total_seconds() + return None + + def _do_render_template_fields( + self, + parent: Any, + template_fields: Iterable[str], + context: Context, + jinja_env: jinja2.Environment, + seen_oids: set[int], + ) -> None: + # IterableOperator doesn't need to render template fields as the actual operator's template fields + # will be rendered in the TaskExecutor when running each mapped task instance. + pass + + def _get_specified_expand_input(self) -> ExpandInput: + return self.expand_input + + def _unmap_operator( + self, context: Context, mapped_kwargs: Context, jinja_env: jinja2.Environment + ) -> BaseOperator: + from airflow.sdk.execution_time.context import context_update_for_unmapped + + self._number_of_tasks += 1 + unmapped_task = self._operator.unmap(mapped_kwargs) + # Make sure deferred operators will always raise a DeferredTask exception when executed + unmapped_task.start_from_trigger = False + context_update_for_unmapped(context, unmapped_task) + + unmapped_task._do_render_template_fields( + parent=unmapped_task, + template_fields=self._operator.template_fields, + context=context, + jinja_env=jinja_env, + seen_oids=set(), + ) + return unmapped_task + + def _xcom_push(self, task: IndexedTaskInstance, value: Any) -> None: + if task.xcom_pushed: + self.log.debug( + "XCom already pushed for task_id %s with index %s", + task.task_id, + task.index, + ) + else: + self.log.debug( + "Pushing XCom for task_id %s with index %s", + task.task_id, + task.index, + ) + + task.xcom_push(key=BaseXCom.XCOM_RETURN_KEY, value=value) + + def _run_tasks( + self, + context: Context, + tasks: Iterable[IndexedTaskInstance], + ) -> XComIterable | None: + exceptions: list[BaseException] = [] + reschedule_date = timezone.utcnow() + prev_futures_count = 0 + futures: dict[Future | asyncio.futures.Future, IndexedTaskInstance] = {} + deferred_tasks: deque[IndexedTaskInstance] = deque() + failed_tasks: deque[IndexedTaskInstance] = deque() + chunked_tasks = batched(tasks, self.max_workers) + do_xcom_push = True + + self.log.info("Running tasks with %d workers", self.max_workers) + + with event_loop() as loop: + with ConcurrentExecutor(loop=loop, max_workers=self.max_workers) as executor: + for task in next(chunked_tasks, []): + do_xcom_push = task.do_xcom_push + if task.is_async: + future = executor.submit(self._run_async_operator, context, task) + else: + future = executor.submit(self._run_operator, context, task) + futures[future] = task + + while futures: + futures_count = len(futures) + + if futures_count != prev_futures_count: + self.log.info("Number of remaining futures: %s", futures_count) + prev_futures_count = futures_count + + for future in collect_futures(loop, list(futures.keys())): + task = futures.pop(future) + + try: + if isinstance(future, asyncio.futures.Future): + result = future.result() + else: + result = future.result(timeout=self.timeout) + + self.log.debug("result: %s", result) + + if result is not None and task.do_xcom_push: + self._xcom_push( + task=task, + value=result, + ) + except TaskDeferred as task_deferred: + operator = DecoratedDeferredAsyncOperator( + operator=task.task, task_deferred=task_deferred + ) + deferred_tasks.append( + self._create_mapped_task( + run_id=task.run_id, + index=task.index, + map_index=task.map_index, # type: ignore[arg-type] + try_number=task.try_number, + operator=operator, + ) + ) + except asyncio.TimeoutError as e: + self.log.warning("A timeout occurred for task_id %s", task.task_id) + if task.next_try_number > (self.retries or 0): + exceptions.append(AirflowTaskTimeout(e)) + else: + reschedule_date = min(reschedule_date, task.next_retry_datetime()) + failed_tasks.append(task) + except AirflowRescheduleTaskInstanceException as e: + reschedule_date = min(reschedule_date, e.reschedule_date) + self.log.exception( + "An exception occurred for task_id %s with index %s, it has been rescheduled at %s", + task.task_id, + task.index, + reschedule_date, + ) + failed_tasks.append(e.task) + except Exception as e: + self.log.exception( + "An exception occurred for task_id %s with index %s", + task.task_id, + task.index, + ) + exceptions.append(e) + + if len(futures) < self.max_workers: + chunked_tasks = chain(list(deferred_tasks), chunked_tasks) + deferred_tasks.clear() + + for task in next(chunked_tasks, []): + if task.is_async: + future = executor.submit(self._run_async_operator, context, task) + else: + future = executor.submit(self._run_operator, context, task) + futures[future] = task + + if not failed_tasks: + if exceptions: + raise BaseExceptionGroup("Multiple sub-task failures", exceptions) + if do_xcom_push: + from airflow.sdk.execution_time.lazy_sequence import XComIterable + + return XComIterable( + task_id=self.task_id, + dag_id=self.dag_id, + run_id=context["run_id"], + length=self._number_of_tasks, + map_index=context["ti"].map_index, + ) + return None + + # If the retry time is still in the future we defer the operator so the worker + # slot is released. If the retry time has already passed we immediately re-run + # the failed tasks without deferring. + if reschedule_date > timezone.utcnow(): + # TODO: This is tricky as that import doesn't exist in Task SDK + from airflow.providers.standard.triggers.temporal import DateTimeTrigger + + self.defer( + trigger=DateTimeTrigger(reschedule_date), + method_name=self.execute_failed_tasks.__name__, + kwargs={ + "failed_tasks": {failed_task.index for failed_task in failed_tasks}, + "try_number": next(iter(failed_tasks)).try_number, + }, + ) + + return self._run_tasks(context=context, tasks=list(failed_tasks)) Review Comment: Still unaddressed from the previous review: `return self._run_tasks(context=context, tasks=list(failed_tasks))` recurses on every retry batch. Sensor-like patterns or any workload that repeatedly hits `AirflowRescheduleTaskInstanceException` will blow the Python recursion limit (default 1000). Convert to a `while failed_tasks:` loop that reassigns `failed_tasks` each iteration -- the tail call here has no special meaning, it's just a tight retry loop. ########## task-sdk/src/airflow/sdk/execution_time/executor.py: ########## @@ -0,0 +1,241 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import asyncio +import contextvars +import inspect +import logging +import time +from asyncio import AbstractEventLoop, Semaphore +from collections.abc import Callable, Generator +from concurrent.futures import Future, ThreadPoolExecutor, as_completed +from typing import TYPE_CHECKING, Any, cast + +from airflow.sdk import BaseAsyncOperator, BaseOperator, TaskInstanceState, timezone +from airflow.sdk.bases.operator import ExecutorSafeguard +from airflow.sdk.definitions._internal.logging_mixin import LoggingMixin +from airflow.sdk.exceptions import AirflowRescheduleTaskInstanceException, TaskDeferred +from airflow.sdk.execution_time.callback_runner import create_executable_runner +from airflow.sdk.execution_time.context import context_get_outlet_events +from airflow.sdk.execution_time.task_runner import ( + RuntimeTaskInstance, + _execute_task, + _run_task_state_change_callbacks, +) + +if TYPE_CHECKING: + from structlog.typing import FilteringBoundLogger as Logger + + from airflow.sdk import Context + from airflow.sdk.execution_time.task_runner import IndexedTaskInstance + + +def collect_futures( + loop: AbstractEventLoop, futures: list[Any] +) -> Generator[Future | asyncio.futures.Future, None, None]: + """ + Yield futures as they complete (sync or async). + + :param loop: The asyncio event loop to use for async tasks + :param futures: List of Future or asyncio.futures.Future objects to collect + :return: Generator yielding Future or asyncio.futures.Future objects as they complete + """ + yield from as_completed(f for f in futures if isinstance(f, Future)) + + async_tasks = [f for f in futures if isinstance(f, asyncio.futures.Future)] + + if async_tasks: + for task, _ in zip( + async_tasks, + loop.run_until_complete(asyncio.gather(*async_tasks, return_exceptions=True)), + ): + yield task + + +class ConcurrentExecutor: + """ + Executes both sync and async functions concurrently. + + Sync functions run in a ThreadPoolExecutor. + Async coroutines run on an asyncio event loop with a semaphore limit. + """ + + def __init__(self, loop: AbstractEventLoop, max_workers: int = 4): + self._loop = loop + self._semaphore = Semaphore(max_workers) + self._thread_pool = ThreadPoolExecutor(max_workers=max_workers) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self._thread_pool: + self._thread_pool.shutdown(wait=True) Review Comment: `ConcurrentExecutor.__exit__` does `self._thread_pool.shutdown(wait=True)`. If any worker thread is hung in user code (slow HTTP, blocking C extension, a deadlocked DB driver), `__exit__` waits forever -- and the outer `future.result(timeout=...)` in `_run_tasks` can't break out because the context manager owns the join. Same anti-pattern as "`ThreadPoolExecutor` as context manager negates the per-future timeout." Use `shutdown(wait=False, cancel_futures=True)` (and document that in-flight tasks are abandoned, not killed -- Python can't kill threads), or take a configurable shutdown deadline. ########## task-sdk/src/airflow/sdk/definitions/mappedoperator.py: ########## @@ -322,6 +289,7 @@ class MappedOperator(AbstractOperator): end_date: pendulum.DateTime | None upstream_task_ids: set[str] = attrs.field(factory=set, init=False) downstream_task_ids: set[str] = attrs.field(factory=set, init=False) + _apply_upstream_relationship: bool = attrs.field(alias="apply_upstream_relationship", default=True) Review Comment: Flag-name vs behavior mismatch. The alias is `apply_upstream_relationship`, which reads like it only gates `XComArg.apply_upstream_relationship(...)`. But the `__attrs_post_init__` body (now correctly documented in the comment block) also gates `task_group.add(self)` and `dag.add_task(self)`. Call sites like `PartitionedOperator._iterate(apply_upstream_relationship=False)` read as "don't wire up XComArg relationships," when the actual effect is "don't register the operator with the DAG at all." The docstring/comment covers the why, but the parameter name still misleads. Consider renaming to `_register_with_dag` (or `_is_throwaway_instance`) and inverting the default, so the call sites read truer to what they do. ########## task-sdk/src/airflow/sdk/definitions/iterableoperator.py: ########## @@ -0,0 +1,453 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import asyncio +import copy +import os +from collections import deque +from collections.abc import Iterable, Mapping, Sequence +from concurrent.futures import Future +from itertools import chain +from typing import TYPE_CHECKING, Any + +try: + # Python 3.12+ + from itertools import batched # type: ignore[attr-defined] +except ImportError: + from more_itertools import batched # type: ignore[no-redef] + +try: + # Python 3.11+ + BaseExceptionGroup +except NameError: + from exceptiongroup import BaseExceptionGroup + +from airflow.sdk import BaseXCom, TaskInstanceState, timezone +from airflow.sdk.bases.operator import BaseOperator, DecoratedDeferredAsyncOperator, event_loop +from airflow.sdk.definitions._internal.expandinput import PartitionedExpandInput +from airflow.sdk.definitions.mappedoperator import MappedOperator +from airflow.sdk.definitions.xcom_arg import MapXComArg, XComArg # noqa: F401 +from airflow.sdk.exceptions import ( + AirflowRescheduleTaskInstanceException, + AirflowTaskTimeout, + TaskDeferred, +) +from airflow.sdk.execution_time.executor import ConcurrentExecutor, TaskExecutor, collect_futures +from airflow.sdk.execution_time.task_runner import IndexedTaskInstance + +if TYPE_CHECKING: + import jinja2 + + from airflow.sdk.definitions._internal.expandinput import ExpandInput + from airflow.sdk.definitions.context import Context + from airflow.sdk.execution_time.lazy_sequence import XComIterable + + +class IterableOperator(BaseOperator): + """Object representing an iterable operator in a DAG.""" + + _operator: MappedOperator + expand_input: ExpandInput + partial_kwargs: dict[str, Any] + shallow_copy_attrs: Sequence[str] = ( + "_operator", + "expand_input", + "partial_kwargs", + "_log", + ) + + def __init__( + self, + *, + operator: MappedOperator, + expand_input: ExpandInput, + **kwargs, + ): + super().__init__( + **{ + **kwargs, + "task_id": operator.task_id, + "owner": operator.owner, + "email": operator.email, + "email_on_retry": operator.email_on_retry, + "email_on_failure": operator.email_on_failure, + "retries": 0, # We should not retry the IterableOperator, only the mapped ti's should be retried + "retry_delay": operator.retry_delay, + "retry_exponential_backoff": operator.retry_exponential_backoff, + "max_retry_delay": operator.max_retry_delay, + "start_date": operator.start_date, + "end_date": operator.end_date, + "depends_on_past": operator.depends_on_past, + "ignore_first_depends_on_past": operator.ignore_first_depends_on_past, + "wait_for_past_depends_before_skipping": operator.wait_for_past_depends_before_skipping, + "wait_for_downstream": operator.wait_for_downstream, + "dag": operator.dag, + "priority_weight": operator.priority_weight, + "queue": operator.queue, + "pool": operator.pool, + "pool_slots": operator.pool_slots, + "execution_timeout": None, Review Comment: `execution_timeout` is hardcoded to `None` in the super-init kwargs, but the `timeout` property below reads `self.execution_timeout`: ```python @property def timeout(self) -> float | None: if self.execution_timeout: return self.execution_timeout.total_seconds() return None ``` and `future.result(timeout=self.timeout)` in `_run_tasks` relies on it. Net effect: any `execution_timeout` set on the operator (via `partial` or via the wrapped `MappedOperator`) is silently dropped -- `future.result(timeout=None)` waits forever. Either propagate `operator.execution_timeout` to the super-init, or document why timeout enforcement is intentionally disabled at the `IterableOperator` level (and what the user should do instead). ########## task-sdk/src/airflow/sdk/definitions/iterableoperator.py: ########## @@ -0,0 +1,453 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import asyncio +import copy +import os +from collections import deque +from collections.abc import Iterable, Mapping, Sequence +from concurrent.futures import Future +from itertools import chain +from typing import TYPE_CHECKING, Any + +try: + # Python 3.12+ + from itertools import batched # type: ignore[attr-defined] +except ImportError: + from more_itertools import batched # type: ignore[no-redef] + +try: + # Python 3.11+ + BaseExceptionGroup +except NameError: + from exceptiongroup import BaseExceptionGroup + +from airflow.sdk import BaseXCom, TaskInstanceState, timezone +from airflow.sdk.bases.operator import BaseOperator, DecoratedDeferredAsyncOperator, event_loop +from airflow.sdk.definitions._internal.expandinput import PartitionedExpandInput +from airflow.sdk.definitions.mappedoperator import MappedOperator +from airflow.sdk.definitions.xcom_arg import MapXComArg, XComArg # noqa: F401 +from airflow.sdk.exceptions import ( + AirflowRescheduleTaskInstanceException, + AirflowTaskTimeout, + TaskDeferred, +) +from airflow.sdk.execution_time.executor import ConcurrentExecutor, TaskExecutor, collect_futures +from airflow.sdk.execution_time.task_runner import IndexedTaskInstance + +if TYPE_CHECKING: + import jinja2 + + from airflow.sdk.definitions._internal.expandinput import ExpandInput + from airflow.sdk.definitions.context import Context + from airflow.sdk.execution_time.lazy_sequence import XComIterable + + +class IterableOperator(BaseOperator): + """Object representing an iterable operator in a DAG.""" + + _operator: MappedOperator + expand_input: ExpandInput + partial_kwargs: dict[str, Any] + shallow_copy_attrs: Sequence[str] = ( + "_operator", + "expand_input", + "partial_kwargs", + "_log", + ) + + def __init__( + self, + *, + operator: MappedOperator, + expand_input: ExpandInput, + **kwargs, + ): + super().__init__( + **{ + **kwargs, + "task_id": operator.task_id, + "owner": operator.owner, + "email": operator.email, + "email_on_retry": operator.email_on_retry, + "email_on_failure": operator.email_on_failure, + "retries": 0, # We should not retry the IterableOperator, only the mapped ti's should be retried + "retry_delay": operator.retry_delay, + "retry_exponential_backoff": operator.retry_exponential_backoff, + "max_retry_delay": operator.max_retry_delay, + "start_date": operator.start_date, + "end_date": operator.end_date, + "depends_on_past": operator.depends_on_past, + "ignore_first_depends_on_past": operator.ignore_first_depends_on_past, + "wait_for_past_depends_before_skipping": operator.wait_for_past_depends_before_skipping, + "wait_for_downstream": operator.wait_for_downstream, + "dag": operator.dag, + "priority_weight": operator.priority_weight, + "queue": operator.queue, + "pool": operator.pool, + "pool_slots": operator.pool_slots, + "execution_timeout": None, + "trigger_rule": operator.trigger_rule, + "resources": operator.resources, + "run_as_user": operator.run_as_user, + "map_index_template": operator.map_index_template, + "max_active_tis_per_dag": operator.max_active_tis_per_dag, + "max_active_tis_per_dagrun": operator.max_active_tis_per_dagrun, + "executor": operator.executor, + "executor_config": operator.executor_config, + "inlets": operator.inlets, + "outlets": operator.outlets, + "task_group": operator.task_group, + "doc": operator.doc, + "doc_md": operator.doc_md, + "doc_json": operator.doc_json, + "doc_yaml": operator.doc_yaml, + "doc_rst": operator.doc_rst, + "task_display_name": operator.task_display_name, + "allow_nested_operators": operator.allow_nested_operators, + } + ) + self._operator = operator + self.expand_input = expand_input + self.partial_kwargs = operator.partial_kwargs or {} + self.partial_kwargs.pop("partition_size", None) + self.max_workers = self.partial_kwargs.pop("task_concurrency", None) or os.cpu_count() or 1 + self._number_of_tasks: int = 0 + XComArg.apply_upstream_relationship(self, self.expand_input.value) + + @property + def task_type(self) -> str: + """@property: type of the task.""" + return self._operator.__class__.__name__ + + @property + def timeout(self) -> float | None: + if self.execution_timeout: + return self.execution_timeout.total_seconds() + return None + + def _do_render_template_fields( + self, + parent: Any, + template_fields: Iterable[str], + context: Context, + jinja_env: jinja2.Environment, + seen_oids: set[int], + ) -> None: + # IterableOperator doesn't need to render template fields as the actual operator's template fields + # will be rendered in the TaskExecutor when running each mapped task instance. + pass + + def _get_specified_expand_input(self) -> ExpandInput: + return self.expand_input + + def _unmap_operator( + self, context: Context, mapped_kwargs: Context, jinja_env: jinja2.Environment + ) -> BaseOperator: + from airflow.sdk.execution_time.context import context_update_for_unmapped + + self._number_of_tasks += 1 + unmapped_task = self._operator.unmap(mapped_kwargs) + # Make sure deferred operators will always raise a DeferredTask exception when executed + unmapped_task.start_from_trigger = False + context_update_for_unmapped(context, unmapped_task) + + unmapped_task._do_render_template_fields( + parent=unmapped_task, + template_fields=self._operator.template_fields, + context=context, + jinja_env=jinja_env, + seen_oids=set(), + ) + return unmapped_task + + def _xcom_push(self, task: IndexedTaskInstance, value: Any) -> None: + if task.xcom_pushed: + self.log.debug( + "XCom already pushed for task_id %s with index %s", + task.task_id, + task.index, + ) + else: + self.log.debug( + "Pushing XCom for task_id %s with index %s", + task.task_id, + task.index, + ) + + task.xcom_push(key=BaseXCom.XCOM_RETURN_KEY, value=value) + + def _run_tasks( + self, + context: Context, + tasks: Iterable[IndexedTaskInstance], + ) -> XComIterable | None: + exceptions: list[BaseException] = [] + reschedule_date = timezone.utcnow() + prev_futures_count = 0 + futures: dict[Future | asyncio.futures.Future, IndexedTaskInstance] = {} + deferred_tasks: deque[IndexedTaskInstance] = deque() + failed_tasks: deque[IndexedTaskInstance] = deque() + chunked_tasks = batched(tasks, self.max_workers) + do_xcom_push = True + + self.log.info("Running tasks with %d workers", self.max_workers) + + with event_loop() as loop: + with ConcurrentExecutor(loop=loop, max_workers=self.max_workers) as executor: + for task in next(chunked_tasks, []): + do_xcom_push = task.do_xcom_push + if task.is_async: + future = executor.submit(self._run_async_operator, context, task) + else: + future = executor.submit(self._run_operator, context, task) + futures[future] = task + + while futures: + futures_count = len(futures) + + if futures_count != prev_futures_count: + self.log.info("Number of remaining futures: %s", futures_count) + prev_futures_count = futures_count + + for future in collect_futures(loop, list(futures.keys())): + task = futures.pop(future) + + try: + if isinstance(future, asyncio.futures.Future): + result = future.result() + else: + result = future.result(timeout=self.timeout) + + self.log.debug("result: %s", result) + + if result is not None and task.do_xcom_push: + self._xcom_push( + task=task, + value=result, + ) + except TaskDeferred as task_deferred: + operator = DecoratedDeferredAsyncOperator( + operator=task.task, task_deferred=task_deferred + ) + deferred_tasks.append( + self._create_mapped_task( + run_id=task.run_id, + index=task.index, + map_index=task.map_index, # type: ignore[arg-type] + try_number=task.try_number, + operator=operator, + ) + ) + except asyncio.TimeoutError as e: + self.log.warning("A timeout occurred for task_id %s", task.task_id) + if task.next_try_number > (self.retries or 0): + exceptions.append(AirflowTaskTimeout(e)) + else: + reschedule_date = min(reschedule_date, task.next_retry_datetime()) + failed_tasks.append(task) + except AirflowRescheduleTaskInstanceException as e: + reschedule_date = min(reschedule_date, e.reschedule_date) + self.log.exception( + "An exception occurred for task_id %s with index %s, it has been rescheduled at %s", + task.task_id, + task.index, + reschedule_date, + ) + failed_tasks.append(e.task) + except Exception as e: + self.log.exception( + "An exception occurred for task_id %s with index %s", + task.task_id, + task.index, + ) + exceptions.append(e) + + if len(futures) < self.max_workers: + chunked_tasks = chain(list(deferred_tasks), chunked_tasks) + deferred_tasks.clear() + + for task in next(chunked_tasks, []): + if task.is_async: + future = executor.submit(self._run_async_operator, context, task) + else: + future = executor.submit(self._run_operator, context, task) + futures[future] = task + + if not failed_tasks: + if exceptions: + raise BaseExceptionGroup("Multiple sub-task failures", exceptions) + if do_xcom_push: + from airflow.sdk.execution_time.lazy_sequence import XComIterable + + return XComIterable( + task_id=self.task_id, + dag_id=self.dag_id, + run_id=context["run_id"], + length=self._number_of_tasks, + map_index=context["ti"].map_index, + ) + return None + + # If the retry time is still in the future we defer the operator so the worker + # slot is released. If the retry time has already passed we immediately re-run + # the failed tasks without deferring. + if reschedule_date > timezone.utcnow(): + # TODO: This is tricky as that import doesn't exist in Task SDK + from airflow.providers.standard.triggers.temporal import DateTimeTrigger Review Comment: Runtime `from airflow.providers.standard.triggers.temporal import DateTimeTrigger` inside `_run_tasks` breaks the task-sdk INV-3 contract (the SDK is supposed to ship without provider packages). The TODO on the line above acknowledges this, and the pre-commit exclusion added to `task-sdk/.pre-commit-config.yaml` for `check-core-imports` silently formalizes the violation. Options: (a) move `DateTimeTrigger` (or a thin SDK equivalent) into the SDK, (b) declare `apache-airflow-providers-standard` as a hard runtime dependency of task-sdk in `pyproject.toml`, or (c) expose a sleep/reschedule primitive in the SDK that the standard provider extends. Silently bypassing the import check is the worst of the three. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
