This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new b1b69af88f Add deferrable mode to `RdsCreateDbInstanceOperator` and 
`RdsDeleteDbInstanceOperator` (#32171)
b1b69af88f is described below

commit b1b69af88f9e24db2d1f003435d8ee8cdb6933b0
Author: Syed Hussaain <[email protected]>
AuthorDate: Wed Jun 28 13:06:05 2023 -0700

    Add deferrable mode to `RdsCreateDbInstanceOperator` and 
`RdsDeleteDbInstanceOperator` (#32171)
    
    * RDS Create/Delete DB Instance Deferrable mode
---
 airflow/providers/amazon/aws/hooks/rds.py          |  12 +-
 airflow/providers/amazon/aws/operators/rds.py      | 100 ++++++++++-
 airflow/providers/amazon/aws/triggers/rds.py       |  89 ++++++++++
 airflow/providers/amazon/provider.yaml             |   3 +
 .../operators/rds.rst                              |   2 +
 tests/providers/amazon/aws/hooks/test_rds.py       |   1 -
 tests/providers/amazon/aws/triggers/test_rds.py    | 184 +++++++++++++++++++++
 7 files changed, 381 insertions(+), 10 deletions(-)

diff --git a/airflow/providers/amazon/aws/hooks/rds.py 
b/airflow/providers/amazon/aws/hooks/rds.py
index df0f7af0e5..0a1dd798da 100644
--- a/airflow/providers/amazon/aws/hooks/rds.py
+++ b/airflow/providers/amazon/aws/hooks/rds.py
@@ -23,6 +23,7 @@ from typing import TYPE_CHECKING, Callable
 
 from airflow.exceptions import AirflowException, AirflowNotFoundException
 from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
+from airflow.providers.amazon.aws.utils.waiter_with_logging import wait
 
 if TYPE_CHECKING:
     from mypy_boto3_rds import RDSClient  # noqa
@@ -265,9 +266,14 @@ class RdsHook(AwsGenericHook["RDSClient"]):
         target_state = target_state.lower()
         if target_state in ("available", "deleted"):
             waiter = self.conn.get_waiter(f"db_instance_{target_state}")  # 
type: ignore
-            waiter.wait(
-                DBInstanceIdentifier=db_instance_id,
-                WaiterConfig={"Delay": check_interval, "MaxAttempts": 
max_attempts},
+            wait(
+                waiter=waiter,
+                waiter_delay=check_interval,
+                waiter_max_attempts=max_attempts,
+                args={"DBInstanceIdentifier": db_instance_id},
+                failure_message=f"Rdb DB instance failed to reach state 
{target_state}",
+                status_message="Rds DB instance state is",
+                status_args=["DBInstances[0].DBInstanceStatus"],
             )
         else:
             self._wait_for_state(poke, target_state, check_interval, 
max_attempts)
diff --git a/airflow/providers/amazon/aws/operators/rds.py 
b/airflow/providers/amazon/aws/operators/rds.py
index 5f7d423d4c..440d895afa 100644
--- a/airflow/providers/amazon/aws/operators/rds.py
+++ b/airflow/providers/amazon/aws/operators/rds.py
@@ -18,14 +18,18 @@
 from __future__ import annotations
 
 import json
+from datetime import timedelta
 from typing import TYPE_CHECKING, Sequence
 
 from mypy_boto3_rds.type_defs import TagTypeDef
 
+from airflow.exceptions import AirflowException
 from airflow.models import BaseOperator
 from airflow.providers.amazon.aws.hooks.rds import RdsHook
+from airflow.providers.amazon.aws.triggers.rds import RdsDbInstanceTrigger
 from airflow.providers.amazon.aws.utils.rds import RdsDbType
 from airflow.providers.amazon.aws.utils.tags import format_tags
+from airflow.providers.amazon.aws.utils.waiter_with_logging import wait
 
 if TYPE_CHECKING:
     from airflow.utils.context import Context
@@ -38,8 +42,8 @@ class RdsBaseOperator(BaseOperator):
     ui_fgcolor = "#ffffff"
 
     def __init__(self, *args, aws_conn_id: str = "aws_conn_id", hook_params: 
dict | None = None, **kwargs):
-        hook_params = hook_params or {}
-        self.hook = RdsHook(aws_conn_id=aws_conn_id, **hook_params)
+        self.hook_params = hook_params or {}
+        self.hook = RdsHook(aws_conn_id=aws_conn_id, **self.hook_params)
         super().__init__(*args, **kwargs)
 
         self._await_interval = 60  # seconds
@@ -522,6 +526,11 @@ class RdsCreateDbInstanceOperator(RdsBaseOperator):
         
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds.html#RDS.Client.create_db_instance
     :param aws_conn_id: The Airflow connection used for AWS credentials.
     :param wait_for_completion:  If True, waits for creation of the DB 
instance to complete. (default: True)
+    :param waiter_delay: Time (in seconds) to wait between two consecutive 
calls to check DB instance state
+    :param waiter_max_attempts: The maximum number of attempts to check DB 
instance state
+    :param deferrable: If True, the operator will wait asynchronously for the 
DB instance to be created.
+        This implies waiting for completion. This mode requires aiobotocore 
module to be installed.
+        (default: False)
     """
 
     template_fields = ("db_instance_identifier", "db_instance_class", 
"engine", "rds_kwargs")
@@ -535,6 +544,9 @@ class RdsCreateDbInstanceOperator(RdsBaseOperator):
         rds_kwargs: dict | None = None,
         aws_conn_id: str = "aws_default",
         wait_for_completion: bool = True,
+        deferrable: bool = False,
+        waiter_delay: int = 30,
+        waiter_max_attempts: int = 60,
         **kwargs,
     ):
         super().__init__(aws_conn_id=aws_conn_id, **kwargs)
@@ -543,7 +555,11 @@ class RdsCreateDbInstanceOperator(RdsBaseOperator):
         self.db_instance_class = db_instance_class
         self.engine = engine
         self.rds_kwargs = rds_kwargs or {}
-        self.wait_for_completion = wait_for_completion
+        self.wait_for_completion = False if deferrable else wait_for_completion
+        self.deferrable = deferrable
+        self.waiter_delay = waiter_delay
+        self.waiter_max_attempts = waiter_max_attempts
+        self.aws_conn_id = aws_conn_id
 
     def execute(self, context: Context) -> str:
         self.log.info("Creating new DB instance %s", 
self.db_instance_identifier)
@@ -554,11 +570,41 @@ class RdsCreateDbInstanceOperator(RdsBaseOperator):
             Engine=self.engine,
             **self.rds_kwargs,
         )
+        if self.deferrable:
+            self.defer(
+                trigger=RdsDbInstanceTrigger(
+                    db_instance_identifier=self.db_instance_identifier,
+                    waiter_delay=self.waiter_delay,
+                    waiter_max_attempts=self.waiter_max_attempts,
+                    aws_conn_id=self.aws_conn_id,
+                    hook_params=self.hook_params,
+                    waiter_name="db_instance_available",
+                    # ignoring type because create_db_instance is a dict
+                    response=create_db_instance,  # type: ignore[arg-type]
+                ),
+                method_name="execute_complete",
+                timeout=timedelta(seconds=self.waiter_delay * 
self.waiter_max_attempts),
+            )
 
         if self.wait_for_completion:
-            self.hook.wait_for_db_instance_state(self.db_instance_identifier, 
target_state="available")
+            waiter = self.hook.conn.get_waiter("db_instance_available")
+            wait(
+                waiter=waiter,
+                waiter_delay=self.waiter_delay,
+                waiter_max_attempts=self.waiter_max_attempts,
+                args={"DBInstanceIdentifier": self.db_instance_identifier},
+                failure_message="DB instance creation failed",
+                status_message="DB Instance status is",
+                status_args=["DBInstances[0].DBInstanceStatus"],
+            )
         return json.dumps(create_db_instance, default=str)
 
+    def execute_complete(self, context, event=None) -> str:
+        if event["status"] != "success":
+            raise AirflowException(f"DB instance creation failed: {event}")
+        else:
+            return json.dumps(event["response"], default=str)
+
 
 class RdsDeleteDbInstanceOperator(RdsBaseOperator):
     """
@@ -573,6 +619,11 @@ class RdsDeleteDbInstanceOperator(RdsBaseOperator):
         
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds.html#RDS.Client.delete_db_instance
     :param aws_conn_id: The Airflow connection used for AWS credentials.
     :param wait_for_completion:  If True, waits for deletion of the DB 
instance to complete. (default: True)
+    :param waiter_delay: Time (in seconds) to wait between two consecutive 
calls to check DB instance state
+    :param waiter_max_attempts: The maximum number of attempts to check DB 
instance state
+    :param deferrable: If True, the operator will wait asynchronously for the 
DB instance to be created.
+        This implies waiting for completion. This mode requires aiobotocore 
module to be installed.
+        (default: False)
     """
 
     template_fields = ("db_instance_identifier", "rds_kwargs")
@@ -584,12 +635,19 @@ class RdsDeleteDbInstanceOperator(RdsBaseOperator):
         rds_kwargs: dict | None = None,
         aws_conn_id: str = "aws_default",
         wait_for_completion: bool = True,
+        deferrable: bool = False,
+        waiter_delay: int = 30,
+        waiter_max_attempts: int = 60,
         **kwargs,
     ):
         super().__init__(aws_conn_id=aws_conn_id, **kwargs)
         self.db_instance_identifier = db_instance_identifier
         self.rds_kwargs = rds_kwargs or {}
-        self.wait_for_completion = wait_for_completion
+        self.wait_for_completion = False if deferrable else wait_for_completion
+        self.deferrable = deferrable
+        self.waiter_delay = waiter_delay
+        self.waiter_max_attempts = waiter_max_attempts
+        self.aws_conn_id = aws_conn_id
 
     def execute(self, context: Context) -> str:
         self.log.info("Deleting DB instance %s", self.db_instance_identifier)
@@ -598,11 +656,41 @@ class RdsDeleteDbInstanceOperator(RdsBaseOperator):
             DBInstanceIdentifier=self.db_instance_identifier,
             **self.rds_kwargs,
         )
+        if self.deferrable:
+            self.defer(
+                trigger=RdsDbInstanceTrigger(
+                    db_instance_identifier=self.db_instance_identifier,
+                    waiter_delay=self.waiter_delay,
+                    waiter_max_attempts=self.waiter_max_attempts,
+                    aws_conn_id=self.aws_conn_id,
+                    hook_params=self.hook_params,
+                    waiter_name="db_instance_deleted",
+                    # ignoring type because delete_db_instance is a dict
+                    response=delete_db_instance,  # type: ignore[arg-type]
+                ),
+                method_name="execute_complete",
+                timeout=timedelta(seconds=self.waiter_delay * 
self.waiter_max_attempts),
+            )
 
         if self.wait_for_completion:
-            self.hook.wait_for_db_instance_state(self.db_instance_identifier, 
target_state="deleted")
+            waiter = self.hook.conn.get_waiter("db_instance_deleted")
+            wait(
+                waiter=waiter,
+                waiter_delay=self.waiter_delay,
+                waiter_max_attempts=self.waiter_max_attempts,
+                args={"DBInstanceIdentifier": self.db_instance_identifier},
+                failure_message="DB instance deletion failed",
+                status_message="DB Instance status is",
+                status_args=["DBInstances[0].DBInstanceStatus"],
+            )
         return json.dumps(delete_db_instance, default=str)
 
+    def execute_complete(self, context, event=None) -> str:
+        if event["status"] != "success":
+            raise AirflowException(f"DB instance deletion failed: {event}")
+        else:
+            return json.dumps(event["response"], default=str)
+
 
 class RdsStartDbOperator(RdsBaseOperator):
     """
diff --git a/airflow/providers/amazon/aws/triggers/rds.py 
b/airflow/providers/amazon/aws/triggers/rds.py
new file mode 100644
index 0000000000..0897f764be
--- /dev/null
+++ b/airflow/providers/amazon/aws/triggers/rds.py
@@ -0,0 +1,89 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import Any
+
+from airflow.providers.amazon.aws.hooks.rds import RdsHook
+from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait
+from airflow.triggers.base import BaseTrigger, TriggerEvent
+
+
+class RdsDbInstanceTrigger(BaseTrigger):
+    """
+    Trigger for RdsCreateDbInstanceOperator and RdsDeleteDbInstanceOperator.
+
+    The trigger will asynchronously poll the boto3 API and wait for the
+    DB instance to be in the state specified by the waiter.
+
+    :param waiter_name: Name of the waiter to use, for instance 
'db_instance_available'
+        or 'db_instance_deleted'.
+    :param db_instance_identifier: The DB instance identifier for the DB 
instance to be polled.
+    :param waiter_delay: The amount of time in seconds to wait between 
attempts.
+    :param waiter_max_attempts: The maximum number of attempts to be made.
+    :param aws_conn_id: The Airflow connection used for AWS credentials.
+    :param hook_params: The parameters to pass to the RdsHook.
+    :param response: The response from the RdsHook, to be passed back to the 
operator.
+    """
+
+    def __init__(
+        self,
+        waiter_name: str,
+        db_instance_identifier: str,
+        waiter_delay: int,
+        waiter_max_attempts: int,
+        aws_conn_id: str,
+        hook_params: dict[str, Any],
+        response: dict[str, Any],
+    ):
+        self.db_instance_identifier = db_instance_identifier
+        self.waiter_delay = waiter_delay
+        self.waiter_max_attempts = waiter_max_attempts
+        self.aws_conn_id = aws_conn_id
+        self.hook_params = hook_params
+        self.waiter_name = waiter_name
+        self.response = response
+
+    def serialize(self) -> tuple[str, dict[str, Any]]:
+        return (
+            # dynamically generate the fully qualified name of the class
+            self.__class__.__module__ + "." + self.__class__.__qualname__,
+            {
+                "db_instance_identifier": self.db_instance_identifier,
+                "waiter_delay": str(self.waiter_delay),
+                "waiter_max_attempts": str(self.waiter_max_attempts),
+                "aws_conn_id": self.aws_conn_id,
+                "hook_params": self.hook_params,
+                "waiter_name": self.waiter_name,
+                "response": self.response,
+            },
+        )
+
+    async def run(self):
+        self.hook = RdsHook(aws_conn_id=self.aws_conn_id, **self.hook_params)
+        async with self.hook.async_conn as client:
+            waiter = client.get_waiter(self.waiter_name)
+            await async_wait(
+                waiter=waiter,
+                waiter_delay=int(self.waiter_delay),
+                waiter_max_attempts=int(self.waiter_max_attempts),
+                args={"DBInstanceIdentifier": self.db_instance_identifier},
+                failure_message="Error checking DB Instance status",
+                status_message="DB instance status is",
+                status_args=["DBInstances[0].DBInstanceStatus"],
+            )
+        yield TriggerEvent({"status": "success", "response": self.response})
diff --git a/airflow/providers/amazon/provider.yaml 
b/airflow/providers/amazon/provider.yaml
index bf2391d9a8..0c74094e81 100644
--- a/airflow/providers/amazon/provider.yaml
+++ b/airflow/providers/amazon/provider.yaml
@@ -546,6 +546,9 @@ triggers:
   - integration-name: Amazon ECS
     python-modules:
       - airflow.providers.amazon.aws.triggers.ecs
+  - integration-name: Amazon RDS
+    python-modules:
+      - airflow.providers.amazon.aws.triggers.rds
 
 transfers:
   - source-integration-name: Amazon DynamoDB
diff --git a/docs/apache-airflow-providers-amazon/operators/rds.rst 
b/docs/apache-airflow-providers-amazon/operators/rds.rst
index bca9c64af1..e27bbc2d2f 100644
--- a/docs/apache-airflow-providers-amazon/operators/rds.rst
+++ b/docs/apache-airflow-providers-amazon/operators/rds.rst
@@ -145,6 +145,7 @@ Create a database instance
 
 To create a AWS DB instance you can use
 
:class:`~airflow.providers.amazon.aws.operators.rds.RdsCreateDbInstanceOperator`.
+You can also run this operator in deferrable mode by setting ``deferrable`` 
param to ``True``.
 
 .. exampleinclude:: 
/../../tests/system/providers/amazon/aws/example_rds_instance.py
     :language: python
@@ -159,6 +160,7 @@ Delete a database instance
 
 To delete a AWS DB instance you can use
 
:class:`~airflow.providers.amazon.aws.operators.rds.RDSDeleteDbInstanceOperator`.
+You can also run this operator in deferrable mode by setting ``deferrable`` 
param to ``True``.
 
 .. exampleinclude:: 
/../../tests/system/providers/amazon/aws/example_rds_instance.py
     :language: python
diff --git a/tests/providers/amazon/aws/hooks/test_rds.py 
b/tests/providers/amazon/aws/hooks/test_rds.py
index f98320a2f1..a34e43c944 100644
--- a/tests/providers/amazon/aws/hooks/test_rds.py
+++ b/tests/providers/amazon/aws/hooks/test_rds.py
@@ -153,7 +153,6 @@ class TestRdsHook:
                 mock.return_value.wait.assert_called_once_with(
                     DBInstanceIdentifier=db_instance_id,
                     WaiterConfig={
-                        "Delay": self.waiter_args["check_interval"],
                         "MaxAttempts": self.waiter_args["max_attempts"],
                     },
                 )
diff --git a/tests/providers/amazon/aws/triggers/test_rds.py 
b/tests/providers/amazon/aws/triggers/test_rds.py
new file mode 100644
index 0000000000..5ae64b83c3
--- /dev/null
+++ b/tests/providers/amazon/aws/triggers/test_rds.py
@@ -0,0 +1,184 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest import mock
+from unittest.mock import AsyncMock
+
+import pytest
+from botocore.exceptions import WaiterError
+
+from airflow.exceptions import AirflowException
+from airflow.providers.amazon.aws.hooks.rds import RdsHook
+from airflow.providers.amazon.aws.triggers.rds import RdsDbInstanceTrigger
+from airflow.triggers.base import TriggerEvent
+
+TEST_DB_INSTANCE_IDENTIFIER = "test-db-instance-identifier"
+TEST_WAITER_DELAY = 10
+TEST_WAITER_MAX_ATTEMPTS = 10
+TEST_AWS_CONN_ID = "test-aws-id"
+TEST_RESPONSE = {
+    "DBInstance": {
+        "DBInstanceIdentifier": "test-db-instance-identifier",
+        "DBInstanceStatus": "test-db-instance-status",
+    }
+}
+
+
+class TestRdsDbInstanceTrigger:
+    def test_rds_db_instance_trigger_serialize(self):
+        rds_db_instance_trigger = RdsDbInstanceTrigger(
+            waiter_name="test-waiter",
+            db_instance_identifier=TEST_DB_INSTANCE_IDENTIFIER,
+            waiter_delay=TEST_WAITER_DELAY,
+            waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
+            aws_conn_id=TEST_AWS_CONN_ID,
+            hook_params={},
+            response=TEST_RESPONSE,
+        )
+        class_path, args = rds_db_instance_trigger.serialize()
+
+        assert class_path == 
"airflow.providers.amazon.aws.triggers.rds.RdsDbInstanceTrigger"
+        assert args["waiter_name"] == "test-waiter"
+        assert args["db_instance_identifier"] == TEST_DB_INSTANCE_IDENTIFIER
+        assert args["waiter_delay"] == str(TEST_WAITER_DELAY)
+        assert args["waiter_max_attempts"] == str(TEST_WAITER_MAX_ATTEMPTS)
+        assert args["aws_conn_id"] == TEST_AWS_CONN_ID
+        assert args["hook_params"] == {}
+        assert args["response"] == TEST_RESPONSE
+
+    @pytest.mark.asyncio
+    @mock.patch.object(RdsHook, "async_conn")
+    async def test_rds_db_instance_trigger_run(self, mock_async_conn):
+        a_mock = mock.MagicMock()
+        mock_async_conn.__aenter__.return_value = a_mock
+
+        a_mock.get_waiter().wait = AsyncMock()
+
+        rds_db_instance_trigger = RdsDbInstanceTrigger(
+            waiter_name="test-waiter",
+            db_instance_identifier=TEST_DB_INSTANCE_IDENTIFIER,
+            waiter_delay=TEST_WAITER_DELAY,
+            waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
+            aws_conn_id=TEST_AWS_CONN_ID,
+            hook_params={},
+            response=TEST_RESPONSE,
+        )
+
+        generator = rds_db_instance_trigger.run()
+        response = await generator.asend(None)
+
+        assert response == TriggerEvent({"status": "success", "response": 
TEST_RESPONSE})
+
+    @pytest.mark.asyncio
+    @mock.patch("asyncio.sleep")
+    @mock.patch.object(RdsHook, "async_conn")
+    async def test_rds_db_instance_trigger_run_multiple_attempts(self, 
mock_async_conn, mock_sleep):
+        mock_sleep.return_value = True
+        a_mock = mock.MagicMock()
+        mock_async_conn.__aenter__.return_value = a_mock
+        error = WaiterError(
+            name="test_name",
+            reason="test_reason",
+            last_response={"DBInstances": [{"DBInstanceStatus": "CREATING"}]},
+        )
+        a_mock.get_waiter().wait = AsyncMock(side_effect=[error, error, error, 
True])
+
+        rds_db_instance_trigger = RdsDbInstanceTrigger(
+            waiter_name="test-waiter",
+            db_instance_identifier=TEST_DB_INSTANCE_IDENTIFIER,
+            waiter_delay=TEST_WAITER_DELAY,
+            waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
+            aws_conn_id=TEST_AWS_CONN_ID,
+            hook_params={},
+            response=TEST_RESPONSE,
+        )
+
+        generator = rds_db_instance_trigger.run()
+        response = await generator.asend(None)
+        assert a_mock.get_waiter().wait.call_count == 4
+
+        assert response == TriggerEvent({"status": "success", "response": 
TEST_RESPONSE})
+
+    @pytest.mark.asyncio
+    @mock.patch("asyncio.sleep")
+    @mock.patch.object(RdsHook, "async_conn")
+    async def test_rds_db_instance_trigger_run_attempts_exceeded(self, 
mock_async_conn, mock_sleep):
+        mock_sleep.return_value = True
+
+        a_mock = mock.MagicMock()
+        mock_async_conn.__aenter__.return_value = a_mock
+        error = WaiterError(
+            name="test_name",
+            reason="test_reason",
+            last_response={"DBInstances": [{"DBInstanceStatus": "CREATING"}]},
+        )
+        a_mock.get_waiter().wait = AsyncMock(side_effect=[error, error, error, 
True])
+
+        rds_db_instance_trigger = RdsDbInstanceTrigger(
+            waiter_name="test-waiter",
+            db_instance_identifier=TEST_DB_INSTANCE_IDENTIFIER,
+            waiter_delay=TEST_WAITER_DELAY,
+            waiter_max_attempts=2,
+            aws_conn_id=TEST_AWS_CONN_ID,
+            hook_params={},
+            response=TEST_RESPONSE,
+        )
+
+        with pytest.raises(AirflowException) as exc:
+            generator = rds_db_instance_trigger.run()
+            await generator.asend(None)
+
+        assert "Waiter error: max attempts reached" in str(exc.value)
+        assert a_mock.get_waiter().wait.call_count == 2
+
+    @pytest.mark.asyncio
+    @mock.patch("asyncio.sleep")
+    @mock.patch.object(RdsHook, "async_conn")
+    async def test_rds_db_instance_trigger_run_attempts_failed(self, 
mock_async_conn, mock_sleep):
+        a_mock = mock.MagicMock()
+        mock_async_conn.__aenter__.return_value = a_mock
+
+        error_creating = WaiterError(
+            name="test_name",
+            reason="test_reason",
+            last_response={"DBInstances": [{"DBInstanceStatus": "CREATING"}]},
+        )
+
+        error_failed = WaiterError(
+            name="test_name",
+            reason="Waiter encountered a terminal failure state:",
+            last_response={"DBInstances": [{"DBInstanceStatus": "FAILED"}]},
+        )
+        a_mock.get_waiter().wait = AsyncMock(side_effect=[error_creating, 
error_creating, error_failed])
+        mock_sleep.return_value = True
+
+        rds_db_instance_trigger = RdsDbInstanceTrigger(
+            waiter_name="test-waiter",
+            db_instance_identifier=TEST_DB_INSTANCE_IDENTIFIER,
+            waiter_delay=TEST_WAITER_DELAY,
+            waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
+            aws_conn_id=TEST_AWS_CONN_ID,
+            hook_params={},
+            response=TEST_RESPONSE,
+        )
+
+        with pytest.raises(AirflowException) as exc:
+            generator = rds_db_instance_trigger.run()
+            await generator.asend(None)
+        assert "Error checking DB Instance status" in str(exc.value)
+        assert a_mock.get_waiter().wait.call_count == 3

Reply via email to