This is an automated email from the ASF dual-hosted git repository.

dmvk pushed a commit to branch release-1.18
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 1cd68f954ee7d5360a98a779bd93c3cc7144c5a6
Author: David Moravek <d...@apache.org>
AuthorDate: Fri Jan 12 10:44:35 2024 +0100

    [FLINK-34063][runtime] Always flush compression buffers, when retrieving 
stream position during OperatorState snapshot.
---
 .../state/CompressibleFSDataOutputStream.java      |   4 +
 .../state/OperatorStateRestoreOperationTest.java   | 196 +++++++++++++++------
 2 files changed, 146 insertions(+), 54 deletions(-)

diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/CompressibleFSDataOutputStream.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/CompressibleFSDataOutputStream.java
index a573417085a..9c3628d1223 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/CompressibleFSDataOutputStream.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/CompressibleFSDataOutputStream.java
@@ -41,6 +41,10 @@ public class CompressibleFSDataOutputStream extends 
FSDataOutputStream {
 
     @Override
     public long getPos() throws IOException {
+        // Underlying compression involves buffering, so the only way to 
report correct position is
+        // to flush the underlying stream. This lowers the effectivity of 
compression, but there is
+        // no other way, since the position is often used as a split point.
+        flush();
         return delegate.getPos();
     }
 
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateRestoreOperationTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateRestoreOperationTest.java
index e0aecd5d723..4ce170aed16 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateRestoreOperationTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateRestoreOperationTest.java
@@ -26,55 +26,60 @@ import org.apache.flink.core.fs.CloseableRegistry;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
 import org.apache.flink.runtime.state.memory.MemCheckpointStreamFactory;
 
-import org.junit.jupiter.api.Test;
-
-import javax.annotation.Nullable;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.ValueSource;
+import org.testcontainers.utility.ThrowingFunction;
 
+import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+
+import static org.assertj.core.api.Assertions.assertThat;
 
 /** Tests for the {@link 
org.apache.flink.runtime.state.OperatorStateRestoreOperation}. */
 public class OperatorStateRestoreOperationTest {
 
-    @Nullable
+    private static ThrowingFunction<Collection<OperatorStateHandle>, 
OperatorStateBackend>
+            createOperatorStateBackendFactory(
+                    ExecutionConfig cfg,
+                    CloseableRegistry cancelStreamRegistry,
+                    ClassLoader classLoader) {
+        return handles ->
+                new DefaultOperatorStateBackendBuilder(
+                                classLoader, cfg, false, handles, 
cancelStreamRegistry)
+                        .build();
+    }
+
     private static OperatorStateHandle createOperatorStateHandle(
-            ExecutionConfig cfg,
-            CloseableRegistry cancelStreamRegistry,
-            ClassLoader classLoader,
-            List<String> stateNames,
-            List<String> broadcastStateNames)
+            ThrowingFunction<Collection<OperatorStateHandle>, 
OperatorStateBackend>
+                    operatorStateBackendFactory,
+            Map<String, List<String>> listStates,
+            Map<String, Map<String, String>> broadcastStates)
             throws Exception {
-
         try (OperatorStateBackend operatorStateBackend =
-                new DefaultOperatorStateBackendBuilder(
-                                classLoader,
-                                cfg,
-                                false,
-                                Collections.emptyList(),
-                                cancelStreamRegistry)
-                        .build()) {
-            CheckpointStreamFactory streamFactory = new 
MemCheckpointStreamFactory(4096);
-
-            for (String stateName : stateNames) {
-                ListStateDescriptor<String> descriptor =
+                operatorStateBackendFactory.apply(Collections.emptyList())) {
+            final CheckpointStreamFactory streamFactory = new 
MemCheckpointStreamFactory(4096);
+            for (String stateName : listStates.keySet()) {
+                final ListStateDescriptor<String> descriptor =
                         new ListStateDescriptor<>(stateName, String.class);
-                PartitionableListState<String> state =
+                final PartitionableListState<String> state =
                         (PartitionableListState<String>)
                                 operatorStateBackend.getListState(descriptor);
-                state.add("value1");
+                state.addAll(listStates.get(stateName));
             }
-
-            for (String broadcastStateName : broadcastStateNames) {
-                MapStateDescriptor<String, String> descriptor =
-                        new MapStateDescriptor<>(broadcastStateName, 
String.class, String.class);
-                BroadcastState<String, String> state =
+            for (String stateName : broadcastStates.keySet()) {
+                final MapStateDescriptor<String, String> descriptor =
+                        new MapStateDescriptor<>(stateName, String.class, 
String.class);
+                final BroadcastState<String, String> state =
                         operatorStateBackend.getBroadcastState(descriptor);
-                state.put("key1", "value1");
+                state.putAll(broadcastStates.get(stateName));
             }
-
-            SnapshotResult<OperatorStateHandle> result =
+            final SnapshotResult<OperatorStateHandle> result =
                     operatorStateBackend
                             .snapshot(
                                     1,
@@ -82,33 +87,116 @@ public class OperatorStateRestoreOperationTest {
                                     streamFactory,
                                     
CheckpointOptions.forCheckpointWithDefaultLocation())
                             .get();
-            return result.getJobManagerOwnedSnapshot();
+            return Objects.requireNonNull(result.getJobManagerOwnedSnapshot());
         }
     }
 
-    @Test
-    public void 
testRestoringMixedOperatorStateWhenSnapshotCompressionIsEnabled() throws 
Exception {
-        ExecutionConfig cfg = new ExecutionConfig();
-        cfg.setUseSnapshotCompression(true);
-        CloseableRegistry cancelStreamRegistry = new CloseableRegistry();
-        ClassLoader classLoader = this.getClass().getClassLoader();
+    private static void verifyOperatorStateHandle(
+            ThrowingFunction<Collection<OperatorStateHandle>, 
OperatorStateBackend>
+                    operatorStateBackendFactory,
+            Collection<OperatorStateHandle> stateHandles,
+            Map<String, List<String>> listStates,
+            Map<String, Map<String, String>> broadcastStates)
+            throws Exception {
+        try (OperatorStateBackend operatorStateBackend =
+                operatorStateBackendFactory.apply(stateHandles)) {
+            for (String stateName : listStates.keySet()) {
+                final ListStateDescriptor<String> descriptor =
+                        new ListStateDescriptor<>(stateName, String.class);
+                final PartitionableListState<String> state =
+                        (PartitionableListState<String>)
+                                operatorStateBackend.getListState(descriptor);
+                
assertThat(state.get()).containsExactlyElementsOf(listStates.get(stateName));
+            }
+            for (String stateName : listStates.keySet()) {
+                final ListStateDescriptor<String> descriptor =
+                        new ListStateDescriptor<>(stateName, String.class);
+                final PartitionableListState<String> state =
+                        (PartitionableListState<String>)
+                                operatorStateBackend.getListState(descriptor);
+                
assertThat(state.get()).containsExactlyElementsOf(listStates.get(stateName));
+            }
+            for (String stateName : broadcastStates.keySet()) {
+                final MapStateDescriptor<String, String> descriptor =
+                        new MapStateDescriptor<>(stateName, String.class, 
String.class);
+                final BroadcastState<String, String> state =
+                        operatorStateBackend.getBroadcastState(descriptor);
+                final Map<String, String> content = new HashMap<>();
+                state.iterator().forEachRemaining(e -> content.put(e.getKey(), 
e.getValue()));
+                
assertThat(content).containsAllEntriesOf(broadcastStates.get(stateName));
+            }
+        }
+    }
+
+    @ParameterizedTest
+    @ValueSource(booleans = {true, false})
+    void testRestoringMixedOperatorState(boolean snapshotCompressionEnabled) 
throws Exception {
+        final ExecutionConfig cfg = new ExecutionConfig();
+        cfg.setUseSnapshotCompression(snapshotCompressionEnabled);
+        ThrowingFunction<Collection<OperatorStateHandle>, OperatorStateBackend>
+                operatorStateBackendFactory =
+                        createOperatorStateBackendFactory(
+                                cfg, new CloseableRegistry(), 
this.getClass().getClassLoader());
+
+        final Map<String, List<String>> listStates = new HashMap<>();
+        listStates.put("s1", Arrays.asList("foo1", "foo2", "foo3"));
+        listStates.put("s2", Arrays.asList("bar1", "bar2", "bar3"));
+
+        final Map<String, Map<String, String>> broadcastStates = new 
HashMap<>();
+        broadcastStates.put("a1", Collections.singletonMap("foo", "bar"));
+        broadcastStates.put("a2", Collections.singletonMap("bar", "foo"));
+
+        final OperatorStateHandle stateHandle =
+                createOperatorStateHandle(operatorStateBackendFactory, 
listStates, broadcastStates);
+
+        verifyOperatorStateHandle(
+                operatorStateBackendFactory,
+                Collections.singletonList(stateHandle),
+                listStates,
+                broadcastStates);
+    }
+
+    @ParameterizedTest
+    @ValueSource(booleans = {true, false})
+    void testRestoreAndRescalePartitionedOperatorState(boolean 
snapshotCompressionEnabled)
+            throws Exception {
+        final ExecutionConfig cfg = new ExecutionConfig();
+        cfg.setUseSnapshotCompression(snapshotCompressionEnabled);
+        ThrowingFunction<Collection<OperatorStateHandle>, OperatorStateBackend>
+                operatorStateBackendFactory =
+                        createOperatorStateBackendFactory(
+                                cfg, new CloseableRegistry(), 
this.getClass().getClassLoader());
+
+        final Map<String, List<String>> firstListStates = new HashMap<>();
+        firstListStates.put("s1", Arrays.asList("foo1", "foo2", "foo3"));
+        firstListStates.put("s2", Arrays.asList("bar1", "bar2", "bar3"));
+
+        final Map<String, List<String>> secondListStates = new HashMap<>();
+        secondListStates.put("s1", Arrays.asList("foo4", "foo5", "foo6"));
+        secondListStates.put("s2", Arrays.asList("bar1", "bar2", "bar3"));
 
-        OperatorStateHandle handle =
+        final OperatorStateHandle firstStateHandle =
                 createOperatorStateHandle(
-                        cfg,
-                        cancelStreamRegistry,
-                        classLoader,
-                        Arrays.asList("s1", "s2"),
-                        Collections.singletonList("b2"));
-
-        OperatorStateRestoreOperation operatorStateRestoreOperation =
-                new OperatorStateRestoreOperation(
-                        cancelStreamRegistry,
-                        classLoader,
-                        new HashMap<>(),
-                        new HashMap<>(),
-                        Collections.singletonList(handle));
-
-        operatorStateRestoreOperation.restore();
+                        operatorStateBackendFactory, firstListStates, 
Collections.emptyMap());
+        final OperatorStateHandle secondStateHandle =
+                createOperatorStateHandle(
+                        operatorStateBackendFactory, firstListStates, 
Collections.emptyMap());
+
+        final Map<String, List<String>> mergedListStates = new HashMap<>();
+        for (String stateName : firstListStates.keySet()) {
+            mergedListStates
+                    .computeIfAbsent(stateName, k -> new ArrayList<>())
+                    .addAll(firstListStates.get(stateName));
+        }
+        for (String stateName : secondListStates.keySet()) {
+            mergedListStates
+                    .computeIfAbsent(stateName, k -> new ArrayList<>())
+                    .addAll(firstListStates.get(stateName));
+        }
+        verifyOperatorStateHandle(
+                operatorStateBackendFactory,
+                Arrays.asList(firstStateHandle, secondStateHandle),
+                mergedListStates,
+                Collections.emptyMap());
     }
 }

Reply via email to