Repository: beam Updated Branches: refs/heads/master 7e9233bbd -> bb8cd72b9
http://git-wip-us.apache.org/repos/asf/beam/blob/6253abaa/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java index e3445bf..628b713 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java @@ -331,58 +331,20 @@ final class StreamingTransformTranslator { }; } - private static <InputT, OutputT> TransformEvaluator<ParDo.Bound<InputT, OutputT>> parDo() { - return new TransformEvaluator<ParDo.Bound<InputT, OutputT>>() { - @Override - public void evaluate(final ParDo.Bound<InputT, OutputT> transform, - final EvaluationContext context) { - final DoFn<InputT, OutputT> doFn = transform.getFn(); - rejectSplittable(doFn); - rejectStateAndTimers(doFn); - final SparkRuntimeContext runtimeContext = context.getRuntimeContext(); - final WindowingStrategy<?, ?> windowingStrategy = - context.getInput(transform).getWindowingStrategy(); - final SparkPCollectionView pviews = context.getPViews(); - - @SuppressWarnings("unchecked") - UnboundedDataset<InputT> unboundedDataset = - ((UnboundedDataset<InputT>) context.borrowDataset(transform)); - JavaDStream<WindowedValue<InputT>> dStream = unboundedDataset.getDStream(); - - final String stepName = context.getCurrentTransform().getFullName(); - - JavaDStream<WindowedValue<OutputT>> outStream = - dStream.transform(new Function<JavaRDD<WindowedValue<InputT>>, - JavaRDD<WindowedValue<OutputT>>>() { - @Override - public JavaRDD<WindowedValue<OutputT>> call(JavaRDD<WindowedValue<InputT>> rdd) throws - Exception { - final JavaSparkContext jsc = new JavaSparkContext(rdd.context()); - final Accumulator<NamedAggregators> aggAccum = - SparkAggregators.getNamedAggregators(jsc); - final Accumulator<SparkMetricsContainer> metricsAccum = - MetricsAccumulator.getInstance(); - final Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, SideInputBroadcast<?>>> sideInputs = - TranslationUtils.getSideInputs(transform.getSideInputs(), - jsc, pviews); - return rdd.mapPartitions( - new DoFnFunction<>(aggAccum, metricsAccum, stepName, doFn, runtimeContext, - sideInputs, windowingStrategy)); - } - }); - - context.putDataset(transform, - new UnboundedDataset<>(outStream, unboundedDataset.getStreamSources())); - } - }; - } - private static <InputT, OutputT> TransformEvaluator<ParDo.BoundMulti<InputT, OutputT>> multiDo() { return new TransformEvaluator<ParDo.BoundMulti<InputT, OutputT>>() { - @Override - public void evaluate(final ParDo.BoundMulti<InputT, OutputT> transform, - final EvaluationContext context) { + public void evaluate( + final ParDo.BoundMulti<InputT, OutputT> transform, final EvaluationContext context) { + if (transform.getSideOutputTags().size() == 0) { + evaluateSingle(transform, context); + } else { + evaluateMulti(transform, context); + } + } + + private void evaluateMulti( + final ParDo.BoundMulti<InputT, OutputT> transform, final EvaluationContext context) { final DoFn<InputT, OutputT> doFn = transform.getFn(); rejectSplittable(doFn); rejectStateAndTimers(doFn); @@ -426,10 +388,60 @@ final class StreamingTransformTranslator { JavaDStream<WindowedValue<Object>> values = (JavaDStream<WindowedValue<Object>>) (JavaDStream<?>) TranslationUtils.dStreamValues(filtered); - context.putDataset(e.getValue(), - new UnboundedDataset<>(values, unboundedDataset.getStreamSources())); + context.putDataset( + e.getValue(), new UnboundedDataset<>(values, unboundedDataset.getStreamSources())); } } + + private void evaluateSingle( + final ParDo.BoundMulti<InputT, OutputT> transform, final EvaluationContext context) { + final DoFn<InputT, OutputT> doFn = transform.getFn(); + rejectSplittable(doFn); + rejectStateAndTimers(doFn); + final SparkRuntimeContext runtimeContext = context.getRuntimeContext(); + final WindowingStrategy<?, ?> windowingStrategy = + context.getInput(transform).getWindowingStrategy(); + final SparkPCollectionView pviews = context.getPViews(); + + @SuppressWarnings("unchecked") + UnboundedDataset<InputT> unboundedDataset = + ((UnboundedDataset<InputT>) context.borrowDataset(transform)); + JavaDStream<WindowedValue<InputT>> dStream = unboundedDataset.getDStream(); + + final String stepName = context.getCurrentTransform().getFullName(); + + JavaDStream<WindowedValue<OutputT>> outStream = + dStream.transform( + new Function<JavaRDD<WindowedValue<InputT>>, JavaRDD<WindowedValue<OutputT>>>() { + @Override + public JavaRDD<WindowedValue<OutputT>> call(JavaRDD<WindowedValue<InputT>> rdd) + throws Exception { + final JavaSparkContext jsc = new JavaSparkContext(rdd.context()); + final Accumulator<NamedAggregators> aggAccum = + SparkAggregators.getNamedAggregators(jsc); + final Accumulator<SparkMetricsContainer> metricsAccum = + MetricsAccumulator.getInstance(); + final Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, SideInputBroadcast<?>>> + sideInputs = + TranslationUtils.getSideInputs(transform.getSideInputs(), jsc, pviews); + return rdd.mapPartitions( + new DoFnFunction<>( + aggAccum, + metricsAccum, + stepName, + doFn, + runtimeContext, + sideInputs, + windowingStrategy)); + } + }); + + PCollection<OutputT> output = + (PCollection<OutputT>) + Iterables.getOnlyElement(context.getOutputs(transform)).getValue(); + context.putDataset( + output, new UnboundedDataset<>(outStream, unboundedDataset.getStreamSources())); + } }; } @@ -440,7 +452,6 @@ final class StreamingTransformTranslator { EVALUATORS.put(Read.Unbounded.class, readUnbounded()); EVALUATORS.put(GroupByKey.class, groupByKey()); EVALUATORS.put(Combine.GroupedValues.class, combineGrouped()); - EVALUATORS.put(ParDo.Bound.class, parDo()); EVALUATORS.put(ParDo.BoundMulti.class, multiDo()); EVALUATORS.put(ConsoleIO.Write.Unbound.class, print()); EVALUATORS.put(CreateStream.class, createFromQueue()); http://git-wip-us.apache.org/repos/asf/beam/blob/6253abaa/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/TrackStreamingSourcesTest.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/TrackStreamingSourcesTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/TrackStreamingSourcesTest.java index b181a04..d66633b 100644 --- a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/TrackStreamingSourcesTest.java +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/TrackStreamingSourcesTest.java @@ -83,7 +83,7 @@ public class TrackStreamingSourcesTest { p.apply(emptyStream).apply(ParDo.of(new PassthroughFn<>())); - p.traverseTopologically(new StreamingSourceTracker(jssc, p, ParDo.Bound.class, 0)); + p.traverseTopologically(new StreamingSourceTracker(jssc, p, ParDo.BoundMulti.class, 0)); assertThat(StreamingSourceTracker.numAssertions, equalTo(1)); } @@ -111,7 +111,7 @@ public class TrackStreamingSourcesTest { PCollectionList.of(pcol1).and(pcol2).apply(Flatten.<Integer>pCollections()); flattened.apply(ParDo.of(new PassthroughFn<>())); - p.traverseTopologically(new StreamingSourceTracker(jssc, p, ParDo.Bound.class, 0, 1)); + p.traverseTopologically(new StreamingSourceTracker(jssc, p, ParDo.BoundMulti.class, 0, 1)); assertThat(StreamingSourceTracker.numAssertions, equalTo(1)); } http://git-wip-us.apache.org/repos/asf/beam/blob/6253abaa/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java index 19c5a2d..9225231 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java @@ -738,12 +738,8 @@ public class ParDo { @Override public PCollection<OutputT> expand(PCollection<? extends InputT> input) { - validateWindowType(input, fn); - return PCollection.<OutputT>createPrimitiveOutputInternal( - input.getPipeline(), - input.getWindowingStrategy(), - input.isBounded()) - .setTypeDescriptor(getFn().getOutputTypeDescriptor()); + TupleTag<OutputT> mainOutput = new TupleTag<>(); + return input.apply(withOutputTags(mainOutput, TupleTagList.empty())).get(mainOutput); } @Override
