This is an automated email from the ASF dual-hosted git repository. damccorm pushed a commit to branch users/damccorm/embeddingThrottling in repository https://gitbox.apache.org/repos/asf/beam.git
commit 8540966513d9379e09c468d19396f2f2c1ea5e6b Author: Danny Mccormick <[email protected]> AuthorDate: Fri Dec 6 11:03:05 2024 -0500 Add throttling metrics and retries to vertex embeddings --- .../ml/transforms/embeddings/vertex_ai.py | 95 +++++++++++++++++++++- 1 file changed, 92 insertions(+), 3 deletions(-) diff --git a/sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai.py b/sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai.py index 6fe8320e758..a453f3ce305 100644 --- a/sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai.py +++ b/sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai.py @@ -21,18 +21,25 @@ from collections.abc import Iterable from collections.abc import Sequence +import logging +import time from typing import Any from typing import Optional +from google.api_core.exceptions import ServerError +from google.api_core.exceptions import TooManyRequests from google.auth.credentials import Credentials import apache_beam as beam import vertexai +from apache_beam.io.components.adaptive_throttler import AdaptiveThrottler +from apache_beam.metrics.metric import Metrics from apache_beam.ml.inference.base import ModelHandler from apache_beam.ml.inference.base import RunInference from apache_beam.ml.transforms.base import EmbeddingsManager from apache_beam.ml.transforms.base import _ImageEmbeddingHandler from apache_beam.ml.transforms.base import _TextEmbeddingHandler +from apache_beam.utils import retry from vertexai.language_models import TextEmbeddingInput from vertexai.language_models import TextEmbeddingModel from vertexai.vision_models import Image @@ -51,6 +58,26 @@ TASK_TYPE_INPUTS = [ "CLUSTERING" ] _BATCH_SIZE = 5 # Vertex AI limits requests to 5 at a time. +_MSEC_TO_SEC = 1000 + +LOGGER = logging.getLogger("VertexAIEmbeddings") + + +def _retry_on_appropriate_gcp_error(exception): + """ + Retry filter that returns True if a returned HTTP error code is 5xx or 429. + This is used to retry remote requests that fail, most notably 429 + (TooManyRequests.) + + Args: + exception: the returned exception encountered during the request/response + loop. + + Returns: + boolean indication whether or not the exception is a Server Error (5xx) or + a TooManyRequests (429) error. + """ + return isinstance(exception, (TooManyRequests, ServerError)) class _VertexAITextEmbeddingHandler(ModelHandler): @@ -74,6 +101,33 @@ class _VertexAITextEmbeddingHandler(ModelHandler): self.task_type = task_type self.title = title + @retry.with_exponential_backoff( + num_retries=5, retry_filter=_retry_on_appropriate_gcp_error) + def get_request( + self, + text_batch: Sequence[TextEmbeddingInput], + model: MultiModalEmbeddingModel, + throttle_delay_secs: int): + + while self.throttler.throttle_request(time.time() * _MSEC_TO_SEC): + LOGGER.info( + "Delaying request for %d seconds due to previous failures", + throttle_delay_secs) + time.sleep(throttle_delay_secs) + self.throttled_secs.inc(throttle_delay_secs) + + try: + req_time = time.time() + prediction = model.get_embeddings(text_batch) + self.throttler.successful_request(req_time * _MSEC_TO_SEC) + return prediction + except TooManyRequests as e: + LOGGER.warning("request was limited by the service with code %i", e.code) + raise + except Exception as e: + LOGGER.error("unexpected exception raised as part of request, got %s", e) + raise + def run_inference( self, batch: Sequence[str], @@ -89,7 +143,7 @@ class _VertexAITextEmbeddingHandler(ModelHandler): text=text, title=self.title, task_type=self.task_type) for text in text_batch ] - embeddings_batch = model.get_embeddings(text_batch) + embeddings_batch = self.get_request(text_batch=text_batch, model=model, throttle_delay_secs=5) embeddings.extend([el.values for el in embeddings_batch]) return embeddings @@ -173,6 +227,42 @@ class _VertexAIImageEmbeddingHandler(ModelHandler): self.model_name = model_name self.dimension = dimension + # Configure AdaptiveThrottler and throttling metrics for client-side + # throttling behavior. + # See https://docs.google.com/document/d/1ePorJGZnLbNCmLD9mR7iFYOdPsyDA1rDnTpYnbdrzSU/edit?usp=sharing + # for more details. + self.throttled_secs = Metrics.counter( + VertexAIImageEmbeddings, "cumulativeThrottlingSeconds") + self.throttler = AdaptiveThrottler( + window_ms=1, bucket_ms=1, overload_ratio=2) + + @retry.with_exponential_backoff( + num_retries=5, retry_filter=_retry_on_appropriate_gcp_error) + def get_request( + self, + img: Image, + model: MultiModalEmbeddingModel, + throttle_delay_secs: int): + + while self.throttler.throttle_request(time.time() * _MSEC_TO_SEC): + LOGGER.info( + "Delaying request for %d seconds due to previous failures", + throttle_delay_secs) + time.sleep(throttle_delay_secs) + self.throttled_secs.inc(throttle_delay_secs) + + try: + req_time = time.time() + prediction = model.get_embeddings(image=img, dimension=self.dimension) + self.throttler.successful_request(req_time * _MSEC_TO_SEC) + return prediction + except TooManyRequests as e: + LOGGER.warning("request was limited by the service with code %i", e.code) + raise + except Exception as e: + LOGGER.error("unexpected exception raised as part of request, got %s", e) + raise + def run_inference( self, batch: Sequence[Image], @@ -182,8 +272,7 @@ class _VertexAIImageEmbeddingHandler(ModelHandler): embeddings = [] # Maximum request size for muli-model embedding models is 1. for img in batch: - embedding_response = model.get_embeddings( - image=img, dimension=self.dimension) + embedding_response = self.get_request(img, model, throttle_delay_secs=5) embeddings.append(embedding_response.image_embedding) return embeddings
