This is an automated email from the ASF dual-hosted git repository.

jscheffl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 4b7b998e221 Clean up CeleryExecutor to use workload terminology and 
typing (#63888)
4b7b998e221 is described below

commit 4b7b998e221b736c24fea045e04f62b703593b0b
Author: SameerMesiah97 <[email protected]>
AuthorDate: Fri Apr 3 21:03:45 2026 +0100

    Clean up CeleryExecutor to use workload terminology and typing (#63888)
    
    * Clean up CeleryExecutor docstrings, comments, variable names, and typing 
to align with the workload-based executor model.
    
    * Resolve typing and logging inconsistencies. Update provider.yaml and 
provider.info files
    
    ---------
    
    Co-authored-by: Sameer Mesiah <[email protected]>
---
 providers/celery/provider.yaml                     |   2 +-
 .../providers/celery/executors/celery_executor.py  | 131 +++++++++++----------
 .../celery/executors/celery_executor_utils.py      |  58 +++++----
 .../celery/executors/celery_kubernetes_executor.py |   2 +-
 .../airflow/providers/celery/get_provider_info.py  |   2 +-
 .../integration/celery/test_celery_executor.py     |  83 ++++++++-----
 .../unit/celery/executors/test_celery_executor.py  | 130 ++++++++++----------
 7 files changed, 221 insertions(+), 187 deletions(-)

diff --git a/providers/celery/provider.yaml b/providers/celery/provider.yaml
index bdbf5ab88ae..0da7df9f09a 100644
--- a/providers/celery/provider.yaml
+++ b/providers/celery/provider.yaml
@@ -298,7 +298,7 @@ config:
         default: "prefork"
       operation_timeout:
         description: |
-          The number of seconds to wait before timing out 
``send_task_to_executor`` or
+          The number of seconds to wait before timing out 
``send_workload_to_executor`` or
           ``fetch_celery_task_state`` operations.
         version_added: ~
         type: float
diff --git 
a/providers/celery/src/airflow/providers/celery/executors/celery_executor.py 
b/providers/celery/src/airflow/providers/celery/executors/celery_executor.py
index b31eec3b061..db4df5ab7bc 100644
--- a/providers/celery/src/airflow/providers/celery/executors/celery_executor.py
+++ b/providers/celery/src/airflow/providers/celery/executors/celery_executor.py
@@ -32,7 +32,7 @@ import time
 from collections import Counter
 from concurrent.futures import ProcessPoolExecutor
 from multiprocessing import cpu_count
-from typing import TYPE_CHECKING, Any
+from typing import TYPE_CHECKING, Any, TypeAlias, cast
 
 from celery import states as celery_states
 from deprecated import deprecated
@@ -40,7 +40,7 @@ from deprecated import deprecated
 from airflow.exceptions import AirflowProviderDeprecationWarning
 from airflow.executors.base_executor import BaseExecutor
 from airflow.providers.celery.executors import (
-    celery_executor_utils as _celery_executor_utils,  # noqa: F401 # Needed to 
register Celery tasks at worker startup, see #63043
+    celery_executor_utils as _celery_executor_utils,  # noqa: F401 # Needed to 
register Celery tasks at worker startup, see #63043.
 )
 from airflow.providers.celery.version_compat import AIRFLOW_V_3_0_PLUS, 
AIRFLOW_V_3_2_PLUS
 from airflow.providers.common.compat.sdk import AirflowTaskTimeout, Stats
@@ -49,18 +49,27 @@ from airflow.utils.state import TaskInstanceState
 log = logging.getLogger(__name__)
 
 
-CELERY_SEND_ERR_MSG_HEADER = "Error sending Celery task"
+CELERY_SEND_ERR_MSG_HEADER = "Error sending Celery workload"
 
 
 if TYPE_CHECKING:
     from collections.abc import Sequence
 
+    from celery.result import AsyncResult
+
     from airflow.cli.cli_config import GroupCommand
     from airflow.executors import workloads
     from airflow.models.taskinstance import TaskInstance
     from airflow.models.taskinstancekey import TaskInstanceKey
     from airflow.providers.celery.executors.celery_executor_utils import 
TaskTuple, WorkloadInCelery
 
+    if AIRFLOW_V_3_2_PLUS:
+        from airflow.executors.workloads.types import WorkloadKey as 
_WorkloadKey
+
+        WorkloadKey: TypeAlias = _WorkloadKey
+    else:
+        WorkloadKey: TypeAlias = TaskInstanceKey  # type: ignore[no-redef, 
misc]
+
 
 # PEP562
 def __getattr__(name):
@@ -84,7 +93,7 @@ class CeleryExecutor(BaseExecutor):
     """
     CeleryExecutor is recommended for production use of Airflow.
 
-    It allows distributing the execution of task instances to multiple worker 
nodes.
+    It allows distributing the execution of workloads (task instances and 
callbacks) to multiple worker nodes.
 
     Celery is a simple, flexible and reliable distributed system to process
     vast amounts of messages, while providing operations with the tools
@@ -102,7 +111,7 @@ class CeleryExecutor(BaseExecutor):
     if TYPE_CHECKING:
         if AIRFLOW_V_3_0_PLUS:
             # TODO: TaskSDK: move this type change into BaseExecutor
-            queued_tasks: dict[TaskInstanceKey, workloads.All]  # type: 
ignore[assignment]
+            queued_tasks: dict[WorkloadKey, workloads.All]  # type: 
ignore[assignment]
 
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
@@ -127,7 +136,7 @@ class CeleryExecutor(BaseExecutor):
 
         self.celery_app = create_celery_app(self.conf)
 
-        # Celery doesn't support bulk sending the tasks (which can become a 
bottleneck on bigger clusters)
+        # Celery doesn't support bulk sending the workloads (which can become 
a bottleneck on bigger clusters)
         # so we use a multiprocessing pool to speed this up.
         # How many worker processes are created for checking celery task state.
         self._sync_parallelism = self.conf.getint("celery", 
"SYNC_PARALLELISM", fallback=0)
@@ -136,144 +145,146 @@ class CeleryExecutor(BaseExecutor):
         from airflow.providers.celery.executors.celery_executor_utils import 
BulkStateFetcher
 
         self.bulk_state_fetcher = BulkStateFetcher(self._sync_parallelism, 
celery_app=self.celery_app)
-        self.tasks = {}
-        self.task_publish_retries: Counter[TaskInstanceKey] = Counter()
-        self.task_publish_max_retries = self.conf.getint("celery", 
"task_publish_max_retries", fallback=3)
+        self.workloads: dict[WorkloadKey, AsyncResult] = {}
+        self.workload_publish_retries: Counter[WorkloadKey] = Counter()
+        self.workload_publish_max_retries = self.conf.getint("celery", 
"task_publish_max_retries", fallback=3)
 
     def start(self) -> None:
         self.log.debug("Starting Celery Executor using %s processes for 
syncing", self._sync_parallelism)
 
-    def _num_tasks_per_send_process(self, to_send_count: int) -> int:
+    def _num_workloads_per_send_process(self, to_send_count: int) -> int:
         """
-        How many Celery tasks should each worker process send.
+        How many Celery workloads should each worker process send.
 
-        :return: Number of tasks that should be sent per process
+        :return: Number of workloads that should be sent per process
         """
         return max(1, math.ceil(to_send_count / self._sync_parallelism))
 
     def _process_tasks(self, task_tuples: Sequence[TaskTuple]) -> None:
-        # Airflow V2 version
+        # Airflow V2 compatibility path — converts task tuples into 
workload-compatible tuples.
 
         task_tuples_to_send = [task_tuple[:3] + (self.team_name,) for 
task_tuple in task_tuples]
 
-        self._send_tasks(task_tuples_to_send)
+        self._send_workloads(task_tuples_to_send)
 
     def _process_workloads(self, workloads: Sequence[workloads.All]) -> None:
-        # Airflow V3 version -- have to delay imports until we know we are on 
v3
+        # Airflow V3 version -- have to delay imports until we know we are on 
v3.
         from airflow.executors.workloads import ExecuteTask
 
         if AIRFLOW_V_3_2_PLUS:
             from airflow.executors.workloads import ExecuteCallback
 
-        tasks: list[WorkloadInCelery] = []
+        workloads_to_be_sent: list[WorkloadInCelery] = []
         for workload in workloads:
             if isinstance(workload, ExecuteTask):
-                tasks.append((workload.ti.key, workload, workload.ti.queue, 
self.team_name))
+                workloads_to_be_sent.append((workload.ti.key, workload, 
workload.ti.queue, self.team_name))
             elif AIRFLOW_V_3_2_PLUS and isinstance(workload, ExecuteCallback):
-                # Use default queue for callbacks, or extract from callback 
data if available
+                # Use default queue for callbacks, or extract from callback 
data if available.
                 queue = "default"
                 if isinstance(workload.callback.data, dict) and "queue" in 
workload.callback.data:
                     queue = workload.callback.data["queue"]
-                tasks.append((workload.callback.key, workload, queue, 
self.team_name))
+                workloads_to_be_sent.append((workload.callback.key, workload, 
queue, self.team_name))
             else:
                 raise ValueError(f"{type(self)}._process_workloads cannot 
handle {type(workload)}")
 
-        self._send_tasks(tasks)
+        self._send_workloads(workloads_to_be_sent)
 
-    def _send_tasks(self, task_tuples_to_send: Sequence[WorkloadInCelery]):
+    def _send_workloads(self, workload_tuples_to_send: 
Sequence[WorkloadInCelery]):
         # Celery state queries will be stuck if we do not use one same backend
-        # for all tasks.
+        # for all workloads.
         cached_celery_backend = self.celery_app.backend
 
-        key_and_async_results = self._send_tasks_to_celery(task_tuples_to_send)
-        self.log.debug("Sent all tasks.")
+        key_and_async_results = 
self._send_workloads_to_celery(workload_tuples_to_send)
+        self.log.debug("Sent all workloads.")
         from airflow.providers.celery.executors.celery_executor_utils import 
ExceptionWithTraceback
 
         for key, _, result in key_and_async_results:
             if isinstance(result, ExceptionWithTraceback) and isinstance(
                 result.exception, AirflowTaskTimeout
             ):
-                retries = self.task_publish_retries[key]
-                if retries < self.task_publish_max_retries:
+                retries = self.workload_publish_retries[key]
+                if retries < self.workload_publish_max_retries:
                     Stats.incr("celery.task_timeout_error")
                     self.log.info(
-                        "[Try %s of %s] Task Timeout Error for Task: (%s).",
-                        self.task_publish_retries[key] + 1,
-                        self.task_publish_max_retries,
+                        "[Try %s of %s] Celery Task Timeout Error for 
Workload: (%s).",
+                        self.workload_publish_retries[key] + 1,
+                        self.workload_publish_max_retries,
                         tuple(key),
                     )
-                    self.task_publish_retries[key] = retries + 1
+                    self.workload_publish_retries[key] = retries + 1
                     continue
             if key in self.queued_tasks:
                 self.queued_tasks.pop(key)
             else:
                 self.queued_callbacks.pop(key, None)
-            self.task_publish_retries.pop(key, None)
+            self.workload_publish_retries.pop(key, None)
             if isinstance(result, ExceptionWithTraceback):
                 self.log.error("%s: %s\n%s\n", CELERY_SEND_ERR_MSG_HEADER, 
result.exception, result.traceback)
                 self.event_buffer[key] = (TaskInstanceState.FAILED, None)
             elif result is not None:
                 result.backend = cached_celery_backend
                 self.running.add(key)
-                self.tasks[key] = result
+                self.workloads[key] = result
 
-                # Store the Celery task_id in the event buffer. This will get 
"overwritten" if the task
+                # Store the Celery task_id (workload execution ID) in the 
event buffer. This will get "overwritten" if the task
                 # has another event, but that is fine, because the only other 
events are success/failed at
-                # which point we don't need the ID anymore anyway
+                # which point we don't need the ID anymore anyway.
                 self.event_buffer[key] = (TaskInstanceState.QUEUED, 
result.task_id)
 
-    def _send_tasks_to_celery(self, task_tuples_to_send: 
Sequence[WorkloadInCelery]):
-        from airflow.providers.celery.executors.celery_executor_utils import 
send_task_to_executor
+    def _send_workloads_to_celery(self, workload_tuples_to_send: 
Sequence[WorkloadInCelery]):
+        from airflow.providers.celery.executors.celery_executor_utils import 
send_workload_to_executor
 
-        if len(task_tuples_to_send) == 1 or self._sync_parallelism == 1:
+        if len(workload_tuples_to_send) == 1 or self._sync_parallelism == 1:
             # One tuple, or max one process -> send it in the main thread.
-            return list(map(send_task_to_executor, task_tuples_to_send))
+            return list(map(send_workload_to_executor, 
workload_tuples_to_send))
 
         # Use chunks instead of a work queue to reduce context switching
-        # since tasks are roughly uniform in size
-        chunksize = self._num_tasks_per_send_process(len(task_tuples_to_send))
-        num_processes = min(len(task_tuples_to_send), self._sync_parallelism)
+        # since workloads are roughly uniform in size.
+        chunksize = 
self._num_workloads_per_send_process(len(workload_tuples_to_send))
+        num_processes = min(len(workload_tuples_to_send), 
self._sync_parallelism)
 
-        # Use ProcessPoolExecutor with team_name instead of task objects to 
avoid pickling issues.
+        # Use ProcessPoolExecutor with team_name instead of workload objects 
to avoid pickling issues.
         # Subprocesses reconstruct the team-specific Celery app from the team 
name and existing config.
         with ProcessPoolExecutor(max_workers=num_processes) as send_pool:
             key_and_async_results = list(
-                send_pool.map(send_task_to_executor, task_tuples_to_send, 
chunksize=chunksize)
+                send_pool.map(send_workload_to_executor, 
workload_tuples_to_send, chunksize=chunksize)
             )
         return key_and_async_results
 
     def sync(self) -> None:
-        if not self.tasks:
-            self.log.debug("No task to query celery, skipping sync")
+        if not self.workloads:
+            self.log.debug("No workload to query celery, skipping sync")
             return
-        self.update_all_task_states()
+        self.update_all_workload_states()
 
     def debug_dump(self) -> None:
         """Debug dump; called in response to SIGUSR2 by the scheduler."""
         super().debug_dump()
         self.log.info(
-            "executor.tasks (%d)\n\t%s", len(self.tasks), 
"\n\t".join(map(repr, self.tasks.items()))
+            "executor.workloads (%d)\n\t%s",
+            len(self.workloads),
+            "\n\t".join(map(repr, self.workloads.items())),
         )
 
-    def update_all_task_states(self) -> None:
-        """Update states of the tasks."""
-        self.log.debug("Inquiring about %s celery task(s)", len(self.tasks))
-        state_and_info_by_celery_task_id = 
self.bulk_state_fetcher.get_many(self.tasks.values())
+    def update_all_workload_states(self) -> None:
+        """Update states of the workloads."""
+        self.log.debug("Inquiring about %s celery workload(s)", 
len(self.workloads))
+        state_and_info_by_celery_task_id = 
self.bulk_state_fetcher.get_many(self.workloads.values())
 
         self.log.debug("Inquiries completed.")
-        for key, async_result in list(self.tasks.items()):
+        for key, async_result in list(self.workloads.items()):
             state, info = 
state_and_info_by_celery_task_id.get(async_result.task_id)
             if state:
-                self.update_task_state(key, state, info)
+                self.update_task_state(cast("TaskInstanceKey", key), state, 
info)
 
     def change_state(
         self, key: TaskInstanceKey, state: TaskInstanceState, info=None, 
remove_running=True
     ) -> None:
         super().change_state(key, state, info, remove_running=remove_running)
-        self.tasks.pop(key, None)
+        self.workloads.pop(key, None)
 
     def update_task_state(self, key: TaskInstanceKey, state: str, info: Any) 
-> None:
-        """Update state of a single task."""
+        """Update state of a single workload."""
         try:
             if state == celery_states.SUCCESS:
                 self.success(key, info)
@@ -288,7 +299,9 @@ class CeleryExecutor(BaseExecutor):
 
     def end(self, synchronous: bool = False) -> None:
         if synchronous:
-            while any(task.state not in celery_states.READY_STATES for task in 
self.tasks.values()):
+            while any(
+                workload.state not in celery_states.READY_STATES for workload 
in self.workloads.values()
+            ):
                 time.sleep(5)
         self.sync()
 
@@ -322,7 +335,7 @@ class CeleryExecutor(BaseExecutor):
                 not_adopted_tis.append(ti)
 
         if not celery_tasks:
-            # Nothing to adopt
+            # Nothing to adopt.
             return tis
 
         states_by_celery_task_id = self.bulk_state_fetcher.get_many(
@@ -342,7 +355,7 @@ class CeleryExecutor(BaseExecutor):
 
             # Set the correct elements of the state dicts, then update this
             # like we just queried it.
-            self.tasks[ti.key] = result
+            self.workloads[ti.key] = result
             self.running.add(ti.key)
             self.update_task_state(ti.key, state, info)
             adopted.append(f"{ti} in state {state}")
@@ -373,7 +386,7 @@ class CeleryExecutor(BaseExecutor):
         return reprs
 
     def revoke_task(self, *, ti: TaskInstance):
-        celery_async_result = self.tasks.pop(ti.key, None)
+        celery_async_result = self.workloads.pop(ti.key, None)
         if celery_async_result:
             try:
                 self.celery_app.control.revoke(celery_async_result.task_id)
diff --git 
a/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py
 
b/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py
index 699052b470e..c0a9e471d2f 100644
--- 
a/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py
+++ 
b/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py
@@ -80,7 +80,7 @@ if TYPE_CHECKING:
     from airflow.models.taskinstance import TaskInstanceKey
 
     # We can't use `if AIRFLOW_V_3_0_PLUS` conditions in type checks, so 
unfortunately we just have to define
-    # the type as the union of both kinds
+    # the type as the union of both kinds.
     CommandType = Sequence[str]
 
     WorkloadInCelery: TypeAlias = tuple[WorkloadKey, workloads.All | 
CommandType, str | None, str | None]
@@ -88,7 +88,7 @@ if TYPE_CHECKING:
         WorkloadKey, CommandType, AsyncResult | "ExceptionWithTraceback"
     ]
 
-    # Deprecated alias for backward compatibility
+    # Deprecated alias for backward compatibility.
     TaskInstanceInCelery: TypeAlias = WorkloadInCelery
 
     TaskTuple = tuple[TaskInstanceKey, CommandType, str | None, Any | None]
@@ -132,10 +132,10 @@ def create_celery_app(team_conf: ExecutorConf | 
AirflowConfigParser) -> Celery:
 
     celery_app_name = team_conf.get("celery", "CELERY_APP_NAME")
 
-    # Make app name unique per team to ensure proper broker isolation
+    # Make app name unique per team to ensure proper broker isolation.
     # Each team's executor needs a distinct Celery app name to prevent
-    # tasks from being routed to the wrong broker
-    # Only do this if team_conf is an ExecutorConf with team_name (not global 
conf)
+    # tasks from being routed to the wrong broker.
+    # Only do this if team_conf is an ExecutorConf with team_name (not global 
conf).
     team_name = getattr(team_conf, "team_name", None)
     if team_name:
         celery_app_name = f"{celery_app_name}_{team_name}"
@@ -153,7 +153,7 @@ def create_celery_app(team_conf: ExecutorConf | 
AirflowConfigParser) -> Celery:
 
     celery_app = Celery(celery_app_name, config_source=config)
 
-    # Register tasks with this app
+    # Register tasks with this app.
     celery_app.task(name="execute_workload")(execute_workload)
     if not AIRFLOW_V_3_0_PLUS:
         celery_app.task(name="execute_command")(execute_command)
@@ -161,7 +161,7 @@ def create_celery_app(team_conf: ExecutorConf | 
AirflowConfigParser) -> Celery:
     return celery_app
 
 
-# Keep module-level app for backward compatibility
+# Keep module-level app for backward compatibility.
 app = _get_celery_app()
 
 
@@ -203,7 +203,7 @@ def on_celery_worker_ready(*args, **kwargs):
 
 
 # Once Celery 5.5 is out of beta, we can pass `pydantic=True` to the decorator 
and it will handle the validation
-# and deserialization for us
+# and deserialization for us.
 @app.task(name="execute_workload")
 def execute_workload(input: str) -> None:
     from celery.exceptions import Ignore
@@ -221,7 +221,7 @@ def execute_workload(input: str) -> None:
     log.info("[%s] Executing workload in Celery: %s", celery_task_id, workload)
 
     base_url = conf.get("api", "base_url", fallback="/")
-    # If it's a relative URL, use localhost:8080 as the default
+    # If it's a relative URL, use localhost:8080 as the default.
     if base_url.startswith("/"):
         base_url = f"http://localhost:8080{base_url}";
     default_execution_api_server = f"{base_url.rstrip('/')}/execution/"
@@ -285,7 +285,7 @@ if not AIRFLOW_V_3_0_PLUS:
 def _execute_in_fork(command_to_exec: CommandType, celery_task_id: str | None 
= None) -> None:
     pid = os.fork()
     if pid:
-        # In parent, wait for the child
+        # In parent, wait for the child.
         pid, ret = os.waitpid(pid, 0)
         if ret == 0:
             return
@@ -300,7 +300,7 @@ def _execute_in_fork(command_to_exec: CommandType, 
celery_task_id: str | None =
         from airflow.cli.cli_parser import get_parser
 
         parser = get_parser()
-        # [1:] - remove "airflow" from the start of the command
+        # [1:] - remove "airflow" from the start of the command.
         args = parser.parse_args(command_to_exec[1:])
         args.shut_down_logging = False
         if celery_task_id:
@@ -360,7 +360,7 @@ def send_workload_to_executor(
     workload_tuple: WorkloadInCelery,
 ) -> WorkloadInCeleryResult:
     """
-    Send workload to executor.
+    Send workload to executor (serialized and executed as a Celery task).
 
     This function is called in ProcessPoolExecutor subprocesses. To avoid 
pickling issues with
     team-specific Celery apps, we pass the team_name and reconstruct the 
Celery app here.
@@ -371,26 +371,26 @@ def send_workload_to_executor(
     # ExecutorConf wraps config access to automatically use team-specific 
config where present.
     if TYPE_CHECKING:
         _conf: ExecutorConf | AirflowConfigParser
-    # Check if Airflow version is greater than or equal to 3.2 to import 
ExecutorConf
+    # Check if Airflow version is greater than or equal to 3.2 to import 
ExecutorConf.
     if AIRFLOW_V_3_2_PLUS:
         from airflow.executors.base_executor import ExecutorConf
 
         _conf = ExecutorConf(team_name)
     else:
-        # Airflow <3.2 ExecutorConf doesn't exist (at least not with the 
required attributes), fall back to global conf
+        # Airflow <3.2 ExecutorConf doesn't exist (at least not with the 
required attributes), fall back to global conf.
         _conf = conf
-    # Create the Celery app with the correct configuration
+    # Create the Celery app with the correct configuration.
     celery_app = create_celery_app(_conf)
 
     if AIRFLOW_V_3_0_PLUS:
-        # Get the task from the app
-        task_to_run = celery_app.tasks["execute_workload"]
+        # Get the task from the app.
+        celery_task = celery_app.tasks["execute_workload"]
         if TYPE_CHECKING:
             assert isinstance(args, workloads.BaseWorkload)
         args = (args.model_dump_json(),)
     else:
-        # Get the task from the app
-        task_to_run = celery_app.tasks["execute_command"]
+        # Get the task from the app.
+        celery_task = celery_app.tasks["execute_command"]
         args = [args]  # type: ignore[list-item]
 
     # Pre-import redis.client to avoid SIGALRM interrupting module 
initialization.
@@ -400,27 +400,23 @@ def send_workload_to_executor(
     try:
         import redis.client  # noqa: F401
     except ImportError:
-        pass  # Redis not installed or not using Redis backend
+        pass  # Redis not installed or not using Redis backend.
 
     try:
         with timeout(seconds=OPERATION_TIMEOUT):
-            result = task_to_run.apply_async(args=args, queue=queue)
+            result = celery_task.apply_async(args=args, queue=queue)
     except (Exception, AirflowTaskTimeout) as e:
         exception_traceback = f"Celery Task ID: 
{key}\n{traceback.format_exc()}"
         result = ExceptionWithTraceback(e, exception_traceback)
 
     # The type is right for the version, but the type cannot be defined 
correctly for Airflow 2 and 3
-    # concurrently;
+    # concurrently.
     return key, args, result
 
 
-# Backward compatibility alias
-send_task_to_executor = send_workload_to_executor
-
-
 def fetch_celery_task_state(async_result: AsyncResult) -> tuple[str, str | 
ExceptionWithTraceback, Any]:
     """
-    Fetch and return the state of the given celery task.
+    Fetch and return the state of the given celery task (workload execution).
 
     The scope of this function is global so that it can be called by 
subprocesses in the pool.
 
@@ -434,12 +430,12 @@ def fetch_celery_task_state(async_result: AsyncResult) -> 
tuple[str, str | Excep
     try:
         import redis.client  # noqa: F401
     except ImportError:
-        pass  # Redis not installed or not using Redis backend
+        pass  # Redis not installed or not using Redis backend.
 
     try:
         with timeout(seconds=OPERATION_TIMEOUT):
-            # Accessing state property of celery task will make actual network 
request
-            # to get the current state of the task
+            # Accessing state property of celery task (workload execution) 
triggers a network request
+            # to get the current state of the task.
             info = async_result.info if hasattr(async_result, "info") else None
             return async_result.task_id, async_result.state, info
     except Exception as e:
@@ -459,7 +455,7 @@ class BulkStateFetcher(LoggingMixin):
     def __init__(self, sync_parallelism: int, celery_app: Celery | None = 
None):
         super().__init__()
         self._sync_parallelism = sync_parallelism
-        self.celery_app = celery_app or app  # Use provided app or fall back 
to module-level app
+        self.celery_app = celery_app or app  # Use provided app or fall back 
to module-level app.
 
     def _tasks_list_to_task_ids(self, async_tasks: Collection[AsyncResult]) -> 
set[str]:
         return {a.task_id for a in async_tasks}
diff --git 
a/providers/celery/src/airflow/providers/celery/executors/celery_kubernetes_executor.py
 
b/providers/celery/src/airflow/providers/celery/executors/celery_kubernetes_executor.py
index 687f7b75e3d..f66c153a6d7 100644
--- 
a/providers/celery/src/airflow/providers/celery/executors/celery_kubernetes_executor.py
+++ 
b/providers/celery/src/airflow/providers/celery/executors/celery_kubernetes_executor.py
@@ -103,7 +103,7 @@ class CeleryKubernetesExecutor(BaseExecutor):
     @property
     def queued_tasks(self) -> dict[TaskInstanceKey, Any]:
         """Return queued tasks from celery and kubernetes executor."""
-        return self.celery_executor.queued_tasks | 
self.kubernetes_executor.queued_tasks
+        return self.celery_executor.queued_tasks | 
self.kubernetes_executor.queued_tasks  # type: ignore[return-value]
 
     @queued_tasks.setter
     def queued_tasks(self, value) -> None:
diff --git a/providers/celery/src/airflow/providers/celery/get_provider_info.py 
b/providers/celery/src/airflow/providers/celery/get_provider_info.py
index 537344c7a4e..071071133bd 100644
--- a/providers/celery/src/airflow/providers/celery/get_provider_info.py
+++ b/providers/celery/src/airflow/providers/celery/get_provider_info.py
@@ -198,7 +198,7 @@ def get_provider_info():
                         "default": "prefork",
                     },
                     "operation_timeout": {
-                        "description": "The number of seconds to wait before 
timing out ``send_task_to_executor`` or\n``fetch_celery_task_state`` 
operations.\n",
+                        "description": "The number of seconds to wait before 
timing out ``send_workload_to_executor`` or\n``fetch_celery_task_state`` 
operations.\n",
                         "version_added": None,
                         "type": "float",
                         "example": None,
diff --git a/providers/celery/tests/integration/celery/test_celery_executor.py 
b/providers/celery/tests/integration/celery/test_celery_executor.py
index 3641abb6b1f..aeb769f4153 100644
--- a/providers/celery/tests/integration/celery/test_celery_executor.py
+++ b/providers/celery/tests/integration/celery/test_celery_executor.py
@@ -27,7 +27,7 @@ from datetime import datetime, timedelta
 from time import sleep
 from unittest import mock
 
-# leave this it is used by the test worker
+# Leave this it is used by the test worker.
 import celery.contrib.testing.tasks  # noqa: F401
 import pytest
 import uuid6
@@ -84,8 +84,8 @@ def _prepare_app(broker_url=None, execute=None):
     test_config = dict(celery_executor_utils.get_celery_configuration())
     test_config.update({"broker_url": broker_url})
     test_app = Celery(broker_url, config_source=test_config)
-    # Register the fake execute function with the test_app using the correct 
task name
-    # This ensures workers using test_app will execute the fake function
+    # Register the fake execute function with the test_app using the correct 
task name.
+    # This ensures workers using test_app will execute the fake function.
     test_execute = test_app.task(name=execute_name)(execute)
     patch_app = mock.patch.object(celery_executor_utils, "app", test_app)
 
@@ -95,7 +95,7 @@ def _prepare_app(broker_url=None, execute=None):
         celery_executor_utils.execute_command.__wrapped__ = execute
 
     patch_execute = mock.patch.object(celery_executor_utils, execute_name, 
test_execute)
-    # Patch factory function so CeleryExecutor instances get the test app
+    # Patch factory function so CeleryExecutor instances get the test app.
     patch_factory = mock.patch.object(celery_executor_utils, 
"create_celery_app", return_value=test_app)
 
     backend = test_app.backend
@@ -105,7 +105,7 @@ def _prepare_app(broker_url=None, execute=None):
         # race condition where it one of the subprocesses can die with "Table
         # already exists" error, because SQLA checks for which tables exist,
         # then issues a CREATE TABLE, rather than doing CREATE TABLE IF NOT
-        # EXISTS
+        # EXISTS.
         session = backend.ResultSession()
         session.close()
 
@@ -128,6 +128,17 @@ class TestCeleryExecutor:
         db.clear_db_runs()
         db.clear_db_jobs()
 
+
+def setup_dagrun_with_success_and_fail_workloads(dag_maker):
+    date = timezone.utcnow()
+    start_date = date - timedelta(days=2)
+
+    with dag_maker("test_celery_integration"):
+        BaseOperator(task_id="success", start_date=start_date)
+        BaseOperator(task_id="fail", start_date=start_date)
+
+    return dag_maker.create_dagrun(logical_date=date)
+
     @pytest.mark.flaky(reruns=5, reruns_delay=3)
     @pytest.mark.parametrize("broker_url", _prepare_test_bodies())
     @pytest.mark.parametrize(
@@ -164,19 +175,19 @@ class TestCeleryExecutor:
         from airflow.providers.celery.executors import celery_executor
 
         if AIRFLOW_V_3_0_PLUS:
-            # Airflow 3: execute_workload receives JSON string
+            # Airflow 3: execute_workload receives JSON string.
             def fake_execute(input: str) -> None:
                 """Fake execute_workload that parses JSON and fails for tasks 
with 'fail' in task_id."""
                 import json
 
                 workload_dict = json.loads(input)
-                # Check if this is a task that should fail (task_id contains 
"fail")
+                # Check if this is a workload that should fail (task_id 
contains "fail").
                 if "ti" in workload_dict and "task_id" in workload_dict["ti"]:
                     if "fail" in workload_dict["ti"]["task_id"]:
                         raise AirflowException("fail")
         else:
-            # Airflow 2: execute_command receives command list
-            def fake_execute(input: str) -> None:  # Use same parameter name 
as Airflow 3 version
+            # Airflow 2: execute_command receives command list.
+            def fake_execute(input: str) -> None:  # Use same parameter name 
as Airflow 3 version.
                 if "fail" in input:
                     raise AirflowException("fail")
 
@@ -218,6 +229,18 @@ class TestCeleryExecutor:
                         bundle_info=BundleInfo(name="test"),
                         log_path="test.log",
                     )
+                keys = [
+                    TaskInstanceKey("id", "success", "abc", 0, -1),
+                    TaskInstanceKey("id", "fail", "abc", 0, -1),
+                ]
+                dagrun = 
setup_dagrun_with_success_and_fail_workloads(dag_maker)
+                ti_success, ti_fail = dagrun.task_instances
+                for w in (
+                    workloads.ExecuteTask.make(
+                        ti=ti_success,
+                    ),
+                    workloads.ExecuteTask.make(ti=ti_fail),
+                ):
                     executor.queue_workload(w, session=None)
 
                 executor.trigger_tasks(open_slots=10)
@@ -244,7 +267,7 @@ class TestCeleryExecutor:
 
         assert executor.queued_tasks == {}
 
-    def test_error_sending_task(self):
+    def test_error_sending_workload(self):
         from airflow.providers.celery.executors import celery_executor, 
celery_executor_utils
 
         with _prepare_app():
@@ -263,8 +286,8 @@ class TestCeleryExecutor:
             executor.queued_tasks[key] = workload
             executor.task_publish_retries[key] = 1
 
-            # Mock send_task_to_executor to return an error result
-            # This simulates a failure when sending the task to Celery
+            # Mock send_workload_to_executor to return an error result.
+            # This simulates a failure when sending the workload to Celery.
             def mock_send_error(task_tuple):
                 key_from_tuple = task_tuple[0]
                 return (
@@ -277,14 +300,14 @@ class TestCeleryExecutor:
                 )
 
             with mock.patch.object(
-                celery_executor_utils, "send_task_to_executor", 
side_effect=mock_send_error
+                celery_executor_utils, "send_workload_to_executor", 
side_effect=mock_send_error
             ):
                 executor.heartbeat()
-        assert len(executor.queued_tasks) == 0, "Task should no longer be 
queued"
+        assert len(executor.queued_tasks) == 0, "Workload should no longer be 
queued"
         assert executor.event_buffer[key][0] == State.FAILED
 
-    def test_retry_on_error_sending_task(self, caplog):
-        """Test that Airflow retries publishing tasks to Celery Broker at 
least 3 times"""
+    def test_retry_on_error_sending_workload(self, caplog):
+        """Test that Airflow retries publishing workloads to Celery Broker at 
least 3 times"""
         from airflow.providers.celery.executors import celery_executor, 
celery_executor_utils
 
         with (
@@ -298,8 +321,8 @@ class TestCeleryExecutor:
             ),
         ):
             executor = celery_executor.CeleryExecutor()
-            assert executor.task_publish_retries == {}
-            assert executor.task_publish_max_retries == 3, "Assert Default Max 
Retries is 3"
+            assert executor.workload_publish_retries == {}
+            assert executor.workload_publish_max_retries == 3, "Assert Default 
Max Retries is 3"
 
             with DAG(dag_id="id"):
                 task = BashOperator(task_id="test", bash_command="true", 
start_date=datetime.now())
@@ -314,27 +337,27 @@ class TestCeleryExecutor:
             key = (task.dag.dag_id, task.task_id, ti.run_id, 0, -1)
             executor.queued_tasks[key] = workload
 
-            # Test that when heartbeat is called again, task is published 
again to Celery Queue
+            # Test that when heartbeat is called again, workload is published 
again to Celery Queue.
             executor.heartbeat()
-            assert dict(executor.task_publish_retries) == {key: 1}
-            assert len(executor.queued_tasks) == 1, "Task should remain in 
queue"
+            assert dict(executor.workload_publish_retries) == {key: 1}
+            assert len(executor.queued_tasks) == 1, "Workload should remain in 
queue"
             assert executor.event_buffer == {}
-            assert f"[Try 1 of 3] Task Timeout Error for Task: ({key})." in 
caplog.text
+            assert f"[Try 1 of 3] Celery Task Timeout Error for Workload: 
({key})." in caplog.text
 
             executor.heartbeat()
-            assert dict(executor.task_publish_retries) == {key: 2}
-            assert len(executor.queued_tasks) == 1, "Task should remain in 
queue"
+            assert dict(executor.workload_publish_retries) == {key: 2}
+            assert len(executor.queued_tasks) == 1, "Workload should remain in 
queue"
             assert executor.event_buffer == {}
-            assert f"[Try 2 of 3] Task Timeout Error for Task: ({key})." in 
caplog.text
+            assert f"[Try 2 of 3] Celery Task Timeout Error for Workload: 
({key})." in caplog.text
 
             executor.heartbeat()
-            assert dict(executor.task_publish_retries) == {key: 3}
-            assert len(executor.queued_tasks) == 1, "Task should remain in 
queue"
+            assert dict(executor.workload_publish_retries) == {key: 3}
+            assert len(executor.queued_tasks) == 1, "Workload should remain in 
queue"
             assert executor.event_buffer == {}
-            assert f"[Try 3 of 3] Task Timeout Error for Task: ({key})." in 
caplog.text
+            assert f"[Try 3 of 3] Celery Task Timeout Error for Workload: 
({key})." in caplog.text
 
             executor.heartbeat()
-            assert dict(executor.task_publish_retries) == {}
+            assert dict(executor.workload_publish_retries) == {}
             assert len(executor.queued_tasks) == 0, "Task should no longer be 
in queue"
             assert executor.event_buffer[key][0] == State.FAILED
 
@@ -389,7 +412,7 @@ class TestBulkStateFetcher:
                     ]
                 )
 
-        # Assert called - ignore order
+        # Assert called - ignore order.
         mget_args, _ = mock_mget.call_args
         assert set(mget_args[0]) == {b"celery-task-meta-456", 
b"celery-task-meta-123"}
         mock_mget.assert_called_once_with(mock.ANY)
diff --git 
a/providers/celery/tests/unit/celery/executors/test_celery_executor.py 
b/providers/celery/tests/unit/celery/executors/test_celery_executor.py
index bc668bef508..d69df9ec5e7 100644
--- a/providers/celery/tests/unit/celery/executors/test_celery_executor.py
+++ b/providers/celery/tests/unit/celery/executors/test_celery_executor.py
@@ -25,7 +25,7 @@ import sys
 from datetime import timedelta
 from unittest import mock
 
-# leave this it is used by the test worker
+# Leave this it is used by the test worker.
 import celery.contrib.testing.tasks  # noqa: F401
 import pytest
 import time_machine
@@ -100,7 +100,7 @@ def _prepare_app(broker_url=None, execute=None):
     test_execute = test_app.task(execute)
     patch_app = mock.patch.object(celery_executor_utils, "app", test_app)
     patch_execute = mock.patch.object(celery_executor_utils, execute_name, 
test_execute)
-    # Patch factory function so CeleryExecutor instances get the test app
+    # Patch factory function so CeleryExecutor instances get the test app.
     patch_factory = mock.patch.object(celery_executor_utils, 
"create_celery_app", return_value=test_app)
 
     backend = test_app.backend
@@ -110,7 +110,7 @@ def _prepare_app(broker_url=None, execute=None):
         # race condition where it one of the subprocesses can die with "Table
         # already exists" error, because SQLA checks for which tables exist,
         # then issues a CREATE TABLE, rather than doing CREATE TABLE IF NOT
-        # EXISTS
+        # EXISTS.
         session = backend.ResultSession()
         session.close()
 
@@ -118,7 +118,7 @@ def _prepare_app(broker_url=None, execute=None):
         try:
             yield test_app
         finally:
-            # Clear event loop to tear down each celery instance
+            # Clear event loop to tear down each celery instance.
             set_event_loop(None)
 
 
@@ -148,7 +148,7 @@ class TestCeleryExecutor:
         team_name = "test_team"
 
         if AIRFLOW_V_3_2_PLUS:
-            # Multi-team support with ExecutorConf requires Airflow 3.2+
+            # Multi-team support with ExecutorConf requires Airflow 3.2+.
             executor = celery_executor.CeleryExecutor(parallelism=parallelism, 
team_name=team_name)
         else:
             executor = celery_executor.CeleryExecutor(parallelism)
@@ -156,7 +156,7 @@ class TestCeleryExecutor:
         assert executor.parallelism == parallelism
 
         if AIRFLOW_V_3_2_PLUS:
-            # Multi-team support with ExecutorConf requires Airflow 3.2+
+            # Multi-team support with ExecutorConf requires Airflow 3.2+.
             assert executor.team_name == team_name
             assert executor.conf.team_name == team_name
 
@@ -167,8 +167,8 @@ class TestCeleryExecutor:
         )
         with _prepare_app():
             executor = celery_executor.CeleryExecutor()
-            executor.tasks = {"key": FakeCeleryResult()}
-            
executor.bulk_state_fetcher._get_many_using_multiprocessing(executor.tasks.values())
+            executor.workloads = {"key": FakeCeleryResult()}
+            
executor.bulk_state_fetcher._get_many_using_multiprocessing(executor.workloads.values())
         assert celery_executor_utils.CELERY_FETCH_ERR_MSG_HEADER in 
caplog.text, caplog.record_tuples
         assert FAKE_EXCEPTION_MSG in caplog.text, caplog.record_tuples
 
@@ -274,7 +274,7 @@ class TestCeleryExecutor:
 
         executor = celery_executor.CeleryExecutor()
         assert executor.running == set()
-        assert executor.tasks == {}
+        assert executor.workloads == {}
 
         not_adopted_tis = executor.try_adopt_task_instances(tis)
 
@@ -282,7 +282,7 @@ class TestCeleryExecutor:
         key_2 = TaskInstanceKey(dag.dag_id, task_2.task_id, None, 0)
         assert executor.running == {key_1, key_2}
 
-        assert executor.tasks == {key_1: AsyncResult("231"), key_2: 
AsyncResult("232")}
+        assert executor.workloads == {key_1: AsyncResult("231"), key_2: 
AsyncResult("232")}
         assert not_adopted_tis == []
 
     @pytest.fixture
@@ -317,12 +317,12 @@ class TestCeleryExecutor:
             executor = celery_executor.CeleryExecutor()
             executor.job_id = 1
             executor.running = {ti.key}
-            executor.tasks = {ti.key: AsyncResult("231")}
+            executor.workloads = {ti.key: AsyncResult("231")}
             assert executor.has_task(ti)
             with pytest.warns(AirflowProviderDeprecationWarning, 
match="cleanup_stuck_queued_tasks"):
                 executor.cleanup_stuck_queued_tasks(tis=tis)
             executor.sync()
-        assert executor.tasks == {}
+        assert executor.workloads == {}
         app.control.revoke.assert_called_once_with("231")
         mock_fail.assert_called()
         assert not executor.has_task(ti)
@@ -351,13 +351,13 @@ class TestCeleryExecutor:
             executor = celery_executor.CeleryExecutor()
             executor.job_id = 1
             executor.running = {ti.key}
-            executor.tasks = {ti.key: AsyncResult("231")}
+            executor.workloads = {ti.key: AsyncResult("231")}
             assert executor.has_task(ti)
             for ti in tis:
                 executor.revoke_task(ti=ti)
             executor.sync()
         app.control.revoke.assert_called_once_with("231")
-        assert executor.tasks == {}
+        assert executor.workloads == {}
         assert not executor.has_task(ti)
         mock_fail.assert_not_called()
 
@@ -365,18 +365,18 @@ class TestCeleryExecutor:
     def test_result_backend_sqlalchemy_engine_options(self):
         import importlib
 
-        # Scope the mock using context manager so we can clean up afterward
+        # Scope the mock using context manager so we can clean up afterward.
         with mock.patch("celery.Celery") as mock_celery:
-            # reload celery conf to apply the new config
+            # Reload celery conf to apply the new config.
             importlib.reload(default_celery)
-            # reload celery_executor_utils to recreate the celery app with new 
config
+            # reload celery_executor_utils to recreate the celery app with new 
config.
             importlib.reload(celery_executor_utils)
 
             call_args = mock_celery.call_args.kwargs.get("config_source")
             assert "database_engine_options" in call_args
             assert call_args["database_engine_options"] == {"pool_recycle": 
1800}
 
-        # Clean up: reload modules with real Celery to restore clean state for 
subsequent tests
+        # Clean up: reload modules with real Celery to restore clean state for 
subsequent tests.
         importlib.reload(default_celery)
         importlib.reload(celery_executor_utils)
 
@@ -385,9 +385,9 @@ def test_operation_timeout_config():
     assert celery_executor_utils.OPERATION_TIMEOUT == 1
 
 
-class MockTask:
+class MockWorkload:
     """
-    A picklable object used to mock tasks sent to Celery. Can't use the mock 
library
+    A picklable object used to mock workloads sent to Celery. Can't use the 
mock library
     here because it's not picklable.
     """
 
@@ -414,7 +414,7 @@ def register_signals():
 
     yield
 
-    # Restore original signal handlers after test
+    # Restore original signal handlers after test.
     signal.signal(signal.SIGINT, orig_sigint)
     signal.signal(signal.SIGTERM, orig_sigterm)
     signal.signal(signal.SIGUSR2, orig_sigusr2)
@@ -422,20 +422,20 @@ def register_signals():
 
 @pytest.mark.execution_timeout(200)
 @pytest.mark.quarantined
-def test_send_tasks_to_celery_hang(register_signals):
+def test_send_workloads_to_celery_hang(register_signals):
     """
     Test that celery_executor does not hang after many runs.
     """
     executor = celery_executor.CeleryExecutor()
 
-    task = MockTask()
-    task_tuples_to_send = [(None, None, None, task) for _ in range(26)]
+    workload = MockWorkload()
+    workload_tuples_to_send = [(None, None, None, workload) for _ in range(26)]
 
     for _ in range(250):
         # This loop can hang on Linux if celery_executor does something wrong 
with
         # multiprocessing.
-        results = executor._send_tasks_to_celery(task_tuples_to_send)
-        assert results == [(None, None, 1) for _ in task_tuples_to_send]
+        results = executor._send_workloads_to_celery(workload_tuples_to_send)
+        assert results == [(None, None, 1) for _ in workload_tuples_to_send]
 
 
 @conf_vars({("celery", "result_backend"): 
"rediss://test_user:test_password@localhost:6379/0"})
@@ -445,7 +445,7 @@ def 
test_celery_executor_with_no_recommended_result_backend(caplog):
     from airflow.providers.celery.executors.default_celery import log
 
     with caplog.at_level(logging.WARNING, logger=log.name):
-        # reload celery conf to apply the new config
+        # Reload celery conf to apply the new config.
         importlib.reload(default_celery)
         assert "test_password" not in caplog.text
         assert (
@@ -458,7 +458,7 @@ def 
test_celery_executor_with_no_recommended_result_backend(caplog):
 def test_sentinel_kwargs_loaded_from_string():
     import importlib
 
-    # reload celery conf to apply the new config
+    # Reload celery conf to apply the new config.
     importlib.reload(default_celery)
     assert 
default_celery.DEFAULT_CELERY_CONFIG["broker_transport_options"]["sentinel_kwargs"]
 == {
         "service_name": "mymaster"
@@ -469,7 +469,7 @@ def test_sentinel_kwargs_loaded_from_string():
 def test_celery_task_acks_late_loaded_from_string():
     import importlib
 
-    # reload celery conf to apply the new config
+    # Reload celery conf to apply the new config.
     importlib.reload(default_celery)
     assert default_celery.DEFAULT_CELERY_CONFIG["task_acks_late"] is False
 
@@ -529,7 +529,7 @@ def 
test_visibility_timeout_not_set_for_unsupported_broker(caplog):
 def test_celery_extra_celery_config_loaded_from_string():
     import importlib
 
-    # reload celery conf to apply the new config
+    # Reload celery conf to apply the new config.
     importlib.reload(default_celery)
     assert default_celery.DEFAULT_CELERY_CONFIG["worker_max_tasks_per_child"] 
== 10
 
@@ -539,7 +539,7 @@ def 
test_result_backend_sentinel_kwargs_loaded_from_string():
     """Test that sentinel_kwargs for result backend transport options is 
correctly parsed."""
     import importlib
 
-    # reload celery conf to apply the new config
+    # Reload celery conf to apply the new config.
     importlib.reload(default_celery)
     assert "result_backend_transport_options" in 
default_celery.DEFAULT_CELERY_CONFIG
     assert 
default_celery.DEFAULT_CELERY_CONFIG["result_backend_transport_options"]["sentinel_kwargs"]
 == {
@@ -552,7 +552,7 @@ def test_result_backend_master_name_loaded():
     """Test that master_name for result backend transport options is correctly 
loaded."""
     import importlib
 
-    # reload celery conf to apply the new config
+    # Reload celery conf to apply the new config.
     importlib.reload(default_celery)
     assert "result_backend_transport_options" in 
default_celery.DEFAULT_CELERY_CONFIG
     assert (
@@ -570,7 +570,7 @@ def 
test_result_backend_transport_options_with_multiple_options():
     """Test that multiple result backend transport options are correctly 
loaded."""
     import importlib
 
-    # reload celery conf to apply the new config
+    # Reload celery conf to apply the new config.
     importlib.reload(default_celery)
     result_backend_opts = 
default_celery.DEFAULT_CELERY_CONFIG["result_backend_transport_options"]
     assert result_backend_opts["sentinel_kwargs"] == {"password": 
"redis_password"}
@@ -614,7 +614,7 @@ def test_result_backend_sentinel_full_config():
     """Test full Redis Sentinel configuration for result backend."""
     import importlib
 
-    # reload celery conf to apply the new config
+    # Reload celery conf to apply the new config.
     importlib.reload(default_celery)
 
     assert default_celery.DEFAULT_CELERY_CONFIG["result_backend"] == (
@@ -643,13 +643,13 @@ class TestMultiTeamCeleryExecutor:
             ("operators", "default_queue"): "global_queue",
         }
     )
-    def test_multi_team_isolation_and_task_routing(self, monkeypatch):
+    def test_multi_team_isolation_and_workload_routing(self, monkeypatch):
         """
-        Test multi-team executor isolation and correct task routing.
+        Test multi-team executor isolation and correct workload routing.
 
         Verifies:
         - Each executor has isolated Celery app and config
-        - Tasks are routed through team-specific apps 
(_process_tasks/_process_workloads)
+        - Workloads are routed through team-specific apps 
(_process_tasks/_process_workloads)
         - Backward compatibility with global executor
         """
         # Set up team-specific config via environment variables
@@ -658,49 +658,49 @@ class TestMultiTeamCeleryExecutor:
         monkeypatch.setenv("AIRFLOW__TEAM_B___CELERY__BROKER_URL", 
"redis://team-b:6379/0")
         monkeypatch.setenv("AIRFLOW__TEAM_B___OPERATORS__DEFAULT_QUEUE", 
"team_b_queue")
 
-        # Reload config to pick up environment variables
+        # Reload config to pick up environment variables.
         from airflow import configuration
 
         configuration.conf.read_dict({}, source="test")
 
-        # Create executors with different team configs
+        # Create executors with different team configs.
         team_a_executor = CeleryExecutor(parallelism=2, team_name="team_a")
         team_b_executor = CeleryExecutor(parallelism=3, team_name="team_b")
         global_executor = CeleryExecutor(parallelism=4)
 
-        # Each executor has its own Celery app (critical for isolation)
+        # Each executor has its own Celery app (critical for isolation).
         assert team_a_executor.celery_app is not team_b_executor.celery_app
         assert team_a_executor.celery_app is not global_executor.celery_app
 
-        # Team-specific broker URLs are used
+        # Team-specific broker URLs are used.
         assert "team-a" in team_a_executor.celery_app.conf.broker_url
         assert "team-b" in team_b_executor.celery_app.conf.broker_url
         assert "global" in global_executor.celery_app.conf.broker_url
 
-        # Team-specific queues are used
+        # Team-specific queues are used.
         assert team_a_executor.celery_app.conf.task_default_queue == 
"team_a_queue"
         assert team_b_executor.celery_app.conf.task_default_queue == 
"team_b_queue"
         assert global_executor.celery_app.conf.task_default_queue == 
"global_queue"
 
-        # Each executor has its own BulkStateFetcher with correct app
+        # Each executor has its own BulkStateFetcher with correct app.
         assert team_a_executor.bulk_state_fetcher.celery_app is 
team_a_executor.celery_app
         assert team_b_executor.bulk_state_fetcher.celery_app is 
team_b_executor.celery_app
 
-        # Executors have isolated internal state
-        assert team_a_executor.tasks is not team_b_executor.tasks
+        # Executors have isolated internal state.
+        assert team_a_executor.workloads is not team_b_executor.workloads
         assert team_a_executor.running is not team_b_executor.running
         assert team_a_executor.queued_tasks is not team_b_executor.queued_tasks
 
     @conf_vars({("celery", "broker_url"): "redis://global:6379/0"})
-    
@mock.patch("airflow.providers.celery.executors.celery_executor.CeleryExecutor._send_tasks")
-    def test_task_routing_through_team_specific_app(self, mock_send_tasks, 
monkeypatch):
+    
@mock.patch("airflow.providers.celery.executors.celery_executor.CeleryExecutor._send_workloads")
+    def test_workload_routing_through_team_specific_app(self, 
mock_send_workloads, monkeypatch):
         """
-        Test that _process_tasks and _process_workloads pass the correct 
team_name for task routing.
+        Test that _process_tasks (v2) and _process_workloads (v3) pass the 
correct team_name for task routing.
 
         With the ProcessPoolExecutor approach, we pass team_name instead of 
task objects to avoid
         pickling issues. The subprocess reconstructs the team-specific Celery 
app from the team_name.
         """
-        # Set up team A config
+        # Set up team A config.
         monkeypatch.setenv("AIRFLOW__TEAM_A___CELERY__BROKER_URL", 
"redis://team-a:6379/0")
 
         team_a_executor = CeleryExecutor(parallelism=2, team_name="team_a")
@@ -709,40 +709,42 @@ class TestMultiTeamCeleryExecutor:
             from airflow.executors.workloads import ExecuteTask
             from airflow.models.taskinstancekey import TaskInstanceKey
 
-            # Create mock workload
+            # Create mock workload.
             mock_ti = mock.Mock()
             mock_ti.key = TaskInstanceKey("dag", "task", "run", 1)
             mock_ti.queue = "test_queue"
             mock_workload = mock.Mock(spec=ExecuteTask)
             mock_workload.ti = mock_ti
 
-            # Process workload through team A executor
+            # Process workload through team A executor.
             team_a_executor._process_workloads([mock_workload])
 
-            # Verify _send_tasks received the correct team_name
-            assert mock_send_tasks.called
-            task_tuples = mock_send_tasks.call_args[0][0]
-            team_name_from_call = task_tuples[0][3]  # 4th element is now 
team_name
+            # Verify _send_workloads received the correct team_name.
+            assert mock_send_workloads.called
+            workload_tuples = mock_send_workloads.call_args[0][0]
+            team_name_from_call = workload_tuples[0][
+                3
+            ]  # 4th element is team_name (used to reconstruct Celery app in 
subprocess).
 
-            # Critical: team_name is passed so subprocess can reconstruct the 
correct app
+            # Critical: team_name is passed so subprocess can reconstruct the 
correct app.
             assert team_name_from_call == "team_a"
         else:
             from airflow.models.taskinstancekey import TaskInstanceKey
 
-            # Test V2 path with execute_command
+            # Test V2 path with execute_command.
             mock_key = TaskInstanceKey("dag", "task", "run", 1)
             mock_command = ["airflow", "tasks", "run", "dag", "task"]
             mock_queue = "test_queue"
 
-            # Process task through team A executor
+            # Process task through team A executor.
             team_a_executor._process_tasks([(mock_key, mock_command, 
mock_queue, None)])
 
-            # Verify _send_tasks received team A's execute_command task
-            assert mock_send_tasks.called
-            task_tuples = mock_send_tasks.call_args[0][0]
-            task_from_call = task_tuples[0][3]  # 4th element is the task (V2 
still uses task object)
+            # Verify _send_workloads received team A's execute_command 
workload (v2 compatibility path).
+            assert mock_send_workloads.called
+            task_tuples = mock_send_workloads.call_args[0][0]
+            task_from_call = task_tuples[0][3]  # 4th element is the task (V2 
still uses task object).
 
-            # Critical: task belongs to team A's app, not module-level app
+            # Critical: Celery task belongs to team A's app, not module-level 
app.
             assert task_from_call.app is team_a_executor.celery_app
             assert task_from_call.name == "execute_command"
 
@@ -763,7 +765,7 @@ def test_celery_tasks_registered_on_import():
         "execute_workload must be registered with the Celery app at import 
time. "
         "Workers need this to receive tasks without KeyError."
     )
-    # TODO: remove this block when min supported Airflow version is >= 3.0
+    # TODO: remove this block when min supported Airflow version is >= 3.0.
     if not AIRFLOW_V_3_0_PLUS:
         assert "execute_command" in registered_tasks, (
             "execute_command must be registered for Airflow 2.x compatibility."

Reply via email to