Repository: incubator-airflow Updated Branches: refs/heads/master d231dce37 -> 0fc45045a
[AIRFLOW-1271] Add Google CloudML Training Operator Closes #2408 from leomzhong/cloudml_training Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/0fc45045 Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/0fc45045 Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/0fc45045 Branch: refs/heads/master Commit: 0fc45045a27a0b1867410613d6c0edba820e3abf Parents: d231dce Author: Ming Zhong <[email protected]> Authored: Thu Jul 6 11:46:13 2017 -0700 Committer: Chris Riccomini <[email protected]> Committed: Thu Jul 6 11:46:13 2017 -0700 ---------------------------------------------------------------------- airflow/contrib/hooks/gcp_cloudml_hook.py | 82 ++++----- airflow/contrib/operators/cloudml_operator.py | 148 ++++++++++++++- tests/contrib/hooks/test_gcp_cloudml_hook.py | 111 +++++++++++- .../contrib/operators/test_cloudml_operator.py | 179 ++++++++++++++----- 4 files changed, 428 insertions(+), 92 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/0fc45045/airflow/contrib/hooks/gcp_cloudml_hook.py ---------------------------------------------------------------------- diff --git a/airflow/contrib/hooks/gcp_cloudml_hook.py b/airflow/contrib/hooks/gcp_cloudml_hook.py index 3af8508..6f634b2 100644 --- a/airflow/contrib/hooks/gcp_cloudml_hook.py +++ b/airflow/contrib/hooks/gcp_cloudml_hook.py @@ -62,30 +62,37 @@ class CloudMLHook(GoogleCloudBaseHook): credentials = GoogleCredentials.get_application_default() return build('ml', 'v1', credentials=credentials) - def create_job(self, project_name, job): + def create_job(self, project_name, job, use_existing_job_fn=None): """ - Creates and executes a CloudML job. - - Returns the job object if the job was created and finished - successfully, or raises an error otherwise. - - Raises: - apiclient.errors.HttpError: if the job cannot be created - successfully - - project_name is the name of the project to use, such as - 'my-project' - - job is the complete Cloud ML Job object that should be provided to the - Cloud ML API, such as - - { - 'jobId': 'my_job_id', - 'trainingInput': { - 'scaleTier': 'STANDARD_1', - ... - } - } + Launches a CloudML job and wait for it to reach a terminal state. + + :param project_name: The Google Cloud project name within which CloudML + job will be launched. + :type project_name: string + + :param job: CloudML Job object that should be provided to the CloudML + API, such as: + { + 'jobId': 'my_job_id', + 'trainingInput': { + 'scaleTier': 'STANDARD_1', + ... + } + } + :type job: dict + + :param use_existing_job_fn: In case that a CloudML job with the same + job_id already exist, this method (if provided) will decide whether + we should use this existing job, continue waiting for it to finish + and returning the job object. It should accepts a CloudML job + object, and returns a boolean value indicating whether it is OK to + reuse the existing job. If 'use_existing_job_fn' is not provided, + we by default reuse the existing CloudML job. + :type use_existing_job_fn: function + + :return: The CloudML job object if the job successfully reach a + terminal state (which might be FAILED or CANCELLED state). + :rtype: dict """ request = self._cloudml.projects().jobs().create( parent='projects/{}'.format(project_name), @@ -94,29 +101,24 @@ class CloudMLHook(GoogleCloudBaseHook): try: request.execute() - return self._wait_for_job_done(project_name, job_id) except errors.HttpError as e: + # 409 means there is an existing job with the same job ID. if e.resp.status == 409: - existing_job = self._get_job(project_name, job_id) + if use_existing_job_fn is not None: + existing_job = self._get_job(project_name, job_id) + if not use_existing_job_fn(existing_job): + logging.error( + 'Job with job_id {} already exist, but it does ' + 'not match our expectation: {}'.format( + job_id, existing_job)) + raise logging.info( - 'Job with job_id {} already exist: {}.'.format( - job_id, - existing_job)) - - if existing_job.get('predictionInput', None) == \ - job['predictionInput']: - return self._wait_for_job_done(project_name, job_id) - else: - logging.error( - 'Job with job_id {} already exists, but the ' - 'predictionInput mismatch: {}' - .format(job_id, existing_job)) - raise ValueError( - 'Found a existing job with job_id {}, but with ' - 'different predictionInput.'.format(job_id)) + 'Job with job_id {} already exist. Will waiting for it to ' + 'finish'.format(job_id)) else: logging.error('Failed to create CloudML job: {}'.format(e)) raise + return self._wait_for_job_done(project_name, job_id) def _get_job(self, project_name, job_id): """ http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/0fc45045/airflow/contrib/operators/cloudml_operator.py ---------------------------------------------------------------------- diff --git a/airflow/contrib/operators/cloudml_operator.py b/airflow/contrib/operators/cloudml_operator.py index 871cc73..3ad6f5a 100644 --- a/airflow/contrib/operators/cloudml_operator.py +++ b/airflow/contrib/operators/cloudml_operator.py @@ -18,8 +18,9 @@ import logging import re from airflow import settings -from airflow.operators import BaseOperator from airflow.contrib.hooks.gcp_cloudml_hook import CloudMLHook +from airflow.exceptions import AirflowException +from airflow.operators import BaseOperator from airflow.utils.decorators import apply_defaults from apiclient import errors @@ -239,10 +240,14 @@ class CloudMLBatchPredictionOperator(BaseOperator): def execute(self, context): hook = CloudMLHook(self.gcp_conn_id, self.delegate_to) + def check_existing_job(existing_job): + return existing_job.get('predictionInput', None) == \ + self.prediction_job_request['predictionInput'] try: finished_prediction_job = hook.create_job( self.project_id, - self.prediction_job_request) + self.prediction_job_request, + check_existing_job) except errors.HttpError: raise @@ -406,3 +411,142 @@ class CloudMLVersionOperator(BaseOperator): self._version['name']) else: raise ValueError('Unknown operation: {}'.format(self._operation)) + + +class CloudMLTrainingOperator(BaseOperator): + """ + Operator for launching a CloudML training job. + + :param project_name: The Google Cloud project name within which CloudML + training job should run. This field could be templated. + :type project_name: string + + :param job_id: A unique templated id for the submitted Google CloudML + training job. + :type job_id: string + + :param package_uris: A list of package locations for CloudML training job, + which should include the main training program + any additional + dependencies. + :type package_uris: string + + :param training_python_module: The Python module name to run within CloudML + training job after installing 'package_uris' packages. + :type training_python_module: string + + :param training_args: A list of templated command line arguments to pass to + the CloudML training program. + :type training_args: string + + :param region: The Google Compute Engine region to run the CloudML training + job in. This field could be templated. + :type region: string + + :param scale_tier: Resource tier for CloudML training job. + :type scale_tier: string + + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: string + + :param delegate_to: The account to impersonate, if any. + For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: string + + :param mode: Can be one of 'DRY_RUN'/'CLOUD'. In 'DRY_RUN' mode, no real + training job will be launched, but the CloudML training job request + will be printed out. In 'CLOUD' mode, a real CloudML training job + creation request will be issued. + :type mode: string + """ + + template_fields = [ + '_project_name', + '_job_id', + '_package_uris', + '_training_python_module', + '_training_args', + '_region', + '_scale_tier', + ] + + @apply_defaults + def __init__(self, + project_name, + job_id, + package_uris, + training_python_module, + training_args, + region, + scale_tier=None, + gcp_conn_id='google_cloud_default', + delegate_to=None, + mode='PRODUCTION', + *args, + **kwargs): + super(CloudMLTrainingOperator, self).__init__(*args, **kwargs) + self._project_name = project_name + self._job_id = job_id + self._package_uris = package_uris + self._training_python_module = training_python_module + self._training_args = training_args + self._region = region + self._scale_tier = scale_tier + self._gcp_conn_id = gcp_conn_id + self._delegate_to = delegate_to + self._mode = mode + + if not self._project_name: + raise AirflowException('Google Cloud project name is required.') + if not self._job_id: + raise AirflowException( + 'An unique job id is required for Google CloudML training ' + 'job.') + if not package_uris: + raise AirflowException( + 'At least one python package is required for CloudML ' + 'Training job.') + if not training_python_module: + raise AirflowException( + 'Python module name to run after installing required ' + 'packages is required.') + if not self._region: + raise AirflowException('Google Compute Engine region is required.') + + def execute(self, context): + job_id = _normalize_cloudml_job_id(self._job_id) + training_request = { + 'jobId': job_id, + 'trainingInput': { + 'scaleTier': self._scale_tier, + 'packageUris': self._package_uris, + 'pythonModule': self._training_python_module, + 'region': self._region, + 'args': self._training_args, + } + } + + if self._mode == 'DRY_RUN': + logging.info('In dry_run mode.') + logging.info( + 'CloudML Training job request is: {}'.format(training_request)) + return + + hook = CloudMLHook( + gcp_conn_id=self._gcp_conn_id, delegate_to=self._delegate_to) + + # Helper method to check if the existing job's training input is the + # same as the request we get here. + def check_existing_job(existing_job): + return existing_job.get('trainingInput', None) == \ + training_request['trainingInput'] + try: + finished_training_job = hook.create_job( + self._project_name, training_request, check_existing_job) + except errors.HttpError: + raise + + if finished_training_job['state'] != 'SUCCEEDED': + logging.error('CloudML training job failed: {}'.format( + str(finished_training_job))) + raise RuntimeError(finished_training_job['errorMessage']) http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/0fc45045/tests/contrib/hooks/test_gcp_cloudml_hook.py ---------------------------------------------------------------------- diff --git a/tests/contrib/hooks/test_gcp_cloudml_hook.py b/tests/contrib/hooks/test_gcp_cloudml_hook.py index e34e05f..53aba41 100644 --- a/tests/contrib/hooks/test_gcp_cloudml_hook.py +++ b/tests/contrib/hooks/test_gcp_cloudml_hook.py @@ -20,6 +20,7 @@ except ImportError: # python 3 from urllib.parse import urlparse, parse_qsl from airflow.contrib.hooks import gcp_cloudml_hook as hook +from apiclient import errors from apiclient.discovery import build from apiclient.http import HttpMockSequence from oauth2client.contrib.gce import HttpAccessTokenRefreshError @@ -137,8 +138,8 @@ class TestCloudMLHook(unittest.TestCase): expected_requests = [ ('{}projects/{}/models/{}/versions/{}:setDefault?alt=json'.format( - self._SERVICE_URI_PREFIX, project, model_name, version), 'POST', - '{}'), + self._SERVICE_URI_PREFIX, project, model_name, version), + 'POST', '{}'), ] with _TestCloudMLHook( @@ -175,7 +176,8 @@ class TestCloudMLHook(unittest.TestCase): self._SERVICE_URI_PREFIX, project, model_name), 'GET', None), ] + [ - ('{}projects/{}/models/{}/versions?alt=json&pageToken={}&pageSize=100'.format( + ('{}projects/{}/models/{}/versions?alt=json&pageToken={}' + '&pageSize=100'.format( self._SERVICE_URI_PREFIX, project, model_name, ix), 'GET', None) for ix in range(len(versions) - 1) ] @@ -303,6 +305,109 @@ class TestCloudMLHook(unittest.TestCase): project_name=project, job=my_job) self.assertEquals(create_job_response, my_job) + @_SKIP_IF + def test_create_cloudml_job_reuse_existing_job_by_default(self): + project = 'test-project' + job_id = 'test-job-id' + my_job = { + 'jobId': job_id, + 'foo': 4815162342, + 'state': 'SUCCEEDED', + } + response_body = json.dumps(my_job) + job_already_exist_response = ({'status': '409'}, json.dumps({})) + succeeded_response = ({'status': '200'}, response_body) + + create_job_request = ('{}projects/{}/jobs?alt=json'.format( + self._SERVICE_URI_PREFIX, project), 'POST', response_body) + ask_if_done_request = ('{}projects/{}/jobs/{}?alt=json'.format( + self._SERVICE_URI_PREFIX, project, job_id), 'GET', None) + expected_requests = [ + create_job_request, + ask_if_done_request, + ] + responses = [job_already_exist_response, succeeded_response] + + # By default, 'create_job' reuse the existing job. + with _TestCloudMLHook( + self, + responses=responses, + expected_requests=expected_requests) as cml_hook: + create_job_response = cml_hook.create_job( + project_name=project, job=my_job) + self.assertEquals(create_job_response, my_job) + + @_SKIP_IF + def test_create_cloudml_job_check_existing_job(self): + project = 'test-project' + job_id = 'test-job-id' + my_job = { + 'jobId': job_id, + 'foo': 4815162342, + 'state': 'SUCCEEDED', + 'someInput': { + 'input': 'someInput' + } + } + different_job = { + 'jobId': job_id, + 'foo': 4815162342, + 'state': 'SUCCEEDED', + 'someInput': { + 'input': 'someDifferentInput' + } + } + + my_job_response_body = json.dumps(my_job) + different_job_response_body = json.dumps(different_job) + job_already_exist_response = ({'status': '409'}, json.dumps({})) + different_job_response = ({'status': '200'}, + different_job_response_body) + + create_job_request = ('{}projects/{}/jobs?alt=json'.format( + self._SERVICE_URI_PREFIX, project), 'POST', my_job_response_body) + ask_if_done_request = ('{}projects/{}/jobs/{}?alt=json'.format( + self._SERVICE_URI_PREFIX, project, job_id), 'GET', None) + expected_requests = [ + create_job_request, + ask_if_done_request, + ] + + # Returns a different job (with different 'someInput' field) will + # cause 'create_job' request to fail. + responses = [job_already_exist_response, different_job_response] + + def check_input(existing_job): + return existing_job.get('someInput', None) == \ + my_job['someInput'] + with _TestCloudMLHook( + self, + responses=responses, + expected_requests=expected_requests) as cml_hook: + with self.assertRaises(errors.HttpError): + cml_hook.create_job( + project_name=project, job=my_job, + use_existing_job_fn=check_input) + + my_job_response = ({'status': '200'}, my_job_response_body) + expected_requests = [ + create_job_request, + ask_if_done_request, + ask_if_done_request, + ] + responses = [ + job_already_exist_response, + my_job_response, + my_job_response] + with _TestCloudMLHook( + self, + responses=responses, + expected_requests=expected_requests) as cml_hook: + create_job_response = cml_hook.create_job( + project_name=project, job=my_job, + use_existing_job_fn=check_input) + self.assertEquals(create_job_response, my_job) + if __name__ == '__main__': unittest.main() http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/0fc45045/tests/contrib/operators/test_cloudml_operator.py ---------------------------------------------------------------------- diff --git a/tests/contrib/operators/test_cloudml_operator.py b/tests/contrib/operators/test_cloudml_operator.py index b76a0c6..dc8c204 100644 --- a/tests/contrib/operators/test_cloudml_operator.py +++ b/tests/contrib/operators/test_cloudml_operator.py @@ -26,41 +26,41 @@ import unittest from airflow import configuration, DAG from airflow.contrib.operators.cloudml_operator import CloudMLBatchPredictionOperator +from airflow.contrib.operators.cloudml_operator import CloudMLTrainingOperator +from mock import ANY from mock import patch DEFAULT_DATE = datetime.datetime(2017, 6, 6) -INPUT_MISSING_ORIGIN = { - 'dataFormat': 'TEXT', - 'inputPaths': ['gs://legal-bucket/fake-input-path/*'], - 'outputPath': 'gs://legal-bucket/fake-output-path', - 'region': 'us-east1', -} - -SUCCESS_MESSAGE_MISSING_INPUT = { - 'jobId': 'test_prediction', - 'predictionOutput': { - 'outputPath': 'gs://fake-output-path', - 'predictionCount': 5000, - 'errorCount': 0, - 'nodeHours': 2.78 - }, - 'state': 'SUCCEEDED' -} - -DEFAULT_ARGS = { - 'project_id': 'test-project', - 'job_id': 'test_prediction', - 'region': 'us-east1', - 'data_format': 'TEXT', - 'input_paths': ['gs://legal-bucket-dash-Capital/legal-input-path/*'], - 'output_path': 'gs://12_legal_bucket_underscore_number/legal-output-path', - 'task_id': 'test-prediction' -} - class CloudMLBatchPredictionOperatorTest(unittest.TestCase): + INPUT_MISSING_ORIGIN = { + 'dataFormat': 'TEXT', + 'inputPaths': ['gs://legal-bucket/fake-input-path/*'], + 'outputPath': 'gs://legal-bucket/fake-output-path', + 'region': 'us-east1', + } + SUCCESS_MESSAGE_MISSING_INPUT = { + 'jobId': 'test_prediction', + 'predictionOutput': { + 'outputPath': 'gs://fake-output-path', + 'predictionCount': 5000, + 'errorCount': 0, + 'nodeHours': 2.78 + }, + 'state': 'SUCCEEDED' + } + BATCH_PREDICTION_DEFAULT_ARGS = { + 'project_id': 'test-project', + 'job_id': 'test_prediction', + 'region': 'us-east1', + 'data_format': 'TEXT', + 'input_paths': ['gs://legal-bucket-dash-Capital/legal-input-path/*'], + 'output_path': + 'gs://12_legal_bucket_underscore_number/legal-output-path', + 'task_id': 'test-prediction' + } def setUp(self): super(CloudMLBatchPredictionOperatorTest, self).setUp() @@ -78,10 +78,10 @@ class CloudMLBatchPredictionOperatorTest(unittest.TestCase): with patch('airflow.contrib.operators.cloudml_operator.CloudMLHook') \ as mock_hook: - input_with_model = INPUT_MISSING_ORIGIN.copy() + input_with_model = self.INPUT_MISSING_ORIGIN.copy() input_with_model['modelName'] = \ 'projects/test-project/models/test_model' - success_message = SUCCESS_MESSAGE_MISSING_INPUT.copy() + success_message = self.SUCCESS_MESSAGE_MISSING_INPUT.copy() success_message['predictionInput'] = input_with_model hook_instance = mock_hook.return_value @@ -104,12 +104,12 @@ class CloudMLBatchPredictionOperatorTest(unittest.TestCase): prediction_output = prediction_task.execute(None) mock_hook.assert_called_with('google_cloud_default', None) - hook_instance.create_job.assert_called_with( + hook_instance.create_job.assert_called_once_with( 'test-project', { 'jobId': 'test_prediction', 'predictionInput': input_with_model - }) + }, ANY) self.assertEquals( success_message['predictionOutput'], prediction_output) @@ -118,10 +118,10 @@ class CloudMLBatchPredictionOperatorTest(unittest.TestCase): with patch('airflow.contrib.operators.cloudml_operator.CloudMLHook') \ as mock_hook: - input_with_version = INPUT_MISSING_ORIGIN.copy() + input_with_version = self.INPUT_MISSING_ORIGIN.copy() input_with_version['versionName'] = \ 'projects/test-project/models/test_model/versions/test_version' - success_message = SUCCESS_MESSAGE_MISSING_INPUT.copy() + success_message = self.SUCCESS_MESSAGE_MISSING_INPUT.copy() success_message['predictionInput'] = input_with_version hook_instance = mock_hook.return_value @@ -132,8 +132,7 @@ class CloudMLBatchPredictionOperatorTest(unittest.TestCase): hook_instance.create_job.return_value = success_message prediction_task = CloudMLBatchPredictionOperator( - job_id='test_prediction', - project_id='test-project', + job_id='test_prediction', project_id='test-project', region=input_with_version['region'], data_format=input_with_version['dataFormat'], input_paths=input_with_version['inputPaths'], @@ -150,7 +149,7 @@ class CloudMLBatchPredictionOperatorTest(unittest.TestCase): { 'jobId': 'test_prediction', 'predictionInput': input_with_version - }) + }, ANY) self.assertEquals( success_message['predictionOutput'], prediction_output) @@ -159,9 +158,9 @@ class CloudMLBatchPredictionOperatorTest(unittest.TestCase): with patch('airflow.contrib.operators.cloudml_operator.CloudMLHook') \ as mock_hook: - input_with_uri = INPUT_MISSING_ORIGIN.copy() + input_with_uri = self.INPUT_MISSING_ORIGIN.copy() input_with_uri['uri'] = 'gs://my_bucket/my_models/savedModel' - success_message = SUCCESS_MESSAGE_MISSING_INPUT.copy() + success_message = self.SUCCESS_MESSAGE_MISSING_INPUT.copy() success_message['predictionInput'] = input_with_uri hook_instance = mock_hook.return_value @@ -189,14 +188,14 @@ class CloudMLBatchPredictionOperatorTest(unittest.TestCase): { 'jobId': 'test_prediction', 'predictionInput': input_with_uri - }) + }, ANY) self.assertEquals( success_message['predictionOutput'], prediction_output) def testInvalidModelOrigin(self): # Test that both uri and model is given - task_args = DEFAULT_ARGS.copy() + task_args = self.BATCH_PREDICTION_DEFAULT_ARGS.copy() task_args['uri'] = 'gs://fake-uri/saved_model' task_args['model_name'] = 'fake_model' with self.assertRaises(ValueError) as context: @@ -204,7 +203,7 @@ class CloudMLBatchPredictionOperatorTest(unittest.TestCase): self.assertEquals('Ambiguous model origin.', str(context.exception)) # Test that both uri and model/version is given - task_args = DEFAULT_ARGS.copy() + task_args = self.BATCH_PREDICTION_DEFAULT_ARGS.copy() task_args['uri'] = 'gs://fake-uri/saved_model' task_args['model_name'] = 'fake_model' task_args['version_name'] = 'fake_version' @@ -213,7 +212,7 @@ class CloudMLBatchPredictionOperatorTest(unittest.TestCase): self.assertEquals('Ambiguous model origin.', str(context.exception)) # Test that a version is given without a model - task_args = DEFAULT_ARGS.copy() + task_args = self.BATCH_PREDICTION_DEFAULT_ARGS.copy() task_args['version_name'] = 'bare_version' with self.assertRaises(ValueError) as context: CloudMLBatchPredictionOperator(**task_args).execute(None) @@ -222,7 +221,7 @@ class CloudMLBatchPredictionOperatorTest(unittest.TestCase): str(context.exception)) # Test that none of uri, model, model/version is given - task_args = DEFAULT_ARGS.copy() + task_args = self.BATCH_PREDICTION_DEFAULT_ARGS.copy() with self.assertRaises(ValueError) as context: CloudMLBatchPredictionOperator(**task_args).execute(None) self.assertEquals( @@ -234,7 +233,7 @@ class CloudMLBatchPredictionOperatorTest(unittest.TestCase): with patch('airflow.contrib.operators.cloudml_operator.CloudMLHook') \ as mock_hook: - input_with_model = INPUT_MISSING_ORIGIN.copy() + input_with_model = self.INPUT_MISSING_ORIGIN.copy() input_with_model['modelName'] = \ 'projects/experimental/models/test_model' @@ -263,7 +262,7 @@ class CloudMLBatchPredictionOperatorTest(unittest.TestCase): { 'jobId': 'test_prediction', 'predictionInput': input_with_model - }) + }, ANY) self.assertEquals(http_error_code, context.exception.resp.status) @@ -275,7 +274,7 @@ class CloudMLBatchPredictionOperatorTest(unittest.TestCase): 'state': 'FAILED', 'errorMessage': 'A failure message' } - task_args = DEFAULT_ARGS.copy() + task_args = self.BATCH_PREDICTION_DEFAULT_ARGS.copy() task_args['uri'] = 'a uri' with self.assertRaises(RuntimeError) as context: @@ -284,5 +283,91 @@ class CloudMLBatchPredictionOperatorTest(unittest.TestCase): self.assertEquals('A failure message', str(context.exception)) +class CloudMLTrainingOperatorTest(unittest.TestCase): + TRAINING_DEFAULT_ARGS = { + 'project_name': 'test-project', + 'job_id': 'test_training', + 'package_uris': ['gs://some-bucket/package1'], + 'training_python_module': 'trainer', + 'training_args': '--some_arg=\'aaa\'', + 'region': 'us-east1', + 'scale_tier': 'STANDARD_1', + 'task_id': 'test-training' + } + TRAINING_INPUT = { + 'jobId': 'test_training', + 'trainingInput': { + 'scaleTier': 'STANDARD_1', + 'packageUris': ['gs://some-bucket/package1'], + 'pythonModule': 'trainer', + 'args': '--some_arg=\'aaa\'', + 'region': 'us-east1' + } + } + + def testSuccessCreateTrainingJob(self): + with patch('airflow.contrib.operators.cloudml_operator.CloudMLHook') \ + as mock_hook: + success_response = self.TRAINING_INPUT.copy() + success_response['state'] = 'SUCCEEDED' + hook_instance = mock_hook.return_value + hook_instance.create_job.return_value = success_response + + training_op = CloudMLTrainingOperator(**self.TRAINING_DEFAULT_ARGS) + training_op.execute(None) + + mock_hook.assert_called_with(gcp_conn_id='google_cloud_default', + delegate_to=None) + # Make sure only 'create_job' is invoked on hook instance + self.assertEquals(len(hook_instance.mock_calls), 1) + hook_instance.create_job.assert_called_with( + 'test-project', self.TRAINING_INPUT, ANY) + + def testHttpError(self): + http_error_code = 403 + with patch('airflow.contrib.operators.cloudml_operator.CloudMLHook') \ + as mock_hook: + hook_instance = mock_hook.return_value + hook_instance.create_job.side_effect = errors.HttpError( + resp=httplib2.Response({ + 'status': http_error_code + }), content=b'Forbidden') + + with self.assertRaises(errors.HttpError) as context: + training_op = CloudMLTrainingOperator( + **self.TRAINING_DEFAULT_ARGS) + training_op.execute(None) + + mock_hook.assert_called_with( + gcp_conn_id='google_cloud_default', delegate_to=None) + # Make sure only 'create_job' is invoked on hook instance + self.assertEquals(len(hook_instance.mock_calls), 1) + hook_instance.create_job.assert_called_with( + 'test-project', self.TRAINING_INPUT, ANY) + self.assertEquals(http_error_code, context.exception.resp.status) + + def testFailedJobError(self): + with patch('airflow.contrib.operators.cloudml_operator.CloudMLHook') \ + as mock_hook: + failure_response = self.TRAINING_INPUT.copy() + failure_response['state'] = 'FAILED' + failure_response['errorMessage'] = 'A failure message' + hook_instance = mock_hook.return_value + hook_instance.create_job.return_value = failure_response + + with self.assertRaises(RuntimeError) as context: + training_op = CloudMLTrainingOperator( + **self.TRAINING_DEFAULT_ARGS) + training_op.execute(None) + + mock_hook.assert_called_with( + gcp_conn_id='google_cloud_default', delegate_to=None) + # Make sure only 'create_job' is invoked on hook instance + self.assertEquals(len(hook_instance.mock_calls), 1) + hook_instance.create_job.assert_called_with( + 'test-project', self.TRAINING_INPUT, ANY) + self.assertEquals('A failure message', str(context.exception)) + + if __name__ == '__main__': unittest.main()
