Repository: incubator-airflow
Updated Branches:
  refs/heads/v1-9-test [created] 78124a2a6


[AIRFLOW-1273] Add Google Cloud ML version and model operators

https://issues.apache.org/jira/browse/AIRFLOW-1273

Closes #2379 from N3da/master


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/265b293a
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/265b293a
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/265b293a

Branch: refs/heads/v1-9-test
Commit: 265b293a7860d070458b6984594138bfae1fa5b7
Parents: e870a8e
Author: Neda Mirian <ne...@google.com>
Authored: Tue Jun 27 09:39:00 2017 -0700
Committer: Chris Riccomini <criccom...@apache.org>
Committed: Tue Jun 27 09:39:18 2017 -0700

----------------------------------------------------------------------
 airflow/contrib/hooks/__init__.py             |   1 +
 airflow/contrib/hooks/gcp_cloudml_hook.py     | 167 ++++++++++++++
 airflow/contrib/operators/cloudml_operator.py | 178 ++++++++++++++
 airflow/utils/db.py                           |   4 +
 tests/contrib/hooks/test_gcp_cloudml_hook.py  | 255 +++++++++++++++++++++
 5 files changed, 605 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/265b293a/airflow/contrib/hooks/__init__.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/hooks/__init__.py 
b/airflow/contrib/hooks/__init__.py
index 182a49f..4941314 100644
--- a/airflow/contrib/hooks/__init__.py
+++ b/airflow/contrib/hooks/__init__.py
@@ -40,6 +40,7 @@ _hooks = {
     'qubole_hook': ['QuboleHook'],
     'gcs_hook': ['GoogleCloudStorageHook'],
     'datastore_hook': ['DatastoreHook'],
+    'gcp_cloudml_hook': ['CloudMLHook'],
     'gcp_dataproc_hook': ['DataProcHook'],
     'gcp_dataflow_hook': ['DataFlowHook'],
     'spark_submit_operator': ['SparkSubmitOperator'],

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/265b293a/airflow/contrib/hooks/gcp_cloudml_hook.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/hooks/gcp_cloudml_hook.py 
b/airflow/contrib/hooks/gcp_cloudml_hook.py
new file mode 100644
index 0000000..e722b2a
--- /dev/null
+++ b/airflow/contrib/hooks/gcp_cloudml_hook.py
@@ -0,0 +1,167 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import logging
+import random
+import time
+from airflow import settings
+from airflow.contrib.hooks.gcp_api_base_hook import GoogleCloudBaseHook
+from apiclient.discovery import build
+from apiclient import errors
+from oauth2client.client import GoogleCredentials
+
+logging.getLogger('GoogleCloudML').setLevel(settings.LOGGING_LEVEL)
+
+
+def _poll_with_exponential_delay(request, max_n, is_done_func, is_error_func):
+
+    for i in range(0, max_n):
+        try:
+            response = request.execute()
+            if is_error_func(response):
+                raise ValueError('The response contained an error: 
{}'.format(response))
+            elif is_done_func(response):
+                logging.info('Operation is done: {}'.format(response))
+                return response
+            else:
+                time.sleep((2**i) + (random.randint(0, 1000) / 1000))
+        except errors.HttpError as e:
+            if e.resp.status != 429:
+                logging.info('Something went wrong. Not retrying: 
{}'.format(e))
+                raise e
+            else:
+                time.sleep((2**i) + (random.randint(0, 1000) / 1000))
+
+
+class CloudMLHook(GoogleCloudBaseHook):
+
+    def __init__(self, gcp_conn_id='google_cloud_default', delegate_to=None):
+        super(CloudMLHook, self).__init__(gcp_conn_id, delegate_to)
+        self._cloudml = self.get_conn()
+
+    def get_conn(self):
+        """
+        Returns a Google CloudML service object.
+        """
+        credentials = GoogleCredentials.get_application_default()
+        return build('ml', 'v1', credentials=credentials)
+
+    def create_version(self, project_name, model_name, version_spec):
+        """
+        Creates the Version on Cloud ML.
+
+        Returns the operation if the version was created successfully and 
raises
+        an error otherwise.
+        """
+        parent_name = 'projects/{}/models/{}'.format(project_name, model_name)
+        create_request = self._cloudml.projects().models().versions().create(
+            parent=parent_name, body=version_spec)
+        response = create_request.execute()
+        get_request = self._cloudml.projects().operations().get(
+            name=response['name'])
+
+        return _poll_with_exponential_delay(
+            request=get_request,
+            max_n=9,
+            is_done_func=lambda resp: resp.get('done', False),
+            is_error_func=lambda resp: resp.get('error', None) is not None)
+
+    def set_default_version(self, project_name, model_name, version_name):
+        """
+        Sets a version to be the default. Blocks until finished.
+        """
+        full_version_name = 'projects/{}/models/{}/versions/{}'.format(
+            project_name, model_name, version_name)
+        request = self._cloudml.projects().models().versions().setDefault(
+            name=full_version_name, body={})
+
+        try:
+            response = request.execute()
+            logging.info('Successfully set version: {} to 
default'.format(response))
+            return response
+        except errors.HttpError as e:
+            logging.error('Something went wrong: {}'.format(e))
+            raise e
+
+    def list_versions(self, project_name, model_name):
+        """
+        Lists all available versions of a model. Blocks until finished.
+        """
+        result = []
+        full_parent_name = 'projects/{}/models/{}'.format(
+            project_name, model_name)
+        request = self._cloudml.projects().models().versions().list(
+            parent=full_parent_name, pageSize=100)
+
+        response = request.execute()
+        next_page_token = response.get('nextPageToken', None)
+        result.extend(response.get('versions', []))
+        while next_page_token is not None:
+            next_request = self._cloudml.projects().models().versions().list(
+                parent=full_parent_name,
+                pageToken=next_page_token,
+                pageSize=100)
+            response = next_request.execute()
+            next_page_token = response.get('nextPageToken', None)
+            result.extend(response.get('versions', []))
+            time.sleep(5)
+        return result
+
+    def delete_version(self, project_name, model_name, version_name):
+        """
+        Deletes the given version of a model. Blocks until finished.
+        """
+        full_name = 'projects/{}/models/{}/versions/{}'.format(
+            project_name, model_name, version_name)
+        delete_request = self._cloudml.projects().models().versions().delete(
+            name=full_name)
+        response = delete_request.execute()
+        get_request = self._cloudml.projects().operations().get(
+            name=response['name'])
+
+        return _poll_with_exponential_delay(
+            request=get_request,
+            max_n=9,
+            is_done_func=lambda resp: resp.get('done', False),
+            is_error_func=lambda resp: resp.get('error', None) is not None)
+
+    def create_model(self, project_name, model):
+        """
+        Create a Model. Blocks until finished.
+        """
+        assert model['name'] is not None and model['name'] is not ''
+        project = 'projects/{}'.format(project_name)
+
+        request = self._cloudml.projects().models().create(
+            parent=project, body=model)
+        return request.execute()
+
+    def get_model(self, project_name, model_name):
+        """
+        Gets a Model. Blocks until finished.
+        """
+        assert model_name is not None and model_name is not ''
+        full_model_name = 'projects/{}/models/{}'.format(
+            project_name, model_name)
+        request = self._cloudml.projects().models().get(name=full_model_name)
+        try:
+            return request.execute()
+        except errors.HttpError as e:
+            if e.resp.status == 404:
+                logging.error('Model was not found: {}'.format(e))
+                return None
+            raise e

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/265b293a/airflow/contrib/operators/cloudml_operator.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/operators/cloudml_operator.py 
b/airflow/contrib/operators/cloudml_operator.py
new file mode 100644
index 0000000..b0b6e91
--- /dev/null
+++ b/airflow/contrib/operators/cloudml_operator.py
@@ -0,0 +1,178 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the 'License'); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import logging
+from airflow import settings
+from airflow.contrib.hooks.gcp_cloudml_hook import CloudMLHook
+from airflow.operators import BaseOperator
+from airflow.utils.decorators import apply_defaults
+
+logging.getLogger('GoogleCloudML').setLevel(settings.LOGGING_LEVEL)
+
+
+class CloudMLVersionOperator(BaseOperator):
+    """
+    Operator for managing a Google Cloud ML version.
+
+    :param model_name: The name of the Google Cloud ML model that the version
+        belongs to.
+    :type model_name: string
+
+    :param project_name: The Google Cloud project name to which CloudML
+        model belongs.
+    :type project_name: string
+
+    :param version: A dictionary containing the information about the version.
+        If the `operation` is `create`, `version` should contain all the
+        information about this version such as name, and deploymentUrl.
+        If the `operation` is `get` or `delete`, the `version` parameter
+        should contain the `name` of the version.
+        If it is None, the only `operation` possible would be `list`.
+    :type version: dict
+
+    :param gcp_conn_id: The connection ID to use when fetching connection info.
+    :type gcp_conn_id: string
+
+    :param operation: The operation to perform. Available operations are:
+        'create': Creates a new version in the model specified by `model_name`,
+            in which case the `version` parameter should contain all the
+            information to create that version
+            (e.g. `name`, `deploymentUrl`).
+        'get': Gets full information of a particular version in the model
+            specified by `model_name`.
+            The name of the version should be specified in the `version`
+            parameter.
+
+        'list': Lists all available versions of the model specified
+            by `model_name`.
+
+        'delete': Deletes the version specified in `version` parameter from the
+            model specified by `model_name`).
+            The name of the version should be specified in the `version`
+            parameter.
+     :type operation: string
+
+    :param delegate_to: The account to impersonate, if any.
+        For this to work, the service account making the request must have
+        domain-wide delegation enabled.
+    :type delegate_to: string
+    """
+
+
+    template_fields = [
+        '_model_name',
+        '_version',
+    ]
+
+    @apply_defaults
+    def __init__(self,
+                 model_name,
+                 project_name,
+                 version=None,
+                 gcp_conn_id='google_cloud_default',
+                 operation='create',
+                 delegate_to=None,
+                 *args,
+                 **kwargs):
+
+        super(CloudMLVersionOperator, self).__init__(*args, **kwargs)
+        self._model_name = model_name
+        self._version = version
+        self._gcp_conn_id = gcp_conn_id
+        self._delegate_to = delegate_to
+        self._project_name = project_name
+        self._operation = operation
+
+    def execute(self, context):
+        hook = CloudMLHook(
+            gcp_conn_id=self._gcp_conn_id, delegate_to=self._delegate_to)
+
+        if self._operation == 'create':
+            assert self._version is not None
+            return hook.create_version(self._project_name, self._model_name,
+                                       self._version)
+        elif self._operation == 'set_default':
+            return hook.set_default_version(
+                self._project_name, self._model_name,
+                self._version['name'])
+        elif self._operation == 'list':
+            return hook.list_versions(self._project_name, self._model_name)
+        elif self._operation == 'delete':
+            return hook.delete_version(self._project_name, self._model_name,
+                                       self._version['name'])
+        else:
+            raise ValueError('Unknown operation: {}'.format(self._operation))
+
+
+class CloudMLModelOperator(BaseOperator):
+    """
+    Operator for managing a Google Cloud ML model.
+
+    :param model: A dictionary containing the information about the model.
+        If the `operation` is `create`, then the `model` parameter should
+        contain all the information about this model such as `name`.
+
+        If the `operation` is `get`, the `model` parameter
+        should contain the `name` of the model.
+    :type model: dict
+
+    :param project_name: The Google Cloud project name to which CloudML
+        model belongs.
+    :type project_name: string
+
+    :param gcp_conn_id: The connection ID to use when fetching connection info.
+    :type gcp_conn_id: string
+
+    :param operation: The operation to perform. Available operations are:
+        'create': Creates a new model as provided by the `model` parameter.
+        'get': Gets a particular model where the name is specified in `model`.
+
+    :param delegate_to: The account to impersonate, if any.
+        For this to work, the service account making the request must have
+        domain-wide delegation enabled.
+    :type delegate_to: string
+    """
+
+    template_fields = [
+        '_model',
+    ]
+
+    @apply_defaults
+    def __init__(self,
+                 model,
+                 project_name,
+                 gcp_conn_id='google_cloud_default',
+                 operation='create',
+                 delegate_to=None,
+                 *args,
+                 **kwargs):
+        super(CloudMLModelOperator, self).__init__(*args, **kwargs)
+        self._model = model
+        self._operation = operation
+        self._gcp_conn_id = gcp_conn_id
+        self._delegate_to = delegate_to
+        self._project_name = project_name
+
+    def execute(self, context):
+        hook = CloudMLHook(
+            gcp_conn_id=self._gcp_conn_id, delegate_to=self._delegate_to)
+        if self._operation == 'create':
+            hook.create_model(self._project_name, self._model)
+        elif self._operation == 'get':
+            hook.get_model(self._project_name, self._model['name'])
+        else:
+            raise ValueError('Unknown operation: {}'.format(self._operation))

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/265b293a/airflow/utils/db.py
----------------------------------------------------------------------
diff --git a/airflow/utils/db.py b/airflow/utils/db.py
index 54254f6..04b1512 100644
--- a/airflow/utils/db.py
+++ b/airflow/utils/db.py
@@ -130,6 +130,10 @@ def initdb():
             schema='hive', port=3400))
     merge_conn(
         models.Connection(
+            conn_id='google_cloud_default', conn_type='google_cloud_platform',
+            schema='default',))
+    merge_conn(
+        models.Connection(
             conn_id='hive_cli_default', conn_type='hive_cli',
             schema='default',))
     merge_conn(

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/265b293a/tests/contrib/hooks/test_gcp_cloudml_hook.py
----------------------------------------------------------------------
diff --git a/tests/contrib/hooks/test_gcp_cloudml_hook.py 
b/tests/contrib/hooks/test_gcp_cloudml_hook.py
new file mode 100644
index 0000000..aa50e69
--- /dev/null
+++ b/tests/contrib/hooks/test_gcp_cloudml_hook.py
@@ -0,0 +1,255 @@
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import mock
+import unittest
+try: # python 2
+    from urlparse import urlparse, parse_qsl
+except ImportError: #python 3
+    from urllib.parse import urlparse, parse_qsl
+
+from airflow.contrib.hooks import gcp_cloudml_hook as hook
+from apiclient.discovery import build
+from apiclient.http import HttpMockSequence
+from oauth2client.contrib.gce import HttpAccessTokenRefreshError
+
+cml_available = True
+try:
+    hook.CloudMLHook().get_conn()
+except HttpAccessTokenRefreshError:
+    cml_available = False
+
+
+class _TestCloudMLHook(object):
+
+    def __init__(self, test_cls, responses, expected_requests):
+        """
+        Init method.
+
+        Usage example:
+        with _TestCloudMLHook(self, responses, expected_requests) as hook:
+            self.run_my_test(hook)
+
+        Args:
+          test_cls: The caller's instance used for test communication.
+          responses: A list of (dict_response, response_content) tuples.
+          expected_requests: A list of (uri, http_method, body) tuples.
+        """
+
+        self._test_cls = test_cls
+        self._responses = responses
+        self._expected_requests = [
+            self._normalize_requests_for_comparison(x[0], x[1], x[2]) for x in 
expected_requests]
+        self._actual_requests = []
+
+    def _normalize_requests_for_comparison(self, uri, http_method, body):
+        parts = urlparse(uri)
+        return (parts._replace(query=set(parse_qsl(parts.query))), 
http_method, body)
+
+    def __enter__(self):
+        http = HttpMockSequence(self._responses)
+        native_request_method = http.request
+
+        # Collecting requests to validate at __exit__.
+        def _request_wrapper(*args, **kwargs):
+            self._actual_requests.append(args + (kwargs['body'],))
+            return native_request_method(*args, **kwargs)
+
+        http.request = _request_wrapper
+        service_mock = build('ml', 'v1', http=http)
+        with mock.patch.object(
+                hook.CloudMLHook, 'get_conn', return_value=service_mock):
+            return hook.CloudMLHook()
+
+    def __exit__(self, *args):
+        # Propogating exceptions here since assert will silence them.
+        if any(args):
+            return None
+        self._test_cls.assertEquals(
+            [self._normalize_requests_for_comparison(x[0], x[1], x[2]) for x 
in self._actual_requests], self._expected_requests)
+
+
+class TestCloudMLHook(unittest.TestCase):
+
+    def setUp(self):
+        pass
+
+    _SKIP_IF = unittest.skipIf(not cml_available,
+                               'CloudML is not available to run tests')
+    _SERVICE_URI_PREFIX = 'https://ml.googleapis.com/v1/'
+
+    @_SKIP_IF
+    def test_create_version(self):
+        project = 'test-project'
+        model_name = 'test-model'
+        version = 'test-version'
+        operation_name = 'projects/{}/operations/test-operation'.format(
+            project)
+
+        response_body = {'name': operation_name, 'done': True}
+        succeeded_response = ({'status': '200'}, json.dumps(response_body))
+
+        expected_requests = [
+            ('{}projects/{}/models/{}/versions?alt=json'.format(
+                self._SERVICE_URI_PREFIX, project, model_name), 'POST',
+             '"{}"'.format(version)),
+            ('{}{}?alt=json'.format(self._SERVICE_URI_PREFIX, operation_name),
+             'GET', None),
+        ]
+
+        with _TestCloudMLHook(
+                self,
+                responses=[succeeded_response] * 2,
+                expected_requests=expected_requests) as cml_hook:
+            create_version_response = cml_hook.create_version(
+                project_name=project, model_name=model_name, 
version_spec=version)
+            self.assertEquals(create_version_response, response_body)
+
+    @_SKIP_IF
+    def test_set_default_version(self):
+        project = 'test-project'
+        model_name = 'test-model'
+        version = 'test-version'
+        operation_name = 'projects/{}/operations/test-operation'.format(
+            project)
+
+        response_body = {'name': operation_name, 'done': True}
+        succeeded_response = ({'status': '200'}, json.dumps(response_body))
+
+        expected_requests = [
+            ('{}projects/{}/models/{}/versions/{}:setDefault?alt=json'.format(
+                self._SERVICE_URI_PREFIX, project, model_name, version), 
'POST',
+             '{}'),
+        ]
+
+        with _TestCloudMLHook(
+                self,
+                responses=[succeeded_response],
+                expected_requests=expected_requests) as cml_hook:
+            set_default_version_response = cml_hook.set_default_version(
+                project_name=project, model_name=model_name, 
version_name=version)
+            self.assertEquals(set_default_version_response, response_body)
+
+    @_SKIP_IF
+    def test_list_versions(self):
+        project = 'test-project'
+        model_name = 'test-model'
+        operation_name = 'projects/{}/operations/test-operation'.format(
+            project)
+
+        # This test returns the versions one at a time.
+        versions = ['ver_{}'.format(ix) for ix in range(3)]
+
+        response_bodies = [{'name': operation_name, 'nextPageToken': ix, 
'versions': [
+            ver]} for ix, ver in enumerate(versions)]
+        response_bodies[-1].pop('nextPageToken')
+        responses = [({'status': '200'}, json.dumps(body))
+                     for body in response_bodies]
+
+        expected_requests = [
+            ('{}projects/{}/models/{}/versions?alt=json&pageSize=100'.format(
+                self._SERVICE_URI_PREFIX, project, model_name), 'GET',
+             None),
+        ] + [
+            
('{}projects/{}/models/{}/versions?alt=json&pageToken={}&pageSize=100'.format(
+                self._SERVICE_URI_PREFIX, project, model_name, ix), 'GET',
+             None) for ix in range(len(versions) - 1)
+        ]
+
+        with _TestCloudMLHook(
+                self,
+                responses=responses,
+                expected_requests=expected_requests) as cml_hook:
+            list_versions_response = cml_hook.list_versions(
+                project_name=project, model_name=model_name)
+            self.assertEquals(list_versions_response, versions)
+
+    @_SKIP_IF
+    def test_delete_version(self):
+        project = 'test-project'
+        model_name = 'test-model'
+        version = 'test-version'
+        operation_name = 'projects/{}/operations/test-operation'.format(
+            project)
+
+        not_done_response_body = {'name': operation_name, 'done': False}
+        done_response_body = {'name': operation_name, 'done': True}
+        not_done_response = (
+            {'status': '200'}, json.dumps(not_done_response_body))
+        succeeded_response = (
+            {'status': '200'}, json.dumps(done_response_body))
+
+        expected_requests = [
+            ('{}projects/{}/models/{}/versions/{}?alt=json'.format(
+                self._SERVICE_URI_PREFIX, project, model_name, version), 
'DELETE',
+             None),
+            ('{}{}?alt=json'.format(self._SERVICE_URI_PREFIX, operation_name),
+             'GET', None),
+        ]
+
+        with _TestCloudMLHook(
+                self,
+                responses=[not_done_response, succeeded_response],
+                expected_requests=expected_requests) as cml_hook:
+            delete_version_response = cml_hook.delete_version(
+                project_name=project, model_name=model_name, 
version_name=version)
+            self.assertEquals(delete_version_response, done_response_body)
+
+    @_SKIP_IF
+    def test_create_model(self):
+        project = 'test-project'
+        model_name = 'test-model'
+        model = {
+            'name': model_name,
+        }
+        response_body = {}
+        succeeded_response = ({'status': '200'}, json.dumps(response_body))
+
+        expected_requests = [
+            ('{}projects/{}/models?alt=json'.format(
+                self._SERVICE_URI_PREFIX, project), 'POST',
+             json.dumps(model)),
+        ]
+
+        with _TestCloudMLHook(
+                self,
+                responses=[succeeded_response],
+                expected_requests=expected_requests) as cml_hook:
+            create_model_response = cml_hook.create_model(
+                project_name=project, model=model)
+            self.assertEquals(create_model_response, response_body)
+
+    @_SKIP_IF
+    def test_get_model(self):
+        project = 'test-project'
+        model_name = 'test-model'
+        response_body = {'model': model_name}
+        succeeded_response = ({'status': '200'}, json.dumps(response_body))
+
+        expected_requests = [
+            ('{}projects/{}/models/{}?alt=json'.format(
+                self._SERVICE_URI_PREFIX, project, model_name), 'GET',
+             None),
+        ]
+
+        with _TestCloudMLHook(
+                self,
+                responses=[succeeded_response],
+                expected_requests=expected_requests) as cml_hook:
+            get_model_response = cml_hook.get_model(
+                project_name=project, model_name=model_name)
+            self.assertEquals(get_model_response, response_body)
+
+
+if __name__ == '__main__':
+    unittest.main()

Reply via email to