[FLINK-4603] [checkpoints] Fix user code classloading in KeyedStateBackend

This closes #2533


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/3b8fe95e
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/3b8fe95e
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/3b8fe95e

Branch: refs/heads/master
Commit: 3b8fe95ec728d59e3ffba2901450c56d7cca2b24
Parents: e6fbda9
Author: Stefan Richter <s.rich...@data-artisans.com>
Authored: Wed Sep 21 14:55:58 2016 +0200
Committer: Stephan Ewen <se...@apache.org>
Committed: Thu Sep 22 14:42:12 2016 +0200

----------------------------------------------------------------------
 .../state/RocksDBKeyedStateBackend.java         |  19 +-
 .../streaming/state/RocksDBStateBackend.java    |   2 +
 .../apache/flink/util/InstantiationUtil.java    |   6 +-
 .../flink/runtime/state/KeyedStateBackend.java  |   4 +
 .../state/filesystem/FsStateBackend.java        |   2 +
 .../state/heap/HeapKeyedStateBackend.java       |  31 +--
 .../state/memory/MemoryStateBackend.java        |   5 +-
 .../streaming/runtime/tasks/StreamTask.java     |   8 +-
 flink-tests/pom.xml                             |  19 ++
 ...t-checkpointing-custom_kv_state-assembly.xml |  38 +++
 .../test/classloading/ClassLoaderITCase.java    |  25 +-
 .../jar/CheckpointingCustomKvStateProgram.java  | 233 +++++++++++++++++++
 12 files changed, 363 insertions(+), 29 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/3b8fe95e/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
----------------------------------------------------------------------
diff --git 
a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
 
b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
index 177c09f..d5a96af 100644
--- 
a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
+++ 
b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
@@ -47,6 +47,7 @@ import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.KeyedStateBackend;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.util.SerializableObject;
+import org.apache.flink.util.InstantiationUtil;
 import org.apache.flink.util.Preconditions;
 import org.rocksdb.ColumnFamilyDescriptor;
 import org.rocksdb.ColumnFamilyHandle;
@@ -63,8 +64,6 @@ import org.slf4j.LoggerFactory;
 import javax.annotation.concurrent.GuardedBy;
 import java.io.File;
 import java.io.IOException;
-import java.io.ObjectInputStream;
-import java.io.ObjectOutputStream;
 import java.util.ArrayList;
 import java.util.Comparator;
 import java.util.HashMap;
@@ -125,6 +124,7 @@ public class RocksDBKeyedStateBackend<K> extends 
KeyedStateBackend<K> {
        public RocksDBKeyedStateBackend(
                        JobID jobId,
                        String operatorIdentifier,
+                       ClassLoader userCodeClassLoader,
                        File instanceBasePath,
                        DBOptions dbOptions,
                        ColumnFamilyOptions columnFamilyOptions,
@@ -134,7 +134,7 @@ public class RocksDBKeyedStateBackend<K> extends 
KeyedStateBackend<K> {
                        KeyGroupRange keyGroupRange
        ) throws Exception {
 
-               super(kvStateRegistry, keySerializer, numberOfKeyGroups, 
keyGroupRange);
+               super(kvStateRegistry, keySerializer, userCodeClassLoader, 
numberOfKeyGroups, keyGroupRange);
 
                this.operatorIdentifier = operatorIdentifier;
                this.jobId = jobId;
@@ -177,6 +177,7 @@ public class RocksDBKeyedStateBackend<K> extends 
KeyedStateBackend<K> {
        public RocksDBKeyedStateBackend(
                        JobID jobId,
                        String operatorIdentifier,
+                       ClassLoader userCodeClassLoader,
                        File instanceBasePath,
                        DBOptions dbOptions,
                        ColumnFamilyOptions columnFamilyOptions,
@@ -189,6 +190,7 @@ public class RocksDBKeyedStateBackend<K> extends 
KeyedStateBackend<K> {
                this(
                        jobId,
                        operatorIdentifier,
+                       userCodeClassLoader,
                        instanceBasePath,
                        dbOptions,
                        columnFamilyOptions,
@@ -455,8 +457,8 @@ public class RocksDBKeyedStateBackend<K> extends 
KeyedStateBackend<K> {
                                checkInterrupted();
 
                                //write StateDescriptor for this k/v state
-                               ObjectOutputStream ooOut = new 
ObjectOutputStream(outStream);
-                               ooOut.writeObject(column.getValue().f1);
+                               InstantiationUtil.serializeObject(outStream, 
column.getValue().f1);
+
                                //retrieve iterator for this k/v states
                                ReadOptions readOptions = new ReadOptions();
                                readOptions.setSnapshot(snapshot);
@@ -649,8 +651,11 @@ public class RocksDBKeyedStateBackend<K> extends 
KeyedStateBackend<K> {
 
                        //restore the empty columns for the k/v states through 
the metadata
                        for (int i = 0; i < numColumns; i++) {
-                               ObjectInputStream ooIn = new 
ObjectInputStream(currentStateHandleInStream);
-                               StateDescriptor stateDescriptor = 
(StateDescriptor) ooIn.readObject();
+
+                               StateDescriptor stateDescriptor = 
InstantiationUtil.deserializeObject(
+                                               currentStateHandleInStream,
+                                               
rocksDBKeyedStateBackend.userCodeClassLoader);
+
                                Tuple2<ColumnFamilyHandle, StateDescriptor> 
columnFamily = rocksDBKeyedStateBackend.
                                                
kvStateInformation.get(stateDescriptor.getName());
 

http://git-wip-us.apache.org/repos/asf/flink/blob/3b8fe95e/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java
----------------------------------------------------------------------
diff --git 
a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java
 
b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java
index 0fdbd5f..b6ce224 100644
--- 
a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java
+++ 
b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java
@@ -240,6 +240,7 @@ public class RocksDBStateBackend extends 
AbstractStateBackend {
                return new RocksDBKeyedStateBackend<>(
                                jobID,
                                operatorIdentifier,
+                               env.getUserClassLoader(),
                                instanceBasePath,
                                getDbOptions(),
                                getColumnOptions(),
@@ -264,6 +265,7 @@ public class RocksDBStateBackend extends 
AbstractStateBackend {
                return new RocksDBKeyedStateBackend<>(
                                jobID,
                                operatorIdentifier,
+                               env.getUserClassLoader(),
                                instanceBasePath,
                                getDbOptions(),
                                getColumnOptions(),

http://git-wip-us.apache.org/repos/asf/flink/blob/3b8fe95e/flink-core/src/main/java/org/apache/flink/util/InstantiationUtil.java
----------------------------------------------------------------------
diff --git 
a/flink-core/src/main/java/org/apache/flink/util/InstantiationUtil.java 
b/flink-core/src/main/java/org/apache/flink/util/InstantiationUtil.java
index b1dddae..de4cffb 100644
--- a/flink-core/src/main/java/org/apache/flink/util/InstantiationUtil.java
+++ b/flink-core/src/main/java/org/apache/flink/util/InstantiationUtil.java
@@ -299,7 +299,10 @@ public final class InstantiationUtil {
        @SuppressWarnings("unchecked")
        public static <T> T deserializeObject(InputStream in, ClassLoader cl) 
throws IOException, ClassNotFoundException {
                final ClassLoader old = 
Thread.currentThread().getContextClassLoader();
-               try (ObjectInputStream oois = new 
ClassLoaderObjectInputStream(in, cl)) {
+               ObjectInputStream oois;
+               // not using resource try to avoid AutoClosable's close() on 
the given stream
+               try {
+                       oois = new ClassLoaderObjectInputStream(in, cl);
                        Thread.currentThread().setContextClassLoader(cl);
                        return (T) oois.readObject();
                }
@@ -332,7 +335,6 @@ public final class InstantiationUtil {
        public static void serializeObject(OutputStream out, Object o) throws 
IOException {
                ObjectOutputStream oos = new ObjectOutputStream(out);
                oos.writeObject(o);
-               oos.flush();
        }
 
        /**

http://git-wip-us.apache.org/repos/asf/flink/blob/3b8fe95e/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java
index bf9018e..8db63ee 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java
@@ -77,14 +77,18 @@ public abstract class KeyedStateBackend<K> {
        /** KvStateRegistry helper for this task */
        protected final TaskKvStateRegistry kvStateRegistry;
 
+       protected final ClassLoader userCodeClassLoader;
+
        public KeyedStateBackend(
                        TaskKvStateRegistry kvStateRegistry,
                        TypeSerializer<K> keySerializer,
+                       ClassLoader userCodeClassLoader,
                        int numberOfKeyGroups,
                        KeyGroupRange keyGroupRange) {
 
                this.kvStateRegistry = 
Preconditions.checkNotNull(kvStateRegistry);
                this.keySerializer = Preconditions.checkNotNull(keySerializer);
+               this.userCodeClassLoader = 
Preconditions.checkNotNull(userCodeClassLoader);
                this.numberOfKeyGroups = 
Preconditions.checkNotNull(numberOfKeyGroups);
                this.keyGroupRange = Preconditions.checkNotNull(keyGroupRange);
        }

http://git-wip-us.apache.org/repos/asf/flink/blob/3b8fe95e/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java
index 6d92a4d..99e3684 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java
@@ -186,6 +186,7 @@ public class FsStateBackend extends AbstractStateBackend {
                return new HeapKeyedStateBackend<>(
                                kvStateRegistry,
                                keySerializer,
+                               env.getUserClassLoader(),
                                numberOfKeyGroups,
                                keyGroupRange);
        }
@@ -203,6 +204,7 @@ public class FsStateBackend extends AbstractStateBackend {
                return new HeapKeyedStateBackend<>(
                                kvStateRegistry,
                                keySerializer,
+                               env.getUserClassLoader(),
                                numberOfKeyGroups,
                                keyGroupRange,
                                restoredState);

http://git-wip-us.apache.org/repos/asf/flink/blob/3b8fe95e/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
index 8d13941..c13be70 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
@@ -39,12 +39,11 @@ import org.apache.flink.runtime.state.KeyGroupRangeOffsets;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.KeyedStateBackend;
 import org.apache.flink.runtime.state.StreamStateHandle;
+import org.apache.flink.util.InstantiationUtil;
 import org.apache.flink.util.Preconditions;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import java.io.ObjectInputStream;
-import java.io.ObjectOutputStream;
 import java.util.ArrayList;
 import java.util.HashMap;
 import java.util.List;
@@ -75,20 +74,23 @@ public class HeapKeyedStateBackend<K> extends 
KeyedStateBackend<K> {
        public HeapKeyedStateBackend(
                        TaskKvStateRegistry kvStateRegistry,
                        TypeSerializer<K> keySerializer,
+                       ClassLoader userCodeClassLoader,
                        int numberOfKeyGroups,
                        KeyGroupRange keyGroupRange) {
 
-               super(kvStateRegistry, keySerializer, numberOfKeyGroups, 
keyGroupRange);
+               super(kvStateRegistry, keySerializer, userCodeClassLoader, 
numberOfKeyGroups, keyGroupRange);
 
                LOG.info("Initializing heap keyed state backend with stream 
factory.");
        }
 
-       public HeapKeyedStateBackend(TaskKvStateRegistry kvStateRegistry,
+       public HeapKeyedStateBackend(
+                       TaskKvStateRegistry kvStateRegistry,
                        TypeSerializer<K> keySerializer,
+                       ClassLoader userCodeClassLoader,
                        int numberOfKeyGroups,
                        KeyGroupRange keyGroupRange,
                        List<KeyGroupsStateHandle> restoredState) throws 
Exception {
-               super(kvStateRegistry, keySerializer, numberOfKeyGroups, 
keyGroupRange);
+               super(kvStateRegistry, keySerializer, userCodeClassLoader, 
numberOfKeyGroups, keyGroupRange);
 
                LOG.info("Initializing heap keyed state backend from 
snapshot.");
 
@@ -135,7 +137,6 @@ public class HeapKeyedStateBackend<K> extends 
KeyedStateBackend<K> {
                @SuppressWarnings("unchecked,rawtypes")
                StateTable<K, N, T> stateTable = (StateTable) 
stateTables.get(stateDesc.getName());
 
-
                if (stateTable == null) {
                        stateTable = new 
StateTable<>(stateDesc.getSerializer(), namespaceSerializer, keyGroupRange);
                        stateTables.put(stateDesc.getName(), stateTable);
@@ -190,10 +191,8 @@ public class HeapKeyedStateBackend<K> extends 
KeyedStateBackend<K> {
                        TypeSerializer namespaceSerializer = 
kvState.getValue().getNamespaceSerializer();
                        TypeSerializer stateSerializer = 
kvState.getValue().getStateSerializer();
 
-                       ObjectOutputStream oos = new 
ObjectOutputStream(outView);
-                       oos.writeObject(namespaceSerializer);
-                       oos.writeObject(stateSerializer);
-                       oos.flush();
+                       InstantiationUtil.serializeObject(stream, 
namespaceSerializer);
+                       InstantiationUtil.serializeObject(stream, 
stateSerializer);
 
                        kVStateToId.put(kvState.getKey(), kVStateToId.size());
                }
@@ -266,18 +265,20 @@ public class HeapKeyedStateBackend<K> extends 
KeyedStateBackend<K> {
                        for (int i = 0; i < numKvStates; ++i) {
                                String stateName = inView.readUTF();
 
-                               ObjectInputStream ois = new 
ObjectInputStream(inView);
+                               TypeSerializer namespaceSerializer =
+                                               
InstantiationUtil.deserializeObject(fsDataInputStream, userCodeClassLoader);
+                               TypeSerializer stateSerializer =
+                                               
InstantiationUtil.deserializeObject(fsDataInputStream, userCodeClassLoader);
 
-                               TypeSerializer namespaceSerializer = 
(TypeSerializer) ois.readObject();
-                               TypeSerializer stateSerializer = 
(TypeSerializer) ois.readObject();
-                               StateTable<K, ?, ?> stateTable = new 
StateTable(stateSerializer,
+                               StateTable<K, ?, ?> stateTable = new StateTable(
+                                               stateSerializer,
                                                namespaceSerializer,
                                                keyGroupRange);
                                stateTables.put(stateName, stateTable);
                                kvStatesById.put(i, stateName);
                        }
 
-                       for (int keyGroupIndex = 
keyGroupRange.getStartKeyGroup(); keyGroupIndex <= 
keyGroupRange.getEndKeyGroup(); keyGroupIndex++) {
+                       for (int keyGroupIndex = 
keyGroupRange.getStartKeyGroup();  keyGroupIndex <= 
keyGroupRange.getEndKeyGroup(); ++keyGroupIndex) {
                                long offset = 
keyGroupsHandle.getOffsetForKeyGroup(keyGroupIndex);
                                fsDataInputStream.seek(offset);
 

http://git-wip-us.apache.org/repos/asf/flink/blob/3b8fe95e/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java
index 179dfe7..cc145ff 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java
@@ -78,7 +78,8 @@ public class MemoryStateBackend extends AbstractStateBackend {
        @Override
        public <K> KeyedStateBackend<K> createKeyedStateBackend(
                        Environment env, JobID jobID,
-                       String operatorIdentifier, TypeSerializer<K> 
keySerializer,
+                       String operatorIdentifier,
+                       TypeSerializer<K> keySerializer,
                        int numberOfKeyGroups,
                        KeyGroupRange keyGroupRange,
                        TaskKvStateRegistry kvStateRegistry) throws IOException 
{
@@ -86,6 +87,7 @@ public class MemoryStateBackend extends AbstractStateBackend {
                return new HeapKeyedStateBackend<>(
                                kvStateRegistry,
                                keySerializer,
+                               env.getUserClassLoader(),
                                numberOfKeyGroups,
                                keyGroupRange);
        }
@@ -103,6 +105,7 @@ public class MemoryStateBackend extends 
AbstractStateBackend {
                return new HeapKeyedStateBackend<>(
                                kvStateRegistry,
                                keySerializer,
+                               env.getUserClassLoader(),
                                numberOfKeyGroups,
                                keyGroupRange,
                                restoredState);

http://git-wip-us.apache.org/repos/asf/flink/blob/3b8fe95e/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
----------------------------------------------------------------------
diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
index 9c26509..d4638a4 100644
--- 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
@@ -23,6 +23,7 @@ import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.configuration.ConfigConstants;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.configuration.IllegalConfigurationException;
+import org.apache.flink.core.fs.FSDataInputStream;
 import org.apache.flink.metrics.Gauge;
 import org.apache.flink.runtime.execution.CancelTaskException;
 import org.apache.flink.runtime.io.network.api.CheckpointBarrier;
@@ -585,7 +586,12 @@ public abstract class StreamTask<OUT, Operator extends 
StreamOperator<OUT>>
 
                                                if (operator != null) {
                                                        LOG.debug("Restore 
state of task {} in chain ({}).", i, getName());
-                                                       
operator.restoreState(state.openInputStream());
+                                                       FSDataInputStream 
inputStream = state.openInputStream();
+                                                       try {
+                                                               
operator.restoreState(inputStream);
+                                                       } finally {
+                                                               
inputStream.close();
+                                                       }
                                                }
                                        }
                                }

http://git-wip-us.apache.org/repos/asf/flink/blob/3b8fe95e/flink-tests/pom.xml
----------------------------------------------------------------------
diff --git a/flink-tests/pom.xml b/flink-tests/pom.xml
index b09db1f..efc95ab 100644
--- a/flink-tests/pom.xml
+++ b/flink-tests/pom.xml
@@ -485,6 +485,25 @@ under the License.
                                                        </descriptors>
                                                </configuration>
                                        </execution>
+                                       <execution>
+                                               
<id>create-checkpointing_custom_kv_state-jar</id>
+                                               
<phase>process-test-classes</phase>
+                                               <goals>
+                                                       <goal>single</goal>
+                                               </goals>
+                                               <configuration>
+                                                       <archive>
+                                                               <manifest>
+                                                                       
<mainClass>org.apache.flink.test.classloading.jar.CheckpointingCustomKvStateProgram</mainClass>
+                                                               </manifest>
+                                                       </archive>
+                                                       
<finalName>checkpointing_custom_kv_state</finalName>
+                                                       <attach>false</attach>
+                                                       <descriptors>
+                                                               
<descriptor>src/test/assembly/test-checkpointing-custom_kv_state-assembly.xml</descriptor>
+                                                       </descriptors>
+                                               </configuration>
+                                       </execution>
                                </executions>
                        </plugin>
 

http://git-wip-us.apache.org/repos/asf/flink/blob/3b8fe95e/flink-tests/src/test/assembly/test-checkpointing-custom_kv_state-assembly.xml
----------------------------------------------------------------------
diff --git 
a/flink-tests/src/test/assembly/test-checkpointing-custom_kv_state-assembly.xml 
b/flink-tests/src/test/assembly/test-checkpointing-custom_kv_state-assembly.xml
new file mode 100644
index 0000000..fdebfdd
--- /dev/null
+++ 
b/flink-tests/src/test/assembly/test-checkpointing-custom_kv_state-assembly.xml
@@ -0,0 +1,38 @@
+<!--
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements.  See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership.  The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License.  You may obtain a copy of the License at
+
+  http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied.  See the License for the
+specific language governing permissions and limitations
+under the License.
+
+-->
+
+<assembly>
+       <id>test-jar</id>
+       <formats>
+               <format>jar</format>
+       </formats>
+       <includeBaseDirectory>false</includeBaseDirectory>
+       <fileSets>
+               <fileSet>
+                       
<directory>${project.build.testOutputDirectory}</directory>
+                       <outputDirectory>/</outputDirectory>
+                       <!--modify/add include to match your package(s) -->
+                       <includes>
+                               
<include>org/apache/flink/test/classloading/jar/CheckpointingCustomKvStateProgram.class</include>
+                               
<include>org/apache/flink/test/classloading/jar/CheckpointingCustomKvStateProgram*.class</include>
+                       </includes>
+               </fileSet>
+       </fileSets>
+</assembly>

http://git-wip-us.apache.org/repos/asf/flink/blob/3b8fe95e/flink-tests/src/test/java/org/apache/flink/test/classloading/ClassLoaderITCase.java
----------------------------------------------------------------------
diff --git 
a/flink-tests/src/test/java/org/apache/flink/test/classloading/ClassLoaderITCase.java
 
b/flink-tests/src/test/java/org/apache/flink/test/classloading/ClassLoaderITCase.java
index 7afafe4..65da33f 100644
--- 
a/flink-tests/src/test/java/org/apache/flink/test/classloading/ClassLoaderITCase.java
+++ 
b/flink-tests/src/test/java/org/apache/flink/test/classloading/ClassLoaderITCase.java
@@ -39,6 +39,7 @@ import 
org.apache.flink.runtime.state.filesystem.FsStateBackendFactory;
 import org.apache.flink.runtime.testingUtils.TestingCluster;
 import 
org.apache.flink.runtime.testingUtils.TestingJobManagerMessages.WaitForAllVerticesToBeRunning;
 import org.apache.flink.test.testdata.KMeansData;
+import org.apache.flink.test.util.SuccessException;
 import org.apache.flink.util.TestLogger;
 import org.junit.AfterClass;
 import org.junit.BeforeClass;
@@ -46,7 +47,6 @@ import org.junit.Test;
 import org.junit.rules.TemporaryFolder;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
-import scala.Option;
 import scala.concurrent.Await;
 import scala.concurrent.Future;
 import scala.concurrent.duration.Deadline;
@@ -79,6 +79,8 @@ public class ClassLoaderITCase extends TestLogger {
 
        private static final String CUSTOM_KV_STATE_JAR_PATH = 
"custom_kv_state-test-jar.jar";
 
+       private static final String CHECKPOINTING_CUSTOM_KV_STATE_JAR_PATH = 
"checkpointing_custom_kv_state-test-jar.jar";
+
        public static final TemporaryFolder FOLDER = new TemporaryFolder();
 
        private static TestingCluster testCluster;
@@ -199,9 +201,26 @@ public class ClassLoaderITCase extends TestLogger {
                                        });
 
                        userCodeTypeProg.invokeInteractiveModeForExecution();
+
+                       File checkpointDir = FOLDER.newFolder();
+                       File outputDir = FOLDER.newFolder();
+
+                       final PackagedProgram program = new PackagedProgram(
+                                       new 
File(CHECKPOINTING_CUSTOM_KV_STATE_JAR_PATH),
+                                       new String[] {
+                                                       
CHECKPOINTING_CUSTOM_KV_STATE_JAR_PATH,
+                                                       "localhost",
+                                                       String.valueOf(port),
+                                                       
checkpointDir.toURI().toString(),
+                                                       
outputDir.toURI().toString()
+                                       });
+
+                       program.invokeInteractiveModeForExecution();
+
                } catch (Exception e) {
-                       e.printStackTrace();
-                       fail(e.getMessage());
+                       if (!(e.getCause().getCause() instanceof 
SuccessException)) {
+                               fail(e.getMessage());
+                       }
                }
        }
 

http://git-wip-us.apache.org/repos/asf/flink/blob/3b8fe95e/flink-tests/src/test/java/org/apache/flink/test/classloading/jar/CheckpointingCustomKvStateProgram.java
----------------------------------------------------------------------
diff --git 
a/flink-tests/src/test/java/org/apache/flink/test/classloading/jar/CheckpointingCustomKvStateProgram.java
 
b/flink-tests/src/test/java/org/apache/flink/test/classloading/jar/CheckpointingCustomKvStateProgram.java
new file mode 100644
index 0000000..6796cb0
--- /dev/null
+++ 
b/flink-tests/src/test/java/org/apache/flink/test/classloading/jar/CheckpointingCustomKvStateProgram.java
@@ -0,0 +1,233 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.test.classloading.jar;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.state.ReducingState;
+import org.apache.flink.api.common.state.ReducingStateDescriptor;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.common.typeutils.base.TypeSerializerSingleton;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.core.fs.FileSystem;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.runtime.state.CheckpointListener;
+import org.apache.flink.runtime.state.filesystem.FsStateBackend;
+import org.apache.flink.streaming.api.checkpoint.Checkpointed;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.source.ParallelSourceFunction;
+import org.apache.flink.test.util.SuccessException;
+import org.apache.flink.util.Collector;
+
+import java.io.IOException;
+import java.util.concurrent.ThreadLocalRandom;
+
+public class CheckpointingCustomKvStateProgram {
+
+       public static void main(String[] args) throws Exception {
+               final String jarFile = args[0];
+               final String host = args[1];
+               final int port = Integer.parseInt(args[2]);
+               final String checkpointPath = args[3];
+               final String outputPath = args[4];
+               final int parallelism = 1;
+
+               StreamExecutionEnvironment env = 
StreamExecutionEnvironment.createRemoteEnvironment(host, port, jarFile);
+
+               env.setParallelism(parallelism);
+               env.getConfig().disableSysoutLogging();
+               env.enableCheckpointing(100);
+               env.setRestartStrategy(RestartStrategies.fixedDelayRestart(1, 
1000));
+               env.setStateBackend(new FsStateBackend(checkpointPath));
+
+               DataStream<Integer> source = env.addSource(new 
InfiniteIntegerSource());
+               source
+                               .map(new MapFunction<Integer, Tuple2<Integer, 
Integer>>() {
+                                       private static final long 
serialVersionUID = 1L;
+
+                                       @Override
+                                       public Tuple2<Integer, Integer> 
map(Integer value) throws Exception {
+                                               return new 
Tuple2<>(ThreadLocalRandom.current().nextInt(parallelism), value);
+                                       }
+                               })
+                               .keyBy(new KeySelector<Tuple2<Integer,Integer>, 
Integer>() {
+                                       private static final long 
serialVersionUID = 1L;
+
+                                       @Override
+                                       public Integer getKey(Tuple2<Integer, 
Integer> value) throws Exception {
+                                               return value.f0;
+                                       }
+                               }).flatMap(new 
ReducingStateFlatMap()).writeAsText(outputPath, FileSystem.WriteMode.OVERWRITE);
+
+               env.execute();
+       }
+
+       private static class InfiniteIntegerSource implements 
ParallelSourceFunction<Integer>, Checkpointed<Integer> {
+               private static final long serialVersionUID = 
-7517574288730066280L;
+               private volatile boolean running = true;
+
+               @Override
+               public void run(SourceContext<Integer> ctx) throws Exception {
+                       int counter = 0;
+                       while (running) {
+                               synchronized (ctx.getCheckpointLock()) {
+                                       ctx.collect(counter++);
+                               }
+                       }
+               }
+
+               @Override
+               public void cancel() {
+                       running = false;
+               }
+
+               @Override
+               public Integer snapshotState(long checkpointId, long 
checkpointTimestamp) throws Exception {
+                       return 0;
+               }
+
+               @Override
+               public void restoreState(Integer state) throws Exception {
+
+               }
+       }
+
+       private static class ReducingStateFlatMap extends 
RichFlatMapFunction<Tuple2<Integer, Integer>, Integer> implements 
Checkpointed<ReducingStateFlatMap>, CheckpointListener {
+
+               private static final long serialVersionUID = 
-5939722892793950253L;
+               private transient ReducingState<Integer> kvState;
+
+               private boolean atLeastOneSnapshotComplete = false;
+               private boolean restored = false;
+
+               @Override
+               public void open(Configuration parameters) throws Exception {
+                       ReducingStateDescriptor<Integer> stateDescriptor =
+                                       new ReducingStateDescriptor<>(
+                                                       "reducing-state",
+                                                       new ReduceSum(),
+                                                       
CustomIntSerializer.INSTANCE);
+
+                       this.kvState = 
getRuntimeContext().getReducingState(stateDescriptor);
+               }
+
+
+               @Override
+               public void flatMap(Tuple2<Integer, Integer> value, 
Collector<Integer> out) throws Exception {
+                       kvState.add(value.f1);
+
+                       if(atLeastOneSnapshotComplete) {
+                               if (restored) {
+                                       throw new SuccessException();
+                               } else {
+                                       throw new RuntimeException("Intended 
failure, to trigger restore");
+                               }
+                       }
+               }
+
+               @Override
+               public ReducingStateFlatMap snapshotState(long checkpointId, 
long checkpointTimestamp) throws Exception {
+                       return this;
+               }
+
+               @Override
+               public void restoreState(ReducingStateFlatMap state) throws 
Exception {
+                       restored = true;
+                       atLeastOneSnapshotComplete = true;
+               }
+
+               @Override
+               public void notifyCheckpointComplete(long checkpointId) throws 
Exception {
+                       atLeastOneSnapshotComplete = true;
+               }
+
+               private static class ReduceSum implements 
ReduceFunction<Integer> {
+                       private static final long serialVersionUID = 1L;
+
+                       @Override
+                       public Integer reduce(Integer value1, Integer value2) 
throws Exception {
+                               return value1 + value2;
+                       }
+               }
+       }
+
+       private static final class CustomIntSerializer extends 
TypeSerializerSingleton<Integer> {
+
+               private static final long serialVersionUID = 
4572452915892737448L;
+
+               public static final TypeSerializer<Integer> INSTANCE = new 
CustomIntSerializer();
+
+               @Override
+               public boolean isImmutableType() {
+                       return true;
+               }
+
+               @Override
+               public Integer createInstance() {
+                       return 0;
+               }
+
+               @Override
+               public Integer copy(Integer from) {
+                       return from;
+               }
+
+               @Override
+               public Integer copy(Integer from, Integer reuse) {
+                       return from;
+               }
+
+               @Override
+               public int getLength() {
+                       return 4;
+               }
+
+               @Override
+               public void serialize(Integer record, DataOutputView target) 
throws IOException {
+                       target.writeInt(record.intValue());
+               }
+
+               @Override
+               public Integer deserialize(DataInputView source) throws 
IOException {
+                       return Integer.valueOf(source.readInt());
+               }
+
+               @Override
+               public Integer deserialize(Integer reuse, DataInputView source) 
throws IOException {
+                       return Integer.valueOf(source.readInt());
+               }
+
+               @Override
+               public void copy(DataInputView source, DataOutputView target) 
throws IOException {
+                       target.writeInt(source.readInt());
+               }
+
+               @Override
+               public boolean canEqual(Object obj) {
+                       return obj instanceof CustomIntSerializer;
+               }
+
+       }
+}

Reply via email to