This is an automated email from the ASF dual-hosted git repository.

shahar pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new de894ecf4de Add new operator for Vertex AI (#51965)
de894ecf4de is described below

commit de894ecf4de89c7307276cde8e861e7a9e2dc28d
Author: Nitochkin <[email protected]>
AuthorDate: Fri Jun 20 22:23:38 2025 +0200

    Add new operator for Vertex AI (#51965)
    
    Co-authored-by: Anton Nitochkin <[email protected]>
---
 .../google/docs/operators/cloud/vertex_ai.rst      | 12 ++++
 .../cloud/hooks/vertex_ai/generative_model.py      | 30 +++++++++
 .../cloud/operators/vertex_ai/generative_model.py  | 73 +++++++++++++++++++++-
 .../example_vertex_ai_generative_model.py          | 12 ++++
 .../operators/vertex_ai/test_generative_model.py   | 29 +++++++++
 5 files changed, 155 insertions(+), 1 deletion(-)

diff --git a/providers/google/docs/operators/cloud/vertex_ai.rst 
b/providers/google/docs/operators/cloud/vertex_ai.rst
index f870623bee0..646cb4fee51 100644
--- a/providers/google/docs/operators/cloud/vertex_ai.rst
+++ b/providers/google/docs/operators/cloud/vertex_ai.rst
@@ -741,6 +741,18 @@ To update cluster you can use
     :start-after: [START how_to_cloud_vertex_ai_update_ray_cluster_operator]
     :end-before: [END how_to_cloud_vertex_ai_update_ray_cluster_operator]
 
+Interacting with experiment run
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+To delete experiment run you can use
+:class:`~airflow.providers.google.cloud.operators.vertex_ai.generative_model.DeleteExperimentRunOperator`.
+
+.. exampleinclude:: 
/../../google/tests/system/google/cloud/vertex_ai/example_vertex_ai_generative_model.py
+    :language: python
+    :dedent: 4
+    :start-after: [START how_to_cloud_vertex_ai_delete_experiment_run_operator]
+    :end-before: [END how_to_cloud_vertex_ai_delete_experiment_run_operator]
+
 Reference
 ^^^^^^^^^
 
diff --git 
a/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py
 
b/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py
index d6d5a6cbadc..f31e131b0a2 100644
--- 
a/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py
+++ 
b/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py
@@ -24,6 +24,7 @@ from datetime import timedelta
 from typing import TYPE_CHECKING
 
 import vertexai
+from google.cloud import aiplatform
 from vertexai.generative_models import GenerativeModel
 from vertexai.language_models import TextEmbeddingModel
 from vertexai.preview.caching import CachedContent
@@ -359,3 +360,32 @@ class GenerativeModelHook(GoogleBaseHook):
         )
 
         return response.text
+
+
+class ExperimentRunHook(GoogleBaseHook):
+    """Use the Vertex AI SDK for Python to create and manage your experiment 
runs."""
+
+    @GoogleBaseHook.fallback_to_default_project_id
+    def delete_experiment_run(
+        self,
+        experiment_run_name: str,
+        experiment_name: str,
+        location: str,
+        project_id: str = PROVIDE_PROJECT_ID,
+        delete_backing_tensorboard_run: bool = False,
+    ) -> None:
+        """
+        Delete experiment run from the experiment.
+
+        :param project_id: Required. The ID of the Google Cloud project that 
the service belongs to.
+        :param location: Required. The ID of the Google Cloud location that 
the service belongs to.
+        :param experiment_name: Required. The name of the evaluation 
experiment.
+        :param experiment_run_name: Required. The specific run name or ID for 
this experiment.
+        :param delete_backing_tensorboard_run: Whether to delete the backing 
Vertex AI TensorBoard run
+            that stores time series metrics for this run.
+        """
+        self.log.info("Next experiment run will be deleted: %s", 
experiment_run_name)
+        experiment_run = aiplatform.ExperimentRun(
+            run_name=experiment_run_name, experiment=experiment_name, 
project=project_id, location=location
+        )
+        
experiment_run.delete(delete_backing_tensorboard_run=delete_backing_tensorboard_run)
diff --git 
a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py
 
b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py
index ec8d7749cd0..bf7ffa5475c 100644
--- 
a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py
+++ 
b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py
@@ -22,7 +22,13 @@ from __future__ import annotations
 from collections.abc import Sequence
 from typing import TYPE_CHECKING
 
-from airflow.providers.google.cloud.hooks.vertex_ai.generative_model import 
GenerativeModelHook
+from google.api_core import exceptions
+
+from airflow.exceptions import AirflowException
+from airflow.providers.google.cloud.hooks.vertex_ai.generative_model import (
+    ExperimentRunHook,
+    GenerativeModelHook,
+)
 from airflow.providers.google.cloud.operators.cloud_base import 
GoogleCloudBaseOperator
 
 if TYPE_CHECKING:
@@ -580,3 +586,68 @@ class 
GenerateFromCachedContentOperator(GoogleCloudBaseOperator):
         self.log.info("Cached Content Response: %s", cached_content_text)
 
         return cached_content_text
+
+
+class DeleteExperimentRunOperator(GoogleCloudBaseOperator):
+    """
+    Use the Rapid Evaluation API to evaluate a model.
+
+    :param project_id: Required. The ID of the Google Cloud project that the 
service belongs to.
+    :param location: Required. The ID of the Google Cloud location that the 
service belongs to.
+    :param experiment_name: Required. The name of the evaluation experiment.
+    :param experiment_run_name: Required. The specific run name or ID for this 
experiment.
+    :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
+    :param impersonation_chain: Optional service account to impersonate using 
short-term
+        credentials, or chained list of accounts required to get the 
access_token
+        of the last account in the list, which will be impersonated in the 
request.
+        If set as a string, the account must grant the originating account
+        the Service Account Token Creator IAM role.
+        If set as a sequence, the identities from the list must grant
+        Service Account Token Creator IAM role to the directly preceding 
identity, with first
+        account from the list granting this role to the originating account 
(templated).
+    """
+
+    template_fields = (
+        "location",
+        "project_id",
+        "impersonation_chain",
+        "experiment_name",
+        "experiment_run_name",
+    )
+
+    def __init__(
+        self,
+        *,
+        project_id: str,
+        location: str,
+        experiment_name: str,
+        experiment_run_name: str,
+        gcp_conn_id: str = "google_cloud_default",
+        impersonation_chain: str | Sequence[str] | None = None,
+        **kwargs,
+    ) -> None:
+        super().__init__(**kwargs)
+        self.project_id = project_id
+        self.location = location
+        self.experiment_name = experiment_name
+        self.experiment_run_name = experiment_run_name
+        self.gcp_conn_id = gcp_conn_id
+        self.impersonation_chain = impersonation_chain
+
+    def execute(self, context: Context) -> None:
+        self.hook = ExperimentRunHook(
+            gcp_conn_id=self.gcp_conn_id,
+            impersonation_chain=self.impersonation_chain,
+        )
+
+        try:
+            self.hook.delete_experiment_run(
+                project_id=self.project_id,
+                location=self.location,
+                experiment_name=self.experiment_name,
+                experiment_run_name=self.experiment_run_name,
+            )
+        except exceptions.NotFound:
+            raise AirflowException(f"Experiment Run with name 
{self.experiment_run_name} not found")
+
+        self.log.info("Deleted experiment run: %s", self.experiment_run_name)
diff --git 
a/providers/google/tests/system/google/cloud/vertex_ai/example_vertex_ai_generative_model.py
 
b/providers/google/tests/system/google/cloud/vertex_ai/example_vertex_ai_generative_model.py
index eb453b2d854..413c1f08ec0 100644
--- 
a/providers/google/tests/system/google/cloud/vertex_ai/example_vertex_ai_generative_model.py
+++ 
b/providers/google/tests/system/google/cloud/vertex_ai/example_vertex_ai_generative_model.py
@@ -33,6 +33,7 @@ from airflow.models.dag import DAG
 from airflow.providers.google.cloud.operators.vertex_ai.generative_model 
import (
     CountTokensOperator,
     CreateCachedContentOperator,
+    DeleteExperimentRunOperator,
     GenerateFromCachedContentOperator,
     GenerativeModelGenerateContentOperator,
     RunEvaluationOperator,
@@ -238,6 +239,16 @@ with DAG(
     )
     # [END how_to_cloud_vertex_ai_run_evaluation_operator]
 
+    # [START how_to_cloud_vertex_ai_delete_experiment_run_operator]
+    delete_experiment_run = DeleteExperimentRunOperator(
+        task_id="delete_experiment_run_task",
+        project_id=PROJECT_ID,
+        location=REGION,
+        experiment_name=EXPERIMENT_NAME,
+        experiment_run_name=EXPERIMENT_RUN_NAME,
+    )
+    # [END how_to_cloud_vertex_ai_delete_experiment_run_operator]
+
     # [START how_to_cloud_vertex_ai_create_cached_content_operator]
     create_cached_content_task = CreateCachedContentOperator(
         task_id="create_cached_content_task",
@@ -264,6 +275,7 @@ with DAG(
     # [END how_to_cloud_vertex_ai_generate_from_cached_content_operator]
 
     create_cached_content_task >> generate_from_cached_content_task
+    run_evaluation_task >> delete_experiment_run
 
     from tests_common.test_utils.watcher import watcher
 
diff --git 
a/providers/google/tests/unit/google/cloud/operators/vertex_ai/test_generative_model.py
 
b/providers/google/tests/unit/google/cloud/operators/vertex_ai/test_generative_model.py
index 2683936e7f5..fbd8fde9c0d 100644
--- 
a/providers/google/tests/unit/google/cloud/operators/vertex_ai/test_generative_model.py
+++ 
b/providers/google/tests/unit/google/cloud/operators/vertex_ai/test_generative_model.py
@@ -30,6 +30,7 @@ from vertexai.preview.evaluation import 
MetricPromptTemplateExamples
 from airflow.providers.google.cloud.operators.vertex_ai.generative_model 
import (
     CountTokensOperator,
     CreateCachedContentOperator,
+    DeleteExperimentRunOperator,
     GenerateFromCachedContentOperator,
     GenerativeModelGenerateContentOperator,
     RunEvaluationOperator,
@@ -355,3 +356,31 @@ class TestVertexAIGenerateFromCachedContentOperator:
             generation_config=None,
             safety_settings=None,
         )
+
+
+class TestVertexAIDeleteExperimentRunOperator:
+    @mock.patch(VERTEX_AI_PATH.format("generative_model.ExperimentRunHook"))
+    def test_execute(self, mock_hook):
+        test_experiment_name = "test_experiment_name"
+        test_experiment_run_name = "test_experiment_run_name"
+
+        op = DeleteExperimentRunOperator(
+            task_id=TASK_ID,
+            project_id=GCP_PROJECT,
+            location=GCP_LOCATION,
+            experiment_name=test_experiment_name,
+            experiment_run_name=test_experiment_run_name,
+            gcp_conn_id=GCP_CONN_ID,
+            impersonation_chain=IMPERSONATION_CHAIN,
+        )
+        op.execute(context={"ti": mock.MagicMock()})
+        mock_hook.assert_called_once_with(
+            gcp_conn_id=GCP_CONN_ID,
+            impersonation_chain=IMPERSONATION_CHAIN,
+        )
+        mock_hook.return_value.delete_experiment_run.assert_called_once_with(
+            project_id=GCP_PROJECT,
+            location=GCP_LOCATION,
+            experiment_name=test_experiment_name,
+            experiment_run_name=test_experiment_run_name,
+        )

Reply via email to