http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/hooks/zendesk_hook.py ---------------------------------------------------------------------- diff --git a/tests/contrib/hooks/zendesk_hook.py b/tests/contrib/hooks/zendesk_hook.py deleted file mode 100644 index 66b8e6b..0000000 --- a/tests/contrib/hooks/zendesk_hook.py +++ /dev/null @@ -1,90 +0,0 @@ -# -*- coding: utf-8 -*- -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - - -from unittest.mock import Mock, patch -from plugins.hooks.zendesk_hook import ZendeskHook -from zdesk import RateLimitError -from pytest import raises - - -@patch("plugins.hooks.zendesk_hook.time") -@patch("plugins.hooks.zendesk_hook.Zendesk") -def test_sleeps_for_correct_interval(_, mocked_time): - sleep_time = 10 - - # To break out of the otherwise infinite tries - mocked_time.sleep = Mock(side_effect=ValueError) - conn_mock = Mock() - mock_response = Mock() - mock_response.headers.get.return_value = sleep_time - conn_mock.call = Mock( - side_effect=RateLimitError(msg="some message", code="some code", - response=mock_response)) - - zendesk_hook = ZendeskHook("conn_id") - zendesk_hook.get_conn = Mock(return_value=conn_mock) - - with raises(ValueError): - zendesk_hook.call("some_path", get_all_pages=False) - mocked_time.sleep.assert_called_with(sleep_time) - - -@patch("plugins.hooks.zendesk_hook.Zendesk") -def test_returns_single_page_if_get_all_pages_false(_): - zendesk_hook = ZendeskHook("conn_id") - mock_connection = Mock() - mock_connection.host = "some_host" - zendesk_hook.get_connection = Mock(return_value=mock_connection) - zendesk_hook.get_conn() - - mock_conn = Mock() - mock_call = Mock( - return_value={'next_page': 'https://some_host/something', 'path': []}) - mock_conn.call = mock_call - zendesk_hook.get_conn = Mock(return_value=mock_conn) - zendesk_hook.call("path", get_all_pages=False) - mock_call.assert_called_once_with("path", None) - - -@patch("plugins.hooks.zendesk_hook.Zendesk") -def test_returns_multiple_pages_if_get_all_pages_true(_): - zendesk_hook = ZendeskHook("conn_id") - mock_connection = Mock() - mock_connection.host = "some_host" - zendesk_hook.get_connection = Mock(return_value=mock_connection) - zendesk_hook.get_conn() - - mock_conn = Mock() - mock_call = Mock( - return_value={'next_page': 'https://some_host/something', 'path': []}) - mock_conn.call = mock_call - zendesk_hook.get_conn = Mock(return_value=mock_conn) - zendesk_hook.call("path", get_all_pages=True) - assert mock_call.call_count == 2 - - -@patch("plugins.hooks.zendesk_hook.Zendesk") -def test_zdesk_is_inited_correctly(mock_zendesk): - conn_mock = Mock() - conn_mock.host = "conn_host" - conn_mock.login = "conn_login" - conn_mock.password = "conn_pass" - - zendesk_hook = ZendeskHook("conn_id") - zendesk_hook.get_connection = Mock(return_value=conn_mock) - zendesk_hook.get_conn() - mock_zendesk.assert_called_with('https://conn_host', 'conn_login', - 'conn_pass', True)
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/operators/__init__.py ---------------------------------------------------------------------- diff --git a/tests/contrib/operators/__init__.py b/tests/contrib/operators/__init__.py index 6e38bea..cdd2147 100644 --- a/tests/contrib/operators/__init__.py +++ b/tests/contrib/operators/__init__.py @@ -13,6 +13,3 @@ # limitations under the License. # -from __future__ import absolute_import -from .ssh_execute_operator import * -from .fs_operator import * http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/operators/databricks_operator.py ---------------------------------------------------------------------- diff --git a/tests/contrib/operators/databricks_operator.py b/tests/contrib/operators/databricks_operator.py deleted file mode 100644 index aab47fa..0000000 --- a/tests/contrib/operators/databricks_operator.py +++ /dev/null @@ -1,185 +0,0 @@ -# -*- coding: utf-8 -*- -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import unittest - -from airflow.contrib.hooks.databricks_hook import RunState -from airflow.contrib.operators.databricks_operator import DatabricksSubmitRunOperator -from airflow.exceptions import AirflowException - -try: - from unittest import mock -except ImportError: - try: - import mock - except ImportError: - mock = None - -TASK_ID = 'databricks-operator' -DEFAULT_CONN_ID = 'databricks_default' -NOTEBOOK_TASK = { - 'notebook_path': '/test' -} -SPARK_JAR_TASK = { - 'main_class_name': 'com.databricks.Test' -} -NEW_CLUSTER = { - 'spark_version': '2.0.x-scala2.10', - 'node_type_id': 'development-node', - 'num_workers': 1 -} -EXISTING_CLUSTER_ID = 'existing-cluster-id' -RUN_NAME = 'run-name' -RUN_ID = 1 - - -class DatabricksSubmitRunOperatorTest(unittest.TestCase): - def test_init_with_named_parameters(self): - """ - Test the initializer with the named parameters. - """ - op = DatabricksSubmitRunOperator(task_id=TASK_ID, new_cluster=NEW_CLUSTER, notebook_task=NOTEBOOK_TASK) - expected = { - 'new_cluster': NEW_CLUSTER, - 'notebook_task': NOTEBOOK_TASK, - 'run_name': TASK_ID - } - self.assertDictEqual(expected, op.json) - - def test_init_with_json(self): - """ - Test the initializer with json data. - """ - json = { - 'new_cluster': NEW_CLUSTER, - 'notebook_task': NOTEBOOK_TASK - } - op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json) - expected = { - 'new_cluster': NEW_CLUSTER, - 'notebook_task': NOTEBOOK_TASK, - 'run_name': TASK_ID - } - self.assertDictEqual(expected, op.json) - - def test_init_with_specified_run_name(self): - """ - Test the initializer with a specified run_name. - """ - json = { - 'new_cluster': NEW_CLUSTER, - 'notebook_task': NOTEBOOK_TASK, - 'run_name': RUN_NAME - } - op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json) - expected = { - 'new_cluster': NEW_CLUSTER, - 'notebook_task': NOTEBOOK_TASK, - 'run_name': RUN_NAME - } - self.assertDictEqual(expected, op.json) - - def test_init_with_merging(self): - """ - Test the initializer when json and other named parameters are both - provided. The named parameters should override top level keys in the - json dict. - """ - override_new_cluster = {'workers': 999} - json = { - 'new_cluster': NEW_CLUSTER, - 'notebook_task': NOTEBOOK_TASK, - } - op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json, new_cluster=override_new_cluster) - expected = { - 'new_cluster': override_new_cluster, - 'notebook_task': NOTEBOOK_TASK, - 'run_name': TASK_ID, - } - self.assertDictEqual(expected, op.json) - - @mock.patch('airflow.contrib.operators.databricks_operator.DatabricksHook') - def test_exec_success(self, db_mock_class): - """ - Test the execute function in case where the run is successful. - """ - run = { - 'new_cluster': NEW_CLUSTER, - 'notebook_task': NOTEBOOK_TASK, - } - op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=run) - db_mock = db_mock_class.return_value - db_mock.submit_run.return_value = 1 - db_mock.get_run_state.return_value = RunState('TERMINATED', 'SUCCESS', '') - - op.execute(None) - - expected = { - 'new_cluster': NEW_CLUSTER, - 'notebook_task': NOTEBOOK_TASK, - 'run_name': TASK_ID - } - db_mock_class.assert_called_once_with( - DEFAULT_CONN_ID, - retry_limit=op.databricks_retry_limit) - db_mock.submit_run.assert_called_once_with(expected) - db_mock.get_run_page_url.assert_called_once_with(RUN_ID) - db_mock.get_run_state.assert_called_once_with(RUN_ID) - self.assertEquals(RUN_ID, op.run_id) - - @mock.patch('airflow.contrib.operators.databricks_operator.DatabricksHook') - def test_exec_failure(self, db_mock_class): - """ - Test the execute function in case where the run failed. - """ - run = { - 'new_cluster': NEW_CLUSTER, - 'notebook_task': NOTEBOOK_TASK, - } - op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=run) - db_mock = db_mock_class.return_value - db_mock.submit_run.return_value = 1 - db_mock.get_run_state.return_value = RunState('TERMINATED', 'FAILED', '') - - with self.assertRaises(AirflowException): - op.execute(None) - - expected = { - 'new_cluster': NEW_CLUSTER, - 'notebook_task': NOTEBOOK_TASK, - 'run_name': TASK_ID, - } - db_mock_class.assert_called_once_with( - DEFAULT_CONN_ID, - retry_limit=op.databricks_retry_limit) - db_mock.submit_run.assert_called_once_with(expected) - db_mock.get_run_page_url.assert_called_once_with(RUN_ID) - db_mock.get_run_state.assert_called_once_with(RUN_ID) - self.assertEquals(RUN_ID, op.run_id) - - @mock.patch('airflow.contrib.operators.databricks_operator.DatabricksHook') - def test_on_kill(self, db_mock_class): - run = { - 'new_cluster': NEW_CLUSTER, - 'notebook_task': NOTEBOOK_TASK, - } - op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=run) - db_mock = db_mock_class.return_value - op.run_id = RUN_ID - - op.on_kill() - - db_mock.cancel_run.assert_called_once_with(RUN_ID) - http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/operators/dataflow_operator.py ---------------------------------------------------------------------- diff --git a/tests/contrib/operators/dataflow_operator.py b/tests/contrib/operators/dataflow_operator.py deleted file mode 100644 index 7455a45..0000000 --- a/tests/contrib/operators/dataflow_operator.py +++ /dev/null @@ -1,82 +0,0 @@ -# -*- coding: utf-8 -*- -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import unittest - -from airflow.contrib.operators.dataflow_operator import DataFlowPythonOperator - -try: - from unittest import mock -except ImportError: - try: - import mock - except ImportError: - mock = None - - -TASK_ID = 'test-python-dataflow' -PY_FILE = 'gs://my-bucket/my-object.py' -PY_OPTIONS = ['-m'] -DEFAULT_OPTIONS = { - 'project': 'test', - 'stagingLocation': 'gs://test/staging' -} -ADDITIONAL_OPTIONS = { - 'output': 'gs://test/output' -} -GCS_HOOK_STRING = 'airflow.contrib.operators.dataflow_operator.{}' - - -class DataFlowPythonOperatorTest(unittest.TestCase): - - def setUp(self): - self.dataflow = DataFlowPythonOperator( - task_id=TASK_ID, - py_file=PY_FILE, - py_options=PY_OPTIONS, - dataflow_default_options=DEFAULT_OPTIONS, - options=ADDITIONAL_OPTIONS) - - def test_init(self): - """Test DataFlowPythonOperator instance is properly initialized.""" - self.assertEqual(self.dataflow.task_id, TASK_ID) - self.assertEqual(self.dataflow.py_file, PY_FILE) - self.assertEqual(self.dataflow.py_options, PY_OPTIONS) - self.assertEqual(self.dataflow.dataflow_default_options, - DEFAULT_OPTIONS) - self.assertEqual(self.dataflow.options, - ADDITIONAL_OPTIONS) - - @mock.patch('airflow.contrib.operators.dataflow_operator.DataFlowHook') - @mock.patch(GCS_HOOK_STRING.format('GoogleCloudStorageHook')) - def test_exec(self, gcs_hook, dataflow_mock): - """Test DataFlowHook is created and the right args are passed to - start_python_workflow. - - """ - start_python_hook = dataflow_mock.return_value.start_python_dataflow - gcs_download_hook = gcs_hook.return_value.download - self.dataflow.execute(None) - self.assertTrue(dataflow_mock.called) - expected_options = { - 'project': 'test', - 'staging_location': 'gs://test/staging', - 'output': 'gs://test/output' - } - gcs_download_hook.assert_called_once_with( - 'my-bucket', 'my-object.py', mock.ANY) - start_python_hook.assert_called_once_with(TASK_ID, expected_options, - mock.ANY, PY_OPTIONS) - self.assertTrue(self.dataflow.py_file.startswith('/tmp/dataflow')) http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/operators/ecs_operator.py ---------------------------------------------------------------------- diff --git a/tests/contrib/operators/ecs_operator.py b/tests/contrib/operators/ecs_operator.py deleted file mode 100644 index 5a593a6..0000000 --- a/tests/contrib/operators/ecs_operator.py +++ /dev/null @@ -1,207 +0,0 @@ -# -*- coding: utf-8 -*- -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import sys -import unittest -from copy import deepcopy - -from airflow import configuration -from airflow.exceptions import AirflowException -from airflow.contrib.operators.ecs_operator import ECSOperator - -try: - from unittest import mock -except ImportError: - try: - import mock - except ImportError: - mock = None - - -RESPONSE_WITHOUT_FAILURES = { - "failures": [], - "tasks": [ - { - "containers": [ - { - "containerArn": "arn:aws:ecs:us-east-1:012345678910:container/e1ed7aac-d9b2-4315-8726-d2432bf11868", - "lastStatus": "PENDING", - "name": "wordpress", - "taskArn": "arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b55" - } - ], - "desiredStatus": "RUNNING", - "lastStatus": "PENDING", - "taskArn": "arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b55", - "taskDefinitionArn": "arn:aws:ecs:us-east-1:012345678910:task-definition/hello_world:11" - } - ] -} - - -class TestECSOperator(unittest.TestCase): - - @mock.patch('airflow.contrib.operators.ecs_operator.AwsHook') - def setUp(self, aws_hook_mock): - configuration.load_test_config() - - self.aws_hook_mock = aws_hook_mock - self.ecs = ECSOperator( - task_id='task', - task_definition='t', - cluster='c', - overrides={}, - aws_conn_id=None, - region_name='eu-west-1') - - def test_init(self): - - self.assertEqual(self.ecs.region_name, 'eu-west-1') - self.assertEqual(self.ecs.task_definition, 't') - self.assertEqual(self.ecs.aws_conn_id, None) - self.assertEqual(self.ecs.cluster, 'c') - self.assertEqual(self.ecs.overrides, {}) - self.assertEqual(self.ecs.hook, self.aws_hook_mock.return_value) - - self.aws_hook_mock.assert_called_once_with(aws_conn_id=None) - - def test_template_fields_overrides(self): - self.assertEqual(self.ecs.template_fields, ('overrides',)) - - @mock.patch.object(ECSOperator, '_wait_for_task_ended') - @mock.patch.object(ECSOperator, '_check_success_task') - def test_execute_without_failures(self, check_mock, wait_mock): - - client_mock = self.aws_hook_mock.return_value.get_client_type.return_value - client_mock.run_task.return_value = RESPONSE_WITHOUT_FAILURES - - self.ecs.execute(None) - - self.aws_hook_mock.return_value.get_client_type.assert_called_once_with('ecs', region_name='eu-west-1') - client_mock.run_task.assert_called_once_with( - cluster='c', - overrides={}, - startedBy='Airflow', - taskDefinition='t' - ) - - wait_mock.assert_called_once_with() - check_mock.assert_called_once_with() - self.assertEqual(self.ecs.arn, 'arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b55') - - def test_execute_with_failures(self): - - client_mock = self.aws_hook_mock.return_value.get_client_type.return_value - resp_failures = deepcopy(RESPONSE_WITHOUT_FAILURES) - resp_failures['failures'].append('dummy error') - client_mock.run_task.return_value = resp_failures - - with self.assertRaises(AirflowException): - self.ecs.execute(None) - - self.aws_hook_mock.return_value.get_client_type.assert_called_once_with('ecs', region_name='eu-west-1') - client_mock.run_task.assert_called_once_with( - cluster='c', - overrides={}, - startedBy='Airflow', - taskDefinition='t' - ) - - def test_wait_end_tasks(self): - - client_mock = mock.Mock() - self.ecs.arn = 'arn' - self.ecs.client = client_mock - - self.ecs._wait_for_task_ended() - client_mock.get_waiter.assert_called_once_with('tasks_stopped') - client_mock.get_waiter.return_value.wait.assert_called_once_with(cluster='c', tasks=['arn']) - self.assertEquals(sys.maxint, client_mock.get_waiter.return_value.config.max_attempts) - - def test_check_success_tasks_raises(self): - client_mock = mock.Mock() - self.ecs.arn = 'arn' - self.ecs.client = client_mock - - client_mock.describe_tasks.return_value = { - 'tasks': [{ - 'containers': [{ - 'name': 'foo', - 'lastStatus': 'STOPPED', - 'exitCode': 1 - }] - }] - } - with self.assertRaises(Exception) as e: - self.ecs._check_success_task() - - self.assertEquals(str(e.exception), "This task is not in success state {'containers': [{'lastStatus': 'STOPPED', 'name': 'foo', 'exitCode': 1}]}") - client_mock.describe_tasks.assert_called_once_with(cluster='c', tasks=['arn']) - - def test_check_success_tasks_raises_pending(self): - client_mock = mock.Mock() - self.ecs.client = client_mock - self.ecs.arn = 'arn' - client_mock.describe_tasks.return_value = { - 'tasks': [{ - 'containers': [{ - 'name': 'container-name', - 'lastStatus': 'PENDING' - }] - }] - } - with self.assertRaises(Exception) as e: - self.ecs._check_success_task() - self.assertEquals(str(e.exception), "This task is still pending {'containers': [{'lastStatus': 'PENDING', 'name': 'container-name'}]}") - client_mock.describe_tasks.assert_called_once_with(cluster='c', tasks=['arn']) - - def test_check_success_tasks_raises_mutliple(self): - client_mock = mock.Mock() - self.ecs.client = client_mock - self.ecs.arn = 'arn' - client_mock.describe_tasks.return_value = { - 'tasks': [{ - 'containers': [{ - 'name': 'foo', - 'exitCode': 1 - }, { - 'name': 'bar', - 'lastStatus': 'STOPPED', - 'exitCode': 0 - }] - }] - } - self.ecs._check_success_task() - client_mock.describe_tasks.assert_called_once_with(cluster='c', tasks=['arn']) - - def test_check_success_task_not_raises(self): - client_mock = mock.Mock() - self.ecs.client = client_mock - self.ecs.arn = 'arn' - client_mock.describe_tasks.return_value = { - 'tasks': [{ - 'containers': [{ - 'name': 'container-name', - 'lastStatus': 'STOPPED', - 'exitCode': 0 - }] - }] - } - self.ecs._check_success_task() - client_mock.describe_tasks.assert_called_once_with(cluster='c', tasks=['arn']) - - -if __name__ == '__main__': - unittest.main() http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/operators/emr_add_steps_operator.py ---------------------------------------------------------------------- diff --git a/tests/contrib/operators/emr_add_steps_operator.py b/tests/contrib/operators/emr_add_steps_operator.py deleted file mode 100644 index 37f9a4c..0000000 --- a/tests/contrib/operators/emr_add_steps_operator.py +++ /dev/null @@ -1,53 +0,0 @@ -# -*- coding: utf-8 -*- -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest -from mock import MagicMock, patch - -from airflow import configuration -from airflow.contrib.operators.emr_add_steps_operator import EmrAddStepsOperator - -ADD_STEPS_SUCCESS_RETURN = { - 'ResponseMetadata': { - 'HTTPStatusCode': 200 - }, - 'StepIds': ['s-2LH3R5GW3A53T'] -} - - -class TestEmrAddStepsOperator(unittest.TestCase): - def setUp(self): - configuration.load_test_config() - - # Mock out the emr_client (moto has incorrect response) - mock_emr_client = MagicMock() - mock_emr_client.add_job_flow_steps.return_value = ADD_STEPS_SUCCESS_RETURN - - # Mock out the emr_client creator - self.boto3_client_mock = MagicMock(return_value=mock_emr_client) - - - def test_execute_adds_steps_to_the_job_flow_and_returns_step_ids(self): - with patch('boto3.client', self.boto3_client_mock): - - operator = EmrAddStepsOperator( - task_id='test_task', - job_flow_id='j-8989898989', - aws_conn_id='aws_default' - ) - - self.assertEqual(operator.execute(None), ['s-2LH3R5GW3A53T']) - -if __name__ == '__main__': - unittest.main() http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/operators/emr_create_job_flow_operator.py ---------------------------------------------------------------------- diff --git a/tests/contrib/operators/emr_create_job_flow_operator.py b/tests/contrib/operators/emr_create_job_flow_operator.py deleted file mode 100644 index 4aa4cd2..0000000 --- a/tests/contrib/operators/emr_create_job_flow_operator.py +++ /dev/null @@ -1,53 +0,0 @@ -# -*- coding: utf-8 -*- -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import unittest -from mock import MagicMock, patch - -from airflow import configuration -from airflow.contrib.operators.emr_create_job_flow_operator import EmrCreateJobFlowOperator - -RUN_JOB_FLOW_SUCCESS_RETURN = { - 'ResponseMetadata': { - 'HTTPStatusCode': 200 - }, - 'JobFlowId': 'j-8989898989' -} - -class TestEmrCreateJobFlowOperator(unittest.TestCase): - def setUp(self): - configuration.load_test_config() - - # Mock out the emr_client (moto has incorrect response) - mock_emr_client = MagicMock() - mock_emr_client.run_job_flow.return_value = RUN_JOB_FLOW_SUCCESS_RETURN - - # Mock out the emr_client creator - self.boto3_client_mock = MagicMock(return_value=mock_emr_client) - - - def test_execute_uses_the_emr_config_to_create_a_cluster_and_returns_job_id(self): - with patch('boto3.client', self.boto3_client_mock): - - operator = EmrCreateJobFlowOperator( - task_id='test_task', - aws_conn_id='aws_default', - emr_conn_id='emr_default' - ) - - self.assertEqual(operator.execute(None), 'j-8989898989') - -if __name__ == '__main__': - unittest.main() http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/operators/emr_terminate_job_flow_operator.py ---------------------------------------------------------------------- diff --git a/tests/contrib/operators/emr_terminate_job_flow_operator.py b/tests/contrib/operators/emr_terminate_job_flow_operator.py deleted file mode 100644 index 94c0124..0000000 --- a/tests/contrib/operators/emr_terminate_job_flow_operator.py +++ /dev/null @@ -1,52 +0,0 @@ -# -*- coding: utf-8 -*- -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest -from mock import MagicMock, patch - -from airflow import configuration -from airflow.contrib.operators.emr_terminate_job_flow_operator import EmrTerminateJobFlowOperator - -TERMINATE_SUCCESS_RETURN = { - 'ResponseMetadata': { - 'HTTPStatusCode': 200 - } -} - - -class TestEmrTerminateJobFlowOperator(unittest.TestCase): - def setUp(self): - configuration.load_test_config() - - # Mock out the emr_client (moto has incorrect response) - mock_emr_client = MagicMock() - mock_emr_client.terminate_job_flows.return_value = TERMINATE_SUCCESS_RETURN - - # Mock out the emr_client creator - self.boto3_client_mock = MagicMock(return_value=mock_emr_client) - - - def test_execute_terminates_the_job_flow_and_does_not_error(self): - with patch('boto3.client', self.boto3_client_mock): - - operator = EmrTerminateJobFlowOperator( - task_id='test_task', - job_flow_id='j-8989898989', - aws_conn_id='aws_default' - ) - - operator.execute(None) - -if __name__ == '__main__': - unittest.main() http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/operators/fs_operator.py ---------------------------------------------------------------------- diff --git a/tests/contrib/operators/fs_operator.py b/tests/contrib/operators/fs_operator.py deleted file mode 100644 index f990157..0000000 --- a/tests/contrib/operators/fs_operator.py +++ /dev/null @@ -1,64 +0,0 @@ -# -*- coding: utf-8 -*- -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import unittest -from datetime import datetime - -from airflow import configuration -from airflow.settings import Session -from airflow import models, DAG -from airflow.contrib.operators.fs_operator import FileSensor - -TEST_DAG_ID = 'unit_tests' -DEFAULT_DATE = datetime(2015, 1, 1) -configuration.load_test_config() - - -def reset(dag_id=TEST_DAG_ID): - session = Session() - tis = session.query(models.TaskInstance).filter_by(dag_id=dag_id) - tis.delete() - session.commit() - session.close() - -reset() - -class FileSensorTest(unittest.TestCase): - def setUp(self): - configuration.load_test_config() - from airflow.contrib.hooks.fs_hook import FSHook - hook = FSHook() - args = { - 'owner': 'airflow', - 'start_date': DEFAULT_DATE, - 'provide_context': True - } - dag = DAG(TEST_DAG_ID+'test_schedule_dag_once', default_args=args) - dag.schedule_interval = '@once' - self.hook = hook - self.dag = dag - - def test_simple(self): - task = FileSensor( - task_id="test", - filepath="etc/hosts", - fs_conn_id='fs_default', - _hook=self.hook, - dag=self.dag, - ) - task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) - -if __name__ == '__main__': - unittest.main() http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/operators/hipchat_operator.py ---------------------------------------------------------------------- diff --git a/tests/contrib/operators/hipchat_operator.py b/tests/contrib/operators/hipchat_operator.py deleted file mode 100644 index 65a2edd..0000000 --- a/tests/contrib/operators/hipchat_operator.py +++ /dev/null @@ -1,74 +0,0 @@ -# -*- coding: utf-8 -*- -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest -import requests - -from airflow.contrib.operators.hipchat_operator import \ - HipChatAPISendRoomNotificationOperator -from airflow.exceptions import AirflowException -from airflow import configuration - -try: - from unittest import mock -except ImportError: - try: - import mock - except ImportError: - mock = None - - -class HipChatOperatorTest(unittest.TestCase): - def setUp(self): - configuration.load_test_config() - - @unittest.skipIf(mock is None, 'mock package not present') - @mock.patch('requests.request') - def test_execute(self, request_mock): - resp = requests.Response() - resp.status_code = 200 - request_mock.return_value = resp - - operator = HipChatAPISendRoomNotificationOperator( - task_id='test_hipchat_success', - owner = 'airflow', - token='abc123', - room_id='room_id', - message='hello world!' - ) - - operator.execute(None) - - @unittest.skipIf(mock is None, 'mock package not present') - @mock.patch('requests.request') - def test_execute_error_response(self, request_mock): - resp = requests.Response() - resp.status_code = 404 - resp.reason = 'Not Found' - request_mock.return_value = resp - - operator = HipChatAPISendRoomNotificationOperator( - task_id='test_hipchat_failure', - owner='airflow', - token='abc123', - room_id='room_id', - message='hello world!' - ) - - with self.assertRaises(AirflowException): - operator.execute(None) - - -if __name__ == '__main__': - unittest.main() http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/operators/jira_operator_test.py ---------------------------------------------------------------------- diff --git a/tests/contrib/operators/jira_operator_test.py b/tests/contrib/operators/jira_operator_test.py deleted file mode 100644 index 6d615df..0000000 --- a/tests/contrib/operators/jira_operator_test.py +++ /dev/null @@ -1,101 +0,0 @@ -# -*- coding: utf-8 -*- -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import unittest -import datetime -from mock import Mock -from mock import patch - -from airflow import DAG, configuration -from airflow.contrib.operators.jira_operator import JiraOperator -from airflow import models -from airflow.utils import db - -DEFAULT_DATE = datetime.datetime(2017, 1, 1) -jira_client_mock = Mock( - name="jira_client_for_test" -) - -minimal_test_ticket = { - "id": "911539", - "self": "https://sandbox.localhost/jira/rest/api/2/issue/911539", - "key": "TEST-1226", - "fields": { - "labels": [ - "test-label-1", - "test-label-2" - ], - "description": "this is a test description", - } -} - - -class TestJiraOperator(unittest.TestCase): - def setUp(self): - configuration.load_test_config() - args = { - 'owner': 'airflow', - 'start_date': DEFAULT_DATE - } - dag = DAG('test_dag_id', default_args=args) - self.dag = dag - db.merge_conn( - models.Connection( - conn_id='jira_default', conn_type='jira', - host='https://localhost/jira/', port=443, - extra='{"verify": "False", "project": "AIRFLOW"}')) - - @patch("airflow.contrib.hooks.jira_hook.JIRA", - autospec=True, return_value=jira_client_mock) - def test_issue_search(self, jira_mock): - jql_str = 'issuekey=TEST-1226' - jira_mock.return_value.search_issues.return_value = minimal_test_ticket - - jira_ticket_search_operator = JiraOperator(task_id='search-ticket-test', - jira_method="search_issues", - jira_method_args={ - 'jql_str': jql_str, - 'maxResults': '1' - }, - dag=self.dag) - - jira_ticket_search_operator.run(start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE, ignore_ti_state=True) - - self.assertTrue(jira_mock.called) - self.assertTrue(jira_mock.return_value.search_issues.called) - - @patch("airflow.contrib.hooks.jira_hook.JIRA", - autospec=True, return_value=jira_client_mock) - def test_update_issue(self, jira_mock): - jira_mock.return_value.add_comment.return_value = True - - add_comment_operator = JiraOperator(task_id='add_comment_test', - jira_method="add_comment", - jira_method_args={ - 'issue': minimal_test_ticket.get("key"), - 'body': 'this is test comment' - }, - dag=self.dag) - - add_comment_operator.run(start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE, ignore_ti_state=True) - - self.assertTrue(jira_mock.called) - self.assertTrue(jira_mock.return_value.add_comment.called) - - -if __name__ == '__main__': - unittest.main() http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/operators/spark_submit_operator.py ---------------------------------------------------------------------- diff --git a/tests/contrib/operators/spark_submit_operator.py b/tests/contrib/operators/spark_submit_operator.py deleted file mode 100644 index 4e2afb2..0000000 --- a/tests/contrib/operators/spark_submit_operator.py +++ /dev/null @@ -1,81 +0,0 @@ -# -*- coding: utf-8 -*- -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import unittest -import datetime - -from airflow import DAG, configuration -from airflow.contrib.operators.spark_submit_operator import SparkSubmitOperator - -DEFAULT_DATE = datetime.datetime(2017, 1, 1) - - -class TestSparkSubmitOperator(unittest.TestCase): - _config = { - 'conf': { - 'parquet.compression': 'SNAPPY' - }, - 'files': 'hive-site.xml', - 'py_files': 'sample_library.py', - 'jars': 'parquet.jar', - 'executor_cores': 4, - 'executor_memory': '22g', - 'keytab': 'privileged_user.keytab', - 'principal': 'user/[email protected]', - 'name': 'spark-job', - 'num_executors': 10, - 'verbose': True, - 'application': 'test_application.py', - 'driver_memory': '3g', - 'java_class': 'com.foo.bar.AppMain' - } - - def setUp(self): - configuration.load_test_config() - args = { - 'owner': 'airflow', - 'start_date': DEFAULT_DATE - } - self.dag = DAG('test_dag_id', default_args=args) - - def test_execute(self, conn_id='spark_default'): - operator = SparkSubmitOperator( - task_id='spark_submit_job', - dag=self.dag, - **self._config - ) - - self.assertEqual(conn_id, operator._conn_id) - - self.assertEqual(self._config['application'], operator._application) - self.assertEqual(self._config['conf'], operator._conf) - self.assertEqual(self._config['files'], operator._files) - self.assertEqual(self._config['py_files'], operator._py_files) - self.assertEqual(self._config['jars'], operator._jars) - self.assertEqual(self._config['executor_cores'], operator._executor_cores) - self.assertEqual(self._config['executor_memory'], operator._executor_memory) - self.assertEqual(self._config['keytab'], operator._keytab) - self.assertEqual(self._config['principal'], operator._principal) - self.assertEqual(self._config['name'], operator._name) - self.assertEqual(self._config['num_executors'], operator._num_executors) - self.assertEqual(self._config['verbose'], operator._verbose) - self.assertEqual(self._config['java_class'], operator._java_class) - self.assertEqual(self._config['driver_memory'], operator._driver_memory) - - - - -if __name__ == '__main__': - unittest.main() http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/operators/sqoop_operator.py ---------------------------------------------------------------------- diff --git a/tests/contrib/operators/sqoop_operator.py b/tests/contrib/operators/sqoop_operator.py deleted file mode 100644 index a46dc93..0000000 --- a/tests/contrib/operators/sqoop_operator.py +++ /dev/null @@ -1,93 +0,0 @@ -# -*- coding: utf-8 -*- -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import datetime -import unittest - -from airflow import DAG, configuration -from airflow.contrib.operators.sqoop_operator import SqoopOperator - - -class TestSqoopOperator(unittest.TestCase): - _config = { - 'cmd_type': 'export', - 'table': 'target_table', - 'query': 'SELECT * FROM schema.table', - 'target_dir': '/path/on/hdfs/to/import', - 'append': True, - 'file_type': 'avro', - 'columns': 'a,b,c', - 'num_mappers': 22, - 'split_by': 'id', - 'export_dir': '/path/on/hdfs/to/export', - 'input_null_string': '\n', - 'input_null_non_string': '\t', - 'staging_table': 'target_table_staging', - 'clear_staging_table': True, - 'enclosed_by': '"', - 'escaped_by': '\\', - 'input_fields_terminated_by': '|', - 'input_lines_terminated_by': '\n', - 'input_optionally_enclosed_by': '"', - 'batch': True, - 'relaxed_isolation': True, - 'direct': True, - 'driver': 'com.microsoft.jdbc.sqlserver.SQLServerDriver', - 'properties': { - 'mapred.map.max.attempts': '1' - } - } - - def setUp(self): - configuration.load_test_config() - args = { - 'owner': 'airflow', - 'start_date': datetime.datetime(2017, 1, 1) - } - self.dag = DAG('test_dag_id', default_args=args) - - def test_execute(self, conn_id='sqoop_default'): - operator = SqoopOperator( - task_id='sqoop_job', - dag=self.dag, - **self._config - ) - - self.assertEqual(conn_id, operator.conn_id) - - self.assertEqual(self._config['cmd_type'], operator.cmd_type) - self.assertEqual(self._config['table'], operator.table) - self.assertEqual(self._config['target_dir'], operator.target_dir) - self.assertEqual(self._config['append'], operator.append) - self.assertEqual(self._config['file_type'], operator.file_type) - self.assertEqual(self._config['num_mappers'], operator.num_mappers) - self.assertEqual(self._config['split_by'], operator.split_by) - self.assertEqual(self._config['input_null_string'], - operator.input_null_string) - self.assertEqual(self._config['input_null_non_string'], - operator.input_null_non_string) - self.assertEqual(self._config['staging_table'], operator.staging_table) - self.assertEqual(self._config['clear_staging_table'], - operator.clear_staging_table) - self.assertEqual(self._config['batch'], operator.batch) - self.assertEqual(self._config['relaxed_isolation'], - operator.relaxed_isolation) - self.assertEqual(self._config['direct'], operator.direct) - self.assertEqual(self._config['driver'], operator.driver) - self.assertEqual(self._config['properties'], operator.properties) - - -if __name__ == '__main__': - unittest.main() http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/operators/ssh_execute_operator.py ---------------------------------------------------------------------- diff --git a/tests/contrib/operators/ssh_execute_operator.py b/tests/contrib/operators/ssh_execute_operator.py deleted file mode 100644 index ef8162c..0000000 --- a/tests/contrib/operators/ssh_execute_operator.py +++ /dev/null @@ -1,79 +0,0 @@ -# -*- coding: utf-8 -*- -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest -import os -from datetime import datetime - -from airflow import configuration -from airflow.settings import Session -from airflow import models, DAG -from airflow.contrib.operators.ssh_execute_operator import SSHExecuteOperator - - -TEST_DAG_ID = 'unit_tests' -DEFAULT_DATE = datetime(2015, 1, 1) -configuration.load_test_config() - - -def reset(dag_id=TEST_DAG_ID): - session = Session() - tis = session.query(models.TaskInstance).filter_by(dag_id=dag_id) - tis.delete() - session.commit() - session.close() - -reset() - - -class SSHExecuteOperatorTest(unittest.TestCase): - def setUp(self): - configuration.load_test_config() - from airflow.contrib.hooks.ssh_hook import SSHHook - hook = SSHHook() - hook.no_host_key_check = True - args = { - 'owner': 'airflow', - 'start_date': DEFAULT_DATE, - 'provide_context': True - } - dag = DAG(TEST_DAG_ID+'test_schedule_dag_once', default_args=args) - dag.schedule_interval = '@once' - self.hook = hook - self.dag = dag - - def test_simple(self): - task = SSHExecuteOperator( - task_id="test", - bash_command="echo airflow", - ssh_hook=self.hook, - dag=self.dag, - ) - task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) - - def test_with_env(self): - test_env = os.environ.copy() - test_env['AIRFLOW_test'] = "test" - task = SSHExecuteOperator( - task_id="test", - bash_command="echo $AIRFLOW_HOME", - ssh_hook=self.hook, - env=test_env, - dag=self.dag, - ) - task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) - - -if __name__ == '__main__': - unittest.main() http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/operators/test_databricks_operator.py ---------------------------------------------------------------------- diff --git a/tests/contrib/operators/test_databricks_operator.py b/tests/contrib/operators/test_databricks_operator.py new file mode 100644 index 0000000..aab47fa --- /dev/null +++ b/tests/contrib/operators/test_databricks_operator.py @@ -0,0 +1,185 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest + +from airflow.contrib.hooks.databricks_hook import RunState +from airflow.contrib.operators.databricks_operator import DatabricksSubmitRunOperator +from airflow.exceptions import AirflowException + +try: + from unittest import mock +except ImportError: + try: + import mock + except ImportError: + mock = None + +TASK_ID = 'databricks-operator' +DEFAULT_CONN_ID = 'databricks_default' +NOTEBOOK_TASK = { + 'notebook_path': '/test' +} +SPARK_JAR_TASK = { + 'main_class_name': 'com.databricks.Test' +} +NEW_CLUSTER = { + 'spark_version': '2.0.x-scala2.10', + 'node_type_id': 'development-node', + 'num_workers': 1 +} +EXISTING_CLUSTER_ID = 'existing-cluster-id' +RUN_NAME = 'run-name' +RUN_ID = 1 + + +class DatabricksSubmitRunOperatorTest(unittest.TestCase): + def test_init_with_named_parameters(self): + """ + Test the initializer with the named parameters. + """ + op = DatabricksSubmitRunOperator(task_id=TASK_ID, new_cluster=NEW_CLUSTER, notebook_task=NOTEBOOK_TASK) + expected = { + 'new_cluster': NEW_CLUSTER, + 'notebook_task': NOTEBOOK_TASK, + 'run_name': TASK_ID + } + self.assertDictEqual(expected, op.json) + + def test_init_with_json(self): + """ + Test the initializer with json data. + """ + json = { + 'new_cluster': NEW_CLUSTER, + 'notebook_task': NOTEBOOK_TASK + } + op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json) + expected = { + 'new_cluster': NEW_CLUSTER, + 'notebook_task': NOTEBOOK_TASK, + 'run_name': TASK_ID + } + self.assertDictEqual(expected, op.json) + + def test_init_with_specified_run_name(self): + """ + Test the initializer with a specified run_name. + """ + json = { + 'new_cluster': NEW_CLUSTER, + 'notebook_task': NOTEBOOK_TASK, + 'run_name': RUN_NAME + } + op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json) + expected = { + 'new_cluster': NEW_CLUSTER, + 'notebook_task': NOTEBOOK_TASK, + 'run_name': RUN_NAME + } + self.assertDictEqual(expected, op.json) + + def test_init_with_merging(self): + """ + Test the initializer when json and other named parameters are both + provided. The named parameters should override top level keys in the + json dict. + """ + override_new_cluster = {'workers': 999} + json = { + 'new_cluster': NEW_CLUSTER, + 'notebook_task': NOTEBOOK_TASK, + } + op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json, new_cluster=override_new_cluster) + expected = { + 'new_cluster': override_new_cluster, + 'notebook_task': NOTEBOOK_TASK, + 'run_name': TASK_ID, + } + self.assertDictEqual(expected, op.json) + + @mock.patch('airflow.contrib.operators.databricks_operator.DatabricksHook') + def test_exec_success(self, db_mock_class): + """ + Test the execute function in case where the run is successful. + """ + run = { + 'new_cluster': NEW_CLUSTER, + 'notebook_task': NOTEBOOK_TASK, + } + op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=run) + db_mock = db_mock_class.return_value + db_mock.submit_run.return_value = 1 + db_mock.get_run_state.return_value = RunState('TERMINATED', 'SUCCESS', '') + + op.execute(None) + + expected = { + 'new_cluster': NEW_CLUSTER, + 'notebook_task': NOTEBOOK_TASK, + 'run_name': TASK_ID + } + db_mock_class.assert_called_once_with( + DEFAULT_CONN_ID, + retry_limit=op.databricks_retry_limit) + db_mock.submit_run.assert_called_once_with(expected) + db_mock.get_run_page_url.assert_called_once_with(RUN_ID) + db_mock.get_run_state.assert_called_once_with(RUN_ID) + self.assertEquals(RUN_ID, op.run_id) + + @mock.patch('airflow.contrib.operators.databricks_operator.DatabricksHook') + def test_exec_failure(self, db_mock_class): + """ + Test the execute function in case where the run failed. + """ + run = { + 'new_cluster': NEW_CLUSTER, + 'notebook_task': NOTEBOOK_TASK, + } + op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=run) + db_mock = db_mock_class.return_value + db_mock.submit_run.return_value = 1 + db_mock.get_run_state.return_value = RunState('TERMINATED', 'FAILED', '') + + with self.assertRaises(AirflowException): + op.execute(None) + + expected = { + 'new_cluster': NEW_CLUSTER, + 'notebook_task': NOTEBOOK_TASK, + 'run_name': TASK_ID, + } + db_mock_class.assert_called_once_with( + DEFAULT_CONN_ID, + retry_limit=op.databricks_retry_limit) + db_mock.submit_run.assert_called_once_with(expected) + db_mock.get_run_page_url.assert_called_once_with(RUN_ID) + db_mock.get_run_state.assert_called_once_with(RUN_ID) + self.assertEquals(RUN_ID, op.run_id) + + @mock.patch('airflow.contrib.operators.databricks_operator.DatabricksHook') + def test_on_kill(self, db_mock_class): + run = { + 'new_cluster': NEW_CLUSTER, + 'notebook_task': NOTEBOOK_TASK, + } + op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=run) + db_mock = db_mock_class.return_value + op.run_id = RUN_ID + + op.on_kill() + + db_mock.cancel_run.assert_called_once_with(RUN_ID) + http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/operators/test_dataflow_operator.py ---------------------------------------------------------------------- diff --git a/tests/contrib/operators/test_dataflow_operator.py b/tests/contrib/operators/test_dataflow_operator.py new file mode 100644 index 0000000..0423616 --- /dev/null +++ b/tests/contrib/operators/test_dataflow_operator.py @@ -0,0 +1,81 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest + +from airflow.contrib.operators.dataflow_operator import DataFlowPythonOperator + +try: + from unittest import mock +except ImportError: + try: + import mock + except ImportError: + mock = None + + +TASK_ID = 'test-python-dataflow' +PY_FILE = 'gs://my-bucket/my-object.py' +PY_OPTIONS = ['-m'] +DEFAULT_OPTIONS = { + 'project': 'test', + 'stagingLocation': 'gs://test/staging' +} +ADDITIONAL_OPTIONS = { + 'output': 'gs://test/output' +} +GCS_HOOK_STRING = 'airflow.contrib.operators.dataflow_operator.{}' + + +class DataFlowPythonOperatorTest(unittest.TestCase): + + def setUp(self): + self.dataflow = DataFlowPythonOperator( + task_id=TASK_ID, + py_file=PY_FILE, + py_options=PY_OPTIONS, + dataflow_default_options=DEFAULT_OPTIONS, + options=ADDITIONAL_OPTIONS) + + def test_init(self): + """Test DataFlowPythonOperator instance is properly initialized.""" + self.assertEqual(self.dataflow.task_id, TASK_ID) + self.assertEqual(self.dataflow.py_file, PY_FILE) + self.assertEqual(self.dataflow.py_options, PY_OPTIONS) + self.assertEqual(self.dataflow.dataflow_default_options, + DEFAULT_OPTIONS) + self.assertEqual(self.dataflow.options, + ADDITIONAL_OPTIONS) + + @mock.patch('airflow.contrib.operators.dataflow_operator.DataFlowHook') + @mock.patch(GCS_HOOK_STRING.format('GoogleCloudBucketHelper')) + def test_exec(self, gcs_hook, dataflow_mock): + """Test DataFlowHook is created and the right args are passed to + start_python_workflow. + + """ + start_python_hook = dataflow_mock.return_value.start_python_dataflow + gcs_download_hook = gcs_hook.return_value.google_cloud_to_local + self.dataflow.execute(None) + self.assertTrue(dataflow_mock.called) + expected_options = { + 'project': 'test', + 'staging_location': 'gs://test/staging', + 'output': 'gs://test/output' + } + gcs_download_hook.assert_called_once_with(PY_FILE) + start_python_hook.assert_called_once_with(TASK_ID, expected_options, + mock.ANY, PY_OPTIONS) + self.assertTrue(self.dataflow.py_file.startswith('/tmp/dataflow')) http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/operators/test_ecs_operator.py ---------------------------------------------------------------------- diff --git a/tests/contrib/operators/test_ecs_operator.py b/tests/contrib/operators/test_ecs_operator.py new file mode 100644 index 0000000..80dedd3 --- /dev/null +++ b/tests/contrib/operators/test_ecs_operator.py @@ -0,0 +1,214 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys +import unittest +from copy import deepcopy + +from airflow import configuration +from airflow.exceptions import AirflowException +from airflow.contrib.operators.ecs_operator import ECSOperator + +try: + from unittest import mock +except ImportError: + try: + import mock + except ImportError: + mock = None + + +RESPONSE_WITHOUT_FAILURES = { + "failures": [], + "tasks": [ + { + "containers": [ + { + "containerArn": "arn:aws:ecs:us-east-1:012345678910:container/e1ed7aac-d9b2-4315-8726-d2432bf11868", + "lastStatus": "PENDING", + "name": "wordpress", + "taskArn": "arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b55" + } + ], + "desiredStatus": "RUNNING", + "lastStatus": "PENDING", + "taskArn": "arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b55", + "taskDefinitionArn": "arn:aws:ecs:us-east-1:012345678910:task-definition/hello_world:11" + } + ] +} + + +class TestECSOperator(unittest.TestCase): + + @mock.patch('airflow.contrib.operators.ecs_operator.AwsHook') + def setUp(self, aws_hook_mock): + configuration.load_test_config() + + self.aws_hook_mock = aws_hook_mock + self.ecs = ECSOperator( + task_id='task', + task_definition='t', + cluster='c', + overrides={}, + aws_conn_id=None, + region_name='eu-west-1') + + def test_init(self): + + self.assertEqual(self.ecs.region_name, 'eu-west-1') + self.assertEqual(self.ecs.task_definition, 't') + self.assertEqual(self.ecs.aws_conn_id, None) + self.assertEqual(self.ecs.cluster, 'c') + self.assertEqual(self.ecs.overrides, {}) + self.assertEqual(self.ecs.hook, self.aws_hook_mock.return_value) + + self.aws_hook_mock.assert_called_once_with(aws_conn_id=None) + + def test_template_fields_overrides(self): + self.assertEqual(self.ecs.template_fields, ('overrides',)) + + @mock.patch.object(ECSOperator, '_wait_for_task_ended') + @mock.patch.object(ECSOperator, '_check_success_task') + def test_execute_without_failures(self, check_mock, wait_mock): + + client_mock = self.aws_hook_mock.return_value.get_client_type.return_value + client_mock.run_task.return_value = RESPONSE_WITHOUT_FAILURES + + self.ecs.execute(None) + + self.aws_hook_mock.return_value.get_client_type.assert_called_once_with('ecs', region_name='eu-west-1') + client_mock.run_task.assert_called_once_with( + cluster='c', + overrides={}, + startedBy=mock.ANY, # Can by 'airflow' or 'Airflow' + taskDefinition='t' + ) + + wait_mock.assert_called_once_with() + check_mock.assert_called_once_with() + self.assertEqual(self.ecs.arn, 'arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b55') + + def test_execute_with_failures(self): + + client_mock = self.aws_hook_mock.return_value.get_client_type.return_value + resp_failures = deepcopy(RESPONSE_WITHOUT_FAILURES) + resp_failures['failures'].append('dummy error') + client_mock.run_task.return_value = resp_failures + + with self.assertRaises(AirflowException): + self.ecs.execute(None) + + self.aws_hook_mock.return_value.get_client_type.assert_called_once_with('ecs', region_name='eu-west-1') + client_mock.run_task.assert_called_once_with( + cluster='c', + overrides={}, + startedBy=mock.ANY, # Can by 'airflow' or 'Airflow' + taskDefinition='t' + ) + + def test_wait_end_tasks(self): + + client_mock = mock.Mock() + self.ecs.arn = 'arn' + self.ecs.client = client_mock + + self.ecs._wait_for_task_ended() + client_mock.get_waiter.assert_called_once_with('tasks_stopped') + client_mock.get_waiter.return_value.wait.assert_called_once_with(cluster='c', tasks=['arn']) + self.assertEquals(sys.maxsize, client_mock.get_waiter.return_value.config.max_attempts) + + def test_check_success_tasks_raises(self): + client_mock = mock.Mock() + self.ecs.arn = 'arn' + self.ecs.client = client_mock + + client_mock.describe_tasks.return_value = { + 'tasks': [{ + 'containers': [{ + 'name': 'foo', + 'lastStatus': 'STOPPED', + 'exitCode': 1 + }] + }] + } + with self.assertRaises(Exception) as e: + self.ecs._check_success_task() + + # Ordering of str(dict) is not guaranteed. + self.assertIn("This task is not in success state ", str(e.exception)) + self.assertIn("'name': 'foo'", str(e.exception)) + self.assertIn("'lastStatus': 'STOPPED'", str(e.exception)) + self.assertIn("'exitCode': 1", str(e.exception)) + client_mock.describe_tasks.assert_called_once_with(cluster='c', tasks=['arn']) + + def test_check_success_tasks_raises_pending(self): + client_mock = mock.Mock() + self.ecs.client = client_mock + self.ecs.arn = 'arn' + client_mock.describe_tasks.return_value = { + 'tasks': [{ + 'containers': [{ + 'name': 'container-name', + 'lastStatus': 'PENDING' + }] + }] + } + with self.assertRaises(Exception) as e: + self.ecs._check_success_task() + # Ordering of str(dict) is not guaranteed. + self.assertIn("This task is still pending ", str(e.exception)) + self.assertIn("'name': 'container-name'", str(e.exception)) + self.assertIn("'lastStatus': 'PENDING'", str(e.exception)) + client_mock.describe_tasks.assert_called_once_with(cluster='c', tasks=['arn']) + + def test_check_success_tasks_raises_mutliple(self): + client_mock = mock.Mock() + self.ecs.client = client_mock + self.ecs.arn = 'arn' + client_mock.describe_tasks.return_value = { + 'tasks': [{ + 'containers': [{ + 'name': 'foo', + 'exitCode': 1 + }, { + 'name': 'bar', + 'lastStatus': 'STOPPED', + 'exitCode': 0 + }] + }] + } + self.ecs._check_success_task() + client_mock.describe_tasks.assert_called_once_with(cluster='c', tasks=['arn']) + + def test_check_success_task_not_raises(self): + client_mock = mock.Mock() + self.ecs.client = client_mock + self.ecs.arn = 'arn' + client_mock.describe_tasks.return_value = { + 'tasks': [{ + 'containers': [{ + 'name': 'container-name', + 'lastStatus': 'STOPPED', + 'exitCode': 0 + }] + }] + } + self.ecs._check_success_task() + client_mock.describe_tasks.assert_called_once_with(cluster='c', tasks=['arn']) + + +if __name__ == '__main__': + unittest.main() http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/operators/test_emr_add_steps_operator.py ---------------------------------------------------------------------- diff --git a/tests/contrib/operators/test_emr_add_steps_operator.py b/tests/contrib/operators/test_emr_add_steps_operator.py new file mode 100644 index 0000000..37f9a4c --- /dev/null +++ b/tests/contrib/operators/test_emr_add_steps_operator.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from mock import MagicMock, patch + +from airflow import configuration +from airflow.contrib.operators.emr_add_steps_operator import EmrAddStepsOperator + +ADD_STEPS_SUCCESS_RETURN = { + 'ResponseMetadata': { + 'HTTPStatusCode': 200 + }, + 'StepIds': ['s-2LH3R5GW3A53T'] +} + + +class TestEmrAddStepsOperator(unittest.TestCase): + def setUp(self): + configuration.load_test_config() + + # Mock out the emr_client (moto has incorrect response) + mock_emr_client = MagicMock() + mock_emr_client.add_job_flow_steps.return_value = ADD_STEPS_SUCCESS_RETURN + + # Mock out the emr_client creator + self.boto3_client_mock = MagicMock(return_value=mock_emr_client) + + + def test_execute_adds_steps_to_the_job_flow_and_returns_step_ids(self): + with patch('boto3.client', self.boto3_client_mock): + + operator = EmrAddStepsOperator( + task_id='test_task', + job_flow_id='j-8989898989', + aws_conn_id='aws_default' + ) + + self.assertEqual(operator.execute(None), ['s-2LH3R5GW3A53T']) + +if __name__ == '__main__': + unittest.main() http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/operators/test_emr_create_job_flow_operator.py ---------------------------------------------------------------------- diff --git a/tests/contrib/operators/test_emr_create_job_flow_operator.py b/tests/contrib/operators/test_emr_create_job_flow_operator.py new file mode 100644 index 0000000..4aa4cd2 --- /dev/null +++ b/tests/contrib/operators/test_emr_create_job_flow_operator.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest +from mock import MagicMock, patch + +from airflow import configuration +from airflow.contrib.operators.emr_create_job_flow_operator import EmrCreateJobFlowOperator + +RUN_JOB_FLOW_SUCCESS_RETURN = { + 'ResponseMetadata': { + 'HTTPStatusCode': 200 + }, + 'JobFlowId': 'j-8989898989' +} + +class TestEmrCreateJobFlowOperator(unittest.TestCase): + def setUp(self): + configuration.load_test_config() + + # Mock out the emr_client (moto has incorrect response) + mock_emr_client = MagicMock() + mock_emr_client.run_job_flow.return_value = RUN_JOB_FLOW_SUCCESS_RETURN + + # Mock out the emr_client creator + self.boto3_client_mock = MagicMock(return_value=mock_emr_client) + + + def test_execute_uses_the_emr_config_to_create_a_cluster_and_returns_job_id(self): + with patch('boto3.client', self.boto3_client_mock): + + operator = EmrCreateJobFlowOperator( + task_id='test_task', + aws_conn_id='aws_default', + emr_conn_id='emr_default' + ) + + self.assertEqual(operator.execute(None), 'j-8989898989') + +if __name__ == '__main__': + unittest.main() http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/operators/test_emr_terminate_job_flow_operator.py ---------------------------------------------------------------------- diff --git a/tests/contrib/operators/test_emr_terminate_job_flow_operator.py b/tests/contrib/operators/test_emr_terminate_job_flow_operator.py new file mode 100644 index 0000000..94c0124 --- /dev/null +++ b/tests/contrib/operators/test_emr_terminate_job_flow_operator.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from mock import MagicMock, patch + +from airflow import configuration +from airflow.contrib.operators.emr_terminate_job_flow_operator import EmrTerminateJobFlowOperator + +TERMINATE_SUCCESS_RETURN = { + 'ResponseMetadata': { + 'HTTPStatusCode': 200 + } +} + + +class TestEmrTerminateJobFlowOperator(unittest.TestCase): + def setUp(self): + configuration.load_test_config() + + # Mock out the emr_client (moto has incorrect response) + mock_emr_client = MagicMock() + mock_emr_client.terminate_job_flows.return_value = TERMINATE_SUCCESS_RETURN + + # Mock out the emr_client creator + self.boto3_client_mock = MagicMock(return_value=mock_emr_client) + + + def test_execute_terminates_the_job_flow_and_does_not_error(self): + with patch('boto3.client', self.boto3_client_mock): + + operator = EmrTerminateJobFlowOperator( + task_id='test_task', + job_flow_id='j-8989898989', + aws_conn_id='aws_default' + ) + + operator.execute(None) + +if __name__ == '__main__': + unittest.main() http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/operators/test_fs_operator.py ---------------------------------------------------------------------- diff --git a/tests/contrib/operators/test_fs_operator.py b/tests/contrib/operators/test_fs_operator.py new file mode 100644 index 0000000..f990157 --- /dev/null +++ b/tests/contrib/operators/test_fs_operator.py @@ -0,0 +1,64 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest +from datetime import datetime + +from airflow import configuration +from airflow.settings import Session +from airflow import models, DAG +from airflow.contrib.operators.fs_operator import FileSensor + +TEST_DAG_ID = 'unit_tests' +DEFAULT_DATE = datetime(2015, 1, 1) +configuration.load_test_config() + + +def reset(dag_id=TEST_DAG_ID): + session = Session() + tis = session.query(models.TaskInstance).filter_by(dag_id=dag_id) + tis.delete() + session.commit() + session.close() + +reset() + +class FileSensorTest(unittest.TestCase): + def setUp(self): + configuration.load_test_config() + from airflow.contrib.hooks.fs_hook import FSHook + hook = FSHook() + args = { + 'owner': 'airflow', + 'start_date': DEFAULT_DATE, + 'provide_context': True + } + dag = DAG(TEST_DAG_ID+'test_schedule_dag_once', default_args=args) + dag.schedule_interval = '@once' + self.hook = hook + self.dag = dag + + def test_simple(self): + task = FileSensor( + task_id="test", + filepath="etc/hosts", + fs_conn_id='fs_default', + _hook=self.hook, + dag=self.dag, + ) + task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + +if __name__ == '__main__': + unittest.main() http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/operators/test_hipchat_operator.py ---------------------------------------------------------------------- diff --git a/tests/contrib/operators/test_hipchat_operator.py b/tests/contrib/operators/test_hipchat_operator.py new file mode 100644 index 0000000..65a2edd --- /dev/null +++ b/tests/contrib/operators/test_hipchat_operator.py @@ -0,0 +1,74 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import requests + +from airflow.contrib.operators.hipchat_operator import \ + HipChatAPISendRoomNotificationOperator +from airflow.exceptions import AirflowException +from airflow import configuration + +try: + from unittest import mock +except ImportError: + try: + import mock + except ImportError: + mock = None + + +class HipChatOperatorTest(unittest.TestCase): + def setUp(self): + configuration.load_test_config() + + @unittest.skipIf(mock is None, 'mock package not present') + @mock.patch('requests.request') + def test_execute(self, request_mock): + resp = requests.Response() + resp.status_code = 200 + request_mock.return_value = resp + + operator = HipChatAPISendRoomNotificationOperator( + task_id='test_hipchat_success', + owner = 'airflow', + token='abc123', + room_id='room_id', + message='hello world!' + ) + + operator.execute(None) + + @unittest.skipIf(mock is None, 'mock package not present') + @mock.patch('requests.request') + def test_execute_error_response(self, request_mock): + resp = requests.Response() + resp.status_code = 404 + resp.reason = 'Not Found' + request_mock.return_value = resp + + operator = HipChatAPISendRoomNotificationOperator( + task_id='test_hipchat_failure', + owner='airflow', + token='abc123', + room_id='room_id', + message='hello world!' + ) + + with self.assertRaises(AirflowException): + operator.execute(None) + + +if __name__ == '__main__': + unittest.main() http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/operators/test_jira_operator_test.py ---------------------------------------------------------------------- diff --git a/tests/contrib/operators/test_jira_operator_test.py b/tests/contrib/operators/test_jira_operator_test.py new file mode 100644 index 0000000..6d615df --- /dev/null +++ b/tests/contrib/operators/test_jira_operator_test.py @@ -0,0 +1,101 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest +import datetime +from mock import Mock +from mock import patch + +from airflow import DAG, configuration +from airflow.contrib.operators.jira_operator import JiraOperator +from airflow import models +from airflow.utils import db + +DEFAULT_DATE = datetime.datetime(2017, 1, 1) +jira_client_mock = Mock( + name="jira_client_for_test" +) + +minimal_test_ticket = { + "id": "911539", + "self": "https://sandbox.localhost/jira/rest/api/2/issue/911539", + "key": "TEST-1226", + "fields": { + "labels": [ + "test-label-1", + "test-label-2" + ], + "description": "this is a test description", + } +} + + +class TestJiraOperator(unittest.TestCase): + def setUp(self): + configuration.load_test_config() + args = { + 'owner': 'airflow', + 'start_date': DEFAULT_DATE + } + dag = DAG('test_dag_id', default_args=args) + self.dag = dag + db.merge_conn( + models.Connection( + conn_id='jira_default', conn_type='jira', + host='https://localhost/jira/', port=443, + extra='{"verify": "False", "project": "AIRFLOW"}')) + + @patch("airflow.contrib.hooks.jira_hook.JIRA", + autospec=True, return_value=jira_client_mock) + def test_issue_search(self, jira_mock): + jql_str = 'issuekey=TEST-1226' + jira_mock.return_value.search_issues.return_value = minimal_test_ticket + + jira_ticket_search_operator = JiraOperator(task_id='search-ticket-test', + jira_method="search_issues", + jira_method_args={ + 'jql_str': jql_str, + 'maxResults': '1' + }, + dag=self.dag) + + jira_ticket_search_operator.run(start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE, ignore_ti_state=True) + + self.assertTrue(jira_mock.called) + self.assertTrue(jira_mock.return_value.search_issues.called) + + @patch("airflow.contrib.hooks.jira_hook.JIRA", + autospec=True, return_value=jira_client_mock) + def test_update_issue(self, jira_mock): + jira_mock.return_value.add_comment.return_value = True + + add_comment_operator = JiraOperator(task_id='add_comment_test', + jira_method="add_comment", + jira_method_args={ + 'issue': minimal_test_ticket.get("key"), + 'body': 'this is test comment' + }, + dag=self.dag) + + add_comment_operator.run(start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE, ignore_ti_state=True) + + self.assertTrue(jira_mock.called) + self.assertTrue(jira_mock.return_value.add_comment.called) + + +if __name__ == '__main__': + unittest.main() http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/operators/test_spark_submit_operator.py ---------------------------------------------------------------------- diff --git a/tests/contrib/operators/test_spark_submit_operator.py b/tests/contrib/operators/test_spark_submit_operator.py new file mode 100644 index 0000000..3c11dbb --- /dev/null +++ b/tests/contrib/operators/test_spark_submit_operator.py @@ -0,0 +1,88 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest +import datetime +import sys + +from airflow import DAG, configuration +from airflow.contrib.operators.spark_submit_operator import SparkSubmitOperator + +DEFAULT_DATE = datetime.datetime(2017, 1, 1) + + +class TestSparkSubmitOperator(unittest.TestCase): + + _config = { + 'conf': { + 'parquet.compression': 'SNAPPY' + }, + 'files': 'hive-site.xml', + 'py_files': 'sample_library.py', + 'jars': 'parquet.jar', + 'executor_cores': 4, + 'executor_memory': '22g', + 'keytab': 'privileged_user.keytab', + 'principal': 'user/[email protected]', + 'name': 'spark-job', + 'num_executors': 10, + 'verbose': True, + 'application': 'test_application.py', + 'driver_memory': '3g', + 'java_class': 'com.foo.bar.AppMain' + } + + def setUp(self): + + if sys.version_info[0] == 3: + raise unittest.SkipTest('TestSparkSubmitOperator won\'t work with ' + 'python3. No need to test anything here') + + configuration.load_test_config() + args = { + 'owner': 'airflow', + 'start_date': DEFAULT_DATE + } + self.dag = DAG('test_dag_id', default_args=args) + + def test_execute(self, conn_id='spark_default'): + operator = SparkSubmitOperator( + task_id='spark_submit_job', + dag=self.dag, + **self._config + ) + + self.assertEqual(conn_id, operator._conn_id) + + self.assertEqual(self._config['application'], operator._application) + self.assertEqual(self._config['conf'], operator._conf) + self.assertEqual(self._config['files'], operator._files) + self.assertEqual(self._config['py_files'], operator._py_files) + self.assertEqual(self._config['jars'], operator._jars) + self.assertEqual(self._config['executor_cores'], operator._executor_cores) + self.assertEqual(self._config['executor_memory'], operator._executor_memory) + self.assertEqual(self._config['keytab'], operator._keytab) + self.assertEqual(self._config['principal'], operator._principal) + self.assertEqual(self._config['name'], operator._name) + self.assertEqual(self._config['num_executors'], operator._num_executors) + self.assertEqual(self._config['verbose'], operator._verbose) + self.assertEqual(self._config['java_class'], operator._java_class) + self.assertEqual(self._config['driver_memory'], operator._driver_memory) + + + + +if __name__ == '__main__': + unittest.main()
