Repository: incubator-airflow Updated Branches: refs/heads/master caaa4a516 -> 7255589f9
[AIRFLOW-2562] Add Google Kubernetes Engine Operators Add Google Kubernetes Engine create_cluster, delete_cluster operators This allows users to use airflow to create or delete clusters in the google cloud platform Closes #3477 from Noremac201/gke_create Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/7255589f Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/7255589f Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/7255589f Branch: refs/heads/master Commit: 7255589f95aa516aa1c4d90d85cba7b7277792c6 Parents: caaa4a5 Author: Cameron Moberg <[email protected]> Authored: Fri Jun 15 20:44:29 2018 +0100 Committer: Kaxil Naik <[email protected]> Committed: Fri Jun 15 20:44:29 2018 +0100 ---------------------------------------------------------------------- airflow/contrib/hooks/gcp_container_hook.py | 217 +++++++++++++++++ .../contrib/operators/gcp_container_operator.py | 172 +++++++++++++ docs/code.rst | 3 + docs/integration.rst | 27 +++ setup.py | 1 + tests/contrib/hooks/test_gcp_container_hook.py | 240 +++++++++++++++++++ .../operators/test_gcp_container_operator.py | 125 ++++++++++ 7 files changed, 785 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/7255589f/airflow/contrib/hooks/gcp_container_hook.py ---------------------------------------------------------------------- diff --git a/airflow/contrib/hooks/gcp_container_hook.py b/airflow/contrib/hooks/gcp_container_hook.py new file mode 100644 index 0000000..d36d796 --- /dev/null +++ b/airflow/contrib/hooks/gcp_container_hook.py @@ -0,0 +1,217 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +import json +import time + +from airflow import AirflowException, version +from airflow.hooks.base_hook import BaseHook + +from google.api_core.exceptions import AlreadyExists +from google.api_core.gapic_v1.method import DEFAULT +from google.cloud import container_v1, exceptions +from google.cloud.container_v1.gapic.enums import Operation +from google.cloud.container_v1.types import Cluster +from google.protobuf import json_format +from google.api_core.gapic_v1.client_info import ClientInfo + +OPERATIONAL_POLL_INTERVAL = 15 + + +class GKEClusterHook(BaseHook): + + def __init__(self, project_id, location): + self.project_id = project_id + self.location = location + + # Add client library info for better error tracking + client_info = ClientInfo(client_library_version='airflow_v' + version.version) + self.client = container_v1.ClusterManagerClient(client_info=client_info) + + def _dict_to_proto(self, py_dict, proto): + """ + Converts a python dictionary to the proto supplied + :param py_dict: The dictionary to convert + :type py_dict: dict + :param proto: The proto object to merge with dictionary + :type proto: protobuf + :return: A parsed python dictionary in provided proto format + :raises: + ParseError: On JSON parsing problems. + """ + dict_json_str = json.dumps(py_dict) + return json_format.Parse(dict_json_str, proto) + + def wait_for_operation(self, operation): + """ + Given an operation, continuously fetches the status from Google Cloud until either + completion or an error occurring + :param operation: The Operation to wait for + :type operation: A google.cloud.container_V1.gapic.enums.Operator + :return: A new, updated operation fetched from Google Cloud + """ + self.log.info("Waiting for OPERATION_NAME %s" % operation.name) + time.sleep(OPERATIONAL_POLL_INTERVAL) + while operation.status != Operation.Status.DONE: + if operation.status == Operation.Status.RUNNING or operation.status == \ + Operation.Status.PENDING: + time.sleep(OPERATIONAL_POLL_INTERVAL) + else: + raise exceptions.GoogleCloudError( + "Operation has failed with status: %s" % operation.status) + # To update status of operation + operation = self.get_operation(operation.name) + return operation + + def get_operation(self, operation_name): + """ + Fetches the operation from Google Cloud + :param operation_name: Name of operation to fetch + :type operation_name: str + :return: The new, updated operation from Google Cloud + """ + return self.client.get_operation(project_id=self.project_id, + zone=self.location, + operation_id=operation_name) + + def _append_label(self, cluster_proto, key, val): + """ + Append labels to provided Cluster Protobuf + + Labels must fit the regex [a-z]([-a-z0-9]*[a-z0-9])? (current airflow version + string follows semantic versioning spec: x.y.z). + :param cluster_proto: The proto to append resource_label airflow version to + :type cluster_proto: google.cloud.container_v1.types.Cluster + :param key: The key label + :type key: str + :param val: + :type val: str + :return: The cluster proto updated with new label + """ + val = val.replace('.', '-').replace('+', '-') + cluster_proto.resource_labels.update({key: val}) + return cluster_proto + + def delete_cluster(self, name, retry=DEFAULT, timeout=DEFAULT): + """ + Deletes the cluster, including the Kubernetes endpoint and all + worker nodes. Firewalls and routes that were configured during + cluster creation are also deleted. Other Google Compute Engine + resources that might be in use by the cluster (e.g. load balancer + resources) will not be deleted if they werenât present at the + initial create time. + + :param name: The name of the cluster to delete + :type name: str + :param retry: Retry object used to determine when/if to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to + complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :return: The full url to the delete operation if successful, else None + """ + + self.log.info("Deleting (project_id={}, zone={}, cluster_id={})".format( + self.project_id, self.location, name)) + + try: + op = self.client.delete_cluster(project_id=self.project_id, + zone=self.location, + cluster_id=name, + retry=retry, + timeout=timeout) + op = self.wait_for_operation(op) + # Returns server-defined url for the resource + return op.self_link + except exceptions.NotFound as error: + self.log.info('Assuming Success: ' + error.message) + + def create_cluster(self, cluster, retry=DEFAULT, timeout=DEFAULT): + """ + Creates a cluster, consisting of the specified number and type of Google Compute + Engine instances. + + :param cluster: A Cluster protobuf or dict. If dict is provided, it must be of + the same form as the protobuf message google.cloud.container_v1.types.Cluster + :type cluster: dict or google.cloud.container_v1.types.Cluster + :param retry: A retry object (google.api_core.retry.Retry) used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to + complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :return: The full url to the new, or existing, cluster + :raises + ParseError: On JSON parsing problems when trying to convert dict + AirflowException: cluster is not dict type nor Cluster proto type + """ + + if isinstance(cluster, dict): + cluster_proto = Cluster() + cluster = self._dict_to_proto(py_dict=cluster, proto=cluster_proto) + elif not isinstance(cluster, Cluster): + raise AirflowException( + "cluster is not instance of Cluster proto or python dict") + + self._append_label(cluster, 'airflow-version', 'v' + version.version) + + self.log.info("Creating (project_id={}, zone={}, cluster_name={})".format( + self.project_id, + self.location, + cluster.name)) + try: + op = self.client.create_cluster(project_id=self.project_id, + zone=self.location, + cluster=cluster, + retry=retry, + timeout=timeout) + op = self.wait_for_operation(op) + + return op.target_link + except AlreadyExists as error: + self.log.info('Assuming Success: ' + error.message) + return self.get_cluster(name=cluster.name).self_link + + def get_cluster(self, name, retry=DEFAULT, timeout=DEFAULT): + """ + Gets details of specified cluster + :param name: The name of the cluster to retrieve + :type name: str + :param retry: A retry object used to retry requests. If None is specified, + requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: The amount of time, in seconds, to wait for the request to + complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :return: A google.cloud.container_v1.types.Cluster instance + """ + self.log.info("Fetching cluster (project_id={}, zone={}, cluster_name={})".format( + self.project_id, + self.location, + name)) + + return self.client.get_cluster(project_id=self.project_id, + zone=self.location, + cluster_id=name, + retry=retry, + timeout=timeout).self_link http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/7255589f/airflow/contrib/operators/gcp_container_operator.py ---------------------------------------------------------------------- diff --git a/airflow/contrib/operators/gcp_container_operator.py b/airflow/contrib/operators/gcp_container_operator.py new file mode 100644 index 0000000..5648b4d --- /dev/null +++ b/airflow/contrib/operators/gcp_container_operator.py @@ -0,0 +1,172 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +from airflow import AirflowException +from airflow.contrib.hooks.gcp_container_hook import GKEClusterHook +from airflow.models import BaseOperator +from airflow.utils.decorators import apply_defaults + + +class GKEClusterDeleteOperator(BaseOperator): + template_fields = ['project_id', 'gcp_conn_id', 'name', 'location', 'api_version'] + + @apply_defaults + def __init__(self, + project_id, + name, + location, + gcp_conn_id='google_cloud_default', + api_version='v2', + *args, + **kwargs): + """ + Deletes the cluster, including the Kubernetes endpoint and all worker nodes. + + + To delete a certain cluster, you must specify the ``project_id``, the ``name`` + of the cluster, the ``location`` that the cluster is in, and the ``task_id``. + + **Operator Creation**: :: + + operator = GKEClusterDeleteOperator( + task_id='cluster_delete', + project_id='my-project', + location='cluster-location' + name='cluster-name') + + .. seealso:: + For more detail about deleting clusters have a look at the reference: + https://google-cloud-python.readthedocs.io/en/latest/container/gapic/v1/api.html#google.cloud.container_v1.ClusterManagerClient.delete_cluster + + :param project_id: The Google Developers Console [project ID or project number] + :type project_id: str + :param name: The name of the resource to delete, in this case cluster name + :type name: str + :param location: The name of the Google Compute Engine zone in which the cluster + resides. + :type location: str + :param gcp_conn_id: The connection ID to use connecting to Google Cloud Platform. + :type gcp_conn_id: str + :param api_version: The api version to use + :type api_version: str + """ + super(GKEClusterDeleteOperator, self).__init__(*args, **kwargs) + + self.project_id = project_id + self.gcp_conn_id = gcp_conn_id + self.location = location + self.api_version = api_version + self.name = name + + def _check_input(self): + if not all([self.project_id, self.name, self.location]): + self.log.error( + 'One of (project_id, name, location) is missing or incorrect') + raise AirflowException('Operator has incorrect or missing input.') + + def execute(self, context): + self._check_input() + hook = GKEClusterHook(self.project_id, self.location) + delete_result = hook.delete_cluster(name=self.name) + return delete_result + + +class GKEClusterCreateOperator(BaseOperator): + template_fields = ['project_id', 'gcp_conn_id', 'location', 'api_version', 'body'] + + @apply_defaults + def __init__(self, + project_id, + location, + body={}, + gcp_conn_id='google_cloud_default', + api_version='v2', + *args, + **kwargs): + """ + Create a Google Kubernetes Engine Cluster of specified dimensions + The operator will wait until the cluster is created. + + The **minimum** required to define a cluster to create is: + + ``dict()`` :: + cluster_def = {'name': 'my-cluster-name', + 'initial_node_count': 1} + + or + + ``Cluster`` proto :: + from google.cloud.container_v1.types import Cluster + + cluster_def = Cluster(name='my-cluster-name', initial_node_count=1) + + **Operator Creation**: :: + + operator = GKEClusterCreateOperator( + task_id='cluster_create', + project_id='my-project', + location='my-location' + body=cluster_def) + + .. seealso:: + For more detail on about creating clusters have a look at the reference: + https://google-cloud-python.readthedocs.io/en/latest/container/gapic/v1/types.html#google.cloud.container_v1.types.Cluster + + :param project_id: The Google Developers Console [project ID or project number] + :type project_id: str + :param location: The name of the Google Compute Engine zone in which the cluster + resides. + :type location: str + :param body: The Cluster definition to create, can be protobuf or python dict, if + dict it must match protobuf message Cluster + :type body: dict or google.cloud.container_v1.types.Cluster + :param gcp_conn_id: The connection ID to use connecting to Google Cloud Platform. + :type gcp_conn_id: str + :param api_version: The api version to use + :type api_version: str + """ + super(GKEClusterCreateOperator, self).__init__(*args, **kwargs) + + self.project_id = project_id + self.gcp_conn_id = gcp_conn_id + self.location = location + self.api_version = api_version + self.body = body + + def _check_input(self): + if all([self.project_id, self.location, self.body]): + if isinstance(self.body, dict) \ + and 'name' in self.body \ + and 'initial_node_count' in self.body: + # Don't throw error + return + # If not dict, then must + elif self.body.name and self.body.initial_node_count: + return + + self.log.error( + 'One of (project_id, location, body, body[\'name\'], ' + 'body[\'initial_node_count\']) is missing or incorrect') + raise AirflowException('Operator has incorrect or missing input.') + + def execute(self, context): + self._check_input() + hook = GKEClusterHook(self.project_id, self.location) + create_op = hook.create_cluster(cluster=self.body) + return create_op http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/7255589f/docs/code.rst ---------------------------------------------------------------------- diff --git a/docs/code.rst b/docs/code.rst index 4ab187c..3b51484 100644 --- a/docs/code.rst +++ b/docs/code.rst @@ -148,6 +148,8 @@ Operators .. autoclass:: airflow.contrib.operators.emr_terminate_job_flow_operator.EmrTerminateJobFlowOperator .. autoclass:: airflow.contrib.operators.file_to_gcs.FileToGoogleCloudStorageOperator .. autoclass:: airflow.contrib.operators.file_to_wasb.FileToWasbOperator +.. autoclass:: airflow.contrib.operators.gcp_container_operator.GKEClusterCreateOperator +.. autoclass:: airflow.contrib.operators.gcp_container_operator.GKEClusterDeleteOperator .. autoclass:: airflow.contrib.operators.gcs_download_operator.GoogleCloudStorageDownloadOperator .. autoclass:: airflow.contrib.operators.gcs_list_operator.GoogleCloudStorageListOperator .. autoclass:: airflow.contrib.operators.gcs_operator.GoogleCloudStorageCreateBucketOperator @@ -375,6 +377,7 @@ Community contributed hooks .. autoclass:: airflow.contrib.hooks.ftp_hook.FTPHook .. autoclass:: airflow.contrib.hooks.ftp_hook.FTPSHook .. autoclass:: airflow.contrib.hooks.gcp_api_base_hook.GoogleCloudBaseHook +.. autoclass:: airflow.contrib.hooks.gcp_container_hook.GKEClusterHook .. autoclass:: airflow.contrib.hooks.gcp_dataflow_hook.DataFlowHook .. autoclass:: airflow.contrib.hooks.gcp_dataproc_hook.DataProcHook .. autoclass:: airflow.contrib.hooks.gcp_mlengine_hook.MLEngineHook http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/7255589f/docs/integration.rst ---------------------------------------------------------------------- diff --git a/docs/integration.rst b/docs/integration.rst index 3d43685..972726b 100644 --- a/docs/integration.rst +++ b/docs/integration.rst @@ -733,3 +733,30 @@ GoogleCloudStorageHook .. autoclass:: airflow.contrib.hooks.gcs_hook.GoogleCloudStorageHook :members: + +Google Kubernetes Engine +'''''''''''''''''''''''' + +Google Kubernetes Engine Cluster Operators +"""""""""""""""""""""""""""""""""""""""""" + +- :ref:`GKEClusterCreateOperator` : Creates a Kubernetes Cluster in Google Cloud Platform +- :ref:`GKEClusterDeleteOperator` : Deletes a Kubernetes Cluster in Google Cloud Platform + +GKEClusterCreateOperator +^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: airflow.contrib.operators.gcp_container_operator.GKEClusterCreateOperator +.. _GKEClusterCreateOperator: + +GKEClusterDeleteOperator +^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: airflow.contrib.operators.gcp_container_operator.GKEClusterDeleteOperator +.. _GKEClusterDeleteOperator: + +Google Kubernetes Engine Hook +""""""""""""""""""""""""""""" + +.. autoclass:: airflow.contrib.hooks.gcp_container_hook.GKEClusterHook + :members: http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/7255589f/setup.py ---------------------------------------------------------------------- diff --git a/setup.py b/setup.py index 368652a..a6ce6c9 100644 --- a/setup.py +++ b/setup.py @@ -148,6 +148,7 @@ gcp_api = [ 'google-api-python-client>=1.6.0, <2.0.0dev', 'google-auth>=1.0.0, <2.0.0dev', 'google-auth-httplib2>=0.0.1', + 'google-cloud-container>=0.1.1', 'PyOpenSSL', 'pandas-gbq' ] http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/7255589f/tests/contrib/hooks/test_gcp_container_hook.py ---------------------------------------------------------------------- diff --git a/tests/contrib/hooks/test_gcp_container_hook.py b/tests/contrib/hooks/test_gcp_container_hook.py new file mode 100644 index 0000000..f3705ea --- /dev/null +++ b/tests/contrib/hooks/test_gcp_container_hook.py @@ -0,0 +1,240 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +import unittest + +from airflow import AirflowException +from airflow.contrib.hooks.gcp_container_hook import GKEClusterHook + +try: + from unittest import mock +except ImportError: + try: + import mock + except ImportError: + mock = None + +TASK_ID = 'test-gke-cluster-operator' +CLUSTER_NAME = 'test-cluster' +TEST_PROJECT_ID = 'test-project' +ZONE = 'test-zone' + + +class GKEClusterHookDeleteTest(unittest.TestCase): + def setUp(self): + with mock.patch.object(GKEClusterHook, "__init__", return_value=None): + self.gke_hook = GKEClusterHook(None, None, None) + self.gke_hook.project_id = TEST_PROJECT_ID + self.gke_hook.location = ZONE + self.gke_hook.client = mock.Mock() + + @mock.patch("airflow.contrib.hooks.gcp_container_hook.GKEClusterHook._dict_to_proto") + @mock.patch( + "airflow.contrib.hooks.gcp_container_hook.GKEClusterHook.wait_for_operation") + def test_delete_cluster(self, wait_mock, convert_mock): + retry_mock, timeout_mock = mock.Mock(), mock.Mock() + + client_delete = self.gke_hook.client.delete_cluster = mock.Mock() + + self.gke_hook.delete_cluster(name=CLUSTER_NAME, retry=retry_mock, + timeout=timeout_mock) + + client_delete.assert_called_with(project_id=TEST_PROJECT_ID, zone=ZONE, + cluster_id=CLUSTER_NAME, + retry=retry_mock, timeout=timeout_mock) + wait_mock.assert_called_with(client_delete.return_value) + convert_mock.assert_not_called() + + @mock.patch("airflow.contrib.hooks.gcp_container_hook.GKEClusterHook._dict_to_proto") + @mock.patch( + "airflow.contrib.hooks.gcp_container_hook.GKEClusterHook.wait_for_operation") + def test_delete_cluster_error(self, wait_mock, convert_mock): + # To force an error + self.gke_hook.client.delete_cluster.side_effect = AirflowException('400') + + with self.assertRaises(AirflowException): + self.gke_hook.delete_cluster(None) + wait_mock.assert_not_called() + convert_mock.assert_not_called() + + +class GKEClusterHookCreateTest(unittest.TestCase): + def setUp(self): + with mock.patch.object(GKEClusterHook, "__init__", return_value=None): + self.gke_hook = GKEClusterHook(None, None, None) + self.gke_hook.project_id = TEST_PROJECT_ID + self.gke_hook.location = ZONE + self.gke_hook.client = mock.Mock() + + @mock.patch("airflow.contrib.hooks.gcp_container_hook.GKEClusterHook._dict_to_proto") + @mock.patch( + "airflow.contrib.hooks.gcp_container_hook.GKEClusterHook.wait_for_operation") + def test_create_cluster_proto(self, wait_mock, convert_mock): + from google.cloud.container_v1.proto.cluster_service_pb2 import Cluster + + mock_cluster_proto = Cluster() + mock_cluster_proto.name = CLUSTER_NAME + + retry_mock, timeout_mock = mock.Mock(), mock.Mock() + + client_create = self.gke_hook.client.create_cluster = mock.Mock() + + self.gke_hook.create_cluster(mock_cluster_proto, retry=retry_mock, + timeout=timeout_mock) + + client_create.assert_called_with(project_id=TEST_PROJECT_ID, zone=ZONE, + cluster=mock_cluster_proto, + retry=retry_mock, timeout=timeout_mock) + wait_mock.assert_called_with(client_create.return_value) + convert_mock.assert_not_called() + + @mock.patch("airflow.contrib.hooks.gcp_container_hook.GKEClusterHook._dict_to_proto") + @mock.patch( + "airflow.contrib.hooks.gcp_container_hook.GKEClusterHook.wait_for_operation") + def test_delete_cluster_dict(self, wait_mock, convert_mock): + mock_cluster_dict = {'name': CLUSTER_NAME} + retry_mock, timeout_mock = mock.Mock(), mock.Mock() + + client_create = self.gke_hook.client.create_cluster = mock.Mock() + proto_mock = convert_mock.return_value = mock.Mock() + + self.gke_hook.create_cluster(mock_cluster_dict, retry=retry_mock, + timeout=timeout_mock) + + client_create.assert_called_with(project_id=TEST_PROJECT_ID, zone=ZONE, + cluster=proto_mock, + retry=retry_mock, timeout=timeout_mock) + wait_mock.assert_called_with(client_create.return_value) + self.assertEqual(convert_mock.call_args[1]['py_dict'], mock_cluster_dict) + + @mock.patch("airflow.contrib.hooks.gcp_container_hook.GKEClusterHook._dict_to_proto") + @mock.patch( + "airflow.contrib.hooks.gcp_container_hook.GKEClusterHook.wait_for_operation") + def test_create_cluster_error(self, wait_mock, convert_mock): + # to force an error + mock_cluster_proto = None + + with self.assertRaises(AirflowException): + self.gke_hook.create_cluster(mock_cluster_proto) + wait_mock.assert_not_called() + convert_mock.assert_not_called() + + +class GKEClusterHookGetTest(unittest.TestCase): + def setUp(self): + with mock.patch.object(GKEClusterHook, "__init__", return_value=None): + self.gke_hook = GKEClusterHook(None, None, None) + self.gke_hook.project_id = TEST_PROJECT_ID + self.gke_hook.location = ZONE + self.gke_hook.client = mock.Mock() + + def test_get_cluster(self): + retry_mock, timeout_mock = mock.Mock(), mock.Mock() + + client_get = self.gke_hook.client.get_cluster = mock.Mock() + + self.gke_hook.get_cluster(name=CLUSTER_NAME, retry=retry_mock, + timeout=timeout_mock) + + client_get.assert_called_with(project_id=TEST_PROJECT_ID, zone=ZONE, + cluster_id=CLUSTER_NAME, + retry=retry_mock, timeout=timeout_mock) + + +class GKEClusterHookTest(unittest.TestCase): + + def setUp(self): + with mock.patch.object(GKEClusterHook, "__init__", return_value=None): + self.gke_hook = GKEClusterHook(None, None, None) + self.gke_hook.project_id = TEST_PROJECT_ID + self.gke_hook.location = ZONE + self.gke_hook.client = mock.Mock() + + def test_get_operation(self): + self.gke_hook.client.get_operation = mock.Mock() + self.gke_hook.get_operation('TEST_OP') + self.gke_hook.client.get_operation.assert_called_with(project_id=TEST_PROJECT_ID, + zone=ZONE, + operation_id='TEST_OP') + + def test_append_label(self): + key = 'test-key' + val = 'test-val' + mock_proto = mock.Mock() + self.gke_hook._append_label(mock_proto, key, val) + mock_proto.resource_labels.update.assert_called_with({key: val}) + + def test_append_label_replace(self): + key = 'test-key' + val = 'test.val+this' + mock_proto = mock.Mock() + self.gke_hook._append_label(mock_proto, key, val) + mock_proto.resource_labels.update.assert_called_with({key: 'test-val-this'}) + + @mock.patch("time.sleep") + def test_wait_for_response_done(self, time_mock): + from google.cloud.container_v1.gapic.enums import Operation + mock_op = mock.Mock() + mock_op.status = Operation.Status.DONE + self.gke_hook.wait_for_operation(mock_op) + self.assertEqual(time_mock.call_count, 1) + + @mock.patch("time.sleep") + def test_wait_for_response_exception(self, time_mock): + from google.cloud.container_v1.gapic.enums import Operation + from google.cloud.exceptions import GoogleCloudError + + mock_op = mock.Mock() + mock_op.status = Operation.Status.ABORTING + + with self.assertRaises(GoogleCloudError): + self.gke_hook.wait_for_operation(mock_op) + self.assertEqual(time_mock.call_count, 1) + + @mock.patch("airflow.contrib.hooks.gcp_container_hook.GKEClusterHook.get_operation") + @mock.patch("time.sleep") + def test_wait_for_response_running(self, time_mock, operation_mock): + from google.cloud.container_v1.gapic.enums import Operation + + running_op, done_op, pending_op = mock.Mock(), mock.Mock(), mock.Mock() + running_op.status = Operation.Status.RUNNING + done_op.status = Operation.Status.DONE + pending_op.status = Operation.Status.PENDING + + # Status goes from Running -> Pending -> Done + operation_mock.side_effect = [pending_op, done_op] + self.gke_hook.wait_for_operation(running_op) + + self.assertEqual(time_mock.call_count, 3) + operation_mock.assert_any_call(running_op.name) + operation_mock.assert_any_call(pending_op.name) + self.assertEqual(operation_mock.call_count, 2) + + @mock.patch("google.protobuf.json_format.Parse") + @mock.patch("json.dumps") + def test_dict_to_proto(self, dumps_mock, parse_mock): + mock_dict = {'name': 'test'} + mock_proto = mock.Mock() + + dumps_mock.return_value = mock.Mock() + + self.gke_hook._dict_to_proto(mock_dict, mock_proto) + + dumps_mock.assert_called_with(mock_dict) + parse_mock.assert_called_with(dumps_mock(), mock_proto) http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/7255589f/tests/contrib/operators/test_gcp_container_operator.py ---------------------------------------------------------------------- diff --git a/tests/contrib/operators/test_gcp_container_operator.py b/tests/contrib/operators/test_gcp_container_operator.py new file mode 100644 index 0000000..0f67290 --- /dev/null +++ b/tests/contrib/operators/test_gcp_container_operator.py @@ -0,0 +1,125 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import unittest + +from airflow import AirflowException +from airflow.contrib.operators.gcp_container_operator import GKEClusterCreateOperator, \ + GKEClusterDeleteOperator + +try: + from unittest import mock +except ImportError: + try: + import mock + except ImportError: + mock = None + +PROJECT_ID = 'test-id' +PROJECT_LOCATION = 'test-location' +PROJECT_TASK_ID = 'test-task-id' +CLUSTER_NAME = 'test-cluster-name' + +PROJECT_BODY = {'name': 'test-name'} +PROJECT_BODY_CREATE = {'name': 'test-name', 'initial_node_count': 1} + + +class GoogleCloudPlatformContainerOperatorTest(unittest.TestCase): + + @mock.patch('airflow.contrib.operators.gcp_container_operator.GKEClusterHook') + def test_create_execute(self, mock_hook): + operator = GKEClusterCreateOperator(project_id=PROJECT_ID, + location=PROJECT_LOCATION, + body=PROJECT_BODY_CREATE, + task_id=PROJECT_TASK_ID) + + operator.execute(None) + mock_hook.return_value.create_cluster.assert_called_once_with( + cluster=PROJECT_BODY_CREATE) + + @mock.patch('airflow.contrib.operators.gcp_container_operator.GKEClusterHook') + def test_create_execute_error_body(self, mock_hook): + with self.assertRaises(AirflowException): + operator = GKEClusterCreateOperator(project_id=PROJECT_ID, + location=PROJECT_LOCATION, + body=None, + task_id=PROJECT_TASK_ID) + + operator.execute(None) + mock_hook.return_value.create_cluster.assert_not_called() + + @mock.patch('airflow.contrib.operators.gcp_container_operator.GKEClusterHook') + def test_create_execute_error_project_id(self, mock_hook): + with self.assertRaises(AirflowException): + operator = GKEClusterCreateOperator(location=PROJECT_LOCATION, + body=PROJECT_BODY, + task_id=PROJECT_TASK_ID) + + operator.execute(None) + mock_hook.return_value.create_cluster.assert_not_called() + + @mock.patch('airflow.contrib.operators.gcp_container_operator.GKEClusterHook') + def test_create_execute_error_location(self, mock_hook): + with self.assertRaises(AirflowException): + operator = GKEClusterCreateOperator(project_id=PROJECT_ID, + body=PROJECT_BODY, + task_id=PROJECT_TASK_ID) + + operator.execute(None) + mock_hook.return_value.create_cluster.assert_not_called() + + @mock.patch('airflow.contrib.operators.gcp_container_operator.GKEClusterHook') + def test_delete_execute(self, mock_hook): + operator = GKEClusterDeleteOperator(project_id=PROJECT_ID, + name=CLUSTER_NAME, + location=PROJECT_LOCATION, + task_id=PROJECT_TASK_ID) + + operator.execute(None) + mock_hook.return_value.delete_cluster.assert_called_once_with( + name=CLUSTER_NAME) + + @mock.patch('airflow.contrib.operators.gcp_container_operator.GKEClusterHook') + def test_delete_execute_error_project_id(self, mock_hook): + with self.assertRaises(AirflowException): + operator = GKEClusterDeleteOperator(location=PROJECT_LOCATION, + name=CLUSTER_NAME, + task_id=PROJECT_TASK_ID) + operator.execute(None) + mock_hook.return_value.delete_cluster.assert_not_called() + + @mock.patch('airflow.contrib.operators.gcp_container_operator.GKEClusterHook') + def test_delete_execute_error_cluster_name(self, mock_hook): + with self.assertRaises(AirflowException): + operator = GKEClusterDeleteOperator(project_id=PROJECT_ID, + location=PROJECT_LOCATION, + task_id=PROJECT_TASK_ID) + + operator.execute(None) + mock_hook.return_value.delete_cluster.assert_not_called() + + @mock.patch('airflow.contrib.operators.gcp_container_operator.GKEClusterHook') + def test_delete_execute_error_location(self, mock_hook): + with self.assertRaises(AirflowException): + operator = GKEClusterDeleteOperator(project_id=PROJECT_ID, + name=CLUSTER_NAME, + task_id=PROJECT_TASK_ID) + + operator.execute(None) + mock_hook.return_value.delete_cluster.assert_not_called()
