Repository: incubator-airflow
Updated Branches:
  refs/heads/master caaa4a516 -> 7255589f9


[AIRFLOW-2562] Add Google Kubernetes Engine Operators

Add Google Kubernetes Engine create_cluster,
delete_cluster operators
This allows users to use airflow to create or
delete clusters in the
google cloud platform

Closes #3477 from Noremac201/gke_create


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/7255589f
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/7255589f
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/7255589f

Branch: refs/heads/master
Commit: 7255589f95aa516aa1c4d90d85cba7b7277792c6
Parents: caaa4a5
Author: Cameron Moberg <[email protected]>
Authored: Fri Jun 15 20:44:29 2018 +0100
Committer: Kaxil Naik <[email protected]>
Committed: Fri Jun 15 20:44:29 2018 +0100

----------------------------------------------------------------------
 airflow/contrib/hooks/gcp_container_hook.py     | 217 +++++++++++++++++
 .../contrib/operators/gcp_container_operator.py | 172 +++++++++++++
 docs/code.rst                                   |   3 +
 docs/integration.rst                            |  27 +++
 setup.py                                        |   1 +
 tests/contrib/hooks/test_gcp_container_hook.py  | 240 +++++++++++++++++++
 .../operators/test_gcp_container_operator.py    | 125 ++++++++++
 7 files changed, 785 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/7255589f/airflow/contrib/hooks/gcp_container_hook.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/hooks/gcp_container_hook.py 
b/airflow/contrib/hooks/gcp_container_hook.py
new file mode 100644
index 0000000..d36d796
--- /dev/null
+++ b/airflow/contrib/hooks/gcp_container_hook.py
@@ -0,0 +1,217 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+import json
+import time
+
+from airflow import AirflowException, version
+from airflow.hooks.base_hook import BaseHook
+
+from google.api_core.exceptions import AlreadyExists
+from google.api_core.gapic_v1.method import DEFAULT
+from google.cloud import container_v1, exceptions
+from google.cloud.container_v1.gapic.enums import Operation
+from google.cloud.container_v1.types import Cluster
+from google.protobuf import json_format
+from google.api_core.gapic_v1.client_info import ClientInfo
+
+OPERATIONAL_POLL_INTERVAL = 15
+
+
+class GKEClusterHook(BaseHook):
+
+    def __init__(self, project_id, location):
+        self.project_id = project_id
+        self.location = location
+
+        # Add client library info for better error tracking
+        client_info = ClientInfo(client_library_version='airflow_v' + 
version.version)
+        self.client = 
container_v1.ClusterManagerClient(client_info=client_info)
+
+    def _dict_to_proto(self, py_dict, proto):
+        """
+        Converts a python dictionary to the proto supplied
+        :param py_dict: The dictionary to convert
+        :type py_dict: dict
+        :param proto: The proto object to merge with dictionary
+        :type proto: protobuf
+        :return: A parsed python dictionary in provided proto format
+        :raises:
+            ParseError: On JSON parsing problems.
+        """
+        dict_json_str = json.dumps(py_dict)
+        return json_format.Parse(dict_json_str, proto)
+
+    def wait_for_operation(self, operation):
+        """
+        Given an operation, continuously fetches the status from Google Cloud 
until either
+        completion or an error occurring
+        :param operation: The Operation to wait for
+        :type operation: A google.cloud.container_V1.gapic.enums.Operator
+        :return: A new, updated operation fetched from Google Cloud
+        """
+        self.log.info("Waiting for OPERATION_NAME %s" % operation.name)
+        time.sleep(OPERATIONAL_POLL_INTERVAL)
+        while operation.status != Operation.Status.DONE:
+            if operation.status == Operation.Status.RUNNING or 
operation.status == \
+                    Operation.Status.PENDING:
+                time.sleep(OPERATIONAL_POLL_INTERVAL)
+            else:
+                raise exceptions.GoogleCloudError(
+                    "Operation has failed with status: %s" % operation.status)
+            # To update status of operation
+            operation = self.get_operation(operation.name)
+        return operation
+
+    def get_operation(self, operation_name):
+        """
+        Fetches the operation from Google Cloud
+        :param operation_name: Name of operation to fetch
+        :type operation_name: str
+        :return: The new, updated operation from Google Cloud
+        """
+        return self.client.get_operation(project_id=self.project_id,
+                                         zone=self.location,
+                                         operation_id=operation_name)
+
+    def _append_label(self, cluster_proto, key, val):
+        """
+        Append labels to provided Cluster Protobuf
+
+        Labels must fit the regex [a-z]([-a-z0-9]*[a-z0-9])? (current airflow 
version
+        string follows semantic versioning spec: x.y.z).
+        :param cluster_proto: The proto to append resource_label airflow 
version to
+        :type cluster_proto: google.cloud.container_v1.types.Cluster
+        :param key: The key label
+        :type key: str
+        :param val:
+        :type val: str
+        :return: The cluster proto updated with new label
+        """
+        val = val.replace('.', '-').replace('+', '-')
+        cluster_proto.resource_labels.update({key: val})
+        return cluster_proto
+
+    def delete_cluster(self, name, retry=DEFAULT, timeout=DEFAULT):
+        """
+        Deletes the cluster, including the Kubernetes endpoint and all
+        worker nodes. Firewalls and routes that were configured during
+        cluster creation are also deleted. Other Google Compute Engine
+        resources that might be in use by the cluster (e.g. load balancer
+        resources) will not be deleted if they weren’t present at the
+        initial create time.
+
+        :param name: The name of the cluster to delete
+        :type name: str
+        :param retry: Retry object used to determine when/if to retry requests.
+            If None is specified, requests will not be retried.
+        :type retry: google.api_core.retry.Retry
+        :param timeout: The amount of time, in seconds, to wait for the 
request to
+            complete. Note that if retry is specified, the timeout applies to 
each
+            individual attempt.
+        :type timeout: float
+        :return: The full url to the delete operation if successful, else None
+        """
+
+        self.log.info("Deleting (project_id={}, zone={}, 
cluster_id={})".format(
+            self.project_id, self.location, name))
+
+        try:
+            op = self.client.delete_cluster(project_id=self.project_id,
+                                            zone=self.location,
+                                            cluster_id=name,
+                                            retry=retry,
+                                            timeout=timeout)
+            op = self.wait_for_operation(op)
+            # Returns server-defined url for the resource
+            return op.self_link
+        except exceptions.NotFound as error:
+            self.log.info('Assuming Success: ' + error.message)
+
+    def create_cluster(self, cluster, retry=DEFAULT, timeout=DEFAULT):
+        """
+        Creates a cluster, consisting of the specified number and type of 
Google Compute
+        Engine instances.
+
+        :param cluster: A Cluster protobuf or dict. If dict is provided, it 
must be of
+            the same form as the protobuf message 
google.cloud.container_v1.types.Cluster
+        :type cluster: dict or google.cloud.container_v1.types.Cluster
+        :param retry: A retry object (google.api_core.retry.Retry) used to 
retry requests.
+            If None is specified, requests will not be retried.
+        :type retry: google.api_core.retry.Retry
+        :param timeout: The amount of time, in seconds, to wait for the 
request to
+            complete. Note that if retry is specified, the timeout applies to 
each
+            individual attempt.
+        :type timeout: float
+        :return: The full url to the new, or existing, cluster
+        :raises
+            ParseError: On JSON parsing problems when trying to convert dict
+            AirflowException: cluster is not dict type nor Cluster proto type
+        """
+
+        if isinstance(cluster, dict):
+            cluster_proto = Cluster()
+            cluster = self._dict_to_proto(py_dict=cluster, proto=cluster_proto)
+        elif not isinstance(cluster, Cluster):
+            raise AirflowException(
+                "cluster is not instance of Cluster proto or python dict")
+
+        self._append_label(cluster, 'airflow-version', 'v' + version.version)
+
+        self.log.info("Creating (project_id={}, zone={}, 
cluster_name={})".format(
+            self.project_id,
+            self.location,
+            cluster.name))
+        try:
+            op = self.client.create_cluster(project_id=self.project_id,
+                                            zone=self.location,
+                                            cluster=cluster,
+                                            retry=retry,
+                                            timeout=timeout)
+            op = self.wait_for_operation(op)
+
+            return op.target_link
+        except AlreadyExists as error:
+            self.log.info('Assuming Success: ' + error.message)
+            return self.get_cluster(name=cluster.name).self_link
+
+    def get_cluster(self, name, retry=DEFAULT, timeout=DEFAULT):
+        """
+        Gets details of specified cluster
+        :param name: The name of the cluster to retrieve
+        :type name: str
+        :param retry: A retry object used to retry requests. If None is 
specified,
+            requests will not be retried.
+        :type retry: google.api_core.retry.Retry
+        :param timeout: The amount of time, in seconds, to wait for the 
request to
+            complete. Note that if retry is specified, the timeout applies to 
each
+            individual attempt.
+        :type timeout: float
+        :return: A google.cloud.container_v1.types.Cluster instance
+        """
+        self.log.info("Fetching cluster (project_id={}, zone={}, 
cluster_name={})".format(
+            self.project_id,
+            self.location,
+            name))
+
+        return self.client.get_cluster(project_id=self.project_id,
+                                       zone=self.location,
+                                       cluster_id=name,
+                                       retry=retry,
+                                       timeout=timeout).self_link

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/7255589f/airflow/contrib/operators/gcp_container_operator.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/operators/gcp_container_operator.py 
b/airflow/contrib/operators/gcp_container_operator.py
new file mode 100644
index 0000000..5648b4d
--- /dev/null
+++ b/airflow/contrib/operators/gcp_container_operator.py
@@ -0,0 +1,172 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+from airflow import AirflowException
+from airflow.contrib.hooks.gcp_container_hook import GKEClusterHook
+from airflow.models import BaseOperator
+from airflow.utils.decorators import apply_defaults
+
+
+class GKEClusterDeleteOperator(BaseOperator):
+    template_fields = ['project_id', 'gcp_conn_id', 'name', 'location', 
'api_version']
+
+    @apply_defaults
+    def __init__(self,
+                 project_id,
+                 name,
+                 location,
+                 gcp_conn_id='google_cloud_default',
+                 api_version='v2',
+                 *args,
+                 **kwargs):
+        """
+        Deletes the cluster, including the Kubernetes endpoint and all worker 
nodes.
+
+
+        To delete a certain cluster, you must specify the ``project_id``, the 
``name``
+        of the cluster, the ``location`` that the cluster is in, and the 
``task_id``.
+
+        **Operator Creation**: ::
+
+            operator = GKEClusterDeleteOperator(
+                        task_id='cluster_delete',
+                        project_id='my-project',
+                        location='cluster-location'
+                        name='cluster-name')
+
+        .. seealso::
+            For more detail about deleting clusters have a look at the 
reference:
+            
https://google-cloud-python.readthedocs.io/en/latest/container/gapic/v1/api.html#google.cloud.container_v1.ClusterManagerClient.delete_cluster
+
+        :param project_id: The Google Developers Console [project ID or 
project number]
+        :type project_id: str
+        :param name: The name of the resource to delete, in this case cluster 
name
+        :type name: str
+        :param location: The name of the Google Compute Engine zone in which 
the cluster
+            resides.
+        :type location: str
+        :param gcp_conn_id: The connection ID to use connecting to Google 
Cloud Platform.
+        :type gcp_conn_id: str
+        :param api_version: The api version to use
+        :type api_version: str
+        """
+        super(GKEClusterDeleteOperator, self).__init__(*args, **kwargs)
+
+        self.project_id = project_id
+        self.gcp_conn_id = gcp_conn_id
+        self.location = location
+        self.api_version = api_version
+        self.name = name
+
+    def _check_input(self):
+        if not all([self.project_id, self.name, self.location]):
+            self.log.error(
+                'One of (project_id, name, location) is missing or incorrect')
+            raise AirflowException('Operator has incorrect or missing input.')
+
+    def execute(self, context):
+        self._check_input()
+        hook = GKEClusterHook(self.project_id, self.location)
+        delete_result = hook.delete_cluster(name=self.name)
+        return delete_result
+
+
+class GKEClusterCreateOperator(BaseOperator):
+    template_fields = ['project_id', 'gcp_conn_id', 'location', 'api_version', 
'body']
+
+    @apply_defaults
+    def __init__(self,
+                 project_id,
+                 location,
+                 body={},
+                 gcp_conn_id='google_cloud_default',
+                 api_version='v2',
+                 *args,
+                 **kwargs):
+        """
+        Create a Google Kubernetes Engine Cluster of specified dimensions
+        The operator will wait until the cluster is created.
+
+        The **minimum** required to define a cluster to create is:
+
+        ``dict()`` ::
+            cluster_def = {'name': 'my-cluster-name',
+                           'initial_node_count': 1}
+
+        or
+
+        ``Cluster`` proto ::
+            from google.cloud.container_v1.types import Cluster
+
+            cluster_def = Cluster(name='my-cluster-name', initial_node_count=1)
+
+        **Operator Creation**: ::
+
+            operator = GKEClusterCreateOperator(
+                        task_id='cluster_create',
+                        project_id='my-project',
+                        location='my-location'
+                        body=cluster_def)
+
+        .. seealso::
+            For more detail on about creating clusters have a look at the 
reference:
+            
https://google-cloud-python.readthedocs.io/en/latest/container/gapic/v1/types.html#google.cloud.container_v1.types.Cluster
+
+        :param project_id: The Google Developers Console [project ID or 
project number]
+        :type project_id: str
+        :param location: The name of the Google Compute Engine zone in which 
the cluster
+            resides.
+        :type location: str
+        :param body: The Cluster definition to create, can be protobuf or 
python dict, if
+            dict it must match protobuf message Cluster
+        :type body: dict or google.cloud.container_v1.types.Cluster
+        :param gcp_conn_id: The connection ID to use connecting to Google 
Cloud Platform.
+        :type gcp_conn_id: str
+        :param api_version: The api version to use
+        :type api_version: str
+        """
+        super(GKEClusterCreateOperator, self).__init__(*args, **kwargs)
+
+        self.project_id = project_id
+        self.gcp_conn_id = gcp_conn_id
+        self.location = location
+        self.api_version = api_version
+        self.body = body
+
+    def _check_input(self):
+        if all([self.project_id, self.location, self.body]):
+            if isinstance(self.body, dict) \
+                    and 'name' in self.body \
+                    and 'initial_node_count' in self.body:
+                # Don't throw error
+                return
+            # If not dict, then must
+            elif self.body.name and self.body.initial_node_count:
+                return
+
+        self.log.error(
+            'One of (project_id, location, body, body[\'name\'], '
+            'body[\'initial_node_count\']) is missing or incorrect')
+        raise AirflowException('Operator has incorrect or missing input.')
+
+    def execute(self, context):
+        self._check_input()
+        hook = GKEClusterHook(self.project_id, self.location)
+        create_op = hook.create_cluster(cluster=self.body)
+        return create_op

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/7255589f/docs/code.rst
----------------------------------------------------------------------
diff --git a/docs/code.rst b/docs/code.rst
index 4ab187c..3b51484 100644
--- a/docs/code.rst
+++ b/docs/code.rst
@@ -148,6 +148,8 @@ Operators
 .. autoclass:: 
airflow.contrib.operators.emr_terminate_job_flow_operator.EmrTerminateJobFlowOperator
 .. autoclass:: 
airflow.contrib.operators.file_to_gcs.FileToGoogleCloudStorageOperator
 .. autoclass:: airflow.contrib.operators.file_to_wasb.FileToWasbOperator
+.. autoclass:: 
airflow.contrib.operators.gcp_container_operator.GKEClusterCreateOperator
+.. autoclass:: 
airflow.contrib.operators.gcp_container_operator.GKEClusterDeleteOperator
 .. autoclass:: 
airflow.contrib.operators.gcs_download_operator.GoogleCloudStorageDownloadOperator
 .. autoclass:: 
airflow.contrib.operators.gcs_list_operator.GoogleCloudStorageListOperator
 .. autoclass:: 
airflow.contrib.operators.gcs_operator.GoogleCloudStorageCreateBucketOperator
@@ -375,6 +377,7 @@ Community contributed hooks
 .. autoclass:: airflow.contrib.hooks.ftp_hook.FTPHook
 .. autoclass:: airflow.contrib.hooks.ftp_hook.FTPSHook
 .. autoclass:: airflow.contrib.hooks.gcp_api_base_hook.GoogleCloudBaseHook
+.. autoclass:: airflow.contrib.hooks.gcp_container_hook.GKEClusterHook
 .. autoclass:: airflow.contrib.hooks.gcp_dataflow_hook.DataFlowHook
 .. autoclass:: airflow.contrib.hooks.gcp_dataproc_hook.DataProcHook
 .. autoclass:: airflow.contrib.hooks.gcp_mlengine_hook.MLEngineHook

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/7255589f/docs/integration.rst
----------------------------------------------------------------------
diff --git a/docs/integration.rst b/docs/integration.rst
index 3d43685..972726b 100644
--- a/docs/integration.rst
+++ b/docs/integration.rst
@@ -733,3 +733,30 @@ GoogleCloudStorageHook
 
 .. autoclass:: airflow.contrib.hooks.gcs_hook.GoogleCloudStorageHook
     :members:
+
+Google Kubernetes Engine
+''''''''''''''''''''''''
+
+Google Kubernetes Engine Cluster Operators
+""""""""""""""""""""""""""""""""""""""""""
+
+- :ref:`GKEClusterCreateOperator` : Creates a Kubernetes Cluster in Google 
Cloud Platform
+- :ref:`GKEClusterDeleteOperator` : Deletes a Kubernetes Cluster in Google 
Cloud Platform
+
+GKEClusterCreateOperator
+^^^^^^^^^^^^^^^^^^^^^^^^
+
+.. autoclass:: 
airflow.contrib.operators.gcp_container_operator.GKEClusterCreateOperator
+.. _GKEClusterCreateOperator:
+
+GKEClusterDeleteOperator
+^^^^^^^^^^^^^^^^^^^^^^^^
+
+.. autoclass:: 
airflow.contrib.operators.gcp_container_operator.GKEClusterDeleteOperator
+.. _GKEClusterDeleteOperator:
+
+Google Kubernetes Engine Hook
+"""""""""""""""""""""""""""""
+
+.. autoclass:: airflow.contrib.hooks.gcp_container_hook.GKEClusterHook
+    :members:

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/7255589f/setup.py
----------------------------------------------------------------------
diff --git a/setup.py b/setup.py
index 368652a..a6ce6c9 100644
--- a/setup.py
+++ b/setup.py
@@ -148,6 +148,7 @@ gcp_api = [
     'google-api-python-client>=1.6.0, <2.0.0dev',
     'google-auth>=1.0.0, <2.0.0dev',
     'google-auth-httplib2>=0.0.1',
+    'google-cloud-container>=0.1.1',
     'PyOpenSSL',
     'pandas-gbq'
 ]

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/7255589f/tests/contrib/hooks/test_gcp_container_hook.py
----------------------------------------------------------------------
diff --git a/tests/contrib/hooks/test_gcp_container_hook.py 
b/tests/contrib/hooks/test_gcp_container_hook.py
new file mode 100644
index 0000000..f3705ea
--- /dev/null
+++ b/tests/contrib/hooks/test_gcp_container_hook.py
@@ -0,0 +1,240 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+import unittest
+
+from airflow import AirflowException
+from airflow.contrib.hooks.gcp_container_hook import GKEClusterHook
+
+try:
+    from unittest import mock
+except ImportError:
+    try:
+        import mock
+    except ImportError:
+        mock = None
+
+TASK_ID = 'test-gke-cluster-operator'
+CLUSTER_NAME = 'test-cluster'
+TEST_PROJECT_ID = 'test-project'
+ZONE = 'test-zone'
+
+
+class GKEClusterHookDeleteTest(unittest.TestCase):
+    def setUp(self):
+        with mock.patch.object(GKEClusterHook, "__init__", return_value=None):
+            self.gke_hook = GKEClusterHook(None, None, None)
+            self.gke_hook.project_id = TEST_PROJECT_ID
+            self.gke_hook.location = ZONE
+            self.gke_hook.client = mock.Mock()
+
+    
@mock.patch("airflow.contrib.hooks.gcp_container_hook.GKEClusterHook._dict_to_proto")
+    @mock.patch(
+        
"airflow.contrib.hooks.gcp_container_hook.GKEClusterHook.wait_for_operation")
+    def test_delete_cluster(self, wait_mock, convert_mock):
+        retry_mock, timeout_mock = mock.Mock(), mock.Mock()
+
+        client_delete = self.gke_hook.client.delete_cluster = mock.Mock()
+
+        self.gke_hook.delete_cluster(name=CLUSTER_NAME, retry=retry_mock,
+                                     timeout=timeout_mock)
+
+        client_delete.assert_called_with(project_id=TEST_PROJECT_ID, zone=ZONE,
+                                         cluster_id=CLUSTER_NAME,
+                                         retry=retry_mock, 
timeout=timeout_mock)
+        wait_mock.assert_called_with(client_delete.return_value)
+        convert_mock.assert_not_called()
+
+    
@mock.patch("airflow.contrib.hooks.gcp_container_hook.GKEClusterHook._dict_to_proto")
+    @mock.patch(
+        
"airflow.contrib.hooks.gcp_container_hook.GKEClusterHook.wait_for_operation")
+    def test_delete_cluster_error(self, wait_mock, convert_mock):
+        # To force an error
+        self.gke_hook.client.delete_cluster.side_effect = 
AirflowException('400')
+
+        with self.assertRaises(AirflowException):
+            self.gke_hook.delete_cluster(None)
+            wait_mock.assert_not_called()
+            convert_mock.assert_not_called()
+
+
+class GKEClusterHookCreateTest(unittest.TestCase):
+    def setUp(self):
+        with mock.patch.object(GKEClusterHook, "__init__", return_value=None):
+            self.gke_hook = GKEClusterHook(None, None, None)
+            self.gke_hook.project_id = TEST_PROJECT_ID
+            self.gke_hook.location = ZONE
+            self.gke_hook.client = mock.Mock()
+
+    
@mock.patch("airflow.contrib.hooks.gcp_container_hook.GKEClusterHook._dict_to_proto")
+    @mock.patch(
+        
"airflow.contrib.hooks.gcp_container_hook.GKEClusterHook.wait_for_operation")
+    def test_create_cluster_proto(self, wait_mock, convert_mock):
+        from google.cloud.container_v1.proto.cluster_service_pb2 import Cluster
+
+        mock_cluster_proto = Cluster()
+        mock_cluster_proto.name = CLUSTER_NAME
+
+        retry_mock, timeout_mock = mock.Mock(), mock.Mock()
+
+        client_create = self.gke_hook.client.create_cluster = mock.Mock()
+
+        self.gke_hook.create_cluster(mock_cluster_proto, retry=retry_mock,
+                                     timeout=timeout_mock)
+
+        client_create.assert_called_with(project_id=TEST_PROJECT_ID, zone=ZONE,
+                                         cluster=mock_cluster_proto,
+                                         retry=retry_mock, 
timeout=timeout_mock)
+        wait_mock.assert_called_with(client_create.return_value)
+        convert_mock.assert_not_called()
+
+    
@mock.patch("airflow.contrib.hooks.gcp_container_hook.GKEClusterHook._dict_to_proto")
+    @mock.patch(
+        
"airflow.contrib.hooks.gcp_container_hook.GKEClusterHook.wait_for_operation")
+    def test_delete_cluster_dict(self, wait_mock, convert_mock):
+        mock_cluster_dict = {'name': CLUSTER_NAME}
+        retry_mock, timeout_mock = mock.Mock(), mock.Mock()
+
+        client_create = self.gke_hook.client.create_cluster = mock.Mock()
+        proto_mock = convert_mock.return_value = mock.Mock()
+
+        self.gke_hook.create_cluster(mock_cluster_dict, retry=retry_mock,
+                                     timeout=timeout_mock)
+
+        client_create.assert_called_with(project_id=TEST_PROJECT_ID, zone=ZONE,
+                                         cluster=proto_mock,
+                                         retry=retry_mock, 
timeout=timeout_mock)
+        wait_mock.assert_called_with(client_create.return_value)
+        self.assertEqual(convert_mock.call_args[1]['py_dict'], 
mock_cluster_dict)
+
+    
@mock.patch("airflow.contrib.hooks.gcp_container_hook.GKEClusterHook._dict_to_proto")
+    @mock.patch(
+        
"airflow.contrib.hooks.gcp_container_hook.GKEClusterHook.wait_for_operation")
+    def test_create_cluster_error(self, wait_mock, convert_mock):
+        # to force an error
+        mock_cluster_proto = None
+
+        with self.assertRaises(AirflowException):
+            self.gke_hook.create_cluster(mock_cluster_proto)
+            wait_mock.assert_not_called()
+            convert_mock.assert_not_called()
+
+
+class GKEClusterHookGetTest(unittest.TestCase):
+    def setUp(self):
+        with mock.patch.object(GKEClusterHook, "__init__", return_value=None):
+            self.gke_hook = GKEClusterHook(None, None, None)
+            self.gke_hook.project_id = TEST_PROJECT_ID
+            self.gke_hook.location = ZONE
+            self.gke_hook.client = mock.Mock()
+
+    def test_get_cluster(self):
+        retry_mock, timeout_mock = mock.Mock(), mock.Mock()
+
+        client_get = self.gke_hook.client.get_cluster = mock.Mock()
+
+        self.gke_hook.get_cluster(name=CLUSTER_NAME, retry=retry_mock,
+                                  timeout=timeout_mock)
+
+        client_get.assert_called_with(project_id=TEST_PROJECT_ID, zone=ZONE,
+                                      cluster_id=CLUSTER_NAME,
+                                      retry=retry_mock, timeout=timeout_mock)
+
+
+class GKEClusterHookTest(unittest.TestCase):
+
+    def setUp(self):
+        with mock.patch.object(GKEClusterHook, "__init__", return_value=None):
+            self.gke_hook = GKEClusterHook(None, None, None)
+            self.gke_hook.project_id = TEST_PROJECT_ID
+            self.gke_hook.location = ZONE
+            self.gke_hook.client = mock.Mock()
+
+    def test_get_operation(self):
+        self.gke_hook.client.get_operation = mock.Mock()
+        self.gke_hook.get_operation('TEST_OP')
+        
self.gke_hook.client.get_operation.assert_called_with(project_id=TEST_PROJECT_ID,
+                                                              zone=ZONE,
+                                                              
operation_id='TEST_OP')
+
+    def test_append_label(self):
+        key = 'test-key'
+        val = 'test-val'
+        mock_proto = mock.Mock()
+        self.gke_hook._append_label(mock_proto, key, val)
+        mock_proto.resource_labels.update.assert_called_with({key: val})
+
+    def test_append_label_replace(self):
+        key = 'test-key'
+        val = 'test.val+this'
+        mock_proto = mock.Mock()
+        self.gke_hook._append_label(mock_proto, key, val)
+        mock_proto.resource_labels.update.assert_called_with({key: 
'test-val-this'})
+
+    @mock.patch("time.sleep")
+    def test_wait_for_response_done(self, time_mock):
+        from google.cloud.container_v1.gapic.enums import Operation
+        mock_op = mock.Mock()
+        mock_op.status = Operation.Status.DONE
+        self.gke_hook.wait_for_operation(mock_op)
+        self.assertEqual(time_mock.call_count, 1)
+
+    @mock.patch("time.sleep")
+    def test_wait_for_response_exception(self, time_mock):
+        from google.cloud.container_v1.gapic.enums import Operation
+        from google.cloud.exceptions import GoogleCloudError
+
+        mock_op = mock.Mock()
+        mock_op.status = Operation.Status.ABORTING
+
+        with self.assertRaises(GoogleCloudError):
+            self.gke_hook.wait_for_operation(mock_op)
+            self.assertEqual(time_mock.call_count, 1)
+
+    
@mock.patch("airflow.contrib.hooks.gcp_container_hook.GKEClusterHook.get_operation")
+    @mock.patch("time.sleep")
+    def test_wait_for_response_running(self, time_mock, operation_mock):
+        from google.cloud.container_v1.gapic.enums import Operation
+
+        running_op, done_op, pending_op = mock.Mock(), mock.Mock(), mock.Mock()
+        running_op.status = Operation.Status.RUNNING
+        done_op.status = Operation.Status.DONE
+        pending_op.status = Operation.Status.PENDING
+
+        # Status goes from Running -> Pending -> Done
+        operation_mock.side_effect = [pending_op, done_op]
+        self.gke_hook.wait_for_operation(running_op)
+
+        self.assertEqual(time_mock.call_count, 3)
+        operation_mock.assert_any_call(running_op.name)
+        operation_mock.assert_any_call(pending_op.name)
+        self.assertEqual(operation_mock.call_count, 2)
+
+    @mock.patch("google.protobuf.json_format.Parse")
+    @mock.patch("json.dumps")
+    def test_dict_to_proto(self, dumps_mock, parse_mock):
+        mock_dict = {'name': 'test'}
+        mock_proto = mock.Mock()
+
+        dumps_mock.return_value = mock.Mock()
+
+        self.gke_hook._dict_to_proto(mock_dict, mock_proto)
+
+        dumps_mock.assert_called_with(mock_dict)
+        parse_mock.assert_called_with(dumps_mock(), mock_proto)

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/7255589f/tests/contrib/operators/test_gcp_container_operator.py
----------------------------------------------------------------------
diff --git a/tests/contrib/operators/test_gcp_container_operator.py 
b/tests/contrib/operators/test_gcp_container_operator.py
new file mode 100644
index 0000000..0f67290
--- /dev/null
+++ b/tests/contrib/operators/test_gcp_container_operator.py
@@ -0,0 +1,125 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import unittest
+
+from airflow import AirflowException
+from airflow.contrib.operators.gcp_container_operator import 
GKEClusterCreateOperator, \
+    GKEClusterDeleteOperator
+
+try:
+    from unittest import mock
+except ImportError:
+    try:
+        import mock
+    except ImportError:
+        mock = None
+
+PROJECT_ID = 'test-id'
+PROJECT_LOCATION = 'test-location'
+PROJECT_TASK_ID = 'test-task-id'
+CLUSTER_NAME = 'test-cluster-name'
+
+PROJECT_BODY = {'name': 'test-name'}
+PROJECT_BODY_CREATE = {'name': 'test-name', 'initial_node_count': 1}
+
+
+class GoogleCloudPlatformContainerOperatorTest(unittest.TestCase):
+
+    
@mock.patch('airflow.contrib.operators.gcp_container_operator.GKEClusterHook')
+    def test_create_execute(self, mock_hook):
+        operator = GKEClusterCreateOperator(project_id=PROJECT_ID,
+                                            location=PROJECT_LOCATION,
+                                            body=PROJECT_BODY_CREATE,
+                                            task_id=PROJECT_TASK_ID)
+
+        operator.execute(None)
+        mock_hook.return_value.create_cluster.assert_called_once_with(
+            cluster=PROJECT_BODY_CREATE)
+
+    
@mock.patch('airflow.contrib.operators.gcp_container_operator.GKEClusterHook')
+    def test_create_execute_error_body(self, mock_hook):
+        with self.assertRaises(AirflowException):
+            operator = GKEClusterCreateOperator(project_id=PROJECT_ID,
+                                                location=PROJECT_LOCATION,
+                                                body=None,
+                                                task_id=PROJECT_TASK_ID)
+
+            operator.execute(None)
+            mock_hook.return_value.create_cluster.assert_not_called()
+
+    
@mock.patch('airflow.contrib.operators.gcp_container_operator.GKEClusterHook')
+    def test_create_execute_error_project_id(self, mock_hook):
+        with self.assertRaises(AirflowException):
+            operator = GKEClusterCreateOperator(location=PROJECT_LOCATION,
+                                                body=PROJECT_BODY,
+                                                task_id=PROJECT_TASK_ID)
+
+            operator.execute(None)
+            mock_hook.return_value.create_cluster.assert_not_called()
+
+    
@mock.patch('airflow.contrib.operators.gcp_container_operator.GKEClusterHook')
+    def test_create_execute_error_location(self, mock_hook):
+        with self.assertRaises(AirflowException):
+            operator = GKEClusterCreateOperator(project_id=PROJECT_ID,
+                                                body=PROJECT_BODY,
+                                                task_id=PROJECT_TASK_ID)
+
+            operator.execute(None)
+            mock_hook.return_value.create_cluster.assert_not_called()
+
+    
@mock.patch('airflow.contrib.operators.gcp_container_operator.GKEClusterHook')
+    def test_delete_execute(self, mock_hook):
+        operator = GKEClusterDeleteOperator(project_id=PROJECT_ID,
+                                            name=CLUSTER_NAME,
+                                            location=PROJECT_LOCATION,
+                                            task_id=PROJECT_TASK_ID)
+
+        operator.execute(None)
+        mock_hook.return_value.delete_cluster.assert_called_once_with(
+            name=CLUSTER_NAME)
+
+    
@mock.patch('airflow.contrib.operators.gcp_container_operator.GKEClusterHook')
+    def test_delete_execute_error_project_id(self, mock_hook):
+        with self.assertRaises(AirflowException):
+            operator = GKEClusterDeleteOperator(location=PROJECT_LOCATION,
+                                                name=CLUSTER_NAME,
+                                                task_id=PROJECT_TASK_ID)
+            operator.execute(None)
+            mock_hook.return_value.delete_cluster.assert_not_called()
+
+    
@mock.patch('airflow.contrib.operators.gcp_container_operator.GKEClusterHook')
+    def test_delete_execute_error_cluster_name(self, mock_hook):
+        with self.assertRaises(AirflowException):
+            operator = GKEClusterDeleteOperator(project_id=PROJECT_ID,
+                                                location=PROJECT_LOCATION,
+                                                task_id=PROJECT_TASK_ID)
+
+            operator.execute(None)
+            mock_hook.return_value.delete_cluster.assert_not_called()
+
+    
@mock.patch('airflow.contrib.operators.gcp_container_operator.GKEClusterHook')
+    def test_delete_execute_error_location(self, mock_hook):
+        with self.assertRaises(AirflowException):
+            operator = GKEClusterDeleteOperator(project_id=PROJECT_ID,
+                                                name=CLUSTER_NAME,
+                                                task_id=PROJECT_TASK_ID)
+
+            operator.execute(None)
+            mock_hook.return_value.delete_cluster.assert_not_called()


Reply via email to