[AIRFLOW-1094] Run unit tests under contrib in Travis Rename all unit tests under tests/contrib to start with test_* and fix broken unit tests so that they run for the Python 2 and 3 builds.
Closes #2234 from hgrif/AIRFLOW-1094 Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/219c5064 Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/219c5064 Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/219c5064 Branch: refs/heads/master Commit: 219c5064142c66cf8f051455199f2dda9b164584 Parents: 74c1ce2 Author: Henk Griffioen <[email protected]> Authored: Mon Apr 17 10:04:29 2017 +0200 Committer: Bolke de Bruin <[email protected]> Committed: Mon Apr 17 10:04:36 2017 +0200 ---------------------------------------------------------------------- airflow/contrib/operators/ecs_operator.py | 2 +- airflow/hooks/__init__.py | 1 + airflow/hooks/zendesk_hook.py | 2 +- scripts/ci/requirements.txt | 5 + tests/contrib/hooks/aws_hook.py | 47 ---- tests/contrib/hooks/bigquery_hook.py | 139 ---------- tests/contrib/hooks/databricks_hook.py | 226 ----------------- tests/contrib/hooks/emr_hook.py | 53 ---- tests/contrib/hooks/gcp_dataflow_hook.py | 56 ----- tests/contrib/hooks/spark_submit_hook.py | 185 -------------- tests/contrib/hooks/sqoop_hook.py | 219 ---------------- tests/contrib/hooks/test_aws_hook.py | 47 ++++ tests/contrib/hooks/test_bigquery_hook.py | 139 ++++++++++ tests/contrib/hooks/test_databricks_hook.py | 226 +++++++++++++++++ tests/contrib/hooks/test_emr_hook.py | 53 ++++ tests/contrib/hooks/test_gcp_dataflow_hook.py | 56 +++++ tests/contrib/hooks/test_spark_submit_hook.py | 197 +++++++++++++++ tests/contrib/hooks/test_sqoop_hook.py | 218 ++++++++++++++++ tests/contrib/hooks/test_zendesk_hook.py | 89 +++++++ tests/contrib/hooks/zendesk_hook.py | 90 ------- tests/contrib/operators/__init__.py | 3 - tests/contrib/operators/databricks_operator.py | 185 -------------- tests/contrib/operators/dataflow_operator.py | 82 ------ tests/contrib/operators/ecs_operator.py | 207 --------------- .../contrib/operators/emr_add_steps_operator.py | 53 ---- .../operators/emr_create_job_flow_operator.py | 53 ---- .../emr_terminate_job_flow_operator.py | 52 ---- tests/contrib/operators/fs_operator.py | 64 ----- tests/contrib/operators/hipchat_operator.py | 74 ------ tests/contrib/operators/jira_operator_test.py | 101 -------- .../contrib/operators/spark_submit_operator.py | 81 ------ tests/contrib/operators/sqoop_operator.py | 93 ------- tests/contrib/operators/ssh_execute_operator.py | 79 ------ .../operators/test_databricks_operator.py | 185 ++++++++++++++ .../contrib/operators/test_dataflow_operator.py | 81 ++++++ tests/contrib/operators/test_ecs_operator.py | 214 ++++++++++++++++ .../operators/test_emr_add_steps_operator.py | 53 ++++ .../test_emr_create_job_flow_operator.py | 53 ++++ .../test_emr_terminate_job_flow_operator.py | 52 ++++ tests/contrib/operators/test_fs_operator.py | 64 +++++ .../contrib/operators/test_hipchat_operator.py | 74 ++++++ .../operators/test_jira_operator_test.py | 101 ++++++++ .../operators/test_spark_submit_operator.py | 88 +++++++ tests/contrib/operators/test_sqoop_operator.py | 93 +++++++ .../operators/test_ssh_execute_operator.py | 95 +++++++ tests/contrib/sensors/datadog_sensor.py | 91 ------- tests/contrib/sensors/emr_base_sensor.py | 126 ---------- tests/contrib/sensors/emr_job_flow_sensor.py | 123 --------- tests/contrib/sensors/emr_step_sensor.py | 119 --------- tests/contrib/sensors/ftp_sensor.py | 66 ----- tests/contrib/sensors/hdfs_sensors.py | 251 ------------------- tests/contrib/sensors/jira_sensor_test.py | 85 ------- tests/contrib/sensors/redis_sensor.py | 64 ----- tests/contrib/sensors/test_datadog_sensor.py | 106 ++++++++ tests/contrib/sensors/test_emr_base_sensor.py | 126 ++++++++++ .../contrib/sensors/test_emr_job_flow_sensor.py | 123 +++++++++ tests/contrib/sensors/test_emr_step_sensor.py | 119 +++++++++ tests/contrib/sensors/test_ftp_sensor.py | 66 +++++ tests/contrib/sensors/test_hdfs_sensors.py | 251 +++++++++++++++++++ tests/contrib/sensors/test_jira_sensor_test.py | 85 +++++++ tests/contrib/sensors/test_redis_sensor.py | 64 +++++ 61 files changed, 3126 insertions(+), 3069 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/airflow/contrib/operators/ecs_operator.py ---------------------------------------------------------------------- diff --git a/airflow/contrib/operators/ecs_operator.py b/airflow/contrib/operators/ecs_operator.py index df02c4e..11f8c94 100644 --- a/airflow/contrib/operators/ecs_operator.py +++ b/airflow/contrib/operators/ecs_operator.py @@ -89,7 +89,7 @@ class ECSOperator(BaseOperator): def _wait_for_task_ended(self): waiter = self.client.get_waiter('tasks_stopped') - waiter.config.max_attempts = sys.maxint # timeout is managed by airflow + waiter.config.max_attempts = sys.maxsize # timeout is managed by airflow waiter.wait( cluster=self.cluster, tasks=[self.arn] http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/airflow/hooks/__init__.py ---------------------------------------------------------------------- diff --git a/airflow/hooks/__init__.py b/airflow/hooks/__init__.py index cc09f5a..bb02967 100644 --- a/airflow/hooks/__init__.py +++ b/airflow/hooks/__init__.py @@ -48,6 +48,7 @@ _hooks = { 'samba_hook': ['SambaHook'], 'sqlite_hook': ['SqliteHook'], 'S3_hook': ['S3Hook'], + 'zendesk_hook': ['ZendeskHook'], 'http_hook': ['HttpHook'], 'druid_hook': ['DruidHook'], 'jdbc_hook': ['JdbcHook'], http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/airflow/hooks/zendesk_hook.py ---------------------------------------------------------------------- diff --git a/airflow/hooks/zendesk_hook.py b/airflow/hooks/zendesk_hook.py index 438597f..907d1e8 100644 --- a/airflow/hooks/zendesk_hook.py +++ b/airflow/hooks/zendesk_hook.py @@ -21,7 +21,7 @@ A hook to talk to Zendesk import logging import time from zdesk import Zendesk, RateLimitError, ZendeskError -from airflow.hooks import BaseHook +from airflow.hooks.base_hook import BaseHook class ZendeskHook(BaseHook): http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/scripts/ci/requirements.txt ---------------------------------------------------------------------- diff --git a/scripts/ci/requirements.txt b/scripts/ci/requirements.txt index 1905398..751c13f 100644 --- a/scripts/ci/requirements.txt +++ b/scripts/ci/requirements.txt @@ -3,6 +3,7 @@ azure-storage>=0.34.0 bcrypt bleach boto +boto3 celery cgroupspy chartkick @@ -11,6 +12,7 @@ coverage coveralls croniter cryptography +datadog dill distributed docker-py @@ -25,6 +27,7 @@ Flask-WTF flower freezegun future +google-api-python-client>=1.5.0,<1.6.0 gunicorn hdfs hive-thrift-py @@ -37,6 +40,7 @@ ldap3 lxml markdown mock +moto mysqlclient nose nose-exclude @@ -69,3 +73,4 @@ statsd thrift thrift_sasl unicodecsv +zdesk http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/hooks/aws_hook.py ---------------------------------------------------------------------- diff --git a/tests/contrib/hooks/aws_hook.py b/tests/contrib/hooks/aws_hook.py deleted file mode 100644 index 6f13e58..0000000 --- a/tests/contrib/hooks/aws_hook.py +++ /dev/null @@ -1,47 +0,0 @@ -# -*- coding: utf-8 -*- -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import unittest -import boto3 - -from airflow import configuration -from airflow.contrib.hooks.aws_hook import AwsHook - - -try: - from moto import mock_emr -except ImportError: - mock_emr = None - - -class TestAwsHook(unittest.TestCase): - @mock_emr - def setUp(self): - configuration.load_test_config() - - @unittest.skipIf(mock_emr is None, 'mock_emr package not present') - @mock_emr - def test_get_client_type_returns_a_boto3_client_of_the_requested_type(self): - client = boto3.client('emr', region_name='us-east-1') - if len(client.list_clusters()['Clusters']): - raise ValueError('AWS not properly mocked') - - hook = AwsHook(aws_conn_id='aws_default') - client_from_hook = hook.get_client_type('emr') - - self.assertEqual(client_from_hook.list_clusters()['Clusters'], []) - -if __name__ == '__main__': - unittest.main() http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/hooks/bigquery_hook.py ---------------------------------------------------------------------- diff --git a/tests/contrib/hooks/bigquery_hook.py b/tests/contrib/hooks/bigquery_hook.py deleted file mode 100644 index 68856f8..0000000 --- a/tests/contrib/hooks/bigquery_hook.py +++ /dev/null @@ -1,139 +0,0 @@ -# -*- coding: utf-8 -*- -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import unittest - -from airflow.contrib.hooks import bigquery_hook as hook - - -class TestBigQueryTableSplitter(unittest.TestCase): - def test_internal_need_default_project(self): - with self.assertRaises(Exception) as context: - hook._split_tablename('dataset.table', None) - - self.assertIn('INTERNAL: No default project is specified', - str(context.exception), "") - - def test_split_dataset_table(self): - project, dataset, table = hook._split_tablename('dataset.table', - 'project') - self.assertEqual("project", project) - self.assertEqual("dataset", dataset) - self.assertEqual("table", table) - - def test_split_project_dataset_table(self): - project, dataset, table = hook._split_tablename('alternative:dataset.table', - 'project') - self.assertEqual("alternative", project) - self.assertEqual("dataset", dataset) - self.assertEqual("table", table) - - def test_sql_split_project_dataset_table(self): - project, dataset, table = hook._split_tablename('alternative.dataset.table', - 'project') - self.assertEqual("alternative", project) - self.assertEqual("dataset", dataset) - self.assertEqual("table", table) - - def test_invalid_syntax_column_double_project(self): - with self.assertRaises(Exception) as context: - hook._split_tablename('alt1:alt.dataset.table', - 'project') - - self.assertIn('Use either : or . to specify project', - str(context.exception), "") - self.assertFalse('Format exception for' in str(context.exception)) - - def test_invalid_syntax_double_column(self): - with self.assertRaises(Exception) as context: - hook._split_tablename('alt1:alt:dataset.table', - 'project') - - self.assertIn('Expect format of (<project:)<dataset>.<table>', - str(context.exception), "") - self.assertFalse('Format exception for' in str(context.exception)) - - def test_invalid_syntax_tiple_dot(self): - with self.assertRaises(Exception) as context: - hook._split_tablename('alt1.alt.dataset.table', - 'project') - - self.assertIn('Expect format of (<project.|<project:)<dataset>.<table>', - str(context.exception), "") - self.assertFalse('Format exception for' in str(context.exception)) - - def test_invalid_syntax_column_double_project_var(self): - with self.assertRaises(Exception) as context: - hook._split_tablename('alt1:alt.dataset.table', - 'project', 'var_x') - - self.assertIn('Use either : or . to specify project', - str(context.exception), "") - self.assertIn('Format exception for var_x:', - str(context.exception), "") - - def test_invalid_syntax_double_column_var(self): - with self.assertRaises(Exception) as context: - hook._split_tablename('alt1:alt:dataset.table', - 'project', 'var_x') - - self.assertIn('Expect format of (<project:)<dataset>.<table>', - str(context.exception), "") - self.assertIn('Format exception for var_x:', - str(context.exception), "") - - def test_invalid_syntax_tiple_dot_var(self): - with self.assertRaises(Exception) as context: - hook._split_tablename('alt1.alt.dataset.table', - 'project', 'var_x') - - self.assertIn('Expect format of (<project.|<project:)<dataset>.<table>', - str(context.exception), "") - self.assertIn('Format exception for var_x:', - str(context.exception), "") - -class TestBigQueryHookSourceFormat(unittest.TestCase): - def test_invalid_source_format(self): - with self.assertRaises(Exception) as context: - hook.BigQueryBaseCursor("test", "test").run_load("test.test", "test_schema.json", ["test_data.json"], source_format="json") - - # since we passed 'json' in, and it's not valid, make sure it's present in the error string. - self.assertIn("json", str(context.exception)) - - -class TestBigQueryBaseCursor(unittest.TestCase): - def test_invalid_schema_update_options(self): - with self.assertRaises(Exception) as context: - hook.BigQueryBaseCursor("test", "test").run_load( - "test.test", - "test_schema.json", - ["test_data.json"], - schema_update_options=["THIS IS NOT VALID"] - ) - self.assertIn("THIS IS NOT VALID", str(context.exception)) - - def test_invalid_schema_update_and_write_disposition(self): - with self.assertRaises(Exception) as context: - hook.BigQueryBaseCursor("test", "test").run_load( - "test.test", - "test_schema.json", - ["test_data.json"], - schema_update_options=['ALLOW_FIELD_ADDITION'], - write_disposition='WRITE_EMPTY' - ) - self.assertIn("schema_update_options is only", str(context.exception)) - -if __name__ == '__main__': - unittest.main() http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/hooks/databricks_hook.py ---------------------------------------------------------------------- diff --git a/tests/contrib/hooks/databricks_hook.py b/tests/contrib/hooks/databricks_hook.py deleted file mode 100644 index 6c789f9..0000000 --- a/tests/contrib/hooks/databricks_hook.py +++ /dev/null @@ -1,226 +0,0 @@ -# -*- coding: utf-8 -*- -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import unittest - -from airflow import __version__ -from airflow.contrib.hooks.databricks_hook import DatabricksHook, RunState, SUBMIT_RUN_ENDPOINT -from airflow.exceptions import AirflowException -from airflow.models import Connection -from airflow.utils import db -from requests import exceptions as requests_exceptions - -try: - from unittest import mock -except ImportError: - try: - import mock - except ImportError: - mock = None - -TASK_ID = 'databricks-operator' -DEFAULT_CONN_ID = 'databricks_default' -NOTEBOOK_TASK = { - 'notebook_path': '/test' -} -NEW_CLUSTER = { - 'spark_version': '2.0.x-scala2.10', - 'node_type_id': 'r3.xlarge', - 'num_workers': 1 -} -RUN_ID = 1 -HOST = 'xx.cloud.databricks.com' -HOST_WITH_SCHEME = 'https://xx.cloud.databricks.com' -LOGIN = 'login' -PASSWORD = 'password' -USER_AGENT_HEADER = {'user-agent': 'airflow-{v}'.format(v=__version__)} -RUN_PAGE_URL = 'https://XX.cloud.databricks.com/#jobs/1/runs/1' -LIFE_CYCLE_STATE = 'PENDING' -STATE_MESSAGE = 'Waiting for cluster' -GET_RUN_RESPONSE = { - 'run_page_url': RUN_PAGE_URL, - 'state': { - 'life_cycle_state': LIFE_CYCLE_STATE, - 'state_message': STATE_MESSAGE - } -} -RESULT_STATE = None - - -def submit_run_endpoint(host): - """ - Utility function to generate the submit run endpoint given the host. - """ - return 'https://{}/api/2.0/jobs/runs/submit'.format(host) - - -def get_run_endpoint(host): - """ - Utility function to generate the get run endpoint given the host. - """ - return 'https://{}/api/2.0/jobs/runs/get'.format(host) - -def cancel_run_endpoint(host): - """ - Utility function to generate the get run endpoint given the host. - """ - return 'https://{}/api/2.0/jobs/runs/cancel'.format(host) - -class DatabricksHookTest(unittest.TestCase): - """ - Tests for DatabricksHook. - """ - @db.provide_session - def setUp(self, session=None): - conn = session.query(Connection) \ - .filter(Connection.conn_id == DEFAULT_CONN_ID) \ - .first() - conn.host = HOST - conn.login = LOGIN - conn.password = PASSWORD - session.commit() - - self.hook = DatabricksHook() - - def test_parse_host_with_proper_host(self): - host = self.hook._parse_host(HOST) - self.assertEquals(host, HOST) - - def test_parse_host_with_scheme(self): - host = self.hook._parse_host(HOST_WITH_SCHEME) - self.assertEquals(host, HOST) - - def test_init_bad_retry_limit(self): - with self.assertRaises(AssertionError): - DatabricksHook(retry_limit = 0) - - @mock.patch('airflow.contrib.hooks.databricks_hook.logging') - @mock.patch('airflow.contrib.hooks.databricks_hook.requests') - def test_do_api_call_with_error_retry(self, mock_requests, mock_logging): - for exception in [requests_exceptions.ConnectionError, requests_exceptions.Timeout]: - mock_requests.reset_mock() - mock_logging.reset_mock() - mock_requests.post.side_effect = exception() - - with self.assertRaises(AirflowException): - self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {}) - - self.assertEquals(len(mock_logging.error.mock_calls), self.hook.retry_limit) - - @mock.patch('airflow.contrib.hooks.databricks_hook.requests') - def test_do_api_call_with_bad_status_code(self, mock_requests): - mock_requests.codes.ok = 200 - status_code_mock = mock.PropertyMock(return_value=500) - type(mock_requests.post.return_value).status_code = status_code_mock - with self.assertRaises(AirflowException): - self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {}) - - @mock.patch('airflow.contrib.hooks.databricks_hook.requests') - def test_submit_run(self, mock_requests): - mock_requests.codes.ok = 200 - mock_requests.post.return_value.json.return_value = {'run_id': '1'} - status_code_mock = mock.PropertyMock(return_value=200) - type(mock_requests.post.return_value).status_code = status_code_mock - json = { - 'notebook_task': NOTEBOOK_TASK, - 'new_cluster': NEW_CLUSTER - } - run_id = self.hook.submit_run(json) - - self.assertEquals(run_id, '1') - mock_requests.post.assert_called_once_with( - submit_run_endpoint(HOST), - json={ - 'notebook_task': NOTEBOOK_TASK, - 'new_cluster': NEW_CLUSTER, - }, - auth=(LOGIN, PASSWORD), - headers=USER_AGENT_HEADER, - timeout=self.hook.timeout_seconds) - - @mock.patch('airflow.contrib.hooks.databricks_hook.requests') - def test_get_run_page_url(self, mock_requests): - mock_requests.codes.ok = 200 - mock_requests.get.return_value.json.return_value = GET_RUN_RESPONSE - status_code_mock = mock.PropertyMock(return_value=200) - type(mock_requests.get.return_value).status_code = status_code_mock - - run_page_url = self.hook.get_run_page_url(RUN_ID) - - self.assertEquals(run_page_url, RUN_PAGE_URL) - mock_requests.get.assert_called_once_with( - get_run_endpoint(HOST), - json={'run_id': RUN_ID}, - auth=(LOGIN, PASSWORD), - headers=USER_AGENT_HEADER, - timeout=self.hook.timeout_seconds) - - @mock.patch('airflow.contrib.hooks.databricks_hook.requests') - def test_get_run_state(self, mock_requests): - mock_requests.codes.ok = 200 - mock_requests.get.return_value.json.return_value = GET_RUN_RESPONSE - status_code_mock = mock.PropertyMock(return_value=200) - type(mock_requests.get.return_value).status_code = status_code_mock - - run_state = self.hook.get_run_state(RUN_ID) - - self.assertEquals(run_state, RunState( - LIFE_CYCLE_STATE, - RESULT_STATE, - STATE_MESSAGE)) - mock_requests.get.assert_called_once_with( - get_run_endpoint(HOST), - json={'run_id': RUN_ID}, - auth=(LOGIN, PASSWORD), - headers=USER_AGENT_HEADER, - timeout=self.hook.timeout_seconds) - - @mock.patch('airflow.contrib.hooks.databricks_hook.requests') - def test_cancel_run(self, mock_requests): - mock_requests.codes.ok = 200 - mock_requests.post.return_value.json.return_value = GET_RUN_RESPONSE - status_code_mock = mock.PropertyMock(return_value=200) - type(mock_requests.post.return_value).status_code = status_code_mock - - self.hook.cancel_run(RUN_ID) - - mock_requests.post.assert_called_once_with( - cancel_run_endpoint(HOST), - json={'run_id': RUN_ID}, - auth=(LOGIN, PASSWORD), - headers=USER_AGENT_HEADER, - timeout=self.hook.timeout_seconds) - -class RunStateTest(unittest.TestCase): - def test_is_terminal_true(self): - terminal_states = ['TERMINATED', 'SKIPPED', 'INTERNAL_ERROR'] - for state in terminal_states: - run_state = RunState(state, '', '') - self.assertTrue(run_state.is_terminal) - - def test_is_terminal_false(self): - non_terminal_states = ['PENDING', 'RUNNING', 'TERMINATING'] - for state in non_terminal_states: - run_state = RunState(state, '', '') - self.assertFalse(run_state.is_terminal) - - def test_is_terminal_with_nonexistent_life_cycle_state(self): - run_state = RunState('blah', '', '') - with self.assertRaises(AirflowException): - run_state.is_terminal - - def test_is_successful(self): - run_state = RunState('TERMINATED', 'SUCCESS', '') - self.assertTrue(run_state.is_successful) http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/hooks/emr_hook.py ---------------------------------------------------------------------- diff --git a/tests/contrib/hooks/emr_hook.py b/tests/contrib/hooks/emr_hook.py deleted file mode 100644 index 119df99..0000000 --- a/tests/contrib/hooks/emr_hook.py +++ /dev/null @@ -1,53 +0,0 @@ -# -*- coding: utf-8 -*- -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import unittest -import boto3 - -from airflow import configuration -from airflow.contrib.hooks.emr_hook import EmrHook - - -try: - from moto import mock_emr -except ImportError: - mock_emr = None - - -class TestEmrHook(unittest.TestCase): - @mock_emr - def setUp(self): - configuration.load_test_config() - - @unittest.skipIf(mock_emr is None, 'mock_emr package not present') - @mock_emr - def test_get_conn_returns_a_boto3_connection(self): - hook = EmrHook(aws_conn_id='aws_default') - self.assertIsNotNone(hook.get_conn().list_clusters()) - - @unittest.skipIf(mock_emr is None, 'mock_emr package not present') - @mock_emr - def test_create_job_flow_uses_the_emr_config_to_create_a_cluster(self): - client = boto3.client('emr', region_name='us-east-1') - if len(client.list_clusters()['Clusters']): - raise ValueError('AWS not properly mocked') - - hook = EmrHook(aws_conn_id='aws_default', emr_conn_id='emr_default') - cluster = hook.create_job_flow({'Name': 'test_cluster'}) - - self.assertEqual(client.list_clusters()['Clusters'][0]['Id'], cluster['JobFlowId']) - -if __name__ == '__main__': - unittest.main() http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/hooks/gcp_dataflow_hook.py ---------------------------------------------------------------------- diff --git a/tests/contrib/hooks/gcp_dataflow_hook.py b/tests/contrib/hooks/gcp_dataflow_hook.py deleted file mode 100644 index 797d40c..0000000 --- a/tests/contrib/hooks/gcp_dataflow_hook.py +++ /dev/null @@ -1,56 +0,0 @@ -# -*- coding: utf-8 -*- -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import unittest -from airflow.contrib.hooks.gcp_dataflow_hook import DataFlowHook - -try: - from unittest import mock -except ImportError: - try: - import mock - except ImportError: - mock = None - - -TASK_ID = 'test-python-dataflow' -PY_FILE = 'apache_beam.examples.wordcount' -PY_OPTIONS = ['-m'] -OPTIONS = { - 'project': 'test', - 'staging_location': 'gs://test/staging' -} -BASE_STRING = 'airflow.contrib.hooks.gcp_api_base_hook.{}' -DATAFLOW_STRING = 'airflow.contrib.hooks.gcp_dataflow_hook.{}' - - -def mock_init(self, gcp_conn_id, delegate_to=None): - pass - - -class DataFlowHookTest(unittest.TestCase): - - def setUp(self): - with mock.patch(BASE_STRING.format('GoogleCloudBaseHook.__init__'), - new=mock_init): - self.dataflow_hook = DataFlowHook(gcp_conn_id='test') - - @mock.patch(DATAFLOW_STRING.format('DataFlowHook._start_dataflow')) - def test_start_python_dataflow(self, internal_dataflow_mock): - self.dataflow_hook.start_python_dataflow( - task_id=TASK_ID, variables=OPTIONS, - dataflow=PY_FILE, py_options=PY_OPTIONS) - internal_dataflow_mock.assert_called_once_with( - TASK_ID, OPTIONS, PY_FILE, mock.ANY, ['python'] + PY_OPTIONS) http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/hooks/spark_submit_hook.py ---------------------------------------------------------------------- diff --git a/tests/contrib/hooks/spark_submit_hook.py b/tests/contrib/hooks/spark_submit_hook.py deleted file mode 100644 index 8f514c2..0000000 --- a/tests/contrib/hooks/spark_submit_hook.py +++ /dev/null @@ -1,185 +0,0 @@ -# -*- coding: utf-8 -*- -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import os -import unittest - -from airflow import configuration, models -from airflow.utils import db -from airflow.exceptions import AirflowException -from airflow.contrib.hooks.spark_submit_hook import SparkSubmitHook - - -class TestSparkSubmitHook(unittest.TestCase): - _spark_job_file = 'test_application.py' - _config = { - 'conf': { - 'parquet.compression': 'SNAPPY' - }, - 'conn_id': 'default_spark', - 'files': 'hive-site.xml', - 'py_files': 'sample_library.py', - 'jars': 'parquet.jar', - 'executor_cores': 4, - 'executor_memory': '22g', - 'keytab': 'privileged_user.keytab', - 'principal': 'user/[email protected]', - 'name': 'spark-job', - 'num_executors': 10, - 'verbose': True, - 'driver_memory': '3g', - 'java_class': 'com.foo.bar.AppMain' - } - - def setUp(self): - configuration.load_test_config() - db.merge_conn( - models.Connection( - conn_id='spark_yarn_cluster', conn_type='spark', - host='yarn://yarn-master', extra='{"queue": "root.etl", "deploy-mode": "cluster"}') - ) - db.merge_conn( - models.Connection( - conn_id='spark_default_mesos', conn_type='spark', - host='mesos://host', port=5050) - ) - - db.merge_conn( - models.Connection( - conn_id='spark_home_set', conn_type='spark', - host='yarn://yarn-master', - extra='{"spark-home": "/opt/myspark"}') - ) - - db.merge_conn( - models.Connection( - conn_id='spark_home_not_set', conn_type='spark', - host='yarn://yarn-master') - ) - - def test_build_command(self): - hook = SparkSubmitHook(**self._config) - - # The subprocess requires an array but we build the cmd by joining on a space - cmd = ' '.join(hook._build_command(self._spark_job_file)) - - # Check if the URL gets build properly and everything exists. - assert self._spark_job_file in cmd - - # Check all the parameters - assert "--files {}".format(self._config['files']) in cmd - assert "--py-files {}".format(self._config['py_files']) in cmd - assert "--jars {}".format(self._config['jars']) in cmd - assert "--executor-cores {}".format(self._config['executor_cores']) in cmd - assert "--executor-memory {}".format(self._config['executor_memory']) in cmd - assert "--keytab {}".format(self._config['keytab']) in cmd - assert "--principal {}".format(self._config['principal']) in cmd - assert "--name {}".format(self._config['name']) in cmd - assert "--num-executors {}".format(self._config['num_executors']) in cmd - assert "--class {}".format(self._config['java_class']) in cmd - assert "--driver-memory {}".format(self._config['driver_memory']) in cmd - - # Check if all config settings are there - for k in self._config['conf']: - assert "--conf {0}={1}".format(k, self._config['conf'][k]) in cmd - - if self._config['verbose']: - assert "--verbose" in cmd - - def test_submit(self): - hook = SparkSubmitHook() - - # We don't have spark-submit available, and this is hard to mock, so just accept - # an exception for now. - with self.assertRaises(AirflowException): - hook.submit(self._spark_job_file) - - def test_resolve_connection(self): - - # Default to the standard yarn connection because conn_id does not exists - hook = SparkSubmitHook(conn_id='') - self.assertEqual(hook._resolve_connection(), ('yarn', None, None, None)) - assert "--master yarn" in ' '.join(hook._build_command(self._spark_job_file)) - - # Default to the standard yarn connection - hook = SparkSubmitHook(conn_id='spark_default') - self.assertEqual( - hook._resolve_connection(), - ('yarn', 'root.default', None, None) - ) - cmd = ' '.join(hook._build_command(self._spark_job_file)) - assert "--master yarn" in cmd - assert "--queue root.default" in cmd - - # Connect to a mesos master - hook = SparkSubmitHook(conn_id='spark_default_mesos') - self.assertEqual( - hook._resolve_connection(), - ('mesos://host:5050', None, None, None) - ) - - cmd = ' '.join(hook._build_command(self._spark_job_file)) - assert "--master mesos://host:5050" in cmd - - # Set specific queue and deploy mode - hook = SparkSubmitHook(conn_id='spark_yarn_cluster') - self.assertEqual( - hook._resolve_connection(), - ('yarn://yarn-master', 'root.etl', 'cluster', None) - ) - - cmd = ' '.join(hook._build_command(self._spark_job_file)) - assert "--master yarn://yarn-master" in cmd - assert "--queue root.etl" in cmd - assert "--deploy-mode cluster" in cmd - - # Set the spark home - hook = SparkSubmitHook(conn_id='spark_home_set') - self.assertEqual( - hook._resolve_connection(), - ('yarn://yarn-master', None, None, '/opt/myspark') - ) - - cmd = ' '.join(hook._build_command(self._spark_job_file)) - assert cmd.startswith('/opt/myspark/bin/spark-submit') - - # Spark home not set - hook = SparkSubmitHook(conn_id='spark_home_not_set') - self.assertEqual( - hook._resolve_connection(), - ('yarn://yarn-master', None, None, None) - ) - - cmd = ' '.join(hook._build_command(self._spark_job_file)) - assert cmd.startswith('spark-submit') - - def test_process_log(self): - # Must select yarn connection - hook = SparkSubmitHook(conn_id='spark_yarn_cluster') - - log_lines = [ - 'SPARK_MAJOR_VERSION is set to 2, using Spark2', - 'WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable', - 'WARN DomainSocketFactory: The short-circuit local reads feature cannot be used because libhadoop cannot be loaded.', - 'INFO Client: Requesting a new application from cluster with 10 NodeManagers', - 'INFO Client: Submitting application application_1486558679801_1820 to ResourceManager' - ] - - hook._process_log(log_lines) - - assert hook._yarn_application_id == 'application_1486558679801_1820' - - -if __name__ == '__main__': - unittest.main() http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/hooks/sqoop_hook.py ---------------------------------------------------------------------- diff --git a/tests/contrib/hooks/sqoop_hook.py b/tests/contrib/hooks/sqoop_hook.py deleted file mode 100644 index 1d85e43..0000000 --- a/tests/contrib/hooks/sqoop_hook.py +++ /dev/null @@ -1,219 +0,0 @@ -# -*- coding: utf-8 -*- -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import json -import unittest -from exceptions import OSError - -from airflow import configuration, models -from airflow.contrib.hooks.sqoop_hook import SqoopHook -from airflow.utils import db - - -class TestSqoopHook(unittest.TestCase): - _config = { - 'conn_id': 'sqoop_test', - 'num_mappers': 22, - 'verbose': True, - 'properties': { - 'mapred.map.max.attempts': '1' - } - } - _config_export = { - 'table': 'domino.export_data_to', - 'export_dir': '/hdfs/data/to/be/exported', - 'input_null_string': '\n', - 'input_null_non_string': '\t', - 'staging_table': 'database.staging', - 'clear_staging_table': True, - 'enclosed_by': '"', - 'escaped_by': '\\', - 'input_fields_terminated_by': '|', - 'input_lines_terminated_by': '\n', - 'input_optionally_enclosed_by': '"', - 'batch': True, - 'relaxed_isolation': True - } - _config_import = { - 'target_dir': '/hdfs/data/target/location', - 'append': True, - 'file_type': 'parquet', - 'split_by': '\n', - 'direct': True, - 'driver': 'com.microsoft.jdbc.sqlserver.SQLServerDriver' - } - - _config_json = { - 'namenode': 'http://0.0.0.0:50070/', - 'job_tracker': 'http://0.0.0.0:50030/', - 'libjars': '/path/to/jars', - 'files': '/path/to/files', - 'archives': '/path/to/archives' - } - - def setUp(self): - configuration.load_test_config() - db.merge_conn( - models.Connection( - conn_id='sqoop_test', conn_type='sqoop', - host='rmdbs', port=5050, extra=json.dumps(self._config_json) - ) - ) - - def test_popen(self): - hook = SqoopHook(**self._config) - - # Should go well - hook.Popen(['ls']) - - # Should give an exception - with self.assertRaises(OSError): - hook.Popen('exit 1') - - def test_submit(self): - hook = SqoopHook(**self._config) - - cmd = ' '.join(hook._prepare_command()) - - # Check if the config has been extracted from the json - if self._config_json['namenode']: - assert "-fs {}".format(self._config_json['namenode']) in cmd - - if self._config_json['job_tracker']: - assert "-jt {}".format(self._config_json['job_tracker']) in cmd - - if self._config_json['libjars']: - assert "-libjars {}".format(self._config_json['libjars']) in cmd - - if self._config_json['files']: - assert "-files {}".format(self._config_json['files']) in cmd - - if self._config_json['archives']: - assert "-archives {}".format(self._config_json['archives']) in cmd - - # Check the regulator stuff passed by the default constructor - if self._config['verbose']: - assert "--verbose" in cmd - - if self._config['num_mappers']: - assert "--num-mappers {}".format( - self._config['num_mappers']) in cmd - - print(self._config['properties']) - for key, value in self._config['properties'].items(): - assert "-D {}={}".format(key, value) in cmd - - # We don't have the sqoop binary available, and this is hard to mock, - # so just accept an exception for now. - with self.assertRaises(OSError): - hook.export_table(**self._config_export) - - with self.assertRaises(OSError): - hook.import_table(table='schema.table', - target_dir='/sqoop/example/path') - - with self.assertRaises(OSError): - hook.import_query(query='SELECT * FROM sometable', - target_dir='/sqoop/example/path') - - def test_export_cmd(self): - hook = SqoopHook() - - # The subprocess requires an array but we build the cmd by joining on a space - cmd = ' '.join( - hook._export_cmd( - self._config_export['table'], - self._config_export['export_dir'], - input_null_string=self._config_export['input_null_string'], - input_null_non_string=self._config_export[ - 'input_null_non_string'], - staging_table=self._config_export['staging_table'], - clear_staging_table=self._config_export['clear_staging_table'], - enclosed_by=self._config_export['enclosed_by'], - escaped_by=self._config_export['escaped_by'], - input_fields_terminated_by=self._config_export[ - 'input_fields_terminated_by'], - input_lines_terminated_by=self._config_export[ - 'input_lines_terminated_by'], - input_optionally_enclosed_by=self._config_export[ - 'input_optionally_enclosed_by'], - batch=self._config_export['batch'], - relaxed_isolation=self._config_export['relaxed_isolation']) - ) - - assert "--input-null-string {}".format( - self._config_export['input_null_string']) in cmd - assert "--input-null-non-string {}".format( - self._config_export['input_null_non_string']) in cmd - assert "--staging-table {}".format( - self._config_export['staging_table']) in cmd - assert "--enclosed-by {}".format( - self._config_export['enclosed_by']) in cmd - assert "--escaped-by {}".format( - self._config_export['escaped_by']) in cmd - assert "--input-fields-terminated-by {}".format( - self._config_export['input_fields_terminated_by']) in cmd - assert "--input-lines-terminated-by {}".format( - self._config_export['input_lines_terminated_by']) in cmd - assert "--input-optionally-enclosed-by {}".format( - self._config_export['input_optionally_enclosed_by']) in cmd - - if self._config_export['clear_staging_table']: - assert "--clear-staging-table" in cmd - - if self._config_export['batch']: - assert "--batch" in cmd - - if self._config_export['relaxed_isolation']: - assert "--relaxed-isolation" in cmd - - def test_import_cmd(self): - hook = SqoopHook() - - # The subprocess requires an array but we build the cmd by joining on a space - cmd = ' '.join( - hook._import_cmd(self._config_import['target_dir'], - append=self._config_import['append'], - file_type=self._config_import['file_type'], - split_by=self._config_import['split_by'], - direct=self._config_import['direct'], - driver=self._config_import['driver']) - ) - - if self._config_import['append']: - assert '--append' in cmd - - if self._config_import['direct']: - assert '--direct' in cmd - - assert '--target-dir {}'.format( - self._config_import['target_dir']) in cmd - - assert '--driver {}'.format(self._config_import['driver']) in cmd - assert '--split-by {}'.format(self._config_import['split_by']) in cmd - - def test_get_export_format_argument(self): - hook = SqoopHook() - assert "--as-avrodatafile" in hook._get_export_format_argument('avro') - assert "--as-parquetfile" in hook._get_export_format_argument( - 'parquet') - assert "--as-sequencefile" in hook._get_export_format_argument( - 'sequence') - assert "--as-textfile" in hook._get_export_format_argument('text') - assert "--as-textfile" in hook._get_export_format_argument('unknown') - - -if __name__ == '__main__': - unittest.main() http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/hooks/test_aws_hook.py ---------------------------------------------------------------------- diff --git a/tests/contrib/hooks/test_aws_hook.py b/tests/contrib/hooks/test_aws_hook.py new file mode 100644 index 0000000..6f13e58 --- /dev/null +++ b/tests/contrib/hooks/test_aws_hook.py @@ -0,0 +1,47 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest +import boto3 + +from airflow import configuration +from airflow.contrib.hooks.aws_hook import AwsHook + + +try: + from moto import mock_emr +except ImportError: + mock_emr = None + + +class TestAwsHook(unittest.TestCase): + @mock_emr + def setUp(self): + configuration.load_test_config() + + @unittest.skipIf(mock_emr is None, 'mock_emr package not present') + @mock_emr + def test_get_client_type_returns_a_boto3_client_of_the_requested_type(self): + client = boto3.client('emr', region_name='us-east-1') + if len(client.list_clusters()['Clusters']): + raise ValueError('AWS not properly mocked') + + hook = AwsHook(aws_conn_id='aws_default') + client_from_hook = hook.get_client_type('emr') + + self.assertEqual(client_from_hook.list_clusters()['Clusters'], []) + +if __name__ == '__main__': + unittest.main() http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/hooks/test_bigquery_hook.py ---------------------------------------------------------------------- diff --git a/tests/contrib/hooks/test_bigquery_hook.py b/tests/contrib/hooks/test_bigquery_hook.py new file mode 100644 index 0000000..0adffc5 --- /dev/null +++ b/tests/contrib/hooks/test_bigquery_hook.py @@ -0,0 +1,139 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest + +from airflow.contrib.hooks import bigquery_hook as hook + + +class TestBigQueryTableSplitter(unittest.TestCase): + def test_internal_need_default_project(self): + with self.assertRaises(Exception) as context: + hook._split_tablename('dataset.table', None) + + self.assertIn('INTERNAL: No default project is specified', + str(context.exception), "") + + def test_split_dataset_table(self): + project, dataset, table = hook._split_tablename('dataset.table', + 'project') + self.assertEqual("project", project) + self.assertEqual("dataset", dataset) + self.assertEqual("table", table) + + def test_split_project_dataset_table(self): + project, dataset, table = hook._split_tablename('alternative:dataset.table', + 'project') + self.assertEqual("alternative", project) + self.assertEqual("dataset", dataset) + self.assertEqual("table", table) + + def test_sql_split_project_dataset_table(self): + project, dataset, table = hook._split_tablename('alternative.dataset.table', + 'project') + self.assertEqual("alternative", project) + self.assertEqual("dataset", dataset) + self.assertEqual("table", table) + + def test_invalid_syntax_column_double_project(self): + with self.assertRaises(Exception) as context: + hook._split_tablename('alt1:alt.dataset.table', + 'project') + + self.assertIn('Use either : or . to specify project', + str(context.exception), "") + self.assertFalse('Format exception for' in str(context.exception)) + + def test_invalid_syntax_double_column(self): + with self.assertRaises(Exception) as context: + hook._split_tablename('alt1:alt:dataset.table', + 'project') + + self.assertIn('Expect format of (<project:)<dataset>.<table>', + str(context.exception), "") + self.assertFalse('Format exception for' in str(context.exception)) + + def test_invalid_syntax_tiple_dot(self): + with self.assertRaises(Exception) as context: + hook._split_tablename('alt1.alt.dataset.table', + 'project') + + self.assertIn('Expect format of (<project.|<project:)<dataset>.<table>', + str(context.exception), "") + self.assertFalse('Format exception for' in str(context.exception)) + + def test_invalid_syntax_column_double_project_var(self): + with self.assertRaises(Exception) as context: + hook._split_tablename('alt1:alt.dataset.table', + 'project', 'var_x') + + self.assertIn('Use either : or . to specify project', + str(context.exception), "") + self.assertIn('Format exception for var_x:', + str(context.exception), "") + + def test_invalid_syntax_double_column_var(self): + with self.assertRaises(Exception) as context: + hook._split_tablename('alt1:alt:dataset.table', + 'project', 'var_x') + + self.assertIn('Expect format of (<project:)<dataset>.<table>', + str(context.exception), "") + self.assertIn('Format exception for var_x:', + str(context.exception), "") + + def test_invalid_syntax_tiple_dot_var(self): + with self.assertRaises(Exception) as context: + hook._split_tablename('alt1.alt.dataset.table', + 'project', 'var_x') + + self.assertIn('Expect format of (<project.|<project:)<dataset>.<table>', + str(context.exception), "") + self.assertIn('Format exception for var_x:', + str(context.exception), "") + +class TestBigQueryHookSourceFormat(unittest.TestCase): + def test_invalid_source_format(self): + with self.assertRaises(Exception) as context: + hook.BigQueryBaseCursor("test", "test").run_load("test.test", "test_schema.json", ["test_data.json"], source_format="json") + + # since we passed 'json' in, and it's not valid, make sure it's present in the error string. + self.assertIn("JSON", str(context.exception)) + + +class TestBigQueryBaseCursor(unittest.TestCase): + def test_invalid_schema_update_options(self): + with self.assertRaises(Exception) as context: + hook.BigQueryBaseCursor("test", "test").run_load( + "test.test", + "test_schema.json", + ["test_data.json"], + schema_update_options=["THIS IS NOT VALID"] + ) + self.assertIn("THIS IS NOT VALID", str(context.exception)) + + def test_invalid_schema_update_and_write_disposition(self): + with self.assertRaises(Exception) as context: + hook.BigQueryBaseCursor("test", "test").run_load( + "test.test", + "test_schema.json", + ["test_data.json"], + schema_update_options=['ALLOW_FIELD_ADDITION'], + write_disposition='WRITE_EMPTY' + ) + self.assertIn("schema_update_options is only", str(context.exception)) + +if __name__ == '__main__': + unittest.main() http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/hooks/test_databricks_hook.py ---------------------------------------------------------------------- diff --git a/tests/contrib/hooks/test_databricks_hook.py b/tests/contrib/hooks/test_databricks_hook.py new file mode 100644 index 0000000..6c789f9 --- /dev/null +++ b/tests/contrib/hooks/test_databricks_hook.py @@ -0,0 +1,226 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest + +from airflow import __version__ +from airflow.contrib.hooks.databricks_hook import DatabricksHook, RunState, SUBMIT_RUN_ENDPOINT +from airflow.exceptions import AirflowException +from airflow.models import Connection +from airflow.utils import db +from requests import exceptions as requests_exceptions + +try: + from unittest import mock +except ImportError: + try: + import mock + except ImportError: + mock = None + +TASK_ID = 'databricks-operator' +DEFAULT_CONN_ID = 'databricks_default' +NOTEBOOK_TASK = { + 'notebook_path': '/test' +} +NEW_CLUSTER = { + 'spark_version': '2.0.x-scala2.10', + 'node_type_id': 'r3.xlarge', + 'num_workers': 1 +} +RUN_ID = 1 +HOST = 'xx.cloud.databricks.com' +HOST_WITH_SCHEME = 'https://xx.cloud.databricks.com' +LOGIN = 'login' +PASSWORD = 'password' +USER_AGENT_HEADER = {'user-agent': 'airflow-{v}'.format(v=__version__)} +RUN_PAGE_URL = 'https://XX.cloud.databricks.com/#jobs/1/runs/1' +LIFE_CYCLE_STATE = 'PENDING' +STATE_MESSAGE = 'Waiting for cluster' +GET_RUN_RESPONSE = { + 'run_page_url': RUN_PAGE_URL, + 'state': { + 'life_cycle_state': LIFE_CYCLE_STATE, + 'state_message': STATE_MESSAGE + } +} +RESULT_STATE = None + + +def submit_run_endpoint(host): + """ + Utility function to generate the submit run endpoint given the host. + """ + return 'https://{}/api/2.0/jobs/runs/submit'.format(host) + + +def get_run_endpoint(host): + """ + Utility function to generate the get run endpoint given the host. + """ + return 'https://{}/api/2.0/jobs/runs/get'.format(host) + +def cancel_run_endpoint(host): + """ + Utility function to generate the get run endpoint given the host. + """ + return 'https://{}/api/2.0/jobs/runs/cancel'.format(host) + +class DatabricksHookTest(unittest.TestCase): + """ + Tests for DatabricksHook. + """ + @db.provide_session + def setUp(self, session=None): + conn = session.query(Connection) \ + .filter(Connection.conn_id == DEFAULT_CONN_ID) \ + .first() + conn.host = HOST + conn.login = LOGIN + conn.password = PASSWORD + session.commit() + + self.hook = DatabricksHook() + + def test_parse_host_with_proper_host(self): + host = self.hook._parse_host(HOST) + self.assertEquals(host, HOST) + + def test_parse_host_with_scheme(self): + host = self.hook._parse_host(HOST_WITH_SCHEME) + self.assertEquals(host, HOST) + + def test_init_bad_retry_limit(self): + with self.assertRaises(AssertionError): + DatabricksHook(retry_limit = 0) + + @mock.patch('airflow.contrib.hooks.databricks_hook.logging') + @mock.patch('airflow.contrib.hooks.databricks_hook.requests') + def test_do_api_call_with_error_retry(self, mock_requests, mock_logging): + for exception in [requests_exceptions.ConnectionError, requests_exceptions.Timeout]: + mock_requests.reset_mock() + mock_logging.reset_mock() + mock_requests.post.side_effect = exception() + + with self.assertRaises(AirflowException): + self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {}) + + self.assertEquals(len(mock_logging.error.mock_calls), self.hook.retry_limit) + + @mock.patch('airflow.contrib.hooks.databricks_hook.requests') + def test_do_api_call_with_bad_status_code(self, mock_requests): + mock_requests.codes.ok = 200 + status_code_mock = mock.PropertyMock(return_value=500) + type(mock_requests.post.return_value).status_code = status_code_mock + with self.assertRaises(AirflowException): + self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {}) + + @mock.patch('airflow.contrib.hooks.databricks_hook.requests') + def test_submit_run(self, mock_requests): + mock_requests.codes.ok = 200 + mock_requests.post.return_value.json.return_value = {'run_id': '1'} + status_code_mock = mock.PropertyMock(return_value=200) + type(mock_requests.post.return_value).status_code = status_code_mock + json = { + 'notebook_task': NOTEBOOK_TASK, + 'new_cluster': NEW_CLUSTER + } + run_id = self.hook.submit_run(json) + + self.assertEquals(run_id, '1') + mock_requests.post.assert_called_once_with( + submit_run_endpoint(HOST), + json={ + 'notebook_task': NOTEBOOK_TASK, + 'new_cluster': NEW_CLUSTER, + }, + auth=(LOGIN, PASSWORD), + headers=USER_AGENT_HEADER, + timeout=self.hook.timeout_seconds) + + @mock.patch('airflow.contrib.hooks.databricks_hook.requests') + def test_get_run_page_url(self, mock_requests): + mock_requests.codes.ok = 200 + mock_requests.get.return_value.json.return_value = GET_RUN_RESPONSE + status_code_mock = mock.PropertyMock(return_value=200) + type(mock_requests.get.return_value).status_code = status_code_mock + + run_page_url = self.hook.get_run_page_url(RUN_ID) + + self.assertEquals(run_page_url, RUN_PAGE_URL) + mock_requests.get.assert_called_once_with( + get_run_endpoint(HOST), + json={'run_id': RUN_ID}, + auth=(LOGIN, PASSWORD), + headers=USER_AGENT_HEADER, + timeout=self.hook.timeout_seconds) + + @mock.patch('airflow.contrib.hooks.databricks_hook.requests') + def test_get_run_state(self, mock_requests): + mock_requests.codes.ok = 200 + mock_requests.get.return_value.json.return_value = GET_RUN_RESPONSE + status_code_mock = mock.PropertyMock(return_value=200) + type(mock_requests.get.return_value).status_code = status_code_mock + + run_state = self.hook.get_run_state(RUN_ID) + + self.assertEquals(run_state, RunState( + LIFE_CYCLE_STATE, + RESULT_STATE, + STATE_MESSAGE)) + mock_requests.get.assert_called_once_with( + get_run_endpoint(HOST), + json={'run_id': RUN_ID}, + auth=(LOGIN, PASSWORD), + headers=USER_AGENT_HEADER, + timeout=self.hook.timeout_seconds) + + @mock.patch('airflow.contrib.hooks.databricks_hook.requests') + def test_cancel_run(self, mock_requests): + mock_requests.codes.ok = 200 + mock_requests.post.return_value.json.return_value = GET_RUN_RESPONSE + status_code_mock = mock.PropertyMock(return_value=200) + type(mock_requests.post.return_value).status_code = status_code_mock + + self.hook.cancel_run(RUN_ID) + + mock_requests.post.assert_called_once_with( + cancel_run_endpoint(HOST), + json={'run_id': RUN_ID}, + auth=(LOGIN, PASSWORD), + headers=USER_AGENT_HEADER, + timeout=self.hook.timeout_seconds) + +class RunStateTest(unittest.TestCase): + def test_is_terminal_true(self): + terminal_states = ['TERMINATED', 'SKIPPED', 'INTERNAL_ERROR'] + for state in terminal_states: + run_state = RunState(state, '', '') + self.assertTrue(run_state.is_terminal) + + def test_is_terminal_false(self): + non_terminal_states = ['PENDING', 'RUNNING', 'TERMINATING'] + for state in non_terminal_states: + run_state = RunState(state, '', '') + self.assertFalse(run_state.is_terminal) + + def test_is_terminal_with_nonexistent_life_cycle_state(self): + run_state = RunState('blah', '', '') + with self.assertRaises(AirflowException): + run_state.is_terminal + + def test_is_successful(self): + run_state = RunState('TERMINATED', 'SUCCESS', '') + self.assertTrue(run_state.is_successful) http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/hooks/test_emr_hook.py ---------------------------------------------------------------------- diff --git a/tests/contrib/hooks/test_emr_hook.py b/tests/contrib/hooks/test_emr_hook.py new file mode 100644 index 0000000..119df99 --- /dev/null +++ b/tests/contrib/hooks/test_emr_hook.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest +import boto3 + +from airflow import configuration +from airflow.contrib.hooks.emr_hook import EmrHook + + +try: + from moto import mock_emr +except ImportError: + mock_emr = None + + +class TestEmrHook(unittest.TestCase): + @mock_emr + def setUp(self): + configuration.load_test_config() + + @unittest.skipIf(mock_emr is None, 'mock_emr package not present') + @mock_emr + def test_get_conn_returns_a_boto3_connection(self): + hook = EmrHook(aws_conn_id='aws_default') + self.assertIsNotNone(hook.get_conn().list_clusters()) + + @unittest.skipIf(mock_emr is None, 'mock_emr package not present') + @mock_emr + def test_create_job_flow_uses_the_emr_config_to_create_a_cluster(self): + client = boto3.client('emr', region_name='us-east-1') + if len(client.list_clusters()['Clusters']): + raise ValueError('AWS not properly mocked') + + hook = EmrHook(aws_conn_id='aws_default', emr_conn_id='emr_default') + cluster = hook.create_job_flow({'Name': 'test_cluster'}) + + self.assertEqual(client.list_clusters()['Clusters'][0]['Id'], cluster['JobFlowId']) + +if __name__ == '__main__': + unittest.main() http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/hooks/test_gcp_dataflow_hook.py ---------------------------------------------------------------------- diff --git a/tests/contrib/hooks/test_gcp_dataflow_hook.py b/tests/contrib/hooks/test_gcp_dataflow_hook.py new file mode 100644 index 0000000..797d40c --- /dev/null +++ b/tests/contrib/hooks/test_gcp_dataflow_hook.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest +from airflow.contrib.hooks.gcp_dataflow_hook import DataFlowHook + +try: + from unittest import mock +except ImportError: + try: + import mock + except ImportError: + mock = None + + +TASK_ID = 'test-python-dataflow' +PY_FILE = 'apache_beam.examples.wordcount' +PY_OPTIONS = ['-m'] +OPTIONS = { + 'project': 'test', + 'staging_location': 'gs://test/staging' +} +BASE_STRING = 'airflow.contrib.hooks.gcp_api_base_hook.{}' +DATAFLOW_STRING = 'airflow.contrib.hooks.gcp_dataflow_hook.{}' + + +def mock_init(self, gcp_conn_id, delegate_to=None): + pass + + +class DataFlowHookTest(unittest.TestCase): + + def setUp(self): + with mock.patch(BASE_STRING.format('GoogleCloudBaseHook.__init__'), + new=mock_init): + self.dataflow_hook = DataFlowHook(gcp_conn_id='test') + + @mock.patch(DATAFLOW_STRING.format('DataFlowHook._start_dataflow')) + def test_start_python_dataflow(self, internal_dataflow_mock): + self.dataflow_hook.start_python_dataflow( + task_id=TASK_ID, variables=OPTIONS, + dataflow=PY_FILE, py_options=PY_OPTIONS) + internal_dataflow_mock.assert_called_once_with( + TASK_ID, OPTIONS, PY_FILE, mock.ANY, ['python'] + PY_OPTIONS) http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/hooks/test_spark_submit_hook.py ---------------------------------------------------------------------- diff --git a/tests/contrib/hooks/test_spark_submit_hook.py b/tests/contrib/hooks/test_spark_submit_hook.py new file mode 100644 index 0000000..24315fa --- /dev/null +++ b/tests/contrib/hooks/test_spark_submit_hook.py @@ -0,0 +1,197 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import sys +import unittest +from io import StringIO + +import mock + +from airflow import configuration, models +from airflow.utils import db +from airflow.contrib.hooks.spark_submit_hook import SparkSubmitHook + + +class TestSparkSubmitHook(unittest.TestCase): + + _spark_job_file = 'test_application.py' + _config = { + 'conf': { + 'parquet.compression': 'SNAPPY' + }, + 'conn_id': 'default_spark', + 'files': 'hive-site.xml', + 'py_files': 'sample_library.py', + 'jars': 'parquet.jar', + 'executor_cores': 4, + 'executor_memory': '22g', + 'keytab': 'privileged_user.keytab', + 'principal': 'user/[email protected]', + 'name': 'spark-job', + 'num_executors': 10, + 'verbose': True, + 'driver_memory': '3g', + 'java_class': 'com.foo.bar.AppMain' + } + + def setUp(self): + + if sys.version_info[0] == 3: + raise unittest.SkipTest('TestSparkSubmitHook won\'t work with ' + 'python3. No need to test anything here') + + configuration.load_test_config() + db.merge_conn( + models.Connection( + conn_id='spark_yarn_cluster', conn_type='spark', + host='yarn://yarn-master', extra='{"queue": "root.etl", "deploy-mode": "cluster"}') + ) + db.merge_conn( + models.Connection( + conn_id='spark_default_mesos', conn_type='spark', + host='mesos://host', port=5050) + ) + + db.merge_conn( + models.Connection( + conn_id='spark_home_set', conn_type='spark', + host='yarn://yarn-master', + extra='{"spark-home": "/opt/myspark"}') + ) + + db.merge_conn( + models.Connection( + conn_id='spark_home_not_set', conn_type='spark', + host='yarn://yarn-master') + ) + + def test_build_command(self): + hook = SparkSubmitHook(**self._config) + + # The subprocess requires an array but we build the cmd by joining on a space + cmd = ' '.join(hook._build_command(self._spark_job_file)) + + # Check if the URL gets build properly and everything exists. + assert self._spark_job_file in cmd + + # Check all the parameters + assert "--files {}".format(self._config['files']) in cmd + assert "--py-files {}".format(self._config['py_files']) in cmd + assert "--jars {}".format(self._config['jars']) in cmd + assert "--executor-cores {}".format(self._config['executor_cores']) in cmd + assert "--executor-memory {}".format(self._config['executor_memory']) in cmd + assert "--keytab {}".format(self._config['keytab']) in cmd + assert "--principal {}".format(self._config['principal']) in cmd + assert "--name {}".format(self._config['name']) in cmd + assert "--num-executors {}".format(self._config['num_executors']) in cmd + assert "--class {}".format(self._config['java_class']) in cmd + assert "--driver-memory {}".format(self._config['driver_memory']) in cmd + + # Check if all config settings are there + for k in self._config['conf']: + assert "--conf {0}={1}".format(k, self._config['conf'][k]) in cmd + + if self._config['verbose']: + assert "--verbose" in cmd + + @mock.patch('airflow.contrib.hooks.spark_submit_hook.subprocess') + def test_submit(self, mock_process): + # We don't have spark-submit available, and this is hard to mock, so let's + # just use this simple mock. + mock_Popen = mock_process.Popen.return_value + mock_Popen.stdout = StringIO(u'stdout') + mock_Popen.stderr = StringIO(u'stderr') + mock_Popen.returncode = None + mock_Popen.communicate.return_value = ['extra stdout', 'extra stderr'] + hook = SparkSubmitHook() + hook.submit(self._spark_job_file) + + def test_resolve_connection(self): + + # Default to the standard yarn connection because conn_id does not exists + hook = SparkSubmitHook(conn_id='') + self.assertEqual(hook._resolve_connection(), ('yarn', None, None, None)) + assert "--master yarn" in ' '.join(hook._build_command(self._spark_job_file)) + + # Default to the standard yarn connection + hook = SparkSubmitHook(conn_id='spark_default') + self.assertEqual( + hook._resolve_connection(), + ('yarn', 'root.default', None, None) + ) + cmd = ' '.join(hook._build_command(self._spark_job_file)) + assert "--master yarn" in cmd + assert "--queue root.default" in cmd + + # Connect to a mesos master + hook = SparkSubmitHook(conn_id='spark_default_mesos') + self.assertEqual( + hook._resolve_connection(), + ('mesos://host:5050', None, None, None) + ) + + cmd = ' '.join(hook._build_command(self._spark_job_file)) + assert "--master mesos://host:5050" in cmd + + # Set specific queue and deploy mode + hook = SparkSubmitHook(conn_id='spark_yarn_cluster') + self.assertEqual( + hook._resolve_connection(), + ('yarn://yarn-master', 'root.etl', 'cluster', None) + ) + + cmd = ' '.join(hook._build_command(self._spark_job_file)) + assert "--master yarn://yarn-master" in cmd + assert "--queue root.etl" in cmd + assert "--deploy-mode cluster" in cmd + + # Set the spark home + hook = SparkSubmitHook(conn_id='spark_home_set') + self.assertEqual( + hook._resolve_connection(), + ('yarn://yarn-master', None, None, '/opt/myspark') + ) + + cmd = ' '.join(hook._build_command(self._spark_job_file)) + assert cmd.startswith('/opt/myspark/bin/spark-submit') + + # Spark home not set + hook = SparkSubmitHook(conn_id='spark_home_not_set') + self.assertEqual( + hook._resolve_connection(), + ('yarn://yarn-master', None, None, None) + ) + + cmd = ' '.join(hook._build_command(self._spark_job_file)) + assert cmd.startswith('spark-submit') + + def test_process_log(self): + # Must select yarn connection + hook = SparkSubmitHook(conn_id='spark_yarn_cluster') + + log_lines = [ + 'SPARK_MAJOR_VERSION is set to 2, using Spark2', + 'WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable', + 'WARN DomainSocketFactory: The short-circuit local reads feature cannot be used because libhadoop cannot be loaded.', + 'INFO Client: Requesting a new application from cluster with 10 NodeManagers', + 'INFO Client: Submitting application application_1486558679801_1820 to ResourceManager' + ] + + hook._process_log(log_lines) + + assert hook._yarn_application_id == 'application_1486558679801_1820' + + +if __name__ == '__main__': + unittest.main() http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/hooks/test_sqoop_hook.py ---------------------------------------------------------------------- diff --git a/tests/contrib/hooks/test_sqoop_hook.py b/tests/contrib/hooks/test_sqoop_hook.py new file mode 100644 index 0000000..ca8033b --- /dev/null +++ b/tests/contrib/hooks/test_sqoop_hook.py @@ -0,0 +1,218 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import json +import unittest + +from airflow import configuration, models +from airflow.contrib.hooks.sqoop_hook import SqoopHook +from airflow.utils import db + + +class TestSqoopHook(unittest.TestCase): + _config = { + 'conn_id': 'sqoop_test', + 'num_mappers': 22, + 'verbose': True, + 'properties': { + 'mapred.map.max.attempts': '1' + } + } + _config_export = { + 'table': 'domino.export_data_to', + 'export_dir': '/hdfs/data/to/be/exported', + 'input_null_string': '\n', + 'input_null_non_string': '\t', + 'staging_table': 'database.staging', + 'clear_staging_table': True, + 'enclosed_by': '"', + 'escaped_by': '\\', + 'input_fields_terminated_by': '|', + 'input_lines_terminated_by': '\n', + 'input_optionally_enclosed_by': '"', + 'batch': True, + 'relaxed_isolation': True + } + _config_import = { + 'target_dir': '/hdfs/data/target/location', + 'append': True, + 'file_type': 'parquet', + 'split_by': '\n', + 'direct': True, + 'driver': 'com.microsoft.jdbc.sqlserver.SQLServerDriver' + } + + _config_json = { + 'namenode': 'http://0.0.0.0:50070/', + 'job_tracker': 'http://0.0.0.0:50030/', + 'libjars': '/path/to/jars', + 'files': '/path/to/files', + 'archives': '/path/to/archives' + } + + def setUp(self): + configuration.load_test_config() + db.merge_conn( + models.Connection( + conn_id='sqoop_test', conn_type='sqoop', + host='rmdbs', port=5050, extra=json.dumps(self._config_json) + ) + ) + + def test_popen(self): + hook = SqoopHook(**self._config) + + # Should go well + hook.Popen(['ls']) + + # Should give an exception + with self.assertRaises(OSError): + hook.Popen('exit 1') + + def test_submit(self): + hook = SqoopHook(**self._config) + + cmd = ' '.join(hook._prepare_command()) + + # Check if the config has been extracted from the json + if self._config_json['namenode']: + assert "-fs {}".format(self._config_json['namenode']) in cmd + + if self._config_json['job_tracker']: + assert "-jt {}".format(self._config_json['job_tracker']) in cmd + + if self._config_json['libjars']: + assert "-libjars {}".format(self._config_json['libjars']) in cmd + + if self._config_json['files']: + assert "-files {}".format(self._config_json['files']) in cmd + + if self._config_json['archives']: + assert "-archives {}".format(self._config_json['archives']) in cmd + + # Check the regulator stuff passed by the default constructor + if self._config['verbose']: + assert "--verbose" in cmd + + if self._config['num_mappers']: + assert "--num-mappers {}".format( + self._config['num_mappers']) in cmd + + print(self._config['properties']) + for key, value in self._config['properties'].items(): + assert "-D {}={}".format(key, value) in cmd + + # We don't have the sqoop binary available, and this is hard to mock, + # so just accept an exception for now. + with self.assertRaises(OSError): + hook.export_table(**self._config_export) + + with self.assertRaises(OSError): + hook.import_table(table='schema.table', + target_dir='/sqoop/example/path') + + with self.assertRaises(OSError): + hook.import_query(query='SELECT * FROM sometable', + target_dir='/sqoop/example/path') + + def test_export_cmd(self): + hook = SqoopHook() + + # The subprocess requires an array but we build the cmd by joining on a space + cmd = ' '.join( + hook._export_cmd( + self._config_export['table'], + self._config_export['export_dir'], + input_null_string=self._config_export['input_null_string'], + input_null_non_string=self._config_export[ + 'input_null_non_string'], + staging_table=self._config_export['staging_table'], + clear_staging_table=self._config_export['clear_staging_table'], + enclosed_by=self._config_export['enclosed_by'], + escaped_by=self._config_export['escaped_by'], + input_fields_terminated_by=self._config_export[ + 'input_fields_terminated_by'], + input_lines_terminated_by=self._config_export[ + 'input_lines_terminated_by'], + input_optionally_enclosed_by=self._config_export[ + 'input_optionally_enclosed_by'], + batch=self._config_export['batch'], + relaxed_isolation=self._config_export['relaxed_isolation']) + ) + + assert "--input-null-string {}".format( + self._config_export['input_null_string']) in cmd + assert "--input-null-non-string {}".format( + self._config_export['input_null_non_string']) in cmd + assert "--staging-table {}".format( + self._config_export['staging_table']) in cmd + assert "--enclosed-by {}".format( + self._config_export['enclosed_by']) in cmd + assert "--escaped-by {}".format( + self._config_export['escaped_by']) in cmd + assert "--input-fields-terminated-by {}".format( + self._config_export['input_fields_terminated_by']) in cmd + assert "--input-lines-terminated-by {}".format( + self._config_export['input_lines_terminated_by']) in cmd + assert "--input-optionally-enclosed-by {}".format( + self._config_export['input_optionally_enclosed_by']) in cmd + + if self._config_export['clear_staging_table']: + assert "--clear-staging-table" in cmd + + if self._config_export['batch']: + assert "--batch" in cmd + + if self._config_export['relaxed_isolation']: + assert "--relaxed-isolation" in cmd + + def test_import_cmd(self): + hook = SqoopHook() + + # The subprocess requires an array but we build the cmd by joining on a space + cmd = ' '.join( + hook._import_cmd(self._config_import['target_dir'], + append=self._config_import['append'], + file_type=self._config_import['file_type'], + split_by=self._config_import['split_by'], + direct=self._config_import['direct'], + driver=self._config_import['driver']) + ) + + if self._config_import['append']: + assert '--append' in cmd + + if self._config_import['direct']: + assert '--direct' in cmd + + assert '--target-dir {}'.format( + self._config_import['target_dir']) in cmd + + assert '--driver {}'.format(self._config_import['driver']) in cmd + assert '--split-by {}'.format(self._config_import['split_by']) in cmd + + def test_get_export_format_argument(self): + hook = SqoopHook() + assert "--as-avrodatafile" in hook._get_export_format_argument('avro') + assert "--as-parquetfile" in hook._get_export_format_argument( + 'parquet') + assert "--as-sequencefile" in hook._get_export_format_argument( + 'sequence') + assert "--as-textfile" in hook._get_export_format_argument('text') + assert "--as-textfile" in hook._get_export_format_argument('unknown') + + +if __name__ == '__main__': + unittest.main() http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/hooks/test_zendesk_hook.py ---------------------------------------------------------------------- diff --git a/tests/contrib/hooks/test_zendesk_hook.py b/tests/contrib/hooks/test_zendesk_hook.py new file mode 100644 index 0000000..7751a2b --- /dev/null +++ b/tests/contrib/hooks/test_zendesk_hook.py @@ -0,0 +1,89 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest + +import mock + +from airflow.hooks.zendesk_hook import ZendeskHook +from zdesk import RateLimitError + + +class TestZendeskHook(unittest.TestCase): + + @mock.patch("airflow.hooks.zendesk_hook.time") + def test_sleeps_for_correct_interval(self, mocked_time): + sleep_time = 10 + # To break out of the otherwise infinite tries + mocked_time.sleep = mock.Mock(side_effect=ValueError, return_value=3) + conn_mock = mock.Mock() + mock_response = mock.Mock() + mock_response.headers.get.return_value = sleep_time + conn_mock.call = mock.Mock( + side_effect=RateLimitError(msg="some message", code="some code", + response=mock_response)) + + zendesk_hook = ZendeskHook("conn_id") + zendesk_hook.get_conn = mock.Mock(return_value=conn_mock) + + with self.assertRaises(ValueError): + zendesk_hook.call("some_path", get_all_pages=False) + mocked_time.sleep.assert_called_with(sleep_time) + + @mock.patch("airflow.hooks.zendesk_hook.Zendesk") + def test_returns_single_page_if_get_all_pages_false(self, _): + zendesk_hook = ZendeskHook("conn_id") + mock_connection = mock.Mock() + mock_connection.host = "some_host" + zendesk_hook.get_connection = mock.Mock(return_value=mock_connection) + zendesk_hook.get_conn() + + mock_conn = mock.Mock() + mock_call = mock.Mock( + return_value={'next_page': 'https://some_host/something', 'path': + []}) + mock_conn.call = mock_call + zendesk_hook.get_conn = mock.Mock(return_value=mock_conn) + zendesk_hook.call("path", get_all_pages=False) + mock_call.assert_called_once_with("path", None) + + @mock.patch("airflow.hooks.zendesk_hook.Zendesk") + def test_returns_multiple_pages_if_get_all_pages_true(self, _): + zendesk_hook = ZendeskHook("conn_id") + mock_connection = mock.Mock() + mock_connection.host = "some_host" + zendesk_hook.get_connection = mock.Mock(return_value=mock_connection) + zendesk_hook.get_conn() + + mock_conn = mock.Mock() + mock_call = mock.Mock( + return_value={'next_page': 'https://some_host/something', 'path': []}) + mock_conn.call = mock_call + zendesk_hook.get_conn = mock.Mock(return_value=mock_conn) + zendesk_hook.call("path", get_all_pages=True) + assert mock_call.call_count == 2 + + @mock.patch("airflow.hooks.zendesk_hook.Zendesk") + def test_zdesk_is_inited_correctly(self, mock_zendesk): + conn_mock = mock.Mock() + conn_mock.host = "conn_host" + conn_mock.login = "conn_login" + conn_mock.password = "conn_pass" + + zendesk_hook = ZendeskHook("conn_id") + zendesk_hook.get_connection = mock.Mock(return_value=conn_mock) + zendesk_hook.get_conn() + mock_zendesk.assert_called_with('https://conn_host', 'conn_login', + 'conn_pass', True)
