This is an automated email from the ASF dual-hosted git repository.
uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 7ed791dab7 Do not return success from AWS ECS trigger after
max_attempts (#32589)
7ed791dab7 is described below
commit 7ed791dab72709fbc5c9c27687a8b014c3e9906d
Author: Raphaƫl Vandon <[email protected]>
AuthorDate: Tue Jul 18 14:45:00 2023 -0700
Do not return success from AWS ECS trigger after max_attempts (#32589)
---
airflow/providers/amazon/aws/triggers/ecs.py | 7 ++++---
tests/providers/amazon/aws/triggers/test_ecs.py | 21 +++++++++++++++++++++
2 files changed, 25 insertions(+), 3 deletions(-)
diff --git a/airflow/providers/amazon/aws/triggers/ecs.py
b/airflow/providers/amazon/aws/triggers/ecs.py
index 29ad22e13e..af6f72e771 100644
--- a/airflow/providers/amazon/aws/triggers/ecs.py
+++ b/airflow/providers/amazon/aws/triggers/ecs.py
@@ -22,6 +22,7 @@ from typing import Any, AsyncIterator
from botocore.exceptions import ClientError, WaiterError
+from airflow import AirflowException
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
from airflow.providers.amazon.aws.hooks.ecs import EcsHook
from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
@@ -170,7 +171,8 @@ class TaskDoneTrigger(BaseTrigger):
await waiter.wait(
cluster=self.cluster, tasks=[self.task_arn],
WaiterConfig={"MaxAttempts": 1}
)
- break # we reach this point only if the waiter met a
success criteria
+ # we reach this point only if the waiter met a success
criteria
+ yield TriggerEvent({"status": "success", "task_arn":
self.task_arn})
except WaiterError as error:
if "terminal failure" in str(error):
raise
@@ -179,8 +181,7 @@ class TaskDoneTrigger(BaseTrigger):
finally:
if self.log_group and self.log_stream:
logs_token = await self._forward_logs(logs_client,
logs_token)
-
- yield TriggerEvent({"status": "success", "task_arn": self.task_arn})
+ raise AirflowException("Waiter error: max attempts reached")
async def _forward_logs(self, logs_client, next_token: str | None = None)
-> str | None:
"""
diff --git a/tests/providers/amazon/aws/triggers/test_ecs.py
b/tests/providers/amazon/aws/triggers/test_ecs.py
index 551ab39a44..e897bce740 100644
--- a/tests/providers/amazon/aws/triggers/test_ecs.py
+++ b/tests/providers/amazon/aws/triggers/test_ecs.py
@@ -22,6 +22,7 @@ from unittest.mock import AsyncMock
import pytest
from botocore.exceptions import WaiterError
+from airflow import AirflowException
from airflow.providers.amazon.aws.hooks.ecs import EcsHook
from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
from airflow.providers.amazon.aws.triggers.ecs import (
@@ -56,6 +57,26 @@ class TestTaskDoneTrigger:
assert wait_mock.call_count == 3
+ @pytest.mark.asyncio
+ @mock.patch.object(EcsHook, "async_conn")
+ # this mock is only necessary to avoid a "No module named 'aiobotocore'"
error in the LatestBoto CI step
+ @mock.patch.object(AwsLogsHook, "async_conn")
+ async def test_run_until_timeout(self, _, client_mock):
+ a_mock = mock.MagicMock()
+ client_mock.__aenter__.return_value = a_mock
+ wait_mock = AsyncMock()
+ wait_mock.side_effect = WaiterError("name", "reason", {"tasks":
[{"lastStatus": "my_status"}]})
+ a_mock.get_waiter().wait = wait_mock
+
+ trigger = TaskDoneTrigger("cluster", "task_arn", 0, 10, None, None)
+
+ with pytest.raises(AirflowException) as err:
+ generator = trigger.run()
+ await generator.asend(None)
+
+ assert wait_mock.call_count == 10
+ assert "max attempts" in str(err.value)
+
@pytest.mark.asyncio
@mock.patch.object(EcsHook, "async_conn")
# this mock is only necessary to avoid a "No module named 'aiobotocore'"
error in the LatestBoto CI step