Eliaaazzz opened a new pull request, #37428:
URL: https://github.com/apache/beam/pull/37428
[RunInference] Add content-aware dynamic batching via element_size_fn (Issue
#37414)
### Rationale
This PR addresses #37414 by introducing content-aware dynamic batching to
`RunInference`.
Currently, `RunInference` relies on `BatchElements` with a strict
count-based limit (`max_batch_size`). However, for workloads like NLP and LLMs,
variable-length inputs (tokens) lead to significant variance in computational
cost and memory usage. A batch of 10 short sentences is vastly different from a
batch of 10 long documents.
This change allows users to provide a custom `element_size_fn` to
`ModelHandler`, which is then passed down to the underlying `BatchElements`
transform. This enables batching based on total "weight" (e.g., token count)
rather than just element count, improving GPU utilization and preventing OOM
errors.
### Design Principles
This implementation prioritizes modularity and type safety through the
following design choices:
* **Decorator Pattern (Composition over Inheritance)**:
Implemented `_SizingModelHandler` as a wrapper to dynamically attach
sizing behavior to *any* `ModelHandler` implementation. This avoids the
combinatorial explosion of subclasses (e.g., `TFModelHandlerWithSizing`,
`PyTorchModelHandlerWithSizing`) and keeps the codebase DRY.
* **Open-Closed Principle (OCP)**:
The change is strictly additive. The base `ModelHandler` remains closed
for modification, ensuring zero regression risk for existing pipelines.
Functionality is extended purely by overriding `batch_elements_kwargs` in the
wrapper and safely delegating all other methods to the base instance.
* **Architectural Consistency**:
The implementation mirrors the existing `_PreProcessingModelHandler`
pattern in Apache Beam. This ensures API consistency and reduces cognitive load
for maintainers.
### Changes
* **`apache_beam/ml/inference/base.py`**:
* Added `with_element_size_fn` method to the `ModelHandler` interface.
* Implemented `_SizingModelHandler` wrapper class.
* Overrode `batch_elements_kwargs` to inject `element_size_fn` while
preserving existing configuration (using safe dictionary copy).
* Implemented full delegation for all `ModelHandler` methods (e.g.,
`update_model_paths`, `get_metrics_namespace`) to ensure transparency.
### Usage Example
```python
def token_counter(text: str) -> int:
return len(text.split())
# Configure the handler to batch based on token count (e.g., max 1000 tokens
per batch)
# The max_batch_size in BatchElements will now act as a limit on the sum of
element sizes.
model_handler = MyModelHandler().with_element_size_fn(token_counter)
```
### Testing
Comprehensive tests were added in
`sdks/python/apache_beam/ml/inference/base_test.py`:
1. **`test_kwargs_are_passed_correctly`**
* Verifies that `element_size_fn` is correctly injected into
`batch_elements_kwargs`.
2. **`test_batch_elements_integration_with_beam_pipeline`**
* Verifies run-time behavior.
* *Scenario*: Input elements with weight 5, `max_batch_size` set to 10.
* *Result*: `BatchElements` correctly creates batches of 2 elements
(5+5=10), confirming the dynamic sizing logic.
3. **`test_element_size_fn_wrapper_delegates_correctly`**
* Ensures all `ModelHandler` methods are properly delegated (Critical
for features like model updates).
4. **`test_multiple_wrappers_can_be_chained`**
* Verifies compatibility when chained with `with_preprocess_fn`.
------------------------
Thank you for your contribution! Follow this checklist to help us
incorporate your contribution quickly and easily:
[x] fixes #37414
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]