http://git-wip-us.apache.org/repos/asf/beam/blob/7b062d71/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java index 3e941e4..fa5ae95 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java @@ -18,6 +18,7 @@ package org.apache.beam.runners.spark.translation; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static org.apache.beam.runners.spark.io.hadoop.ShardNameBuilder.getOutputDirectory; import static org.apache.beam.runners.spark.io.hadoop.ShardNameBuilder.getOutputFilePrefix; @@ -28,6 +29,7 @@ import static org.apache.beam.runners.spark.translation.TranslationUtils.rejectS import com.google.common.collect.Maps; import java.io.IOException; import java.util.Collections; +import java.util.List; import java.util.Map; import org.apache.avro.mapred.AvroKey; import org.apache.avro.mapreduce.AvroJob; @@ -63,9 +65,8 @@ import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.util.WindowingStrategy; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; -import org.apache.beam.sdk.values.PCollectionList; -import org.apache.beam.sdk.values.PCollectionTuple; import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.TaggedPValue; import org.apache.beam.sdk.values.TupleTag; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.io.NullWritable; @@ -94,14 +95,19 @@ public final class TransformTranslator { @SuppressWarnings("unchecked") @Override public void evaluate(Flatten.FlattenPCollectionList<T> transform, EvaluationContext context) { - PCollectionList<T> pcs = context.getInput(transform); + List<TaggedPValue> pcs = context.getInputs(transform); JavaRDD<WindowedValue<T>> unionRDD; if (pcs.size() == 0) { unionRDD = context.getSparkContext().emptyRDD(); } else { JavaRDD<WindowedValue<T>>[] rdds = new JavaRDD[pcs.size()]; for (int i = 0; i < rdds.length; i++) { - rdds[i] = ((BoundedDataset<T>) context.borrowDataset(pcs.get(i))).getRDD(); + checkArgument( + pcs.get(i).getValue() instanceof PCollection, + "Flatten had non-PCollection value in input: %s of type %s", + pcs.get(i).getValue(), + pcs.get(i).getValue().getClass().getSimpleName()); + rdds[i] = ((BoundedDataset<T>) context.borrowDataset(pcs.get(i).getValue())).getRDD(); } unionRDD = context.getSparkContext().union(rdds); } @@ -124,9 +130,15 @@ public final class TransformTranslator { final Accumulator<NamedAggregators> accum = SparkAggregators.getNamedAggregators(context.getSparkContext()); - context.putDataset(transform, - new BoundedDataset<>(GroupCombineFunctions.groupByKey(inRDD, accum, coder, - context.getRuntimeContext(), context.getInput(transform).getWindowingStrategy()))); + context.putDataset( + transform, + new BoundedDataset<>( + GroupCombineFunctions.groupByKey( + inRDD, + accum, + coder, + context.getRuntimeContext(), + context.getInput(transform).getWindowingStrategy()))); } }; } @@ -265,11 +277,11 @@ public final class TransformTranslator { new MultiDoFnFunction<>(accum, doFn, context.getRuntimeContext(), transform.getMainOutputTag(), TranslationUtils.getSideInputs( transform.getSideInputs(), context), windowingStrategy)).cache(); - PCollectionTuple pct = context.getOutput(transform); - for (Map.Entry<TupleTag<?>, PCollection<?>> e : pct.getAll().entrySet()) { + List<TaggedPValue> pct = context.getOutputs(transform); + for (TaggedPValue e : pct) { @SuppressWarnings("unchecked") JavaPairRDD<TupleTag<?>, WindowedValue<?>> filtered = - all.filter(new TranslationUtils.TupleTagFilter(e.getKey())); + all.filter(new TranslationUtils.TupleTagFilter(e.getTag())); @SuppressWarnings("unchecked") // Object is the best we can do since different outputs can have different tags JavaRDD<WindowedValue<Object>> values = @@ -529,7 +541,7 @@ public final class TransformTranslator { @Override public void evaluate(View.AsSingleton<T> transform, EvaluationContext context) { Iterable<? extends WindowedValue<?>> iter = - context.getWindowedValues(context.getInput(transform)); + context.getWindowedValues(context.getInput(transform)); PCollectionView<T> output = context.getOutput(transform); Coder<Iterable<WindowedValue<?>>> coderInternal = output.getCoderInternal();
http://git-wip-us.apache.org/repos/asf/beam/blob/7b062d71/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java index 3c89b99..a2a1d3b 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java @@ -17,6 +17,7 @@ */ package org.apache.beam.runners.spark.translation.streaming; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static org.apache.beam.runners.spark.translation.TranslationUtils.rejectStateAndTimers; @@ -64,8 +65,7 @@ import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.util.WindowingStrategy; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; -import org.apache.beam.sdk.values.PCollectionList; -import org.apache.beam.sdk.values.PCollectionTuple; +import org.apache.beam.sdk.values.TaggedPValue; import org.apache.beam.sdk.values.TupleTag; import org.apache.spark.Accumulator; import org.apache.spark.api.java.JavaPairRDD; @@ -125,14 +125,20 @@ final class StreamingTransformTranslator { @SuppressWarnings("unchecked") @Override public void evaluate(Flatten.FlattenPCollectionList<T> transform, EvaluationContext context) { - PCollectionList<T> pcs = context.getInput(transform); + List<TaggedPValue> pcs = context.getInputs(transform); // since this is a streaming pipeline, at least one of the PCollections to "flatten" are // unbounded, meaning it represents a DStream. // So we could end up with an unbounded unified DStream. final List<JavaRDD<WindowedValue<T>>> rdds = new ArrayList<>(); final List<JavaDStream<WindowedValue<T>>> dStreams = new ArrayList<>(); - for (PCollection<T> pcol : pcs.getAll()) { - Dataset dataset = context.borrowDataset(pcol); + for (TaggedPValue pv : pcs) { + checkArgument( + pv.getValue() instanceof PCollection, + "Flatten had non-PCollection value in input: %s of type %s", + pv.getValue(), + pv.getValue().getClass().getSimpleName()); + PCollection<T> pcol = (PCollection<T>) pv.getValue(); + Dataset dataset = context.borrowDataset(pcol); if (dataset instanceof UnboundedDataset) { dStreams.add(((UnboundedDataset<T>) dataset).getDStream()); } else { @@ -144,14 +150,15 @@ final class StreamingTransformTranslator { context.getStreamingContext().union(dStreams.remove(0), dStreams); // now unify in RDDs. if (rdds.size() > 0) { - JavaDStream<WindowedValue<T>> joined = unifiedStreams.transform( - new Function<JavaRDD<WindowedValue<T>>, JavaRDD<WindowedValue<T>>>() { - @Override - public JavaRDD<WindowedValue<T>> call(JavaRDD<WindowedValue<T>> streamRdd) - throws Exception { - return new JavaSparkContext(streamRdd.context()).union(streamRdd, rdds); - } - }); + JavaDStream<WindowedValue<T>> joined = + unifiedStreams.transform( + new Function<JavaRDD<WindowedValue<T>>, JavaRDD<WindowedValue<T>>>() { + @Override + public JavaRDD<WindowedValue<T>> call(JavaRDD<WindowedValue<T>> streamRdd) + throws Exception { + return new JavaSparkContext(streamRdd.context()).union(streamRdd, rdds); + } + }); context.putDataset(transform, new UnboundedDataset<>(joined)); } else { context.putDataset(transform, new UnboundedDataset<>(unifiedStreams)); @@ -284,8 +291,9 @@ final class StreamingTransformTranslator { @SuppressWarnings("unchecked") @Override - public void evaluate(final Combine.Globally<InputT, OutputT> transform, - EvaluationContext context) { + public void evaluate( + final Combine.Globally<InputT, OutputT> transform, + EvaluationContext context) { final PCollection<InputT> input = context.getInput(transform); // serializable arguments to pass. final Coder<InputT> iCoder = context.getInput(transform).getCoder(); @@ -372,7 +380,6 @@ final class StreamingTransformTranslator { final WindowingStrategy<?, ?> windowingStrategy = context.getInput(transform).getWindowingStrategy(); final SparkPCollectionView pviews = context.getPViews(); - JavaDStream<WindowedValue<InputT>> dStream = ((UnboundedDataset<InputT>) context.borrowDataset(transform)).getDStream(); @@ -431,11 +438,11 @@ final class StreamingTransformTranslator { runtimeContext, transform.getMainOutputTag(), sideInputs, windowingStrategy)); } }).cache(); - PCollectionTuple pct = context.getOutput(transform); - for (Map.Entry<TupleTag<?>, PCollection<?>> e : pct.getAll().entrySet()) { + List<TaggedPValue> pct = context.getOutputs(transform); + for (TaggedPValue e : pct) { @SuppressWarnings("unchecked") JavaPairDStream<TupleTag<?>, WindowedValue<?>> filtered = - all.filter(new TranslationUtils.TupleTagFilter(e.getKey())); + all.filter(new TranslationUtils.TupleTagFilter(e.getTag())); @SuppressWarnings("unchecked") // Object is the best we can do since different outputs can have different tags JavaDStream<WindowedValue<Object>> values = http://git-wip-us.apache.org/repos/asf/beam/blob/7b062d71/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AppliedPTransform.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AppliedPTransform.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AppliedPTransform.java index 77de54a..a6d8859 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AppliedPTransform.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AppliedPTransform.java @@ -18,8 +18,11 @@ package org.apache.beam.sdk.transforms; import com.google.auto.value.AutoValue; +import java.util.List; +import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.values.PInput; import org.apache.beam.sdk.values.POutput; +import org.apache.beam.sdk.values.TaggedPValue; /** * Represents the application of a {@link PTransform} to a specific input to produce @@ -43,14 +46,16 @@ public abstract class AppliedPTransform AppliedPTransform<InputT, OutputT, TransformT> of( String fullName, InputT input, OutputT output, TransformT transform) { return new AutoValue_AppliedPTransform<InputT, OutputT, TransformT>( - fullName, input, output, transform); + fullName, input.expand(), output.expand(), transform, input.getPipeline()); } public abstract String getFullName(); - public abstract InputT getInput(); + public abstract List<TaggedPValue> getInputs(); - public abstract OutputT getOutput(); + public abstract List<TaggedPValue> getOutputs(); public abstract TransformT getTransform(); + + public abstract Pipeline getPipeline(); }
