This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 74c3fb366e Remove parent_model version suffix if it is passed to
Vertex AI operators (#39640)
74c3fb366e is described below
commit 74c3fb366ecf830b6e0fb961dd6668216d21cdeb
Author: Eugene <[email protected]>
AuthorDate: Fri May 17 07:36:00 2024 +0000
Remove parent_model version suffix if it is passed to Vertex AI operators
(#39640)
---
.../google/cloud/operators/vertex_ai/auto_ml.py | 5 +
.../google/cloud/operators/vertex_ai/custom_job.py | 6 +
.../operators/cloud/vertex_ai.rst | 4 +-
.../google/cloud/operators/test_vertex_ai.py | 516 +++++++++++++++++++++
.../vertex_ai/example_vertex_ai_custom_job.py | 43 +-
5 files changed, 553 insertions(+), 21 deletions(-)
diff --git a/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py
b/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py
index 9ab9d06002..1475232012 100644
--- a/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py
+++ b/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py
@@ -176,6 +176,7 @@ class
CreateAutoMLForecastingTrainingJobOperator(AutoMLTrainingJobBaseOperator):
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)
+ self.parent_model = self.parent_model.split("@")[0] if
self.parent_model else None
model, training_id = self.hook.create_auto_ml_forecasting_training_job(
project_id=self.project_id,
region=self.region,
@@ -283,6 +284,7 @@ class
CreateAutoMLImageTrainingJobOperator(AutoMLTrainingJobBaseOperator):
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)
+ self.parent_model = self.parent_model.split("@")[0] if
self.parent_model else None
model, training_id = self.hook.create_auto_ml_image_training_job(
project_id=self.project_id,
region=self.region,
@@ -391,6 +393,7 @@ class
CreateAutoMLTabularTrainingJobOperator(AutoMLTrainingJobBaseOperator):
impersonation_chain=self.impersonation_chain,
)
credentials, _ = self.hook.get_credentials_and_project_id()
+ self.parent_model = self.parent_model.split("@")[0] if
self.parent_model else None
model, training_id = self.hook.create_auto_ml_tabular_training_job(
project_id=self.project_id,
region=self.region,
@@ -485,6 +488,7 @@ class
CreateAutoMLTextTrainingJobOperator(AutoMLTrainingJobBaseOperator):
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)
+ self.parent_model = self.parent_model.split("@")[0] if
self.parent_model else None
model, training_id = self.hook.create_auto_ml_text_training_job(
project_id=self.project_id,
region=self.region,
@@ -561,6 +565,7 @@ class
CreateAutoMLVideoTrainingJobOperator(AutoMLTrainingJobBaseOperator):
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)
+ self.parent_model = self.parent_model.split("@")[0] if
self.parent_model else None
model, training_id = self.hook.create_auto_ml_video_training_job(
project_id=self.project_id,
region=self.region,
diff --git a/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py
b/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py
index 9264852050..3d61f2ac77 100644
--- a/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py
+++ b/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py
@@ -493,6 +493,8 @@ class
CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator):
def execute(self, context: Context):
super().execute(context)
+ self.parent_model = self.parent_model.split("@")[0] if
self.parent_model else None
+
if self.deferrable:
self.invoke_defer(context=context)
@@ -966,6 +968,8 @@ class
CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator
def execute(self, context: Context):
super().execute(context)
+ self.parent_model = self.parent_model.split("@")[0] if
self.parent_model else None
+
if self.deferrable:
self.invoke_defer(context=context)
@@ -1446,6 +1450,8 @@ class
CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator):
def execute(self, context: Context):
super().execute(context)
+ self.parent_model = self.parent_model.split("@")[0] if
self.parent_model else None
+
if self.deferrable:
self.invoke_defer(context=context)
diff --git a/docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst
b/docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst
index c93ce54577..f5c12039ff 100644
--- a/docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst
+++ b/docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst
@@ -187,8 +187,8 @@ The same operation can be performed in the deferrable mode:
.. exampleinclude::
/../../tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_custom_job.py
:language: python
:dedent: 4
- :start-after: [START
how_to_cloud_vertex_ai_create_custom_training_job_v2_operator_deferrable]
- :end-before: [END
how_to_cloud_vertex_ai_create_custom_training_job_v2_operator_deferrable]
+ :start-after: [START
how_to_cloud_vertex_ai_create_custom_training_job_v2_deferrable_operator]
+ :end-before: [END
how_to_cloud_vertex_ai_create_custom_training_job_v2_deferrable_operator]
You can get a list of Training Jobs using
diff --git a/tests/providers/google/cloud/operators/test_vertex_ai.py
b/tests/providers/google/cloud/operators/test_vertex_ai.py
index e891d65249..3f8649f588 100644
--- a/tests/providers/google/cloud/operators/test_vertex_ai.py
+++ b/tests/providers/google/cloud/operators/test_vertex_ai.py
@@ -141,6 +141,7 @@ TEST_DATASET = {
}
TEST_DATASET_ID = "test-dataset-id"
TEST_PARENT_MODEL = "test-parent-model"
+VERSIONED_TEST_PARENT_MODEL = f"{TEST_PARENT_MODEL}@1"
TEST_EXPORT_CONFIG = {
"annotationsFilter": "test-filter",
"gcs_destination": {"output_uri_prefix": "airflow-system-tests-data"},
@@ -292,6 +293,93 @@ class TestVertexAICreateCustomContainerTrainingJobOperator:
model_version_description=None,
)
+ @mock.patch(VERTEX_AI_PATH.format("custom_job.Dataset"))
+ @mock.patch(VERTEX_AI_PATH.format("custom_job.CustomJobHook"))
+ def test_execute__parent_model_version_index_is_removed(self, mock_hook,
mock_dataset):
+
mock_hook.return_value.create_custom_container_training_job.return_value = (
+ None,
+ "training_id",
+ "custom_job_id",
+ )
+ op = CreateCustomContainerTrainingJobOperator(
+ task_id=TASK_ID,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ staging_bucket=STAGING_BUCKET,
+ display_name=DISPLAY_NAME,
+ args=ARGS,
+ container_uri=CONTAINER_URI,
+ model_serving_container_image_uri=CONTAINER_URI,
+ command=COMMAND_2,
+ model_display_name=DISPLAY_NAME_2,
+ replica_count=REPLICA_COUNT,
+ machine_type=MACHINE_TYPE,
+ accelerator_type=ACCELERATOR_TYPE,
+ accelerator_count=ACCELERATOR_COUNT,
+ training_fraction_split=TRAINING_FRACTION_SPLIT,
+ validation_fraction_split=VALIDATION_FRACTION_SPLIT,
+ test_fraction_split=TEST_FRACTION_SPLIT,
+ region=GCP_LOCATION,
+ project_id=GCP_PROJECT,
+ dataset_id=TEST_DATASET_ID,
+ parent_model=VERSIONED_TEST_PARENT_MODEL,
+ )
+ with pytest.warns(AirflowProviderDeprecationWarning,
match=SYNC_DEPRECATION_WARNING):
+ op.execute(context={"ti": mock.MagicMock()})
+
mock_hook.return_value.create_custom_container_training_job.assert_called_once_with(
+ staging_bucket=STAGING_BUCKET,
+ display_name=DISPLAY_NAME,
+ args=ARGS,
+ container_uri=CONTAINER_URI,
+ model_serving_container_image_uri=CONTAINER_URI,
+ command=COMMAND_2,
+ dataset=mock_dataset.return_value,
+ model_display_name=DISPLAY_NAME_2,
+ replica_count=REPLICA_COUNT,
+ machine_type=MACHINE_TYPE,
+ accelerator_type=ACCELERATOR_TYPE,
+ accelerator_count=ACCELERATOR_COUNT,
+ training_fraction_split=TRAINING_FRACTION_SPLIT,
+ validation_fraction_split=VALIDATION_FRACTION_SPLIT,
+ test_fraction_split=TEST_FRACTION_SPLIT,
+ region=GCP_LOCATION,
+ project_id=GCP_PROJECT,
+ parent_model=TEST_PARENT_MODEL,
+ model_serving_container_predict_route=None,
+ model_serving_container_health_route=None,
+ model_serving_container_command=None,
+ model_serving_container_args=None,
+ model_serving_container_environment_variables=None,
+ model_serving_container_ports=None,
+ model_description=None,
+ model_instance_schema_uri=None,
+ model_parameters_schema_uri=None,
+ model_prediction_schema_uri=None,
+ labels=None,
+ training_encryption_spec_key_name=None,
+ model_encryption_spec_key_name=None,
+ # RUN
+ annotation_schema_uri=None,
+ model_labels=None,
+ base_output_dir=None,
+ service_account=None,
+ network=None,
+ bigquery_destination=None,
+ environment_variables=None,
+ boot_disk_type="pd-ssd",
+ boot_disk_size_gb=100,
+ training_filter_split=None,
+ validation_filter_split=None,
+ test_filter_split=None,
+ predefined_split_column_name=None,
+ timestamp_split_column_name=None,
+ tensorboard=None,
+ sync=True,
+ is_default_version=None,
+ model_version_aliases=None,
+ model_version_description=None,
+ )
+
@mock.patch(VERTEX_AI_PATH.format("custom_job.CreateCustomContainerTrainingJobOperator.hook"))
def test_execute_enters_deferred_state(self, mock_hook):
task = CreateCustomContainerTrainingJobOperator(
@@ -476,6 +564,95 @@ class
TestVertexAICreateCustomPythonPackageTrainingJobOperator:
sync=True,
)
+ @mock.patch(VERTEX_AI_PATH.format("custom_job.Dataset"))
+ @mock.patch(VERTEX_AI_PATH.format("custom_job.CustomJobHook"))
+ def test_execute__parent_model_version_index_is_removed(self, mock_hook,
mock_dataset):
+
mock_hook.return_value.create_custom_python_package_training_job.return_value =
(
+ None,
+ "training_id",
+ "custom_job_id",
+ )
+ op = CreateCustomPythonPackageTrainingJobOperator(
+ task_id=TASK_ID,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ staging_bucket=STAGING_BUCKET,
+ display_name=DISPLAY_NAME,
+ python_package_gcs_uri=PYTHON_PACKAGE_GCS_URI,
+ python_module_name=PYTHON_MODULE_NAME,
+ container_uri=CONTAINER_URI,
+ args=ARGS,
+ model_serving_container_image_uri=CONTAINER_URI,
+ model_display_name=DISPLAY_NAME_2,
+ replica_count=REPLICA_COUNT,
+ machine_type=MACHINE_TYPE,
+ accelerator_type=ACCELERATOR_TYPE,
+ accelerator_count=ACCELERATOR_COUNT,
+ training_fraction_split=TRAINING_FRACTION_SPLIT,
+ validation_fraction_split=VALIDATION_FRACTION_SPLIT,
+ test_fraction_split=TEST_FRACTION_SPLIT,
+ region=GCP_LOCATION,
+ project_id=GCP_PROJECT,
+ dataset_id=TEST_DATASET_ID,
+ parent_model=VERSIONED_TEST_PARENT_MODEL,
+ )
+ with pytest.warns(AirflowProviderDeprecationWarning,
match=SYNC_DEPRECATION_WARNING):
+ op.execute(context={"ti": mock.MagicMock()})
+
mock_hook.return_value.create_custom_python_package_training_job.assert_called_once_with(
+ staging_bucket=STAGING_BUCKET,
+ display_name=DISPLAY_NAME,
+ args=ARGS,
+ container_uri=CONTAINER_URI,
+ model_serving_container_image_uri=CONTAINER_URI,
+ python_package_gcs_uri=PYTHON_PACKAGE_GCS_URI,
+ python_module_name=PYTHON_MODULE_NAME,
+ dataset=mock_dataset.return_value,
+ model_display_name=DISPLAY_NAME_2,
+ replica_count=REPLICA_COUNT,
+ machine_type=MACHINE_TYPE,
+ accelerator_type=ACCELERATOR_TYPE,
+ accelerator_count=ACCELERATOR_COUNT,
+ training_fraction_split=TRAINING_FRACTION_SPLIT,
+ validation_fraction_split=VALIDATION_FRACTION_SPLIT,
+ test_fraction_split=TEST_FRACTION_SPLIT,
+ region=GCP_LOCATION,
+ project_id=GCP_PROJECT,
+ parent_model=TEST_PARENT_MODEL,
+ is_default_version=None,
+ model_version_aliases=None,
+ model_version_description=None,
+ model_serving_container_predict_route=None,
+ model_serving_container_health_route=None,
+ model_serving_container_command=None,
+ model_serving_container_args=None,
+ model_serving_container_environment_variables=None,
+ model_serving_container_ports=None,
+ model_description=None,
+ model_instance_schema_uri=None,
+ model_parameters_schema_uri=None,
+ model_prediction_schema_uri=None,
+ labels=None,
+ training_encryption_spec_key_name=None,
+ model_encryption_spec_key_name=None,
+ # RUN
+ annotation_schema_uri=None,
+ model_labels=None,
+ base_output_dir=None,
+ service_account=None,
+ network=None,
+ bigquery_destination=None,
+ environment_variables=None,
+ boot_disk_type="pd-ssd",
+ boot_disk_size_gb=100,
+ training_filter_split=None,
+ validation_filter_split=None,
+ test_filter_split=None,
+ predefined_split_column_name=None,
+ timestamp_split_column_name=None,
+ tensorboard=None,
+ sync=True,
+ )
+
@mock.patch(VERTEX_AI_PATH.format("custom_job.CreateCustomPythonPackageTrainingJobOperator.hook"))
def test_execute_enters_deferred_state(self, mock_hook):
task = CreateCustomPythonPackageTrainingJobOperator(
@@ -656,6 +833,88 @@ class TestVertexAICreateCustomTrainingJobOperator:
model_version_description=None,
)
+ @mock.patch(VERTEX_AI_PATH.format("custom_job.Dataset"))
+ @mock.patch(VERTEX_AI_PATH.format("custom_job.CustomJobHook"))
+ def test_execute__parent_model_version_index_is_removed(self, mock_hook,
mock_dataset):
+ mock_hook.return_value.create_custom_training_job.return_value = (
+ None,
+ "training_id",
+ "custom_job_id",
+ )
+ op = CreateCustomTrainingJobOperator(
+ task_id=TASK_ID,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ staging_bucket=STAGING_BUCKET,
+ display_name=DISPLAY_NAME,
+ script_path=PYTHON_PACKAGE,
+ args=PYTHON_PACKAGE_CMDARGS,
+ container_uri=CONTAINER_URI,
+ model_serving_container_image_uri=CONTAINER_URI,
+ requirements=[],
+ replica_count=1,
+ region=GCP_LOCATION,
+ project_id=GCP_PROJECT,
+ dataset_id=TEST_DATASET_ID,
+ parent_model=VERSIONED_TEST_PARENT_MODEL,
+ )
+ with pytest.warns(AirflowProviderDeprecationWarning,
match=SYNC_DEPRECATION_WARNING):
+ op.execute(context={"ti": mock.MagicMock()})
+
mock_hook.return_value.create_custom_training_job.assert_called_once_with(
+ staging_bucket=STAGING_BUCKET,
+ display_name=DISPLAY_NAME,
+ args=PYTHON_PACKAGE_CMDARGS,
+ container_uri=CONTAINER_URI,
+ model_serving_container_image_uri=CONTAINER_URI,
+ script_path=PYTHON_PACKAGE,
+ requirements=[],
+ dataset=mock_dataset.return_value,
+ model_display_name=None,
+ replica_count=REPLICA_COUNT,
+ machine_type=MACHINE_TYPE,
+ accelerator_type=ACCELERATOR_TYPE,
+ accelerator_count=ACCELERATOR_COUNT,
+ training_fraction_split=None,
+ validation_fraction_split=None,
+ test_fraction_split=None,
+ parent_model=TEST_PARENT_MODEL,
+ region=GCP_LOCATION,
+ project_id=GCP_PROJECT,
+ model_serving_container_predict_route=None,
+ model_serving_container_health_route=None,
+ model_serving_container_command=None,
+ model_serving_container_args=None,
+ model_serving_container_environment_variables=None,
+ model_serving_container_ports=None,
+ model_description=None,
+ model_instance_schema_uri=None,
+ model_parameters_schema_uri=None,
+ model_prediction_schema_uri=None,
+ labels=None,
+ training_encryption_spec_key_name=None,
+ model_encryption_spec_key_name=None,
+ # RUN
+ annotation_schema_uri=None,
+ model_labels=None,
+ base_output_dir=None,
+ service_account=None,
+ network=None,
+ bigquery_destination=None,
+ environment_variables=None,
+ boot_disk_type="pd-ssd",
+ boot_disk_size_gb=100,
+ training_filter_split=None,
+ validation_filter_split=None,
+ test_filter_split=None,
+ predefined_split_column_name=None,
+ timestamp_split_column_name=None,
+ tensorboard=None,
+ sync=True,
+ is_default_version=None,
+ model_version_aliases=None,
+ model_version_description=None,
+ )
+
@mock.patch(VERTEX_AI_PATH.format("custom_job.CreateCustomTrainingJobOperator.hook"))
def test_execute_enters_deferred_state(self, mock_hook):
task = CreateCustomTrainingJobOperator(
@@ -1083,6 +1342,71 @@ class
TestVertexAICreateAutoMLForecastingTrainingJobOperator:
model_version_description=None,
)
+ @mock.patch("google.cloud.aiplatform.datasets.TimeSeriesDataset")
+ @mock.patch(VERTEX_AI_PATH.format("auto_ml.AutoMLHook"))
+ def test_execute__parent_model_version_index_is_removed(self, mock_hook,
mock_dataset):
+
mock_hook.return_value.create_auto_ml_forecasting_training_job.return_value =
(None, "training_id")
+ op = CreateAutoMLForecastingTrainingJobOperator(
+ task_id=TASK_ID,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ display_name=DISPLAY_NAME,
+ dataset_id=TEST_DATASET_ID,
+ target_column=TEST_TRAINING_TARGET_COLUMN,
+ time_column=TEST_TRAINING_TIME_COLUMN,
+
time_series_identifier_column=TEST_TRAINING_TIME_SERIES_IDENTIFIER_COLUMN,
+
unavailable_at_forecast_columns=TEST_TRAINING_UNAVAILABLE_AT_FORECAST_COLUMNS,
+
available_at_forecast_columns=TEST_TRAINING_AVAILABLE_AT_FORECAST_COLUMNS,
+ forecast_horizon=TEST_TRAINING_FORECAST_HORIZON,
+ data_granularity_unit=TEST_TRAINING_DATA_GRANULARITY_UNIT,
+ data_granularity_count=TEST_TRAINING_DATA_GRANULARITY_COUNT,
+ sync=True,
+ region=GCP_LOCATION,
+ project_id=GCP_PROJECT,
+ parent_model=VERSIONED_TEST_PARENT_MODEL,
+ )
+ op.execute(context={"ti": mock.MagicMock()})
+
mock_hook.return_value.create_auto_ml_forecasting_training_job.assert_called_once_with(
+ project_id=GCP_PROJECT,
+ region=GCP_LOCATION,
+ display_name=DISPLAY_NAME,
+ dataset=mock_dataset.return_value,
+ target_column=TEST_TRAINING_TARGET_COLUMN,
+ time_column=TEST_TRAINING_TIME_COLUMN,
+
time_series_identifier_column=TEST_TRAINING_TIME_SERIES_IDENTIFIER_COLUMN,
+
unavailable_at_forecast_columns=TEST_TRAINING_UNAVAILABLE_AT_FORECAST_COLUMNS,
+
available_at_forecast_columns=TEST_TRAINING_AVAILABLE_AT_FORECAST_COLUMNS,
+ forecast_horizon=TEST_TRAINING_FORECAST_HORIZON,
+ data_granularity_unit=TEST_TRAINING_DATA_GRANULARITY_UNIT,
+ data_granularity_count=TEST_TRAINING_DATA_GRANULARITY_COUNT,
+ parent_model=TEST_PARENT_MODEL,
+ optimization_objective=None,
+ column_specs=None,
+ column_transformations=None,
+ labels=None,
+ training_encryption_spec_key_name=None,
+ model_encryption_spec_key_name=None,
+ training_fraction_split=None,
+ validation_fraction_split=None,
+ test_fraction_split=None,
+ predefined_split_column_name=None,
+ weight_column=None,
+ time_series_attribute_columns=None,
+ context_window=None,
+ export_evaluated_data_items=False,
+ export_evaluated_data_items_bigquery_destination_uri=None,
+ export_evaluated_data_items_override_destination=False,
+ quantiles=None,
+ validation_options=None,
+ budget_milli_node_hours=1000,
+ model_display_name=None,
+ model_labels=None,
+ sync=True,
+ is_default_version=None,
+ model_version_aliases=None,
+ model_version_description=None,
+ )
+
class TestVertexAICreateAutoMLImageTrainingJobOperator:
@mock.patch("google.cloud.aiplatform.datasets.ImageDataset")
@@ -1135,6 +1459,54 @@ class TestVertexAICreateAutoMLImageTrainingJobOperator:
model_version_description=None,
)
+ @mock.patch("google.cloud.aiplatform.datasets.ImageDataset")
+ @mock.patch(VERTEX_AI_PATH.format("auto_ml.AutoMLHook"))
+ def test_execute__parent_model_version_index_is_removed(self, mock_hook,
mock_dataset):
+ mock_hook.return_value.create_auto_ml_image_training_job.return_value
= (None, "training_id")
+ op = CreateAutoMLImageTrainingJobOperator(
+ task_id=TASK_ID,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ display_name=DISPLAY_NAME,
+ dataset_id=TEST_DATASET_ID,
+ prediction_type="classification",
+ multi_label=False,
+ model_type="CLOUD",
+ sync=True,
+ region=GCP_LOCATION,
+ project_id=GCP_PROJECT,
+ parent_model=VERSIONED_TEST_PARENT_MODEL,
+ )
+ op.execute(context={"ti": mock.MagicMock()})
+
mock_hook.return_value.create_auto_ml_image_training_job.assert_called_once_with(
+ project_id=GCP_PROJECT,
+ region=GCP_LOCATION,
+ display_name=DISPLAY_NAME,
+ dataset=mock_dataset.return_value,
+ prediction_type="classification",
+ parent_model=TEST_PARENT_MODEL,
+ multi_label=False,
+ model_type="CLOUD",
+ base_model=None,
+ labels=None,
+ training_encryption_spec_key_name=None,
+ model_encryption_spec_key_name=None,
+ training_fraction_split=None,
+ validation_fraction_split=None,
+ test_fraction_split=None,
+ training_filter_split=None,
+ validation_filter_split=None,
+ test_filter_split=None,
+ budget_milli_node_hours=None,
+ model_display_name=None,
+ model_labels=None,
+ disable_early_stopping=False,
+ sync=True,
+ is_default_version=None,
+ model_version_aliases=None,
+ model_version_description=None,
+ )
+
class TestVertexAICreateAutoMLTabularTrainingJobOperator:
@mock.patch("google.cloud.aiplatform.datasets.TabularDataset")
@@ -1199,6 +1571,64 @@ class TestVertexAICreateAutoMLTabularTrainingJobOperator:
model_version_description=None,
)
+ @mock.patch("google.cloud.aiplatform.datasets.TabularDataset")
+ @mock.patch(VERTEX_AI_PATH.format("auto_ml.AutoMLHook"))
+ def test_execute__parent_model_version_index_is_removed(self, mock_hook,
mock_dataset):
+ mock_hook.return_value = MagicMock(
+ **{
+ "create_auto_ml_tabular_training_job.return_value": (None,
"training_id"),
+ "get_credentials_and_project_id.return_value": ("creds",
"project_id"),
+ }
+ )
+ op = CreateAutoMLTabularTrainingJobOperator(
+ task_id=TASK_ID,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ display_name=DISPLAY_NAME,
+ dataset_id=TEST_DATASET_ID,
+ target_column=None,
+ optimization_prediction_type=None,
+ sync=True,
+ region=GCP_LOCATION,
+ project_id=GCP_PROJECT,
+ parent_model=VERSIONED_TEST_PARENT_MODEL,
+ )
+ op.execute(context={"ti": mock.MagicMock()})
+
mock_hook.return_value.create_auto_ml_tabular_training_job.assert_called_once_with(
+ project_id=GCP_PROJECT,
+ region=GCP_LOCATION,
+ display_name=DISPLAY_NAME,
+ dataset=mock_dataset.return_value,
+ parent_model=TEST_PARENT_MODEL,
+ target_column=None,
+ optimization_prediction_type=None,
+ optimization_objective=None,
+ column_specs=None,
+ column_transformations=None,
+ optimization_objective_recall_value=None,
+ optimization_objective_precision_value=None,
+ labels=None,
+ training_encryption_spec_key_name=None,
+ model_encryption_spec_key_name=None,
+ training_fraction_split=None,
+ validation_fraction_split=None,
+ test_fraction_split=None,
+ predefined_split_column_name=None,
+ timestamp_split_column_name=None,
+ weight_column=None,
+ budget_milli_node_hours=1000,
+ model_display_name=None,
+ model_labels=None,
+ disable_early_stopping=False,
+ export_evaluated_data_items=False,
+ export_evaluated_data_items_bigquery_destination_uri=None,
+ export_evaluated_data_items_override_destination=False,
+ sync=True,
+ is_default_version=None,
+ model_version_aliases=None,
+ model_version_description=None,
+ )
+
class TestVertexAICreateAutoMLTextTrainingJobOperator:
@mock.patch("google.cloud.aiplatform.datasets.TextDataset")
@@ -1248,6 +1678,51 @@ class TestVertexAICreateAutoMLTextTrainingJobOperator:
model_version_description=None,
)
+ @mock.patch("google.cloud.aiplatform.datasets.TextDataset")
+ @mock.patch(VERTEX_AI_PATH.format("auto_ml.AutoMLHook"))
+ def test_execute__parent_model_version_index_is_removed(self, mock_hook,
mock_dataset):
+ mock_hook.return_value.create_auto_ml_text_training_job.return_value =
(None, "training_id")
+ op = CreateAutoMLTextTrainingJobOperator(
+ task_id=TASK_ID,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ display_name=DISPLAY_NAME,
+ dataset_id=TEST_DATASET_ID,
+ prediction_type=None,
+ multi_label=False,
+ sentiment_max=10,
+ sync=True,
+ region=GCP_LOCATION,
+ project_id=GCP_PROJECT,
+ parent_model=VERSIONED_TEST_PARENT_MODEL,
+ )
+ op.execute(context={"ti": mock.MagicMock()})
+
mock_hook.return_value.create_auto_ml_text_training_job.assert_called_once_with(
+ project_id=GCP_PROJECT,
+ region=GCP_LOCATION,
+ display_name=DISPLAY_NAME,
+ dataset=mock_dataset.return_value,
+ parent_model=TEST_PARENT_MODEL,
+ prediction_type=None,
+ multi_label=False,
+ sentiment_max=10,
+ labels=None,
+ training_encryption_spec_key_name=None,
+ model_encryption_spec_key_name=None,
+ training_fraction_split=None,
+ validation_fraction_split=None,
+ test_fraction_split=None,
+ training_filter_split=None,
+ validation_filter_split=None,
+ test_filter_split=None,
+ model_display_name=None,
+ model_labels=None,
+ sync=True,
+ is_default_version=None,
+ model_version_aliases=None,
+ model_version_description=None,
+ )
+
class TestVertexAICreateAutoMLVideoTrainingJobOperator:
@mock.patch("google.cloud.aiplatform.datasets.VideoDataset")
@@ -1293,6 +1768,47 @@ class TestVertexAICreateAutoMLVideoTrainingJobOperator:
model_version_description=None,
)
+ @mock.patch("google.cloud.aiplatform.datasets.VideoDataset")
+ @mock.patch(VERTEX_AI_PATH.format("auto_ml.AutoMLHook"))
+ def test_execute__parent_model_version_index_is_removed(self, mock_hook,
mock_dataset):
+ mock_hook.return_value.create_auto_ml_video_training_job.return_value
= (None, "training_id")
+ op = CreateAutoMLVideoTrainingJobOperator(
+ task_id=TASK_ID,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ display_name=DISPLAY_NAME,
+ dataset_id=TEST_DATASET_ID,
+ prediction_type="classification",
+ model_type="CLOUD",
+ sync=True,
+ region=GCP_LOCATION,
+ project_id=GCP_PROJECT,
+ parent_model=VERSIONED_TEST_PARENT_MODEL,
+ )
+ op.execute(context={"ti": mock.MagicMock()})
+
mock_hook.return_value.create_auto_ml_video_training_job.assert_called_once_with(
+ project_id=GCP_PROJECT,
+ region=GCP_LOCATION,
+ display_name=DISPLAY_NAME,
+ dataset=mock_dataset.return_value,
+ parent_model=TEST_PARENT_MODEL,
+ prediction_type="classification",
+ model_type="CLOUD",
+ labels=None,
+ training_encryption_spec_key_name=None,
+ model_encryption_spec_key_name=None,
+ training_fraction_split=None,
+ test_fraction_split=None,
+ training_filter_split=None,
+ test_filter_split=None,
+ model_display_name=None,
+ model_labels=None,
+ sync=True,
+ is_default_version=None,
+ model_version_aliases=None,
+ model_version_description=None,
+ )
+
class TestVertexAIDeleteAutoMLTrainingJobOperator:
@mock.patch(VERTEX_AI_PATH.format("auto_ml.AutoMLHook"))
diff --git
a/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_custom_job.py
b/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_custom_job.py
index 702e6a6c51..c90c1aac23 100644
---
a/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_custom_job.py
+++
b/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_custom_job.py
@@ -17,9 +17,7 @@
# under the License.
-"""
-Example Airflow DAG for Google Vertex AI service testing Custom Jobs
operations.
-"""
+"""Example Airflow DAG for Google Vertex AI service testing Custom Jobs
operations."""
from __future__ import annotations
@@ -49,15 +47,13 @@ from airflow.providers.google.cloud.transfers.gcs_to_local
import GCSToLocalFile
from airflow.utils.trigger_rule import TriggerRule
ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID", "default")
-DAG_ID = "example_vertex_ai_custom_job_operations"
PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "default")
REGION = "us-central1"
CUSTOM_DISPLAY_NAME = f"train-housing-custom-{ENV_ID}"
MODEL_DISPLAY_NAME = f"custom-housing-model-{ENV_ID}"
-
+DAG_ID = "vertex_ai_custom_job_operations"
RESOURCE_DATA_BUCKET = "airflow-system-tests-resources"
CUSTOM_GCS_BUCKET_NAME = f"bucket_{DAG_ID}_{ENV_ID}".replace("_", "-")
-
DATA_SAMPLE_GCS_OBJECT_NAME = "vertex-ai/california_housing_train.csv"
@@ -133,6 +129,7 @@ with DAG(
region=REGION,
project_id=PROJECT_ID,
)
+
model_id_v1 = create_custom_training_job.output["model_id"]
# [END how_to_cloud_vertex_ai_create_custom_training_job_operator]
@@ -140,7 +137,7 @@ with DAG(
create_custom_training_job_deferrable = CreateCustomTrainingJobOperator(
task_id="custom_task_deferrable",
staging_bucket=f"gs://{CUSTOM_GCS_BUCKET_NAME}",
- display_name=f"{CUSTOM_DISPLAY_NAME}_DEF",
+ display_name=f"{CUSTOM_DISPLAY_NAME}-def",
script_path=LOCAL_TRAINING_SCRIPT_PATH,
container_uri=CONTAINER_URI,
requirements=["gcsfs==0.7.1"],
@@ -148,12 +145,12 @@ with DAG(
# run params
dataset_id=tabular_dataset_id,
replica_count=REPLICA_COUNT,
- model_display_name=f"{MODEL_DISPLAY_NAME}_DEF",
+ model_display_name=f"{MODEL_DISPLAY_NAME}-def",
region=REGION,
project_id=PROJECT_ID,
deferrable=True,
)
- model_id_v1_deferrable = create_custom_training_job.output["model_id"]
+ model_id_deferrable_v1 =
create_custom_training_job_deferrable.output["model_id"]
# [END
how_to_cloud_vertex_ai_create_custom_training_job_operator_deferrable]
# [START how_to_cloud_vertex_ai_create_custom_training_job_v2_operator]
@@ -176,26 +173,26 @@ with DAG(
)
# [END how_to_cloud_vertex_ai_create_custom_training_job_v2_operator]
- # [START
how_to_cloud_vertex_ai_create_custom_training_job_v2_operator_deferrable]
- create_custom_training_job_v2_deferrable = CreateCustomTrainingJobOperator(
- task_id="custom_task_v2_deferrable",
+ # [START
how_to_cloud_vertex_ai_create_custom_training_job_v2_deferrable_operator]
+ create_custom_training_job_deferrable_v2 = CreateCustomTrainingJobOperator(
+ task_id="custom_task_deferrable_v2",
staging_bucket=f"gs://{CUSTOM_GCS_BUCKET_NAME}",
- display_name=f"{CUSTOM_DISPLAY_NAME}_DEF",
+ display_name=f"{CUSTOM_DISPLAY_NAME}-def",
script_path=LOCAL_TRAINING_SCRIPT_PATH,
container_uri=CONTAINER_URI,
requirements=["gcsfs==0.7.1"],
model_serving_container_image_uri=MODEL_SERVING_CONTAINER_URI,
- parent_model=model_id_v1,
+ parent_model=model_id_deferrable_v1,
# run params
dataset_id=tabular_dataset_id,
replica_count=REPLICA_COUNT,
- model_display_name=f"{MODEL_DISPLAY_NAME}_DEF",
+ model_display_name=f"{MODEL_DISPLAY_NAME}-def",
sync=False,
region=REGION,
project_id=PROJECT_ID,
deferrable=True,
)
- # [END
how_to_cloud_vertex_ai_create_custom_training_job_v2_operator_deferrable]
+ # [END
how_to_cloud_vertex_ai_create_custom_training_job_v2_deferrable_operator]
# [START how_to_cloud_vertex_ai_delete_custom_training_job_operator]
delete_custom_training_job = DeleteCustomTrainingJobOperator(
@@ -208,6 +205,15 @@ with DAG(
)
# [END how_to_cloud_vertex_ai_delete_custom_training_job_operator]
+ delete_custom_training_job_deferrable = DeleteCustomTrainingJobOperator(
+ task_id="delete_custom_training_job_deferrable",
+ training_pipeline_id="{{
task_instance.xcom_pull(task_ids='custom_task_deferrable', key='training_id')
}}",
+ custom_job_id="{{
task_instance.xcom_pull(task_ids='custom_task_deferrable', key='custom_job_id')
}}",
+ region=REGION,
+ project_id=PROJECT_ID,
+ trigger_rule=TriggerRule.ALL_DONE,
+ )
+
delete_tabular_dataset = DeleteDatasetOperator(
task_id="delete_tabular_dataset",
dataset_id=tabular_dataset_id,
@@ -230,15 +236,14 @@ with DAG(
create_tabular_dataset,
# TEST BODY
[create_custom_training_job,
create_custom_training_job_deferrable],
- [create_custom_training_job_v2,
create_custom_training_job_v2_deferrable],
+ [create_custom_training_job_v2,
create_custom_training_job_deferrable_v2],
# TEST TEARDOWN
- delete_custom_training_job,
+ [delete_custom_training_job,
delete_custom_training_job_deferrable],
delete_tabular_dataset,
delete_bucket,
)
)
-
from tests.system.utils import get_test_run # noqa: E402
# Needed to run the example DAG with pytest (see:
tests/system/README.md#run_via_pytest)