This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new a2413cf6ca Add RdsStopDbOperator and RdsStartDbOperator (#27076)
a2413cf6ca is described below

commit a2413cf6ca8b93e491a48af11d769cd13bce8884
Author: Hank Ehly <[email protected]>
AuthorDate: Wed Oct 19 14:36:05 2022 +0900

    Add RdsStopDbOperator and RdsStartDbOperator (#27076)
    
    * Add operator classes
    
    Co-authored-by: Vincent <[email protected]>
    
    Co-authored-by: eladkal <[email protected]>
---
 airflow/providers/amazon/aws/operators/rds.py      | 147 ++++++++++++++++++++-
 airflow/providers/amazon/aws/sensors/rds.py        |   2 +-
 .../operators/rds.rst                              |  29 ++++
 tests/providers/amazon/aws/operators/test_rds.py   | 144 ++++++++++++++++++++
 .../providers/amazon/aws/example_rds_instance.py   |  18 +++
 5 files changed, 335 insertions(+), 5 deletions(-)

diff --git a/airflow/providers/amazon/aws/operators/rds.py 
b/airflow/providers/amazon/aws/operators/rds.py
index e79086c2da..aa7a3ac87e 100644
--- a/airflow/providers/amazon/aws/operators/rds.py
+++ b/airflow/providers/amazon/aws/operators/rds.py
@@ -46,7 +46,6 @@ class RdsBaseOperator(BaseOperator):
         self._await_interval = 60  # seconds
 
     def _describe_item(self, item_type: str, item_name: str) -> list:
-
         if item_type == 'instance_snapshot':
             db_snaps = 
self.hook.conn.describe_db_snapshots(DBSnapshotIdentifier=item_name)
             return db_snaps['DBSnapshots']
@@ -59,6 +58,12 @@ class RdsBaseOperator(BaseOperator):
         elif item_type == 'event_subscription':
             subscriptions = 
self.hook.conn.describe_event_subscriptions(SubscriptionName=item_name)
             return subscriptions['EventSubscriptionsList']
+        elif item_type == "db_instance":
+            instances = 
self.hook.conn.describe_db_instances(DBInstanceIdentifier=item_name)
+            return instances["DBInstances"]
+        elif item_type == "db_cluster":
+            clusters = 
self.hook.conn.describe_db_clusters(DBClusterIdentifier=item_name)
+            return clusters["DBClusters"]
         else:
             raise AirflowException(f"Method for {item_type} is not 
implemented")
 
@@ -83,12 +88,17 @@ class RdsBaseOperator(BaseOperator):
             if len(items) > 1:
                 raise AirflowException(f"There are {len(items)} {item_type} 
with identifier {item_name}")
 
-            if wait_statuses and items[0]['Status'].lower() in wait_statuses:
+            if item_type == "db_instance":
+                status_field = "DBInstanceStatus"
+            else:
+                status_field = "Status"
+
+            if wait_statuses and items[0][status_field].lower() in 
wait_statuses:
                 time.sleep(self._await_interval)
                 continue
-            elif ok_statuses and items[0]['Status'].lower() in ok_statuses:
+            elif ok_statuses and items[0][status_field].lower() in ok_statuses:
                 break
-            elif error_statuses and items[0]['Status'].lower() in 
error_statuses:
+            elif error_statuses and items[0][status_field].lower() in 
error_statuses:
                 raise AirflowException(f"Item has error status 
({error_statuses}): {items[0]}")
             else:
                 raise AirflowException(f"Item has uncertain status: 
{items[0]}")
@@ -672,6 +682,133 @@ class RdsDeleteDbInstanceOperator(RdsBaseOperator):
         return json.dumps(delete_db_instance, default=str)
 
 
+class RdsStartDbOperator(RdsBaseOperator):
+    """
+    Starts an RDS DB instance / cluster
+
+    .. seealso::
+        For more information on how to use this operator, take a look at the 
guide:
+        :ref:`howto/operator:RdsStartDbOperator`
+
+    :param db_identifier: The AWS identifier of the DB to start
+    :param db_type: Type of the DB - either "instance" or "cluster" (default: 
"instance")
+    :param aws_conn_id: The Airflow connection used for AWS credentials. 
(default: "aws_default")
+    :param wait_for_completion:  If True, waits for DB to start. (default: 
True)
+    """
+
+    template_fields = ("db_identifier", "db_type")
+
+    def __init__(
+        self,
+        *,
+        db_identifier: str,
+        db_type: RdsDbType | str = RdsDbType.INSTANCE,
+        aws_conn_id: str = "aws_default",
+        wait_for_completion: bool = True,
+        **kwargs,
+    ):
+        super().__init__(aws_conn_id=aws_conn_id, **kwargs)
+        self.db_identifier = db_identifier
+        self.db_type = db_type
+        self.wait_for_completion = wait_for_completion
+
+    def execute(self, context: Context) -> str:
+        self.db_type = RdsDbType(self.db_type)
+        start_db_response = self._start_db()
+        if self.wait_for_completion:
+            self._wait_until_db_available()
+        return json.dumps(start_db_response, default=str)
+
+    def _start_db(self):
+        self.log.info("Starting DB %s '%s'", self.db_type.value, 
self.db_identifier)
+        if self.db_type == RdsDbType.INSTANCE:
+            response = 
self.hook.conn.start_db_instance(DBInstanceIdentifier=self.db_identifier)
+        else:
+            response = 
self.hook.conn.start_db_cluster(DBClusterIdentifier=self.db_identifier)
+        return response
+
+    def _wait_until_db_available(self):
+        self.log.info("Waiting for DB %s to reach 'available' state", 
self.db_type.value)
+        if self.db_type == RdsDbType.INSTANCE:
+            
self.hook.conn.get_waiter("db_instance_available").wait(DBInstanceIdentifier=self.db_identifier)
+        else:
+            
self.hook.conn.get_waiter("db_cluster_available").wait(DBClusterIdentifier=self.db_identifier)
+
+
+class RdsStopDbOperator(RdsBaseOperator):
+    """
+    Stops an RDS DB instance / cluster
+
+    .. seealso::
+        For more information on how to use this operator, take a look at the 
guide:
+        :ref:`howto/operator:RdsStopDbOperator`
+
+    :param db_identifier: The AWS identifier of the DB to stop
+    :param db_type: Type of the DB - either "instance" or "cluster" (default: 
"instance")
+    :param db_snapshot_identifier: The instance identifier of the DB Snapshot 
to create before
+        stopping the DB instance. The default value (None) skips snapshot 
creation. This
+        parameter is ignored when ``db_type`` is "cluster"
+    :param aws_conn_id: The Airflow connection used for AWS credentials. 
(default: "aws_default")
+    :param wait_for_completion:  If True, waits for DB to stop. (default: True)
+    """
+
+    template_fields = ("db_identifier", "db_snapshot_identifier", "db_type")
+
+    def __init__(
+        self,
+        *,
+        db_identifier: str,
+        db_type: RdsDbType | str = RdsDbType.INSTANCE,
+        db_snapshot_identifier: str | None = None,
+        aws_conn_id: str = "aws_default",
+        wait_for_completion: bool = True,
+        **kwargs,
+    ):
+        super().__init__(aws_conn_id=aws_conn_id, **kwargs)
+        self.db_identifier = db_identifier
+        self.db_type = db_type
+        self.db_snapshot_identifier = db_snapshot_identifier
+        self.wait_for_completion = wait_for_completion
+
+    def execute(self, context: Context) -> str:
+        self.db_type = RdsDbType(self.db_type)
+        stop_db_response = self._stop_db()
+        if self.wait_for_completion:
+            self._wait_until_db_stopped()
+        return json.dumps(stop_db_response, default=str)
+
+    def _stop_db(self):
+        self.log.info("Stopping DB %s '%s'", self.db_type.value, 
self.db_identifier)
+        if self.db_type == RdsDbType.INSTANCE:
+            conn_params = {"DBInstanceIdentifier": self.db_identifier}
+            # The db snapshot parameter is optional, but the AWS SDK raises an 
exception
+            # if passed a null value. Only set snapshot id if value is present.
+            if self.db_snapshot_identifier:
+                conn_params["DBSnapshotIdentifier"] = 
self.db_snapshot_identifier
+            response = self.hook.conn.stop_db_instance(**conn_params)
+        else:
+            if self.db_snapshot_identifier:
+                self.log.warning(
+                    "'db_snapshot_identifier' does not apply to db clusters. "
+                    "Remove it to silence this warning."
+                )
+            response = 
self.hook.conn.stop_db_cluster(DBClusterIdentifier=self.db_identifier)
+        return response
+
+    def _wait_until_db_stopped(self):
+        self.log.info("Waiting for DB %s to reach 'stopped' state", 
self.db_type.value)
+        wait_statuses = ["stopping"]
+        ok_statuses = ["stopped"]
+        if self.db_type == RdsDbType.INSTANCE:
+            self._await_status(
+                "db_instance", self.db_identifier, 
wait_statuses=wait_statuses, ok_statuses=ok_statuses
+            )
+        else:
+            self._await_status(
+                "db_cluster", self.db_identifier, wait_statuses=wait_statuses, 
ok_statuses=ok_statuses
+            )
+
+
 __all__ = [
     "RdsCreateDbSnapshotOperator",
     "RdsCopyDbSnapshotOperator",
@@ -682,4 +819,6 @@ __all__ = [
     "RdsCancelExportTaskOperator",
     "RdsCreateDbInstanceOperator",
     "RdsDeleteDbInstanceOperator",
+    "RdsStartDbOperator",
+    "RdsStopDbOperator",
 ]
diff --git a/airflow/providers/amazon/aws/sensors/rds.py 
b/airflow/providers/amazon/aws/sensors/rds.py
index 6240f10a91..6af1ccdc35 100644
--- a/airflow/providers/amazon/aws/sensors/rds.py
+++ b/airflow/providers/amazon/aws/sensors/rds.py
@@ -169,7 +169,7 @@ class RdsDbSensor(RdsBaseSensor):
         For more information on how to use this sensor, take a look at the 
guide:
         :ref:`howto/sensor:RdsDbSensor`
 
-    :param db_type: Type of the DB - either "instance" or "cluster"
+    :param db_type: Type of the DB - either "instance" or "cluster" (default: 
'instance')
     :param db_identifier: The AWS identifier for the DB
     :param target_statuses: Target status of DB
     """
diff --git a/docs/apache-airflow-providers-amazon/operators/rds.rst 
b/docs/apache-airflow-providers-amazon/operators/rds.rst
index d0752a014d..47f00966bc 100644
--- a/docs/apache-airflow-providers-amazon/operators/rds.rst
+++ b/docs/apache-airflow-providers-amazon/operators/rds.rst
@@ -166,6 +166,35 @@ To delete a AWS DB instance you can use
     :start-after: [START howto_operator_rds_delete_db_instance]
     :end-before: [END howto_operator_rds_delete_db_instance]
 
+.. _howto/operator:RdsStartDbOperator:
+
+Start a database instance or cluster
+====================================
+
+To start an Amazon RDS DB instance or cluster you can use
+:class:`~airflow.providers.amazon.aws.operators.rds.RdsStartDbOperator`.
+
+.. exampleinclude:: 
/../../tests/system/providers/amazon/aws/example_rds_instance.py
+    :language: python
+    :dedent: 4
+    :start-after: [START howto_operator_rds_start_db]
+    :end-before: [END howto_operator_rds_start_db]
+
+
+.. _howto/operator:RdsStopDbOperator:
+
+Stop a database instance or cluster
+===================================
+
+To stop an Amazon RDS DB instance or cluster you can use
+:class:`~airflow.providers.amazon.aws.operators.rds.RdsStopDbOperator`.
+
+.. exampleinclude:: 
/../../tests/system/providers/amazon/aws/example_rds_instance.py
+    :language: python
+    :dedent: 4
+    :start-after: [START howto_operator_rds_stop_db]
+    :end-before: [END howto_operator_rds_stop_db]
+
 Sensors
 -------
 
diff --git a/tests/providers/amazon/aws/operators/test_rds.py 
b/tests/providers/amazon/aws/operators/test_rds.py
index b4b3587e50..a4092e7fc5 100644
--- a/tests/providers/amazon/aws/operators/test_rds.py
+++ b/tests/providers/amazon/aws/operators/test_rds.py
@@ -17,6 +17,7 @@
 # under the License.
 from __future__ import annotations
 
+import logging
 from unittest.mock import patch
 
 import pytest
@@ -35,7 +36,9 @@ from airflow.providers.amazon.aws.operators.rds import (
     RdsDeleteDbInstanceOperator,
     RdsDeleteDbSnapshotOperator,
     RdsDeleteEventSubscriptionOperator,
+    RdsStartDbOperator,
     RdsStartExportTaskOperator,
+    RdsStopDbOperator,
 )
 from airflow.utils import timezone
 
@@ -767,3 +770,144 @@ class TestRdsDeleteDbInstanceOperator:
         with pytest.raises(self.hook.conn.exceptions.ClientError):
             
self.hook.conn.describe_db_instances(DBInstanceIdentifier=DB_INSTANCE_NAME)
         assert mock_await_status.not_called()
+
+
[email protected](mock_rds is None, reason="mock_rds package not present")
+class TestRdsStopDbOperator:
+    @classmethod
+    def setup_class(cls):
+        cls.dag = DAG("test_dag", default_args={"owner": "airflow", 
"start_date": DEFAULT_DATE})
+        cls.hook = RdsHook(aws_conn_id=AWS_CONN, region_name="us-east-1")
+        _patch_hook_get_connection(cls.hook)
+
+    @classmethod
+    def teardown_class(cls):
+        del cls.dag
+        del cls.hook
+
+    @mock_rds
+    @patch.object(RdsBaseOperator, "_await_status")
+    def test_stop_db_instance(self, mock_await_status):
+        _create_db_instance(self.hook)
+        stop_db_instance = RdsStopDbOperator(task_id="test_stop_db_instance", 
db_identifier=DB_INSTANCE_NAME)
+        _patch_hook_get_connection(stop_db_instance.hook)
+        stop_db_instance.execute(None)
+        result = 
self.hook.conn.describe_db_instances(DBInstanceIdentifier=DB_INSTANCE_NAME)
+        status = result["DBInstances"][0]["DBInstanceStatus"]
+        assert status == "stopped"
+        mock_await_status.assert_called()
+
+    @mock_rds
+    @patch.object(RdsBaseOperator, "_await_status")
+    def test_stop_db_instance_no_wait(self, mock_await_status):
+        _create_db_instance(self.hook)
+        stop_db_instance = RdsStopDbOperator(
+            task_id="test_stop_db_instance_no_wait", 
db_identifier=DB_INSTANCE_NAME, wait_for_completion=False
+        )
+        _patch_hook_get_connection(stop_db_instance.hook)
+        stop_db_instance.execute(None)
+        result = 
self.hook.conn.describe_db_instances(DBInstanceIdentifier=DB_INSTANCE_NAME)
+        status = result["DBInstances"][0]["DBInstanceStatus"]
+        assert status == "stopped"
+        mock_await_status.assert_not_called()
+
+    @mock_rds
+    def test_stop_db_instance_create_snapshot(self):
+        _create_db_instance(self.hook)
+        stop_db_instance = RdsStopDbOperator(
+            task_id="test_stop_db_instance_create_snapshot",
+            db_identifier=DB_INSTANCE_NAME,
+            db_snapshot_identifier=DB_INSTANCE_SNAPSHOT,
+        )
+        _patch_hook_get_connection(stop_db_instance.hook)
+        stop_db_instance.execute(None)
+
+        describe_result = 
self.hook.conn.describe_db_instances(DBInstanceIdentifier=DB_INSTANCE_NAME)
+        status = describe_result["DBInstances"][0]['DBInstanceStatus']
+        assert status == "stopped"
+
+        snapshot_result = 
self.hook.conn.describe_db_snapshots(DBSnapshotIdentifier=DB_INSTANCE_SNAPSHOT)
+        instance_snapshots = snapshot_result.get("DBSnapshots")
+        assert instance_snapshots
+        assert len(instance_snapshots) == 1
+
+    @mock_rds
+    @patch.object(RdsBaseOperator, "_await_status")
+    def test_stop_db_cluster(self, mock_await_status):
+        _create_db_cluster(self.hook)
+        stop_db_cluster = RdsStopDbOperator(
+            task_id="test_stop_db_cluster", db_identifier=DB_CLUSTER_NAME, 
db_type="cluster"
+        )
+        _patch_hook_get_connection(stop_db_cluster.hook)
+        stop_db_cluster.execute(None)
+
+        describe_result = 
self.hook.conn.describe_db_clusters(DBClusterIdentifier=DB_CLUSTER_NAME)
+        status = describe_result["DBClusters"][0]["Status"]
+        assert status == "stopped"
+
+    @mock_rds
+    def test_stop_db_cluster_create_snapshot_logs_warning_message(self, 
caplog):
+        _create_db_cluster(self.hook)
+        stop_db_cluster = RdsStopDbOperator(
+            task_id="test_stop_db_cluster",
+            db_identifier=DB_CLUSTER_NAME,
+            db_type="cluster",
+            db_snapshot_identifier=DB_CLUSTER_SNAPSHOT,
+        )
+        _patch_hook_get_connection(stop_db_cluster.hook)
+        with caplog.at_level(logging.WARNING, logger=stop_db_cluster.log.name):
+            stop_db_cluster.execute(None)
+        warning_message = (
+            "'db_snapshot_identifier' does not apply to db clusters. Remove it 
to silence this warning."
+        )
+        assert warning_message in caplog.text
+
+
[email protected](mock_rds is None, reason="mock_rds package not present")
+class TestRdsStartDbOperator:
+    @classmethod
+    def setup_class(cls):
+        cls.dag = DAG("test_dag", default_args={"owner": "airflow", 
"start_date": DEFAULT_DATE})
+        cls.hook = RdsHook(aws_conn_id=AWS_CONN, region_name="us-east-1")
+        _patch_hook_get_connection(cls.hook)
+
+    @classmethod
+    def teardown_class(cls):
+        del cls.dag
+        del cls.hook
+
+    @mock_rds
+    def test_start_db_instance(self):
+        _create_db_instance(self.hook)
+        self.hook.conn.stop_db_instance(DBInstanceIdentifier=DB_INSTANCE_NAME)
+        result_before = 
self.hook.conn.describe_db_instances(DBInstanceIdentifier=DB_INSTANCE_NAME)
+        status_before = result_before["DBInstances"][0]["DBInstanceStatus"]
+        assert status_before == "stopped"
+
+        start_db_instance = RdsStartDbOperator(
+            task_id="test_start_db_instance", db_identifier=DB_INSTANCE_NAME
+        )
+        _patch_hook_get_connection(start_db_instance.hook)
+        start_db_instance.execute(None)
+
+        result_after = 
self.hook.conn.describe_db_instances(DBInstanceIdentifier=DB_INSTANCE_NAME)
+        status_after = result_after["DBInstances"][0]["DBInstanceStatus"]
+        assert status_after == "available"
+
+    @mock_rds
+    def test_start_db_cluster(self):
+        _create_db_cluster(self.hook)
+        self.hook.conn.stop_db_cluster(DBClusterIdentifier=DB_CLUSTER_NAME)
+        result_before = 
self.hook.conn.describe_db_clusters(DBClusterIdentifier=DB_CLUSTER_NAME)
+        status_before = result_before["DBClusters"][0]["Status"]
+        assert status_before == "stopped"
+
+        start_db_cluster = RdsStartDbOperator(
+            task_id="test_start_db_cluster", db_identifier=DB_CLUSTER_NAME, 
db_type="cluster"
+        )
+        _patch_hook_get_connection(start_db_cluster.hook)
+        start_db_cluster.execute(None)
+
+        result_after = 
self.hook.conn.describe_db_clusters(DBClusterIdentifier=DB_CLUSTER_NAME)
+        status_after = result_after["DBClusters"][0]["Status"]
+        assert status_after == "available"
diff --git a/tests/system/providers/amazon/aws/example_rds_instance.py 
b/tests/system/providers/amazon/aws/example_rds_instance.py
index c27f7f99ee..b31be7c007 100644
--- a/tests/system/providers/amazon/aws/example_rds_instance.py
+++ b/tests/system/providers/amazon/aws/example_rds_instance.py
@@ -23,6 +23,8 @@ from airflow.models.baseoperator import chain
 from airflow.providers.amazon.aws.operators.rds import (
     RdsCreateDbInstanceOperator,
     RdsDeleteDbInstanceOperator,
+    RdsStartDbOperator,
+    RdsStopDbOperator,
 )
 from airflow.providers.amazon.aws.sensors.rds import RdsDbSensor
 from airflow.utils.trigger_rule import TriggerRule
@@ -71,6 +73,20 @@ with DAG(
     )
     # [END howto_sensor_rds_instance]
 
+    # [START howto_operator_rds_stop_db]
+    stop_db_instance = RdsStopDbOperator(
+        task_id="stop_db_instance",
+        db_identifier=rds_db_identifier,
+    )
+    # [END howto_operator_rds_stop_db]
+
+    # [START howto_operator_rds_start_db]
+    start_db_instance = RdsStartDbOperator(
+        task_id="start_db_instance",
+        db_identifier=rds_db_identifier,
+    )
+    # [END howto_operator_rds_start_db]
+
     # [START howto_operator_rds_delete_db_instance]
     delete_db_instance = RdsDeleteDbInstanceOperator(
         task_id='delete_db_instance',
@@ -88,6 +104,8 @@ with DAG(
         # TEST BODY
         create_db_instance,
         await_db_instance,
+        stop_db_instance,
+        start_db_instance,
         delete_db_instance,
     )
 

Reply via email to