jrmccluskey commented on code in PR #34379: URL: https://github.com/apache/beam/pull/34379#discussion_r2025303498
########## sdks/python/apache_beam/ml/inference/base.py: ########## @@ -339,6 +346,139 @@ def should_garbage_collect_on_timeout(self) -> bool: return self.share_model_across_processes() +class RemoteModelHandler(ABC, ModelHandler[ExampleT, PredictionT, ModelT]): + """Has the ability to call a model at a remote endpoint.""" + def __init__( + self, + namespace: str = '', + num_retries: int = 5, + throttle_delay_secs: int = 5, + retry_filter: Callable[[Exception], bool] = lambda x: True, + *, + window_ms: int = 1 * _MILLISECOND_TO_SECOND, + bucket_ms: int = 1 * _MILLISECOND_TO_SECOND, + overload_ratio: float = 2): + """Initializes metrics tracking + an AdaptiveThrottler class for enabling + client-side throttling for remote calls to an inference service. + See https://s.apache.org/beam-client-side-throttling for more details + on the configuration of the throttling and retry + mechanics. + + Args: + namespace: the metrics and logging namespace + num_retries: the maximum number of times to retry a request on retriable + errors before failing + throttle_delay_secs: the amount of time to throttle when the client-side + elects to throttle + retry_filter: a function accepting an exception as an argument and + returning a boolean. On a true return, the run_inference call will + be retried. Defaults to always retrying. + window_ms: length of history to consider, in ms, to set throttling. + bucket_ms: granularity of time buckets that we store data in, in ms. + overload_ratio: the target ratio between requests sent and successful + requests. This is "K" in the formula in + https://landing.google.com/sre/book/chapters/handling-overload.html. + """ + # Configure AdaptiveThrottler and throttling metrics for client-side + # throttling behavior. + self.throttled_secs = Metrics.counter( + namespace, "cumulativeThrottlingSeconds") + self.throttler = AdaptiveThrottler( + window_ms=window_ms, bucket_ms=bucket_ms, overload_ratio=overload_ratio) + self.logger = logging.getLogger(namespace) Review Comment: it will work, the logging and the metric will just not be specific to the model handler (for logging that isn't a big deal, but if you hypothetically had multiple distinct remote model handler classes they would share the same cumulativeThrottlingSeconds counter) -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@beam.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org