Repository: incubator-airflow Updated Branches: refs/heads/master 41490f9c4 -> 98197d956
[AIRFLOW-345] Add contrib ECSOperator Closes #1894 from poulainv/ecs_operator Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/98197d95 Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/98197d95 Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/98197d95 Branch: refs/heads/master Commit: 98197d95681abaae0ec8f928e0147a8b32132ecb Parents: 41490f9 Author: Vincent Poulain <[email protected]> Authored: Wed Nov 23 10:49:57 2016 -0800 Committer: Siddharth Anand <[email protected]> Committed: Wed Nov 23 10:49:57 2016 -0800 ---------------------------------------------------------------------- airflow/contrib/hooks/aws_hook.py | 27 +++- airflow/contrib/operators/ecs_operator.py | 127 +++++++++++++++ docs/code.rst | 1 + tests/contrib/operators/ecs_operator.py | 207 +++++++++++++++++++++++++ 4 files changed, 356 insertions(+), 6 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/98197d95/airflow/contrib/hooks/aws_hook.py ---------------------------------------------------------------------- diff --git a/airflow/contrib/hooks/aws_hook.py b/airflow/contrib/hooks/aws_hook.py index 37a02ee..3eced28 100644 --- a/airflow/contrib/hooks/aws_hook.py +++ b/airflow/contrib/hooks/aws_hook.py @@ -12,24 +12,39 @@ # See the License for the specific language governing permissions and # limitations under the License. + import boto3 + +from airflow.exceptions import AirflowException from airflow.hooks.base_hook import BaseHook class AwsHook(BaseHook): """ Interact with AWS. - This class is a thin wrapper around the boto3 python library. """ def __init__(self, aws_conn_id='aws_default'): self.aws_conn_id = aws_conn_id - def get_client_type(self, client_type): - connection_object = self.get_connection(self.aws_conn_id) + def get_client_type(self, client_type, region_name=None): + try: + connection_object = self.get_connection(self.aws_conn_id) + aws_access_key_id = connection_object.login + aws_secret_access_key = connection_object.password + + if region_name is None: + region_name = connection_object.extra_dejson.get('region_name') + + except AirflowException: + # No connection found: fallback on boto3 credential strategy + # http://boto3.readthedocs.io/en/latest/guide/configuration.html + aws_access_key_id = None + aws_secret_access_key = None + return boto3.client( client_type, - region_name=connection_object.extra_dejson.get('region_name'), - aws_access_key_id=connection_object.login, - aws_secret_access_key=connection_object.password, + region_name=region_name, + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key ) http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/98197d95/airflow/contrib/operators/ecs_operator.py ---------------------------------------------------------------------- diff --git a/airflow/contrib/operators/ecs_operator.py b/airflow/contrib/operators/ecs_operator.py new file mode 100644 index 0000000..7415d32 --- /dev/null +++ b/airflow/contrib/operators/ecs_operator.py @@ -0,0 +1,127 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +import logging + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.utils import apply_defaults + +from airflow.contrib.hooks.aws_hook import AwsHook + + +class ECSOperator(BaseOperator): + + """ + Execute a task on AWS EC2 Container Service + + :param task_definition: the task definition name on EC2 Container Service + :type task_definition: str + :param cluster: the cluster name on EC2 Container Service + :type cluster: str + :param: overrides: the same parameter that boto3 will receive: http://boto3.readthedocs.org/en/latest/reference/services/ecs.html#ECS.Client.run_task + :type: overrides: dict + :param aws_conn_id: connection id of AWS credentials / region name. If None, credential boto3 strategy will be used (http://boto3.readthedocs.io/en/latest/guide/configuration.html). + :type aws_conn_id: str + :param region_name: region name to use in AWS Hook. Override the region_name in connection (if provided) + """ + + ui_color = '#f0ede4' + client = None + arn = None + template_fields = ('overrides',) + + @apply_defaults + def __init__(self, task_definition, cluster, overrides, + aws_conn_id=None, region_name=None, **kwargs): + super(ECSOperator, self).__init__(**kwargs) + + self.aws_conn_id = aws_conn_id + self.region_name = region_name + self.task_definition = task_definition + self.cluster = cluster + self.overrides = overrides + + self.hook = self.get_hook() + + def execute(self, context): + + logging.info('Running ECS Task - Task definition: {} - on cluster {}'.format( + self.task_definition, + self.cluster + )) + logging.info('ECSOperator overrides: {}'.format(self.overrides)) + + self.client = self.hook.get_client_type( + 'ecs', + region_name=self.region_name + ) + + response = self.client.run_task( + cluster=self.cluster, + taskDefinition=self.task_definition, + overrides=self.overrides, + startedBy=self.owner + ) + + failures = response['failures'] + if (len(failures) > 0): + raise AirflowException(response) + logging.info('ECS Task started: {}'.format(response)) + + self.arn = response['tasks'][0]['taskArn'] + self._wait_for_task_ended() + + self._check_success_task() + logging.info('ECS Task has been successfully executed: {}'.format(response)) + + def _wait_for_task_ended(self): + waiter = self.client.get_waiter('tasks_stopped') + waiter.config.max_attempts = sys.maxint # timeout is managed by airflow + waiter.wait( + cluster=self.cluster, + tasks=[self.arn] + ) + + def _check_success_task(self): + response = self.client.describe_tasks( + cluster=self.cluster, + tasks=[self.arn] + ) + logging.info('ECS Task stopped, check status: {}'.format(response)) + + if (len(response.get('failures', [])) > 0): + raise AirflowException(response) + + for task in response['tasks']: + containers = task['containers'] + for container in containers: + if container.get('lastStatus') == 'STOPPED' and container['exitCode'] != 0: + raise AirflowException('This task is not in success state {}'.format(task)) + elif container.get('lastStatus') == 'PENDING': + raise AirflowException('This task is still pending {}'.format(task)) + elif 'error' in container.get('reason', '').lower(): + raise AirflowException('This containers encounter an error during launching : {}'.format(container.get('reason', '').lower())) + + def get_hook(self): + return AwsHook( + aws_conn_id=self.aws_conn_id + ) + + def on_kill(self): + response = self.client.stop_task( + cluster=self.cluster, + task=self.arn, + reason='Task killed by the user') + logging.info(response) http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/98197d95/docs/code.rst ---------------------------------------------------------------------- diff --git a/docs/code.rst b/docs/code.rst index 8548120..0e1993e 100644 --- a/docs/code.rst +++ b/docs/code.rst @@ -97,6 +97,7 @@ Community-contributed Operators .. autoclass:: airflow.contrib.operators.bigquery_operator.BigQueryOperator .. autoclass:: airflow.contrib.operators.bigquery_to_gcs.BigQueryToCloudStorageOperator +.. autoclass:: airflow.contrib.operators.ecs_operator.ECSOperator .. autoclass:: airflow.contrib.operators.gcs_download_operator.GoogleCloudStorageDownloadOperator .. autoclass:: airflow.contrib.operators.QuboleOperator .. autoclass:: airflow.contrib.operators.hipchat_operator.HipChatAPIOperator http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/98197d95/tests/contrib/operators/ecs_operator.py ---------------------------------------------------------------------- diff --git a/tests/contrib/operators/ecs_operator.py b/tests/contrib/operators/ecs_operator.py new file mode 100644 index 0000000..5a593a6 --- /dev/null +++ b/tests/contrib/operators/ecs_operator.py @@ -0,0 +1,207 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys +import unittest +from copy import deepcopy + +from airflow import configuration +from airflow.exceptions import AirflowException +from airflow.contrib.operators.ecs_operator import ECSOperator + +try: + from unittest import mock +except ImportError: + try: + import mock + except ImportError: + mock = None + + +RESPONSE_WITHOUT_FAILURES = { + "failures": [], + "tasks": [ + { + "containers": [ + { + "containerArn": "arn:aws:ecs:us-east-1:012345678910:container/e1ed7aac-d9b2-4315-8726-d2432bf11868", + "lastStatus": "PENDING", + "name": "wordpress", + "taskArn": "arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b55" + } + ], + "desiredStatus": "RUNNING", + "lastStatus": "PENDING", + "taskArn": "arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b55", + "taskDefinitionArn": "arn:aws:ecs:us-east-1:012345678910:task-definition/hello_world:11" + } + ] +} + + +class TestECSOperator(unittest.TestCase): + + @mock.patch('airflow.contrib.operators.ecs_operator.AwsHook') + def setUp(self, aws_hook_mock): + configuration.load_test_config() + + self.aws_hook_mock = aws_hook_mock + self.ecs = ECSOperator( + task_id='task', + task_definition='t', + cluster='c', + overrides={}, + aws_conn_id=None, + region_name='eu-west-1') + + def test_init(self): + + self.assertEqual(self.ecs.region_name, 'eu-west-1') + self.assertEqual(self.ecs.task_definition, 't') + self.assertEqual(self.ecs.aws_conn_id, None) + self.assertEqual(self.ecs.cluster, 'c') + self.assertEqual(self.ecs.overrides, {}) + self.assertEqual(self.ecs.hook, self.aws_hook_mock.return_value) + + self.aws_hook_mock.assert_called_once_with(aws_conn_id=None) + + def test_template_fields_overrides(self): + self.assertEqual(self.ecs.template_fields, ('overrides',)) + + @mock.patch.object(ECSOperator, '_wait_for_task_ended') + @mock.patch.object(ECSOperator, '_check_success_task') + def test_execute_without_failures(self, check_mock, wait_mock): + + client_mock = self.aws_hook_mock.return_value.get_client_type.return_value + client_mock.run_task.return_value = RESPONSE_WITHOUT_FAILURES + + self.ecs.execute(None) + + self.aws_hook_mock.return_value.get_client_type.assert_called_once_with('ecs', region_name='eu-west-1') + client_mock.run_task.assert_called_once_with( + cluster='c', + overrides={}, + startedBy='Airflow', + taskDefinition='t' + ) + + wait_mock.assert_called_once_with() + check_mock.assert_called_once_with() + self.assertEqual(self.ecs.arn, 'arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b55') + + def test_execute_with_failures(self): + + client_mock = self.aws_hook_mock.return_value.get_client_type.return_value + resp_failures = deepcopy(RESPONSE_WITHOUT_FAILURES) + resp_failures['failures'].append('dummy error') + client_mock.run_task.return_value = resp_failures + + with self.assertRaises(AirflowException): + self.ecs.execute(None) + + self.aws_hook_mock.return_value.get_client_type.assert_called_once_with('ecs', region_name='eu-west-1') + client_mock.run_task.assert_called_once_with( + cluster='c', + overrides={}, + startedBy='Airflow', + taskDefinition='t' + ) + + def test_wait_end_tasks(self): + + client_mock = mock.Mock() + self.ecs.arn = 'arn' + self.ecs.client = client_mock + + self.ecs._wait_for_task_ended() + client_mock.get_waiter.assert_called_once_with('tasks_stopped') + client_mock.get_waiter.return_value.wait.assert_called_once_with(cluster='c', tasks=['arn']) + self.assertEquals(sys.maxint, client_mock.get_waiter.return_value.config.max_attempts) + + def test_check_success_tasks_raises(self): + client_mock = mock.Mock() + self.ecs.arn = 'arn' + self.ecs.client = client_mock + + client_mock.describe_tasks.return_value = { + 'tasks': [{ + 'containers': [{ + 'name': 'foo', + 'lastStatus': 'STOPPED', + 'exitCode': 1 + }] + }] + } + with self.assertRaises(Exception) as e: + self.ecs._check_success_task() + + self.assertEquals(str(e.exception), "This task is not in success state {'containers': [{'lastStatus': 'STOPPED', 'name': 'foo', 'exitCode': 1}]}") + client_mock.describe_tasks.assert_called_once_with(cluster='c', tasks=['arn']) + + def test_check_success_tasks_raises_pending(self): + client_mock = mock.Mock() + self.ecs.client = client_mock + self.ecs.arn = 'arn' + client_mock.describe_tasks.return_value = { + 'tasks': [{ + 'containers': [{ + 'name': 'container-name', + 'lastStatus': 'PENDING' + }] + }] + } + with self.assertRaises(Exception) as e: + self.ecs._check_success_task() + self.assertEquals(str(e.exception), "This task is still pending {'containers': [{'lastStatus': 'PENDING', 'name': 'container-name'}]}") + client_mock.describe_tasks.assert_called_once_with(cluster='c', tasks=['arn']) + + def test_check_success_tasks_raises_mutliple(self): + client_mock = mock.Mock() + self.ecs.client = client_mock + self.ecs.arn = 'arn' + client_mock.describe_tasks.return_value = { + 'tasks': [{ + 'containers': [{ + 'name': 'foo', + 'exitCode': 1 + }, { + 'name': 'bar', + 'lastStatus': 'STOPPED', + 'exitCode': 0 + }] + }] + } + self.ecs._check_success_task() + client_mock.describe_tasks.assert_called_once_with(cluster='c', tasks=['arn']) + + def test_check_success_task_not_raises(self): + client_mock = mock.Mock() + self.ecs.client = client_mock + self.ecs.arn = 'arn' + client_mock.describe_tasks.return_value = { + 'tasks': [{ + 'containers': [{ + 'name': 'container-name', + 'lastStatus': 'STOPPED', + 'exitCode': 0 + }] + }] + } + self.ecs._check_success_task() + client_mock.describe_tasks.assert_called_once_with(cluster='c', tasks=['arn']) + + +if __name__ == '__main__': + unittest.main()
