Repository: incubator-airflow Updated Branches: refs/heads/v1-9-test 9e209bf30 -> bfddae724
[AIRFLOW-1695] Add RedshiftHook using boto3 Adds RedshiftHook class, allowing for management of AWS Redshift clusters and snapshots using boto3 library. Also adds new test file and unit tests for class methods. Closes #2717 from andyxhadji/1695 (cherry picked from commit 4fb7a90b36ec1daf169a65aa4adf28a31b30fbc5) Signed-off-by: Bolke de Bruin <[email protected]> Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/bfddae72 Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/bfddae72 Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/bfddae72 Branch: refs/heads/v1-9-test Commit: bfddae72411f7f7be67bc18d2b867ddc093d34f1 Parents: 9e209bf Author: Andy Hadjigeorgiou <[email protected]> Authored: Mon Oct 30 20:36:18 2017 +0100 Committer: Bolke de Bruin <[email protected]> Committed: Mon Oct 30 20:36:31 2017 +0100 ---------------------------------------------------------------------- airflow/contrib/hooks/__init__.py | 1 + airflow/contrib/hooks/aws_hook.py | 16 ++-- airflow/contrib/hooks/redshift_hook.py | 100 +++++++++++++++++++++++++ airflow/hooks/__init__.py | 1 - tests/contrib/hooks/test_redshift_hook.py | 77 +++++++++++++++++++ 5 files changed, 186 insertions(+), 9 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/bfddae72/airflow/contrib/hooks/__init__.py ---------------------------------------------------------------------- diff --git a/airflow/contrib/hooks/__init__.py b/airflow/contrib/hooks/__init__.py index 2891980..6d45ace 100644 --- a/airflow/contrib/hooks/__init__.py +++ b/airflow/contrib/hooks/__init__.py @@ -41,6 +41,7 @@ _hooks = { 'gcs_hook': ['GoogleCloudStorageHook'], 'datastore_hook': ['DatastoreHook'], 'gcp_cloudml_hook': ['CloudMLHook'], + 'redshift_hook': ['RedshiftHook'], 'gcp_dataproc_hook': ['DataProcHook'], 'gcp_dataflow_hook': ['DataFlowHook'], 'spark_submit_operator': ['SparkSubmitOperator'], http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/bfddae72/airflow/contrib/hooks/aws_hook.py ---------------------------------------------------------------------- diff --git a/airflow/contrib/hooks/aws_hook.py b/airflow/contrib/hooks/aws_hook.py index ca2ee05..8573db3 100644 --- a/airflow/contrib/hooks/aws_hook.py +++ b/airflow/contrib/hooks/aws_hook.py @@ -85,28 +85,28 @@ class AwsHook(BaseHook): aws_access_key_id = None aws_secret_access_key = None s3_endpoint_url = None - + if self.aws_conn_id: try: connection_object = self.get_connection(self.aws_conn_id) if connection_object.login: aws_access_key_id = connection_object.login aws_secret_access_key = connection_object.password - + elif 'aws_secret_access_key' in connection_object.extra_dejson: aws_access_key_id = connection_object.extra_dejson['aws_access_key_id'] aws_secret_access_key = connection_object.extra_dejson['aws_secret_access_key'] - + elif 's3_config_file' in connection_object.extra_dejson: aws_access_key_id, aws_secret_access_key = \ _parse_s3_config(connection_object.extra_dejson['s3_config_file'], connection_object.extra_dejson.get('s3_config_format')) - + if region_name is None: region_name = connection_object.extra_dejson.get('region_name') - - s3_endpoint_url = connection_object.extra_dejson.get('host') - + + s3_endpoint_url = connection_object.extra_dejson.get('host') + except AirflowException: # No connection found: fallback on boto3 credential strategy # http://boto3.readthedocs.io/en/latest/guide/configuration.html @@ -129,7 +129,7 @@ class AwsHook(BaseHook): def get_resource_type(self, resource_type, region_name=None): aws_access_key_id, aws_secret_access_key, region_name, endpoint_url = \ self._get_credentials(region_name) - + return boto3.resource( resource_type, region_name=region_name, http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/bfddae72/airflow/contrib/hooks/redshift_hook.py ---------------------------------------------------------------------- diff --git a/airflow/contrib/hooks/redshift_hook.py b/airflow/contrib/hooks/redshift_hook.py new file mode 100644 index 0000000..071caf2 --- /dev/null +++ b/airflow/contrib/hooks/redshift_hook.py @@ -0,0 +1,100 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from airflow.contrib.hooks.aws_hook import AwsHook + +class RedshiftHook(AwsHook): + """ + Interact with AWS Redshift, using the boto3 library + """ + def get_conn(self): + return self.get_client_type('redshift') + + # TODO: Wrap create_cluster_snapshot + def cluster_status(self, cluster_identifier): + """ + Return status of a cluster + + :param cluster_identifier: unique identifier of a cluster whose properties you are requesting + :type cluster_identifier: str + """ + # Use describe clusters + response = self.get_conn().describe_clusters(ClusterIdentifier=cluster_identifier) + # Possibly return error if cluster does not exist + return response['Clusters'][0]['ClusterStatus'] if response['Clusters'] else None + + def delete_cluster(self, cluster_identifier, skip_final_cluster_snapshot=True, final_cluster_snapshot_identifier=''): + """ + Delete a cluster and optionally create a snapshot + + :param cluster_identifier: unique identifier of a cluster whose properties you are requesting + :type cluster_identifier: str + :param skip_final_cluster_snapshot: determines if a final cluster snapshot is made before shut-down + :type skip_final_cluster_snapshot: bool + :param final_cluster_snapshot_identifier: name of final cluster snapshot + :type final_cluster_snapshot_identifier: str + """ + response = self.get_conn().delete_cluster( + ClusterIdentifier = cluster_identifier, + SkipFinalClusterSnapshot = skip_final_cluster_snapshot, + FinalClusterSnapshotIdentifier = final_cluster_snapshot_identifier + ) + return response['Cluster'] if response['Cluster'] else None + + def describe_cluster_snapshots(self, cluster_identifier): + """ + Gets a list of snapshots for a cluster + + :param cluster_identifier: unique identifier of a cluster whose properties you are requesting + :type cluster_identifier: str + """ + response = self.get_conn().describe_cluster_snapshots( + ClusterIdentifier = cluster_identifier + ) + if 'Snapshots' not in response: + return None + snapshots = response['Snapshots'] + snapshots = filter(lambda x: x['Status'], snapshots) + snapshots.sort(key=lambda x: x['SnapshotCreateTime'], reverse=True) + return snapshots + + def restore_from_cluster_snapshot(self, cluster_identifier, snapshot_identifier): + """ + Restores a cluster from it's snapshot + + :param cluster_identifier: unique identifier of a cluster whose properties you are requesting + :type cluster_identifier: str + :param snapshot_identifier: unique identifier for a snapshot of a cluster + :type snapshot_identifier: str + """ + response = self.get_conn().restore_from_cluster_snapshot( + ClusterIdentifier = cluster_identifier, + SnapshotIdentifier = snapshot_identifier + ) + return response['Cluster'] if response['Cluster'] else None + + def create_cluster_snapshot(self, snapshot_identifier, cluster_identifier): + """ + Creates a snapshot of a cluster + + :param snapshot_identifier: unique identifier for a snapshot of a cluster + :type snapshot_identifier: str + :param cluster_identifier: unique identifier of a cluster whose properties you are requesting + :type cluster_identifier: str + """ + response = self.get_conn().create_cluster_snapshot( + SnapshotIdentifier=snapshot_identifier, + ClusterIdentifier=cluster_identifier, + ) + return response['Snapshot'] if response['Snapshot'] else None http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/bfddae72/airflow/hooks/__init__.py ---------------------------------------------------------------------- diff --git a/airflow/hooks/__init__.py b/airflow/hooks/__init__.py index bb02967..6e96e2a 100644 --- a/airflow/hooks/__init__.py +++ b/airflow/hooks/__init__.py @@ -85,4 +85,3 @@ def _integrate_plugins(): "import from 'airflow.hooks.[plugin_module]' " "instead. Support for direct imports will be dropped " "entirely in Airflow 2.0.".format(i=hook_name)) - http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/bfddae72/tests/contrib/hooks/test_redshift_hook.py ---------------------------------------------------------------------- diff --git a/tests/contrib/hooks/test_redshift_hook.py b/tests/contrib/hooks/test_redshift_hook.py new file mode 100644 index 0000000..185be5e --- /dev/null +++ b/tests/contrib/hooks/test_redshift_hook.py @@ -0,0 +1,77 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest +import boto3 + +from airflow import configuration +from airflow.contrib.hooks.redshift_hook import RedshiftHook +from airflow.contrib.hooks.aws_hook import AwsHook + +try: + from moto import mock_redshift +except ImportError: + mock_redshift = None + +@mock_redshift +class TestRedshiftHook(unittest.TestCase): + def setUp(self): + configuration.load_test_config() + client = boto3.client('redshift', region_name='us-east-1') + client.create_cluster( + ClusterIdentifier='test_cluster', + NodeType='dc1.large', + MasterUsername='admin', + MasterUserPassword='mock_password' + ) + client.create_cluster( + ClusterIdentifier='test_cluster_2', + NodeType='dc1.large', + MasterUsername='admin', + MasterUserPassword='mock_password' + ) + if len(client.describe_clusters()['Clusters']) == 0: + raise ValueError('AWS not properly mocked') + + @unittest.skipIf(mock_redshift is None, 'mock_redshift package not present') + def test_get_client_type_returns_a_boto3_client_of_the_requested_type(self): + hook = AwsHook(aws_conn_id='aws_default') + client_from_hook = hook.get_client_type('redshift') + + clusters = client_from_hook.describe_clusters()['Clusters'] + self.assertEqual(len(clusters), 2) + + @unittest.skipIf(mock_redshift is None, 'mock_redshift package not present') + def test_restore_from_cluster_snapshot_returns_dict_with_cluster_data(self): + hook = RedshiftHook(aws_conn_id='aws_default') + snapshot = hook.create_cluster_snapshot('test_snapshot', 'test_cluster') + self.assertEqual(hook.restore_from_cluster_snapshot('test_cluster_3', 'test_snapshot')['ClusterIdentifier'], 'test_cluster_3') + + @unittest.skipIf(mock_redshift is None, 'mock_redshift package not present') + def test_delete_cluster_returns_a_dict_with_cluster_data(self): + hook = RedshiftHook(aws_conn_id='aws_default') + + cluster = hook.delete_cluster('test_cluster_2') + self.assertNotEqual(cluster, None) + + @unittest.skipIf(mock_redshift is None, 'mock_redshift package not present') + def test_create_cluster_snapshot_returns_snapshot_data(self): + hook = RedshiftHook(aws_conn_id='aws_default') + + snapshot = hook.create_cluster_snapshot('test_snapshot_2', 'test_cluster') + self.assertNotEqual(snapshot, None) + +if __name__ == '__main__': + unittest.main()
