This is an automated email from the ASF dual-hosted git repository.
yhu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new e2be9fd8e36 Plumb custom batch parameters for autosharding from
WriteFiles to FileIO. (#37463)
e2be9fd8e36 is described below
commit e2be9fd8e36d3ada86a23fffcb8681674138ae89
Author: Celeste Zeng <[email protected]>
AuthorDate: Thu Feb 5 10:23:24 2026 -0800
Plumb custom batch parameters for autosharding from WriteFiles to FileIO.
(#37463)
* Address circular dependencies.
* Fix formatting.
* Fix tests.
* Fix lint.
* Remove unused import.
* Resolve circular dependency without removing __repr__.
* Fix formatting.
* Remove nextmark_json_util and move all its methods into nextmark_model.
* Restore millis_to_timestamp.
* Plumb custom batch params and add tests.
* Fix formatting and imports.
* Fix imports and test.
* Add missing import.
* Add another test case for byte count.
* Added checks for positive values.
---
.../main/java/org/apache/beam/sdk/io/FileIO.java | 66 ++++++++++
.../java/org/apache/beam/sdk/io/FileIOTest.java | 139 +++++++++++++++++++++
2 files changed, 205 insertions(+)
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileIO.java
b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileIO.java
index cfa06f3cf0d..5c9e19da160 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileIO.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileIO.java
@@ -1059,6 +1059,12 @@ public class FileIO {
abstract @Nullable Integer getMaxNumWritersPerBundle();
+ abstract @Nullable Integer getBatchSize();
+
+ abstract @Nullable Integer getBatchSizeBytes();
+
+ abstract @Nullable Duration getBatchMaxBufferingDuration();
+
abstract @Nullable ErrorHandler<BadRecord, ?> getBadRecordErrorHandler();
abstract Builder<DestinationT, UserT> toBuilder();
@@ -1112,6 +1118,13 @@ public class FileIO {
abstract Builder<DestinationT, UserT> setMaxNumWritersPerBundle(
@Nullable Integer maxNumWritersPerBundle);
+ abstract Builder<DestinationT, UserT> setBatchSize(@Nullable Integer
batchSize);
+
+ abstract Builder<DestinationT, UserT> setBatchSizeBytes(@Nullable
Integer batchSizeBytes);
+
+ abstract Builder<DestinationT, UserT> setBatchMaxBufferingDuration(
+ @Nullable Duration batchMaxBufferingDuration);
+
abstract Builder<DestinationT, UserT> setBadRecordErrorHandler(
@Nullable ErrorHandler<BadRecord, ?> badRecordErrorHandler);
@@ -1301,6 +1314,7 @@ public class FileIO {
*/
public Write<DestinationT, UserT> withNumShards(int numShards) {
checkArgument(numShards >= 0, "numShards must be non-negative, but was:
%s", numShards);
+ checkArgument(!getAutoSharding(), "Cannot set numShards when
withAutoSharding() is used");
if (numShards == 0) {
return withNumShards(null);
}
@@ -1311,6 +1325,7 @@ public class FileIO {
* Like {@link #withNumShards(int)}. Specifying {@code null} means
runner-determined sharding.
*/
public Write<DestinationT, UserT> withNumShards(@Nullable
ValueProvider<Integer> numShards) {
+ checkArgument(!getAutoSharding(), "Cannot set numShards when
withAutoSharding() is used");
return toBuilder().setNumShards(numShards).build();
}
@@ -1321,6 +1336,7 @@ public class FileIO {
public Write<DestinationT, UserT> withSharding(
PTransform<PCollection<UserT>, PCollectionView<Integer>> sharding) {
checkArgument(sharding != null, "sharding can not be null");
+ checkArgument(!getAutoSharding(), "Cannot set sharding when
withAutoSharding() is used");
return toBuilder().setSharding(sharding).build();
}
@@ -1337,6 +1353,9 @@ public class FileIO {
}
public Write<DestinationT, UserT> withAutoSharding() {
+ checkArgument(
+ getNumShards() == null && getSharding() == null,
+ "Cannot use withAutoSharding() when withNumShards() or
withSharding() is set");
return toBuilder().setAutoSharding(true).build();
}
@@ -1366,6 +1385,44 @@ public class FileIO {
return toBuilder().setBadRecordErrorHandler(errorHandler).build();
}
+ /**
+ * Returns a new {@link Write} that will batch the input records using
specified batch size. The
+ * default value is {@link WriteFiles#FILE_TRIGGERING_RECORD_COUNT}.
+ *
+ * <p>This option is used only for writing unbounded data with
auto-sharding.
+ */
+ public Write<DestinationT, UserT> withBatchSize(@Nullable Integer
batchSize) {
+ checkArgument(batchSize > 0, "batchSize must be positive, but was: %s",
batchSize);
+ return toBuilder().setBatchSize(batchSize).build();
+ }
+
+ /**
+ * Returns a new {@link Write} that will batch the input records using
specified batch size in
+ * bytes. The default value is {@link
WriteFiles#FILE_TRIGGERING_BYTE_COUNT}.
+ *
+ * <p>This option is used only for writing unbounded data with
auto-sharding.
+ */
+ public Write<DestinationT, UserT> withBatchSizeBytes(@Nullable Integer
batchSizeBytes) {
+ checkArgument(
+ batchSizeBytes > 0, "batchSizeBytes must be positive, but was: %s",
batchSizeBytes);
+ return toBuilder().setBatchSizeBytes(batchSizeBytes).build();
+ }
+
+ /**
+ * Returns a new {@link Write} that will batch the input records using
specified max buffering
+ * duration. The default value is {@link
WriteFiles#FILE_TRIGGERING_RECORD_BUFFERING_DURATION}.
+ *
+ * <p>This option is used only for writing unbounded data with
auto-sharding.
+ */
+ public Write<DestinationT, UserT> withBatchMaxBufferingDuration(
+ @Nullable Duration batchMaxBufferingDuration) {
+ checkArgument(
+ batchMaxBufferingDuration.isLongerThan(Duration.ZERO),
+ "batchMaxBufferingDuration must be positive, but was: %s",
+ batchMaxBufferingDuration);
+ return
toBuilder().setBatchMaxBufferingDuration(batchMaxBufferingDuration).build();
+ }
+
@VisibleForTesting
Contextful<Fn<DestinationT, FileNaming>> resolveFileNamingFn() {
if (getDynamic()) {
@@ -1482,6 +1539,15 @@ public class FileIO {
if (getBadRecordErrorHandler() != null) {
writeFiles =
writeFiles.withBadRecordErrorHandler(getBadRecordErrorHandler());
}
+ if (getBatchSize() != null) {
+ writeFiles = writeFiles.withBatchSize(getBatchSize());
+ }
+ if (getBatchSizeBytes() != null) {
+ writeFiles = writeFiles.withBatchSizeBytes(getBatchSizeBytes());
+ }
+ if (getBatchMaxBufferingDuration() != null) {
+ writeFiles =
writeFiles.withBatchMaxBufferingDuration(getBatchMaxBufferingDuration());
+ }
return input.apply(writeFiles);
}
diff --git
a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/FileIOTest.java
b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/FileIOTest.java
index d3c1f6680be..dffc4943bfa 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/FileIOTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/FileIOTest.java
@@ -19,11 +19,14 @@ package org.apache.beam.sdk.io;
import static
org.apache.beam.sdk.io.fs.ResolveOptions.StandardResolveOptions.RESOLVE_FILE;
import static
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.MoreObjects.firstNonNull;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.isA;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
+import java.io.BufferedReader;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
@@ -38,7 +41,9 @@ import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.StandardCopyOption;
import java.nio.file.attribute.FileTime;
+import java.util.ArrayList;
import java.util.Arrays;
+import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.zip.GZIPOutputStream;
@@ -46,6 +51,7 @@ import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.coders.VarIntCoder;
import org.apache.beam.sdk.io.fs.EmptyMatchTreatment;
import org.apache.beam.sdk.io.fs.MatchResult;
+import org.apache.beam.sdk.io.fs.MatchResult.Metadata;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.state.StateSpec;
import org.apache.beam.sdk.state.StateSpecs;
@@ -53,23 +59,30 @@ import org.apache.beam.sdk.state.ValueState;
import org.apache.beam.sdk.testing.NeedsRunner;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.testing.UsesUnboundedPCollections;
import org.apache.beam.sdk.testing.UsesUnboundedSplittableParDo;
import org.apache.beam.sdk.transforms.Contextful;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.MapElements;
+import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.Requirements;
import org.apache.beam.sdk.transforms.SerializableFunctions;
import org.apache.beam.sdk.transforms.View;
import org.apache.beam.sdk.transforms.Watch;
+import org.apache.beam.sdk.transforms.windowing.AfterWatermark;
+import org.apache.beam.sdk.transforms.windowing.FixedWindows;
import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
import org.apache.beam.sdk.transforms.windowing.PaneInfo;
+import org.apache.beam.sdk.transforms.windowing.Window;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollection.IsBounded;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.TypeDescriptor;
import org.apache.beam.sdk.values.TypeDescriptors;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists;
import org.joda.time.Duration;
import org.junit.Rule;
import org.junit.Test;
@@ -547,4 +560,130 @@ public class FileIOTest implements Serializable {
"Output file shard 0 exists after pipeline completes",
new File(outputFileName + "-0").exists());
}
+
+ @Test
+ @Category({NeedsRunner.class, UsesUnboundedPCollections.class})
+ public void testWriteUnboundedWithCustomBatchSize() throws IOException {
+ File root = tmpFolder.getRoot();
+ List<String> inputs = Arrays.asList("one", "two", "three", "four", "five",
"six");
+
+ PTransform<PCollection<String>, PCollection<String>> transform =
+ Window.<String>into(FixedWindows.of(Duration.standardSeconds(10)))
+ .triggering(AfterWatermark.pastEndOfWindow())
+ .withAllowedLateness(Duration.ZERO)
+ .discardingFiredPanes();
+
+ FileIO.Write<Void, String> write =
+ FileIO.<String>write()
+ .via(TextIO.sink())
+ .to(root.getAbsolutePath())
+ .withPrefix("output")
+ .withSuffix(".txt")
+ .withAutoSharding()
+ .withBatchSize(3)
+ .withBatchSizeBytes(1024 * 1024) // Set high to avoid triggering
flushing by byte count.
+ .withBatchMaxBufferingDuration(
+ Duration.standardMinutes(1)); // Set high to avoid triggering
flushing by duration.
+
+ // Prepare timestamps for the elements.
+ List<Long> timestamps = new ArrayList<>();
+ for (long i = 0; i < inputs.size(); i++) {
+ timestamps.add(i + 1);
+ }
+
+ p.apply(Create.timestamped(inputs,
timestamps).withCoder(StringUtf8Coder.of()))
+ .setIsBoundedInternal(IsBounded.UNBOUNDED)
+ .apply(transform)
+ .apply(write);
+ p.run().waitUntilFinish();
+
+ // Verify that the custom batch parameters are set.
+ assertEquals(3, write.getBatchSize().intValue());
+ assertEquals(1024 * 1024, write.getBatchSizeBytes().intValue());
+ assertEquals(Duration.standardMinutes(1),
write.getBatchMaxBufferingDuration());
+
+ // Verify file contents.
+ checkFileContents(root, "output", inputs);
+
+ // With auto-sharding, we can't assert on the exact number of output
files, but because
+ // batch size is 3 and there are 6 elements, we expect at least 2 files.
+ final String pattern = new File(root, "output").getAbsolutePath() + "*";
+ List<Metadata> metadata =
+
FileSystems.match(Collections.singletonList(pattern)).get(0).metadata();
+ assertTrue(metadata.size() >= 2);
+ }
+
+ @Test
+ @Category({NeedsRunner.class, UsesUnboundedPCollections.class})
+ public void testWriteUnboundedWithCustomBatchSizeBytes() throws IOException {
+ File root = tmpFolder.getRoot();
+ // The elements plus newline characters give a total of 4+4+6+5+5+4=28
bytes.
+ List<String> inputs = Arrays.asList("one", "two", "three", "four", "five",
"six");
+ // Assign timestamps so that all elements fall into the same 10s window.
+ List<Long> timestamps = Arrays.asList(1L, 2L, 3L, 4L, 5L, 6L);
+
+ FileIO.Write<Void, String> write =
+ FileIO.<String>write()
+ .via(TextIO.sink())
+ .to(root.getAbsolutePath())
+ .withPrefix("output")
+ .withSuffix(".txt")
+ .withAutoSharding()
+ .withBatchSize(1000) // Set high to avoid flushing by record count.
+ .withBatchSizeBytes(10)
+ .withBatchMaxBufferingDuration(
+ Duration.standardMinutes(1)); // Set high to avoid flushing by
duration.
+
+ p.apply(Create.timestamped(inputs,
timestamps).withCoder(StringUtf8Coder.of()))
+ .setIsBoundedInternal(IsBounded.UNBOUNDED)
+ .apply(
+ Window.<String>into(FixedWindows.of(Duration.standardSeconds(10)))
+ .triggering(AfterWatermark.pastEndOfWindow())
+ .withAllowedLateness(Duration.ZERO)
+ .discardingFiredPanes())
+ .apply(write);
+
+ p.run().waitUntilFinish();
+
+ // Verify that the custom batch parameters are set.
+ assertEquals(1000, write.getBatchSize().intValue());
+ assertEquals(10, write.getBatchSizeBytes().intValue());
+ assertEquals(Duration.standardMinutes(1),
write.getBatchMaxBufferingDuration());
+ checkFileContents(root, "output", inputs);
+
+ // With auto-sharding, we cannot assert on the exact number of output
files. The BatchSizeBytes
+ // acts as a threshold for flushing; once buffer size reaches 10 bytes, a
flush is triggered,
+ // but more items may be added before it completes. With 28 bytes total,
we can only guarantee
+ // at least 2 files are produced.
+ final String pattern = new File(root, "output").getAbsolutePath() + "*";
+ List<Metadata> metadata =
+
FileSystems.match(Collections.singletonList(pattern)).get(0).metadata();
+ assertTrue(metadata.size() >= 2);
+ }
+
+ static void checkFileContents(File rootDir, String prefix, List<String>
inputs)
+ throws IOException {
+ List<File> outputFiles = Lists.newArrayList();
+ final String pattern = new File(rootDir, prefix).getAbsolutePath() + "*";
+ List<Metadata> metadata =
+
FileSystems.match(Collections.singletonList(pattern)).get(0).metadata();
+ for (Metadata meta : metadata) {
+ outputFiles.add(new File(meta.resourceId().toString()));
+ }
+ assertFalse("Should have produced at least 1 output file",
outputFiles.isEmpty());
+
+ List<String> actual = Lists.newArrayList();
+ for (File outputFile : outputFiles) {
+ List<String> actualShard = Lists.newArrayList();
+ try (BufferedReader reader =
+ Files.newBufferedReader(outputFile.toPath(),
StandardCharsets.UTF_8)) {
+ String line;
+ while ((line = reader.readLine()) != null) {
+ actualShard.add(line);
+ }
+ }
+ actual.addAll(actualShard);
+ }
+ assertThat(actual, containsInAnyOrder(inputs.toArray()));
+ }
}