Repository: beam
Updated Branches:
  refs/heads/master 7954896a5 -> efc701ed6


Test runner to stop on EOT watermark, or timeout.

Remove timeout since it is already a pipeline option.

Advance to infinity at the end of pipelines.

Add EOT watermark and expected assertions test options.

SparkPipelineResult should avoid returning null, and handle exceptions better.

Make ResumeFromCheckpointStreamingTest use TestSparkRunner and stop on EOT 
watermark.

Stop the context and update the state in finally.

Addressed comments - better name for a watermark that stops execution.


Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/2e308463
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/2e308463
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/2e308463

Branch: refs/heads/master
Commit: 2e3084631dbb56acd3971f12596953b63290c9f2
Parents: 7954896
Author: Sela <ans...@paypal.com>
Authored: Sat Mar 4 21:04:02 2017 +0200
Committer: Amit Sela <amitsel...@gmail.com>
Committed: Thu Mar 9 18:02:39 2017 +0200

----------------------------------------------------------------------
 .../beam/runners/spark/SparkPipelineResult.java | 52 ++++++++++++------
 .../runners/spark/TestSparkPipelineOptions.java | 26 +++++++++
 .../beam/runners/spark/TestSparkRunner.java     | 53 +++++++++++++++++--
 .../spark/stateful/SparkTimerInternals.java     | 10 +++-
 .../apache/beam/runners/spark/PipelineRule.java | 33 ++++--------
 .../runners/spark/SparkPipelineStateTest.java   | 12 ++---
 .../translation/streaming/CreateStreamTest.java | 11 ++--
 .../ResumeFromCheckpointStreamingTest.java      | 55 ++++++++++----------
 8 files changed, 172 insertions(+), 80 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/2e308463/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineResult.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineResult.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineResult.java
index ab59fb2..ddc1964 100644
--- 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineResult.java
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineResult.java
@@ -19,6 +19,7 @@
 package org.apache.beam.runners.spark;
 
 import java.io.IOException;
+import java.util.Objects;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.Future;
 import java.util.concurrent.TimeUnit;
@@ -56,11 +57,11 @@ public abstract class SparkPipelineResult implements 
PipelineResult {
     state = State.RUNNING;
   }
 
-  private RuntimeException runtimeExceptionFrom(final Throwable e) {
+  private static RuntimeException runtimeExceptionFrom(final Throwable e) {
     return (e instanceof RuntimeException) ? (RuntimeException) e : new 
RuntimeException(e);
   }
 
-  private RuntimeException beamExceptionFrom(final Throwable e) {
+  private static RuntimeException beamExceptionFrom(final Throwable e) {
     // Scala doesn't declare checked exceptions in the bytecode, and the Java 
compiler
     // won't let you catch something that is not declared, so we can't catch
     // SparkException directly, instead we do an instanceof check.
@@ -107,15 +108,15 @@ public abstract class SparkPipelineResult implements 
PipelineResult {
     try {
       state = awaitTermination(duration);
     } catch (final TimeoutException e) {
-      state = null;
+      // ignore.
     } catch (final ExecutionException e) {
       state = PipelineResult.State.FAILED;
+      stop();
       throw beamExceptionFrom(e.getCause());
     } catch (final Exception e) {
       state = PipelineResult.State.FAILED;
-      throw beamExceptionFrom(e);
-    } finally {
       stop();
+      throw beamExceptionFrom(e);
     }
 
     return state;
@@ -149,6 +150,9 @@ public abstract class SparkPipelineResult implements 
PipelineResult {
     @Override
     protected void stop() {
       SparkContextFactory.stopSparkContext(javaSparkContext);
+      if (Objects.equals(state, State.RUNNING)) {
+        state = State.STOPPED;
+      }
     }
 
     @Override
@@ -178,19 +182,37 @@ public abstract class SparkPipelineResult implements 
PipelineResult {
       // after calling stop, if exception occurs in "grace period" it won't 
propagate.
       // calling the StreamingContext's waiter with 0 msec will throw any 
error that might have
       // been thrown during the "grace period".
-      javaStreamingContext.awaitTermination(0);
-      SparkContextFactory.stopSparkContext(javaSparkContext);
+      try {
+        javaStreamingContext.awaitTermination(0);
+      } catch (Exception e) {
+        throw beamExceptionFrom(e);
+      } finally {
+        SparkContextFactory.stopSparkContext(javaSparkContext);
+        if (Objects.equals(state, State.RUNNING)) {
+          state = State.STOPPED;
+        }
+      }
     }
 
     @Override
-    protected State awaitTermination(final Duration duration) throws 
TimeoutException,
-        ExecutionException, InterruptedException {
-      pipelineExecution.get(duration.getMillis(), TimeUnit.MILLISECONDS);
-      if 
(javaStreamingContext.awaitTerminationOrTimeout(duration.getMillis())) {
-        return State.DONE;
-      } else {
-        return null;
-      }
+    protected State awaitTermination(final Duration duration) throws 
ExecutionException,
+        InterruptedException {
+      pipelineExecution.get(); // execution is asynchronous anyway so no need 
to time-out.
+      javaStreamingContext.awaitTerminationOrTimeout(duration.getMillis());
+
+      State terminationState = null;
+      switch (javaStreamingContext.getState()) {
+         case ACTIVE:
+           terminationState = State.RUNNING;
+           break;
+         case STOPPED:
+           terminationState = State.DONE;
+           break;
+         default:
+           state = State.UNKNOWN;
+           break;
+       }
+       return terminationState;
     }
 
   }

http://git-wip-us.apache.org/repos/asf/beam/blob/2e308463/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkPipelineOptions.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkPipelineOptions.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkPipelineOptions.java
index d50b652..902b250 100644
--- 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkPipelineOptions.java
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkPipelineOptions.java
@@ -17,9 +17,13 @@
  */
 package org.apache.beam.runners.spark;
 
+import javax.annotation.Nullable;
 import org.apache.beam.sdk.options.Default;
+import org.apache.beam.sdk.options.DefaultValueFactory;
 import org.apache.beam.sdk.options.Description;
+import org.apache.beam.sdk.options.PipelineOptions;
 import org.apache.beam.sdk.testing.TestPipelineOptions;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
 
 
 /**
@@ -32,4 +36,26 @@ public interface TestSparkPipelineOptions extends 
SparkPipelineOptions, TestPipe
   boolean isForceStreaming();
   void setForceStreaming(boolean forceStreaming);
 
+  @Description("A hard-coded expected number of assertions for this test 
pipeline.")
+  @Nullable
+  Integer getExpectedAssertions();
+  void setExpectedAssertions(Integer expectedAssertions);
+
+  @Description("A watermark (time in millis) that causes a pipeline that reads 
"
+      + "from an unbounded source to stop.")
+  @Default.InstanceFactory(DefaultStopPipelineWatermarkFactory.class)
+  Long getStopPipelineWatermark();
+  void setStopPipelineWatermark(Long stopPipelineWatermark);
+
+  /**
+   * A factory to provide the default watermark to stop a pipeline that reads
+   * from an unbounded source.
+   */
+  class DefaultStopPipelineWatermarkFactory implements 
DefaultValueFactory<Long> {
+    @Override
+    public Long create(PipelineOptions options) {
+      return BoundedWindow.TIMESTAMP_MAX_VALUE.getMillis();
+    }
+  }
+
 }

http://git-wip-us.apache.org/repos/asf/beam/blob/2e308463/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java
index 16ddc9e..d321f99 100644
--- 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java
@@ -21,19 +21,24 @@ package org.apache.beam.runners.spark;
 import static com.google.common.base.Preconditions.checkNotNull;
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.Matchers.is;
+import static org.hamcrest.Matchers.isOneOf;
 
 import com.google.common.annotations.VisibleForTesting;
+import com.google.common.util.concurrent.Uninterruptibles;
 import java.io.File;
 import java.io.IOException;
 import java.util.List;
 import java.util.Map;
+import java.util.concurrent.TimeUnit;
 import org.apache.beam.runners.core.UnboundedReadFromBoundedSource;
 import org.apache.beam.runners.core.construction.PTransformMatchers;
 import org.apache.beam.runners.core.construction.ReplacementOutputs;
 import org.apache.beam.runners.spark.aggregators.AggregatorsAccumulator;
 import org.apache.beam.runners.spark.metrics.SparkMetricsContainer;
+import org.apache.beam.runners.spark.stateful.SparkTimerInternals;
 import org.apache.beam.runners.spark.util.GlobalWatermarkHolder;
 import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.PipelineResult;
 import org.apache.beam.sdk.io.BoundedReadFromUnboundedSource;
 import org.apache.beam.sdk.options.PipelineOptions;
 import org.apache.beam.sdk.options.PipelineOptionsValidator;
@@ -49,6 +54,7 @@ import org.apache.beam.sdk.values.PValue;
 import org.apache.beam.sdk.values.TaggedPValue;
 import org.apache.commons.io.FileUtils;
 import org.joda.time.Duration;
+import org.joda.time.Instant;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -118,16 +124,25 @@ public final class TestSparkRunner extends 
PipelineRunner<SparkPipelineResult> {
     if (isForceStreaming) {
       try {
         result = delegate.run(pipeline);
-        Long timeout = testSparkPipelineOptions.getTestTimeoutSeconds();
-        
result.waitUntilFinish(Duration.standardSeconds(checkNotNull(timeout)));
+        awaitWatermarksOrTimeout(testSparkPipelineOptions, result);
+        result.stop();
+        PipelineResult.State finishState = result.getState();
+        // assert finish state.
+        assertThat(
+            String.format("Finish state %s is not allowed.", finishState),
+            finishState,
+            isOneOf(PipelineResult.State.STOPPED, PipelineResult.State.DONE));
+
         // validate assertion succeeded (at least once).
         int successAssertions = 
result.getAggregatorValue(PAssert.SUCCESS_COUNTER, Integer.class);
+        Integer expectedAssertions = 
testSparkPipelineOptions.getExpectedAssertions() != null
+            ? testSparkPipelineOptions.getExpectedAssertions() : 
expectedNumberOfAssertions;
         assertThat(
             String.format(
                 "Expected %d successful assertions, but found %d.",
-                expectedNumberOfAssertions, successAssertions),
+                expectedAssertions, successAssertions),
             successAssertions,
-            is(expectedNumberOfAssertions));
+            is(expectedAssertions));
         // validate assertion didn't fail.
         int failedAssertions = 
result.getAggregatorValue(PAssert.FAILURE_COUNTER, Integer.class);
         assertThat(
@@ -152,6 +167,13 @@ public final class TestSparkRunner extends 
PipelineRunner<SparkPipelineResult> {
       // for batch test pipelines, run and block until done.
       result = delegate.run(pipeline);
       result.waitUntilFinish();
+      result.stop();
+      PipelineResult.State finishState = result.getState();
+      // assert finish state.
+      assertThat(
+          String.format("Finish state %s is not allowed.", finishState),
+          finishState,
+          is(PipelineResult.State.DONE));
       // assert via matchers.
       assertThat(result, testSparkPipelineOptions.getOnCreateMatcher());
       assertThat(result, testSparkPipelineOptions.getOnSuccessMatcher());
@@ -159,6 +181,29 @@ public final class TestSparkRunner extends 
PipelineRunner<SparkPipelineResult> {
     return result;
   }
 
+  private static void awaitWatermarksOrTimeout(
+      TestSparkPipelineOptions testSparkPipelineOptions, SparkPipelineResult 
result) {
+    Long timeoutMillis = Duration.standardSeconds(
+        
checkNotNull(testSparkPipelineOptions.getTestTimeoutSeconds())).getMillis();
+    Long batchDurationMillis = 
testSparkPipelineOptions.getBatchIntervalMillis();
+    Instant stopPipelineWatermark =
+        new Instant(testSparkPipelineOptions.getStopPipelineWatermark());
+    // we poll for pipeline status in batch-intervals. while this is not 
in-sync with Spark's
+    // execution clock, this is good enough.
+    // we break on timeout or end-of-time WM, which ever comes first.
+    Instant globalWatermark;
+    result.waitUntilFinish(Duration.millis(batchDurationMillis));
+    do {
+      SparkTimerInternals sparkTimerInternals =
+          SparkTimerInternals.global(GlobalWatermarkHolder.get());
+      sparkTimerInternals.advanceWatermark();
+      globalWatermark = sparkTimerInternals.currentInputWatermarkTime();
+      // let another batch-interval period of execution, just to reason about 
WM propagation.
+      Uninterruptibles.sleepUninterruptibly(batchDurationMillis, 
TimeUnit.MILLISECONDS);
+    } while ((timeoutMillis -= batchDurationMillis) > 0
+          && globalWatermark.isBefore(stopPipelineWatermark));
+  }
+
   @VisibleForTesting
   void adaptBoundedReads(Pipeline pipeline) {
     pipeline.replace(

http://git-wip-us.apache.org/repos/asf/beam/blob/2e308463/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkTimerInternals.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkTimerInternals.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkTimerInternals.java
index 1949e1d..646e269 100644
--- 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkTimerInternals.java
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkTimerInternals.java
@@ -19,6 +19,7 @@ package org.apache.beam.runners.spark.stateful;
 
 import static com.google.common.base.Preconditions.checkArgument;
 
+import com.google.common.collect.Lists;
 import com.google.common.collect.Sets;
 import java.util.Collection;
 import java.util.Collections;
@@ -40,7 +41,7 @@ import org.joda.time.Instant;
 /**
  * An implementation of {@link TimerInternals} for the SparkRunner.
  */
-class SparkTimerInternals implements TimerInternals {
+public class SparkTimerInternals implements TimerInternals {
   private final Instant highWatermark;
   private final Instant synchronizedProcessingTime;
   private final Set<TimerData> timers = Sets.newHashSet();
@@ -92,6 +93,13 @@ class SparkTimerInternals implements TimerInternals {
         slowestLowWatermark, slowestHighWatermark, synchronizedProcessingTime);
   }
 
+  /** Build a global {@link TimerInternals} for all feeding streams.*/
+  public static SparkTimerInternals global(
+      @Nullable Broadcast<Map<Integer, SparkWatermarks>> broadcast) {
+    return broadcast == null ? 
forStreamFromSources(Collections.<Integer>emptyList(), null)
+        : 
forStreamFromSources(Lists.newArrayList(broadcast.getValue().keySet()), 
broadcast);
+  }
+
   Collection<TimerData> getTimers() {
     return timers;
   }

http://git-wip-us.apache.org/repos/asf/beam/blob/2e308463/runners/spark/src/test/java/org/apache/beam/runners/spark/PipelineRule.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/test/java/org/apache/beam/runners/spark/PipelineRule.java 
b/runners/spark/src/test/java/org/apache/beam/runners/spark/PipelineRule.java
index 77519cd..f8499f3 100644
--- 
a/runners/spark/src/test/java/org/apache/beam/runners/spark/PipelineRule.java
+++ 
b/runners/spark/src/test/java/org/apache/beam/runners/spark/PipelineRule.java
@@ -33,34 +33,29 @@ import org.junit.runners.model.Statement;
  */
 public class PipelineRule implements TestRule {
 
-  private final TestName testName = new TestName();
-
   private final SparkPipelineRule delegate;
   private final RuleChain chain;
 
-  private PipelineRule() {
-    this.delegate = new SparkPipelineRule(testName);
-    this.chain = RuleChain.outerRule(testName).around(this.delegate);
-  }
-
-  private PipelineRule(Duration forcedTimeout) {
-    this.delegate = new SparkStreamingPipelineRule(forcedTimeout, testName);
+  private PipelineRule(SparkPipelineRule delegate) {
+    TestName testName = new TestName();
+    this.delegate = delegate;
+    this.delegate.setTestName(testName);
     this.chain = RuleChain.outerRule(testName).around(this.delegate);
   }
 
   public static PipelineRule streaming() {
-    return new PipelineRule(Duration.standardSeconds(5));
+    return new PipelineRule(new SparkStreamingPipelineRule());
   }
 
   public static PipelineRule batch() {
-    return new PipelineRule();
+    return new PipelineRule(new SparkPipelineRule());
   }
 
   public Duration batchDuration() {
     return Duration.millis(delegate.options.getBatchIntervalMillis());
   }
 
-  public SparkPipelineOptions getOptions() {
+  public TestSparkPipelineOptions getOptions() {
     return delegate.options;
   }
 
@@ -76,19 +71,12 @@ public class PipelineRule implements TestRule {
   private static class SparkStreamingPipelineRule extends SparkPipelineRule {
 
     private final TemporaryFolder temporaryFolder = new TemporaryFolder();
-    private final Duration forcedTimeout;
-
-    SparkStreamingPipelineRule(Duration forcedTimeout, TestName testName) {
-      super(testName);
-      this.forcedTimeout = forcedTimeout;
-    }
 
     @Override
     protected void before() throws Throwable {
       super.before();
       temporaryFolder.create();
       options.setForceStreaming(true);
-      options.setTestTimeoutSeconds(forcedTimeout.getStandardSeconds());
       options.setCheckpointDir(
           
temporaryFolder.newFolder(options.getJobName()).toURI().toURL().toString());
     }
@@ -104,9 +92,9 @@ public class PipelineRule implements TestRule {
     protected final TestSparkPipelineOptions options =
         PipelineOptionsFactory.as(TestSparkPipelineOptions.class);
 
-    private final TestName testName;
+    private TestName testName;
 
-    private SparkPipelineRule(TestName testName) {
+    public void setTestName(TestName testName) {
       this.testName = testName;
     }
 
@@ -114,7 +102,8 @@ public class PipelineRule implements TestRule {
     protected void before() throws Throwable {
       options.setRunner(TestSparkRunner.class);
       options.setEnableSparkMetricSinks(false);
-      options.setJobName(testName.getMethodName());
+      options.setJobName(
+          testName != null ? testName.getMethodName() : "test-at-" + 
System.currentTimeMillis());
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/beam/blob/2e308463/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkPipelineStateTest.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkPipelineStateTest.java
 
b/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkPipelineStateTest.java
index 37a201c..3a68d6f 100644
--- 
a/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkPipelineStateTest.java
+++ 
b/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkPipelineStateTest.java
@@ -20,7 +20,6 @@ package org.apache.beam.runners.spark;
 
 import static org.hamcrest.CoreMatchers.instanceOf;
 import static org.hamcrest.Matchers.is;
-import static org.hamcrest.Matchers.nullValue;
 import static org.junit.Assert.assertThat;
 import static org.junit.Assert.fail;
 
@@ -141,7 +140,7 @@ public class SparkPipelineStateTest implements Serializable 
{
 
     result.waitUntilFinish(Duration.millis(1));
 
-    assertThat(result.getState(), nullValue());
+    assertThat(result.getState(), is(PipelineResult.State.RUNNING));
 
     result.cancel();
   }
@@ -188,11 +187,10 @@ public class SparkPipelineStateTest implements 
Serializable {
     testCanceledPipeline(getBatchOptions());
   }
 
-  //TODO: fix this!
-//  @Test
-//  public void testStreamingPipelineFailedState() throws Exception {
-//    testFailedPipeline(getStreamingOptions());
-//  }
+  @Test
+  public void testStreamingPipelineFailedState() throws Exception {
+    testFailedPipeline(getStreamingOptions());
+  }
 
   @Test
   public void testBatchPipelineFailedState() throws Exception {

http://git-wip-us.apache.org/repos/asf/beam/blob/2e308463/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/CreateStreamTest.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/CreateStreamTest.java
 
b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/CreateStreamTest.java
index b32f5f3..75abc8b 100644
--- 
a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/CreateStreamTest.java
+++ 
b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/CreateStreamTest.java
@@ -313,7 +313,9 @@ public class CreateStreamTest implements Serializable {
             
.advanceWatermarkForNextBatch(instant.plus(Duration.standardMinutes(2)))
             .nextBatch(
                 TimestampedValue.of(5, instant))
-            
.advanceWatermarkForNextBatch(instant.plus(Duration.standardMinutes(5)));
+            
.advanceWatermarkForNextBatch(instant.plus(Duration.standardMinutes(5)))
+            .emptyBatch()
+            .advanceNextBatchWatermarkToInfinity();
 
     PCollection<Integer> windowed1 = p
         .apply(source1)
@@ -346,7 +348,8 @@ public class CreateStreamTest implements Serializable {
   public void testElementAtPositiveInfinityThrows() {
     CreateStream<Integer> source =
         CreateStream.of(VarIntCoder.of(), pipelineRule.batchDuration())
-            .nextBatch(TimestampedValue.of(-1, 
BoundedWindow.TIMESTAMP_MAX_VALUE.minus(1L)));
+            .nextBatch(TimestampedValue.of(-1, 
BoundedWindow.TIMESTAMP_MAX_VALUE.minus(1L)))
+            .advanceNextBatchWatermarkToInfinity();
     thrown.expect(IllegalArgumentException.class);
     source.nextBatch(TimestampedValue.of(1, 
BoundedWindow.TIMESTAMP_MAX_VALUE));
   }
@@ -357,7 +360,9 @@ public class CreateStreamTest implements Serializable {
         CreateStream.of(VarIntCoder.of(), pipelineRule.batchDuration())
             .advanceWatermarkForNextBatch(new Instant(0L));
     thrown.expect(IllegalArgumentException.class);
-    source.advanceWatermarkForNextBatch(new Instant(-1L));
+    source
+        .advanceWatermarkForNextBatch(new Instant(-1L))
+        .advanceNextBatchWatermarkToInfinity();
   }
 
   @Test

http://git-wip-us.apache.org/repos/asf/beam/blob/2e308463/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/ResumeFromCheckpointStreamingTest.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/ResumeFromCheckpointStreamingTest.java
 
b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/ResumeFromCheckpointStreamingTest.java
index 7706777..bc22980 100644
--- 
a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/ResumeFromCheckpointStreamingTest.java
+++ 
b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/ResumeFromCheckpointStreamingTest.java
@@ -24,6 +24,7 @@ import static org.hamcrest.Matchers.hasItem;
 import static org.hamcrest.Matchers.is;
 import static org.junit.Assert.assertThat;
 
+import com.google.common.base.Optional;
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableMap;
 import com.google.common.util.concurrent.Uninterruptibles;
@@ -33,10 +34,10 @@ import java.util.List;
 import java.util.Map;
 import java.util.Properties;
 import java.util.concurrent.TimeUnit;
+import org.apache.beam.runners.spark.PipelineRule;
 import org.apache.beam.runners.spark.ReuseSparkContextRule;
-import org.apache.beam.runners.spark.SparkPipelineOptions;
 import org.apache.beam.runners.spark.SparkPipelineResult;
-import org.apache.beam.runners.spark.SparkRunner;
+import org.apache.beam.runners.spark.TestSparkPipelineOptions;
 import org.apache.beam.runners.spark.aggregators.AggregatorsAccumulator;
 import org.apache.beam.runners.spark.coders.CoderHelpers;
 import org.apache.beam.runners.spark.metrics.SparkMetricsContainer;
@@ -50,7 +51,6 @@ import org.apache.beam.sdk.metrics.Counter;
 import org.apache.beam.sdk.metrics.MetricNameFilter;
 import org.apache.beam.sdk.metrics.Metrics;
 import org.apache.beam.sdk.metrics.MetricsFilter;
-import org.apache.beam.sdk.options.PipelineOptionsFactory;
 import org.apache.beam.sdk.testing.PAssert;
 import org.apache.beam.sdk.transforms.Aggregator;
 import org.apache.beam.sdk.transforms.Create;
@@ -82,8 +82,6 @@ import org.junit.AfterClass;
 import org.junit.BeforeClass;
 import org.junit.Rule;
 import org.junit.Test;
-import org.junit.rules.TemporaryFolder;
-import org.junit.rules.TestName;
 
 
 /**
@@ -92,7 +90,8 @@ import org.junit.rules.TestName;
  * <p>Runs the pipeline reading from a Kafka backlog with a WM function that 
will move to infinity
  * on a EOF signal.
  * After resuming from checkpoint, a single output (guaranteed by the WM) is 
asserted, along with
- * {@link Aggregator}s values that are expected to resume from previous count 
as well.
+ * {@link Aggregator}s and {@link Metrics} values that are expected to resume 
from previous count
+ * and a side-input that is expected to recover as well.
  */
 public class ResumeFromCheckpointStreamingTest {
   private static final EmbeddedKafkaCluster.EmbeddedZookeeper 
EMBEDDED_ZOOKEEPER =
@@ -102,11 +101,9 @@ public class ResumeFromCheckpointStreamingTest {
   private static final String TOPIC = "kafka_beam_test_topic";
 
   @Rule
-  public TemporaryFolder tmpFolder = new TemporaryFolder();
+  public final transient ReuseSparkContextRule noContextReuse = 
ReuseSparkContextRule.no();
   @Rule
-  public ReuseSparkContextRule noContextResue = ReuseSparkContextRule.no();
-  @Rule
-  public transient TestName testName = new TestName();
+  public final transient PipelineRule pipelineRule = PipelineRule.streaming();
 
   @BeforeClass
   public static void init() throws IOException {
@@ -144,13 +141,6 @@ public class ResumeFromCheckpointStreamingTest {
 
   @Test
   public void testWithResume() throws Exception {
-    SparkPipelineOptions options = 
PipelineOptionsFactory.create().as(SparkPipelineOptions.class);
-    options.setRunner(SparkRunner.class);
-    options.setCheckpointDir(tmpFolder.newFolder().toString());
-    options.setCheckpointDurationMillis(500L);
-    options.setJobName(testName.getMethodName());
-    options.setSparkMaster("local[*]");
-
     // write to Kafka
     produce(ImmutableMap.of(
         "k1", new Instant(100),
@@ -164,9 +154,8 @@ public class ResumeFromCheckpointStreamingTest {
             
.addNameFilter(MetricNameFilter.inNamespace(ResumeFromCheckpointStreamingTest.class))
             .build();
 
-    // first run will read from Kafka backlog - "auto.offset.reset=smallest"
-    SparkPipelineResult res = run(options);
-    res.waitUntilFinish(Duration.standardSeconds(5));
+    // first run should expect EOT matching the last injected element.
+    SparkPipelineResult res = run(pipelineRule, Optional.of(new Instant(400)), 
0);
     // assertions 1:
     long processedMessages1 = res.getAggregatorValue("processedMessages", 
Long.class);
     assertThat(
@@ -192,8 +181,7 @@ public class ResumeFromCheckpointStreamingTest {
     ));
 
     // recovery should resume from last read offset, and read the second batch 
of input.
-    res = runAgain(options);
-    res.waitUntilFinish(Duration.standardSeconds(5));
+    res = runAgain(pipelineRule, 1);
     // assertions 2:
     long processedMessages2 = res.getAggregatorValue("processedMessages", 
Long.class);
     assertThat(
@@ -219,13 +207,15 @@ public class ResumeFromCheckpointStreamingTest {
 
   }
 
-  private SparkPipelineResult runAgain(SparkPipelineOptions options) {
+  private SparkPipelineResult runAgain(PipelineRule pipelineRule, int 
expectedAssertions) {
     // sleep before next run.
     Uninterruptibles.sleepUninterruptibly(10, TimeUnit.MILLISECONDS);
-    return run(options);
+    return run(pipelineRule, Optional.<Instant>absent(), expectedAssertions);
   }
 
-  private static SparkPipelineResult run(SparkPipelineOptions options) {
+  @SuppressWarnings("OptionalUsedAsFieldOrParameterType")
+  private static SparkPipelineResult run(
+      PipelineRule pipelineRule, Optional<Instant> stopWatermarkOption, int 
expectedAssertions) {
     KafkaIO.Read<String, Instant> read = KafkaIO.<String, Instant>read()
         .withBootstrapServers(EMBEDDED_KAFKA_CLUSTER.getBrokerList())
         .withTopics(Collections.singletonList(TOPIC))
@@ -247,14 +237,23 @@ public class ResumeFromCheckpointStreamingTest {
           }
         });
 
-    Pipeline p = Pipeline.create(options);
+    TestSparkPipelineOptions options = pipelineRule.getOptions();
+    options.setSparkMaster("local[*]");
+    options.setCheckpointDurationMillis(options.getBatchIntervalMillis());
+    options.setExpectedAssertions(expectedAssertions);
+    // timeout is per execution so it can be injected by the caller.
+    if (stopWatermarkOption.isPresent()) {
+      options.setStopPipelineWatermark(stopWatermarkOption.get().getMillis());
+    }
+    Pipeline p = pipelineRule.createPipeline();
 
     PCollection<String> expectedCol =
         p.apply(Create.of(ImmutableList.of("side1", 
"side2")).withCoder(StringUtf8Coder.of()));
     PCollectionView<List<String>> view = 
expectedCol.apply(View.<String>asList());
 
-    PCollection<Iterable<String>> grouped = p
-        .apply(read.withoutMetadata())
+    PCollection<KV<String, Instant>> kafkaStream = 
p.apply(read.withoutMetadata());
+
+    PCollection<Iterable<String>> grouped = kafkaStream
         .apply(Keys.<String>create())
         .apply("EOFShallNotPassFn", ParDo.of(new 
EOFShallNotPassFn(view)).withSideInputs(view))
         .apply(Window.<String>into(FixedWindows.of(Duration.millis(500)))

Reply via email to