gemini-code-assist[bot] commented on code in PR #37428:
URL: https://github.com/apache/beam/pull/37428#discussion_r2742839163
##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -167,10 +167,48 @@ class KeyModelPathMapping(Generic[KeyT]):
class ModelHandler(Generic[ExampleT, PredictionT, ModelT]):
"""Has the ability to load and apply an ML model."""
- def __init__(self):
- """Environment variables are set using a dict named 'env_vars' before
- loading the model. Child classes can accept this dict as a kwarg."""
- self._env_vars = {}
+ def __init__(
+ self,
+ *,
+ min_batch_size: Optional[int] = None,
+ max_batch_size: Optional[int] = None,
+ max_batch_duration_secs: Optional[int] = None,
+ max_batch_weight: Optional[int] = None,
+ element_size_fn: Optional[Callable[[Any], int]] = None,
+ large_model: bool = False,
+ model_copies: Optional[int] = None,
+ **kwargs):
+ """Initializes the ModelHandler.
+
+ Args:
+ min_batch_size: the minimum batch size to use when batching inputs.
+ max_batch_size: the maximum batch size to use when batching inputs.
+ max_batch_duration_secs: the maximum amount of time to buffer a batch
+ before emitting; used in streaming contexts.
+ max_batch_weight: the maximum weight of a batch. Requires
element_size_fn.
+ element_size_fn: a function that returns the size (weight) of an element.
+ large_model: set to true if your model is large enough to run into
+ memory pressure if you load multiple copies.
+ model_copies: The exact number of models that you would like loaded
+ onto your machine.
+ kwargs: 'env_vars' can be used to set environment variables
+ before loading the model.
+ """
+ self._env_vars = kwargs.get('env_vars', {})
+ self._batching_kwargs: dict[str, Any] = {}
+ if min_batch_size is not None:
+ self._batching_kwargs['min_batch_size'] = min_batch_size
+ if max_batch_size is not None:
+ self._batching_kwargs['max_batch_size'] = max_batch_size
+ if max_batch_duration_secs is not None:
+ self._batching_kwargs['max_batch_duration_secs'] =
max_batch_duration_secs
+ if max_batch_weight is not None:
+ self._batching_kwargs['max_batch_weight'] = max_batch_weight
+ if element_size_fn is not None:
+ self._batching_kwargs['element_size_fn'] = element_size_fn
Review Comment:

This block of `if` statements can be made more concise and maintainable by
using a dictionary comprehension to build `_batching_kwargs`.
```suggestion
self._batching_kwargs: dict[str, Any] = {
k: v for k, v in (
('min_batch_size', min_batch_size),
('max_batch_size', max_batch_size),
('max_batch_duration_secs', max_batch_duration_secs),
('max_batch_weight', max_batch_weight),
('element_size_fn', element_size_fn),
) if v is not None
}
```
##########
sdks/python/apache_beam/ml/inference/base_test.py:
##########
@@ -2133,5 +2133,63 @@ def request(self, batch, model, inference_args=None):
model_handler.run_inference([1], FakeModel())
+class FakeModelHandlerForSizing(base.ModelHandler[int, int, FakeModel]):
+ """A ModelHandler used to test element sizing behavior."""
+ def __init__(
+ self,
+ max_batch_size: int = 10,
+ max_batch_weight: Optional[int] = None,
+ element_size_fn=None):
+ self._max_batch_size = max_batch_size
+ self._max_batch_weight = max_batch_weight
+ self._element_size_fn = element_size_fn
+
+ def load_model(self) -> FakeModel:
+ return FakeModel()
+
+ def run_inference(self, batch, model, inference_args=None):
+ return [model.predict(x) for x in batch]
+
+ def batch_elements_kwargs(self):
+ kwargs = {'max_batch_size': self._max_batch_size}
+ if self._max_batch_weight is not None:
+ kwargs['max_batch_weight'] = self._max_batch_weight
+ if self._element_size_fn:
+ kwargs['element_size_fn'] = self._element_size_fn
+ return kwargs
Review Comment:

The `FakeModelHandlerForSizing` re-implements the logic for handling
batching keyword arguments, instead of leveraging the new implementation in the
`ModelHandler` base class. To ensure the test correctly validates the base
class behavior, this test handler should be refactored to call
`super().__init__` and remove the overridden `batch_elements_kwargs` method.
```suggestion
def __init__(
self,
max_batch_size: int = 10,
max_batch_weight: Optional[int] = None,
element_size_fn=None):
super().__init__(
max_batch_size=max_batch_size,
max_batch_weight=max_batch_weight,
element_size_fn=element_size_fn)
def load_model(self) -> FakeModel:
return FakeModel()
def run_inference(self, batch, model, inference_args=None):
return [model.predict(x) for x in batch]
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]