This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new b1b69af88f Add deferrable mode to `RdsCreateDbInstanceOperator` and
`RdsDeleteDbInstanceOperator` (#32171)
b1b69af88f is described below
commit b1b69af88f9e24db2d1f003435d8ee8cdb6933b0
Author: Syed Hussaain <[email protected]>
AuthorDate: Wed Jun 28 13:06:05 2023 -0700
Add deferrable mode to `RdsCreateDbInstanceOperator` and
`RdsDeleteDbInstanceOperator` (#32171)
* RDS Create/Delete DB Instance Deferrable mode
---
airflow/providers/amazon/aws/hooks/rds.py | 12 +-
airflow/providers/amazon/aws/operators/rds.py | 100 ++++++++++-
airflow/providers/amazon/aws/triggers/rds.py | 89 ++++++++++
airflow/providers/amazon/provider.yaml | 3 +
.../operators/rds.rst | 2 +
tests/providers/amazon/aws/hooks/test_rds.py | 1 -
tests/providers/amazon/aws/triggers/test_rds.py | 184 +++++++++++++++++++++
7 files changed, 381 insertions(+), 10 deletions(-)
diff --git a/airflow/providers/amazon/aws/hooks/rds.py
b/airflow/providers/amazon/aws/hooks/rds.py
index df0f7af0e5..0a1dd798da 100644
--- a/airflow/providers/amazon/aws/hooks/rds.py
+++ b/airflow/providers/amazon/aws/hooks/rds.py
@@ -23,6 +23,7 @@ from typing import TYPE_CHECKING, Callable
from airflow.exceptions import AirflowException, AirflowNotFoundException
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
+from airflow.providers.amazon.aws.utils.waiter_with_logging import wait
if TYPE_CHECKING:
from mypy_boto3_rds import RDSClient # noqa
@@ -265,9 +266,14 @@ class RdsHook(AwsGenericHook["RDSClient"]):
target_state = target_state.lower()
if target_state in ("available", "deleted"):
waiter = self.conn.get_waiter(f"db_instance_{target_state}") #
type: ignore
- waiter.wait(
- DBInstanceIdentifier=db_instance_id,
- WaiterConfig={"Delay": check_interval, "MaxAttempts":
max_attempts},
+ wait(
+ waiter=waiter,
+ waiter_delay=check_interval,
+ waiter_max_attempts=max_attempts,
+ args={"DBInstanceIdentifier": db_instance_id},
+ failure_message=f"Rdb DB instance failed to reach state
{target_state}",
+ status_message="Rds DB instance state is",
+ status_args=["DBInstances[0].DBInstanceStatus"],
)
else:
self._wait_for_state(poke, target_state, check_interval,
max_attempts)
diff --git a/airflow/providers/amazon/aws/operators/rds.py
b/airflow/providers/amazon/aws/operators/rds.py
index 5f7d423d4c..440d895afa 100644
--- a/airflow/providers/amazon/aws/operators/rds.py
+++ b/airflow/providers/amazon/aws/operators/rds.py
@@ -18,14 +18,18 @@
from __future__ import annotations
import json
+from datetime import timedelta
from typing import TYPE_CHECKING, Sequence
from mypy_boto3_rds.type_defs import TagTypeDef
+from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.rds import RdsHook
+from airflow.providers.amazon.aws.triggers.rds import RdsDbInstanceTrigger
from airflow.providers.amazon.aws.utils.rds import RdsDbType
from airflow.providers.amazon.aws.utils.tags import format_tags
+from airflow.providers.amazon.aws.utils.waiter_with_logging import wait
if TYPE_CHECKING:
from airflow.utils.context import Context
@@ -38,8 +42,8 @@ class RdsBaseOperator(BaseOperator):
ui_fgcolor = "#ffffff"
def __init__(self, *args, aws_conn_id: str = "aws_conn_id", hook_params:
dict | None = None, **kwargs):
- hook_params = hook_params or {}
- self.hook = RdsHook(aws_conn_id=aws_conn_id, **hook_params)
+ self.hook_params = hook_params or {}
+ self.hook = RdsHook(aws_conn_id=aws_conn_id, **self.hook_params)
super().__init__(*args, **kwargs)
self._await_interval = 60 # seconds
@@ -522,6 +526,11 @@ class RdsCreateDbInstanceOperator(RdsBaseOperator):
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds.html#RDS.Client.create_db_instance
:param aws_conn_id: The Airflow connection used for AWS credentials.
:param wait_for_completion: If True, waits for creation of the DB
instance to complete. (default: True)
+ :param waiter_delay: Time (in seconds) to wait between two consecutive
calls to check DB instance state
+ :param waiter_max_attempts: The maximum number of attempts to check DB
instance state
+ :param deferrable: If True, the operator will wait asynchronously for the
DB instance to be created.
+ This implies waiting for completion. This mode requires aiobotocore
module to be installed.
+ (default: False)
"""
template_fields = ("db_instance_identifier", "db_instance_class",
"engine", "rds_kwargs")
@@ -535,6 +544,9 @@ class RdsCreateDbInstanceOperator(RdsBaseOperator):
rds_kwargs: dict | None = None,
aws_conn_id: str = "aws_default",
wait_for_completion: bool = True,
+ deferrable: bool = False,
+ waiter_delay: int = 30,
+ waiter_max_attempts: int = 60,
**kwargs,
):
super().__init__(aws_conn_id=aws_conn_id, **kwargs)
@@ -543,7 +555,11 @@ class RdsCreateDbInstanceOperator(RdsBaseOperator):
self.db_instance_class = db_instance_class
self.engine = engine
self.rds_kwargs = rds_kwargs or {}
- self.wait_for_completion = wait_for_completion
+ self.wait_for_completion = False if deferrable else wait_for_completion
+ self.deferrable = deferrable
+ self.waiter_delay = waiter_delay
+ self.waiter_max_attempts = waiter_max_attempts
+ self.aws_conn_id = aws_conn_id
def execute(self, context: Context) -> str:
self.log.info("Creating new DB instance %s",
self.db_instance_identifier)
@@ -554,11 +570,41 @@ class RdsCreateDbInstanceOperator(RdsBaseOperator):
Engine=self.engine,
**self.rds_kwargs,
)
+ if self.deferrable:
+ self.defer(
+ trigger=RdsDbInstanceTrigger(
+ db_instance_identifier=self.db_instance_identifier,
+ waiter_delay=self.waiter_delay,
+ waiter_max_attempts=self.waiter_max_attempts,
+ aws_conn_id=self.aws_conn_id,
+ hook_params=self.hook_params,
+ waiter_name="db_instance_available",
+ # ignoring type because create_db_instance is a dict
+ response=create_db_instance, # type: ignore[arg-type]
+ ),
+ method_name="execute_complete",
+ timeout=timedelta(seconds=self.waiter_delay *
self.waiter_max_attempts),
+ )
if self.wait_for_completion:
- self.hook.wait_for_db_instance_state(self.db_instance_identifier,
target_state="available")
+ waiter = self.hook.conn.get_waiter("db_instance_available")
+ wait(
+ waiter=waiter,
+ waiter_delay=self.waiter_delay,
+ waiter_max_attempts=self.waiter_max_attempts,
+ args={"DBInstanceIdentifier": self.db_instance_identifier},
+ failure_message="DB instance creation failed",
+ status_message="DB Instance status is",
+ status_args=["DBInstances[0].DBInstanceStatus"],
+ )
return json.dumps(create_db_instance, default=str)
+ def execute_complete(self, context, event=None) -> str:
+ if event["status"] != "success":
+ raise AirflowException(f"DB instance creation failed: {event}")
+ else:
+ return json.dumps(event["response"], default=str)
+
class RdsDeleteDbInstanceOperator(RdsBaseOperator):
"""
@@ -573,6 +619,11 @@ class RdsDeleteDbInstanceOperator(RdsBaseOperator):
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds.html#RDS.Client.delete_db_instance
:param aws_conn_id: The Airflow connection used for AWS credentials.
:param wait_for_completion: If True, waits for deletion of the DB
instance to complete. (default: True)
+ :param waiter_delay: Time (in seconds) to wait between two consecutive
calls to check DB instance state
+ :param waiter_max_attempts: The maximum number of attempts to check DB
instance state
+ :param deferrable: If True, the operator will wait asynchronously for the
DB instance to be created.
+ This implies waiting for completion. This mode requires aiobotocore
module to be installed.
+ (default: False)
"""
template_fields = ("db_instance_identifier", "rds_kwargs")
@@ -584,12 +635,19 @@ class RdsDeleteDbInstanceOperator(RdsBaseOperator):
rds_kwargs: dict | None = None,
aws_conn_id: str = "aws_default",
wait_for_completion: bool = True,
+ deferrable: bool = False,
+ waiter_delay: int = 30,
+ waiter_max_attempts: int = 60,
**kwargs,
):
super().__init__(aws_conn_id=aws_conn_id, **kwargs)
self.db_instance_identifier = db_instance_identifier
self.rds_kwargs = rds_kwargs or {}
- self.wait_for_completion = wait_for_completion
+ self.wait_for_completion = False if deferrable else wait_for_completion
+ self.deferrable = deferrable
+ self.waiter_delay = waiter_delay
+ self.waiter_max_attempts = waiter_max_attempts
+ self.aws_conn_id = aws_conn_id
def execute(self, context: Context) -> str:
self.log.info("Deleting DB instance %s", self.db_instance_identifier)
@@ -598,11 +656,41 @@ class RdsDeleteDbInstanceOperator(RdsBaseOperator):
DBInstanceIdentifier=self.db_instance_identifier,
**self.rds_kwargs,
)
+ if self.deferrable:
+ self.defer(
+ trigger=RdsDbInstanceTrigger(
+ db_instance_identifier=self.db_instance_identifier,
+ waiter_delay=self.waiter_delay,
+ waiter_max_attempts=self.waiter_max_attempts,
+ aws_conn_id=self.aws_conn_id,
+ hook_params=self.hook_params,
+ waiter_name="db_instance_deleted",
+ # ignoring type because delete_db_instance is a dict
+ response=delete_db_instance, # type: ignore[arg-type]
+ ),
+ method_name="execute_complete",
+ timeout=timedelta(seconds=self.waiter_delay *
self.waiter_max_attempts),
+ )
if self.wait_for_completion:
- self.hook.wait_for_db_instance_state(self.db_instance_identifier,
target_state="deleted")
+ waiter = self.hook.conn.get_waiter("db_instance_deleted")
+ wait(
+ waiter=waiter,
+ waiter_delay=self.waiter_delay,
+ waiter_max_attempts=self.waiter_max_attempts,
+ args={"DBInstanceIdentifier": self.db_instance_identifier},
+ failure_message="DB instance deletion failed",
+ status_message="DB Instance status is",
+ status_args=["DBInstances[0].DBInstanceStatus"],
+ )
return json.dumps(delete_db_instance, default=str)
+ def execute_complete(self, context, event=None) -> str:
+ if event["status"] != "success":
+ raise AirflowException(f"DB instance deletion failed: {event}")
+ else:
+ return json.dumps(event["response"], default=str)
+
class RdsStartDbOperator(RdsBaseOperator):
"""
diff --git a/airflow/providers/amazon/aws/triggers/rds.py
b/airflow/providers/amazon/aws/triggers/rds.py
new file mode 100644
index 0000000000..0897f764be
--- /dev/null
+++ b/airflow/providers/amazon/aws/triggers/rds.py
@@ -0,0 +1,89 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import Any
+
+from airflow.providers.amazon.aws.hooks.rds import RdsHook
+from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait
+from airflow.triggers.base import BaseTrigger, TriggerEvent
+
+
+class RdsDbInstanceTrigger(BaseTrigger):
+ """
+ Trigger for RdsCreateDbInstanceOperator and RdsDeleteDbInstanceOperator.
+
+ The trigger will asynchronously poll the boto3 API and wait for the
+ DB instance to be in the state specified by the waiter.
+
+ :param waiter_name: Name of the waiter to use, for instance
'db_instance_available'
+ or 'db_instance_deleted'.
+ :param db_instance_identifier: The DB instance identifier for the DB
instance to be polled.
+ :param waiter_delay: The amount of time in seconds to wait between
attempts.
+ :param waiter_max_attempts: The maximum number of attempts to be made.
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
+ :param hook_params: The parameters to pass to the RdsHook.
+ :param response: The response from the RdsHook, to be passed back to the
operator.
+ """
+
+ def __init__(
+ self,
+ waiter_name: str,
+ db_instance_identifier: str,
+ waiter_delay: int,
+ waiter_max_attempts: int,
+ aws_conn_id: str,
+ hook_params: dict[str, Any],
+ response: dict[str, Any],
+ ):
+ self.db_instance_identifier = db_instance_identifier
+ self.waiter_delay = waiter_delay
+ self.waiter_max_attempts = waiter_max_attempts
+ self.aws_conn_id = aws_conn_id
+ self.hook_params = hook_params
+ self.waiter_name = waiter_name
+ self.response = response
+
+ def serialize(self) -> tuple[str, dict[str, Any]]:
+ return (
+ # dynamically generate the fully qualified name of the class
+ self.__class__.__module__ + "." + self.__class__.__qualname__,
+ {
+ "db_instance_identifier": self.db_instance_identifier,
+ "waiter_delay": str(self.waiter_delay),
+ "waiter_max_attempts": str(self.waiter_max_attempts),
+ "aws_conn_id": self.aws_conn_id,
+ "hook_params": self.hook_params,
+ "waiter_name": self.waiter_name,
+ "response": self.response,
+ },
+ )
+
+ async def run(self):
+ self.hook = RdsHook(aws_conn_id=self.aws_conn_id, **self.hook_params)
+ async with self.hook.async_conn as client:
+ waiter = client.get_waiter(self.waiter_name)
+ await async_wait(
+ waiter=waiter,
+ waiter_delay=int(self.waiter_delay),
+ waiter_max_attempts=int(self.waiter_max_attempts),
+ args={"DBInstanceIdentifier": self.db_instance_identifier},
+ failure_message="Error checking DB Instance status",
+ status_message="DB instance status is",
+ status_args=["DBInstances[0].DBInstanceStatus"],
+ )
+ yield TriggerEvent({"status": "success", "response": self.response})
diff --git a/airflow/providers/amazon/provider.yaml
b/airflow/providers/amazon/provider.yaml
index bf2391d9a8..0c74094e81 100644
--- a/airflow/providers/amazon/provider.yaml
+++ b/airflow/providers/amazon/provider.yaml
@@ -546,6 +546,9 @@ triggers:
- integration-name: Amazon ECS
python-modules:
- airflow.providers.amazon.aws.triggers.ecs
+ - integration-name: Amazon RDS
+ python-modules:
+ - airflow.providers.amazon.aws.triggers.rds
transfers:
- source-integration-name: Amazon DynamoDB
diff --git a/docs/apache-airflow-providers-amazon/operators/rds.rst
b/docs/apache-airflow-providers-amazon/operators/rds.rst
index bca9c64af1..e27bbc2d2f 100644
--- a/docs/apache-airflow-providers-amazon/operators/rds.rst
+++ b/docs/apache-airflow-providers-amazon/operators/rds.rst
@@ -145,6 +145,7 @@ Create a database instance
To create a AWS DB instance you can use
:class:`~airflow.providers.amazon.aws.operators.rds.RdsCreateDbInstanceOperator`.
+You can also run this operator in deferrable mode by setting ``deferrable``
param to ``True``.
.. exampleinclude::
/../../tests/system/providers/amazon/aws/example_rds_instance.py
:language: python
@@ -159,6 +160,7 @@ Delete a database instance
To delete a AWS DB instance you can use
:class:`~airflow.providers.amazon.aws.operators.rds.RDSDeleteDbInstanceOperator`.
+You can also run this operator in deferrable mode by setting ``deferrable``
param to ``True``.
.. exampleinclude::
/../../tests/system/providers/amazon/aws/example_rds_instance.py
:language: python
diff --git a/tests/providers/amazon/aws/hooks/test_rds.py
b/tests/providers/amazon/aws/hooks/test_rds.py
index f98320a2f1..a34e43c944 100644
--- a/tests/providers/amazon/aws/hooks/test_rds.py
+++ b/tests/providers/amazon/aws/hooks/test_rds.py
@@ -153,7 +153,6 @@ class TestRdsHook:
mock.return_value.wait.assert_called_once_with(
DBInstanceIdentifier=db_instance_id,
WaiterConfig={
- "Delay": self.waiter_args["check_interval"],
"MaxAttempts": self.waiter_args["max_attempts"],
},
)
diff --git a/tests/providers/amazon/aws/triggers/test_rds.py
b/tests/providers/amazon/aws/triggers/test_rds.py
new file mode 100644
index 0000000000..5ae64b83c3
--- /dev/null
+++ b/tests/providers/amazon/aws/triggers/test_rds.py
@@ -0,0 +1,184 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest import mock
+from unittest.mock import AsyncMock
+
+import pytest
+from botocore.exceptions import WaiterError
+
+from airflow.exceptions import AirflowException
+from airflow.providers.amazon.aws.hooks.rds import RdsHook
+from airflow.providers.amazon.aws.triggers.rds import RdsDbInstanceTrigger
+from airflow.triggers.base import TriggerEvent
+
+TEST_DB_INSTANCE_IDENTIFIER = "test-db-instance-identifier"
+TEST_WAITER_DELAY = 10
+TEST_WAITER_MAX_ATTEMPTS = 10
+TEST_AWS_CONN_ID = "test-aws-id"
+TEST_RESPONSE = {
+ "DBInstance": {
+ "DBInstanceIdentifier": "test-db-instance-identifier",
+ "DBInstanceStatus": "test-db-instance-status",
+ }
+}
+
+
+class TestRdsDbInstanceTrigger:
+ def test_rds_db_instance_trigger_serialize(self):
+ rds_db_instance_trigger = RdsDbInstanceTrigger(
+ waiter_name="test-waiter",
+ db_instance_identifier=TEST_DB_INSTANCE_IDENTIFIER,
+ waiter_delay=TEST_WAITER_DELAY,
+ waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
+ aws_conn_id=TEST_AWS_CONN_ID,
+ hook_params={},
+ response=TEST_RESPONSE,
+ )
+ class_path, args = rds_db_instance_trigger.serialize()
+
+ assert class_path ==
"airflow.providers.amazon.aws.triggers.rds.RdsDbInstanceTrigger"
+ assert args["waiter_name"] == "test-waiter"
+ assert args["db_instance_identifier"] == TEST_DB_INSTANCE_IDENTIFIER
+ assert args["waiter_delay"] == str(TEST_WAITER_DELAY)
+ assert args["waiter_max_attempts"] == str(TEST_WAITER_MAX_ATTEMPTS)
+ assert args["aws_conn_id"] == TEST_AWS_CONN_ID
+ assert args["hook_params"] == {}
+ assert args["response"] == TEST_RESPONSE
+
+ @pytest.mark.asyncio
+ @mock.patch.object(RdsHook, "async_conn")
+ async def test_rds_db_instance_trigger_run(self, mock_async_conn):
+ a_mock = mock.MagicMock()
+ mock_async_conn.__aenter__.return_value = a_mock
+
+ a_mock.get_waiter().wait = AsyncMock()
+
+ rds_db_instance_trigger = RdsDbInstanceTrigger(
+ waiter_name="test-waiter",
+ db_instance_identifier=TEST_DB_INSTANCE_IDENTIFIER,
+ waiter_delay=TEST_WAITER_DELAY,
+ waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
+ aws_conn_id=TEST_AWS_CONN_ID,
+ hook_params={},
+ response=TEST_RESPONSE,
+ )
+
+ generator = rds_db_instance_trigger.run()
+ response = await generator.asend(None)
+
+ assert response == TriggerEvent({"status": "success", "response":
TEST_RESPONSE})
+
+ @pytest.mark.asyncio
+ @mock.patch("asyncio.sleep")
+ @mock.patch.object(RdsHook, "async_conn")
+ async def test_rds_db_instance_trigger_run_multiple_attempts(self,
mock_async_conn, mock_sleep):
+ mock_sleep.return_value = True
+ a_mock = mock.MagicMock()
+ mock_async_conn.__aenter__.return_value = a_mock
+ error = WaiterError(
+ name="test_name",
+ reason="test_reason",
+ last_response={"DBInstances": [{"DBInstanceStatus": "CREATING"}]},
+ )
+ a_mock.get_waiter().wait = AsyncMock(side_effect=[error, error, error,
True])
+
+ rds_db_instance_trigger = RdsDbInstanceTrigger(
+ waiter_name="test-waiter",
+ db_instance_identifier=TEST_DB_INSTANCE_IDENTIFIER,
+ waiter_delay=TEST_WAITER_DELAY,
+ waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
+ aws_conn_id=TEST_AWS_CONN_ID,
+ hook_params={},
+ response=TEST_RESPONSE,
+ )
+
+ generator = rds_db_instance_trigger.run()
+ response = await generator.asend(None)
+ assert a_mock.get_waiter().wait.call_count == 4
+
+ assert response == TriggerEvent({"status": "success", "response":
TEST_RESPONSE})
+
+ @pytest.mark.asyncio
+ @mock.patch("asyncio.sleep")
+ @mock.patch.object(RdsHook, "async_conn")
+ async def test_rds_db_instance_trigger_run_attempts_exceeded(self,
mock_async_conn, mock_sleep):
+ mock_sleep.return_value = True
+
+ a_mock = mock.MagicMock()
+ mock_async_conn.__aenter__.return_value = a_mock
+ error = WaiterError(
+ name="test_name",
+ reason="test_reason",
+ last_response={"DBInstances": [{"DBInstanceStatus": "CREATING"}]},
+ )
+ a_mock.get_waiter().wait = AsyncMock(side_effect=[error, error, error,
True])
+
+ rds_db_instance_trigger = RdsDbInstanceTrigger(
+ waiter_name="test-waiter",
+ db_instance_identifier=TEST_DB_INSTANCE_IDENTIFIER,
+ waiter_delay=TEST_WAITER_DELAY,
+ waiter_max_attempts=2,
+ aws_conn_id=TEST_AWS_CONN_ID,
+ hook_params={},
+ response=TEST_RESPONSE,
+ )
+
+ with pytest.raises(AirflowException) as exc:
+ generator = rds_db_instance_trigger.run()
+ await generator.asend(None)
+
+ assert "Waiter error: max attempts reached" in str(exc.value)
+ assert a_mock.get_waiter().wait.call_count == 2
+
+ @pytest.mark.asyncio
+ @mock.patch("asyncio.sleep")
+ @mock.patch.object(RdsHook, "async_conn")
+ async def test_rds_db_instance_trigger_run_attempts_failed(self,
mock_async_conn, mock_sleep):
+ a_mock = mock.MagicMock()
+ mock_async_conn.__aenter__.return_value = a_mock
+
+ error_creating = WaiterError(
+ name="test_name",
+ reason="test_reason",
+ last_response={"DBInstances": [{"DBInstanceStatus": "CREATING"}]},
+ )
+
+ error_failed = WaiterError(
+ name="test_name",
+ reason="Waiter encountered a terminal failure state:",
+ last_response={"DBInstances": [{"DBInstanceStatus": "FAILED"}]},
+ )
+ a_mock.get_waiter().wait = AsyncMock(side_effect=[error_creating,
error_creating, error_failed])
+ mock_sleep.return_value = True
+
+ rds_db_instance_trigger = RdsDbInstanceTrigger(
+ waiter_name="test-waiter",
+ db_instance_identifier=TEST_DB_INSTANCE_IDENTIFIER,
+ waiter_delay=TEST_WAITER_DELAY,
+ waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
+ aws_conn_id=TEST_AWS_CONN_ID,
+ hook_params={},
+ response=TEST_RESPONSE,
+ )
+
+ with pytest.raises(AirflowException) as exc:
+ generator = rds_db_instance_trigger.run()
+ await generator.asend(None)
+ assert "Error checking DB Instance status" in str(exc.value)
+ assert a_mock.get_waiter().wait.call_count == 3