This is an automated email from the ASF dual-hosted git repository.

kamilbregula pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/master by this push:
     new 6e3d7b6  Add masterConfig parameter to 
MLEngineStartTrainingJobOperator (#10578)
6e3d7b6 is described below

commit 6e3d7b63d3b34c34f8b38a7b41f4a5876e1f731f
Author: Antonio Davide Calì <[email protected]>
AuthorDate: Fri Sep 4 23:58:24 2020 +0200

    Add masterConfig parameter to MLEngineStartTrainingJobOperator (#10578)
    
    Co-authored-by: antonio-davide-cali <[email protected]>
---
 .../providers/google/cloud/operators/mlengine.py   | 11 ++++++
 .../google/cloud/operators/test_mlengine.py        | 39 ++++++++++++++++++++++
 2 files changed, 50 insertions(+)

diff --git a/airflow/providers/google/cloud/operators/mlengine.py 
b/airflow/providers/google/cloud/operators/mlengine.py
index 5dfcda2..f202254 100644
--- a/airflow/providers/google/cloud/operators/mlengine.py
+++ b/airflow/providers/google/cloud/operators/mlengine.py
@@ -1102,6 +1102,9 @@ class MLEngineStartTrainingJobOperator(BaseOperator):
     :param master_type: Cloud ML Engine machine name.
         Must be set when scale_tier is CUSTOM. (templated)
     :type master_type: str
+    :param master_config: Cloud ML Engine master config.
+        master_type must be set if master_config is provided. (templated)
+    :type master_type: dict
     :param runtime_version: The Google Cloud ML runtime version to use for
         training. (templated)
     :type runtime_version: str
@@ -1147,6 +1150,7 @@ class MLEngineStartTrainingJobOperator(BaseOperator):
         '_region',
         '_scale_tier',
         '_master_type',
+        '_master_config',
         '_runtime_version',
         '_python_version',
         '_job_dir',
@@ -1166,6 +1170,7 @@ class MLEngineStartTrainingJobOperator(BaseOperator):
         region: str,
         scale_tier: Optional[str] = None,
         master_type: Optional[str] = None,
+        master_config: Optional[Dict] = None,
         runtime_version: Optional[str] = None,
         python_version: Optional[str] = None,
         job_dir: Optional[str] = None,
@@ -1186,6 +1191,7 @@ class MLEngineStartTrainingJobOperator(BaseOperator):
         self._region = region
         self._scale_tier = scale_tier
         self._master_type = master_type
+        self._master_config = master_config
         self._runtime_version = runtime_version
         self._python_version = python_version
         self._job_dir = job_dir
@@ -1209,6 +1215,8 @@ class MLEngineStartTrainingJobOperator(BaseOperator):
             raise AirflowException('Google Compute Engine region is required.')
         if self._scale_tier is not None and self._scale_tier.upper() == 
"CUSTOM" and not self._master_type:
             raise AirflowException('master_type must be set when scale_tier is 
CUSTOM')
+        if self._master_config and not self._master_type:
+            raise AirflowException('master_type must be set when master_config 
is provided')
 
     def execute(self, context):
         job_id = _normalize_mlengine_job_id(self._job_id)
@@ -1237,6 +1245,9 @@ class MLEngineStartTrainingJobOperator(BaseOperator):
         if self._scale_tier is not None and self._scale_tier.upper() == 
"CUSTOM":
             training_request['trainingInput']['masterType'] = self._master_type
 
+            if self._master_config:
+                training_request['trainingInput']['masterConfig'] = 
self._master_config
+
         if self._mode == 'DRY_RUN':
             self.log.info('In dry_run mode.')
             self.log.info('MLEngine Training job request is: %s', 
training_request)
diff --git a/tests/providers/google/cloud/operators/test_mlengine.py 
b/tests/providers/google/cloud/operators/test_mlengine.py
index 30ad775..dad9b9d 100644
--- a/tests/providers/google/cloud/operators/test_mlengine.py
+++ b/tests/providers/google/cloud/operators/test_mlengine.py
@@ -351,6 +351,45 @@ class TestMLEngineTrainingOperator(unittest.TestCase):
         )
 
     @patch('airflow.providers.google.cloud.operators.mlengine.MLEngineHook')
+    def test_success_create_training_job_with_master_config(self, mock_hook):
+        custom_training_default_args: dict = 
copy.deepcopy(self.TRAINING_DEFAULT_ARGS)
+        custom_training_default_args['scale_tier'] = 'CUSTOM'
+
+        training_input = copy.deepcopy(self.TRAINING_INPUT)
+        training_input['trainingInput']['runtimeVersion'] = '1.6'
+        training_input['trainingInput']['pythonVersion'] = '3.5'
+        training_input['trainingInput']['jobDir'] = 
'gs://some-bucket/jobs/test_training'
+        training_input['trainingInput']['scaleTier'] = 'CUSTOM'
+        training_input['trainingInput']['masterType'] = 'n1-standard-4'
+        training_input['trainingInput']['masterConfig'] = {
+            'acceleratorConfig': {'count': '1', 'type': 'NVIDIA_TESLA_P4'},
+        }
+
+        success_response = training_input.copy()
+        success_response['state'] = 'SUCCEEDED'
+        hook_instance = mock_hook.return_value
+        hook_instance.create_job.return_value = success_response
+
+        training_op = MLEngineStartTrainingJobOperator(
+            runtime_version='1.6',
+            python_version='3.5',
+            job_dir='gs://some-bucket/jobs/test_training',
+            master_type='n1-standard-4',
+            master_config={'acceleratorConfig': {'count': '1', 'type': 
'NVIDIA_TESLA_P4'},},
+            **custom_training_default_args,
+        )
+        training_op.execute(MagicMock())
+
+        mock_hook.assert_called_once_with(
+            gcp_conn_id='google_cloud_default', delegate_to=None, 
impersonation_chain=None,
+        )
+        # Make sure only 'create_job' is invoked on hook instance
+        self.assertEqual(len(hook_instance.mock_calls), 1)
+        hook_instance.create_job.assert_called_once_with(
+            project_id='test-project', job=training_input, 
use_existing_job_fn=ANY
+        )
+
+    @patch('airflow.providers.google.cloud.operators.mlengine.MLEngineHook')
     def test_success_create_training_job_with_optional_args(self, mock_hook):
         training_input = copy.deepcopy(self.TRAINING_INPUT)
         training_input['trainingInput']['runtimeVersion'] = '1.6'

Reply via email to