[FLINK-5265] Introduce state handle replication mode for CheckpointCoordinator
Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/1020ba2c Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/1020ba2c Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/1020ba2c Branch: refs/heads/master Commit: 1020ba2c9cfc1d01703e97c72e20a922bae0732d Parents: 8b1b4a1 Author: Stefan Richter <[email protected]> Authored: Sat Dec 3 02:42:25 2016 +0100 Committer: Aljoscha Krettek <[email protected]> Committed: Fri Jan 13 21:29:19 2017 +0100 ---------------------------------------------------------------------- .../api/common/state/OperatorStateStore.java | 37 +++- .../RoundRobinOperatorStateRepartitioner.java | 133 ++++++++++++--- .../checkpoint/StateAssignmentOperation.java | 15 +- .../savepoint/SavepointV1Serializer.java | 24 ++- .../state/DefaultOperatorStateBackend.java | 160 +++++++++++------ .../OperatorBackendSerializationProxy.java | 51 +++++- .../OperatorStateCheckpointOutputStream.java | 10 +- .../runtime/state/OperatorStateHandle.java | 86 +++++++++- .../state/StateInitializationContextImpl.java | 18 +- .../checkpoint/CheckpointCoordinatorTest.java | 119 +++++++++---- .../checkpoint/savepoint/SavepointV1Test.java | 7 +- .../runtime/state/OperatorStateBackendTest.java | 54 +++++- .../runtime/state/OperatorStateHandleTest.java | 39 +++++ ...OperatorStateOutputCheckpointStreamTest.java | 11 +- .../runtime/state/SerializationProxiesTest.java | 63 ++++++- .../StateInitializationContextImplTest.java | 6 +- .../tasks/InterruptSensitiveRestoreTest.java | 6 +- .../test/checkpointing/RescalingITCase.java | 171 +++++++++++-------- 18 files changed, 787 insertions(+), 223 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/1020ba2c/flink-core/src/main/java/org/apache/flink/api/common/state/OperatorStateStore.java ---------------------------------------------------------------------- diff --git a/flink-core/src/main/java/org/apache/flink/api/common/state/OperatorStateStore.java b/flink-core/src/main/java/org/apache/flink/api/common/state/OperatorStateStore.java index 43dbe51..87a7759 100644 --- a/flink-core/src/main/java/org/apache/flink/api/common/state/OperatorStateStore.java +++ b/flink-core/src/main/java/org/apache/flink/api/common/state/OperatorStateStore.java @@ -30,8 +30,22 @@ import java.util.Set; public interface OperatorStateStore { /** - * Creates a state descriptor of the given name that uses Java serialization to persist the - * state. + * Creates (or restores) a list state. Each state is registered under a unique name. + * The provided serializer is used to de/serialize the state in case of checkpointing (snapshot/restore). + * + * The items in the list are repartitionable by the system in case of changed operator parallelism. + * + * @param stateDescriptor The descriptor for this state, providing a name and serializer. + * @param <S> The generic type of the state + * + * @return A list for all state partitions. + * @throws Exception + */ + <S> ListState<S> getOperatorState(ListStateDescriptor<S> stateDescriptor) throws Exception; + + /** + * Creates a state of the given name that uses Java serialization to persist the state. The items in the list + * are repartitionable by the system in case of changed operator parallelism. * * <p>This is a simple convenience method. For more flexibility on how state serialization * should happen, use the {@link #getOperatorState(ListStateDescriptor)} method. @@ -46,13 +60,28 @@ public interface OperatorStateStore { * Creates (or restores) a list state. Each state is registered under a unique name. * The provided serializer is used to de/serialize the state in case of checkpointing (snapshot/restore). * + * On restore, all items in the list are broadcasted to all parallel operator instances. + * * @param stateDescriptor The descriptor for this state, providing a name and serializer. * @param <S> The generic type of the state - * + * * @return A list for all state partitions. * @throws Exception */ - <S> ListState<S> getOperatorState(ListStateDescriptor<S> stateDescriptor) throws Exception; + <S> ListState<S> getBroadcastOperatorState(ListStateDescriptor<S> stateDescriptor) throws Exception; + + /** + * Creates a state of the given name that uses Java serialization to persist the state. On restore, all items + * in the list are broadcasted to all parallel operator instances. + * + * <p>This is a simple convenience method. For more flexibility on how state serialization + * should happen, use the {@link #getBroadcastOperatorState(ListStateDescriptor)} method. + * + * @param stateName The name of state to create + * @return A list state using Java serialization to serialize state objects. + * @throws Exception + */ + <T extends Serializable> ListState<T> getBroadcastSerializableListState(String stateName) throws Exception; /** * Returns a set with the names of all currently registered states. http://git-wip-us.apache.org/repos/asf/flink/blob/1020ba2c/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/RoundRobinOperatorStateRepartitioner.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/RoundRobinOperatorStateRepartitioner.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/RoundRobinOperatorStateRepartitioner.java index 16a7e27..046096f 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/RoundRobinOperatorStateRepartitioner.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/RoundRobinOperatorStateRepartitioner.java @@ -26,6 +26,7 @@ import org.apache.flink.util.Preconditions; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; +import java.util.EnumMap; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -47,8 +48,7 @@ public class RoundRobinOperatorStateRepartitioner implements OperatorStateRepart Preconditions.checkArgument(parallelism > 0); // Reorganize: group by (State Name -> StreamStateHandle + Offsets) - Map<String, List<Tuple2<StreamStateHandle, long[]>>> nameToState = - groupByStateName(previousParallelSubtaskStates); + GroupByStateNameResults nameToStateByMode = groupByStateName(previousParallelSubtaskStates); if (OPTIMIZE_MEMORY_USE) { previousParallelSubtaskStates.clear(); // free for GC at to cost that old handles are no longer available @@ -59,7 +59,7 @@ public class RoundRobinOperatorStateRepartitioner implements OperatorStateRepart // Do the actual repartitioning for all named states List<Map<StreamStateHandle, OperatorStateHandle>> mergeMapList = - repartition(nameToState, parallelism); + repartition(nameToStateByMode, parallelism); for (int i = 0; i < mergeMapList.size(); ++i) { result.add(i, new ArrayList<>(mergeMapList.get(i).values())); @@ -71,16 +71,33 @@ public class RoundRobinOperatorStateRepartitioner implements OperatorStateRepart /** * Group by the different named states. */ - private Map<String, List<Tuple2<StreamStateHandle, long[]>>> groupByStateName( + @SuppressWarnings("unchecked, rawtype") + private GroupByStateNameResults groupByStateName( List<OperatorStateHandle> previousParallelSubtaskStates) { - //Reorganize: group by (State Name -> StreamStateHandle + Offsets) - Map<String, List<Tuple2<StreamStateHandle, long[]>>> nameToState = new HashMap<>(); + //Reorganize: group by (State Name -> StreamStateHandle + StateMetaInfo) + EnumMap<OperatorStateHandle.Mode, + Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>>> nameToStateByMode = + new EnumMap<>(OperatorStateHandle.Mode.class); + + for (OperatorStateHandle.Mode mode : OperatorStateHandle.Mode.values()) { + Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>> map = new HashMap<>(); + nameToStateByMode.put( + mode, + new HashMap<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>>()); + } + for (OperatorStateHandle psh : previousParallelSubtaskStates) { - for (Map.Entry<String, long[]> e : psh.getStateNameToPartitionOffsets().entrySet()) { + for (Map.Entry<String, OperatorStateHandle.StateMetaInfo> e : + psh.getStateNameToPartitionOffsets().entrySet()) { + OperatorStateHandle.StateMetaInfo metaInfo = e.getValue(); + + Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>> nameToState = + nameToStateByMode.get(metaInfo.getDistributionMode()); - List<Tuple2<StreamStateHandle, long[]>> stateLocations = nameToState.get(e.getKey()); + List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>> stateLocations = + nameToState.get(e.getKey()); if (stateLocations == null) { stateLocations = new ArrayList<>(); @@ -90,32 +107,40 @@ public class RoundRobinOperatorStateRepartitioner implements OperatorStateRepart stateLocations.add(new Tuple2<>(psh.getDelegateStateHandle(), e.getValue())); } } - return nameToState; + + return new GroupByStateNameResults(nameToStateByMode); } /** * Repartition all named states. */ private List<Map<StreamStateHandle, OperatorStateHandle>> repartition( - Map<String, List<Tuple2<StreamStateHandle, long[]>>> nameToState, int parallelism) { + GroupByStateNameResults nameToStateByMode, + int parallelism) { // We will use this to merge w.r.t. StreamStateHandles for each parallel subtask inside the maps List<Map<StreamStateHandle, OperatorStateHandle>> mergeMapList = new ArrayList<>(parallelism); + // Initialize for (int i = 0; i < parallelism; ++i) { mergeMapList.add(new HashMap<StreamStateHandle, OperatorStateHandle>()); } - int startParallelOP = 0; + // Start with the state handles we distribute round robin by splitting by offsets + Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>> distributeNameToState = + nameToStateByMode.getByMode(OperatorStateHandle.Mode.SPLIT_DISTRIBUTE); + + int startParallelOp = 0; // Iterate all named states and repartition one named state at a time per iteration - for (Map.Entry<String, List<Tuple2<StreamStateHandle, long[]>>> e : nameToState.entrySet()) { + for (Map.Entry<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>> e : + distributeNameToState.entrySet()) { - List<Tuple2<StreamStateHandle, long[]>> current = e.getValue(); + List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>> current = e.getValue(); // Determine actual number of partitions for this named state int totalPartitions = 0; - for (Tuple2<StreamStateHandle, long[]> offsets : current) { - totalPartitions += offsets.f1.length; + for (Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo> offsets : current) { + totalPartitions += offsets.f1.getOffsets().length; } // Repartition the state across the parallel operator instances @@ -124,12 +149,12 @@ public class RoundRobinOperatorStateRepartitioner implements OperatorStateRepart int baseFraction = totalPartitions / parallelism; int remainder = totalPartitions % parallelism; - int newStartParallelOp = startParallelOP; + int newStartParallelOp = startParallelOp; for (int i = 0; i < parallelism; ++i) { // Preparation: calculate the actual index considering wrap around - int parallelOpIdx = (i + startParallelOP) % parallelism; + int parallelOpIdx = (i + startParallelOp) % parallelism; // Now calculate the number of partitions we will assign to the parallel instance in this round ... int numberOfPartitionsToAssign = baseFraction; @@ -146,11 +171,14 @@ public class RoundRobinOperatorStateRepartitioner implements OperatorStateRepart } // Now start collection the partitions for the parallel instance into this list - List<Tuple2<StreamStateHandle, long[]>> parallelOperatorState = new ArrayList<>(); + List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>> parallelOperatorState = + new ArrayList<>(); while (numberOfPartitionsToAssign > 0) { - Tuple2<StreamStateHandle, long[]> handleWithOffsets = current.get(lstIdx); - long[] offsets = handleWithOffsets.f1; + Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo> handleWithOffsets = + current.get(lstIdx); + + long[] offsets = handleWithOffsets.f1.getOffsets(); int remaining = offsets.length - offsetIdx; // Repartition offsets long[] offs; @@ -166,25 +194,74 @@ public class RoundRobinOperatorStateRepartitioner implements OperatorStateRepart ++lstIdx; } - parallelOperatorState.add( - new Tuple2<>(handleWithOffsets.f0, offs)); + parallelOperatorState.add(new Tuple2<>( + handleWithOffsets.f0, + new OperatorStateHandle.StateMetaInfo(offs, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE))); numberOfPartitionsToAssign -= remaining; // As a last step we merge partitions that use the same StreamStateHandle in a single // OperatorStateHandle Map<StreamStateHandle, OperatorStateHandle> mergeMap = mergeMapList.get(parallelOpIdx); - OperatorStateHandle psh = mergeMap.get(handleWithOffsets.f0); - if (psh == null) { - psh = new OperatorStateHandle(new HashMap<String, long[]>(), handleWithOffsets.f0); - mergeMap.put(handleWithOffsets.f0, psh); + OperatorStateHandle operatorStateHandle = mergeMap.get(handleWithOffsets.f0); + if (operatorStateHandle == null) { + operatorStateHandle = new OperatorStateHandle( + new HashMap<String, OperatorStateHandle.StateMetaInfo>(), + handleWithOffsets.f0); + + mergeMap.put(handleWithOffsets.f0, operatorStateHandle); } - psh.getStateNameToPartitionOffsets().put(e.getKey(), offs); + operatorStateHandle.getStateNameToPartitionOffsets().put( + e.getKey(), + new OperatorStateHandle.StateMetaInfo(offs, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE)); } } - startParallelOP = newStartParallelOp; + startParallelOp = newStartParallelOp; e.setValue(null); } + + // Now we also add the state handles marked for broadcast to all parallel instances + Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>> broadcastNameToState = + nameToStateByMode.getByMode(OperatorStateHandle.Mode.BROADCAST); + + for (int i = 0; i < parallelism; ++i) { + + Map<StreamStateHandle, OperatorStateHandle> mergeMap = mergeMapList.get(i); + + for (Map.Entry<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>> e : + broadcastNameToState.entrySet()) { + + List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>> current = e.getValue(); + + for (Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo> handleWithMetaInfo : current) { + OperatorStateHandle operatorStateHandle = mergeMap.get(handleWithMetaInfo.f0); + if (operatorStateHandle == null) { + operatorStateHandle = new OperatorStateHandle( + new HashMap<String, OperatorStateHandle.StateMetaInfo>(), + handleWithMetaInfo.f0); + + mergeMap.put(handleWithMetaInfo.f0, operatorStateHandle); + } + operatorStateHandle.getStateNameToPartitionOffsets().put(e.getKey(), handleWithMetaInfo.f1); + } + } + } return mergeMapList; } + + private static final class GroupByStateNameResults { + private final EnumMap<OperatorStateHandle.Mode, + Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>>> byMode; + + public GroupByStateNameResults( + EnumMap<OperatorStateHandle.Mode, + Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>>> byMode) { + this.byMode = Preconditions.checkNotNull(byMode); + } + + public Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>> getByMode( + OperatorStateHandle.Mode mode) { + return byMode.get(mode); + } + } } http://git-wip-us.apache.org/repos/asf/flink/blob/1020ba2c/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java index 2e05a85..f11f69b 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java @@ -338,9 +338,22 @@ public class StateAssignmentOperation { chainOpParallelStates, newParallelism); } else { - List<Collection<OperatorStateHandle>> repackStream = new ArrayList<>(newParallelism); for (OperatorStateHandle operatorStateHandle : chainOpParallelStates) { + + Map<String, OperatorStateHandle.StateMetaInfo> partitionOffsets = + operatorStateHandle.getStateNameToPartitionOffsets(); + + for (OperatorStateHandle.StateMetaInfo metaInfo : partitionOffsets.values()) { + + // if we find any broadcast state, we cannot take the shortcut and need to go through repartitioning + if (OperatorStateHandle.Mode.BROADCAST.equals(metaInfo.getDistributionMode())) { + return opStateRepartitioner.repartitionState( + chainOpParallelStates, + newParallelism); + } + } + repackStream.add(Collections.singletonList(operatorStateHandle)); } return repackStream; http://git-wip-us.apache.org/repos/asf/flink/blob/1020ba2c/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Serializer.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Serializer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Serializer.java index 48324ca..ba1949a 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Serializer.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Serializer.java @@ -250,11 +250,18 @@ class SavepointV1Serializer implements SavepointSerializer<SavepointV1> { if (stateHandle != null) { dos.writeByte(PARTITIONABLE_OPERATOR_STATE_HANDLE); - Map<String, long[]> partitionOffsetsMap = stateHandle.getStateNameToPartitionOffsets(); + Map<String, OperatorStateHandle.StateMetaInfo> partitionOffsetsMap = + stateHandle.getStateNameToPartitionOffsets(); dos.writeInt(partitionOffsetsMap.size()); - for (Map.Entry<String, long[]> entry : partitionOffsetsMap.entrySet()) { + for (Map.Entry<String, OperatorStateHandle.StateMetaInfo> entry : partitionOffsetsMap.entrySet()) { dos.writeUTF(entry.getKey()); - long[] offsets = entry.getValue(); + + OperatorStateHandle.StateMetaInfo stateMetaInfo = entry.getValue(); + + int mode = stateMetaInfo.getDistributionMode().ordinal(); + dos.writeByte(mode); + + long[] offsets = stateMetaInfo.getOffsets(); dos.writeInt(offsets.length); for (long offset : offsets) { dos.writeLong(offset); @@ -274,14 +281,21 @@ class SavepointV1Serializer implements SavepointSerializer<SavepointV1> { return null; } else if (PARTITIONABLE_OPERATOR_STATE_HANDLE == type) { int mapSize = dis.readInt(); - Map<String, long[]> offsetsMap = new HashMap<>(mapSize); + Map<String, OperatorStateHandle.StateMetaInfo> offsetsMap = new HashMap<>(mapSize); for (int i = 0; i < mapSize; ++i) { String key = dis.readUTF(); + + int modeOrdinal = dis.readByte(); + OperatorStateHandle.Mode mode = OperatorStateHandle.Mode.values()[modeOrdinal]; + long[] offsets = new long[dis.readInt()]; for (int j = 0; j < offsets.length; ++j) { offsets[j] = dis.readLong(); } - offsetsMap.put(key, offsets); + + OperatorStateHandle.StateMetaInfo metaInfo = + new OperatorStateHandle.StateMetaInfo(offsets, mode); + offsetsMap.put(key, metaInfo); } StreamStateHandle stateHandle = deserializeStreamStateHandle(dis); return new OperatorStateHandle(offsetsMap, stateHandle); http://git-wip-us.apache.org/repos/asf/flink/blob/1020ba2c/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java index 10bb409..6c65088 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java @@ -19,6 +19,7 @@ package org.apache.flink.runtime.state; import org.apache.commons.io.IOUtils; +import org.apache.flink.annotation.Internal; import org.apache.flink.api.common.state.ListState; import org.apache.flink.api.common.state.ListStateDescriptor; import org.apache.flink.api.common.typeutils.TypeSerializer; @@ -44,6 +45,7 @@ import java.util.concurrent.RunnableFuture; /** * Default implementation of OperatorStateStore that provides the ability to make snapshots. */ +@Internal public class DefaultOperatorStateBackend implements OperatorStateBackend { /** The default namespace for state in cases where no state name is provided */ @@ -62,14 +64,46 @@ public class DefaultOperatorStateBackend implements OperatorStateBackend { this.registeredStates = new HashMap<>(); } + @Override + public Set<String> getRegisteredStateNames() { + return registeredStates.keySet(); + } + + @Override + public void close() throws IOException { + closeStreamOnCancelRegistry.close(); + } + + @Override + public void dispose() { + registeredStates.clear(); + } + @SuppressWarnings("unchecked") @Override public <T extends Serializable> ListState<T> getSerializableListState(String stateName) throws Exception { return (ListState<T>) getOperatorState(new ListStateDescriptor<>(stateName, javaSerializer)); } - + @Override public <S> ListState<S> getOperatorState(ListStateDescriptor<S> stateDescriptor) throws IOException { + return getOperatorState(stateDescriptor, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE); + } + + @SuppressWarnings("unchecked") + @Override + public <T extends Serializable> ListState<T> getBroadcastSerializableListState(String stateName) throws Exception { + return (ListState<T>) getBroadcastOperatorState(new ListStateDescriptor<>(stateName, javaSerializer)); + } + + @Override + public <S> ListState<S> getBroadcastOperatorState(ListStateDescriptor<S> stateDescriptor) throws Exception { + return getOperatorState(stateDescriptor, OperatorStateHandle.Mode.BROADCAST); + } + + private <S> ListState<S> getOperatorState( + ListStateDescriptor<S> stateDescriptor, + OperatorStateHandle.Mode mode) throws IOException { Preconditions.checkNotNull(stateDescriptor); @@ -81,10 +115,18 @@ public class DefaultOperatorStateBackend implements OperatorStateBackend { if (null == partitionableListState) { - partitionableListState = new PartitionableListState<>(name, partitionStateSerializer); + partitionableListState = new PartitionableListState<>( + name, + partitionStateSerializer, + mode); + registeredStates.put(name, partitionableListState); } else { Preconditions.checkState( + partitionableListState.getAssignmentMode().equals(mode), + "Incompatible assignment mode. Provided: " + mode + ", expected: " + + partitionableListState.getAssignmentMode()); + Preconditions.checkState( partitionableListState.getPartitionStateSerializer(). isCompatibleWith(stateDescriptor.getSerializer()), "Incompatible type serializers. Provided: " + stateDescriptor.getSerializer() + @@ -97,16 +139,21 @@ public class DefaultOperatorStateBackend implements OperatorStateBackend { private static <S> void deserializeStateValues( PartitionableListState<S> stateListForName, FSDataInputStream in, - long[] offsets) throws IOException { - - DataInputView div = new DataInputViewStreamWrapper(in); - TypeSerializer<S> serializer = stateListForName.getPartitionStateSerializer(); - for (long offset : offsets) { - in.seek(offset); - stateListForName.add(serializer.deserialize(div)); + OperatorStateHandle.StateMetaInfo metaInfo) throws IOException { + + if (null != metaInfo) { + long[] offsets = metaInfo.getOffsets(); + if (null != offsets) { + DataInputView div = new DataInputViewStreamWrapper(in); + TypeSerializer<S> serializer = stateListForName.getPartitionStateSerializer(); + for (long offset : offsets) { + in.seek(offset); + stateListForName.add(serializer.deserialize(div)); + } + } } } - + @Override public RunnableFuture<OperatorStateHandle> snapshot( long checkpointId, long timestamp, CheckpointStreamFactory streamFactory) throws Exception { @@ -123,11 +170,12 @@ public class DefaultOperatorStateBackend implements OperatorStateBackend { OperatorBackendSerializationProxy.StateMetaInfo<?> metaInfo = new OperatorBackendSerializationProxy.StateMetaInfo<>( state.getName(), - state.getPartitionStateSerializer()); + state.getPartitionStateSerializer(), + state.getAssignmentMode()); metaInfoList.add(metaInfo); } - Map<String, long[]> writtenStatesMetaData = new HashMap<>(registeredStates.size()); + Map<String, OperatorStateHandle.StateMetaInfo> writtenStatesMetaData = new HashMap<>(registeredStates.size()); CheckpointStreamFactory.CheckpointStateOutputStream out = streamFactory. createCheckpointStateOutputStream(checkpointId, timestamp); @@ -145,8 +193,10 @@ public class DefaultOperatorStateBackend implements OperatorStateBackend { dov.writeInt(registeredStates.size()); for (Map.Entry<String, PartitionableListState<?>> entry : registeredStates.entrySet()) { - long[] partitionOffsets = entry.getValue().write(out); - writtenStatesMetaData.put(entry.getKey(), partitionOffsets); + PartitionableListState<?> value = entry.getValue(); + long[] partitionOffsets = value.write(out); + OperatorStateHandle.Mode mode = value.getAssignmentMode(); + writtenStatesMetaData.put(entry.getKey(), new OperatorStateHandle.StateMetaInfo(partitionOffsets, mode)); } OperatorStateHandle handle = new OperatorStateHandle(writtenStatesMetaData, out.closeAndGetHandle()); @@ -193,7 +243,8 @@ public class DefaultOperatorStateBackend implements OperatorStateBackend { if (null == listState) { listState = new PartitionableListState<>( stateMetaInfo.getName(), - stateMetaInfo.getStateSerializer()); + stateMetaInfo.getStateSerializer(), + stateMetaInfo.getMode()); registeredStates.put(listState.getName(), listState); } else { @@ -205,7 +256,9 @@ public class DefaultOperatorStateBackend implements OperatorStateBackend { } // Restore all the state in PartitionableListStates - for (Map.Entry<String, long[]> nameToOffsets : stateHandle.getStateNameToPartitionOffsets().entrySet()) { + for (Map.Entry<String, OperatorStateHandle.StateMetaInfo> nameToOffsets : + stateHandle.getStateNameToPartitionOffsets().entrySet()) { + PartitionableListState<?> stateListForName = registeredStates.get(nameToOffsets.getKey()); Preconditions.checkState(null != stateListForName, "Found state without " + @@ -222,60 +275,40 @@ public class DefaultOperatorStateBackend implements OperatorStateBackend { } } - @Override - public void dispose() { - registeredStates.clear(); - } - - @Override - public Set<String> getRegisteredStateNames() { - return registeredStates.keySet(); - } - - @Override - public void close() throws IOException { - closeStreamOnCancelRegistry.close(); - } - static final class PartitionableListState<S> implements ListState<S> { - private final List<S> internalList; private final String name; private final TypeSerializer<S> partitionStateSerializer; + private final OperatorStateHandle.Mode assignmentMode; + private final List<S> internalList; - public PartitionableListState(String name, TypeSerializer<S> partitionStateSerializer) { - this.internalList = new ArrayList<>(); - this.partitionStateSerializer = Preconditions.checkNotNull(partitionStateSerializer); - this.name = Preconditions.checkNotNull(name); - } - - public long[] write(FSDataOutputStream out) throws IOException { - - long[] partitionOffsets = new long[internalList.size()]; - - DataOutputView dov = new DataOutputViewStreamWrapper(out); - - for (int i = 0; i < internalList.size(); ++i) { - S element = internalList.get(i); - partitionOffsets[i] = out.getPos(); - partitionStateSerializer.serialize(element, dov); - } - - return partitionOffsets; - } + public PartitionableListState( + String name, + TypeSerializer<S> partitionStateSerializer, + OperatorStateHandle.Mode assignmentMode) { - public List<S> getInternalList() { - return internalList; + this.name = Preconditions.checkNotNull(name); + this.partitionStateSerializer = Preconditions.checkNotNull(partitionStateSerializer); + this.assignmentMode = Preconditions.checkNotNull(assignmentMode); + this.internalList = new ArrayList<>(); } public String getName() { return name; } + public OperatorStateHandle.Mode getAssignmentMode() { + return assignmentMode; + } + public TypeSerializer<S> getPartitionStateSerializer() { return partitionStateSerializer; } + public List<S> getInternalList() { + return internalList; + } + @Override public void clear() { internalList.clear(); @@ -294,8 +327,25 @@ public class DefaultOperatorStateBackend implements OperatorStateBackend { @Override public String toString() { return "PartitionableListState{" + - "listState=" + internalList + + "name='" + name + '\'' + + ", assignmentMode=" + assignmentMode + + ", internalList=" + internalList + '}'; } + + public long[] write(FSDataOutputStream out) throws IOException { + + long[] partitionOffsets = new long[internalList.size()]; + + DataOutputView dov = new DataOutputViewStreamWrapper(out); + + for (int i = 0; i < internalList.size(); ++i) { + S element = internalList.get(i); + partitionOffsets[i] = out.getPos(); + partitionStateSerializer.serialize(element, dov); + } + + return partitionOffsets; + } } } \ No newline at end of file http://git-wip-us.apache.org/repos/asf/flink/blob/1020ba2c/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorBackendSerializationProxy.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorBackendSerializationProxy.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorBackendSerializationProxy.java index 61df979..d571dcc 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorBackendSerializationProxy.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorBackendSerializationProxy.java @@ -18,6 +18,7 @@ package org.apache.flink.runtime.state; +import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.java.typeutils.runtime.DataInputViewStream; import org.apache.flink.api.java.typeutils.runtime.DataOutputViewStream; @@ -91,15 +92,19 @@ public class OperatorBackendSerializationProxy extends VersionedIOReadableWritab private String name; private TypeSerializer<S> stateSerializer; + private OperatorStateHandle.Mode mode; + private ClassLoader userClassLoader; - private StateMetaInfo(ClassLoader userClassLoader) { + @VisibleForTesting + public StateMetaInfo(ClassLoader userClassLoader) { this.userClassLoader = Preconditions.checkNotNull(userClassLoader); } - public StateMetaInfo(String name, TypeSerializer<S> stateSerializer) { + public StateMetaInfo(String name, TypeSerializer<S> stateSerializer, OperatorStateHandle.Mode mode) { this.name = Preconditions.checkNotNull(name); this.stateSerializer = Preconditions.checkNotNull(stateSerializer); + this.mode = Preconditions.checkNotNull(mode); } public String getName() { @@ -118,9 +123,18 @@ public class OperatorBackendSerializationProxy extends VersionedIOReadableWritab this.stateSerializer = stateSerializer; } + public OperatorStateHandle.Mode getMode() { + return mode; + } + + public void setMode(OperatorStateHandle.Mode mode) { + this.mode = mode; + } + @Override public void write(DataOutputView out) throws IOException { out.writeUTF(getName()); + out.writeByte(getMode().ordinal()); DataOutputViewStream dos = new DataOutputViewStream(out); InstantiationUtil.serializeObject(dos, getStateSerializer()); } @@ -128,6 +142,7 @@ public class OperatorBackendSerializationProxy extends VersionedIOReadableWritab @Override public void read(DataInputView in) throws IOException { setName(in.readUTF()); + setMode(OperatorStateHandle.Mode.values()[in.readByte()]); DataInputViewStream dis = new DataInputViewStream(in); try { TypeSerializer<S> stateSerializer = InstantiationUtil.deserializeObject(dis, userClassLoader); @@ -136,5 +151,37 @@ public class OperatorBackendSerializationProxy extends VersionedIOReadableWritab throw new IOException(exception); } } + + @Override + public boolean equals(Object o) { + + if (this == o) { + return true; + } + + if (o == null || getClass() != o.getClass()) { + return false; + } + + StateMetaInfo<?> metaInfo = (StateMetaInfo<?>) o; + + if (!getName().equals(metaInfo.getName())) { + return false; + } + + if (!getStateSerializer().equals(metaInfo.getStateSerializer())) { + return false; + } + + return getMode() == metaInfo.getMode(); + } + + @Override + public int hashCode() { + int result = getName().hashCode(); + result = 31 * result + getStateSerializer().hashCode(); + result = 31 * result + getMode().hashCode(); + return result; + } } } http://git-wip-us.apache.org/repos/asf/flink/blob/1020ba2c/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateCheckpointOutputStream.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateCheckpointOutputStream.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateCheckpointOutputStream.java index eaa9fd9..036aed0 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateCheckpointOutputStream.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateCheckpointOutputStream.java @@ -66,8 +66,14 @@ public final class OperatorStateCheckpointOutputStream startNewPartition(); } - Map<String, long[]> offsetsMap = new HashMap<>(1); - offsetsMap.put(DefaultOperatorStateBackend.DEFAULT_OPERATOR_STATE_NAME, partitionOffsets.toArray()); + Map<String, OperatorStateHandle.StateMetaInfo> offsetsMap = new HashMap<>(1); + + OperatorStateHandle.StateMetaInfo metaInfo = + new OperatorStateHandle.StateMetaInfo( + partitionOffsets.toArray(), + OperatorStateHandle.Mode.SPLIT_DISTRIBUTE); + + offsetsMap.put(DefaultOperatorStateBackend.DEFAULT_OPERATOR_STATE_NAME, metaInfo); return new OperatorStateHandle(offsetsMap, streamStateHandle); } http://git-wip-us.apache.org/repos/asf/flink/blob/1020ba2c/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateHandle.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateHandle.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateHandle.java index 3cd37c9..c59fbad 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateHandle.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateHandle.java @@ -22,6 +22,7 @@ import org.apache.flink.core.fs.FSDataInputStream; import org.apache.flink.util.Preconditions; import java.io.IOException; +import java.io.Serializable; import java.util.Arrays; import java.util.Map; @@ -31,21 +32,27 @@ import java.util.Map; */ public class OperatorStateHandle implements StreamStateHandle { + public enum Mode { + SPLIT_DISTRIBUTE, BROADCAST + } + private static final long serialVersionUID = 35876522969227335L; - /** unique state name -> offsets for available partitions in the handle stream */ - private final Map<String, long[]> stateNameToPartitionOffsets; + /** + * unique state name -> offsets for available partitions in the handle stream + */ + private final Map<String, StateMetaInfo> stateNameToPartitionOffsets; private final StreamStateHandle delegateStateHandle; public OperatorStateHandle( - Map<String, long[]> stateNameToPartitionOffsets, + Map<String, StateMetaInfo> stateNameToPartitionOffsets, StreamStateHandle delegateStateHandle) { this.delegateStateHandle = Preconditions.checkNotNull(delegateStateHandle); this.stateNameToPartitionOffsets = Preconditions.checkNotNull(stateNameToPartitionOffsets); } - public Map<String, long[]> getStateNameToPartitionOffsets() { + public Map<String, StateMetaInfo> getStateNameToPartitionOffsets() { return stateNameToPartitionOffsets; } @@ -80,12 +87,12 @@ public class OperatorStateHandle implements StreamStateHandle { OperatorStateHandle that = (OperatorStateHandle) o; - if(stateNameToPartitionOffsets.size() != that.stateNameToPartitionOffsets.size()) { + if (stateNameToPartitionOffsets.size() != that.stateNameToPartitionOffsets.size()) { return false; } - for (Map.Entry<String, long[]> entry : stateNameToPartitionOffsets.entrySet()) { - if (!Arrays.equals(entry.getValue(), that.stateNameToPartitionOffsets.get(entry.getKey()))) { + for (Map.Entry<String, StateMetaInfo> entry : stateNameToPartitionOffsets.entrySet()) { + if (!entry.getValue().equals(that.stateNameToPartitionOffsets.get(entry.getKey()))) { return false; } } @@ -96,14 +103,75 @@ public class OperatorStateHandle implements StreamStateHandle { @Override public int hashCode() { int result = delegateStateHandle.hashCode(); - for (Map.Entry<String, long[]> entry : stateNameToPartitionOffsets.entrySet()) { + for (Map.Entry<String, StateMetaInfo> entry : stateNameToPartitionOffsets.entrySet()) { int entryHash = entry.getKey().hashCode(); if (entry.getValue() != null) { - entryHash += Arrays.hashCode(entry.getValue()); + entryHash += entry.getValue().hashCode(); } result = 31 * result + entryHash; } return result; } + + @Override + public String toString() { + return "OperatorStateHandle{" + + "stateNameToPartitionOffsets=" + stateNameToPartitionOffsets + + ", delegateStateHandle=" + delegateStateHandle + + '}'; + } + + public static class StateMetaInfo implements Serializable { + + private static final long serialVersionUID = 3593817615858941166L; + + private final long[] offsets; + private final Mode distributionMode; + + public StateMetaInfo(long[] offsets, Mode distributionMode) { + this.offsets = Preconditions.checkNotNull(offsets); + this.distributionMode = Preconditions.checkNotNull(distributionMode); + } + + public long[] getOffsets() { + return offsets; + } + + public Mode getDistributionMode() { + return distributionMode; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + StateMetaInfo that = (StateMetaInfo) o; + + if (!Arrays.equals(getOffsets(), that.getOffsets())) { + return false; + } + return getDistributionMode() == that.getDistributionMode(); + } + + @Override + public int hashCode() { + int result = Arrays.hashCode(getOffsets()); + result = 31 * result + getDistributionMode().hashCode(); + return result; + } + + @Override + public String toString() { + return "StateMetaInfo{" + + "offsets=" + Arrays.toString(offsets) + + ", distributionMode=" + distributionMode + + '}'; + } + } } http://git-wip-us.apache.org/repos/asf/flink/blob/1020ba2c/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateInitializationContextImpl.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateInitializationContextImpl.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateInitializationContextImpl.java index 46445d2..886d214 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateInitializationContextImpl.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateInitializationContextImpl.java @@ -220,13 +220,21 @@ public class StateInitializationContextImpl implements StateInitializationContex while (stateHandleIterator.hasNext()) { currentStateHandle = stateHandleIterator.next(); - long[] offsets = currentStateHandle.getStateNameToPartitionOffsets().get(stateName); - if (null != offsets && offsets.length > 0) { + OperatorStateHandle.StateMetaInfo metaInfo = + currentStateHandle.getStateNameToPartitionOffsets().get(stateName); - this.offsets = offsets; - this.offPos = 0; + if (null != metaInfo) { + long[] metaOffsets = metaInfo.getOffsets(); + if (null != metaOffsets && metaOffsets.length > 0) { + this.offsets = metaOffsets; + this.offPos = 0; - return true; + closableRegistry.unregisterClosable(currentStream); + IOUtils.closeQuietly(currentStream); + currentStream = null; + + return true; + } } } http://git-wip-us.apache.org/repos/asf/flink/blob/1020ba2c/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java index daacbfb..ca9dbc2 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java @@ -941,7 +941,7 @@ public class CheckpointCoordinatorTest { } @Test - public void handleMessagesForNonExistingCheckpoints() { + public void testHandleMessagesForNonExistingCheckpoints() { try { final JobID jid = new JobID(); final long timestamp = System.currentTimeMillis(); @@ -1937,8 +1937,8 @@ public class CheckpointCoordinatorTest { coord.restoreLatestCheckpointedState(tasks, true, false); // verify the restored state - verifiyStateRestore(jobVertexID1, jobVertex1, keyGroupPartitions1); - verifiyStateRestore(jobVertexID2, jobVertex2, keyGroupPartitions2); + verifyStateRestore(jobVertexID1, jobVertex1, keyGroupPartitions1); + verifyStateRestore(jobVertexID2, jobVertex2, keyGroupPartitions2); } /** @@ -2318,7 +2318,7 @@ public class CheckpointCoordinatorTest { coord.restoreLatestCheckpointedState(tasks, true, false); // verify the restored state - verifiyStateRestore(jobVertexID1, newJobVertex1, keyGroupPartitions1); + verifyStateRestore(jobVertexID1, newJobVertex1, keyGroupPartitions1); List<List<Collection<OperatorStateHandle>>> actualOpStatesBackend = new ArrayList<>(newJobVertex2.getParallelism()); List<List<Collection<OperatorStateHandle>>> actualOpStatesRaw = new ArrayList<>(newJobVertex2.getParallelism()); for (int i = 0; i < newJobVertex2.getParallelism(); i++) { @@ -2390,6 +2390,49 @@ public class CheckpointCoordinatorTest { } } + @Test + public void testReplicateModeStateHandle() { + Map<String, OperatorStateHandle.StateMetaInfo> metaInfoMap = new HashMap<>(1); + metaInfoMap.put("t-1", new OperatorStateHandle.StateMetaInfo(new long[]{0, 23}, OperatorStateHandle.Mode.BROADCAST)); + metaInfoMap.put("t-2", new OperatorStateHandle.StateMetaInfo(new long[]{42, 64}, OperatorStateHandle.Mode.BROADCAST)); + metaInfoMap.put("t-3", new OperatorStateHandle.StateMetaInfo(new long[]{72, 83}, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE)); + OperatorStateHandle osh = new OperatorStateHandle(metaInfoMap, new ByteStreamStateHandle("test", new byte[100])); + + OperatorStateRepartitioner repartitioner = RoundRobinOperatorStateRepartitioner.INSTANCE; + List<Collection<OperatorStateHandle>> repartitionedStates = + repartitioner.repartitionState(Collections.singletonList(osh), 3); + + Map<String, Integer> checkCounts = new HashMap<>(3); + + for (Collection<OperatorStateHandle> operatorStateHandles : repartitionedStates) { + for (OperatorStateHandle operatorStateHandle : operatorStateHandles) { + for (Map.Entry<String, OperatorStateHandle.StateMetaInfo> stateNameToMetaInfo : + operatorStateHandle.getStateNameToPartitionOffsets().entrySet()) { + + String stateName = stateNameToMetaInfo.getKey(); + Integer count = checkCounts.get(stateName); + if (null == count) { + checkCounts.put(stateName, 1); + } else { + checkCounts.put(stateName, 1 + count); + } + + OperatorStateHandle.StateMetaInfo stateMetaInfo = stateNameToMetaInfo.getValue(); + if (OperatorStateHandle.Mode.SPLIT_DISTRIBUTE.equals(stateMetaInfo.getDistributionMode())) { + Assert.assertEquals(1, stateNameToMetaInfo.getValue().getOffsets().length); + } else { + Assert.assertEquals(2, stateNameToMetaInfo.getValue().getOffsets().length); + } + } + } + } + + Assert.assertEquals(3, checkCounts.size()); + Assert.assertEquals(3, checkCounts.get("t-1").intValue()); + Assert.assertEquals(3, checkCounts.get("t-2").intValue()); + Assert.assertEquals(2, checkCounts.get("t-3").intValue()); + } + // ------------------------------------------------------------------------ // Utilities // ------------------------------------------------------------------------ @@ -2520,11 +2563,15 @@ public class CheckpointCoordinatorTest { Tuple2<byte[], List<long[]>> serializationWithOffsets = serializeTogetherAndTrackOffsets(namedStateSerializables); - Map<String, long[]> offsetsMap = new HashMap<>(states.size()); + Map<String, OperatorStateHandle.StateMetaInfo> offsetsMap = new HashMap<>(states.size()); int idx = 0; for (Map.Entry<String, List<? extends Serializable>> entry : states.entrySet()) { - offsetsMap.put(entry.getKey(), serializationWithOffsets.f1.get(idx)); + offsetsMap.put( + entry.getKey(), + new OperatorStateHandle.StateMetaInfo( + serializationWithOffsets.f1.get(idx), + OperatorStateHandle.Mode.SPLIT_DISTRIBUTE)); ++idx; } @@ -2601,7 +2648,7 @@ public class CheckpointCoordinatorTest { return vertex; } - public static void verifiyStateRestore( + public static void verifyStateRestore( JobVertexID jobVertexID, ExecutionJobVertex executionJobVertex, List<KeyGroupRange> keyGroupPartitions) throws Exception { @@ -2697,8 +2744,8 @@ public class CheckpointCoordinatorTest { private static void collectResult(int opIdx, OperatorStateHandle operatorStateHandle, List<String> resultCollector) throws Exception { try (FSDataInputStream in = operatorStateHandle.openInputStream()) { - for (Map.Entry<String, long[]> entry : operatorStateHandle.getStateNameToPartitionOffsets().entrySet()) { - for (long offset : entry.getValue()) { + for (Map.Entry<String, OperatorStateHandle.StateMetaInfo> entry : operatorStateHandle.getStateNameToPartitionOffsets().entrySet()) { + for (long offset : entry.getValue().getOffsets()) { in.seek(offset); Integer state = InstantiationUtil. deserializeObject(in, Thread.currentThread().getContextClassLoader()); @@ -2801,17 +2848,22 @@ public class CheckpointCoordinatorTest { for (int i = 0; i < oldParallelism; ++i) { Path fakePath = new Path("/fake-" + i); - Map<String, long[]> namedStatesToOffsets = new HashMap<>(); + Map<String, OperatorStateHandle.StateMetaInfo> namedStatesToOffsets = new HashMap<>(); int off = 0; for (int s = 0; s < numNamedStates; ++s) { long[] offs = new long[1 + r.nextInt(maxPartitionsPerState)]; - if (offs.length > 0) { - for (int o = 0; o < offs.length; ++o) { - offs[o] = off; - ++off; - } - namedStatesToOffsets.put("State-" + s, offs); + + for (int o = 0; o < offs.length; ++o) { + offs[o] = off; + ++off; } + + OperatorStateHandle.Mode mode = r.nextInt(10) == 0 ? + OperatorStateHandle.Mode.BROADCAST : OperatorStateHandle.Mode.SPLIT_DISTRIBUTE; + namedStatesToOffsets.put( + "State-" + s, + new OperatorStateHandle.StateMetaInfo(offs, mode)); + } previousParallelOpInstanceStates.add( @@ -2822,14 +2874,21 @@ public class CheckpointCoordinatorTest { int expectedTotalPartitions = 0; for (OperatorStateHandle psh : previousParallelOpInstanceStates) { - Map<String, long[]> offsMap = psh.getStateNameToPartitionOffsets(); + Map<String, OperatorStateHandle.StateMetaInfo> offsMap = psh.getStateNameToPartitionOffsets(); Map<String, List<Long>> offsMapWithList = new HashMap<>(offsMap.size()); - for (Map.Entry<String, long[]> e : offsMap.entrySet()) { - long[] offs = e.getValue(); - expectedTotalPartitions += offs.length; + for (Map.Entry<String, OperatorStateHandle.StateMetaInfo> e : offsMap.entrySet()) { + + long[] offs = e.getValue().getOffsets(); + int replication = e.getValue().getDistributionMode().equals(OperatorStateHandle.Mode.BROADCAST) ? + newParallelism : 1; + + expectedTotalPartitions += replication * offs.length; List<Long> offsList = new ArrayList<>(offs.length); + for (int i = 0; i < offs.length; ++i) { - offsList.add(i, offs[i]); + for(int p = 0; p < replication; ++p) { + offsList.add(offs[i]); + } } offsMapWithList.put(e.getKey(), offsList); } @@ -2851,25 +2910,25 @@ public class CheckpointCoordinatorTest { Collection<OperatorStateHandle> pshc = pshs.get(p); for (OperatorStateHandle sh : pshc) { - for (Map.Entry<String, long[]> namedState : sh.getStateNameToPartitionOffsets().entrySet()) { + for (Map.Entry<String, OperatorStateHandle.StateMetaInfo> namedState : sh.getStateNameToPartitionOffsets().entrySet()) { - Map<String, List<Long>> x = actual.get(sh.getDelegateStateHandle()); - if (x == null) { - x = new HashMap<>(); - actual.put(sh.getDelegateStateHandle(), x); + Map<String, List<Long>> stateToOffsets = actual.get(sh.getDelegateStateHandle()); + if (stateToOffsets == null) { + stateToOffsets = new HashMap<>(); + actual.put(sh.getDelegateStateHandle(), stateToOffsets); } - List<Long> actualOffs = x.get(namedState.getKey()); + List<Long> actualOffs = stateToOffsets.get(namedState.getKey()); if (actualOffs == null) { actualOffs = new ArrayList<>(); - x.put(namedState.getKey(), actualOffs); + stateToOffsets.put(namedState.getKey(), actualOffs); } - long[] add = namedState.getValue(); + long[] add = namedState.getValue().getOffsets(); for (int i = 0; i < add.length; ++i) { actualOffs.add(add[i]); } - partitionCount += namedState.getValue().length; + partitionCount += namedState.getValue().getOffsets().length; } } http://git-wip-us.apache.org/repos/asf/flink/blob/1020ba2c/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Test.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Test.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Test.java index db5c35b..5184db8 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Test.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Test.java @@ -99,9 +99,10 @@ public class SavepointV1Test { new TestByteStreamStateHandleDeepCompare("b-" + chainIdx, ("Beautiful-" + chainIdx).getBytes()); StreamStateHandle operatorStateStream = new TestByteStreamStateHandleDeepCompare("b-" + chainIdx, ("Beautiful-" + chainIdx).getBytes()); - Map<String, long[]> offsetsMap = new HashMap<>(); - offsetsMap.put("A", new long[]{0, 10, 20}); - offsetsMap.put("B", new long[]{30, 40, 50}); + Map<String, OperatorStateHandle.StateMetaInfo> offsetsMap = new HashMap<>(); + offsetsMap.put("A", new OperatorStateHandle.StateMetaInfo(new long[]{0, 10, 20}, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE)); + offsetsMap.put("B", new OperatorStateHandle.StateMetaInfo(new long[]{30, 40, 50}, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE)); + offsetsMap.put("C", new OperatorStateHandle.StateMetaInfo(new long[]{60, 70, 80}, OperatorStateHandle.Mode.BROADCAST)); if (chainIdx != noNonPartitionableStateAtIndex) { nonPartitionableStates.add(nonPartitionableState); http://git-wip-us.apache.org/repos/asf/flink/blob/1020ba2c/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java index 515011f..cd0391f 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java @@ -31,12 +31,13 @@ import java.util.Iterator; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; public class OperatorStateBackendTest { - AbstractStateBackend abstractStateBackend = new MemoryStateBackend(1024); + AbstractStateBackend abstractStateBackend = new MemoryStateBackend(4096); static Environment createMockEnvironment() { Environment env = mock(Environment.class); @@ -62,6 +63,7 @@ public class OperatorStateBackendTest { OperatorStateBackend operatorStateBackend = createNewOperatorStateBackend(); ListStateDescriptor<Serializable> stateDescriptor1 = new ListStateDescriptor<>("test1", new JavaSerializer<>()); ListStateDescriptor<Serializable> stateDescriptor2 = new ListStateDescriptor<>("test2", new JavaSerializer<>()); + ListStateDescriptor<Serializable> stateDescriptor3 = new ListStateDescriptor<>("test3", new JavaSerializer<>()); ListState<Serializable> listState1 = operatorStateBackend.getOperatorState(stateDescriptor1); assertNotNull(listState1); assertEquals(1, operatorStateBackend.getRegisteredStateNames().size()); @@ -89,6 +91,20 @@ public class OperatorStateBackendTest { assertEquals(23, it.next()); assertTrue(!it.hasNext()); + ListState<Serializable> listState3 = operatorStateBackend.getBroadcastOperatorState(stateDescriptor3); + assertNotNull(listState3); + assertEquals(3, operatorStateBackend.getRegisteredStateNames().size()); + assertTrue(!it.hasNext()); + listState3.add(17); + listState3.add(3); + listState3.add(123); + + it = listState3.get().iterator(); + assertEquals(17, it.next()); + assertEquals(3, it.next()); + assertEquals(123, it.next()); + assertTrue(!it.hasNext()); + ListState<Serializable> listState1b = operatorStateBackend.getOperatorState(stateDescriptor1); assertNotNull(listState1b); listState1b.add(123); @@ -109,6 +125,20 @@ public class OperatorStateBackendTest { assertEquals(4711, it.next()); assertEquals(123, it.next()); assertTrue(!it.hasNext()); + + try { + operatorStateBackend.getBroadcastOperatorState(stateDescriptor2); + fail("Did not detect changed mode"); + } catch (IllegalStateException ignored) { + + } + + try { + operatorStateBackend.getOperatorState(stateDescriptor3); + fail("Did not detect changed mode"); + } catch (IllegalStateException ignored) { + + } } @Test @@ -116,8 +146,10 @@ public class OperatorStateBackendTest { OperatorStateBackend operatorStateBackend = createNewOperatorStateBackend(); ListStateDescriptor<Serializable> stateDescriptor1 = new ListStateDescriptor<>("test1", new JavaSerializer<>()); ListStateDescriptor<Serializable> stateDescriptor2 = new ListStateDescriptor<>("test2", new JavaSerializer<>()); + ListStateDescriptor<Serializable> stateDescriptor3 = new ListStateDescriptor<>("test3", new JavaSerializer<>()); ListState<Serializable> listState1 = operatorStateBackend.getOperatorState(stateDescriptor1); ListState<Serializable> listState2 = operatorStateBackend.getOperatorState(stateDescriptor2); + ListState<Serializable> listState3 = operatorStateBackend.getBroadcastOperatorState(stateDescriptor3); listState1.add(42); listState1.add(4711); @@ -126,11 +158,17 @@ public class OperatorStateBackendTest { listState2.add(13); listState2.add(23); + listState3.add(17); + listState3.add(18); + listState3.add(19); + listState3.add(20); + CheckpointStreamFactory streamFactory = abstractStateBackend.createStreamFactory(new JobID(), "testOperator"); OperatorStateHandle stateHandle = operatorStateBackend.snapshot(1, 1, streamFactory).get(); try { + operatorStateBackend.close(); operatorStateBackend.dispose(); operatorStateBackend = abstractStateBackend.createOperatorStateBackend( @@ -139,13 +177,13 @@ public class OperatorStateBackendTest { operatorStateBackend.restore(Collections.singletonList(stateHandle)); - assertEquals(2, operatorStateBackend.getRegisteredStateNames().size()); + assertEquals(3, operatorStateBackend.getRegisteredStateNames().size()); listState1 = operatorStateBackend.getOperatorState(stateDescriptor1); listState2 = operatorStateBackend.getOperatorState(stateDescriptor2); + listState3 = operatorStateBackend.getBroadcastOperatorState(stateDescriptor3); - assertEquals(2, operatorStateBackend.getRegisteredStateNames().size()); - + assertEquals(3, operatorStateBackend.getRegisteredStateNames().size()); Iterator<Serializable> it = listState1.get().iterator(); assertEquals(42, it.next()); @@ -158,6 +196,14 @@ public class OperatorStateBackendTest { assertEquals(23, it.next()); assertTrue(!it.hasNext()); + it = listState3.get().iterator(); + assertEquals(17, it.next()); + assertEquals(18, it.next()); + assertEquals(19, it.next()); + assertEquals(20, it.next()); + assertTrue(!it.hasNext()); + + operatorStateBackend.close(); operatorStateBackend.dispose(); } finally { stateHandle.discardState(); http://git-wip-us.apache.org/repos/asf/flink/blob/1020ba2c/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateHandleTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateHandleTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateHandleTest.java new file mode 100644 index 0000000..ab801b6 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateHandleTest.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state; + +import org.junit.Assert; +import org.junit.Test; + +public class OperatorStateHandleTest { + + @Test + public void testFixedEnumOrder() { + + // Ensure the order / ordinal of all values of enum 'mode' are fixed, as this is used for serialization + Assert.assertEquals(0, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE.ordinal()); + Assert.assertEquals(1, OperatorStateHandle.Mode.BROADCAST.ordinal()); + + // Ensure all enum values are registered and fixed forever by this test + Assert.assertEquals(2, OperatorStateHandle.Mode.values().length); + + // Byte is used to encode enum value on serialization + Assert.assertTrue(OperatorStateHandle.Mode.values().length <= Byte.MAX_VALUE); + } +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/flink/blob/1020ba2c/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateOutputCheckpointStreamTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateOutputCheckpointStreamTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateOutputCheckpointStreamTest.java index c6ef0f0..7efcd0d 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateOutputCheckpointStreamTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateOutputCheckpointStreamTest.java @@ -27,6 +27,7 @@ import org.junit.Assert; import org.junit.Test; import java.io.IOException; +import java.util.Map; public class OperatorStateOutputCheckpointStreamTest { @@ -77,15 +78,23 @@ public class OperatorStateOutputCheckpointStreamTest { OperatorStateHandle fullHandle = writeAllTestKeyGroups(stream, numPartitions); Assert.assertNotNull(fullHandle); + Map<String, OperatorStateHandle.StateMetaInfo> stateNameToPartitionOffsets = + fullHandle.getStateNameToPartitionOffsets(); + for (Map.Entry<String, OperatorStateHandle.StateMetaInfo> entry : stateNameToPartitionOffsets.entrySet()) { + + Assert.assertEquals(OperatorStateHandle.Mode.SPLIT_DISTRIBUTE, entry.getValue().getDistributionMode()); + } verifyRead(fullHandle, numPartitions); } private static void verifyRead(OperatorStateHandle fullHandle, int numPartitions) throws IOException { int count = 0; try (FSDataInputStream in = fullHandle.openInputStream()) { - long[] offsets = fullHandle.getStateNameToPartitionOffsets(). + OperatorStateHandle.StateMetaInfo metaInfo = fullHandle.getStateNameToPartitionOffsets(). get(DefaultOperatorStateBackend.DEFAULT_OPERATOR_STATE_NAME); + long[] offsets = metaInfo.getOffsets(); + Assert.assertNotNull(offsets); DataInputView div = new DataInputViewStreamWrapper(in); http://git-wip-us.apache.org/repos/asf/flink/blob/1020ba2c/flink-runtime/src/test/java/org/apache/flink/runtime/state/SerializationProxiesTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/SerializationProxiesTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/SerializationProxiesTest.java index 832b022..2448540 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/SerializationProxiesTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/SerializationProxiesTest.java @@ -36,7 +36,7 @@ import java.util.List; public class SerializationProxiesTest { @Test - public void testSerializationRoundtrip() throws Exception { + public void testKeyedBackendSerializationProxyRoundtrip() throws Exception { TypeSerializer<?> keySerializer = IntSerializer.INSTANCE; TypeSerializer<?> namespaceSerializer = LongSerializer.INSTANCE; @@ -67,13 +67,12 @@ public class SerializationProxiesTest { serializationProxy.read(new DataInputViewStreamWrapper(in)); } - Assert.assertEquals(keySerializer, serializationProxy.getKeySerializerProxy().getTypeSerializer()); Assert.assertEquals(stateMetaInfoList, serializationProxy.getNamedStateSerializationProxies()); } @Test - public void testMetaInfoSerialization() throws Exception { + public void testKeyedStateMetaInfoSerialization() throws Exception { String name = "test"; TypeSerializer<?> namespaceSerializer = LongSerializer.INSTANCE; @@ -97,6 +96,64 @@ public class SerializationProxiesTest { Assert.assertEquals(name, metaInfo.getStateName()); } + + @Test + public void testOperatorBackendSerializationProxyRoundtrip() throws Exception { + + TypeSerializer<?> stateSerializer = DoubleSerializer.INSTANCE; + + List<OperatorBackendSerializationProxy.StateMetaInfo<?>> stateMetaInfoList = new ArrayList<>(); + + stateMetaInfoList.add( + new OperatorBackendSerializationProxy.StateMetaInfo<>("a", stateSerializer, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE)); + stateMetaInfoList.add( + new OperatorBackendSerializationProxy.StateMetaInfo<>("b", stateSerializer, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE)); + stateMetaInfoList.add( + new OperatorBackendSerializationProxy.StateMetaInfo<>("c", stateSerializer, OperatorStateHandle.Mode.BROADCAST)); + + OperatorBackendSerializationProxy serializationProxy = + new OperatorBackendSerializationProxy(stateMetaInfoList); + + byte[] serialized; + try (ByteArrayOutputStreamWithPos out = new ByteArrayOutputStreamWithPos()) { + serializationProxy.write(new DataOutputViewStreamWrapper(out)); + serialized = out.toByteArray(); + } + + serializationProxy = + new OperatorBackendSerializationProxy(Thread.currentThread().getContextClassLoader()); + + try (ByteArrayInputStreamWithPos in = new ByteArrayInputStreamWithPos(serialized)) { + serializationProxy.read(new DataInputViewStreamWrapper(in)); + } + + Assert.assertEquals(stateMetaInfoList, serializationProxy.getNamedStateSerializationProxies()); + } + + @Test + public void testOperatorStateMetaInfoSerialization() throws Exception { + + String name = "test"; + TypeSerializer<?> stateSerializer = DoubleSerializer.INSTANCE; + + OperatorBackendSerializationProxy.StateMetaInfo<?> metaInfo = + new OperatorBackendSerializationProxy.StateMetaInfo<>(name, stateSerializer, OperatorStateHandle.Mode.BROADCAST); + + byte[] serialized; + try (ByteArrayOutputStreamWithPos out = new ByteArrayOutputStreamWithPos()) { + metaInfo.write(new DataOutputViewStreamWrapper(out)); + serialized = out.toByteArray(); + } + + metaInfo = new OperatorBackendSerializationProxy.StateMetaInfo<>(Thread.currentThread().getContextClassLoader()); + + try (ByteArrayInputStreamWithPos in = new ByteArrayInputStreamWithPos(serialized)) { + metaInfo.read(new DataInputViewStreamWrapper(in)); + } + + Assert.assertEquals(name, metaInfo.getName()); + } + /** * This test fixes the order of elements in the enum which is important for serialization. Do not modify this test * except if you are entirely sure what you are doing. http://git-wip-us.apache.org/repos/asf/flink/blob/1020ba2c/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StateInitializationContextImplTest.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StateInitializationContextImplTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StateInitializationContextImplTest.java index 39dc5d6..963c42c 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StateInitializationContextImplTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StateInitializationContextImplTest.java @@ -111,8 +111,10 @@ public class StateInitializationContextImplTest { writtenOperatorStates.add(val); } - Map<String, long[]> offsetsMap = new HashMap<>(); - offsetsMap.put(DefaultOperatorStateBackend.DEFAULT_OPERATOR_STATE_NAME, offsets.toArray()); + Map<String, OperatorStateHandle.StateMetaInfo> offsetsMap = new HashMap<>(); + offsetsMap.put( + DefaultOperatorStateBackend.DEFAULT_OPERATOR_STATE_NAME, + new OperatorStateHandle.StateMetaInfo(offsets.toArray(), OperatorStateHandle.Mode.SPLIT_DISTRIBUTE)); OperatorStateHandle operatorStateHandle = new OperatorStateHandle(offsetsMap, new ByteStateHandleCloseChecking("os-" + i, out.toByteArray())); operatorStateHandles.add(operatorStateHandle); http://git-wip-us.apache.org/repos/asf/flink/blob/1020ba2c/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java index 0206cf5..58cfefd 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java @@ -191,8 +191,10 @@ public class InterruptSensitiveRestoreTest { List<Collection<OperatorStateHandle>> operatorStateBackend = Collections.emptyList(); List<Collection<OperatorStateHandle>> operatorStateStream = Collections.emptyList(); - Map<String, long[]> operatorStateMetadata = new HashMap<>(1); - operatorStateMetadata.put(DefaultOperatorStateBackend.DEFAULT_OPERATOR_STATE_NAME, new long[]{0}); + Map<String, OperatorStateHandle.StateMetaInfo> operatorStateMetadata = new HashMap<>(1); + OperatorStateHandle.StateMetaInfo metaInfo = + new OperatorStateHandle.StateMetaInfo(new long[]{0}, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE); + operatorStateMetadata.put(DefaultOperatorStateBackend.DEFAULT_OPERATOR_STATE_NAME, metaInfo); KeyGroupRangeOffsets keyGroupRangeOffsets = new KeyGroupRangeOffsets(new KeyGroupRange(0,0)); http://git-wip-us.apache.org/repos/asf/flink/blob/1020ba2c/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java ---------------------------------------------------------------------- diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java index da4a01b..45fcc25 100644 --- a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java +++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java @@ -86,7 +86,7 @@ public class RescalingITCase extends TestLogger { private static final int numSlots = numTaskManagers * slotsPerTaskManager; enum OperatorCheckpointMethod { - NON_PARTITIONED, CHECKPOINTED_FUNCTION, LIST_CHECKPOINTED + NON_PARTITIONED, CHECKPOINTED_FUNCTION, CHECKPOINTED_FUNCTION_BROADCAST, LIST_CHECKPOINTED } private static TestingCluster cluster; @@ -179,7 +179,7 @@ public class RescalingITCase extends TestLogger { Future<Object> savepointPathFuture = jobManager.ask(new JobManagerMessages.TriggerSavepoint(jobID, Option.<String>empty()), deadline.timeLeft()); final String savepointPath = ((JobManagerMessages.TriggerSavepointSuccess) - Await.result(savepointPathFuture, deadline.timeLeft())).savepointPath(); + Await.result(savepointPathFuture, deadline.timeLeft())).savepointPath(); Future<Object> jobRemovedFuture = jobManager.ask(new TestingJobManagerMessages.NotifyWhenJobRemoved(jobID), deadline.timeLeft()); @@ -270,7 +270,7 @@ public class RescalingITCase extends TestLogger { assertTrue(String.valueOf(savepointResponse), savepointResponse instanceof JobManagerMessages.TriggerSavepointSuccess); - final String savepointPath = ((JobManagerMessages.TriggerSavepointSuccess)savepointResponse).savepointPath(); + final String savepointPath = ((JobManagerMessages.TriggerSavepointSuccess) savepointResponse).savepointPath(); Future<Object> jobRemovedFuture = jobManager.ask(new TestingJobManagerMessages.NotifyWhenJobRemoved(jobID), deadline.timeLeft()); @@ -339,16 +339,16 @@ public class RescalingITCase extends TestLogger { JobID jobID = null; try { - jobManager = cluster.getLeaderGateway(deadline.timeLeft()); + jobManager = cluster.getLeaderGateway(deadline.timeLeft()); JobGraph jobGraph = createJobGraphWithKeyedAndNonPartitionedOperatorState( - parallelism, - maxParallelism, - parallelism, - numberKeys, - numberElements, - false, - 100); + parallelism, + maxParallelism, + parallelism, + numberKeys, + numberElements, + false, + 100); jobID = jobGraph.getJobID(); @@ -366,7 +366,7 @@ public class RescalingITCase extends TestLogger { for (int key = 0; key < numberKeys; key++) { int keyGroupIndex = KeyGroupRangeAssignment.assignToKeyGroup(key, maxParallelism); - expectedResult.add(Tuple2.of(KeyGroupRangeAssignment.computeOperatorIndexForKeyGroup(maxParallelism, parallelism, keyGroupIndex) , numberElements * key)); + expectedResult.add(Tuple2.of(KeyGroupRangeAssignment.computeOperatorIndexForKeyGroup(maxParallelism, parallelism, keyGroupIndex), numberElements * key)); } assertEquals(expectedResult, actualResult); @@ -377,7 +377,7 @@ public class RescalingITCase extends TestLogger { Future<Object> savepointPathFuture = jobManager.ask(new JobManagerMessages.TriggerSavepoint(jobID, Option.<String>empty()), deadline.timeLeft()); final String savepointPath = ((JobManagerMessages.TriggerSavepointSuccess) - Await.result(savepointPathFuture, deadline.timeLeft())).savepointPath(); + Await.result(savepointPathFuture, deadline.timeLeft())).savepointPath(); Future<Object> jobRemovedFuture = jobManager.ask(new TestingJobManagerMessages.NotifyWhenJobRemoved(jobID), deadline.timeLeft()); @@ -392,13 +392,13 @@ public class RescalingITCase extends TestLogger { jobID = null; JobGraph scaledJobGraph = createJobGraphWithKeyedAndNonPartitionedOperatorState( - parallelism2, - maxParallelism, - parallelism, - numberKeys, - numberElements + numberElements2, - true, - 100); + parallelism2, + maxParallelism, + parallelism, + numberKeys, + numberElements + numberElements2, + true, + 100); scaledJobGraph.setSavepointRestoreSettings(SavepointRestoreSettings.forPath(savepointPath)); @@ -447,6 +447,16 @@ public class RescalingITCase extends TestLogger { } @Test + public void testSavepointRescalingInBroadcastOperatorState() throws Exception { + testSavepointRescalingPartitionedOperatorState(false, OperatorCheckpointMethod.CHECKPOINTED_FUNCTION_BROADCAST); + } + + @Test + public void testSavepointRescalingOutBroadcastOperatorState() throws Exception { + testSavepointRescalingPartitionedOperatorState(true, OperatorCheckpointMethod.CHECKPOINTED_FUNCTION_BROADCAST); + } + + @Test public void testSavepointRescalingInPartitionedOperatorStateList() throws Exception { testSavepointRescalingPartitionedOperatorState(false, OperatorCheckpointMethod.LIST_CHECKPOINTED); } @@ -474,7 +484,8 @@ public class RescalingITCase extends TestLogger { int counterSize = Math.max(parallelism, parallelism2); - if(checkpointMethod == OperatorCheckpointMethod.CHECKPOINTED_FUNCTION) { + if (checkpointMethod == OperatorCheckpointMethod.CHECKPOINTED_FUNCTION || + checkpointMethod == OperatorCheckpointMethod.CHECKPOINTED_FUNCTION_BROADCAST) { PartitionedStateSource.CHECK_CORRECT_SNAPSHOT = new int[counterSize]; PartitionedStateSource.CHECK_CORRECT_RESTORE = new int[counterSize]; } else { @@ -505,11 +516,12 @@ public class RescalingITCase extends TestLogger { if (savepointResponse instanceof JobManagerMessages.TriggerSavepointSuccess) { break; } + System.out.println(savepointResponse); } assertTrue(savepointResponse instanceof JobManagerMessages.TriggerSavepointSuccess); - final String savepointPath = ((JobManagerMessages.TriggerSavepointSuccess)savepointResponse).savepointPath(); + final String savepointPath = ((JobManagerMessages.TriggerSavepointSuccess) savepointResponse).savepointPath(); Future<Object> jobRemovedFuture = jobManager.ask(new TestingJobManagerMessages.NotifyWhenJobRemoved(jobID), deadline.timeLeft()); @@ -543,6 +555,16 @@ public class RescalingITCase extends TestLogger { for (int c : PartitionedStateSource.CHECK_CORRECT_RESTORE) { sumAct += c; } + } else if (checkpointMethod == OperatorCheckpointMethod.CHECKPOINTED_FUNCTION_BROADCAST) { + for (int c : PartitionedStateSource.CHECK_CORRECT_SNAPSHOT) { + sumExp += c; + } + + for (int c : PartitionedStateSource.CHECK_CORRECT_RESTORE) { + sumAct += c; + } + + sumExp *= parallelism2; } else { for (int c : PartitionedStateSourceListCheckpointed.CHECK_CORRECT_SNAPSHOT) { sumExp += c; @@ -587,7 +609,10 @@ public class RescalingITCase extends TestLogger { switch (checkpointMethod) { case CHECKPOINTED_FUNCTION: - src = new PartitionedStateSource(); + src = new PartitionedStateSource(false); + break; + case CHECKPOINTED_FUNCTION_BROADCAST: + src = new PartitionedStateSource(true); break; case LIST_CHECKPOINTED: src = new PartitionedStateSourceListCheckpointed(); @@ -607,12 +632,12 @@ public class RescalingITCase extends TestLogger { } private static JobGraph createJobGraphWithKeyedState( - int parallelism, - int maxParallelism, - int numberKeys, - int numberElements, - boolean terminateAfterEmission, - int checkpointingInterval) { + int parallelism, + int maxParallelism, + int numberKeys, + int numberElements, + boolean terminateAfterEmission, + int checkpointingInterval) { StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); env.setParallelism(parallelism); @@ -621,17 +646,17 @@ public class RescalingITCase extends TestLogger { env.setRestartStrategy(RestartStrategies.noRestart()); DataStream<Integer> input = env.addSource(new SubtaskIndexSource( - numberKeys, - numberElements, - terminateAfterEmission)) - .keyBy(new KeySelector<Integer, Integer>() { - private static final long serialVersionUID = -7952298871120320940L; - - @Override - public Integer getKey(Integer value) throws Exception { - return value; - } - }); + numberKeys, + numberElements, + terminateAfterEmission)) + .keyBy(new KeySelector<Integer, Integer>() { + private static final long serialVersionUID = -7952298871120320940L; + + @Override + public Integer getKey(Integer value) throws Exception { + return value; + } + }); SubtaskIndexFlatMapper.workCompletedLatch = new CountDownLatch(numberKeys); @@ -643,13 +668,13 @@ public class RescalingITCase extends TestLogger { } private static JobGraph createJobGraphWithKeyedAndNonPartitionedOperatorState( - int parallelism, - int maxParallelism, - int fixedParallelism, - int numberKeys, - int numberElements, - boolean terminateAfterEmission, - int checkpointingInterval) { + int parallelism, + int maxParallelism, + int fixedParallelism, + int numberKeys, + int numberElements, + boolean terminateAfterEmission, + int checkpointingInterval) { StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); env.setParallelism(parallelism); @@ -658,18 +683,18 @@ public class RescalingITCase extends TestLogger { env.setRestartStrategy(RestartStrategies.noRestart()); DataStream<Integer> input = env.addSource(new SubtaskIndexNonPartitionedStateSource( - numberKeys, - numberElements, - terminateAfterEmission)) - .setParallelism(fixedParallelism) - .keyBy(new KeySelector<Integer, Integer>() { - private static final long serialVersionUID = -7952298871120320940L; - - @Override - public Integer getKey(Integer value) throws Exception { - return value; - } - }); + numberKeys, + numberElements, + terminateAfterEmission)) + .setParallelism(fixedParallelism) + .keyBy(new KeySelector<Integer, Integer>() { + private static final long serialVersionUID = -7952298871120320940L; + + @Override + public Integer getKey(Integer value) throws Exception { + return value; + } + }); SubtaskIndexFlatMapper.workCompletedLatch = new CountDownLatch(numberKeys); @@ -681,7 +706,7 @@ public class RescalingITCase extends TestLogger { } private static class SubtaskIndexSource - extends RichParallelSourceFunction<Integer> { + extends RichParallelSourceFunction<Integer> { private static final long serialVersionUID = -400066323594122516L; @@ -694,9 +719,9 @@ public class RescalingITCase extends TestLogger { private boolean running = true; SubtaskIndexSource( - int numberKeys, - int numberElements, - boolean terminateAfterEmission) { + int numberKeys, + int numberElements, + boolean terminateAfterEmission) { this.numberKeys = numberKeys; this.numberElements = numberElements; @@ -713,8 +738,8 @@ public class RescalingITCase extends TestLogger { if (counter < numberElements) { synchronized (lock) { for (int value = subtaskIndex; - value < numberKeys; - value += getRuntimeContext().getNumberOfParallelSubtasks()) { + value < numberKeys; + value += getRuntimeContext().getNumberOfParallelSubtasks()) { ctx.collect(value); } @@ -836,6 +861,7 @@ public class RescalingITCase extends TestLogger { } Thread.sleep(2); + if (counter == 10) { workStartedLatch.countDown(); } @@ -910,10 +936,14 @@ public class RescalingITCase extends TestLogger { private static final int NUM_PARTITIONS = 7; private ListState<Integer> counterPartitions; + private boolean broadcast; private static int[] CHECK_CORRECT_SNAPSHOT; private static int[] CHECK_CORRECT_RESTORE; + public PartitionedStateSource(boolean broadcast) { + this.broadcast = broadcast; + } @Override public void snapshotState(FunctionSnapshotContext context) throws Exception { @@ -937,8 +967,15 @@ public class RescalingITCase extends TestLogger { @Override public void initializeState(FunctionInitializationContext context) throws Exception { - this.counterPartitions = - context.getOperatorStateStore().getSerializableListState("counter_partitions"); + + if (broadcast) { + this.counterPartitions = + context.getOperatorStateStore().getBroadcastSerializableListState("counter_partitions"); + } else { + this.counterPartitions = + context.getOperatorStateStore().getSerializableListState("counter_partitions"); + } + if (context.isRestored()) { for (int v : counterPartitions.get()) { counter += v;
