Repository: incubator-airflow Updated Branches: refs/heads/master e88ecff6a -> 194d1d6e5
[AIRFLOW-1359] Add Google CloudML utils for model evaluation Closes #2407 from yk5/evaluate Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/194d1d6e Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/194d1d6e Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/194d1d6e Branch: refs/heads/master Commit: 194d1d6e5b89918f22267ae6a86455a0acc771df Parents: e88ecff Author: Younghee Kwon <[email protected]> Authored: Thu Jul 13 17:06:06 2017 -0700 Committer: Chris Riccomini <[email protected]> Committed: Thu Jul 13 17:06:56 2017 -0700 ---------------------------------------------------------------------- .../contrib/operators/cloudml_operator_utils.py | 223 +++++++++++++++++++ .../operators/cloudml_prediction_summary.py | 177 +++++++++++++++ .../operators/test_cloudml_operator_utils.py | 179 +++++++++++++++ 3 files changed, 579 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/194d1d6e/airflow/contrib/operators/cloudml_operator_utils.py ---------------------------------------------------------------------- diff --git a/airflow/contrib/operators/cloudml_operator_utils.py b/airflow/contrib/operators/cloudml_operator_utils.py new file mode 100644 index 0000000..f4abb32 --- /dev/null +++ b/airflow/contrib/operators/cloudml_operator_utils.py @@ -0,0 +1,223 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the 'License'); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import base64 +import json +import os +import re +try: # python 2 + from urlparse import urlsplit +except ImportError: # python 3 + from urllib.parse import urlsplit + +import dill + +from airflow.contrib.hooks.gcs_hook import GoogleCloudStorageHook +from airflow.contrib.operators.cloudml_operator import CloudMLBatchPredictionOperator +from airflow.contrib.operators.cloudml_operator import _normalize_cloudml_job_id +from airflow.contrib.operators.dataflow_operator import DataFlowPythonOperator +from airflow.exceptions import AirflowException +from airflow.operators.python_operator import PythonOperator + + +def create_evaluate_ops(task_prefix, + project_id, + job_id, + region, + data_format, + input_paths, + prediction_path, + metric_fn_and_keys, + validate_fn, + dataflow_options, + model_uri=None, + model_name=None, + version_name=None, + dag=None): + """ + Creates Operators needed for model evaluation and returns. + + It gets prediction over inputs via Cloud ML Engine BatchPrediction API by + calling CloudMLBatchPredictionOperator, then summarize and validate + the result via Cloud Dataflow using DataFlowPythonOperator. + + For details and pricing about Batch prediction, please refer to the website + https://cloud.google.com/ml-engine/docs/how-tos/batch-predict + and for Cloud Dataflow, https://cloud.google.com/dataflow/docs/ + + It returns three chained operators for prediction, summary, and validation, + named as <prefix>-prediction, <prefix>-summary, and <prefix>-validation, + respectively. + (<prefix> should contain only alphanumeric characters or hyphen.) + + The upstream and downstream can be set accordingly like: + pred, _, val = create_evaluate_ops(...) + pred.set_upstream(upstream_op) + ... + downstream_op.set_upstream(val) + + Callers will provide two python callables, metric_fn and validate_fn, in + order to customize the evaluation behavior as they wish. + - metric_fn receives a dictionary per instance derived from json in the + batch prediction result. The keys might vary depending on the model. + It should return a tuple of metrics. + - validation_fn receives a dictionary of the averaged metrics that metric_fn + generated over all instances. + The key/value of the dictionary matches to what's given by + metric_fn_and_keys arg. + The dictionary contains an additional metric, 'count' to represent the + total number of instances received for evaluation. + The function would raise an exception to mark the task as failed, in a + case the validation result is not okay to proceed (i.e. to set the trained + version as default). + + Typical examples are like this: + + def get_metric_fn_and_keys(): + import math # imports should be outside of the metric_fn below. + def error_and_squared_error(inst): + label = float(inst['input_label']) + classes = float(inst['classes']) # 0 or 1 + err = abs(classes-label) + squared_err = math.pow(classes-label, 2) + return (err, squared_err) # returns a tuple. + return error_and_squared_error, ['err', 'mse'] # key order must match. + + def validate_err_and_count(summary): + if summary['err'] > 0.2: + raise ValueError('Too high err>0.2; summary=%s' % summary) + if summary['mse'] > 0.05: + raise ValueError('Too high mse>0.05; summary=%s' % summary) + if summary['count'] < 1000: + raise ValueError('Too few instances<1000; summary=%s' % summary) + return summary + + For the details on the other BatchPrediction-related arguments (project_id, + job_id, region, data_format, input_paths, prediction_path, model_uri), + please refer to CloudMLBatchPredictionOperator too. + + :param task_prefix: a prefix for the tasks. Only alphanumeric characters and + hyphen are allowed (no underscores), since this will be used as dataflow + job name, which doesn't allow other characters. + :type task_prefix: string + + :param model_uri: GCS path of the model exported by Tensorflow using + tensorflow.estimator.export_savedmodel(). It cannot be used with + model_name or version_name below. See CloudMLBatchPredictionOperator for + more detail. + :type model_uri: string + + :param model_name: Used to indicate a model to use for prediction. Can be + used in combination with version_name, but cannot be used together with + model_uri. See CloudMLBatchPredictionOperator for more detail. + :type model_name: string + + :param version_name: Used to indicate a model version to use for prediciton, + in combination with model_name. Cannot be used together with model_uri. + See CloudMLBatchPredictionOperator for more detail. + :type version_name: string + + :param data_format: either of 'TEXT', 'TF_RECORD', 'TF_RECORD_GZIP' + :type data_format: string + + :param input_paths: a list of input paths to be sent to BatchPrediction. + :type input_paths: list of strings + + :param prediction_path: GCS path to put the prediction results in. + :type prediction_path: string + + :param metric_fn_and_keys: a tuple of metric_fn and metric_keys: + - metric_fn is a function that accepts a dictionary (for an instance), + and returns a tuple of metric(s) that it calculates. + - metric_keys is a list of strings to denote the key of each metric. + :type metric_fn_and_keys: tuple of a function and a list of strings + + :param validate_fn: a function to validate whether the averaged metric(s) is + good enough to push the model. + :type validate_fn: function + + :param dataflow_options: options to run Dataflow jobs. + :type dataflow_options: dictionary + + :returns: a tuple of three operators, (prediction, summary, validation) + :rtype: tuple(DataFlowPythonOperator, DataFlowPythonOperator, + PythonOperator) + """ + + # Verify that task_prefix doesn't have any special characters except hyphen + # '-', which is the only allowed non-alphanumeric character by Dataflow. + if not re.match(r"^[a-zA-Z][-A-Za-z0-9]*$", task_prefix): + raise AirflowException( + "Malformed task_id for DataFlowPythonOperator (only alphanumeric " + "and hyphens are allowed but got: " + task_prefix) + + metric_fn, metric_keys = metric_fn_and_keys + if not callable(metric_fn): + raise AirflowException("`metric_fn` param must be callable.") + if not callable(validate_fn): + raise AirflowException("`validate_fn` param must be callable.") + + evaluate_prediction = CloudMLBatchPredictionOperator( + task_id=(task_prefix + "-prediction"), + project_id=project_id, + job_id=_normalize_cloudml_job_id(job_id), + region=region, + data_format=data_format, + input_paths=input_paths, + output_path=prediction_path, + uri=model_uri, + model_name=model_name, + version_name=version_name, + dag=dag) + + metric_fn_encoded = base64.b64encode(dill.dumps(metric_fn, recurse=True)) + evaluate_summary = DataFlowPythonOperator( + task_id=(task_prefix + "-summary"), + py_options=["-m"], + py_file="airflow.contrib.operators.cloudml_prediction_summary", + dataflow_default_options=dataflow_options, + options={ + "prediction_path": prediction_path, + "metric_fn_encoded": metric_fn_encoded, + "metric_keys": ','.join(metric_keys) + }, + dag=dag) + # TODO: "options" is not template_field of DataFlowPythonOperator (not sure + # if intended or by mistake); consider fixing in the DataFlowPythonOperator. + evaluate_summary.template_fields.append("options") + evaluate_summary.set_upstream(evaluate_prediction) + + def apply_validate_fn(*args, **kwargs): + prediction_path = kwargs["templates_dict"]["prediction_path"] + scheme, bucket, obj, _, _ = urlsplit(prediction_path) + if scheme != "gs" or not bucket or not obj: + raise ValueError("Wrong format prediction_path: %s", + prediction_path) + summary = os.path.join(obj.strip("/"), + "prediction.summary.json") + gcs_hook = GoogleCloudStorageHook() + summary = json.loads(gcs_hook.download(bucket, summary)) + return validate_fn(summary) + + evaluate_validation = PythonOperator( + task_id=(task_prefix + "-validation"), + python_callable=apply_validate_fn, + provide_context=True, + templates_dict={"prediction_path": prediction_path}, + dag=dag) + evaluate_validation.set_upstream(evaluate_summary) + + return evaluate_prediction, evaluate_summary, evaluate_validation http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/194d1d6e/airflow/contrib/operators/cloudml_prediction_summary.py ---------------------------------------------------------------------- diff --git a/airflow/contrib/operators/cloudml_prediction_summary.py b/airflow/contrib/operators/cloudml_prediction_summary.py new file mode 100644 index 0000000..3128dc3 --- /dev/null +++ b/airflow/contrib/operators/cloudml_prediction_summary.py @@ -0,0 +1,177 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the 'License'); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A template called by DataFlowPythonOperator to summarize BatchPrediction. + +It accepts a user function to calculate the metric(s) per instance in +the prediction results, then aggregates to output as a summary. + +Args: + --prediction_path: + The GCS folder that contains BatchPrediction results, containing + prediction.results-NNNNN-of-NNNNN files in the json format. + Output will be also stored in this folder, as 'prediction.summary.json'. + + --metric_fn_encoded: + An encoded function that calculates and returns a tuple of metric(s) + for a given instance (as a dictionary). It should be encoded + via base64.b64encode(dill.dumps(fn, recurse=True)). + + --metric_keys: + A comma-separated key(s) of the aggregated metric(s) in the summary + output. The order and the size of the keys must match to the output + of metric_fn. + The summary will have an additional key, 'count', to represent the + total number of instances, so the keys shouldn't include 'count'. + +# Usage example: +def get_metric_fn(): + import math # all imports must be outside of the function to be passed. + def metric_fn(inst): + label = float(inst["input_label"]) + classes = float(inst["classes"]) + prediction = float(inst["scores"][1]) + log_loss = math.log(1 + math.exp( + -(label * 2 - 1) * math.log(prediction / (1 - prediction)))) + squared_err = (classes-label)**2 + return (log_loss, squared_err) + return metric_fn +metric_fn_encoded = base64.b64encode(dill.dumps(get_metric_fn(), recurse=True)) + +airflow.contrib.operators.DataFlowPythonOperator( + task_id="summary-prediction", + py_options=["-m"], + py_file="airflow.contrib.operators.cloudml_prediction_summary", + options={ + "prediction_path": prediction_path, + "metric_fn_encoded": metric_fn_encoded, + "metric_keys": "log_loss,mse" + }, + dataflow_default_options={ + "project": "xxx", "region": "us-east1", + "staging_location": "gs://yy", "temp_location": "gs://zz", + }) + >> dag + +# When the input file is like the following: +{"inputs": "1,x,y,z", "classes": 1, "scores": [0.1, 0.9]} +{"inputs": "0,o,m,g", "classes": 0, "scores": [0.7, 0.3]} +{"inputs": "1,o,m,w", "classes": 0, "scores": [0.6, 0.4]} +{"inputs": "1,b,r,b", "classes": 1, "scores": [0.2, 0.8]} + +# The output file will be: +{"log_loss": 0.43890510565304547, "count": 4, "mse": 0.25} + +# To test outside of the dag: +subprocess.check_call(["python", + "-m", + "airflow.contrib.operators.cloudml_prediction_summary", + "--prediction_path=gs://...", + "--metric_fn_encoded=" + metric_fn_encoded, + "--metric_keys=log_loss,mse", + "--runner=DataflowRunner", + "--staging_location=gs://...", + "--temp_location=gs://...", + ]) + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import base64 +import json +import logging +import os + +import apache_beam as beam +import dill + + +class JsonCoder(object): + def encode(self, x): + return json.dumps(x) + + def decode(self, x): + return json.loads(x) + + [email protected]_fn +def MakeSummary(pcoll, metric_fn, metric_keys): # pylint: disable=invalid-name + return ( + pcoll + | "ApplyMetricFnPerInstance" >> beam.Map(metric_fn) + | "PairWith1" >> beam.Map(lambda tup: tup + (1,)) + | "SumTuple" >> beam.CombineGlobally(beam.combiners.TupleCombineFn( + *([sum] * (len(metric_keys) + 1)))) + | "AverageAndMakeDict" >> beam.Map( + lambda tup: dict( + [(name, tup[i]/tup[-1]) for i, name in enumerate(metric_keys)] + + [("count", tup[-1])]))) + + +def run(argv=None): + parser = argparse.ArgumentParser() + parser.add_argument( + "--prediction_path", required=True, + help=( + "The GCS folder that contains BatchPrediction results, containing " + "prediction.results-NNNNN-of-NNNNN files in the json format. " + "Output will be also stored in this folder, as a file" + "'prediction.summary.json'.")) + parser.add_argument( + "--metric_fn_encoded", required=True, + help=( + "An encoded function that calculates and returns a tuple of " + "metric(s) for a given instance (as a dictionary). It should be " + "encoded via base64.b64encode(dill.dumps(fn, recurse=True)).")) + parser.add_argument( + "--metric_keys", required=True, + help=( + "A comma-separated keys of the aggregated metric(s) in the summary " + "output. The order and the size of the keys must match to the " + "output of metric_fn. The summary will have an additional key, " + "'count', to represent the total number of instances, so this flag " + "shouldn't include 'count'.")) + known_args, pipeline_args = parser.parse_known_args(argv) + + metric_fn = dill.loads(base64.b64decode(known_args.metric_fn_encoded)) + if not callable(metric_fn): + raise ValueError("--metric_fn_encoded must be an encoded callable.") + metric_keys = known_args.metric_keys.split(",") + + with beam.Pipeline( + options=beam.pipeline.PipelineOptions(pipeline_args)) as p: + # This is apache-beam ptransform's convention + # pylint: disable=no-value-for-parameter + _ = (p + | "ReadPredictionResult" >> beam.io.ReadFromText( + os.path.join(known_args.prediction_path, + "prediction.results-*-of-*"), + coder=JsonCoder()) + | "Summary" >> MakeSummary(metric_fn, metric_keys) + | "Write" >> beam.io.WriteToText( + os.path.join(known_args.prediction_path, + "prediction.summary.json"), + shard_name_template='', # without trailing -NNNNN-of-NNNNN. + coder=JsonCoder())) + # pylint: enable=no-value-for-parameter + + +if __name__ == "__main__": + logging.getLogger().setLevel(logging.INFO) + run() http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/194d1d6e/tests/contrib/operators/test_cloudml_operator_utils.py ---------------------------------------------------------------------- diff --git a/tests/contrib/operators/test_cloudml_operator_utils.py b/tests/contrib/operators/test_cloudml_operator_utils.py new file mode 100644 index 0000000..91a9f77 --- /dev/null +++ b/tests/contrib/operators/test_cloudml_operator_utils.py @@ -0,0 +1,179 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import datetime +import unittest + +from airflow import configuration, DAG +from airflow.contrib.operators import cloudml_operator_utils +from airflow.contrib.operators.cloudml_operator_utils import create_evaluate_ops +from airflow.exceptions import AirflowException + +from mock import ANY +from mock import patch + +DEFAULT_DATE = datetime.datetime(2017, 6, 6) + + +class CreateEvaluateOpsTest(unittest.TestCase): + + INPUT_MISSING_ORIGIN = { + 'dataFormat': 'TEXT', + 'inputPaths': ['gs://legal-bucket/fake-input-path/*'], + 'outputPath': 'gs://legal-bucket/fake-output-path', + 'region': 'us-east1', + } + SUCCESS_MESSAGE_MISSING_INPUT = { + 'jobId': 'eval_test_prediction', + 'predictionOutput': { + 'outputPath': 'gs://fake-output-path', + 'predictionCount': 5000, + 'errorCount': 0, + 'nodeHours': 2.78 + }, + 'state': 'SUCCEEDED' + } + + def setUp(self): + super(CreateEvaluateOpsTest, self).setUp() + configuration.load_test_config() + self.dag = DAG( + 'test_dag', + default_args={ + 'owner': 'airflow', + 'start_date': DEFAULT_DATE, + 'end_date': DEFAULT_DATE, + }, + schedule_interval='@daily') + self.metric_fn = lambda x: (0.1,) + self.metric_fn_encoded = cloudml_operator_utils.base64.b64encode( + cloudml_operator_utils.dill.dumps(self.metric_fn, recurse=True)) + + + def testSuccessfulRun(self): + input_with_model = self.INPUT_MISSING_ORIGIN.copy() + input_with_model['modelName'] = ( + 'projects/test-project/models/test_model') + + pred, summary, validate = create_evaluate_ops( + task_prefix='eval-test', + project_id='test-project', + job_id='eval-test-prediction', + region=input_with_model['region'], + data_format=input_with_model['dataFormat'], + input_paths=input_with_model['inputPaths'], + prediction_path=input_with_model['outputPath'], + model_name=input_with_model['modelName'].split('/')[-1], + metric_fn_and_keys=(self.metric_fn, ['err']), + validate_fn=(lambda x: 'err=%.1f' % x['err']), + dataflow_options=None, + dag=self.dag) + + with patch('airflow.contrib.operators.cloudml_operator.' + 'CloudMLHook') as mock_cloudml_hook: + + success_message = self.SUCCESS_MESSAGE_MISSING_INPUT.copy() + success_message['predictionInput'] = input_with_model + hook_instance = mock_cloudml_hook.return_value + hook_instance.create_job.return_value = success_message + result = pred.execute(None) + mock_cloudml_hook.assert_called_with('google_cloud_default', None) + hook_instance.create_job.assert_called_once_with( + 'test-project', + { + 'jobId': 'eval_test_prediction', + 'predictionInput': input_with_model + }, ANY) + self.assertEqual(success_message['predictionOutput'], result) + + with patch('airflow.contrib.operators.dataflow_operator.' + 'DataFlowHook') as mock_dataflow_hook: + + hook_instance = mock_dataflow_hook.return_value + hook_instance.start_python_dataflow.return_value = None + summary.execute(None) + mock_dataflow_hook.assert_called_with( + gcp_conn_id='google_cloud_default', delegate_to=None) + hook_instance.start_python_dataflow.assert_called_once_with( + 'eval-test-summary', + { + 'prediction_path': 'gs://legal-bucket/fake-output-path', + 'metric_keys': 'err', + 'metric_fn_encoded': self.metric_fn_encoded, + }, + 'airflow.contrib.operators.cloudml_prediction_summary', + ['-m']) + + with patch('airflow.contrib.operators.cloudml_operator_utils.' + 'GoogleCloudStorageHook') as mock_gcs_hook: + + hook_instance = mock_gcs_hook.return_value + hook_instance.download.return_value = '{"err": 0.9, "count": 9}' + result = validate.execute({}) + hook_instance.download.assert_called_once_with( + 'legal-bucket', 'fake-output-path/prediction.summary.json') + self.assertEqual('err=0.9', result) + + def testFailures(self): + input_with_model = self.INPUT_MISSING_ORIGIN.copy() + input_with_model['modelName'] = ( + 'projects/test-project/models/test_model') + + other_params_but_models = { + 'task_prefix': 'eval-test', + 'project_id': 'test-project', + 'job_id': 'eval-test-prediction', + 'region': input_with_model['region'], + 'data_format': input_with_model['dataFormat'], + 'input_paths': input_with_model['inputPaths'], + 'prediction_path': input_with_model['outputPath'], + 'metric_fn_and_keys': (self.metric_fn, ['err']), + 'validate_fn': (lambda x: 'err=%.1f' % x['err']), + 'dataflow_options': None, + 'dag': self.dag, + } + + with self.assertRaisesRegexp(ValueError, 'Missing model origin'): + _ = create_evaluate_ops(**other_params_but_models) + + with self.assertRaisesRegexp(ValueError, 'Ambiguous model origin'): + _ = create_evaluate_ops(model_uri='abc', model_name='cde', + **other_params_but_models) + + with self.assertRaisesRegexp(ValueError, 'Ambiguous model origin'): + _ = create_evaluate_ops(model_uri='abc', version_name='vvv', + **other_params_but_models) + + with self.assertRaisesRegexp(AirflowException, + '`metric_fn` param must be callable'): + params = other_params_but_models.copy() + params['metric_fn_and_keys'] = (None, ['abc']) + _ = create_evaluate_ops(model_uri='gs://blah', **params) + + with self.assertRaisesRegexp(AirflowException, + '`validate_fn` param must be callable'): + params = other_params_but_models.copy() + params['validate_fn'] = None + _ = create_evaluate_ops(model_uri='gs://blah', **params) + + +if __name__ == '__main__': + unittest.main()
