Repository: incubator-beam
Updated Branches:
  refs/heads/master 5c23f4954 -> 1ceb12aeb


http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/0feb6499/runners/spark/src/test/java/org/apache/beam/runners/spark/ClearAggregatorsRule.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/test/java/org/apache/beam/runners/spark/ClearAggregatorsRule.java
 
b/runners/spark/src/test/java/org/apache/beam/runners/spark/ClearAggregatorsRule.java
new file mode 100644
index 0000000..beaae13
--- /dev/null
+++ 
b/runners/spark/src/test/java/org/apache/beam/runners/spark/ClearAggregatorsRule.java
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.spark;
+
+import org.apache.beam.runners.spark.aggregators.AccumulatorSingleton;
+import org.junit.rules.ExternalResource;
+
+/**
+ * A rule that clears the {@link 
org.apache.beam.runners.spark.aggregators.AccumulatorSingleton}
+ * which represents the Beam {@link 
org.apache.beam.sdk.transforms.Aggregator}s.
+ */
+class ClearAggregatorsRule extends ExternalResource {
+  @Override
+  protected void before() throws Throwable {
+    AccumulatorSingleton.clear();
+  }
+}

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/0feb6499/runners/spark/src/test/java/org/apache/beam/runners/spark/SimpleWordCountTest.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/test/java/org/apache/beam/runners/spark/SimpleWordCountTest.java
 
b/runners/spark/src/test/java/org/apache/beam/runners/spark/SimpleWordCountTest.java
index 8b7762f..238d7ba 100644
--- 
a/runners/spark/src/test/java/org/apache/beam/runners/spark/SimpleWordCountTest.java
+++ 
b/runners/spark/src/test/java/org/apache/beam/runners/spark/SimpleWordCountTest.java
@@ -29,6 +29,7 @@ import java.io.File;
 import java.util.Arrays;
 import java.util.List;
 import java.util.Set;
+
 import org.apache.beam.runners.spark.aggregators.metrics.sink.InMemoryMetrics;
 import org.apache.beam.runners.spark.examples.WordCount;
 import org.apache.beam.sdk.Pipeline;
@@ -53,6 +54,9 @@ public class SimpleWordCountTest {
   @Rule
   public ExternalResource inMemoryMetricsSink = new InMemoryMetricsSinkRule();
 
+  @Rule
+  public ClearAggregatorsRule clearAggregators = new ClearAggregatorsRule();
+
   private static final String[] WORDS_ARRAY = {
       "hi there", "hi", "hi sue bob",
       "hi sue", "", "bob hi"};

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/0feb6499/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/SideEffectsTest.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/SideEffectsTest.java
 
b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/SideEffectsTest.java
index 0d15d12..f85baab 100644
--- 
a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/SideEffectsTest.java
+++ 
b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/SideEffectsTest.java
@@ -67,8 +67,7 @@ public class SideEffectsTest implements Serializable {
 
       // TODO: remove the version check (and the setup and teardown methods) 
when we no
       // longer support Spark 1.3 or 1.4
-      String version = 
SparkContextFactory.getSparkContext(options.getSparkMaster(),
-          options.getAppName()).version();
+      String version = SparkContextFactory.getSparkContext(options).version();
       if (!version.startsWith("1.3.") && !version.startsWith("1.4.")) {
         assertTrue(e.getCause() instanceof UserException);
       }

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/0feb6499/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/FlattenStreamingTest.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/FlattenStreamingTest.java
 
b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/FlattenStreamingTest.java
index a6fe755..8210b0d 100644
--- 
a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/FlattenStreamingTest.java
+++ 
b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/FlattenStreamingTest.java
@@ -22,19 +22,21 @@ import java.util.Collections;
 import java.util.List;
 import org.apache.beam.runners.spark.EvaluationResult;
 import org.apache.beam.runners.spark.SparkPipelineOptions;
-import org.apache.beam.runners.spark.SparkRunner;
 import org.apache.beam.runners.spark.io.CreateStream;
 import 
org.apache.beam.runners.spark.translation.streaming.utils.PAssertStreaming;
+import 
org.apache.beam.runners.spark.translation.streaming.utils.TestOptionsForStreaming;
 import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.coders.StringUtf8Coder;
-import org.apache.beam.sdk.options.PipelineOptionsFactory;
+import org.apache.beam.sdk.transforms.Create;
 import org.apache.beam.sdk.transforms.Flatten;
 import org.apache.beam.sdk.transforms.windowing.FixedWindows;
 import org.apache.beam.sdk.transforms.windowing.Window;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PCollectionList;
 import org.joda.time.Duration;
+import org.junit.Rule;
 import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
 
 /**
  * Test Flatten (union) implementation for streaming.
@@ -51,26 +53,50 @@ public class FlattenStreamingTest {
           
Collections.<Iterable<String>>singletonList(Arrays.asList(WORDS_ARRAY_2));
   private static final String[] EXPECTED_UNION = {
           "one", "two", "three", "four", "five", "six", "seven", "eight"};
-  private static final long TEST_TIMEOUT_MSEC = 1000L;
+
+  @Rule
+  public TemporaryFolder checkpointParentDir = new TemporaryFolder();
+
+  @Rule
+  public TestOptionsForStreaming commonOptions = new TestOptionsForStreaming();
 
   @Test
-  public void testRun() throws Exception {
-    SparkPipelineOptions options =
-        PipelineOptionsFactory.as(SparkPipelineOptions.class);
-    options.setRunner(SparkRunner.class);
-    options.setStreaming(true);
-    // using the default 1000 msec interval
-    options.setTimeout(TEST_TIMEOUT_MSEC); // run for one interval
+  public void testFlattenUnbounded() throws Exception {
+    SparkPipelineOptions options = commonOptions.withTmpCheckpointDir(
+        checkpointParentDir.newFolder(getClass().getSimpleName()));
+
     Pipeline p = Pipeline.create(options);
+    PCollection<String> w1 =
+        
p.apply(CreateStream.fromQueue(WORDS_QUEUE_1)).setCoder(StringUtf8Coder.of());
+    PCollection<String> windowedW1 =
+        
w1.apply(Window.<String>into(FixedWindows.of(Duration.standardSeconds(1))));
+    PCollection<String> w2 =
+        
p.apply(CreateStream.fromQueue(WORDS_QUEUE_2)).setCoder(StringUtf8Coder.of());
+    PCollection<String> windowedW2 =
+        
w2.apply(Window.<String>into(FixedWindows.of(Duration.standardSeconds(1))));
+    PCollectionList<String> list = 
PCollectionList.of(windowedW1).and(windowedW2);
+    PCollection<String> union = list.apply(Flatten.<String>pCollections());
 
+    PAssertStreaming.assertContents(union, EXPECTED_UNION);
+
+    EvaluationResult res = (EvaluationResult) p.run();
+    res.close();
+  }
+
+  @Test
+  public void testFlattenBoundedUnbounded() throws Exception {
+    SparkPipelineOptions options = commonOptions.withTmpCheckpointDir(
+        checkpointParentDir.newFolder(getClass().getSimpleName()));
+
+    Pipeline p = Pipeline.create(options);
     PCollection<String> w1 =
-            
p.apply(CreateStream.fromQueue(WORDS_QUEUE_1)).setCoder(StringUtf8Coder.of());
+        
p.apply(CreateStream.fromQueue(WORDS_QUEUE_1)).setCoder(StringUtf8Coder.of());
     PCollection<String> windowedW1 =
-            
w1.apply(Window.<String>into(FixedWindows.of(Duration.standardSeconds(1))));
+        
w1.apply(Window.<String>into(FixedWindows.of(Duration.standardSeconds(1))));
     PCollection<String> w2 =
-            
p.apply(CreateStream.fromQueue(WORDS_QUEUE_2)).setCoder(StringUtf8Coder.of());
+        p.apply(Create.of(WORDS_ARRAY_2)).setCoder(StringUtf8Coder.of());
     PCollection<String> windowedW2 =
-            
w2.apply(Window.<String>into(FixedWindows.of(Duration.standardSeconds(1))));
+        
w2.apply(Window.<String>into(FixedWindows.of(Duration.standardSeconds(1))));
     PCollectionList<String> list = 
PCollectionList.of(windowedW1).and(windowedW2);
     PCollection<String> union = list.apply(Flatten.<String>pCollections());
 

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/0feb6499/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/KafkaStreamingTest.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/KafkaStreamingTest.java
 
b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/KafkaStreamingTest.java
index ac77922..caf5d13 100644
--- 
a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/KafkaStreamingTest.java
+++ 
b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/KafkaStreamingTest.java
@@ -25,14 +25,13 @@ import java.util.Properties;
 import kafka.serializer.StringDecoder;
 import org.apache.beam.runners.spark.EvaluationResult;
 import org.apache.beam.runners.spark.SparkPipelineOptions;
-import org.apache.beam.runners.spark.SparkRunner;
 import org.apache.beam.runners.spark.io.KafkaIO;
 import 
org.apache.beam.runners.spark.translation.streaming.utils.EmbeddedKafkaCluster;
 import 
org.apache.beam.runners.spark.translation.streaming.utils.PAssertStreaming;
+import 
org.apache.beam.runners.spark.translation.streaming.utils.TestOptionsForStreaming;
 import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.coders.KvCoder;
 import org.apache.beam.sdk.coders.StringUtf8Coder;
-import org.apache.beam.sdk.options.PipelineOptionsFactory;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.transforms.windowing.FixedWindows;
@@ -43,11 +42,13 @@ import org.apache.kafka.clients.producer.KafkaProducer;
 import org.apache.kafka.clients.producer.ProducerRecord;
 import org.apache.kafka.common.serialization.Serializer;
 import org.apache.kafka.common.serialization.StringSerializer;
-import org.apache.spark.streaming.Durations;
 import org.joda.time.Duration;
 import org.junit.AfterClass;
 import org.junit.BeforeClass;
+import org.junit.Rule;
 import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
 /**
  * Test Kafka as input.
  */
@@ -61,7 +62,6 @@ public class KafkaStreamingTest {
       "k1", "v1", "k2", "v2", "k3", "v3", "k4", "v4"
   );
   private static final String[] EXPECTED = {"k1,v1", "k2,v2", "k3,v3", 
"k4,v4"};
-  private static final long TEST_TIMEOUT_MSEC = 1000L;
 
   @BeforeClass
   public static void init() throws IOException {
@@ -82,22 +82,22 @@ public class KafkaStreamingTest {
     }
   }
 
+  @Rule
+  public TemporaryFolder checkpointParentDir = new TemporaryFolder();
+
+  @Rule
+  public TestOptionsForStreaming commonOptions = new TestOptionsForStreaming();
+
   @Test
   public void testRun() throws Exception {
-    // test read from Kafka
-    SparkPipelineOptions options =
-        PipelineOptionsFactory.as(SparkPipelineOptions.class);
-    options.setRunner(SparkRunner.class);
-    options.setStreaming(true);
-    options.setBatchIntervalMillis(Durations.seconds(1).milliseconds());
-    options.setTimeout(TEST_TIMEOUT_MSEC); // run for one interval
-    Pipeline p = Pipeline.create(options);
-
+    SparkPipelineOptions options = commonOptions.withTmpCheckpointDir(
+        checkpointParentDir.newFolder(getClass().getSimpleName()));
     Map<String, String> kafkaParams = ImmutableMap.of(
         "metadata.broker.list", EMBEDDED_KAFKA_CLUSTER.getBrokerList(),
         "auto.offset.reset", "smallest"
     );
 
+    Pipeline p = Pipeline.create(options);
     PCollection<KV<String, String>> kafkaInput = 
p.apply(KafkaIO.Read.from(StringDecoder.class,
         StringDecoder.class, String.class, String.class, 
Collections.singleton(TOPIC),
         kafkaParams))

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/0feb6499/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/RecoverFromCheckpointStreamingTest.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/RecoverFromCheckpointStreamingTest.java
 
b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/RecoverFromCheckpointStreamingTest.java
new file mode 100644
index 0000000..4a96690
--- /dev/null
+++ 
b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/RecoverFromCheckpointStreamingTest.java
@@ -0,0 +1,179 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.spark.translation.streaming;
+
+import static org.hamcrest.Matchers.equalTo;
+import static org.junit.Assert.assertThat;
+
+import com.google.common.collect.ImmutableMap;
+import com.google.common.util.concurrent.Uninterruptibles;
+import java.io.IOException;
+import java.util.Collections;
+import java.util.Map;
+import java.util.Properties;
+import java.util.concurrent.TimeUnit;
+import kafka.serializer.StringDecoder;
+import org.apache.beam.runners.spark.EvaluationResult;
+import org.apache.beam.runners.spark.SparkPipelineOptions;
+import org.apache.beam.runners.spark.aggregators.AccumulatorSingleton;
+import org.apache.beam.runners.spark.io.KafkaIO;
+import 
org.apache.beam.runners.spark.translation.streaming.utils.EmbeddedKafkaCluster;
+import 
org.apache.beam.runners.spark.translation.streaming.utils.PAssertStreaming;
+import 
org.apache.beam.runners.spark.translation.streaming.utils.TestOptionsForStreaming;
+import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.sdk.transforms.Aggregator;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.Sum;
+import org.apache.beam.sdk.transforms.windowing.FixedWindows;
+import org.apache.beam.sdk.transforms.windowing.Window;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.kafka.clients.producer.KafkaProducer;
+import org.apache.kafka.clients.producer.ProducerRecord;
+import org.apache.kafka.common.serialization.Serializer;
+import org.apache.kafka.common.serialization.StringSerializer;
+import org.joda.time.Duration;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+
+/**
+ * Tests DStream recovery from checkpoint - recreate the job and continue 
(from checkpoint).
+ *
+ * Tests Aggregators, which rely on Accumulators - Aggregators should be 
available, though state
+ * is not preserved (Spark issue), so they start from initial value.
+ * //TODO: after the runner supports recovering the state of Aggregators, 
update this test's
+ * expected values for the recovered (second) run.
+ */
+public class RecoverFromCheckpointStreamingTest {
+  private static final EmbeddedKafkaCluster.EmbeddedZookeeper 
EMBEDDED_ZOOKEEPER =
+      new EmbeddedKafkaCluster.EmbeddedZookeeper();
+  private static final EmbeddedKafkaCluster EMBEDDED_KAFKA_CLUSTER =
+      new EmbeddedKafkaCluster(EMBEDDED_ZOOKEEPER.getConnection(), new 
Properties());
+  private static final String TOPIC = "kafka_beam_test_topic";
+  private static final Map<String, String> KAFKA_MESSAGES = ImmutableMap.of(
+      "k1", "v1", "k2", "v2", "k3", "v3", "k4", "v4"
+  );
+  private static final String[] EXPECTED = {"k1,v1", "k2,v2", "k3,v3", 
"k4,v4"};
+  private static final long EXPECTED_AGG_FIRST = 4L;
+
+  @Rule
+  public TemporaryFolder checkpointParentDir = new TemporaryFolder();
+
+  @Rule
+  public TestOptionsForStreaming commonOptions = new TestOptionsForStreaming();
+
+  @BeforeClass
+  public static void init() throws IOException {
+    EMBEDDED_ZOOKEEPER.startup();
+    EMBEDDED_KAFKA_CLUSTER.startup();
+    /// this test actually requires to NOT reuse the context but rather to 
stop it and start again
+    // from the checkpoint with a brand new context.
+    System.setProperty("beam.spark.test.reuseSparkContext", "false");
+    // write to Kafka
+    Properties producerProps = new Properties();
+    producerProps.putAll(EMBEDDED_KAFKA_CLUSTER.getProps());
+    producerProps.put("request.required.acks", 1);
+    producerProps.put("bootstrap.servers", 
EMBEDDED_KAFKA_CLUSTER.getBrokerList());
+    Serializer<String> stringSerializer = new StringSerializer();
+    try (@SuppressWarnings("unchecked") KafkaProducer<String, String> 
kafkaProducer =
+        new KafkaProducer(producerProps, stringSerializer, stringSerializer)) {
+      for (Map.Entry<String, String> en : KAFKA_MESSAGES.entrySet()) {
+        kafkaProducer.send(new ProducerRecord<>(TOPIC, en.getKey(), 
en.getValue()));
+      }
+      kafkaProducer.close();
+    }
+  }
+
+  @Test
+  public void testRun() throws Exception {
+    SparkPipelineOptions options = commonOptions.withTmpCheckpointDir(
+        checkpointParentDir.newFolder(getClass().getSimpleName()));
+
+    // checkpoint after first (and only) interval.
+    options.setCheckpointDurationMillis(options.getBatchIntervalMillis());
+
+    // first run will read from Kafka backlog - "auto.offset.reset=smallest"
+    EvaluationResult res = run(options);
+    res.close();
+    long processedMessages1 = res.getAggregatorValue("processedMessages", 
Long.class);
+    assertThat(String.format("Expected %d processed messages count but "
+        + "found %d", EXPECTED_AGG_FIRST, processedMessages1), 
processedMessages1,
+            equalTo(EXPECTED_AGG_FIRST));
+
+    // recovery should resume from last read offset, so nothing is read here.
+    res = runAgain(options);
+    res.close();
+    long processedMessages2 = res.getAggregatorValue("processedMessages", 
Long.class);
+    assertThat(String.format("Expected %d processed messages count but "
+        + "found %d", 0, processedMessages2), processedMessages2, equalTo(0L));
+  }
+
+  private static EvaluationResult runAgain(SparkPipelineOptions options) {
+    AccumulatorSingleton.clear();
+    // sleep before next run.
+    Uninterruptibles.sleepUninterruptibly(100, TimeUnit.MILLISECONDS);
+    return run(options);
+  }
+
+  private static EvaluationResult run(SparkPipelineOptions options) {
+    Map<String, String> kafkaParams = ImmutableMap.of(
+            "metadata.broker.list", EMBEDDED_KAFKA_CLUSTER.getBrokerList(),
+            "auto.offset.reset", "smallest"
+    );
+    Pipeline p = Pipeline.create(options);
+    PCollection<KV<String, String>> kafkaInput = p.apply(KafkaIO.Read.from(
+        StringDecoder.class, StringDecoder.class, String.class, String.class,
+            Collections.singleton(TOPIC), 
kafkaParams)).setCoder(KvCoder.of(StringUtf8Coder.of(),
+                StringUtf8Coder.of()));
+    PCollection<KV<String, String>> windowedWords = kafkaInput
+        .apply(Window.<KV<String, 
String>>into(FixedWindows.of(Duration.standardSeconds(1))));
+    PCollection<String> formattedKV = windowedWords.apply(ParDo.of(
+        new FormatAsText()));
+
+    PAssertStreaming.assertContents(formattedKV, EXPECTED);
+
+    return  (EvaluationResult) p.run();
+  }
+
+  @AfterClass
+  public static void tearDown() {
+    EMBEDDED_KAFKA_CLUSTER.shutdown();
+    EMBEDDED_ZOOKEEPER.shutdown();
+  }
+
+  private static class FormatAsText extends DoFn<KV<String, String>, String> {
+
+    private final Aggregator<Long, Long> aggregator =
+        createAggregator("processedMessages", new Sum.SumLongFn());
+
+    @ProcessElement
+    public void process(ProcessContext c) {
+      aggregator.addValue(1L);
+      String formatted = c.element().getKey() + "," + c.element().getValue();
+      c.output(formatted);
+    }
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/0feb6499/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/SimpleStreamingWordCountTest.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/SimpleStreamingWordCountTest.java
 
b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/SimpleStreamingWordCountTest.java
index 671d227..1464273 100644
--- 
a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/SimpleStreamingWordCountTest.java
+++ 
b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/SimpleStreamingWordCountTest.java
@@ -24,20 +24,21 @@ import java.util.Collections;
 import java.util.List;
 import org.apache.beam.runners.spark.EvaluationResult;
 import org.apache.beam.runners.spark.SparkPipelineOptions;
-import org.apache.beam.runners.spark.SparkRunner;
 import org.apache.beam.runners.spark.examples.WordCount;
 import org.apache.beam.runners.spark.io.CreateStream;
 import 
org.apache.beam.runners.spark.translation.streaming.utils.PAssertStreaming;
+import 
org.apache.beam.runners.spark.translation.streaming.utils.TestOptionsForStreaming;
 import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.coders.StringUtf8Coder;
-import org.apache.beam.sdk.options.PipelineOptionsFactory;
 import org.apache.beam.sdk.transforms.MapElements;
 import org.apache.beam.sdk.transforms.windowing.FixedWindows;
 import org.apache.beam.sdk.transforms.windowing.Window;
 import org.apache.beam.sdk.values.PCollection;
-import org.apache.spark.streaming.Durations;
 import org.joda.time.Duration;
+import org.junit.Rule;
 import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
 
 /**
  * Simple word count streaming test.
@@ -49,23 +50,23 @@ public class SimpleStreamingWordCountTest implements 
Serializable {
   private static final List<Iterable<String>> WORDS_QUEUE =
       Collections.<Iterable<String>>singletonList(Arrays.asList(WORDS_ARRAY));
   private static final String[] EXPECTED_COUNTS = {"hi: 5", "there: 1", "sue: 
2", "bob: 2"};
-  private static final long TEST_TIMEOUT_MSEC = 1000L;
+
+  @Rule
+  public TemporaryFolder checkpointParentDir = new TemporaryFolder();
+
+  @Rule
+  public TestOptionsForStreaming commonOptions = new TestOptionsForStreaming();
 
   @Test
   public void testRun() throws Exception {
-    SparkPipelineOptions options =
-        PipelineOptionsFactory.as(SparkPipelineOptions.class);
-    options.setRunner(SparkRunner.class);
-    options.setStreaming(true);
-    options.setBatchIntervalMillis(Durations.seconds(1).milliseconds());
-    options.setTimeout(TEST_TIMEOUT_MSEC); // run for one interval
-    Pipeline p = Pipeline.create(options);
+    SparkPipelineOptions options = commonOptions.withTmpCheckpointDir(
+        checkpointParentDir.newFolder(getClass().getSimpleName()));
 
+    Pipeline p = Pipeline.create(options);
     PCollection<String> inputWords =
         
p.apply(CreateStream.fromQueue(WORDS_QUEUE)).setCoder(StringUtf8Coder.of());
     PCollection<String> windowedWords = inputWords
         
.apply(Window.<String>into(FixedWindows.of(Duration.standardSeconds(1))));
-
     PCollection<String> output = windowedWords.apply(new 
WordCount.CountWords())
         .apply(MapElements.via(new WordCount.FormatAsTextFn()));
 

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/0feb6499/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/TestOptionsForStreaming.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/TestOptionsForStreaming.java
 
b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/TestOptionsForStreaming.java
new file mode 100644
index 0000000..d695df0
--- /dev/null
+++ 
b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/TestOptionsForStreaming.java
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.spark.translation.streaming.utils;
+
+
+import java.io.File;
+import java.net.MalformedURLException;
+import org.apache.beam.runners.spark.SparkPipelineOptions;
+import org.apache.beam.runners.spark.SparkRunner;
+import org.apache.beam.sdk.options.PipelineOptionsFactory;
+import org.junit.rules.ExternalResource;
+
+
+/**
+ * A rule to create a common {@link SparkPipelineOptions} for testing 
streaming pipelines.
+ */
+public class TestOptionsForStreaming extends ExternalResource {
+  private final SparkPipelineOptions options =
+      PipelineOptionsFactory.as(SparkPipelineOptions.class);
+
+  @Override
+  protected void before() throws Throwable {
+    options.setRunner(SparkRunner.class);
+    options.setStreaming(true);
+    options.setTimeout(1000L);
+  }
+
+  public SparkPipelineOptions withTmpCheckpointDir(File checkpointDir)
+      throws MalformedURLException {
+    // tests use JUnit's TemporaryFolder path in the form of: /.../junit/...
+    // so need to add the missing protocol.
+    options.setCheckpointDir(checkpointDir.toURI().toURL().toString());
+    return options;
+  }
+
+  public SparkPipelineOptions getOptions() {
+    return options;
+  }
+}

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/0feb6499/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Combine.java
----------------------------------------------------------------------
diff --git 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Combine.java 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Combine.java
index 2b89372..a00dcba 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Combine.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Combine.java
@@ -1410,6 +1410,13 @@ public class Combine {
           ImmutableList.copyOf(sideInputs));
     }
 
+    /**
+     * Returns the {@link GlobalCombineFn} used by this Combine operation.
+     */
+    public GlobalCombineFn<? super InputT, ?, OutputT> getFn() {
+      return fn;
+    }
+
     @Override
     public PCollection<OutputT> apply(PCollection<InputT> input) {
       PCollection<KV<Void, InputT>> withKeys = input

Reply via email to