This is an automated email from the ASF dual-hosted git repository.
kamilbregula pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/master by this push:
new 6e3d7b6 Add masterConfig parameter to
MLEngineStartTrainingJobOperator (#10578)
6e3d7b6 is described below
commit 6e3d7b63d3b34c34f8b38a7b41f4a5876e1f731f
Author: Antonio Davide Calì <[email protected]>
AuthorDate: Fri Sep 4 23:58:24 2020 +0200
Add masterConfig parameter to MLEngineStartTrainingJobOperator (#10578)
Co-authored-by: antonio-davide-cali <[email protected]>
---
.../providers/google/cloud/operators/mlengine.py | 11 ++++++
.../google/cloud/operators/test_mlengine.py | 39 ++++++++++++++++++++++
2 files changed, 50 insertions(+)
diff --git a/airflow/providers/google/cloud/operators/mlengine.py
b/airflow/providers/google/cloud/operators/mlengine.py
index 5dfcda2..f202254 100644
--- a/airflow/providers/google/cloud/operators/mlengine.py
+++ b/airflow/providers/google/cloud/operators/mlengine.py
@@ -1102,6 +1102,9 @@ class MLEngineStartTrainingJobOperator(BaseOperator):
:param master_type: Cloud ML Engine machine name.
Must be set when scale_tier is CUSTOM. (templated)
:type master_type: str
+ :param master_config: Cloud ML Engine master config.
+ master_type must be set if master_config is provided. (templated)
+ :type master_type: dict
:param runtime_version: The Google Cloud ML runtime version to use for
training. (templated)
:type runtime_version: str
@@ -1147,6 +1150,7 @@ class MLEngineStartTrainingJobOperator(BaseOperator):
'_region',
'_scale_tier',
'_master_type',
+ '_master_config',
'_runtime_version',
'_python_version',
'_job_dir',
@@ -1166,6 +1170,7 @@ class MLEngineStartTrainingJobOperator(BaseOperator):
region: str,
scale_tier: Optional[str] = None,
master_type: Optional[str] = None,
+ master_config: Optional[Dict] = None,
runtime_version: Optional[str] = None,
python_version: Optional[str] = None,
job_dir: Optional[str] = None,
@@ -1186,6 +1191,7 @@ class MLEngineStartTrainingJobOperator(BaseOperator):
self._region = region
self._scale_tier = scale_tier
self._master_type = master_type
+ self._master_config = master_config
self._runtime_version = runtime_version
self._python_version = python_version
self._job_dir = job_dir
@@ -1209,6 +1215,8 @@ class MLEngineStartTrainingJobOperator(BaseOperator):
raise AirflowException('Google Compute Engine region is required.')
if self._scale_tier is not None and self._scale_tier.upper() ==
"CUSTOM" and not self._master_type:
raise AirflowException('master_type must be set when scale_tier is
CUSTOM')
+ if self._master_config and not self._master_type:
+ raise AirflowException('master_type must be set when master_config
is provided')
def execute(self, context):
job_id = _normalize_mlengine_job_id(self._job_id)
@@ -1237,6 +1245,9 @@ class MLEngineStartTrainingJobOperator(BaseOperator):
if self._scale_tier is not None and self._scale_tier.upper() ==
"CUSTOM":
training_request['trainingInput']['masterType'] = self._master_type
+ if self._master_config:
+ training_request['trainingInput']['masterConfig'] =
self._master_config
+
if self._mode == 'DRY_RUN':
self.log.info('In dry_run mode.')
self.log.info('MLEngine Training job request is: %s',
training_request)
diff --git a/tests/providers/google/cloud/operators/test_mlengine.py
b/tests/providers/google/cloud/operators/test_mlengine.py
index 30ad775..dad9b9d 100644
--- a/tests/providers/google/cloud/operators/test_mlengine.py
+++ b/tests/providers/google/cloud/operators/test_mlengine.py
@@ -351,6 +351,45 @@ class TestMLEngineTrainingOperator(unittest.TestCase):
)
@patch('airflow.providers.google.cloud.operators.mlengine.MLEngineHook')
+ def test_success_create_training_job_with_master_config(self, mock_hook):
+ custom_training_default_args: dict =
copy.deepcopy(self.TRAINING_DEFAULT_ARGS)
+ custom_training_default_args['scale_tier'] = 'CUSTOM'
+
+ training_input = copy.deepcopy(self.TRAINING_INPUT)
+ training_input['trainingInput']['runtimeVersion'] = '1.6'
+ training_input['trainingInput']['pythonVersion'] = '3.5'
+ training_input['trainingInput']['jobDir'] =
'gs://some-bucket/jobs/test_training'
+ training_input['trainingInput']['scaleTier'] = 'CUSTOM'
+ training_input['trainingInput']['masterType'] = 'n1-standard-4'
+ training_input['trainingInput']['masterConfig'] = {
+ 'acceleratorConfig': {'count': '1', 'type': 'NVIDIA_TESLA_P4'},
+ }
+
+ success_response = training_input.copy()
+ success_response['state'] = 'SUCCEEDED'
+ hook_instance = mock_hook.return_value
+ hook_instance.create_job.return_value = success_response
+
+ training_op = MLEngineStartTrainingJobOperator(
+ runtime_version='1.6',
+ python_version='3.5',
+ job_dir='gs://some-bucket/jobs/test_training',
+ master_type='n1-standard-4',
+ master_config={'acceleratorConfig': {'count': '1', 'type':
'NVIDIA_TESLA_P4'},},
+ **custom_training_default_args,
+ )
+ training_op.execute(MagicMock())
+
+ mock_hook.assert_called_once_with(
+ gcp_conn_id='google_cloud_default', delegate_to=None,
impersonation_chain=None,
+ )
+ # Make sure only 'create_job' is invoked on hook instance
+ self.assertEqual(len(hook_instance.mock_calls), 1)
+ hook_instance.create_job.assert_called_once_with(
+ project_id='test-project', job=training_input,
use_existing_job_fn=ANY
+ )
+
+ @patch('airflow.providers.google.cloud.operators.mlengine.MLEngineHook')
def test_success_create_training_job_with_optional_args(self, mock_hook):
training_input = copy.deepcopy(self.TRAINING_INPUT)
training_input['trainingInput']['runtimeVersion'] = '1.6'