This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new f6bda38e20 Convert RDS Export Sample DAG to System Test (AIP-47) 
(#25205)
f6bda38e20 is described below

commit f6bda38e20c721df12e0cc88a27119fe320f2a42
Author: D. Ferruzzi <[email protected]>
AuthorDate: Thu Jul 21 20:35:48 2022 +0000

    Convert RDS Export Sample DAG to System Test (AIP-47) (#25205)
    
    * Convert RDS Export Sample DAG to System Test
    
    * PR Fixes
---
 .../amazon/aws/example_dags/example_rds_export.py  |  71 -------
 airflow/providers/amazon/aws/operators/rds.py      |  94 +++++----
 .../operators/rds.rst                              |   6 +-
 tests/providers/amazon/aws/operators/test_rds.py   | 211 +++++++++++++++++++++
 .../providers/amazon/aws/example_rds_export.py     | 188 ++++++++++++++++++
 5 files changed, 456 insertions(+), 114 deletions(-)

diff --git a/airflow/providers/amazon/aws/example_dags/example_rds_export.py 
b/airflow/providers/amazon/aws/example_dags/example_rds_export.py
deleted file mode 100644
index 1dce580491..0000000000
--- a/airflow/providers/amazon/aws/example_dags/example_rds_export.py
+++ /dev/null
@@ -1,71 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-from datetime import datetime
-from os import getenv
-
-from airflow import DAG
-from airflow.models.baseoperator import chain
-from airflow.providers.amazon.aws.operators.rds import 
RdsCancelExportTaskOperator, RdsStartExportTaskOperator
-from airflow.providers.amazon.aws.sensors.rds import 
RdsExportTaskExistenceSensor
-
-RDS_EXPORT_TASK_IDENTIFIER = getenv("RDS_EXPORT_TASK_IDENTIFIER", 
"export-task-identifier")
-RDS_EXPORT_SOURCE_ARN = getenv(
-    "RDS_EXPORT_SOURCE_ARN", "arn:aws:rds:<region>:<account 
number>:snapshot:snap-id"
-)
-BUCKET_NAME = getenv("BUCKET_NAME", "bucket-name")
-BUCKET_PREFIX = getenv("BUCKET_PREFIX", "bucket-prefix")
-ROLE_ARN = getenv("ROLE_ARN", "arn:aws:iam::<account number>:role/Role")
-KMS_KEY_ID = getenv("KMS_KEY_ID", "arn:aws:kms:<region>:<account 
number>:key/key-id")
-
-
-with DAG(
-    dag_id='example_rds_export',
-    schedule_interval=None,
-    start_date=datetime(2021, 1, 1),
-    tags=['example'],
-    catchup=False,
-) as dag:
-    # [START howto_operator_rds_start_export_task]
-    start_export = RdsStartExportTaskOperator(
-        task_id='start_export',
-        export_task_identifier=RDS_EXPORT_TASK_IDENTIFIER,
-        source_arn=RDS_EXPORT_SOURCE_ARN,
-        s3_bucket_name=BUCKET_NAME,
-        s3_prefix=BUCKET_PREFIX,
-        iam_role_arn=ROLE_ARN,
-        kms_key_id=KMS_KEY_ID,
-    )
-    # [END howto_operator_rds_start_export_task]
-
-    # [START howto_operator_rds_cancel_export]
-    cancel_export = RdsCancelExportTaskOperator(
-        task_id='cancel_export',
-        export_task_identifier=RDS_EXPORT_TASK_IDENTIFIER,
-    )
-    # [END howto_operator_rds_cancel_export]
-
-    # [START howto_sensor_rds_export_task_existence]
-    export_sensor = RdsExportTaskExistenceSensor(
-        task_id='export_sensor',
-        export_task_identifier=RDS_EXPORT_TASK_IDENTIFIER,
-        target_statuses=['canceled'],
-    )
-    # [END howto_sensor_rds_export_task_existence]
-
-    chain(start_export, cancel_export, export_sensor)
diff --git a/airflow/providers/amazon/aws/operators/rds.py 
b/airflow/providers/amazon/aws/operators/rds.py
index fe38bfed69..6787caf20a 100644
--- a/airflow/providers/amazon/aws/operators/rds.py
+++ b/airflow/providers/amazon/aws/operators/rds.py
@@ -70,7 +70,7 @@ class RdsBaseOperator(BaseOperator):
         error_statuses: Optional[List[str]] = None,
     ) -> None:
         """
-        Continuously gets item description from `_describe_item()` and waits 
until:
+        Continuously gets item description from `_describe_item()` and waits 
while:
         - status is in `wait_statuses`
         - status not in `ok_statuses` and `error_statuses`
         """
@@ -117,6 +117,7 @@ class RdsCreateDbSnapshotOperator(RdsBaseOperator):
     :param db_snapshot_identifier: The identifier for the DB snapshot
     :param tags: A list of tags in format `[{"Key": "something", "Value": 
"something"},]
         `USER Tagging 
<https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/USER_Tagging.html>`__
+    :param wait_for_completion:  If True, waits for creation of the DB 
snapshot to complete. (default: True)
     """
 
     template_fields = ("db_snapshot_identifier", "db_identifier", "tags")
@@ -128,6 +129,7 @@ class RdsCreateDbSnapshotOperator(RdsBaseOperator):
         db_identifier: str,
         db_snapshot_identifier: str,
         tags: Optional[Sequence[TagTypeDef]] = None,
+        wait_for_completion: bool = True,
         aws_conn_id: str = "aws_conn_id",
         **kwargs,
     ):
@@ -136,6 +138,7 @@ class RdsCreateDbSnapshotOperator(RdsBaseOperator):
         self.db_identifier = db_identifier
         self.db_snapshot_identifier = db_snapshot_identifier
         self.tags = tags or []
+        self.wait_for_completion = wait_for_completion
 
     def execute(self, context: 'Context') -> str:
         self.log.info(
@@ -152,12 +155,8 @@ class RdsCreateDbSnapshotOperator(RdsBaseOperator):
                 Tags=self.tags,
             )
             create_response = json.dumps(create_instance_snap, default=str)
-            self._await_status(
-                'instance_snapshot',
-                self.db_snapshot_identifier,
-                wait_statuses=['creating'],
-                ok_statuses=['available'],
-            )
+            item_type = 'instance_snapshot'
+
         else:
             create_cluster_snap = self.hook.conn.create_db_cluster_snapshot(
                 DBClusterIdentifier=self.db_identifier,
@@ -165,13 +164,15 @@ class RdsCreateDbSnapshotOperator(RdsBaseOperator):
                 Tags=self.tags,
             )
             create_response = json.dumps(create_cluster_snap, default=str)
+            item_type = 'cluster_snapshot'
+
+        if self.wait_for_completion:
             self._await_status(
-                'cluster_snapshot',
+                item_type,
                 self.db_snapshot_identifier,
                 wait_statuses=['creating'],
                 ok_statuses=['available'],
             )
-
         return create_response
 
 
@@ -196,6 +197,7 @@ class RdsCopyDbSnapshotOperator(RdsBaseOperator):
     :param target_custom_availability_zone: The external custom Availability 
Zone identifier for the target
         Only when db_type='instance'
     :param source_region: The ID of the region that contains the snapshot to 
be copied
+    :param wait_for_completion:  If True, waits for snapshot copy to complete. 
(default: True)
     """
 
     template_fields = (
@@ -219,6 +221,7 @@ class RdsCopyDbSnapshotOperator(RdsBaseOperator):
         option_group_name: str = "",
         target_custom_availability_zone: str = "",
         source_region: str = "",
+        wait_for_completion: bool = True,
         aws_conn_id: str = "aws_default",
         **kwargs,
     ):
@@ -234,6 +237,7 @@ class RdsCopyDbSnapshotOperator(RdsBaseOperator):
         self.option_group_name = option_group_name
         self.target_custom_availability_zone = target_custom_availability_zone
         self.source_region = source_region
+        self.wait_for_completion = wait_for_completion
 
     def execute(self, context: 'Context') -> str:
         self.log.info(
@@ -255,12 +259,8 @@ class RdsCopyDbSnapshotOperator(RdsBaseOperator):
                 SourceRegion=self.source_region,
             )
             copy_response = json.dumps(copy_instance_snap, default=str)
-            self._await_status(
-                'instance_snapshot',
-                self.target_db_snapshot_identifier,
-                wait_statuses=['creating'],
-                ok_statuses=['available'],
-            )
+            item_type = 'instance_snapshot'
+
         else:
             copy_cluster_snap = self.hook.conn.copy_db_cluster_snapshot(
                 
SourceDBClusterSnapshotIdentifier=self.source_db_snapshot_identifier,
@@ -272,13 +272,15 @@ class RdsCopyDbSnapshotOperator(RdsBaseOperator):
                 SourceRegion=self.source_region,
             )
             copy_response = json.dumps(copy_cluster_snap, default=str)
+            item_type = 'cluster_snapshot'
+
+        if self.wait_for_completion:
             self._await_status(
-                'cluster_snapshot',
+                item_type,
                 self.target_db_snapshot_identifier,
                 wait_statuses=['copying'],
                 ok_statuses=['available'],
             )
-
         return copy_response
 
 
@@ -341,6 +343,7 @@ class RdsStartExportTaskOperator(RdsBaseOperator):
     :param kms_key_id: The ID of the Amazon Web Services KMS key to use to 
encrypt the snapshot.
     :param s3_prefix: The Amazon S3 bucket prefix to use as the file name and 
path of the exported snapshot.
     :param export_only: The data to be exported from the snapshot.
+    :param wait_for_completion:  If True, waits for the DB snapshot export to 
complete. (default: True)
     """
 
     template_fields = (
@@ -363,6 +366,7 @@ class RdsStartExportTaskOperator(RdsBaseOperator):
         kms_key_id: str,
         s3_prefix: str = '',
         export_only: Optional[List[str]] = None,
+        wait_for_completion: bool = True,
         aws_conn_id: str = "aws_default",
         **kwargs,
     ):
@@ -375,6 +379,7 @@ class RdsStartExportTaskOperator(RdsBaseOperator):
         self.kms_key_id = kms_key_id
         self.s3_prefix = s3_prefix
         self.export_only = export_only or []
+        self.wait_for_completion = wait_for_completion
 
     def execute(self, context: 'Context') -> str:
         self.log.info("Starting export task %s for snapshot %s", 
self.export_task_identifier, self.source_arn)
@@ -389,13 +394,14 @@ class RdsStartExportTaskOperator(RdsBaseOperator):
             ExportOnly=self.export_only,
         )
 
-        self._await_status(
-            'export_task',
-            self.export_task_identifier,
-            wait_statuses=['starting', 'in_progress'],
-            ok_statuses=['complete'],
-            error_statuses=['canceling', 'canceled'],
-        )
+        if self.wait_for_completion:
+            self._await_status(
+                'export_task',
+                self.export_task_identifier,
+                wait_statuses=['starting', 'in_progress'],
+                ok_statuses=['complete'],
+                error_statuses=['canceling', 'canceled'],
+            )
 
         return json.dumps(start_export, default=str)
 
@@ -409,6 +415,7 @@ class RdsCancelExportTaskOperator(RdsBaseOperator):
         :ref:`howto/operator:RdsCancelExportTaskOperator`
 
     :param export_task_identifier: The identifier of the snapshot export task 
to cancel
+    :param wait_for_completion:  If True, waits for DB snapshot export to 
cancel. (default: True)
     """
 
     template_fields = ("export_task_identifier",)
@@ -417,12 +424,14 @@ class RdsCancelExportTaskOperator(RdsBaseOperator):
         self,
         *,
         export_task_identifier: str,
+        wait_for_completion: bool = True,
         aws_conn_id: str = "aws_default",
         **kwargs,
     ):
         super().__init__(aws_conn_id=aws_conn_id, **kwargs)
 
         self.export_task_identifier = export_task_identifier
+        self.wait_for_completion = wait_for_completion
 
     def execute(self, context: 'Context') -> str:
         self.log.info("Canceling export task %s", self.export_task_identifier)
@@ -430,12 +439,14 @@ class RdsCancelExportTaskOperator(RdsBaseOperator):
         cancel_export = self.hook.conn.cancel_export_task(
             ExportTaskIdentifier=self.export_task_identifier,
         )
-        self._await_status(
-            'export_task',
-            self.export_task_identifier,
-            wait_statuses=['canceling'],
-            ok_statuses=['canceled'],
-        )
+
+        if self.wait_for_completion:
+            self._await_status(
+                'export_task',
+                self.export_task_identifier,
+                wait_statuses=['canceling'],
+                ok_statuses=['canceled'],
+            )
 
         return json.dumps(cancel_export, default=str)
 
@@ -458,6 +469,7 @@ class RdsCreateEventSubscriptionOperator(RdsBaseOperator):
     :param enabled: A value that indicates whether to activate the 
subscription (default True)l
     :param tags: A list of tags in format `[{"Key": "something", "Value": 
"something"},]
         `USER Tagging 
<https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/USER_Tagging.html>`__
+    :param wait_for_completion:  If True, waits for creation of the 
subscription to complete. (default: True)
     """
 
     template_fields = (
@@ -479,6 +491,7 @@ class RdsCreateEventSubscriptionOperator(RdsBaseOperator):
         source_ids: Optional[Sequence[str]] = None,
         enabled: bool = True,
         tags: Optional[Sequence[TagTypeDef]] = None,
+        wait_for_completion: bool = True,
         aws_conn_id: str = "aws_default",
         **kwargs,
     ):
@@ -491,6 +504,7 @@ class RdsCreateEventSubscriptionOperator(RdsBaseOperator):
         self.source_ids = source_ids or []
         self.enabled = enabled
         self.tags = tags or []
+        self.wait_for_completion = wait_for_completion
 
     def execute(self, context: 'Context') -> str:
         self.log.info("Creating event subscription '%s' to '%s'", 
self.subscription_name, self.sns_topic_arn)
@@ -504,12 +518,14 @@ class RdsCreateEventSubscriptionOperator(RdsBaseOperator):
             Enabled=self.enabled,
             Tags=self.tags,
         )
-        self._await_status(
-            'event_subscription',
-            self.subscription_name,
-            wait_statuses=['creating'],
-            ok_statuses=['active'],
-        )
+
+        if self.wait_for_completion:
+            self._await_status(
+                'event_subscription',
+                self.subscription_name,
+                wait_statuses=['creating'],
+                ok_statuses=['active'],
+            )
 
         return json.dumps(create_subscription, default=str)
 
@@ -566,8 +582,7 @@ class RdsCreateDbInstanceOperator(RdsBaseOperator):
     :param rds_kwargs: Named arguments to pass to boto3 RDS client function 
``create_db_instance``
         
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds.html#RDS.Client.create_db_instance
     :param aws_conn_id: The Airflow connection used for AWS credentials.
-    :param wait_for_completion:  Whether or not wait for creation of the DB 
instance to
-        complete. (default: True)
+    :param wait_for_completion:  If True, waits for creation of the DB 
instance to complete. (default: True)
     """
 
     def __init__(
@@ -619,8 +634,7 @@ class RdsDeleteDbInstanceOperator(RdsBaseOperator):
     :param rds_kwargs: Named arguments to pass to boto3 RDS client function 
``delete_db_instance``
         
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds.html#RDS.Client.delete_db_instance
     :param aws_conn_id: The Airflow connection used for AWS credentials.
-    :param wait_for_completion:  Whether or not wait for deletion of the DB 
instance to
-        complete. (default: True)
+    :param wait_for_completion:  If True, waits for deletion of the DB 
instance to complete. (default: True)
     """
 
     def __init__(
diff --git a/docs/apache-airflow-providers-amazon/operators/rds.rst 
b/docs/apache-airflow-providers-amazon/operators/rds.rst
index 022b9ca8f1..c6e4614cd2 100644
--- a/docs/apache-airflow-providers-amazon/operators/rds.rst
+++ b/docs/apache-airflow-providers-amazon/operators/rds.rst
@@ -86,7 +86,7 @@ To export an Amazon RDS snapshot to Amazon S3 you can use
 
:class:`~airflow.providers.amazon.aws.operators.rds.RDSStartExportTaskOperator`.
 The provided IAM role must have access to the S3 bucket.
 
-.. exampleinclude:: 
/../../airflow/providers/amazon/aws/example_dags/example_rds_export.py
+.. exampleinclude:: 
/../../tests/system/providers/amazon/aws/example_rds_export.py
     :language: python
     :dedent: 4
     :start-after: [START howto_operator_rds_start_export_task]
@@ -101,7 +101,7 @@ To cancel an Amazon RDS export task to S3 you can use
 
:class:`~airflow.providers.amazon.aws.operators.rds.RDSCancelExportTaskOperator`.
 Any data that has already been written to the S3 bucket isn't removed.
 
-.. exampleinclude:: 
/../../airflow/providers/amazon/aws/example_dags/example_rds_export.py
+.. exampleinclude:: 
/../../tests/system/providers/amazon/aws/example_rds_export.py
     :language: python
     :dedent: 4
     :start-after: [START howto_operator_rds_cancel_export]
@@ -194,7 +194,7 @@ To wait a for an Amazon RDS snapshot export task with 
specific statuses you can
 
:class:`~airflow.providers.amazon.aws.sensors.rds.RdsExportTaskExistenceSensor`.
 By default, the sensor waits for the existence of a snapshot with status 
``available``.
 
-.. exampleinclude:: 
/../../airflow/providers/amazon/aws/example_dags/example_rds_export.py
+.. exampleinclude:: 
/../../tests/system/providers/amazon/aws/example_rds_export.py
     :language: python
     :dedent: 4
     :start-after: [START howto_sensor_rds_export_task_existence]
diff --git a/tests/providers/amazon/aws/operators/test_rds.py 
b/tests/providers/amazon/aws/operators/test_rds.py
index 355ff529c9..fd78b5d530 100644
--- a/tests/providers/amazon/aws/operators/test_rds.py
+++ b/tests/providers/amazon/aws/operators/test_rds.py
@@ -16,6 +16,7 @@
 # specific language governing permissions and limitations
 # under the License.
 #
+from unittest.mock import patch
 
 import pytest
 
@@ -198,6 +199,28 @@ class TestRdsCreateDbSnapshotOperator:
         assert instance_snapshots
         assert len(instance_snapshots) == 1
 
+    @mock_rds
+    @patch.object(RdsBaseOperator, '_await_status')
+    def test_create_db_instance_snapshot_no_wait(self, mock_await_status):
+        _create_db_instance(self.hook)
+        instance_snapshot_operator = RdsCreateDbSnapshotOperator(
+            task_id='test_instance_no_wait',
+            db_type='instance',
+            db_snapshot_identifier=DB_INSTANCE_SNAPSHOT,
+            db_identifier=DB_INSTANCE_NAME,
+            aws_conn_id=AWS_CONN,
+            dag=self.dag,
+            wait_for_completion=False,
+        )
+        instance_snapshot_operator.execute(None)
+
+        result = 
self.hook.conn.describe_db_snapshots(DBSnapshotIdentifier=DB_INSTANCE_SNAPSHOT)
+        instance_snapshots = result.get("DBSnapshots")
+
+        assert instance_snapshots
+        assert len(instance_snapshots) == 1
+        assert mock_await_status.not_called()
+
     @mock_rds
     def test_create_db_cluster_snapshot(self):
         _create_db_cluster(self.hook)
@@ -217,6 +240,28 @@ class TestRdsCreateDbSnapshotOperator:
         assert cluster_snapshots
         assert len(cluster_snapshots) == 1
 
+    @mock_rds
+    @patch.object(RdsBaseOperator, '_await_status')
+    def test_create_db_cluster_snapshot_no_wait(self, mock_no_wait):
+        _create_db_cluster(self.hook)
+        cluster_snapshot_operator = RdsCreateDbSnapshotOperator(
+            task_id='test_cluster_no_wait',
+            db_type='cluster',
+            db_snapshot_identifier=DB_CLUSTER_SNAPSHOT,
+            db_identifier=DB_CLUSTER_NAME,
+            aws_conn_id=AWS_CONN,
+            dag=self.dag,
+            wait_for_completion=False,
+        )
+        cluster_snapshot_operator.execute(None)
+
+        result = 
self.hook.conn.describe_db_cluster_snapshots(DBClusterSnapshotIdentifier=DB_CLUSTER_SNAPSHOT)
+        cluster_snapshots = result.get("DBClusterSnapshots")
+
+        assert cluster_snapshots
+        assert len(cluster_snapshots) == 1
+        assert mock_no_wait.not_called()
+
 
 @pytest.mark.skipif(mock_rds is None, reason='mock_rds package not present')
 class TestRdsCopyDbSnapshotOperator:
@@ -250,6 +295,29 @@ class TestRdsCopyDbSnapshotOperator:
         assert instance_snapshots
         assert len(instance_snapshots) == 1
 
+    @mock_rds
+    @patch.object(RdsBaseOperator, '_await_status')
+    def test_copy_db_instance_snapshot_no_wait(self, mock_await_status):
+        _create_db_instance(self.hook)
+        _create_db_instance_snapshot(self.hook)
+
+        instance_snapshot_operator = RdsCopyDbSnapshotOperator(
+            task_id='test_instance_no_wait',
+            db_type='instance',
+            source_db_snapshot_identifier=DB_INSTANCE_SNAPSHOT,
+            target_db_snapshot_identifier=DB_INSTANCE_SNAPSHOT_COPY,
+            aws_conn_id=AWS_CONN,
+            dag=self.dag,
+            wait_for_completion=False,
+        )
+        instance_snapshot_operator.execute(None)
+        result = 
self.hook.conn.describe_db_snapshots(DBSnapshotIdentifier=DB_INSTANCE_SNAPSHOT_COPY)
+        instance_snapshots = result.get("DBSnapshots")
+
+        assert instance_snapshots
+        assert len(instance_snapshots) == 1
+        assert mock_await_status.not_called()
+
     @mock_rds
     def test_copy_db_cluster_snapshot(self):
         _create_db_cluster(self.hook)
@@ -272,6 +340,30 @@ class TestRdsCopyDbSnapshotOperator:
         assert cluster_snapshots
         assert len(cluster_snapshots) == 1
 
+    @mock_rds
+    @patch.object(RdsBaseOperator, '_await_status')
+    def test_copy_db_cluster_snapshot_no_wait(self, mock_await_status):
+        _create_db_cluster(self.hook)
+        _create_db_cluster_snapshot(self.hook)
+
+        cluster_snapshot_operator = RdsCopyDbSnapshotOperator(
+            task_id='test_cluster_no_wait',
+            db_type='cluster',
+            source_db_snapshot_identifier=DB_CLUSTER_SNAPSHOT,
+            target_db_snapshot_identifier=DB_CLUSTER_SNAPSHOT_COPY,
+            aws_conn_id=AWS_CONN,
+            dag=self.dag,
+        )
+        cluster_snapshot_operator.execute(None)
+        result = self.hook.conn.describe_db_cluster_snapshots(
+            DBClusterSnapshotIdentifier=DB_CLUSTER_SNAPSHOT_COPY
+        )
+        cluster_snapshots = result.get("DBClusterSnapshots")
+
+        assert cluster_snapshots
+        assert len(cluster_snapshots) == 1
+        assert mock_await_status.not_called()
+
 
 @pytest.mark.skipif(mock_rds is None, reason='mock_rds package not present')
 class TestRdsDeleteDbSnapshotOperator:
@@ -356,6 +448,33 @@ class TestRdsStartExportTaskOperator:
         assert len(export_tasks) == 1
         assert export_tasks[0]['Status'] == 'complete'
 
+    @mock_rds
+    @patch.object(RdsBaseOperator, '_await_status')
+    def test_start_export_task_no_wait(self, mock_await_status):
+        _create_db_instance(self.hook)
+        _create_db_instance_snapshot(self.hook)
+
+        start_export_operator = RdsStartExportTaskOperator(
+            task_id='test_start_no_wait',
+            export_task_identifier=EXPORT_TASK_NAME,
+            source_arn=EXPORT_TASK_SOURCE,
+            iam_role_arn=EXPORT_TASK_ROLE_ARN,
+            kms_key_id=EXPORT_TASK_KMS,
+            s3_bucket_name=EXPORT_TASK_BUCKET,
+            aws_conn_id=AWS_CONN,
+            dag=self.dag,
+            wait_for_completion=False,
+        )
+        start_export_operator.execute(None)
+
+        result = 
self.hook.conn.describe_export_tasks(ExportTaskIdentifier=EXPORT_TASK_NAME)
+        export_tasks = result.get("ExportTasks")
+
+        assert export_tasks
+        assert len(export_tasks) == 1
+        assert export_tasks[0]['Status'] == 'complete'
+        assert mock_await_status.not_called()
+
 
 @pytest.mark.skipif(mock_rds is None, reason='mock_rds package not present')
 class TestRdsCancelExportTaskOperator:
@@ -390,6 +509,29 @@ class TestRdsCancelExportTaskOperator:
         assert len(export_tasks) == 1
         assert export_tasks[0]['Status'] == 'canceled'
 
+    @mock_rds
+    @patch.object(RdsBaseOperator, '_await_status')
+    def test_cancel_export_task_no_wait(self, mock_await_status):
+        _create_db_instance(self.hook)
+        _create_db_instance_snapshot(self.hook)
+        _start_export_task(self.hook)
+
+        cancel_export_operator = RdsCancelExportTaskOperator(
+            task_id='test_cancel_no_wait',
+            export_task_identifier=EXPORT_TASK_NAME,
+            aws_conn_id=AWS_CONN,
+            dag=self.dag,
+        )
+        cancel_export_operator.execute(None)
+
+        result = 
self.hook.conn.describe_export_tasks(ExportTaskIdentifier=EXPORT_TASK_NAME)
+        export_tasks = result.get("ExportTasks")
+
+        assert export_tasks
+        assert len(export_tasks) == 1
+        assert export_tasks[0]['Status'] == 'canceled'
+        assert mock_await_status.not_called()
+
 
 @pytest.mark.skipif(mock_rds is None, reason='mock_rds package not present')
 class TestRdsCreateEventSubscriptionOperator:
@@ -425,6 +567,30 @@ class TestRdsCreateEventSubscriptionOperator:
         assert len(subscriptions) == 1
         assert subscriptions[0]['Status'] == 'active'
 
+    @mock_rds
+    @patch.object(RdsBaseOperator, '_await_status')
+    def test_create_event_subscription_no_wait(self, mock_await_status):
+        _create_db_instance(self.hook)
+
+        create_subscription_operator = RdsCreateEventSubscriptionOperator(
+            task_id='test_create_no_wait',
+            subscription_name=SUBSCRIPTION_NAME,
+            sns_topic_arn=SUBSCRIPTION_TOPIC,
+            source_type='db-instance',
+            source_ids=[DB_INSTANCE_NAME],
+            aws_conn_id=AWS_CONN,
+            dag=self.dag,
+        )
+        create_subscription_operator.execute(None)
+
+        result = 
self.hook.conn.describe_event_subscriptions(SubscriptionName=SUBSCRIPTION_NAME)
+        subscriptions = result.get("EventSubscriptionsList")
+
+        assert subscriptions
+        assert len(subscriptions) == 1
+        assert subscriptions[0]['Status'] == 'active'
+        assert mock_await_status.not_called()
+
 
 @pytest.mark.skipif(mock_rds is None, reason='mock_rds package not present')
 class TestRdsDeleteEventSubscriptionOperator:
@@ -488,6 +654,30 @@ class TestRdsCreateDbInstanceOperator:
         assert len(db_instances) == 1
         assert db_instances[0]['DBInstanceStatus'] == 'available'
 
+    @mock_rds
+    @patch.object(RdsBaseOperator, '_await_status')
+    def test_create_db_instance_no_wait(self, mock_await_status):
+        create_db_instance_operator = RdsCreateDbInstanceOperator(
+            task_id='test_create_db_instance_no_wait',
+            db_instance_identifier=DB_INSTANCE_NAME,
+            db_instance_class="db.m5.large",
+            engine="postgres",
+            rds_kwargs={
+                "DBName": DB_INSTANCE_NAME,
+            },
+            aws_conn_id=AWS_CONN,
+            dag=self.dag,
+        )
+        create_db_instance_operator.execute(None)
+
+        result = 
self.hook.conn.describe_db_instances(DBInstanceIdentifier=DB_INSTANCE_NAME)
+        db_instances = result.get("DBInstances")
+
+        assert db_instances
+        assert len(db_instances) == 1
+        assert db_instances[0]['DBInstanceStatus'] == 'available'
+        assert mock_await_status.not_called()
+
 
 @pytest.mark.skipif(mock_rds is None, reason='mock_rds package not present')
 class TestRdsDeleteDbInstanceOperator:
@@ -518,3 +708,24 @@ class TestRdsDeleteDbInstanceOperator:
 
         with pytest.raises(self.hook.conn.exceptions.ClientError):
             
self.hook.conn.describe_db_instances(DBInstanceIdentifier=DB_INSTANCE_NAME)
+
+    @mock_rds
+    @patch.object(RdsBaseOperator, '_await_status')
+    def test_delete_event_subscription_no_wait(self, mock_await_status):
+        _create_db_instance(self.hook)
+
+        delete_db_instance_operator = RdsDeleteDbInstanceOperator(
+            task_id='test_delete_db_instance_no_wait',
+            db_instance_identifier=DB_INSTANCE_NAME,
+            rds_kwargs={
+                "SkipFinalSnapshot": True,
+            },
+            aws_conn_id=AWS_CONN,
+            dag=self.dag,
+            wait_for_completion=False,
+        )
+        delete_db_instance_operator.execute(None)
+
+        with pytest.raises(self.hook.conn.exceptions.ClientError):
+            
self.hook.conn.describe_db_instances(DBInstanceIdentifier=DB_INSTANCE_NAME)
+        assert mock_await_status.not_called()
diff --git a/tests/system/providers/amazon/aws/example_rds_export.py 
b/tests/system/providers/amazon/aws/example_rds_export.py
new file mode 100644
index 0000000000..baed5a4bce
--- /dev/null
+++ b/tests/system/providers/amazon/aws/example_rds_export.py
@@ -0,0 +1,188 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from datetime import datetime
+
+from airflow import DAG
+from airflow.decorators import task
+from airflow.models.baseoperator import chain
+from airflow.providers.amazon.aws.hooks.rds import RdsHook
+from airflow.providers.amazon.aws.operators.rds import (
+    RdsCancelExportTaskOperator,
+    RdsCreateDbSnapshotOperator,
+    RdsDeleteDbSnapshotOperator,
+    RdsStartExportTaskOperator,
+)
+from airflow.providers.amazon.aws.operators.s3 import S3CreateBucketOperator, 
S3DeleteBucketOperator
+from airflow.providers.amazon.aws.sensors.rds import 
RdsExportTaskExistenceSensor, RdsSnapshotExistenceSensor
+from airflow.utils.trigger_rule import TriggerRule
+from tests.system.providers.amazon.aws.utils import ENV_ID_KEY, 
SystemTestContextBuilder
+
+DAG_ID = 'example_rds_export'
+
+# Externally fetched variables:
+KMS_KEY_ID_KEY = 'KMS_KEY_ID'
+ROLE_ARN_KEY = 'ROLE_ARN'
+
+sys_test_context_task = (
+    
SystemTestContextBuilder().add_variable(KMS_KEY_ID_KEY).add_variable(ROLE_ARN_KEY).build()
+)
+
+
+@task
+def create_rds_instance(db_name: str, instance_name: str) -> None:
+    rds_client = RdsHook().conn
+    rds_client.create_db_instance(
+        DBName=db_name,
+        DBInstanceIdentifier=instance_name,
+        AllocatedStorage=20,
+        DBInstanceClass='db.t3.micro',
+        Engine='postgres',
+        MasterUsername='username',
+        # NEVER store your production password in plaintext in a DAG like this.
+        # Use Airflow Secrets or a secret manager for this in production.
+        MasterUserPassword='rds_password',
+    )
+
+    
rds_client.get_waiter('db_instance_available').wait(DBInstanceIdentifier=instance_name)
+
+
+@task
+def get_snapshot_arn(snapshot_name: str) -> str:
+    result = 
RdsHook().conn.describe_db_snapshots(DBSnapshotIdentifier=snapshot_name)
+    return result['DBSnapshots'][0]['DBSnapshotArn']
+
+
+@task(trigger_rule=TriggerRule.ALL_DONE)
+def delete_rds_instance(instance_name) -> None:
+    rds_client = RdsHook().get_conn()
+    rds_client.delete_db_instance(
+        DBInstanceIdentifier=instance_name,
+        SkipFinalSnapshot=True,
+    )
+
+    
rds_client.get_waiter('db_instance_deleted').wait(DBInstanceIdentifier=instance_name)
+
+
+with DAG(
+    dag_id=DAG_ID,
+    schedule_interval='@once',
+    start_date=datetime(2021, 1, 1),
+    tags=['example'],
+    catchup=False,
+) as dag:
+    test_context = sys_test_context_task()
+    env_id = test_context[ENV_ID_KEY]
+
+    bucket_name: str = f'{env_id}-bucket'
+
+    rds_db_name: str = f'{env_id}_db'
+    rds_instance_name: str = f'{env_id}-instance'
+    rds_snapshot_name: str = f'{env_id}-snapshot'
+    rds_export_task_id: str = f'{env_id}-export-task'
+
+    create_bucket = S3CreateBucketOperator(
+        task_id='create_bucket',
+        bucket_name=bucket_name,
+    )
+
+    create_snapshot = RdsCreateDbSnapshotOperator(
+        task_id='create_snapshot',
+        db_type='instance',
+        db_identifier=rds_instance_name,
+        db_snapshot_identifier=rds_snapshot_name,
+    )
+
+    await_snapshot = RdsSnapshotExistenceSensor(
+        task_id='snapshot_sensor',
+        db_type='instance',
+        db_snapshot_identifier=rds_snapshot_name,
+        target_statuses=['available'],
+    )
+
+    snapshot_arn = get_snapshot_arn(rds_snapshot_name)
+
+    # [START howto_operator_rds_start_export_task]
+    start_export = RdsStartExportTaskOperator(
+        task_id='start_export',
+        export_task_identifier=rds_export_task_id,
+        source_arn=snapshot_arn,
+        s3_bucket_name=bucket_name,
+        s3_prefix='rds-test',
+        iam_role_arn=test_context[ROLE_ARN_KEY],
+        kms_key_id=test_context[KMS_KEY_ID_KEY],
+    )
+    # [END howto_operator_rds_start_export_task]
+
+    # [START howto_operator_rds_cancel_export]
+    cancel_export = RdsCancelExportTaskOperator(
+        task_id='cancel_export',
+        export_task_identifier=rds_export_task_id,
+    )
+    # [END howto_operator_rds_cancel_export]
+
+    # [START howto_sensor_rds_export_task_existence]
+    export_sensor = RdsExportTaskExistenceSensor(
+        task_id='export_sensor',
+        export_task_identifier=rds_export_task_id,
+        target_statuses=['canceled'],
+    )
+    # [END howto_sensor_rds_export_task_existence]
+
+    delete_snapshot = RdsDeleteDbSnapshotOperator(
+        task_id='delete_snapshot',
+        trigger_rule=TriggerRule.ALL_DONE,
+        db_type='instance',
+        db_snapshot_identifier=rds_snapshot_name,
+    )
+
+    delete_bucket = S3DeleteBucketOperator(
+        task_id='delete_bucket',
+        trigger_rule=TriggerRule.ALL_DONE,
+        bucket_name=bucket_name,
+        force_delete=True,
+    )
+
+    chain(
+        # TEST SETUP
+        test_context,
+        create_bucket,
+        create_rds_instance(rds_db_name, rds_instance_name),
+        create_snapshot,
+        await_snapshot,
+        snapshot_arn,
+        # TEST BODY
+        start_export,
+        cancel_export,
+        export_sensor,
+        # TEST TEARDOWN
+        delete_snapshot,
+        delete_bucket,
+        delete_rds_instance(rds_instance_name),
+    )
+
+    from tests.system.utils.watcher import watcher
+
+    # This test needs watcher in order to properly mark success/failure
+    # when "tearDown" task with trigger rule is part of the DAG
+    list(dag.tasks) >> watcher()
+
+from tests.system.utils import get_test_run  # noqa: E402
+
+# Needed to run the example DAG with pytest (see: 
tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)

Reply via email to