This is an automated email from the ASF dual-hosted git repository.
jscheffl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 4b7b998e221 Clean up CeleryExecutor to use workload terminology and
typing (#63888)
4b7b998e221 is described below
commit 4b7b998e221b736c24fea045e04f62b703593b0b
Author: SameerMesiah97 <[email protected]>
AuthorDate: Fri Apr 3 21:03:45 2026 +0100
Clean up CeleryExecutor to use workload terminology and typing (#63888)
* Clean up CeleryExecutor docstrings, comments, variable names, and typing
to align with the workload-based executor model.
* Resolve typing and logging inconsistencies. Update provider.yaml and
provider.info files
---------
Co-authored-by: Sameer Mesiah <[email protected]>
---
providers/celery/provider.yaml | 2 +-
.../providers/celery/executors/celery_executor.py | 131 +++++++++++----------
.../celery/executors/celery_executor_utils.py | 58 +++++----
.../celery/executors/celery_kubernetes_executor.py | 2 +-
.../airflow/providers/celery/get_provider_info.py | 2 +-
.../integration/celery/test_celery_executor.py | 83 ++++++++-----
.../unit/celery/executors/test_celery_executor.py | 130 ++++++++++----------
7 files changed, 221 insertions(+), 187 deletions(-)
diff --git a/providers/celery/provider.yaml b/providers/celery/provider.yaml
index bdbf5ab88ae..0da7df9f09a 100644
--- a/providers/celery/provider.yaml
+++ b/providers/celery/provider.yaml
@@ -298,7 +298,7 @@ config:
default: "prefork"
operation_timeout:
description: |
- The number of seconds to wait before timing out
``send_task_to_executor`` or
+ The number of seconds to wait before timing out
``send_workload_to_executor`` or
``fetch_celery_task_state`` operations.
version_added: ~
type: float
diff --git
a/providers/celery/src/airflow/providers/celery/executors/celery_executor.py
b/providers/celery/src/airflow/providers/celery/executors/celery_executor.py
index b31eec3b061..db4df5ab7bc 100644
--- a/providers/celery/src/airflow/providers/celery/executors/celery_executor.py
+++ b/providers/celery/src/airflow/providers/celery/executors/celery_executor.py
@@ -32,7 +32,7 @@ import time
from collections import Counter
from concurrent.futures import ProcessPoolExecutor
from multiprocessing import cpu_count
-from typing import TYPE_CHECKING, Any
+from typing import TYPE_CHECKING, Any, TypeAlias, cast
from celery import states as celery_states
from deprecated import deprecated
@@ -40,7 +40,7 @@ from deprecated import deprecated
from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.executors.base_executor import BaseExecutor
from airflow.providers.celery.executors import (
- celery_executor_utils as _celery_executor_utils, # noqa: F401 # Needed to
register Celery tasks at worker startup, see #63043
+ celery_executor_utils as _celery_executor_utils, # noqa: F401 # Needed to
register Celery tasks at worker startup, see #63043.
)
from airflow.providers.celery.version_compat import AIRFLOW_V_3_0_PLUS,
AIRFLOW_V_3_2_PLUS
from airflow.providers.common.compat.sdk import AirflowTaskTimeout, Stats
@@ -49,18 +49,27 @@ from airflow.utils.state import TaskInstanceState
log = logging.getLogger(__name__)
-CELERY_SEND_ERR_MSG_HEADER = "Error sending Celery task"
+CELERY_SEND_ERR_MSG_HEADER = "Error sending Celery workload"
if TYPE_CHECKING:
from collections.abc import Sequence
+ from celery.result import AsyncResult
+
from airflow.cli.cli_config import GroupCommand
from airflow.executors import workloads
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.providers.celery.executors.celery_executor_utils import
TaskTuple, WorkloadInCelery
+ if AIRFLOW_V_3_2_PLUS:
+ from airflow.executors.workloads.types import WorkloadKey as
_WorkloadKey
+
+ WorkloadKey: TypeAlias = _WorkloadKey
+ else:
+ WorkloadKey: TypeAlias = TaskInstanceKey # type: ignore[no-redef,
misc]
+
# PEP562
def __getattr__(name):
@@ -84,7 +93,7 @@ class CeleryExecutor(BaseExecutor):
"""
CeleryExecutor is recommended for production use of Airflow.
- It allows distributing the execution of task instances to multiple worker
nodes.
+ It allows distributing the execution of workloads (task instances and
callbacks) to multiple worker nodes.
Celery is a simple, flexible and reliable distributed system to process
vast amounts of messages, while providing operations with the tools
@@ -102,7 +111,7 @@ class CeleryExecutor(BaseExecutor):
if TYPE_CHECKING:
if AIRFLOW_V_3_0_PLUS:
# TODO: TaskSDK: move this type change into BaseExecutor
- queued_tasks: dict[TaskInstanceKey, workloads.All] # type:
ignore[assignment]
+ queued_tasks: dict[WorkloadKey, workloads.All] # type:
ignore[assignment]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -127,7 +136,7 @@ class CeleryExecutor(BaseExecutor):
self.celery_app = create_celery_app(self.conf)
- # Celery doesn't support bulk sending the tasks (which can become a
bottleneck on bigger clusters)
+ # Celery doesn't support bulk sending the workloads (which can become
a bottleneck on bigger clusters)
# so we use a multiprocessing pool to speed this up.
# How many worker processes are created for checking celery task state.
self._sync_parallelism = self.conf.getint("celery",
"SYNC_PARALLELISM", fallback=0)
@@ -136,144 +145,146 @@ class CeleryExecutor(BaseExecutor):
from airflow.providers.celery.executors.celery_executor_utils import
BulkStateFetcher
self.bulk_state_fetcher = BulkStateFetcher(self._sync_parallelism,
celery_app=self.celery_app)
- self.tasks = {}
- self.task_publish_retries: Counter[TaskInstanceKey] = Counter()
- self.task_publish_max_retries = self.conf.getint("celery",
"task_publish_max_retries", fallback=3)
+ self.workloads: dict[WorkloadKey, AsyncResult] = {}
+ self.workload_publish_retries: Counter[WorkloadKey] = Counter()
+ self.workload_publish_max_retries = self.conf.getint("celery",
"task_publish_max_retries", fallback=3)
def start(self) -> None:
self.log.debug("Starting Celery Executor using %s processes for
syncing", self._sync_parallelism)
- def _num_tasks_per_send_process(self, to_send_count: int) -> int:
+ def _num_workloads_per_send_process(self, to_send_count: int) -> int:
"""
- How many Celery tasks should each worker process send.
+ How many Celery workloads should each worker process send.
- :return: Number of tasks that should be sent per process
+ :return: Number of workloads that should be sent per process
"""
return max(1, math.ceil(to_send_count / self._sync_parallelism))
def _process_tasks(self, task_tuples: Sequence[TaskTuple]) -> None:
- # Airflow V2 version
+ # Airflow V2 compatibility path — converts task tuples into
workload-compatible tuples.
task_tuples_to_send = [task_tuple[:3] + (self.team_name,) for
task_tuple in task_tuples]
- self._send_tasks(task_tuples_to_send)
+ self._send_workloads(task_tuples_to_send)
def _process_workloads(self, workloads: Sequence[workloads.All]) -> None:
- # Airflow V3 version -- have to delay imports until we know we are on
v3
+ # Airflow V3 version -- have to delay imports until we know we are on
v3.
from airflow.executors.workloads import ExecuteTask
if AIRFLOW_V_3_2_PLUS:
from airflow.executors.workloads import ExecuteCallback
- tasks: list[WorkloadInCelery] = []
+ workloads_to_be_sent: list[WorkloadInCelery] = []
for workload in workloads:
if isinstance(workload, ExecuteTask):
- tasks.append((workload.ti.key, workload, workload.ti.queue,
self.team_name))
+ workloads_to_be_sent.append((workload.ti.key, workload,
workload.ti.queue, self.team_name))
elif AIRFLOW_V_3_2_PLUS and isinstance(workload, ExecuteCallback):
- # Use default queue for callbacks, or extract from callback
data if available
+ # Use default queue for callbacks, or extract from callback
data if available.
queue = "default"
if isinstance(workload.callback.data, dict) and "queue" in
workload.callback.data:
queue = workload.callback.data["queue"]
- tasks.append((workload.callback.key, workload, queue,
self.team_name))
+ workloads_to_be_sent.append((workload.callback.key, workload,
queue, self.team_name))
else:
raise ValueError(f"{type(self)}._process_workloads cannot
handle {type(workload)}")
- self._send_tasks(tasks)
+ self._send_workloads(workloads_to_be_sent)
- def _send_tasks(self, task_tuples_to_send: Sequence[WorkloadInCelery]):
+ def _send_workloads(self, workload_tuples_to_send:
Sequence[WorkloadInCelery]):
# Celery state queries will be stuck if we do not use one same backend
- # for all tasks.
+ # for all workloads.
cached_celery_backend = self.celery_app.backend
- key_and_async_results = self._send_tasks_to_celery(task_tuples_to_send)
- self.log.debug("Sent all tasks.")
+ key_and_async_results =
self._send_workloads_to_celery(workload_tuples_to_send)
+ self.log.debug("Sent all workloads.")
from airflow.providers.celery.executors.celery_executor_utils import
ExceptionWithTraceback
for key, _, result in key_and_async_results:
if isinstance(result, ExceptionWithTraceback) and isinstance(
result.exception, AirflowTaskTimeout
):
- retries = self.task_publish_retries[key]
- if retries < self.task_publish_max_retries:
+ retries = self.workload_publish_retries[key]
+ if retries < self.workload_publish_max_retries:
Stats.incr("celery.task_timeout_error")
self.log.info(
- "[Try %s of %s] Task Timeout Error for Task: (%s).",
- self.task_publish_retries[key] + 1,
- self.task_publish_max_retries,
+ "[Try %s of %s] Celery Task Timeout Error for
Workload: (%s).",
+ self.workload_publish_retries[key] + 1,
+ self.workload_publish_max_retries,
tuple(key),
)
- self.task_publish_retries[key] = retries + 1
+ self.workload_publish_retries[key] = retries + 1
continue
if key in self.queued_tasks:
self.queued_tasks.pop(key)
else:
self.queued_callbacks.pop(key, None)
- self.task_publish_retries.pop(key, None)
+ self.workload_publish_retries.pop(key, None)
if isinstance(result, ExceptionWithTraceback):
self.log.error("%s: %s\n%s\n", CELERY_SEND_ERR_MSG_HEADER,
result.exception, result.traceback)
self.event_buffer[key] = (TaskInstanceState.FAILED, None)
elif result is not None:
result.backend = cached_celery_backend
self.running.add(key)
- self.tasks[key] = result
+ self.workloads[key] = result
- # Store the Celery task_id in the event buffer. This will get
"overwritten" if the task
+ # Store the Celery task_id (workload execution ID) in the
event buffer. This will get "overwritten" if the task
# has another event, but that is fine, because the only other
events are success/failed at
- # which point we don't need the ID anymore anyway
+ # which point we don't need the ID anymore anyway.
self.event_buffer[key] = (TaskInstanceState.QUEUED,
result.task_id)
- def _send_tasks_to_celery(self, task_tuples_to_send:
Sequence[WorkloadInCelery]):
- from airflow.providers.celery.executors.celery_executor_utils import
send_task_to_executor
+ def _send_workloads_to_celery(self, workload_tuples_to_send:
Sequence[WorkloadInCelery]):
+ from airflow.providers.celery.executors.celery_executor_utils import
send_workload_to_executor
- if len(task_tuples_to_send) == 1 or self._sync_parallelism == 1:
+ if len(workload_tuples_to_send) == 1 or self._sync_parallelism == 1:
# One tuple, or max one process -> send it in the main thread.
- return list(map(send_task_to_executor, task_tuples_to_send))
+ return list(map(send_workload_to_executor,
workload_tuples_to_send))
# Use chunks instead of a work queue to reduce context switching
- # since tasks are roughly uniform in size
- chunksize = self._num_tasks_per_send_process(len(task_tuples_to_send))
- num_processes = min(len(task_tuples_to_send), self._sync_parallelism)
+ # since workloads are roughly uniform in size.
+ chunksize =
self._num_workloads_per_send_process(len(workload_tuples_to_send))
+ num_processes = min(len(workload_tuples_to_send),
self._sync_parallelism)
- # Use ProcessPoolExecutor with team_name instead of task objects to
avoid pickling issues.
+ # Use ProcessPoolExecutor with team_name instead of workload objects
to avoid pickling issues.
# Subprocesses reconstruct the team-specific Celery app from the team
name and existing config.
with ProcessPoolExecutor(max_workers=num_processes) as send_pool:
key_and_async_results = list(
- send_pool.map(send_task_to_executor, task_tuples_to_send,
chunksize=chunksize)
+ send_pool.map(send_workload_to_executor,
workload_tuples_to_send, chunksize=chunksize)
)
return key_and_async_results
def sync(self) -> None:
- if not self.tasks:
- self.log.debug("No task to query celery, skipping sync")
+ if not self.workloads:
+ self.log.debug("No workload to query celery, skipping sync")
return
- self.update_all_task_states()
+ self.update_all_workload_states()
def debug_dump(self) -> None:
"""Debug dump; called in response to SIGUSR2 by the scheduler."""
super().debug_dump()
self.log.info(
- "executor.tasks (%d)\n\t%s", len(self.tasks),
"\n\t".join(map(repr, self.tasks.items()))
+ "executor.workloads (%d)\n\t%s",
+ len(self.workloads),
+ "\n\t".join(map(repr, self.workloads.items())),
)
- def update_all_task_states(self) -> None:
- """Update states of the tasks."""
- self.log.debug("Inquiring about %s celery task(s)", len(self.tasks))
- state_and_info_by_celery_task_id =
self.bulk_state_fetcher.get_many(self.tasks.values())
+ def update_all_workload_states(self) -> None:
+ """Update states of the workloads."""
+ self.log.debug("Inquiring about %s celery workload(s)",
len(self.workloads))
+ state_and_info_by_celery_task_id =
self.bulk_state_fetcher.get_many(self.workloads.values())
self.log.debug("Inquiries completed.")
- for key, async_result in list(self.tasks.items()):
+ for key, async_result in list(self.workloads.items()):
state, info =
state_and_info_by_celery_task_id.get(async_result.task_id)
if state:
- self.update_task_state(key, state, info)
+ self.update_task_state(cast("TaskInstanceKey", key), state,
info)
def change_state(
self, key: TaskInstanceKey, state: TaskInstanceState, info=None,
remove_running=True
) -> None:
super().change_state(key, state, info, remove_running=remove_running)
- self.tasks.pop(key, None)
+ self.workloads.pop(key, None)
def update_task_state(self, key: TaskInstanceKey, state: str, info: Any)
-> None:
- """Update state of a single task."""
+ """Update state of a single workload."""
try:
if state == celery_states.SUCCESS:
self.success(key, info)
@@ -288,7 +299,9 @@ class CeleryExecutor(BaseExecutor):
def end(self, synchronous: bool = False) -> None:
if synchronous:
- while any(task.state not in celery_states.READY_STATES for task in
self.tasks.values()):
+ while any(
+ workload.state not in celery_states.READY_STATES for workload
in self.workloads.values()
+ ):
time.sleep(5)
self.sync()
@@ -322,7 +335,7 @@ class CeleryExecutor(BaseExecutor):
not_adopted_tis.append(ti)
if not celery_tasks:
- # Nothing to adopt
+ # Nothing to adopt.
return tis
states_by_celery_task_id = self.bulk_state_fetcher.get_many(
@@ -342,7 +355,7 @@ class CeleryExecutor(BaseExecutor):
# Set the correct elements of the state dicts, then update this
# like we just queried it.
- self.tasks[ti.key] = result
+ self.workloads[ti.key] = result
self.running.add(ti.key)
self.update_task_state(ti.key, state, info)
adopted.append(f"{ti} in state {state}")
@@ -373,7 +386,7 @@ class CeleryExecutor(BaseExecutor):
return reprs
def revoke_task(self, *, ti: TaskInstance):
- celery_async_result = self.tasks.pop(ti.key, None)
+ celery_async_result = self.workloads.pop(ti.key, None)
if celery_async_result:
try:
self.celery_app.control.revoke(celery_async_result.task_id)
diff --git
a/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py
b/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py
index 699052b470e..c0a9e471d2f 100644
---
a/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py
+++
b/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py
@@ -80,7 +80,7 @@ if TYPE_CHECKING:
from airflow.models.taskinstance import TaskInstanceKey
# We can't use `if AIRFLOW_V_3_0_PLUS` conditions in type checks, so
unfortunately we just have to define
- # the type as the union of both kinds
+ # the type as the union of both kinds.
CommandType = Sequence[str]
WorkloadInCelery: TypeAlias = tuple[WorkloadKey, workloads.All |
CommandType, str | None, str | None]
@@ -88,7 +88,7 @@ if TYPE_CHECKING:
WorkloadKey, CommandType, AsyncResult | "ExceptionWithTraceback"
]
- # Deprecated alias for backward compatibility
+ # Deprecated alias for backward compatibility.
TaskInstanceInCelery: TypeAlias = WorkloadInCelery
TaskTuple = tuple[TaskInstanceKey, CommandType, str | None, Any | None]
@@ -132,10 +132,10 @@ def create_celery_app(team_conf: ExecutorConf |
AirflowConfigParser) -> Celery:
celery_app_name = team_conf.get("celery", "CELERY_APP_NAME")
- # Make app name unique per team to ensure proper broker isolation
+ # Make app name unique per team to ensure proper broker isolation.
# Each team's executor needs a distinct Celery app name to prevent
- # tasks from being routed to the wrong broker
- # Only do this if team_conf is an ExecutorConf with team_name (not global
conf)
+ # tasks from being routed to the wrong broker.
+ # Only do this if team_conf is an ExecutorConf with team_name (not global
conf).
team_name = getattr(team_conf, "team_name", None)
if team_name:
celery_app_name = f"{celery_app_name}_{team_name}"
@@ -153,7 +153,7 @@ def create_celery_app(team_conf: ExecutorConf |
AirflowConfigParser) -> Celery:
celery_app = Celery(celery_app_name, config_source=config)
- # Register tasks with this app
+ # Register tasks with this app.
celery_app.task(name="execute_workload")(execute_workload)
if not AIRFLOW_V_3_0_PLUS:
celery_app.task(name="execute_command")(execute_command)
@@ -161,7 +161,7 @@ def create_celery_app(team_conf: ExecutorConf |
AirflowConfigParser) -> Celery:
return celery_app
-# Keep module-level app for backward compatibility
+# Keep module-level app for backward compatibility.
app = _get_celery_app()
@@ -203,7 +203,7 @@ def on_celery_worker_ready(*args, **kwargs):
# Once Celery 5.5 is out of beta, we can pass `pydantic=True` to the decorator
and it will handle the validation
-# and deserialization for us
+# and deserialization for us.
@app.task(name="execute_workload")
def execute_workload(input: str) -> None:
from celery.exceptions import Ignore
@@ -221,7 +221,7 @@ def execute_workload(input: str) -> None:
log.info("[%s] Executing workload in Celery: %s", celery_task_id, workload)
base_url = conf.get("api", "base_url", fallback="/")
- # If it's a relative URL, use localhost:8080 as the default
+ # If it's a relative URL, use localhost:8080 as the default.
if base_url.startswith("/"):
base_url = f"http://localhost:8080{base_url}"
default_execution_api_server = f"{base_url.rstrip('/')}/execution/"
@@ -285,7 +285,7 @@ if not AIRFLOW_V_3_0_PLUS:
def _execute_in_fork(command_to_exec: CommandType, celery_task_id: str | None
= None) -> None:
pid = os.fork()
if pid:
- # In parent, wait for the child
+ # In parent, wait for the child.
pid, ret = os.waitpid(pid, 0)
if ret == 0:
return
@@ -300,7 +300,7 @@ def _execute_in_fork(command_to_exec: CommandType,
celery_task_id: str | None =
from airflow.cli.cli_parser import get_parser
parser = get_parser()
- # [1:] - remove "airflow" from the start of the command
+ # [1:] - remove "airflow" from the start of the command.
args = parser.parse_args(command_to_exec[1:])
args.shut_down_logging = False
if celery_task_id:
@@ -360,7 +360,7 @@ def send_workload_to_executor(
workload_tuple: WorkloadInCelery,
) -> WorkloadInCeleryResult:
"""
- Send workload to executor.
+ Send workload to executor (serialized and executed as a Celery task).
This function is called in ProcessPoolExecutor subprocesses. To avoid
pickling issues with
team-specific Celery apps, we pass the team_name and reconstruct the
Celery app here.
@@ -371,26 +371,26 @@ def send_workload_to_executor(
# ExecutorConf wraps config access to automatically use team-specific
config where present.
if TYPE_CHECKING:
_conf: ExecutorConf | AirflowConfigParser
- # Check if Airflow version is greater than or equal to 3.2 to import
ExecutorConf
+ # Check if Airflow version is greater than or equal to 3.2 to import
ExecutorConf.
if AIRFLOW_V_3_2_PLUS:
from airflow.executors.base_executor import ExecutorConf
_conf = ExecutorConf(team_name)
else:
- # Airflow <3.2 ExecutorConf doesn't exist (at least not with the
required attributes), fall back to global conf
+ # Airflow <3.2 ExecutorConf doesn't exist (at least not with the
required attributes), fall back to global conf.
_conf = conf
- # Create the Celery app with the correct configuration
+ # Create the Celery app with the correct configuration.
celery_app = create_celery_app(_conf)
if AIRFLOW_V_3_0_PLUS:
- # Get the task from the app
- task_to_run = celery_app.tasks["execute_workload"]
+ # Get the task from the app.
+ celery_task = celery_app.tasks["execute_workload"]
if TYPE_CHECKING:
assert isinstance(args, workloads.BaseWorkload)
args = (args.model_dump_json(),)
else:
- # Get the task from the app
- task_to_run = celery_app.tasks["execute_command"]
+ # Get the task from the app.
+ celery_task = celery_app.tasks["execute_command"]
args = [args] # type: ignore[list-item]
# Pre-import redis.client to avoid SIGALRM interrupting module
initialization.
@@ -400,27 +400,23 @@ def send_workload_to_executor(
try:
import redis.client # noqa: F401
except ImportError:
- pass # Redis not installed or not using Redis backend
+ pass # Redis not installed or not using Redis backend.
try:
with timeout(seconds=OPERATION_TIMEOUT):
- result = task_to_run.apply_async(args=args, queue=queue)
+ result = celery_task.apply_async(args=args, queue=queue)
except (Exception, AirflowTaskTimeout) as e:
exception_traceback = f"Celery Task ID:
{key}\n{traceback.format_exc()}"
result = ExceptionWithTraceback(e, exception_traceback)
# The type is right for the version, but the type cannot be defined
correctly for Airflow 2 and 3
- # concurrently;
+ # concurrently.
return key, args, result
-# Backward compatibility alias
-send_task_to_executor = send_workload_to_executor
-
-
def fetch_celery_task_state(async_result: AsyncResult) -> tuple[str, str |
ExceptionWithTraceback, Any]:
"""
- Fetch and return the state of the given celery task.
+ Fetch and return the state of the given celery task (workload execution).
The scope of this function is global so that it can be called by
subprocesses in the pool.
@@ -434,12 +430,12 @@ def fetch_celery_task_state(async_result: AsyncResult) ->
tuple[str, str | Excep
try:
import redis.client # noqa: F401
except ImportError:
- pass # Redis not installed or not using Redis backend
+ pass # Redis not installed or not using Redis backend.
try:
with timeout(seconds=OPERATION_TIMEOUT):
- # Accessing state property of celery task will make actual network
request
- # to get the current state of the task
+ # Accessing state property of celery task (workload execution)
triggers a network request
+ # to get the current state of the task.
info = async_result.info if hasattr(async_result, "info") else None
return async_result.task_id, async_result.state, info
except Exception as e:
@@ -459,7 +455,7 @@ class BulkStateFetcher(LoggingMixin):
def __init__(self, sync_parallelism: int, celery_app: Celery | None =
None):
super().__init__()
self._sync_parallelism = sync_parallelism
- self.celery_app = celery_app or app # Use provided app or fall back
to module-level app
+ self.celery_app = celery_app or app # Use provided app or fall back
to module-level app.
def _tasks_list_to_task_ids(self, async_tasks: Collection[AsyncResult]) ->
set[str]:
return {a.task_id for a in async_tasks}
diff --git
a/providers/celery/src/airflow/providers/celery/executors/celery_kubernetes_executor.py
b/providers/celery/src/airflow/providers/celery/executors/celery_kubernetes_executor.py
index 687f7b75e3d..f66c153a6d7 100644
---
a/providers/celery/src/airflow/providers/celery/executors/celery_kubernetes_executor.py
+++
b/providers/celery/src/airflow/providers/celery/executors/celery_kubernetes_executor.py
@@ -103,7 +103,7 @@ class CeleryKubernetesExecutor(BaseExecutor):
@property
def queued_tasks(self) -> dict[TaskInstanceKey, Any]:
"""Return queued tasks from celery and kubernetes executor."""
- return self.celery_executor.queued_tasks |
self.kubernetes_executor.queued_tasks
+ return self.celery_executor.queued_tasks |
self.kubernetes_executor.queued_tasks # type: ignore[return-value]
@queued_tasks.setter
def queued_tasks(self, value) -> None:
diff --git a/providers/celery/src/airflow/providers/celery/get_provider_info.py
b/providers/celery/src/airflow/providers/celery/get_provider_info.py
index 537344c7a4e..071071133bd 100644
--- a/providers/celery/src/airflow/providers/celery/get_provider_info.py
+++ b/providers/celery/src/airflow/providers/celery/get_provider_info.py
@@ -198,7 +198,7 @@ def get_provider_info():
"default": "prefork",
},
"operation_timeout": {
- "description": "The number of seconds to wait before
timing out ``send_task_to_executor`` or\n``fetch_celery_task_state``
operations.\n",
+ "description": "The number of seconds to wait before
timing out ``send_workload_to_executor`` or\n``fetch_celery_task_state``
operations.\n",
"version_added": None,
"type": "float",
"example": None,
diff --git a/providers/celery/tests/integration/celery/test_celery_executor.py
b/providers/celery/tests/integration/celery/test_celery_executor.py
index 3641abb6b1f..aeb769f4153 100644
--- a/providers/celery/tests/integration/celery/test_celery_executor.py
+++ b/providers/celery/tests/integration/celery/test_celery_executor.py
@@ -27,7 +27,7 @@ from datetime import datetime, timedelta
from time import sleep
from unittest import mock
-# leave this it is used by the test worker
+# Leave this it is used by the test worker.
import celery.contrib.testing.tasks # noqa: F401
import pytest
import uuid6
@@ -84,8 +84,8 @@ def _prepare_app(broker_url=None, execute=None):
test_config = dict(celery_executor_utils.get_celery_configuration())
test_config.update({"broker_url": broker_url})
test_app = Celery(broker_url, config_source=test_config)
- # Register the fake execute function with the test_app using the correct
task name
- # This ensures workers using test_app will execute the fake function
+ # Register the fake execute function with the test_app using the correct
task name.
+ # This ensures workers using test_app will execute the fake function.
test_execute = test_app.task(name=execute_name)(execute)
patch_app = mock.patch.object(celery_executor_utils, "app", test_app)
@@ -95,7 +95,7 @@ def _prepare_app(broker_url=None, execute=None):
celery_executor_utils.execute_command.__wrapped__ = execute
patch_execute = mock.patch.object(celery_executor_utils, execute_name,
test_execute)
- # Patch factory function so CeleryExecutor instances get the test app
+ # Patch factory function so CeleryExecutor instances get the test app.
patch_factory = mock.patch.object(celery_executor_utils,
"create_celery_app", return_value=test_app)
backend = test_app.backend
@@ -105,7 +105,7 @@ def _prepare_app(broker_url=None, execute=None):
# race condition where it one of the subprocesses can die with "Table
# already exists" error, because SQLA checks for which tables exist,
# then issues a CREATE TABLE, rather than doing CREATE TABLE IF NOT
- # EXISTS
+ # EXISTS.
session = backend.ResultSession()
session.close()
@@ -128,6 +128,17 @@ class TestCeleryExecutor:
db.clear_db_runs()
db.clear_db_jobs()
+
+def setup_dagrun_with_success_and_fail_workloads(dag_maker):
+ date = timezone.utcnow()
+ start_date = date - timedelta(days=2)
+
+ with dag_maker("test_celery_integration"):
+ BaseOperator(task_id="success", start_date=start_date)
+ BaseOperator(task_id="fail", start_date=start_date)
+
+ return dag_maker.create_dagrun(logical_date=date)
+
@pytest.mark.flaky(reruns=5, reruns_delay=3)
@pytest.mark.parametrize("broker_url", _prepare_test_bodies())
@pytest.mark.parametrize(
@@ -164,19 +175,19 @@ class TestCeleryExecutor:
from airflow.providers.celery.executors import celery_executor
if AIRFLOW_V_3_0_PLUS:
- # Airflow 3: execute_workload receives JSON string
+ # Airflow 3: execute_workload receives JSON string.
def fake_execute(input: str) -> None:
"""Fake execute_workload that parses JSON and fails for tasks
with 'fail' in task_id."""
import json
workload_dict = json.loads(input)
- # Check if this is a task that should fail (task_id contains
"fail")
+ # Check if this is a workload that should fail (task_id
contains "fail").
if "ti" in workload_dict and "task_id" in workload_dict["ti"]:
if "fail" in workload_dict["ti"]["task_id"]:
raise AirflowException("fail")
else:
- # Airflow 2: execute_command receives command list
- def fake_execute(input: str) -> None: # Use same parameter name
as Airflow 3 version
+ # Airflow 2: execute_command receives command list.
+ def fake_execute(input: str) -> None: # Use same parameter name
as Airflow 3 version.
if "fail" in input:
raise AirflowException("fail")
@@ -218,6 +229,18 @@ class TestCeleryExecutor:
bundle_info=BundleInfo(name="test"),
log_path="test.log",
)
+ keys = [
+ TaskInstanceKey("id", "success", "abc", 0, -1),
+ TaskInstanceKey("id", "fail", "abc", 0, -1),
+ ]
+ dagrun =
setup_dagrun_with_success_and_fail_workloads(dag_maker)
+ ti_success, ti_fail = dagrun.task_instances
+ for w in (
+ workloads.ExecuteTask.make(
+ ti=ti_success,
+ ),
+ workloads.ExecuteTask.make(ti=ti_fail),
+ ):
executor.queue_workload(w, session=None)
executor.trigger_tasks(open_slots=10)
@@ -244,7 +267,7 @@ class TestCeleryExecutor:
assert executor.queued_tasks == {}
- def test_error_sending_task(self):
+ def test_error_sending_workload(self):
from airflow.providers.celery.executors import celery_executor,
celery_executor_utils
with _prepare_app():
@@ -263,8 +286,8 @@ class TestCeleryExecutor:
executor.queued_tasks[key] = workload
executor.task_publish_retries[key] = 1
- # Mock send_task_to_executor to return an error result
- # This simulates a failure when sending the task to Celery
+ # Mock send_workload_to_executor to return an error result.
+ # This simulates a failure when sending the workload to Celery.
def mock_send_error(task_tuple):
key_from_tuple = task_tuple[0]
return (
@@ -277,14 +300,14 @@ class TestCeleryExecutor:
)
with mock.patch.object(
- celery_executor_utils, "send_task_to_executor",
side_effect=mock_send_error
+ celery_executor_utils, "send_workload_to_executor",
side_effect=mock_send_error
):
executor.heartbeat()
- assert len(executor.queued_tasks) == 0, "Task should no longer be
queued"
+ assert len(executor.queued_tasks) == 0, "Workload should no longer be
queued"
assert executor.event_buffer[key][0] == State.FAILED
- def test_retry_on_error_sending_task(self, caplog):
- """Test that Airflow retries publishing tasks to Celery Broker at
least 3 times"""
+ def test_retry_on_error_sending_workload(self, caplog):
+ """Test that Airflow retries publishing workloads to Celery Broker at
least 3 times"""
from airflow.providers.celery.executors import celery_executor,
celery_executor_utils
with (
@@ -298,8 +321,8 @@ class TestCeleryExecutor:
),
):
executor = celery_executor.CeleryExecutor()
- assert executor.task_publish_retries == {}
- assert executor.task_publish_max_retries == 3, "Assert Default Max
Retries is 3"
+ assert executor.workload_publish_retries == {}
+ assert executor.workload_publish_max_retries == 3, "Assert Default
Max Retries is 3"
with DAG(dag_id="id"):
task = BashOperator(task_id="test", bash_command="true",
start_date=datetime.now())
@@ -314,27 +337,27 @@ class TestCeleryExecutor:
key = (task.dag.dag_id, task.task_id, ti.run_id, 0, -1)
executor.queued_tasks[key] = workload
- # Test that when heartbeat is called again, task is published
again to Celery Queue
+ # Test that when heartbeat is called again, workload is published
again to Celery Queue.
executor.heartbeat()
- assert dict(executor.task_publish_retries) == {key: 1}
- assert len(executor.queued_tasks) == 1, "Task should remain in
queue"
+ assert dict(executor.workload_publish_retries) == {key: 1}
+ assert len(executor.queued_tasks) == 1, "Workload should remain in
queue"
assert executor.event_buffer == {}
- assert f"[Try 1 of 3] Task Timeout Error for Task: ({key})." in
caplog.text
+ assert f"[Try 1 of 3] Celery Task Timeout Error for Workload:
({key})." in caplog.text
executor.heartbeat()
- assert dict(executor.task_publish_retries) == {key: 2}
- assert len(executor.queued_tasks) == 1, "Task should remain in
queue"
+ assert dict(executor.workload_publish_retries) == {key: 2}
+ assert len(executor.queued_tasks) == 1, "Workload should remain in
queue"
assert executor.event_buffer == {}
- assert f"[Try 2 of 3] Task Timeout Error for Task: ({key})." in
caplog.text
+ assert f"[Try 2 of 3] Celery Task Timeout Error for Workload:
({key})." in caplog.text
executor.heartbeat()
- assert dict(executor.task_publish_retries) == {key: 3}
- assert len(executor.queued_tasks) == 1, "Task should remain in
queue"
+ assert dict(executor.workload_publish_retries) == {key: 3}
+ assert len(executor.queued_tasks) == 1, "Workload should remain in
queue"
assert executor.event_buffer == {}
- assert f"[Try 3 of 3] Task Timeout Error for Task: ({key})." in
caplog.text
+ assert f"[Try 3 of 3] Celery Task Timeout Error for Workload:
({key})." in caplog.text
executor.heartbeat()
- assert dict(executor.task_publish_retries) == {}
+ assert dict(executor.workload_publish_retries) == {}
assert len(executor.queued_tasks) == 0, "Task should no longer be
in queue"
assert executor.event_buffer[key][0] == State.FAILED
@@ -389,7 +412,7 @@ class TestBulkStateFetcher:
]
)
- # Assert called - ignore order
+ # Assert called - ignore order.
mget_args, _ = mock_mget.call_args
assert set(mget_args[0]) == {b"celery-task-meta-456",
b"celery-task-meta-123"}
mock_mget.assert_called_once_with(mock.ANY)
diff --git
a/providers/celery/tests/unit/celery/executors/test_celery_executor.py
b/providers/celery/tests/unit/celery/executors/test_celery_executor.py
index bc668bef508..d69df9ec5e7 100644
--- a/providers/celery/tests/unit/celery/executors/test_celery_executor.py
+++ b/providers/celery/tests/unit/celery/executors/test_celery_executor.py
@@ -25,7 +25,7 @@ import sys
from datetime import timedelta
from unittest import mock
-# leave this it is used by the test worker
+# Leave this it is used by the test worker.
import celery.contrib.testing.tasks # noqa: F401
import pytest
import time_machine
@@ -100,7 +100,7 @@ def _prepare_app(broker_url=None, execute=None):
test_execute = test_app.task(execute)
patch_app = mock.patch.object(celery_executor_utils, "app", test_app)
patch_execute = mock.patch.object(celery_executor_utils, execute_name,
test_execute)
- # Patch factory function so CeleryExecutor instances get the test app
+ # Patch factory function so CeleryExecutor instances get the test app.
patch_factory = mock.patch.object(celery_executor_utils,
"create_celery_app", return_value=test_app)
backend = test_app.backend
@@ -110,7 +110,7 @@ def _prepare_app(broker_url=None, execute=None):
# race condition where it one of the subprocesses can die with "Table
# already exists" error, because SQLA checks for which tables exist,
# then issues a CREATE TABLE, rather than doing CREATE TABLE IF NOT
- # EXISTS
+ # EXISTS.
session = backend.ResultSession()
session.close()
@@ -118,7 +118,7 @@ def _prepare_app(broker_url=None, execute=None):
try:
yield test_app
finally:
- # Clear event loop to tear down each celery instance
+ # Clear event loop to tear down each celery instance.
set_event_loop(None)
@@ -148,7 +148,7 @@ class TestCeleryExecutor:
team_name = "test_team"
if AIRFLOW_V_3_2_PLUS:
- # Multi-team support with ExecutorConf requires Airflow 3.2+
+ # Multi-team support with ExecutorConf requires Airflow 3.2+.
executor = celery_executor.CeleryExecutor(parallelism=parallelism,
team_name=team_name)
else:
executor = celery_executor.CeleryExecutor(parallelism)
@@ -156,7 +156,7 @@ class TestCeleryExecutor:
assert executor.parallelism == parallelism
if AIRFLOW_V_3_2_PLUS:
- # Multi-team support with ExecutorConf requires Airflow 3.2+
+ # Multi-team support with ExecutorConf requires Airflow 3.2+.
assert executor.team_name == team_name
assert executor.conf.team_name == team_name
@@ -167,8 +167,8 @@ class TestCeleryExecutor:
)
with _prepare_app():
executor = celery_executor.CeleryExecutor()
- executor.tasks = {"key": FakeCeleryResult()}
-
executor.bulk_state_fetcher._get_many_using_multiprocessing(executor.tasks.values())
+ executor.workloads = {"key": FakeCeleryResult()}
+
executor.bulk_state_fetcher._get_many_using_multiprocessing(executor.workloads.values())
assert celery_executor_utils.CELERY_FETCH_ERR_MSG_HEADER in
caplog.text, caplog.record_tuples
assert FAKE_EXCEPTION_MSG in caplog.text, caplog.record_tuples
@@ -274,7 +274,7 @@ class TestCeleryExecutor:
executor = celery_executor.CeleryExecutor()
assert executor.running == set()
- assert executor.tasks == {}
+ assert executor.workloads == {}
not_adopted_tis = executor.try_adopt_task_instances(tis)
@@ -282,7 +282,7 @@ class TestCeleryExecutor:
key_2 = TaskInstanceKey(dag.dag_id, task_2.task_id, None, 0)
assert executor.running == {key_1, key_2}
- assert executor.tasks == {key_1: AsyncResult("231"), key_2:
AsyncResult("232")}
+ assert executor.workloads == {key_1: AsyncResult("231"), key_2:
AsyncResult("232")}
assert not_adopted_tis == []
@pytest.fixture
@@ -317,12 +317,12 @@ class TestCeleryExecutor:
executor = celery_executor.CeleryExecutor()
executor.job_id = 1
executor.running = {ti.key}
- executor.tasks = {ti.key: AsyncResult("231")}
+ executor.workloads = {ti.key: AsyncResult("231")}
assert executor.has_task(ti)
with pytest.warns(AirflowProviderDeprecationWarning,
match="cleanup_stuck_queued_tasks"):
executor.cleanup_stuck_queued_tasks(tis=tis)
executor.sync()
- assert executor.tasks == {}
+ assert executor.workloads == {}
app.control.revoke.assert_called_once_with("231")
mock_fail.assert_called()
assert not executor.has_task(ti)
@@ -351,13 +351,13 @@ class TestCeleryExecutor:
executor = celery_executor.CeleryExecutor()
executor.job_id = 1
executor.running = {ti.key}
- executor.tasks = {ti.key: AsyncResult("231")}
+ executor.workloads = {ti.key: AsyncResult("231")}
assert executor.has_task(ti)
for ti in tis:
executor.revoke_task(ti=ti)
executor.sync()
app.control.revoke.assert_called_once_with("231")
- assert executor.tasks == {}
+ assert executor.workloads == {}
assert not executor.has_task(ti)
mock_fail.assert_not_called()
@@ -365,18 +365,18 @@ class TestCeleryExecutor:
def test_result_backend_sqlalchemy_engine_options(self):
import importlib
- # Scope the mock using context manager so we can clean up afterward
+ # Scope the mock using context manager so we can clean up afterward.
with mock.patch("celery.Celery") as mock_celery:
- # reload celery conf to apply the new config
+ # Reload celery conf to apply the new config.
importlib.reload(default_celery)
- # reload celery_executor_utils to recreate the celery app with new
config
+ # reload celery_executor_utils to recreate the celery app with new
config.
importlib.reload(celery_executor_utils)
call_args = mock_celery.call_args.kwargs.get("config_source")
assert "database_engine_options" in call_args
assert call_args["database_engine_options"] == {"pool_recycle":
1800}
- # Clean up: reload modules with real Celery to restore clean state for
subsequent tests
+ # Clean up: reload modules with real Celery to restore clean state for
subsequent tests.
importlib.reload(default_celery)
importlib.reload(celery_executor_utils)
@@ -385,9 +385,9 @@ def test_operation_timeout_config():
assert celery_executor_utils.OPERATION_TIMEOUT == 1
-class MockTask:
+class MockWorkload:
"""
- A picklable object used to mock tasks sent to Celery. Can't use the mock
library
+ A picklable object used to mock workloads sent to Celery. Can't use the
mock library
here because it's not picklable.
"""
@@ -414,7 +414,7 @@ def register_signals():
yield
- # Restore original signal handlers after test
+ # Restore original signal handlers after test.
signal.signal(signal.SIGINT, orig_sigint)
signal.signal(signal.SIGTERM, orig_sigterm)
signal.signal(signal.SIGUSR2, orig_sigusr2)
@@ -422,20 +422,20 @@ def register_signals():
@pytest.mark.execution_timeout(200)
@pytest.mark.quarantined
-def test_send_tasks_to_celery_hang(register_signals):
+def test_send_workloads_to_celery_hang(register_signals):
"""
Test that celery_executor does not hang after many runs.
"""
executor = celery_executor.CeleryExecutor()
- task = MockTask()
- task_tuples_to_send = [(None, None, None, task) for _ in range(26)]
+ workload = MockWorkload()
+ workload_tuples_to_send = [(None, None, None, workload) for _ in range(26)]
for _ in range(250):
# This loop can hang on Linux if celery_executor does something wrong
with
# multiprocessing.
- results = executor._send_tasks_to_celery(task_tuples_to_send)
- assert results == [(None, None, 1) for _ in task_tuples_to_send]
+ results = executor._send_workloads_to_celery(workload_tuples_to_send)
+ assert results == [(None, None, 1) for _ in workload_tuples_to_send]
@conf_vars({("celery", "result_backend"):
"rediss://test_user:test_password@localhost:6379/0"})
@@ -445,7 +445,7 @@ def
test_celery_executor_with_no_recommended_result_backend(caplog):
from airflow.providers.celery.executors.default_celery import log
with caplog.at_level(logging.WARNING, logger=log.name):
- # reload celery conf to apply the new config
+ # Reload celery conf to apply the new config.
importlib.reload(default_celery)
assert "test_password" not in caplog.text
assert (
@@ -458,7 +458,7 @@ def
test_celery_executor_with_no_recommended_result_backend(caplog):
def test_sentinel_kwargs_loaded_from_string():
import importlib
- # reload celery conf to apply the new config
+ # Reload celery conf to apply the new config.
importlib.reload(default_celery)
assert
default_celery.DEFAULT_CELERY_CONFIG["broker_transport_options"]["sentinel_kwargs"]
== {
"service_name": "mymaster"
@@ -469,7 +469,7 @@ def test_sentinel_kwargs_loaded_from_string():
def test_celery_task_acks_late_loaded_from_string():
import importlib
- # reload celery conf to apply the new config
+ # Reload celery conf to apply the new config.
importlib.reload(default_celery)
assert default_celery.DEFAULT_CELERY_CONFIG["task_acks_late"] is False
@@ -529,7 +529,7 @@ def
test_visibility_timeout_not_set_for_unsupported_broker(caplog):
def test_celery_extra_celery_config_loaded_from_string():
import importlib
- # reload celery conf to apply the new config
+ # Reload celery conf to apply the new config.
importlib.reload(default_celery)
assert default_celery.DEFAULT_CELERY_CONFIG["worker_max_tasks_per_child"]
== 10
@@ -539,7 +539,7 @@ def
test_result_backend_sentinel_kwargs_loaded_from_string():
"""Test that sentinel_kwargs for result backend transport options is
correctly parsed."""
import importlib
- # reload celery conf to apply the new config
+ # Reload celery conf to apply the new config.
importlib.reload(default_celery)
assert "result_backend_transport_options" in
default_celery.DEFAULT_CELERY_CONFIG
assert
default_celery.DEFAULT_CELERY_CONFIG["result_backend_transport_options"]["sentinel_kwargs"]
== {
@@ -552,7 +552,7 @@ def test_result_backend_master_name_loaded():
"""Test that master_name for result backend transport options is correctly
loaded."""
import importlib
- # reload celery conf to apply the new config
+ # Reload celery conf to apply the new config.
importlib.reload(default_celery)
assert "result_backend_transport_options" in
default_celery.DEFAULT_CELERY_CONFIG
assert (
@@ -570,7 +570,7 @@ def
test_result_backend_transport_options_with_multiple_options():
"""Test that multiple result backend transport options are correctly
loaded."""
import importlib
- # reload celery conf to apply the new config
+ # Reload celery conf to apply the new config.
importlib.reload(default_celery)
result_backend_opts =
default_celery.DEFAULT_CELERY_CONFIG["result_backend_transport_options"]
assert result_backend_opts["sentinel_kwargs"] == {"password":
"redis_password"}
@@ -614,7 +614,7 @@ def test_result_backend_sentinel_full_config():
"""Test full Redis Sentinel configuration for result backend."""
import importlib
- # reload celery conf to apply the new config
+ # Reload celery conf to apply the new config.
importlib.reload(default_celery)
assert default_celery.DEFAULT_CELERY_CONFIG["result_backend"] == (
@@ -643,13 +643,13 @@ class TestMultiTeamCeleryExecutor:
("operators", "default_queue"): "global_queue",
}
)
- def test_multi_team_isolation_and_task_routing(self, monkeypatch):
+ def test_multi_team_isolation_and_workload_routing(self, monkeypatch):
"""
- Test multi-team executor isolation and correct task routing.
+ Test multi-team executor isolation and correct workload routing.
Verifies:
- Each executor has isolated Celery app and config
- - Tasks are routed through team-specific apps
(_process_tasks/_process_workloads)
+ - Workloads are routed through team-specific apps
(_process_tasks/_process_workloads)
- Backward compatibility with global executor
"""
# Set up team-specific config via environment variables
@@ -658,49 +658,49 @@ class TestMultiTeamCeleryExecutor:
monkeypatch.setenv("AIRFLOW__TEAM_B___CELERY__BROKER_URL",
"redis://team-b:6379/0")
monkeypatch.setenv("AIRFLOW__TEAM_B___OPERATORS__DEFAULT_QUEUE",
"team_b_queue")
- # Reload config to pick up environment variables
+ # Reload config to pick up environment variables.
from airflow import configuration
configuration.conf.read_dict({}, source="test")
- # Create executors with different team configs
+ # Create executors with different team configs.
team_a_executor = CeleryExecutor(parallelism=2, team_name="team_a")
team_b_executor = CeleryExecutor(parallelism=3, team_name="team_b")
global_executor = CeleryExecutor(parallelism=4)
- # Each executor has its own Celery app (critical for isolation)
+ # Each executor has its own Celery app (critical for isolation).
assert team_a_executor.celery_app is not team_b_executor.celery_app
assert team_a_executor.celery_app is not global_executor.celery_app
- # Team-specific broker URLs are used
+ # Team-specific broker URLs are used.
assert "team-a" in team_a_executor.celery_app.conf.broker_url
assert "team-b" in team_b_executor.celery_app.conf.broker_url
assert "global" in global_executor.celery_app.conf.broker_url
- # Team-specific queues are used
+ # Team-specific queues are used.
assert team_a_executor.celery_app.conf.task_default_queue ==
"team_a_queue"
assert team_b_executor.celery_app.conf.task_default_queue ==
"team_b_queue"
assert global_executor.celery_app.conf.task_default_queue ==
"global_queue"
- # Each executor has its own BulkStateFetcher with correct app
+ # Each executor has its own BulkStateFetcher with correct app.
assert team_a_executor.bulk_state_fetcher.celery_app is
team_a_executor.celery_app
assert team_b_executor.bulk_state_fetcher.celery_app is
team_b_executor.celery_app
- # Executors have isolated internal state
- assert team_a_executor.tasks is not team_b_executor.tasks
+ # Executors have isolated internal state.
+ assert team_a_executor.workloads is not team_b_executor.workloads
assert team_a_executor.running is not team_b_executor.running
assert team_a_executor.queued_tasks is not team_b_executor.queued_tasks
@conf_vars({("celery", "broker_url"): "redis://global:6379/0"})
-
@mock.patch("airflow.providers.celery.executors.celery_executor.CeleryExecutor._send_tasks")
- def test_task_routing_through_team_specific_app(self, mock_send_tasks,
monkeypatch):
+
@mock.patch("airflow.providers.celery.executors.celery_executor.CeleryExecutor._send_workloads")
+ def test_workload_routing_through_team_specific_app(self,
mock_send_workloads, monkeypatch):
"""
- Test that _process_tasks and _process_workloads pass the correct
team_name for task routing.
+ Test that _process_tasks (v2) and _process_workloads (v3) pass the
correct team_name for task routing.
With the ProcessPoolExecutor approach, we pass team_name instead of
task objects to avoid
pickling issues. The subprocess reconstructs the team-specific Celery
app from the team_name.
"""
- # Set up team A config
+ # Set up team A config.
monkeypatch.setenv("AIRFLOW__TEAM_A___CELERY__BROKER_URL",
"redis://team-a:6379/0")
team_a_executor = CeleryExecutor(parallelism=2, team_name="team_a")
@@ -709,40 +709,42 @@ class TestMultiTeamCeleryExecutor:
from airflow.executors.workloads import ExecuteTask
from airflow.models.taskinstancekey import TaskInstanceKey
- # Create mock workload
+ # Create mock workload.
mock_ti = mock.Mock()
mock_ti.key = TaskInstanceKey("dag", "task", "run", 1)
mock_ti.queue = "test_queue"
mock_workload = mock.Mock(spec=ExecuteTask)
mock_workload.ti = mock_ti
- # Process workload through team A executor
+ # Process workload through team A executor.
team_a_executor._process_workloads([mock_workload])
- # Verify _send_tasks received the correct team_name
- assert mock_send_tasks.called
- task_tuples = mock_send_tasks.call_args[0][0]
- team_name_from_call = task_tuples[0][3] # 4th element is now
team_name
+ # Verify _send_workloads received the correct team_name.
+ assert mock_send_workloads.called
+ workload_tuples = mock_send_workloads.call_args[0][0]
+ team_name_from_call = workload_tuples[0][
+ 3
+ ] # 4th element is team_name (used to reconstruct Celery app in
subprocess).
- # Critical: team_name is passed so subprocess can reconstruct the
correct app
+ # Critical: team_name is passed so subprocess can reconstruct the
correct app.
assert team_name_from_call == "team_a"
else:
from airflow.models.taskinstancekey import TaskInstanceKey
- # Test V2 path with execute_command
+ # Test V2 path with execute_command.
mock_key = TaskInstanceKey("dag", "task", "run", 1)
mock_command = ["airflow", "tasks", "run", "dag", "task"]
mock_queue = "test_queue"
- # Process task through team A executor
+ # Process task through team A executor.
team_a_executor._process_tasks([(mock_key, mock_command,
mock_queue, None)])
- # Verify _send_tasks received team A's execute_command task
- assert mock_send_tasks.called
- task_tuples = mock_send_tasks.call_args[0][0]
- task_from_call = task_tuples[0][3] # 4th element is the task (V2
still uses task object)
+ # Verify _send_workloads received team A's execute_command
workload (v2 compatibility path).
+ assert mock_send_workloads.called
+ task_tuples = mock_send_workloads.call_args[0][0]
+ task_from_call = task_tuples[0][3] # 4th element is the task (V2
still uses task object).
- # Critical: task belongs to team A's app, not module-level app
+ # Critical: Celery task belongs to team A's app, not module-level
app.
assert task_from_call.app is team_a_executor.celery_app
assert task_from_call.name == "execute_command"
@@ -763,7 +765,7 @@ def test_celery_tasks_registered_on_import():
"execute_workload must be registered with the Celery app at import
time. "
"Workers need this to receive tasks without KeyError."
)
- # TODO: remove this block when min supported Airflow version is >= 3.0
+ # TODO: remove this block when min supported Airflow version is >= 3.0.
if not AIRFLOW_V_3_0_PLUS:
assert "execute_command" in registered_tasks, (
"execute_command must be registered for Airflow 2.x compatibility."