This is an automated email from the ASF dual-hosted git repository.

ibzib pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 4836b14  [BEAM-7864] Simplify/generalize Spark reshuffle translation
     new a5e7e67  Merge pull request #9410 from ibzib/spark-reshuffle
4836b14 is described below

commit 4836b14826b546db1bf4cc1ac36d415a53aad868
Author: Kyle Weaver <[email protected]>
AuthorDate: Thu Aug 22 18:24:30 2019 -0700

    [BEAM-7864] Simplify/generalize Spark reshuffle translation
---
 .../spark/translation/GroupCombineFunctions.java   | 13 ++++--------
 .../SparkBatchPortablePipelineTranslator.java      | 23 +++++++++-------------
 .../spark/translation/TransformTranslator.java     |  7 +++----
 .../streaming/StreamingTransformTranslator.java    |  7 +++----
 4 files changed, 19 insertions(+), 31 deletions(-)

diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/GroupCombineFunctions.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/GroupCombineFunctions.java
index b718725..9a10354 100644
--- 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/GroupCombineFunctions.java
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/GroupCombineFunctions.java
@@ -183,17 +183,12 @@ public class GroupCombineFunctions {
   }
 
   /** An implementation of {@link Reshuffle} for the Spark runner. */
-  public static <K, V> JavaRDD<WindowedValue<KV<K, V>>> reshuffle(
-      JavaRDD<WindowedValue<KV<K, V>>> rdd, Coder<K> keyCoder, 
WindowedValueCoder<V> wvCoder) {
-
+  public static <T> JavaRDD<WindowedValue<T>> reshuffle(
+      JavaRDD<WindowedValue<T>> rdd, WindowedValueCoder<T> wvCoder) {
     // Use coders to convert objects in the PCollection to byte arrays, so they
     // can be transferred over the network for the shuffle.
-    return rdd.map(new ReifyTimestampsAndWindowsFunction<>())
-        .mapToPair(TranslationUtils.toPairFunction())
-        .mapToPair(CoderHelpers.toByteFunction(keyCoder, wvCoder))
+    return rdd.map(CoderHelpers.toByteFunction(wvCoder))
         .repartition(rdd.getNumPartitions())
-        .mapToPair(new CoderHelpers.FromByteFunction(keyCoder, wvCoder))
-        .map(new TranslationUtils.FromPairFunction())
-        .map(new TranslationUtils.ToKVByWindowInValueFunction<>());
+        .map(CoderHelpers.fromByteFunction(wvCoder));
   }
 }
diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkBatchPortablePipelineTranslator.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkBatchPortablePipelineTranslator.java
index 4bed3a0..07368f3 100644
--- 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkBatchPortablePipelineTranslator.java
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkBatchPortablePipelineTranslator.java
@@ -64,11 +64,16 @@ import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.broadcast.Broadcast;
 import org.apache.spark.storage.StorageLevel;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 import scala.Tuple2;
 
 /** Translates a bounded portable pipeline into a Spark job. */
 public class SparkBatchPortablePipelineTranslator {
 
+  private static final Logger LOG =
+      LoggerFactory.getLogger(SparkBatchPortablePipelineTranslator.class);
+
   private final ImmutableMap<String, PTransformTranslator> 
urnToTransformTranslator;
 
   interface PTransformTranslator {
@@ -350,22 +355,12 @@ public class SparkBatchPortablePipelineTranslator {
     context.pushDataset(getOutputId(transformNode), new 
BoundedDataset<>(unionRDD));
   }
 
-  private static <K, V> void translateReshuffle(
+  private static <T> void translateReshuffle(
       PTransformNode transformNode, RunnerApi.Pipeline pipeline, 
SparkTranslationContext context) {
     String inputId = getInputId(transformNode);
-    JavaRDD<WindowedValue<KV<K, V>>> inRDD =
-        ((BoundedDataset<KV<K, V>>) context.popDataset(inputId)).getRDD();
-    RunnerApi.Components components = pipeline.getComponents();
-    WindowingStrategy windowingStrategy = getWindowingStrategy(inputId, 
components);
-    WindowedValueCoder<KV<K, V>> windowedCoder = 
getWindowedValueCoder(inputId, components);
-    KvCoder<K, V> coder = (KvCoder<K, V>) windowedCoder.getValueCoder();
-    final WindowFn windowFn = windowingStrategy.getWindowFn();
-    final Coder<K> keyCoder = coder.getKeyCoder();
-    final WindowedValue.WindowedValueCoder<V> wvCoder =
-        WindowedValue.FullWindowedValueCoder.of(coder.getValueCoder(), 
windowFn.windowCoder());
-
-    JavaRDD<WindowedValue<KV<K, V>>> reshuffled =
-        GroupCombineFunctions.reshuffle(inRDD, keyCoder, wvCoder);
+    WindowedValueCoder<T> coder = getWindowedValueCoder(inputId, 
pipeline.getComponents());
+    JavaRDD<WindowedValue<T>> inRDD = ((BoundedDataset<T>) 
context.popDataset(inputId)).getRDD();
+    JavaRDD<WindowedValue<T>> reshuffled = 
GroupCombineFunctions.reshuffle(inRDD, coder);
     context.pushDataset(getOutputId(transformNode), new 
BoundedDataset<>(reshuffled));
   }
 
diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
index 23560e2..dd4d426 100644
--- 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
@@ -547,12 +547,11 @@ public final class TransformTranslator {
         @SuppressWarnings("unchecked")
         final WindowFn<Object, W> windowFn = (WindowFn<Object, W>) 
windowingStrategy.getWindowFn();
 
-        final Coder<K> keyCoder = coder.getKeyCoder();
-        final WindowedValue.WindowedValueCoder<V> wvCoder =
-            WindowedValue.FullWindowedValueCoder.of(coder.getValueCoder(), 
windowFn.windowCoder());
+        final WindowedValue.WindowedValueCoder<KV<K, V>> wvCoder =
+            WindowedValue.FullWindowedValueCoder.of(coder, 
windowFn.windowCoder());
 
         JavaRDD<WindowedValue<KV<K, V>>> reshuffled =
-            GroupCombineFunctions.reshuffle(inRDD, keyCoder, wvCoder);
+            GroupCombineFunctions.reshuffle(inRDD, wvCoder);
 
         context.putDataset(transform, new BoundedDataset<>(reshuffled));
       }
diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
index 2480c6b..3eed320 100644
--- 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
@@ -475,12 +475,11 @@ public final class StreamingTransformTranslator {
         @SuppressWarnings("unchecked")
         final WindowFn<Object, W> windowFn = (WindowFn<Object, W>) 
windowingStrategy.getWindowFn();
 
-        final WindowedValue.WindowedValueCoder<V> wvCoder =
-            WindowedValue.FullWindowedValueCoder.of(coder.getValueCoder(), 
windowFn.windowCoder());
+        final WindowedValue.WindowedValueCoder<KV<K, V>> wvCoder =
+            WindowedValue.FullWindowedValueCoder.of(coder, 
windowFn.windowCoder());
 
         JavaDStream<WindowedValue<KV<K, V>>> reshuffledStream =
-            dStream.transform(
-                rdd -> GroupCombineFunctions.reshuffle(rdd, 
coder.getKeyCoder(), wvCoder));
+            dStream.transform(rdd -> GroupCombineFunctions.reshuffle(rdd, 
wvCoder));
 
         context.putDataset(transform, new UnboundedDataset<>(reshuffledStream, 
streamSources));
       }

Reply via email to