This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/master by this push:
     new 614be87  Added retry to ECS Operator (#14263)
614be87 is described below

commit 614be87b23199acd67e69677cfdb6ae4ed023b69
Author: Mark Hopson <[email protected]>
AuthorDate: Fri Mar 26 18:34:15 2021 -0400

    Added retry to ECS Operator (#14263)
    
    * Added retry to ECS Operator
    
    * ...
    
    * Remove airflow/www/yarn-error.log
    
    * Update decorator to not accept any params
    
    * ...
    
    * ...
    
    * ...
    
    * lint
    
    * Add predicate argument in retry decorator
    
    * Add wraps and fixed test
    
    * ...
    
    * Remove unnecessary retry_if_permissible_error and fix lint errors
    
    * Static check fixes
    
    * Fix TestECSOperator.test_execute_with_failures
---
 airflow/providers/amazon/aws/exceptions.py        | 29 ++++++++
 airflow/providers/amazon/aws/hooks/base_aws.py    | 35 +++++++++-
 airflow/providers/amazon/aws/operators/ecs.py     | 19 +++++-
 tests/providers/amazon/aws/hooks/test_base_aws.py | 80 +++++++++++++++++++++++
 tests/providers/amazon/aws/operators/test_ecs.py  | 13 +++-
 5 files changed, 172 insertions(+), 4 deletions(-)

diff --git a/airflow/providers/amazon/aws/exceptions.py 
b/airflow/providers/amazon/aws/exceptions.py
new file mode 100644
index 0000000..d0e5b54
--- /dev/null
+++ b/airflow/providers/amazon/aws/exceptions.py
@@ -0,0 +1,29 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# Note: Any AirflowException raised is expected to cause the TaskInstance
+#       to be marked in an ERROR state
+
+
+class ECSOperatorError(Exception):
+    """Raise when ECS cannot handle the request."""
+
+    def __init__(self, failures: list, message: str):
+        self.failures = failures
+        self.message = message
+        super().__init__(message)
diff --git a/airflow/providers/amazon/aws/hooks/base_aws.py 
b/airflow/providers/amazon/aws/hooks/base_aws.py
index 5cd694b..c1c5b1d 100644
--- a/airflow/providers/amazon/aws/hooks/base_aws.py
+++ b/airflow/providers/amazon/aws/hooks/base_aws.py
@@ -27,11 +27,13 @@ This module contains Base AWS Hook.
 import configparser
 import datetime
 import logging
-from typing import Any, Dict, Optional, Tuple, Union
+from functools import wraps
+from typing import Any, Callable, Dict, Optional, Tuple, Union
 
 import boto3
 import botocore
 import botocore.session
+import tenacity
 from botocore.config import Config
 from botocore.credentials import ReadOnlyCredentials
 
@@ -488,6 +490,37 @@ class AwsBaseHook(BaseHook):
         else:
             return 
self.get_client_type("iam").get_role(RoleName=role)["Role"]["Arn"]
 
+    @staticmethod
+    def retry(should_retry: Callable[[Exception], bool]):
+        """
+        A decorator that provides a mechanism to repeat requests in response 
to exceeding a temporary quote
+        limit.
+        """
+
+        def retry_decorator(fun: Callable):
+            @wraps(fun)
+            def decorator_f(self, *args, **kwargs):
+                retry_args = getattr(self, 'retry_args', None)
+                if retry_args is None:
+                    return fun(self)
+                multiplier = retry_args.get('multiplier', 1)
+                min_limit = retry_args.get('min', 1)
+                max_limit = retry_args.get('max', 1)
+                stop_after_delay = retry_args.get('stop_after_delay', 10)
+                tenacity_logger = tenacity.before_log(self.log, logging.DEBUG) 
if self.log else None
+                default_kwargs = {
+                    'wait': tenacity.wait_exponential(multiplier=multiplier, 
max=max_limit, min=min_limit),
+                    'retry': tenacity.retry_if_exception(should_retry),
+                    'stop': tenacity.stop_after_delay(stop_after_delay),
+                    'before': tenacity_logger,
+                    'after': tenacity_logger,
+                }
+                return tenacity.retry(**default_kwargs)(fun)(self)
+
+            return decorator_f
+
+        return retry_decorator
+
 
 def _parse_s3_config(
     config_file_name: str, config_format: Optional[str] = "boto", profile: 
Optional[str] = None
diff --git a/airflow/providers/amazon/aws/operators/ecs.py 
b/airflow/providers/amazon/aws/operators/ecs.py
index a0e6dea..50ab958 100644
--- a/airflow/providers/amazon/aws/operators/ecs.py
+++ b/airflow/providers/amazon/aws/operators/ecs.py
@@ -25,12 +25,24 @@ from botocore.waiter import Waiter
 
 from airflow.exceptions import AirflowException
 from airflow.models import BaseOperator
+from airflow.providers.amazon.aws.exceptions import ECSOperatorError
 from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
 from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
 from airflow.typing_compat import Protocol, runtime_checkable
 from airflow.utils.decorators import apply_defaults
 
 
+def should_retry(exception: Exception):
+    """Check if exception is related to ECS resource quota (CPU, MEM)."""
+    if isinstance(exception, ECSOperatorError):
+        return any(
+            quota_reason in failure['reason']
+            for quota_reason in ['RESOURCE:MEMORY', 'RESOURCE:CPU']
+            for failure in exception.failures
+        )
+    return False
+
+
 @runtime_checkable
 class ECSProtocol(Protocol):
     """
@@ -125,6 +137,8 @@ class ECSOperator(BaseOperator):  # pylint: 
disable=too-many-instance-attributes
     :param reattach: If set to True, will check if a task from the same family 
is already running.
         If so, the operator will attach to it instead of starting a new task.
     :type reattach: bool
+    :param quota_retry: Config if and how to retry _start_task() for transient 
errors.
+    :type quota_retry: dict
     """
 
     ui_color = '#f0ede4'
@@ -150,6 +164,7 @@ class ECSOperator(BaseOperator):  # pylint: 
disable=too-many-instance-attributes
         awslogs_region: Optional[str] = None,
         awslogs_stream_prefix: Optional[str] = None,
         propagate_tags: Optional[str] = None,
+        quota_retry: Optional[dict] = None,
         reattach: bool = False,
         **kwargs,
     ):
@@ -180,6 +195,7 @@ class ECSOperator(BaseOperator):  # pylint: 
disable=too-many-instance-attributes
         self.hook: Optional[AwsBaseHook] = None
         self.client: Optional[ECSProtocol] = None
         self.arn: Optional[str] = None
+        self.retry_args = quota_retry
 
     def execute(self, context):
         self.log.info(
@@ -206,6 +222,7 @@ class ECSOperator(BaseOperator):  # pylint: 
disable=too-many-instance-attributes
 
         return None
 
+    @AwsBaseHook.retry(should_retry)
     def _start_task(self):
         run_opts = {
             'cluster': self.cluster,
@@ -235,7 +252,7 @@ class ECSOperator(BaseOperator):  # pylint: 
disable=too-many-instance-attributes
 
         failures = response['failures']
         if len(failures) > 0:
-            raise AirflowException(response)
+            raise ECSOperatorError(failures, response)
         self.log.info('ECS Task started: %s', response)
 
         self.arn = response['tasks'][0]['taskArn']
diff --git a/tests/providers/amazon/aws/hooks/test_base_aws.py 
b/tests/providers/amazon/aws/hooks/test_base_aws.py
index da8f8c8..383880d 100644
--- a/tests/providers/amazon/aws/hooks/test_base_aws.py
+++ b/tests/providers/amazon/aws/hooks/test_base_aws.py
@@ -21,6 +21,7 @@ import unittest
 from unittest import mock
 
 import boto3
+import pytest
 
 from airflow.models import Connection
 from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
@@ -266,3 +267,82 @@ class TestAwsBaseHook(unittest.TestCase):
             hook = AwsBaseHook(aws_conn_id=conn_id, client_type='s3')
             # should cause no exception
             hook.get_client_type('s3')
+
+
+class ThrowErrorUntilCount:
+    """Holds counter state for invoking a method several times in a row."""
+
+    def __init__(self, count, quota_retry, **kwargs):
+        self.counter = 0
+        self.count = count
+        self.retry_args = quota_retry
+        self.kwargs = kwargs
+        self.log = None
+
+    def __call__(self):
+        """
+        Raise an Forbidden until after count threshold has been crossed.
+        Then return True.
+        """
+        if self.counter < self.count:
+            self.counter += 1
+            raise Exception()
+        return True
+
+
+def _always_true_predicate(e: Exception):  # pylint: disable=unused-argument
+    return True
+
+
[email protected](_always_true_predicate)
+def _retryable_test(thing):
+    return thing()
+
+
+def _always_false_predicate(e: Exception):  # pylint: disable=unused-argument
+    return False
+
+
[email protected](_always_false_predicate)
+def _non_retryable_test(thing):
+    return thing()
+
+
+class TestRetryDecorator(unittest.TestCase):  # ptlint: disable=invalid-name
+    def test_do_nothing_on_non_exception(self):
+        result = _retryable_test(lambda: 42)
+        assert result, 42
+
+    def test_retry_on_exception(self):
+        quota_retry = {
+            'stop_after_delay': 2,
+            'multiplier': 1,
+            'min': 1,
+            'max': 10,
+        }
+        custom_fn = ThrowErrorUntilCount(
+            count=2,
+            quota_retry=quota_retry,
+        )
+        result = _retryable_test(custom_fn)
+        assert custom_fn.counter == 2
+        assert result
+
+    def test_no_retry_on_exception(self):
+        quota_retry = {
+            'stop_after_delay': 2,
+            'multiplier': 1,
+            'min': 1,
+            'max': 10,
+        }
+        custom_fn = ThrowErrorUntilCount(
+            count=2,
+            quota_retry=quota_retry,
+        )
+        with pytest.raises(Exception):
+            _non_retryable_test(custom_fn)
+
+    def test_raise_exception_when_no_retry_args(self):
+        custom_fn = ThrowErrorUntilCount(count=2, quota_retry=None)
+        with pytest.raises(Exception):
+            _retryable_test(custom_fn)
diff --git a/tests/providers/amazon/aws/operators/test_ecs.py 
b/tests/providers/amazon/aws/operators/test_ecs.py
index 7465f0c..96717c3 100644
--- a/tests/providers/amazon/aws/operators/test_ecs.py
+++ b/tests/providers/amazon/aws/operators/test_ecs.py
@@ -26,7 +26,8 @@ import pytest
 from parameterized import parameterized
 
 from airflow.exceptions import AirflowException
-from airflow.providers.amazon.aws.operators.ecs import ECSOperator
+from airflow.providers.amazon.aws.exceptions import ECSOperatorError
+from airflow.providers.amazon.aws.operators.ecs import ECSOperator, 
should_retry
 
 # fmt: off
 RESPONSE_WITHOUT_FAILURES = {
@@ -145,7 +146,7 @@ class TestECSOperator(unittest.TestCase):
         resp_failures['failures'].append('dummy error')
         client_mock.run_task.return_value = resp_failures
 
-        with pytest.raises(AirflowException):
+        with pytest.raises(ECSOperatorError):
             self.ecs.execute(None)
 
         self.aws_hook_mock.return_value.get_conn.assert_called_once()
@@ -326,3 +327,11 @@ class TestECSOperator(unittest.TestCase):
     def test_execute_xcom_disabled(self, mock_cloudwatch_log_message):
         self.ecs.do_xcom_push = False
         assert self.ecs.execute(None) is None
+
+
+class TestShouldRetry(unittest.TestCase):
+    def test_return_true_on_valid_reason(self):
+        self.assertTrue(should_retry(ECSOperatorError([{'reason': 
'RESOURCE:MEMORY'}], 'Foo')))
+
+    def test_return_false_on_invalid_reason(self):
+        self.assertFalse(should_retry(ECSOperatorError([{'reason': 
'CLUSTER_NOT_FOUND'}], 'Foo')))

Reply via email to