This is an automated email from the ASF dual-hosted git repository.

dmvk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit b7c314b46acdde0936a0e2d329f48972d1ea3e0b
Author: David Moravek <d...@apache.org>
AuthorDate: Tue Jan 16 00:07:44 2024 +0100

    [FLINK-34063][runtime] Fix OperatorState repartitioning when compression is 
enabled. We should only write compression headers once, at the end of the 
"value" part of the serialized stream, to make sure we can always seek to a 
split point.
---
 ...efaultOperatorStateBackendSnapshotStrategy.java | 31 ++++-----
 .../state/OperatorStateRestoreOperation.java       | 32 +++++-----
 .../state/OperatorStateRestoreOperationTest.java   | 73 ++++++++++++++++++++--
 3 files changed, 93 insertions(+), 43 deletions(-)

diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackendSnapshotStrategy.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackendSnapshotStrategy.java
index 52df5cad736..d81f7b3a1c0 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackendSnapshotStrategy.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackendSnapshotStrategy.java
@@ -170,35 +170,26 @@ class DefaultOperatorStateBackendSnapshotStrategy
             final Map<String, OperatorStateHandle.StateMetaInfo> 
writtenStatesMetaData =
                     
CollectionUtil.newHashMapWithExpectedSize(initialMapCapacity);
 
-            for (Map.Entry<String, PartitionableListState<?>> entry :
-                    registeredOperatorStatesDeepCopies.entrySet()) {
+            try (final CompressibleFSDataOutputStream compressedLocalOut =
+                    new CompressibleFSDataOutputStream(
+                            localOut,
+                            compressionDecorator)) { // closes only the outer 
compression stream
+                for (Map.Entry<String, PartitionableListState<?>> entry :
+                        registeredOperatorStatesDeepCopies.entrySet()) {
 
-                PartitionableListState<?> value = entry.getValue();
-                // create the compressed stream for each state to have the 
compression header for
-                // each
-                try (final CompressibleFSDataOutputStream compressedLocalOut =
-                        new CompressibleFSDataOutputStream(
-                                localOut,
-                                compressionDecorator)) { // closes only the 
outer compression stream
+                    PartitionableListState<?> value = entry.getValue();
                     long[] partitionOffsets = value.write(compressedLocalOut);
                     OperatorStateHandle.Mode mode = 
value.getStateMetaInfo().getAssignmentMode();
                     writtenStatesMetaData.put(
                             entry.getKey(),
                             new 
OperatorStateHandle.StateMetaInfo(partitionOffsets, mode));
                 }
-            }
 
-            // ... and the broadcast states themselves ...
-            for (Map.Entry<String, BackendWritableBroadcastState<?, ?>> entry :
-                    registeredBroadcastStatesDeepCopies.entrySet()) {
+                // ... and the broadcast states themselves ...
+                for (Map.Entry<String, BackendWritableBroadcastState<?, ?>> 
entry :
+                        registeredBroadcastStatesDeepCopies.entrySet()) {
 
-                BackendWritableBroadcastState<?, ?> value = entry.getValue();
-                // create the compressed stream for each state to have the 
compression header for
-                // each
-                try (final CompressibleFSDataOutputStream compressedLocalOut =
-                        new CompressibleFSDataOutputStream(
-                                localOut,
-                                compressionDecorator)) { // closes only the 
outer compression stream
+                    BackendWritableBroadcastState<?, ?> value = 
entry.getValue();
                     long[] partitionOffsets = 
{value.write(compressedLocalOut)};
                     OperatorStateHandle.Mode mode = 
value.getStateMetaInfo().getAssignmentMode();
                     writtenStatesMetaData.put(
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateRestoreOperation.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateRestoreOperation.java
index 1634641d68b..33ec3d92ab8 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateRestoreOperation.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateRestoreOperation.java
@@ -177,24 +177,20 @@ public class OperatorStateRestoreOperation implements 
RestoreOperation<Void> {
                 restoredBroadcastMetaInfoSnapshots.forEach(
                         stateName -> toRestore.add(stateName.getName()));
 
-                for (String stateName : toRestore) {
-
-                    final OperatorStateHandle.StateMetaInfo offsets =
-                            
stateHandle.getStateNameToPartitionOffsets().get(stateName);
-
-                    PartitionableListState<?> listStateForName =
-                            registeredOperatorStates.get(stateName);
-                    final StreamCompressionDecorator compressionDecorator =
-                            backendSerializationProxy.isUsingStateCompression()
-                                    ? SnappyStreamCompressionDecorator.INSTANCE
-                                    : 
UncompressedStreamCompressionDecorator.INSTANCE;
-                    // create the compressed stream for each state to have the 
compression header
-                    // for each
-                    try (final CompressibleFSDataInputStream compressedIn =
-                            new CompressibleFSDataInputStream(
-                                    in,
-                                    compressionDecorator)) { // closes only 
the outer compression
-                        // stream
+                final StreamCompressionDecorator compressionDecorator =
+                        backendSerializationProxy.isUsingStateCompression()
+                                ? SnappyStreamCompressionDecorator.INSTANCE
+                                : 
UncompressedStreamCompressionDecorator.INSTANCE;
+
+                try (final CompressibleFSDataInputStream compressedIn =
+                        new CompressibleFSDataInputStream(
+                                in,
+                                compressionDecorator)) { // closes only the 
outer compression stream
+                    for (String stateName : toRestore) {
+                        final OperatorStateHandle.StateMetaInfo offsets =
+                                
stateHandle.getStateNameToPartitionOffsets().get(stateName);
+                        PartitionableListState<?> listStateForName =
+                                registeredOperatorStates.get(stateName);
                         if (listStateForName == null) {
                             BackendWritableBroadcastState<?, ?> 
broadcastStateForName =
                                     registeredBroadcastStates.get(stateName);
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateRestoreOperationTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateRestoreOperationTest.java
index 47eca087a1f..a22d7e30c39 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateRestoreOperationTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateRestoreOperationTest.java
@@ -24,6 +24,7 @@ import org.apache.flink.api.common.state.ListStateDescriptor;
 import org.apache.flink.api.common.state.MapStateDescriptor;
 import org.apache.flink.core.fs.CloseableRegistry;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
+import 
org.apache.flink.runtime.checkpoint.RoundRobinOperatorStateRepartitioner;
 import org.apache.flink.runtime.state.memory.MemCheckpointStreamFactory;
 
 import org.junit.jupiter.params.ParameterizedTest;
@@ -38,6 +39,8 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Objects;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
 
 import static org.assertj.core.api.Assertions.assertThat;
 
@@ -157,8 +160,7 @@ public class OperatorStateRestoreOperationTest {
 
     @ParameterizedTest
     @ValueSource(booleans = {true, false})
-    void testRestoreAndRescalePartitionedOperatorState(boolean 
snapshotCompressionEnabled)
-            throws Exception {
+    void testMergeOperatorState(boolean snapshotCompressionEnabled) throws 
Exception {
         final ExecutionConfig cfg = new ExecutionConfig();
         cfg.setUseSnapshotCompression(snapshotCompressionEnabled);
         final ThrowingFunction<Collection<OperatorStateHandle>, 
OperatorStateBackend>
@@ -213,14 +215,75 @@ public class OperatorStateRestoreOperationTest {
         listStates.put("bufferState", Collections.emptyList());
         listStates.put("offsetState", Collections.singletonList("foo"));
 
+        final Map<String, Map<String, String>> broadcastStates = new 
HashMap<>();
+        broadcastStates.put("whateverState", Collections.emptyMap());
+
         final OperatorStateHandle stateHandle =
-                createOperatorStateHandle(
-                        operatorStateBackendFactory, listStates, 
Collections.emptyMap());
+                createOperatorStateHandle(operatorStateBackendFactory, 
listStates, broadcastStates);
 
         verifyOperatorStateHandle(
                 operatorStateBackendFactory,
                 Collections.singletonList(stateHandle),
                 listStates,
-                Collections.emptyMap());
+                broadcastStates);
+    }
+
+    @ParameterizedTest
+    @ValueSource(booleans = {true, false})
+    void testRepartitionOperatorState(boolean snapshotCompressionEnabled) 
throws Exception {
+        final ExecutionConfig cfg = new ExecutionConfig();
+        cfg.setUseSnapshotCompression(snapshotCompressionEnabled);
+        final ThrowingFunction<Collection<OperatorStateHandle>, 
OperatorStateBackend>
+                operatorStateBackendFactory =
+                        createOperatorStateBackendFactory(
+                                cfg, new CloseableRegistry(), 
this.getClass().getClassLoader());
+
+        final Map<String, List<String>> listStates = new HashMap<>();
+        listStates.put(
+                "bufferState",
+                IntStream.range(0, 10).mapToObj(idx -> "foo" + 
idx).collect(Collectors.toList()));
+        listStates.put(
+                "offsetState",
+                IntStream.range(0, 10).mapToObj(idx -> "bar" + 
idx).collect(Collectors.toList()));
+
+        final OperatorStateHandle stateHandle =
+                createOperatorStateHandle(
+                        operatorStateBackendFactory, listStates, 
Collections.emptyMap());
+
+        for (int newParallelism : Arrays.asList(1, 2, 5, 10)) {
+            final RoundRobinOperatorStateRepartitioner partitioner =
+                    new RoundRobinOperatorStateRepartitioner();
+            final List<List<OperatorStateHandle>> repartitioned =
+                    partitioner.repartitionState(
+                            
Collections.singletonList(Collections.singletonList(stateHandle)),
+                            1,
+                            newParallelism);
+            for (int idx = 0; idx < newParallelism; idx++) {
+                verifyOperatorStateHandle(
+                        operatorStateBackendFactory,
+                        repartitioned.get(idx),
+                        getExpectedSplit(listStates, newParallelism, idx),
+                        Collections.emptyMap());
+            }
+        }
+    }
+
+    /**
+     * This is a simplified version of what RR partitioner does, so it only 
works in case there is
+     * no remainder.
+     */
+    private static Map<String, List<String>> getExpectedSplit(
+            Map<String, List<String>> states, int newParallelism, int idx) {
+        final Map<String, List<String>> newStates = new HashMap<>();
+        for (String stateName : states.keySet()) {
+            final int stateSize = states.get(stateName).size();
+            newStates.put(
+                    stateName,
+                    states.get(stateName)
+                            .subList(
+                                    idx * stateSize / newParallelism,
+                                    (idx + 1) * stateSize / newParallelism));
+        }
+        return newStates;
     }
 }

Reply via email to