kaxil commented on code in PR #59087:
URL: https://github.com/apache/airflow/pull/59087#discussion_r2670517691


##########
providers/common/compat/src/airflow/providers/common/compat/standard/operators.py:
##########
@@ -17,18 +17,230 @@
 
 from __future__ import annotations
 
+from typing import TYPE_CHECKING
+
 from airflow.providers.common.compat._compat_utils import create_module_getattr
+from airflow.providers.common.compat.version_compat import (
+    AIRFLOW_V_3_0_PLUS,
+    AIRFLOW_V_3_1_PLUS,
+    AIRFLOW_V_3_2_PLUS,
+)
 
 _IMPORT_MAP: dict[str, str | tuple[str, ...]] = {
     # Re-export from sdk (which handles Airflow 2.x/3.x fallbacks)
+    "AsyncExecutionCallableRunner": "airflow.providers.common.compat.sdk",
     "BaseOperator": "airflow.providers.common.compat.sdk",
+    "BaseAsyncOperator": "airflow.providers.common.compat.sdk",
+    "create_async_executable_runner": "airflow.providers.common.compat.sdk",
     "get_current_context": "airflow.providers.common.compat.sdk",
+    "is_async_callable": "airflow.providers.common.compat.sdk",
     # Standard provider items with direct fallbacks
     "PythonOperator": ("airflow.providers.standard.operators.python", 
"airflow.operators.python"),
     "ShortCircuitOperator": ("airflow.providers.standard.operators.python", 
"airflow.operators.python"),
     "_SERIALIZERS": ("airflow.providers.standard.operators.python", 
"airflow.operators.python"),
 }
 
+if TYPE_CHECKING:
+    from structlog.typing import FilteringBoundLogger as Logger
+
+    from airflow.sdk.bases.decorator import is_async_callable
+    from airflow.sdk.bases.operator import BaseAsyncOperator
+    from airflow.sdk.execution_time.callback_runner import (
+        AsyncExecutionCallableRunner,
+        create_async_executable_runner,
+    )
+    from airflow.sdk.types import OutletEventAccessorsProtocol
+elif AIRFLOW_V_3_2_PLUS:
+    from airflow.sdk.bases.decorator import is_async_callable
+    from airflow.sdk.bases.operator import BaseAsyncOperator
+    from airflow.sdk.execution_time.callback_runner import (
+        AsyncExecutionCallableRunner,
+        create_async_executable_runner,
+    )
+else:
+    import asyncio
+    import contextlib
+    import inspect
+    import logging
+    from asyncio import AbstractEventLoop
+    from collections.abc import AsyncIterator, Awaitable, Callable, Generator
+    from contextlib import suppress
+    from functools import partial
+    from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar, cast
+
+    from typing_extensions import ParamSpec
+
+    if AIRFLOW_V_3_0_PLUS:
+        from airflow.sdk import BaseOperator
+        from airflow.sdk.bases.decorator import _TaskDecorator
+        from airflow.sdk.definitions.asset.metadata import Metadata
+        from airflow.sdk.definitions.mappedoperator import OperatorPartial
+    else:
+        from airflow.datasets.metadata import Metadata
+        from airflow.decorators.base import _TaskDecorator
+        from airflow.models import BaseOperator
+        from airflow.models.mappedoperator import OperatorPartial
+
+    P = ParamSpec("P")
+    R = TypeVar("R")
+
+    @contextlib.contextmanager
+    def event_loop() -> Generator[AbstractEventLoop]:
+        new_event_loop = False
+        loop = None
+        try:
+            try:
+                loop = asyncio.get_event_loop()
+                if loop.is_closed():
+                    raise RuntimeError
+            except RuntimeError:
+                loop = asyncio.new_event_loop()
+                asyncio.set_event_loop(loop)
+                new_event_loop = True
+            yield loop
+        finally:
+            if new_event_loop and loop is not None:
+                with contextlib.suppress(AttributeError):
+                    loop.close()
+                    asyncio.set_event_loop(None)
+
+    def unwrap_partial(fn):
+        while isinstance(fn, partial):
+            fn = fn.func
+        return fn
+
+    def unwrap_callable(func):
+        # Airflow-specific unwrap
+        if isinstance(func, (_TaskDecorator, OperatorPartial)):
+            func = getattr(func, "function", getattr(func, "_func", func))
+
+        # Unwrap functools.partial
+        func = unwrap_partial(func)
+
+        # Unwrap @functools.wraps chains
+        with suppress(Exception):
+            func = inspect.unwrap(func)
+
+        return func
+
+    def is_async_callable(func):
+        """Detect if a callable (possibly wrapped) is an async function."""
+        func = unwrap_callable(func)
+
+        if not callable(func):
+            return False
+
+        # Direct async function
+        if inspect.iscoroutinefunction(func):
+            return True
+
+        # Callable object with async __call__
+        if not inspect.isfunction(func):
+            call = type(func).__call__  # Bandit-safe
+            with suppress(Exception):
+                call = inspect.unwrap(call)
+            if inspect.iscoroutinefunction(call):
+                return True
+
+        return False
+
+    class _AsyncExecutionCallableRunner(Generic[P, R]):
+        @staticmethod
+        async def run(*args: P.args, **kwargs: P.kwargs) -> R: ...  # type: 
ignore[empty-body]
+
+    class AsyncExecutionCallableRunner(Protocol):
+        def __call__(
+            self,
+            func: Callable[P, R],
+            outlet_events: OutletEventAccessorsProtocol,
+            *,
+            logger: logging.Logger | Logger,
+        ) -> _AsyncExecutionCallableRunner[P, R]: ...
+
+    def create_async_executable_runner(
+        func: Callable[P, Awaitable[R] | AsyncIterator],
+        outlet_events: OutletEventAccessorsProtocol,
+        *,
+        logger: logging.Logger | logging.Logger,

Review Comment:
   ? duplicate types



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to