This is an automated email from the ASF dual-hosted git repository.

damccorm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new cdf48147bdd [RunInference] Add content-aware dynamic batching via 
element_size_fn… (#37428)
cdf48147bdd is described below

commit cdf48147bdd5cec78914f1a434af9fc87782b893
Author: Elia LIU <[email protected]>
AuthorDate: Sat Jan 31 00:43:49 2026 +1100

    [RunInference] Add content-aware dynamic batching via element_size_fn… 
(#37428)
    
    * Refactor: Unify batching args in ModelHandler constructors
    
    - Added 'max_batch_weight' and 'element_size_fn' to __init__ of all 
ModelHandlers (PyTorch, Sklearn, TF, ONNX, XGBoost, TensorRT, Hugging Face, 
vLLM, VertexAI).
    - Updated subclasses to delegate these args to 'super().__init__' or 
internal batching kwargs.
    - Removed 'with_element_size_fn' builder method from base class to enforce 
API consistency.
    - Updated tests to reflect the new API signature.
    
    * Address review comments: refactor tests and fix linting
---
 sdks/python/apache_beam/ml/inference/base.py       |  52 +++++++-
 sdks/python/apache_beam/ml/inference/base_test.py  | 146 +++++++++++++++++++++
 .../apache_beam/ml/inference/gemini_inference.py   |  16 ++-
 .../ml/inference/huggingface_inference.py          |  96 ++++++--------
 .../apache_beam/ml/inference/onnx_inference.py     |  33 ++---
 .../apache_beam/ml/inference/pytorch_inference.py  |  64 ++++-----
 .../apache_beam/ml/inference/sklearn_inference.py  |  64 ++++-----
 .../ml/inference/tensorflow_inference.py           |  66 ++++------
 .../apache_beam/ml/inference/tensorrt_inference.py |  33 ++---
 .../ml/inference/vertex_ai_inference.py            |  12 +-
 .../apache_beam/ml/inference/vllm_inference.py     |  49 ++++++-
 .../apache_beam/ml/inference/xgboost_inference.py  |  26 ++--
 12 files changed, 419 insertions(+), 238 deletions(-)

diff --git a/sdks/python/apache_beam/ml/inference/base.py 
b/sdks/python/apache_beam/ml/inference/base.py
index e0f870669f7..ad2e2f8d0e3 100644
--- a/sdks/python/apache_beam/ml/inference/base.py
+++ b/sdks/python/apache_beam/ml/inference/base.py
@@ -167,10 +167,48 @@ class KeyModelPathMapping(Generic[KeyT]):
 
 class ModelHandler(Generic[ExampleT, PredictionT, ModelT]):
   """Has the ability to load and apply an ML model."""
-  def __init__(self):
-    """Environment variables are set using a dict named 'env_vars' before
-    loading the model. Child classes can accept this dict as a kwarg."""
-    self._env_vars = {}
+  def __init__(
+      self,
+      *,
+      min_batch_size: Optional[int] = None,
+      max_batch_size: Optional[int] = None,
+      max_batch_duration_secs: Optional[int] = None,
+      max_batch_weight: Optional[int] = None,
+      element_size_fn: Optional[Callable[[Any], int]] = None,
+      large_model: bool = False,
+      model_copies: Optional[int] = None,
+      **kwargs):
+    """Initializes the ModelHandler.
+
+    Args:
+      min_batch_size: the minimum batch size to use when batching inputs.
+      max_batch_size: the maximum batch size to use when batching inputs.
+      max_batch_duration_secs: the maximum amount of time to buffer a batch
+        before emitting; used in streaming contexts.
+      max_batch_weight: the maximum weight of a batch. Requires 
element_size_fn.
+      element_size_fn: a function that returns the size (weight) of an element.
+      large_model: set to true if your model is large enough to run into
+        memory pressure if you load multiple copies.
+      model_copies: The exact number of models that you would like loaded
+        onto your machine.
+      kwargs: 'env_vars' can be used to set environment variables
+        before loading the model.
+    """
+    self._env_vars = kwargs.get('env_vars', {})
+    self._batching_kwargs: dict[str, Any] = {}
+    if min_batch_size is not None:
+      self._batching_kwargs['min_batch_size'] = min_batch_size
+    if max_batch_size is not None:
+      self._batching_kwargs['max_batch_size'] = max_batch_size
+    if max_batch_duration_secs is not None:
+      self._batching_kwargs['max_batch_duration_secs'] = 
max_batch_duration_secs
+    if max_batch_weight is not None:
+      self._batching_kwargs['max_batch_weight'] = max_batch_weight
+    if element_size_fn is not None:
+      self._batching_kwargs['element_size_fn'] = element_size_fn
+    self._large_model = large_model
+    self._model_copies = model_copies
+    self._share_across_processes = large_model or (model_copies is not None)
 
   def load_model(self) -> ModelT:
     """Loads and initializes a model for processing."""
@@ -220,7 +258,7 @@ class ModelHandler(Generic[ExampleT, PredictionT, ModelT]):
     Returns:
        kwargs suitable for beam.BatchElements.
     """
-    return {}
+    return getattr(self, '_batching_kwargs', {})
 
   def validate_inference_args(self, inference_args: Optional[dict[str, Any]]):
     """
@@ -325,14 +363,14 @@ class ModelHandler(Generic[ExampleT, PredictionT, 
ModelT]):
     memory. Multi-process support may vary by runner, but this will fallback to
     loading per process as necessary. See
     
https://beam.apache.org/releases/pydoc/current/apache_beam.utils.multi_process_shared.html""";
-    return False
+    return getattr(self, '_share_across_processes', False)
 
   def model_copies(self) -> int:
     """Returns the maximum number of model copies that should be loaded at one
     time. This only impacts model handlers that are using
     share_model_across_processes to share their model across processes instead
     of being loaded per process."""
-    return 1
+    return getattr(self, '_model_copies', None) or 1
 
   def override_metrics(self, metrics_namespace: str = '') -> bool:
     """Returns a boolean representing whether or not a model handler will
diff --git a/sdks/python/apache_beam/ml/inference/base_test.py 
b/sdks/python/apache_beam/ml/inference/base_test.py
index 381bf545660..55784166ad5 100644
--- a/sdks/python/apache_beam/ml/inference/base_test.py
+++ b/sdks/python/apache_beam/ml/inference/base_test.py
@@ -2133,5 +2133,151 @@ class RunInferenceRemoteTest(unittest.TestCase):
       model_handler.run_inference([1], FakeModel())
 
 
+class FakeModelHandlerForSizing(base.ModelHandler[int, int, FakeModel]):
+  """A ModelHandler used to test element sizing behavior."""
+  def __init__(
+      self,
+      max_batch_size: int = 10,
+      max_batch_weight: Optional[int] = None,
+      element_size_fn=None):
+    super().__init__(
+        max_batch_size=max_batch_size,
+        max_batch_weight=max_batch_weight,
+        element_size_fn=element_size_fn)
+
+  def load_model(self) -> FakeModel:
+    return FakeModel()
+
+  def run_inference(self, batch, model, inference_args=None):
+    return [model.predict(x) for x in batch]
+
+
+class RunInferenceSizeTest(unittest.TestCase):
+  """Tests for ModelHandler.batch_elements_kwargs with element_size_fn."""
+  def test_kwargs_are_passed_correctly(self):
+    """Adds element_size_fn without clobbering existing kwargs."""
+    def size_fn(x):
+      return 10
+
+    sized_handler = FakeModelHandlerForSizing(
+        max_batch_size=20, max_batch_weight=100, element_size_fn=size_fn)
+
+    kwargs = sized_handler.batch_elements_kwargs()
+
+    self.assertEqual(kwargs['max_batch_size'], 20)
+    self.assertEqual(kwargs['max_batch_weight'], 100)
+    self.assertIn('element_size_fn', kwargs)
+    self.assertEqual(kwargs['element_size_fn'](1), 10)
+
+  def test_sizing_with_edge_cases(self):
+    """Allows extreme values from element_size_fn."""
+    zero_size_fn = lambda x: 0
+    sized_handler = FakeModelHandlerForSizing(
+        max_batch_size=1, element_size_fn=zero_size_fn)
+    kwargs = sized_handler.batch_elements_kwargs()
+    self.assertEqual(kwargs['element_size_fn'](999), 0)
+
+    large_size_fn = lambda x: 1000000
+    sized_handler = FakeModelHandlerForSizing(
+        max_batch_size=1, element_size_fn=large_size_fn)
+    kwargs = sized_handler.batch_elements_kwargs()
+    self.assertEqual(kwargs['element_size_fn'](1), 1000000)
+
+
+class FakeModelHandlerForBatching(base.ModelHandler[int, int, FakeModel]):
+  """A ModelHandler used to test batching behavior via base class __init__."""
+  def __init__(self, **kwargs):
+    super().__init__(**kwargs)
+
+  def load_model(self) -> FakeModel:
+    return FakeModel()
+
+  def run_inference(self, batch, model, inference_args=None):
+    return [model.predict(x) for x in batch]
+
+
+class ModelHandlerBatchingArgsTest(unittest.TestCase):
+  """Tests for ModelHandler.__init__ batching parameters."""
+  def test_batch_elements_kwargs_all_args(self):
+    """All batching args passed to __init__ are in batch_elements_kwargs."""
+    def size_fn(x):
+      return 10
+
+    handler = FakeModelHandlerForBatching(
+        min_batch_size=5,
+        max_batch_size=20,
+        max_batch_duration_secs=30,
+        max_batch_weight=100,
+        element_size_fn=size_fn)
+
+    kwargs = handler.batch_elements_kwargs()
+
+    self.assertEqual(kwargs['min_batch_size'], 5)
+    self.assertEqual(kwargs['max_batch_size'], 20)
+    self.assertEqual(kwargs['max_batch_duration_secs'], 30)
+    self.assertEqual(kwargs['max_batch_weight'], 100)
+    self.assertIn('element_size_fn', kwargs)
+    self.assertEqual(kwargs['element_size_fn'](1), 10)
+
+  def test_batch_elements_kwargs_partial_args(self):
+    """Only provided batching args are included in kwargs."""
+    handler = FakeModelHandlerForBatching(max_batch_size=50)
+    kwargs = handler.batch_elements_kwargs()
+
+    self.assertEqual(kwargs, {'max_batch_size': 50})
+
+  def test_batch_elements_kwargs_empty_when_no_args(self):
+    """No batching kwargs when none are provided."""
+    handler = FakeModelHandlerForBatching()
+    kwargs = handler.batch_elements_kwargs()
+
+    self.assertEqual(kwargs, {})
+
+  def test_large_model_sets_share_across_processes(self):
+    """Setting large_model=True enables share_model_across_processes."""
+    handler = FakeModelHandlerForBatching(large_model=True)
+
+    self.assertTrue(handler.share_model_across_processes())
+
+  def test_model_copies_sets_share_across_processes(self):
+    """Setting model_copies enables share_model_across_processes."""
+    handler = FakeModelHandlerForBatching(model_copies=2)
+
+    self.assertTrue(handler.share_model_across_processes())
+    self.assertEqual(handler.model_copies(), 2)
+
+  def test_default_share_across_processes_is_false(self):
+    """Default share_model_across_processes is False."""
+    handler = FakeModelHandlerForBatching()
+
+    self.assertFalse(handler.share_model_across_processes())
+
+  def test_default_model_copies_is_one(self):
+    """Default model_copies is 1."""
+    handler = FakeModelHandlerForBatching()
+
+    self.assertEqual(handler.model_copies(), 1)
+
+  def test_env_vars_from_kwargs(self):
+    """Environment variables can be passed via kwargs."""
+    handler = FakeModelHandlerForBatching(env_vars={'MY_VAR': 'value'})
+
+    self.assertEqual(handler._env_vars, {'MY_VAR': 'value'})
+
+  def test_min_batch_size_only(self):
+    """min_batch_size can be passed alone."""
+    handler = FakeModelHandlerForBatching(min_batch_size=10)
+    kwargs = handler.batch_elements_kwargs()
+
+    self.assertEqual(kwargs, {'min_batch_size': 10})
+
+  def test_max_batch_duration_secs_only(self):
+    """max_batch_duration_secs can be passed alone."""
+    handler = FakeModelHandlerForBatching(max_batch_duration_secs=60)
+    kwargs = handler.batch_elements_kwargs()
+
+    self.assertEqual(kwargs, {'max_batch_duration_secs': 60})
+
+
 if __name__ == '__main__':
   unittest.main()
diff --git a/sdks/python/apache_beam/ml/inference/gemini_inference.py 
b/sdks/python/apache_beam/ml/inference/gemini_inference.py
index c840efedd8f..a79fbe8a555 100644
--- a/sdks/python/apache_beam/ml/inference/gemini_inference.py
+++ b/sdks/python/apache_beam/ml/inference/gemini_inference.py
@@ -112,6 +112,8 @@ class GeminiModelHandler(RemoteModelHandler[Any, 
PredictionResult,
       min_batch_size: Optional[int] = None,
       max_batch_size: Optional[int] = None,
       max_batch_duration_secs: Optional[int] = None,
+      max_batch_weight: Optional[int] = None,
+      element_size_fn: Optional[Callable[[Any], int]] = None,
       **kwargs):
     """Implementation of the ModelHandler interface for Google Gemini.
     **NOTE:** This API and its implementation are under development and
@@ -134,15 +136,18 @@ class GeminiModelHandler(RemoteModelHandler[Any, 
PredictionResult,
       project: the GCP project to use for Vertex AI requests. Setting this
         parameter routes requests to Vertex AI. If this paramter is provided,
         location must also be provided and api_key should not be set.
-      location: the GCP project to use for Vertex AI requests. Setting this 
+      location: the GCP project to use for Vertex AI requests. Setting this
         parameter routes requests to Vertex AI. If this paramter is provided,
         project must also be provided and api_key should not be set.
       min_batch_size: optional. the minimum batch size to use when batching
         inputs.
       max_batch_size: optional. the maximum batch size to use when batching
         inputs.
-      max_batch_duration_secs: optional. the maximum amount of time to buffer 
+      max_batch_duration_secs: optional. the maximum amount of time to buffer
         a batch before emitting; used in streaming contexts.
+      max_batch_weight: optional. the maximum total weight of a batch.
+      element_size_fn: optional. a function that returns the size (weight)
+        of an element.
     """
     self._batching_kwargs = {}
     self._env_vars = kwargs.get('env_vars', {})
@@ -152,6 +157,10 @@ class GeminiModelHandler(RemoteModelHandler[Any, 
PredictionResult,
       self._batching_kwargs["max_batch_size"] = max_batch_size
     if max_batch_duration_secs is not None:
       self._batching_kwargs["max_batch_duration_secs"] = 
max_batch_duration_secs
+    if max_batch_weight is not None:
+      self._batching_kwargs["max_batch_weight"] = max_batch_weight
+    if element_size_fn is not None:
+      self._batching_kwargs['element_size_fn'] = element_size_fn
 
     self.model_name = model_name
     self.request_fn = request_fn
@@ -174,6 +183,9 @@ class GeminiModelHandler(RemoteModelHandler[Any, 
PredictionResult,
         retry_filter=_retry_on_appropriate_service_error,
         **kwargs)
 
+  def batch_elements_kwargs(self):
+    return self._batching_kwargs
+
   def create_client(self) -> genai.Client:
     """Creates the GenAI client used to send requests. Creates a version for
     the Vertex AI API or the Gemini Developer API based on the arguments
diff --git a/sdks/python/apache_beam/ml/inference/huggingface_inference.py 
b/sdks/python/apache_beam/ml/inference/huggingface_inference.py
index 501a019c378..2c1f5e2cc90 100644
--- a/sdks/python/apache_beam/ml/inference/huggingface_inference.py
+++ b/sdks/python/apache_beam/ml/inference/huggingface_inference.py
@@ -227,6 +227,8 @@ class 
HuggingFaceModelHandlerKeyedTensor(ModelHandler[dict[str,
       max_batch_duration_secs: Optional[int] = None,
       large_model: bool = False,
       model_copies: Optional[int] = None,
+      max_batch_weight: Optional[int] = None,
+      element_size_fn: Optional[Callable[[Any], int]] = None,
       **kwargs):
     """
     Implementation of the ModelHandler interface for HuggingFace with
@@ -262,27 +264,28 @@ class 
HuggingFaceModelHandlerKeyedTensor(ModelHandler[dict[str,
       model_copies: The exact number of models that you would like loaded
         onto your machine. This can be useful if you exactly know your CPU or
         GPU capacity and want to maximize resource utilization.
+      max_batch_weight: the maximum total weight of a batch.
+      element_size_fn: a function that returns the size (weight) of an element.
       kwargs: 'env_vars' can be used to set environment variables
         before loading the model.
 
     **Supported Versions:** HuggingFaceModelHandler supports
     transformers>=4.18.0.
     """
+    super().__init__(
+        min_batch_size=min_batch_size,
+        max_batch_size=max_batch_size,
+        max_batch_duration_secs=max_batch_duration_secs,
+        max_batch_weight=max_batch_weight,
+        element_size_fn=element_size_fn,
+        large_model=large_model,
+        model_copies=model_copies,
+        **kwargs)
     self._model_uri = model_uri
     self._model_class = model_class
     self._device = device
     self._inference_fn = inference_fn
     self._model_config_args = load_model_args if load_model_args else {}
-    self._batching_kwargs = {}
-    self._env_vars = kwargs.get("env_vars", {})
-    if min_batch_size is not None:
-      self._batching_kwargs["min_batch_size"] = min_batch_size
-    if max_batch_size is not None:
-      self._batching_kwargs["max_batch_size"] = max_batch_size
-    if max_batch_duration_secs is not None:
-      self._batching_kwargs["max_batch_duration_secs"] = 
max_batch_duration_secs
-    self._share_across_processes = large_model or (model_copies is not None)
-    self._model_copies = model_copies or 1
     self._framework = framework
 
     _validate_constructor_args(
@@ -352,15 +355,6 @@ class 
HuggingFaceModelHandlerKeyedTensor(ModelHandler[dict[str,
       return sum(
           (el.element_size() for tensor in batch for el in tensor.values()))
 
-  def batch_elements_kwargs(self):
-    return self._batching_kwargs
-
-  def share_model_across_processes(self) -> bool:
-    return self._share_across_processes
-
-  def model_copies(self) -> int:
-    return self._model_copies
-
   def get_metrics_namespace(self) -> str:
     """
     Returns:
@@ -415,6 +409,8 @@ class 
HuggingFaceModelHandlerTensor(ModelHandler[Union[tf.Tensor, torch.Tensor],
       max_batch_duration_secs: Optional[int] = None,
       large_model: bool = False,
       model_copies: Optional[int] = None,
+      max_batch_weight: Optional[int] = None,
+      element_size_fn: Optional[Callable[[Any], int]] = None,
       **kwargs):
     """
     Implementation of the ModelHandler interface for HuggingFace with
@@ -450,27 +446,28 @@ class 
HuggingFaceModelHandlerTensor(ModelHandler[Union[tf.Tensor, torch.Tensor],
       model_copies: The exact number of models that you would like loaded
         onto your machine. This can be useful if you exactly know your CPU or
         GPU capacity and want to maximize resource utilization.
+      max_batch_weight: the maximum total weight of a batch.
+      element_size_fn: a function that returns the size (weight) of an element.
       kwargs: 'env_vars' can be used to set environment variables
         before loading the model.
 
     **Supported Versions:** HuggingFaceModelHandler supports
     transformers>=4.18.0.
     """
+    super().__init__(
+        min_batch_size=min_batch_size,
+        max_batch_size=max_batch_size,
+        max_batch_duration_secs=max_batch_duration_secs,
+        max_batch_weight=max_batch_weight,
+        element_size_fn=element_size_fn,
+        large_model=large_model,
+        model_copies=model_copies,
+        **kwargs)
     self._model_uri = model_uri
     self._model_class = model_class
     self._device = device
     self._inference_fn = inference_fn
     self._model_config_args = load_model_args if load_model_args else {}
-    self._batching_kwargs = {}
-    self._env_vars = kwargs.get("env_vars", {})
-    if min_batch_size is not None:
-      self._batching_kwargs["min_batch_size"] = min_batch_size
-    if max_batch_size is not None:
-      self._batching_kwargs["max_batch_size"] = max_batch_size
-    if max_batch_duration_secs is not None:
-      self._batching_kwargs["max_batch_duration_secs"] = 
max_batch_duration_secs
-    self._share_across_processes = large_model or (model_copies is not None)
-    self._model_copies = model_copies or 1
     self._framework = ""
 
     _validate_constructor_args(
@@ -547,15 +544,6 @@ class 
HuggingFaceModelHandlerTensor(ModelHandler[Union[tf.Tensor, torch.Tensor],
       return sum(
           (el.element_size() for tensor in batch for el in tensor.values()))
 
-  def batch_elements_kwargs(self):
-    return self._batching_kwargs
-
-  def share_model_across_processes(self) -> bool:
-    return self._share_across_processes
-
-  def model_copies(self) -> int:
-    return self._model_copies
-
   def get_metrics_namespace(self) -> str:
     """
     Returns:
@@ -586,6 +574,8 @@ class HuggingFacePipelineModelHandler(ModelHandler[str,
       max_batch_duration_secs: Optional[int] = None,
       large_model: bool = False,
       model_copies: Optional[int] = None,
+      max_batch_weight: Optional[int] = None,
+      element_size_fn: Optional[Callable[[Any], int]] = None,
       **kwargs):
     """
     Implementation of the ModelHandler interface for Hugging Face Pipelines.
@@ -629,27 +619,28 @@ class HuggingFacePipelineModelHandler(ModelHandler[str,
       model_copies: The exact number of models that you would like loaded
         onto your machine. This can be useful if you exactly know your CPU or
         GPU capacity and want to maximize resource utilization.
+      max_batch_weight: the maximum total weight of a batch.
+      element_size_fn: a function that returns the size (weight) of an element.
       kwargs: 'env_vars' can be used to set environment variables
         before loading the model.
 
     **Supported Versions:** HuggingFacePipelineModelHandler supports
     transformers>=4.18.0.
     """
+    super().__init__(
+        min_batch_size=min_batch_size,
+        max_batch_size=max_batch_size,
+        max_batch_duration_secs=max_batch_duration_secs,
+        max_batch_weight=max_batch_weight,
+        element_size_fn=element_size_fn,
+        large_model=large_model,
+        model_copies=model_copies,
+        **kwargs)
     self._task = task
     self._model = model
     self._inference_fn = inference_fn
     self._load_pipeline_args = load_pipeline_args if load_pipeline_args else {}
-    self._batching_kwargs = {}
     self._framework = "pt"
-    self._env_vars = kwargs.get('env_vars', {})
-    if min_batch_size is not None:
-      self._batching_kwargs['min_batch_size'] = min_batch_size
-    if max_batch_size is not None:
-      self._batching_kwargs['max_batch_size'] = max_batch_size
-    if max_batch_duration_secs is not None:
-      self._batching_kwargs["max_batch_duration_secs"] = 
max_batch_duration_secs
-    self._share_across_processes = large_model or (model_copies is not None)
-    self._model_copies = model_copies or 1
 
     # Check if the device is specified twice. If true then the device parameter
     # of model handler is overridden.
@@ -726,15 +717,6 @@ class HuggingFacePipelineModelHandler(ModelHandler[str,
     """
     return sum(sys.getsizeof(element) for element in batch)
 
-  def batch_elements_kwargs(self):
-    return self._batching_kwargs
-
-  def share_model_across_processes(self) -> bool:
-    return self._share_across_processes
-
-  def model_copies(self) -> int:
-    return self._model_copies
-
   def get_metrics_namespace(self) -> str:
     """
     Returns:
diff --git a/sdks/python/apache_beam/ml/inference/onnx_inference.py 
b/sdks/python/apache_beam/ml/inference/onnx_inference.py
index 3485866f11c..4423eed2e40 100644
--- a/sdks/python/apache_beam/ml/inference/onnx_inference.py
+++ b/sdks/python/apache_beam/ml/inference/onnx_inference.py
@@ -17,7 +17,6 @@
 
 from collections.abc import Callable
 from collections.abc import Iterable
-from collections.abc import Mapping
 from collections.abc import Sequence
 from typing import Any
 from typing import Optional
@@ -67,6 +66,8 @@ class OnnxModelHandlerNumpy(ModelHandler[numpy.ndarray,
       min_batch_size: Optional[int] = None,
       max_batch_size: Optional[int] = None,
       max_batch_duration_secs: Optional[int] = None,
+      max_batch_weight: Optional[int] = None,
+      element_size_fn: Optional[Callable[[Any], int]] = None,
       **kwargs):
     """ Implementation of the ModelHandler interface for onnx
     using numpy arrays as input.
@@ -91,24 +92,25 @@ class OnnxModelHandlerNumpy(ModelHandler[numpy.ndarray,
       max_batch_size: the maximum batch size to use when batching inputs.
       max_batch_duration_secs: the maximum amount of time to buffer a batch
         before emitting; used in streaming contexts.
+      max_batch_weight: the maximum total weight of a batch.
+      element_size_fn: a function that returns the size (weight) of an element.
       kwargs: 'env_vars' can be used to set environment variables
         before loading the model.
     """
+    super().__init__(
+        min_batch_size=min_batch_size,
+        max_batch_size=max_batch_size,
+        max_batch_duration_secs=max_batch_duration_secs,
+        max_batch_weight=max_batch_weight,
+        element_size_fn=element_size_fn,
+        large_model=large_model,
+        model_copies=model_copies,
+        **kwargs)
     self._model_uri = model_uri
     self._session_options = session_options
     self._providers = providers
     self._provider_options = provider_options
     self._model_inference_fn = inference_fn
-    self._env_vars = kwargs.get('env_vars', {})
-    self._share_across_processes = large_model or (model_copies is not None)
-    self._model_copies = model_copies or 1
-    self._batching_kwargs = {}
-    if min_batch_size is not None:
-      self._batching_kwargs["min_batch_size"] = min_batch_size
-    if max_batch_size is not None:
-      self._batching_kwargs["max_batch_size"] = max_batch_size
-    if max_batch_duration_secs is not None:
-      self._batching_kwargs["max_batch_duration_secs"] = 
max_batch_duration_secs
 
   def load_model(self) -> ort.InferenceSession:
     """Loads and initializes an onnx inference session for processing."""
@@ -167,12 +169,3 @@ class OnnxModelHandlerNumpy(ModelHandler[numpy.ndarray,
        A namespace for metrics collected by the RunInference transform.
     """
     return 'BeamML_Onnx'
-
-  def share_model_across_processes(self) -> bool:
-    return self._share_across_processes
-
-  def model_copies(self) -> int:
-    return self._model_copies
-
-  def batch_elements_kwargs(self) -> Mapping[str, Any]:
-    return self._batching_kwargs
diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference.py 
b/sdks/python/apache_beam/ml/inference/pytorch_inference.py
index affbcd977f5..63c2a116fcc 100644
--- a/sdks/python/apache_beam/ml/inference/pytorch_inference.py
+++ b/sdks/python/apache_beam/ml/inference/pytorch_inference.py
@@ -197,6 +197,8 @@ class PytorchModelHandlerTensor(ModelHandler[torch.Tensor,
       large_model: bool = False,
       model_copies: Optional[int] = None,
       load_model_args: Optional[dict[str, Any]] = None,
+      max_batch_weight: Optional[int] = None,
+      element_size_fn: Optional[Callable[[Any], int]] = None,
       **kwargs):
     """Implementation of the ModelHandler interface for PyTorch.
 
@@ -240,12 +242,23 @@ class PytorchModelHandlerTensor(ModelHandler[torch.Tensor,
         GPU capacity and want to maximize resource utilization.
       load_model_args: a dictionary of parameters passed to the torch.load
         function to specify custom config for loading models.
+      max_batch_weight: the maximum total weight of a batch.
+      element_size_fn: a function that returns the size (weight) of an element.
       kwargs: 'env_vars' can be used to set environment variables
         before loading the model.
 
     **Supported Versions:** RunInference APIs in Apache Beam have been tested
     with PyTorch 1.9 and 1.10.
     """
+    super().__init__(
+        min_batch_size=min_batch_size,
+        max_batch_size=max_batch_size,
+        max_batch_duration_secs=max_batch_duration_secs,
+        max_batch_weight=max_batch_weight,
+        element_size_fn=element_size_fn,
+        large_model=large_model,
+        model_copies=model_copies,
+        **kwargs)
     self._state_dict_path = state_dict_path
     if device == 'GPU':
       logging.info("Device is set to CUDA")
@@ -256,18 +269,8 @@ class PytorchModelHandlerTensor(ModelHandler[torch.Tensor,
     self._model_class = model_class
     self._model_params = model_params if model_params else {}
     self._inference_fn = inference_fn
-    self._batching_kwargs = {}
-    if min_batch_size is not None:
-      self._batching_kwargs['min_batch_size'] = min_batch_size
-    if max_batch_size is not None:
-      self._batching_kwargs['max_batch_size'] = max_batch_size
-    if max_batch_duration_secs is not None:
-      self._batching_kwargs["max_batch_duration_secs"] = 
max_batch_duration_secs
     self._torch_script_model_path = torch_script_model_path
     self._load_model_args = load_model_args if load_model_args else {}
-    self._env_vars = kwargs.get('env_vars', {})
-    self._share_across_processes = large_model or (model_copies is not None)
-    self._model_copies = model_copies or 1
 
     _validate_constructor_args(
         state_dict_path=self._state_dict_path,
@@ -342,15 +345,6 @@ class PytorchModelHandlerTensor(ModelHandler[torch.Tensor,
     """
     return 'BeamML_PyTorch'
 
-  def batch_elements_kwargs(self):
-    return self._batching_kwargs
-
-  def share_model_across_processes(self) -> bool:
-    return self._share_across_processes
-
-  def model_copies(self) -> int:
-    return self._model_copies
-
 
 def default_keyed_tensor_inference_fn(
     batch: Sequence[dict[str, torch.Tensor]],
@@ -435,6 +429,8 @@ class PytorchModelHandlerKeyedTensor(ModelHandler[dict[str, 
torch.Tensor],
       large_model: bool = False,
       model_copies: Optional[int] = None,
       load_model_args: Optional[dict[str, Any]] = None,
+      max_batch_weight: Optional[int] = None,
+      element_size_fn: Optional[Callable[[Any], int]] = None,
       **kwargs):
     """Implementation of the ModelHandler interface for PyTorch.
 
@@ -483,12 +479,23 @@ class 
PytorchModelHandlerKeyedTensor(ModelHandler[dict[str, torch.Tensor],
         GPU capacity and want to maximize resource utilization.
       load_model_args: a dictionary of parameters passed to the torch.load
         function to specify custom config for loading models.
+      max_batch_weight: the maximum total weight of a batch.
+      element_size_fn: a function that returns the size (weight) of an element.
       kwargs: 'env_vars' can be used to set environment variables
         before loading the model.
 
     **Supported Versions:** RunInference APIs in Apache Beam have been tested
     on torch>=1.9.0,<1.14.0.
     """
+    super().__init__(
+        min_batch_size=min_batch_size,
+        max_batch_size=max_batch_size,
+        max_batch_duration_secs=max_batch_duration_secs,
+        max_batch_weight=max_batch_weight,
+        element_size_fn=element_size_fn,
+        large_model=large_model,
+        model_copies=model_copies,
+        **kwargs)
     self._state_dict_path = state_dict_path
     if device == 'GPU':
       logging.info("Device is set to CUDA")
@@ -499,18 +506,8 @@ class 
PytorchModelHandlerKeyedTensor(ModelHandler[dict[str, torch.Tensor],
     self._model_class = model_class
     self._model_params = model_params if model_params else {}
     self._inference_fn = inference_fn
-    self._batching_kwargs = {}
-    if min_batch_size is not None:
-      self._batching_kwargs['min_batch_size'] = min_batch_size
-    if max_batch_size is not None:
-      self._batching_kwargs['max_batch_size'] = max_batch_size
-    if max_batch_duration_secs is not None:
-      self._batching_kwargs["max_batch_duration_secs"] = 
max_batch_duration_secs
     self._torch_script_model_path = torch_script_model_path
     self._load_model_args = load_model_args if load_model_args else {}
-    self._env_vars = kwargs.get('env_vars', {})
-    self._share_across_processes = large_model or (model_copies is not None)
-    self._model_copies = model_copies or 1
 
     _validate_constructor_args(
         state_dict_path=self._state_dict_path,
@@ -586,12 +583,3 @@ class 
PytorchModelHandlerKeyedTensor(ModelHandler[dict[str, torch.Tensor],
        A namespace for metrics collected by the RunInference transform.
     """
     return 'BeamML_PyTorch'
-
-  def batch_elements_kwargs(self):
-    return self._batching_kwargs
-
-  def share_model_across_processes(self) -> bool:
-    return self._share_across_processes
-
-  def model_copies(self) -> int:
-    return self._model_copies
diff --git a/sdks/python/apache_beam/ml/inference/sklearn_inference.py 
b/sdks/python/apache_beam/ml/inference/sklearn_inference.py
index 84947bec3df..e61ef9c194a 100644
--- a/sdks/python/apache_beam/ml/inference/sklearn_inference.py
+++ b/sdks/python/apache_beam/ml/inference/sklearn_inference.py
@@ -93,6 +93,8 @@ class SklearnModelHandlerNumpy(ModelHandler[numpy.ndarray,
       max_batch_duration_secs: Optional[int] = None,
       large_model: bool = False,
       model_copies: Optional[int] = None,
+      max_batch_weight: Optional[int] = None,
+      element_size_fn: Optional[Callable[[Any], int]] = None,
       **kwargs):
     """ Implementation of the ModelHandler interface for scikit-learn
     using numpy arrays as input.
@@ -122,22 +124,23 @@ class SklearnModelHandlerNumpy(ModelHandler[numpy.ndarray,
       model_copies: The exact number of models that you would like loaded
         onto your machine. This can be useful if you exactly know your CPU or
         GPU capacity and want to maximize resource utilization.
+      max_batch_weight: the maximum total weight of a batch.
+      element_size_fn: a function that returns the size (weight) of an element.
       kwargs: 'env_vars' can be used to set environment variables
         before loading the model.
     """
+    super().__init__(
+        min_batch_size=min_batch_size,
+        max_batch_size=max_batch_size,
+        max_batch_duration_secs=max_batch_duration_secs,
+        max_batch_weight=max_batch_weight,
+        element_size_fn=element_size_fn,
+        large_model=large_model,
+        model_copies=model_copies,
+        **kwargs)
     self._model_uri = model_uri
     self._model_file_type = model_file_type
     self._model_inference_fn = inference_fn
-    self._batching_kwargs = {}
-    if min_batch_size is not None:
-      self._batching_kwargs['min_batch_size'] = min_batch_size
-    if max_batch_size is not None:
-      self._batching_kwargs['max_batch_size'] = max_batch_size
-    if max_batch_duration_secs is not None:
-      self._batching_kwargs["max_batch_duration_secs"] = 
max_batch_duration_secs
-    self._env_vars = kwargs.get('env_vars', {})
-    self._share_across_processes = large_model or (model_copies is not None)
-    self._model_copies = model_copies or 1
 
   def load_model(self) -> BaseEstimator:
     """Loads and initializes a model for processing."""
@@ -187,15 +190,6 @@ class SklearnModelHandlerNumpy(ModelHandler[numpy.ndarray,
     """
     return 'BeamML_Sklearn'
 
-  def batch_elements_kwargs(self):
-    return self._batching_kwargs
-
-  def share_model_across_processes(self) -> bool:
-    return self._share_across_processes
-
-  def model_copies(self) -> int:
-    return self._model_copies
-
 
 PandasInferenceFn = Callable[
     [BaseEstimator, Sequence[pandas.DataFrame], Optional[dict[str, Any]]], Any]
@@ -228,6 +222,8 @@ class 
SklearnModelHandlerPandas(ModelHandler[pandas.DataFrame,
       max_batch_duration_secs: Optional[int] = None,
       large_model: bool = False,
       model_copies: Optional[int] = None,
+      max_batch_weight: Optional[int] = None,
+      element_size_fn: Optional[Callable[[Any], int]] = None,
       **kwargs):
     """Implementation of the ModelHandler interface for scikit-learn that
     supports pandas dataframes.
@@ -260,22 +256,23 @@ class 
SklearnModelHandlerPandas(ModelHandler[pandas.DataFrame,
       model_copies: The exact number of models that you would like loaded
         onto your machine. This can be useful if you exactly know your CPU or
         GPU capacity and want to maximize resource utilization.
+      max_batch_weight: the maximum total weight of a batch.
+      element_size_fn: a function that returns the size (weight) of an element.
       kwargs: 'env_vars' can be used to set environment variables
         before loading the model.
     """
+    super().__init__(
+        min_batch_size=min_batch_size,
+        max_batch_size=max_batch_size,
+        max_batch_duration_secs=max_batch_duration_secs,
+        max_batch_weight=max_batch_weight,
+        element_size_fn=element_size_fn,
+        large_model=large_model,
+        model_copies=model_copies,
+        **kwargs)
     self._model_uri = model_uri
     self._model_file_type = model_file_type
     self._model_inference_fn = inference_fn
-    self._batching_kwargs = {}
-    if min_batch_size is not None:
-      self._batching_kwargs['min_batch_size'] = min_batch_size
-    if max_batch_size is not None:
-      self._batching_kwargs['max_batch_size'] = max_batch_size
-    if max_batch_duration_secs is not None:
-      self._batching_kwargs["max_batch_duration_secs"] = 
max_batch_duration_secs
-    self._env_vars = kwargs.get('env_vars', {})
-    self._share_across_processes = large_model or (model_copies is not None)
-    self._model_copies = model_copies or 1
 
   def load_model(self) -> BaseEstimator:
     """Loads and initializes a model for processing."""
@@ -326,12 +323,3 @@ class 
SklearnModelHandlerPandas(ModelHandler[pandas.DataFrame,
        A namespace for metrics collected by the RunInference transform.
     """
     return 'BeamML_Sklearn'
-
-  def batch_elements_kwargs(self):
-    return self._batching_kwargs
-
-  def share_model_across_processes(self) -> bool:
-    return self._share_across_processes
-
-  def model_copies(self) -> int:
-    return self._model_copies
diff --git a/sdks/python/apache_beam/ml/inference/tensorflow_inference.py 
b/sdks/python/apache_beam/ml/inference/tensorflow_inference.py
index 5ce293a06ac..97b74eb360a 100644
--- a/sdks/python/apache_beam/ml/inference/tensorflow_inference.py
+++ b/sdks/python/apache_beam/ml/inference/tensorflow_inference.py
@@ -112,6 +112,8 @@ class TFModelHandlerNumpy(ModelHandler[numpy.ndarray,
       max_batch_duration_secs: Optional[int] = None,
       large_model: bool = False,
       model_copies: Optional[int] = None,
+      max_batch_weight: Optional[int] = None,
+      element_size_fn: Optional[Callable[[Any], int]] = None,
       **kwargs):
     """Implementation of the ModelHandler interface for Tensorflow.
 
@@ -140,28 +142,30 @@ class TFModelHandlerNumpy(ModelHandler[numpy.ndarray,
         model_copies: The exact number of models that you would like loaded
           onto your machine. This can be useful if you exactly know your CPU or
           GPU capacity and want to maximize resource utilization.
+        max_batch_weight: the maximum total weight of a batch.
+        element_size_fn: a function that returns the size (weight) of an
+          element.
         kwargs: 'env_vars' can be used to set environment variables
           before loading the model.
 
     **Supported Versions:** RunInference APIs in Apache Beam have been tested
     with Tensorflow 2.9, 2.10, 2.11.
     """
+    super().__init__(
+        min_batch_size=min_batch_size,
+        max_batch_size=max_batch_size,
+        max_batch_duration_secs=max_batch_duration_secs,
+        max_batch_weight=max_batch_weight,
+        element_size_fn=element_size_fn,
+        large_model=large_model,
+        model_copies=model_copies,
+        **kwargs)
     self._model_uri = model_uri
     self._model_type = model_type
     self._inference_fn = inference_fn
     self._create_model_fn = create_model_fn
-    self._env_vars = kwargs.get('env_vars', {})
     self._load_model_args = {} if not load_model_args else load_model_args
     self._custom_weights = custom_weights
-    self._batching_kwargs = {}
-    if min_batch_size is not None:
-      self._batching_kwargs['min_batch_size'] = min_batch_size
-    if max_batch_size is not None:
-      self._batching_kwargs['max_batch_size'] = max_batch_size
-    if max_batch_duration_secs is not None:
-      self._batching_kwargs["max_batch_duration_secs"] = 
max_batch_duration_secs
-    self._share_across_processes = large_model or (model_copies is not None)
-    self._model_copies = model_copies or 1
 
   def load_model(self) -> tf.Module:
     """Loads and initializes a Tensorflow model for processing."""
@@ -219,15 +223,6 @@ class TFModelHandlerNumpy(ModelHandler[numpy.ndarray,
     """
     return 'BeamML_TF_Numpy'
 
-  def batch_elements_kwargs(self):
-    return self._batching_kwargs
-
-  def share_model_across_processes(self) -> bool:
-    return self._share_across_processes
-
-  def model_copies(self) -> int:
-    return self._model_copies
-
 
 class TFModelHandlerTensor(ModelHandler[tf.Tensor, PredictionResult,
                                         tf.Module]):
@@ -245,6 +240,8 @@ class TFModelHandlerTensor(ModelHandler[tf.Tensor, 
PredictionResult,
       max_batch_duration_secs: Optional[int] = None,
       large_model: bool = False,
       model_copies: Optional[int] = None,
+      max_batch_weight: Optional[int] = None,
+      element_size_fn: Optional[Callable[[Any], int]] = None,
       **kwargs):
     """Implementation of the ModelHandler interface for Tensorflow.
 
@@ -278,28 +275,30 @@ class TFModelHandlerTensor(ModelHandler[tf.Tensor, 
PredictionResult,
         model_copies: The exact number of models that you would like loaded
           onto your machine. This can be useful if you exactly know your CPU or
           GPU capacity and want to maximize resource utilization.
+        max_batch_weight: the maximum total weight of a batch.
+        element_size_fn: a function that returns the size (weight) of an
+          element.
         kwargs: 'env_vars' can be used to set environment variables
           before loading the model.
 
     **Supported Versions:** RunInference APIs in Apache Beam have been tested
     with Tensorflow 2.11.
     """
+    super().__init__(
+        min_batch_size=min_batch_size,
+        max_batch_size=max_batch_size,
+        max_batch_duration_secs=max_batch_duration_secs,
+        max_batch_weight=max_batch_weight,
+        element_size_fn=element_size_fn,
+        large_model=large_model,
+        model_copies=model_copies,
+        **kwargs)
     self._model_uri = model_uri
     self._model_type = model_type
     self._inference_fn = inference_fn
     self._create_model_fn = create_model_fn
-    self._env_vars = kwargs.get('env_vars', {})
     self._load_model_args = {} if not load_model_args else load_model_args
     self._custom_weights = custom_weights
-    self._batching_kwargs = {}
-    if min_batch_size is not None:
-      self._batching_kwargs['min_batch_size'] = min_batch_size
-    if max_batch_size is not None:
-      self._batching_kwargs['max_batch_size'] = max_batch_size
-    if max_batch_duration_secs is not None:
-      self._batching_kwargs["max_batch_duration_secs"] = 
max_batch_duration_secs
-    self._share_across_processes = large_model or (model_copies is not None)
-    self._model_copies = model_copies or 1
 
   def load_model(self) -> tf.Module:
     """Loads and initializes a tensorflow model for processing."""
@@ -356,12 +355,3 @@ class TFModelHandlerTensor(ModelHandler[tf.Tensor, 
PredictionResult,
        A namespace for metrics collected by the RunInference transform.
     """
     return 'BeamML_TF_Tensor'
-
-  def batch_elements_kwargs(self):
-    return self._batching_kwargs
-
-  def share_model_across_processes(self) -> bool:
-    return self._share_across_processes
-
-  def model_copies(self) -> int:
-    return self._model_copies
diff --git a/sdks/python/apache_beam/ml/inference/tensorrt_inference.py 
b/sdks/python/apache_beam/ml/inference/tensorrt_inference.py
index b575dfa849d..00a61b4934a 100644
--- a/sdks/python/apache_beam/ml/inference/tensorrt_inference.py
+++ b/sdks/python/apache_beam/ml/inference/tensorrt_inference.py
@@ -230,6 +230,8 @@ class TensorRTEngineHandlerNumPy(ModelHandler[np.ndarray,
       large_model: bool = False,
       model_copies: Optional[int] = None,
       max_batch_duration_secs: Optional[int] = None,
+      max_batch_weight: Optional[int] = None,
+      element_size_fn: Optional[Callable[[Any], int]] = None,
       **kwargs):
     """Implementation of the ModelHandler interface for TensorRT.
 
@@ -258,6 +260,8 @@ class TensorRTEngineHandlerNumPy(ModelHandler[np.ndarray,
         GPU capacity and want to maximize resource utilization.
       max_batch_duration_secs: the maximum amount of time to buffer
         a batch before emitting; used in streaming contexts.
+      max_batch_weight: the maximum total weight of a batch.
+      element_size_fn: a function that returns the size (weight) of an element.
       kwargs: Additional arguments like 'engine_path' and 'onnx_path' are
         currently supported. 'env_vars' can be used to set environment 
variables
         before loading the model.
@@ -265,25 +269,20 @@ class TensorRTEngineHandlerNumPy(ModelHandler[np.ndarray,
     See https://docs.nvidia.com/deeplearning/tensorrt/api/python_api/
     for details
     """
-    self.min_batch_size = min_batch_size
-    self.max_batch_size = max_batch_size
-    self.max_batch_duration_secs = max_batch_duration_secs
+    super().__init__(
+        min_batch_size=min_batch_size,
+        max_batch_size=max_batch_size,
+        max_batch_duration_secs=max_batch_duration_secs,
+        max_batch_weight=max_batch_weight,
+        element_size_fn=element_size_fn,
+        large_model=large_model,
+        model_copies=model_copies,
+        **kwargs)
     self.inference_fn = inference_fn
     if 'engine_path' in kwargs:
       self.engine_path = kwargs.get('engine_path')
     elif 'onnx_path' in kwargs:
       self.onnx_path = kwargs.get('onnx_path')
-    self._env_vars = kwargs.get('env_vars', {})
-    self._share_across_processes = large_model or (model_copies is not None)
-    self._model_copies = model_copies or 1
-
-  def batch_elements_kwargs(self):
-    """Sets min_batch_size and max_batch_size of a TensorRT engine."""
-    return {
-        'min_batch_size': self.min_batch_size,
-        'max_batch_size': self.max_batch_size,
-        'max_batch_duration_secs': self.max_batch_duration_secs
-    }
 
   def load_model(self) -> TensorRTEngine:
     """Loads and initializes a TensorRT engine for processing."""
@@ -336,12 +335,6 @@ class TensorRTEngineHandlerNumPy(ModelHandler[np.ndarray,
     """
     return 'BeamML_TensorRT'
 
-  def share_model_across_processes(self) -> bool:
-    return self._share_across_processes
-
-  def model_copies(self) -> int:
-    return self._model_copies
-
   def validate_inference_args(self, inference_args: Optional[dict[str, Any]]):
     """
     Currently, this model handler does not support inference args. Given that,
diff --git a/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py 
b/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py
index cd3d0beb593..02827f9578f 100644
--- a/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py
+++ b/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py
@@ -17,6 +17,7 @@
 
 import json
 import logging
+from collections.abc import Callable
 from collections.abc import Iterable
 from collections.abc import Mapping
 from collections.abc import Sequence
@@ -69,6 +70,8 @@ class VertexAIModelHandlerJSON(RemoteModelHandler[Any,
       min_batch_size: Optional[int] = None,
       max_batch_size: Optional[int] = None,
       max_batch_duration_secs: Optional[int] = None,
+      max_batch_weight: Optional[int] = None,
+      element_size_fn: Optional[Callable[[Any], int]] = None,
       **kwargs):
     """Implementation of the ModelHandler interface for Vertex AI.
     **NOTE:** This API and its implementation are under development and
@@ -107,8 +110,11 @@ class VertexAIModelHandlerJSON(RemoteModelHandler[Any,
         inputs.
       max_batch_size: optional. the maximum batch size to use when batching
         inputs.
-      max_batch_duration_secs: optional. the maximum amount of time to buffer 
+      max_batch_duration_secs: optional. the maximum amount of time to buffer
         a batch before emitting; used in streaming contexts.
+      max_batch_weight: optional. the maximum total weight of a batch.
+      element_size_fn: optional. a function that returns the size (weight)
+        of an element.
     """
     self._batching_kwargs = {}
     self._env_vars = kwargs.get('env_vars', {})
@@ -119,6 +125,10 @@ class VertexAIModelHandlerJSON(RemoteModelHandler[Any,
       self._batching_kwargs["max_batch_size"] = max_batch_size
     if max_batch_duration_secs is not None:
       self._batching_kwargs["max_batch_duration_secs"] = 
max_batch_duration_secs
+    if max_batch_weight is not None:
+      self._batching_kwargs["max_batch_weight"] = max_batch_weight
+    if element_size_fn is not None:
+      self._batching_kwargs['element_size_fn'] = element_size_fn
 
     if private and network is None:
       raise ValueError(
diff --git a/sdks/python/apache_beam/ml/inference/vllm_inference.py 
b/sdks/python/apache_beam/ml/inference/vllm_inference.py
index bdbee9e51fd..918b4915560 100644
--- a/sdks/python/apache_beam/ml/inference/vllm_inference.py
+++ b/sdks/python/apache_beam/ml/inference/vllm_inference.py
@@ -25,6 +25,7 @@ import sys
 import threading
 import time
 import uuid
+from collections.abc import Callable
 from collections.abc import Iterable
 from collections.abc import Sequence
 from dataclasses import dataclass
@@ -175,7 +176,13 @@ class VLLMCompletionsModelHandler(ModelHandler[str,
   def __init__(
       self,
       model_name: str,
-      vllm_server_kwargs: Optional[dict[str, str]] = None):
+      vllm_server_kwargs: Optional[dict[str, str]] = None,
+      *,
+      min_batch_size: Optional[int] = None,
+      max_batch_size: Optional[int] = None,
+      max_batch_duration_secs: Optional[int] = None,
+      max_batch_weight: Optional[int] = None,
+      element_size_fn: Optional[Callable[[Any], int]] = None):
     """Implementation of the ModelHandler interface for vLLM using text as
     input.
 
@@ -194,10 +201,24 @@ class VLLMCompletionsModelHandler(ModelHandler[str,
         `{'echo': 'true'}` to prepend new messages with the previous message.
         For a list of possible kwargs, see
         
https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#extra-parameters-for-completions-api
+      min_batch_size: optional. the minimum batch size to use when batching
+        inputs.
+      max_batch_size: optional. the maximum batch size to use when batching
+        inputs.
+      max_batch_duration_secs: optional. the maximum amount of time to buffer
+        a batch before emitting; used in streaming contexts.
+      max_batch_weight: optional. the maximum total weight of a batch.
+      element_size_fn: optional. a function that returns the size (weight) of
+        an element.
     """
+    super().__init__(
+        min_batch_size=min_batch_size,
+        max_batch_size=max_batch_size,
+        max_batch_duration_secs=max_batch_duration_secs,
+        max_batch_weight=max_batch_weight,
+        element_size_fn=element_size_fn)
     self._model_name = model_name
     self._vllm_server_kwargs: dict[str, str] = vllm_server_kwargs or {}
-    self._env_vars = {}
 
   def load_model(self) -> _VLLMModelServer:
     return _VLLMModelServer(self._model_name, self._vllm_server_kwargs)
@@ -253,7 +274,13 @@ class 
VLLMChatModelHandler(ModelHandler[Sequence[OpenAIChatMessage],
       self,
       model_name: str,
       chat_template_path: Optional[str] = None,
-      vllm_server_kwargs: Optional[dict[str, str]] = None):
+      vllm_server_kwargs: Optional[dict[str, str]] = None,
+      *,
+      min_batch_size: Optional[int] = None,
+      max_batch_size: Optional[int] = None,
+      max_batch_duration_secs: Optional[int] = None,
+      max_batch_weight: Optional[int] = None,
+      element_size_fn: Optional[Callable[[Any], int]] = None):
     """ Implementation of the ModelHandler interface for vLLM using previous
     messages as input.
 
@@ -277,10 +304,24 @@ class 
VLLMChatModelHandler(ModelHandler[Sequence[OpenAIChatMessage],
         `{'echo': 'true'}` to prepend new messages with the previous message.
         For a list of possible kwargs, see
         
https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#extra-parameters-for-chat-api
+      min_batch_size: optional. the minimum batch size to use when batching
+        inputs.
+      max_batch_size: optional. the maximum batch size to use when batching
+        inputs.
+      max_batch_duration_secs: optional. the maximum amount of time to buffer
+        a batch before emitting; used in streaming contexts.
+      max_batch_weight: optional. the maximum total weight of a batch.
+      element_size_fn: optional. a function that returns the size (weight) of
+        an element.
     """
+    super().__init__(
+        min_batch_size=min_batch_size,
+        max_batch_size=max_batch_size,
+        max_batch_duration_secs=max_batch_duration_secs,
+        max_batch_weight=max_batch_weight,
+        element_size_fn=element_size_fn)
     self._model_name = model_name
     self._vllm_server_kwargs: dict[str, str] = vllm_server_kwargs or {}
-    self._env_vars = {}
     self._chat_template_path = chat_template_path
     self._chat_file = f'template-{uuid.uuid4().hex}.jinja'
 
diff --git a/sdks/python/apache_beam/ml/inference/xgboost_inference.py 
b/sdks/python/apache_beam/ml/inference/xgboost_inference.py
index 10289b07641..9d741368511 100644
--- a/sdks/python/apache_beam/ml/inference/xgboost_inference.py
+++ b/sdks/python/apache_beam/ml/inference/xgboost_inference.py
@@ -19,7 +19,6 @@ import sys
 from abc import ABC
 from collections.abc import Callable
 from collections.abc import Iterable
-from collections.abc import Mapping
 from collections.abc import Sequence
 from typing import Any
 from typing import Optional
@@ -79,6 +78,8 @@ class XGBoostModelHandler(ModelHandler[ExampleT, PredictionT, 
ModelT], ABC):
       min_batch_size: Optional[int] = None,
       max_batch_size: Optional[int] = None,
       max_batch_duration_secs: Optional[int] = None,
+      max_batch_weight: Optional[int] = None,
+      element_size_fn: Optional[Callable[[Any], int]] = None,
       **kwargs):
     """Implementation of the ModelHandler interface for XGBoost.
 
@@ -103,8 +104,11 @@ class XGBoostModelHandler(ModelHandler[ExampleT, 
PredictionT, ModelT], ABC):
         inputs.
       max_batch_size: optional. the maximum batch size to use when batching
         inputs.
-      max_batch_duration_secs: optional. the maximum amount of time to buffer 
+      max_batch_duration_secs: optional. the maximum amount of time to buffer
         a batch before emitting; used in streaming contexts.
+      max_batch_weight: optional. the maximum total weight of a batch.
+      element_size_fn: optional. a function that returns the size (weight)
+        of an element.
       kwargs: 'env_vars' can be used to set environment variables
         before loading the model.
 
@@ -121,17 +125,16 @@ class XGBoostModelHandler(ModelHandler[ExampleT, 
PredictionT, ModelT], ABC):
     and should not be instantiated directly. (See instead
     XGBoostModelHandlerNumpy, XGBoostModelHandlerPandas, etc.)
     """
+    super().__init__(
+        min_batch_size=min_batch_size,
+        max_batch_size=max_batch_size,
+        max_batch_duration_secs=max_batch_duration_secs,
+        max_batch_weight=max_batch_weight,
+        element_size_fn=element_size_fn,
+        **kwargs)
     self._model_class = model_class
     self._model_state = model_state
     self._inference_fn = inference_fn
-    self._env_vars = kwargs.get('env_vars', {})
-    self._batching_kwargs = {}
-    if min_batch_size is not None:
-      self._batching_kwargs["min_batch_size"] = min_batch_size
-    if max_batch_size is not None:
-      self._batching_kwargs["max_batch_size"] = max_batch_size
-    if max_batch_duration_secs is not None:
-      self._batching_kwargs["max_batch_duration_secs"] = 
max_batch_duration_secs
 
   def load_model(self) -> Union[xgboost.Booster, xgboost.XGBModel]:
     model = self._model_class()
@@ -146,9 +149,6 @@ class XGBoostModelHandler(ModelHandler[ExampleT, 
PredictionT, ModelT], ABC):
   def get_metrics_namespace(self) -> str:
     return 'BeamML_XGBoost'
 
-  def batch_elements_kwargs(self) -> Mapping[str, Any]:
-    return self._batching_kwargs
-
 
 class XGBoostModelHandlerNumpy(XGBoostModelHandler[numpy.ndarray,
                                                    PredictionResult,

Reply via email to