This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 7cf54a734e Add CountTokensOperator for Google Generative AI
CountTokensAPI (#41908)
7cf54a734e is described below
commit 7cf54a734e1eefa04ab710cb2eb364d529c6b1b1
Author: Christian Yarros <[email protected]>
AuthorDate: Sun Sep 1 15:13:09 2024 -0400
Add CountTokensOperator for Google Generative AI CountTokensAPI (#41908)
* Add CountTokensOperator for Google Generative AI CountTokensAPI
* Update system test DAG with correct arguments
---
.../cloud/hooks/vertex_ai/generative_model.py | 34 +++++++++-
.../cloud/operators/vertex_ai/generative_model.py | 74 +++++++++++++++++++++-
.../operators/cloud/vertex_ai.rst | 11 ++++
.../cloud/hooks/vertex_ai/test_generative_model.py | 13 ++++
.../operators/vertex_ai/test_generative_model.py | 31 +++++++++
.../example_vertex_ai_generative_model.py | 11 ++++
6 files changed, 170 insertions(+), 4 deletions(-)
diff --git a/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py
b/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py
index eaf306a1f0..2fe73221bf 100644
--- a/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py
+++ b/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py
@@ -32,7 +32,8 @@ from airflow.providers.google.common.deprecated import
deprecated
from airflow.providers.google.common.hooks.base_google import
PROVIDE_PROJECT_ID, GoogleBaseHook
if TYPE_CHECKING:
- from google.cloud.aiplatform_v1 import types
+ from google.cloud.aiplatform_v1 import types as types_v1
+ from google.cloud.aiplatform_v1beta1 import types as types_v1beta1
class GenerativeModelHook(GoogleBaseHook):
@@ -367,7 +368,7 @@ class GenerativeModelHook(GoogleBaseHook):
adapter_size: int | None = None,
learning_rate_multiplier: float | None = None,
project_id: str = PROVIDE_PROJECT_ID,
- ) -> types.TuningJob:
+ ) -> types_v1.TuningJob:
"""
Use the Supervised Fine Tuning API to create a tuning job.
@@ -406,3 +407,32 @@ class GenerativeModelHook(GoogleBaseHook):
sft_tuning_job.refresh()
return sft_tuning_job
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ def count_tokens(
+ self,
+ contents: list,
+ location: str,
+ pretrained_model: str = "gemini-pro",
+ project_id: str = PROVIDE_PROJECT_ID,
+ ) -> types_v1beta1.CountTokensResponse:
+ """
+ Use the Vertex AI Count Tokens API to calculate the number of input
tokens before sending a request to the Gemini API.
+
+ :param contents: Required. The multi-part content of a message that a
user or a program
+ gives to the generative model, in order to elicit a specific
response.
+ :param location: Required. The ID of the Google Cloud location that
the service belongs to.
+ :param pretrained_model: By default uses the pre-trained model
`gemini-pro`,
+ supporting prompts with text-only input, including natural language
+ tasks, multi-turn text and code chat, and code generation. It can
+ output text and code.
+ :param project_id: Required. The ID of the Google Cloud project that
the service belongs to.
+ """
+ vertexai.init(project=project_id, location=location,
credentials=self.get_credentials())
+
+ model = self.get_generative_model(pretrained_model)
+ response = model.count_tokens(
+ contents=contents,
+ )
+
+ return response
diff --git
a/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py
b/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py
index 13e76cd17c..3a49035002 100644
--- a/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py
+++ b/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py
@@ -21,7 +21,8 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Sequence
-from google.cloud.aiplatform_v1 import types
+from google.cloud.aiplatform_v1 import types as types_v1
+from google.cloud.aiplatform_v1beta1 import types as types_v1beta1
from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.google.cloud.hooks.vertex_ai.generative_model import
GenerativeModelHook
@@ -665,4 +666,73 @@ class
SupervisedFineTuningTrainOperator(GoogleCloudBaseOperator):
self.xcom_push(context, key="tuned_model_name",
value=response.tuned_model_name)
self.xcom_push(context, key="tuned_model_endpoint_name",
value=response.tuned_model_endpoint_name)
- return types.TuningJob.to_dict(response)
+ return types_v1.TuningJob.to_dict(response)
+
+
+class CountTokensOperator(GoogleCloudBaseOperator):
+ """
+ Use the Vertex AI Count Tokens API to calculate the number of input tokens
before sending a request to the Gemini API.
+
+ :param project_id: Required. The ID of the Google Cloud project that the
+ service belongs to (templated).
+ :param contents: Required. The multi-part content of a message that a user
or a program
+ gives to the generative model, in order to elicit a specific response.
+ :param location: Required. The ID of the Google Cloud location that the
+ service belongs to (templated).
+ :param system_instruction: Optional. Instructions for the model to steer
it toward better
+ performance. For example, "Answer as concisely as possible"
+ :param pretrained_model: By default uses the pre-trained model
`gemini-pro`,
+ supporting prompts with text-only input, including natural language
+ tasks, multi-turn text and code chat, and code generation. It can
+ output text and code.
+ :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
+ :param impersonation_chain: Optional service account to impersonate using
short-term
+ credentials, or chained list of accounts required to get the
access_token
+ of the last account in the list, which will be impersonated in the
request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding
identity, with first
+ account from the list granting this role to the originating account
(templated).
+ """
+
+ template_fields = ("location", "project_id", "impersonation_chain",
"contents", "pretrained_model")
+
+ def __init__(
+ self,
+ *,
+ project_id: str,
+ contents: list,
+ location: str,
+ pretrained_model: str = "gemini-pro",
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.project_id = project_id
+ self.location = location
+ self.contents = contents
+ self.pretrained_model = pretrained_model
+ self.gcp_conn_id = gcp_conn_id
+ self.impersonation_chain = impersonation_chain
+
+ def execute(self, context: Context):
+ self.hook = GenerativeModelHook(
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ )
+ response = self.hook.count_tokens(
+ project_id=self.project_id,
+ location=self.location,
+ contents=self.contents,
+ pretrained_model=self.pretrained_model,
+ )
+
+ self.log.info("Total tokens: %s", response.total_tokens)
+ self.log.info("Total billable characters: %s",
response.total_billable_characters)
+
+ self.xcom_push(context, key="total_tokens",
value=response.total_tokens)
+ self.xcom_push(context, key="total_billable_characters",
value=response.total_billable_characters)
+
+ return types_v1beta1.CountTokensResponse.to_dict(response)
diff --git a/docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst
b/docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst
index e359cd21d2..87c7b81bc9 100644
--- a/docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst
+++ b/docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst
@@ -625,6 +625,17 @@ The operator returns the tuned model's endpoint name in
:ref:`XCom <concepts:xco
:start-after: [START
how_to_cloud_vertex_ai_supervised_fine_tuning_train_operator]
:end-before: [END
how_to_cloud_vertex_ai_supervised_fine_tuning_train_operator]
+
+To calculates the number of input tokens before sending a request to the
Gemini API you can use:
+:class:`~airflow.providers.google.cloud.operators.vertex_ai.generative_model.CountTokensOperator`.
+The operator returns the total tokens in :ref:`XCom <concepts:xcom>` under
``total_tokens`` key.
+
+.. exampleinclude::
/../../tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_generative_model.py
+ :language: python
+ :dedent: 4
+ :start-after: [START how_to_cloud_vertex_ai_count_tokens_operator]
+ :end-before: [END how_to_cloud_vertex_ai_count_tokens_operator]
+
Reference
^^^^^^^^^
diff --git
a/tests/providers/google/cloud/hooks/vertex_ai/test_generative_model.py
b/tests/providers/google/cloud/hooks/vertex_ai/test_generative_model.py
index cee720d519..5eff60b91d 100644
--- a/tests/providers/google/cloud/hooks/vertex_ai/test_generative_model.py
+++ b/tests/providers/google/cloud/hooks/vertex_ai/test_generative_model.py
@@ -217,3 +217,16 @@ class TestGenerativeModelWithDefaultProjectIdHook:
learning_rate_multiplier=None,
tuned_model_display_name=None,
)
+
+
@mock.patch(GENERATIVE_MODEL_STRING.format("GenerativeModelHook.get_generative_model"))
+ def test_count_tokens(self, mock_model) -> None:
+ self.hook.count_tokens(
+ project_id=GCP_PROJECT,
+ contents=TEST_CONTENTS,
+ location=GCP_LOCATION,
+ pretrained_model=TEST_MULTIMODAL_PRETRAINED_MODEL,
+ )
+ mock_model.assert_called_once_with(TEST_MULTIMODAL_PRETRAINED_MODEL)
+ mock_model.return_value.count_tokens.assert_called_once_with(
+ contents=TEST_CONTENTS,
+ )
diff --git
a/tests/providers/google/cloud/operators/vertex_ai/test_generative_model.py
b/tests/providers/google/cloud/operators/vertex_ai/test_generative_model.py
index eb6b6f946e..ada9d7843d 100644
--- a/tests/providers/google/cloud/operators/vertex_ai/test_generative_model.py
+++ b/tests/providers/google/cloud/operators/vertex_ai/test_generative_model.py
@@ -26,10 +26,12 @@ from airflow.exceptions import (
# For no Pydantic environment, we need to skip the tests
pytest.importorskip("google.cloud.aiplatform_v1")
+pytest.importorskip("google.cloud.aiplatform_v1beta1")
vertexai = pytest.importorskip("vertexai.generative_models")
from vertexai.generative_models import HarmBlockThreshold, HarmCategory, Tool,
grounding
from airflow.providers.google.cloud.operators.vertex_ai.generative_model
import (
+ CountTokensOperator,
GenerateTextEmbeddingsOperator,
GenerativeModelGenerateContentOperator,
PromptLanguageModelOperator,
@@ -417,3 +419,32 @@ class TestVertexAISupervisedFineTuningTrainOperator:
tuned_model_display_name=None,
validation_dataset=None,
)
+
+
+class TestVertexAICountTokensOperator:
+ @mock.patch(VERTEX_AI_PATH.format("generative_model.GenerativeModelHook"))
+
@mock.patch("google.cloud.aiplatform_v1beta1.types.CountTokensResponse.to_dict")
+ def test_execute(self, to_dict_mock, mock_hook):
+ contents = ["In 10 words or less, what is Apache Airflow?"]
+ pretrained_model = "gemini-pro"
+
+ op = CountTokensOperator(
+ task_id=TASK_ID,
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ contents=contents,
+ pretrained_model=pretrained_model,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ )
+ op.execute(context={"ti": mock.MagicMock()})
+ mock_hook.assert_called_once_with(
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ )
+ mock_hook.return_value.count_tokens.assert_called_once_with(
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ contents=contents,
+ pretrained_model=pretrained_model,
+ )
diff --git
a/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_generative_model.py
b/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_generative_model.py
index fdacfef2be..d28fa89c98 100644
---
a/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_generative_model.py
+++
b/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_generative_model.py
@@ -29,6 +29,7 @@ from vertexai.generative_models import HarmBlockThreshold,
HarmCategory, Tool, g
from airflow.models.dag import DAG
from airflow.providers.google.cloud.operators.vertex_ai.generative_model
import (
+ CountTokensOperator,
GenerativeModelGenerateContentOperator,
TextEmbeddingModelGetEmbeddingsOperator,
TextGenerationModelPredictOperator,
@@ -84,6 +85,16 @@ with DAG(
)
# [END how_to_cloud_vertex_ai_text_embedding_model_get_embeddings_operator]
+ # [START how_to_cloud_vertex_ai_count_tokens_operator]
+ count_tokens_task = CountTokensOperator(
+ task_id="count_tokens_task",
+ project_id=PROJECT_ID,
+ contents=CONTENTS,
+ location=REGION,
+ pretrained_model=MULTIMODAL_MODEL,
+ )
+ # [END how_to_cloud_vertex_ai_count_tokens_operator]
+
# [START how_to_cloud_vertex_ai_generative_model_generate_content_operator]
generate_content_task = GenerativeModelGenerateContentOperator(
task_id="generate_content_task",