vincbeck commented on code in PR #27410:
URL: https://github.com/apache/airflow/pull/27410#discussion_r1015618809


##########
airflow/providers/amazon/aws/hooks/rds.py:
##########
@@ -48,3 +50,294 @@ class RdsHook(AwsGenericHook["RDSClient"]):
     def __init__(self, *args, **kwargs) -> None:
         kwargs["client_type"] = "rds"
         super().__init__(*args, **kwargs)
+
+    def get_db_snapshot_state(self, snapshot_id: str) -> str:
+        """
+        Get the current state of a DB instance snapshot.
+
+        :param snapshot_id: The ID of the target DB instance snapshot
+        :return: Returns the status of the DB snapshot as a string (eg. 
"available")
+        :rtype: str
+        :raises AirflowNotFoundException: If the DB instance snapshot does not 
exist.
+        """
+        try:
+            response = 
self.conn.describe_db_snapshots(DBSnapshotIdentifier=snapshot_id)
+        except self.conn.exceptions.DBSnapshotNotFoundFault as e:
+            raise AirflowNotFoundException(e)
+        return response["DBSnapshots"][0]["Status"].lower()
+
+    def wait_for_db_snapshot_state(
+        self, snapshot_id: str, target_state: str, check_interval: int = 30, 
max_attempts: int = 40
+    ) -> None:
+        """
+        Polls :py:meth:`RDS.Client.describe_db_snapshots` until the target 
state is reached.
+        An error is raised after a max number of attempts.
+
+        :param snapshot_id: The ID of the target DB instance snapshot
+        :param target_state: Wait until this state is reached
+        :param check_interval: The amount of time in seconds to wait between 
attempts
+        :param max_attempts: The maximum number of attempts to be made
+        """
+
+        def poke():
+            return self.get_db_snapshot_state(snapshot_id)
+
+        target_state = target_state.lower()
+        if target_state in ("available", "deleted", "completed"):
+            waiter = self.conn.get_waiter(f"db_snapshot_{target_state}")  # 
type: ignore
+            waiter.wait(
+                DBSnapshotIdentifier=snapshot_id,
+                WaiterConfig={"Delay": check_interval, "MaxAttempts": 
max_attempts},
+            )
+        else:
+            self._wait_for_state(poke, target_state, check_interval, 
max_attempts)
+            self.log.info("DB snapshot '%s' reached the '%s' state" % 
(snapshot_id, target_state))
+
+    def get_db_cluster_snapshot_state(self, snapshot_id: str) -> str:
+        """
+        Get the current state of a DB cluster snapshot.
+
+        :param snapshot_id: The ID of the target DB cluster.
+        :return: Returns the status of the DB cluster snapshot as a string 
(eg. "available")
+        :rtype: str
+        :raises AirflowNotFoundException: If the DB cluster snapshot does not 
exist.
+        """
+        try:
+            response = 
self.conn.describe_db_cluster_snapshots(DBClusterSnapshotIdentifier=snapshot_id)
+        except self.conn.exceptions.DBClusterSnapshotNotFoundFault as e:
+            raise AirflowNotFoundException(e)
+        return response["DBClusterSnapshots"][0]["Status"].lower()
+
+    def wait_for_db_cluster_snapshot_state(
+        self, snapshot_id: str, target_state: str, check_interval: int = 30, 
max_attempts: int = 40
+    ) -> None:
+        """
+        Polls :py:meth:`RDS.Client.describe_db_cluster_snapshots` until the 
target state is reached.
+        An error is raised after a max number of attempts.
+
+        :param snapshot_id: The ID of the target DB cluster snapshot
+        :param target_state: Wait until this state is reached
+        :param check_interval: The amount of time in seconds to wait between 
attempts
+        :param max_attempts: The maximum number of attempts to be made
+
+        .. seealso::
+            A list of possible values for target_state:
+            
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds.html#RDS.Client.describe_db_cluster_snapshots
+        """
+
+        def poke():
+            return self.get_db_cluster_snapshot_state(snapshot_id)
+
+        target_state = target_state.lower()
+        if target_state in ("available", "deleted"):
+            waiter = 
self.conn.get_waiter(f"db_cluster_snapshot_{target_state}")  # type: ignore
+            waiter.wait(
+                DBClusterSnapshotIdentifier=snapshot_id,
+                WaiterConfig={"Delay": check_interval, "MaxAttempts": 
max_attempts},
+            )
+        else:
+            self._wait_for_state(poke, target_state, check_interval, 
max_attempts)
+            self.log.info("DB cluster snapshot '%s' reached the '%s' state" % 
(snapshot_id, target_state))
+
+    def get_export_task_state(self, export_task_id: str) -> str:
+        """
+        Gets the current state of an RDS snapshot export to Amazon S3.
+
+        :param export_task_id: The identifier of the target snapshot export 
task.
+        :return: Returns the status of the snapshot export task as a string 
(eg. "canceled")
+        :rtype: str
+        :raises AirflowNotFoundException: If the export task does not exist.
+        """
+        try:
+            response = 
self.conn.describe_export_tasks(ExportTaskIdentifier=export_task_id)
+        # The RDS botocore documentation states that describe_export_tasks 
raises an exception of type
+        # ExportTaskNotFoundFault when the export task does not exist, but 
unit tests show that a generic
+        # ClientError is raised instead.

Review Comment:
   This is boto3 behavior. Every service specific exception is wrapper inside 
`ClientError`. [See documentation 
here](https://boto3.amazonaws.com/v1/documentation/api/latest/guide/error-handling.html#aws-service-exceptions).
 I would then double check that all the service exception you are catching are 
really thrown by boto3 (e.g. `DBSnapshotNotFoundFault`, 
`DBClusterSnapshotNotFoundFault`, ...). They really look service specific 
exceptions. One proposal as well is is to create a common function we could use 
across the amazon provider package to get the actual exception. The code 
`.response["Error"]["Code"]` is widely used



##########
airflow/providers/amazon/aws/operators/rds.py:
##########
@@ -314,13 +244,15 @@ def __init__(
         *,
         db_type: str,
         db_snapshot_identifier: str,
+        wait_for_completion: bool = True,

Review Comment:
   Nice! Love that! Could you add unit tests associated to this new feature?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to