Repository: flink Updated Branches: refs/heads/master de2605ea7 -> 30c9e2b68
http://git-wip-us.apache.org/repos/asf/flink/blob/30c9e2b6/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java index e386e0f..04e4fbc 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java @@ -23,6 +23,7 @@ import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.api.common.state.AggregatingStateDescriptor; import org.apache.flink.api.common.state.FoldingStateDescriptor; import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.state.MapStateDescriptor; import org.apache.flink.api.common.state.ReducingStateDescriptor; import org.apache.flink.api.common.state.StateDescriptor; import org.apache.flink.api.common.state.ValueStateDescriptor; @@ -44,6 +45,7 @@ import org.apache.flink.runtime.state.AbstractKeyedStateBackend; import org.apache.flink.runtime.state.ArrayListSerializer; import org.apache.flink.runtime.state.CheckpointStreamFactory; import org.apache.flink.runtime.state.DoneFuture; +import org.apache.flink.runtime.state.HashMapSerializer; import org.apache.flink.runtime.state.KeyGroupRange; import org.apache.flink.runtime.state.KeyGroupRangeOffsets; import org.apache.flink.runtime.state.KeyGroupsStateHandle; @@ -55,6 +57,7 @@ import org.apache.flink.runtime.state.VoidNamespaceSerializer; import org.apache.flink.runtime.state.internal.InternalAggregatingState; import org.apache.flink.runtime.state.internal.InternalFoldingState; import org.apache.flink.runtime.state.internal.InternalListState; +import org.apache.flink.runtime.state.internal.InternalMapState; import org.apache.flink.runtime.state.internal.InternalReducingState; import org.apache.flink.runtime.state.internal.InternalValueState; import org.apache.flink.util.InstantiationUtil; @@ -186,7 +189,7 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> { } @Override - protected <N, T, ACC> InternalFoldingState<N, T, ACC> createFoldingState( + public <N, T, ACC> InternalFoldingState<N, T, ACC> createFoldingState( TypeSerializer<N> namespaceSerializer, FoldingStateDescriptor<T, ACC> stateDesc) throws Exception { @@ -195,6 +198,19 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> { } @Override + public <N, UK, UV> InternalMapState<N, UK, UV> createMapState(TypeSerializer<N> namespaceSerializer, + MapStateDescriptor<UK, UV> stateDesc) throws Exception { + + StateTable<K, N, HashMap<UK, UV>> stateTable = tryRegisterStateTable( + stateDesc.getName(), + stateDesc.getType(), + namespaceSerializer, + new HashMapSerializer<>(stateDesc.getKeySerializer(), stateDesc.getValueSerializer())); + + return new HeapMapState<>(this, stateDesc, stateTable, keySerializer, namespaceSerializer); + } + + @Override @SuppressWarnings("unchecked") public RunnableFuture<KeyGroupsStateHandle> snapshot( long checkpointId, http://git-wip-us.apache.org/repos/asf/flink/blob/30c9e2b6/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapMapState.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapMapState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapMapState.java new file mode 100644 index 0000000..b28d661 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapMapState.java @@ -0,0 +1,311 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state.heap; + +import org.apache.flink.api.common.state.MapState; +import org.apache.flink.api.common.state.MapStateDescriptor; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.runtime.query.netty.message.KvStateRequestSerializer; +import org.apache.flink.runtime.state.KeyGroupRangeAssignment; +import org.apache.flink.runtime.state.KeyedStateBackend; +import org.apache.flink.runtime.state.internal.InternalMapState; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Iterator; +import java.util.Map; + +/** + * Heap-backed partitioned {@link MapState} that is snapshotted into files. + * + * @param <K> The type of the key. + * @param <N> The type of the namespace. + * @param <UK> The type of the keys in the state. + * @param <UV> The type of the values in the state. + */ +public class HeapMapState<K, N, UK, UV> + extends AbstractHeapState<K, N, HashMap<UK, UV>, MapState<UK, UV>, MapStateDescriptor<UK, UV>> + implements InternalMapState<N, UK, UV> { + + /** + * Creates a new key/value state for the given hash map of key/value pairs. + * + * @param backend The state backend backing that created this state. + * @param stateDesc The state identifier for the state. This contains name + * and can create a default state value. + * @param stateTable The state tab;e to use in this kev/value state. May contain initial state. + */ + public HeapMapState(KeyedStateBackend<K> backend, + MapStateDescriptor<UK, UV> stateDesc, + StateTable<K, N, HashMap<UK, UV>> stateTable, + TypeSerializer<K> keySerializer, + TypeSerializer<N> namespaceSerializer) { + super(backend, stateDesc, stateTable, keySerializer, namespaceSerializer); + } + + @Override + public UV get(UK userKey) { + Preconditions.checkState(currentNamespace != null, "No namespace set."); + Preconditions.checkState(backend.getCurrentKey() != null, "No key set."); + + Map<N, Map<K, HashMap<UK, UV>>> namespaceMap = stateTable.get(backend.getCurrentKeyGroupIndex()); + if (namespaceMap == null) { + return null; + } + + Map<K, HashMap<UK, UV>> keyedMap = namespaceMap.get(currentNamespace); + if (keyedMap == null) { + return null; + } + + HashMap<UK, UV> userMap = keyedMap.get(backend.<K>getCurrentKey()); + if (userMap == null) { + return null; + } + + return userMap.get(userKey); + } + + @Override + public void put(UK userKey, UV userValue) { + Preconditions.checkState(currentNamespace != null, "No namespace set."); + Preconditions.checkState(backend.getCurrentKey() != null, "No key set."); + + Map<N, Map<K, HashMap<UK, UV>>> namespaceMap = stateTable.get(backend.getCurrentKeyGroupIndex()); + if (namespaceMap == null) { + namespaceMap = createNewMap(); + stateTable.set(backend.getCurrentKeyGroupIndex(), namespaceMap); + } + + Map<K, HashMap<UK, UV>> keyedMap = namespaceMap.get(currentNamespace); + if (keyedMap == null) { + keyedMap = createNewMap(); + namespaceMap.put(currentNamespace, keyedMap); + } + + HashMap<UK, UV> userMap = keyedMap.get(backend.getCurrentKey()); + if (userMap == null) { + userMap = new HashMap<>(); + keyedMap.put(backend.getCurrentKey(), userMap); + } + + userMap.put(userKey, userValue); + } + + @Override + public void putAll(Map<UK, UV> value) { + Preconditions.checkState(currentNamespace != null, "No namespace set."); + Preconditions.checkState(backend.getCurrentKey() != null, "No key set."); + + Map<N, Map<K, HashMap<UK, UV>>> namespaceMap = stateTable.get(backend.getCurrentKeyGroupIndex()); + if (namespaceMap == null) { + namespaceMap = createNewMap(); + stateTable.set(backend.getCurrentKeyGroupIndex(), namespaceMap); + } + + Map<K, HashMap<UK, UV>> keyedMap = namespaceMap.get(currentNamespace); + if (keyedMap == null) { + keyedMap = createNewMap(); + namespaceMap.put(currentNamespace, keyedMap); + } + + HashMap<UK, UV> userMap = keyedMap.get(backend.getCurrentKey()); + if (userMap == null) { + userMap = new HashMap<>(); + keyedMap.put(backend.getCurrentKey(), userMap); + } + + userMap.putAll(value); + } + + @Override + public void remove(UK userKey) { + Preconditions.checkState(currentNamespace != null, "No namespace set."); + Preconditions.checkState(backend.getCurrentKey() != null, "No key set."); + + Map<N, Map<K, HashMap<UK, UV>>> namespaceMap = stateTable.get(backend.getCurrentKeyGroupIndex()); + if (namespaceMap == null) { + return; + } + + Map<K, HashMap<UK, UV>> keyedMap = namespaceMap.get(currentNamespace); + if (keyedMap == null) { + return; + } + + HashMap<UK, UV> userMap = keyedMap.get(backend.getCurrentKey()); + if (userMap == null) { + return; + } + + userMap.remove(userKey); + + if (userMap.isEmpty()) { + clear(); + } + } + + @Override + public boolean contains(UK userKey) { + Preconditions.checkState(currentNamespace != null, "No namespace set."); + Preconditions.checkState(backend.getCurrentKey() != null, "No key set."); + + Map<N, Map<K, HashMap<UK, UV>>> namespaceMap = stateTable.get(backend.getCurrentKeyGroupIndex()); + if (namespaceMap == null) { + return false; + } + + Map<K, HashMap<UK, UV>> keyedMap = namespaceMap.get(currentNamespace); + if (keyedMap == null) { + return false; + } + + HashMap<UK, UV> userMap = keyedMap.get(backend.<K>getCurrentKey()); + + return userMap != null && userMap.containsKey(userKey); + } + + @Override + public int size() { + Preconditions.checkState(currentNamespace != null, "No namespace set."); + Preconditions.checkState(backend.getCurrentKey() != null, "No key set."); + + Map<N, Map<K, HashMap<UK, UV>>> namespaceMap = stateTable.get(backend.getCurrentKeyGroupIndex()); + if (namespaceMap == null) { + return 0; + } + + Map<K, HashMap<UK, UV>> keyedMap = namespaceMap.get(currentNamespace); + if (keyedMap == null) { + return 0; + } + + HashMap<UK, UV> userMap = keyedMap.get(backend.<K>getCurrentKey()); + + return userMap == null ? 0 : userMap.size(); + } + + @Override + public Iterable<Map.Entry<UK, UV>> entries() { + Preconditions.checkState(currentNamespace != null, "No namespace set."); + Preconditions.checkState(backend.getCurrentKey() != null, "No key set."); + + Map<N, Map<K, HashMap<UK, UV>>> namespaceMap = stateTable.get(backend.getCurrentKeyGroupIndex()); + if (namespaceMap == null) { + return null; + } + + Map<K, HashMap<UK, UV>> keyedMap = namespaceMap.get(currentNamespace); + if (keyedMap == null) { + return null; + } + + HashMap<UK, UV> userMap = keyedMap.get(backend.<K>getCurrentKey()); + + return userMap == null ? null : userMap.entrySet(); + } + + @Override + public Iterable<UK> keys() { + Preconditions.checkState(currentNamespace != null, "No namespace set."); + Preconditions.checkState(backend.getCurrentKey() != null, "No key set."); + + Map<N, Map<K, HashMap<UK, UV>>> namespaceMap = stateTable.get(backend.getCurrentKeyGroupIndex()); + if (namespaceMap == null) { + return null; + } + + Map<K, HashMap<UK, UV>> keyedMap = namespaceMap.get(currentNamespace); + if (keyedMap == null) { + return null; + } + + HashMap<UK, UV> userMap = keyedMap.get(backend.<K>getCurrentKey()); + + return userMap == null ? null : userMap.keySet(); + } + + @Override + public Iterable<UV> values() { + Preconditions.checkState(currentNamespace != null, "No namespace set."); + Preconditions.checkState(backend.getCurrentKey() != null, "No key set."); + + Map<N, Map<K, HashMap<UK, UV>>> namespaceMap = stateTable.get(backend.getCurrentKeyGroupIndex()); + if (namespaceMap == null) { + return null; + } + + Map<K, HashMap<UK, UV>> keyedMap = namespaceMap.get(currentNamespace); + if (keyedMap == null) { + return null; + } + + HashMap<UK, UV> userMap = keyedMap.get(backend.<K>getCurrentKey()); + + return userMap == null ? null : userMap.values(); + } + + @Override + public Iterator<Map.Entry<UK, UV>> iterator() { + Preconditions.checkState(currentNamespace != null, "No namespace set."); + Preconditions.checkState(backend.getCurrentKey() != null, "No key set."); + + Map<N, Map<K, HashMap<UK, UV>>> namespaceMap = stateTable.get(backend.getCurrentKeyGroupIndex()); + if (namespaceMap == null) { + return null; + } + + Map<K, HashMap<UK, UV>> keyedMap = namespaceMap.get(currentNamespace); + if (keyedMap == null) { + return null; + } + + HashMap<UK, UV> userMap = keyedMap.get(backend.<K>getCurrentKey()); + + return userMap == null ? null : userMap.entrySet().iterator(); + } + + @Override + public byte[] getSerializedValue(K key, N namespace) throws IOException { + Preconditions.checkState(namespace != null, "No namespace given."); + Preconditions.checkState(key != null, "No key given."); + + Map<N, Map<K, HashMap<UK, UV>>> namespaceMap = stateTable.get(KeyGroupRangeAssignment.assignToKeyGroup(key, backend.getNumberOfKeyGroups())); + + if (namespaceMap == null) { + return null; + } + + Map<K, HashMap<UK, UV>> keyedMap = namespaceMap.get(namespace); + if (keyedMap == null) { + return null; + } + + HashMap<UK, UV> result = keyedMap.get(key); + if (result == null) { + return null; + } + + TypeSerializer<UK> userKeySerializer = stateDesc.getKeySerializer(); + TypeSerializer<UV> userValueSerializer = stateDesc.getValueSerializer(); + + return KvStateRequestSerializer.serializeMap(result.entrySet(), userKeySerializer, userValueSerializer); + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/30c9e2b6/flink-runtime/src/main/java/org/apache/flink/runtime/state/internal/InternalMapState.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/internal/InternalMapState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/internal/InternalMapState.java new file mode 100644 index 0000000..f2a7b41 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/internal/InternalMapState.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state.internal; + +import org.apache.flink.api.common.state.MapState; + +/** + * The peer to the {@link MapState} in the internal state type hierarchy. + * + * <p>See {@link InternalKvState} for a description of the internal state hierarchy. + * + * @param <N> The type of the namespace + * @param <UK> Type of the values folded into the state + * @param <UV> Type of the value in the state + */ +public interface InternalMapState<N, UK, UV> extends InternalKvState<N>, MapState<UK, UV> {} http://git-wip-us.apache.org/repos/asf/flink/blob/30c9e2b6/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/message/KvStateRequestSerializerTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/message/KvStateRequestSerializerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/message/KvStateRequestSerializerTest.java index 69dbe6f..dd61a3f 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/message/KvStateRequestSerializerTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/message/KvStateRequestSerializerTest.java @@ -23,7 +23,9 @@ import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.UnpooledByteBufAllocator; import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.state.MapStateDescriptor; import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.common.typeutils.base.ByteSerializer; import org.apache.flink.api.common.typeutils.base.LongSerializer; import org.apache.flink.api.common.typeutils.base.StringSerializer; import org.apache.flink.api.java.tuple.Tuple2; @@ -36,11 +38,15 @@ import org.apache.flink.runtime.state.heap.HeapKeyedStateBackend; import org.apache.flink.runtime.state.internal.InternalKvState; import org.apache.flink.runtime.state.internal.InternalListState; +import org.apache.flink.runtime.state.internal.InternalMapState; import org.junit.Test; +import java.io.ByteArrayOutputStream; import java.io.IOException; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.concurrent.ThreadLocalRandom; import static org.junit.Assert.assertArrayEquals; @@ -410,6 +416,131 @@ public class KvStateRequestSerializerTest { KvStateRequestSerializer.deserializeList(new byte[] {1, 1, 1, 1, 1, 1, 1, 1, 2, 3}, LongSerializer.INSTANCE); } + + /** + * Tests map serialization utils. + */ + @Test + public void testMapSerialization() throws Exception { + final long key = 0L; + + // objects for heap state list serialisation + final HeapKeyedStateBackend<Long> longHeapKeyedStateBackend = + new HeapKeyedStateBackend<>( + mock(TaskKvStateRegistry.class), + LongSerializer.INSTANCE, + ClassLoader.getSystemClassLoader(), + 1, new KeyGroupRange(0, 0) + ); + longHeapKeyedStateBackend.setCurrentKey(key); + + final InternalMapState<VoidNamespace, Long, String> mapState = (InternalMapState<VoidNamespace, Long, String>) longHeapKeyedStateBackend.getPartitionedState( + VoidNamespace.INSTANCE, + VoidNamespaceSerializer.INSTANCE, + new MapStateDescriptor<>("test", LongSerializer.INSTANCE, StringSerializer.INSTANCE)); + + testMapSerialization(key, mapState); + } + + /** + * Verifies that the serialization of a map using the given map state + * matches the deserialization with {@link KvStateRequestSerializer#deserializeList}. + * + * @param key + * key of the map state + * @param mapState + * map state using the {@link VoidNamespace}, must also be a {@link InternalKvState} instance + * + * @throws Exception + */ + public static void testMapSerialization( + final long key, + final InternalMapState<VoidNamespace, Long, String> mapState) throws Exception { + + TypeSerializer<Long> userKeySerializer = LongSerializer.INSTANCE; + TypeSerializer<String> userValueSerializer = StringSerializer.INSTANCE; + mapState.setCurrentNamespace(VoidNamespace.INSTANCE); + + // Map + final int numElements = 10; + + final Map<Long, String> expectedValues = new HashMap<>(); + for (int i = 1; i <= numElements; i++) { + final long value = ThreadLocalRandom.current().nextLong(); + expectedValues.put(value, Long.toString(value)); + mapState.put(value, Long.toString(value)); + } + + expectedValues.put(0L, null); + mapState.put(0L, null); + + final byte[] serializedKey = + KvStateRequestSerializer.serializeKeyAndNamespace( + key, LongSerializer.INSTANCE, + VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE); + + final byte[] serializedValues = mapState.getSerializedValue(serializedKey); + + Map<Long, String> actualValues = KvStateRequestSerializer.deserializeMap(serializedValues, userKeySerializer, userValueSerializer); + assertEquals(expectedValues.size(), actualValues.size()); + for (Map.Entry<Long, String> actualEntry : actualValues.entrySet()) { + assertEquals(expectedValues.get(actualEntry.getKey()), actualEntry.getValue()); + } + + // Single value + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + long expectedKey = ThreadLocalRandom.current().nextLong(); + String expectedValue = Long.toString(expectedKey); + byte[] isNull = {0}; + + baos.write(KvStateRequestSerializer.serializeValue(expectedKey, userKeySerializer)); + baos.write(isNull); + baos.write(KvStateRequestSerializer.serializeValue(expectedValue, userValueSerializer)); + byte[] serializedValue = baos.toByteArray(); + + Map<Long, String> actualValue = KvStateRequestSerializer.deserializeMap(serializedValue, userKeySerializer, userValueSerializer); + assertEquals(1, actualValue.size()); + assertEquals(expectedValue, actualValue.get(expectedKey)); + } + + /** + * Tests map deserialization with too few bytes. + */ + @Test + public void testDeserializeMapEmpty() throws Exception { + Map<Long, String> actualValue = KvStateRequestSerializer + .deserializeMap(new byte[] {}, LongSerializer.INSTANCE, StringSerializer.INSTANCE); + assertEquals(0, actualValue.size()); + } + + /** + * Tests map deserialization with too few bytes. + */ + @Test(expected = IOException.class) + public void testDeserializeMapTooShort1() throws Exception { + // 1 byte (incomplete Key) + KvStateRequestSerializer.deserializeMap(new byte[] {1}, LongSerializer.INSTANCE, StringSerializer.INSTANCE); + } + + /** + * Tests map deserialization with too few bytes. + */ + @Test(expected = IOException.class) + public void testDeserializeMapTooShort2() throws Exception { + // Long (Key) + 1 byte (incomplete Value) + KvStateRequestSerializer.deserializeMap(new byte[]{1, 1, 1, 1, 1, 1, 1, 1, 0}, + LongSerializer.INSTANCE, LongSerializer.INSTANCE); + } + + /** + * Tests map deserialization with too few bytes. + */ + @Test(expected = IOException.class) + public void testDeserializeMapTooShort3() throws Exception { + // Long (Key1) + Boolean (false) + Long (Value1) + 1 byte (incomplete Key2) + KvStateRequestSerializer.deserializeMap(new byte[] {1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 3}, + LongSerializer.INSTANCE, LongSerializer.INSTANCE); + } private byte[] randomByteArray(int capacity) { byte[] bytes = new byte[capacity]; http://git-wip-us.apache.org/repos/asf/flink/blob/30c9e2b6/flink-runtime/src/test/java/org/apache/flink/runtime/state/FileStateBackendTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/FileStateBackendTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/FileStateBackendTest.java index 57f4572..75014e7 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/FileStateBackendTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/FileStateBackendTest.java @@ -65,6 +65,10 @@ public class FileStateBackendTest extends StateBackendTestBase<FsStateBackend> { @Override @Test public void testReducingStateRestoreWithWrongSerializers() {} + + @Override + @Test + public void testMapStateRestoreWithWrongSerializers() {} @Test public void testStateOutputStream() throws IOException { http://git-wip-us.apache.org/repos/asf/flink/blob/30c9e2b6/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java index c267afc..362fcd6 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java @@ -59,6 +59,10 @@ public class MemoryStateBackendTest extends StateBackendTestBase<MemoryStateBack @Override @Test public void testReducingStateRestoreWithWrongSerializers() {} + + @Override + @Test + public void testMapStateRestoreWithWrongSerializers() {} @Test @SuppressWarnings("unchecked") http://git-wip-us.apache.org/repos/asf/flink/blob/30c9e2b6/flink-runtime/src/test/java/org/apache/flink/runtime/state/SerializationProxiesTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/SerializationProxiesTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/SerializationProxiesTest.java index 66e8d02..0dbe2eb 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/SerializationProxiesTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/SerializationProxiesTest.java @@ -161,7 +161,7 @@ public class SerializationProxiesTest { @Test public void testFixTypeOrder() { // ensure all elements are covered - Assert.assertEquals(6, StateDescriptor.Type.values().length); + Assert.assertEquals(7, StateDescriptor.Type.values().length); // fix the order of elements to keep serialization format stable Assert.assertEquals(0, StateDescriptor.Type.UNKNOWN.ordinal()); Assert.assertEquals(1, StateDescriptor.Type.VALUE.ordinal()); @@ -169,5 +169,6 @@ public class SerializationProxiesTest { Assert.assertEquals(3, StateDescriptor.Type.REDUCING.ordinal()); Assert.assertEquals(4, StateDescriptor.Type.FOLDING.ordinal()); Assert.assertEquals(5, StateDescriptor.Type.AGGREGATING.ordinal()); + Assert.assertEquals(6, StateDescriptor.Type.MAP.ordinal()); } } \ No newline at end of file http://git-wip-us.apache.org/repos/asf/flink/blob/30c9e2b6/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java index 7737ecf..3b0350d 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java @@ -28,6 +28,8 @@ import org.apache.flink.api.common.state.FoldingState; import org.apache.flink.api.common.state.FoldingStateDescriptor; import org.apache.flink.api.common.state.ListState; import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.state.MapState; +import org.apache.flink.api.common.state.MapStateDescriptor; import org.apache.flink.api.common.state.ReducingState; import org.apache.flink.api.common.state.ReducingStateDescriptor; import org.apache.flink.api.common.state.ValueState; @@ -54,8 +56,12 @@ import org.apache.flink.util.TestLogger; import org.junit.Test; import java.io.IOException; +import java.util.ArrayList; import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; import java.util.List; +import java.util.Map; import java.util.Random; import java.util.Timer; import java.util.TimerTask; @@ -784,6 +790,169 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten fail(e.getMessage()); } } + + @Test + @SuppressWarnings("unchecked,rawtypes") + public void testMapState() { + try { + CheckpointStreamFactory streamFactory = createStreamFactory(); + AbstractKeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE); + + MapStateDescriptor<Integer, String> kvId = new MapStateDescriptor<>("id", Integer.class, String.class); + kvId.initializeSerializerUnlessSet(new ExecutionConfig()); + + TypeSerializer<Integer> keySerializer = IntSerializer.INSTANCE; + TypeSerializer<VoidNamespace> namespaceSerializer = VoidNamespaceSerializer.INSTANCE; + TypeSerializer<Integer> userKeySerializer = kvId.getKeySerializer(); + TypeSerializer<String> userValueSerializer = kvId.getValueSerializer(); + + MapState<Integer, String> state = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId); + @SuppressWarnings("unchecked") + InternalKvState<VoidNamespace> kvState = (InternalKvState<VoidNamespace>) state; + + // some modifications to the state + backend.setCurrentKey(1); + assertEquals(0, state.size()); + assertEquals(null, state.get(1)); + assertEquals(null, getSerializedMap(kvState, 1, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, userKeySerializer, userValueSerializer)); + state.put(1, "1"); + backend.setCurrentKey(2); + assertEquals(0, state.size()); + assertEquals(null, state.get(2)); + assertEquals(null, getSerializedMap(kvState, 2, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, userKeySerializer, userValueSerializer)); + state.put(2, "2"); + backend.setCurrentKey(1); + assertEquals(1, state.size()); + assertTrue(state.contains(1)); + assertEquals("1", state.get(1)); + assertEquals(new HashMap<Integer, String>() {{ put (1, "1"); }}, + getSerializedMap(kvState, 1, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, userKeySerializer, userValueSerializer)); + + // draw a snapshot + KeyGroupsStateHandle snapshot1 = runSnapshot(backend.snapshot(682375462378L, 2, streamFactory)); + + // make some more modifications + backend.setCurrentKey(1); + state.put(1, "101"); + backend.setCurrentKey(2); + state.put(102, "102"); + backend.setCurrentKey(3); + state.put(103, "103"); + state.putAll(new HashMap<Integer, String>() {{ put(1031, "1031"); put(1032, "1032"); }}); + + // draw another snapshot + KeyGroupsStateHandle snapshot2 = runSnapshot(backend.snapshot(682375462379L, 4, streamFactory)); + + // validate the original state + backend.setCurrentKey(1); + assertEquals("101", state.get(1)); + assertEquals(new HashMap<Integer, String>() {{ put(1, "101"); }}, + getSerializedMap(kvState, 1, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, userKeySerializer, userValueSerializer)); + backend.setCurrentKey(2); + assertEquals("102", state.get(102)); + assertEquals(new HashMap<Integer, String>() {{ put(2, "2"); put(102, "102"); }}, + getSerializedMap(kvState, 2, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, userKeySerializer, userValueSerializer)); + backend.setCurrentKey(3); + assertEquals(3, state.size()); + assertTrue(state.contains(103)); + assertEquals("103", state.get(103)); + assertEquals(new HashMap<Integer, String>() {{ put(103, "103"); put(1031, "1031"); put(1032, "1032"); }}, + getSerializedMap(kvState, 3, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, userKeySerializer, userValueSerializer)); + + List<Integer> keys = new ArrayList<>(); + for (Integer key : state.keys()) { + keys.add(key); + } + List<Integer> expectedKeys = new ArrayList<Integer>() {{ add(103); add(1031); add(1032); }}; + assertEquals(keys.size(), expectedKeys.size()); + keys.removeAll(expectedKeys); + assertTrue(keys.isEmpty()); + + List<String> values = new ArrayList<>(); + for (String value : state.values()) { + values.add(value); + } + List<String> expectedValues = new ArrayList<String>() {{ add("103"); add("1031"); add("1032"); }}; + assertEquals(values.size(), expectedValues.size()); + values.removeAll(expectedValues); + assertTrue(values.isEmpty()); + + // make some more modifications + backend.setCurrentKey(1); + state.clear(); + backend.setCurrentKey(2); + state.remove(102); + backend.setCurrentKey(3); + final String updateSuffix = "_updated"; + Iterator<Map.Entry<Integer, String>> iterator = state.iterator(); + while (iterator.hasNext()) { + Map.Entry<Integer, String> entry = iterator.next(); + if (entry.getValue().length() != 4) { + iterator.remove(); + } else { + entry.setValue(entry.getValue() + updateSuffix); + } + } + + // validate the state + backend.setCurrentKey(1); + assertEquals(0, state.size()); + backend.setCurrentKey(2); + assertFalse(state.contains(102)); + backend.setCurrentKey(3); + for (Map.Entry<Integer, String> entry : state.entries()) { + assertEquals(4 + updateSuffix.length(), entry.getValue().length()); + assertTrue(entry.getValue().endsWith(updateSuffix)); + } + + backend.dispose(); + // restore the first snapshot and validate it + backend = restoreKeyedBackend(IntSerializer.INSTANCE, snapshot1); + snapshot1.discardState(); + + MapState<Integer, String> restored1 = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId); + @SuppressWarnings("unchecked") + InternalKvState<VoidNamespace> restoredKvState1 = (InternalKvState<VoidNamespace>) restored1; + + backend.setCurrentKey(1); + assertEquals("1", restored1.get(1)); + assertEquals(new HashMap<Integer, String>() {{ put (1, "1"); }}, + getSerializedMap(restoredKvState1, 1, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, userKeySerializer, userValueSerializer)); + backend.setCurrentKey(2); + assertEquals("2", restored1.get(2)); + assertEquals(new HashMap<Integer, String>() {{ put (2, "2"); }}, + getSerializedMap(restoredKvState1, 2, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, userKeySerializer, userValueSerializer)); + + backend.dispose(); + // restore the second snapshot and validate it + backend = restoreKeyedBackend(IntSerializer.INSTANCE, snapshot2); + snapshot2.discardState(); + + @SuppressWarnings("unchecked") + MapState<Integer, String> restored2 = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId); + @SuppressWarnings("unchecked") + InternalKvState<VoidNamespace> restoredKvState2 = (InternalKvState<VoidNamespace>) restored2; + + backend.setCurrentKey(1); + assertEquals("101", restored2.get(1)); + assertEquals(new HashMap<Integer, String>() {{ put (1, "101"); }}, + getSerializedMap(restoredKvState2, 1, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, userKeySerializer, userValueSerializer)); + backend.setCurrentKey(2); + assertEquals("102", restored2.get(102)); + assertEquals(new HashMap<Integer, String>() {{ put(2, "2"); put (102, "102"); }}, + getSerializedMap(restoredKvState2, 2, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, userKeySerializer, userValueSerializer)); + backend.setCurrentKey(3); + assertEquals("103", restored2.get(103)); + assertEquals(new HashMap<Integer, String>() {{ put(103, "103"); put(1031, "1031"); put(1032, "1032"); }}, + getSerializedMap(restoredKvState2, 3, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, userKeySerializer, userValueSerializer)); + + backend.dispose(); + } catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + + } /** * Verify that {@link ValueStateDescriptor} allows {@code null} as default. @@ -917,9 +1086,36 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten backend.dispose(); } + /** + * Verify that an empty {@code MapState} yields {@code null}. + */ + @Test + public void testMapStateDefaultValue() throws Exception { + AbstractKeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE); + + MapStateDescriptor<String, String> kvId = new MapStateDescriptor<>("id", String.class, String.class); + kvId.initializeSerializerUnlessSet(new ExecutionConfig()); + + MapState<String, String> state = backend.getPartitionedState( + VoidNamespace.INSTANCE, + VoidNamespaceSerializer.INSTANCE, kvId); + + backend.setCurrentKey(1); + assertNull(state.entries()); + state.put("Ciao", "Hello"); + state.put("Bello", "Nice"); + + assertEquals(state.size(), 2); + assertEquals(state.get("Ciao"), "Hello"); + assertEquals(state.get("Bello"), "Nice"); + state.clear(); + assertNull(state.entries()); + backend.dispose(); + } + /** * This test verifies that state is correctly assigned to key groups and that restore * restores the relevant key groups in the backend. @@ -1172,6 +1368,58 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten fail(e.getMessage()); } } + + @Test + @SuppressWarnings("unchecked") + public void testMapStateRestoreWithWrongSerializers() { + try { + CheckpointStreamFactory streamFactory = createStreamFactory(); + AbstractKeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE); + + MapStateDescriptor<String, String> kvId = new MapStateDescriptor<>("id", StringSerializer.INSTANCE, StringSerializer.INSTANCE); + MapState<String, String> state = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId); + + backend.setCurrentKey(1); + state.put("1", "First"); + backend.setCurrentKey(2); + state.put("2", "Second"); + + // draw a snapshot + KeyGroupsStateHandle snapshot1 = runSnapshot(backend.snapshot(682375462378L, 2, streamFactory)); + + backend.dispose(); + // restore the first snapshot and validate it + backend = restoreKeyedBackend(IntSerializer.INSTANCE, snapshot1); + snapshot1.discardState(); + + @SuppressWarnings("unchecked") + TypeSerializer<String> fakeStringSerializer = + (TypeSerializer<String>) (TypeSerializer<?>) FloatSerializer.INSTANCE; + + try { + kvId = new MapStateDescriptor<>("id", fakeStringSerializer, StringSerializer.INSTANCE); + + state = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId); + + state.entries(); + + fail("should recognize wrong serializers"); + } catch (IOException e) { + if (!e.getMessage().contains("Trying to access state using wrong ")) { + fail("wrong exception " + e); + } + // expected + } catch (Exception e) { + fail("wrong exception " + e); + } + backend.dispose(); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + @Test public void testCopyDefaultValue() throws Exception { @@ -1357,6 +1605,31 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten assertTrue(stateTable.get(keyGroupIndex) instanceof ConcurrentHashMap); assertTrue(stateTable.get(keyGroupIndex).get(VoidNamespace.INSTANCE) instanceof ConcurrentHashMap); } + + { + // MapState + MapStateDescriptor<Integer, String> desc = new MapStateDescriptor<>("map-state", Integer.class, String.class); + desc.setQueryable("my-query"); + desc.initializeSerializerUnlessSet(new ExecutionConfig()); + + MapState<Integer, String> state = backend.getPartitionedState( + VoidNamespace.INSTANCE, + VoidNamespaceSerializer.INSTANCE, + desc); + + InternalKvState<VoidNamespace> kvState = (InternalKvState<VoidNamespace>) state; + assertTrue(kvState instanceof AbstractHeapState); + + kvState.setCurrentNamespace(VoidNamespace.INSTANCE); + backend.setCurrentKey(1); + state.put(121818273, "121818273"); + + int keyGroupIndex = KeyGroupRangeAssignment.assignToKeyGroup(1, numberOfKeyGroups); + StateTable stateTable = ((AbstractHeapState) kvState).getStateTable(); + assertNotNull("State not set", stateTable.get(keyGroupIndex)); + assertTrue(stateTable.get(keyGroupIndex) instanceof ConcurrentHashMap); + assertTrue(stateTable.get(keyGroupIndex).get(VoidNamespace.INSTANCE) instanceof ConcurrentHashMap); + } backend.dispose(); } @@ -1495,6 +1768,32 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten return KvStateRequestSerializer.deserializeList(serializedValue, valueSerializer); } } + + /** + * Returns the value by getting the serialized value and deserializing it + * if it is not null. + */ + private static <UK, UV, K, N> Map<UK, UV> getSerializedMap( + InternalKvState<N> kvState, + K key, + TypeSerializer<K> keySerializer, + N namespace, + TypeSerializer<N> namespaceSerializer, + TypeSerializer<UK> userKeySerializer, + TypeSerializer<UV> userValueSerializer + ) throws Exception { + + byte[] serializedKeyAndNamespace = KvStateRequestSerializer.serializeKeyAndNamespace( + key, keySerializer, namespace, namespaceSerializer); + + byte[] serializedValue = kvState.getSerializedValue(serializedKeyAndNamespace); + + if (serializedValue == null) { + return null; + } else { + return KvStateRequestSerializer.deserializeMap(serializedValue, userKeySerializer, userValueSerializer); + } + } private KeyGroupsStateHandle runSnapshot(RunnableFuture<KeyGroupsStateHandle> snapshotRunnableFuture) throws Exception { if(!snapshotRunnableFuture.isDone()) { http://git-wip-us.apache.org/repos/asf/flink/blob/30c9e2b6/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/async/RichAsyncFunction.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/async/RichAsyncFunction.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/async/RichAsyncFunction.java index e6a186a..7971460 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/async/RichAsyncFunction.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/async/RichAsyncFunction.java @@ -36,6 +36,8 @@ import org.apache.flink.api.common.state.FoldingState; import org.apache.flink.api.common.state.FoldingStateDescriptor; import org.apache.flink.api.common.state.ListState; import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.state.MapState; +import org.apache.flink.api.common.state.MapStateDescriptor; import org.apache.flink.api.common.state.ReducingState; import org.apache.flink.api.common.state.ReducingStateDescriptor; import org.apache.flink.api.common.state.ValueState; @@ -171,6 +173,11 @@ public abstract class RichAsyncFunction<IN, OUT> extends AbstractRichFunction im throw new UnsupportedOperationException("State is not supported in rich async functions."); } + @Override + public <UK, UV> MapState<UK, UV> getMapState(MapStateDescriptor<UK, UV> stateProperties) { + throw new UnsupportedOperationException("State is not supported in rich async functions."); + } + @Override public <V, A extends Serializable> void addAccumulator(String name, Accumulator<V, A> accumulator) { http://git-wip-us.apache.org/repos/asf/flink/blob/30c9e2b6/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContext.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContext.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContext.java index b9c9b9b..b666a2b 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContext.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContext.java @@ -27,6 +27,8 @@ import org.apache.flink.api.common.state.FoldingStateDescriptor; import org.apache.flink.api.common.state.KeyedStateStore; import org.apache.flink.api.common.state.ListState; import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.state.MapState; +import org.apache.flink.api.common.state.MapStateDescriptor; import org.apache.flink.api.common.state.ReducingState; import org.apache.flink.api.common.state.ReducingStateDescriptor; import org.apache.flink.api.common.state.StateDescriptor; @@ -136,6 +138,13 @@ public class StreamingRuntimeContext extends AbstractRuntimeUDFContext { stateProperties.initializeSerializerUnlessSet(getExecutionConfig()); return keyedStateStore.getFoldingState(stateProperties); } + + @Override + public <UK, UV> MapState<UK, UV> getMapState(MapStateDescriptor<UK, UV> stateProperties) { + KeyedStateStore keyedStateStore = checkPreconditionsAndGetKeyedStateStore(stateProperties); + stateProperties.initializeSerializerUnlessSet(getExecutionConfig()); + return keyedStateStore.getMapState(stateProperties); + } private KeyedStateStore checkPreconditionsAndGetKeyedStateStore(StateDescriptor<?, ?> stateDescriptor) { Preconditions.checkNotNull(stateDescriptor, "The state properties must not be null"); http://git-wip-us.apache.org/repos/asf/flink/blob/30c9e2b6/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/functions/async/RichAsyncFunctionTest.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/functions/async/RichAsyncFunctionTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/functions/async/RichAsyncFunctionTest.java index 815f856..562883d 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/functions/async/RichAsyncFunctionTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/functions/async/RichAsyncFunctionTest.java @@ -27,6 +27,7 @@ import org.apache.flink.api.common.functions.ReduceFunction; import org.apache.flink.api.common.functions.RuntimeContext; import org.apache.flink.api.common.state.FoldingStateDescriptor; import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.state.MapStateDescriptor; import org.apache.flink.api.common.state.ReducingStateDescriptor; import org.apache.flink.api.common.state.ValueStateDescriptor; import org.apache.flink.metrics.MetricGroup; @@ -165,7 +166,6 @@ public class RichAsyncFunctionTest { } catch (UnsupportedOperationException e) { // expected } - try { runtimeContext.getFoldingState(new FoldingStateDescriptor<>("foobar", 0, new FoldFunction<Integer, Integer>() { @Override @@ -178,6 +178,12 @@ public class RichAsyncFunctionTest { } try { + runtimeContext.getMapState(new MapStateDescriptor<>("foobar", Integer.class, String.class)); + } catch (UnsupportedOperationException e) { + // expected + } + + try { runtimeContext.addAccumulator("foobar", new Accumulator<Integer, Integer>() { private static final long serialVersionUID = -4673320336846482358L; http://git-wip-us.apache.org/repos/asf/flink/blob/30c9e2b6/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContextTest.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContextTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContextTest.java index 294b8da..36496f2 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContextTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContextTest.java @@ -27,6 +27,8 @@ import org.apache.flink.api.common.functions.ReduceFunction; import org.apache.flink.api.common.state.FoldingStateDescriptor; import org.apache.flink.api.common.state.ListState; import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.state.MapState; +import org.apache.flink.api.common.state.MapStateDescriptor; import org.apache.flink.api.common.state.ReducingStateDescriptor; import org.apache.flink.api.common.state.StateDescriptor; import org.apache.flink.api.common.state.ValueStateDescriptor; @@ -52,6 +54,7 @@ import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; import java.util.Collections; +import java.util.Map; import java.util.concurrent.Future; import java.util.concurrent.atomic.AtomicReference; @@ -178,7 +181,7 @@ public class StreamingRuntimeContextTest { public void testListStateReturnsEmptyListByDefault() throws Exception { StreamingRuntimeContext context = new StreamingRuntimeContext( - createPlainMockOp(), + createListPlainMockOp(), createMockEnvironment(), Collections.<String, Accumulator<?, ?>>emptyMap()); @@ -190,6 +193,48 @@ public class StreamingRuntimeContextTest { assertFalse(value.iterator().hasNext()); } + @Test + public void testMapStateInstantiation() throws Exception { + + final ExecutionConfig config = new ExecutionConfig(); + config.registerKryoType(Path.class); + + final AtomicReference<Object> descriptorCapture = new AtomicReference<>(); + + StreamingRuntimeContext context = new StreamingRuntimeContext( + createDescriptorCapturingMockOp(descriptorCapture, config), + createMockEnvironment(), + Collections.<String, Accumulator<?, ?>>emptyMap()); + + MapStateDescriptor<String, TaskInfo> descr = + new MapStateDescriptor<>("name", String.class, TaskInfo.class); + + context.getMapState(descr); + + MapStateDescriptor<?, ?> descrIntercepted = (MapStateDescriptor<?, ?>) descriptorCapture.get(); + TypeSerializer<?> valueSerializer = descrIntercepted.getValueSerializer(); + + // check that the Path class is really registered, i.e., the execution config was applied + assertTrue(valueSerializer instanceof KryoSerializer); + assertTrue(((KryoSerializer<?>) valueSerializer).getKryo().getRegistration(Path.class).getId() > 0); + } + + @Test + public void testMapStateReturnsEmptyMapByDefault() throws Exception { + + StreamingRuntimeContext context = new StreamingRuntimeContext( + createMapPlainMockOp(), + createMockEnvironment(), + Collections.<String, Accumulator<?, ?>>emptyMap()); + + MapStateDescriptor<Integer, String> descr = new MapStateDescriptor<>("name", Integer.class, String.class); + MapState<Integer, String> state = context.getMapState(descr); + + Iterable<Map.Entry<Integer, String>> value = state.entries(); + assertNotNull(value); + assertFalse(value.iterator().hasNext()); + } + // ------------------------------------------------------------------------ // // ------------------------------------------------------------------------ @@ -221,7 +266,7 @@ public class StreamingRuntimeContextTest { } @SuppressWarnings("unchecked") - private static AbstractStreamOperator<?> createPlainMockOp() throws Exception { + private static AbstractStreamOperator<?> createListPlainMockOp() throws Exception { AbstractStreamOperator<?> operatorMock = mock(AbstractStreamOperator.class); ExecutionConfig config = new ExecutionConfig(); @@ -256,6 +301,42 @@ public class StreamingRuntimeContextTest { return operatorMock; } + @SuppressWarnings("unchecked") + private static AbstractStreamOperator<?> createMapPlainMockOp() throws Exception { + + AbstractStreamOperator<?> operatorMock = mock(AbstractStreamOperator.class); + ExecutionConfig config = new ExecutionConfig(); + + KeyedStateBackend keyedStateBackend= mock(KeyedStateBackend.class); + + DefaultKeyedStateStore keyedStateStore = new DefaultKeyedStateStore(keyedStateBackend, config); + + when(operatorMock.getExecutionConfig()).thenReturn(config); + + doAnswer(new Answer<MapState<Integer, String>>() { + + @Override + public MapState<Integer, String> answer(InvocationOnMock invocationOnMock) throws Throwable { + MapStateDescriptor<Integer, String> descr = + (MapStateDescriptor<Integer, String>) invocationOnMock.getArguments()[2]; + + AbstractKeyedStateBackend<Integer> backend = new MemoryStateBackend().createKeyedStateBackend( + new DummyEnvironment("test_task", 1, 0), + new JobID(), + "test_op", + IntSerializer.INSTANCE, + 1, + new KeyGroupRange(0, 0), + new KvStateRegistry().createTaskRegistry(new JobID(), new JobVertexID())); + backend.setCurrentKey(0); + return backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, descr); + } + }).when(keyedStateBackend).getPartitionedState(Matchers.any(), any(TypeSerializer.class), any(MapStateDescriptor.class)); + + when(operatorMock.getKeyedStateStore()).thenReturn(keyedStateStore); + return operatorMock; + } + private static Environment createMockEnvironment() { Environment env = mock(Environment.class); when(env.getUserClassLoader()).thenReturn(StreamingRuntimeContextTest.class.getClassLoader()); http://git-wip-us.apache.org/repos/asf/flink/blob/30c9e2b6/flink-tests/src/test/java/org/apache/flink/test/query/KVStateRequestSerializerRocksDBTest.java ---------------------------------------------------------------------- diff --git a/flink-tests/src/test/java/org/apache/flink/test/query/KVStateRequestSerializerRocksDBTest.java b/flink-tests/src/test/java/org/apache/flink/test/query/KVStateRequestSerializerRocksDBTest.java index 0562443..6e2fd62 100644 --- a/flink-tests/src/test/java/org/apache/flink/test/query/KVStateRequestSerializerRocksDBTest.java +++ b/flink-tests/src/test/java/org/apache/flink/test/query/KVStateRequestSerializerRocksDBTest.java @@ -20,8 +20,10 @@ package org.apache.flink.test.query; import org.apache.flink.api.common.JobID; import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.state.MapStateDescriptor; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.common.typeutils.base.LongSerializer; +import org.apache.flink.api.common.typeutils.base.StringSerializer; import org.apache.flink.contrib.streaming.state.PredefinedOptions; import org.apache.flink.contrib.streaming.state.RocksDBKeyedStateBackend; import org.apache.flink.runtime.query.TaskKvStateRegistry; @@ -32,6 +34,7 @@ import org.apache.flink.runtime.state.VoidNamespace; import org.apache.flink.runtime.state.VoidNamespaceSerializer; import org.apache.flink.runtime.state.internal.InternalListState; +import org.apache.flink.runtime.state.internal.InternalMapState; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; @@ -122,4 +125,41 @@ public final class KVStateRequestSerializerRocksDBTest { KvStateRequestSerializerTest.testListSerialization(key, listState); } + + /** + * Tests map serialization and deserialization match. + * + * @see KvStateRequestSerializerTest#testMapSerialization() + * KvStateRequestSerializerTest#testMapSerialization() using the heap state back-end + * test + */ + @Test + public void testMapSerialization() throws Exception { + final long key = 0L; + + // objects for RocksDB state list serialisation + DBOptions dbOptions = PredefinedOptions.DEFAULT.createDBOptions(); + dbOptions.setCreateIfMissing(true); + ColumnFamilyOptions columnFamilyOptions = PredefinedOptions.DEFAULT.createColumnOptions(); + final RocksDBKeyedStateBackend<Long> longHeapKeyedStateBackend = + new RocksDBKeyedStateBackend<>( + new JobID(), "no-op", + ClassLoader.getSystemClassLoader(), + temporaryFolder.getRoot(), + dbOptions, + columnFamilyOptions, + mock(TaskKvStateRegistry.class), + LongSerializer.INSTANCE, + 1, new KeyGroupRange(0, 0) + ); + longHeapKeyedStateBackend.setCurrentKey(key); + + final InternalMapState<VoidNamespace, Long, String> mapState = (InternalMapState<VoidNamespace, Long, String>) + longHeapKeyedStateBackend.getPartitionedState( + VoidNamespace.INSTANCE, + VoidNamespaceSerializer.INSTANCE, + new MapStateDescriptor<>("test", LongSerializer.INSTANCE, StringSerializer.INSTANCE)); + + KvStateRequestSerializerTest.testMapSerialization(key, mapState); + } }
