This is an automated email from the ASF dual-hosted git repository.
damccorm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 30518c840bf [Stateful] Implement length-aware keying to minimize
padding in BatchElements (Part 2/3) (#37565)
30518c840bf is described below
commit 30518c840bf9dcbc93c5ec4ac1b7b9717e455543
Author: Elia LIU <[email protected]>
AuthorDate: Fri Feb 27 04:49:41 2026 -0800
[Stateful] Implement length-aware keying to minimize padding in
BatchElements (Part 2/3) (#37565)
* Add length-aware batching to BatchElements and ModelHandler
- Add length_fn and bucket_boundaries parameters to ModelHandler.__init__
to support length-aware bucketed keying for ML inference batching
- Add WithLengthBucketKey DoFn to route elements by length buckets
- Update BatchElements to support length-aware batching when
max_batch_duration_secs is set, reducing padding waste for
variable-length sequences (e.g., NLP workloads)
- Default bucket boundaries: [16, 32, 64, 128, 256, 512]
- Add comprehensive tests validating bucket assignment, mixed-length
batching, and padding efficiency improvements (77% vs 68% on bimodal data)
- All formatting (yapf) and lint (pylint 10/10) checks passed
* Refine length bucketing docs and fix boundary inclusivity
Expands parameter documentation for clarity and replaces bisect_left with
bisect_right to ensure bucket boundaries are inclusive on the lower bound.
Updates util_test.py assertions accordingly.
---
sdks/python/apache_beam/ml/inference/base.py | 17 ++
sdks/python/apache_beam/ml/inference/base_test.py | 39 ++++
sdks/python/apache_beam/transforms/util.py | 60 +++++-
sdks/python/apache_beam/transforms/util_test.py | 233 ++++++++++++++++++++++
4 files changed, 347 insertions(+), 2 deletions(-)
diff --git a/sdks/python/apache_beam/ml/inference/base.py
b/sdks/python/apache_beam/ml/inference/base.py
index ef5d15264b5..b2441281dd1 100644
--- a/sdks/python/apache_beam/ml/inference/base.py
+++ b/sdks/python/apache_beam/ml/inference/base.py
@@ -178,6 +178,8 @@ class ModelHandler(Generic[ExampleT, PredictionT, ModelT]):
max_batch_duration_secs: Optional[int] = None,
max_batch_weight: Optional[int] = None,
element_size_fn: Optional[Callable[[Any], int]] = None,
+ batch_length_fn: Optional[Callable[[Any], int]] = None,
+ batch_bucket_boundaries: Optional[list[int]] = None,
large_model: bool = False,
model_copies: Optional[int] = None,
**kwargs):
@@ -190,6 +192,17 @@ class ModelHandler(Generic[ExampleT, PredictionT, ModelT]):
before emitting; used in streaming contexts.
max_batch_weight: the maximum weight of a batch. Requires
element_size_fn.
element_size_fn: a function that returns the size (weight) of an element.
+ batch_length_fn: a callable mapping an element to its length (int). When
+ set together with max_batch_duration_secs, enables length-aware
bucketed
+ keying so that elements of similar length are batched together,
reducing
+ padding waste for variable-length inputs. Bucket assignment uses
+ bisect_right so boundaries are lower-inclusive: e.g., for boundaries
+ [10, 50], buckets are (-inf, 10), [10, 50), [50, inf).
+ batch_bucket_boundaries: a sorted list of positive boundary values for
+ length bucketing. Boundaries are lower-inclusive (bisect_right
+ semantics): bucket i covers lengths in [boundaries[i-1],
boundaries[i]).
+ Requires batch_length_fn. Defaults to [16, 32, 64, 128, 256, 512] when
+ batch_length_fn is set.
large_model: set to true if your model is large enough to run into
memory pressure if you load multiple copies.
model_copies: The exact number of models that you would like loaded
@@ -209,6 +222,10 @@ class ModelHandler(Generic[ExampleT, PredictionT, ModelT]):
self._batching_kwargs['max_batch_weight'] = max_batch_weight
if element_size_fn is not None:
self._batching_kwargs['element_size_fn'] = element_size_fn
+ if batch_length_fn is not None:
+ self._batching_kwargs['length_fn'] = batch_length_fn
+ if batch_bucket_boundaries is not None:
+ self._batching_kwargs['bucket_boundaries'] = batch_bucket_boundaries
self._large_model = large_model
self._model_copies = model_copies
self._share_across_processes = large_model or (model_copies is not None)
diff --git a/sdks/python/apache_beam/ml/inference/base_test.py
b/sdks/python/apache_beam/ml/inference/base_test.py
index 8236ac5c1e5..f25316f474f 100644
--- a/sdks/python/apache_beam/ml/inference/base_test.py
+++ b/sdks/python/apache_beam/ml/inference/base_test.py
@@ -2279,6 +2279,45 @@ class ModelHandlerBatchingArgsTest(unittest.TestCase):
self.assertEqual(kwargs, {'max_batch_duration_secs': 60})
+ def test_batch_length_fn_and_batch_bucket_boundaries(self):
+ """batch_length_fn and batch_bucket_boundaries passed through to kwargs."""
+ handler = FakeModelHandlerForBatching(
+ batch_length_fn=len, batch_bucket_boundaries=[16, 32, 64])
+ kwargs = handler.batch_elements_kwargs()
+
+ self.assertIs(kwargs['length_fn'], len)
+ self.assertEqual(kwargs['bucket_boundaries'], [16, 32, 64])
+
+ def test_batch_length_fn_only(self):
+ """batch_length_fn alone is passed through without bucket_boundaries."""
+ handler = FakeModelHandlerForBatching(batch_length_fn=len)
+ kwargs = handler.batch_elements_kwargs()
+
+ self.assertIs(kwargs['length_fn'], len)
+ self.assertNotIn('bucket_boundaries', kwargs)
+
+ def test_batch_bucket_boundaries_without_batch_length_fn(self):
+ """Passing batch_bucket_boundaries without batch_length_fn should fail in
+ BatchElements.
+
+ Note: ModelHandler.__init__ doesn't validate this; the error is raised
+ by BatchElements when batch_elements_kwargs are used."""
+ handler = FakeModelHandlerForBatching(batch_bucket_boundaries=[10, 20])
+ kwargs = handler.batch_elements_kwargs()
+ # The kwargs are stored, but BatchElements will reject them
+ self.assertEqual(kwargs['bucket_boundaries'], [10, 20])
+ self.assertNotIn('length_fn', kwargs)
+
+ def test_batching_kwargs_none_values_omitted(self):
+ """None values for batch_length_fn and batch_bucket_boundaries are not in
+ kwargs."""
+ handler = FakeModelHandlerForBatching(
+ min_batch_size=5, batch_length_fn=None, batch_bucket_boundaries=None)
+ kwargs = handler.batch_elements_kwargs()
+ self.assertNotIn('length_fn', kwargs)
+ self.assertNotIn('bucket_boundaries', kwargs)
+ self.assertEqual(kwargs['min_batch_size'], 5)
+
class SimpleFakeModelHandler(base.ModelHandler[int, int, FakeModel]):
def load_model(self):
diff --git a/sdks/python/apache_beam/transforms/util.py
b/sdks/python/apache_beam/transforms/util.py
index 770a5baec36..1faac99b64d 100644
--- a/sdks/python/apache_beam/transforms/util.py
+++ b/sdks/python/apache_beam/transforms/util.py
@@ -20,6 +20,7 @@
# pytype: skip-file
+import bisect
import collections
import contextlib
import hashlib
@@ -1208,6 +1209,30 @@ class WithSharedKey(DoFn):
yield (self.key, element)
+class WithLengthBucketKey(DoFn):
+ """Keys elements with (worker_uuid, length_bucket) for length-aware
+ stateful batching. Elements of similar length are routed to the same
+ state partition, reducing padding waste."""
+ def __init__(self, length_fn, bucket_boundaries):
+ self.shared_handle = shared.Shared()
+ self._length_fn = length_fn
+ self._bucket_boundaries = bucket_boundaries
+
+ def setup(self):
+ self.key = self.shared_handle.acquire(
+ load_shared_key, "WithLengthBucketKey").key
+
+ def _get_bucket(self, length):
+ # bisect_right: boundaries are lower-inclusive.
+ # e.g., for boundaries [10, 50], buckets are (-inf, 10), [10, 50), [50,
inf)
+ return bisect.bisect_right(self._bucket_boundaries, length)
+
+ def process(self, element):
+ length = self._length_fn(element)
+ bucket = self._get_bucket(length)
+ yield ((self.key, bucket), element)
+
+
@typehints.with_input_types(T)
@typehints.with_output_types(list[T])
class BatchElements(PTransform):
@@ -1267,7 +1292,18 @@ class BatchElements(PTransform):
donwstream operations (mostly for testing)
record_metrics: (optional) whether or not to record beam metrics on
distributions of the batch size. Defaults to True.
+ length_fn: (optional) a callable mapping an element to its length (int).
+ When set together with bucket_boundaries, enables length-aware bucketed
+ keying on the stateful path so that elements of similar length are
+ routed to the same batch, reducing padding waste.
+ bucket_boundaries: (optional) a sorted list of positive boundary values
+ for length bucketing. Boundaries are lower-inclusive (bisect_right
+ semantics): e.g., for boundaries [10, 50], buckets are (-inf, 10),
+ [10, 50), [50, inf). Defaults to [16, 32, 64, 128, 256, 512] when
+ length_fn is set. Requires length_fn.
"""
+ _DEFAULT_BUCKET_BOUNDARIES = [16, 32, 64, 128, 256, 512]
+
def __init__(
self,
min_batch_size=1,
@@ -1280,7 +1316,17 @@ class BatchElements(PTransform):
element_size_fn=lambda x: 1,
variance=0.25,
clock=time.time,
- record_metrics=True):
+ record_metrics=True,
+ length_fn=None,
+ bucket_boundaries=None):
+ if bucket_boundaries is not None and length_fn is None:
+ raise ValueError('bucket_boundaries requires length_fn to be set.')
+ if bucket_boundaries is not None:
+ if (not bucket_boundaries or any(b <= 0 for b in bucket_boundaries) or
+ bucket_boundaries != sorted(bucket_boundaries)):
+ raise ValueError(
+ 'bucket_boundaries must be a non-empty sorted list of '
+ 'positive values.')
self._batch_size_estimator = _BatchSizeEstimator(
min_batch_size=min_batch_size,
max_batch_size=max_batch_size,
@@ -1294,13 +1340,23 @@ class BatchElements(PTransform):
self._element_size_fn = element_size_fn
self._max_batch_dur = max_batch_duration_secs
self._clock = clock
+ self._length_fn = length_fn
+ if length_fn is not None and bucket_boundaries is None:
+ self._bucket_boundaries = self._DEFAULT_BUCKET_BOUNDARIES
+ else:
+ self._bucket_boundaries = bucket_boundaries
def expand(self, pcoll):
if getattr(pcoll.pipeline.runner, 'is_streaming', False):
raise NotImplementedError("Requires stateful processing (BEAM-2687)")
elif self._max_batch_dur is not None:
coder = coders.registry.get_coder(pcoll)
- return pcoll | ParDo(WithSharedKey()) | ParDo(
+ if self._length_fn is not None:
+ keying_dofn = WithLengthBucketKey(
+ self._length_fn, self._bucket_boundaries)
+ else:
+ keying_dofn = WithSharedKey()
+ return pcoll | ParDo(keying_dofn) | ParDo(
_pardo_stateful_batch_elements(
coder,
self._batch_size_estimator,
diff --git a/sdks/python/apache_beam/transforms/util_test.py
b/sdks/python/apache_beam/transforms/util_test.py
index 98edb4cc2bd..47c3ee54452 100644
--- a/sdks/python/apache_beam/transforms/util_test.py
+++ b/sdks/python/apache_beam/transforms/util_test.py
@@ -65,6 +65,7 @@ from apache_beam.testing.util import TestWindowedValue
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import contains_in_any_order
from apache_beam.testing.util import equal_to
+from apache_beam.testing.util import is_not_empty
from apache_beam.transforms import trigger
from apache_beam.transforms import util
from apache_beam.transforms import window
@@ -1025,6 +1026,238 @@ class BatchElementsTest(unittest.TestCase):
| beam.Map(len))
assert_that(res, equal_to([1, 1, 2, 4, 8, 16, 32, 50, 50]))
+ def test_length_bucket_assignment(self):
+ """WithLengthBucketKey assigns correct bucket indices."""
+ boundaries = [10, 50, 100]
+ dofn = util.WithLengthBucketKey(length_fn=len,
bucket_boundaries=boundaries)
+ # bisect_right: boundaries are lower-inclusive.
+ # e.g., for boundaries [10, 50, 100], buckets are:
+ # (-inf, 10), [10, 50), [50, 100), [100, inf)
+ self.assertEqual(dofn._get_bucket(5), 0)
+ self.assertEqual(dofn._get_bucket(10), 1)
+ self.assertEqual(dofn._get_bucket(11), 1)
+ self.assertEqual(dofn._get_bucket(50), 2)
+ self.assertEqual(dofn._get_bucket(51), 2)
+ self.assertEqual(dofn._get_bucket(100), 3)
+ self.assertEqual(dofn._get_bucket(101), 3)
+ self.assertEqual(dofn._get_bucket(999), 3)
+
+ def test_stateful_length_aware_constant_batch(self):
+ """Elements in distinct length groups produce separate batches."""
+ # Create short strings (len 1-5) and long strings (len 50-55)
+ short = ['x' * i for i in range(1, 6)] * 4 # 20 short strings
+ long = ['y' * i for i in range(50, 56)] * 4 # 24 long strings
+ elements = short + long
+
+ p = TestPipeline('FnApiRunner')
+ batches = (
+ p
+ | beam.Create(elements)
+ | util.BatchElements(
+ min_batch_size=5,
+ max_batch_size=10,
+ max_batch_duration_secs=100,
+ length_fn=len,
+ bucket_boundaries=[10, 50]))
+
+ # Verify that no batch mixes short and long elements
+ def check_no_mixing(batch):
+ lengths = [len(s) for s in batch]
+ min_len, max_len = min(lengths), max(lengths)
+ # Within a bucket, all elements should have similar length
+ assert max_len - min_len < 50, (
+ f'Batch mixed short and long: lengths {lengths}')
+ return True
+
+ checks = batches | beam.Map(check_no_mixing)
+ assert_that(checks, is_not_empty())
+ res = p.run()
+ res.wait_until_finish()
+
+ def test_stateful_length_aware_default_boundaries(self):
+ """Default boundaries [16, 32, 64, 128, 256, 512] are applied."""
+ be = util.BatchElements(max_batch_duration_secs=100, length_fn=len)
+ self.assertEqual(be._bucket_boundaries, [16, 32, 64, 128, 256, 512])
+
+ def test_length_aware_requires_length_fn(self):
+ """bucket_boundaries without length_fn raises ValueError."""
+ with self.assertRaises(ValueError):
+ util.BatchElements(
+ max_batch_duration_secs=100, bucket_boundaries=[10, 20])
+
+ def test_bucket_boundaries_must_be_sorted(self):
+ """Unsorted boundaries raise ValueError."""
+ with self.assertRaises(ValueError):
+ util.BatchElements(
+ max_batch_duration_secs=100,
+ length_fn=len,
+ bucket_boundaries=[50, 10, 100])
+
+ def test_bucket_boundaries_must_be_positive(self):
+ """Non-positive boundaries raise ValueError."""
+ with self.assertRaises(ValueError):
+ util.BatchElements(
+ max_batch_duration_secs=100,
+ length_fn=len,
+ bucket_boundaries=[0, 10, 100])
+
+ def test_length_fn_without_stateful_is_ignored(self):
+ """length_fn without max_batch_duration_secs uses non-stateful path."""
+ with TestPipeline() as p:
+ res = (
+ p
+ | beam.Create(['a', 'bb', 'ccc'])
+ | util.BatchElements(
+ min_batch_size=3, max_batch_size=3, length_fn=len)
+ | beam.Map(len))
+ assert_that(res, equal_to([3]))
+
+ def test_padding_efficiency_bimodal(self):
+ """Benchmark: length-aware bucketing yields better padding efficiency
+ than unbucketed batching on a bimodal length distribution.
+
+ Padding efficiency per batch = sum(lengths) / (max_len * batch_size).
+ With bucketing, short and long elements land in separate batches,
+ so each batch pads to a smaller max, improving efficiency.
+ """
+ random.seed(42)
+ short = ['x' * random.randint(5, 30) for _ in range(500)]
+ long = ['y' * random.randint(200, 512) for _ in range(500)]
+ elements = short + long
+ batch_size = 32
+
+ def batch_efficiency(batch):
+ """Returns (useful_tokens, padded_tokens) for one batch."""
+ lengths = [len(s) for s in batch]
+ return (sum(lengths), max(lengths) * len(lengths))
+
+ # Run WITH bucketing — collect (useful, padded) per batch
+ p_bucketed = TestPipeline('FnApiRunner')
+ bucketed_eff = (
+ p_bucketed
+ | 'CreateBucketed' >> beam.Create(elements)
+ | 'BatchBucketed' >> util.BatchElements(
+ min_batch_size=batch_size,
+ max_batch_size=batch_size,
+ max_batch_duration_secs=100,
+ length_fn=len,
+ bucket_boundaries=[16, 32, 64, 128, 256, 512])
+ | 'EffBucketed' >> beam.Map(batch_efficiency)
+ | 'SumBucketed' >> beam.CombineGlobally(
+ lambda pairs: (sum(p[0] for p in pairs), sum(p[1] for p in
pairs))))
+
+ # Run WITHOUT bucketing
+ p_unbucketed = TestPipeline('FnApiRunner')
+ unbucketed_eff = (
+ p_unbucketed
+ | 'CreateUnbucketed' >> beam.Create(elements)
+ | 'BatchUnbucketed' >> util.BatchElements(
+ min_batch_size=batch_size,
+ max_batch_size=batch_size,
+ max_batch_duration_secs=100)
+ | 'EffUnbucketed' >> beam.Map(batch_efficiency)
+ | 'SumUnbucketed' >> beam.CombineGlobally(
+ lambda pairs: (sum(p[0] for p in pairs), sum(p[1] for p in
pairs))))
+
+ def check_bucketed_above_threshold(totals):
+ useful, padded = totals[0]
+ eff = useful / padded if padded else 0
+ assert eff > 0.70, (
+ f'Bucketed padding efficiency {eff:.2%} should be > 70%')
+
+ def check_unbucketed_below_bucketed(totals):
+ useful, padded = totals[0]
+ eff = useful / padded if padded else 0
+ # With bimodal data in a single key, short elements get padded
+ # to the max of each batch which often includes long elements.
+ assert eff < 0.70, (
+ f'Unbucketed efficiency {eff:.2%} expected < 70% for '
+ f'bimodal distribution (sanity check)')
+
+ assert_that(bucketed_eff, check_bucketed_above_threshold)
+ res = p_bucketed.run()
+ res.wait_until_finish()
+
+ assert_that(unbucketed_eff, check_unbucketed_below_bucketed)
+ res = p_unbucketed.run()
+ res.wait_until_finish()
+
+ def test_with_length_bucket_key_setup_and_process(self):
+ """WithLengthBucketKey.setup() and process() work correctly in pipeline."""
+ boundaries = [10, 50]
+ elements = ['short', 'x' * 30, 'y' * 60]
+
+ with TestPipeline('FnApiRunner') as p:
+ result = (
+ p
+ | beam.Create(elements)
+ | beam.ParDo(util.WithLengthBucketKey(len, boundaries)))
+
+ def check_keys(keyed_elements):
+ # Each element should have format ((worker_key, bucket), element)
+ for (key, bucket), elem in keyed_elements:
+ # Verify key is a UUID string
+ assert isinstance(key, str) and len(key) > 0
+ # Verify bucket is correct
+ if len(elem) < 10:
+ assert bucket == 0, f'Expected bucket 0 for {elem}'
+ elif len(elem) < 50:
+ assert bucket == 1, f'Expected bucket 1 for {elem}'
+ else:
+ assert bucket == 2, f'Expected bucket 2 for {elem}'
+
+ assert_that(result, check_keys)
+
+ def test_bucket_boundaries_empty_list(self):
+ """Empty bucket_boundaries list raises ValueError."""
+ with self.assertRaises(ValueError):
+ util.BatchElements(
+ max_batch_duration_secs=100, length_fn=len, bucket_boundaries=[])
+
+ def test_with_custom_bucket_boundaries(self):
+ """Custom bucket_boundaries are used instead of defaults."""
+ custom_boundaries = [5, 15, 25]
+ be = util.BatchElements(
+ max_batch_duration_secs=100,
+ length_fn=len,
+ bucket_boundaries=custom_boundaries)
+ self.assertEqual(be._bucket_boundaries, custom_boundaries)
+
+ def test_length_fn_applied_in_pipeline(self):
+ """Verify length_fn is used for bucketing in stateful batching."""
+ # Create strings of different lengths that should go to different buckets
+ short_strings = ['x' * i for i in range(1, 5)] # lengths 1-4, bucket 0
+ medium_strings = ['y' * i for i in range(20, 24)] # lengths 20-23, bucket
1
+ elements = short_strings + medium_strings
+
+ with TestPipeline('FnApiRunner') as p:
+ batches = (
+ p
+ | beam.Create(elements)
+ | util.BatchElements(
+ min_batch_size=2,
+ max_batch_size=10,
+ max_batch_duration_secs=100,
+ length_fn=len,
+ bucket_boundaries=[10, 30]))
+
+ def check_batch_homogeneity(batch):
+ """Batches should contain elements of similar length."""
+ lengths = [len(s) for s in batch]
+ # If bucketing works, all elements should be in same bucket
+ # (either all < 10 or all between 10 and 30)
+ min_len, max_len = min(lengths), max(lengths)
+ if min_len < 10:
+ # Short bucket: all should be < 10
+ assert max_len < 10, f'Mixed batch: {lengths}'
+ else:
+ # Medium bucket: all should be >= 10
+ assert min_len >= 10, f'Mixed batch: {lengths}'
+ return True
+
+ checks = batches | beam.Map(check_batch_homogeneity)
+ assert_that(checks, is_not_empty())
+
class IdentityWindowTest(unittest.TestCase):
def test_window_preserved(self):