[AIRFLOW-1567][Airflow-1567] Renamed cloudml hook and operator to mlengine Closes #2567 from yk5/cmle
Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/af91e2ac Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/af91e2ac Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/af91e2ac Branch: refs/heads/master Commit: af91e2ac0636685c0c1c25ddeba97f78b7009b88 Parents: 86063ba Author: Younghee Kwon <[email protected]> Authored: Wed Sep 6 09:51:17 2017 -0700 Committer: Chris Riccomini <[email protected]> Committed: Wed Sep 6 09:51:17 2017 -0700 ---------------------------------------------------------------------- airflow/contrib/hooks/gcp_cloudml_hook.py | 269 --------- airflow/contrib/hooks/gcp_mlengine_hook.py | 269 +++++++++ airflow/contrib/operators/cloudml_operator.py | 565 ------------------- .../contrib/operators/cloudml_operator_utils.py | 245 -------- .../operators/cloudml_prediction_summary.py | 177 ------ airflow/contrib/operators/mlengine_operator.py | 564 ++++++++++++++++++ .../operators/mlengine_operator_utils.py | 245 ++++++++ .../operators/mlengine_prediction_summary.py | 177 ++++++ tests/contrib/hooks/test_gcp_cloudml_hook.py | 413 -------------- tests/contrib/hooks/test_gcp_mlengine_hook.py | 413 ++++++++++++++ .../contrib/operators/test_cloudml_operator.py | 373 ------------ .../operators/test_cloudml_operator_utils.py | 183 ------ .../contrib/operators/test_mlengine_operator.py | 373 ++++++++++++ .../operators/test_mlengine_operator_utils.py | 183 ++++++ 14 files changed, 2224 insertions(+), 2225 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/af91e2ac/airflow/contrib/hooks/gcp_cloudml_hook.py ---------------------------------------------------------------------- diff --git a/airflow/contrib/hooks/gcp_cloudml_hook.py b/airflow/contrib/hooks/gcp_cloudml_hook.py deleted file mode 100644 index e1ff155..0000000 --- a/airflow/contrib/hooks/gcp_cloudml_hook.py +++ /dev/null @@ -1,269 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import logging -import random -import time -from airflow import settings -from airflow.contrib.hooks.gcp_api_base_hook import GoogleCloudBaseHook -from apiclient.discovery import build -from apiclient import errors -from oauth2client.client import GoogleCredentials - -logging.getLogger('GoogleCloudML').setLevel(settings.LOGGING_LEVEL) - - -def _poll_with_exponential_delay(request, max_n, is_done_func, is_error_func): - - for i in range(0, max_n): - try: - response = request.execute() - if is_error_func(response): - raise ValueError( - 'The response contained an error: {}'.format(response)) - elif is_done_func(response): - logging.info('Operation is done: {}'.format(response)) - return response - else: - time.sleep((2**i) + (random.randint(0, 1000) / 1000)) - except errors.HttpError as e: - if e.resp.status != 429: - logging.info( - 'Something went wrong. Not retrying: {}'.format(e)) - raise - else: - time.sleep((2**i) + (random.randint(0, 1000) / 1000)) - - -class CloudMLHook(GoogleCloudBaseHook): - - def __init__(self, gcp_conn_id='google_cloud_default', delegate_to=None): - super(CloudMLHook, self).__init__(gcp_conn_id, delegate_to) - self._cloudml = self.get_conn() - - def get_conn(self): - """ - Returns a Google CloudML service object. - """ - credentials = GoogleCredentials.get_application_default() - return build('ml', 'v1', credentials=credentials) - - def create_job(self, project_id, job, use_existing_job_fn=None): - """ - Launches a CloudML job and wait for it to reach a terminal state. - - :param project_id: The Google Cloud project id within which CloudML - job will be launched. - :type project_id: string - - :param job: CloudML Job object that should be provided to the CloudML - API, such as: - { - 'jobId': 'my_job_id', - 'trainingInput': { - 'scaleTier': 'STANDARD_1', - ... - } - } - :type job: dict - - :param use_existing_job_fn: In case that a CloudML job with the same - job_id already exist, this method (if provided) will decide whether - we should use this existing job, continue waiting for it to finish - and returning the job object. It should accepts a CloudML job - object, and returns a boolean value indicating whether it is OK to - reuse the existing job. If 'use_existing_job_fn' is not provided, - we by default reuse the existing CloudML job. - :type use_existing_job_fn: function - - :return: The CloudML job object if the job successfully reach a - terminal state (which might be FAILED or CANCELLED state). - :rtype: dict - """ - request = self._cloudml.projects().jobs().create( - parent='projects/{}'.format(project_id), - body=job) - job_id = job['jobId'] - - try: - request.execute() - except errors.HttpError as e: - # 409 means there is an existing job with the same job ID. - if e.resp.status == 409: - if use_existing_job_fn is not None: - existing_job = self._get_job(project_id, job_id) - if not use_existing_job_fn(existing_job): - logging.error( - 'Job with job_id {} already exist, but it does ' - 'not match our expectation: {}'.format( - job_id, existing_job)) - raise - logging.info( - 'Job with job_id {} already exist. Will waiting for it to ' - 'finish'.format(job_id)) - else: - logging.error('Failed to create CloudML job: {}'.format(e)) - raise - return self._wait_for_job_done(project_id, job_id) - - def _get_job(self, project_id, job_id): - """ - Gets a CloudML job based on the job name. - - :return: CloudML job object if succeed. - :rtype: dict - - Raises: - apiclient.errors.HttpError: if HTTP error is returned from server - """ - job_name = 'projects/{}/jobs/{}'.format(project_id, job_id) - request = self._cloudml.projects().jobs().get(name=job_name) - while True: - try: - return request.execute() - except errors.HttpError as e: - if e.resp.status == 429: - # polling after 30 seconds when quota failure occurs - time.sleep(30) - else: - logging.error('Failed to get CloudML job: {}'.format(e)) - raise - - def _wait_for_job_done(self, project_id, job_id, interval=30): - """ - Waits for the Job to reach a terminal state. - - This method will periodically check the job state until the job reach - a terminal state. - - Raises: - apiclient.errors.HttpError: if HTTP error is returned when getting - the job - """ - assert interval > 0 - while True: - job = self._get_job(project_id, job_id) - if job['state'] in ['SUCCEEDED', 'FAILED', 'CANCELLED']: - return job - time.sleep(interval) - - def create_version(self, project_id, model_name, version_spec): - """ - Creates the Version on Cloud ML. - - Returns the operation if the version was created successfully and - raises an error otherwise. - """ - parent_name = 'projects/{}/models/{}'.format(project_id, model_name) - create_request = self._cloudml.projects().models().versions().create( - parent=parent_name, body=version_spec) - response = create_request.execute() - get_request = self._cloudml.projects().operations().get( - name=response['name']) - - return _poll_with_exponential_delay( - request=get_request, - max_n=9, - is_done_func=lambda resp: resp.get('done', False), - is_error_func=lambda resp: resp.get('error', None) is not None) - - def set_default_version(self, project_id, model_name, version_name): - """ - Sets a version to be the default. Blocks until finished. - """ - full_version_name = 'projects/{}/models/{}/versions/{}'.format( - project_id, model_name, version_name) - request = self._cloudml.projects().models().versions().setDefault( - name=full_version_name, body={}) - - try: - response = request.execute() - logging.info( - 'Successfully set version: {} to default'.format(response)) - return response - except errors.HttpError as e: - logging.error('Something went wrong: {}'.format(e)) - raise - - def list_versions(self, project_id, model_name): - """ - Lists all available versions of a model. Blocks until finished. - """ - result = [] - full_parent_name = 'projects/{}/models/{}'.format( - project_id, model_name) - request = self._cloudml.projects().models().versions().list( - parent=full_parent_name, pageSize=100) - - response = request.execute() - next_page_token = response.get('nextPageToken', None) - result.extend(response.get('versions', [])) - while next_page_token is not None: - next_request = self._cloudml.projects().models().versions().list( - parent=full_parent_name, - pageToken=next_page_token, - pageSize=100) - response = next_request.execute() - next_page_token = response.get('nextPageToken', None) - result.extend(response.get('versions', [])) - time.sleep(5) - return result - - def delete_version(self, project_id, model_name, version_name): - """ - Deletes the given version of a model. Blocks until finished. - """ - full_name = 'projects/{}/models/{}/versions/{}'.format( - project_id, model_name, version_name) - delete_request = self._cloudml.projects().models().versions().delete( - name=full_name) - response = delete_request.execute() - get_request = self._cloudml.projects().operations().get( - name=response['name']) - - return _poll_with_exponential_delay( - request=get_request, - max_n=9, - is_done_func=lambda resp: resp.get('done', False), - is_error_func=lambda resp: resp.get('error', None) is not None) - - def create_model(self, project_id, model): - """ - Create a Model. Blocks until finished. - """ - assert model['name'] is not None and model['name'] is not '' - project = 'projects/{}'.format(project_id) - - request = self._cloudml.projects().models().create( - parent=project, body=model) - return request.execute() - - def get_model(self, project_id, model_name): - """ - Gets a Model. Blocks until finished. - """ - assert model_name is not None and model_name is not '' - full_model_name = 'projects/{}/models/{}'.format( - project_id, model_name) - request = self._cloudml.projects().models().get(name=full_model_name) - try: - return request.execute() - except errors.HttpError as e: - if e.resp.status == 404: - logging.error('Model was not found: {}'.format(e)) - return None - raise http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/af91e2ac/airflow/contrib/hooks/gcp_mlengine_hook.py ---------------------------------------------------------------------- diff --git a/airflow/contrib/hooks/gcp_mlengine_hook.py b/airflow/contrib/hooks/gcp_mlengine_hook.py new file mode 100644 index 0000000..47d9700 --- /dev/null +++ b/airflow/contrib/hooks/gcp_mlengine_hook.py @@ -0,0 +1,269 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging +import random +import time +from airflow import settings +from airflow.contrib.hooks.gcp_api_base_hook import GoogleCloudBaseHook +from apiclient.discovery import build +from apiclient import errors +from oauth2client.client import GoogleCredentials + +logging.getLogger('GoogleCloudMLEngine').setLevel(settings.LOGGING_LEVEL) + + +def _poll_with_exponential_delay(request, max_n, is_done_func, is_error_func): + + for i in range(0, max_n): + try: + response = request.execute() + if is_error_func(response): + raise ValueError( + 'The response contained an error: {}'.format(response)) + elif is_done_func(response): + logging.info('Operation is done: {}'.format(response)) + return response + else: + time.sleep((2**i) + (random.randint(0, 1000) / 1000)) + except errors.HttpError as e: + if e.resp.status != 429: + logging.info( + 'Something went wrong. Not retrying: {}'.format(e)) + raise + else: + time.sleep((2**i) + (random.randint(0, 1000) / 1000)) + + +class MLEngineHook(GoogleCloudBaseHook): + + def __init__(self, gcp_conn_id='google_cloud_default', delegate_to=None): + super(MLEngineHook, self).__init__(gcp_conn_id, delegate_to) + self._mlengine = self.get_conn() + + def get_conn(self): + """ + Returns a Google MLEngine service object. + """ + credentials = GoogleCredentials.get_application_default() + return build('ml', 'v1', credentials=credentials) + + def create_job(self, project_id, job, use_existing_job_fn=None): + """ + Launches a MLEngine job and wait for it to reach a terminal state. + + :param project_id: The Google Cloud project id within which MLEngine + job will be launched. + :type project_id: string + + :param job: MLEngine Job object that should be provided to the MLEngine + API, such as: + { + 'jobId': 'my_job_id', + 'trainingInput': { + 'scaleTier': 'STANDARD_1', + ... + } + } + :type job: dict + + :param use_existing_job_fn: In case that a MLEngine job with the same + job_id already exist, this method (if provided) will decide whether + we should use this existing job, continue waiting for it to finish + and returning the job object. It should accepts a MLEngine job + object, and returns a boolean value indicating whether it is OK to + reuse the existing job. If 'use_existing_job_fn' is not provided, + we by default reuse the existing MLEngine job. + :type use_existing_job_fn: function + + :return: The MLEngine job object if the job successfully reach a + terminal state (which might be FAILED or CANCELLED state). + :rtype: dict + """ + request = self._mlengine.projects().jobs().create( + parent='projects/{}'.format(project_id), + body=job) + job_id = job['jobId'] + + try: + request.execute() + except errors.HttpError as e: + # 409 means there is an existing job with the same job ID. + if e.resp.status == 409: + if use_existing_job_fn is not None: + existing_job = self._get_job(project_id, job_id) + if not use_existing_job_fn(existing_job): + logging.error( + 'Job with job_id {} already exist, but it does ' + 'not match our expectation: {}'.format( + job_id, existing_job)) + raise + logging.info( + 'Job with job_id {} already exist. Will waiting for it to ' + 'finish'.format(job_id)) + else: + logging.error('Failed to create MLEngine job: {}'.format(e)) + raise + return self._wait_for_job_done(project_id, job_id) + + def _get_job(self, project_id, job_id): + """ + Gets a MLEngine job based on the job name. + + :return: MLEngine job object if succeed. + :rtype: dict + + Raises: + apiclient.errors.HttpError: if HTTP error is returned from server + """ + job_name = 'projects/{}/jobs/{}'.format(project_id, job_id) + request = self._mlengine.projects().jobs().get(name=job_name) + while True: + try: + return request.execute() + except errors.HttpError as e: + if e.resp.status == 429: + # polling after 30 seconds when quota failure occurs + time.sleep(30) + else: + logging.error('Failed to get MLEngine job: {}'.format(e)) + raise + + def _wait_for_job_done(self, project_id, job_id, interval=30): + """ + Waits for the Job to reach a terminal state. + + This method will periodically check the job state until the job reach + a terminal state. + + Raises: + apiclient.errors.HttpError: if HTTP error is returned when getting + the job + """ + assert interval > 0 + while True: + job = self._get_job(project_id, job_id) + if job['state'] in ['SUCCEEDED', 'FAILED', 'CANCELLED']: + return job + time.sleep(interval) + + def create_version(self, project_id, model_name, version_spec): + """ + Creates the Version on Google Cloud ML Engine. + + Returns the operation if the version was created successfully and + raises an error otherwise. + """ + parent_name = 'projects/{}/models/{}'.format(project_id, model_name) + create_request = self._mlengine.projects().models().versions().create( + parent=parent_name, body=version_spec) + response = create_request.execute() + get_request = self._mlengine.projects().operations().get( + name=response['name']) + + return _poll_with_exponential_delay( + request=get_request, + max_n=9, + is_done_func=lambda resp: resp.get('done', False), + is_error_func=lambda resp: resp.get('error', None) is not None) + + def set_default_version(self, project_id, model_name, version_name): + """ + Sets a version to be the default. Blocks until finished. + """ + full_version_name = 'projects/{}/models/{}/versions/{}'.format( + project_id, model_name, version_name) + request = self._mlengine.projects().models().versions().setDefault( + name=full_version_name, body={}) + + try: + response = request.execute() + logging.info( + 'Successfully set version: {} to default'.format(response)) + return response + except errors.HttpError as e: + logging.error('Something went wrong: {}'.format(e)) + raise + + def list_versions(self, project_id, model_name): + """ + Lists all available versions of a model. Blocks until finished. + """ + result = [] + full_parent_name = 'projects/{}/models/{}'.format( + project_id, model_name) + request = self._mlengine.projects().models().versions().list( + parent=full_parent_name, pageSize=100) + + response = request.execute() + next_page_token = response.get('nextPageToken', None) + result.extend(response.get('versions', [])) + while next_page_token is not None: + next_request = self._mlengine.projects().models().versions().list( + parent=full_parent_name, + pageToken=next_page_token, + pageSize=100) + response = next_request.execute() + next_page_token = response.get('nextPageToken', None) + result.extend(response.get('versions', [])) + time.sleep(5) + return result + + def delete_version(self, project_id, model_name, version_name): + """ + Deletes the given version of a model. Blocks until finished. + """ + full_name = 'projects/{}/models/{}/versions/{}'.format( + project_id, model_name, version_name) + delete_request = self._mlengine.projects().models().versions().delete( + name=full_name) + response = delete_request.execute() + get_request = self._mlengine.projects().operations().get( + name=response['name']) + + return _poll_with_exponential_delay( + request=get_request, + max_n=9, + is_done_func=lambda resp: resp.get('done', False), + is_error_func=lambda resp: resp.get('error', None) is not None) + + def create_model(self, project_id, model): + """ + Create a Model. Blocks until finished. + """ + assert model['name'] is not None and model['name'] is not '' + project = 'projects/{}'.format(project_id) + + request = self._mlengine.projects().models().create( + parent=project, body=model) + return request.execute() + + def get_model(self, project_id, model_name): + """ + Gets a Model. Blocks until finished. + """ + assert model_name is not None and model_name is not '' + full_model_name = 'projects/{}/models/{}'.format( + project_id, model_name) + request = self._mlengine.projects().models().get(name=full_model_name) + try: + return request.execute() + except errors.HttpError as e: + if e.resp.status == 404: + logging.error('Model was not found: {}'.format(e)) + return None + raise http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/af91e2ac/airflow/contrib/operators/cloudml_operator.py ---------------------------------------------------------------------- diff --git a/airflow/contrib/operators/cloudml_operator.py b/airflow/contrib/operators/cloudml_operator.py deleted file mode 100644 index 6bdd516..0000000 --- a/airflow/contrib/operators/cloudml_operator.py +++ /dev/null @@ -1,565 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the 'License'); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an 'AS IS' BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import re - -from airflow import settings -from airflow.contrib.hooks.gcp_cloudml_hook import CloudMLHook -from airflow.exceptions import AirflowException -from airflow.operators import BaseOperator -from airflow.utils.decorators import apply_defaults -from apiclient import errors - - -logging.getLogger('GoogleCloudML').setLevel(settings.LOGGING_LEVEL) - - -def _create_prediction_input(project_id, - region, - data_format, - input_paths, - output_path, - model_name=None, - version_name=None, - uri=None, - max_worker_count=None, - runtime_version=None): - """ - Create the batch prediction input from the given parameters. - - Args: - A subset of arguments documented in __init__ method of class - CloudMLBatchPredictionOperator - - Returns: - A dictionary representing the predictionInput object as documented - in https://cloud.google.com/ml-engine/reference/rest/v1/projects.jobs. - - Raises: - ValueError: if a unique model/version origin cannot be determined. - """ - - prediction_input = { - 'dataFormat': data_format, - 'inputPaths': input_paths, - 'outputPath': output_path, - 'region': region - } - - if uri: - if model_name or version_name: - logging.error( - 'Ambiguous model origin: Both uri and model/version name are ' - 'provided.') - raise ValueError('Ambiguous model origin.') - prediction_input['uri'] = uri - elif model_name: - origin_name = 'projects/{}/models/{}'.format(project_id, model_name) - if not version_name: - prediction_input['modelName'] = origin_name - else: - prediction_input['versionName'] = \ - origin_name + '/versions/{}'.format(version_name) - else: - logging.error( - 'Missing model origin: Batch prediction expects a model, ' - 'a model & version combination, or a URI to savedModel.') - raise ValueError('Missing model origin.') - - if max_worker_count: - prediction_input['maxWorkerCount'] = max_worker_count - if runtime_version: - prediction_input['runtimeVersion'] = runtime_version - - return prediction_input - - -def _normalize_cloudml_job_id(job_id): - """ - Replaces invalid CloudML job_id characters with '_'. - - This also adds a leading 'z' in case job_id starts with an invalid - character. - - Args: - job_id: A job_id str that may have invalid characters. - - Returns: - A valid job_id representation. - """ - match = re.search(r'\d', job_id) - if match and match.start() is 0: - job_id = 'z_{}'.format(job_id) - return re.sub('[^0-9a-zA-Z]+', '_', job_id) - - -class CloudMLBatchPredictionOperator(BaseOperator): - """ - Start a Cloud ML prediction job. - - NOTE: For model origin, users should consider exactly one from the - three options below: - 1. Populate 'uri' field only, which should be a GCS location that - points to a tensorflow savedModel directory. - 2. Populate 'model_name' field only, which refers to an existing - model, and the default version of the model will be used. - 3. Populate both 'model_name' and 'version_name' fields, which - refers to a specific version of a specific model. - - In options 2 and 3, both model and version name should contain the - minimal identifier. For instance, call - CloudMLBatchPredictionOperator( - ..., - model_name='my_model', - version_name='my_version', - ...) - if the desired model version is - "projects/my_project/models/my_model/versions/my_version". - - - :param project_id: The Google Cloud project name where the - prediction job is submitted. - :type project_id: string - - :param job_id: A unique id for the prediction job on Google Cloud - ML Engine. - :type job_id: string - - :param data_format: The format of the input data. - It will default to 'DATA_FORMAT_UNSPECIFIED' if is not provided - or is not one of ["TEXT", "TF_RECORD", "TF_RECORD_GZIP"]. - :type data_format: string - - :param input_paths: A list of GCS paths of input data for batch - prediction. Accepting wildcard operator *, but only at the end. - :type input_paths: list of string - - :param output_path: The GCS path where the prediction results are - written to. - :type output_path: string - - :param region: The Google Compute Engine region to run the - prediction job in.: - :type region: string - - :param model_name: The Google Cloud ML model to use for prediction. - If version_name is not provided, the default version of this - model will be used. - Should not be None if version_name is provided. - Should be None if uri is provided. - :type model_name: string - - :param version_name: The Google Cloud ML model version to use for - prediction. - Should be None if uri is provided. - :type version_name: string - - :param uri: The GCS path of the saved model to use for prediction. - Should be None if model_name is provided. - It should be a GCS path pointing to a tensorflow SavedModel. - :type uri: string - - :param max_worker_count: The maximum number of workers to be used - for parallel processing. Defaults to 10 if not specified. - :type max_worker_count: int - - :param runtime_version: The Google Cloud ML runtime version to use - for batch prediction. - :type runtime_version: string - - :param gcp_conn_id: The connection ID used for connection to Google - Cloud Platform. - :type gcp_conn_id: string - - :param delegate_to: The account to impersonate, if any. - For this to work, the service account making the request must - have doamin-wide delegation enabled. - :type delegate_to: string - - Raises: - ValueError: if a unique model/version origin cannot be determined. - """ - - template_fields = [ - "prediction_job_request", - ] - - @apply_defaults - def __init__(self, - project_id, - job_id, - region, - data_format, - input_paths, - output_path, - model_name=None, - version_name=None, - uri=None, - max_worker_count=None, - runtime_version=None, - gcp_conn_id='google_cloud_default', - delegate_to=None, - *args, - **kwargs): - super(CloudMLBatchPredictionOperator, self).__init__(*args, **kwargs) - - self.project_id = project_id - self.gcp_conn_id = gcp_conn_id - self.delegate_to = delegate_to - - try: - prediction_input = _create_prediction_input( - project_id, region, data_format, input_paths, output_path, - model_name, version_name, uri, max_worker_count, - runtime_version) - except ValueError as e: - logging.error( - 'Cannot create batch prediction job request due to: {}' - .format(str(e))) - raise - - self.prediction_job_request = { - 'jobId': _normalize_cloudml_job_id(job_id), - 'predictionInput': prediction_input - } - - def execute(self, context): - hook = CloudMLHook(self.gcp_conn_id, self.delegate_to) - - def check_existing_job(existing_job): - return existing_job.get('predictionInput', None) == \ - self.prediction_job_request['predictionInput'] - try: - finished_prediction_job = hook.create_job( - self.project_id, - self.prediction_job_request, - check_existing_job) - except errors.HttpError: - raise - - if finished_prediction_job['state'] != 'SUCCEEDED': - logging.error( - 'Batch prediction job failed: %s', - str(finished_prediction_job)) - raise RuntimeError(finished_prediction_job['errorMessage']) - - return finished_prediction_job['predictionOutput'] - - -class CloudMLModelOperator(BaseOperator): - """ - Operator for managing a Google Cloud ML model. - - :param model: A dictionary containing the information about the model. - If the `operation` is `create`, then the `model` parameter should - contain all the information about this model such as `name`. - - If the `operation` is `get`, the `model` parameter - should contain the `name` of the model. - :type model: dict - - :param project_id: The Google Cloud project name to which CloudML - model belongs. - :type project_id: string - - :param gcp_conn_id: The connection ID to use when fetching connection info. - :type gcp_conn_id: string - - :param operation: The operation to perform. Available operations are: - 'create': Creates a new model as provided by the `model` parameter. - 'get': Gets a particular model where the name is specified in `model`. - - :param delegate_to: The account to impersonate, if any. - For this to work, the service account making the request must have - domain-wide delegation enabled. - :type delegate_to: string - """ - - template_fields = [ - '_model', - '_model_name', - ] - - @apply_defaults - def __init__(self, - project_id, - model, - gcp_conn_id='google_cloud_default', - operation='create', - delegate_to=None, - *args, - **kwargs): - super(CloudMLModelOperator, self).__init__(*args, **kwargs) - self._model = model - self._operation = operation - self._gcp_conn_id = gcp_conn_id - self._delegate_to = delegate_to - self._project_id = project_id - - def execute(self, context): - hook = CloudMLHook( - gcp_conn_id=self._gcp_conn_id, delegate_to=self._delegate_to) - if self._operation == 'create': - hook.create_model(self._project_id, self._model) - elif self._operation == 'get': - hook.get_model(self._project_id, self._model['name']) - else: - raise ValueError('Unknown operation: {}'.format(self._operation)) - - -class CloudMLVersionOperator(BaseOperator): - """ - Operator for managing a Google Cloud ML version. - - :param model_name: The name of the Google Cloud ML model that the version - belongs to. - :type model_name: string - - :param project_id: The Google Cloud project name to which CloudML - model belongs. - :type project_id: string - - :param version: A dictionary containing the information about the version. - If the `operation` is `create`, `version` should contain all the - information about this version such as name, and deploymentUrl. - If the `operation` is `get` or `delete`, the `version` parameter - should contain the `name` of the version. - If it is None, the only `operation` possible would be `list`. - :type version: dict - - :param version_name: A name to use for the version being operated upon. If - not None and the `version` argument is None or does not have a value for - the `name` key, then this will be populated in the payload for the - `name` key. - :type version_name: string - - :param gcp_conn_id: The connection ID to use when fetching connection info. - :type gcp_conn_id: string - - :param operation: The operation to perform. Available operations are: - 'create': Creates a new version in the model specified by `model_name`, - in which case the `version` parameter should contain all the - information to create that version - (e.g. `name`, `deploymentUrl`). - 'get': Gets full information of a particular version in the model - specified by `model_name`. - The name of the version should be specified in the `version` - parameter. - - 'list': Lists all available versions of the model specified - by `model_name`. - - 'delete': Deletes the version specified in `version` parameter from the - model specified by `model_name`). - The name of the version should be specified in the `version` - parameter. - :type operation: string - - :param delegate_to: The account to impersonate, if any. - For this to work, the service account making the request must have - domain-wide delegation enabled. - :type delegate_to: string - """ - - template_fields = [ - '_model_name', - '_version', - '_version_name', - ] - - @apply_defaults - def __init__(self, - model_name, - project_id, - version=None, - version_name=None, - gcp_conn_id='google_cloud_default', - operation='create', - delegate_to=None, - *args, - **kwargs): - - super(CloudMLVersionOperator, self).__init__(*args, **kwargs) - self._model_name = model_name - self._version = version or {} - self._version_name = version_name - self._gcp_conn_id = gcp_conn_id - self._delegate_to = delegate_to - self._project_id = project_id - self._operation = operation - - def execute(self, context): - if 'name' not in self._version: - self._version['name'] = self._version_name - - hook = CloudMLHook( - gcp_conn_id=self._gcp_conn_id, delegate_to=self._delegate_to) - - if self._operation == 'create': - assert self._version is not None - return hook.create_version(self._project_id, self._model_name, - self._version) - elif self._operation == 'set_default': - return hook.set_default_version( - self._project_id, self._model_name, - self._version['name']) - elif self._operation == 'list': - return hook.list_versions(self._project_id, self._model_name) - elif self._operation == 'delete': - return hook.delete_version(self._project_id, self._model_name, - self._version['name']) - else: - raise ValueError('Unknown operation: {}'.format(self._operation)) - - -class CloudMLTrainingOperator(BaseOperator): - """ - Operator for launching a CloudML training job. - - :param project_id: The Google Cloud project name within which CloudML - training job should run. This field could be templated. - :type project_id: string - - :param job_id: A unique templated id for the submitted Google CloudML - training job. - :type job_id: string - - :param package_uris: A list of package locations for CloudML training job, - which should include the main training program + any additional - dependencies. - :type package_uris: string - - :param training_python_module: The Python module name to run within CloudML - training job after installing 'package_uris' packages. - :type training_python_module: string - - :param training_args: A list of templated command line arguments to pass to - the CloudML training program. - :type training_args: string - - :param region: The Google Compute Engine region to run the CloudML training - job in. This field could be templated. - :type region: string - - :param scale_tier: Resource tier for CloudML training job. - :type scale_tier: string - - :param gcp_conn_id: The connection ID to use when fetching connection info. - :type gcp_conn_id: string - - :param delegate_to: The account to impersonate, if any. - For this to work, the service account making the request must have - domain-wide delegation enabled. - :type delegate_to: string - - :param mode: Can be one of 'DRY_RUN'/'CLOUD'. In 'DRY_RUN' mode, no real - training job will be launched, but the CloudML training job request - will be printed out. In 'CLOUD' mode, a real CloudML training job - creation request will be issued. - :type mode: string - """ - - template_fields = [ - '_project_id', - '_job_id', - '_package_uris', - '_training_python_module', - '_training_args', - '_region', - '_scale_tier', - ] - - @apply_defaults - def __init__(self, - project_id, - job_id, - package_uris, - training_python_module, - training_args, - region, - scale_tier=None, - gcp_conn_id='google_cloud_default', - delegate_to=None, - mode='PRODUCTION', - *args, - **kwargs): - super(CloudMLTrainingOperator, self).__init__(*args, **kwargs) - self._project_id = project_id - self._job_id = job_id - self._package_uris = package_uris - self._training_python_module = training_python_module - self._training_args = training_args - self._region = region - self._scale_tier = scale_tier - self._gcp_conn_id = gcp_conn_id - self._delegate_to = delegate_to - self._mode = mode - - if not self._project_id: - raise AirflowException('Google Cloud project id is required.') - if not self._job_id: - raise AirflowException( - 'An unique job id is required for Google CloudML training ' - 'job.') - if not package_uris: - raise AirflowException( - 'At least one python package is required for CloudML ' - 'Training job.') - if not training_python_module: - raise AirflowException( - 'Python module name to run after installing required ' - 'packages is required.') - if not self._region: - raise AirflowException('Google Compute Engine region is required.') - - def execute(self, context): - job_id = _normalize_cloudml_job_id(self._job_id) - training_request = { - 'jobId': job_id, - 'trainingInput': { - 'scaleTier': self._scale_tier, - 'packageUris': self._package_uris, - 'pythonModule': self._training_python_module, - 'region': self._region, - 'args': self._training_args, - } - } - - if self._mode == 'DRY_RUN': - logging.info('In dry_run mode.') - logging.info( - 'CloudML Training job request is: {}'.format(training_request)) - return - - hook = CloudMLHook( - gcp_conn_id=self._gcp_conn_id, delegate_to=self._delegate_to) - - # Helper method to check if the existing job's training input is the - # same as the request we get here. - def check_existing_job(existing_job): - return existing_job.get('trainingInput', None) == \ - training_request['trainingInput'] - try: - finished_training_job = hook.create_job( - self._project_id, training_request, check_existing_job) - except errors.HttpError: - raise - - if finished_training_job['state'] != 'SUCCEEDED': - logging.error('CloudML training job failed: {}'.format( - str(finished_training_job))) - raise RuntimeError(finished_training_job['errorMessage']) http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/af91e2ac/airflow/contrib/operators/cloudml_operator_utils.py ---------------------------------------------------------------------- diff --git a/airflow/contrib/operators/cloudml_operator_utils.py b/airflow/contrib/operators/cloudml_operator_utils.py deleted file mode 100644 index 81cd54f..0000000 --- a/airflow/contrib/operators/cloudml_operator_utils.py +++ /dev/null @@ -1,245 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the 'License'); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an 'AS IS' BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import base64 -import json -import os -import re - -import dill - -from airflow.contrib.hooks.gcs_hook import GoogleCloudStorageHook -from airflow.contrib.operators.cloudml_operator import CloudMLBatchPredictionOperator -from airflow.contrib.operators.dataflow_operator import DataFlowPythonOperator -from airflow.exceptions import AirflowException -from airflow.operators.python_operator import PythonOperator -from six.moves.urllib.parse import urlsplit - -def create_evaluate_ops(task_prefix, - data_format, - input_paths, - prediction_path, - metric_fn_and_keys, - validate_fn, - batch_prediction_job_id=None, - project_id=None, - region=None, - dataflow_options=None, - model_uri=None, - model_name=None, - version_name=None, - dag=None): - """ - Creates Operators needed for model evaluation and returns. - - It gets prediction over inputs via Cloud ML Engine BatchPrediction API by - calling CloudMLBatchPredictionOperator, then summarize and validate - the result via Cloud Dataflow using DataFlowPythonOperator. - - For details and pricing about Batch prediction, please refer to the website - https://cloud.google.com/ml-engine/docs/how-tos/batch-predict - and for Cloud Dataflow, https://cloud.google.com/dataflow/docs/ - - It returns three chained operators for prediction, summary, and validation, - named as <prefix>-prediction, <prefix>-summary, and <prefix>-validation, - respectively. - (<prefix> should contain only alphanumeric characters or hyphen.) - - The upstream and downstream can be set accordingly like: - pred, _, val = create_evaluate_ops(...) - pred.set_upstream(upstream_op) - ... - downstream_op.set_upstream(val) - - Callers will provide two python callables, metric_fn and validate_fn, in - order to customize the evaluation behavior as they wish. - - metric_fn receives a dictionary per instance derived from json in the - batch prediction result. The keys might vary depending on the model. - It should return a tuple of metrics. - - validation_fn receives a dictionary of the averaged metrics that metric_fn - generated over all instances. - The key/value of the dictionary matches to what's given by - metric_fn_and_keys arg. - The dictionary contains an additional metric, 'count' to represent the - total number of instances received for evaluation. - The function would raise an exception to mark the task as failed, in a - case the validation result is not okay to proceed (i.e. to set the trained - version as default). - - Typical examples are like this: - - def get_metric_fn_and_keys(): - import math # imports should be outside of the metric_fn below. - def error_and_squared_error(inst): - label = float(inst['input_label']) - classes = float(inst['classes']) # 0 or 1 - err = abs(classes-label) - squared_err = math.pow(classes-label, 2) - return (err, squared_err) # returns a tuple. - return error_and_squared_error, ['err', 'mse'] # key order must match. - - def validate_err_and_count(summary): - if summary['err'] > 0.2: - raise ValueError('Too high err>0.2; summary=%s' % summary) - if summary['mse'] > 0.05: - raise ValueError('Too high mse>0.05; summary=%s' % summary) - if summary['count'] < 1000: - raise ValueError('Too few instances<1000; summary=%s' % summary) - return summary - - For the details on the other BatchPrediction-related arguments (project_id, - job_id, region, data_format, input_paths, prediction_path, model_uri), - please refer to CloudMLBatchPredictionOperator too. - - :param task_prefix: a prefix for the tasks. Only alphanumeric characters and - hyphen are allowed (no underscores), since this will be used as dataflow - job name, which doesn't allow other characters. - :type task_prefix: string - - :param data_format: either of 'TEXT', 'TF_RECORD', 'TF_RECORD_GZIP' - :type data_format: string - - :param input_paths: a list of input paths to be sent to BatchPrediction. - :type input_paths: list of strings - - :param prediction_path: GCS path to put the prediction results in. - :type prediction_path: string - - :param metric_fn_and_keys: a tuple of metric_fn and metric_keys: - - metric_fn is a function that accepts a dictionary (for an instance), - and returns a tuple of metric(s) that it calculates. - - metric_keys is a list of strings to denote the key of each metric. - :type metric_fn_and_keys: tuple of a function and a list of strings - - :param validate_fn: a function to validate whether the averaged metric(s) is - good enough to push the model. - :type validate_fn: function - - :param batch_prediction_job_id: the id to use for the Cloud ML Batch - prediction job. Passed directly to the CloudMLBatchPredictionOperator as - the job_id argument. - :type batch_prediction_job_id: string - - :param project_id: the Google Cloud Platform project id in which to execute - Cloud ML Batch Prediction and Dataflow jobs. If None, then the `dag`'s - `default_args['project_id']` will be used. - :type project_id: string - - :param region: the Google Cloud Platform region in which to execute Cloud ML - Batch Prediction and Dataflow jobs. If None, then the `dag`'s - `default_args['region']` will be used. - :type region: string - - :param dataflow_options: options to run Dataflow jobs. If None, then the - `dag`'s `default_args['dataflow_default_options']` will be used. - :type dataflow_options: dictionary - - :param model_uri: GCS path of the model exported by Tensorflow using - tensorflow.estimator.export_savedmodel(). It cannot be used with - model_name or version_name below. See CloudMLBatchPredictionOperator for - more detail. - :type model_uri: string - - :param model_name: Used to indicate a model to use for prediction. Can be - used in combination with version_name, but cannot be used together with - model_uri. See CloudMLBatchPredictionOperator for more detail. If None, - then the `dag`'s `default_args['model_name']` will be used. - :type model_name: string - - :param version_name: Used to indicate a model version to use for prediciton, - in combination with model_name. Cannot be used together with model_uri. - See CloudMLBatchPredictionOperator for more detail. If None, then the - `dag`'s `default_args['version_name']` will be used. - :type version_name: string - - :param dag: The `DAG` to use for all Operators. - :type dag: airflow.DAG - - :returns: a tuple of three operators, (prediction, summary, validation) - :rtype: tuple(DataFlowPythonOperator, DataFlowPythonOperator, - PythonOperator) - """ - - # Verify that task_prefix doesn't have any special characters except hyphen - # '-', which is the only allowed non-alphanumeric character by Dataflow. - if not re.match(r"^[a-zA-Z][-A-Za-z0-9]*$", task_prefix): - raise AirflowException( - "Malformed task_id for DataFlowPythonOperator (only alphanumeric " - "and hyphens are allowed but got: " + task_prefix) - - metric_fn, metric_keys = metric_fn_and_keys - if not callable(metric_fn): - raise AirflowException("`metric_fn` param must be callable.") - if not callable(validate_fn): - raise AirflowException("`validate_fn` param must be callable.") - - if dag is not None and dag.default_args is not None: - default_args = dag.default_args - project_id = project_id or default_args.get('project_id') - region = region or default_args.get('region') - model_name = model_name or default_args.get('model_name') - version_name = version_name or default_args.get('version_name') - dataflow_options = dataflow_options or \ - default_args.get('dataflow_default_options') - - evaluate_prediction = CloudMLBatchPredictionOperator( - task_id=(task_prefix + "-prediction"), - project_id=project_id, - job_id=batch_prediction_job_id, - region=region, - data_format=data_format, - input_paths=input_paths, - output_path=prediction_path, - uri=model_uri, - model_name=model_name, - version_name=version_name, - dag=dag) - - metric_fn_encoded = base64.b64encode(dill.dumps(metric_fn, recurse=True)) - evaluate_summary = DataFlowPythonOperator( - task_id=(task_prefix + "-summary"), - py_options=["-m"], - py_file="airflow.contrib.operators.cloudml_prediction_summary", - dataflow_default_options=dataflow_options, - options={ - "prediction_path": prediction_path, - "metric_fn_encoded": metric_fn_encoded, - "metric_keys": ','.join(metric_keys) - }, - dag=dag) - evaluate_summary.set_upstream(evaluate_prediction) - - def apply_validate_fn(*args, **kwargs): - prediction_path = kwargs["templates_dict"]["prediction_path"] - scheme, bucket, obj, _, _ = urlsplit(prediction_path) - if scheme != "gs" or not bucket or not obj: - raise ValueError("Wrong format prediction_path: %s", - prediction_path) - summary = os.path.join(obj.strip("/"), - "prediction.summary.json") - gcs_hook = GoogleCloudStorageHook() - summary = json.loads(gcs_hook.download(bucket, summary)) - return validate_fn(summary) - - evaluate_validation = PythonOperator( - task_id=(task_prefix + "-validation"), - python_callable=apply_validate_fn, - provide_context=True, - templates_dict={"prediction_path": prediction_path}, - dag=dag) - evaluate_validation.set_upstream(evaluate_summary) - - return evaluate_prediction, evaluate_summary, evaluate_validation http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/af91e2ac/airflow/contrib/operators/cloudml_prediction_summary.py ---------------------------------------------------------------------- diff --git a/airflow/contrib/operators/cloudml_prediction_summary.py b/airflow/contrib/operators/cloudml_prediction_summary.py deleted file mode 100644 index 3128dc3..0000000 --- a/airflow/contrib/operators/cloudml_prediction_summary.py +++ /dev/null @@ -1,177 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the 'License'); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an 'AS IS' BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""A template called by DataFlowPythonOperator to summarize BatchPrediction. - -It accepts a user function to calculate the metric(s) per instance in -the prediction results, then aggregates to output as a summary. - -Args: - --prediction_path: - The GCS folder that contains BatchPrediction results, containing - prediction.results-NNNNN-of-NNNNN files in the json format. - Output will be also stored in this folder, as 'prediction.summary.json'. - - --metric_fn_encoded: - An encoded function that calculates and returns a tuple of metric(s) - for a given instance (as a dictionary). It should be encoded - via base64.b64encode(dill.dumps(fn, recurse=True)). - - --metric_keys: - A comma-separated key(s) of the aggregated metric(s) in the summary - output. The order and the size of the keys must match to the output - of metric_fn. - The summary will have an additional key, 'count', to represent the - total number of instances, so the keys shouldn't include 'count'. - -# Usage example: -def get_metric_fn(): - import math # all imports must be outside of the function to be passed. - def metric_fn(inst): - label = float(inst["input_label"]) - classes = float(inst["classes"]) - prediction = float(inst["scores"][1]) - log_loss = math.log(1 + math.exp( - -(label * 2 - 1) * math.log(prediction / (1 - prediction)))) - squared_err = (classes-label)**2 - return (log_loss, squared_err) - return metric_fn -metric_fn_encoded = base64.b64encode(dill.dumps(get_metric_fn(), recurse=True)) - -airflow.contrib.operators.DataFlowPythonOperator( - task_id="summary-prediction", - py_options=["-m"], - py_file="airflow.contrib.operators.cloudml_prediction_summary", - options={ - "prediction_path": prediction_path, - "metric_fn_encoded": metric_fn_encoded, - "metric_keys": "log_loss,mse" - }, - dataflow_default_options={ - "project": "xxx", "region": "us-east1", - "staging_location": "gs://yy", "temp_location": "gs://zz", - }) - >> dag - -# When the input file is like the following: -{"inputs": "1,x,y,z", "classes": 1, "scores": [0.1, 0.9]} -{"inputs": "0,o,m,g", "classes": 0, "scores": [0.7, 0.3]} -{"inputs": "1,o,m,w", "classes": 0, "scores": [0.6, 0.4]} -{"inputs": "1,b,r,b", "classes": 1, "scores": [0.2, 0.8]} - -# The output file will be: -{"log_loss": 0.43890510565304547, "count": 4, "mse": 0.25} - -# To test outside of the dag: -subprocess.check_call(["python", - "-m", - "airflow.contrib.operators.cloudml_prediction_summary", - "--prediction_path=gs://...", - "--metric_fn_encoded=" + metric_fn_encoded, - "--metric_keys=log_loss,mse", - "--runner=DataflowRunner", - "--staging_location=gs://...", - "--temp_location=gs://...", - ]) - -""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import argparse -import base64 -import json -import logging -import os - -import apache_beam as beam -import dill - - -class JsonCoder(object): - def encode(self, x): - return json.dumps(x) - - def decode(self, x): - return json.loads(x) - - [email protected]_fn -def MakeSummary(pcoll, metric_fn, metric_keys): # pylint: disable=invalid-name - return ( - pcoll - | "ApplyMetricFnPerInstance" >> beam.Map(metric_fn) - | "PairWith1" >> beam.Map(lambda tup: tup + (1,)) - | "SumTuple" >> beam.CombineGlobally(beam.combiners.TupleCombineFn( - *([sum] * (len(metric_keys) + 1)))) - | "AverageAndMakeDict" >> beam.Map( - lambda tup: dict( - [(name, tup[i]/tup[-1]) for i, name in enumerate(metric_keys)] + - [("count", tup[-1])]))) - - -def run(argv=None): - parser = argparse.ArgumentParser() - parser.add_argument( - "--prediction_path", required=True, - help=( - "The GCS folder that contains BatchPrediction results, containing " - "prediction.results-NNNNN-of-NNNNN files in the json format. " - "Output will be also stored in this folder, as a file" - "'prediction.summary.json'.")) - parser.add_argument( - "--metric_fn_encoded", required=True, - help=( - "An encoded function that calculates and returns a tuple of " - "metric(s) for a given instance (as a dictionary). It should be " - "encoded via base64.b64encode(dill.dumps(fn, recurse=True)).")) - parser.add_argument( - "--metric_keys", required=True, - help=( - "A comma-separated keys of the aggregated metric(s) in the summary " - "output. The order and the size of the keys must match to the " - "output of metric_fn. The summary will have an additional key, " - "'count', to represent the total number of instances, so this flag " - "shouldn't include 'count'.")) - known_args, pipeline_args = parser.parse_known_args(argv) - - metric_fn = dill.loads(base64.b64decode(known_args.metric_fn_encoded)) - if not callable(metric_fn): - raise ValueError("--metric_fn_encoded must be an encoded callable.") - metric_keys = known_args.metric_keys.split(",") - - with beam.Pipeline( - options=beam.pipeline.PipelineOptions(pipeline_args)) as p: - # This is apache-beam ptransform's convention - # pylint: disable=no-value-for-parameter - _ = (p - | "ReadPredictionResult" >> beam.io.ReadFromText( - os.path.join(known_args.prediction_path, - "prediction.results-*-of-*"), - coder=JsonCoder()) - | "Summary" >> MakeSummary(metric_fn, metric_keys) - | "Write" >> beam.io.WriteToText( - os.path.join(known_args.prediction_path, - "prediction.summary.json"), - shard_name_template='', # without trailing -NNNNN-of-NNNNN. - coder=JsonCoder())) - # pylint: enable=no-value-for-parameter - - -if __name__ == "__main__": - logging.getLogger().setLevel(logging.INFO) - run() http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/af91e2ac/airflow/contrib/operators/mlengine_operator.py ---------------------------------------------------------------------- diff --git a/airflow/contrib/operators/mlengine_operator.py b/airflow/contrib/operators/mlengine_operator.py new file mode 100644 index 0000000..7476825 --- /dev/null +++ b/airflow/contrib/operators/mlengine_operator.py @@ -0,0 +1,564 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the 'License'); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import re + +from airflow import settings +from airflow.contrib.hooks.gcp_mlengine_hook import MLEngineHook +from airflow.exceptions import AirflowException +from airflow.operators import BaseOperator +from airflow.utils.decorators import apply_defaults +from apiclient import errors + + +logging.getLogger('GoogleCloudMLEngine').setLevel(settings.LOGGING_LEVEL) + + +def _create_prediction_input(project_id, + region, + data_format, + input_paths, + output_path, + model_name=None, + version_name=None, + uri=None, + max_worker_count=None, + runtime_version=None): + """ + Create the batch prediction input from the given parameters. + + Args: + A subset of arguments documented in __init__ method of class + MLEngineBatchPredictionOperator + + Returns: + A dictionary representing the predictionInput object as documented + in https://cloud.google.com/ml-engine/reference/rest/v1/projects.jobs. + + Raises: + ValueError: if a unique model/version origin cannot be determined. + """ + + prediction_input = { + 'dataFormat': data_format, + 'inputPaths': input_paths, + 'outputPath': output_path, + 'region': region + } + + if uri: + if model_name or version_name: + logging.error( + 'Ambiguous model origin: Both uri and model/version name are ' + 'provided.') + raise ValueError('Ambiguous model origin.') + prediction_input['uri'] = uri + elif model_name: + origin_name = 'projects/{}/models/{}'.format(project_id, model_name) + if not version_name: + prediction_input['modelName'] = origin_name + else: + prediction_input['versionName'] = \ + origin_name + '/versions/{}'.format(version_name) + else: + logging.error( + 'Missing model origin: Batch prediction expects a model, ' + 'a model & version combination, or a URI to savedModel.') + raise ValueError('Missing model origin.') + + if max_worker_count: + prediction_input['maxWorkerCount'] = max_worker_count + if runtime_version: + prediction_input['runtimeVersion'] = runtime_version + + return prediction_input + + +def _normalize_mlengine_job_id(job_id): + """ + Replaces invalid MLEngine job_id characters with '_'. + + This also adds a leading 'z' in case job_id starts with an invalid + character. + + Args: + job_id: A job_id str that may have invalid characters. + + Returns: + A valid job_id representation. + """ + match = re.search(r'\d', job_id) + if match and match.start() is 0: + job_id = 'z_{}'.format(job_id) + return re.sub('[^0-9a-zA-Z]+', '_', job_id) + + +class MLEngineBatchPredictionOperator(BaseOperator): + """ + Start a Google Cloud ML Engine prediction job. + + NOTE: For model origin, users should consider exactly one from the + three options below: + 1. Populate 'uri' field only, which should be a GCS location that + points to a tensorflow savedModel directory. + 2. Populate 'model_name' field only, which refers to an existing + model, and the default version of the model will be used. + 3. Populate both 'model_name' and 'version_name' fields, which + refers to a specific version of a specific model. + + In options 2 and 3, both model and version name should contain the + minimal identifier. For instance, call + MLEngineBatchPredictionOperator( + ..., + model_name='my_model', + version_name='my_version', + ...) + if the desired model version is + "projects/my_project/models/my_model/versions/my_version". + + + :param project_id: The Google Cloud project name where the + prediction job is submitted. + :type project_id: string + + :param job_id: A unique id for the prediction job on Google Cloud + ML Engine. + :type job_id: string + + :param data_format: The format of the input data. + It will default to 'DATA_FORMAT_UNSPECIFIED' if is not provided + or is not one of ["TEXT", "TF_RECORD", "TF_RECORD_GZIP"]. + :type data_format: string + + :param input_paths: A list of GCS paths of input data for batch + prediction. Accepting wildcard operator *, but only at the end. + :type input_paths: list of string + + :param output_path: The GCS path where the prediction results are + written to. + :type output_path: string + + :param region: The Google Compute Engine region to run the + prediction job in.: + :type region: string + + :param model_name: The Google Cloud ML Engine model to use for prediction. + If version_name is not provided, the default version of this + model will be used. + Should not be None if version_name is provided. + Should be None if uri is provided. + :type model_name: string + + :param version_name: The Google Cloud ML Engine model version to use for + prediction. + Should be None if uri is provided. + :type version_name: string + + :param uri: The GCS path of the saved model to use for prediction. + Should be None if model_name is provided. + It should be a GCS path pointing to a tensorflow SavedModel. + :type uri: string + + :param max_worker_count: The maximum number of workers to be used + for parallel processing. Defaults to 10 if not specified. + :type max_worker_count: int + + :param runtime_version: The Google Cloud ML Engine runtime version to use + for batch prediction. + :type runtime_version: string + + :param gcp_conn_id: The connection ID used for connection to Google + Cloud Platform. + :type gcp_conn_id: string + + :param delegate_to: The account to impersonate, if any. + For this to work, the service account making the request must + have doamin-wide delegation enabled. + :type delegate_to: string + + Raises: + ValueError: if a unique model/version origin cannot be determined. + """ + + template_fields = [ + "prediction_job_request", + ] + + @apply_defaults + def __init__(self, + project_id, + job_id, + region, + data_format, + input_paths, + output_path, + model_name=None, + version_name=None, + uri=None, + max_worker_count=None, + runtime_version=None, + gcp_conn_id='google_cloud_default', + delegate_to=None, + *args, + **kwargs): + super(MLEngineBatchPredictionOperator, self).__init__(*args, **kwargs) + + self.project_id = project_id + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + + try: + prediction_input = _create_prediction_input( + project_id, region, data_format, input_paths, output_path, + model_name, version_name, uri, max_worker_count, + runtime_version) + except ValueError as e: + logging.error( + 'Cannot create batch prediction job request due to: {}' + .format(str(e))) + raise + + self.prediction_job_request = { + 'jobId': _normalize_mlengine_job_id(job_id), + 'predictionInput': prediction_input + } + + def execute(self, context): + hook = MLEngineHook(self.gcp_conn_id, self.delegate_to) + + def check_existing_job(existing_job): + return existing_job.get('predictionInput', None) == \ + self.prediction_job_request['predictionInput'] + try: + finished_prediction_job = hook.create_job( + self.project_id, + self.prediction_job_request, + check_existing_job) + except errors.HttpError: + raise + + if finished_prediction_job['state'] != 'SUCCEEDED': + logging.error( + 'Batch prediction job failed: %s', + str(finished_prediction_job)) + raise RuntimeError(finished_prediction_job['errorMessage']) + + return finished_prediction_job['predictionOutput'] + + +class MLEngineModelOperator(BaseOperator): + """ + Operator for managing a Google Cloud ML Engine model. + + :param project_id: The Google Cloud project name to which MLEngine + model belongs. + :type project_id: string + + :param model: A dictionary containing the information about the model. + If the `operation` is `create`, then the `model` parameter should + contain all the information about this model such as `name`. + + If the `operation` is `get`, the `model` parameter + should contain the `name` of the model. + :type model: dict + + :param operation: The operation to perform. Available operations are: + 'create': Creates a new model as provided by the `model` parameter. + 'get': Gets a particular model where the name is specified in `model`. + + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: string + + :param delegate_to: The account to impersonate, if any. + For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: string + """ + + template_fields = [ + '_model', + ] + + @apply_defaults + def __init__(self, + project_id, + model, + operation='create', + gcp_conn_id='google_cloud_default', + delegate_to=None, + *args, + **kwargs): + super(MLEngineModelOperator, self).__init__(*args, **kwargs) + self._project_id = project_id + self._model = model + self._operation = operation + self._gcp_conn_id = gcp_conn_id + self._delegate_to = delegate_to + + def execute(self, context): + hook = MLEngineHook( + gcp_conn_id=self._gcp_conn_id, delegate_to=self._delegate_to) + if self._operation == 'create': + return hook.create_model(self._project_id, self._model) + elif self._operation == 'get': + return hook.get_model(self._project_id, self._model['name']) + else: + raise ValueError('Unknown operation: {}'.format(self._operation)) + + +class MLEngineVersionOperator(BaseOperator): + """ + Operator for managing a Google Cloud ML Engine version. + + :param project_id: The Google Cloud project name to which MLEngine + model belongs. + :type project_id: string + + :param model_name: The name of the Google Cloud ML Engine model that the version + belongs to. + :type model_name: string + + :param version_name: A name to use for the version being operated upon. If + not None and the `version` argument is None or does not have a value for + the `name` key, then this will be populated in the payload for the + `name` key. + :type version_name: string + + :param version: A dictionary containing the information about the version. + If the `operation` is `create`, `version` should contain all the + information about this version such as name, and deploymentUrl. + If the `operation` is `get` or `delete`, the `version` parameter + should contain the `name` of the version. + If it is None, the only `operation` possible would be `list`. + :type version: dict + + :param operation: The operation to perform. Available operations are: + 'create': Creates a new version in the model specified by `model_name`, + in which case the `version` parameter should contain all the + information to create that version + (e.g. `name`, `deploymentUrl`). + 'get': Gets full information of a particular version in the model + specified by `model_name`. + The name of the version should be specified in the `version` + parameter. + + 'list': Lists all available versions of the model specified + by `model_name`. + + 'delete': Deletes the version specified in `version` parameter from the + model specified by `model_name`). + The name of the version should be specified in the `version` + parameter. + :type operation: string + + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: string + + :param delegate_to: The account to impersonate, if any. + For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: string + """ + + template_fields = [ + '_model_name', + '_version_name', + '_version', + ] + + @apply_defaults + def __init__(self, + project_id, + model_name, + version_name=None, + version=None, + operation='create', + gcp_conn_id='google_cloud_default', + delegate_to=None, + *args, + **kwargs): + + super(MLEngineVersionOperator, self).__init__(*args, **kwargs) + self._project_id = project_id + self._model_name = model_name + self._version_name = version_name + self._version = version or {} + self._operation = operation + self._gcp_conn_id = gcp_conn_id + self._delegate_to = delegate_to + + def execute(self, context): + if 'name' not in self._version: + self._version['name'] = self._version_name + + hook = MLEngineHook( + gcp_conn_id=self._gcp_conn_id, delegate_to=self._delegate_to) + + if self._operation == 'create': + assert self._version is not None + return hook.create_version(self._project_id, self._model_name, + self._version) + elif self._operation == 'set_default': + return hook.set_default_version( + self._project_id, self._model_name, + self._version['name']) + elif self._operation == 'list': + return hook.list_versions(self._project_id, self._model_name) + elif self._operation == 'delete': + return hook.delete_version(self._project_id, self._model_name, + self._version['name']) + else: + raise ValueError('Unknown operation: {}'.format(self._operation)) + + +class MLEngineTrainingOperator(BaseOperator): + """ + Operator for launching a MLEngine training job. + + :param project_id: The Google Cloud project name within which MLEngine + training job should run. This field could be templated. + :type project_id: string + + :param job_id: A unique templated id for the submitted Google MLEngine + training job. + :type job_id: string + + :param package_uris: A list of package locations for MLEngine training job, + which should include the main training program + any additional + dependencies. + :type package_uris: string + + :param training_python_module: The Python module name to run within MLEngine + training job after installing 'package_uris' packages. + :type training_python_module: string + + :param training_args: A list of templated command line arguments to pass to + the MLEngine training program. + :type training_args: string + + :param region: The Google Compute Engine region to run the MLEngine training + job in. This field could be templated. + :type region: string + + :param scale_tier: Resource tier for MLEngine training job. + :type scale_tier: string + + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: string + + :param delegate_to: The account to impersonate, if any. + For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: string + + :param mode: Can be one of 'DRY_RUN'/'CLOUD'. In 'DRY_RUN' mode, no real + training job will be launched, but the MLEngine training job request + will be printed out. In 'CLOUD' mode, a real MLEngine training job + creation request will be issued. + :type mode: string + """ + + template_fields = [ + '_project_id', + '_job_id', + '_package_uris', + '_training_python_module', + '_training_args', + '_region', + '_scale_tier', + ] + + @apply_defaults + def __init__(self, + project_id, + job_id, + package_uris, + training_python_module, + training_args, + region, + scale_tier=None, + gcp_conn_id='google_cloud_default', + delegate_to=None, + mode='PRODUCTION', + *args, + **kwargs): + super(MLEngineTrainingOperator, self).__init__(*args, **kwargs) + self._project_id = project_id + self._job_id = job_id + self._package_uris = package_uris + self._training_python_module = training_python_module + self._training_args = training_args + self._region = region + self._scale_tier = scale_tier + self._gcp_conn_id = gcp_conn_id + self._delegate_to = delegate_to + self._mode = mode + + if not self._project_id: + raise AirflowException('Google Cloud project id is required.') + if not self._job_id: + raise AirflowException( + 'An unique job id is required for Google MLEngine training ' + 'job.') + if not package_uris: + raise AirflowException( + 'At least one python package is required for MLEngine ' + 'Training job.') + if not training_python_module: + raise AirflowException( + 'Python module name to run after installing required ' + 'packages is required.') + if not self._region: + raise AirflowException('Google Compute Engine region is required.') + + def execute(self, context): + job_id = _normalize_mlengine_job_id(self._job_id) + training_request = { + 'jobId': job_id, + 'trainingInput': { + 'scaleTier': self._scale_tier, + 'packageUris': self._package_uris, + 'pythonModule': self._training_python_module, + 'region': self._region, + 'args': self._training_args, + } + } + + if self._mode == 'DRY_RUN': + logging.info('In dry_run mode.') + logging.info( + 'MLEngine Training job request is: {}'.format(training_request)) + return + + hook = MLEngineHook( + gcp_conn_id=self._gcp_conn_id, delegate_to=self._delegate_to) + + # Helper method to check if the existing job's training input is the + # same as the request we get here. + def check_existing_job(existing_job): + return existing_job.get('trainingInput', None) == \ + training_request['trainingInput'] + try: + finished_training_job = hook.create_job( + self._project_id, training_request, check_existing_job) + except errors.HttpError: + raise + + if finished_training_job['state'] != 'SUCCEEDED': + logging.error('MLEngine training job failed: {}'.format( + str(finished_training_job))) + raise RuntimeError(finished_training_job['errorMessage'])
