This is an automated email from the ASF dual-hosted git repository.

uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 7ed791dab7 Do not return success from AWS ECS trigger after 
max_attempts (#32589)
7ed791dab7 is described below

commit 7ed791dab72709fbc5c9c27687a8b014c3e9906d
Author: RaphaĆ«l Vandon <[email protected]>
AuthorDate: Tue Jul 18 14:45:00 2023 -0700

    Do not return success from AWS ECS trigger after max_attempts (#32589)
---
 airflow/providers/amazon/aws/triggers/ecs.py    |  7 ++++---
 tests/providers/amazon/aws/triggers/test_ecs.py | 21 +++++++++++++++++++++
 2 files changed, 25 insertions(+), 3 deletions(-)

diff --git a/airflow/providers/amazon/aws/triggers/ecs.py 
b/airflow/providers/amazon/aws/triggers/ecs.py
index 29ad22e13e..af6f72e771 100644
--- a/airflow/providers/amazon/aws/triggers/ecs.py
+++ b/airflow/providers/amazon/aws/triggers/ecs.py
@@ -22,6 +22,7 @@ from typing import Any, AsyncIterator
 
 from botocore.exceptions import ClientError, WaiterError
 
+from airflow import AirflowException
 from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
 from airflow.providers.amazon.aws.hooks.ecs import EcsHook
 from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
@@ -170,7 +171,8 @@ class TaskDoneTrigger(BaseTrigger):
                     await waiter.wait(
                         cluster=self.cluster, tasks=[self.task_arn], 
WaiterConfig={"MaxAttempts": 1}
                     )
-                    break  # we reach this point only if the waiter met a 
success criteria
+                    # we reach this point only if the waiter met a success 
criteria
+                    yield TriggerEvent({"status": "success", "task_arn": 
self.task_arn})
                 except WaiterError as error:
                     if "terminal failure" in str(error):
                         raise
@@ -179,8 +181,7 @@ class TaskDoneTrigger(BaseTrigger):
                 finally:
                     if self.log_group and self.log_stream:
                         logs_token = await self._forward_logs(logs_client, 
logs_token)
-
-        yield TriggerEvent({"status": "success", "task_arn": self.task_arn})
+        raise AirflowException("Waiter error: max attempts reached")
 
     async def _forward_logs(self, logs_client, next_token: str | None = None) 
-> str | None:
         """
diff --git a/tests/providers/amazon/aws/triggers/test_ecs.py 
b/tests/providers/amazon/aws/triggers/test_ecs.py
index 551ab39a44..e897bce740 100644
--- a/tests/providers/amazon/aws/triggers/test_ecs.py
+++ b/tests/providers/amazon/aws/triggers/test_ecs.py
@@ -22,6 +22,7 @@ from unittest.mock import AsyncMock
 import pytest
 from botocore.exceptions import WaiterError
 
+from airflow import AirflowException
 from airflow.providers.amazon.aws.hooks.ecs import EcsHook
 from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
 from airflow.providers.amazon.aws.triggers.ecs import (
@@ -56,6 +57,26 @@ class TestTaskDoneTrigger:
 
         assert wait_mock.call_count == 3
 
+    @pytest.mark.asyncio
+    @mock.patch.object(EcsHook, "async_conn")
+    # this mock is only necessary to avoid a "No module named 'aiobotocore'" 
error in the LatestBoto CI step
+    @mock.patch.object(AwsLogsHook, "async_conn")
+    async def test_run_until_timeout(self, _, client_mock):
+        a_mock = mock.MagicMock()
+        client_mock.__aenter__.return_value = a_mock
+        wait_mock = AsyncMock()
+        wait_mock.side_effect = WaiterError("name", "reason", {"tasks": 
[{"lastStatus": "my_status"}]})
+        a_mock.get_waiter().wait = wait_mock
+
+        trigger = TaskDoneTrigger("cluster", "task_arn", 0, 10, None, None)
+
+        with pytest.raises(AirflowException) as err:
+            generator = trigger.run()
+            await generator.asend(None)
+
+        assert wait_mock.call_count == 10
+        assert "max attempts" in str(err.value)
+
     @pytest.mark.asyncio
     @mock.patch.object(EcsHook, "async_conn")
     # this mock is only necessary to avoid a "No module named 'aiobotocore'" 
error in the LatestBoto CI step

Reply via email to