This is an automated email from the ASF dual-hosted git repository.
onikolas pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 44741f354c Change Deferrable implementation for
RedshiftResumeClusterOperator to follow standard (#30864)
44741f354c is described below
commit 44741f354cf8b7113099b0e0f147b58a55ecc5d3
Author: Syed Hussaain <[email protected]>
AuthorDate: Wed May 24 16:35:32 2023 -0700
Change Deferrable implementation for RedshiftResumeClusterOperator to
follow standard (#30864)
* Add Deferrable mode to Redshift Resume Cluster Operator
* Add wait for completeion to RedshiftResumeClusterOperator
---------
Co-authored-by: Raphaƫl Vandon <[email protected]>
---
.../amazon/aws/operators/redshift_cluster.py | 84 ++++++------
.../amazon/aws/triggers/redshift_cluster.py | 72 +++++++++++
airflow/providers/amazon/aws/waiters/redshift.json | 25 ++++
.../amazon/aws/operators/test_redshift_cluster.py | 61 +++++----
.../amazon/aws/triggers/test_redshift_cluster.py | 142 +++++++++++++++++++++
5 files changed, 318 insertions(+), 66 deletions(-)
diff --git a/airflow/providers/amazon/aws/operators/redshift_cluster.py
b/airflow/providers/amazon/aws/operators/redshift_cluster.py
index 94ce2e58c2..abf0e68f39 100644
--- a/airflow/providers/amazon/aws/operators/redshift_cluster.py
+++ b/airflow/providers/amazon/aws/operators/redshift_cluster.py
@@ -24,10 +24,10 @@ from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftHook
from airflow.providers.amazon.aws.triggers.redshift_cluster import (
- RedshiftClusterTrigger,
RedshiftCreateClusterSnapshotTrigger,
RedshiftCreateClusterTrigger,
RedshiftPauseClusterTrigger,
+ RedshiftResumeClusterTrigger,
)
if TYPE_CHECKING:
@@ -452,8 +452,11 @@ class RedshiftResumeClusterOperator(BaseOperator):
:param cluster_identifier: Unique identifier of the AWS Redshift cluster
:param aws_conn_id: The Airflow connection used for AWS credentials.
The default connection id is ``aws_default``
- :param deferrable: Run operator in deferrable mode
:param poll_interval: Time (in seconds) to wait between two consecutive
calls to check cluster state
+ :param max_attempts: The maximum number of attempts to check the state of
the cluster.
+ :param wait_for_completion: If True, the operator will wait for the
cluster to be in the
+ `resumed` state. Default is False.
+ :param deferrable: If True, the operator will run as a deferrable operator.
"""
template_fields: Sequence[str] = ("cluster_identifier",)
@@ -465,66 +468,71 @@ class RedshiftResumeClusterOperator(BaseOperator):
*,
cluster_identifier: str,
aws_conn_id: str = "aws_default",
+ wait_for_completion: bool = False,
deferrable: bool = False,
poll_interval: int = 10,
+ max_attempts: int = 10,
**kwargs,
):
super().__init__(**kwargs)
self.cluster_identifier = cluster_identifier
self.aws_conn_id = aws_conn_id
+ self.wait_for_completion = wait_for_completion
self.deferrable = deferrable
+ self.max_attempts = max_attempts
self.poll_interval = poll_interval
- # These parameters are added to address an issue with the boto3 API
where the API
+ # These parameters are used to address an issue with the boto3 API
where the API
# prematurely reports the cluster as available to receive requests.
This causes the cluster
# to reject initial attempts to resume the cluster despite reporting
the correct state.
- self._attempts = 10
+ self._remaining_attempts = 10
self._attempt_interval = 15
def execute(self, context: Context):
redshift_hook = RedshiftHook(aws_conn_id=self.aws_conn_id)
+ self.log.info("Starting resume cluster")
+ while self._remaining_attempts >= 1:
+ try:
+
redshift_hook.get_conn().resume_cluster(ClusterIdentifier=self.cluster_identifier)
+ break
+ except
redshift_hook.get_conn().exceptions.InvalidClusterStateFault as error:
+ self._remaining_attempts = self._remaining_attempts - 1
+ if self._remaining_attempts > 0:
+ self.log.error(
+ "Unable to resume cluster. %d attempts remaining.",
self._remaining_attempts
+ )
+ time.sleep(self._attempt_interval)
+ else:
+ raise error
if self.deferrable:
self.defer(
- timeout=self.execution_timeout,
- trigger=RedshiftClusterTrigger(
- task_id=self.task_id,
+ trigger=RedshiftResumeClusterTrigger(
+ cluster_identifier=self.cluster_identifier,
poll_interval=self.poll_interval,
+ max_attempts=self.max_attempts,
aws_conn_id=self.aws_conn_id,
- cluster_identifier=self.cluster_identifier,
- attempts=self._attempts,
- operation_type="resume_cluster",
),
method_name="execute_complete",
+ # timeout is set to ensure that if a trigger dies, the timeout
does not restart
+ # 60 seconds is added to allow the trigger to exit gracefully
(i.e. yield TriggerEvent)
+ timeout=timedelta(seconds=self.max_attempts *
self.poll_interval + 60),
)
+ if self.wait_for_completion:
+ waiter = redshift_hook.get_waiter("cluster_resumed")
+ waiter.wait(
+ ClusterIdentifier=self.cluster_identifier,
+ WaiterConfig={
+ "Delay": self.poll_interval,
+ "MaxAttempts": self.max_attempts,
+ },
+ )
+
+ def execute_complete(self, context, event=None):
+ if event["status"] != "success":
+ raise AirflowException(f"Error resuming cluster: {event}")
else:
- while self._attempts >= 1:
- try:
-
redshift_hook.get_conn().resume_cluster(ClusterIdentifier=self.cluster_identifier)
- return
- except
redshift_hook.get_conn().exceptions.InvalidClusterStateFault as error:
- self._attempts = self._attempts - 1
-
- if self._attempts > 0:
- self.log.error("Unable to resume cluster. %d attempts
remaining.", self._attempts)
- time.sleep(self._attempt_interval)
- else:
- raise error
-
- def execute_complete(self, context: Context, event: Any = None) -> None:
- """
- Callback for when the trigger fires - returns immediately.
- Relies on trigger to throw an exception, otherwise it assumes
execution was
- successful.
- """
- if event:
- if "status" in event and event["status"] == "error":
- msg = f"{event['status']}: {event['message']}"
- raise AirflowException(msg)
- elif "status" in event and event["status"] == "success":
- self.log.info("%s completed successfully.", self.task_id)
- self.log.info("Resumed cluster successfully")
- else:
- raise AirflowException("No event received from trigger")
+ self.log.info("Resumed cluster successfully")
+ return
class RedshiftPauseClusterOperator(BaseOperator):
diff --git a/airflow/providers/amazon/aws/triggers/redshift_cluster.py
b/airflow/providers/amazon/aws/triggers/redshift_cluster.py
index 06f008d695..3224350e54 100644
--- a/airflow/providers/amazon/aws/triggers/redshift_cluster.py
+++ b/airflow/providers/amazon/aws/triggers/redshift_cluster.py
@@ -285,3 +285,75 @@ class RedshiftCreateClusterSnapshotTrigger(BaseTrigger):
)
else:
yield TriggerEvent({"status": "success", "message": "Cluster
Snapshot Created"})
+
+
+class RedshiftResumeClusterTrigger(BaseTrigger):
+ """
+ Trigger for RedshiftResumeClusterOperator.
+ The trigger will asynchronously poll the boto3 API and wait for the
+ Redshift cluster to be in the `available` state.
+
+ :param cluster_identifier: A unique identifier for the cluster.
+ :param poll_interval: The amount of time in seconds to wait between
attempts.
+ :param max_attempts: The maximum number of attempts to be made.
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
+ """
+
+ def __init__(
+ self,
+ cluster_identifier: str,
+ poll_interval: int,
+ max_attempts: int,
+ aws_conn_id: str,
+ ):
+ self.cluster_identifier = cluster_identifier
+ self.poll_interval = poll_interval
+ self.max_attempts = max_attempts
+ self.aws_conn_id = aws_conn_id
+
+ def serialize(self) -> tuple[str, dict[str, Any]]:
+ return (
+
"airflow.providers.amazon.aws.triggers.redshift_cluster.RedshiftResumeClusterTrigger",
+ {
+ "cluster_identifier": self.cluster_identifier,
+ "poll_interval": str(self.poll_interval),
+ "max_attempts": str(self.max_attempts),
+ "aws_conn_id": self.aws_conn_id,
+ },
+ )
+
+ @cached_property
+ def hook(self) -> RedshiftHook:
+ return RedshiftHook(aws_conn_id=self.aws_conn_id)
+
+ async def run(self):
+ async with self.hook.async_conn as client:
+ attempt = 0
+ waiter = self.hook.get_waiter("cluster_resumed", deferrable=True,
client=client)
+ while attempt < int(self.max_attempts):
+ attempt = attempt + 1
+ try:
+ await waiter.wait(
+ ClusterIdentifier=self.cluster_identifier,
+ WaiterConfig={
+ "Delay": int(self.poll_interval),
+ "MaxAttempts": 1,
+ },
+ )
+ break
+ except WaiterError as error:
+ if "terminal failure" in str(error):
+ yield TriggerEvent(
+ {"status": "failure", "message": f"Resume Cluster
Failed: {error}"}
+ )
+ break
+ self.log.info(
+ "Status of cluster is %s",
error.last_response["Clusters"][0]["ClusterStatus"]
+ )
+ await asyncio.sleep(int(self.poll_interval))
+ if attempt >= int(self.max_attempts):
+ yield TriggerEvent(
+ {"status": "failure", "message": "Resume Cluster Failed - max
attempts reached."}
+ )
+ else:
+ yield TriggerEvent({"status": "success", "message": "Cluster
resumed"})
diff --git a/airflow/providers/amazon/aws/waiters/redshift.json
b/airflow/providers/amazon/aws/waiters/redshift.json
index 587f8ce989..8165eb3fc4 100644
--- a/airflow/providers/amazon/aws/waiters/redshift.json
+++ b/airflow/providers/amazon/aws/waiters/redshift.json
@@ -25,6 +25,31 @@
"state": "failure"
}
]
+ },
+ "cluster_resumed": {
+ "operation": "DescribeClusters",
+ "delay": 30,
+ "maxAttempts": 60,
+ "acceptors": [
+ {
+ "matcher": "pathAll",
+ "argument": "Clusters[].ClusterStatus",
+ "expected": "available",
+ "state": "success"
+ },
+ {
+ "matcher": "error",
+ "argument": "Clusters[].ClusterStatus",
+ "expected": "ClusterNotFound",
+ "state": "retry"
+ },
+ {
+ "matcher": "pathAny",
+ "argument": "Clusters[].ClusterStatus",
+ "expected": "deleting",
+ "state": "failure"
+ }
+ ]
}
}
}
diff --git a/tests/providers/amazon/aws/operators/test_redshift_cluster.py
b/tests/providers/amazon/aws/operators/test_redshift_cluster.py
index 82f9b19f2c..28b8d28642 100644
--- a/tests/providers/amazon/aws/operators/test_redshift_cluster.py
+++ b/tests/providers/amazon/aws/operators/test_redshift_cluster.py
@@ -32,9 +32,9 @@ from airflow.providers.amazon.aws.operators.redshift_cluster
import (
RedshiftResumeClusterOperator,
)
from airflow.providers.amazon.aws.triggers.redshift_cluster import (
- RedshiftClusterTrigger,
RedshiftCreateClusterSnapshotTrigger,
RedshiftPauseClusterTrigger,
+ RedshiftResumeClusterTrigger,
)
@@ -264,7 +264,7 @@ class TestResumeClusterOperator:
assert redshift_operator.cluster_identifier == "test_cluster"
assert redshift_operator.aws_conn_id == "aws_conn_test"
-
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.get_conn")
+ @mock.patch.object(RedshiftHook, "get_conn")
def test_resume_cluster_is_called_when_cluster_is_paused(self,
mock_get_conn):
redshift_operator = RedshiftResumeClusterOperator(
task_id="task_test", cluster_identifier="test_cluster",
aws_conn_id="aws_conn_test"
@@ -272,7 +272,7 @@ class TestResumeClusterOperator:
redshift_operator.execute(None)
mock_get_conn.return_value.resume_cluster.assert_called_once_with(ClusterIdentifier="test_cluster")
-
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.conn")
+ @mock.patch.object(RedshiftHook, "conn")
@mock.patch("time.sleep", return_value=None)
def test_resume_cluster_multiple_attempts(self, mock_sleep, mock_conn):
exception =
boto3.client("redshift").exceptions.InvalidClusterStateFault({}, "test")
@@ -288,7 +288,7 @@ class TestResumeClusterOperator:
redshift_operator.execute(None)
assert mock_conn.resume_cluster.call_count == 3
-
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.conn")
+ @mock.patch.object(RedshiftHook, "conn")
@mock.patch("time.sleep", return_value=None)
def test_resume_cluster_multiple_attempts_fail(self, mock_sleep,
mock_conn):
exception =
boto3.client("redshift").exceptions.InvalidClusterStateFault({}, "test")
@@ -306,16 +306,10 @@ class TestResumeClusterOperator:
redshift_operator.execute(None)
assert mock_conn.resume_cluster.call_count == 10
-
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.cluster_status")
-
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftAsyncHook.resume_cluster")
-
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftAsyncHook.get_client_async")
- def test_resume_cluster(self, mock_async_client,
mock_async_resume_cluster, mock_sync_cluster_status):
- """Test Resume cluster operator run"""
- mock_sync_cluster_status.return_value = "paused"
- mock_async_client.return_value.resume_cluster.return_value = {
- "Cluster": {"ClusterIdentifier": "test_cluster", "ClusterStatus":
"resuming"}
- }
- mock_async_resume_cluster.return_value = {"status": "success",
"cluster_state": "available"}
+ @mock.patch.object(RedshiftHook, "conn")
+ def test_resume_cluster_deferrable(self, mock_conn):
+ """Test Resume cluster operator deferrable"""
+ mock_conn.resume_cluster.return_value = True
redshift_operator = RedshiftResumeClusterOperator(
task_id="task_test",
@@ -328,22 +322,33 @@ class TestResumeClusterOperator:
redshift_operator.execute({})
assert isinstance(
- exc.value.trigger, RedshiftClusterTrigger
- ), "Trigger is not a RedshiftClusterTrigger"
+ exc.value.trigger, RedshiftResumeClusterTrigger
+ ), "Trigger is not a RedshiftResumeClusterTrigger"
-
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.cluster_status")
-
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftAsyncHook.resume_cluster")
-
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftAsyncHook.get_client_async")
- def test_resume_cluster_failure(
- self, mock_async_client, mock_async_resume_cluster,
mock_sync_cluster_statue
- ):
- """Test Resume cluster operator Failure"""
- mock_sync_cluster_statue.return_value = "paused"
- mock_async_client.return_value.resume_cluster.return_value = {
- "Cluster": {"ClusterIdentifier": "test_cluster", "ClusterStatus":
"resuming"}
- }
- mock_async_resume_cluster.return_value = {"status": "success",
"cluster_state": "available"}
+ @mock.patch.object(RedshiftHook, "get_waiter")
+ @mock.patch.object(RedshiftHook, "conn")
+ def test_resume_cluster_wait_for_completion(self, mock_conn,
mock_get_waiter):
+ """Test Resume cluster operator wait for complettion"""
+ mock_conn.resume_cluster.return_value = True
+ mock_get_waiter().wait.return_value = None
+
+ redshift_operator = RedshiftResumeClusterOperator(
+ task_id="task_test",
+ cluster_identifier="test_cluster",
+ aws_conn_id="aws_conn_test",
+ wait_for_completion=True,
+ )
+ redshift_operator.execute(None)
+
mock_conn.resume_cluster.assert_called_once_with(ClusterIdentifier="test_cluster")
+ mock_get_waiter.assert_called_with("cluster_resumed")
+ assert mock_get_waiter.call_count == 2
+ mock_get_waiter().wait.assert_called_once_with(
+ ClusterIdentifier="test_cluster", WaiterConfig={"Delay": 10,
"MaxAttempts": 10}
+ )
+
+ def test_resume_cluster_failure(self):
+ """Test Resume cluster operator Failure"""
redshift_operator = RedshiftResumeClusterOperator(
task_id="task_test",
cluster_identifier="test_cluster",
diff --git a/tests/providers/amazon/aws/triggers/test_redshift_cluster.py
b/tests/providers/amazon/aws/triggers/test_redshift_cluster.py
index b79286a093..92379b3ac3 100644
--- a/tests/providers/amazon/aws/triggers/test_redshift_cluster.py
+++ b/tests/providers/amazon/aws/triggers/test_redshift_cluster.py
@@ -26,6 +26,7 @@ from airflow.providers.amazon.aws.triggers.redshift_cluster
import (
RedshiftCreateClusterSnapshotTrigger,
RedshiftCreateClusterTrigger,
RedshiftPauseClusterTrigger,
+ RedshiftResumeClusterTrigger,
)
from airflow.triggers.base import TriggerEvent
@@ -364,3 +365,144 @@ class TestRedshiftCreateClusterSnapshotTrigger:
assert response == TriggerEvent(
{"status": "failure", "message": f"Create Cluster Snapshot Failed:
{error_failed}"}
)
+
+
+class TestRedshiftResumeClusterTrigger:
+ def test_redshift_resume_cluster_trigger_serialize(self):
+ redshift_resume_cluster_trigger = RedshiftResumeClusterTrigger(
+ cluster_identifier=TEST_CLUSTER_IDENTIFIER,
+ poll_interval=TEST_POLL_INTERVAL,
+ max_attempts=TEST_MAX_ATTEMPT,
+ aws_conn_id=TEST_AWS_CONN_ID,
+ )
+ class_path, args = redshift_resume_cluster_trigger.serialize()
+ assert (
+ class_path
+ ==
"airflow.providers.amazon.aws.triggers.redshift_cluster.RedshiftResumeClusterTrigger"
+ )
+ assert args["cluster_identifier"] == TEST_CLUSTER_IDENTIFIER
+ assert args["poll_interval"] == str(TEST_POLL_INTERVAL)
+ assert args["max_attempts"] == str(TEST_MAX_ATTEMPT)
+ assert args["aws_conn_id"] == TEST_AWS_CONN_ID
+
+ @pytest.mark.asyncio
+ @async_mock.patch.object(RedshiftHook, "get_waiter")
+ @async_mock.patch.object(RedshiftHook, "async_conn")
+ async def test_redshift_resume_cluster_trigger_run(self, mock_async_conn,
mock_get_waiter):
+ mock = async_mock.MagicMock()
+ mock_async_conn.__aenter__.return_value = mock
+
+ mock_get_waiter().wait = AsyncMock()
+
+ redshift_resume_cluster_trigger = RedshiftResumeClusterTrigger(
+ cluster_identifier=TEST_CLUSTER_IDENTIFIER,
+ poll_interval=TEST_POLL_INTERVAL,
+ max_attempts=TEST_MAX_ATTEMPT,
+ aws_conn_id=TEST_AWS_CONN_ID,
+ )
+
+ generator = redshift_resume_cluster_trigger.run()
+ response = await generator.asend(None)
+
+ assert response == TriggerEvent({"status": "success", "message":
"Cluster resumed"})
+
+ @pytest.mark.asyncio
+ @async_mock.patch("asyncio.sleep")
+ @async_mock.patch.object(RedshiftHook, "get_waiter")
+ @async_mock.patch.object(RedshiftHook, "async_conn")
+ async def test_redshift_resume_cluster_trigger_run_multiple_attempts(
+ self, mock_async_conn, mock_get_waiter, mock_sleep
+ ):
+ mock = async_mock.MagicMock()
+ mock_async_conn.__aenter__.return_value = mock
+ error = WaiterError(
+ name="test_name",
+ reason="test_reason",
+ last_response={"Clusters": [{"ClusterStatus": "available"}]},
+ )
+ mock_get_waiter().wait.side_effect = AsyncMock(side_effect=[error,
error, True])
+ mock_sleep.return_value = True
+
+ redshift_resume_cluster_trigger = RedshiftResumeClusterTrigger(
+ cluster_identifier=TEST_CLUSTER_IDENTIFIER,
+ poll_interval=TEST_POLL_INTERVAL,
+ max_attempts=TEST_MAX_ATTEMPT,
+ aws_conn_id=TEST_AWS_CONN_ID,
+ )
+
+ generator = redshift_resume_cluster_trigger.run()
+ response = await generator.asend(None)
+
+ assert mock_get_waiter().wait.call_count == 3
+ assert response == TriggerEvent({"status": "success", "message":
"Cluster resumed"})
+
+ @pytest.mark.asyncio
+ @async_mock.patch("asyncio.sleep")
+ @async_mock.patch.object(RedshiftHook, "get_waiter")
+ @async_mock.patch.object(RedshiftHook, "async_conn")
+ async def test_redshift_resume_cluster_trigger_run_attempts_exceeded(
+ self, mock_async_conn, mock_get_waiter, mock_sleep
+ ):
+ mock = async_mock.MagicMock()
+ mock_async_conn.__aenter__.return_value = mock
+ error = WaiterError(
+ name="test_name",
+ reason="test_reason",
+ last_response={"Clusters": [{"ClusterStatus": "available"}]},
+ )
+ mock_get_waiter().wait.side_effect = AsyncMock(side_effect=[error,
error, True])
+ mock_sleep.return_value = True
+
+ redshift_resume_cluster_trigger = RedshiftResumeClusterTrigger(
+ cluster_identifier=TEST_CLUSTER_IDENTIFIER,
+ poll_interval=TEST_POLL_INTERVAL,
+ max_attempts=2,
+ aws_conn_id=TEST_AWS_CONN_ID,
+ )
+
+ generator = redshift_resume_cluster_trigger.run()
+ response = await generator.asend(None)
+
+ assert mock_get_waiter().wait.call_count == 2
+ assert response == TriggerEvent(
+ {"status": "failure", "message": "Resume Cluster Failed - max
attempts reached."}
+ )
+
+ @pytest.mark.asyncio
+ @async_mock.patch("asyncio.sleep")
+ @async_mock.patch.object(RedshiftHook, "get_waiter")
+ @async_mock.patch.object(RedshiftHook, "async_conn")
+ async def test_redshift_resume_cluster_trigger_run_attempts_failed(
+ self, mock_async_conn, mock_get_waiter, mock_sleep
+ ):
+ mock = async_mock.MagicMock()
+ mock_async_conn.__aenter__.return_value = mock
+ error_available = WaiterError(
+ name="test_name",
+ reason="Max attempts exceeded",
+ last_response={"Clusters": [{"ClusterStatus": "available"}]},
+ )
+ error_failed = WaiterError(
+ name="test_name",
+ reason="Waiter encountered a terminal failure state:",
+ last_response={"Clusters": [{"ClusterStatus": "available"}]},
+ )
+ mock_get_waiter().wait.side_effect = AsyncMock(
+ side_effect=[error_available, error_available, error_failed]
+ )
+ mock_sleep.return_value = True
+
+ redshift_resume_cluster_trigger = RedshiftResumeClusterTrigger(
+ cluster_identifier=TEST_CLUSTER_IDENTIFIER,
+ poll_interval=TEST_POLL_INTERVAL,
+ max_attempts=TEST_MAX_ATTEMPT,
+ aws_conn_id=TEST_AWS_CONN_ID,
+ )
+
+ generator = redshift_resume_cluster_trigger.run()
+ response = await generator.asend(None)
+
+ assert mock_get_waiter().wait.call_count == 3
+ assert response == TriggerEvent(
+ {"status": "failure", "message": f"Resume Cluster Failed:
{error_failed}"}
+ )