This is an automated email from the ASF dual-hosted git repository.
o-nikolas pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new f2af887699a Add ExecuteCallback support to AWS ECS Executor (#63657)
f2af887699a is described below
commit f2af887699a42ab88a1c3213cd963b84133543b8
Author: Shivam Rastogi <[email protected]>
AuthorDate: Thu May 21 14:05:41 2026 -0400
Add ExecuteCallback support to AWS ECS Executor (#63657)
* Add ExecuteCallback support to AWS ECS Executor
Enables the ECS executor to dispatch ExecuteCallback workloads (deadline
alerts) alongside regular ExecuteTask workloads. Builds on #65392 which
widened BaseExecutor signatures to accept WorkloadKey.
- supports_callbacks = True (gated on AIRFLOW_V_3_3_PLUS)
- Widen key types to WorkloadKey throughout EcsQueuedTask /
EcsTaskCollection
- Branch _process_workloads on ExecuteTask vs ExecuteCallback
- Add AIRFLOW_V_3_3_PLUS to version_compat.py
- Unit tests for queueing, processing, serialization, sync, mixed keys
* Rename task-named methods/attrs and fix older-Airflow compat
Renames (mirrors the merged Lambda callback PR — straight rename,
no shim, executor-internal surface):
sync_running_tasks -> sync_running_workloads
attempt_task_runs -> attempt_workload_runs
pending_tasks (attr) -> pending_workloads
__update_running_task -> __update_running_workload
__handle_failed_task -> __handle_failed_workload
Fix CI on older Airflow compat tests:
- Restore queue_workload() override. Airflow 3.3+ BaseExecutor routes
ExecuteCallback natively, but pre-3.3 raises ValueError for anything
not ExecuteTask. Override works across versions.
- Import AIRFLOW_V_3_3_PLUS from tests_common (main bumped to 3.3).
check-airflow-v-imports-in-tests hook disallows provider-internal
version_compat imports from test files.
* Apply suggestions from code review
Co-authored-by: D. Ferruzzi <[email protected]>
* Widen CommandType and pre-declare loop locals to drop type ignores
Mirror the Lambda executor pattern: widen CommandType to
Sequence[str] | Sequence[ExecuteTask | ExecuteCallback] on Airflow 3.3+,
and pre-declare queue/key/command at the top of the _process_workloads
loop so mypy doesn't infer narrow types from the first if-branch.
Removes five # type: ignore comments that previously covered the
signature mismatch and cross-branch reassignment.
* Fix docs autoapi duplicate-object warnings
Two adjustments to silence Sphinx duplicate-object warnings between the
Lambda and ECS executors that block the docs build:
- Add `:sphinx-autoapi-skip:` docstring on ECS `_process_workloads` so
autoapi skips this private method.
- Split the inner union in ECS `CommandType` so its rendered TypeAlias
text differs from Lambda's, avoiding a duplicate. The split form is
also stricter (disallows mixed task+callback lists, which our code
never produces).
* Add trailing period to sphinx-autoapi-skip docstring for D415
* Use TaskInstanceKey and CallbackKey in test_collection_mixed_key_types
Switch from Mock(spec=tuple) / bare-string callback key to
Mock(spec=TaskInstanceKey) and a real CallbackKey(...) instance,
matching the pattern established in #67268 for the rest of the
ECS/Lambda/Batch executor tests.
CallbackKey became a frozen dataclass in #66973 and no longer
accepts bare strings; this test was added in this PR so it was
missed by the #67268 cleanup sweep.
---------
Co-authored-by: D. Ferruzzi <[email protected]>
---
.../amazon/aws/executors/ecs/ecs_executor.py | 176 +++++++++-----
.../providers/amazon/aws/executors/ecs/utils.py | 61 +++--
.../amazon/aws/executors/ecs/test_ecs_executor.py | 261 +++++++++++++++------
3 files changed, 344 insertions(+), 154 deletions(-)
diff --git
a/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py
b/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py
index c35ba5c0fa2..644fbb29ed6 100644
---
a/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py
+++
b/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py
@@ -24,13 +24,15 @@ Each Airflow task gets delegated out to an Amazon ECS Task.
from __future__ import annotations
import time
+import warnings
from collections import defaultdict, deque
from collections.abc import Sequence
from copy import deepcopy
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, TypeAlias
from botocore.exceptions import ClientError, NoCredentialsError
+from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.executors.base_executor import BaseExecutor
from airflow.providers.amazon.aws.executors.ecs.boto_schema import
BotoDescribeTasksSchema, BotoRunTaskSchema
from airflow.providers.amazon.aws.executors.ecs.utils import (
@@ -46,7 +48,7 @@ from
airflow.providers.amazon.aws.executors.utils.exponential_backoff_retry impo
exponential_backoff_retry,
)
from airflow.providers.amazon.aws.hooks.ecs import EcsHook
-from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS
+from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS,
AIRFLOW_V_3_3_PLUS
from airflow.providers.common.compat.sdk import AirflowException, Stats,
timezone
from airflow.utils.helpers import merge_dicts
from airflow.utils.state import State
@@ -61,6 +63,13 @@ if TYPE_CHECKING:
ExecutorConfigType,
)
+ if AIRFLOW_V_3_3_PLUS:
+ from airflow.executors.workloads.types import WorkloadKey as
_EcsWorkloadKey
+
+ WorkloadKey: TypeAlias = _EcsWorkloadKey
+ else:
+ WorkloadKey: TypeAlias = TaskInstanceKey # type: ignore[no-redef,
misc]
+
INVALID_CREDENTIALS_EXCEPTIONS = [
"ExpiredTokenException",
"InvalidClientTokenId",
@@ -92,6 +101,9 @@ class AwsEcsExecutor(BaseExecutor):
supports_multi_team: bool = True
+ if AIRFLOW_V_3_3_PLUS:
+ supports_callbacks: bool = True
+
# AWS limits the maximum number of ARNs in the describe_tasks function.
DESCRIBE_TASKS_BATCH_SIZE = 99
@@ -103,7 +115,7 @@ class AwsEcsExecutor(BaseExecutor):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.active_workers: EcsTaskCollection = EcsTaskCollection()
- self.pending_tasks: deque = deque()
+ self.pending_workloads: deque = deque()
# Check if self has the ExecutorConf set on the self.conf attribute,
and if not, set it to the global
# configuration object. This allows the changes to be backwards
compatible with older versions of
@@ -131,30 +143,46 @@ class AwsEcsExecutor(BaseExecutor):
fallback=CONFIG_DEFAULTS[AllEcsConfigKeys.MAX_RUN_TASK_ATTEMPTS],
)
+ # TODO: Remove this once the minimum supported version is 3.3+, and defer
to BaseExecutor.queue_workload.
def queue_workload(self, workload: workloads.All, session: Session | None)
-> None:
from airflow.executors import workloads
- if not isinstance(workload, workloads.ExecuteTask):
- raise RuntimeError(f"{type(self)} cannot handle workloads of type
{type(workload)}")
- ti = workload.ti
- self.queued_tasks[ti.key] = workload
+ if isinstance(workload, workloads.ExecuteTask):
+ self.queued_tasks[workload.ti.key] = workload
+ return
+ if AIRFLOW_V_3_3_PLUS and isinstance(workload,
workloads.ExecuteCallback):
+ self.queued_callbacks[workload.callback.key] = workload
+ return
+ raise RuntimeError(f"{type(self)} cannot handle workloads of type
{type(workload)}")
- def _process_workloads(self, workloads: Sequence[workloads.All]) -> None:
- from airflow.executors.workloads import ExecuteTask
+ def _process_workloads(self, workload_items: Sequence[workloads.All]) ->
None:
+ """:sphinx-autoapi-skip:."""
+ from airflow.executors import workloads
- # Airflow V3 version
- for w in workloads:
- if not isinstance(w, ExecuteTask):
- raise RuntimeError(f"{type(self)} cannot handle workloads of
type {type(w)}")
+ for workload in workload_items:
+ queue: str | None
+ key: WorkloadKey
+ command: CommandType
+ if isinstance(workload, workloads.ExecuteTask):
+ command = [workload]
+ key = workload.ti.key
+ queue = workload.ti.queue
+ executor_config = workload.ti.executor_config or {}
- command = [w]
- key = w.ti.key
- queue = w.ti.queue
- executor_config = w.ti.executor_config or {}
+ del self.queued_tasks[key]
+ self.execute_async(key=key, command=command, queue=queue,
executor_config=executor_config)
+ self.running.add(key)
- del self.queued_tasks[key]
- self.execute_async(key=key, command=command, queue=queue,
executor_config=executor_config) # type: ignore[arg-type]
- self.running.add(key)
+ elif AIRFLOW_V_3_3_PLUS and isinstance(workload,
workloads.ExecuteCallback):
+ command = [workload]
+ key = workload.callback.key
+
+ del self.queued_callbacks[key]
+ self.execute_async(key=key, command=command, queue=None)
+ self.running.add(key)
+
+ else:
+ raise RuntimeError(f"{type(self)} cannot handle workloads of
type {type(workload)}")
def start(self):
"""Call this when the Executor is run for the first time by the
scheduler."""
@@ -246,8 +274,8 @@ class AwsEcsExecutor(BaseExecutor):
if not self.IS_BOTO_CONNECTION_HEALTHY:
return
try:
- self.sync_running_tasks()
- self.attempt_task_runs()
+ self.sync_running_workloads()
+ self.attempt_workload_runs()
except (ClientError, NoCredentialsError) as error:
error_code = error.response["Error"]["Code"]
if error_code in INVALID_CREDENTIALS_EXCEPTIONS:
@@ -261,11 +289,11 @@ class AwsEcsExecutor(BaseExecutor):
# up and kill the scheduler process
self.log.exception("Failed to sync %s", self.__class__.__name__)
- def sync_running_tasks(self):
- """Check and update state on all running tasks."""
+ def sync_running_workloads(self):
+ """Check and update state on all running workloads (tasks and
callbacks)."""
all_task_arns = self.active_workers.get_all_arns()
if not all_task_arns:
- self.log.debug("No active Airflow tasks, skipping sync.")
+ self.log.debug("No active Airflow workloads, skipping sync.")
return
describe_tasks_response = self.__describe_tasks(all_task_arns)
@@ -274,13 +302,13 @@ class AwsEcsExecutor(BaseExecutor):
if describe_tasks_response["failures"]:
for failure in describe_tasks_response["failures"]:
- self.__handle_failed_task(failure["arn"], failure["reason"])
+ self.__handle_failed_workload(failure["arn"],
failure["reason"])
updated_tasks = describe_tasks_response["tasks"]
for task in updated_tasks:
- self.__update_running_task(task)
+ self.__update_running_workload(task)
- def __update_running_task(self, task):
+ def __update_running_workload(self, task):
self.active_workers.update_task(task)
# Get state of current task.
task_state = task.get_task_state()
@@ -289,10 +317,10 @@ class AwsEcsExecutor(BaseExecutor):
# Mark finished tasks as either a success/failure.
if task_state == State.FAILED or task_state == State.REMOVED:
self.__log_container_failures(task_arn=task.task_arn)
- self.__handle_failed_task(task.task_arn, task.stopped_reason)
+ self.__handle_failed_workload(task.task_arn, task.stopped_reason)
elif task_state == State.SUCCESS:
self.log.debug(
- "Airflow task %s marked as %s after running on ECS Task (arn)
%s",
+ "Airflow workload %s marked as %s after running on ECS Task
(arn) %s",
task_key,
task_state,
task.task_arn,
@@ -329,13 +357,13 @@ class AwsEcsExecutor(BaseExecutor):
"The ECS task failed due to the following containers
failing:\n%s", "\n".join(reasons)
)
- def __handle_failed_task(self, task_arn: str, reason: str):
+ def __handle_failed_workload(self, task_arn: str, reason: str):
"""
- If an API failure occurs, the task is rescheduled.
+ If an API failure occurs, the workload is rescheduled.
This function will determine whether the task has been attempted the
appropriate number
of times, and determine whether the task should be marked failed or
not. The task will
- be removed active_workers, and marked as FAILED, or set into
pending_tasks depending on
+ be removed active_workers, and marked as FAILED, or set into
pending_workloads depending on
how many times it has been retried.
"""
task_key = self.active_workers.arn_to_key[task_arn]
@@ -346,14 +374,14 @@ class AwsEcsExecutor(BaseExecutor):
failure_count = self.active_workers.failure_count_by_key(task_key)
if int(failure_count) < int(self.max_run_task_attempts):
self.log.warning(
- "Airflow task %s failed due to %s. Failure %s out of %s
occurred on %s. Rescheduling.",
+ "Airflow workload %s failed due to %s. Failure %s out of %s
occurred on %s. Rescheduling.",
task_key,
reason,
failure_count,
self.max_run_task_attempts,
task_arn,
)
- self.pending_tasks.append(
+ self.pending_workloads.append(
EcsQueuedTask(
task_key,
task_cmd,
@@ -365,16 +393,16 @@ class AwsEcsExecutor(BaseExecutor):
)
else:
self.log.error(
- "Airflow task %s has failed a maximum of %s times. Marking as
failed",
+ "Airflow workload %s has failed a maximum of %s times. Marking
as failed",
task_key,
failure_count,
)
self.fail(task_key)
self.active_workers.pop_by_key(task_key)
- def attempt_task_runs(self):
+ def attempt_workload_runs(self):
"""
- Take tasks from the pending_tasks queue, and attempts to find an
instance to run it on.
+ Take tasks from the pending_workloads queue, and attempts to find an
instance to run it on.
If the launch type is EC2, this will attempt to place tasks on empty
EC2 instances. If
there are no EC2 instances available, no task is placed and this
function will be
@@ -382,10 +410,10 @@ class AwsEcsExecutor(BaseExecutor):
If the launch type is FARGATE, this will run the tasks on new AWS
Fargate instances.
"""
- queue_len = len(self.pending_tasks)
+ queue_len = len(self.pending_workloads)
failure_reasons = defaultdict(int)
for _ in range(queue_len):
- ecs_task = self.pending_tasks.popleft()
+ ecs_task = self.pending_workloads.popleft()
task_key = ecs_task.key
cmd = ecs_task.command
queue = ecs_task.queue
@@ -393,17 +421,17 @@ class AwsEcsExecutor(BaseExecutor):
attempt_number = ecs_task.attempt_number
failure_reasons = []
if timezone.utcnow() < ecs_task.next_attempt_time:
- self.pending_tasks.append(ecs_task)
+ self.pending_workloads.append(ecs_task)
continue
try:
run_task_response = self._run_task(task_key, cmd, queue,
exec_config)
except NoCredentialsError:
- self.pending_tasks.append(ecs_task)
+ self.pending_workloads.append(ecs_task)
raise
except ClientError as e:
error_code = e.response["Error"]["Code"]
if error_code in INVALID_CREDENTIALS_EXCEPTIONS:
- self.pending_tasks.append(ecs_task)
+ self.pending_workloads.append(ecs_task)
raise
failure_reasons.append(str(e))
except Exception as e:
@@ -426,11 +454,11 @@ class AwsEcsExecutor(BaseExecutor):
ecs_task.next_attempt_time = timezone.utcnow() +
calculate_next_attempt_delay(
attempt_number
)
- self.pending_tasks.append(ecs_task)
+ self.pending_workloads.append(ecs_task)
else:
reasons_str = ", ".join(failure_reasons)
self.log.error(
- "ECS task %s has failed a maximum of %s times. Marking
as failed. Reasons: %s",
+ "ECS workload %s has failed a maximum of %s times.
Marking as failed. Reasons: %s",
task_key,
attempt_number,
reasons_str,
@@ -460,7 +488,11 @@ class AwsEcsExecutor(BaseExecutor):
self.running_state(task_key, task.task_arn)
def _run_task(
- self, task_id: TaskInstanceKey, cmd: CommandType, queue: str,
exec_config: ExecutorConfigType
+ self,
+ task_id: WorkloadKey,
+ cmd: CommandType,
+ queue: str | None,
+ exec_config: ExecutorConfigType,
):
"""
Run a queued-up Airflow task.
@@ -475,7 +507,11 @@ class AwsEcsExecutor(BaseExecutor):
return run_task_response
def _run_task_kwargs(
- self, task_id: TaskInstanceKey, cmd: CommandType, queue: str,
exec_config: ExecutorConfigType
+ self,
+ task_id: WorkloadKey,
+ cmd: CommandType,
+ queue: str | None,
+ exec_config: ExecutorConfigType,
) -> dict:
"""
Update the Airflow command by modifying container overrides for
task-specific kwargs.
@@ -494,21 +530,23 @@ class AwsEcsExecutor(BaseExecutor):
return run_task_kwargs
- def execute_async(self, key: TaskInstanceKey, command: CommandType,
queue=None, executor_config=None):
- """Save the task to be executed in the next sync by inserting the
commands into a queue."""
+ def execute_async(self, key: WorkloadKey, command: CommandType,
queue=None, executor_config=None):
+ """Save the workload to be executed in the next sync by inserting the
commands into a queue."""
if executor_config and ("name" in executor_config or "command" in
executor_config):
raise ValueError('Executor Config should never override "name" or
"command"')
if len(command) == 1:
- from airflow.executors.workloads import ExecuteTask
+ from airflow.executors import workloads
- if isinstance(command[0], ExecuteTask):
+ if isinstance(command[0], workloads.ExecuteTask) or (
+ AIRFLOW_V_3_3_PLUS and isinstance(command[0],
workloads.ExecuteCallback)
+ ):
command = self._serialize_workload_to_command(command[0])
else:
raise ValueError(
f"EcsExecutor doesn't know how to handle workload of type:
{type(command[0])}"
)
- self.pending_tasks.append(
+ self.pending_workloads.append(
EcsQueuedTask(key, command, queue, executor_config or {}, 1,
timezone.utcnow())
)
@@ -567,9 +605,9 @@ class AwsEcsExecutor(BaseExecutor):
@staticmethod
def _serialize_workload_to_command(workload) -> CommandType:
"""
- Serialize an ExecuteTask workload into a command for the Task SDK.
+ Serialize a workload into a command for the Task SDK.
- :param workload: ExecuteTask workload to serialize
+ :param workload: ExecuteTask or ExecuteCallback workload to serialize
:return: Command as list of strings for Task SDK execution
"""
return [
@@ -634,3 +672,33 @@ class AwsEcsExecutor(BaseExecutor):
not_adopted_tis = [ti for ti in tis if ti not in adopted_tis]
return not_adopted_tis
+
+ # ── Back-compat shims for renamed methods/attrs ────────────────────────
+
+ @property
+ def pending_tasks(self) -> deque:
+ """Use pending_workloads as pending_tasks is deprecated."""
+ warnings.warn(
+ "pending_tasks is deprecated, use pending_workloads instead.",
+ AirflowProviderDeprecationWarning,
+ stacklevel=2,
+ )
+ return self.pending_workloads
+
+ def sync_running_tasks(self):
+ """Use sync_running_workloads as sync_running_tasks is deprecated."""
+ warnings.warn(
+ "sync_running_tasks is deprecated, use sync_running_workloads
instead.",
+ AirflowProviderDeprecationWarning,
+ stacklevel=2,
+ )
+ return self.sync_running_workloads()
+
+ def attempt_task_runs(self):
+ """Use attempt_workload_runs as attempt_task_runs is deprecated."""
+ warnings.warn(
+ "attempt_task_runs is deprecated, use attempt_workload_runs
instead.",
+ AirflowProviderDeprecationWarning,
+ stacklevel=2,
+ )
+ return self.attempt_workload_runs()
diff --git
a/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/utils.py
b/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/utils.py
index f8d7f580621..fa36377abfd 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/utils.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/utils.py
@@ -27,17 +27,26 @@ import datetime
from collections import defaultdict
from collections.abc import Callable, Sequence
from dataclasses import dataclass
-from typing import TYPE_CHECKING, Any
+from typing import TYPE_CHECKING, Any, TypeAlias
from inflection import camelize
from airflow.providers.amazon.aws.executors.utils.base_config_keys import
BaseConfigKeys
+from airflow.providers.amazon.version_compat import AIRFLOW_V_3_3_PLUS
from airflow.utils.state import State
if TYPE_CHECKING:
- from airflow.models.taskinstance import TaskInstanceKey
+ if AIRFLOW_V_3_3_PLUS:
+ from airflow.executors.workloads.types import WorkloadKey
+ else:
+ from airflow.models.taskinstance import TaskInstanceKey as WorkloadKey
# type: ignore[assignment]
-CommandType = Sequence[str]
+if AIRFLOW_V_3_3_PLUS:
+ from airflow.executors.workloads import ExecuteCallback, ExecuteTask
+
+ CommandType: TypeAlias = Sequence[str] | Sequence[ExecuteTask] |
Sequence[ExecuteCallback]
+else:
+ CommandType: TypeAlias = Sequence[str] # type: ignore[no-redef, misc]
ExecutorConfigFunctionType = Callable[[CommandType], dict]
ExecutorConfigType = dict[str, Any]
@@ -57,11 +66,11 @@ CONFIG_DEFAULTS = {
@dataclass
class EcsQueuedTask:
- """Represents an ECS task that is queued. The task will be run in the next
heartbeat."""
+ """Represents a queued ECS workload (task or callback). The workload will
be run in the next heartbeat."""
- key: TaskInstanceKey
+ key: WorkloadKey
command: CommandType
- queue: str
+ queue: str | None
executor_config: ExecutorConfigType
attempt_number: int
next_attempt_time: datetime.datetime
@@ -72,7 +81,7 @@ class EcsTaskInfo:
"""Contains information about a currently running ECS task."""
cmd: CommandType
- queue: str
+ queue: str | None
config: ExecutorConfigType
@@ -156,20 +165,20 @@ class EcsExecutorTask:
class EcsTaskCollection:
- """A five-way dictionary between Airflow task ids, Airflow cmds, ECS ARNs,
and ECS task objects."""
+ """A five-way dictionary between Airflow workload keys, commands, ECS
ARNs, and ECS task objects."""
def __init__(self):
- self.key_to_arn: dict[TaskInstanceKey, str] = {}
- self.arn_to_key: dict[str, TaskInstanceKey] = {}
+ self.key_to_arn: dict[WorkloadKey, str] = {}
+ self.arn_to_key: dict[str, WorkloadKey] = {}
self.tasks: dict[str, EcsExecutorTask] = {}
- self.key_to_failure_counts: dict[TaskInstanceKey, int] =
defaultdict(int)
- self.key_to_task_info: dict[TaskInstanceKey, EcsTaskInfo] = {}
+ self.key_to_failure_counts: dict[WorkloadKey, int] = defaultdict(int)
+ self.key_to_task_info: dict[WorkloadKey, EcsTaskInfo] = {}
def add_task(
self,
task: EcsExecutorTask,
- airflow_task_key: TaskInstanceKey,
- queue: str,
+ airflow_task_key: WorkloadKey,
+ queue: str | None,
airflow_cmd: CommandType,
exec_config: ExecutorConfigType,
attempt_number: int,
@@ -186,8 +195,8 @@ class EcsTaskCollection:
"""Update the state of the given task based on task ARN."""
self.tasks[task.task_arn] = task
- def task_by_key(self, task_key: TaskInstanceKey) -> EcsExecutorTask:
- """Get a task by Airflow Instance Key."""
+ def task_by_key(self, task_key: WorkloadKey) -> EcsExecutorTask:
+ """Get a task by Airflow workload key."""
arn = self.key_to_arn[task_key]
return self.task_by_arn(arn)
@@ -195,8 +204,8 @@ class EcsTaskCollection:
"""Get a task by AWS ARN."""
return self.tasks[arn]
- def pop_by_key(self, task_key: TaskInstanceKey) -> EcsExecutorTask:
- """Delete task from collection based off of Airflow Task Instance
Key."""
+ def pop_by_key(self, task_key: WorkloadKey) -> EcsExecutorTask:
+ """Delete task from collection based off of Airflow workload key."""
arn = self.key_to_arn[task_key]
task = self.tasks[arn]
del self.key_to_arn[task_key]
@@ -211,20 +220,20 @@ class EcsTaskCollection:
"""Get all AWS ARNs in collection."""
return list(self.key_to_arn.values())
- def get_all_task_keys(self) -> list[TaskInstanceKey]:
- """Get all Airflow Task Keys in collection."""
+ def get_all_task_keys(self) -> list[WorkloadKey]:
+ """Get all Airflow workload keys in collection."""
return list(self.key_to_arn.keys())
- def failure_count_by_key(self, task_key: TaskInstanceKey) -> int:
- """Get the number of times a task has failed given an Airflow Task
Key."""
+ def failure_count_by_key(self, task_key: WorkloadKey) -> int:
+ """Get the number of times a workload has failed given an Airflow
workload key."""
return self.key_to_failure_counts[task_key]
- def increment_failure_count(self, task_key: TaskInstanceKey):
- """Increment the failure counter given an Airflow Task Key."""
+ def increment_failure_count(self, task_key: WorkloadKey):
+ """Increment the failure counter given an Airflow workload key."""
self.key_to_failure_counts[task_key] += 1
- def info_by_key(self, task_key: TaskInstanceKey) -> EcsTaskInfo:
- """Get the Airflow Command given an Airflow task key."""
+ def info_by_key(self, task_key: WorkloadKey) -> EcsTaskInfo:
+ """Get the task info given an Airflow workload key."""
return self.key_to_task_info[task_key]
def __getitem__(self, value):
diff --git
a/providers/amazon/tests/unit/amazon/aws/executors/ecs/test_ecs_executor.py
b/providers/amazon/tests/unit/amazon/aws/executors/ecs/test_ecs_executor.py
index ca2d1e255b2..f350c884981 100644
--- a/providers/amazon/tests/unit/amazon/aws/executors/ecs/test_ecs_executor.py
+++ b/providers/amazon/tests/unit/amazon/aws/executors/ecs/test_ecs_executor.py
@@ -60,7 +60,7 @@ from airflow.version import version as airflow_version_str
from tests_common import RUNNING_TESTS_AGAINST_AIRFLOW_PACKAGES
from tests_common.test_utils.config import conf_vars
-from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
+from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS,
AIRFLOW_V_3_3_PLUS
airflow_version = VersionInfo(*map(int, airflow_version_str.split(".")[:3]))
@@ -399,11 +399,11 @@ class TestAwsEcsExecutor:
"failures": [],
}
- assert len(mock_executor.pending_tasks) == 0
+ assert len(mock_executor.pending_workloads) == 0
mock_executor.execute_async(airflow_key, mock_cmd)
- assert len(mock_executor.pending_tasks) == 1
+ assert len(mock_executor.pending_workloads) == 1
- mock_executor.attempt_task_runs()
+ mock_executor.attempt_workload_runs()
mock_executor.ecs.run_task.assert_called_once()
# Task is stored in active worker.
@@ -442,14 +442,14 @@ class TestAwsEcsExecutor:
}
assert mock_executor.queued_tasks[workload.ti.key] == workload
- assert len(mock_executor.pending_tasks) == 0
+ assert len(mock_executor.pending_workloads) == 0
assert len(mock_executor.running) == 0
mock_executor._process_workloads([workload])
assert len(mock_executor.queued_tasks) == 0
assert len(mock_executor.running) == 1
assert workload.ti.key in mock_executor.running
- assert len(mock_executor.pending_tasks) == 1
- assert mock_executor.pending_tasks[0].command == [
+ assert len(mock_executor.pending_workloads) == 1
+ assert mock_executor.pending_workloads[0].command == [
"python",
"-m",
"airflow.sdk.execution_time.execute_workload",
@@ -457,9 +457,9 @@ class TestAwsEcsExecutor:
'{"test_key": "test_value"}',
]
- mock_executor.attempt_task_runs()
+ mock_executor.attempt_workload_runs()
mock_executor.ecs.run_task.assert_called_once()
- assert len(mock_executor.pending_tasks) == 0
+ assert len(mock_executor.pending_workloads) == 0
mock_executor.ecs.run_task.assert_called_once_with(
cluster="some-cluster",
count=1,
@@ -524,13 +524,13 @@ class TestAwsEcsExecutor:
# Fail 2 times
for _ in range(expected_retry_count):
- mock_executor.attempt_task_runs()
+ mock_executor.attempt_workload_runs()
# Task is not stored in active workers.
assert len(mock_executor.active_workers) == 0
# Pass in last attempt
- mock_executor.attempt_task_runs()
- assert len(mock_executor.pending_tasks) == 0
+ mock_executor.attempt_workload_runs()
+ assert len(mock_executor.pending_workloads) == 0
assert ARN1 in mock_executor.active_workers.get_all_arns()
assert mock_backoff.call_count == expected_retry_count
for attempt_number in range(1, expected_retry_count):
@@ -543,7 +543,7 @@ class TestAwsEcsExecutor:
# No matter what, don't schedule until run_task becomes successful.
for _ in range(int(mock_executor.max_run_task_attempts) * 2):
- mock_executor.attempt_task_runs()
+ mock_executor.attempt_workload_runs()
# Task is not stored in active workers.
assert len(mock_executor.active_workers) == 0
@@ -560,12 +560,12 @@ class TestAwsEcsExecutor:
# No matter what, don't schedule until run_task becomes successful.
for _ in range(int(mock_executor.max_run_task_attempts) * 2):
- mock_executor.attempt_task_runs()
+ mock_executor.attempt_workload_runs()
# Task is not stored in active workers.
assert len(mock_executor.active_workers) == 0
@mock.patch.object(ecs_executor, "calculate_next_attempt_delay",
return_value=dt.timedelta(seconds=0))
- def test_attempt_task_runs_attempts_when_tasks_fail(self, _,
mock_executor):
+ def test_attempt_workload_runs_attempts_when_tasks_fail(self, _,
mock_executor):
"""
Test case when all tasks fail to run.
@@ -586,36 +586,36 @@ class TestAwsEcsExecutor:
mock_executor.execute_async(airflow_keys[0], commands[0])
mock_executor.execute_async(airflow_keys[1], commands[1])
- assert len(mock_executor.pending_tasks) == 2
+ assert len(mock_executor.pending_workloads) == 2
assert len(mock_executor.active_workers.get_all_arns()) == 0
mock_executor.ecs.run_task.side_effect = failures
- mock_executor.attempt_task_runs()
+ mock_executor.attempt_workload_runs()
for i in range(2):
RUN_TASK_KWARGS["overrides"]["containerOverrides"][0]["command"] =
commands[i]
assert mock_executor.ecs.run_task.call_args_list[i].kwargs ==
RUN_TASK_KWARGS
- assert len(mock_executor.pending_tasks) == 2
+ assert len(mock_executor.pending_workloads) == 2
assert len(mock_executor.active_workers.get_all_arns()) == 0
mock_executor.ecs.run_task.call_args_list.clear()
mock_executor.ecs.run_task.side_effect = failures
- mock_executor.attempt_task_runs()
+ mock_executor.attempt_workload_runs()
for i in range(2):
RUN_TASK_KWARGS["overrides"]["containerOverrides"][0]["command"] =
commands[i]
assert mock_executor.ecs.run_task.call_args_list[i].kwargs ==
RUN_TASK_KWARGS
- assert len(mock_executor.pending_tasks) == 2
+ assert len(mock_executor.pending_workloads) == 2
assert len(mock_executor.active_workers.get_all_arns()) == 0
mock_executor.ecs.run_task.call_args_list.clear()
mock_executor.ecs.run_task.side_effect = failures
- mock_executor.attempt_task_runs()
+ mock_executor.attempt_workload_runs()
assert len(mock_executor.active_workers.get_all_arns()) == 0
- assert len(mock_executor.pending_tasks) == 0
+ assert len(mock_executor.pending_workloads) == 0
if airflow_version >= (2, 10, 0):
events = [(x.event, x.task_id, x.try_number) for x in
mock_executor._task_event_logs]
@@ -625,7 +625,7 @@ class TestAwsEcsExecutor:
]
@mock.patch.object(ecs_executor, "calculate_next_attempt_delay",
return_value=dt.timedelta(seconds=0))
- def test_attempt_task_runs_attempts_when_some_tasks_fal(self, _,
mock_executor):
+ def test_attempt_workload_runs_attempts_when_some_tasks_fal(self, _,
mock_executor):
"""
Test case when one task fail to run, and a new task gets queued.
@@ -654,16 +654,16 @@ class TestAwsEcsExecutor:
mock_executor.execute_async(airflow_keys[0], airflow_commands[0])
mock_executor.execute_async(airflow_keys[1], airflow_commands[1])
- assert len(mock_executor.pending_tasks) == 2
+ assert len(mock_executor.pending_workloads) == 2
mock_executor.ecs.run_task.side_effect = responses
- mock_executor.attempt_task_runs()
+ mock_executor.attempt_workload_runs()
for i in range(2):
RUN_TASK_KWARGS["overrides"]["containerOverrides"][0]["command"] =
airflow_commands[i]
assert mock_executor.ecs.run_task.call_args_list[i].kwargs ==
RUN_TASK_KWARGS
- assert len(mock_executor.pending_tasks) == 1
+ assert len(mock_executor.pending_workloads) == 1
assert len(mock_executor.active_workers.get_all_arns()) == 1
mock_executor.ecs.run_task.call_args_list.clear()
@@ -673,29 +673,29 @@ class TestAwsEcsExecutor:
airflow_commands[1] = _generate_mock_cmd()
mock_executor.execute_async(airflow_keys[1], airflow_commands[1])
- assert len(mock_executor.pending_tasks) == 2
+ assert len(mock_executor.pending_workloads) == 2
# assert that the order of pending tasks is preserved i.e. the first
task is 1st etc.
- assert mock_executor.pending_tasks[0].key == airflow_keys[0]
- assert mock_executor.pending_tasks[0].command == airflow_commands[0]
+ assert mock_executor.pending_workloads[0].key == airflow_keys[0]
+ assert mock_executor.pending_workloads[0].command ==
airflow_commands[0]
task["taskArn"] = ARN2
success_response = {"tasks": [task], "failures": []}
responses = [Exception("Failure 1"), success_response]
mock_executor.ecs.run_task.side_effect = responses
- mock_executor.attempt_task_runs()
+ mock_executor.attempt_workload_runs()
for i in range(2):
RUN_TASK_KWARGS["overrides"]["containerOverrides"][0]["command"] =
airflow_commands[i]
assert mock_executor.ecs.run_task.call_args_list[i].kwargs ==
RUN_TASK_KWARGS
- assert len(mock_executor.pending_tasks) == 1
+ assert len(mock_executor.pending_workloads) == 1
assert len(mock_executor.active_workers.get_all_arns()) == 2
mock_executor.ecs.run_task.call_args_list.clear()
responses = [Exception("Failure 1")]
mock_executor.ecs.run_task.side_effect = responses
- mock_executor.attempt_task_runs()
+ mock_executor.attempt_workload_runs()
RUN_TASK_KWARGS["overrides"]["containerOverrides"][0]["command"] =
airflow_commands[0]
assert mock_executor.ecs.run_task.call_args_list[0].kwargs ==
RUN_TASK_KWARGS
@@ -718,7 +718,7 @@ class TestAwsEcsExecutor:
mock_executor.execute_async(airflow_keys[0], airflow_commands[0])
mock_executor.execute_async(airflow_keys[1], airflow_commands[1])
- assert len(mock_executor.pending_tasks) == 2
+ assert len(mock_executor.pending_workloads) == 2
caplog.set_level("WARNING")
describe_tasks = [
@@ -770,19 +770,19 @@ class TestAwsEcsExecutor:
]
mock_executor.ecs.describe_tasks.side_effect = [{"tasks":
describe_tasks, "failures": []}]
- mock_executor.attempt_task_runs()
+ mock_executor.attempt_workload_runs()
for i in range(2):
RUN_TASK_KWARGS["overrides"]["containerOverrides"][0]["command"] =
airflow_commands[i]
assert mock_executor.ecs.run_task.call_args_list[i].kwargs ==
RUN_TASK_KWARGS
- assert len(mock_executor.pending_tasks) == 0
+ assert len(mock_executor.pending_workloads) == 0
assert len(mock_executor.active_workers.get_all_arns()) == 2
- mock_executor.sync_running_tasks()
+ mock_executor.sync_running_workloads()
for i in range(2):
assert (
- f"Airflow task {airflow_keys[i]} failed due to
{describe_tasks[i]['stoppedReason']}. Failure 1 out of 2"
+ f"Airflow workload {airflow_keys[i]} failed due to
{describe_tasks[i]['stoppedReason']}. Failure 1 out of 2"
in caplog.messages[i]
)
@@ -795,18 +795,18 @@ class TestAwsEcsExecutor:
]
mock_executor.ecs.describe_tasks.side_effect = [{"tasks":
describe_tasks, "failures": []}]
- mock_executor.attempt_task_runs()
+ mock_executor.attempt_workload_runs()
- mock_executor.attempt_task_runs()
+ mock_executor.attempt_workload_runs()
for i in range(2):
RUN_TASK_KWARGS["overrides"]["containerOverrides"][0]["command"] =
airflow_commands[i]
assert mock_executor.ecs.run_task.call_args_list[i].kwargs ==
RUN_TASK_KWARGS
- mock_executor.sync_running_tasks()
+ mock_executor.sync_running_workloads()
for i in range(2):
assert (
- f"Airflow task {airflow_keys[i]} has failed a maximum of 2
times. Marking as failed"
+ f"Airflow workload {airflow_keys[i]} has failed a maximum of 2
times. Marking as failed"
in caplog.messages[i]
)
@@ -816,7 +816,7 @@ class TestAwsEcsExecutor:
"""Test sync from end-to-end."""
self._mock_sync(mock_executor)
- mock_executor.sync_running_tasks()
+ mock_executor.sync_running_workloads()
mock_executor.ecs.describe_tasks.assert_called_once()
# Task is not stored in active workers.
@@ -831,7 +831,7 @@ class TestAwsEcsExecutor:
def test_sync_short_circuits_with_no_arns(self, _, success_mock,
fail_mock, mock_executor):
self._mock_sync(mock_executor)
- mock_executor.sync_running_tasks()
+ mock_executor.sync_running_workloads()
mock_executor.ecs.describe_tasks.assert_not_called()
fail_mock.assert_not_called()
@@ -860,7 +860,7 @@ class TestAwsEcsExecutor:
mock_executor.max_run_task_attempts = "1"
self._mock_sync(mock_executor, expected_state=State.REMOVED,
set_task_state=State.REMOVED)
- mock_executor.sync_running_tasks()
+ mock_executor.sync_running_workloads()
# Task is not stored in active workers.
assert len(mock_executor.active_workers) == 0
@@ -886,11 +886,11 @@ class TestAwsEcsExecutor:
task_key = mock_airflow_key()
mock_executor.execute_async(task_key, mock_cmd)
for _ in range(2):
- assert len(mock_executor.pending_tasks) == 1
- keys = [task.key for task in mock_executor.pending_tasks]
+ assert len(mock_executor.pending_workloads) == 1
+ keys = [task.key for task in mock_executor.pending_workloads]
assert task_key in keys
- mock_executor.attempt_task_runs()
- assert len(mock_executor.pending_tasks) == 1
+ mock_executor.attempt_workload_runs()
+ assert len(mock_executor.pending_workloads) == 1
mock_executor.ecs.run_task.return_value = {
"tasks": [
@@ -903,8 +903,8 @@ class TestAwsEcsExecutor:
],
"failures": [],
}
- mock_executor.attempt_task_runs()
- assert len(mock_executor.pending_tasks) == 0
+ mock_executor.attempt_workload_runs()
+ assert len(mock_executor.pending_workloads) == 0
assert ARN1 in mock_executor.active_workers.get_all_arns()
mock_executor.ecs.describe_tasks.return_value = {
@@ -914,19 +914,19 @@ class TestAwsEcsExecutor:
],
}
- # Call sync_running_tasks and attempt_task_runs 2 times with failures.
+ # Call sync_running_workloads and attempt_workload_runs 2 times with
failures.
for _ in range(2):
- mock_executor.sync_running_tasks()
+ mock_executor.sync_running_workloads()
# Ensure task gets removed from active_workers.
assert ARN1 not in mock_executor.active_workers.get_all_arns()
- # Ensure task gets back on the pending_tasks queue
- assert len(mock_executor.pending_tasks) == 1
- keys = [task.key for task in mock_executor.pending_tasks]
+ # Ensure task gets back on the pending_workloads queue
+ assert len(mock_executor.pending_workloads) == 1
+ keys = [task.key for task in mock_executor.pending_workloads]
assert task_key in keys
- mock_executor.attempt_task_runs()
- assert len(mock_executor.pending_tasks) == 0
+ mock_executor.attempt_workload_runs()
+ assert len(mock_executor.pending_workloads) == 0
assert ARN1 in mock_executor.active_workers.get_all_arns()
# Task is neither failed nor succeeded.
@@ -940,7 +940,7 @@ class TestAwsEcsExecutor:
# 2 run_task failures + 2 describe_task failures = 4 failures
# Last call should fail the task.
- mock_executor.sync_running_tasks()
+ mock_executor.sync_running_workloads()
assert ARN1 not in mock_executor.active_workers.get_all_arns()
fail_mock.assert_called()
success_mock.assert_not_called()
@@ -960,7 +960,7 @@ class TestAwsEcsExecutor:
"""Test what happens when ECS sync fails for certain tasks
repeatedly."""
airflow_key = TaskInstanceKey("dag", "task", "run", 1, -1)
mock_executor.execute_async(airflow_key, mock_cmd)
- assert len(mock_executor.pending_tasks) == 1
+ assert len(mock_executor.pending_workloads) == 1
run_task_ret_val = {
"taskArn": ARN1,
@@ -982,35 +982,35 @@ class TestAwsEcsExecutor:
],
}
mock_executor.ecs.describe_tasks.return_value =
describe_tasks_ret_value
- mock_executor.attempt_task_runs()
- assert len(mock_executor.pending_tasks) == 0
+ mock_executor.attempt_workload_runs()
+ assert len(mock_executor.pending_workloads) == 0
assert len(mock_executor.active_workers.get_all_arns()) == 1
task_key = mock_executor.active_workers.arn_to_key[ARN1]
# Call Sync 2 times with failures. The task can only fail
max_run_task_attempts times.
for check_count in range(1, int(mock_executor.max_run_task_attempts)):
- mock_executor.sync_running_tasks()
+ mock_executor.sync_running_workloads()
assert mock_executor.ecs.describe_tasks.call_count == check_count
# Ensure task gets removed from active_workers.
assert ARN1 not in mock_executor.active_workers.get_all_arns()
- # Ensure task gets back on the pending_tasks queue
- assert len(mock_executor.pending_tasks) == 1
- keys = [task.key for task in mock_executor.pending_tasks]
+ # Ensure task gets back on the pending_workloads queue
+ assert len(mock_executor.pending_workloads) == 1
+ keys = [task.key for task in mock_executor.pending_workloads]
assert task_key in keys
# Task is neither failed nor succeeded.
fail_mock.assert_not_called()
success_mock.assert_not_called()
- mock_executor.attempt_task_runs()
+ mock_executor.attempt_workload_runs()
- assert len(mock_executor.pending_tasks) == 0
+ assert len(mock_executor.pending_workloads) == 0
assert len(mock_executor.active_workers.get_all_arns()) == 1
assert ARN1 in mock_executor.active_workers.get_all_arns()
task_key = mock_executor.active_workers.arn_to_key[ARN1]
# Last call should fail the task.
- mock_executor.sync_running_tasks()
+ mock_executor.sync_running_workloads()
assert ARN1 not in mock_executor.active_workers.get_all_arns()
fail_mock.assert_called()
success_mock.assert_not_called()
@@ -1104,7 +1104,7 @@ class TestAwsEcsExecutor:
with pytest.raises(ValueError, match='Executor Config should never
override "name" or "command"'):
mock_executor.execute_async(mock_airflow_key, mock_cmd,
executor_config=bad_config)
- assert len(mock_executor.pending_tasks) == 0
+ assert len(mock_executor.pending_workloads) == 0
@mock.patch.object(ecs_executor_config, "build_task_kwargs")
def test_container_not_found(self, mock_build_task_kwargs, mock_executor):
@@ -1118,7 +1118,7 @@ class TestAwsEcsExecutor:
'"overrides[containerOverrides][containers][x][command]"'
)
)
- assert len(mock_executor.pending_tasks) == 0
+ assert len(mock_executor.pending_workloads) == 0
def _mock_sync(
self,
@@ -1127,7 +1127,7 @@ class TestAwsEcsExecutor:
set_task_state=TaskInstanceState.RUNNING,
) -> None:
"""Mock ECS to the expected state."""
- executor.pending_tasks.clear()
+ executor.pending_workloads.clear()
self._add_mock_task(executor, ARN1, set_task_state)
response_task_json = {
@@ -1194,7 +1194,7 @@ class TestAwsEcsExecutor:
],
}
mock_executor.ecs.describe_tasks.return_value = {"tasks":
[test_response_task_json], "failures": []}
- mock_executor.sync_running_tasks()
+ mock_executor.sync_running_workloads()
if expected_status != State.REMOVED:
assert mock_executor.active_workers.tasks["arn1"].get_task_state()
== expected_status
# The task is not removed from active_workers in these states
@@ -1223,7 +1223,7 @@ class TestAwsEcsExecutor:
)
mock_success_function = patcher.start()
mock_executor.ecs.describe_tasks.return_value = {"tasks":
[test_response_task_json], "failures": []}
- mock_executor.sync_running_tasks()
+ mock_executor.sync_running_workloads()
assert len(mock_executor.active_workers) == 0
mock_success_function.assert_called_once()
@@ -1252,7 +1252,7 @@ class TestAwsEcsExecutor:
)
mock_failed_function = patcher.start()
mock_executor.ecs.describe_tasks.return_value = {"tasks":
[test_response_task_json], "failures": []}
- mock_executor.sync_running_tasks()
+ mock_executor.sync_running_workloads()
assert len(mock_executor.active_workers) == 0
mock_failed_function.assert_called_once()
assert (
@@ -1986,3 +1986,116 @@ class TestEcsExecutorConfig:
from airflow.providers.amazon.aws.executors.ecs import AwsEcsExecutor
as AwsEcsExecutorShortPath
assert AwsEcsExecutor is AwsEcsExecutorShortPath
+
+
+class TestEcsExecutorCallbackSupport:
+ """Tests for ExecuteCallback support in the ECS Executor."""
+
+ @pytest.fixture
+ def callback_workload(self):
+ """Create a mock ExecuteCallback workload for testing."""
+ from airflow.executors.workloads import ExecuteCallback
+ from airflow.executors.workloads.base import BundleInfo
+ from airflow.executors.workloads.callback import CallbackDTO,
CallbackFetchMethod
+
+ callback_data = CallbackDTO(
+ id="12345678-1234-5678-1234-567812345678",
+ fetch_method=CallbackFetchMethod.IMPORT_PATH,
+ data={"path": "test.module.alert_func", "kwargs": {}},
+ )
+ return ExecuteCallback(
+ callback=callback_data,
+ dag_rel_path="test.py",
+ bundle_info=BundleInfo(name="test_bundle", version="1.0"),
+ token="test_token",
+ log_path="test.log",
+ )
+
+ @pytest.mark.skipif(not AIRFLOW_V_3_3_PLUS, reason="Test requires Airflow
3.3+")
+ def test_supports_callbacks_attribute(self, mock_executor):
+ """Verify that the ECS executor declares callback support."""
+ assert mock_executor.supports_callbacks is True
+
+ @pytest.mark.skipif(not AIRFLOW_V_3_3_PLUS, reason="Test requires Airflow
3.3+")
+ def test_queue_callback_workload(self, mock_executor, callback_workload):
+ """Test that queue_workload correctly stores ExecuteCallback in
queued_callbacks."""
+ mock_executor.queue_workload(callback_workload, session=None)
+
+ assert len(mock_executor.queued_callbacks) == 1
+ assert callback_workload.callback.key in mock_executor.queued_callbacks
+ assert mock_executor.queued_callbacks[callback_workload.callback.key]
is callback_workload
+
+ @pytest.mark.skipif(not AIRFLOW_V_3_3_PLUS, reason="Test requires Airflow
3.3+")
+ def test_process_callback_workload(self, mock_executor, callback_workload):
+ """Test that _process_workloads handles ExecuteCallback correctly."""
+ callback_key = callback_workload.callback.key
+ mock_executor.queued_callbacks[callback_key] = callback_workload
+
+ mock_executor._process_workloads([callback_workload])
+
+ # Callback should be removed from queued_callbacks
+ assert callback_key not in mock_executor.queued_callbacks
+ # Callback should be added to running set
+ assert callback_key in mock_executor.running
+ # Callback should be added to pending_workloads for execution
+ assert len(mock_executor.pending_workloads) == 1
+ queued = mock_executor.pending_workloads[0]
+ assert queued.key == callback_key
+ assert queued.queue is None
+
+ @pytest.mark.skipif(not AIRFLOW_V_3_3_PLUS, reason="Test requires Airflow
3.3+")
+ def test_execute_async_callback_workload(self, mock_executor,
callback_workload):
+ """Test that execute_async serializes ExecuteCallback workloads
correctly."""
+ callback_key = callback_workload.callback.key
+ mock_executor.execute_async(key=callback_key,
command=[callback_workload], queue=None)
+
+ assert len(mock_executor.pending_workloads) == 1
+ queued = mock_executor.pending_workloads[0]
+ assert queued.key == callback_key
+ # Command should be serialized to the execute_workload entrypoint
+ assert queued.command[0] == "python"
+ assert queued.command[2] ==
"airflow.sdk.execution_time.execute_workload"
+ assert queued.command[3] == "--json-string"
+
+ @pytest.mark.skipif(not AIRFLOW_V_3_3_PLUS, reason="Test requires Airflow
3.3+")
+ def test_callback_sync_running_success(self, mock_executor,
callback_workload):
+ """Test that sync_running_workloads correctly handles successful
callback ECS tasks."""
+ callback_key = callback_workload.callback.key
+ ecs_task = mock_task(ARN1, State.SUCCESS)
+ mock_cmd = _generate_mock_cmd()
+ mock_executor.active_workers.add_task(ecs_task, callback_key, None,
mock_cmd, {}, 1)
+
+ mock_executor.ecs.describe_tasks.return_value = {
+ "tasks": [
+ {
+ "taskArn": ARN1,
+ "lastStatus": "STOPPED",
+ "desiredStatus": "STOPPED",
+ "containers": [{"name": "container-name", "exitCode": 0,
"lastStatus": "STOPPED"}],
+ "startedAt": "2024-01-01T00:00:00Z",
+ }
+ ],
+ "failures": [],
+ }
+
+ mock_executor.sync_running_workloads()
+
+ # Callback should be removed from active workers after success
+ assert len(mock_executor.active_workers) == 0
+
+ @pytest.mark.skipif(not AIRFLOW_V_3_3_PLUS, reason="Test requires Airflow
3.3+")
+ def test_collection_mixed_key_types(self):
+ """Test that EcsTaskCollection works with both TaskInstanceKey and
CallbackKey workload keys."""
+ from airflow.models.callback import CallbackKey
+
+ collection = EcsTaskCollection()
+ mock_cmd = _generate_mock_cmd()
+ task_key = mock.Mock(spec=TaskInstanceKey)
+ callback_key = CallbackKey("12345678-1234-5678-1234-567812345678")
+
+ collection.add_task(mock_task(ARN1), task_key, "default", mock_cmd,
{}, 1)
+ collection.add_task(mock_task(ARN2), callback_key, None, mock_cmd, {},
1)
+
+ assert len(collection) == 2
+ assert collection.key_to_arn[task_key] == ARN1
+ assert collection.key_to_arn[callback_key] == ARN2