This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 7cf54a734e Add CountTokensOperator for Google Generative AI 
CountTokensAPI (#41908)
7cf54a734e is described below

commit 7cf54a734e1eefa04ab710cb2eb364d529c6b1b1
Author: Christian Yarros <[email protected]>
AuthorDate: Sun Sep 1 15:13:09 2024 -0400

    Add CountTokensOperator for Google Generative AI CountTokensAPI (#41908)
    
    * Add CountTokensOperator for Google Generative AI CountTokensAPI
    
    * Update system test DAG with correct arguments
---
 .../cloud/hooks/vertex_ai/generative_model.py      | 34 +++++++++-
 .../cloud/operators/vertex_ai/generative_model.py  | 74 +++++++++++++++++++++-
 .../operators/cloud/vertex_ai.rst                  | 11 ++++
 .../cloud/hooks/vertex_ai/test_generative_model.py | 13 ++++
 .../operators/vertex_ai/test_generative_model.py   | 31 +++++++++
 .../example_vertex_ai_generative_model.py          | 11 ++++
 6 files changed, 170 insertions(+), 4 deletions(-)

diff --git a/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py 
b/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py
index eaf306a1f0..2fe73221bf 100644
--- a/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py
+++ b/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py
@@ -32,7 +32,8 @@ from airflow.providers.google.common.deprecated import 
deprecated
 from airflow.providers.google.common.hooks.base_google import 
PROVIDE_PROJECT_ID, GoogleBaseHook
 
 if TYPE_CHECKING:
-    from google.cloud.aiplatform_v1 import types
+    from google.cloud.aiplatform_v1 import types as types_v1
+    from google.cloud.aiplatform_v1beta1 import types as types_v1beta1
 
 
 class GenerativeModelHook(GoogleBaseHook):
@@ -367,7 +368,7 @@ class GenerativeModelHook(GoogleBaseHook):
         adapter_size: int | None = None,
         learning_rate_multiplier: float | None = None,
         project_id: str = PROVIDE_PROJECT_ID,
-    ) -> types.TuningJob:
+    ) -> types_v1.TuningJob:
         """
         Use the Supervised Fine Tuning API to create a tuning job.
 
@@ -406,3 +407,32 @@ class GenerativeModelHook(GoogleBaseHook):
             sft_tuning_job.refresh()
 
         return sft_tuning_job
+
+    @GoogleBaseHook.fallback_to_default_project_id
+    def count_tokens(
+        self,
+        contents: list,
+        location: str,
+        pretrained_model: str = "gemini-pro",
+        project_id: str = PROVIDE_PROJECT_ID,
+    ) -> types_v1beta1.CountTokensResponse:
+        """
+        Use the Vertex AI Count Tokens API to calculate the number of input 
tokens before sending a request to the Gemini API.
+
+        :param contents: Required. The multi-part content of a message that a 
user or a program
+            gives to the generative model, in order to elicit a specific 
response.
+        :param location: Required. The ID of the Google Cloud location that 
the service belongs to.
+        :param pretrained_model: By default uses the pre-trained model 
`gemini-pro`,
+            supporting prompts with text-only input, including natural language
+            tasks, multi-turn text and code chat, and code generation. It can
+            output text and code.
+        :param project_id: Required. The ID of the Google Cloud project that 
the service belongs to.
+        """
+        vertexai.init(project=project_id, location=location, 
credentials=self.get_credentials())
+
+        model = self.get_generative_model(pretrained_model)
+        response = model.count_tokens(
+            contents=contents,
+        )
+
+        return response
diff --git 
a/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py 
b/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py
index 13e76cd17c..3a49035002 100644
--- a/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py
+++ b/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py
@@ -21,7 +21,8 @@ from __future__ import annotations
 
 from typing import TYPE_CHECKING, Sequence
 
-from google.cloud.aiplatform_v1 import types
+from google.cloud.aiplatform_v1 import types as types_v1
+from google.cloud.aiplatform_v1beta1 import types as types_v1beta1
 
 from airflow.exceptions import AirflowProviderDeprecationWarning
 from airflow.providers.google.cloud.hooks.vertex_ai.generative_model import 
GenerativeModelHook
@@ -665,4 +666,73 @@ class 
SupervisedFineTuningTrainOperator(GoogleCloudBaseOperator):
         self.xcom_push(context, key="tuned_model_name", 
value=response.tuned_model_name)
         self.xcom_push(context, key="tuned_model_endpoint_name", 
value=response.tuned_model_endpoint_name)
 
-        return types.TuningJob.to_dict(response)
+        return types_v1.TuningJob.to_dict(response)
+
+
+class CountTokensOperator(GoogleCloudBaseOperator):
+    """
+    Use the Vertex AI Count Tokens API to calculate the number of input tokens 
before sending a request to the Gemini API.
+
+    :param project_id: Required. The ID of the Google Cloud project that the
+        service belongs to (templated).
+    :param contents: Required. The multi-part content of a message that a user 
or a program
+        gives to the generative model, in order to elicit a specific response.
+    :param location: Required. The ID of the Google Cloud location that the
+        service belongs to (templated).
+    :param system_instruction: Optional. Instructions for the model to steer 
it toward better
+        performance. For example, "Answer as concisely as possible"
+    :param pretrained_model: By default uses the pre-trained model 
`gemini-pro`,
+        supporting prompts with text-only input, including natural language
+        tasks, multi-turn text and code chat, and code generation. It can
+        output text and code.
+    :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
+    :param impersonation_chain: Optional service account to impersonate using 
short-term
+        credentials, or chained list of accounts required to get the 
access_token
+        of the last account in the list, which will be impersonated in the 
request.
+        If set as a string, the account must grant the originating account
+        the Service Account Token Creator IAM role.
+        If set as a sequence, the identities from the list must grant
+        Service Account Token Creator IAM role to the directly preceding 
identity, with first
+        account from the list granting this role to the originating account 
(templated).
+    """
+
+    template_fields = ("location", "project_id", "impersonation_chain", 
"contents", "pretrained_model")
+
+    def __init__(
+        self,
+        *,
+        project_id: str,
+        contents: list,
+        location: str,
+        pretrained_model: str = "gemini-pro",
+        gcp_conn_id: str = "google_cloud_default",
+        impersonation_chain: str | Sequence[str] | None = None,
+        **kwargs,
+    ) -> None:
+        super().__init__(**kwargs)
+        self.project_id = project_id
+        self.location = location
+        self.contents = contents
+        self.pretrained_model = pretrained_model
+        self.gcp_conn_id = gcp_conn_id
+        self.impersonation_chain = impersonation_chain
+
+    def execute(self, context: Context):
+        self.hook = GenerativeModelHook(
+            gcp_conn_id=self.gcp_conn_id,
+            impersonation_chain=self.impersonation_chain,
+        )
+        response = self.hook.count_tokens(
+            project_id=self.project_id,
+            location=self.location,
+            contents=self.contents,
+            pretrained_model=self.pretrained_model,
+        )
+
+        self.log.info("Total tokens: %s", response.total_tokens)
+        self.log.info("Total billable characters: %s", 
response.total_billable_characters)
+
+        self.xcom_push(context, key="total_tokens", 
value=response.total_tokens)
+        self.xcom_push(context, key="total_billable_characters", 
value=response.total_billable_characters)
+
+        return types_v1beta1.CountTokensResponse.to_dict(response)
diff --git a/docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst 
b/docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst
index e359cd21d2..87c7b81bc9 100644
--- a/docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst
+++ b/docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst
@@ -625,6 +625,17 @@ The operator returns the tuned model's endpoint name in 
:ref:`XCom <concepts:xco
     :start-after: [START 
how_to_cloud_vertex_ai_supervised_fine_tuning_train_operator]
     :end-before: [END 
how_to_cloud_vertex_ai_supervised_fine_tuning_train_operator]
 
+
+To calculates the number of input tokens before sending a request to the 
Gemini API you can use:
+:class:`~airflow.providers.google.cloud.operators.vertex_ai.generative_model.CountTokensOperator`.
+The operator returns the total tokens in :ref:`XCom <concepts:xcom>` under 
``total_tokens`` key.
+
+.. exampleinclude:: 
/../../tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_generative_model.py
+    :language: python
+    :dedent: 4
+    :start-after: [START how_to_cloud_vertex_ai_count_tokens_operator]
+    :end-before: [END how_to_cloud_vertex_ai_count_tokens_operator]
+
 Reference
 ^^^^^^^^^
 
diff --git 
a/tests/providers/google/cloud/hooks/vertex_ai/test_generative_model.py 
b/tests/providers/google/cloud/hooks/vertex_ai/test_generative_model.py
index cee720d519..5eff60b91d 100644
--- a/tests/providers/google/cloud/hooks/vertex_ai/test_generative_model.py
+++ b/tests/providers/google/cloud/hooks/vertex_ai/test_generative_model.py
@@ -217,3 +217,16 @@ class TestGenerativeModelWithDefaultProjectIdHook:
             learning_rate_multiplier=None,
             tuned_model_display_name=None,
         )
+
+    
@mock.patch(GENERATIVE_MODEL_STRING.format("GenerativeModelHook.get_generative_model"))
+    def test_count_tokens(self, mock_model) -> None:
+        self.hook.count_tokens(
+            project_id=GCP_PROJECT,
+            contents=TEST_CONTENTS,
+            location=GCP_LOCATION,
+            pretrained_model=TEST_MULTIMODAL_PRETRAINED_MODEL,
+        )
+        mock_model.assert_called_once_with(TEST_MULTIMODAL_PRETRAINED_MODEL)
+        mock_model.return_value.count_tokens.assert_called_once_with(
+            contents=TEST_CONTENTS,
+        )
diff --git 
a/tests/providers/google/cloud/operators/vertex_ai/test_generative_model.py 
b/tests/providers/google/cloud/operators/vertex_ai/test_generative_model.py
index eb6b6f946e..ada9d7843d 100644
--- a/tests/providers/google/cloud/operators/vertex_ai/test_generative_model.py
+++ b/tests/providers/google/cloud/operators/vertex_ai/test_generative_model.py
@@ -26,10 +26,12 @@ from airflow.exceptions import (
 
 # For no Pydantic environment, we need to skip the tests
 pytest.importorskip("google.cloud.aiplatform_v1")
+pytest.importorskip("google.cloud.aiplatform_v1beta1")
 vertexai = pytest.importorskip("vertexai.generative_models")
 from vertexai.generative_models import HarmBlockThreshold, HarmCategory, Tool, 
grounding
 
 from airflow.providers.google.cloud.operators.vertex_ai.generative_model 
import (
+    CountTokensOperator,
     GenerateTextEmbeddingsOperator,
     GenerativeModelGenerateContentOperator,
     PromptLanguageModelOperator,
@@ -417,3 +419,32 @@ class TestVertexAISupervisedFineTuningTrainOperator:
             tuned_model_display_name=None,
             validation_dataset=None,
         )
+
+
+class TestVertexAICountTokensOperator:
+    @mock.patch(VERTEX_AI_PATH.format("generative_model.GenerativeModelHook"))
+    
@mock.patch("google.cloud.aiplatform_v1beta1.types.CountTokensResponse.to_dict")
+    def test_execute(self, to_dict_mock, mock_hook):
+        contents = ["In 10 words or less, what is Apache Airflow?"]
+        pretrained_model = "gemini-pro"
+
+        op = CountTokensOperator(
+            task_id=TASK_ID,
+            project_id=GCP_PROJECT,
+            location=GCP_LOCATION,
+            contents=contents,
+            pretrained_model=pretrained_model,
+            gcp_conn_id=GCP_CONN_ID,
+            impersonation_chain=IMPERSONATION_CHAIN,
+        )
+        op.execute(context={"ti": mock.MagicMock()})
+        mock_hook.assert_called_once_with(
+            gcp_conn_id=GCP_CONN_ID,
+            impersonation_chain=IMPERSONATION_CHAIN,
+        )
+        mock_hook.return_value.count_tokens.assert_called_once_with(
+            project_id=GCP_PROJECT,
+            location=GCP_LOCATION,
+            contents=contents,
+            pretrained_model=pretrained_model,
+        )
diff --git 
a/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_generative_model.py
 
b/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_generative_model.py
index fdacfef2be..d28fa89c98 100644
--- 
a/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_generative_model.py
+++ 
b/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_generative_model.py
@@ -29,6 +29,7 @@ from vertexai.generative_models import HarmBlockThreshold, 
HarmCategory, Tool, g
 
 from airflow.models.dag import DAG
 from airflow.providers.google.cloud.operators.vertex_ai.generative_model 
import (
+    CountTokensOperator,
     GenerativeModelGenerateContentOperator,
     TextEmbeddingModelGetEmbeddingsOperator,
     TextGenerationModelPredictOperator,
@@ -84,6 +85,16 @@ with DAG(
     )
     # [END how_to_cloud_vertex_ai_text_embedding_model_get_embeddings_operator]
 
+    # [START how_to_cloud_vertex_ai_count_tokens_operator]
+    count_tokens_task = CountTokensOperator(
+        task_id="count_tokens_task",
+        project_id=PROJECT_ID,
+        contents=CONTENTS,
+        location=REGION,
+        pretrained_model=MULTIMODAL_MODEL,
+    )
+    # [END how_to_cloud_vertex_ai_count_tokens_operator]
+
     # [START how_to_cloud_vertex_ai_generative_model_generate_content_operator]
     generate_content_task = GenerativeModelGenerateContentOperator(
         task_id="generate_content_task",

Reply via email to