This is an automated email from the ASF dual-hosted git repository.

yhu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new e2be9fd8e36 Plumb custom batch parameters for autosharding from 
WriteFiles to FileIO. (#37463)
e2be9fd8e36 is described below

commit e2be9fd8e36d3ada86a23fffcb8681674138ae89
Author: Celeste Zeng <[email protected]>
AuthorDate: Thu Feb 5 10:23:24 2026 -0800

    Plumb custom batch parameters for autosharding from WriteFiles to FileIO. 
(#37463)
    
    * Address circular dependencies.
    
    * Fix formatting.
    
    * Fix tests.
    
    * Fix lint.
    
    * Remove unused import.
    
    * Resolve circular dependency without removing __repr__.
    
    * Fix formatting.
    
    * Remove nextmark_json_util and move all its methods into nextmark_model.
    
    * Restore millis_to_timestamp.
    
    * Plumb custom batch params and add tests.
    
    * Fix formatting and imports.
    
    * Fix imports and test.
    
    * Add missing import.
    
    * Add another test case for byte count.
    
    * Added checks for positive values.
---
 .../main/java/org/apache/beam/sdk/io/FileIO.java   |  66 ++++++++++
 .../java/org/apache/beam/sdk/io/FileIOTest.java    | 139 +++++++++++++++++++++
 2 files changed, 205 insertions(+)

diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileIO.java 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileIO.java
index cfa06f3cf0d..5c9e19da160 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileIO.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileIO.java
@@ -1059,6 +1059,12 @@ public class FileIO {
 
     abstract @Nullable Integer getMaxNumWritersPerBundle();
 
+    abstract @Nullable Integer getBatchSize();
+
+    abstract @Nullable Integer getBatchSizeBytes();
+
+    abstract @Nullable Duration getBatchMaxBufferingDuration();
+
     abstract @Nullable ErrorHandler<BadRecord, ?> getBadRecordErrorHandler();
 
     abstract Builder<DestinationT, UserT> toBuilder();
@@ -1112,6 +1118,13 @@ public class FileIO {
       abstract Builder<DestinationT, UserT> setMaxNumWritersPerBundle(
           @Nullable Integer maxNumWritersPerBundle);
 
+      abstract Builder<DestinationT, UserT> setBatchSize(@Nullable Integer 
batchSize);
+
+      abstract Builder<DestinationT, UserT> setBatchSizeBytes(@Nullable 
Integer batchSizeBytes);
+
+      abstract Builder<DestinationT, UserT> setBatchMaxBufferingDuration(
+          @Nullable Duration batchMaxBufferingDuration);
+
       abstract Builder<DestinationT, UserT> setBadRecordErrorHandler(
           @Nullable ErrorHandler<BadRecord, ?> badRecordErrorHandler);
 
@@ -1301,6 +1314,7 @@ public class FileIO {
      */
     public Write<DestinationT, UserT> withNumShards(int numShards) {
       checkArgument(numShards >= 0, "numShards must be non-negative, but was: 
%s", numShards);
+      checkArgument(!getAutoSharding(), "Cannot set numShards when 
withAutoSharding() is used");
       if (numShards == 0) {
         return withNumShards(null);
       }
@@ -1311,6 +1325,7 @@ public class FileIO {
      * Like {@link #withNumShards(int)}. Specifying {@code null} means 
runner-determined sharding.
      */
     public Write<DestinationT, UserT> withNumShards(@Nullable 
ValueProvider<Integer> numShards) {
+      checkArgument(!getAutoSharding(), "Cannot set numShards when 
withAutoSharding() is used");
       return toBuilder().setNumShards(numShards).build();
     }
 
@@ -1321,6 +1336,7 @@ public class FileIO {
     public Write<DestinationT, UserT> withSharding(
         PTransform<PCollection<UserT>, PCollectionView<Integer>> sharding) {
       checkArgument(sharding != null, "sharding can not be null");
+      checkArgument(!getAutoSharding(), "Cannot set sharding when 
withAutoSharding() is used");
       return toBuilder().setSharding(sharding).build();
     }
 
@@ -1337,6 +1353,9 @@ public class FileIO {
     }
 
     public Write<DestinationT, UserT> withAutoSharding() {
+      checkArgument(
+          getNumShards() == null && getSharding() == null,
+          "Cannot use withAutoSharding() when withNumShards() or 
withSharding() is set");
       return toBuilder().setAutoSharding(true).build();
     }
 
@@ -1366,6 +1385,44 @@ public class FileIO {
       return toBuilder().setBadRecordErrorHandler(errorHandler).build();
     }
 
+    /**
+     * Returns a new {@link Write} that will batch the input records using 
specified batch size. The
+     * default value is {@link WriteFiles#FILE_TRIGGERING_RECORD_COUNT}.
+     *
+     * <p>This option is used only for writing unbounded data with 
auto-sharding.
+     */
+    public Write<DestinationT, UserT> withBatchSize(@Nullable Integer 
batchSize) {
+      checkArgument(batchSize > 0, "batchSize must be positive, but was: %s", 
batchSize);
+      return toBuilder().setBatchSize(batchSize).build();
+    }
+
+    /**
+     * Returns a new {@link Write} that will batch the input records using 
specified batch size in
+     * bytes. The default value is {@link 
WriteFiles#FILE_TRIGGERING_BYTE_COUNT}.
+     *
+     * <p>This option is used only for writing unbounded data with 
auto-sharding.
+     */
+    public Write<DestinationT, UserT> withBatchSizeBytes(@Nullable Integer 
batchSizeBytes) {
+      checkArgument(
+          batchSizeBytes > 0, "batchSizeBytes must be positive, but was: %s", 
batchSizeBytes);
+      return toBuilder().setBatchSizeBytes(batchSizeBytes).build();
+    }
+
+    /**
+     * Returns a new {@link Write} that will batch the input records using 
specified max buffering
+     * duration. The default value is {@link 
WriteFiles#FILE_TRIGGERING_RECORD_BUFFERING_DURATION}.
+     *
+     * <p>This option is used only for writing unbounded data with 
auto-sharding.
+     */
+    public Write<DestinationT, UserT> withBatchMaxBufferingDuration(
+        @Nullable Duration batchMaxBufferingDuration) {
+      checkArgument(
+          batchMaxBufferingDuration.isLongerThan(Duration.ZERO),
+          "batchMaxBufferingDuration must be positive, but was: %s",
+          batchMaxBufferingDuration);
+      return 
toBuilder().setBatchMaxBufferingDuration(batchMaxBufferingDuration).build();
+    }
+
     @VisibleForTesting
     Contextful<Fn<DestinationT, FileNaming>> resolveFileNamingFn() {
       if (getDynamic()) {
@@ -1482,6 +1539,15 @@ public class FileIO {
       if (getBadRecordErrorHandler() != null) {
         writeFiles = 
writeFiles.withBadRecordErrorHandler(getBadRecordErrorHandler());
       }
+      if (getBatchSize() != null) {
+        writeFiles = writeFiles.withBatchSize(getBatchSize());
+      }
+      if (getBatchSizeBytes() != null) {
+        writeFiles = writeFiles.withBatchSizeBytes(getBatchSizeBytes());
+      }
+      if (getBatchMaxBufferingDuration() != null) {
+        writeFiles = 
writeFiles.withBatchMaxBufferingDuration(getBatchMaxBufferingDuration());
+      }
       return input.apply(writeFiles);
     }
 
diff --git 
a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/FileIOTest.java 
b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/FileIOTest.java
index d3c1f6680be..dffc4943bfa 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/FileIOTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/FileIOTest.java
@@ -19,11 +19,14 @@ package org.apache.beam.sdk.io;
 
 import static 
org.apache.beam.sdk.io.fs.ResolveOptions.StandardResolveOptions.RESOLVE_FILE;
 import static 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.MoreObjects.firstNonNull;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.containsInAnyOrder;
 import static org.hamcrest.Matchers.isA;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertTrue;
 
+import java.io.BufferedReader;
 import java.io.File;
 import java.io.FileNotFoundException;
 import java.io.FileOutputStream;
@@ -38,7 +41,9 @@ import java.nio.file.Path;
 import java.nio.file.Paths;
 import java.nio.file.StandardCopyOption;
 import java.nio.file.attribute.FileTime;
+import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Collections;
 import java.util.List;
 import java.util.Objects;
 import java.util.zip.GZIPOutputStream;
@@ -46,6 +51,7 @@ import org.apache.beam.sdk.coders.StringUtf8Coder;
 import org.apache.beam.sdk.coders.VarIntCoder;
 import org.apache.beam.sdk.io.fs.EmptyMatchTreatment;
 import org.apache.beam.sdk.io.fs.MatchResult;
+import org.apache.beam.sdk.io.fs.MatchResult.Metadata;
 import org.apache.beam.sdk.options.PipelineOptionsFactory;
 import org.apache.beam.sdk.state.StateSpec;
 import org.apache.beam.sdk.state.StateSpecs;
@@ -53,23 +59,30 @@ import org.apache.beam.sdk.state.ValueState;
 import org.apache.beam.sdk.testing.NeedsRunner;
 import org.apache.beam.sdk.testing.PAssert;
 import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.testing.UsesUnboundedPCollections;
 import org.apache.beam.sdk.testing.UsesUnboundedSplittableParDo;
 import org.apache.beam.sdk.transforms.Contextful;
 import org.apache.beam.sdk.transforms.Create;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.MapElements;
+import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.transforms.Requirements;
 import org.apache.beam.sdk.transforms.SerializableFunctions;
 import org.apache.beam.sdk.transforms.View;
 import org.apache.beam.sdk.transforms.Watch;
+import org.apache.beam.sdk.transforms.windowing.AfterWatermark;
+import org.apache.beam.sdk.transforms.windowing.FixedWindows;
 import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
 import org.apache.beam.sdk.transforms.windowing.PaneInfo;
+import org.apache.beam.sdk.transforms.windowing.Window;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollection.IsBounded;
 import org.apache.beam.sdk.values.PCollectionView;
 import org.apache.beam.sdk.values.TypeDescriptor;
 import org.apache.beam.sdk.values.TypeDescriptors;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists;
 import org.joda.time.Duration;
 import org.junit.Rule;
 import org.junit.Test;
@@ -547,4 +560,130 @@ public class FileIOTest implements Serializable {
         "Output file shard 0 exists after pipeline completes",
         new File(outputFileName + "-0").exists());
   }
+
+  @Test
+  @Category({NeedsRunner.class, UsesUnboundedPCollections.class})
+  public void testWriteUnboundedWithCustomBatchSize() throws IOException {
+    File root = tmpFolder.getRoot();
+    List<String> inputs = Arrays.asList("one", "two", "three", "four", "five", 
"six");
+
+    PTransform<PCollection<String>, PCollection<String>> transform =
+        Window.<String>into(FixedWindows.of(Duration.standardSeconds(10)))
+            .triggering(AfterWatermark.pastEndOfWindow())
+            .withAllowedLateness(Duration.ZERO)
+            .discardingFiredPanes();
+
+    FileIO.Write<Void, String> write =
+        FileIO.<String>write()
+            .via(TextIO.sink())
+            .to(root.getAbsolutePath())
+            .withPrefix("output")
+            .withSuffix(".txt")
+            .withAutoSharding()
+            .withBatchSize(3)
+            .withBatchSizeBytes(1024 * 1024) // Set high to avoid triggering 
flushing by byte count.
+            .withBatchMaxBufferingDuration(
+                Duration.standardMinutes(1)); // Set high to avoid triggering 
flushing by duration.
+
+    // Prepare timestamps for the elements.
+    List<Long> timestamps = new ArrayList<>();
+    for (long i = 0; i < inputs.size(); i++) {
+      timestamps.add(i + 1);
+    }
+
+    p.apply(Create.timestamped(inputs, 
timestamps).withCoder(StringUtf8Coder.of()))
+        .setIsBoundedInternal(IsBounded.UNBOUNDED)
+        .apply(transform)
+        .apply(write);
+    p.run().waitUntilFinish();
+
+    // Verify that the custom batch parameters are set.
+    assertEquals(3, write.getBatchSize().intValue());
+    assertEquals(1024 * 1024, write.getBatchSizeBytes().intValue());
+    assertEquals(Duration.standardMinutes(1), 
write.getBatchMaxBufferingDuration());
+
+    // Verify file contents.
+    checkFileContents(root, "output", inputs);
+
+    // With auto-sharding, we can't assert on the exact number of output 
files, but because
+    // batch size is 3 and there are 6 elements, we expect at least 2 files.
+    final String pattern = new File(root, "output").getAbsolutePath() + "*";
+    List<Metadata> metadata =
+        
FileSystems.match(Collections.singletonList(pattern)).get(0).metadata();
+    assertTrue(metadata.size() >= 2);
+  }
+
+  @Test
+  @Category({NeedsRunner.class, UsesUnboundedPCollections.class})
+  public void testWriteUnboundedWithCustomBatchSizeBytes() throws IOException {
+    File root = tmpFolder.getRoot();
+    // The elements plus newline characters give a total of 4+4+6+5+5+4=28 
bytes.
+    List<String> inputs = Arrays.asList("one", "two", "three", "four", "five", 
"six");
+    // Assign timestamps so that all elements fall into the same 10s window.
+    List<Long> timestamps = Arrays.asList(1L, 2L, 3L, 4L, 5L, 6L);
+
+    FileIO.Write<Void, String> write =
+        FileIO.<String>write()
+            .via(TextIO.sink())
+            .to(root.getAbsolutePath())
+            .withPrefix("output")
+            .withSuffix(".txt")
+            .withAutoSharding()
+            .withBatchSize(1000) // Set high to avoid flushing by record count.
+            .withBatchSizeBytes(10)
+            .withBatchMaxBufferingDuration(
+                Duration.standardMinutes(1)); // Set high to avoid flushing by 
duration.
+
+    p.apply(Create.timestamped(inputs, 
timestamps).withCoder(StringUtf8Coder.of()))
+        .setIsBoundedInternal(IsBounded.UNBOUNDED)
+        .apply(
+            Window.<String>into(FixedWindows.of(Duration.standardSeconds(10)))
+                .triggering(AfterWatermark.pastEndOfWindow())
+                .withAllowedLateness(Duration.ZERO)
+                .discardingFiredPanes())
+        .apply(write);
+
+    p.run().waitUntilFinish();
+
+    // Verify that the custom batch parameters are set.
+    assertEquals(1000, write.getBatchSize().intValue());
+    assertEquals(10, write.getBatchSizeBytes().intValue());
+    assertEquals(Duration.standardMinutes(1), 
write.getBatchMaxBufferingDuration());
+    checkFileContents(root, "output", inputs);
+
+    // With auto-sharding, we cannot assert on the exact number of output 
files. The BatchSizeBytes
+    // acts as a threshold for flushing; once buffer size reaches 10 bytes, a 
flush is triggered,
+    // but more items may be added before it completes. With 28 bytes total, 
we can only guarantee
+    // at least 2 files are produced.
+    final String pattern = new File(root, "output").getAbsolutePath() + "*";
+    List<Metadata> metadata =
+        
FileSystems.match(Collections.singletonList(pattern)).get(0).metadata();
+    assertTrue(metadata.size() >= 2);
+  }
+
+  static void checkFileContents(File rootDir, String prefix, List<String> 
inputs)
+      throws IOException {
+    List<File> outputFiles = Lists.newArrayList();
+    final String pattern = new File(rootDir, prefix).getAbsolutePath() + "*";
+    List<Metadata> metadata =
+        
FileSystems.match(Collections.singletonList(pattern)).get(0).metadata();
+    for (Metadata meta : metadata) {
+      outputFiles.add(new File(meta.resourceId().toString()));
+    }
+    assertFalse("Should have produced at least 1 output file", 
outputFiles.isEmpty());
+
+    List<String> actual = Lists.newArrayList();
+    for (File outputFile : outputFiles) {
+      List<String> actualShard = Lists.newArrayList();
+      try (BufferedReader reader =
+          Files.newBufferedReader(outputFile.toPath(), 
StandardCharsets.UTF_8)) {
+        String line;
+        while ((line = reader.readLine()) != null) {
+          actualShard.add(line);
+        }
+      }
+      actual.addAll(actualShard);
+    }
+    assertThat(actual, containsInAnyOrder(inputs.toArray()));
+  }
 }

Reply via email to