This is an automated email from the ASF dual-hosted git repository. dmvk pushed a commit to branch release-1.18 in repository https://gitbox.apache.org/repos/asf/flink.git
commit 1cd68f954ee7d5360a98a779bd93c3cc7144c5a6 Author: David Moravek <d...@apache.org> AuthorDate: Fri Jan 12 10:44:35 2024 +0100 [FLINK-34063][runtime] Always flush compression buffers, when retrieving stream position during OperatorState snapshot. --- .../state/CompressibleFSDataOutputStream.java | 4 + .../state/OperatorStateRestoreOperationTest.java | 196 +++++++++++++++------ 2 files changed, 146 insertions(+), 54 deletions(-) diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/CompressibleFSDataOutputStream.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/CompressibleFSDataOutputStream.java index a573417085a..9c3628d1223 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/CompressibleFSDataOutputStream.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/CompressibleFSDataOutputStream.java @@ -41,6 +41,10 @@ public class CompressibleFSDataOutputStream extends FSDataOutputStream { @Override public long getPos() throws IOException { + // Underlying compression involves buffering, so the only way to report correct position is + // to flush the underlying stream. This lowers the effectivity of compression, but there is + // no other way, since the position is often used as a split point. + flush(); return delegate.getPos(); } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateRestoreOperationTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateRestoreOperationTest.java index e0aecd5d723..4ce170aed16 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateRestoreOperationTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateRestoreOperationTest.java @@ -26,55 +26,60 @@ import org.apache.flink.core.fs.CloseableRegistry; import org.apache.flink.runtime.checkpoint.CheckpointOptions; import org.apache.flink.runtime.state.memory.MemCheckpointStreamFactory; -import org.junit.jupiter.api.Test; - -import javax.annotation.Nullable; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.testcontainers.utility.ThrowingFunction; +import java.util.ArrayList; import java.util.Arrays; +import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.List; +import java.util.Map; +import java.util.Objects; + +import static org.assertj.core.api.Assertions.assertThat; /** Tests for the {@link org.apache.flink.runtime.state.OperatorStateRestoreOperation}. */ public class OperatorStateRestoreOperationTest { - @Nullable + private static ThrowingFunction<Collection<OperatorStateHandle>, OperatorStateBackend> + createOperatorStateBackendFactory( + ExecutionConfig cfg, + CloseableRegistry cancelStreamRegistry, + ClassLoader classLoader) { + return handles -> + new DefaultOperatorStateBackendBuilder( + classLoader, cfg, false, handles, cancelStreamRegistry) + .build(); + } + private static OperatorStateHandle createOperatorStateHandle( - ExecutionConfig cfg, - CloseableRegistry cancelStreamRegistry, - ClassLoader classLoader, - List<String> stateNames, - List<String> broadcastStateNames) + ThrowingFunction<Collection<OperatorStateHandle>, OperatorStateBackend> + operatorStateBackendFactory, + Map<String, List<String>> listStates, + Map<String, Map<String, String>> broadcastStates) throws Exception { - try (OperatorStateBackend operatorStateBackend = - new DefaultOperatorStateBackendBuilder( - classLoader, - cfg, - false, - Collections.emptyList(), - cancelStreamRegistry) - .build()) { - CheckpointStreamFactory streamFactory = new MemCheckpointStreamFactory(4096); - - for (String stateName : stateNames) { - ListStateDescriptor<String> descriptor = + operatorStateBackendFactory.apply(Collections.emptyList())) { + final CheckpointStreamFactory streamFactory = new MemCheckpointStreamFactory(4096); + for (String stateName : listStates.keySet()) { + final ListStateDescriptor<String> descriptor = new ListStateDescriptor<>(stateName, String.class); - PartitionableListState<String> state = + final PartitionableListState<String> state = (PartitionableListState<String>) operatorStateBackend.getListState(descriptor); - state.add("value1"); + state.addAll(listStates.get(stateName)); } - - for (String broadcastStateName : broadcastStateNames) { - MapStateDescriptor<String, String> descriptor = - new MapStateDescriptor<>(broadcastStateName, String.class, String.class); - BroadcastState<String, String> state = + for (String stateName : broadcastStates.keySet()) { + final MapStateDescriptor<String, String> descriptor = + new MapStateDescriptor<>(stateName, String.class, String.class); + final BroadcastState<String, String> state = operatorStateBackend.getBroadcastState(descriptor); - state.put("key1", "value1"); + state.putAll(broadcastStates.get(stateName)); } - - SnapshotResult<OperatorStateHandle> result = + final SnapshotResult<OperatorStateHandle> result = operatorStateBackend .snapshot( 1, @@ -82,33 +87,116 @@ public class OperatorStateRestoreOperationTest { streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()) .get(); - return result.getJobManagerOwnedSnapshot(); + return Objects.requireNonNull(result.getJobManagerOwnedSnapshot()); } } - @Test - public void testRestoringMixedOperatorStateWhenSnapshotCompressionIsEnabled() throws Exception { - ExecutionConfig cfg = new ExecutionConfig(); - cfg.setUseSnapshotCompression(true); - CloseableRegistry cancelStreamRegistry = new CloseableRegistry(); - ClassLoader classLoader = this.getClass().getClassLoader(); + private static void verifyOperatorStateHandle( + ThrowingFunction<Collection<OperatorStateHandle>, OperatorStateBackend> + operatorStateBackendFactory, + Collection<OperatorStateHandle> stateHandles, + Map<String, List<String>> listStates, + Map<String, Map<String, String>> broadcastStates) + throws Exception { + try (OperatorStateBackend operatorStateBackend = + operatorStateBackendFactory.apply(stateHandles)) { + for (String stateName : listStates.keySet()) { + final ListStateDescriptor<String> descriptor = + new ListStateDescriptor<>(stateName, String.class); + final PartitionableListState<String> state = + (PartitionableListState<String>) + operatorStateBackend.getListState(descriptor); + assertThat(state.get()).containsExactlyElementsOf(listStates.get(stateName)); + } + for (String stateName : listStates.keySet()) { + final ListStateDescriptor<String> descriptor = + new ListStateDescriptor<>(stateName, String.class); + final PartitionableListState<String> state = + (PartitionableListState<String>) + operatorStateBackend.getListState(descriptor); + assertThat(state.get()).containsExactlyElementsOf(listStates.get(stateName)); + } + for (String stateName : broadcastStates.keySet()) { + final MapStateDescriptor<String, String> descriptor = + new MapStateDescriptor<>(stateName, String.class, String.class); + final BroadcastState<String, String> state = + operatorStateBackend.getBroadcastState(descriptor); + final Map<String, String> content = new HashMap<>(); + state.iterator().forEachRemaining(e -> content.put(e.getKey(), e.getValue())); + assertThat(content).containsAllEntriesOf(broadcastStates.get(stateName)); + } + } + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void testRestoringMixedOperatorState(boolean snapshotCompressionEnabled) throws Exception { + final ExecutionConfig cfg = new ExecutionConfig(); + cfg.setUseSnapshotCompression(snapshotCompressionEnabled); + ThrowingFunction<Collection<OperatorStateHandle>, OperatorStateBackend> + operatorStateBackendFactory = + createOperatorStateBackendFactory( + cfg, new CloseableRegistry(), this.getClass().getClassLoader()); + + final Map<String, List<String>> listStates = new HashMap<>(); + listStates.put("s1", Arrays.asList("foo1", "foo2", "foo3")); + listStates.put("s2", Arrays.asList("bar1", "bar2", "bar3")); + + final Map<String, Map<String, String>> broadcastStates = new HashMap<>(); + broadcastStates.put("a1", Collections.singletonMap("foo", "bar")); + broadcastStates.put("a2", Collections.singletonMap("bar", "foo")); + + final OperatorStateHandle stateHandle = + createOperatorStateHandle(operatorStateBackendFactory, listStates, broadcastStates); + + verifyOperatorStateHandle( + operatorStateBackendFactory, + Collections.singletonList(stateHandle), + listStates, + broadcastStates); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void testRestoreAndRescalePartitionedOperatorState(boolean snapshotCompressionEnabled) + throws Exception { + final ExecutionConfig cfg = new ExecutionConfig(); + cfg.setUseSnapshotCompression(snapshotCompressionEnabled); + ThrowingFunction<Collection<OperatorStateHandle>, OperatorStateBackend> + operatorStateBackendFactory = + createOperatorStateBackendFactory( + cfg, new CloseableRegistry(), this.getClass().getClassLoader()); + + final Map<String, List<String>> firstListStates = new HashMap<>(); + firstListStates.put("s1", Arrays.asList("foo1", "foo2", "foo3")); + firstListStates.put("s2", Arrays.asList("bar1", "bar2", "bar3")); + + final Map<String, List<String>> secondListStates = new HashMap<>(); + secondListStates.put("s1", Arrays.asList("foo4", "foo5", "foo6")); + secondListStates.put("s2", Arrays.asList("bar1", "bar2", "bar3")); - OperatorStateHandle handle = + final OperatorStateHandle firstStateHandle = createOperatorStateHandle( - cfg, - cancelStreamRegistry, - classLoader, - Arrays.asList("s1", "s2"), - Collections.singletonList("b2")); - - OperatorStateRestoreOperation operatorStateRestoreOperation = - new OperatorStateRestoreOperation( - cancelStreamRegistry, - classLoader, - new HashMap<>(), - new HashMap<>(), - Collections.singletonList(handle)); - - operatorStateRestoreOperation.restore(); + operatorStateBackendFactory, firstListStates, Collections.emptyMap()); + final OperatorStateHandle secondStateHandle = + createOperatorStateHandle( + operatorStateBackendFactory, firstListStates, Collections.emptyMap()); + + final Map<String, List<String>> mergedListStates = new HashMap<>(); + for (String stateName : firstListStates.keySet()) { + mergedListStates + .computeIfAbsent(stateName, k -> new ArrayList<>()) + .addAll(firstListStates.get(stateName)); + } + for (String stateName : secondListStates.keySet()) { + mergedListStates + .computeIfAbsent(stateName, k -> new ArrayList<>()) + .addAll(firstListStates.get(stateName)); + } + verifyOperatorStateHandle( + operatorStateBackendFactory, + Arrays.asList(firstStateHandle, secondStateHandle), + mergedListStates, + Collections.emptyMap()); } }