http://git-wip-us.apache.org/repos/asf/flink/blob/484fedd4/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java index d318575..c791fd8 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java @@ -2731,10 +2731,15 @@ public class CheckpointCoordinatorTest extends TestLogger { @Test public void testReplicateModeStateHandle() { Map<String, OperatorStateHandle.StateMetaInfo> metaInfoMap = new HashMap<>(1); - metaInfoMap.put("t-1", new OperatorStateHandle.StateMetaInfo(new long[]{0, 23}, OperatorStateHandle.Mode.BROADCAST)); - metaInfoMap.put("t-2", new OperatorStateHandle.StateMetaInfo(new long[]{42, 64}, OperatorStateHandle.Mode.BROADCAST)); + metaInfoMap.put("t-1", new OperatorStateHandle.StateMetaInfo(new long[]{0, 23}, OperatorStateHandle.Mode.UNION)); + metaInfoMap.put("t-2", new OperatorStateHandle.StateMetaInfo(new long[]{42, 64}, OperatorStateHandle.Mode.UNION)); metaInfoMap.put("t-3", new OperatorStateHandle.StateMetaInfo(new long[]{72, 83}, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE)); - OperatorStateHandle osh = new OperatorStateHandle(metaInfoMap, new ByteStreamStateHandle("test", new byte[100])); + metaInfoMap.put("t-4", new OperatorStateHandle.StateMetaInfo(new long[]{87, 94, 95}, OperatorStateHandle.Mode.BROADCAST)); + metaInfoMap.put("t-5", new OperatorStateHandle.StateMetaInfo(new long[]{97, 108, 112}, OperatorStateHandle.Mode.BROADCAST)); + metaInfoMap.put("t-6", new OperatorStateHandle.StateMetaInfo(new long[]{121, 143, 147}, OperatorStateHandle.Mode.BROADCAST)); + + // this is what a single task will return + OperatorStateHandle osh = new OperatorStateHandle(metaInfoMap, new ByteStreamStateHandle("test", new byte[150])); OperatorStateRepartitioner repartitioner = RoundRobinOperatorStateRepartitioner.INSTANCE; List<Collection<OperatorStateHandle>> repartitionedStates = @@ -2757,18 +2762,26 @@ public class CheckpointCoordinatorTest extends TestLogger { OperatorStateHandle.StateMetaInfo stateMetaInfo = stateNameToMetaInfo.getValue(); if (OperatorStateHandle.Mode.SPLIT_DISTRIBUTE.equals(stateMetaInfo.getDistributionMode())) { + // SPLIT_DISTRIBUTE: so split the state and re-distribute it -> each one will go to one task Assert.assertEquals(1, stateNameToMetaInfo.getValue().getOffsets().length); - } else { + } else if (OperatorStateHandle.Mode.UNION.equals(stateMetaInfo.getDistributionMode())) { + // BROADCAST: so all to all Assert.assertEquals(2, stateNameToMetaInfo.getValue().getOffsets().length); + } else { + // UNIFORM_BROADCAST: so all to all + Assert.assertEquals(3, stateNameToMetaInfo.getValue().getOffsets().length); } } } } - Assert.assertEquals(3, checkCounts.size()); + Assert.assertEquals(6, checkCounts.size()); Assert.assertEquals(3, checkCounts.get("t-1").intValue()); Assert.assertEquals(3, checkCounts.get("t-2").intValue()); Assert.assertEquals(2, checkCounts.get("t-3").intValue()); + Assert.assertEquals(3, checkCounts.get("t-4").intValue()); + Assert.assertEquals(3, checkCounts.get("t-5").intValue()); + Assert.assertEquals(3, checkCounts.get("t-6").intValue()); } // ------------------------------------------------------------------------ @@ -3243,7 +3256,7 @@ public class CheckpointCoordinatorTest extends TestLogger { Path fakePath = new Path("/fake-" + i); Map<String, OperatorStateHandle.StateMetaInfo> namedStatesToOffsets = new HashMap<>(); int off = 0; - for (int s = 0; s < numNamedStates; ++s) { + for (int s = 0; s < numNamedStates - 1; ++s) { long[] offs = new long[1 + r.nextInt(maxPartitionsPerState)]; for (int o = 0; o < offs.length; ++o) { @@ -3252,19 +3265,29 @@ public class CheckpointCoordinatorTest extends TestLogger { } OperatorStateHandle.Mode mode = r.nextInt(10) == 0 ? - OperatorStateHandle.Mode.BROADCAST : OperatorStateHandle.Mode.SPLIT_DISTRIBUTE; + OperatorStateHandle.Mode.UNION : OperatorStateHandle.Mode.SPLIT_DISTRIBUTE; namedStatesToOffsets.put( "State-" + s, new OperatorStateHandle.StateMetaInfo(offs, mode)); } + if (numNamedStates % 2 == 0) { + // finally add a broadcast state + long[] offs = {off + 1, off + 2, off + 3, off + 4}; + + namedStatesToOffsets.put( + "State-" + (numNamedStates - 1), + new OperatorStateHandle.StateMetaInfo(offs, OperatorStateHandle.Mode.BROADCAST)); + } + previousParallelOpInstanceStates.add( new OperatorStateHandle(namedStatesToOffsets, new FileStateHandle(fakePath, -1))); } Map<StreamStateHandle, Map<String, List<Long>>> expected = new HashMap<>(); + int taskIndex = 0; int expectedTotalPartitions = 0; for (OperatorStateHandle psh : previousParallelOpInstanceStates) { Map<String, OperatorStateHandle.StateMetaInfo> offsMap = psh.getStateNameToPartitionOffsets(); @@ -3272,20 +3295,39 @@ public class CheckpointCoordinatorTest extends TestLogger { for (Map.Entry<String, OperatorStateHandle.StateMetaInfo> e : offsMap.entrySet()) { long[] offs = e.getValue().getOffsets(); - int replication = e.getValue().getDistributionMode().equals(OperatorStateHandle.Mode.BROADCAST) ? - newParallelism : 1; + int replication; + switch (e.getValue().getDistributionMode()) { + case UNION: + replication = newParallelism; + break; + case BROADCAST: + int extra = taskIndex < (newParallelism % oldParallelism) ? 1 : 0; + replication = newParallelism / oldParallelism + extra; + break; + case SPLIT_DISTRIBUTE: + replication = 1; + break; + default: + throw new RuntimeException("Unknown distribution mode " + e.getValue().getDistributionMode()); + } - expectedTotalPartitions += replication * offs.length; - List<Long> offsList = new ArrayList<>(offs.length); + if (replication > 0) { + expectedTotalPartitions += replication * offs.length; + List<Long> offsList = new ArrayList<>(offs.length); - for (long off : offs) { - for (int p = 0; p < replication; ++p) { - offsList.add(off); + for (long off : offs) { + for (int p = 0; p < replication; ++p) { + offsList.add(off); + } } + offsMapWithList.put(e.getKey(), offsList); } - offsMapWithList.put(e.getKey(), offsList); } - expected.put(psh.getDelegateStateHandle(), offsMapWithList); + + if (!offsMapWithList.isEmpty()) { + expected.put(psh.getDelegateStateHandle(), offsMapWithList); + } + taskIndex++; } OperatorStateRepartitioner repartitioner = RoundRobinOperatorStateRepartitioner.INSTANCE;
http://git-wip-us.apache.org/repos/asf/flink/blob/484fedd4/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/CheckpointTestUtils.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/CheckpointTestUtils.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/CheckpointTestUtils.java index acedb50..d1d67ff 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/CheckpointTestUtils.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/CheckpointTestUtils.java @@ -97,7 +97,7 @@ public class CheckpointTestUtils { Map<String, StateMetaInfo> offsetsMap = new HashMap<>(); offsetsMap.put("A", new OperatorStateHandle.StateMetaInfo(new long[]{0, 10, 20}, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE)); offsetsMap.put("B", new OperatorStateHandle.StateMetaInfo(new long[]{30, 40, 50}, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE)); - offsetsMap.put("C", new OperatorStateHandle.StateMetaInfo(new long[]{60, 70, 80}, OperatorStateHandle.Mode.BROADCAST)); + offsetsMap.put("C", new OperatorStateHandle.StateMetaInfo(new long[]{60, 70, 80}, OperatorStateHandle.Mode.UNION)); if (hasOperatorStateBackend) { operatorStateHandleBackend = new OperatorStateHandle(offsetsMap, operatorStateBackend); @@ -179,7 +179,7 @@ public class CheckpointTestUtils { Map<String, StateMetaInfo> offsetsMap = new HashMap<>(); offsetsMap.put("A", new OperatorStateHandle.StateMetaInfo(new long[]{0, 10, 20}, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE)); offsetsMap.put("B", new OperatorStateHandle.StateMetaInfo(new long[]{30, 40, 50}, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE)); - offsetsMap.put("C", new OperatorStateHandle.StateMetaInfo(new long[]{60, 70, 80}, OperatorStateHandle.Mode.BROADCAST)); + offsetsMap.put("C", new OperatorStateHandle.StateMetaInfo(new long[]{60, 70, 80}, OperatorStateHandle.Mode.UNION)); if (chainIdx != noOperatorStateBackendAtIndex) { OperatorStateHandle operatorStateHandleBackend = http://git-wip-us.apache.org/repos/asf/flink/blob/484fedd4/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java index ef390db..1881dad 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java @@ -18,8 +18,11 @@ package org.apache.flink.runtime.state; import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.state.BroadcastState; import org.apache.flink.api.common.state.ListState; import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.state.MapStateDescriptor; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; import org.apache.flink.api.common.typeutils.CompatibilityResult; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.common.typeutils.TypeSerializerConfigSnapshot; @@ -49,7 +52,9 @@ import java.io.File; import java.io.IOException; import java.io.Serializable; import java.util.Collections; +import java.util.HashMap; import java.util.Iterator; +import java.util.Map; import java.util.concurrent.CancellationException; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; @@ -85,6 +90,7 @@ public class OperatorStateBackendTest { assertNotNull(operatorStateBackend); assertTrue(operatorStateBackend.getRegisteredStateNames().isEmpty()); + assertTrue(operatorStateBackend.getRegisteredBroadcastStateNames().isEmpty()); } @Test @@ -233,6 +239,20 @@ public class OperatorStateBackendTest { listState.add(42); + AtomicInteger keyCopyCounter = new AtomicInteger(0); + AtomicInteger valueCopyCounter = new AtomicInteger(0); + + TypeSerializer<Integer> keySerializer = new VerifyingIntSerializer(env.getUserClassLoader(), keyCopyCounter); + TypeSerializer<Integer> valueSerializer = new VerifyingIntSerializer(env.getUserClassLoader(), valueCopyCounter); + + MapStateDescriptor<Integer, Integer> broadcastStateDesc = new MapStateDescriptor<>( + "test-broadcast", keySerializer, valueSerializer); + + BroadcastState<Integer, Integer> broadcastState = operatorStateBackend.getBroadcastState(broadcastStateDesc); + broadcastState.put(1, 2); + broadcastState.put(3, 4); + broadcastState.put(5, 6); + CheckpointStreamFactory streamFactory = new MemCheckpointStreamFactory(4096); RunnableFuture<OperatorStateHandle> runnableFuture = operatorStateBackend.snapshot(1, 1, streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()); @@ -240,6 +260,8 @@ public class OperatorStateBackendTest { // make sure that the copy method has been called assertTrue(copyCounter.get() > 0); + assertTrue(keyCopyCounter.get() > 0); + assertTrue(valueCopyCounter.get() > 0); } /** @@ -361,17 +383,102 @@ public class OperatorStateBackendTest { } @Test + public void testSnapshotBroadcastStateWithEmptyOperatorState() throws Exception { + final AbstractStateBackend abstractStateBackend = new MemoryStateBackend(4096); + + final OperatorStateBackend operatorStateBackend = + abstractStateBackend.createOperatorStateBackend(createMockEnvironment(), "testOperator"); + + final MapStateDescriptor<Integer, Integer> broadcastStateDesc = new MapStateDescriptor<>( + "test-broadcast", BasicTypeInfo.INT_TYPE_INFO, BasicTypeInfo.INT_TYPE_INFO); + + final Map<Integer, Integer> expected = new HashMap<>(3); + expected.put(1, 2); + expected.put(3, 4); + expected.put(5, 6); + + final BroadcastState<Integer, Integer> broadcastState = operatorStateBackend.getBroadcastState(broadcastStateDesc); + broadcastState.putAll(expected); + + final CheckpointStreamFactory streamFactory = new MemCheckpointStreamFactory(4096); + OperatorStateHandle stateHandle = null; + + try { + RunnableFuture<OperatorStateHandle> snapshot = + operatorStateBackend.snapshot(0L, 0L, streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()); + + stateHandle = FutureUtil.runIfNotDoneAndGet(snapshot); + assertNotNull(stateHandle); + + final Map<Integer, Integer> retrieved = new HashMap<>(); + + operatorStateBackend.restore(Collections.singleton(stateHandle)); + BroadcastState<Integer, Integer> retrievedState = operatorStateBackend.getBroadcastState(broadcastStateDesc); + for (Map.Entry<Integer, Integer> e: retrievedState.entries()) { + retrieved.put(e.getKey(), e.getValue()); + } + assertEquals(expected, retrieved); + + // remove an element from both expected and stored state. + broadcastState.remove(1); + expected.remove(1); + + snapshot = operatorStateBackend.snapshot(1L, 1L, streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()); + stateHandle = FutureUtil.runIfNotDoneAndGet(snapshot); + + retrieved.clear(); + operatorStateBackend.restore(Collections.singleton(stateHandle)); + retrievedState = operatorStateBackend.getBroadcastState(broadcastStateDesc); + for (Map.Entry<Integer, Integer> e: retrievedState.immutableEntries()) { + retrieved.put(e.getKey(), e.getValue()); + } + assertEquals(expected, retrieved); + + // remove all elements from both expected and stored state. + broadcastState.clear(); + expected.clear(); + + snapshot = operatorStateBackend.snapshot(2L, 2L, streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()); + stateHandle = FutureUtil.runIfNotDoneAndGet(snapshot); + + retrieved.clear(); + operatorStateBackend.restore(Collections.singleton(stateHandle)); + retrievedState = operatorStateBackend.getBroadcastState(broadcastStateDesc); + for (Map.Entry<Integer, Integer> e: retrievedState.immutableEntries()) { + retrieved.put(e.getKey(), e.getValue()); + } + assertTrue(expected.isEmpty()); + assertEquals(expected, retrieved); + } finally { + operatorStateBackend.close(); + operatorStateBackend.dispose(); + if (stateHandle != null) { + stateHandle.discardState(); + } + } + } + + @Test public void testSnapshotRestoreSync() throws Exception { - AbstractStateBackend abstractStateBackend = new MemoryStateBackend(4096); + AbstractStateBackend abstractStateBackend = new MemoryStateBackend(2 * 4096); OperatorStateBackend operatorStateBackend = abstractStateBackend.createOperatorStateBackend(createMockEnvironment(), "test-op-name"); ListStateDescriptor<Serializable> stateDescriptor1 = new ListStateDescriptor<>("test1", new JavaSerializer<>()); ListStateDescriptor<Serializable> stateDescriptor2 = new ListStateDescriptor<>("test2", new JavaSerializer<>()); ListStateDescriptor<Serializable> stateDescriptor3 = new ListStateDescriptor<>("test3", new JavaSerializer<>()); + + MapStateDescriptor<Serializable, Serializable> broadcastStateDescriptor1 = new MapStateDescriptor<>("test4", new JavaSerializer<>(), new JavaSerializer<>()); + MapStateDescriptor<Serializable, Serializable> broadcastStateDescriptor2 = new MapStateDescriptor<>("test5", new JavaSerializer<>(), new JavaSerializer<>()); + MapStateDescriptor<Serializable, Serializable> broadcastStateDescriptor3 = new MapStateDescriptor<>("test6", new JavaSerializer<>(), new JavaSerializer<>()); + ListState<Serializable> listState1 = operatorStateBackend.getListState(stateDescriptor1); ListState<Serializable> listState2 = operatorStateBackend.getListState(stateDescriptor2); ListState<Serializable> listState3 = operatorStateBackend.getUnionListState(stateDescriptor3); + BroadcastState<Serializable, Serializable> broadcastState1 = operatorStateBackend.getBroadcastState(broadcastStateDescriptor1); + BroadcastState<Serializable, Serializable> broadcastState2 = operatorStateBackend.getBroadcastState(broadcastStateDescriptor2); + BroadcastState<Serializable, Serializable> broadcastState3 = operatorStateBackend.getBroadcastState(broadcastStateDescriptor3); + listState1.add(42); listState1.add(4711); @@ -384,7 +491,12 @@ public class OperatorStateBackendTest { listState3.add(19); listState3.add(20); - CheckpointStreamFactory streamFactory = new MemCheckpointStreamFactory(4096); + broadcastState1.put(1, 2); + broadcastState1.put(2, 5); + + broadcastState2.put(2, 5); + + CheckpointStreamFactory streamFactory = new MemCheckpointStreamFactory(2 * 4096); RunnableFuture<OperatorStateHandle> runnableFuture = operatorStateBackend.snapshot(1, 1, streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()); OperatorStateHandle stateHandle = FutureUtil.runIfNotDoneAndGet(runnableFuture); @@ -401,12 +513,18 @@ public class OperatorStateBackendTest { operatorStateBackend.restore(Collections.singletonList(stateHandle)); assertEquals(3, operatorStateBackend.getRegisteredStateNames().size()); + assertEquals(3, operatorStateBackend.getRegisteredBroadcastStateNames().size()); listState1 = operatorStateBackend.getListState(stateDescriptor1); listState2 = operatorStateBackend.getListState(stateDescriptor2); listState3 = operatorStateBackend.getUnionListState(stateDescriptor3); + broadcastState1 = operatorStateBackend.getBroadcastState(broadcastStateDescriptor1); + broadcastState2 = operatorStateBackend.getBroadcastState(broadcastStateDescriptor2); + broadcastState3 = operatorStateBackend.getBroadcastState(broadcastStateDescriptor3); + assertEquals(3, operatorStateBackend.getRegisteredStateNames().size()); + assertEquals(3, operatorStateBackend.getRegisteredBroadcastStateNames().size()); Iterator<Serializable> it = listState1.get().iterator(); assertEquals(42, it.next()); @@ -426,6 +544,27 @@ public class OperatorStateBackendTest { assertEquals(20, it.next()); assertFalse(it.hasNext()); + Iterator<Map.Entry<Serializable, Serializable>> bIt = broadcastState1.iterator(); + assertTrue(bIt.hasNext()); + Map.Entry<Serializable, Serializable> entry = bIt.next(); + assertEquals(1, entry.getKey()); + assertEquals(2, entry.getValue()); + assertTrue(bIt.hasNext()); + entry = bIt.next(); + assertEquals(2, entry.getKey()); + assertEquals(5, entry.getValue()); + assertFalse(bIt.hasNext()); + + bIt = broadcastState2.iterator(); + assertTrue(bIt.hasNext()); + entry = bIt.next(); + assertEquals(2, entry.getKey()); + assertEquals(5, entry.getValue()); + assertFalse(bIt.hasNext()); + + bIt = broadcastState3.iterator(); + assertFalse(bIt.hasNext()); + operatorStateBackend.close(); operatorStateBackend.dispose(); } finally { @@ -444,10 +583,22 @@ public class OperatorStateBackendTest { new ListStateDescriptor<>("test2", new JavaSerializer<MutableType>()); ListStateDescriptor<MutableType> stateDescriptor3 = new ListStateDescriptor<>("test3", new JavaSerializer<MutableType>()); + + MapStateDescriptor<MutableType, MutableType> broadcastStateDescriptor1 = + new MapStateDescriptor<>("test4", new JavaSerializer<MutableType>(), new JavaSerializer<MutableType>()); + MapStateDescriptor<MutableType, MutableType> broadcastStateDescriptor2 = + new MapStateDescriptor<>("test5", new JavaSerializer<MutableType>(), new JavaSerializer<MutableType>()); + MapStateDescriptor<MutableType, MutableType> broadcastStateDescriptor3 = + new MapStateDescriptor<>("test6", new JavaSerializer<MutableType>(), new JavaSerializer<MutableType>()); + ListState<MutableType> listState1 = operatorStateBackend.getListState(stateDescriptor1); ListState<MutableType> listState2 = operatorStateBackend.getListState(stateDescriptor2); ListState<MutableType> listState3 = operatorStateBackend.getUnionListState(stateDescriptor3); + BroadcastState<MutableType, MutableType> broadcastState1 = operatorStateBackend.getBroadcastState(broadcastStateDescriptor1); + BroadcastState<MutableType, MutableType> broadcastState2 = operatorStateBackend.getBroadcastState(broadcastStateDescriptor2); + BroadcastState<MutableType, MutableType> broadcastState3 = operatorStateBackend.getBroadcastState(broadcastStateDescriptor3); + listState1.add(MutableType.of(42)); listState1.add(MutableType.of(4711)); @@ -460,6 +611,11 @@ public class OperatorStateBackendTest { listState3.add(MutableType.of(19)); listState3.add(MutableType.of(20)); + broadcastState1.put(MutableType.of(1), MutableType.of(2)); + broadcastState1.put(MutableType.of(2), MutableType.of(5)); + + broadcastState2.put(MutableType.of(2), MutableType.of(5)); + BlockerCheckpointStreamFactory streamFactory = new BlockerCheckpointStreamFactory(1024 * 1024); OneShotLatch waiterLatch = new OneShotLatch(); @@ -482,6 +638,8 @@ public class OperatorStateBackendTest { listState1.add(MutableType.of(77)); + broadcastState1.put(MutableType.of(32), MutableType.of(97)); + int n = 0; for (MutableType mutableType : listState2.get()) { @@ -493,6 +651,7 @@ public class OperatorStateBackendTest { } listState3.clear(); + broadcastState2.clear(); operatorStateBackend.getListState( new ListStateDescriptor<>("test4", new JavaSerializer<MutableType>())); @@ -514,12 +673,18 @@ public class OperatorStateBackendTest { operatorStateBackend.restore(Collections.singletonList(stateHandle)); assertEquals(3, operatorStateBackend.getRegisteredStateNames().size()); + assertEquals(3, operatorStateBackend.getRegisteredBroadcastStateNames().size()); listState1 = operatorStateBackend.getListState(stateDescriptor1); listState2 = operatorStateBackend.getListState(stateDescriptor2); listState3 = operatorStateBackend.getUnionListState(stateDescriptor3); + broadcastState1 = operatorStateBackend.getBroadcastState(broadcastStateDescriptor1); + broadcastState2 = operatorStateBackend.getBroadcastState(broadcastStateDescriptor2); + broadcastState3 = operatorStateBackend.getBroadcastState(broadcastStateDescriptor3); + assertEquals(3, operatorStateBackend.getRegisteredStateNames().size()); + assertEquals(3, operatorStateBackend.getRegisteredBroadcastStateNames().size()); Iterator<MutableType> it = listState1.get().iterator(); assertEquals(42, it.next().value); @@ -539,6 +704,27 @@ public class OperatorStateBackendTest { assertEquals(20, it.next().value); assertFalse(it.hasNext()); + Iterator<Map.Entry<MutableType, MutableType>> bIt = broadcastState1.iterator(); + assertTrue(bIt.hasNext()); + Map.Entry<MutableType, MutableType> entry = bIt.next(); + assertEquals(1, entry.getKey().value); + assertEquals(2, entry.getValue().value); + assertTrue(bIt.hasNext()); + entry = bIt.next(); + assertEquals(2, entry.getKey().value); + assertEquals(5, entry.getValue().value); + assertFalse(bIt.hasNext()); + + bIt = broadcastState2.iterator(); + assertTrue(bIt.hasNext()); + entry = bIt.next(); + assertEquals(2, entry.getKey().value); + assertEquals(5, entry.getValue().value); + assertFalse(bIt.hasNext()); + + bIt = broadcastState3.iterator(); + assertFalse(bIt.hasNext()); + operatorStateBackend.close(); operatorStateBackend.dispose(); } finally { @@ -558,10 +744,16 @@ public class OperatorStateBackendTest { ListState<MutableType> listState1 = operatorStateBackend.getOperatorState(stateDescriptor1); - listState1.add(MutableType.of(42)); listState1.add(MutableType.of(4711)); + MapStateDescriptor<MutableType, MutableType> broadcastStateDescriptor1 = + new MapStateDescriptor<>("test4", new JavaSerializer<MutableType>(), new JavaSerializer<MutableType>()); + + BroadcastState<MutableType, MutableType> broadcastState1 = operatorStateBackend.getBroadcastState(broadcastStateDescriptor1); + broadcastState1.put(MutableType.of(1), MutableType.of(2)); + broadcastState1.put(MutableType.of(2), MutableType.of(5)); + BlockerCheckpointStreamFactory streamFactory = new BlockerCheckpointStreamFactory(1024 * 1024); OneShotLatch waiterLatch = new OneShotLatch(); @@ -602,7 +794,6 @@ public class OperatorStateBackendTest { ListState<MutableType> listState1 = operatorStateBackend.getOperatorState(stateDescriptor1); - listState1.add(MutableType.of(42)); listState1.add(MutableType.of(4711)); http://git-wip-us.apache.org/repos/asf/flink/blob/484fedd4/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateHandleTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateHandleTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateHandleTest.java index ab801b6..88f9cd7 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateHandleTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateHandleTest.java @@ -28,10 +28,11 @@ public class OperatorStateHandleTest { // Ensure the order / ordinal of all values of enum 'mode' are fixed, as this is used for serialization Assert.assertEquals(0, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE.ordinal()); - Assert.assertEquals(1, OperatorStateHandle.Mode.BROADCAST.ordinal()); + Assert.assertEquals(1, OperatorStateHandle.Mode.UNION.ordinal()); + Assert.assertEquals(2, OperatorStateHandle.Mode.BROADCAST.ordinal()); // Ensure all enum values are registered and fixed forever by this test - Assert.assertEquals(2, OperatorStateHandle.Mode.values().length); + Assert.assertEquals(3, OperatorStateHandle.Mode.values().length); // Byte is used to encode enum value on serialization Assert.assertTrue(OperatorStateHandle.Mode.values().length <= Byte.MAX_VALUE); http://git-wip-us.apache.org/repos/asf/flink/blob/484fedd4/flink-runtime/src/test/java/org/apache/flink/runtime/state/SerializationProxiesTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/SerializationProxiesTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/SerializationProxiesTest.java index 341d4fe..57e4aed 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/SerializationProxiesTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/SerializationProxiesTest.java @@ -24,6 +24,7 @@ import org.apache.flink.api.common.typeutils.TypeSerializerSerializationUtil; import org.apache.flink.api.common.typeutils.base.DoubleSerializer; import org.apache.flink.api.common.typeutils.base.IntSerializer; import org.apache.flink.api.common.typeutils.base.LongSerializer; +import org.apache.flink.api.common.typeutils.base.StringSerializer; import org.apache.flink.core.memory.ByteArrayInputStreamWithPos; import org.apache.flink.core.memory.ByteArrayOutputStreamWithPos; import org.apache.flink.core.memory.DataInputViewStreamWrapper; @@ -205,6 +206,8 @@ public class SerializationProxiesTest { public void testOperatorBackendSerializationProxyRoundtrip() throws Exception { TypeSerializer<?> stateSerializer = DoubleSerializer.INSTANCE; + TypeSerializer<?> keySerializer = DoubleSerializer.INSTANCE; + TypeSerializer<?> valueSerializer = StringSerializer.INSTANCE; List<RegisteredOperatorBackendStateMetaInfo.Snapshot<?>> stateMetaInfoSnapshots = new ArrayList<>(); @@ -213,10 +216,17 @@ public class SerializationProxiesTest { stateMetaInfoSnapshots.add(new RegisteredOperatorBackendStateMetaInfo<>( "b", stateSerializer, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE).snapshot()); stateMetaInfoSnapshots.add(new RegisteredOperatorBackendStateMetaInfo<>( - "c", stateSerializer, OperatorStateHandle.Mode.BROADCAST).snapshot()); + "c", stateSerializer, OperatorStateHandle.Mode.UNION).snapshot()); + + List<RegisteredBroadcastBackendStateMetaInfo.Snapshot<?, ?>> broadcastStateMetaInfoSnapshots = new ArrayList<>(); + + broadcastStateMetaInfoSnapshots.add(new RegisteredBroadcastBackendStateMetaInfo<>( + "d", OperatorStateHandle.Mode.BROADCAST, keySerializer, valueSerializer).snapshot()); + broadcastStateMetaInfoSnapshots.add(new RegisteredBroadcastBackendStateMetaInfo<>( + "e", OperatorStateHandle.Mode.BROADCAST, valueSerializer, keySerializer).snapshot()); OperatorBackendSerializationProxy serializationProxy = - new OperatorBackendSerializationProxy(stateMetaInfoSnapshots); + new OperatorBackendSerializationProxy(stateMetaInfoSnapshots, broadcastStateMetaInfoSnapshots); byte[] serialized; try (ByteArrayOutputStreamWithPos out = new ByteArrayOutputStreamWithPos()) { @@ -231,7 +241,8 @@ public class SerializationProxiesTest { serializationProxy.read(new DataInputViewStreamWrapper(in)); } - Assert.assertEquals(stateMetaInfoSnapshots, serializationProxy.getStateMetaInfoSnapshots()); + Assert.assertEquals(stateMetaInfoSnapshots, serializationProxy.getOperatorStateMetaInfoSnapshots()); + Assert.assertEquals(broadcastStateMetaInfoSnapshots, serializationProxy.getBroadcastStateMetaInfoSnapshots()); } @Test @@ -242,24 +253,58 @@ public class SerializationProxiesTest { RegisteredOperatorBackendStateMetaInfo.Snapshot<?> metaInfo = new RegisteredOperatorBackendStateMetaInfo<>( - name, stateSerializer, OperatorStateHandle.Mode.BROADCAST).snapshot(); + name, stateSerializer, OperatorStateHandle.Mode.UNION).snapshot(); byte[] serialized; try (ByteArrayOutputStreamWithPos out = new ByteArrayOutputStreamWithPos()) { OperatorBackendStateMetaInfoSnapshotReaderWriters - .getWriterForVersion(OperatorBackendSerializationProxy.VERSION, metaInfo) - .writeStateMetaInfo(new DataOutputViewStreamWrapper(out)); + .getOperatorStateWriterForVersion(OperatorBackendSerializationProxy.VERSION, metaInfo) + .writeOperatorStateMetaInfo(new DataOutputViewStreamWrapper(out)); serialized = out.toByteArray(); } try (ByteArrayInputStreamWithPos in = new ByteArrayInputStreamWithPos(serialized)) { metaInfo = OperatorBackendStateMetaInfoSnapshotReaderWriters - .getReaderForVersion(OperatorBackendSerializationProxy.VERSION, Thread.currentThread().getContextClassLoader()) - .readStateMetaInfo(new DataInputViewStreamWrapper(in)); + .getOperatorStateReaderForVersion(OperatorBackendSerializationProxy.VERSION, Thread.currentThread().getContextClassLoader()) + .readOperatorStateMetaInfo(new DataInputViewStreamWrapper(in)); + } + + Assert.assertEquals(name, metaInfo.getName()); + Assert.assertEquals(OperatorStateHandle.Mode.UNION, metaInfo.getAssignmentMode()); + Assert.assertEquals(stateSerializer, metaInfo.getPartitionStateSerializer()); + } + + @Test + public void testBroadcastStateMetaInfoSerialization() throws Exception { + + String name = "test"; + TypeSerializer<?> keySerializer = DoubleSerializer.INSTANCE; + TypeSerializer<?> valueSerializer = StringSerializer.INSTANCE; + + RegisteredBroadcastBackendStateMetaInfo.Snapshot<?, ?> metaInfo = + new RegisteredBroadcastBackendStateMetaInfo<>( + name, OperatorStateHandle.Mode.BROADCAST, keySerializer, valueSerializer).snapshot(); + + byte[] serialized; + try (ByteArrayOutputStreamWithPos out = new ByteArrayOutputStreamWithPos()) { + OperatorBackendStateMetaInfoSnapshotReaderWriters + .getBroadcastStateWriterForVersion(OperatorBackendSerializationProxy.VERSION, metaInfo) + .writeBroadcastStateMetaInfo(new DataOutputViewStreamWrapper(out)); + + serialized = out.toByteArray(); + } + + try (ByteArrayInputStreamWithPos in = new ByteArrayInputStreamWithPos(serialized)) { + metaInfo = OperatorBackendStateMetaInfoSnapshotReaderWriters + .getBroadcastStateReaderForVersion(OperatorBackendSerializationProxy.VERSION, Thread.currentThread().getContextClassLoader()) + .readBroadcastStateMetaInfo(new DataInputViewStreamWrapper(in)); } Assert.assertEquals(name, metaInfo.getName()); + Assert.assertEquals(OperatorStateHandle.Mode.BROADCAST, metaInfo.getAssignmentMode()); + Assert.assertEquals(keySerializer, metaInfo.getKeySerializer()); + Assert.assertEquals(valueSerializer, metaInfo.getValueSerializer()); } @Test @@ -269,13 +314,13 @@ public class SerializationProxiesTest { RegisteredOperatorBackendStateMetaInfo.Snapshot<?> metaInfo = new RegisteredOperatorBackendStateMetaInfo<>( - name, stateSerializer, OperatorStateHandle.Mode.BROADCAST).snapshot(); + name, stateSerializer, OperatorStateHandle.Mode.UNION).snapshot(); byte[] serialized; try (ByteArrayOutputStreamWithPos out = new ByteArrayOutputStreamWithPos()) { OperatorBackendStateMetaInfoSnapshotReaderWriters - .getWriterForVersion(OperatorBackendSerializationProxy.VERSION, metaInfo) - .writeStateMetaInfo(new DataOutputViewStreamWrapper(out)); + .getOperatorStateWriterForVersion(OperatorBackendSerializationProxy.VERSION, metaInfo) + .writeOperatorStateMetaInfo(new DataOutputViewStreamWrapper(out)); serialized = out.toByteArray(); } @@ -288,8 +333,8 @@ public class SerializationProxiesTest { try (ByteArrayInputStreamWithPos in = new ByteArrayInputStreamWithPos(serialized)) { metaInfo = OperatorBackendStateMetaInfoSnapshotReaderWriters - .getReaderForVersion(OperatorBackendSerializationProxy.VERSION, Thread.currentThread().getContextClassLoader()) - .readStateMetaInfo(new DataInputViewStreamWrapper(in)); + .getOperatorStateReaderForVersion(OperatorBackendSerializationProxy.VERSION, Thread.currentThread().getContextClassLoader()) + .readOperatorStateMetaInfo(new DataInputViewStreamWrapper(in)); } Assert.assertEquals(name, metaInfo.getName()); @@ -297,6 +342,44 @@ public class SerializationProxiesTest { Assert.assertEquals(stateSerializer.snapshotConfiguration(), metaInfo.getPartitionStateSerializerConfigSnapshot()); } + @Test + public void testBroadcastStateMetaInfoReadSerializerFailureResilience() throws Exception { + String broadcastName = "broadcastTest"; + TypeSerializer<?> keySerializer = DoubleSerializer.INSTANCE; + TypeSerializer<?> valueSerializer = StringSerializer.INSTANCE; + + RegisteredBroadcastBackendStateMetaInfo.Snapshot<?, ?> broadcastMetaInfo = + new RegisteredBroadcastBackendStateMetaInfo<>( + broadcastName, OperatorStateHandle.Mode.BROADCAST, keySerializer, valueSerializer).snapshot(); + + byte[] serialized; + try (ByteArrayOutputStreamWithPos out = new ByteArrayOutputStreamWithPos()) { + OperatorBackendStateMetaInfoSnapshotReaderWriters + .getBroadcastStateWriterForVersion(OperatorBackendSerializationProxy.VERSION, broadcastMetaInfo) + .writeBroadcastStateMetaInfo(new DataOutputViewStreamWrapper(out)); + + serialized = out.toByteArray(); + } + + // mock failure when deserializing serializer + TypeSerializerSerializationUtil.TypeSerializerSerializationProxy<?> mockProxy = + mock(TypeSerializerSerializationUtil.TypeSerializerSerializationProxy.class); + doThrow(new IOException()).when(mockProxy).read(any(DataInputViewStreamWrapper.class)); + PowerMockito.whenNew(TypeSerializerSerializationUtil.TypeSerializerSerializationProxy.class).withAnyArguments().thenReturn(mockProxy); + + try (ByteArrayInputStreamWithPos in = new ByteArrayInputStreamWithPos(serialized)) { + broadcastMetaInfo = OperatorBackendStateMetaInfoSnapshotReaderWriters + .getBroadcastStateReaderForVersion(OperatorBackendSerializationProxy.VERSION, Thread.currentThread().getContextClassLoader()) + .readBroadcastStateMetaInfo(new DataInputViewStreamWrapper(in)); + } + + Assert.assertEquals(broadcastName, broadcastMetaInfo.getName()); + Assert.assertEquals(null, broadcastMetaInfo.getKeySerializer()); + Assert.assertEquals(keySerializer.snapshotConfiguration(), broadcastMetaInfo.getKeySerializerConfigSnapshot()); + Assert.assertEquals(null, broadcastMetaInfo.getValueSerializer()); + Assert.assertEquals(valueSerializer.snapshotConfiguration(), broadcastMetaInfo.getValueSerializerConfigSnapshot()); + } + /** * This test fixes the order of elements in the enum which is important for serialization. Do not modify this test * except if you are entirely sure what you are doing.
