This is an automated email from the ASF dual-hosted git repository.
onikolas pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new b9c84eb663 add deferrable mode to rds start & stop DB (#32437)
b9c84eb663 is described below
commit b9c84eb6639e825ed951c08e477411bf52dfc437
Author: Raphaƫl Vandon <[email protected]>
AuthorDate: Wed Jul 19 10:37:32 2023 -0700
add deferrable mode to rds start & stop DB (#32437)
* remove tests on obsolete trigger
* add delay & attempts to params
---------
Co-authored-by: Wei Lee <[email protected]>
---
airflow/providers/amazon/aws/operators/rds.py | 168 +++++++++++++------
airflow/providers/amazon/aws/triggers/rds.py | 166 ++++++++++++++++++-
airflow/providers/amazon/aws/utils/rds.py | 4 +-
docs/spelling_wordlist.txt | 1 +
tests/providers/amazon/aws/operators/test_rds.py | 29 ++++
tests/providers/amazon/aws/triggers/test_rds.py | 201 ++++++-----------------
6 files changed, 363 insertions(+), 206 deletions(-)
diff --git a/airflow/providers/amazon/aws/operators/rds.py
b/airflow/providers/amazon/aws/operators/rds.py
index c037251643..7a544a8c00 100644
--- a/airflow/providers/amazon/aws/operators/rds.py
+++ b/airflow/providers/amazon/aws/operators/rds.py
@@ -20,15 +20,20 @@ from __future__ import annotations
import json
import warnings
from datetime import timedelta
-from typing import TYPE_CHECKING, Sequence
+from typing import TYPE_CHECKING, Any, Sequence
+from astroid.decorators import cachedproperty
from mypy_boto3_rds.type_defs import TagTypeDef
from airflow.configuration import conf
from airflow.exceptions import AirflowException,
AirflowProviderDeprecationWarning
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.rds import RdsHook
-from airflow.providers.amazon.aws.triggers.rds import RdsDbInstanceTrigger
+from airflow.providers.amazon.aws.triggers.rds import (
+ RdsDbAvailableTrigger,
+ RdsDbDeletedTrigger,
+ RdsDbStoppedTrigger,
+)
from airflow.providers.amazon.aws.utils.rds import RdsDbType
from airflow.providers.amazon.aws.utils.tags import format_tags
from airflow.providers.amazon.aws.utils.waiter_with_logging import wait
@@ -62,13 +67,17 @@ class RdsBaseOperator(BaseOperator):
AirflowProviderDeprecationWarning,
stacklevel=3, # 2 is in the operator's init, 3 is in the user
code creating the operator
)
- hook_params = hook_params or {}
- self.region_name = region_name or hook_params.pop("region_name", None)
- self.hook = RdsHook(aws_conn_id=aws_conn_id,
region_name=self.region_name, **(hook_params))
+ self.hook_params = hook_params or {}
+ self.aws_conn_id = aws_conn_id
+ self.region_name = region_name or self.hook_params.pop("region_name",
None)
super().__init__(*args, **kwargs)
self._await_interval = 60 # seconds
+ @cachedproperty
+ def hook(self) -> RdsHook:
+ return RdsHook(aws_conn_id=self.aws_conn_id,
region_name=self.region_name, **self.hook_params)
+
def execute(self, context: Context) -> str:
"""Different implementations for snapshots, tasks and events."""
raise NotImplementedError
@@ -106,10 +115,9 @@ class RdsCreateDbSnapshotOperator(RdsBaseOperator):
db_snapshot_identifier: str,
tags: Sequence[TagTypeDef] | dict | None = None,
wait_for_completion: bool = True,
- aws_conn_id: str = "aws_default",
**kwargs,
):
- super().__init__(aws_conn_id=aws_conn_id, **kwargs)
+ super().__init__(**kwargs)
self.db_type = RdsDbType(db_type)
self.db_identifier = db_identifier
self.db_snapshot_identifier = db_snapshot_identifier
@@ -194,10 +202,9 @@ class RdsCopyDbSnapshotOperator(RdsBaseOperator):
target_custom_availability_zone: str = "",
source_region: str = "",
wait_for_completion: bool = True,
- aws_conn_id: str = "aws_default",
**kwargs,
):
- super().__init__(aws_conn_id=aws_conn_id, **kwargs)
+ super().__init__(**kwargs)
self.db_type = RdsDbType(db_type)
self.source_db_snapshot_identifier = source_db_snapshot_identifier
@@ -274,10 +281,9 @@ class RdsDeleteDbSnapshotOperator(RdsBaseOperator):
db_type: str,
db_snapshot_identifier: str,
wait_for_completion: bool = True,
- aws_conn_id: str = "aws_default",
**kwargs,
):
- super().__init__(aws_conn_id=aws_conn_id, **kwargs)
+ super().__init__(**kwargs)
self.db_type = RdsDbType(db_type)
self.db_snapshot_identifier = db_snapshot_identifier
@@ -345,10 +351,9 @@ class RdsStartExportTaskOperator(RdsBaseOperator):
s3_prefix: str = "",
export_only: list[str] | None = None,
wait_for_completion: bool = True,
- aws_conn_id: str = "aws_default",
**kwargs,
):
- super().__init__(aws_conn_id=aws_conn_id, **kwargs)
+ super().__init__(**kwargs)
self.export_task_identifier = export_task_identifier
self.source_arn = source_arn
@@ -397,10 +402,9 @@ class RdsCancelExportTaskOperator(RdsBaseOperator):
export_task_identifier: str,
wait_for_completion: bool = True,
check_interval: int = 30,
- aws_conn_id: str = "aws_default",
**kwargs,
):
- super().__init__(aws_conn_id=aws_conn_id, **kwargs)
+ super().__init__(**kwargs)
self.export_task_identifier = export_task_identifier
self.wait_for_completion = wait_for_completion
@@ -461,10 +465,9 @@ class RdsCreateEventSubscriptionOperator(RdsBaseOperator):
enabled: bool = True,
tags: Sequence[TagTypeDef] | dict | None = None,
wait_for_completion: bool = True,
- aws_conn_id: str = "aws_default",
**kwargs,
):
- super().__init__(aws_conn_id=aws_conn_id, **kwargs)
+ super().__init__(**kwargs)
self.subscription_name = subscription_name
self.sns_topic_arn = sns_topic_arn
@@ -511,10 +514,9 @@ class RdsDeleteEventSubscriptionOperator(RdsBaseOperator):
self,
*,
subscription_name: str,
- aws_conn_id: str = "aws_default",
**kwargs,
):
- super().__init__(aws_conn_id=aws_conn_id, **kwargs)
+ super().__init__(**kwargs)
self.subscription_name = subscription_name
@@ -545,7 +547,6 @@ class RdsCreateDbInstanceOperator(RdsBaseOperator):
:param engine: The name of the database engine to be used for this instance
:param rds_kwargs: Named arguments to pass to boto3 RDS client function
``create_db_instance``
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds.html#RDS.Client.create_db_instance
- :param aws_conn_id: The Airflow connection used for AWS credentials.
:param wait_for_completion: If True, waits for creation of the DB
instance to complete. (default: True)
:param waiter_delay: Time (in seconds) to wait between two consecutive
calls to check DB instance state
:param waiter_max_attempts: The maximum number of attempts to check DB
instance state
@@ -563,14 +564,13 @@ class RdsCreateDbInstanceOperator(RdsBaseOperator):
db_instance_class: str,
engine: str,
rds_kwargs: dict | None = None,
- aws_conn_id: str = "aws_default",
wait_for_completion: bool = True,
deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
waiter_delay: int = 30,
waiter_max_attempts: int = 60,
**kwargs,
):
- super().__init__(aws_conn_id=aws_conn_id, **kwargs)
+ super().__init__(**kwargs)
self.db_instance_identifier = db_instance_identifier
self.db_instance_class = db_instance_class
@@ -580,7 +580,6 @@ class RdsCreateDbInstanceOperator(RdsBaseOperator):
self.deferrable = deferrable
self.waiter_delay = waiter_delay
self.waiter_max_attempts = waiter_max_attempts
- self.aws_conn_id = aws_conn_id
def execute(self, context: Context) -> str:
self.log.info("Creating new DB instance %s",
self.db_instance_identifier)
@@ -593,15 +592,15 @@ class RdsCreateDbInstanceOperator(RdsBaseOperator):
)
if self.deferrable:
self.defer(
- trigger=RdsDbInstanceTrigger(
- db_instance_identifier=self.db_instance_identifier,
+ trigger=RdsDbAvailableTrigger(
+ db_identifier=self.db_instance_identifier,
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
- waiter_name="db_instance_available",
# ignoring type because create_db_instance is a dict
response=create_db_instance, # type: ignore[arg-type]
+ db_type=RdsDbType.INSTANCE,
),
method_name="execute_complete",
timeout=timedelta(seconds=self.waiter_delay *
self.waiter_max_attempts),
@@ -638,7 +637,6 @@ class RdsDeleteDbInstanceOperator(RdsBaseOperator):
:param db_instance_identifier: The DB instance identifier for the DB
instance to be deleted
:param rds_kwargs: Named arguments to pass to boto3 RDS client function
``delete_db_instance``
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds.html#RDS.Client.delete_db_instance
- :param aws_conn_id: The Airflow connection used for AWS credentials.
:param wait_for_completion: If True, waits for deletion of the DB
instance to complete. (default: True)
:param waiter_delay: Time (in seconds) to wait between two consecutive
calls to check DB instance state
:param waiter_max_attempts: The maximum number of attempts to check DB
instance state
@@ -654,21 +652,19 @@ class RdsDeleteDbInstanceOperator(RdsBaseOperator):
*,
db_instance_identifier: str,
rds_kwargs: dict | None = None,
- aws_conn_id: str = "aws_default",
wait_for_completion: bool = True,
deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
waiter_delay: int = 30,
waiter_max_attempts: int = 60,
**kwargs,
):
- super().__init__(aws_conn_id=aws_conn_id, **kwargs)
+ super().__init__(**kwargs)
self.db_instance_identifier = db_instance_identifier
self.rds_kwargs = rds_kwargs or {}
self.wait_for_completion = False if deferrable else wait_for_completion
self.deferrable = deferrable
self.waiter_delay = waiter_delay
self.waiter_max_attempts = waiter_max_attempts
- self.aws_conn_id = aws_conn_id
def execute(self, context: Context) -> str:
self.log.info("Deleting DB instance %s", self.db_instance_identifier)
@@ -679,15 +675,15 @@ class RdsDeleteDbInstanceOperator(RdsBaseOperator):
)
if self.deferrable:
self.defer(
- trigger=RdsDbInstanceTrigger(
- db_instance_identifier=self.db_instance_identifier,
+ trigger=RdsDbDeletedTrigger(
+ db_identifier=self.db_instance_identifier,
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
- waiter_name="db_instance_deleted",
# ignoring type because delete_db_instance is a dict
response=delete_db_instance, # type: ignore[arg-type]
+ db_type=RdsDbType.INSTANCE,
),
method_name="execute_complete",
timeout=timedelta(seconds=self.waiter_delay *
self.waiter_max_attempts),
@@ -723,8 +719,11 @@ class RdsStartDbOperator(RdsBaseOperator):
:param db_identifier: The AWS identifier of the DB to start
:param db_type: Type of the DB - either "instance" or "cluster" (default:
"instance")
- :param aws_conn_id: The Airflow connection used for AWS credentials.
(default: "aws_default")
:param wait_for_completion: If True, waits for DB to start. (default:
True)
+ :param waiter_delay: Time (in seconds) to wait between two consecutive
calls to check DB instance state
+ :param waiter_max_attempts: The maximum number of attempts to check DB
instance state
+ :param deferrable: If True, the operator will wait asynchronously for the
DB instance to be created.
+ This implies waiting for completion. This mode requires aiobotocore
module to be installed.
"""
template_fields = ("db_identifier", "db_type")
@@ -734,26 +733,52 @@ class RdsStartDbOperator(RdsBaseOperator):
*,
db_identifier: str,
db_type: RdsDbType | str = RdsDbType.INSTANCE,
- aws_conn_id: str = "aws_default",
wait_for_completion: bool = True,
+ waiter_delay: int = 30,
+ waiter_max_attempts: int = 40,
+ deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
**kwargs,
):
- super().__init__(aws_conn_id=aws_conn_id, **kwargs)
+ super().__init__(**kwargs)
self.db_identifier = db_identifier
self.db_type = db_type
self.wait_for_completion = wait_for_completion
+ self.waiter_delay = waiter_delay
+ self.waiter_max_attempts = waiter_max_attempts
+ self.deferrable = deferrable
def execute(self, context: Context) -> str:
self.db_type = RdsDbType(self.db_type)
- start_db_response = self._start_db()
- if self.wait_for_completion:
+ start_db_response: dict[str, Any] = self._start_db()
+ if self.deferrable:
+ self.defer(
+ trigger=RdsDbAvailableTrigger(
+ db_identifier=self.db_identifier,
+ waiter_delay=self.waiter_delay,
+ waiter_max_attempts=self.waiter_max_attempts,
+ aws_conn_id=self.aws_conn_id,
+ region_name=self.region_name,
+ response=start_db_response,
+ db_type=RdsDbType.INSTANCE,
+ ),
+ method_name="execute_complete",
+ )
+ elif self.wait_for_completion:
self._wait_until_db_available()
return json.dumps(start_db_response, default=str)
+ def execute_complete(self, context: Context, event: dict[str, Any] | None
= None) -> str:
+ if event is None or event["status"] != "success":
+ raise AirflowException(f"Failed to start DB: {event}")
+ else:
+ return json.dumps(event["response"], default=str)
+
def _start_db(self):
self.log.info("Starting DB %s '%s'", self.db_type.value,
self.db_identifier)
if self.db_type == RdsDbType.INSTANCE:
- response =
self.hook.conn.start_db_instance(DBInstanceIdentifier=self.db_identifier)
+ response = self.hook.conn.start_db_instance(
+ DBInstanceIdentifier=self.db_identifier,
+ )
else:
response =
self.hook.conn.start_db_cluster(DBClusterIdentifier=self.db_identifier)
return response
@@ -761,9 +786,19 @@ class RdsStartDbOperator(RdsBaseOperator):
def _wait_until_db_available(self):
self.log.info("Waiting for DB %s to reach 'available' state",
self.db_type.value)
if self.db_type == RdsDbType.INSTANCE:
- self.hook.wait_for_db_instance_state(self.db_identifier,
target_state="available")
+ self.hook.wait_for_db_instance_state(
+ self.db_identifier,
+ target_state="available",
+ check_interval=self.waiter_delay,
+ max_attempts=self.waiter_max_attempts,
+ )
else:
- self.hook.wait_for_db_cluster_state(self.db_identifier,
target_state="available")
+ self.hook.wait_for_db_cluster_state(
+ self.db_identifier,
+ target_state="available",
+ check_interval=self.waiter_delay,
+ max_attempts=self.waiter_max_attempts,
+ )
class RdsStopDbOperator(RdsBaseOperator):
@@ -779,8 +814,11 @@ class RdsStopDbOperator(RdsBaseOperator):
:param db_snapshot_identifier: The instance identifier of the DB Snapshot
to create before
stopping the DB instance. The default value (None) skips snapshot
creation. This
parameter is ignored when ``db_type`` is "cluster"
- :param aws_conn_id: The Airflow connection used for AWS credentials.
(default: "aws_default")
:param wait_for_completion: If True, waits for DB to stop. (default: True)
+ :param waiter_delay: Time (in seconds) to wait between two consecutive
calls to check DB instance state
+ :param waiter_max_attempts: The maximum number of attempts to check DB
instance state
+ :param deferrable: If True, the operator will wait asynchronously for the
DB instance to be created.
+ This implies waiting for completion. This mode requires aiobotocore
module to be installed.
"""
template_fields = ("db_identifier", "db_snapshot_identifier", "db_type")
@@ -791,23 +829,47 @@ class RdsStopDbOperator(RdsBaseOperator):
db_identifier: str,
db_type: RdsDbType | str = RdsDbType.INSTANCE,
db_snapshot_identifier: str | None = None,
- aws_conn_id: str = "aws_default",
wait_for_completion: bool = True,
+ waiter_delay: int = 30,
+ waiter_max_attempts: int = 40,
+ deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
**kwargs,
):
- super().__init__(aws_conn_id=aws_conn_id, **kwargs)
+ super().__init__(**kwargs)
self.db_identifier = db_identifier
self.db_type = db_type
self.db_snapshot_identifier = db_snapshot_identifier
self.wait_for_completion = wait_for_completion
+ self.waiter_delay = waiter_delay
+ self.waiter_max_attempts = waiter_max_attempts
+ self.deferrable = deferrable
def execute(self, context: Context) -> str:
self.db_type = RdsDbType(self.db_type)
- stop_db_response = self._stop_db()
- if self.wait_for_completion:
+ stop_db_response: dict[str, Any] = self._stop_db()
+ if self.deferrable:
+ self.defer(
+ trigger=RdsDbStoppedTrigger(
+ db_identifier=self.db_identifier,
+ waiter_delay=self.waiter_delay,
+ waiter_max_attempts=self.waiter_max_attempts,
+ aws_conn_id=self.aws_conn_id,
+ region_name=self.region_name,
+ response=stop_db_response,
+ db_type=RdsDbType.INSTANCE,
+ ),
+ method_name="execute_complete",
+ )
+ elif self.wait_for_completion:
self._wait_until_db_stopped()
return json.dumps(stop_db_response, default=str)
+ def execute_complete(self, context: Context, event: dict[str, Any] | None
= None) -> str:
+ if event is None or event["status"] != "success":
+ raise AirflowException(f"Failed to start DB: {event}")
+ else:
+ return json.dumps(event["response"], default=str)
+
def _stop_db(self):
self.log.info("Stopping DB %s '%s'", self.db_type.value,
self.db_identifier)
if self.db_type == RdsDbType.INSTANCE:
@@ -829,9 +891,19 @@ class RdsStopDbOperator(RdsBaseOperator):
def _wait_until_db_stopped(self):
self.log.info("Waiting for DB %s to reach 'stopped' state",
self.db_type.value)
if self.db_type == RdsDbType.INSTANCE:
- self.hook.wait_for_db_instance_state(self.db_identifier,
target_state="stopped")
+ self.hook.wait_for_db_instance_state(
+ self.db_identifier,
+ target_state="stopped",
+ check_interval=self.waiter_delay,
+ max_attempts=self.waiter_max_attempts,
+ )
else:
- self.hook.wait_for_db_cluster_state(self.db_identifier,
target_state="stopped")
+ self.hook.wait_for_db_cluster_state(
+ self.db_identifier,
+ target_state="stopped",
+ check_interval=self.waiter_delay,
+ max_attempts=self.waiter_max_attempts,
+ )
__all__ = [
diff --git a/airflow/providers/amazon/aws/triggers/rds.py
b/airflow/providers/amazon/aws/triggers/rds.py
index 0551d67591..b8f20b9c9b 100644
--- a/airflow/providers/amazon/aws/triggers/rds.py
+++ b/airflow/providers/amazon/aws/triggers/rds.py
@@ -16,19 +16,21 @@
# under the License.
from __future__ import annotations
+import warnings
from typing import Any
+from airflow.exceptions import AirflowProviderDeprecationWarning
+from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
from airflow.providers.amazon.aws.hooks.rds import RdsHook
+from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger
+from airflow.providers.amazon.aws.utils.rds import RdsDbType
from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait
from airflow.triggers.base import BaseTrigger, TriggerEvent
class RdsDbInstanceTrigger(BaseTrigger):
"""
- Trigger for RdsCreateDbInstanceOperator and RdsDeleteDbInstanceOperator.
-
- The trigger will asynchronously poll the boto3 API and wait for the
- DB instance to be in the state specified by the waiter.
+ Deprecated Trigger for RDS operations. Do not use.
:param waiter_name: Name of the waiter to use, for instance
'db_instance_available'
or 'db_instance_deleted'.
@@ -36,7 +38,7 @@ class RdsDbInstanceTrigger(BaseTrigger):
:param waiter_delay: The amount of time in seconds to wait between
attempts.
:param waiter_max_attempts: The maximum number of attempts to be made.
:param aws_conn_id: The Airflow connection used for AWS credentials.
- :param hook_params: The parameters to pass to the RdsHook.
+ :param region_name: AWS region where the DB is located, if different from
the default one.
:param response: The response from the RdsHook, to be passed back to the
operator.
"""
@@ -50,6 +52,12 @@ class RdsDbInstanceTrigger(BaseTrigger):
region_name: str | None,
response: dict[str, Any],
):
+ warnings.warn(
+ "This trigger is deprecated, please use the other RDS triggers "
+ "such as RdsDbDeletedTrigger, RdsDbStoppedTrigger or
RdsDbAvailableTrigger",
+ AirflowProviderDeprecationWarning,
+ stacklevel=2,
+ )
self.db_instance_identifier = db_instance_identifier
self.waiter_delay = waiter_delay
self.waiter_max_attempts = waiter_max_attempts
@@ -87,3 +95,151 @@ class RdsDbInstanceTrigger(BaseTrigger):
status_args=["DBInstances[0].DBInstanceStatus"],
)
yield TriggerEvent({"status": "success", "response": self.response})
+
+
+_waiter_arg = {
+ RdsDbType.INSTANCE: "DBInstanceIdentifier",
+ RdsDbType.CLUSTER: "DBClusterIdentifier",
+}
+_status_paths = {
+ RdsDbType.INSTANCE: ["DBInstances[].DBInstanceStatus",
"DBInstances[].StatusInfos"],
+ RdsDbType.CLUSTER: ["DBClusters[].Status"],
+}
+
+
+class RdsDbAvailableTrigger(AwsBaseWaiterTrigger):
+ """
+ Trigger to wait asynchronously for a DB instance or cluster to be
available.
+
+ :param db_identifier: The DB identifier for the DB instance or cluster to
be polled.
+ :param waiter_delay: The amount of time in seconds to wait between
attempts.
+ :param waiter_max_attempts: The maximum number of attempts to be made.
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
+ :param region_name: AWS region where the DB is located, if different from
the default one.
+ :param response: The response from the RdsHook, to be passed back to the
operator.
+ :param db_type: The type of DB: instance or cluster.
+ """
+
+ def __init__(
+ self,
+ db_identifier: str,
+ waiter_delay: int,
+ waiter_max_attempts: int,
+ aws_conn_id: str,
+ region_name: str | None,
+ response: dict[str, Any],
+ db_type: RdsDbType,
+ ) -> None:
+ super().__init__(
+ serialized_fields={
+ "db_identifier": db_identifier,
+ "response": response,
+ "db_type": db_type,
+ },
+ waiter_name=f"db_{db_type.value}_available",
+ waiter_args={_waiter_arg[db_type]: db_identifier},
+ failure_message="Error while waiting for DB to be available",
+ status_message="DB initialization in progress",
+ status_queries=_status_paths[db_type],
+ return_key="response",
+ return_value=response,
+ waiter_delay=waiter_delay,
+ waiter_max_attempts=waiter_max_attempts,
+ aws_conn_id=aws_conn_id,
+ region_name=region_name,
+ )
+
+ def hook(self) -> AwsGenericHook:
+ return RdsHook(aws_conn_id=self.aws_conn_id,
region_name=self.region_name)
+
+
+class RdsDbDeletedTrigger(AwsBaseWaiterTrigger):
+ """
+ Trigger to wait asynchronously for a DB instance or cluster to be deleted.
+
+ :param db_identifier: The DB identifier for the DB instance or cluster to
be polled.
+ :param waiter_delay: The amount of time in seconds to wait between
attempts.
+ :param waiter_max_attempts: The maximum number of attempts to be made.
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
+ :param region_name: AWS region where the DB is located, if different from
the default one.
+ :param response: The response from the RdsHook, to be passed back to the
operator.
+ :param db_type: The type of DB: instance or cluster.
+ """
+
+ def __init__(
+ self,
+ db_identifier: str,
+ waiter_delay: int,
+ waiter_max_attempts: int,
+ aws_conn_id: str,
+ region_name: str | None,
+ response: dict[str, Any],
+ db_type: RdsDbType,
+ ) -> None:
+ super().__init__(
+ serialized_fields={
+ "db_identifier": db_identifier,
+ "response": response,
+ "db_type": db_type,
+ },
+ waiter_name=f"db_{db_type.value}_deleted",
+ waiter_args={_waiter_arg[db_type]: db_identifier},
+ failure_message="Error while deleting DB",
+ status_message="DB deletion in progress",
+ status_queries=_status_paths[db_type],
+ return_key="response",
+ return_value=response,
+ waiter_delay=waiter_delay,
+ waiter_max_attempts=waiter_max_attempts,
+ aws_conn_id=aws_conn_id,
+ region_name=region_name,
+ )
+
+ def hook(self) -> AwsGenericHook:
+ return RdsHook(aws_conn_id=self.aws_conn_id,
region_name=self.region_name)
+
+
+class RdsDbStoppedTrigger(AwsBaseWaiterTrigger):
+ """
+ Trigger to wait asynchronously for a DB instance or cluster to be stopped.
+
+ :param db_identifier: The DB identifier for the DB instance or cluster to
be polled.
+ :param waiter_delay: The amount of time in seconds to wait between
attempts.
+ :param waiter_max_attempts: The maximum number of attempts to be made.
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
+ :param region_name: AWS region where the DB is located, if different from
the default one.
+ :param response: The response from the RdsHook, to be passed back to the
operator.
+ :param db_type: The type of DB: instance or cluster.
+ """
+
+ def __init__(
+ self,
+ db_identifier: str,
+ waiter_delay: int,
+ waiter_max_attempts: int,
+ aws_conn_id: str,
+ region_name: str | None,
+ response: dict[str, Any],
+ db_type: RdsDbType,
+ ) -> None:
+ super().__init__(
+ serialized_fields={
+ "db_identifier": db_identifier,
+ "response": response,
+ "db_type": db_type,
+ },
+ waiter_name=f"db_{db_type.value}_stopped",
+ waiter_args={_waiter_arg[db_type]: db_identifier},
+ failure_message="Error while stopping DB",
+ status_message="DB is being stopped",
+ status_queries=_status_paths[db_type],
+ return_key="response",
+ return_value=response,
+ waiter_delay=waiter_delay,
+ waiter_max_attempts=waiter_max_attempts,
+ aws_conn_id=aws_conn_id,
+ region_name=region_name,
+ )
+
+ def hook(self) -> AwsGenericHook:
+ return RdsHook(aws_conn_id=self.aws_conn_id,
region_name=self.region_name)
diff --git a/airflow/providers/amazon/aws/utils/rds.py
b/airflow/providers/amazon/aws/utils/rds.py
index 1ba511c4b4..eacbf0b63a 100644
--- a/airflow/providers/amazon/aws/utils/rds.py
+++ b/airflow/providers/amazon/aws/utils/rds.py
@@ -22,5 +22,5 @@ from enum import Enum
class RdsDbType(Enum):
"""Only available types for the RDS."""
- INSTANCE: str = "instance"
- CLUSTER: str = "cluster"
+ INSTANCE = "instance"
+ CLUSTER = "cluster"
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index 0975656c75..ad3b7cd8f8 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -1208,6 +1208,7 @@ RBAC
rbac
rc
rdbms
+RDS
readfp
Readme
readme
diff --git a/tests/providers/amazon/aws/operators/test_rds.py
b/tests/providers/amazon/aws/operators/test_rds.py
index d2408fa47d..47122d55c0 100644
--- a/tests/providers/amazon/aws/operators/test_rds.py
+++ b/tests/providers/amazon/aws/operators/test_rds.py
@@ -18,11 +18,13 @@
from __future__ import annotations
import logging
+from unittest import mock
from unittest.mock import patch
import pytest
from moto import mock_rds
+from airflow.exceptions import TaskDeferred
from airflow.models import DAG
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
from airflow.providers.amazon.aws.hooks.rds import RdsHook
@@ -40,6 +42,7 @@ from airflow.providers.amazon.aws.operators.rds import (
RdsStartExportTaskOperator,
RdsStopDbOperator,
)
+from airflow.providers.amazon.aws.triggers.rds import RdsDbAvailableTrigger,
RdsDbStoppedTrigger
from airflow.utils import timezone
DEFAULT_DATE = timezone.datetime(2019, 1, 1)
@@ -832,6 +835,19 @@ class TestRdsStopDbOperator:
assert status == "stopped"
mock_await_status.assert_not_called()
+ @mock.patch.object(RdsHook, "conn")
+ def test_deferred(self, conn_mock):
+ op = RdsStopDbOperator(
+ task_id="test_stop_db_instance_no_wait",
+ db_identifier=DB_INSTANCE_NAME,
+ deferrable=True,
+ )
+
+ with pytest.raises(TaskDeferred) as defer:
+ op.execute({})
+
+ assert isinstance(defer.value.trigger, RdsDbStoppedTrigger)
+
@mock_rds
def test_stop_db_instance_create_snapshot(self):
_create_db_instance(self.hook)
@@ -932,3 +948,16 @@ class TestRdsStartDbOperator:
result_after =
self.hook.conn.describe_db_clusters(DBClusterIdentifier=DB_CLUSTER_NAME)
status_after = result_after["DBClusters"][0]["Status"]
assert status_after == "available"
+
+ @mock.patch.object(RdsHook, "conn")
+ def test_deferred(self, conn_mock):
+ op = RdsStartDbOperator(
+ task_id="test_stop_db_instance_no_wait",
+ db_identifier=DB_INSTANCE_NAME,
+ deferrable=True,
+ )
+
+ with pytest.raises(TaskDeferred) as defer:
+ op.execute({})
+
+ assert isinstance(defer.value.trigger, RdsDbAvailableTrigger)
diff --git a/tests/providers/amazon/aws/triggers/test_rds.py
b/tests/providers/amazon/aws/triggers/test_rds.py
index 9c518c8eee..57db41a5e0 100644
--- a/tests/providers/amazon/aws/triggers/test_rds.py
+++ b/tests/providers/amazon/aws/triggers/test_rds.py
@@ -16,16 +16,14 @@
# under the License.
from __future__ import annotations
-from unittest import mock
-from unittest.mock import AsyncMock
-
import pytest
-from botocore.exceptions import WaiterError
-from airflow.exceptions import AirflowException
-from airflow.providers.amazon.aws.hooks.rds import RdsHook
-from airflow.providers.amazon.aws.triggers.rds import RdsDbInstanceTrigger
-from airflow.triggers.base import TriggerEvent
+from airflow.providers.amazon.aws.triggers.rds import (
+ RdsDbAvailableTrigger,
+ RdsDbDeletedTrigger,
+ RdsDbStoppedTrigger,
+)
+from airflow.providers.amazon.aws.utils.rds import RdsDbType
TEST_DB_INSTANCE_IDENTIFIER = "test-db-instance-identifier"
TEST_WAITER_DELAY = 10
@@ -40,146 +38,47 @@ TEST_RESPONSE = {
}
-class TestRdsDbInstanceTrigger:
- def test_rds_db_instance_trigger_serialize(self):
- rds_db_instance_trigger = RdsDbInstanceTrigger(
- waiter_name="test-waiter",
- db_instance_identifier=TEST_DB_INSTANCE_IDENTIFIER,
- waiter_delay=TEST_WAITER_DELAY,
- waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
- aws_conn_id=TEST_AWS_CONN_ID,
- region_name=TEST_REGION,
- response=TEST_RESPONSE,
- )
- class_path, args = rds_db_instance_trigger.serialize()
-
- assert class_path ==
"airflow.providers.amazon.aws.triggers.rds.RdsDbInstanceTrigger"
- assert args["waiter_name"] == "test-waiter"
- assert args["db_instance_identifier"] == TEST_DB_INSTANCE_IDENTIFIER
- assert args["waiter_delay"] == str(TEST_WAITER_DELAY)
- assert args["waiter_max_attempts"] == str(TEST_WAITER_MAX_ATTEMPTS)
- assert args["aws_conn_id"] == TEST_AWS_CONN_ID
- assert args["region_name"] == TEST_REGION
- assert args["response"] == TEST_RESPONSE
-
- @pytest.mark.asyncio
- @mock.patch.object(RdsHook, "async_conn")
- async def test_rds_db_instance_trigger_run(self, mock_async_conn):
- a_mock = mock.MagicMock()
- mock_async_conn.__aenter__.return_value = a_mock
-
- a_mock.get_waiter().wait = AsyncMock()
-
- rds_db_instance_trigger = RdsDbInstanceTrigger(
- waiter_name="test-waiter",
- db_instance_identifier=TEST_DB_INSTANCE_IDENTIFIER,
- waiter_delay=TEST_WAITER_DELAY,
- waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
- aws_conn_id=TEST_AWS_CONN_ID,
- region_name=TEST_REGION,
- response=TEST_RESPONSE,
- )
-
- generator = rds_db_instance_trigger.run()
- response = await generator.asend(None)
-
- assert response == TriggerEvent({"status": "success", "response":
TEST_RESPONSE})
-
- @pytest.mark.asyncio
- @mock.patch("asyncio.sleep")
- @mock.patch.object(RdsHook, "async_conn")
- async def test_rds_db_instance_trigger_run_multiple_attempts(self,
mock_async_conn, mock_sleep):
- mock_sleep.return_value = True
- a_mock = mock.MagicMock()
- mock_async_conn.__aenter__.return_value = a_mock
- error = WaiterError(
- name="test_name",
- reason="test_reason",
- last_response={"DBInstances": [{"DBInstanceStatus": "CREATING"}]},
- )
- a_mock.get_waiter().wait = AsyncMock(side_effect=[error, error, error,
True])
-
- rds_db_instance_trigger = RdsDbInstanceTrigger(
- waiter_name="test-waiter",
- db_instance_identifier=TEST_DB_INSTANCE_IDENTIFIER,
- waiter_delay=TEST_WAITER_DELAY,
- waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
- aws_conn_id=TEST_AWS_CONN_ID,
- region_name=TEST_REGION,
- response=TEST_RESPONSE,
- )
-
- generator = rds_db_instance_trigger.run()
- response = await generator.asend(None)
- assert a_mock.get_waiter().wait.call_count == 4
-
- assert response == TriggerEvent({"status": "success", "response":
TEST_RESPONSE})
-
- @pytest.mark.asyncio
- @mock.patch("asyncio.sleep")
- @mock.patch.object(RdsHook, "async_conn")
- async def test_rds_db_instance_trigger_run_attempts_exceeded(self,
mock_async_conn, mock_sleep):
- mock_sleep.return_value = True
-
- a_mock = mock.MagicMock()
- mock_async_conn.__aenter__.return_value = a_mock
- error = WaiterError(
- name="test_name",
- reason="test_reason",
- last_response={"DBInstances": [{"DBInstanceStatus": "CREATING"}]},
- )
- a_mock.get_waiter().wait = AsyncMock(side_effect=[error, error, error,
True])
-
- rds_db_instance_trigger = RdsDbInstanceTrigger(
- waiter_name="test-waiter",
- db_instance_identifier=TEST_DB_INSTANCE_IDENTIFIER,
- waiter_delay=TEST_WAITER_DELAY,
- waiter_max_attempts=2,
- aws_conn_id=TEST_AWS_CONN_ID,
- region_name=TEST_REGION,
- response=TEST_RESPONSE,
- )
-
- with pytest.raises(AirflowException) as exc:
- generator = rds_db_instance_trigger.run()
- await generator.asend(None)
-
- assert "Waiter error: max attempts reached" in str(exc.value)
- assert a_mock.get_waiter().wait.call_count == 2
-
- @pytest.mark.asyncio
- @mock.patch("asyncio.sleep")
- @mock.patch.object(RdsHook, "async_conn")
- async def test_rds_db_instance_trigger_run_attempts_failed(self,
mock_async_conn, mock_sleep):
- a_mock = mock.MagicMock()
- mock_async_conn.__aenter__.return_value = a_mock
-
- error_creating = WaiterError(
- name="test_name",
- reason="test_reason",
- last_response={"DBInstances": [{"DBInstanceStatus": "CREATING"}]},
- )
-
- error_failed = WaiterError(
- name="test_name",
- reason="Waiter encountered a terminal failure state:",
- last_response={"DBInstances": [{"DBInstanceStatus": "FAILED"}]},
- )
- a_mock.get_waiter().wait = AsyncMock(side_effect=[error_creating,
error_creating, error_failed])
- mock_sleep.return_value = True
-
- rds_db_instance_trigger = RdsDbInstanceTrigger(
- waiter_name="test-waiter",
- db_instance_identifier=TEST_DB_INSTANCE_IDENTIFIER,
- waiter_delay=TEST_WAITER_DELAY,
- waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
- aws_conn_id=TEST_AWS_CONN_ID,
- region_name=TEST_REGION,
- response=TEST_RESPONSE,
- )
-
- with pytest.raises(AirflowException) as exc:
- generator = rds_db_instance_trigger.run()
- await generator.asend(None)
- assert "Error checking DB Instance status" in str(exc.value)
- assert a_mock.get_waiter().wait.call_count == 3
+class TestRdsTriggers:
+ @pytest.mark.parametrize(
+ "trigger",
+ [
+ RdsDbAvailableTrigger(
+ db_identifier=TEST_DB_INSTANCE_IDENTIFIER,
+ waiter_delay=TEST_WAITER_DELAY,
+ waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
+ aws_conn_id=TEST_AWS_CONN_ID,
+ region_name=TEST_REGION,
+ response=TEST_RESPONSE,
+ db_type=RdsDbType.INSTANCE,
+ ),
+ RdsDbDeletedTrigger(
+ db_identifier=TEST_DB_INSTANCE_IDENTIFIER,
+ waiter_delay=TEST_WAITER_DELAY,
+ waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
+ aws_conn_id=TEST_AWS_CONN_ID,
+ region_name=TEST_REGION,
+ response=TEST_RESPONSE,
+ db_type=RdsDbType.INSTANCE,
+ ),
+ RdsDbStoppedTrigger(
+ db_identifier=TEST_DB_INSTANCE_IDENTIFIER,
+ waiter_delay=TEST_WAITER_DELAY,
+ waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
+ aws_conn_id=TEST_AWS_CONN_ID,
+ region_name=TEST_REGION,
+ response=TEST_RESPONSE,
+ db_type=RdsDbType.INSTANCE,
+ ),
+ ],
+ )
+ def test_serialize_recreate(self, trigger):
+ class_path, args = trigger.serialize()
+
+ class_name = class_path.split(".")[-1]
+ clazz = globals()[class_name]
+ instance = clazz(**args)
+
+ class_path2, args2 = instance.serialize()
+
+ assert class_path == class_path2
+ assert args == args2