This is an automated email from the ASF dual-hosted git repository. aromanenko pushed a commit to branch spark-runner_structured-streaming in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/spark-runner_structured-streaming by this push: new 16cf3c2 Simplify logic of ParDo translator 16cf3c2 is described below commit 16cf3c2ca6e5a82f1959ce2976a330badd6e6c44 Author: Alexey Romanenko <aromanenko....@gmail.com> AuthorDate: Mon Feb 4 11:22:10 2019 +0100 Simplify logic of ParDo translator --- .../translation/batch/DoFnFunction.java | 9 ++-- .../translation/batch/ParDoTranslatorBatch.java | 59 ++++------------------ 2 files changed, 13 insertions(+), 55 deletions(-) diff --git a/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnFunction.java b/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnFunction.java index 8ce98a8..2989d0d 100644 --- a/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnFunction.java +++ b/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnFunction.java @@ -20,7 +20,6 @@ package org.apache.beam.runners.spark.structuredstreaming.translation.batch; import com.google.common.base.Function; import com.google.common.collect.Iterators; import com.google.common.collect.LinkedListMultimap; -import com.google.common.collect.Lists; import com.google.common.collect.Multimap; import java.util.Collections; import java.util.Iterator; @@ -60,7 +59,7 @@ public class DoFnFunction<InputT, OutputT> private final WindowingStrategy<?, ?> windowingStrategy; - private final Map<TupleTag<?>, Integer> outputMap; + private final List<TupleTag<?>> additionalOutputTags; private final TupleTag<OutputT> mainOutputTag; private final Coder<InputT> inputCoder; private final Map<TupleTag<?>, Coder<?>> outputCoderMap; @@ -72,7 +71,7 @@ public class DoFnFunction<InputT, OutputT> WindowingStrategy<?, ?> windowingStrategy, Map<PCollectionView<?>, WindowingStrategy<?, ?>> sideInputs, PipelineOptions options, - Map<TupleTag<?>, Integer> outputMap, + List<TupleTag<?>> additionalOutputTags, TupleTag<OutputT> mainOutputTag, Coder<InputT> inputCoder, Map<TupleTag<?>, Coder<?>> outputCoderMap) { @@ -81,7 +80,7 @@ public class DoFnFunction<InputT, OutputT> this.sideInputs = sideInputs; this.serializedOptions = new SerializablePipelineOptions(options); this.windowingStrategy = windowingStrategy; - this.outputMap = outputMap; + this.additionalOutputTags = additionalOutputTags; this.mainOutputTag = mainOutputTag; this.inputCoder = inputCoder; this.outputCoderMap = outputCoderMap; @@ -93,8 +92,6 @@ public class DoFnFunction<InputT, OutputT> DoFnOutputManager outputManager = new DoFnOutputManager(); - List<TupleTag<?>> additionalOutputTags = Lists.newArrayList(outputMap.keySet()); - DoFnRunner<InputT, OutputT> doFnRunner = DoFnRunners.simpleRunner( serializedOptions.get(), diff --git a/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch.java b/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch.java index fbb6649..5c9cb16 100644 --- a/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch.java +++ b/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch.java @@ -20,7 +20,6 @@ package org.apache.beam.runners.spark.structuredstreaming.translation.batch; import static com.google.common.base.Preconditions.checkState; import com.google.common.collect.Lists; -import com.google.common.collect.Maps; import java.io.IOException; import java.util.HashMap; import java.util.List; @@ -32,7 +31,6 @@ import org.apache.beam.runners.spark.structuredstreaming.translation.Translation import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.PTransform; -import org.apache.beam.sdk.transforms.join.UnionCoder; import org.apache.beam.sdk.transforms.reflect.DoFnSignature; import org.apache.beam.sdk.transforms.reflect.DoFnSignatures; import org.apache.beam.sdk.util.WindowedValue; @@ -61,7 +59,7 @@ class ParDoTranslatorBatch<InputT, OutputT> public void translateTransform( PTransform<PCollection<InputT>, PCollectionTuple> transform, TranslationContext context) { - // Check for not-supported advanced features + // Check for not supported advanced features // TODO: add support of Splittable DoFn DoFn<InputT, OutputT> doFn = getDoFn(context); checkState( @@ -80,51 +78,13 @@ class ParDoTranslatorBatch<InputT, OutputT> final boolean hasSideInputs = sideInputs != null && sideInputs.size() > 0; checkState(!hasSideInputs, "SideInputs are not supported for the moment."); - // Init main variables Dataset<WindowedValue<InputT>> inputDataSet = context.getDataset(context.getInput()); Map<TupleTag<?>, PValue> outputs = context.getOutputs(); TupleTag<?> mainOutputTag = getTupleTag(context); - Map<TupleTag<?>, Integer> outputTags = Maps.newHashMap(); - - outputTags.put(mainOutputTag, 0); - int count = 1; - for (TupleTag<?> tag : outputs.keySet()) { - if (!outputTags.containsKey(tag)) { - outputTags.put(tag, count++); - } - } - - // Union coder elements must match the order of the output tags. - Map<Integer, TupleTag<?>> indexMap = Maps.newTreeMap(); - for (Map.Entry<TupleTag<?>, Integer> entry : outputTags.entrySet()) { - indexMap.put(entry.getValue(), entry.getKey()); - } - - // assume that the windowing strategy is the same for all outputs - WindowingStrategy<?, ?> windowingStrategy = null; - - // collect all output Coders and create a UnionCoder for our tagged outputs -// List<Coder<?>> outputCoders = Lists.newArrayList(); - for (TupleTag<?> tag : indexMap.values()) { - PValue taggedValue = outputs.get(tag); - checkState( - taggedValue instanceof PCollection, - "Within ParDo, got a non-PCollection output %s of type %s", - taggedValue, - taggedValue.getClass().getSimpleName()); - PCollection<?> coll = (PCollection<?>) taggedValue; -// outputCoders.add(coll.getCoder()); - windowingStrategy = coll.getWindowingStrategy(); - } - - if (windowingStrategy == null) { - throw new IllegalStateException("No outputs defined."); - } - -// UnionCoder unionCoder = UnionCoder.of(outputCoders); - - + List<TupleTag<?>> outputTags = Lists.newArrayList(outputs.keySet()); + WindowingStrategy<?, ?> windowingStrategy = + ((PCollection<InputT>) context.getInput()).getWindowingStrategy(); // construct a map from side input to WindowingStrategy so that // the DoFn runner can map main-input windows to side input windows @@ -134,6 +94,7 @@ class ParDoTranslatorBatch<InputT, OutputT> } Map<TupleTag<?>, Coder<?>> outputCoderMap = context.getOutputCoders(); + Coder<InputT> inputCoder = ((PCollection<InputT>) context.getInput()).getCoder(); @SuppressWarnings("unchecked") DoFnFunction<InputT, OutputT> doFnWrapper = @@ -144,14 +105,14 @@ class ParDoTranslatorBatch<InputT, OutputT> context.getOptions(), outputTags, mainOutputTag, - ((PCollection<InputT>)context.getInput()).getCoder(), + inputCoder, outputCoderMap); - Dataset<Tuple2<TupleTag<?>, WindowedValue<?>>> allOutputsDataset = + Dataset<Tuple2<TupleTag<?>, WindowedValue<?>>> allOutputs = inputDataSet.mapPartitions(doFnWrapper, EncoderHelpers.tuple2Encoder()); for (Map.Entry<TupleTag<?>, PValue> output : outputs.entrySet()) { - pruneOutputFilteredByTag(context, allOutputsDataset, output); + pruneOutputFilteredByTag(context, allOutputs, output); } } @@ -188,10 +149,10 @@ class ParDoTranslatorBatch<InputT, OutputT> private void pruneOutputFilteredByTag( TranslationContext context, - Dataset<Tuple2<TupleTag<?>, WindowedValue<?>>> tmpDataset, + Dataset<Tuple2<TupleTag<?>, WindowedValue<?>>> allOutputs, Map.Entry<TupleTag<?>, PValue> output) { Dataset<Tuple2<TupleTag<?>, WindowedValue<?>>> filteredDataset = - tmpDataset.filter(new SparkDoFnFilterFunction(output.getKey())); + allOutputs.filter(new SparkDoFnFilterFunction(output.getKey())); Dataset<WindowedValue<?>> outputDataset = filteredDataset.map( (MapFunction<Tuple2<TupleTag<?>, WindowedValue<?>>, WindowedValue<?>>)