This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 415e076761 Deferrable mode for ECS operators (#31881)
415e076761 is described below
commit 415e0767616121854b6a29b3e44387f708cdf81e
Author: Raphaƫl Vandon <[email protected]>
AuthorDate: Fri Jun 23 10:13:13 2023 -0700
Deferrable mode for ECS operators (#31881)
---
airflow/providers/amazon/aws/operators/ecs.py | 160 ++++++++++++++---
airflow/providers/amazon/aws/triggers/ecs.py | 198 +++++++++++++++++++++
.../providers/amazon/aws/utils/task_log_fetcher.py | 5 +-
airflow/providers/amazon/provider.yaml | 3 +
tests/providers/amazon/aws/operators/test_ecs.py | 74 +++++++-
tests/providers/amazon/aws/triggers/test_ecs.py | 123 +++++++++++++
.../amazon/aws/utils/test_task_log_fetcher.py | 2 +-
7 files changed, 532 insertions(+), 33 deletions(-)
diff --git a/airflow/providers/amazon/aws/operators/ecs.py
b/airflow/providers/amazon/aws/operators/ecs.py
index bc8c4b70d7..2c2e93af35 100644
--- a/airflow/providers/amazon/aws/operators/ecs.py
+++ b/airflow/providers/amazon/aws/operators/ecs.py
@@ -35,6 +35,11 @@ from airflow.providers.amazon.aws.hooks.ecs import (
EcsHook,
should_retry_eni,
)
+from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
+from airflow.providers.amazon.aws.triggers.ecs import (
+ ClusterWaiterTrigger,
+ TaskDoneTrigger,
+)
from airflow.providers.amazon.aws.utils.task_log_fetcher import
AwsTaskLogFetcher
from airflow.utils.helpers import prune_dict
from airflow.utils.session import provide_session
@@ -67,6 +72,15 @@ class EcsBaseOperator(BaseOperator):
"""Must overwrite in child classes."""
raise NotImplementedError("Please implement execute() in subclass")
+ def _complete_exec_with_cluster_desc(self, context, event=None):
+ """To be used as trigger callback for operators that return the
cluster description."""
+ if event["status"] != "success":
+ raise AirflowException(f"Error while waiting for operation on
cluster to complete: {event}")
+ cluster_arn = event.get("arn")
+ # We cannot get the cluster definition from the waiter on success, so
we have to query it here.
+ details =
self.hook.conn.describe_clusters(clusters=[cluster_arn])["clusters"][0]
+ return details
+
class EcsCreateClusterOperator(EcsBaseOperator):
"""
@@ -84,9 +98,17 @@ class EcsCreateClusterOperator(EcsBaseOperator):
if not set then the default waiter value will be used.
:param waiter_max_attempts: The maximum number of attempts to be made,
if not set then the default waiter value will be used.
+ :param deferrable: If True, the operator will wait asynchronously for the
job to complete.
+ This implies waiting for completion. This mode requires aiobotocore
module to be installed.
+ (default: False)
"""
- template_fields: Sequence[str] = ("cluster_name", "create_cluster_kwargs",
"wait_for_completion")
+ template_fields: Sequence[str] = (
+ "cluster_name",
+ "create_cluster_kwargs",
+ "wait_for_completion",
+ "deferrable",
+ )
def __init__(
self,
@@ -94,8 +116,9 @@ class EcsCreateClusterOperator(EcsBaseOperator):
cluster_name: str,
create_cluster_kwargs: dict | None = None,
wait_for_completion: bool = True,
- waiter_delay: int | None = None,
- waiter_max_attempts: int | None = None,
+ waiter_delay: int = 15,
+ waiter_max_attempts: int = 60,
+ deferrable: bool = False,
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -104,6 +127,7 @@ class EcsCreateClusterOperator(EcsBaseOperator):
self.wait_for_completion = wait_for_completion
self.waiter_delay = waiter_delay
self.waiter_max_attempts = waiter_max_attempts
+ self.deferrable = deferrable
def execute(self, context: Context):
self.log.info(
@@ -119,6 +143,21 @@ class EcsCreateClusterOperator(EcsBaseOperator):
# In some circumstances the ECS Cluster is created immediately,
# and there is no reason to wait for completion.
self.log.info("Cluster %r in state: %r.", self.cluster_name,
cluster_state)
+ elif self.deferrable:
+ self.defer(
+ trigger=ClusterWaiterTrigger(
+ waiter_name="cluster_active",
+ cluster_arn=cluster_details["clusterArn"],
+ waiter_delay=self.waiter_delay,
+ waiter_max_attempts=self.waiter_max_attempts,
+ aws_conn_id=self.aws_conn_id,
+ region=self.region,
+ ),
+ method_name="_complete_exec_with_cluster_desc",
+ # timeout is set to ensure that if a trigger dies, the timeout
does not restart
+ # 60 seconds is added to allow the trigger to exit gracefully
(i.e. yield TriggerEvent)
+ timeout=timedelta(seconds=self.waiter_max_attempts *
self.waiter_delay + 60),
+ )
elif self.wait_for_completion:
waiter = self.hook.get_waiter("cluster_active")
waiter.wait(
@@ -148,17 +187,21 @@ class EcsDeleteClusterOperator(EcsBaseOperator):
if not set then the default waiter value will be used.
:param waiter_max_attempts: The maximum number of attempts to be made,
if not set then the default waiter value will be used.
+ :param deferrable: If True, the operator will wait asynchronously for the
job to complete.
+ This implies waiting for completion. This mode requires aiobotocore
module to be installed.
+ (default: False)
"""
- template_fields: Sequence[str] = ("cluster_name", "wait_for_completion")
+ template_fields: Sequence[str] = ("cluster_name", "wait_for_completion",
"deferrable")
def __init__(
self,
*,
cluster_name: str,
wait_for_completion: bool = True,
- waiter_delay: int | None = None,
- waiter_max_attempts: int | None = None,
+ waiter_delay: int = 15,
+ waiter_max_attempts: int = 60,
+ deferrable: bool = False,
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -166,6 +209,7 @@ class EcsDeleteClusterOperator(EcsBaseOperator):
self.wait_for_completion = wait_for_completion
self.waiter_delay = waiter_delay
self.waiter_max_attempts = waiter_max_attempts
+ self.deferrable = deferrable
def execute(self, context: Context):
self.log.info("Deleting cluster %r.", self.cluster_name)
@@ -174,9 +218,24 @@ class EcsDeleteClusterOperator(EcsBaseOperator):
cluster_state = cluster_details.get("status")
if cluster_state == EcsClusterStates.INACTIVE:
- # In some circumstances the ECS Cluster is deleted immediately,
- # so there is no reason to wait for completion.
+ # if the cluster doesn't have capacity providers that are
associated with it,
+ # the deletion is instantaneous, and we don't need to wait for it.
self.log.info("Cluster %r in state: %r.", self.cluster_name,
cluster_state)
+ elif self.deferrable:
+ self.defer(
+ trigger=ClusterWaiterTrigger(
+ waiter_name="cluster_inactive",
+ cluster_arn=cluster_details["clusterArn"],
+ waiter_delay=self.waiter_delay,
+ waiter_max_attempts=self.waiter_max_attempts,
+ aws_conn_id=self.aws_conn_id,
+ region=self.region,
+ ),
+ method_name="_complete_exec_with_cluster_desc",
+ # timeout is set to ensure that if a trigger dies, the timeout
does not restart
+ # 60 seconds is added to allow the trigger to exit gracefully
(i.e. yield TriggerEvent)
+ timeout=timedelta(seconds=self.waiter_max_attempts *
self.waiter_delay + 60),
+ )
elif self.wait_for_completion:
waiter = self.hook.get_waiter("cluster_inactive")
waiter.wait(
@@ -347,6 +406,7 @@ class EcsRunTaskOperator(EcsBaseOperator):
finished.
:param awslogs_fetch_interval: the interval that the ECS task log fetcher
should wait
in between each Cloudwatch logs fetches.
+ If deferrable is set to True, that parameter is ignored and
waiter_delay is used instead.
:param quota_retry: Config if and how to retry the launch of a new ECS
task, to handle
transient errors.
:param reattach: If set to True, will check if the task previously
launched by the task_instance
@@ -361,6 +421,9 @@ class EcsRunTaskOperator(EcsBaseOperator):
if not set then the default waiter value will be used.
:param waiter_max_attempts: The maximum number of attempts to be made,
if not set then the default waiter value will be used.
+ :param deferrable: If True, the operator will wait asynchronously for the
job to complete.
+ This implies waiting for completion. This mode requires aiobotocore
module to be installed.
+ (default: False)
"""
ui_color = "#f0ede4"
@@ -384,6 +447,7 @@ class EcsRunTaskOperator(EcsBaseOperator):
"reattach",
"number_logs_exception",
"wait_for_completion",
+ "deferrable",
)
template_fields_renderers = {
"overrides": "json",
@@ -416,8 +480,9 @@ class EcsRunTaskOperator(EcsBaseOperator):
reattach: bool = False,
number_logs_exception: int = 10,
wait_for_completion: bool = True,
- waiter_delay: int | None = None,
- waiter_max_attempts: int | None = None,
+ waiter_delay: int = 6,
+ waiter_max_attempts: int = 100,
+ deferrable: bool = False,
**kwargs,
):
super().__init__(**kwargs)
@@ -451,6 +516,7 @@ class EcsRunTaskOperator(EcsBaseOperator):
self.wait_for_completion = wait_for_completion
self.waiter_delay = waiter_delay
self.waiter_max_attempts = waiter_max_attempts
+ self.deferrable = deferrable
if self._aws_logs_enabled() and not self.wait_for_completion:
self.log.warning(
@@ -473,7 +539,35 @@ class EcsRunTaskOperator(EcsBaseOperator):
if self.reattach:
self._try_reattach_task(context)
- self._start_wait_check_task(context)
+ self._start_wait_task(context)
+
+ self._after_execution(session)
+
+ if self.do_xcom_push and self.task_log_fetcher:
+ return self.task_log_fetcher.get_last_log_message()
+ else:
+ return None
+
+ def execute_complete(self, context, event=None):
+ if event["status"] != "success":
+ raise AirflowException(f"Error in task execution: {event}")
+ self.arn = event["task_arn"] # restore arn to its updated value,
needed for next steps
+ self._after_execution()
+ if self._aws_logs_enabled():
+ # same behavior as non-deferrable mode, return last line of logs
of the task.
+ logs_client = AwsLogsHook(aws_conn_id=self.aws_conn_id,
region_name=self.region).conn
+ one_log = logs_client.get_log_events(
+ logGroupName=self.awslogs_group,
+ logStreamName=self._get_logs_stream_name(),
+ startFromHead=False,
+ limit=1,
+ )
+ if len(one_log["events"]) > 0:
+ return one_log["events"][0]["message"]
+
+ @provide_session
+ def _after_execution(self, session=None):
+ self._check_success_task()
self.log.info("ECS Task has been successfully executed")
@@ -482,16 +576,29 @@ class EcsRunTaskOperator(EcsBaseOperator):
# as we can't reattach it anymore
self._xcom_del(session,
self.REATTACH_XCOM_TASK_ID_TEMPLATE.format(task_id=self.task_id))
- if self.do_xcom_push and self.task_log_fetcher:
- return self.task_log_fetcher.get_last_log_message()
-
- return None
-
@AwsBaseHook.retry(should_retry_eni)
- def _start_wait_check_task(self, context):
+ def _start_wait_task(self, context):
if not self.arn:
self._start_task(context)
+ if self.deferrable:
+ self.defer(
+ trigger=TaskDoneTrigger(
+ cluster=self.cluster,
+ task_arn=self.arn,
+ waiter_delay=self.waiter_delay,
+ waiter_max_attempts=self.waiter_max_attempts,
+ aws_conn_id=self.aws_conn_id,
+ region=self.region,
+ log_group=self.awslogs_group,
+ log_stream=self._get_logs_stream_name(),
+ ),
+ method_name="execute_complete",
+ # timeout is set to ensure that if a trigger dies, the timeout
does not restart
+ # 60 seconds is added to allow the trigger to exit gracefully
(i.e. yield TriggerEvent)
+ timeout=timedelta(seconds=self.waiter_max_attempts *
self.waiter_delay + 60),
+ )
+
if not self.wait_for_completion:
return
@@ -508,8 +615,6 @@ class EcsRunTaskOperator(EcsBaseOperator):
else:
self._wait_for_task_ended()
- self._check_success_task()
-
def _xcom_del(self, session, task_id):
session.query(XCom).filter(XCom.dag_id == self.dag_id, XCom.task_id ==
task_id).delete()
@@ -584,12 +689,10 @@ class EcsRunTaskOperator(EcsBaseOperator):
waiter.wait(
cluster=self.cluster,
tasks=[self.arn],
- WaiterConfig=prune_dict(
- {
- "Delay": self.waiter_delay,
- "MaxAttempts": self.waiter_max_attempts,
- }
- ),
+ WaiterConfig={
+ "Delay": self.waiter_delay,
+ "MaxAttempts": self.waiter_max_attempts,
+ },
)
return
@@ -597,20 +700,23 @@ class EcsRunTaskOperator(EcsBaseOperator):
def _aws_logs_enabled(self):
return self.awslogs_group and self.awslogs_stream_prefix
+ def _get_logs_stream_name(self) -> str:
+ return
f"{self.awslogs_stream_prefix}/{self._get_ecs_task_id(self.arn)}"
+
def _get_task_log_fetcher(self) -> AwsTaskLogFetcher:
if not self.awslogs_group:
raise ValueError("must specify awslogs_group to fetch task logs")
- log_stream_name =
f"{self.awslogs_stream_prefix}/{self._get_ecs_task_id(self.arn)}"
return AwsTaskLogFetcher(
aws_conn_id=self.aws_conn_id,
region_name=self.awslogs_region,
log_group=self.awslogs_group,
- log_stream_name=log_stream_name,
+ log_stream_name=self._get_logs_stream_name(),
fetch_interval=self.awslogs_fetch_interval,
logger=self.log,
)
+ @AwsBaseHook.retry(should_retry_eni)
def _check_success_task(self) -> None:
if not self.client or not self.arn:
return
diff --git a/airflow/providers/amazon/aws/triggers/ecs.py
b/airflow/providers/amazon/aws/triggers/ecs.py
new file mode 100644
index 0000000000..8ba8350588
--- /dev/null
+++ b/airflow/providers/amazon/aws/triggers/ecs.py
@@ -0,0 +1,198 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import asyncio
+from typing import Any, AsyncIterator
+
+from botocore.exceptions import ClientError, WaiterError
+
+from airflow.providers.amazon.aws.hooks.ecs import EcsHook
+from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
+from airflow.providers.amazon.aws.utils.task_log_fetcher import
AwsTaskLogFetcher
+from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait
+from airflow.triggers.base import BaseTrigger, TriggerEvent
+
+
+class ClusterWaiterTrigger(BaseTrigger):
+ """
+ Polls the status of a cluster using a given waiter. Can be used to poll
for an active or inactive cluster.
+
+ :param waiter_name: Name of the waiter to use, for instance
'cluster_active' or 'cluster_inactive'
+ :param cluster_arn: ARN of the cluster to watch.
+ :param waiter_delay: The amount of time in seconds to wait between
attempts.
+ :param waiter_max_attempts: The number of times to ping for status.
+ Will fail after that many unsuccessful attempts.
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
+ :param region: The AWS region where the cluster is located.
+ """
+
+ def __init__(
+ self,
+ waiter_name: str,
+ cluster_arn: str,
+ waiter_delay: int | None,
+ waiter_max_attempts: int | None,
+ aws_conn_id: str | None,
+ region: str | None,
+ ):
+ self.cluster_arn = cluster_arn
+ self.waiter_name = waiter_name
+ self.waiter_delay = waiter_delay if waiter_delay is not None else 15
# written like this to allow 0
+ self.attempts = waiter_max_attempts or 999999999
+ self.aws_conn_id = aws_conn_id
+ self.region = region
+
+ def serialize(self) -> tuple[str, dict[str, Any]]:
+ return (
+ self.__class__.__module__ + "." + self.__class__.__qualname__,
+ {
+ "waiter_name": self.waiter_name,
+ "cluster_arn": self.cluster_arn,
+ "waiter_delay": self.waiter_delay,
+ "waiter_max_attempts": self.attempts,
+ "aws_conn_id": self.aws_conn_id,
+ "region": self.region,
+ },
+ )
+
+ async def run(self) -> AsyncIterator[TriggerEvent]:
+ async with EcsHook(aws_conn_id=self.aws_conn_id,
region_name=self.region).async_conn as client:
+ waiter = client.get_waiter(self.waiter_name)
+ await async_wait(
+ waiter,
+ self.waiter_delay,
+ self.attempts,
+ {"clusters": [self.cluster_arn]},
+ "error when checking cluster status",
+ "Status of cluster",
+ ["clusters[].status"],
+ )
+ yield TriggerEvent({"status": "success", "arn": self.cluster_arn})
+
+
+class TaskDoneTrigger(BaseTrigger):
+ """
+ Waits for an ECS task to be done, while eventually polling logs.
+
+ :param cluster: short name or full ARN of the cluster where the task is
running.
+ :param task_arn: ARN of the task to watch.
+ :param waiter_delay: The amount of time in seconds to wait between
attempts.
+ :param waiter_max_attempts: The number of times to ping for status.
+ Will fail after that many unsuccessful attempts.
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
+ :param region: The AWS region where the cluster is located.
+ """
+
+ def __init__(
+ self,
+ cluster: str,
+ task_arn: str,
+ waiter_delay: int,
+ waiter_max_attempts: int,
+ aws_conn_id: str | None,
+ region: str | None,
+ log_group: str | None = None,
+ log_stream: str | None = None,
+ ):
+ self.cluster = cluster
+ self.task_arn = task_arn
+
+ self.waiter_delay = waiter_delay
+ self.waiter_max_attempts = waiter_max_attempts
+ self.aws_conn_id = aws_conn_id
+ self.region = region
+
+ self.log_group = log_group
+ self.log_stream = log_stream
+
+ def serialize(self) -> tuple[str, dict[str, Any]]:
+ return (
+ self.__class__.__module__ + "." + self.__class__.__qualname__,
+ {
+ "cluster": self.cluster,
+ "task_arn": self.task_arn,
+ "waiter_delay": self.waiter_delay,
+ "waiter_max_attempts": self.waiter_max_attempts,
+ "aws_conn_id": self.aws_conn_id,
+ "region": self.region,
+ "log_group": self.log_group,
+ "log_stream": self.log_stream,
+ },
+ )
+
+ async def run(self) -> AsyncIterator[TriggerEvent]:
+ # fmt: off
+ async with EcsHook(aws_conn_id=self.aws_conn_id,
region_name=self.region).async_conn as ecs_client,\
+ AwsLogsHook(aws_conn_id=self.aws_conn_id,
region_name=self.region).async_conn as logs_client:
+ # fmt: on
+ waiter = ecs_client.get_waiter("tasks_stopped")
+ logs_token = None
+ while self.waiter_max_attempts >= 1:
+ self.waiter_max_attempts = self.waiter_max_attempts - 1
+ try:
+ await waiter.wait(
+ cluster=self.cluster, tasks=[self.task_arn],
WaiterConfig={"MaxAttempts": 1}
+ )
+ break # we reach this point only if the waiter met a
success criteria
+ except WaiterError as error:
+ if "terminal failure" in str(error):
+ raise
+ self.log.info("Status of the task is %s",
error.last_response["tasks"][0]["lastStatus"])
+ await asyncio.sleep(int(self.waiter_delay))
+ finally:
+ if self.log_group and self.log_stream:
+ logs_token = await self._forward_logs(logs_client,
logs_token)
+
+ yield TriggerEvent({"status": "success", "task_arn": self.task_arn})
+
+ async def _forward_logs(self, logs_client, next_token: str | None = None)
-> str | None:
+ """
+ Reads logs from the cloudwatch stream and prints them to the task logs.
+ :return: the token to pass to the next iteration to resume where we
started.
+ """
+ while True:
+ if next_token is not None:
+ token_arg: dict[str, str] = {"nextToken": next_token}
+ else:
+ token_arg = {}
+ try:
+ response = await logs_client.get_log_events(
+ logGroupName=self.log_group,
+ logStreamName=self.log_stream,
+ startFromHead=True,
+ **token_arg,
+ )
+ except ClientError as ce:
+ if ce.response["Error"]["Code"] == "ResourceNotFoundException":
+ self.log.info(
+ "Tried to get logs from stream %s in group %s but it
didn't exist (yet). "
+ "Will try again.",
+ self.log_stream,
+ self.log_group,
+ )
+ return None
+ raise
+
+ events = response["events"]
+ for log_event in events:
+ self.log.info(AwsTaskLogFetcher.event_to_str(log_event))
+
+ if len(events) == 0 or next_token == response["nextForwardToken"]:
+ return response["nextForwardToken"]
+ next_token = response["nextForwardToken"]
diff --git a/airflow/providers/amazon/aws/utils/task_log_fetcher.py
b/airflow/providers/amazon/aws/utils/task_log_fetcher.py
index 97b43a67b2..22a5e5f2a1 100644
--- a/airflow/providers/amazon/aws/utils/task_log_fetcher.py
+++ b/airflow/providers/amazon/aws/utils/task_log_fetcher.py
@@ -62,7 +62,7 @@ class AwsTaskLogFetcher(Thread):
time.sleep(self.fetch_interval.total_seconds())
log_events = self._get_log_events(continuation_token)
for log_event in log_events:
- self.logger.info(self._event_to_str(log_event))
+ self.logger.info(self.event_to_str(log_event))
def _get_log_events(self, skip_token: AwsLogsHook.ContinuationToken | None
= None) -> Generator:
if skip_token is None:
@@ -87,7 +87,8 @@ class AwsTaskLogFetcher(Thread):
self.logger.warning("ConnectionClosedError on retrieving
Cloudwatch log events", error)
yield from ()
- def _event_to_str(self, event: dict) -> str:
+ @staticmethod
+ def event_to_str(event: dict) -> str:
event_dt = datetime.utcfromtimestamp(event["timestamp"] / 1000.0)
formatted_event_dt = event_dt.strftime("%Y-%m-%d %H:%M:%S,%f")[:-3]
message = event["message"]
diff --git a/airflow/providers/amazon/provider.yaml
b/airflow/providers/amazon/provider.yaml
index 3680915dc3..223bc553ef 100644
--- a/airflow/providers/amazon/provider.yaml
+++ b/airflow/providers/amazon/provider.yaml
@@ -532,6 +532,9 @@ triggers:
- integration-name: Amazon Elastic Kubernetes Service (EKS)
python-modules:
- airflow.providers.amazon.aws.triggers.eks
+ - integration-name: Amazon ECS
+ python-modules:
+ - airflow.providers.amazon.aws.triggers.ecs
transfers:
- source-integration-name: Amazon DynamoDB
diff --git a/tests/providers/amazon/aws/operators/test_ecs.py
b/tests/providers/amazon/aws/operators/test_ecs.py
index 8c99a02ce8..b89ea59c56 100644
--- a/tests/providers/amazon/aws/operators/test_ecs.py
+++ b/tests/providers/amazon/aws/operators/test_ecs.py
@@ -20,13 +20,14 @@ from __future__ import annotations
import sys
from copy import deepcopy
from unittest import mock
+from unittest.mock import MagicMock, PropertyMock
import boto3
import pytest
-from airflow.exceptions import AirflowException,
AirflowProviderDeprecationWarning
+from airflow.exceptions import AirflowException,
AirflowProviderDeprecationWarning, TaskDeferred
from airflow.providers.amazon.aws.exceptions import EcsOperatorError,
EcsTaskFailToStart
-from airflow.providers.amazon.aws.hooks.ecs import EcsHook
+from airflow.providers.amazon.aws.hooks.ecs import EcsClusterStates, EcsHook
from airflow.providers.amazon.aws.operators.ecs import (
DEFAULT_CONN_ID,
EcsBaseOperator,
@@ -36,6 +37,7 @@ from airflow.providers.amazon.aws.operators.ecs import (
EcsRegisterTaskDefinitionOperator,
EcsRunTaskOperator,
)
+from airflow.providers.amazon.aws.triggers.ecs import TaskDoneTrigger
from airflow.providers.amazon.aws.utils.task_log_fetcher import
AwsTaskLogFetcher
from airflow.utils.types import NOTSET
@@ -186,6 +188,7 @@ class TestEcsRunTaskOperator(EcsBaseTestCase):
"reattach",
"number_logs_exception",
"wait_for_completion",
+ "deferrable",
)
@pytest.mark.parametrize(
@@ -343,7 +346,7 @@ class TestEcsRunTaskOperator(EcsBaseTestCase):
self.ecs._wait_for_task_ended()
client_mock.get_waiter.assert_called_once_with("tasks_stopped")
client_mock.get_waiter.return_value.wait.assert_called_once_with(
- cluster="c", tasks=["arn"], WaiterConfig={}
+ cluster="c", tasks=["arn"], WaiterConfig={"Delay": 6,
"MaxAttempts": 100}
)
assert sys.maxsize ==
client_mock.get_waiter.return_value.config.max_attempts
@@ -654,6 +657,31 @@ class TestEcsRunTaskOperator(EcsBaseTestCase):
self.ecs.do_xcom_push = False
assert self.ecs.execute(None) is None
+ @mock.patch.object(EcsRunTaskOperator, "client")
+ def test_with_defer(self, client_mock):
+ self.ecs.deferrable = True
+
+ client_mock.run_task.return_value = RESPONSE_WITHOUT_FAILURES
+
+ with pytest.raises(TaskDeferred) as deferred:
+ self.ecs.execute(None)
+
+ assert isinstance(deferred.value.trigger, TaskDoneTrigger)
+ assert deferred.value.trigger.task_arn ==
f"arn:aws:ecs:us-east-1:012345678910:task/{TASK_ID}"
+
+ @mock.patch.object(EcsRunTaskOperator, "client", new_callable=PropertyMock)
+ @mock.patch.object(EcsRunTaskOperator, "_xcom_del")
+ def test_execute_complete(self, xcom_del_mock: MagicMock, client_mock):
+ event = {"status": "success", "task_arn": "my_arn"}
+ self.ecs.reattach = True
+
+ self.ecs.execute_complete(None, event)
+
+ # task gets described to assert its success
+ client_mock().describe_tasks.assert_called_once_with(cluster="c",
tasks=["my_arn"])
+ # if reattach mode, xcom value is deleted on success
+ xcom_del_mock.assert_called_once()
+
class TestEcsCreateClusterOperator(EcsBaseTestCase):
@pytest.mark.parametrize("waiter_delay, waiter_max_attempts",
WAITERS_TEST_CASES)
@@ -680,6 +708,26 @@ class TestEcsCreateClusterOperator(EcsBaseTestCase):
mocked_waiters.wait.assert_called_once_with(clusters=mock.ANY,
WaiterConfig=expected_waiter_config)
assert result is not None
+ @mock.patch.object(EcsCreateClusterOperator, "client")
+ def test_execute_deferrable(self, mock_client: MagicMock):
+ op = EcsCreateClusterOperator(
+ task_id="task",
+ cluster_name=CLUSTER_NAME,
+ deferrable=True,
+ waiter_delay=12,
+ waiter_max_attempts=34,
+ )
+ mock_client.create_cluster.return_value = {
+ "cluster": {"status": EcsClusterStates.PROVISIONING, "clusterArn":
"my arn"}
+ }
+
+ with pytest.raises(TaskDeferred) as defer:
+ op.execute(None)
+
+ assert defer.value.trigger.cluster_arn == "my arn"
+ assert defer.value.trigger.waiter_delay == 12
+ assert defer.value.trigger.attempts == 34
+
def test_execute_immediate_create(self, patch_hook_waiters):
"""Test if cluster created during initial request."""
op = EcsCreateClusterOperator(task_id="task",
cluster_name=CLUSTER_NAME, wait_for_completion=True)
@@ -725,6 +773,26 @@ class TestEcsDeleteClusterOperator(EcsBaseTestCase):
mocked_waiters.wait.assert_called_once_with(clusters=mock.ANY,
WaiterConfig=expected_waiter_config)
assert result is not None
+ @mock.patch.object(EcsDeleteClusterOperator, "client")
+ def test_execute_deferrable(self, mock_client: MagicMock):
+ op = EcsDeleteClusterOperator(
+ task_id="task",
+ cluster_name=CLUSTER_NAME,
+ deferrable=True,
+ waiter_delay=12,
+ waiter_max_attempts=34,
+ )
+ mock_client.delete_cluster.return_value = {
+ "cluster": {"status": EcsClusterStates.DEPROVISIONING,
"clusterArn": "my arn"}
+ }
+
+ with pytest.raises(TaskDeferred) as defer:
+ op.execute(None)
+
+ assert defer.value.trigger.cluster_arn == "my arn"
+ assert defer.value.trigger.waiter_delay == 12
+ assert defer.value.trigger.attempts == 34
+
def test_execute_immediate_delete(self, patch_hook_waiters):
"""Test if cluster deleted during initial request."""
op = EcsDeleteClusterOperator(task_id="task",
cluster_name=CLUSTER_NAME, wait_for_completion=True)
diff --git a/tests/providers/amazon/aws/triggers/test_ecs.py
b/tests/providers/amazon/aws/triggers/test_ecs.py
new file mode 100644
index 0000000000..09b5decbe6
--- /dev/null
+++ b/tests/providers/amazon/aws/triggers/test_ecs.py
@@ -0,0 +1,123 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest import mock
+from unittest.mock import AsyncMock
+
+import pytest
+from botocore.exceptions import WaiterError
+
+from airflow import AirflowException
+from airflow.providers.amazon.aws.hooks.ecs import EcsHook
+from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
+from airflow.providers.amazon.aws.triggers.ecs import ClusterWaiterTrigger,
TaskDoneTrigger
+from airflow.triggers.base import TriggerEvent
+
+
+class TestClusterWaiterTrigger:
+ @pytest.mark.asyncio
+ @mock.patch.object(EcsHook, "async_conn")
+ async def test_run_max_attempts(self, client_mock):
+ a_mock = mock.MagicMock()
+ client_mock.__aenter__.return_value = a_mock
+ wait_mock = AsyncMock()
+ wait_mock.side_effect = WaiterError("name", "reason", {"clusters":
[{"status": "my_status"}]})
+ a_mock.get_waiter().wait = wait_mock
+
+ max_attempts = 5
+ trigger = ClusterWaiterTrigger("my_waiter", "cluster_arn", 0,
max_attempts, None, None)
+
+ with pytest.raises(AirflowException):
+ generator = trigger.run()
+ await generator.asend(None)
+
+ assert wait_mock.call_count == max_attempts
+
+ @pytest.mark.asyncio
+ @mock.patch.object(EcsHook, "async_conn")
+ async def test_run_success(self, client_mock):
+ a_mock = mock.MagicMock()
+ client_mock.__aenter__.return_value = a_mock
+ wait_mock = AsyncMock()
+ a_mock.get_waiter().wait = wait_mock
+
+ trigger = ClusterWaiterTrigger("my_waiter", "cluster_arn", 0, 5, None,
None)
+
+ generator = trigger.run()
+ response: TriggerEvent = await generator.asend(None)
+
+ assert response.payload["status"] == "success"
+ assert response.payload["arn"] == "cluster_arn"
+
+ @pytest.mark.asyncio
+ @mock.patch.object(EcsHook, "async_conn")
+ async def test_run_error(self, client_mock):
+ a_mock = mock.MagicMock()
+ client_mock.__aenter__.return_value = a_mock
+ wait_mock = AsyncMock()
+ wait_mock.side_effect = WaiterError("terminal failure", "reason", {})
+ a_mock.get_waiter().wait = wait_mock
+
+ trigger = ClusterWaiterTrigger("my_waiter", "cluster_arn", 0, 5, None,
None)
+
+ with pytest.raises(AirflowException):
+ generator = trigger.run()
+ await generator.asend(None)
+
+
+class TestTaskDoneTrigger:
+ @pytest.mark.asyncio
+ @mock.patch.object(EcsHook, "async_conn")
+ # this mock is only necessary to avoid a "No module named 'aiobotocore'"
error in the LatestBoto CI step
+ @mock.patch.object(AwsLogsHook, "async_conn")
+ async def test_run_until_error(self, _, client_mock):
+ a_mock = mock.MagicMock()
+ client_mock.__aenter__.return_value = a_mock
+ wait_mock = AsyncMock()
+ wait_mock.side_effect = [
+ WaiterError("name", "reason", {"tasks": [{"lastStatus":
"my_status"}]}),
+ WaiterError("name", "reason", {"tasks": [{"lastStatus":
"my_status"}]}),
+ WaiterError("terminal failure", "reason", {}),
+ ]
+ a_mock.get_waiter().wait = wait_mock
+
+ trigger = TaskDoneTrigger("cluster", "task_arn", 0, 10, None, None)
+
+ with pytest.raises(WaiterError):
+ generator = trigger.run()
+ await generator.asend(None)
+
+ assert wait_mock.call_count == 3
+
+ @pytest.mark.asyncio
+ @mock.patch.object(EcsHook, "async_conn")
+ # this mock is only necessary to avoid a "No module named 'aiobotocore'"
error in the LatestBoto CI step
+ @mock.patch.object(AwsLogsHook, "async_conn")
+ async def test_run_success(self, _, client_mock):
+ a_mock = mock.MagicMock()
+ client_mock.__aenter__.return_value = a_mock
+ wait_mock = AsyncMock()
+ a_mock.get_waiter().wait = wait_mock
+
+ trigger = TaskDoneTrigger("cluster", "my_task_arn", 0, 10, None, None)
+
+ generator = trigger.run()
+ response: TriggerEvent = await generator.asend(None)
+
+ assert response.payload["status"] == "success"
+ assert response.payload["task_arn"] == "my_task_arn"
diff --git a/tests/providers/amazon/aws/utils/test_task_log_fetcher.py
b/tests/providers/amazon/aws/utils/test_task_log_fetcher.py
index dbda751cfb..a5598ebf55 100644
--- a/tests/providers/amazon/aws/utils/test_task_log_fetcher.py
+++ b/tests/providers/amazon/aws/utils/test_task_log_fetcher.py
@@ -112,7 +112,7 @@ class TestAwsTaskLogFetcher:
{"timestamp": 1617400367456, "message": "Second"},
{"timestamp": 1617400467789, "message": "Third"},
]
- assert [self.log_fetcher._event_to_str(event) for event in events] == (
+ assert [self.log_fetcher.event_to_str(event) for event in events] == (
[
"[2021-04-02 21:51:07,123] First",
"[2021-04-02 21:52:47,456] Second",