Repository: incubator-airflow
Updated Branches:
  refs/heads/v1-9-test 9e209bf30 -> bfddae724


[AIRFLOW-1695] Add RedshiftHook using boto3

Adds RedshiftHook class, allowing for management
of AWS Redshift
clusters and snapshots using boto3 library. Also
adds new test file and
unit tests for class methods.

Closes #2717 from andyxhadji/1695

(cherry picked from commit 4fb7a90b36ec1daf169a65aa4adf28a31b30fbc5)
Signed-off-by: Bolke de Bruin <[email protected]>


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/bfddae72
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/bfddae72
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/bfddae72

Branch: refs/heads/v1-9-test
Commit: bfddae72411f7f7be67bc18d2b867ddc093d34f1
Parents: 9e209bf
Author: Andy Hadjigeorgiou <[email protected]>
Authored: Mon Oct 30 20:36:18 2017 +0100
Committer: Bolke de Bruin <[email protected]>
Committed: Mon Oct 30 20:36:31 2017 +0100

----------------------------------------------------------------------
 airflow/contrib/hooks/__init__.py         |   1 +
 airflow/contrib/hooks/aws_hook.py         |  16 ++--
 airflow/contrib/hooks/redshift_hook.py    | 100 +++++++++++++++++++++++++
 airflow/hooks/__init__.py                 |   1 -
 tests/contrib/hooks/test_redshift_hook.py |  77 +++++++++++++++++++
 5 files changed, 186 insertions(+), 9 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/bfddae72/airflow/contrib/hooks/__init__.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/hooks/__init__.py 
b/airflow/contrib/hooks/__init__.py
index 2891980..6d45ace 100644
--- a/airflow/contrib/hooks/__init__.py
+++ b/airflow/contrib/hooks/__init__.py
@@ -41,6 +41,7 @@ _hooks = {
     'gcs_hook': ['GoogleCloudStorageHook'],
     'datastore_hook': ['DatastoreHook'],
     'gcp_cloudml_hook': ['CloudMLHook'],
+    'redshift_hook': ['RedshiftHook'],
     'gcp_dataproc_hook': ['DataProcHook'],
     'gcp_dataflow_hook': ['DataFlowHook'],
     'spark_submit_operator': ['SparkSubmitOperator'],

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/bfddae72/airflow/contrib/hooks/aws_hook.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/hooks/aws_hook.py 
b/airflow/contrib/hooks/aws_hook.py
index ca2ee05..8573db3 100644
--- a/airflow/contrib/hooks/aws_hook.py
+++ b/airflow/contrib/hooks/aws_hook.py
@@ -85,28 +85,28 @@ class AwsHook(BaseHook):
         aws_access_key_id = None
         aws_secret_access_key = None
         s3_endpoint_url = None
-        
+
         if self.aws_conn_id:
             try:
                 connection_object = self.get_connection(self.aws_conn_id)
                 if connection_object.login:
                     aws_access_key_id = connection_object.login
                     aws_secret_access_key = connection_object.password
-    
+
                 elif 'aws_secret_access_key' in connection_object.extra_dejson:
                     aws_access_key_id = 
connection_object.extra_dejson['aws_access_key_id']
                     aws_secret_access_key = 
connection_object.extra_dejson['aws_secret_access_key']
-    
+
                 elif 's3_config_file' in connection_object.extra_dejson:
                     aws_access_key_id, aws_secret_access_key = \
                         
_parse_s3_config(connection_object.extra_dejson['s3_config_file'],
                                          
connection_object.extra_dejson.get('s3_config_format'))
-    
+
                 if region_name is None:
                     region_name = 
connection_object.extra_dejson.get('region_name')
-    
-                s3_endpoint_url = connection_object.extra_dejson.get('host') 
-    
+
+                s3_endpoint_url = connection_object.extra_dejson.get('host')
+
             except AirflowException:
                 # No connection found: fallback on boto3 credential strategy
                 # 
http://boto3.readthedocs.io/en/latest/guide/configuration.html
@@ -129,7 +129,7 @@ class AwsHook(BaseHook):
     def get_resource_type(self, resource_type, region_name=None):
         aws_access_key_id, aws_secret_access_key, region_name, endpoint_url = \
             self._get_credentials(region_name)
-        
+
         return boto3.resource(
             resource_type,
             region_name=region_name,

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/bfddae72/airflow/contrib/hooks/redshift_hook.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/hooks/redshift_hook.py 
b/airflow/contrib/hooks/redshift_hook.py
new file mode 100644
index 0000000..071caf2
--- /dev/null
+++ b/airflow/contrib/hooks/redshift_hook.py
@@ -0,0 +1,100 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from airflow.contrib.hooks.aws_hook import AwsHook
+
+class RedshiftHook(AwsHook):
+    """
+    Interact with AWS Redshift, using the boto3 library
+    """
+    def get_conn(self):
+        return self.get_client_type('redshift')
+
+    # TODO: Wrap create_cluster_snapshot
+    def cluster_status(self, cluster_identifier):
+        """
+        Return status of a cluster
+
+        :param cluster_identifier: unique identifier of a cluster whose 
properties you are requesting
+        :type cluster_identifier: str
+        """
+        # Use describe clusters
+        response = 
self.get_conn().describe_clusters(ClusterIdentifier=cluster_identifier)
+        # Possibly return error if cluster does not exist
+        return response['Clusters'][0]['ClusterStatus'] if 
response['Clusters'] else None
+
+    def delete_cluster(self, cluster_identifier, 
skip_final_cluster_snapshot=True, final_cluster_snapshot_identifier=''):
+        """
+        Delete a cluster and optionally create a snapshot
+
+        :param cluster_identifier: unique identifier of a cluster whose 
properties you are requesting
+        :type cluster_identifier: str
+        :param skip_final_cluster_snapshot: determines if a final cluster 
snapshot is made before shut-down
+        :type skip_final_cluster_snapshot: bool
+        :param final_cluster_snapshot_identifier: name of final cluster 
snapshot
+        :type final_cluster_snapshot_identifier: str
+        """
+        response = self.get_conn().delete_cluster(
+            ClusterIdentifier = cluster_identifier,
+            SkipFinalClusterSnapshot = skip_final_cluster_snapshot,
+            FinalClusterSnapshotIdentifier = final_cluster_snapshot_identifier
+        )
+        return response['Cluster'] if response['Cluster'] else None
+
+    def describe_cluster_snapshots(self, cluster_identifier):
+        """
+        Gets a list of snapshots for a cluster
+
+        :param cluster_identifier: unique identifier of a cluster whose 
properties you are requesting
+        :type cluster_identifier: str
+        """
+        response = self.get_conn().describe_cluster_snapshots(
+            ClusterIdentifier = cluster_identifier
+        )
+        if 'Snapshots' not in response:
+            return None
+        snapshots = response['Snapshots']
+        snapshots = filter(lambda x: x['Status'], snapshots)
+        snapshots.sort(key=lambda x: x['SnapshotCreateTime'], reverse=True)
+        return snapshots
+
+    def restore_from_cluster_snapshot(self, cluster_identifier, 
snapshot_identifier):
+        """
+        Restores a cluster from it's snapshot
+
+        :param cluster_identifier: unique identifier of a cluster whose 
properties you are requesting
+        :type cluster_identifier: str
+        :param snapshot_identifier: unique identifier for a snapshot of a 
cluster
+        :type snapshot_identifier: str
+        """
+        response = self.get_conn().restore_from_cluster_snapshot(
+            ClusterIdentifier = cluster_identifier,
+            SnapshotIdentifier = snapshot_identifier
+        )
+        return response['Cluster'] if response['Cluster'] else None
+
+    def create_cluster_snapshot(self, snapshot_identifier, cluster_identifier):
+        """
+        Creates a snapshot of a cluster
+
+        :param snapshot_identifier: unique identifier for a snapshot of a 
cluster
+        :type snapshot_identifier: str
+        :param cluster_identifier: unique identifier of a cluster whose 
properties you are requesting
+        :type cluster_identifier: str
+        """
+        response = self.get_conn().create_cluster_snapshot(
+            SnapshotIdentifier=snapshot_identifier,
+            ClusterIdentifier=cluster_identifier,
+        )
+        return response['Snapshot'] if response['Snapshot'] else None

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/bfddae72/airflow/hooks/__init__.py
----------------------------------------------------------------------
diff --git a/airflow/hooks/__init__.py b/airflow/hooks/__init__.py
index bb02967..6e96e2a 100644
--- a/airflow/hooks/__init__.py
+++ b/airflow/hooks/__init__.py
@@ -85,4 +85,3 @@ def _integrate_plugins():
                     "import from 'airflow.hooks.[plugin_module]' "
                     "instead. Support for direct imports will be dropped "
                     "entirely in Airflow 2.0.".format(i=hook_name))
-

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/bfddae72/tests/contrib/hooks/test_redshift_hook.py
----------------------------------------------------------------------
diff --git a/tests/contrib/hooks/test_redshift_hook.py 
b/tests/contrib/hooks/test_redshift_hook.py
new file mode 100644
index 0000000..185be5e
--- /dev/null
+++ b/tests/contrib/hooks/test_redshift_hook.py
@@ -0,0 +1,77 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+import boto3
+
+from airflow import configuration
+from airflow.contrib.hooks.redshift_hook import RedshiftHook
+from airflow.contrib.hooks.aws_hook import AwsHook
+
+try:
+    from moto import mock_redshift
+except ImportError:
+    mock_redshift = None
+
+@mock_redshift
+class TestRedshiftHook(unittest.TestCase):
+    def setUp(self):
+        configuration.load_test_config()
+        client = boto3.client('redshift', region_name='us-east-1')
+        client.create_cluster(
+            ClusterIdentifier='test_cluster',
+            NodeType='dc1.large',
+            MasterUsername='admin',
+            MasterUserPassword='mock_password'
+        )
+        client.create_cluster(
+            ClusterIdentifier='test_cluster_2',
+            NodeType='dc1.large',
+            MasterUsername='admin',
+            MasterUserPassword='mock_password'
+        )
+        if len(client.describe_clusters()['Clusters']) == 0:
+            raise ValueError('AWS not properly mocked')
+
+    @unittest.skipIf(mock_redshift is None, 'mock_redshift package not 
present')
+    def 
test_get_client_type_returns_a_boto3_client_of_the_requested_type(self):
+        hook = AwsHook(aws_conn_id='aws_default')
+        client_from_hook = hook.get_client_type('redshift')
+
+        clusters = client_from_hook.describe_clusters()['Clusters']
+        self.assertEqual(len(clusters), 2)
+
+    @unittest.skipIf(mock_redshift is None, 'mock_redshift package not 
present')
+    def 
test_restore_from_cluster_snapshot_returns_dict_with_cluster_data(self):
+        hook = RedshiftHook(aws_conn_id='aws_default')
+        snapshot = hook.create_cluster_snapshot('test_snapshot', 
'test_cluster')
+        self.assertEqual(hook.restore_from_cluster_snapshot('test_cluster_3', 
'test_snapshot')['ClusterIdentifier'], 'test_cluster_3')
+
+    @unittest.skipIf(mock_redshift is None, 'mock_redshift package not 
present')
+    def test_delete_cluster_returns_a_dict_with_cluster_data(self):
+        hook = RedshiftHook(aws_conn_id='aws_default')
+
+        cluster = hook.delete_cluster('test_cluster_2')
+        self.assertNotEqual(cluster, None)
+
+    @unittest.skipIf(mock_redshift is None, 'mock_redshift package not 
present')
+    def test_create_cluster_snapshot_returns_snapshot_data(self):
+        hook = RedshiftHook(aws_conn_id='aws_default')
+
+        snapshot = hook.create_cluster_snapshot('test_snapshot_2', 
'test_cluster')
+        self.assertNotEqual(snapshot, None)
+
+if __name__ == '__main__':
+    unittest.main()

Reply via email to