This is an automated email from the ASF dual-hosted git repository. dmvk pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
commit b7c314b46acdde0936a0e2d329f48972d1ea3e0b Author: David Moravek <d...@apache.org> AuthorDate: Tue Jan 16 00:07:44 2024 +0100 [FLINK-34063][runtime] Fix OperatorState repartitioning when compression is enabled. We should only write compression headers once, at the end of the "value" part of the serialized stream, to make sure we can always seek to a split point. --- ...efaultOperatorStateBackendSnapshotStrategy.java | 31 ++++----- .../state/OperatorStateRestoreOperation.java | 32 +++++----- .../state/OperatorStateRestoreOperationTest.java | 73 ++++++++++++++++++++-- 3 files changed, 93 insertions(+), 43 deletions(-) diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackendSnapshotStrategy.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackendSnapshotStrategy.java index 52df5cad736..d81f7b3a1c0 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackendSnapshotStrategy.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackendSnapshotStrategy.java @@ -170,35 +170,26 @@ class DefaultOperatorStateBackendSnapshotStrategy final Map<String, OperatorStateHandle.StateMetaInfo> writtenStatesMetaData = CollectionUtil.newHashMapWithExpectedSize(initialMapCapacity); - for (Map.Entry<String, PartitionableListState<?>> entry : - registeredOperatorStatesDeepCopies.entrySet()) { + try (final CompressibleFSDataOutputStream compressedLocalOut = + new CompressibleFSDataOutputStream( + localOut, + compressionDecorator)) { // closes only the outer compression stream + for (Map.Entry<String, PartitionableListState<?>> entry : + registeredOperatorStatesDeepCopies.entrySet()) { - PartitionableListState<?> value = entry.getValue(); - // create the compressed stream for each state to have the compression header for - // each - try (final CompressibleFSDataOutputStream compressedLocalOut = - new CompressibleFSDataOutputStream( - localOut, - compressionDecorator)) { // closes only the outer compression stream + PartitionableListState<?> value = entry.getValue(); long[] partitionOffsets = value.write(compressedLocalOut); OperatorStateHandle.Mode mode = value.getStateMetaInfo().getAssignmentMode(); writtenStatesMetaData.put( entry.getKey(), new OperatorStateHandle.StateMetaInfo(partitionOffsets, mode)); } - } - // ... and the broadcast states themselves ... - for (Map.Entry<String, BackendWritableBroadcastState<?, ?>> entry : - registeredBroadcastStatesDeepCopies.entrySet()) { + // ... and the broadcast states themselves ... + for (Map.Entry<String, BackendWritableBroadcastState<?, ?>> entry : + registeredBroadcastStatesDeepCopies.entrySet()) { - BackendWritableBroadcastState<?, ?> value = entry.getValue(); - // create the compressed stream for each state to have the compression header for - // each - try (final CompressibleFSDataOutputStream compressedLocalOut = - new CompressibleFSDataOutputStream( - localOut, - compressionDecorator)) { // closes only the outer compression stream + BackendWritableBroadcastState<?, ?> value = entry.getValue(); long[] partitionOffsets = {value.write(compressedLocalOut)}; OperatorStateHandle.Mode mode = value.getStateMetaInfo().getAssignmentMode(); writtenStatesMetaData.put( diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateRestoreOperation.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateRestoreOperation.java index 1634641d68b..33ec3d92ab8 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateRestoreOperation.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateRestoreOperation.java @@ -177,24 +177,20 @@ public class OperatorStateRestoreOperation implements RestoreOperation<Void> { restoredBroadcastMetaInfoSnapshots.forEach( stateName -> toRestore.add(stateName.getName())); - for (String stateName : toRestore) { - - final OperatorStateHandle.StateMetaInfo offsets = - stateHandle.getStateNameToPartitionOffsets().get(stateName); - - PartitionableListState<?> listStateForName = - registeredOperatorStates.get(stateName); - final StreamCompressionDecorator compressionDecorator = - backendSerializationProxy.isUsingStateCompression() - ? SnappyStreamCompressionDecorator.INSTANCE - : UncompressedStreamCompressionDecorator.INSTANCE; - // create the compressed stream for each state to have the compression header - // for each - try (final CompressibleFSDataInputStream compressedIn = - new CompressibleFSDataInputStream( - in, - compressionDecorator)) { // closes only the outer compression - // stream + final StreamCompressionDecorator compressionDecorator = + backendSerializationProxy.isUsingStateCompression() + ? SnappyStreamCompressionDecorator.INSTANCE + : UncompressedStreamCompressionDecorator.INSTANCE; + + try (final CompressibleFSDataInputStream compressedIn = + new CompressibleFSDataInputStream( + in, + compressionDecorator)) { // closes only the outer compression stream + for (String stateName : toRestore) { + final OperatorStateHandle.StateMetaInfo offsets = + stateHandle.getStateNameToPartitionOffsets().get(stateName); + PartitionableListState<?> listStateForName = + registeredOperatorStates.get(stateName); if (listStateForName == null) { BackendWritableBroadcastState<?, ?> broadcastStateForName = registeredBroadcastStates.get(stateName); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateRestoreOperationTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateRestoreOperationTest.java index 47eca087a1f..a22d7e30c39 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateRestoreOperationTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateRestoreOperationTest.java @@ -24,6 +24,7 @@ import org.apache.flink.api.common.state.ListStateDescriptor; import org.apache.flink.api.common.state.MapStateDescriptor; import org.apache.flink.core.fs.CloseableRegistry; import org.apache.flink.runtime.checkpoint.CheckpointOptions; +import org.apache.flink.runtime.checkpoint.RoundRobinOperatorStateRepartitioner; import org.apache.flink.runtime.state.memory.MemCheckpointStreamFactory; import org.junit.jupiter.params.ParameterizedTest; @@ -38,6 +39,8 @@ import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.stream.Collectors; +import java.util.stream.IntStream; import static org.assertj.core.api.Assertions.assertThat; @@ -157,8 +160,7 @@ public class OperatorStateRestoreOperationTest { @ParameterizedTest @ValueSource(booleans = {true, false}) - void testRestoreAndRescalePartitionedOperatorState(boolean snapshotCompressionEnabled) - throws Exception { + void testMergeOperatorState(boolean snapshotCompressionEnabled) throws Exception { final ExecutionConfig cfg = new ExecutionConfig(); cfg.setUseSnapshotCompression(snapshotCompressionEnabled); final ThrowingFunction<Collection<OperatorStateHandle>, OperatorStateBackend> @@ -213,14 +215,75 @@ public class OperatorStateRestoreOperationTest { listStates.put("bufferState", Collections.emptyList()); listStates.put("offsetState", Collections.singletonList("foo")); + final Map<String, Map<String, String>> broadcastStates = new HashMap<>(); + broadcastStates.put("whateverState", Collections.emptyMap()); + final OperatorStateHandle stateHandle = - createOperatorStateHandle( - operatorStateBackendFactory, listStates, Collections.emptyMap()); + createOperatorStateHandle(operatorStateBackendFactory, listStates, broadcastStates); verifyOperatorStateHandle( operatorStateBackendFactory, Collections.singletonList(stateHandle), listStates, - Collections.emptyMap()); + broadcastStates); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void testRepartitionOperatorState(boolean snapshotCompressionEnabled) throws Exception { + final ExecutionConfig cfg = new ExecutionConfig(); + cfg.setUseSnapshotCompression(snapshotCompressionEnabled); + final ThrowingFunction<Collection<OperatorStateHandle>, OperatorStateBackend> + operatorStateBackendFactory = + createOperatorStateBackendFactory( + cfg, new CloseableRegistry(), this.getClass().getClassLoader()); + + final Map<String, List<String>> listStates = new HashMap<>(); + listStates.put( + "bufferState", + IntStream.range(0, 10).mapToObj(idx -> "foo" + idx).collect(Collectors.toList())); + listStates.put( + "offsetState", + IntStream.range(0, 10).mapToObj(idx -> "bar" + idx).collect(Collectors.toList())); + + final OperatorStateHandle stateHandle = + createOperatorStateHandle( + operatorStateBackendFactory, listStates, Collections.emptyMap()); + + for (int newParallelism : Arrays.asList(1, 2, 5, 10)) { + final RoundRobinOperatorStateRepartitioner partitioner = + new RoundRobinOperatorStateRepartitioner(); + final List<List<OperatorStateHandle>> repartitioned = + partitioner.repartitionState( + Collections.singletonList(Collections.singletonList(stateHandle)), + 1, + newParallelism); + for (int idx = 0; idx < newParallelism; idx++) { + verifyOperatorStateHandle( + operatorStateBackendFactory, + repartitioned.get(idx), + getExpectedSplit(listStates, newParallelism, idx), + Collections.emptyMap()); + } + } + } + + /** + * This is a simplified version of what RR partitioner does, so it only works in case there is + * no remainder. + */ + private static Map<String, List<String>> getExpectedSplit( + Map<String, List<String>> states, int newParallelism, int idx) { + final Map<String, List<String>> newStates = new HashMap<>(); + for (String stateName : states.keySet()) { + final int stateSize = states.get(stateName).size(); + newStates.put( + stateName, + states.get(stateName) + .subList( + idx * stateSize / newParallelism, + (idx + 1) * stateSize / newParallelism)); + } + return newStates; } }