This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new e2b8f68ddf add generation_config and safety_settings to google cloud
multimodal model operators (#40126)
e2b8f68ddf is described below
commit e2b8f68ddf84bf55f350461aaf1c801e9a152ac1
Author: Christian Yarros <[email protected]>
AuthorDate: Fri Jun 14 10:13:50 2024 -0400
add generation_config and safety_settings to google cloud multimodal model
operators (#40126)
---
.../cloud/hooks/vertex_ai/generative_model.py | 18 ++++++++++++---
.../cloud/operators/vertex_ai/generative_model.py | 16 +++++++++++++
airflow/providers/google/provider.yaml | 2 +-
generated/provider_dependencies.json | 2 +-
.../cloud/hooks/vertex_ai/test_generative_model.py | 27 ++++++++++++++++++++--
.../operators/vertex_ai/test_generative_model.py | 24 +++++++++++++++++++
.../example_vertex_ai_generative_model.py | 13 +++++++++++
7 files changed, 95 insertions(+), 7 deletions(-)
diff --git a/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py
b/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py
index c71759c890..eb3db0f3c6 100644
--- a/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py
+++ b/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py
@@ -141,6 +141,8 @@ class GenerativeModelHook(GoogleBaseHook):
self,
prompt: str,
location: str,
+ generation_config: dict | None = None,
+ safety_settings: dict | None = None,
pretrained_model: str = "gemini-pro",
project_id: str = PROVIDE_PROJECT_ID,
) -> str:
@@ -149,17 +151,21 @@ class GenerativeModelHook(GoogleBaseHook):
:param prompt: Required. Inputs or queries that a user or a program
gives
to the Multi-modal model, in order to elicit a specific response.
+ :param location: Required. The ID of the Google Cloud location that
the service belongs to.
+ :param generation_config: Optional. Generation configuration settings.
+ :param safety_settings: Optional. Per request settings for blocking
unsafe content.
:param pretrained_model: By default uses the pre-trained model
`gemini-pro`,
supporting prompts with text-only input, including natural language
tasks, multi-turn text and code chat, and code generation. It can
output text and code.
- :param location: Required. The ID of the Google Cloud location that
the service belongs to.
:param project_id: Required. The ID of the Google Cloud project that
the service belongs to.
"""
vertexai.init(project=project_id, location=location,
credentials=self.get_credentials())
model = self.get_generative_model(pretrained_model)
- response = model.generate_content(prompt)
+ response = model.generate_content(
+ contents=[prompt], generation_config=generation_config,
safety_settings=safety_settings
+ )
return response.text
@@ -170,6 +176,8 @@ class GenerativeModelHook(GoogleBaseHook):
location: str,
media_gcs_path: str,
mime_type: str,
+ generation_config: dict | None = None,
+ safety_settings: dict | None = None,
pretrained_model: str = "gemini-pro-vision",
project_id: str = PROVIDE_PROJECT_ID,
) -> str:
@@ -178,6 +186,8 @@ class GenerativeModelHook(GoogleBaseHook):
:param prompt: Required. Inputs or queries that a user or a program
gives
to the Multi-modal model, in order to elicit a specific response.
+ :param generation_config: Optional. Generation configuration settings.
+ :param safety_settings: Optional. Per request settings for blocking
unsafe content.
:param pretrained_model: By default uses the pre-trained model
`gemini-pro-vision`,
supporting prompts with text-only input, including natural language
tasks, multi-turn text and code chat, and code generation. It can
@@ -192,6 +202,8 @@ class GenerativeModelHook(GoogleBaseHook):
model = self.get_generative_model(pretrained_model)
part = self.get_generative_model_part(media_gcs_path, mime_type)
- response = model.generate_content([prompt, part])
+ response = model.generate_content(
+ contents=[prompt, part], generation_config=generation_config,
safety_settings=safety_settings
+ )
return response.text
diff --git
a/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py
b/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py
index da1436a6ab..a42b00c677 100644
--- a/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py
+++ b/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py
@@ -187,6 +187,8 @@ class
PromptMultimodalModelOperator(GoogleCloudBaseOperator):
service belongs to (templated).
:param prompt: Required. Inputs or queries that a user or a program gives
to the Multi-modal model, in order to elicit a specific response
(templated).
+ :param generation_config: Optional. Generation configuration settings.
+ :param safety_settings: Optional. Per request settings for blocking unsafe
content.
:param pretrained_model: By default uses the pre-trained model
`gemini-pro`,
supporting prompts with text-only input, including natural language
tasks, multi-turn text and code chat, and code generation. It can
@@ -210,6 +212,8 @@ class
PromptMultimodalModelOperator(GoogleCloudBaseOperator):
project_id: str,
location: str,
prompt: str,
+ generation_config: dict | None = None,
+ safety_settings: dict | None = None,
pretrained_model: str = "gemini-pro",
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
@@ -219,6 +223,8 @@ class
PromptMultimodalModelOperator(GoogleCloudBaseOperator):
self.project_id = project_id
self.location = location
self.prompt = prompt
+ self.generation_config = generation_config
+ self.safety_settings = safety_settings
self.pretrained_model = pretrained_model
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
@@ -232,6 +238,8 @@ class
PromptMultimodalModelOperator(GoogleCloudBaseOperator):
project_id=self.project_id,
location=self.location,
prompt=self.prompt,
+ generation_config=self.generation_config,
+ safety_settings=self.safety_settings,
pretrained_model=self.pretrained_model,
)
@@ -251,6 +259,8 @@ class
PromptMultimodalModelWithMediaOperator(GoogleCloudBaseOperator):
service belongs to (templated).
:param prompt: Required. Inputs or queries that a user or a program gives
to the Multi-modal model, in order to elicit a specific response
(templated).
+ :param generation_config: Optional. Generation configuration settings.
+ :param safety_settings: Optional. Per request settings for blocking unsafe
content.
:param pretrained_model: By default uses the pre-trained model
`gemini-pro-vision`,
supporting prompts with text-only input, including natural language
tasks, multi-turn text and code chat, and code generation. It can
@@ -279,6 +289,8 @@ class
PromptMultimodalModelWithMediaOperator(GoogleCloudBaseOperator):
prompt: str,
media_gcs_path: str,
mime_type: str,
+ generation_config: dict | None = None,
+ safety_settings: dict | None = None,
pretrained_model: str = "gemini-pro-vision",
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
@@ -288,6 +300,8 @@ class
PromptMultimodalModelWithMediaOperator(GoogleCloudBaseOperator):
self.project_id = project_id
self.location = location
self.prompt = prompt
+ self.generation_config = generation_config
+ self.safety_settings = safety_settings
self.pretrained_model = pretrained_model
self.media_gcs_path = media_gcs_path
self.mime_type = mime_type
@@ -303,6 +317,8 @@ class
PromptMultimodalModelWithMediaOperator(GoogleCloudBaseOperator):
project_id=self.project_id,
location=self.location,
prompt=self.prompt,
+ generation_config=self.generation_config,
+ safety_settings=self.safety_settings,
pretrained_model=self.pretrained_model,
media_gcs_path=self.media_gcs_path,
mime_type=self.mime_type,
diff --git a/airflow/providers/google/provider.yaml
b/airflow/providers/google/provider.yaml
index 205e9f398e..376dcb92f2 100644
--- a/airflow/providers/google/provider.yaml
+++ b/airflow/providers/google/provider.yaml
@@ -113,7 +113,7 @@ dependencies:
- google-api-python-client>=2.0.2
- google-auth>=2.29.0
- google-auth-httplib2>=0.0.1
- - google-cloud-aiplatform>=1.42.1
+ - google-cloud-aiplatform>=1.54.0
- google-cloud-automl>=2.12.0
# google-cloud-bigquery version 3.21.0 introduced a performance enhancement
in QueryJob.result(),
# which has led to backward compatibility issues
diff --git a/generated/provider_dependencies.json
b/generated/provider_dependencies.json
index 6c7ca749f7..6fdf5d9c29 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -593,7 +593,7 @@
"google-api-python-client>=2.0.2",
"google-auth-httplib2>=0.0.1",
"google-auth>=2.29.0",
- "google-cloud-aiplatform>=1.42.1",
+ "google-cloud-aiplatform>=1.54.0",
"google-cloud-automl>=2.12.0",
"google-cloud-batch>=0.13.0",
"google-cloud-bigquery-datatransfer>=3.13.0",
diff --git
a/tests/providers/google/cloud/hooks/vertex_ai/test_generative_model.py
b/tests/providers/google/cloud/hooks/vertex_ai/test_generative_model.py
index 1899222174..2308903485 100644
--- a/tests/providers/google/cloud/hooks/vertex_ai/test_generative_model.py
+++ b/tests/providers/google/cloud/hooks/vertex_ai/test_generative_model.py
@@ -23,6 +23,8 @@ import pytest
# For no Pydantic environment, we need to skip the tests
pytest.importorskip("google.cloud.aiplatform_v1")
+vertexai = pytest.importorskip("vertexai.generative_models")
+from vertexai.generative_models import HarmBlockThreshold, HarmCategory
from airflow.providers.google.cloud.hooks.vertex_ai.generative_model import (
GenerativeModelHook,
@@ -45,6 +47,17 @@ TEST_TOP_K = 40
TEST_TEXT_EMBEDDING_MODEL = ""
TEST_MULTIMODAL_PRETRAINED_MODEL = "gemini-pro"
+TEST_SAFETY_SETTINGS = {
+ HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH,
+ HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT:
HarmBlockThreshold.BLOCK_ONLY_HIGH,
+ HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT:
HarmBlockThreshold.BLOCK_ONLY_HIGH,
+ HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
+}
+TEST_GENERATION_CONFIG = {
+ "max_output_tokens": TEST_MAX_OUTPUT_TOKENS,
+ "top_p": TEST_TOP_P,
+ "temperature": TEST_TEMPERATURE,
+}
TEST_MULTIMODAL_VISION_MODEL = "gemini-pro-vision"
TEST_VISION_PROMPT = "In 10 words or less, describe this content."
@@ -104,10 +117,16 @@ class TestGenerativeModelWithDefaultProjectIdHook:
project_id=GCP_PROJECT,
location=GCP_LOCATION,
prompt=TEST_PROMPT,
+ generation_config=TEST_GENERATION_CONFIG,
+ safety_settings=TEST_SAFETY_SETTINGS,
pretrained_model=TEST_MULTIMODAL_PRETRAINED_MODEL,
)
mock_model.assert_called_once_with(TEST_MULTIMODAL_PRETRAINED_MODEL)
-
mock_model.return_value.generate_content.assert_called_once_with(TEST_PROMPT)
+ mock_model.return_value.generate_content.assert_called_once_with(
+ contents=[TEST_PROMPT],
+ generation_config=TEST_GENERATION_CONFIG,
+ safety_settings=TEST_SAFETY_SETTINGS,
+ )
@mock.patch(GENERATIVE_MODEL_STRING.format("GenerativeModelHook.get_generative_model_part"))
@mock.patch(GENERATIVE_MODEL_STRING.format("GenerativeModelHook.get_generative_model"))
@@ -116,6 +135,8 @@ class TestGenerativeModelWithDefaultProjectIdHook:
project_id=GCP_PROJECT,
location=GCP_LOCATION,
prompt=TEST_VISION_PROMPT,
+ generation_config=TEST_GENERATION_CONFIG,
+ safety_settings=TEST_SAFETY_SETTINGS,
pretrained_model=TEST_MULTIMODAL_VISION_MODEL,
media_gcs_path=TEST_MEDIA_GCS_PATH,
mime_type=TEST_MIME_TYPE,
@@ -124,5 +145,7 @@ class TestGenerativeModelWithDefaultProjectIdHook:
mock_part.assert_called_once_with(TEST_MEDIA_GCS_PATH, TEST_MIME_TYPE)
mock_model.return_value.generate_content.assert_called_once_with(
- [TEST_VISION_PROMPT, mock_part.return_value]
+ contents=[TEST_VISION_PROMPT, mock_part.return_value],
+ generation_config=TEST_GENERATION_CONFIG,
+ safety_settings=TEST_SAFETY_SETTINGS,
)
diff --git
a/tests/providers/google/cloud/operators/vertex_ai/test_generative_model.py
b/tests/providers/google/cloud/operators/vertex_ai/test_generative_model.py
index c9c23019f5..a5afd8ac27 100644
--- a/tests/providers/google/cloud/operators/vertex_ai/test_generative_model.py
+++ b/tests/providers/google/cloud/operators/vertex_ai/test_generative_model.py
@@ -22,6 +22,8 @@ import pytest
# For no Pydantic environment, we need to skip the tests
pytest.importorskip("google.cloud.aiplatform_v1")
+vertexai = pytest.importorskip("vertexai.generative_models")
+from vertexai.generative_models import HarmBlockThreshold, HarmCategory
from airflow.providers.google.cloud.operators.vertex_ai.generative_model
import (
GenerateTextEmbeddingsOperator,
@@ -112,12 +114,21 @@ class TestVertexAIPromptMultimodalModelOperator:
def test_execute(self, mock_hook):
prompt = "In 10 words or less, what is Apache Airflow?"
pretrained_model = "gemini-pro"
+ safety_settings = {
+ HarmCategory.HARM_CATEGORY_HATE_SPEECH:
HarmBlockThreshold.BLOCK_ONLY_HIGH,
+ HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT:
HarmBlockThreshold.BLOCK_ONLY_HIGH,
+ HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT:
HarmBlockThreshold.BLOCK_ONLY_HIGH,
+ HarmCategory.HARM_CATEGORY_HARASSMENT:
HarmBlockThreshold.BLOCK_ONLY_HIGH,
+ }
+ generation_config = {"max_output_tokens": 256, "top_p": 0.8,
"temperature": 0.0}
op = PromptMultimodalModelOperator(
task_id=TASK_ID,
project_id=GCP_PROJECT,
location=GCP_LOCATION,
prompt=prompt,
+ generation_config=generation_config,
+ safety_settings=safety_settings,
pretrained_model=pretrained_model,
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
@@ -131,6 +142,8 @@ class TestVertexAIPromptMultimodalModelOperator:
project_id=GCP_PROJECT,
location=GCP_LOCATION,
prompt=prompt,
+ generation_config=generation_config,
+ safety_settings=safety_settings,
pretrained_model=pretrained_model,
)
@@ -142,12 +155,21 @@ class TestVertexAIPromptMultimodalModelWithMediaOperator:
vision_prompt = "In 10 words or less, describe this content."
media_gcs_path =
"gs://download.tensorflow.org/example_images/320px-Felis_catus-cat_on_snow.jpg"
mime_type = "image/jpeg"
+ safety_settings = {
+ HarmCategory.HARM_CATEGORY_HATE_SPEECH:
HarmBlockThreshold.BLOCK_ONLY_HIGH,
+ HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT:
HarmBlockThreshold.BLOCK_ONLY_HIGH,
+ HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT:
HarmBlockThreshold.BLOCK_ONLY_HIGH,
+ HarmCategory.HARM_CATEGORY_HARASSMENT:
HarmBlockThreshold.BLOCK_ONLY_HIGH,
+ }
+ generation_config = {"max_output_tokens": 256, "top_p": 0.8,
"temperature": 0.0}
op = PromptMultimodalModelWithMediaOperator(
task_id=TASK_ID,
project_id=GCP_PROJECT,
location=GCP_LOCATION,
prompt=vision_prompt,
+ generation_config=generation_config,
+ safety_settings=safety_settings,
pretrained_model=pretrained_model,
media_gcs_path=media_gcs_path,
mime_type=mime_type,
@@ -163,6 +185,8 @@ class TestVertexAIPromptMultimodalModelWithMediaOperator:
project_id=GCP_PROJECT,
location=GCP_LOCATION,
prompt=vision_prompt,
+ generation_config=generation_config,
+ safety_settings=safety_settings,
pretrained_model=pretrained_model,
media_gcs_path=media_gcs_path,
mime_type=mime_type,
diff --git
a/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_generative_model.py
b/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_generative_model.py
index 101cedaf7e..95c141c7af 100644
---
a/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_generative_model.py
+++
b/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_generative_model.py
@@ -25,6 +25,8 @@ from __future__ import annotations
import os
from datetime import datetime
+from vertexai.generative_models import HarmBlockThreshold, HarmCategory
+
from airflow.models.dag import DAG
from airflow.providers.google.cloud.operators.vertex_ai.generative_model
import (
GenerateTextEmbeddingsOperator,
@@ -44,6 +46,13 @@ MULTIMODAL_VISION_MODEL = "gemini-pro-vision"
VISION_PROMPT = "In 10 words or less, describe this content."
MEDIA_GCS_PATH =
"gs://download.tensorflow.org/example_images/320px-Felis_catus-cat_on_snow.jpg"
MIME_TYPE = "image/jpeg"
+GENERATION_CONFIG = {"max_output_tokens": 256, "top_p": 0.95, "temperature":
0.0}
+SAFETY_SETTINGS = {
+ HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH,
+ HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT:
HarmBlockThreshold.BLOCK_ONLY_HIGH,
+ HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT:
HarmBlockThreshold.BLOCK_ONLY_HIGH,
+ HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
+}
with DAG(
dag_id=DAG_ID,
@@ -79,6 +88,8 @@ with DAG(
project_id=PROJECT_ID,
location=REGION,
prompt=PROMPT,
+ generation_config=GENERATION_CONFIG,
+ safety_settings=SAFETY_SETTINGS,
pretrained_model=MULTIMODAL_MODEL,
)
# [END how_to_cloud_vertex_ai_prompt_multimodal_model_operator]
@@ -89,6 +100,8 @@ with DAG(
project_id=PROJECT_ID,
location=REGION,
prompt=VISION_PROMPT,
+ generation_config=GENERATION_CONFIG,
+ safety_settings=SAFETY_SETTINGS,
pretrained_model=MULTIMODAL_VISION_MODEL,
media_gcs_path=MEDIA_GCS_PATH,
mime_type=MIME_TYPE,