damccorm commented on code in PR #37428:
URL: https://github.com/apache/beam/pull/37428#discussion_r2742852549


##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -167,10 +167,48 @@ class KeyModelPathMapping(Generic[KeyT]):
 
 class ModelHandler(Generic[ExampleT, PredictionT, ModelT]):
   """Has the ability to load and apply an ML model."""
-  def __init__(self):
-    """Environment variables are set using a dict named 'env_vars' before
-    loading the model. Child classes can accept this dict as a kwarg."""
-    self._env_vars = {}
+  def __init__(
+      self,
+      *,
+      min_batch_size: Optional[int] = None,
+      max_batch_size: Optional[int] = None,
+      max_batch_duration_secs: Optional[int] = None,
+      max_batch_weight: Optional[int] = None,
+      element_size_fn: Optional[Callable[[Any], int]] = None,
+      large_model: bool = False,
+      model_copies: Optional[int] = None,
+      **kwargs):
+    """Initializes the ModelHandler.
+
+    Args:
+      min_batch_size: the minimum batch size to use when batching inputs.
+      max_batch_size: the maximum batch size to use when batching inputs.
+      max_batch_duration_secs: the maximum amount of time to buffer a batch
+        before emitting; used in streaming contexts.
+      max_batch_weight: the maximum weight of a batch. Requires 
element_size_fn.
+      element_size_fn: a function that returns the size (weight) of an element.
+      large_model: set to true if your model is large enough to run into
+        memory pressure if you load multiple copies.
+      model_copies: The exact number of models that you would like loaded
+        onto your machine.
+      kwargs: 'env_vars' can be used to set environment variables
+        before loading the model.
+    """
+    self._env_vars = kwargs.get('env_vars', {})
+    self._batching_kwargs: dict[str, Any] = {}
+    if min_batch_size is not None:
+      self._batching_kwargs['min_batch_size'] = min_batch_size
+    if max_batch_size is not None:
+      self._batching_kwargs['max_batch_size'] = max_batch_size
+    if max_batch_duration_secs is not None:
+      self._batching_kwargs['max_batch_duration_secs'] = 
max_batch_duration_secs
+    if max_batch_weight is not None:
+      self._batching_kwargs['max_batch_weight'] = max_batch_weight
+    if element_size_fn is not None:
+      self._batching_kwargs['element_size_fn'] = element_size_fn

Review Comment:
   I don't think this is cleaner



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to