jrmccluskey commented on code in PR #34379:
URL: https://github.com/apache/beam/pull/34379#discussion_r2025303498


##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -339,6 +346,139 @@ def should_garbage_collect_on_timeout(self) -> bool:
     return self.share_model_across_processes()
 
 
+class RemoteModelHandler(ABC, ModelHandler[ExampleT, PredictionT, ModelT]):
+  """Has the ability to call a model at a remote endpoint."""
+  def __init__(
+      self,
+      namespace: str = '',
+      num_retries: int = 5,
+      throttle_delay_secs: int = 5,
+      retry_filter: Callable[[Exception], bool] = lambda x: True,
+      *,
+      window_ms: int = 1 * _MILLISECOND_TO_SECOND,
+      bucket_ms: int = 1 * _MILLISECOND_TO_SECOND,
+      overload_ratio: float = 2):
+    """Initializes metrics tracking + an AdaptiveThrottler class for enabling
+    client-side throttling for remote calls to an inference service.
+    See https://s.apache.org/beam-client-side-throttling for more details
+    on the configuration of the throttling and retry
+    mechanics.
+
+    Args:
+      namespace: the metrics and logging namespace 
+      num_retries: the maximum number of times to retry a request on retriable
+        errors before failing
+      throttle_delay_secs: the amount of time to throttle when the client-side
+        elects to throttle
+      retry_filter: a function accepting an exception as an argument and
+        returning a boolean. On a true return, the run_inference call will
+        be retried. Defaults to always retrying.
+      window_ms: length of history to consider, in ms, to set throttling.
+      bucket_ms: granularity of time buckets that we store data in, in ms.
+      overload_ratio: the target ratio between requests sent and successful
+        requests. This is "K" in the formula in 
+        https://landing.google.com/sre/book/chapters/handling-overload.html.
+    """
+    # Configure AdaptiveThrottler and throttling metrics for client-side
+    # throttling behavior.
+    self.throttled_secs = Metrics.counter(
+        namespace, "cumulativeThrottlingSeconds")
+    self.throttler = AdaptiveThrottler(
+        window_ms=window_ms, bucket_ms=bucket_ms, 
overload_ratio=overload_ratio)
+    self.logger = logging.getLogger(namespace)

Review Comment:
   it will work, the logging and the metric will just not be specific to the 
model handler (for logging that isn't a big deal, but if you hypothetically had 
multiple distinct remote model handler classes they would share the same 
cumulativeThrottlingSeconds counter)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscr...@beam.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to