This is an automated email from the ASF dual-hosted git repository.

damccorm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 30518c840bf [Stateful] Implement length-aware keying to minimize 
padding in BatchElements (Part 2/3) (#37565)
30518c840bf is described below

commit 30518c840bf9dcbc93c5ec4ac1b7b9717e455543
Author: Elia LIU <[email protected]>
AuthorDate: Fri Feb 27 04:49:41 2026 -0800

    [Stateful] Implement length-aware keying to minimize padding in 
BatchElements (Part 2/3) (#37565)
    
    * Add length-aware batching to BatchElements and ModelHandler
    
    - Add length_fn and bucket_boundaries parameters to ModelHandler.__init__
      to support length-aware bucketed keying for ML inference batching
    - Add WithLengthBucketKey DoFn to route elements by length buckets
    - Update BatchElements to support length-aware batching when
      max_batch_duration_secs is set, reducing padding waste for
      variable-length sequences (e.g., NLP workloads)
    - Default bucket boundaries: [16, 32, 64, 128, 256, 512]
    - Add comprehensive tests validating bucket assignment, mixed-length
      batching, and padding efficiency improvements (77% vs 68% on bimodal data)
    - All formatting (yapf) and lint (pylint 10/10) checks passed
    
    * Refine length bucketing docs and fix boundary inclusivity
    
    Expands parameter documentation for clarity and replaces bisect_left with 
bisect_right to ensure bucket boundaries are inclusive on the lower bound. 
Updates util_test.py assertions accordingly.
---
 sdks/python/apache_beam/ml/inference/base.py      |  17 ++
 sdks/python/apache_beam/ml/inference/base_test.py |  39 ++++
 sdks/python/apache_beam/transforms/util.py        |  60 +++++-
 sdks/python/apache_beam/transforms/util_test.py   | 233 ++++++++++++++++++++++
 4 files changed, 347 insertions(+), 2 deletions(-)

diff --git a/sdks/python/apache_beam/ml/inference/base.py 
b/sdks/python/apache_beam/ml/inference/base.py
index ef5d15264b5..b2441281dd1 100644
--- a/sdks/python/apache_beam/ml/inference/base.py
+++ b/sdks/python/apache_beam/ml/inference/base.py
@@ -178,6 +178,8 @@ class ModelHandler(Generic[ExampleT, PredictionT, ModelT]):
       max_batch_duration_secs: Optional[int] = None,
       max_batch_weight: Optional[int] = None,
       element_size_fn: Optional[Callable[[Any], int]] = None,
+      batch_length_fn: Optional[Callable[[Any], int]] = None,
+      batch_bucket_boundaries: Optional[list[int]] = None,
       large_model: bool = False,
       model_copies: Optional[int] = None,
       **kwargs):
@@ -190,6 +192,17 @@ class ModelHandler(Generic[ExampleT, PredictionT, ModelT]):
         before emitting; used in streaming contexts.
       max_batch_weight: the maximum weight of a batch. Requires 
element_size_fn.
       element_size_fn: a function that returns the size (weight) of an element.
+      batch_length_fn: a callable mapping an element to its length (int). When
+        set together with max_batch_duration_secs, enables length-aware 
bucketed
+        keying so that elements of similar length are batched together, 
reducing
+        padding waste for variable-length inputs. Bucket assignment uses
+        bisect_right so boundaries are lower-inclusive: e.g., for boundaries
+        [10, 50], buckets are (-inf, 10), [10, 50), [50, inf).
+      batch_bucket_boundaries: a sorted list of positive boundary values for
+        length bucketing. Boundaries are lower-inclusive (bisect_right
+        semantics): bucket i covers lengths in [boundaries[i-1], 
boundaries[i]).
+        Requires batch_length_fn. Defaults to [16, 32, 64, 128, 256, 512] when
+        batch_length_fn is set.
       large_model: set to true if your model is large enough to run into
         memory pressure if you load multiple copies.
       model_copies: The exact number of models that you would like loaded
@@ -209,6 +222,10 @@ class ModelHandler(Generic[ExampleT, PredictionT, ModelT]):
       self._batching_kwargs['max_batch_weight'] = max_batch_weight
     if element_size_fn is not None:
       self._batching_kwargs['element_size_fn'] = element_size_fn
+    if batch_length_fn is not None:
+      self._batching_kwargs['length_fn'] = batch_length_fn
+    if batch_bucket_boundaries is not None:
+      self._batching_kwargs['bucket_boundaries'] = batch_bucket_boundaries
     self._large_model = large_model
     self._model_copies = model_copies
     self._share_across_processes = large_model or (model_copies is not None)
diff --git a/sdks/python/apache_beam/ml/inference/base_test.py 
b/sdks/python/apache_beam/ml/inference/base_test.py
index 8236ac5c1e5..f25316f474f 100644
--- a/sdks/python/apache_beam/ml/inference/base_test.py
+++ b/sdks/python/apache_beam/ml/inference/base_test.py
@@ -2279,6 +2279,45 @@ class ModelHandlerBatchingArgsTest(unittest.TestCase):
 
     self.assertEqual(kwargs, {'max_batch_duration_secs': 60})
 
+  def test_batch_length_fn_and_batch_bucket_boundaries(self):
+    """batch_length_fn and batch_bucket_boundaries passed through to kwargs."""
+    handler = FakeModelHandlerForBatching(
+        batch_length_fn=len, batch_bucket_boundaries=[16, 32, 64])
+    kwargs = handler.batch_elements_kwargs()
+
+    self.assertIs(kwargs['length_fn'], len)
+    self.assertEqual(kwargs['bucket_boundaries'], [16, 32, 64])
+
+  def test_batch_length_fn_only(self):
+    """batch_length_fn alone is passed through without bucket_boundaries."""
+    handler = FakeModelHandlerForBatching(batch_length_fn=len)
+    kwargs = handler.batch_elements_kwargs()
+
+    self.assertIs(kwargs['length_fn'], len)
+    self.assertNotIn('bucket_boundaries', kwargs)
+
+  def test_batch_bucket_boundaries_without_batch_length_fn(self):
+    """Passing batch_bucket_boundaries without batch_length_fn should fail in
+    BatchElements.
+
+    Note: ModelHandler.__init__ doesn't validate this; the error is raised
+    by BatchElements when batch_elements_kwargs are used."""
+    handler = FakeModelHandlerForBatching(batch_bucket_boundaries=[10, 20])
+    kwargs = handler.batch_elements_kwargs()
+    # The kwargs are stored, but BatchElements will reject them
+    self.assertEqual(kwargs['bucket_boundaries'], [10, 20])
+    self.assertNotIn('length_fn', kwargs)
+
+  def test_batching_kwargs_none_values_omitted(self):
+    """None values for batch_length_fn and batch_bucket_boundaries are not in
+    kwargs."""
+    handler = FakeModelHandlerForBatching(
+        min_batch_size=5, batch_length_fn=None, batch_bucket_boundaries=None)
+    kwargs = handler.batch_elements_kwargs()
+    self.assertNotIn('length_fn', kwargs)
+    self.assertNotIn('bucket_boundaries', kwargs)
+    self.assertEqual(kwargs['min_batch_size'], 5)
+
 
 class SimpleFakeModelHandler(base.ModelHandler[int, int, FakeModel]):
   def load_model(self):
diff --git a/sdks/python/apache_beam/transforms/util.py 
b/sdks/python/apache_beam/transforms/util.py
index 770a5baec36..1faac99b64d 100644
--- a/sdks/python/apache_beam/transforms/util.py
+++ b/sdks/python/apache_beam/transforms/util.py
@@ -20,6 +20,7 @@
 
 # pytype: skip-file
 
+import bisect
 import collections
 import contextlib
 import hashlib
@@ -1208,6 +1209,30 @@ class WithSharedKey(DoFn):
     yield (self.key, element)
 
 
+class WithLengthBucketKey(DoFn):
+  """Keys elements with (worker_uuid, length_bucket) for length-aware
+  stateful batching. Elements of similar length are routed to the same
+  state partition, reducing padding waste."""
+  def __init__(self, length_fn, bucket_boundaries):
+    self.shared_handle = shared.Shared()
+    self._length_fn = length_fn
+    self._bucket_boundaries = bucket_boundaries
+
+  def setup(self):
+    self.key = self.shared_handle.acquire(
+        load_shared_key, "WithLengthBucketKey").key
+
+  def _get_bucket(self, length):
+    # bisect_right: boundaries are lower-inclusive.
+    # e.g., for boundaries [10, 50], buckets are (-inf, 10), [10, 50), [50, 
inf)
+    return bisect.bisect_right(self._bucket_boundaries, length)
+
+  def process(self, element):
+    length = self._length_fn(element)
+    bucket = self._get_bucket(length)
+    yield ((self.key, bucket), element)
+
+
 @typehints.with_input_types(T)
 @typehints.with_output_types(list[T])
 class BatchElements(PTransform):
@@ -1267,7 +1292,18 @@ class BatchElements(PTransform):
         donwstream operations (mostly for testing)
     record_metrics: (optional) whether or not to record beam metrics on
         distributions of the batch size. Defaults to True.
+    length_fn: (optional) a callable mapping an element to its length (int).
+        When set together with bucket_boundaries, enables length-aware bucketed
+        keying on the stateful path so that elements of similar length are
+        routed to the same batch, reducing padding waste.
+    bucket_boundaries: (optional) a sorted list of positive boundary values
+        for length bucketing. Boundaries are lower-inclusive (bisect_right
+        semantics): e.g., for boundaries [10, 50], buckets are (-inf, 10),
+        [10, 50), [50, inf). Defaults to [16, 32, 64, 128, 256, 512] when
+        length_fn is set. Requires length_fn.
   """
+  _DEFAULT_BUCKET_BOUNDARIES = [16, 32, 64, 128, 256, 512]
+
   def __init__(
       self,
       min_batch_size=1,
@@ -1280,7 +1316,17 @@ class BatchElements(PTransform):
       element_size_fn=lambda x: 1,
       variance=0.25,
       clock=time.time,
-      record_metrics=True):
+      record_metrics=True,
+      length_fn=None,
+      bucket_boundaries=None):
+    if bucket_boundaries is not None and length_fn is None:
+      raise ValueError('bucket_boundaries requires length_fn to be set.')
+    if bucket_boundaries is not None:
+      if (not bucket_boundaries or any(b <= 0 for b in bucket_boundaries) or
+          bucket_boundaries != sorted(bucket_boundaries)):
+        raise ValueError(
+            'bucket_boundaries must be a non-empty sorted list of '
+            'positive values.')
     self._batch_size_estimator = _BatchSizeEstimator(
         min_batch_size=min_batch_size,
         max_batch_size=max_batch_size,
@@ -1294,13 +1340,23 @@ class BatchElements(PTransform):
     self._element_size_fn = element_size_fn
     self._max_batch_dur = max_batch_duration_secs
     self._clock = clock
+    self._length_fn = length_fn
+    if length_fn is not None and bucket_boundaries is None:
+      self._bucket_boundaries = self._DEFAULT_BUCKET_BOUNDARIES
+    else:
+      self._bucket_boundaries = bucket_boundaries
 
   def expand(self, pcoll):
     if getattr(pcoll.pipeline.runner, 'is_streaming', False):
       raise NotImplementedError("Requires stateful processing (BEAM-2687)")
     elif self._max_batch_dur is not None:
       coder = coders.registry.get_coder(pcoll)
-      return pcoll | ParDo(WithSharedKey()) | ParDo(
+      if self._length_fn is not None:
+        keying_dofn = WithLengthBucketKey(
+            self._length_fn, self._bucket_boundaries)
+      else:
+        keying_dofn = WithSharedKey()
+      return pcoll | ParDo(keying_dofn) | ParDo(
           _pardo_stateful_batch_elements(
               coder,
               self._batch_size_estimator,
diff --git a/sdks/python/apache_beam/transforms/util_test.py 
b/sdks/python/apache_beam/transforms/util_test.py
index 98edb4cc2bd..47c3ee54452 100644
--- a/sdks/python/apache_beam/transforms/util_test.py
+++ b/sdks/python/apache_beam/transforms/util_test.py
@@ -65,6 +65,7 @@ from apache_beam.testing.util import TestWindowedValue
 from apache_beam.testing.util import assert_that
 from apache_beam.testing.util import contains_in_any_order
 from apache_beam.testing.util import equal_to
+from apache_beam.testing.util import is_not_empty
 from apache_beam.transforms import trigger
 from apache_beam.transforms import util
 from apache_beam.transforms import window
@@ -1025,6 +1026,238 @@ class BatchElementsTest(unittest.TestCase):
           | beam.Map(len))
       assert_that(res, equal_to([1, 1, 2, 4, 8, 16, 32, 50, 50]))
 
+  def test_length_bucket_assignment(self):
+    """WithLengthBucketKey assigns correct bucket indices."""
+    boundaries = [10, 50, 100]
+    dofn = util.WithLengthBucketKey(length_fn=len, 
bucket_boundaries=boundaries)
+    # bisect_right: boundaries are lower-inclusive.
+    # e.g., for boundaries [10, 50, 100], buckets are:
+    #   (-inf, 10), [10, 50), [50, 100), [100, inf)
+    self.assertEqual(dofn._get_bucket(5), 0)
+    self.assertEqual(dofn._get_bucket(10), 1)
+    self.assertEqual(dofn._get_bucket(11), 1)
+    self.assertEqual(dofn._get_bucket(50), 2)
+    self.assertEqual(dofn._get_bucket(51), 2)
+    self.assertEqual(dofn._get_bucket(100), 3)
+    self.assertEqual(dofn._get_bucket(101), 3)
+    self.assertEqual(dofn._get_bucket(999), 3)
+
+  def test_stateful_length_aware_constant_batch(self):
+    """Elements in distinct length groups produce separate batches."""
+    # Create short strings (len 1-5) and long strings (len 50-55)
+    short = ['x' * i for i in range(1, 6)] * 4  # 20 short strings
+    long = ['y' * i for i in range(50, 56)] * 4  # 24 long strings
+    elements = short + long
+
+    p = TestPipeline('FnApiRunner')
+    batches = (
+        p
+        | beam.Create(elements)
+        | util.BatchElements(
+            min_batch_size=5,
+            max_batch_size=10,
+            max_batch_duration_secs=100,
+            length_fn=len,
+            bucket_boundaries=[10, 50]))
+
+    # Verify that no batch mixes short and long elements
+    def check_no_mixing(batch):
+      lengths = [len(s) for s in batch]
+      min_len, max_len = min(lengths), max(lengths)
+      # Within a bucket, all elements should have similar length
+      assert max_len - min_len < 50, (
+          f'Batch mixed short and long: lengths {lengths}')
+      return True
+
+    checks = batches | beam.Map(check_no_mixing)
+    assert_that(checks, is_not_empty())
+    res = p.run()
+    res.wait_until_finish()
+
+  def test_stateful_length_aware_default_boundaries(self):
+    """Default boundaries [16, 32, 64, 128, 256, 512] are applied."""
+    be = util.BatchElements(max_batch_duration_secs=100, length_fn=len)
+    self.assertEqual(be._bucket_boundaries, [16, 32, 64, 128, 256, 512])
+
+  def test_length_aware_requires_length_fn(self):
+    """bucket_boundaries without length_fn raises ValueError."""
+    with self.assertRaises(ValueError):
+      util.BatchElements(
+          max_batch_duration_secs=100, bucket_boundaries=[10, 20])
+
+  def test_bucket_boundaries_must_be_sorted(self):
+    """Unsorted boundaries raise ValueError."""
+    with self.assertRaises(ValueError):
+      util.BatchElements(
+          max_batch_duration_secs=100,
+          length_fn=len,
+          bucket_boundaries=[50, 10, 100])
+
+  def test_bucket_boundaries_must_be_positive(self):
+    """Non-positive boundaries raise ValueError."""
+    with self.assertRaises(ValueError):
+      util.BatchElements(
+          max_batch_duration_secs=100,
+          length_fn=len,
+          bucket_boundaries=[0, 10, 100])
+
+  def test_length_fn_without_stateful_is_ignored(self):
+    """length_fn without max_batch_duration_secs uses non-stateful path."""
+    with TestPipeline() as p:
+      res = (
+          p
+          | beam.Create(['a', 'bb', 'ccc'])
+          | util.BatchElements(
+              min_batch_size=3, max_batch_size=3, length_fn=len)
+          | beam.Map(len))
+      assert_that(res, equal_to([3]))
+
+  def test_padding_efficiency_bimodal(self):
+    """Benchmark: length-aware bucketing yields better padding efficiency
+    than unbucketed batching on a bimodal length distribution.
+
+    Padding efficiency per batch = sum(lengths) / (max_len * batch_size).
+    With bucketing, short and long elements land in separate batches,
+    so each batch pads to a smaller max, improving efficiency.
+    """
+    random.seed(42)
+    short = ['x' * random.randint(5, 30) for _ in range(500)]
+    long = ['y' * random.randint(200, 512) for _ in range(500)]
+    elements = short + long
+    batch_size = 32
+
+    def batch_efficiency(batch):
+      """Returns (useful_tokens, padded_tokens) for one batch."""
+      lengths = [len(s) for s in batch]
+      return (sum(lengths), max(lengths) * len(lengths))
+
+    # Run WITH bucketing — collect (useful, padded) per batch
+    p_bucketed = TestPipeline('FnApiRunner')
+    bucketed_eff = (
+        p_bucketed
+        | 'CreateBucketed' >> beam.Create(elements)
+        | 'BatchBucketed' >> util.BatchElements(
+            min_batch_size=batch_size,
+            max_batch_size=batch_size,
+            max_batch_duration_secs=100,
+            length_fn=len,
+            bucket_boundaries=[16, 32, 64, 128, 256, 512])
+        | 'EffBucketed' >> beam.Map(batch_efficiency)
+        | 'SumBucketed' >> beam.CombineGlobally(
+            lambda pairs: (sum(p[0] for p in pairs), sum(p[1] for p in 
pairs))))
+
+    # Run WITHOUT bucketing
+    p_unbucketed = TestPipeline('FnApiRunner')
+    unbucketed_eff = (
+        p_unbucketed
+        | 'CreateUnbucketed' >> beam.Create(elements)
+        | 'BatchUnbucketed' >> util.BatchElements(
+            min_batch_size=batch_size,
+            max_batch_size=batch_size,
+            max_batch_duration_secs=100)
+        | 'EffUnbucketed' >> beam.Map(batch_efficiency)
+        | 'SumUnbucketed' >> beam.CombineGlobally(
+            lambda pairs: (sum(p[0] for p in pairs), sum(p[1] for p in 
pairs))))
+
+    def check_bucketed_above_threshold(totals):
+      useful, padded = totals[0]
+      eff = useful / padded if padded else 0
+      assert eff > 0.70, (
+          f'Bucketed padding efficiency {eff:.2%} should be > 70%')
+
+    def check_unbucketed_below_bucketed(totals):
+      useful, padded = totals[0]
+      eff = useful / padded if padded else 0
+      # With bimodal data in a single key, short elements get padded
+      # to the max of each batch which often includes long elements.
+      assert eff < 0.70, (
+          f'Unbucketed efficiency {eff:.2%} expected < 70% for '
+          f'bimodal distribution (sanity check)')
+
+    assert_that(bucketed_eff, check_bucketed_above_threshold)
+    res = p_bucketed.run()
+    res.wait_until_finish()
+
+    assert_that(unbucketed_eff, check_unbucketed_below_bucketed)
+    res = p_unbucketed.run()
+    res.wait_until_finish()
+
+  def test_with_length_bucket_key_setup_and_process(self):
+    """WithLengthBucketKey.setup() and process() work correctly in pipeline."""
+    boundaries = [10, 50]
+    elements = ['short', 'x' * 30, 'y' * 60]
+
+    with TestPipeline('FnApiRunner') as p:
+      result = (
+          p
+          | beam.Create(elements)
+          | beam.ParDo(util.WithLengthBucketKey(len, boundaries)))
+
+      def check_keys(keyed_elements):
+        # Each element should have format ((worker_key, bucket), element)
+        for (key, bucket), elem in keyed_elements:
+          # Verify key is a UUID string
+          assert isinstance(key, str) and len(key) > 0
+          # Verify bucket is correct
+          if len(elem) < 10:
+            assert bucket == 0, f'Expected bucket 0 for {elem}'
+          elif len(elem) < 50:
+            assert bucket == 1, f'Expected bucket 1 for {elem}'
+          else:
+            assert bucket == 2, f'Expected bucket 2 for {elem}'
+
+      assert_that(result, check_keys)
+
+  def test_bucket_boundaries_empty_list(self):
+    """Empty bucket_boundaries list raises ValueError."""
+    with self.assertRaises(ValueError):
+      util.BatchElements(
+          max_batch_duration_secs=100, length_fn=len, bucket_boundaries=[])
+
+  def test_with_custom_bucket_boundaries(self):
+    """Custom bucket_boundaries are used instead of defaults."""
+    custom_boundaries = [5, 15, 25]
+    be = util.BatchElements(
+        max_batch_duration_secs=100,
+        length_fn=len,
+        bucket_boundaries=custom_boundaries)
+    self.assertEqual(be._bucket_boundaries, custom_boundaries)
+
+  def test_length_fn_applied_in_pipeline(self):
+    """Verify length_fn is used for bucketing in stateful batching."""
+    # Create strings of different lengths that should go to different buckets
+    short_strings = ['x' * i for i in range(1, 5)]  # lengths 1-4, bucket 0
+    medium_strings = ['y' * i for i in range(20, 24)]  # lengths 20-23, bucket 
1
+    elements = short_strings + medium_strings
+
+    with TestPipeline('FnApiRunner') as p:
+      batches = (
+          p
+          | beam.Create(elements)
+          | util.BatchElements(
+              min_batch_size=2,
+              max_batch_size=10,
+              max_batch_duration_secs=100,
+              length_fn=len,
+              bucket_boundaries=[10, 30]))
+
+      def check_batch_homogeneity(batch):
+        """Batches should contain elements of similar length."""
+        lengths = [len(s) for s in batch]
+        # If bucketing works, all elements should be in same bucket
+        # (either all < 10 or all between 10 and 30)
+        min_len, max_len = min(lengths), max(lengths)
+        if min_len < 10:
+          # Short bucket: all should be < 10
+          assert max_len < 10, f'Mixed batch: {lengths}'
+        else:
+          # Medium bucket: all should be >= 10
+          assert min_len >= 10, f'Mixed batch: {lengths}'
+        return True
+
+      checks = batches | beam.Map(check_batch_homogeneity)
+      assert_that(checks, is_not_empty())
+
 
 class IdentityWindowTest(unittest.TestCase):
   def test_window_preserved(self):

Reply via email to