Repository: beam Updated Branches: refs/heads/master 7954896a5 -> efc701ed6
Test runner to stop on EOT watermark, or timeout. Remove timeout since it is already a pipeline option. Advance to infinity at the end of pipelines. Add EOT watermark and expected assertions test options. SparkPipelineResult should avoid returning null, and handle exceptions better. Make ResumeFromCheckpointStreamingTest use TestSparkRunner and stop on EOT watermark. Stop the context and update the state in finally. Addressed comments - better name for a watermark that stops execution. Project: http://git-wip-us.apache.org/repos/asf/beam/repo Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/2e308463 Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/2e308463 Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/2e308463 Branch: refs/heads/master Commit: 2e3084631dbb56acd3971f12596953b63290c9f2 Parents: 7954896 Author: Sela <ans...@paypal.com> Authored: Sat Mar 4 21:04:02 2017 +0200 Committer: Amit Sela <amitsel...@gmail.com> Committed: Thu Mar 9 18:02:39 2017 +0200 ---------------------------------------------------------------------- .../beam/runners/spark/SparkPipelineResult.java | 52 ++++++++++++------ .../runners/spark/TestSparkPipelineOptions.java | 26 +++++++++ .../beam/runners/spark/TestSparkRunner.java | 53 +++++++++++++++++-- .../spark/stateful/SparkTimerInternals.java | 10 +++- .../apache/beam/runners/spark/PipelineRule.java | 33 ++++-------- .../runners/spark/SparkPipelineStateTest.java | 12 ++--- .../translation/streaming/CreateStreamTest.java | 11 ++-- .../ResumeFromCheckpointStreamingTest.java | 55 ++++++++++---------- 8 files changed, 172 insertions(+), 80 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/beam/blob/2e308463/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineResult.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineResult.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineResult.java index ab59fb2..ddc1964 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineResult.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineResult.java @@ -19,6 +19,7 @@ package org.apache.beam.runners.spark; import java.io.IOException; +import java.util.Objects; import java.util.concurrent.ExecutionException; import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; @@ -56,11 +57,11 @@ public abstract class SparkPipelineResult implements PipelineResult { state = State.RUNNING; } - private RuntimeException runtimeExceptionFrom(final Throwable e) { + private static RuntimeException runtimeExceptionFrom(final Throwable e) { return (e instanceof RuntimeException) ? (RuntimeException) e : new RuntimeException(e); } - private RuntimeException beamExceptionFrom(final Throwable e) { + private static RuntimeException beamExceptionFrom(final Throwable e) { // Scala doesn't declare checked exceptions in the bytecode, and the Java compiler // won't let you catch something that is not declared, so we can't catch // SparkException directly, instead we do an instanceof check. @@ -107,15 +108,15 @@ public abstract class SparkPipelineResult implements PipelineResult { try { state = awaitTermination(duration); } catch (final TimeoutException e) { - state = null; + // ignore. } catch (final ExecutionException e) { state = PipelineResult.State.FAILED; + stop(); throw beamExceptionFrom(e.getCause()); } catch (final Exception e) { state = PipelineResult.State.FAILED; - throw beamExceptionFrom(e); - } finally { stop(); + throw beamExceptionFrom(e); } return state; @@ -149,6 +150,9 @@ public abstract class SparkPipelineResult implements PipelineResult { @Override protected void stop() { SparkContextFactory.stopSparkContext(javaSparkContext); + if (Objects.equals(state, State.RUNNING)) { + state = State.STOPPED; + } } @Override @@ -178,19 +182,37 @@ public abstract class SparkPipelineResult implements PipelineResult { // after calling stop, if exception occurs in "grace period" it won't propagate. // calling the StreamingContext's waiter with 0 msec will throw any error that might have // been thrown during the "grace period". - javaStreamingContext.awaitTermination(0); - SparkContextFactory.stopSparkContext(javaSparkContext); + try { + javaStreamingContext.awaitTermination(0); + } catch (Exception e) { + throw beamExceptionFrom(e); + } finally { + SparkContextFactory.stopSparkContext(javaSparkContext); + if (Objects.equals(state, State.RUNNING)) { + state = State.STOPPED; + } + } } @Override - protected State awaitTermination(final Duration duration) throws TimeoutException, - ExecutionException, InterruptedException { - pipelineExecution.get(duration.getMillis(), TimeUnit.MILLISECONDS); - if (javaStreamingContext.awaitTerminationOrTimeout(duration.getMillis())) { - return State.DONE; - } else { - return null; - } + protected State awaitTermination(final Duration duration) throws ExecutionException, + InterruptedException { + pipelineExecution.get(); // execution is asynchronous anyway so no need to time-out. + javaStreamingContext.awaitTerminationOrTimeout(duration.getMillis()); + + State terminationState = null; + switch (javaStreamingContext.getState()) { + case ACTIVE: + terminationState = State.RUNNING; + break; + case STOPPED: + terminationState = State.DONE; + break; + default: + state = State.UNKNOWN; + break; + } + return terminationState; } } http://git-wip-us.apache.org/repos/asf/beam/blob/2e308463/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkPipelineOptions.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkPipelineOptions.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkPipelineOptions.java index d50b652..902b250 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkPipelineOptions.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkPipelineOptions.java @@ -17,9 +17,13 @@ */ package org.apache.beam.runners.spark; +import javax.annotation.Nullable; import org.apache.beam.sdk.options.Default; +import org.apache.beam.sdk.options.DefaultValueFactory; import org.apache.beam.sdk.options.Description; +import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.testing.TestPipelineOptions; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; /** @@ -32,4 +36,26 @@ public interface TestSparkPipelineOptions extends SparkPipelineOptions, TestPipe boolean isForceStreaming(); void setForceStreaming(boolean forceStreaming); + @Description("A hard-coded expected number of assertions for this test pipeline.") + @Nullable + Integer getExpectedAssertions(); + void setExpectedAssertions(Integer expectedAssertions); + + @Description("A watermark (time in millis) that causes a pipeline that reads " + + "from an unbounded source to stop.") + @Default.InstanceFactory(DefaultStopPipelineWatermarkFactory.class) + Long getStopPipelineWatermark(); + void setStopPipelineWatermark(Long stopPipelineWatermark); + + /** + * A factory to provide the default watermark to stop a pipeline that reads + * from an unbounded source. + */ + class DefaultStopPipelineWatermarkFactory implements DefaultValueFactory<Long> { + @Override + public Long create(PipelineOptions options) { + return BoundedWindow.TIMESTAMP_MAX_VALUE.getMillis(); + } + } + } http://git-wip-us.apache.org/repos/asf/beam/blob/2e308463/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java index 16ddc9e..d321f99 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java @@ -21,19 +21,24 @@ package org.apache.beam.runners.spark; import static com.google.common.base.Preconditions.checkNotNull; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.isOneOf; import com.google.common.annotations.VisibleForTesting; +import com.google.common.util.concurrent.Uninterruptibles; import java.io.File; import java.io.IOException; import java.util.List; import java.util.Map; +import java.util.concurrent.TimeUnit; import org.apache.beam.runners.core.UnboundedReadFromBoundedSource; import org.apache.beam.runners.core.construction.PTransformMatchers; import org.apache.beam.runners.core.construction.ReplacementOutputs; import org.apache.beam.runners.spark.aggregators.AggregatorsAccumulator; import org.apache.beam.runners.spark.metrics.SparkMetricsContainer; +import org.apache.beam.runners.spark.stateful.SparkTimerInternals; import org.apache.beam.runners.spark.util.GlobalWatermarkHolder; import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.PipelineResult; import org.apache.beam.sdk.io.BoundedReadFromUnboundedSource; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.PipelineOptionsValidator; @@ -49,6 +54,7 @@ import org.apache.beam.sdk.values.PValue; import org.apache.beam.sdk.values.TaggedPValue; import org.apache.commons.io.FileUtils; import org.joda.time.Duration; +import org.joda.time.Instant; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -118,16 +124,25 @@ public final class TestSparkRunner extends PipelineRunner<SparkPipelineResult> { if (isForceStreaming) { try { result = delegate.run(pipeline); - Long timeout = testSparkPipelineOptions.getTestTimeoutSeconds(); - result.waitUntilFinish(Duration.standardSeconds(checkNotNull(timeout))); + awaitWatermarksOrTimeout(testSparkPipelineOptions, result); + result.stop(); + PipelineResult.State finishState = result.getState(); + // assert finish state. + assertThat( + String.format("Finish state %s is not allowed.", finishState), + finishState, + isOneOf(PipelineResult.State.STOPPED, PipelineResult.State.DONE)); + // validate assertion succeeded (at least once). int successAssertions = result.getAggregatorValue(PAssert.SUCCESS_COUNTER, Integer.class); + Integer expectedAssertions = testSparkPipelineOptions.getExpectedAssertions() != null + ? testSparkPipelineOptions.getExpectedAssertions() : expectedNumberOfAssertions; assertThat( String.format( "Expected %d successful assertions, but found %d.", - expectedNumberOfAssertions, successAssertions), + expectedAssertions, successAssertions), successAssertions, - is(expectedNumberOfAssertions)); + is(expectedAssertions)); // validate assertion didn't fail. int failedAssertions = result.getAggregatorValue(PAssert.FAILURE_COUNTER, Integer.class); assertThat( @@ -152,6 +167,13 @@ public final class TestSparkRunner extends PipelineRunner<SparkPipelineResult> { // for batch test pipelines, run and block until done. result = delegate.run(pipeline); result.waitUntilFinish(); + result.stop(); + PipelineResult.State finishState = result.getState(); + // assert finish state. + assertThat( + String.format("Finish state %s is not allowed.", finishState), + finishState, + is(PipelineResult.State.DONE)); // assert via matchers. assertThat(result, testSparkPipelineOptions.getOnCreateMatcher()); assertThat(result, testSparkPipelineOptions.getOnSuccessMatcher()); @@ -159,6 +181,29 @@ public final class TestSparkRunner extends PipelineRunner<SparkPipelineResult> { return result; } + private static void awaitWatermarksOrTimeout( + TestSparkPipelineOptions testSparkPipelineOptions, SparkPipelineResult result) { + Long timeoutMillis = Duration.standardSeconds( + checkNotNull(testSparkPipelineOptions.getTestTimeoutSeconds())).getMillis(); + Long batchDurationMillis = testSparkPipelineOptions.getBatchIntervalMillis(); + Instant stopPipelineWatermark = + new Instant(testSparkPipelineOptions.getStopPipelineWatermark()); + // we poll for pipeline status in batch-intervals. while this is not in-sync with Spark's + // execution clock, this is good enough. + // we break on timeout or end-of-time WM, which ever comes first. + Instant globalWatermark; + result.waitUntilFinish(Duration.millis(batchDurationMillis)); + do { + SparkTimerInternals sparkTimerInternals = + SparkTimerInternals.global(GlobalWatermarkHolder.get()); + sparkTimerInternals.advanceWatermark(); + globalWatermark = sparkTimerInternals.currentInputWatermarkTime(); + // let another batch-interval period of execution, just to reason about WM propagation. + Uninterruptibles.sleepUninterruptibly(batchDurationMillis, TimeUnit.MILLISECONDS); + } while ((timeoutMillis -= batchDurationMillis) > 0 + && globalWatermark.isBefore(stopPipelineWatermark)); + } + @VisibleForTesting void adaptBoundedReads(Pipeline pipeline) { pipeline.replace( http://git-wip-us.apache.org/repos/asf/beam/blob/2e308463/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkTimerInternals.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkTimerInternals.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkTimerInternals.java index 1949e1d..646e269 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkTimerInternals.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkTimerInternals.java @@ -19,6 +19,7 @@ package org.apache.beam.runners.spark.stateful; import static com.google.common.base.Preconditions.checkArgument; +import com.google.common.collect.Lists; import com.google.common.collect.Sets; import java.util.Collection; import java.util.Collections; @@ -40,7 +41,7 @@ import org.joda.time.Instant; /** * An implementation of {@link TimerInternals} for the SparkRunner. */ -class SparkTimerInternals implements TimerInternals { +public class SparkTimerInternals implements TimerInternals { private final Instant highWatermark; private final Instant synchronizedProcessingTime; private final Set<TimerData> timers = Sets.newHashSet(); @@ -92,6 +93,13 @@ class SparkTimerInternals implements TimerInternals { slowestLowWatermark, slowestHighWatermark, synchronizedProcessingTime); } + /** Build a global {@link TimerInternals} for all feeding streams.*/ + public static SparkTimerInternals global( + @Nullable Broadcast<Map<Integer, SparkWatermarks>> broadcast) { + return broadcast == null ? forStreamFromSources(Collections.<Integer>emptyList(), null) + : forStreamFromSources(Lists.newArrayList(broadcast.getValue().keySet()), broadcast); + } + Collection<TimerData> getTimers() { return timers; } http://git-wip-us.apache.org/repos/asf/beam/blob/2e308463/runners/spark/src/test/java/org/apache/beam/runners/spark/PipelineRule.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/PipelineRule.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/PipelineRule.java index 77519cd..f8499f3 100644 --- a/runners/spark/src/test/java/org/apache/beam/runners/spark/PipelineRule.java +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/PipelineRule.java @@ -33,34 +33,29 @@ import org.junit.runners.model.Statement; */ public class PipelineRule implements TestRule { - private final TestName testName = new TestName(); - private final SparkPipelineRule delegate; private final RuleChain chain; - private PipelineRule() { - this.delegate = new SparkPipelineRule(testName); - this.chain = RuleChain.outerRule(testName).around(this.delegate); - } - - private PipelineRule(Duration forcedTimeout) { - this.delegate = new SparkStreamingPipelineRule(forcedTimeout, testName); + private PipelineRule(SparkPipelineRule delegate) { + TestName testName = new TestName(); + this.delegate = delegate; + this.delegate.setTestName(testName); this.chain = RuleChain.outerRule(testName).around(this.delegate); } public static PipelineRule streaming() { - return new PipelineRule(Duration.standardSeconds(5)); + return new PipelineRule(new SparkStreamingPipelineRule()); } public static PipelineRule batch() { - return new PipelineRule(); + return new PipelineRule(new SparkPipelineRule()); } public Duration batchDuration() { return Duration.millis(delegate.options.getBatchIntervalMillis()); } - public SparkPipelineOptions getOptions() { + public TestSparkPipelineOptions getOptions() { return delegate.options; } @@ -76,19 +71,12 @@ public class PipelineRule implements TestRule { private static class SparkStreamingPipelineRule extends SparkPipelineRule { private final TemporaryFolder temporaryFolder = new TemporaryFolder(); - private final Duration forcedTimeout; - - SparkStreamingPipelineRule(Duration forcedTimeout, TestName testName) { - super(testName); - this.forcedTimeout = forcedTimeout; - } @Override protected void before() throws Throwable { super.before(); temporaryFolder.create(); options.setForceStreaming(true); - options.setTestTimeoutSeconds(forcedTimeout.getStandardSeconds()); options.setCheckpointDir( temporaryFolder.newFolder(options.getJobName()).toURI().toURL().toString()); } @@ -104,9 +92,9 @@ public class PipelineRule implements TestRule { protected final TestSparkPipelineOptions options = PipelineOptionsFactory.as(TestSparkPipelineOptions.class); - private final TestName testName; + private TestName testName; - private SparkPipelineRule(TestName testName) { + public void setTestName(TestName testName) { this.testName = testName; } @@ -114,7 +102,8 @@ public class PipelineRule implements TestRule { protected void before() throws Throwable { options.setRunner(TestSparkRunner.class); options.setEnableSparkMetricSinks(false); - options.setJobName(testName.getMethodName()); + options.setJobName( + testName != null ? testName.getMethodName() : "test-at-" + System.currentTimeMillis()); } } } http://git-wip-us.apache.org/repos/asf/beam/blob/2e308463/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkPipelineStateTest.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkPipelineStateTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkPipelineStateTest.java index 37a201c..3a68d6f 100644 --- a/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkPipelineStateTest.java +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkPipelineStateTest.java @@ -20,7 +20,6 @@ package org.apache.beam.runners.spark; import static org.hamcrest.CoreMatchers.instanceOf; import static org.hamcrest.Matchers.is; -import static org.hamcrest.Matchers.nullValue; import static org.junit.Assert.assertThat; import static org.junit.Assert.fail; @@ -141,7 +140,7 @@ public class SparkPipelineStateTest implements Serializable { result.waitUntilFinish(Duration.millis(1)); - assertThat(result.getState(), nullValue()); + assertThat(result.getState(), is(PipelineResult.State.RUNNING)); result.cancel(); } @@ -188,11 +187,10 @@ public class SparkPipelineStateTest implements Serializable { testCanceledPipeline(getBatchOptions()); } - //TODO: fix this! -// @Test -// public void testStreamingPipelineFailedState() throws Exception { -// testFailedPipeline(getStreamingOptions()); -// } + @Test + public void testStreamingPipelineFailedState() throws Exception { + testFailedPipeline(getStreamingOptions()); + } @Test public void testBatchPipelineFailedState() throws Exception { http://git-wip-us.apache.org/repos/asf/beam/blob/2e308463/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/CreateStreamTest.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/CreateStreamTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/CreateStreamTest.java index b32f5f3..75abc8b 100644 --- a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/CreateStreamTest.java +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/CreateStreamTest.java @@ -313,7 +313,9 @@ public class CreateStreamTest implements Serializable { .advanceWatermarkForNextBatch(instant.plus(Duration.standardMinutes(2))) .nextBatch( TimestampedValue.of(5, instant)) - .advanceWatermarkForNextBatch(instant.plus(Duration.standardMinutes(5))); + .advanceWatermarkForNextBatch(instant.plus(Duration.standardMinutes(5))) + .emptyBatch() + .advanceNextBatchWatermarkToInfinity(); PCollection<Integer> windowed1 = p .apply(source1) @@ -346,7 +348,8 @@ public class CreateStreamTest implements Serializable { public void testElementAtPositiveInfinityThrows() { CreateStream<Integer> source = CreateStream.of(VarIntCoder.of(), pipelineRule.batchDuration()) - .nextBatch(TimestampedValue.of(-1, BoundedWindow.TIMESTAMP_MAX_VALUE.minus(1L))); + .nextBatch(TimestampedValue.of(-1, BoundedWindow.TIMESTAMP_MAX_VALUE.minus(1L))) + .advanceNextBatchWatermarkToInfinity(); thrown.expect(IllegalArgumentException.class); source.nextBatch(TimestampedValue.of(1, BoundedWindow.TIMESTAMP_MAX_VALUE)); } @@ -357,7 +360,9 @@ public class CreateStreamTest implements Serializable { CreateStream.of(VarIntCoder.of(), pipelineRule.batchDuration()) .advanceWatermarkForNextBatch(new Instant(0L)); thrown.expect(IllegalArgumentException.class); - source.advanceWatermarkForNextBatch(new Instant(-1L)); + source + .advanceWatermarkForNextBatch(new Instant(-1L)) + .advanceNextBatchWatermarkToInfinity(); } @Test http://git-wip-us.apache.org/repos/asf/beam/blob/2e308463/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/ResumeFromCheckpointStreamingTest.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/ResumeFromCheckpointStreamingTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/ResumeFromCheckpointStreamingTest.java index 7706777..bc22980 100644 --- a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/ResumeFromCheckpointStreamingTest.java +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/ResumeFromCheckpointStreamingTest.java @@ -24,6 +24,7 @@ import static org.hamcrest.Matchers.hasItem; import static org.hamcrest.Matchers.is; import static org.junit.Assert.assertThat; +import com.google.common.base.Optional; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.util.concurrent.Uninterruptibles; @@ -33,10 +34,10 @@ import java.util.List; import java.util.Map; import java.util.Properties; import java.util.concurrent.TimeUnit; +import org.apache.beam.runners.spark.PipelineRule; import org.apache.beam.runners.spark.ReuseSparkContextRule; -import org.apache.beam.runners.spark.SparkPipelineOptions; import org.apache.beam.runners.spark.SparkPipelineResult; -import org.apache.beam.runners.spark.SparkRunner; +import org.apache.beam.runners.spark.TestSparkPipelineOptions; import org.apache.beam.runners.spark.aggregators.AggregatorsAccumulator; import org.apache.beam.runners.spark.coders.CoderHelpers; import org.apache.beam.runners.spark.metrics.SparkMetricsContainer; @@ -50,7 +51,6 @@ import org.apache.beam.sdk.metrics.Counter; import org.apache.beam.sdk.metrics.MetricNameFilter; import org.apache.beam.sdk.metrics.Metrics; import org.apache.beam.sdk.metrics.MetricsFilter; -import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.transforms.Aggregator; import org.apache.beam.sdk.transforms.Create; @@ -82,8 +82,6 @@ import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Rule; import org.junit.Test; -import org.junit.rules.TemporaryFolder; -import org.junit.rules.TestName; /** @@ -92,7 +90,8 @@ import org.junit.rules.TestName; * <p>Runs the pipeline reading from a Kafka backlog with a WM function that will move to infinity * on a EOF signal. * After resuming from checkpoint, a single output (guaranteed by the WM) is asserted, along with - * {@link Aggregator}s values that are expected to resume from previous count as well. + * {@link Aggregator}s and {@link Metrics} values that are expected to resume from previous count + * and a side-input that is expected to recover as well. */ public class ResumeFromCheckpointStreamingTest { private static final EmbeddedKafkaCluster.EmbeddedZookeeper EMBEDDED_ZOOKEEPER = @@ -102,11 +101,9 @@ public class ResumeFromCheckpointStreamingTest { private static final String TOPIC = "kafka_beam_test_topic"; @Rule - public TemporaryFolder tmpFolder = new TemporaryFolder(); + public final transient ReuseSparkContextRule noContextReuse = ReuseSparkContextRule.no(); @Rule - public ReuseSparkContextRule noContextResue = ReuseSparkContextRule.no(); - @Rule - public transient TestName testName = new TestName(); + public final transient PipelineRule pipelineRule = PipelineRule.streaming(); @BeforeClass public static void init() throws IOException { @@ -144,13 +141,6 @@ public class ResumeFromCheckpointStreamingTest { @Test public void testWithResume() throws Exception { - SparkPipelineOptions options = PipelineOptionsFactory.create().as(SparkPipelineOptions.class); - options.setRunner(SparkRunner.class); - options.setCheckpointDir(tmpFolder.newFolder().toString()); - options.setCheckpointDurationMillis(500L); - options.setJobName(testName.getMethodName()); - options.setSparkMaster("local[*]"); - // write to Kafka produce(ImmutableMap.of( "k1", new Instant(100), @@ -164,9 +154,8 @@ public class ResumeFromCheckpointStreamingTest { .addNameFilter(MetricNameFilter.inNamespace(ResumeFromCheckpointStreamingTest.class)) .build(); - // first run will read from Kafka backlog - "auto.offset.reset=smallest" - SparkPipelineResult res = run(options); - res.waitUntilFinish(Duration.standardSeconds(5)); + // first run should expect EOT matching the last injected element. + SparkPipelineResult res = run(pipelineRule, Optional.of(new Instant(400)), 0); // assertions 1: long processedMessages1 = res.getAggregatorValue("processedMessages", Long.class); assertThat( @@ -192,8 +181,7 @@ public class ResumeFromCheckpointStreamingTest { )); // recovery should resume from last read offset, and read the second batch of input. - res = runAgain(options); - res.waitUntilFinish(Duration.standardSeconds(5)); + res = runAgain(pipelineRule, 1); // assertions 2: long processedMessages2 = res.getAggregatorValue("processedMessages", Long.class); assertThat( @@ -219,13 +207,15 @@ public class ResumeFromCheckpointStreamingTest { } - private SparkPipelineResult runAgain(SparkPipelineOptions options) { + private SparkPipelineResult runAgain(PipelineRule pipelineRule, int expectedAssertions) { // sleep before next run. Uninterruptibles.sleepUninterruptibly(10, TimeUnit.MILLISECONDS); - return run(options); + return run(pipelineRule, Optional.<Instant>absent(), expectedAssertions); } - private static SparkPipelineResult run(SparkPipelineOptions options) { + @SuppressWarnings("OptionalUsedAsFieldOrParameterType") + private static SparkPipelineResult run( + PipelineRule pipelineRule, Optional<Instant> stopWatermarkOption, int expectedAssertions) { KafkaIO.Read<String, Instant> read = KafkaIO.<String, Instant>read() .withBootstrapServers(EMBEDDED_KAFKA_CLUSTER.getBrokerList()) .withTopics(Collections.singletonList(TOPIC)) @@ -247,14 +237,23 @@ public class ResumeFromCheckpointStreamingTest { } }); - Pipeline p = Pipeline.create(options); + TestSparkPipelineOptions options = pipelineRule.getOptions(); + options.setSparkMaster("local[*]"); + options.setCheckpointDurationMillis(options.getBatchIntervalMillis()); + options.setExpectedAssertions(expectedAssertions); + // timeout is per execution so it can be injected by the caller. + if (stopWatermarkOption.isPresent()) { + options.setStopPipelineWatermark(stopWatermarkOption.get().getMillis()); + } + Pipeline p = pipelineRule.createPipeline(); PCollection<String> expectedCol = p.apply(Create.of(ImmutableList.of("side1", "side2")).withCoder(StringUtf8Coder.of())); PCollectionView<List<String>> view = expectedCol.apply(View.<String>asList()); - PCollection<Iterable<String>> grouped = p - .apply(read.withoutMetadata()) + PCollection<KV<String, Instant>> kafkaStream = p.apply(read.withoutMetadata()); + + PCollection<Iterable<String>> grouped = kafkaStream .apply(Keys.<String>create()) .apply("EOFShallNotPassFn", ParDo.of(new EOFShallNotPassFn(view)).withSideInputs(view)) .apply(Window.<String>into(FixedWindows.of(Duration.millis(500)))