This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 1204920d2b0 Revert "Add support for async callables in PythonOperator
(#59087)" (#60266)
1204920d2b0 is described below
commit 1204920d2b08bcd687490c7303dd26689a357b43
Author: David Blain <[email protected]>
AuthorDate: Thu Jan 8 16:23:22 2026 +0100
Revert "Add support for async callables in PythonOperator (#59087)" (#60266)
This reverts commit 9cab6fb7ef4e494dcdf0d9a77171d8001888ca0a.
---
.../core_api/routes/public/test_task_instances.py | 58 +++---
.../providers/common/compat/standard/operators.py | 212 ---------------------
.../providers/common/compat/version_compat.py | 2 -
providers/standard/docs/operators/python.rst | 28 ---
.../example_dags/example_python_decorator.py | 18 --
.../example_dags/example_python_operator.py | 19 --
.../airflow/providers/standard/operators/python.py | 66 +------
.../tests/unit/standard/decorators/test_python.py | 36 +---
.../tests/unit/standard/operators/test_python.py | 157 +--------------
task-sdk/docs/api.rst | 4 +-
task-sdk/docs/index.rst | 1 -
task-sdk/src/airflow/sdk/__init__.py | 10 +-
task-sdk/src/airflow/sdk/__init__.pyi | 2 -
task-sdk/src/airflow/sdk/bases/decorator.py | 53 +-----
task-sdk/src/airflow/sdk/bases/operator.py | 55 +-----
.../sdk/definitions/_internal/abstractoperator.py | 4 -
.../airflow/sdk/execution_time/callback_runner.py | 67 +------
task-sdk/src/airflow/sdk/execution_time/comms.py | 72 ++-----
18 files changed, 64 insertions(+), 800 deletions(-)
diff --git
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py
index 6a426d8e418..63601f5cb02 100644
---
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py
+++
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py
@@ -218,7 +218,7 @@ class TestGetTaskInstance(TestTaskInstanceEndpoint):
"pid": 100,
"pool": "default_pool",
"pool_slots": 1,
- "priority_weight": 14,
+ "priority_weight": 9,
"queue": "default_queue",
"queued_when": None,
"scheduled_when": None,
@@ -376,7 +376,7 @@ class TestGetTaskInstance(TestTaskInstanceEndpoint):
"pid": 100,
"pool": "default_pool",
"pool_slots": 1,
- "priority_weight": 14,
+ "priority_weight": 9,
"queue": "default_queue",
"queued_when": None,
"scheduled_when": None,
@@ -440,7 +440,7 @@ class TestGetTaskInstance(TestTaskInstanceEndpoint):
"pid": 100,
"pool": "default_pool",
"pool_slots": 1,
- "priority_weight": 14,
+ "priority_weight": 9,
"queue": "default_queue",
"queued_when": None,
"scheduled_when": None,
@@ -496,7 +496,7 @@ class TestGetTaskInstance(TestTaskInstanceEndpoint):
"pid": 100,
"pool": "default_pool",
"pool_slots": 1,
- "priority_weight": 14,
+ "priority_weight": 9,
"queue": "default_queue",
"queued_when": None,
"scheduled_when": None,
@@ -616,7 +616,7 @@ class TestGetMappedTaskInstance(TestTaskInstanceEndpoint):
"pid": 100,
"pool": "default_pool",
"pool_slots": 1,
- "priority_weight": 14,
+ "priority_weight": 9,
"queue": "default_queue",
"queued_when": None,
"scheduled_when": None,
@@ -1408,7 +1408,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
False,
"/dags/~/dagRuns/~/taskInstances",
{"dag_id_pattern": "example_python_operator"},
- 14, # Based on test failure - example_python_operator creates
14 task instances
+ 9, # Based on test failure - example_python_operator creates
9 task instances
3,
id="test dag_id_pattern exact match",
),
@@ -1417,7 +1417,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
False,
"/dags/~/dagRuns/~/taskInstances",
{"dag_id_pattern": "example_%"},
- 22, # Based on test failure - both DAGs together create 22
task instances
+ 17, # Based on test failure - both DAGs together create 17
task instances
3,
id="test dag_id_pattern wildcard prefix",
),
@@ -1931,8 +1931,8 @@ class TestGetTaskInstancesBatch(TestTaskInstanceEndpoint):
[
pytest.param(
{"dag_ids": ["example_python_operator", "example_skip_dag"]},
- 22,
- 22,
+ 17,
+ 17,
id="with dag filter",
),
],
@@ -2041,7 +2041,7 @@ class TestGetTaskInstancesBatch(TestTaskInstanceEndpoint):
assert len(response_batch2.json()["task_instances"]) > 0
# Match
- ti_count = 10
+ ti_count = 9
assert response_batch1.json()["total_entries"] ==
response_batch2.json()["total_entries"] == ti_count
assert (num_entries_batch1 + num_entries_batch2) == ti_count
assert response_batch1 != response_batch2
@@ -2080,7 +2080,7 @@ class TestGetTaskInstanceTry(TestTaskInstanceEndpoint):
"pid": 100,
"pool": "default_pool",
"pool_slots": 1,
- "priority_weight": 14,
+ "priority_weight": 9,
"queue": "default_queue",
"queued_when": None,
"scheduled_when": None,
@@ -2127,7 +2127,7 @@ class TestGetTaskInstanceTry(TestTaskInstanceEndpoint):
"pid": 100,
"pool": "default_pool",
"pool_slots": 1,
- "priority_weight": 14,
+ "priority_weight": 9,
"queue": "default_queue",
"queued_when": None,
"scheduled_when": None,
@@ -2205,7 +2205,7 @@ class TestGetTaskInstanceTry(TestTaskInstanceEndpoint):
"pid": 100,
"pool": "default_pool",
"pool_slots": 1,
- "priority_weight": 14,
+ "priority_weight": 9,
"queue": "default_queue",
"queued_when": None,
"scheduled_when": None,
@@ -2278,7 +2278,7 @@ class TestGetTaskInstanceTry(TestTaskInstanceEndpoint):
"pid": 100,
"pool": "default_pool",
"pool_slots": 1,
- "priority_weight": 14,
+ "priority_weight": 9,
"queue": "default_queue",
"queued_when": None,
"scheduled_when": None,
@@ -2326,7 +2326,7 @@ class TestGetTaskInstanceTry(TestTaskInstanceEndpoint):
"pid": 100,
"pool": "default_pool",
"pool_slots": 1,
- "priority_weight": 14,
+ "priority_weight": 9,
"queue": "default_queue",
"queued_when": None,
"scheduled_when": None,
@@ -3243,7 +3243,7 @@ class
TestPostClearTaskInstances(TestTaskInstanceEndpoint):
"pid": 100,
"pool": "default_pool",
"pool_slots": 1,
- "priority_weight": 14,
+ "priority_weight": 9,
"queue": "default_queue",
"queued_when": None,
"scheduled_when": None,
@@ -3615,7 +3615,7 @@ class TestGetTaskInstanceTries(TestTaskInstanceEndpoint):
"pid": 100,
"pool": "default_pool",
"pool_slots": 1,
- "priority_weight": 14,
+ "priority_weight": 9,
"queue": "default_queue",
"queued_when": None,
"scheduled_when": None,
@@ -3653,7 +3653,7 @@ class TestGetTaskInstanceTries(TestTaskInstanceEndpoint):
"pid": 100,
"pool": "default_pool",
"pool_slots": 1,
- "priority_weight": 14,
+ "priority_weight": 9,
"queue": "default_queue",
"queued_when": None,
"scheduled_when": None,
@@ -3801,7 +3801,7 @@ class TestGetTaskInstanceTries(TestTaskInstanceEndpoint):
"pid": 100,
"pool": "default_pool",
"pool_slots": 1,
- "priority_weight": 14,
+ "priority_weight": 9,
"queue": "default_queue",
"queued_when": None,
"scheduled_when": None,
@@ -3885,7 +3885,7 @@ class TestGetTaskInstanceTries(TestTaskInstanceEndpoint):
"pid": 100,
"pool": "default_pool",
"pool_slots": 1,
- "priority_weight": 14,
+ "priority_weight": 9,
"queue": "default_queue",
"queued_when": None,
"scheduled_when": None,
@@ -3923,7 +3923,7 @@ class TestGetTaskInstanceTries(TestTaskInstanceEndpoint):
"pid": 100,
"pool": "default_pool",
"pool_slots": 1,
- "priority_weight": 14,
+ "priority_weight": 9,
"queue": "default_queue",
"queued_when": None,
"scheduled_when": None,
@@ -4165,7 +4165,7 @@ class TestPatchTaskInstance(TestTaskInstanceEndpoint):
"pid": 100,
"pool": "default_pool",
"pool_slots": 1,
- "priority_weight": 14,
+ "priority_weight": 9,
"queue": "default_queue",
"queued_when": None,
"scheduled_when": None,
@@ -4439,7 +4439,7 @@ class TestPatchTaskInstance(TestTaskInstanceEndpoint):
"pid": 100,
"pool": "default_pool",
"pool_slots": 1,
- "priority_weight": 14,
+ "priority_weight": 9,
"queue": "default_queue",
"queued_when": None,
"scheduled_when": None,
@@ -4573,7 +4573,7 @@ class TestPatchTaskInstance(TestTaskInstanceEndpoint):
"pid": 100,
"pool": "default_pool",
"pool_slots": 1,
- "priority_weight": 14,
+ "priority_weight": 9,
"queue": "default_queue",
"queued_when": None,
"scheduled_when": None,
@@ -4634,7 +4634,7 @@ class TestPatchTaskInstance(TestTaskInstanceEndpoint):
"pid": 100,
"pool": "default_pool",
"pool_slots": 1,
- "priority_weight": 14,
+ "priority_weight": 9,
"queue": "default_queue",
"queued_when": None,
"scheduled_when": None,
@@ -4713,7 +4713,7 @@ class TestPatchTaskInstance(TestTaskInstanceEndpoint):
"pid": 100,
"pool": "default_pool",
"pool_slots": 1,
- "priority_weight": 14,
+ "priority_weight": 9,
"queue": "default_queue",
"queued_when": None,
"scheduled_when": None,
@@ -4794,7 +4794,7 @@ class TestPatchTaskInstance(TestTaskInstanceEndpoint):
"pid": 100,
"pool": "default_pool",
"pool_slots": 1,
- "priority_weight": 14,
+ "priority_weight": 9,
"queue": "default_queue",
"queued_when": None,
"scheduled_when": None,
@@ -4912,7 +4912,7 @@ class
TestPatchTaskInstanceDryRun(TestTaskInstanceEndpoint):
"pid": 100,
"pool": "default_pool",
"pool_slots": 1,
- "priority_weight": 14,
+ "priority_weight": 9,
"queue": "default_queue",
"queued_when": None,
"scheduled_when": None,
@@ -5198,7 +5198,7 @@ class
TestPatchTaskInstanceDryRun(TestTaskInstanceEndpoint):
"pid": 100,
"pool": "default_pool",
"pool_slots": 1,
- "priority_weight": 14,
+ "priority_weight": 9,
"queue": "default_queue",
"queued_when": None,
"scheduled_when": None,
diff --git
a/providers/common/compat/src/airflow/providers/common/compat/standard/operators.py
b/providers/common/compat/src/airflow/providers/common/compat/standard/operators.py
index abe3308149d..6b77db3e4a9 100644
---
a/providers/common/compat/src/airflow/providers/common/compat/standard/operators.py
+++
b/providers/common/compat/src/airflow/providers/common/compat/standard/operators.py
@@ -17,230 +17,18 @@
from __future__ import annotations
-from typing import TYPE_CHECKING
-
from airflow.providers.common.compat._compat_utils import create_module_getattr
-from airflow.providers.common.compat.version_compat import (
- AIRFLOW_V_3_0_PLUS,
- AIRFLOW_V_3_1_PLUS,
- AIRFLOW_V_3_2_PLUS,
-)
_IMPORT_MAP: dict[str, str | tuple[str, ...]] = {
# Re-export from sdk (which handles Airflow 2.x/3.x fallbacks)
- "AsyncExecutionCallableRunner": "airflow.providers.common.compat.sdk",
"BaseOperator": "airflow.providers.common.compat.sdk",
- "BaseAsyncOperator": "airflow.providers.common.compat.sdk",
- "create_async_executable_runner": "airflow.providers.common.compat.sdk",
"get_current_context": "airflow.providers.common.compat.sdk",
- "is_async_callable": "airflow.providers.common.compat.sdk",
# Standard provider items with direct fallbacks
"PythonOperator": ("airflow.providers.standard.operators.python",
"airflow.operators.python"),
"ShortCircuitOperator": ("airflow.providers.standard.operators.python",
"airflow.operators.python"),
"_SERIALIZERS": ("airflow.providers.standard.operators.python",
"airflow.operators.python"),
}
-if TYPE_CHECKING:
- from structlog.typing import FilteringBoundLogger as Logger
-
- from airflow.sdk.bases.decorator import is_async_callable
- from airflow.sdk.bases.operator import BaseAsyncOperator
- from airflow.sdk.execution_time.callback_runner import (
- AsyncExecutionCallableRunner,
- create_async_executable_runner,
- )
- from airflow.sdk.types import OutletEventAccessorsProtocol
-elif AIRFLOW_V_3_2_PLUS:
- from airflow.sdk.bases.decorator import is_async_callable
- from airflow.sdk.bases.operator import BaseAsyncOperator
- from airflow.sdk.execution_time.callback_runner import (
- AsyncExecutionCallableRunner,
- create_async_executable_runner,
- )
-else:
- import asyncio
- import contextlib
- import inspect
- import logging
- from asyncio import AbstractEventLoop
- from collections.abc import AsyncIterator, Awaitable, Callable, Generator
- from contextlib import suppress
- from functools import partial
- from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar, cast
-
- from typing_extensions import ParamSpec
-
- if AIRFLOW_V_3_0_PLUS:
- from airflow.sdk import BaseOperator
- from airflow.sdk.bases.decorator import _TaskDecorator
- from airflow.sdk.definitions.asset.metadata import Metadata
- from airflow.sdk.definitions.mappedoperator import OperatorPartial
- else:
- from airflow.datasets.metadata import Metadata
- from airflow.decorators.base import _TaskDecorator
- from airflow.models import BaseOperator
- from airflow.models.mappedoperator import OperatorPartial
-
- P = ParamSpec("P")
- R = TypeVar("R")
-
- @contextlib.contextmanager
- def event_loop() -> Generator[AbstractEventLoop]:
- new_event_loop = False
- loop = None
- try:
- try:
- loop = asyncio.get_event_loop()
- if loop.is_closed():
- raise RuntimeError
- except RuntimeError:
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
- new_event_loop = True
- yield loop
- finally:
- if new_event_loop and loop is not None:
- with contextlib.suppress(AttributeError):
- loop.close()
- asyncio.set_event_loop(None)
-
- def unwrap_partial(fn):
- while isinstance(fn, partial):
- fn = fn.func
- return fn
-
- def unwrap_callable(func):
- # Airflow-specific unwrap
- if isinstance(func, (_TaskDecorator, OperatorPartial)):
- func = getattr(func, "function", getattr(func, "_func", func))
-
- # Unwrap functools.partial
- func = unwrap_partial(func)
-
- # Unwrap @functools.wraps chains
- with suppress(Exception):
- func = inspect.unwrap(func)
-
- return func
-
- def is_async_callable(func):
- """Detect if a callable (possibly wrapped) is an async function."""
- func = unwrap_callable(func)
-
- if not callable(func):
- return False
-
- # Direct async function
- if inspect.iscoroutinefunction(func):
- return True
-
- # Callable object with async __call__
- if not inspect.isfunction(func):
- call = type(func).__call__ # Bandit-safe
- with suppress(Exception):
- call = inspect.unwrap(call)
- if inspect.iscoroutinefunction(call):
- return True
-
- return False
-
- class _AsyncExecutionCallableRunner(Generic[P, R]):
- @staticmethod
- async def run(*args: P.args, **kwargs: P.kwargs) -> R: ... # type:
ignore[empty-body]
-
- class AsyncExecutionCallableRunner(Protocol):
- def __call__(
- self,
- func: Callable[P, R],
- outlet_events: OutletEventAccessorsProtocol,
- *,
- logger: logging.Logger | Logger,
- ) -> _AsyncExecutionCallableRunner[P, R]: ...
-
- def create_async_executable_runner(
- func: Callable[P, Awaitable[R] | AsyncIterator],
- outlet_events: OutletEventAccessorsProtocol,
- *,
- logger: logging.Logger | Logger,
- ) -> _AsyncExecutionCallableRunner[P, R]:
- """
- Run an async execution callable against a task context and given
arguments.
-
- If the callable is a simple function, this simply calls it with the
supplied
- arguments (including the context). If the callable is a generator
function,
- the generator is exhausted here, with the yielded values getting fed
back
- into the task context automatically for execution.
-
- This convoluted implementation of inner class with closure is so *all*
- arguments passed to ``run()`` can be forwarded to the wrapped
function. This
- is particularly important for the argument "self", which some use cases
- need to receive. This is not possible if this is implemented as a
normal
- class, where "self" needs to point to the runner object, not the object
- bounded to the inner callable.
-
- :meta private:
- """
-
- class _AsyncExecutionCallableRunnerImpl(_AsyncExecutionCallableRunner):
- @staticmethod
- async def run(*args: P.args, **kwargs: P.kwargs) -> R:
- if not inspect.isasyncgenfunction(func):
- result = cast("Awaitable[R]", func(*args, **kwargs))
- return await result
-
- results: list[Any] = []
-
- async for result in func(*args, **kwargs):
- if isinstance(result, Metadata):
- outlet_events[result.asset].extra.update(result.extra)
- if result.alias:
- outlet_events[result.alias].add(result.asset,
extra=result.extra)
-
- results.append(result)
-
- return cast("R", results)
-
- return cast("_AsyncExecutionCallableRunner[P, R]",
_AsyncExecutionCallableRunnerImpl)
-
- class BaseAsyncOperator(BaseOperator):
- """
- Base class for async-capable operators.
-
- As opposed to deferred operators which are executed on the triggerer,
async operators are executed
- on the worker.
- """
-
- @property
- def is_async(self) -> bool:
- return True
-
- if not AIRFLOW_V_3_1_PLUS:
-
- @property
- def xcom_push(self) -> bool:
- return self.do_xcom_push
-
- @xcom_push.setter
- def xcom_push(self, value: bool):
- self.do_xcom_push = value
-
- async def aexecute(self, context):
- """Async version of execute(). Subclasses should implement this."""
- raise NotImplementedError()
-
- def execute(self, context):
- """Run `aexecute()` inside an event loop."""
- with event_loop() as loop:
- if self.execution_timeout:
- return loop.run_until_complete(
- asyncio.wait_for(
- self.aexecute(context),
- timeout=self.execution_timeout.total_seconds(),
- )
- )
- return loop.run_until_complete(self.aexecute(context))
-
-
__getattr__ = create_module_getattr(import_map=_IMPORT_MAP)
__all__ = sorted(_IMPORT_MAP.keys())
diff --git
a/providers/common/compat/src/airflow/providers/common/compat/version_compat.py
b/providers/common/compat/src/airflow/providers/common/compat/version_compat.py
index e3fd1e55f14..4142937bd2a 100644
---
a/providers/common/compat/src/airflow/providers/common/compat/version_compat.py
+++
b/providers/common/compat/src/airflow/providers/common/compat/version_compat.py
@@ -34,7 +34,6 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]:
AIRFLOW_V_3_0_PLUS: bool = get_base_airflow_version_tuple() >= (3, 0, 0)
AIRFLOW_V_3_1_PLUS: bool = get_base_airflow_version_tuple() >= (3, 1, 0)
-AIRFLOW_V_3_2_PLUS: bool = get_base_airflow_version_tuple() >= (3, 2, 0)
# BaseOperator removed from version_compat to avoid circular imports
# Import it directly in files that need it instead
@@ -42,5 +41,4 @@ AIRFLOW_V_3_2_PLUS: bool = get_base_airflow_version_tuple()
>= (3, 2, 0)
__all__ = [
"AIRFLOW_V_3_0_PLUS",
"AIRFLOW_V_3_1_PLUS",
- "AIRFLOW_V_3_2_PLUS",
]
diff --git a/providers/standard/docs/operators/python.rst
b/providers/standard/docs/operators/python.rst
index 091412509d8..2e5e63ea437 100644
--- a/providers/standard/docs/operators/python.rst
+++ b/providers/standard/docs/operators/python.rst
@@ -72,34 +72,6 @@ Pass extra arguments to the ``@task`` decorated function as
you would with a nor
:start-after: [START howto_operator_python_kwargs]
:end-before: [END howto_operator_python_kwargs]
-Async Python functions
-^^^^^^^^^^^^^^^^^^^^^^
-
-From Airflow 3.2 onward, async Python callables are now also supported out of
the box.
-This means we don't need to cope with the event loop and allows us to easily
invoke async Python code and async
-Airflow hooks which are not always available through deferred operators.
-As opposed to deferred operators which are executed on the triggerer, async
operators are executed on the workers.
-
-.. tab-set::
-
- .. tab-item:: @task
- :sync: taskflow
-
- .. exampleinclude::
/../src/airflow/providers/standard/example_dags/example_python_decorator.py
- :language: python
- :dedent: 4
- :start-after: [START howto_async_operator_python_kwargs]
- :end-before: [END howto_async_operator_python_kwargs]
-
- .. tab-item:: PythonOperator
- :sync: operator
-
- .. exampleinclude::
/../src/airflow/providers/standard/example_dags/example_python_operator.py
- :language: python
- :dedent: 4
- :start-after: [START howto_async_operator_python_kwargs]
- :end-before: [END howto_async_operator_python_kwargs]
-
Templating
^^^^^^^^^^
diff --git
a/providers/standard/src/airflow/providers/standard/example_dags/example_python_decorator.py
b/providers/standard/src/airflow/providers/standard/example_dags/example_python_decorator.py
index 578a09f574f..ac9938d92ea 100644
---
a/providers/standard/src/airflow/providers/standard/example_dags/example_python_decorator.py
+++
b/providers/standard/src/airflow/providers/standard/example_dags/example_python_decorator.py
@@ -22,7 +22,6 @@ virtual environment.
from __future__ import annotations
-import asyncio
import logging
import sys
import time
@@ -65,7 +64,6 @@ def example_python_decorator():
# [START howto_operator_python_kwargs]
# Generate 5 sleeping tasks, sleeping from 0.0 to 0.4 seconds respectively
- # Asynchronous callables are natively supported since Airflow 3.2+
@task
def my_sleeping_function(random_base):
"""This is a function that will run within the DAG execution"""
@@ -77,22 +75,6 @@ def example_python_decorator():
run_this >> log_the_sql >> sleeping_task
# [END howto_operator_python_kwargs]
- # [START howto_async_operator_python_kwargs]
- # Generate 5 sleeping tasks, sleeping from 0.0 to 0.4 seconds respectively
- # Asynchronous callables are natively supported since Airflow 3.2+
- @task
- async def my_async_sleeping_function(random_base):
- """This is a function that will run within the DAG execution"""
- await asyncio.sleep(random_base)
-
- for i in range(5):
- async_sleeping_task =
my_async_sleeping_function.override(task_id=f"async_sleep_for_{i}")(
- random_base=i / 10
- )
-
- run_this >> log_the_sql >> async_sleeping_task
- # [END howto_async_operator_python_kwargs]
-
# [START howto_operator_python_venv]
@task.virtualenv(
task_id="virtualenv_python", requirements=["colorama==0.4.0"],
system_site_packages=False
diff --git
a/providers/standard/src/airflow/providers/standard/example_dags/example_python_operator.py
b/providers/standard/src/airflow/providers/standard/example_dags/example_python_operator.py
index 064ac042025..18aa8f207e3 100644
---
a/providers/standard/src/airflow/providers/standard/example_dags/example_python_operator.py
+++
b/providers/standard/src/airflow/providers/standard/example_dags/example_python_operator.py
@@ -22,7 +22,6 @@ within a virtual environment.
from __future__ import annotations
-import asyncio
import logging
import sys
import time
@@ -77,7 +76,6 @@ with DAG(
# [START howto_operator_python_kwargs]
# Generate 5 sleeping tasks, sleeping from 0.0 to 0.4 seconds respectively
- # Asynchronous callables are natively supported since Airflow 3.2+
def my_sleeping_function(random_base):
"""This is a function that will run within the DAG execution"""
time.sleep(random_base)
@@ -90,23 +88,6 @@ with DAG(
run_this >> log_the_sql >> sleeping_task
# [END howto_operator_python_kwargs]
- # [START howto_async_operator_python_kwargs]
- # Generate 5 sleeping tasks, sleeping from 0.0 to 0.4 seconds respectively
- # Asynchronous callables are natively supported since Airflow 3.2+
- async def my_async_sleeping_function(random_base):
- """This is a function that will run within the DAG execution"""
- await asyncio.sleep(random_base)
-
- for i in range(5):
- async_sleeping_task = PythonOperator(
- task_id=f"async_sleep_for_{i}",
- python_callable=my_async_sleeping_function,
- op_kwargs={"random_base": i / 10},
- )
-
- run_this >> log_the_sql >> async_sleeping_task
- # [END howto_async_operator_python_kwargs]
-
# [START howto_operator_python_venv]
def callable_virtualenv():
"""
diff --git
a/providers/standard/src/airflow/providers/standard/operators/python.py
b/providers/standard/src/airflow/providers/standard/operators/python.py
index 42fa7a063a4..ac8862f2923 100644
--- a/providers/standard/src/airflow/providers/standard/operators/python.py
+++ b/providers/standard/src/airflow/providers/standard/operators/python.py
@@ -48,18 +48,13 @@ from airflow.exceptions import (
)
from airflow.models.variable import Variable
from airflow.providers.common.compat.sdk import AirflowException,
AirflowSkipException, context_merge
-from airflow.providers.common.compat.standard.operators import (
- AsyncExecutionCallableRunner,
- BaseAsyncOperator,
- is_async_callable,
-)
from airflow.providers.standard.hooks.package_index import PackageIndexHook
from airflow.providers.standard.utils.python_virtualenv import (
_execute_in_subprocess,
prepare_virtualenv,
write_python_script,
)
-from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS
+from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS,
BaseOperator
from airflow.utils import hashlib_wrapper
from airflow.utils.file import get_unique_dag_module_name
from airflow.utils.operator_helpers import KeywordParameters
@@ -120,9 +115,9 @@ class _PythonVersionInfo(NamedTuple):
return cls(*_parse_version_info(result.strip()))
-class PythonOperator(BaseAsyncOperator):
+class PythonOperator(BaseOperator):
"""
- Base class for all Python operators.
+ Executes a Python callable.
.. seealso::
For more information on how to use this operator, take a look at the
guide:
@@ -197,14 +192,7 @@ class PythonOperator(BaseAsyncOperator):
self.template_ext = templates_exts
self.show_return_value_in_logs = show_return_value_in_logs
- @property
- def is_async(self) -> bool:
- return is_async_callable(self.python_callable)
-
- def execute(self, context) -> Any:
- if self.is_async:
- return BaseAsyncOperator.execute(self, context)
-
+ def execute(self, context: Context) -> Any:
context_merge(context, self.op_kwargs,
templates_dict=self.templates_dict)
self.op_kwargs = self.determine_kwargs(context)
@@ -231,40 +219,6 @@ class PythonOperator(BaseAsyncOperator):
return return_value
- async def aexecute(self, context):
- """Async version of execute(). Subclasses should implement this."""
- context_merge(context, self.op_kwargs,
templates_dict=self.templates_dict)
- self.op_kwargs = self.determine_kwargs(context)
-
- # This needs to be lazy because subclasses may implement
execute_callable
- # by running a separate process that can't use the eager result.
- def __prepare_execution() -> tuple[AsyncExecutionCallableRunner,
OutletEventAccessorsProtocol] | None:
- from airflow.providers.common.compat.standard.operators import (
- create_async_executable_runner,
- )
-
- if AIRFLOW_V_3_0_PLUS:
- from airflow.sdk.execution_time.context import (
- context_get_outlet_events,
- )
- else:
- from airflow.utils.context import context_get_outlet_events #
type: ignore
-
- return (
- cast("AsyncExecutionCallableRunner",
create_async_executable_runner),
- context_get_outlet_events(context),
- )
-
- self.__prepare_execution = __prepare_execution
-
- return_value = await self.aexecute_callable()
- if self.show_return_value_in_logs:
- self.log.info("Done. Returned value was: %s", return_value)
- else:
- self.log.info("Done. Returned value not shown")
-
- return return_value
-
def determine_kwargs(self, context: Mapping[str, Any]) -> Mapping[str,
Any]:
return KeywordParameters.determine(self.python_callable, self.op_args,
context).unpacking()
@@ -282,18 +236,6 @@ class PythonOperator(BaseAsyncOperator):
runner = create_execution_runner(self.python_callable, asset_events,
logger=self.log)
return runner.run(*self.op_args, **self.op_kwargs)
- async def aexecute_callable(self) -> Any:
- """
- Call the python callable with the given arguments.
-
- :return: the return value of the call.
- """
- if (execution_preparation := self.__prepare_execution()) is None:
- return await self.python_callable(*self.op_args, **self.op_kwargs)
- create_execution_runner, asset_events = execution_preparation
- runner = create_execution_runner(self.python_callable, asset_events,
logger=self.log)
- return await runner.run(*self.op_args, **self.op_kwargs)
-
class BranchPythonOperator(BaseBranchOperator, PythonOperator):
"""
diff --git a/providers/standard/tests/unit/standard/decorators/test_python.py
b/providers/standard/tests/unit/standard/decorators/test_python.py
index d2aa38ec544..32b2fc2615d 100644
--- a/providers/standard/tests/unit/standard/decorators/test_python.py
+++ b/providers/standard/tests/unit/standard/decorators/test_python.py
@@ -37,16 +37,8 @@ from tests_common.test_utils.version_compat import (
from unit.standard.operators.test_python import BasePythonTest
if AIRFLOW_V_3_0_PLUS:
- from airflow.sdk import (
- DAG,
- BaseOperator,
- TaskGroup,
- XComArg,
- setup,
- task as task_decorator,
- teardown,
- )
- from airflow.sdk.bases.decorator import DecoratedMappedOperator,
_TaskDecorator
+ from airflow.sdk import DAG, BaseOperator, TaskGroup, XComArg, setup, task
as task_decorator, teardown
+ from airflow.sdk.bases.decorator import DecoratedMappedOperator
from airflow.sdk.definitions._internal.expandinput import
DictOfListsExpandInput
else:
from airflow.decorators import ( # type: ignore[attr-defined,no-redef]
@@ -54,7 +46,7 @@ else:
task as task_decorator,
teardown,
)
- from airflow.decorators.base import DecoratedMappedOperator,
_TaskDecorator # type: ignore[no-redef]
+ from airflow.decorators.base import DecoratedMappedOperator # type:
ignore[no-redef]
from airflow.models.baseoperator import BaseOperator # type:
ignore[no-redef]
from airflow.models.dag import DAG # type: ignore[assignment,no-redef]
from airflow.models.expandinput import DictOfListsExpandInput # type:
ignore[attr-defined,no-redef]
@@ -666,9 +658,9 @@ class TestAirflowTaskDecorator(BasePythonTest):
hello.override(pool="my_pool", priority_weight=i)()
weights = []
- for _task in self.dag_non_serialized.tasks:
- assert _task.pool == "my_pool"
- weights.append(_task.priority_weight)
+ for task in self.dag_non_serialized.tasks:
+ assert task.pool == "my_pool"
+ weights.append(task.priority_weight)
assert weights == [0, 1, 2]
def test_python_callable_args_work_as_well_as_baseoperator_args(self,
dag_maker):
@@ -1150,19 +1142,3 @@ def
test_teardown_trigger_rule_override_behavior(dag_maker, session):
my_teardown()
assert work_task.operator.trigger_rule == TriggerRule.ONE_SUCCESS
assert setup_task.operator.trigger_rule == TriggerRule.ONE_SUCCESS
-
-
-async def async_fn():
- return 42
-
-
-def test_python_task():
- from airflow.providers.standard.decorators.python import
_PythonDecoratedOperator, python_task
-
- decorator = python_task(async_fn)
-
- assert isinstance(decorator, _TaskDecorator)
- assert decorator.function == async_fn
- assert decorator.operator_class == _PythonDecoratedOperator
- assert not decorator.multiple_outputs
- assert decorator.kwargs == {"task_id": "async_fn"}
diff --git a/providers/standard/tests/unit/standard/operators/test_python.py
b/providers/standard/tests/unit/standard/operators/test_python.py
index e9f9babab8a..a59c33b29dc 100644
--- a/providers/standard/tests/unit/standard/operators/test_python.py
+++ b/providers/standard/tests/unit/standard/operators/test_python.py
@@ -17,9 +17,7 @@
# under the License.
from __future__ import annotations
-import asyncio
import copy
-import functools
import logging
import os
import pickle
@@ -45,8 +43,7 @@ from slugify import slugify
from airflow.exceptions import AirflowProviderDeprecationWarning,
DeserializingResultError
from airflow.models.connection import Connection
from airflow.models.taskinstance import TaskInstance, clear_task_instances
-from airflow.providers.common.compat.sdk import AirflowException,
BaseOperator, task
-from airflow.providers.common.compat.standard.operators import
is_async_callable
+from airflow.providers.common.compat.sdk import AirflowException
from airflow.providers.standard.operators.empty import EmptyOperator
from airflow.providers.standard.operators.python import (
BranchExternalPythonOperator,
@@ -76,9 +73,11 @@ from tests_common.test_utils.version_compat import (
)
if AIRFLOW_V_3_0_PLUS:
+ from airflow.sdk import BaseOperator
from airflow.sdk.execution_time.context import set_current_context
from airflow.serialization.serialized_objects import LazyDeserializedDAG
else:
+ from airflow.models.baseoperator import BaseOperator # type:
ignore[no-redef]
from airflow.models.taskinstance import set_current_context # type:
ignore[attr-defined,no-redef]
if TYPE_CHECKING:
@@ -2466,18 +2465,6 @@ class TestShortCircuitWithTeardown:
assert set(actual_skipped) == {op3}
-class TestPythonAsyncOperator(TestPythonOperator):
- def test_run_async_task(self, caplog):
- caplog.set_level(logging.INFO, logger=LOGGER_NAME)
-
- async def say_hello(name: str) -> str:
- await asyncio.sleep(1)
- return f"Hello {name}!"
-
- self.run_as_task(say_hello, op_kwargs={"name": "world"},
show_return_value_in_logs=True)
- assert "Done. Returned value was: Hello world!" in caplog.messages
-
-
@pytest.mark.parametrize(
("text_input", "expected_tuple"),
[
@@ -2534,141 +2521,3 @@ def test_python_version_info(mocker):
assert result.releaselevel == sys.version_info.releaselevel
assert result.serial == sys.version_info.serial
assert list(result) == list(sys.version_info)
-
-
-def simple_decorator(fn):
- @functools.wraps(fn)
- def wrapper(*args, **kwargs):
- return fn(*args, **kwargs)
-
- return wrapper
-
-
-def decorator_without_wraps(fn):
- def wrapper(*args, **kwargs):
- return fn(*args, **kwargs)
-
- return wrapper
-
-
-async def async_fn():
- return 42
-
-
-def sync_fn():
- return 42
-
-
-@simple_decorator
-async def wrapped_async_fn():
- return 42
-
-
-@simple_decorator
-def wrapped_sync_fn():
- return 42
-
-
-@decorator_without_wraps
-async def wrapped_async_fn_no_wraps():
- return 42
-
-
-@simple_decorator
-@simple_decorator
-async def multi_wrapped_async_fn():
- return 42
-
-
-async def async_with_args(x, y):
- return x + y
-
-
-def sync_with_args(x, y):
- return x + y
-
-
-class AsyncCallable:
- async def __call__(self):
- return 42
-
-
-class SyncCallable:
- def __call__(self):
- return 42
-
-
-class WrappedAsyncCallable:
- @simple_decorator
- async def __call__(self):
- return 42
-
-
-class TestAsyncCallable:
- def test_plain_async_function(self):
- assert is_async_callable(async_fn)
-
- def test_plain_sync_function(self):
- assert not is_async_callable(sync_fn)
-
- def test_wrapped_async_function_with_wraps(self):
- assert is_async_callable(wrapped_async_fn)
-
- def test_wrapped_sync_function_with_wraps(self):
- assert not is_async_callable(wrapped_sync_fn)
-
- def test_wrapped_async_function_without_wraps(self):
- """
- Without functools.wraps, inspect.unwrap cannot recover the coroutine.
- This documents expected behavior.
- """
- assert not is_async_callable(wrapped_async_fn_no_wraps)
-
- def test_multi_wrapped_async_function(self):
- assert is_async_callable(multi_wrapped_async_fn)
-
- def test_partial_async_function(self):
- fn = functools.partial(async_with_args, 1)
- assert is_async_callable(fn)
-
- def test_partial_sync_function(self):
- fn = functools.partial(sync_with_args, 1)
- assert not is_async_callable(fn)
-
- def test_nested_partial_async_function(self):
- fn = functools.partial(
- functools.partial(async_with_args, 1),
- 2,
- )
- assert is_async_callable(fn)
-
- def test_async_callable_class(self):
- assert is_async_callable(AsyncCallable())
-
- def test_sync_callable_class(self):
- assert not is_async_callable(SyncCallable())
-
- def test_wrapped_async_callable_class(self):
- assert is_async_callable(WrappedAsyncCallable())
-
- def test_partial_callable_class(self):
- fn = functools.partial(AsyncCallable())
- assert is_async_callable(fn)
-
- @pytest.mark.parametrize("value", [None, 42, "string", object()])
- def test_non_callable(self, value):
- assert not is_async_callable(value)
-
- def test_task_decorator_async_function(self):
- @task
- async def async_task_fn():
- return 42
-
- assert is_async_callable(async_task_fn)
-
- def test_task_decorator_sync_function(self):
- @task
- def sync_task_fn():
- return 42
-
- assert not is_async_callable(sync_task_fn)
diff --git a/task-sdk/docs/api.rst b/task-sdk/docs/api.rst
index 6e573e67ecc..9c4d2b880b4 100644
--- a/task-sdk/docs/api.rst
+++ b/task-sdk/docs/api.rst
@@ -62,8 +62,6 @@ Task Decorators:
Bases
-----
-.. autoapiclass:: airflow.sdk.BaseAsyncOperator
-
.. autoapiclass:: airflow.sdk.BaseOperator
.. autoapiclass:: airflow.sdk.BaseSensorOperator
@@ -178,7 +176,7 @@ Everything else
.. autoapimodule:: airflow.sdk
:members:
:special-members: __version__
- :exclude-members: BaseAsyncOperator, BaseOperator, DAG, dag, asset, Asset,
AssetAlias, AssetAll, AssetAny, AssetWatcher, TaskGroup, XComArg,
get_current_context, get_parsing_context
+ :exclude-members: BaseOperator, DAG, dag, asset, Asset, AssetAlias,
AssetAll, AssetAny, AssetWatcher, TaskGroup, XComArg, get_current_context,
get_parsing_context
:undoc-members:
:imported-members:
:no-index:
diff --git a/task-sdk/docs/index.rst b/task-sdk/docs/index.rst
index f3258ea8243..819f637676b 100644
--- a/task-sdk/docs/index.rst
+++ b/task-sdk/docs/index.rst
@@ -78,7 +78,6 @@ Why use ``airflow.sdk``?
**Classes**
- :class:`airflow.sdk.Asset`
-- :class:`airflow.sdk.BaseAsyncOperator`
- :class:`airflow.sdk.BaseHook`
- :class:`airflow.sdk.BaseNotifier`
- :class:`airflow.sdk.BaseOperator`
diff --git a/task-sdk/src/airflow/sdk/__init__.py
b/task-sdk/src/airflow/sdk/__init__.py
index 034a6379430..4dbda282d08 100644
--- a/task-sdk/src/airflow/sdk/__init__.py
+++ b/task-sdk/src/airflow/sdk/__init__.py
@@ -26,7 +26,6 @@ __all__ = [
"AssetAny",
"AssetOrTimeSchedule",
"AssetWatcher",
- "BaseAsyncOperator",
"BaseHook",
"BaseNotifier",
"BaseOperator",
@@ -77,13 +76,7 @@ if TYPE_CHECKING:
from airflow.sdk.api.datamodels._generated import DagRunState,
TaskInstanceState, TriggerRule, WeightRule
from airflow.sdk.bases.hook import BaseHook
from airflow.sdk.bases.notifier import BaseNotifier
- from airflow.sdk.bases.operator import (
- BaseAsyncOperator,
- BaseOperator,
- chain,
- chain_linear,
- cross_downstream,
- )
+ from airflow.sdk.bases.operator import BaseOperator, chain, chain_linear,
cross_downstream
from airflow.sdk.bases.operatorlink import BaseOperatorLink
from airflow.sdk.bases.sensor import BaseSensorOperator, PokeReturnValue
from airflow.sdk.configuration import AirflowSDKConfigParser
@@ -124,7 +117,6 @@ __lazy_imports: dict[str, str] = {
"AssetAny": ".definitions.asset",
"AssetOrTimeSchedule": ".definitions.timetables.assets",
"AssetWatcher": ".definitions.asset",
- "BaseAsyncOperator": ".bases.operator",
"BaseHook": ".bases.hook",
"BaseNotifier": ".bases.notifier",
"BaseOperator": ".bases.operator",
diff --git a/task-sdk/src/airflow/sdk/__init__.pyi
b/task-sdk/src/airflow/sdk/__init__.pyi
index b035f49226c..eede7ff806a 100644
--- a/task-sdk/src/airflow/sdk/__init__.pyi
+++ b/task-sdk/src/airflow/sdk/__init__.pyi
@@ -24,7 +24,6 @@ from airflow.sdk.api.datamodels._generated import (
from airflow.sdk.bases.hook import BaseHook as BaseHook
from airflow.sdk.bases.notifier import BaseNotifier as BaseNotifier
from airflow.sdk.bases.operator import (
- BaseAsyncOperator as BaseAsyncOperator,
BaseOperator as BaseOperator,
chain as chain,
chain_linear as chain_linear,
@@ -84,7 +83,6 @@ __all__ = [
"AssetAny",
"AssetOrTimeSchedule",
"AssetWatcher",
- "BaseAsyncOperator",
"BaseHook",
"BaseNotifier",
"BaseOperator",
diff --git a/task-sdk/src/airflow/sdk/bases/decorator.py
b/task-sdk/src/airflow/sdk/bases/decorator.py
index 96768bc4505..bde898c1696 100644
--- a/task-sdk/src/airflow/sdk/bases/decorator.py
+++ b/task-sdk/src/airflow/sdk/bases/decorator.py
@@ -22,8 +22,7 @@ import re
import textwrap
import warnings
from collections.abc import Callable, Collection, Iterator, Mapping, Sequence
-from contextlib import suppress
-from functools import cached_property, partial, update_wrapper
+from functools import cached_property, update_wrapper
from typing import TYPE_CHECKING, Any, ClassVar, Generic, ParamSpec, Protocol,
TypeVar, cast, overload
import attr
@@ -150,52 +149,6 @@ def get_unique_task_id(
return f"{core}__{max(_find_id_suffixes(dag)) + 1}"
-def unwrap_partial(fn):
- while isinstance(fn, partial):
- fn = fn.func
- return fn
-
-
-def unwrap_callable(func):
- from airflow.sdk.bases.decorator import _TaskDecorator
- from airflow.sdk.definitions.mappedoperator import OperatorPartial
-
- # Airflow-specific unwrap
- if isinstance(func, (_TaskDecorator, OperatorPartial)):
- func = getattr(func, "function", getattr(func, "_func", func))
-
- # Unwrap functools.partial
- func = unwrap_partial(func)
-
- # Unwrap @functools.wraps chains
- with suppress(Exception):
- func = inspect.unwrap(func)
-
- return func
-
-
-def is_async_callable(func):
- """Detect if a callable (possibly wrapped) is an async function."""
- func = unwrap_callable(func)
-
- if not callable(func):
- return False
-
- # Direct async function
- if inspect.iscoroutinefunction(func):
- return True
-
- # Callable object with async __call__
- if not inspect.isfunction(func):
- call = type(func).__call__ # Bandit-safe
- with suppress(Exception):
- call = inspect.unwrap(call)
- if inspect.iscoroutinefunction(call):
- return True
-
- return False
-
-
class DecoratedOperator(BaseOperator):
"""
Wraps a Python callable and captures args/kwargs when called for execution.
@@ -290,10 +243,6 @@ class DecoratedOperator(BaseOperator):
self.op_kwargs = op_kwargs
super().__init__(task_id=task_id, **kwargs_to_upstream, **kwargs)
- @property
- def is_async(self) -> bool:
- return is_async_callable(self.python_callable)
-
def execute(self, context: Context):
# todo make this more generic (move to prepare_lineage) so it deals
with non taskflow operators
# as well
diff --git a/task-sdk/src/airflow/sdk/bases/operator.py
b/task-sdk/src/airflow/sdk/bases/operator.py
index 5f07a188362..0c97df00ef0 100644
--- a/task-sdk/src/airflow/sdk/bases/operator.py
+++ b/task-sdk/src/airflow/sdk/bases/operator.py
@@ -18,15 +18,13 @@
from __future__ import annotations
import abc
-import asyncio
import collections.abc
import contextlib
import copy
import inspect
import sys
import warnings
-from asyncio import AbstractEventLoop
-from collections.abc import Callable, Collection, Generator, Iterable,
Mapping, Sequence
+from collections.abc import Callable, Collection, Iterable, Mapping, Sequence
from contextvars import ContextVar
from dataclasses import dataclass, field
from datetime import datetime, timedelta
@@ -103,7 +101,6 @@ if TYPE_CHECKING:
TaskPostExecuteHook = Callable[[Context, Any], None]
__all__ = [
- "BaseAsyncOperator",
"BaseOperator",
"chain",
"chain_linear",
@@ -199,27 +196,6 @@ def coerce_resources(resources: dict[str, Any] | None) ->
Resources | None:
return Resources(**resources)
[email protected]
-def event_loop() -> Generator[AbstractEventLoop]:
- new_event_loop = False
- loop = None
- try:
- try:
- loop = asyncio.get_event_loop()
- if loop.is_closed():
- raise RuntimeError
- except RuntimeError:
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
- new_event_loop = True
- yield loop
- finally:
- if new_event_loop and loop is not None:
- with contextlib.suppress(AttributeError):
- loop.close()
- asyncio.set_event_loop(None)
-
-
class _PartialDescriptor:
"""A descriptor that guards against ``.partial`` being called on Task
objects."""
@@ -1694,35 +1670,6 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
return bool(self.on_skipped_callback)
-class BaseAsyncOperator(BaseOperator):
- """
- Base class for async-capable operators.
-
- As opposed to deferred operators which are executed on the triggerer,
async operators are executed
- on the worker.
- """
-
- @property
- def is_async(self) -> bool:
- return True
-
- async def aexecute(self, context):
- """Async version of execute(). Subclasses should implement this."""
- raise NotImplementedError()
-
- def execute(self, context):
- """Run `aexecute()` inside an event loop."""
- with event_loop() as loop:
- if self.execution_timeout:
- return loop.run_until_complete(
- asyncio.wait_for(
- self.aexecute(context),
- timeout=self.execution_timeout.total_seconds(),
- )
- )
- return loop.run_until_complete(self.aexecute(context))
-
-
def chain(*tasks: DependencyMixin | Sequence[DependencyMixin]) -> None:
r"""
Given a number of tasks, builds a dependency chain.
diff --git a/task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py
b/task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py
index e7e5ebe8b9a..6c99a72b220 100644
--- a/task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py
+++ b/task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py
@@ -137,10 +137,6 @@ class AbstractOperator(Templater, DAGNode):
)
)
- @property
- def is_async(self) -> bool:
- return False
-
@property
def task_type(self) -> str:
raise NotImplementedError()
diff --git a/task-sdk/src/airflow/sdk/execution_time/callback_runner.py
b/task-sdk/src/airflow/sdk/execution_time/callback_runner.py
index 322e4bc9780..316c3d38e99 100644
--- a/task-sdk/src/airflow/sdk/execution_time/callback_runner.py
+++ b/task-sdk/src/airflow/sdk/execution_time/callback_runner.py
@@ -20,8 +20,8 @@ from __future__ import annotations
import inspect
import logging
-from collections.abc import AsyncIterator, Awaitable, Callable
-from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar, cast
+from collections.abc import Callable
+from typing import TYPE_CHECKING, Generic, Protocol, TypeVar, cast
from typing_extensions import ParamSpec
@@ -39,11 +39,6 @@ class _ExecutionCallableRunner(Generic[P, R]):
def run(*args: P.args, **kwargs: P.kwargs) -> R: ... # type:
ignore[empty-body]
-class _AsyncExecutionCallableRunner(Generic[P, R]):
- @staticmethod
- async def run(*args: P.args, **kwargs: P.kwargs) -> R: ... # type:
ignore[empty-body]
-
-
class ExecutionCallableRunner(Protocol):
def __call__(
self,
@@ -54,16 +49,6 @@ class ExecutionCallableRunner(Protocol):
) -> _ExecutionCallableRunner[P, R]: ...
-class AsyncExecutionCallableRunner(Protocol):
- def __call__(
- self,
- func: Callable[P, R],
- outlet_events: OutletEventAccessorsProtocol,
- *,
- logger: logging.Logger | Logger,
- ) -> _AsyncExecutionCallableRunner[P, R]: ...
-
-
def create_executable_runner(
func: Callable[P, R],
outlet_events: OutletEventAccessorsProtocol,
@@ -124,51 +109,3 @@ def create_executable_runner(
return result # noqa: F821 # Ruff is not smart enough to know
this is always set in _run().
return cast("_ExecutionCallableRunner[P, R]", _ExecutionCallableRunnerImpl)
-
-
-def create_async_executable_runner(
- func: Callable[P, Awaitable[R] | AsyncIterator],
- outlet_events: OutletEventAccessorsProtocol,
- *,
- logger: logging.Logger | Logger,
-) -> _AsyncExecutionCallableRunner[P, R]:
- """
- Run an async execution callable against a task context and given arguments.
-
- If the callable is a simple function, this simply calls it with the
supplied
- arguments (including the context). If the callable is a generator function,
- the generator is exhausted here, with the yielded values getting fed back
- into the task context automatically for execution.
-
- This convoluted implementation of inner class with closure is so *all*
- arguments passed to ``run()`` can be forwarded to the wrapped function.
This
- is particularly important for the argument "self", which some use cases
- need to receive. This is not possible if this is implemented as a normal
- class, where "self" needs to point to the runner object, not the object
- bounded to the inner callable.
-
- :meta private:
- """
-
- class _AsyncExecutionCallableRunnerImpl(_AsyncExecutionCallableRunner):
- @staticmethod
- async def run(*args: P.args, **kwargs: P.kwargs) -> R:
- from airflow.sdk.definitions.asset.metadata import Metadata
-
- if not inspect.isasyncgenfunction(func):
- result = cast("Awaitable[R]", func(*args, **kwargs))
- return await result
-
- results: list[Any] = []
-
- async for result in func(*args, **kwargs):
- if isinstance(result, Metadata):
- outlet_events[result.asset].extra.update(result.extra)
- if result.alias:
- outlet_events[result.alias].add(result.asset,
extra=result.extra)
-
- results.append(result)
-
- return cast("R", results)
-
- return cast("_AsyncExecutionCallableRunner[P, R]",
_AsyncExecutionCallableRunnerImpl)
diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py
b/task-sdk/src/airflow/sdk/execution_time/comms.py
index 15755e640d9..52a96d0b665 100644
--- a/task-sdk/src/airflow/sdk/execution_time/comms.py
+++ b/task-sdk/src/airflow/sdk/execution_time/comms.py
@@ -48,9 +48,7 @@ Execution API server is because:
from __future__ import annotations
-import asyncio
import itertools
-import threading
from collections.abc import Iterator
from datetime import datetime
from functools import cached_property
@@ -187,69 +185,31 @@ class CommsDecoder(Generic[ReceiveMsgType, SendMsgType]):
err_decoder: TypeAdapter[ErrorResponse] = attrs.field(factory=lambda:
TypeAdapter(ToTask), repr=False)
- # Threading lock for sync operations
- _thread_lock: threading.Lock = attrs.field(factory=threading.Lock,
repr=False)
- # Async lock for async operations
- _async_lock: asyncio.Lock = attrs.field(factory=asyncio.Lock, repr=False)
-
def send(self, msg: SendMsgType) -> ReceiveMsgType | None:
"""Send a request to the parent and block until the response is
received."""
frame = _RequestFrame(id=next(self.id_counter), body=msg.model_dump())
frame_bytes = frame.as_bytes()
- # We must make sure sockets aren't intermixed between sync and async
calls,
- # thus we need a dual locking mechanism to ensure that.
- with self._thread_lock:
- self.socket.sendall(frame_bytes)
- if isinstance(msg, ResendLoggingFD):
- if recv_fds is None:
- return None
- # We need special handling here! The server can't send us the
fd number, as the number on the
- # supervisor will be different to in this process, so we have
to mutate the message ourselves here.
- frame, fds = self._read_frame(maxfds=1)
- resp = self._from_frame(frame)
- if TYPE_CHECKING:
- assert isinstance(resp, SentFDs)
- resp.fds = fds
- # Since we know this is an expliclt SendFDs, and since this
class is generic SendFDs might not
- # always be in the return type union
- return resp # type: ignore[return-value]
+ self.socket.sendall(frame_bytes)
+ if isinstance(msg, ResendLoggingFD):
+ if recv_fds is None:
+ return None
+ # We need special handling here! The server can't send us the fd
number, as the number on the
+ # supervisor will be different to in this process, so we have to
mutate the message ourselves here.
+ frame, fds = self._read_frame(maxfds=1)
+ resp = self._from_frame(frame)
+ if TYPE_CHECKING:
+ assert isinstance(resp, SentFDs)
+ resp.fds = fds
+ # Since we know this is an expliclt SendFDs, and since this class
is generic SendFDs might not
+ # always be in the return type union
+ return resp # type: ignore[return-value]
return self._get_response()
async def asend(self, msg: SendMsgType) -> ReceiveMsgType | None:
- """
- Send a request to the parent without blocking.
-
- Uses async lock for coroutine safety and thread lock for socket safety.
- """
- frame = _RequestFrame(id=next(self.id_counter), body=msg.model_dump())
- frame_bytes = frame.as_bytes()
-
- async with self._async_lock:
- # Acquire the threading lock without blocking the event loop
- loop = asyncio.get_running_loop()
- await loop.run_in_executor(None, self._thread_lock.acquire)
- try:
- # Async write to socket
- await loop.sock_sendall(self.socket, frame_bytes)
-
- if isinstance(msg, ResendLoggingFD):
- if recv_fds is None:
- return None
- # Blocking read in a thread
- frame, fds = await asyncio.to_thread(self._read_frame,
maxfds=1)
- resp = self._from_frame(frame)
- if TYPE_CHECKING:
- assert isinstance(resp, SentFDs)
- resp.fds = fds
- return resp # type: ignore[return-value]
-
- # Normal blocking read in a thread
- frame = await asyncio.to_thread(self._read_frame)
- return self._from_frame(frame)
- finally:
- self._thread_lock.release()
+ """Send a request to the parent without blocking."""
+ raise NotImplementedError
@overload
def _read_frame(self, maxfds: None = None) -> _ResponseFrame: ...