This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new a66edcbb2c `CreateBatchPredictionJobOperator` Add batch_size param for 
Vertex AI BatchPredictionJob objects (#31118)
a66edcbb2c is described below

commit a66edcbb2c9c04031326a3965c76a419043338ff
Author: Viacheslav <[email protected]>
AuthorDate: Sat May 13 10:45:19 2023 -0700

    `CreateBatchPredictionJobOperator` Add batch_size param for Vertex AI 
BatchPredictionJob objects (#31118)
    
    * Add batch_size param for BatchPredictionJob objects
    
    Co-authored-by: Jarek Potiuk <[email protected]>
    
    ---------
    
    Co-authored-by: Jarek Potiuk <[email protected]>
---
 .../google/cloud/hooks/vertex_ai/batch_prediction_job.py   | 12 ++++++++++++
 .../cloud/operators/vertex_ai/batch_prediction_job.py      | 14 ++++++++++++++
 airflow/providers/google/provider.yaml                     |  2 +-
 docs/apache-airflow-providers-google/index.rst             |  2 +-
 generated/provider_dependencies.json                       |  2 +-
 tests/providers/google/cloud/operators/test_vertex_ai.py   |  7 +++++++
 6 files changed, 36 insertions(+), 3 deletions(-)

diff --git 
a/airflow/providers/google/cloud/hooks/vertex_ai/batch_prediction_job.py 
b/airflow/providers/google/cloud/hooks/vertex_ai/batch_prediction_job.py
index 93cd6902bb..a3d0d20513 100644
--- a/airflow/providers/google/cloud/hooks/vertex_ai/batch_prediction_job.py
+++ b/airflow/providers/google/cloud/hooks/vertex_ai/batch_prediction_job.py
@@ -114,6 +114,8 @@ class BatchPredictionJobHook(GoogleBaseHook):
         labels: dict[str, str] | None = None,
         encryption_spec_key_name: str | None = None,
         sync: bool = True,
+        create_request_timeout: float | None = None,
+        batch_size: int | None = None,
     ) -> BatchPredictionJob:
         """
         Create a batch prediction job.
@@ -207,6 +209,14 @@ class BatchPredictionJobHook(GoogleBaseHook):
         :param sync: Whether to execute this method synchronously. If False, 
this method will be executed in
             concurrent Future and any downstream object will be immediately 
returned and synced when the
             Future has completed.
+        :param create_request_timeout: Optional. The timeout for the create 
request in seconds.
+        :param batch_size: Optional. The number of the records (e.g. instances)
+            of the operation given in each batch
+            to a machine replica. Machine type, and size of a single record 
should be considered
+            when setting this parameter, higher value speeds up the batch 
operation's execution,
+            but too high value will result in a whole batch not fitting in a 
machine's memory,
+            and the whole operation will fail.
+            The default value is same as in the aiplatform's 
BatchPredictionJob.
         """
         self._batch_prediction_job = BatchPredictionJob.create(
             job_display_name=job_display_name,
@@ -232,6 +242,8 @@ class BatchPredictionJobHook(GoogleBaseHook):
             credentials=self.get_credentials(),
             encryption_spec_key_name=encryption_spec_key_name,
             sync=sync,
+            create_request_timeout=create_request_timeout,
+            batch_size=batch_size,
         )
         return self._batch_prediction_job
 
diff --git 
a/airflow/providers/google/cloud/operators/vertex_ai/batch_prediction_job.py 
b/airflow/providers/google/cloud/operators/vertex_ai/batch_prediction_job.py
index 5c12322775..dc4775fe24 100644
--- a/airflow/providers/google/cloud/operators/vertex_ai/batch_prediction_job.py
+++ b/airflow/providers/google/cloud/operators/vertex_ai/batch_prediction_job.py
@@ -139,6 +139,14 @@ class 
CreateBatchPredictionJobOperator(GoogleCloudBaseOperator):
     :param sync: Whether to execute this method synchronously. If False, this 
method will be executed in
         concurrent Future and any downstream object will be immediately 
returned and synced when the
         Future has completed.
+    :param create_request_timeout: Optional. The timeout for the create 
request in seconds.
+    :param batch_size: Optional. The number of the records (e.g. instances)
+        of the operation given in each batch
+        to a machine replica. Machine type, and size of a single record should 
be considered
+        when setting this parameter, higher value speeds up the batch 
operation's execution,
+        but too high value will result in a whole batch not fitting in a 
machine's memory,
+        and the whole operation will fail.
+        The default value is same as in the aiplatform's BatchPredictionJob.
     :param retry: Designation of what errors, if any, should be retried.
     :param timeout: The timeout for this request.
     :param metadata: Strings which should be sent along with the request as 
metadata.
@@ -181,6 +189,8 @@ class 
CreateBatchPredictionJobOperator(GoogleCloudBaseOperator):
         labels: dict[str, str] | None = None,
         encryption_spec_key_name: str | None = None,
         sync: bool = True,
+        create_request_timeout: float | None = None,
+        batch_size: int | None = None,
         gcp_conn_id: str = "google_cloud_default",
         impersonation_chain: str | Sequence[str] | None = None,
         **kwargs,
@@ -208,6 +218,8 @@ class 
CreateBatchPredictionJobOperator(GoogleCloudBaseOperator):
         self.labels = labels
         self.encryption_spec_key_name = encryption_spec_key_name
         self.sync = sync
+        self.create_request_timeout = create_request_timeout
+        self.batch_size = batch_size
         self.gcp_conn_id = gcp_conn_id
         self.impersonation_chain = impersonation_chain
         self.hook: BatchPredictionJobHook | None = None
@@ -241,6 +253,8 @@ class 
CreateBatchPredictionJobOperator(GoogleCloudBaseOperator):
             labels=self.labels,
             encryption_spec_key_name=self.encryption_spec_key_name,
             sync=self.sync,
+            create_request_timeout=self.create_request_timeout,
+            batch_size=self.batch_size,
         )
 
         batch_prediction_job = result.to_dict()
diff --git a/airflow/providers/google/provider.yaml 
b/airflow/providers/google/provider.yaml
index f79c54509a..41263cd87c 100644
--- a/airflow/providers/google/provider.yaml
+++ b/airflow/providers/google/provider.yaml
@@ -95,7 +95,7 @@ dependencies:
   - google-api-python-client>=1.6.0,<2.0.0
   - google-auth>=1.0.0
   - google-auth-httplib2>=0.0.1
-  - google-cloud-aiplatform>=1.7.1,<2.0.0
+  - google-cloud-aiplatform>=1.13.1,<2.0.0
   - google-cloud-automl>=2.1.0
   - google-cloud-bigquery-datatransfer>=3.0.0
   - google-cloud-bigtable>=2.0.0,<3.0.0
diff --git a/docs/apache-airflow-providers-google/index.rst 
b/docs/apache-airflow-providers-google/index.rst
index d78fbf67dd..85dc79385c 100644
--- a/docs/apache-airflow-providers-google/index.rst
+++ b/docs/apache-airflow-providers-google/index.rst
@@ -115,7 +115,7 @@ PIP package                              Version required
 ``google-api-python-client``             ``>=1.6.0,<2.0.0``
 ``google-auth``                          ``>=1.0.0``
 ``google-auth-httplib2``                 ``>=0.0.1``
-``google-cloud-aiplatform``              ``>=1.7.1,<2.0.0``
+``google-cloud-aiplatform``              ``>=1.13.1,<2.0.0``
 ``google-cloud-automl``                  ``>=2.1.0``
 ``google-cloud-bigquery-datatransfer``   ``>=3.0.0``
 ``google-cloud-bigtable``                ``>=2.0.0,<3.0.0``
diff --git a/generated/provider_dependencies.json 
b/generated/provider_dependencies.json
index 9777e3bfa1..a9aae0f43b 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -346,7 +346,7 @@
       "google-auth-httplib2>=0.0.1",
       "google-auth-oauthlib<1.0.0,>=0.3.0",
       "google-auth>=1.0.0",
-      "google-cloud-aiplatform>=1.7.1,<2.0.0",
+      "google-cloud-aiplatform>=1.13.1,<2.0.0",
       "google-cloud-automl>=2.1.0",
       "google-cloud-bigquery-datatransfer>=3.0.0",
       "google-cloud-bigtable>=2.0.0,<3.0.0",
diff --git a/tests/providers/google/cloud/operators/test_vertex_ai.py 
b/tests/providers/google/cloud/operators/test_vertex_ai.py
index 2c5b9a694d..f4b3ad154d 100644
--- a/tests/providers/google/cloud/operators/test_vertex_ai.py
+++ b/tests/providers/google/cloud/operators/test_vertex_ai.py
@@ -165,6 +165,9 @@ TEST_OUTPUT_CONFIG = {
     "export_format_id": "tf-saved-model",
 }
 
+TEST_CREATE_REQUEST_TIMEOUT = 100.5
+TEST_BATCH_SIZE = 4000
+
 
 class TestVertexAICreateCustomContainerTrainingJobOperator:
     @mock.patch(VERTEX_AI_PATH.format("custom_job.CustomJobHook"))
@@ -989,6 +992,8 @@ class TestVertexAICreateBatchPredictionJobOperator:
             model_name=TEST_MODEL_NAME,
             instances_format="jsonl",
             predictions_format="jsonl",
+            create_request_timeout=TEST_CREATE_REQUEST_TIMEOUT,
+            batch_size=TEST_BATCH_SIZE,
         )
         op.execute(context={"ti": mock.MagicMock()})
         mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, 
impersonation_chain=IMPERSONATION_CHAIN)
@@ -1015,6 +1020,8 @@ class TestVertexAICreateBatchPredictionJobOperator:
             labels=None,
             encryption_spec_key_name=None,
             sync=True,
+            create_request_timeout=TEST_CREATE_REQUEST_TIMEOUT,
+            batch_size=TEST_BATCH_SIZE,
         )
 
 

Reply via email to