Eliaaazzz opened a new issue, #37414:
URL: https://github.com/apache/beam/issues/37414
### What would you like to happen?
I’ve been evaluating RunInference for variable-length inputs (NLP/LLMs), and
the current batching approach (via BatchElements) feels too coarse for these
workloads.
Right now batching is mostly count-based (batch size) or byte-based. That
works well for fixed-shape inputs (e.g., images), but for highly variable
sequence lengths it creates a few issues:
1. Padding waste: One long sequence in a batch forces padding for many short
sequences, wasting GPU compute on padding tokens.
2. Unpredictable OOMs: A single outlier (very long input) can spike memory
usage for the entire batch and cause hard-to-reproduce OOMs.
3. Boilerplate for users: I ended up writing a custom DoFn to “bucket by
length” before RunInference. This is a common pattern in NLP and would be
better supported natively.
**Proposed solution**
Add content-aware batching to RunInference, so batching can be driven by
computational cost rather than only element count.
Conceptually:
- [ ] Cost-based thresholds: Let users provide a cost_fn (e.g., token count)
and a max_cost_per_batch (e.g., 4096 tokens) to form batches.
- [ ] Optional bucketing: Buffer elements and group similar lengths to
reduce padding overhead.
- [ ] Dynamic padding integration: Ensure the ModelHandler (e.g., PyTorch)
pads to the batch max length (per batch), not a global max.
Sketch API (conceptual)
pipeline | RunInference(
model_handler=...,
batching_kwargs={
"mode": "dynamic",
"max_cost": 4096, # e.g., total tokens per batch
"cost_fn": lambda x: len(x), # user-defined cost metric
"bucket": True, # optional
},
)
Alternatives considered
1. GroupIntoBatches with weights: Works, but is verbose and separates
batching from inference/padding logic.
2. Static padding to max length: Too inefficient for production latency/cost
requirements.
Additional context
I’ve skimmed base.py and pytorch_inference.py. My initial thought is a
TokenBasedBatcher (or an extension around BatchElements) that can be reused by
RunInference.
I’m interested in working on this for GSoC 2026 and can draft a design Any
pointers to existing work or preferred direction would be appreciated.
### Issue Priority
Priority: 2 (default / most feature requests should be filed as P2)
### Issue Components
- [x] Component: Python SDK
- [ ] Component: Java SDK
- [ ] Component: Go SDK
- [ ] Component: Typescript SDK
- [ ] Component: IO connector
- [ ] Component: Beam YAML
- [ ] Component: Beam examples
- [ ] Component: Beam playground
- [ ] Component: Beam katas
- [ ] Component: Website
- [ ] Component: Infrastructure
- [ ] Component: Spark Runner
- [ ] Component: Flink Runner
- [ ] Component: Samza Runner
- [ ] Component: Twister2 Runner
- [ ] Component: Hazelcast Jet Runner
- [ ] Component: Google Cloud Dataflow Runner
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]