This is an automated email from the ASF dual-hosted git repository.
ash pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/master by this push:
new 614be87 Added retry to ECS Operator (#14263)
614be87 is described below
commit 614be87b23199acd67e69677cfdb6ae4ed023b69
Author: Mark Hopson <[email protected]>
AuthorDate: Fri Mar 26 18:34:15 2021 -0400
Added retry to ECS Operator (#14263)
* Added retry to ECS Operator
* ...
* Remove airflow/www/yarn-error.log
* Update decorator to not accept any params
* ...
* ...
* ...
* lint
* Add predicate argument in retry decorator
* Add wraps and fixed test
* ...
* Remove unnecessary retry_if_permissible_error and fix lint errors
* Static check fixes
* Fix TestECSOperator.test_execute_with_failures
---
airflow/providers/amazon/aws/exceptions.py | 29 ++++++++
airflow/providers/amazon/aws/hooks/base_aws.py | 35 +++++++++-
airflow/providers/amazon/aws/operators/ecs.py | 19 +++++-
tests/providers/amazon/aws/hooks/test_base_aws.py | 80 +++++++++++++++++++++++
tests/providers/amazon/aws/operators/test_ecs.py | 13 +++-
5 files changed, 172 insertions(+), 4 deletions(-)
diff --git a/airflow/providers/amazon/aws/exceptions.py
b/airflow/providers/amazon/aws/exceptions.py
new file mode 100644
index 0000000..d0e5b54
--- /dev/null
+++ b/airflow/providers/amazon/aws/exceptions.py
@@ -0,0 +1,29 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# Note: Any AirflowException raised is expected to cause the TaskInstance
+# to be marked in an ERROR state
+
+
+class ECSOperatorError(Exception):
+ """Raise when ECS cannot handle the request."""
+
+ def __init__(self, failures: list, message: str):
+ self.failures = failures
+ self.message = message
+ super().__init__(message)
diff --git a/airflow/providers/amazon/aws/hooks/base_aws.py
b/airflow/providers/amazon/aws/hooks/base_aws.py
index 5cd694b..c1c5b1d 100644
--- a/airflow/providers/amazon/aws/hooks/base_aws.py
+++ b/airflow/providers/amazon/aws/hooks/base_aws.py
@@ -27,11 +27,13 @@ This module contains Base AWS Hook.
import configparser
import datetime
import logging
-from typing import Any, Dict, Optional, Tuple, Union
+from functools import wraps
+from typing import Any, Callable, Dict, Optional, Tuple, Union
import boto3
import botocore
import botocore.session
+import tenacity
from botocore.config import Config
from botocore.credentials import ReadOnlyCredentials
@@ -488,6 +490,37 @@ class AwsBaseHook(BaseHook):
else:
return
self.get_client_type("iam").get_role(RoleName=role)["Role"]["Arn"]
+ @staticmethod
+ def retry(should_retry: Callable[[Exception], bool]):
+ """
+ A decorator that provides a mechanism to repeat requests in response
to exceeding a temporary quote
+ limit.
+ """
+
+ def retry_decorator(fun: Callable):
+ @wraps(fun)
+ def decorator_f(self, *args, **kwargs):
+ retry_args = getattr(self, 'retry_args', None)
+ if retry_args is None:
+ return fun(self)
+ multiplier = retry_args.get('multiplier', 1)
+ min_limit = retry_args.get('min', 1)
+ max_limit = retry_args.get('max', 1)
+ stop_after_delay = retry_args.get('stop_after_delay', 10)
+ tenacity_logger = tenacity.before_log(self.log, logging.DEBUG)
if self.log else None
+ default_kwargs = {
+ 'wait': tenacity.wait_exponential(multiplier=multiplier,
max=max_limit, min=min_limit),
+ 'retry': tenacity.retry_if_exception(should_retry),
+ 'stop': tenacity.stop_after_delay(stop_after_delay),
+ 'before': tenacity_logger,
+ 'after': tenacity_logger,
+ }
+ return tenacity.retry(**default_kwargs)(fun)(self)
+
+ return decorator_f
+
+ return retry_decorator
+
def _parse_s3_config(
config_file_name: str, config_format: Optional[str] = "boto", profile:
Optional[str] = None
diff --git a/airflow/providers/amazon/aws/operators/ecs.py
b/airflow/providers/amazon/aws/operators/ecs.py
index a0e6dea..50ab958 100644
--- a/airflow/providers/amazon/aws/operators/ecs.py
+++ b/airflow/providers/amazon/aws/operators/ecs.py
@@ -25,12 +25,24 @@ from botocore.waiter import Waiter
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
+from airflow.providers.amazon.aws.exceptions import ECSOperatorError
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
from airflow.typing_compat import Protocol, runtime_checkable
from airflow.utils.decorators import apply_defaults
+def should_retry(exception: Exception):
+ """Check if exception is related to ECS resource quota (CPU, MEM)."""
+ if isinstance(exception, ECSOperatorError):
+ return any(
+ quota_reason in failure['reason']
+ for quota_reason in ['RESOURCE:MEMORY', 'RESOURCE:CPU']
+ for failure in exception.failures
+ )
+ return False
+
+
@runtime_checkable
class ECSProtocol(Protocol):
"""
@@ -125,6 +137,8 @@ class ECSOperator(BaseOperator): # pylint:
disable=too-many-instance-attributes
:param reattach: If set to True, will check if a task from the same family
is already running.
If so, the operator will attach to it instead of starting a new task.
:type reattach: bool
+ :param quota_retry: Config if and how to retry _start_task() for transient
errors.
+ :type quota_retry: dict
"""
ui_color = '#f0ede4'
@@ -150,6 +164,7 @@ class ECSOperator(BaseOperator): # pylint:
disable=too-many-instance-attributes
awslogs_region: Optional[str] = None,
awslogs_stream_prefix: Optional[str] = None,
propagate_tags: Optional[str] = None,
+ quota_retry: Optional[dict] = None,
reattach: bool = False,
**kwargs,
):
@@ -180,6 +195,7 @@ class ECSOperator(BaseOperator): # pylint:
disable=too-many-instance-attributes
self.hook: Optional[AwsBaseHook] = None
self.client: Optional[ECSProtocol] = None
self.arn: Optional[str] = None
+ self.retry_args = quota_retry
def execute(self, context):
self.log.info(
@@ -206,6 +222,7 @@ class ECSOperator(BaseOperator): # pylint:
disable=too-many-instance-attributes
return None
+ @AwsBaseHook.retry(should_retry)
def _start_task(self):
run_opts = {
'cluster': self.cluster,
@@ -235,7 +252,7 @@ class ECSOperator(BaseOperator): # pylint:
disable=too-many-instance-attributes
failures = response['failures']
if len(failures) > 0:
- raise AirflowException(response)
+ raise ECSOperatorError(failures, response)
self.log.info('ECS Task started: %s', response)
self.arn = response['tasks'][0]['taskArn']
diff --git a/tests/providers/amazon/aws/hooks/test_base_aws.py
b/tests/providers/amazon/aws/hooks/test_base_aws.py
index da8f8c8..383880d 100644
--- a/tests/providers/amazon/aws/hooks/test_base_aws.py
+++ b/tests/providers/amazon/aws/hooks/test_base_aws.py
@@ -21,6 +21,7 @@ import unittest
from unittest import mock
import boto3
+import pytest
from airflow.models import Connection
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
@@ -266,3 +267,82 @@ class TestAwsBaseHook(unittest.TestCase):
hook = AwsBaseHook(aws_conn_id=conn_id, client_type='s3')
# should cause no exception
hook.get_client_type('s3')
+
+
+class ThrowErrorUntilCount:
+ """Holds counter state for invoking a method several times in a row."""
+
+ def __init__(self, count, quota_retry, **kwargs):
+ self.counter = 0
+ self.count = count
+ self.retry_args = quota_retry
+ self.kwargs = kwargs
+ self.log = None
+
+ def __call__(self):
+ """
+ Raise an Forbidden until after count threshold has been crossed.
+ Then return True.
+ """
+ if self.counter < self.count:
+ self.counter += 1
+ raise Exception()
+ return True
+
+
+def _always_true_predicate(e: Exception): # pylint: disable=unused-argument
+ return True
+
+
[email protected](_always_true_predicate)
+def _retryable_test(thing):
+ return thing()
+
+
+def _always_false_predicate(e: Exception): # pylint: disable=unused-argument
+ return False
+
+
[email protected](_always_false_predicate)
+def _non_retryable_test(thing):
+ return thing()
+
+
+class TestRetryDecorator(unittest.TestCase): # ptlint: disable=invalid-name
+ def test_do_nothing_on_non_exception(self):
+ result = _retryable_test(lambda: 42)
+ assert result, 42
+
+ def test_retry_on_exception(self):
+ quota_retry = {
+ 'stop_after_delay': 2,
+ 'multiplier': 1,
+ 'min': 1,
+ 'max': 10,
+ }
+ custom_fn = ThrowErrorUntilCount(
+ count=2,
+ quota_retry=quota_retry,
+ )
+ result = _retryable_test(custom_fn)
+ assert custom_fn.counter == 2
+ assert result
+
+ def test_no_retry_on_exception(self):
+ quota_retry = {
+ 'stop_after_delay': 2,
+ 'multiplier': 1,
+ 'min': 1,
+ 'max': 10,
+ }
+ custom_fn = ThrowErrorUntilCount(
+ count=2,
+ quota_retry=quota_retry,
+ )
+ with pytest.raises(Exception):
+ _non_retryable_test(custom_fn)
+
+ def test_raise_exception_when_no_retry_args(self):
+ custom_fn = ThrowErrorUntilCount(count=2, quota_retry=None)
+ with pytest.raises(Exception):
+ _retryable_test(custom_fn)
diff --git a/tests/providers/amazon/aws/operators/test_ecs.py
b/tests/providers/amazon/aws/operators/test_ecs.py
index 7465f0c..96717c3 100644
--- a/tests/providers/amazon/aws/operators/test_ecs.py
+++ b/tests/providers/amazon/aws/operators/test_ecs.py
@@ -26,7 +26,8 @@ import pytest
from parameterized import parameterized
from airflow.exceptions import AirflowException
-from airflow.providers.amazon.aws.operators.ecs import ECSOperator
+from airflow.providers.amazon.aws.exceptions import ECSOperatorError
+from airflow.providers.amazon.aws.operators.ecs import ECSOperator,
should_retry
# fmt: off
RESPONSE_WITHOUT_FAILURES = {
@@ -145,7 +146,7 @@ class TestECSOperator(unittest.TestCase):
resp_failures['failures'].append('dummy error')
client_mock.run_task.return_value = resp_failures
- with pytest.raises(AirflowException):
+ with pytest.raises(ECSOperatorError):
self.ecs.execute(None)
self.aws_hook_mock.return_value.get_conn.assert_called_once()
@@ -326,3 +327,11 @@ class TestECSOperator(unittest.TestCase):
def test_execute_xcom_disabled(self, mock_cloudwatch_log_message):
self.ecs.do_xcom_push = False
assert self.ecs.execute(None) is None
+
+
+class TestShouldRetry(unittest.TestCase):
+ def test_return_true_on_valid_reason(self):
+ self.assertTrue(should_retry(ECSOperatorError([{'reason':
'RESOURCE:MEMORY'}], 'Foo')))
+
+ def test_return_false_on_invalid_reason(self):
+ self.assertFalse(should_retry(ECSOperatorError([{'reason':
'CLUSTER_NOT_FOUND'}], 'Foo')))