This is an automated email from the ASF dual-hosted git repository.

husseinawala pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 1455a3babb Use base aws classes in AWS CloudFormation 
Operators/Sensors (#36771)
1455a3babb is described below

commit 1455a3babb1bf4b890562a65610b33c0db206f69
Author: Andrey Anshin <[email protected]>
AuthorDate: Sun Jan 14 17:23:02 2024 +0400

    Use base aws classes in AWS CloudFormation Operators/Sensors (#36771)
---
 .../amazon/aws/operators/cloud_formation.py        | 63 ++++++++++--------
 .../amazon/aws/sensors/cloud_formation.py          | 55 ++++++++--------
 .../operators/cloudformation.rst                   |  5 ++
 .../amazon/aws/operators/test_cloud_formation.py   | 53 +++++++++++++++
 .../amazon/aws/sensors/test_cloud_formation.py     | 75 ++++++++++++++++++++--
 5 files changed, 194 insertions(+), 57 deletions(-)

diff --git a/airflow/providers/amazon/aws/operators/cloud_formation.py 
b/airflow/providers/amazon/aws/operators/cloud_formation.py
index c6963520f8..b24ccf05f4 100644
--- a/airflow/providers/amazon/aws/operators/cloud_formation.py
+++ b/airflow/providers/amazon/aws/operators/cloud_formation.py
@@ -15,66 +15,79 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""This module contains CloudFormation create/delete stack operators."""
+"""This module contains AWS CloudFormation create/delete stack operators."""
 from __future__ import annotations
 
 from typing import TYPE_CHECKING, Sequence
 
-from airflow.models import BaseOperator
 from airflow.providers.amazon.aws.hooks.cloud_formation import 
CloudFormationHook
+from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
+from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
 
 if TYPE_CHECKING:
     from airflow.utils.context import Context
 
 
-class CloudFormationCreateStackOperator(BaseOperator):
+class CloudFormationCreateStackOperator(AwsBaseOperator[CloudFormationHook]):
     """
-    An operator that creates a CloudFormation stack.
+    An operator that creates a AWS CloudFormation stack.
 
     .. seealso::
         For more information on how to use this operator, take a look at the 
guide:
         :ref:`howto/operator:CloudFormationCreateStackOperator`
 
     :param stack_name: stack name (templated)
-    :param cloudformation_parameters: parameters to be passed to 
CloudFormation.
-    :param aws_conn_id: aws connection to uses
+    :param cloudformation_parameters: parameters to be passed to AWS 
CloudFormation.
+    :param aws_conn_id: The Airflow connection used for AWS credentials.
+        If this is ``None`` or empty then the default boto3 behaviour is used. 
If
+        running Airflow in a distributed manner and aws_conn_id is None or
+        empty, then default boto3 configuration would be used (and must be
+        maintained on each worker node).
+    :param region_name: AWS region_name. If not specified then the default 
boto3 behaviour is used.
+    :param verify: Whether or not to verify SSL certificates. See:
+        
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
+    :param botocore_config: Configuration dictionary (key-values) for botocore 
client. See:
+        
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
     """
 
-    template_fields: Sequence[str] = ("stack_name", 
"cloudformation_parameters")
-    template_ext: Sequence[str] = ()
+    aws_hook_class = CloudFormationHook
+    template_fields: Sequence[str] = aws_template_fields("stack_name", 
"cloudformation_parameters")
     ui_color = "#6b9659"
 
-    def __init__(
-        self, *, stack_name: str, cloudformation_parameters: dict, 
aws_conn_id: str = "aws_default", **kwargs
-    ):
+    def __init__(self, *, stack_name: str, cloudformation_parameters: dict, 
**kwargs):
         super().__init__(**kwargs)
         self.stack_name = stack_name
         self.cloudformation_parameters = cloudformation_parameters
-        self.aws_conn_id = aws_conn_id
 
     def execute(self, context: Context):
         self.log.info("CloudFormation parameters: %s", 
self.cloudformation_parameters)
-
-        cloudformation_hook = CloudFormationHook(aws_conn_id=self.aws_conn_id)
-        cloudformation_hook.create_stack(self.stack_name, 
self.cloudformation_parameters)
+        self.hook.create_stack(self.stack_name, self.cloudformation_parameters)
 
 
-class CloudFormationDeleteStackOperator(BaseOperator):
+class CloudFormationDeleteStackOperator(AwsBaseOperator[CloudFormationHook]):
     """
-    An operator that deletes a CloudFormation stack.
-
-    :param stack_name: stack name (templated)
-    :param cloudformation_parameters: parameters to be passed to 
CloudFormation.
+    An operator that deletes a AWS CloudFormation stack.
 
     .. seealso::
         For more information on how to use this operator, take a look at the 
guide:
         :ref:`howto/operator:CloudFormationDeleteStackOperator`
 
-    :param aws_conn_id: aws connection to uses
+    :param stack_name: stack name (templated)
+    :param cloudformation_parameters: parameters to be passed to 
CloudFormation.
+    :param aws_conn_id: The Airflow connection used for AWS credentials.
+        If this is ``None`` or empty then the default boto3 behaviour is used. 
If
+        running Airflow in a distributed manner and aws_conn_id is None or
+        empty, then default boto3 configuration would be used (and must be
+        maintained on each worker node).
+    :param region_name: AWS region_name. If not specified then the default 
boto3 behaviour is used.
+    :param verify: Whether or not to verify SSL certificates. See:
+        
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
+    :param botocore_config: Configuration dictionary (key-values) for botocore 
client. See:
+        
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
     """
 
-    template_fields: Sequence[str] = ("stack_name",)
-    template_ext: Sequence[str] = ()
+    aws_hook_class = CloudFormationHook
+    template_fields: Sequence[str] = aws_template_fields("stack_name")
     ui_color = "#1d472b"
     ui_fgcolor = "#FFF"
 
@@ -93,6 +106,4 @@ class CloudFormationDeleteStackOperator(BaseOperator):
 
     def execute(self, context: Context):
         self.log.info("CloudFormation Parameters: %s", 
self.cloudformation_parameters)
-
-        cloudformation_hook = CloudFormationHook(aws_conn_id=self.aws_conn_id)
-        cloudformation_hook.delete_stack(self.stack_name, 
self.cloudformation_parameters)
+        self.hook.delete_stack(self.stack_name, self.cloudformation_parameters)
diff --git a/airflow/providers/amazon/aws/sensors/cloud_formation.py 
b/airflow/providers/amazon/aws/sensors/cloud_formation.py
index 5c6b1f2246..044ca50484 100644
--- a/airflow/providers/amazon/aws/sensors/cloud_formation.py
+++ b/airflow/providers/amazon/aws/sensors/cloud_formation.py
@@ -18,18 +18,19 @@
 """This module contains sensors for AWS CloudFormation."""
 from __future__ import annotations
 
-from functools import cached_property
 from typing import TYPE_CHECKING, Sequence
 
+from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
+from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
+
 if TYPE_CHECKING:
     from airflow.utils.context import Context
 
 from airflow.exceptions import AirflowSkipException
 from airflow.providers.amazon.aws.hooks.cloud_formation import 
CloudFormationHook
-from airflow.sensors.base import BaseSensorOperator
 
 
-class CloudFormationCreateStackSensor(BaseSensorOperator):
+class CloudFormationCreateStackSensor(AwsBaseSensor[CloudFormationHook]):
     """
     Waits for a stack to be created successfully on AWS CloudFormation.
 
@@ -38,19 +39,25 @@ class CloudFormationCreateStackSensor(BaseSensorOperator):
         :ref:`howto/sensor:CloudFormationCreateStackSensor`
 
     :param stack_name: The name of the stack to wait for (templated)
-    :param aws_conn_id: ID of the Airflow connection where credentials and 
extra configuration are
-        stored
-    :param poke_interval: Time in seconds that the job should wait between 
each try
+    :param aws_conn_id: The Airflow connection used for AWS credentials.
+        If this is ``None`` or empty then the default boto3 behaviour is used. 
If
+        running Airflow in a distributed manner and aws_conn_id is None or
+        empty, then default boto3 configuration would be used (and must be
+        maintained on each worker node).
+    :param region_name: AWS region_name. If not specified then the default 
boto3 behaviour is used.
+    :param verify: Whether or not to verify SSL certificates. See:
+        
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
+    :param botocore_config: Configuration dictionary (key-values) for botocore 
client. See:
+        
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
     """
 
-    template_fields: Sequence[str] = ("stack_name",)
+    aws_hook_class = CloudFormationHook
+    template_fields: Sequence[str] = aws_template_fields("stack_name")
     ui_color = "#C5CAE9"
 
-    def __init__(self, *, stack_name, aws_conn_id="aws_default", 
region_name=None, **kwargs):
+    def __init__(self, *, stack_name, **kwargs):
         super().__init__(**kwargs)
         self.stack_name = stack_name
-        self.aws_conn_id = aws_conn_id
-        self.region_name = region_name
 
     def poke(self, context: Context):
         stack_status = self.hook.get_stack_status(self.stack_name)
@@ -65,13 +72,8 @@ class CloudFormationCreateStackSensor(BaseSensorOperator):
             raise AirflowSkipException(message)
         raise ValueError(message)
 
-    @cached_property
-    def hook(self) -> CloudFormationHook:
-        """Create and return a CloudFormationHook."""
-        return CloudFormationHook(aws_conn_id=self.aws_conn_id, 
region_name=self.region_name)
 
-
-class CloudFormationDeleteStackSensor(BaseSensorOperator):
+class CloudFormationDeleteStackSensor(AwsBaseSensor[CloudFormationHook]):
     """
     Waits for a stack to be deleted successfully on AWS CloudFormation.
 
@@ -80,12 +82,20 @@ class CloudFormationDeleteStackSensor(BaseSensorOperator):
         :ref:`howto/sensor:CloudFormationDeleteStackSensor`
 
     :param stack_name: The name of the stack to wait for (templated)
-    :param aws_conn_id: ID of the Airflow connection where credentials and 
extra configuration are
-        stored
-    :param poke_interval: Time in seconds that the job should wait between 
each try
+    :param aws_conn_id: The Airflow connection used for AWS credentials.
+        If this is ``None`` or empty then the default boto3 behaviour is used. 
If
+        running Airflow in a distributed manner and aws_conn_id is None or
+        empty, then default boto3 configuration would be used (and must be
+        maintained on each worker node).
+    :param region_name: AWS region_name. If not specified then the default 
boto3 behaviour is used.
+    :param verify: Whether or not to verify SSL certificates. See:
+        
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
+    :param botocore_config: Configuration dictionary (key-values) for botocore 
client. See:
+        
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
     """
 
-    template_fields: Sequence[str] = ("stack_name",)
+    aws_hook_class = CloudFormationHook
+    template_fields: Sequence[str] = aws_template_fields("stack_name")
     ui_color = "#C5CAE9"
 
     def __init__(
@@ -113,8 +123,3 @@ class CloudFormationDeleteStackSensor(BaseSensorOperator):
         if self.soft_fail:
             raise AirflowSkipException(message)
         raise ValueError(message)
-
-    @cached_property
-    def hook(self) -> CloudFormationHook:
-        """Create and return a CloudFormationHook."""
-        return CloudFormationHook(aws_conn_id=self.aws_conn_id, 
region_name=self.region_name)
diff --git a/docs/apache-airflow-providers-amazon/operators/cloudformation.rst 
b/docs/apache-airflow-providers-amazon/operators/cloudformation.rst
index 4051be0ccd..ff45efcdb6 100644
--- a/docs/apache-airflow-providers-amazon/operators/cloudformation.rst
+++ b/docs/apache-airflow-providers-amazon/operators/cloudformation.rst
@@ -31,6 +31,11 @@ Prerequisite Tasks
 
 .. include:: ../_partials/prerequisite_tasks.rst
 
+Generic Parameters
+------------------
+
+.. include:: ../_partials/generic_parameters.rst
+
 Operators
 ---------
 
diff --git a/tests/providers/amazon/aws/operators/test_cloud_formation.py 
b/tests/providers/amazon/aws/operators/test_cloud_formation.py
index df54596b02..071ba5c847 100644
--- a/tests/providers/amazon/aws/operators/test_cloud_formation.py
+++ b/tests/providers/amazon/aws/operators/test_cloud_formation.py
@@ -40,6 +40,35 @@ def mocked_hook_client():
 
 
 class TestCloudFormationCreateStackOperator:
+    def test_init(self):
+        op = CloudFormationCreateStackOperator(
+            task_id="cf_create_stack_init",
+            stack_name="fake-stack",
+            cloudformation_parameters={},
+            # Generic hooks parameters
+            aws_conn_id="fake-conn-id",
+            region_name="eu-west-1",
+            verify=True,
+            botocore_config={"read_timeout": 42},
+        )
+        assert op.hook.client_type == "cloudformation"
+        assert op.hook.resource_type is None
+        assert op.hook.aws_conn_id == "fake-conn-id"
+        assert op.hook._region_name == "eu-west-1"
+        assert op.hook._verify is True
+        assert op.hook._config is not None
+        assert op.hook._config.read_timeout == 42
+
+        op = CloudFormationCreateStackOperator(
+            task_id="cf_create_stack_init",
+            stack_name="fake-stack",
+            cloudformation_parameters={},
+        )
+        assert op.hook.aws_conn_id == "aws_default"
+        assert op.hook._region_name is None
+        assert op.hook._verify is None
+        assert op.hook._config is None
+
     def test_create_stack(self, mocked_hook_client):
         stack_name = "myStack"
         timeout = 15
@@ -60,6 +89,30 @@ class TestCloudFormationCreateStackOperator:
 
 
 class TestCloudFormationDeleteStackOperator:
+    def test_init(self):
+        op = CloudFormationDeleteStackOperator(
+            task_id="cf_delete_stack_init",
+            stack_name="fake-stack",
+            # Generic hooks parameters
+            aws_conn_id="fake-conn-id",
+            region_name="us-east-1",
+            verify=False,
+            botocore_config={"read_timeout": 42},
+        )
+        assert op.hook.client_type == "cloudformation"
+        assert op.hook.resource_type is None
+        assert op.hook.aws_conn_id == "fake-conn-id"
+        assert op.hook._region_name == "us-east-1"
+        assert op.hook._verify is False
+        assert op.hook._config is not None
+        assert op.hook._config.read_timeout == 42
+
+        op = CloudFormationDeleteStackOperator(task_id="cf_delete_stack_init", 
stack_name="fake-stack")
+        assert op.hook.aws_conn_id == "aws_default"
+        assert op.hook._region_name is None
+        assert op.hook._verify is None
+        assert op.hook._config is None
+
     def test_delete_stack(self, mocked_hook_client):
         stack_name = "myStackToBeDeleted"
 
diff --git a/tests/providers/amazon/aws/sensors/test_cloud_formation.py 
b/tests/providers/amazon/aws/sensors/test_cloud_formation.py
index 51b9c385f1..ca41774411 100644
--- a/tests/providers/amazon/aws/sensors/test_cloud_formation.py
+++ b/tests/providers/amazon/aws/sensors/test_cloud_formation.py
@@ -23,6 +23,7 @@ import boto3
 import pytest
 from moto import mock_cloudformation
 
+from airflow.exceptions import AirflowSkipException
 from airflow.providers.amazon.aws.sensors.cloud_formation import (
     CloudFormationCreateStackSensor,
     CloudFormationDeleteStackSensor,
@@ -40,6 +41,30 @@ class TestCloudFormationCreateStackSensor:
     def setup_method(self, method):
         self.client = boto3.client("cloudformation", region_name="us-east-1")
 
+    def test_init(self):
+        sensor = CloudFormationCreateStackSensor(
+            task_id="cf_create_stack_init",
+            stack_name="fake-stack",
+            # Generic hooks parameters
+            aws_conn_id="fake-conn-id",
+            region_name="eu-central-1",
+            verify=False,
+            botocore_config={"read_timeout": 42},
+        )
+        assert sensor.hook.client_type == "cloudformation"
+        assert sensor.hook.resource_type is None
+        assert sensor.hook.aws_conn_id == "fake-conn-id"
+        assert sensor.hook._region_name == "eu-central-1"
+        assert sensor.hook._verify is False
+        assert sensor.hook._config is not None
+        assert sensor.hook._config.read_timeout == 42
+
+        sensor = 
CloudFormationCreateStackSensor(task_id="cf_create_stack_init", 
stack_name="fake-stack")
+        assert sensor.hook.aws_conn_id == "aws_default"
+        assert sensor.hook._region_name is None
+        assert sensor.hook._verify is None
+        assert sensor.hook._config is None
+
     @mock_cloudformation
     def test_poke(self):
         self.client.create_stack(StackName="foobar", 
TemplateBody='{"Resources": {}}')
@@ -51,10 +76,17 @@ class TestCloudFormationCreateStackSensor:
         op = CloudFormationCreateStackSensor(task_id="task", stack_name="foo")
         assert not op.poke({})
 
-    def test_poke_stack_in_unsuccessful_state(self, mocked_hook_client):
+    @pytest.mark.parametrize(
+        "soft_fail, expected_exception",
+        [
+            pytest.param(True, AirflowSkipException, id="soft-fail"),
+            pytest.param(False, ValueError, id="non-soft-fail"),
+        ],
+    )
+    def test_poke_stack_in_unsuccessful_state(self, mocked_hook_client, 
soft_fail, expected_exception):
         mocked_hook_client.describe_stacks.return_value = {"Stacks": 
[{"StackStatus": "bar"}]}
-        op = CloudFormationCreateStackSensor(task_id="task", stack_name="foo")
-        with pytest.raises(ValueError, match="Stack foo in bad state: bar"):
+        op = CloudFormationCreateStackSensor(task_id="task", stack_name="foo", 
soft_fail=soft_fail)
+        with pytest.raises(expected_exception, match="Stack foo in bad state: 
bar"):
             op.poke({})
 
 
@@ -63,6 +95,30 @@ class TestCloudFormationDeleteStackSensor:
     def setup_method(self, method):
         self.client = boto3.client("cloudformation", region_name="us-east-1")
 
+    def test_init(self):
+        sensor = CloudFormationDeleteStackSensor(
+            task_id="cf_delete_stack_init",
+            stack_name="fake-stack",
+            # Generic hooks parameters
+            aws_conn_id="fake-conn-id",
+            region_name="ca-west-1",
+            verify=True,
+            botocore_config={"read_timeout": 42},
+        )
+        assert sensor.hook.client_type == "cloudformation"
+        assert sensor.hook.resource_type is None
+        assert sensor.hook.aws_conn_id == "fake-conn-id"
+        assert sensor.hook._region_name == "ca-west-1"
+        assert sensor.hook._verify is True
+        assert sensor.hook._config is not None
+        assert sensor.hook._config.read_timeout == 42
+
+        sensor = 
CloudFormationDeleteStackSensor(task_id="cf_delete_stack_init", 
stack_name="fake-stack")
+        assert sensor.hook.aws_conn_id == "aws_default"
+        assert sensor.hook._region_name is None
+        assert sensor.hook._verify is None
+        assert sensor.hook._config is None
+
     @mock_cloudformation
     def test_poke(self):
         stack_name = "foobar"
@@ -76,10 +132,17 @@ class TestCloudFormationDeleteStackSensor:
         op = CloudFormationDeleteStackSensor(task_id="task", stack_name="foo")
         assert not op.poke({})
 
-    def test_poke_stack_in_unsuccessful_state(self, mocked_hook_client):
+    @pytest.mark.parametrize(
+        "soft_fail, expected_exception",
+        [
+            pytest.param(True, AirflowSkipException, id="soft-fail"),
+            pytest.param(False, ValueError, id="non-soft-fail"),
+        ],
+    )
+    def test_poke_stack_in_unsuccessful_state(self, mocked_hook_client, 
soft_fail, expected_exception):
         mocked_hook_client.describe_stacks.return_value = {"Stacks": 
[{"StackStatus": "bar"}]}
-        op = CloudFormationDeleteStackSensor(task_id="task", stack_name="foo")
-        with pytest.raises(ValueError, match="Stack foo in bad state: bar"):
+        op = CloudFormationDeleteStackSensor(task_id="task", stack_name="foo", 
soft_fail=soft_fail)
+        with pytest.raises(expected_exception, match="Stack foo in bad state: 
bar"):
             op.poke({})
 
     @mock_cloudformation

Reply via email to