This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 73d87945e9 Use base aws classes in Amazon ECS 
Operators/Sensors/Triggers (#36393)
73d87945e9 is described below

commit 73d87945e9a78195278d0a4b495483062ddc9b35
Author: Jayce Slesar <[email protected]>
AuthorDate: Tue Dec 26 13:36:36 2023 -0500

    Use base aws classes in Amazon ECS Operators/Sensors/Triggers (#36393)
    
    * Use base aws classes in Amazon ECS Operators/Sensors/Triggers
    
    * remove redundant init and wrapper for the hook that essentially did 
nothing
    
    * split out half of test
    
    * change test to properly mock hook
    
    * format
---
 airflow/providers/amazon/aws/operators/ecs.py      | 19 ++++---------
 airflow/providers/amazon/aws/sensors/ecs.py        | 23 +++++-----------
 airflow/providers/amazon/aws/triggers/ecs.py       |  4 +++
 .../operators/ecs.rst                              |  4 +++
 tests/providers/amazon/aws/operators/test_ecs.py   | 31 ++++++++++------------
 tests/providers/amazon/aws/sensors/test_ecs.py     |  3 +--
 6 files changed, 35 insertions(+), 49 deletions(-)

diff --git a/airflow/providers/amazon/aws/operators/ecs.py 
b/airflow/providers/amazon/aws/operators/ecs.py
index 7d5d1d0583..f043a076e6 100644
--- a/airflow/providers/amazon/aws/operators/ecs.py
+++ b/airflow/providers/amazon/aws/operators/ecs.py
@@ -25,17 +25,18 @@ from typing import TYPE_CHECKING, Sequence
 
 from airflow.configuration import conf
 from airflow.exceptions import AirflowException, 
AirflowProviderDeprecationWarning
-from airflow.models import BaseOperator
 from airflow.providers.amazon.aws.exceptions import EcsOperatorError, 
EcsTaskFailToStart
 from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
 from airflow.providers.amazon.aws.hooks.ecs import EcsClusterStates, EcsHook, 
should_retry_eni
 from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
+from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
 from airflow.providers.amazon.aws.triggers.ecs import (
     ClusterActiveTrigger,
     ClusterInactiveTrigger,
     TaskDoneTrigger,
 )
 from airflow.providers.amazon.aws.utils.identifiers import generate_uuid
+from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
 from airflow.providers.amazon.aws.utils.task_log_fetcher import 
AwsTaskLogFetcher
 from airflow.utils.helpers import prune_dict
 
@@ -45,21 +46,11 @@ if TYPE_CHECKING:
     from airflow.models import TaskInstance
     from airflow.utils.context import Context
 
-DEFAULT_CONN_ID = "aws_default"
 
-
-class EcsBaseOperator(BaseOperator):
+class EcsBaseOperator(AwsBaseOperator[EcsHook]):
     """This is the base operator for all Elastic Container Service 
operators."""
 
-    def __init__(self, *, aws_conn_id: str | None = DEFAULT_CONN_ID, region: 
str | None = None, **kwargs):
-        self.aws_conn_id = aws_conn_id
-        self.region = region
-        super().__init__(**kwargs)
-
-    @cached_property
-    def hook(self) -> EcsHook:
-        """Create and return an EcsHook."""
-        return EcsHook(aws_conn_id=self.aws_conn_id, region_name=self.region)
+    aws_hook_class = EcsHook
 
     @cached_property
     def client(self) -> boto3.client:
@@ -101,7 +92,7 @@ class EcsCreateClusterOperator(EcsBaseOperator):
         (default: False)
     """
 
-    template_fields: Sequence[str] = (
+    template_fields: Sequence[str] = aws_template_fields(
         "cluster_name",
         "create_cluster_kwargs",
         "wait_for_completion",
diff --git a/airflow/providers/amazon/aws/sensors/ecs.py 
b/airflow/providers/amazon/aws/sensors/ecs.py
index ecb1b92ee0..02a212fbde 100644
--- a/airflow/providers/amazon/aws/sensors/ecs.py
+++ b/airflow/providers/amazon/aws/sensors/ecs.py
@@ -26,15 +26,14 @@ from airflow.providers.amazon.aws.hooks.ecs import (
     EcsTaskDefinitionStates,
     EcsTaskStates,
 )
-from airflow.sensors.base import BaseSensorOperator
+from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
+from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
 
 if TYPE_CHECKING:
     import boto3
 
     from airflow.utils.context import Context
 
-DEFAULT_CONN_ID: str = "aws_default"
-
 
 def _check_failed(current_state, target_state, failure_states, soft_fail: 
bool) -> None:
     if (current_state != target_state) and (current_state in failure_states):
@@ -45,18 +44,10 @@ def _check_failed(current_state, target_state, 
failure_states, soft_fail: bool)
         raise AirflowException(message)
 
 
-class EcsBaseSensor(BaseSensorOperator):
+class EcsBaseSensor(AwsBaseSensor[EcsHook]):
     """Contains general sensor behavior for Elastic Container Service."""
 
-    def __init__(self, *, aws_conn_id: str | None = DEFAULT_CONN_ID, region: 
str | None = None, **kwargs):
-        self.aws_conn_id = aws_conn_id
-        self.region = region
-        super().__init__(**kwargs)
-
-    @cached_property
-    def hook(self) -> EcsHook:
-        """Create and return an EcsHook."""
-        return EcsHook(aws_conn_id=self.aws_conn_id, region_name=self.region)
+    aws_hook_class = EcsHook
 
     @cached_property
     def client(self) -> boto3.client:
@@ -78,7 +69,7 @@ class EcsClusterStateSensor(EcsBaseSensor):
          Success State. (Default: "FAILED" or "INACTIVE")
     """
 
-    template_fields: Sequence[str] = ("cluster_name", "target_state", 
"failure_states")
+    template_fields: Sequence[str] = aws_template_fields("cluster_name", 
"target_state", "failure_states")
 
     def __init__(
         self,
@@ -116,7 +107,7 @@ class EcsTaskDefinitionStateSensor(EcsBaseSensor):
     :param target_state: Success state to watch for. (Default: "ACTIVE")
     """
 
-    template_fields: Sequence[str] = ("task_definition", "target_state", 
"failure_states")
+    template_fields: Sequence[str] = aws_template_fields("task_definition", 
"target_state", "failure_states")
 
     def __init__(
         self,
@@ -162,7 +153,7 @@ class EcsTaskStateSensor(EcsBaseSensor):
          the Success State. (Default: "STOPPED")
     """
 
-    template_fields: Sequence[str] = ("cluster", "task", "target_state", 
"failure_states")
+    template_fields: Sequence[str] = aws_template_fields("cluster", "task", 
"target_state", "failure_states")
 
     def __init__(
         self,
diff --git a/airflow/providers/amazon/aws/triggers/ecs.py 
b/airflow/providers/amazon/aws/triggers/ecs.py
index 2d4b68b98f..1177aa657a 100644
--- a/airflow/providers/amazon/aws/triggers/ecs.py
+++ b/airflow/providers/amazon/aws/triggers/ecs.py
@@ -52,6 +52,7 @@ class ClusterActiveTrigger(AwsBaseWaiterTrigger):
         waiter_max_attempts: int,
         aws_conn_id: str | None,
         region_name: str | None = None,
+        **kwargs,
     ):
         super().__init__(
             serialized_fields={"cluster_arn": cluster_arn},
@@ -66,6 +67,7 @@ class ClusterActiveTrigger(AwsBaseWaiterTrigger):
             waiter_max_attempts=waiter_max_attempts,
             aws_conn_id=aws_conn_id,
             region_name=region_name,
+            **kwargs,
         )
 
     def hook(self) -> AwsGenericHook:
@@ -91,6 +93,7 @@ class ClusterInactiveTrigger(AwsBaseWaiterTrigger):
         waiter_max_attempts: int,
         aws_conn_id: str | None,
         region_name: str | None = None,
+        **kwargs,
     ):
         super().__init__(
             serialized_fields={"cluster_arn": cluster_arn},
@@ -104,6 +107,7 @@ class ClusterInactiveTrigger(AwsBaseWaiterTrigger):
             waiter_max_attempts=waiter_max_attempts,
             aws_conn_id=aws_conn_id,
             region_name=region_name,
+            **kwargs,
         )
 
     def hook(self) -> AwsGenericHook:
diff --git a/docs/apache-airflow-providers-amazon/operators/ecs.rst 
b/docs/apache-airflow-providers-amazon/operators/ecs.rst
index e6b4385d36..6e8e3f5409 100644
--- a/docs/apache-airflow-providers-amazon/operators/ecs.rst
+++ b/docs/apache-airflow-providers-amazon/operators/ecs.rst
@@ -30,6 +30,10 @@ Prerequisite Tasks
 
 .. include:: ../_partials/prerequisite_tasks.rst
 
+Generic Parameters
+------------------
+.. include:: ../_partials/generic_parameters.rst
+
 Operators
 ---------
 
diff --git a/tests/providers/amazon/aws/operators/test_ecs.py 
b/tests/providers/amazon/aws/operators/test_ecs.py
index f402a60436..828b069690 100644
--- a/tests/providers/amazon/aws/operators/test_ecs.py
+++ b/tests/providers/amazon/aws/operators/test_ecs.py
@@ -28,7 +28,6 @@ from airflow.exceptions import AirflowException, 
AirflowProviderDeprecationWarni
 from airflow.providers.amazon.aws.exceptions import EcsOperatorError, 
EcsTaskFailToStart
 from airflow.providers.amazon.aws.hooks.ecs import EcsClusterStates, EcsHook
 from airflow.providers.amazon.aws.operators.ecs import (
-    DEFAULT_CONN_ID,
     EcsBaseOperator,
     EcsCreateClusterOperator,
     EcsDeleteClusterOperator,
@@ -112,30 +111,28 @@ class TestEcsBaseOperator(EcsBaseTestCase):
         op_kw = {k: v for k, v in op_kw.items() if v is not NOTSET}
         op = EcsBaseOperator(task_id="test_ecs_base", **op_kw)
 
-        assert op.aws_conn_id == (aws_conn_id if aws_conn_id is not NOTSET 
else DEFAULT_CONN_ID)
+        assert op.aws_conn_id == (aws_conn_id if aws_conn_id is not NOTSET 
else "aws_default")
         assert op.region == (region_name if region_name is not NOTSET else 
None)
 
-    @mock.patch("airflow.providers.amazon.aws.operators.ecs.EcsHook")
     @pytest.mark.parametrize("aws_conn_id", [None, NOTSET, "aws_test_conn"])
     @pytest.mark.parametrize("region_name", [None, NOTSET, "ca-central-1"])
-    def test_hook_and_client(self, mock_ecs_hook_cls, aws_conn_id, 
region_name):
-        """Test initialize ``EcsHook`` and ``boto3.client``."""
-        mock_ecs_hook = mock_ecs_hook_cls.return_value
-        mock_conn = mock.MagicMock()
-        type(mock_ecs_hook).conn = mock.PropertyMock(return_value=mock_conn)
-
+    def test_initialise_operator_hook(self, aws_conn_id, region_name):
+        """Test initialize operator."""
         op_kw = {"aws_conn_id": aws_conn_id, "region": region_name}
         op_kw = {k: v for k, v in op_kw.items() if v is not NOTSET}
-        op = EcsBaseOperator(task_id="test_ecs_base_hook_client", **op_kw)
+        op = EcsBaseOperator(task_id="test_ecs_base", **op_kw)
+
+        assert op.hook.aws_conn_id == (aws_conn_id if aws_conn_id is not 
NOTSET else "aws_default")
+        assert op.hook.region_name == (region_name if region_name is not 
NOTSET else None)
 
-        hook = op.hook
-        assert op.hook is hook
-        mock_ecs_hook_cls.assert_called_once_with(aws_conn_id=op.aws_conn_id, 
region_name=op.region)
+        with mock.patch.object(EcsBaseOperator, "hook", 
new_callable=mock.PropertyMock) as m:
+            mocked_hook = mock.MagicMock(name="MockHook")
+            mocked_client = mock.MagicMock(name="Mocklient")
+            mocked_hook.conn = mocked_client
+            m.return_value = mocked_hook
 
-        client = op.client
-        mock_ecs_hook_cls.assert_called_once_with(aws_conn_id=op.aws_conn_id, 
region_name=op.region)
-        assert client == mock_conn
-        assert op.client is client
+            assert op.client == mocked_client
+            m.assert_called_once()
 
 
 class TestEcsRunTaskOperator(EcsBaseTestCase):
diff --git a/tests/providers/amazon/aws/sensors/test_ecs.py 
b/tests/providers/amazon/aws/sensors/test_ecs.py
index 46f7dbd4fd..d69cbee11f 100644
--- a/tests/providers/amazon/aws/sensors/test_ecs.py
+++ b/tests/providers/amazon/aws/sensors/test_ecs.py
@@ -26,7 +26,6 @@ from slugify import slugify
 
 from airflow.exceptions import AirflowException, AirflowSkipException
 from airflow.providers.amazon.aws.sensors.ecs import (
-    DEFAULT_CONN_ID,
     EcsBaseSensor,
     EcsClusterStates,
     EcsClusterStateSensor,
@@ -79,7 +78,7 @@ class TestEcsBaseSensor(EcsBaseTestCase):
         op_kw = {k: v for k, v in op_kw.items() if v is not NOTSET}
         op = EcsBaseSensor(task_id="test_ecs_base", **op_kw)
 
-        assert op.aws_conn_id == (aws_conn_id if aws_conn_id is not NOTSET 
else DEFAULT_CONN_ID)
+        assert op.aws_conn_id == (aws_conn_id if aws_conn_id is not NOTSET 
else "aws_default")
         assert op.region == (region_name if region_name is not NOTSET else 
None)
 
     @pytest.mark.parametrize("aws_conn_id", [None, NOTSET, "aws_test_conn"])

Reply via email to