Repository: beam
Updated Branches:
  refs/heads/master 7e9233bbd -> bb8cd72b9


http://git-wip-us.apache.org/repos/asf/beam/blob/6253abaa/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
index e3445bf..628b713 100644
--- 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
@@ -331,58 +331,20 @@ final class StreamingTransformTranslator {
     };
   }
 
-  private static <InputT, OutputT> TransformEvaluator<ParDo.Bound<InputT, 
OutputT>> parDo() {
-    return new TransformEvaluator<ParDo.Bound<InputT, OutputT>>() {
-      @Override
-      public void evaluate(final ParDo.Bound<InputT, OutputT> transform,
-                           final EvaluationContext context) {
-        final DoFn<InputT, OutputT> doFn = transform.getFn();
-        rejectSplittable(doFn);
-        rejectStateAndTimers(doFn);
-        final SparkRuntimeContext runtimeContext = context.getRuntimeContext();
-        final WindowingStrategy<?, ?> windowingStrategy =
-            context.getInput(transform).getWindowingStrategy();
-        final SparkPCollectionView pviews = context.getPViews();
-
-        @SuppressWarnings("unchecked")
-        UnboundedDataset<InputT> unboundedDataset =
-            ((UnboundedDataset<InputT>) context.borrowDataset(transform));
-        JavaDStream<WindowedValue<InputT>> dStream = 
unboundedDataset.getDStream();
-
-        final String stepName = context.getCurrentTransform().getFullName();
-
-        JavaDStream<WindowedValue<OutputT>> outStream =
-            dStream.transform(new Function<JavaRDD<WindowedValue<InputT>>,
-                JavaRDD<WindowedValue<OutputT>>>() {
-          @Override
-          public JavaRDD<WindowedValue<OutputT>> 
call(JavaRDD<WindowedValue<InputT>> rdd) throws
-              Exception {
-            final JavaSparkContext jsc = new JavaSparkContext(rdd.context());
-            final Accumulator<NamedAggregators> aggAccum =
-                SparkAggregators.getNamedAggregators(jsc);
-            final Accumulator<SparkMetricsContainer> metricsAccum =
-                MetricsAccumulator.getInstance();
-            final Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, 
SideInputBroadcast<?>>> sideInputs =
-                TranslationUtils.getSideInputs(transform.getSideInputs(),
-                    jsc, pviews);
-            return rdd.mapPartitions(
-                new DoFnFunction<>(aggAccum, metricsAccum, stepName, doFn, 
runtimeContext,
-                    sideInputs, windowingStrategy));
-          }
-        });
-
-        context.putDataset(transform,
-            new UnboundedDataset<>(outStream, 
unboundedDataset.getStreamSources()));
-      }
-    };
-  }
-
   private static <InputT, OutputT> TransformEvaluator<ParDo.BoundMulti<InputT, 
OutputT>>
   multiDo() {
     return new TransformEvaluator<ParDo.BoundMulti<InputT, OutputT>>() {
-      @Override
-      public void evaluate(final ParDo.BoundMulti<InputT, OutputT> transform,
-                           final EvaluationContext context) {
+      public void evaluate(
+          final ParDo.BoundMulti<InputT, OutputT> transform, final 
EvaluationContext context) {
+        if (transform.getSideOutputTags().size() == 0) {
+          evaluateSingle(transform, context);
+        } else {
+          evaluateMulti(transform, context);
+        }
+      }
+
+      private void evaluateMulti(
+          final ParDo.BoundMulti<InputT, OutputT> transform, final 
EvaluationContext context) {
         final DoFn<InputT, OutputT> doFn = transform.getFn();
         rejectSplittable(doFn);
         rejectStateAndTimers(doFn);
@@ -426,10 +388,60 @@ final class StreamingTransformTranslator {
           JavaDStream<WindowedValue<Object>> values =
               (JavaDStream<WindowedValue<Object>>)
                   (JavaDStream<?>) TranslationUtils.dStreamValues(filtered);
-          context.putDataset(e.getValue(),
-              new UnboundedDataset<>(values, 
unboundedDataset.getStreamSources()));
+          context.putDataset(
+              e.getValue(), new UnboundedDataset<>(values, 
unboundedDataset.getStreamSources()));
         }
       }
+
+      private void evaluateSingle(
+          final ParDo.BoundMulti<InputT, OutputT> transform, final 
EvaluationContext context) {
+        final DoFn<InputT, OutputT> doFn = transform.getFn();
+        rejectSplittable(doFn);
+        rejectStateAndTimers(doFn);
+        final SparkRuntimeContext runtimeContext = context.getRuntimeContext();
+        final WindowingStrategy<?, ?> windowingStrategy =
+            context.getInput(transform).getWindowingStrategy();
+        final SparkPCollectionView pviews = context.getPViews();
+
+        @SuppressWarnings("unchecked")
+        UnboundedDataset<InputT> unboundedDataset =
+            ((UnboundedDataset<InputT>) context.borrowDataset(transform));
+        JavaDStream<WindowedValue<InputT>> dStream = 
unboundedDataset.getDStream();
+
+        final String stepName = context.getCurrentTransform().getFullName();
+
+        JavaDStream<WindowedValue<OutputT>> outStream =
+            dStream.transform(
+                new Function<JavaRDD<WindowedValue<InputT>>, 
JavaRDD<WindowedValue<OutputT>>>() {
+                  @Override
+                  public JavaRDD<WindowedValue<OutputT>> 
call(JavaRDD<WindowedValue<InputT>> rdd)
+                      throws Exception {
+                    final JavaSparkContext jsc = new 
JavaSparkContext(rdd.context());
+                    final Accumulator<NamedAggregators> aggAccum =
+                        SparkAggregators.getNamedAggregators(jsc);
+                    final Accumulator<SparkMetricsContainer> metricsAccum =
+                        MetricsAccumulator.getInstance();
+                    final Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, 
SideInputBroadcast<?>>>
+                        sideInputs =
+                            
TranslationUtils.getSideInputs(transform.getSideInputs(), jsc, pviews);
+                    return rdd.mapPartitions(
+                        new DoFnFunction<>(
+                            aggAccum,
+                            metricsAccum,
+                            stepName,
+                            doFn,
+                            runtimeContext,
+                            sideInputs,
+                            windowingStrategy));
+                  }
+                });
+
+        PCollection<OutputT> output =
+            (PCollection<OutputT>)
+                
Iterables.getOnlyElement(context.getOutputs(transform)).getValue();
+        context.putDataset(
+            output, new UnboundedDataset<>(outStream, 
unboundedDataset.getStreamSources()));
+      }
     };
   }
 
@@ -440,7 +452,6 @@ final class StreamingTransformTranslator {
     EVALUATORS.put(Read.Unbounded.class, readUnbounded());
     EVALUATORS.put(GroupByKey.class, groupByKey());
     EVALUATORS.put(Combine.GroupedValues.class, combineGrouped());
-    EVALUATORS.put(ParDo.Bound.class, parDo());
     EVALUATORS.put(ParDo.BoundMulti.class, multiDo());
     EVALUATORS.put(ConsoleIO.Write.Unbound.class, print());
     EVALUATORS.put(CreateStream.class, createFromQueue());

http://git-wip-us.apache.org/repos/asf/beam/blob/6253abaa/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/TrackStreamingSourcesTest.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/TrackStreamingSourcesTest.java
 
b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/TrackStreamingSourcesTest.java
index b181a04..d66633b 100644
--- 
a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/TrackStreamingSourcesTest.java
+++ 
b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/TrackStreamingSourcesTest.java
@@ -83,7 +83,7 @@ public class TrackStreamingSourcesTest {
 
     p.apply(emptyStream).apply(ParDo.of(new PassthroughFn<>()));
 
-    p.traverseTopologically(new StreamingSourceTracker(jssc, p, 
ParDo.Bound.class,  0));
+    p.traverseTopologically(new StreamingSourceTracker(jssc, p, 
ParDo.BoundMulti.class,  0));
     assertThat(StreamingSourceTracker.numAssertions, equalTo(1));
   }
 
@@ -111,7 +111,7 @@ public class TrackStreamingSourcesTest {
         
PCollectionList.of(pcol1).and(pcol2).apply(Flatten.<Integer>pCollections());
     flattened.apply(ParDo.of(new PassthroughFn<>()));
 
-    p.traverseTopologically(new StreamingSourceTracker(jssc, p, 
ParDo.Bound.class, 0, 1));
+    p.traverseTopologically(new StreamingSourceTracker(jssc, p, 
ParDo.BoundMulti.class, 0, 1));
     assertThat(StreamingSourceTracker.numAssertions, equalTo(1));
   }
 

http://git-wip-us.apache.org/repos/asf/beam/blob/6253abaa/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java
----------------------------------------------------------------------
diff --git 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java
index 19c5a2d..9225231 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java
@@ -738,12 +738,8 @@ public class ParDo {
 
     @Override
     public PCollection<OutputT> expand(PCollection<? extends InputT> input) {
-      validateWindowType(input, fn);
-      return PCollection.<OutputT>createPrimitiveOutputInternal(
-              input.getPipeline(),
-              input.getWindowingStrategy(),
-              input.isBounded())
-          .setTypeDescriptor(getFn().getOutputTypeDescriptor());
+      TupleTag<OutputT> mainOutput = new TupleTag<>();
+      return input.apply(withOutputTags(mainOutput, 
TupleTagList.empty())).get(mainOutput);
     }
 
     @Override

Reply via email to