http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
index 73e2808..2f21574 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
@@ -80,11 +80,11 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> {
                return getStateBackend().createStreamFactory(new JobID(), 
"test_op");
        }
 
-       protected <K> KeyedStateBackend<K> createKeyedBackend(TypeSerializer<K> 
keySerializer) throws Exception {
+       protected <K> AbstractKeyedStateBackend<K> 
createKeyedBackend(TypeSerializer<K> keySerializer) throws Exception {
                return createKeyedBackend(keySerializer, new 
DummyEnvironment("test", 1, 0));
        }
 
-       protected <K> KeyedStateBackend<K> createKeyedBackend(TypeSerializer<K> 
keySerializer, Environment env) throws Exception {
+       protected <K> AbstractKeyedStateBackend<K> 
createKeyedBackend(TypeSerializer<K> keySerializer, Environment env) throws 
Exception {
                return createKeyedBackend(
                                keySerializer,
                                10,
@@ -92,7 +92,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> {
                                env);
        }
 
-       protected <K> KeyedStateBackend<K> createKeyedBackend(
+       protected <K> AbstractKeyedStateBackend<K> createKeyedBackend(
                        TypeSerializer<K> keySerializer,
                        int numberOfKeyGroups,
                        KeyGroupRange keyGroupRange,
@@ -104,14 +104,15 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> {
                                keySerializer,
                                numberOfKeyGroups,
                                keyGroupRange,
-                               env.getTaskKvStateRegistry());
+                               env.getTaskKvStateRegistry())
+;
        }
 
-       protected <K> KeyedStateBackend<K> 
restoreKeyedBackend(TypeSerializer<K> keySerializer, KeyGroupsStateHandle 
state) throws Exception {
+       protected <K> AbstractKeyedStateBackend<K> 
restoreKeyedBackend(TypeSerializer<K> keySerializer, KeyGroupsStateHandle 
state) throws Exception {
                return restoreKeyedBackend(keySerializer, state, new 
DummyEnvironment("test", 1, 0));
        }
 
-       protected <K> KeyedStateBackend<K> restoreKeyedBackend(
+       protected <K> AbstractKeyedStateBackend<K> restoreKeyedBackend(
                        TypeSerializer<K> keySerializer,
                        KeyGroupsStateHandle state,
                        Environment env) throws Exception {
@@ -123,7 +124,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> {
                                env);
        }
 
-       protected <K> KeyedStateBackend<K> restoreKeyedBackend(
+       protected <K> AbstractKeyedStateBackend<K> restoreKeyedBackend(
                        TypeSerializer<K> keySerializer,
                        int numberOfKeyGroups,
                        KeyGroupRange keyGroupRange,
@@ -144,7 +145,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> {
        @SuppressWarnings("unchecked")
        public void testValueState() throws Exception {
                CheckpointStreamFactory streamFactory = createStreamFactory();
-               KeyedStateBackend<Integer> backend = 
createKeyedBackend(IntSerializer.INSTANCE);
+               AbstractKeyedStateBackend<Integer> backend = 
createKeyedBackend(IntSerializer.INSTANCE);
 
                ValueStateDescriptor<String> kvId = new 
ValueStateDescriptor<>("id", String.class, null);
                kvId.initializeSerializerUnlessSet(new ExecutionConfig());
@@ -195,7 +196,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> {
                assertEquals("u3", state.value());
                assertEquals("u3", getSerializedValue(kvState, 3, 
keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer));
 
-               backend.close();
+               backend.dispose();
                backend = restoreKeyedBackend(IntSerializer.INSTANCE, 
snapshot1);
 
                snapshot1.discardState();
@@ -211,7 +212,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> {
                assertEquals("2", restored1.value());
                assertEquals("2", getSerializedValue(restoredKvState1, 2, 
keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer));
 
-               backend.close();
+               backend.dispose();
                backend = restoreKeyedBackend(IntSerializer.INSTANCE, 
snapshot2);
 
                snapshot2.discardState();
@@ -230,7 +231,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> {
                assertEquals("u3", restored2.value());
                assertEquals("u3", getSerializedValue(restoredKvState2, 3, 
keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer));
 
-               backend.close();
+               backend.dispose();
        }
 
        @Test
@@ -238,7 +239,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> {
        public void testMultipleValueStates() throws Exception {
                CheckpointStreamFactory streamFactory = createStreamFactory();
 
-               KeyedStateBackend<Integer> backend = createKeyedBackend(
+               AbstractKeyedStateBackend<Integer> backend = createKeyedBackend(
                                IntSerializer.INSTANCE,
                                1,
                                new KeyGroupRange(0, 0),
@@ -271,7 +272,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> {
                // draw a snapshot
                KeyGroupsStateHandle snapshot1 = 
runSnapshot(backend.snapshot(682375462378L, 2, streamFactory));
 
-               backend.close();
+               backend.dispose();
                backend = restoreKeyedBackend(
                                IntSerializer.INSTANCE,
                                1,
@@ -290,7 +291,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> {
                assertEquals("1", state1.value());
                assertEquals(13, (int) state2.value());
 
-               backend.close();
+               backend.dispose();
        }
 
        /**
@@ -313,7 +314,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> {
                }
 
                CheckpointStreamFactory streamFactory = createStreamFactory();
-               KeyedStateBackend<Integer> backend = 
createKeyedBackend(IntSerializer.INSTANCE);
+               AbstractKeyedStateBackend<Integer> backend = 
createKeyedBackend(IntSerializer.INSTANCE);
 
                ValueStateDescriptor<Long> kvId = new 
ValueStateDescriptor<>("id", LongSerializer.INSTANCE, 42L);
                kvId.initializeSerializerUnlessSet(new ExecutionConfig());
@@ -344,14 +345,14 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> {
                // draw a snapshot
                KeyGroupsStateHandle snapshot1 = 
runSnapshot(backend.snapshot(682375462378L, 2, streamFactory));
 
-               backend.close();
+               backend.dispose();
                backend = restoreKeyedBackend(IntSerializer.INSTANCE, 
snapshot1);
 
                snapshot1.discardState();
 
                backend.getPartitionedState(VoidNamespace.INSTANCE, 
VoidNamespaceSerializer.INSTANCE, kvId);
 
-               backend.close();
+               backend.dispose();
        }
 
        @Test
@@ -359,7 +360,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> {
        public void testListState() {
                try {
                        CheckpointStreamFactory streamFactory = 
createStreamFactory();
-                       KeyedStateBackend<Integer> backend = 
createKeyedBackend(IntSerializer.INSTANCE);
+                       AbstractKeyedStateBackend<Integer> backend = 
createKeyedBackend(IntSerializer.INSTANCE);
 
                        ListStateDescriptor<String> kvId = new 
ListStateDescriptor<>("id", String.class);
                        kvId.initializeSerializerUnlessSet(new 
ExecutionConfig());
@@ -411,7 +412,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> {
                        assertEquals("u3", joiner.join(state.get()));
                        assertEquals("u3", 
joiner.join(getSerializedList(kvState, 3, keySerializer, 
VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer)));
 
-                       backend.close();
+                       backend.dispose();
                        // restore the first snapshot and validate it
                        backend = restoreKeyedBackend(IntSerializer.INSTANCE, 
snapshot1);
                        snapshot1.discardState();
@@ -427,7 +428,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> {
                        assertEquals("2", joiner.join(restored1.get()));
                        assertEquals("2", 
joiner.join(getSerializedList(restoredKvState1, 2, keySerializer, 
VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer)));
 
-                       backend.close();
+                       backend.dispose();
                        // restore the second snapshot and validate it
                        backend = restoreKeyedBackend(IntSerializer.INSTANCE, 
snapshot2);
                        snapshot2.discardState();
@@ -446,7 +447,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> {
                        assertEquals("u3", joiner.join(restored2.get()));
                        assertEquals("u3", 
joiner.join(getSerializedList(restoredKvState2, 3, keySerializer, 
VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer)));
 
-                       backend.close();
+                       backend.dispose();
                }
                catch (Exception e) {
                        e.printStackTrace();
@@ -459,7 +460,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> {
        public void testReducingState() {
                try {
                        CheckpointStreamFactory streamFactory = 
createStreamFactory();
-                       KeyedStateBackend<Integer> backend = 
createKeyedBackend(IntSerializer.INSTANCE);
+                       AbstractKeyedStateBackend<Integer> backend = 
createKeyedBackend(IntSerializer.INSTANCE);
 
                        ReducingStateDescriptor<String> kvId = new 
ReducingStateDescriptor<>("id", new AppendingReduce(), String.class);
                        kvId.initializeSerializerUnlessSet(new 
ExecutionConfig());
@@ -510,7 +511,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> {
                        assertEquals("u3", state.get());
                        assertEquals("u3", getSerializedValue(kvState, 3, 
keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer));
 
-                       backend.close();
+                       backend.dispose();
                        // restore the first snapshot and validate it
                        backend = restoreKeyedBackend(IntSerializer.INSTANCE, 
snapshot1);
                        snapshot1.discardState();
@@ -526,7 +527,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> {
                        assertEquals("2", restored1.get());
                        assertEquals("2", getSerializedValue(restoredKvState1, 
2, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, 
valueSerializer));
 
-                       backend.close();
+                       backend.dispose();
                        // restore the second snapshot and validate it
                        backend = restoreKeyedBackend(IntSerializer.INSTANCE, 
snapshot2);
                        snapshot2.discardState();
@@ -545,7 +546,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> {
                        assertEquals("u3", restored2.get());
                        assertEquals("u3", getSerializedValue(restoredKvState2, 
3, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, 
valueSerializer));
 
-                       backend.close();
+                       backend.dispose();
                }
                catch (Exception e) {
                        e.printStackTrace();
@@ -558,7 +559,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> {
        public void testFoldingState() {
                try {
                        CheckpointStreamFactory streamFactory = 
createStreamFactory();
-                       KeyedStateBackend<Integer> backend = 
createKeyedBackend(IntSerializer.INSTANCE);
+                       AbstractKeyedStateBackend<Integer> backend = 
createKeyedBackend(IntSerializer.INSTANCE);
 
                        FoldingStateDescriptor<Integer, String> kvId = new 
FoldingStateDescriptor<>("id",
                                        "Fold-Initial:",
@@ -613,7 +614,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> {
                        assertEquals("Fold-Initial:,103", state.get());
                        assertEquals("Fold-Initial:,103", 
getSerializedValue(kvState, 3, keySerializer, VoidNamespace.INSTANCE, 
namespaceSerializer, valueSerializer));
 
-                       backend.close();
+                       backend.dispose();
                        // restore the first snapshot and validate it
                        backend = restoreKeyedBackend(IntSerializer.INSTANCE, 
snapshot1);
                        snapshot1.discardState();
@@ -629,7 +630,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> {
                        assertEquals("Fold-Initial:,2", restored1.get());
                        assertEquals("Fold-Initial:,2", 
getSerializedValue(restoredKvState1, 2, keySerializer, VoidNamespace.INSTANCE, 
namespaceSerializer, valueSerializer));
 
-                       backend.close();
+                       backend.dispose();
                        // restore the second snapshot and validate it
                        backend = restoreKeyedBackend(IntSerializer.INSTANCE, 
snapshot2);
                        snapshot1.discardState();
@@ -649,7 +650,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> {
                        assertEquals("Fold-Initial:,103", restored2.get());
                        assertEquals("Fold-Initial:,103", 
getSerializedValue(restoredKvState2, 3, keySerializer, VoidNamespace.INSTANCE, 
namespaceSerializer, valueSerializer));
 
-                       backend.close();
+                       backend.dispose();
                }
                catch (Exception e) {
                        e.printStackTrace();
@@ -672,7 +673,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> {
                final int MAX_PARALLELISM = 10;
 
                CheckpointStreamFactory streamFactory = createStreamFactory();
-               KeyedStateBackend<Integer> backend = createKeyedBackend(
+               AbstractKeyedStateBackend<Integer> backend = createKeyedBackend(
                                IntSerializer.INSTANCE,
                                MAX_PARALLELISM,
                                new KeyGroupRange(0, MAX_PARALLELISM - 1),
@@ -714,10 +715,10 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> {
                                Collections.singletonList(snapshot),
                                
KeyGroupRangeAssignment.computeKeyGroupRangeForOperatorIndex(MAX_PARALLELISM, 
2, 1));
 
-               backend.close();
+               backend.dispose();
 
                // backend for the first half of the key group range
-               KeyedStateBackend<Integer> firstHalfBackend = 
restoreKeyedBackend(
+               AbstractKeyedStateBackend<Integer> firstHalfBackend = 
restoreKeyedBackend(
                                IntSerializer.INSTANCE,
                                MAX_PARALLELISM,
                                new KeyGroupRange(0, 4),
@@ -725,7 +726,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> {
                                new DummyEnvironment("test", 1, 0));
 
                // backend for the second half of the key group range
-               KeyedStateBackend<Integer> secondHalfBackend = 
restoreKeyedBackend(
+               AbstractKeyedStateBackend<Integer> secondHalfBackend = 
restoreKeyedBackend(
                                IntSerializer.INSTANCE,
                                MAX_PARALLELISM,
                                new KeyGroupRange(5, 9),
@@ -749,8 +750,8 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> {
                secondHalfBackend.setCurrentKey(keyInSecondHalf);
                
assertTrue(secondHalfState.value().equals("ShouldBeInSecondHalf"));
 
-               firstHalfBackend.close();
-               secondHalfBackend.close();
+               firstHalfBackend.dispose();
+               secondHalfBackend.dispose();
        }
 
        @Test
@@ -758,7 +759,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> {
        public void testValueStateRestoreWithWrongSerializers() {
                try {
                        CheckpointStreamFactory streamFactory = 
createStreamFactory();
-                       KeyedStateBackend<Integer> backend = 
createKeyedBackend(IntSerializer.INSTANCE);
+                       AbstractKeyedStateBackend<Integer> backend = 
createKeyedBackend(IntSerializer.INSTANCE);
 
                        ValueStateDescriptor<String> kvId = new 
ValueStateDescriptor<>("id", String.class, null);
                        kvId.initializeSerializerUnlessSet(new 
ExecutionConfig());
@@ -773,7 +774,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> {
                        // draw a snapshot
                        KeyGroupsStateHandle snapshot1 = 
runSnapshot(backend.snapshot(682375462378L, 2, streamFactory));
 
-                       backend.close();
+                       backend.dispose();
                        // restore the first snapshot and validate it
                        backend = restoreKeyedBackend(IntSerializer.INSTANCE, 
snapshot1);
                        snapshot1.discardState();
@@ -798,7 +799,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> {
                        } catch (Exception e) {
                                fail("wrong exception " + e);
                        }
-                       backend.close();
+                       backend.dispose();
                }
                catch (Exception e) {
                        e.printStackTrace();
@@ -811,7 +812,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> {
        public void testListStateRestoreWithWrongSerializers() {
                try {
                        CheckpointStreamFactory streamFactory = 
createStreamFactory();
-                       KeyedStateBackend<Integer> backend = 
createKeyedBackend(IntSerializer.INSTANCE);
+                       AbstractKeyedStateBackend<Integer> backend = 
createKeyedBackend(IntSerializer.INSTANCE);
 
                        ListStateDescriptor<String> kvId = new 
ListStateDescriptor<>("id", String.class);
                        ListState<String> state = 
backend.getPartitionedState(VoidNamespace.INSTANCE, 
VoidNamespaceSerializer.INSTANCE, kvId);
@@ -824,7 +825,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> {
                        // draw a snapshot
                        KeyGroupsStateHandle snapshot1 = 
runSnapshot(backend.snapshot(682375462378L, 2, streamFactory));
 
-                       backend.close();
+                       backend.dispose();
                        // restore the first snapshot and validate it
                        backend = restoreKeyedBackend(IntSerializer.INSTANCE, 
snapshot1);
                        snapshot1.discardState();
@@ -849,7 +850,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> {
                        } catch (Exception e) {
                                fail("wrong exception " + e);
                        }
-                       backend.close();
+                       backend.dispose();
                }
                catch (Exception e) {
                        e.printStackTrace();
@@ -862,7 +863,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> {
        public void testReducingStateRestoreWithWrongSerializers() {
                try {
                        CheckpointStreamFactory streamFactory = 
createStreamFactory();
-                       KeyedStateBackend<Integer> backend = 
createKeyedBackend(IntSerializer.INSTANCE);
+                       AbstractKeyedStateBackend<Integer> backend = 
createKeyedBackend(IntSerializer.INSTANCE);
 
                        ReducingStateDescriptor<String> kvId = new 
ReducingStateDescriptor<>("id",
                                        new AppendingReduce(),
@@ -877,7 +878,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> {
                        // draw a snapshot
                        KeyGroupsStateHandle snapshot1 = 
runSnapshot(backend.snapshot(682375462378L, 2, streamFactory));
 
-                       backend.close();
+                       backend.dispose();
                        // restore the first snapshot and validate it
                        backend = restoreKeyedBackend(IntSerializer.INSTANCE, 
snapshot1);
                        snapshot1.discardState();
@@ -902,7 +903,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> {
                        } catch (Exception e) {
                                fail("wrong exception " + e);
                        }
-                       backend.close();
+                       backend.dispose();
                }
                catch (Exception e) {
                        e.printStackTrace();
@@ -912,7 +913,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> {
 
        @Test
        public void testCopyDefaultValue() throws Exception {
-               KeyedStateBackend<Integer> backend = 
createKeyedBackend(IntSerializer.INSTANCE);
+               AbstractKeyedStateBackend<Integer> backend = 
createKeyedBackend(IntSerializer.INSTANCE);
 
                ValueStateDescriptor<IntValue> kvId = new 
ValueStateDescriptor<>("id", IntValue.class, new IntValue(-1));
                kvId.initializeSerializerUnlessSet(new ExecutionConfig());
@@ -930,7 +931,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> {
                assertEquals(default1, default2);
                assertFalse(default1 == default2);
 
-               backend.close();
+               backend.dispose();
        }
 
        /**
@@ -940,7 +941,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> {
         */
        @Test
        public void testRequireNonNullNamespace() throws Exception {
-               KeyedStateBackend<Integer> backend = 
createKeyedBackend(IntSerializer.INSTANCE);
+               AbstractKeyedStateBackend<Integer> backend = 
createKeyedBackend(IntSerializer.INSTANCE);
 
                ValueStateDescriptor<IntValue> kvId = new 
ValueStateDescriptor<>("id", IntValue.class, new IntValue(-1));
                kvId.initializeSerializerUnlessSet(new ExecutionConfig());
@@ -963,7 +964,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> {
                } catch (NullPointerException ignored) {
                }
 
-               backend.close();
+               backend.dispose();
        }
 
        /**
@@ -973,7 +974,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> {
        @SuppressWarnings("unchecked")
        protected void testConcurrentMapIfQueryable() throws Exception {
                final int numberOfKeyGroups = 1;
-               KeyedStateBackend<Integer> backend = createKeyedBackend(
+               AbstractKeyedStateBackend<Integer> backend = createKeyedBackend(
                                IntSerializer.INSTANCE,
                                numberOfKeyGroups,
                                new KeyGroupRange(0, 0),
@@ -1095,7 +1096,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> {
                        
assertTrue(stateTable.get(keyGroupIndex).get(VoidNamespace.INSTANCE) instanceof 
ConcurrentHashMap);
                }
 
-               backend.close();
+               backend.dispose();
        }
 
        /**
@@ -1107,7 +1108,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> {
                KvStateRegistry registry = env.getKvStateRegistry();
 
                CheckpointStreamFactory streamFactory = createStreamFactory();
-               KeyedStateBackend<Integer> backend = 
createKeyedBackend(IntSerializer.INSTANCE, env);
+               AbstractKeyedStateBackend<Integer> backend = 
createKeyedBackend(IntSerializer.INSTANCE, env);
                KeyGroupRange expectedKeyGroupRange = 
backend.getKeyGroupRange();
 
                KvStateRegistryListener listener = 
mock(KvStateRegistryListener.class);
@@ -1128,11 +1129,11 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> {
 
                KeyGroupsStateHandle snapshot = 
runSnapshot(backend.snapshot(682375462379L, 4, streamFactory));
 
-               backend.close();
+               backend.dispose();
 
                verify(listener, times(1)).notifyKvStateUnregistered(
                                eq(env.getJobID()), eq(env.getJobVertexId()), 
eq(expectedKeyGroupRange), eq("banana"));
-               backend.close();
+               backend.dispose();
                // Initialize again
                backend = restoreKeyedBackend(IntSerializer.INSTANCE, snapshot, 
env);
                snapshot.discardState();
@@ -1143,7 +1144,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> {
                verify(listener, times(2)).notifyKvStateRegistered(
                                eq(env.getJobID()), eq(env.getJobVertexId()), 
eq(expectedKeyGroupRange), eq("banana"), any(KvStateID.class));
 
-               backend.close();
+               backend.dispose();
 
        }
 
@@ -1152,17 +1153,17 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> {
 
                try {
                        CheckpointStreamFactory streamFactory = 
createStreamFactory();
-                       KeyedStateBackend<Integer> backend = 
createKeyedBackend(IntSerializer.INSTANCE);
+                       AbstractKeyedStateBackend<Integer> backend = 
createKeyedBackend(IntSerializer.INSTANCE);
 
                        ListStateDescriptor<String> kvId = new 
ListStateDescriptor<>("id", String.class);
 
                        // draw a snapshot
                        KeyGroupsStateHandle snapshot = 
runSnapshot(backend.snapshot(682375462379L, 1, streamFactory));
                        assertNull(snapshot);
-                       backend.close();
+                       backend.dispose();
 
                        backend = restoreKeyedBackend(IntSerializer.INSTANCE, 
snapshot);
-                       backend.close();
+                       backend.dispose();
                }
                catch (Exception e) {
                        e.printStackTrace();

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/state/filesystem/FsCheckpointStateOutputStreamTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/filesystem/FsCheckpointStateOutputStreamTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/filesystem/FsCheckpointStateOutputStreamTest.java
index a6a555d..d484f2e 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/filesystem/FsCheckpointStateOutputStreamTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/filesystem/FsCheckpointStateOutputStreamTest.java
@@ -20,8 +20,6 @@ package org.apache.flink.runtime.state.filesystem;
 
 import org.apache.flink.core.fs.FileSystem;
 import org.apache.flink.core.fs.Path;
-import org.apache.flink.runtime.state.AbstractStateBackend;
-
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
 import org.junit.Test;
@@ -31,7 +29,8 @@ import java.io.File;
 import java.io.InputStream;
 import java.util.Random;
 
-import static org.junit.Assert.*;
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertTrue;
 
 public class FsCheckpointStateOutputStreamTest {
 
@@ -112,13 +111,14 @@ public class FsCheckpointStateOutputStreamTest {
                // make sure the writing process did not alter the original 
byte array
                assertArrayEquals(original, bytes);
 
-               InputStream inStream = handle.openInputStream();
-               byte[] validation = new byte[bytes.length];
+               try (InputStream inStream = handle.openInputStream()) {
+                       byte[] validation = new byte[bytes.length];
 
-               DataInputStream dataInputStream = new DataInputStream(inStream);
-               dataInputStream.readFully(validation);
+                       DataInputStream dataInputStream = new 
DataInputStream(inStream);
+                       dataInputStream.readFully(validation);
 
-               assertArrayEquals(bytes, validation);
+                       assertArrayEquals(bytes, validation);
+               }
 
                handle.discardState();
        }

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java
index 454196f..7bc2c29 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java
@@ -46,6 +46,7 @@ import org.apache.flink.runtime.memory.MemoryManager;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
 import org.apache.flink.runtime.state.ChainedStateHandle;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
 
 import org.apache.flink.util.SerializedValue;
@@ -53,6 +54,7 @@ import org.junit.Before;
 import org.junit.Test;
 
 import java.net.URL;
+import java.util.Collection;
 import java.util.Collections;
 import java.util.List;
 import java.util.concurrent.Executor;
@@ -209,7 +211,8 @@ public class TaskAsyncCallTest {
 
                @Override
                public void 
setInitialState(ChainedStateHandle<StreamStateHandle> chainedState,
-                               List<KeyGroupsStateHandle> keyGroupsState) 
throws Exception {
+                                                                       
List<KeyGroupsStateHandle> keyGroupsState,
+                                                                       
List<Collection<OperatorStateHandle>> partitionableOperatorState) throws 
Exception {
 
                }
 

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/zookeeper/ZooKeeperStateHandleStoreITCase.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/zookeeper/ZooKeeperStateHandleStoreITCase.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/zookeeper/ZooKeeperStateHandleStoreITCase.java
index 7e8868c..8f9c932 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/zookeeper/ZooKeeperStateHandleStoreITCase.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/zookeeper/ZooKeeperStateHandleStoreITCase.java
@@ -33,7 +33,6 @@ import org.junit.Test;
 import org.mockito.invocation.InvocationOnMock;
 import org.mockito.stubbing.Answer;
 
-import java.io.IOException;
 import java.util.ArrayList;
 import java.util.HashSet;
 import java.util.List;
@@ -587,8 +586,5 @@ public class ZooKeeperStateHandleStoreITCase extends 
TestLogger {
                public int getNumberOfDiscardCalls() {
                        return numberOfDiscardCalls;
                }
-
-               @Override
-               public void close() throws IOException {}
        }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-connectors/flink-connector-kafka-0.8/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumer08.java
----------------------------------------------------------------------
diff --git 
a/flink-streaming-connectors/flink-connector-kafka-0.8/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumer08.java
 
b/flink-streaming-connectors/flink-connector-kafka-0.8/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumer08.java
index c16629d..d7a6364 100644
--- 
a/flink-streaming-connectors/flink-connector-kafka-0.8/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumer08.java
+++ 
b/flink-streaming-connectors/flink-connector-kafka-0.8/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumer08.java
@@ -177,7 +177,7 @@ public class FlinkKafkaConsumer08<T> extends 
FlinkKafkaConsumerBase<T> {
         *           The properties that are used to configure both the fetcher 
and the offset handler.
         */
        public FlinkKafkaConsumer08(List<String> topics, 
KeyedDeserializationSchema<T> deserializer, Properties props) {
-               super(deserializer);
+               super(topics, deserializer);
 
                checkNotNull(topics, "topics");
                this.kafkaProperties = checkNotNull(props, "props");
@@ -187,22 +187,6 @@ public class FlinkKafkaConsumer08<T> extends 
FlinkKafkaConsumerBase<T> {
 
                this.invalidOffsetBehavior = getInvalidOffsetBehavior(props);
                this.autoCommitInterval = PropertiesUtil.getLong(props, 
"auto.commit.interval.ms", 60000);
-
-               // Connect to a broker to get the partitions for all topics
-               List<KafkaTopicPartition> partitionInfos = 
-                               
KafkaTopicPartition.dropLeaderData(getPartitionsForTopic(topics, props));
-
-               if (partitionInfos.size() == 0) {
-                       throw new RuntimeException(
-                                       "Unable to retrieve any partitions for 
the requested topics " + topics + 
-                                                       ". Please check 
previous log entries");
-               }
-
-               if (LOG.isInfoEnabled()) {
-                       logPartitionInfo(LOG, partitionInfos);
-               }
-
-               setSubscribedPartitions(partitionInfos);
        }
 
        @Override
@@ -221,6 +205,25 @@ public class FlinkKafkaConsumer08<T> extends 
FlinkKafkaConsumerBase<T> {
                                invalidOffsetBehavior, autoCommitInterval, 
useMetrics);
        }
 
+       @Override
+       protected List<KafkaTopicPartition> getKafkaPartitions(List<String> 
topics) {
+               // Connect to a broker to get the partitions for all topics
+               List<KafkaTopicPartition> partitionInfos =
+                       
KafkaTopicPartition.dropLeaderData(getPartitionsForTopic(topics, 
kafkaProperties));
+
+               if (partitionInfos.size() == 0) {
+                       throw new RuntimeException(
+                               "Unable to retrieve any partitions for the 
requested topics " + topics +
+                                       ". Please check previous log entries");
+               }
+
+               if (LOG.isInfoEnabled()) {
+                       logPartitionInfo(LOG, partitionInfos);
+               }
+
+               return partitionInfos;
+       }
+
        // 
------------------------------------------------------------------------
        //  Kafka / ZooKeeper communication utilities
        // 
------------------------------------------------------------------------

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-connectors/flink-connector-kafka-0.8/src/test/java/org/apache/flink/streaming/connectors/kafka/KafkaConsumer08Test.java
----------------------------------------------------------------------
diff --git 
a/flink-streaming-connectors/flink-connector-kafka-0.8/src/test/java/org/apache/flink/streaming/connectors/kafka/KafkaConsumer08Test.java
 
b/flink-streaming-connectors/flink-connector-kafka-0.8/src/test/java/org/apache/flink/streaming/connectors/kafka/KafkaConsumer08Test.java
index 36fb7e6..f0b58cf 100644
--- 
a/flink-streaming-connectors/flink-connector-kafka-0.8/src/test/java/org/apache/flink/streaming/connectors/kafka/KafkaConsumer08Test.java
+++ 
b/flink-streaming-connectors/flink-connector-kafka-0.8/src/test/java/org/apache/flink/streaming/connectors/kafka/KafkaConsumer08Test.java
@@ -18,6 +18,7 @@
 
 package org.apache.flink.streaming.connectors.kafka;
 
+import org.apache.flink.configuration.Configuration;
 import org.apache.flink.streaming.util.serialization.SimpleStringSchema;
 
 import org.apache.kafka.clients.consumer.ConsumerConfig;
@@ -80,7 +81,8 @@ public class KafkaConsumer08Test {
                        props.setProperty("bootstrap.servers", 
"localhost:11111, localhost:22222");
                        props.setProperty("group.id", "non-existent-group");
 
-                       new 
FlinkKafkaConsumer08<>(Collections.singletonList("no op topic"), new 
SimpleStringSchema(), props);
+                       FlinkKafkaConsumer08<String> consumer = new 
FlinkKafkaConsumer08<>(Collections.singletonList("no op topic"), new 
SimpleStringSchema(), props);
+                       consumer.open(new Configuration());
                        fail();
                }
                catch (Exception e) {

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-connectors/flink-connector-kafka-0.9/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumer09.java
----------------------------------------------------------------------
diff --git 
a/flink-streaming-connectors/flink-connector-kafka-0.9/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumer09.java
 
b/flink-streaming-connectors/flink-connector-kafka-0.9/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumer09.java
index 8c3eaf8..9708777 100644
--- 
a/flink-streaming-connectors/flink-connector-kafka-0.9/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumer09.java
+++ 
b/flink-streaming-connectors/flink-connector-kafka-0.9/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumer09.java
@@ -149,9 +149,8 @@ public class FlinkKafkaConsumer09<T> extends 
FlinkKafkaConsumerBase<T> {
         *           The properties that are used to configure both the fetcher 
and the offset handler.
         */
        public FlinkKafkaConsumer09(List<String> topics, 
KeyedDeserializationSchema<T> deserializer, Properties props) {
-               super(deserializer);
+               super(topics, deserializer);
 
-               checkNotNull(topics, "topics");
                this.properties = checkNotNull(props, "props");
                setDeserializer(this.properties);
 
@@ -166,7 +165,27 @@ public class FlinkKafkaConsumer09<T> extends 
FlinkKafkaConsumerBase<T> {
                catch (Exception e) {
                        throw new IllegalArgumentException("Cannot parse poll 
timeout for '" + KEY_POLL_TIMEOUT + '\'', e);
                }
+       }
+
+       @Override
+       protected AbstractFetcher<T, ?> createFetcher(
+                       SourceContext<T> sourceContext,
+                       List<KafkaTopicPartition> thisSubtaskPartitions,
+                       SerializedValue<AssignerWithPeriodicWatermarks<T>> 
watermarksPeriodic,
+                       SerializedValue<AssignerWithPunctuatedWatermarks<T>> 
watermarksPunctuated,
+                       StreamingRuntimeContext runtimeContext) throws 
Exception {
+
+               boolean useMetrics = 
!Boolean.valueOf(properties.getProperty(KEY_DISABLE_METRICS, "false"));
+
+               return new Kafka09Fetcher<>(sourceContext, 
thisSubtaskPartitions,
+                               watermarksPeriodic, watermarksPunctuated,
+                               runtimeContext, deserializer,
+                               properties, pollTimeout, useMetrics);
+               
+       }
 
+       @Override
+       protected List<KafkaTopicPartition> getKafkaPartitions(List<String> 
topics) {
                // read the partitions that belong to the listed topics
                final List<KafkaTopicPartition> partitions = new ArrayList<>();
 
@@ -192,25 +211,7 @@ public class FlinkKafkaConsumer09<T> extends 
FlinkKafkaConsumerBase<T> {
                        logPartitionInfo(LOG, partitions);
                }
 
-               // register these partitions
-               setSubscribedPartitions(partitions);
-       }
-
-       @Override
-       protected AbstractFetcher<T, ?> createFetcher(
-                       SourceContext<T> sourceContext,
-                       List<KafkaTopicPartition> thisSubtaskPartitions,
-                       SerializedValue<AssignerWithPeriodicWatermarks<T>> 
watermarksPeriodic,
-                       SerializedValue<AssignerWithPunctuatedWatermarks<T>> 
watermarksPunctuated,
-                       StreamingRuntimeContext runtimeContext) throws 
Exception {
-
-               boolean useMetrics = 
!Boolean.valueOf(properties.getProperty(KEY_DISABLE_METRICS, "false"));
-
-               return new Kafka09Fetcher<>(sourceContext, 
thisSubtaskPartitions,
-                               watermarksPeriodic, watermarksPunctuated,
-                               runtimeContext, deserializer,
-                               properties, pollTimeout, useMetrics);
-               
+               return partitions;
        }
 
        // 
------------------------------------------------------------------------

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBase.java
----------------------------------------------------------------------
diff --git 
a/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBase.java
 
b/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBase.java
index 2b2c527..939b77b 100644
--- 
a/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBase.java
+++ 
b/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBase.java
@@ -18,11 +18,16 @@
 package org.apache.flink.streaming.connectors.kafka;
 
 import org.apache.commons.collections.map.LinkedMap;
-
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.typeinfo.TypeHint;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.api.java.typeutils.ResultTypeQueryable;
+import org.apache.flink.configuration.Configuration;
 import org.apache.flink.runtime.state.CheckpointListener;
-import org.apache.flink.streaming.api.checkpoint.CheckpointedAsynchronously;
+import org.apache.flink.runtime.state.OperatorStateStore;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
+import org.apache.flink.streaming.api.checkpoint.ListCheckpointed;
 import org.apache.flink.streaming.api.functions.AssignerWithPeriodicWatermarks;
 import 
org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks;
 import 
org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
@@ -30,18 +35,21 @@ import 
org.apache.flink.streaming.api.operators.StreamingRuntimeContext;
 import org.apache.flink.streaming.api.watermark.Watermark;
 import org.apache.flink.streaming.connectors.kafka.internals.AbstractFetcher;
 import 
org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartition;
+import 
org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartitionState;
 import 
org.apache.flink.streaming.util.serialization.KeyedDeserializationSchema;
 import org.apache.flink.util.SerializedValue;
-
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import java.io.Serializable;
 import java.util.ArrayList;
 import java.util.Collections;
+import java.util.Comparator;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 
+import static org.apache.flink.util.Preconditions.checkArgument;
 import static org.apache.flink.util.Preconditions.checkNotNull;
 
 /**
@@ -55,11 +63,12 @@ import static 
org.apache.flink.util.Preconditions.checkNotNull;
  */
 public abstract class FlinkKafkaConsumerBase<T> extends 
RichParallelSourceFunction<T> implements 
                CheckpointListener,
-               CheckpointedAsynchronously<HashMap<KafkaTopicPartition, Long>>,
-               ResultTypeQueryable<T>
-{
+               ResultTypeQueryable<T>,
+               CheckpointedFunction {
        private static final long serialVersionUID = -6272159445203409112L;
 
+       private static final String KAFKA_OFFSETS = "kafka_offsets";
+
        protected static final Logger LOG = 
LoggerFactory.getLogger(FlinkKafkaConsumerBase.class);
        
        /** The maximum number of pending non-committed checkpoints to track, 
to avoid memory leaks */
@@ -71,12 +80,14 @@ public abstract class FlinkKafkaConsumerBase<T> extends 
RichParallelSourceFuncti
        // 
------------------------------------------------------------------------
        //  configuration state, set on the client relevant for all subtasks
        // 
------------------------------------------------------------------------
+
+       private final List<String> topics;
        
        /** The schema to convert between Kafka's byte messages, and Flink's 
objects */
        protected final KeyedDeserializationSchema<T> deserializer;
 
        /** The set of topic partitions that the source will read */
-       protected List<KafkaTopicPartition> allSubscribedPartitions;
+       protected List<KafkaTopicPartition> subscribedPartitions;
        
        /** Optional timestamp extractor / watermark generator that will be run 
per Kafka partition,
         * to exploit per-partition timestamp characteristics.
@@ -88,6 +99,8 @@ public abstract class FlinkKafkaConsumerBase<T> extends 
RichParallelSourceFuncti
         * The assigner is kept in serialized form, to deserialize it into 
multiple copies */
        private SerializedValue<AssignerWithPunctuatedWatermarks<T>> 
punctuatedWatermarkAssigner;
 
+       private transient OperatorStateStore stateStore;
+
        // 
------------------------------------------------------------------------
        //  runtime state (used individually by each parallel subtask) 
        // 
------------------------------------------------------------------------
@@ -112,8 +125,14 @@ public abstract class FlinkKafkaConsumerBase<T> extends 
RichParallelSourceFuncti
         * @param deserializer
         *           The deserializer to turn raw byte messages into Java/Scala 
objects.
         */
-       public FlinkKafkaConsumerBase(KeyedDeserializationSchema<T> 
deserializer) {
+       public FlinkKafkaConsumerBase(List<String> topics, 
KeyedDeserializationSchema<T> deserializer) {
+               this.topics = checkNotNull(topics);
+               checkArgument(topics.size() > 0, "You have to define at least 
one topic.");
+
                this.deserializer = checkNotNull(deserializer, 
"valueDeserializer");
+
+               TypeInformation<Tuple2<KafkaTopicPartition, Long>> typeInfo =
+                               TypeInformation.of(new 
TypeHint<Tuple2<KafkaTopicPartition, Long>>(){});
        }
 
        /**
@@ -124,7 +143,7 @@ public abstract class FlinkKafkaConsumerBase<T> extends 
RichParallelSourceFuncti
         */
        protected void setSubscribedPartitions(List<KafkaTopicPartition> 
allSubscribedPartitions) {
                checkNotNull(allSubscribedPartitions);
-               this.allSubscribedPartitions = 
Collections.unmodifiableList(allSubscribedPartitions);
+               this.subscribedPartitions = 
Collections.unmodifiableList(allSubscribedPartitions);
        }
 
        // 
------------------------------------------------------------------------
@@ -205,20 +224,16 @@ public abstract class FlinkKafkaConsumerBase<T> extends 
RichParallelSourceFuncti
 
        @Override
        public void run(SourceContext<T> sourceContext) throws Exception {
-               if (allSubscribedPartitions == null) {
+               if (subscribedPartitions == null) {
                        throw new Exception("The partitions were not set for 
the consumer");
                }
-               
-               // figure out which partitions this subtask should process
-               final List<KafkaTopicPartition> thisSubtaskPartitions = 
assignPartitions(allSubscribedPartitions,
-                               
getRuntimeContext().getNumberOfParallelSubtasks(), 
getRuntimeContext().getIndexOfThisSubtask());
-               
+
                // we need only do work, if we actually have partitions assigned
-               if (!thisSubtaskPartitions.isEmpty()) {
+               if (!subscribedPartitions.isEmpty()) {
 
                        // (1) create the fetcher that will communicate with 
the Kafka brokers
                        final AbstractFetcher<T, ?> fetcher = createFetcher(
-                                       sourceContext, thisSubtaskPartitions, 
+                                       sourceContext, subscribedPartitions,
                                        periodicWatermarkAssigner, 
punctuatedWatermarkAssigner,
                                        (StreamingRuntimeContext) 
getRuntimeContext());
 
@@ -277,6 +292,15 @@ public abstract class FlinkKafkaConsumerBase<T> extends 
RichParallelSourceFuncti
        }
 
        @Override
+       public void open(Configuration configuration) {
+               List<KafkaTopicPartition> kafkaTopicPartitions = 
getKafkaPartitions(topics);
+
+               if (kafkaTopicPartitions != null) {
+                       assignTopicPartitions(kafkaTopicPartitions);
+               }
+       }
+
+       @Override
        public void close() throws Exception {
                // pretty much the same logic as cancelling
                try {
@@ -289,44 +313,76 @@ public abstract class FlinkKafkaConsumerBase<T> extends 
RichParallelSourceFuncti
        // 
------------------------------------------------------------------------
        //  Checkpoint and restore
        // 
------------------------------------------------------------------------
-       
+
+
        @Override
-       public HashMap<KafkaTopicPartition, Long> snapshotState(long 
checkpointId, long checkpointTimestamp) throws Exception {
-               if (!running) {
-                       LOG.debug("snapshotState() called on closed source");
-                       return null;
-               }
-               
-               final AbstractFetcher<?, ?> fetcher = this.kafkaFetcher;
-               if (fetcher == null) {
-                       // the fetcher has not yet been initialized, which 
means we need to return the
-                       // originally restored offsets
-                       return restoreToOffset;
-               }
+       public void initializeState(OperatorStateStore stateStore) throws 
Exception {
 
-               HashMap<KafkaTopicPartition, Long> currentOffsets = 
fetcher.snapshotCurrentState();
+               this.stateStore = stateStore;
 
-               if (LOG.isDebugEnabled()) {
-                       LOG.debug("Snapshotting state. Offsets: {}, checkpoint 
id: {}, timestamp: {}",
-                                       
KafkaTopicPartition.toString(currentOffsets), checkpointId, 
checkpointTimestamp);
-               }
+               ListState<Serializable> offsets = 
stateStore.getPartitionableState(ListCheckpointed.DEFAULT_LIST_DESCRIPTOR);
 
-               // the map cannot be asynchronously updated, because only one 
checkpoint call can happen
-               // on this function at a time: either snapshotState() or 
notifyCheckpointComplete()
-               pendingCheckpoints.put(checkpointId, currentOffsets);
-               
-               // truncate the map, to prevent infinite growth
-               while (pendingCheckpoints.size() > MAX_NUM_PENDING_CHECKPOINTS) 
{
-                       pendingCheckpoints.remove(0);
+               restoreToOffset = new HashMap<>();
+
+               for (Serializable serializable : offsets.get()) {
+                       @SuppressWarnings("unchecked")
+                       Tuple2<KafkaTopicPartition, Long> kafkaOffset = 
(Tuple2<KafkaTopicPartition, Long>) serializable;
+                       restoreToOffset.put(kafkaOffset.f0, kafkaOffset.f1);
                }
 
-               return currentOffsets;
+               LOG.info("Setting restore state in the FlinkKafkaConsumer: {}", 
restoreToOffset);
        }
 
        @Override
-       public void restoreState(HashMap<KafkaTopicPartition, Long> 
restoredOffsets) {
-               LOG.info("Setting restore state in the FlinkKafkaConsumer: {}", 
restoredOffsets);
-               restoreToOffset = restoredOffsets;
+       public void prepareSnapshot(long checkpointId, long timestamp) throws 
Exception {
+               if (!running) {
+                       LOG.debug("storeOperatorState() called on closed 
source");
+               } else {
+
+                       ListState<Serializable> listState = 
stateStore.getPartitionableState(ListCheckpointed.DEFAULT_LIST_DESCRIPTOR);
+
+                       listState.clear();
+
+                       final AbstractFetcher<?, ?> fetcher = this.kafkaFetcher;
+                       if (fetcher == null) {
+                               // the fetcher has not yet been initialized, 
which means we need to return the
+                               // originally restored offsets or the assigned 
partitions
+
+                               if (restoreToOffset != null) {
+                                       // the map cannot be asynchronously 
updated, because only one checkpoint call can happen
+                                       // on this function at a time: either 
snapshotState() or notifyCheckpointComplete()
+                                       pendingCheckpoints.put(checkpointId, 
restoreToOffset);
+
+                                       // truncate the map, to prevent 
infinite growth
+                                       while (pendingCheckpoints.size() > 
MAX_NUM_PENDING_CHECKPOINTS) {
+                                               pendingCheckpoints.remove(0);
+                                       }
+
+                                       for (Map.Entry<KafkaTopicPartition, 
Long> kafkaTopicPartitionLongEntry : restoreToOffset.entrySet()) {
+                                               
listState.add(Tuple2.of(kafkaTopicPartitionLongEntry.getKey(), 
kafkaTopicPartitionLongEntry.getValue()));
+                                       }
+                               } else if (subscribedPartitions != null) {
+                                       for (KafkaTopicPartition 
subscribedPartition : subscribedPartitions) {
+                                               
listState.add(Tuple2.of(subscribedPartition, 
KafkaTopicPartitionState.OFFSET_NOT_SET));
+                                       }
+                               }
+                       } else {
+                               HashMap<KafkaTopicPartition, Long> 
currentOffsets = fetcher.snapshotCurrentState();
+
+                               // the map cannot be asynchronously updated, 
because only one checkpoint call can happen
+                               // on this function at a time: either 
snapshotState() or notifyCheckpointComplete()
+                               pendingCheckpoints.put(checkpointId, 
currentOffsets);
+
+                               // truncate the map, to prevent infinite growth
+                               while (pendingCheckpoints.size() > 
MAX_NUM_PENDING_CHECKPOINTS) {
+                                       pendingCheckpoints.remove(0);
+                               }
+
+                               for (Map.Entry<KafkaTopicPartition, Long> 
kafkaTopicPartitionLongEntry : currentOffsets.entrySet()) {
+                                       
listState.add(Tuple2.of(kafkaTopicPartitionLongEntry.getKey(), 
kafkaTopicPartitionLongEntry.getValue()));
+                               }
+                       }
+               }
        }
 
        @Override
@@ -401,6 +457,8 @@ public abstract class FlinkKafkaConsumerBase<T> extends 
RichParallelSourceFuncti
                        SerializedValue<AssignerWithPeriodicWatermarks<T>> 
watermarksPeriodic,
                        SerializedValue<AssignerWithPunctuatedWatermarks<T>> 
watermarksPunctuated,
                        StreamingRuntimeContext runtimeContext) throws 
Exception;
+
+       protected abstract List<KafkaTopicPartition> 
getKafkaPartitions(List<String> topics);
        
        // 
------------------------------------------------------------------------
        //  ResultTypeQueryable methods 
@@ -415,6 +473,35 @@ public abstract class FlinkKafkaConsumerBase<T> extends 
RichParallelSourceFuncti
        //  Utilities
        // 
------------------------------------------------------------------------
 
+       private void assignTopicPartitions(List<KafkaTopicPartition> 
kafkaTopicPartitions) {
+               subscribedPartitions = new ArrayList<>();
+
+               if (restoreToOffset != null) {
+                       for (KafkaTopicPartition kafkaTopicPartition : 
kafkaTopicPartitions) {
+                               if 
(restoreToOffset.containsKey(kafkaTopicPartition)) {
+                                       
subscribedPartitions.add(kafkaTopicPartition);
+                               }
+                       }
+               } else {
+                       Collections.sort(kafkaTopicPartitions, new 
Comparator<KafkaTopicPartition>() {
+                               @Override
+                               public int compare(KafkaTopicPartition o1, 
KafkaTopicPartition o2) {
+                                       int topicComparison = 
o1.getTopic().compareTo(o2.getTopic());
+
+                                       if (topicComparison == 0) {
+                                               return o1.getPartition() - 
o2.getPartition();
+                                       } else {
+                                               return topicComparison;
+                                       }
+                               }
+                       });
+
+                       for (int i = 
getRuntimeContext().getIndexOfThisSubtask(); i < kafkaTopicPartitions.size(); i 
+= getRuntimeContext().getNumberOfParallelSubtasks()) {
+                               
subscribedPartitions.add(kafkaTopicPartitions.get(i));
+                       }
+               }
+       }
+
        /**
         * Selects which of the given partitions should be handled by a 
specific consumer,
         * given a certain number of consumers.
@@ -427,8 +514,7 @@ public abstract class FlinkKafkaConsumerBase<T> extends 
RichParallelSourceFuncti
         */
        protected static List<KafkaTopicPartition> assignPartitions(
                        List<KafkaTopicPartition> allPartitions,
-                       int numConsumers, int consumerIndex)
-       {
+                       int numConsumers, int consumerIndex) {
                final List<KafkaTopicPartition> thisSubtaskPartitions = new 
ArrayList<>(
                                allPartitions.size() / numConsumers + 1);
 

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducerBase.java
----------------------------------------------------------------------
diff --git 
a/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducerBase.java
 
b/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducerBase.java
index e63f033..8b87004 100644
--- 
a/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducerBase.java
+++ 
b/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducerBase.java
@@ -20,16 +20,16 @@ package org.apache.flink.streaming.connectors.kafka;
 import org.apache.flink.api.common.functions.RuntimeContext;
 import org.apache.flink.api.java.ClosureCleaner;
 import org.apache.flink.configuration.Configuration;
-import org.apache.flink.runtime.util.SerializableObject;
 import org.apache.flink.metrics.MetricGroup;
-import org.apache.flink.streaming.api.checkpoint.Checkpointed;
+import org.apache.flink.runtime.state.OperatorStateStore;
+import org.apache.flink.runtime.util.SerializableObject;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
 import org.apache.flink.streaming.api.functions.sink.RichSinkFunction;
 import org.apache.flink.streaming.api.operators.StreamingRuntimeContext;
 import 
org.apache.flink.streaming.connectors.kafka.internals.metrics.KafkaMetricWrapper;
 import 
org.apache.flink.streaming.connectors.kafka.partitioner.KafkaPartitioner;
 import org.apache.flink.streaming.util.serialization.KeyedSerializationSchema;
 import org.apache.flink.util.NetUtils;
-
 import org.apache.kafka.clients.producer.Callback;
 import org.apache.kafka.clients.producer.KafkaProducer;
 import org.apache.kafka.clients.producer.Producer;
@@ -40,11 +40,9 @@ import org.apache.kafka.common.Metric;
 import org.apache.kafka.common.MetricName;
 import org.apache.kafka.common.PartitionInfo;
 import org.apache.kafka.common.serialization.ByteArraySerializer;
-
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import java.io.Serializable;
 import java.util.List;
 import java.util.Map;
 import java.util.Properties;
@@ -61,7 +59,7 @@ import static java.util.Objects.requireNonNull;
  *
  * @param <IN> Type of the messages to write into Kafka.
  */
-public abstract class FlinkKafkaProducerBase<IN> extends RichSinkFunction<IN> 
implements Checkpointed<Serializable> {
+public abstract class FlinkKafkaProducerBase<IN> extends RichSinkFunction<IN> 
implements CheckpointedFunction {
 
        private static final Logger LOG = 
LoggerFactory.getLogger(FlinkKafkaProducerBase.class);
 
@@ -126,6 +124,8 @@ public abstract class FlinkKafkaProducerBase<IN> extends 
RichSinkFunction<IN> im
        /** Number of unacknowledged records. */
        protected long pendingRecords;
 
+       protected OperatorStateStore stateStore;
+
 
        /**
         * The main constructor for creating a FlinkKafkaProducer.
@@ -330,7 +330,12 @@ public abstract class FlinkKafkaProducerBase<IN> extends 
RichSinkFunction<IN> im
        protected abstract void flush();
 
        @Override
-       public Serializable snapshotState(long checkpointId, long 
checkpointTimestamp) {
+       public void initializeState(OperatorStateStore stateStore) throws 
Exception {
+               this.stateStore = stateStore;
+       }
+
+       @Override
+       public void prepareSnapshot(long checkpointId, long timestamp) throws 
Exception {
                if (flushOnCheckpoint) {
                        // flushing is activated: We need to wait until 
pendingRecords is 0
                        flush();
@@ -341,16 +346,8 @@ public abstract class FlinkKafkaProducerBase<IN> extends 
RichSinkFunction<IN> im
                                // pending records count is 0. We can now 
confirm the checkpoint
                        }
                }
-               // return empty state
-               return null;
-       }
-
-       @Override
-       public void restoreState(Serializable state) {
-               // nothing to do here
        }
 
-
        // ----------------------------------- Utilities 
--------------------------
 
        protected void checkErroneous() throws Exception {

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/internals/AbstractFetcher.java
----------------------------------------------------------------------
diff --git 
a/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/internals/AbstractFetcher.java
 
b/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/internals/AbstractFetcher.java
index 9255445..7ce3a9d 100644
--- 
a/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/internals/AbstractFetcher.java
+++ 
b/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/internals/AbstractFetcher.java
@@ -183,9 +183,7 @@ public abstract class AbstractFetcher<T, KPH> {
 
                HashMap<KafkaTopicPartition, Long> state = new 
HashMap<>(allPartitions.length);
                for (KafkaTopicPartitionState<?> partition : 
subscribedPartitions()) {
-                       if (partition.isOffsetDefined()) {
-                               state.put(partition.getKafkaTopicPartition(), 
partition.getOffset());
-                       }
+                       state.put(partition.getKafkaTopicPartition(), 
partition.getOffset());
                }
                return state;
        }

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/AtLeastOnceProducerTest.java
----------------------------------------------------------------------
diff --git 
a/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/AtLeastOnceProducerTest.java
 
b/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/AtLeastOnceProducerTest.java
index b02593c..766a107 100644
--- 
a/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/AtLeastOnceProducerTest.java
+++ 
b/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/AtLeastOnceProducerTest.java
@@ -20,6 +20,7 @@ package org.apache.flink.streaming.connectors.kafka;
 
 import org.apache.flink.api.java.tuple.Tuple1;
 import org.apache.flink.configuration.Configuration;
+import org.apache.flink.runtime.state.OperatorStateStore;
 import 
org.apache.flink.streaming.connectors.kafka.testutils.MockRuntimeContext;
 import org.apache.flink.streaming.util.serialization.KeyedSerializationSchema;
 import 
org.apache.flink.streaming.util.serialization.KeyedSerializationSchemaWrapper;
@@ -37,7 +38,6 @@ import org.junit.Test;
 import scala.concurrent.duration.Deadline;
 import scala.concurrent.duration.FiniteDuration;
 
-import java.io.Serializable;
 import java.util.ArrayList;
 import java.util.List;
 import java.util.Map;
@@ -45,6 +45,8 @@ import java.util.Properties;
 import java.util.concurrent.Future;
 import java.util.concurrent.atomic.AtomicBoolean;
 
+import static org.mockito.Mockito.mock;
+
 /**
  * Test ensuring that the producer is not dropping buffered records
  */
@@ -111,7 +113,7 @@ public class AtLeastOnceProducerTest {
                Thread threadB = new Thread(confirmer);
                threadB.start();
                // this should block:
-               producer.snapshotState(0, 0);
+               producer.prepareSnapshot(0, 0);
                synchronized (threadA) {
                        threadA.notifyAll(); // just in case, to let the test 
fail faster
                }
@@ -130,6 +132,8 @@ public class AtLeastOnceProducerTest {
 
 
        private static class TestingKafkaProducer<T> extends 
FlinkKafkaProducerBase<T> {
+               private static final long serialVersionUID = 
-1759403646061180067L;
+
                private MockProducer prod;
                private AtomicBoolean snapshottingFinished;
 
@@ -145,12 +149,11 @@ public class AtLeastOnceProducerTest {
                }
 
                @Override
-               public Serializable snapshotState(long checkpointId, long 
checkpointTimestamp) {
+               public void prepareSnapshot(long checkpointId, long timestamp) 
throws Exception {
                        // call the actual snapshot state
-                       Serializable ret = super.snapshotState(checkpointId, 
checkpointTimestamp);
+                       super.prepareSnapshot(checkpointId, timestamp);
                        // notify test that snapshotting has been done
                        snapshottingFinished.set(true);
-                       return ret;
                }
 
                @Override

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBaseTest.java
----------------------------------------------------------------------
diff --git 
a/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBaseTest.java
 
b/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBaseTest.java
index 9b517df..fc8b7e9 100644
--- 
a/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBaseTest.java
+++ 
b/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBaseTest.java
@@ -19,6 +19,11 @@
 package org.apache.flink.streaming.connectors.kafka;
 
 import org.apache.commons.collections.map.LinkedMap;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.runtime.state.OperatorStateStore;
 import org.apache.flink.streaming.api.functions.AssignerWithPeriodicWatermarks;
 import 
org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks;
 import org.apache.flink.streaming.api.operators.StreamingRuntimeContext;
@@ -26,15 +31,26 @@ import 
org.apache.flink.streaming.connectors.kafka.internals.AbstractFetcher;
 import 
org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartition;
 import 
org.apache.flink.streaming.util.serialization.KeyedDeserializationSchema;
 import org.apache.flink.util.SerializedValue;
-
 import org.junit.Test;
+import org.mockito.Matchers;
 
 import java.lang.reflect.Field;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
 import java.util.HashMap;
+import java.util.HashSet;
 import java.util.List;
+import java.util.Set;
 
-import static org.junit.Assert.*;
-import static org.mockito.Mockito.*;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+import static org.mockito.Mockito.any;
+import static org.mockito.Mockito.anyString;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
 
 public class FlinkKafkaConsumerBaseTest {
 
@@ -82,7 +98,13 @@ public class FlinkKafkaConsumerBaseTest {
                final AbstractFetcher<String, ?> fetcher = 
mock(AbstractFetcher.class);
 
                FlinkKafkaConsumerBase<String> consumer = getConsumer(fetcher, 
new LinkedMap(), false);
-               assertNull(consumer.snapshotState(17L, 23L));
+               OperatorStateStore operatorStateStore = 
mock(OperatorStateStore.class);
+               TestingListState<Tuple2<KafkaTopicPartition, Long>> listState = 
new TestingListState<>();
+               
when(operatorStateStore.getPartitionableState(Matchers.any(ListStateDescriptor.class))).thenReturn(listState);
+
+               consumer.prepareSnapshot(17L, 17L);
+
+               assertFalse(listState.get().iterator().hasNext());
                consumer.notifyCheckpointComplete(66L);
        }
 
@@ -91,14 +113,37 @@ public class FlinkKafkaConsumerBaseTest {
         */
        @Test
        public void checkRestoredCheckpointWhenFetcherNotReady() throws 
Exception {
-               HashMap<KafkaTopicPartition, Long> restoreState = new 
HashMap<>();
-               restoreState.put(new KafkaTopicPartition("abc", 13), 16768L);
-               restoreState.put(new KafkaTopicPartition("def", 7), 987654321L);
+               OperatorStateStore operatorStateStore = 
mock(OperatorStateStore.class);
+
+               TestingListState<Tuple2<KafkaTopicPartition, Long>> 
expectedState = new TestingListState<>();
+               expectedState.add(Tuple2.of(new KafkaTopicPartition("abc", 13), 
16768L));
+               expectedState.add(Tuple2.of(new KafkaTopicPartition("def", 7), 
987654321L));
+
+               TestingListState<Tuple2<KafkaTopicPartition, Long>> listState = 
new TestingListState<>();
 
                FlinkKafkaConsumerBase<String> consumer = getConsumer(null, new 
LinkedMap(), true);
-               consumer.restoreState(restoreState);
-               
-               assertEquals(restoreState, consumer.snapshotState(17L, 23L));
+
+               
when(operatorStateStore.getPartitionableState(Matchers.any(ListStateDescriptor.class))).thenReturn(expectedState);
+               consumer.initializeState(operatorStateStore);
+
+               
when(operatorStateStore.getPartitionableState(Matchers.any(ListStateDescriptor.class))).thenReturn(listState);
+
+               consumer.prepareSnapshot(17L, 17L);
+
+               Set<Tuple2<KafkaTopicPartition, Long>> expected = new 
HashSet<Tuple2<KafkaTopicPartition, Long>>();
+
+               for (Tuple2<KafkaTopicPartition, Long> 
kafkaTopicPartitionLongTuple2 : expectedState.get()) {
+                       expected.add(kafkaTopicPartitionLongTuple2);
+               }
+
+               int counter = 0;
+
+               for (Tuple2<KafkaTopicPartition, Long> 
kafkaTopicPartitionLongTuple2 : listState.get()) {
+                       
assertTrue(expected.contains(kafkaTopicPartitionLongTuple2));
+                       counter++;
+               }
+
+               assertEquals(expected.size(), counter);
        }
 
        /**
@@ -107,7 +152,15 @@ public class FlinkKafkaConsumerBaseTest {
        @Test
        public void checkRestoredNullCheckpointWhenFetcherNotReady() throws 
Exception {
                FlinkKafkaConsumerBase<String> consumer = getConsumer(null, new 
LinkedMap(), true);
-               assertNull(consumer.snapshotState(17L, 23L));
+
+               OperatorStateStore operatorStateStore = 
mock(OperatorStateStore.class);
+               TestingListState<Tuple2<KafkaTopicPartition, Long>> listState = 
new TestingListState<>();
+               
when(operatorStateStore.getPartitionableState(Matchers.any(ListStateDescriptor.class))).thenReturn(listState);
+
+               consumer.initializeState(operatorStateStore);
+               consumer.prepareSnapshot(17L, 17L);
+
+               assertFalse(listState.get().iterator().hasNext());
        }
        
        @Test
@@ -132,15 +185,40 @@ public class FlinkKafkaConsumerBaseTest {
        
                FlinkKafkaConsumerBase<String> consumer = getConsumer(fetcher, 
pendingCheckpoints, true);
                assertEquals(0, pendingCheckpoints.size());
-               
+
+               OperatorStateStore backend = mock(OperatorStateStore.class);
+
+               TestingListState<Tuple2<KafkaTopicPartition, Long>> listState1 
= new TestingListState<>();
+               TestingListState<Tuple2<KafkaTopicPartition, Long>> listState2 
= new TestingListState<>();
+               TestingListState<Tuple2<KafkaTopicPartition, Long>> listState3 
= new TestingListState<>();
+
+               
when(backend.getPartitionableState(Matchers.any(ListStateDescriptor.class))).
+                               thenReturn(listState1, listState1, listState2, 
listState2, listState3, listState3);
+
+               consumer.initializeState(backend);
+
                // checkpoint 1
-               HashMap<KafkaTopicPartition, Long> snapshot1 = 
consumer.snapshotState(138L, 19L);
+               consumer.prepareSnapshot(138L, 138L);
+
+               HashMap<KafkaTopicPartition, Long> snapshot1 = new HashMap<>();
+
+               for (Tuple2<KafkaTopicPartition, Long> 
kafkaTopicPartitionLongTuple2 : listState1.get()) {
+                       snapshot1.put(kafkaTopicPartitionLongTuple2.f0, 
kafkaTopicPartitionLongTuple2.f1);
+               }
+
                assertEquals(state1, snapshot1);
                assertEquals(1, pendingCheckpoints.size());
                assertEquals(state1, pendingCheckpoints.get(138L));
 
                // checkpoint 2
-               HashMap<KafkaTopicPartition, Long> snapshot2 = 
consumer.snapshotState(140L, 1578L);
+               consumer.prepareSnapshot(140L, 140L);
+
+               HashMap<KafkaTopicPartition, Long> snapshot2 = new HashMap<>();
+
+               for (Tuple2<KafkaTopicPartition, Long> 
kafkaTopicPartitionLongTuple2 : listState2.get()) {
+                       snapshot2.put(kafkaTopicPartitionLongTuple2.f0, 
kafkaTopicPartitionLongTuple2.f1);
+               }
+
                assertEquals(state2, snapshot2);
                assertEquals(2, pendingCheckpoints.size());
                assertEquals(state2, pendingCheckpoints.get(140L));
@@ -151,7 +229,14 @@ public class FlinkKafkaConsumerBaseTest {
                assertTrue(pendingCheckpoints.containsKey(140L));
 
                // checkpoint 3
-               HashMap<KafkaTopicPartition, Long> snapshot3 = 
consumer.snapshotState(141L, 1578L);
+               consumer.prepareSnapshot(141L, 141L);
+
+               HashMap<KafkaTopicPartition, Long> snapshot3 = new HashMap<>();
+
+               for (Tuple2<KafkaTopicPartition, Long> 
kafkaTopicPartitionLongTuple2 : listState1.get()) {
+                       snapshot1.put(kafkaTopicPartitionLongTuple2.f0, 
kafkaTopicPartitionLongTuple2.f1);
+               }
+
                assertEquals(state3, snapshot3);
                assertEquals(2, pendingCheckpoints.size());
                assertEquals(state3, pendingCheckpoints.get(141L));
@@ -164,9 +249,14 @@ public class FlinkKafkaConsumerBaseTest {
                consumer.notifyCheckpointComplete(666); // invalid checkpoint
                assertEquals(0, pendingCheckpoints.size());
 
+               OperatorStateStore operatorStateStore = 
mock(OperatorStateStore.class);
+               TestingListState<Tuple2<KafkaTopicPartition, Long>> listState = 
new TestingListState<>();
+               
when(operatorStateStore.getPartitionableState(Matchers.any(ListStateDescriptor.class))).thenReturn(listState);
+
                // create 500 snapshots
                for (int i = 100; i < 600; i++) {
-                       consumer.snapshotState(i, 15 * i);
+                       consumer.prepareSnapshot(i, i);
+                       listState.clear();
                }
                
assertEquals(FlinkKafkaConsumerBase.MAX_NUM_PENDING_CHECKPOINTS, 
pendingCheckpoints.size());
 
@@ -211,12 +301,37 @@ public class FlinkKafkaConsumerBaseTest {
 
                @SuppressWarnings("unchecked")
                public DummyFlinkKafkaConsumer() {
-                       super((KeyedDeserializationSchema<T>) 
mock(KeyedDeserializationSchema.class));
+                       super(Arrays.asList("abc", "def"), 
(KeyedDeserializationSchema < T >) mock(KeyedDeserializationSchema.class));
                }
 
                @Override
                protected AbstractFetcher<T, ?> createFetcher(SourceContext<T> 
sourceContext, List<KafkaTopicPartition> thisSubtaskPartitions, 
SerializedValue<AssignerWithPeriodicWatermarks<T>> watermarksPeriodic, 
SerializedValue<AssignerWithPunctuatedWatermarks<T>> watermarksPunctuated, 
StreamingRuntimeContext runtimeContext) throws Exception {
                        return null;
                }
+
+               @Override
+               protected List<KafkaTopicPartition> 
getKafkaPartitions(List<String> topics) {
+                       return Collections.emptyList();
+               }
+       }
+
+       private static final class TestingListState<T> implements ListState<T> {
+
+               private final List<T> list = new ArrayList<>();
+
+               @Override
+               public void clear() {
+                       list.clear();
+               }
+
+               @Override
+               public Iterable<T> get() throws Exception {
+                       return list;
+               }
+
+               @Override
+               public void add(T value) throws Exception {
+                       list.add(value);
+               }
        }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/KafkaConsumerTestBase.java
----------------------------------------------------------------------
diff --git 
a/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/KafkaConsumerTestBase.java
 
b/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/KafkaConsumerTestBase.java
index a87ff8a..9c36b43 100644
--- 
a/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/KafkaConsumerTestBase.java
+++ 
b/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/KafkaConsumerTestBase.java
@@ -68,7 +68,6 @@ import 
org.apache.flink.streaming.api.functions.sink.SinkFunction;
 import 
org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
 import org.apache.flink.streaming.api.functions.source.RichSourceFunction;
 import org.apache.flink.streaming.api.functions.source.SourceFunction;
-import org.apache.flink.streaming.api.functions.source.StatefulSequenceSource;
 import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
 import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
 import org.apache.flink.streaming.api.watermark.Watermark;
@@ -92,7 +91,6 @@ import org.apache.flink.test.util.SuccessException;
 import org.apache.flink.testutils.junit.RetryOnException;
 import org.apache.flink.testutils.junit.RetryRule;
 import org.apache.flink.util.Collector;
-import org.apache.flink.util.StringUtils;
 import org.apache.kafka.clients.producer.ProducerConfig;
 import org.apache.kafka.common.errors.TimeoutException;
 import org.junit.Assert;
@@ -186,15 +184,27 @@ public abstract class KafkaConsumerTestBase extends 
KafkaTestBase {
                        DataStream<String> stream = see.addSource(source);
                        stream.print();
                        see.execute("No broker test");
-               } catch(RuntimeException re) {
+               } catch(ProgramInvocationException pie) {
                        if(kafkaServer.getVersion().equals("0.9")) {
-                               Assert.assertTrue("Wrong RuntimeException 
thrown: " + StringUtils.stringifyException(re),
-                                               
re.getClass().equals(TimeoutException.class) &&
-                                                               
re.getMessage().contains("Timeout expired while fetching topic metadata"));
+                               assertTrue(pie.getCause() instanceof 
JobExecutionException);
+
+                               JobExecutionException jee = 
(JobExecutionException) pie.getCause();
+
+                               assertTrue(jee.getCause() instanceof 
TimeoutException);
+
+                               TimeoutException te = (TimeoutException) 
jee.getCause();
+
+                               assertEquals("Timeout expired while fetching 
topic metadata", te.getMessage());
                        } else {
-                               Assert.assertTrue("Wrong RuntimeException 
thrown: " + StringUtils.stringifyException(re),
-                                               
re.getClass().equals(RuntimeException.class) &&
-                                                               
re.getMessage().contains("Unable to retrieve any partitions for the requested 
topics [doesntexist]"));
+                               assertTrue(pie.getCause() instanceof 
JobExecutionException);
+
+                               JobExecutionException jee = 
(JobExecutionException) pie.getCause();
+
+                               assertTrue(jee.getCause() instanceof 
RuntimeException);
+
+                               RuntimeException re = (RuntimeException) 
jee.getCause();
+
+                               assertTrue(re.getMessage().contains("Unable to 
retrieve any partitions for the requested topics [doesntexist]"));
                        }
                }
        }
@@ -413,7 +423,7 @@ public abstract class KafkaConsumerTestBase extends 
KafkaTestBase {
                DataGenerators.generateRandomizedIntegerSequence(
                                
StreamExecutionEnvironment.createRemoteEnvironment("localhost", flinkPort),
                                kafkaServer,
-                               topic, numPartitions, numElementsPerPartition, 
true);
+                               topic, numPartitions, numElementsPerPartition, 
false);
 
                // run the topology that fails and recovers
 

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/testutils/MockRuntimeContext.java
----------------------------------------------------------------------
diff --git 
a/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/testutils/MockRuntimeContext.java
 
b/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/testutils/MockRuntimeContext.java
index da2c652..5be4195 100644
--- 
a/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/testutils/MockRuntimeContext.java
+++ 
b/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/testutils/MockRuntimeContext.java
@@ -173,16 +173,6 @@ public class MockRuntimeContext extends 
StreamingRuntimeContext {
        }
 
        @Override
-       public <S> org.apache.flink.api.common.state.OperatorState<S> 
getKeyValueState(String name, Class<S> stateType, S defaultState) {
-               throw new UnsupportedOperationException();
-       }
-
-       @Override
-       public <S> org.apache.flink.api.common.state.OperatorState<S> 
getKeyValueState(String name, TypeInformation<S> stateType, S defaultState) {
-               throw new UnsupportedOperationException();
-       }
-
-       @Override
        public <T> ValueState<T> getState(ValueStateDescriptor<T> 
stateProperties) {
                throw new UnsupportedOperationException();
        }

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/Checkpointed.java
----------------------------------------------------------------------
diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/Checkpointed.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/Checkpointed.java
index 6e2850c..4a0fd60 100644
--- 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/Checkpointed.java
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/Checkpointed.java
@@ -36,6 +36,7 @@ import java.io.Serializable;
  * 
  * @param <T> The type of the operator state.
  */
+@Deprecated
 @PublicEvolving
 public interface Checkpointed<T extends Serializable> {
 

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/CheckpointedFunction.java
----------------------------------------------------------------------
diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/CheckpointedFunction.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/CheckpointedFunction.java
new file mode 100644
index 0000000..2227201
--- /dev/null
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/CheckpointedFunction.java
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.checkpoint;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.runtime.state.OperatorStateStore;
+
+/**
+ *
+ * Similar to @{@link Checkpointed}, this interface must be implemented by 
functions that have potentially
+ * repartitionable state that needs to be checkpointed. Methods from this 
interface are called upon checkpointing and
+ * restoring of state.
+ *
+ * On #initializeState the implementing class receives the {@link 
org.apache.flink.runtime.state.OperatorStateStore}
+ * to store it's state. At least before each snapshot, all state persistent 
state must be stored in the state store.
+ *
+ * When the backend is received for initialization, the user registers states 
with the backend via
+ * {@link org.apache.flink.api.common.state.StateDescriptor}. Then, all 
previously stored state is found in the
+ * received {@link org.apache.flink.api.common.state.State} (currently only
+ * {@link org.apache.flink.api.common.state.ListState} is supported.
+ *
+ * In #prepareSnapshot, the implementing class must ensure that all operator 
state is passed to the operator backend,
+ * i.e. that the state was stored in the relevant {@link 
org.apache.flink.api.common.state.State} instances that
+ * are requested on restore. Notice that users might want to clear and 
reinsert the complete state first if incremental
+ * updates of the states are not possible.
+ */
+@PublicEvolving
+public interface CheckpointedFunction {
+
+       /**
+        *
+        * This method is called when state should be stored for a checkpoint. 
The state can be registered and written to
+        * the provided backend.
+        *
+        * @param checkpointId Id of the checkpoint to perform
+        * @param timestamp Timestamp of the checkpoint
+        * @throws Exception
+        */
+       void prepareSnapshot(long checkpointId, long timestamp) throws 
Exception;
+
+       /**
+        * This method is called when an operator is opened, so that the 
function can set the state backend to which it
+        * hands it's state on snapshot.
+        *
+        * @param stateStore the state store to which this function stores it's 
state
+        * @throws Exception
+        */
+       void initializeState(OperatorStateStore stateStore) throws Exception;
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/ListCheckpointed.java
----------------------------------------------------------------------
diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/ListCheckpointed.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/ListCheckpointed.java
new file mode 100644
index 0000000..430b2b9
--- /dev/null
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/ListCheckpointed.java
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.checkpoint;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.java.typeutils.runtime.JavaSerializer;
+
+import java.io.Serializable;
+import java.util.List;
+
+/**
+ * This method must be implemented by functions that have state that needs to 
be
+ * checkpointed. The functions get a call whenever a checkpoint should take 
place
+ * and return a snapshot of their state as a list of redistributable 
sub-states,
+ * which will be checkpointed.
+ *
+ * @param <T> The type of the operator state.
+ */
+@PublicEvolving
+public interface ListCheckpointed<T extends Serializable> {
+
+       ListStateDescriptor<Serializable> DEFAULT_LIST_DESCRIPTOR =
+                       new ListStateDescriptor<>("", new JavaSerializer<>());
+
+       /**
+        * Gets the current state of the function of operator. The state must 
reflect the result of all
+        * prior invocations to this function.
+        *
+        * @param checkpointId The ID of the checkpoint.
+        * @param timestamp Timestamp of the checkpoint.
+        * @return The operator state in a list of redistributable, atomic 
sub-states.
+        * @throws Exception Thrown if the creation of the state object failed. 
This causes the
+        *                   checkpoint to fail. The system may decide to fail 
the operation (and trigger
+        *                   recovery), or to discard this checkpoint attempt 
and to continue running
+        *                   and to try again with the next checkpoint attempt.
+        */
+       List<T> snapshotState(long checkpointId, long timestamp) throws 
Exception;
+
+       /**
+        * Restores the state of the function or operator to that of a previous 
checkpoint.
+        * This method is invoked when a function is executed as part of a 
recovery run.
+        * <p>
+        * Note that restoreState() is called before open().
+        *
+        * @param state The state to be restored as a list of atomic sub-states.
+        */
+       void restoreState(List<T> state) throws Exception;
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/ContinuousFileReaderOperator.java
----------------------------------------------------------------------
diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/ContinuousFileReaderOperator.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/ContinuousFileReaderOperator.java
index 0c0b81a..838bee6 100644
--- 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/ContinuousFileReaderOperator.java
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/ContinuousFileReaderOperator.java
@@ -28,6 +28,7 @@ import org.apache.flink.core.fs.FSDataInputStream;
 import org.apache.flink.core.fs.FSDataOutputStream;
 import org.apache.flink.core.fs.FileInputSplit;
 import org.apache.flink.metrics.Counter;
+import org.apache.flink.streaming.api.operators.StreamCheckpointedOperator;
 import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
 import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
 import org.apache.flink.streaming.api.operators.OutputTypeConfigurable;
@@ -60,7 +61,7 @@ import static 
org.apache.flink.util.Preconditions.checkNotNull;
  */
 @Internal
 public class ContinuousFileReaderOperator<OUT, S extends Serializable> extends 
AbstractStreamOperator<OUT>
-       implements OneInputStreamOperator<FileInputSplit, OUT>, 
OutputTypeConfigurable<OUT> {
+       implements OneInputStreamOperator<FileInputSplit, OUT>, 
OutputTypeConfigurable<OUT>, StreamCheckpointedOperator {
 
        private static final long serialVersionUID = 1L;
 
@@ -374,7 +375,6 @@ public class ContinuousFileReaderOperator<OUT, S extends 
Serializable> extends A
 
        @Override
        public void snapshotState(FSDataOutputStream os, long checkpointId, 
long timestamp) throws Exception {
-               super.snapshotState(os, checkpointId, timestamp);
 
                final ObjectOutputStream oos = new ObjectOutputStream(os);
 
@@ -397,7 +397,6 @@ public class ContinuousFileReaderOperator<OUT, S extends 
Serializable> extends A
 
        @Override
        public void restoreState(FSDataInputStream is) throws Exception {
-               super.restoreState(is);
 
                final ObjectInputStream ois = new ObjectInputStream(is);
 

Reply via email to