This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 73d87945e9 Use base aws classes in Amazon ECS
Operators/Sensors/Triggers (#36393)
73d87945e9 is described below
commit 73d87945e9a78195278d0a4b495483062ddc9b35
Author: Jayce Slesar <[email protected]>
AuthorDate: Tue Dec 26 13:36:36 2023 -0500
Use base aws classes in Amazon ECS Operators/Sensors/Triggers (#36393)
* Use base aws classes in Amazon ECS Operators/Sensors/Triggers
* remove redundant init and wrapper for the hook that essentially did
nothing
* split out half of test
* change test to properly mock hook
* format
---
airflow/providers/amazon/aws/operators/ecs.py | 19 ++++---------
airflow/providers/amazon/aws/sensors/ecs.py | 23 +++++-----------
airflow/providers/amazon/aws/triggers/ecs.py | 4 +++
.../operators/ecs.rst | 4 +++
tests/providers/amazon/aws/operators/test_ecs.py | 31 ++++++++++------------
tests/providers/amazon/aws/sensors/test_ecs.py | 3 +--
6 files changed, 35 insertions(+), 49 deletions(-)
diff --git a/airflow/providers/amazon/aws/operators/ecs.py
b/airflow/providers/amazon/aws/operators/ecs.py
index 7d5d1d0583..f043a076e6 100644
--- a/airflow/providers/amazon/aws/operators/ecs.py
+++ b/airflow/providers/amazon/aws/operators/ecs.py
@@ -25,17 +25,18 @@ from typing import TYPE_CHECKING, Sequence
from airflow.configuration import conf
from airflow.exceptions import AirflowException,
AirflowProviderDeprecationWarning
-from airflow.models import BaseOperator
from airflow.providers.amazon.aws.exceptions import EcsOperatorError,
EcsTaskFailToStart
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.providers.amazon.aws.hooks.ecs import EcsClusterStates, EcsHook,
should_retry_eni
from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
+from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
from airflow.providers.amazon.aws.triggers.ecs import (
ClusterActiveTrigger,
ClusterInactiveTrigger,
TaskDoneTrigger,
)
from airflow.providers.amazon.aws.utils.identifiers import generate_uuid
+from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
from airflow.providers.amazon.aws.utils.task_log_fetcher import
AwsTaskLogFetcher
from airflow.utils.helpers import prune_dict
@@ -45,21 +46,11 @@ if TYPE_CHECKING:
from airflow.models import TaskInstance
from airflow.utils.context import Context
-DEFAULT_CONN_ID = "aws_default"
-
-class EcsBaseOperator(BaseOperator):
+class EcsBaseOperator(AwsBaseOperator[EcsHook]):
"""This is the base operator for all Elastic Container Service
operators."""
- def __init__(self, *, aws_conn_id: str | None = DEFAULT_CONN_ID, region:
str | None = None, **kwargs):
- self.aws_conn_id = aws_conn_id
- self.region = region
- super().__init__(**kwargs)
-
- @cached_property
- def hook(self) -> EcsHook:
- """Create and return an EcsHook."""
- return EcsHook(aws_conn_id=self.aws_conn_id, region_name=self.region)
+ aws_hook_class = EcsHook
@cached_property
def client(self) -> boto3.client:
@@ -101,7 +92,7 @@ class EcsCreateClusterOperator(EcsBaseOperator):
(default: False)
"""
- template_fields: Sequence[str] = (
+ template_fields: Sequence[str] = aws_template_fields(
"cluster_name",
"create_cluster_kwargs",
"wait_for_completion",
diff --git a/airflow/providers/amazon/aws/sensors/ecs.py
b/airflow/providers/amazon/aws/sensors/ecs.py
index ecb1b92ee0..02a212fbde 100644
--- a/airflow/providers/amazon/aws/sensors/ecs.py
+++ b/airflow/providers/amazon/aws/sensors/ecs.py
@@ -26,15 +26,14 @@ from airflow.providers.amazon.aws.hooks.ecs import (
EcsTaskDefinitionStates,
EcsTaskStates,
)
-from airflow.sensors.base import BaseSensorOperator
+from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
+from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
if TYPE_CHECKING:
import boto3
from airflow.utils.context import Context
-DEFAULT_CONN_ID: str = "aws_default"
-
def _check_failed(current_state, target_state, failure_states, soft_fail:
bool) -> None:
if (current_state != target_state) and (current_state in failure_states):
@@ -45,18 +44,10 @@ def _check_failed(current_state, target_state,
failure_states, soft_fail: bool)
raise AirflowException(message)
-class EcsBaseSensor(BaseSensorOperator):
+class EcsBaseSensor(AwsBaseSensor[EcsHook]):
"""Contains general sensor behavior for Elastic Container Service."""
- def __init__(self, *, aws_conn_id: str | None = DEFAULT_CONN_ID, region:
str | None = None, **kwargs):
- self.aws_conn_id = aws_conn_id
- self.region = region
- super().__init__(**kwargs)
-
- @cached_property
- def hook(self) -> EcsHook:
- """Create and return an EcsHook."""
- return EcsHook(aws_conn_id=self.aws_conn_id, region_name=self.region)
+ aws_hook_class = EcsHook
@cached_property
def client(self) -> boto3.client:
@@ -78,7 +69,7 @@ class EcsClusterStateSensor(EcsBaseSensor):
Success State. (Default: "FAILED" or "INACTIVE")
"""
- template_fields: Sequence[str] = ("cluster_name", "target_state",
"failure_states")
+ template_fields: Sequence[str] = aws_template_fields("cluster_name",
"target_state", "failure_states")
def __init__(
self,
@@ -116,7 +107,7 @@ class EcsTaskDefinitionStateSensor(EcsBaseSensor):
:param target_state: Success state to watch for. (Default: "ACTIVE")
"""
- template_fields: Sequence[str] = ("task_definition", "target_state",
"failure_states")
+ template_fields: Sequence[str] = aws_template_fields("task_definition",
"target_state", "failure_states")
def __init__(
self,
@@ -162,7 +153,7 @@ class EcsTaskStateSensor(EcsBaseSensor):
the Success State. (Default: "STOPPED")
"""
- template_fields: Sequence[str] = ("cluster", "task", "target_state",
"failure_states")
+ template_fields: Sequence[str] = aws_template_fields("cluster", "task",
"target_state", "failure_states")
def __init__(
self,
diff --git a/airflow/providers/amazon/aws/triggers/ecs.py
b/airflow/providers/amazon/aws/triggers/ecs.py
index 2d4b68b98f..1177aa657a 100644
--- a/airflow/providers/amazon/aws/triggers/ecs.py
+++ b/airflow/providers/amazon/aws/triggers/ecs.py
@@ -52,6 +52,7 @@ class ClusterActiveTrigger(AwsBaseWaiterTrigger):
waiter_max_attempts: int,
aws_conn_id: str | None,
region_name: str | None = None,
+ **kwargs,
):
super().__init__(
serialized_fields={"cluster_arn": cluster_arn},
@@ -66,6 +67,7 @@ class ClusterActiveTrigger(AwsBaseWaiterTrigger):
waiter_max_attempts=waiter_max_attempts,
aws_conn_id=aws_conn_id,
region_name=region_name,
+ **kwargs,
)
def hook(self) -> AwsGenericHook:
@@ -91,6 +93,7 @@ class ClusterInactiveTrigger(AwsBaseWaiterTrigger):
waiter_max_attempts: int,
aws_conn_id: str | None,
region_name: str | None = None,
+ **kwargs,
):
super().__init__(
serialized_fields={"cluster_arn": cluster_arn},
@@ -104,6 +107,7 @@ class ClusterInactiveTrigger(AwsBaseWaiterTrigger):
waiter_max_attempts=waiter_max_attempts,
aws_conn_id=aws_conn_id,
region_name=region_name,
+ **kwargs,
)
def hook(self) -> AwsGenericHook:
diff --git a/docs/apache-airflow-providers-amazon/operators/ecs.rst
b/docs/apache-airflow-providers-amazon/operators/ecs.rst
index e6b4385d36..6e8e3f5409 100644
--- a/docs/apache-airflow-providers-amazon/operators/ecs.rst
+++ b/docs/apache-airflow-providers-amazon/operators/ecs.rst
@@ -30,6 +30,10 @@ Prerequisite Tasks
.. include:: ../_partials/prerequisite_tasks.rst
+Generic Parameters
+------------------
+.. include:: ../_partials/generic_parameters.rst
+
Operators
---------
diff --git a/tests/providers/amazon/aws/operators/test_ecs.py
b/tests/providers/amazon/aws/operators/test_ecs.py
index f402a60436..828b069690 100644
--- a/tests/providers/amazon/aws/operators/test_ecs.py
+++ b/tests/providers/amazon/aws/operators/test_ecs.py
@@ -28,7 +28,6 @@ from airflow.exceptions import AirflowException,
AirflowProviderDeprecationWarni
from airflow.providers.amazon.aws.exceptions import EcsOperatorError,
EcsTaskFailToStart
from airflow.providers.amazon.aws.hooks.ecs import EcsClusterStates, EcsHook
from airflow.providers.amazon.aws.operators.ecs import (
- DEFAULT_CONN_ID,
EcsBaseOperator,
EcsCreateClusterOperator,
EcsDeleteClusterOperator,
@@ -112,30 +111,28 @@ class TestEcsBaseOperator(EcsBaseTestCase):
op_kw = {k: v for k, v in op_kw.items() if v is not NOTSET}
op = EcsBaseOperator(task_id="test_ecs_base", **op_kw)
- assert op.aws_conn_id == (aws_conn_id if aws_conn_id is not NOTSET
else DEFAULT_CONN_ID)
+ assert op.aws_conn_id == (aws_conn_id if aws_conn_id is not NOTSET
else "aws_default")
assert op.region == (region_name if region_name is not NOTSET else
None)
- @mock.patch("airflow.providers.amazon.aws.operators.ecs.EcsHook")
@pytest.mark.parametrize("aws_conn_id", [None, NOTSET, "aws_test_conn"])
@pytest.mark.parametrize("region_name", [None, NOTSET, "ca-central-1"])
- def test_hook_and_client(self, mock_ecs_hook_cls, aws_conn_id,
region_name):
- """Test initialize ``EcsHook`` and ``boto3.client``."""
- mock_ecs_hook = mock_ecs_hook_cls.return_value
- mock_conn = mock.MagicMock()
- type(mock_ecs_hook).conn = mock.PropertyMock(return_value=mock_conn)
-
+ def test_initialise_operator_hook(self, aws_conn_id, region_name):
+ """Test initialize operator."""
op_kw = {"aws_conn_id": aws_conn_id, "region": region_name}
op_kw = {k: v for k, v in op_kw.items() if v is not NOTSET}
- op = EcsBaseOperator(task_id="test_ecs_base_hook_client", **op_kw)
+ op = EcsBaseOperator(task_id="test_ecs_base", **op_kw)
+
+ assert op.hook.aws_conn_id == (aws_conn_id if aws_conn_id is not
NOTSET else "aws_default")
+ assert op.hook.region_name == (region_name if region_name is not
NOTSET else None)
- hook = op.hook
- assert op.hook is hook
- mock_ecs_hook_cls.assert_called_once_with(aws_conn_id=op.aws_conn_id,
region_name=op.region)
+ with mock.patch.object(EcsBaseOperator, "hook",
new_callable=mock.PropertyMock) as m:
+ mocked_hook = mock.MagicMock(name="MockHook")
+ mocked_client = mock.MagicMock(name="Mocklient")
+ mocked_hook.conn = mocked_client
+ m.return_value = mocked_hook
- client = op.client
- mock_ecs_hook_cls.assert_called_once_with(aws_conn_id=op.aws_conn_id,
region_name=op.region)
- assert client == mock_conn
- assert op.client is client
+ assert op.client == mocked_client
+ m.assert_called_once()
class TestEcsRunTaskOperator(EcsBaseTestCase):
diff --git a/tests/providers/amazon/aws/sensors/test_ecs.py
b/tests/providers/amazon/aws/sensors/test_ecs.py
index 46f7dbd4fd..d69cbee11f 100644
--- a/tests/providers/amazon/aws/sensors/test_ecs.py
+++ b/tests/providers/amazon/aws/sensors/test_ecs.py
@@ -26,7 +26,6 @@ from slugify import slugify
from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.providers.amazon.aws.sensors.ecs import (
- DEFAULT_CONN_ID,
EcsBaseSensor,
EcsClusterStates,
EcsClusterStateSensor,
@@ -79,7 +78,7 @@ class TestEcsBaseSensor(EcsBaseTestCase):
op_kw = {k: v for k, v in op_kw.items() if v is not NOTSET}
op = EcsBaseSensor(task_id="test_ecs_base", **op_kw)
- assert op.aws_conn_id == (aws_conn_id if aws_conn_id is not NOTSET
else DEFAULT_CONN_ID)
+ assert op.aws_conn_id == (aws_conn_id if aws_conn_id is not NOTSET
else "aws_default")
assert op.region == (region_name if region_name is not NOTSET else
None)
@pytest.mark.parametrize("aws_conn_id", [None, NOTSET, "aws_test_conn"])