This is an automated email from the ASF dual-hosted git repository. tangyun pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
commit 1bf45b25791cc3fad8b7d0d863caa9b0eef9a87b Author: fredia <[email protected]> AuthorDate: Thu Mar 10 11:33:59 2022 +0800 [FLINK-21321][Runtime/StateBackends] Add ITCases for rescaling from checkpoint --- .../state/RocksDBIncrementalCheckpointUtils.java | 2 +- ...ncrementalCheckpointRescalingBenchmarkTest.java | 240 ----------------- .../RescaleCheckpointManuallyITCase.java | 286 +++++++++++++++++++++ .../flink/test/checkpointing/RescalingITCase.java | 136 +--------- .../ResumeCheckpointManuallyITCase.java | 65 +---- .../checkpointing/utils/RescalingTestUtils.java | 162 ++++++++++++ .../java/org/apache/flink/test/util/TestUtils.java | 22 ++ 7 files changed, 480 insertions(+), 433 deletions(-) diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBIncrementalCheckpointUtils.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBIncrementalCheckpointUtils.java index 467156c..23c7867 100644 --- a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBIncrementalCheckpointUtils.java +++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBIncrementalCheckpointUtils.java @@ -149,7 +149,7 @@ public class RocksDBIncrementalCheckpointUtils { // Using RocksDB's deleteRange will take advantage of delete // tombstones, which mark the range as deleted. // - // https://github.com/facebook/rocksdb/blob/bcd32560dd5898956b9d24553c2bb3c1b1d2319f/include/rocksdb/db.h#L357-L371 + // https://github.com/ververica/frocksdb/blob/FRocksDB-6.20.3/include/rocksdb/db.h#L363-L377 db.deleteRange(columnFamilyHandle, beginKeyBytes, endKeyBytes); } } diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksIncrementalCheckpointRescalingBenchmarkTest.java b/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksIncrementalCheckpointRescalingBenchmarkTest.java deleted file mode 100644 index a4e267b..0000000 --- a/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksIncrementalCheckpointRescalingBenchmarkTest.java +++ /dev/null @@ -1,240 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.contrib.streaming.state; - -import org.apache.flink.api.common.state.ValueState; -import org.apache.flink.api.common.state.ValueStateDescriptor; -import org.apache.flink.api.common.typeinfo.BasicTypeInfo; -import org.apache.flink.api.java.functions.KeySelector; -import org.apache.flink.configuration.Configuration; -import org.apache.flink.runtime.checkpoint.OperatorSubtaskState; -import org.apache.flink.runtime.checkpoint.StateAssignmentOperation; -import org.apache.flink.runtime.state.KeyGroupRange; -import org.apache.flink.streaming.api.functions.KeyedProcessFunction; -import org.apache.flink.streaming.api.operators.KeyedProcessOperator; -import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; -import org.apache.flink.streaming.util.AbstractStreamOperatorTestHarness; -import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness; -import org.apache.flink.testutils.junit.RetryOnFailure; -import org.apache.flink.util.Collector; -import org.apache.flink.util.TestLogger; - -import org.junit.Before; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; - -import java.util.Collection; -import java.util.List; - -/** Test runs the benchmark for incremental checkpoint rescaling. */ -public class RocksIncrementalCheckpointRescalingBenchmarkTest extends TestLogger { - - @Rule public TemporaryFolder rootFolder = new TemporaryFolder(); - - private static final int maxParallelism = 10; - - private static final int recordCount = 1_000; - - /** partitionParallelism is the parallelism to use for creating the partitionedSnapshot. */ - private static final int partitionParallelism = 2; - - /** - * repartitionParallelism is the parallelism to use during the test for the repartition step. - * - * <p>NOTE: To trigger {@link - * org.apache.flink.contrib.streaming.state.restore.RocksDBIncrementalRestoreOperation#restoreWithRescaling(Collection)}, - * where the improvement code is exercised, the target parallelism must not be divisible by - * {@link partitionParallelism}. If this parallelism was instead 4, then there is no rescale. - */ - private static final int repartitionParallelism = 3; - - /** partitionedSnapshot is a partitioned incremental RocksDB snapshot. */ - private OperatorSubtaskState partitionedSnapshot; - - private KeySelector<Integer, Integer> keySelector = new TestKeySelector(); - - /** - * The benchmark's preparation will: - * - * <ol> - * <li>Create a stateful operator and process records to persist state. - * <li>Snapshot the state and re-partition it so the test operates on a partitioned state. - * </ol> - * - * @throws Exception - */ - @Before - public void before() throws Exception { - OperatorSubtaskState snapshot; - // Initialize the test harness with a a task parallelism of 1. - try (KeyedOneInputStreamOperatorTestHarness<Integer, Integer, Integer> harness = - getHarnessTest(keySelector, maxParallelism, 1, 0)) { - // Set the state backend of the harness to RocksDB. - harness.setStateBackend(getStateBackend()); - // Initialize the harness. - harness.open(); - // Push the test records into the operator to trigger state updates. - Integer[] records = new Integer[recordCount]; - for (int i = 0; i < recordCount; i++) { - harness.processElement(new StreamRecord<>(i, 1)); - } - // Grab a snapshot of the state. - snapshot = harness.snapshot(0, 0); - } - - // Now, re-partition to create a partitioned state. - KeyedOneInputStreamOperatorTestHarness<Integer, Integer, Integer>[] partitionedTestHarness = - new KeyedOneInputStreamOperatorTestHarness[partitionParallelism]; - List<KeyGroupRange> keyGroupPartitions = - StateAssignmentOperation.createKeyGroupPartitions( - maxParallelism, partitionParallelism); - try { - for (int i = 0; i < partitionParallelism; i++) { - // Initialize, open, and then re-snapshot the two subtasks to create a partitioned - // incremental RocksDB snapshot. - OperatorSubtaskState subtaskState = - AbstractStreamOperatorTestHarness.repartitionOperatorState( - snapshot, maxParallelism, 1, partitionParallelism, i); - KeyGroupRange localKeyGroupRange20 = keyGroupPartitions.get(i); - - partitionedTestHarness[i] = - getHarnessTest(keySelector, maxParallelism, partitionParallelism, i); - partitionedTestHarness[i].setStateBackend(getStateBackend()); - partitionedTestHarness[i].setup(); - partitionedTestHarness[i].initializeState(subtaskState); - partitionedTestHarness[i].open(); - } - - partitionedSnapshot = - AbstractStreamOperatorTestHarness.repackageState( - partitionedTestHarness[0].snapshot(1, 2), - partitionedTestHarness[1].snapshot(1, 2)); - - } finally { - closeHarness(partitionedTestHarness); - } - } - - @Test(timeout = 1000) - @RetryOnFailure(times = 3) - public void benchmarkScalingUp() throws Exception { - long benchmarkTime = 0; - - // Trigger the incremental re-scaling via restoreWithRescaling by repartitioning it from - // parallelism of >1 to a higher parallelism. Time spent during this step includes the cost - // of incremental rescaling. - List<KeyGroupRange> keyGroupPartitions = - StateAssignmentOperation.createKeyGroupPartitions( - maxParallelism, repartitionParallelism); - - long fullStateSize = partitionedSnapshot.getStateSize(); - - for (int i = 0; i < repartitionParallelism; i++) { - OperatorSubtaskState subtaskState = - AbstractStreamOperatorTestHarness.repartitionOperatorState( - partitionedSnapshot, - maxParallelism, - partitionParallelism, - repartitionParallelism, - i); - KeyGroupRange localKeyGroupRange20 = keyGroupPartitions.get(i); - - try (KeyedOneInputStreamOperatorTestHarness<Integer, Integer, Integer> subtaskHarness = - getHarnessTest(keySelector, maxParallelism, repartitionParallelism, i)) { - RocksDBStateBackend backend = getStateBackend(); - subtaskHarness.setStateBackend(backend); - subtaskHarness.setup(); - - // Precisely measure the call-site that triggers the restore operation. - long startingTime = System.nanoTime(); - subtaskHarness.initializeState(subtaskState); - benchmarkTime += System.nanoTime() - startingTime; - } - } - - log.error( - "--------------> performance for incremental checkpoint re-scaling <--------------"); - log.error( - "rescale from {} to {} with {} records took: {} nanoseconds", - partitionParallelism, - repartitionParallelism, - recordCount, - benchmarkTime); - } - - private KeyedOneInputStreamOperatorTestHarness<Integer, Integer, Integer> getHarnessTest( - KeySelector<Integer, Integer> keySelector, - int maxParallelism, - int taskParallelism, - int subtaskIdx) - throws Exception { - return new KeyedOneInputStreamOperatorTestHarness<>( - new KeyedProcessOperator<>(new TestKeyedFunction()), - keySelector, - BasicTypeInfo.INT_TYPE_INFO, - maxParallelism, - taskParallelism, - subtaskIdx); - } - - private void closeHarness(KeyedOneInputStreamOperatorTestHarness<?, ?, ?>[] harnessArr) - throws Exception { - for (KeyedOneInputStreamOperatorTestHarness<?, ?, ?> harness : harnessArr) { - if (harness != null) { - harness.close(); - } - } - } - - private RocksDBStateBackend getStateBackend() throws Exception { - return new RocksDBStateBackend("file://" + rootFolder.newFolder().getAbsolutePath(), true); - } - - /** A simple keyed function for tests. */ - private class TestKeyedFunction extends KeyedProcessFunction<Integer, Integer, Integer> { - - public ValueStateDescriptor<Integer> stateDescriptor; - private ValueState<Integer> counterState; - - @Override - public void open(Configuration parameters) throws Exception { - super.open(parameters); - stateDescriptor = new ValueStateDescriptor<Integer>("counter", Integer.class); - counterState = this.getRuntimeContext().getState(stateDescriptor); - } - - @Override - public void processElement(Integer incomingValue, Context ctx, Collector<Integer> out) - throws Exception { - Integer oldValue = counterState.value(); - Integer newValue = oldValue != null ? oldValue + incomingValue : incomingValue; - counterState.update(newValue); - out.collect(newValue); - } - } - - /** A simple key selector for tests. */ - private class TestKeySelector implements KeySelector<Integer, Integer> { - @Override - public Integer getKey(Integer value) throws Exception { - return value; - } - } -} diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescaleCheckpointManuallyITCase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescaleCheckpointManuallyITCase.java new file mode 100644 index 0000000..8c05afa --- /dev/null +++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescaleCheckpointManuallyITCase.java @@ -0,0 +1,286 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.test.checkpointing; + +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.client.program.ClusterClient; +import org.apache.flink.configuration.CheckpointingOptions; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.configuration.StateBackendOptions; +import org.apache.flink.runtime.jobgraph.JobGraph; +import org.apache.flink.runtime.jobgraph.SavepointRestoreSettings; +import org.apache.flink.runtime.state.KeyGroupRangeAssignment; +import org.apache.flink.runtime.testutils.MiniClusterResourceConfiguration; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.CheckpointConfig; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.test.checkpointing.utils.RescalingTestUtils; +import org.apache.flink.test.util.MiniClusterWithClientResource; +import org.apache.flink.test.util.TestUtils; +import org.apache.flink.util.TestLogger; + +import org.junit.Before; +import org.junit.ClassRule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.io.File; +import java.util.HashSet; +import java.util.Set; +import java.util.concurrent.CountDownLatch; + +import static org.apache.flink.test.util.TestUtils.submitJobAndWaitForResult; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +/** Test checkpoint rescaling for incremental rocksdb. */ +public class RescaleCheckpointManuallyITCase extends TestLogger { + + private static final int NUM_TASK_MANAGERS = 2; + private static final int SLOTS_PER_TASK_MANAGER = 2; + + private static MiniClusterWithClientResource cluster; + private File checkpointDir; + + @ClassRule public static TemporaryFolder temporaryFolder = new TemporaryFolder(); + + @Before + public void setup() throws Exception { + Configuration config = new Configuration(); + + checkpointDir = temporaryFolder.newFolder(); + + config.setString(StateBackendOptions.STATE_BACKEND, "rocksdb"); + config.setString( + CheckpointingOptions.CHECKPOINTS_DIRECTORY, checkpointDir.toURI().toString()); + config.setBoolean(CheckpointingOptions.INCREMENTAL_CHECKPOINTS, true); + + cluster = + new MiniClusterWithClientResource( + new MiniClusterResourceConfiguration.Builder() + .setConfiguration(config) + .setNumberTaskManagers(NUM_TASK_MANAGERS) + .setNumberSlotsPerTaskManager(SLOTS_PER_TASK_MANAGER) + .build()); + cluster.before(); + } + + @Test + public void testCheckpointRescalingInKeyedState() throws Exception { + testCheckpointRescalingKeyedState(false); + } + + @Test + public void testCheckpointRescalingOutKeyedState() throws Exception { + testCheckpointRescalingKeyedState(true); + } + + /** + * Tests that a job with purely keyed state can be restarted from a checkpoint with a different + * parallelism. + */ + public void testCheckpointRescalingKeyedState(boolean scaleOut) throws Exception { + final int numberKeys = 42; + final int numberElements = 1000; + final int numberElements2 = 500; + final int parallelism = scaleOut ? 3 : 4; + final int parallelism2 = scaleOut ? 4 : 3; + final int maxParallelism = 13; + + cluster.before(); + + ClusterClient<?> client = cluster.getClusterClient(); + String checkpointPath = + runJobAndGetCheckpoint( + numberKeys, + numberElements, + parallelism, + maxParallelism, + client, + checkpointDir); + + assertNotNull(checkpointPath); + + restoreAndAssert( + parallelism2, + maxParallelism, + maxParallelism, + numberKeys, + numberElements2, + numberElements + numberElements2, + client, + checkpointPath); + } + + private static String runJobAndGetCheckpoint( + int numberKeys, + int numberElements, + int parallelism, + int maxParallelism, + ClusterClient<?> client, + File checkpointDir) + throws Exception { + try { + JobGraph jobGraph = + createJobGraphWithKeyedState( + parallelism, maxParallelism, numberKeys, numberElements, false, 100); + NotifyingDefiniteKeySource.sourceLatch = new CountDownLatch(parallelism); + client.submitJob(jobGraph).get(); + NotifyingDefiniteKeySource.sourceLatch.await(); + + RescalingTestUtils.SubtaskIndexFlatMapper.workCompletedLatch.await(); + + // verify the current state + Set<Tuple2<Integer, Integer>> actualResult = + RescalingTestUtils.CollectionSink.getElementsSet(); + + Set<Tuple2<Integer, Integer>> expectedResult = new HashSet<>(); + + for (int key = 0; key < numberKeys; key++) { + int keyGroupIndex = KeyGroupRangeAssignment.assignToKeyGroup(key, maxParallelism); + expectedResult.add( + Tuple2.of( + KeyGroupRangeAssignment.computeOperatorIndexForKeyGroup( + maxParallelism, parallelism, keyGroupIndex), + numberElements * key)); + } + + assertEquals(expectedResult, actualResult); + NotifyingDefiniteKeySource.sourceLatch.await(); + + TestUtils.waitUntilExternalizedCheckpointCreated(checkpointDir); + client.cancel(jobGraph.getJobID()).get(); + TestUtils.waitUntilJobCanceled(jobGraph.getJobID(), client); + return TestUtils.getMostRecentCompletedCheckpoint(checkpointDir).getAbsolutePath(); + } finally { + RescalingTestUtils.CollectionSink.clearElementsSet(); + } + } + + private void restoreAndAssert( + int restoreParallelism, + int restoreMaxParallelism, + int maxParallelismBefore, + int numberKeys, + int numberElements, + int numberElementsExpect, + ClusterClient<?> client, + String restorePath) + throws Exception { + try { + + JobGraph scaledJobGraph = + createJobGraphWithKeyedState( + restoreParallelism, + restoreMaxParallelism, + numberKeys, + numberElements, + true, + 100); + + scaledJobGraph.setSavepointRestoreSettings( + SavepointRestoreSettings.forPath(restorePath)); + + submitJobAndWaitForResult(client, scaledJobGraph, getClass().getClassLoader()); + + Set<Tuple2<Integer, Integer>> actualResult2 = + RescalingTestUtils.CollectionSink.getElementsSet(); + + Set<Tuple2<Integer, Integer>> expectedResult2 = new HashSet<>(); + + for (int key = 0; key < numberKeys; key++) { + int keyGroupIndex = + KeyGroupRangeAssignment.assignToKeyGroup(key, maxParallelismBefore); + expectedResult2.add( + Tuple2.of( + KeyGroupRangeAssignment.computeOperatorIndexForKeyGroup( + maxParallelismBefore, restoreParallelism, keyGroupIndex), + key * numberElementsExpect)); + } + assertEquals(expectedResult2, actualResult2); + } finally { + RescalingTestUtils.CollectionSink.clearElementsSet(); + } + } + + private static JobGraph createJobGraphWithKeyedState( + int parallelism, + int maxParallelism, + int numberKeys, + int numberElements, + boolean terminateAfterEmission, + int checkpointingInterval) { + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + env.setParallelism(parallelism); + if (0 < maxParallelism) { + env.getConfig().setMaxParallelism(maxParallelism); + } + env.enableCheckpointing(checkpointingInterval); + env.getCheckpointConfig() + .setExternalizedCheckpointCleanup( + CheckpointConfig.ExternalizedCheckpointCleanup.RETAIN_ON_CANCELLATION); + env.setRestartStrategy(RestartStrategies.noRestart()); + env.getConfig().setUseSnapshotCompression(true); + + DataStream<Integer> input = + env.addSource( + new NotifyingDefiniteKeySource( + numberKeys, numberElements, terminateAfterEmission)) + .keyBy( + new KeySelector<Integer, Integer>() { + private static final long serialVersionUID = + -7952298871120320940L; + + @Override + public Integer getKey(Integer value) throws Exception { + return value; + } + }); + RescalingTestUtils.SubtaskIndexFlatMapper.workCompletedLatch = + new CountDownLatch(numberKeys); + + DataStream<Tuple2<Integer, Integer>> result = + input.flatMap(new RescalingTestUtils.SubtaskIndexFlatMapper(numberElements)); + + result.addSink(new RescalingTestUtils.CollectionSink<>()); + + return env.getStreamGraph().getJobGraph(); + } + + private static class NotifyingDefiniteKeySource extends RescalingTestUtils.DefiniteKeySource { + private static final long serialVersionUID = 8120981235081181746L; + + private static CountDownLatch sourceLatch; + + public NotifyingDefiniteKeySource( + int numberKeys, int numberElements, boolean terminateAfterEmission) { + super(numberKeys, numberElements, terminateAfterEmission); + } + + @Override + public void run(SourceContext<Integer> ctx) throws Exception { + if (sourceLatch != null) { + sourceLatch.countDown(); + } + super.run(ctx); + } + } +} diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java index ad43eae..d27d6f4 100644 --- a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java +++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java @@ -19,12 +19,9 @@ package org.apache.flink.test.checkpointing; import org.apache.flink.api.common.JobID; -import org.apache.flink.api.common.functions.RichFlatMapFunction; import org.apache.flink.api.common.restartstrategy.RestartStrategies; import org.apache.flink.api.common.state.ListState; import org.apache.flink.api.common.state.ListStateDescriptor; -import org.apache.flink.api.common.state.ValueState; -import org.apache.flink.api.common.state.ValueStateDescriptor; import org.apache.flink.api.common.time.Deadline; import org.apache.flink.api.common.time.Time; import org.apache.flink.api.common.typeutils.base.IntSerializer; @@ -50,12 +47,13 @@ import org.apache.flink.streaming.api.checkpoint.ListCheckpointed; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.streaming.api.functions.sink.DiscardingSink; -import org.apache.flink.streaming.api.functions.sink.SinkFunction; import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction; import org.apache.flink.streaming.api.functions.source.SourceFunction; +import org.apache.flink.test.checkpointing.utils.RescalingTestUtils.CollectionSink; +import org.apache.flink.test.checkpointing.utils.RescalingTestUtils.DefiniteKeySource; +import org.apache.flink.test.checkpointing.utils.RescalingTestUtils.SubtaskIndexFlatMapper; import org.apache.flink.test.util.MiniClusterWithClientResource; import org.apache.flink.testutils.TestingUtils; -import org.apache.flink.util.Collector; import org.apache.flink.util.TestLogger; import org.apache.flink.util.concurrent.FutureUtils; @@ -77,7 +75,6 @@ import java.util.HashSet; import java.util.List; import java.util.Set; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; @@ -672,7 +669,7 @@ public class RescalingITCase extends TestLogger { DataStream<Integer> input = env.addSource( - new SubtaskIndexSource( + new DefiniteKeySource( numberKeys, numberElements, terminateAfterEmission)) .keyBy( new KeySelector<Integer, Integer>() { @@ -736,60 +733,7 @@ public class RescalingITCase extends TestLogger { return env.getStreamGraph().getJobGraph(); } - private static class SubtaskIndexSource extends RichParallelSourceFunction<Integer> { - - private static final long serialVersionUID = -400066323594122516L; - - private final int numberKeys; - private final int numberElements; - private final boolean terminateAfterEmission; - - protected int counter = 0; - - private boolean running = true; - - SubtaskIndexSource(int numberKeys, int numberElements, boolean terminateAfterEmission) { - - this.numberKeys = numberKeys; - this.numberElements = numberElements; - this.terminateAfterEmission = terminateAfterEmission; - } - - @Override - public void run(SourceContext<Integer> ctx) throws Exception { - final Object lock = ctx.getCheckpointLock(); - final int subtaskIndex = getRuntimeContext().getIndexOfThisSubtask(); - - while (running) { - - if (counter < numberElements) { - synchronized (lock) { - for (int value = subtaskIndex; - value < numberKeys; - value += getRuntimeContext().getNumberOfParallelSubtasks()) { - - ctx.collect(value); - } - - counter++; - } - } else { - if (terminateAfterEmission) { - running = false; - } else { - Thread.sleep(100); - } - } - } - } - - @Override - public void cancel() { - running = false; - } - } - - private static class SubtaskIndexNonPartitionedStateSource extends SubtaskIndexSource + private static class SubtaskIndexNonPartitionedStateSource extends DefiniteKeySource implements ListCheckpointed<Integer> { private static final long serialVersionUID = 8388073059042040203L; @@ -814,76 +758,6 @@ public class RescalingITCase extends TestLogger { } } - private static class SubtaskIndexFlatMapper - extends RichFlatMapFunction<Integer, Tuple2<Integer, Integer>> - implements CheckpointedFunction { - - private static final long serialVersionUID = 5273172591283191348L; - - private static CountDownLatch workCompletedLatch = new CountDownLatch(1); - - private transient ValueState<Integer> counter; - private transient ValueState<Integer> sum; - - private final int numberElements; - - SubtaskIndexFlatMapper(int numberElements) { - this.numberElements = numberElements; - } - - @Override - public void flatMap(Integer value, Collector<Tuple2<Integer, Integer>> out) - throws Exception { - - int count = counter.value() + 1; - counter.update(count); - - int s = sum.value() + value; - sum.update(s); - - if (count % numberElements == 0) { - out.collect(Tuple2.of(getRuntimeContext().getIndexOfThisSubtask(), s)); - workCompletedLatch.countDown(); - } - } - - @Override - public void snapshotState(FunctionSnapshotContext context) throws Exception { - // all managed, nothing to do. - } - - @Override - public void initializeState(FunctionInitializationContext context) throws Exception { - counter = - context.getKeyedStateStore() - .getState(new ValueStateDescriptor<>("counter", Integer.class, 0)); - sum = - context.getKeyedStateStore() - .getState(new ValueStateDescriptor<>("sum", Integer.class, 0)); - } - } - - private static class CollectionSink<IN> implements SinkFunction<IN> { - - private static Set<Object> elements = - Collections.newSetFromMap(new ConcurrentHashMap<Object, Boolean>()); - - private static final long serialVersionUID = -1652452958040267745L; - - public static <IN> Set<IN> getElementsSet() { - return (Set<IN>) elements; - } - - public static void clearElementsSet() { - elements.clear(); - } - - @Override - public void invoke(IN value) throws Exception { - elements.add(value); - } - } - private static class StateSourceBase extends RichParallelSourceFunction<Integer> { private static final long serialVersionUID = 7512206069681177940L; diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/ResumeCheckpointManuallyITCase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/ResumeCheckpointManuallyITCase.java index eeda1b0..ab5d7f9 100644 --- a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/ResumeCheckpointManuallyITCase.java +++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/ResumeCheckpointManuallyITCase.java @@ -18,8 +18,6 @@ package org.apache.flink.test.checkpointing; -import org.apache.flink.api.common.JobID; -import org.apache.flink.api.common.JobStatus; import org.apache.flink.api.common.eventtime.AscendingTimestampsWatermarks; import org.apache.flink.api.common.eventtime.TimestampAssigner; import org.apache.flink.api.common.eventtime.TimestampAssignerSupplier; @@ -45,6 +43,7 @@ import org.apache.flink.streaming.api.windowing.assigners.TumblingEventTimeWindo import org.apache.flink.streaming.api.windowing.time.Time; import org.apache.flink.test.state.ManualWindowSpeedITCase; import org.apache.flink.test.util.MiniClusterWithClientResource; +import org.apache.flink.test.util.TestUtils; import org.apache.flink.util.TestLogger; import org.apache.curator.test.TestingServer; @@ -56,12 +55,7 @@ import javax.annotation.Nullable; import java.io.File; import java.io.IOException; -import java.nio.file.Files; -import java.nio.file.Path; -import java.util.Optional; import java.util.concurrent.CountDownLatch; -import java.util.concurrent.ExecutionException; -import java.util.stream.Stream; import static org.junit.Assert.assertNotNull; @@ -306,62 +300,11 @@ public class ResumeCheckpointManuallyITCase extends TestLogger { // wait until all sources have been started NotifyingInfiniteTupleSource.countDownLatch.await(); - waitUntilExternalizedCheckpointCreated(checkpointDir, initialJobGraph.getJobID()); + TestUtils.waitUntilExternalizedCheckpointCreated(checkpointDir); client.cancel(initialJobGraph.getJobID()).get(); - waitUntilCanceled(initialJobGraph.getJobID(), client); + TestUtils.waitUntilJobCanceled(initialJobGraph.getJobID(), client); - return getExternalizedCheckpointCheckpointPath(checkpointDir, initialJobGraph.getJobID()); - } - - private static String getExternalizedCheckpointCheckpointPath(File checkpointDir, JobID jobId) - throws IOException { - Optional<Path> checkpoint = findExternalizedCheckpoint(checkpointDir, jobId); - if (!checkpoint.isPresent()) { - throw new AssertionError("No complete checkpoint could be found."); - } else { - return checkpoint.get().toString(); - } - } - - private static void waitUntilExternalizedCheckpointCreated(File checkpointDir, JobID jobId) - throws InterruptedException, IOException { - while (true) { - Thread.sleep(50); - Optional<Path> externalizedCheckpoint = - findExternalizedCheckpoint(checkpointDir, jobId); - if (externalizedCheckpoint.isPresent()) { - break; - } - } - } - - private static Optional<Path> findExternalizedCheckpoint(File checkpointDir, JobID jobId) - throws IOException { - try (Stream<Path> checkpoints = - Files.list(checkpointDir.toPath().resolve(jobId.toString()))) { - return checkpoints - .filter(path -> path.getFileName().toString().startsWith("chk-")) - .filter( - path -> { - try (Stream<Path> checkpointFiles = Files.list(path)) { - return checkpointFiles.anyMatch( - child -> - child.getFileName() - .toString() - .contains("meta")); - } catch (IOException ignored) { - return false; - } - }) - .findAny(); - } - } - - private static void waitUntilCanceled(JobID jobId, ClusterClient<?> client) - throws ExecutionException, InterruptedException { - while (client.getJobStatus(jobId).get() != JobStatus.CANCELED) { - Thread.sleep(50); - } + return TestUtils.getMostRecentCompletedCheckpoint(checkpointDir).getAbsolutePath(); } private static JobGraph getJobGraph(StateBackend backend, @Nullable String externalCheckpoint) { diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/utils/RescalingTestUtils.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/utils/RescalingTestUtils.java new file mode 100644 index 0000000..87214b7 --- /dev/null +++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/utils/RescalingTestUtils.java @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.test.checkpointing.utils; + +import org.apache.flink.api.common.functions.RichFlatMapFunction; +import org.apache.flink.api.common.state.ValueState; +import org.apache.flink.api.common.state.ValueStateDescriptor; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.runtime.state.FunctionInitializationContext; +import org.apache.flink.runtime.state.FunctionSnapshotContext; +import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction; +import org.apache.flink.streaming.api.functions.sink.SinkFunction; +import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction; +import org.apache.flink.util.Collector; + +import java.util.Collections; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CountDownLatch; + +/** Test utilities for rescaling. */ +public class RescalingTestUtils { + + /** A parallel source with definite keys. */ + public static class DefiniteKeySource extends RichParallelSourceFunction<Integer> { + + private static final long serialVersionUID = -400066323594122516L; + + private final int numberKeys; + private final int numberElements; + private final boolean terminateAfterEmission; + + protected int counter = 0; + + private boolean running = true; + + public DefiniteKeySource( + int numberKeys, int numberElements, boolean terminateAfterEmission) { + this.numberKeys = numberKeys; + this.numberElements = numberElements; + this.terminateAfterEmission = terminateAfterEmission; + } + + @Override + public void run(SourceContext<Integer> ctx) throws Exception { + final Object lock = ctx.getCheckpointLock(); + final int subtaskIndex = getRuntimeContext().getIndexOfThisSubtask(); + while (running) { + + if (counter < numberElements) { + synchronized (lock) { + for (int value = subtaskIndex; + value < numberKeys; + value += getRuntimeContext().getNumberOfParallelSubtasks()) { + ctx.collect(value); + } + counter++; + } + } else { + if (terminateAfterEmission) { + running = false; + } else { + Thread.sleep(100); + } + } + } + } + + @Override + public void cancel() { + running = false; + } + } + + /** A flatMapper with the index of subtask. */ + public static class SubtaskIndexFlatMapper + extends RichFlatMapFunction<Integer, Tuple2<Integer, Integer>> + implements CheckpointedFunction { + + private static final long serialVersionUID = 5273172591283191348L; + + public static CountDownLatch workCompletedLatch = new CountDownLatch(1); + + private transient ValueState<Integer> counter; + private transient ValueState<Integer> sum; + + private final int numberElements; + + public SubtaskIndexFlatMapper(int numberElements) { + this.numberElements = numberElements; + } + + @Override + public void flatMap(Integer value, Collector<Tuple2<Integer, Integer>> out) + throws Exception { + + int count = counter.value() + 1; + counter.update(count); + + int s = sum.value() + value; + sum.update(s); + + if (count % numberElements == 0) { + out.collect(Tuple2.of(getRuntimeContext().getIndexOfThisSubtask(), s)); + workCompletedLatch.countDown(); + } + } + + @Override + public void snapshotState(FunctionSnapshotContext context) throws Exception { + // all managed, nothing to do. + } + + @Override + public void initializeState(FunctionInitializationContext context) throws Exception { + counter = + context.getKeyedStateStore() + .getState(new ValueStateDescriptor<>("counter", Integer.class, 0)); + sum = + context.getKeyedStateStore() + .getState(new ValueStateDescriptor<>("sum", Integer.class, 0)); + } + } + + /** Sink for collecting results into a collection. */ + public static class CollectionSink<IN> implements SinkFunction<IN> { + + private static final Set<Object> elements = + Collections.newSetFromMap(new ConcurrentHashMap<>()); + + private static final long serialVersionUID = -1652452958040267745L; + + public static <IN> Set<IN> getElementsSet() { + return (Set<IN>) elements; + } + + public static void clearElementsSet() { + elements.clear(); + } + + @Override + public void invoke(IN value) throws Exception { + elements.add(value); + } + } +} diff --git a/flink-tests/src/test/java/org/apache/flink/test/util/TestUtils.java b/flink-tests/src/test/java/org/apache/flink/test/util/TestUtils.java index 1856960..bd69b85 100644 --- a/flink-tests/src/test/java/org/apache/flink/test/util/TestUtils.java +++ b/flink-tests/src/test/java/org/apache/flink/test/util/TestUtils.java @@ -18,6 +18,8 @@ package org.apache.flink.test.util; +import org.apache.flink.api.common.JobID; +import org.apache.flink.api.common.JobStatus; import org.apache.flink.client.program.ClusterClient; import org.apache.flink.core.execution.JobClient; import org.apache.flink.runtime.checkpoint.Checkpoints; @@ -39,6 +41,7 @@ import java.nio.file.Path; import java.nio.file.attribute.BasicFileAttributes; import java.util.Comparator; import java.util.Optional; +import java.util.concurrent.ExecutionException; import static org.apache.flink.runtime.state.filesystem.AbstractFsCheckpointStorageAccess.CHECKPOINT_DIR_PREFIX; import static org.apache.flink.runtime.state.filesystem.AbstractFsCheckpointStorageAccess.METADATA_FILE_NAME; @@ -139,4 +142,23 @@ public class TestUtils { return false; // should never happen } } + + public static void waitUntilExternalizedCheckpointCreated(File checkpointDir) + throws InterruptedException, IOException { + while (true) { + Thread.sleep(50); + Optional<File> externalizedCheckpoint = + getMostRecentCompletedCheckpointMaybe(checkpointDir); + if (externalizedCheckpoint.isPresent()) { + break; + } + } + } + + public static void waitUntilJobCanceled(JobID jobId, ClusterClient<?> client) + throws ExecutionException, InterruptedException { + while (client.getJobStatus(jobId).get() != JobStatus.CANCELED) { + Thread.sleep(50); + } + } }
