This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 9442435b87 Rename Vertex AI AutoML operators fields' names to comply
with templated fields validation (#38049)
9442435b87 is described below
commit 9442435b87973e48c6e726d970ab1c8de0dd8265
Author: Shahar Epstein <[email protected]>
AuthorDate: Fri Mar 15 15:00:31 2024 +0200
Rename Vertex AI AutoML operators fields' names to comply with templated
fields validation (#38049)
---
.pre-commit-config.yaml | 1 -
.../google/cloud/operators/vertex_ai/auto_ml.py | 16 +++++++++++++--
.../google/cloud/operators/test_vertex_ai.py | 24 ++++++++++++++++++++++
3 files changed, 38 insertions(+), 3 deletions(-)
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 68eea393f1..6f5cdc50b8 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -337,7 +337,6 @@ repos:
^airflow\/providers\/google\/cloud\/operators\/mlengine.py$|
^airflow\/providers\/google\/cloud\/operators\/cloud_storage_transfer_service.py$|
^airflow\/providers\/apache\/spark\/operators\/spark_submit.py\.py$|
-
^airflow\/providers\/google\/cloud\/operators\/vertex_ai\/auto_ml\.py$|
^airflow\/providers\/apache\/spark\/operators\/spark_submit\.py$|
^airflow\/providers\/databricks\/operators\/databricks_sql\.py$|
)$
diff --git a/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py
b/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py
index d7eb3e01f7..5269c4db25 100644
--- a/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py
+++ b/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py
@@ -22,12 +22,14 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Sequence
+from deprecated import deprecated
from google.api_core.exceptions import NotFound
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
from google.cloud.aiplatform import datasets
from google.cloud.aiplatform.models import Model
from google.cloud.aiplatform_v1.types.training_pipeline import TrainingPipeline
+from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.google.cloud.hooks.vertex_ai.auto_ml import AutoMLHook
from airflow.providers.google.cloud.links.vertex_ai import (
VertexAIModelLink,
@@ -607,7 +609,7 @@ class
DeleteAutoMLTrainingJobOperator(GoogleCloudBaseOperator):
AutoMLTabularTrainingJob, AutoMLTextTrainingJob, or AutoMLVideoTrainingJob.
"""
- template_fields = ("training_pipeline", "region", "project_id",
"impersonation_chain")
+ template_fields = ("training_pipeline_id", "region", "project_id",
"impersonation_chain")
def __init__(
self,
@@ -623,7 +625,7 @@ class
DeleteAutoMLTrainingJobOperator(GoogleCloudBaseOperator):
**kwargs,
) -> None:
super().__init__(**kwargs)
- self.training_pipeline = training_pipeline_id
+ self.training_pipeline_id = training_pipeline_id
self.region = region
self.project_id = project_id
self.retry = retry
@@ -632,6 +634,16 @@ class
DeleteAutoMLTrainingJobOperator(GoogleCloudBaseOperator):
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
+ @property
+ @deprecated(
+ reason="`training_pipeline` is deprecated and will be removed in the
future. "
+ "Please use `training_pipeline_id` instead.",
+ category=AirflowProviderDeprecationWarning,
+ )
+ def training_pipeline(self):
+ """Alias for ``training_pipeline_id``, used for compatibility
(deprecated)."""
+ return self.training_pipeline_id
+
def execute(self, context: Context):
hook = AutoMLHook(
gcp_conn_id=self.gcp_conn_id,
diff --git a/tests/providers/google/cloud/operators/test_vertex_ai.py
b/tests/providers/google/cloud/operators/test_vertex_ai.py
index a092cebafb..3f14525471 100644
--- a/tests/providers/google/cloud/operators/test_vertex_ai.py
+++ b/tests/providers/google/cloud/operators/test_vertex_ai.py
@@ -1028,6 +1028,30 @@ class TestVertexAIDeleteAutoMLTrainingJobOperator:
metadata=METADATA,
)
+ @pytest.mark.db_test
+ def test_templating(self, create_task_instance_of_operator):
+ ti = create_task_instance_of_operator(
+ DeleteAutoMLTrainingJobOperator,
+ # Templated fields
+ training_pipeline_id="{{ 'training-pipeline-id' }}",
+ region="{{ 'region' }}",
+ project_id="{{ 'project-id' }}",
+ impersonation_chain="{{ 'impersonation-chain' }}",
+ # Other parameters
+ dag_id="test_template_body_templating_dag",
+ task_id="test_template_body_templating_task",
+ execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
+ )
+ ti.render_templates()
+ task: DeleteAutoMLTrainingJobOperator = ti.task
+ assert task.training_pipeline_id == "training-pipeline-id"
+ assert task.region == "region"
+ assert task.project_id == "project-id"
+ assert task.impersonation_chain == "impersonation-chain"
+
+ with pytest.warns(AirflowProviderDeprecationWarning):
+ assert task.training_pipeline == "training-pipeline-id"
+
class TestVertexAIListAutoMLTrainingJobOperator:
@mock.patch(VERTEX_AI_PATH.format("auto_ml.AutoMLHook"))