hankehly commented on code in PR #27410:
URL: https://github.com/apache/airflow/pull/27410#discussion_r1017102069


##########
airflow/providers/amazon/aws/hooks/rds.py:
##########
@@ -48,3 +50,294 @@ class RdsHook(AwsGenericHook["RDSClient"]):
     def __init__(self, *args, **kwargs) -> None:
         kwargs["client_type"] = "rds"
         super().__init__(*args, **kwargs)
+
+    def get_db_snapshot_state(self, snapshot_id: str) -> str:
+        """
+        Get the current state of a DB instance snapshot.
+
+        :param snapshot_id: The ID of the target DB instance snapshot
+        :return: Returns the status of the DB snapshot as a string (eg. 
"available")
+        :rtype: str
+        :raises AirflowNotFoundException: If the DB instance snapshot does not 
exist.
+        """
+        try:
+            response = 
self.conn.describe_db_snapshots(DBSnapshotIdentifier=snapshot_id)
+        except self.conn.exceptions.DBSnapshotNotFoundFault as e:
+            raise AirflowNotFoundException(e)
+        return response["DBSnapshots"][0]["Status"].lower()
+
+    def wait_for_db_snapshot_state(
+        self, snapshot_id: str, target_state: str, check_interval: int = 30, 
max_attempts: int = 40
+    ) -> None:
+        """
+        Polls :py:meth:`RDS.Client.describe_db_snapshots` until the target 
state is reached.
+        An error is raised after a max number of attempts.
+
+        :param snapshot_id: The ID of the target DB instance snapshot
+        :param target_state: Wait until this state is reached
+        :param check_interval: The amount of time in seconds to wait between 
attempts
+        :param max_attempts: The maximum number of attempts to be made
+        """
+
+        def poke():
+            return self.get_db_snapshot_state(snapshot_id)
+
+        target_state = target_state.lower()
+        if target_state in ("available", "deleted", "completed"):
+            waiter = self.conn.get_waiter(f"db_snapshot_{target_state}")  # 
type: ignore
+            waiter.wait(
+                DBSnapshotIdentifier=snapshot_id,
+                WaiterConfig={"Delay": check_interval, "MaxAttempts": 
max_attempts},
+            )
+        else:
+            self._wait_for_state(poke, target_state, check_interval, 
max_attempts)
+            self.log.info("DB snapshot '%s' reached the '%s' state" % 
(snapshot_id, target_state))
+
+    def get_db_cluster_snapshot_state(self, snapshot_id: str) -> str:
+        """
+        Get the current state of a DB cluster snapshot.
+
+        :param snapshot_id: The ID of the target DB cluster.
+        :return: Returns the status of the DB cluster snapshot as a string 
(eg. "available")
+        :rtype: str
+        :raises AirflowNotFoundException: If the DB cluster snapshot does not 
exist.
+        """
+        try:
+            response = 
self.conn.describe_db_cluster_snapshots(DBClusterSnapshotIdentifier=snapshot_id)
+        except self.conn.exceptions.DBClusterSnapshotNotFoundFault as e:
+            raise AirflowNotFoundException(e)
+        return response["DBClusterSnapshots"][0]["Status"].lower()
+
+    def wait_for_db_cluster_snapshot_state(
+        self, snapshot_id: str, target_state: str, check_interval: int = 30, 
max_attempts: int = 40
+    ) -> None:
+        """
+        Polls :py:meth:`RDS.Client.describe_db_cluster_snapshots` until the 
target state is reached.
+        An error is raised after a max number of attempts.
+
+        :param snapshot_id: The ID of the target DB cluster snapshot
+        :param target_state: Wait until this state is reached
+        :param check_interval: The amount of time in seconds to wait between 
attempts
+        :param max_attempts: The maximum number of attempts to be made
+
+        .. seealso::
+            A list of possible values for target_state:
+            
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds.html#RDS.Client.describe_db_cluster_snapshots
+        """
+
+        def poke():
+            return self.get_db_cluster_snapshot_state(snapshot_id)
+
+        target_state = target_state.lower()
+        if target_state in ("available", "deleted"):
+            waiter = 
self.conn.get_waiter(f"db_cluster_snapshot_{target_state}")  # type: ignore
+            waiter.wait(
+                DBClusterSnapshotIdentifier=snapshot_id,
+                WaiterConfig={"Delay": check_interval, "MaxAttempts": 
max_attempts},
+            )
+        else:
+            self._wait_for_state(poke, target_state, check_interval, 
max_attempts)
+            self.log.info("DB cluster snapshot '%s' reached the '%s' state" % 
(snapshot_id, target_state))
+
+    def get_export_task_state(self, export_task_id: str) -> str:
+        """
+        Gets the current state of an RDS snapshot export to Amazon S3.
+
+        :param export_task_id: The identifier of the target snapshot export 
task.
+        :return: Returns the status of the snapshot export task as a string 
(eg. "canceled")
+        :rtype: str
+        :raises AirflowNotFoundException: If the export task does not exist.
+        """
+        try:
+            response = 
self.conn.describe_export_tasks(ExportTaskIdentifier=export_task_id)
+        # The RDS botocore documentation states that describe_export_tasks 
raises an exception of type
+        # ExportTaskNotFoundFault when the export task does not exist, but 
unit tests show that a generic
+        # ClientError is raised instead.

Review Comment:
   Thanks for mentioning this. I did the following:
   * Added another test case for each wait method to check that the appropriate 
"not found" error is raised
   * Replaced all related `try..except` blocks with what the documentation 
recommends
   
   The documentation states that only _some_ service exceptions are available 
through dot-notation, but _all_ service exceptions are available through 
`.response["Error"]["Code"]`.
   > Additionally, you can also access some of the dynamic service-side 
exceptions from the client’s exception property.
   
   .. further down the page:
   > Catching exceptions through ClientError and parsing for error codes is 
still the best way to catch all service-side exceptions and errors.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to