This is an automated email from the ASF dual-hosted git repository.

vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 7fe573317e Fix AWS Redshift operators and sensors (#41191)
7fe573317e is described below

commit 7fe573317eb630c2d176329c599d6fbbb30f4378
Author: Vincent <[email protected]>
AuthorDate: Thu Aug 1 13:54:14 2024 -0400

    Fix AWS Redshift operators and sensors (#41191)
---
 .../amazon/aws/operators/redshift_cluster.py       | 118 ++++++++++++---------
 .../amazon/aws/operators/test_redshift_cluster.py  |  25 ++++-
 .../providers/amazon/aws/example_redshift.py       |   4 +-
 3 files changed, 91 insertions(+), 56 deletions(-)

diff --git a/airflow/providers/amazon/aws/operators/redshift_cluster.py 
b/airflow/providers/amazon/aws/operators/redshift_cluster.py
index f31b7db09a..4666445f96 100644
--- a/airflow/providers/amazon/aws/operators/redshift_cluster.py
+++ b/airflow/providers/amazon/aws/operators/redshift_cluster.py
@@ -32,6 +32,7 @@ from airflow.providers.amazon.aws.triggers.redshift_cluster 
import (
     RedshiftResumeClusterTrigger,
 )
 from airflow.providers.amazon.aws.utils import validate_execute_complete_event
+from airflow.utils.helpers import prune_dict
 
 if TYPE_CHECKING:
     from airflow.utils.context import Context
@@ -507,8 +508,8 @@ class RedshiftResumeClusterOperator(BaseOperator):
         aws_conn_id: str | None = "aws_default",
         wait_for_completion: bool = False,
         deferrable: bool = conf.getboolean("operators", "default_deferrable", 
fallback=False),
-        poll_interval: int = 10,
-        max_attempts: int = 10,
+        poll_interval: int = 30,
+        max_attempts: int = 30,
         **kwargs,
     ):
         super().__init__(**kwargs)
@@ -542,38 +543,38 @@ class RedshiftResumeClusterOperator(BaseOperator):
                 else:
                     raise error
 
-        if self.deferrable:
-            cluster_state = 
redshift_hook.cluster_status(cluster_identifier=self.cluster_identifier)
-            if cluster_state == "available":
-                self.log.info("Resumed cluster successfully")
-            elif cluster_state == "deleting":
-                raise AirflowException(
-                    "Unable to resume cluster since cluster is currently in 
status: %s", cluster_state
-                )
+        if self.wait_for_completion:
+            if self.deferrable:
+                cluster_state = 
redshift_hook.cluster_status(cluster_identifier=self.cluster_identifier)
+                if cluster_state == "available":
+                    self.log.info("Resumed cluster successfully")
+                elif cluster_state == "deleting":
+                    raise AirflowException(
+                        "Unable to resume cluster since cluster is currently 
in status: %s", cluster_state
+                    )
+                else:
+                    self.defer(
+                        trigger=RedshiftResumeClusterTrigger(
+                            cluster_identifier=self.cluster_identifier,
+                            waiter_delay=self.poll_interval,
+                            waiter_max_attempts=self.max_attempts,
+                            aws_conn_id=self.aws_conn_id,
+                        ),
+                        method_name="execute_complete",
+                        # timeout is set to ensure that if a trigger dies, the 
timeout does not restart
+                        # 60 seconds is added to allow the trigger to exit 
gracefully (i.e. yield TriggerEvent)
+                        timeout=timedelta(seconds=self.max_attempts * 
self.poll_interval + 60),
+                    )
             else:
-                self.defer(
-                    trigger=RedshiftResumeClusterTrigger(
-                        cluster_identifier=self.cluster_identifier,
-                        waiter_delay=self.poll_interval,
-                        waiter_max_attempts=self.max_attempts,
-                        aws_conn_id=self.aws_conn_id,
-                    ),
-                    method_name="execute_complete",
-                    # timeout is set to ensure that if a trigger dies, the 
timeout does not restart
-                    # 60 seconds is added to allow the trigger to exit 
gracefully (i.e. yield TriggerEvent)
-                    timeout=timedelta(seconds=self.max_attempts * 
self.poll_interval + 60),
+                waiter = redshift_hook.get_waiter("cluster_resumed")
+                waiter.wait(
+                    ClusterIdentifier=self.cluster_identifier,
+                    WaiterConfig={
+                        "Delay": self.poll_interval,
+                        "MaxAttempts": self.max_attempts,
+                    },
                 )
 
-        if self.wait_for_completion:
-            waiter = redshift_hook.get_waiter("cluster_resumed")
-            waiter.wait(
-                ClusterIdentifier=self.cluster_identifier,
-                WaiterConfig={
-                    "Delay": self.poll_interval,
-                    "MaxAttempts": self.max_attempts,
-                },
-            )
-
     def execute_complete(self, context: Context, event: dict[str, Any] | None 
= None) -> None:
         event = validate_execute_complete_event(event)
 
@@ -596,6 +597,7 @@ class RedshiftPauseClusterOperator(BaseOperator):
         running Airflow in a distributed manner and aws_conn_id is None or
         empty, then default boto3 configuration would be used (and must be
         maintained on each worker node).
+    :param wait_for_completion: If True, waits for the cluster to be paused. 
(default: False)
     :param deferrable: Run operator in the deferrable mode
     :param poll_interval: Time (in seconds) to wait between two consecutive 
calls to check cluster state
     :param max_attempts: Maximum number of attempts to poll the cluster
@@ -610,14 +612,16 @@ class RedshiftPauseClusterOperator(BaseOperator):
         *,
         cluster_identifier: str,
         aws_conn_id: str | None = "aws_default",
+        wait_for_completion: bool = False,
         deferrable: bool = conf.getboolean("operators", "default_deferrable", 
fallback=False),
-        poll_interval: int = 10,
-        max_attempts: int = 15,
+        poll_interval: int = 30,
+        max_attempts: int = 30,
         **kwargs,
     ):
         super().__init__(**kwargs)
         self.cluster_identifier = cluster_identifier
         self.aws_conn_id = aws_conn_id
+        self.wait_for_completion = wait_for_completion
         self.deferrable = deferrable
         self.max_attempts = max_attempts
         self.poll_interval = poll_interval
@@ -643,26 +647,38 @@ class RedshiftPauseClusterOperator(BaseOperator):
                     time.sleep(self._attempt_interval)
                 else:
                     raise error
-        if self.deferrable:
-            cluster_state = 
redshift_hook.cluster_status(cluster_identifier=self.cluster_identifier)
-            if cluster_state == "paused":
-                self.log.info("Paused cluster successfully")
-            elif cluster_state == "deleting":
-                raise AirflowException(
-                    f"Unable to pause cluster since cluster is currently in 
status: {cluster_state}"
-                )
+        if self.wait_for_completion:
+            if self.deferrable:
+                cluster_state = 
redshift_hook.cluster_status(cluster_identifier=self.cluster_identifier)
+                if cluster_state == "paused":
+                    self.log.info("Paused cluster successfully")
+                elif cluster_state == "deleting":
+                    raise AirflowException(
+                        f"Unable to pause cluster since cluster is currently 
in status: {cluster_state}"
+                    )
+                else:
+                    self.defer(
+                        trigger=RedshiftPauseClusterTrigger(
+                            cluster_identifier=self.cluster_identifier,
+                            waiter_delay=self.poll_interval,
+                            waiter_max_attempts=self.max_attempts,
+                            aws_conn_id=self.aws_conn_id,
+                        ),
+                        method_name="execute_complete",
+                        # timeout is set to ensure that if a trigger dies, the 
timeout does not restart
+                        # 60 seconds is added to allow the trigger to exit 
gracefully (i.e. yield TriggerEvent)
+                        timeout=timedelta(seconds=self.max_attempts * 
self.poll_interval + 60),
+                    )
             else:
-                self.defer(
-                    trigger=RedshiftPauseClusterTrigger(
-                        cluster_identifier=self.cluster_identifier,
-                        waiter_delay=self.poll_interval,
-                        waiter_max_attempts=self.max_attempts,
-                        aws_conn_id=self.aws_conn_id,
+                waiter = redshift_hook.get_waiter("cluster_paused")
+                waiter.wait(
+                    ClusterIdentifier=self.cluster_identifier,
+                    WaiterConfig=prune_dict(
+                        {
+                            "Delay": self.poll_interval,
+                            "MaxAttempts": self.max_attempts,
+                        }
                     ),
-                    method_name="execute_complete",
-                    # timeout is set to ensure that if a trigger dies, the 
timeout does not restart
-                    # 60 seconds is added to allow the trigger to exit 
gracefully (i.e. yield TriggerEvent)
-                    timeout=timedelta(seconds=self.max_attempts * 
self.poll_interval + 60),
                 )
 
     def execute_complete(self, context: Context, event: dict[str, Any] | None 
= None) -> None:
diff --git a/tests/providers/amazon/aws/operators/test_redshift_cluster.py 
b/tests/providers/amazon/aws/operators/test_redshift_cluster.py
index c537f4ce24..f6c960b7bd 100644
--- a/tests/providers/amazon/aws/operators/test_redshift_cluster.py
+++ b/tests/providers/amazon/aws/operators/test_redshift_cluster.py
@@ -17,6 +17,7 @@
 from __future__ import annotations
 
 from unittest import mock
+from unittest.mock import Mock
 
 import boto3
 import pytest
@@ -318,6 +319,7 @@ class TestResumeClusterOperator:
             task_id="task_test",
             cluster_identifier="test_cluster",
             aws_conn_id="aws_conn_test",
+            wait_for_completion=True,
             deferrable=True,
         )
 
@@ -340,6 +342,7 @@ class TestResumeClusterOperator:
             task_id="task_test",
             cluster_identifier="test_cluster",
             aws_conn_id="aws_conn_test",
+            wait_for_completion=True,
             deferrable=True,
         )
 
@@ -366,7 +369,7 @@ class TestResumeClusterOperator:
         mock_get_waiter.assert_called_with("cluster_resumed")
         assert mock_get_waiter.call_count == 2
         mock_get_waiter().wait.assert_called_once_with(
-            ClusterIdentifier="test_cluster", WaiterConfig={"Delay": 10, 
"MaxAttempts": 10}
+            ClusterIdentifier="test_cluster", WaiterConfig={"Delay": 30, 
"MaxAttempts": 30}
         )
 
     def test_resume_cluster_failure(self):
@@ -437,6 +440,22 @@ class TestPauseClusterOperator:
             redshift_operator.execute(None)
         assert mock_conn.pause_cluster.call_count == 10
 
+    @mock.patch.object(RedshiftHook, "get_waiter")
+    @mock.patch.object(RedshiftHook, "get_conn")
+    def test_pause_cluster_wait_for_completion(self, mock_get_conn, 
mock_get_waiter):
+        """Test Pause cluster operator with defer when deferrable param is 
true"""
+        mock_get_conn.return_value.pause_cluster.return_value = True
+        waiter = Mock()
+        mock_get_waiter.return_value = waiter
+
+        redshift_operator = RedshiftPauseClusterOperator(
+            task_id="task_test", cluster_identifier="test_cluster", 
wait_for_completion=True
+        )
+
+        redshift_operator.execute(context=None)
+
+        waiter.wait.assert_called_once()
+
     @mock.patch.object(RedshiftHook, "cluster_status")
     @mock.patch.object(RedshiftHook, "get_conn")
     def test_pause_cluster_deferrable_mode(self, mock_get_conn, 
mock_cluster_status):
@@ -445,7 +464,7 @@ class TestPauseClusterOperator:
         mock_cluster_status.return_value = "available"
 
         redshift_operator = RedshiftPauseClusterOperator(
-            task_id="task_test", cluster_identifier="test_cluster", 
deferrable=True
+            task_id="task_test", cluster_identifier="test_cluster", 
wait_for_completion=True, deferrable=True
         )
 
         with pytest.raises(TaskDeferred) as exc:
@@ -466,7 +485,7 @@ class TestPauseClusterOperator:
         mock_cluster_status.return_value = "deleting"
 
         redshift_operator = RedshiftPauseClusterOperator(
-            task_id="task_test", cluster_identifier="test_cluster", 
deferrable=True
+            task_id="task_test", cluster_identifier="test_cluster", 
wait_for_completion=True, deferrable=True
         )
 
         with pytest.raises(AirflowException):
diff --git a/tests/system/providers/amazon/aws/example_redshift.py 
b/tests/system/providers/amazon/aws/example_redshift.py
index cc88811bef..cc92076dcb 100644
--- a/tests/system/providers/amazon/aws/example_redshift.py
+++ b/tests/system/providers/amazon/aws/example_redshift.py
@@ -121,7 +121,7 @@ with DAG(
         task_id="wait_cluster_paused",
         cluster_identifier=redshift_cluster_identifier,
         target_status="paused",
-        poke_interval=15,
+        poke_interval=30,
         timeout=60 * 30,
     )
 
@@ -136,7 +136,7 @@ with DAG(
         task_id="wait_cluster_available_after_resume",
         cluster_identifier=redshift_cluster_identifier,
         target_status="available",
-        poke_interval=15,
+        poke_interval=30,
         timeout=60 * 30,
     )
 

Reply via email to