This is an automated email from the ASF dual-hosted git repository.

Abacn pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new b30af9eb79c BatchElements transform for Java SDK (#38369)
b30af9eb79c is described below

commit b30af9eb79cea657888fa8642c9e0b82387313ad
Author: Ganesh Sivakumar <[email protected]>
AuthorDate: Thu May 7 23:54:45 2026 +0530

    BatchElements transform for Java SDK (#38369)
    
    * skeleton
    
    * calculate batch size
    
    * global window batching
    
    * window awae batching
    
    * unit tests
    
    * java api docs
    
    * gemini comments
    
    * checkstyle
    
    * fix spotbug
    
    * update doc and changes.md
    
    * checkstyle
    
    * whitespace
    
    ---------
    
    Co-authored-by: Ganeshsivakumar <[email protected]>
---
 CHANGES.md                                         |   3 +-
 .../apache/beam/sdk/transforms/BatchElements.java  | 601 +++++++++++++++++++++
 .../beam/sdk/transforms/BatchElementsTest.java     | 598 ++++++++++++++++++++
 3 files changed, 1201 insertions(+), 1 deletion(-)

diff --git a/CHANGES.md b/CHANGES.md
index f9b9f1d2848..922c38ffef3 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -77,6 +77,7 @@
   compatible. Both coders can decode encoded bytes from the other coder
   ([#38139](https://github.com/apache/beam/issues/38139)).
 * (Python) Added type alias for with_exception_handling to be used for 
typehints. ([#38173](https://github.com/apache/beam/issues/38173)).
+* (Java) BatchElements transform for Java SDK 
([#38369](https://github.com/apache/beam/issues/38369))
 * Added plugin mechanism to support different Lineage implementations (Java) 
([#36790](https://github.com/apache/beam/issues/36790)).
 * (Python) Supported Python user type in Beam SQL. For example, SQL statements 
like `SELECT some_field from PCOLLECTION` can now operate a PCollection of Beam 
Row containing pickable Python user type 
([#20738](https://github.com/apache/beam/issues/20738)).
 * (Python) Introduced `beam.coders.registry.register_row` as preferred API to 
register a named tuple or dataclass with a Beam Row. At pipelne runtime, the 
original type associated with the registered row are preserved across the 
serialization boundary ([#38108](https://github.com/apache/beam/issues/38108)).
@@ -2435,4 +2436,4 @@ Schema Options, it will be removed in version `2.23.0`. 
([BEAM-9704](https://iss
 
 ## Highlights
 
-- For versions 2.19.0 and older release notes are available on [Apache Beam 
Blog](https://beam.apache.org/blog/).
+- For versions 2.19.0 and older release notes are available on [Apache Beam 
Blog](https://beam.apache.org/blog/).
\ No newline at end of file
diff --git 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/BatchElements.java
 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/BatchElements.java
new file mode 100644
index 00000000000..35796d1b138
--- /dev/null
+++ 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/BatchElements.java
@@ -0,0 +1,601 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.transforms;
+
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import javax.annotation.Nullable;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.WindowingStrategy;
+
+/**
+ * A {@link PTransform} that batches elements for amortized processing.
+ *
+ * <p>This transform is designed to precede operations whose processing cost 
is of the form:
+ *
+ * <pre>
+ *   time = fixed_cost + num_elements * per_element_cost
+ * </pre>
+ *
+ * <p>When the per-element cost is significantly smaller than the fixed cost, 
batching multiple
+ * elements together can amortize that fixed cost and improve overall 
throughput.
+ *
+ * <p>The transform consumes a {@code PCollection<T>} and produces a {@code 
PCollection<List<T>>},
+ * where each output element is a batch of input elements.
+ *
+ * <p>This transform dynamically determines an optimal batch size between the 
configured minimum and
+ * maximum values by profiling the execution time of downstream (fused) 
operations. To enforce a
+ * fixed batch size, set {@code minBatchSize == maxBatchSize}.
+ *
+ * <p>Elements are batched per window. Each emitted batch belongs to the same 
window as its elements
+ * and is assigned a timestamp at the end of that window.
+ *
+ * <h3>Example</h3>
+ *
+ * <pre>{@code
+ * // With default configuration
+ * pipeline
+ *     .apply("Create", Create.of(range(200)))
+ *     .apply("Batch", BatchElements.withDefaults())
+ *     .apply(...);
+ *
+ * // With custom configuration
+ * BatchElements.BatchConfig config =
+ *     BatchElements.BatchConfig.builder()
+ *         .withMinBatchSize(1)
+ *         .withMaxBatchSize(15)
+ *         .withTargetBatchDurationSecs(10.0)
+ *         .withTargetBatchOverhead(0.05)
+ *         .withVariance(0.0)
+ *         .build();
+ *
+ * pipeline
+ *     .apply("Create", Create.of(range(200)))
+ *     .apply("Batch", BatchElements.withConfig(config))
+ *     .apply(
+ *         "Sizes",
+ *         MapElements.via(
+ *             new SimpleFunction<List<Integer>, Integer>() {
+ *               @Override
+ *               public Integer apply(List<Integer> input) {
+ *                 return input.size();
+ *               }
+ *             }));
+ * }</pre>
+ *
+ * @param <T> the type of input elements
+ */
+public class BatchElements<T> extends PTransform<PCollection<T>, 
PCollection<List<T>>> {
+
+  private final BatchConfig config;
+
+  private BatchElements(BatchConfig config) {
+    this.config = config;
+  }
+
+  /** Batch Elements with default configuration. */
+  public static <T> BatchElements<T> withDefaults() {
+    return withConfig(BatchConfig.defaults());
+  }
+
+  public static <T> BatchElements<T> withConfig(BatchConfig config) {
+    return new BatchElements<>(config);
+  }
+  /**
+   * Configuration for {@link BatchElements}.
+   *
+   * <p>Controls how batch sizes are selected and adapted over time.
+   */
+  public static final class BatchConfig implements Serializable {
+    final int minBatchSize;
+    final int maxBatchSize;
+    final double targetBatchOverhead;
+    final double targetBatchDurationSecs;
+    final double targetBatchDurationSecsWithFixedCost;
+    final double variance;
+
+    private BatchConfig(Builder builder) {
+      this.minBatchSize = builder.minBatchSize;
+      this.maxBatchSize = builder.maxBatchSize;
+      this.targetBatchOverhead = builder.targetBatchOverhead;
+      this.targetBatchDurationSecs = builder.targetBatchDurationSecs;
+      this.targetBatchDurationSecsWithFixedCost = 
builder.targetBatchDurationSecsWithFixedCost;
+      this.variance = builder.variance;
+    }
+
+    static BatchConfig defaults() {
+      return BatchConfig.builder()
+          .withMinBatchSize(1)
+          .withMaxBatchSize(10000)
+          .withTargetBatchOverhead(0.05)
+          .withTargetBatchDurationSecs(10.0)
+          .withVariance(0.25)
+          .build();
+    }
+
+    /**
+     * Builder for {@link BatchConfig}.
+     *
+     * <p>Allows configuring batching constraints and tuning parameters.
+     */
+    public static Builder builder() {
+      return new Builder();
+    }
+
+    public static final class Builder {
+      private int minBatchSize = 1;
+      private int maxBatchSize = 10_000;
+      private double targetBatchOverhead = 0.05;
+      private double targetBatchDurationSecs = 10.0;
+      private double targetBatchDurationSecsWithFixedCost = -1; // -1 = unset
+      private double variance = 0.25;
+
+      private Builder() {}
+
+      /**
+       * Sets the minimum batch size.
+       *
+       * @param minBatchSize minimum number of elements per batch
+       */
+      public Builder withMinBatchSize(int minBatchSize) {
+        this.minBatchSize = minBatchSize;
+        return this;
+      }
+
+      /**
+       * Sets the maximum batch size.
+       *
+       * @param maxBatchSize maximum number of elements per batch
+       */
+      public Builder withMaxBatchSize(int maxBatchSize) {
+        this.maxBatchSize = maxBatchSize;
+        return this;
+      }
+
+      /**
+       * Sets the target batch overhead ratio.
+       *
+       * <p>This represents the desired ratio:
+       *
+       * <p>fixed_cost / total_time
+       *
+       * <p>Lower values favor larger batches (higher throughput, higher 
latency).
+       *
+       * @param targetBatchOverhead value in (0, 1]
+       */
+      public Builder withTargetBatchOverhead(double targetBatchOverhead) {
+        this.targetBatchOverhead = targetBatchOverhead;
+        return this;
+      }
+
+      /**
+       * Sets the target batch duration excluding fixed cost.
+       *
+       * <p>This controls the desired time spent processing elements in a 
batch, ignoring fixed
+       * overhead.
+       *
+       * @param targetBatchDurationSecs target duration in seconds
+       */
+      public Builder withTargetBatchDurationSecs(double 
targetBatchDurationSecs) {
+        this.targetBatchDurationSecs = targetBatchDurationSecs;
+        return this;
+      }
+      /**
+       * Sets the target batch duration including fixed cost.
+       *
+       * <p>If set, this provides a stricter upper bound on total batch 
processing time.
+       *
+       * @param value target duration in seconds
+       */
+      public Builder withTargetBatchDurationSecsWithFixedCost(double value) {
+        this.targetBatchDurationSecsWithFixedCost = value;
+        return this;
+      }
+
+      /**
+       * Sets the allowed variance when selecting batch sizes.
+       *
+       * <p>This introduces controlled randomness to avoid converging to a 
single batch size and
+       * improves robustness of estimation.
+       *
+       * @param variance relative deviation (e.g., 0.25 for ±25%)
+       */
+      public Builder withVariance(double variance) {
+        this.variance = variance;
+        return this;
+      }
+
+      public BatchConfig build() {
+        validate();
+        return new BatchConfig(this);
+      }
+
+      private void validate() {
+        if (minBatchSize > maxBatchSize) {
+          throw new IllegalArgumentException(
+              String.format(
+                  "Minimum (%d) must not be greater than maximum (%d)",
+                  minBatchSize, maxBatchSize));
+        }
+        if (!(targetBatchOverhead > 0 && targetBatchOverhead <= 1)) {
+          throw new IllegalArgumentException(
+              String.format(
+                  "targetBatchOverhead (%f) must be between 0 and 1", 
targetBatchOverhead));
+        }
+        if (targetBatchDurationSecs <= 0) {
+          throw new IllegalArgumentException(
+              String.format(
+                  "targetBatchDurationSecs (%f) must be positive", 
targetBatchDurationSecs));
+        }
+        if (targetBatchDurationSecsWithFixedCost != -1
+            && targetBatchDurationSecsWithFixedCost <= 0) {
+          throw new IllegalArgumentException(
+              String.format(
+                  "targetBatchDurationSecsWithFixedCost (%f) must be positive",
+                  targetBatchDurationSecsWithFixedCost));
+        }
+      }
+    }
+  }
+
+  static class BatchSizeEstimator implements Serializable {
+    private List<long[]> data = new ArrayList<>();
+    private final BatchConfig config;
+    private @Nullable Integer replayLastBatchSize = null; // null = no replay 
pending
+    private final Map<Integer, Integer>
+        batchSizeNumSeen; // tracks how many times each batch size seen
+    private boolean ignoreNextTiming = false;
+    private final Random random;
+
+    private static final int MAX_DATA_POINTS = 100;
+    private static final int MAX_GROWTH_FACTOR = 2;
+    private static final int WARMUP_BATCH_COUNT = 1;
+
+    public BatchSizeEstimator(BatchConfig config) {
+      this.config = config;
+      this.data = new ArrayList<>();
+      this.random = new Random();
+      this.batchSizeNumSeen = new HashMap<>();
+    }
+
+    public class Stopwatch implements AutoCloseable {
+      private final long startTime;
+      private final int batchSize;
+
+      public Stopwatch(int batchSize) {
+        this.batchSize = batchSize;
+        this.startTime = System.currentTimeMillis();
+      }
+
+      @Override
+      public void close() {
+        long elapsed = System.currentTimeMillis() - startTime;
+        if (ignoreNextTiming) {
+          ignoreNextTiming = false;
+          replayLastBatchSize = Math.min(batchSize, config.maxBatchSize);
+        } else {
+          data.add(new long[] {batchSize, elapsed});
+          if (data.size() >= MAX_DATA_POINTS) {
+            thinData();
+          }
+        }
+      }
+    }
+
+    public Stopwatch recordTime(int batchSize) {
+      return new Stopwatch(batchSize);
+    }
+
+    private void thinData() {
+      data.remove(random.nextInt(data.size() / 4));
+      data.remove(random.nextInt(data.size() / 2));
+    }
+
+    public void ignoreNextTiming() {
+      this.ignoreNextTiming = true;
+    }
+
+    private double[] linearRegression(double[] xs, double[] ys) {
+      int n = xs.length;
+      double xbar = 0, ybar = 0;
+      for (int i = 0; i < n; i++) {
+        xbar += xs[i];
+        ybar += ys[i];
+      }
+      xbar /= n;
+      ybar /= n;
+
+      if (xbar == 0) {
+        return new double[] {ybar, 0}; // a=ybar, b=0
+      }
+
+      // all batch sizes identical, can't separate fixed vs per-element cost
+      boolean allSame = true;
+      for (double x : xs) {
+        if (x != xs[0]) {
+          allSame = false;
+          break;
+        }
+      }
+      if (allSame) {
+        return new double[] {0, ybar / xbar}; // a=0, b=avg time per element
+      }
+
+      // fit the line
+      double num = 0, den = 0;
+      for (int i = 0; i < n; i++) {
+        num += (xs[i] - xbar) * (ys[i] - ybar);
+        den += (xs[i] - xbar) * (xs[i] - xbar);
+      }
+      double b = num / den;
+      double a = ybar - b * xbar;
+      return new double[] {a, b};
+    }
+
+    private int calculateNextBatchSize() {
+
+      // cold start
+      if (config.minBatchSize == config.maxBatchSize) {
+        return config.minBatchSize;
+      } else if (data.size() < 1) {
+        return config.minBatchSize;
+      } else if (data.size() < 2) {
+        // variety of regression
+        return (int)
+            Math.max(
+                Math.min(config.maxBatchSize, config.minBatchSize * 
MAX_GROWTH_FACTOR),
+                config.minBatchSize + 1);
+      }
+
+      // trim top 20% outliers
+      List<long[]> sorted = new ArrayList<>(data);
+      sorted.sort((p1, p2) -> Long.compare(p1[0], p2[0])); // sort by batch 
size
+      int trimSize = Math.max(20, sorted.size() * 4 / 5);
+      List<long[]> trimmed = sorted.subList(0, Math.min(trimSize, 
sorted.size()));
+
+      // find a and b (fixed cost and per element cost)
+      double[] xs = new double[trimmed.size()];
+      double[] ys = new double[trimmed.size()];
+      for (int i = 0; i < trimmed.size(); i++) {
+        xs[i] = trimmed.get(i)[0]; // batch size
+        ys[i] = trimmed.get(i)[1]; // elapsed ms
+      }
+      double[] ab = linearRegression(xs, ys);
+      double a = Math.max(ab[0], 1e-10); // floor at tiny value
+      double b = Math.max(ab[1], 1e-20);
+
+      // solve for target batch size
+      long lastBatchSize = data.get(data.size() - 1)[0];
+      int cap = (int) Math.min(lastBatchSize * MAX_GROWTH_FACTOR, 
config.maxBatchSize);
+
+      double target = config.maxBatchSize;
+
+      // convert to mills
+      double targetDurationMs = config.targetBatchDurationSecs * 1000.0;
+      double targetDurationWithFixedMs = 
config.targetBatchDurationSecsWithFixedCost * 1000.0;
+
+      // 1: a + b*x = targetDurationIncludingFixedCost
+      if (config.targetBatchDurationSecsWithFixedCost > 0) {
+        target = Math.min(target, (targetDurationWithFixedMs - a) / b);
+      }
+
+      // 2: b*x = targetDurationSecs
+      if (config.targetBatchDurationSecs > 0) {
+        target = Math.min(target, targetDurationMs / b);
+      }
+
+      // 3: a / (a + b*x) = targetOverhead
+      if (config.targetBatchOverhead > 0) {
+        target = Math.min(target, (a / b) * (1.0 / config.targetBatchOverhead 
- 1));
+      }
+
+      // add jitter to avoid any single batch size
+      int jitter = data.size() % 2;
+      if (data.size() > 10) {
+        target += (int) (target * config.variance * 2 * (random.nextDouble() - 
0.5));
+      }
+
+      return (int) Math.max(config.minBatchSize + jitter, Math.min(target, 
cap));
+    }
+
+    public int nextBatchSize() {
+      int result;
+
+      // Check if we should replay a previous batch size due to it not being 
recorded.
+      if (replayLastBatchSize != null) {
+        result = replayLastBatchSize;
+        replayLastBatchSize = null;
+      } else {
+        result = calculateNextBatchSize();
+      }
+
+      // track how many times we've seen this batch size
+      int seenCount = batchSizeNumSeen.getOrDefault(result, 0) + 1;
+      if (seenCount <= WARMUP_BATCH_COUNT) {
+        ignoreNextTiming();
+      }
+      batchSizeNumSeen.put(result, seenCount);
+
+      return result;
+    }
+  }
+
+  /** A {@link DoFn} that batches elements in the global window. */
+  @SuppressWarnings("initialization")
+  static class GlobalWindowsBatchingDoFn<T> extends DoFn<T, List<T>> {
+    private transient BatchSizeEstimator estimator;
+    private final BatchConfig config;
+    private List<T> batch;
+    private int runningBatchSize;
+    private int targetBatchSize;
+
+    public GlobalWindowsBatchingDoFn(BatchConfig config) {
+      this.config = config;
+    }
+
+    @Setup
+    public void setup() {
+      estimator = new BatchSizeEstimator(config);
+    }
+
+    @StartBundle
+    public void startBundle() {
+      batch = new ArrayList<>();
+      runningBatchSize = 0;
+      targetBatchSize = estimator.nextBatchSize();
+    }
+
+    @ProcessElement
+    public void processElement(@Element T element, OutputReceiver<List<T>> 
receiver) {
+      int elementSize = 1;
+      if (runningBatchSize + elementSize > targetBatchSize) {
+        if (runningBatchSize > 0 && !batch.isEmpty()) {
+          try (BatchElements.BatchSizeEstimator.Stopwatch sw =
+              estimator.recordTime(runningBatchSize)) {
+            receiver.output(batch); // emit full batch downstream
+          }
+        }
+        batch = new ArrayList<>();
+        runningBatchSize = 0;
+        targetBatchSize = estimator.nextBatchSize();
+      }
+      batch.add(element);
+      runningBatchSize += elementSize;
+    }
+
+    @FinishBundle
+    public void finishBundle(FinishBundleContext context) {
+      if (!batch.isEmpty()) {
+        try (BatchElements.BatchSizeEstimator.Stopwatch sw =
+            estimator.recordTime(runningBatchSize)) {
+          context.output( // flush leftover elements
+              batch,
+              GlobalWindow.INSTANCE.maxTimestamp(), // end of window timestamp
+              GlobalWindow.INSTANCE // global window
+              );
+        }
+      }
+    }
+  }
+
+  /**
+   * A {@link DoFn} that batches elements per window.
+   *
+   * <p>Maintains separate batches for each active window and emits batches 
when they reach the
+   * target size or when windows are evicted.
+   */
+  @SuppressWarnings("initialization")
+  static class WindowAwareBatchingDoFn<T> extends DoFn<T, List<T>> {
+    private transient BatchSizeEstimator estimator;
+    private final BatchConfig config;
+    private transient Map<BoundedWindow, SizedBatch<T>> batches;
+    private int targetBatchSize;
+
+    private static final int MAX_LIVE_WINDOWS = 10;
+
+    private static class SizedBatch<T> implements Serializable {
+      List<T> elements = new ArrayList<>();
+      int size = 0;
+    }
+
+    private WindowAwareBatchingDoFn(BatchConfig config) {
+      this.config = config;
+    }
+
+    @Setup
+    public void setup() {
+      estimator = new BatchSizeEstimator(config);
+    }
+
+    @StartBundle
+    public void startBundle() {
+      batches = new HashMap<>();
+      targetBatchSize = estimator.nextBatchSize();
+    }
+
+    @ProcessElement
+    public void processElement(
+        @Element T element, BoundedWindow window, OutputReceiver<List<T>> 
receiver) {
+
+      // get or create batch for this window
+      SizedBatch<T> batch = batches.computeIfAbsent(window, w -> new 
SizedBatch<>());
+
+      int elementSize = 1;
+
+      // emit if this window's batch is full
+      if (batch.size + elementSize > targetBatchSize) {
+        try (BatchSizeEstimator.Stopwatch sw = 
estimator.recordTime(batch.size)) {
+          receiver.output(batch.elements);
+        }
+        batches.remove(window);
+        targetBatchSize = estimator.nextBatchSize();
+        // create fresh batch for this window after emit
+        batch = batches.computeIfAbsent(window, w -> new SizedBatch<>());
+      }
+
+      batch.elements.add(element);
+      batch.size += elementSize;
+
+      // evict largest window if too many live windows
+      if (batches.size() > MAX_LIVE_WINDOWS) {
+        Map.Entry<BoundedWindow, SizedBatch<T>> largest =
+            batches.entrySet().stream().max(Comparator.comparingInt(e -> 
e.getValue().size)).get();
+
+        BoundedWindow targetWindow = largest.getKey();
+        SizedBatch<T> targetBatch = largest.getValue();
+
+        try (BatchSizeEstimator.Stopwatch sw = 
estimator.recordTime(targetBatch.size)) {
+
+          receiver.outputWithTimestamp(targetBatch.elements, 
targetWindow.maxTimestamp());
+        }
+
+        batches.remove(targetWindow);
+        targetBatchSize = estimator.nextBatchSize();
+      }
+    }
+
+    @FinishBundle
+    public void finishBundle(FinishBundleContext context) {
+      for (Map.Entry<BoundedWindow, SizedBatch<T>> entry : batches.entrySet()) 
{
+        BoundedWindow window = entry.getKey();
+        SizedBatch<T> batch = entry.getValue();
+        if (!batch.elements.isEmpty()) {
+          try (BatchSizeEstimator.Stopwatch sw = 
estimator.recordTime(batch.size)) {
+            context.output(batch.elements, window.maxTimestamp(), window);
+          }
+        }
+      }
+    }
+  }
+
+  @Override
+  public PCollection<List<T>> expand(PCollection<T> input) {
+    if 
(input.getWindowingStrategy().equals(WindowingStrategy.globalDefault())) {
+      return input.apply(ParDo.of(new GlobalWindowsBatchingDoFn<>(config)));
+    } else {
+      return input.apply(ParDo.of(new WindowAwareBatchingDoFn<>(config)));
+    }
+  }
+}
diff --git 
a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/BatchElementsTest.java
 
b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/BatchElementsTest.java
new file mode 100644
index 00000000000..70167b43faf
--- /dev/null
+++ 
b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/BatchElementsTest.java
@@ -0,0 +1,598 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.transforms;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+import java.io.Serializable;
+import java.lang.reflect.Method;
+import java.util.ArrayList;
+import java.util.List;
+import org.apache.beam.sdk.testing.NeedsRunner;
+import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.FixedWindows;
+import org.apache.beam.sdk.transforms.windowing.IntervalWindow;
+import org.apache.beam.sdk.transforms.windowing.Window;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.TimestampedValue;
+import org.joda.time.Duration;
+import org.joda.time.Instant;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.experimental.categories.Category;
+import org.junit.rules.Timeout;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class BatchElementsTest implements Serializable {
+
+  @Rule public transient TestPipeline pipeline = TestPipeline.create();
+
+  @Rule public transient Timeout globalTimeout = Timeout.seconds(120);
+
+  // Helpers
+
+  private static BatchElements.BatchConfig constantConfig(int size) {
+    return BatchElements.BatchConfig.builder()
+        .withMinBatchSize(size)
+        .withMaxBatchSize(size)
+        .withTargetBatchDurationSecs(10.0)
+        .withTargetBatchOverhead(0.05)
+        .withVariance(0.0)
+        .build();
+  }
+
+  // BatchConfig validation
+
+  @Test
+  public void testBatchConfigDefaults() {
+    BatchElements.BatchConfig config = BatchElements.BatchConfig.defaults();
+    assertEquals(1, config.minBatchSize);
+    assertEquals(10_000, config.maxBatchSize);
+    assertEquals(0.05, config.targetBatchOverhead, 1e-9);
+    assertEquals(10.0, config.targetBatchDurationSecs, 1e-9);
+    assertEquals(0.25, config.variance, 1e-9);
+  }
+
+  @Test(expected = IllegalArgumentException.class)
+  public void testBatchConfigMinGreaterThanMaxThrows() {
+    
BatchElements.BatchConfig.builder().withMinBatchSize(100).withMaxBatchSize(10).build();
+  }
+
+  @Test(expected = IllegalArgumentException.class)
+  public void testBatchConfigZeroTargetDurationThrows() {
+    
BatchElements.BatchConfig.builder().withTargetBatchDurationSecs(0.0).build();
+  }
+
+  @Test(expected = IllegalArgumentException.class)
+  public void testBatchConfigNegativeTargetDurationThrows() {
+    
BatchElements.BatchConfig.builder().withTargetBatchDurationSecs(-5.0).build();
+  }
+
+  @Test(expected = IllegalArgumentException.class)
+  public void testBatchConfigZeroOverheadThrows() {
+    BatchElements.BatchConfig.builder().withTargetBatchOverhead(0.0).build();
+  }
+
+  @Test(expected = IllegalArgumentException.class)
+  public void testBatchConfigOverheadAboveOneThrows() {
+    BatchElements.BatchConfig.builder().withTargetBatchOverhead(1.5).build();
+  }
+
+  @Test(expected = IllegalArgumentException.class)
+  public void testBatchConfigNegativeFixedCostDurationThrows() {
+    
BatchElements.BatchConfig.builder().withTargetBatchDurationSecsWithFixedCost(-5.0).build();
+  }
+
+  //  BatchSizeEstimator unit tests
+
+  @Test
+  public void testEstimatorConstantBatchSize() {
+    BatchElements.BatchConfig config = constantConfig(42);
+    BatchElements.BatchSizeEstimator estimator = new 
BatchElements.BatchSizeEstimator(config);
+    // When min == max, always return that size
+    for (int i = 0; i < 10; i++) {
+      assertEquals(42, estimator.nextBatchSize());
+    }
+  }
+
+  @Test
+  public void testEstimatorColdStartReturnsMinBatchSize() {
+    BatchElements.BatchConfig config =
+        BatchElements.BatchConfig.builder()
+            .withMinBatchSize(5)
+            .withMaxBatchSize(500)
+            .withTargetBatchDurationSecs(10.0)
+            .withTargetBatchOverhead(0.05)
+            .build();
+    BatchElements.BatchSizeEstimator estimator = new 
BatchElements.BatchSizeEstimator(config);
+    // No timing data yet — should return minBatchSize
+    assertEquals(5, estimator.nextBatchSize());
+  }
+
+  @Test
+  public void testEstimatorGrowsAfterTimingData() {
+    BatchElements.BatchConfig config =
+        BatchElements.BatchConfig.builder()
+            .withMinBatchSize(1)
+            .withMaxBatchSize(500)
+            .withTargetBatchDurationSecs(10.0)
+            .withTargetBatchOverhead(0.05)
+            .withVariance(0.0)
+            .build();
+    BatchElements.BatchSizeEstimator estimator = new 
BatchElements.BatchSizeEstimator(config);
+
+    // Warm up with some fake recordings
+    int size = estimator.nextBatchSize();
+    try (BatchElements.BatchSizeEstimator.Stopwatch sw = 
estimator.recordTime(size)) {
+      Thread.sleep(10);
+    } catch (InterruptedException e) {
+      throw new RuntimeException(e);
+    }
+    size = estimator.nextBatchSize();
+    try (BatchElements.BatchSizeEstimator.Stopwatch sw = 
estimator.recordTime(size)) {
+      Thread.sleep(10);
+    } catch (InterruptedException e) {
+      throw new RuntimeException(e);
+    }
+
+    int grown = estimator.nextBatchSize();
+    assertTrue("Estimator should grow beyond minBatchSize after data, got: " + 
grown, grown > 1);
+  }
+
+  @Test
+  public void testEstimatorNeverExceedsMaxBatchSize() {
+    BatchElements.BatchConfig config =
+        BatchElements.BatchConfig.builder()
+            .withMinBatchSize(1)
+            .withMaxBatchSize(10)
+            .withTargetBatchDurationSecs(10.0)
+            .withTargetBatchOverhead(0.05)
+            .withVariance(0.0)
+            .build();
+    BatchElements.BatchSizeEstimator estimator = new 
BatchElements.BatchSizeEstimator(config);
+
+    for (int i = 0; i < 50; i++) {
+      int next = estimator.nextBatchSize();
+      assertTrue("Batch size " + next + " exceeds max of 10", next <= 10);
+      try (BatchElements.BatchSizeEstimator.Stopwatch sw = 
estimator.recordTime(next)) {
+        Thread.sleep(10 + i);
+      } catch (InterruptedException e) {
+        throw new RuntimeException(e);
+      }
+    }
+  }
+
+  @Test
+  public void testEstimatorNeverGoesBelowMinBatchSize() {
+    BatchElements.BatchConfig config =
+        BatchElements.BatchConfig.builder()
+            .withMinBatchSize(7)
+            .withMaxBatchSize(500)
+            .withTargetBatchDurationSecs(10.0)
+            .withTargetBatchOverhead(0.05)
+            .build();
+    BatchElements.BatchSizeEstimator estimator = new 
BatchElements.BatchSizeEstimator(config);
+
+    for (int i = 0; i < 20; i++) {
+      int next = estimator.nextBatchSize();
+      assertTrue("Batch size " + next + " is below min of 7", next >= 7);
+      try (BatchElements.BatchSizeEstimator.Stopwatch sw = 
estimator.recordTime(next)) {
+        Thread.sleep(10 + i);
+      } catch (InterruptedException e) {
+        throw new RuntimeException(e);
+      }
+    }
+  }
+
+  @Test
+  public void testIgnoreNextTimingReplaysBatchSize() {
+    BatchElements.BatchConfig config =
+        BatchElements.BatchConfig.builder()
+            .withMinBatchSize(1)
+            .withMaxBatchSize(500)
+            .withTargetBatchDurationSecs(10.0)
+            .withTargetBatchOverhead(0.05)
+            .withVariance(0.0)
+            .build();
+    BatchElements.BatchSizeEstimator estimator = new 
BatchElements.BatchSizeEstimator(config);
+
+    estimator.ignoreNextTiming();
+    int first = estimator.nextBatchSize();
+
+    // After ignoreNextTiming, the stopwatch will set replayLastBatchSize
+    try (BatchElements.BatchSizeEstimator.Stopwatch sw = 
estimator.recordTime(first)) {}
+
+    // Next call should replay the same size
+    int replayed = estimator.nextBatchSize();
+    assertEquals(
+        "Expected replay of batch size " + first + " but got " + replayed, 
first, replayed);
+  }
+
+  // GlobalWindows pipeline tests
+  @Test
+  @Category(NeedsRunner.class)
+  public void testConstantBatchInGlobalWindow() {
+    // Mirrors Python: test_constant_batch
+    // Runner bundle boundaries are not fixed, so partial batches may be 
emitted at bundle end.
+    PCollection<Integer> output =
+        pipeline
+            .apply("Create", Create.of(range(35)))
+            .apply("Batch", BatchElements.withConfig(constantConfig(10)))
+            .apply(
+                "Sizes",
+                MapElements.via(
+                    new SimpleFunction<List<Integer>, Integer>() {
+                      @Override
+                      public Integer apply(List<Integer> batch) {
+                        return batch.size();
+                      }
+                    }));
+
+    PAssert.that(output).satisfies(sizes -> 
assertBatchSizesWithinLimitAndTotal(sizes, 10, 35));
+    pipeline.run().waitUntilFinish();
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testPipelineRespectsMaxBatchSizeAndPreservesElements() {
+    BatchElements.BatchConfig config =
+        BatchElements.BatchConfig.builder()
+            .withMinBatchSize(1)
+            .withMaxBatchSize(15)
+            .withTargetBatchDurationSecs(10.0)
+            .withTargetBatchOverhead(0.05)
+            .withVariance(0.0)
+            .build();
+
+    PCollection<Integer> sizes =
+        pipeline
+            .apply("Create", Create.of(range(200)))
+            .apply("Batch", BatchElements.withConfig(config))
+            .apply(
+                "Sizes",
+                MapElements.via(
+                    new SimpleFunction<List<Integer>, Integer>() {
+                      @Override
+                      public Integer apply(List<Integer> input) {
+                        return input.size();
+                      }
+                    }));
+
+    PAssert.that(sizes)
+        .satisfies(
+            s -> {
+              int total = 0;
+              for (int size : s) {
+                assertTrue("Batch size must be > 0", size > 0);
+                assertTrue("Batch size " + size + " exceeded maxBatchSize 15", 
size <= 15);
+                total += size;
+              }
+              assertEquals("All 200 elements must be present", 200, total);
+              return null;
+            });
+
+    pipeline.run().waitUntilFinish();
+  }
+
+  // WindowAware pipeline tests
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testWindowedBatches() {
+    // 47 elements across FixedWindows(30s); runner bundle boundaries may 
split batches.
+    List<TimestampedValue<Integer>> timestamped = new ArrayList<>();
+    for (int i = 0; i < 47; i++) {
+      timestamped.add(TimestampedValue.of(i, new Instant((long) i * 1000)));
+    }
+
+    PCollection<Integer> sizes =
+        pipeline
+            .apply("Create", Create.timestamped(timestamped))
+            .apply(
+                "Window",
+                
Window.<Integer>into(FixedWindows.of(Duration.standardSeconds(30)))
+                    .withAllowedLateness(Duration.ZERO)
+                    .discardingFiredPanes())
+            .apply("Batch", BatchElements.withConfig(constantConfig(10)))
+            .apply(
+                "Sizes",
+                MapElements.via(
+                    new SimpleFunction<List<Integer>, Integer>() {
+                      @Override
+                      public Integer apply(List<Integer> input) {
+                        return input.size();
+                      }
+                    }));
+
+    PAssert.that(sizes).satisfies(s -> assertBatchSizesWithinLimitAndTotal(s, 
10, 47));
+
+    pipeline.run().waitUntilFinish();
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testCrossWindowIsolation() {
+    // Elements from different windows must NEVER appear in the same batch
+    List<TimestampedValue<String>> elements = new ArrayList<>();
+    for (int i = 0; i < 2000; i++) {
+      // 4 windows of 2s each, 500 elements per window
+      long ts = (long) (i / 500) * 2000L + (i % 500);
+      elements.add(TimestampedValue.of("w" + (i / 500) + "-e" + i, new 
Instant(ts)));
+    }
+
+    PCollection<Boolean> isolationChecks =
+        pipeline
+            .apply("Create", Create.timestamped(elements))
+            .apply(
+                "Window",
+                
Window.<String>into(FixedWindows.of(Duration.standardSeconds(2)))
+                    .withAllowedLateness(Duration.ZERO)
+                    .discardingFiredPanes())
+            .apply("Batch", BatchElements.withConfig(constantConfig(50)))
+            .apply(
+                "CheckIsolation",
+                MapElements.via(
+                    new SimpleFunction<List<String>, Boolean>() {
+                      @Override
+                      public Boolean apply(List<String> batch) {
+                        // All elements in a batch must share the same window 
prefix
+                        String firstWindow = batch.get(0).substring(0, 2);
+                        for (String el : batch) {
+                          if (!el.startsWith(firstWindow)) {
+                            throw new AssertionError(
+                                "Cross-window contamination: "
+                                    + el
+                                    + " in batch starting with "
+                                    + firstWindow);
+                          }
+                        }
+                        return true;
+                      }
+                    }));
+
+    PAssert.that(isolationChecks)
+        .satisfies(
+            checks -> {
+              int count = 0;
+              for (boolean ok : checks) {
+                assertTrue(ok);
+                count++;
+              }
+              assertTrue("Expected at least one batch", count > 0);
+              return null;
+            });
+
+    pipeline.run().waitUntilFinish();
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testWindowedBatchesPreserveAllElements() {
+    // Total elements across all windows must equal input count
+    int numElements = 500;
+    List<TimestampedValue<Integer>> elements = new ArrayList<>();
+    for (int i = 0; i < numElements; i++) {
+      elements.add(TimestampedValue.of(i, new Instant((long) (i / 100) * 5000L 
+ i)));
+    }
+
+    PCollection<Integer> sizes =
+        pipeline
+            .apply("Create", Create.timestamped(elements))
+            .apply(
+                "Window",
+                
Window.<Integer>into(FixedWindows.of(Duration.standardSeconds(5)))
+                    .withAllowedLateness(Duration.ZERO)
+                    .discardingFiredPanes())
+            .apply("Batch", BatchElements.withConfig(constantConfig(30)))
+            .apply(
+                "Sizes",
+                MapElements.via(
+                    new SimpleFunction<List<Integer>, Integer>() {
+                      @Override
+                      public Integer apply(List<Integer> input) {
+                        return input.size();
+                      }
+                    }));
+
+    PAssert.that(sizes).satisfies(s -> assertBatchSizesWithinLimitAndTotal(s, 
30, numElements));
+
+    pipeline.run().waitUntilFinish();
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testEvictionPreservesWindowMetadata() {
+    // 1. Use 12 windows to trigger the MAX_LIVE_WINDOWS (10) eviction
+    List<TimestampedValue<Integer>> elements = new ArrayList<>();
+    for (int w = 0; w < 12; w++) {
+      long timestamp = w * 10000L; // Distinct windows
+      elements.add(TimestampedValue.of(w, new Instant(timestamp)));
+    }
+
+    PCollection<List<Integer>> batched =
+        pipeline
+            .apply(Create.timestamped(elements))
+            .apply(Window.into(FixedWindows.of(Duration.standardSeconds(5))))
+            .apply(
+                BatchElements.withConfig(
+                    constantConfig(10))); // Large size so they don't flush 
naturally
+
+    // 2. Use ParDo to capture the ACTUAL window from the context
+    PCollection<KV<IntervalWindow, List<Integer>>> windowCaptured =
+        batched.apply(
+            ParDo.of(
+                new DoFn<List<Integer>, KV<IntervalWindow, List<Integer>>>() {
+                  @ProcessElement
+                  public void process(
+                      @Element List<Integer> e,
+                      BoundedWindow w,
+                      OutputReceiver<KV<IntervalWindow, List<Integer>>> r) {
+                    r.output(KV.of((IntervalWindow) w, e));
+                  }
+                }));
+
+    // 3. Assert that the value inside the batch matches the window it was 
found in
+    PAssert.that(windowCaptured)
+        .satisfies(
+            items -> {
+              for (KV<IntervalWindow, List<Integer>> item : items) {
+                int windowIndex = (int) (item.getKey().start().getMillis() / 
10000L);
+                for (Integer val : item.getValue()) {
+                  assertEquals(
+                      "Element " + val + " found in wrong window!", (Integer) 
windowIndex, val);
+                }
+              }
+              return null;
+            });
+
+    pipeline.run();
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testWindowedBatchMaxSizeRespected() {
+    int maxBatch = 20;
+    List<TimestampedValue<Integer>> elements = new ArrayList<>();
+    for (int i = 0; i < 300; i++) {
+      elements.add(TimestampedValue.of(i, new Instant((long) (i / 150) * 3000L 
+ i)));
+    }
+
+    BatchElements.BatchConfig config =
+        BatchElements.BatchConfig.builder()
+            .withMinBatchSize(1)
+            .withMaxBatchSize(maxBatch)
+            .withTargetBatchDurationSecs(10.0)
+            .withTargetBatchOverhead(0.05)
+            .withVariance(0.0)
+            .build();
+
+    PCollection<Integer> sizes =
+        pipeline
+            .apply("Create", Create.timestamped(elements))
+            .apply(
+                "Window",
+                
Window.<Integer>into(FixedWindows.of(Duration.standardSeconds(3)))
+                    .withAllowedLateness(Duration.ZERO)
+                    .discardingFiredPanes())
+            .apply("Batch", BatchElements.withConfig(config))
+            .apply(
+                "Sizes",
+                MapElements.via(
+                    new SimpleFunction<List<Integer>, Integer>() {
+                      @Override
+                      public Integer apply(List<Integer> input) {
+                        return input.size();
+                      }
+                    }));
+
+    PAssert.that(sizes)
+        .satisfies(
+            s -> {
+              int total = 0;
+              for (int size : s) {
+                assertTrue("Batch size must be > 0", size > 0);
+                assertTrue("Batch size " + size + " exceeded max " + maxBatch, 
size <= maxBatch);
+                total += size;
+              }
+              assertEquals("All 300 elements must be present", 300, total);
+              return null;
+            });
+
+    pipeline.run().waitUntilFinish();
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testWindowRoutingGlobalWindowUsesGlobalDoFn() {
+    // A plain Create (no windowing) should route through 
GlobalWindowsBatchingDoFn.
+    // We validate this indirectly: output is correct and no windowing errors 
occur.
+    PCollection<Integer> sizes =
+        pipeline
+            .apply("Create", Create.of(range(25)))
+            .apply("Batch", BatchElements.withConfig(constantConfig(10)))
+            .apply(
+                "Sizes",
+                MapElements.via(
+                    new SimpleFunction<List<Integer>, Integer>() {
+                      @Override
+                      public Integer apply(List<Integer> input) {
+                        return input.size();
+                      }
+                    }));
+
+    PAssert.that(sizes).satisfies(s -> assertBatchSizesWithinLimitAndTotal(s, 
10, 25));
+
+    pipeline.run().waitUntilFinish();
+  }
+
+  // LinearRegression unit tests
+
+  @Test
+  public void testLinearRegressionPerfectFit() throws Exception {
+    double[] ab = linearRegression(new double[] {1, 2, 3, 4, 5}, new double[] 
{3, 5, 7, 9, 11});
+
+    assertEquals(1.0, ab[0], 1e-9);
+    assertEquals(2.0, ab[1], 1e-9);
+  }
+
+  @Test
+  public void testLinearRegressionRepeatedXsUsesMeanTimePerElement() throws 
Exception {
+    double[] ab = linearRegression(new double[] {5, 5, 5, 5}, new double[] 
{10, 15, 20, 25});
+
+    assertEquals(0.0, ab[0], 1e-9);
+    assertEquals(3.5, ab[1], 1e-9);
+  }
+
+  // Utility
+
+  private static Void assertBatchSizesWithinLimitAndTotal(
+      Iterable<Integer> sizes, int maxBatchSize, int expectedTotal) {
+    int total = 0;
+    for (int size : sizes) {
+      assertTrue("Batch size must be > 0", size > 0);
+      assertTrue("Batch size " + size + " exceeded max " + maxBatchSize, size 
<= maxBatchSize);
+      total += size;
+    }
+    assertEquals("All elements must be present", expectedTotal, total);
+    return null;
+  }
+
+  private static double[] linearRegression(double[] xs, double[] ys) throws 
Exception {
+    BatchElements.BatchSizeEstimator estimator =
+        new 
BatchElements.BatchSizeEstimator(BatchElements.BatchConfig.defaults());
+    Method method =
+        BatchElements.BatchSizeEstimator.class.getDeclaredMethod(
+            "linearRegression", double[].class, double[].class);
+    method.setAccessible(true);
+    return (double[]) method.invoke(estimator, xs, ys);
+  }
+
+  private static List<Integer> range(int n) {
+    List<Integer> list = new ArrayList<>(n);
+    for (int i = 0; i < n; i++) {
+      list.add(i);
+    }
+    return list;
+  }
+}


Reply via email to