Repository: incubator-airflow Updated Branches: refs/heads/master 9fd0beaac -> b6d363104
[AIRFLOW-1401] Standardize cloud ml operator arguments Standardize on project_id, to be consistent with other cloud operators, better-supporting default arguments. This is one of multiple commits that will be required to resolve AIRFLOW-1401. Closes #2439 from peterjdolan/cloudml_project_id Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/b6d36310 Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/b6d36310 Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/b6d36310 Branch: refs/heads/master Commit: b6d3631043ceb896dd1f8b7ade84751a284770b0 Parents: 9fd0bea Author: Peter Dolan <[email protected]> Authored: Thu Jul 13 14:33:32 2017 -0700 Committer: Alex Guziel <[email protected]> Committed: Thu Jul 13 14:33:32 2017 -0700 ---------------------------------------------------------------------- airflow/contrib/hooks/gcp_cloudml_hook.py | 44 +++++++++--------- airflow/contrib/operators/cloudml_operator.py | 47 ++++++++++---------- tests/contrib/hooks/test_gcp_cloudml_hook.py | 20 ++++----- .../contrib/operators/test_cloudml_operator.py | 2 +- 4 files changed, 57 insertions(+), 56 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/b6d36310/airflow/contrib/hooks/gcp_cloudml_hook.py ---------------------------------------------------------------------- diff --git a/airflow/contrib/hooks/gcp_cloudml_hook.py b/airflow/contrib/hooks/gcp_cloudml_hook.py index 6f634b2..e1ff155 100644 --- a/airflow/contrib/hooks/gcp_cloudml_hook.py +++ b/airflow/contrib/hooks/gcp_cloudml_hook.py @@ -62,13 +62,13 @@ class CloudMLHook(GoogleCloudBaseHook): credentials = GoogleCredentials.get_application_default() return build('ml', 'v1', credentials=credentials) - def create_job(self, project_name, job, use_existing_job_fn=None): + def create_job(self, project_id, job, use_existing_job_fn=None): """ Launches a CloudML job and wait for it to reach a terminal state. - :param project_name: The Google Cloud project name within which CloudML + :param project_id: The Google Cloud project id within which CloudML job will be launched. - :type project_name: string + :type project_id: string :param job: CloudML Job object that should be provided to the CloudML API, such as: @@ -95,7 +95,7 @@ class CloudMLHook(GoogleCloudBaseHook): :rtype: dict """ request = self._cloudml.projects().jobs().create( - parent='projects/{}'.format(project_name), + parent='projects/{}'.format(project_id), body=job) job_id = job['jobId'] @@ -105,7 +105,7 @@ class CloudMLHook(GoogleCloudBaseHook): # 409 means there is an existing job with the same job ID. if e.resp.status == 409: if use_existing_job_fn is not None: - existing_job = self._get_job(project_name, job_id) + existing_job = self._get_job(project_id, job_id) if not use_existing_job_fn(existing_job): logging.error( 'Job with job_id {} already exist, but it does ' @@ -118,9 +118,9 @@ class CloudMLHook(GoogleCloudBaseHook): else: logging.error('Failed to create CloudML job: {}'.format(e)) raise - return self._wait_for_job_done(project_name, job_id) + return self._wait_for_job_done(project_id, job_id) - def _get_job(self, project_name, job_id): + def _get_job(self, project_id, job_id): """ Gets a CloudML job based on the job name. @@ -130,7 +130,7 @@ class CloudMLHook(GoogleCloudBaseHook): Raises: apiclient.errors.HttpError: if HTTP error is returned from server """ - job_name = 'projects/{}/jobs/{}'.format(project_name, job_id) + job_name = 'projects/{}/jobs/{}'.format(project_id, job_id) request = self._cloudml.projects().jobs().get(name=job_name) while True: try: @@ -143,7 +143,7 @@ class CloudMLHook(GoogleCloudBaseHook): logging.error('Failed to get CloudML job: {}'.format(e)) raise - def _wait_for_job_done(self, project_name, job_id, interval=30): + def _wait_for_job_done(self, project_id, job_id, interval=30): """ Waits for the Job to reach a terminal state. @@ -156,19 +156,19 @@ class CloudMLHook(GoogleCloudBaseHook): """ assert interval > 0 while True: - job = self._get_job(project_name, job_id) + job = self._get_job(project_id, job_id) if job['state'] in ['SUCCEEDED', 'FAILED', 'CANCELLED']: return job time.sleep(interval) - def create_version(self, project_name, model_name, version_spec): + def create_version(self, project_id, model_name, version_spec): """ Creates the Version on Cloud ML. Returns the operation if the version was created successfully and raises an error otherwise. """ - parent_name = 'projects/{}/models/{}'.format(project_name, model_name) + parent_name = 'projects/{}/models/{}'.format(project_id, model_name) create_request = self._cloudml.projects().models().versions().create( parent=parent_name, body=version_spec) response = create_request.execute() @@ -181,12 +181,12 @@ class CloudMLHook(GoogleCloudBaseHook): is_done_func=lambda resp: resp.get('done', False), is_error_func=lambda resp: resp.get('error', None) is not None) - def set_default_version(self, project_name, model_name, version_name): + def set_default_version(self, project_id, model_name, version_name): """ Sets a version to be the default. Blocks until finished. """ full_version_name = 'projects/{}/models/{}/versions/{}'.format( - project_name, model_name, version_name) + project_id, model_name, version_name) request = self._cloudml.projects().models().versions().setDefault( name=full_version_name, body={}) @@ -199,13 +199,13 @@ class CloudMLHook(GoogleCloudBaseHook): logging.error('Something went wrong: {}'.format(e)) raise - def list_versions(self, project_name, model_name): + def list_versions(self, project_id, model_name): """ Lists all available versions of a model. Blocks until finished. """ result = [] full_parent_name = 'projects/{}/models/{}'.format( - project_name, model_name) + project_id, model_name) request = self._cloudml.projects().models().versions().list( parent=full_parent_name, pageSize=100) @@ -223,12 +223,12 @@ class CloudMLHook(GoogleCloudBaseHook): time.sleep(5) return result - def delete_version(self, project_name, model_name, version_name): + def delete_version(self, project_id, model_name, version_name): """ Deletes the given version of a model. Blocks until finished. """ full_name = 'projects/{}/models/{}/versions/{}'.format( - project_name, model_name, version_name) + project_id, model_name, version_name) delete_request = self._cloudml.projects().models().versions().delete( name=full_name) response = delete_request.execute() @@ -241,24 +241,24 @@ class CloudMLHook(GoogleCloudBaseHook): is_done_func=lambda resp: resp.get('done', False), is_error_func=lambda resp: resp.get('error', None) is not None) - def create_model(self, project_name, model): + def create_model(self, project_id, model): """ Create a Model. Blocks until finished. """ assert model['name'] is not None and model['name'] is not '' - project = 'projects/{}'.format(project_name) + project = 'projects/{}'.format(project_id) request = self._cloudml.projects().models().create( parent=project, body=model) return request.execute() - def get_model(self, project_name, model_name): + def get_model(self, project_id, model_name): """ Gets a Model. Blocks until finished. """ assert model_name is not None and model_name is not '' full_model_name = 'projects/{}/models/{}'.format( - project_name, model_name) + project_id, model_name) request = self._cloudml.projects().models().get(name=full_model_name) try: return request.execute() http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/b6d36310/airflow/contrib/operators/cloudml_operator.py ---------------------------------------------------------------------- diff --git a/airflow/contrib/operators/cloudml_operator.py b/airflow/contrib/operators/cloudml_operator.py index 3ad6f5a..34b2e83 100644 --- a/airflow/contrib/operators/cloudml_operator.py +++ b/airflow/contrib/operators/cloudml_operator.py @@ -272,9 +272,9 @@ class CloudMLModelOperator(BaseOperator): should contain the `name` of the model. :type model: dict - :param project_name: The Google Cloud project name to which CloudML + :param project_id: The Google Cloud project name to which CloudML model belongs. - :type project_name: string + :type project_id: string :param gcp_conn_id: The connection ID to use when fetching connection info. :type gcp_conn_id: string @@ -291,12 +291,13 @@ class CloudMLModelOperator(BaseOperator): template_fields = [ '_model', + '_model_name', ] @apply_defaults def __init__(self, + project_id, model, - project_name, gcp_conn_id='google_cloud_default', operation='create', delegate_to=None, @@ -307,15 +308,15 @@ class CloudMLModelOperator(BaseOperator): self._operation = operation self._gcp_conn_id = gcp_conn_id self._delegate_to = delegate_to - self._project_name = project_name + self._project_id = project_id def execute(self, context): hook = CloudMLHook( gcp_conn_id=self._gcp_conn_id, delegate_to=self._delegate_to) if self._operation == 'create': - hook.create_model(self._project_name, self._model) + hook.create_model(self._project_id, self._model) elif self._operation == 'get': - hook.get_model(self._project_name, self._model['name']) + hook.get_model(self._project_id, self._model['name']) else: raise ValueError('Unknown operation: {}'.format(self._operation)) @@ -328,9 +329,9 @@ class CloudMLVersionOperator(BaseOperator): belongs to. :type model_name: string - :param project_name: The Google Cloud project name to which CloudML + :param project_id: The Google Cloud project name to which CloudML model belongs. - :type project_name: string + :type project_id: string :param version: A dictionary containing the information about the version. If the `operation` is `create`, `version` should contain all the @@ -376,8 +377,8 @@ class CloudMLVersionOperator(BaseOperator): @apply_defaults def __init__(self, model_name, - project_name, - version=None, + project_id, + version, gcp_conn_id='google_cloud_default', operation='create', delegate_to=None, @@ -389,7 +390,7 @@ class CloudMLVersionOperator(BaseOperator): self._version = version self._gcp_conn_id = gcp_conn_id self._delegate_to = delegate_to - self._project_name = project_name + self._project_id = project_id self._operation = operation def execute(self, context): @@ -398,16 +399,16 @@ class CloudMLVersionOperator(BaseOperator): if self._operation == 'create': assert self._version is not None - return hook.create_version(self._project_name, self._model_name, + return hook.create_version(self._project_id, self._model_name, self._version) elif self._operation == 'set_default': return hook.set_default_version( - self._project_name, self._model_name, + self._project_id, self._model_name, self._version['name']) elif self._operation == 'list': - return hook.list_versions(self._project_name, self._model_name) + return hook.list_versions(self._project_id, self._model_name) elif self._operation == 'delete': - return hook.delete_version(self._project_name, self._model_name, + return hook.delete_version(self._project_id, self._model_name, self._version['name']) else: raise ValueError('Unknown operation: {}'.format(self._operation)) @@ -417,9 +418,9 @@ class CloudMLTrainingOperator(BaseOperator): """ Operator for launching a CloudML training job. - :param project_name: The Google Cloud project name within which CloudML + :param project_id: The Google Cloud project name within which CloudML training job should run. This field could be templated. - :type project_name: string + :type project_id: string :param job_id: A unique templated id for the submitted Google CloudML training job. @@ -461,7 +462,7 @@ class CloudMLTrainingOperator(BaseOperator): """ template_fields = [ - '_project_name', + '_project_id', '_job_id', '_package_uris', '_training_python_module', @@ -472,7 +473,7 @@ class CloudMLTrainingOperator(BaseOperator): @apply_defaults def __init__(self, - project_name, + project_id, job_id, package_uris, training_python_module, @@ -485,7 +486,7 @@ class CloudMLTrainingOperator(BaseOperator): *args, **kwargs): super(CloudMLTrainingOperator, self).__init__(*args, **kwargs) - self._project_name = project_name + self._project_id = project_id self._job_id = job_id self._package_uris = package_uris self._training_python_module = training_python_module @@ -496,8 +497,8 @@ class CloudMLTrainingOperator(BaseOperator): self._delegate_to = delegate_to self._mode = mode - if not self._project_name: - raise AirflowException('Google Cloud project name is required.') + if not self._project_id: + raise AirflowException('Google Cloud project id is required.') if not self._job_id: raise AirflowException( 'An unique job id is required for Google CloudML training ' @@ -542,7 +543,7 @@ class CloudMLTrainingOperator(BaseOperator): training_request['trainingInput'] try: finished_training_job = hook.create_job( - self._project_name, training_request, check_existing_job) + self._project_id, training_request, check_existing_job) except errors.HttpError: raise http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/b6d36310/tests/contrib/hooks/test_gcp_cloudml_hook.py ---------------------------------------------------------------------- diff --git a/tests/contrib/hooks/test_gcp_cloudml_hook.py b/tests/contrib/hooks/test_gcp_cloudml_hook.py index 53aba41..f56018d 100644 --- a/tests/contrib/hooks/test_gcp_cloudml_hook.py +++ b/tests/contrib/hooks/test_gcp_cloudml_hook.py @@ -121,7 +121,7 @@ class TestCloudMLHook(unittest.TestCase): responses=[succeeded_response] * 2, expected_requests=expected_requests) as cml_hook: create_version_response = cml_hook.create_version( - project_name=project, model_name=model_name, + project_id=project, model_name=model_name, version_spec=version) self.assertEquals(create_version_response, response_body) @@ -147,7 +147,7 @@ class TestCloudMLHook(unittest.TestCase): responses=[succeeded_response], expected_requests=expected_requests) as cml_hook: set_default_version_response = cml_hook.set_default_version( - project_name=project, model_name=model_name, + project_id=project, model_name=model_name, version_name=version) self.assertEquals(set_default_version_response, response_body) @@ -187,7 +187,7 @@ class TestCloudMLHook(unittest.TestCase): responses=responses, expected_requests=expected_requests) as cml_hook: list_versions_response = cml_hook.list_versions( - project_name=project, model_name=model_name) + project_id=project, model_name=model_name) self.assertEquals(list_versions_response, versions) @_SKIP_IF @@ -220,7 +220,7 @@ class TestCloudMLHook(unittest.TestCase): responses=[not_done_response, succeeded_response], expected_requests=expected_requests) as cml_hook: delete_version_response = cml_hook.delete_version( - project_name=project, model_name=model_name, + project_id=project, model_name=model_name, version_name=version) self.assertEquals(delete_version_response, done_response_body) @@ -245,7 +245,7 @@ class TestCloudMLHook(unittest.TestCase): responses=[succeeded_response], expected_requests=expected_requests) as cml_hook: create_model_response = cml_hook.create_model( - project_name=project, model=model) + project_id=project, model=model) self.assertEquals(create_model_response, response_body) @_SKIP_IF @@ -266,7 +266,7 @@ class TestCloudMLHook(unittest.TestCase): responses=[succeeded_response], expected_requests=expected_requests) as cml_hook: get_model_response = cml_hook.get_model( - project_name=project, model_name=model_name) + project_id=project, model_name=model_name) self.assertEquals(get_model_response, response_body) @_SKIP_IF @@ -302,7 +302,7 @@ class TestCloudMLHook(unittest.TestCase): responses=responses, expected_requests=expected_requests) as cml_hook: create_job_response = cml_hook.create_job( - project_name=project, job=my_job) + project_id=project, job=my_job) self.assertEquals(create_job_response, my_job) @_SKIP_IF @@ -334,7 +334,7 @@ class TestCloudMLHook(unittest.TestCase): responses=responses, expected_requests=expected_requests) as cml_hook: create_job_response = cml_hook.create_job( - project_name=project, job=my_job) + project_id=project, job=my_job) self.assertEquals(create_job_response, my_job) @_SKIP_IF @@ -386,7 +386,7 @@ class TestCloudMLHook(unittest.TestCase): expected_requests=expected_requests) as cml_hook: with self.assertRaises(errors.HttpError): cml_hook.create_job( - project_name=project, job=my_job, + project_id=project, job=my_job, use_existing_job_fn=check_input) my_job_response = ({'status': '200'}, my_job_response_body) @@ -404,7 +404,7 @@ class TestCloudMLHook(unittest.TestCase): responses=responses, expected_requests=expected_requests) as cml_hook: create_job_response = cml_hook.create_job( - project_name=project, job=my_job, + project_id=project, job=my_job, use_existing_job_fn=check_input) self.assertEquals(create_job_response, my_job) http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/b6d36310/tests/contrib/operators/test_cloudml_operator.py ---------------------------------------------------------------------- diff --git a/tests/contrib/operators/test_cloudml_operator.py b/tests/contrib/operators/test_cloudml_operator.py index dc8c204..dc2366e 100644 --- a/tests/contrib/operators/test_cloudml_operator.py +++ b/tests/contrib/operators/test_cloudml_operator.py @@ -285,7 +285,7 @@ class CloudMLBatchPredictionOperatorTest(unittest.TestCase): class CloudMLTrainingOperatorTest(unittest.TestCase): TRAINING_DEFAULT_ARGS = { - 'project_name': 'test-project', + 'project_id': 'test-project', 'job_id': 'test_training', 'package_uris': ['gs://some-bucket/package1'], 'training_python_module': 'trainer',
