Repository: flink
Updated Branches:
  refs/heads/master fc343e0c3 -> 82ed79999


[FLINK-5407] IT case for savepoint with iterative job


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/82ed7999
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/82ed7999
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/82ed7999

Branch: refs/heads/master
Commit: 82ed799999e3f05ebfd67d69dfb56ff13dbd497a
Parents: 9c6eb57
Author: Stefan Richter <[email protected]>
Authored: Tue Jan 10 16:08:06 2017 +0100
Committer: Aljoscha Krettek <[email protected]>
Committed: Thu Jan 12 17:40:32 2017 +0100

----------------------------------------------------------------------
 .../runtime/testingUtils/TestingCluster.scala   |  74 ++++++-
 .../test/checkpointing/SavepointITCase.java     | 198 +++++++++++++++++++
 2 files changed, 269 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/82ed7999/flink-runtime/src/test/scala/org/apache/flink/runtime/testingUtils/TestingCluster.scala
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/scala/org/apache/flink/runtime/testingUtils/TestingCluster.scala
 
b/flink-runtime/src/test/scala/org/apache/flink/runtime/testingUtils/TestingCluster.scala
index 269a66f..d6215eb 100644
--- 
a/flink-runtime/src/test/scala/org/apache/flink/runtime/testingUtils/TestingCluster.scala
+++ 
b/flink-runtime/src/test/scala/org/apache/flink/runtime/testingUtils/TestingCluster.scala
@@ -18,26 +18,31 @@
 
 package org.apache.flink.runtime.testingUtils
 
-import java.util.concurrent.{Executor, ExecutorService, TimeUnit, 
TimeoutException}
+import java.io.IOException
+import java.util.concurrent.{Executor, TimeUnit, TimeoutException}
 
 import akka.actor.{ActorRef, ActorSystem, Props}
 import akka.pattern.Patterns._
 import akka.pattern.ask
 import akka.testkit.CallingThreadDispatcher
+import org.apache.flink.api.common.JobID
 import org.apache.flink.configuration.{ConfigConstants, Configuration}
 import org.apache.flink.runtime.akka.AkkaUtils
 import org.apache.flink.runtime.checkpoint.CheckpointRecoveryFactory
+import org.apache.flink.runtime.checkpoint.savepoint.Savepoint
 import org.apache.flink.runtime.clusterframework.FlinkResourceManager
 import org.apache.flink.runtime.clusterframework.types.ResourceIDRetrievable
 import org.apache.flink.runtime.execution.librarycache.BlobLibraryCacheManager
 import org.apache.flink.runtime.executiongraph.restart.RestartStrategyFactory
-import org.apache.flink.runtime.instance.InstanceManager
+import org.apache.flink.runtime.instance.{ActorGateway, InstanceManager}
 import org.apache.flink.runtime.jobmanager.scheduler.Scheduler
 import org.apache.flink.runtime.jobmanager.{JobManager, MemoryArchivist, 
SubmittedJobGraphStore}
 import org.apache.flink.runtime.leaderelection.LeaderElectionService
+import org.apache.flink.runtime.messages.JobManagerMessages._
 import org.apache.flink.runtime.metrics.MetricRegistry
 import org.apache.flink.runtime.minicluster.LocalFlinkMiniCluster
 import org.apache.flink.runtime.taskmanager.TaskManager
+import 
org.apache.flink.runtime.testingUtils.TestingJobManagerMessages.ResponseSavepoint
 import org.apache.flink.runtime.testingUtils.TestingMessages.Alive
 import 
org.apache.flink.runtime.testingUtils.TestingTaskManagerMessages.NotifyWhenRegisteredAtJobManager
 import org.apache.flink.runtime.testutils.TestingResourceManager
@@ -281,7 +286,70 @@ class TestingCluster(
       }
     }
   }
-}
+
+  @throws(classOf[IOException])
+  def triggerSavepoint(jobId: JobID): String = {
+    val timeout = AkkaUtils.getTimeout(configuration)
+    triggerSavepoint(jobId, getLeaderGateway(timeout), timeout)
+  }
+
+  @throws(classOf[IOException])
+  def requestSavepoint(savepointPath: String): Savepoint = {
+    val timeout = AkkaUtils.getTimeout(configuration)
+    requestSavepoint(savepointPath, getLeaderGateway(timeout), timeout)
+  }
+
+  @throws(classOf[IOException])
+  def disposeSavepoint(savepointPath: String): Unit = {
+    val timeout = AkkaUtils.getTimeout(configuration)
+    disposeSavepoint(savepointPath, getLeaderGateway(timeout), timeout)
+  }
+
+  @throws(classOf[IOException])
+  def triggerSavepoint(
+      jobId: JobID,
+      jobManager: ActorGateway,
+      timeout: FiniteDuration): String = {
+    val result = Await.result(
+      jobManager.ask(
+        TriggerSavepoint(jobId), timeout), timeout)
+
+    result match {
+      case success: TriggerSavepointSuccess => success.savepointPath
+      case fail: TriggerSavepointFailure => throw new IOException(fail.cause)
+      case _ => throw new IllegalStateException("Trigger savepoint failed")
+    }
+  }
+
+  @throws(classOf[IOException])
+  def requestSavepoint(
+      savepointPath: String,
+      jobManager: ActorGateway,
+      timeout: FiniteDuration): Savepoint = {
+    val result = Await.result(
+      jobManager.ask(
+        TestingJobManagerMessages.RequestSavepoint(savepointPath), timeout), 
timeout)
+
+    result match {
+      case success: ResponseSavepoint => success.savepoint
+      case _ => throw new IOException("Request savepoint failed")
+    }
+  }
+
+  @throws(classOf[IOException])
+  def disposeSavepoint(
+      savepointPath: String,
+      jobManager: ActorGateway,
+      timeout: FiniteDuration): Unit = {
+    val timeout = AkkaUtils.getTimeout(originalConfiguration)
+    val jobManager = getLeaderGateway(timeout)
+    val result = Await.result(jobManager.ask(DisposeSavepoint(savepointPath), 
timeout), timeout)
+    result match {
+      case DisposeSavepointSuccess =>
+      case _ => throw new IOException("Dispose savepoint failed")
+    }
+  }
+ }
 
 object TestingCluster {
   val MAX_RESTART_DURATION = new FiniteDuration(2, TimeUnit.MINUTES)

http://git-wip-us.apache.org/repos/asf/flink/blob/82ed7999/flink-tests/src/test/java/org/apache/flink/test/checkpointing/SavepointITCase.java
----------------------------------------------------------------------
diff --git 
a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/SavepointITCase.java
 
b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/SavepointITCase.java
index d52f115..9f957e5 100644
--- 
a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/SavepointITCase.java
+++ 
b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/SavepointITCase.java
@@ -25,10 +25,16 @@ import com.google.common.collect.HashMultimap;
 import com.google.common.collect.Multimap;
 import org.apache.commons.io.FileUtils;
 import org.apache.flink.api.common.JobID;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
 import org.apache.flink.api.common.functions.RichMapFunction;
 import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.state.ValueState;
+import org.apache.flink.api.common.state.ValueStateDescriptor;
+import org.apache.flink.api.java.functions.KeySelector;
 import org.apache.flink.configuration.ConfigConstants;
 import org.apache.flink.configuration.Configuration;
+import org.apache.flink.core.testutils.OneShotLatch;
 import org.apache.flink.runtime.akka.AkkaUtils;
 import org.apache.flink.runtime.checkpoint.SubtaskState;
 import org.apache.flink.runtime.checkpoint.TaskState;
@@ -58,11 +64,17 @@ import 
org.apache.flink.runtime.testingUtils.TestingTaskManagerMessages;
 import 
org.apache.flink.runtime.testingUtils.TestingTaskManagerMessages.ResponseSubmitTaskListener;
 import org.apache.flink.runtime.testutils.CommonTestUtils;
 import org.apache.flink.streaming.api.checkpoint.Checkpointed;
+import org.apache.flink.streaming.api.checkpoint.ListCheckpointed;
 import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.IterativeStream;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
 import org.apache.flink.streaming.api.functions.sink.DiscardingSink;
+import org.apache.flink.streaming.api.functions.source.RichSourceFunction;
 import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.streaming.api.graph.StreamGraph;
+import org.apache.flink.util.Collector;
 import org.apache.flink.util.TestLogger;
+import org.junit.Assert;
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.rules.TemporaryFolder;
@@ -78,6 +90,7 @@ import java.io.File;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collection;
+import java.util.Collections;
 import java.util.List;
 import java.util.Random;
 import java.util.concurrent.TimeUnit;
@@ -613,4 +626,189 @@ public class SavepointITCase extends TestLogger {
                }
        }
 
+       private static final int ITER_TEST_PARALLELISM = 1;
+       private static OneShotLatch[] ITER_TEST_SNAPSHOT_WAIT = new 
OneShotLatch[ITER_TEST_PARALLELISM];
+       private static OneShotLatch[] ITER_TEST_RESTORE_WAIT = new 
OneShotLatch[ITER_TEST_PARALLELISM];
+       private static int[] ITER_TEST_CHECKPOINT_VERIFY = new 
int[ITER_TEST_PARALLELISM];
+
+       @Test
+       public void testSavepointForJobWithIteration() throws Exception {
+
+               for (int i = 0; i < ITER_TEST_PARALLELISM; ++i) {
+                       ITER_TEST_SNAPSHOT_WAIT[i] = new OneShotLatch();
+                       ITER_TEST_RESTORE_WAIT[i] = new OneShotLatch();
+                       ITER_TEST_CHECKPOINT_VERIFY[i] = 0;
+               }
+
+               TemporaryFolder folder = new TemporaryFolder();
+               folder.create();
+               // Temporary directory for file state backend
+               final File tmpDir = folder.newFolder();
+
+               final StreamExecutionEnvironment env = 
StreamExecutionEnvironment.getExecutionEnvironment();
+               final IntegerStreamSource source = new IntegerStreamSource();
+               IterativeStream<Integer> iteration = env.addSource(source)
+                               .flatMap(new RichFlatMapFunction<Integer, 
Integer>() {
+
+                                       private static final long 
serialVersionUID = 1L;
+
+                                       @Override
+                                       public void flatMap(Integer in, 
Collector<Integer> clctr) throws Exception {
+                                               clctr.collect(in);
+                                       }
+                               }).setParallelism(ITER_TEST_PARALLELISM)
+                               .keyBy(new KeySelector<Integer, Object>() {
+
+                                       private static final long 
serialVersionUID = 1L;
+
+                                       @Override
+                                       public Object getKey(Integer value) 
throws Exception {
+                                               return value;
+                                       }
+                               })
+                               .flatMap(new DuplicateFilter())
+                               .setParallelism(ITER_TEST_PARALLELISM)
+                               .iterate();
+
+               DataStream<Integer> iterationBody = iteration
+                               .map(new MapFunction<Integer, Integer>() {
+                                       private static final long 
serialVersionUID = 1L;
+
+                                       @Override
+                                       public Integer map(Integer value) 
throws Exception {
+                                               return value;
+                                       }
+                               })
+                               .setParallelism(ITER_TEST_PARALLELISM);
+
+               iteration.closeWith(iterationBody);
+
+               StreamGraph streamGraph = env.getStreamGraph();
+               streamGraph.setJobName("Test");
+
+               JobGraph jobGraph = streamGraph.getJobGraph();
+
+               Configuration config = new Configuration();
+               config.addAll(jobGraph.getJobConfiguration());
+               config.setLong(ConfigConstants.TASK_MANAGER_MEMORY_SIZE_KEY, 
-1L);
+               config.setInteger(ConfigConstants.TASK_MANAGER_NUM_TASK_SLOTS, 
2 * jobGraph.getMaximumParallelism());
+               final File checkpointDir = new File(tmpDir, "checkpoints");
+               final File savepointDir = new File(tmpDir, "savepoints");
+
+               if (!checkpointDir.mkdir() || !savepointDir.mkdirs()) {
+                       fail("Test setup failed: failed to create temporary 
directories.");
+               }
+
+               config.setString(ConfigConstants.STATE_BACKEND, "filesystem");
+               
config.setString(FsStateBackendFactory.CHECKPOINT_DIRECTORY_URI_CONF_KEY,
+                               checkpointDir.toURI().toString());
+               
config.setString(FsStateBackendFactory.MEMORY_THRESHOLD_CONF_KEY, "0");
+               config.setString(ConfigConstants.SAVEPOINT_DIRECTORY_KEY,
+                               savepointDir.toURI().toString());
+
+               TestingCluster cluster = new TestingCluster(config, false);
+               String savepointPath = null;
+               try {
+                       cluster.start();
+
+                       cluster.submitJobDetached(jobGraph);
+                       for (OneShotLatch latch : ITER_TEST_SNAPSHOT_WAIT) {
+                               latch.await();
+                       }
+                       savepointPath = 
cluster.triggerSavepoint(jobGraph.getJobID());
+                       source.cancel();
+
+                       jobGraph = streamGraph.getJobGraph();
+                       
jobGraph.setSavepointRestoreSettings(SavepointRestoreSettings.forPath(savepointPath));
+
+                       cluster.submitJobDetached(jobGraph);
+                       for (OneShotLatch latch : ITER_TEST_RESTORE_WAIT) {
+                               latch.await();
+                       }
+                       source.cancel();
+               } finally {
+                       if (null != savepointPath) {
+                               cluster.disposeSavepoint(savepointPath);
+                       }
+                       cluster.stop();
+                       cluster.awaitTermination();
+               }
+       }
+
+       private static final class IntegerStreamSource
+                       extends RichSourceFunction<Integer>
+                       implements ListCheckpointed<Integer> {
+
+               private static final long serialVersionUID = 1L;
+               private volatile boolean running;
+               private volatile boolean isRestored;
+               private int emittedCount;
+
+               public IntegerStreamSource() {
+                       this.running = true;
+                       this.isRestored = false;
+                       this.emittedCount = 0;
+               }
+
+               @Override
+               public void run(SourceContext<Integer> ctx) throws Exception {
+
+                       while (running) {
+                               synchronized (ctx.getCheckpointLock()) {
+                                       ctx.collect(emittedCount);
+                               }
+
+                               if (emittedCount < 100) {
+                                       ++emittedCount;
+                               } else {
+                                       emittedCount = 0;
+                               }
+                               Thread.sleep(1);
+                       }
+               }
+
+               @Override
+               public void cancel() {
+                       running = false;
+               }
+
+               @Override
+               public List<Integer> snapshotState(long checkpointId, long 
timestamp) throws Exception {
+                       
ITER_TEST_CHECKPOINT_VERIFY[getRuntimeContext().getIndexOfThisSubtask()] = 
emittedCount;
+                       return Collections.singletonList(emittedCount);
+               }
+
+               @Override
+               public void restoreState(List<Integer> state) throws Exception {
+                       if (!state.isEmpty()) {
+                               this.emittedCount = state.get(0);
+                       }
+                       
Assert.assertEquals(ITER_TEST_CHECKPOINT_VERIFY[getRuntimeContext().getIndexOfThisSubtask()],
 emittedCount);
+                       
ITER_TEST_RESTORE_WAIT[getRuntimeContext().getIndexOfThisSubtask()].trigger();
+               }
+       }
+
+       public static class DuplicateFilter extends 
RichFlatMapFunction<Integer, Integer> {
+
+               static final ValueStateDescriptor<Boolean> descriptor = new 
ValueStateDescriptor<>("seen", Boolean.class, false);
+               private static final long serialVersionUID = 1L;
+               private ValueState<Boolean> operatorState;
+
+               @Override
+               public void open(Configuration configuration) {
+                       operatorState = 
this.getRuntimeContext().getState(descriptor);
+               }
+
+               @Override
+               public void flatMap(Integer value, Collector<Integer> out) 
throws Exception {
+                       if (!operatorState.value()) {
+                               out.collect(value);
+                               operatorState.update(true);
+                       }
+
+                       if (30 == value) {
+                               
ITER_TEST_SNAPSHOT_WAIT[getRuntimeContext().getIndexOfThisSubtask()].trigger();
+                       }
+               }
+       }
 }

Reply via email to