This is an automated email from the ASF dual-hosted git repository.

onikolas pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 0b7c095c9f Change Deferrable implementation for 
RedshiftPauseClusterOperator to follow standard (#30853)
0b7c095c9f is described below

commit 0b7c095c9fa54981248a72659da9acdf3bf5c2c0
Author: Syed Hussaain <[email protected]>
AuthorDate: Tue May 23 11:07:37 2023 -0700

    Change Deferrable implementation for RedshiftPauseClusterOperator to follow 
standard (#30853)
    
    * Change base_aws.py to support async_conn
    * Add async custom waiter support in get_waiter, and base_waiter.py
    * Add Deferrable mode to RedshiftCreateClusterOperator
    * Add RedshiftCreateClusterTrigger and unit test
    * Add README.md for writing Triggers for AMPP
    * Add Deferrable mode to Redshift Pause Cluster Operator
    * Add logging to deferrable waiter
    * Add check for failure early
---
 .../amazon/aws/operators/redshift_cluster.py       |  69 +++++-----
 .../amazon/aws/triggers/redshift_cluster.py        |  73 ++++++++++
 airflow/providers/amazon/aws/waiters/redshift.json |  30 +++++
 .../amazon/aws/operators/test_redshift_cluster.py  |  14 +-
 .../amazon/aws/triggers/test_redshift_cluster.py   | 147 ++++++++++++++++++++-
 5 files changed, 291 insertions(+), 42 deletions(-)

diff --git a/airflow/providers/amazon/aws/operators/redshift_cluster.py 
b/airflow/providers/amazon/aws/operators/redshift_cluster.py
index 2880240b15..6b44785da6 100644
--- a/airflow/providers/amazon/aws/operators/redshift_cluster.py
+++ b/airflow/providers/amazon/aws/operators/redshift_cluster.py
@@ -17,6 +17,7 @@
 from __future__ import annotations
 
 import time
+from datetime import timedelta
 from typing import TYPE_CHECKING, Any, Sequence
 
 from airflow.exceptions import AirflowException
@@ -25,6 +26,7 @@ from airflow.providers.amazon.aws.hooks.redshift_cluster 
import RedshiftHook
 from airflow.providers.amazon.aws.triggers.redshift_cluster import (
     RedshiftClusterTrigger,
     RedshiftCreateClusterTrigger,
+    RedshiftPauseClusterTrigger,
 )
 
 if TYPE_CHECKING:
@@ -510,7 +512,9 @@ class RedshiftPauseClusterOperator(BaseOperator):
 
     :param cluster_identifier: id of the AWS Redshift Cluster
     :param aws_conn_id: aws connection to use
-    :param deferrable: Run operator in the deferrable mode. This mode requires 
an additional aiobotocore>=
+    :param deferrable: Run operator in the deferrable mode
+    :param poll_interval: Time (in seconds) to wait between two consecutive 
calls to check cluster state
+    :param max_attempts: Maximum number of attempts to poll the cluster
     """
 
     template_fields: Sequence[str] = ("cluster_identifier",)
@@ -524,64 +528,57 @@ class RedshiftPauseClusterOperator(BaseOperator):
         aws_conn_id: str = "aws_default",
         deferrable: bool = False,
         poll_interval: int = 10,
+        max_attempts: int = 15,
         **kwargs,
     ):
         super().__init__(**kwargs)
         self.cluster_identifier = cluster_identifier
         self.aws_conn_id = aws_conn_id
         self.deferrable = deferrable
+        self.max_attempts = max_attempts
         self.poll_interval = poll_interval
-        # These parameters are added to address an issue with the boto3 API 
where the API
+        # These parameters are used to address an issue with the boto3 API 
where the API
         # prematurely reports the cluster as available to receive requests. 
This causes the cluster
         # to reject initial attempts to pause the cluster despite reporting 
the correct state.
-        self._attempts = 10
+        self._remaining_attempts = 10
         self._attempt_interval = 15
 
     def execute(self, context: Context):
         redshift_hook = RedshiftHook(aws_conn_id=self.aws_conn_id)
+        while self._remaining_attempts >= 1:
+            try:
+                
redshift_hook.get_conn().pause_cluster(ClusterIdentifier=self.cluster_identifier)
+                break
+            except 
redshift_hook.get_conn().exceptions.InvalidClusterStateFault as error:
+                self._remaining_attempts = self._remaining_attempts - 1
 
+                if self._remaining_attempts > 0:
+                    self.log.error(
+                        "Unable to pause cluster. %d attempts remaining.", 
self._remaining_attempts
+                    )
+                    time.sleep(self._attempt_interval)
+                else:
+                    raise error
         if self.deferrable:
             self.defer(
-                timeout=self.execution_timeout,
-                trigger=RedshiftClusterTrigger(
-                    task_id=self.task_id,
+                trigger=RedshiftPauseClusterTrigger(
+                    cluster_identifier=self.cluster_identifier,
                     poll_interval=self.poll_interval,
+                    max_attempts=self.max_attempts,
                     aws_conn_id=self.aws_conn_id,
-                    cluster_identifier=self.cluster_identifier,
-                    attempts=self._attempts,
-                    operation_type="pause_cluster",
                 ),
                 method_name="execute_complete",
+                # timeout is set to ensure that if a trigger dies, the timeout 
does not restart
+                # 60 seconds is added to allow the trigger to exit gracefully 
(i.e. yield TriggerEvent)
+                timeout=timedelta(seconds=self.max_attempts * 
self.poll_interval + 60),
             )
-        else:
-            while self._attempts >= 1:
-                try:
-                    
redshift_hook.get_conn().pause_cluster(ClusterIdentifier=self.cluster_identifier)
-                    return
-                except 
redshift_hook.get_conn().exceptions.InvalidClusterStateFault as error:
-                    self._attempts = self._attempts - 1
-
-                    if self._attempts > 0:
-                        self.log.error("Unable to pause cluster. %d attempts 
remaining.", self._attempts)
-                        time.sleep(self._attempt_interval)
-                    else:
-                        raise error
 
-    def execute_complete(self, context: Context, event: Any = None) -> None:
-        """
-        Callback for when the trigger fires - returns immediately.
-        Relies on trigger to throw an exception, otherwise it assumes 
execution was
-        successful.
-        """
-        if event:
-            if "status" in event and event["status"] == "error":
-                msg = f"{event['status']}: {event['message']}"
-                raise AirflowException(msg)
-            elif "status" in event and event["status"] == "success":
-                self.log.info("%s completed successfully.", self.task_id)
-                self.log.info("Paused cluster successfully")
+    def execute_complete(self, context, event=None):
+        if event["status"] != "success":
+            raise AirflowException(f"Error pausing cluster: {event}")
         else:
-            raise AirflowException("No event received from trigger")
+            self.log.info("Paused cluster successfully")
+        return
 
 
 class RedshiftDeleteClusterOperator(BaseOperator):
diff --git a/airflow/providers/amazon/aws/triggers/redshift_cluster.py 
b/airflow/providers/amazon/aws/triggers/redshift_cluster.py
index ef19d0b5a1..879f027e35 100644
--- a/airflow/providers/amazon/aws/triggers/redshift_cluster.py
+++ b/airflow/providers/amazon/aws/triggers/redshift_cluster.py
@@ -16,8 +16,11 @@
 # under the License.
 from __future__ import annotations
 
+import asyncio
 from typing import Any, AsyncIterator
 
+from botocore.exceptions import WaiterError
+
 from airflow.compat.functools import cached_property
 from airflow.providers.amazon.aws.hooks.redshift_cluster import 
RedshiftAsyncHook, RedshiftHook
 from airflow.triggers.base import BaseTrigger, TriggerEvent
@@ -137,3 +140,73 @@ class RedshiftCreateClusterTrigger(BaseTrigger):
                 },
             )
         yield TriggerEvent({"status": "success", "message": "Cluster Created"})
+
+
+class RedshiftPauseClusterTrigger(BaseTrigger):
+    """
+    Trigger for RedshiftPauseClusterOperator.
+    The trigger will asynchronously poll the boto3 API and wait for the
+    Redshift cluster to be in the `paused` state.
+
+    :param cluster_identifier:  A unique identifier for the cluster.
+    :param poll_interval: The amount of time in seconds to wait between 
attempts.
+    :param max_attempts: The maximum number of attempts to be made.
+    :param aws_conn_id: The Airflow connection used for AWS credentials.
+    """
+
+    def __init__(
+        self,
+        cluster_identifier: str,
+        poll_interval: int,
+        max_attempts: int,
+        aws_conn_id: str,
+    ):
+        self.cluster_identifier = cluster_identifier
+        self.poll_interval = poll_interval
+        self.max_attempts = max_attempts
+        self.aws_conn_id = aws_conn_id
+
+    def serialize(self) -> tuple[str, dict[str, Any]]:
+        return (
+            
"airflow.providers.amazon.aws.triggers.redshift_cluster.RedshiftPauseClusterTrigger",
+            {
+                "cluster_identifier": self.cluster_identifier,
+                "poll_interval": str(self.poll_interval),
+                "max_attempts": str(self.max_attempts),
+                "aws_conn_id": self.aws_conn_id,
+            },
+        )
+
+    @cached_property
+    def hook(self) -> RedshiftHook:
+        return RedshiftHook(aws_conn_id=self.aws_conn_id)
+
+    async def run(self):
+        async with self.hook.async_conn as client:
+            attempt = 0
+            waiter = self.hook.get_waiter("cluster_paused", deferrable=True, 
client=client)
+            while attempt < int(self.max_attempts):
+                attempt = attempt + 1
+                try:
+                    await waiter.wait(
+                        ClusterIdentifier=self.cluster_identifier,
+                        WaiterConfig={
+                            "Delay": int(self.poll_interval),
+                            "MaxAttempts": 1,
+                        },
+                    )
+                    break
+                except WaiterError as error:
+                    if "terminal failure" in str(error):
+                        yield TriggerEvent({"status": "failure", "message": 
f"Pause Cluster Failed: {error}"})
+                        break
+                    self.log.info(
+                        "Status of cluster is %s", 
error.last_response["Clusters"][0]["ClusterStatus"]
+                    )
+                    await asyncio.sleep(int(self.poll_interval))
+        if attempt >= int(self.max_attempts):
+            yield TriggerEvent(
+                {"status": "failure", "message": "Pause Cluster Failed - max 
attempts reached."}
+            )
+        else:
+            yield TriggerEvent({"status": "success", "message": "Cluster 
paused"})
diff --git a/airflow/providers/amazon/aws/waiters/redshift.json 
b/airflow/providers/amazon/aws/waiters/redshift.json
new file mode 100644
index 0000000000..587f8ce989
--- /dev/null
+++ b/airflow/providers/amazon/aws/waiters/redshift.json
@@ -0,0 +1,30 @@
+{
+    "version": 2,
+    "waiters": {
+        "cluster_paused": {
+            "operation": "DescribeClusters",
+            "delay": 30,
+            "maxAttempts": 60,
+            "acceptors": [
+                {
+                    "matcher": "pathAll",
+                    "argument": "Clusters[].ClusterStatus",
+                    "expected": "paused",
+                    "state": "success"
+                },
+                {
+                    "matcher": "error",
+                    "argument": "Clusters[].ClusterStatus",
+                    "expected": "ClusterNotFound",
+                    "state": "retry"
+                },
+                {
+                    "matcher": "pathAny",
+                    "argument": "Clusters[].ClusterStatus",
+                    "expected": "deleting",
+                    "state": "failure"
+                }
+            ]
+        }
+    }
+}
diff --git a/tests/providers/amazon/aws/operators/test_redshift_cluster.py 
b/tests/providers/amazon/aws/operators/test_redshift_cluster.py
index fca7dafdaa..f4bb22d4b5 100644
--- a/tests/providers/amazon/aws/operators/test_redshift_cluster.py
+++ b/tests/providers/amazon/aws/operators/test_redshift_cluster.py
@@ -31,7 +31,10 @@ from airflow.providers.amazon.aws.operators.redshift_cluster 
import (
     RedshiftPauseClusterOperator,
     RedshiftResumeClusterOperator,
 )
-from airflow.providers.amazon.aws.triggers.redshift_cluster import 
RedshiftClusterTrigger
+from airflow.providers.amazon.aws.triggers.redshift_cluster import (
+    RedshiftClusterTrigger,
+    RedshiftPauseClusterTrigger,
+)
 
 
 class TestRedshiftCreateClusterOperator:
@@ -389,9 +392,10 @@ class TestPauseClusterOperator:
             redshift_operator.execute(None)
         assert mock_conn.pause_cluster.call_count == 10
 
-    def test_pause_cluster_deferrable_mode(self):
+    @mock.patch.object(RedshiftHook, "get_conn")
+    def test_pause_cluster_deferrable_mode(self, mock_get_conn):
         """Test Pause cluster operator with defer when deferrable param is 
true"""
-
+        mock_get_conn().pause_cluster.return_value = True
         redshift_operator = RedshiftPauseClusterOperator(
             task_id="task_test", cluster_identifier="test_cluster", 
deferrable=True
         )
@@ -400,8 +404,8 @@ class TestPauseClusterOperator:
             redshift_operator.execute(context=None)
 
         assert isinstance(
-            exc.value.trigger, RedshiftClusterTrigger
-        ), "Trigger is not a RedshiftClusterTrigger"
+            exc.value.trigger, RedshiftPauseClusterTrigger
+        ), "Trigger is not a RedshiftPauseClusterTrigger"
 
     def test_pause_cluster_execute_complete_success(self):
         """Asserts that logging occurs as expected"""
diff --git a/tests/providers/amazon/aws/triggers/test_redshift_cluster.py 
b/tests/providers/amazon/aws/triggers/test_redshift_cluster.py
index 941258659e..2e7f6490d6 100644
--- a/tests/providers/amazon/aws/triggers/test_redshift_cluster.py
+++ b/tests/providers/amazon/aws/triggers/test_redshift_cluster.py
@@ -19,8 +19,13 @@ from __future__ import annotations
 import sys
 
 import pytest
+from botocore.exceptions import WaiterError
 
-from airflow.providers.amazon.aws.triggers.redshift_cluster import 
RedshiftCreateClusterTrigger
+from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftHook
+from airflow.providers.amazon.aws.triggers.redshift_cluster import (
+    RedshiftCreateClusterTrigger,
+    RedshiftPauseClusterTrigger,
+)
 from airflow.triggers.base import TriggerEvent
 
 if sys.version_info < (3, 8):
@@ -72,3 +77,143 @@ class TestRedshiftCreateClusterTrigger:
         response = await generator.asend(None)
 
         assert response == TriggerEvent({"status": "success", "message": 
"Cluster Created"})
+
+
+class TestRedshiftPauseClusterTrigger:
+    def test_redshift_pause_cluster_trigger_serialize(self):
+        redshift_pause_cluster_trigger = RedshiftPauseClusterTrigger(
+            cluster_identifier=TEST_CLUSTER_IDENTIFIER,
+            poll_interval=TEST_POLL_INTERVAL,
+            max_attempts=TEST_MAX_ATTEMPT,
+            aws_conn_id=TEST_AWS_CONN_ID,
+        )
+        class_path, args = redshift_pause_cluster_trigger.serialize()
+        assert (
+            class_path == 
"airflow.providers.amazon.aws.triggers.redshift_cluster.RedshiftPauseClusterTrigger"
+        )
+        assert args["cluster_identifier"] == TEST_CLUSTER_IDENTIFIER
+        assert args["poll_interval"] == str(TEST_POLL_INTERVAL)
+        assert args["max_attempts"] == str(TEST_MAX_ATTEMPT)
+        assert args["aws_conn_id"] == TEST_AWS_CONN_ID
+
+    @pytest.mark.asyncio
+    @async_mock.patch.object(RedshiftHook, "get_waiter")
+    @async_mock.patch.object(RedshiftHook, "async_conn")
+    async def test_redshift_pause_cluster_trigger_run(self, mock_async_conn, 
mock_get_waiter):
+        mock = async_mock.MagicMock()
+        mock_async_conn.__aenter__.return_value = mock
+
+        mock_get_waiter().wait = AsyncMock()
+
+        redshift_pause_cluster_trigger = RedshiftPauseClusterTrigger(
+            cluster_identifier=TEST_CLUSTER_IDENTIFIER,
+            poll_interval=TEST_POLL_INTERVAL,
+            max_attempts=TEST_MAX_ATTEMPT,
+            aws_conn_id=TEST_AWS_CONN_ID,
+        )
+
+        generator = redshift_pause_cluster_trigger.run()
+        response = await generator.asend(None)
+
+        assert response == TriggerEvent({"status": "success", "message": 
"Cluster paused"})
+
+    @pytest.mark.asyncio
+    @async_mock.patch("asyncio.sleep")
+    @async_mock.patch.object(RedshiftHook, "get_waiter")
+    @async_mock.patch.object(RedshiftHook, "async_conn")
+    async def test_redshift_pause_cluster_trigger_run_multiple_attempts(
+        self, mock_async_conn, mock_get_waiter, mock_sleep
+    ):
+        mock = async_mock.MagicMock()
+        mock_async_conn.__aenter__.return_value = mock
+        error = WaiterError(
+            name="test_name",
+            reason="test_reason",
+            last_response={"Clusters": [{"ClusterStatus": "available"}]},
+        )
+        mock_get_waiter().wait.side_effect = AsyncMock(side_effect=[error, 
error, True])
+        mock_sleep.return_value = True
+
+        redshift_pause_cluster_trigger = RedshiftPauseClusterTrigger(
+            cluster_identifier=TEST_CLUSTER_IDENTIFIER,
+            poll_interval=TEST_POLL_INTERVAL,
+            max_attempts=TEST_MAX_ATTEMPT,
+            aws_conn_id=TEST_AWS_CONN_ID,
+        )
+
+        generator = redshift_pause_cluster_trigger.run()
+        response = await generator.asend(None)
+
+        assert mock_get_waiter().wait.call_count == 3
+        assert response == TriggerEvent({"status": "success", "message": 
"Cluster paused"})
+
+    @pytest.mark.asyncio
+    @async_mock.patch("asyncio.sleep")
+    @async_mock.patch.object(RedshiftHook, "get_waiter")
+    @async_mock.patch.object(RedshiftHook, "async_conn")
+    async def test_redshift_pause_cluster_trigger_run_attempts_exceeded(
+        self, mock_async_conn, mock_get_waiter, mock_sleep
+    ):
+        mock = async_mock.MagicMock()
+        mock_async_conn.__aenter__.return_value = mock
+        error = WaiterError(
+            name="test_name",
+            reason="test_reason",
+            last_response={"Clusters": [{"ClusterStatus": "available"}]},
+        )
+        mock_get_waiter().wait.side_effect = AsyncMock(side_effect=[error, 
error, True])
+        mock_sleep.return_value = True
+
+        redshift_pause_cluster_trigger = RedshiftPauseClusterTrigger(
+            cluster_identifier=TEST_CLUSTER_IDENTIFIER,
+            poll_interval=TEST_POLL_INTERVAL,
+            max_attempts=2,
+            aws_conn_id=TEST_AWS_CONN_ID,
+        )
+
+        generator = redshift_pause_cluster_trigger.run()
+        response = await generator.asend(None)
+
+        assert mock_get_waiter().wait.call_count == 2
+        assert response == TriggerEvent(
+            {"status": "failure", "message": "Pause Cluster Failed - max 
attempts reached."}
+        )
+
+    @pytest.mark.asyncio
+    @async_mock.patch("asyncio.sleep")
+    @async_mock.patch.object(RedshiftHook, "get_waiter")
+    @async_mock.patch.object(RedshiftHook, "async_conn")
+    async def test_redshift_pause_cluster_trigger_run_attempts_failed(
+        self, mock_async_conn, mock_get_waiter, mock_sleep
+    ):
+        mock = async_mock.MagicMock()
+        mock_async_conn.__aenter__.return_value = mock
+        error_available = WaiterError(
+            name="test_name",
+            reason="Max attempts exceeded",
+            last_response={"Clusters": [{"ClusterStatus": "available"}]},
+        )
+        error_failed = WaiterError(
+            name="test_name",
+            reason="Waiter encountered a terminal failure state:",
+            last_response={"Clusters": [{"ClusterStatus": "available"}]},
+        )
+        mock_get_waiter().wait.side_effect = AsyncMock(
+            side_effect=[error_available, error_available, error_failed]
+        )
+        mock_sleep.return_value = True
+
+        redshift_pause_cluster_trigger = RedshiftPauseClusterTrigger(
+            cluster_identifier=TEST_CLUSTER_IDENTIFIER,
+            poll_interval=TEST_POLL_INTERVAL,
+            max_attempts=TEST_MAX_ATTEMPT,
+            aws_conn_id=TEST_AWS_CONN_ID,
+        )
+
+        generator = redshift_pause_cluster_trigger.run()
+        response = await generator.asend(None)
+
+        assert mock_get_waiter().wait.call_count == 3
+        assert response == TriggerEvent(
+            {"status": "failure", "message": f"Pause Cluster Failed: 
{error_failed}"}
+        )

Reply via email to