This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new a66edcbb2c `CreateBatchPredictionJobOperator` Add batch_size param for
Vertex AI BatchPredictionJob objects (#31118)
a66edcbb2c is described below
commit a66edcbb2c9c04031326a3965c76a419043338ff
Author: Viacheslav <[email protected]>
AuthorDate: Sat May 13 10:45:19 2023 -0700
`CreateBatchPredictionJobOperator` Add batch_size param for Vertex AI
BatchPredictionJob objects (#31118)
* Add batch_size param for BatchPredictionJob objects
Co-authored-by: Jarek Potiuk <[email protected]>
---------
Co-authored-by: Jarek Potiuk <[email protected]>
---
.../google/cloud/hooks/vertex_ai/batch_prediction_job.py | 12 ++++++++++++
.../cloud/operators/vertex_ai/batch_prediction_job.py | 14 ++++++++++++++
airflow/providers/google/provider.yaml | 2 +-
docs/apache-airflow-providers-google/index.rst | 2 +-
generated/provider_dependencies.json | 2 +-
tests/providers/google/cloud/operators/test_vertex_ai.py | 7 +++++++
6 files changed, 36 insertions(+), 3 deletions(-)
diff --git
a/airflow/providers/google/cloud/hooks/vertex_ai/batch_prediction_job.py
b/airflow/providers/google/cloud/hooks/vertex_ai/batch_prediction_job.py
index 93cd6902bb..a3d0d20513 100644
--- a/airflow/providers/google/cloud/hooks/vertex_ai/batch_prediction_job.py
+++ b/airflow/providers/google/cloud/hooks/vertex_ai/batch_prediction_job.py
@@ -114,6 +114,8 @@ class BatchPredictionJobHook(GoogleBaseHook):
labels: dict[str, str] | None = None,
encryption_spec_key_name: str | None = None,
sync: bool = True,
+ create_request_timeout: float | None = None,
+ batch_size: int | None = None,
) -> BatchPredictionJob:
"""
Create a batch prediction job.
@@ -207,6 +209,14 @@ class BatchPredictionJobHook(GoogleBaseHook):
:param sync: Whether to execute this method synchronously. If False,
this method will be executed in
concurrent Future and any downstream object will be immediately
returned and synced when the
Future has completed.
+ :param create_request_timeout: Optional. The timeout for the create
request in seconds.
+ :param batch_size: Optional. The number of the records (e.g. instances)
+ of the operation given in each batch
+ to a machine replica. Machine type, and size of a single record
should be considered
+ when setting this parameter, higher value speeds up the batch
operation's execution,
+ but too high value will result in a whole batch not fitting in a
machine's memory,
+ and the whole operation will fail.
+ The default value is same as in the aiplatform's
BatchPredictionJob.
"""
self._batch_prediction_job = BatchPredictionJob.create(
job_display_name=job_display_name,
@@ -232,6 +242,8 @@ class BatchPredictionJobHook(GoogleBaseHook):
credentials=self.get_credentials(),
encryption_spec_key_name=encryption_spec_key_name,
sync=sync,
+ create_request_timeout=create_request_timeout,
+ batch_size=batch_size,
)
return self._batch_prediction_job
diff --git
a/airflow/providers/google/cloud/operators/vertex_ai/batch_prediction_job.py
b/airflow/providers/google/cloud/operators/vertex_ai/batch_prediction_job.py
index 5c12322775..dc4775fe24 100644
--- a/airflow/providers/google/cloud/operators/vertex_ai/batch_prediction_job.py
+++ b/airflow/providers/google/cloud/operators/vertex_ai/batch_prediction_job.py
@@ -139,6 +139,14 @@ class
CreateBatchPredictionJobOperator(GoogleCloudBaseOperator):
:param sync: Whether to execute this method synchronously. If False, this
method will be executed in
concurrent Future and any downstream object will be immediately
returned and synced when the
Future has completed.
+ :param create_request_timeout: Optional. The timeout for the create
request in seconds.
+ :param batch_size: Optional. The number of the records (e.g. instances)
+ of the operation given in each batch
+ to a machine replica. Machine type, and size of a single record should
be considered
+ when setting this parameter, higher value speeds up the batch
operation's execution,
+ but too high value will result in a whole batch not fitting in a
machine's memory,
+ and the whole operation will fail.
+ The default value is same as in the aiplatform's BatchPredictionJob.
:param retry: Designation of what errors, if any, should be retried.
:param timeout: The timeout for this request.
:param metadata: Strings which should be sent along with the request as
metadata.
@@ -181,6 +189,8 @@ class
CreateBatchPredictionJobOperator(GoogleCloudBaseOperator):
labels: dict[str, str] | None = None,
encryption_spec_key_name: str | None = None,
sync: bool = True,
+ create_request_timeout: float | None = None,
+ batch_size: int | None = None,
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
**kwargs,
@@ -208,6 +218,8 @@ class
CreateBatchPredictionJobOperator(GoogleCloudBaseOperator):
self.labels = labels
self.encryption_spec_key_name = encryption_spec_key_name
self.sync = sync
+ self.create_request_timeout = create_request_timeout
+ self.batch_size = batch_size
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
self.hook: BatchPredictionJobHook | None = None
@@ -241,6 +253,8 @@ class
CreateBatchPredictionJobOperator(GoogleCloudBaseOperator):
labels=self.labels,
encryption_spec_key_name=self.encryption_spec_key_name,
sync=self.sync,
+ create_request_timeout=self.create_request_timeout,
+ batch_size=self.batch_size,
)
batch_prediction_job = result.to_dict()
diff --git a/airflow/providers/google/provider.yaml
b/airflow/providers/google/provider.yaml
index f79c54509a..41263cd87c 100644
--- a/airflow/providers/google/provider.yaml
+++ b/airflow/providers/google/provider.yaml
@@ -95,7 +95,7 @@ dependencies:
- google-api-python-client>=1.6.0,<2.0.0
- google-auth>=1.0.0
- google-auth-httplib2>=0.0.1
- - google-cloud-aiplatform>=1.7.1,<2.0.0
+ - google-cloud-aiplatform>=1.13.1,<2.0.0
- google-cloud-automl>=2.1.0
- google-cloud-bigquery-datatransfer>=3.0.0
- google-cloud-bigtable>=2.0.0,<3.0.0
diff --git a/docs/apache-airflow-providers-google/index.rst
b/docs/apache-airflow-providers-google/index.rst
index d78fbf67dd..85dc79385c 100644
--- a/docs/apache-airflow-providers-google/index.rst
+++ b/docs/apache-airflow-providers-google/index.rst
@@ -115,7 +115,7 @@ PIP package Version required
``google-api-python-client`` ``>=1.6.0,<2.0.0``
``google-auth`` ``>=1.0.0``
``google-auth-httplib2`` ``>=0.0.1``
-``google-cloud-aiplatform`` ``>=1.7.1,<2.0.0``
+``google-cloud-aiplatform`` ``>=1.13.1,<2.0.0``
``google-cloud-automl`` ``>=2.1.0``
``google-cloud-bigquery-datatransfer`` ``>=3.0.0``
``google-cloud-bigtable`` ``>=2.0.0,<3.0.0``
diff --git a/generated/provider_dependencies.json
b/generated/provider_dependencies.json
index 9777e3bfa1..a9aae0f43b 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -346,7 +346,7 @@
"google-auth-httplib2>=0.0.1",
"google-auth-oauthlib<1.0.0,>=0.3.0",
"google-auth>=1.0.0",
- "google-cloud-aiplatform>=1.7.1,<2.0.0",
+ "google-cloud-aiplatform>=1.13.1,<2.0.0",
"google-cloud-automl>=2.1.0",
"google-cloud-bigquery-datatransfer>=3.0.0",
"google-cloud-bigtable>=2.0.0,<3.0.0",
diff --git a/tests/providers/google/cloud/operators/test_vertex_ai.py
b/tests/providers/google/cloud/operators/test_vertex_ai.py
index 2c5b9a694d..f4b3ad154d 100644
--- a/tests/providers/google/cloud/operators/test_vertex_ai.py
+++ b/tests/providers/google/cloud/operators/test_vertex_ai.py
@@ -165,6 +165,9 @@ TEST_OUTPUT_CONFIG = {
"export_format_id": "tf-saved-model",
}
+TEST_CREATE_REQUEST_TIMEOUT = 100.5
+TEST_BATCH_SIZE = 4000
+
class TestVertexAICreateCustomContainerTrainingJobOperator:
@mock.patch(VERTEX_AI_PATH.format("custom_job.CustomJobHook"))
@@ -989,6 +992,8 @@ class TestVertexAICreateBatchPredictionJobOperator:
model_name=TEST_MODEL_NAME,
instances_format="jsonl",
predictions_format="jsonl",
+ create_request_timeout=TEST_CREATE_REQUEST_TIMEOUT,
+ batch_size=TEST_BATCH_SIZE,
)
op.execute(context={"ti": mock.MagicMock()})
mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN)
@@ -1015,6 +1020,8 @@ class TestVertexAICreateBatchPredictionJobOperator:
labels=None,
encryption_spec_key_name=None,
sync=True,
+ create_request_timeout=TEST_CREATE_REQUEST_TIMEOUT,
+ batch_size=TEST_BATCH_SIZE,
)