echauchot commented on code in PR #22446:
URL: https://github.com/apache/beam/pull/22446#discussion_r977397062


##########
runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombinePerKeyTranslatorBatch.java:
##########
@@ -17,98 +17,129 @@
  */
 package org.apache.beam.runners.spark.structuredstreaming.translation.batch;
 
-import java.util.ArrayList;
-import java.util.List;
-import 
org.apache.beam.runners.spark.structuredstreaming.translation.AbstractTranslationContext;
-import 
org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator;
-import 
org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers;
-import 
org.apache.beam.runners.spark.structuredstreaming.translation.helpers.KVHelpers;
+import static 
org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.fun1;
+
+import java.util.Collection;
+import 
org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop;
+import 
org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.Fun1;
 import org.apache.beam.sdk.coders.CannotProvideCoderException;
 import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.CoderRegistry;
 import org.apache.beam.sdk.coders.KvCoder;
 import org.apache.beam.sdk.transforms.Combine;
-import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.Combine.CombineFn;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
 import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.values.KV;
-import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.WindowingStrategy;
-import org.apache.spark.api.java.function.FlatMapFunction;
 import org.apache.spark.sql.Dataset;
-import org.apache.spark.sql.KeyValueGroupedDataset;
+import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.expressions.Aggregator;
 import scala.Tuple2;
+import scala.collection.TraversableOnce;
 
-@SuppressWarnings({
-  "rawtypes" // TODO(https://github.com/apache/beam/issues/20447)
-})
-class CombinePerKeyTranslatorBatch<K, InputT, AccumT, OutputT>
-    implements TransformTranslator<
-        PTransform<PCollection<KV<K, InputT>>, PCollection<KV<K, OutputT>>>> {
+/**
+ * Translator for {@link Combine.PerKey} using {@link Dataset#groupByKey} with 
a Spark {@link
+ * Aggregator}.
+ *
+ * <ul>
+ *   <li>When using the default global window, window information is dropped 
and restored after the
+ *       aggregation.
+ *   <li>For non-merging windows, windows are exploded and moved into a 
composite key for better
+ *       distribution. After the aggregation, windowed values are restored 
from the composite key.
+ *   <li>All other cases use an aggregator on windowed values that is 
optimized for the current
+ *       windowing strategy.
+ * </ul>
+ *
+ * TODOs:
+ * <li>combine with context (CombineFnWithContext)?
+ * <li>combine with sideInputs?
+ * <li>other there other missing features?
+ */
+class CombinePerKeyTranslatorBatch<K, InT, AccT, OutT>
+    extends GroupingTranslator<K, InT, OutT, Combine.PerKey<K, InT, OutT>> {
 
   @Override
-  public void translateTransform(
-      PTransform<PCollection<KV<K, InputT>>, PCollection<KV<K, OutputT>>> 
transform,
-      AbstractTranslationContext context) {
+  public void translate(Combine.PerKey<K, InT, OutT> transform, Context cxt) {
+    WindowingStrategy<?, ?> windowing = cxt.getInput().getWindowingStrategy();
+    CombineFn<InT, AccT, OutT> combineFn = (CombineFn<InT, AccT, OutT>) 
transform.getFn();
+
+    KvCoder<K, InT> inputCoder = (KvCoder<K, InT>) cxt.getInput().getCoder();
+    KvCoder<K, OutT> outputCoder = (KvCoder<K, OutT>) 
cxt.getOutput().getCoder();
+
+    Encoder<K> keyEnc = cxt.keyEncoderOf(inputCoder);
+    Encoder<KV<K, InT>> inputEnc = cxt.encoderOf(inputCoder);
+    Encoder<WindowedValue<KV<K, OutT>>> wvOutputEnc = 
cxt.windowedEncoder(outputCoder);
+    Encoder<AccT> accumEnc = accumEncoder(combineFn, 
inputCoder.getValueCoder(), cxt);
+
+    final Dataset<WindowedValue<KV<K, OutT>>> result;
+
+    boolean globalGroupBy = eligibleForGlobalGroupBy(windowing, true);
+    boolean groupByWindow = eligibleForGroupByWindow(windowing, true);
 
-    Combine.PerKey combineTransform = (Combine.PerKey) transform;
-    @SuppressWarnings("unchecked")
-    final PCollection<KV<K, InputT>> input = (PCollection<KV<K, InputT>>) 
context.getInput();
-    @SuppressWarnings("unchecked")
-    final PCollection<KV<K, OutputT>> output = (PCollection<KV<K, OutputT>>) 
context.getOutput();
-    @SuppressWarnings("unchecked")
-    final Combine.CombineFn<InputT, AccumT, OutputT> combineFn =
-        (Combine.CombineFn<InputT, AccumT, OutputT>) combineTransform.getFn();
-    WindowingStrategy<?, ?> windowingStrategy = input.getWindowingStrategy();
+    if (globalGroupBy || groupByWindow) {
+      Aggregator<KV<K, InT>, ?, OutT> valueAgg =
+          Aggregators.value(combineFn, KV::getValue, accumEnc, 
cxt.valueEncoderOf(outputCoder));
 
-    Dataset<WindowedValue<KV<K, InputT>>> inputDataset = 
context.getDataset(input);
+      if (globalGroupBy) {
+        // Drop window and group by key globally to run the aggregation 
(combineFn), afterwards the
+        // global window is restored
+        result =
+            cxt.getDataset(cxt.getInput())
+                .groupByKey(valueKey(), keyEnc)
+                .mapValues(value(), inputEnc)
+                .agg(valueAgg.toColumn())
+                .map(globalKV(), wvOutputEnc);
+      } else {
+        Encoder<Tuple2<BoundedWindow, K>> windowedKeyEnc = 
windowedKeyEnc(keyEnc, cxt);
 
-    KvCoder<K, InputT> inputCoder = (KvCoder<K, InputT>) input.getCoder();
-    Coder<K> keyCoder = inputCoder.getKeyCoder();
-    KvCoder<K, OutputT> outputKVCoder = (KvCoder<K, OutputT>) 
output.getCoder();
-    Coder<OutputT> outputCoder = outputKVCoder.getValueCoder();
+        // Group by window and key to run the aggregation (combineFn)
+        result =
+            cxt.getDataset(cxt.getInput())
+                .flatMap(explodeWindowedKey(value()), 
cxt.tupleEncoder(windowedKeyEnc, inputEnc))
+                .groupByKey(fun1(Tuple2::_1), windowedKeyEnc)
+                .mapValues(fun1(Tuple2::_2), inputEnc)
+                .agg(valueAgg.toColumn())
+                .map(windowedKV(), wvOutputEnc);
+      }
+    } else {
+      // Use an optimized aggregator for session window fns

Review Comment:
   agree



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to