This is an automated email from the ASF dual-hosted git repository.

aromanenko pushed a commit to branch spark-runner_structured-streaming
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to 
refs/heads/spark-runner_structured-streaming by this push:
     new 16cf3c2  Simplify logic of ParDo translator
16cf3c2 is described below

commit 16cf3c2ca6e5a82f1959ce2976a330badd6e6c44
Author: Alexey Romanenko <aromanenko....@gmail.com>
AuthorDate: Mon Feb 4 11:22:10 2019 +0100

    Simplify logic of ParDo translator
---
 .../translation/batch/DoFnFunction.java            |  9 ++--
 .../translation/batch/ParDoTranslatorBatch.java    | 59 ++++------------------
 2 files changed, 13 insertions(+), 55 deletions(-)

diff --git 
a/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnFunction.java
 
b/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnFunction.java
index 8ce98a8..2989d0d 100644
--- 
a/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnFunction.java
+++ 
b/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnFunction.java
@@ -20,7 +20,6 @@ package 
org.apache.beam.runners.spark.structuredstreaming.translation.batch;
 import com.google.common.base.Function;
 import com.google.common.collect.Iterators;
 import com.google.common.collect.LinkedListMultimap;
-import com.google.common.collect.Lists;
 import com.google.common.collect.Multimap;
 import java.util.Collections;
 import java.util.Iterator;
@@ -60,7 +59,7 @@ public class DoFnFunction<InputT, OutputT>
 
   private final WindowingStrategy<?, ?> windowingStrategy;
 
-  private final Map<TupleTag<?>, Integer> outputMap;
+  private final List<TupleTag<?>> additionalOutputTags;
   private final TupleTag<OutputT> mainOutputTag;
   private final Coder<InputT> inputCoder;
   private final Map<TupleTag<?>, Coder<?>> outputCoderMap;
@@ -72,7 +71,7 @@ public class DoFnFunction<InputT, OutputT>
       WindowingStrategy<?, ?> windowingStrategy,
       Map<PCollectionView<?>, WindowingStrategy<?, ?>> sideInputs,
       PipelineOptions options,
-      Map<TupleTag<?>, Integer> outputMap,
+      List<TupleTag<?>> additionalOutputTags,
       TupleTag<OutputT> mainOutputTag,
       Coder<InputT> inputCoder,
       Map<TupleTag<?>, Coder<?>> outputCoderMap) {
@@ -81,7 +80,7 @@ public class DoFnFunction<InputT, OutputT>
     this.sideInputs = sideInputs;
     this.serializedOptions = new SerializablePipelineOptions(options);
     this.windowingStrategy = windowingStrategy;
-    this.outputMap = outputMap;
+    this.additionalOutputTags = additionalOutputTags;
     this.mainOutputTag = mainOutputTag;
     this.inputCoder = inputCoder;
     this.outputCoderMap = outputCoderMap;
@@ -93,8 +92,6 @@ public class DoFnFunction<InputT, OutputT>
 
     DoFnOutputManager outputManager = new DoFnOutputManager();
 
-    List<TupleTag<?>> additionalOutputTags = 
Lists.newArrayList(outputMap.keySet());
-
     DoFnRunner<InputT, OutputT> doFnRunner =
         DoFnRunners.simpleRunner(
             serializedOptions.get(),
diff --git 
a/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch.java
 
b/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch.java
index fbb6649..5c9cb16 100644
--- 
a/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch.java
+++ 
b/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch.java
@@ -20,7 +20,6 @@ package 
org.apache.beam.runners.spark.structuredstreaming.translation.batch;
 import static com.google.common.base.Preconditions.checkState;
 
 import com.google.common.collect.Lists;
-import com.google.common.collect.Maps;
 import java.io.IOException;
 import java.util.HashMap;
 import java.util.List;
@@ -32,7 +31,6 @@ import 
org.apache.beam.runners.spark.structuredstreaming.translation.Translation
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.PTransform;
-import org.apache.beam.sdk.transforms.join.UnionCoder;
 import org.apache.beam.sdk.transforms.reflect.DoFnSignature;
 import org.apache.beam.sdk.transforms.reflect.DoFnSignatures;
 import org.apache.beam.sdk.util.WindowedValue;
@@ -61,7 +59,7 @@ class ParDoTranslatorBatch<InputT, OutputT>
   public void translateTransform(
       PTransform<PCollection<InputT>, PCollectionTuple> transform, 
TranslationContext context) {
 
-    // Check for not-supported advanced features
+    // Check for not supported advanced features
     // TODO: add support of Splittable DoFn
     DoFn<InputT, OutputT> doFn = getDoFn(context);
     checkState(
@@ -80,51 +78,13 @@ class ParDoTranslatorBatch<InputT, OutputT>
     final boolean hasSideInputs = sideInputs != null && sideInputs.size() > 0;
     checkState(!hasSideInputs, "SideInputs are not supported for the moment.");
 
-
     // Init main variables
     Dataset<WindowedValue<InputT>> inputDataSet = 
context.getDataset(context.getInput());
     Map<TupleTag<?>, PValue> outputs = context.getOutputs();
     TupleTag<?> mainOutputTag = getTupleTag(context);
-    Map<TupleTag<?>, Integer> outputTags = Maps.newHashMap();
-
-    outputTags.put(mainOutputTag, 0);
-    int count = 1;
-    for (TupleTag<?> tag : outputs.keySet()) {
-      if (!outputTags.containsKey(tag)) {
-        outputTags.put(tag, count++);
-      }
-    }
-
-    // Union coder elements must match the order of the output tags.
-    Map<Integer, TupleTag<?>> indexMap = Maps.newTreeMap();
-    for (Map.Entry<TupleTag<?>, Integer> entry : outputTags.entrySet()) {
-      indexMap.put(entry.getValue(), entry.getKey());
-    }
-
-    // assume that the windowing strategy is the same for all outputs
-    WindowingStrategy<?, ?> windowingStrategy = null;
-
-    // collect all output Coders and create a UnionCoder for our tagged outputs
-//    List<Coder<?>> outputCoders = Lists.newArrayList();
-    for (TupleTag<?> tag : indexMap.values()) {
-      PValue taggedValue = outputs.get(tag);
-      checkState(
-          taggedValue instanceof PCollection,
-          "Within ParDo, got a non-PCollection output %s of type %s",
-          taggedValue,
-          taggedValue.getClass().getSimpleName());
-      PCollection<?> coll = (PCollection<?>) taggedValue;
-//      outputCoders.add(coll.getCoder());
-      windowingStrategy = coll.getWindowingStrategy();
-    }
-
-    if (windowingStrategy == null) {
-      throw new IllegalStateException("No outputs defined.");
-    }
-
-//    UnionCoder unionCoder = UnionCoder.of(outputCoders);
-
-
+    List<TupleTag<?>> outputTags = Lists.newArrayList(outputs.keySet());
+    WindowingStrategy<?, ?> windowingStrategy =
+        ((PCollection<InputT>) context.getInput()).getWindowingStrategy();
 
     // construct a map from side input to WindowingStrategy so that
     // the DoFn runner can map main-input windows to side input windows
@@ -134,6 +94,7 @@ class ParDoTranslatorBatch<InputT, OutputT>
     }
 
     Map<TupleTag<?>, Coder<?>> outputCoderMap = context.getOutputCoders();
+    Coder<InputT> inputCoder = ((PCollection<InputT>) 
context.getInput()).getCoder();
 
     @SuppressWarnings("unchecked")
     DoFnFunction<InputT, OutputT> doFnWrapper =
@@ -144,14 +105,14 @@ class ParDoTranslatorBatch<InputT, OutputT>
             context.getOptions(),
             outputTags,
             mainOutputTag,
-            ((PCollection<InputT>)context.getInput()).getCoder(),
+            inputCoder,
             outputCoderMap);
 
-    Dataset<Tuple2<TupleTag<?>, WindowedValue<?>>> allOutputsDataset =
+    Dataset<Tuple2<TupleTag<?>, WindowedValue<?>>> allOutputs =
         inputDataSet.mapPartitions(doFnWrapper, 
EncoderHelpers.tuple2Encoder());
 
     for (Map.Entry<TupleTag<?>, PValue> output : outputs.entrySet()) {
-      pruneOutputFilteredByTag(context, allOutputsDataset, output);
+      pruneOutputFilteredByTag(context, allOutputs, output);
     }
   }
 
@@ -188,10 +149,10 @@ class ParDoTranslatorBatch<InputT, OutputT>
 
   private void pruneOutputFilteredByTag(
       TranslationContext context,
-      Dataset<Tuple2<TupleTag<?>, WindowedValue<?>>> tmpDataset,
+      Dataset<Tuple2<TupleTag<?>, WindowedValue<?>>> allOutputs,
       Map.Entry<TupleTag<?>, PValue> output) {
     Dataset<Tuple2<TupleTag<?>, WindowedValue<?>>> filteredDataset =
-        tmpDataset.filter(new SparkDoFnFilterFunction(output.getKey()));
+        allOutputs.filter(new SparkDoFnFilterFunction(output.getKey()));
     Dataset<WindowedValue<?>> outputDataset =
         filteredDataset.map(
             (MapFunction<Tuple2<TupleTag<?>, WindowedValue<?>>, 
WindowedValue<?>>)

Reply via email to