Repository: incubator-airflow
Updated Branches:
  refs/heads/master 65e7025f3 -> d9bf1edd4


[AIRFLOW-2291] Add optional params to ML Engine

This commit adds three extra optional parameters
to the
`MLEngineTrainingOperator` as well as a new unit
test for the
`MLEngineVersionOperator`

Closes #3202 from dlebech/ml-engine-python-version


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/d9bf1edd
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/d9bf1edd
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/d9bf1edd

Branch: refs/heads/master
Commit: d9bf1edd44af847aa1f9ecd9d797dfacb806cc00
Parents: 65e7025
Author: David Volquartz Lebech <da...@lebech.info>
Authored: Wed Apr 11 14:15:06 2018 +0200
Committer: Fokko Driesprong <fokkodriespr...@godatadriven.com>
Committed: Wed Apr 11 14:15:06 2018 +0200

----------------------------------------------------------------------
 airflow/contrib/operators/mlengine_operator.py  | 28 +++++++++
 .../contrib/operators/test_mlengine_operator.py | 63 +++++++++++++++++++-
 2 files changed, 90 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/d9bf1edd/airflow/contrib/operators/mlengine_operator.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/operators/mlengine_operator.py 
b/airflow/contrib/operators/mlengine_operator.py
index 3dd63f2..a6f186b 100644
--- a/airflow/contrib/operators/mlengine_operator.py
+++ b/airflow/contrib/operators/mlengine_operator.py
@@ -474,6 +474,16 @@ class MLEngineTrainingOperator(BaseOperator):
     :param scale_tier: Resource tier for MLEngine training job.
     :type scale_tier: string
 
+    :param runtime_version: The Google Cloud ML runtime version to use for 
training.
+    :type runtime_version: string
+
+    :param python_version: The version of Python used in training.
+    :type python_version: string
+
+    :param job_dir: A Google Cloud Storage path in which to store training
+        outputs and other data needed for training.
+    :type job_dir: string
+
     :param gcp_conn_id: The connection ID to use when fetching connection info.
     :type gcp_conn_id: string
 
@@ -497,6 +507,9 @@ class MLEngineTrainingOperator(BaseOperator):
         '_training_args',
         '_region',
         '_scale_tier',
+        '_runtime_version',
+        '_python_version',
+        '_job_dir'
     ]
 
     @apply_defaults
@@ -508,6 +521,9 @@ class MLEngineTrainingOperator(BaseOperator):
                  training_args,
                  region,
                  scale_tier=None,
+                 runtime_version=None,
+                 python_version=None,
+                 job_dir=None,
                  gcp_conn_id='google_cloud_default',
                  delegate_to=None,
                  mode='PRODUCTION',
@@ -521,6 +537,9 @@ class MLEngineTrainingOperator(BaseOperator):
         self._training_args = training_args
         self._region = region
         self._scale_tier = scale_tier
+        self._runtime_version = runtime_version
+        self._python_version = python_version
+        self._job_dir = job_dir
         self._gcp_conn_id = gcp_conn_id
         self._delegate_to = delegate_to
         self._mode = mode
@@ -555,6 +574,15 @@ class MLEngineTrainingOperator(BaseOperator):
             }
         }
 
+        if self._runtime_version:
+            training_request['trainingInput']['runtimeVersion'] = 
self._runtime_version
+
+        if self._python_version:
+            training_request['trainingInput']['pythonVersion'] = 
self._python_version
+
+        if self._job_dir:
+            training_request['trainingInput']['jobDir'] = self._job_dir
+
         if self._mode == 'DRY_RUN':
             self.log.info('In dry_run mode.')
             self.log.info('MLEngine Training job request is: {}'.format(

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/d9bf1edd/tests/contrib/operators/test_mlengine_operator.py
----------------------------------------------------------------------
diff --git a/tests/contrib/operators/test_mlengine_operator.py 
b/tests/contrib/operators/test_mlengine_operator.py
index 2766e5d..9a8d56c 100644
--- a/tests/contrib/operators/test_mlengine_operator.py
+++ b/tests/contrib/operators/test_mlengine_operator.py
@@ -17,6 +17,7 @@
 
 from __future__ import absolute_import, division, print_function
 
+import copy
 import datetime
 import unittest
 
@@ -26,7 +27,8 @@ from mock import ANY, patch
 
 from airflow import DAG, configuration
 from airflow.contrib.operators.mlengine_operator import 
(MLEngineBatchPredictionOperator,
-                                                         
MLEngineTrainingOperator)
+                                                         
MLEngineTrainingOperator,
+                                                         
MLEngineVersionOperator)
 from airflow.exceptions import AirflowException
 
 DEFAULT_DATE = datetime.datetime(2017, 6, 6)
@@ -322,6 +324,33 @@ class MLEngineTrainingOperatorTest(unittest.TestCase):
             hook_instance.create_job.assert_called_with(
                 'test-project', self.TRAINING_INPUT, ANY)
 
+    def testSuccessCreateTrainingJobWithOptionalArgs(self):
+        training_input = copy.deepcopy(self.TRAINING_INPUT)
+        training_input['trainingInput']['runtimeVersion'] = '1.6'
+        training_input['trainingInput']['pythonVersion'] = '3.5'
+        training_input['trainingInput']['jobDir'] = 
'gs://some-bucket/jobs/test_training'
+
+        with patch('airflow.contrib.operators.mlengine_operator.MLEngineHook') 
\
+                as mock_hook:
+            success_response = self.TRAINING_INPUT.copy()
+            success_response['state'] = 'SUCCEEDED'
+            hook_instance = mock_hook.return_value
+            hook_instance.create_job.return_value = success_response
+
+            training_op = MLEngineTrainingOperator(
+                runtime_version='1.6',
+                python_version='3.5',
+                job_dir='gs://some-bucket/jobs/test_training',
+                **self.TRAINING_DEFAULT_ARGS)
+            training_op.execute(None)
+
+            mock_hook.assert_called_with(gcp_conn_id='google_cloud_default',
+                                         delegate_to=None)
+            # Make sure only 'create_job' is invoked on hook instance
+            self.assertEquals(len(hook_instance.mock_calls), 1)
+            hook_instance.create_job.assert_called_with(
+                'test-project', training_input, ANY)
+
     def testHttpError(self):
         http_error_code = 403
         with patch('airflow.contrib.operators.mlengine_operator.MLEngineHook') 
\
@@ -369,5 +398,37 @@ class MLEngineTrainingOperatorTest(unittest.TestCase):
             self.assertEquals('A failure message', str(context.exception))
 
 
+class MLEngineVersionOperatorTest(unittest.TestCase):
+    VERSION_DEFAULT_ARGS = {
+        'project_id': 'test-project',
+        'model_name': 'test-model',
+        'task_id': 'test-version'
+    }
+    VERSION_INPUT = {
+        'name': 'v1',
+        'deploymentUri': 'gs://some-bucket/jobs/test_training/model.pb',
+        'runtimeVersion': '1.6'
+    }
+
+    def testSuccessCreateVersion(self):
+        with patch('airflow.contrib.operators.mlengine_operator.MLEngineHook') 
\
+                as mock_hook:
+            success_response = {'name': 'some-name', 'done': True}
+            hook_instance = mock_hook.return_value
+            hook_instance.create_version.return_value = success_response
+
+            training_op = MLEngineVersionOperator(
+                version=self.VERSION_INPUT,
+                **self.VERSION_DEFAULT_ARGS)
+            training_op.execute(None)
+
+            mock_hook.assert_called_with(gcp_conn_id='google_cloud_default',
+                                         delegate_to=None)
+            # Make sure only 'create_version' is invoked on hook instance
+            self.assertEquals(len(hook_instance.mock_calls), 1)
+            hook_instance.create_version.assert_called_with(
+                'test-project', 'test-model', self.VERSION_INPUT)
+
+
 if __name__ == '__main__':
     unittest.main()

Reply via email to