This is an automated email from the ASF dual-hosted git repository.
onikolas pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 0b7c095c9f Change Deferrable implementation for
RedshiftPauseClusterOperator to follow standard (#30853)
0b7c095c9f is described below
commit 0b7c095c9fa54981248a72659da9acdf3bf5c2c0
Author: Syed Hussaain <[email protected]>
AuthorDate: Tue May 23 11:07:37 2023 -0700
Change Deferrable implementation for RedshiftPauseClusterOperator to follow
standard (#30853)
* Change base_aws.py to support async_conn
* Add async custom waiter support in get_waiter, and base_waiter.py
* Add Deferrable mode to RedshiftCreateClusterOperator
* Add RedshiftCreateClusterTrigger and unit test
* Add README.md for writing Triggers for AMPP
* Add Deferrable mode to Redshift Pause Cluster Operator
* Add logging to deferrable waiter
* Add check for failure early
---
.../amazon/aws/operators/redshift_cluster.py | 69 +++++-----
.../amazon/aws/triggers/redshift_cluster.py | 73 ++++++++++
airflow/providers/amazon/aws/waiters/redshift.json | 30 +++++
.../amazon/aws/operators/test_redshift_cluster.py | 14 +-
.../amazon/aws/triggers/test_redshift_cluster.py | 147 ++++++++++++++++++++-
5 files changed, 291 insertions(+), 42 deletions(-)
diff --git a/airflow/providers/amazon/aws/operators/redshift_cluster.py
b/airflow/providers/amazon/aws/operators/redshift_cluster.py
index 2880240b15..6b44785da6 100644
--- a/airflow/providers/amazon/aws/operators/redshift_cluster.py
+++ b/airflow/providers/amazon/aws/operators/redshift_cluster.py
@@ -17,6 +17,7 @@
from __future__ import annotations
import time
+from datetime import timedelta
from typing import TYPE_CHECKING, Any, Sequence
from airflow.exceptions import AirflowException
@@ -25,6 +26,7 @@ from airflow.providers.amazon.aws.hooks.redshift_cluster
import RedshiftHook
from airflow.providers.amazon.aws.triggers.redshift_cluster import (
RedshiftClusterTrigger,
RedshiftCreateClusterTrigger,
+ RedshiftPauseClusterTrigger,
)
if TYPE_CHECKING:
@@ -510,7 +512,9 @@ class RedshiftPauseClusterOperator(BaseOperator):
:param cluster_identifier: id of the AWS Redshift Cluster
:param aws_conn_id: aws connection to use
- :param deferrable: Run operator in the deferrable mode. This mode requires
an additional aiobotocore>=
+ :param deferrable: Run operator in the deferrable mode
+ :param poll_interval: Time (in seconds) to wait between two consecutive
calls to check cluster state
+ :param max_attempts: Maximum number of attempts to poll the cluster
"""
template_fields: Sequence[str] = ("cluster_identifier",)
@@ -524,64 +528,57 @@ class RedshiftPauseClusterOperator(BaseOperator):
aws_conn_id: str = "aws_default",
deferrable: bool = False,
poll_interval: int = 10,
+ max_attempts: int = 15,
**kwargs,
):
super().__init__(**kwargs)
self.cluster_identifier = cluster_identifier
self.aws_conn_id = aws_conn_id
self.deferrable = deferrable
+ self.max_attempts = max_attempts
self.poll_interval = poll_interval
- # These parameters are added to address an issue with the boto3 API
where the API
+ # These parameters are used to address an issue with the boto3 API
where the API
# prematurely reports the cluster as available to receive requests.
This causes the cluster
# to reject initial attempts to pause the cluster despite reporting
the correct state.
- self._attempts = 10
+ self._remaining_attempts = 10
self._attempt_interval = 15
def execute(self, context: Context):
redshift_hook = RedshiftHook(aws_conn_id=self.aws_conn_id)
+ while self._remaining_attempts >= 1:
+ try:
+
redshift_hook.get_conn().pause_cluster(ClusterIdentifier=self.cluster_identifier)
+ break
+ except
redshift_hook.get_conn().exceptions.InvalidClusterStateFault as error:
+ self._remaining_attempts = self._remaining_attempts - 1
+ if self._remaining_attempts > 0:
+ self.log.error(
+ "Unable to pause cluster. %d attempts remaining.",
self._remaining_attempts
+ )
+ time.sleep(self._attempt_interval)
+ else:
+ raise error
if self.deferrable:
self.defer(
- timeout=self.execution_timeout,
- trigger=RedshiftClusterTrigger(
- task_id=self.task_id,
+ trigger=RedshiftPauseClusterTrigger(
+ cluster_identifier=self.cluster_identifier,
poll_interval=self.poll_interval,
+ max_attempts=self.max_attempts,
aws_conn_id=self.aws_conn_id,
- cluster_identifier=self.cluster_identifier,
- attempts=self._attempts,
- operation_type="pause_cluster",
),
method_name="execute_complete",
+ # timeout is set to ensure that if a trigger dies, the timeout
does not restart
+ # 60 seconds is added to allow the trigger to exit gracefully
(i.e. yield TriggerEvent)
+ timeout=timedelta(seconds=self.max_attempts *
self.poll_interval + 60),
)
- else:
- while self._attempts >= 1:
- try:
-
redshift_hook.get_conn().pause_cluster(ClusterIdentifier=self.cluster_identifier)
- return
- except
redshift_hook.get_conn().exceptions.InvalidClusterStateFault as error:
- self._attempts = self._attempts - 1
-
- if self._attempts > 0:
- self.log.error("Unable to pause cluster. %d attempts
remaining.", self._attempts)
- time.sleep(self._attempt_interval)
- else:
- raise error
- def execute_complete(self, context: Context, event: Any = None) -> None:
- """
- Callback for when the trigger fires - returns immediately.
- Relies on trigger to throw an exception, otherwise it assumes
execution was
- successful.
- """
- if event:
- if "status" in event and event["status"] == "error":
- msg = f"{event['status']}: {event['message']}"
- raise AirflowException(msg)
- elif "status" in event and event["status"] == "success":
- self.log.info("%s completed successfully.", self.task_id)
- self.log.info("Paused cluster successfully")
+ def execute_complete(self, context, event=None):
+ if event["status"] != "success":
+ raise AirflowException(f"Error pausing cluster: {event}")
else:
- raise AirflowException("No event received from trigger")
+ self.log.info("Paused cluster successfully")
+ return
class RedshiftDeleteClusterOperator(BaseOperator):
diff --git a/airflow/providers/amazon/aws/triggers/redshift_cluster.py
b/airflow/providers/amazon/aws/triggers/redshift_cluster.py
index ef19d0b5a1..879f027e35 100644
--- a/airflow/providers/amazon/aws/triggers/redshift_cluster.py
+++ b/airflow/providers/amazon/aws/triggers/redshift_cluster.py
@@ -16,8 +16,11 @@
# under the License.
from __future__ import annotations
+import asyncio
from typing import Any, AsyncIterator
+from botocore.exceptions import WaiterError
+
from airflow.compat.functools import cached_property
from airflow.providers.amazon.aws.hooks.redshift_cluster import
RedshiftAsyncHook, RedshiftHook
from airflow.triggers.base import BaseTrigger, TriggerEvent
@@ -137,3 +140,73 @@ class RedshiftCreateClusterTrigger(BaseTrigger):
},
)
yield TriggerEvent({"status": "success", "message": "Cluster Created"})
+
+
+class RedshiftPauseClusterTrigger(BaseTrigger):
+ """
+ Trigger for RedshiftPauseClusterOperator.
+ The trigger will asynchronously poll the boto3 API and wait for the
+ Redshift cluster to be in the `paused` state.
+
+ :param cluster_identifier: A unique identifier for the cluster.
+ :param poll_interval: The amount of time in seconds to wait between
attempts.
+ :param max_attempts: The maximum number of attempts to be made.
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
+ """
+
+ def __init__(
+ self,
+ cluster_identifier: str,
+ poll_interval: int,
+ max_attempts: int,
+ aws_conn_id: str,
+ ):
+ self.cluster_identifier = cluster_identifier
+ self.poll_interval = poll_interval
+ self.max_attempts = max_attempts
+ self.aws_conn_id = aws_conn_id
+
+ def serialize(self) -> tuple[str, dict[str, Any]]:
+ return (
+
"airflow.providers.amazon.aws.triggers.redshift_cluster.RedshiftPauseClusterTrigger",
+ {
+ "cluster_identifier": self.cluster_identifier,
+ "poll_interval": str(self.poll_interval),
+ "max_attempts": str(self.max_attempts),
+ "aws_conn_id": self.aws_conn_id,
+ },
+ )
+
+ @cached_property
+ def hook(self) -> RedshiftHook:
+ return RedshiftHook(aws_conn_id=self.aws_conn_id)
+
+ async def run(self):
+ async with self.hook.async_conn as client:
+ attempt = 0
+ waiter = self.hook.get_waiter("cluster_paused", deferrable=True,
client=client)
+ while attempt < int(self.max_attempts):
+ attempt = attempt + 1
+ try:
+ await waiter.wait(
+ ClusterIdentifier=self.cluster_identifier,
+ WaiterConfig={
+ "Delay": int(self.poll_interval),
+ "MaxAttempts": 1,
+ },
+ )
+ break
+ except WaiterError as error:
+ if "terminal failure" in str(error):
+ yield TriggerEvent({"status": "failure", "message":
f"Pause Cluster Failed: {error}"})
+ break
+ self.log.info(
+ "Status of cluster is %s",
error.last_response["Clusters"][0]["ClusterStatus"]
+ )
+ await asyncio.sleep(int(self.poll_interval))
+ if attempt >= int(self.max_attempts):
+ yield TriggerEvent(
+ {"status": "failure", "message": "Pause Cluster Failed - max
attempts reached."}
+ )
+ else:
+ yield TriggerEvent({"status": "success", "message": "Cluster
paused"})
diff --git a/airflow/providers/amazon/aws/waiters/redshift.json
b/airflow/providers/amazon/aws/waiters/redshift.json
new file mode 100644
index 0000000000..587f8ce989
--- /dev/null
+++ b/airflow/providers/amazon/aws/waiters/redshift.json
@@ -0,0 +1,30 @@
+{
+ "version": 2,
+ "waiters": {
+ "cluster_paused": {
+ "operation": "DescribeClusters",
+ "delay": 30,
+ "maxAttempts": 60,
+ "acceptors": [
+ {
+ "matcher": "pathAll",
+ "argument": "Clusters[].ClusterStatus",
+ "expected": "paused",
+ "state": "success"
+ },
+ {
+ "matcher": "error",
+ "argument": "Clusters[].ClusterStatus",
+ "expected": "ClusterNotFound",
+ "state": "retry"
+ },
+ {
+ "matcher": "pathAny",
+ "argument": "Clusters[].ClusterStatus",
+ "expected": "deleting",
+ "state": "failure"
+ }
+ ]
+ }
+ }
+}
diff --git a/tests/providers/amazon/aws/operators/test_redshift_cluster.py
b/tests/providers/amazon/aws/operators/test_redshift_cluster.py
index fca7dafdaa..f4bb22d4b5 100644
--- a/tests/providers/amazon/aws/operators/test_redshift_cluster.py
+++ b/tests/providers/amazon/aws/operators/test_redshift_cluster.py
@@ -31,7 +31,10 @@ from airflow.providers.amazon.aws.operators.redshift_cluster
import (
RedshiftPauseClusterOperator,
RedshiftResumeClusterOperator,
)
-from airflow.providers.amazon.aws.triggers.redshift_cluster import
RedshiftClusterTrigger
+from airflow.providers.amazon.aws.triggers.redshift_cluster import (
+ RedshiftClusterTrigger,
+ RedshiftPauseClusterTrigger,
+)
class TestRedshiftCreateClusterOperator:
@@ -389,9 +392,10 @@ class TestPauseClusterOperator:
redshift_operator.execute(None)
assert mock_conn.pause_cluster.call_count == 10
- def test_pause_cluster_deferrable_mode(self):
+ @mock.patch.object(RedshiftHook, "get_conn")
+ def test_pause_cluster_deferrable_mode(self, mock_get_conn):
"""Test Pause cluster operator with defer when deferrable param is
true"""
-
+ mock_get_conn().pause_cluster.return_value = True
redshift_operator = RedshiftPauseClusterOperator(
task_id="task_test", cluster_identifier="test_cluster",
deferrable=True
)
@@ -400,8 +404,8 @@ class TestPauseClusterOperator:
redshift_operator.execute(context=None)
assert isinstance(
- exc.value.trigger, RedshiftClusterTrigger
- ), "Trigger is not a RedshiftClusterTrigger"
+ exc.value.trigger, RedshiftPauseClusterTrigger
+ ), "Trigger is not a RedshiftPauseClusterTrigger"
def test_pause_cluster_execute_complete_success(self):
"""Asserts that logging occurs as expected"""
diff --git a/tests/providers/amazon/aws/triggers/test_redshift_cluster.py
b/tests/providers/amazon/aws/triggers/test_redshift_cluster.py
index 941258659e..2e7f6490d6 100644
--- a/tests/providers/amazon/aws/triggers/test_redshift_cluster.py
+++ b/tests/providers/amazon/aws/triggers/test_redshift_cluster.py
@@ -19,8 +19,13 @@ from __future__ import annotations
import sys
import pytest
+from botocore.exceptions import WaiterError
-from airflow.providers.amazon.aws.triggers.redshift_cluster import
RedshiftCreateClusterTrigger
+from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftHook
+from airflow.providers.amazon.aws.triggers.redshift_cluster import (
+ RedshiftCreateClusterTrigger,
+ RedshiftPauseClusterTrigger,
+)
from airflow.triggers.base import TriggerEvent
if sys.version_info < (3, 8):
@@ -72,3 +77,143 @@ class TestRedshiftCreateClusterTrigger:
response = await generator.asend(None)
assert response == TriggerEvent({"status": "success", "message":
"Cluster Created"})
+
+
+class TestRedshiftPauseClusterTrigger:
+ def test_redshift_pause_cluster_trigger_serialize(self):
+ redshift_pause_cluster_trigger = RedshiftPauseClusterTrigger(
+ cluster_identifier=TEST_CLUSTER_IDENTIFIER,
+ poll_interval=TEST_POLL_INTERVAL,
+ max_attempts=TEST_MAX_ATTEMPT,
+ aws_conn_id=TEST_AWS_CONN_ID,
+ )
+ class_path, args = redshift_pause_cluster_trigger.serialize()
+ assert (
+ class_path ==
"airflow.providers.amazon.aws.triggers.redshift_cluster.RedshiftPauseClusterTrigger"
+ )
+ assert args["cluster_identifier"] == TEST_CLUSTER_IDENTIFIER
+ assert args["poll_interval"] == str(TEST_POLL_INTERVAL)
+ assert args["max_attempts"] == str(TEST_MAX_ATTEMPT)
+ assert args["aws_conn_id"] == TEST_AWS_CONN_ID
+
+ @pytest.mark.asyncio
+ @async_mock.patch.object(RedshiftHook, "get_waiter")
+ @async_mock.patch.object(RedshiftHook, "async_conn")
+ async def test_redshift_pause_cluster_trigger_run(self, mock_async_conn,
mock_get_waiter):
+ mock = async_mock.MagicMock()
+ mock_async_conn.__aenter__.return_value = mock
+
+ mock_get_waiter().wait = AsyncMock()
+
+ redshift_pause_cluster_trigger = RedshiftPauseClusterTrigger(
+ cluster_identifier=TEST_CLUSTER_IDENTIFIER,
+ poll_interval=TEST_POLL_INTERVAL,
+ max_attempts=TEST_MAX_ATTEMPT,
+ aws_conn_id=TEST_AWS_CONN_ID,
+ )
+
+ generator = redshift_pause_cluster_trigger.run()
+ response = await generator.asend(None)
+
+ assert response == TriggerEvent({"status": "success", "message":
"Cluster paused"})
+
+ @pytest.mark.asyncio
+ @async_mock.patch("asyncio.sleep")
+ @async_mock.patch.object(RedshiftHook, "get_waiter")
+ @async_mock.patch.object(RedshiftHook, "async_conn")
+ async def test_redshift_pause_cluster_trigger_run_multiple_attempts(
+ self, mock_async_conn, mock_get_waiter, mock_sleep
+ ):
+ mock = async_mock.MagicMock()
+ mock_async_conn.__aenter__.return_value = mock
+ error = WaiterError(
+ name="test_name",
+ reason="test_reason",
+ last_response={"Clusters": [{"ClusterStatus": "available"}]},
+ )
+ mock_get_waiter().wait.side_effect = AsyncMock(side_effect=[error,
error, True])
+ mock_sleep.return_value = True
+
+ redshift_pause_cluster_trigger = RedshiftPauseClusterTrigger(
+ cluster_identifier=TEST_CLUSTER_IDENTIFIER,
+ poll_interval=TEST_POLL_INTERVAL,
+ max_attempts=TEST_MAX_ATTEMPT,
+ aws_conn_id=TEST_AWS_CONN_ID,
+ )
+
+ generator = redshift_pause_cluster_trigger.run()
+ response = await generator.asend(None)
+
+ assert mock_get_waiter().wait.call_count == 3
+ assert response == TriggerEvent({"status": "success", "message":
"Cluster paused"})
+
+ @pytest.mark.asyncio
+ @async_mock.patch("asyncio.sleep")
+ @async_mock.patch.object(RedshiftHook, "get_waiter")
+ @async_mock.patch.object(RedshiftHook, "async_conn")
+ async def test_redshift_pause_cluster_trigger_run_attempts_exceeded(
+ self, mock_async_conn, mock_get_waiter, mock_sleep
+ ):
+ mock = async_mock.MagicMock()
+ mock_async_conn.__aenter__.return_value = mock
+ error = WaiterError(
+ name="test_name",
+ reason="test_reason",
+ last_response={"Clusters": [{"ClusterStatus": "available"}]},
+ )
+ mock_get_waiter().wait.side_effect = AsyncMock(side_effect=[error,
error, True])
+ mock_sleep.return_value = True
+
+ redshift_pause_cluster_trigger = RedshiftPauseClusterTrigger(
+ cluster_identifier=TEST_CLUSTER_IDENTIFIER,
+ poll_interval=TEST_POLL_INTERVAL,
+ max_attempts=2,
+ aws_conn_id=TEST_AWS_CONN_ID,
+ )
+
+ generator = redshift_pause_cluster_trigger.run()
+ response = await generator.asend(None)
+
+ assert mock_get_waiter().wait.call_count == 2
+ assert response == TriggerEvent(
+ {"status": "failure", "message": "Pause Cluster Failed - max
attempts reached."}
+ )
+
+ @pytest.mark.asyncio
+ @async_mock.patch("asyncio.sleep")
+ @async_mock.patch.object(RedshiftHook, "get_waiter")
+ @async_mock.patch.object(RedshiftHook, "async_conn")
+ async def test_redshift_pause_cluster_trigger_run_attempts_failed(
+ self, mock_async_conn, mock_get_waiter, mock_sleep
+ ):
+ mock = async_mock.MagicMock()
+ mock_async_conn.__aenter__.return_value = mock
+ error_available = WaiterError(
+ name="test_name",
+ reason="Max attempts exceeded",
+ last_response={"Clusters": [{"ClusterStatus": "available"}]},
+ )
+ error_failed = WaiterError(
+ name="test_name",
+ reason="Waiter encountered a terminal failure state:",
+ last_response={"Clusters": [{"ClusterStatus": "available"}]},
+ )
+ mock_get_waiter().wait.side_effect = AsyncMock(
+ side_effect=[error_available, error_available, error_failed]
+ )
+ mock_sleep.return_value = True
+
+ redshift_pause_cluster_trigger = RedshiftPauseClusterTrigger(
+ cluster_identifier=TEST_CLUSTER_IDENTIFIER,
+ poll_interval=TEST_POLL_INTERVAL,
+ max_attempts=TEST_MAX_ATTEMPT,
+ aws_conn_id=TEST_AWS_CONN_ID,
+ )
+
+ generator = redshift_pause_cluster_trigger.run()
+ response = await generator.asend(None)
+
+ assert mock_get_waiter().wait.call_count == 3
+ assert response == TriggerEvent(
+ {"status": "failure", "message": f"Pause Cluster Failed:
{error_failed}"}
+ )