This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 74c3fb366e Remove parent_model version suffix if it is passed to 
Vertex AI operators (#39640)
74c3fb366e is described below

commit 74c3fb366ecf830b6e0fb961dd6668216d21cdeb
Author: Eugene <[email protected]>
AuthorDate: Fri May 17 07:36:00 2024 +0000

    Remove parent_model version suffix if it is passed to Vertex AI operators 
(#39640)
---
 .../google/cloud/operators/vertex_ai/auto_ml.py    |   5 +
 .../google/cloud/operators/vertex_ai/custom_job.py |   6 +
 .../operators/cloud/vertex_ai.rst                  |   4 +-
 .../google/cloud/operators/test_vertex_ai.py       | 516 +++++++++++++++++++++
 .../vertex_ai/example_vertex_ai_custom_job.py      |  43 +-
 5 files changed, 553 insertions(+), 21 deletions(-)

diff --git a/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py 
b/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py
index 9ab9d06002..1475232012 100644
--- a/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py
+++ b/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py
@@ -176,6 +176,7 @@ class 
CreateAutoMLForecastingTrainingJobOperator(AutoMLTrainingJobBaseOperator):
             gcp_conn_id=self.gcp_conn_id,
             impersonation_chain=self.impersonation_chain,
         )
+        self.parent_model = self.parent_model.split("@")[0] if 
self.parent_model else None
         model, training_id = self.hook.create_auto_ml_forecasting_training_job(
             project_id=self.project_id,
             region=self.region,
@@ -283,6 +284,7 @@ class 
CreateAutoMLImageTrainingJobOperator(AutoMLTrainingJobBaseOperator):
             gcp_conn_id=self.gcp_conn_id,
             impersonation_chain=self.impersonation_chain,
         )
+        self.parent_model = self.parent_model.split("@")[0] if 
self.parent_model else None
         model, training_id = self.hook.create_auto_ml_image_training_job(
             project_id=self.project_id,
             region=self.region,
@@ -391,6 +393,7 @@ class 
CreateAutoMLTabularTrainingJobOperator(AutoMLTrainingJobBaseOperator):
             impersonation_chain=self.impersonation_chain,
         )
         credentials, _ = self.hook.get_credentials_and_project_id()
+        self.parent_model = self.parent_model.split("@")[0] if 
self.parent_model else None
         model, training_id = self.hook.create_auto_ml_tabular_training_job(
             project_id=self.project_id,
             region=self.region,
@@ -485,6 +488,7 @@ class 
CreateAutoMLTextTrainingJobOperator(AutoMLTrainingJobBaseOperator):
             gcp_conn_id=self.gcp_conn_id,
             impersonation_chain=self.impersonation_chain,
         )
+        self.parent_model = self.parent_model.split("@")[0] if 
self.parent_model else None
         model, training_id = self.hook.create_auto_ml_text_training_job(
             project_id=self.project_id,
             region=self.region,
@@ -561,6 +565,7 @@ class 
CreateAutoMLVideoTrainingJobOperator(AutoMLTrainingJobBaseOperator):
             gcp_conn_id=self.gcp_conn_id,
             impersonation_chain=self.impersonation_chain,
         )
+        self.parent_model = self.parent_model.split("@")[0] if 
self.parent_model else None
         model, training_id = self.hook.create_auto_ml_video_training_job(
             project_id=self.project_id,
             region=self.region,
diff --git a/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py 
b/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py
index 9264852050..3d61f2ac77 100644
--- a/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py
+++ b/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py
@@ -493,6 +493,8 @@ class 
CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator):
     def execute(self, context: Context):
         super().execute(context)
 
+        self.parent_model = self.parent_model.split("@")[0] if 
self.parent_model else None
+
         if self.deferrable:
             self.invoke_defer(context=context)
 
@@ -966,6 +968,8 @@ class 
CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator
     def execute(self, context: Context):
         super().execute(context)
 
+        self.parent_model = self.parent_model.split("@")[0] if 
self.parent_model else None
+
         if self.deferrable:
             self.invoke_defer(context=context)
 
@@ -1446,6 +1450,8 @@ class 
CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator):
     def execute(self, context: Context):
         super().execute(context)
 
+        self.parent_model = self.parent_model.split("@")[0] if 
self.parent_model else None
+
         if self.deferrable:
             self.invoke_defer(context=context)
 
diff --git a/docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst 
b/docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst
index c93ce54577..f5c12039ff 100644
--- a/docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst
+++ b/docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst
@@ -187,8 +187,8 @@ The same operation can be performed in the deferrable mode:
 .. exampleinclude:: 
/../../tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_custom_job.py
     :language: python
     :dedent: 4
-    :start-after: [START 
how_to_cloud_vertex_ai_create_custom_training_job_v2_operator_deferrable]
-    :end-before: [END 
how_to_cloud_vertex_ai_create_custom_training_job_v2_operator_deferrable]
+    :start-after: [START 
how_to_cloud_vertex_ai_create_custom_training_job_v2_deferrable_operator]
+    :end-before: [END 
how_to_cloud_vertex_ai_create_custom_training_job_v2_deferrable_operator]
 
 
 You can get a list of Training Jobs using
diff --git a/tests/providers/google/cloud/operators/test_vertex_ai.py 
b/tests/providers/google/cloud/operators/test_vertex_ai.py
index e891d65249..3f8649f588 100644
--- a/tests/providers/google/cloud/operators/test_vertex_ai.py
+++ b/tests/providers/google/cloud/operators/test_vertex_ai.py
@@ -141,6 +141,7 @@ TEST_DATASET = {
 }
 TEST_DATASET_ID = "test-dataset-id"
 TEST_PARENT_MODEL = "test-parent-model"
+VERSIONED_TEST_PARENT_MODEL = f"{TEST_PARENT_MODEL}@1"
 TEST_EXPORT_CONFIG = {
     "annotationsFilter": "test-filter",
     "gcs_destination": {"output_uri_prefix": "airflow-system-tests-data"},
@@ -292,6 +293,93 @@ class TestVertexAICreateCustomContainerTrainingJobOperator:
             model_version_description=None,
         )
 
+    @mock.patch(VERTEX_AI_PATH.format("custom_job.Dataset"))
+    @mock.patch(VERTEX_AI_PATH.format("custom_job.CustomJobHook"))
+    def test_execute__parent_model_version_index_is_removed(self, mock_hook, 
mock_dataset):
+        
mock_hook.return_value.create_custom_container_training_job.return_value = (
+            None,
+            "training_id",
+            "custom_job_id",
+        )
+        op = CreateCustomContainerTrainingJobOperator(
+            task_id=TASK_ID,
+            gcp_conn_id=GCP_CONN_ID,
+            impersonation_chain=IMPERSONATION_CHAIN,
+            staging_bucket=STAGING_BUCKET,
+            display_name=DISPLAY_NAME,
+            args=ARGS,
+            container_uri=CONTAINER_URI,
+            model_serving_container_image_uri=CONTAINER_URI,
+            command=COMMAND_2,
+            model_display_name=DISPLAY_NAME_2,
+            replica_count=REPLICA_COUNT,
+            machine_type=MACHINE_TYPE,
+            accelerator_type=ACCELERATOR_TYPE,
+            accelerator_count=ACCELERATOR_COUNT,
+            training_fraction_split=TRAINING_FRACTION_SPLIT,
+            validation_fraction_split=VALIDATION_FRACTION_SPLIT,
+            test_fraction_split=TEST_FRACTION_SPLIT,
+            region=GCP_LOCATION,
+            project_id=GCP_PROJECT,
+            dataset_id=TEST_DATASET_ID,
+            parent_model=VERSIONED_TEST_PARENT_MODEL,
+        )
+        with pytest.warns(AirflowProviderDeprecationWarning, 
match=SYNC_DEPRECATION_WARNING):
+            op.execute(context={"ti": mock.MagicMock()})
+        
mock_hook.return_value.create_custom_container_training_job.assert_called_once_with(
+            staging_bucket=STAGING_BUCKET,
+            display_name=DISPLAY_NAME,
+            args=ARGS,
+            container_uri=CONTAINER_URI,
+            model_serving_container_image_uri=CONTAINER_URI,
+            command=COMMAND_2,
+            dataset=mock_dataset.return_value,
+            model_display_name=DISPLAY_NAME_2,
+            replica_count=REPLICA_COUNT,
+            machine_type=MACHINE_TYPE,
+            accelerator_type=ACCELERATOR_TYPE,
+            accelerator_count=ACCELERATOR_COUNT,
+            training_fraction_split=TRAINING_FRACTION_SPLIT,
+            validation_fraction_split=VALIDATION_FRACTION_SPLIT,
+            test_fraction_split=TEST_FRACTION_SPLIT,
+            region=GCP_LOCATION,
+            project_id=GCP_PROJECT,
+            parent_model=TEST_PARENT_MODEL,
+            model_serving_container_predict_route=None,
+            model_serving_container_health_route=None,
+            model_serving_container_command=None,
+            model_serving_container_args=None,
+            model_serving_container_environment_variables=None,
+            model_serving_container_ports=None,
+            model_description=None,
+            model_instance_schema_uri=None,
+            model_parameters_schema_uri=None,
+            model_prediction_schema_uri=None,
+            labels=None,
+            training_encryption_spec_key_name=None,
+            model_encryption_spec_key_name=None,
+            # RUN
+            annotation_schema_uri=None,
+            model_labels=None,
+            base_output_dir=None,
+            service_account=None,
+            network=None,
+            bigquery_destination=None,
+            environment_variables=None,
+            boot_disk_type="pd-ssd",
+            boot_disk_size_gb=100,
+            training_filter_split=None,
+            validation_filter_split=None,
+            test_filter_split=None,
+            predefined_split_column_name=None,
+            timestamp_split_column_name=None,
+            tensorboard=None,
+            sync=True,
+            is_default_version=None,
+            model_version_aliases=None,
+            model_version_description=None,
+        )
+
     
@mock.patch(VERTEX_AI_PATH.format("custom_job.CreateCustomContainerTrainingJobOperator.hook"))
     def test_execute_enters_deferred_state(self, mock_hook):
         task = CreateCustomContainerTrainingJobOperator(
@@ -476,6 +564,95 @@ class 
TestVertexAICreateCustomPythonPackageTrainingJobOperator:
             sync=True,
         )
 
+    @mock.patch(VERTEX_AI_PATH.format("custom_job.Dataset"))
+    @mock.patch(VERTEX_AI_PATH.format("custom_job.CustomJobHook"))
+    def test_execute__parent_model_version_index_is_removed(self, mock_hook, 
mock_dataset):
+        
mock_hook.return_value.create_custom_python_package_training_job.return_value = 
(
+            None,
+            "training_id",
+            "custom_job_id",
+        )
+        op = CreateCustomPythonPackageTrainingJobOperator(
+            task_id=TASK_ID,
+            gcp_conn_id=GCP_CONN_ID,
+            impersonation_chain=IMPERSONATION_CHAIN,
+            staging_bucket=STAGING_BUCKET,
+            display_name=DISPLAY_NAME,
+            python_package_gcs_uri=PYTHON_PACKAGE_GCS_URI,
+            python_module_name=PYTHON_MODULE_NAME,
+            container_uri=CONTAINER_URI,
+            args=ARGS,
+            model_serving_container_image_uri=CONTAINER_URI,
+            model_display_name=DISPLAY_NAME_2,
+            replica_count=REPLICA_COUNT,
+            machine_type=MACHINE_TYPE,
+            accelerator_type=ACCELERATOR_TYPE,
+            accelerator_count=ACCELERATOR_COUNT,
+            training_fraction_split=TRAINING_FRACTION_SPLIT,
+            validation_fraction_split=VALIDATION_FRACTION_SPLIT,
+            test_fraction_split=TEST_FRACTION_SPLIT,
+            region=GCP_LOCATION,
+            project_id=GCP_PROJECT,
+            dataset_id=TEST_DATASET_ID,
+            parent_model=VERSIONED_TEST_PARENT_MODEL,
+        )
+        with pytest.warns(AirflowProviderDeprecationWarning, 
match=SYNC_DEPRECATION_WARNING):
+            op.execute(context={"ti": mock.MagicMock()})
+        
mock_hook.return_value.create_custom_python_package_training_job.assert_called_once_with(
+            staging_bucket=STAGING_BUCKET,
+            display_name=DISPLAY_NAME,
+            args=ARGS,
+            container_uri=CONTAINER_URI,
+            model_serving_container_image_uri=CONTAINER_URI,
+            python_package_gcs_uri=PYTHON_PACKAGE_GCS_URI,
+            python_module_name=PYTHON_MODULE_NAME,
+            dataset=mock_dataset.return_value,
+            model_display_name=DISPLAY_NAME_2,
+            replica_count=REPLICA_COUNT,
+            machine_type=MACHINE_TYPE,
+            accelerator_type=ACCELERATOR_TYPE,
+            accelerator_count=ACCELERATOR_COUNT,
+            training_fraction_split=TRAINING_FRACTION_SPLIT,
+            validation_fraction_split=VALIDATION_FRACTION_SPLIT,
+            test_fraction_split=TEST_FRACTION_SPLIT,
+            region=GCP_LOCATION,
+            project_id=GCP_PROJECT,
+            parent_model=TEST_PARENT_MODEL,
+            is_default_version=None,
+            model_version_aliases=None,
+            model_version_description=None,
+            model_serving_container_predict_route=None,
+            model_serving_container_health_route=None,
+            model_serving_container_command=None,
+            model_serving_container_args=None,
+            model_serving_container_environment_variables=None,
+            model_serving_container_ports=None,
+            model_description=None,
+            model_instance_schema_uri=None,
+            model_parameters_schema_uri=None,
+            model_prediction_schema_uri=None,
+            labels=None,
+            training_encryption_spec_key_name=None,
+            model_encryption_spec_key_name=None,
+            # RUN
+            annotation_schema_uri=None,
+            model_labels=None,
+            base_output_dir=None,
+            service_account=None,
+            network=None,
+            bigquery_destination=None,
+            environment_variables=None,
+            boot_disk_type="pd-ssd",
+            boot_disk_size_gb=100,
+            training_filter_split=None,
+            validation_filter_split=None,
+            test_filter_split=None,
+            predefined_split_column_name=None,
+            timestamp_split_column_name=None,
+            tensorboard=None,
+            sync=True,
+        )
+
     
@mock.patch(VERTEX_AI_PATH.format("custom_job.CreateCustomPythonPackageTrainingJobOperator.hook"))
     def test_execute_enters_deferred_state(self, mock_hook):
         task = CreateCustomPythonPackageTrainingJobOperator(
@@ -656,6 +833,88 @@ class TestVertexAICreateCustomTrainingJobOperator:
             model_version_description=None,
         )
 
+    @mock.patch(VERTEX_AI_PATH.format("custom_job.Dataset"))
+    @mock.patch(VERTEX_AI_PATH.format("custom_job.CustomJobHook"))
+    def test_execute__parent_model_version_index_is_removed(self, mock_hook, 
mock_dataset):
+        mock_hook.return_value.create_custom_training_job.return_value = (
+            None,
+            "training_id",
+            "custom_job_id",
+        )
+        op = CreateCustomTrainingJobOperator(
+            task_id=TASK_ID,
+            gcp_conn_id=GCP_CONN_ID,
+            impersonation_chain=IMPERSONATION_CHAIN,
+            staging_bucket=STAGING_BUCKET,
+            display_name=DISPLAY_NAME,
+            script_path=PYTHON_PACKAGE,
+            args=PYTHON_PACKAGE_CMDARGS,
+            container_uri=CONTAINER_URI,
+            model_serving_container_image_uri=CONTAINER_URI,
+            requirements=[],
+            replica_count=1,
+            region=GCP_LOCATION,
+            project_id=GCP_PROJECT,
+            dataset_id=TEST_DATASET_ID,
+            parent_model=VERSIONED_TEST_PARENT_MODEL,
+        )
+        with pytest.warns(AirflowProviderDeprecationWarning, 
match=SYNC_DEPRECATION_WARNING):
+            op.execute(context={"ti": mock.MagicMock()})
+        
mock_hook.return_value.create_custom_training_job.assert_called_once_with(
+            staging_bucket=STAGING_BUCKET,
+            display_name=DISPLAY_NAME,
+            args=PYTHON_PACKAGE_CMDARGS,
+            container_uri=CONTAINER_URI,
+            model_serving_container_image_uri=CONTAINER_URI,
+            script_path=PYTHON_PACKAGE,
+            requirements=[],
+            dataset=mock_dataset.return_value,
+            model_display_name=None,
+            replica_count=REPLICA_COUNT,
+            machine_type=MACHINE_TYPE,
+            accelerator_type=ACCELERATOR_TYPE,
+            accelerator_count=ACCELERATOR_COUNT,
+            training_fraction_split=None,
+            validation_fraction_split=None,
+            test_fraction_split=None,
+            parent_model=TEST_PARENT_MODEL,
+            region=GCP_LOCATION,
+            project_id=GCP_PROJECT,
+            model_serving_container_predict_route=None,
+            model_serving_container_health_route=None,
+            model_serving_container_command=None,
+            model_serving_container_args=None,
+            model_serving_container_environment_variables=None,
+            model_serving_container_ports=None,
+            model_description=None,
+            model_instance_schema_uri=None,
+            model_parameters_schema_uri=None,
+            model_prediction_schema_uri=None,
+            labels=None,
+            training_encryption_spec_key_name=None,
+            model_encryption_spec_key_name=None,
+            # RUN
+            annotation_schema_uri=None,
+            model_labels=None,
+            base_output_dir=None,
+            service_account=None,
+            network=None,
+            bigquery_destination=None,
+            environment_variables=None,
+            boot_disk_type="pd-ssd",
+            boot_disk_size_gb=100,
+            training_filter_split=None,
+            validation_filter_split=None,
+            test_filter_split=None,
+            predefined_split_column_name=None,
+            timestamp_split_column_name=None,
+            tensorboard=None,
+            sync=True,
+            is_default_version=None,
+            model_version_aliases=None,
+            model_version_description=None,
+        )
+
     
@mock.patch(VERTEX_AI_PATH.format("custom_job.CreateCustomTrainingJobOperator.hook"))
     def test_execute_enters_deferred_state(self, mock_hook):
         task = CreateCustomTrainingJobOperator(
@@ -1083,6 +1342,71 @@ class 
TestVertexAICreateAutoMLForecastingTrainingJobOperator:
             model_version_description=None,
         )
 
+    @mock.patch("google.cloud.aiplatform.datasets.TimeSeriesDataset")
+    @mock.patch(VERTEX_AI_PATH.format("auto_ml.AutoMLHook"))
+    def test_execute__parent_model_version_index_is_removed(self, mock_hook, 
mock_dataset):
+        
mock_hook.return_value.create_auto_ml_forecasting_training_job.return_value = 
(None, "training_id")
+        op = CreateAutoMLForecastingTrainingJobOperator(
+            task_id=TASK_ID,
+            gcp_conn_id=GCP_CONN_ID,
+            impersonation_chain=IMPERSONATION_CHAIN,
+            display_name=DISPLAY_NAME,
+            dataset_id=TEST_DATASET_ID,
+            target_column=TEST_TRAINING_TARGET_COLUMN,
+            time_column=TEST_TRAINING_TIME_COLUMN,
+            
time_series_identifier_column=TEST_TRAINING_TIME_SERIES_IDENTIFIER_COLUMN,
+            
unavailable_at_forecast_columns=TEST_TRAINING_UNAVAILABLE_AT_FORECAST_COLUMNS,
+            
available_at_forecast_columns=TEST_TRAINING_AVAILABLE_AT_FORECAST_COLUMNS,
+            forecast_horizon=TEST_TRAINING_FORECAST_HORIZON,
+            data_granularity_unit=TEST_TRAINING_DATA_GRANULARITY_UNIT,
+            data_granularity_count=TEST_TRAINING_DATA_GRANULARITY_COUNT,
+            sync=True,
+            region=GCP_LOCATION,
+            project_id=GCP_PROJECT,
+            parent_model=VERSIONED_TEST_PARENT_MODEL,
+        )
+        op.execute(context={"ti": mock.MagicMock()})
+        
mock_hook.return_value.create_auto_ml_forecasting_training_job.assert_called_once_with(
+            project_id=GCP_PROJECT,
+            region=GCP_LOCATION,
+            display_name=DISPLAY_NAME,
+            dataset=mock_dataset.return_value,
+            target_column=TEST_TRAINING_TARGET_COLUMN,
+            time_column=TEST_TRAINING_TIME_COLUMN,
+            
time_series_identifier_column=TEST_TRAINING_TIME_SERIES_IDENTIFIER_COLUMN,
+            
unavailable_at_forecast_columns=TEST_TRAINING_UNAVAILABLE_AT_FORECAST_COLUMNS,
+            
available_at_forecast_columns=TEST_TRAINING_AVAILABLE_AT_FORECAST_COLUMNS,
+            forecast_horizon=TEST_TRAINING_FORECAST_HORIZON,
+            data_granularity_unit=TEST_TRAINING_DATA_GRANULARITY_UNIT,
+            data_granularity_count=TEST_TRAINING_DATA_GRANULARITY_COUNT,
+            parent_model=TEST_PARENT_MODEL,
+            optimization_objective=None,
+            column_specs=None,
+            column_transformations=None,
+            labels=None,
+            training_encryption_spec_key_name=None,
+            model_encryption_spec_key_name=None,
+            training_fraction_split=None,
+            validation_fraction_split=None,
+            test_fraction_split=None,
+            predefined_split_column_name=None,
+            weight_column=None,
+            time_series_attribute_columns=None,
+            context_window=None,
+            export_evaluated_data_items=False,
+            export_evaluated_data_items_bigquery_destination_uri=None,
+            export_evaluated_data_items_override_destination=False,
+            quantiles=None,
+            validation_options=None,
+            budget_milli_node_hours=1000,
+            model_display_name=None,
+            model_labels=None,
+            sync=True,
+            is_default_version=None,
+            model_version_aliases=None,
+            model_version_description=None,
+        )
+
 
 class TestVertexAICreateAutoMLImageTrainingJobOperator:
     @mock.patch("google.cloud.aiplatform.datasets.ImageDataset")
@@ -1135,6 +1459,54 @@ class TestVertexAICreateAutoMLImageTrainingJobOperator:
             model_version_description=None,
         )
 
+    @mock.patch("google.cloud.aiplatform.datasets.ImageDataset")
+    @mock.patch(VERTEX_AI_PATH.format("auto_ml.AutoMLHook"))
+    def test_execute__parent_model_version_index_is_removed(self, mock_hook, 
mock_dataset):
+        mock_hook.return_value.create_auto_ml_image_training_job.return_value 
= (None, "training_id")
+        op = CreateAutoMLImageTrainingJobOperator(
+            task_id=TASK_ID,
+            gcp_conn_id=GCP_CONN_ID,
+            impersonation_chain=IMPERSONATION_CHAIN,
+            display_name=DISPLAY_NAME,
+            dataset_id=TEST_DATASET_ID,
+            prediction_type="classification",
+            multi_label=False,
+            model_type="CLOUD",
+            sync=True,
+            region=GCP_LOCATION,
+            project_id=GCP_PROJECT,
+            parent_model=VERSIONED_TEST_PARENT_MODEL,
+        )
+        op.execute(context={"ti": mock.MagicMock()})
+        
mock_hook.return_value.create_auto_ml_image_training_job.assert_called_once_with(
+            project_id=GCP_PROJECT,
+            region=GCP_LOCATION,
+            display_name=DISPLAY_NAME,
+            dataset=mock_dataset.return_value,
+            prediction_type="classification",
+            parent_model=TEST_PARENT_MODEL,
+            multi_label=False,
+            model_type="CLOUD",
+            base_model=None,
+            labels=None,
+            training_encryption_spec_key_name=None,
+            model_encryption_spec_key_name=None,
+            training_fraction_split=None,
+            validation_fraction_split=None,
+            test_fraction_split=None,
+            training_filter_split=None,
+            validation_filter_split=None,
+            test_filter_split=None,
+            budget_milli_node_hours=None,
+            model_display_name=None,
+            model_labels=None,
+            disable_early_stopping=False,
+            sync=True,
+            is_default_version=None,
+            model_version_aliases=None,
+            model_version_description=None,
+        )
+
 
 class TestVertexAICreateAutoMLTabularTrainingJobOperator:
     @mock.patch("google.cloud.aiplatform.datasets.TabularDataset")
@@ -1199,6 +1571,64 @@ class TestVertexAICreateAutoMLTabularTrainingJobOperator:
             model_version_description=None,
         )
 
+    @mock.patch("google.cloud.aiplatform.datasets.TabularDataset")
+    @mock.patch(VERTEX_AI_PATH.format("auto_ml.AutoMLHook"))
+    def test_execute__parent_model_version_index_is_removed(self, mock_hook, 
mock_dataset):
+        mock_hook.return_value = MagicMock(
+            **{
+                "create_auto_ml_tabular_training_job.return_value": (None, 
"training_id"),
+                "get_credentials_and_project_id.return_value": ("creds", 
"project_id"),
+            }
+        )
+        op = CreateAutoMLTabularTrainingJobOperator(
+            task_id=TASK_ID,
+            gcp_conn_id=GCP_CONN_ID,
+            impersonation_chain=IMPERSONATION_CHAIN,
+            display_name=DISPLAY_NAME,
+            dataset_id=TEST_DATASET_ID,
+            target_column=None,
+            optimization_prediction_type=None,
+            sync=True,
+            region=GCP_LOCATION,
+            project_id=GCP_PROJECT,
+            parent_model=VERSIONED_TEST_PARENT_MODEL,
+        )
+        op.execute(context={"ti": mock.MagicMock()})
+        
mock_hook.return_value.create_auto_ml_tabular_training_job.assert_called_once_with(
+            project_id=GCP_PROJECT,
+            region=GCP_LOCATION,
+            display_name=DISPLAY_NAME,
+            dataset=mock_dataset.return_value,
+            parent_model=TEST_PARENT_MODEL,
+            target_column=None,
+            optimization_prediction_type=None,
+            optimization_objective=None,
+            column_specs=None,
+            column_transformations=None,
+            optimization_objective_recall_value=None,
+            optimization_objective_precision_value=None,
+            labels=None,
+            training_encryption_spec_key_name=None,
+            model_encryption_spec_key_name=None,
+            training_fraction_split=None,
+            validation_fraction_split=None,
+            test_fraction_split=None,
+            predefined_split_column_name=None,
+            timestamp_split_column_name=None,
+            weight_column=None,
+            budget_milli_node_hours=1000,
+            model_display_name=None,
+            model_labels=None,
+            disable_early_stopping=False,
+            export_evaluated_data_items=False,
+            export_evaluated_data_items_bigquery_destination_uri=None,
+            export_evaluated_data_items_override_destination=False,
+            sync=True,
+            is_default_version=None,
+            model_version_aliases=None,
+            model_version_description=None,
+        )
+
 
 class TestVertexAICreateAutoMLTextTrainingJobOperator:
     @mock.patch("google.cloud.aiplatform.datasets.TextDataset")
@@ -1248,6 +1678,51 @@ class TestVertexAICreateAutoMLTextTrainingJobOperator:
             model_version_description=None,
         )
 
+    @mock.patch("google.cloud.aiplatform.datasets.TextDataset")
+    @mock.patch(VERTEX_AI_PATH.format("auto_ml.AutoMLHook"))
+    def test_execute__parent_model_version_index_is_removed(self, mock_hook, 
mock_dataset):
+        mock_hook.return_value.create_auto_ml_text_training_job.return_value = 
(None, "training_id")
+        op = CreateAutoMLTextTrainingJobOperator(
+            task_id=TASK_ID,
+            gcp_conn_id=GCP_CONN_ID,
+            impersonation_chain=IMPERSONATION_CHAIN,
+            display_name=DISPLAY_NAME,
+            dataset_id=TEST_DATASET_ID,
+            prediction_type=None,
+            multi_label=False,
+            sentiment_max=10,
+            sync=True,
+            region=GCP_LOCATION,
+            project_id=GCP_PROJECT,
+            parent_model=VERSIONED_TEST_PARENT_MODEL,
+        )
+        op.execute(context={"ti": mock.MagicMock()})
+        
mock_hook.return_value.create_auto_ml_text_training_job.assert_called_once_with(
+            project_id=GCP_PROJECT,
+            region=GCP_LOCATION,
+            display_name=DISPLAY_NAME,
+            dataset=mock_dataset.return_value,
+            parent_model=TEST_PARENT_MODEL,
+            prediction_type=None,
+            multi_label=False,
+            sentiment_max=10,
+            labels=None,
+            training_encryption_spec_key_name=None,
+            model_encryption_spec_key_name=None,
+            training_fraction_split=None,
+            validation_fraction_split=None,
+            test_fraction_split=None,
+            training_filter_split=None,
+            validation_filter_split=None,
+            test_filter_split=None,
+            model_display_name=None,
+            model_labels=None,
+            sync=True,
+            is_default_version=None,
+            model_version_aliases=None,
+            model_version_description=None,
+        )
+
 
 class TestVertexAICreateAutoMLVideoTrainingJobOperator:
     @mock.patch("google.cloud.aiplatform.datasets.VideoDataset")
@@ -1293,6 +1768,47 @@ class TestVertexAICreateAutoMLVideoTrainingJobOperator:
             model_version_description=None,
         )
 
+    @mock.patch("google.cloud.aiplatform.datasets.VideoDataset")
+    @mock.patch(VERTEX_AI_PATH.format("auto_ml.AutoMLHook"))
+    def test_execute__parent_model_version_index_is_removed(self, mock_hook, 
mock_dataset):
+        mock_hook.return_value.create_auto_ml_video_training_job.return_value 
= (None, "training_id")
+        op = CreateAutoMLVideoTrainingJobOperator(
+            task_id=TASK_ID,
+            gcp_conn_id=GCP_CONN_ID,
+            impersonation_chain=IMPERSONATION_CHAIN,
+            display_name=DISPLAY_NAME,
+            dataset_id=TEST_DATASET_ID,
+            prediction_type="classification",
+            model_type="CLOUD",
+            sync=True,
+            region=GCP_LOCATION,
+            project_id=GCP_PROJECT,
+            parent_model=VERSIONED_TEST_PARENT_MODEL,
+        )
+        op.execute(context={"ti": mock.MagicMock()})
+        
mock_hook.return_value.create_auto_ml_video_training_job.assert_called_once_with(
+            project_id=GCP_PROJECT,
+            region=GCP_LOCATION,
+            display_name=DISPLAY_NAME,
+            dataset=mock_dataset.return_value,
+            parent_model=TEST_PARENT_MODEL,
+            prediction_type="classification",
+            model_type="CLOUD",
+            labels=None,
+            training_encryption_spec_key_name=None,
+            model_encryption_spec_key_name=None,
+            training_fraction_split=None,
+            test_fraction_split=None,
+            training_filter_split=None,
+            test_filter_split=None,
+            model_display_name=None,
+            model_labels=None,
+            sync=True,
+            is_default_version=None,
+            model_version_aliases=None,
+            model_version_description=None,
+        )
+
 
 class TestVertexAIDeleteAutoMLTrainingJobOperator:
     @mock.patch(VERTEX_AI_PATH.format("auto_ml.AutoMLHook"))
diff --git 
a/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_custom_job.py 
b/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_custom_job.py
index 702e6a6c51..c90c1aac23 100644
--- 
a/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_custom_job.py
+++ 
b/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_custom_job.py
@@ -17,9 +17,7 @@
 # under the License.
 
 
-"""
-Example Airflow DAG for Google Vertex AI service testing Custom Jobs 
operations.
-"""
+"""Example Airflow DAG for Google Vertex AI service testing Custom Jobs 
operations."""
 
 from __future__ import annotations
 
@@ -49,15 +47,13 @@ from airflow.providers.google.cloud.transfers.gcs_to_local 
import GCSToLocalFile
 from airflow.utils.trigger_rule import TriggerRule
 
 ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID", "default")
-DAG_ID = "example_vertex_ai_custom_job_operations"
 PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "default")
 REGION = "us-central1"
 CUSTOM_DISPLAY_NAME = f"train-housing-custom-{ENV_ID}"
 MODEL_DISPLAY_NAME = f"custom-housing-model-{ENV_ID}"
-
+DAG_ID = "vertex_ai_custom_job_operations"
 RESOURCE_DATA_BUCKET = "airflow-system-tests-resources"
 CUSTOM_GCS_BUCKET_NAME = f"bucket_{DAG_ID}_{ENV_ID}".replace("_", "-")
-
 DATA_SAMPLE_GCS_OBJECT_NAME = "vertex-ai/california_housing_train.csv"
 
 
@@ -133,6 +129,7 @@ with DAG(
         region=REGION,
         project_id=PROJECT_ID,
     )
+
     model_id_v1 = create_custom_training_job.output["model_id"]
     # [END how_to_cloud_vertex_ai_create_custom_training_job_operator]
 
@@ -140,7 +137,7 @@ with DAG(
     create_custom_training_job_deferrable = CreateCustomTrainingJobOperator(
         task_id="custom_task_deferrable",
         staging_bucket=f"gs://{CUSTOM_GCS_BUCKET_NAME}",
-        display_name=f"{CUSTOM_DISPLAY_NAME}_DEF",
+        display_name=f"{CUSTOM_DISPLAY_NAME}-def",
         script_path=LOCAL_TRAINING_SCRIPT_PATH,
         container_uri=CONTAINER_URI,
         requirements=["gcsfs==0.7.1"],
@@ -148,12 +145,12 @@ with DAG(
         # run params
         dataset_id=tabular_dataset_id,
         replica_count=REPLICA_COUNT,
-        model_display_name=f"{MODEL_DISPLAY_NAME}_DEF",
+        model_display_name=f"{MODEL_DISPLAY_NAME}-def",
         region=REGION,
         project_id=PROJECT_ID,
         deferrable=True,
     )
-    model_id_v1_deferrable = create_custom_training_job.output["model_id"]
+    model_id_deferrable_v1 = 
create_custom_training_job_deferrable.output["model_id"]
     # [END 
how_to_cloud_vertex_ai_create_custom_training_job_operator_deferrable]
 
     # [START how_to_cloud_vertex_ai_create_custom_training_job_v2_operator]
@@ -176,26 +173,26 @@ with DAG(
     )
     # [END how_to_cloud_vertex_ai_create_custom_training_job_v2_operator]
 
-    # [START 
how_to_cloud_vertex_ai_create_custom_training_job_v2_operator_deferrable]
-    create_custom_training_job_v2_deferrable = CreateCustomTrainingJobOperator(
-        task_id="custom_task_v2_deferrable",
+    # [START 
how_to_cloud_vertex_ai_create_custom_training_job_v2_deferrable_operator]
+    create_custom_training_job_deferrable_v2 = CreateCustomTrainingJobOperator(
+        task_id="custom_task_deferrable_v2",
         staging_bucket=f"gs://{CUSTOM_GCS_BUCKET_NAME}",
-        display_name=f"{CUSTOM_DISPLAY_NAME}_DEF",
+        display_name=f"{CUSTOM_DISPLAY_NAME}-def",
         script_path=LOCAL_TRAINING_SCRIPT_PATH,
         container_uri=CONTAINER_URI,
         requirements=["gcsfs==0.7.1"],
         model_serving_container_image_uri=MODEL_SERVING_CONTAINER_URI,
-        parent_model=model_id_v1,
+        parent_model=model_id_deferrable_v1,
         # run params
         dataset_id=tabular_dataset_id,
         replica_count=REPLICA_COUNT,
-        model_display_name=f"{MODEL_DISPLAY_NAME}_DEF",
+        model_display_name=f"{MODEL_DISPLAY_NAME}-def",
         sync=False,
         region=REGION,
         project_id=PROJECT_ID,
         deferrable=True,
     )
-    # [END 
how_to_cloud_vertex_ai_create_custom_training_job_v2_operator_deferrable]
+    # [END 
how_to_cloud_vertex_ai_create_custom_training_job_v2_deferrable_operator]
 
     # [START how_to_cloud_vertex_ai_delete_custom_training_job_operator]
     delete_custom_training_job = DeleteCustomTrainingJobOperator(
@@ -208,6 +205,15 @@ with DAG(
     )
     # [END how_to_cloud_vertex_ai_delete_custom_training_job_operator]
 
+    delete_custom_training_job_deferrable = DeleteCustomTrainingJobOperator(
+        task_id="delete_custom_training_job_deferrable",
+        training_pipeline_id="{{ 
task_instance.xcom_pull(task_ids='custom_task_deferrable', key='training_id') 
}}",
+        custom_job_id="{{ 
task_instance.xcom_pull(task_ids='custom_task_deferrable', key='custom_job_id') 
}}",
+        region=REGION,
+        project_id=PROJECT_ID,
+        trigger_rule=TriggerRule.ALL_DONE,
+    )
+
     delete_tabular_dataset = DeleteDatasetOperator(
         task_id="delete_tabular_dataset",
         dataset_id=tabular_dataset_id,
@@ -230,15 +236,14 @@ with DAG(
             create_tabular_dataset,
             # TEST BODY
             [create_custom_training_job, 
create_custom_training_job_deferrable],
-            [create_custom_training_job_v2, 
create_custom_training_job_v2_deferrable],
+            [create_custom_training_job_v2, 
create_custom_training_job_deferrable_v2],
             # TEST TEARDOWN
-            delete_custom_training_job,
+            [delete_custom_training_job, 
delete_custom_training_job_deferrable],
             delete_tabular_dataset,
             delete_bucket,
         )
     )
 
-
 from tests.system.utils import get_test_run  # noqa: E402
 
 # Needed to run the example DAG with pytest (see: 
tests/system/README.md#run_via_pytest)


Reply via email to