gemini-code-assist[bot] commented on code in PR #37428:
URL: https://github.com/apache/beam/pull/37428#discussion_r2742839163


##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -167,10 +167,48 @@ class KeyModelPathMapping(Generic[KeyT]):
 
 class ModelHandler(Generic[ExampleT, PredictionT, ModelT]):
   """Has the ability to load and apply an ML model."""
-  def __init__(self):
-    """Environment variables are set using a dict named 'env_vars' before
-    loading the model. Child classes can accept this dict as a kwarg."""
-    self._env_vars = {}
+  def __init__(
+      self,
+      *,
+      min_batch_size: Optional[int] = None,
+      max_batch_size: Optional[int] = None,
+      max_batch_duration_secs: Optional[int] = None,
+      max_batch_weight: Optional[int] = None,
+      element_size_fn: Optional[Callable[[Any], int]] = None,
+      large_model: bool = False,
+      model_copies: Optional[int] = None,
+      **kwargs):
+    """Initializes the ModelHandler.
+
+    Args:
+      min_batch_size: the minimum batch size to use when batching inputs.
+      max_batch_size: the maximum batch size to use when batching inputs.
+      max_batch_duration_secs: the maximum amount of time to buffer a batch
+        before emitting; used in streaming contexts.
+      max_batch_weight: the maximum weight of a batch. Requires 
element_size_fn.
+      element_size_fn: a function that returns the size (weight) of an element.
+      large_model: set to true if your model is large enough to run into
+        memory pressure if you load multiple copies.
+      model_copies: The exact number of models that you would like loaded
+        onto your machine.
+      kwargs: 'env_vars' can be used to set environment variables
+        before loading the model.
+    """
+    self._env_vars = kwargs.get('env_vars', {})
+    self._batching_kwargs: dict[str, Any] = {}
+    if min_batch_size is not None:
+      self._batching_kwargs['min_batch_size'] = min_batch_size
+    if max_batch_size is not None:
+      self._batching_kwargs['max_batch_size'] = max_batch_size
+    if max_batch_duration_secs is not None:
+      self._batching_kwargs['max_batch_duration_secs'] = 
max_batch_duration_secs
+    if max_batch_weight is not None:
+      self._batching_kwargs['max_batch_weight'] = max_batch_weight
+    if element_size_fn is not None:
+      self._batching_kwargs['element_size_fn'] = element_size_fn

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   This block of `if` statements can be made more concise and maintainable by 
using a dictionary comprehension to build `_batching_kwargs`.
   
   ```suggestion
       self._batching_kwargs: dict[str, Any] = {
           k: v for k, v in (
               ('min_batch_size', min_batch_size),
               ('max_batch_size', max_batch_size),
               ('max_batch_duration_secs', max_batch_duration_secs),
               ('max_batch_weight', max_batch_weight),
               ('element_size_fn', element_size_fn),
           ) if v is not None
       }
   ```



##########
sdks/python/apache_beam/ml/inference/base_test.py:
##########
@@ -2133,5 +2133,63 @@ def request(self, batch, model, inference_args=None):
       model_handler.run_inference([1], FakeModel())
 
 
+class FakeModelHandlerForSizing(base.ModelHandler[int, int, FakeModel]):
+  """A ModelHandler used to test element sizing behavior."""
+  def __init__(
+      self,
+      max_batch_size: int = 10,
+      max_batch_weight: Optional[int] = None,
+      element_size_fn=None):
+    self._max_batch_size = max_batch_size
+    self._max_batch_weight = max_batch_weight
+    self._element_size_fn = element_size_fn
+
+  def load_model(self) -> FakeModel:
+    return FakeModel()
+
+  def run_inference(self, batch, model, inference_args=None):
+    return [model.predict(x) for x in batch]
+
+  def batch_elements_kwargs(self):
+    kwargs = {'max_batch_size': self._max_batch_size}
+    if self._max_batch_weight is not None:
+      kwargs['max_batch_weight'] = self._max_batch_weight
+    if self._element_size_fn:
+      kwargs['element_size_fn'] = self._element_size_fn
+    return kwargs

Review Comment:
   ![high](https://www.gstatic.com/codereviewagent/high-priority.svg)
   
   The `FakeModelHandlerForSizing` re-implements the logic for handling 
batching keyword arguments, instead of leveraging the new implementation in the 
`ModelHandler` base class. To ensure the test correctly validates the base 
class behavior, this test handler should be refactored to call 
`super().__init__` and remove the overridden `batch_elements_kwargs` method.
   
   ```suggestion
     def __init__(
         self,
         max_batch_size: int = 10,
         max_batch_weight: Optional[int] = None,
         element_size_fn=None):
       super().__init__(
           max_batch_size=max_batch_size,
           max_batch_weight=max_batch_weight,
           element_size_fn=element_size_fn)
   
     def load_model(self) -> FakeModel:
       return FakeModel()
   
     def run_inference(self, batch, model, inference_args=None):
       return [model.predict(x) for x in batch]
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to