This is an automated email from the ASF dual-hosted git repository.
shahar pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new de894ecf4de Add new operator for Vertex AI (#51965)
de894ecf4de is described below
commit de894ecf4de89c7307276cde8e861e7a9e2dc28d
Author: Nitochkin <[email protected]>
AuthorDate: Fri Jun 20 22:23:38 2025 +0200
Add new operator for Vertex AI (#51965)
Co-authored-by: Anton Nitochkin <[email protected]>
---
.../google/docs/operators/cloud/vertex_ai.rst | 12 ++++
.../cloud/hooks/vertex_ai/generative_model.py | 30 +++++++++
.../cloud/operators/vertex_ai/generative_model.py | 73 +++++++++++++++++++++-
.../example_vertex_ai_generative_model.py | 12 ++++
.../operators/vertex_ai/test_generative_model.py | 29 +++++++++
5 files changed, 155 insertions(+), 1 deletion(-)
diff --git a/providers/google/docs/operators/cloud/vertex_ai.rst
b/providers/google/docs/operators/cloud/vertex_ai.rst
index f870623bee0..646cb4fee51 100644
--- a/providers/google/docs/operators/cloud/vertex_ai.rst
+++ b/providers/google/docs/operators/cloud/vertex_ai.rst
@@ -741,6 +741,18 @@ To update cluster you can use
:start-after: [START how_to_cloud_vertex_ai_update_ray_cluster_operator]
:end-before: [END how_to_cloud_vertex_ai_update_ray_cluster_operator]
+Interacting with experiment run
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+To delete experiment run you can use
+:class:`~airflow.providers.google.cloud.operators.vertex_ai.generative_model.DeleteExperimentRunOperator`.
+
+.. exampleinclude::
/../../google/tests/system/google/cloud/vertex_ai/example_vertex_ai_generative_model.py
+ :language: python
+ :dedent: 4
+ :start-after: [START how_to_cloud_vertex_ai_delete_experiment_run_operator]
+ :end-before: [END how_to_cloud_vertex_ai_delete_experiment_run_operator]
+
Reference
^^^^^^^^^
diff --git
a/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py
b/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py
index d6d5a6cbadc..f31e131b0a2 100644
---
a/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py
+++
b/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py
@@ -24,6 +24,7 @@ from datetime import timedelta
from typing import TYPE_CHECKING
import vertexai
+from google.cloud import aiplatform
from vertexai.generative_models import GenerativeModel
from vertexai.language_models import TextEmbeddingModel
from vertexai.preview.caching import CachedContent
@@ -359,3 +360,32 @@ class GenerativeModelHook(GoogleBaseHook):
)
return response.text
+
+
+class ExperimentRunHook(GoogleBaseHook):
+ """Use the Vertex AI SDK for Python to create and manage your experiment
runs."""
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ def delete_experiment_run(
+ self,
+ experiment_run_name: str,
+ experiment_name: str,
+ location: str,
+ project_id: str = PROVIDE_PROJECT_ID,
+ delete_backing_tensorboard_run: bool = False,
+ ) -> None:
+ """
+ Delete experiment run from the experiment.
+
+ :param project_id: Required. The ID of the Google Cloud project that
the service belongs to.
+ :param location: Required. The ID of the Google Cloud location that
the service belongs to.
+ :param experiment_name: Required. The name of the evaluation
experiment.
+ :param experiment_run_name: Required. The specific run name or ID for
this experiment.
+ :param delete_backing_tensorboard_run: Whether to delete the backing
Vertex AI TensorBoard run
+ that stores time series metrics for this run.
+ """
+ self.log.info("Next experiment run will be deleted: %s",
experiment_run_name)
+ experiment_run = aiplatform.ExperimentRun(
+ run_name=experiment_run_name, experiment=experiment_name,
project=project_id, location=location
+ )
+
experiment_run.delete(delete_backing_tensorboard_run=delete_backing_tensorboard_run)
diff --git
a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py
b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py
index ec8d7749cd0..bf7ffa5475c 100644
---
a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py
+++
b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py
@@ -22,7 +22,13 @@ from __future__ import annotations
from collections.abc import Sequence
from typing import TYPE_CHECKING
-from airflow.providers.google.cloud.hooks.vertex_ai.generative_model import
GenerativeModelHook
+from google.api_core import exceptions
+
+from airflow.exceptions import AirflowException
+from airflow.providers.google.cloud.hooks.vertex_ai.generative_model import (
+ ExperimentRunHook,
+ GenerativeModelHook,
+)
from airflow.providers.google.cloud.operators.cloud_base import
GoogleCloudBaseOperator
if TYPE_CHECKING:
@@ -580,3 +586,68 @@ class
GenerateFromCachedContentOperator(GoogleCloudBaseOperator):
self.log.info("Cached Content Response: %s", cached_content_text)
return cached_content_text
+
+
+class DeleteExperimentRunOperator(GoogleCloudBaseOperator):
+ """
+ Use the Rapid Evaluation API to evaluate a model.
+
+ :param project_id: Required. The ID of the Google Cloud project that the
service belongs to.
+ :param location: Required. The ID of the Google Cloud location that the
service belongs to.
+ :param experiment_name: Required. The name of the evaluation experiment.
+ :param experiment_run_name: Required. The specific run name or ID for this
experiment.
+ :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
+ :param impersonation_chain: Optional service account to impersonate using
short-term
+ credentials, or chained list of accounts required to get the
access_token
+ of the last account in the list, which will be impersonated in the
request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding
identity, with first
+ account from the list granting this role to the originating account
(templated).
+ """
+
+ template_fields = (
+ "location",
+ "project_id",
+ "impersonation_chain",
+ "experiment_name",
+ "experiment_run_name",
+ )
+
+ def __init__(
+ self,
+ *,
+ project_id: str,
+ location: str,
+ experiment_name: str,
+ experiment_run_name: str,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.project_id = project_id
+ self.location = location
+ self.experiment_name = experiment_name
+ self.experiment_run_name = experiment_run_name
+ self.gcp_conn_id = gcp_conn_id
+ self.impersonation_chain = impersonation_chain
+
+ def execute(self, context: Context) -> None:
+ self.hook = ExperimentRunHook(
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ )
+
+ try:
+ self.hook.delete_experiment_run(
+ project_id=self.project_id,
+ location=self.location,
+ experiment_name=self.experiment_name,
+ experiment_run_name=self.experiment_run_name,
+ )
+ except exceptions.NotFound:
+ raise AirflowException(f"Experiment Run with name
{self.experiment_run_name} not found")
+
+ self.log.info("Deleted experiment run: %s", self.experiment_run_name)
diff --git
a/providers/google/tests/system/google/cloud/vertex_ai/example_vertex_ai_generative_model.py
b/providers/google/tests/system/google/cloud/vertex_ai/example_vertex_ai_generative_model.py
index eb453b2d854..413c1f08ec0 100644
---
a/providers/google/tests/system/google/cloud/vertex_ai/example_vertex_ai_generative_model.py
+++
b/providers/google/tests/system/google/cloud/vertex_ai/example_vertex_ai_generative_model.py
@@ -33,6 +33,7 @@ from airflow.models.dag import DAG
from airflow.providers.google.cloud.operators.vertex_ai.generative_model
import (
CountTokensOperator,
CreateCachedContentOperator,
+ DeleteExperimentRunOperator,
GenerateFromCachedContentOperator,
GenerativeModelGenerateContentOperator,
RunEvaluationOperator,
@@ -238,6 +239,16 @@ with DAG(
)
# [END how_to_cloud_vertex_ai_run_evaluation_operator]
+ # [START how_to_cloud_vertex_ai_delete_experiment_run_operator]
+ delete_experiment_run = DeleteExperimentRunOperator(
+ task_id="delete_experiment_run_task",
+ project_id=PROJECT_ID,
+ location=REGION,
+ experiment_name=EXPERIMENT_NAME,
+ experiment_run_name=EXPERIMENT_RUN_NAME,
+ )
+ # [END how_to_cloud_vertex_ai_delete_experiment_run_operator]
+
# [START how_to_cloud_vertex_ai_create_cached_content_operator]
create_cached_content_task = CreateCachedContentOperator(
task_id="create_cached_content_task",
@@ -264,6 +275,7 @@ with DAG(
# [END how_to_cloud_vertex_ai_generate_from_cached_content_operator]
create_cached_content_task >> generate_from_cached_content_task
+ run_evaluation_task >> delete_experiment_run
from tests_common.test_utils.watcher import watcher
diff --git
a/providers/google/tests/unit/google/cloud/operators/vertex_ai/test_generative_model.py
b/providers/google/tests/unit/google/cloud/operators/vertex_ai/test_generative_model.py
index 2683936e7f5..fbd8fde9c0d 100644
---
a/providers/google/tests/unit/google/cloud/operators/vertex_ai/test_generative_model.py
+++
b/providers/google/tests/unit/google/cloud/operators/vertex_ai/test_generative_model.py
@@ -30,6 +30,7 @@ from vertexai.preview.evaluation import
MetricPromptTemplateExamples
from airflow.providers.google.cloud.operators.vertex_ai.generative_model
import (
CountTokensOperator,
CreateCachedContentOperator,
+ DeleteExperimentRunOperator,
GenerateFromCachedContentOperator,
GenerativeModelGenerateContentOperator,
RunEvaluationOperator,
@@ -355,3 +356,31 @@ class TestVertexAIGenerateFromCachedContentOperator:
generation_config=None,
safety_settings=None,
)
+
+
+class TestVertexAIDeleteExperimentRunOperator:
+ @mock.patch(VERTEX_AI_PATH.format("generative_model.ExperimentRunHook"))
+ def test_execute(self, mock_hook):
+ test_experiment_name = "test_experiment_name"
+ test_experiment_run_name = "test_experiment_run_name"
+
+ op = DeleteExperimentRunOperator(
+ task_id=TASK_ID,
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ experiment_name=test_experiment_name,
+ experiment_run_name=test_experiment_run_name,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ )
+ op.execute(context={"ti": mock.MagicMock()})
+ mock_hook.assert_called_once_with(
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ )
+ mock_hook.return_value.delete_experiment_run.assert_called_once_with(
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ experiment_name=test_experiment_name,
+ experiment_run_name=test_experiment_run_name,
+ )