[FLINK-4821] [kinesis] General improvements to rescalable FlinkKinesisConsumer

This commit adds some general improvements to the rescalable
implementation of FlinkKinesisConsumer, including:
- Refactor setup procedures in KinesisDataFetcher so that duplicate work
  isn't done on a restored run
- Strengthen corner cases where fetcher was not fully seeded with
  initial state when snapshot is taken

This closes #3001.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/e5b65a7f
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/e5b65a7f
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/e5b65a7f

Branch: refs/heads/master
Commit: e5b65a7fc2b4a7532ca40748f81bcbf8ace46815
Parents: a05b574
Author: Tzu-Li (Gordon) Tai <[email protected]>
Authored: Sun May 7 16:29:32 2017 +0800
Committer: Tzu-Li (Gordon) Tai <[email protected]>
Committed: Sun May 7 17:33:04 2017 +0800

----------------------------------------------------------------------
 .../kinesis/FlinkKinesisConsumer.java           | 150 ++++++++---------
 .../kinesis/internals/KinesisDataFetcher.java   |  52 +-----
 .../FlinkKinesisConsumerMigrationTest.java      |   5 +-
 .../kinesis/FlinkKinesisConsumerTest.java       | 159 +++++++++++--------
 .../internals/KinesisDataFetcherTest.java       |  65 ++++++--
 .../testutils/TestableKinesisDataFetcher.java   |  14 ++
 6 files changed, 233 insertions(+), 212 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/e5b65a7f/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumer.java
----------------------------------------------------------------------
diff --git 
a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumer.java
 
b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumer.java
index dfcd552..4982f7f 100644
--- 
a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumer.java
+++ 
b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumer.java
@@ -25,13 +25,14 @@ import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.api.java.typeutils.ResultTypeQueryable;
 import org.apache.flink.api.java.typeutils.TupleTypeInfo;
-import org.apache.flink.configuration.Configuration;
 import org.apache.flink.runtime.state.FunctionInitializationContext;
 import org.apache.flink.runtime.state.FunctionSnapshotContext;
 import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
 import org.apache.flink.streaming.api.checkpoint.CheckpointedRestoring;
 import 
org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
 import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import 
org.apache.flink.streaming.connectors.kinesis.config.ConsumerConfigConstants;
+import 
org.apache.flink.streaming.connectors.kinesis.config.ConsumerConfigConstants.InitialPosition;
 import 
org.apache.flink.streaming.connectors.kinesis.internals.KinesisDataFetcher;
 import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShard;
 import 
org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShardState;
@@ -67,9 +68,9 @@ import static 
org.apache.flink.util.Preconditions.checkNotNull;
  * @param <T> the type of data emitted
  */
 public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T> 
implements
-       ResultTypeQueryable<T>,
-       CheckpointedFunction,
-       CheckpointedRestoring<HashMap<KinesisStreamShard, SequenceNumber>> {
+               ResultTypeQueryable<T>,
+               CheckpointedFunction,
+               CheckpointedRestoring<HashMap<KinesisStreamShard, 
SequenceNumber>> {
 
        private static final long serialVersionUID = 4724006128720664870L;
 
@@ -86,7 +87,7 @@ public class FlinkKinesisConsumer<T> extends 
RichParallelSourceFunction<T> imple
         * shard list retrieval behaviours, etc */
        private final Properties configProps;
 
-       /** User supplied deseriliazation schema to convert Kinesis byte 
messages to Flink objects */
+       /** User supplied deserialization schema to convert Kinesis byte 
messages to Flink objects */
        private final KinesisDeserializationSchema<T> deserializer;
 
        // 
------------------------------------------------------------------------
@@ -96,9 +97,6 @@ public class FlinkKinesisConsumer<T> extends 
RichParallelSourceFunction<T> imple
        /** Per-task fetcher for Kinesis data records, where each fetcher pulls 
data from one or more Kinesis shards */
        private transient KinesisDataFetcher<T> fetcher;
 
-       /** The sequence numbers in the last state snapshot of this subtask */
-       private transient HashMap<KinesisStreamShard, SequenceNumber> 
lastStateSnapshot;
-
        /** The sequence numbers to restore to upon restore from failure */
        private transient HashMap<KinesisStreamShard, SequenceNumber> 
sequenceNumsToRestore;
 
@@ -108,7 +106,7 @@ public class FlinkKinesisConsumer<T> extends 
RichParallelSourceFunction<T> imple
        //  State for Checkpoint
        // 
------------------------------------------------------------------------
 
-       /** The name is the key for sequence numbers state, and cannot be 
changed. */
+       /** State name to access shard sequence number states; cannot be 
changed */
        private static final String sequenceNumsStateStoreName = 
"Kinesis-Stream-Shard-State";
 
        private transient ListState<Tuple2<KinesisStreamShard, SequenceNumber>> 
sequenceNumsStateForCheckpoint;
@@ -191,57 +189,33 @@ public class FlinkKinesisConsumer<T> extends 
RichParallelSourceFunction<T> imple
        // 
------------------------------------------------------------------------
 
        @Override
-       public void open(Configuration parameters) throws Exception {
-               super.open(parameters);
-
-               // restore to the last known sequence numbers from the latest 
complete snapshot
-               if (sequenceNumsToRestore != null) {
-                       if (LOG.isInfoEnabled()) {
-                               LOG.info("Subtask {} is restoring sequence 
numbers {} from previous checkpointed state",
-                                       
getRuntimeContext().getIndexOfThisSubtask(), sequenceNumsToRestore.toString());
-                       }
-
-                       // initialize sequence numbers with restored state
-                       lastStateSnapshot = sequenceNumsToRestore;
-               } else {
-                       // start fresh with empty sequence numbers if there are 
no snapshots to restore from.
-                       lastStateSnapshot = new HashMap<>();
-               }
-       }
-
-       @Override
        public void run(SourceContext<T> sourceContext) throws Exception {
 
                // all subtasks will run a fetcher, regardless of whether or 
not the subtask will initially have
                // shards to subscribe to; fetchers will continuously poll for 
changes in the shard list, so all subtasks
                // can potentially have new shards to subscribe to later on
-               fetcher = createFetcher(streams, sourceContext, 
getRuntimeContext(), configProps, deserializer);
-
-               boolean isRestoringFromFailure = (sequenceNumsToRestore != 
null);
-               fetcher.setIsRestoringFromFailure(isRestoringFromFailure);
-
-               // if we are restoring from a checkpoint, we iterate over the 
restored
-               // state and accordingly seed the fetcher with subscribed 
shards states
-               if (isRestoringFromFailure) {
-                       // Since there may have a situation that some subtasks 
did not finish discovering before rescale,
-                       // and KinesisDataFetcher will always discover the 
shard from the largest shard id. To prevent from
-                       // missing some shards which didn't be discovered and 
whose id is not the largest one, we force the
-                       // consumer to discover once from the smallest id and 
make sure each shard have its initial sequence
-                       // number from restored state or 
SENTINEL_EARLIEST_SEQUENCE_NUM.
-                       List<KinesisStreamShard> 
newShardsCreatedWhileNotRunning = fetcher.discoverNewShardsToSubscribe();
-                       for (KinesisStreamShard shard : 
newShardsCreatedWhileNotRunning) {
-                               SequenceNumber startingStateForNewShard;
-
-                               if (lastStateSnapshot.containsKey(shard)) {
-                                       startingStateForNewShard = 
lastStateSnapshot.get(shard);
+               KinesisDataFetcher<T> fetcher = createFetcher(streams, 
sourceContext, getRuntimeContext(), configProps, deserializer);
+
+               // initial discovery
+               List<KinesisStreamShard> allShards = 
fetcher.discoverNewShardsToSubscribe();
+
+               for (KinesisStreamShard shard : allShards) {
+                       if (sequenceNumsToRestore != null) {
+                               if (sequenceNumsToRestore.containsKey(shard)) {
+                                       // if the shard was already seen and is 
contained in the state,
+                                       // just use the sequence number stored 
in the state
+                                       fetcher.registerNewSubscribedShardState(
+                                               new 
KinesisStreamShardState(shard, sequenceNumsToRestore.get(shard)));
 
                                        if (LOG.isInfoEnabled()) {
                                                LOG.info("Subtask {} is seeding 
the fetcher with restored shard {}," +
                                                                " starting 
state set to the restored sequence number {}",
-                                                       
getRuntimeContext().getIndexOfThisSubtask(), shard.toString(), 
startingStateForNewShard);
+                                                       
getRuntimeContext().getIndexOfThisSubtask(), shard.toString(), 
sequenceNumsToRestore.get(shard));
                                        }
                                } else {
-                                       startingStateForNewShard = 
SentinelSequenceNumber.SENTINEL_EARLIEST_SEQUENCE_NUM.get();
+                                       // the shard wasn't discovered in the 
previous run, therefore should be consumed from the beginning
+                                       fetcher.registerNewSubscribedShardState(
+                                               new 
KinesisStreamShardState(shard, 
SentinelSequenceNumber.SENTINEL_EARLIEST_SEQUENCE_NUM.get()));
 
                                        if (LOG.isInfoEnabled()) {
                                                LOG.info("Subtask {} is seeding 
the fetcher with new discovered shard {}," +
@@ -249,9 +223,20 @@ public class FlinkKinesisConsumer<T> extends 
RichParallelSourceFunction<T> imple
                                                        
getRuntimeContext().getIndexOfThisSubtask(), shard.toString());
                                        }
                                }
+                       } else {
+                               // we're starting fresh; use the configured 
start position as initial state
+                               SentinelSequenceNumber startingSeqNum =
+                                       
InitialPosition.valueOf(configProps.getProperty(
+                                               
ConsumerConfigConstants.STREAM_INITIAL_POSITION,
+                                               
ConsumerConfigConstants.DEFAULT_STREAM_INITIAL_POSITION)).toSentinelSequenceNumber();
 
                                fetcher.registerNewSubscribedShardState(
-                                       new KinesisStreamShardState(shard, 
startingStateForNewShard));
+                                       new KinesisStreamShardState(shard, 
startingSeqNum.get()));
+
+                               if (LOG.isInfoEnabled()) {
+                                       LOG.info("Subtask {} will be seeded 
with initial shard {}, starting state set as sequence number {}",
+                                               
getRuntimeContext().getIndexOfThisSubtask(), shard.toString(), 
startingSeqNum.get());
+                               }
                        }
                }
 
@@ -260,6 +245,10 @@ public class FlinkKinesisConsumer<T> extends 
RichParallelSourceFunction<T> imple
                        return;
                }
 
+               // expose the fetcher from this point, so that state
+               // snapshots can be taken from the fetcher's state holders
+               this.fetcher = fetcher;
+
                // start the fetcher loop. The fetcher will stop running only 
when cancel() or
                // close() is called, or an error is thrown by threads created 
by the fetcher
                fetcher.runFetcher();
@@ -306,13 +295,12 @@ public class FlinkKinesisConsumer<T> extends 
RichParallelSourceFunction<T> imple
 
        @Override
        public void initializeState(FunctionInitializationContext context) 
throws Exception {
-               TypeInformation<Tuple2<KinesisStreamShard, SequenceNumber>> 
tuple = new TupleTypeInfo<>(
+               TypeInformation<Tuple2<KinesisStreamShard, SequenceNumber>> 
shardsStateTypeInfo = new TupleTypeInfo<>(
                        TypeInformation.of(KinesisStreamShard.class),
-                       TypeInformation.of(SequenceNumber.class)
-               );
+                       TypeInformation.of(SequenceNumber.class));
 
                sequenceNumsStateForCheckpoint = 
context.getOperatorStateStore().getUnionListState(
-                       new ListStateDescriptor<>(sequenceNumsStateStoreName, 
tuple));
+                       new ListStateDescriptor<>(sequenceNumsStateStoreName, 
shardsStateTypeInfo));
 
                if (context.isRestored()) {
                        if (sequenceNumsToRestore == null) {
@@ -323,8 +311,6 @@ public class FlinkKinesisConsumer<T> extends 
RichParallelSourceFunction<T> imple
 
                                LOG.info("Setting restore state in the 
FlinkKinesisConsumer. Using the following offsets: {}",
                                        sequenceNumsToRestore);
-                       } else if (sequenceNumsToRestore.isEmpty()) {
-                               sequenceNumsToRestore = null;
                        }
                } else {
                        LOG.info("No restore state for FlinkKinesisConsumer.");
@@ -333,11 +319,7 @@ public class FlinkKinesisConsumer<T> extends 
RichParallelSourceFunction<T> imple
 
        @Override
        public void snapshotState(FunctionSnapshotContext context) throws 
Exception {
-               if (lastStateSnapshot == null) {
-                       LOG.debug("snapshotState() requested on not yet opened 
source; returning null.");
-               } else if (fetcher == null) {
-                       LOG.debug("snapshotState() requested on not yet running 
source; returning null.");
-               } else if (!running) {
+               if (!running) {
                        LOG.debug("snapshotState() called on closed source; 
returning null.");
                } else {
                        if (LOG.isDebugEnabled()) {
@@ -345,15 +327,33 @@ public class FlinkKinesisConsumer<T> extends 
RichParallelSourceFunction<T> imple
                        }
 
                        sequenceNumsStateForCheckpoint.clear();
-                       lastStateSnapshot = fetcher.snapshotState();
 
-                       if (LOG.isDebugEnabled()) {
-                               LOG.debug("Snapshotted state, last processed 
sequence numbers: {}, checkpoint id: {}, timestamp: {}",
-                                       lastStateSnapshot.toString(), 
context.getCheckpointId(), context.getCheckpointTimestamp());
-                       }
+                       if (fetcher == null) {
+                               if (sequenceNumsToRestore != null) {
+                                       for (Map.Entry<KinesisStreamShard, 
SequenceNumber> entry : sequenceNumsToRestore.entrySet()) {
+                                               // sequenceNumsToRestore is the 
restored global union state;
+                                               // should only snapshot shards 
that actually belong to us
+
+                                               if 
(KinesisDataFetcher.isThisSubtaskShouldSubscribeTo(
+                                                               entry.getKey(),
+                                                               
getRuntimeContext().getNumberOfParallelSubtasks(),
+                                                               
getRuntimeContext().getIndexOfThisSubtask())) {
+
+                                                       
sequenceNumsStateForCheckpoint.add(Tuple2.of(entry.getKey(), entry.getValue()));
+                                               }
+                                       }
+                               }
+                       } else {
+                               HashMap<KinesisStreamShard, SequenceNumber> 
lastStateSnapshot = fetcher.snapshotState();
+
+                               if (LOG.isDebugEnabled()) {
+                                       LOG.debug("Snapshotted state, last 
processed sequence numbers: {}, checkpoint id: {}, timestamp: {}",
+                                               lastStateSnapshot.toString(), 
context.getCheckpointId(), context.getCheckpointTimestamp());
+                               }
 
-                       for (Map.Entry<KinesisStreamShard, SequenceNumber> 
entry : lastStateSnapshot.entrySet()) {
-                               
sequenceNumsStateForCheckpoint.add(Tuple2.of(entry.getKey(), entry.getValue()));
+                               for (Map.Entry<KinesisStreamShard, 
SequenceNumber> entry : lastStateSnapshot.entrySet()) {
+                                       
sequenceNumsStateForCheckpoint.add(Tuple2.of(entry.getKey(), entry.getValue()));
+                               }
                        }
                }
        }
@@ -366,12 +366,14 @@ public class FlinkKinesisConsumer<T> extends 
RichParallelSourceFunction<T> imple
                sequenceNumsToRestore = restoredState.isEmpty() ? null : 
restoredState;
        }
 
-       /** This method is created for tests that can mock the 
KinesisDataFetcher in the consumer. */
-       protected KinesisDataFetcher<T> createFetcher(List<String> streams,
-                                                                               
                        SourceFunction.SourceContext<T> sourceContext,
-                                                                               
                        RuntimeContext runtimeContext,
-                                                                               
                        Properties configProps,
-                                                                               
                        KinesisDeserializationSchema<T> deserializationSchema) {
+       /** This method is exposed for tests that need to mock the 
KinesisDataFetcher in the consumer. */
+       protected KinesisDataFetcher<T> createFetcher(
+                       List<String> streams,
+                       SourceFunction.SourceContext<T> sourceContext,
+                       RuntimeContext runtimeContext,
+                       Properties configProps,
+                       KinesisDeserializationSchema<T> deserializationSchema) {
+
                return new KinesisDataFetcher<>(streams, sourceContext, 
runtimeContext, configProps, deserializationSchema);
        }
 

http://git-wip-us.apache.org/repos/asf/flink/blob/e5b65a7f/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcher.java
----------------------------------------------------------------------
diff --git 
a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcher.java
 
b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcher.java
index c5b4b04..99305cb 100644
--- 
a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcher.java
+++ 
b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcher.java
@@ -19,9 +19,7 @@ package 
org.apache.flink.streaming.connectors.kinesis.internals;
 
 import org.apache.flink.api.common.functions.RuntimeContext;
 import org.apache.flink.streaming.api.functions.source.SourceFunction;
-import org.apache.flink.streaming.connectors.kinesis.FlinkKinesisConsumer;
 import 
org.apache.flink.streaming.connectors.kinesis.config.ConsumerConfigConstants;
-import 
org.apache.flink.streaming.connectors.kinesis.config.ConsumerConfigConstants.InitialPosition;
 import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShard;
 import 
org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShardState;
 import 
org.apache.flink.streaming.connectors.kinesis.model.SentinelSequenceNumber;
@@ -99,12 +97,6 @@ public class KinesisDataFetcher<T> {
 
        private final int indexOfThisConsumerSubtask;
 
-       /**
-        * This flag should be set by {@link FlinkKinesisConsumer} using
-        * {@link KinesisDataFetcher#setIsRestoringFromFailure(boolean)}
-        */
-       private boolean isRestoredFromFailure;
-
        // 
------------------------------------------------------------------------
        //  Executor services to run created threads
        // 
------------------------------------------------------------------------
@@ -235,41 +227,7 @@ public class KinesisDataFetcher<T> {
                //  Procedures before starting the infinite while loop:
                // 
------------------------------------------------------------------------
 
-               //  1. query for any new shards that may have been created 
while the Kinesis consumer was not running,
-               //     and register them to the subscribedShardState list.
-               if (LOG.isDebugEnabled()) {
-                       String logFormat = (!isRestoredFromFailure)
-                               ? "Subtask {} is trying to discover initial 
shards ..."
-                               : "Subtask {} is trying to discover any new 
shards that were created while the consumer wasn't " +
-                               "running due to failure ...";
-
-                       LOG.debug(logFormat, indexOfThisConsumerSubtask);
-               }
-               List<KinesisStreamShard> newShardsCreatedWhileNotRunning = 
discoverNewShardsToSubscribe();
-               for (KinesisStreamShard shard : 
newShardsCreatedWhileNotRunning) {
-                       // the starting state for new shards created while the 
consumer wasn't running depends on whether or not
-                       // we are starting fresh (not restoring from a 
checkpoint); when we are starting fresh, this simply means
-                       // all existing shards of streams we are subscribing to 
are new shards; when we are restoring from checkpoint,
-                       // any new shards due to Kinesis resharding from the 
time of the checkpoint will be considered new shards.
-                       InitialPosition initialPosition = 
InitialPosition.valueOf(configProps.getProperty(
-                               
ConsumerConfigConstants.STREAM_INITIAL_POSITION, 
ConsumerConfigConstants.DEFAULT_STREAM_INITIAL_POSITION));
-
-                       SentinelSequenceNumber startingStateForNewShard = 
(isRestoredFromFailure)
-                               ? 
SentinelSequenceNumber.SENTINEL_EARLIEST_SEQUENCE_NUM
-                               : initialPosition.toSentinelSequenceNumber();
-
-                       if (LOG.isInfoEnabled()) {
-                               String logFormat = (!isRestoredFromFailure)
-                                       ? "Subtask {} will be seeded with 
initial shard {}, starting state set as sequence number {}"
-                                       : "Subtask {} will be seeded with new 
shard {} that was created while the consumer wasn't " +
-                                       "running due to failure, starting state 
set as sequence number {}";
-
-                               LOG.info(logFormat, indexOfThisConsumerSubtask, 
shard.toString(), startingStateForNewShard.get());
-                       }
-                       registerNewSubscribedShardState(new 
KinesisStreamShardState(shard, startingStateForNewShard.get()));
-               }
-
-               //  2. check that there is at least one shard in the subscribed 
streams to consume from (can be done by
+               //  1. check that there is at least one shard in the subscribed 
streams to consume from (can be done by
                //     checking if at least one value in 
subscribedStreamsToLastDiscoveredShardIds is not null)
                boolean hasShards = false;
                StringBuilder streamsWithNoShardsFound = new StringBuilder();
@@ -290,7 +248,7 @@ public class KinesisDataFetcher<T> {
                        throw new RuntimeException("No shards can be found for 
all subscribed streams: " + streams);
                }
 
-               //  3. start consuming any shard state we already have in the 
subscribedShardState up to this point; the
+               //  2. start consuming any shard state we already have in the 
subscribedShardState up to this point; the
                //     subscribedShardState may already be seeded with values 
due to step 1., or explicitly added by the
                //     consumer using a restored state checkpoint
                for (int seededStateIndex = 0; seededStateIndex < 
subscribedShardsState.size(); seededStateIndex++) {
@@ -489,10 +447,6 @@ public class KinesisDataFetcher<T> {
        //  Functions to get / set information about the consumer
        // 
------------------------------------------------------------------------
 
-       public void setIsRestoringFromFailure(boolean bool) {
-               this.isRestoredFromFailure = bool;
-       }
-
        protected Properties getConsumerConfiguration() {
                return configProps;
        }
@@ -595,7 +549,7 @@ public class KinesisDataFetcher<T> {
         * @param totalNumberOfConsumerSubtasks total number of consumer 
subtasks
         * @param indexOfThisConsumerSubtask index of this consumer subtask
         */
-       private static boolean 
isThisSubtaskShouldSubscribeTo(KinesisStreamShard shard,
+       public static boolean isThisSubtaskShouldSubscribeTo(KinesisStreamShard 
shard,
                                                                                
                                int totalNumberOfConsumerSubtasks,
                                                                                
                                int indexOfThisConsumerSubtask) {
                return (Math.abs(shard.hashCode() % 
totalNumberOfConsumerSubtasks)) == indexOfThisConsumerSubtask;

http://git-wip-us.apache.org/repos/asf/flink/blob/e5b65a7f/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerMigrationTest.java
----------------------------------------------------------------------
diff --git 
a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerMigrationTest.java
 
b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerMigrationTest.java
index 2f46e09..7629f9d 100644
--- 
a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerMigrationTest.java
+++ 
b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerMigrationTest.java
@@ -42,10 +42,7 @@ import static org.mockito.Mockito.mock;
 
 /**
  * Tests for checking whether {@link FlinkKinesisConsumer} can restore from 
snapshots that were
- * done using the Flink 1.1 {@link FlinkKinesisConsumer}.
- *
- * <p>For regenerating the binary snapshot file you have to run the commented 
out portion
- * of each test on a checkout of the Flink 1.1 branch.
+ * done using the Flink 1.1 {@code FlinkKinesisConsumer}.
  */
 public class FlinkKinesisConsumerMigrationTest {
 

http://git-wip-us.apache.org/repos/asf/flink/blob/e5b65a7f/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerTest.java
----------------------------------------------------------------------
diff --git 
a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerTest.java
 
b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerTest.java
index bf8e44f..4b178c7 100644
--- 
a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerTest.java
+++ 
b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerTest.java
@@ -40,6 +40,7 @@ import 
org.apache.flink.streaming.connectors.kinesis.testutils.KinesisShardIdGen
 import 
org.apache.flink.streaming.connectors.kinesis.testutils.TestableFlinkKinesisConsumer;
 import org.apache.flink.streaming.connectors.kinesis.util.KinesisConfigUtil;
 import org.apache.flink.streaming.util.serialization.SimpleStringSchema;
+import org.junit.Assert;
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.rules.ExpectedException;
@@ -57,10 +58,8 @@ import java.util.ArrayList;
 import java.util.List;
 import java.util.Properties;
 import java.util.UUID;
-import java.io.Serializable;
 
 import static org.junit.Assert.fail;
-import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNotEquals;
 import static org.mockito.Mockito.mock;
@@ -530,7 +529,7 @@ public class FlinkKinesisConsumerTest {
        // 
----------------------------------------------------------------------
 
        @Test
-       public void testSnapshotStateShouldNotClearListStateIfSourceNotOpened() 
throws Exception {
+       public void testUseRestoredStateForSnapshotIfFetcherNotInitialized() 
throws Exception {
                Properties config = new Properties();
                config.setProperty(AWSConfigConstants.AWS_REGION, "us-east-1");
                config.setProperty(AWSConfigConstants.AWS_ACCESS_KEY_ID, 
"accessKeyId");
@@ -538,57 +537,63 @@ public class FlinkKinesisConsumerTest {
 
                OperatorStateStore operatorStateStore = 
mock(OperatorStateStore.class);
 
-               TestingListState<Serializable> listState = new 
TestingListState<>();
-
-               FlinkKinesisConsumer<String> consumer = new 
FlinkKinesisConsumer<>("fakeStream", new SimpleStringSchema(), config);
-
-               
when(operatorStateStore.getUnionListState(Matchers.any(ListStateDescriptor.class))).thenReturn(listState);
-
-               StateInitializationContext initializationContext = 
mock(StateInitializationContext.class);
-
-               
when(initializationContext.getOperatorStateStore()).thenReturn(operatorStateStore);
-               when(initializationContext.isRestored()).thenReturn(false);
-
-               consumer.initializeState(initializationContext);
-
-               consumer.snapshotState(new 
StateSnapshotContextSynchronousImpl(123, 123)); //arbitrary checkpoint id and 
timestamp
-
-               assertFalse(listState.isClearCalled());
-       }
-
-       @Test
-       public void testSnapshotStateShouldNotClearListStateIfSourceNotRun() 
throws Exception {
-               Properties config = new Properties();
-               config.setProperty(AWSConfigConstants.AWS_REGION, "us-east-1");
-               config.setProperty(AWSConfigConstants.AWS_ACCESS_KEY_ID, 
"accessKeyId");
-               config.setProperty(AWSConfigConstants.AWS_SECRET_ACCESS_KEY, 
"secretKey");
-
-               OperatorStateStore operatorStateStore = 
mock(OperatorStateStore.class);
+               List<Tuple2<KinesisStreamShard, SequenceNumber>> 
globalUnionState = new ArrayList<>(4);
+               globalUnionState.add(Tuple2.of(
+                       new KinesisStreamShard("fakeStream",
+                               new 
Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(0))),
+                       new SequenceNumber("1")));
+               globalUnionState.add(Tuple2.of(
+                       new KinesisStreamShard("fakeStream",
+                               new 
Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(1))),
+                       new SequenceNumber("1")));
+               globalUnionState.add(Tuple2.of(
+                       new KinesisStreamShard("fakeStream",
+                               new 
Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(2))),
+                       new SequenceNumber("1")));
+               globalUnionState.add(Tuple2.of(
+                       new KinesisStreamShard("fakeStream",
+                               new 
Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(3))),
+                       new SequenceNumber("1")));
 
-               TestingListState<Serializable> listState = new 
TestingListState<>();
+               TestingListState<Tuple2<KinesisStreamShard, SequenceNumber>> 
listState = new TestingListState<>();
+               for (Tuple2<KinesisStreamShard, SequenceNumber> state : 
globalUnionState) {
+                       listState.add(state);
+               }
 
                FlinkKinesisConsumer<String> consumer = new 
FlinkKinesisConsumer<>("fakeStream", new SimpleStringSchema(), config);
+               RuntimeContext context = mock(RuntimeContext.class);
+               when(context.getIndexOfThisSubtask()).thenReturn(0);
+               when(context.getNumberOfParallelSubtasks()).thenReturn(2);
+               consumer.setRuntimeContext(context);
 
                
when(operatorStateStore.getUnionListState(Matchers.any(ListStateDescriptor.class))).thenReturn(listState);
 
                StateInitializationContext initializationContext = 
mock(StateInitializationContext.class);
 
                
when(initializationContext.getOperatorStateStore()).thenReturn(operatorStateStore);
-               when(initializationContext.isRestored()).thenReturn(false);
+               when(initializationContext.isRestored()).thenReturn(true);
 
                consumer.initializeState(initializationContext);
 
-               consumer.open(new Configuration()); // only opened, not run
+               // only opened, not run
+               consumer.open(new Configuration());
+
+               // arbitrary checkpoint id and timestamp
+               consumer.snapshotState(new 
StateSnapshotContextSynchronousImpl(123, 123));
 
-               consumer.snapshotState(new 
StateSnapshotContextSynchronousImpl(123, 123)); //arbitrary checkpoint id and 
timestamp
+               Assert.assertTrue(listState.isClearCalled());
 
-               assertFalse(listState.isClearCalled());
+               // the checkpointed list state should contain only the shards 
that it should subscribe to
+               Assert.assertEquals(globalUnionState.size() / 2, 
listState.getList().size());
+               
Assert.assertTrue(listState.getList().contains(globalUnionState.get(0)));
+               
Assert.assertTrue(listState.getList().contains(globalUnionState.get(2)));
        }
 
        @Test
        public void testListStateChangedAfterSnapshotState() throws Exception {
+
                // 
----------------------------------------------------------------------
-               // setting config, initial state and state after snapshot
+               // setup config, initial state and expected state snapshot
                // 
----------------------------------------------------------------------
                Properties config = new Properties();
                config.setProperty(AWSConfigConstants.AWS_REGION, "us-east-1");
@@ -601,16 +606,16 @@ public class FlinkKinesisConsumerTest {
                                new 
Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(0))),
                        new SequenceNumber("1")));
 
-               ArrayList<Tuple2<KinesisStreamShard, SequenceNumber>> 
snapShotState = new ArrayList<>(3);
-               snapShotState.add(Tuple2.of(
+               ArrayList<Tuple2<KinesisStreamShard, SequenceNumber>> 
expectedStateSnapshot = new ArrayList<>(3);
+               expectedStateSnapshot.add(Tuple2.of(
                        new KinesisStreamShard("fakeStream1",
                                new 
Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(0))),
                        new SequenceNumber("12")));
-               snapShotState.add(Tuple2.of(
+               expectedStateSnapshot.add(Tuple2.of(
                        new KinesisStreamShard("fakeStream1",
                                new 
Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(1))),
                        new SequenceNumber("11")));
-               snapShotState.add(Tuple2.of(
+               expectedStateSnapshot.add(Tuple2.of(
                        new KinesisStreamShard("fakeStream1",
                                new 
Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(2))),
                        new SequenceNumber("31")));
@@ -618,8 +623,9 @@ public class FlinkKinesisConsumerTest {
                // 
----------------------------------------------------------------------
                // mock operator state backend and initial state for 
initializeState()
                // 
----------------------------------------------------------------------
-               TestingListState<Serializable> listState = new 
TestingListState<>();
-               for (Serializable state: initialState) {
+
+               TestingListState<Tuple2<KinesisStreamShard, SequenceNumber>> 
listState = new TestingListState<>();
+               for (Tuple2<KinesisStreamShard, SequenceNumber> state: 
initialState) {
                        listState.add(state);
                }
 
@@ -633,8 +639,9 @@ public class FlinkKinesisConsumerTest {
                // 
----------------------------------------------------------------------
                // mock a running fetcher and its state for snapshot
                // 
----------------------------------------------------------------------
+
                HashMap<KinesisStreamShard, SequenceNumber> stateSnapshot = new 
HashMap<>();
-               for (Tuple2<KinesisStreamShard, SequenceNumber> tuple: 
snapShotState) {
+               for (Tuple2<KinesisStreamShard, SequenceNumber> tuple: 
expectedStateSnapshot) {
                        stateSnapshot.put(tuple.f0, tuple.f1);
                }
 
@@ -644,6 +651,7 @@ public class FlinkKinesisConsumerTest {
                // 
----------------------------------------------------------------------
                // create a consumer and test the snapshotState()
                // 
----------------------------------------------------------------------
+
                FlinkKinesisConsumer<String> consumer = new 
FlinkKinesisConsumer<>("fakeStream", new SimpleStringSchema(), config);
                FlinkKinesisConsumer<?> mockedConsumer = spy(consumer);
 
@@ -653,22 +661,22 @@ public class FlinkKinesisConsumerTest {
                mockedConsumer.setRuntimeContext(context);
                mockedConsumer.initializeState(initializationContext);
                mockedConsumer.open(new Configuration());
-               Whitebox.setInternalState(mockedConsumer, "fetcher", 
mockedFetcher); // mock as consumer is running.
+               Whitebox.setInternalState(mockedConsumer, "fetcher", 
mockedFetcher); // mock consumer as running.
 
                
mockedConsumer.snapshotState(mock(FunctionSnapshotContext.class));
 
                assertEquals(true, listState.clearCalled);
                assertEquals(3, listState.getList().size());
 
-               for (Serializable state: initialState) {
-                       for (Serializable currentState: listState.getList()) {
+               for (Tuple2<KinesisStreamShard, SequenceNumber> state: 
initialState) {
+                       for (Tuple2<KinesisStreamShard, SequenceNumber> 
currentState: listState.getList()) {
                                assertNotEquals(state, currentState);
                        }
                }
 
-               for (Serializable state: snapShotState) {
+               for (Tuple2<KinesisStreamShard, SequenceNumber> state: 
expectedStateSnapshot) {
                        boolean hasOneIsSame = false;
-                       for (Serializable currentState: listState.getList()) {
+                       for (Tuple2<KinesisStreamShard, SequenceNumber> 
currentState: listState.getList()) {
                                hasOneIsSame = hasOneIsSame || 
state.equals(currentState);
                        }
                        assertEquals(true, hasOneIsSame);
@@ -693,8 +701,6 @@ public class FlinkKinesisConsumerTest {
                        "fakeStream", new Properties(), 10, 2);
                consumer.open(new Configuration());
                consumer.run(Mockito.mock(SourceFunction.SourceContext.class));
-
-               Mockito.verify(mockedFetcher).setIsRestoringFromFailure(false);
        }
 
        @Test
@@ -718,7 +724,6 @@ public class FlinkKinesisConsumerTest {
                consumer.open(new Configuration());
                consumer.run(Mockito.mock(SourceFunction.SourceContext.class));
 
-               Mockito.verify(mockedFetcher).setIsRestoringFromFailure(true);
                for (Map.Entry<KinesisStreamShard, SequenceNumber> 
restoredShard : fakeRestoredState.entrySet()) {
                        
Mockito.verify(mockedFetcher).registerNewSubscribedShardState(
                                new 
KinesisStreamShardState(restoredShard.getKey(), restoredShard.getValue()));
@@ -728,15 +733,18 @@ public class FlinkKinesisConsumerTest {
        @Test
        @SuppressWarnings("unchecked")
        public void 
testFetcherShouldBeCorrectlySeededIfRestoringFromCheckpoint() throws Exception {
+
                // 
----------------------------------------------------------------------
-               // setting initial state
+               // setup initial state
                // 
----------------------------------------------------------------------
+
                HashMap<KinesisStreamShard, SequenceNumber> fakeRestoredState = 
getFakeRestoredStore("all");
 
                // 
----------------------------------------------------------------------
                // mock operator state backend and initial state for 
initializeState()
                // 
----------------------------------------------------------------------
-               TestingListState<Serializable> listState = new 
TestingListState<>();
+
+               TestingListState<Tuple2<KinesisStreamShard, SequenceNumber>> 
listState = new TestingListState<>();
                for (Map.Entry<KinesisStreamShard, SequenceNumber> state: 
fakeRestoredState.entrySet()) {
                        listState.add(Tuple2.of(state.getKey(), 
state.getValue()));
                }
@@ -751,6 +759,7 @@ public class FlinkKinesisConsumerTest {
                // 
----------------------------------------------------------------------
                // mock fetcher
                // 
----------------------------------------------------------------------
+
                KinesisDataFetcher mockedFetcher = 
Mockito.mock(KinesisDataFetcher.class);
                List<KinesisStreamShard> shards = new ArrayList<>();
                shards.addAll(fakeRestoredState.keySet());
@@ -762,15 +771,15 @@ public class FlinkKinesisConsumerTest {
                PowerMockito.doNothing().when(KinesisConfigUtil.class);
 
                // 
----------------------------------------------------------------------
-               // start to test seed initial state to fetcher
+               // start to test fetcher's initial state seeding
                // 
----------------------------------------------------------------------
+
                TestableFlinkKinesisConsumer consumer = new 
TestableFlinkKinesisConsumer(
                        "fakeStream", new Properties(), 10, 2);
                consumer.initializeState(initializationContext);
                consumer.open(new Configuration());
                consumer.run(Mockito.mock(SourceFunction.SourceContext.class));
 
-               Mockito.verify(mockedFetcher).setIsRestoringFromFailure(true);
                for (Map.Entry<KinesisStreamShard, SequenceNumber> 
restoredShard : fakeRestoredState.entrySet()) {
                        
Mockito.verify(mockedFetcher).registerNewSubscribedShardState(
                                new 
KinesisStreamShardState(restoredShard.getKey(), restoredShard.getValue()));
@@ -780,9 +789,11 @@ public class FlinkKinesisConsumerTest {
        @Test
        @SuppressWarnings("unchecked")
        public void testFetcherShouldBeCorrectlySeededOnlyItsOwnStates() throws 
Exception {
+
                // 
----------------------------------------------------------------------
-               // setting initial state
+               // setup initial state
                // 
----------------------------------------------------------------------
+
                HashMap<KinesisStreamShard, SequenceNumber> fakeRestoredState = 
getFakeRestoredStore("fakeStream1");
 
                HashMap<KinesisStreamShard, SequenceNumber> 
fakeRestoredStateForOthers = getFakeRestoredStore("fakeStream2");
@@ -790,7 +801,8 @@ public class FlinkKinesisConsumerTest {
                // 
----------------------------------------------------------------------
                // mock operator state backend and initial state for 
initializeState()
                // 
----------------------------------------------------------------------
-               TestingListState<Serializable> listState = new 
TestingListState<>();
+
+               TestingListState<Tuple2<KinesisStreamShard, SequenceNumber>> 
listState = new TestingListState<>();
                for (Map.Entry<KinesisStreamShard, SequenceNumber> state: 
fakeRestoredState.entrySet()) {
                        listState.add(Tuple2.of(state.getKey(), 
state.getValue()));
                }
@@ -808,6 +820,7 @@ public class FlinkKinesisConsumerTest {
                // 
----------------------------------------------------------------------
                // mock fetcher
                // 
----------------------------------------------------------------------
+
                KinesisDataFetcher mockedFetcher = 
Mockito.mock(KinesisDataFetcher.class);
                List<KinesisStreamShard> shards = new ArrayList<>();
                shards.addAll(fakeRestoredState.keySet());
@@ -819,15 +832,15 @@ public class FlinkKinesisConsumerTest {
                PowerMockito.doNothing().when(KinesisConfigUtil.class);
 
                // 
----------------------------------------------------------------------
-               // start to test seed initial state to fetcher
+               // start to test fetcher's initial state seeding
                // 
----------------------------------------------------------------------
+
                TestableFlinkKinesisConsumer consumer = new 
TestableFlinkKinesisConsumer(
                        "fakeStream", new Properties(), 10, 2);
                consumer.initializeState(initializationContext);
                consumer.open(new Configuration());
                consumer.run(Mockito.mock(SourceFunction.SourceContext.class));
 
-               Mockito.verify(mockedFetcher).setIsRestoringFromFailure(true);
                for (Map.Entry<KinesisStreamShard, SequenceNumber> 
restoredShard : fakeRestoredStateForOthers.entrySet()) {
                        // should never get restored state not belonging to 
itself
                        Mockito.verify(mockedFetcher, 
never()).registerNewSubscribedShardState(
@@ -841,42 +854,49 @@ public class FlinkKinesisConsumerTest {
        }
 
        /*
-        * If the original parallelism is 2 and states is:
+        * This tests that the consumer correctly picks up shards that were not 
discovered on the previous run.
+        *
+        * Case under test:
+        *
+        * If the original parallelism is 2 and states are:
         *   Consumer subtask 1:
         *     stream1, shard1, SequentialNumber(xxx)
         *   Consumer subtask 2:
         *     stream1, shard2, SequentialNumber(yyy)
-        * After discoverNewShardsToSubscribe() if there are two shards 
(shard3, shard4) been created:
+        *
+        * After discoverNewShardsToSubscribe() if there were two shards 
(shard3, shard4) created:
         *   Consumer subtask 1 (late for discoverNewShardsToSubscribe()):
         *     stream1, shard1, SequentialNumber(xxx)
         *   Consumer subtask 2:
         *     stream1, shard2, SequentialNumber(yyy)
         *     stream1, shard4, SequentialNumber(zzz)
-        *  If snapshotState() occur and parallelism is changed to 1:
-        *    Union state will be:
+        *
+        * If snapshotState() occurs and parallelism is changed to 1:
+        *   Union state will be:
         *     stream1, shard1, SequentialNumber(xxx)
         *     stream1, shard2, SequentialNumber(yyy)
         *     stream1, shard4, SequentialNumber(zzz)
-        *    Fetcher should be seeded with:
+        *   Fetcher should be seeded with:
         *     stream1, shard1, SequentialNumber(xxx)
         *     stream1, shard2, SequentialNumber(yyy)
         *     stream1, share3, 
SentinelSequenceNumber.SENTINEL_EARLIEST_SEQUENCE_NUM
         *     stream1, shard4, SequentialNumber(zzz)
-        *
-        *  This test is to guarantee the fetcher will be seeded correctly for 
such situation.
         */
        @Test
        @SuppressWarnings("unchecked")
        public void 
testFetcherShouldBeCorrectlySeededWithNewDiscoveredKinesisStreamShard() throws 
Exception {
+
                // 
----------------------------------------------------------------------
-               // setting initial state
+               // setup initial state
                // 
----------------------------------------------------------------------
+
                HashMap<KinesisStreamShard, SequenceNumber> fakeRestoredState = 
getFakeRestoredStore("all");
 
                // 
----------------------------------------------------------------------
                // mock operator state backend and initial state for 
initializeState()
                // 
----------------------------------------------------------------------
-               TestingListState<Serializable> listState = new 
TestingListState<>();
+
+               TestingListState<Tuple2<KinesisStreamShard, SequenceNumber>> 
listState = new TestingListState<>();
                for (Map.Entry<KinesisStreamShard, SequenceNumber> state: 
fakeRestoredState.entrySet()) {
                        listState.add(Tuple2.of(state.getKey(), 
state.getValue()));
                }
@@ -891,6 +911,7 @@ public class FlinkKinesisConsumerTest {
                // 
----------------------------------------------------------------------
                // mock fetcher
                // 
----------------------------------------------------------------------
+
                KinesisDataFetcher mockedFetcher = 
Mockito.mock(KinesisDataFetcher.class);
                List<KinesisStreamShard> shards = new ArrayList<>();
                shards.addAll(fakeRestoredState.keySet());
@@ -904,8 +925,9 @@ public class FlinkKinesisConsumerTest {
                PowerMockito.doNothing().when(KinesisConfigUtil.class);
 
                // 
----------------------------------------------------------------------
-               // start to test seed initial state to fetcher
+               // start to test fetcher's initial state seeding
                // 
----------------------------------------------------------------------
+
                TestableFlinkKinesisConsumer consumer = new 
TestableFlinkKinesisConsumer(
                        "fakeStream", new Properties(), 10, 2);
                consumer.initializeState(initializationContext);
@@ -915,7 +937,6 @@ public class FlinkKinesisConsumerTest {
                fakeRestoredState.put(new KinesisStreamShard("fakeStream2",
                                new 
Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(2))),
                        
SentinelSequenceNumber.SENTINEL_EARLIEST_SEQUENCE_NUM.get());
-               Mockito.verify(mockedFetcher).setIsRestoringFromFailure(true);
                for (Map.Entry<KinesisStreamShard, SequenceNumber> 
restoredShard : fakeRestoredState.entrySet()) {
                        
Mockito.verify(mockedFetcher).registerNewSubscribedShardState(
                                new 
KinesisStreamShardState(restoredShard.getKey(), restoredShard.getValue()));

http://git-wip-us.apache.org/repos/asf/flink/blob/e5b65a7f/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcherTest.java
----------------------------------------------------------------------
diff --git 
a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcherTest.java
 
b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcherTest.java
index e79f9b1..800fde5 100644
--- 
a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcherTest.java
+++ 
b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcherTest.java
@@ -18,9 +18,14 @@
 package org.apache.flink.streaming.connectors.kinesis.internals;
 
 import com.amazonaws.services.kinesis.model.Shard;
+import org.apache.flink.api.common.functions.RuntimeContext;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.streaming.connectors.kinesis.FlinkKinesisConsumer;
+import 
org.apache.flink.streaming.connectors.kinesis.config.ConsumerConfigConstants;
 import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShard;
 import org.apache.flink.streaming.connectors.kinesis.model.SequenceNumber;
 import 
org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShardState;
+import 
org.apache.flink.streaming.connectors.kinesis.serialization.KinesisDeserializationSchema;
 import 
org.apache.flink.streaming.connectors.kinesis.testutils.FakeKinesisBehavioursFactory;
 import 
org.apache.flink.streaming.connectors.kinesis.testutils.KinesisShardIdGenerator;
 import 
org.apache.flink.streaming.connectors.kinesis.testutils.TestableKinesisDataFetcher;
@@ -42,6 +47,8 @@ import java.util.UUID;
 import java.util.concurrent.atomic.AtomicReference;
 
 import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
 
 @RunWith(PowerMockRunner.class)
 @PrepareForTest(TestableKinesisDataFetcher.class)
@@ -67,8 +74,6 @@ public class KinesisDataFetcherTest {
                                subscribedStreamsToLastSeenShardIdsUnderTest,
                                
FakeKinesisBehavioursFactory.noShardsFoundForRequestedStreamsBehaviour());
 
-               fetcher.setIsRestoringFromFailure(false); // not restoring
-
                fetcher.runFetcher(); // this should throw RuntimeException
        }
 
@@ -100,23 +105,30 @@ public class KinesisDataFetcherTest {
                                subscribedStreamsToLastSeenShardIdsUnderTest,
                                
FakeKinesisBehavioursFactory.nonReshardedStreamsBehaviour(streamToShardCount));
 
-               fetcher.setIsRestoringFromFailure(false);
+               Properties testConfig = new Properties();
+               testConfig.setProperty(ConsumerConfigConstants.AWS_REGION, 
"us-east-1");
+               
testConfig.setProperty(ConsumerConfigConstants.AWS_CREDENTIALS_PROVIDER, 
"BASIC");
+               
testConfig.setProperty(ConsumerConfigConstants.AWS_ACCESS_KEY_ID, 
"accessKeyId");
+               
testConfig.setProperty(ConsumerConfigConstants.AWS_SECRET_ACCESS_KEY, 
"secretKey");
+
+               final DummyFlinkKafkaConsumer<String> consumer = new 
DummyFlinkKafkaConsumer<>(testConfig, fetcher);
 
                
PowerMockito.whenNew(ShardConsumer.class).withAnyArguments().thenReturn(Mockito.mock(ShardConsumer.class));
-               Thread runFetcherThread = new Thread(new Runnable() {
+               Thread consumerThread = new Thread(new Runnable() {
                        @Override
                        public void run() {
                                try {
-                                       fetcher.runFetcher();
+                                       
consumer.run(mock(SourceFunction.SourceContext.class));
                                } catch (Exception e) {
                                        //
                                }
                        }
                });
-               runFetcherThread.start();
-               Thread.sleep(1000); // sleep a while before closing
-               fetcher.shutdownFetcher();
+               consumerThread.start();
 
+               fetcher.waitUntilRun();
+               consumer.cancel();
+               consumerThread.join();
 
                // assert that the streams tracked in the state are identical 
to the subscribed streams
                Set<String> streamsInState = 
subscribedStreamsToLastSeenShardIdsUnderTest.keySet();
@@ -192,8 +204,6 @@ public class KinesisDataFetcherTest {
                                new 
KinesisStreamShardState(restoredState.getKey(), new 
SequenceNumber(restoredState.getValue())));
                }
 
-               fetcher.setIsRestoringFromFailure(true);
-
                
PowerMockito.whenNew(ShardConsumer.class).withAnyArguments().thenReturn(Mockito.mock(ShardConsumer.class));
                Thread runFetcherThread = new Thread(new Runnable() {
                        @Override
@@ -284,8 +294,6 @@ public class KinesisDataFetcherTest {
                                new 
KinesisStreamShardState(restoredState.getKey(), new 
SequenceNumber(restoredState.getValue())));
                }
 
-               fetcher.setIsRestoringFromFailure(true);
-
                
PowerMockito.whenNew(ShardConsumer.class).withAnyArguments().thenReturn(Mockito.mock(ShardConsumer.class));
                Thread runFetcherThread = new Thread(new Runnable() {
                        @Override
@@ -380,8 +388,6 @@ public class KinesisDataFetcherTest {
                                new 
KinesisStreamShardState(restoredState.getKey(), new 
SequenceNumber(restoredState.getValue())));
                }
 
-               fetcher.setIsRestoringFromFailure(true);
-
                
PowerMockito.whenNew(ShardConsumer.class).withAnyArguments().thenReturn(Mockito.mock(ShardConsumer.class));
                Thread runFetcherThread = new Thread(new Runnable() {
                        @Override
@@ -477,8 +483,6 @@ public class KinesisDataFetcherTest {
                                new 
KinesisStreamShardState(restoredState.getKey(), new 
SequenceNumber(restoredState.getValue())));
                }
 
-               fetcher.setIsRestoringFromFailure(true);
-
                
PowerMockito.whenNew(ShardConsumer.class).withAnyArguments().thenReturn(Mockito.mock(ShardConsumer.class));
                Thread runFetcherThread = new Thread(new Runnable() {
                        @Override
@@ -507,4 +511,33 @@ public class KinesisDataFetcherTest {
                
assertTrue(subscribedStreamsToLastSeenShardIdsUnderTest.get("fakeStream3") == 
null);
                
assertTrue(subscribedStreamsToLastSeenShardIdsUnderTest.get("fakeStream4") == 
null);
        }
+
+       private static class DummyFlinkKafkaConsumer<T> extends 
FlinkKinesisConsumer<T> {
+               private static final long serialVersionUID = 1L;
+
+               private KinesisDataFetcher<T> fetcher;
+
+               @SuppressWarnings("unchecked")
+               DummyFlinkKafkaConsumer(Properties properties, 
KinesisDataFetcher<T> fetcher) {
+                       super("test", mock(KinesisDeserializationSchema.class), 
properties);
+                       this.fetcher = fetcher;
+               }
+
+               @Override
+               protected KinesisDataFetcher<T> createFetcher(List<String> 
streams,
+                                                                               
                          SourceFunction.SourceContext<T> sourceContext,
+                                                                               
                          RuntimeContext runtimeContext,
+                                                                               
                          Properties configProps,
+                                                                               
                          KinesisDeserializationSchema<T> 
deserializationSchema) {
+                       return fetcher;
+               }
+
+               @Override
+               public RuntimeContext getRuntimeContext() {
+                       RuntimeContext context = mock(RuntimeContext.class);
+                       when(context.getIndexOfThisSubtask()).thenReturn(0);
+                       
when(context.getNumberOfParallelSubtasks()).thenReturn(1);
+                       return context;
+               }
+       }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/e5b65a7f/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/testutils/TestableKinesisDataFetcher.java
----------------------------------------------------------------------
diff --git 
a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/testutils/TestableKinesisDataFetcher.java
 
b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/testutils/TestableKinesisDataFetcher.java
index 57886fe..bb644ba 100644
--- 
a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/testutils/TestableKinesisDataFetcher.java
+++ 
b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/testutils/TestableKinesisDataFetcher.java
@@ -18,6 +18,7 @@
 package org.apache.flink.streaming.connectors.kinesis.testutils;
 
 import org.apache.flink.api.common.functions.RuntimeContext;
+import org.apache.flink.core.testutils.OneShotLatch;
 import org.apache.flink.streaming.api.functions.source.SourceFunction;
 import 
org.apache.flink.streaming.connectors.kinesis.internals.KinesisDataFetcher;
 import 
org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShardState;
@@ -42,6 +43,8 @@ public class TestableKinesisDataFetcher extends 
KinesisDataFetcher<String> {
 
        private long numElementsCollected;
 
+       private OneShotLatch runWaiter;
+
        public TestableKinesisDataFetcher(List<String> fakeStreams,
                                                                          
Properties fakeConfiguration,
                                                                          int 
fakeTotalCountOfSubtasks,
@@ -62,6 +65,7 @@ public class TestableKinesisDataFetcher extends 
KinesisDataFetcher<String> {
                        fakeKinesis);
 
                this.numElementsCollected = 0;
+               this.runWaiter = new OneShotLatch();
        }
 
        public long getNumOfElementsCollected() {
@@ -81,6 +85,16 @@ public class TestableKinesisDataFetcher extends 
KinesisDataFetcher<String> {
                }
        }
 
+       @Override
+       public void runFetcher() throws Exception {
+               runWaiter.trigger();
+               super.runFetcher();
+       }
+
+       public void waitUntilRun() throws Exception {
+               runWaiter.await();
+       }
+
        @SuppressWarnings("unchecked")
        private static SourceFunction.SourceContext<String> 
getMockedSourceContext() {
                return Mockito.mock(SourceFunction.SourceContext.class);

Reply via email to