This is an automated email from the ASF dual-hosted git repository.
shahar pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new f278e62255 Minor fixes to ensure successful Vertex AI LLMops pipeline
(#41997)
f278e62255 is described below
commit f278e62255e513c26f26c76e41a8734ec36fb07a
Author: Christian Yarros <[email protected]>
AuthorDate: Thu Sep 19 15:40:23 2024 -0400
Minor fixes to ensure successful Vertex AI LLMops pipeline (#41997)
* generative ai operator cleanup
* return fix
---
.../cloud/hooks/vertex_ai/generative_model.py | 48 +++++++++-------
.../cloud/operators/vertex_ai/generative_model.py | 65 ++++++++++++----------
.../cloud/hooks/vertex_ai/test_generative_model.py | 9 ++-
.../operators/vertex_ai/test_generative_model.py | 3 +
4 files changed, 73 insertions(+), 52 deletions(-)
diff --git a/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py
b/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py
index acadbb788a..04242306f8 100644
--- a/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py
+++ b/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py
@@ -128,6 +128,8 @@ class GenerativeModelHook(GoogleBaseHook):
"""
Use the Vertex AI PaLM API to generate natural language text.
+ :param project_id: Required. The ID of the Google Cloud project that
the service belongs to.
+ :param location: Required. The ID of the Google Cloud location that
the service belongs to.
:param prompt: Required. Inputs or queries that a user or a program
gives
to the Vertex AI PaLM API, in order to elicit a specific response.
:param pretrained_model: A pre-trained model optimized for performing
natural
@@ -141,8 +143,6 @@ class GenerativeModelHook(GoogleBaseHook):
of their probabilities equals the top_p value. Defaults to 0.8.
:param top_k: A top_k of 1 means the selected token is the most
probable
among all tokens.
- :param location: Required. The ID of the Google Cloud location that
the service belongs to.
- :param project_id: Required. The ID of the Google Cloud project that
the service belongs to.
"""
vertexai.init(project=project_id, location=location,
credentials=self.get_credentials())
@@ -178,11 +178,11 @@ class GenerativeModelHook(GoogleBaseHook):
"""
Use the Vertex AI PaLM API to generate text embeddings.
+ :param project_id: Required. The ID of the Google Cloud project that
the service belongs to.
+ :param location: Required. The ID of the Google Cloud location that
the service belongs to.
:param prompt: Required. Inputs or queries that a user or a program
gives
to the Vertex AI PaLM API, in order to elicit a specific response.
:param pretrained_model: A pre-trained model optimized for generating
text embeddings.
- :param location: Required. The ID of the Google Cloud location that
the service belongs to.
- :param project_id: Required. The ID of the Google Cloud project that
the service belongs to.
"""
vertexai.init(project=project_id, location=location,
credentials=self.get_credentials())
model = self.get_text_embedding_model(pretrained_model)
@@ -210,16 +210,16 @@ class GenerativeModelHook(GoogleBaseHook):
"""
Use the Vertex AI Gemini Pro foundation model to generate natural
language text.
+ :param project_id: Required. The ID of the Google Cloud project that
the service belongs to.
+ :param location: Required. The ID of the Google Cloud location that
the service belongs to.
:param prompt: Required. Inputs or queries that a user or a program
gives
to the Multi-modal model, in order to elicit a specific response.
- :param location: Required. The ID of the Google Cloud location that
the service belongs to.
:param generation_config: Optional. Generation configuration settings.
:param safety_settings: Optional. Per request settings for blocking
unsafe content.
:param pretrained_model: By default uses the pre-trained model
`gemini-pro`,
supporting prompts with text-only input, including natural language
tasks, multi-turn text and code chat, and code generation. It can
output text and code.
- :param project_id: Required. The ID of the Google Cloud project that
the service belongs to.
"""
vertexai.init(project=project_id, location=location,
credentials=self.get_credentials())
@@ -251,6 +251,8 @@ class GenerativeModelHook(GoogleBaseHook):
"""
Use the Vertex AI Gemini Pro foundation model to generate natural
language text.
+ :param project_id: Required. The ID of the Google Cloud project that
the service belongs to.
+ :param location: Required. The ID of the Google Cloud location that
the service belongs to.
:param prompt: Required. Inputs or queries that a user or a program
gives
to the Multi-modal model, in order to elicit a specific response.
:param generation_config: Optional. Generation configuration settings.
@@ -262,8 +264,6 @@ class GenerativeModelHook(GoogleBaseHook):
:param media_gcs_path: A GCS path to a content file such as an image
or a video.
Can be passed to the multi-modal model as part of the prompt. Used
with vision models.
:param mime_type: Validates the media type presented by the file in
the media_gcs_path.
- :param location: Required. The ID of the Google Cloud location that
the service belongs to.
- :param project_id: Required. The ID of the Google Cloud project that
the service belongs to.
"""
vertexai.init(project=project_id, location=location,
credentials=self.get_credentials())
@@ -290,6 +290,8 @@ class GenerativeModelHook(GoogleBaseHook):
"""
Use the Vertex AI PaLM API to generate natural language text.
+ :param project_id: Required. The ID of the Google Cloud project that
the service belongs to.
+ :param location: Required. The ID of the Google Cloud location that
the service belongs to.
:param prompt: Required. Inputs or queries that a user or a program
gives
to the Vertex AI PaLM API, in order to elicit a specific response.
:param pretrained_model: A pre-trained model optimized for performing
natural
@@ -303,8 +305,6 @@ class GenerativeModelHook(GoogleBaseHook):
of their probabilities equals the top_p value. Defaults to 0.8.
:param top_k: A top_k of 1 means the selected token is the most
probable
among all tokens.
- :param location: Required. The ID of the Google Cloud location that
the service belongs to.
- :param project_id: Required. The ID of the Google Cloud project that
the service belongs to.
"""
vertexai.init(project=project_id, location=location,
credentials=self.get_credentials())
@@ -334,11 +334,11 @@ class GenerativeModelHook(GoogleBaseHook):
"""
Use the Vertex AI PaLM API to generate text embeddings.
+ :param project_id: Required. The ID of the Google Cloud project that
the service belongs to.
+ :param location: Required. The ID of the Google Cloud location that
the service belongs to.
:param prompt: Required. Inputs or queries that a user or a program
gives
to the Vertex AI PaLM API, in order to elicit a specific response.
:param pretrained_model: A pre-trained model optimized for generating
text embeddings.
- :param location: Required. The ID of the Google Cloud location that
the service belongs to.
- :param project_id: Required. The ID of the Google Cloud project that
the service belongs to.
"""
vertexai.init(project=project_id, location=location,
credentials=self.get_credentials())
model = self.get_text_embedding_model(pretrained_model)
@@ -355,26 +355,31 @@ class GenerativeModelHook(GoogleBaseHook):
tools: list | None = None,
generation_config: dict | None = None,
safety_settings: dict | None = None,
+ system_instruction: str | None = None,
pretrained_model: str = "gemini-pro",
project_id: str = PROVIDE_PROJECT_ID,
) -> str:
"""
Use the Vertex AI Gemini Pro foundation model to generate natural
language text.
+ :param location: Required. The ID of the Google Cloud location that
the service belongs to.
+ :param project_id: Required. The ID of the Google Cloud project that
the service belongs to.
:param contents: Required. The multi-part content of a message that a
user or a program
gives to the generative model, in order to elicit a specific
response.
- :param location: Required. The ID of the Google Cloud location that
the service belongs to.
:param generation_config: Optional. Generation configuration settings.
:param safety_settings: Optional. Per request settings for blocking
unsafe content.
+ :param tools: Optional. A list of tools available to the model during
evaluation, such as a data store.
+ :param system_instruction: Optional. An instruction given to the model
to guide its behavior.
:param pretrained_model: By default uses the pre-trained model
`gemini-pro`,
supporting prompts with text-only input, including natural language
tasks, multi-turn text and code chat, and code generation. It can
output text and code.
- :param project_id: Required. The ID of the Google Cloud project that
the service belongs to.
"""
vertexai.init(project=project_id, location=location,
credentials=self.get_credentials())
- model = self.get_generative_model(pretrained_model)
+ model = self.get_generative_model(
+ pretrained_model=pretrained_model,
system_instruction=system_instruction
+ )
response = model.generate_content(
contents=contents,
tools=tools,
@@ -400,12 +405,13 @@ class GenerativeModelHook(GoogleBaseHook):
"""
Use the Supervised Fine Tuning API to create a tuning job.
+ :param project_id: Required. The ID of the Google Cloud project that
the service belongs to.
+ :param location: Required. The ID of the Google Cloud location that
the service belongs to.
:param source_model: Required. A pre-trained model optimized for
performing natural
language tasks such as classification, summarization, extraction,
content
creation, and ideation.
:param train_dataset: Required. Cloud Storage URI of your training
dataset. The dataset
must be formatted as a JSONL file. For best results, provide at
least 100 to 500 examples.
- :param location: Required. The ID of the Google Cloud location that
the service belongs to.
:param tuned_model_display_name: Optional. Display name of the
TunedModel. The name can be up
to 128 characters long and can consist of any UTF-8 characters.
:param validation_dataset: Optional. Cloud Storage URI of your
training dataset. The dataset must be
@@ -447,18 +453,18 @@ class GenerativeModelHook(GoogleBaseHook):
"""
Use the Vertex AI Count Tokens API to calculate the number of input
tokens before sending a request to the Gemini API.
+ :param project_id: Required. The ID of the Google Cloud project that
the service belongs to.
+ :param location: Required. The ID of the Google Cloud location that
the service belongs to.
:param contents: Required. The multi-part content of a message that a
user or a program
gives to the generative model, in order to elicit a specific
response.
- :param location: Required. The ID of the Google Cloud location that
the service belongs to.
:param pretrained_model: By default uses the pre-trained model
`gemini-pro`,
supporting prompts with text-only input, including natural language
tasks, multi-turn text and code chat, and code generation. It can
output text and code.
- :param project_id: Required. The ID of the Google Cloud project that
the service belongs to.
"""
vertexai.init(project=project_id, location=location,
credentials=self.get_credentials())
- model = self.get_generative_model(pretrained_model)
+ model = self.get_generative_model(pretrained_model=pretrained_model)
response = model.count_tokens(
contents=contents,
)
@@ -484,6 +490,8 @@ class GenerativeModelHook(GoogleBaseHook):
"""
Use the Rapid Evaluation API to evaluate a model.
+ :param project_id: Required. The ID of the Google Cloud project that
the service belongs to.
+ :param location: Required. The ID of the Google Cloud location that
the service belongs to.
:param pretrained_model: Required. A pre-trained model optimized for
performing natural
language tasks such as classification, summarization, extraction,
content
creation, and ideation.
@@ -492,8 +500,6 @@ class GenerativeModelHook(GoogleBaseHook):
:param experiment_name: Required. The name of the evaluation
experiment.
:param experiment_run_name: Required. The specific run name or ID for
this experiment.
:param prompt_template: Required. The template used to format the
model's prompts during evaluation. Adheres to Rapid Evaluation API.
- :param project_id: Required. The ID of the Google Cloud project that
the service belongs to.
- :param location: Required. The ID of the Google Cloud location that
the service belongs to.
:param generation_config: Optional. A dictionary containing generation
parameters for the model.
:param safety_settings: Optional. A dictionary specifying harm
category thresholds for blocking model outputs.
:param system_instruction: Optional. An instruction given to the model
to guide its behavior.
diff --git
a/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py
b/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py
index b0b9462e21..fddd5dcf72 100644
--- a/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py
+++ b/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py
@@ -21,7 +21,6 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Sequence
-from google.cloud.aiplatform_v1 import types as types_v1
from google.cloud.aiplatform_v1beta1 import types as types_v1beta1
from airflow.exceptions import AirflowProviderDeprecationWarning
@@ -510,12 +509,14 @@ class
GenerativeModelGenerateContentOperator(GoogleCloudBaseOperator):
:param project_id: Required. The ID of the Google Cloud project that the
service belongs to (templated).
- :param contents: Required. The multi-part content of a message that a user
or a program
- gives to the generative model, in order to elicit a specific response.
:param location: Required. The ID of the Google Cloud location that the
service belongs to (templated).
+ :param contents: Required. The multi-part content of a message that a user
or a program
+ gives to the generative model, in order to elicit a specific response.
:param generation_config: Optional. Generation configuration settings.
:param safety_settings: Optional. Per request settings for blocking unsafe
content.
+ :param tools: Optional. A list of tools available to the model during
evaluation, such as a data store.
+ :param system_instruction: Optional. An instruction given to the model to
guide its behavior.
:param pretrained_model: By default uses the pre-trained model
`gemini-pro`,
supporting prompts with text-only input, including natural language
tasks, multi-turn text and code chat, and code generation. It can
@@ -537,11 +538,12 @@ class
GenerativeModelGenerateContentOperator(GoogleCloudBaseOperator):
self,
*,
project_id: str,
- contents: list,
location: str,
+ contents: list,
tools: list | None = None,
generation_config: dict | None = None,
safety_settings: dict | None = None,
+ system_instruction: str | None = None,
pretrained_model: str = "gemini-pro",
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
@@ -554,6 +556,7 @@ class
GenerativeModelGenerateContentOperator(GoogleCloudBaseOperator):
self.tools = tools
self.generation_config = generation_config
self.safety_settings = safety_settings
+ self.system_instruction = system_instruction
self.pretrained_model = pretrained_model
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
@@ -570,6 +573,7 @@ class
GenerativeModelGenerateContentOperator(GoogleCloudBaseOperator):
tools=self.tools,
generation_config=self.generation_config,
safety_settings=self.safety_settings,
+ system_instruction=self.system_instruction,
pretrained_model=self.pretrained_model,
)
@@ -583,14 +587,14 @@ class
SupervisedFineTuningTrainOperator(GoogleCloudBaseOperator):
"""
Use the Supervised Fine Tuning API to create a tuning job.
+ :param project_id: Required. The ID of the Google Cloud project that the
+ service belongs to.
+ :param location: Required. The ID of the Google Cloud location that the
service belongs to.
:param source_model: Required. A pre-trained model optimized for
performing natural
language tasks such as classification, summarization, extraction,
content
creation, and ideation.
:param train_dataset: Required. Cloud Storage URI of your training
dataset. The dataset
must be formatted as a JSONL file. For best results, provide at least
100 to 500 examples.
- :param project_id: Required. The ID of the Google Cloud project that the
- service belongs to.
- :param location: Required. The ID of the Google Cloud location that the
service belongs to.
:param tuned_model_display_name: Optional. Display name of the TunedModel.
The name can be up
to 128 characters long and can consist of any UTF-8 characters.
:param validation_dataset: Optional. Cloud Storage URI of your training
dataset. The dataset must be
@@ -617,10 +621,10 @@ class
SupervisedFineTuningTrainOperator(GoogleCloudBaseOperator):
def __init__(
self,
*,
- source_model: str,
- train_dataset: str,
project_id: str,
location: str,
+ source_model: str,
+ train_dataset: str,
tuned_model_display_name: str | None = None,
validation_dataset: str | None = None,
epochs: int | None = None,
@@ -631,6 +635,8 @@ class
SupervisedFineTuningTrainOperator(GoogleCloudBaseOperator):
**kwargs,
) -> None:
super().__init__(**kwargs)
+ self.project_id = project_id
+ self.location = location
self.source_model = source_model
self.train_dataset = train_dataset
self.tuned_model_display_name = tuned_model_display_name
@@ -638,8 +644,6 @@ class
SupervisedFineTuningTrainOperator(GoogleCloudBaseOperator):
self.epochs = epochs
self.adapter_size = adapter_size
self.learning_rate_multiplier = learning_rate_multiplier
- self.project_id = project_id
- self.location = location
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
@@ -649,10 +653,10 @@ class
SupervisedFineTuningTrainOperator(GoogleCloudBaseOperator):
impersonation_chain=self.impersonation_chain,
)
response = self.hook.supervised_fine_tuning_train(
- source_model=self.source_model,
- train_dataset=self.train_dataset,
project_id=self.project_id,
location=self.location,
+ source_model=self.source_model,
+ train_dataset=self.train_dataset,
validation_dataset=self.validation_dataset,
epochs=self.epochs,
adapter_size=self.adapter_size,
@@ -666,7 +670,12 @@ class
SupervisedFineTuningTrainOperator(GoogleCloudBaseOperator):
self.xcom_push(context, key="tuned_model_name",
value=response.tuned_model_name)
self.xcom_push(context, key="tuned_model_endpoint_name",
value=response.tuned_model_endpoint_name)
- return types_v1.TuningJob.to_dict(response)
+ result = {
+ "tuned_model_name": response.tuned_model_name,
+ "tuned_model_endpoint_name": response.tuned_model_endpoint_name,
+ }
+
+ return result
class CountTokensOperator(GoogleCloudBaseOperator):
@@ -675,12 +684,10 @@ class CountTokensOperator(GoogleCloudBaseOperator):
:param project_id: Required. The ID of the Google Cloud project that the
service belongs to (templated).
- :param contents: Required. The multi-part content of a message that a user
or a program
- gives to the generative model, in order to elicit a specific response.
:param location: Required. The ID of the Google Cloud location that the
service belongs to (templated).
- :param system_instruction: Optional. Instructions for the model to steer
it toward better
- performance. For example, "Answer as concisely as possible"
+ :param contents: Required. The multi-part content of a message that a user
or a program
+ gives to the generative model, in order to elicit a specific response.
:param pretrained_model: By default uses the pre-trained model
`gemini-pro`,
supporting prompts with text-only input, including natural language
tasks, multi-turn text and code chat, and code generation. It can
@@ -702,8 +709,8 @@ class CountTokensOperator(GoogleCloudBaseOperator):
self,
*,
project_id: str,
- contents: list,
location: str,
+ contents: list,
pretrained_model: str = "gemini-pro",
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
@@ -742,6 +749,8 @@ class RunEvaluationOperator(GoogleCloudBaseOperator):
"""
Use the Rapid Evaluation API to evaluate a model.
+ :param project_id: Required. The ID of the Google Cloud project that the
service belongs to.
+ :param location: Required. The ID of the Google Cloud location that the
service belongs to.
:param pretrained_model: Required. A pre-trained model optimized for
performing natural
language tasks such as classification, summarization, extraction,
content
creation, and ideation.
@@ -750,8 +759,6 @@ class RunEvaluationOperator(GoogleCloudBaseOperator):
:param experiment_name: Required. The name of the evaluation experiment.
:param experiment_run_name: Required. The specific run name or ID for this
experiment.
:param prompt_template: Required. The template used to format the model's
prompts during evaluation. Adheres to Rapid Evaluation API.
- :param project_id: Required. The ID of the Google Cloud project that the
service belongs to.
- :param location: Required. The ID of the Google Cloud location that the
service belongs to.
:param generation_config: Optional. A dictionary containing generation
parameters for the model.
:param safety_settings: Optional. A dictionary specifying harm category
thresholds for blocking model outputs.
:param system_instruction: Optional. An instruction given to the model to
guide its behavior.
@@ -781,14 +788,14 @@ class RunEvaluationOperator(GoogleCloudBaseOperator):
def __init__(
self,
*,
+ project_id: str,
+ location: str,
pretrained_model: str,
eval_dataset: dict,
metrics: list,
experiment_name: str,
experiment_run_name: str,
prompt_template: str,
- project_id: str,
- location: str,
generation_config: dict | None = None,
safety_settings: dict | None = None,
system_instruction: str | None = None,
@@ -799,18 +806,18 @@ class RunEvaluationOperator(GoogleCloudBaseOperator):
) -> None:
super().__init__(**kwargs)
+ self.project_id = project_id
+ self.location = location
self.pretrained_model = pretrained_model
self.eval_dataset = eval_dataset
self.metrics = metrics
self.experiment_name = experiment_name
self.experiment_run_name = experiment_run_name
self.prompt_template = prompt_template
- self.system_instruction = system_instruction
self.generation_config = generation_config
self.safety_settings = safety_settings
+ self.system_instruction = system_instruction
self.tools = tools
- self.project_id = project_id
- self.location = location
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
@@ -820,17 +827,17 @@ class RunEvaluationOperator(GoogleCloudBaseOperator):
impersonation_chain=self.impersonation_chain,
)
response = self.hook.run_evaluation(
+ project_id=self.project_id,
+ location=self.location,
pretrained_model=self.pretrained_model,
eval_dataset=self.eval_dataset,
metrics=self.metrics,
experiment_name=self.experiment_name,
experiment_run_name=self.experiment_run_name,
prompt_template=self.prompt_template,
- project_id=self.project_id,
- location=self.location,
- system_instruction=self.system_instruction,
generation_config=self.generation_config,
safety_settings=self.safety_settings,
+ system_instruction=self.system_instruction,
tools=self.tools,
)
diff --git
a/tests/providers/google/cloud/hooks/vertex_ai/test_generative_model.py
b/tests/providers/google/cloud/hooks/vertex_ai/test_generative_model.py
index 2004eec29f..19723a51b1 100644
--- a/tests/providers/google/cloud/hooks/vertex_ai/test_generative_model.py
+++ b/tests/providers/google/cloud/hooks/vertex_ai/test_generative_model.py
@@ -222,7 +222,10 @@ class TestGenerativeModelWithDefaultProjectIdHook:
safety_settings=TEST_SAFETY_SETTINGS,
pretrained_model=TEST_MULTIMODAL_PRETRAINED_MODEL,
)
- mock_model.assert_called_once_with(TEST_MULTIMODAL_PRETRAINED_MODEL)
+ mock_model.assert_called_once_with(
+ pretrained_model=TEST_MULTIMODAL_PRETRAINED_MODEL,
+ system_instruction=None,
+ )
mock_model.return_value.generate_content.assert_called_once_with(
contents=TEST_CONTENTS,
tools=TEST_TOOLS,
@@ -257,7 +260,9 @@ class TestGenerativeModelWithDefaultProjectIdHook:
location=GCP_LOCATION,
pretrained_model=TEST_MULTIMODAL_PRETRAINED_MODEL,
)
- mock_model.assert_called_once_with(TEST_MULTIMODAL_PRETRAINED_MODEL)
+ mock_model.assert_called_once_with(
+ pretrained_model=TEST_MULTIMODAL_PRETRAINED_MODEL,
+ )
mock_model.return_value.count_tokens.assert_called_once_with(
contents=TEST_CONTENTS,
)
diff --git
a/tests/providers/google/cloud/operators/vertex_ai/test_generative_model.py
b/tests/providers/google/cloud/operators/vertex_ai/test_generative_model.py
index 3745850e72..e8efb9601f 100644
--- a/tests/providers/google/cloud/operators/vertex_ai/test_generative_model.py
+++ b/tests/providers/google/cloud/operators/vertex_ai/test_generative_model.py
@@ -356,6 +356,7 @@ class TestVertexAIGenerativeModelGenerateContentOperator:
HarmCategory.HARM_CATEGORY_HARASSMENT:
HarmBlockThreshold.BLOCK_ONLY_HIGH,
}
generation_config = {"max_output_tokens": 256, "top_p": 0.8,
"temperature": 0.0}
+ system_instruction = "be concise."
op = GenerativeModelGenerateContentOperator(
task_id=TASK_ID,
@@ -366,6 +367,7 @@ class TestVertexAIGenerativeModelGenerateContentOperator:
generation_config=generation_config,
safety_settings=safety_settings,
pretrained_model=pretrained_model,
+ system_instruction=system_instruction,
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
)
@@ -382,6 +384,7 @@ class TestVertexAIGenerativeModelGenerateContentOperator:
generation_config=generation_config,
safety_settings=safety_settings,
pretrained_model=pretrained_model,
+ system_instruction=system_instruction,
)