This is an automated email from the ASF dual-hosted git repository.
vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 83ca61a501 Fix `RdsStopDbOperator` operator in deferrable mode (#41059)
83ca61a501 is described below
commit 83ca61a501d755669fc83b1ad9038d0ca9d600ad
Author: Vincent <[email protected]>
AuthorDate: Fri Jul 26 16:52:25 2024 -0400
Fix `RdsStopDbOperator` operator in deferrable mode (#41059)
---
airflow/providers/amazon/aws/hooks/rds.py | 6 +-
airflow/providers/amazon/aws/operators/rds.py | 37 ++--
airflow/providers/amazon/aws/waiters/rds.json | 253 +++++++++++++++++++++++
tests/providers/amazon/aws/hooks/test_rds.py | 24 +--
tests/providers/amazon/aws/operators/test_rds.py | 14 +-
5 files changed, 280 insertions(+), 54 deletions(-)
diff --git a/airflow/providers/amazon/aws/hooks/rds.py
b/airflow/providers/amazon/aws/hooks/rds.py
index 8219a37757..588d78c782 100644
--- a/airflow/providers/amazon/aws/hooks/rds.py
+++ b/airflow/providers/amazon/aws/hooks/rds.py
@@ -259,7 +259,7 @@ class RdsHook(AwsGenericHook["RDSClient"]):
return self.get_db_instance_state(db_instance_id)
target_state = target_state.lower()
- if target_state in ("available", "deleted"):
+ if target_state in ("available", "deleted", "stopped"):
waiter = self.conn.get_waiter(f"db_instance_{target_state}") #
type: ignore
wait(
waiter=waiter,
@@ -272,7 +272,7 @@ class RdsHook(AwsGenericHook["RDSClient"]):
)
else:
self._wait_for_state(poke, target_state, check_interval,
max_attempts)
- self.log.info("DB cluster snapshot '%s' reached the '%s' state",
db_instance_id, target_state)
+ self.log.info("DB cluster '%s' reached the '%s' state",
db_instance_id, target_state)
def get_db_cluster_state(self, db_cluster_id: str) -> str:
"""
@@ -310,7 +310,7 @@ class RdsHook(AwsGenericHook["RDSClient"]):
return self.get_db_cluster_state(db_cluster_id)
target_state = target_state.lower()
- if target_state in ("available", "deleted"):
+ if target_state in ("available", "deleted", "stopped"):
waiter = self.conn.get_waiter(f"db_cluster_{target_state}") #
type: ignore
waiter.wait(
DBClusterIdentifier=db_cluster_id,
diff --git a/airflow/providers/amazon/aws/operators/rds.py
b/airflow/providers/amazon/aws/operators/rds.py
index a2f35b5081..f37c698d87 100644
--- a/airflow/providers/amazon/aws/operators/rds.py
+++ b/airflow/providers/amazon/aws/operators/rds.py
@@ -36,6 +36,7 @@ from airflow.providers.amazon.aws.utils import
validate_execute_complete_event
from airflow.providers.amazon.aws.utils.rds import RdsDbType
from airflow.providers.amazon.aws.utils.tags import format_tags
from airflow.providers.amazon.aws.utils.waiter_with_logging import wait
+from airflow.utils.helpers import prune_dict
if TYPE_CHECKING:
from mypy_boto3_rds.type_defs import TagTypeDef
@@ -782,7 +783,7 @@ class RdsStartDbOperator(RdsBaseOperator):
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
response=start_db_response,
- db_type=RdsDbType.INSTANCE,
+ db_type=self.db_type,
),
method_name="execute_complete",
)
@@ -881,12 +882,25 @@ class RdsStopDbOperator(RdsBaseOperator):
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
response=stop_db_response,
- db_type=RdsDbType.INSTANCE,
+ db_type=self.db_type,
),
method_name="execute_complete",
)
elif self.wait_for_completion:
- self._wait_until_db_stopped()
+ waiter = self.hook.get_waiter(f"db_{self.db_type.value}_stopped")
+ waiter_key = (
+ "DBInstanceIdentifier" if self.db_type == RdsDbType.INSTANCE
else "DBClusterIdentifier"
+ )
+ kwargs = {waiter_key: self.db_identifier}
+ waiter.wait(
+ WaiterConfig=prune_dict(
+ {
+ "Delay": self.waiter_delay,
+ "MaxAttempts": self.waiter_max_attempts,
+ }
+ ),
+ **kwargs,
+ )
return json.dumps(stop_db_response, default=str)
def execute_complete(self, context: Context, event: dict[str, Any] | None
= None) -> str:
@@ -915,23 +929,6 @@ class RdsStopDbOperator(RdsBaseOperator):
response =
self.hook.conn.stop_db_cluster(DBClusterIdentifier=self.db_identifier)
return response
- def _wait_until_db_stopped(self):
- self.log.info("Waiting for DB %s to reach 'stopped' state",
self.db_type.value)
- if self.db_type == RdsDbType.INSTANCE:
- self.hook.wait_for_db_instance_state(
- self.db_identifier,
- target_state="stopped",
- check_interval=self.waiter_delay,
- max_attempts=self.waiter_max_attempts,
- )
- else:
- self.hook.wait_for_db_cluster_state(
- self.db_identifier,
- target_state="stopped",
- check_interval=self.waiter_delay,
- max_attempts=self.waiter_max_attempts,
- )
-
__all__ = [
"RdsCreateDbSnapshotOperator",
diff --git a/airflow/providers/amazon/aws/waiters/rds.json
b/airflow/providers/amazon/aws/waiters/rds.json
new file mode 100644
index 0000000000..78c56da53f
--- /dev/null
+++ b/airflow/providers/amazon/aws/waiters/rds.json
@@ -0,0 +1,253 @@
+{
+ "version": 2,
+ "waiters": {
+ "db_instance_stopped": {
+ "operation": "DescribeDBInstances",
+ "delay": 30,
+ "maxAttempts": 60,
+ "acceptors": [
+ {
+ "matcher": "pathAll",
+ "argument": "DBInstances[].DBInstanceStatus",
+ "expected": "stopped",
+ "state": "success"
+ },
+ {
+ "matcher": "pathAll",
+ "argument": "DBInstances[].DBInstanceStatus",
+ "expected": "available",
+ "state": "failure"
+ },
+ {
+ "matcher": "pathAll",
+ "argument": "DBInstances[].DBInstanceStatus",
+ "expected": "backing-up",
+ "state": "retry"
+ },
+ {
+ "matcher": "pathAll",
+ "argument": "DBInstances[].DBInstanceStatus",
+ "expected": "creating",
+ "state": "retry"
+ },
+ {
+ "matcher": "pathAll",
+ "argument": "DBInstances[].DBInstanceStatus",
+ "expected": "delete-precheck",
+ "state": "failure"
+ },
+ {
+ "matcher": "pathAll",
+ "argument": "DBInstances[].DBInstanceStatus",
+ "expected": "deleting",
+ "state": "failure"
+ },
+ {
+ "matcher": "pathAll",
+ "argument": "DBInstances[].DBInstanceStatus",
+ "expected": "failed",
+ "state": "failure"
+ },
+ {
+ "matcher": "pathAll",
+ "argument": "DBInstances[].DBInstanceStatus",
+ "expected": "inaccessible-encryption-credentials",
+ "state": "failure"
+ },
+ {
+ "matcher": "pathAll",
+ "argument": "DBInstances[].DBInstanceStatus",
+ "expected":
"inaccessible-encryption-credentials-recoverable",
+ "state": "failure"
+ },
+ {
+ "matcher": "pathAll",
+ "argument": "DBInstances[].DBInstanceStatus",
+ "expected": "incompatible-network",
+ "state": "failure"
+ },
+ {
+ "matcher": "pathAll",
+ "argument": "DBInstances[].DBInstanceStatus",
+ "expected": "incompatible-option-group",
+ "state": "failure"
+ },
+ {
+ "matcher": "pathAll",
+ "argument": "DBInstances[].DBInstanceStatus",
+ "expected": "incompatible-parameters",
+ "state": "failure"
+ },
+ {
+ "matcher": "pathAll",
+ "argument": "DBInstances[].DBInstanceStatus",
+ "expected": "incompatible-restore",
+ "state": "failure"
+ },
+ {
+ "matcher": "pathAll",
+ "argument": "DBInstances[].DBInstanceStatus",
+ "expected": "insufficient-capacity",
+ "state": "failure"
+ },
+ {
+ "matcher": "pathAll",
+ "argument": "DBInstances[].DBInstanceStatus",
+ "expected": "maintenance",
+ "state": "retry"
+ },
+ {
+ "matcher": "pathAll",
+ "argument": "DBInstances[].DBInstanceStatus",
+ "expected": "modifying",
+ "state": "retry"
+ },
+ {
+ "matcher": "pathAll",
+ "argument": "DBInstances[].DBInstanceStatus",
+ "expected": "rebooting",
+ "state": "retry"
+ },
+ {
+ "matcher": "pathAll",
+ "argument": "DBInstances[].DBInstanceStatus",
+ "expected": "renaming",
+ "state": "retry"
+ },
+ {
+ "matcher": "pathAll",
+ "argument": "DBInstances[].DBInstanceStatus",
+ "expected": "restore-error",
+ "state": "failure"
+ },
+ {
+ "matcher": "pathAll",
+ "argument": "DBInstances[].DBInstanceStatus",
+ "expected": "starting",
+ "state": "retry"
+ },
+ {
+ "matcher": "pathAll",
+ "argument": "DBInstances[].DBInstanceStatus",
+ "expected": "stopping",
+ "state": "retry"
+ },
+ {
+ "matcher": "pathAll",
+ "argument": "DBInstances[].DBInstanceStatus",
+ "expected": "storage-full",
+ "state": "failure"
+ },
+ {
+ "matcher": "pathAll",
+ "argument": "DBInstances[].DBInstanceStatus",
+ "expected": "upgrading",
+ "state": "retry"
+ }
+ ]
+ },
+ "db_cluster_stopped": {
+ "operation": "DescribeDBClusters",
+ "delay": 30,
+ "maxAttempts": 60,
+ "acceptors": [
+ {
+ "matcher": "pathAll",
+ "argument": "DBClusters[].Status",
+ "expected": "stopped",
+ "state": "success"
+ },
+ {
+ "matcher": "pathAll",
+ "argument": "DBInstances[].DBInstanceStatus",
+ "expected": "available",
+ "state": "failure"
+ },
+ {
+ "matcher": "pathAll",
+ "argument": "DBInstances[].DBInstanceStatus",
+ "expected": "backing-up",
+ "state": "retry"
+ },
+ {
+ "matcher": "pathAll",
+ "argument": "DBInstances[].DBInstanceStatus",
+ "expected": "cloning-failed",
+ "state": "retry"
+ },
+ {
+ "matcher": "pathAll",
+ "argument": "DBInstances[].DBInstanceStatus",
+ "expected": "creating",
+ "state": "retry"
+ },
+ {
+ "matcher": "pathAll",
+ "argument": "DBInstances[].DBInstanceStatus",
+ "expected": "deleting",
+ "state": "failure"
+ },
+ {
+ "matcher": "pathAll",
+ "argument": "DBInstances[].DBInstanceStatus",
+ "expected": "failing-over",
+ "state": "failure"
+ },
+ {
+ "matcher": "pathAll",
+ "argument": "DBInstances[].DBInstanceStatus",
+ "expected": "inaccessible-encryption-credentials",
+ "state": "failure"
+ },
+ {
+ "matcher": "pathAll",
+ "argument": "DBInstances[].DBInstanceStatus",
+ "expected":
"inaccessible-encryption-credentials-recoverable",
+ "state": "failure"
+ },
+ {
+ "matcher": "pathAll",
+ "argument": "DBInstances[].DBInstanceStatus",
+ "expected": "maintenance",
+ "state": "retry"
+ },
+ {
+ "matcher": "pathAll",
+ "argument": "DBInstances[].DBInstanceStatus",
+ "expected": "migrating",
+ "state": "retry"
+ },
+ {
+ "matcher": "pathAll",
+ "argument": "DBInstances[].DBInstanceStatus",
+ "expected": "migration-failed",
+ "state": "failure"
+ },
+ {
+ "matcher": "pathAll",
+ "argument": "DBInstances[].DBInstanceStatus",
+ "expected": "modifying",
+ "state": "retry"
+ },
+ {
+ "matcher": "pathAll",
+ "argument": "DBInstances[].DBInstanceStatus",
+ "expected": "renaming",
+ "state": "retry"
+ },
+ {
+ "matcher": "pathAll",
+ "argument": "DBInstances[].DBInstanceStatus",
+ "expected": "starting",
+ "state": "retry"
+ },
+ {
+ "matcher": "pathAll",
+ "argument": "DBInstances[].DBInstanceStatus",
+ "expected": "stopping",
+ "state": "retry"
+ }
+ ]
+ }
+ }
+}
diff --git a/tests/providers/amazon/aws/hooks/test_rds.py
b/tests/providers/amazon/aws/hooks/test_rds.py
index b2668febfb..77159b4e14 100644
--- a/tests/providers/amazon/aws/hooks/test_rds.py
+++ b/tests/providers/amazon/aws/hooks/test_rds.py
@@ -150,7 +150,7 @@ class TestRdsHook:
def test_wait_for_db_instance_state_boto_waiters(self, rds_hook: RdsHook,
db_instance_id: str):
"""Checks that the DB instance waiter uses AWS boto waiters where
possible"""
- for state in ("available", "deleted"):
+ for state in ("available", "deleted", "stopped"):
with patch.object(rds_hook.conn, "get_waiter") as mock:
rds_hook.wait_for_db_instance_state(db_instance_id,
target_state=state, **self.waiter_args)
mock.assert_called_once_with(f"db_instance_{state}")
@@ -161,16 +161,6 @@ class TestRdsHook:
},
)
- def test_wait_for_db_instance_state_custom_waiter(self, rds_hook: RdsHook,
db_instance_id: str):
- """Checks that the DB instance waiter uses custom wait logic when AWS
boto waiters aren't available"""
- with patch.object(rds_hook, "_wait_for_state") as mock:
- rds_hook.wait_for_db_instance_state(db_instance_id,
target_state="stopped", **self.waiter_args)
- mock.assert_called_once()
-
- with patch.object(rds_hook, "get_db_instance_state",
return_value="stopped") as mock:
- rds_hook.wait_for_db_instance_state(db_instance_id,
target_state="stopped", **self.waiter_args)
- mock.assert_called_once_with(db_instance_id)
-
def test_get_db_cluster_state(self, rds_hook: RdsHook, db_cluster_id: str):
response =
rds_hook.conn.describe_db_clusters(DBClusterIdentifier=db_cluster_id)
state_expected = response["DBClusters"][0]["Status"]
@@ -179,7 +169,7 @@ class TestRdsHook:
def test_wait_for_db_cluster_state_boto_waiters(self, rds_hook: RdsHook,
db_cluster_id: str):
"""Checks that the DB cluster waiter uses AWS boto waiters where
possible"""
- for state in ("available", "deleted"):
+ for state in ("available", "deleted", "stopped"):
with patch.object(rds_hook.conn, "get_waiter") as mock:
rds_hook.wait_for_db_cluster_state(db_cluster_id,
target_state=state, **self.waiter_args)
mock.assert_called_once_with(f"db_cluster_{state}")
@@ -191,16 +181,6 @@ class TestRdsHook:
},
)
- def test_wait_for_db_cluster_state_custom_waiter(self, rds_hook: RdsHook,
db_cluster_id: str):
- """Checks that the DB cluster waiter uses custom wait logic when AWS
boto waiters aren't available"""
- with patch.object(rds_hook, "_wait_for_state") as mock_wait_for_state:
- rds_hook.wait_for_db_cluster_state(db_cluster_id,
target_state="stopped", **self.waiter_args)
- mock_wait_for_state.assert_called_once()
-
- with patch.object(rds_hook, "get_db_cluster_state",
return_value="stopped") as mock:
- rds_hook.wait_for_db_cluster_state(db_cluster_id,
target_state="stopped", **self.waiter_args)
- mock.assert_called_once_with(db_cluster_id)
-
def test_get_db_snapshot_state(self, rds_hook: RdsHook, db_snapshot_id:
str):
response =
rds_hook.conn.describe_db_snapshots(DBSnapshotIdentifier=db_snapshot_id)
state_expected = response["DBSnapshots"][0]["Status"]
diff --git a/tests/providers/amazon/aws/operators/test_rds.py
b/tests/providers/amazon/aws/operators/test_rds.py
index fd464019dd..651db53d42 100644
--- a/tests/providers/amazon/aws/operators/test_rds.py
+++ b/tests/providers/amazon/aws/operators/test_rds.py
@@ -813,8 +813,7 @@ class TestRdsStopDbOperator:
del cls.hook
@mock_aws
- @patch.object(RdsHook, "wait_for_db_instance_state")
- def test_stop_db_instance(self, mock_await_status):
+ def test_stop_db_instance(self):
_create_db_instance(self.hook)
stop_db_instance = RdsStopDbOperator(task_id="test_stop_db_instance",
db_identifier=DB_INSTANCE_NAME)
_patch_hook_get_connection(stop_db_instance.hook)
@@ -822,11 +821,10 @@ class TestRdsStopDbOperator:
result =
self.hook.conn.describe_db_instances(DBInstanceIdentifier=DB_INSTANCE_NAME)
status = result["DBInstances"][0]["DBInstanceStatus"]
assert status == "stopped"
- mock_await_status.assert_called()
@mock_aws
- @patch.object(RdsHook, "wait_for_db_instance_state")
- def test_stop_db_instance_no_wait(self, mock_await_status):
+ @patch.object(RdsHook, "get_waiter")
+ def test_stop_db_instance_no_wait(self, mock_get_waiter):
_create_db_instance(self.hook)
stop_db_instance = RdsStopDbOperator(
task_id="test_stop_db_instance_no_wait",
db_identifier=DB_INSTANCE_NAME, wait_for_completion=False
@@ -836,7 +834,7 @@ class TestRdsStopDbOperator:
result =
self.hook.conn.describe_db_instances(DBInstanceIdentifier=DB_INSTANCE_NAME)
status = result["DBInstances"][0]["DBInstanceStatus"]
assert status == "stopped"
- mock_await_status.assert_not_called()
+ mock_get_waiter.assert_not_called()
@mock.patch.object(RdsHook, "conn")
def test_deferred(self, conn_mock):
@@ -872,8 +870,7 @@ class TestRdsStopDbOperator:
assert len(instance_snapshots) == 1
@mock_aws
- @patch.object(RdsHook, "wait_for_db_cluster_state")
- def test_stop_db_cluster(self, mock_await_status):
+ def test_stop_db_cluster(self):
_create_db_cluster(self.hook)
stop_db_cluster = RdsStopDbOperator(
task_id="test_stop_db_cluster", db_identifier=DB_CLUSTER_NAME,
db_type="cluster"
@@ -884,7 +881,6 @@ class TestRdsStopDbOperator:
describe_result =
self.hook.conn.describe_db_clusters(DBClusterIdentifier=DB_CLUSTER_NAME)
status = describe_result["DBClusters"][0]["Status"]
assert status == "stopped"
- mock_await_status.assert_called()
@mock_aws
def test_stop_db_cluster_create_snapshot_logs_warning_message(self,
caplog):