[BEAM-2359] Fix watermark broadcasting to executors in Spark runner
Project: http://git-wip-us.apache.org/repos/asf/beam/repo Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/20820fa5 Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/20820fa5 Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/20820fa5 Branch: refs/heads/master Commit: 20820fa5477ffcdd4a9ef2e9340353ed3c5691a9 Parents: b3099bb Author: Aviem Zur <aviem...@gmail.com> Authored: Mon Jun 12 17:04:00 2017 +0300 Committer: Aviem Zur <aviem...@gmail.com> Committed: Thu Jun 22 14:51:02 2017 +0300 ---------------------------------------------------------------------- .../apache/beam/runners/spark/SparkRunner.java | 2 +- .../beam/runners/spark/TestSparkRunner.java | 2 +- .../SparkGroupAlsoByWindowViaWindowSet.java | 6 +- .../spark/stateful/SparkTimerInternals.java | 18 ++- .../spark/util/GlobalWatermarkHolder.java | 127 ++++++++++++++----- .../spark/GlobalWatermarkHolderTest.java | 18 +-- 6 files changed, 120 insertions(+), 53 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/beam/blob/20820fa5/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunner.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunner.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunner.java index d008718..595521f 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunner.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunner.java @@ -171,7 +171,7 @@ public final class SparkRunner extends PipelineRunner<SparkPipelineResult> { } // register Watermarks listener to broadcast the advanced WMs. - jssc.addStreamingListener(new JavaStreamingListenerWrapper(new WatermarksListener(jssc))); + jssc.addStreamingListener(new JavaStreamingListenerWrapper(new WatermarksListener())); // The reason we call initAccumulators here even though it is called in // SparkRunnerStreamingContextFactory is because the factory is not called when resuming http://git-wip-us.apache.org/repos/asf/beam/blob/20820fa5/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java index eccee57..a13a3b1 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java @@ -169,7 +169,7 @@ public final class TestSparkRunner extends PipelineRunner<SparkPipelineResult> { result.waitUntilFinish(Duration.millis(batchDurationMillis)); do { SparkTimerInternals sparkTimerInternals = - SparkTimerInternals.global(GlobalWatermarkHolder.get()); + SparkTimerInternals.global(GlobalWatermarkHolder.get(batchDurationMillis)); sparkTimerInternals.advanceWatermark(); globalWatermark = sparkTimerInternals.currentInputWatermarkTime(); // let another batch-interval period of execution, just to reason about WM propagation. http://git-wip-us.apache.org/repos/asf/beam/blob/20820fa5/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java index be4f3f6..1385e07 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java @@ -104,13 +104,15 @@ public class SparkGroupAlsoByWindowViaWindowSet { public static <K, InputT, W extends BoundedWindow> JavaDStream<WindowedValue<KV<K, Iterable<InputT>>>> groupAlsoByWindow( - JavaDStream<WindowedValue<KV<K, Iterable<WindowedValue<InputT>>>>> inputDStream, + final JavaDStream<WindowedValue<KV<K, Iterable<WindowedValue<InputT>>>>> inputDStream, final Coder<K> keyCoder, final Coder<WindowedValue<InputT>> wvCoder, final WindowingStrategy<?, W> windowingStrategy, final SparkRuntimeContext runtimeContext, final List<Integer> sourceIds) { + final long batchDurationMillis = + runtimeContext.getPipelineOptions().as(SparkPipelineOptions.class).getBatchIntervalMillis(); final IterableCoder<WindowedValue<InputT>> itrWvCoder = IterableCoder.of(wvCoder); final Coder<InputT> iCoder = ((FullWindowedValueCoder<InputT>) wvCoder).getValueCoder(); final Coder<? extends BoundedWindow> wCoder = @@ -239,7 +241,7 @@ public class SparkGroupAlsoByWindowViaWindowSet { SparkStateInternals<K> stateInternals; SparkTimerInternals timerInternals = SparkTimerInternals.forStreamFromSources( - sourceIds, GlobalWatermarkHolder.get()); + sourceIds, GlobalWatermarkHolder.get(batchDurationMillis)); // get state(internals) per key. if (prevStateAndTimersOpt.isEmpty()) { // no previous state. http://git-wip-us.apache.org/repos/asf/beam/blob/20820fa5/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkTimerInternals.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkTimerInternals.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkTimerInternals.java index 107915f..a68da55 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkTimerInternals.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkTimerInternals.java @@ -34,7 +34,6 @@ import org.apache.beam.runners.spark.coders.CoderHelpers; import org.apache.beam.runners.spark.util.GlobalWatermarkHolder.SparkWatermarks; import org.apache.beam.sdk.state.TimeDomain; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.spark.broadcast.Broadcast; import org.joda.time.Instant; @@ -58,10 +57,10 @@ public class SparkTimerInternals implements TimerInternals { /** Build the {@link TimerInternals} according to the feeding streams. */ public static SparkTimerInternals forStreamFromSources( List<Integer> sourceIds, - @Nullable Broadcast<Map<Integer, SparkWatermarks>> broadcast) { - // if broadcast is invalid for the specific ids, use defaults. - if (broadcast == null || broadcast.getValue().isEmpty() - || Collections.disjoint(sourceIds, broadcast.getValue().keySet())) { + Map<Integer, SparkWatermarks> watermarks) { + // if watermarks are invalid for the specific ids, use defaults. + if (watermarks == null || watermarks.isEmpty() + || Collections.disjoint(sourceIds, watermarks.keySet())) { return new SparkTimerInternals( BoundedWindow.TIMESTAMP_MIN_VALUE, BoundedWindow.TIMESTAMP_MIN_VALUE, new Instant(0)); } @@ -71,7 +70,7 @@ public class SparkTimerInternals implements TimerInternals { // synchronized processing time should clearly be synchronized. Instant synchronizedProcessingTime = null; for (Integer sourceId: sourceIds) { - SparkWatermarks sparkWatermarks = broadcast.getValue().get(sourceId); + SparkWatermarks sparkWatermarks = watermarks.get(sourceId); if (sparkWatermarks != null) { // keep slowest WMs. slowestLowWatermark = slowestLowWatermark.isBefore(sparkWatermarks.getLowWatermark()) @@ -94,10 +93,9 @@ public class SparkTimerInternals implements TimerInternals { } /** Build a global {@link TimerInternals} for all feeding streams.*/ - public static SparkTimerInternals global( - @Nullable Broadcast<Map<Integer, SparkWatermarks>> broadcast) { - return broadcast == null ? forStreamFromSources(Collections.<Integer>emptyList(), null) - : forStreamFromSources(Lists.newArrayList(broadcast.getValue().keySet()), broadcast); + public static SparkTimerInternals global(Map<Integer, SparkWatermarks> watermarks) { + return watermarks == null ? forStreamFromSources(Collections.<Integer>emptyList(), null) + : forStreamFromSources(Lists.newArrayList(watermarks.keySet()), watermarks); } Collection<TimerData> getTimers() { http://git-wip-us.apache.org/repos/asf/beam/blob/20820fa5/runners/spark/src/main/java/org/apache/beam/runners/spark/util/GlobalWatermarkHolder.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/util/GlobalWatermarkHolder.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/util/GlobalWatermarkHolder.java index 8b384d8..2cb6f26 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/util/GlobalWatermarkHolder.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/util/GlobalWatermarkHolder.java @@ -21,31 +21,43 @@ package org.apache.beam.runners.spark.util; import static com.google.common.base.Preconditions.checkState; import com.google.common.annotations.VisibleForTesting; +import com.google.common.cache.CacheBuilder; +import com.google.common.cache.CacheLoader; +import com.google.common.cache.LoadingCache; +import com.google.common.collect.Maps; import java.io.Serializable; import java.util.HashMap; import java.util.Map; import java.util.Queue; import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import javax.annotation.Nonnull; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.SparkEnv; import org.apache.spark.broadcast.Broadcast; -import org.apache.spark.streaming.api.java.JavaStreamingContext; +import org.apache.spark.storage.BlockId; +import org.apache.spark.storage.BlockManager; +import org.apache.spark.storage.BlockResult; +import org.apache.spark.storage.BlockStore; +import org.apache.spark.storage.StorageLevel; import org.apache.spark.streaming.api.java.JavaStreamingListener; import org.apache.spark.streaming.api.java.JavaStreamingListenerBatchCompleted; import org.joda.time.Instant; - +import scala.Option; /** - * A {@link Broadcast} variable to hold the global watermarks for a micro-batch. + * A {@link BlockStore} variable to hold the global watermarks for a micro-batch. * * <p>For each source, holds a queue for the watermarks of each micro-batch that was read, * and advances the watermarks according to the queue (first-in-first-out). */ public class GlobalWatermarkHolder { - // the broadcast is broadcasted to the workers. - private static volatile Broadcast<Map<Integer, SparkWatermarks>> broadcast = null; - // this should only live in the driver so transient. - private static final transient Map<Integer, Queue<SparkWatermarks>> sourceTimes = new HashMap<>(); + private static final Map<Integer, Queue<SparkWatermarks>> sourceTimes = new HashMap<>(); + private static final BlockId WATERMARKS_BLOCK_ID = BlockId.apply("broadcast_0WATERMARKS"); + + private static volatile Map<Integer, SparkWatermarks> driverWatermarks = null; + private static volatile LoadingCache<String, Map<Integer, SparkWatermarks>> watermarkCache = null; public static void add(int sourceId, SparkWatermarks sparkWatermarks) { Queue<SparkWatermarks> timesQueue = sourceTimes.get(sourceId); @@ -71,22 +83,48 @@ public class GlobalWatermarkHolder { * Returns the {@link Broadcast} containing the {@link SparkWatermarks} mapped * to their sources. */ - public static Broadcast<Map<Integer, SparkWatermarks>> get() { - return broadcast; + @SuppressWarnings("unchecked") + public static Map<Integer, SparkWatermarks> get(Long cacheInterval) { + if (driverWatermarks != null) { + // if we are executing in local mode simply return the local values. + return driverWatermarks; + } else { + if (watermarkCache == null) { + initWatermarkCache(cacheInterval); + } + try { + return watermarkCache.get("SINGLETON"); + } catch (ExecutionException e) { + throw new RuntimeException(e); + } + } + } + + private static synchronized void initWatermarkCache(Long batchDuration) { + if (watermarkCache == null) { + watermarkCache = + CacheBuilder.newBuilder() + // expire watermarks every half batch duration to ensure they update in every batch. + .expireAfterWrite(batchDuration / 2, TimeUnit.MILLISECONDS) + .build(new WatermarksLoader()); + } } /** * Advances the watermarks to the next-in-line watermarks. * SparkWatermarks are monotonically increasing. */ - public static void advance(JavaSparkContext jsc) { - synchronized (GlobalWatermarkHolder.class){ + @SuppressWarnings("unchecked") + public static void advance() { + synchronized (GlobalWatermarkHolder.class) { + BlockManager blockManager = SparkEnv.get().blockManager(); + if (sourceTimes.isEmpty()) { return; } // update all sources' watermarks into the new broadcast. - Map<Integer, SparkWatermarks> newBroadcast = new HashMap<>(); + Map<Integer, SparkWatermarks> newValues = new HashMap<>(); for (Map.Entry<Integer, Queue<SparkWatermarks>> en: sourceTimes.entrySet()) { if (en.getValue().isEmpty()) { @@ -99,8 +137,22 @@ public class GlobalWatermarkHolder { Instant currentLowWatermark = BoundedWindow.TIMESTAMP_MIN_VALUE; Instant currentHighWatermark = BoundedWindow.TIMESTAMP_MIN_VALUE; Instant currentSynchronizedProcessingTime = BoundedWindow.TIMESTAMP_MIN_VALUE; - if (broadcast != null && broadcast.getValue().containsKey(sourceId)) { - SparkWatermarks currentTimes = broadcast.getValue().get(sourceId); + + Option<BlockResult> currentOption = blockManager.getRemote(WATERMARKS_BLOCK_ID); + Map<Integer, SparkWatermarks> current; + if (currentOption.isDefined()) { + current = (Map<Integer, SparkWatermarks>) currentOption.get().data().next(); + } else { + current = Maps.newHashMap(); + blockManager.putSingle( + WATERMARKS_BLOCK_ID, + current, + StorageLevel.MEMORY_ONLY(), + true); + } + + if (current.containsKey(sourceId)) { + SparkWatermarks currentTimes = current.get(sourceId); currentLowWatermark = currentTimes.getLowWatermark(); currentHighWatermark = currentTimes.getHighWatermark(); currentSynchronizedProcessingTime = currentTimes.getSynchronizedProcessingTime(); @@ -119,20 +171,21 @@ public class GlobalWatermarkHolder { nextLowWatermark, nextHighWatermark)); checkState(nextSynchronizedProcessingTime.isAfter(currentSynchronizedProcessingTime), "Synchronized processing time must advance."); - newBroadcast.put( + newValues.put( sourceId, new SparkWatermarks( nextLowWatermark, nextHighWatermark, nextSynchronizedProcessingTime)); } // update the watermarks broadcast only if something has changed. - if (!newBroadcast.isEmpty()) { - if (broadcast != null) { - // for now this is blocking, we could make this asynchronous - // but it could slow down WM propagation. - broadcast.destroy(); - } - broadcast = jsc.broadcast(newBroadcast); + if (!newValues.isEmpty()) { + driverWatermarks = newValues; + blockManager.removeBlock(WATERMARKS_BLOCK_ID, true); + blockManager.putSingle( + WATERMARKS_BLOCK_ID, + newValues, + StorageLevel.MEMORY_ONLY(), + true); } } } @@ -140,7 +193,12 @@ public class GlobalWatermarkHolder { @VisibleForTesting public static synchronized void clear() { sourceTimes.clear(); - broadcast = null; + driverWatermarks = null; + SparkEnv sparkEnv = SparkEnv.get(); + if (sparkEnv != null) { + BlockManager blockManager = sparkEnv.blockManager(); + blockManager.removeBlock(WATERMARKS_BLOCK_ID, true); + } } /** @@ -185,15 +243,24 @@ public class GlobalWatermarkHolder { /** Advance the WMs onBatchCompleted event. */ public static class WatermarksListener extends JavaStreamingListener { - private final JavaStreamingContext jssc; - - public WatermarksListener(JavaStreamingContext jssc) { - this.jssc = jssc; + @Override + public void onBatchCompleted(JavaStreamingListenerBatchCompleted batchCompleted) { + GlobalWatermarkHolder.advance(); } + } + + private static class WatermarksLoader extends CacheLoader<String, Map<Integer, SparkWatermarks>> { + @SuppressWarnings("unchecked") @Override - public void onBatchCompleted(JavaStreamingListenerBatchCompleted batchCompleted) { - GlobalWatermarkHolder.advance(jssc.sparkContext()); + public Map<Integer, SparkWatermarks> load(@Nonnull String key) throws Exception { + Option<BlockResult> blockResultOption = + SparkEnv.get().blockManager().getRemote(WATERMARKS_BLOCK_ID); + if (blockResultOption.isDefined()) { + return (Map<Integer, SparkWatermarks>) blockResultOption.get().data().next(); + } else { + return Maps.newHashMap(); + } } } } http://git-wip-us.apache.org/repos/asf/beam/blob/20820fa5/runners/spark/src/test/java/org/apache/beam/runners/spark/GlobalWatermarkHolderTest.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/GlobalWatermarkHolderTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/GlobalWatermarkHolderTest.java index 47a6e3f..1708123 100644 --- a/runners/spark/src/test/java/org/apache/beam/runners/spark/GlobalWatermarkHolderTest.java +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/GlobalWatermarkHolderTest.java @@ -65,17 +65,17 @@ public class GlobalWatermarkHolderTest { instant.plus(Duration.millis(5)), instant.plus(Duration.millis(5)), instant)); - GlobalWatermarkHolder.advance(jsc); + GlobalWatermarkHolder.advance(); // low < high. GlobalWatermarkHolder.add(1, new SparkWatermarks( instant.plus(Duration.millis(10)), instant.plus(Duration.millis(15)), instant.plus(Duration.millis(100)))); - GlobalWatermarkHolder.advance(jsc); + GlobalWatermarkHolder.advance(); // assert watermarks in Broadcast. - SparkWatermarks currentWatermarks = GlobalWatermarkHolder.get().getValue().get(1); + SparkWatermarks currentWatermarks = GlobalWatermarkHolder.get(0L).get(1); assertThat(currentWatermarks.getLowWatermark(), equalTo(instant.plus(Duration.millis(10)))); assertThat(currentWatermarks.getHighWatermark(), equalTo(instant.plus(Duration.millis(15)))); assertThat(currentWatermarks.getSynchronizedProcessingTime(), @@ -93,7 +93,7 @@ public class GlobalWatermarkHolderTest { instant.plus(Duration.millis(25)), instant.plus(Duration.millis(20)), instant.plus(Duration.millis(200)))); - GlobalWatermarkHolder.advance(jsc); + GlobalWatermarkHolder.advance(); } @Test @@ -106,7 +106,7 @@ public class GlobalWatermarkHolderTest { instant.plus(Duration.millis(5)), instant.plus(Duration.millis(10)), instant)); - GlobalWatermarkHolder.advance(jsc); + GlobalWatermarkHolder.advance(); thrown.expect(IllegalStateException.class); thrown.expectMessage("Synchronized processing time must advance."); @@ -117,7 +117,7 @@ public class GlobalWatermarkHolderTest { instant.plus(Duration.millis(5)), instant.plus(Duration.millis(10)), instant)); - GlobalWatermarkHolder.advance(jsc); + GlobalWatermarkHolder.advance(); } @Test @@ -136,15 +136,15 @@ public class GlobalWatermarkHolderTest { instant.plus(Duration.millis(6)), instant)); - GlobalWatermarkHolder.advance(jsc); + GlobalWatermarkHolder.advance(); // assert watermarks for source 1. - SparkWatermarks watermarksForSource1 = GlobalWatermarkHolder.get().getValue().get(1); + SparkWatermarks watermarksForSource1 = GlobalWatermarkHolder.get(0L).get(1); assertThat(watermarksForSource1.getLowWatermark(), equalTo(instant.plus(Duration.millis(5)))); assertThat(watermarksForSource1.getHighWatermark(), equalTo(instant.plus(Duration.millis(10)))); // assert watermarks for source 2. - SparkWatermarks watermarksForSource2 = GlobalWatermarkHolder.get().getValue().get(2); + SparkWatermarks watermarksForSource2 = GlobalWatermarkHolder.get(0L).get(2); assertThat(watermarksForSource2.getLowWatermark(), equalTo(instant.plus(Duration.millis(3)))); assertThat(watermarksForSource2.getHighWatermark(), equalTo(instant.plus(Duration.millis(6)))); }