This is an automated email from the ASF dual-hosted git repository.
vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 7fe573317e Fix AWS Redshift operators and sensors (#41191)
7fe573317e is described below
commit 7fe573317eb630c2d176329c599d6fbbb30f4378
Author: Vincent <[email protected]>
AuthorDate: Thu Aug 1 13:54:14 2024 -0400
Fix AWS Redshift operators and sensors (#41191)
---
.../amazon/aws/operators/redshift_cluster.py | 118 ++++++++++++---------
.../amazon/aws/operators/test_redshift_cluster.py | 25 ++++-
.../providers/amazon/aws/example_redshift.py | 4 +-
3 files changed, 91 insertions(+), 56 deletions(-)
diff --git a/airflow/providers/amazon/aws/operators/redshift_cluster.py
b/airflow/providers/amazon/aws/operators/redshift_cluster.py
index f31b7db09a..4666445f96 100644
--- a/airflow/providers/amazon/aws/operators/redshift_cluster.py
+++ b/airflow/providers/amazon/aws/operators/redshift_cluster.py
@@ -32,6 +32,7 @@ from airflow.providers.amazon.aws.triggers.redshift_cluster
import (
RedshiftResumeClusterTrigger,
)
from airflow.providers.amazon.aws.utils import validate_execute_complete_event
+from airflow.utils.helpers import prune_dict
if TYPE_CHECKING:
from airflow.utils.context import Context
@@ -507,8 +508,8 @@ class RedshiftResumeClusterOperator(BaseOperator):
aws_conn_id: str | None = "aws_default",
wait_for_completion: bool = False,
deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
- poll_interval: int = 10,
- max_attempts: int = 10,
+ poll_interval: int = 30,
+ max_attempts: int = 30,
**kwargs,
):
super().__init__(**kwargs)
@@ -542,38 +543,38 @@ class RedshiftResumeClusterOperator(BaseOperator):
else:
raise error
- if self.deferrable:
- cluster_state =
redshift_hook.cluster_status(cluster_identifier=self.cluster_identifier)
- if cluster_state == "available":
- self.log.info("Resumed cluster successfully")
- elif cluster_state == "deleting":
- raise AirflowException(
- "Unable to resume cluster since cluster is currently in
status: %s", cluster_state
- )
+ if self.wait_for_completion:
+ if self.deferrable:
+ cluster_state =
redshift_hook.cluster_status(cluster_identifier=self.cluster_identifier)
+ if cluster_state == "available":
+ self.log.info("Resumed cluster successfully")
+ elif cluster_state == "deleting":
+ raise AirflowException(
+ "Unable to resume cluster since cluster is currently
in status: %s", cluster_state
+ )
+ else:
+ self.defer(
+ trigger=RedshiftResumeClusterTrigger(
+ cluster_identifier=self.cluster_identifier,
+ waiter_delay=self.poll_interval,
+ waiter_max_attempts=self.max_attempts,
+ aws_conn_id=self.aws_conn_id,
+ ),
+ method_name="execute_complete",
+ # timeout is set to ensure that if a trigger dies, the
timeout does not restart
+ # 60 seconds is added to allow the trigger to exit
gracefully (i.e. yield TriggerEvent)
+ timeout=timedelta(seconds=self.max_attempts *
self.poll_interval + 60),
+ )
else:
- self.defer(
- trigger=RedshiftResumeClusterTrigger(
- cluster_identifier=self.cluster_identifier,
- waiter_delay=self.poll_interval,
- waiter_max_attempts=self.max_attempts,
- aws_conn_id=self.aws_conn_id,
- ),
- method_name="execute_complete",
- # timeout is set to ensure that if a trigger dies, the
timeout does not restart
- # 60 seconds is added to allow the trigger to exit
gracefully (i.e. yield TriggerEvent)
- timeout=timedelta(seconds=self.max_attempts *
self.poll_interval + 60),
+ waiter = redshift_hook.get_waiter("cluster_resumed")
+ waiter.wait(
+ ClusterIdentifier=self.cluster_identifier,
+ WaiterConfig={
+ "Delay": self.poll_interval,
+ "MaxAttempts": self.max_attempts,
+ },
)
- if self.wait_for_completion:
- waiter = redshift_hook.get_waiter("cluster_resumed")
- waiter.wait(
- ClusterIdentifier=self.cluster_identifier,
- WaiterConfig={
- "Delay": self.poll_interval,
- "MaxAttempts": self.max_attempts,
- },
- )
-
def execute_complete(self, context: Context, event: dict[str, Any] | None
= None) -> None:
event = validate_execute_complete_event(event)
@@ -596,6 +597,7 @@ class RedshiftPauseClusterOperator(BaseOperator):
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
+ :param wait_for_completion: If True, waits for the cluster to be paused.
(default: False)
:param deferrable: Run operator in the deferrable mode
:param poll_interval: Time (in seconds) to wait between two consecutive
calls to check cluster state
:param max_attempts: Maximum number of attempts to poll the cluster
@@ -610,14 +612,16 @@ class RedshiftPauseClusterOperator(BaseOperator):
*,
cluster_identifier: str,
aws_conn_id: str | None = "aws_default",
+ wait_for_completion: bool = False,
deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
- poll_interval: int = 10,
- max_attempts: int = 15,
+ poll_interval: int = 30,
+ max_attempts: int = 30,
**kwargs,
):
super().__init__(**kwargs)
self.cluster_identifier = cluster_identifier
self.aws_conn_id = aws_conn_id
+ self.wait_for_completion = wait_for_completion
self.deferrable = deferrable
self.max_attempts = max_attempts
self.poll_interval = poll_interval
@@ -643,26 +647,38 @@ class RedshiftPauseClusterOperator(BaseOperator):
time.sleep(self._attempt_interval)
else:
raise error
- if self.deferrable:
- cluster_state =
redshift_hook.cluster_status(cluster_identifier=self.cluster_identifier)
- if cluster_state == "paused":
- self.log.info("Paused cluster successfully")
- elif cluster_state == "deleting":
- raise AirflowException(
- f"Unable to pause cluster since cluster is currently in
status: {cluster_state}"
- )
+ if self.wait_for_completion:
+ if self.deferrable:
+ cluster_state =
redshift_hook.cluster_status(cluster_identifier=self.cluster_identifier)
+ if cluster_state == "paused":
+ self.log.info("Paused cluster successfully")
+ elif cluster_state == "deleting":
+ raise AirflowException(
+ f"Unable to pause cluster since cluster is currently
in status: {cluster_state}"
+ )
+ else:
+ self.defer(
+ trigger=RedshiftPauseClusterTrigger(
+ cluster_identifier=self.cluster_identifier,
+ waiter_delay=self.poll_interval,
+ waiter_max_attempts=self.max_attempts,
+ aws_conn_id=self.aws_conn_id,
+ ),
+ method_name="execute_complete",
+ # timeout is set to ensure that if a trigger dies, the
timeout does not restart
+ # 60 seconds is added to allow the trigger to exit
gracefully (i.e. yield TriggerEvent)
+ timeout=timedelta(seconds=self.max_attempts *
self.poll_interval + 60),
+ )
else:
- self.defer(
- trigger=RedshiftPauseClusterTrigger(
- cluster_identifier=self.cluster_identifier,
- waiter_delay=self.poll_interval,
- waiter_max_attempts=self.max_attempts,
- aws_conn_id=self.aws_conn_id,
+ waiter = redshift_hook.get_waiter("cluster_paused")
+ waiter.wait(
+ ClusterIdentifier=self.cluster_identifier,
+ WaiterConfig=prune_dict(
+ {
+ "Delay": self.poll_interval,
+ "MaxAttempts": self.max_attempts,
+ }
),
- method_name="execute_complete",
- # timeout is set to ensure that if a trigger dies, the
timeout does not restart
- # 60 seconds is added to allow the trigger to exit
gracefully (i.e. yield TriggerEvent)
- timeout=timedelta(seconds=self.max_attempts *
self.poll_interval + 60),
)
def execute_complete(self, context: Context, event: dict[str, Any] | None
= None) -> None:
diff --git a/tests/providers/amazon/aws/operators/test_redshift_cluster.py
b/tests/providers/amazon/aws/operators/test_redshift_cluster.py
index c537f4ce24..f6c960b7bd 100644
--- a/tests/providers/amazon/aws/operators/test_redshift_cluster.py
+++ b/tests/providers/amazon/aws/operators/test_redshift_cluster.py
@@ -17,6 +17,7 @@
from __future__ import annotations
from unittest import mock
+from unittest.mock import Mock
import boto3
import pytest
@@ -318,6 +319,7 @@ class TestResumeClusterOperator:
task_id="task_test",
cluster_identifier="test_cluster",
aws_conn_id="aws_conn_test",
+ wait_for_completion=True,
deferrable=True,
)
@@ -340,6 +342,7 @@ class TestResumeClusterOperator:
task_id="task_test",
cluster_identifier="test_cluster",
aws_conn_id="aws_conn_test",
+ wait_for_completion=True,
deferrable=True,
)
@@ -366,7 +369,7 @@ class TestResumeClusterOperator:
mock_get_waiter.assert_called_with("cluster_resumed")
assert mock_get_waiter.call_count == 2
mock_get_waiter().wait.assert_called_once_with(
- ClusterIdentifier="test_cluster", WaiterConfig={"Delay": 10,
"MaxAttempts": 10}
+ ClusterIdentifier="test_cluster", WaiterConfig={"Delay": 30,
"MaxAttempts": 30}
)
def test_resume_cluster_failure(self):
@@ -437,6 +440,22 @@ class TestPauseClusterOperator:
redshift_operator.execute(None)
assert mock_conn.pause_cluster.call_count == 10
+ @mock.patch.object(RedshiftHook, "get_waiter")
+ @mock.patch.object(RedshiftHook, "get_conn")
+ def test_pause_cluster_wait_for_completion(self, mock_get_conn,
mock_get_waiter):
+ """Test Pause cluster operator with defer when deferrable param is
true"""
+ mock_get_conn.return_value.pause_cluster.return_value = True
+ waiter = Mock()
+ mock_get_waiter.return_value = waiter
+
+ redshift_operator = RedshiftPauseClusterOperator(
+ task_id="task_test", cluster_identifier="test_cluster",
wait_for_completion=True
+ )
+
+ redshift_operator.execute(context=None)
+
+ waiter.wait.assert_called_once()
+
@mock.patch.object(RedshiftHook, "cluster_status")
@mock.patch.object(RedshiftHook, "get_conn")
def test_pause_cluster_deferrable_mode(self, mock_get_conn,
mock_cluster_status):
@@ -445,7 +464,7 @@ class TestPauseClusterOperator:
mock_cluster_status.return_value = "available"
redshift_operator = RedshiftPauseClusterOperator(
- task_id="task_test", cluster_identifier="test_cluster",
deferrable=True
+ task_id="task_test", cluster_identifier="test_cluster",
wait_for_completion=True, deferrable=True
)
with pytest.raises(TaskDeferred) as exc:
@@ -466,7 +485,7 @@ class TestPauseClusterOperator:
mock_cluster_status.return_value = "deleting"
redshift_operator = RedshiftPauseClusterOperator(
- task_id="task_test", cluster_identifier="test_cluster",
deferrable=True
+ task_id="task_test", cluster_identifier="test_cluster",
wait_for_completion=True, deferrable=True
)
with pytest.raises(AirflowException):
diff --git a/tests/system/providers/amazon/aws/example_redshift.py
b/tests/system/providers/amazon/aws/example_redshift.py
index cc88811bef..cc92076dcb 100644
--- a/tests/system/providers/amazon/aws/example_redshift.py
+++ b/tests/system/providers/amazon/aws/example_redshift.py
@@ -121,7 +121,7 @@ with DAG(
task_id="wait_cluster_paused",
cluster_identifier=redshift_cluster_identifier,
target_status="paused",
- poke_interval=15,
+ poke_interval=30,
timeout=60 * 30,
)
@@ -136,7 +136,7 @@ with DAG(
task_id="wait_cluster_available_after_resume",
cluster_identifier=redshift_cluster_identifier,
target_status="available",
- poke_interval=15,
+ poke_interval=30,
timeout=60 * 30,
)