http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/0feb6499/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java index 8341c6d..1a0511f 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java @@ -19,39 +19,32 @@ package org.apache.beam.runners.spark.translation; +import static com.google.common.base.Preconditions.checkState; import static org.apache.beam.runners.spark.io.hadoop.ShardNameBuilder.getOutputDirectory; import static org.apache.beam.runners.spark.io.hadoop.ShardNameBuilder.getOutputFilePrefix; import static org.apache.beam.runners.spark.io.hadoop.ShardNameBuilder.getOutputFileTemplate; import static org.apache.beam.runners.spark.io.hadoop.ShardNameBuilder.replaceShardCount; -import com.google.common.collect.ImmutableMap; -import com.google.common.collect.Lists; import com.google.common.collect.Maps; import java.io.IOException; -import java.io.Serializable; -import java.lang.reflect.Field; -import java.util.Arrays; import java.util.Collections; -import java.util.List; import java.util.Map; import org.apache.avro.mapred.AvroKey; import org.apache.avro.mapreduce.AvroJob; import org.apache.avro.mapreduce.AvroKeyInputFormat; import org.apache.beam.runners.core.AssignWindowsDoFn; -import org.apache.beam.runners.core.GroupAlsoByWindowsViaOutputBufferDoFn; import org.apache.beam.runners.core.GroupByKeyViaGroupByKeyOnly.GroupAlsoByWindow; import org.apache.beam.runners.core.GroupByKeyViaGroupByKeyOnly.GroupByKeyOnly; -import org.apache.beam.runners.core.SystemReduceFn; +import org.apache.beam.runners.spark.aggregators.AccumulatorSingleton; +import org.apache.beam.runners.spark.aggregators.NamedAggregators; import org.apache.beam.runners.spark.coders.CoderHelpers; import org.apache.beam.runners.spark.io.hadoop.HadoopIO; import org.apache.beam.runners.spark.io.hadoop.ShardNameTemplateHelper; import org.apache.beam.runners.spark.io.hadoop.TemplatedAvroKeyOutputFormat; import org.apache.beam.runners.spark.io.hadoop.TemplatedTextOutputFormat; import org.apache.beam.runners.spark.util.BroadcastHelper; -import org.apache.beam.runners.spark.util.ByteArray; import org.apache.beam.sdk.coders.CannotProvideCoderException; import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.coders.IterableCoder; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.io.AvroIO; import org.apache.beam.sdk.io.TextIO; @@ -63,36 +56,30 @@ import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.View; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.transforms.windowing.GlobalWindows; import org.apache.beam.sdk.transforms.windowing.Window; import org.apache.beam.sdk.transforms.windowing.WindowFn; import org.apache.beam.sdk.util.WindowedValue; -import org.apache.beam.sdk.util.WindowedValue.WindowedValueCoder; -import org.apache.beam.sdk.util.WindowingStrategy; -import org.apache.beam.sdk.util.state.InMemoryStateInternals; -import org.apache.beam.sdk.util.state.StateInternals; -import org.apache.beam.sdk.util.state.StateInternalsFactory; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionList; import org.apache.beam.sdk.values.PCollectionTuple; -import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.TupleTag; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.io.NullWritable; import org.apache.hadoop.io.Text; import org.apache.hadoop.mapreduce.Job; import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat; +import org.apache.spark.Accumulator; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaRDDLike; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.Function2; -import org.apache.spark.api.java.function.PairFlatMapFunction; import org.apache.spark.api.java.function.PairFunction; + import scala.Tuple2; + /** * Supports translation between a Beam transform, and Spark's operations on RDDs. */ @@ -101,31 +88,6 @@ public final class TransformTranslator { private TransformTranslator() { } - /** - * Getter of the field. - */ - public static class FieldGetter { - private final Map<String, Field> fields; - - public FieldGetter(Class<?> clazz) { - this.fields = Maps.newHashMap(); - for (Field f : clazz.getDeclaredFields()) { - f.setAccessible(true); - this.fields.put(f.getName(), f); - } - } - - public <T> T get(String fieldname, Object value) { - try { - @SuppressWarnings("unchecked") - T fieldValue = (T) fields.get(fieldname).get(value); - return fieldValue; - } catch (IllegalAccessException e) { - throw new IllegalStateException(e); - } - } - } - private static <T> TransformEvaluator<Flatten.FlattenPCollectionList<T>> flattenPColl() { return new TransformEvaluator<Flatten.FlattenPCollectionList<T>>() { @SuppressWarnings("unchecked") @@ -142,28 +104,18 @@ public final class TransformTranslator { }; } - private static <K, V> TransformEvaluator<GroupByKeyOnly<K, V>> gbk() { + private static <K, V> TransformEvaluator<GroupByKeyOnly<K, V>> gbko() { return new TransformEvaluator<GroupByKeyOnly<K, V>>() { @Override public void evaluate(GroupByKeyOnly<K, V> transform, EvaluationContext context) { @SuppressWarnings("unchecked") - JavaRDDLike<WindowedValue<KV<K, V>>, ?> inRDD = - (JavaRDDLike<WindowedValue<KV<K, V>>, ?>) context.getInputRDD(transform); + JavaRDD<WindowedValue<KV<K, V>>> inRDD = + (JavaRDD<WindowedValue<KV<K, V>>>) context.getInputRDD(transform); + @SuppressWarnings("unchecked") - KvCoder<K, V> coder = (KvCoder<K, V>) context.getInput(transform).getCoder(); - Coder<K> keyCoder = coder.getKeyCoder(); - Coder<V> valueCoder = coder.getValueCoder(); - - // Use coders to convert objects in the PCollection to byte arrays, so they - // can be transferred over the network for the shuffle. - JavaRDDLike<WindowedValue<KV<K, Iterable<V>>>, ?> outRDD = fromPair( - toPair(inRDD.map(WindowingHelpers.<KV<K, V>>unwindowFunction())) - .mapToPair(CoderHelpers.toByteFunction(keyCoder, valueCoder)) - .groupByKey() - .mapToPair(CoderHelpers.fromByteFunctionIterable(keyCoder, valueCoder))) - // empty windows are OK here, see GroupByKey#evaluateHelper in the SDK - .map(WindowingHelpers.<KV<K, Iterable<V>>>windowFunction()); - context.setOutputRDD(transform, outRDD); + final KvCoder<K, V> coder = (KvCoder<K, V>) context.getInput(transform).getCoder(); + + context.setOutputRDD(transform, GroupCombineFunctions.groupByKeyOnly(inRDD, coder)); } }; } @@ -174,81 +126,52 @@ public final class TransformTranslator { @Override public void evaluate(GroupAlsoByWindow<K, V> transform, EvaluationContext context) { @SuppressWarnings("unchecked") - JavaRDDLike<WindowedValue<KV<K, Iterable<WindowedValue<V>>>>, ?> inRDD = - (JavaRDDLike<WindowedValue<KV<K, Iterable<WindowedValue<V>>>>, ?>) + JavaRDD<WindowedValue<KV<K, Iterable<WindowedValue<V>>>>> inRDD = + (JavaRDD<WindowedValue<KV<K, Iterable<WindowedValue<V>>>>>) context.getInputRDD(transform); - Coder<KV<K, Iterable<WindowedValue<V>>>> inputCoder = - context.getInput(transform).getCoder(); - Coder<K> keyCoder = transform.getKeyCoder(inputCoder); - Coder<V> valueCoder = transform.getValueCoder(inputCoder); - @SuppressWarnings("unchecked") - KvCoder<K, Iterable<WindowedValue<V>>> inputKvCoder = + final KvCoder<K, Iterable<WindowedValue<V>>> inputKvCoder = (KvCoder<K, Iterable<WindowedValue<V>>>) context.getInput(transform).getCoder(); - Coder<Iterable<WindowedValue<V>>> inputValueCoder = inputKvCoder.getValueCoder(); - - IterableCoder<WindowedValue<V>> inputIterableValueCoder = - (IterableCoder<WindowedValue<V>>) inputValueCoder; - Coder<WindowedValue<V>> inputIterableElementCoder = inputIterableValueCoder.getElemCoder(); - WindowedValueCoder<V> inputIterableWindowedValueCoder = - (WindowedValueCoder<V>) inputIterableElementCoder; - Coder<V> inputIterableElementValueCoder = inputIterableWindowedValueCoder.getValueCoder(); + final Accumulator<NamedAggregators> accum = + AccumulatorSingleton.getInstance(context.getSparkContext()); - @SuppressWarnings("unchecked") - WindowingStrategy<?, W> windowingStrategy = - (WindowingStrategy<?, W>) transform.getWindowingStrategy(); - - OldDoFn<KV<K, Iterable<WindowedValue<V>>>, KV<K, Iterable<V>>> gabwDoFn = - new GroupAlsoByWindowsViaOutputBufferDoFn<K, V, Iterable<V>, W>( - windowingStrategy, - new InMemoryStateInternalsFactory<K>(), - SystemReduceFn.<K, V, W>buffering(inputIterableElementValueCoder)); - - // GroupAlsoByWindow current uses a dummy in-memory StateInternals - JavaRDDLike<WindowedValue<KV<K, Iterable<V>>>, ?> outRDD = - inRDD.mapPartitions( - new DoFnFunction<KV<K, Iterable<WindowedValue<V>>>, KV<K, Iterable<V>>>( - gabwDoFn, context.getRuntimeContext(), null)); - - context.setOutputRDD(transform, outRDD); + context.setOutputRDD(transform, GroupCombineFunctions.groupAlsoByWindow(inRDD, transform, + context.getRuntimeContext(), accum, inputKvCoder)); } }; } - private static final FieldGetter GROUPED_FG = new FieldGetter(Combine.GroupedValues.class); - private static <K, InputT, OutputT> TransformEvaluator<Combine.GroupedValues<K, InputT, OutputT>> grouped() { return new TransformEvaluator<Combine.GroupedValues<K, InputT, OutputT>>() { @Override public void evaluate(Combine.GroupedValues<K, InputT, OutputT> transform, EvaluationContext context) { - Combine.KeyedCombineFn<K, InputT, ?, OutputT> keyed = GROUPED_FG.get("fn", transform); @SuppressWarnings("unchecked") JavaRDDLike<WindowedValue<KV<K, Iterable<InputT>>>, ?> inRDD = - (JavaRDDLike<WindowedValue<KV<K, Iterable<InputT>>>, ?>) context.getInputRDD(transform); - context.setOutputRDD(transform, - inRDD.map(new KVFunction<>(keyed))); + (JavaRDDLike<WindowedValue<KV<K, Iterable<InputT>>>, ?>) + context.getInputRDD(transform); + context.setOutputRDD(transform, inRDD.map( + new TranslationUtils.CombineGroupedValues<>(transform))); } }; } - private static final FieldGetter COMBINE_GLOBALLY_FG = new FieldGetter(Combine.Globally.class); - private static <InputT, AccumT, OutputT> TransformEvaluator<Combine.Globally<InputT, OutputT>> combineGlobally() { return new TransformEvaluator<Combine.Globally<InputT, OutputT>>() { @Override public void evaluate(Combine.Globally<InputT, OutputT> transform, EvaluationContext context) { - final Combine.CombineFn<InputT, AccumT, OutputT> globally = - COMBINE_GLOBALLY_FG.get("fn", transform); + @SuppressWarnings("unchecked") + JavaRDD<WindowedValue<InputT>> inRdd = + (JavaRDD<WindowedValue<InputT>>) context.getInputRDD(transform); @SuppressWarnings("unchecked") - JavaRDDLike<WindowedValue<InputT>, ?> inRdd = - (JavaRDDLike<WindowedValue<InputT>, ?>) context.getInputRDD(transform); + final Combine.CombineFn<InputT, AccumT, OutputT> globally = + (Combine.CombineFn<InputT, AccumT, OutputT>) transform.getFn(); final Coder<InputT> iCoder = context.getInput(transform).getCoder(); final Coder<AccumT> aCoder; @@ -259,61 +182,26 @@ public final class TransformTranslator { throw new IllegalStateException("Could not determine coder for accumulator", e); } - // Use coders to convert objects in the PCollection to byte arrays, so they - // can be transferred over the network for the shuffle. - JavaRDD<byte[]> inRddBytes = inRdd - .map(WindowingHelpers.<InputT>unwindowFunction()) - .map(CoderHelpers.toByteFunction(iCoder)); - - /*AccumT*/ byte[] acc = inRddBytes.aggregate( - CoderHelpers.toByteArray(globally.createAccumulator(), aCoder), - new Function2</*AccumT*/ byte[], /*InputT*/ byte[], /*AccumT*/ byte[]>() { - @Override - public /*AccumT*/ byte[] call(/*AccumT*/ byte[] ab, /*InputT*/ byte[] ib) - throws Exception { - AccumT a = CoderHelpers.fromByteArray(ab, aCoder); - InputT i = CoderHelpers.fromByteArray(ib, iCoder); - return CoderHelpers.toByteArray(globally.addInput(a, i), aCoder); - } - }, - new Function2</*AccumT*/ byte[], /*AccumT*/ byte[], /*AccumT*/ byte[]>() { - @Override - public /*AccumT*/ byte[] call(/*AccumT*/ byte[] a1b, /*AccumT*/ byte[] a2b) - throws Exception { - AccumT a1 = CoderHelpers.fromByteArray(a1b, aCoder); - AccumT a2 = CoderHelpers.fromByteArray(a2b, aCoder); - // don't use Guava's ImmutableList.of as values may be null - List<AccumT> accumulators = Collections.unmodifiableList(Arrays.asList(a1, a2)); - AccumT merged = globally.mergeAccumulators(accumulators); - return CoderHelpers.toByteArray(merged, aCoder); - } - } - ); - OutputT output = globally.extractOutput(CoderHelpers.fromByteArray(acc, aCoder)); - - Coder<OutputT> coder = context.getOutput(transform).getCoder(); + final Coder<OutputT> oCoder = context.getOutput(transform).getCoder(); JavaRDD<byte[]> outRdd = context.getSparkContext().parallelize( // don't use Guava's ImmutableList.of as output may be null - CoderHelpers.toByteArrays(Collections.singleton(output), coder)); - context.setOutputRDD(transform, outRdd.map(CoderHelpers.fromByteFunction(coder)) + CoderHelpers.toByteArrays(Collections.singleton( + GroupCombineFunctions.combineGlobally(inRdd, globally, iCoder, aCoder)), oCoder)); + context.setOutputRDD(transform, outRdd.map(CoderHelpers.fromByteFunction(oCoder)) .map(WindowingHelpers.<OutputT>windowFunction())); } }; } - private static final FieldGetter COMBINE_PERKEY_FG = new FieldGetter(Combine.PerKey.class); - private static <K, InputT, AccumT, OutputT> TransformEvaluator<Combine.PerKey<K, InputT, OutputT>> combinePerKey() { return new TransformEvaluator<Combine.PerKey<K, InputT, OutputT>>() { @Override - public void evaluate(Combine.PerKey<K, InputT, OutputT> - transform, EvaluationContext context) { - final Combine.KeyedCombineFn<K, InputT, AccumT, OutputT> keyed = - COMBINE_PERKEY_FG.get("fn", transform); + public void evaluate(Combine.PerKey<K, InputT, OutputT> transform, + EvaluationContext context) { @SuppressWarnings("unchecked") - JavaRDDLike<WindowedValue<KV<K, InputT>>, ?> inRdd = - (JavaRDDLike<WindowedValue<KV<K, InputT>>, ?>) context.getInputRDD(transform); + final Combine.KeyedCombineFn<K, InputT, AccumT, OutputT> keyed = + (Combine.KeyedCombineFn<K, InputT, AccumT, OutputT>) transform.getFn(); @SuppressWarnings("unchecked") KvCoder<K, InputT> inputCoder = (KvCoder<K, InputT>) @@ -329,214 +217,66 @@ public final class TransformTranslator { } Coder<KV<K, InputT>> kviCoder = KvCoder.of(keyCoder, viCoder); Coder<KV<K, AccumT>> kvaCoder = KvCoder.of(keyCoder, vaCoder); - - // We need to duplicate K as both the key of the JavaPairRDD as well as inside the value, - // since the functions passed to combineByKey don't receive the associated key of each - // value, and we need to map back into methods in Combine.KeyedCombineFn, which each - // require the key in addition to the InputT's and AccumT's being merged/accumulated. - // Once Spark provides a way to include keys in the arguments of combine/merge functions, - // we won't need to duplicate the keys anymore. - - // Key has to bw windowed in order to group by window as well - JavaPairRDD<WindowedValue<K>, WindowedValue<KV<K, InputT>>> inRddDuplicatedKeyPair = - inRdd.flatMapToPair( - new PairFlatMapFunction<WindowedValue<KV<K, InputT>>, WindowedValue<K>, - WindowedValue<KV<K, InputT>>>() { - @Override - public Iterable<Tuple2<WindowedValue<K>, - WindowedValue<KV<K, InputT>>>> - call(WindowedValue<KV<K, InputT>> kv) { - List<Tuple2<WindowedValue<K>, - WindowedValue<KV<K, InputT>>>> tuple2s = - Lists.newArrayListWithCapacity(kv.getWindows().size()); - for (BoundedWindow boundedWindow: kv.getWindows()) { - WindowedValue<K> wk = WindowedValue.of(kv.getValue().getKey(), - boundedWindow.maxTimestamp(), boundedWindow, kv.getPane()); - tuple2s.add(new Tuple2<>(wk, kv)); - } - return tuple2s; - } - }); //-- windowed coders final WindowedValue.FullWindowedValueCoder<K> wkCoder = - WindowedValue.FullWindowedValueCoder.of(keyCoder, + WindowedValue.FullWindowedValueCoder.of(keyCoder, context.getInput(transform).getWindowingStrategy().getWindowFn().windowCoder()); final WindowedValue.FullWindowedValueCoder<KV<K, InputT>> wkviCoder = - WindowedValue.FullWindowedValueCoder.of(kviCoder, + WindowedValue.FullWindowedValueCoder.of(kviCoder, context.getInput(transform).getWindowingStrategy().getWindowFn().windowCoder()); final WindowedValue.FullWindowedValueCoder<KV<K, AccumT>> wkvaCoder = - WindowedValue.FullWindowedValueCoder.of(kvaCoder, + WindowedValue.FullWindowedValueCoder.of(kvaCoder, context.getInput(transform).getWindowingStrategy().getWindowFn().windowCoder()); - // Use coders to convert objects in the PCollection to byte arrays, so they - // can be transferred over the network for the shuffle. - JavaPairRDD<ByteArray, byte[]> inRddDuplicatedKeyPairBytes = inRddDuplicatedKeyPair - .mapToPair(CoderHelpers.toByteFunction(wkCoder, wkviCoder)); - - // The output of combineByKey will be "AccumT" (accumulator) - // types rather than "OutputT" (final output types) since Combine.CombineFn - // only provides ways to merge VAs, and no way to merge VOs. - JavaPairRDD</*K*/ ByteArray, /*KV<K, AccumT>*/ byte[]> accumulatedBytes = - inRddDuplicatedKeyPairBytes.combineByKey( - new Function</*KV<K, InputT>*/ byte[], /*KV<K, AccumT>*/ byte[]>() { - @Override - public /*KV<K, AccumT>*/ byte[] call(/*KV<K, InputT>*/ byte[] input) { - WindowedValue<KV<K, InputT>> wkvi = - CoderHelpers.fromByteArray(input, wkviCoder); - AccumT va = keyed.createAccumulator(wkvi.getValue().getKey()); - va = keyed.addInput(wkvi.getValue().getKey(), va, wkvi.getValue().getValue()); - WindowedValue<KV<K, AccumT>> wkva = - WindowedValue.of(KV.of(wkvi.getValue().getKey(), va), wkvi.getTimestamp(), - wkvi.getWindows(), wkvi.getPane()); - return CoderHelpers.toByteArray(wkva, wkvaCoder); - } - }, - new Function2</*KV<K, AccumT>*/ byte[], - /*KV<K, InputT>*/ byte[], - /*KV<K, AccumT>*/ byte[]>() { - @Override - public /*KV<K, AccumT>*/ byte[] call(/*KV<K, AccumT>*/ byte[] acc, - /*KV<K, InputT>*/ byte[] input) { - WindowedValue<KV<K, AccumT>> wkva = - CoderHelpers.fromByteArray(acc, wkvaCoder); - WindowedValue<KV<K, InputT>> wkvi = - CoderHelpers.fromByteArray(input, wkviCoder); - AccumT va = - keyed.addInput(wkva.getValue().getKey(), wkva.getValue().getValue(), - wkvi.getValue().getValue()); - wkva = WindowedValue.of(KV.of(wkva.getValue().getKey(), va), wkva.getTimestamp(), - wkva.getWindows(), wkva.getPane()); - return CoderHelpers.toByteArray(wkva, wkvaCoder); - } - }, - new Function2</*KV<K, AccumT>*/ byte[], - /*KV<K, AccumT>*/ byte[], - /*KV<K, AccumT>*/ byte[]>() { - @Override - public /*KV<K, AccumT>*/ byte[] call(/*KV<K, AccumT>*/ byte[] acc1, - /*KV<K, AccumT>*/ byte[] acc2) { - WindowedValue<KV<K, AccumT>> wkva1 = - CoderHelpers.fromByteArray(acc1, wkvaCoder); - WindowedValue<KV<K, AccumT>> wkva2 = - CoderHelpers.fromByteArray(acc2, wkvaCoder); - AccumT va = keyed.mergeAccumulators(wkva1.getValue().getKey(), - // don't use Guava's ImmutableList.of as values may be null - Collections.unmodifiableList(Arrays.asList(wkva1.getValue().getValue(), - wkva2.getValue().getValue()))); - WindowedValue<KV<K, AccumT>> wkva = - WindowedValue.of(KV.of(wkva1.getValue().getKey(), - va), wkva1.getTimestamp(), wkva1.getWindows(), wkva1.getPane()); - return CoderHelpers.toByteArray(wkva, wkvaCoder); - } - }); - - JavaPairRDD<WindowedValue<K>, WindowedValue<OutputT>> extracted = accumulatedBytes - .mapToPair(CoderHelpers.fromByteFunction(wkCoder, wkvaCoder)) - .mapValues( - new Function<WindowedValue<KV<K, AccumT>>, WindowedValue<OutputT>>() { - @Override - public WindowedValue<OutputT> call(WindowedValue<KV<K, AccumT>> acc) { - return WindowedValue.of(keyed.extractOutput(acc.getValue().getKey(), - acc.getValue().getValue()), acc.getTimestamp(), - acc.getWindows(), acc.getPane()); - } - }); + @SuppressWarnings("unchecked") + JavaRDD<WindowedValue<KV<K, InputT>>> inRdd = + (JavaRDD<WindowedValue<KV<K, InputT>>>) context.getInputRDD(transform); - context.setOutputRDD(transform, - fromPair(extracted) - .map(new Function<KV<WindowedValue<K>, WindowedValue<OutputT>>, - WindowedValue<KV<K, OutputT>>>() { - @Override - public WindowedValue<KV<K, OutputT>> call(KV<WindowedValue<K>, - WindowedValue<OutputT>> kwvo) - throws Exception { - WindowedValue<OutputT> wvo = kwvo.getValue(); - KV<K, OutputT> kvo = KV.of(kwvo.getKey().getValue(), wvo.getValue()); - return WindowedValue.of(kvo, wvo.getTimestamp(), wvo.getWindows(), wvo.getPane()); - } - })); + context.setOutputRDD(transform, GroupCombineFunctions.combinePerKey(inRdd, keyed, wkCoder, + wkviCoder, wkvaCoder)); } }; } - private static final class KVFunction<K, InputT, OutputT> - implements Function<WindowedValue<KV<K, Iterable<InputT>>>, - WindowedValue<KV<K, OutputT>>> { - private final Combine.KeyedCombineFn<K, InputT, ?, OutputT> keyed; - - KVFunction(Combine.KeyedCombineFn<K, InputT, ?, OutputT> keyed) { - this.keyed = keyed; - } - - @Override - public WindowedValue<KV<K, OutputT>> call(WindowedValue<KV<K, - Iterable<InputT>>> windowedKv) - throws Exception { - KV<K, Iterable<InputT>> kv = windowedKv.getValue(); - return WindowedValue.of(KV.of(kv.getKey(), keyed.apply(kv.getKey(), kv.getValue())), - windowedKv.getTimestamp(), windowedKv.getWindows(), windowedKv.getPane()); - } - } - - private static <K, V> JavaPairRDD<K, V> toPair(JavaRDDLike<KV<K, V>, ?> rdd) { - return rdd.mapToPair(new PairFunction<KV<K, V>, K, V>() { - @Override - public Tuple2<K, V> call(KV<K, V> kv) { - return new Tuple2<>(kv.getKey(), kv.getValue()); - } - }); - } - - private static <K, V> JavaRDDLike<KV<K, V>, ?> fromPair(JavaPairRDD<K, V> rdd) { - return rdd.map(new Function<Tuple2<K, V>, KV<K, V>>() { - @Override - public KV<K, V> call(Tuple2<K, V> t2) { - return KV.of(t2._1(), t2._2()); - } - }); - } - private static <InputT, OutputT> TransformEvaluator<ParDo.Bound<InputT, OutputT>> parDo() { return new TransformEvaluator<ParDo.Bound<InputT, OutputT>>() { @Override public void evaluate(ParDo.Bound<InputT, OutputT> transform, EvaluationContext context) { - DoFnFunction<InputT, OutputT> dofn = - new DoFnFunction<>(transform.getFn(), - context.getRuntimeContext(), - getSideInputs(transform.getSideInputs(), context)); @SuppressWarnings("unchecked") JavaRDDLike<WindowedValue<InputT>, ?> inRDD = (JavaRDDLike<WindowedValue<InputT>, ?>) context.getInputRDD(transform); - context.setOutputRDD(transform, inRDD.mapPartitions(dofn)); + Accumulator<NamedAggregators> accum = + AccumulatorSingleton.getInstance(context.getSparkContext()); + Map<TupleTag<?>, BroadcastHelper<?>> sideInputs = + TranslationUtils.getSideInputs(transform.getSideInputs(), context); + context.setOutputRDD(transform, + inRDD.mapPartitions(new DoFnFunction<>(accum, transform.getFn(), + context.getRuntimeContext(), sideInputs))); } }; } - private static final FieldGetter MULTIDO_FG = new FieldGetter(ParDo.BoundMulti.class); - - private static <InputT, OutputT> TransformEvaluator<ParDo.BoundMulti<InputT, OutputT>> multiDo() { + private static <InputT, OutputT> TransformEvaluator<ParDo.BoundMulti<InputT, OutputT>> + multiDo() { return new TransformEvaluator<ParDo.BoundMulti<InputT, OutputT>>() { @Override public void evaluate(ParDo.BoundMulti<InputT, OutputT> transform, EvaluationContext context) { - TupleTag<OutputT> mainOutputTag = MULTIDO_FG.get("mainOutputTag", transform); - MultiDoFnFunction<InputT, OutputT> multifn = new MultiDoFnFunction<>( - transform.getFn(), - context.getRuntimeContext(), - mainOutputTag, - getSideInputs(transform.getSideInputs(), context)); - @SuppressWarnings("unchecked") JavaRDDLike<WindowedValue<InputT>, ?> inRDD = (JavaRDDLike<WindowedValue<InputT>, ?>) context.getInputRDD(transform); + Accumulator<NamedAggregators> accum = + AccumulatorSingleton.getInstance(context.getSparkContext()); JavaPairRDD<TupleTag<?>, WindowedValue<?>> all = inRDD - .mapPartitionsToPair(multifn) - .cache(); - + .mapPartitionsToPair( + new MultiDoFnFunction<>(accum, transform.getFn(), context.getRuntimeContext(), + transform.getMainOutputTag(), TranslationUtils.getSideInputs( + transform.getSideInputs(), context))) + .cache(); PCollectionTuple pct = context.getOutput(transform); for (Map.Entry<TupleTag<?>, PCollection<?>> e : pct.getAll().entrySet()) { @SuppressWarnings("unchecked") JavaPairRDD<TupleTag<?>, WindowedValue<?>> filtered = - all.filter(new TupleTagFilter(e.getKey())); + all.filter(new TranslationUtils.TupleTagFilter(e.getKey())); @SuppressWarnings("unchecked") // Object is the best we can do since different outputs can have different tags JavaRDD<WindowedValue<Object>> values = @@ -753,22 +493,17 @@ public final class TransformTranslator { JavaRDDLike<WindowedValue<T>, ?> inRDD = (JavaRDDLike<WindowedValue<T>, ?>) context.getInputRDD(transform); - @SuppressWarnings("unchecked") - WindowFn<? super T, W> windowFn = (WindowFn<? super T, W>) transform.getWindowFn(); - - // Avoid running assign windows if both source and destination are global window - // or if the user has not specified the WindowFn (meaning they are just messing - // with triggering or allowed lateness) - if (windowFn == null - || (context.getInput(transform).getWindowingStrategy().getWindowFn() - instanceof GlobalWindows - && windowFn instanceof GlobalWindows)) { + if (TranslationUtils.skipAssignWindows(transform, context)) { context.setOutputRDD(transform, inRDD); } else { + @SuppressWarnings("unchecked") + WindowFn<? super T, W> windowFn = (WindowFn<? super T, W>) transform.getWindowFn(); OldDoFn<T, T> addWindowsDoFn = new AssignWindowsDoFn<>(windowFn); - DoFnFunction<T, T> dofn = - new DoFnFunction<>(addWindowsDoFn, context.getRuntimeContext(), null); - context.setOutputRDD(transform, inRDD.mapPartitions(dofn)); + Accumulator<NamedAggregators> accum = + AccumulatorSingleton.getInstance(context.getSparkContext()); + context.setOutputRDD(transform, + inRDD.mapPartitions(new DoFnFunction<>(accum, addWindowsDoFn, + context.getRuntimeContext(), null))); } } }; @@ -822,42 +557,6 @@ public final class TransformTranslator { }; } - private static final class TupleTagFilter<V> - implements Function<Tuple2<TupleTag<V>, WindowedValue<?>>, Boolean> { - - private final TupleTag<V> tag; - - private TupleTagFilter(TupleTag<V> tag) { - this.tag = tag; - } - - @Override - public Boolean call(Tuple2<TupleTag<V>, WindowedValue<?>> input) { - return tag.equals(input._1()); - } - } - - private static Map<TupleTag<?>, BroadcastHelper<?>> getSideInputs( - List<PCollectionView<?>> views, - EvaluationContext context) { - if (views == null) { - return ImmutableMap.of(); - } else { - Map<TupleTag<?>, BroadcastHelper<?>> sideInputs = Maps.newHashMap(); - for (PCollectionView<?> view : views) { - Iterable<? extends WindowedValue<?>> collectionView = context.getPCollectionView(view); - Coder<Iterable<WindowedValue<?>>> coderInternal = view.getCoderInternal(); - @SuppressWarnings("unchecked") - BroadcastHelper<?> helper = - BroadcastHelper.create((Iterable<WindowedValue<?>>) collectionView, coderInternal); - //broadcast side inputs - helper.broadcast(context.getSparkContext()); - sideInputs.put(view.getTagInternal(), helper); - } - return sideInputs; - } - } - private static final Map<Class<? extends PTransform>, TransformEvaluator<?>> EVALUATORS = Maps .newHashMap(); @@ -870,7 +569,7 @@ public final class TransformTranslator { EVALUATORS.put(HadoopIO.Write.Bound.class, writeHadoop()); EVALUATORS.put(ParDo.Bound.class, parDo()); EVALUATORS.put(ParDo.BoundMulti.class, multiDo()); - EVALUATORS.put(GroupByKeyOnly.class, gbk()); + EVALUATORS.put(GroupByKeyOnly.class, gbko()); EVALUATORS.put(GroupAlsoByWindow.class, gabw()); EVALUATORS.put(Combine.GroupedValues.class, grouped()); EVALUATORS.put(Combine.Globally.class, combineGlobally()); @@ -883,17 +582,6 @@ public final class TransformTranslator { EVALUATORS.put(Window.Bound.class, window()); } - public static <TransformT extends PTransform<?, ?>> TransformEvaluator<TransformT> - getTransformEvaluator(Class<TransformT> clazz) { - @SuppressWarnings("unchecked") - TransformEvaluator<TransformT> transform = - (TransformEvaluator<TransformT>) EVALUATORS.get(clazz); - if (transform == null) { - throw new IllegalStateException("No TransformEvaluator registered for " + clazz); - } - return transform; - } - /** * Translator matches Beam transformation with the appropriate evaluator. */ @@ -905,17 +593,20 @@ public final class TransformTranslator { } @Override - public <TransformT extends PTransform<?, ?>> TransformEvaluator<TransformT> translate( - Class<TransformT> clazz) { - return getTransformEvaluator(clazz); + public <TransformT extends PTransform<?, ?>> TransformEvaluator<TransformT> + translateBounded (Class<TransformT> clazz) { + @SuppressWarnings("unchecked") TransformEvaluator<TransformT> transformEvaluator = + (TransformEvaluator<TransformT>) EVALUATORS.get(clazz); + checkState(transformEvaluator != null, + "No TransformEvaluator registered for BOUNDED transform %s", clazz); + return transformEvaluator; } - } - private static class InMemoryStateInternalsFactory<K> implements StateInternalsFactory<K>, - Serializable { @Override - public StateInternals<K> stateInternalsForKey(K key) { - return InMemoryStateInternals.forKey(key); + public <TransformT extends PTransform<?, ?>> TransformEvaluator<TransformT> + translateUnbounded(Class<TransformT> clazz) { + throw new IllegalStateException("TransformTranslator used in a batch pipeline only " + + "supports BOUNDED transforms."); } } }
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/0feb6499/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TranslationUtils.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TranslationUtils.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TranslationUtils.java new file mode 100644 index 0000000..9b156fe --- /dev/null +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TranslationUtils.java @@ -0,0 +1,195 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.spark.translation; + +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Maps; +import java.io.Serializable; +import java.util.List; +import java.util.Map; +import org.apache.beam.runners.spark.util.BroadcastHelper; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.transforms.Combine; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.GlobalWindows; +import org.apache.beam.sdk.transforms.windowing.Window; +import org.apache.beam.sdk.transforms.windowing.WindowFn; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.util.state.InMemoryStateInternals; +import org.apache.beam.sdk.util.state.StateInternals; +import org.apache.beam.sdk.util.state.StateInternalsFactory; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.streaming.api.java.JavaDStream; +import org.apache.spark.streaming.api.java.JavaPairDStream; + +import scala.Tuple2; + +/** + * A set of utilities to help translating Beam transformations into Spark transformations. + */ +public final class TranslationUtils { + + private TranslationUtils() { + } + + /** + * In-memory state internals factory. + * + * @param <K> State key type. + */ + static class InMemoryStateInternalsFactory<K> implements StateInternalsFactory<K>, + Serializable { + @Override + public StateInternals<K> stateInternalsForKey(K key) { + return InMemoryStateInternals.forKey(key); + } + } + + /** + * A {@link Combine.GroupedValues} function applied to grouped KVs. + * + * @param <K> Grouped key type. + * @param <InputT> Grouped values type. + * @param <OutputT> Output type. + */ + public static class CombineGroupedValues<K, InputT, OutputT> implements + Function<WindowedValue<KV<K, Iterable<InputT>>>, WindowedValue<KV<K, OutputT>>> { + private final Combine.KeyedCombineFn<K, InputT, ?, OutputT> keyed; + + public CombineGroupedValues(Combine.GroupedValues<K, InputT, OutputT> transform) { + //noinspection unchecked + keyed = (Combine.KeyedCombineFn<K, InputT, ?, OutputT>) transform.getFn(); + } + + @Override + public WindowedValue<KV<K, OutputT>> call(WindowedValue<KV<K, Iterable<InputT>>> windowedKv) + throws Exception { + KV<K, Iterable<InputT>> kv = windowedKv.getValue(); + return WindowedValue.of(KV.of(kv.getKey(), keyed.apply(kv.getKey(), kv.getValue())), + windowedKv.getTimestamp(), windowedKv.getWindows(), windowedKv.getPane()); + } + } + + /** + * Checks if the window transformation should be applied or skipped. + * + * <p> + * Avoid running assign windows if both source and destination are global window + * or if the user has not specified the WindowFn (meaning they are just messing + * with triggering or allowed lateness). + * </p> + * + * @param transform The {@link Window.Bound} transformation. + * @param context The {@link EvaluationContext}. + * @param <T> PCollection type. + * @param <W> {@link BoundedWindow} type. + * @return if to apply the transformation. + */ + public static <T, W extends BoundedWindow> boolean + skipAssignWindows(Window.Bound<T> transform, EvaluationContext context) { + @SuppressWarnings("unchecked") + WindowFn<? super T, W> windowFn = (WindowFn<? super T, W>) transform.getWindowFn(); + return windowFn == null + || (context.getInput(transform).getWindowingStrategy().getWindowFn() + instanceof GlobalWindows + && windowFn instanceof GlobalWindows); + } + + /** Transform a pair stream into a value stream. */ + public static <T1, T2> JavaDStream<T2> dStreamValues(JavaPairDStream<T1, T2> pairDStream) { + return pairDStream.map(new Function<Tuple2<T1, T2>, T2>() { + @Override + public T2 call(Tuple2<T1, T2> v1) throws Exception { + return v1._2(); + } + }); + } + + /** {@link KV} to pair function. */ + static <K, V> PairFunction<KV<K, V>, K, V> toPairFunction() { + return new PairFunction<KV<K, V>, K, V>() { + @Override + public Tuple2<K, V> call(KV<K, V> kv) { + return new Tuple2<>(kv.getKey(), kv.getValue()); + } + }; + } + + /** A pair to {@link KV} function . */ + static <K, V> Function<Tuple2<K, V>, KV<K, V>> fromPairFunction() { + return new Function<Tuple2<K, V>, KV<K, V>>() { + @Override + public KV<K, V> call(Tuple2<K, V> t2) { + return KV.of(t2._1(), t2._2()); + } + }; + } + + /** + * A utility class to filter {@link TupleTag}s. + * + * @param <V> TupleTag type. + */ + public static final class TupleTagFilter<V> + implements Function<Tuple2<TupleTag<V>, WindowedValue<?>>, Boolean> { + + private final TupleTag<V> tag; + + public TupleTagFilter(TupleTag<V> tag) { + this.tag = tag; + } + + @Override + public Boolean call(Tuple2<TupleTag<V>, WindowedValue<?>> input) { + return tag.equals(input._1()); + } + } + + /*** + * Create SideInputs as Broadcast variables. + * + * @param views The {@link PCollectionView}s. + * @param context The {@link EvaluationContext}. + * @return a map of tagged {@link BroadcastHelper}s. + */ + public static Map<TupleTag<?>, BroadcastHelper<?>> getSideInputs(List<PCollectionView<?>> views, + EvaluationContext context) { + if (views == null) { + return ImmutableMap.of(); + } else { + Map<TupleTag<?>, BroadcastHelper<?>> sideInputs = Maps.newHashMap(); + for (PCollectionView<?> view : views) { + Iterable<? extends WindowedValue<?>> collectionView = context.getPCollectionView(view); + Coder<Iterable<WindowedValue<?>>> coderInternal = view.getCoderInternal(); + @SuppressWarnings("unchecked") + BroadcastHelper<?> helper = + BroadcastHelper.create((Iterable<WindowedValue<?>>) collectionView, coderInternal); + //broadcast side inputs + helper.broadcast(context.getSparkContext()); + sideInputs.put(view.getTagInternal(), helper); + } + return sideInputs; + } + } + +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/0feb6499/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/SparkRunnerStreamingContextFactory.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/SparkRunnerStreamingContextFactory.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/SparkRunnerStreamingContextFactory.java new file mode 100644 index 0000000..b7a407c --- /dev/null +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/SparkRunnerStreamingContextFactory.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.spark.translation.streaming; + +import com.google.common.base.Predicates; +import com.google.common.collect.Iterables; +import java.net.MalformedURLException; +import java.net.URL; +import java.util.Arrays; +import org.apache.beam.runners.spark.SparkPipelineOptions; +import org.apache.beam.runners.spark.SparkRunner; +import org.apache.beam.runners.spark.translation.SparkContextFactory; +import org.apache.beam.runners.spark.translation.SparkPipelineTranslator; +import org.apache.beam.runners.spark.translation.TransformTranslator; +import org.apache.beam.sdk.Pipeline; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.streaming.Duration; +import org.apache.spark.streaming.api.java.JavaStreamingContext; +import org.apache.spark.streaming.api.java.JavaStreamingContextFactory; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + + +/** + * A {@link JavaStreamingContext} factory for resilience. + * @see <a href="https://spark.apache.org/docs/1.6.2/streaming-programming-guide.html#how-to-configure-checkpointing">how-to-configure-checkpointing</a> + */ +public class SparkRunnerStreamingContextFactory implements JavaStreamingContextFactory { + private static final Logger LOG = + LoggerFactory.getLogger(SparkRunnerStreamingContextFactory.class); + private static final Iterable<String> KNOWN_RELIABLE_FS = Arrays.asList("hdfs", "s3", "gs"); + + private final Pipeline pipeline; + private final SparkPipelineOptions options; + + public SparkRunnerStreamingContextFactory(Pipeline pipeline, SparkPipelineOptions options) { + this.pipeline = pipeline; + this.options = options; + } + + private StreamingEvaluationContext ctxt; + + @Override + public JavaStreamingContext create() { + LOG.info("Creating a new Spark Streaming Context"); + + SparkPipelineTranslator translator = new StreamingTransformTranslator.Translator( + new TransformTranslator.Translator()); + Duration batchDuration = new Duration(options.getBatchIntervalMillis()); + LOG.info("Setting Spark streaming batchDuration to {} msec", batchDuration.milliseconds()); + + JavaSparkContext jsc = SparkContextFactory.getSparkContext(options); + JavaStreamingContext jssc = new JavaStreamingContext(jsc, batchDuration); + ctxt = new StreamingEvaluationContext(jsc, pipeline, jssc, + options.getTimeout()); + pipeline.traverseTopologically(new SparkRunner.Evaluator(translator, ctxt)); + ctxt.computeOutputs(); + + // set checkpoint dir. + String checkpointDir = options.getCheckpointDir(); + LOG.info("Checkpoint dir set to: {}", checkpointDir); + try { + // validate checkpoint dir and warn if not of a known durable filesystem. + URL checkpointDirUrl = new URL(checkpointDir); + if (!Iterables.any(KNOWN_RELIABLE_FS, Predicates.equalTo(checkpointDirUrl.getProtocol()))) { + LOG.warn("Checkpoint dir URL {} does not match a reliable filesystem, in case of failures " + + "this job may not recover properly or even at all.", checkpointDirUrl); + } + } catch (MalformedURLException e) { + throw new RuntimeException("Failed to form checkpoint dir URL. CheckpointDir should be in " + + "the form of hdfs:///path/to/dir or other reliable fs protocol, " + + "or file:///path/to/dir for local mode.", e); + } + jssc.checkpoint(checkpointDir); + + return jssc; + } + + public StreamingEvaluationContext getCtxt() { + return ctxt; + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/0feb6499/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingEvaluationContext.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingEvaluationContext.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingEvaluationContext.java index 2e4da44..5a43c55 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingEvaluationContext.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingEvaluationContext.java @@ -18,14 +18,18 @@ package org.apache.beam.runners.spark.translation.streaming; +import com.google.common.collect.Iterables; + import java.util.LinkedHashMap; import java.util.LinkedHashSet; import java.util.Map; import java.util.Queue; import java.util.Set; import java.util.concurrent.LinkedBlockingQueue; +import org.apache.beam.runners.spark.coders.CoderHelpers; import org.apache.beam.runners.spark.translation.EvaluationContext; import org.apache.beam.runners.spark.translation.SparkRuntimeContext; +import org.apache.beam.runners.spark.translation.WindowingHelpers; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.transforms.AppliedPTransform; @@ -82,11 +86,17 @@ public class StreamingEvaluationContext extends EvaluationContext { @SuppressWarnings("unchecked") JavaDStream<WindowedValue<T>> getDStream() { if (dStream == null) { - // create the DStream from values + WindowedValue.ValueOnlyWindowedValueCoder<T> windowCoder = + WindowedValue.getValueOnlyCoder(coder); + // create the DStream from queue Queue<JavaRDD<WindowedValue<T>>> rddQueue = new LinkedBlockingQueue<>(); for (Iterable<T> v : values) { - setOutputRDDFromValues(currentTransform.getTransform(), v, coder); - rddQueue.offer((JavaRDD<WindowedValue<T>>) getOutputRDD(currentTransform.getTransform())); + Iterable<WindowedValue<T>> windowedValues = + Iterables.transform(v, WindowingHelpers.<T>windowValueFunction()); + JavaRDD<WindowedValue<T>> rdd = getSparkContext().parallelize( + CoderHelpers.toByteArrays(windowedValues, windowCoder)).map( + CoderHelpers.fromByteFunction(windowCoder)); + rddQueue.offer(rdd); } // create dstream from queue, one at a time, no defaults // mainly for unit test so no reason to have this configurable @@ -102,7 +112,10 @@ public class StreamingEvaluationContext extends EvaluationContext { } <T> void setStream(PTransform<?, ?> transform, JavaDStream<WindowedValue<T>> dStream) { - PValue pvalue = (PValue) getOutput(transform); + setStream((PValue) getOutput(transform), dStream); + } + + <T> void setStream(PValue pvalue, JavaDStream<WindowedValue<T>> dStream) { DStreamHolder<T> dStreamHolder = new DStreamHolder<>(dStream); pstreams.put(pvalue, dStreamHolder); leafStreams.add(dStreamHolder); @@ -110,6 +123,10 @@ public class StreamingEvaluationContext extends EvaluationContext { boolean hasStream(PTransform<?, ?> transform) { PValue pvalue = (PValue) getInput(transform); + return hasStream(pvalue); + } + + boolean hasStream(PValue pvalue) { return pstreams.containsKey(pvalue); } @@ -141,19 +158,23 @@ public class StreamingEvaluationContext extends EvaluationContext { @Override public void computeOutputs() { + super.computeOutputs(); // in case the pipeline contains bounded branches as well. for (DStreamHolder<?> streamHolder : leafStreams) { computeOutput(streamHolder); - } + } // force a DStream action } private static <T> void computeOutput(DStreamHolder<T> streamHolder) { - streamHolder.getDStream().foreachRDD(new VoidFunction<JavaRDD<WindowedValue<T>>>() { + JavaDStream<WindowedValue<T>> dStream = streamHolder.getDStream(); + // cache in DStream level not RDD + // because there could be a difference in StorageLevel if the DStream is windowed. + dStream.dstream().cache(); + dStream.foreachRDD(new VoidFunction<JavaRDD<WindowedValue<T>>>() { @Override public void call(JavaRDD<WindowedValue<T>> rdd) throws Exception { - rdd.rdd().cache(); rdd.count(); } - }); // force a DStream action + }); } @Override @@ -163,8 +184,9 @@ public class StreamingEvaluationContext extends EvaluationContext { } else { jssc.awaitTermination(); } - //TODO: stop gracefully ? - jssc.stop(false, false); + // stop streaming context gracefully, so checkpointing (and other computations) get to + // finish before shutdown. + jssc.stop(false, true); state = State.DONE; super.close(); } @@ -197,7 +219,7 @@ public class StreamingEvaluationContext extends EvaluationContext { } @Override - protected void setCurrentTransform(AppliedPTransform<?, ?, ?> transform) { + public void setCurrentTransform(AppliedPTransform<?, ?, ?> transform) { super.setCurrentTransform(transform); } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/0feb6499/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java index c55be3d..64ddc57 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java @@ -17,53 +17,68 @@ */ package org.apache.beam.runners.spark.translation.streaming; -import com.google.common.collect.Lists; +import static com.google.common.base.Preconditions.checkState; + import com.google.common.collect.Maps; -import com.google.common.collect.Sets; -import com.google.common.reflect.TypeToken; -import java.lang.reflect.ParameterizedType; -import java.lang.reflect.Type; +import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Set; import kafka.serializer.Decoder; import org.apache.beam.runners.core.AssignWindowsDoFn; +import org.apache.beam.runners.core.GroupByKeyViaGroupByKeyOnly; +import org.apache.beam.runners.core.GroupByKeyViaGroupByKeyOnly.GroupAlsoByWindow; +import org.apache.beam.runners.core.GroupByKeyViaGroupByKeyOnly.GroupByKeyOnly; +import org.apache.beam.runners.spark.aggregators.AccumulatorSingleton; +import org.apache.beam.runners.spark.aggregators.NamedAggregators; +import org.apache.beam.runners.spark.coders.CoderHelpers; import org.apache.beam.runners.spark.io.ConsoleIO; import org.apache.beam.runners.spark.io.CreateStream; import org.apache.beam.runners.spark.io.KafkaIO; -import org.apache.beam.runners.spark.io.hadoop.HadoopIO; import org.apache.beam.runners.spark.translation.DoFnFunction; import org.apache.beam.runners.spark.translation.EvaluationContext; +import org.apache.beam.runners.spark.translation.GroupCombineFunctions; +import org.apache.beam.runners.spark.translation.MultiDoFnFunction; import org.apache.beam.runners.spark.translation.SparkPipelineTranslator; +import org.apache.beam.runners.spark.translation.SparkRuntimeContext; import org.apache.beam.runners.spark.translation.TransformEvaluator; +import org.apache.beam.runners.spark.translation.TranslationUtils; import org.apache.beam.runners.spark.translation.WindowingHelpers; +import org.apache.beam.runners.spark.util.BroadcastHelper; +import org.apache.beam.sdk.coders.CannotProvideCoderException; import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.io.AvroIO; -import org.apache.beam.sdk.io.TextIO; -import org.apache.beam.sdk.transforms.AppliedPTransform; -import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.transforms.Combine; import org.apache.beam.sdk.transforms.Flatten; import org.apache.beam.sdk.transforms.OldDoFn; import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.FixedWindows; import org.apache.beam.sdk.transforms.windowing.SlidingWindows; import org.apache.beam.sdk.transforms.windowing.Window; import org.apache.beam.sdk.transforms.windowing.WindowFn; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionList; -import org.apache.beam.sdk.values.PDone; +import org.apache.beam.sdk.values.PCollectionTuple; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.spark.Accumulator; +import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.VoidFunction; import org.apache.spark.streaming.Duration; import org.apache.spark.streaming.Durations; import org.apache.spark.streaming.api.java.JavaDStream; import org.apache.spark.streaming.api.java.JavaDStreamLike; +import org.apache.spark.streaming.api.java.JavaPairDStream; import org.apache.spark.streaming.api.java.JavaPairInputDStream; import org.apache.spark.streaming.api.java.JavaStreamingContext; import org.apache.spark.streaming.kafka.KafkaUtils; + import scala.Tuple2; @@ -114,19 +129,6 @@ public final class StreamingTransformTranslator { }; } - private static <T> TransformEvaluator<Create.Values<T>> create() { - return new TransformEvaluator<Create.Values<T>>() { - @SuppressWarnings("unchecked") - @Override - public void evaluate(Create.Values<T> transform, EvaluationContext context) { - StreamingEvaluationContext sec = (StreamingEvaluationContext) context; - Iterable<T> elems = transform.getElements(); - Coder<T> coder = sec.getOutput(transform).getCoder(); - sec.setDStreamFromQueue(transform, Collections.singletonList(elems), coder); - } - }; - } - private static <T> TransformEvaluator<CreateStream.QueuedValues<T>> createFromQueue() { return new TransformEvaluator<CreateStream.QueuedValues<T>>() { @Override @@ -146,173 +148,325 @@ public final class StreamingTransformTranslator { public void evaluate(Flatten.FlattenPCollectionList<T> transform, EvaluationContext context) { StreamingEvaluationContext sec = (StreamingEvaluationContext) context; PCollectionList<T> pcs = sec.getInput(transform); - JavaDStream<WindowedValue<T>> first = - (JavaDStream<WindowedValue<T>>) sec.getStream(pcs.get(0)); - List<JavaDStream<WindowedValue<T>>> rest = Lists.newArrayListWithCapacity(pcs.size() - 1); - for (int i = 1; i < pcs.size(); i++) { - rest.add((JavaDStream<WindowedValue<T>>) sec.getStream(pcs.get(i))); + // since this is a streaming pipeline, at least one of the PCollections to "flatten" are + // unbounded, meaning it represents a DStream. + // So we could end up with an unbounded unified DStream. + final List<JavaRDD<WindowedValue<T>>> rdds = new ArrayList<>(); + final List<JavaDStream<WindowedValue<T>>> dStreams = new ArrayList<>(); + for (PCollection<T> pcol: pcs.getAll()) { + if (sec.hasStream(pcol)) { + dStreams.add((JavaDStream<WindowedValue<T>>) sec.getStream(pcol)); + } else { + rdds.add((JavaRDD<WindowedValue<T>>) context.getRDD(pcol)); + } + } + // start by unifying streams into a single stream. + JavaDStream<WindowedValue<T>> unifiedStreams = + sec.getStreamingContext().union(dStreams.remove(0), dStreams); + // now unify in RDDs. + if (rdds.size() > 0) { + JavaDStream<WindowedValue<T>> joined = unifiedStreams.transform( + new Function<JavaRDD<WindowedValue<T>>, JavaRDD<WindowedValue<T>>>() { + @Override + public JavaRDD<WindowedValue<T>> call(JavaRDD<WindowedValue<T>> streamRdd) + throws Exception { + return new JavaSparkContext(streamRdd.context()).union(streamRdd, rdds); + } + }); + sec.setStream(transform, joined); + } else { + sec.setStream(transform, unifiedStreams); } - JavaDStream<WindowedValue<T>> dstream = sec.getStreamingContext().union(first, rest); - sec.setStream(transform, dstream); } }; } - private static <TransformT extends PTransform<?, ?>> TransformEvaluator<TransformT> rddTransform( - final SparkPipelineTranslator rddTranslator) { - return new TransformEvaluator<TransformT>() { - @SuppressWarnings("unchecked") + private static <T, W extends BoundedWindow> TransformEvaluator<Window.Bound<T>> window() { + return new TransformEvaluator<Window.Bound<T>>() { @Override - public void evaluate(TransformT transform, EvaluationContext context) { - TransformEvaluator<TransformT> rddEvaluator = - rddTranslator.translate((Class<TransformT>) transform.getClass()); - + public void evaluate(Window.Bound<T> transform, EvaluationContext context) { StreamingEvaluationContext sec = (StreamingEvaluationContext) context; - if (sec.hasStream(transform)) { - JavaDStreamLike<WindowedValue<Object>, ?, JavaRDD<WindowedValue<Object>>> dStream = - (JavaDStreamLike<WindowedValue<Object>, ?, JavaRDD<WindowedValue<Object>>>) - sec.getStream(transform); - - sec.setStream(transform, dStream - .transform(new RDDTransform<>(sec, rddEvaluator, transform))); + @SuppressWarnings("unchecked") + WindowFn<? super T, W> windowFn = (WindowFn<? super T, W>) transform.getWindowFn(); + @SuppressWarnings("unchecked") + JavaDStream<WindowedValue<T>> dStream = + (JavaDStream<WindowedValue<T>>) sec.getStream(transform); + if (windowFn instanceof FixedWindows) { + Duration windowDuration = Durations.milliseconds(((FixedWindows) windowFn).getSize() + .getMillis()); + sec.setStream(transform, dStream.window(windowDuration)); + } else if (windowFn instanceof SlidingWindows) { + Duration windowDuration = Durations.milliseconds(((SlidingWindows) windowFn).getSize() + .getMillis()); + Duration slideDuration = Durations.milliseconds(((SlidingWindows) windowFn).getPeriod() + .getMillis()); + sec.setStream(transform, dStream.window(windowDuration, slideDuration)); + } + //--- then we apply windowing to the elements + @SuppressWarnings("unchecked") + JavaDStream<WindowedValue<T>> dStream2 = + (JavaDStream<WindowedValue<T>>) sec.getStream(transform); + if (TranslationUtils.skipAssignWindows(transform, context)) { + sec.setStream(transform, dStream2); } else { - // if the transformation requires direct access to RDD (not in stream) - // this is used for "fake" transformations like with PAssert - rddEvaluator.evaluate(transform, context); + final OldDoFn<T, T> addWindowsDoFn = new AssignWindowsDoFn<>(windowFn); + final SparkRuntimeContext runtimeContext = sec.getRuntimeContext(); + JavaDStream<WindowedValue<T>> outStream = dStream2.transform( + new Function<JavaRDD<WindowedValue<T>>, JavaRDD<WindowedValue<T>>>() { + @Override + public JavaRDD<WindowedValue<T>> call(JavaRDD<WindowedValue<T>> rdd) throws Exception { + final Accumulator<NamedAggregators> accum = + AccumulatorSingleton.getInstance(new JavaSparkContext(rdd.context())); + return rdd.mapPartitions( + new DoFnFunction<>(accum, addWindowsDoFn, runtimeContext, null)); + } + }); + sec.setStream(transform, outStream); } } }; } - /** - * RDD transform function If the transformation function doesn't have an input, create a fake one - * as an empty RDD. - * - * @param <TransformT> PTransform type - */ - private static final class RDDTransform<TransformT extends PTransform<?, ?>> - implements Function<JavaRDD<WindowedValue<Object>>, JavaRDD<WindowedValue<Object>>> { - - private final StreamingEvaluationContext context; - private final AppliedPTransform<?, ?, ?> appliedPTransform; - private final TransformEvaluator<TransformT> rddEvaluator; - private final TransformT transform; - - - private RDDTransform(StreamingEvaluationContext context, - TransformEvaluator<TransformT> rddEvaluator, - TransformT transform) { - this.context = context; - this.appliedPTransform = context.getCurrentTransform(); - this.rddEvaluator = rddEvaluator; - this.transform = transform; - } + private static <K, V> TransformEvaluator<GroupByKeyOnly<K, V>> gbko() { + return new TransformEvaluator<GroupByKeyOnly<K, V>>() { + @Override + public void evaluate(GroupByKeyOnly<K, V> transform, EvaluationContext context) { + StreamingEvaluationContext sec = (StreamingEvaluationContext) context; - @Override - @SuppressWarnings("unchecked") - public JavaRDD<WindowedValue<Object>> - call(JavaRDD<WindowedValue<Object>> rdd) throws Exception { - AppliedPTransform<?, ?, ?> existingAPT = context.getCurrentTransform(); - context.setCurrentTransform(appliedPTransform); - context.setInputRDD(transform, rdd); - rddEvaluator.evaluate(transform, context); - if (!context.hasOutputRDD(transform)) { - // fake RDD as output - context.setOutputRDD(transform, - context.getSparkContext().<WindowedValue<Object>>emptyRDD()); + @SuppressWarnings("unchecked") + JavaDStream<WindowedValue<KV<K, V>>> dStream = + (JavaDStream<WindowedValue<KV<K, V>>>) sec.getStream(transform); + + @SuppressWarnings("unchecked") + final KvCoder<K, V> coder = (KvCoder<K, V>) sec.getInput(transform).getCoder(); + + JavaDStream<WindowedValue<KV<K, Iterable<V>>>> outStream = + dStream.transform(new Function<JavaRDD<WindowedValue<KV<K, V>>>, + JavaRDD<WindowedValue<KV<K, Iterable<V>>>>>() { + @Override + public JavaRDD<WindowedValue<KV<K, Iterable<V>>>> call( + JavaRDD<WindowedValue<KV<K, V>>> rdd) throws Exception { + return GroupCombineFunctions.groupByKeyOnly(rdd, coder); + } + }); + sec.setStream(transform, outStream); } - JavaRDD<WindowedValue<Object>> outRDD = - (JavaRDD<WindowedValue<Object>>) context.getOutputRDD(transform); - context.setCurrentTransform(existingAPT); - return outRDD; - } + }; } - @SuppressWarnings("unchecked") - private static <TransformT extends PTransform<?, ?>> TransformEvaluator<TransformT> foreachRDD( - final SparkPipelineTranslator rddTranslator) { - return new TransformEvaluator<TransformT>() { + private static <K, V, W extends BoundedWindow> + TransformEvaluator<GroupAlsoByWindow<K, V>> gabw() { + return new TransformEvaluator<GroupAlsoByWindow<K, V>>() { @Override - public void evaluate(TransformT transform, EvaluationContext context) { - TransformEvaluator<TransformT> rddEvaluator = - rddTranslator.translate((Class<TransformT>) transform.getClass()); + public void evaluate(final GroupAlsoByWindow<K, V> transform, EvaluationContext context) { + final StreamingEvaluationContext sec = (StreamingEvaluationContext) context; + final SparkRuntimeContext runtimeContext = sec.getRuntimeContext(); + @SuppressWarnings("unchecked") + JavaDStream<WindowedValue<KV<K, Iterable<WindowedValue<V>>>>> dStream = + (JavaDStream<WindowedValue<KV<K, Iterable<WindowedValue<V>>>>>) + sec.getStream(transform); - StreamingEvaluationContext sec = (StreamingEvaluationContext) context; - if (sec.hasStream(transform)) { - JavaDStreamLike<WindowedValue<Object>, ?, JavaRDD<WindowedValue<Object>>> dStream = - (JavaDStreamLike<WindowedValue<Object>, ?, JavaRDD<WindowedValue<Object>>>) - sec.getStream(transform); + @SuppressWarnings("unchecked") + final KvCoder<K, Iterable<WindowedValue<V>>> inputKvCoder = + (KvCoder<K, Iterable<WindowedValue<V>>>) sec.getInput(transform).getCoder(); + + JavaDStream<WindowedValue<KV<K, Iterable<V>>>> outStream = + dStream.transform(new Function<JavaRDD<WindowedValue<KV<K, + Iterable<WindowedValue<V>>>>>, JavaRDD<WindowedValue<KV<K, Iterable<V>>>>>() { + @Override + public JavaRDD<WindowedValue<KV<K, Iterable<V>>>> call(JavaRDD<WindowedValue<KV<K, + Iterable<WindowedValue<V>>>>> rdd) throws Exception { + final Accumulator<NamedAggregators> accum = + AccumulatorSingleton.getInstance(new JavaSparkContext(rdd.context())); + return GroupCombineFunctions.groupAlsoByWindow(rdd, transform, runtimeContext, + accum, inputKvCoder); + } + }); + sec.setStream(transform, outStream); + } + }; + } - dStream.foreachRDD(new RDDOutputOperator<>(sec, rddEvaluator, transform)); - } else { - rddEvaluator.evaluate(transform, context); - } + private static <K, InputT, OutputT> TransformEvaluator<Combine.GroupedValues<K, InputT, OutputT>> + grouped() { + return new TransformEvaluator<Combine.GroupedValues<K, InputT, OutputT>>() { + @Override + public void evaluate(Combine.GroupedValues<K, InputT, OutputT> transform, + EvaluationContext context) { + StreamingEvaluationContext sec = (StreamingEvaluationContext) context; + @SuppressWarnings("unchecked") + JavaDStream<WindowedValue<KV<K, Iterable<InputT>>>> dStream = + (JavaDStream<WindowedValue<KV<K, Iterable<InputT>>>>) sec.getStream(transform); + sec.setStream(transform, dStream.map( + new TranslationUtils.CombineGroupedValues<>(transform))); } }; } - /** - * RDD output function. - * - * @param <TransformT> PTransform type - */ - private static final class RDDOutputOperator<TransformT extends PTransform<?, ?>> - implements VoidFunction<JavaRDD<WindowedValue<Object>>> { + private static <InputT, AccumT, OutputT> TransformEvaluator<Combine.Globally<InputT, OutputT>> + combineGlobally() { + return new TransformEvaluator<Combine.Globally<InputT, OutputT>>() { - private final StreamingEvaluationContext context; - private final AppliedPTransform<?, ?, ?> appliedPTransform; - private final TransformEvaluator<TransformT> rddEvaluator; - private final TransformT transform; + @Override + public void evaluate(Combine.Globally<InputT, OutputT> transform, EvaluationContext context) { + StreamingEvaluationContext sec = (StreamingEvaluationContext) context; + @SuppressWarnings("unchecked") + final Combine.CombineFn<InputT, AccumT, OutputT> globally = + (Combine.CombineFn<InputT, AccumT, OutputT>) transform.getFn(); + @SuppressWarnings("unchecked") + JavaDStream<WindowedValue<InputT>> dStream = + (JavaDStream<WindowedValue<InputT>>) sec.getStream(transform); + + final Coder<InputT> iCoder = sec.getInput(transform).getCoder(); + final Coder<OutputT> oCoder = sec.getOutput(transform).getCoder(); + final Coder<AccumT> aCoder; + try { + aCoder = globally.getAccumulatorCoder(sec.getPipeline().getCoderRegistry(), iCoder); + } catch (CannotProvideCoderException e) { + throw new IllegalStateException("Could not determine coder for accumulator", e); + } - private RDDOutputOperator(StreamingEvaluationContext context, - TransformEvaluator<TransformT> rddEvaluator, TransformT transform) { - this.context = context; - this.appliedPTransform = context.getCurrentTransform(); - this.rddEvaluator = rddEvaluator; - this.transform = transform; - } + JavaDStream<WindowedValue<OutputT>> outStream = dStream.transform( + new Function<JavaRDD<WindowedValue<InputT>>, JavaRDD<WindowedValue<OutputT>>>() { + @Override + public JavaRDD<WindowedValue<OutputT>> call(JavaRDD<WindowedValue<InputT>> rdd) + throws Exception { + JavaRDD<byte[]> outRdd = new JavaSparkContext(rdd.context()).parallelize( + // don't use Guava's ImmutableList.of as output may be null + CoderHelpers.toByteArrays(Collections.singleton( + GroupCombineFunctions.combineGlobally(rdd, globally, iCoder, aCoder)), oCoder)); + return outRdd.map(CoderHelpers.fromByteFunction(oCoder)).map( + WindowingHelpers.<OutputT>windowFunction()); + } + }); - @Override - @SuppressWarnings("unchecked") - public void call(JavaRDD<WindowedValue<Object>> rdd) throws Exception { - AppliedPTransform<?, ?, ?> existingAPT = context.getCurrentTransform(); - context.setCurrentTransform(appliedPTransform); - context.setInputRDD(transform, rdd); - rddEvaluator.evaluate(transform, context); - context.setCurrentTransform(existingAPT); - } + sec.setStream(transform, outStream); + } + }; } - private static <T> TransformEvaluator<Window.Bound<T>> window() { - return new TransformEvaluator<Window.Bound<T>>() { + private static <K, InputT, AccumT, OutputT> + TransformEvaluator<Combine.PerKey<K, InputT, OutputT>> combinePerKey() { + return new TransformEvaluator<Combine.PerKey<K, InputT, OutputT>>() { @Override - public void evaluate(Window.Bound<T> transform, EvaluationContext context) { + public void evaluate(Combine.PerKey<K, InputT, OutputT> + transform, EvaluationContext context) { StreamingEvaluationContext sec = (StreamingEvaluationContext) context; - WindowFn<? super T, ?> windowFn = transform.getWindowFn(); @SuppressWarnings("unchecked") - JavaDStream<WindowedValue<T>> dStream = - (JavaDStream<WindowedValue<T>>) sec.getStream(transform); - if (windowFn instanceof FixedWindows) { - Duration windowDuration = Durations.milliseconds(((FixedWindows) windowFn).getSize() - .getMillis()); - sec.setStream(transform, dStream.window(windowDuration)); - } else if (windowFn instanceof SlidingWindows) { - Duration windowDuration = Durations.milliseconds(((SlidingWindows) windowFn).getSize() - .getMillis()); - Duration slideDuration = Durations.milliseconds(((SlidingWindows) windowFn).getPeriod() - .getMillis()); - sec.setStream(transform, dStream.window(windowDuration, slideDuration)); + final Combine.KeyedCombineFn<K, InputT, AccumT, OutputT> keyed = + (Combine.KeyedCombineFn<K, InputT, AccumT, OutputT>) transform.getFn(); + @SuppressWarnings("unchecked") + JavaDStream<WindowedValue<KV<K, InputT>>> dStream = + (JavaDStream<WindowedValue<KV<K, InputT>>>) sec.getStream(transform); + + @SuppressWarnings("unchecked") + KvCoder<K, InputT> inputCoder = (KvCoder<K, InputT>) sec.getInput(transform).getCoder(); + Coder<K> keyCoder = inputCoder.getKeyCoder(); + Coder<InputT> viCoder = inputCoder.getValueCoder(); + Coder<AccumT> vaCoder; + try { + vaCoder = keyed.getAccumulatorCoder( + context.getPipeline().getCoderRegistry(), keyCoder, viCoder); + } catch (CannotProvideCoderException e) { + throw new IllegalStateException("Could not determine coder for accumulator", e); } - //--- then we apply windowing to the elements - OldDoFn<T, T> addWindowsDoFn = new AssignWindowsDoFn<>(windowFn); - DoFnFunction<T, T> dofn = new DoFnFunction<>(addWindowsDoFn, - ((StreamingEvaluationContext) context).getRuntimeContext(), null); + Coder<KV<K, InputT>> kviCoder = KvCoder.of(keyCoder, viCoder); + Coder<KV<K, AccumT>> kvaCoder = KvCoder.of(keyCoder, vaCoder); + //-- windowed coders + final WindowedValue.FullWindowedValueCoder<K> wkCoder = + WindowedValue.FullWindowedValueCoder.of(keyCoder, + sec.getInput(transform).getWindowingStrategy().getWindowFn().windowCoder()); + final WindowedValue.FullWindowedValueCoder<KV<K, InputT>> wkviCoder = + WindowedValue.FullWindowedValueCoder.of(kviCoder, + sec.getInput(transform).getWindowingStrategy().getWindowFn().windowCoder()); + final WindowedValue.FullWindowedValueCoder<KV<K, AccumT>> wkvaCoder = + WindowedValue.FullWindowedValueCoder.of(kvaCoder, + sec.getInput(transform).getWindowingStrategy().getWindowFn().windowCoder()); + + JavaDStream<WindowedValue<KV<K, OutputT>>> outStream = + dStream.transform(new Function<JavaRDD<WindowedValue<KV<K, InputT>>>, + JavaRDD<WindowedValue<KV<K, OutputT>>>>() { + @Override + public JavaRDD<WindowedValue<KV<K, OutputT>>> call( + JavaRDD<WindowedValue<KV<K, InputT>>> rdd) throws Exception { + return GroupCombineFunctions.combinePerKey(rdd, keyed, wkCoder, wkviCoder, wkvaCoder); + } + }); + + sec.setStream(transform, outStream); + } + }; + } + + private static <InputT, OutputT> TransformEvaluator<ParDo.Bound<InputT, OutputT>> parDo() { + return new TransformEvaluator<ParDo.Bound<InputT, OutputT>>() { + @Override + public void evaluate(final ParDo.Bound<InputT, OutputT> transform, + final EvaluationContext context) { + final StreamingEvaluationContext sec = (StreamingEvaluationContext) context; + final SparkRuntimeContext runtimeContext = sec.getRuntimeContext(); + final Map<TupleTag<?>, BroadcastHelper<?>> sideInputs = + TranslationUtils.getSideInputs(transform.getSideInputs(), context); @SuppressWarnings("unchecked") - JavaDStreamLike<WindowedValue<T>, ?, JavaRDD<WindowedValue<T>>> dstream = - (JavaDStreamLike<WindowedValue<T>, ?, JavaRDD<WindowedValue<T>>>) - sec.getStream(transform); - sec.setStream(transform, dstream.mapPartitions(dofn)); + JavaDStream<WindowedValue<InputT>> dStream = + (JavaDStream<WindowedValue<InputT>>) sec.getStream(transform); + + JavaDStream<WindowedValue<OutputT>> outStream = + dStream.transform(new Function<JavaRDD<WindowedValue<InputT>>, + JavaRDD<WindowedValue<OutputT>>>() { + @Override + public JavaRDD<WindowedValue<OutputT>> call(JavaRDD<WindowedValue<InputT>> rdd) throws + Exception { + final Accumulator<NamedAggregators> accum = + AccumulatorSingleton.getInstance(new JavaSparkContext(rdd.context())); + return rdd.mapPartitions( + new DoFnFunction<>(accum, transform.getFn(), runtimeContext, sideInputs)); + } + }); + + sec.setStream(transform, outStream); + } + }; + } + + private static <InputT, OutputT> TransformEvaluator<ParDo.BoundMulti<InputT, OutputT>> + multiDo() { + return new TransformEvaluator<ParDo.BoundMulti<InputT, OutputT>>() { + @Override + public void evaluate(final ParDo.BoundMulti<InputT, OutputT> transform, + final EvaluationContext context) { + final StreamingEvaluationContext sec = (StreamingEvaluationContext) context; + final SparkRuntimeContext runtimeContext = sec.getRuntimeContext(); + final Map<TupleTag<?>, BroadcastHelper<?>> sideInputs = + TranslationUtils.getSideInputs(transform.getSideInputs(), context); + @SuppressWarnings("unchecked") + JavaDStream<WindowedValue<InputT>> dStream = + (JavaDStream<WindowedValue<InputT>>) sec.getStream(transform); + JavaPairDStream<TupleTag<?>, WindowedValue<?>> all = dStream.transformToPair( + new Function<JavaRDD<WindowedValue<InputT>>, + JavaPairRDD<TupleTag<?>, WindowedValue<?>>>() { + @Override + public JavaPairRDD<TupleTag<?>, WindowedValue<?>> call( + JavaRDD<WindowedValue<InputT>> rdd) throws Exception { + final Accumulator<NamedAggregators> accum = + AccumulatorSingleton.getInstance(new JavaSparkContext(rdd.context())); + return rdd.mapPartitionsToPair(new MultiDoFnFunction<>(accum, transform.getFn(), + runtimeContext, transform.getMainOutputTag(), sideInputs)); + } + }).cache(); + PCollectionTuple pct = sec.getOutput(transform); + for (Map.Entry<TupleTag<?>, PCollection<?>> e : pct.getAll().entrySet()) { + @SuppressWarnings("unchecked") + JavaPairDStream<TupleTag<?>, WindowedValue<?>> filtered = + all.filter(new TranslationUtils.TupleTagFilter(e.getKey())); + @SuppressWarnings("unchecked") + // Object is the best we can do since different outputs can have different tags + JavaDStream<WindowedValue<Object>> values = + (JavaDStream<WindowedValue<Object>>) + (JavaDStream<?>) TranslationUtils.dStreamValues(filtered); + sec.setStream(e.getValue(), values); + } } }; } @@ -321,79 +475,54 @@ public final class StreamingTransformTranslator { .newHashMap(); static { + EVALUATORS.put(GroupByKeyViaGroupByKeyOnly.GroupByKeyOnly.class, gbko()); + EVALUATORS.put(GroupByKeyViaGroupByKeyOnly.GroupAlsoByWindow.class, gabw()); + EVALUATORS.put(Combine.GroupedValues.class, grouped()); + EVALUATORS.put(Combine.Globally.class, combineGlobally()); + EVALUATORS.put(Combine.PerKey.class, combinePerKey()); + EVALUATORS.put(ParDo.Bound.class, parDo()); + EVALUATORS.put(ParDo.BoundMulti.class, multiDo()); EVALUATORS.put(ConsoleIO.Write.Unbound.class, print()); EVALUATORS.put(CreateStream.QueuedValues.class, createFromQueue()); - EVALUATORS.put(Create.Values.class, create()); EVALUATORS.put(KafkaIO.Read.Unbound.class, kafka()); EVALUATORS.put(Window.Bound.class, window()); EVALUATORS.put(Flatten.FlattenPCollectionList.class, flattenPColl()); } - private static final Set<Class<? extends PTransform>> UNSUPPORTED_EVALUATORS = Sets - .newHashSet(); - - static { - //TODO - add support for the following - UNSUPPORTED_EVALUATORS.add(TextIO.Read.Bound.class); - UNSUPPORTED_EVALUATORS.add(TextIO.Write.Bound.class); - UNSUPPORTED_EVALUATORS.add(AvroIO.Read.Bound.class); - UNSUPPORTED_EVALUATORS.add(AvroIO.Write.Bound.class); - UNSUPPORTED_EVALUATORS.add(HadoopIO.Read.Bound.class); - UNSUPPORTED_EVALUATORS.add(HadoopIO.Write.Bound.class); - } - - @SuppressWarnings("unchecked") - private static <TransformT extends PTransform<?, ?>> TransformEvaluator<TransformT> - getTransformEvaluator(Class<TransformT> clazz, SparkPipelineTranslator rddTranslator) { - TransformEvaluator<TransformT> transform = - (TransformEvaluator<TransformT>) EVALUATORS.get(clazz); - if (transform == null) { - if (UNSUPPORTED_EVALUATORS.contains(clazz)) { - throw new UnsupportedOperationException("Beam transformation " + clazz - .getCanonicalName() - + " is currently unsupported by the Spark streaming pipeline"); - } - // DStream transformations will transform an RDD into another RDD - // Actions will create output - // In Beam it depends on the PTransform's Input and Output class - Class<?> pTOutputClazz = getPTransformOutputClazz(clazz); - if (PDone.class.equals(pTOutputClazz)) { - return foreachRDD(rddTranslator); - } else { - return rddTransform(rddTranslator); - } - } - return transform; - } - - private static <TransformT extends PTransform<?, ?>> Class<?> - getPTransformOutputClazz(Class<TransformT> clazz) { - Type[] types = ((ParameterizedType) clazz.getGenericSuperclass()).getActualTypeArguments(); - return TypeToken.of(clazz).resolveType(types[1]).getRawType(); - } - /** - * Translator matches Beam transformation with the appropriate Spark streaming evaluator. - * rddTranslator uses Spark evaluators in transform/foreachRDD to evaluate the transformation + * Translator matches Beam transformation with the appropriate evaluator. */ public static class Translator implements SparkPipelineTranslator { - private final SparkPipelineTranslator rddTranslator; + private final SparkPipelineTranslator batchTranslator; - public Translator(SparkPipelineTranslator rddTranslator) { - this.rddTranslator = rddTranslator; + Translator(SparkPipelineTranslator batchTranslator) { + this.batchTranslator = batchTranslator; } @Override public boolean hasTranslation(Class<? extends PTransform<?, ?>> clazz) { - // streaming includes rdd transformations as well - return EVALUATORS.containsKey(clazz) || rddTranslator.hasTranslation(clazz); + // streaming includes rdd/bounded transformations as well + return EVALUATORS.containsKey(clazz) || batchTranslator.hasTranslation(clazz); + } + + @Override + public <TransformT extends PTransform<?, ?>> TransformEvaluator<TransformT> + translateBounded(Class<TransformT> clazz) { + TransformEvaluator<TransformT> transformEvaluator = batchTranslator.translateBounded(clazz); + checkState(transformEvaluator != null, + "No TransformEvaluator registered for BOUNDED transform %s", clazz); + return transformEvaluator; } @Override public <TransformT extends PTransform<?, ?>> TransformEvaluator<TransformT> - translate(Class<TransformT> clazz) { - return getTransformEvaluator(clazz, rddTranslator); + translateUnbounded(Class<TransformT> clazz) { + @SuppressWarnings("unchecked") TransformEvaluator<TransformT> transformEvaluator = + (TransformEvaluator<TransformT>) EVALUATORS.get(clazz); + checkState(transformEvaluator != null, + "No TransformEvaluator registered for for UNBOUNDED transform %s", clazz); + return transformEvaluator; } } } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/0feb6499/runners/spark/src/main/java/org/apache/beam/runners/spark/util/BroadcastHelper.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/util/BroadcastHelper.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/util/BroadcastHelper.java index 5c13b80..0e742eb 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/util/BroadcastHelper.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/util/BroadcastHelper.java @@ -52,8 +52,12 @@ public abstract class BroadcastHelper<T> implements Serializable { public abstract T getValue(); + public abstract boolean isBroadcasted(); + public abstract void broadcast(JavaSparkContext jsc); + public abstract void unregister(); + /** * A {@link BroadcastHelper} that relies on the underlying * Spark serialization (Kryo) to broadcast values. This is appropriate when @@ -77,9 +81,20 @@ public abstract class BroadcastHelper<T> implements Serializable { } @Override + public boolean isBroadcasted() { + return bcast != null; + } + + @Override public void broadcast(JavaSparkContext jsc) { this.bcast = jsc.broadcast(value); } + + @Override + public void unregister() { + this.bcast.destroy(); + this.bcast = null; + } } /** @@ -107,10 +122,21 @@ public abstract class BroadcastHelper<T> implements Serializable { } @Override + public boolean isBroadcasted() { + return bcast != null; + } + + @Override public void broadcast(JavaSparkContext jsc) { this.bcast = jsc.broadcast(CoderHelpers.toByteArray(value, coder)); } + @Override + public void unregister() { + this.bcast.destroy(); + this.bcast = null; + } + private T deserialize() { T val; try {