This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new a2413cf6ca Add RdsStopDbOperator and RdsStartDbOperator (#27076)
a2413cf6ca is described below
commit a2413cf6ca8b93e491a48af11d769cd13bce8884
Author: Hank Ehly <[email protected]>
AuthorDate: Wed Oct 19 14:36:05 2022 +0900
Add RdsStopDbOperator and RdsStartDbOperator (#27076)
* Add operator classes
Co-authored-by: Vincent <[email protected]>
Co-authored-by: eladkal <[email protected]>
---
airflow/providers/amazon/aws/operators/rds.py | 147 ++++++++++++++++++++-
airflow/providers/amazon/aws/sensors/rds.py | 2 +-
.../operators/rds.rst | 29 ++++
tests/providers/amazon/aws/operators/test_rds.py | 144 ++++++++++++++++++++
.../providers/amazon/aws/example_rds_instance.py | 18 +++
5 files changed, 335 insertions(+), 5 deletions(-)
diff --git a/airflow/providers/amazon/aws/operators/rds.py
b/airflow/providers/amazon/aws/operators/rds.py
index e79086c2da..aa7a3ac87e 100644
--- a/airflow/providers/amazon/aws/operators/rds.py
+++ b/airflow/providers/amazon/aws/operators/rds.py
@@ -46,7 +46,6 @@ class RdsBaseOperator(BaseOperator):
self._await_interval = 60 # seconds
def _describe_item(self, item_type: str, item_name: str) -> list:
-
if item_type == 'instance_snapshot':
db_snaps =
self.hook.conn.describe_db_snapshots(DBSnapshotIdentifier=item_name)
return db_snaps['DBSnapshots']
@@ -59,6 +58,12 @@ class RdsBaseOperator(BaseOperator):
elif item_type == 'event_subscription':
subscriptions =
self.hook.conn.describe_event_subscriptions(SubscriptionName=item_name)
return subscriptions['EventSubscriptionsList']
+ elif item_type == "db_instance":
+ instances =
self.hook.conn.describe_db_instances(DBInstanceIdentifier=item_name)
+ return instances["DBInstances"]
+ elif item_type == "db_cluster":
+ clusters =
self.hook.conn.describe_db_clusters(DBClusterIdentifier=item_name)
+ return clusters["DBClusters"]
else:
raise AirflowException(f"Method for {item_type} is not
implemented")
@@ -83,12 +88,17 @@ class RdsBaseOperator(BaseOperator):
if len(items) > 1:
raise AirflowException(f"There are {len(items)} {item_type}
with identifier {item_name}")
- if wait_statuses and items[0]['Status'].lower() in wait_statuses:
+ if item_type == "db_instance":
+ status_field = "DBInstanceStatus"
+ else:
+ status_field = "Status"
+
+ if wait_statuses and items[0][status_field].lower() in
wait_statuses:
time.sleep(self._await_interval)
continue
- elif ok_statuses and items[0]['Status'].lower() in ok_statuses:
+ elif ok_statuses and items[0][status_field].lower() in ok_statuses:
break
- elif error_statuses and items[0]['Status'].lower() in
error_statuses:
+ elif error_statuses and items[0][status_field].lower() in
error_statuses:
raise AirflowException(f"Item has error status
({error_statuses}): {items[0]}")
else:
raise AirflowException(f"Item has uncertain status:
{items[0]}")
@@ -672,6 +682,133 @@ class RdsDeleteDbInstanceOperator(RdsBaseOperator):
return json.dumps(delete_db_instance, default=str)
+class RdsStartDbOperator(RdsBaseOperator):
+ """
+ Starts an RDS DB instance / cluster
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the
guide:
+ :ref:`howto/operator:RdsStartDbOperator`
+
+ :param db_identifier: The AWS identifier of the DB to start
+ :param db_type: Type of the DB - either "instance" or "cluster" (default:
"instance")
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
(default: "aws_default")
+ :param wait_for_completion: If True, waits for DB to start. (default:
True)
+ """
+
+ template_fields = ("db_identifier", "db_type")
+
+ def __init__(
+ self,
+ *,
+ db_identifier: str,
+ db_type: RdsDbType | str = RdsDbType.INSTANCE,
+ aws_conn_id: str = "aws_default",
+ wait_for_completion: bool = True,
+ **kwargs,
+ ):
+ super().__init__(aws_conn_id=aws_conn_id, **kwargs)
+ self.db_identifier = db_identifier
+ self.db_type = db_type
+ self.wait_for_completion = wait_for_completion
+
+ def execute(self, context: Context) -> str:
+ self.db_type = RdsDbType(self.db_type)
+ start_db_response = self._start_db()
+ if self.wait_for_completion:
+ self._wait_until_db_available()
+ return json.dumps(start_db_response, default=str)
+
+ def _start_db(self):
+ self.log.info("Starting DB %s '%s'", self.db_type.value,
self.db_identifier)
+ if self.db_type == RdsDbType.INSTANCE:
+ response =
self.hook.conn.start_db_instance(DBInstanceIdentifier=self.db_identifier)
+ else:
+ response =
self.hook.conn.start_db_cluster(DBClusterIdentifier=self.db_identifier)
+ return response
+
+ def _wait_until_db_available(self):
+ self.log.info("Waiting for DB %s to reach 'available' state",
self.db_type.value)
+ if self.db_type == RdsDbType.INSTANCE:
+
self.hook.conn.get_waiter("db_instance_available").wait(DBInstanceIdentifier=self.db_identifier)
+ else:
+
self.hook.conn.get_waiter("db_cluster_available").wait(DBClusterIdentifier=self.db_identifier)
+
+
+class RdsStopDbOperator(RdsBaseOperator):
+ """
+ Stops an RDS DB instance / cluster
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the
guide:
+ :ref:`howto/operator:RdsStopDbOperator`
+
+ :param db_identifier: The AWS identifier of the DB to stop
+ :param db_type: Type of the DB - either "instance" or "cluster" (default:
"instance")
+ :param db_snapshot_identifier: The instance identifier of the DB Snapshot
to create before
+ stopping the DB instance. The default value (None) skips snapshot
creation. This
+ parameter is ignored when ``db_type`` is "cluster"
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
(default: "aws_default")
+ :param wait_for_completion: If True, waits for DB to stop. (default: True)
+ """
+
+ template_fields = ("db_identifier", "db_snapshot_identifier", "db_type")
+
+ def __init__(
+ self,
+ *,
+ db_identifier: str,
+ db_type: RdsDbType | str = RdsDbType.INSTANCE,
+ db_snapshot_identifier: str | None = None,
+ aws_conn_id: str = "aws_default",
+ wait_for_completion: bool = True,
+ **kwargs,
+ ):
+ super().__init__(aws_conn_id=aws_conn_id, **kwargs)
+ self.db_identifier = db_identifier
+ self.db_type = db_type
+ self.db_snapshot_identifier = db_snapshot_identifier
+ self.wait_for_completion = wait_for_completion
+
+ def execute(self, context: Context) -> str:
+ self.db_type = RdsDbType(self.db_type)
+ stop_db_response = self._stop_db()
+ if self.wait_for_completion:
+ self._wait_until_db_stopped()
+ return json.dumps(stop_db_response, default=str)
+
+ def _stop_db(self):
+ self.log.info("Stopping DB %s '%s'", self.db_type.value,
self.db_identifier)
+ if self.db_type == RdsDbType.INSTANCE:
+ conn_params = {"DBInstanceIdentifier": self.db_identifier}
+ # The db snapshot parameter is optional, but the AWS SDK raises an
exception
+ # if passed a null value. Only set snapshot id if value is present.
+ if self.db_snapshot_identifier:
+ conn_params["DBSnapshotIdentifier"] =
self.db_snapshot_identifier
+ response = self.hook.conn.stop_db_instance(**conn_params)
+ else:
+ if self.db_snapshot_identifier:
+ self.log.warning(
+ "'db_snapshot_identifier' does not apply to db clusters. "
+ "Remove it to silence this warning."
+ )
+ response =
self.hook.conn.stop_db_cluster(DBClusterIdentifier=self.db_identifier)
+ return response
+
+ def _wait_until_db_stopped(self):
+ self.log.info("Waiting for DB %s to reach 'stopped' state",
self.db_type.value)
+ wait_statuses = ["stopping"]
+ ok_statuses = ["stopped"]
+ if self.db_type == RdsDbType.INSTANCE:
+ self._await_status(
+ "db_instance", self.db_identifier,
wait_statuses=wait_statuses, ok_statuses=ok_statuses
+ )
+ else:
+ self._await_status(
+ "db_cluster", self.db_identifier, wait_statuses=wait_statuses,
ok_statuses=ok_statuses
+ )
+
+
__all__ = [
"RdsCreateDbSnapshotOperator",
"RdsCopyDbSnapshotOperator",
@@ -682,4 +819,6 @@ __all__ = [
"RdsCancelExportTaskOperator",
"RdsCreateDbInstanceOperator",
"RdsDeleteDbInstanceOperator",
+ "RdsStartDbOperator",
+ "RdsStopDbOperator",
]
diff --git a/airflow/providers/amazon/aws/sensors/rds.py
b/airflow/providers/amazon/aws/sensors/rds.py
index 6240f10a91..6af1ccdc35 100644
--- a/airflow/providers/amazon/aws/sensors/rds.py
+++ b/airflow/providers/amazon/aws/sensors/rds.py
@@ -169,7 +169,7 @@ class RdsDbSensor(RdsBaseSensor):
For more information on how to use this sensor, take a look at the
guide:
:ref:`howto/sensor:RdsDbSensor`
- :param db_type: Type of the DB - either "instance" or "cluster"
+ :param db_type: Type of the DB - either "instance" or "cluster" (default:
'instance')
:param db_identifier: The AWS identifier for the DB
:param target_statuses: Target status of DB
"""
diff --git a/docs/apache-airflow-providers-amazon/operators/rds.rst
b/docs/apache-airflow-providers-amazon/operators/rds.rst
index d0752a014d..47f00966bc 100644
--- a/docs/apache-airflow-providers-amazon/operators/rds.rst
+++ b/docs/apache-airflow-providers-amazon/operators/rds.rst
@@ -166,6 +166,35 @@ To delete a AWS DB instance you can use
:start-after: [START howto_operator_rds_delete_db_instance]
:end-before: [END howto_operator_rds_delete_db_instance]
+.. _howto/operator:RdsStartDbOperator:
+
+Start a database instance or cluster
+====================================
+
+To start an Amazon RDS DB instance or cluster you can use
+:class:`~airflow.providers.amazon.aws.operators.rds.RdsStartDbOperator`.
+
+.. exampleinclude::
/../../tests/system/providers/amazon/aws/example_rds_instance.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_rds_start_db]
+ :end-before: [END howto_operator_rds_start_db]
+
+
+.. _howto/operator:RdsStopDbOperator:
+
+Stop a database instance or cluster
+===================================
+
+To stop an Amazon RDS DB instance or cluster you can use
+:class:`~airflow.providers.amazon.aws.operators.rds.RdsStopDbOperator`.
+
+.. exampleinclude::
/../../tests/system/providers/amazon/aws/example_rds_instance.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_rds_stop_db]
+ :end-before: [END howto_operator_rds_stop_db]
+
Sensors
-------
diff --git a/tests/providers/amazon/aws/operators/test_rds.py
b/tests/providers/amazon/aws/operators/test_rds.py
index b4b3587e50..a4092e7fc5 100644
--- a/tests/providers/amazon/aws/operators/test_rds.py
+++ b/tests/providers/amazon/aws/operators/test_rds.py
@@ -17,6 +17,7 @@
# under the License.
from __future__ import annotations
+import logging
from unittest.mock import patch
import pytest
@@ -35,7 +36,9 @@ from airflow.providers.amazon.aws.operators.rds import (
RdsDeleteDbInstanceOperator,
RdsDeleteDbSnapshotOperator,
RdsDeleteEventSubscriptionOperator,
+ RdsStartDbOperator,
RdsStartExportTaskOperator,
+ RdsStopDbOperator,
)
from airflow.utils import timezone
@@ -767,3 +770,144 @@ class TestRdsDeleteDbInstanceOperator:
with pytest.raises(self.hook.conn.exceptions.ClientError):
self.hook.conn.describe_db_instances(DBInstanceIdentifier=DB_INSTANCE_NAME)
assert mock_await_status.not_called()
+
+
[email protected](mock_rds is None, reason="mock_rds package not present")
+class TestRdsStopDbOperator:
+ @classmethod
+ def setup_class(cls):
+ cls.dag = DAG("test_dag", default_args={"owner": "airflow",
"start_date": DEFAULT_DATE})
+ cls.hook = RdsHook(aws_conn_id=AWS_CONN, region_name="us-east-1")
+ _patch_hook_get_connection(cls.hook)
+
+ @classmethod
+ def teardown_class(cls):
+ del cls.dag
+ del cls.hook
+
+ @mock_rds
+ @patch.object(RdsBaseOperator, "_await_status")
+ def test_stop_db_instance(self, mock_await_status):
+ _create_db_instance(self.hook)
+ stop_db_instance = RdsStopDbOperator(task_id="test_stop_db_instance",
db_identifier=DB_INSTANCE_NAME)
+ _patch_hook_get_connection(stop_db_instance.hook)
+ stop_db_instance.execute(None)
+ result =
self.hook.conn.describe_db_instances(DBInstanceIdentifier=DB_INSTANCE_NAME)
+ status = result["DBInstances"][0]["DBInstanceStatus"]
+ assert status == "stopped"
+ mock_await_status.assert_called()
+
+ @mock_rds
+ @patch.object(RdsBaseOperator, "_await_status")
+ def test_stop_db_instance_no_wait(self, mock_await_status):
+ _create_db_instance(self.hook)
+ stop_db_instance = RdsStopDbOperator(
+ task_id="test_stop_db_instance_no_wait",
db_identifier=DB_INSTANCE_NAME, wait_for_completion=False
+ )
+ _patch_hook_get_connection(stop_db_instance.hook)
+ stop_db_instance.execute(None)
+ result =
self.hook.conn.describe_db_instances(DBInstanceIdentifier=DB_INSTANCE_NAME)
+ status = result["DBInstances"][0]["DBInstanceStatus"]
+ assert status == "stopped"
+ mock_await_status.assert_not_called()
+
+ @mock_rds
+ def test_stop_db_instance_create_snapshot(self):
+ _create_db_instance(self.hook)
+ stop_db_instance = RdsStopDbOperator(
+ task_id="test_stop_db_instance_create_snapshot",
+ db_identifier=DB_INSTANCE_NAME,
+ db_snapshot_identifier=DB_INSTANCE_SNAPSHOT,
+ )
+ _patch_hook_get_connection(stop_db_instance.hook)
+ stop_db_instance.execute(None)
+
+ describe_result =
self.hook.conn.describe_db_instances(DBInstanceIdentifier=DB_INSTANCE_NAME)
+ status = describe_result["DBInstances"][0]['DBInstanceStatus']
+ assert status == "stopped"
+
+ snapshot_result =
self.hook.conn.describe_db_snapshots(DBSnapshotIdentifier=DB_INSTANCE_SNAPSHOT)
+ instance_snapshots = snapshot_result.get("DBSnapshots")
+ assert instance_snapshots
+ assert len(instance_snapshots) == 1
+
+ @mock_rds
+ @patch.object(RdsBaseOperator, "_await_status")
+ def test_stop_db_cluster(self, mock_await_status):
+ _create_db_cluster(self.hook)
+ stop_db_cluster = RdsStopDbOperator(
+ task_id="test_stop_db_cluster", db_identifier=DB_CLUSTER_NAME,
db_type="cluster"
+ )
+ _patch_hook_get_connection(stop_db_cluster.hook)
+ stop_db_cluster.execute(None)
+
+ describe_result =
self.hook.conn.describe_db_clusters(DBClusterIdentifier=DB_CLUSTER_NAME)
+ status = describe_result["DBClusters"][0]["Status"]
+ assert status == "stopped"
+
+ @mock_rds
+ def test_stop_db_cluster_create_snapshot_logs_warning_message(self,
caplog):
+ _create_db_cluster(self.hook)
+ stop_db_cluster = RdsStopDbOperator(
+ task_id="test_stop_db_cluster",
+ db_identifier=DB_CLUSTER_NAME,
+ db_type="cluster",
+ db_snapshot_identifier=DB_CLUSTER_SNAPSHOT,
+ )
+ _patch_hook_get_connection(stop_db_cluster.hook)
+ with caplog.at_level(logging.WARNING, logger=stop_db_cluster.log.name):
+ stop_db_cluster.execute(None)
+ warning_message = (
+ "'db_snapshot_identifier' does not apply to db clusters. Remove it
to silence this warning."
+ )
+ assert warning_message in caplog.text
+
+
[email protected](mock_rds is None, reason="mock_rds package not present")
+class TestRdsStartDbOperator:
+ @classmethod
+ def setup_class(cls):
+ cls.dag = DAG("test_dag", default_args={"owner": "airflow",
"start_date": DEFAULT_DATE})
+ cls.hook = RdsHook(aws_conn_id=AWS_CONN, region_name="us-east-1")
+ _patch_hook_get_connection(cls.hook)
+
+ @classmethod
+ def teardown_class(cls):
+ del cls.dag
+ del cls.hook
+
+ @mock_rds
+ def test_start_db_instance(self):
+ _create_db_instance(self.hook)
+ self.hook.conn.stop_db_instance(DBInstanceIdentifier=DB_INSTANCE_NAME)
+ result_before =
self.hook.conn.describe_db_instances(DBInstanceIdentifier=DB_INSTANCE_NAME)
+ status_before = result_before["DBInstances"][0]["DBInstanceStatus"]
+ assert status_before == "stopped"
+
+ start_db_instance = RdsStartDbOperator(
+ task_id="test_start_db_instance", db_identifier=DB_INSTANCE_NAME
+ )
+ _patch_hook_get_connection(start_db_instance.hook)
+ start_db_instance.execute(None)
+
+ result_after =
self.hook.conn.describe_db_instances(DBInstanceIdentifier=DB_INSTANCE_NAME)
+ status_after = result_after["DBInstances"][0]["DBInstanceStatus"]
+ assert status_after == "available"
+
+ @mock_rds
+ def test_start_db_cluster(self):
+ _create_db_cluster(self.hook)
+ self.hook.conn.stop_db_cluster(DBClusterIdentifier=DB_CLUSTER_NAME)
+ result_before =
self.hook.conn.describe_db_clusters(DBClusterIdentifier=DB_CLUSTER_NAME)
+ status_before = result_before["DBClusters"][0]["Status"]
+ assert status_before == "stopped"
+
+ start_db_cluster = RdsStartDbOperator(
+ task_id="test_start_db_cluster", db_identifier=DB_CLUSTER_NAME,
db_type="cluster"
+ )
+ _patch_hook_get_connection(start_db_cluster.hook)
+ start_db_cluster.execute(None)
+
+ result_after =
self.hook.conn.describe_db_clusters(DBClusterIdentifier=DB_CLUSTER_NAME)
+ status_after = result_after["DBClusters"][0]["Status"]
+ assert status_after == "available"
diff --git a/tests/system/providers/amazon/aws/example_rds_instance.py
b/tests/system/providers/amazon/aws/example_rds_instance.py
index c27f7f99ee..b31be7c007 100644
--- a/tests/system/providers/amazon/aws/example_rds_instance.py
+++ b/tests/system/providers/amazon/aws/example_rds_instance.py
@@ -23,6 +23,8 @@ from airflow.models.baseoperator import chain
from airflow.providers.amazon.aws.operators.rds import (
RdsCreateDbInstanceOperator,
RdsDeleteDbInstanceOperator,
+ RdsStartDbOperator,
+ RdsStopDbOperator,
)
from airflow.providers.amazon.aws.sensors.rds import RdsDbSensor
from airflow.utils.trigger_rule import TriggerRule
@@ -71,6 +73,20 @@ with DAG(
)
# [END howto_sensor_rds_instance]
+ # [START howto_operator_rds_stop_db]
+ stop_db_instance = RdsStopDbOperator(
+ task_id="stop_db_instance",
+ db_identifier=rds_db_identifier,
+ )
+ # [END howto_operator_rds_stop_db]
+
+ # [START howto_operator_rds_start_db]
+ start_db_instance = RdsStartDbOperator(
+ task_id="start_db_instance",
+ db_identifier=rds_db_identifier,
+ )
+ # [END howto_operator_rds_start_db]
+
# [START howto_operator_rds_delete_db_instance]
delete_db_instance = RdsDeleteDbInstanceOperator(
task_id='delete_db_instance',
@@ -88,6 +104,8 @@ with DAG(
# TEST BODY
create_db_instance,
await_db_instance,
+ stop_db_instance,
+ start_db_instance,
delete_db_instance,
)