dabla commented on code in PR #62922: URL: https://github.com/apache/airflow/pull/62922#discussion_r2958727610
########## task-sdk/src/airflow/sdk/definitions/iterableoperator.py: ########## @@ -0,0 +1,358 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import asyncio +import os +from collections import deque +from collections.abc import Iterable, Sequence +from concurrent.futures import Future +from itertools import chain +from typing import TYPE_CHECKING, Any + +try: + # Python 3.12+ + from itertools import batched # type: ignore[attr-defined] +except ImportError: + from more_itertools import batched # type: ignore[no-redef] + +try: + # Python 3.11+ + BaseExceptionGroup +except NameError: + from exceptiongroup import BaseExceptionGroup + +from airflow.sdk import TaskInstanceState, timezone +from airflow.sdk.bases.operator import BaseOperator, DecoratedDeferredAsyncOperator, event_loop +from airflow.sdk.definitions.xcom_arg import MapXComArg, XComArg # noqa: F401 +from airflow.sdk.exceptions import ( + AirflowRescheduleTaskInstanceException, + AirflowTaskTimeout, + TaskDeferred, +) +from airflow.sdk.execution_time.context import context_to_airflow_vars +from airflow.sdk.execution_time.executor import HybridExecutor, TaskExecutor, collect_futures +from airflow.sdk.execution_time.lazy_sequence import XComIterable +from airflow.sdk.execution_time.task_runner import MappedTaskInstance + +if TYPE_CHECKING: + from airflow.sdk.definitions._internal.expandinput import ExpandInput + from airflow.sdk.definitions.context import Context + from airflow.sdk.definitions.mappedoperator import MappedOperator + + +class IterableOperator(BaseOperator): + """Object representing an iterable operator in a DAG.""" + + _operator: MappedOperator + expand_input: ExpandInput + partial_kwargs: dict[str, Any] + # each operator should override this class attr for shallow copy attrs. + shallow_copy_attrs: Sequence[str] = ( + "_operator", + "expand_input", + "partial_kwargs", + "_log", + ) + + def __init__( + self, + *, + operator: MappedOperator, + expand_input: ExpandInput, + task_concurrency: int | None = None, + **kwargs, + ): + super().__init__( + **{ + **kwargs, + "task_id": operator.task_id, + "owner": operator.owner, + "email": operator.email, + "email_on_retry": operator.email_on_retry, + "email_on_failure": operator.email_on_failure, + "retries": 0, # We should not retry the IterableOperator, only the mapped ti's should be retried + "retry_delay": operator.retry_delay, + "retry_exponential_backoff": operator.retry_exponential_backoff, + "max_retry_delay": operator.max_retry_delay, + "start_date": operator.start_date, + "end_date": operator.end_date, + "depends_on_past": operator.depends_on_past, + "ignore_first_depends_on_past": operator.ignore_first_depends_on_past, + "wait_for_past_depends_before_skipping": operator.wait_for_past_depends_before_skipping, + "wait_for_downstream": operator.wait_for_downstream, + "dag": operator.dag, + "priority_weight": operator.priority_weight, + "queue": operator.queue, + "pool": operator.pool, + "pool_slots": operator.pool_slots, + "execution_timeout": None, + "trigger_rule": operator.trigger_rule, + "resources": operator.resources, + "run_as_user": operator.run_as_user, + "map_index_template": operator.map_index_template, + "max_active_tis_per_dag": operator.max_active_tis_per_dag, + "max_active_tis_per_dagrun": operator.max_active_tis_per_dagrun, + "executor": operator.executor, + "executor_config": operator.executor_config, + "inlets": operator.inlets, + "outlets": operator.outlets, + "task_group": operator.task_group, + "doc": operator.doc, + "doc_md": operator.doc_md, + "doc_json": operator.doc_json, + "doc_yaml": operator.doc_yaml, + "doc_rst": operator.doc_rst, + "task_display_name": operator.task_display_name, + "allow_nested_operators": operator.allow_nested_operators, + } + ) + self._operator = operator + self.expand_input = expand_input + self.partial_kwargs = operator.partial_kwargs or {} + self.max_workers = task_concurrency or os.cpu_count() or 1 + self._number_of_tasks: int = 0 + XComArg.apply_upstream_relationship(self, self.expand_input.value) + + @property + def task_type(self) -> str: + """@property: type of the task.""" + return self._operator.__class__.__name__ + + @property + def timeout(self) -> float | None: + if self.execution_timeout: + return self.execution_timeout.total_seconds() + return None + + def _get_specified_expand_input(self) -> ExpandInput: + return self.expand_input + + def _unmap_operator(self, mapped_kwargs: Context): + self._number_of_tasks += 1 + return self._operator.unmap(mapped_kwargs) + + def _xcom_push(self, context: Context, task: MappedTaskInstance, value: Any) -> None: + self.log.debug("Pushing XCom %s", task.xcom_key) + + context["ti"].xcom_push(key=task.xcom_key, value=value) + + def _run_tasks( + self, + context: Context, + tasks: Iterable[MappedTaskInstance], + ) -> XComIterable | None: + exceptions: list[BaseException] = [] + reschedule_date = timezone.utcnow() + prev_futures_count = 0 + futures: dict[Future, MappedTaskInstance] = {} + deferred_tasks: deque[MappedTaskInstance] = deque() + failed_tasks: deque[MappedTaskInstance] = deque() + chunked_tasks = batched(tasks, self.max_workers) + do_xcom_push = True + + self.log.info("Running tasks with %d workers", self.max_workers) + + # Export context in os.environ to make it available for operators to use. + airflow_context_vars = context_to_airflow_vars(context, in_env_var_format=True) + os.environ.update(airflow_context_vars) + Review Comment: I think this is actually not needed here, as this is already done by Airflow before the task that is going to do the DTI. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
