This is an automated email from the ASF dual-hosted git repository.

liyu pushed a commit to branch release-1.9
in repository https://gitbox.apache.org/repos/asf/flink.git

commit bb6de2ddb37287da0def0f9d81dbc4792512e97a
Author: klion26 <[email protected]>
AuthorDate: Thu Apr 2 12:40:14 2020 +0800

    [FLINK-16576][state backends] Fix the problem of wrong mapping between 
stateId and metaInfo in HeapRestoreOperation
---
 .../checkpoint/StateAssignmentOperation.java       |  32 +---
 .../runtime/state/heap/HeapRestoreOperation.java   |   8 +-
 .../flink/runtime/state/StateBackendTestBase.java  | 209 ++++++++++++++-------
 .../util/AbstractStreamOperatorTestHarness.java    |  12 +-
 4 files changed, 155 insertions(+), 106 deletions(-)

diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
index fdb62eb..e8ec90a 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
@@ -19,6 +19,7 @@
 package org.apache.flink.runtime.checkpoint;
 
 import org.apache.flink.annotation.Internal;
+import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.runtime.executiongraph.Execution;
 import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
@@ -461,10 +462,11 @@ public class StateAssignmentOperation {
        /**
         * Extracts certain key group ranges from the given state handles and 
adds them to the collector.
         */
-       private static void extractIntersectingState(
-               Collection<KeyedStateHandle> originalSubtaskStateHandles,
-               KeyGroupRange rangeToExtract,
-               List<KeyedStateHandle> extractedStateCollector) {
+       @VisibleForTesting
+       public static void extractIntersectingState(
+                       Collection<? extends KeyedStateHandle> 
originalSubtaskStateHandles,
+                       KeyGroupRange rangeToExtract,
+                       List<KeyedStateHandle> extractedStateCollector) {
 
                for (KeyedStateHandle keyedStateHandle : 
originalSubtaskStateHandles) {
 
@@ -620,26 +622,4 @@ public class StateAssignmentOperation {
                        newParallelism);
                }
 
-       /**
-        * Determine the subset of {@link KeyGroupsStateHandle 
KeyGroupsStateHandles} with correct
-        * key group index for the given subtask {@link KeyGroupRange}.
-        *
-        * <p>This is publicly visible to be used in tests.
-        */
-       public static List<KeyedStateHandle> getKeyedStateHandles(
-               Collection<? extends KeyedStateHandle> keyedStateHandles,
-               KeyGroupRange subtaskKeyGroupRange) {
-
-               List<KeyedStateHandle> subtaskKeyedStateHandles = new 
ArrayList<>(keyedStateHandles.size());
-
-               for (KeyedStateHandle keyedStateHandle : keyedStateHandles) {
-                       KeyedStateHandle intersectedKeyedStateHandle = 
keyedStateHandle.getIntersection(subtaskKeyGroupRange);
-
-                       if (intersectedKeyedStateHandle != null) {
-                               
subtaskKeyedStateHandles.add(intersectedKeyedStateHandle);
-                       }
-               }
-
-               return subtaskKeyedStateHandles;
-       }
 }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapRestoreOperation.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapRestoreOperation.java
index b2ddbd1..90272b1 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapRestoreOperation.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapRestoreOperation.java
@@ -103,7 +103,6 @@ public class HeapRestoreOperation<K> implements 
RestoreOperation<Void> {
        @Override
        public Void restore() throws Exception {
 
-               final Map<Integer, StateMetaInfoSnapshot> kvStatesById = new 
HashMap<>();
                registeredKVStates.clear();
                registeredPQStates.clear();
 
@@ -148,6 +147,8 @@ public class HeapRestoreOperation<K> implements 
RestoreOperation<Void> {
                                List<StateMetaInfoSnapshot> restoredMetaInfos =
                                        
serializationProxy.getStateMetaInfoSnapshots();
 
+                               final Map<Integer, StateMetaInfoSnapshot> 
kvStatesById = new HashMap<>();
+
                                
createOrCheckStateForMetaInfo(restoredMetaInfos, kvStatesById);
 
                                readStateHandleStateData(
@@ -198,9 +199,8 @@ public class HeapRestoreOperation<K> implements 
RestoreOperation<Void> {
                                                
metaInfoSnapshot.getBackendStateType() + ".");
                        }
 
-                       if (registeredState == null) {
-                               kvStatesById.put(kvStatesById.size(), 
metaInfoSnapshot);
-                       }
+                       // always put metaInfo into kvStatesById, because 
kvStatesById is KeyGroupsStateHandle related
+                       kvStatesById.put(kvStatesById.size(), metaInfoSnapshot);
                }
        }
 
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
index 132fd01..6868bdf 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
@@ -112,6 +112,7 @@ import java.util.concurrent.TimeUnit;
 import java.util.stream.Stream;
 
 import static java.util.Arrays.asList;
+import static org.apache.flink.util.Preconditions.checkArgument;
 import static org.hamcrest.CoreMatchers.anyOf;
 import static org.hamcrest.CoreMatchers.isA;
 import static org.hamcrest.Matchers.containsInAnyOrder;
@@ -2910,98 +2911,149 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
         * This test verifies that state is correctly assigned to key groups 
and that restore
         * restores the relevant key groups in the backend.
         *
-        * <p>We have ten key groups. Initially, one backend is responsible for 
all ten key groups.
-        * Then we snapshot, split up the state and restore in to backends 
where each is responsible
-        * for five key groups. Then we make sure that the state is only 
available in the correct
-        * backend.
-        * @throws Exception
+        * <p>We have 128 key groups. Initially, four backends with different 
states are responsible for all the key groups equally.
+        * Different backends for the same operator may contains different 
states if we create the state in runtime (such as {@link DeltaTrigger#onElement}
+        * Then we snapshot, split up the state and restore into 2 backends 
where each is responsible
+        * for 64 key groups. Then we make sure that the state is only 
available in the correct backend.
         */
        @Test
-       public void testKeyGroupSnapshotRestore() throws Exception {
-               final int MAX_PARALLELISM = 10;
+       public void testKeyGroupSnapshotRestoreScaleDown() throws Exception {
+               testKeyGroupSnapshotRestore(4, 2, 128);
+       }
 
-               CheckpointStreamFactory streamFactory = createStreamFactory();
-               SharedStateRegistry sharedStateRegistry = new 
SharedStateRegistry();
-               final AbstractKeyedStateBackend<Integer> backend = 
createKeyedBackend(
-                               IntSerializer.INSTANCE,
-                               MAX_PARALLELISM,
-                               new KeyGroupRange(0, MAX_PARALLELISM - 1),
-                               new DummyEnvironment());
+       /**
+        * This test verifies that state is correctly assigned to key groups 
and that restore
+        * restores the relevant key groups in the backend.
+        *
+        * <p>We have 128 key groups. Initially, two backends with different 
states are responsible for all the key groups equally.
+        * Different backends for the same operator may contains different 
states if we create the state in runtime (such as {@link DeltaTrigger#onElement}
+        * Then we snapshot, split up the state and restore into 4 backends 
where each is responsible
+        * for 32 key groups. Then we make sure that the state is only 
available in the correct backend.
+        */
+       @Test
+       public void testKeyGroupSnapshotRestoreScaleUp() throws Exception {
+               testKeyGroupSnapshotRestore(2, 4, 128);
+       }
 
-               ValueStateDescriptor<String> kvId = new 
ValueStateDescriptor<>("id", String.class);
+       /**
+        * This test verifies that state is correctly assigned to key groups 
and that restore
+        * restores the relevant key groups in the backend.
+        *
+        * <p>We have 128 key groups. Initially, two backends with different 
states are responsible for all the key groups equally.
+        * Different backends for the same operator may contains different 
states if we create the state in runtime (such as {@link DeltaTrigger#onElement}
+        * Then we snapshot, split up the state and restore into 2 backends 
where each is responsible
+        * for 64 key groups. Then we make sure that the state is only 
available in the correct backend.
+        */
+       @Test
+       public void testKeyGroupsSnapshotRestoreNoRescale() throws Exception {
+               testKeyGroupSnapshotRestore(2, 2, 128);
+       }
 
-               ValueState<String> state = 
backend.getPartitionedState(VoidNamespace.INSTANCE, 
VoidNamespaceSerializer.INSTANCE, kvId);
+       /**
+        * Similar with testKeyGroupSnapshotRestoreScaleUp, but the KeyGroups 
were distributed unevenly.
+        */
+       @Test
+       public void testKeyGroupsSnapshotRestoreScaleUpUnEvenDistribute() 
throws Exception {
+               testKeyGroupSnapshotRestore(15, 77, 128);
+       }
+
+       /**
+        * Similar with testKeyGroupSnapshotRestoreScaleDown, but the KeyGroups 
were distributed unevenly.
+        */
+       @Test
+       public void testKeyGroupsSnapshotRestoreScaleDownUnEvenDistribute() 
throws Exception {
+               testKeyGroupSnapshotRestore(77, 15, 128);
+       }
 
-               // keys that fall into the first half/second half of the key 
groups, respectively
-               int keyInFirstHalf = 17;
-               int keyInSecondHalf = 42;
-               Random rand = new Random(0);
+       private void testKeyGroupSnapshotRestore(int sourceParallelism, int 
targetParallelism, int maxParallelism) throws Exception {
+               checkArgument(sourceParallelism > 0, "parallelism must be 
positive, current is %s.", sourceParallelism);
+               checkArgument(targetParallelism > 0, "parallelism must be 
positive, current is %s.", targetParallelism);
+               checkArgument(sourceParallelism <= maxParallelism, "Maximum 
parallelism must not be smaller than parallelism.");
+               checkArgument(targetParallelism <= maxParallelism, "Maximum 
parallelism must not be smaller than parallelism.");
 
-               // for each key, determine into which half of the key-group 
space they fall
-               int firstKeyHalf = 
KeyGroupRangeAssignment.assignKeyToParallelOperator(keyInFirstHalf, 
MAX_PARALLELISM, 2);
-               int secondKeyHalf = 
KeyGroupRangeAssignment.assignKeyToParallelOperator(keyInFirstHalf, 
MAX_PARALLELISM, 2);
+               Random random = new Random();
 
-               while (firstKeyHalf == secondKeyHalf) {
-                       keyInSecondHalf = rand.nextInt();
-                       secondKeyHalf = 
KeyGroupRangeAssignment.assignKeyToParallelOperator(keyInSecondHalf, 
MAX_PARALLELISM, 2);
+               CheckpointStreamFactory streamFactory = createStreamFactory();
+               SharedStateRegistry sharedStateRegistry = new 
SharedStateRegistry();
+               List<KeyGroupRange> keyGroupRanges = new ArrayList<>();
+               List<AbstractKeyedStateBackend<Integer>> stateBackends = new 
ArrayList<>();
+               for (int i = 0; i < sourceParallelism; ++i) {
+                       keyGroupRanges.add(KeyGroupRange.of(maxParallelism * i 
/ sourceParallelism, maxParallelism * (i + 1) / sourceParallelism - 1));
+                       
stateBackends.add(createKeyedBackend(IntSerializer.INSTANCE, maxParallelism, 
keyGroupRanges.get(i), new DummyEnvironment()));
                }
 
-               backend.setCurrentKey(keyInFirstHalf);
-               state.update("ShouldBeInFirstHalf");
-
-               backend.setCurrentKey(keyInSecondHalf);
-               state.update("ShouldBeInSecondHalf");
+               List<ValueStateDescriptor<String>> stateDescriptors = new 
ArrayList<>(maxParallelism);
 
+               for (int i = 0; i < maxParallelism; ++i) {
+                       // all states have different name to mock that all the 
parallelisms of one operator have different states.
+                       stateDescriptors.add(new ValueStateDescriptor<>("state" 
+ i, String.class));
+               }
 
-               KeyedStateHandle snapshot = runSnapshot(
-                       backend.snapshot(0, 0, streamFactory, 
CheckpointOptions.forCheckpointWithDefaultLocation()),
-                       sharedStateRegistry);
+               List<Integer> keyInKeyGroups = new ArrayList<>(maxParallelism);
+               List<String> expectedValue = new ArrayList<>(maxParallelism);
+               for (int i = 0; i < sourceParallelism; ++i) {
+                       AbstractKeyedStateBackend<Integer> backend = 
stateBackends.get(i);
+                       KeyGroupRange range = keyGroupRanges.get(i);
+                       for (int j = range.getStartKeyGroup(); j <= 
range.getEndKeyGroup(); ++j) {
+                               ValueState<String> state = 
backend.getPartitionedState(VoidNamespace.INSTANCE, 
VoidNamespaceSerializer.INSTANCE, stateDescriptors.get(j));
+                               int keyInKeyGroup = getKeyInKeyGroup(random, 
maxParallelism, KeyGroupRange.of(j, j));
+                               backend.setCurrentKey(keyInKeyGroup);
+                               keyInKeyGroups.add(keyInKeyGroup);
+                               String updateValue = i + ":" + j;
+                               state.update(updateValue);
+                               expectedValue.add(updateValue);
+                       }
+               }
 
-               List<KeyedStateHandle> firstHalfKeyGroupStates = 
StateAssignmentOperation.getKeyedStateHandles(
-                               Collections.singletonList(snapshot),
-                               
KeyGroupRangeAssignment.computeKeyGroupRangeForOperatorIndex(MAX_PARALLELISM, 
2, 0));
+               // snapshot
+               List<KeyedStateHandle> snapshots = new 
ArrayList<>(sourceParallelism);
+               for (int i = 0; i < sourceParallelism; ++i) {
+                       snapshots.add(
+                               runSnapshot(stateBackends.get(i).snapshot(0, 0, 
streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()),
+                               sharedStateRegistry));
+               }
 
-               List<KeyedStateHandle> secondHalfKeyGroupStates = 
StateAssignmentOperation.getKeyedStateHandles(
-                               Collections.singletonList(snapshot),
-                               
KeyGroupRangeAssignment.computeKeyGroupRangeForOperatorIndex(MAX_PARALLELISM, 
2, 1));
+               for (int i = 0; i < sourceParallelism; ++i) {
+                       stateBackends.get(i).dispose();
+               }
 
-               backend.dispose();
+               // redistribute the stateHandle
+               List<KeyGroupRange> keyGroupRangesRestore = new ArrayList<>();
+               for (int i = 0; i < targetParallelism; ++i) {
+                       
keyGroupRangesRestore.add(KeyGroupRangeAssignment.computeKeyGroupRangeForOperatorIndex(maxParallelism,
 targetParallelism, i));
+               }
+               List<List<KeyedStateHandle>> keyGroupStatesAfterDistribute = 
new ArrayList<>(targetParallelism);
+               for (int i = 0; i < targetParallelism; ++i) {
+                       List<KeyedStateHandle> keyedStateHandles = new 
ArrayList<>();
+                       StateAssignmentOperation.extractIntersectingState(
+                               snapshots,
+                               keyGroupRangesRestore.get(i),
+                               keyedStateHandles);
+                       keyGroupStatesAfterDistribute.add(keyedStateHandles);
+               }
 
-               // backend for the first half of the key group range
-               final AbstractKeyedStateBackend<Integer> firstHalfBackend = 
restoreKeyedBackend(
-                               IntSerializer.INSTANCE,
-                               MAX_PARALLELISM,
-                               new KeyGroupRange(0, 4),
-                               firstHalfKeyGroupStates,
-                               new DummyEnvironment());
+               // restore and verify
+               List<AbstractKeyedStateBackend<Integer>> targetBackends = new 
ArrayList<>(targetParallelism);
 
-               // backend for the second half of the key group range
-               final AbstractKeyedStateBackend<Integer> secondHalfBackend = 
restoreKeyedBackend(
+               for (int i = 0; i < targetParallelism; ++i) {
+                       AbstractKeyedStateBackend<Integer> backend = 
restoreKeyedBackend(
                                IntSerializer.INSTANCE,
-                               MAX_PARALLELISM,
-                               new KeyGroupRange(5, 9),
-                               secondHalfKeyGroupStates,
+                               maxParallelism,
+                               keyGroupRangesRestore.get(i),
+                               keyGroupStatesAfterDistribute.get(i),
                                new DummyEnvironment());
+                       targetBackends.add(backend);
+                       KeyGroupRange range = keyGroupRangesRestore.get(i);
+                       for (int j = range.getStartKeyGroup(); j <= 
range.getEndKeyGroup(); ++j) {
+                               ValueState<String> state = 
targetBackends.get(i).getPartitionedState(VoidNamespace.INSTANCE, 
VoidNamespaceSerializer.INSTANCE, stateDescriptors.get(j));
+                               backend.setCurrentKey(keyInKeyGroups.get(j));
+                               assertEquals(expectedValue.get(j), 
state.value());
+                       }
+               }
 
-
-               ValueState<String> firstHalfState = 
firstHalfBackend.getPartitionedState(VoidNamespace.INSTANCE, 
VoidNamespaceSerializer.INSTANCE, kvId);
-
-               firstHalfBackend.setCurrentKey(keyInFirstHalf);
-               
assertTrue(firstHalfState.value().equals("ShouldBeInFirstHalf"));
-
-               firstHalfBackend.setCurrentKey(keyInSecondHalf);
-               assertTrue(firstHalfState.value() == null);
-
-               ValueState<String> secondHalfState = 
secondHalfBackend.getPartitionedState(VoidNamespace.INSTANCE, 
VoidNamespaceSerializer.INSTANCE, kvId);
-
-               secondHalfBackend.setCurrentKey(keyInFirstHalf);
-               assertTrue(secondHalfState.value() == null);
-
-               secondHalfBackend.setCurrentKey(keyInSecondHalf);
-               
assertTrue(secondHalfState.value().equals("ShouldBeInSecondHalf"));
-
-               firstHalfBackend.dispose();
-               secondHalfBackend.dispose();
+               for (int i = 0; i < targetParallelism; ++i) {
+                       targetBackends.get(i).dispose();
+               }
        }
 
        @Test
@@ -4004,6 +4056,19 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
        }
 
        /**
+        * Returns an Integer key in specified keyGroupRange.
+        */
+       private int getKeyInKeyGroup(Random random, int maxParallelism, 
KeyGroupRange keyGroupRange) {
+               int keyInKG = random.nextInt();
+               int kg = KeyGroupRangeAssignment.assignToKeyGroup(keyInKG, 
maxParallelism);
+               while (!keyGroupRange.contains(kg)) {
+                       keyInKG = random.nextInt();
+                       kg = KeyGroupRangeAssignment.assignToKeyGroup(keyInKG, 
maxParallelism);
+               }
+               return keyInKG;
+       }
+
+       /**
         * Returns the value by getting the serialized value and deserializing 
it
         * if it is not null.
         */
diff --git 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/AbstractStreamOperatorTestHarness.java
 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/AbstractStreamOperatorTestHarness.java
index fc24956..c2b300d 100644
--- 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/AbstractStreamOperatorTestHarness.java
+++ 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/AbstractStreamOperatorTestHarness.java
@@ -329,13 +329,17 @@ public class AbstractStreamOperatorTestHarness<OUT> 
implements AutoCloseable {
 
                KeyGroupRange localKeyGroupRange = 
keyGroupPartitions.get(subtaskIndex);
 
-               List<KeyedStateHandle> localManagedKeyGroupState = 
StateAssignmentOperation.getKeyedStateHandles(
+               List<KeyedStateHandle> localManagedKeyGroupState = new 
ArrayList<>();
+               StateAssignmentOperation.extractIntersectingState(
                        operatorStateHandles.getManagedKeyedState(),
-                       localKeyGroupRange);
+                       localKeyGroupRange,
+                       localManagedKeyGroupState);
 
-               List<KeyedStateHandle> localRawKeyGroupState = 
StateAssignmentOperation.getKeyedStateHandles(
+               List<KeyedStateHandle> localRawKeyGroupState = new 
ArrayList<>();
+               StateAssignmentOperation.extractIntersectingState(
                        operatorStateHandles.getRawKeyedState(),
-                       localKeyGroupRange);
+                       localKeyGroupRange,
+                       localRawKeyGroupState);
 
                StateObjectCollection<OperatorStateHandle> 
managedOperatorStates = operatorStateHandles.getManagedOperatorState();
                Collection<OperatorStateHandle> localManagedOperatorState;

Reply via email to