This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new f332c770f5b Deprecate VertexAI PaLM text generative model (#44719)
f332c770f5b is described below
commit f332c770f5b2a18733c0f0352411ca9f73b90e1f
Author: Maksim <[email protected]>
AuthorDate: Fri Dec 6 05:15:01 2024 -0800
Deprecate VertexAI PaLM text generative model (#44719)
---
.../operators/cloud/vertex_ai.rst | 10 ----
.../cloud/hooks/vertex_ai/generative_model.py | 10 ++++
.../cloud/operators/vertex_ai/generative_model.py | 5 ++
.../cloud/hooks/vertex_ai/test_generative_model.py | 30 ++++------
.../operators/vertex_ai/test_generative_model.py | 70 ++++++++++++++--------
.../example_vertex_ai_generative_model.py | 12 ----
tests/always/test_project_structure.py | 13 ++--
7 files changed, 78 insertions(+), 72 deletions(-)
diff --git a/docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst
b/docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst
index f8f87040f9f..173b23dfa30 100644
--- a/docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst
+++ b/docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst
@@ -573,16 +573,6 @@ To get a pipeline job list you can use
Interacting with Generative AI
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-To generate a prediction via language model you can use
-:class:`~airflow.providers.google.cloud.operators.vertex_ai.generative_model.TextGenerationModelPredictOperator`.
-The operator returns the model's response in :ref:`XCom <concepts:xcom>` under
``model_response`` key.
-
-.. exampleinclude::
/../../providers/tests/system/google/cloud/vertex_ai/example_vertex_ai_generative_model.py
- :language: python
- :dedent: 4
- :start-after: [START
how_to_cloud_vertex_ai_text_generation_model_predict_operator]
- :end-before: [END
how_to_cloud_vertex_ai_text_generation_model_predict_operator]
-
To generate text embeddings you can use
:class:`~airflow.providers.google.cloud.operators.vertex_ai.generative_model.TextEmbeddingModelGetEmbeddingsOperator`.
The operator returns the model's response in :ref:`XCom <concepts:xcom>` under
``model_response`` key.
diff --git
a/providers/src/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py
b/providers/src/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py
index 931cc192737..7e506641484 100644
---
a/providers/src/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py
+++
b/providers/src/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py
@@ -43,6 +43,11 @@ if TYPE_CHECKING:
class GenerativeModelHook(GoogleBaseHook):
"""Hook for Google Cloud Vertex AI Generative Model APIs."""
+ @deprecated(
+ planned_removal_date="April 09, 2025",
+ use_instead="GenerativeModelHook.get_generative_model",
+ category=AirflowProviderDeprecationWarning,
+ )
def get_text_generation_model(self, pretrained_model: str):
"""Return a Model Garden Model object based on Text Generation."""
model = TextGenerationModel.from_pretrained(pretrained_model)
@@ -275,6 +280,11 @@ class GenerativeModelHook(GoogleBaseHook):
return response.text
+ @deprecated(
+ planned_removal_date="April 09, 2025",
+ use_instead="GenerativeModelHook.generative_model_generate_content",
+ category=AirflowProviderDeprecationWarning,
+ )
@GoogleBaseHook.fallback_to_default_project_id
def text_generation_model_predict(
self,
diff --git
a/providers/src/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py
b/providers/src/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py
index 78eba333896..42e4fdc588e 100644
---
a/providers/src/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py
+++
b/providers/src/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py
@@ -353,6 +353,11 @@ class
PromptMultimodalModelWithMediaOperator(GoogleCloudBaseOperator):
return response
+@deprecated(
+ planned_removal_date="April 09, 2025",
+ use_instead="GenerativeModelGenerateContentOperator",
+ category=AirflowProviderDeprecationWarning,
+)
class TextGenerationModelPredictOperator(GoogleCloudBaseOperator):
"""
Uses the Vertex AI PaLM API to generate natural language text.
diff --git
a/providers/tests/google/cloud/hooks/vertex_ai/test_generative_model.py
b/providers/tests/google/cloud/hooks/vertex_ai/test_generative_model.py
index 35d3fc9256e..21741a617ea 100644
--- a/providers/tests/google/cloud/hooks/vertex_ai/test_generative_model.py
+++ b/providers/tests/google/cloud/hooks/vertex_ai/test_generative_model.py
@@ -205,24 +205,18 @@ class TestGenerativeModelWithDefaultProjectIdHook:
@mock.patch(GENERATIVE_MODEL_STRING.format("GenerativeModelHook.get_text_generation_model"))
def test_text_generation_model_predict(self, mock_model) -> None:
- self.hook.text_generation_model_predict(
- project_id=GCP_PROJECT,
- location=GCP_LOCATION,
- prompt=TEST_PROMPT,
- pretrained_model=TEST_LANGUAGE_PRETRAINED_MODEL,
- temperature=TEST_TEMPERATURE,
- max_output_tokens=TEST_MAX_OUTPUT_TOKENS,
- top_p=TEST_TOP_P,
- top_k=TEST_TOP_K,
- )
- mock_model.assert_called_once_with(TEST_LANGUAGE_PRETRAINED_MODEL)
- mock_model.return_value.predict.assert_called_once_with(
- prompt=TEST_PROMPT,
- temperature=TEST_TEMPERATURE,
- max_output_tokens=TEST_MAX_OUTPUT_TOKENS,
- top_p=TEST_TOP_P,
- top_k=TEST_TOP_K,
- )
+ with pytest.warns(AirflowProviderDeprecationWarning) as warnings:
+ self.hook.text_generation_model_predict(
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ prompt=TEST_PROMPT,
+ pretrained_model=TEST_LANGUAGE_PRETRAINED_MODEL,
+ temperature=TEST_TEMPERATURE,
+ max_output_tokens=TEST_MAX_OUTPUT_TOKENS,
+ top_p=TEST_TOP_P,
+ top_k=TEST_TOP_K,
+ )
+ assert_warning("generative_model_generate_content", warnings)
@mock.patch(GENERATIVE_MODEL_STRING.format("GenerativeModelHook.get_text_embedding_model"))
def test_text_embedding_model_get_embeddings(self, mock_model) -> None:
diff --git
a/providers/tests/google/cloud/operators/vertex_ai/test_generative_model.py
b/providers/tests/google/cloud/operators/vertex_ai/test_generative_model.py
index 5bdb04cb3ed..709e5d1f784 100644
--- a/providers/tests/google/cloud/operators/vertex_ai/test_generative_model.py
+++ b/providers/tests/google/cloud/operators/vertex_ai/test_generative_model.py
@@ -278,28 +278,46 @@ class TestVertexAIPromptMultimodalModelWithMediaOperator:
class TestVertexAITextGenerationModelPredictOperator:
+ prompt = "In 10 words or less, what is Apache Airflow?"
+ pretrained_model = "text-bison"
+ temperature = 0.0
+ max_output_tokens = 256
+ top_p = 0.8
+ top_k = 40
+
+ def test_deprecation_warning(self):
+ with pytest.warns(AirflowProviderDeprecationWarning) as warnings:
+ TextGenerationModelPredictOperator(
+ task_id=TASK_ID,
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ prompt=self.prompt,
+ pretrained_model=self.pretrained_model,
+ temperature=self.temperature,
+ max_output_tokens=self.max_output_tokens,
+ top_p=self.top_p,
+ top_k=self.top_k,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ )
+ assert_warning("GenerativeModelGenerateContentOperator", warnings)
+
@mock.patch(VERTEX_AI_PATH.format("generative_model.GenerativeModelHook"))
def test_execute(self, mock_hook):
- prompt = "In 10 words or less, what is Apache Airflow?"
- pretrained_model = "text-bison"
- temperature = 0.0
- max_output_tokens = 256
- top_p = 0.8
- top_k = 40
-
- op = TextGenerationModelPredictOperator(
- task_id=TASK_ID,
- project_id=GCP_PROJECT,
- location=GCP_LOCATION,
- prompt=prompt,
- pretrained_model=pretrained_model,
- temperature=temperature,
- max_output_tokens=max_output_tokens,
- top_p=top_p,
- top_k=top_k,
- gcp_conn_id=GCP_CONN_ID,
- impersonation_chain=IMPERSONATION_CHAIN,
- )
+ with pytest.warns(AirflowProviderDeprecationWarning):
+ op = TextGenerationModelPredictOperator(
+ task_id=TASK_ID,
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ prompt=self.prompt,
+ pretrained_model=self.pretrained_model,
+ temperature=self.temperature,
+ max_output_tokens=self.max_output_tokens,
+ top_p=self.top_p,
+ top_k=self.top_k,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ )
op.execute(context={"ti": mock.MagicMock()})
mock_hook.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID,
@@ -308,12 +326,12 @@ class TestVertexAITextGenerationModelPredictOperator:
mock_hook.return_value.text_generation_model_predict.assert_called_once_with(
project_id=GCP_PROJECT,
location=GCP_LOCATION,
- prompt=prompt,
- pretrained_model=pretrained_model,
- temperature=temperature,
- max_output_tokens=max_output_tokens,
- top_p=top_p,
- top_k=top_k,
+ prompt=self.prompt,
+ pretrained_model=self.pretrained_model,
+ temperature=self.temperature,
+ max_output_tokens=self.max_output_tokens,
+ top_p=self.top_p,
+ top_k=self.top_k,
)
diff --git
a/providers/tests/system/google/cloud/vertex_ai/example_vertex_ai_generative_model.py
b/providers/tests/system/google/cloud/vertex_ai/example_vertex_ai_generative_model.py
index 4384626999d..bafb361bc47 100644
---
a/providers/tests/system/google/cloud/vertex_ai/example_vertex_ai_generative_model.py
+++
b/providers/tests/system/google/cloud/vertex_ai/example_vertex_ai_generative_model.py
@@ -36,7 +36,6 @@ from
airflow.providers.google.cloud.operators.vertex_ai.generative_model import
GenerativeModelGenerateContentOperator,
RunEvaluationOperator,
TextEmbeddingModelGetEmbeddingsOperator,
- TextGenerationModelPredictOperator,
)
PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "default")
@@ -44,7 +43,6 @@ DAG_ID = "vertex_ai_generative_model_dag"
REGION = "us-central1"
PROMPT = "In 10 words or less, why is Apache Airflow amazing?"
CONTENTS = [PROMPT]
-LANGUAGE_MODEL = "text-bison"
TEXT_EMBEDDING_MODEL = "textembedding-gecko"
MULTIMODAL_MODEL = "gemini-pro"
MULTIMODAL_VISION_MODEL = "gemini-pro-vision"
@@ -117,16 +115,6 @@ with DAG(
catchup=False,
tags=["example", "vertex_ai", "generative_model"],
) as dag:
- # [START how_to_cloud_vertex_ai_text_generation_model_predict_operator]
- predict_task = TextGenerationModelPredictOperator(
- task_id="predict_task",
- project_id=PROJECT_ID,
- location=REGION,
- prompt=PROMPT,
- pretrained_model=LANGUAGE_MODEL,
- )
- # [END how_to_cloud_vertex_ai_text_generation_model_predict_operator]
-
# [START
how_to_cloud_vertex_ai_text_embedding_model_get_embeddings_operator]
generate_embeddings_task = TextEmbeddingModelGetEmbeddingsOperator(
task_id="generate_embeddings_task",
diff --git a/tests/always/test_project_structure.py
b/tests/always/test_project_structure.py
index fca84fdb4de..ad24f34e0c3 100644
--- a/tests/always/test_project_structure.py
+++ b/tests/always/test_project_structure.py
@@ -353,6 +353,7 @@ class
TestGoogleProviderProjectStructure(ExampleCoverageTest, AssetsCoverageTest
"airflow.providers.google.cloud.operators.automl.AutoMLTablesListTableSpecsOperator",
"airflow.providers.google.cloud.operators.automl.AutoMLTablesUpdateDatasetOperator",
"airflow.providers.google.cloud.operators.automl.AutoMLDeployModelOperator",
+
"airflow.providers.google.cloud.operators.automl.AutoMLBatchPredictOperator",
"airflow.providers.google.cloud.operators.datapipeline.CreateDataPipelineOperator",
"airflow.providers.google.cloud.operators.datapipeline.RunDataPipelineOperator",
"airflow.providers.google.cloud.operators.dataproc.DataprocScaleClusterOperator",
@@ -367,6 +368,12 @@ class
TestGoogleProviderProjectStructure(ExampleCoverageTest, AssetsCoverageTest
"airflow.providers.google.cloud.operators.mlengine.MLEngineSetDefaultVersionOperator",
"airflow.providers.google.cloud.operators.mlengine.MLEngineStartBatchPredictionJobOperator",
"airflow.providers.google.cloud.operators.mlengine.MLEngineStartTrainingJobOperator",
+
"airflow.providers.google.cloud.operators.mlengine.MLEngineTrainingCancelJobOperator",
+
"airflow.providers.google.cloud.operators.vertex_ai.generative_model.PromptLanguageModelOperator",
+
"airflow.providers.google.cloud.operators.vertex_ai.generative_model.GenerateTextEmbeddingsOperator",
+
"airflow.providers.google.cloud.operators.vertex_ai.generative_model.PromptMultimodalModelOperator",
+
"airflow.providers.google.cloud.operators.vertex_ai.generative_model.PromptMultimodalModelWithMediaOperator",
+
"airflow.providers.google.cloud.operators.vertex_ai.generative_model.TextGenerationModelPredictOperator",
"airflow.providers.google.marketing_platform.operators.GoogleDisplayVideo360CreateQueryOperator",
"airflow.providers.google.marketing_platform.operators.GoogleDisplayVideo360RunQueryOperator",
"airflow.providers.google.marketing_platform.operators.GoogleDisplayVideo360DownloadReportV2Operator",
@@ -385,7 +392,6 @@ class
TestGoogleProviderProjectStructure(ExampleCoverageTest, AssetsCoverageTest
}
MISSING_EXAMPLES_FOR_CLASSES = {
-
"airflow.providers.google.cloud.operators.mlengine.MLEngineTrainingCancelJobOperator",
"airflow.providers.google.cloud.operators.dlp.CloudDLPRedactImageOperator",
"airflow.providers.google.cloud.transfers.cassandra_to_gcs.CassandraToGCSOperator",
"airflow.providers.google.cloud.transfers.adls_to_gcs.ADLSToGCSOperator",
@@ -394,11 +400,6 @@ class
TestGoogleProviderProjectStructure(ExampleCoverageTest, AssetsCoverageTest
"airflow.providers.google.cloud.operators.vertex_ai.auto_ml.AutoMLTrainingJobBaseOperator",
"airflow.providers.google.cloud.operators.vertex_ai.endpoint_service.UpdateEndpointOperator",
"airflow.providers.google.cloud.operators.vertex_ai.batch_prediction_job.GetBatchPredictionJobOperator",
-
"airflow.providers.google.cloud.operators.vertex_ai.generative_model.PromptLanguageModelOperator",
-
"airflow.providers.google.cloud.operators.vertex_ai.generative_model.GenerateTextEmbeddingsOperator",
-
"airflow.providers.google.cloud.operators.vertex_ai.generative_model.PromptMultimodalModelOperator",
-
"airflow.providers.google.cloud.operators.vertex_ai.generative_model.PromptMultimodalModelWithMediaOperator",
-
"airflow.providers.google.cloud.operators.automl.AutoMLBatchPredictOperator",
}
ASSETS_NOT_REQUIRED = {