This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 43652d5bbc3 Update Amazon RDS Operators and Sensors to inherit AWS 
Base classes (#48872)
43652d5bbc3 is described below

commit 43652d5bbc3735868b05233aebb30c48c8d2f5c9
Author: ellisms <[email protected]>
AuthorDate: Mon Apr 7 11:41:14 2025 -0400

    Update Amazon RDS Operators and Sensors to inherit AWS Base classes (#48872)
    
    * Update RDS operator to inherit AwsBaseOperator
    
    ---------
    
    Co-authored-by: mse139 <[email protected]>
---
 providers/amazon/docs/operators/rds.rst            |   5 +
 .../airflow/providers/amazon/aws/operators/rds.py  | 101 +++++++++++++++++----
 .../airflow/providers/amazon/aws/sensors/rds.py    |  43 +++++----
 .../tests/unit/amazon/aws/operators/test_rds.py    |  18 +++-
 .../tests/unit/amazon/aws/sensors/test_rds.py      |  30 ++++++
 5 files changed, 158 insertions(+), 39 deletions(-)

diff --git a/providers/amazon/docs/operators/rds.rst 
b/providers/amazon/docs/operators/rds.rst
index 841dd90a6a3..916c61d8548 100644
--- a/providers/amazon/docs/operators/rds.rst
+++ b/providers/amazon/docs/operators/rds.rst
@@ -29,6 +29,11 @@ Prerequisite Tasks
 
 .. include:: ../_partials/prerequisite_tasks.rst
 
+Generic Parameters
+------------------
+
+.. include:: ../_partials/generic_parameters.rst
+
 Operators
 ---------
 
diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/rds.py 
b/providers/amazon/src/airflow/providers/amazon/aws/operators/rds.py
index 18f227d85ae..c7841a08ce1 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/operators/rds.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/rds.py
@@ -20,19 +20,19 @@ from __future__ import annotations
 import json
 from collections.abc import Sequence
 from datetime import timedelta
-from functools import cached_property
 from typing import TYPE_CHECKING, Any
 
 from airflow.configuration import conf
 from airflow.exceptions import AirflowException
-from airflow.models import BaseOperator
 from airflow.providers.amazon.aws.hooks.rds import RdsHook
+from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
 from airflow.providers.amazon.aws.triggers.rds import (
     RdsDbAvailableTrigger,
     RdsDbDeletedTrigger,
     RdsDbStoppedTrigger,
 )
 from airflow.providers.amazon.aws.utils import validate_execute_complete_event
+from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
 from airflow.providers.amazon.aws.utils.rds import RdsDbType
 from airflow.providers.amazon.aws.utils.tags import format_tags
 from airflow.providers.amazon.aws.utils.waiter_with_logging import wait
@@ -44,9 +44,10 @@ if TYPE_CHECKING:
     from airflow.utils.context import Context
 
 
-class RdsBaseOperator(BaseOperator):
+class RdsBaseOperator(AwsBaseOperator[RdsHook]):
     """Base operator that implements common functions for all operators."""
 
+    aws_hook_class = RdsHook
     ui_color = "#eeaa88"
     ui_fgcolor = "#ffffff"
 
@@ -63,10 +64,6 @@ class RdsBaseOperator(BaseOperator):
 
         self._await_interval = 60  # seconds
 
-    @cached_property
-    def hook(self) -> RdsHook:
-        return RdsHook(aws_conn_id=self.aws_conn_id, 
region_name=self.region_name)
-
     def execute(self, context: Context) -> str:
         """Different implementations for snapshots, tasks and events."""
         raise NotImplementedError
@@ -92,9 +89,19 @@ class RdsCreateDbSnapshotOperator(RdsBaseOperator):
     :param tags: A dictionary of tags or a list of tags in format `[{"Key": 
"...", "Value": "..."},]`
         `USER Tagging 
<https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/USER_Tagging.html>`__
     :param wait_for_completion:  If True, waits for creation of the DB 
snapshot to complete. (default: True)
+    :param aws_conn_id: The Airflow connection used for AWS credentials.
+         If this is ``None`` or empty then the default boto3 behaviour is 
used. If
+         running Airflow in a distributed manner and aws_conn_id is None or
+         empty, then default boto3 configuration would be used (and must be
+         maintained on each worker node).
+    :param region_name: AWS region_name. If not specified then the default 
boto3 behaviour is used.
+    :param verify: Whether or not to verify SSL certificates. See:
+        
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
+    :param botocore_config: Configuration dictionary (key-values) for botocore 
client. See:
+        
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
     """
 
-    template_fields = ("db_snapshot_identifier", "db_identifier", "tags")
+    template_fields = aws_template_fields("db_snapshot_identifier", 
"db_identifier", "tags")
 
     def __init__(
         self,
@@ -167,9 +174,14 @@ class RdsCopyDbSnapshotOperator(RdsBaseOperator):
         Only when db_type='instance'
     :param source_region: The ID of the region that contains the snapshot to 
be copied
     :param wait_for_completion:  If True, waits for snapshot copy to complete. 
(default: True)
+    :param region_name: AWS region_name. If not specified then the default 
boto3 behaviour is used.
+    :param verify: Whether or not to verify SSL certificates. See:
+        
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
+    :param botocore_config: Configuration dictionary (key-values) for botocore 
client. See:
+        
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
     """
 
-    template_fields = (
+    template_fields = aws_template_fields(
         "source_db_snapshot_identifier",
         "target_db_snapshot_identifier",
         "tags",
@@ -260,9 +272,16 @@ class RdsDeleteDbSnapshotOperator(RdsBaseOperator):
 
     :param db_type: Type of the DB - either "instance" or "cluster"
     :param db_snapshot_identifier: The identifier for the DB instance or DB 
cluster snapshot
+    :param region_name: AWS region_name. If not specified then the default 
boto3 behaviour is used.
+    :param verify: Whether or not to verify SSL certificates. See:
+        
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
+    :param botocore_config: Configuration dictionary (key-values) for botocore 
client. See:
+        
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
     """
 
-    template_fields = ("db_snapshot_identifier",)
+    template_fields = aws_template_fields(
+        "db_snapshot_identifier",
+    )
 
     def __init__(
         self,
@@ -319,9 +338,14 @@ class RdsStartExportTaskOperator(RdsBaseOperator):
     :param wait_for_completion:  If True, waits for the DB snapshot export to 
complete. (default: True)
     :param waiter_interval: The number of seconds to wait before checking the 
export status. (default: 30)
     :param waiter_max_attempts: The number of attempts to make before failing. 
(default: 40)
+    :param region_name: AWS region_name. If not specified then the default 
boto3 behaviour is used.
+    :param verify: Whether or not to verify SSL certificates. See:
+        
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
+    :param botocore_config: Configuration dictionary (key-values) for botocore 
client. See:
+        
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
     """
 
-    template_fields = (
+    template_fields = aws_template_fields(
         "export_task_identifier",
         "source_arn",
         "s3_bucket_name",
@@ -394,9 +418,16 @@ class RdsCancelExportTaskOperator(RdsBaseOperator):
     :param wait_for_completion:  If True, waits for DB snapshot export to 
cancel. (default: True)
     :param check_interval: The amount of time in seconds to wait between 
attempts
     :param max_attempts: The maximum number of attempts to be made
+    :param region_name: AWS region_name. If not specified then the default 
boto3 behaviour is used.
+    :param verify: Whether or not to verify SSL certificates. See:
+        
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
+    :param botocore_config: Configuration dictionary (key-values) for botocore 
client. See:
+        
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
     """
 
-    template_fields = ("export_task_identifier",)
+    template_fields = aws_template_fields(
+        "export_task_identifier",
+    )
 
     def __init__(
         self,
@@ -450,9 +481,14 @@ class RdsCreateEventSubscriptionOperator(RdsBaseOperator):
     :param tags: A dictionary of tags or a list of tags in format `[{"Key": 
"...", "Value": "..."},]`
         `USER Tagging 
<https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/USER_Tagging.html>`__
     :param wait_for_completion:  If True, waits for creation of the 
subscription to complete. (default: True)
+    :param region_name: AWS region_name. If not specified then the default 
boto3 behaviour is used.
+    :param verify: Whether or not to verify SSL certificates. See:
+        
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
+    :param botocore_config: Configuration dictionary (key-values) for botocore 
client. See:
+        
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
     """
 
-    template_fields = (
+    template_fields = aws_template_fields(
         "subscription_name",
         "sns_topic_arn",
         "source_type",
@@ -513,9 +549,16 @@ class RdsDeleteEventSubscriptionOperator(RdsBaseOperator):
         :ref:`howto/operator:RdsDeleteEventSubscriptionOperator`
 
     :param subscription_name: The name of the RDS event notification 
subscription you want to delete
+    :param region_name: AWS region_name. If not specified then the default 
boto3 behaviour is used.
+    :param verify: Whether or not to verify SSL certificates. See:
+        
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
+    :param botocore_config: Configuration dictionary (key-values) for botocore 
client. See:
+        
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
     """
 
-    template_fields = ("subscription_name",)
+    template_fields = aws_template_fields(
+        "subscription_name",
+    )
 
     def __init__(
         self,
@@ -560,9 +603,16 @@ class RdsCreateDbInstanceOperator(RdsBaseOperator):
     :param deferrable: If True, the operator will wait asynchronously for the 
DB instance to be created.
         This implies waiting for completion. This mode requires aiobotocore 
module to be installed.
         (default: False)
+    :param region_name: AWS region_name. If not specified then the default 
boto3 behaviour is used.
+    :param verify: Whether or not to verify SSL certificates. See:
+        
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
+    :param botocore_config: Configuration dictionary (key-values) for botocore 
client. See:
+        
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
     """
 
-    template_fields = ("db_instance_identifier", "db_instance_class", 
"engine", "rds_kwargs")
+    template_fields = aws_template_fields(
+        "db_instance_identifier", "db_instance_class", "engine", "rds_kwargs"
+    )
 
     def __init__(
         self,
@@ -652,9 +702,14 @@ class RdsDeleteDbInstanceOperator(RdsBaseOperator):
     :param deferrable: If True, the operator will wait asynchronously for the 
DB instance to be created.
         This implies waiting for completion. This mode requires aiobotocore 
module to be installed.
         (default: False)
+    :param region_name: AWS region_name. If not specified then the default 
boto3 behaviour is used.
+    :param verify: Whether or not to verify SSL certificates. See:
+        
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
+    :param botocore_config: Configuration dictionary (key-values) for botocore 
client. See:
+        
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
     """
 
-    template_fields = ("db_instance_identifier", "rds_kwargs")
+    template_fields = aws_template_fields("db_instance_identifier", 
"rds_kwargs")
 
     def __init__(
         self,
@@ -735,9 +790,14 @@ class RdsStartDbOperator(RdsBaseOperator):
     :param waiter_max_attempts: The maximum number of attempts to check DB 
instance state
     :param deferrable: If True, the operator will wait asynchronously for the 
DB instance to be created.
         This implies waiting for completion. This mode requires aiobotocore 
module to be installed.
+    :param region_name: AWS region_name. If not specified then the default 
boto3 behaviour is used.
+    :param verify: Whether or not to verify SSL certificates. See:
+        
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
+    :param botocore_config: Configuration dictionary (key-values) for botocore 
client. See:
+        
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
     """
 
-    template_fields = ("db_identifier", "db_type")
+    template_fields = aws_template_fields("db_identifier", "db_type")
 
     def __init__(
         self,
@@ -832,9 +892,14 @@ class RdsStopDbOperator(RdsBaseOperator):
     :param waiter_max_attempts: The maximum number of attempts to check DB 
instance state
     :param deferrable: If True, the operator will wait asynchronously for the 
DB instance to be created.
         This implies waiting for completion. This mode requires aiobotocore 
module to be installed.
+    :param region_name: AWS region_name. If not specified then the default 
boto3 behaviour is used.
+    :param verify: Whether or not to verify SSL certificates. See:
+        
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
+    :param botocore_config: Configuration dictionary (key-values) for botocore 
client. See:
+        
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
     """
 
-    template_fields = ("db_identifier", "db_snapshot_identifier", "db_type")
+    template_fields = aws_template_fields("db_identifier", 
"db_snapshot_identifier", "db_type")
 
     def __init__(
         self,
diff --git a/providers/amazon/src/airflow/providers/amazon/aws/sensors/rds.py 
b/providers/amazon/src/airflow/providers/amazon/aws/sensors/rds.py
index 8e290e61a87..03a170d29a1 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/sensors/rds.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/sensors/rds.py
@@ -17,36 +17,30 @@
 from __future__ import annotations
 
 from collections.abc import Sequence
-from functools import cached_property
 from typing import TYPE_CHECKING
 
 from airflow.exceptions import AirflowException, AirflowNotFoundException
 from airflow.providers.amazon.aws.hooks.rds import RdsHook
+from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
+from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
 from airflow.providers.amazon.aws.utils.rds import RdsDbType
-from airflow.sensors.base import BaseSensorOperator
 
 if TYPE_CHECKING:
     from airflow.utils.context import Context
 
 
-class RdsBaseSensor(BaseSensorOperator):
+class RdsBaseSensor(AwsBaseSensor[RdsHook]):
     """Base operator that implements common functions for all sensors."""
 
+    aws_hook_class = RdsHook
     ui_color = "#ddbb77"
     ui_fgcolor = "#ffffff"
 
-    def __init__(
-        self, *args, aws_conn_id: str | None = "aws_conn_id", hook_params: 
dict | None = None, **kwargs
-    ):
+    def __init__(self, *args, hook_params: dict | None = None, **kwargs):
         self.hook_params = hook_params or {}
-        self.aws_conn_id = aws_conn_id
         self.target_statuses: list[str] = []
         super().__init__(*args, **kwargs)
 
-    @cached_property
-    def hook(self):
-        return RdsHook(aws_conn_id=self.aws_conn_id, **self.hook_params)
-
 
 class RdsSnapshotExistenceSensor(RdsBaseSensor):
     """
@@ -59,9 +53,19 @@ class RdsSnapshotExistenceSensor(RdsBaseSensor):
     :param db_type: Type of the DB - either "instance" or "cluster"
     :param db_snapshot_identifier: The identifier for the DB snapshot
     :param target_statuses: Target status of snapshot
+    :param aws_conn_id: The Airflow connection used for AWS credentials.
+         If this is ``None`` or empty then the default boto3 behaviour is 
used. If
+         running Airflow in a distributed manner and aws_conn_id is None or
+         empty, then default boto3 configuration would be used (and must be
+         maintained on each worker node).
+    :param region_name: AWS region_name. If not specified then the default 
boto3 behaviour is used.
+    :param verify: Whether or not to verify SSL certificates. See:
+        
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
+    :param botocore_config: Configuration dictionary (key-values) for botocore 
client. See:
+        
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
     """
 
-    template_fields: Sequence[str] = (
+    template_fields: Sequence[str] = aws_template_fields(
         "db_snapshot_identifier",
         "target_statuses",
     )
@@ -72,10 +76,9 @@ class RdsSnapshotExistenceSensor(RdsBaseSensor):
         db_type: str,
         db_snapshot_identifier: str,
         target_statuses: list[str] | None = None,
-        aws_conn_id: str | None = "aws_conn_id",
         **kwargs,
     ):
-        super().__init__(aws_conn_id=aws_conn_id, **kwargs)
+        super().__init__(**kwargs)
         self.db_type = RdsDbType(db_type)
         self.db_snapshot_identifier = db_snapshot_identifier
         self.target_statuses = target_statuses or ["available"]
@@ -107,7 +110,9 @@ class RdsExportTaskExistenceSensor(RdsBaseSensor):
     :param error_statuses: Target error status of export task to fail the 
sensor
     """
 
-    template_fields: Sequence[str] = ("export_task_identifier", 
"target_statuses", "error_statuses")
+    template_fields: Sequence[str] = aws_template_fields(
+        "export_task_identifier", "target_statuses", "error_statuses"
+    )
 
     def __init__(
         self,
@@ -115,10 +120,9 @@ class RdsExportTaskExistenceSensor(RdsBaseSensor):
         export_task_identifier: str,
         target_statuses: list[str] | None = None,
         error_statuses: list[str] | None = None,
-        aws_conn_id: str | None = "aws_default",
         **kwargs,
     ):
-        super().__init__(aws_conn_id=aws_conn_id, **kwargs)
+        super().__init__(**kwargs)
 
         self.export_task_identifier = export_task_identifier
         self.target_statuses = target_statuses or [
@@ -159,7 +163,7 @@ class RdsDbSensor(RdsBaseSensor):
     :param target_statuses: Target status of DB
     """
 
-    template_fields: Sequence[str] = (
+    template_fields: Sequence[str] = aws_template_fields(
         "db_identifier",
         "db_type",
         "target_statuses",
@@ -171,10 +175,9 @@ class RdsDbSensor(RdsBaseSensor):
         db_identifier: str,
         db_type: RdsDbType | str = RdsDbType.INSTANCE,
         target_statuses: list[str] | None = None,
-        aws_conn_id: str | None = "aws_default",
         **kwargs,
     ):
-        super().__init__(aws_conn_id=aws_conn_id, **kwargs)
+        super().__init__(**kwargs)
         self.db_identifier = db_identifier
         self.target_statuses = target_statuses or ["available"]
         self.db_type = db_type
diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_rds.py 
b/providers/amazon/tests/unit/amazon/aws/operators/test_rds.py
index c6537c9391b..6c22c75e8af 100644
--- a/providers/amazon/tests/unit/amazon/aws/operators/test_rds.py
+++ b/providers/amazon/tests/unit/amazon/aws/operators/test_rds.py
@@ -54,6 +54,8 @@ DEFAULT_DATE = timezone.datetime(2019, 1, 1)
 
 AWS_CONN = "amazon_default"
 
+REGION = "us-east-1"
+
 DB_INSTANCE_NAME = "my-db-instance"
 DB_CLUSTER_NAME = "my-db-cluster"
 
@@ -282,6 +284,7 @@ class TestRdsCreateDbSnapshotOperator:
             db_snapshot_identifier=DB_INSTANCE_SNAPSHOT,
             db_identifier=DB_INSTANCE_NAME,
             aws_conn_id=AWS_CONN,
+            region_name=REGION,
         )
         validate_template_fields(operator)
 
@@ -410,6 +413,7 @@ class TestRdsCopyDbSnapshotOperator:
             source_db_snapshot_identifier=DB_CLUSTER_SNAPSHOT,
             target_db_snapshot_identifier=DB_CLUSTER_SNAPSHOT_COPY,
             aws_conn_id=AWS_CONN,
+            region_name=REGION,
         )
         validate_template_fields(operator)
 
@@ -527,6 +531,7 @@ class TestRdsDeleteDbSnapshotOperator:
             db_snapshot_identifier=DB_CLUSTER_SNAPSHOT,
             aws_conn_id=AWS_CONN,
             wait_for_completion=False,
+            region_name=REGION,
         )
         validate_template_fields(operator)
 
@@ -610,6 +615,7 @@ class TestRdsStartExportTaskOperator:
             s3_bucket_name=EXPORT_TASK_BUCKET,
             aws_conn_id=AWS_CONN,
             wait_for_completion=False,
+            region_name=REGION,
         )
         validate_template_fields(operator)
 
@@ -682,6 +688,7 @@ class TestRdsCancelExportTaskOperator:
             task_id="test_cancel",
             export_task_identifier=EXPORT_TASK_NAME,
             aws_conn_id=AWS_CONN,
+            region_name=REGION,
         )
         validate_template_fields(operator)
 
@@ -759,6 +766,7 @@ class TestRdsCreateEventSubscriptionOperator:
             source_type="db-instance",
             source_ids=[DB_INSTANCE_NAME],
             aws_conn_id=AWS_CONN,
+            region_name=REGION,
         )
         validate_template_fields(operator)
 
@@ -800,6 +808,7 @@ class TestRdsDeleteEventSubscriptionOperator:
             task_id="test_delete",
             subscription_name=SUBSCRIPTION_NAME,
             aws_conn_id=AWS_CONN,
+            region_name=REGION,
         )
         validate_template_fields(operator)
 
@@ -879,6 +888,7 @@ class TestRdsCreateDbInstanceOperator:
                 "DBName": DB_INSTANCE_NAME,
             },
             aws_conn_id=AWS_CONN,
+            region_name=REGION,
         )
         validate_template_fields(operator)
 
@@ -949,6 +959,7 @@ class TestRdsDeleteDbInstanceOperator:
             },
             aws_conn_id=AWS_CONN,
             wait_for_completion=False,
+            region_name=REGION,
         )
         validate_template_fields(operator)
 
@@ -1062,6 +1073,7 @@ class TestRdsStopDbOperator:
             db_identifier=DB_CLUSTER_NAME,
             db_type="cluster",
             db_snapshot_identifier=DB_CLUSTER_SNAPSHOT,
+            region_name=REGION,
         )
         validate_template_fields(operator)
 
@@ -1133,6 +1145,10 @@ class TestRdsStartDbOperator:
 
     def test_template_fields(self):
         operator = RdsStartDbOperator(
-            task_id="test_start_db_cluster", db_identifier=DB_CLUSTER_NAME, 
db_type="cluster"
+            region_name=REGION,
+            aws_conn_id=AWS_CONN,
+            task_id="test_start_db_cluster",
+            db_identifier=DB_CLUSTER_NAME,
+            db_type="cluster",
         )
         validate_template_fields(operator)
diff --git a/providers/amazon/tests/unit/amazon/aws/sensors/test_rds.py 
b/providers/amazon/tests/unit/amazon/aws/sensors/test_rds.py
index b585ca21c7b..3bdd8f673bf 100644
--- a/providers/amazon/tests/unit/amazon/aws/sensors/test_rds.py
+++ b/providers/amazon/tests/unit/amazon/aws/sensors/test_rds.py
@@ -33,6 +33,8 @@ from airflow.providers.amazon.aws.sensors.rds import (
 from airflow.providers.amazon.aws.utils.rds import RdsDbType
 from airflow.utils import timezone
 
+from unit.amazon.aws.utils.test_template_fields import validate_template_fields
+
 DEFAULT_DATE = timezone.datetime(2019, 1, 1)
 
 AWS_CONN = "aws_default"
@@ -146,6 +148,16 @@ class TestRdsSnapshotExistenceSensor:
         del cls.dag
         del cls.hook
 
+    def test_template_fields(self):
+        sensor = RdsSnapshotExistenceSensor(
+            task_id="test_template_fields",
+            db_type="instance",
+            db_snapshot_identifier=DB_INSTANCE_SNAPSHOT,
+            aws_conn_id=AWS_CONN,
+            region_name="us-east-1",
+        )
+        validate_template_fields(sensor)
+
     @mock_aws
     def test_db_instance_snapshot_poke_true(self):
         _create_db_instance_snapshot(self.hook)
@@ -209,6 +221,15 @@ class TestRdsExportTaskExistenceSensor:
         del cls.dag
         del cls.hook
 
+    def test_template_fields(self):
+        sensor = RdsExportTaskExistenceSensor(
+            task_id="test_template_fields",
+            export_task_identifier=EXPORT_TASK_NAME,
+            aws_conn_id=AWS_CONN,
+            region_name="us-east-1",
+        )
+        validate_template_fields(sensor)
+
     @mock_aws
     def test_export_task_poke_true(self):
         _create_db_instance_snapshot(self.hook)
@@ -264,6 +285,15 @@ class TestRdsDbSensor:
         del cls.dag
         del cls.hook
 
+    def test_template_fields(self):
+        sensor = RdsDbSensor(
+            task_id="test_template_fields",
+            db_identifier=DB_INSTANCE_NAME,
+            aws_conn_id=AWS_CONN,
+            region_name="us-east-1",
+        )
+        validate_template_fields(sensor)
+
     @mock_aws
     def test_poke_true_instance(self):
         """

Reply via email to