[FLINK-4603] [checkpoints] Fix user code classloading in KeyedStateBackend This closes #2533
Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/3b8fe95e Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/3b8fe95e Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/3b8fe95e Branch: refs/heads/master Commit: 3b8fe95ec728d59e3ffba2901450c56d7cca2b24 Parents: e6fbda9 Author: Stefan Richter <s.rich...@data-artisans.com> Authored: Wed Sep 21 14:55:58 2016 +0200 Committer: Stephan Ewen <se...@apache.org> Committed: Thu Sep 22 14:42:12 2016 +0200 ---------------------------------------------------------------------- .../state/RocksDBKeyedStateBackend.java | 19 +- .../streaming/state/RocksDBStateBackend.java | 2 + .../apache/flink/util/InstantiationUtil.java | 6 +- .../flink/runtime/state/KeyedStateBackend.java | 4 + .../state/filesystem/FsStateBackend.java | 2 + .../state/heap/HeapKeyedStateBackend.java | 31 +-- .../state/memory/MemoryStateBackend.java | 5 +- .../streaming/runtime/tasks/StreamTask.java | 8 +- flink-tests/pom.xml | 19 ++ ...t-checkpointing-custom_kv_state-assembly.xml | 38 +++ .../test/classloading/ClassLoaderITCase.java | 25 +- .../jar/CheckpointingCustomKvStateProgram.java | 233 +++++++++++++++++++ 12 files changed, 363 insertions(+), 29 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/3b8fe95e/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java ---------------------------------------------------------------------- diff --git a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java index 177c09f..d5a96af 100644 --- a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java +++ b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java @@ -47,6 +47,7 @@ import org.apache.flink.runtime.state.KeyGroupsStateHandle; import org.apache.flink.runtime.state.KeyedStateBackend; import org.apache.flink.runtime.state.StreamStateHandle; import org.apache.flink.runtime.util.SerializableObject; +import org.apache.flink.util.InstantiationUtil; import org.apache.flink.util.Preconditions; import org.rocksdb.ColumnFamilyDescriptor; import org.rocksdb.ColumnFamilyHandle; @@ -63,8 +64,6 @@ import org.slf4j.LoggerFactory; import javax.annotation.concurrent.GuardedBy; import java.io.File; import java.io.IOException; -import java.io.ObjectInputStream; -import java.io.ObjectOutputStream; import java.util.ArrayList; import java.util.Comparator; import java.util.HashMap; @@ -125,6 +124,7 @@ public class RocksDBKeyedStateBackend<K> extends KeyedStateBackend<K> { public RocksDBKeyedStateBackend( JobID jobId, String operatorIdentifier, + ClassLoader userCodeClassLoader, File instanceBasePath, DBOptions dbOptions, ColumnFamilyOptions columnFamilyOptions, @@ -134,7 +134,7 @@ public class RocksDBKeyedStateBackend<K> extends KeyedStateBackend<K> { KeyGroupRange keyGroupRange ) throws Exception { - super(kvStateRegistry, keySerializer, numberOfKeyGroups, keyGroupRange); + super(kvStateRegistry, keySerializer, userCodeClassLoader, numberOfKeyGroups, keyGroupRange); this.operatorIdentifier = operatorIdentifier; this.jobId = jobId; @@ -177,6 +177,7 @@ public class RocksDBKeyedStateBackend<K> extends KeyedStateBackend<K> { public RocksDBKeyedStateBackend( JobID jobId, String operatorIdentifier, + ClassLoader userCodeClassLoader, File instanceBasePath, DBOptions dbOptions, ColumnFamilyOptions columnFamilyOptions, @@ -189,6 +190,7 @@ public class RocksDBKeyedStateBackend<K> extends KeyedStateBackend<K> { this( jobId, operatorIdentifier, + userCodeClassLoader, instanceBasePath, dbOptions, columnFamilyOptions, @@ -455,8 +457,8 @@ public class RocksDBKeyedStateBackend<K> extends KeyedStateBackend<K> { checkInterrupted(); //write StateDescriptor for this k/v state - ObjectOutputStream ooOut = new ObjectOutputStream(outStream); - ooOut.writeObject(column.getValue().f1); + InstantiationUtil.serializeObject(outStream, column.getValue().f1); + //retrieve iterator for this k/v states ReadOptions readOptions = new ReadOptions(); readOptions.setSnapshot(snapshot); @@ -649,8 +651,11 @@ public class RocksDBKeyedStateBackend<K> extends KeyedStateBackend<K> { //restore the empty columns for the k/v states through the metadata for (int i = 0; i < numColumns; i++) { - ObjectInputStream ooIn = new ObjectInputStream(currentStateHandleInStream); - StateDescriptor stateDescriptor = (StateDescriptor) ooIn.readObject(); + + StateDescriptor stateDescriptor = InstantiationUtil.deserializeObject( + currentStateHandleInStream, + rocksDBKeyedStateBackend.userCodeClassLoader); + Tuple2<ColumnFamilyHandle, StateDescriptor> columnFamily = rocksDBKeyedStateBackend. kvStateInformation.get(stateDescriptor.getName()); http://git-wip-us.apache.org/repos/asf/flink/blob/3b8fe95e/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java ---------------------------------------------------------------------- diff --git a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java index 0fdbd5f..b6ce224 100644 --- a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java +++ b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java @@ -240,6 +240,7 @@ public class RocksDBStateBackend extends AbstractStateBackend { return new RocksDBKeyedStateBackend<>( jobID, operatorIdentifier, + env.getUserClassLoader(), instanceBasePath, getDbOptions(), getColumnOptions(), @@ -264,6 +265,7 @@ public class RocksDBStateBackend extends AbstractStateBackend { return new RocksDBKeyedStateBackend<>( jobID, operatorIdentifier, + env.getUserClassLoader(), instanceBasePath, getDbOptions(), getColumnOptions(), http://git-wip-us.apache.org/repos/asf/flink/blob/3b8fe95e/flink-core/src/main/java/org/apache/flink/util/InstantiationUtil.java ---------------------------------------------------------------------- diff --git a/flink-core/src/main/java/org/apache/flink/util/InstantiationUtil.java b/flink-core/src/main/java/org/apache/flink/util/InstantiationUtil.java index b1dddae..de4cffb 100644 --- a/flink-core/src/main/java/org/apache/flink/util/InstantiationUtil.java +++ b/flink-core/src/main/java/org/apache/flink/util/InstantiationUtil.java @@ -299,7 +299,10 @@ public final class InstantiationUtil { @SuppressWarnings("unchecked") public static <T> T deserializeObject(InputStream in, ClassLoader cl) throws IOException, ClassNotFoundException { final ClassLoader old = Thread.currentThread().getContextClassLoader(); - try (ObjectInputStream oois = new ClassLoaderObjectInputStream(in, cl)) { + ObjectInputStream oois; + // not using resource try to avoid AutoClosable's close() on the given stream + try { + oois = new ClassLoaderObjectInputStream(in, cl); Thread.currentThread().setContextClassLoader(cl); return (T) oois.readObject(); } @@ -332,7 +335,6 @@ public final class InstantiationUtil { public static void serializeObject(OutputStream out, Object o) throws IOException { ObjectOutputStream oos = new ObjectOutputStream(out); oos.writeObject(o); - oos.flush(); } /** http://git-wip-us.apache.org/repos/asf/flink/blob/3b8fe95e/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java index bf9018e..8db63ee 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java @@ -77,14 +77,18 @@ public abstract class KeyedStateBackend<K> { /** KvStateRegistry helper for this task */ protected final TaskKvStateRegistry kvStateRegistry; + protected final ClassLoader userCodeClassLoader; + public KeyedStateBackend( TaskKvStateRegistry kvStateRegistry, TypeSerializer<K> keySerializer, + ClassLoader userCodeClassLoader, int numberOfKeyGroups, KeyGroupRange keyGroupRange) { this.kvStateRegistry = Preconditions.checkNotNull(kvStateRegistry); this.keySerializer = Preconditions.checkNotNull(keySerializer); + this.userCodeClassLoader = Preconditions.checkNotNull(userCodeClassLoader); this.numberOfKeyGroups = Preconditions.checkNotNull(numberOfKeyGroups); this.keyGroupRange = Preconditions.checkNotNull(keyGroupRange); } http://git-wip-us.apache.org/repos/asf/flink/blob/3b8fe95e/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java index 6d92a4d..99e3684 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java @@ -186,6 +186,7 @@ public class FsStateBackend extends AbstractStateBackend { return new HeapKeyedStateBackend<>( kvStateRegistry, keySerializer, + env.getUserClassLoader(), numberOfKeyGroups, keyGroupRange); } @@ -203,6 +204,7 @@ public class FsStateBackend extends AbstractStateBackend { return new HeapKeyedStateBackend<>( kvStateRegistry, keySerializer, + env.getUserClassLoader(), numberOfKeyGroups, keyGroupRange, restoredState); http://git-wip-us.apache.org/repos/asf/flink/blob/3b8fe95e/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java index 8d13941..c13be70 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java @@ -39,12 +39,11 @@ import org.apache.flink.runtime.state.KeyGroupRangeOffsets; import org.apache.flink.runtime.state.KeyGroupsStateHandle; import org.apache.flink.runtime.state.KeyedStateBackend; import org.apache.flink.runtime.state.StreamStateHandle; +import org.apache.flink.util.InstantiationUtil; import org.apache.flink.util.Preconditions; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.io.ObjectInputStream; -import java.io.ObjectOutputStream; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -75,20 +74,23 @@ public class HeapKeyedStateBackend<K> extends KeyedStateBackend<K> { public HeapKeyedStateBackend( TaskKvStateRegistry kvStateRegistry, TypeSerializer<K> keySerializer, + ClassLoader userCodeClassLoader, int numberOfKeyGroups, KeyGroupRange keyGroupRange) { - super(kvStateRegistry, keySerializer, numberOfKeyGroups, keyGroupRange); + super(kvStateRegistry, keySerializer, userCodeClassLoader, numberOfKeyGroups, keyGroupRange); LOG.info("Initializing heap keyed state backend with stream factory."); } - public HeapKeyedStateBackend(TaskKvStateRegistry kvStateRegistry, + public HeapKeyedStateBackend( + TaskKvStateRegistry kvStateRegistry, TypeSerializer<K> keySerializer, + ClassLoader userCodeClassLoader, int numberOfKeyGroups, KeyGroupRange keyGroupRange, List<KeyGroupsStateHandle> restoredState) throws Exception { - super(kvStateRegistry, keySerializer, numberOfKeyGroups, keyGroupRange); + super(kvStateRegistry, keySerializer, userCodeClassLoader, numberOfKeyGroups, keyGroupRange); LOG.info("Initializing heap keyed state backend from snapshot."); @@ -135,7 +137,6 @@ public class HeapKeyedStateBackend<K> extends KeyedStateBackend<K> { @SuppressWarnings("unchecked,rawtypes") StateTable<K, N, T> stateTable = (StateTable) stateTables.get(stateDesc.getName()); - if (stateTable == null) { stateTable = new StateTable<>(stateDesc.getSerializer(), namespaceSerializer, keyGroupRange); stateTables.put(stateDesc.getName(), stateTable); @@ -190,10 +191,8 @@ public class HeapKeyedStateBackend<K> extends KeyedStateBackend<K> { TypeSerializer namespaceSerializer = kvState.getValue().getNamespaceSerializer(); TypeSerializer stateSerializer = kvState.getValue().getStateSerializer(); - ObjectOutputStream oos = new ObjectOutputStream(outView); - oos.writeObject(namespaceSerializer); - oos.writeObject(stateSerializer); - oos.flush(); + InstantiationUtil.serializeObject(stream, namespaceSerializer); + InstantiationUtil.serializeObject(stream, stateSerializer); kVStateToId.put(kvState.getKey(), kVStateToId.size()); } @@ -266,18 +265,20 @@ public class HeapKeyedStateBackend<K> extends KeyedStateBackend<K> { for (int i = 0; i < numKvStates; ++i) { String stateName = inView.readUTF(); - ObjectInputStream ois = new ObjectInputStream(inView); + TypeSerializer namespaceSerializer = + InstantiationUtil.deserializeObject(fsDataInputStream, userCodeClassLoader); + TypeSerializer stateSerializer = + InstantiationUtil.deserializeObject(fsDataInputStream, userCodeClassLoader); - TypeSerializer namespaceSerializer = (TypeSerializer) ois.readObject(); - TypeSerializer stateSerializer = (TypeSerializer) ois.readObject(); - StateTable<K, ?, ?> stateTable = new StateTable(stateSerializer, + StateTable<K, ?, ?> stateTable = new StateTable( + stateSerializer, namespaceSerializer, keyGroupRange); stateTables.put(stateName, stateTable); kvStatesById.put(i, stateName); } - for (int keyGroupIndex = keyGroupRange.getStartKeyGroup(); keyGroupIndex <= keyGroupRange.getEndKeyGroup(); keyGroupIndex++) { + for (int keyGroupIndex = keyGroupRange.getStartKeyGroup(); keyGroupIndex <= keyGroupRange.getEndKeyGroup(); ++keyGroupIndex) { long offset = keyGroupsHandle.getOffsetForKeyGroup(keyGroupIndex); fsDataInputStream.seek(offset); http://git-wip-us.apache.org/repos/asf/flink/blob/3b8fe95e/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java index 179dfe7..cc145ff 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java @@ -78,7 +78,8 @@ public class MemoryStateBackend extends AbstractStateBackend { @Override public <K> KeyedStateBackend<K> createKeyedStateBackend( Environment env, JobID jobID, - String operatorIdentifier, TypeSerializer<K> keySerializer, + String operatorIdentifier, + TypeSerializer<K> keySerializer, int numberOfKeyGroups, KeyGroupRange keyGroupRange, TaskKvStateRegistry kvStateRegistry) throws IOException { @@ -86,6 +87,7 @@ public class MemoryStateBackend extends AbstractStateBackend { return new HeapKeyedStateBackend<>( kvStateRegistry, keySerializer, + env.getUserClassLoader(), numberOfKeyGroups, keyGroupRange); } @@ -103,6 +105,7 @@ public class MemoryStateBackend extends AbstractStateBackend { return new HeapKeyedStateBackend<>( kvStateRegistry, keySerializer, + env.getUserClassLoader(), numberOfKeyGroups, keyGroupRange, restoredState); http://git-wip-us.apache.org/repos/asf/flink/blob/3b8fe95e/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java index 9c26509..d4638a4 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java @@ -23,6 +23,7 @@ import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.configuration.ConfigConstants; import org.apache.flink.configuration.Configuration; import org.apache.flink.configuration.IllegalConfigurationException; +import org.apache.flink.core.fs.FSDataInputStream; import org.apache.flink.metrics.Gauge; import org.apache.flink.runtime.execution.CancelTaskException; import org.apache.flink.runtime.io.network.api.CheckpointBarrier; @@ -585,7 +586,12 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>> if (operator != null) { LOG.debug("Restore state of task {} in chain ({}).", i, getName()); - operator.restoreState(state.openInputStream()); + FSDataInputStream inputStream = state.openInputStream(); + try { + operator.restoreState(inputStream); + } finally { + inputStream.close(); + } } } } http://git-wip-us.apache.org/repos/asf/flink/blob/3b8fe95e/flink-tests/pom.xml ---------------------------------------------------------------------- diff --git a/flink-tests/pom.xml b/flink-tests/pom.xml index b09db1f..efc95ab 100644 --- a/flink-tests/pom.xml +++ b/flink-tests/pom.xml @@ -485,6 +485,25 @@ under the License. </descriptors> </configuration> </execution> + <execution> + <id>create-checkpointing_custom_kv_state-jar</id> + <phase>process-test-classes</phase> + <goals> + <goal>single</goal> + </goals> + <configuration> + <archive> + <manifest> + <mainClass>org.apache.flink.test.classloading.jar.CheckpointingCustomKvStateProgram</mainClass> + </manifest> + </archive> + <finalName>checkpointing_custom_kv_state</finalName> + <attach>false</attach> + <descriptors> + <descriptor>src/test/assembly/test-checkpointing-custom_kv_state-assembly.xml</descriptor> + </descriptors> + </configuration> + </execution> </executions> </plugin> http://git-wip-us.apache.org/repos/asf/flink/blob/3b8fe95e/flink-tests/src/test/assembly/test-checkpointing-custom_kv_state-assembly.xml ---------------------------------------------------------------------- diff --git a/flink-tests/src/test/assembly/test-checkpointing-custom_kv_state-assembly.xml b/flink-tests/src/test/assembly/test-checkpointing-custom_kv_state-assembly.xml new file mode 100644 index 0000000..fdebfdd --- /dev/null +++ b/flink-tests/src/test/assembly/test-checkpointing-custom_kv_state-assembly.xml @@ -0,0 +1,38 @@ +<!-- +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. + +--> + +<assembly> + <id>test-jar</id> + <formats> + <format>jar</format> + </formats> + <includeBaseDirectory>false</includeBaseDirectory> + <fileSets> + <fileSet> + <directory>${project.build.testOutputDirectory}</directory> + <outputDirectory>/</outputDirectory> + <!--modify/add include to match your package(s) --> + <includes> + <include>org/apache/flink/test/classloading/jar/CheckpointingCustomKvStateProgram.class</include> + <include>org/apache/flink/test/classloading/jar/CheckpointingCustomKvStateProgram*.class</include> + </includes> + </fileSet> + </fileSets> +</assembly> http://git-wip-us.apache.org/repos/asf/flink/blob/3b8fe95e/flink-tests/src/test/java/org/apache/flink/test/classloading/ClassLoaderITCase.java ---------------------------------------------------------------------- diff --git a/flink-tests/src/test/java/org/apache/flink/test/classloading/ClassLoaderITCase.java b/flink-tests/src/test/java/org/apache/flink/test/classloading/ClassLoaderITCase.java index 7afafe4..65da33f 100644 --- a/flink-tests/src/test/java/org/apache/flink/test/classloading/ClassLoaderITCase.java +++ b/flink-tests/src/test/java/org/apache/flink/test/classloading/ClassLoaderITCase.java @@ -39,6 +39,7 @@ import org.apache.flink.runtime.state.filesystem.FsStateBackendFactory; import org.apache.flink.runtime.testingUtils.TestingCluster; import org.apache.flink.runtime.testingUtils.TestingJobManagerMessages.WaitForAllVerticesToBeRunning; import org.apache.flink.test.testdata.KMeansData; +import org.apache.flink.test.util.SuccessException; import org.apache.flink.util.TestLogger; import org.junit.AfterClass; import org.junit.BeforeClass; @@ -46,7 +47,6 @@ import org.junit.Test; import org.junit.rules.TemporaryFolder; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import scala.Option; import scala.concurrent.Await; import scala.concurrent.Future; import scala.concurrent.duration.Deadline; @@ -79,6 +79,8 @@ public class ClassLoaderITCase extends TestLogger { private static final String CUSTOM_KV_STATE_JAR_PATH = "custom_kv_state-test-jar.jar"; + private static final String CHECKPOINTING_CUSTOM_KV_STATE_JAR_PATH = "checkpointing_custom_kv_state-test-jar.jar"; + public static final TemporaryFolder FOLDER = new TemporaryFolder(); private static TestingCluster testCluster; @@ -199,9 +201,26 @@ public class ClassLoaderITCase extends TestLogger { }); userCodeTypeProg.invokeInteractiveModeForExecution(); + + File checkpointDir = FOLDER.newFolder(); + File outputDir = FOLDER.newFolder(); + + final PackagedProgram program = new PackagedProgram( + new File(CHECKPOINTING_CUSTOM_KV_STATE_JAR_PATH), + new String[] { + CHECKPOINTING_CUSTOM_KV_STATE_JAR_PATH, + "localhost", + String.valueOf(port), + checkpointDir.toURI().toString(), + outputDir.toURI().toString() + }); + + program.invokeInteractiveModeForExecution(); + } catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); + if (!(e.getCause().getCause() instanceof SuccessException)) { + fail(e.getMessage()); + } } } http://git-wip-us.apache.org/repos/asf/flink/blob/3b8fe95e/flink-tests/src/test/java/org/apache/flink/test/classloading/jar/CheckpointingCustomKvStateProgram.java ---------------------------------------------------------------------- diff --git a/flink-tests/src/test/java/org/apache/flink/test/classloading/jar/CheckpointingCustomKvStateProgram.java b/flink-tests/src/test/java/org/apache/flink/test/classloading/jar/CheckpointingCustomKvStateProgram.java new file mode 100644 index 0000000..6796cb0 --- /dev/null +++ b/flink-tests/src/test/java/org/apache/flink/test/classloading/jar/CheckpointingCustomKvStateProgram.java @@ -0,0 +1,233 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.test.classloading.jar; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.common.functions.RichFlatMapFunction; +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.api.common.state.ReducingState; +import org.apache.flink.api.common.state.ReducingStateDescriptor; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.common.typeutils.base.TypeSerializerSingleton; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.core.fs.FileSystem; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.runtime.state.CheckpointListener; +import org.apache.flink.runtime.state.filesystem.FsStateBackend; +import org.apache.flink.streaming.api.checkpoint.Checkpointed; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.functions.source.ParallelSourceFunction; +import org.apache.flink.test.util.SuccessException; +import org.apache.flink.util.Collector; + +import java.io.IOException; +import java.util.concurrent.ThreadLocalRandom; + +public class CheckpointingCustomKvStateProgram { + + public static void main(String[] args) throws Exception { + final String jarFile = args[0]; + final String host = args[1]; + final int port = Integer.parseInt(args[2]); + final String checkpointPath = args[3]; + final String outputPath = args[4]; + final int parallelism = 1; + + StreamExecutionEnvironment env = StreamExecutionEnvironment.createRemoteEnvironment(host, port, jarFile); + + env.setParallelism(parallelism); + env.getConfig().disableSysoutLogging(); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.fixedDelayRestart(1, 1000)); + env.setStateBackend(new FsStateBackend(checkpointPath)); + + DataStream<Integer> source = env.addSource(new InfiniteIntegerSource()); + source + .map(new MapFunction<Integer, Tuple2<Integer, Integer>>() { + private static final long serialVersionUID = 1L; + + @Override + public Tuple2<Integer, Integer> map(Integer value) throws Exception { + return new Tuple2<>(ThreadLocalRandom.current().nextInt(parallelism), value); + } + }) + .keyBy(new KeySelector<Tuple2<Integer,Integer>, Integer>() { + private static final long serialVersionUID = 1L; + + @Override + public Integer getKey(Tuple2<Integer, Integer> value) throws Exception { + return value.f0; + } + }).flatMap(new ReducingStateFlatMap()).writeAsText(outputPath, FileSystem.WriteMode.OVERWRITE); + + env.execute(); + } + + private static class InfiniteIntegerSource implements ParallelSourceFunction<Integer>, Checkpointed<Integer> { + private static final long serialVersionUID = -7517574288730066280L; + private volatile boolean running = true; + + @Override + public void run(SourceContext<Integer> ctx) throws Exception { + int counter = 0; + while (running) { + synchronized (ctx.getCheckpointLock()) { + ctx.collect(counter++); + } + } + } + + @Override + public void cancel() { + running = false; + } + + @Override + public Integer snapshotState(long checkpointId, long checkpointTimestamp) throws Exception { + return 0; + } + + @Override + public void restoreState(Integer state) throws Exception { + + } + } + + private static class ReducingStateFlatMap extends RichFlatMapFunction<Tuple2<Integer, Integer>, Integer> implements Checkpointed<ReducingStateFlatMap>, CheckpointListener { + + private static final long serialVersionUID = -5939722892793950253L; + private transient ReducingState<Integer> kvState; + + private boolean atLeastOneSnapshotComplete = false; + private boolean restored = false; + + @Override + public void open(Configuration parameters) throws Exception { + ReducingStateDescriptor<Integer> stateDescriptor = + new ReducingStateDescriptor<>( + "reducing-state", + new ReduceSum(), + CustomIntSerializer.INSTANCE); + + this.kvState = getRuntimeContext().getReducingState(stateDescriptor); + } + + + @Override + public void flatMap(Tuple2<Integer, Integer> value, Collector<Integer> out) throws Exception { + kvState.add(value.f1); + + if(atLeastOneSnapshotComplete) { + if (restored) { + throw new SuccessException(); + } else { + throw new RuntimeException("Intended failure, to trigger restore"); + } + } + } + + @Override + public ReducingStateFlatMap snapshotState(long checkpointId, long checkpointTimestamp) throws Exception { + return this; + } + + @Override + public void restoreState(ReducingStateFlatMap state) throws Exception { + restored = true; + atLeastOneSnapshotComplete = true; + } + + @Override + public void notifyCheckpointComplete(long checkpointId) throws Exception { + atLeastOneSnapshotComplete = true; + } + + private static class ReduceSum implements ReduceFunction<Integer> { + private static final long serialVersionUID = 1L; + + @Override + public Integer reduce(Integer value1, Integer value2) throws Exception { + return value1 + value2; + } + } + } + + private static final class CustomIntSerializer extends TypeSerializerSingleton<Integer> { + + private static final long serialVersionUID = 4572452915892737448L; + + public static final TypeSerializer<Integer> INSTANCE = new CustomIntSerializer(); + + @Override + public boolean isImmutableType() { + return true; + } + + @Override + public Integer createInstance() { + return 0; + } + + @Override + public Integer copy(Integer from) { + return from; + } + + @Override + public Integer copy(Integer from, Integer reuse) { + return from; + } + + @Override + public int getLength() { + return 4; + } + + @Override + public void serialize(Integer record, DataOutputView target) throws IOException { + target.writeInt(record.intValue()); + } + + @Override + public Integer deserialize(DataInputView source) throws IOException { + return Integer.valueOf(source.readInt()); + } + + @Override + public Integer deserialize(Integer reuse, DataInputView source) throws IOException { + return Integer.valueOf(source.readInt()); + } + + @Override + public void copy(DataInputView source, DataOutputView target) throws IOException { + target.writeInt(source.readInt()); + } + + @Override + public boolean canEqual(Object obj) { + return obj instanceof CustomIntSerializer; + } + + } +}