Serialize state stream with coders for shuffle and checkpointing.
Project: http://git-wip-us.apache.org/repos/asf/beam/repo Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/da5f8497 Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/da5f8497 Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/da5f8497 Branch: refs/heads/master Commit: da5f84970a30aa5afbbe3cf93e3ed28b9df3260d Parents: b21de69 Author: Sela <[email protected]> Authored: Mon Feb 20 00:13:23 2017 +0200 Committer: Sela <[email protected]> Committed: Wed Mar 1 00:18:01 2017 +0200 ---------------------------------------------------------------------- .../beam/runners/spark/coders/CoderHelpers.java | 23 ++++ .../SparkGroupAlsoByWindowViaWindowSet.java | 135 ++++++++++++------- .../spark/stateful/SparkTimerInternals.java | 17 ++- .../streaming/StreamingTransformTranslator.java | 3 +- 4 files changed, 124 insertions(+), 54 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/beam/blob/da5f8497/runners/spark/src/main/java/org/apache/beam/runners/spark/coders/CoderHelpers.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/coders/CoderHelpers.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/coders/CoderHelpers.java index 0df66c2..9c46ecf 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/coders/CoderHelpers.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/coders/CoderHelpers.java @@ -18,16 +18,21 @@ package org.apache.beam.runners.spark.coders; +import static com.google.common.base.Preconditions.checkNotNull; + import com.google.common.collect.Iterables; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.util.Collection; import java.util.LinkedList; import java.util.List; +import javax.annotation.Nonnull; import org.apache.beam.runners.spark.util.ByteArray; import org.apache.beam.sdk.coders.Coder; import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.PairFunction; + import scala.Tuple2; /** @@ -89,6 +94,24 @@ public final class CoderHelpers { } /** + * Utility method for deserializing a Iterable of byte arrays using the specified coder. + * + * @param serialized bytearrays to be deserialized. + * @param coder Coder to deserialize with. + * @param <T> Type of object to be returned. + * @return Iterable of deserialized objects. + */ + public static <T> Iterable<T> fromByteArrays( + Collection<byte[]> serialized, final Coder<T> coder) { + return Iterables.transform(serialized, new com.google.common.base.Function<byte[], T>() { + @Override + public T apply(@Nonnull byte[] bytes) { + return fromByteArray(checkNotNull(bytes, "Cannot decode null values."), coder); + } + }); + } + + /** * A function wrapper for converting an object to a bytearray. * * @param coder Coder to serialize with. http://git-wip-us.apache.org/repos/asf/beam/blob/da5f8497/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java index 2fb4100..5589d82 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java @@ -27,23 +27,28 @@ import org.apache.beam.runners.core.GroupAlsoByWindowsDoFn; import org.apache.beam.runners.core.OutputWindowedValue; import org.apache.beam.runners.core.ReduceFnRunner; import org.apache.beam.runners.core.SystemReduceFn; -import org.apache.beam.runners.core.TimerInternals.TimerData; +import org.apache.beam.runners.core.TimerInternals; import org.apache.beam.runners.core.triggers.ExecutableTriggerStateMachine; import org.apache.beam.runners.core.triggers.TriggerStateMachines; import org.apache.beam.runners.spark.SparkPipelineOptions; +import org.apache.beam.runners.spark.coders.CoderHelpers; import org.apache.beam.runners.spark.translation.SparkRuntimeContext; import org.apache.beam.runners.spark.translation.TranslationUtils; import org.apache.beam.runners.spark.translation.WindowingHelpers; +import org.apache.beam.runners.spark.util.ByteArray; import org.apache.beam.runners.spark.util.GlobalWatermarkHolder; import org.apache.beam.runners.spark.util.LateDataUtils; import org.apache.beam.runners.spark.util.UnsupportedSideInputReader; import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.IterableCoder; +import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.transforms.Aggregator; import org.apache.beam.sdk.transforms.Combine; import org.apache.beam.sdk.transforms.Sum; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.PaneInfo; import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.util.WindowedValue.FullWindowedValueCoder; import org.apache.beam.sdk.util.WindowingStrategy; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.TupleTag; @@ -97,45 +102,67 @@ public class SparkGroupAlsoByWindowViaWindowSet { public static <K, InputT, W extends BoundedWindow> JavaDStream<WindowedValue<KV<K, Iterable<InputT>>>> groupAlsoByWindow( JavaDStream<WindowedValue<KV<K, Iterable<WindowedValue<InputT>>>>> inputDStream, - final Coder<InputT> iCoder, + final Coder<K> keyCoder, + final Coder<WindowedValue<InputT>> wvCoder, final WindowingStrategy<?, W> windowingStrategy, final SparkRuntimeContext runtimeContext, final List<Integer> sourceIds) { + final IterableCoder<WindowedValue<InputT>> itrWvCoder = IterableCoder.of(wvCoder); + final Coder<InputT> iCoder = ((FullWindowedValueCoder<InputT>) wvCoder).getValueCoder(); + final Coder<? extends BoundedWindow> wCoder = + ((FullWindowedValueCoder<InputT>) wvCoder).getWindowCoder(); + final Coder<WindowedValue<KV<K, Iterable<InputT>>>> wvKvIterCoder = + FullWindowedValueCoder.of(KvCoder.of(keyCoder, IterableCoder.of(iCoder)), wCoder); + final TimerInternals.TimerDataCoder timerDataCoder = + TimerInternals.TimerDataCoder.of(windowingStrategy.getWindowFn().windowCoder()); + long checkpointDurationMillis = runtimeContext.getPipelineOptions().as(SparkPipelineOptions.class) .getCheckpointDurationMillis(); // we have to switch to Scala API to avoid Optional in the Java API, see: SPARK-4819. // we also have a broader API for Scala (access to the actual key and entire iterator). - DStream<Tuple2<K, Iterable<WindowedValue<InputT>>>> pairDStream = + // we use coders to convert objects in the PCollection to byte arrays, so they + // can be transferred over the network for the shuffle and be in serialized form + // for checkpointing. + // for readability, we add comments with actual type next to byte[]. + // to shorten line length, we use: + //---- WV: WindowedValue + //---- Iterable: Itr + //---- AccumT: A + //---- InputT: I + DStream<Tuple2</*K*/ ByteArray, /*Itr<WV<I>>*/ byte[]>> pairDStream = inputDStream .map(WindowingHelpers.<KV<K, Iterable<WindowedValue<InputT>>>>unwindowFunction()) .mapToPair(TranslationUtils.<K, Iterable<WindowedValue<InputT>>>toPairFunction()) + // move to bytes and use coders for deserialization because there's a shuffle + // and checkpointing involved. + .mapToPair(CoderHelpers.toByteFunction(keyCoder, itrWvCoder)) .dstream(); - PairDStreamFunctions<K, Iterable<WindowedValue<InputT>>> pairDStreamFunctions = + PairDStreamFunctions<ByteArray, byte[]> pairDStreamFunctions = DStream.toPairDStreamFunctions( pairDStream, - JavaSparkContext$.MODULE$.<K>fakeClassTag(), - JavaSparkContext$.MODULE$.<Iterable<WindowedValue<InputT>>>fakeClassTag(), + JavaSparkContext$.MODULE$.<ByteArray>fakeClassTag(), + JavaSparkContext$.MODULE$.<byte[]>fakeClassTag(), null); int defaultNumPartitions = pairDStreamFunctions.defaultPartitioner$default$1(); Partitioner partitioner = pairDStreamFunctions.defaultPartitioner(defaultNumPartitions); // use updateStateByKey to scan through the state and update elements and timers. - DStream<Tuple2<K, Tuple2<StateAndTimers, List<WindowedValue<KV<K, Iterable<InputT>>>>>>> + DStream<Tuple2</*K*/ ByteArray, Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/ List<byte[]>>>> firedStream = pairDStreamFunctions.updateStateByKey( new SerializableFunction1< - scala.collection.Iterator<Tuple3<K, Seq<Iterable<WindowedValue<InputT>>>, - Option<Tuple2<StateAndTimers, List<WindowedValue<KV<K, Iterable<InputT>>>>>>>>, - scala.collection.Iterator<Tuple2<K, Tuple2<StateAndTimers, - List<WindowedValue<KV<K, Iterable<InputT>>>>>>>>() { + scala.collection.Iterator<Tuple3</*K*/ ByteArray, Seq</*Itr<WV<I>>*/ byte[]>, + Option<Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/ List<byte[]>>>>>, + scala.collection.Iterator<Tuple2</*K*/ ByteArray, Tuple2<StateAndTimers, + /*WV<KV<K, Itr<I>>>*/ List<byte[]>>>>>() { @Override - public scala.collection.Iterator<Tuple2<K, Tuple2<StateAndTimers, - List<WindowedValue<KV<K, Iterable<InputT>>>>>>> apply( - final scala.collection.Iterator<Tuple3<K, Seq<Iterable<WindowedValue<InputT>>>, - Option<Tuple2<StateAndTimers, List<WindowedValue<KV<K, Iterable<InputT>>>>>>>> iter) { + public scala.collection.Iterator<Tuple2</*K*/ ByteArray, Tuple2<StateAndTimers, + /*WV<KV<K, Itr<I>>>*/ List<byte[]>>>> apply( + final scala.collection.Iterator<Tuple3</*K*/ ByteArray, Seq</*Itr<WV<I>>*/ byte[]>, + Option<Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/ List<byte[]>>>>> iter) { //--- ACTUAL STATEFUL OPERATION: // // Input Iterator: the partition (~bundle) of a cogrouping of the input @@ -149,7 +176,8 @@ public class SparkGroupAlsoByWindowViaWindowSet { // (3) Seq.nonEmpty && Option<S>.isDefined: new data with previous state. final SystemReduceFn<K, InputT, Iterable<InputT>, Iterable<InputT>, W> reduceFn = - SystemReduceFn.buffering(iCoder); + SystemReduceFn.buffering( + ((FullWindowedValueCoder<InputT>) wvCoder).getValueCoder()); final OutputWindowedValueHolder<K, InputT> outputHolder = new OutputWindowedValueHolder<>(); // use in memory Aggregators since Spark Accumulators are not resilient @@ -160,25 +188,25 @@ public class SparkGroupAlsoByWindowViaWindowSet { GroupAlsoByWindowsDoFn.DROPPED_DUE_TO_LATENESS_COUNTER); AbstractIterator< - Tuple2<K, Tuple2<StateAndTimers, List<WindowedValue<KV<K, Iterable<InputT>>>>>>> - outIter = new AbstractIterator<Tuple2<K, - Tuple2<StateAndTimers, List<WindowedValue<KV<K, Iterable<InputT>>>>>>>() { + Tuple2</*K*/ ByteArray, Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/ List<byte[]>>>> + outIter = new AbstractIterator<Tuple2</*K*/ ByteArray, + Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/ List<byte[]>>>>() { @Override - protected Tuple2<K, Tuple2<StateAndTimers, - List<WindowedValue<KV<K, Iterable<InputT>>>>>> computeNext() { + protected Tuple2</*K*/ ByteArray, Tuple2<StateAndTimers, + /*WV<KV<K, Itr<I>>>*/ List<byte[]>>> computeNext() { // input iterator is a Spark partition (~bundle), containing keys and their // (possibly) previous-state and (possibly) new data. while (iter.hasNext()) { // for each element in the partition: - Tuple3<K, Seq<Iterable<WindowedValue<InputT>>>, Option<Tuple2<StateAndTimers, - List<WindowedValue<KV<K, Iterable<InputT>>>>>>> next = iter.next(); - K key = next._1(); + Tuple3<ByteArray, Seq<byte[]>, + Option<Tuple2<StateAndTimers, List<byte[]>>>> next = iter.next(); + ByteArray encodedKey = next._1(); + K key = CoderHelpers.fromByteArray(encodedKey.getValue(), keyCoder); - Seq<Iterable<WindowedValue<InputT>>> seq = next._2(); + Seq<byte[]> seq = next._2(); Option<Tuple2<StateAndTimers, - List<WindowedValue<KV<K, Iterable<InputT>>>>>> - prevStateAndTimersOpt = next._3(); + List<byte[]>>> prevStateAndTimersOpt = next._3(); SparkStateInternals<K> stateInternals; SparkTimerInternals timerInternals = SparkTimerInternals.forStreamFromSources( @@ -192,7 +220,9 @@ public class SparkGroupAlsoByWindowViaWindowSet { StateAndTimers prevStateAndTimers = prevStateAndTimersOpt.get()._1(); stateInternals = SparkStateInternals.forKeyAndState(key, prevStateAndTimers.getState()); - timerInternals.addTimers(prevStateAndTimers.getTimers()); + Collection<byte[]> serTimers = prevStateAndTimers.getTimers(); + timerInternals.addTimers( + SparkTimerInternals.deserializeTimers(serTimers, timerDataCoder)); } ReduceFnRunner<K, InputT, Iterable<InputT>, W> reduceFnRunner = @@ -214,7 +244,8 @@ public class SparkGroupAlsoByWindowViaWindowSet { if (!seq.isEmpty()) { // new input for key. try { - Iterable<WindowedValue<InputT>> elementsIterable = seq.head(); + Iterable<WindowedValue<InputT>> elementsIterable = + CoderHelpers.fromByteArray(seq.head(), itrWvCoder); Iterable<WindowedValue<InputT>> validElements = LateDataUtils .dropExpiredWindows( @@ -247,9 +278,11 @@ public class SparkGroupAlsoByWindowViaWindowSet { List<WindowedValue<KV<K, Iterable<InputT>>>> outputs = outputHolder.get(); if (!outputs.isEmpty() || !stateInternals.getState().isEmpty()) { StateAndTimers updated = new StateAndTimers(stateInternals.getState(), - timerInternals.getTimers()); + SparkTimerInternals.serializeTimers( + timerInternals.getTimers(), timerDataCoder)); // persist Spark's state by outputting. - return new Tuple2<>(key, new Tuple2<>(updated, outputs)); + List<byte[]> serOutput = CoderHelpers.toByteArrays(outputs, wvKvIterCoder); + return new Tuple2<>(encodedKey, new Tuple2<>(updated, serOutput)); } // an empty state with no output, can be evicted completely - do nothing. } @@ -271,43 +304,43 @@ public class SparkGroupAlsoByWindowViaWindowSet { return scala.collection.JavaConversions.asScalaIterator(outIter); } - }, partitioner, true, JavaSparkContext$.MODULE$.<Tuple2<StateAndTimers, - List<WindowedValue<KV<K, Iterable<InputT>>>>>>fakeClassTag()); + }, partitioner, true, + JavaSparkContext$.MODULE$.<Tuple2<StateAndTimers, List<byte[]>>>fakeClassTag()); if (checkpointDurationMillis > 0) { firedStream.checkpoint(new Duration(checkpointDurationMillis)); } // go back to Java now. - JavaPairDStream<K, Tuple2<StateAndTimers, List<WindowedValue<KV<K, Iterable<InputT>>>>>> + JavaPairDStream</*K*/ ByteArray, Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/ List<byte[]>>> javaFiredStream = JavaPairDStream.fromPairDStream( firedStream, - JavaSparkContext$.MODULE$.<K>fakeClassTag(), - JavaSparkContext$.MODULE$.<Tuple2<StateAndTimers, - List<WindowedValue<KV<K, Iterable<InputT>>>>>>fakeClassTag()); + JavaSparkContext$.MODULE$.<ByteArray>fakeClassTag(), + JavaSparkContext$.MODULE$.<Tuple2<StateAndTimers, List<byte[]>>>fakeClassTag()); // filter state-only output (nothing to fire) and remove the state from the output. return javaFiredStream.filter( - new Function<Tuple2<K, Tuple2<StateAndTimers, - List<WindowedValue<KV<K, Iterable<InputT>>>>>>, Boolean>() { + new Function<Tuple2</*K*/ ByteArray, Tuple2<StateAndTimers, + /*WV<KV<K, Itr<I>>>*/ List<byte[]>>>, Boolean>() { @Override public Boolean call( - Tuple2<K, Tuple2<StateAndTimers, - List<WindowedValue<KV<K, Iterable<InputT>>>>>> t2) throws Exception { + Tuple2</*K*/ ByteArray, Tuple2<StateAndTimers, + /*WV<KV<K, Itr<I>>>*/ List<byte[]>>> t2) throws Exception { // filter output if defined. return !t2._2()._2().isEmpty(); } }) .flatMap( - new FlatMapFunction<Tuple2<K, Tuple2<StateAndTimers, - List<WindowedValue<KV<K, Iterable<InputT>>>>>>, + new FlatMapFunction<Tuple2</*K*/ ByteArray, Tuple2<StateAndTimers, + /*WV<KV<K, Itr<I>>>*/ List<byte[]>>>, WindowedValue<KV<K, Iterable<InputT>>>>() { @Override public Iterable<WindowedValue<KV<K, Iterable<InputT>>>> call( - Tuple2<K, Tuple2<StateAndTimers, - List<WindowedValue<KV<K, Iterable<InputT>>>>>> t2) throws Exception { + Tuple2</*K*/ ByteArray, Tuple2<StateAndTimers, + /*WV<KV<K, Itr<I>>>*/ List<byte[]>>> t2) throws Exception { // drop the state since it is already persisted at this point. - return t2._2()._2(); + // return in serialized form. + return CoderHelpers.fromByteArrays(t2._2()._2(), wvKvIterCoder); } }); } @@ -315,20 +348,20 @@ public class SparkGroupAlsoByWindowViaWindowSet { private static class StateAndTimers { //Serializable state for internals (namespace to state tag to coded value). private final Table<String, String, byte[]> state; - private final Collection<TimerData> timers; + private final Collection<byte[]> serTimers; private StateAndTimers( - Table<String, String, byte[]> state, Collection<TimerData> timers) { + Table<String, String, byte[]> state, Collection<byte[]> timers) { this.state = state; - this.timers = timers; + this.serTimers = timers; } public Table<String, String, byte[]> getState() { return state; } - public Collection<TimerData> getTimers() { - return timers; + public Collection<byte[]> getTimers() { + return serTimers; } } http://git-wip-us.apache.org/repos/asf/beam/blob/da5f8497/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkTimerInternals.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkTimerInternals.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkTimerInternals.java index b9783ef..1949e1d 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkTimerInternals.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkTimerInternals.java @@ -29,6 +29,7 @@ import java.util.Set; import javax.annotation.Nullable; import org.apache.beam.runners.core.StateNamespace; import org.apache.beam.runners.core.TimerInternals; +import org.apache.beam.runners.spark.coders.CoderHelpers; import org.apache.beam.runners.spark.util.GlobalWatermarkHolder.SparkWatermarks; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.util.TimeDomain; @@ -109,8 +110,10 @@ class SparkTimerInternals implements TimerInternals { return toFire; } - void addTimers(Collection<TimerData> timers) { - this.timers.addAll(timers); + void addTimers(Iterable<TimerData> timers) { + for (TimerData timer: timers) { + this.timers.add(timer); + } } @Override @@ -169,4 +172,14 @@ class SparkTimerInternals implements TimerInternals { throw new UnsupportedOperationException("Deleting a timer by ID is not yet supported."); } + public static Collection<byte[]> serializeTimers( + Collection<TimerData> timers, TimerDataCoder timerDataCoder) { + return CoderHelpers.toByteArrays(timers, timerDataCoder); + } + + public static Iterable<TimerData> deserializeTimers( + Collection<byte[]> serTimers, TimerDataCoder timerDataCoder) { + return CoderHelpers.fromByteArrays(serTimers, timerDataCoder); + } + } http://git-wip-us.apache.org/repos/asf/beam/blob/da5f8497/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java index 7abf5be..a98eff2 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java @@ -258,7 +258,8 @@ final class StreamingTransformTranslator { JavaDStream<WindowedValue<KV<K, Iterable<V>>>> outStream = SparkGroupAlsoByWindowViaWindowSet.groupAlsoByWindow( groupedByKeyStream, - coder.getValueCoder(), + coder.getKeyCoder(), + wvCoder, windowingStrategy, runtimeContext, streamSources);
