This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 2bba98f109 Use Boto waiters instead of customer _await_status method
for RDS Operators (#27410)
2bba98f109 is described below
commit 2bba98f109cc7737f4293a195e03a0cc21a624cb
Author: Hank Ehly <[email protected]>
AuthorDate: Fri Nov 18 02:02:21 2022 +0900
Use Boto waiters instead of customer _await_status method for RDS Operators
(#27410)
---
airflow/providers/amazon/aws/hooks/rds.py | 301 +++++++++++++++++++++-
airflow/providers/amazon/aws/operators/rds.py | 155 +++--------
airflow/providers/amazon/aws/sensors/rds.py | 76 ++----
tests/providers/amazon/aws/hooks/test_rds.py | 312 +++++++++++++++++++++++
tests/providers/amazon/aws/operators/test_rds.py | 135 ++++++----
tests/providers/amazon/aws/sensors/test_rds.py | 19 +-
6 files changed, 755 insertions(+), 243 deletions(-)
diff --git a/airflow/providers/amazon/aws/hooks/rds.py
b/airflow/providers/amazon/aws/hooks/rds.py
index cd2434282c..15790e3add 100644
--- a/airflow/providers/amazon/aws/hooks/rds.py
+++ b/airflow/providers/amazon/aws/hooks/rds.py
@@ -18,8 +18,10 @@
"""Interact with AWS RDS."""
from __future__ import annotations
-from typing import TYPE_CHECKING
+import time
+from typing import TYPE_CHECKING, Callable
+from airflow.exceptions import AirflowException, AirflowNotFoundException
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
if TYPE_CHECKING:
@@ -48,3 +50,300 @@ class RdsHook(AwsGenericHook["RDSClient"]):
def __init__(self, *args, **kwargs) -> None:
kwargs["client_type"] = "rds"
super().__init__(*args, **kwargs)
+
+ def get_db_snapshot_state(self, snapshot_id: str) -> str:
+ """
+ Get the current state of a DB instance snapshot.
+
+ :param snapshot_id: The ID of the target DB instance snapshot
+ :return: Returns the status of the DB snapshot as a string (eg.
"available")
+ :rtype: str
+ :raises AirflowNotFoundException: If the DB instance snapshot does not
exist.
+ """
+ try:
+ response =
self.conn.describe_db_snapshots(DBSnapshotIdentifier=snapshot_id)
+ except self.conn.exceptions.ClientError as e:
+ if e.response["Error"]["Code"] == "DBSnapshotNotFound":
+ raise AirflowNotFoundException(e)
+ raise e
+ return response["DBSnapshots"][0]["Status"].lower()
+
+ def wait_for_db_snapshot_state(
+ self, snapshot_id: str, target_state: str, check_interval: int = 30,
max_attempts: int = 40
+ ) -> None:
+ """
+ Polls :py:meth:`RDS.Client.describe_db_snapshots` until the target
state is reached.
+ An error is raised after a max number of attempts.
+
+ :param snapshot_id: The ID of the target DB instance snapshot
+ :param target_state: Wait until this state is reached
+ :param check_interval: The amount of time in seconds to wait between
attempts
+ :param max_attempts: The maximum number of attempts to be made
+ """
+
+ def poke():
+ return self.get_db_snapshot_state(snapshot_id)
+
+ target_state = target_state.lower()
+ if target_state in ("available", "deleted", "completed"):
+ waiter = self.conn.get_waiter(f"db_snapshot_{target_state}") #
type: ignore
+ waiter.wait(
+ DBSnapshotIdentifier=snapshot_id,
+ WaiterConfig={"Delay": check_interval, "MaxAttempts":
max_attempts},
+ )
+ else:
+ self._wait_for_state(poke, target_state, check_interval,
max_attempts)
+ self.log.info("DB snapshot '%s' reached the '%s' state",
snapshot_id, target_state)
+
+ def get_db_cluster_snapshot_state(self, snapshot_id: str) -> str:
+ """
+ Get the current state of a DB cluster snapshot.
+
+ :param snapshot_id: The ID of the target DB cluster.
+ :return: Returns the status of the DB cluster snapshot as a string
(eg. "available")
+ :rtype: str
+ :raises AirflowNotFoundException: If the DB cluster snapshot does not
exist.
+ """
+ try:
+ response =
self.conn.describe_db_cluster_snapshots(DBClusterSnapshotIdentifier=snapshot_id)
+ except self.conn.exceptions.ClientError as e:
+ if e.response["Error"]["Code"] == "DBClusterSnapshotNotFoundFault":
+ raise AirflowNotFoundException(e)
+ raise e
+ return response["DBClusterSnapshots"][0]["Status"].lower()
+
+ def wait_for_db_cluster_snapshot_state(
+ self, snapshot_id: str, target_state: str, check_interval: int = 30,
max_attempts: int = 40
+ ) -> None:
+ """
+ Polls :py:meth:`RDS.Client.describe_db_cluster_snapshots` until the
target state is reached.
+ An error is raised after a max number of attempts.
+
+ :param snapshot_id: The ID of the target DB cluster snapshot
+ :param target_state: Wait until this state is reached
+ :param check_interval: The amount of time in seconds to wait between
attempts
+ :param max_attempts: The maximum number of attempts to be made
+
+ .. seealso::
+ A list of possible values for target_state:
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds.html#RDS.Client.describe_db_cluster_snapshots
+ """
+
+ def poke():
+ return self.get_db_cluster_snapshot_state(snapshot_id)
+
+ target_state = target_state.lower()
+ if target_state in ("available", "deleted"):
+ waiter =
self.conn.get_waiter(f"db_cluster_snapshot_{target_state}") # type: ignore
+ waiter.wait(
+ DBClusterSnapshotIdentifier=snapshot_id,
+ WaiterConfig={"Delay": check_interval, "MaxAttempts":
max_attempts},
+ )
+ else:
+ self._wait_for_state(poke, target_state, check_interval,
max_attempts)
+ self.log.info("DB cluster snapshot '%s' reached the '%s' state",
snapshot_id, target_state)
+
+ def get_export_task_state(self, export_task_id: str) -> str:
+ """
+ Gets the current state of an RDS snapshot export to Amazon S3.
+
+ :param export_task_id: The identifier of the target snapshot export
task.
+ :return: Returns the status of the snapshot export task as a string
(eg. "canceled")
+ :rtype: str
+ :raises AirflowNotFoundException: If the export task does not exist.
+ """
+ try:
+ response =
self.conn.describe_export_tasks(ExportTaskIdentifier=export_task_id)
+ except self.conn.exceptions.ClientError as e:
+ if e.response["Error"]["Code"] == "ExportTaskNotFoundFault":
+ raise AirflowNotFoundException(e)
+ raise e
+ return response["ExportTasks"][0]["Status"].lower()
+
+ def wait_for_export_task_state(
+ self, export_task_id: str, target_state: str, check_interval: int =
30, max_attempts: int = 40
+ ) -> None:
+ """
+ Polls :py:meth:`RDS.Client.describe_export_tasks` until the target
state is reached.
+ An error is raised after a max number of attempts.
+
+ :param export_task_id: The identifier of the target snapshot export
task.
+ :param target_state: Wait until this state is reached
+ :param check_interval: The amount of time in seconds to wait between
attempts
+ :param max_attempts: The maximum number of attempts to be made
+
+ .. seealso::
+ A list of possible values for target_state:
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds.html#RDS.Client.describe_export_tasks
+ """
+
+ def poke():
+ return self.get_export_task_state(export_task_id)
+
+ target_state = target_state.lower()
+ self._wait_for_state(poke, target_state, check_interval, max_attempts)
+ self.log.info("export task '%s' reached the '%s' state",
export_task_id, target_state)
+
+ def get_event_subscription_state(self, subscription_name: str) -> str:
+ """
+ Gets the current state of an RDS snapshot export to Amazon S3.
+
+ :param subscription_name: The name of the target RDS event
notification subscription.
+ :return: Returns the status of the event subscription as a string (eg.
"active")
+ :rtype: str
+ :raises AirflowNotFoundException: If the event subscription does not
exist.
+ """
+ try:
+ response =
self.conn.describe_event_subscriptions(SubscriptionName=subscription_name)
+ except self.conn.exceptions.ClientError as e:
+ if e.response["Error"]["Code"] == "SubscriptionNotFoundFault":
+ raise AirflowNotFoundException(e)
+ raise e
+ return response["EventSubscriptionsList"][0]["Status"].lower()
+
+ def wait_for_event_subscription_state(
+ self, subscription_name: str, target_state: str, check_interval: int =
30, max_attempts: int = 40
+ ) -> None:
+ """
+ Polls :py:meth:`RDS.Client.describe_event_subscriptions` until the
target state is reached.
+ An error is raised after a max number of attempts.
+
+ :param subscription_name: The name of the target RDS event
notification subscription.
+ :param target_state: Wait until this state is reached
+ :param check_interval: The amount of time in seconds to wait between
attempts
+ :param max_attempts: The maximum number of attempts to be made
+
+ .. seealso::
+ A list of possible values for target_state:
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds.html#RDS.Client.describe_event_subscriptions
+ """
+
+ def poke():
+ return self.get_event_subscription_state(subscription_name)
+
+ target_state = target_state.lower()
+ self._wait_for_state(poke, target_state, check_interval, max_attempts)
+ self.log.info("event subscription '%s' reached the '%s' state",
subscription_name, target_state)
+
+ def get_db_instance_state(self, db_instance_id: str) -> str:
+ """
+ Get the current state of a DB instance.
+
+ :param snapshot_id: The ID of the target DB instance.
+ :return: Returns the status of the DB instance as a string (eg.
"available")
+ :rtype: str
+ :raises AirflowNotFoundException: If the DB instance does not exist.
+ """
+ try:
+ response =
self.conn.describe_db_instances(DBInstanceIdentifier=db_instance_id)
+ except self.conn.exceptions.ClientError as e:
+ if e.response["Error"]["Code"] == "DBInstanceNotFoundFault":
+ raise AirflowNotFoundException(e)
+ raise e
+ return response["DBInstances"][0]["DBInstanceStatus"].lower()
+
+ def wait_for_db_instance_state(
+ self, db_instance_id: str, target_state: str, check_interval: int =
30, max_attempts: int = 40
+ ) -> None:
+ """
+ Polls :py:meth:`RDS.Client.describe_db_instances` until the target
state is reached.
+ An error is raised after a max number of attempts.
+
+ :param db_instance_id: The ID of the target DB instance.
+ :param target_state: Wait until this state is reached
+ :param check_interval: The amount of time in seconds to wait between
attempts
+ :param max_attempts: The maximum number of attempts to be made
+
+ .. seealso::
+ For information about DB instance statuses, see Viewing DB
instance status in the Amazon RDS
+ User Guide.
+
https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/accessing-monitoring.html#Overview.DBInstance.Status
+ """
+
+ def poke():
+ return self.get_db_instance_state(db_instance_id)
+
+ target_state = target_state.lower()
+ if target_state in ("available", "deleted"):
+ waiter = self.conn.get_waiter(f"db_instance_{target_state}") #
type: ignore
+ waiter.wait(
+ DBInstanceIdentifier=db_instance_id,
+ WaiterConfig={"Delay": check_interval, "MaxAttempts":
max_attempts},
+ )
+ else:
+ self._wait_for_state(poke, target_state, check_interval,
max_attempts)
+ self.log.info("DB cluster snapshot '%s' reached the '%s' state",
db_instance_id, target_state)
+
+ def get_db_cluster_state(self, db_cluster_id: str) -> str:
+ """
+ Get the current state of a DB cluster.
+
+ :param snapshot_id: The ID of the target DB cluster.
+ :return: Returns the status of the DB cluster as a string (eg.
"available")
+ :rtype: str
+ :raises AirflowNotFoundException: If the DB cluster does not exist.
+ """
+ try:
+ response =
self.conn.describe_db_clusters(DBClusterIdentifier=db_cluster_id)
+ except self.conn.exceptions.ClientError as e:
+ if e.response["Error"]["Code"] == "DBClusterNotFoundFault":
+ raise AirflowNotFoundException(e)
+ raise e
+ return response["DBClusters"][0]["Status"].lower()
+
+ def wait_for_db_cluster_state(
+ self, db_cluster_id: str, target_state: str, check_interval: int = 30,
max_attempts: int = 40
+ ) -> None:
+ """
+ Polls :py:meth:`RDS.Client.describe_db_clusters` until the target
state is reached.
+ An error is raised after a max number of attempts.
+
+ :param db_cluster_id: The ID of the target DB cluster.
+ :param target_state: Wait until this state is reached
+ :param check_interval: The amount of time in seconds to wait between
attempts
+ :param max_attempts: The maximum number of attempts to be made
+
+ .. seealso::
+ For information about DB instance statuses, see Viewing DB
instance status in the Amazon RDS
+ User Guide.
+
https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/accessing-monitoring.html#Overview.DBInstance.Status
+ """
+
+ def poke():
+ return self.get_db_cluster_state(db_cluster_id)
+
+ target_state = target_state.lower()
+ if target_state in ("available", "deleted"):
+ waiter = self.conn.get_waiter(f"db_cluster_{target_state}") #
type: ignore
+ waiter.wait(
+ DBClusterIdentifier=db_cluster_id,
+ WaiterConfig={"Delay": check_interval, "MaxAttempts":
max_attempts},
+ )
+ else:
+ self._wait_for_state(poke, target_state, check_interval,
max_attempts)
+ self.log.info("DB cluster snapshot '%s' reached the '%s' state",
db_cluster_id, target_state)
+
+ def _wait_for_state(
+ self,
+ poke: Callable[..., str],
+ target_state: str,
+ check_interval: int,
+ max_attempts: int,
+ ) -> None:
+ """
+ Polls the poke function for the current state until it reaches the
target_state.
+
+ :param poke: A function that returns the current state of the target
resource as a string.
+ :param target_state: Wait until this state is reached
+ :param check_interval: The amount of time in seconds to wait between
attempts
+ :param max_attempts: The maximum number of attempts to be made
+ """
+ state = poke()
+ tries = 1
+ while state != target_state:
+ self.log.info("Current state is %s", state)
+ if tries >= max_attempts:
+ raise AirflowException("Max attempts exceeded")
+ time.sleep(check_interval)
+ state = poke()
+ tries += 1
diff --git a/airflow/providers/amazon/aws/operators/rds.py
b/airflow/providers/amazon/aws/operators/rds.py
index f9e23eaf91..c10e969c8b 100644
--- a/airflow/providers/amazon/aws/operators/rds.py
+++ b/airflow/providers/amazon/aws/operators/rds.py
@@ -18,12 +18,10 @@
from __future__ import annotations
import json
-import time
from typing import TYPE_CHECKING, Sequence
from mypy_boto3_rds.type_defs import TagTypeDef
-from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.rds import RdsHook
from airflow.providers.amazon.aws.utils.rds import RdsDbType
@@ -45,66 +43,6 @@ class RdsBaseOperator(BaseOperator):
self._await_interval = 60 # seconds
- def _describe_item(self, item_type: str, item_name: str) -> list:
- if item_type == "instance_snapshot":
- db_snaps =
self.hook.conn.describe_db_snapshots(DBSnapshotIdentifier=item_name)
- return db_snaps["DBSnapshots"]
- elif item_type == "cluster_snapshot":
- cl_snaps =
self.hook.conn.describe_db_cluster_snapshots(DBClusterSnapshotIdentifier=item_name)
- return cl_snaps["DBClusterSnapshots"]
- elif item_type == "export_task":
- exports =
self.hook.conn.describe_export_tasks(ExportTaskIdentifier=item_name)
- return exports["ExportTasks"]
- elif item_type == "event_subscription":
- subscriptions =
self.hook.conn.describe_event_subscriptions(SubscriptionName=item_name)
- return subscriptions["EventSubscriptionsList"]
- elif item_type == "db_instance":
- instances =
self.hook.conn.describe_db_instances(DBInstanceIdentifier=item_name)
- return instances["DBInstances"]
- elif item_type == "db_cluster":
- clusters =
self.hook.conn.describe_db_clusters(DBClusterIdentifier=item_name)
- return clusters["DBClusters"]
- else:
- raise AirflowException(f"Method for {item_type} is not
implemented")
-
- def _await_status(
- self,
- item_type: str,
- item_name: str,
- wait_statuses: list[str] | None = None,
- ok_statuses: list[str] | None = None,
- error_statuses: list[str] | None = None,
- ) -> None:
- """
- Continuously gets item description from `_describe_item()` and waits
while:
- - status is in `wait_statuses`
- - status not in `ok_statuses` and `error_statuses`
- """
- while True:
- items = self._describe_item(item_type, item_name)
-
- if len(items) == 0:
- raise AirflowException(f"There is no {item_type} with
identifier {item_name}")
- if len(items) > 1:
- raise AirflowException(f"There are {len(items)} {item_type}
with identifier {item_name}")
-
- if item_type == "db_instance":
- status_field = "DBInstanceStatus"
- else:
- status_field = "Status"
-
- if wait_statuses and items[0][status_field].lower() in
wait_statuses:
- time.sleep(self._await_interval)
- continue
- elif ok_statuses and items[0][status_field].lower() in ok_statuses:
- break
- elif error_statuses and items[0][status_field].lower() in
error_statuses:
- raise AirflowException(f"Item has error status
({error_statuses}): {items[0]}")
- else:
- raise AirflowException(f"Item has uncertain status:
{items[0]}")
-
- return None
-
def execute(self, context: Context) -> str:
"""Different implementations for snapshots, tasks and events"""
raise NotImplementedError
@@ -166,8 +104,8 @@ class RdsCreateDbSnapshotOperator(RdsBaseOperator):
Tags=self.tags,
)
create_response = json.dumps(create_instance_snap, default=str)
- item_type = "instance_snapshot"
-
+ if self.wait_for_completion:
+
self.hook.wait_for_db_snapshot_state(self.db_snapshot_identifier,
target_state="available")
else:
create_cluster_snap = self.hook.conn.create_db_cluster_snapshot(
DBClusterIdentifier=self.db_identifier,
@@ -175,15 +113,10 @@ class RdsCreateDbSnapshotOperator(RdsBaseOperator):
Tags=self.tags,
)
create_response = json.dumps(create_cluster_snap, default=str)
- item_type = "cluster_snapshot"
-
- if self.wait_for_completion:
- self._await_status(
- item_type,
- self.db_snapshot_identifier,
- wait_statuses=["creating"],
- ok_statuses=["available"],
- )
+ if self.wait_for_completion:
+ self.hook.wait_for_db_cluster_snapshot_state(
+ self.db_snapshot_identifier, target_state="available"
+ )
return create_response
@@ -270,8 +203,10 @@ class RdsCopyDbSnapshotOperator(RdsBaseOperator):
SourceRegion=self.source_region,
)
copy_response = json.dumps(copy_instance_snap, default=str)
- item_type = "instance_snapshot"
-
+ if self.wait_for_completion:
+ self.hook.wait_for_db_snapshot_state(
+ self.target_db_snapshot_identifier,
target_state="available"
+ )
else:
copy_cluster_snap = self.hook.conn.copy_db_cluster_snapshot(
SourceDBClusterSnapshotIdentifier=self.source_db_snapshot_identifier,
@@ -283,15 +218,10 @@ class RdsCopyDbSnapshotOperator(RdsBaseOperator):
SourceRegion=self.source_region,
)
copy_response = json.dumps(copy_cluster_snap, default=str)
- item_type = "cluster_snapshot"
-
- if self.wait_for_completion:
- self._await_status(
- item_type,
- self.target_db_snapshot_identifier,
- wait_statuses=["creating", "copying"],
- ok_statuses=["available"],
- )
+ if self.wait_for_completion:
+ self.hook.wait_for_db_cluster_snapshot_state(
+ self.target_db_snapshot_identifier,
target_state="available"
+ )
return copy_response
@@ -314,6 +244,7 @@ class RdsDeleteDbSnapshotOperator(RdsBaseOperator):
*,
db_type: str,
db_snapshot_identifier: str,
+ wait_for_completion: bool = True,
aws_conn_id: str = "aws_default",
**kwargs,
):
@@ -321,6 +252,7 @@ class RdsDeleteDbSnapshotOperator(RdsBaseOperator):
self.db_type = RdsDbType(db_type)
self.db_snapshot_identifier = db_snapshot_identifier
+ self.wait_for_completion = wait_for_completion
def execute(self, context: Context) -> str:
self.log.info("Starting to delete snapshot '%s'",
self.db_snapshot_identifier)
@@ -330,11 +262,17 @@ class RdsDeleteDbSnapshotOperator(RdsBaseOperator):
DBSnapshotIdentifier=self.db_snapshot_identifier,
)
delete_response = json.dumps(delete_instance_snap, default=str)
+ if self.wait_for_completion:
+
self.hook.wait_for_db_snapshot_state(self.db_snapshot_identifier,
target_state="deleted")
else:
delete_cluster_snap = self.hook.conn.delete_db_cluster_snapshot(
DBClusterSnapshotIdentifier=self.db_snapshot_identifier,
)
delete_response = json.dumps(delete_cluster_snap, default=str)
+ if self.wait_for_completion:
+ self.hook.wait_for_db_cluster_snapshot_state(
+ self.db_snapshot_identifier, target_state="deleted"
+ )
return delete_response
@@ -406,14 +344,7 @@ class RdsStartExportTaskOperator(RdsBaseOperator):
)
if self.wait_for_completion:
- self._await_status(
- "export_task",
- self.export_task_identifier,
- wait_statuses=["starting", "in_progress"],
- ok_statuses=["complete"],
- error_statuses=["canceling", "canceled"],
- )
-
+ self.hook.wait_for_export_task_state(self.export_task_identifier,
target_state="complete")
return json.dumps(start_export, default=str)
@@ -452,13 +383,7 @@ class RdsCancelExportTaskOperator(RdsBaseOperator):
)
if self.wait_for_completion:
- self._await_status(
- "export_task",
- self.export_task_identifier,
- wait_statuses=["canceling"],
- ok_statuses=["canceled"],
- )
-
+ self.hook.wait_for_export_task_state(self.export_task_identifier,
target_state="canceled")
return json.dumps(cancel_export, default=str)
@@ -531,13 +456,7 @@ class RdsCreateEventSubscriptionOperator(RdsBaseOperator):
)
if self.wait_for_completion:
- self._await_status(
- "event_subscription",
- self.subscription_name,
- wait_statuses=["creating"],
- ok_statuses=["active"],
- )
-
+
self.hook.wait_for_event_subscription_state(self.subscription_name,
target_state="active")
return json.dumps(create_subscription, default=str)
@@ -628,10 +547,7 @@ class RdsCreateDbInstanceOperator(RdsBaseOperator):
)
if self.wait_for_completion:
- self.hook.conn.get_waiter("db_instance_available").wait(
- DBInstanceIdentifier=self.db_instance_identifier
- )
-
+ self.hook.wait_for_db_instance_state(self.db_instance_identifier,
target_state="available")
return json.dumps(create_db_instance, default=str)
@@ -675,10 +591,7 @@ class RdsDeleteDbInstanceOperator(RdsBaseOperator):
)
if self.wait_for_completion:
- self.hook.conn.get_waiter("db_instance_deleted").wait(
- DBInstanceIdentifier=self.db_instance_identifier
- )
-
+ self.hook.wait_for_db_instance_state(self.db_instance_identifier,
target_state="deleted")
return json.dumps(delete_db_instance, default=str)
@@ -730,9 +643,9 @@ class RdsStartDbOperator(RdsBaseOperator):
def _wait_until_db_available(self):
self.log.info("Waiting for DB %s to reach 'available' state",
self.db_type.value)
if self.db_type == RdsDbType.INSTANCE:
-
self.hook.conn.get_waiter("db_instance_available").wait(DBInstanceIdentifier=self.db_identifier)
+ self.hook.wait_for_db_instance_state(self.db_identifier,
target_state="available")
else:
-
self.hook.conn.get_waiter("db_cluster_available").wait(DBClusterIdentifier=self.db_identifier)
+ self.hook.wait_for_db_cluster_state(self.db_identifier,
target_state="available")
class RdsStopDbOperator(RdsBaseOperator):
@@ -797,16 +710,10 @@ class RdsStopDbOperator(RdsBaseOperator):
def _wait_until_db_stopped(self):
self.log.info("Waiting for DB %s to reach 'stopped' state",
self.db_type.value)
- wait_statuses = ["stopping"]
- ok_statuses = ["stopped"]
if self.db_type == RdsDbType.INSTANCE:
- self._await_status(
- "db_instance", self.db_identifier,
wait_statuses=wait_statuses, ok_statuses=ok_statuses
- )
+ self.hook.wait_for_db_instance_state(self.db_identifier,
target_state="stopped")
else:
- self._await_status(
- "db_cluster", self.db_identifier, wait_statuses=wait_statuses,
ok_statuses=ok_statuses
- )
+ self.hook.wait_for_db_cluster_state(self.db_identifier,
target_state="stopped")
__all__ = [
diff --git a/airflow/providers/amazon/aws/sensors/rds.py
b/airflow/providers/amazon/aws/sensors/rds.py
index 07933809fe..731c8b5def 100644
--- a/airflow/providers/amazon/aws/sensors/rds.py
+++ b/airflow/providers/amazon/aws/sensors/rds.py
@@ -18,9 +18,7 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Sequence
-from botocore.exceptions import ClientError
-
-from airflow import AirflowException
+from airflow.exceptions import AirflowNotFoundException
from airflow.providers.amazon.aws.hooks.rds import RdsHook
from airflow.providers.amazon.aws.utils.rds import RdsDbType
from airflow.sensors.base import BaseSensorOperator
@@ -41,40 +39,6 @@ class RdsBaseSensor(BaseSensorOperator):
self.target_statuses: list[str] = []
super().__init__(*args, **kwargs)
- def _describe_item(self, item_type: str, item_name: str) -> list:
- if item_type == "instance_snapshot":
- db_snaps =
self.hook.conn.describe_db_snapshots(DBSnapshotIdentifier=item_name)
- return db_snaps["DBSnapshots"]
- elif item_type == "cluster_snapshot":
- cl_snaps =
self.hook.conn.describe_db_cluster_snapshots(DBClusterSnapshotIdentifier=item_name)
- return cl_snaps["DBClusterSnapshots"]
- elif item_type == "export_task":
- exports =
self.hook.conn.describe_export_tasks(ExportTaskIdentifier=item_name)
- return exports["ExportTasks"]
- elif item_type == "db_instance":
- instances =
self.hook.conn.describe_db_instances(DBInstanceIdentifier=item_name)
- return instances["DBInstances"]
- elif item_type == "db_cluster":
- clusters =
self.hook.conn.describe_db_clusters(DBClusterIdentifier=item_name)
- return clusters["DBClusters"]
- else:
- raise AirflowException(f"Method for {item_type} is not
implemented")
-
- def _check_item(self, item_type: str, item_name: str) -> bool:
- """Get certain item from `_describe_item()` and check its status"""
- if item_type == "db_instance":
- status_field = "DBInstanceStatus"
- else:
- status_field = "Status"
- try:
- items = self._describe_item(item_type, item_name)
- except ClientError:
- return False
- else:
- return bool(items) and any(
- map(lambda status: items[0][status_field].lower() == status,
self.target_statuses)
- )
-
class RdsSnapshotExistenceSensor(RdsBaseSensor):
"""
@@ -112,10 +76,14 @@ class RdsSnapshotExistenceSensor(RdsBaseSensor):
self.log.info(
"Poking for statuses : %s\nfor snapshot %s", self.target_statuses,
self.db_snapshot_identifier
)
- if self.db_type.value == "instance":
- return self._check_item(item_type="instance_snapshot",
item_name=self.db_snapshot_identifier)
- else:
- return self._check_item(item_type="cluster_snapshot",
item_name=self.db_snapshot_identifier)
+ try:
+ if self.db_type.value == "instance":
+ state =
self.hook.get_db_snapshot_state(self.db_snapshot_identifier)
+ else:
+ state =
self.hook.get_db_cluster_snapshot_state(self.db_snapshot_identifier)
+ except AirflowNotFoundException:
+ return False
+ return state in self.target_statuses
class RdsExportTaskExistenceSensor(RdsBaseSensor):
@@ -158,7 +126,11 @@ class RdsExportTaskExistenceSensor(RdsBaseSensor):
self.log.info(
"Poking for statuses : %s\nfor export task %s",
self.target_statuses, self.export_task_identifier
)
- return self._check_item(item_type="export_task",
item_name=self.export_task_identifier)
+ try:
+ state =
self.hook.get_export_task_state(self.export_task_identifier)
+ except AirflowNotFoundException:
+ return False
+ return state in self.target_statuses
class RdsDbSensor(RdsBaseSensor):
@@ -184,7 +156,7 @@ class RdsDbSensor(RdsBaseSensor):
self,
*,
db_identifier: str,
- db_type: str = "instance",
+ db_type: RdsDbType | str = RdsDbType.INSTANCE,
target_statuses: list[str] | None = None,
aws_conn_id: str = "aws_default",
**kwargs,
@@ -192,19 +164,21 @@ class RdsDbSensor(RdsBaseSensor):
super().__init__(aws_conn_id=aws_conn_id, **kwargs)
self.db_identifier = db_identifier
self.target_statuses = target_statuses or ["available"]
- self.db_type = RdsDbType(db_type)
+ self.db_type = db_type
def poke(self, context: Context):
+ db_type = RdsDbType(self.db_type)
self.log.info(
"Poking for statuses : %s\nfor db instance %s",
self.target_statuses, self.db_identifier
)
- item_type = self._check_item_type()
- return self._check_item(item_type=item_type,
item_name=self.db_identifier)
-
- def _check_item_type(self):
- if self.db_type == RdsDbType.CLUSTER:
- return "db_cluster"
- return "db_instance"
+ try:
+ if db_type == RdsDbType.INSTANCE:
+ state = self.hook.get_db_instance_state(self.db_identifier)
+ else:
+ state = self.hook.get_db_cluster_state(self.db_identifier)
+ except AirflowNotFoundException:
+ return False
+ return state in self.target_statuses
__all__ = [
diff --git a/tests/providers/amazon/aws/hooks/test_rds.py
b/tests/providers/amazon/aws/hooks/test_rds.py
index b7ea30b9ad..f98320a2f1 100644
--- a/tests/providers/amazon/aws/hooks/test_rds.py
+++ b/tests/providers/amazon/aws/hooks/test_rds.py
@@ -17,10 +17,119 @@
# under the License.
from __future__ import annotations
+from unittest.mock import patch
+
+import pytest
+from moto import mock_rds
+
+from airflow.exceptions import AirflowException, AirflowNotFoundException
from airflow.providers.amazon.aws.hooks.rds import RdsHook
[email protected]
+def rds_hook() -> RdsHook:
+ """Returns an RdsHook whose underlying connection is mocked with moto"""
+ with mock_rds():
+ yield RdsHook(aws_conn_id="aws_default", region_name="us-east-1")
+
+
[email protected]
+def db_instance_id(rds_hook: RdsHook) -> str:
+ """Creates an RDS DB instance and returns its id"""
+ response = rds_hook.conn.create_db_instance(
+ DBInstanceIdentifier="testrdshook-db-instance",
+ DBInstanceClass="db.t4g.micro",
+ Engine="postgres",
+ AllocatedStorage=20,
+ MasterUsername="testrdshook",
+ MasterUserPassword="testrdshook",
+ )
+ return response["DBInstance"]["DBInstanceIdentifier"]
+
+
[email protected]
+def db_cluster_id(rds_hook: RdsHook) -> str:
+ """Creates an RDS DB cluster and returns its id"""
+ response = rds_hook.conn.create_db_cluster(
+ DBClusterIdentifier="testrdshook-db-cluster",
+ Engine="postgres",
+ MasterUsername="testrdshook",
+ MasterUserPassword="testrdshook",
+ DBClusterInstanceClass="db.t4g.micro",
+ AllocatedStorage=20,
+ )
+ return response["DBCluster"]["DBClusterIdentifier"]
+
+
[email protected]
+def db_snapshot(rds_hook: RdsHook, db_instance_id: str) -> dict:
+ """
+ Creates a mock DB instance snapshot and returns the DBSnapshot dict from
the boto response object.
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds.html#RDS.Client.create_db_snapshot
+ """
+ response = rds_hook.conn.create_db_snapshot(
+ DBSnapshotIdentifier="testrdshook-db-instance-snapshot",
DBInstanceIdentifier=db_instance_id
+ )
+ return response["DBSnapshot"]
+
+
[email protected]
+def db_snapshot_id(db_snapshot: dict) -> str:
+ return db_snapshot["DBSnapshotIdentifier"]
+
+
[email protected]
+def db_snapshot_arn(db_snapshot: dict) -> str:
+ return db_snapshot["DBSnapshotArn"]
+
+
[email protected]
+def db_cluster_snapshot(rds_hook: RdsHook, db_cluster_id: str):
+ """
+ Creates a mock DB cluster snapshot and returns the DBClusterSnapshot dict
from the boto response object.
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds.html#RDS.Client.create_db_cluster_snapshot
+ """
+ response = rds_hook.conn.create_db_cluster_snapshot(
+ DBClusterSnapshotIdentifier="testrdshook-db-cluster-snapshot",
DBClusterIdentifier=db_cluster_id
+ )
+ return response["DBClusterSnapshot"]
+
+
[email protected]
+def db_cluster_snapshot_id(db_cluster_snapshot) -> str:
+ return db_cluster_snapshot["DBClusterSnapshotIdentifier"]
+
+
[email protected]
+def export_task_id(rds_hook: RdsHook, db_snapshot_arn: str) -> str:
+ response = rds_hook.conn.start_export_task(
+ ExportTaskIdentifier="testrdshook-export-task",
+ SourceArn=db_snapshot_arn,
+ S3BucketName="test",
+ IamRoleArn="test",
+ KmsKeyId="test",
+ )
+ return response["ExportTaskIdentifier"]
+
+
[email protected]
+def event_subscription_name(rds_hook: RdsHook, db_instance_id: str) -> str:
+ """Creates an mock RDS event subscription and returns its name"""
+ response = rds_hook.conn.create_event_subscription(
+ SubscriptionName="testrdshook-event-subscription",
+ SnsTopicArn="test",
+ SourceType="db-instance",
+ SourceIds=[db_instance_id],
+ Enabled=True,
+ )
+ return response["EventSubscription"]["CustSubscriptionId"]
+
+
class TestRdsHook:
+ # For testing, set the delay between status checks to 0 so that we aren't
sleeping during tests,
+ # and max_attempts to 1 so that we don't retry unless required.
+ waiter_args = {"check_interval": 0, "max_attempts": 1}
+
def test_conn_attribute(self):
hook = RdsHook(aws_conn_id="aws_default", region_name="us-east-1")
assert hasattr(hook, "conn")
@@ -28,3 +137,206 @@ class TestRdsHook:
conn = hook.conn
assert conn is hook.conn # Cached property
assert conn is hook.get_conn() # Same object as returned by `conn`
property
+
+ def test_get_db_instance_state(self, rds_hook: RdsHook, db_instance_id:
str):
+ response =
rds_hook.conn.describe_db_instances(DBInstanceIdentifier=db_instance_id)
+ state_expected = response["DBInstances"][0]["DBInstanceStatus"]
+ state_actual = rds_hook.get_db_instance_state(db_instance_id)
+ assert state_actual == state_expected
+
+ def test_wait_for_db_instance_state_boto_waiters(self, rds_hook: RdsHook,
db_instance_id: str):
+ """Checks that the DB instance waiter uses AWS boto waiters where
possible"""
+ for state in ("available", "deleted"):
+ with patch.object(rds_hook.conn, "get_waiter") as mock:
+ rds_hook.wait_for_db_instance_state(db_instance_id,
target_state=state, **self.waiter_args)
+ mock.assert_called_once_with(f"db_instance_{state}")
+ mock.return_value.wait.assert_called_once_with(
+ DBInstanceIdentifier=db_instance_id,
+ WaiterConfig={
+ "Delay": self.waiter_args["check_interval"],
+ "MaxAttempts": self.waiter_args["max_attempts"],
+ },
+ )
+
+ def test_wait_for_db_instance_state_custom_waiter(self, rds_hook: RdsHook,
db_instance_id: str):
+ """Checks that the DB instance waiter uses custom wait logic when AWS
boto waiters aren't available"""
+ with patch.object(rds_hook, "_wait_for_state") as mock:
+ rds_hook.wait_for_db_instance_state(db_instance_id,
target_state="stopped", **self.waiter_args)
+ mock.assert_called_once()
+
+ with patch.object(rds_hook, "get_db_instance_state",
return_value="stopped") as mock:
+ rds_hook.wait_for_db_instance_state(db_instance_id,
target_state="stopped", **self.waiter_args)
+ mock.assert_called_once_with(db_instance_id)
+
+ def test_get_db_cluster_state(self, rds_hook: RdsHook, db_cluster_id: str):
+ response =
rds_hook.conn.describe_db_clusters(DBClusterIdentifier=db_cluster_id)
+ state_expected = response["DBClusters"][0]["Status"]
+ state_actual = rds_hook.get_db_cluster_state(db_cluster_id)
+ assert state_actual == state_expected
+
+ def test_wait_for_db_cluster_state_boto_waiters(self, rds_hook: RdsHook,
db_cluster_id: str):
+ """Checks that the DB cluster waiter uses AWS boto waiters where
possible"""
+ for state in ("available", "deleted"):
+ with patch.object(rds_hook.conn, "get_waiter") as mock:
+ rds_hook.wait_for_db_cluster_state(db_cluster_id,
target_state=state, **self.waiter_args)
+ mock.assert_called_once_with(f"db_cluster_{state}")
+ mock.return_value.wait.assert_called_once_with(
+ DBClusterIdentifier=db_cluster_id,
+ WaiterConfig={
+ "Delay": self.waiter_args["check_interval"],
+ "MaxAttempts": self.waiter_args["max_attempts"],
+ },
+ )
+
+ def test_wait_for_db_cluster_state_custom_waiter(self, rds_hook: RdsHook,
db_cluster_id: str):
+ """Checks that the DB cluster waiter uses custom wait logic when AWS
boto waiters aren't available"""
+ with patch.object(rds_hook, "_wait_for_state") as mock_wait_for_state:
+ rds_hook.wait_for_db_cluster_state(db_cluster_id,
target_state="stopped", **self.waiter_args)
+ mock_wait_for_state.assert_called_once()
+
+ with patch.object(rds_hook, "get_db_cluster_state",
return_value="stopped") as mock:
+ rds_hook.wait_for_db_cluster_state(db_cluster_id,
target_state="stopped", **self.waiter_args)
+ mock.assert_called_once_with(db_cluster_id)
+
+ def test_get_db_snapshot_state(self, rds_hook: RdsHook, db_snapshot_id:
str):
+ response =
rds_hook.conn.describe_db_snapshots(DBSnapshotIdentifier=db_snapshot_id)
+ state_expected = response["DBSnapshots"][0]["Status"]
+ state_actual = rds_hook.get_db_snapshot_state(db_snapshot_id)
+ assert state_actual == state_expected
+
+ def test_get_db_snapshot_state_not_found(self, rds_hook: RdsHook):
+ with pytest.raises(AirflowNotFoundException):
+ rds_hook.get_db_snapshot_state("does_not_exist")
+
+ def test_wait_for_db_snapshot_state_boto_waiters(self, rds_hook: RdsHook,
db_snapshot_id: str):
+ """Checks that the DB snapshot waiter uses AWS boto waiters where
possible"""
+ for state in ("available", "deleted", "completed"):
+ with patch.object(rds_hook.conn, "get_waiter") as mock:
+ rds_hook.wait_for_db_snapshot_state(db_snapshot_id,
target_state=state, **self.waiter_args)
+ mock.assert_called_once_with(f"db_snapshot_{state}")
+ mock.return_value.wait.assert_called_once_with(
+ DBSnapshotIdentifier=db_snapshot_id,
+ WaiterConfig={
+ "Delay": self.waiter_args["check_interval"],
+ "MaxAttempts": self.waiter_args["max_attempts"],
+ },
+ )
+
+ def test_wait_for_db_snapshot_state_custom_waiter(self, rds_hook: RdsHook,
db_snapshot_id: str):
+ """Checks that the DB snapshot waiter uses custom wait logic when AWS
boto waiters aren't available"""
+ with patch.object(rds_hook, "_wait_for_state") as mock:
+ rds_hook.wait_for_db_snapshot_state(db_snapshot_id,
target_state="canceled", **self.waiter_args)
+ mock.assert_called_once()
+
+ with patch.object(rds_hook, "get_db_snapshot_state",
return_value="canceled") as mock:
+ rds_hook.wait_for_db_snapshot_state(db_snapshot_id,
target_state="canceled", **self.waiter_args)
+ mock.assert_called_once_with(db_snapshot_id)
+
+ def test_get_db_cluster_snapshot_state(self, rds_hook: RdsHook,
db_cluster_snapshot_id: str):
+ response = rds_hook.conn.describe_db_cluster_snapshots(
+ DBClusterSnapshotIdentifier=db_cluster_snapshot_id
+ )
+ state_expected = response["DBClusterSnapshots"][0]["Status"]
+ state_actual =
rds_hook.get_db_cluster_snapshot_state(db_cluster_snapshot_id)
+ assert state_actual == state_expected
+
+ def test_get_db_cluster_snapshot_state_not_found(self, rds_hook: RdsHook):
+ with pytest.raises(AirflowNotFoundException):
+ rds_hook.get_db_cluster_snapshot_state("does_not_exist")
+
+ def test_wait_for_db_cluster_snapshot_state_boto_waiters(
+ self, rds_hook: RdsHook, db_cluster_snapshot_id: str
+ ):
+ """Checks that the DB cluster snapshot waiter uses AWS boto waiters
where possible"""
+ for state in ("available", "deleted"):
+ with patch.object(rds_hook.conn, "get_waiter") as mock:
+ rds_hook.wait_for_db_cluster_snapshot_state(
+ db_cluster_snapshot_id, target_state=state,
**self.waiter_args
+ )
+ mock.assert_called_once_with(f"db_cluster_snapshot_{state}")
+ mock.return_value.wait.assert_called_once_with(
+ DBClusterSnapshotIdentifier=db_cluster_snapshot_id,
+ WaiterConfig={
+ "Delay": self.waiter_args["check_interval"],
+ "MaxAttempts": self.waiter_args["max_attempts"],
+ },
+ )
+
+ def test_wait_for_db_cluster_snapshot_state_custom_waiter(
+ self, rds_hook: RdsHook, db_cluster_snapshot_id: str
+ ):
+ """
+ Checks that the DB cluster snapshot waiter uses custom wait logic when
AWS boto waiters
+ aren't available
+ """
+ with patch.object(rds_hook, "_wait_for_state") as mock:
+ rds_hook.wait_for_db_cluster_snapshot_state(
+ db_cluster_snapshot_id, target_state="canceled",
**self.waiter_args
+ )
+ mock.assert_called_once()
+
+ with patch.object(rds_hook, "get_db_cluster_snapshot_state",
return_value="canceled") as mock:
+ rds_hook.wait_for_db_cluster_snapshot_state(
+ db_cluster_snapshot_id, target_state="canceled",
**self.waiter_args
+ )
+ mock.assert_called_once_with(db_cluster_snapshot_id)
+
+ def test_get_export_task_state(self, rds_hook: RdsHook, export_task_id:
str):
+ response =
rds_hook.conn.describe_export_tasks(ExportTaskIdentifier=export_task_id)
+ state_expected = response["ExportTasks"][0]["Status"]
+ state_actual = rds_hook.get_export_task_state(export_task_id)
+ assert state_actual == state_expected
+
+ def test_get_export_task_state_not_found(self, rds_hook: RdsHook):
+ with pytest.raises(AirflowNotFoundException):
+ rds_hook.get_export_task_state("does_not_exist")
+
+ def test_wait_for_export_task_state(self, rds_hook: RdsHook,
export_task_id: str):
+ """
+ Checks that the export task waiter uses custom wait logic (no boto
waiters exist for this resource)
+ """
+ with patch.object(rds_hook, "_wait_for_state") as mock:
+ rds_hook.wait_for_export_task_state(export_task_id,
target_state="complete", **self.waiter_args)
+ mock.assert_called_once()
+
+ with patch.object(rds_hook, "get_export_task_state",
return_value="complete") as mock:
+ rds_hook.wait_for_export_task_state(export_task_id,
target_state="complete", **self.waiter_args)
+ mock.assert_called_once_with(export_task_id)
+
+ def test_get_event_subscription_state(self, rds_hook: RdsHook,
event_subscription_name: str):
+ response =
rds_hook.conn.describe_event_subscriptions(SubscriptionName=event_subscription_name)
+ state_expected = response["EventSubscriptionsList"][0]["Status"]
+ state_actual =
rds_hook.get_event_subscription_state(event_subscription_name)
+ assert state_actual == state_expected
+
+ def test_get_event_subscription_state_not_found(self, rds_hook: RdsHook):
+ with pytest.raises(AirflowNotFoundException):
+ rds_hook.get_event_subscription_state("does_not_exist")
+
+ def test_wait_for_event_subscription_state(self, rds_hook: RdsHook,
event_subscription_name: str):
+ """
+ Checks that the event subscription waiter uses custom wait logic (no
boto waiters
+ exist for this resource)
+ """
+ with patch.object(rds_hook, "_wait_for_state") as mock:
+ rds_hook.wait_for_event_subscription_state(
+ event_subscription_name, target_state="active",
**self.waiter_args
+ )
+ mock.assert_called_once()
+
+ with patch.object(rds_hook, "get_event_subscription_state",
return_value="active") as mock:
+ rds_hook.wait_for_event_subscription_state(
+ event_subscription_name, target_state="active",
**self.waiter_args
+ )
+ mock.assert_called_once_with(event_subscription_name)
+
+ def test_wait_for_state(self, rds_hook: RdsHook):
+ def poke():
+ return "foo"
+
+ with pytest.raises(AirflowException, match="Max attempts exceeded"):
+ with patch("airflow.providers.amazon.aws.hooks.rds.time.sleep") as
mock:
+ rds_hook._wait_for_state(poke, target_state="bar",
check_interval=0, max_attempts=2)
+ # This next line should exist outside of the pytest.raises() context
manager or else it won't
+ # get executed
+ mock.assert_called_once_with(0)
diff --git a/tests/providers/amazon/aws/operators/test_rds.py
b/tests/providers/amazon/aws/operators/test_rds.py
index 4817d9dc6e..d2408fa47d 100644
--- a/tests/providers/amazon/aws/operators/test_rds.py
+++ b/tests/providers/amazon/aws/operators/test_rds.py
@@ -23,7 +23,6 @@ from unittest.mock import patch
import pytest
from moto import mock_rds
-from airflow.exceptions import AirflowException
from airflow.models import DAG
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
from airflow.providers.amazon.aws.hooks.rds import RdsHook
@@ -153,26 +152,6 @@ class TestBaseRdsOperator:
assert hasattr(self.op, "hook")
assert self.op.hook.__class__.__name__ == "RdsHook"
- def test_describe_item_wrong_type(self):
- with pytest.raises(AirflowException):
- self.op._describe_item("database", "auth-db")
-
- def test_await_status_error(self):
- self.op._describe_item = lambda item_type, item_name: [{"Status":
"error"}]
- with pytest.raises(AirflowException):
- self.op._await_status(
- item_type="instance_snapshot",
- item_name="",
- wait_statuses=["wait"],
- error_statuses=["error"],
- )
-
- def test_await_status_ok(self):
- self.op._describe_item = lambda item_type, item_name: [{"Status":
"ok"}]
- self.op._await_status(
- item_type="instance_snapshot", item_name="",
wait_statuses=["wait"], ok_statuses=["ok"]
- )
-
class TestRdsCreateDbSnapshotOperator:
@classmethod
@@ -207,8 +186,8 @@ class TestRdsCreateDbSnapshotOperator:
assert len(instance_snapshots) == 1
@mock_rds
- @patch.object(RdsBaseOperator, "_await_status")
- def test_create_db_instance_snapshot_no_wait(self, mock_await_status):
+ @patch.object(RdsHook, "wait_for_db_snapshot_state")
+ def test_create_db_instance_snapshot_no_wait(self, mock_wait):
_create_db_instance(self.hook)
instance_snapshot_operator = RdsCreateDbSnapshotOperator(
task_id="test_instance_no_wait",
@@ -227,7 +206,7 @@ class TestRdsCreateDbSnapshotOperator:
assert instance_snapshots
assert len(instance_snapshots) == 1
- assert mock_await_status.not_called()
+ mock_wait.assert_not_called()
@mock_rds
def test_create_db_cluster_snapshot(self):
@@ -250,8 +229,8 @@ class TestRdsCreateDbSnapshotOperator:
assert len(cluster_snapshots) == 1
@mock_rds
- @patch.object(RdsBaseOperator, "_await_status")
- def test_create_db_cluster_snapshot_no_wait(self, mock_no_wait):
+ @patch.object(RdsHook, "wait_for_db_cluster_snapshot_state")
+ def test_create_db_cluster_snapshot_no_wait(self, mock_wait):
_create_db_cluster(self.hook)
cluster_snapshot_operator = RdsCreateDbSnapshotOperator(
task_id="test_cluster_no_wait",
@@ -270,7 +249,7 @@ class TestRdsCreateDbSnapshotOperator:
assert cluster_snapshots
assert len(cluster_snapshots) == 1
- assert mock_no_wait.not_called()
+ mock_wait.assert_not_called()
class TestRdsCopyDbSnapshotOperator:
@@ -307,7 +286,7 @@ class TestRdsCopyDbSnapshotOperator:
assert len(instance_snapshots) == 1
@mock_rds
- @patch.object(RdsBaseOperator, "_await_status")
+ @patch.object(RdsHook, "wait_for_db_snapshot_state")
def test_copy_db_instance_snapshot_no_wait(self, mock_await_status):
_create_db_instance(self.hook)
_create_db_instance_snapshot(self.hook)
@@ -328,7 +307,7 @@ class TestRdsCopyDbSnapshotOperator:
assert instance_snapshots
assert len(instance_snapshots) == 1
- assert mock_await_status.not_called()
+ mock_await_status.assert_not_called()
@mock_rds
def test_copy_db_cluster_snapshot(self):
@@ -354,7 +333,7 @@ class TestRdsCopyDbSnapshotOperator:
assert len(cluster_snapshots) == 1
@mock_rds
- @patch.object(RdsBaseOperator, "_await_status")
+ @patch.object(RdsHook, "wait_for_db_snapshot_state")
def test_copy_db_cluster_snapshot_no_wait(self, mock_await_status):
_create_db_cluster(self.hook)
_create_db_cluster_snapshot(self.hook)
@@ -376,7 +355,7 @@ class TestRdsCopyDbSnapshotOperator:
assert cluster_snapshots
assert len(cluster_snapshots) == 1
- assert mock_await_status.not_called()
+ mock_await_status.assert_not_called()
class TestRdsDeleteDbSnapshotOperator:
@@ -404,10 +383,37 @@ class TestRdsDeleteDbSnapshotOperator:
dag=self.dag,
)
_patch_hook_get_connection(instance_snapshot_operator.hook)
- instance_snapshot_operator.execute(None)
+ with patch.object(instance_snapshot_operator.hook,
"wait_for_db_snapshot_state") as mock_wait:
+ instance_snapshot_operator.execute(None)
+ mock_wait.assert_called_once_with(DB_INSTANCE_SNAPSHOT,
target_state="deleted")
with pytest.raises(self.hook.conn.exceptions.ClientError):
-
self.hook.conn.describe_db_snapshots(DBSnapshotIdentifier=DB_CLUSTER_SNAPSHOT)
+
self.hook.conn.describe_db_snapshots(DBSnapshotIdentifier=DB_INSTANCE_SNAPSHOT)
+
+ @mock_rds
+ def test_delete_db_instance_snapshot_no_wait(self):
+ """
+ Check that the operator does not wait for the DB instance snapshot
delete operation to complete when
+ wait_for_completion=False
+ """
+ _create_db_instance(self.hook)
+ _create_db_instance_snapshot(self.hook)
+
+ instance_snapshot_operator = RdsDeleteDbSnapshotOperator(
+ task_id="test_delete_db_instance_snapshot_no_wait",
+ db_type="instance",
+ db_snapshot_identifier=DB_INSTANCE_SNAPSHOT,
+ aws_conn_id=AWS_CONN,
+ dag=self.dag,
+ wait_for_completion=False,
+ )
+ _patch_hook_get_connection(instance_snapshot_operator.hook)
+ with patch.object(instance_snapshot_operator.hook,
"wait_for_db_snapshot_state") as mock_wait:
+ instance_snapshot_operator.execute(None)
+ mock_wait.assert_not_called()
+
+ with pytest.raises(self.hook.conn.exceptions.ClientError):
+
self.hook.conn.describe_db_snapshots(DBSnapshotIdentifier=DB_INSTANCE_SNAPSHOT)
@mock_rds
def test_delete_db_cluster_snapshot(self):
@@ -422,7 +428,34 @@ class TestRdsDeleteDbSnapshotOperator:
dag=self.dag,
)
_patch_hook_get_connection(cluster_snapshot_operator.hook)
- cluster_snapshot_operator.execute(None)
+ with patch.object(cluster_snapshot_operator.hook,
"wait_for_db_cluster_snapshot_state") as mock_wait:
+ cluster_snapshot_operator.execute(None)
+ mock_wait.assert_called_once_with(DB_CLUSTER_SNAPSHOT,
target_state="deleted")
+
+ with pytest.raises(self.hook.conn.exceptions.ClientError):
+
self.hook.conn.describe_db_cluster_snapshots(DBClusterSnapshotIdentifier=DB_CLUSTER_SNAPSHOT)
+
+ @mock_rds
+ def test_delete_db_cluster_snapshot_no_wait(self):
+ """
+ Check that the operator does not wait for the DB cluster snapshot
delete operation to complete when
+ wait_for_completion=False
+ """
+ _create_db_cluster(self.hook)
+ _create_db_cluster_snapshot(self.hook)
+
+ cluster_snapshot_operator = RdsDeleteDbSnapshotOperator(
+ task_id="test_delete_db_cluster_snapshot_no_wait",
+ db_type="cluster",
+ db_snapshot_identifier=DB_CLUSTER_SNAPSHOT,
+ aws_conn_id=AWS_CONN,
+ dag=self.dag,
+ wait_for_completion=False,
+ )
+ _patch_hook_get_connection(cluster_snapshot_operator.hook)
+ with patch.object(cluster_snapshot_operator.hook,
"wait_for_db_cluster_snapshot_state") as mock_wait:
+ cluster_snapshot_operator.execute(None)
+ mock_wait.assert_not_called()
with pytest.raises(self.hook.conn.exceptions.ClientError):
self.hook.conn.describe_db_cluster_snapshots(DBClusterSnapshotIdentifier=DB_CLUSTER_SNAPSHOT)
@@ -466,7 +499,7 @@ class TestRdsStartExportTaskOperator:
assert export_tasks[0]["Status"] == "complete"
@mock_rds
- @patch.object(RdsBaseOperator, "_await_status")
+ @patch.object(RdsHook, "wait_for_export_task_state")
def test_start_export_task_no_wait(self, mock_await_status):
_create_db_instance(self.hook)
_create_db_instance_snapshot(self.hook)
@@ -491,7 +524,7 @@ class TestRdsStartExportTaskOperator:
assert export_tasks
assert len(export_tasks) == 1
assert export_tasks[0]["Status"] == "complete"
- assert mock_await_status.not_called()
+ mock_await_status.assert_not_called()
class TestRdsCancelExportTaskOperator:
@@ -529,7 +562,7 @@ class TestRdsCancelExportTaskOperator:
assert export_tasks[0]["Status"] == "canceled"
@mock_rds
- @patch.object(RdsBaseOperator, "_await_status")
+ @patch.object(RdsHook, "wait_for_export_task_state")
def test_cancel_export_task_no_wait(self, mock_await_status):
_create_db_instance(self.hook)
_create_db_instance_snapshot(self.hook)
@@ -540,6 +573,7 @@ class TestRdsCancelExportTaskOperator:
export_task_identifier=EXPORT_TASK_NAME,
aws_conn_id=AWS_CONN,
dag=self.dag,
+ wait_for_completion=False,
)
_patch_hook_get_connection(cancel_export_operator.hook)
cancel_export_operator.execute(None)
@@ -550,7 +584,7 @@ class TestRdsCancelExportTaskOperator:
assert export_tasks
assert len(export_tasks) == 1
assert export_tasks[0]["Status"] == "canceled"
- assert mock_await_status.not_called()
+ mock_await_status.assert_not_called()
class TestRdsCreateEventSubscriptionOperator:
@@ -589,7 +623,7 @@ class TestRdsCreateEventSubscriptionOperator:
assert subscriptions[0]["Status"] == "active"
@mock_rds
- @patch.object(RdsBaseOperator, "_await_status")
+ @patch.object(RdsHook, "wait_for_event_subscription_state")
def test_create_event_subscription_no_wait(self, mock_await_status):
_create_db_instance(self.hook)
@@ -601,6 +635,7 @@ class TestRdsCreateEventSubscriptionOperator:
source_ids=[DB_INSTANCE_NAME],
aws_conn_id=AWS_CONN,
dag=self.dag,
+ wait_for_completion=False,
)
_patch_hook_get_connection(create_subscription_operator.hook)
create_subscription_operator.execute(None)
@@ -611,7 +646,7 @@ class TestRdsCreateEventSubscriptionOperator:
assert subscriptions
assert len(subscriptions) == 1
assert subscriptions[0]["Status"] == "active"
- assert mock_await_status.not_called()
+ mock_await_status.assert_not_called()
class TestRdsDeleteEventSubscriptionOperator:
@@ -679,7 +714,7 @@ class TestRdsCreateDbInstanceOperator:
assert db_instances[0]["DBInstanceStatus"] == "available"
@mock_rds
- @patch.object(RdsBaseOperator, "_await_status")
+ @patch.object(RdsHook, "wait_for_db_instance_state")
def test_create_db_instance_no_wait(self, mock_await_status):
create_db_instance_operator = RdsCreateDbInstanceOperator(
task_id="test_create_db_instance_no_wait",
@@ -691,6 +726,7 @@ class TestRdsCreateDbInstanceOperator:
},
aws_conn_id=AWS_CONN,
dag=self.dag,
+ wait_for_completion=False,
)
_patch_hook_get_connection(create_db_instance_operator.hook)
create_db_instance_operator.execute(None)
@@ -701,7 +737,7 @@ class TestRdsCreateDbInstanceOperator:
assert db_instances
assert len(db_instances) == 1
assert db_instances[0]["DBInstanceStatus"] == "available"
- assert mock_await_status.not_called()
+ mock_await_status.assert_not_called()
class TestRdsDeleteDbInstanceOperator:
@@ -717,7 +753,7 @@ class TestRdsDeleteDbInstanceOperator:
del cls.hook
@mock_rds
- def test_delete_event_subscription(self):
+ def test_delete_db_instance(self):
_create_db_instance(self.hook)
delete_db_instance_operator = RdsDeleteDbInstanceOperator(
@@ -736,8 +772,8 @@ class TestRdsDeleteDbInstanceOperator:
self.hook.conn.describe_db_instances(DBInstanceIdentifier=DB_INSTANCE_NAME)
@mock_rds
- @patch.object(RdsBaseOperator, "_await_status")
- def test_delete_event_subscription_no_wait(self, mock_await_status):
+ @patch.object(RdsHook, "wait_for_db_instance_state")
+ def test_delete_db_instance_no_wait(self, mock_await_status):
_create_db_instance(self.hook)
delete_db_instance_operator = RdsDeleteDbInstanceOperator(
@@ -755,7 +791,7 @@ class TestRdsDeleteDbInstanceOperator:
with pytest.raises(self.hook.conn.exceptions.ClientError):
self.hook.conn.describe_db_instances(DBInstanceIdentifier=DB_INSTANCE_NAME)
- assert mock_await_status.not_called()
+ mock_await_status.assert_not_called()
class TestRdsStopDbOperator:
@@ -771,7 +807,7 @@ class TestRdsStopDbOperator:
del cls.hook
@mock_rds
- @patch.object(RdsBaseOperator, "_await_status")
+ @patch.object(RdsHook, "wait_for_db_instance_state")
def test_stop_db_instance(self, mock_await_status):
_create_db_instance(self.hook)
stop_db_instance = RdsStopDbOperator(task_id="test_stop_db_instance",
db_identifier=DB_INSTANCE_NAME)
@@ -783,7 +819,7 @@ class TestRdsStopDbOperator:
mock_await_status.assert_called()
@mock_rds
- @patch.object(RdsBaseOperator, "_await_status")
+ @patch.object(RdsHook, "wait_for_db_instance_state")
def test_stop_db_instance_no_wait(self, mock_await_status):
_create_db_instance(self.hook)
stop_db_instance = RdsStopDbOperator(
@@ -817,7 +853,7 @@ class TestRdsStopDbOperator:
assert len(instance_snapshots) == 1
@mock_rds
- @patch.object(RdsBaseOperator, "_await_status")
+ @patch.object(RdsHook, "wait_for_db_cluster_state")
def test_stop_db_cluster(self, mock_await_status):
_create_db_cluster(self.hook)
stop_db_cluster = RdsStopDbOperator(
@@ -829,6 +865,7 @@ class TestRdsStopDbOperator:
describe_result =
self.hook.conn.describe_db_clusters(DBClusterIdentifier=DB_CLUSTER_NAME)
status = describe_result["DBClusters"][0]["Status"]
assert status == "stopped"
+ mock_await_status.assert_called()
@mock_rds
def test_stop_db_cluster_create_snapshot_logs_warning_message(self,
caplog):
diff --git a/tests/providers/amazon/aws/sensors/test_rds.py
b/tests/providers/amazon/aws/sensors/test_rds.py
index 0fca35f4fb..7db2b3cb82 100644
--- a/tests/providers/amazon/aws/sensors/test_rds.py
+++ b/tests/providers/amazon/aws/sensors/test_rds.py
@@ -16,10 +16,8 @@
# under the License.
from __future__ import annotations
-import pytest
from moto import mock_rds
-from airflow.exceptions import AirflowException
from airflow.models import DAG
from airflow.providers.amazon.aws.hooks.rds import RdsHook
from airflow.providers.amazon.aws.sensors.rds import (
@@ -119,22 +117,6 @@ class TestBaseRdsSensor:
assert hasattr(self.base_sensor, "hook")
assert self.base_sensor.hook.__class__.__name__ == "RdsHook"
- def test_describe_item_wrong_type(self):
- with pytest.raises(AirflowException):
- self.base_sensor._describe_item("database", "auth-db")
-
- def test_check_item_true(self):
- self.base_sensor._describe_item = lambda item_type, item_name:
[{"Status": "available"}]
- self.base_sensor.target_statuses = ["available", "created"]
-
- assert self.base_sensor._check_item(item_type="instance_snapshot",
item_name="")
-
- def test_check_item_false(self):
- self.base_sensor._describe_item = lambda item_type, item_name:
[{"Status": "creating"}]
- self.base_sensor.target_statuses = ["available", "created"]
-
- assert not self.base_sensor._check_item(item_type="instance_snapshot",
item_name="")
-
class TestRdsSnapshotExistenceSensor:
@classmethod
@@ -161,6 +143,7 @@ class TestRdsSnapshotExistenceSensor:
@mock_rds
def test_db_instance_snapshot_poke_false(self):
+ _create_db_instance(self.hook)
op = RdsSnapshotExistenceSensor(
task_id="test_instance_snap_false",
db_type="instance",