Repository: incubator-airflow
Updated Branches:
  refs/heads/master 3475faf6b -> 65e7025f3


[AIRFLOW-1774] Allow consistent templating of arguments in 
MLEngineBatchPredictionOperator

Fix a minor typo and a wrong non-default
assignment

Fix one more typo

Adapt tests to new error messages and fix another
typo

Fix exception type in utils operator test class

Improve cleansing of non-valid training and
prediciton job names

Closes #2746 from wileeam/ml-engine-prediction-
job-normalization


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/65e7025f
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/65e7025f
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/65e7025f

Branch: refs/heads/master
Commit: 65e7025f3af378dfa825eb04d005ec1f7a422cde
Parents: 3475faf
Author: Guillermo Rodríguez Cano <wsch...@gmail.com>
Authored: Wed Apr 11 11:57:21 2018 +0200
Committer: Fokko Driesprong <fokkodriespr...@godatadriven.com>
Committed: Wed Apr 11 11:57:21 2018 +0200

----------------------------------------------------------------------
 airflow/contrib/operators/mlengine_operator.py  | 205 ++++++++++---------
 .../contrib/operators/test_mlengine_operator.py |  88 ++++----
 .../operators/test_mlengine_operator_utils.py   |   6 +-
 3 files changed, 156 insertions(+), 143 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/65e7025f/airflow/contrib/operators/mlengine_operator.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/operators/mlengine_operator.py 
b/airflow/contrib/operators/mlengine_operator.py
index 0d033d3..3dd63f2 100644
--- a/airflow/contrib/operators/mlengine_operator.py
+++ b/airflow/contrib/operators/mlengine_operator.py
@@ -15,77 +15,17 @@
 # limitations under the License.
 import re
 
-from airflow import settings
+from apiclient import errors
+
 from airflow.contrib.hooks.gcp_mlengine_hook import MLEngineHook
 from airflow.exceptions import AirflowException
 from airflow.operators import BaseOperator
 from airflow.utils.decorators import apply_defaults
-from apiclient import errors
-
 from airflow.utils.log.logging_mixin import LoggingMixin
 
 log = LoggingMixin().log
 
 
-def _create_prediction_input(project_id,
-                             region,
-                             data_format,
-                             input_paths,
-                             output_path,
-                             model_name=None,
-                             version_name=None,
-                             uri=None,
-                             max_worker_count=None,
-                             runtime_version=None):
-    """
-    Create the batch prediction input from the given parameters.
-
-    Args:
-        A subset of arguments documented in __init__ method of class
-        MLEngineBatchPredictionOperator
-
-    Returns:
-        A dictionary representing the predictionInput object as documented
-        in https://cloud.google.com/ml-engine/reference/rest/v1/projects.jobs.
-
-    Raises:
-        ValueError: if a unique model/version origin cannot be determined.
-    """
-    prediction_input = {
-        'dataFormat': data_format,
-        'inputPaths': input_paths,
-        'outputPath': output_path,
-        'region': region
-    }
-
-    if uri:
-        if model_name or version_name:
-            log.error(
-                'Ambiguous model origin: Both uri and model/version name are 
provided.'
-            )
-            raise ValueError('Ambiguous model origin.')
-        prediction_input['uri'] = uri
-    elif model_name:
-        origin_name = 'projects/{}/models/{}'.format(project_id, model_name)
-        if not version_name:
-            prediction_input['modelName'] = origin_name
-        else:
-            prediction_input['versionName'] = \
-                origin_name + '/versions/{}'.format(version_name)
-    else:
-        log.error(
-            'Missing model origin: Batch prediction expects a model, '
-            'a model & version combination, or a URI to savedModel.')
-        raise ValueError('Missing model origin.')
-
-    if max_worker_count:
-        prediction_input['maxWorkerCount'] = max_worker_count
-    if runtime_version:
-        prediction_input['runtimeVersion'] = runtime_version
-
-    return prediction_input
-
-
 def _normalize_mlengine_job_id(job_id):
     """
     Replaces invalid MLEngine job_id characters with '_'.
@@ -99,10 +39,27 @@ def _normalize_mlengine_job_id(job_id):
     Returns:
         A valid job_id representation.
     """
-    match = re.search(r'\d', job_id)
+
+    # Add a prefix when a job_id starts with a digit or a template
+    match = re.search(r'\d|\{{2}', job_id)
     if match and match.start() is 0:
-        job_id = 'z_{}'.format(job_id)
-    return re.sub('[^0-9a-zA-Z]+', '_', job_id)
+        job = 'z_{}'.format(job_id)
+    else:
+        job = job_id
+
+    # Clean up 'bad' characters except templates
+    tracker = 0
+    cleansed_job_id = ''
+    for m in re.finditer(r'\{{2}.+?\}{2}', job):
+        cleansed_job_id += re.sub(r'[^0-9a-zA-Z]+', '_',
+                                  job[tracker:m.start()])
+        cleansed_job_id += job[m.start():m.end()]
+        tracker = m.end()
+
+    # Clean up last substring or the full string if no templates
+    cleansed_job_id += re.sub(r'[^0-9a-zA-Z]+', '_', job[tracker:])
+
+    return cleansed_job_id
 
 
 class MLEngineBatchPredictionOperator(BaseOperator):
@@ -132,6 +89,8 @@ class MLEngineBatchPredictionOperator(BaseOperator):
     if the desired model version is
     "projects/my_project/models/my_model/versions/my_version".
 
+    See https://cloud.google.com/ml-engine/reference/rest/v1/projects.jobs
+    for further documentation on the parameters.
 
     :param project_id: The Google Cloud project name where the
         prediction job is submitted.
@@ -197,7 +156,14 @@ class MLEngineBatchPredictionOperator(BaseOperator):
     """
 
     template_fields = [
-        "prediction_job_request",
+        '_project_id',
+        '_job_id',
+        '_region',
+        '_input_paths',
+        '_output_path',
+        '_model_name',
+        '_version_name',
+        '_uri',
     ]
 
     @apply_defaults
@@ -219,45 +185,91 @@ class MLEngineBatchPredictionOperator(BaseOperator):
                  **kwargs):
         super(MLEngineBatchPredictionOperator, self).__init__(*args, **kwargs)
 
-        self.project_id = project_id
-        self.gcp_conn_id = gcp_conn_id
-        self.delegate_to = delegate_to
+        self._project_id = project_id
+        self._job_id = job_id
+        self._region = region
+        self._data_format = data_format
+        self._input_paths = input_paths
+        self._output_path = output_path
+        self._model_name = model_name
+        self._version_name = version_name
+        self._uri = uri
+        self._max_worker_count = max_worker_count
+        self._runtime_version = runtime_version
+        self._gcp_conn_id = gcp_conn_id
+        self._delegate_to = delegate_to
 
-        try:
-            prediction_input = _create_prediction_input(
-                project_id, region, data_format, input_paths, output_path,
-                model_name, version_name, uri, max_worker_count,
-                runtime_version)
-        except ValueError as e:
-            self.log.error(
-                'Cannot create batch prediction job request due to: %s',
-                e
-            )
-            raise
+        if not self._project_id:
+            raise AirflowException('Google Cloud project id is required.')
+        if not self._job_id:
+            raise AirflowException(
+                'An unique job id is required for Google MLEngine prediction '
+                'job.')
 
-        self.prediction_job_request = {
-            'jobId': _normalize_mlengine_job_id(job_id),
-            'predictionInput': prediction_input
-        }
+        if self._uri:
+            if self._model_name or self._version_name:
+                raise AirflowException('Ambiguous model origin: Both uri and '
+                                       'model/version name are provided.')
+
+        if self._version_name and not self._model_name:
+            raise AirflowException(
+                'Missing model: Batch prediction expects '
+                'a model name when a version name is provided.')
+
+        if not (self._uri or self._model_name):
+            raise AirflowException(
+                'Missing model origin: Batch prediction expects a model, '
+                'a model & version combination, or a URI to a savedModel.')
 
     def execute(self, context):
-        hook = MLEngineHook(self.gcp_conn_id, self.delegate_to)
+        job_id = _normalize_mlengine_job_id(self._job_id)
+        prediction_request = {
+            'jobId': job_id,
+            'predictionInput': {
+                'dataFormat': self._data_format,
+                'inputPaths': self._input_paths,
+                'outputPath': self._output_path,
+                'region': self._region
+            }
+        }
 
+        if self._uri:
+            prediction_request['predictionInput']['uri'] = self._uri
+        elif self._model_name:
+            origin_name = 'projects/{}/models/{}'.format(
+                self._project_id, self._model_name)
+            if not self._version_name:
+                prediction_request['predictionInput'][
+                    'modelName'] = origin_name
+            else:
+                prediction_request['predictionInput']['versionName'] = \
+                    origin_name + '/versions/{}'.format(self._version_name)
+
+        if self._max_worker_count:
+            prediction_request['predictionInput'][
+                'maxWorkerCount'] = self._max_worker_count
+
+        if self._runtime_version:
+            prediction_request['predictionInput'][
+                'runtimeVersion'] = self._runtime_version
+
+        hook = MLEngineHook(self._gcp_conn_id, self._delegate_to)
+
+        # Helper method to check if the existing job's prediction input is the
+        # same as the request we get here.
         def check_existing_job(existing_job):
             return existing_job.get('predictionInput', None) == \
-                self.prediction_job_request['predictionInput']
+                prediction_request['predictionInput']
+
         try:
             finished_prediction_job = hook.create_job(
-                self.project_id,
-                self.prediction_job_request,
-                check_existing_job)
+                self._project_id, prediction_request, check_existing_job)
         except errors.HttpError:
             raise
 
         if finished_prediction_job['state'] != 'SUCCEEDED':
-            self.log.error(
-                'Batch prediction job failed: %s',
-                str(finished_prediction_job))
+            self.log.error('MLEngine batch prediction job failed: {}'.format(
+                str(finished_prediction_job)))
             raise RuntimeError(finished_prediction_job['errorMessage'])
 
         return finished_prediction_job['predictionOutput']
@@ -419,9 +431,8 @@ class MLEngineVersionOperator(BaseOperator):
             return hook.create_version(self._project_id, self._model_name,
                                        self._version)
         elif self._operation == 'set_default':
-            return hook.set_default_version(
-                self._project_id, self._model_name,
-                self._version['name'])
+            return hook.set_default_version(self._project_id, self._model_name,
+                                            self._version['name'])
         elif self._operation == 'list':
             return hook.list_versions(self._project_id, self._model_name)
         elif self._operation == 'delete':
@@ -546,7 +557,8 @@ class MLEngineTrainingOperator(BaseOperator):
 
         if self._mode == 'DRY_RUN':
             self.log.info('In dry_run mode.')
-            self.log.info('MLEngine Training job request is: 
{}'.format(training_request))
+            self.log.info('MLEngine Training job request is: {}'.format(
+                training_request))
             return
 
         hook = MLEngineHook(
@@ -557,6 +569,7 @@ class MLEngineTrainingOperator(BaseOperator):
         def check_existing_job(existing_job):
             return existing_job.get('trainingInput', None) == \
                 training_request['trainingInput']
+
         try:
             finished_training_job = hook.create_job(
                 self._project_id, training_request, check_existing_job)

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/65e7025f/tests/contrib/operators/test_mlengine_operator.py
----------------------------------------------------------------------
diff --git a/tests/contrib/operators/test_mlengine_operator.py 
b/tests/contrib/operators/test_mlengine_operator.py
index 75b46a0..2766e5d 100644
--- a/tests/contrib/operators/test_mlengine_operator.py
+++ b/tests/contrib/operators/test_mlengine_operator.py
@@ -15,21 +15,19 @@
 # specific language governing permissions and limitations
 # under the License.
 
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
+from __future__ import absolute_import, division, print_function
 
 import datetime
-from apiclient import errors
-import httplib2
 import unittest
 
-from airflow import configuration, DAG
-from airflow.contrib.operators.mlengine_operator import 
MLEngineBatchPredictionOperator
-from airflow.contrib.operators.mlengine_operator import 
MLEngineTrainingOperator
+import httplib2
+from apiclient import errors
+from mock import ANY, patch
 
-from mock import ANY
-from mock import patch
+from airflow import DAG, configuration
+from airflow.contrib.operators.mlengine_operator import 
(MLEngineBatchPredictionOperator,
+                                                         
MLEngineTrainingOperator)
+from airflow.exceptions import AirflowException
 
 DEFAULT_DATE = datetime.datetime(2017, 6, 6)
 
@@ -58,7 +56,7 @@ class MLEngineBatchPredictionOperatorTest(unittest.TestCase):
         'data_format': 'TEXT',
         'input_paths': ['gs://legal-bucket-dash-Capital/legal-input-path/*'],
         'output_path':
-            'gs://12_legal_bucket_underscore_number/legal-output-path',
+        'gs://12_legal_bucket_underscore_number/legal-output-path',
         'task_id': 'test-prediction'
     }
 
@@ -105,14 +103,12 @@ class 
MLEngineBatchPredictionOperatorTest(unittest.TestCase):
 
             mock_hook.assert_called_with('google_cloud_default', None)
             hook_instance.create_job.assert_called_once_with(
-                'test-project',
-                {
+                'test-project', {
                     'jobId': 'test_prediction',
                     'predictionInput': input_with_model
                 }, ANY)
-            self.assertEquals(
-                success_message['predictionOutput'],
-                prediction_output)
+            self.assertEquals(success_message['predictionOutput'],
+                              prediction_output)
 
     def testSuccessWithVersion(self):
         with patch('airflow.contrib.operators.mlengine_operator.MLEngineHook') 
\
@@ -132,7 +128,8 @@ class 
MLEngineBatchPredictionOperatorTest(unittest.TestCase):
             hook_instance.create_job.return_value = success_message
 
             prediction_task = MLEngineBatchPredictionOperator(
-                job_id='test_prediction', project_id='test-project',
+                job_id='test_prediction',
+                project_id='test-project',
                 region=input_with_version['region'],
                 data_format=input_with_version['dataFormat'],
                 input_paths=input_with_version['inputPaths'],
@@ -145,14 +142,12 @@ class 
MLEngineBatchPredictionOperatorTest(unittest.TestCase):
 
             mock_hook.assert_called_with('google_cloud_default', None)
             hook_instance.create_job.assert_called_with(
-                'test-project',
-                {
+                'test-project', {
                     'jobId': 'test_prediction',
                     'predictionInput': input_with_version
                 }, ANY)
-            self.assertEquals(
-                success_message['predictionOutput'],
-                prediction_output)
+            self.assertEquals(success_message['predictionOutput'],
+                              prediction_output)
 
     def testSuccessWithURI(self):
         with patch('airflow.contrib.operators.mlengine_operator.MLEngineHook') 
\
@@ -184,48 +179,51 @@ class 
MLEngineBatchPredictionOperatorTest(unittest.TestCase):
 
             mock_hook.assert_called_with('google_cloud_default', None)
             hook_instance.create_job.assert_called_with(
-                'test-project',
-                {
+                'test-project', {
                     'jobId': 'test_prediction',
                     'predictionInput': input_with_uri
                 }, ANY)
-            self.assertEquals(
-                success_message['predictionOutput'],
-                prediction_output)
+            self.assertEquals(success_message['predictionOutput'],
+                              prediction_output)
 
     def testInvalidModelOrigin(self):
         # Test that both uri and model is given
         task_args = self.BATCH_PREDICTION_DEFAULT_ARGS.copy()
         task_args['uri'] = 'gs://fake-uri/saved_model'
         task_args['model_name'] = 'fake_model'
-        with self.assertRaises(ValueError) as context:
+        with self.assertRaises(AirflowException) as context:
             MLEngineBatchPredictionOperator(**task_args).execute(None)
-        self.assertEquals('Ambiguous model origin.', str(context.exception))
+        self.assertEquals('Ambiguous model origin: Both uri and '
+                          'model/version name are provided.',
+                          str(context.exception))
 
         # Test that both uri and model/version is given
         task_args = self.BATCH_PREDICTION_DEFAULT_ARGS.copy()
         task_args['uri'] = 'gs://fake-uri/saved_model'
         task_args['model_name'] = 'fake_model'
         task_args['version_name'] = 'fake_version'
-        with self.assertRaises(ValueError) as context:
+        with self.assertRaises(AirflowException) as context:
             MLEngineBatchPredictionOperator(**task_args).execute(None)
-        self.assertEquals('Ambiguous model origin.', str(context.exception))
+        self.assertEquals('Ambiguous model origin: Both uri and '
+                          'model/version name are provided.',
+                          str(context.exception))
 
         # Test that a version is given without a model
         task_args = self.BATCH_PREDICTION_DEFAULT_ARGS.copy()
         task_args['version_name'] = 'bare_version'
-        with self.assertRaises(ValueError) as context:
+        with self.assertRaises(AirflowException) as context:
             MLEngineBatchPredictionOperator(**task_args).execute(None)
-        self.assertEquals(
-            'Missing model origin.',
-            str(context.exception))
+        self.assertEquals('Missing model: Batch prediction expects a model '
+                          'name when a version name is provided.',
+                          str(context.exception))
 
         # Test that none of uri, model, model/version is given
         task_args = self.BATCH_PREDICTION_DEFAULT_ARGS.copy()
-        with self.assertRaises(ValueError) as context:
+        with self.assertRaises(AirflowException) as context:
             MLEngineBatchPredictionOperator(**task_args).execute(None)
         self.assertEquals(
-            'Missing model origin.',
+            'Missing model origin: Batch prediction expects a '
+            'model, a model & version combination, or a URI to a savedModel.',
             str(context.exception))
 
     def testHttpError(self):
@@ -241,7 +239,8 @@ class 
MLEngineBatchPredictionOperatorTest(unittest.TestCase):
             hook_instance.create_job.side_effect = errors.HttpError(
                 resp=httplib2.Response({
                     'status': http_error_code
-                }), content=b'Forbidden')
+                }),
+                content=b'Forbidden')
 
             with self.assertRaises(errors.HttpError) as context:
                 prediction_task = MLEngineBatchPredictionOperator(
@@ -258,8 +257,7 @@ class 
MLEngineBatchPredictionOperatorTest(unittest.TestCase):
 
                 mock_hook.assert_called_with('google_cloud_default', None)
                 hook_instance.create_job.assert_called_with(
-                    'test-project',
-                    {
+                    'test-project', {
                         'jobId': 'test_prediction',
                         'predictionInput': input_with_model
                     }, ANY)
@@ -313,11 +311,12 @@ class MLEngineTrainingOperatorTest(unittest.TestCase):
             hook_instance = mock_hook.return_value
             hook_instance.create_job.return_value = success_response
 
-            training_op = 
MLEngineTrainingOperator(**self.TRAINING_DEFAULT_ARGS)
+            training_op = MLEngineTrainingOperator(
+                **self.TRAINING_DEFAULT_ARGS)
             training_op.execute(None)
 
-            mock_hook.assert_called_with(gcp_conn_id='google_cloud_default',
-                                         delegate_to=None)
+            mock_hook.assert_called_with(
+                gcp_conn_id='google_cloud_default', delegate_to=None)
             # Make sure only 'create_job' is invoked on hook instance
             self.assertEquals(len(hook_instance.mock_calls), 1)
             hook_instance.create_job.assert_called_with(
@@ -331,7 +330,8 @@ class MLEngineTrainingOperatorTest(unittest.TestCase):
             hook_instance.create_job.side_effect = errors.HttpError(
                 resp=httplib2.Response({
                     'status': http_error_code
-                }), content=b'Forbidden')
+                }),
+                content=b'Forbidden')
 
             with self.assertRaises(errors.HttpError) as context:
                 training_op = MLEngineTrainingOperator(

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/65e7025f/tests/contrib/operators/test_mlengine_operator_utils.py
----------------------------------------------------------------------
diff --git a/tests/contrib/operators/test_mlengine_operator_utils.py 
b/tests/contrib/operators/test_mlengine_operator_utils.py
index c8f6fb5..0cb106d 100644
--- a/tests/contrib/operators/test_mlengine_operator_utils.py
+++ b/tests/contrib/operators/test_mlengine_operator_utils.py
@@ -158,14 +158,14 @@ class CreateEvaluateOpsTest(unittest.TestCase):
             'dag': dag,
         }
 
-        with self.assertRaisesRegexp(ValueError, 'Missing model origin'):
+        with self.assertRaisesRegexp(AirflowException, 'Missing model origin'):
             _ = create_evaluate_ops(**other_params_but_models)
 
-        with self.assertRaisesRegexp(ValueError, 'Ambiguous model origin'):
+        with self.assertRaisesRegexp(AirflowException, 'Ambiguous model 
origin'):
             _ = create_evaluate_ops(model_uri='abc', model_name='cde',
                                     **other_params_but_models)
 
-        with self.assertRaisesRegexp(ValueError, 'Ambiguous model origin'):
+        with self.assertRaisesRegexp(AirflowException, 'Ambiguous model 
origin'):
             _ = create_evaluate_ops(model_uri='abc', version_name='vvv',
                                     **other_params_but_models)
 

Reply via email to