This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 01022368c3d Add Operators for Gemini Batch API and Gemini Files API.
(#59518)
01022368c3d is described below
commit 01022368c3d2cc974615beab9243e1541c64e5ab
Author: Nitochkin <[email protected]>
AuthorDate: Wed Jan 7 16:28:07 2026 +0100
Add Operators for Gemini Batch API and Gemini Files API. (#59518)
Co-authored-by: Anton Nitochkin <[email protected]>
---
docs/spelling_wordlist.txt | 1 +
providers/google/docs/operators/cloud/gen_ai.rst | 189 ++++-
.../airflow/providers/google/cloud/hooks/gen_ai.py | 188 +++++
.../providers/google/cloud/operators/gen_ai.py | 865 +++++++++++++++++++++
.../gen_ai/example_gen_ai_gemini_batch_api.py | 341 ++++++++
.../gemini_batch_embeddings_requests.jsonl | 2 +
.../gen_ai/resources/gemini_batch_requests.jsonl | 2 +
.../tests/unit/google/cloud/hooks/test_gen_ai.py | 184 +++++
.../unit/google/cloud/operators/test_gen_ai.py | 259 ++++++
9 files changed, 2025 insertions(+), 6 deletions(-)
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index a3e10044b57..f6d0e2d5fc5 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -753,6 +753,7 @@ gdbm
gdrive
Gelf
gelf
+gemini
generateUploadUrl
Gentner
geq
diff --git a/providers/google/docs/operators/cloud/gen_ai.rst
b/providers/google/docs/operators/cloud/gen_ai.rst
index f43cba500f3..1cbb4534aa1 100644
--- a/providers/google/docs/operators/cloud/gen_ai.rst
+++ b/providers/google/docs/operators/cloud/gen_ai.rst
@@ -15,8 +15,18 @@
specific language governing permissions and limitations
under the License.
-Google Cloud Generative AI on Vertex AI Operators
-=================================================
+Google Cloud Generative AI Operators
+====================================
+
+The Google Cloud Generative AI Operators ecosystem is anchored by the Gemini
family of multimodal models, which
+provide interfaces for generating and processing diverse inputs like text,
images, and audio. By leveraging
+these foundation models, developers can securely prompt, tune, and ground AI
using their own proprietary data.
+This capability enables the construction of versatile applications, ranging
from custom chatbots and code assistants to
+automated content summarization tools.
+
+
+Interacting with Generative AI on Vertex AI
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
The `Google Cloud VertexAI
<https://cloud.google.com/vertex-ai/generative-ai/docs/>`__
extends Vertex AI with powerful foundation models capable of generating text,
images, and other modalities.
@@ -26,10 +36,6 @@ applications such as chat bots, content creation tools, code
assistants, and sum
With Vertex AI, you can securely integrate generative capabilities into
enterprise workflows, monitor usage,
evaluate model quality, and deploy models at scale — all within the same
managed ML platform.
-
-Interacting with Generative AI
-^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-
To generate text embeddings you can use
:class:`~airflow.providers.google.cloud.operators.gen_ai.GenAIGenerateEmbeddingsOperator`.
The operator returns the model's response in :ref:`XCom <concepts:xcom>` under
``model_response`` key.
@@ -103,6 +109,175 @@ The operator returns the cached content response in
:ref:`XCom <concepts:xcom>`
:start-after: [START
how_to_cloud_gen_ai_generate_from_cached_content_operator]
:end-before: [END
how_to_cloud_gen_ai_generate_from_cached_content_operator]
+Interacting with Gemini Batch API
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+The `Gemini Batch API <https://ai.google.dev/gemini-api/docs/batch-api>`__ is
designed to process large volumes
+of requests asynchronously at 50% of the standard cost. The target turnaround
time is 24 hours,
+but in majority of cases, it is much quicker. Use Batch API for large-scale,
non-urgent tasks such as
+data pre-processing or running evaluations where an immediate response is not
required.
+
+Create batch job
+""""""""""""""""
+
+To create batch job via Batch API you can use
+:class:`~airflow.providers.google.cloud.operators.gen_ai.GenAIGeminiCreateBatchJobOperator`.
+The operator returns the job name in :ref:`XCom <concepts:xcom>` under
``job_name`` key.
+
+Two option of input source is allowed: inline requests, file.
+
+If you use inline requests take a look at this example:
+
+.. exampleinclude::
/../../google/tests/system/google/cloud/gen_ai/example_gen_ai_gemini_batch_api.py
+ :language: python
+ :dedent: 4
+ :start-after: [START
how_to_cloud_gen_ai_batch_api_create_batch_job_with_inlined_requests_task]
+ :end-before: [END
how_to_cloud_gen_ai_batch_api_create_batch_job_with_inlined_requests_task]
+
+If you use file take a look at this example:
+
+.. exampleinclude::
/../../google/tests/system/google/cloud/gen_ai/example_gen_ai_gemini_batch_api.py
+ :language: python
+ :dedent: 4
+ :start-after: [START
how_to_cloud_gen_ai_batch_api_create_batch_job_with_file_task]
+ :end-before: [END
how_to_cloud_gen_ai_batch_api_create_batch_job_with_file_task]
+
+Get batch job
+"""""""""""""
+
+To get batch job you can use
+:class:`~airflow.providers.google.cloud.operators.gen_ai.GenAIGeminiGetBatchJobOperator`.
+The operator returns the job name in :ref:`XCom <concepts:xcom>` under
``job_name`` key.
+
+.. exampleinclude::
/../../google/tests/system/google/cloud/gen_ai/example_gen_ai_gemini_batch_api.py
+ :language: python
+ :dedent: 4
+ :start-after: [START how_to_cloud_gen_ai_batch_api_get_batch_job_task]
+ :end-before: [END how_to_cloud_gen_ai_batch_api_get_batch_job_task]
+
+List batch jobs
+"""""""""""""""
+
+To list batch jobs via Batch API you can use
+:class:`~airflow.providers.google.cloud.operators.gen_ai.GenAIGeminiListBatchJobsOperator`.
+The operator returns the job names in :ref:`XCom <concepts:xcom>` under
``job_names`` key.
+
+.. exampleinclude::
/../../google/tests/system/google/cloud/gen_ai/example_gen_ai_gemini_batch_api.py
+ :language: python
+ :dedent: 4
+ :start-after: [START how_to_cloud_gen_ai_batch_api_list_batch_jobs_task]
+ :end-before: [END how_to_cloud_gen_ai_batch_api_list_batch_jobs_task]
+
+Cancel batch job
+""""""""""""""""
+
+To cancel batch job via Batch API you can use
+:class:`~airflow.providers.google.cloud.operators.gen_ai.GenAIGeminiCancelBatchJobOperator`.
+
+.. exampleinclude::
/../../google/tests/system/google/cloud/gen_ai/example_gen_ai_gemini_batch_api.py
+ :language: python
+ :dedent: 4
+ :start-after: [START how_to_cloud_gen_ai_batch_api_cancel_batch_job_task]
+ :end-before: [END how_to_cloud_gen_ai_batch_api_cancel_batch_job_task]
+
+Delete batch job
+""""""""""""""""
+
+To queue batch job for deletion via Batch API you can use
+:class:`~airflow.providers.google.cloud.operators.gen_ai.GenAIGeminiDeleteBatchJobOperator`.
+The job will not be deleted immediately. After submitting it for deletion, it
will still be available
+through GenAIGeminiListBatchJobsOperator or GenAIGeminiGetBatchJobOperator for
some time. It is behavior of the API
+not the operator.
+
+.. exampleinclude::
/../../google/tests/system/google/cloud/gen_ai/example_gen_ai_gemini_batch_api.py
+ :language: python
+ :dedent: 4
+ :start-after: [START how_to_cloud_gen_ai_batch_api_delete_batch_job_task]
+ :end-before: [END how_to_cloud_gen_ai_batch_api_delete_batch_job_task]
+
+Create embeddings
+"""""""""""""""""
+
+To create embeddings batch job via Batch API you can use
+:class:`~airflow.providers.google.cloud.operators.gen_ai.GenAIGeminiCreateEmbeddingsBatchJobOperator`.
+The operator returns the job name in :ref:`XCom <concepts:xcom>` under
``job_name`` key.
+
+Two option of input source is allowed: inline requests, file.
+
+If you use inline requests take a look at this example:
+
+.. exampleinclude::
/../../google/tests/system/google/cloud/gen_ai/example_gen_ai_gemini_batch_api.py
+ :language: python
+ :dedent: 4
+ :start-after: [START
how_to_cloud_gen_ai_batch_api_create_embeddings_with_inlined_requests_task]
+ :end-before: [END
how_to_cloud_gen_ai_batch_api_create_embeddings_with_inlined_requests_task]
+
+If you use file take a look at this example:
+
+.. exampleinclude::
/../../google/tests/system/google/cloud/gen_ai/example_gen_ai_gemini_batch_api.py
+ :language: python
+ :dedent: 4
+ :start-after: [START
how_to_cloud_gen_ai_batch_api_create_embeddings_with_file_task]
+ :end-before: [END
how_to_cloud_gen_ai_batch_api_create_embeddings_with_file_task]
+
+Interacting with Gemini Files API
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+The `Gemini Files API <https://ai.google.dev/gemini-api/docs/files>`__ helps
Gemini to handle various types of
+input data, including text, images, and audio, at the same time. Please note
that Gemini Batch API mostly works with
+files that were uploaded via Gemini Files API. The Files API lets you store up
to 20GB of files per project, with each
+file not exceeding 2GB in size. Files are stored for 48 hours.
+
+Upload file
+"""""""""""
+
+To upload file via Files API you can use
+:class:`~airflow.providers.google.cloud.operators.gen_ai.GenAIGeminiUploadFileOperator`.
+The operator returns the file name in :ref:`XCom <concepts:xcom>` under
``file_name`` key.
+
+.. exampleinclude::
/../../google/tests/system/google/cloud/gen_ai/example_gen_ai_gemini_batch_api.py
+ :language: python
+ :dedent: 4
+ :start-after: [START how_to_cloud_gen_ai_files_api_upload_file_task]
+ :end-before: [END how_to_cloud_gen_ai_files_api_upload_file_task]
+
+Get file
+""""""""
+
+To get file via Files API you can use
+:class:`~airflow.providers.google.cloud.operators.gen_ai.GenAIGeminiGetFileOperator`.
+The operator returns the file_name in :ref:`XCom <concepts:xcom>` under
``file_name`` key.
+
+.. exampleinclude::
/../../google/tests/system/google/cloud/gen_ai/example_gen_ai_gemini_batch_api.py
+ :language: python
+ :dedent: 4
+ :start-after: [START how_to_cloud_gen_ai_files_api_get_file_task]
+ :end-before: [END how_to_cloud_gen_ai_files_api_get_file_task]
+
+List files
+""""""""""
+
+To list files via Files API you can use
+:class:`~airflow.providers.google.cloud.operators.gen_ai.GenAIGeminiListFilesOperator`.
+The operator returns file names in :ref:`XCom <concepts:xcom>` under
``file_names`` key.
+
+.. exampleinclude::
/../../google/tests/system/google/cloud/gen_ai/example_gen_ai_gemini_batch_api.py
+ :language: python
+ :dedent: 4
+ :start-after: [START how_to_cloud_gen_ai_files_api_list_files_task]
+ :end-before: [END how_to_cloud_gen_ai_files_api_list_files_task]
+
+Delete file
+"""""""""""
+
+To delete file via Files API you can use
+:class:`~airflow.providers.google.cloud.operators.gen_ai.GenAIGeminiDeleteFileOperator`.
+
+.. exampleinclude::
/../../google/tests/system/google/cloud/gen_ai/example_gen_ai_gemini_batch_api.py
+ :language: python
+ :dedent: 4
+ :start-after: [START how_to_cloud_gen_ai_files_api_delete_file_task]
+ :end-before: [END how_to_cloud_gen_ai_files_api_delete_file_task]
Reference
^^^^^^^^^
@@ -110,3 +285,5 @@ Reference
For further information, look at:
* `Client Library Documentation
<https://cloud.google.com/vertex-ai/generative-ai/docs/sdks/overview>`__
+* `Gemini Batch API Documentation
<https://ai.google.dev/gemini-api/docs/batch-api>`__
+* `Gemini Files API Documentation
<https://ai.google.dev/gemini-api/docs/files>`__
diff --git
a/providers/google/src/airflow/providers/google/cloud/hooks/gen_ai.py
b/providers/google/src/airflow/providers/google/cloud/hooks/gen_ai.py
index 25c8ad65911..827470cbe2e 100644
--- a/providers/google/src/airflow/providers/google/cloud/hooks/gen_ai.py
+++ b/providers/google/src/airflow/providers/google/cloud/hooks/gen_ai.py
@@ -27,16 +27,23 @@ from google import genai
from airflow.providers.google.common.hooks.base_google import
PROVIDE_PROJECT_ID, GoogleBaseHook
if TYPE_CHECKING:
+ from google.genai.pagers import Pager
from google.genai.types import (
+ BatchJob,
ContentListUnion,
ContentListUnionDict,
CountTokensConfigOrDict,
CountTokensResponse,
+ CreateBatchJobConfig,
CreateCachedContentConfigOrDict,
CreateTuningJobConfigOrDict,
+ DeleteFileResponse,
+ DeleteResourceJob,
EmbedContentConfigOrDict,
EmbedContentResponse,
+ File,
GenerateContentConfig,
+ ListBatchJobsConfig,
TuningDatasetOrDict,
TuningJob,
)
@@ -194,3 +201,184 @@ class GenAIGenerativeModelHook(GoogleBaseHook):
)
return resp.name
+
+
+class GenAIGeminiAPIHook(GoogleBaseHook):
+ """Class for Google Cloud Generative AI Gemini Developer API hook."""
+
+ def __init__(self, gemini_api_key: str, **kwargs) -> None:
+ super().__init__(**kwargs)
+ self.gemini_api_key = gemini_api_key
+
+ def get_genai_client(self):
+ return genai.Client(
+ api_key=self.gemini_api_key,
+ vertexai=False,
+ )
+
+ def get_batch_job(
+ self,
+ job_name: str,
+ ) -> BatchJob:
+ """
+ Get batch job using Gemini Batch API.
+
+ :param job_name: Required. Batch job name.
+ """
+ client = self.get_genai_client()
+ resp = client.batches.get(name=job_name)
+ return resp
+
+ def list_batch_jobs(
+ self,
+ list_batch_jobs_config: ListBatchJobsConfig | dict | None = None,
+ ) -> Pager[BatchJob]:
+ """
+ Get list of batch jobs using Gemini Batch API.
+
+ :param list_batch_jobs_config: Optional. Configuration of returned
iterator.
+ """
+ client = self.get_genai_client()
+ resp = client.batches.list(
+ config=list_batch_jobs_config,
+ )
+ return resp
+
+ def create_batch_job(
+ self,
+ model: str,
+ source: list | str,
+ create_batch_job_config: CreateBatchJobConfig | dict | None = None,
+ ) -> BatchJob:
+ """
+ Create batch job using Gemini Batch API to process large-scale,
non-urgent tasks.
+
+ :param model: Required. Gemini model name to process requests.
+ :param source: Required. Requests that will be sent to chosen model.
+ Can be in format of Inline requests or file name.
+ :param create_batch_job_config: Optional. Configuration parameters for
batch job.
+ """
+ client = self.get_genai_client()
+ resp = client.batches.create(
+ model=model,
+ src=source,
+ config=create_batch_job_config,
+ )
+ return resp
+
+ def delete_batch_job(
+ self,
+ job_name: str,
+ ) -> DeleteResourceJob:
+ """
+ Delete batch job using Gemini Batch API.
+
+ :param job_name: Required. Batch job name.
+ """
+ client = self.get_genai_client()
+ resp = client.batches.delete(name=job_name)
+ return resp
+
+ def cancel_batch_job(
+ self,
+ job_name: str,
+ ) -> None:
+ """
+ Cancel batch job using Gemini Batch API.
+
+ :param job_name: Required. Batch job name.
+ """
+ client = self.get_genai_client()
+ client.batches.cancel(
+ name=job_name,
+ )
+
+ def create_embeddings(
+ self,
+ model: str,
+ source: dict | str,
+ create_embeddings_config: CreateBatchJobConfig | dict | None = None,
+ ) -> BatchJob:
+ """
+ Create batch job for embeddings using Gemini Batch API to process
large-scale, non-urgent tasks.
+
+ :param model: Required. Gemini model name to process requests.
+ :param source: Required. Requests that will be sent to chosen model.
+ Can be in format of Inline requests or file name.
+ :param create_embeddings_config: Optional. Configuration parameters
for embeddings batch job.
+ """
+ client = self.get_genai_client()
+ input_type = "inlined_requests"
+
+ if isinstance(source, str):
+ input_type = "file_name"
+
+ self.log.info("Using %s to create embeddings", input_type)
+
+ resp = client.batches.create_embeddings(
+ model=model,
+ src={input_type: source},
+ config=create_embeddings_config,
+ )
+ return resp
+
+ def upload_file(self, path_to_file: str, upload_file_config: dict | None =
None) -> File:
+ """
+ Upload file for batch job or embeddings batch job using Gemini Files
API.
+
+ :param path_to_file: Required. Path to file on local filesystem.
+ :param upload_file_config: Optional. Configuration for file upload.
+ """
+ client = self.get_genai_client()
+
+ if upload_file_config is None:
+ self.log.info("Default configuration will be used to upload file")
+ try:
+ file_name, file_type = path_to_file.split("/")[-1].split(".")
+ upload_file_config = {"display_name": file_name, "mime_type":
file_type}
+ except ValueError as exc:
+ raise ValueError(
+ "Error during unpacking file name or mime type. Please
check file path"
+ ) from exc
+
+ resp = client.files.upload(
+ file=path_to_file,
+ config=upload_file_config,
+ )
+ return resp
+
+ def get_file(self, file_name: str) -> File:
+ """
+ Get file's metadata for batch job or embeddings batch job using Gemini
Files API.
+
+ :param file_name: Required. Name of the file in Gemini Files API.
+ """
+ client = self.get_genai_client()
+ resp = client.files.get(name=file_name)
+ return resp
+
+ def download_file(self, file_name: str) -> bytes:
+ """
+ Download file for batch job or embeddings batch job using Gemini Files
API.
+
+ :param file_name: Required. Name of the file in Gemini Files API.
+ """
+ client = self.get_genai_client()
+ resp = client.files.download(file=file_name)
+ return resp
+
+ def list_files(self) -> Pager[File]:
+ """List files for stored in Gemini Files API."""
+ client = self.get_genai_client()
+ resp = client.files.list()
+ return resp
+
+ def delete_file(self, file_name: str) -> DeleteFileResponse:
+ """
+ Delete file from Gemini Files API storage.
+
+ :param file_name: Required. Name of the file in Gemini Files API.
+ """
+ client = self.get_genai_client()
+ resp = client.files.delete(name=file_name)
+ return resp
diff --git
a/providers/google/src/airflow/providers/google/cloud/operators/gen_ai.py
b/providers/google/src/airflow/providers/google/cloud/operators/gen_ai.py
index 66328d9b6d3..94e3688ff0e 100644
--- a/providers/google/src/airflow/providers/google/cloud/operators/gen_ai.py
+++ b/providers/google/src/airflow/providers/google/cloud/operators/gen_ai.py
@@ -19,10 +19,17 @@
from __future__ import annotations
+import enum
+import os.path
+import time
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any
+from google.genai.errors import ClientError
+
+from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.hooks.gen_ai import (
+ GenAIGeminiAPIHook,
GenAIGenerativeModelHook,
)
from airflow.providers.google.cloud.operators.cloud_base import
GoogleCloudBaseOperator
@@ -32,10 +39,12 @@ if TYPE_CHECKING:
ContentListUnion,
ContentListUnionDict,
CountTokensConfigOrDict,
+ CreateBatchJobConfig,
CreateCachedContentConfigOrDict,
CreateTuningJobConfigOrDict,
EmbedContentConfigOrDict,
GenerateContentConfig,
+ ListBatchJobsConfig,
TuningDatasetOrDict,
)
@@ -387,3 +396,859 @@ class
GenAICreateCachedContentOperator(GoogleCloudBaseOperator):
context["ti"].xcom_push(key="cached_content",
value=cached_content_name)
return cached_content_name
+
+
+class BatchJobStatus(enum.Enum):
+ """Possible states of batch job in Gemini Batch API."""
+
+ SUCCEEDED = "JOB_STATE_SUCCEEDED"
+ PENDING = "JOB_STATE_PENDING"
+ FAILED = "JOB_STATE_FAILED"
+ RUNNING = "JOB_STATE_RUNNING"
+ CANCELLED = "JOB_STATE_CANCELLED"
+ EXPIRED = "JOB_STATE_EXPIRED"
+
+
+class GenAIGeminiCreateBatchJobOperator(GoogleCloudBaseOperator):
+ """
+ Create Batch job using Gemini Batch API. Use to generate model response
for several requests.
+
+ :param project_id: Required. The ID of the Google Cloud project that the
service belongs to.
+ :param location: Required. The ID of the Google Cloud location that the
service belongs to.
+ :param model: Required. The name of the publisher model to use for Batch
job.
+ :param gemini_api_key: Required. Key to interact with Gemini Batch API.
+ :param input_source: Required. Source of requests, could be inline
requests or file name.
+ :param results_folder: Optional. Path to a folder on local machine where
file with results will be saved.
+ :param create_batch_job_config: Optional. Config for batch job creation.
+ :param wait_until_complete: Optional. Await job completion.
+ :param retrieve_result: Optional. Push the result to XCom. If the
input_source is inline, this pushes
+ the execution result. If a file name is specified, this pushes the
output file path.
+ :param polling_interval: Optional. The interval, in seconds, to poll the
job status.
+ :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
+ :param impersonation_chain: Optional service account to impersonate using
short-term
+ credentials, or chained list of accounts required to get the
access_token
+ of the last account in the list, which will be impersonated in the
request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding
identity, with first
+ account from the list granting this role to the originating account
(templated).
+ """
+
+ template_fields = (
+ "location",
+ "project_id",
+ "impersonation_chain",
+ "model",
+ "create_batch_job_config",
+ "gemini_api_key",
+ "input_source",
+ )
+
+ def __init__(
+ self,
+ *,
+ project_id: str,
+ location: str,
+ model: str,
+ input_source: list | str,
+ gemini_api_key: str,
+ create_batch_job_config: CreateBatchJobConfig | dict | None = None,
+ results_folder: str | None = None,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ retrieve_result: bool = False,
+ wait_until_complete: bool = False,
+ polling_interval: int = 30,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+
+ self.project_id = project_id
+ self.location = location
+ self.gcp_conn_id = gcp_conn_id
+ self.impersonation_chain = impersonation_chain
+ self.model = model
+ self.input_source = input_source
+ self.create_batch_job_config = create_batch_job_config
+ self.gemini_api_key = gemini_api_key
+ self.retrieve_result = retrieve_result
+ self.wait_until_complete = wait_until_complete
+ self.polling_interval = polling_interval
+ self.results_folder = results_folder
+
+ if self.retrieve_result and not self.wait_until_complete:
+ raise AirflowException("Retrieving results is possible only if
wait_until_complete set to True")
+ if self.results_folder and not isinstance(self.input_source, str):
+ raise AirflowException("results_folder works only when
input_source is file name")
+ if self.results_folder and not
os.path.exists(os.path.abspath(self.results_folder)):
+ raise AirflowException("path to results_folder does not exist,
please provide correct path")
+
+ def _wait_until_complete(self, job, polling_interval: int = 30):
+ try:
+ while True:
+ job = self.hook.get_batch_job(job_name=job.name)
+ if job.state.name == BatchJobStatus.SUCCEEDED.value:
+ self.log.info("Job execution completed")
+ break
+ if job.state.name in [
+ BatchJobStatus.FAILED.value,
+ BatchJobStatus.EXPIRED.value,
+ BatchJobStatus.CANCELLED.value,
+ ]:
+ self.log.error("Job execution was not completed!")
+ break
+ self.log.info(
+ "Waiting for job execution, polling interval: %s seconds,
current state: %s",
+ self.polling_interval,
+ job.state.name,
+ )
+ time.sleep(polling_interval)
+ except Exception:
+ raise AirflowException("Something went wrong during waiting of the
batch job.")
+ return job
+
+ def _prepare_results_for_xcom(self, job):
+ results = []
+ if job.dest and job.dest.inlined_responses:
+ self.log.info("Results are inline")
+ for inline_response in job.dest.inlined_responses:
+ if inline_response.response:
+ # Accessing response, structure may vary.
+ try:
+ results.append(inline_response.response.text)
+ except AttributeError:
+ results.append(inline_response.response)
+ elif inline_response.error:
+ self.log.warning("Error found in the inline result")
+ results.append(inline_response.error)
+ elif job.dest and job.dest.file_name:
+ file_content_bytes =
self.hook.download_file(file_name=job.dest.file_name)
+ file_content = file_content_bytes.decode("utf-8")
+ file_name = job.display_name or job.name.replace("/", "-")
+ path_to_file =
os.path.abspath(f"{self.results_folder}/{file_name}.jsonl")
+ with open(path_to_file, "w") as file_with_results:
+ file_with_results.writelines(file_content.splitlines(True))
+ results = path_to_file
+
+ return results
+
+ def execute(self, context: Context):
+ self.hook = GenAIGeminiAPIHook(
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ gemini_api_key=self.gemini_api_key,
+ )
+
+ try:
+ job = self.hook.create_batch_job(
+ model=self.model,
+ source=self.input_source,
+ create_batch_job_config=self.create_batch_job_config,
+ )
+ except Exception as e:
+ raise AirflowException("Something went wrong during creation of
the batch job: %s", e)
+
+ self.log.info("Job with name %s was successfully created!", job.name)
+ context["ti"].xcom_push(key="job_name", value=job.name)
+
+ if self.wait_until_complete:
+ job = self._wait_until_complete(job, self.polling_interval)
+ if self.retrieve_result and job.error is None:
+ job_results = self._prepare_results_for_xcom(job)
+ context["ti"].xcom_push(key="job_results", value=job_results)
+
+ return dict(job)
+
+
+class GenAIGeminiGetBatchJobOperator(GoogleCloudBaseOperator):
+ """
+ Get Batch job using Gemini API.
+
+ :param project_id: Required. The ID of the Google Cloud project that the
service belongs to.
+ :param location: Required. The ID of the Google Cloud location that the
service belongs to.
+ :param model: Required. The name of the publisher model to use for Batch
job.
+ :param gemini_api_key: Required. Key to interact with Gemini Batch API.
+ :param job_name: Required. Name of the batch job.
+ :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
+ :param impersonation_chain: Optional service account to impersonate using
short-term
+ credentials, or chained list of accounts required to get the
access_token
+ of the last account in the list, which will be impersonated in the
request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding
identity, with first
+ account from the list granting this role to the originating account
(templated).
+ """
+
+ template_fields = ("location", "project_id", "impersonation_chain",
"job_name", "gemini_api_key")
+
+ def __init__(
+ self,
+ *,
+ project_id: str,
+ location: str,
+ job_name: str,
+ gemini_api_key: str,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+
+ self.project_id = project_id
+ self.location = location
+ self.gcp_conn_id = gcp_conn_id
+ self.impersonation_chain = impersonation_chain
+ self.job_name = job_name
+ self.gemini_api_key = gemini_api_key
+
+ def execute(self, context: Context):
+ self.hook = GenAIGeminiAPIHook(
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ gemini_api_key=self.gemini_api_key,
+ )
+
+ try:
+ job = self.hook.get_batch_job(job_name=self.job_name)
+ except ValueError:
+ raise AirflowException("Job with name %s not found", self.job_name)
+
+ context["ti"].xcom_push(key="job_status", value=job.state)
+ return dict(job)
+
+
+class GenAIGeminiListBatchJobsOperator(GoogleCloudBaseOperator):
+ """
+ Get list of Batch jobs metadata using Gemini API.
+
+ :param project_id: Required. The ID of the Google Cloud project that the
service belongs to.
+ :param location: Required. The ID of the Google Cloud location that the
service belongs to.
+ :param model: Required. The name of the publisher model to use for Batch
job.
+ :param gemini_api_key: Required. Key to interact with Gemini Batch API.
+ :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
+ :param impersonation_chain: Optional service account to impersonate using
short-term
+ credentials, or chained list of accounts required to get the
access_token
+ of the last account in the list, which will be impersonated in the
request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding
identity, with first
+ account from the list granting this role to the originating account
(templated).
+ """
+
+ template_fields = (
+ "location",
+ "project_id",
+ "impersonation_chain",
+ "list_batch_jobs_config",
+ "gemini_api_key",
+ )
+
+ def __init__(
+ self,
+ *,
+ project_id: str,
+ location: str,
+ gemini_api_key: str,
+ list_batch_jobs_config: ListBatchJobsConfig | dict | None = None,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+
+ self.project_id = project_id
+ self.location = location
+ self.gcp_conn_id = gcp_conn_id
+ self.impersonation_chain = impersonation_chain
+ self.list_batch_jobs_config = list_batch_jobs_config
+ self.gemini_api_key = gemini_api_key
+
+ def execute(self, context: Context):
+ self.hook = GenAIGeminiAPIHook(
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ gemini_api_key=self.gemini_api_key,
+ )
+
+ jobs_list =
self.hook.list_batch_jobs(list_batch_jobs_config=self.list_batch_jobs_config)
+
+ job_names = []
+ job_objs = []
+
+ try:
+ for job in jobs_list:
+ job_names.append(job.name)
+ job_objs.append(job.model_dump(exclude={"dest"}))
+ except RuntimeError:
+ self.log.info("%s jobs found", len(job_names))
+
+ context["ti"].xcom_push(key="job_names", value=job_names)
+
+ return job_objs
+
+
+class GenAIGeminiDeleteBatchJobOperator(GoogleCloudBaseOperator):
+ """
+ Queue a batch job for deletion using the Gemini API.
+
+ The job will not be deleted immediately. After submitting it for deletion,
it will still be available
+ through GenAIGeminiListBatchJobsOperator or GenAIGeminiGetBatchJobOperator
for some time.
+
+ :param project_id: Required. The ID of the Google Cloud project that the
service belongs to.
+ :param location: Required. The ID of the Google Cloud location that the
service belongs to.
+ :param model: Required. The name of the publisher model to use for Batch
job.
+ :param gemini_api_key: Required. Key to interact with Gemini Batch API.
+ :param job_name: Required. Name of the batch job.
+ :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
+ :param impersonation_chain: Optional service account to impersonate using
short-term
+ credentials, or chained list of accounts required to get the
access_token
+ of the last account in the list, which will be impersonated in the
request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding
identity, with first
+ account from the list granting this role to the originating account
(templated).
+ """
+
+ template_fields = ("location", "project_id", "impersonation_chain",
"job_name", "gemini_api_key")
+
+ def __init__(
+ self,
+ *,
+ project_id: str,
+ location: str,
+ job_name: str,
+ gemini_api_key: str,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+
+ self.project_id = project_id
+ self.location = location
+ self.gcp_conn_id = gcp_conn_id
+ self.impersonation_chain = impersonation_chain
+ self.job_name = job_name
+ self.gemini_api_key = gemini_api_key
+
+ def execute(self, context: Context):
+ self.hook = GenAIGeminiAPIHook(
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ gemini_api_key=self.gemini_api_key,
+ )
+
+ try:
+ delete_response =
self.hook.delete_batch_job(job_name=self.job_name)
+ except ValueError:
+ raise AirflowException("Job with name %s was not found",
self.job_name)
+
+ self.log.info("Job with name %s was submitted for deletion.",
self.job_name)
+
+ if delete_response.error:
+ raise AirflowException(
+ "Job with name %s was not deleted due to error: %s",
self.job_name, delete_response.error
+ )
+
+ return delete_response.model_dump()
+
+
+class GenAIGeminiCancelBatchJobOperator(GoogleCloudBaseOperator):
+ """
+ Cancel Batch job using Gemini API.
+
+ :param project_id: Required. The ID of the Google Cloud project that the
service belongs to.
+ :param location: Required. The ID of the Google Cloud location that the
service belongs to.
+ :param model: Required. The name of the publisher model to use for Batch
job.
+ :param gemini_api_key: Required. Key to interact with Gemini Batch API.
+ :param job_name: Required. Name of the batch job.
+ :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
+ :param impersonation_chain: Optional service account to impersonate using
short-term
+ credentials, or chained list of accounts required to get the
access_token
+ of the last account in the list, which will be impersonated in the
request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding
identity, with first
+ account from the list granting this role to the originating account
(templated).
+ """
+
+ template_fields = ("location", "project_id", "impersonation_chain",
"job_name", "gemini_api_key")
+
+ def __init__(
+ self,
+ *,
+ project_id: str,
+ location: str,
+ job_name: str,
+ gemini_api_key: str,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+
+ self.project_id = project_id
+ self.location = location
+ self.gcp_conn_id = gcp_conn_id
+ self.impersonation_chain = impersonation_chain
+ self.job_name = job_name
+ self.gemini_api_key = gemini_api_key
+
+ def execute(self, context: Context):
+ self.hook = GenAIGeminiAPIHook(
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ gemini_api_key=self.gemini_api_key,
+ )
+ self.log.info("Cancelling job with name %s ...", self.job_name)
+
+ try:
+ self.hook.cancel_batch_job(job_name=self.job_name)
+ except ValueError:
+ raise AirflowException("Job with name %s was not found",
self.job_name)
+
+ self.log.info("Job with name %s was successfully cancelled",
self.job_name)
+
+
+class GenAIGeminiCreateEmbeddingsBatchJobOperator(GoogleCloudBaseOperator):
+ """
+ Create embeddings Batch job using Gemini Batch API.
+
+ Use to generate embeddings for words, phrases, sentences, and code for
several requests.
+
+ :param project_id: Required. The ID of the Google Cloud project that the
service belongs to.
+ :param location: Required. The ID of the Google Cloud location that the
service belongs to.
+ :param model: Required. The name of the publisher model to use for Batch
job.
+ :param gemini_api_key: Required. Key to interact with Gemini Batch API.
+ :param input_source: Required. Source of requests, could be inline
requests or file name.
+ :param results_folder: Optional. Path to a folder on local machine where
file with results will be saved.
+ :param create_embeddings_config: Optional. Config for batch job creation.
+ :param wait_until_complete: Optional. Await job completion.
+ :param retrieve_result: Optional. Push the result to XCom. If the
input_source is inline, this pushes
+ the execution result. If a file name is specified, this pushes the
output file path.
+ :param polling_interval: Optional. The interval, in seconds, to poll the
job status.
+ :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
+ :param impersonation_chain: Optional service account to impersonate using
short-term
+ credentials, or chained list of accounts required to get the
access_token
+ of the last account in the list, which will be impersonated in the
request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding
identity, with first
+ account from the list granting this role to the originating account
(templated).
+ """
+
+ template_fields = (
+ "location",
+ "project_id",
+ "impersonation_chain",
+ "model",
+ "create_embeddings_config",
+ "gemini_api_key",
+ "input_source",
+ )
+
+ def __init__(
+ self,
+ *,
+ project_id: str,
+ location: str,
+ model: str,
+ gemini_api_key: str,
+ input_source: dict | str,
+ results_folder: str | None = None,
+ create_embeddings_config: CreateBatchJobConfig | dict | None = None,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ wait_until_complete: bool = False,
+ retrieve_result: bool = False,
+ polling_interval: int = 30,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+
+ self.project_id = project_id
+ self.location = location
+ self.gcp_conn_id = gcp_conn_id
+ self.impersonation_chain = impersonation_chain
+ self.model = model
+ self.input_source = input_source
+ self.create_embeddings_config = create_embeddings_config
+ self.gemini_api_key = gemini_api_key
+ self.wait_until_complete = wait_until_complete
+ self.retrieve_result = retrieve_result
+ self.polling_interval = polling_interval
+ self.results_folder = results_folder
+
+ if self.retrieve_result and not self.wait_until_complete:
+ raise AirflowException("Retrieving results is possible only if
wait_until_complete set to True")
+ if self.results_folder and not isinstance(self.input_source, str):
+ raise AirflowException("results_folder works only when
input_source is file name")
+ if self.results_folder and not
os.path.exists(os.path.abspath(self.results_folder)):
+ raise AirflowException("path to results_folder does not exist,
please provide correct path")
+
+ def _wait_until_complete(self, job, polling_interval: int = 30):
+ try:
+ while True:
+ job = self.hook.get_batch_job(job_name=job.name)
+ if job.state.name == BatchJobStatus.SUCCEEDED.value:
+ self.log.info("Job execution completed")
+ break
+ if job.state.name in [
+ BatchJobStatus.FAILED.value,
+ BatchJobStatus.EXPIRED.value,
+ BatchJobStatus.CANCELLED.value,
+ ]:
+ self.log.error("Job execution was not completed!")
+ break
+ self.log.info(
+ "Waiting for job execution, polling interval: %s seconds,
current state: %s",
+ self.polling_interval,
+ job.state.name,
+ )
+ time.sleep(polling_interval)
+ except Exception as e:
+ raise AirflowException("Something went wrong during waiting of the
batch job: %s", e)
+ return job
+
+ def _prepare_results_for_xcom(self, job):
+ results = []
+ if job.dest and job.dest.inlined_embed_content_responses:
+ self.log.info("Results are inline")
+ for inline_embed_response in
job.dest.inlined_embed_content_responses:
+ if inline_embed_response.response:
+ # Accessing response, structure may vary.
+ try:
+
results.append(dict(inline_embed_response.response.embedding))
+ except AttributeError:
+ results.append(inline_embed_response.response)
+ elif inline_embed_response.error:
+ self.log.warning("Error found in the inline result")
+ results.append(inline_embed_response.error)
+ elif job.dest and job.dest.file_name:
+ file_content_bytes =
self.hook.download_file(file_name=job.dest.file_name)
+ file_content = file_content_bytes.decode("utf-8")
+ file_name = job.display_name or job.name.replace("/", "-")
+ path_to_file =
os.path.abspath(f"{self.results_folder}/{file_name}.jsonl")
+ with open(path_to_file, "w") as file_with_results:
+ file_with_results.writelines(file_content.splitlines(True))
+ results = path_to_file
+
+ return results
+
+ def execute(self, context: Context):
+ self.hook = GenAIGeminiAPIHook(
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ gemini_api_key=self.gemini_api_key,
+ )
+
+ try:
+ embeddings_job = self.hook.create_embeddings(
+ model=self.model,
+ source=self.input_source,
+ create_embeddings_config=self.create_embeddings_config,
+ )
+ except Exception:
+ raise AirflowException("Something went wrong during creation of
the embeddings job.")
+
+ self.log.info("Embeddings Job with name %s was successfully created!",
embeddings_job.name)
+ context["ti"].xcom_push(key="job_name", value=embeddings_job.name)
+
+ if self.wait_until_complete:
+ embeddings_job = self._wait_until_complete(embeddings_job,
self.polling_interval)
+ if self.retrieve_result and embeddings_job.error is None:
+ job_results = self._prepare_results_for_xcom(embeddings_job)
+ context["ti"].xcom_push(key="job_results", value=job_results)
+
+ return embeddings_job.model_dump()
+
+
+class GenAIGeminiUploadFileOperator(GoogleCloudBaseOperator):
+ """
+ Get file uploaded to Gemini Files API.
+
+ The Files API lets you store up to 20GB of files per project, with each
file not exceeding 2GB in size.
+ Supported types are audio files, images, videos, documents, and others.
Files are stored for 48 hours.
+
+ :param project_id: Required. The ID of the Google Cloud project that the
service belongs to.
+ :param location: Required. The ID of the Google Cloud location that the
service belongs to.
+ :param gemini_api_key: Required. Key to interact with Gemini Batch API.
+ :param file_path: Required. Path to file on your local machine.
+ :param upload_file_config: Optional. Metadata configuration for file
upload.
+ Defaults to display name and mime type parsed from file_path.
+ :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
+ :param impersonation_chain: Optional service account to impersonate using
short-term
+ credentials, or chained list of accounts required to get the
access_token
+ of the last account in the list, which will be impersonated in the
request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding
identity, with first
+ account from the list granting this role to the originating account
(templated).
+ """
+
+ template_fields = (
+ "location",
+ "project_id",
+ "impersonation_chain",
+ "file_path",
+ "gemini_api_key",
+ )
+
+ def __init__(
+ self,
+ *,
+ project_id: str,
+ location: str,
+ file_path: str,
+ gemini_api_key: str,
+ upload_file_config: dict | None = None,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+
+ self.project_id = project_id
+ self.location = location
+ self.gcp_conn_id = gcp_conn_id
+ self.impersonation_chain = impersonation_chain
+ self.gemini_api_key = gemini_api_key
+ self.file_path = file_path
+ self.upload_file_config = upload_file_config
+
+ def execute(self, context: Context):
+ self.hook = GenAIGeminiAPIHook(
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ gemini_api_key=self.gemini_api_key,
+ )
+
+ try:
+ file = self.hook.upload_file(
+ path_to_file=self.file_path,
upload_file_config=self.upload_file_config
+ )
+ except RuntimeError as exc:
+ raise exc
+ except ValueError:
+ raise AirflowException("Error during file upload! Check file name
or mime type!")
+ except FileNotFoundError:
+ raise AirflowException("Provided file was not found!")
+
+ self.log.info("File with name %s successfully uploaded!", file.name)
+ context["ti"].xcom_push(key="file_name", value=file.name)
+
+ return file.model_dump()
+
+
+class GenAIGeminiGetFileOperator(GoogleCloudBaseOperator):
+ """
+ Get file's metadata uploaded to Gemini Files API by using
GenAIGeminiUploadFileOperator.
+
+ The Files API lets you store up to 20GB of files per project, with each
file not exceeding 2GB in size.
+ Files are stored for 48 hours.
+
+ :param project_id: Required. The ID of the Google Cloud project that the
service belongs to.
+ :param location: Required. The ID of the Google Cloud location that the
service belongs to.
+ :param gemini_api_key: Required. Key to interact with Gemini Batch API.
+ :param file_name: Required. File name in Gemini Files API to get
+ :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
+ :param impersonation_chain: Optional service account to impersonate using
short-term
+ credentials, or chained list of accounts required to get the
access_token
+ of the last account in the list, which will be impersonated in the
request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding
identity, with first
+ account from the list granting this role to the originating account
(templated).
+ """
+
+ template_fields = (
+ "location",
+ "project_id",
+ "impersonation_chain",
+ "file_name",
+ "gemini_api_key",
+ )
+
+ def __init__(
+ self,
+ *,
+ project_id: str,
+ location: str,
+ file_name: str,
+ gemini_api_key: str,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+
+ self.project_id = project_id
+ self.location = location
+ self.gcp_conn_id = gcp_conn_id
+ self.impersonation_chain = impersonation_chain
+ self.gemini_api_key = gemini_api_key
+ self.file_name = file_name
+
+ def execute(self, context: Context):
+ self.hook = GenAIGeminiAPIHook(
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ gemini_api_key=self.gemini_api_key,
+ )
+ self.log.info("Looking for file with name: %s", self.file_name)
+
+ try:
+ file = self.hook.get_file(file_name=self.file_name)
+ except ClientError:
+ raise AirflowException("File with name %s not found",
self.file_name)
+
+ self.log.info("Find file with name: %s", file.name)
+ context["ti"].xcom_push(key="file_uri", value=file.uri)
+
+ return file.model_dump()
+
+
+class GenAIGeminiListFilesOperator(GoogleCloudBaseOperator):
+ """
+ List files uploaded to Gemini Files API.
+
+ The Files API lets you store up to 20GB of files per project, with each
file not exceeding 2GB in size.
+ Files are stored for 48 hours.
+
+ :param project_id: Required. The ID of the Google Cloud project that the
service belongs to.
+ :param location: Required. The ID of the Google Cloud location that the
service belongs to.
+ :param gemini_api_key: Required. Key to interact with Gemini Batch API.
+ :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
+ :param impersonation_chain: Optional service account to impersonate using
short-term
+ credentials, or chained list of accounts required to get the
access_token
+ of the last account in the list, which will be impersonated in the
request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding
identity, with first
+ account from the list granting this role to the originating account
(templated).
+ """
+
+ template_fields = (
+ "location",
+ "project_id",
+ "impersonation_chain",
+ "gemini_api_key",
+ )
+
+ def __init__(
+ self,
+ *,
+ project_id: str,
+ location: str,
+ gemini_api_key: str,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+
+ self.project_id = project_id
+ self.location = location
+ self.gcp_conn_id = gcp_conn_id
+ self.impersonation_chain = impersonation_chain
+ self.gemini_api_key = gemini_api_key
+
+ def execute(self, context: Context):
+ self.hook = GenAIGeminiAPIHook(
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ gemini_api_key=self.gemini_api_key,
+ )
+
+ files = self.hook.list_files()
+
+ if files:
+ xcom_file_names = []
+ xcom_files = []
+ try:
+ for file in files:
+ xcom_file_names.append(file.name)
+ xcom_files.append(file.model_dump())
+ except RuntimeError:
+ self.log.info("%s files found", len(xcom_files))
+
+ context["ti"].xcom_push(key="file_names", value=xcom_file_names)
+ return xcom_files
+
+ self.log.info("No files found")
+
+
+class GenAIGeminiDeleteFileOperator(GoogleCloudBaseOperator):
+ """
+ Delete file uploaded to Gemini Files API.
+
+ The Files API lets you store up to 20GB of files per project, with each
file not exceeding 2GB in size.
+ Files are stored for 48 hours.
+
+ :param project_id: Required. The ID of the Google Cloud project that the
service belongs to.
+ :param location: Required. The ID of the Google Cloud location that the
service belongs to.
+ :param gemini_api_key: Required. Key to interact with Gemini Batch API.
+ :param file_name: Required. File name in Gemini Files API to delete.
+ :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
+ :param impersonation_chain: Optional service account to impersonate using
short-term
+ credentials, or chained list of accounts required to get the
access_token
+ of the last account in the list, which will be impersonated in the
request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding
identity, with first
+ account from the list granting this role to the originating account
(templated).
+ """
+
+ template_fields = (
+ "location",
+ "project_id",
+ "impersonation_chain",
+ "file_name",
+ "gemini_api_key",
+ )
+
+ def __init__(
+ self,
+ *,
+ project_id: str,
+ location: str,
+ file_name: str,
+ gemini_api_key: str,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+
+ self.project_id = project_id
+ self.location = location
+ self.gcp_conn_id = gcp_conn_id
+ self.impersonation_chain = impersonation_chain
+ self.file_name = file_name
+ self.gemini_api_key = gemini_api_key
+
+ def execute(self, context: Context):
+ self.hook = GenAIGeminiAPIHook(
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ gemini_api_key=self.gemini_api_key,
+ )
+
+ try:
+ delete_response = self.hook.delete_file(file_name=self.file_name)
+ except ClientError:
+ raise AirflowException("File %s not found!", self.file_name)
+
+ self.log.info("File %s was successfully deleted!", self.file_name)
+
+ return delete_response.model_dump()
diff --git
a/providers/google/tests/system/google/cloud/gen_ai/example_gen_ai_gemini_batch_api.py
b/providers/google/tests/system/google/cloud/gen_ai/example_gen_ai_gemini_batch_api.py
new file mode 100644
index 00000000000..d942466c17f
--- /dev/null
+++
b/providers/google/tests/system/google/cloud/gen_ai/example_gen_ai_gemini_batch_api.py
@@ -0,0 +1,341 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Example Airflow DAG for Google Gen AI Gemini Batch API and File API.
+"""
+
+from __future__ import annotations
+
+import os
+from datetime import datetime
+
+try:
+ from airflow.sdk import task
+except ImportError:
+ # Airflow 2 path
+ from airflow.decorators import task # type: ignore[attr-defined,no-redef]
+
+try:
+ from airflow.sdk import TriggerRule
+except ImportError:
+ # Compatibility for Airflow < 3.1
+ from airflow.utils.trigger_rule import TriggerRule # type:
ignore[no-redef,attr-defined]
+
+from pathlib import Path
+
+from airflow.models.dag import DAG
+from airflow.providers.google.cloud.operators.gen_ai import (
+ GenAIGeminiCancelBatchJobOperator,
+ GenAIGeminiCreateBatchJobOperator,
+ GenAIGeminiCreateEmbeddingsBatchJobOperator,
+ GenAIGeminiDeleteBatchJobOperator,
+ GenAIGeminiDeleteFileOperator,
+ GenAIGeminiGetBatchJobOperator,
+ GenAIGeminiGetFileOperator,
+ GenAIGeminiListBatchJobsOperator,
+ GenAIGeminiListFilesOperator,
+ GenAIGeminiUploadFileOperator,
+)
+from airflow.providers.google.common.utils.get_secret import get_secret
+from airflow.providers.standard.operators.bash import BashOperator
+
+ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID", "default")
+PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "default")
+REGION = "us-central1"
+DAG_ID = "gen_ai_gemini_batch_api"
+
+GEMINI_API_KEY = "api_key"
+
+INLINED_REQUESTS_FOR_BATCH_JOB = [
+ {"contents": [{"parts": [{"text": "Tell me a one-sentence joke."}],
"role": "user"}]},
+ {"contents": [{"parts": [{"text": "Why is the sky blue?"}], "role":
"user"}]},
+]
+
+INLINED_REQUESTS_FOR_EMBEDDINGS_BATCH_JOB = {
+ "contents": [{"parts": [{"text": "Why is the sky blue?"}], "role": "user"}]
+}
+
+
+GEMINI_XCOM_API_KEY = "{{ task_instance.xcom_pull('get_gemini_api_key') }}"
+
+LOCAL_FILE_NAME = "gemini_batch_requests.jsonl"
+LOCAL_EMBEDDINGS_FILE_NAME = "gemini_batch_embeddings_requests.jsonl"
+UPLOAD_FILE_PATH = str(Path(__file__).parent / "resources" / LOCAL_FILE_NAME)
+PATH_TO_SAVE_RESULTS = str(Path(__file__).parent / "resources")
+UPLOAD_EMBEDDINGS_FILE_PATH = str(Path(__file__).parent / "resources" /
LOCAL_EMBEDDINGS_FILE_NAME)
+
+UPLOADED_FILE_NAME = (
+ "{{ task_instance.xcom_pull(task_ids='upload_file_for_batch_job_task',
key='file_name') }}"
+)
+UPLOADED_EMBEDDINGS_FILE_NAME = (
+ "{{
task_instance.xcom_pull(task_ids='upload_file_for_embeddings_batch_job_task',
key='file_name') }}"
+)
+
+BATCH_JOB_WITH_INLINED_REQUESTS_NAME = (
+ "{{
task_instance.xcom_pull(task_ids='create_batch_job_using_inlined_requests_task',
key='job_name') }}"
+)
+BATCH_JOB_WITH_FILE_NAME = (
+ "{{ task_instance.xcom_pull(task_ids='create_batch_job_using_file_task',
key='job_name') }}"
+)
+EMBEDDINGS_BATCH_JOB_WITH_INLINED_REQUESTS_NAME = "{{
task_instance.xcom_pull(task_ids='create_embeddings_job_using_inlined_requests_task',
key='job_name') }}"
+EMBEDDINGS_BATCH_JOB_WITH_FILE_NAME = (
+ "{{
task_instance.xcom_pull(task_ids='create_embeddings_job_using_file_task',
key='job_name') }}"
+)
+
+
+with DAG(
+ dag_id=DAG_ID,
+ description="Sample DAG with Gemini Batch API.",
+ schedule="@once",
+ start_date=datetime(2024, 1, 1),
+ catchup=False,
+ tags=["example", "gen_ai", "gemini_batch_api", "gemini_file_api"],
+ render_template_as_native_obj=True,
+) as dag:
+
+ @task
+ def get_gemini_api_key():
+ return get_secret(GEMINI_API_KEY)
+
+ get_gemini_api_key_task = get_gemini_api_key()
+
+ # [START how_to_cloud_gen_ai_files_api_upload_file_task]
+ upload_file = GenAIGeminiUploadFileOperator(
+ task_id="upload_file_for_batch_job_task",
+ project_id=PROJECT_ID,
+ location=REGION,
+ file_path=UPLOAD_FILE_PATH,
+ gemini_api_key=GEMINI_XCOM_API_KEY,
+ )
+ # [END how_to_cloud_gen_ai_files_api_upload_file_task]
+
+ upload_embeddings_file = GenAIGeminiUploadFileOperator(
+ task_id="upload_file_for_embeddings_batch_job_task",
+ project_id=PROJECT_ID,
+ location=REGION,
+ file_path=UPLOAD_EMBEDDINGS_FILE_PATH,
+ gemini_api_key=GEMINI_XCOM_API_KEY,
+ )
+
+ # [START how_to_cloud_gen_ai_files_api_get_file_task]
+ get_file = GenAIGeminiGetFileOperator(
+ task_id="get_file_for_batch_job_task",
+ project_id=PROJECT_ID,
+ location=REGION,
+ file_name=UPLOADED_FILE_NAME,
+ gemini_api_key=GEMINI_XCOM_API_KEY,
+ )
+ # [END how_to_cloud_gen_ai_files_api_get_file_task]
+
+ # [START how_to_cloud_gen_ai_files_api_list_files_task]
+ list_files = GenAIGeminiListFilesOperator(
+ task_id="list_files_files_api_task",
+ project_id=PROJECT_ID,
+ location=REGION,
+ gemini_api_key=GEMINI_XCOM_API_KEY,
+ )
+ # [END how_to_cloud_gen_ai_files_api_list_files_task]
+
+ # [START how_to_cloud_gen_ai_files_api_delete_file_task]
+ delete_file = GenAIGeminiDeleteFileOperator(
+ task_id="delete_file_for_batch_job_task",
+ project_id=PROJECT_ID,
+ location=REGION,
+ file_name=UPLOADED_FILE_NAME,
+ gemini_api_key=GEMINI_XCOM_API_KEY,
+ trigger_rule=TriggerRule.ALL_DONE,
+ )
+ # [END how_to_cloud_gen_ai_files_api_delete_file_task]
+
+ delete_embeddings_file = GenAIGeminiDeleteFileOperator(
+ task_id="delete_file_for_embeddings_batch_job_task",
+ project_id=PROJECT_ID,
+ location=REGION,
+ file_name=UPLOADED_EMBEDDINGS_FILE_NAME,
+ gemini_api_key=GEMINI_XCOM_API_KEY,
+ trigger_rule=TriggerRule.ALL_DONE,
+ )
+
+ # [START
how_to_cloud_gen_ai_batch_api_create_batch_job_with_inlined_requests_task]
+ create_batch_job_using_inlined_requests =
GenAIGeminiCreateBatchJobOperator(
+ task_id="create_batch_job_using_inlined_requests_task",
+ project_id=PROJECT_ID,
+ location=REGION,
+ model="gemini-3-pro-preview",
+ gemini_api_key=GEMINI_XCOM_API_KEY,
+ create_batch_job_config={
+ "display_name": "inlined-requests-batch-job",
+ },
+ input_source=INLINED_REQUESTS_FOR_BATCH_JOB,
+ wait_until_complete=True,
+ retrieve_result=True,
+ )
+ # [END
how_to_cloud_gen_ai_batch_api_create_batch_job_with_inlined_requests_task]
+
+ # [START how_to_cloud_gen_ai_batch_api_create_batch_job_with_file_task]
+ create_batch_job_using_file = GenAIGeminiCreateBatchJobOperator(
+ task_id="create_batch_job_using_file_task",
+ project_id=PROJECT_ID,
+ location=REGION,
+ model="gemini-3-pro-preview",
+ gemini_api_key=GEMINI_XCOM_API_KEY,
+ create_batch_job_config={
+ "display_name": "file-upload-batch-job",
+ },
+ input_source=UPLOADED_FILE_NAME,
+ wait_until_complete=True,
+ retrieve_result=True,
+ results_folder=PATH_TO_SAVE_RESULTS,
+ )
+ # [END how_to_cloud_gen_ai_batch_api_create_batch_job_with_file_task]
+
+ # [START
how_to_cloud_gen_ai_batch_api_create_embeddings_with_inlined_requests_task]
+ create_embeddings_job_using_inlined_requests =
GenAIGeminiCreateEmbeddingsBatchJobOperator(
+ task_id="create_embeddings_job_using_inlined_requests_task",
+ project_id=PROJECT_ID,
+ location=REGION,
+ model="gemini-embedding-001",
+ wait_until_complete=False,
+ gemini_api_key=GEMINI_XCOM_API_KEY,
+ create_embeddings_config={
+ "display_name": "inlined-requests-embeddings-job",
+ },
+ input_source=INLINED_REQUESTS_FOR_EMBEDDINGS_BATCH_JOB,
+ )
+ # [END
how_to_cloud_gen_ai_batch_api_create_embeddings_with_inlined_requests_task]
+
+ # [START how_to_cloud_gen_ai_batch_api_create_embeddings_with_file_task]
+ create_embeddings_job_using_file =
GenAIGeminiCreateEmbeddingsBatchJobOperator(
+ task_id="create_embeddings_job_using_file_task",
+ project_id=PROJECT_ID,
+ location=REGION,
+ model="gemini-embedding-001",
+ wait_until_complete=False,
+ gemini_api_key=GEMINI_XCOM_API_KEY,
+ create_embeddings_config={
+ "display_name": "file-upload-embeddings-job",
+ },
+ input_source=UPLOADED_EMBEDDINGS_FILE_NAME,
+ )
+ # [END how_to_cloud_gen_ai_batch_api_create_embeddings_with_file_task]
+
+ # [START how_to_cloud_gen_ai_batch_api_get_batch_job_task]
+ get_batch_job = GenAIGeminiGetBatchJobOperator(
+ task_id="get_batch_job_task",
+ project_id=PROJECT_ID,
+ location=REGION,
+ gemini_api_key=GEMINI_XCOM_API_KEY,
+ job_name=BATCH_JOB_WITH_INLINED_REQUESTS_NAME,
+ )
+ # [END how_to_cloud_gen_ai_batch_api_get_batch_job_task]
+
+ # [START how_to_cloud_gen_ai_batch_api_list_batch_jobs_task]
+ list_batch_jobs = GenAIGeminiListBatchJobsOperator(
+ task_id="list_batch_jobs_task",
+ project_id=PROJECT_ID,
+ location=REGION,
+ gemini_api_key=GEMINI_XCOM_API_KEY,
+ )
+ # [END how_to_cloud_gen_ai_batch_api_list_batch_jobs_task]
+
+ # [START how_to_cloud_gen_ai_batch_api_cancel_batch_job_task]
+ cancel_batch_job = GenAIGeminiCancelBatchJobOperator(
+ task_id="cancel_batch_job_task",
+ project_id=PROJECT_ID,
+ location=REGION,
+ gemini_api_key=GEMINI_XCOM_API_KEY,
+ job_name=BATCH_JOB_WITH_FILE_NAME,
+ )
+ # [END how_to_cloud_gen_ai_batch_api_cancel_batch_job_task]
+
+ # [START how_to_cloud_gen_ai_batch_api_delete_batch_job_task]
+ delete_batch_job_1 = GenAIGeminiDeleteBatchJobOperator(
+ task_id="delete_batch_job_1_task",
+ project_id=PROJECT_ID,
+ location=REGION,
+ gemini_api_key=GEMINI_XCOM_API_KEY,
+ job_name=BATCH_JOB_WITH_INLINED_REQUESTS_NAME,
+ trigger_rule=TriggerRule.ALL_DONE,
+ )
+ # [END how_to_cloud_gen_ai_batch_api_delete_batch_job_task]
+
+ delete_batch_job_2 = GenAIGeminiDeleteBatchJobOperator(
+ task_id="delete_batch_job_2_task",
+ project_id=PROJECT_ID,
+ location=REGION,
+ gemini_api_key=GEMINI_XCOM_API_KEY,
+ job_name=BATCH_JOB_WITH_FILE_NAME,
+ trigger_rule=TriggerRule.ALL_DONE,
+ )
+
+ delete_embeddings_batch_job_1 = GenAIGeminiDeleteBatchJobOperator(
+ task_id="delete_embeddings_batch_job_1_task",
+ project_id=PROJECT_ID,
+ location=REGION,
+ gemini_api_key=GEMINI_XCOM_API_KEY,
+ job_name=EMBEDDINGS_BATCH_JOB_WITH_INLINED_REQUESTS_NAME,
+ trigger_rule=TriggerRule.ALL_DONE,
+ )
+
+ delete_embeddings_batch_job_2 = GenAIGeminiDeleteBatchJobOperator(
+ task_id="delete_embeddings_batch_job_2_task",
+ project_id=PROJECT_ID,
+ location=REGION,
+ gemini_api_key=GEMINI_XCOM_API_KEY,
+ job_name=EMBEDDINGS_BATCH_JOB_WITH_FILE_NAME,
+ trigger_rule=TriggerRule.ALL_DONE,
+ )
+ delete_result_file = BashOperator(
+ task_id="delete_result_file_task",
+ bash_command="rm "
+ "{{
task_instance.xcom_pull(task_ids='create_batch_job_using_file_task',
key='job_results') }}",
+ trigger_rule=TriggerRule.ALL_DONE,
+ )
+
+ (
+ get_gemini_api_key_task
+ >> [upload_file, upload_embeddings_file]
+ >> get_file
+ >> list_files
+ >> [
+ create_batch_job_using_inlined_requests,
+ create_batch_job_using_file,
+ create_embeddings_job_using_file,
+ create_embeddings_job_using_inlined_requests,
+ ]
+ >> get_batch_job
+ >> list_batch_jobs
+ >> cancel_batch_job
+ >> delete_batch_job_1
+ >> delete_batch_job_2
+ >> delete_embeddings_batch_job_1
+ >> delete_embeddings_batch_job_2
+ >> [delete_file, delete_embeddings_file, delete_result_file]
+ )
+
+ from tests_common.test_utils.watcher import watcher
+
+ # This test needs watcher in order to properly mark success/failure
+ # when "tearDown" task with trigger rule is part of the DAG
+ list(dag.tasks) >> watcher()
+
+from tests_common.test_utils.system_tests import get_test_run # noqa: E402
+
+# Needed to run the example DAG with pytest (see:
tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)
diff --git
a/providers/google/tests/system/google/cloud/gen_ai/resources/gemini_batch_embeddings_requests.jsonl
b/providers/google/tests/system/google/cloud/gen_ai/resources/gemini_batch_embeddings_requests.jsonl
new file mode 100644
index 00000000000..b898b469662
--- /dev/null
+++
b/providers/google/tests/system/google/cloud/gen_ai/resources/gemini_batch_embeddings_requests.jsonl
@@ -0,0 +1,2 @@
+{"key": "request_1", "request": {"output_dimensionality": 3, "content":
{"parts": [{"text": "1"}]}}}
+{"key": "request_2", "request": {"output_dimensionality": 4, "content":
{"parts": [{"text": "2"}]}}}
diff --git
a/providers/google/tests/system/google/cloud/gen_ai/resources/gemini_batch_requests.jsonl
b/providers/google/tests/system/google/cloud/gen_ai/resources/gemini_batch_requests.jsonl
new file mode 100644
index 00000000000..ec67464ed8a
--- /dev/null
+++
b/providers/google/tests/system/google/cloud/gen_ai/resources/gemini_batch_requests.jsonl
@@ -0,0 +1,2 @@
+{"key": "request-1", "request": {"contents": [{"parts": [{"text": "Describe
the process of photosynthesis."}]}]}}
+{"key": "request-2", "request": {"contents": [{"parts": [{"text": "What are
the main ingredients in a Margherita pizza?"}]}]}}
diff --git a/providers/google/tests/unit/google/cloud/hooks/test_gen_ai.py
b/providers/google/tests/unit/google/cloud/hooks/test_gen_ai.py
index e16bf64d969..095c5466461 100644
--- a/providers/google/tests/unit/google/cloud/hooks/test_gen_ai.py
+++ b/providers/google/tests/unit/google/cloud/hooks/test_gen_ai.py
@@ -29,6 +29,7 @@ from google.genai.types import (
)
from airflow.providers.google.cloud.hooks.gen_ai import (
+ GenAIGeminiAPIHook,
GenAIGenerativeModelHook,
)
@@ -95,6 +96,25 @@ CACHED_CONTENT_CONFIG = CreateCachedContentConfig(
BASE_STRING = "airflow.providers.google.common.hooks.base_google.{}"
GENERATIVE_MODEL_STRING = "airflow.providers.google.cloud.hooks.gen_ai.{}"
+TEST_API_KEY = "test-api-key"
+TEST_JOB_NAME = "batches/test-job-id"
+TEST_MODEL = "models/gemini-2.5-flash"
+TEST_BATCH_JOB_SOURCE_INLINE = [
+ {"contents": [{"parts": [{"text": "Tell me a one-sentence joke."}],
"role": "user"}]},
+ {"contents": [{"parts": [{"text": "Why is the sky blue?"}], "role":
"user"}]},
+]
+TEST_EMBEDDINGS_JOB_SOURCE_INLINE = {
+ "contents": [{"parts": [{"text": "Why is the sky blue?"}], "role": "user"}]
+}
+TEST_SOURCE_FILE = "test-bucket/source.jsonl"
+TEST_LOCAL_FILE_PATH = "/tmp/data/test_file.json"
+TEST_FILE_NAME = "files/test-file-id"
+
+# Mock constants for configuration objects
+TEST_LIST_BATCH_JOBS_CONFIG = {"page_size": 10}
+TEST_CREATE_BATCH_JOB_CONFIG = {"display_name": "test-job"}
+TEST_UPLOAD_FILE_CONFIG = {"display_name": "custom_name", "mime_type":
"text/plain"}
+
def assert_warning(msg: str, warnings):
assert any(msg in str(w) for w in warnings)
@@ -191,3 +211,167 @@ class TestGenAIGenerativeModelHookWithDefaultProjectId:
model=TEST_CACHED_MODEL,
config=CACHED_CONTENT_CONFIG,
)
+
+
+class TestGenAIGeminiAPIHook:
+ def setup_method(self):
+ with mock.patch(
+ BASE_STRING.format("GoogleBaseHook.__init__"),
new=mock_base_gcp_hook_default_project_id
+ ):
+ self.hook = GenAIGeminiAPIHook(gemini_api_key=TEST_API_KEY)
+
+ @mock.patch("google.genai.Client")
+ def test_get_genai_client(self, mock_client):
+ """Test client initialization with correct parameters."""
+ self.hook.get_genai_client()
+
+ mock_client.assert_called_once_with(
+ api_key=TEST_API_KEY,
+ vertexai=False,
+ )
+
+
@mock.patch(GENERATIVE_MODEL_STRING.format("GenAIGeminiAPIHook.get_genai_client"))
+ def test_get_batch_job(self, mock_get_client):
+ client_mock = mock_get_client.return_value
+ client_mock.batches = mock.Mock()
+
+ self.hook.get_batch_job(job_name=TEST_JOB_NAME)
+
+ client_mock.batches.get.assert_called_once_with(name=TEST_JOB_NAME)
+
+
@mock.patch(GENERATIVE_MODEL_STRING.format("GenAIGeminiAPIHook.get_genai_client"))
+ def test_list_batch_jobs(self, mock_get_client):
+ client_mock = mock_get_client.return_value
+ client_mock.batches = mock.Mock()
+
+
self.hook.list_batch_jobs(list_batch_jobs_config=TEST_LIST_BATCH_JOBS_CONFIG)
+
+
client_mock.batches.list.assert_called_once_with(config=TEST_LIST_BATCH_JOBS_CONFIG)
+
+
@mock.patch(GENERATIVE_MODEL_STRING.format("GenAIGeminiAPIHook.get_genai_client"))
+ def test_create_batch_job(self, mock_get_client):
+ client_mock = mock_get_client.return_value
+ client_mock.batches = mock.Mock()
+
+ self.hook.create_batch_job(
+ model=TEST_MODEL,
+ source=TEST_BATCH_JOB_SOURCE_INLINE,
+ create_batch_job_config=TEST_CREATE_BATCH_JOB_CONFIG,
+ )
+
+ client_mock.batches.create.assert_called_once_with(
+ model=TEST_MODEL, src=TEST_BATCH_JOB_SOURCE_INLINE,
config=TEST_CREATE_BATCH_JOB_CONFIG
+ )
+
+
@mock.patch(GENERATIVE_MODEL_STRING.format("GenAIGeminiAPIHook.get_genai_client"))
+ def test_delete_batch_job(self, mock_get_client):
+ client_mock = mock_get_client.return_value
+ client_mock.batches = mock.Mock()
+
+ self.hook.delete_batch_job(job_name=TEST_JOB_NAME)
+
+ client_mock.batches.delete.assert_called_once_with(name=TEST_JOB_NAME)
+
+
@mock.patch(GENERATIVE_MODEL_STRING.format("GenAIGeminiAPIHook.get_genai_client"))
+ def test_cancel_batch_job(self, mock_get_client):
+ client_mock = mock_get_client.return_value
+ client_mock.batches = mock.Mock()
+
+ self.hook.cancel_batch_job(job_name=TEST_JOB_NAME)
+
+ client_mock.batches.cancel.assert_called_once_with(name=TEST_JOB_NAME)
+
+
@mock.patch(GENERATIVE_MODEL_STRING.format("GenAIGeminiAPIHook.get_genai_client"))
+ def test_create_embeddings_with_inline_source(self, mock_get_client):
+ """Test create_embeddings when source is a dict (inline)."""
+ client_mock = mock_get_client.return_value
+ client_mock.batches = mock.Mock()
+
+ self.hook.create_embeddings(
+ model=TEST_MODEL,
+ source=TEST_EMBEDDINGS_JOB_SOURCE_INLINE,
+ create_embeddings_config=TEST_CREATE_BATCH_JOB_CONFIG,
+ )
+
+ client_mock.batches.create_embeddings.assert_called_once_with(
+ model=TEST_MODEL,
+ src={"inlined_requests": TEST_EMBEDDINGS_JOB_SOURCE_INLINE},
+ config=TEST_CREATE_BATCH_JOB_CONFIG,
+ )
+
+
@mock.patch(GENERATIVE_MODEL_STRING.format("GenAIGeminiAPIHook.get_genai_client"))
+ def test_create_embeddings_with_file_source(self, mock_get_client):
+ """Test create_embeddings when source is a string (file name)."""
+ client_mock = mock_get_client.return_value
+ client_mock.batches = mock.Mock()
+
+ # Test with str (File name)
+ source_file = "/bucket/file.jsonl"
+ self.hook.create_embeddings(
+ model=TEST_MODEL, source=source_file,
create_embeddings_config=TEST_CREATE_BATCH_JOB_CONFIG
+ )
+
+ client_mock.batches.create_embeddings.assert_called_once_with(
+ model=TEST_MODEL, src={"file_name": source_file},
config=TEST_CREATE_BATCH_JOB_CONFIG
+ )
+
+
@mock.patch(GENERATIVE_MODEL_STRING.format("GenAIGeminiAPIHook.get_genai_client"))
+ def test_upload_file_with_provided_config(self, mock_get_client):
+ """Test upload_file when explicit config is provided."""
+ client_mock = mock_get_client.return_value
+ client_mock.files = mock.Mock()
+
+ self.hook.upload_file(path_to_file=TEST_LOCAL_FILE_PATH,
upload_file_config=TEST_UPLOAD_FILE_CONFIG)
+
+ client_mock.files.upload.assert_called_once_with(
+ file=TEST_LOCAL_FILE_PATH, config=TEST_UPLOAD_FILE_CONFIG
+ )
+
+
@mock.patch(GENERATIVE_MODEL_STRING.format("GenAIGeminiAPIHook.get_genai_client"))
+ def test_upload_file_default_config_generation(self, mock_get_client):
+ """Test that upload_file generates correct config from filename if
config is None."""
+ client_mock = mock_get_client.return_value
+ client_mock.files = mock.Mock()
+
+ # Path: /tmp/data/test_file.json -> name: test_file, type: json
+ self.hook.upload_file(path_to_file=TEST_LOCAL_FILE_PATH,
upload_file_config=None)
+
+ expected_config = {"display_name": "test_file", "mime_type": "json"}
+
+
client_mock.files.upload.assert_called_once_with(file=TEST_LOCAL_FILE_PATH,
config=expected_config)
+
+
@mock.patch(GENERATIVE_MODEL_STRING.format("GenAIGeminiAPIHook.get_genai_client"))
+ def test_get_file(self, mock_get_client):
+ client_mock = mock_get_client.return_value
+ client_mock.files = mock.Mock()
+
+ self.hook.get_file(file_name=TEST_FILE_NAME)
+
+ client_mock.files.get.assert_called_once_with(name=TEST_FILE_NAME)
+
+
@mock.patch(GENERATIVE_MODEL_STRING.format("GenAIGeminiAPIHook.get_genai_client"))
+ def test_download_file(self, mock_get_client):
+ client_mock = mock_get_client.return_value
+ client_mock.files = mock.Mock()
+
+ self.hook.download_file(file_name=TEST_FILE_NAME)
+
+ client_mock.files.download.assert_called_once_with(file=TEST_FILE_NAME)
+
+
@mock.patch(GENERATIVE_MODEL_STRING.format("GenAIGeminiAPIHook.get_genai_client"))
+ def test_list_files(self, mock_get_client):
+ client_mock = mock_get_client.return_value
+ client_mock.files = mock.Mock()
+
+ self.hook.list_files()
+
+ client_mock.files.list.assert_called_once()
+
+
@mock.patch(GENERATIVE_MODEL_STRING.format("GenAIGeminiAPIHook.get_genai_client"))
+ def test_delete_file(self, mock_get_client):
+ client_mock = mock_get_client.return_value
+ client_mock.files = mock.Mock()
+
+ self.hook.delete_file(file_name=TEST_FILE_NAME)
+
+ client_mock.files.delete.assert_called_once_with(name=TEST_FILE_NAME)
diff --git a/providers/google/tests/unit/google/cloud/operators/test_gen_ai.py
b/providers/google/tests/unit/google/cloud/operators/test_gen_ai.py
index aa1a4640a0d..4489ec84b94 100644
--- a/providers/google/tests/unit/google/cloud/operators/test_gen_ai.py
+++ b/providers/google/tests/unit/google/cloud/operators/test_gen_ai.py
@@ -31,6 +31,16 @@ from google.genai.types import (
from airflow.providers.google.cloud.operators.gen_ai import (
GenAICountTokensOperator,
GenAICreateCachedContentOperator,
+ GenAIGeminiCancelBatchJobOperator,
+ GenAIGeminiCreateBatchJobOperator,
+ GenAIGeminiCreateEmbeddingsBatchJobOperator,
+ GenAIGeminiDeleteBatchJobOperator,
+ GenAIGeminiDeleteFileOperator,
+ GenAIGeminiGetBatchJobOperator,
+ GenAIGeminiGetFileOperator,
+ GenAIGeminiListBatchJobsOperator,
+ GenAIGeminiListFilesOperator,
+ GenAIGeminiUploadFileOperator,
GenAIGenerateContentOperator,
GenAIGenerateEmbeddingsOperator,
GenAISupervisedFineTuningTrainOperator,
@@ -84,6 +94,20 @@ GENERATE_FROM_CACHED_MODEL_CONFIG = {
"cached_content": "cached_name",
}
+TEST_BATCH_JOB_INLINED_REQUESTS = [
+ {"contents": [{"parts": [{"text": "Tell me a one-sentence joke."}],
"role": "user"}]},
+ {"contents": [{"parts": [{"text": "Why is the sky blue?"}], "role":
"user"}]},
+]
+
+TEST_EMBEDDINGS_JOB_INLINED_REQUESTS = {
+ "contents": [{"parts": [{"text": "Why is the sky blue?"}], "role": "user"}]
+}
+TEST_GEMINI_API_KEY = "test-key"
+TEST_GEMINI_MODEL = "test-gemini-model"
+TEST_BATCH_JOB_NAME = "test-name"
+TEST_FILE_NAME = "test-file"
+TEST_FILE_PATH = "test/path/to/file"
+
def assert_warning(msg: str, warnings):
assert any(msg in str(w) for w in warnings)
@@ -248,3 +272,238 @@ class TestGenAIGenerateFromCachedContentOperator:
contents=CONTENTS,
generation_config=GENERATE_FROM_CACHED_MODEL_CONFIG,
)
+
+
+class TestGenAIGeminiCreateBatchJobOperator:
+ @mock.patch(GEN_AI_PATH.format("GenAIGeminiAPIHook"))
+ def test_execute(self, mock_hook):
+ op = GenAIGeminiCreateBatchJobOperator(
+ task_id=TASK_ID,
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ model=TEST_GEMINI_MODEL,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ input_source=TEST_BATCH_JOB_INLINED_REQUESTS,
+ gemini_api_key=TEST_GEMINI_API_KEY,
+ wait_until_complete=False,
+ retrieve_result=False,
+ )
+ op.execute(context={"ti": mock.MagicMock()})
+ mock_hook.assert_called_once_with(
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ gemini_api_key=TEST_GEMINI_API_KEY,
+ )
+ mock_hook.return_value.create_batch_job.assert_called_once_with(
+ source=TEST_BATCH_JOB_INLINED_REQUESTS,
+ model=TEST_GEMINI_MODEL,
+ create_batch_job_config=None,
+ )
+
+
+class TestGenAIGeminiGetBatchJobOperator:
+ @mock.patch(GEN_AI_PATH.format("GenAIGeminiAPIHook"))
+ def test_execute(self, mock_hook):
+ op = GenAIGeminiGetBatchJobOperator(
+ task_id=TASK_ID,
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ job_name=TEST_BATCH_JOB_NAME,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ gemini_api_key=TEST_GEMINI_API_KEY,
+ )
+ op.execute(context={"ti": mock.MagicMock()})
+ mock_hook.assert_called_once_with(
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ gemini_api_key=TEST_GEMINI_API_KEY,
+ )
+ mock_hook.return_value.get_batch_job.assert_called_once_with(
+ job_name=TEST_BATCH_JOB_NAME,
+ )
+
+
+class TestGenAIGeminiListBatchJobsOperator:
+ @mock.patch(GEN_AI_PATH.format("GenAIGeminiAPIHook"))
+ def test_execute(self, mock_hook):
+ op = GenAIGeminiListBatchJobsOperator(
+ task_id=TASK_ID,
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ gemini_api_key=TEST_GEMINI_API_KEY,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ )
+ op.execute(context={"ti": mock.MagicMock()})
+ mock_hook.assert_called_once_with(
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ gemini_api_key=TEST_GEMINI_API_KEY,
+ )
+
mock_hook.return_value.list_batch_jobs.assert_called_once_with(list_batch_jobs_config=None)
+
+
+class TestGenAIGeminiDeleteBatchJobOperator:
+ @mock.patch(GEN_AI_PATH.format("GenAIGeminiAPIHook"))
+ def test_execute(self, mock_hook):
+ mock_hook.return_value.delete_batch_job.return_value =
mock.MagicMock(error=False)
+ op = GenAIGeminiDeleteBatchJobOperator(
+ task_id=TASK_ID,
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ job_name=TEST_BATCH_JOB_NAME,
+ gemini_api_key=TEST_GEMINI_API_KEY,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ )
+ op.execute(context={"ti": mock.MagicMock()})
+ mock_hook.assert_called_once_with(
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ gemini_api_key=TEST_GEMINI_API_KEY,
+ )
+ mock_hook.return_value.delete_batch_job.assert_called_once_with(
+ job_name=TEST_BATCH_JOB_NAME,
+ )
+
+
+class TestGenAIGeminiCancelBatchJobOperator:
+ @mock.patch(GEN_AI_PATH.format("GenAIGeminiAPIHook"))
+ def test_execute(self, mock_hook):
+ op = GenAIGeminiCancelBatchJobOperator(
+ task_id=TASK_ID,
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ job_name=TEST_BATCH_JOB_NAME,
+ gemini_api_key=TEST_GEMINI_API_KEY,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ )
+ op.execute(context={"ti": mock.MagicMock()})
+ mock_hook.assert_called_once_with(
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ gemini_api_key=TEST_GEMINI_API_KEY,
+ )
+ mock_hook.return_value.cancel_batch_job.assert_called_once_with(
+ job_name=TEST_BATCH_JOB_NAME,
+ )
+
+
+class TestGenAIGeminiCreateEmbeddingsBatchJobOperator:
+ @mock.patch(GEN_AI_PATH.format("GenAIGeminiAPIHook"))
+ def test_execute(self, mock_hook):
+ op = GenAIGeminiCreateEmbeddingsBatchJobOperator(
+ task_id=TASK_ID,
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ input_source=TEST_EMBEDDINGS_JOB_INLINED_REQUESTS,
+ model=EMBEDDING_MODEL,
+ gemini_api_key=TEST_GEMINI_API_KEY,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ wait_until_complete=False,
+ )
+ op.execute(context={"ti": mock.MagicMock()})
+ mock_hook.assert_called_once_with(
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ gemini_api_key=TEST_GEMINI_API_KEY,
+ )
+ mock_hook.return_value.create_embeddings.assert_called_once_with(
+ source=TEST_EMBEDDINGS_JOB_INLINED_REQUESTS,
+ model=EMBEDDING_MODEL,
+ create_embeddings_config=None,
+ )
+
+
+class TestGenAIGeminiGetFileOperator:
+ @mock.patch(GEN_AI_PATH.format("GenAIGeminiAPIHook"))
+ def test_execute(self, mock_hook):
+ op = GenAIGeminiGetFileOperator(
+ task_id=TASK_ID,
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ gemini_api_key=TEST_GEMINI_API_KEY,
+ file_name=TEST_FILE_NAME,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ )
+ op.execute(context={"ti": mock.MagicMock()})
+ mock_hook.assert_called_once_with(
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ gemini_api_key=TEST_GEMINI_API_KEY,
+ )
+ mock_hook.return_value.get_file.assert_called_once_with(
+ file_name=TEST_FILE_NAME,
+ )
+
+
+class TestGenAIGeminiUploadFileOperator:
+ @mock.patch(GEN_AI_PATH.format("GenAIGeminiAPIHook"))
+ def test_execute(self, mock_hook):
+ op = GenAIGeminiUploadFileOperator(
+ task_id=TASK_ID,
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ file_path=TEST_FILE_PATH,
+ gemini_api_key=TEST_GEMINI_API_KEY,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ )
+ op.execute(context={"ti": mock.MagicMock()})
+ mock_hook.assert_called_once_with(
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ gemini_api_key=TEST_GEMINI_API_KEY,
+ )
+ mock_hook.return_value.upload_file.assert_called_once_with(
+ path_to_file=TEST_FILE_PATH,
+ upload_file_config=None,
+ )
+
+
+class TestGenAIGeminiListFilesOperator:
+ @mock.patch(GEN_AI_PATH.format("GenAIGeminiAPIHook"))
+ def test_execute(self, mock_hook):
+ op = GenAIGeminiListFilesOperator(
+ task_id=TASK_ID,
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ gemini_api_key=TEST_GEMINI_API_KEY,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ )
+ op.execute(context={"ti": mock.MagicMock()})
+ mock_hook.assert_called_once_with(
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ gemini_api_key=TEST_GEMINI_API_KEY,
+ )
+ mock_hook.return_value.list_files.assert_called_once_with()
+
+
+class TestGenAIGeminiDeleteFileOperator:
+ @mock.patch(GEN_AI_PATH.format("GenAIGeminiAPIHook"))
+ def test_execute(self, mock_hook):
+ op = GenAIGeminiDeleteFileOperator(
+ task_id=TASK_ID,
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ file_name=TEST_FILE_NAME,
+ gemini_api_key=TEST_GEMINI_API_KEY,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ )
+ op.execute(context={"ti": mock.MagicMock()})
+ mock_hook.assert_called_once_with(
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ gemini_api_key=TEST_GEMINI_API_KEY,
+ )
+ mock_hook.return_value.delete_file.assert_called_once_with(
+ file_name=TEST_FILE_NAME,
+ )