Repository: incubator-airflow Updated Branches: refs/heads/master 65e7025f3 -> d9bf1edd4
[AIRFLOW-2291] Add optional params to ML Engine This commit adds three extra optional parameters to the `MLEngineTrainingOperator` as well as a new unit test for the `MLEngineVersionOperator` Closes #3202 from dlebech/ml-engine-python-version Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/d9bf1edd Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/d9bf1edd Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/d9bf1edd Branch: refs/heads/master Commit: d9bf1edd44af847aa1f9ecd9d797dfacb806cc00 Parents: 65e7025 Author: David Volquartz Lebech <da...@lebech.info> Authored: Wed Apr 11 14:15:06 2018 +0200 Committer: Fokko Driesprong <fokkodriespr...@godatadriven.com> Committed: Wed Apr 11 14:15:06 2018 +0200 ---------------------------------------------------------------------- airflow/contrib/operators/mlengine_operator.py | 28 +++++++++ .../contrib/operators/test_mlengine_operator.py | 63 +++++++++++++++++++- 2 files changed, 90 insertions(+), 1 deletion(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/d9bf1edd/airflow/contrib/operators/mlengine_operator.py ---------------------------------------------------------------------- diff --git a/airflow/contrib/operators/mlengine_operator.py b/airflow/contrib/operators/mlengine_operator.py index 3dd63f2..a6f186b 100644 --- a/airflow/contrib/operators/mlengine_operator.py +++ b/airflow/contrib/operators/mlengine_operator.py @@ -474,6 +474,16 @@ class MLEngineTrainingOperator(BaseOperator): :param scale_tier: Resource tier for MLEngine training job. :type scale_tier: string + :param runtime_version: The Google Cloud ML runtime version to use for training. + :type runtime_version: string + + :param python_version: The version of Python used in training. + :type python_version: string + + :param job_dir: A Google Cloud Storage path in which to store training + outputs and other data needed for training. + :type job_dir: string + :param gcp_conn_id: The connection ID to use when fetching connection info. :type gcp_conn_id: string @@ -497,6 +507,9 @@ class MLEngineTrainingOperator(BaseOperator): '_training_args', '_region', '_scale_tier', + '_runtime_version', + '_python_version', + '_job_dir' ] @apply_defaults @@ -508,6 +521,9 @@ class MLEngineTrainingOperator(BaseOperator): training_args, region, scale_tier=None, + runtime_version=None, + python_version=None, + job_dir=None, gcp_conn_id='google_cloud_default', delegate_to=None, mode='PRODUCTION', @@ -521,6 +537,9 @@ class MLEngineTrainingOperator(BaseOperator): self._training_args = training_args self._region = region self._scale_tier = scale_tier + self._runtime_version = runtime_version + self._python_version = python_version + self._job_dir = job_dir self._gcp_conn_id = gcp_conn_id self._delegate_to = delegate_to self._mode = mode @@ -555,6 +574,15 @@ class MLEngineTrainingOperator(BaseOperator): } } + if self._runtime_version: + training_request['trainingInput']['runtimeVersion'] = self._runtime_version + + if self._python_version: + training_request['trainingInput']['pythonVersion'] = self._python_version + + if self._job_dir: + training_request['trainingInput']['jobDir'] = self._job_dir + if self._mode == 'DRY_RUN': self.log.info('In dry_run mode.') self.log.info('MLEngine Training job request is: {}'.format( http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/d9bf1edd/tests/contrib/operators/test_mlengine_operator.py ---------------------------------------------------------------------- diff --git a/tests/contrib/operators/test_mlengine_operator.py b/tests/contrib/operators/test_mlengine_operator.py index 2766e5d..9a8d56c 100644 --- a/tests/contrib/operators/test_mlengine_operator.py +++ b/tests/contrib/operators/test_mlengine_operator.py @@ -17,6 +17,7 @@ from __future__ import absolute_import, division, print_function +import copy import datetime import unittest @@ -26,7 +27,8 @@ from mock import ANY, patch from airflow import DAG, configuration from airflow.contrib.operators.mlengine_operator import (MLEngineBatchPredictionOperator, - MLEngineTrainingOperator) + MLEngineTrainingOperator, + MLEngineVersionOperator) from airflow.exceptions import AirflowException DEFAULT_DATE = datetime.datetime(2017, 6, 6) @@ -322,6 +324,33 @@ class MLEngineTrainingOperatorTest(unittest.TestCase): hook_instance.create_job.assert_called_with( 'test-project', self.TRAINING_INPUT, ANY) + def testSuccessCreateTrainingJobWithOptionalArgs(self): + training_input = copy.deepcopy(self.TRAINING_INPUT) + training_input['trainingInput']['runtimeVersion'] = '1.6' + training_input['trainingInput']['pythonVersion'] = '3.5' + training_input['trainingInput']['jobDir'] = 'gs://some-bucket/jobs/test_training' + + with patch('airflow.contrib.operators.mlengine_operator.MLEngineHook') \ + as mock_hook: + success_response = self.TRAINING_INPUT.copy() + success_response['state'] = 'SUCCEEDED' + hook_instance = mock_hook.return_value + hook_instance.create_job.return_value = success_response + + training_op = MLEngineTrainingOperator( + runtime_version='1.6', + python_version='3.5', + job_dir='gs://some-bucket/jobs/test_training', + **self.TRAINING_DEFAULT_ARGS) + training_op.execute(None) + + mock_hook.assert_called_with(gcp_conn_id='google_cloud_default', + delegate_to=None) + # Make sure only 'create_job' is invoked on hook instance + self.assertEquals(len(hook_instance.mock_calls), 1) + hook_instance.create_job.assert_called_with( + 'test-project', training_input, ANY) + def testHttpError(self): http_error_code = 403 with patch('airflow.contrib.operators.mlengine_operator.MLEngineHook') \ @@ -369,5 +398,37 @@ class MLEngineTrainingOperatorTest(unittest.TestCase): self.assertEquals('A failure message', str(context.exception)) +class MLEngineVersionOperatorTest(unittest.TestCase): + VERSION_DEFAULT_ARGS = { + 'project_id': 'test-project', + 'model_name': 'test-model', + 'task_id': 'test-version' + } + VERSION_INPUT = { + 'name': 'v1', + 'deploymentUri': 'gs://some-bucket/jobs/test_training/model.pb', + 'runtimeVersion': '1.6' + } + + def testSuccessCreateVersion(self): + with patch('airflow.contrib.operators.mlengine_operator.MLEngineHook') \ + as mock_hook: + success_response = {'name': 'some-name', 'done': True} + hook_instance = mock_hook.return_value + hook_instance.create_version.return_value = success_response + + training_op = MLEngineVersionOperator( + version=self.VERSION_INPUT, + **self.VERSION_DEFAULT_ARGS) + training_op.execute(None) + + mock_hook.assert_called_with(gcp_conn_id='google_cloud_default', + delegate_to=None) + # Make sure only 'create_version' is invoked on hook instance + self.assertEquals(len(hook_instance.mock_calls), 1) + hook_instance.create_version.assert_called_with( + 'test-project', 'test-model', self.VERSION_INPUT) + + if __name__ == '__main__': unittest.main()