hankehly commented on code in PR #27410:
URL: https://github.com/apache/airflow/pull/27410#discussion_r1009703650


##########
airflow/providers/amazon/aws/hooks/rds.py:
##########
@@ -48,3 +50,294 @@ class RdsHook(AwsGenericHook["RDSClient"]):
     def __init__(self, *args, **kwargs) -> None:
         kwargs["client_type"] = "rds"
         super().__init__(*args, **kwargs)
+
+    def get_db_snapshot_state(self, snapshot_id: str) -> str:
+        """
+        Get the current state of a DB instance snapshot.
+
+        :param snapshot_id: The ID of the target DB instance snapshot
+        :return: Returns the status of the DB snapshot as a string (eg. 
"available")
+        :rtype: str
+        :raises AirflowNotFoundException: If the DB instance snapshot does not 
exist.
+        """
+        try:
+            response = 
self.conn.describe_db_snapshots(DBSnapshotIdentifier=snapshot_id)
+        except self.conn.exceptions.DBSnapshotNotFoundFault as e:
+            raise AirflowNotFoundException(e)
+        return response["DBSnapshots"][0]["Status"].lower()
+
+    def wait_for_db_snapshot_state(
+        self, snapshot_id: str, target_state: str, check_interval: int = 30, 
max_attempts: int = 40
+    ) -> None:
+        """
+        Polls :py:meth:`RDS.Client.describe_db_snapshots` until the target 
state is reached.
+        An error is raised after a max number of attempts.
+
+        :param snapshot_id: The ID of the target DB instance snapshot
+        :param target_state: Wait until this state is reached
+        :param check_interval: The amount of time in seconds to wait between 
attempts
+        :param max_attempts: The maximum number of attempts to be made
+        """
+
+        def poke():
+            return self.get_db_snapshot_state(snapshot_id)
+
+        target_state = target_state.lower()
+        if target_state in ("available", "deleted", "completed"):
+            waiter = self.conn.get_waiter(f"db_snapshot_{target_state}")  # 
type: ignore

Review Comment:
   I see `no overload exists for ...` mypy warnings here because I'm creating 
the waiter name dynamically. The `if` statement one line above assures we 
aren't creating a waiter that doesn't exist so I ignore the warning, but if 
someone has a more appropriate solution, please let me know.



##########
airflow/providers/amazon/aws/hooks/rds.py:
##########
@@ -48,3 +50,294 @@ class RdsHook(AwsGenericHook["RDSClient"]):
     def __init__(self, *args, **kwargs) -> None:
         kwargs["client_type"] = "rds"
         super().__init__(*args, **kwargs)
+
+    def get_db_snapshot_state(self, snapshot_id: str) -> str:
+        """
+        Get the current state of a DB instance snapshot.
+
+        :param snapshot_id: The ID of the target DB instance snapshot
+        :return: Returns the status of the DB snapshot as a string (eg. 
"available")
+        :rtype: str
+        :raises AirflowNotFoundException: If the DB instance snapshot does not 
exist.
+        """
+        try:
+            response = 
self.conn.describe_db_snapshots(DBSnapshotIdentifier=snapshot_id)
+        except self.conn.exceptions.DBSnapshotNotFoundFault as e:
+            raise AirflowNotFoundException(e)
+        return response["DBSnapshots"][0]["Status"].lower()
+
+    def wait_for_db_snapshot_state(
+        self, snapshot_id: str, target_state: str, check_interval: int = 30, 
max_attempts: int = 40
+    ) -> None:
+        """
+        Polls :py:meth:`RDS.Client.describe_db_snapshots` until the target 
state is reached.
+        An error is raised after a max number of attempts.
+
+        :param snapshot_id: The ID of the target DB instance snapshot
+        :param target_state: Wait until this state is reached
+        :param check_interval: The amount of time in seconds to wait between 
attempts
+        :param max_attempts: The maximum number of attempts to be made
+        """
+
+        def poke():
+            return self.get_db_snapshot_state(snapshot_id)
+
+        target_state = target_state.lower()
+        if target_state in ("available", "deleted", "completed"):
+            waiter = self.conn.get_waiter(f"db_snapshot_{target_state}")  # 
type: ignore

Review Comment:
   I see `no overload exists for ...` mypy warnings here because I'm creating 
the waiter name dynamically. The `if` statement one line above assures we 
aren't creating a waiter that doesn't exist so I ignore the warning. If someone 
has a more appropriate solution, please let me know.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to