echauchot commented on code in PR #22446:
URL: https://github.com/apache/beam/pull/22446#discussion_r976302083


##########
runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyTranslatorBatch.java:
##########
@@ -17,74 +17,264 @@
  */
 package org.apache.beam.runners.spark.structuredstreaming.translation.batch;
 
+import static 
org.apache.beam.repackaged.core.org.apache.commons.lang3.ArrayUtils.EMPTY_BYTE_ARRAY;
+import static 
org.apache.beam.runners.spark.structuredstreaming.translation.helpers.CoderHelpers.toByteArray;
+import static 
org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.collectionEncoder;
+import static 
org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.encoderOf;
+import static 
org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.kvEncoder;
+import static 
org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.windowedValueEncoder;
+import static 
org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.concat;
+import static 
org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.fun1;
+import static 
org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.fun2;
+import static 
org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.javaIterator;
+import static 
org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.listOf;
+import static 
org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.seqOf;
+import static org.apache.beam.sdk.transforms.windowing.PaneInfo.NO_FIRING;
+import static 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState;
+import static org.apache.spark.sql.functions.col;
+import static org.apache.spark.sql.functions.collect_list;
+import static org.apache.spark.sql.functions.explode;
+import static org.apache.spark.sql.functions.max;
+import static org.apache.spark.sql.functions.min;
+import static org.apache.spark.sql.functions.struct;
+
 import java.io.Serializable;
 import org.apache.beam.runners.core.InMemoryStateInternals;
-import org.apache.beam.runners.core.StateInternals;
+import org.apache.beam.runners.core.ReduceFnRunner;
 import org.apache.beam.runners.core.StateInternalsFactory;
 import org.apache.beam.runners.core.SystemReduceFn;
-import 
org.apache.beam.runners.spark.structuredstreaming.translation.AbstractTranslationContext;
-import 
org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator;
 import 
org.apache.beam.runners.spark.structuredstreaming.translation.batch.functions.GroupAlsoByWindowViaOutputBufferFn;
-import 
org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers;
-import 
org.apache.beam.runners.spark.structuredstreaming.translation.helpers.KVHelpers;
-import org.apache.beam.sdk.coders.Coder;
-import org.apache.beam.sdk.coders.IterableCoder;
 import org.apache.beam.sdk.coders.KvCoder;
-import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.GroupByKey;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
+import org.apache.beam.sdk.transforms.windowing.PaneInfo.PaneInfoCoder;
+import org.apache.beam.sdk.transforms.windowing.TimestampCombiner;
 import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.values.KV;
-import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.WindowingStrategy;
+import org.apache.spark.sql.Column;
 import org.apache.spark.sql.Dataset;
-import org.apache.spark.sql.KeyValueGroupedDataset;
+import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.TypedColumn;
+import org.apache.spark.sql.catalyst.expressions.CreateArray;
+import org.apache.spark.sql.catalyst.expressions.CreateNamedStruct;
+import org.apache.spark.sql.catalyst.expressions.Expression;
+import org.apache.spark.sql.catalyst.expressions.Literal;
+import org.apache.spark.sql.catalyst.expressions.Literal$;
+import org.apache.spark.sql.types.DataType;
+import org.apache.spark.sql.types.DataTypes;
+import org.checkerframework.checker.nullness.qual.NonNull;
+import scala.Tuple2;
+import scala.collection.Iterator;
+import scala.collection.Seq;
+import scala.collection.immutable.List;
 
+/**
+ * Translator for {@link GroupByKey} using {@link Dataset#groupByKey} with the 
build-in aggregation
+ * function {@code collect_list} when applicable.
+ *
+ * <p>Note: Using {@code collect_list} isn't any worse than using {@link 
ReduceFnRunner}. In the
+ * latter case the entire group (iterator) has to be loaded into memory as 
well, risking OOM errors
+ * in both cases. When disabling {@link #useCollectList}, a more memory 
sensitive iterable is used
+ * that can be traversed just once. Attempting to traverse the iterable again 
will throw.
+ *
+ * <ul>
+ *   <li>When using the default global window, window information is dropped 
and restored after the
+ *       aggregation.
+ *   <li>For non-merging windows, windows are exploded and moved into a 
composite key for better
+ *       distribution. Though, to keep the amount of shuffled data low, this 
is only done if values
+ *       are assigned to a single window or if there are only few keys and 
distributing data is
+ *       important. After the aggregation, windowed values are restored from 
the composite key.
+ *   <li>All other cases are implemented using the SDK {@link ReduceFnRunner}.
+ * </ul>
+ */
 class GroupByKeyTranslatorBatch<K, V>
-    implements TransformTranslator<
-        PTransform<PCollection<KV<K, V>>, PCollection<KV<K, Iterable<V>>>>> {
+    extends GroupingTranslator<K, V, Iterable<V>, GroupByKey<K, V>> {
+
+  /** Literal of binary encoded Pane info. */
+  private static final Expression PANE_NO_FIRING = lit(toByteArray(NO_FIRING, 
PaneInfoCoder.of()));
+
+  /** Defaults for value in single global window. */
+  private static final List<Expression> GLOBAL_WINDOW_DETAILS =
+      windowDetails(lit(new byte[][] {EMPTY_BYTE_ARRAY}));
+
+  private boolean useCollectList = true;
+
+  public GroupByKeyTranslatorBatch() {}
+
+  public GroupByKeyTranslatorBatch(boolean useCollectList) {
+    this.useCollectList = useCollectList;
+  }
 
   @Override
-  public void translateTransform(
-      PTransform<PCollection<KV<K, V>>, PCollection<KV<K, Iterable<V>>>> 
transform,
-      AbstractTranslationContext context) {
-
-    @SuppressWarnings("unchecked")
-    final PCollection<KV<K, V>> inputPCollection = (PCollection<KV<K, V>>) 
context.getInput();
-    Dataset<WindowedValue<KV<K, V>>> input = 
context.getDataset(inputPCollection);
-    WindowingStrategy<?, ?> windowingStrategy = 
inputPCollection.getWindowingStrategy();
-    KvCoder<K, V> kvCoder = (KvCoder<K, V>) inputPCollection.getCoder();
-    Coder<V> valueCoder = kvCoder.getValueCoder();
-
-    // group by key only
-    Coder<K> keyCoder = kvCoder.getKeyCoder();
-    KeyValueGroupedDataset<K, WindowedValue<KV<K, V>>> groupByKeyOnly =
-        input.groupByKey(KVHelpers.extractKey(), 
EncoderHelpers.fromBeamCoder(keyCoder));
-
-    // group also by windows
-    WindowedValue.FullWindowedValueCoder<KV<K, Iterable<V>>> outputCoder =
-        WindowedValue.FullWindowedValueCoder.of(
-            KvCoder.of(keyCoder, IterableCoder.of(valueCoder)),
-            windowingStrategy.getWindowFn().windowCoder());
-    Dataset<WindowedValue<KV<K, Iterable<V>>>> output =
-        groupByKeyOnly.flatMapGroups(
-            new GroupAlsoByWindowViaOutputBufferFn<>(
-                windowingStrategy,
-                new InMemoryStateInternalsFactory<>(),
-                SystemReduceFn.buffering(valueCoder),
-                context.getSerializableOptions()),
-            EncoderHelpers.fromBeamCoder(outputCoder));
-
-    context.putDataset(context.getOutput(), output);
+  public void translate(GroupByKey<K, V> transform, Context cxt) {
+    WindowingStrategy<?, ?> windowing = cxt.getInput().getWindowingStrategy();
+    TimestampCombiner tsCombiner = windowing.getTimestampCombiner();
+
+    Dataset<WindowedValue<KV<K, V>>> input = cxt.getDataset(cxt.getInput());
+
+    KvCoder<K, V> inputCoder = (KvCoder<K, V>) cxt.getInput().getCoder();
+    KvCoder<K, Iterable<V>> outputCoder = (KvCoder<K, Iterable<V>>) 
cxt.getOutput().getCoder();
+
+    Encoder<V> valueEnc = cxt.valueEncoderOf(inputCoder);
+    Encoder<K> keyEnc = cxt.keyEncoderOf(inputCoder);
+
+    // In batch we can ignore triggering and allowed lateness parameters

Review Comment:
   processing time trigger in batch mode is more a synomym of "not based on 
event time" rather than "based on wall clock".  In that case the trigger is 
based on elements: it is firing every 2 elements seen. So I think that makes 
sense in batch mode. 
   Regarding watermak triggers in batch mode, they should better understood as 
triggers based on even timestamps as there is no flow of data in batch mode: 
please consider that event timestamps can be put on elements even in batch mode 
(cf replay of streaming recording) and in this case trigering on the event 
timestamp makes sense.  
   But I agree these are corner cases not frequent at all among users.
   Both nexmark queries run fine but beware nexmark does not check the 
correctness of data. So if there is no VR test on these corner cases, then it 
is not sure this behavior is correct.
   
   Anyway fair enough as the RDD runner does the same and AFAIK no user 
complained.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to