Repository: beam Updated Branches: refs/heads/master d35e1b0d9 -> 434eadb53
[BEAM-1815] Force a "default" partitioner based on Spark default parallelism to avoid unnecessary shuffles in the composite GBK implementation. Add Javadoc. Project: http://git-wip-us.apache.org/repos/asf/beam/repo Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/25569eaf Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/25569eaf Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/25569eaf Branch: refs/heads/master Commit: 25569eafeca647561ef10cedc03f06b1de53b8cd Parents: d35e1b0 Author: Amit Sela <[email protected]> Authored: Mon Mar 27 15:30:03 2017 +0300 Committer: Amit Sela <[email protected]> Committed: Tue Mar 28 16:44:32 2017 +0300 ---------------------------------------------------------------------- .../SparkGroupAlsoByWindowViaWindowSet.java | 35 ++- .../translation/GroupCombineFunctions.java | 38 ++- .../spark/translation/TranslationUtils.java | 245 +++++++++++++++---- 3 files changed, 248 insertions(+), 70 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/beam/blob/25569eaf/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java index 2f1713a..1f2fcb6 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java @@ -54,6 +54,8 @@ import org.apache.beam.sdk.util.WindowingStrategy; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.TupleTag; import org.apache.spark.Partitioner; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext$; import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.api.java.function.Function; @@ -135,12 +137,35 @@ public class SparkGroupAlsoByWindowViaWindowSet { //---- InputT: I DStream<Tuple2</*K*/ ByteArray, /*Itr<WV<I>>*/ byte[]>> pairDStream = inputDStream - .map(WindowingHelpers.<KV<K, Iterable<WindowedValue<InputT>>>>unwindowFunction()) - .mapToPair(TranslationUtils.<K, Iterable<WindowedValue<InputT>>>toPairFunction()) - // move to bytes and use coders for deserialization because there's a shuffle - // and checkpointing involved. - .mapToPair(CoderHelpers.toByteFunction(keyCoder, itrWvCoder)) + .transformToPair( + new Function< + JavaRDD<WindowedValue<KV<K, Iterable<WindowedValue<InputT>>>>>, + JavaPairRDD<ByteArray, byte[]>>() { + // we use mapPartitions with the RDD API because its the only available API + // that allows to preserve partitioning. + @Override + public JavaPairRDD<ByteArray, byte[]> call( + JavaRDD<WindowedValue<KV<K, Iterable<WindowedValue<InputT>>>>> rdd) + throws Exception { + return rdd.mapPartitions( + TranslationUtils.functionToFlatMapFunction( + WindowingHelpers + .<KV<K, Iterable<WindowedValue<InputT>>>>unwindowFunction()), + true) + .mapPartitionsToPair( + TranslationUtils + .<K, Iterable<WindowedValue<InputT>>>toPairFlatMapFunction(), + true) + // move to bytes representation and use coders for deserialization + // because of checkpointing. + .mapPartitionsToPair( + TranslationUtils.pairFunctionToPairFlatMapFunction( + CoderHelpers.toByteFunction(keyCoder, itrWvCoder)), + true); + } + }) .dstream(); + PairDStreamFunctions<ByteArray, byte[]> pairDStreamFunctions = DStream.toPairDStreamFunctions( pairDStream, http://git-wip-us.apache.org/repos/asf/beam/blob/25569eaf/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/GroupCombineFunctions.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/GroupCombineFunctions.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/GroupCombineFunctions.java index 917a9ee..6a67cce 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/GroupCombineFunctions.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/GroupCombineFunctions.java @@ -28,13 +28,14 @@ import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.util.WindowedValue.WindowedValueCoder; import org.apache.beam.sdk.util.WindowingStrategy; import org.apache.beam.sdk.values.KV; +import org.apache.spark.HashPartitioner; +import org.apache.spark.Partitioner; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.Function2; - /** * A set of group/combine functions to apply to Spark {@link org.apache.spark.rdd.RDD}s. */ @@ -49,18 +50,31 @@ public class GroupCombineFunctions { JavaRDD<WindowedValue<KV<K, V>>> rdd, Coder<K> keyCoder, WindowedValueCoder<V> wvCoder) { - - // Use coders to convert objects in the PCollection to byte arrays, so they + // we use coders to convert objects in the PCollection to byte arrays, so they // can be transferred over the network for the shuffle. - return rdd - .map(new ReifyTimestampsAndWindowsFunction<K, V>()) - .map(WindowingHelpers.<KV<K, WindowedValue<V>>>unwindowFunction()) - .mapToPair(TranslationUtils.<K, WindowedValue<V>>toPairFunction()) - .mapToPair(CoderHelpers.toByteFunction(keyCoder, wvCoder)) - .groupByKey() - .mapToPair(CoderHelpers.fromByteFunctionIterable(keyCoder, wvCoder)) - .map(TranslationUtils.<K, Iterable<WindowedValue<V>>>fromPairFunction()) - .map(WindowingHelpers.<KV<K, Iterable<WindowedValue<V>>>>windowFunction()); + JavaPairRDD<ByteArray, byte[]> pairRDD = + rdd + .map(new ReifyTimestampsAndWindowsFunction<K, V>()) + .map(WindowingHelpers.<KV<K, WindowedValue<V>>>unwindowFunction()) + .mapToPair(TranslationUtils.<K, WindowedValue<V>>toPairFunction()) + .mapToPair(CoderHelpers.toByteFunction(keyCoder, wvCoder)); + // use a default parallelism HashPartitioner. + Partitioner partitioner = new HashPartitioner(rdd.rdd().sparkContext().defaultParallelism()); + + // using mapPartitions allows to preserve the partitioner + // and avoid unnecessary shuffle downstream. + return pairRDD + .groupByKey(partitioner) + .mapPartitionsToPair( + TranslationUtils.pairFunctionToPairFlatMapFunction( + CoderHelpers.fromByteFunctionIterable(keyCoder, wvCoder)), + true) + .mapPartitions( + TranslationUtils.<K, Iterable<WindowedValue<V>>>fromPairFlatMapFunction(), true) + .mapPartitions( + TranslationUtils.functionToFlatMapFunction( + WindowingHelpers.<KV<K, Iterable<WindowedValue<V>>>>windowFunction()), + true); } /** http://git-wip-us.apache.org/repos/asf/beam/blob/25569eaf/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TranslationUtils.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TranslationUtils.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TranslationUtils.java index 8545b36..ef1ff9f 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TranslationUtils.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TranslationUtils.java @@ -19,8 +19,10 @@ package org.apache.beam.runners.spark.translation; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Iterators; import com.google.common.collect.Maps; import java.io.Serializable; +import java.util.Iterator; import java.util.List; import java.util.Map; import org.apache.beam.runners.core.InMemoryStateInternals; @@ -41,7 +43,9 @@ import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.TupleTag; import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.PairFlatMapFunction; import org.apache.spark.api.java.function.PairFunction; import org.apache.spark.api.java.function.VoidFunction; import org.apache.spark.streaming.api.java.JavaDStream; @@ -49,21 +53,17 @@ import org.apache.spark.streaming.api.java.JavaPairDStream; import scala.Tuple2; -/** - * A set of utilities to help translating Beam transformations into Spark transformations. - */ +/** A set of utilities to help translating Beam transformations into Spark transformations. */ public final class TranslationUtils { - private TranslationUtils() { - } + private TranslationUtils() {} /** * In-memory state internals factory. * * @param <K> State key type. */ - static class InMemoryStateInternalsFactory<K> implements StateInternalsFactory<K>, - Serializable { + static class InMemoryStateInternalsFactory<K> implements StateInternalsFactory<K>, Serializable { @Override public StateInternals<K> stateInternalsForKey(K key) { return InMemoryStateInternals.forKey(key); @@ -73,12 +73,12 @@ public final class TranslationUtils { /** * A SparkKeyedCombineFn function applied to grouped KVs. * - * @param <K> Grouped key type. - * @param <InputT> Grouped values type. + * @param <K> Grouped key type. + * @param <InputT> Grouped values type. * @param <OutputT> Output type. */ - public static class CombineGroupedValues<K, InputT, OutputT> implements - Function<WindowedValue<KV<K, Iterable<InputT>>>, WindowedValue<KV<K, OutputT>>> { + public static class CombineGroupedValues<K, InputT, OutputT> + implements Function<WindowedValue<KV<K, Iterable<InputT>>>, WindowedValue<KV<K, OutputT>>> { private final SparkKeyedCombineFn<K, InputT, ?, OutputT> fn; public CombineGroupedValues(SparkKeyedCombineFn<K, InputT, ?, OutputT> fn) { @@ -88,44 +88,46 @@ public final class TranslationUtils { @Override public WindowedValue<KV<K, OutputT>> call(WindowedValue<KV<K, Iterable<InputT>>> windowedKv) throws Exception { - return WindowedValue.of(KV.of(windowedKv.getValue().getKey(), fn.apply(windowedKv)), - windowedKv.getTimestamp(), windowedKv.getWindows(), windowedKv.getPane()); + return WindowedValue.of( + KV.of(windowedKv.getValue().getKey(), fn.apply(windowedKv)), + windowedKv.getTimestamp(), + windowedKv.getWindows(), + windowedKv.getPane()); } } /** * Checks if the window transformation should be applied or skipped. * - * <p> - * Avoid running assign windows if both source and destination are global window - * or if the user has not specified the WindowFn (meaning they are just messing - * with triggering or allowed lateness). - * </p> + * <p>Avoid running assign windows if both source and destination are global window or if the user + * has not specified the WindowFn (meaning they are just messing with triggering or allowed + * lateness). * * @param transform The {@link Window.Assign} transformation. - * @param context The {@link EvaluationContext}. - * @param <T> PCollection type. - * @param <W> {@link BoundedWindow} type. + * @param context The {@link EvaluationContext}. + * @param <T> PCollection type. + * @param <W> {@link BoundedWindow} type. * @return if to apply the transformation. */ - public static <T, W extends BoundedWindow> boolean - skipAssignWindows(Window.Assign<T> transform, EvaluationContext context) { + public static <T, W extends BoundedWindow> boolean skipAssignWindows( + Window.Assign<T> transform, EvaluationContext context) { @SuppressWarnings("unchecked") WindowFn<? super T, W> windowFn = (WindowFn<? super T, W>) transform.getWindowFn(); return windowFn == null || (context.getInput(transform).getWindowingStrategy().getWindowFn() - instanceof GlobalWindows - && windowFn instanceof GlobalWindows); + instanceof GlobalWindows + && windowFn instanceof GlobalWindows); } /** Transform a pair stream into a value stream. */ public static <T1, T2> JavaDStream<T2> dStreamValues(JavaPairDStream<T1, T2> pairDStream) { - return pairDStream.map(new Function<Tuple2<T1, T2>, T2>() { - @Override - public T2 call(Tuple2<T1, T2> v1) throws Exception { - return v1._2(); - } - }); + return pairDStream.map( + new Function<Tuple2<T1, T2>, T2>() { + @Override + public T2 call(Tuple2<T1, T2> v1) throws Exception { + return v1._2(); + } + }); } /** {@link KV} to pair function. */ @@ -138,7 +140,33 @@ public final class TranslationUtils { }; } - /** A pair to {@link KV} function . */ + /** {@link KV} to pair flatmap function. */ + public static <K, V> PairFlatMapFunction<Iterator<KV<K, V>>, K, V> toPairFlatMapFunction() { + return new PairFlatMapFunction<Iterator<KV<K, V>>, K, V>() { + @Override + public Iterable<Tuple2<K, V>> call(final Iterator<KV<K, V>> itr) { + final Iterator<Tuple2<K, V>> outputItr = + Iterators.transform( + itr, + new com.google.common.base.Function<KV<K, V>, Tuple2<K, V>>() { + + @Override + public Tuple2<K, V> apply(KV<K, V> kv) { + return new Tuple2<>(kv.getKey(), kv.getValue()); + } + }); + return new Iterable<Tuple2<K, V>>() { + + @Override + public Iterator<Tuple2<K, V>> iterator() { + return outputItr; + } + }; + } + }; + } + + /** A pair to {@link KV} function . */ static <K, V> Function<Tuple2<K, V>, KV<K, V>> fromPairFunction() { return new Function<Tuple2<K, V>, KV<K, V>>() { @Override @@ -148,22 +176,48 @@ public final class TranslationUtils { }; } - /** Extract key from a {@link WindowedValue} {@link KV} into a pair. */ - public static <K, V> PairFunction<WindowedValue<KV<K, V>>, K, WindowedValue<KV<K, V>>> - toPairByKeyInWindowedValue() { - return new PairFunction<WindowedValue<KV<K, V>>, K, WindowedValue<KV<K, V>>>() { + /** A pair to {@link KV} flatmap function . */ + static <K, V> FlatMapFunction<Iterator<Tuple2<K, V>>, KV<K, V>> fromPairFlatMapFunction() { + return new FlatMapFunction<Iterator<Tuple2<K, V>>, KV<K, V>>() { + @Override + public Iterable<KV<K, V>> call(Iterator<Tuple2<K, V>> itr) { + final Iterator<KV<K, V>> outputItr = + Iterators.transform( + itr, + new com.google.common.base.Function<Tuple2<K, V>, KV<K, V>>() { + @Override + public KV<K, V> apply(Tuple2<K, V> t2) { + return KV.of(t2._1(), t2._2()); + } + }); + return new Iterable<KV<K, V>>() { @Override - public Tuple2<K, WindowedValue<KV<K, V>>> call( - WindowedValue<KV<K, V>> windowedKv) throws Exception { - return new Tuple2<>(windowedKv.getValue().getKey(), windowedKv); - } + public Iterator<KV<K, V>> iterator() { + return outputItr; + } }; } + }; + } + + /** Extract key from a {@link WindowedValue} {@link KV} into a pair. */ + public static <K, V> + PairFunction<WindowedValue<KV<K, V>>, K, WindowedValue<KV<K, V>>> + toPairByKeyInWindowedValue() { + return new PairFunction<WindowedValue<KV<K, V>>, K, WindowedValue<KV<K, V>>>() { + @Override + public Tuple2<K, WindowedValue<KV<K, V>>> call(WindowedValue<KV<K, V>> windowedKv) + throws Exception { + return new Tuple2<>(windowedKv.getValue().getKey(), windowedKv); + } + }; + } /** Extract window from a {@link KV} with {@link WindowedValue} value. */ static <K, V> Function<KV<K, WindowedValue<V>>, WindowedValue<KV<K, V>>> toKVByWindowInValue() { return new Function<KV<K, WindowedValue<V>>, WindowedValue<KV<K, V>>>() { - @Override public WindowedValue<KV<K, V>> call(KV<K, WindowedValue<V>> kv) throws Exception { + @Override + public WindowedValue<KV<K, V>> call(KV<K, WindowedValue<V>> kv) throws Exception { WindowedValue<V> wv = kv.getValue(); return wv.withValue(KV.of(kv.getKey(), wv.getValue())); } @@ -193,28 +247,25 @@ public final class TranslationUtils { /** * Create SideInputs as Broadcast variables. * - * @param views The {@link PCollectionView}s. + * @param views The {@link PCollectionView}s. * @param context The {@link EvaluationContext}. * @return a map of tagged {@link SideInputBroadcast}s and their {@link WindowingStrategy}. */ - static Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, SideInputBroadcast<?>>> - getSideInputs(List<PCollectionView<?>> views, EvaluationContext context) { + static Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, SideInputBroadcast<?>>> getSideInputs( + List<PCollectionView<?>> views, EvaluationContext context) { return getSideInputs(views, context.getSparkContext(), context.getPViews()); } /** * Create SideInputs as Broadcast variables. * - * @param views The {@link PCollectionView}s. + * @param views The {@link PCollectionView}s. * @param context The {@link JavaSparkContext}. - * @param pviews The {@link SparkPCollectionView}. + * @param pviews The {@link SparkPCollectionView}. * @return a map of tagged {@link SideInputBroadcast}s and their {@link WindowingStrategy}. */ - public static Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, SideInputBroadcast<?>>> - getSideInputs( - List<PCollectionView<?>> views, - JavaSparkContext context, - SparkPCollectionView pviews) { + public static Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, SideInputBroadcast<?>>> getSideInputs( + List<PCollectionView<?>> views, JavaSparkContext context, SparkPCollectionView pviews) { if (views == null) { return ImmutableMap.of(); } else { @@ -223,7 +274,8 @@ public final class TranslationUtils { for (PCollectionView<?> view : views) { SideInputBroadcast helper = pviews.getPCollectionView(view, context); WindowingStrategy<?, ?> windowingStrategy = view.getWindowingStrategyInternal(); - sideInputs.put(view.getTagInternal(), + sideInputs.put( + view.getTagInternal(), KV.<WindowingStrategy<?, ?>, SideInputBroadcast<?>>of(windowingStrategy, helper)); } return sideInputs; @@ -270,9 +322,96 @@ public final class TranslationUtils { public static <T> VoidFunction<T> emptyVoidFunction() { return new VoidFunction<T>() { - @Override public void call(T t) throws Exception { + @Override + public void call(T t) throws Exception { // Empty implementation. } }; } + + /** + * A utility method that adapts {@link PairFunction} to a {@link PairFlatMapFunction} with an + * {@link Iterator} input. This is particularly useful because it allows to use functions written + * for mapToPair functions in flatmapToPair functions. + * + * @param pairFunction the {@link PairFunction} to adapt. + * @param <T> the input type. + * @param <K> the output key type. + * @param <V> the output value type. + * @return a {@link PairFlatMapFunction} that accepts an {@link Iterator} as an input and applies + * the {@link PairFunction} on every element. + */ + public static <T, K, V> PairFlatMapFunction<Iterator<T>, K, V> pairFunctionToPairFlatMapFunction( + final PairFunction<T, K, V> pairFunction) { + return new PairFlatMapFunction<Iterator<T>, K, V>() { + + @Override + public Iterable<Tuple2<K, V>> call(Iterator<T> itr) throws Exception { + final Iterator<Tuple2<K, V>> outputItr = + Iterators.transform( + itr, + new com.google.common.base.Function<T, Tuple2<K, V>>() { + + @Override + public Tuple2<K, V> apply(T t) { + try { + return pairFunction.call(t); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + }); + return new Iterable<Tuple2<K, V>>() { + + @Override + public Iterator<Tuple2<K, V>> iterator() { + return outputItr; + } + }; + } + }; + } + + /** + * A utility method that adapts {@link Function} to a {@link FlatMapFunction} with an {@link + * Iterator} input. This is particularly useful because it allows to use functions written for map + * functions in flatmap functions. + * + * @param func the {@link Function} to adapt. + * @param <InputT> the input type. + * @param <OutputT> the output type. + * @return a {@link FlatMapFunction} that accepts an {@link Iterator} as an input and applies the + * {@link Function} on every element. + */ + public static <InputT, OutputT> + FlatMapFunction<Iterator<InputT>, OutputT> functionToFlatMapFunction( + final Function<InputT, OutputT> func) { + return new FlatMapFunction<Iterator<InputT>, OutputT>() { + + @Override + public Iterable<OutputT> call(Iterator<InputT> itr) throws Exception { + final Iterator<OutputT> outputItr = + Iterators.transform( + itr, + new com.google.common.base.Function<InputT, OutputT>() { + + @Override + public OutputT apply(InputT t) { + try { + return func.call(t); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + }); + return new Iterable<OutputT>() { + + @Override + public Iterator<OutputT> iterator() { + return outputItr; + } + }; + } + }; + } }
