nehsyc commented on a change in pull request #14852:
URL: https://github.com/apache/beam/pull/14852#discussion_r638216523



##########
File path: 
runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/GroupIntoBatchesOverride.java
##########
@@ -87,14 +94,22 @@ private BatchGroupIntoBatches(long batchSize) {
                   new DoFn<KV<K, Iterable<V>>, KV<K, Iterable<V>>>() {
                     @ProcessElement
                     public void process(ProcessContext c) {
-                      // Iterators.partition lazily creates the partitions as 
they are accessed
-                      // allowing it to partition very large iterators.
-                      Iterator<List<V>> iterator =
-                          
Iterators.partition(c.element().getValue().iterator(), (int) batchSize);
-
-                      // Note that GroupIntoBatches only outputs when the 
batch is non-empty.
-                      while (iterator.hasNext()) {
-                        c.output(KV.of(c.element().getKey(), iterator.next()));
+                      List<V> currentBatch = Lists.newArrayList();
+                      long batchSizeBytes = 0;
+                      for (V element : c.element().getValue()) {
+                        currentBatch.add(element);
+                        if (weigher != null) {
+                          batchSizeBytes += weigher.apply(element);
+                        }
+                        if (currentBatch.size() == maxBatchSizeElements
+                            || (maxBatchSizeBytes != Long.MAX_VALUE
+                                && batchSizeBytes >= maxBatchSizeBytes)) {
+                          c.output(KV.of(c.element().getKey(), currentBatch));
+                          // Call clear() since that allows us to reuse the 
array memory for
+                          // subsequent batches.
+                          currentBatch.clear();
+                          batchSizeBytes = 0;
+                        }
                       }

Review comment:
       Should we also emit `currentBatch` (the remaining partial batch) after 
the loop?

##########
File path: 
runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java
##########
@@ -1728,24 +1733,93 @@ private void verifyGroupIntoBatchesOverride(
     }
     PAssert.thatMultimap(output)
         .satisfies(
-            new SerializableFunction<Map<String, Iterable<Iterable<Integer>>>, 
Void>() {
-              @Override
-              public Void apply(Map<String, Iterable<Iterable<Integer>>> 
input) {
-                assertEquals(2, input.size());
-                assertThat(input.keySet(), containsInAnyOrder("A", "B"));
-                Map<String, Integer> sums = new HashMap<>();
-                for (Map.Entry<String, Iterable<Iterable<Integer>>> entry : 
input.entrySet()) {
-                  for (Iterable<Integer> batch : entry.getValue()) {
-                    assertThat(Iterables.size(batch), 
lessThanOrEqualTo(batchSize));
-                    for (Integer value : batch) {
-                      sums.put(entry.getKey(), value + 
sums.getOrDefault(entry.getKey(), 0));
-                    }
+            i -> {
+              assertEquals(2, i.size());
+              assertThat(i.keySet(), containsInAnyOrder("A", "B"));
+              Map<String, Integer> sums = new HashMap<>();
+              for (Map.Entry<String, Iterable<Iterable<Integer>>> entry : 
i.entrySet()) {
+                for (Iterable<Integer> batch : entry.getValue()) {
+                  assertThat(Iterables.size(batch), 
lessThanOrEqualTo(batchSize));
+                  for (Integer value : batch) {
+                    sums.put(entry.getKey(), value + 
sums.getOrDefault(entry.getKey(), 0));
                   }
                 }
-                assertEquals(15, (int) sums.get("A"));
-                assertEquals(0, (int) sums.get("B"));
-                return null;
               }
+              assertEquals(15, (int) sums.get("A"));
+              assertEquals(0, (int) sums.get("B"));
+              return null;
+            });
+    p.run();
+
+    AtomicBoolean sawGroupIntoBatchesOverride = new AtomicBoolean(false);
+    p.traverseTopologically(
+        new PipelineVisitor.Defaults() {
+
+          @Override
+          public CompositeBehavior enterCompositeTransform(Node node) {
+            if (p.getOptions().as(StreamingOptions.class).isStreaming()
+                && node.getTransform()
+                    instanceof 
GroupIntoBatchesOverride.StreamingGroupIntoBatchesWithShardedKey) {
+              sawGroupIntoBatchesOverride.set(true);
+            }
+            if (!p.getOptions().as(StreamingOptions.class).isStreaming()
+                && node.getTransform() instanceof 
GroupIntoBatchesOverride.BatchGroupIntoBatches) {
+              sawGroupIntoBatchesOverride.set(true);
+            }
+            if (!p.getOptions().as(StreamingOptions.class).isStreaming()
+                && node.getTransform()
+                    instanceof 
GroupIntoBatchesOverride.BatchGroupIntoBatchesWithShardedKey) {
+              sawGroupIntoBatchesOverride.set(true);
+            }
+            return CompositeBehavior.ENTER_TRANSFORM;
+          }
+        });
+    if (expectOverriden) {
+      assertTrue(sawGroupIntoBatchesOverride.get());
+    } else {
+      assertFalse(sawGroupIntoBatchesOverride.get());
+    }
+  }
+
+  private void verifyGroupIntoBatchesOverrideBytes(
+      Pipeline p, Boolean withShardedKey, Boolean expectOverriden) {
+    final long batchSizeBytes = 2;
+    List<KV<String, String>> testValues =
+        Arrays.asList(
+            KV.of("A", "a"),
+            KV.of("A", "ab"),
+            KV.of("A", "abc"),
+            KV.of("A", "abcd"),
+            KV.of("A", "abcde"));
+    PCollection<KV<String, String>> input = p.apply(Create.of(testValues));
+    PCollection<KV<String, Iterable<String>>> output;
+    if (withShardedKey) {
+      output =
+          input
+              .apply(GroupIntoBatches.<String, 
String>ofSize(batchSizeBytes).withShardedKey())

Review comment:
       ofByteSize(batchSizeBytes)




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to