This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 2bba98f109 Use Boto waiters instead of customer _await_status method 
for RDS Operators (#27410)
2bba98f109 is described below

commit 2bba98f109cc7737f4293a195e03a0cc21a624cb
Author: Hank Ehly <[email protected]>
AuthorDate: Fri Nov 18 02:02:21 2022 +0900

    Use Boto waiters instead of customer _await_status method for RDS Operators 
(#27410)
---
 airflow/providers/amazon/aws/hooks/rds.py        | 301 +++++++++++++++++++++-
 airflow/providers/amazon/aws/operators/rds.py    | 155 +++--------
 airflow/providers/amazon/aws/sensors/rds.py      |  76 ++----
 tests/providers/amazon/aws/hooks/test_rds.py     | 312 +++++++++++++++++++++++
 tests/providers/amazon/aws/operators/test_rds.py | 135 ++++++----
 tests/providers/amazon/aws/sensors/test_rds.py   |  19 +-
 6 files changed, 755 insertions(+), 243 deletions(-)

diff --git a/airflow/providers/amazon/aws/hooks/rds.py 
b/airflow/providers/amazon/aws/hooks/rds.py
index cd2434282c..15790e3add 100644
--- a/airflow/providers/amazon/aws/hooks/rds.py
+++ b/airflow/providers/amazon/aws/hooks/rds.py
@@ -18,8 +18,10 @@
 """Interact with AWS RDS."""
 from __future__ import annotations
 
-from typing import TYPE_CHECKING
+import time
+from typing import TYPE_CHECKING, Callable
 
+from airflow.exceptions import AirflowException, AirflowNotFoundException
 from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
 
 if TYPE_CHECKING:
@@ -48,3 +50,300 @@ class RdsHook(AwsGenericHook["RDSClient"]):
     def __init__(self, *args, **kwargs) -> None:
         kwargs["client_type"] = "rds"
         super().__init__(*args, **kwargs)
+
+    def get_db_snapshot_state(self, snapshot_id: str) -> str:
+        """
+        Get the current state of a DB instance snapshot.
+
+        :param snapshot_id: The ID of the target DB instance snapshot
+        :return: Returns the status of the DB snapshot as a string (eg. 
"available")
+        :rtype: str
+        :raises AirflowNotFoundException: If the DB instance snapshot does not 
exist.
+        """
+        try:
+            response = 
self.conn.describe_db_snapshots(DBSnapshotIdentifier=snapshot_id)
+        except self.conn.exceptions.ClientError as e:
+            if e.response["Error"]["Code"] == "DBSnapshotNotFound":
+                raise AirflowNotFoundException(e)
+            raise e
+        return response["DBSnapshots"][0]["Status"].lower()
+
+    def wait_for_db_snapshot_state(
+        self, snapshot_id: str, target_state: str, check_interval: int = 30, 
max_attempts: int = 40
+    ) -> None:
+        """
+        Polls :py:meth:`RDS.Client.describe_db_snapshots` until the target 
state is reached.
+        An error is raised after a max number of attempts.
+
+        :param snapshot_id: The ID of the target DB instance snapshot
+        :param target_state: Wait until this state is reached
+        :param check_interval: The amount of time in seconds to wait between 
attempts
+        :param max_attempts: The maximum number of attempts to be made
+        """
+
+        def poke():
+            return self.get_db_snapshot_state(snapshot_id)
+
+        target_state = target_state.lower()
+        if target_state in ("available", "deleted", "completed"):
+            waiter = self.conn.get_waiter(f"db_snapshot_{target_state}")  # 
type: ignore
+            waiter.wait(
+                DBSnapshotIdentifier=snapshot_id,
+                WaiterConfig={"Delay": check_interval, "MaxAttempts": 
max_attempts},
+            )
+        else:
+            self._wait_for_state(poke, target_state, check_interval, 
max_attempts)
+            self.log.info("DB snapshot '%s' reached the '%s' state", 
snapshot_id, target_state)
+
+    def get_db_cluster_snapshot_state(self, snapshot_id: str) -> str:
+        """
+        Get the current state of a DB cluster snapshot.
+
+        :param snapshot_id: The ID of the target DB cluster.
+        :return: Returns the status of the DB cluster snapshot as a string 
(eg. "available")
+        :rtype: str
+        :raises AirflowNotFoundException: If the DB cluster snapshot does not 
exist.
+        """
+        try:
+            response = 
self.conn.describe_db_cluster_snapshots(DBClusterSnapshotIdentifier=snapshot_id)
+        except self.conn.exceptions.ClientError as e:
+            if e.response["Error"]["Code"] == "DBClusterSnapshotNotFoundFault":
+                raise AirflowNotFoundException(e)
+            raise e
+        return response["DBClusterSnapshots"][0]["Status"].lower()
+
+    def wait_for_db_cluster_snapshot_state(
+        self, snapshot_id: str, target_state: str, check_interval: int = 30, 
max_attempts: int = 40
+    ) -> None:
+        """
+        Polls :py:meth:`RDS.Client.describe_db_cluster_snapshots` until the 
target state is reached.
+        An error is raised after a max number of attempts.
+
+        :param snapshot_id: The ID of the target DB cluster snapshot
+        :param target_state: Wait until this state is reached
+        :param check_interval: The amount of time in seconds to wait between 
attempts
+        :param max_attempts: The maximum number of attempts to be made
+
+        .. seealso::
+            A list of possible values for target_state:
+            
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds.html#RDS.Client.describe_db_cluster_snapshots
+        """
+
+        def poke():
+            return self.get_db_cluster_snapshot_state(snapshot_id)
+
+        target_state = target_state.lower()
+        if target_state in ("available", "deleted"):
+            waiter = 
self.conn.get_waiter(f"db_cluster_snapshot_{target_state}")  # type: ignore
+            waiter.wait(
+                DBClusterSnapshotIdentifier=snapshot_id,
+                WaiterConfig={"Delay": check_interval, "MaxAttempts": 
max_attempts},
+            )
+        else:
+            self._wait_for_state(poke, target_state, check_interval, 
max_attempts)
+            self.log.info("DB cluster snapshot '%s' reached the '%s' state", 
snapshot_id, target_state)
+
+    def get_export_task_state(self, export_task_id: str) -> str:
+        """
+        Gets the current state of an RDS snapshot export to Amazon S3.
+
+        :param export_task_id: The identifier of the target snapshot export 
task.
+        :return: Returns the status of the snapshot export task as a string 
(eg. "canceled")
+        :rtype: str
+        :raises AirflowNotFoundException: If the export task does not exist.
+        """
+        try:
+            response = 
self.conn.describe_export_tasks(ExportTaskIdentifier=export_task_id)
+        except self.conn.exceptions.ClientError as e:
+            if e.response["Error"]["Code"] == "ExportTaskNotFoundFault":
+                raise AirflowNotFoundException(e)
+            raise e
+        return response["ExportTasks"][0]["Status"].lower()
+
+    def wait_for_export_task_state(
+        self, export_task_id: str, target_state: str, check_interval: int = 
30, max_attempts: int = 40
+    ) -> None:
+        """
+        Polls :py:meth:`RDS.Client.describe_export_tasks` until the target 
state is reached.
+        An error is raised after a max number of attempts.
+
+        :param export_task_id: The identifier of the target snapshot export 
task.
+        :param target_state: Wait until this state is reached
+        :param check_interval: The amount of time in seconds to wait between 
attempts
+        :param max_attempts: The maximum number of attempts to be made
+
+        .. seealso::
+            A list of possible values for target_state:
+            
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds.html#RDS.Client.describe_export_tasks
+        """
+
+        def poke():
+            return self.get_export_task_state(export_task_id)
+
+        target_state = target_state.lower()
+        self._wait_for_state(poke, target_state, check_interval, max_attempts)
+        self.log.info("export task '%s' reached the '%s' state", 
export_task_id, target_state)
+
+    def get_event_subscription_state(self, subscription_name: str) -> str:
+        """
+        Gets the current state of an RDS snapshot export to Amazon S3.
+
+        :param subscription_name: The name of the target RDS event 
notification subscription.
+        :return: Returns the status of the event subscription as a string (eg. 
"active")
+        :rtype: str
+        :raises AirflowNotFoundException: If the event subscription does not 
exist.
+        """
+        try:
+            response = 
self.conn.describe_event_subscriptions(SubscriptionName=subscription_name)
+        except self.conn.exceptions.ClientError as e:
+            if e.response["Error"]["Code"] == "SubscriptionNotFoundFault":
+                raise AirflowNotFoundException(e)
+            raise e
+        return response["EventSubscriptionsList"][0]["Status"].lower()
+
+    def wait_for_event_subscription_state(
+        self, subscription_name: str, target_state: str, check_interval: int = 
30, max_attempts: int = 40
+    ) -> None:
+        """
+        Polls :py:meth:`RDS.Client.describe_event_subscriptions` until the 
target state is reached.
+        An error is raised after a max number of attempts.
+
+        :param subscription_name: The name of the target RDS event 
notification subscription.
+        :param target_state: Wait until this state is reached
+        :param check_interval: The amount of time in seconds to wait between 
attempts
+        :param max_attempts: The maximum number of attempts to be made
+
+        .. seealso::
+            A list of possible values for target_state:
+            
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds.html#RDS.Client.describe_event_subscriptions
+        """
+
+        def poke():
+            return self.get_event_subscription_state(subscription_name)
+
+        target_state = target_state.lower()
+        self._wait_for_state(poke, target_state, check_interval, max_attempts)
+        self.log.info("event subscription '%s' reached the '%s' state", 
subscription_name, target_state)
+
+    def get_db_instance_state(self, db_instance_id: str) -> str:
+        """
+        Get the current state of a DB instance.
+
+        :param snapshot_id: The ID of the target DB instance.
+        :return: Returns the status of the DB instance as a string (eg. 
"available")
+        :rtype: str
+        :raises AirflowNotFoundException: If the DB instance does not exist.
+        """
+        try:
+            response = 
self.conn.describe_db_instances(DBInstanceIdentifier=db_instance_id)
+        except self.conn.exceptions.ClientError as e:
+            if e.response["Error"]["Code"] == "DBInstanceNotFoundFault":
+                raise AirflowNotFoundException(e)
+            raise e
+        return response["DBInstances"][0]["DBInstanceStatus"].lower()
+
+    def wait_for_db_instance_state(
+        self, db_instance_id: str, target_state: str, check_interval: int = 
30, max_attempts: int = 40
+    ) -> None:
+        """
+        Polls :py:meth:`RDS.Client.describe_db_instances` until the target 
state is reached.
+        An error is raised after a max number of attempts.
+
+        :param db_instance_id: The ID of the target DB instance.
+        :param target_state: Wait until this state is reached
+        :param check_interval: The amount of time in seconds to wait between 
attempts
+        :param max_attempts: The maximum number of attempts to be made
+
+        .. seealso::
+            For information about DB instance statuses, see Viewing DB 
instance status in the Amazon RDS
+            User Guide.
+            
https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/accessing-monitoring.html#Overview.DBInstance.Status
+        """
+
+        def poke():
+            return self.get_db_instance_state(db_instance_id)
+
+        target_state = target_state.lower()
+        if target_state in ("available", "deleted"):
+            waiter = self.conn.get_waiter(f"db_instance_{target_state}")  # 
type: ignore
+            waiter.wait(
+                DBInstanceIdentifier=db_instance_id,
+                WaiterConfig={"Delay": check_interval, "MaxAttempts": 
max_attempts},
+            )
+        else:
+            self._wait_for_state(poke, target_state, check_interval, 
max_attempts)
+            self.log.info("DB cluster snapshot '%s' reached the '%s' state", 
db_instance_id, target_state)
+
+    def get_db_cluster_state(self, db_cluster_id: str) -> str:
+        """
+        Get the current state of a DB cluster.
+
+        :param snapshot_id: The ID of the target DB cluster.
+        :return: Returns the status of the DB cluster as a string (eg. 
"available")
+        :rtype: str
+        :raises AirflowNotFoundException: If the DB cluster does not exist.
+        """
+        try:
+            response = 
self.conn.describe_db_clusters(DBClusterIdentifier=db_cluster_id)
+        except self.conn.exceptions.ClientError as e:
+            if e.response["Error"]["Code"] == "DBClusterNotFoundFault":
+                raise AirflowNotFoundException(e)
+            raise e
+        return response["DBClusters"][0]["Status"].lower()
+
+    def wait_for_db_cluster_state(
+        self, db_cluster_id: str, target_state: str, check_interval: int = 30, 
max_attempts: int = 40
+    ) -> None:
+        """
+        Polls :py:meth:`RDS.Client.describe_db_clusters` until the target 
state is reached.
+        An error is raised after a max number of attempts.
+
+        :param db_cluster_id: The ID of the target DB cluster.
+        :param target_state: Wait until this state is reached
+        :param check_interval: The amount of time in seconds to wait between 
attempts
+        :param max_attempts: The maximum number of attempts to be made
+
+        .. seealso::
+            For information about DB instance statuses, see Viewing DB 
instance status in the Amazon RDS
+            User Guide.
+            
https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/accessing-monitoring.html#Overview.DBInstance.Status
+        """
+
+        def poke():
+            return self.get_db_cluster_state(db_cluster_id)
+
+        target_state = target_state.lower()
+        if target_state in ("available", "deleted"):
+            waiter = self.conn.get_waiter(f"db_cluster_{target_state}")  # 
type: ignore
+            waiter.wait(
+                DBClusterIdentifier=db_cluster_id,
+                WaiterConfig={"Delay": check_interval, "MaxAttempts": 
max_attempts},
+            )
+        else:
+            self._wait_for_state(poke, target_state, check_interval, 
max_attempts)
+            self.log.info("DB cluster snapshot '%s' reached the '%s' state", 
db_cluster_id, target_state)
+
+    def _wait_for_state(
+        self,
+        poke: Callable[..., str],
+        target_state: str,
+        check_interval: int,
+        max_attempts: int,
+    ) -> None:
+        """
+        Polls the poke function for the current state until it reaches the 
target_state.
+
+        :param poke: A function that returns the current state of the target 
resource as a string.
+        :param target_state: Wait until this state is reached
+        :param check_interval: The amount of time in seconds to wait between 
attempts
+        :param max_attempts: The maximum number of attempts to be made
+        """
+        state = poke()
+        tries = 1
+        while state != target_state:
+            self.log.info("Current state is %s", state)
+            if tries >= max_attempts:
+                raise AirflowException("Max attempts exceeded")
+            time.sleep(check_interval)
+            state = poke()
+            tries += 1
diff --git a/airflow/providers/amazon/aws/operators/rds.py 
b/airflow/providers/amazon/aws/operators/rds.py
index f9e23eaf91..c10e969c8b 100644
--- a/airflow/providers/amazon/aws/operators/rds.py
+++ b/airflow/providers/amazon/aws/operators/rds.py
@@ -18,12 +18,10 @@
 from __future__ import annotations
 
 import json
-import time
 from typing import TYPE_CHECKING, Sequence
 
 from mypy_boto3_rds.type_defs import TagTypeDef
 
-from airflow.exceptions import AirflowException
 from airflow.models import BaseOperator
 from airflow.providers.amazon.aws.hooks.rds import RdsHook
 from airflow.providers.amazon.aws.utils.rds import RdsDbType
@@ -45,66 +43,6 @@ class RdsBaseOperator(BaseOperator):
 
         self._await_interval = 60  # seconds
 
-    def _describe_item(self, item_type: str, item_name: str) -> list:
-        if item_type == "instance_snapshot":
-            db_snaps = 
self.hook.conn.describe_db_snapshots(DBSnapshotIdentifier=item_name)
-            return db_snaps["DBSnapshots"]
-        elif item_type == "cluster_snapshot":
-            cl_snaps = 
self.hook.conn.describe_db_cluster_snapshots(DBClusterSnapshotIdentifier=item_name)
-            return cl_snaps["DBClusterSnapshots"]
-        elif item_type == "export_task":
-            exports = 
self.hook.conn.describe_export_tasks(ExportTaskIdentifier=item_name)
-            return exports["ExportTasks"]
-        elif item_type == "event_subscription":
-            subscriptions = 
self.hook.conn.describe_event_subscriptions(SubscriptionName=item_name)
-            return subscriptions["EventSubscriptionsList"]
-        elif item_type == "db_instance":
-            instances = 
self.hook.conn.describe_db_instances(DBInstanceIdentifier=item_name)
-            return instances["DBInstances"]
-        elif item_type == "db_cluster":
-            clusters = 
self.hook.conn.describe_db_clusters(DBClusterIdentifier=item_name)
-            return clusters["DBClusters"]
-        else:
-            raise AirflowException(f"Method for {item_type} is not 
implemented")
-
-    def _await_status(
-        self,
-        item_type: str,
-        item_name: str,
-        wait_statuses: list[str] | None = None,
-        ok_statuses: list[str] | None = None,
-        error_statuses: list[str] | None = None,
-    ) -> None:
-        """
-        Continuously gets item description from `_describe_item()` and waits 
while:
-        - status is in `wait_statuses`
-        - status not in `ok_statuses` and `error_statuses`
-        """
-        while True:
-            items = self._describe_item(item_type, item_name)
-
-            if len(items) == 0:
-                raise AirflowException(f"There is no {item_type} with 
identifier {item_name}")
-            if len(items) > 1:
-                raise AirflowException(f"There are {len(items)} {item_type} 
with identifier {item_name}")
-
-            if item_type == "db_instance":
-                status_field = "DBInstanceStatus"
-            else:
-                status_field = "Status"
-
-            if wait_statuses and items[0][status_field].lower() in 
wait_statuses:
-                time.sleep(self._await_interval)
-                continue
-            elif ok_statuses and items[0][status_field].lower() in ok_statuses:
-                break
-            elif error_statuses and items[0][status_field].lower() in 
error_statuses:
-                raise AirflowException(f"Item has error status 
({error_statuses}): {items[0]}")
-            else:
-                raise AirflowException(f"Item has uncertain status: 
{items[0]}")
-
-        return None
-
     def execute(self, context: Context) -> str:
         """Different implementations for snapshots, tasks and events"""
         raise NotImplementedError
@@ -166,8 +104,8 @@ class RdsCreateDbSnapshotOperator(RdsBaseOperator):
                 Tags=self.tags,
             )
             create_response = json.dumps(create_instance_snap, default=str)
-            item_type = "instance_snapshot"
-
+            if self.wait_for_completion:
+                
self.hook.wait_for_db_snapshot_state(self.db_snapshot_identifier, 
target_state="available")
         else:
             create_cluster_snap = self.hook.conn.create_db_cluster_snapshot(
                 DBClusterIdentifier=self.db_identifier,
@@ -175,15 +113,10 @@ class RdsCreateDbSnapshotOperator(RdsBaseOperator):
                 Tags=self.tags,
             )
             create_response = json.dumps(create_cluster_snap, default=str)
-            item_type = "cluster_snapshot"
-
-        if self.wait_for_completion:
-            self._await_status(
-                item_type,
-                self.db_snapshot_identifier,
-                wait_statuses=["creating"],
-                ok_statuses=["available"],
-            )
+            if self.wait_for_completion:
+                self.hook.wait_for_db_cluster_snapshot_state(
+                    self.db_snapshot_identifier, target_state="available"
+                )
         return create_response
 
 
@@ -270,8 +203,10 @@ class RdsCopyDbSnapshotOperator(RdsBaseOperator):
                 SourceRegion=self.source_region,
             )
             copy_response = json.dumps(copy_instance_snap, default=str)
-            item_type = "instance_snapshot"
-
+            if self.wait_for_completion:
+                self.hook.wait_for_db_snapshot_state(
+                    self.target_db_snapshot_identifier, 
target_state="available"
+                )
         else:
             copy_cluster_snap = self.hook.conn.copy_db_cluster_snapshot(
                 
SourceDBClusterSnapshotIdentifier=self.source_db_snapshot_identifier,
@@ -283,15 +218,10 @@ class RdsCopyDbSnapshotOperator(RdsBaseOperator):
                 SourceRegion=self.source_region,
             )
             copy_response = json.dumps(copy_cluster_snap, default=str)
-            item_type = "cluster_snapshot"
-
-        if self.wait_for_completion:
-            self._await_status(
-                item_type,
-                self.target_db_snapshot_identifier,
-                wait_statuses=["creating", "copying"],
-                ok_statuses=["available"],
-            )
+            if self.wait_for_completion:
+                self.hook.wait_for_db_cluster_snapshot_state(
+                    self.target_db_snapshot_identifier, 
target_state="available"
+                )
         return copy_response
 
 
@@ -314,6 +244,7 @@ class RdsDeleteDbSnapshotOperator(RdsBaseOperator):
         *,
         db_type: str,
         db_snapshot_identifier: str,
+        wait_for_completion: bool = True,
         aws_conn_id: str = "aws_default",
         **kwargs,
     ):
@@ -321,6 +252,7 @@ class RdsDeleteDbSnapshotOperator(RdsBaseOperator):
 
         self.db_type = RdsDbType(db_type)
         self.db_snapshot_identifier = db_snapshot_identifier
+        self.wait_for_completion = wait_for_completion
 
     def execute(self, context: Context) -> str:
         self.log.info("Starting to delete snapshot '%s'", 
self.db_snapshot_identifier)
@@ -330,11 +262,17 @@ class RdsDeleteDbSnapshotOperator(RdsBaseOperator):
                 DBSnapshotIdentifier=self.db_snapshot_identifier,
             )
             delete_response = json.dumps(delete_instance_snap, default=str)
+            if self.wait_for_completion:
+                
self.hook.wait_for_db_snapshot_state(self.db_snapshot_identifier, 
target_state="deleted")
         else:
             delete_cluster_snap = self.hook.conn.delete_db_cluster_snapshot(
                 DBClusterSnapshotIdentifier=self.db_snapshot_identifier,
             )
             delete_response = json.dumps(delete_cluster_snap, default=str)
+            if self.wait_for_completion:
+                self.hook.wait_for_db_cluster_snapshot_state(
+                    self.db_snapshot_identifier, target_state="deleted"
+                )
 
         return delete_response
 
@@ -406,14 +344,7 @@ class RdsStartExportTaskOperator(RdsBaseOperator):
         )
 
         if self.wait_for_completion:
-            self._await_status(
-                "export_task",
-                self.export_task_identifier,
-                wait_statuses=["starting", "in_progress"],
-                ok_statuses=["complete"],
-                error_statuses=["canceling", "canceled"],
-            )
-
+            self.hook.wait_for_export_task_state(self.export_task_identifier, 
target_state="complete")
         return json.dumps(start_export, default=str)
 
 
@@ -452,13 +383,7 @@ class RdsCancelExportTaskOperator(RdsBaseOperator):
         )
 
         if self.wait_for_completion:
-            self._await_status(
-                "export_task",
-                self.export_task_identifier,
-                wait_statuses=["canceling"],
-                ok_statuses=["canceled"],
-            )
-
+            self.hook.wait_for_export_task_state(self.export_task_identifier, 
target_state="canceled")
         return json.dumps(cancel_export, default=str)
 
 
@@ -531,13 +456,7 @@ class RdsCreateEventSubscriptionOperator(RdsBaseOperator):
         )
 
         if self.wait_for_completion:
-            self._await_status(
-                "event_subscription",
-                self.subscription_name,
-                wait_statuses=["creating"],
-                ok_statuses=["active"],
-            )
-
+            
self.hook.wait_for_event_subscription_state(self.subscription_name, 
target_state="active")
         return json.dumps(create_subscription, default=str)
 
 
@@ -628,10 +547,7 @@ class RdsCreateDbInstanceOperator(RdsBaseOperator):
         )
 
         if self.wait_for_completion:
-            self.hook.conn.get_waiter("db_instance_available").wait(
-                DBInstanceIdentifier=self.db_instance_identifier
-            )
-
+            self.hook.wait_for_db_instance_state(self.db_instance_identifier, 
target_state="available")
         return json.dumps(create_db_instance, default=str)
 
 
@@ -675,10 +591,7 @@ class RdsDeleteDbInstanceOperator(RdsBaseOperator):
         )
 
         if self.wait_for_completion:
-            self.hook.conn.get_waiter("db_instance_deleted").wait(
-                DBInstanceIdentifier=self.db_instance_identifier
-            )
-
+            self.hook.wait_for_db_instance_state(self.db_instance_identifier, 
target_state="deleted")
         return json.dumps(delete_db_instance, default=str)
 
 
@@ -730,9 +643,9 @@ class RdsStartDbOperator(RdsBaseOperator):
     def _wait_until_db_available(self):
         self.log.info("Waiting for DB %s to reach 'available' state", 
self.db_type.value)
         if self.db_type == RdsDbType.INSTANCE:
-            
self.hook.conn.get_waiter("db_instance_available").wait(DBInstanceIdentifier=self.db_identifier)
+            self.hook.wait_for_db_instance_state(self.db_identifier, 
target_state="available")
         else:
-            
self.hook.conn.get_waiter("db_cluster_available").wait(DBClusterIdentifier=self.db_identifier)
+            self.hook.wait_for_db_cluster_state(self.db_identifier, 
target_state="available")
 
 
 class RdsStopDbOperator(RdsBaseOperator):
@@ -797,16 +710,10 @@ class RdsStopDbOperator(RdsBaseOperator):
 
     def _wait_until_db_stopped(self):
         self.log.info("Waiting for DB %s to reach 'stopped' state", 
self.db_type.value)
-        wait_statuses = ["stopping"]
-        ok_statuses = ["stopped"]
         if self.db_type == RdsDbType.INSTANCE:
-            self._await_status(
-                "db_instance", self.db_identifier, 
wait_statuses=wait_statuses, ok_statuses=ok_statuses
-            )
+            self.hook.wait_for_db_instance_state(self.db_identifier, 
target_state="stopped")
         else:
-            self._await_status(
-                "db_cluster", self.db_identifier, wait_statuses=wait_statuses, 
ok_statuses=ok_statuses
-            )
+            self.hook.wait_for_db_cluster_state(self.db_identifier, 
target_state="stopped")
 
 
 __all__ = [
diff --git a/airflow/providers/amazon/aws/sensors/rds.py 
b/airflow/providers/amazon/aws/sensors/rds.py
index 07933809fe..731c8b5def 100644
--- a/airflow/providers/amazon/aws/sensors/rds.py
+++ b/airflow/providers/amazon/aws/sensors/rds.py
@@ -18,9 +18,7 @@ from __future__ import annotations
 
 from typing import TYPE_CHECKING, Sequence
 
-from botocore.exceptions import ClientError
-
-from airflow import AirflowException
+from airflow.exceptions import AirflowNotFoundException
 from airflow.providers.amazon.aws.hooks.rds import RdsHook
 from airflow.providers.amazon.aws.utils.rds import RdsDbType
 from airflow.sensors.base import BaseSensorOperator
@@ -41,40 +39,6 @@ class RdsBaseSensor(BaseSensorOperator):
         self.target_statuses: list[str] = []
         super().__init__(*args, **kwargs)
 
-    def _describe_item(self, item_type: str, item_name: str) -> list:
-        if item_type == "instance_snapshot":
-            db_snaps = 
self.hook.conn.describe_db_snapshots(DBSnapshotIdentifier=item_name)
-            return db_snaps["DBSnapshots"]
-        elif item_type == "cluster_snapshot":
-            cl_snaps = 
self.hook.conn.describe_db_cluster_snapshots(DBClusterSnapshotIdentifier=item_name)
-            return cl_snaps["DBClusterSnapshots"]
-        elif item_type == "export_task":
-            exports = 
self.hook.conn.describe_export_tasks(ExportTaskIdentifier=item_name)
-            return exports["ExportTasks"]
-        elif item_type == "db_instance":
-            instances = 
self.hook.conn.describe_db_instances(DBInstanceIdentifier=item_name)
-            return instances["DBInstances"]
-        elif item_type == "db_cluster":
-            clusters = 
self.hook.conn.describe_db_clusters(DBClusterIdentifier=item_name)
-            return clusters["DBClusters"]
-        else:
-            raise AirflowException(f"Method for {item_type} is not 
implemented")
-
-    def _check_item(self, item_type: str, item_name: str) -> bool:
-        """Get certain item from `_describe_item()` and check its status"""
-        if item_type == "db_instance":
-            status_field = "DBInstanceStatus"
-        else:
-            status_field = "Status"
-        try:
-            items = self._describe_item(item_type, item_name)
-        except ClientError:
-            return False
-        else:
-            return bool(items) and any(
-                map(lambda status: items[0][status_field].lower() == status, 
self.target_statuses)
-            )
-
 
 class RdsSnapshotExistenceSensor(RdsBaseSensor):
     """
@@ -112,10 +76,14 @@ class RdsSnapshotExistenceSensor(RdsBaseSensor):
         self.log.info(
             "Poking for statuses : %s\nfor snapshot %s", self.target_statuses, 
self.db_snapshot_identifier
         )
-        if self.db_type.value == "instance":
-            return self._check_item(item_type="instance_snapshot", 
item_name=self.db_snapshot_identifier)
-        else:
-            return self._check_item(item_type="cluster_snapshot", 
item_name=self.db_snapshot_identifier)
+        try:
+            if self.db_type.value == "instance":
+                state = 
self.hook.get_db_snapshot_state(self.db_snapshot_identifier)
+            else:
+                state = 
self.hook.get_db_cluster_snapshot_state(self.db_snapshot_identifier)
+        except AirflowNotFoundException:
+            return False
+        return state in self.target_statuses
 
 
 class RdsExportTaskExistenceSensor(RdsBaseSensor):
@@ -158,7 +126,11 @@ class RdsExportTaskExistenceSensor(RdsBaseSensor):
         self.log.info(
             "Poking for statuses : %s\nfor export task %s", 
self.target_statuses, self.export_task_identifier
         )
-        return self._check_item(item_type="export_task", 
item_name=self.export_task_identifier)
+        try:
+            state = 
self.hook.get_export_task_state(self.export_task_identifier)
+        except AirflowNotFoundException:
+            return False
+        return state in self.target_statuses
 
 
 class RdsDbSensor(RdsBaseSensor):
@@ -184,7 +156,7 @@ class RdsDbSensor(RdsBaseSensor):
         self,
         *,
         db_identifier: str,
-        db_type: str = "instance",
+        db_type: RdsDbType | str = RdsDbType.INSTANCE,
         target_statuses: list[str] | None = None,
         aws_conn_id: str = "aws_default",
         **kwargs,
@@ -192,19 +164,21 @@ class RdsDbSensor(RdsBaseSensor):
         super().__init__(aws_conn_id=aws_conn_id, **kwargs)
         self.db_identifier = db_identifier
         self.target_statuses = target_statuses or ["available"]
-        self.db_type = RdsDbType(db_type)
+        self.db_type = db_type
 
     def poke(self, context: Context):
+        db_type = RdsDbType(self.db_type)
         self.log.info(
             "Poking for statuses : %s\nfor db instance %s", 
self.target_statuses, self.db_identifier
         )
-        item_type = self._check_item_type()
-        return self._check_item(item_type=item_type, 
item_name=self.db_identifier)
-
-    def _check_item_type(self):
-        if self.db_type == RdsDbType.CLUSTER:
-            return "db_cluster"
-        return "db_instance"
+        try:
+            if db_type == RdsDbType.INSTANCE:
+                state = self.hook.get_db_instance_state(self.db_identifier)
+            else:
+                state = self.hook.get_db_cluster_state(self.db_identifier)
+        except AirflowNotFoundException:
+            return False
+        return state in self.target_statuses
 
 
 __all__ = [
diff --git a/tests/providers/amazon/aws/hooks/test_rds.py 
b/tests/providers/amazon/aws/hooks/test_rds.py
index b7ea30b9ad..f98320a2f1 100644
--- a/tests/providers/amazon/aws/hooks/test_rds.py
+++ b/tests/providers/amazon/aws/hooks/test_rds.py
@@ -17,10 +17,119 @@
 # under the License.
 from __future__ import annotations
 
+from unittest.mock import patch
+
+import pytest
+from moto import mock_rds
+
+from airflow.exceptions import AirflowException, AirflowNotFoundException
 from airflow.providers.amazon.aws.hooks.rds import RdsHook
 
 
[email protected]
+def rds_hook() -> RdsHook:
+    """Returns an RdsHook whose underlying connection is mocked with moto"""
+    with mock_rds():
+        yield RdsHook(aws_conn_id="aws_default", region_name="us-east-1")
+
+
[email protected]
+def db_instance_id(rds_hook: RdsHook) -> str:
+    """Creates an RDS DB instance and returns its id"""
+    response = rds_hook.conn.create_db_instance(
+        DBInstanceIdentifier="testrdshook-db-instance",
+        DBInstanceClass="db.t4g.micro",
+        Engine="postgres",
+        AllocatedStorage=20,
+        MasterUsername="testrdshook",
+        MasterUserPassword="testrdshook",
+    )
+    return response["DBInstance"]["DBInstanceIdentifier"]
+
+
[email protected]
+def db_cluster_id(rds_hook: RdsHook) -> str:
+    """Creates an RDS DB cluster and returns its id"""
+    response = rds_hook.conn.create_db_cluster(
+        DBClusterIdentifier="testrdshook-db-cluster",
+        Engine="postgres",
+        MasterUsername="testrdshook",
+        MasterUserPassword="testrdshook",
+        DBClusterInstanceClass="db.t4g.micro",
+        AllocatedStorage=20,
+    )
+    return response["DBCluster"]["DBClusterIdentifier"]
+
+
[email protected]
+def db_snapshot(rds_hook: RdsHook, db_instance_id: str) -> dict:
+    """
+    Creates a mock DB instance snapshot and returns the DBSnapshot dict from 
the boto response object.
+    
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds.html#RDS.Client.create_db_snapshot
+    """
+    response = rds_hook.conn.create_db_snapshot(
+        DBSnapshotIdentifier="testrdshook-db-instance-snapshot", 
DBInstanceIdentifier=db_instance_id
+    )
+    return response["DBSnapshot"]
+
+
[email protected]
+def db_snapshot_id(db_snapshot: dict) -> str:
+    return db_snapshot["DBSnapshotIdentifier"]
+
+
[email protected]
+def db_snapshot_arn(db_snapshot: dict) -> str:
+    return db_snapshot["DBSnapshotArn"]
+
+
[email protected]
+def db_cluster_snapshot(rds_hook: RdsHook, db_cluster_id: str):
+    """
+    Creates a mock DB cluster snapshot and returns the DBClusterSnapshot dict 
from the boto response object.
+    
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds.html#RDS.Client.create_db_cluster_snapshot
+    """
+    response = rds_hook.conn.create_db_cluster_snapshot(
+        DBClusterSnapshotIdentifier="testrdshook-db-cluster-snapshot", 
DBClusterIdentifier=db_cluster_id
+    )
+    return response["DBClusterSnapshot"]
+
+
[email protected]
+def db_cluster_snapshot_id(db_cluster_snapshot) -> str:
+    return db_cluster_snapshot["DBClusterSnapshotIdentifier"]
+
+
[email protected]
+def export_task_id(rds_hook: RdsHook, db_snapshot_arn: str) -> str:
+    response = rds_hook.conn.start_export_task(
+        ExportTaskIdentifier="testrdshook-export-task",
+        SourceArn=db_snapshot_arn,
+        S3BucketName="test",
+        IamRoleArn="test",
+        KmsKeyId="test",
+    )
+    return response["ExportTaskIdentifier"]
+
+
[email protected]
+def event_subscription_name(rds_hook: RdsHook, db_instance_id: str) -> str:
+    """Creates an mock RDS event subscription and returns its name"""
+    response = rds_hook.conn.create_event_subscription(
+        SubscriptionName="testrdshook-event-subscription",
+        SnsTopicArn="test",
+        SourceType="db-instance",
+        SourceIds=[db_instance_id],
+        Enabled=True,
+    )
+    return response["EventSubscription"]["CustSubscriptionId"]
+
+
 class TestRdsHook:
+    # For testing, set the delay between status checks to 0 so that we aren't 
sleeping during tests,
+    # and max_attempts to 1 so that we don't retry unless required.
+    waiter_args = {"check_interval": 0, "max_attempts": 1}
+
     def test_conn_attribute(self):
         hook = RdsHook(aws_conn_id="aws_default", region_name="us-east-1")
         assert hasattr(hook, "conn")
@@ -28,3 +137,206 @@ class TestRdsHook:
         conn = hook.conn
         assert conn is hook.conn  # Cached property
         assert conn is hook.get_conn()  # Same object as returned by `conn` 
property
+
+    def test_get_db_instance_state(self, rds_hook: RdsHook, db_instance_id: 
str):
+        response = 
rds_hook.conn.describe_db_instances(DBInstanceIdentifier=db_instance_id)
+        state_expected = response["DBInstances"][0]["DBInstanceStatus"]
+        state_actual = rds_hook.get_db_instance_state(db_instance_id)
+        assert state_actual == state_expected
+
+    def test_wait_for_db_instance_state_boto_waiters(self, rds_hook: RdsHook, 
db_instance_id: str):
+        """Checks that the DB instance waiter uses AWS boto waiters where 
possible"""
+        for state in ("available", "deleted"):
+            with patch.object(rds_hook.conn, "get_waiter") as mock:
+                rds_hook.wait_for_db_instance_state(db_instance_id, 
target_state=state, **self.waiter_args)
+                mock.assert_called_once_with(f"db_instance_{state}")
+                mock.return_value.wait.assert_called_once_with(
+                    DBInstanceIdentifier=db_instance_id,
+                    WaiterConfig={
+                        "Delay": self.waiter_args["check_interval"],
+                        "MaxAttempts": self.waiter_args["max_attempts"],
+                    },
+                )
+
+    def test_wait_for_db_instance_state_custom_waiter(self, rds_hook: RdsHook, 
db_instance_id: str):
+        """Checks that the DB instance waiter uses custom wait logic when AWS 
boto waiters aren't available"""
+        with patch.object(rds_hook, "_wait_for_state") as mock:
+            rds_hook.wait_for_db_instance_state(db_instance_id, 
target_state="stopped", **self.waiter_args)
+            mock.assert_called_once()
+
+        with patch.object(rds_hook, "get_db_instance_state", 
return_value="stopped") as mock:
+            rds_hook.wait_for_db_instance_state(db_instance_id, 
target_state="stopped", **self.waiter_args)
+            mock.assert_called_once_with(db_instance_id)
+
+    def test_get_db_cluster_state(self, rds_hook: RdsHook, db_cluster_id: str):
+        response = 
rds_hook.conn.describe_db_clusters(DBClusterIdentifier=db_cluster_id)
+        state_expected = response["DBClusters"][0]["Status"]
+        state_actual = rds_hook.get_db_cluster_state(db_cluster_id)
+        assert state_actual == state_expected
+
+    def test_wait_for_db_cluster_state_boto_waiters(self, rds_hook: RdsHook, 
db_cluster_id: str):
+        """Checks that the DB cluster waiter uses AWS boto waiters where 
possible"""
+        for state in ("available", "deleted"):
+            with patch.object(rds_hook.conn, "get_waiter") as mock:
+                rds_hook.wait_for_db_cluster_state(db_cluster_id, 
target_state=state, **self.waiter_args)
+                mock.assert_called_once_with(f"db_cluster_{state}")
+                mock.return_value.wait.assert_called_once_with(
+                    DBClusterIdentifier=db_cluster_id,
+                    WaiterConfig={
+                        "Delay": self.waiter_args["check_interval"],
+                        "MaxAttempts": self.waiter_args["max_attempts"],
+                    },
+                )
+
+    def test_wait_for_db_cluster_state_custom_waiter(self, rds_hook: RdsHook, 
db_cluster_id: str):
+        """Checks that the DB cluster waiter uses custom wait logic when AWS 
boto waiters aren't available"""
+        with patch.object(rds_hook, "_wait_for_state") as mock_wait_for_state:
+            rds_hook.wait_for_db_cluster_state(db_cluster_id, 
target_state="stopped", **self.waiter_args)
+            mock_wait_for_state.assert_called_once()
+
+        with patch.object(rds_hook, "get_db_cluster_state", 
return_value="stopped") as mock:
+            rds_hook.wait_for_db_cluster_state(db_cluster_id, 
target_state="stopped", **self.waiter_args)
+            mock.assert_called_once_with(db_cluster_id)
+
+    def test_get_db_snapshot_state(self, rds_hook: RdsHook, db_snapshot_id: 
str):
+        response = 
rds_hook.conn.describe_db_snapshots(DBSnapshotIdentifier=db_snapshot_id)
+        state_expected = response["DBSnapshots"][0]["Status"]
+        state_actual = rds_hook.get_db_snapshot_state(db_snapshot_id)
+        assert state_actual == state_expected
+
+    def test_get_db_snapshot_state_not_found(self, rds_hook: RdsHook):
+        with pytest.raises(AirflowNotFoundException):
+            rds_hook.get_db_snapshot_state("does_not_exist")
+
+    def test_wait_for_db_snapshot_state_boto_waiters(self, rds_hook: RdsHook, 
db_snapshot_id: str):
+        """Checks that the DB snapshot waiter uses AWS boto waiters where 
possible"""
+        for state in ("available", "deleted", "completed"):
+            with patch.object(rds_hook.conn, "get_waiter") as mock:
+                rds_hook.wait_for_db_snapshot_state(db_snapshot_id, 
target_state=state, **self.waiter_args)
+                mock.assert_called_once_with(f"db_snapshot_{state}")
+                mock.return_value.wait.assert_called_once_with(
+                    DBSnapshotIdentifier=db_snapshot_id,
+                    WaiterConfig={
+                        "Delay": self.waiter_args["check_interval"],
+                        "MaxAttempts": self.waiter_args["max_attempts"],
+                    },
+                )
+
+    def test_wait_for_db_snapshot_state_custom_waiter(self, rds_hook: RdsHook, 
db_snapshot_id: str):
+        """Checks that the DB snapshot waiter uses custom wait logic when AWS 
boto waiters aren't available"""
+        with patch.object(rds_hook, "_wait_for_state") as mock:
+            rds_hook.wait_for_db_snapshot_state(db_snapshot_id, 
target_state="canceled", **self.waiter_args)
+            mock.assert_called_once()
+
+        with patch.object(rds_hook, "get_db_snapshot_state", 
return_value="canceled") as mock:
+            rds_hook.wait_for_db_snapshot_state(db_snapshot_id, 
target_state="canceled", **self.waiter_args)
+            mock.assert_called_once_with(db_snapshot_id)
+
+    def test_get_db_cluster_snapshot_state(self, rds_hook: RdsHook, 
db_cluster_snapshot_id: str):
+        response = rds_hook.conn.describe_db_cluster_snapshots(
+            DBClusterSnapshotIdentifier=db_cluster_snapshot_id
+        )
+        state_expected = response["DBClusterSnapshots"][0]["Status"]
+        state_actual = 
rds_hook.get_db_cluster_snapshot_state(db_cluster_snapshot_id)
+        assert state_actual == state_expected
+
+    def test_get_db_cluster_snapshot_state_not_found(self, rds_hook: RdsHook):
+        with pytest.raises(AirflowNotFoundException):
+            rds_hook.get_db_cluster_snapshot_state("does_not_exist")
+
+    def test_wait_for_db_cluster_snapshot_state_boto_waiters(
+        self, rds_hook: RdsHook, db_cluster_snapshot_id: str
+    ):
+        """Checks that the DB cluster snapshot waiter uses AWS boto waiters 
where possible"""
+        for state in ("available", "deleted"):
+            with patch.object(rds_hook.conn, "get_waiter") as mock:
+                rds_hook.wait_for_db_cluster_snapshot_state(
+                    db_cluster_snapshot_id, target_state=state, 
**self.waiter_args
+                )
+                mock.assert_called_once_with(f"db_cluster_snapshot_{state}")
+                mock.return_value.wait.assert_called_once_with(
+                    DBClusterSnapshotIdentifier=db_cluster_snapshot_id,
+                    WaiterConfig={
+                        "Delay": self.waiter_args["check_interval"],
+                        "MaxAttempts": self.waiter_args["max_attempts"],
+                    },
+                )
+
+    def test_wait_for_db_cluster_snapshot_state_custom_waiter(
+        self, rds_hook: RdsHook, db_cluster_snapshot_id: str
+    ):
+        """
+        Checks that the DB cluster snapshot waiter uses custom wait logic when 
AWS boto waiters
+        aren't available
+        """
+        with patch.object(rds_hook, "_wait_for_state") as mock:
+            rds_hook.wait_for_db_cluster_snapshot_state(
+                db_cluster_snapshot_id, target_state="canceled", 
**self.waiter_args
+            )
+            mock.assert_called_once()
+
+        with patch.object(rds_hook, "get_db_cluster_snapshot_state", 
return_value="canceled") as mock:
+            rds_hook.wait_for_db_cluster_snapshot_state(
+                db_cluster_snapshot_id, target_state="canceled", 
**self.waiter_args
+            )
+            mock.assert_called_once_with(db_cluster_snapshot_id)
+
+    def test_get_export_task_state(self, rds_hook: RdsHook, export_task_id: 
str):
+        response = 
rds_hook.conn.describe_export_tasks(ExportTaskIdentifier=export_task_id)
+        state_expected = response["ExportTasks"][0]["Status"]
+        state_actual = rds_hook.get_export_task_state(export_task_id)
+        assert state_actual == state_expected
+
+    def test_get_export_task_state_not_found(self, rds_hook: RdsHook):
+        with pytest.raises(AirflowNotFoundException):
+            rds_hook.get_export_task_state("does_not_exist")
+
+    def test_wait_for_export_task_state(self, rds_hook: RdsHook, 
export_task_id: str):
+        """
+        Checks that the export task waiter uses custom wait logic (no boto 
waiters exist for this resource)
+        """
+        with patch.object(rds_hook, "_wait_for_state") as mock:
+            rds_hook.wait_for_export_task_state(export_task_id, 
target_state="complete", **self.waiter_args)
+            mock.assert_called_once()
+
+        with patch.object(rds_hook, "get_export_task_state", 
return_value="complete") as mock:
+            rds_hook.wait_for_export_task_state(export_task_id, 
target_state="complete", **self.waiter_args)
+            mock.assert_called_once_with(export_task_id)
+
+    def test_get_event_subscription_state(self, rds_hook: RdsHook, 
event_subscription_name: str):
+        response = 
rds_hook.conn.describe_event_subscriptions(SubscriptionName=event_subscription_name)
+        state_expected = response["EventSubscriptionsList"][0]["Status"]
+        state_actual = 
rds_hook.get_event_subscription_state(event_subscription_name)
+        assert state_actual == state_expected
+
+    def test_get_event_subscription_state_not_found(self, rds_hook: RdsHook):
+        with pytest.raises(AirflowNotFoundException):
+            rds_hook.get_event_subscription_state("does_not_exist")
+
+    def test_wait_for_event_subscription_state(self, rds_hook: RdsHook, 
event_subscription_name: str):
+        """
+        Checks that the event subscription waiter uses custom wait logic (no 
boto waiters
+        exist for this resource)
+        """
+        with patch.object(rds_hook, "_wait_for_state") as mock:
+            rds_hook.wait_for_event_subscription_state(
+                event_subscription_name, target_state="active", 
**self.waiter_args
+            )
+            mock.assert_called_once()
+
+        with patch.object(rds_hook, "get_event_subscription_state", 
return_value="active") as mock:
+            rds_hook.wait_for_event_subscription_state(
+                event_subscription_name, target_state="active", 
**self.waiter_args
+            )
+            mock.assert_called_once_with(event_subscription_name)
+
+    def test_wait_for_state(self, rds_hook: RdsHook):
+        def poke():
+            return "foo"
+
+        with pytest.raises(AirflowException, match="Max attempts exceeded"):
+            with patch("airflow.providers.amazon.aws.hooks.rds.time.sleep") as 
mock:
+                rds_hook._wait_for_state(poke, target_state="bar", 
check_interval=0, max_attempts=2)
+        # This next line should exist outside of the pytest.raises() context 
manager or else it won't
+        # get executed
+        mock.assert_called_once_with(0)
diff --git a/tests/providers/amazon/aws/operators/test_rds.py 
b/tests/providers/amazon/aws/operators/test_rds.py
index 4817d9dc6e..d2408fa47d 100644
--- a/tests/providers/amazon/aws/operators/test_rds.py
+++ b/tests/providers/amazon/aws/operators/test_rds.py
@@ -23,7 +23,6 @@ from unittest.mock import patch
 import pytest
 from moto import mock_rds
 
-from airflow.exceptions import AirflowException
 from airflow.models import DAG
 from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
 from airflow.providers.amazon.aws.hooks.rds import RdsHook
@@ -153,26 +152,6 @@ class TestBaseRdsOperator:
         assert hasattr(self.op, "hook")
         assert self.op.hook.__class__.__name__ == "RdsHook"
 
-    def test_describe_item_wrong_type(self):
-        with pytest.raises(AirflowException):
-            self.op._describe_item("database", "auth-db")
-
-    def test_await_status_error(self):
-        self.op._describe_item = lambda item_type, item_name: [{"Status": 
"error"}]
-        with pytest.raises(AirflowException):
-            self.op._await_status(
-                item_type="instance_snapshot",
-                item_name="",
-                wait_statuses=["wait"],
-                error_statuses=["error"],
-            )
-
-    def test_await_status_ok(self):
-        self.op._describe_item = lambda item_type, item_name: [{"Status": 
"ok"}]
-        self.op._await_status(
-            item_type="instance_snapshot", item_name="", 
wait_statuses=["wait"], ok_statuses=["ok"]
-        )
-
 
 class TestRdsCreateDbSnapshotOperator:
     @classmethod
@@ -207,8 +186,8 @@ class TestRdsCreateDbSnapshotOperator:
         assert len(instance_snapshots) == 1
 
     @mock_rds
-    @patch.object(RdsBaseOperator, "_await_status")
-    def test_create_db_instance_snapshot_no_wait(self, mock_await_status):
+    @patch.object(RdsHook, "wait_for_db_snapshot_state")
+    def test_create_db_instance_snapshot_no_wait(self, mock_wait):
         _create_db_instance(self.hook)
         instance_snapshot_operator = RdsCreateDbSnapshotOperator(
             task_id="test_instance_no_wait",
@@ -227,7 +206,7 @@ class TestRdsCreateDbSnapshotOperator:
 
         assert instance_snapshots
         assert len(instance_snapshots) == 1
-        assert mock_await_status.not_called()
+        mock_wait.assert_not_called()
 
     @mock_rds
     def test_create_db_cluster_snapshot(self):
@@ -250,8 +229,8 @@ class TestRdsCreateDbSnapshotOperator:
         assert len(cluster_snapshots) == 1
 
     @mock_rds
-    @patch.object(RdsBaseOperator, "_await_status")
-    def test_create_db_cluster_snapshot_no_wait(self, mock_no_wait):
+    @patch.object(RdsHook, "wait_for_db_cluster_snapshot_state")
+    def test_create_db_cluster_snapshot_no_wait(self, mock_wait):
         _create_db_cluster(self.hook)
         cluster_snapshot_operator = RdsCreateDbSnapshotOperator(
             task_id="test_cluster_no_wait",
@@ -270,7 +249,7 @@ class TestRdsCreateDbSnapshotOperator:
 
         assert cluster_snapshots
         assert len(cluster_snapshots) == 1
-        assert mock_no_wait.not_called()
+        mock_wait.assert_not_called()
 
 
 class TestRdsCopyDbSnapshotOperator:
@@ -307,7 +286,7 @@ class TestRdsCopyDbSnapshotOperator:
         assert len(instance_snapshots) == 1
 
     @mock_rds
-    @patch.object(RdsBaseOperator, "_await_status")
+    @patch.object(RdsHook, "wait_for_db_snapshot_state")
     def test_copy_db_instance_snapshot_no_wait(self, mock_await_status):
         _create_db_instance(self.hook)
         _create_db_instance_snapshot(self.hook)
@@ -328,7 +307,7 @@ class TestRdsCopyDbSnapshotOperator:
 
         assert instance_snapshots
         assert len(instance_snapshots) == 1
-        assert mock_await_status.not_called()
+        mock_await_status.assert_not_called()
 
     @mock_rds
     def test_copy_db_cluster_snapshot(self):
@@ -354,7 +333,7 @@ class TestRdsCopyDbSnapshotOperator:
         assert len(cluster_snapshots) == 1
 
     @mock_rds
-    @patch.object(RdsBaseOperator, "_await_status")
+    @patch.object(RdsHook, "wait_for_db_snapshot_state")
     def test_copy_db_cluster_snapshot_no_wait(self, mock_await_status):
         _create_db_cluster(self.hook)
         _create_db_cluster_snapshot(self.hook)
@@ -376,7 +355,7 @@ class TestRdsCopyDbSnapshotOperator:
 
         assert cluster_snapshots
         assert len(cluster_snapshots) == 1
-        assert mock_await_status.not_called()
+        mock_await_status.assert_not_called()
 
 
 class TestRdsDeleteDbSnapshotOperator:
@@ -404,10 +383,37 @@ class TestRdsDeleteDbSnapshotOperator:
             dag=self.dag,
         )
         _patch_hook_get_connection(instance_snapshot_operator.hook)
-        instance_snapshot_operator.execute(None)
+        with patch.object(instance_snapshot_operator.hook, 
"wait_for_db_snapshot_state") as mock_wait:
+            instance_snapshot_operator.execute(None)
+        mock_wait.assert_called_once_with(DB_INSTANCE_SNAPSHOT, 
target_state="deleted")
 
         with pytest.raises(self.hook.conn.exceptions.ClientError):
-            
self.hook.conn.describe_db_snapshots(DBSnapshotIdentifier=DB_CLUSTER_SNAPSHOT)
+            
self.hook.conn.describe_db_snapshots(DBSnapshotIdentifier=DB_INSTANCE_SNAPSHOT)
+
+    @mock_rds
+    def test_delete_db_instance_snapshot_no_wait(self):
+        """
+        Check that the operator does not wait for the DB instance snapshot 
delete operation to complete when
+        wait_for_completion=False
+        """
+        _create_db_instance(self.hook)
+        _create_db_instance_snapshot(self.hook)
+
+        instance_snapshot_operator = RdsDeleteDbSnapshotOperator(
+            task_id="test_delete_db_instance_snapshot_no_wait",
+            db_type="instance",
+            db_snapshot_identifier=DB_INSTANCE_SNAPSHOT,
+            aws_conn_id=AWS_CONN,
+            dag=self.dag,
+            wait_for_completion=False,
+        )
+        _patch_hook_get_connection(instance_snapshot_operator.hook)
+        with patch.object(instance_snapshot_operator.hook, 
"wait_for_db_snapshot_state") as mock_wait:
+            instance_snapshot_operator.execute(None)
+        mock_wait.assert_not_called()
+
+        with pytest.raises(self.hook.conn.exceptions.ClientError):
+            
self.hook.conn.describe_db_snapshots(DBSnapshotIdentifier=DB_INSTANCE_SNAPSHOT)
 
     @mock_rds
     def test_delete_db_cluster_snapshot(self):
@@ -422,7 +428,34 @@ class TestRdsDeleteDbSnapshotOperator:
             dag=self.dag,
         )
         _patch_hook_get_connection(cluster_snapshot_operator.hook)
-        cluster_snapshot_operator.execute(None)
+        with patch.object(cluster_snapshot_operator.hook, 
"wait_for_db_cluster_snapshot_state") as mock_wait:
+            cluster_snapshot_operator.execute(None)
+        mock_wait.assert_called_once_with(DB_CLUSTER_SNAPSHOT, 
target_state="deleted")
+
+        with pytest.raises(self.hook.conn.exceptions.ClientError):
+            
self.hook.conn.describe_db_cluster_snapshots(DBClusterSnapshotIdentifier=DB_CLUSTER_SNAPSHOT)
+
+    @mock_rds
+    def test_delete_db_cluster_snapshot_no_wait(self):
+        """
+        Check that the operator does not wait for the DB cluster snapshot 
delete operation to complete when
+        wait_for_completion=False
+        """
+        _create_db_cluster(self.hook)
+        _create_db_cluster_snapshot(self.hook)
+
+        cluster_snapshot_operator = RdsDeleteDbSnapshotOperator(
+            task_id="test_delete_db_cluster_snapshot_no_wait",
+            db_type="cluster",
+            db_snapshot_identifier=DB_CLUSTER_SNAPSHOT,
+            aws_conn_id=AWS_CONN,
+            dag=self.dag,
+            wait_for_completion=False,
+        )
+        _patch_hook_get_connection(cluster_snapshot_operator.hook)
+        with patch.object(cluster_snapshot_operator.hook, 
"wait_for_db_cluster_snapshot_state") as mock_wait:
+            cluster_snapshot_operator.execute(None)
+        mock_wait.assert_not_called()
 
         with pytest.raises(self.hook.conn.exceptions.ClientError):
             
self.hook.conn.describe_db_cluster_snapshots(DBClusterSnapshotIdentifier=DB_CLUSTER_SNAPSHOT)
@@ -466,7 +499,7 @@ class TestRdsStartExportTaskOperator:
         assert export_tasks[0]["Status"] == "complete"
 
     @mock_rds
-    @patch.object(RdsBaseOperator, "_await_status")
+    @patch.object(RdsHook, "wait_for_export_task_state")
     def test_start_export_task_no_wait(self, mock_await_status):
         _create_db_instance(self.hook)
         _create_db_instance_snapshot(self.hook)
@@ -491,7 +524,7 @@ class TestRdsStartExportTaskOperator:
         assert export_tasks
         assert len(export_tasks) == 1
         assert export_tasks[0]["Status"] == "complete"
-        assert mock_await_status.not_called()
+        mock_await_status.assert_not_called()
 
 
 class TestRdsCancelExportTaskOperator:
@@ -529,7 +562,7 @@ class TestRdsCancelExportTaskOperator:
         assert export_tasks[0]["Status"] == "canceled"
 
     @mock_rds
-    @patch.object(RdsBaseOperator, "_await_status")
+    @patch.object(RdsHook, "wait_for_export_task_state")
     def test_cancel_export_task_no_wait(self, mock_await_status):
         _create_db_instance(self.hook)
         _create_db_instance_snapshot(self.hook)
@@ -540,6 +573,7 @@ class TestRdsCancelExportTaskOperator:
             export_task_identifier=EXPORT_TASK_NAME,
             aws_conn_id=AWS_CONN,
             dag=self.dag,
+            wait_for_completion=False,
         )
         _patch_hook_get_connection(cancel_export_operator.hook)
         cancel_export_operator.execute(None)
@@ -550,7 +584,7 @@ class TestRdsCancelExportTaskOperator:
         assert export_tasks
         assert len(export_tasks) == 1
         assert export_tasks[0]["Status"] == "canceled"
-        assert mock_await_status.not_called()
+        mock_await_status.assert_not_called()
 
 
 class TestRdsCreateEventSubscriptionOperator:
@@ -589,7 +623,7 @@ class TestRdsCreateEventSubscriptionOperator:
         assert subscriptions[0]["Status"] == "active"
 
     @mock_rds
-    @patch.object(RdsBaseOperator, "_await_status")
+    @patch.object(RdsHook, "wait_for_event_subscription_state")
     def test_create_event_subscription_no_wait(self, mock_await_status):
         _create_db_instance(self.hook)
 
@@ -601,6 +635,7 @@ class TestRdsCreateEventSubscriptionOperator:
             source_ids=[DB_INSTANCE_NAME],
             aws_conn_id=AWS_CONN,
             dag=self.dag,
+            wait_for_completion=False,
         )
         _patch_hook_get_connection(create_subscription_operator.hook)
         create_subscription_operator.execute(None)
@@ -611,7 +646,7 @@ class TestRdsCreateEventSubscriptionOperator:
         assert subscriptions
         assert len(subscriptions) == 1
         assert subscriptions[0]["Status"] == "active"
-        assert mock_await_status.not_called()
+        mock_await_status.assert_not_called()
 
 
 class TestRdsDeleteEventSubscriptionOperator:
@@ -679,7 +714,7 @@ class TestRdsCreateDbInstanceOperator:
         assert db_instances[0]["DBInstanceStatus"] == "available"
 
     @mock_rds
-    @patch.object(RdsBaseOperator, "_await_status")
+    @patch.object(RdsHook, "wait_for_db_instance_state")
     def test_create_db_instance_no_wait(self, mock_await_status):
         create_db_instance_operator = RdsCreateDbInstanceOperator(
             task_id="test_create_db_instance_no_wait",
@@ -691,6 +726,7 @@ class TestRdsCreateDbInstanceOperator:
             },
             aws_conn_id=AWS_CONN,
             dag=self.dag,
+            wait_for_completion=False,
         )
         _patch_hook_get_connection(create_db_instance_operator.hook)
         create_db_instance_operator.execute(None)
@@ -701,7 +737,7 @@ class TestRdsCreateDbInstanceOperator:
         assert db_instances
         assert len(db_instances) == 1
         assert db_instances[0]["DBInstanceStatus"] == "available"
-        assert mock_await_status.not_called()
+        mock_await_status.assert_not_called()
 
 
 class TestRdsDeleteDbInstanceOperator:
@@ -717,7 +753,7 @@ class TestRdsDeleteDbInstanceOperator:
         del cls.hook
 
     @mock_rds
-    def test_delete_event_subscription(self):
+    def test_delete_db_instance(self):
         _create_db_instance(self.hook)
 
         delete_db_instance_operator = RdsDeleteDbInstanceOperator(
@@ -736,8 +772,8 @@ class TestRdsDeleteDbInstanceOperator:
             
self.hook.conn.describe_db_instances(DBInstanceIdentifier=DB_INSTANCE_NAME)
 
     @mock_rds
-    @patch.object(RdsBaseOperator, "_await_status")
-    def test_delete_event_subscription_no_wait(self, mock_await_status):
+    @patch.object(RdsHook, "wait_for_db_instance_state")
+    def test_delete_db_instance_no_wait(self, mock_await_status):
         _create_db_instance(self.hook)
 
         delete_db_instance_operator = RdsDeleteDbInstanceOperator(
@@ -755,7 +791,7 @@ class TestRdsDeleteDbInstanceOperator:
 
         with pytest.raises(self.hook.conn.exceptions.ClientError):
             
self.hook.conn.describe_db_instances(DBInstanceIdentifier=DB_INSTANCE_NAME)
-        assert mock_await_status.not_called()
+        mock_await_status.assert_not_called()
 
 
 class TestRdsStopDbOperator:
@@ -771,7 +807,7 @@ class TestRdsStopDbOperator:
         del cls.hook
 
     @mock_rds
-    @patch.object(RdsBaseOperator, "_await_status")
+    @patch.object(RdsHook, "wait_for_db_instance_state")
     def test_stop_db_instance(self, mock_await_status):
         _create_db_instance(self.hook)
         stop_db_instance = RdsStopDbOperator(task_id="test_stop_db_instance", 
db_identifier=DB_INSTANCE_NAME)
@@ -783,7 +819,7 @@ class TestRdsStopDbOperator:
         mock_await_status.assert_called()
 
     @mock_rds
-    @patch.object(RdsBaseOperator, "_await_status")
+    @patch.object(RdsHook, "wait_for_db_instance_state")
     def test_stop_db_instance_no_wait(self, mock_await_status):
         _create_db_instance(self.hook)
         stop_db_instance = RdsStopDbOperator(
@@ -817,7 +853,7 @@ class TestRdsStopDbOperator:
         assert len(instance_snapshots) == 1
 
     @mock_rds
-    @patch.object(RdsBaseOperator, "_await_status")
+    @patch.object(RdsHook, "wait_for_db_cluster_state")
     def test_stop_db_cluster(self, mock_await_status):
         _create_db_cluster(self.hook)
         stop_db_cluster = RdsStopDbOperator(
@@ -829,6 +865,7 @@ class TestRdsStopDbOperator:
         describe_result = 
self.hook.conn.describe_db_clusters(DBClusterIdentifier=DB_CLUSTER_NAME)
         status = describe_result["DBClusters"][0]["Status"]
         assert status == "stopped"
+        mock_await_status.assert_called()
 
     @mock_rds
     def test_stop_db_cluster_create_snapshot_logs_warning_message(self, 
caplog):
diff --git a/tests/providers/amazon/aws/sensors/test_rds.py 
b/tests/providers/amazon/aws/sensors/test_rds.py
index 0fca35f4fb..7db2b3cb82 100644
--- a/tests/providers/amazon/aws/sensors/test_rds.py
+++ b/tests/providers/amazon/aws/sensors/test_rds.py
@@ -16,10 +16,8 @@
 # under the License.
 from __future__ import annotations
 
-import pytest
 from moto import mock_rds
 
-from airflow.exceptions import AirflowException
 from airflow.models import DAG
 from airflow.providers.amazon.aws.hooks.rds import RdsHook
 from airflow.providers.amazon.aws.sensors.rds import (
@@ -119,22 +117,6 @@ class TestBaseRdsSensor:
         assert hasattr(self.base_sensor, "hook")
         assert self.base_sensor.hook.__class__.__name__ == "RdsHook"
 
-    def test_describe_item_wrong_type(self):
-        with pytest.raises(AirflowException):
-            self.base_sensor._describe_item("database", "auth-db")
-
-    def test_check_item_true(self):
-        self.base_sensor._describe_item = lambda item_type, item_name: 
[{"Status": "available"}]
-        self.base_sensor.target_statuses = ["available", "created"]
-
-        assert self.base_sensor._check_item(item_type="instance_snapshot", 
item_name="")
-
-    def test_check_item_false(self):
-        self.base_sensor._describe_item = lambda item_type, item_name: 
[{"Status": "creating"}]
-        self.base_sensor.target_statuses = ["available", "created"]
-
-        assert not self.base_sensor._check_item(item_type="instance_snapshot", 
item_name="")
-
 
 class TestRdsSnapshotExistenceSensor:
     @classmethod
@@ -161,6 +143,7 @@ class TestRdsSnapshotExistenceSensor:
 
     @mock_rds
     def test_db_instance_snapshot_poke_false(self):
+        _create_db_instance(self.hook)
         op = RdsSnapshotExistenceSensor(
             task_id="test_instance_snap_false",
             db_type="instance",

Reply via email to