http://git-wip-us.apache.org/repos/asf/flink/blob/caf46728/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemReducingState.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemReducingState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemReducingState.java new file mode 100644 index 0000000..ce16344 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemReducingState.java @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state.memory; + +import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.common.state.ReducingState; +import org.apache.flink.api.common.state.ReducingStateDescriptor; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.runtime.state.KvState; +import org.apache.flink.runtime.state.KvStateSnapshot; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +/** + * Heap-backed partitioned {@link org.apache.flink.api.common.state.ReducingState} that is + * snapshotted into a serialized memory copy. + * + * @param <K> The type of the key. + * @param <N> The type of the namespace. + * @param <V> The type of the values in the list state. + */ +public class MemReducingState<K, N, V> + extends AbstractMemState<K, N, V, ReducingState<V>, ReducingStateDescriptor<V>> + implements ReducingState<V> { + + private final ReduceFunction<V> reduceFunction; + + public MemReducingState(TypeSerializer<K> keySerializer, + TypeSerializer<N> namespaceSerializer, + ReducingStateDescriptor<V> stateDesc) { + super(keySerializer, namespaceSerializer, stateDesc.getSerializer(), stateDesc); + this.reduceFunction = stateDesc.getReduceFunction(); + } + + public MemReducingState(TypeSerializer<K> keySerializer, + TypeSerializer<N> namespaceSerializer, + ReducingStateDescriptor<V> stateDesc, + HashMap<N, Map<K, V>> state) { + super(keySerializer, namespaceSerializer, stateDesc.getSerializer(), stateDesc, state); + this.reduceFunction = stateDesc.getReduceFunction(); + } + + @Override + public V get() { + if (currentNSState == null) { + currentNSState = state.get(currentNamespace); + } + if (currentNSState != null) { + return currentNSState.get(currentKey); + } + return null; + } + + @Override + public void add(V value) throws IOException { + if (currentKey == null) { + throw new RuntimeException("No key available."); + } + + if (currentNSState == null) { + currentNSState = new HashMap<>(); + state.put(currentNamespace, currentNSState); + } +// currentKeyState.merge(currentNamespace, value, new BiFunction<V, V, V>() { +// @Override +// public V apply(V v, V v2) { +// try { +// return reduceFunction.reduce(v, v2); +// } catch (Exception e) { +// return null; +// } +// } +// }); + V currentValue = currentNSState.get(currentKey); + if (currentValue == null) { + currentNSState.put(currentKey, value); + } else { + try { + currentNSState.put(currentKey, reduceFunction.reduce(currentValue, value)); + } catch (Exception e) { + throw new RuntimeException("Could not add value to reducing state.", e); + } + } + } + + @Override + public KvStateSnapshot<K, N, ReducingState<V>, ReducingStateDescriptor<V>, MemoryStateBackend> createHeapSnapshot(byte[] bytes) { + return new Snapshot<>(getKeySerializer(), getNamespaceSerializer(), stateSerializer, stateDesc, bytes); + } + + public static class Snapshot<K, N, V> extends AbstractMemStateSnapshot<K, N, V, ReducingState<V>, ReducingStateDescriptor<V>> { + private static final long serialVersionUID = 1L; + + public Snapshot(TypeSerializer<K> keySerializer, + TypeSerializer<N> namespaceSerializer, + TypeSerializer<V> stateSerializer, + ReducingStateDescriptor<V> stateDescs, byte[] data) { + super(keySerializer, namespaceSerializer, stateSerializer, stateDescs, data); + } + + @Override + public KvState<K, N, ReducingState<V>, ReducingStateDescriptor<V>, MemoryStateBackend> createMemState(HashMap<N, Map<K, V>> stateMap) { + return new MemReducingState<>(keySerializer, namespaceSerializer, stateDesc, stateMap); + } + }}
http://git-wip-us.apache.org/repos/asf/flink/blob/caf46728/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemValueState.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemValueState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemValueState.java new file mode 100644 index 0000000..8ce166a --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemValueState.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state.memory; + +import org.apache.flink.api.common.state.ValueState; +import org.apache.flink.api.common.state.ValueStateDescriptor; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.runtime.state.KvState; +import org.apache.flink.runtime.state.KvStateSnapshot; + +import java.util.HashMap; +import java.util.Map; + +/** + * Heap-backed key/value state that is snapshotted into a serialized memory copy. + * + * @param <K> The type of the key. + * @param <N> The type of the namespace. + * @param <V> The type of the value. + */ +public class MemValueState<K, N, V> + extends AbstractMemState<K, N, V, ValueState<V>, ValueStateDescriptor<V>> + implements ValueState<V> { + + public MemValueState(TypeSerializer<K> keySerializer, + TypeSerializer<N> namespaceSerializer, + ValueStateDescriptor<V> stateDesc) { + super(keySerializer, namespaceSerializer, stateDesc.getSerializer(), stateDesc); + } + + public MemValueState(TypeSerializer<K> keySerializer, + TypeSerializer<N> namespaceSerializer, + ValueStateDescriptor<V> stateDesc, + HashMap<N, Map<K, V>> state) { + super(keySerializer, namespaceSerializer, stateDesc.getSerializer(), stateDesc, state); + } + + @Override + public V value() { + if (currentNSState == null) { + currentNSState = state.get(currentNamespace); + } + if (currentNSState != null) { + V value = currentNSState.get(currentKey); + return value != null ? value : stateDesc.getDefaultValue(); + } + return stateDesc.getDefaultValue(); + } + + @Override + public void update(V value) { + if (currentKey == null) { + throw new RuntimeException("No key available."); + } + + if (currentNSState == null) { + currentNSState = new HashMap<>(); + state.put(currentNamespace, currentNSState); + } + + currentNSState.put(currentKey, value); + } + + @Override + public KvStateSnapshot<K, N, ValueState<V>, ValueStateDescriptor<V>, MemoryStateBackend> createHeapSnapshot(byte[] bytes) { + return new Snapshot<>(getKeySerializer(), getNamespaceSerializer(), stateSerializer, stateDesc, bytes); + } + + public static class Snapshot<K, N, V> extends AbstractMemStateSnapshot<K, N, V, ValueState<V>, ValueStateDescriptor<V>> { + private static final long serialVersionUID = 1L; + + public Snapshot(TypeSerializer<K> keySerializer, + TypeSerializer<N> namespaceSerializer, + TypeSerializer<V> stateSerializer, + ValueStateDescriptor<V> stateDescs, byte[] data) { + super(keySerializer, namespaceSerializer, stateSerializer, stateDescs, data); + } + + @Override + public KvState<K, N, ValueState<V>, ValueStateDescriptor<V>, MemoryStateBackend> createMemState(HashMap<N, Map<K, V>> stateMap) { + return new MemValueState<>(keySerializer, namespaceSerializer, stateDesc, stateMap); + } + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/caf46728/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryHeapKvStateSnapshot.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryHeapKvStateSnapshot.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryHeapKvStateSnapshot.java deleted file mode 100644 index 0cb7fa4..0000000 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryHeapKvStateSnapshot.java +++ /dev/null @@ -1,107 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.runtime.state.memory; - -import org.apache.flink.api.common.typeutils.TypeSerializer; -import org.apache.flink.runtime.state.KvStateSnapshot; -import org.apache.flink.runtime.util.DataInputDeserializer; - -import java.util.HashMap; - -/** - * A snapshot of a {@link MemHeapKvState} for a checkpoint. The data is stored in a heap byte - * array, in serialized form. - * - * @param <K> The type of the key in the snapshot state. - * @param <V> The type of the value in the snapshot state. - */ -public class MemoryHeapKvStateSnapshot<K, V> implements KvStateSnapshot<K, V, MemoryStateBackend> { - - private static final long serialVersionUID = 1L; - - /** Name of the key serializer class */ - private final String keySerializerClassName; - - /** Name of the value serializer class */ - private final String valueSerializerClassName; - - /** The serialized data of the state key/value pairs */ - private final byte[] data; - - /** The number of key/value pairs */ - private final int numEntries; - - /** - * Creates a new heap memory state snapshot. - * - * @param keySerializer The serializer for the keys. - * @param valueSerializer The serializer for the values. - * @param data The serialized data of the state key/value pairs - * @param numEntries The number of key/value pairs - */ - public MemoryHeapKvStateSnapshot(TypeSerializer<K> keySerializer, - TypeSerializer<V> valueSerializer, byte[] data, int numEntries) { - this.keySerializerClassName = keySerializer.getClass().getName(); - this.valueSerializerClassName = valueSerializer.getClass().getName(); - this.data = data; - this.numEntries = numEntries; - } - - @Override - public MemHeapKvState<K, V> restoreState( - MemoryStateBackend stateBackend, - final TypeSerializer<K> keySerializer, - final TypeSerializer<V> valueSerializer, - V defaultValue, - ClassLoader classLoader, - long recoveryTimestamp) throws Exception { - - // validity checks - if (!keySerializer.getClass().getName().equals(keySerializerClassName) || - !valueSerializer.getClass().getName().equals(valueSerializerClassName)) { - throw new IllegalArgumentException( - "Cannot restore the state from the snapshot with the given serializers. " + - "State (K/V) was serialized with (" + valueSerializerClassName + - "/" + keySerializerClassName + ")"); - } - - // restore state - HashMap<K, V> stateMap = new HashMap<>(numEntries); - DataInputDeserializer in = new DataInputDeserializer(data, 0, data.length); - - for (int i = 0; i < numEntries; i++) { - K key = keySerializer.deserialize(in); - V value = valueSerializer.deserialize(in); - stateMap.put(key, value); - } - - return new MemHeapKvState<K, V>(keySerializer, valueSerializer, defaultValue, stateMap); - } - - /** - * Discarding the heap state is a no-op. - */ - @Override - public void discardState() {} - - @Override - public long getStateSize() { - return data.length; - } -} http://git-wip-us.apache.org/repos/asf/flink/blob/caf46728/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java index 2963237..2b7b5f1 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java @@ -18,10 +18,15 @@ package org.apache.flink.runtime.state.memory; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.state.ReducingState; +import org.apache.flink.api.common.state.ReducingStateDescriptor; +import org.apache.flink.api.common.state.ValueState; +import org.apache.flink.api.common.state.ValueStateDescriptor; import org.apache.flink.api.common.typeutils.TypeSerializer; -import org.apache.flink.runtime.execution.Environment; -import org.apache.flink.runtime.state.StateBackend; import org.apache.flink.runtime.state.StateHandle; +import org.apache.flink.runtime.state.AbstractStateBackend; import org.apache.flink.runtime.state.StreamStateHandle; import java.io.ByteArrayOutputStream; @@ -29,11 +34,11 @@ import java.io.IOException; import java.io.Serializable; /** - * A {@link StateBackend} that stores all its data and checkpoints in memory and has no + * A {@link AbstractStateBackend} that stores all its data and checkpoints in memory and has no * capabilities to spill to disk. Checkpoints are serialized and the serialized data is * transferred */ -public class MemoryStateBackend extends StateBackend<MemoryStateBackend> { +public class MemoryStateBackend extends AbstractStateBackend { private static final long serialVersionUID = 4109305377809414635L; @@ -66,11 +71,6 @@ public class MemoryStateBackend extends StateBackend<MemoryStateBackend> { // ------------------------------------------------------------------------ @Override - public void initializeForJob(Environment env) { - // nothing to do here - } - - @Override public void disposeAllStateForCurrentJob() { // nothing to do here, GC will do it } @@ -83,9 +83,18 @@ public class MemoryStateBackend extends StateBackend<MemoryStateBackend> { // ------------------------------------------------------------------------ @Override - public <K, V> MemHeapKvState<K, V> createKvState(String stateId, String stateName, - TypeSerializer<K> keySerializer, TypeSerializer<V> valueSerializer, V defaultValue) { - return new MemHeapKvState<K, V>(keySerializer, valueSerializer, defaultValue); + public <N, V> ValueState<V> createValueState(TypeSerializer<N> namespaceSerializer, ValueStateDescriptor<V> stateDesc) throws Exception { + return new MemValueState<>(keySerializer, namespaceSerializer, stateDesc); + } + + @Override + public <N, T> ListState<T> createListState(TypeSerializer<N> namespaceSerializer, ListStateDescriptor<T> stateDesc) throws Exception { + return new MemListState<>(keySerializer, namespaceSerializer, stateDesc); + } + + @Override + public <N, T> ReducingState<T> createReducingState(TypeSerializer<N> namespaceSerializer, ReducingStateDescriptor<T> stateDesc) throws Exception { + return new MemReducingState<>(keySerializer, namespaceSerializer, stateDesc); } /** @@ -196,14 +205,11 @@ public class MemoryStateBackend extends StateBackend<MemoryStateBackend> { // Static default instance // ------------------------------------------------------------------------ - /** The default instance of this state backend, using the default maximal state size */ - private static final MemoryStateBackend DEFAULT_INSTANCE = new MemoryStateBackend(); - /** * Gets the default instance of this state backend, using the default maximal state size. * @return The default instance of this state backend. */ - public static MemoryStateBackend defaultInstance() { - return DEFAULT_INSTANCE; + public static MemoryStateBackend create() { + return new MemoryStateBackend(); } } http://git-wip-us.apache.org/repos/asf/flink/blob/caf46728/flink-runtime/src/test/java/org/apache/flink/runtime/state/FileStateBackendTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/FileStateBackendTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/FileStateBackendTest.java index 05bc8fa..e7bf80e 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/FileStateBackendTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/FileStateBackendTest.java @@ -18,29 +18,8 @@ package org.apache.flink.runtime.state; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertTrue; -import static org.junit.Assert.fail; - -import java.io.File; -import java.io.IOException; -import java.io.InputStream; -import java.net.URI; -import java.util.Random; -import java.util.UUID; - import org.apache.commons.io.FileUtils; - -import org.apache.flink.api.common.typeutils.TypeSerializer; -import org.apache.flink.api.common.typeutils.base.FloatSerializer; import org.apache.flink.api.common.typeutils.base.IntSerializer; -import org.apache.flink.api.common.typeutils.base.IntValueSerializer; -import org.apache.flink.api.common.typeutils.base.StringSerializer; -import org.apache.flink.api.java.typeutils.runtime.ValueSerializer; import org.apache.flink.configuration.ConfigConstants; import org.apache.flink.core.fs.Path; import org.apache.flink.core.testutils.CommonTestUtils; @@ -48,12 +27,32 @@ import org.apache.flink.runtime.operators.testutils.DummyEnvironment; import org.apache.flink.runtime.state.filesystem.FileStreamStateHandle; import org.apache.flink.runtime.state.filesystem.FsStateBackend; import org.apache.flink.runtime.state.memory.ByteStreamStateHandle; -import org.apache.flink.types.IntValue; -import org.apache.flink.types.StringValue; import org.junit.Test; -public class FileStateBackendTest { +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.net.URI; +import java.util.Random; +import java.util.UUID; + +import static org.junit.Assert.*; + +public class FileStateBackendTest extends StateBackendTestBase<FsStateBackend> { + + private File stateDir; + + @Override + protected FsStateBackend getStateBackend() throws Exception { + stateDir = new File(ConfigConstants.DEFAULT_TASK_MANAGER_TMP_PATH, UUID.randomUUID().toString()); + return new FsStateBackend(localFileUri(stateDir)); + } + + @Override + protected void cleanup() throws Exception { + deleteDirectorySilently(stateDir); + } @Test public void testSetupAndSerialization() { @@ -80,7 +79,7 @@ public class FileStateBackendTest { // supreme! } - backend.initializeForJob(new DummyEnvironment("test", 1, 0)); + backend.initializeForJob(new DummyEnvironment("test", 1, 0), "test-op", IntSerializer.INSTANCE); assertNotNull(backend.getCheckpointDirectory()); File checkpointDir = new File(backend.getCheckpointDirectory().toUri().getPath()); @@ -105,9 +104,8 @@ public class FileStateBackendTest { public void testSerializableState() { File tempDir = new File(ConfigConstants.DEFAULT_TASK_MANAGER_TMP_PATH, UUID.randomUUID().toString()); try { - FsStateBackend backend = CommonTestUtils.createCopySerializable( - new FsStateBackend(tempDir.toURI(), 40)); - backend.initializeForJob(new DummyEnvironment("test", 1, 0)); + FsStateBackend backend = CommonTestUtils.createCopySerializable(new FsStateBackend(localFileUri(tempDir))); + backend.initializeForJob(new DummyEnvironment("test", 1, 0), "test-op", IntSerializer.INSTANCE); File checkpointDir = new File(backend.getCheckpointDirectory().toUri().getPath()); @@ -118,13 +116,13 @@ public class FileStateBackendTest { StateHandle<String> handle1 = backend.checkpointStateSerializable(state1, 439568923746L, System.currentTimeMillis()); StateHandle<String> handle2 = backend.checkpointStateSerializable(state2, 439568923746L, System.currentTimeMillis()); StateHandle<Integer> handle3 = backend.checkpointStateSerializable(state3, 439568923746L, System.currentTimeMillis()); - + assertEquals(state1, handle1.getState(getClass().getClassLoader())); handle1.discardState(); - + assertEquals(state2, handle2.getState(getClass().getClassLoader())); handle2.discardState(); - + assertEquals(state3, handle3.getState(getClass().getClassLoader())); handle3.discardState(); @@ -144,10 +142,9 @@ public class FileStateBackendTest { File tempDir = new File(ConfigConstants.DEFAULT_TASK_MANAGER_TMP_PATH, UUID.randomUUID().toString()); try { // the state backend has a very low in-mem state threshold (15 bytes) - FsStateBackend backend = CommonTestUtils.createCopySerializable( - new FsStateBackend(tempDir.toURI(), 15)); - - backend.initializeForJob(new DummyEnvironment("test", 1, 0)); + FsStateBackend backend = CommonTestUtils.createCopySerializable(new FsStateBackend(tempDir.toURI(), 15)); + + backend.initializeForJob(new DummyEnvironment("test", 1, 0), "test-op", IntSerializer.INSTANCE); File checkpointDir = new File(backend.getCheckpointDirectory().toUri().getPath()); @@ -181,14 +178,14 @@ public class FileStateBackendTest { // use with try-with-resources FileStreamStateHandle handle4; - try (StateBackend.CheckpointStateOutputStream stream4 = + try (AbstractStateBackend.CheckpointStateOutputStream stream4 = backend.createCheckpointStateOutputStream(checkpointId, System.currentTimeMillis())) { stream4.write(state4); handle4 = (FileStreamStateHandle) stream4.closeAndGetHandle(); } // close before accessing handle - StateBackend.CheckpointStateOutputStream stream5 = + AbstractStateBackend.CheckpointStateOutputStream stream5 = backend.createCheckpointStateOutputStream(checkpointId, System.currentTimeMillis()); stream5.write(state4); stream5.close(); @@ -223,197 +220,6 @@ public class FileStateBackendTest { } } - @Test - public void testKeyValueState() { - File tempDir = new File(ConfigConstants.DEFAULT_TASK_MANAGER_TMP_PATH, UUID.randomUUID().toString()); - try { - FsStateBackend backend = CommonTestUtils.createCopySerializable(new FsStateBackend(localFileUri(tempDir))); - backend.initializeForJob(new DummyEnvironment("test", 1, 0)); - - File checkpointDir = new File(backend.getCheckpointDirectory().toUri().getPath()); - - KvState<Integer, String, FsStateBackend> kv = - backend.createKvState("0", "a", IntSerializer.INSTANCE, StringSerializer.INSTANCE, null); - - assertEquals(0, kv.size()); - - // some modifications to the state - kv.setCurrentKey(1); - assertNull(kv.value()); - kv.update("1"); - assertEquals(1, kv.size()); - kv.setCurrentKey(2); - assertNull(kv.value()); - kv.update("2"); - assertEquals(2, kv.size()); - kv.setCurrentKey(1); - assertEquals("1", kv.value()); - assertEquals(2, kv.size()); - - // draw a snapshot - KvStateSnapshot<Integer, String, FsStateBackend> snapshot1 = - kv.snapshot(682375462378L, System.currentTimeMillis()); - - // make some more modifications - kv.setCurrentKey(1); - kv.update("u1"); - kv.setCurrentKey(2); - kv.update("u2"); - kv.setCurrentKey(3); - kv.update("u3"); - - // draw another snapshot - KvStateSnapshot<Integer, String, FsStateBackend> snapshot2 = - kv.snapshot(682375462379L, System.currentTimeMillis()); - - // validate the original state - assertEquals(3, kv.size()); - kv.setCurrentKey(1); - assertEquals("u1", kv.value()); - kv.setCurrentKey(2); - assertEquals("u2", kv.value()); - kv.setCurrentKey(3); - assertEquals("u3", kv.value()); - - // restore the first snapshot and validate it - KvState<Integer, String, FsStateBackend> restored1 = snapshot1.restoreState(backend, - IntSerializer.INSTANCE, StringSerializer.INSTANCE, null, getClass().getClassLoader(), 1); - - assertEquals(2, restored1.size()); - restored1.setCurrentKey(1); - assertEquals("1", restored1.value()); - restored1.setCurrentKey(2); - assertEquals("2", restored1.value()); - - // restore the first snapshot and validate it - KvState<Integer, String, FsStateBackend> restored2 = snapshot2.restoreState(backend, - IntSerializer.INSTANCE, StringSerializer.INSTANCE, null, getClass().getClassLoader(), 1); - - assertEquals(3, restored2.size()); - restored2.setCurrentKey(1); - assertEquals("u1", restored2.value()); - restored2.setCurrentKey(2); - assertEquals("u2", restored2.value()); - restored2.setCurrentKey(3); - assertEquals("u3", restored2.value()); - - snapshot1.discardState(); - assertFalse(isDirectoryEmpty(checkpointDir)); - - snapshot2.discardState(); - assertTrue(isDirectoryEmpty(checkpointDir)); - } - catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); - } - finally { - deleteDirectorySilently(tempDir); - } - } - - @Test - public void testRestoreWithWrongSerializers() { - File tempDir = new File(ConfigConstants.DEFAULT_TASK_MANAGER_TMP_PATH, UUID.randomUUID().toString()); - try { - FsStateBackend backend = CommonTestUtils.createCopySerializable(new FsStateBackend(localFileUri(tempDir))); - backend.initializeForJob(new DummyEnvironment("test", 1, 0)); - - File checkpointDir = new File(backend.getCheckpointDirectory().toUri().getPath()); - - KvState<Integer, String, FsStateBackend> kv = - backend.createKvState("a_0", "a", IntSerializer.INSTANCE, StringSerializer.INSTANCE, null); - - kv.setCurrentKey(1); - kv.update("1"); - kv.setCurrentKey(2); - kv.update("2"); - - KvStateSnapshot<Integer, String, FsStateBackend> snapshot = - kv.snapshot(682375462378L, System.currentTimeMillis()); - - - @SuppressWarnings("unchecked") - TypeSerializer<Integer> fakeIntSerializer = - (TypeSerializer<Integer>) (TypeSerializer<?>) FloatSerializer.INSTANCE; - - @SuppressWarnings("unchecked") - TypeSerializer<String> fakeStringSerializer = - (TypeSerializer<String>) (TypeSerializer<?>) new ValueSerializer<StringValue>(StringValue.class); - - try { - snapshot.restoreState(backend, fakeIntSerializer, - StringSerializer.INSTANCE, null, getClass().getClassLoader(), 1); - fail("should recognize wrong serializers"); - } catch (IllegalArgumentException e) { - // expected - } catch (Exception e) { - fail("wrong exception"); - } - - try { - snapshot.restoreState(backend, IntSerializer.INSTANCE, - fakeStringSerializer, null, getClass().getClassLoader(), 1); - fail("should recognize wrong serializers"); - } catch (IllegalArgumentException e) { - // expected - } catch (Exception e) { - fail("wrong exception"); - } - - try { - snapshot.restoreState(backend, fakeIntSerializer, - fakeStringSerializer, null, getClass().getClassLoader(), 1); - fail("should recognize wrong serializers"); - } catch (IllegalArgumentException e) { - // expected - } catch (Exception e) { - fail("wrong exception"); - } - - snapshot.discardState(); - - assertTrue(isDirectoryEmpty(checkpointDir)); - } - catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); - } - finally { - deleteDirectorySilently(tempDir); - } - } - - @Test - public void testCopyDefaultValue() { - File tempDir = new File(ConfigConstants.DEFAULT_TASK_MANAGER_TMP_PATH, UUID.randomUUID().toString()); - try { - FsStateBackend backend = CommonTestUtils.createCopySerializable(new FsStateBackend(localFileUri(tempDir))); - backend.initializeForJob(new DummyEnvironment("test", 1, 0)); - - KvState<Integer, IntValue, FsStateBackend> kv = - backend.createKvState("a_0", "a", IntSerializer.INSTANCE, IntValueSerializer.INSTANCE, new IntValue(-1)); - - kv.setCurrentKey(1); - IntValue default1 = kv.value(); - - kv.setCurrentKey(2); - IntValue default2 = kv.value(); - - assertNotNull(default1); - assertNotNull(default2); - assertEquals(default1, default2); - assertFalse(default1 == default2); - } - catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); - } - finally { - deleteDirectorySilently(tempDir); - } - } - // ------------------------------------------------------------------------ // Utilities // ------------------------------------------------------------------------ @@ -437,6 +243,9 @@ public class FileStateBackendTest { } private static boolean isDirectoryEmpty(File directory) { + if (!directory.exists()) { + return true; + } String[] nested = directory.list(); return nested == null || nested.length == 0; } @@ -447,15 +256,16 @@ public class FileStateBackendTest { private static void validateBytesInStream(InputStream is, byte[] data) throws IOException { byte[] holder = new byte[data.length]; - int numBytesRead = is.read(holder); - - if (holder.length == 0) { - assertTrue("stream not empty", numBytesRead == 0 || numBytesRead == -1); - } else { - assertEquals("not enough data", holder.length, numBytesRead); + + int pos = 0; + int read; + while (pos < holder.length && (read = is.read(holder, pos, holder.length - pos)) != -1) { + pos += read; } - + + assertEquals("not enough data", holder.length, pos); assertEquals("too much data", -1, is.read()); assertArrayEquals("wrong data", data, holder); } + } http://git-wip-us.apache.org/repos/asf/flink/blob/caf46728/flink-runtime/src/test/java/org/apache/flink/runtime/state/FsCheckpointStateOutputStreamTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/FsCheckpointStateOutputStreamTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/FsCheckpointStateOutputStreamTest.java index 66a7271..5964b72 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/FsCheckpointStateOutputStreamTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/FsCheckpointStateOutputStreamTest.java @@ -48,7 +48,7 @@ public class FsCheckpointStateOutputStreamTest { @Test public void testEmptyState() throws Exception { - StateBackend.CheckpointStateOutputStream stream = new FsStateBackend.FsCheckpointStateOutputStream( + AbstractStateBackend.CheckpointStateOutputStream stream = new FsStateBackend.FsCheckpointStateOutputStream( TEMP_DIR_PATH, FileSystem.getLocalFileSystem(), 1024, 512); StreamStateHandle handle = stream.closeAndGetHandle(); @@ -79,7 +79,7 @@ public class FsCheckpointStateOutputStreamTest { } private void runTest(int numBytes, int bufferSize, int threshold, boolean expectFile) throws Exception { - StateBackend.CheckpointStateOutputStream stream = + AbstractStateBackend.CheckpointStateOutputStream stream = new FsStateBackend.FsCheckpointStateOutputStream( TEMP_DIR_PATH, FileSystem.getLocalFileSystem(), bufferSize, threshold); http://git-wip-us.apache.org/repos/asf/flink/blob/caf46728/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java index 4b5aebd..34354c1 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java @@ -18,15 +18,7 @@ package org.apache.flink.runtime.state; -import org.apache.flink.api.common.typeutils.TypeSerializer; -import org.apache.flink.api.common.typeutils.base.FloatSerializer; -import org.apache.flink.api.common.typeutils.base.IntSerializer; -import org.apache.flink.api.common.typeutils.base.IntValueSerializer; -import org.apache.flink.api.common.typeutils.base.StringSerializer; -import org.apache.flink.api.java.typeutils.runtime.ValueSerializer; import org.apache.flink.runtime.state.memory.MemoryStateBackend; -import org.apache.flink.types.IntValue; -import org.apache.flink.types.StringValue; import org.junit.Test; import java.io.IOException; @@ -39,7 +31,15 @@ import static org.junit.Assert.*; /** * Tests for the {@link org.apache.flink.runtime.state.memory.MemoryStateBackend}. */ -public class MemoryStateBackendTest { +public class MemoryStateBackendTest extends StateBackendTestBase<MemoryStateBackend> { + + @Override + protected MemoryStateBackend getStateBackend() throws Exception { + return new MemoryStateBackend(); + } + + @Override + protected void cleanup() throws Exception { } @Test public void testSerializableState() { @@ -94,7 +94,7 @@ public class MemoryStateBackendTest { state.put("hey there", 2); state.put("the crazy brown fox stumbles over a sentence that does not contain every letter", 77); - StateBackend.CheckpointStateOutputStream os = backend.createCheckpointStateOutputStream(1, 2); + AbstractStateBackend.CheckpointStateOutputStream os = backend.createCheckpointStateOutputStream(1, 2); ObjectOutputStream oos = new ObjectOutputStream(os); oos.writeObject(state); oos.flush(); @@ -122,7 +122,7 @@ public class MemoryStateBackendTest { state.put("hey there", 2); state.put("the crazy brown fox stumbles over a sentence that does not contain every letter", 77); - StateBackend.CheckpointStateOutputStream os = backend.createCheckpointStateOutputStream(1, 2); + AbstractStateBackend.CheckpointStateOutputStream os = backend.createCheckpointStateOutputStream(1, 2); ObjectOutputStream oos = new ObjectOutputStream(os); try { @@ -140,164 +140,4 @@ public class MemoryStateBackendTest { fail(e.getMessage()); } } - - @Test - public void testKeyValueState() { - try { - MemoryStateBackend backend = new MemoryStateBackend(); - - KvState<Integer, String, MemoryStateBackend> kv = - backend.createKvState("s_0", "s", IntSerializer.INSTANCE, StringSerializer.INSTANCE, null); - - assertEquals(0, kv.size()); - - // some modifications to the state - kv.setCurrentKey(1); - assertNull(kv.value()); - kv.update("1"); - assertEquals(1, kv.size()); - kv.setCurrentKey(2); - assertNull(kv.value()); - kv.update("2"); - assertEquals(2, kv.size()); - kv.setCurrentKey(1); - assertEquals("1", kv.value()); - assertEquals(2, kv.size()); - - // draw a snapshot - KvStateSnapshot<Integer, String, MemoryStateBackend> snapshot1 = - kv.snapshot(682375462378L, System.currentTimeMillis()); - - // make some more modifications - kv.setCurrentKey(1); - kv.update("u1"); - kv.setCurrentKey(2); - kv.update("u2"); - kv.setCurrentKey(3); - kv.update("u3"); - - // draw another snapshot - KvStateSnapshot<Integer, String, MemoryStateBackend> snapshot2 = - kv.snapshot(682375462379L, System.currentTimeMillis()); - - // validate the original state - assertEquals(3, kv.size()); - kv.setCurrentKey(1); - assertEquals("u1", kv.value()); - kv.setCurrentKey(2); - assertEquals("u2", kv.value()); - kv.setCurrentKey(3); - assertEquals("u3", kv.value()); - - // restore the first snapshot and validate it - KvState<Integer, String, MemoryStateBackend> restored1 = snapshot1.restoreState(backend, - IntSerializer.INSTANCE, StringSerializer.INSTANCE, null, getClass().getClassLoader(), 1); - - assertEquals(2, restored1.size()); - restored1.setCurrentKey(1); - assertEquals("1", restored1.value()); - restored1.setCurrentKey(2); - assertEquals("2", restored1.value()); - - // restore the first snapshot and validate it - KvState<Integer, String, MemoryStateBackend> restored2 = snapshot2.restoreState(backend, - IntSerializer.INSTANCE, StringSerializer.INSTANCE, null, getClass().getClassLoader(), 1); - - assertEquals(3, restored2.size()); - restored2.setCurrentKey(1); - assertEquals("u1", restored2.value()); - restored2.setCurrentKey(2); - assertEquals("u2", restored2.value()); - restored2.setCurrentKey(3); - assertEquals("u3", restored2.value()); - } - catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); - } - } - - @Test - public void testRestoreWithWrongSerializers() { - try { - MemoryStateBackend backend = new MemoryStateBackend(); - KvState<Integer, String, MemoryStateBackend> kv = - backend.createKvState("s_0", "s", IntSerializer.INSTANCE, StringSerializer.INSTANCE, null); - - kv.setCurrentKey(1); - kv.update("1"); - kv.setCurrentKey(2); - kv.update("2"); - - KvStateSnapshot<Integer, String, MemoryStateBackend> snapshot = - kv.snapshot(682375462378L, System.currentTimeMillis()); - - - @SuppressWarnings("unchecked") - TypeSerializer<Integer> fakeIntSerializer = - (TypeSerializer<Integer>) (TypeSerializer<?>) FloatSerializer.INSTANCE; - - @SuppressWarnings("unchecked") - TypeSerializer<String> fakeStringSerializer = - (TypeSerializer<String>) (TypeSerializer<?>) new ValueSerializer<StringValue>(StringValue.class); - - try { - snapshot.restoreState(backend, fakeIntSerializer, - StringSerializer.INSTANCE, null, getClass().getClassLoader(), 1); - fail("should recognize wrong serializers"); - } catch (IllegalArgumentException e) { - // expected - } catch (Exception e) { - fail("wrong exception"); - } - - try { - snapshot.restoreState(backend, IntSerializer.INSTANCE, - fakeStringSerializer, null, getClass().getClassLoader(), 1); - fail("should recognize wrong serializers"); - } catch (IllegalArgumentException e) { - // expected - } catch (Exception e) { - fail("wrong exception"); - } - - try { - snapshot.restoreState(backend, fakeIntSerializer, - fakeStringSerializer, null, getClass().getClassLoader(), 1); - fail("should recognize wrong serializers"); - } catch (IllegalArgumentException e) { - // expected - } catch (Exception e) { - fail("wrong exception"); - } - } - catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); - } - } - - @Test - public void testCopyDefaultValue() { - try { - MemoryStateBackend backend = new MemoryStateBackend(); - KvState<Integer, IntValue, MemoryStateBackend> kv = - backend.createKvState("a_0", "a", IntSerializer.INSTANCE, IntValueSerializer.INSTANCE, new IntValue(-1)); - - kv.setCurrentKey(1); - IntValue default1 = kv.value(); - - kv.setCurrentKey(2); - IntValue default2 = kv.value(); - - assertNotNull(default1); - assertNotNull(default2); - assertEquals(default1, default2); - assertFalse(default1 == default2); - } - catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); - } - } } http://git-wip-us.apache.org/repos/asf/flink/blob/caf46728/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java new file mode 100644 index 0000000..82ab3b3 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java @@ -0,0 +1,494 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state; + +import com.google.common.base.Joiner; +import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.state.ReducingState; +import org.apache.flink.api.common.state.ReducingStateDescriptor; +import org.apache.flink.api.common.state.ValueState; +import org.apache.flink.api.common.state.ValueStateDescriptor; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.common.typeutils.base.FloatSerializer; +import org.apache.flink.api.common.typeutils.base.IntSerializer; +import org.apache.flink.api.common.typeutils.base.IntValueSerializer; +import org.apache.flink.api.common.typeutils.base.StringSerializer; +import org.apache.flink.api.common.typeutils.base.VoidSerializer; +import org.apache.flink.runtime.operators.testutils.DummyEnvironment; +import org.apache.flink.types.IntValue; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import static org.junit.Assert.*; + +/** + * Generic tests for the partitioned state part of {@link AbstractStateBackend}. + */ +public abstract class StateBackendTestBase<B extends AbstractStateBackend> { + + protected B backend; + + protected abstract B getStateBackend() throws Exception; + + protected abstract void cleanup() throws Exception; + + @Before + public void setup() throws Exception { + this.backend = getStateBackend(); + } + + @After + public void teardown() throws Exception { + this.backend.dispose(); + cleanup(); + } + + @Test + public void testValueState() throws Exception { + + backend.initializeForJob(new DummyEnvironment("test", 1, 0), "test_op", IntSerializer.INSTANCE); + + ValueStateDescriptor<String> kvId = new ValueStateDescriptor<>("id", null, StringSerializer.INSTANCE); + ValueState<String> state = backend.getPartitionedState(null, VoidSerializer.INSTANCE, kvId); + + @SuppressWarnings("unchecked") + KvState<Integer, Void, ValueState<String>, ValueStateDescriptor<String>, B> kv = + (KvState<Integer, Void, ValueState<String>, ValueStateDescriptor<String>, B>) state; + + // some modifications to the state + kv.setCurrentKey(1); + assertNull(state.value()); + state.update("1"); + kv.setCurrentKey(2); + assertNull(state.value()); + state.update("2"); + kv.setCurrentKey(1); + assertEquals("1", state.value()); + + // draw a snapshot + KvStateSnapshot<Integer, Void, ValueState<String>, ValueStateDescriptor<String>, B> snapshot1 = + kv.snapshot(682375462378L, 2); + + // make some more modifications + kv.setCurrentKey(1); + state.update("u1"); + kv.setCurrentKey(2); + state.update("u2"); + kv.setCurrentKey(3); + state.update("u3"); + + // draw another snapshot + KvStateSnapshot<Integer, Void, ValueState<String>, ValueStateDescriptor<String>, B> snapshot2 = + kv.snapshot(682375462379L, 4); + + // validate the original state + kv.setCurrentKey(1); + assertEquals("u1", state.value()); + kv.setCurrentKey(2); + assertEquals("u2", state.value()); + kv.setCurrentKey(3); + assertEquals("u3", state.value()); + + kv.dispose(); + +// restore the first snapshot and validate it + KvState<Integer, Void, ValueState<String>, ValueStateDescriptor<String>, B> restored1 = snapshot1.restoreState( + backend, + IntSerializer.INSTANCE, + this.getClass().getClassLoader(), 10); + + @SuppressWarnings("unchecked") + ValueState<String> restored1State = (ValueState<String>) restored1; + + restored1.setCurrentKey(1); + assertEquals("1", restored1State.value()); + restored1.setCurrentKey(2); + assertEquals("2", restored1State.value()); + + restored1.dispose(); + + // restore the second snapshot and validate it + KvState<Integer, Void, ValueState<String>, ValueStateDescriptor<String>, B> restored2 = snapshot2.restoreState( + backend, + IntSerializer.INSTANCE, + this.getClass().getClassLoader(), 10); + + @SuppressWarnings("unchecked") + ValueState<String> restored2State = (ValueState<String>) restored2; + + restored2.setCurrentKey(1); + assertEquals("u1", restored2State.value()); + restored2.setCurrentKey(2); + assertEquals("u2", restored2State.value()); + restored2.setCurrentKey(3); + assertEquals("u3", restored2State.value()); + } + + @Test + @SuppressWarnings("unchecked,rawtypes") + public void testListState() { + try { + backend.initializeForJob(new DummyEnvironment("test", 1, 0), "test_op", IntSerializer.INSTANCE); + + ListStateDescriptor<String> kvId = new ListStateDescriptor<>("id", StringSerializer.INSTANCE); + ListState<String> state = backend.getPartitionedState(null, VoidSerializer.INSTANCE, kvId); + + @SuppressWarnings("unchecked") + KvState<Integer, Void, ListState<String>, ListStateDescriptor<String>, B> kv = + (KvState<Integer, Void, ListState<String>, ListStateDescriptor<String>, B>) state; + + Joiner joiner = Joiner.on(","); + // some modifications to the state + kv.setCurrentKey(1); + assertEquals("", joiner.join(state.get())); + state.add("1"); + kv.setCurrentKey(2); + assertEquals("", joiner.join(state.get())); + state.add("2"); + kv.setCurrentKey(1); + assertEquals("1", joiner.join(state.get())); + + // draw a snapshot + KvStateSnapshot<Integer, Void, ListState<String>, ListStateDescriptor<String>, B> snapshot1 = + kv.snapshot(682375462378L, 2); + + // make some more modifications + kv.setCurrentKey(1); + state.add("u1"); + kv.setCurrentKey(2); + state.add("u2"); + kv.setCurrentKey(3); + state.add("u3"); + + // draw another snapshot + KvStateSnapshot<Integer, Void, ListState<String>, ListStateDescriptor<String>, B> snapshot2 = + kv.snapshot(682375462379L, 4); + + // validate the original state + kv.setCurrentKey(1); + assertEquals("1,u1", joiner.join(state.get())); + kv.setCurrentKey(2); + assertEquals("2,u2", joiner.join(state.get())); + kv.setCurrentKey(3); + assertEquals("u3", joiner.join(state.get())); + + kv.dispose(); + + // restore the first snapshot and validate it + KvState<Integer, Void, ListState<String>, ListStateDescriptor<String>, B> restored1 = snapshot1.restoreState( + backend, + IntSerializer.INSTANCE, + this.getClass().getClassLoader(), 10); + + @SuppressWarnings("unchecked") + ListState<String> restored1State = (ListState<String>) restored1; + + restored1.setCurrentKey(1); + assertEquals("1", joiner.join(restored1State.get())); + restored1.setCurrentKey(2); + assertEquals("2", joiner.join(restored1State.get())); + + restored1.dispose(); + + // restore the second snapshot and validate it + KvState<Integer, Void, ListState<String>, ListStateDescriptor<String>, B> restored2 = snapshot2.restoreState( + backend, + IntSerializer.INSTANCE, + this.getClass().getClassLoader(), 20); + + @SuppressWarnings("unchecked") + ListState<String> restored2State = (ListState<String>) restored2; + + restored2.setCurrentKey(1); + assertEquals("1,u1", joiner.join(restored2State.get())); + restored2.setCurrentKey(2); + assertEquals("2,u2", joiner.join(restored2State.get())); + restored2.setCurrentKey(3); + assertEquals("u3", joiner.join(restored2State.get())); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + @SuppressWarnings("unchecked,rawtypes") + public void testReducingState() { + try { + backend.initializeForJob(new DummyEnvironment("test", 1, 0), "test_op", IntSerializer.INSTANCE); + + ReducingStateDescriptor<String> kvId = new ReducingStateDescriptor<>("id", + new ReduceFunction<String>() { + private static final long serialVersionUID = 1L; + + @Override + public String reduce(String value1, String value2) throws Exception { + return value1 + "," + value2; + } + }, + StringSerializer.INSTANCE); + ReducingState<String> state = backend.getPartitionedState(null, VoidSerializer.INSTANCE, kvId); + + @SuppressWarnings("unchecked") + KvState<Integer, Void, ReducingState<String>, ReducingStateDescriptor<String>, B> kv = + (KvState<Integer, Void, ReducingState<String>, ReducingStateDescriptor<String>, B>) state; + + Joiner joiner = Joiner.on(","); + // some modifications to the state + kv.setCurrentKey(1); + assertEquals(null, state.get()); + state.add("1"); + kv.setCurrentKey(2); + assertEquals(null, state.get()); + state.add("2"); + kv.setCurrentKey(1); + assertEquals("1", state.get()); + + // draw a snapshot + KvStateSnapshot<Integer, Void, ReducingState<String>, ReducingStateDescriptor<String>, B> snapshot1 = + kv.snapshot(682375462378L, 2); + + // make some more modifications + kv.setCurrentKey(1); + state.add("u1"); + kv.setCurrentKey(2); + state.add("u2"); + kv.setCurrentKey(3); + state.add("u3"); + + // draw another snapshot + KvStateSnapshot<Integer, Void, ReducingState<String>, ReducingStateDescriptor<String>, B> snapshot2 = + kv.snapshot(682375462379L, 4); + + // validate the original state + kv.setCurrentKey(1); + assertEquals("1,u1", state.get()); + kv.setCurrentKey(2); + assertEquals("2,u2", state.get()); + kv.setCurrentKey(3); + assertEquals("u3", state.get()); + + kv.dispose(); + + // restore the first snapshot and validate it + KvState<Integer, Void, ReducingState<String>, ReducingStateDescriptor<String>, B> restored1 = snapshot1.restoreState( + backend, + IntSerializer.INSTANCE, + this.getClass().getClassLoader(), 10); + + @SuppressWarnings("unchecked") + ReducingState<String> restored1State = (ReducingState<String>) restored1; + + restored1.setCurrentKey(1); + assertEquals("1", restored1State.get()); + restored1.setCurrentKey(2); + assertEquals("2", restored1State.get()); + + restored1.dispose(); + + // restore the second snapshot and validate it + KvState<Integer, Void, ReducingState<String>, ReducingStateDescriptor<String>, B> restored2 = snapshot2.restoreState( + backend, + IntSerializer.INSTANCE, + this.getClass().getClassLoader(), 20); + + @SuppressWarnings("unchecked") + ReducingState<String> restored2State = (ReducingState<String>) restored2; + + restored2.setCurrentKey(1); + assertEquals("1,u1", restored2State.get()); + restored2.setCurrentKey(2); + assertEquals("2,u2", restored2State.get()); + restored2.setCurrentKey(3); + assertEquals("u3", restored2State.get()); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + + @Test + public void testValueStateRestoreWithWrongSerializers() { + try { + backend.initializeForJob(new DummyEnvironment("test", 1, 0), + "test_op", + IntSerializer.INSTANCE); + + ValueStateDescriptor<String> kvId = new ValueStateDescriptor<>("id", + null, + StringSerializer.INSTANCE); + ValueState<String> state = backend.getPartitionedState(null, + VoidSerializer.INSTANCE, + kvId); + + @SuppressWarnings("unchecked") + KvState<Integer, Void, ValueState<String>, ValueStateDescriptor<String>, B> kv = + (KvState<Integer, Void, ValueState<String>, ValueStateDescriptor<String>, B>) state; + + kv.setCurrentKey(1); + state.update("1"); + kv.setCurrentKey(2); + state.update("2"); + + KvStateSnapshot<Integer, Void, ValueState<String>, ValueStateDescriptor<String>, B> snapshot = + kv.snapshot(682375462378L, System.currentTimeMillis()); + + @SuppressWarnings("unchecked") + TypeSerializer<Integer> fakeIntSerializer = + (TypeSerializer<Integer>) (TypeSerializer<?>) FloatSerializer.INSTANCE; + + try { + snapshot.restoreState(backend, fakeIntSerializer, getClass().getClassLoader(), 1); + fail("should recognize wrong serializers"); + } catch (IllegalArgumentException e) { + // expected + } catch (Exception e) { + fail("wrong exception"); + } + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testListStateRestoreWithWrongSerializers() { + try { + backend.initializeForJob(new DummyEnvironment("test", 1, 0), "test_op", IntSerializer.INSTANCE); + + ListStateDescriptor<String> kvId = new ListStateDescriptor<>("id", StringSerializer.INSTANCE); + ListState<String> state = backend.getPartitionedState(null, VoidSerializer.INSTANCE, kvId); + + @SuppressWarnings("unchecked") + KvState<Integer, Void, ListState<String>, ListStateDescriptor<String>, B> kv = + (KvState<Integer, Void, ListState<String>, ListStateDescriptor<String>, B>) state; + + kv.setCurrentKey(1); + state.add("1"); + kv.setCurrentKey(2); + state.add("2"); + + KvStateSnapshot<Integer, Void, ListState<String>, ListStateDescriptor<String>, B> snapshot = + kv.snapshot(682375462378L, System.currentTimeMillis()); + + kv.dispose(); + + @SuppressWarnings("unchecked") + TypeSerializer<Integer> fakeIntSerializer = + (TypeSerializer<Integer>) (TypeSerializer<?>) FloatSerializer.INSTANCE; + + try { + snapshot.restoreState(backend, fakeIntSerializer, getClass().getClassLoader(), 1); + fail("should recognize wrong serializers"); + } catch (IllegalArgumentException e) { + // expected + } catch (Exception e) { + fail("wrong exception " + e); + } + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testReducingStateRestoreWithWrongSerializers() { + try { + backend.initializeForJob(new DummyEnvironment("test", 1, 0), "test_op", IntSerializer.INSTANCE); + + ReducingStateDescriptor<String> kvId = new ReducingStateDescriptor<>("id", + new ReduceFunction<String>() { + @Override + public String reduce(String value1, String value2) throws Exception { + return value1 + "," + value2; + } + }, + StringSerializer.INSTANCE); + ReducingState<String> state = backend.getPartitionedState(null, VoidSerializer.INSTANCE, kvId); + + @SuppressWarnings("unchecked") + KvState<Integer, Void, ReducingState<String>, ReducingStateDescriptor<String>, B> kv = + (KvState<Integer, Void, ReducingState<String>, ReducingStateDescriptor<String>, B>) state; + + kv.setCurrentKey(1); + state.add("1"); + kv.setCurrentKey(2); + state.add("2"); + + KvStateSnapshot<Integer, Void, ReducingState<String>, ReducingStateDescriptor<String>, B> snapshot = + kv.snapshot(682375462378L, System.currentTimeMillis()); + + kv.dispose(); + + @SuppressWarnings("unchecked") + TypeSerializer<Integer> fakeIntSerializer = + (TypeSerializer<Integer>) (TypeSerializer<?>) FloatSerializer.INSTANCE; + + try { + snapshot.restoreState(backend, fakeIntSerializer, getClass().getClassLoader(), 1); + fail("should recognize wrong serializers"); + } catch (IllegalArgumentException e) { + // expected + } catch (Exception e) { + fail("wrong exception " + e); + } + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testCopyDefaultValue() { + try { + backend.initializeForJob(new DummyEnvironment("test", 1, 0), "test_op", IntSerializer.INSTANCE); + + ValueStateDescriptor<IntValue> kvId = new ValueStateDescriptor<>("id", new IntValue(-1), IntValueSerializer.INSTANCE); + ValueState<IntValue> state = backend.getPartitionedState(null, VoidSerializer.INSTANCE, kvId); + + @SuppressWarnings("unchecked") + KvState<Integer, Void, ValueState<IntValue>, ValueStateDescriptor<IntValue>, B> kv = + (KvState<Integer, Void, ValueState<IntValue>, ValueStateDescriptor<IntValue>, B>) state; + + kv.setCurrentKey(1); + IntValue default1 = state.value(); + + kv.setCurrentKey(2); + IntValue default2 = state.value(); + + assertNotNull(default1); + assertNotNull(default2); + assertEquals(default1, default2); + assertFalse(default1 == default2); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/caf46728/flink-streaming-connectors/flink-connector-filesystem/src/main/java/org/apache/flink/streaming/connectors/fs/RollingSink.java ---------------------------------------------------------------------- diff --git a/flink-streaming-connectors/flink-connector-filesystem/src/main/java/org/apache/flink/streaming/connectors/fs/RollingSink.java b/flink-streaming-connectors/flink-connector-filesystem/src/main/java/org/apache/flink/streaming/connectors/fs/RollingSink.java index 2112b28..afae68f 100644 --- a/flink-streaming-connectors/flink-connector-filesystem/src/main/java/org/apache/flink/streaming/connectors/fs/RollingSink.java +++ b/flink-streaming-connectors/flink-connector-filesystem/src/main/java/org/apache/flink/streaming/connectors/fs/RollingSink.java @@ -22,7 +22,7 @@ import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.typeutils.InputTypeConfigurable; import org.apache.flink.configuration.Configuration; -import org.apache.flink.streaming.api.checkpoint.CheckpointNotifier; +import org.apache.flink.runtime.state.CheckpointListener; import org.apache.flink.streaming.api.checkpoint.Checkpointed; import org.apache.flink.streaming.api.functions.sink.RichSinkFunction; import org.apache.hadoop.fs.FSDataOutputStream; @@ -119,7 +119,7 @@ import java.util.UUID; * * @param <T> Type of the elements emitted by this sink */ -public class RollingSink<T> extends RichSinkFunction<T> implements InputTypeConfigurable, Checkpointed<RollingSink.BucketState>, CheckpointNotifier { +public class RollingSink<T> extends RichSinkFunction<T> implements InputTypeConfigurable, Checkpointed<RollingSink.BucketState>, CheckpointListener { private static final long serialVersionUID = 1L; private static Logger LOG = LoggerFactory.getLogger(RollingSink.class); http://git-wip-us.apache.org/repos/asf/flink/blob/caf46728/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBase.java ---------------------------------------------------------------------- diff --git a/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBase.java b/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBase.java index 3c36586..a513637 100644 --- a/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBase.java +++ b/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBase.java @@ -20,7 +20,7 @@ package org.apache.flink.streaming.connectors.kafka; import org.apache.commons.collections.map.LinkedMap; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.typeutils.ResultTypeQueryable; -import org.apache.flink.streaming.api.checkpoint.CheckpointNotifier; +import org.apache.flink.runtime.state.CheckpointListener; import org.apache.flink.streaming.api.checkpoint.CheckpointedAsynchronously; import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction; import org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartition; @@ -39,7 +39,7 @@ import static com.google.common.base.Preconditions.checkNotNull; public abstract class FlinkKafkaConsumerBase<T> extends RichParallelSourceFunction<T> - implements CheckpointNotifier, CheckpointedAsynchronously<HashMap<KafkaTopicPartition, Long>>, ResultTypeQueryable<T> { + implements CheckpointListener, CheckpointedAsynchronously<HashMap<KafkaTopicPartition, Long>>, ResultTypeQueryable<T> { // ------------------------------------------------------------------------ http://git-wip-us.apache.org/repos/asf/flink/blob/caf46728/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/KafkaConsumerTestBase.java ---------------------------------------------------------------------- diff --git a/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/KafkaConsumerTestBase.java b/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/KafkaConsumerTestBase.java index 3d39869..2d9f2fc 100644 --- a/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/KafkaConsumerTestBase.java +++ b/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/KafkaConsumerTestBase.java @@ -45,7 +45,7 @@ import org.apache.flink.core.memory.DataInputView; import org.apache.flink.core.memory.DataInputViewStreamWrapper; import org.apache.flink.runtime.client.JobExecutionException; import org.apache.flink.runtime.jobmanager.scheduler.NoResourceAvailableException; -import org.apache.flink.streaming.api.checkpoint.CheckpointNotifier; +import org.apache.flink.runtime.state.CheckpointListener; import org.apache.flink.streaming.api.checkpoint.Checkpointed; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.datastream.DataStreamSource; @@ -1296,7 +1296,7 @@ public abstract class KafkaConsumerTestBase extends KafkaTestBase { public static class BrokerKillingMapper<T> extends RichMapFunction<T,T> - implements Checkpointed<Integer>, CheckpointNotifier { + implements Checkpointed<Integer>, CheckpointListener { private static final long serialVersionUID = 6334389850158707313L; http://git-wip-us.apache.org/repos/asf/flink/blob/caf46728/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/testutils/FailingIdentityMapper.java ---------------------------------------------------------------------- diff --git a/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/testutils/FailingIdentityMapper.java b/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/testutils/FailingIdentityMapper.java index 5a8ffaa..2bd400c 100644 --- a/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/testutils/FailingIdentityMapper.java +++ b/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/testutils/FailingIdentityMapper.java @@ -20,14 +20,14 @@ package org.apache.flink.streaming.connectors.kafka.testutils; import org.apache.flink.api.common.functions.RichMapFunction; import org.apache.flink.configuration.Configuration; -import org.apache.flink.streaming.api.checkpoint.CheckpointNotifier; +import org.apache.flink.runtime.state.CheckpointListener; import org.apache.flink.streaming.api.checkpoint.Checkpointed; import org.slf4j.Logger; import org.slf4j.LoggerFactory; public class FailingIdentityMapper<T> extends RichMapFunction<T,T> implements - Checkpointed<Integer>, CheckpointNotifier, Runnable { + Checkpointed<Integer>, CheckpointListener, Runnable { private static final Logger LOG = LoggerFactory.getLogger(FailingIdentityMapper.class); http://git-wip-us.apache.org/repos/asf/flink/blob/caf46728/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/testutils/MockRuntimeContext.java ---------------------------------------------------------------------- diff --git a/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/testutils/MockRuntimeContext.java b/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/testutils/MockRuntimeContext.java index 50c57ab..ee246bb 100644 --- a/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/testutils/MockRuntimeContext.java +++ b/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/testutils/MockRuntimeContext.java @@ -26,7 +26,9 @@ import org.apache.flink.api.common.accumulators.IntCounter; import org.apache.flink.api.common.accumulators.LongCounter; import org.apache.flink.api.common.cache.DistributedCache; import org.apache.flink.api.common.functions.BroadcastVariableInitializer; -import org.apache.flink.api.common.state.OperatorState; +import org.apache.flink.api.common.state.State; +import org.apache.flink.api.common.state.StateDescriptor; +import org.apache.flink.api.common.state.ValueState; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.runtime.memory.MemoryManager; import org.apache.flink.runtime.operators.testutils.MockEnvironment; @@ -146,12 +148,17 @@ public class MockRuntimeContext extends StreamingRuntimeContext { } @Override - public <S> OperatorState<S> getKeyValueState(String name, Class<S> stateType, S defaultState) { + public <S> ValueState<S> getKeyValueState(String name, Class<S> stateType, S defaultState) { throw new UnsupportedOperationException(); } @Override - public <S> OperatorState<S> getKeyValueState(String name, TypeInformation<S> stateType, S defaultState) { + public <S> ValueState<S> getKeyValueState(String name, TypeInformation<S> stateType, S defaultState) { + throw new UnsupportedOperationException(); + } + + @Override + public <S extends State> S getPartitionedState(StateDescriptor<S> stateDescriptor) { throw new UnsupportedOperationException(); } } http://git-wip-us.apache.org/repos/asf/flink/blob/caf46728/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/CheckpointNotifier.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/CheckpointNotifier.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/CheckpointNotifier.java deleted file mode 100644 index c2d2182..0000000 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/CheckpointNotifier.java +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.flink.streaming.api.checkpoint; - -/** - * This interface must be implemented by functions/operations that want to receive - * a commit notification once a checkpoint has been completely acknowledged by all - * participants. - */ -public interface CheckpointNotifier { - - /** - * This method is called as a notification once a distributed checkpoint has been completed. - * - * Note that any exception during this method will not cause the checkpoint to - * fail any more. - * - * @param checkpointId The ID of the checkpoint that has been completed. - * @throws Exception - */ - void notifyCheckpointComplete(long checkpointId) throws Exception; -} http://git-wip-us.apache.org/repos/asf/flink/blob/caf46728/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/ConnectedStreams.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/ConnectedStreams.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/ConnectedStreams.java index 4074a1d..395b329 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/ConnectedStreams.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/ConnectedStreams.java @@ -321,6 +321,21 @@ public class ConnectedStreams<IN1, IN2> { outTypeInfo, environment.getParallelism()); + if (inputStream1 instanceof KeyedStream && inputStream2 instanceof KeyedStream) { + KeyedStream<IN1, ?> keyedInput1 = (KeyedStream<IN1, ?>) inputStream1; + KeyedStream<IN2, ?> keyedInput2 = (KeyedStream<IN2, ?>) inputStream2; + + TypeInformation<?> keyType1 = keyedInput1.getKeyType(); + TypeInformation<?> keyType2 = keyedInput2.getKeyType(); + if (!(keyType1.canEqual(keyType2) && keyType1.equals(keyType2))) { + throw new UnsupportedOperationException("Key types if input KeyedStreams " + + "don't match: " + keyType1 + " and " + keyType2 + "."); + } + + transform.setStateKeySelectors(keyedInput1.getKeySelector(), keyedInput2.getKeySelector()); + transform.setStateKeyType(keyType1); + } + @SuppressWarnings({ "unchecked", "rawtypes" }) SingleOutputStreamOperator<OUT, ?> returnStream = new SingleOutputStreamOperator(environment, transform); http://git-wip-us.apache.org/repos/asf/flink/blob/caf46728/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/environment/StreamExecutionEnvironment.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/environment/StreamExecutionEnvironment.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/environment/StreamExecutionEnvironment.java index cb5fce5..f4b3184 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/environment/StreamExecutionEnvironment.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/environment/StreamExecutionEnvironment.java @@ -26,6 +26,7 @@ import org.apache.flink.api.common.JobExecutionResult; import org.apache.flink.api.common.functions.InvalidTypesException; import org.apache.flink.api.common.io.FileInputFormat; import org.apache.flink.api.common.io.InputFormat; +import org.apache.flink.api.common.state.ValueState; import org.apache.flink.api.common.typeinfo.BasicTypeInfo; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.ClosureCleaner; @@ -62,7 +63,7 @@ import org.apache.flink.streaming.api.functions.source.StatefulSequenceSource; import org.apache.flink.streaming.api.graph.StreamGraph; import org.apache.flink.streaming.api.graph.StreamGraphGenerator; import org.apache.flink.streaming.api.operators.StreamSource; -import org.apache.flink.runtime.state.StateBackend; +import org.apache.flink.runtime.state.AbstractStateBackend; import org.apache.flink.streaming.api.transformations.StreamTransformation; import org.apache.flink.types.StringValue; import org.apache.flink.util.SplittableIterator; @@ -124,7 +125,7 @@ public abstract class StreamExecutionEnvironment { protected boolean isChainingEnabled = true; /** The state backend used for storing k/v state and state snapshots */ - private StateBackend<?> defaultStateBackend; + private AbstractStateBackend defaultStateBackend; /** The time characteristic used by the data streams */ private TimeCharacteristic timeCharacteristic = DEFAULT_TIME_CHARACTERISTIC; @@ -376,7 +377,7 @@ public abstract class StreamExecutionEnvironment { /** * Sets the state backend that describes how to store and checkpoint operator state. It defines in - * what form the key/value state ({@link org.apache.flink.api.common.state.OperatorState}, accessible + * what form the key/value state ({@link ValueState}, accessible * from operations on {@link org.apache.flink.streaming.api.datastream.KeyedStream}) is maintained * (heap, managed memory, externally), and where state snapshots/checkpoints are stored, both for * the key/value state, and for checkpointed functions (implementing the interface @@ -396,7 +397,7 @@ public abstract class StreamExecutionEnvironment { * * @see #getStateBackend() */ - public StreamExecutionEnvironment setStateBackend(StateBackend<?> backend) { + public StreamExecutionEnvironment setStateBackend(AbstractStateBackend backend) { this.defaultStateBackend = requireNonNull(backend); return this; } @@ -405,9 +406,9 @@ public abstract class StreamExecutionEnvironment { * Returns the state backend that defines how to store and checkpoint state. * @return The state backend that defines how to store and checkpoint state. * - * @see #setStateBackend(StateBackend) + * @see #setStateBackend(AbstractStateBackend) */ - public StateBackend<?> getStateBackend() { + public AbstractStateBackend getStateBackend() { return defaultStateBackend; } http://git-wip-us.apache.org/repos/asf/flink/blob/caf46728/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/MessageAcknowledgingSourceBase.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/MessageAcknowledgingSourceBase.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/MessageAcknowledgingSourceBase.java index 4385884..e7da5f8 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/MessageAcknowledgingSourceBase.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/MessageAcknowledgingSourceBase.java @@ -31,7 +31,7 @@ import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.typeutils.TypeExtractor; import org.apache.flink.configuration.Configuration; -import org.apache.flink.streaming.api.checkpoint.CheckpointNotifier; +import org.apache.flink.runtime.state.CheckpointListener; import org.apache.flink.streaming.api.checkpoint.Checkpointed; import org.apache.flink.runtime.state.SerializedCheckpointData; import org.slf4j.Logger; @@ -78,7 +78,7 @@ import org.slf4j.LoggerFactory; */ public abstract class MessageAcknowledgingSourceBase<Type, UId> extends RichSourceFunction<Type> - implements Checkpointed<SerializedCheckpointData[]>, CheckpointNotifier { + implements Checkpointed<SerializedCheckpointData[]>, CheckpointListener { private static final long serialVersionUID = -8689291992192955579L; http://git-wip-us.apache.org/repos/asf/flink/blob/caf46728/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamConfig.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamConfig.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamConfig.java index 11bf84f..7a07c79 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamConfig.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamConfig.java @@ -31,7 +31,7 @@ import org.apache.flink.runtime.util.ClassLoaderUtil; import org.apache.flink.streaming.api.CheckpointingMode; import org.apache.flink.streaming.api.collector.selector.OutputSelectorWrapper; import org.apache.flink.streaming.api.operators.StreamOperator; -import org.apache.flink.runtime.state.StateBackend; +import org.apache.flink.runtime.state.AbstractStateBackend; import org.apache.flink.streaming.runtime.tasks.StreamTaskException; import org.apache.flink.util.InstantiationUtil; @@ -370,7 +370,7 @@ public class StreamConfig implements Serializable { // State backend // ------------------------------------------------------------------------ - public void setStateBackend(StateBackend<?> backend) { + public void setStateBackend(AbstractStateBackend backend) { try { InstantiationUtil.writeObjectToConfig(backend, this.config, STATE_BACKEND); } catch (Exception e) { @@ -378,7 +378,7 @@ public class StreamConfig implements Serializable { } } - public StateBackend<?> getStateBackend(ClassLoader cl) { + public AbstractStateBackend getStateBackend(ClassLoader cl) { try { return InstantiationUtil.readObjectFromConfig(this.config, STATE_BACKEND, cl); } catch (Exception e) { @@ -386,17 +386,17 @@ public class StreamConfig implements Serializable { } } - public void setStatePartitioner(KeySelector<?, ?> partitioner) { + public void setStatePartitioner(int input, KeySelector<?, ?> partitioner) { try { - InstantiationUtil.writeObjectToConfig(partitioner, this.config, STATE_PARTITIONER); + InstantiationUtil.writeObjectToConfig(partitioner, this.config, STATE_PARTITIONER + input); } catch (IOException e) { throw new StreamTaskException("Could not serialize state partitioner.", e); } } - public KeySelector<?, Serializable> getStatePartitioner(ClassLoader cl) { + public KeySelector<?, Serializable> getStatePartitioner(int input, ClassLoader cl) { try { - return InstantiationUtil.readObjectFromConfig(this.config, STATE_PARTITIONER, cl); + return InstantiationUtil.readObjectFromConfig(this.config, STATE_PARTITIONER + input, cl); } catch (Exception e) { throw new StreamTaskException("Could not instantiate state partitioner.", e); } http://git-wip-us.apache.org/repos/asf/flink/blob/caf46728/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamGraph.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamGraph.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamGraph.java index fa8c9d4..ea85f05 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamGraph.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamGraph.java @@ -47,7 +47,7 @@ import org.apache.flink.streaming.api.operators.OutputTypeConfigurable; import org.apache.flink.streaming.api.operators.StreamOperator; import org.apache.flink.streaming.api.operators.StreamSource; import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; -import org.apache.flink.runtime.state.StateBackend; +import org.apache.flink.runtime.state.AbstractStateBackend; import org.apache.flink.streaming.runtime.partitioner.ForwardPartitioner; import org.apache.flink.streaming.runtime.partitioner.RebalancePartitioner; import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner; @@ -85,7 +85,7 @@ public class StreamGraph extends StreamingPlan { protected Map<Integer, String> vertexIDtoBrokerID; protected Map<Integer, Long> vertexIDtoLoopTimeout; - private StateBackend<?> stateBackend; + private AbstractStateBackend stateBackend; private Set<Tuple2<StreamNode, StreamNode>> iterationSourceSinkPairs; @@ -132,11 +132,11 @@ public class StreamGraph extends StreamingPlan { this.chaining = chaining; } - public void setStateBackend(StateBackend<?> backend) { + public void setStateBackend(AbstractStateBackend backend) { this.stateBackend = backend; } - public StateBackend<?> getStateBackend() { + public AbstractStateBackend getStateBackend() { return this.stateBackend; } @@ -363,9 +363,16 @@ public class StreamGraph extends StreamingPlan { } } - public void setKey(Integer vertexID, KeySelector<?, ?> keySelector, TypeSerializer<?> keySerializer) { + public void setOneInputStateKey(Integer vertexID, KeySelector<?, ?> keySelector, TypeSerializer<?> keySerializer) { StreamNode node = getStreamNode(vertexID); - node.setStatePartitioner(keySelector); + node.setStatePartitioner1(keySelector); + node.setStateKeySerializer(keySerializer); + } + + public void setTwoInputStateKey(Integer vertexID, KeySelector<?, ?> keySelector1, KeySelector<?, ?> keySelector2, TypeSerializer<?> keySerializer) { + StreamNode node = getStreamNode(vertexID); + node.setStatePartitioner1(keySelector1); + node.setStatePartitioner2(keySelector2); node.setStateKeySerializer(keySerializer); } http://git-wip-us.apache.org/repos/asf/flink/blob/caf46728/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphGenerator.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphGenerator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphGenerator.java index 91c5e0f..f200bed 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphGenerator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphGenerator.java @@ -439,7 +439,7 @@ public class StreamGraphGenerator { if (sink.getStateKeySelector() != null) { TypeSerializer<?> keySerializer = sink.getStateKeyType().createSerializer(env.getConfig()); - streamGraph.setKey(sink.getId(), sink.getStateKeySelector(), keySerializer); + streamGraph.setOneInputStateKey(sink.getId(), sink.getStateKeySelector(), keySerializer); } return Collections.emptyList(); @@ -469,10 +469,7 @@ public class StreamGraphGenerator { if (transform.getStateKeySelector() != null) { TypeSerializer<?> keySerializer = transform.getStateKeyType().createSerializer(env.getConfig()); - streamGraph.setKey(transform.getId(), transform.getStateKeySelector(), keySerializer); - } - if (transform.getStateKeyType() != null) { - + streamGraph.setOneInputStateKey(transform.getId(), transform.getStateKeySelector(), keySerializer); } streamGraph.setParallelism(transform.getId(), transform.getParallelism()); @@ -509,6 +506,12 @@ public class StreamGraphGenerator { transform.getOutputType(), transform.getName()); + if (transform.getStateKeySelector1() != null) { + TypeSerializer<?> keySerializer = transform.getStateKeyType().createSerializer(env.getConfig()); + streamGraph.setTwoInputStateKey(transform.getId(), transform.getStateKeySelector1(), transform.getStateKeySelector2(), keySerializer); + } + + streamGraph.setParallelism(transform.getId(), transform.getParallelism()); for (Integer inputId: inputIds1) { http://git-wip-us.apache.org/repos/asf/flink/blob/caf46728/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamNode.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamNode.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamNode.java index 77b7cb4..0a612f3 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamNode.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamNode.java @@ -49,7 +49,8 @@ public class StreamNode implements Serializable { private String operatorName; private Integer slotSharingID; private boolean isolatedSlot = false; - private KeySelector<?,?> statePartitioner; + private KeySelector<?,?> statePartitioner1; + private KeySelector<?,?> statePartitioner2; private TypeSerializer<?> stateKeySerializer; private transient StreamOperator<?> operator; @@ -228,12 +229,20 @@ public class StreamNode implements Serializable { return operatorName + "-" + id; } - public KeySelector<?, ?> getStatePartitioner() { - return statePartitioner; + public KeySelector<?, ?> getStatePartitioner1() { + return statePartitioner1; } - public void setStatePartitioner(KeySelector<?, ?> statePartitioner) { - this.statePartitioner = statePartitioner; + public KeySelector<?, ?> getStatePartitioner2() { + return statePartitioner2; + } + + public void setStatePartitioner1(KeySelector<?, ?> statePartitioner) { + this.statePartitioner1 = statePartitioner; + } + + public void setStatePartitioner2(KeySelector<?, ?> statePartitioner) { + this.statePartitioner2 = statePartitioner; } public TypeSerializer<?> getStateKeySerializer() { http://git-wip-us.apache.org/repos/asf/flink/blob/caf46728/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java index 56b16a4..5da2caa 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java @@ -327,7 +327,8 @@ public class StreamingJobGraphGenerator { // so we use that one if checkpointing is not enabled config.setCheckpointMode(CheckpointingMode.AT_LEAST_ONCE); } - config.setStatePartitioner(vertex.getStatePartitioner()); + config.setStatePartitioner(0, vertex.getStatePartitioner1()); + config.setStatePartitioner(1, vertex.getStatePartitioner2()); config.setStateKeySerializer(vertex.getStateKeySerializer()); Class<? extends AbstractInvokable> vertexClass = vertex.getJobVertexClass();
