This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new f332c770f5b Deprecate VertexAI PaLM text generative model (#44719)
f332c770f5b is described below

commit f332c770f5b2a18733c0f0352411ca9f73b90e1f
Author: Maksim <[email protected]>
AuthorDate: Fri Dec 6 05:15:01 2024 -0800

    Deprecate VertexAI PaLM text generative model (#44719)
---
 .../operators/cloud/vertex_ai.rst                  | 10 ----
 .../cloud/hooks/vertex_ai/generative_model.py      | 10 ++++
 .../cloud/operators/vertex_ai/generative_model.py  |  5 ++
 .../cloud/hooks/vertex_ai/test_generative_model.py | 30 ++++------
 .../operators/vertex_ai/test_generative_model.py   | 70 ++++++++++++++--------
 .../example_vertex_ai_generative_model.py          | 12 ----
 tests/always/test_project_structure.py             | 13 ++--
 7 files changed, 78 insertions(+), 72 deletions(-)

diff --git a/docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst 
b/docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst
index f8f87040f9f..173b23dfa30 100644
--- a/docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst
+++ b/docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst
@@ -573,16 +573,6 @@ To get a pipeline job list you can use
 Interacting with Generative AI
 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 
-To generate a prediction via language model you can use
-:class:`~airflow.providers.google.cloud.operators.vertex_ai.generative_model.TextGenerationModelPredictOperator`.
-The operator returns the model's response in :ref:`XCom <concepts:xcom>` under 
``model_response`` key.
-
-.. exampleinclude:: 
/../../providers/tests/system/google/cloud/vertex_ai/example_vertex_ai_generative_model.py
-    :language: python
-    :dedent: 4
-    :start-after: [START 
how_to_cloud_vertex_ai_text_generation_model_predict_operator]
-    :end-before: [END 
how_to_cloud_vertex_ai_text_generation_model_predict_operator]
-
 To generate text embeddings you can use
 
:class:`~airflow.providers.google.cloud.operators.vertex_ai.generative_model.TextEmbeddingModelGetEmbeddingsOperator`.
 The operator returns the model's response in :ref:`XCom <concepts:xcom>` under 
``model_response`` key.
diff --git 
a/providers/src/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py
 
b/providers/src/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py
index 931cc192737..7e506641484 100644
--- 
a/providers/src/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py
+++ 
b/providers/src/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py
@@ -43,6 +43,11 @@ if TYPE_CHECKING:
 class GenerativeModelHook(GoogleBaseHook):
     """Hook for Google Cloud Vertex AI Generative Model APIs."""
 
+    @deprecated(
+        planned_removal_date="April 09, 2025",
+        use_instead="GenerativeModelHook.get_generative_model",
+        category=AirflowProviderDeprecationWarning,
+    )
     def get_text_generation_model(self, pretrained_model: str):
         """Return a Model Garden Model object based on Text Generation."""
         model = TextGenerationModel.from_pretrained(pretrained_model)
@@ -275,6 +280,11 @@ class GenerativeModelHook(GoogleBaseHook):
 
         return response.text
 
+    @deprecated(
+        planned_removal_date="April 09, 2025",
+        use_instead="GenerativeModelHook.generative_model_generate_content",
+        category=AirflowProviderDeprecationWarning,
+    )
     @GoogleBaseHook.fallback_to_default_project_id
     def text_generation_model_predict(
         self,
diff --git 
a/providers/src/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py
 
b/providers/src/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py
index 78eba333896..42e4fdc588e 100644
--- 
a/providers/src/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py
+++ 
b/providers/src/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py
@@ -353,6 +353,11 @@ class 
PromptMultimodalModelWithMediaOperator(GoogleCloudBaseOperator):
         return response
 
 
+@deprecated(
+    planned_removal_date="April 09, 2025",
+    use_instead="GenerativeModelGenerateContentOperator",
+    category=AirflowProviderDeprecationWarning,
+)
 class TextGenerationModelPredictOperator(GoogleCloudBaseOperator):
     """
     Uses the Vertex AI PaLM API to generate natural language text.
diff --git 
a/providers/tests/google/cloud/hooks/vertex_ai/test_generative_model.py 
b/providers/tests/google/cloud/hooks/vertex_ai/test_generative_model.py
index 35d3fc9256e..21741a617ea 100644
--- a/providers/tests/google/cloud/hooks/vertex_ai/test_generative_model.py
+++ b/providers/tests/google/cloud/hooks/vertex_ai/test_generative_model.py
@@ -205,24 +205,18 @@ class TestGenerativeModelWithDefaultProjectIdHook:
 
     
@mock.patch(GENERATIVE_MODEL_STRING.format("GenerativeModelHook.get_text_generation_model"))
     def test_text_generation_model_predict(self, mock_model) -> None:
-        self.hook.text_generation_model_predict(
-            project_id=GCP_PROJECT,
-            location=GCP_LOCATION,
-            prompt=TEST_PROMPT,
-            pretrained_model=TEST_LANGUAGE_PRETRAINED_MODEL,
-            temperature=TEST_TEMPERATURE,
-            max_output_tokens=TEST_MAX_OUTPUT_TOKENS,
-            top_p=TEST_TOP_P,
-            top_k=TEST_TOP_K,
-        )
-        mock_model.assert_called_once_with(TEST_LANGUAGE_PRETRAINED_MODEL)
-        mock_model.return_value.predict.assert_called_once_with(
-            prompt=TEST_PROMPT,
-            temperature=TEST_TEMPERATURE,
-            max_output_tokens=TEST_MAX_OUTPUT_TOKENS,
-            top_p=TEST_TOP_P,
-            top_k=TEST_TOP_K,
-        )
+        with pytest.warns(AirflowProviderDeprecationWarning) as warnings:
+            self.hook.text_generation_model_predict(
+                project_id=GCP_PROJECT,
+                location=GCP_LOCATION,
+                prompt=TEST_PROMPT,
+                pretrained_model=TEST_LANGUAGE_PRETRAINED_MODEL,
+                temperature=TEST_TEMPERATURE,
+                max_output_tokens=TEST_MAX_OUTPUT_TOKENS,
+                top_p=TEST_TOP_P,
+                top_k=TEST_TOP_K,
+            )
+            assert_warning("generative_model_generate_content", warnings)
 
     
@mock.patch(GENERATIVE_MODEL_STRING.format("GenerativeModelHook.get_text_embedding_model"))
     def test_text_embedding_model_get_embeddings(self, mock_model) -> None:
diff --git 
a/providers/tests/google/cloud/operators/vertex_ai/test_generative_model.py 
b/providers/tests/google/cloud/operators/vertex_ai/test_generative_model.py
index 5bdb04cb3ed..709e5d1f784 100644
--- a/providers/tests/google/cloud/operators/vertex_ai/test_generative_model.py
+++ b/providers/tests/google/cloud/operators/vertex_ai/test_generative_model.py
@@ -278,28 +278,46 @@ class TestVertexAIPromptMultimodalModelWithMediaOperator:
 
 
 class TestVertexAITextGenerationModelPredictOperator:
+    prompt = "In 10 words or less, what is Apache Airflow?"
+    pretrained_model = "text-bison"
+    temperature = 0.0
+    max_output_tokens = 256
+    top_p = 0.8
+    top_k = 40
+
+    def test_deprecation_warning(self):
+        with pytest.warns(AirflowProviderDeprecationWarning) as warnings:
+            TextGenerationModelPredictOperator(
+                task_id=TASK_ID,
+                project_id=GCP_PROJECT,
+                location=GCP_LOCATION,
+                prompt=self.prompt,
+                pretrained_model=self.pretrained_model,
+                temperature=self.temperature,
+                max_output_tokens=self.max_output_tokens,
+                top_p=self.top_p,
+                top_k=self.top_k,
+                gcp_conn_id=GCP_CONN_ID,
+                impersonation_chain=IMPERSONATION_CHAIN,
+            )
+            assert_warning("GenerativeModelGenerateContentOperator", warnings)
+
     @mock.patch(VERTEX_AI_PATH.format("generative_model.GenerativeModelHook"))
     def test_execute(self, mock_hook):
-        prompt = "In 10 words or less, what is Apache Airflow?"
-        pretrained_model = "text-bison"
-        temperature = 0.0
-        max_output_tokens = 256
-        top_p = 0.8
-        top_k = 40
-
-        op = TextGenerationModelPredictOperator(
-            task_id=TASK_ID,
-            project_id=GCP_PROJECT,
-            location=GCP_LOCATION,
-            prompt=prompt,
-            pretrained_model=pretrained_model,
-            temperature=temperature,
-            max_output_tokens=max_output_tokens,
-            top_p=top_p,
-            top_k=top_k,
-            gcp_conn_id=GCP_CONN_ID,
-            impersonation_chain=IMPERSONATION_CHAIN,
-        )
+        with pytest.warns(AirflowProviderDeprecationWarning):
+            op = TextGenerationModelPredictOperator(
+                task_id=TASK_ID,
+                project_id=GCP_PROJECT,
+                location=GCP_LOCATION,
+                prompt=self.prompt,
+                pretrained_model=self.pretrained_model,
+                temperature=self.temperature,
+                max_output_tokens=self.max_output_tokens,
+                top_p=self.top_p,
+                top_k=self.top_k,
+                gcp_conn_id=GCP_CONN_ID,
+                impersonation_chain=IMPERSONATION_CHAIN,
+            )
         op.execute(context={"ti": mock.MagicMock()})
         mock_hook.assert_called_once_with(
             gcp_conn_id=GCP_CONN_ID,
@@ -308,12 +326,12 @@ class TestVertexAITextGenerationModelPredictOperator:
         
mock_hook.return_value.text_generation_model_predict.assert_called_once_with(
             project_id=GCP_PROJECT,
             location=GCP_LOCATION,
-            prompt=prompt,
-            pretrained_model=pretrained_model,
-            temperature=temperature,
-            max_output_tokens=max_output_tokens,
-            top_p=top_p,
-            top_k=top_k,
+            prompt=self.prompt,
+            pretrained_model=self.pretrained_model,
+            temperature=self.temperature,
+            max_output_tokens=self.max_output_tokens,
+            top_p=self.top_p,
+            top_k=self.top_k,
         )
 
 
diff --git 
a/providers/tests/system/google/cloud/vertex_ai/example_vertex_ai_generative_model.py
 
b/providers/tests/system/google/cloud/vertex_ai/example_vertex_ai_generative_model.py
index 4384626999d..bafb361bc47 100644
--- 
a/providers/tests/system/google/cloud/vertex_ai/example_vertex_ai_generative_model.py
+++ 
b/providers/tests/system/google/cloud/vertex_ai/example_vertex_ai_generative_model.py
@@ -36,7 +36,6 @@ from 
airflow.providers.google.cloud.operators.vertex_ai.generative_model import
     GenerativeModelGenerateContentOperator,
     RunEvaluationOperator,
     TextEmbeddingModelGetEmbeddingsOperator,
-    TextGenerationModelPredictOperator,
 )
 
 PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "default")
@@ -44,7 +43,6 @@ DAG_ID = "vertex_ai_generative_model_dag"
 REGION = "us-central1"
 PROMPT = "In 10 words or less, why is Apache Airflow amazing?"
 CONTENTS = [PROMPT]
-LANGUAGE_MODEL = "text-bison"
 TEXT_EMBEDDING_MODEL = "textembedding-gecko"
 MULTIMODAL_MODEL = "gemini-pro"
 MULTIMODAL_VISION_MODEL = "gemini-pro-vision"
@@ -117,16 +115,6 @@ with DAG(
     catchup=False,
     tags=["example", "vertex_ai", "generative_model"],
 ) as dag:
-    # [START how_to_cloud_vertex_ai_text_generation_model_predict_operator]
-    predict_task = TextGenerationModelPredictOperator(
-        task_id="predict_task",
-        project_id=PROJECT_ID,
-        location=REGION,
-        prompt=PROMPT,
-        pretrained_model=LANGUAGE_MODEL,
-    )
-    # [END how_to_cloud_vertex_ai_text_generation_model_predict_operator]
-
     # [START 
how_to_cloud_vertex_ai_text_embedding_model_get_embeddings_operator]
     generate_embeddings_task = TextEmbeddingModelGetEmbeddingsOperator(
         task_id="generate_embeddings_task",
diff --git a/tests/always/test_project_structure.py 
b/tests/always/test_project_structure.py
index fca84fdb4de..ad24f34e0c3 100644
--- a/tests/always/test_project_structure.py
+++ b/tests/always/test_project_structure.py
@@ -353,6 +353,7 @@ class 
TestGoogleProviderProjectStructure(ExampleCoverageTest, AssetsCoverageTest
         
"airflow.providers.google.cloud.operators.automl.AutoMLTablesListTableSpecsOperator",
         
"airflow.providers.google.cloud.operators.automl.AutoMLTablesUpdateDatasetOperator",
         
"airflow.providers.google.cloud.operators.automl.AutoMLDeployModelOperator",
+        
"airflow.providers.google.cloud.operators.automl.AutoMLBatchPredictOperator",
         
"airflow.providers.google.cloud.operators.datapipeline.CreateDataPipelineOperator",
         
"airflow.providers.google.cloud.operators.datapipeline.RunDataPipelineOperator",
         
"airflow.providers.google.cloud.operators.dataproc.DataprocScaleClusterOperator",
@@ -367,6 +368,12 @@ class 
TestGoogleProviderProjectStructure(ExampleCoverageTest, AssetsCoverageTest
         
"airflow.providers.google.cloud.operators.mlengine.MLEngineSetDefaultVersionOperator",
         
"airflow.providers.google.cloud.operators.mlengine.MLEngineStartBatchPredictionJobOperator",
         
"airflow.providers.google.cloud.operators.mlengine.MLEngineStartTrainingJobOperator",
+        
"airflow.providers.google.cloud.operators.mlengine.MLEngineTrainingCancelJobOperator",
+        
"airflow.providers.google.cloud.operators.vertex_ai.generative_model.PromptLanguageModelOperator",
+        
"airflow.providers.google.cloud.operators.vertex_ai.generative_model.GenerateTextEmbeddingsOperator",
+        
"airflow.providers.google.cloud.operators.vertex_ai.generative_model.PromptMultimodalModelOperator",
+        
"airflow.providers.google.cloud.operators.vertex_ai.generative_model.PromptMultimodalModelWithMediaOperator",
+        
"airflow.providers.google.cloud.operators.vertex_ai.generative_model.TextGenerationModelPredictOperator",
         
"airflow.providers.google.marketing_platform.operators.GoogleDisplayVideo360CreateQueryOperator",
         
"airflow.providers.google.marketing_platform.operators.GoogleDisplayVideo360RunQueryOperator",
         
"airflow.providers.google.marketing_platform.operators.GoogleDisplayVideo360DownloadReportV2Operator",
@@ -385,7 +392,6 @@ class 
TestGoogleProviderProjectStructure(ExampleCoverageTest, AssetsCoverageTest
     }
 
     MISSING_EXAMPLES_FOR_CLASSES = {
-        
"airflow.providers.google.cloud.operators.mlengine.MLEngineTrainingCancelJobOperator",
         
"airflow.providers.google.cloud.operators.dlp.CloudDLPRedactImageOperator",
         
"airflow.providers.google.cloud.transfers.cassandra_to_gcs.CassandraToGCSOperator",
         
"airflow.providers.google.cloud.transfers.adls_to_gcs.ADLSToGCSOperator",
@@ -394,11 +400,6 @@ class 
TestGoogleProviderProjectStructure(ExampleCoverageTest, AssetsCoverageTest
         
"airflow.providers.google.cloud.operators.vertex_ai.auto_ml.AutoMLTrainingJobBaseOperator",
         
"airflow.providers.google.cloud.operators.vertex_ai.endpoint_service.UpdateEndpointOperator",
         
"airflow.providers.google.cloud.operators.vertex_ai.batch_prediction_job.GetBatchPredictionJobOperator",
-        
"airflow.providers.google.cloud.operators.vertex_ai.generative_model.PromptLanguageModelOperator",
-        
"airflow.providers.google.cloud.operators.vertex_ai.generative_model.GenerateTextEmbeddingsOperator",
-        
"airflow.providers.google.cloud.operators.vertex_ai.generative_model.PromptMultimodalModelOperator",
-        
"airflow.providers.google.cloud.operators.vertex_ai.generative_model.PromptMultimodalModelWithMediaOperator",
-        
"airflow.providers.google.cloud.operators.automl.AutoMLBatchPredictOperator",
     }
 
     ASSETS_NOT_REQUIRED = {

Reply via email to