Repository: flink Updated Branches: refs/heads/master fc343e0c3 -> 82ed79999
[FLINK-5407] IT case for savepoint with iterative job Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/82ed7999 Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/82ed7999 Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/82ed7999 Branch: refs/heads/master Commit: 82ed799999e3f05ebfd67d69dfb56ff13dbd497a Parents: 9c6eb57 Author: Stefan Richter <[email protected]> Authored: Tue Jan 10 16:08:06 2017 +0100 Committer: Aljoscha Krettek <[email protected]> Committed: Thu Jan 12 17:40:32 2017 +0100 ---------------------------------------------------------------------- .../runtime/testingUtils/TestingCluster.scala | 74 ++++++- .../test/checkpointing/SavepointITCase.java | 198 +++++++++++++++++++ 2 files changed, 269 insertions(+), 3 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/82ed7999/flink-runtime/src/test/scala/org/apache/flink/runtime/testingUtils/TestingCluster.scala ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/scala/org/apache/flink/runtime/testingUtils/TestingCluster.scala b/flink-runtime/src/test/scala/org/apache/flink/runtime/testingUtils/TestingCluster.scala index 269a66f..d6215eb 100644 --- a/flink-runtime/src/test/scala/org/apache/flink/runtime/testingUtils/TestingCluster.scala +++ b/flink-runtime/src/test/scala/org/apache/flink/runtime/testingUtils/TestingCluster.scala @@ -18,26 +18,31 @@ package org.apache.flink.runtime.testingUtils -import java.util.concurrent.{Executor, ExecutorService, TimeUnit, TimeoutException} +import java.io.IOException +import java.util.concurrent.{Executor, TimeUnit, TimeoutException} import akka.actor.{ActorRef, ActorSystem, Props} import akka.pattern.Patterns._ import akka.pattern.ask import akka.testkit.CallingThreadDispatcher +import org.apache.flink.api.common.JobID import org.apache.flink.configuration.{ConfigConstants, Configuration} import org.apache.flink.runtime.akka.AkkaUtils import org.apache.flink.runtime.checkpoint.CheckpointRecoveryFactory +import org.apache.flink.runtime.checkpoint.savepoint.Savepoint import org.apache.flink.runtime.clusterframework.FlinkResourceManager import org.apache.flink.runtime.clusterframework.types.ResourceIDRetrievable import org.apache.flink.runtime.execution.librarycache.BlobLibraryCacheManager import org.apache.flink.runtime.executiongraph.restart.RestartStrategyFactory -import org.apache.flink.runtime.instance.InstanceManager +import org.apache.flink.runtime.instance.{ActorGateway, InstanceManager} import org.apache.flink.runtime.jobmanager.scheduler.Scheduler import org.apache.flink.runtime.jobmanager.{JobManager, MemoryArchivist, SubmittedJobGraphStore} import org.apache.flink.runtime.leaderelection.LeaderElectionService +import org.apache.flink.runtime.messages.JobManagerMessages._ import org.apache.flink.runtime.metrics.MetricRegistry import org.apache.flink.runtime.minicluster.LocalFlinkMiniCluster import org.apache.flink.runtime.taskmanager.TaskManager +import org.apache.flink.runtime.testingUtils.TestingJobManagerMessages.ResponseSavepoint import org.apache.flink.runtime.testingUtils.TestingMessages.Alive import org.apache.flink.runtime.testingUtils.TestingTaskManagerMessages.NotifyWhenRegisteredAtJobManager import org.apache.flink.runtime.testutils.TestingResourceManager @@ -281,7 +286,70 @@ class TestingCluster( } } } -} + + @throws(classOf[IOException]) + def triggerSavepoint(jobId: JobID): String = { + val timeout = AkkaUtils.getTimeout(configuration) + triggerSavepoint(jobId, getLeaderGateway(timeout), timeout) + } + + @throws(classOf[IOException]) + def requestSavepoint(savepointPath: String): Savepoint = { + val timeout = AkkaUtils.getTimeout(configuration) + requestSavepoint(savepointPath, getLeaderGateway(timeout), timeout) + } + + @throws(classOf[IOException]) + def disposeSavepoint(savepointPath: String): Unit = { + val timeout = AkkaUtils.getTimeout(configuration) + disposeSavepoint(savepointPath, getLeaderGateway(timeout), timeout) + } + + @throws(classOf[IOException]) + def triggerSavepoint( + jobId: JobID, + jobManager: ActorGateway, + timeout: FiniteDuration): String = { + val result = Await.result( + jobManager.ask( + TriggerSavepoint(jobId), timeout), timeout) + + result match { + case success: TriggerSavepointSuccess => success.savepointPath + case fail: TriggerSavepointFailure => throw new IOException(fail.cause) + case _ => throw new IllegalStateException("Trigger savepoint failed") + } + } + + @throws(classOf[IOException]) + def requestSavepoint( + savepointPath: String, + jobManager: ActorGateway, + timeout: FiniteDuration): Savepoint = { + val result = Await.result( + jobManager.ask( + TestingJobManagerMessages.RequestSavepoint(savepointPath), timeout), timeout) + + result match { + case success: ResponseSavepoint => success.savepoint + case _ => throw new IOException("Request savepoint failed") + } + } + + @throws(classOf[IOException]) + def disposeSavepoint( + savepointPath: String, + jobManager: ActorGateway, + timeout: FiniteDuration): Unit = { + val timeout = AkkaUtils.getTimeout(originalConfiguration) + val jobManager = getLeaderGateway(timeout) + val result = Await.result(jobManager.ask(DisposeSavepoint(savepointPath), timeout), timeout) + result match { + case DisposeSavepointSuccess => + case _ => throw new IOException("Dispose savepoint failed") + } + } + } object TestingCluster { val MAX_RESTART_DURATION = new FiniteDuration(2, TimeUnit.MINUTES) http://git-wip-us.apache.org/repos/asf/flink/blob/82ed7999/flink-tests/src/test/java/org/apache/flink/test/checkpointing/SavepointITCase.java ---------------------------------------------------------------------- diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/SavepointITCase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/SavepointITCase.java index d52f115..9f957e5 100644 --- a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/SavepointITCase.java +++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/SavepointITCase.java @@ -25,10 +25,16 @@ import com.google.common.collect.HashMultimap; import com.google.common.collect.Multimap; import org.apache.commons.io.FileUtils; import org.apache.flink.api.common.JobID; +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.RichFlatMapFunction; import org.apache.flink.api.common.functions.RichMapFunction; import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.api.common.state.ValueState; +import org.apache.flink.api.common.state.ValueStateDescriptor; +import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.configuration.ConfigConstants; import org.apache.flink.configuration.Configuration; +import org.apache.flink.core.testutils.OneShotLatch; import org.apache.flink.runtime.akka.AkkaUtils; import org.apache.flink.runtime.checkpoint.SubtaskState; import org.apache.flink.runtime.checkpoint.TaskState; @@ -58,11 +64,17 @@ import org.apache.flink.runtime.testingUtils.TestingTaskManagerMessages; import org.apache.flink.runtime.testingUtils.TestingTaskManagerMessages.ResponseSubmitTaskListener; import org.apache.flink.runtime.testutils.CommonTestUtils; import org.apache.flink.streaming.api.checkpoint.Checkpointed; +import org.apache.flink.streaming.api.checkpoint.ListCheckpointed; import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.datastream.IterativeStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.streaming.api.functions.sink.DiscardingSink; +import org.apache.flink.streaming.api.functions.source.RichSourceFunction; import org.apache.flink.streaming.api.functions.source.SourceFunction; +import org.apache.flink.streaming.api.graph.StreamGraph; +import org.apache.flink.util.Collector; import org.apache.flink.util.TestLogger; +import org.junit.Assert; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; @@ -78,6 +90,7 @@ import java.io.File; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; +import java.util.Collections; import java.util.List; import java.util.Random; import java.util.concurrent.TimeUnit; @@ -613,4 +626,189 @@ public class SavepointITCase extends TestLogger { } } + private static final int ITER_TEST_PARALLELISM = 1; + private static OneShotLatch[] ITER_TEST_SNAPSHOT_WAIT = new OneShotLatch[ITER_TEST_PARALLELISM]; + private static OneShotLatch[] ITER_TEST_RESTORE_WAIT = new OneShotLatch[ITER_TEST_PARALLELISM]; + private static int[] ITER_TEST_CHECKPOINT_VERIFY = new int[ITER_TEST_PARALLELISM]; + + @Test + public void testSavepointForJobWithIteration() throws Exception { + + for (int i = 0; i < ITER_TEST_PARALLELISM; ++i) { + ITER_TEST_SNAPSHOT_WAIT[i] = new OneShotLatch(); + ITER_TEST_RESTORE_WAIT[i] = new OneShotLatch(); + ITER_TEST_CHECKPOINT_VERIFY[i] = 0; + } + + TemporaryFolder folder = new TemporaryFolder(); + folder.create(); + // Temporary directory for file state backend + final File tmpDir = folder.newFolder(); + + final StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + final IntegerStreamSource source = new IntegerStreamSource(); + IterativeStream<Integer> iteration = env.addSource(source) + .flatMap(new RichFlatMapFunction<Integer, Integer>() { + + private static final long serialVersionUID = 1L; + + @Override + public void flatMap(Integer in, Collector<Integer> clctr) throws Exception { + clctr.collect(in); + } + }).setParallelism(ITER_TEST_PARALLELISM) + .keyBy(new KeySelector<Integer, Object>() { + + private static final long serialVersionUID = 1L; + + @Override + public Object getKey(Integer value) throws Exception { + return value; + } + }) + .flatMap(new DuplicateFilter()) + .setParallelism(ITER_TEST_PARALLELISM) + .iterate(); + + DataStream<Integer> iterationBody = iteration + .map(new MapFunction<Integer, Integer>() { + private static final long serialVersionUID = 1L; + + @Override + public Integer map(Integer value) throws Exception { + return value; + } + }) + .setParallelism(ITER_TEST_PARALLELISM); + + iteration.closeWith(iterationBody); + + StreamGraph streamGraph = env.getStreamGraph(); + streamGraph.setJobName("Test"); + + JobGraph jobGraph = streamGraph.getJobGraph(); + + Configuration config = new Configuration(); + config.addAll(jobGraph.getJobConfiguration()); + config.setLong(ConfigConstants.TASK_MANAGER_MEMORY_SIZE_KEY, -1L); + config.setInteger(ConfigConstants.TASK_MANAGER_NUM_TASK_SLOTS, 2 * jobGraph.getMaximumParallelism()); + final File checkpointDir = new File(tmpDir, "checkpoints"); + final File savepointDir = new File(tmpDir, "savepoints"); + + if (!checkpointDir.mkdir() || !savepointDir.mkdirs()) { + fail("Test setup failed: failed to create temporary directories."); + } + + config.setString(ConfigConstants.STATE_BACKEND, "filesystem"); + config.setString(FsStateBackendFactory.CHECKPOINT_DIRECTORY_URI_CONF_KEY, + checkpointDir.toURI().toString()); + config.setString(FsStateBackendFactory.MEMORY_THRESHOLD_CONF_KEY, "0"); + config.setString(ConfigConstants.SAVEPOINT_DIRECTORY_KEY, + savepointDir.toURI().toString()); + + TestingCluster cluster = new TestingCluster(config, false); + String savepointPath = null; + try { + cluster.start(); + + cluster.submitJobDetached(jobGraph); + for (OneShotLatch latch : ITER_TEST_SNAPSHOT_WAIT) { + latch.await(); + } + savepointPath = cluster.triggerSavepoint(jobGraph.getJobID()); + source.cancel(); + + jobGraph = streamGraph.getJobGraph(); + jobGraph.setSavepointRestoreSettings(SavepointRestoreSettings.forPath(savepointPath)); + + cluster.submitJobDetached(jobGraph); + for (OneShotLatch latch : ITER_TEST_RESTORE_WAIT) { + latch.await(); + } + source.cancel(); + } finally { + if (null != savepointPath) { + cluster.disposeSavepoint(savepointPath); + } + cluster.stop(); + cluster.awaitTermination(); + } + } + + private static final class IntegerStreamSource + extends RichSourceFunction<Integer> + implements ListCheckpointed<Integer> { + + private static final long serialVersionUID = 1L; + private volatile boolean running; + private volatile boolean isRestored; + private int emittedCount; + + public IntegerStreamSource() { + this.running = true; + this.isRestored = false; + this.emittedCount = 0; + } + + @Override + public void run(SourceContext<Integer> ctx) throws Exception { + + while (running) { + synchronized (ctx.getCheckpointLock()) { + ctx.collect(emittedCount); + } + + if (emittedCount < 100) { + ++emittedCount; + } else { + emittedCount = 0; + } + Thread.sleep(1); + } + } + + @Override + public void cancel() { + running = false; + } + + @Override + public List<Integer> snapshotState(long checkpointId, long timestamp) throws Exception { + ITER_TEST_CHECKPOINT_VERIFY[getRuntimeContext().getIndexOfThisSubtask()] = emittedCount; + return Collections.singletonList(emittedCount); + } + + @Override + public void restoreState(List<Integer> state) throws Exception { + if (!state.isEmpty()) { + this.emittedCount = state.get(0); + } + Assert.assertEquals(ITER_TEST_CHECKPOINT_VERIFY[getRuntimeContext().getIndexOfThisSubtask()], emittedCount); + ITER_TEST_RESTORE_WAIT[getRuntimeContext().getIndexOfThisSubtask()].trigger(); + } + } + + public static class DuplicateFilter extends RichFlatMapFunction<Integer, Integer> { + + static final ValueStateDescriptor<Boolean> descriptor = new ValueStateDescriptor<>("seen", Boolean.class, false); + private static final long serialVersionUID = 1L; + private ValueState<Boolean> operatorState; + + @Override + public void open(Configuration configuration) { + operatorState = this.getRuntimeContext().getState(descriptor); + } + + @Override + public void flatMap(Integer value, Collector<Integer> out) throws Exception { + if (!operatorState.value()) { + out.collect(value); + operatorState.update(true); + } + + if (30 == value) { + ITER_TEST_SNAPSHOT_WAIT[getRuntimeContext().getIndexOfThisSubtask()].trigger(); + } + } + } }
