Repository: flink
Updated Branches:
  refs/heads/master fde4f9097 -> e5b65a7fc


[FLINK-4821] Implement rescalable non-partitioned state for Kinesis Connector


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/a05b574c
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/a05b574c
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/a05b574c

Branch: refs/heads/master
Commit: a05b574cc68d3f652d11fece46c23bbc24f35430
Parents: fde4f90
Author: Tony Wei <[email protected]>
Authored: Wed Dec 14 10:18:25 2016 +0800
Committer: Tzu-Li (Gordon) Tai <[email protected]>
Committed: Sun May 7 16:28:52 2017 +0800

----------------------------------------------------------------------
 .../flink-connector-kinesis/pom.xml             |  16 +
 .../kinesis/FlinkKinesisConsumer.java           | 150 +++++--
 .../kinesis/internals/KinesisDataFetcher.java   |   2 +-
 .../FlinkKinesisConsumerMigrationTest.java      | 149 +++++++
 .../kinesis/FlinkKinesisConsumerTest.java       | 440 +++++++++++++++++--
 ...is-consumer-migration-test-flink1.1-snapshot | Bin 0 -> 1140 bytes
 ...sumer-migration-test-flink1.1-snapshot-empty | Bin 0 -> 468 bytes
 7 files changed, 690 insertions(+), 67 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/a05b574c/flink-connectors/flink-connector-kinesis/pom.xml
----------------------------------------------------------------------
diff --git a/flink-connectors/flink-connector-kinesis/pom.xml 
b/flink-connectors/flink-connector-kinesis/pom.xml
index d457199..080626f 100644
--- a/flink-connectors/flink-connector-kinesis/pom.xml
+++ b/flink-connectors/flink-connector-kinesis/pom.xml
@@ -57,6 +57,14 @@ under the License.
 
                <dependency>
                        <groupId>org.apache.flink</groupId>
+                       <artifactId>flink-runtime_2.10</artifactId>
+                       <version>${project.version}</version>
+                       <type>test-jar</type>
+                       <scope>test</scope>
+               </dependency>
+
+               <dependency>
+                       <groupId>org.apache.flink</groupId>
                        <artifactId>flink-tests_2.10</artifactId>
                        <version>${project.version}</version>
                        <scope>test</scope>
@@ -65,6 +73,14 @@ under the License.
 
                <dependency>
                        <groupId>org.apache.flink</groupId>
+                       <artifactId>flink-streaming-java_2.10</artifactId>
+                       <version>${project.version}</version>
+                       <type>test-jar</type>
+                       <scope>test</scope>
+               </dependency>
+
+               <dependency>
+                       <groupId>org.apache.flink</groupId>
                        <artifactId>flink-test-utils_2.10</artifactId>
                        <version>${project.version}</version>
                        <scope>test</scope>

http://git-wip-us.apache.org/repos/asf/flink/blob/a05b574c/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumer.java
----------------------------------------------------------------------
diff --git 
a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumer.java
 
b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumer.java
index a62dc10..dfcd552 100644
--- 
a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumer.java
+++ 
b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumer.java
@@ -17,15 +17,26 @@
 
 package org.apache.flink.streaming.connectors.kinesis;
 
+import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.api.common.functions.RuntimeContext;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.api.java.typeutils.ResultTypeQueryable;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
 import org.apache.flink.configuration.Configuration;
-import org.apache.flink.streaming.api.checkpoint.CheckpointedAsynchronously;
+import org.apache.flink.runtime.state.FunctionInitializationContext;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedRestoring;
 import 
org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
 import 
org.apache.flink.streaming.connectors.kinesis.internals.KinesisDataFetcher;
 import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShard;
-import org.apache.flink.streaming.connectors.kinesis.model.SequenceNumber;
 import 
org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShardState;
+import 
org.apache.flink.streaming.connectors.kinesis.model.SentinelSequenceNumber;
+import org.apache.flink.streaming.connectors.kinesis.model.SequenceNumber;
 import 
org.apache.flink.streaming.connectors.kinesis.serialization.KinesisDeserializationSchema;
 import 
org.apache.flink.streaming.connectors.kinesis.serialization.KinesisDeserializationSchemaWrapper;
 import org.apache.flink.streaming.connectors.kinesis.util.KinesisConfigUtil;
@@ -55,8 +66,10 @@ import static 
org.apache.flink.util.Preconditions.checkNotNull;
  *
  * @param <T> the type of data emitted
  */
-public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T>
-       implements CheckpointedAsynchronously<HashMap<KinesisStreamShard, 
SequenceNumber>>, ResultTypeQueryable<T> {
+public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T> 
implements
+       ResultTypeQueryable<T>,
+       CheckpointedFunction,
+       CheckpointedRestoring<HashMap<KinesisStreamShard, SequenceNumber>> {
 
        private static final long serialVersionUID = 4724006128720664870L;
 
@@ -91,6 +104,14 @@ public class FlinkKinesisConsumer<T> extends 
RichParallelSourceFunction<T>
 
        private volatile boolean running = true;
 
+       // 
------------------------------------------------------------------------
+       //  State for Checkpoint
+       // 
------------------------------------------------------------------------
+
+       /** The name is the key for sequence numbers state, and cannot be 
changed. */
+       private static final String sequenceNumsStateStoreName = 
"Kinesis-Stream-Shard-State";
+
+       private transient ListState<Tuple2<KinesisStreamShard, SequenceNumber>> 
sequenceNumsStateForCheckpoint;
 
        // 
------------------------------------------------------------------------
        //  Constructors
@@ -194,8 +215,7 @@ public class FlinkKinesisConsumer<T> extends 
RichParallelSourceFunction<T>
                // all subtasks will run a fetcher, regardless of whether or 
not the subtask will initially have
                // shards to subscribe to; fetchers will continuously poll for 
changes in the shard list, so all subtasks
                // can potentially have new shards to subscribe to later on
-               fetcher = new KinesisDataFetcher<>(
-                       streams, sourceContext, getRuntimeContext(), 
configProps, deserializer);
+               fetcher = createFetcher(streams, sourceContext, 
getRuntimeContext(), configProps, deserializer);
 
                boolean isRestoringFromFailure = (sequenceNumsToRestore != 
null);
                fetcher.setIsRestoringFromFailure(isRestoringFromFailure);
@@ -203,17 +223,35 @@ public class FlinkKinesisConsumer<T> extends 
RichParallelSourceFunction<T>
                // if we are restoring from a checkpoint, we iterate over the 
restored
                // state and accordingly seed the fetcher with subscribed 
shards states
                if (isRestoringFromFailure) {
-                       for (Map.Entry<KinesisStreamShard, SequenceNumber> 
restored : lastStateSnapshot.entrySet()) {
-                               fetcher.advanceLastDiscoveredShardOfStream(
-                                       restored.getKey().getStreamName(), 
restored.getKey().getShard().getShardId());
-
-                               if (LOG.isInfoEnabled()) {
-                                       LOG.info("Subtask {} is seeding the 
fetcher with restored shard {}," +
-                                                       " starting state set to 
the restored sequence number {}",
-                                               
getRuntimeContext().getIndexOfThisSubtask(), restored.getKey().toString(), 
restored.getValue());
+                       // Since there may have a situation that some subtasks 
did not finish discovering before rescale,
+                       // and KinesisDataFetcher will always discover the 
shard from the largest shard id. To prevent from
+                       // missing some shards which didn't be discovered and 
whose id is not the largest one, we force the
+                       // consumer to discover once from the smallest id and 
make sure each shard have its initial sequence
+                       // number from restored state or 
SENTINEL_EARLIEST_SEQUENCE_NUM.
+                       List<KinesisStreamShard> 
newShardsCreatedWhileNotRunning = fetcher.discoverNewShardsToSubscribe();
+                       for (KinesisStreamShard shard : 
newShardsCreatedWhileNotRunning) {
+                               SequenceNumber startingStateForNewShard;
+
+                               if (lastStateSnapshot.containsKey(shard)) {
+                                       startingStateForNewShard = 
lastStateSnapshot.get(shard);
+
+                                       if (LOG.isInfoEnabled()) {
+                                               LOG.info("Subtask {} is seeding 
the fetcher with restored shard {}," +
+                                                               " starting 
state set to the restored sequence number {}",
+                                                       
getRuntimeContext().getIndexOfThisSubtask(), shard.toString(), 
startingStateForNewShard);
+                                       }
+                               } else {
+                                       startingStateForNewShard = 
SentinelSequenceNumber.SENTINEL_EARLIEST_SEQUENCE_NUM.get();
+
+                                       if (LOG.isInfoEnabled()) {
+                                               LOG.info("Subtask {} is seeding 
the fetcher with new discovered shard {}," +
+                                                               " starting 
state set to the SENTINEL_EARLIEST_SEQUENCE_NUM",
+                                                       
getRuntimeContext().getIndexOfThisSubtask(), shard.toString());
+                                       }
                                }
+
                                fetcher.registerNewSubscribedShardState(
-                                       new 
KinesisStreamShardState(restored.getKey(), restored.getValue()));
+                                       new KinesisStreamShardState(shard, 
startingStateForNewShard));
                        }
                }
 
@@ -267,38 +305,78 @@ public class FlinkKinesisConsumer<T> extends 
RichParallelSourceFunction<T>
        // 
------------------------------------------------------------------------
 
        @Override
-       public HashMap<KinesisStreamShard, SequenceNumber> snapshotState(long 
checkpointId, long checkpointTimestamp) throws Exception {
-               if (lastStateSnapshot == null) {
-                       LOG.debug("snapshotState() requested on not yet opened 
source; returning null.");
-                       return null;
-               }
+       public void initializeState(FunctionInitializationContext context) 
throws Exception {
+               TypeInformation<Tuple2<KinesisStreamShard, SequenceNumber>> 
tuple = new TupleTypeInfo<>(
+                       TypeInformation.of(KinesisStreamShard.class),
+                       TypeInformation.of(SequenceNumber.class)
+               );
+
+               sequenceNumsStateForCheckpoint = 
context.getOperatorStateStore().getUnionListState(
+                       new ListStateDescriptor<>(sequenceNumsStateStoreName, 
tuple));
+
+               if (context.isRestored()) {
+                       if (sequenceNumsToRestore == null) {
+                               sequenceNumsToRestore = new HashMap<>();
+                               for (Tuple2<KinesisStreamShard, SequenceNumber> 
kinesisSequenceNumber : sequenceNumsStateForCheckpoint.get()) {
+                                       
sequenceNumsToRestore.put(kinesisSequenceNumber.f0, kinesisSequenceNumber.f1);
+                               }
 
-               if (fetcher == null) {
-                       LOG.debug("snapshotState() requested on not yet running 
source; returning null.");
-                       return null;
+                               LOG.info("Setting restore state in the 
FlinkKinesisConsumer. Using the following offsets: {}",
+                                       sequenceNumsToRestore);
+                       } else if (sequenceNumsToRestore.isEmpty()) {
+                               sequenceNumsToRestore = null;
+                       }
+               } else {
+                       LOG.info("No restore state for FlinkKinesisConsumer.");
                }
+       }
 
-               if (!running) {
+       @Override
+       public void snapshotState(FunctionSnapshotContext context) throws 
Exception {
+               if (lastStateSnapshot == null) {
+                       LOG.debug("snapshotState() requested on not yet opened 
source; returning null.");
+               } else if (fetcher == null) {
+                       LOG.debug("snapshotState() requested on not yet running 
source; returning null.");
+               } else if (!running) {
                        LOG.debug("snapshotState() called on closed source; 
returning null.");
-                       return null;
-               }
+               } else {
+                       if (LOG.isDebugEnabled()) {
+                               LOG.debug("Snapshotting state ...");
+                       }
 
-               if (LOG.isDebugEnabled()) {
-                       LOG.debug("Snapshotting state ...");
-               }
+                       sequenceNumsStateForCheckpoint.clear();
+                       lastStateSnapshot = fetcher.snapshotState();
 
-               lastStateSnapshot = fetcher.snapshotState();
+                       if (LOG.isDebugEnabled()) {
+                               LOG.debug("Snapshotted state, last processed 
sequence numbers: {}, checkpoint id: {}, timestamp: {}",
+                                       lastStateSnapshot.toString(), 
context.getCheckpointId(), context.getCheckpointTimestamp());
+                       }
 
-               if (LOG.isDebugEnabled()) {
-                       LOG.debug("Snapshotted state, last processed sequence 
numbers: {}, checkpoint id: {}, timestamp: {}",
-                               lastStateSnapshot.toString(), checkpointId, 
checkpointTimestamp);
+                       for (Map.Entry<KinesisStreamShard, SequenceNumber> 
entry : lastStateSnapshot.entrySet()) {
+                               
sequenceNumsStateForCheckpoint.add(Tuple2.of(entry.getKey(), entry.getValue()));
+                       }
                }
-
-               return lastStateSnapshot;
        }
 
        @Override
        public void restoreState(HashMap<KinesisStreamShard, SequenceNumber> 
restoredState) throws Exception {
-               sequenceNumsToRestore = restoredState;
+               LOG.info("Subtask {} restoring offsets from an older Flink 
version: {}",
+                       getRuntimeContext().getIndexOfThisSubtask(), 
sequenceNumsToRestore);
+
+               sequenceNumsToRestore = restoredState.isEmpty() ? null : 
restoredState;
+       }
+
+       /** This method is created for tests that can mock the 
KinesisDataFetcher in the consumer. */
+       protected KinesisDataFetcher<T> createFetcher(List<String> streams,
+                                                                               
                        SourceFunction.SourceContext<T> sourceContext,
+                                                                               
                        RuntimeContext runtimeContext,
+                                                                               
                        Properties configProps,
+                                                                               
                        KinesisDeserializationSchema<T> deserializationSchema) {
+               return new KinesisDataFetcher<>(streams, sourceContext, 
runtimeContext, configProps, deserializationSchema);
+       }
+
+       @VisibleForTesting
+       HashMap<KinesisStreamShard, SequenceNumber> getRestoredState() {
+               return sequenceNumsToRestore;
        }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/a05b574c/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcher.java
----------------------------------------------------------------------
diff --git 
a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcher.java
 
b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcher.java
index 8f7ca6c..c5b4b04 100644
--- 
a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcher.java
+++ 
b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcher.java
@@ -461,7 +461,7 @@ public class KinesisDataFetcher<T> {
         * 3. Update the subscribedStreamsToLastDiscoveredShardIds state so 
that we won't get shards
         *    that we have already seen before the next time this function is 
called
         */
-       private List<KinesisStreamShard> discoverNewShardsToSubscribe() throws 
InterruptedException {
+       public List<KinesisStreamShard> discoverNewShardsToSubscribe() throws 
InterruptedException {
 
                List<KinesisStreamShard> newShardsToSubscribe = new 
LinkedList<>();
 

http://git-wip-us.apache.org/repos/asf/flink/blob/a05b574c/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerMigrationTest.java
----------------------------------------------------------------------
diff --git 
a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerMigrationTest.java
 
b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerMigrationTest.java
new file mode 100644
index 0000000..2f46e09
--- /dev/null
+++ 
b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerMigrationTest.java
@@ -0,0 +1,149 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.connectors.kinesis;
+
+import com.amazonaws.services.kinesis.model.Shard;
+import org.apache.flink.api.common.functions.RuntimeContext;
+import org.apache.flink.streaming.api.TimeCharacteristic;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.streaming.api.operators.StreamSource;
+import 
org.apache.flink.streaming.connectors.kinesis.config.ConsumerConfigConstants;
+import 
org.apache.flink.streaming.connectors.kinesis.internals.KinesisDataFetcher;
+import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShard;
+import org.apache.flink.streaming.connectors.kinesis.model.SequenceNumber;
+import 
org.apache.flink.streaming.connectors.kinesis.serialization.KinesisDeserializationSchema;
+import 
org.apache.flink.streaming.connectors.kinesis.testutils.KinesisShardIdGenerator;
+import org.apache.flink.streaming.util.AbstractStreamOperatorTestHarness;
+import org.junit.Test;
+
+import java.net.URL;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Properties;
+
+import static org.junit.Assert.assertNotEquals;
+import static org.junit.Assert.assertEquals;
+import static org.mockito.Mockito.mock;
+
+/**
+ * Tests for checking whether {@link FlinkKinesisConsumer} can restore from 
snapshots that were
+ * done using the Flink 1.1 {@link FlinkKinesisConsumer}.
+ *
+ * <p>For regenerating the binary snapshot file you have to run the commented 
out portion
+ * of each test on a checkout of the Flink 1.1 branch.
+ */
+public class FlinkKinesisConsumerMigrationTest {
+
+       @Test
+       public void testRestoreFromFlink11WithEmptyState() throws Exception {
+               Properties testConfig = new Properties();
+               testConfig.setProperty(ConsumerConfigConstants.AWS_REGION, 
"us-east-1");
+               
testConfig.setProperty(ConsumerConfigConstants.AWS_CREDENTIALS_PROVIDER, 
"BASIC");
+               
testConfig.setProperty(ConsumerConfigConstants.AWS_ACCESS_KEY_ID, 
"accessKeyId");
+               
testConfig.setProperty(ConsumerConfigConstants.AWS_SECRET_ACCESS_KEY, 
"secretKey");
+
+               final DummyFlinkKafkaConsumer<String> consumerFunction = new 
DummyFlinkKafkaConsumer<>(testConfig);
+
+               StreamSource<String, DummyFlinkKafkaConsumer<String>> 
consumerOperator = new StreamSource<>(consumerFunction);
+
+               final AbstractStreamOperatorTestHarness<String> testHarness =
+                       new 
AbstractStreamOperatorTestHarness<>(consumerOperator, 1, 1, 0);
+
+               
testHarness.setTimeCharacteristic(TimeCharacteristic.ProcessingTime);
+
+               testHarness.setup();
+               // restore state from binary snapshot file using legacy method
+               testHarness.initializeStateFromLegacyCheckpoint(
+                       
getResourceFilename("kinesis-consumer-migration-test-flink1.1-snapshot-empty"));
+               testHarness.open();
+
+               // assert that no state was restored
+               assertEquals(null, consumerFunction.getRestoredState());
+
+               consumerOperator.close();
+               consumerOperator.cancel();
+       }
+
+       @Test
+       public void testRestoreFromFlink11() throws Exception {
+               Properties testConfig = new Properties();
+               testConfig.setProperty(ConsumerConfigConstants.AWS_REGION, 
"us-east-1");
+               
testConfig.setProperty(ConsumerConfigConstants.AWS_CREDENTIALS_PROVIDER, 
"BASIC");
+               
testConfig.setProperty(ConsumerConfigConstants.AWS_ACCESS_KEY_ID, 
"accessKeyId");
+               
testConfig.setProperty(ConsumerConfigConstants.AWS_SECRET_ACCESS_KEY, 
"secretKey");
+
+               final DummyFlinkKafkaConsumer<String> consumerFunction = new 
DummyFlinkKafkaConsumer<>(testConfig);
+
+               StreamSource<String, DummyFlinkKafkaConsumer<String>> 
consumerOperator =
+                       new StreamSource<>(consumerFunction);
+
+               final AbstractStreamOperatorTestHarness<String> testHarness =
+                       new 
AbstractStreamOperatorTestHarness<>(consumerOperator, 1, 1, 0);
+
+               
testHarness.setTimeCharacteristic(TimeCharacteristic.ProcessingTime);
+
+               testHarness.setup();
+               // restore state from binary snapshot file using legacy method
+               testHarness.initializeStateFromLegacyCheckpoint(
+                       
getResourceFilename("kinesis-consumer-migration-test-flink1.1-snapshot"));
+               testHarness.open();
+
+               // the expected state in 
"kafka-consumer-migration-test-flink1.1-snapshot"
+               final HashMap<KinesisStreamShard, SequenceNumber> expectedState 
= new HashMap<>();
+               expectedState.put(new KinesisStreamShard("fakeStream1",
+                               new 
Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(0))),
+                       new SequenceNumber("987654321"));
+
+               // assert that state is correctly restored from legacy 
checkpoint
+               assertNotEquals(null, consumerFunction.getRestoredState());
+               assertEquals(1, consumerFunction.getRestoredState().size());
+               assertEquals(expectedState, 
consumerFunction.getRestoredState());
+
+               consumerOperator.close();
+               consumerOperator.cancel();
+       }
+
+       // 
------------------------------------------------------------------------
+
+       private static String getResourceFilename(String filename) {
+               ClassLoader cl = 
FlinkKinesisConsumerMigrationTest.class.getClassLoader();
+               URL resource = cl.getResource(filename);
+               if (resource == null) {
+                       throw new NullPointerException("Missing snapshot 
resource.");
+               }
+               return resource.getFile();
+       }
+
+       private static class DummyFlinkKafkaConsumer<T> extends 
FlinkKinesisConsumer<T> {
+               private static final long serialVersionUID = 1L;
+
+               @SuppressWarnings("unchecked")
+               DummyFlinkKafkaConsumer(Properties properties) {
+                       super("test", mock(KinesisDeserializationSchema.class), 
properties);
+               }
+
+               @Override
+               protected KinesisDataFetcher<T> createFetcher(List<String> 
streams,
+                                                                               
                                SourceFunction.SourceContext<T> sourceContext,
+                                                                               
                                RuntimeContext runtimeContext,
+                                                                               
                                Properties configProps,
+                                                                               
                                KinesisDeserializationSchema<T> 
deserializationSchema) {
+                       return mock(KinesisDataFetcher.class);
+               }
+       }
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/a05b574c/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerTest.java
----------------------------------------------------------------------
diff --git 
a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerTest.java
 
b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerTest.java
index 741f0ca..bf8e44f 100644
--- 
a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerTest.java
+++ 
b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerTest.java
@@ -18,13 +18,22 @@
 package org.apache.flink.streaming.connectors.kinesis;
 
 import com.amazonaws.services.kinesis.model.Shard;
+import org.apache.flink.api.common.functions.RuntimeContext;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.state.OperatorStateStore;
+import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.configuration.Configuration;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContextSynchronousImpl;
 import org.apache.flink.streaming.api.functions.source.SourceFunction;
 import org.apache.flink.streaming.connectors.kinesis.config.AWSConfigConstants;
 import 
org.apache.flink.streaming.connectors.kinesis.config.ConsumerConfigConstants;
 import 
org.apache.flink.streaming.connectors.kinesis.config.ProducerConfigConstants;
 import 
org.apache.flink.streaming.connectors.kinesis.internals.KinesisDataFetcher;
 import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShard;
+import 
org.apache.flink.streaming.connectors.kinesis.model.SentinelSequenceNumber;
 import org.apache.flink.streaming.connectors.kinesis.model.SequenceNumber;
 import 
org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShardState;
 import 
org.apache.flink.streaming.connectors.kinesis.testutils.KinesisShardIdGenerator;
@@ -35,19 +44,29 @@ import org.junit.Rule;
 import org.junit.Test;
 import org.junit.rules.ExpectedException;
 import org.junit.runner.RunWith;
+import org.mockito.Matchers;
 import org.mockito.Mockito;
+import org.mockito.internal.util.reflection.Whitebox;
 import org.powermock.api.mockito.PowerMockito;
 import org.powermock.core.classloader.annotations.PrepareForTest;
 import org.powermock.modules.junit4.PowerMockRunner;
 
-import java.text.SimpleDateFormat;
 import java.util.HashMap;
 import java.util.Map;
+import java.util.ArrayList;
+import java.util.List;
 import java.util.Properties;
 import java.util.UUID;
+import java.io.Serializable;
 
-import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotEquals;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.when;
+import static org.mockito.Mockito.never;
 
 /**
  * Suite of FlinkKinesisConsumer tests for the methods called throughout the 
source life cycle.
@@ -511,28 +530,149 @@ public class FlinkKinesisConsumerTest {
        // 
----------------------------------------------------------------------
 
        @Test
-       public void testSnapshotStateShouldBeNullIfSourceNotOpened() throws 
Exception {
+       public void testSnapshotStateShouldNotClearListStateIfSourceNotOpened() 
throws Exception {
                Properties config = new Properties();
                config.setProperty(AWSConfigConstants.AWS_REGION, "us-east-1");
                config.setProperty(AWSConfigConstants.AWS_ACCESS_KEY_ID, 
"accessKeyId");
                config.setProperty(AWSConfigConstants.AWS_SECRET_ACCESS_KEY, 
"secretKey");
 
+               OperatorStateStore operatorStateStore = 
mock(OperatorStateStore.class);
+
+               TestingListState<Serializable> listState = new 
TestingListState<>();
+
                FlinkKinesisConsumer<String> consumer = new 
FlinkKinesisConsumer<>("fakeStream", new SimpleStringSchema(), config);
 
-               assertTrue(consumer.snapshotState(123, 123) == null); 
//arbitrary checkpoint id and timestamp
+               
when(operatorStateStore.getUnionListState(Matchers.any(ListStateDescriptor.class))).thenReturn(listState);
+
+               StateInitializationContext initializationContext = 
mock(StateInitializationContext.class);
+
+               
when(initializationContext.getOperatorStateStore()).thenReturn(operatorStateStore);
+               when(initializationContext.isRestored()).thenReturn(false);
+
+               consumer.initializeState(initializationContext);
+
+               consumer.snapshotState(new 
StateSnapshotContextSynchronousImpl(123, 123)); //arbitrary checkpoint id and 
timestamp
+
+               assertFalse(listState.isClearCalled());
        }
 
        @Test
-       public void testSnapshotStateShouldBeNullIfSourceNotRun() throws 
Exception {
+       public void testSnapshotStateShouldNotClearListStateIfSourceNotRun() 
throws Exception {
                Properties config = new Properties();
                config.setProperty(AWSConfigConstants.AWS_REGION, "us-east-1");
                config.setProperty(AWSConfigConstants.AWS_ACCESS_KEY_ID, 
"accessKeyId");
                config.setProperty(AWSConfigConstants.AWS_SECRET_ACCESS_KEY, 
"secretKey");
 
+               OperatorStateStore operatorStateStore = 
mock(OperatorStateStore.class);
+
+               TestingListState<Serializable> listState = new 
TestingListState<>();
+
                FlinkKinesisConsumer<String> consumer = new 
FlinkKinesisConsumer<>("fakeStream", new SimpleStringSchema(), config);
+
+               
when(operatorStateStore.getUnionListState(Matchers.any(ListStateDescriptor.class))).thenReturn(listState);
+
+               StateInitializationContext initializationContext = 
mock(StateInitializationContext.class);
+
+               
when(initializationContext.getOperatorStateStore()).thenReturn(operatorStateStore);
+               when(initializationContext.isRestored()).thenReturn(false);
+
+               consumer.initializeState(initializationContext);
+
                consumer.open(new Configuration()); // only opened, not run
 
-               assertTrue(consumer.snapshotState(123, 123) == null); 
//arbitrary checkpoint id and timestamp
+               consumer.snapshotState(new 
StateSnapshotContextSynchronousImpl(123, 123)); //arbitrary checkpoint id and 
timestamp
+
+               assertFalse(listState.isClearCalled());
+       }
+
+       @Test
+       public void testListStateChangedAfterSnapshotState() throws Exception {
+               // 
----------------------------------------------------------------------
+               // setting config, initial state and state after snapshot
+               // 
----------------------------------------------------------------------
+               Properties config = new Properties();
+               config.setProperty(AWSConfigConstants.AWS_REGION, "us-east-1");
+               config.setProperty(AWSConfigConstants.AWS_ACCESS_KEY_ID, 
"accessKeyId");
+               config.setProperty(AWSConfigConstants.AWS_SECRET_ACCESS_KEY, 
"secretKey");
+
+               ArrayList<Tuple2<KinesisStreamShard, SequenceNumber>> 
initialState = new ArrayList<>(1);
+               initialState.add(Tuple2.of(
+                       new KinesisStreamShard("fakeStream1",
+                               new 
Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(0))),
+                       new SequenceNumber("1")));
+
+               ArrayList<Tuple2<KinesisStreamShard, SequenceNumber>> 
snapShotState = new ArrayList<>(3);
+               snapShotState.add(Tuple2.of(
+                       new KinesisStreamShard("fakeStream1",
+                               new 
Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(0))),
+                       new SequenceNumber("12")));
+               snapShotState.add(Tuple2.of(
+                       new KinesisStreamShard("fakeStream1",
+                               new 
Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(1))),
+                       new SequenceNumber("11")));
+               snapShotState.add(Tuple2.of(
+                       new KinesisStreamShard("fakeStream1",
+                               new 
Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(2))),
+                       new SequenceNumber("31")));
+
+               // 
----------------------------------------------------------------------
+               // mock operator state backend and initial state for 
initializeState()
+               // 
----------------------------------------------------------------------
+               TestingListState<Serializable> listState = new 
TestingListState<>();
+               for (Serializable state: initialState) {
+                       listState.add(state);
+               }
+
+               OperatorStateStore operatorStateStore = 
mock(OperatorStateStore.class);
+               
when(operatorStateStore.getUnionListState(Matchers.any(ListStateDescriptor.class))).thenReturn(listState);
+
+               StateInitializationContext initializationContext = 
mock(StateInitializationContext.class);
+               
when(initializationContext.getOperatorStateStore()).thenReturn(operatorStateStore);
+               when(initializationContext.isRestored()).thenReturn(true);
+
+               // 
----------------------------------------------------------------------
+               // mock a running fetcher and its state for snapshot
+               // 
----------------------------------------------------------------------
+               HashMap<KinesisStreamShard, SequenceNumber> stateSnapshot = new 
HashMap<>();
+               for (Tuple2<KinesisStreamShard, SequenceNumber> tuple: 
snapShotState) {
+                       stateSnapshot.put(tuple.f0, tuple.f1);
+               }
+
+               KinesisDataFetcher mockedFetcher = 
mock(KinesisDataFetcher.class);
+               when(mockedFetcher.snapshotState()).thenReturn(stateSnapshot);
+
+               // 
----------------------------------------------------------------------
+               // create a consumer and test the snapshotState()
+               // 
----------------------------------------------------------------------
+               FlinkKinesisConsumer<String> consumer = new 
FlinkKinesisConsumer<>("fakeStream", new SimpleStringSchema(), config);
+               FlinkKinesisConsumer<?> mockedConsumer = spy(consumer);
+
+               RuntimeContext context = mock(RuntimeContext.class);
+               when(context.getIndexOfThisSubtask()).thenReturn(1);
+
+               mockedConsumer.setRuntimeContext(context);
+               mockedConsumer.initializeState(initializationContext);
+               mockedConsumer.open(new Configuration());
+               Whitebox.setInternalState(mockedConsumer, "fetcher", 
mockedFetcher); // mock as consumer is running.
+
+               
mockedConsumer.snapshotState(mock(FunctionSnapshotContext.class));
+
+               assertEquals(true, listState.clearCalled);
+               assertEquals(3, listState.getList().size());
+
+               for (Serializable state: initialState) {
+                       for (Serializable currentState: listState.getList()) {
+                               assertNotEquals(state, currentState);
+                       }
+               }
+
+               for (Serializable state: snapShotState) {
+                       boolean hasOneIsSame = false;
+                       for (Serializable currentState: listState.getList()) {
+                               hasOneIsSame = hasOneIsSame || 
state.equals(currentState);
+                       }
+                       assertEquals(true, hasOneIsSame);
+               }
        }
 
        // 
----------------------------------------------------------------------
@@ -559,48 +699,288 @@ public class FlinkKinesisConsumerTest {
 
        @Test
        @SuppressWarnings("unchecked")
+       public void 
testFetcherShouldBeCorrectlySeededIfRestoringFromLegacyCheckpoint() throws 
Exception {
+               HashMap<KinesisStreamShard, SequenceNumber> fakeRestoredState = 
getFakeRestoredStore("all");
+
+               KinesisDataFetcher mockedFetcher = 
Mockito.mock(KinesisDataFetcher.class);
+               List<KinesisStreamShard> shards = new ArrayList<>();
+               shards.addAll(fakeRestoredState.keySet());
+               
when(mockedFetcher.discoverNewShardsToSubscribe()).thenReturn(shards);
+               
PowerMockito.whenNew(KinesisDataFetcher.class).withAnyArguments().thenReturn(mockedFetcher);
+
+               // assume the given config is correct
+               PowerMockito.mockStatic(KinesisConfigUtil.class);
+               PowerMockito.doNothing().when(KinesisConfigUtil.class);
+
+               TestableFlinkKinesisConsumer consumer = new 
TestableFlinkKinesisConsumer(
+                       "fakeStream", new Properties(), 10, 2);
+               consumer.restoreState(fakeRestoredState);
+               consumer.open(new Configuration());
+               consumer.run(Mockito.mock(SourceFunction.SourceContext.class));
+
+               Mockito.verify(mockedFetcher).setIsRestoringFromFailure(true);
+               for (Map.Entry<KinesisStreamShard, SequenceNumber> 
restoredShard : fakeRestoredState.entrySet()) {
+                       
Mockito.verify(mockedFetcher).registerNewSubscribedShardState(
+                               new 
KinesisStreamShardState(restoredShard.getKey(), restoredShard.getValue()));
+               }
+       }
+
+       @Test
+       @SuppressWarnings("unchecked")
        public void 
testFetcherShouldBeCorrectlySeededIfRestoringFromCheckpoint() throws Exception {
+               // 
----------------------------------------------------------------------
+               // setting initial state
+               // 
----------------------------------------------------------------------
+               HashMap<KinesisStreamShard, SequenceNumber> fakeRestoredState = 
getFakeRestoredStore("all");
+
+               // 
----------------------------------------------------------------------
+               // mock operator state backend and initial state for 
initializeState()
+               // 
----------------------------------------------------------------------
+               TestingListState<Serializable> listState = new 
TestingListState<>();
+               for (Map.Entry<KinesisStreamShard, SequenceNumber> state: 
fakeRestoredState.entrySet()) {
+                       listState.add(Tuple2.of(state.getKey(), 
state.getValue()));
+               }
+
+               OperatorStateStore operatorStateStore = 
mock(OperatorStateStore.class);
+               
when(operatorStateStore.getUnionListState(Matchers.any(ListStateDescriptor.class))).thenReturn(listState);
+
+               StateInitializationContext initializationContext = 
mock(StateInitializationContext.class);
+               
when(initializationContext.getOperatorStateStore()).thenReturn(operatorStateStore);
+               when(initializationContext.isRestored()).thenReturn(true);
+
+               // 
----------------------------------------------------------------------
+               // mock fetcher
+               // 
----------------------------------------------------------------------
                KinesisDataFetcher mockedFetcher = 
Mockito.mock(KinesisDataFetcher.class);
+               List<KinesisStreamShard> shards = new ArrayList<>();
+               shards.addAll(fakeRestoredState.keySet());
+               
when(mockedFetcher.discoverNewShardsToSubscribe()).thenReturn(shards);
                
PowerMockito.whenNew(KinesisDataFetcher.class).withAnyArguments().thenReturn(mockedFetcher);
 
                // assume the given config is correct
                PowerMockito.mockStatic(KinesisConfigUtil.class);
                PowerMockito.doNothing().when(KinesisConfigUtil.class);
 
-               HashMap<KinesisStreamShard, SequenceNumber> fakeRestoredState = 
new HashMap<>();
-               fakeRestoredState.put(
-                       new KinesisStreamShard("fakeStream1",
-                               new 
Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(0))),
-                       new SequenceNumber(UUID.randomUUID().toString()));
-               fakeRestoredState.put(
-                       new KinesisStreamShard("fakeStream1",
-                               new 
Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(1))),
-                       new SequenceNumber(UUID.randomUUID().toString()));
-               fakeRestoredState.put(
-                       new KinesisStreamShard("fakeStream1",
-                               new 
Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(2))),
-                       new SequenceNumber(UUID.randomUUID().toString()));
-               fakeRestoredState.put(
-                       new KinesisStreamShard("fakeStream2",
-                               new 
Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(0))),
-                       new SequenceNumber(UUID.randomUUID().toString()));
-               fakeRestoredState.put(
-                       new KinesisStreamShard("fakeStream2",
-                               new 
Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(1))),
-                       new SequenceNumber(UUID.randomUUID().toString()));
+               // 
----------------------------------------------------------------------
+               // start to test seed initial state to fetcher
+               // 
----------------------------------------------------------------------
+               TestableFlinkKinesisConsumer consumer = new 
TestableFlinkKinesisConsumer(
+                       "fakeStream", new Properties(), 10, 2);
+               consumer.initializeState(initializationContext);
+               consumer.open(new Configuration());
+               consumer.run(Mockito.mock(SourceFunction.SourceContext.class));
+
+               Mockito.verify(mockedFetcher).setIsRestoringFromFailure(true);
+               for (Map.Entry<KinesisStreamShard, SequenceNumber> 
restoredShard : fakeRestoredState.entrySet()) {
+                       
Mockito.verify(mockedFetcher).registerNewSubscribedShardState(
+                               new 
KinesisStreamShardState(restoredShard.getKey(), restoredShard.getValue()));
+               }
+       }
+
+       @Test
+       @SuppressWarnings("unchecked")
+       public void testFetcherShouldBeCorrectlySeededOnlyItsOwnStates() throws 
Exception {
+               // 
----------------------------------------------------------------------
+               // setting initial state
+               // 
----------------------------------------------------------------------
+               HashMap<KinesisStreamShard, SequenceNumber> fakeRestoredState = 
getFakeRestoredStore("fakeStream1");
+
+               HashMap<KinesisStreamShard, SequenceNumber> 
fakeRestoredStateForOthers = getFakeRestoredStore("fakeStream2");
+
+               // 
----------------------------------------------------------------------
+               // mock operator state backend and initial state for 
initializeState()
+               // 
----------------------------------------------------------------------
+               TestingListState<Serializable> listState = new 
TestingListState<>();
+               for (Map.Entry<KinesisStreamShard, SequenceNumber> state: 
fakeRestoredState.entrySet()) {
+                       listState.add(Tuple2.of(state.getKey(), 
state.getValue()));
+               }
+               for (Map.Entry<KinesisStreamShard, SequenceNumber> state: 
fakeRestoredStateForOthers.entrySet()) {
+                       listState.add(Tuple2.of(state.getKey(), 
state.getValue()));
+               }
 
+               OperatorStateStore operatorStateStore = 
mock(OperatorStateStore.class);
+               
when(operatorStateStore.getUnionListState(Matchers.any(ListStateDescriptor.class))).thenReturn(listState);
+
+               StateInitializationContext initializationContext = 
mock(StateInitializationContext.class);
+               
when(initializationContext.getOperatorStateStore()).thenReturn(operatorStateStore);
+               when(initializationContext.isRestored()).thenReturn(true);
+
+               // 
----------------------------------------------------------------------
+               // mock fetcher
+               // 
----------------------------------------------------------------------
+               KinesisDataFetcher mockedFetcher = 
Mockito.mock(KinesisDataFetcher.class);
+               List<KinesisStreamShard> shards = new ArrayList<>();
+               shards.addAll(fakeRestoredState.keySet());
+               
when(mockedFetcher.discoverNewShardsToSubscribe()).thenReturn(shards);
+               
PowerMockito.whenNew(KinesisDataFetcher.class).withAnyArguments().thenReturn(mockedFetcher);
+
+               // assume the given config is correct
+               PowerMockito.mockStatic(KinesisConfigUtil.class);
+               PowerMockito.doNothing().when(KinesisConfigUtil.class);
+
+               // 
----------------------------------------------------------------------
+               // start to test seed initial state to fetcher
+               // 
----------------------------------------------------------------------
                TestableFlinkKinesisConsumer consumer = new 
TestableFlinkKinesisConsumer(
                        "fakeStream", new Properties(), 10, 2);
-               consumer.restoreState(fakeRestoredState);
+               consumer.initializeState(initializationContext);
                consumer.open(new Configuration());
                consumer.run(Mockito.mock(SourceFunction.SourceContext.class));
 
                Mockito.verify(mockedFetcher).setIsRestoringFromFailure(true);
+               for (Map.Entry<KinesisStreamShard, SequenceNumber> 
restoredShard : fakeRestoredStateForOthers.entrySet()) {
+                       // should never get restored state not belonging to 
itself
+                       Mockito.verify(mockedFetcher, 
never()).registerNewSubscribedShardState(
+                               new 
KinesisStreamShardState(restoredShard.getKey(), restoredShard.getValue()));
+               }
                for (Map.Entry<KinesisStreamShard, SequenceNumber> 
restoredShard : fakeRestoredState.entrySet()) {
-                       
Mockito.verify(mockedFetcher).advanceLastDiscoveredShardOfStream(
-                               restoredShard.getKey().getStreamName(), 
restoredShard.getKey().getShard().getShardId());
+                       // should get restored state belonging to itself
                        
Mockito.verify(mockedFetcher).registerNewSubscribedShardState(
                                new 
KinesisStreamShardState(restoredShard.getKey(), restoredShard.getValue()));
                }
        }
+
+       /*
+        * If the original parallelism is 2 and states is:
+        *   Consumer subtask 1:
+        *     stream1, shard1, SequentialNumber(xxx)
+        *   Consumer subtask 2:
+        *     stream1, shard2, SequentialNumber(yyy)
+        * After discoverNewShardsToSubscribe() if there are two shards 
(shard3, shard4) been created:
+        *   Consumer subtask 1 (late for discoverNewShardsToSubscribe()):
+        *     stream1, shard1, SequentialNumber(xxx)
+        *   Consumer subtask 2:
+        *     stream1, shard2, SequentialNumber(yyy)
+        *     stream1, shard4, SequentialNumber(zzz)
+        *  If snapshotState() occur and parallelism is changed to 1:
+        *    Union state will be:
+        *     stream1, shard1, SequentialNumber(xxx)
+        *     stream1, shard2, SequentialNumber(yyy)
+        *     stream1, shard4, SequentialNumber(zzz)
+        *    Fetcher should be seeded with:
+        *     stream1, shard1, SequentialNumber(xxx)
+        *     stream1, shard2, SequentialNumber(yyy)
+        *     stream1, share3, 
SentinelSequenceNumber.SENTINEL_EARLIEST_SEQUENCE_NUM
+        *     stream1, shard4, SequentialNumber(zzz)
+        *
+        *  This test is to guarantee the fetcher will be seeded correctly for 
such situation.
+        */
+       @Test
+       @SuppressWarnings("unchecked")
+       public void 
testFetcherShouldBeCorrectlySeededWithNewDiscoveredKinesisStreamShard() throws 
Exception {
+               // 
----------------------------------------------------------------------
+               // setting initial state
+               // 
----------------------------------------------------------------------
+               HashMap<KinesisStreamShard, SequenceNumber> fakeRestoredState = 
getFakeRestoredStore("all");
+
+               // 
----------------------------------------------------------------------
+               // mock operator state backend and initial state for 
initializeState()
+               // 
----------------------------------------------------------------------
+               TestingListState<Serializable> listState = new 
TestingListState<>();
+               for (Map.Entry<KinesisStreamShard, SequenceNumber> state: 
fakeRestoredState.entrySet()) {
+                       listState.add(Tuple2.of(state.getKey(), 
state.getValue()));
+               }
+
+               OperatorStateStore operatorStateStore = 
mock(OperatorStateStore.class);
+               
when(operatorStateStore.getUnionListState(Matchers.any(ListStateDescriptor.class))).thenReturn(listState);
+
+               StateInitializationContext initializationContext = 
mock(StateInitializationContext.class);
+               
when(initializationContext.getOperatorStateStore()).thenReturn(operatorStateStore);
+               when(initializationContext.isRestored()).thenReturn(true);
+
+               // 
----------------------------------------------------------------------
+               // mock fetcher
+               // 
----------------------------------------------------------------------
+               KinesisDataFetcher mockedFetcher = 
Mockito.mock(KinesisDataFetcher.class);
+               List<KinesisStreamShard> shards = new ArrayList<>();
+               shards.addAll(fakeRestoredState.keySet());
+               shards.add(new KinesisStreamShard("fakeStream2",
+                       new 
Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(2))));
+               
when(mockedFetcher.discoverNewShardsToSubscribe()).thenReturn(shards);
+               
PowerMockito.whenNew(KinesisDataFetcher.class).withAnyArguments().thenReturn(mockedFetcher);
+
+               // assume the given config is correct
+               PowerMockito.mockStatic(KinesisConfigUtil.class);
+               PowerMockito.doNothing().when(KinesisConfigUtil.class);
+
+               // 
----------------------------------------------------------------------
+               // start to test seed initial state to fetcher
+               // 
----------------------------------------------------------------------
+               TestableFlinkKinesisConsumer consumer = new 
TestableFlinkKinesisConsumer(
+                       "fakeStream", new Properties(), 10, 2);
+               consumer.initializeState(initializationContext);
+               consumer.open(new Configuration());
+               consumer.run(Mockito.mock(SourceFunction.SourceContext.class));
+
+               fakeRestoredState.put(new KinesisStreamShard("fakeStream2",
+                               new 
Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(2))),
+                       
SentinelSequenceNumber.SENTINEL_EARLIEST_SEQUENCE_NUM.get());
+               Mockito.verify(mockedFetcher).setIsRestoringFromFailure(true);
+               for (Map.Entry<KinesisStreamShard, SequenceNumber> 
restoredShard : fakeRestoredState.entrySet()) {
+                       
Mockito.verify(mockedFetcher).registerNewSubscribedShardState(
+                               new 
KinesisStreamShardState(restoredShard.getKey(), restoredShard.getValue()));
+               }
+       }
+
+       private static final class TestingListState<T> implements ListState<T> {
+
+               private final List<T> list = new ArrayList<>();
+               private boolean clearCalled = false;
+
+               @Override
+               public void clear() {
+                       list.clear();
+                       clearCalled = true;
+               }
+
+               @Override
+               public Iterable<T> get() throws Exception {
+                       return list;
+               }
+
+               @Override
+               public void add(T value) throws Exception {
+                       list.add(value);
+               }
+
+               public List<T> getList() {
+                       return list;
+               }
+
+               public boolean isClearCalled() {
+                       return clearCalled;
+               }
+       }
+
+       private HashMap<KinesisStreamShard, SequenceNumber> 
getFakeRestoredStore(String streamName) {
+               HashMap<KinesisStreamShard, SequenceNumber> fakeRestoredState = 
new HashMap<>();
+
+               if (streamName.equals("fakeStream1") || 
streamName.equals("all")) {
+                       fakeRestoredState.put(
+                               new KinesisStreamShard("fakeStream1",
+                                       new 
Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(0))),
+                               new 
SequenceNumber(UUID.randomUUID().toString()));
+                       fakeRestoredState.put(
+                               new KinesisStreamShard("fakeStream1",
+                                       new 
Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(1))),
+                               new 
SequenceNumber(UUID.randomUUID().toString()));
+                       fakeRestoredState.put(
+                               new KinesisStreamShard("fakeStream1",
+                                       new 
Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(2))),
+                               new 
SequenceNumber(UUID.randomUUID().toString()));
+               }
+
+               if (streamName.equals("fakeStream2") || 
streamName.equals("all")) {
+                       fakeRestoredState.put(
+                               new KinesisStreamShard("fakeStream2",
+                                       new 
Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(0))),
+                               new 
SequenceNumber(UUID.randomUUID().toString()));
+                       fakeRestoredState.put(
+                               new KinesisStreamShard("fakeStream2",
+                                       new 
Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(1))),
+                               new 
SequenceNumber(UUID.randomUUID().toString()));
+               }
+
+               return fakeRestoredState;
+       }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/a05b574c/flink-connectors/flink-connector-kinesis/src/test/resources/kinesis-consumer-migration-test-flink1.1-snapshot
----------------------------------------------------------------------
diff --git 
a/flink-connectors/flink-connector-kinesis/src/test/resources/kinesis-consumer-migration-test-flink1.1-snapshot
 
b/flink-connectors/flink-connector-kinesis/src/test/resources/kinesis-consumer-migration-test-flink1.1-snapshot
new file mode 100644
index 0000000..b60402e
Binary files /dev/null and 
b/flink-connectors/flink-connector-kinesis/src/test/resources/kinesis-consumer-migration-test-flink1.1-snapshot
 differ

http://git-wip-us.apache.org/repos/asf/flink/blob/a05b574c/flink-connectors/flink-connector-kinesis/src/test/resources/kinesis-consumer-migration-test-flink1.1-snapshot-empty
----------------------------------------------------------------------
diff --git 
a/flink-connectors/flink-connector-kinesis/src/test/resources/kinesis-consumer-migration-test-flink1.1-snapshot-empty
 
b/flink-connectors/flink-connector-kinesis/src/test/resources/kinesis-consumer-migration-test-flink1.1-snapshot-empty
new file mode 100644
index 0000000..f4dd96d
Binary files /dev/null and 
b/flink-connectors/flink-connector-kinesis/src/test/resources/kinesis-consumer-migration-test-flink1.1-snapshot-empty
 differ

Reply via email to