http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java index 9adaa86..c39e436 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java @@ -20,7 +20,9 @@ package org.apache.flink.runtime.checkpoint; import com.google.common.collect.Iterables; import org.apache.flink.api.common.JobID; +import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.core.fs.FSDataInputStream; +import org.apache.flink.core.fs.Path; import org.apache.flink.runtime.checkpoint.savepoint.HeapSavepointStore; import org.apache.flink.runtime.checkpoint.stats.DisabledCheckpointStatsTracker; import org.apache.flink.runtime.execution.ExecutionState; @@ -34,21 +36,21 @@ import org.apache.flink.runtime.messages.checkpoint.DeclineCheckpoint; import org.apache.flink.runtime.messages.checkpoint.NotifyCheckpointComplete; import org.apache.flink.runtime.messages.checkpoint.TriggerCheckpoint; import org.apache.flink.runtime.state.ChainedStateHandle; +import org.apache.flink.runtime.state.CheckpointStateHandles; import org.apache.flink.runtime.state.KeyGroupRange; import org.apache.flink.runtime.state.KeyGroupRangeAssignment; import org.apache.flink.runtime.state.KeyGroupRangeOffsets; import org.apache.flink.runtime.state.KeyGroupsStateHandle; +import org.apache.flink.runtime.state.OperatorStateHandle; import org.apache.flink.runtime.state.StreamStateHandle; +import org.apache.flink.runtime.state.filesystem.FileStateHandle; import org.apache.flink.runtime.state.memory.ByteStreamStateHandle; import org.apache.flink.util.InstantiationUtil; import org.apache.flink.util.Preconditions; - import org.junit.Assert; import org.junit.Test; - import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; - import scala.concurrent.ExecutionContext; import scala.concurrent.Future; @@ -56,6 +58,8 @@ import java.io.IOException; import java.io.Serializable; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; import java.util.HashMap; import java.util.Iterator; import java.util.List; @@ -1459,7 +1463,7 @@ public class CheckpointCoordinatorTest { maxConcurrentAttempts, new ExecutionVertex[] { triggerVertex }, new ExecutionVertex[] { ackVertex }, - new ExecutionVertex[] { commitVertex }, + new ExecutionVertex[] { commitVertex }, new StandaloneCheckpointIDCounter(), new StandaloneCompletedCheckpointStore(2, cl), new HeapSavepointStore(), @@ -1531,7 +1535,7 @@ public class CheckpointCoordinatorTest { maxConcurrentAttempts, // max two concurrent checkpoints new ExecutionVertex[] { triggerVertex }, new ExecutionVertex[] { ackVertex }, - new ExecutionVertex[] { commitVertex }, + new ExecutionVertex[] { commitVertex }, new StandaloneCheckpointIDCounter(), new StandaloneCompletedCheckpointStore(2, cl), new HeapSavepointStore(), @@ -1791,29 +1795,29 @@ public class CheckpointCoordinatorTest { for (int index = 0; index < jobVertex1.getParallelism(); index++) { ChainedStateHandle<StreamStateHandle> nonPartitionedState = generateStateForVertex(jobVertexID1, index); + ChainedStateHandle<OperatorStateHandle> partitionableState = generateChainedPartitionableStateHandle(jobVertexID1, index, 2, 8); List<KeyGroupsStateHandle> partitionedKeyGroupState = generateKeyGroupState(jobVertexID1, keyGroupPartitions1.get(index)); + CheckpointStateHandles checkpointStateHandles = new CheckpointStateHandles(nonPartitionedState, partitionableState, partitionedKeyGroupState); AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint( - jid, - jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(), - checkpointId, - nonPartitionedState, - partitionedKeyGroupState); + jid, + jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(), + checkpointId, + checkpointStateHandles); coord.receiveAcknowledgeMessage(acknowledgeCheckpoint); } - for (int index = 0; index < jobVertex2.getParallelism(); index++) { ChainedStateHandle<StreamStateHandle> nonPartitionedState = generateStateForVertex(jobVertexID2, index); + ChainedStateHandle<OperatorStateHandle> partitionableState = generateChainedPartitionableStateHandle(jobVertexID2, index, 2, 8); List<KeyGroupsStateHandle> partitionedKeyGroupState = generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index)); - + CheckpointStateHandles checkpointStateHandles = new CheckpointStateHandles(nonPartitionedState, partitionableState, partitionedKeyGroupState); AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint( - jid, - jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(), - checkpointId, - nonPartitionedState, - partitionedKeyGroupState); + jid, + jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(), + checkpointId, + checkpointStateHandles); coord.receiveAcknowledgeMessage(acknowledgeCheckpoint); } @@ -1895,13 +1899,12 @@ public class CheckpointCoordinatorTest { for (int index = 0; index < jobVertex1.getParallelism(); index++) { ChainedStateHandle<StreamStateHandle> valueSizeTuple = generateStateForVertex(jobVertexID1, index); List<KeyGroupsStateHandle> keyGroupState = generateKeyGroupState(jobVertexID1, keyGroupPartitions1.get(index)); - + CheckpointStateHandles checkpointStateHandles = new CheckpointStateHandles(valueSizeTuple, null, keyGroupState); AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint( - jid, - jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(), - checkpointId, - valueSizeTuple, - keyGroupState); + jid, + jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(), + checkpointId, + checkpointStateHandles); coord.receiveAcknowledgeMessage(acknowledgeCheckpoint); } @@ -1910,13 +1913,12 @@ public class CheckpointCoordinatorTest { for (int index = 0; index < jobVertex2.getParallelism(); index++) { ChainedStateHandle<StreamStateHandle> valueSizeTuple = generateStateForVertex(jobVertexID2, index); List<KeyGroupsStateHandle> keyGroupState = generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index)); - + CheckpointStateHandles checkpointStateHandles = new CheckpointStateHandles(valueSizeTuple, null, keyGroupState); AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint( - jid, - jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(), - checkpointId, - valueSizeTuple, - keyGroupState); + jid, + jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(), + checkpointId, + checkpointStateHandles); coord.receiveAcknowledgeMessage(acknowledgeCheckpoint); } @@ -2014,12 +2016,12 @@ public class CheckpointCoordinatorTest { List<KeyGroupsStateHandle> keyGroupState = generateKeyGroupState( jobVertexID1, keyGroupPartitions1.get(index)); + CheckpointStateHandles checkpointStateHandles = new CheckpointStateHandles(valueSizeTuple, null, keyGroupState); AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint( - jid, - jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(), - checkpointId, - valueSizeTuple, - keyGroupState); + jid, + jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(), + checkpointId, + checkpointStateHandles); coord.receiveAcknowledgeMessage(acknowledgeCheckpoint); } @@ -2031,12 +2033,12 @@ public class CheckpointCoordinatorTest { List<KeyGroupsStateHandle> keyGroupState = generateKeyGroupState( jobVertexID2, keyGroupPartitions2.get(index)); + CheckpointStateHandles checkpointStateHandles = new CheckpointStateHandles(state, null, keyGroupState); AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint( jid, jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(), checkpointId, - state, - keyGroupState); + checkpointStateHandles); coord.receiveAcknowledgeMessage(acknowledgeCheckpoint); } @@ -2132,28 +2134,32 @@ public class CheckpointCoordinatorTest { for (int index = 0; index < jobVertex1.getParallelism(); index++) { ChainedStateHandle<StreamStateHandle> valueSizeTuple = generateStateForVertex(jobVertexID1, index); + ChainedStateHandle<OperatorStateHandle> partitionableState = generateChainedPartitionableStateHandle(jobVertexID1, index, 2, 8); List<KeyGroupsStateHandle> keyGroupState = generateKeyGroupState(jobVertexID1, keyGroupPartitions1.get(index)); + + CheckpointStateHandles checkpointStateHandles = new CheckpointStateHandles(valueSizeTuple, partitionableState, keyGroupState); AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint( jid, jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(), checkpointId, - valueSizeTuple, - keyGroupState); + checkpointStateHandles); coord.receiveAcknowledgeMessage(acknowledgeCheckpoint); } + final List<ChainedStateHandle<OperatorStateHandle>> originalPartitionableStates = new ArrayList<>(jobVertex2.getParallelism()); for (int index = 0; index < jobVertex2.getParallelism(); index++) { List<KeyGroupsStateHandle> keyGroupState = generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index)); - + ChainedStateHandle<OperatorStateHandle> partitionableState = generateChainedPartitionableStateHandle(jobVertexID2, index, 2, 8); + originalPartitionableStates.add(partitionableState); + CheckpointStateHandles checkpointStateHandles = new CheckpointStateHandles(null, partitionableState, keyGroupState); AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint( jid, jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(), checkpointId, - null, - keyGroupState); + checkpointStateHandles); coord.receiveAcknowledgeMessage(acknowledgeCheckpoint); } @@ -2185,22 +2191,49 @@ public class CheckpointCoordinatorTest { // verify the restored state verifiyStateRestore(jobVertexID1, newJobVertex1, keyGroupPartitions1); - + List<List<Collection<OperatorStateHandle>>> actualPartitionableStates = new ArrayList<>(newJobVertex2.getParallelism()); for (int i = 0; i < newJobVertex2.getParallelism(); i++) { List<KeyGroupsStateHandle> originalKeyGroupState = generateKeyGroupState(jobVertexID2, newKeyGroupPartitions2.get(i)); ChainedStateHandle<StreamStateHandle> operatorState = newJobVertex2.getTaskVertices()[i].getCurrentExecutionAttempt().getChainedStateHandle(); + List<Collection<OperatorStateHandle>> partitionableState = newJobVertex2.getTaskVertices()[i].getCurrentExecutionAttempt().getChainedPartitionableStateHandle(); List<KeyGroupsStateHandle> keyGroupState = newJobVertex2.getTaskVertices()[i].getCurrentExecutionAttempt().getKeyGroupsStateHandles(); + actualPartitionableStates.add(partitionableState); assertNull(operatorState); - comparePartitionedState(originalKeyGroupState, keyGroupState); + compareKeyPartitionedState(originalKeyGroupState, keyGroupState); } + comparePartitionableState(originalPartitionableStates, actualPartitionableStates); } // ------------------------------------------------------------------------ // Utilities // ------------------------------------------------------------------------ + static void sendAckMessageToCoordinator( + CheckpointCoordinator coord, + long checkpointId, JobID jid, + ExecutionJobVertex jobVertex, + JobVertexID jobVertexID, + List<KeyGroupRange> keyGroupPartitions) throws Exception { + + for (int index = 0; index < jobVertex.getParallelism(); index++) { + ChainedStateHandle<StreamStateHandle> state = generateStateForVertex(jobVertexID, index); + List<KeyGroupsStateHandle> keyGroupState = generateKeyGroupState( + jobVertexID, + keyGroupPartitions.get(index)); + + CheckpointStateHandles checkpointStateHandles = new CheckpointStateHandles(state, null, keyGroupState); + AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint( + jid, + jobVertex.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(), + checkpointId, + checkpointStateHandles); + + coord.receiveAcknowledgeMessage(acknowledgeCheckpoint); + } + } + public static List<KeyGroupsStateHandle> generateKeyGroupState( JobVertexID jobVertexID, KeyGroupRange keyGroupPartition) throws IOException { @@ -2217,23 +2250,45 @@ public class CheckpointCoordinatorTest { return generateKeyGroupState(keyGroupPartition, testStatesLists); } - public static List<KeyGroupsStateHandle> generateKeyGroupState(KeyGroupRange keyGroupRange, List< ? extends Serializable> states) throws IOException { + public static List<KeyGroupsStateHandle> generateKeyGroupState( + KeyGroupRange keyGroupRange, + List<? extends Serializable> states) throws IOException { + Preconditions.checkArgument(keyGroupRange.getNumberOfKeyGroups() == states.size()); - long[] offsets = new long[keyGroupRange.getNumberOfKeyGroups()]; - List<byte[]> serializedGroupValues = new ArrayList<>(offsets.length); + Tuple2<byte[], List<long[]>> serializedDataWithOffsets = + serializeTogetherAndTrackOffsets(Collections.<List<? extends Serializable>>singletonList(states)); + + KeyGroupRangeOffsets keyGroupRangeOffsets = new KeyGroupRangeOffsets(keyGroupRange, serializedDataWithOffsets.f1.get(0)); + + ByteStreamStateHandle allSerializedStatesHandle = new ByteStreamStateHandle( + serializedDataWithOffsets.f0); + KeyGroupsStateHandle keyGroupsStateHandle = new KeyGroupsStateHandle( + keyGroupRangeOffsets, + allSerializedStatesHandle); + List<KeyGroupsStateHandle> keyGroupsStateHandleList = new ArrayList<>(); + keyGroupsStateHandleList.add(keyGroupsStateHandle); + return keyGroupsStateHandleList; + } + + public static Tuple2<byte[], List<long[]>> serializeTogetherAndTrackOffsets( + List<List<? extends Serializable>> serializables) throws IOException { - KeyGroupRangeOffsets keyGroupRangeOffsets = new KeyGroupRangeOffsets(keyGroupRange, offsets); + List<long[]> offsets = new ArrayList<>(serializables.size()); + List<byte[]> serializedGroupValues = new ArrayList<>(); int runningGroupsOffset = 0; - // generate test state for all keygroups - int idx = 0; - for (int keyGroup : keyGroupRange) { - keyGroupRangeOffsets.setKeyGroupOffset(keyGroup,runningGroupsOffset); - byte[] serializedValue = InstantiationUtil.serializeObject(states.get(idx)); - runningGroupsOffset += serializedValue.length; - serializedGroupValues.add(serializedValue); - ++idx; + for(List<? extends Serializable> list : serializables) { + + long[] currentOffsets = new long[list.size()]; + offsets.add(currentOffsets); + + for (int i = 0; i < list.size(); ++i) { + currentOffsets[i] = runningGroupsOffset; + byte[] serializedValue = InstantiationUtil.serializeObject(list.get(i)); + serializedGroupValues.add(serializedValue); + runningGroupsOffset += serializedValue.length; + } } //write all generated values in a single byte array, which is index by groupOffsetsInFinalByteArray @@ -2248,15 +2303,7 @@ public class CheckpointCoordinatorTest { serializedGroupValue.length); runningGroupsOffset += serializedGroupValue.length; } - - ByteStreamStateHandle allSerializedStatesHandle = new ByteStreamStateHandle( - allSerializedValuesConcatenated); - KeyGroupsStateHandle keyGroupsStateHandle = new KeyGroupsStateHandle( - keyGroupRangeOffsets, - allSerializedStatesHandle); - List<KeyGroupsStateHandle> keyGroupsStateHandleList = new ArrayList<>(); - keyGroupsStateHandleList.add(keyGroupsStateHandle); - return keyGroupsStateHandleList; + return new Tuple2<>(allSerializedValuesConcatenated, offsets); } public static ChainedStateHandle<StreamStateHandle> generateStateForVertex( @@ -2273,6 +2320,55 @@ public class CheckpointCoordinatorTest { return ChainedStateHandle.wrapSingleHandle(ByteStreamStateHandle.fromSerializable(value)); } + public static ChainedStateHandle<OperatorStateHandle> generateChainedPartitionableStateHandle( + JobVertexID jobVertexID, + int index, + int namedStates, + int partitionsPerState) throws IOException { + + Map<String, List<? extends Serializable>> statesListsMap = new HashMap<>(namedStates); + + for (int i = 0; i < namedStates; ++i) { + List<Integer> testStatesLists = new ArrayList<>(partitionsPerState); + // generate state + Random random = new Random(jobVertexID.hashCode() * index + i * namedStates); + for (int j = 0; j < partitionsPerState; ++j) { + int simulatedStateValue = random.nextInt(); + testStatesLists.add(simulatedStateValue); + } + statesListsMap.put("state-" + i, testStatesLists); + } + + return generateChainedPartitionableStateHandle(statesListsMap); + } + + public static ChainedStateHandle<OperatorStateHandle> generateChainedPartitionableStateHandle( + Map<String, List<? extends Serializable>> states) throws IOException { + + List<List<? extends Serializable>> namedStateSerializables = new ArrayList<>(states.size()); + + for (Map.Entry<String, List<? extends Serializable>> entry : states.entrySet()) { + namedStateSerializables.add(entry.getValue()); + } + + Tuple2<byte[], List<long[]>> serializationWithOffsets = serializeTogetherAndTrackOffsets(namedStateSerializables); + + Map<String, long[]> offsetsMap = new HashMap<>(states.size()); + + int idx = 0; + for (Map.Entry<String, List<? extends Serializable>> entry : states.entrySet()) { + offsetsMap.put(entry.getKey(), serializationWithOffsets.f1.get(idx)); + ++idx; + } + + ByteStreamStateHandle streamStateHandle = new ByteStreamStateHandle( + serializationWithOffsets.f0); + + OperatorStateHandle operatorStateHandle = + new OperatorStateHandle(streamStateHandle, offsetsMap); + return ChainedStateHandle.wrapSingleHandle(operatorStateHandle); + } + public static ExecutionJobVertex mockExecutionJobVertex( JobVertexID jobVertexID, int parallelism, @@ -2348,16 +2444,24 @@ public class CheckpointCoordinatorTest { getTaskVertices()[i].getCurrentExecutionAttempt().getChainedStateHandle(); assertEquals(expectNonPartitionedState.get(0), actualNonPartitionedState.get(0)); + ChainedStateHandle<OperatorStateHandle> expectedPartitionableState = + generateChainedPartitionableStateHandle(jobVertexID, i, 2, 8); + + List<Collection<OperatorStateHandle>> actualPartitionableState = executionJobVertex. + getTaskVertices()[i].getCurrentExecutionAttempt().getChainedPartitionableStateHandle(); + + assertEquals(expectedPartitionableState.get(0), actualPartitionableState.get(0).iterator().next()); + List<KeyGroupsStateHandle> expectPartitionedKeyGroupState = generateKeyGroupState( jobVertexID, keyGroupPartitions.get(i)); List<KeyGroupsStateHandle> actualPartitionedKeyGroupState = executionJobVertex. getTaskVertices()[i].getCurrentExecutionAttempt().getKeyGroupsStateHandles(); - comparePartitionedState(expectPartitionedKeyGroupState, actualPartitionedKeyGroupState); + compareKeyPartitionedState(expectPartitionedKeyGroupState, actualPartitionedKeyGroupState); } } - public static void comparePartitionedState( + public static void compareKeyPartitionedState( List<KeyGroupsStateHandle> expectPartitionedKeyGroupState, List<KeyGroupsStateHandle> actualPartitionedKeyGroupState) throws Exception { @@ -2370,22 +2474,68 @@ public class CheckpointCoordinatorTest { assertEquals(expectedTotalKeyGroups, actualTotalKeyGroups); - FSDataInputStream inputStream = expectedHeadOpKeyGroupStateHandle.getStateHandle().openInputStream(); - for(int groupId : expectedHeadOpKeyGroupStateHandle.keyGroups()) { - long offset = expectedHeadOpKeyGroupStateHandle.getOffsetForKeyGroup(groupId); - inputStream.seek(offset); - int expectedKeyGroupState = InstantiationUtil.deserializeObject(inputStream); - for(KeyGroupsStateHandle oneActualKeyGroupStateHandle : actualPartitionedKeyGroupState) { - if (oneActualKeyGroupStateHandle.containsKeyGroup(groupId)) { - long actualOffset = oneActualKeyGroupStateHandle.getOffsetForKeyGroup(groupId); - FSDataInputStream actualInputStream = oneActualKeyGroupStateHandle.getStateHandle().openInputStream(); - actualInputStream.seek(actualOffset); - int actualGroupState = InstantiationUtil.deserializeObject(actualInputStream); - - assertEquals(expectedKeyGroupState, actualGroupState); + try (FSDataInputStream inputStream = expectedHeadOpKeyGroupStateHandle.getStateHandle().openInputStream()) { + for (int groupId : expectedHeadOpKeyGroupStateHandle.keyGroups()) { + long offset = expectedHeadOpKeyGroupStateHandle.getOffsetForKeyGroup(groupId); + inputStream.seek(offset); + int expectedKeyGroupState = InstantiationUtil.deserializeObject(inputStream); + for (KeyGroupsStateHandle oneActualKeyGroupStateHandle : actualPartitionedKeyGroupState) { + if (oneActualKeyGroupStateHandle.containsKeyGroup(groupId)) { + long actualOffset = oneActualKeyGroupStateHandle.getOffsetForKeyGroup(groupId); + try (FSDataInputStream actualInputStream = + oneActualKeyGroupStateHandle.getStateHandle().openInputStream()) { + actualInputStream.seek(actualOffset); + int actualGroupState = InstantiationUtil.deserializeObject(actualInputStream); + assertEquals(expectedKeyGroupState, actualGroupState); + } + } + } + } + } + } + + public static void comparePartitionableState( + List<ChainedStateHandle<OperatorStateHandle>> expected, + List<List<Collection<OperatorStateHandle>>> actual) throws Exception { + + List<String> expectedResult = new ArrayList<>(); + for (ChainedStateHandle<OperatorStateHandle> chainedStateHandle : expected) { + for (int i = 0; i < chainedStateHandle.getLength(); ++i) { + OperatorStateHandle operatorStateHandle = chainedStateHandle.get(i); + try (FSDataInputStream in = operatorStateHandle.openInputStream()) { + for (Map.Entry<String, long[]> entry : operatorStateHandle.getStateNameToPartitionOffsets().entrySet()) { + for (long offset : entry.getValue()) { + in.seek(offset); + Integer state = InstantiationUtil.deserializeObject(in); + expectedResult.add(i + " : " + entry.getKey() + " : " + state); + } + } } } } + Collections.sort(expectedResult); + + List<String> actualResult = new ArrayList<>(); + for (List<Collection<OperatorStateHandle>> collectionList : actual) { + if (collectionList != null) { + for (int i = 0; i < collectionList.size(); ++i) { + Collection<OperatorStateHandle> stateHandles = collectionList.get(i); + for (OperatorStateHandle operatorStateHandle : stateHandles) { + try (FSDataInputStream in = operatorStateHandle.openInputStream()) { + for (Map.Entry<String, long[]> entry : operatorStateHandle.getStateNameToPartitionOffsets().entrySet()) { + for (long offset : entry.getValue()) { + in.seek(offset); + Integer state = InstantiationUtil.deserializeObject(in); + actualResult.add(i + " : " + entry.getKey() + " : " + state); + } + } + } + } + } + } + } + Collections.sort(actualResult); + Assert.assertEquals(expectedResult, actualResult); } @Test @@ -2415,4 +2565,117 @@ public class CheckpointCoordinatorTest { } } + + @Test + public void testPartitionableStateRepartitioning() { + Random r = new Random(42); + + for (int run = 0; run < 10000; ++run) { + int oldParallelism = 1 + r.nextInt(9); + int newParallelism = 1 + r.nextInt(9); + + int numNamedStates = 1 + r.nextInt(9); + int maxPartitionsPerState = 1 + r.nextInt(9); + + doTestPartitionableStateRepartitioning( + r, oldParallelism, newParallelism, numNamedStates, maxPartitionsPerState); + } + } + + private void doTestPartitionableStateRepartitioning( + Random r, int oldParallelism, int newParallelism, int numNamedStates, int maxPartitionsPerState) { + + List<OperatorStateHandle> previousParallelOpInstanceStates = new ArrayList<>(oldParallelism); + + for (int i = 0; i < oldParallelism; ++i) { + Path fakePath = new Path("/fake-" + i); + Map<String, long[]> namedStatesToOffsets = new HashMap<>(); + int off = 0; + for (int s = 0; s < numNamedStates; ++s) { + long[] offs = new long[1 + r.nextInt(maxPartitionsPerState)]; + if (offs.length > 0) { + for (int o = 0; o < offs.length; ++o) { + offs[o] = off; + ++off; + } + namedStatesToOffsets.put("State-" + s, offs); + } + } + + previousParallelOpInstanceStates.add( + new OperatorStateHandle(new FileStateHandle(fakePath, -1), namedStatesToOffsets)); + } + + Map<StreamStateHandle, Map<String, List<Long>>> expected = new HashMap<>(); + + int expectedTotalPartitions = 0; + for (OperatorStateHandle psh : previousParallelOpInstanceStates) { + Map<String, long[]> offsMap = psh.getStateNameToPartitionOffsets(); + Map<String, List<Long>> offsMapWithList = new HashMap<>(offsMap.size()); + for (Map.Entry<String, long[]> e : offsMap.entrySet()) { + long[] offs = e.getValue(); + expectedTotalPartitions += offs.length; + List<Long> offsList = new ArrayList<>(offs.length); + for (int i = 0; i < offs.length; ++i) { + offsList.add(i, offs[i]); + } + offsMapWithList.put(e.getKey(), offsList); + } + expected.put(psh.getDelegateStateHandle(), offsMapWithList); + } + + OperatorStateRepartitioner repartitioner = RoundRobinOperatorStateRepartitioner.INSTANCE; + + List<Collection<OperatorStateHandle>> pshs = + repartitioner.repartitionState(previousParallelOpInstanceStates, newParallelism); + + Map<StreamStateHandle, Map<String, List<Long>>> actual = new HashMap<>(); + + int minCount = Integer.MAX_VALUE; + int maxCount = 0; + int actualTotalPartitions = 0; + for (int p = 0; p < newParallelism; ++p) { + int partitionCount = 0; + + Collection<OperatorStateHandle> pshc = pshs.get(p); + for (OperatorStateHandle sh : pshc) { + for (Map.Entry<String, long[]> namedState : sh.getStateNameToPartitionOffsets().entrySet()) { + + Map<String, List<Long>> x = actual.get(sh.getDelegateStateHandle()); + if (x == null) { + x = new HashMap<>(); + actual.put(sh.getDelegateStateHandle(), x); + } + + List<Long> actualOffs = x.get(namedState.getKey()); + if (actualOffs == null) { + actualOffs = new ArrayList<>(); + x.put(namedState.getKey(), actualOffs); + } + long[] add = namedState.getValue(); + for (int i = 0; i < add.length; ++i) { + actualOffs.add(add[i]); + } + + partitionCount += namedState.getValue().length; + } + } + + minCount = Math.min(minCount, partitionCount); + maxCount = Math.max(maxCount, partitionCount); + actualTotalPartitions += partitionCount; + } + + for (Map<String, List<Long>> v : actual.values()) { + for (List<Long> l : v.values()) { + Collections.sort(l); + } + } + + int maxLoadDiff = maxCount - minCount; + Assert.assertTrue("Difference in partition load is > 1 : " + maxLoadDiff, maxLoadDiff <= 1); + Assert.assertEquals(expectedTotalPartitions, actualTotalPartitions); + Assert.assertEquals(expected, actual); + } + }
http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java index a4896aa..bb78b6a 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java @@ -29,14 +29,18 @@ import org.apache.flink.runtime.executiongraph.ExecutionVertex; import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint; import org.apache.flink.runtime.state.ChainedStateHandle; +import org.apache.flink.runtime.state.CheckpointStateHandles; import org.apache.flink.runtime.state.KeyGroupRange; import org.apache.flink.runtime.state.KeyGroupsStateHandle; +import org.apache.flink.runtime.state.OperatorStateHandle; import org.apache.flink.runtime.state.StreamStateHandle; import org.apache.flink.runtime.util.SerializableObject; - +import org.hamcrest.BaseMatcher; +import org.hamcrest.Description; import org.junit.Test; import org.mockito.Mockito; +import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -112,9 +116,11 @@ public class CheckpointStateRestoreTest { PendingCheckpoint pending = coord.getPendingCheckpoints().values().iterator().next(); final long checkpointId = pending.getCheckpointId(); - coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec1.getAttemptId(), checkpointId, serializedState, serializedKeyGroupStates)); - coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec2.getAttemptId(), checkpointId, serializedState, serializedKeyGroupStates)); - coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec3.getAttemptId(), checkpointId, serializedState, serializedKeyGroupStates)); + CheckpointStateHandles checkpointStateHandles = new CheckpointStateHandles(serializedState, null, serializedKeyGroupStates); + + coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec1.getAttemptId(), checkpointId, checkpointStateHandles)); + coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec2.getAttemptId(), checkpointId, checkpointStateHandles)); + coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec3.getAttemptId(), checkpointId, checkpointStateHandles)); coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statelessExec1.getAttemptId(), checkpointId)); coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statelessExec2.getAttemptId(), checkpointId)); @@ -125,11 +131,27 @@ public class CheckpointStateRestoreTest { coord.restoreLatestCheckpointedState(map, true, false); // verify that each stateful vertex got the state - verify(statefulExec1, times(1)).setInitialState(Mockito.eq(serializedState), Mockito.<List<KeyGroupsStateHandle>>any()); - verify(statefulExec2, times(1)).setInitialState(Mockito.eq(serializedState), Mockito.<List<KeyGroupsStateHandle>>any()); - verify(statefulExec3, times(1)).setInitialState(Mockito.eq(serializedState), Mockito.<List<KeyGroupsStateHandle>>any()); - verify(statelessExec1, times(0)).setInitialState(Mockito.<ChainedStateHandle<StreamStateHandle>>any(), Mockito.<List<KeyGroupsStateHandle>>any()); - verify(statelessExec2, times(0)).setInitialState(Mockito.<ChainedStateHandle<StreamStateHandle>>any(), Mockito.<List<KeyGroupsStateHandle>>any()); + + BaseMatcher<CheckpointStateHandles> matcher = new BaseMatcher<CheckpointStateHandles>() { + @Override + public boolean matches(Object o) { + if (o instanceof CheckpointStateHandles) { + return ((CheckpointStateHandles) o).getNonPartitionedStateHandles().equals(serializedState); + } + return false; + } + + @Override + public void describeTo(Description description) { + description.appendValue(serializedState); + } + }; + + verify(statefulExec1, times(1)).setInitialState(Mockito.argThat(matcher), Mockito.<List<Collection<OperatorStateHandle>>>any()); + verify(statefulExec2, times(1)).setInitialState(Mockito.argThat(matcher), Mockito.<List<Collection<OperatorStateHandle>>>any()); + verify(statefulExec3, times(1)).setInitialState(Mockito.argThat(matcher), Mockito.<List<Collection<OperatorStateHandle>>>any()); + verify(statelessExec1, times(0)).setInitialState(Mockito.<CheckpointStateHandles>any(), Mockito.<List<Collection<OperatorStateHandle>>>any()); + verify(statelessExec2, times(0)).setInitialState(Mockito.<CheckpointStateHandles>any(), Mockito.<List<Collection<OperatorStateHandle>>>any()); } catch (Exception e) { e.printStackTrace(); @@ -193,9 +215,11 @@ public class CheckpointStateRestoreTest { final long checkpointId = pending.getCheckpointId(); // the difference to the test "testSetState" is that one stateful subtask does not report state - coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec1.getAttemptId(), checkpointId, serializedState, serializedKeyGroupStates)); + CheckpointStateHandles checkpointStateHandles = new CheckpointStateHandles(serializedState, null, serializedKeyGroupStates); + + coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec1.getAttemptId(), checkpointId, checkpointStateHandles)); coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec2.getAttemptId(), checkpointId)); - coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec3.getAttemptId(), checkpointId, serializedState, serializedKeyGroupStates)); + coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec3.getAttemptId(), checkpointId, checkpointStateHandles)); coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statelessExec1.getAttemptId(), checkpointId)); coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statelessExec2.getAttemptId(), checkpointId)); http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java index 6182ffd..289f5c3 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java @@ -197,7 +197,7 @@ public abstract class CompletedCheckpointStoreTest extends TestLogger { JobVertexID jvid = new JobVertexID(); Map<JobVertexID, TaskState> taskGroupStates = new HashMap<>(); - TaskState taskState = new TaskState(jvid, numberOfStates, numberOfStates); + TaskState taskState = new TaskState(jvid, numberOfStates, numberOfStates, 1); taskGroupStates.put(jvid, taskState); for (int i = 0; i < numberOfStates; i++) { http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java index fd4e02d..b8126e9 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java @@ -106,7 +106,7 @@ public class PendingCheckpointTest { PendingCheckpoint pending = createPendingCheckpoint(); PendingCheckpointTest.setTaskState(pending, state); - pending.acknowledgeTask(ATTEMPT_ID, null, null); + pending.acknowledgeTask(ATTEMPT_ID, null); CompletedCheckpoint checkpoint = pending.finalizeCheckpoint(); http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingSavepointTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingSavepointTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingSavepointTest.java index 7258545..3701359 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingSavepointTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingSavepointTest.java @@ -117,7 +117,7 @@ public class PendingSavepointTest { Future<String> future = pending.getCompletionFuture(); - pending.acknowledgeTask(ATTEMPT_ID, null, null); + pending.acknowledgeTask(ATTEMPT_ID, null); CompletedCheckpoint checkpoint = pending.finalizeCheckpoint(); http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreITCase.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreITCase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreITCase.java index 6a8d072..9fbe574 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreITCase.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreITCase.java @@ -186,10 +186,5 @@ public class ZooKeeperCompletedCheckpointStoreITCase extends CompletedCheckpoint public long getStateSize() throws IOException { return 0; } - - @Override - public void close() throws IOException { - - } } } http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Test.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Test.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Test.java index ef10032..c82be18 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Test.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Test.java @@ -24,6 +24,7 @@ import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.state.ChainedStateHandle; import org.apache.flink.runtime.state.KeyGroupRangeOffsets; import org.apache.flink.runtime.state.KeyGroupsStateHandle; +import org.apache.flink.runtime.state.OperatorStateHandle; import org.apache.flink.runtime.state.StreamStateHandle; import org.apache.flink.runtime.state.memory.ByteStreamStateHandle; import org.junit.Test; @@ -32,7 +33,9 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.concurrent.ThreadLocalRandom; import static org.junit.Assert.assertEquals; @@ -67,17 +70,30 @@ public class SavepointV1Test { List<TaskState> taskStates = new ArrayList<>(numTaskStates); for (int i = 0; i < numTaskStates; i++) { - TaskState taskState = new TaskState(new JobVertexID(), numSubtaskStates, numSubtaskStates); + TaskState taskState = new TaskState(new JobVertexID(), numSubtaskStates, numSubtaskStates, 1); for (int j = 0; j < numSubtaskStates; j++) { StreamStateHandle stateHandle = new ByteStreamStateHandle("Hello".getBytes()); taskState.putState(i, new SubtaskState( new ChainedStateHandle<>(Collections.singletonList(stateHandle)), 0)); + + stateHandle = new ByteStreamStateHandle("Beautiful".getBytes()); + Map<String, long[]> offsetsMap = new HashMap<>(); + offsetsMap.put("A", new long[]{0, 10, 20}); + offsetsMap.put("B", new long[]{30, 40, 50}); + + OperatorStateHandle operatorStateHandle = + new OperatorStateHandle(stateHandle, offsetsMap); + + taskState.putPartitionableState( + i, + new ChainedStateHandle<OperatorStateHandle>( + Collections.singletonList(operatorStateHandle))); } taskState.putKeyedState( 0, new KeyGroupsStateHandle( - new KeyGroupRangeOffsets(1,1, new long[] {42}), new ByteStreamStateHandle("Hello".getBytes()))); + new KeyGroupRangeOffsets(1,1, new long[] {42}), new ByteStreamStateHandle("World".getBytes()))); taskStates.add(taskState); } http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/stats/SimpleCheckpointStatsTrackerTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/stats/SimpleCheckpointStatsTrackerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/stats/SimpleCheckpointStatsTrackerTest.java index 504143b..1e95732 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/stats/SimpleCheckpointStatsTrackerTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/stats/SimpleCheckpointStatsTrackerTest.java @@ -319,7 +319,7 @@ public class SimpleCheckpointStatsTrackerTest { JobVertexID operatorId = operatorIds[operatorIndex]; int parallelism = operatorParallelism[operatorIndex]; - TaskState taskState = new TaskState(operatorId, parallelism, maxParallelism); + TaskState taskState = new TaskState(operatorId, parallelism, maxParallelism, 1); taskGroupStates.put(operatorId, taskState); http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java index ef8e3bd..9b12cac 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java @@ -26,6 +26,7 @@ import akka.testkit.JavaTestKit; import org.apache.flink.api.common.JobID; import org.apache.flink.configuration.ConfigConstants; import org.apache.flink.configuration.Configuration; +import org.apache.flink.core.fs.FSDataInputStream; import org.apache.flink.runtime.akka.AkkaUtils; import org.apache.flink.runtime.akka.ListeningBehaviour; import org.apache.flink.runtime.blob.BlobServer; @@ -54,9 +55,11 @@ import org.apache.flink.runtime.leaderelection.TestingLeaderRetrievalService; import org.apache.flink.runtime.leaderretrieval.LeaderRetrievalService; import org.apache.flink.runtime.messages.JobManagerMessages; import org.apache.flink.runtime.state.ChainedStateHandle; +import org.apache.flink.runtime.state.CheckpointStateHandles; import org.apache.flink.runtime.state.KeyGroupsStateHandle; -import org.apache.flink.runtime.state.StreamStateHandle; +import org.apache.flink.runtime.state.OperatorStateHandle; import org.apache.flink.runtime.state.RetrievableStreamStateHandle; +import org.apache.flink.runtime.state.StreamStateHandle; import org.apache.flink.runtime.state.memory.ByteStreamStateHandle; import org.apache.flink.runtime.taskmanager.TaskManager; import org.apache.flink.runtime.testingUtils.TestingJobManager; @@ -80,6 +83,7 @@ import scala.concurrent.duration.FiniteDuration; import java.util.ArrayDeque; import java.util.ArrayList; +import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -441,10 +445,15 @@ public class JobManagerHARecoveryTest { private int completedCheckpoints = 0; @Override - public void setInitialState(ChainedStateHandle<StreamStateHandle> chainedState, List<KeyGroupsStateHandle> keyGroupsState) throws Exception { + public void setInitialState( + ChainedStateHandle<StreamStateHandle> chainedState, + List<KeyGroupsStateHandle> keyGroupsState, + List<Collection<OperatorStateHandle>> partitionableOperatorState) throws Exception { int subtaskIndex = getIndexInSubtaskGroup(); if (subtaskIndex < recoveredStates.length) { - recoveredStates[subtaskIndex] = InstantiationUtil.deserializeObject(chainedState.get(0).openInputStream()); + try (FSDataInputStream in = chainedState.get(0).openInputStream()) { + recoveredStates[subtaskIndex] = InstantiationUtil.deserializeObject(in); + } } } @@ -456,11 +465,12 @@ public class JobManagerHARecoveryTest { RetrievableStreamStateHandle<Long> state = new RetrievableStreamStateHandle<Long>(byteStreamStateHandle); ChainedStateHandle<StreamStateHandle> chainedStateHandle = new ChainedStateHandle<StreamStateHandle>(Collections.singletonList(state)); + CheckpointStateHandles checkpointStateHandles = + new CheckpointStateHandles(chainedStateHandle, null, Collections.<KeyGroupsStateHandle>emptyList()); getEnvironment().acknowledgeCheckpoint( checkpointId, - chainedStateHandle, - Collections.<KeyGroupsStateHandle>emptyList(), + checkpointStateHandles, 0L, 0L, 0L, 0L); return true; } catch (Exception ex) { http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java index 6a6ac64..4873335 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java @@ -23,11 +23,12 @@ import org.apache.flink.core.fs.FSDataInputStream; import org.apache.flink.core.testutils.CommonTestUtils; import org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTest; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; +import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint; import org.apache.flink.runtime.messages.checkpoint.NotifyCheckpointComplete; import org.apache.flink.runtime.messages.checkpoint.TriggerCheckpoint; +import org.apache.flink.runtime.state.CheckpointStateHandles; import org.apache.flink.runtime.state.KeyGroupRange; -import org.apache.flink.runtime.state.StateObject; import org.apache.flink.runtime.state.StreamStateHandle; import org.junit.Test; @@ -65,13 +66,17 @@ public class CheckpointMessagesTest { KeyGroupRange keyGroupRange = KeyGroupRange.of(42,42); + CheckpointStateHandles checkpointStateHandles = + new CheckpointStateHandles( + CheckpointCoordinatorTest.generateChainedStateHandle(new MyHandle()), + CheckpointCoordinatorTest.generateChainedPartitionableStateHandle(new JobVertexID(), 0, 2, 8), + CheckpointCoordinatorTest.generateKeyGroupState(keyGroupRange, Collections.singletonList(new MyHandle()))); + AcknowledgeCheckpoint withState = new AcknowledgeCheckpoint( new JobID(), new ExecutionAttemptID(), 87658976143L, - CheckpointCoordinatorTest.generateChainedStateHandle(new MyHandle()), - CheckpointCoordinatorTest.generateKeyGroupState( - keyGroupRange, Collections.singletonList(new MyHandle()))); + checkpointStateHandles); testSerializabilityEqualsHashCode(noState); testSerializabilityEqualsHashCode(withState); @@ -83,7 +88,6 @@ public class CheckpointMessagesTest { private static void testSerializabilityEqualsHashCode(Serializable o) throws IOException { Object copy = CommonTestUtils.createCopySerializable(o); - System.out.println(o.getClass() +" "+copy.getClass()); assertEquals(o, copy); assertEquals(o.hashCode(), copy.hashCode()); assertNotNull(o.toString()); @@ -117,9 +121,6 @@ public class CheckpointMessagesTest { } @Override - public void close() throws IOException {} - - @Override public FSDataInputStream openInputStream() throws IOException { return null; } http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java index a857d1b..c855230 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java @@ -37,6 +37,7 @@ import org.apache.flink.runtime.memory.MemoryManager; import org.apache.flink.runtime.query.KvStateRegistry; import org.apache.flink.runtime.query.TaskKvStateRegistry; import org.apache.flink.runtime.state.ChainedStateHandle; +import org.apache.flink.runtime.state.CheckpointStateHandles; import org.apache.flink.runtime.state.KeyGroupsStateHandle; import org.apache.flink.runtime.state.StreamStateHandle; import org.apache.flink.runtime.taskmanager.TaskManagerRuntimeInfo; @@ -162,7 +163,7 @@ public class DummyEnvironment implements Environment { @Override public void acknowledgeCheckpoint( long checkpointId, - ChainedStateHandle<StreamStateHandle> chainedStateHandle, List<KeyGroupsStateHandle> keyGroupStateHandles, + CheckpointStateHandles checkpointStateHandles, long synchronousDurationMillis, long asynchronousDurationMillis, long bytesBufferedInAlignment, long alignmentDurationNanos) { } http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java index 75e88eb..c3ed6c0 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java @@ -46,6 +46,7 @@ import org.apache.flink.runtime.metrics.groups.TaskMetricGroup; import org.apache.flink.runtime.query.KvStateRegistry; import org.apache.flink.runtime.query.TaskKvStateRegistry; import org.apache.flink.runtime.state.ChainedStateHandle; +import org.apache.flink.runtime.state.CheckpointStateHandles; import org.apache.flink.runtime.state.KeyGroupsStateHandle; import org.apache.flink.runtime.state.StreamStateHandle; @@ -323,7 +324,7 @@ public class MockEnvironment implements Environment { @Override public void acknowledgeCheckpoint( long checkpointId, - ChainedStateHandle<StreamStateHandle> chainedStateHandle, List<KeyGroupsStateHandle> keyGroupStateHandles, + CheckpointStateHandles checkpointStateHandles, long synchronousDurationMillis, long asynchronousDurationMillis, long bytesBufferedInAlignment, long alignmentDurationNanos) { } http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/query/QueryableStateClientTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/query/QueryableStateClientTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/query/QueryableStateClientTest.java index 1039568..4279635 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/query/QueryableStateClientTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/query/QueryableStateClientTest.java @@ -32,8 +32,8 @@ import org.apache.flink.runtime.query.netty.KvStateClient; import org.apache.flink.runtime.query.netty.KvStateServer; import org.apache.flink.runtime.query.netty.UnknownKvStateID; import org.apache.flink.runtime.query.netty.message.KvStateRequestSerializer; +import org.apache.flink.runtime.state.AbstractKeyedStateBackend; import org.apache.flink.runtime.state.KeyGroupRange; -import org.apache.flink.runtime.state.KeyedStateBackend; import org.apache.flink.runtime.state.VoidNamespace; import org.apache.flink.runtime.state.VoidNamespaceSerializer; import org.apache.flink.runtime.state.heap.HeapValueState; @@ -246,7 +246,7 @@ public class QueryableStateClientTest { MemoryStateBackend backend = new MemoryStateBackend(); DummyEnvironment dummyEnv = new DummyEnvironment("test", 1, 0); - KeyedStateBackend<Integer> keyedStateBackend = backend.createKeyedStateBackend(dummyEnv, + AbstractKeyedStateBackend<Integer> keyedStateBackend = backend.createKeyedStateBackend(dummyEnv, new JobID(), "test_op", IntSerializer.INSTANCE, http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateClientTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateClientTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateClientTest.java index c8fb4bb..0db8b31 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateClientTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateClientTest.java @@ -41,9 +41,9 @@ import org.apache.flink.runtime.query.KvStateServerAddress; import org.apache.flink.runtime.query.netty.message.KvStateRequest; import org.apache.flink.runtime.query.netty.message.KvStateRequestSerializer; import org.apache.flink.runtime.query.netty.message.KvStateRequestType; +import org.apache.flink.runtime.state.AbstractKeyedStateBackend; import org.apache.flink.runtime.state.AbstractStateBackend; import org.apache.flink.runtime.state.KeyGroupRange; -import org.apache.flink.runtime.state.KeyedStateBackend; import org.apache.flink.runtime.state.KvState; import org.apache.flink.runtime.state.VoidNamespace; import org.apache.flink.runtime.state.VoidNamespaceSerializer; @@ -538,7 +538,8 @@ public class KvStateClientTest { KvStateRegistry dummyRegistry = new KvStateRegistry(); DummyEnvironment dummyEnv = new DummyEnvironment("test", 1, 0); dummyEnv.setKvStateRegistry(dummyRegistry); - KeyedStateBackend<Integer> backend = abstractBackend.createKeyedStateBackend( + + AbstractKeyedStateBackend<Integer> backend = abstractBackend.createKeyedStateBackend( dummyEnv, new JobID(), "test_op", http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateServerHandlerTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateServerHandlerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateServerHandlerTest.java index 7e6d713..ed4a822 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateServerHandlerTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateServerHandlerTest.java @@ -38,6 +38,7 @@ import org.apache.flink.runtime.query.netty.message.KvStateRequestFailure; import org.apache.flink.runtime.query.netty.message.KvStateRequestResult; import org.apache.flink.runtime.query.netty.message.KvStateRequestSerializer; import org.apache.flink.runtime.query.netty.message.KvStateRequestType; +import org.apache.flink.runtime.state.AbstractKeyedStateBackend; import org.apache.flink.runtime.state.AbstractStateBackend; import org.apache.flink.runtime.state.KeyGroupRange; import org.apache.flink.runtime.state.KeyedStateBackend; @@ -92,7 +93,7 @@ public class KvStateServerHandlerTest { AbstractStateBackend abstractBackend = new MemoryStateBackend(); DummyEnvironment dummyEnv = new DummyEnvironment("test", 1, 0); dummyEnv.setKvStateRegistry(registry); - KeyedStateBackend<Integer> backend = abstractBackend.createKeyedStateBackend( + AbstractKeyedStateBackend<Integer> backend = abstractBackend.createKeyedStateBackend( dummyEnv, new JobID(), "test_op", @@ -490,7 +491,7 @@ public class KvStateServerHandlerTest { AbstractStateBackend abstractBackend = new MemoryStateBackend(); DummyEnvironment dummyEnv = new DummyEnvironment("test", 1, 0); dummyEnv.setKvStateRegistry(registry); - KeyedStateBackend<Integer> backend = abstractBackend.createKeyedStateBackend( + AbstractKeyedStateBackend<Integer> backend = abstractBackend.createKeyedStateBackend( dummyEnv, new JobID(), "test_op", @@ -586,7 +587,7 @@ public class KvStateServerHandlerTest { AbstractStateBackend abstractBackend = new MemoryStateBackend(); DummyEnvironment dummyEnv = new DummyEnvironment("test", 1, 0); dummyEnv.setKvStateRegistry(registry); - KeyedStateBackend<Integer> backend = abstractBackend.createKeyedStateBackend( + AbstractKeyedStateBackend<Integer> backend = abstractBackend.createKeyedStateBackend( dummyEnv, new JobID(), "test_op", http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateServerTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateServerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateServerTest.java index e92fb10..b1c4a9f 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateServerTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateServerTest.java @@ -41,9 +41,9 @@ import org.apache.flink.runtime.query.KvStateServerAddress; import org.apache.flink.runtime.query.netty.message.KvStateRequestResult; import org.apache.flink.runtime.query.netty.message.KvStateRequestSerializer; import org.apache.flink.runtime.query.netty.message.KvStateRequestType; +import org.apache.flink.runtime.state.AbstractKeyedStateBackend; import org.apache.flink.runtime.state.AbstractStateBackend; import org.apache.flink.runtime.state.KeyGroupRange; -import org.apache.flink.runtime.state.KeyedStateBackend; import org.apache.flink.runtime.state.VoidNamespace; import org.apache.flink.runtime.state.VoidNamespaceSerializer; import org.apache.flink.runtime.state.memory.MemoryStateBackend; @@ -91,7 +91,7 @@ public class KvStateServerTest { AbstractStateBackend abstractBackend = new MemoryStateBackend(); DummyEnvironment dummyEnv = new DummyEnvironment("test", 1, 0); dummyEnv.setKvStateRegistry(registry); - KeyedStateBackend<Integer> backend = abstractBackend.createKeyedStateBackend( + AbstractKeyedStateBackend<Integer> backend = abstractBackend.createKeyedStateBackend( dummyEnv, new JobID(), "test_op", http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/state/AbstractCloseableHandleTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/AbstractCloseableHandleTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/AbstractCloseableHandleTest.java deleted file mode 100644 index e613105..0000000 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/AbstractCloseableHandleTest.java +++ /dev/null @@ -1,97 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.runtime.state; - -import org.junit.Test; - -import java.io.Closeable; -import java.io.IOException; - -import static org.junit.Assert.*; -import static org.mockito.Mockito.*; - -public class AbstractCloseableHandleTest { - - @Test - public void testRegisterThenClose() throws Exception { - Closeable closeable = mock(Closeable.class); - - AbstractCloseableHandle handle = new CloseableHandle(); - assertFalse(handle.isClosed()); - - // no immediate closing - handle.registerCloseable(closeable); - verify(closeable, times(0)).close(); - assertFalse(handle.isClosed()); - - // close forwarded once - handle.close(); - verify(closeable, times(1)).close(); - assertTrue(handle.isClosed()); - - // no repeated closing - handle.close(); - verify(closeable, times(1)).close(); - assertTrue(handle.isClosed()); - } - - @Test - public void testCloseThenRegister() throws Exception { - Closeable closeable = mock(Closeable.class); - - AbstractCloseableHandle handle = new CloseableHandle(); - assertFalse(handle.isClosed()); - - // close the handle before setting the closeable - handle.close(); - assertTrue(handle.isClosed()); - - // immediate closing - try { - handle.registerCloseable(closeable); - fail("this should throw an excepion"); - } catch (IOException e) { - // expected - assertTrue(e.getMessage().contains("closed")); - } - - // should still have called "close" on the Closeable - verify(closeable, times(1)).close(); - assertTrue(handle.isClosed()); - - // no repeated closing - handle.close(); - verify(closeable, times(1)).close(); - assertTrue(handle.isClosed()); - } - - // ------------------------------------------------------------------------ - - private static final class CloseableHandle extends AbstractCloseableHandle { - private static final long serialVersionUID = 1L; - - @Override - public void discardState() {} - - @Override - public long getStateSize() { - return 0; - } - } -} http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/state/FileStateBackendTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/FileStateBackendTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/FileStateBackendTest.java index bc0b9c3..0b04ebc 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/FileStateBackendTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/FileStateBackendTest.java @@ -20,16 +20,11 @@ package org.apache.flink.runtime.state; import org.apache.commons.io.FileUtils; import org.apache.flink.api.common.JobID; -import org.apache.flink.api.common.typeutils.base.IntSerializer; -import org.apache.flink.configuration.ConfigConstants; import org.apache.flink.core.fs.Path; import org.apache.flink.core.testutils.CommonTestUtils; -import org.apache.flink.runtime.operators.testutils.DummyEnvironment; - import org.apache.flink.runtime.state.filesystem.FileStateHandle; import org.apache.flink.runtime.state.filesystem.FsStateBackend; import org.apache.flink.runtime.state.memory.ByteStreamStateHandle; - import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; @@ -39,9 +34,12 @@ import java.io.IOException; import java.io.InputStream; import java.net.URI; import java.util.Random; -import java.util.UUID; -import static org.junit.Assert.*; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; public class FileStateBackendTest extends StateBackendTestBase<FsStateBackend> { @@ -188,18 +186,21 @@ public class FileStateBackendTest extends StateBackendTestBase<FsStateBackend> { } private static void validateBytesInStream(InputStream is, byte[] data) throws IOException { - byte[] holder = new byte[data.length]; + try { + byte[] holder = new byte[data.length]; - int pos = 0; - int read; - while (pos < holder.length && (read = is.read(holder, pos, holder.length - pos)) != -1) { - pos += read; - } + int pos = 0; + int read; + while (pos < holder.length && (read = is.read(holder, pos, holder.length - pos)) != -1) { + pos += read; + } - assertEquals("not enough data", holder.length, pos); - assertEquals("too much data", -1, is.read()); - assertArrayEquals("wrong data", data, holder); - is.close(); + assertEquals("not enough data", holder.length, pos); + assertEquals("too much data", -1, is.read()); + assertArrayEquals("wrong data", data, holder); + } finally { + is.close(); + } } @Test http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java index 944938b..ac6adff 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java @@ -19,8 +19,6 @@ package org.apache.flink.runtime.state; import org.apache.flink.api.common.JobID; -import org.apache.flink.api.common.typeutils.base.IntSerializer; -import org.apache.flink.runtime.operators.testutils.DummyEnvironment; import org.apache.flink.runtime.state.memory.MemoryStateBackend; import org.junit.Test; @@ -29,7 +27,10 @@ import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import java.util.HashMap; -import static org.junit.Assert.*; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; /** * Tests for the {@link org.apache.flink.runtime.state.memory.MemoryStateBackend}. @@ -105,10 +106,10 @@ public class MemoryStateBackendTest extends StateBackendTestBase<MemoryStateBack assertNotNull(handle); - ObjectInputStream ois = new ObjectInputStream(handle.openInputStream()); - assertEquals(state, ois.readObject()); - assertTrue(ois.available() <= 0); - ois.close(); + try (ObjectInputStream ois = new ObjectInputStream(handle.openInputStream())) { + assertEquals(state, ois.readObject()); + assertTrue(ois.available() <= 0); + } } catch (Exception e) { e.printStackTrace(); http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java new file mode 100644 index 0000000..56c8987 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state; + +import org.apache.flink.api.common.JobID; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.java.typeutils.runtime.JavaSerializer; +import org.apache.flink.runtime.state.memory.MemoryStateBackend; +import org.junit.Test; + +import java.io.Serializable; +import java.util.Collections; +import java.util.Iterator; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +public class OperatorStateBackendTest { + + AbstractStateBackend abstractStateBackend = new MemoryStateBackend(1024); + + private OperatorStateBackend createNewOperatorStateBackend() throws Exception { + return abstractStateBackend.createOperatorStateBackend(null, "test-operator"); + } + + @Test + public void testCreateNew() throws Exception { + OperatorStateBackend operatorStateBackend = createNewOperatorStateBackend(); + assertNotNull(operatorStateBackend); + assertTrue(operatorStateBackend.getRegisteredStateNames().isEmpty()); + } + + @Test + public void testRegisterStates() throws Exception { + OperatorStateBackend operatorStateBackend = createNewOperatorStateBackend(); + ListStateDescriptor<Serializable> stateDescriptor1 = new ListStateDescriptor<>("test1", new JavaSerializer<>()); + ListStateDescriptor<Serializable> stateDescriptor2 = new ListStateDescriptor<>("test2", new JavaSerializer<>()); + ListState<Serializable> listState1 = operatorStateBackend.getPartitionableState(stateDescriptor1); + assertNotNull(listState1); + assertEquals(1, operatorStateBackend.getRegisteredStateNames().size()); + Iterator<Serializable> it = listState1.get().iterator(); + assertTrue(!it.hasNext()); + listState1.add(42); + listState1.add(4711); + + it = listState1.get().iterator(); + assertEquals(42, it.next()); + assertEquals(4711, it.next()); + assertTrue(!it.hasNext()); + + ListState<Serializable> listState2 = operatorStateBackend.getPartitionableState(stateDescriptor2); + assertNotNull(listState2); + assertEquals(2, operatorStateBackend.getRegisteredStateNames().size()); + assertTrue(!it.hasNext()); + listState2.add(7); + listState2.add(13); + listState2.add(23); + + it = listState2.get().iterator(); + assertEquals(7, it.next()); + assertEquals(13, it.next()); + assertEquals(23, it.next()); + assertTrue(!it.hasNext()); + + ListState<Serializable> listState1b = operatorStateBackend.getPartitionableState(stateDescriptor1); + assertNotNull(listState1b); + listState1b.add(123); + it = listState1b.get().iterator(); + assertEquals(42, it.next()); + assertEquals(4711, it.next()); + assertEquals(123, it.next()); + assertTrue(!it.hasNext()); + + it = listState1.get().iterator(); + assertEquals(42, it.next()); + assertEquals(4711, it.next()); + assertEquals(123, it.next()); + assertTrue(!it.hasNext()); + + it = listState1b.get().iterator(); + assertEquals(42, it.next()); + assertEquals(4711, it.next()); + assertEquals(123, it.next()); + assertTrue(!it.hasNext()); + } + + @Test + public void testSnapshotRestore() throws Exception { + OperatorStateBackend operatorStateBackend = createNewOperatorStateBackend(); + ListStateDescriptor<Serializable> stateDescriptor1 = new ListStateDescriptor<>("test1", new JavaSerializer<>()); + ListStateDescriptor<Serializable> stateDescriptor2 = new ListStateDescriptor<>("test2", new JavaSerializer<>()); + ListState<Serializable> listState1 = operatorStateBackend.getPartitionableState(stateDescriptor1); + ListState<Serializable> listState2 = operatorStateBackend.getPartitionableState(stateDescriptor2); + + listState1.add(42); + listState1.add(4711); + + listState2.add(7); + listState2.add(13); + listState2.add(23); + + CheckpointStreamFactory streamFactory = abstractStateBackend.createStreamFactory(new JobID(), "testOperator"); + OperatorStateHandle stateHandle = operatorStateBackend.snapshot(1, 1, streamFactory).get(); + + try { + + operatorStateBackend.dispose(); + + operatorStateBackend = abstractStateBackend. + restoreOperatorStateBackend(null, "testOperator", Collections.singletonList(stateHandle)); + + assertEquals(0, operatorStateBackend.getRegisteredStateNames().size()); + + listState1 = operatorStateBackend.getPartitionableState(stateDescriptor1); + listState2 = operatorStateBackend.getPartitionableState(stateDescriptor2); + + assertEquals(2, operatorStateBackend.getRegisteredStateNames().size()); + + + Iterator<Serializable> it = listState1.get().iterator(); + assertEquals(42, it.next()); + assertEquals(4711, it.next()); + assertTrue(!it.hasNext()); + + it = listState2.get().iterator(); + assertEquals(7, it.next()); + assertEquals(13, it.next()); + assertEquals(23, it.next()); + assertTrue(!it.hasNext()); + + operatorStateBackend.dispose(); + } finally { + + stateHandle.discardState(); + } + } + +} \ No newline at end of file
