This is an automated email from the ASF dual-hosted git repository.

onikolas pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new b9c84eb663 add deferrable mode to rds start & stop DB (#32437)
b9c84eb663 is described below

commit b9c84eb6639e825ed951c08e477411bf52dfc437
Author: RaphaĆ«l Vandon <[email protected]>
AuthorDate: Wed Jul 19 10:37:32 2023 -0700

    add deferrable mode to rds start & stop DB (#32437)
    
    * remove tests on obsolete trigger
    
    * add delay & attempts to params
    
    ---------
    
    Co-authored-by: Wei Lee <[email protected]>
---
 airflow/providers/amazon/aws/operators/rds.py    | 168 +++++++++++++------
 airflow/providers/amazon/aws/triggers/rds.py     | 166 ++++++++++++++++++-
 airflow/providers/amazon/aws/utils/rds.py        |   4 +-
 docs/spelling_wordlist.txt                       |   1 +
 tests/providers/amazon/aws/operators/test_rds.py |  29 ++++
 tests/providers/amazon/aws/triggers/test_rds.py  | 201 ++++++-----------------
 6 files changed, 363 insertions(+), 206 deletions(-)

diff --git a/airflow/providers/amazon/aws/operators/rds.py 
b/airflow/providers/amazon/aws/operators/rds.py
index c037251643..7a544a8c00 100644
--- a/airflow/providers/amazon/aws/operators/rds.py
+++ b/airflow/providers/amazon/aws/operators/rds.py
@@ -20,15 +20,20 @@ from __future__ import annotations
 import json
 import warnings
 from datetime import timedelta
-from typing import TYPE_CHECKING, Sequence
+from typing import TYPE_CHECKING, Any, Sequence
 
+from astroid.decorators import cachedproperty
 from mypy_boto3_rds.type_defs import TagTypeDef
 
 from airflow.configuration import conf
 from airflow.exceptions import AirflowException, 
AirflowProviderDeprecationWarning
 from airflow.models import BaseOperator
 from airflow.providers.amazon.aws.hooks.rds import RdsHook
-from airflow.providers.amazon.aws.triggers.rds import RdsDbInstanceTrigger
+from airflow.providers.amazon.aws.triggers.rds import (
+    RdsDbAvailableTrigger,
+    RdsDbDeletedTrigger,
+    RdsDbStoppedTrigger,
+)
 from airflow.providers.amazon.aws.utils.rds import RdsDbType
 from airflow.providers.amazon.aws.utils.tags import format_tags
 from airflow.providers.amazon.aws.utils.waiter_with_logging import wait
@@ -62,13 +67,17 @@ class RdsBaseOperator(BaseOperator):
                 AirflowProviderDeprecationWarning,
                 stacklevel=3,  # 2 is in the operator's init, 3 is in the user 
code creating the operator
             )
-        hook_params = hook_params or {}
-        self.region_name = region_name or hook_params.pop("region_name", None)
-        self.hook = RdsHook(aws_conn_id=aws_conn_id, 
region_name=self.region_name, **(hook_params))
+        self.hook_params = hook_params or {}
+        self.aws_conn_id = aws_conn_id
+        self.region_name = region_name or self.hook_params.pop("region_name", 
None)
         super().__init__(*args, **kwargs)
 
         self._await_interval = 60  # seconds
 
+    @cachedproperty
+    def hook(self) -> RdsHook:
+        return RdsHook(aws_conn_id=self.aws_conn_id, 
region_name=self.region_name, **self.hook_params)
+
     def execute(self, context: Context) -> str:
         """Different implementations for snapshots, tasks and events."""
         raise NotImplementedError
@@ -106,10 +115,9 @@ class RdsCreateDbSnapshotOperator(RdsBaseOperator):
         db_snapshot_identifier: str,
         tags: Sequence[TagTypeDef] | dict | None = None,
         wait_for_completion: bool = True,
-        aws_conn_id: str = "aws_default",
         **kwargs,
     ):
-        super().__init__(aws_conn_id=aws_conn_id, **kwargs)
+        super().__init__(**kwargs)
         self.db_type = RdsDbType(db_type)
         self.db_identifier = db_identifier
         self.db_snapshot_identifier = db_snapshot_identifier
@@ -194,10 +202,9 @@ class RdsCopyDbSnapshotOperator(RdsBaseOperator):
         target_custom_availability_zone: str = "",
         source_region: str = "",
         wait_for_completion: bool = True,
-        aws_conn_id: str = "aws_default",
         **kwargs,
     ):
-        super().__init__(aws_conn_id=aws_conn_id, **kwargs)
+        super().__init__(**kwargs)
 
         self.db_type = RdsDbType(db_type)
         self.source_db_snapshot_identifier = source_db_snapshot_identifier
@@ -274,10 +281,9 @@ class RdsDeleteDbSnapshotOperator(RdsBaseOperator):
         db_type: str,
         db_snapshot_identifier: str,
         wait_for_completion: bool = True,
-        aws_conn_id: str = "aws_default",
         **kwargs,
     ):
-        super().__init__(aws_conn_id=aws_conn_id, **kwargs)
+        super().__init__(**kwargs)
 
         self.db_type = RdsDbType(db_type)
         self.db_snapshot_identifier = db_snapshot_identifier
@@ -345,10 +351,9 @@ class RdsStartExportTaskOperator(RdsBaseOperator):
         s3_prefix: str = "",
         export_only: list[str] | None = None,
         wait_for_completion: bool = True,
-        aws_conn_id: str = "aws_default",
         **kwargs,
     ):
-        super().__init__(aws_conn_id=aws_conn_id, **kwargs)
+        super().__init__(**kwargs)
 
         self.export_task_identifier = export_task_identifier
         self.source_arn = source_arn
@@ -397,10 +402,9 @@ class RdsCancelExportTaskOperator(RdsBaseOperator):
         export_task_identifier: str,
         wait_for_completion: bool = True,
         check_interval: int = 30,
-        aws_conn_id: str = "aws_default",
         **kwargs,
     ):
-        super().__init__(aws_conn_id=aws_conn_id, **kwargs)
+        super().__init__(**kwargs)
 
         self.export_task_identifier = export_task_identifier
         self.wait_for_completion = wait_for_completion
@@ -461,10 +465,9 @@ class RdsCreateEventSubscriptionOperator(RdsBaseOperator):
         enabled: bool = True,
         tags: Sequence[TagTypeDef] | dict | None = None,
         wait_for_completion: bool = True,
-        aws_conn_id: str = "aws_default",
         **kwargs,
     ):
-        super().__init__(aws_conn_id=aws_conn_id, **kwargs)
+        super().__init__(**kwargs)
 
         self.subscription_name = subscription_name
         self.sns_topic_arn = sns_topic_arn
@@ -511,10 +514,9 @@ class RdsDeleteEventSubscriptionOperator(RdsBaseOperator):
         self,
         *,
         subscription_name: str,
-        aws_conn_id: str = "aws_default",
         **kwargs,
     ):
-        super().__init__(aws_conn_id=aws_conn_id, **kwargs)
+        super().__init__(**kwargs)
 
         self.subscription_name = subscription_name
 
@@ -545,7 +547,6 @@ class RdsCreateDbInstanceOperator(RdsBaseOperator):
     :param engine: The name of the database engine to be used for this instance
     :param rds_kwargs: Named arguments to pass to boto3 RDS client function 
``create_db_instance``
         
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds.html#RDS.Client.create_db_instance
-    :param aws_conn_id: The Airflow connection used for AWS credentials.
     :param wait_for_completion:  If True, waits for creation of the DB 
instance to complete. (default: True)
     :param waiter_delay: Time (in seconds) to wait between two consecutive 
calls to check DB instance state
     :param waiter_max_attempts: The maximum number of attempts to check DB 
instance state
@@ -563,14 +564,13 @@ class RdsCreateDbInstanceOperator(RdsBaseOperator):
         db_instance_class: str,
         engine: str,
         rds_kwargs: dict | None = None,
-        aws_conn_id: str = "aws_default",
         wait_for_completion: bool = True,
         deferrable: bool = conf.getboolean("operators", "default_deferrable", 
fallback=False),
         waiter_delay: int = 30,
         waiter_max_attempts: int = 60,
         **kwargs,
     ):
-        super().__init__(aws_conn_id=aws_conn_id, **kwargs)
+        super().__init__(**kwargs)
 
         self.db_instance_identifier = db_instance_identifier
         self.db_instance_class = db_instance_class
@@ -580,7 +580,6 @@ class RdsCreateDbInstanceOperator(RdsBaseOperator):
         self.deferrable = deferrable
         self.waiter_delay = waiter_delay
         self.waiter_max_attempts = waiter_max_attempts
-        self.aws_conn_id = aws_conn_id
 
     def execute(self, context: Context) -> str:
         self.log.info("Creating new DB instance %s", 
self.db_instance_identifier)
@@ -593,15 +592,15 @@ class RdsCreateDbInstanceOperator(RdsBaseOperator):
         )
         if self.deferrable:
             self.defer(
-                trigger=RdsDbInstanceTrigger(
-                    db_instance_identifier=self.db_instance_identifier,
+                trigger=RdsDbAvailableTrigger(
+                    db_identifier=self.db_instance_identifier,
                     waiter_delay=self.waiter_delay,
                     waiter_max_attempts=self.waiter_max_attempts,
                     aws_conn_id=self.aws_conn_id,
                     region_name=self.region_name,
-                    waiter_name="db_instance_available",
                     # ignoring type because create_db_instance is a dict
                     response=create_db_instance,  # type: ignore[arg-type]
+                    db_type=RdsDbType.INSTANCE,
                 ),
                 method_name="execute_complete",
                 timeout=timedelta(seconds=self.waiter_delay * 
self.waiter_max_attempts),
@@ -638,7 +637,6 @@ class RdsDeleteDbInstanceOperator(RdsBaseOperator):
     :param db_instance_identifier: The DB instance identifier for the DB 
instance to be deleted
     :param rds_kwargs: Named arguments to pass to boto3 RDS client function 
``delete_db_instance``
         
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds.html#RDS.Client.delete_db_instance
-    :param aws_conn_id: The Airflow connection used for AWS credentials.
     :param wait_for_completion:  If True, waits for deletion of the DB 
instance to complete. (default: True)
     :param waiter_delay: Time (in seconds) to wait between two consecutive 
calls to check DB instance state
     :param waiter_max_attempts: The maximum number of attempts to check DB 
instance state
@@ -654,21 +652,19 @@ class RdsDeleteDbInstanceOperator(RdsBaseOperator):
         *,
         db_instance_identifier: str,
         rds_kwargs: dict | None = None,
-        aws_conn_id: str = "aws_default",
         wait_for_completion: bool = True,
         deferrable: bool = conf.getboolean("operators", "default_deferrable", 
fallback=False),
         waiter_delay: int = 30,
         waiter_max_attempts: int = 60,
         **kwargs,
     ):
-        super().__init__(aws_conn_id=aws_conn_id, **kwargs)
+        super().__init__(**kwargs)
         self.db_instance_identifier = db_instance_identifier
         self.rds_kwargs = rds_kwargs or {}
         self.wait_for_completion = False if deferrable else wait_for_completion
         self.deferrable = deferrable
         self.waiter_delay = waiter_delay
         self.waiter_max_attempts = waiter_max_attempts
-        self.aws_conn_id = aws_conn_id
 
     def execute(self, context: Context) -> str:
         self.log.info("Deleting DB instance %s", self.db_instance_identifier)
@@ -679,15 +675,15 @@ class RdsDeleteDbInstanceOperator(RdsBaseOperator):
         )
         if self.deferrable:
             self.defer(
-                trigger=RdsDbInstanceTrigger(
-                    db_instance_identifier=self.db_instance_identifier,
+                trigger=RdsDbDeletedTrigger(
+                    db_identifier=self.db_instance_identifier,
                     waiter_delay=self.waiter_delay,
                     waiter_max_attempts=self.waiter_max_attempts,
                     aws_conn_id=self.aws_conn_id,
                     region_name=self.region_name,
-                    waiter_name="db_instance_deleted",
                     # ignoring type because delete_db_instance is a dict
                     response=delete_db_instance,  # type: ignore[arg-type]
+                    db_type=RdsDbType.INSTANCE,
                 ),
                 method_name="execute_complete",
                 timeout=timedelta(seconds=self.waiter_delay * 
self.waiter_max_attempts),
@@ -723,8 +719,11 @@ class RdsStartDbOperator(RdsBaseOperator):
 
     :param db_identifier: The AWS identifier of the DB to start
     :param db_type: Type of the DB - either "instance" or "cluster" (default: 
"instance")
-    :param aws_conn_id: The Airflow connection used for AWS credentials. 
(default: "aws_default")
     :param wait_for_completion:  If True, waits for DB to start. (default: 
True)
+    :param waiter_delay: Time (in seconds) to wait between two consecutive 
calls to check DB instance state
+    :param waiter_max_attempts: The maximum number of attempts to check DB 
instance state
+    :param deferrable: If True, the operator will wait asynchronously for the 
DB instance to be created.
+        This implies waiting for completion. This mode requires aiobotocore 
module to be installed.
     """
 
     template_fields = ("db_identifier", "db_type")
@@ -734,26 +733,52 @@ class RdsStartDbOperator(RdsBaseOperator):
         *,
         db_identifier: str,
         db_type: RdsDbType | str = RdsDbType.INSTANCE,
-        aws_conn_id: str = "aws_default",
         wait_for_completion: bool = True,
+        waiter_delay: int = 30,
+        waiter_max_attempts: int = 40,
+        deferrable: bool = conf.getboolean("operators", "default_deferrable", 
fallback=False),
         **kwargs,
     ):
-        super().__init__(aws_conn_id=aws_conn_id, **kwargs)
+        super().__init__(**kwargs)
         self.db_identifier = db_identifier
         self.db_type = db_type
         self.wait_for_completion = wait_for_completion
+        self.waiter_delay = waiter_delay
+        self.waiter_max_attempts = waiter_max_attempts
+        self.deferrable = deferrable
 
     def execute(self, context: Context) -> str:
         self.db_type = RdsDbType(self.db_type)
-        start_db_response = self._start_db()
-        if self.wait_for_completion:
+        start_db_response: dict[str, Any] = self._start_db()
+        if self.deferrable:
+            self.defer(
+                trigger=RdsDbAvailableTrigger(
+                    db_identifier=self.db_identifier,
+                    waiter_delay=self.waiter_delay,
+                    waiter_max_attempts=self.waiter_max_attempts,
+                    aws_conn_id=self.aws_conn_id,
+                    region_name=self.region_name,
+                    response=start_db_response,
+                    db_type=RdsDbType.INSTANCE,
+                ),
+                method_name="execute_complete",
+            )
+        elif self.wait_for_completion:
             self._wait_until_db_available()
         return json.dumps(start_db_response, default=str)
 
+    def execute_complete(self, context: Context, event: dict[str, Any] | None 
= None) -> str:
+        if event is None or event["status"] != "success":
+            raise AirflowException(f"Failed to start DB: {event}")
+        else:
+            return json.dumps(event["response"], default=str)
+
     def _start_db(self):
         self.log.info("Starting DB %s '%s'", self.db_type.value, 
self.db_identifier)
         if self.db_type == RdsDbType.INSTANCE:
-            response = 
self.hook.conn.start_db_instance(DBInstanceIdentifier=self.db_identifier)
+            response = self.hook.conn.start_db_instance(
+                DBInstanceIdentifier=self.db_identifier,
+            )
         else:
             response = 
self.hook.conn.start_db_cluster(DBClusterIdentifier=self.db_identifier)
         return response
@@ -761,9 +786,19 @@ class RdsStartDbOperator(RdsBaseOperator):
     def _wait_until_db_available(self):
         self.log.info("Waiting for DB %s to reach 'available' state", 
self.db_type.value)
         if self.db_type == RdsDbType.INSTANCE:
-            self.hook.wait_for_db_instance_state(self.db_identifier, 
target_state="available")
+            self.hook.wait_for_db_instance_state(
+                self.db_identifier,
+                target_state="available",
+                check_interval=self.waiter_delay,
+                max_attempts=self.waiter_max_attempts,
+            )
         else:
-            self.hook.wait_for_db_cluster_state(self.db_identifier, 
target_state="available")
+            self.hook.wait_for_db_cluster_state(
+                self.db_identifier,
+                target_state="available",
+                check_interval=self.waiter_delay,
+                max_attempts=self.waiter_max_attempts,
+            )
 
 
 class RdsStopDbOperator(RdsBaseOperator):
@@ -779,8 +814,11 @@ class RdsStopDbOperator(RdsBaseOperator):
     :param db_snapshot_identifier: The instance identifier of the DB Snapshot 
to create before
         stopping the DB instance. The default value (None) skips snapshot 
creation. This
         parameter is ignored when ``db_type`` is "cluster"
-    :param aws_conn_id: The Airflow connection used for AWS credentials. 
(default: "aws_default")
     :param wait_for_completion:  If True, waits for DB to stop. (default: True)
+    :param waiter_delay: Time (in seconds) to wait between two consecutive 
calls to check DB instance state
+    :param waiter_max_attempts: The maximum number of attempts to check DB 
instance state
+    :param deferrable: If True, the operator will wait asynchronously for the 
DB instance to be created.
+        This implies waiting for completion. This mode requires aiobotocore 
module to be installed.
     """
 
     template_fields = ("db_identifier", "db_snapshot_identifier", "db_type")
@@ -791,23 +829,47 @@ class RdsStopDbOperator(RdsBaseOperator):
         db_identifier: str,
         db_type: RdsDbType | str = RdsDbType.INSTANCE,
         db_snapshot_identifier: str | None = None,
-        aws_conn_id: str = "aws_default",
         wait_for_completion: bool = True,
+        waiter_delay: int = 30,
+        waiter_max_attempts: int = 40,
+        deferrable: bool = conf.getboolean("operators", "default_deferrable", 
fallback=False),
         **kwargs,
     ):
-        super().__init__(aws_conn_id=aws_conn_id, **kwargs)
+        super().__init__(**kwargs)
         self.db_identifier = db_identifier
         self.db_type = db_type
         self.db_snapshot_identifier = db_snapshot_identifier
         self.wait_for_completion = wait_for_completion
+        self.waiter_delay = waiter_delay
+        self.waiter_max_attempts = waiter_max_attempts
+        self.deferrable = deferrable
 
     def execute(self, context: Context) -> str:
         self.db_type = RdsDbType(self.db_type)
-        stop_db_response = self._stop_db()
-        if self.wait_for_completion:
+        stop_db_response: dict[str, Any] = self._stop_db()
+        if self.deferrable:
+            self.defer(
+                trigger=RdsDbStoppedTrigger(
+                    db_identifier=self.db_identifier,
+                    waiter_delay=self.waiter_delay,
+                    waiter_max_attempts=self.waiter_max_attempts,
+                    aws_conn_id=self.aws_conn_id,
+                    region_name=self.region_name,
+                    response=stop_db_response,
+                    db_type=RdsDbType.INSTANCE,
+                ),
+                method_name="execute_complete",
+            )
+        elif self.wait_for_completion:
             self._wait_until_db_stopped()
         return json.dumps(stop_db_response, default=str)
 
+    def execute_complete(self, context: Context, event: dict[str, Any] | None 
= None) -> str:
+        if event is None or event["status"] != "success":
+            raise AirflowException(f"Failed to start DB: {event}")
+        else:
+            return json.dumps(event["response"], default=str)
+
     def _stop_db(self):
         self.log.info("Stopping DB %s '%s'", self.db_type.value, 
self.db_identifier)
         if self.db_type == RdsDbType.INSTANCE:
@@ -829,9 +891,19 @@ class RdsStopDbOperator(RdsBaseOperator):
     def _wait_until_db_stopped(self):
         self.log.info("Waiting for DB %s to reach 'stopped' state", 
self.db_type.value)
         if self.db_type == RdsDbType.INSTANCE:
-            self.hook.wait_for_db_instance_state(self.db_identifier, 
target_state="stopped")
+            self.hook.wait_for_db_instance_state(
+                self.db_identifier,
+                target_state="stopped",
+                check_interval=self.waiter_delay,
+                max_attempts=self.waiter_max_attempts,
+            )
         else:
-            self.hook.wait_for_db_cluster_state(self.db_identifier, 
target_state="stopped")
+            self.hook.wait_for_db_cluster_state(
+                self.db_identifier,
+                target_state="stopped",
+                check_interval=self.waiter_delay,
+                max_attempts=self.waiter_max_attempts,
+            )
 
 
 __all__ = [
diff --git a/airflow/providers/amazon/aws/triggers/rds.py 
b/airflow/providers/amazon/aws/triggers/rds.py
index 0551d67591..b8f20b9c9b 100644
--- a/airflow/providers/amazon/aws/triggers/rds.py
+++ b/airflow/providers/amazon/aws/triggers/rds.py
@@ -16,19 +16,21 @@
 # under the License.
 from __future__ import annotations
 
+import warnings
 from typing import Any
 
+from airflow.exceptions import AirflowProviderDeprecationWarning
+from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
 from airflow.providers.amazon.aws.hooks.rds import RdsHook
+from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger
+from airflow.providers.amazon.aws.utils.rds import RdsDbType
 from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait
 from airflow.triggers.base import BaseTrigger, TriggerEvent
 
 
 class RdsDbInstanceTrigger(BaseTrigger):
     """
-    Trigger for RdsCreateDbInstanceOperator and RdsDeleteDbInstanceOperator.
-
-    The trigger will asynchronously poll the boto3 API and wait for the
-    DB instance to be in the state specified by the waiter.
+    Deprecated Trigger for RDS operations. Do not use.
 
     :param waiter_name: Name of the waiter to use, for instance 
'db_instance_available'
         or 'db_instance_deleted'.
@@ -36,7 +38,7 @@ class RdsDbInstanceTrigger(BaseTrigger):
     :param waiter_delay: The amount of time in seconds to wait between 
attempts.
     :param waiter_max_attempts: The maximum number of attempts to be made.
     :param aws_conn_id: The Airflow connection used for AWS credentials.
-    :param hook_params: The parameters to pass to the RdsHook.
+    :param region_name: AWS region where the DB is located, if different from 
the default one.
     :param response: The response from the RdsHook, to be passed back to the 
operator.
     """
 
@@ -50,6 +52,12 @@ class RdsDbInstanceTrigger(BaseTrigger):
         region_name: str | None,
         response: dict[str, Any],
     ):
+        warnings.warn(
+            "This trigger is deprecated, please use the other RDS triggers "
+            "such as RdsDbDeletedTrigger, RdsDbStoppedTrigger or 
RdsDbAvailableTrigger",
+            AirflowProviderDeprecationWarning,
+            stacklevel=2,
+        )
         self.db_instance_identifier = db_instance_identifier
         self.waiter_delay = waiter_delay
         self.waiter_max_attempts = waiter_max_attempts
@@ -87,3 +95,151 @@ class RdsDbInstanceTrigger(BaseTrigger):
                 status_args=["DBInstances[0].DBInstanceStatus"],
             )
         yield TriggerEvent({"status": "success", "response": self.response})
+
+
+_waiter_arg = {
+    RdsDbType.INSTANCE: "DBInstanceIdentifier",
+    RdsDbType.CLUSTER: "DBClusterIdentifier",
+}
+_status_paths = {
+    RdsDbType.INSTANCE: ["DBInstances[].DBInstanceStatus", 
"DBInstances[].StatusInfos"],
+    RdsDbType.CLUSTER: ["DBClusters[].Status"],
+}
+
+
+class RdsDbAvailableTrigger(AwsBaseWaiterTrigger):
+    """
+    Trigger to wait asynchronously for a DB instance or cluster to be 
available.
+
+    :param db_identifier: The DB identifier for the DB instance or cluster to 
be polled.
+    :param waiter_delay: The amount of time in seconds to wait between 
attempts.
+    :param waiter_max_attempts: The maximum number of attempts to be made.
+    :param aws_conn_id: The Airflow connection used for AWS credentials.
+    :param region_name: AWS region where the DB is located, if different from 
the default one.
+    :param response: The response from the RdsHook, to be passed back to the 
operator.
+    :param db_type: The type of DB: instance or cluster.
+    """
+
+    def __init__(
+        self,
+        db_identifier: str,
+        waiter_delay: int,
+        waiter_max_attempts: int,
+        aws_conn_id: str,
+        region_name: str | None,
+        response: dict[str, Any],
+        db_type: RdsDbType,
+    ) -> None:
+        super().__init__(
+            serialized_fields={
+                "db_identifier": db_identifier,
+                "response": response,
+                "db_type": db_type,
+            },
+            waiter_name=f"db_{db_type.value}_available",
+            waiter_args={_waiter_arg[db_type]: db_identifier},
+            failure_message="Error while waiting for DB to be available",
+            status_message="DB initialization in progress",
+            status_queries=_status_paths[db_type],
+            return_key="response",
+            return_value=response,
+            waiter_delay=waiter_delay,
+            waiter_max_attempts=waiter_max_attempts,
+            aws_conn_id=aws_conn_id,
+            region_name=region_name,
+        )
+
+    def hook(self) -> AwsGenericHook:
+        return RdsHook(aws_conn_id=self.aws_conn_id, 
region_name=self.region_name)
+
+
+class RdsDbDeletedTrigger(AwsBaseWaiterTrigger):
+    """
+    Trigger to wait asynchronously for a DB instance or cluster to be deleted.
+
+    :param db_identifier: The DB identifier for the DB instance or cluster to 
be polled.
+    :param waiter_delay: The amount of time in seconds to wait between 
attempts.
+    :param waiter_max_attempts: The maximum number of attempts to be made.
+    :param aws_conn_id: The Airflow connection used for AWS credentials.
+    :param region_name: AWS region where the DB is located, if different from 
the default one.
+    :param response: The response from the RdsHook, to be passed back to the 
operator.
+    :param db_type: The type of DB: instance or cluster.
+    """
+
+    def __init__(
+        self,
+        db_identifier: str,
+        waiter_delay: int,
+        waiter_max_attempts: int,
+        aws_conn_id: str,
+        region_name: str | None,
+        response: dict[str, Any],
+        db_type: RdsDbType,
+    ) -> None:
+        super().__init__(
+            serialized_fields={
+                "db_identifier": db_identifier,
+                "response": response,
+                "db_type": db_type,
+            },
+            waiter_name=f"db_{db_type.value}_deleted",
+            waiter_args={_waiter_arg[db_type]: db_identifier},
+            failure_message="Error while deleting DB",
+            status_message="DB deletion in progress",
+            status_queries=_status_paths[db_type],
+            return_key="response",
+            return_value=response,
+            waiter_delay=waiter_delay,
+            waiter_max_attempts=waiter_max_attempts,
+            aws_conn_id=aws_conn_id,
+            region_name=region_name,
+        )
+
+    def hook(self) -> AwsGenericHook:
+        return RdsHook(aws_conn_id=self.aws_conn_id, 
region_name=self.region_name)
+
+
+class RdsDbStoppedTrigger(AwsBaseWaiterTrigger):
+    """
+    Trigger to wait asynchronously for a DB instance or cluster to be stopped.
+
+    :param db_identifier: The DB identifier for the DB instance or cluster to 
be polled.
+    :param waiter_delay: The amount of time in seconds to wait between 
attempts.
+    :param waiter_max_attempts: The maximum number of attempts to be made.
+    :param aws_conn_id: The Airflow connection used for AWS credentials.
+    :param region_name: AWS region where the DB is located, if different from 
the default one.
+    :param response: The response from the RdsHook, to be passed back to the 
operator.
+    :param db_type: The type of DB: instance or cluster.
+    """
+
+    def __init__(
+        self,
+        db_identifier: str,
+        waiter_delay: int,
+        waiter_max_attempts: int,
+        aws_conn_id: str,
+        region_name: str | None,
+        response: dict[str, Any],
+        db_type: RdsDbType,
+    ) -> None:
+        super().__init__(
+            serialized_fields={
+                "db_identifier": db_identifier,
+                "response": response,
+                "db_type": db_type,
+            },
+            waiter_name=f"db_{db_type.value}_stopped",
+            waiter_args={_waiter_arg[db_type]: db_identifier},
+            failure_message="Error while stopping DB",
+            status_message="DB is being stopped",
+            status_queries=_status_paths[db_type],
+            return_key="response",
+            return_value=response,
+            waiter_delay=waiter_delay,
+            waiter_max_attempts=waiter_max_attempts,
+            aws_conn_id=aws_conn_id,
+            region_name=region_name,
+        )
+
+    def hook(self) -> AwsGenericHook:
+        return RdsHook(aws_conn_id=self.aws_conn_id, 
region_name=self.region_name)
diff --git a/airflow/providers/amazon/aws/utils/rds.py 
b/airflow/providers/amazon/aws/utils/rds.py
index 1ba511c4b4..eacbf0b63a 100644
--- a/airflow/providers/amazon/aws/utils/rds.py
+++ b/airflow/providers/amazon/aws/utils/rds.py
@@ -22,5 +22,5 @@ from enum import Enum
 class RdsDbType(Enum):
     """Only available types for the RDS."""
 
-    INSTANCE: str = "instance"
-    CLUSTER: str = "cluster"
+    INSTANCE = "instance"
+    CLUSTER = "cluster"
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index 0975656c75..ad3b7cd8f8 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -1208,6 +1208,7 @@ RBAC
 rbac
 rc
 rdbms
+RDS
 readfp
 Readme
 readme
diff --git a/tests/providers/amazon/aws/operators/test_rds.py 
b/tests/providers/amazon/aws/operators/test_rds.py
index d2408fa47d..47122d55c0 100644
--- a/tests/providers/amazon/aws/operators/test_rds.py
+++ b/tests/providers/amazon/aws/operators/test_rds.py
@@ -18,11 +18,13 @@
 from __future__ import annotations
 
 import logging
+from unittest import mock
 from unittest.mock import patch
 
 import pytest
 from moto import mock_rds
 
+from airflow.exceptions import TaskDeferred
 from airflow.models import DAG
 from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
 from airflow.providers.amazon.aws.hooks.rds import RdsHook
@@ -40,6 +42,7 @@ from airflow.providers.amazon.aws.operators.rds import (
     RdsStartExportTaskOperator,
     RdsStopDbOperator,
 )
+from airflow.providers.amazon.aws.triggers.rds import RdsDbAvailableTrigger, 
RdsDbStoppedTrigger
 from airflow.utils import timezone
 
 DEFAULT_DATE = timezone.datetime(2019, 1, 1)
@@ -832,6 +835,19 @@ class TestRdsStopDbOperator:
         assert status == "stopped"
         mock_await_status.assert_not_called()
 
+    @mock.patch.object(RdsHook, "conn")
+    def test_deferred(self, conn_mock):
+        op = RdsStopDbOperator(
+            task_id="test_stop_db_instance_no_wait",
+            db_identifier=DB_INSTANCE_NAME,
+            deferrable=True,
+        )
+
+        with pytest.raises(TaskDeferred) as defer:
+            op.execute({})
+
+        assert isinstance(defer.value.trigger, RdsDbStoppedTrigger)
+
     @mock_rds
     def test_stop_db_instance_create_snapshot(self):
         _create_db_instance(self.hook)
@@ -932,3 +948,16 @@ class TestRdsStartDbOperator:
         result_after = 
self.hook.conn.describe_db_clusters(DBClusterIdentifier=DB_CLUSTER_NAME)
         status_after = result_after["DBClusters"][0]["Status"]
         assert status_after == "available"
+
+    @mock.patch.object(RdsHook, "conn")
+    def test_deferred(self, conn_mock):
+        op = RdsStartDbOperator(
+            task_id="test_stop_db_instance_no_wait",
+            db_identifier=DB_INSTANCE_NAME,
+            deferrable=True,
+        )
+
+        with pytest.raises(TaskDeferred) as defer:
+            op.execute({})
+
+        assert isinstance(defer.value.trigger, RdsDbAvailableTrigger)
diff --git a/tests/providers/amazon/aws/triggers/test_rds.py 
b/tests/providers/amazon/aws/triggers/test_rds.py
index 9c518c8eee..57db41a5e0 100644
--- a/tests/providers/amazon/aws/triggers/test_rds.py
+++ b/tests/providers/amazon/aws/triggers/test_rds.py
@@ -16,16 +16,14 @@
 # under the License.
 from __future__ import annotations
 
-from unittest import mock
-from unittest.mock import AsyncMock
-
 import pytest
-from botocore.exceptions import WaiterError
 
-from airflow.exceptions import AirflowException
-from airflow.providers.amazon.aws.hooks.rds import RdsHook
-from airflow.providers.amazon.aws.triggers.rds import RdsDbInstanceTrigger
-from airflow.triggers.base import TriggerEvent
+from airflow.providers.amazon.aws.triggers.rds import (
+    RdsDbAvailableTrigger,
+    RdsDbDeletedTrigger,
+    RdsDbStoppedTrigger,
+)
+from airflow.providers.amazon.aws.utils.rds import RdsDbType
 
 TEST_DB_INSTANCE_IDENTIFIER = "test-db-instance-identifier"
 TEST_WAITER_DELAY = 10
@@ -40,146 +38,47 @@ TEST_RESPONSE = {
 }
 
 
-class TestRdsDbInstanceTrigger:
-    def test_rds_db_instance_trigger_serialize(self):
-        rds_db_instance_trigger = RdsDbInstanceTrigger(
-            waiter_name="test-waiter",
-            db_instance_identifier=TEST_DB_INSTANCE_IDENTIFIER,
-            waiter_delay=TEST_WAITER_DELAY,
-            waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
-            aws_conn_id=TEST_AWS_CONN_ID,
-            region_name=TEST_REGION,
-            response=TEST_RESPONSE,
-        )
-        class_path, args = rds_db_instance_trigger.serialize()
-
-        assert class_path == 
"airflow.providers.amazon.aws.triggers.rds.RdsDbInstanceTrigger"
-        assert args["waiter_name"] == "test-waiter"
-        assert args["db_instance_identifier"] == TEST_DB_INSTANCE_IDENTIFIER
-        assert args["waiter_delay"] == str(TEST_WAITER_DELAY)
-        assert args["waiter_max_attempts"] == str(TEST_WAITER_MAX_ATTEMPTS)
-        assert args["aws_conn_id"] == TEST_AWS_CONN_ID
-        assert args["region_name"] == TEST_REGION
-        assert args["response"] == TEST_RESPONSE
-
-    @pytest.mark.asyncio
-    @mock.patch.object(RdsHook, "async_conn")
-    async def test_rds_db_instance_trigger_run(self, mock_async_conn):
-        a_mock = mock.MagicMock()
-        mock_async_conn.__aenter__.return_value = a_mock
-
-        a_mock.get_waiter().wait = AsyncMock()
-
-        rds_db_instance_trigger = RdsDbInstanceTrigger(
-            waiter_name="test-waiter",
-            db_instance_identifier=TEST_DB_INSTANCE_IDENTIFIER,
-            waiter_delay=TEST_WAITER_DELAY,
-            waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
-            aws_conn_id=TEST_AWS_CONN_ID,
-            region_name=TEST_REGION,
-            response=TEST_RESPONSE,
-        )
-
-        generator = rds_db_instance_trigger.run()
-        response = await generator.asend(None)
-
-        assert response == TriggerEvent({"status": "success", "response": 
TEST_RESPONSE})
-
-    @pytest.mark.asyncio
-    @mock.patch("asyncio.sleep")
-    @mock.patch.object(RdsHook, "async_conn")
-    async def test_rds_db_instance_trigger_run_multiple_attempts(self, 
mock_async_conn, mock_sleep):
-        mock_sleep.return_value = True
-        a_mock = mock.MagicMock()
-        mock_async_conn.__aenter__.return_value = a_mock
-        error = WaiterError(
-            name="test_name",
-            reason="test_reason",
-            last_response={"DBInstances": [{"DBInstanceStatus": "CREATING"}]},
-        )
-        a_mock.get_waiter().wait = AsyncMock(side_effect=[error, error, error, 
True])
-
-        rds_db_instance_trigger = RdsDbInstanceTrigger(
-            waiter_name="test-waiter",
-            db_instance_identifier=TEST_DB_INSTANCE_IDENTIFIER,
-            waiter_delay=TEST_WAITER_DELAY,
-            waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
-            aws_conn_id=TEST_AWS_CONN_ID,
-            region_name=TEST_REGION,
-            response=TEST_RESPONSE,
-        )
-
-        generator = rds_db_instance_trigger.run()
-        response = await generator.asend(None)
-        assert a_mock.get_waiter().wait.call_count == 4
-
-        assert response == TriggerEvent({"status": "success", "response": 
TEST_RESPONSE})
-
-    @pytest.mark.asyncio
-    @mock.patch("asyncio.sleep")
-    @mock.patch.object(RdsHook, "async_conn")
-    async def test_rds_db_instance_trigger_run_attempts_exceeded(self, 
mock_async_conn, mock_sleep):
-        mock_sleep.return_value = True
-
-        a_mock = mock.MagicMock()
-        mock_async_conn.__aenter__.return_value = a_mock
-        error = WaiterError(
-            name="test_name",
-            reason="test_reason",
-            last_response={"DBInstances": [{"DBInstanceStatus": "CREATING"}]},
-        )
-        a_mock.get_waiter().wait = AsyncMock(side_effect=[error, error, error, 
True])
-
-        rds_db_instance_trigger = RdsDbInstanceTrigger(
-            waiter_name="test-waiter",
-            db_instance_identifier=TEST_DB_INSTANCE_IDENTIFIER,
-            waiter_delay=TEST_WAITER_DELAY,
-            waiter_max_attempts=2,
-            aws_conn_id=TEST_AWS_CONN_ID,
-            region_name=TEST_REGION,
-            response=TEST_RESPONSE,
-        )
-
-        with pytest.raises(AirflowException) as exc:
-            generator = rds_db_instance_trigger.run()
-            await generator.asend(None)
-
-        assert "Waiter error: max attempts reached" in str(exc.value)
-        assert a_mock.get_waiter().wait.call_count == 2
-
-    @pytest.mark.asyncio
-    @mock.patch("asyncio.sleep")
-    @mock.patch.object(RdsHook, "async_conn")
-    async def test_rds_db_instance_trigger_run_attempts_failed(self, 
mock_async_conn, mock_sleep):
-        a_mock = mock.MagicMock()
-        mock_async_conn.__aenter__.return_value = a_mock
-
-        error_creating = WaiterError(
-            name="test_name",
-            reason="test_reason",
-            last_response={"DBInstances": [{"DBInstanceStatus": "CREATING"}]},
-        )
-
-        error_failed = WaiterError(
-            name="test_name",
-            reason="Waiter encountered a terminal failure state:",
-            last_response={"DBInstances": [{"DBInstanceStatus": "FAILED"}]},
-        )
-        a_mock.get_waiter().wait = AsyncMock(side_effect=[error_creating, 
error_creating, error_failed])
-        mock_sleep.return_value = True
-
-        rds_db_instance_trigger = RdsDbInstanceTrigger(
-            waiter_name="test-waiter",
-            db_instance_identifier=TEST_DB_INSTANCE_IDENTIFIER,
-            waiter_delay=TEST_WAITER_DELAY,
-            waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
-            aws_conn_id=TEST_AWS_CONN_ID,
-            region_name=TEST_REGION,
-            response=TEST_RESPONSE,
-        )
-
-        with pytest.raises(AirflowException) as exc:
-            generator = rds_db_instance_trigger.run()
-            await generator.asend(None)
-        assert "Error checking DB Instance status" in str(exc.value)
-        assert a_mock.get_waiter().wait.call_count == 3
+class TestRdsTriggers:
+    @pytest.mark.parametrize(
+        "trigger",
+        [
+            RdsDbAvailableTrigger(
+                db_identifier=TEST_DB_INSTANCE_IDENTIFIER,
+                waiter_delay=TEST_WAITER_DELAY,
+                waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
+                aws_conn_id=TEST_AWS_CONN_ID,
+                region_name=TEST_REGION,
+                response=TEST_RESPONSE,
+                db_type=RdsDbType.INSTANCE,
+            ),
+            RdsDbDeletedTrigger(
+                db_identifier=TEST_DB_INSTANCE_IDENTIFIER,
+                waiter_delay=TEST_WAITER_DELAY,
+                waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
+                aws_conn_id=TEST_AWS_CONN_ID,
+                region_name=TEST_REGION,
+                response=TEST_RESPONSE,
+                db_type=RdsDbType.INSTANCE,
+            ),
+            RdsDbStoppedTrigger(
+                db_identifier=TEST_DB_INSTANCE_IDENTIFIER,
+                waiter_delay=TEST_WAITER_DELAY,
+                waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
+                aws_conn_id=TEST_AWS_CONN_ID,
+                region_name=TEST_REGION,
+                response=TEST_RESPONSE,
+                db_type=RdsDbType.INSTANCE,
+            ),
+        ],
+    )
+    def test_serialize_recreate(self, trigger):
+        class_path, args = trigger.serialize()
+
+        class_name = class_path.split(".")[-1]
+        clazz = globals()[class_name]
+        instance = clazz(**args)
+
+        class_path2, args2 = instance.serialize()
+
+        assert class_path == class_path2
+        assert args == args2


Reply via email to