nehsyc commented on a change in pull request #14852:
URL: https://github.com/apache/beam/pull/14852#discussion_r638216523
##########
File path:
runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/GroupIntoBatchesOverride.java
##########
@@ -87,14 +94,22 @@ private BatchGroupIntoBatches(long batchSize) {
new DoFn<KV<K, Iterable<V>>, KV<K, Iterable<V>>>() {
@ProcessElement
public void process(ProcessContext c) {
- // Iterators.partition lazily creates the partitions as
they are accessed
- // allowing it to partition very large iterators.
- Iterator<List<V>> iterator =
-
Iterators.partition(c.element().getValue().iterator(), (int) batchSize);
-
- // Note that GroupIntoBatches only outputs when the
batch is non-empty.
- while (iterator.hasNext()) {
- c.output(KV.of(c.element().getKey(), iterator.next()));
+ List<V> currentBatch = Lists.newArrayList();
+ long batchSizeBytes = 0;
+ for (V element : c.element().getValue()) {
+ currentBatch.add(element);
+ if (weigher != null) {
+ batchSizeBytes += weigher.apply(element);
+ }
+ if (currentBatch.size() == maxBatchSizeElements
+ || (maxBatchSizeBytes != Long.MAX_VALUE
+ && batchSizeBytes >= maxBatchSizeBytes)) {
+ c.output(KV.of(c.element().getKey(), currentBatch));
+ // Call clear() since that allows us to reuse the
array memory for
+ // subsequent batches.
+ currentBatch.clear();
+ batchSizeBytes = 0;
+ }
}
Review comment:
Should we also emit `currentBatch` (the remaining partial batch) after
the loop?
##########
File path:
runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java
##########
@@ -1728,24 +1733,93 @@ private void verifyGroupIntoBatchesOverride(
}
PAssert.thatMultimap(output)
.satisfies(
- new SerializableFunction<Map<String, Iterable<Iterable<Integer>>>,
Void>() {
- @Override
- public Void apply(Map<String, Iterable<Iterable<Integer>>>
input) {
- assertEquals(2, input.size());
- assertThat(input.keySet(), containsInAnyOrder("A", "B"));
- Map<String, Integer> sums = new HashMap<>();
- for (Map.Entry<String, Iterable<Iterable<Integer>>> entry :
input.entrySet()) {
- for (Iterable<Integer> batch : entry.getValue()) {
- assertThat(Iterables.size(batch),
lessThanOrEqualTo(batchSize));
- for (Integer value : batch) {
- sums.put(entry.getKey(), value +
sums.getOrDefault(entry.getKey(), 0));
- }
+ i -> {
+ assertEquals(2, i.size());
+ assertThat(i.keySet(), containsInAnyOrder("A", "B"));
+ Map<String, Integer> sums = new HashMap<>();
+ for (Map.Entry<String, Iterable<Iterable<Integer>>> entry :
i.entrySet()) {
+ for (Iterable<Integer> batch : entry.getValue()) {
+ assertThat(Iterables.size(batch),
lessThanOrEqualTo(batchSize));
+ for (Integer value : batch) {
+ sums.put(entry.getKey(), value +
sums.getOrDefault(entry.getKey(), 0));
}
}
- assertEquals(15, (int) sums.get("A"));
- assertEquals(0, (int) sums.get("B"));
- return null;
}
+ assertEquals(15, (int) sums.get("A"));
+ assertEquals(0, (int) sums.get("B"));
+ return null;
+ });
+ p.run();
+
+ AtomicBoolean sawGroupIntoBatchesOverride = new AtomicBoolean(false);
+ p.traverseTopologically(
+ new PipelineVisitor.Defaults() {
+
+ @Override
+ public CompositeBehavior enterCompositeTransform(Node node) {
+ if (p.getOptions().as(StreamingOptions.class).isStreaming()
+ && node.getTransform()
+ instanceof
GroupIntoBatchesOverride.StreamingGroupIntoBatchesWithShardedKey) {
+ sawGroupIntoBatchesOverride.set(true);
+ }
+ if (!p.getOptions().as(StreamingOptions.class).isStreaming()
+ && node.getTransform() instanceof
GroupIntoBatchesOverride.BatchGroupIntoBatches) {
+ sawGroupIntoBatchesOverride.set(true);
+ }
+ if (!p.getOptions().as(StreamingOptions.class).isStreaming()
+ && node.getTransform()
+ instanceof
GroupIntoBatchesOverride.BatchGroupIntoBatchesWithShardedKey) {
+ sawGroupIntoBatchesOverride.set(true);
+ }
+ return CompositeBehavior.ENTER_TRANSFORM;
+ }
+ });
+ if (expectOverriden) {
+ assertTrue(sawGroupIntoBatchesOverride.get());
+ } else {
+ assertFalse(sawGroupIntoBatchesOverride.get());
+ }
+ }
+
+ private void verifyGroupIntoBatchesOverrideBytes(
+ Pipeline p, Boolean withShardedKey, Boolean expectOverriden) {
+ final long batchSizeBytes = 2;
+ List<KV<String, String>> testValues =
+ Arrays.asList(
+ KV.of("A", "a"),
+ KV.of("A", "ab"),
+ KV.of("A", "abc"),
+ KV.of("A", "abcd"),
+ KV.of("A", "abcde"));
+ PCollection<KV<String, String>> input = p.apply(Create.of(testValues));
+ PCollection<KV<String, Iterable<String>>> output;
+ if (withShardedKey) {
+ output =
+ input
+ .apply(GroupIntoBatches.<String,
String>ofSize(batchSizeBytes).withShardedKey())
Review comment:
ofByteSize(batchSizeBytes)
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]