This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 36b9c0f23c5 Fix ECS Executor compatibility with Airflow 3.x in
`try_adopt_task_instances` (#62192)
36b9c0f23c5 is described below
commit 36b9c0f23c5ae52a40c8a32b783858dc4ff151db
Author: Vincent <[email protected]>
AuthorDate: Sun Feb 22 14:21:58 2026 -0500
Fix ECS Executor compatibility with Airflow 3.x in
`try_adopt_task_instances` (#62192)
* Fix ECS Executor compatibility with Airflow 3.x in
try_adopt_task_instances
The try_adopt_task_instances method was calling ti.command_as_list() which
doesn't exist in Airflow 3.x due to Task SDK changes. This caused scheduler
to crash with AttributeError when trying to adopt orphaned ECS tasks.
* Fix ECS Executor compatibility with Airflow 3.x in
try_adopt_task_instances
---------
Co-authored-by: iamapez <[email protected]>
---
.../amazon/aws/executors/ecs/ecs_executor.py | 47 ++++++++---
.../amazon/aws/executors/ecs/test_ecs_executor.py | 90 +++++++++++++++++++++-
2 files changed, 125 insertions(+), 12 deletions(-)
diff --git
a/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py
b/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py
index 3605c803cd0..ede38853165 100644
---
a/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py
+++
b/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py
@@ -502,15 +502,7 @@ class AwsEcsExecutor(BaseExecutor):
from airflow.executors.workloads import ExecuteTask
if isinstance(command[0], ExecuteTask):
- workload = command[0]
- ser_input = workload.model_dump_json()
- command = [
- "python",
- "-m",
- "airflow.sdk.execution_time.execute_workload",
- "--json-string",
- ser_input,
- ]
+ command = self._serialize_workload_to_command(command[0])
else:
raise ValueError(
f"EcsExecutor doesn't know how to handle workload of type:
{type(command[0])}"
@@ -572,6 +564,39 @@ class AwsEcsExecutor(BaseExecutor):
)
raise KeyError(f"No such container found by container name:
{self.container_name}")
+ @staticmethod
+ def _serialize_workload_to_command(workload) -> CommandType:
+ """
+ Serialize an ExecuteTask workload into a command for the Task SDK.
+
+ :param workload: ExecuteTask workload to serialize
+ :return: Command as list of strings for Task SDK execution
+ """
+ return [
+ "python",
+ "-m",
+ "airflow.sdk.execution_time.execute_workload",
+ "--json-string",
+ workload.model_dump_json(),
+ ]
+
+ def _build_task_command(self, ti: TaskInstance) -> CommandType:
+ """
+ Build task command for execution based on Airflow version.
+
+ For Airflow 3.x+, generates an ExecuteTask workload with JSON
serialization.
+ For Airflow 2.x, uses the legacy command_as_list() method.
+
+ :param ti: TaskInstance to build command for
+ :return: Command as list of strings
+ """
+ if AIRFLOW_V_3_0_PLUS:
+ from airflow.executors.workloads import ExecuteTask
+
+ workload = ExecuteTask.make(ti)
+ return self._serialize_workload_to_command(workload)
+ return ti.command_as_list()
+
def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) ->
Sequence[TaskInstance]:
"""
Adopt task instances which have an external_executor_id (the ECS task
ARN).
@@ -586,11 +611,13 @@ class AwsEcsExecutor(BaseExecutor):
for task in task_descriptions:
ti = next(ti for ti in tis if ti.external_executor_id ==
task.task_arn)
+ command = self._build_task_command(ti)
+
self.active_workers.add_task(
task,
ti.key,
ti.queue,
- ti.command_as_list(),
+ command,
ti.executor_config,
ti.try_number,
)
diff --git
a/providers/amazon/tests/unit/amazon/aws/executors/ecs/test_ecs_executor.py
b/providers/amazon/tests/unit/amazon/aws/executors/ecs/test_ecs_executor.py
index 04e7d2bb822..f57cd13d28a 100644
--- a/providers/amazon/tests/unit/amazon/aws/executors/ecs/test_ecs_executor.py
+++ b/providers/amazon/tests/unit/amazon/aws/executors/ecs/test_ecs_executor.py
@@ -26,6 +26,7 @@ from collections.abc import Callable
from functools import partial
from unittest import mock
from unittest.mock import MagicMock, patch
+from uuid import uuid4
import pytest
import yaml
@@ -1249,7 +1250,6 @@ class TestAwsEcsExecutor:
"test failure" in caplog.messages[0]
)
- @pytest.mark.skip(reason="Adopting task instances hasn't been ported over
to Airflow 3 yet")
def test_try_adopt_task_instances(self, mock_executor):
"""Test that executor can adopt orphaned task instances from a
SchedulerJob shutdown event."""
mock_executor.ecs.describe_tasks.return_value = {
@@ -1278,8 +1278,41 @@ class TestAwsEcsExecutor:
orphaned_tasks[0].external_executor_id = "001" # Matches a running
task_arn
orphaned_tasks[1].external_executor_id = "002" # Matches a running
task_arn
orphaned_tasks[2].external_executor_id = None # One orphaned task has
no external_executor_id
- for task in orphaned_tasks:
+
+ for idx, task in enumerate(orphaned_tasks):
task.try_number = 1
+ task.key = mock.Mock(spec=TaskInstanceKey)
+ task.queue = "default"
+ task.executor_config = {}
+ task.id = uuid4()
+ task.dag_version_id = uuid4()
+ task.task_id = f"task_{idx}"
+ task.dag_id = "test_dag"
+ task.run_id = "test_run"
+ task.map_index = -1
+ task.pool_slots = 1
+ task.priority_weight = 1
+ task.context_carrier = {}
+ task.queued_dttm = dt.datetime.now()
+ # Set up nested attributes for BundleInfo
+ task.dag_model = mock.Mock()
+ task.dag_model.bundle_name = "test_bundle"
+ task.dag_model.relative_fileloc = "test_dag.py"
+ task.dag_run = mock.Mock()
+ task.dag_run.bundle_version = "1.0.0"
+ task.dag_run.context_carrier = {}
+
+ # Mock command generation based on Airflow version
+ if not AIRFLOW_V_3_0_PLUS:
+ # For Airflow 2.x, command_as_list will be called
+ task.command_as_list.return_value = [
+ "airflow",
+ "tasks",
+ "run",
+ "dag",
+ f"task_{idx}",
+ "2024-01-01",
+ ]
not_adopted_tasks =
mock_executor.try_adopt_task_instances(orphaned_tasks)
@@ -1886,6 +1919,59 @@ class TestEcsExecutorConfig:
assert final_run_task_kwargs == expected_result
+ @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow
3+")
+ def test_serialize_workload_to_command(self, mock_executor):
+ """Test that _serialize_workload_to_command properly serializes an
ExecuteTask workload."""
+ from airflow.executors.workloads import ExecuteTask
+
+ workload = mock.Mock(spec=ExecuteTask)
+ ser_workload = json.dumps({"test_key": "test_value"})
+ workload.model_dump_json.return_value = ser_workload
+
+ command = mock_executor._serialize_workload_to_command(workload)
+
+ assert command == [
+ "python",
+ "-m",
+ "airflow.sdk.execution_time.execute_workload",
+ "--json-string",
+ ser_workload,
+ ]
+ workload.model_dump_json.assert_called_once()
+
+ @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow
3+")
+ @mock.patch("airflow.executors.workloads.ExecuteTask")
+ def test_build_task_command_airflow3(self, mock_execute_task_class,
mock_executor):
+ """Test _build_task_command for Airflow 3.x+ using Task SDK."""
+ mock_ti = mock.Mock(spec=TaskInstance)
+ mock_workload = mock.Mock()
+ ser_workload = json.dumps({"task": "data"})
+ mock_workload.model_dump_json.return_value = ser_workload
+ mock_execute_task_class.make.return_value = mock_workload
+
+ command = mock_executor._build_task_command(mock_ti)
+
+ mock_execute_task_class.make.assert_called_once_with(mock_ti)
+ assert command == [
+ "python",
+ "-m",
+ "airflow.sdk.execution_time.execute_workload",
+ "--json-string",
+ ser_workload,
+ ]
+
+ @pytest.mark.skipif(AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 2.x")
+ def test_build_task_command_airflow2(self, mock_executor):
+ """Test _build_task_command for Airflow 2.x using command_as_list."""
+ mock_ti = mock.Mock(spec=TaskInstance)
+ expected_command = ["airflow", "tasks", "run", "dag_id", "task_id",
"execution_date"]
+ mock_ti.command_as_list.return_value = expected_command
+
+ command = mock_executor._build_task_command(mock_ti)
+
+ mock_ti.command_as_list.assert_called_once()
+ assert command == expected_command
+
def test_short_import_path(self):
from airflow.providers.amazon.aws.executors.ecs import AwsEcsExecutor
as AwsEcsExecutorShortPath