This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new e2b8f68ddf add generation_config and safety_settings to google cloud 
multimodal model operators (#40126)
e2b8f68ddf is described below

commit e2b8f68ddf84bf55f350461aaf1c801e9a152ac1
Author: Christian Yarros <[email protected]>
AuthorDate: Fri Jun 14 10:13:50 2024 -0400

    add generation_config and safety_settings to google cloud multimodal model 
operators (#40126)
---
 .../cloud/hooks/vertex_ai/generative_model.py      | 18 ++++++++++++---
 .../cloud/operators/vertex_ai/generative_model.py  | 16 +++++++++++++
 airflow/providers/google/provider.yaml             |  2 +-
 generated/provider_dependencies.json               |  2 +-
 .../cloud/hooks/vertex_ai/test_generative_model.py | 27 ++++++++++++++++++++--
 .../operators/vertex_ai/test_generative_model.py   | 24 +++++++++++++++++++
 .../example_vertex_ai_generative_model.py          | 13 +++++++++++
 7 files changed, 95 insertions(+), 7 deletions(-)

diff --git a/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py 
b/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py
index c71759c890..eb3db0f3c6 100644
--- a/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py
+++ b/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py
@@ -141,6 +141,8 @@ class GenerativeModelHook(GoogleBaseHook):
         self,
         prompt: str,
         location: str,
+        generation_config: dict | None = None,
+        safety_settings: dict | None = None,
         pretrained_model: str = "gemini-pro",
         project_id: str = PROVIDE_PROJECT_ID,
     ) -> str:
@@ -149,17 +151,21 @@ class GenerativeModelHook(GoogleBaseHook):
 
         :param prompt: Required. Inputs or queries that a user or a program 
gives
             to the Multi-modal model, in order to elicit a specific response.
+        :param location: Required. The ID of the Google Cloud location that 
the service belongs to.
+        :param generation_config: Optional. Generation configuration settings.
+        :param safety_settings: Optional. Per request settings for blocking 
unsafe content.
         :param pretrained_model: By default uses the pre-trained model 
`gemini-pro`,
             supporting prompts with text-only input, including natural language
             tasks, multi-turn text and code chat, and code generation. It can
             output text and code.
-        :param location: Required. The ID of the Google Cloud location that 
the service belongs to.
         :param project_id: Required. The ID of the Google Cloud project that 
the service belongs to.
         """
         vertexai.init(project=project_id, location=location, 
credentials=self.get_credentials())
 
         model = self.get_generative_model(pretrained_model)
-        response = model.generate_content(prompt)
+        response = model.generate_content(
+            contents=[prompt], generation_config=generation_config, 
safety_settings=safety_settings
+        )
 
         return response.text
 
@@ -170,6 +176,8 @@ class GenerativeModelHook(GoogleBaseHook):
         location: str,
         media_gcs_path: str,
         mime_type: str,
+        generation_config: dict | None = None,
+        safety_settings: dict | None = None,
         pretrained_model: str = "gemini-pro-vision",
         project_id: str = PROVIDE_PROJECT_ID,
     ) -> str:
@@ -178,6 +186,8 @@ class GenerativeModelHook(GoogleBaseHook):
 
         :param prompt: Required. Inputs or queries that a user or a program 
gives
             to the Multi-modal model, in order to elicit a specific response.
+        :param generation_config: Optional. Generation configuration settings.
+        :param safety_settings: Optional. Per request settings for blocking 
unsafe content.
         :param pretrained_model: By default uses the pre-trained model 
`gemini-pro-vision`,
             supporting prompts with text-only input, including natural language
             tasks, multi-turn text and code chat, and code generation. It can
@@ -192,6 +202,8 @@ class GenerativeModelHook(GoogleBaseHook):
 
         model = self.get_generative_model(pretrained_model)
         part = self.get_generative_model_part(media_gcs_path, mime_type)
-        response = model.generate_content([prompt, part])
+        response = model.generate_content(
+            contents=[prompt, part], generation_config=generation_config, 
safety_settings=safety_settings
+        )
 
         return response.text
diff --git 
a/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py 
b/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py
index da1436a6ab..a42b00c677 100644
--- a/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py
+++ b/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py
@@ -187,6 +187,8 @@ class 
PromptMultimodalModelOperator(GoogleCloudBaseOperator):
         service belongs to (templated).
     :param prompt: Required. Inputs or queries that a user or a program gives
         to the Multi-modal model, in order to elicit a specific response 
(templated).
+    :param generation_config: Optional. Generation configuration settings.
+    :param safety_settings: Optional. Per request settings for blocking unsafe 
content.
     :param pretrained_model: By default uses the pre-trained model 
`gemini-pro`,
         supporting prompts with text-only input, including natural language
         tasks, multi-turn text and code chat, and code generation. It can
@@ -210,6 +212,8 @@ class 
PromptMultimodalModelOperator(GoogleCloudBaseOperator):
         project_id: str,
         location: str,
         prompt: str,
+        generation_config: dict | None = None,
+        safety_settings: dict | None = None,
         pretrained_model: str = "gemini-pro",
         gcp_conn_id: str = "google_cloud_default",
         impersonation_chain: str | Sequence[str] | None = None,
@@ -219,6 +223,8 @@ class 
PromptMultimodalModelOperator(GoogleCloudBaseOperator):
         self.project_id = project_id
         self.location = location
         self.prompt = prompt
+        self.generation_config = generation_config
+        self.safety_settings = safety_settings
         self.pretrained_model = pretrained_model
         self.gcp_conn_id = gcp_conn_id
         self.impersonation_chain = impersonation_chain
@@ -232,6 +238,8 @@ class 
PromptMultimodalModelOperator(GoogleCloudBaseOperator):
             project_id=self.project_id,
             location=self.location,
             prompt=self.prompt,
+            generation_config=self.generation_config,
+            safety_settings=self.safety_settings,
             pretrained_model=self.pretrained_model,
         )
 
@@ -251,6 +259,8 @@ class 
PromptMultimodalModelWithMediaOperator(GoogleCloudBaseOperator):
         service belongs to (templated).
     :param prompt: Required. Inputs or queries that a user or a program gives
         to the Multi-modal model, in order to elicit a specific response 
(templated).
+    :param generation_config: Optional. Generation configuration settings.
+    :param safety_settings: Optional. Per request settings for blocking unsafe 
content.
     :param pretrained_model: By default uses the pre-trained model 
`gemini-pro-vision`,
         supporting prompts with text-only input, including natural language
         tasks, multi-turn text and code chat, and code generation. It can
@@ -279,6 +289,8 @@ class 
PromptMultimodalModelWithMediaOperator(GoogleCloudBaseOperator):
         prompt: str,
         media_gcs_path: str,
         mime_type: str,
+        generation_config: dict | None = None,
+        safety_settings: dict | None = None,
         pretrained_model: str = "gemini-pro-vision",
         gcp_conn_id: str = "google_cloud_default",
         impersonation_chain: str | Sequence[str] | None = None,
@@ -288,6 +300,8 @@ class 
PromptMultimodalModelWithMediaOperator(GoogleCloudBaseOperator):
         self.project_id = project_id
         self.location = location
         self.prompt = prompt
+        self.generation_config = generation_config
+        self.safety_settings = safety_settings
         self.pretrained_model = pretrained_model
         self.media_gcs_path = media_gcs_path
         self.mime_type = mime_type
@@ -303,6 +317,8 @@ class 
PromptMultimodalModelWithMediaOperator(GoogleCloudBaseOperator):
             project_id=self.project_id,
             location=self.location,
             prompt=self.prompt,
+            generation_config=self.generation_config,
+            safety_settings=self.safety_settings,
             pretrained_model=self.pretrained_model,
             media_gcs_path=self.media_gcs_path,
             mime_type=self.mime_type,
diff --git a/airflow/providers/google/provider.yaml 
b/airflow/providers/google/provider.yaml
index 205e9f398e..376dcb92f2 100644
--- a/airflow/providers/google/provider.yaml
+++ b/airflow/providers/google/provider.yaml
@@ -113,7 +113,7 @@ dependencies:
   - google-api-python-client>=2.0.2
   - google-auth>=2.29.0
   - google-auth-httplib2>=0.0.1
-  - google-cloud-aiplatform>=1.42.1
+  - google-cloud-aiplatform>=1.54.0
   - google-cloud-automl>=2.12.0
   # google-cloud-bigquery version 3.21.0 introduced a performance enhancement 
in QueryJob.result(),
   # which has led to backward compatibility issues
diff --git a/generated/provider_dependencies.json 
b/generated/provider_dependencies.json
index 6c7ca749f7..6fdf5d9c29 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -593,7 +593,7 @@
       "google-api-python-client>=2.0.2",
       "google-auth-httplib2>=0.0.1",
       "google-auth>=2.29.0",
-      "google-cloud-aiplatform>=1.42.1",
+      "google-cloud-aiplatform>=1.54.0",
       "google-cloud-automl>=2.12.0",
       "google-cloud-batch>=0.13.0",
       "google-cloud-bigquery-datatransfer>=3.13.0",
diff --git 
a/tests/providers/google/cloud/hooks/vertex_ai/test_generative_model.py 
b/tests/providers/google/cloud/hooks/vertex_ai/test_generative_model.py
index 1899222174..2308903485 100644
--- a/tests/providers/google/cloud/hooks/vertex_ai/test_generative_model.py
+++ b/tests/providers/google/cloud/hooks/vertex_ai/test_generative_model.py
@@ -23,6 +23,8 @@ import pytest
 
 # For no Pydantic environment, we need to skip the tests
 pytest.importorskip("google.cloud.aiplatform_v1")
+vertexai = pytest.importorskip("vertexai.generative_models")
+from vertexai.generative_models import HarmBlockThreshold, HarmCategory
 
 from airflow.providers.google.cloud.hooks.vertex_ai.generative_model import (
     GenerativeModelHook,
@@ -45,6 +47,17 @@ TEST_TOP_K = 40
 TEST_TEXT_EMBEDDING_MODEL = ""
 
 TEST_MULTIMODAL_PRETRAINED_MODEL = "gemini-pro"
+TEST_SAFETY_SETTINGS = {
+    HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH,
+    HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: 
HarmBlockThreshold.BLOCK_ONLY_HIGH,
+    HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: 
HarmBlockThreshold.BLOCK_ONLY_HIGH,
+    HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
+}
+TEST_GENERATION_CONFIG = {
+    "max_output_tokens": TEST_MAX_OUTPUT_TOKENS,
+    "top_p": TEST_TOP_P,
+    "temperature": TEST_TEMPERATURE,
+}
 
 TEST_MULTIMODAL_VISION_MODEL = "gemini-pro-vision"
 TEST_VISION_PROMPT = "In 10 words or less, describe this content."
@@ -104,10 +117,16 @@ class TestGenerativeModelWithDefaultProjectIdHook:
             project_id=GCP_PROJECT,
             location=GCP_LOCATION,
             prompt=TEST_PROMPT,
+            generation_config=TEST_GENERATION_CONFIG,
+            safety_settings=TEST_SAFETY_SETTINGS,
             pretrained_model=TEST_MULTIMODAL_PRETRAINED_MODEL,
         )
         mock_model.assert_called_once_with(TEST_MULTIMODAL_PRETRAINED_MODEL)
-        
mock_model.return_value.generate_content.assert_called_once_with(TEST_PROMPT)
+        mock_model.return_value.generate_content.assert_called_once_with(
+            contents=[TEST_PROMPT],
+            generation_config=TEST_GENERATION_CONFIG,
+            safety_settings=TEST_SAFETY_SETTINGS,
+        )
 
     
@mock.patch(GENERATIVE_MODEL_STRING.format("GenerativeModelHook.get_generative_model_part"))
     
@mock.patch(GENERATIVE_MODEL_STRING.format("GenerativeModelHook.get_generative_model"))
@@ -116,6 +135,8 @@ class TestGenerativeModelWithDefaultProjectIdHook:
             project_id=GCP_PROJECT,
             location=GCP_LOCATION,
             prompt=TEST_VISION_PROMPT,
+            generation_config=TEST_GENERATION_CONFIG,
+            safety_settings=TEST_SAFETY_SETTINGS,
             pretrained_model=TEST_MULTIMODAL_VISION_MODEL,
             media_gcs_path=TEST_MEDIA_GCS_PATH,
             mime_type=TEST_MIME_TYPE,
@@ -124,5 +145,7 @@ class TestGenerativeModelWithDefaultProjectIdHook:
         mock_part.assert_called_once_with(TEST_MEDIA_GCS_PATH, TEST_MIME_TYPE)
 
         mock_model.return_value.generate_content.assert_called_once_with(
-            [TEST_VISION_PROMPT, mock_part.return_value]
+            contents=[TEST_VISION_PROMPT, mock_part.return_value],
+            generation_config=TEST_GENERATION_CONFIG,
+            safety_settings=TEST_SAFETY_SETTINGS,
         )
diff --git 
a/tests/providers/google/cloud/operators/vertex_ai/test_generative_model.py 
b/tests/providers/google/cloud/operators/vertex_ai/test_generative_model.py
index c9c23019f5..a5afd8ac27 100644
--- a/tests/providers/google/cloud/operators/vertex_ai/test_generative_model.py
+++ b/tests/providers/google/cloud/operators/vertex_ai/test_generative_model.py
@@ -22,6 +22,8 @@ import pytest
 
 # For no Pydantic environment, we need to skip the tests
 pytest.importorskip("google.cloud.aiplatform_v1")
+vertexai = pytest.importorskip("vertexai.generative_models")
+from vertexai.generative_models import HarmBlockThreshold, HarmCategory
 
 from airflow.providers.google.cloud.operators.vertex_ai.generative_model 
import (
     GenerateTextEmbeddingsOperator,
@@ -112,12 +114,21 @@ class TestVertexAIPromptMultimodalModelOperator:
     def test_execute(self, mock_hook):
         prompt = "In 10 words or less, what is Apache Airflow?"
         pretrained_model = "gemini-pro"
+        safety_settings = {
+            HarmCategory.HARM_CATEGORY_HATE_SPEECH: 
HarmBlockThreshold.BLOCK_ONLY_HIGH,
+            HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: 
HarmBlockThreshold.BLOCK_ONLY_HIGH,
+            HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: 
HarmBlockThreshold.BLOCK_ONLY_HIGH,
+            HarmCategory.HARM_CATEGORY_HARASSMENT: 
HarmBlockThreshold.BLOCK_ONLY_HIGH,
+        }
+        generation_config = {"max_output_tokens": 256, "top_p": 0.8, 
"temperature": 0.0}
 
         op = PromptMultimodalModelOperator(
             task_id=TASK_ID,
             project_id=GCP_PROJECT,
             location=GCP_LOCATION,
             prompt=prompt,
+            generation_config=generation_config,
+            safety_settings=safety_settings,
             pretrained_model=pretrained_model,
             gcp_conn_id=GCP_CONN_ID,
             impersonation_chain=IMPERSONATION_CHAIN,
@@ -131,6 +142,8 @@ class TestVertexAIPromptMultimodalModelOperator:
             project_id=GCP_PROJECT,
             location=GCP_LOCATION,
             prompt=prompt,
+            generation_config=generation_config,
+            safety_settings=safety_settings,
             pretrained_model=pretrained_model,
         )
 
@@ -142,12 +155,21 @@ class TestVertexAIPromptMultimodalModelWithMediaOperator:
         vision_prompt = "In 10 words or less, describe this content."
         media_gcs_path = 
"gs://download.tensorflow.org/example_images/320px-Felis_catus-cat_on_snow.jpg"
         mime_type = "image/jpeg"
+        safety_settings = {
+            HarmCategory.HARM_CATEGORY_HATE_SPEECH: 
HarmBlockThreshold.BLOCK_ONLY_HIGH,
+            HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: 
HarmBlockThreshold.BLOCK_ONLY_HIGH,
+            HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: 
HarmBlockThreshold.BLOCK_ONLY_HIGH,
+            HarmCategory.HARM_CATEGORY_HARASSMENT: 
HarmBlockThreshold.BLOCK_ONLY_HIGH,
+        }
+        generation_config = {"max_output_tokens": 256, "top_p": 0.8, 
"temperature": 0.0}
 
         op = PromptMultimodalModelWithMediaOperator(
             task_id=TASK_ID,
             project_id=GCP_PROJECT,
             location=GCP_LOCATION,
             prompt=vision_prompt,
+            generation_config=generation_config,
+            safety_settings=safety_settings,
             pretrained_model=pretrained_model,
             media_gcs_path=media_gcs_path,
             mime_type=mime_type,
@@ -163,6 +185,8 @@ class TestVertexAIPromptMultimodalModelWithMediaOperator:
             project_id=GCP_PROJECT,
             location=GCP_LOCATION,
             prompt=vision_prompt,
+            generation_config=generation_config,
+            safety_settings=safety_settings,
             pretrained_model=pretrained_model,
             media_gcs_path=media_gcs_path,
             mime_type=mime_type,
diff --git 
a/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_generative_model.py
 
b/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_generative_model.py
index 101cedaf7e..95c141c7af 100644
--- 
a/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_generative_model.py
+++ 
b/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_generative_model.py
@@ -25,6 +25,8 @@ from __future__ import annotations
 import os
 from datetime import datetime
 
+from vertexai.generative_models import HarmBlockThreshold, HarmCategory
+
 from airflow.models.dag import DAG
 from airflow.providers.google.cloud.operators.vertex_ai.generative_model 
import (
     GenerateTextEmbeddingsOperator,
@@ -44,6 +46,13 @@ MULTIMODAL_VISION_MODEL = "gemini-pro-vision"
 VISION_PROMPT = "In 10 words or less, describe this content."
 MEDIA_GCS_PATH = 
"gs://download.tensorflow.org/example_images/320px-Felis_catus-cat_on_snow.jpg"
 MIME_TYPE = "image/jpeg"
+GENERATION_CONFIG = {"max_output_tokens": 256, "top_p": 0.95, "temperature": 
0.0}
+SAFETY_SETTINGS = {
+    HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH,
+    HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: 
HarmBlockThreshold.BLOCK_ONLY_HIGH,
+    HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: 
HarmBlockThreshold.BLOCK_ONLY_HIGH,
+    HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
+}
 
 with DAG(
     dag_id=DAG_ID,
@@ -79,6 +88,8 @@ with DAG(
         project_id=PROJECT_ID,
         location=REGION,
         prompt=PROMPT,
+        generation_config=GENERATION_CONFIG,
+        safety_settings=SAFETY_SETTINGS,
         pretrained_model=MULTIMODAL_MODEL,
     )
     # [END how_to_cloud_vertex_ai_prompt_multimodal_model_operator]
@@ -89,6 +100,8 @@ with DAG(
         project_id=PROJECT_ID,
         location=REGION,
         prompt=VISION_PROMPT,
+        generation_config=GENERATION_CONFIG,
+        safety_settings=SAFETY_SETTINGS,
         pretrained_model=MULTIMODAL_VISION_MODEL,
         media_gcs_path=MEDIA_GCS_PATH,
         mime_type=MIME_TYPE,

Reply via email to