tillrohrmann commented on a change in pull request #9501: [FLINK-12697] [State 
Backends] Support on-disk state storage for spill-able heap backend
URL: https://github.com/apache/flink/pull/9501#discussion_r334038508
 
 

 ##########
 File path: 
flink-state-backends/flink-statebackend-heap-spillable/src/test/java/org/apache/flink/runtime/state/heap/CopyOnWriteSkipListStateMapTest.java
 ##########
 @@ -115,577 +144,783 @@ public void testInitStateMap() {
                assertFalse(stateMap.getStateIncrementalVisitor(100).hasNext());
 
                stateMap.close();
-               assertEquals(0, stateMap.size());
-               assertEquals(0, stateMap.totalSize());
-               assertTrue(stateMap.isClosed());
        }
 
        /**
-        * Test basic operations.
+        * Test state put operation.
         */
        @Test
-       public void testBasicOperations() throws Exception {
-               TypeSerializer<Integer> keySerializer = IntSerializer.INSTANCE;
-               TypeSerializer<Long> namespaceSerializer = 
LongSerializer.INSTANCE;
-               TypeSerializer<String> stateSerializer = 
StringSerializer.INSTANCE;
-               CopyOnWriteSkipListStateMap<Integer, Long, String> stateMap = 
new CopyOnWriteSkipListStateMap<>(
-                       keySerializer, namespaceSerializer, stateSerializer, 
spaceAllocator);
+       public void testPutState() {
+               testWithFunction((totalSize, stateMap, referenceStates) -> 
getDefaultSizes(totalSize));
+       }
 
-               ThreadLocalRandom random = ThreadLocalRandom.current();
-               // map to store expected states, namespace -> key -> state
-               Map<Long, Map<Integer, String>> referenceStates = new 
HashMap<>();
-               int totalSize = 0;
+       /**
+        * Test remove existing state.
+        */
+       @Test
+       public void testRemoveExistingState() {
+               testRemoveState(false, false);
+       }
 
-               // put some states
-               for (long namespace = 0; namespace < 10; namespace++) {
-                       for (int key = 0; key < 100; key++) {
-                               totalSize++;
-                               String state = String.valueOf(key * namespace);
-                               if (random.nextBoolean()) {
-                                       stateMap.put(key, namespace, state);
-                               } else {
-                                       assertNull(stateMap.putAndGetOld(key, 
namespace, state));
+       /**
+        * Test remove and get existing state.
+        */
+       @Test
+       public void testRemoveAndGetExistingState() {
+               testRemoveState(false, true);
+       }
+
+       /**
+        * Test remove absent state.
+        */
+       @Test
+       public void testRemoveAbsentState() {
+               testRemoveState(true, true);
+       }
+
+       /**
+        * Test remove previously removed state.
+        */
+       @Test
+       public void testPutPreviouslyRemovedState() {
+               testWithFunction(
+                       (totalSize, stateMap, referenceStates) -> 
applyFunctionAfterRemove(stateMap, referenceStates,
+                               (removedCnt, removedStates) -> {
+                                       int size = totalSize - removedCnt;
+                                       for (Map.Entry<Long, Set<Integer>> 
entry : removedStates.entrySet()) {
+                                               long namespace = entry.getKey();
+                                               for (int key : 
entry.getValue()) {
+                                                       size++;
+                                                       String state = 
String.valueOf(key * namespace);
+                                                       
assertNull(stateMap.putAndGetOld(key, namespace, state));
+                                                       
referenceStates.computeIfAbsent(namespace, (none) -> new HashMap<>()).put(key, 
String.valueOf(state));
+                                               }
+                                       }
+                                       return getDefaultSizes(size);
                                }
-                               referenceStates.computeIfAbsent(namespace, 
(none) -> new HashMap<>()).put(key, state);
-                               assertEquals(totalSize, stateMap.size());
-                               assertEquals(totalSize, stateMap.totalSize());
-                       }
-               }
+                       )
+               );
+       }
 
-               // validates space allocation. Each pair need 2 spaces
-               assertEquals(totalSize * 2, 
spaceAllocator.getTotalSpaceNumber());
-               verifyState(referenceStates, stateMap);
+       private void testRemoveState(boolean removeAbsent, boolean getOld) {
+               testWithFunction(
+                       (totalSize, stateMap, referenceStates) -> {
+                               if (removeAbsent) {
+                                       totalSize -= 
removeAbsentState(stateMap, referenceStates);
+                               } else {
+                                       totalSize -= 
removeExistingState(stateMap, referenceStates, getOld);
+                               }
+                               return getDefaultSizes(totalSize);
+                       });
+       }
 
-               // remove some states
-               Map<Long, Set<Integer>> removedStates = new HashMap<>();
+       private int removeExistingState(
+               CopyOnWriteSkipListStateMap<Integer, Long, String> stateMap,
+               @Nonnull Map<Long, Map<Integer, String>> referenceStates,
+               boolean getOld) {
+               int removedCnt = 0;
                for (Map.Entry<Long, Map<Integer, String>> namespaceEntry : 
referenceStates.entrySet()) {
                        long namespace = namespaceEntry.getKey();
-                       for (Map.Entry<Integer, String> keyEntry : 
namespaceEntry.getValue().entrySet()) {
+                       Map<Integer, String> kvMap = namespaceEntry.getValue();
+                       Iterator<Map.Entry<Integer, String>> kvIterator = 
kvMap.entrySet().iterator();
+                       while (kvIterator.hasNext()) {
+                               Map.Entry<Integer, String> keyEntry = 
kvIterator.next();
                                if (random.nextBoolean()) {
                                        int key = keyEntry.getKey();
                                        String state = keyEntry.getValue();
-                                       
removedStates.computeIfAbsent(namespace, (none) -> new HashSet<>()).add(key);
-                                       totalSize--;
-                                       if (random.nextBoolean()) {
-                                               stateMap.remove(key, namespace);
-                                       } else {
+                                       removedCnt++;
+                                       // remove from state map
+                                       if (getOld) {
                                                assertEquals(state, 
stateMap.removeAndGetOld(key, namespace));
+                                       } else {
+                                               stateMap.remove(key, namespace);
                                        }
-                                       assertEquals(totalSize, 
stateMap.size());
-                                       assertEquals(totalSize, 
stateMap.totalSize());
+                                       // remove from reference to keep in 
accordance
+                                       kvIterator.remove();
                                }
                        }
                }
+               return removedCnt;
+       }
 
-               for (Map.Entry<Long, Set<Integer>> entry : 
removedStates.entrySet()) {
-                       long namespace = entry.getKey();
-                       Map<Integer, String> keyMap = 
referenceStates.get(namespace);
-                       if (keyMap != null) {
-                               entry.getValue().forEach(keyMap::remove);
-                               if (keyMap.isEmpty()) {
-                                       referenceStates.remove(namespace);
+       private int removeAbsentState(
+                       CopyOnWriteSkipListStateMap<Integer, Long, String> 
stateMap,
+                       Map<Long, Map<Integer, String>> referenceStates) {
+               return applyFunctionAfterRemove(
+                       stateMap,
+                       referenceStates,
+                       (removedCnt, removedStates) -> {
+                               // remove the same keys again, which would be 
absent already
+                               for (Map.Entry<Long, Set<Integer>> entry : 
removedStates.entrySet()) {
+                                       long namespace = entry.getKey();
+                                       for (int key : entry.getValue()) {
+                                               
assertNull(stateMap.removeAndGetOld(key, namespace));
+                                       }
                                }
+                               return removedCnt;
                        }
-                       for (int key : entry.getValue()) {
-                               assertNull(stateMap.get(key, namespace));
-                               assertFalse(stateMap.containsKey(key, 
namespace));
-                       }
-               }
-
-               assertEquals(totalSize * 2, 
spaceAllocator.getTotalSpaceNumber());
-               verifyState(referenceStates, stateMap);
+               );
+       }
 
-               // update some states
+       /**
+        * Apply the given function after removing some states.
+        *
+        * @param stateMap the state map to test against.
+        * @param referenceStates the reference of states for correctness 
verfication.
+        * @param function a {@link BiFunction} which takes [removedCnt, 
removedStates] as input parameters.
+        * @param <R> The type of the result returned by the function.
+        * @return The result of applying the given function.
+        */
+       private <R> R applyFunctionAfterRemove(
+               CopyOnWriteSkipListStateMap<Integer, Long, String> stateMap,
+               @Nonnull Map<Long, Map<Integer, String>> referenceStates,
+               BiFunction<Integer, Map<Long, Set<Integer>>, R> function) {
+               int removedCnt = 0;
+               Map<Long, Set<Integer>> removedStates = new HashMap<>();
+               // remove some state
                for (Map.Entry<Long, Map<Integer, String>> namespaceEntry : 
referenceStates.entrySet()) {
                        long namespace = namespaceEntry.getKey();
-                       for (Map.Entry<Integer, String> keyEntry : 
namespaceEntry.getValue().entrySet()) {
+                       Map<Integer, String> kvMap = namespaceEntry.getValue();
+                       Iterator<Map.Entry<Integer, String>> kvIterator = 
kvMap.entrySet().iterator();
+                       while (kvIterator.hasNext()) {
+                               Map.Entry<Integer, String> keyEntry = 
kvIterator.next();
                                if (random.nextBoolean()) {
                                        int key = keyEntry.getKey();
-                                       String state = keyEntry.getValue();
-                                       String newState = state + "-update";
-                                       keyEntry.setValue(newState);
-                                       if (random.nextBoolean()) {
-                                               stateMap.put(key, namespace, 
newState);
-                                       } else {
-                                               assertEquals(state, 
stateMap.putAndGetOld(key, namespace, newState));
-                                       }
-                                       assertEquals(totalSize, 
stateMap.size());
-                                       assertEquals(totalSize, 
stateMap.totalSize());
+                                       removedCnt++;
+                                       
removedStates.computeIfAbsent(namespace, (none) -> new HashSet<>()).add(key);
+                                       // remove from state map
+                                       stateMap.remove(key, namespace);
+                                       // remove from reference to keep in 
accordance
+                                       kvIterator.remove();
                                }
                        }
                }
+               return function.apply(removedCnt, removedStates);
+       }
 
-               // put some new states
-               for (long namespace = 10; namespace < 15; namespace++) {
-                       for (int key = 0; key < 100; key++) {
-                               totalSize++;
-                               String state = String.valueOf(key * namespace);
-                               if (random.nextBoolean()) {
-                                       stateMap.put(key, namespace, state);
-                               } else {
-                                       assertNull(stateMap.putAndGetOld(key, 
namespace, state));
+       /**
+        * Test state update operation.
+        */
+       @Test
+       public void testUpdateState() {
+               testWithFunction(
+                       (totalSize, stateMap, referenceStates) -> {
+                               // update some states
+                               for (Map.Entry<Long, Map<Integer, String>> 
namespaceEntry : referenceStates.entrySet()) {
+                                       long namespace = 
namespaceEntry.getKey();
+                                       for (Map.Entry<Integer, String> 
keyEntry : namespaceEntry.getValue().entrySet()) {
+                                               if (random.nextBoolean()) {
+                                                       int key = 
keyEntry.getKey();
+                                                       String state = 
keyEntry.getValue();
+                                                       String newState = state 
+ "-update";
+                                                       
keyEntry.setValue(newState);
+                                                       if 
(random.nextBoolean()) {
+                                                               
stateMap.put(key, namespace, newState);
+                                                       } else {
+                                                               
assertEquals(state, stateMap.putAndGetOld(key, namespace, newState));
+                                                       }
+                                               }
+                                       }
                                }
-                               referenceStates.computeIfAbsent(namespace, 
(none) -> new HashMap<>()).put(key, state);
-                               assertEquals(totalSize, stateMap.size());
-                               assertEquals(totalSize, stateMap.totalSize());
-                       }
-               }
+                               return getDefaultSizes(totalSize);
+                       });
+       }
 
-               // remove some absent states
-               for (Map.Entry<Long, Set<Integer>> entry : 
removedStates.entrySet()) {
-                       long namespace = entry.getKey();
-                       for (int key : entry.getValue()) {
-                               if (random.nextBoolean()) {
-                                       stateMap.remove(key, namespace);
-                               } else {
-                                       
assertNull(stateMap.removeAndGetOld(key, namespace));
+       /**
+        * Test transform existing state.
+        */
+       @Test
+       public void testTransformExistingState() throws Exception {
+               final AtomicReference<Exception> exceptionRef = new 
AtomicReference<>(null);
+               testWithFunction(
+                       (totalSize, stateMap, referenceStates) -> {
+                               StateTransformationFunction<String, Integer> 
function =
+                                       (String prevState, Integer value) -> 
prevState == null ? String.valueOf(value) : prevState + value;
+                               // transform existing states
+                               for (Map.Entry<Long, Map<Integer, String>> 
namespaceEntry : referenceStates.entrySet()) {
+                                       long namespace = 
namespaceEntry.getKey();
+                                       try {
+                                               for (Map.Entry<Integer, String> 
keyEntry : namespaceEntry.getValue().entrySet()) {
+                                                       if 
(random.nextBoolean()) {
+                                                               int key = 
keyEntry.getKey();
+                                                               String state = 
keyEntry.getValue();
+                                                               int delta = 
random.nextInt();
+                                                               String newState 
= function.apply(state, delta);
+                                                               
keyEntry.setValue(newState);
+                                                               
stateMap.transform(key, namespace, delta, function);
+                                                       }
+                                               }
+                                       } catch (Exception e) {
+                                               exceptionRef.set(e);
+                                       }
                                }
-                               assertEquals(totalSize, stateMap.size());
-                               assertEquals(totalSize, stateMap.totalSize());
-                       }
+                               return getDefaultSizes(totalSize);
+                       });
+               Exception e = exceptionRef.get();
+               if (e != null) {
+                       throw e;
                }
+       }
 
-               assertEquals(totalSize * 2, 
spaceAllocator.getTotalSpaceNumber());
-               verifyState(referenceStates, stateMap);
-
-               StateTransformationFunction<String, Integer> function =
-                       (String prevState, Integer value) -> prevState == null 
? String.valueOf(value) : prevState + value;
-
-               // transform some old states
-               for (Map.Entry<Long, Map<Integer, String>> namespaceEntry : 
referenceStates.entrySet()) {
-                       long namespace = namespaceEntry.getKey();
-                       for (Map.Entry<Integer, String> keyEntry : 
namespaceEntry.getValue().entrySet()) {
-                               if (random.nextBoolean()) {
-                                       int key = keyEntry.getKey();
-                                       String state = keyEntry.getValue();
-                                       int delta = random.nextInt();
-                                       String newState = function.apply(state, 
delta);
-                                       keyEntry.setValue(newState);
-                                       stateMap.transform(key, namespace, 
delta, function);
-                                       assertEquals(totalSize, 
stateMap.size());
-                                       assertEquals(totalSize, 
stateMap.totalSize());
+       /**
+        * Test transform with previous absent state.
+        */
+       @Test
+       public void testTransformNewState() throws Exception {
+               final AtomicReference<Exception> exceptionRef = new 
AtomicReference<>(null);
+               testWithFunction(
+                       (totalSize, stateMap, referenceStates) -> {
+                               StateTransformationFunction<String, Integer> 
function =
+                                       (String prevState, Integer value) -> 
prevState == null ? String.valueOf(value) : prevState + value;
+                               // transform some new states
+                               for (long namespace = initNamespaceNumber; 
namespace < initNamespaceNumber + 5; namespace++) {
+                                       for (int key = 0; key < 100; key++) {
+                                               totalSize++;
+                                               int value = (int) (key * 
namespace);
+                                               try {
+                                                       stateMap.transform(key, 
namespace, value, function);
+                                                       String state = 
function.apply(null, value);
+                                                       
referenceStates.computeIfAbsent(namespace, (none) -> new HashMap<>()).put(key, 
state);
+                                               } catch (Exception e) {
+                                                       exceptionRef.set(e);
+                                               }
+                                       }
                                }
-                       }
+                               return getDefaultSizes(totalSize);
+                       });
+               Exception e = exceptionRef.get();
+               if (e != null) {
+                       throw e;
                }
+       }
 
-               // transform some new states
-               for (long namespace = 15; namespace < 20; namespace++) {
-                       for (int key = 0; key < 100; key++) {
-                               totalSize++;
-                               int value = (int) (key * namespace);
-                               stateMap.transform(key, namespace, value, 
function);
-                               referenceStates.computeIfAbsent(namespace, 
(none) -> new HashMap<>()).put(key, String.valueOf(value));
-                               assertEquals(totalSize, stateMap.size());
-                               assertEquals(totalSize, stateMap.totalSize());
-                       }
-               }
+       /**
+        * Test remove namespace.
+        */
+       @Test
+       public void testPurgeNamespace() {
+               testPurgeNamespace(false);
+       }
 
-               assertEquals(totalSize * 2, 
spaceAllocator.getTotalSpaceNumber());
-               verifyState(referenceStates, stateMap);
+       /**
+        * Test remove namespace.
+        */
+       @Test
+       public void testPurgeNamespaceWithSnapshot() {
+               testPurgeNamespace(true);
+       }
 
-               // put some previously removed states
-               for (Map.Entry<Long, Set<Integer>> entry : 
removedStates.entrySet()) {
-                       long namespace = entry.getKey();
-                       for (int key : entry.getValue()) {
-                               totalSize++;
-                               String state = String.valueOf(key * namespace);
-                               if (random.nextBoolean()) {
-                                       stateMap.put(key, namespace, state);
-                               } else {
-                                       assertNull(stateMap.putAndGetOld(key, 
namespace, state));
+       private void testPurgeNamespace(boolean withSnapshot) {
+               testWithFunction(
+                       (totalSize, stateMap, referenceStates) -> {
+                               CopyOnWriteSkipListStateMapSnapshot<Integer, 
Long, String> snapshot = null;
+                               int totalSizeIncludingLogicalRemove = totalSize;
+                               int totalSpaceSize = totalSize * 2;
+                               if (withSnapshot) {
+                                       snapshot = stateMap.stateSnapshot();
                                }
-                               referenceStates.computeIfAbsent(namespace, 
(none) -> new HashMap<>()).put(key, String.valueOf(state));
-                               assertEquals(totalSize, stateMap.size());
-                               assertEquals(totalSize, stateMap.totalSize());
-                       }
-               }
-
-               assertEquals(totalSize * 2, 
spaceAllocator.getTotalSpaceNumber());
-               verifyState(referenceStates, stateMap);
-
-               // remove some namespaces
-               Set<Long> removedNamespaces = new HashSet<>();
-               for (Map.Entry<Long, Map<Integer, String>> namespaceEntry : 
referenceStates.entrySet()) {
-                       if (random.nextBoolean()) {
-                               long namespace = namespaceEntry.getKey();
-                               removedNamespaces.add(namespace);
-                               for (Map.Entry<Integer, String> keyEntry : 
namespaceEntry.getValue().entrySet()) {
-                                       int key = keyEntry.getKey();
+                               // empty some namespaces
+                               Set<Long> removedNamespaces = new HashSet<>();
+                               for (Map.Entry<Long, Map<Integer, String>> 
namespaceEntry : referenceStates.entrySet()) {
                                        if (random.nextBoolean()) {
-                                               stateMap.remove(key, namespace);
-                                       } else {
-                                               
assertEquals(keyEntry.getValue(), stateMap.removeAndGetOld(key, namespace));
+                                               long namespace = 
namespaceEntry.getKey();
+                                               
removedNamespaces.add(namespace);
+                                               for (Map.Entry<Integer, String> 
keyEntry : namespaceEntry.getValue().entrySet()) {
+                                                       int key = 
keyEntry.getKey();
+                                                       if 
(random.nextBoolean()) {
+                                                               
stateMap.remove(key, namespace);
+                                                       } else {
+                                                               
assertEquals(keyEntry.getValue(), stateMap.removeAndGetOld(key, namespace));
+                                                       }
+                                                       totalSize--;
+                                                       if (withSnapshot) {
+                                                               // logical 
remove with copy-on-write
+                                                               
totalSpaceSize++;
+                                                       } else {
+                                                               // physical 
remove
+                                                               
totalSizeIncludingLogicalRemove--;
+                                                               totalSpaceSize 
-= 2;
+                                                       }
+                                               }
                                        }
-                                       totalSize--;
-                                       assertEquals(totalSize, 
stateMap.size());
-                                       assertEquals(totalSize, 
stateMap.totalSize());
                                }
-                       }
-               }
-
-               for (long namespace : removedNamespaces) {
-                       referenceStates.remove(namespace);
-                       assertEquals(0, stateMap.sizeOfNamespace(namespace));
-                       
assertFalse(stateMap.getKeys(namespace).iterator().hasNext());
-               }
 
-               assertEquals(totalSize * 2, 
spaceAllocator.getTotalSpaceNumber());
-               verifyState(referenceStates, stateMap);
+                               for (long namespace : removedNamespaces) {
+                                       referenceStates.remove(namespace);
+                                       // verify namespace related stuff.
+                                       assertEquals(0, 
stateMap.sizeOfNamespace(namespace));
+                                       
assertFalse(stateMap.getKeys(namespace).iterator().hasNext());
+                               }
+                               if (withSnapshot) {
+                                       snapshot.release();
+                               }
+                               return new Tuple3<>(totalSize, 
totalSizeIncludingLogicalRemove, totalSpaceSize);
+                       }
+               );
+       }
 
+       /**
+        * Test close operation.
+        */
+       @Test
+       public void testClose() {
+               CopyOnWriteSkipListStateMap<Integer, Long, String> stateMap = 
createStateMapForTesting();
+               putStates(stateMap, null);
                stateMap.close();
+               assertTrue(stateMap.isClosed());
                assertEquals(0, stateMap.size());
                assertEquals(0, stateMap.totalSize());
-               // all spaces should be free
                assertEquals(0, spaceAllocator.getTotalSpaceNumber());
-               assertTrue(stateMap.isClosed());
        }
 
        /**
-        *  Tests copy-on-write contracts.
+        * Test with the given function.
+        *
+        * @param function a {@link TriFunction} with [totalSizeBeforeFunction, 
stateMap, referenceStates] as input
+        *                 parameters and returns the [totalSize, 
totalSizeIncludingLogicalRemovedKey, totalSpaceSize]
+        *                 tuple after applying the function.
         */
-       @SuppressWarnings("unchecked")
-       @Test
-       public void testCopyOnWriteContracts() throws IOException {
-               TypeSerializer<Integer> keySerializer = IntSerializer.INSTANCE;
-               TypeSerializer<Long> namespaceSerializer = 
LongSerializer.INSTANCE;
-               TypeSerializer<String> stateSerializer = 
StringSerializer.INSTANCE;
+       private void testWithFunction(
+               @Nonnull TriFunction<
+                       Integer,
+                       CopyOnWriteSkipListStateMap<Integer, Long, String>,
+                       Map<Long, Map<Integer, String>>,
+                       Tuple3<Integer, Integer, Integer>> function) {
                // do not remove states physically when get, put, remove and 
snapshot
-               CopyOnWriteSkipListStateMap<Integer, Long, String> stateMap = 
new CopyOnWriteSkipListStateMap<>(
-                       keySerializer,
-                       namespaceSerializer,
-                       stateSerializer,
-                       spaceAllocator,
-                       0,
-                       1.0f);
+               CopyOnWriteSkipListStateMap<Integer, Long, String> stateMap = 
createStateMapForTesting(0, 1.0f);
+               // map to store expected states, namespace -> key -> state
+               Map<Long, Map<Integer, String>> referenceStates = new 
HashMap<>();
+               int totalSize = putStates(stateMap, referenceStates);
+               Tuple3 tuple3 = function.apply(totalSize, stateMap, 
referenceStates);
+               totalSize = (int) tuple3.f0;
+               int totalSizeIncludingLogicalRemove = (int) tuple3.f1;
+               int totalSpaceSize = (int) tuple3.f2;
+               assertEquals(totalSize, stateMap.size());
+               assertEquals(totalSizeIncludingLogicalRemove, 
stateMap.totalSize());
+               assertEquals(totalSpaceSize, 
spaceAllocator.getTotalSpaceNumber());
+               verifyState(referenceStates, stateMap);
+               stateMap.close();
+       }
 
-               StateSnapshotTransformer<String> transformer = new 
StateSnapshotTransformer<String>() {
-                       @Nullable
-                       @Override
-                       public String filterOrTransform(@Nullable String value) 
{
-                               if (value == null) {
-                                       return null;
+       private int putStates(
+               CopyOnWriteSkipListStateMap<Integer, Long, String> stateMap,
+               Map<Long, Map<Integer, String>> referenceStates) {
+               int totalSize = 0;
+               for (long namespace = 0; namespace < initNamespaceNumber; 
namespace++) {
+                       for (int key = 0; key < 100; key++) {
+                               totalSize++;
+                               String state = String.valueOf(key * namespace);
+                               if (random.nextBoolean()) {
+                                       stateMap.put(key, namespace, state);
+                               } else {
+                                       assertNull(stateMap.putAndGetOld(key, 
namespace, state));
                                }
-                               int op = value.hashCode() % 3;
-                               switch (op) {
-                                       case 0:
-                                               return null;
-                                       case 1:
-                                               return value + "-transform";
-                                       default:
-                                               return value;
+                               if (referenceStates != null) {
+                                       
referenceStates.computeIfAbsent(namespace, (none) -> new HashMap<>()).put(key, 
state);
                                }
                        }
-               };
+               }
+               assertEquals(totalSize, stateMap.size());
+               assertEquals(totalSize, stateMap.totalSize());
+               return totalSize;
+       }
 
+       /**
+        * By default there's no remove/copy-on-write, so space cost would be 
two times (for key and value) of entry number.
+        *
+        * @param totalSize the total number of valid entries in state map.
+        * @return the default tuple3 of [entry_number, 
entry_number_including_logical_remove, space_size]
+        */
+       private Tuple3<Integer, Integer, Integer> getDefaultSizes(int 
totalSize) {
+               return new Tuple3<>(totalSize, totalSize, totalSize * 2);
+       }
+
+       /**
+        * Test snapshot empty state map.
+        */
+       @Test
+       public void testSnapshotEmptyStateMap() throws IOException {
+               CopyOnWriteSkipListStateMap<Integer, Long, String> stateMap = 
createStateMapForTesting();
                // map to store expected states, namespace -> key -> state
                Map<Long, Map<Integer, String>> referenceStates = new 
HashMap<>();
-               int totalStateSize = 0;
-               int totalSizeIncludingLogicalRemovedKey = 0;
-               int totalLogicallyRemovedKey = 0;
-               int totalSpaceNumber = 0;
-
-               // take snapshot 1 which is an empty snapshot
-               Map<Long, Map<Integer, String>> expectedSnapshot1 = 
snapshotReferenceStates(referenceStates);
-               CopyOnWriteSkipListStateMapSnapshot<Integer, Long, String> 
snapshot1 =
-                       (CopyOnWriteSkipListStateMapSnapshot<Integer, Long, 
String>) stateMap.stateSnapshot();
+               // take snapshot on an empty state map
+               Map<Long, Map<Integer, String>> expectedSnapshot = 
snapshotReferenceStates(referenceStates);
+               CopyOnWriteSkipListStateMapSnapshot<Integer, Long, String> 
snapshot = stateMap.stateSnapshot();
                assertEquals(1, stateMap.getHighestRequiredSnapshotVersion());
                assertEquals(1, stateMap.getSnapshotVersions().size());
-               assertEquals(true, stateMap.getSnapshotVersions().contains(1));
+               assertThat(stateMap.getSnapshotVersions(), contains(1));
                assertEquals(1, stateMap.getResourceGuard().getLeaseCount());
                verifySnapshotWithoutTransform(
-                       expectedSnapshot1, snapshot1, keySerializer, 
namespaceSerializer, stateSerializer);
+                       expectedSnapshot, snapshot, keySerializer, 
namespaceSerializer, stateSerializer);
                verifySnapshotWithTransform(
-                       expectedSnapshot1, snapshot1, transformer, 
keySerializer, namespaceSerializer, stateSerializer);
-
-               snapshot1.release();
-               assertEquals(1, stateMap.getStateMapVersion());
-               assertEquals(0, stateMap.getHighestRequiredSnapshotVersion());
-               assertEquals(1, stateMap.getHighestFinishedSnapshotVersion());
-               assertTrue(stateMap.getSnapshotVersions().isEmpty());
-               assertEquals(0, stateMap.getResourceGuard().getLeaseCount());
+                       expectedSnapshot, snapshot, transformer, keySerializer, 
namespaceSerializer, stateSerializer);
+               snapshot.release();
+               stateMap.close();
+       }
 
-               // put some states
-               for (int i = 1; i <= 10; i++) {
-                       totalStateSize++;
-                       totalSizeIncludingLogicalRemovedKey++;
-                       totalSpaceNumber += 2;
-                       stateMap.put(i, (long) i, String.valueOf(i));
-                       addToReferenceState(referenceStates, i, (long) i, 
String.valueOf(i));
+       /**
+        * Test snapshot release.
+        */
+       @Test
+       public void testReleaseSnapshot() {
+               CopyOnWriteSkipListStateMap<Integer, Long, String> stateMap = 
createStateMapForTesting();
+               int expectedSnapshotVersion = 0;
+               int round = 10;
+               for (int i = 0; i < round; i++) {
+                       assertEquals(expectedSnapshotVersion, 
stateMap.getStateMapVersion());
+                       assertEquals(expectedSnapshotVersion, 
stateMap.getHighestFinishedSnapshotVersion());
+                       CopyOnWriteSkipListStateMapSnapshot<Integer, Long, 
String> snapshot = stateMap.stateSnapshot();
+                       expectedSnapshotVersion++;
+                       snapshot.release();
+                       assertEquals(0, 
stateMap.getHighestRequiredSnapshotVersion());
+                       assertTrue(stateMap.getSnapshotVersions().isEmpty());
+                       assertEquals(0, 
stateMap.getResourceGuard().getLeaseCount());
                }
+               stateMap.close();
+       }
 
-               // take snapshot 2
-               Map<Long, Map<Integer, String>> expectedSnapshot2 = 
snapshotReferenceStates(referenceStates);
-               CopyOnWriteSkipListStateMapSnapshot<Integer, Long, String> 
snapshot2 =
-                       (CopyOnWriteSkipListStateMapSnapshot<Integer, Long, 
String>) stateMap.stateSnapshot();
+       /**
+        * Test basic snapshot correctness.
+        */
+       @Test
+       public void testBasicSnapshot() throws IOException {
+               CopyOnWriteSkipListStateMap<Integer, Long, String> stateMap = 
createStateMapForTesting();
+               // map to store expected states, namespace -> key -> state
+               Map<Long, Map<Integer, String>> referenceStates = new 
HashMap<>();
+               // take an empty snapshot
+               CopyOnWriteSkipListStateMapSnapshot<Integer, Long, String> 
snapshot = stateMap.stateSnapshot();
+               snapshot.release();
+               // put some states
+               putStates(stateMap, referenceStates);
+               // take the 2nd snapshot with data
+               Map<Long, Map<Integer, String>> expectedSnapshot = 
snapshotReferenceStates(referenceStates);
+               CopyOnWriteSkipListStateMapSnapshot<Integer, Long, String> 
snapshot2 = stateMap.stateSnapshot();
                assertEquals(2, stateMap.getStateMapVersion());
                assertEquals(2, stateMap.getHighestRequiredSnapshotVersion());
                assertEquals(1, stateMap.getSnapshotVersions().size());
-               assertEquals(true, stateMap.getSnapshotVersions().contains(2));
+               assertThat(stateMap.getSnapshotVersions(), contains(2));
                assertEquals(1, stateMap.getResourceGuard().getLeaseCount());
-
-               // 1. test put -> put -> remove for (key 1, namespace 1)
-
-               // put (key 1, namespace 1), and copy-on-write should happen
-               stateMap.put(1, 1L, String.valueOf("11"));
-               addToReferenceState(referenceStates, 1, 1L, "11");
-               // a space for new value should be allocated
-               totalSpaceNumber += 1;
-               assertEquals(totalSpaceNumber, 
spaceAllocator.getTotalSpaceNumber());
-               assertEquals("11", stateMap.get(1, 1L));
-               assertTrue(stateMap.containsKey(1, 1L));
-               assertEquals(totalStateSize, stateMap.size());
                verifyState(referenceStates, stateMap);
+               verifySnapshotWithoutTransform(
+                       expectedSnapshot, snapshot2, keySerializer, 
namespaceSerializer, stateSerializer);
+               snapshot2.release();
+               stateMap.close();
+       }
 
-               // put (key 1, namespace 1) again, old value should be replaced 
and space will not increase
-               assertEquals("11", stateMap.putAndGetOld(1, 1L, 
String.valueOf("111")));
-               addToReferenceState(referenceStates, 1, 1L, "111");
-               assertEquals(totalSpaceNumber, 
spaceAllocator.getTotalSpaceNumber());
-               assertEquals("111", stateMap.get(1, 1L));
-               assertTrue(stateMap.containsKey(1, 1L));
-               verifyState(referenceStates, stateMap);
+       /**
+        * Test put -> put without snapshot.
+        */
+       @Test
+       public void testPutAndPut() {
+               testPutAndPut(false);
+       }
 
-               // remove (key 1, namespace 1)
-               stateMap.remove(1, 1L);
-               removeFromReferenceState(referenceStates, 1, 1L);
-               totalStateSize--;
-               totalLogicallyRemovedKey++;
-               assertEquals(totalSpaceNumber, 
spaceAllocator.getTotalSpaceNumber());
-               assertNull(stateMap.get(1, 1L));
-               assertFalse(stateMap.containsKey(1, 1L));
-               assertEquals(totalStateSize, stateMap.size());
-               assertEquals(totalSizeIncludingLogicalRemovedKey, 
stateMap.totalSize());
-               assertEquals(totalLogicallyRemovedKey, 
stateMap.getLogicallyRemovedNodes().size());
-               verifyState(referenceStates, stateMap);
+       /**
+        * Test put -> put during snapshot, the first put should trigger 
copy-on-write and the second shouldn't.
+        */
+       @Test
+       public void testPutAndPutWithSnapshot() {
+               testPutAndPut(true);
+       }
 
-               // 2. test remove -> remove -> put for (key 4, namespace 4)
+       private void testPutAndPut(boolean withSnapshot) {
+               final int key = 1;
+               final long namespace = 1L;
+               final String value = "11";
+               final String newValue = "111";
+               testWithFunction(
+                       (totalSize, stateMap, referenceStates) -> {
+                               CopyOnWriteSkipListStateMapSnapshot<Integer, 
Long, String> snapshot = null;
+                               if (withSnapshot) {
+                                       // take a snapshot
+                                       snapshot = stateMap.stateSnapshot();
+                               }
+                               // put (key 1, namespace 1)
+                               int totalSpaceNumber =
+                                       putExistingKey(totalSize, stateMap, 
referenceStates,
+                                               totalSize * 2, key, namespace, 
value, withSnapshot);
+
+                               // put (key 1, namespace 1) again, old value 
should be replaced and space will not increase
+                               assertEquals("11", stateMap.putAndGetOld(key, 
namespace, newValue));
+                               addToReferenceState(referenceStates, key, 
namespace, newValue);
+                               assertEquals("111", stateMap.get(key, 
namespace));
+                               assertTrue(stateMap.containsKey(key, 
namespace));
+                               if (withSnapshot) {
+                                       snapshot.release();
+                               }
+                               return new Tuple3<>(totalSize, totalSize, 
totalSpaceNumber);
+                       }
+               );
+       }
 
-               // remove (key 4, namespace 4), and it should be logically 
removed
-               assertEquals("4", stateMap.removeAndGetOld(4, 4L));
-               removeFromReferenceState(referenceStates, 4, 4L);
-               // a space should be allocated
-               totalStateSize--;
-               totalLogicallyRemovedKey++;
-               totalSpaceNumber += 1;
-               assertEquals(totalSpaceNumber, 
spaceAllocator.getTotalSpaceNumber());
-               assertNull(stateMap.get(4, 4L));
-               assertFalse(stateMap.containsKey(4, 4L));
-               assertEquals(totalStateSize, stateMap.size());
-               assertEquals(totalSizeIncludingLogicalRemovedKey, 
stateMap.totalSize());
-               assertEquals(totalLogicallyRemovedKey, 
stateMap.getLogicallyRemovedNodes().size());
-               verifyState(referenceStates, stateMap);
+       /**
+        * Test put -> remove without snapshot.
+        */
+       @Test
+       public void testPutAndRemove() {
+               testPutAndRemove(false);
+       }
 
-               // remove (key 4, namespace 4) again, and nothing should happen
-               assertNull(stateMap.removeAndGetOld(4, 4L));
-               assertEquals(totalSpaceNumber, 
spaceAllocator.getTotalSpaceNumber());
-               assertEquals(totalStateSize, stateMap.size());
-               assertEquals(totalSizeIncludingLogicalRemovedKey, 
stateMap.totalSize());
-               assertEquals(totalLogicallyRemovedKey, 
stateMap.getLogicallyRemovedNodes().size());
-               verifyState(referenceStates, stateMap);
+       /**
+        * Test put -> remove during snapshot, put should trigger copy-on-write 
and remove shouldn't.
+        */
+       @Test
+       public void testPutAndRemoveWithSnapshot() {
+               testPutAndRemove(true);
+       }
 
-               // put the logically removed (key 4, namespace 4)
-               assertNull(stateMap.putAndGetOld(4, 4L, "44"));
-               addToReferenceState(referenceStates, 4, 4L, "44");
-               totalStateSize++;
-               totalLogicallyRemovedKey--;
-               assertEquals(totalSpaceNumber, 
spaceAllocator.getTotalSpaceNumber());
-               assertEquals("44", stateMap.get(4, 4L));
-               assertTrue(stateMap.containsKey(4, 4L));
-               assertEquals(totalStateSize, stateMap.size());
-               assertEquals(totalSizeIncludingLogicalRemovedKey, 
stateMap.totalSize());
-               assertEquals(totalLogicallyRemovedKey, 
stateMap.getLogicallyRemovedNodes().size());
-               verifyState(referenceStates, stateMap);
+       private void testPutAndRemove(boolean withSnapshot) {
+               final int key = 6;
+               final long namespace = 6L;
+               final String value = "66";
+               testWithFunction(
+                       (totalSize, stateMap, referenceStates) -> {
+                               int totalLogicallyRemovedKey = 0;
+                               int totalSizeIncludingLogicalRemovedKey = 
totalSize;
+                               CopyOnWriteSkipListStateMapSnapshot<Integer, 
Long, String> snapshot = null;
+                               if (withSnapshot) {
+                                       // take a snapshot
+                                       snapshot = stateMap.stateSnapshot();
+                               }
 
 Review comment:
   Instead of pushing all logic deep into the one uber test function I think it 
is better to use a compositional approach. E.g. one could write the test with 
`withSnapshot = false` like this
   
   ```
   putAndRemove()
   makeAssertions()
   ```
   
   and the one with `withSnapshot == true`
   
   ```
   takeSnapshot()
   putAndRemove()
   releaseSnapshot()
   makeAssertions()
   ```

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to