[FLINK-4821] [kinesis] General improvements to rescalable FlinkKinesisConsumer
This commit adds some general improvements to the rescalable implementation of FlinkKinesisConsumer, including: - Refactor setup procedures in KinesisDataFetcher so that duplicate work isn't done on a restored run - Strengthen corner cases where fetcher was not fully seeded with initial state when snapshot is taken This closes #3001. Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/e5b65a7f Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/e5b65a7f Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/e5b65a7f Branch: refs/heads/master Commit: e5b65a7fc2b4a7532ca40748f81bcbf8ace46815 Parents: a05b574 Author: Tzu-Li (Gordon) Tai <[email protected]> Authored: Sun May 7 16:29:32 2017 +0800 Committer: Tzu-Li (Gordon) Tai <[email protected]> Committed: Sun May 7 17:33:04 2017 +0800 ---------------------------------------------------------------------- .../kinesis/FlinkKinesisConsumer.java | 150 ++++++++--------- .../kinesis/internals/KinesisDataFetcher.java | 52 +----- .../FlinkKinesisConsumerMigrationTest.java | 5 +- .../kinesis/FlinkKinesisConsumerTest.java | 159 +++++++++++-------- .../internals/KinesisDataFetcherTest.java | 65 ++++++-- .../testutils/TestableKinesisDataFetcher.java | 14 ++ 6 files changed, 233 insertions(+), 212 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/e5b65a7f/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumer.java ---------------------------------------------------------------------- diff --git a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumer.java b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumer.java index dfcd552..4982f7f 100644 --- a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumer.java +++ b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumer.java @@ -25,13 +25,14 @@ import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.typeutils.ResultTypeQueryable; import org.apache.flink.api.java.typeutils.TupleTypeInfo; -import org.apache.flink.configuration.Configuration; import org.apache.flink.runtime.state.FunctionInitializationContext; import org.apache.flink.runtime.state.FunctionSnapshotContext; import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction; import org.apache.flink.streaming.api.checkpoint.CheckpointedRestoring; import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction; import org.apache.flink.streaming.api.functions.source.SourceFunction; +import org.apache.flink.streaming.connectors.kinesis.config.ConsumerConfigConstants; +import org.apache.flink.streaming.connectors.kinesis.config.ConsumerConfigConstants.InitialPosition; import org.apache.flink.streaming.connectors.kinesis.internals.KinesisDataFetcher; import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShard; import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShardState; @@ -67,9 +68,9 @@ import static org.apache.flink.util.Preconditions.checkNotNull; * @param <T> the type of data emitted */ public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T> implements - ResultTypeQueryable<T>, - CheckpointedFunction, - CheckpointedRestoring<HashMap<KinesisStreamShard, SequenceNumber>> { + ResultTypeQueryable<T>, + CheckpointedFunction, + CheckpointedRestoring<HashMap<KinesisStreamShard, SequenceNumber>> { private static final long serialVersionUID = 4724006128720664870L; @@ -86,7 +87,7 @@ public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T> imple * shard list retrieval behaviours, etc */ private final Properties configProps; - /** User supplied deseriliazation schema to convert Kinesis byte messages to Flink objects */ + /** User supplied deserialization schema to convert Kinesis byte messages to Flink objects */ private final KinesisDeserializationSchema<T> deserializer; // ------------------------------------------------------------------------ @@ -96,9 +97,6 @@ public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T> imple /** Per-task fetcher for Kinesis data records, where each fetcher pulls data from one or more Kinesis shards */ private transient KinesisDataFetcher<T> fetcher; - /** The sequence numbers in the last state snapshot of this subtask */ - private transient HashMap<KinesisStreamShard, SequenceNumber> lastStateSnapshot; - /** The sequence numbers to restore to upon restore from failure */ private transient HashMap<KinesisStreamShard, SequenceNumber> sequenceNumsToRestore; @@ -108,7 +106,7 @@ public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T> imple // State for Checkpoint // ------------------------------------------------------------------------ - /** The name is the key for sequence numbers state, and cannot be changed. */ + /** State name to access shard sequence number states; cannot be changed */ private static final String sequenceNumsStateStoreName = "Kinesis-Stream-Shard-State"; private transient ListState<Tuple2<KinesisStreamShard, SequenceNumber>> sequenceNumsStateForCheckpoint; @@ -191,57 +189,33 @@ public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T> imple // ------------------------------------------------------------------------ @Override - public void open(Configuration parameters) throws Exception { - super.open(parameters); - - // restore to the last known sequence numbers from the latest complete snapshot - if (sequenceNumsToRestore != null) { - if (LOG.isInfoEnabled()) { - LOG.info("Subtask {} is restoring sequence numbers {} from previous checkpointed state", - getRuntimeContext().getIndexOfThisSubtask(), sequenceNumsToRestore.toString()); - } - - // initialize sequence numbers with restored state - lastStateSnapshot = sequenceNumsToRestore; - } else { - // start fresh with empty sequence numbers if there are no snapshots to restore from. - lastStateSnapshot = new HashMap<>(); - } - } - - @Override public void run(SourceContext<T> sourceContext) throws Exception { // all subtasks will run a fetcher, regardless of whether or not the subtask will initially have // shards to subscribe to; fetchers will continuously poll for changes in the shard list, so all subtasks // can potentially have new shards to subscribe to later on - fetcher = createFetcher(streams, sourceContext, getRuntimeContext(), configProps, deserializer); - - boolean isRestoringFromFailure = (sequenceNumsToRestore != null); - fetcher.setIsRestoringFromFailure(isRestoringFromFailure); - - // if we are restoring from a checkpoint, we iterate over the restored - // state and accordingly seed the fetcher with subscribed shards states - if (isRestoringFromFailure) { - // Since there may have a situation that some subtasks did not finish discovering before rescale, - // and KinesisDataFetcher will always discover the shard from the largest shard id. To prevent from - // missing some shards which didn't be discovered and whose id is not the largest one, we force the - // consumer to discover once from the smallest id and make sure each shard have its initial sequence - // number from restored state or SENTINEL_EARLIEST_SEQUENCE_NUM. - List<KinesisStreamShard> newShardsCreatedWhileNotRunning = fetcher.discoverNewShardsToSubscribe(); - for (KinesisStreamShard shard : newShardsCreatedWhileNotRunning) { - SequenceNumber startingStateForNewShard; - - if (lastStateSnapshot.containsKey(shard)) { - startingStateForNewShard = lastStateSnapshot.get(shard); + KinesisDataFetcher<T> fetcher = createFetcher(streams, sourceContext, getRuntimeContext(), configProps, deserializer); + + // initial discovery + List<KinesisStreamShard> allShards = fetcher.discoverNewShardsToSubscribe(); + + for (KinesisStreamShard shard : allShards) { + if (sequenceNumsToRestore != null) { + if (sequenceNumsToRestore.containsKey(shard)) { + // if the shard was already seen and is contained in the state, + // just use the sequence number stored in the state + fetcher.registerNewSubscribedShardState( + new KinesisStreamShardState(shard, sequenceNumsToRestore.get(shard))); if (LOG.isInfoEnabled()) { LOG.info("Subtask {} is seeding the fetcher with restored shard {}," + " starting state set to the restored sequence number {}", - getRuntimeContext().getIndexOfThisSubtask(), shard.toString(), startingStateForNewShard); + getRuntimeContext().getIndexOfThisSubtask(), shard.toString(), sequenceNumsToRestore.get(shard)); } } else { - startingStateForNewShard = SentinelSequenceNumber.SENTINEL_EARLIEST_SEQUENCE_NUM.get(); + // the shard wasn't discovered in the previous run, therefore should be consumed from the beginning + fetcher.registerNewSubscribedShardState( + new KinesisStreamShardState(shard, SentinelSequenceNumber.SENTINEL_EARLIEST_SEQUENCE_NUM.get())); if (LOG.isInfoEnabled()) { LOG.info("Subtask {} is seeding the fetcher with new discovered shard {}," + @@ -249,9 +223,20 @@ public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T> imple getRuntimeContext().getIndexOfThisSubtask(), shard.toString()); } } + } else { + // we're starting fresh; use the configured start position as initial state + SentinelSequenceNumber startingSeqNum = + InitialPosition.valueOf(configProps.getProperty( + ConsumerConfigConstants.STREAM_INITIAL_POSITION, + ConsumerConfigConstants.DEFAULT_STREAM_INITIAL_POSITION)).toSentinelSequenceNumber(); fetcher.registerNewSubscribedShardState( - new KinesisStreamShardState(shard, startingStateForNewShard)); + new KinesisStreamShardState(shard, startingSeqNum.get())); + + if (LOG.isInfoEnabled()) { + LOG.info("Subtask {} will be seeded with initial shard {}, starting state set as sequence number {}", + getRuntimeContext().getIndexOfThisSubtask(), shard.toString(), startingSeqNum.get()); + } } } @@ -260,6 +245,10 @@ public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T> imple return; } + // expose the fetcher from this point, so that state + // snapshots can be taken from the fetcher's state holders + this.fetcher = fetcher; + // start the fetcher loop. The fetcher will stop running only when cancel() or // close() is called, or an error is thrown by threads created by the fetcher fetcher.runFetcher(); @@ -306,13 +295,12 @@ public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T> imple @Override public void initializeState(FunctionInitializationContext context) throws Exception { - TypeInformation<Tuple2<KinesisStreamShard, SequenceNumber>> tuple = new TupleTypeInfo<>( + TypeInformation<Tuple2<KinesisStreamShard, SequenceNumber>> shardsStateTypeInfo = new TupleTypeInfo<>( TypeInformation.of(KinesisStreamShard.class), - TypeInformation.of(SequenceNumber.class) - ); + TypeInformation.of(SequenceNumber.class)); sequenceNumsStateForCheckpoint = context.getOperatorStateStore().getUnionListState( - new ListStateDescriptor<>(sequenceNumsStateStoreName, tuple)); + new ListStateDescriptor<>(sequenceNumsStateStoreName, shardsStateTypeInfo)); if (context.isRestored()) { if (sequenceNumsToRestore == null) { @@ -323,8 +311,6 @@ public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T> imple LOG.info("Setting restore state in the FlinkKinesisConsumer. Using the following offsets: {}", sequenceNumsToRestore); - } else if (sequenceNumsToRestore.isEmpty()) { - sequenceNumsToRestore = null; } } else { LOG.info("No restore state for FlinkKinesisConsumer."); @@ -333,11 +319,7 @@ public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T> imple @Override public void snapshotState(FunctionSnapshotContext context) throws Exception { - if (lastStateSnapshot == null) { - LOG.debug("snapshotState() requested on not yet opened source; returning null."); - } else if (fetcher == null) { - LOG.debug("snapshotState() requested on not yet running source; returning null."); - } else if (!running) { + if (!running) { LOG.debug("snapshotState() called on closed source; returning null."); } else { if (LOG.isDebugEnabled()) { @@ -345,15 +327,33 @@ public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T> imple } sequenceNumsStateForCheckpoint.clear(); - lastStateSnapshot = fetcher.snapshotState(); - if (LOG.isDebugEnabled()) { - LOG.debug("Snapshotted state, last processed sequence numbers: {}, checkpoint id: {}, timestamp: {}", - lastStateSnapshot.toString(), context.getCheckpointId(), context.getCheckpointTimestamp()); - } + if (fetcher == null) { + if (sequenceNumsToRestore != null) { + for (Map.Entry<KinesisStreamShard, SequenceNumber> entry : sequenceNumsToRestore.entrySet()) { + // sequenceNumsToRestore is the restored global union state; + // should only snapshot shards that actually belong to us + + if (KinesisDataFetcher.isThisSubtaskShouldSubscribeTo( + entry.getKey(), + getRuntimeContext().getNumberOfParallelSubtasks(), + getRuntimeContext().getIndexOfThisSubtask())) { + + sequenceNumsStateForCheckpoint.add(Tuple2.of(entry.getKey(), entry.getValue())); + } + } + } + } else { + HashMap<KinesisStreamShard, SequenceNumber> lastStateSnapshot = fetcher.snapshotState(); + + if (LOG.isDebugEnabled()) { + LOG.debug("Snapshotted state, last processed sequence numbers: {}, checkpoint id: {}, timestamp: {}", + lastStateSnapshot.toString(), context.getCheckpointId(), context.getCheckpointTimestamp()); + } - for (Map.Entry<KinesisStreamShard, SequenceNumber> entry : lastStateSnapshot.entrySet()) { - sequenceNumsStateForCheckpoint.add(Tuple2.of(entry.getKey(), entry.getValue())); + for (Map.Entry<KinesisStreamShard, SequenceNumber> entry : lastStateSnapshot.entrySet()) { + sequenceNumsStateForCheckpoint.add(Tuple2.of(entry.getKey(), entry.getValue())); + } } } } @@ -366,12 +366,14 @@ public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T> imple sequenceNumsToRestore = restoredState.isEmpty() ? null : restoredState; } - /** This method is created for tests that can mock the KinesisDataFetcher in the consumer. */ - protected KinesisDataFetcher<T> createFetcher(List<String> streams, - SourceFunction.SourceContext<T> sourceContext, - RuntimeContext runtimeContext, - Properties configProps, - KinesisDeserializationSchema<T> deserializationSchema) { + /** This method is exposed for tests that need to mock the KinesisDataFetcher in the consumer. */ + protected KinesisDataFetcher<T> createFetcher( + List<String> streams, + SourceFunction.SourceContext<T> sourceContext, + RuntimeContext runtimeContext, + Properties configProps, + KinesisDeserializationSchema<T> deserializationSchema) { + return new KinesisDataFetcher<>(streams, sourceContext, runtimeContext, configProps, deserializationSchema); } http://git-wip-us.apache.org/repos/asf/flink/blob/e5b65a7f/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcher.java ---------------------------------------------------------------------- diff --git a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcher.java b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcher.java index c5b4b04..99305cb 100644 --- a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcher.java +++ b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcher.java @@ -19,9 +19,7 @@ package org.apache.flink.streaming.connectors.kinesis.internals; import org.apache.flink.api.common.functions.RuntimeContext; import org.apache.flink.streaming.api.functions.source.SourceFunction; -import org.apache.flink.streaming.connectors.kinesis.FlinkKinesisConsumer; import org.apache.flink.streaming.connectors.kinesis.config.ConsumerConfigConstants; -import org.apache.flink.streaming.connectors.kinesis.config.ConsumerConfigConstants.InitialPosition; import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShard; import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShardState; import org.apache.flink.streaming.connectors.kinesis.model.SentinelSequenceNumber; @@ -99,12 +97,6 @@ public class KinesisDataFetcher<T> { private final int indexOfThisConsumerSubtask; - /** - * This flag should be set by {@link FlinkKinesisConsumer} using - * {@link KinesisDataFetcher#setIsRestoringFromFailure(boolean)} - */ - private boolean isRestoredFromFailure; - // ------------------------------------------------------------------------ // Executor services to run created threads // ------------------------------------------------------------------------ @@ -235,41 +227,7 @@ public class KinesisDataFetcher<T> { // Procedures before starting the infinite while loop: // ------------------------------------------------------------------------ - // 1. query for any new shards that may have been created while the Kinesis consumer was not running, - // and register them to the subscribedShardState list. - if (LOG.isDebugEnabled()) { - String logFormat = (!isRestoredFromFailure) - ? "Subtask {} is trying to discover initial shards ..." - : "Subtask {} is trying to discover any new shards that were created while the consumer wasn't " + - "running due to failure ..."; - - LOG.debug(logFormat, indexOfThisConsumerSubtask); - } - List<KinesisStreamShard> newShardsCreatedWhileNotRunning = discoverNewShardsToSubscribe(); - for (KinesisStreamShard shard : newShardsCreatedWhileNotRunning) { - // the starting state for new shards created while the consumer wasn't running depends on whether or not - // we are starting fresh (not restoring from a checkpoint); when we are starting fresh, this simply means - // all existing shards of streams we are subscribing to are new shards; when we are restoring from checkpoint, - // any new shards due to Kinesis resharding from the time of the checkpoint will be considered new shards. - InitialPosition initialPosition = InitialPosition.valueOf(configProps.getProperty( - ConsumerConfigConstants.STREAM_INITIAL_POSITION, ConsumerConfigConstants.DEFAULT_STREAM_INITIAL_POSITION)); - - SentinelSequenceNumber startingStateForNewShard = (isRestoredFromFailure) - ? SentinelSequenceNumber.SENTINEL_EARLIEST_SEQUENCE_NUM - : initialPosition.toSentinelSequenceNumber(); - - if (LOG.isInfoEnabled()) { - String logFormat = (!isRestoredFromFailure) - ? "Subtask {} will be seeded with initial shard {}, starting state set as sequence number {}" - : "Subtask {} will be seeded with new shard {} that was created while the consumer wasn't " + - "running due to failure, starting state set as sequence number {}"; - - LOG.info(logFormat, indexOfThisConsumerSubtask, shard.toString(), startingStateForNewShard.get()); - } - registerNewSubscribedShardState(new KinesisStreamShardState(shard, startingStateForNewShard.get())); - } - - // 2. check that there is at least one shard in the subscribed streams to consume from (can be done by + // 1. check that there is at least one shard in the subscribed streams to consume from (can be done by // checking if at least one value in subscribedStreamsToLastDiscoveredShardIds is not null) boolean hasShards = false; StringBuilder streamsWithNoShardsFound = new StringBuilder(); @@ -290,7 +248,7 @@ public class KinesisDataFetcher<T> { throw new RuntimeException("No shards can be found for all subscribed streams: " + streams); } - // 3. start consuming any shard state we already have in the subscribedShardState up to this point; the + // 2. start consuming any shard state we already have in the subscribedShardState up to this point; the // subscribedShardState may already be seeded with values due to step 1., or explicitly added by the // consumer using a restored state checkpoint for (int seededStateIndex = 0; seededStateIndex < subscribedShardsState.size(); seededStateIndex++) { @@ -489,10 +447,6 @@ public class KinesisDataFetcher<T> { // Functions to get / set information about the consumer // ------------------------------------------------------------------------ - public void setIsRestoringFromFailure(boolean bool) { - this.isRestoredFromFailure = bool; - } - protected Properties getConsumerConfiguration() { return configProps; } @@ -595,7 +549,7 @@ public class KinesisDataFetcher<T> { * @param totalNumberOfConsumerSubtasks total number of consumer subtasks * @param indexOfThisConsumerSubtask index of this consumer subtask */ - private static boolean isThisSubtaskShouldSubscribeTo(KinesisStreamShard shard, + public static boolean isThisSubtaskShouldSubscribeTo(KinesisStreamShard shard, int totalNumberOfConsumerSubtasks, int indexOfThisConsumerSubtask) { return (Math.abs(shard.hashCode() % totalNumberOfConsumerSubtasks)) == indexOfThisConsumerSubtask; http://git-wip-us.apache.org/repos/asf/flink/blob/e5b65a7f/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerMigrationTest.java ---------------------------------------------------------------------- diff --git a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerMigrationTest.java b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerMigrationTest.java index 2f46e09..7629f9d 100644 --- a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerMigrationTest.java +++ b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerMigrationTest.java @@ -42,10 +42,7 @@ import static org.mockito.Mockito.mock; /** * Tests for checking whether {@link FlinkKinesisConsumer} can restore from snapshots that were - * done using the Flink 1.1 {@link FlinkKinesisConsumer}. - * - * <p>For regenerating the binary snapshot file you have to run the commented out portion - * of each test on a checkout of the Flink 1.1 branch. + * done using the Flink 1.1 {@code FlinkKinesisConsumer}. */ public class FlinkKinesisConsumerMigrationTest { http://git-wip-us.apache.org/repos/asf/flink/blob/e5b65a7f/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerTest.java ---------------------------------------------------------------------- diff --git a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerTest.java b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerTest.java index bf8e44f..4b178c7 100644 --- a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerTest.java +++ b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerTest.java @@ -40,6 +40,7 @@ import org.apache.flink.streaming.connectors.kinesis.testutils.KinesisShardIdGen import org.apache.flink.streaming.connectors.kinesis.testutils.TestableFlinkKinesisConsumer; import org.apache.flink.streaming.connectors.kinesis.util.KinesisConfigUtil; import org.apache.flink.streaming.util.serialization.SimpleStringSchema; +import org.junit.Assert; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; @@ -57,10 +58,8 @@ import java.util.ArrayList; import java.util.List; import java.util.Properties; import java.util.UUID; -import java.io.Serializable; import static org.junit.Assert.fail; -import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotEquals; import static org.mockito.Mockito.mock; @@ -530,7 +529,7 @@ public class FlinkKinesisConsumerTest { // ---------------------------------------------------------------------- @Test - public void testSnapshotStateShouldNotClearListStateIfSourceNotOpened() throws Exception { + public void testUseRestoredStateForSnapshotIfFetcherNotInitialized() throws Exception { Properties config = new Properties(); config.setProperty(AWSConfigConstants.AWS_REGION, "us-east-1"); config.setProperty(AWSConfigConstants.AWS_ACCESS_KEY_ID, "accessKeyId"); @@ -538,57 +537,63 @@ public class FlinkKinesisConsumerTest { OperatorStateStore operatorStateStore = mock(OperatorStateStore.class); - TestingListState<Serializable> listState = new TestingListState<>(); - - FlinkKinesisConsumer<String> consumer = new FlinkKinesisConsumer<>("fakeStream", new SimpleStringSchema(), config); - - when(operatorStateStore.getUnionListState(Matchers.any(ListStateDescriptor.class))).thenReturn(listState); - - StateInitializationContext initializationContext = mock(StateInitializationContext.class); - - when(initializationContext.getOperatorStateStore()).thenReturn(operatorStateStore); - when(initializationContext.isRestored()).thenReturn(false); - - consumer.initializeState(initializationContext); - - consumer.snapshotState(new StateSnapshotContextSynchronousImpl(123, 123)); //arbitrary checkpoint id and timestamp - - assertFalse(listState.isClearCalled()); - } - - @Test - public void testSnapshotStateShouldNotClearListStateIfSourceNotRun() throws Exception { - Properties config = new Properties(); - config.setProperty(AWSConfigConstants.AWS_REGION, "us-east-1"); - config.setProperty(AWSConfigConstants.AWS_ACCESS_KEY_ID, "accessKeyId"); - config.setProperty(AWSConfigConstants.AWS_SECRET_ACCESS_KEY, "secretKey"); - - OperatorStateStore operatorStateStore = mock(OperatorStateStore.class); + List<Tuple2<KinesisStreamShard, SequenceNumber>> globalUnionState = new ArrayList<>(4); + globalUnionState.add(Tuple2.of( + new KinesisStreamShard("fakeStream", + new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(0))), + new SequenceNumber("1"))); + globalUnionState.add(Tuple2.of( + new KinesisStreamShard("fakeStream", + new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(1))), + new SequenceNumber("1"))); + globalUnionState.add(Tuple2.of( + new KinesisStreamShard("fakeStream", + new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(2))), + new SequenceNumber("1"))); + globalUnionState.add(Tuple2.of( + new KinesisStreamShard("fakeStream", + new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(3))), + new SequenceNumber("1"))); - TestingListState<Serializable> listState = new TestingListState<>(); + TestingListState<Tuple2<KinesisStreamShard, SequenceNumber>> listState = new TestingListState<>(); + for (Tuple2<KinesisStreamShard, SequenceNumber> state : globalUnionState) { + listState.add(state); + } FlinkKinesisConsumer<String> consumer = new FlinkKinesisConsumer<>("fakeStream", new SimpleStringSchema(), config); + RuntimeContext context = mock(RuntimeContext.class); + when(context.getIndexOfThisSubtask()).thenReturn(0); + when(context.getNumberOfParallelSubtasks()).thenReturn(2); + consumer.setRuntimeContext(context); when(operatorStateStore.getUnionListState(Matchers.any(ListStateDescriptor.class))).thenReturn(listState); StateInitializationContext initializationContext = mock(StateInitializationContext.class); when(initializationContext.getOperatorStateStore()).thenReturn(operatorStateStore); - when(initializationContext.isRestored()).thenReturn(false); + when(initializationContext.isRestored()).thenReturn(true); consumer.initializeState(initializationContext); - consumer.open(new Configuration()); // only opened, not run + // only opened, not run + consumer.open(new Configuration()); + + // arbitrary checkpoint id and timestamp + consumer.snapshotState(new StateSnapshotContextSynchronousImpl(123, 123)); - consumer.snapshotState(new StateSnapshotContextSynchronousImpl(123, 123)); //arbitrary checkpoint id and timestamp + Assert.assertTrue(listState.isClearCalled()); - assertFalse(listState.isClearCalled()); + // the checkpointed list state should contain only the shards that it should subscribe to + Assert.assertEquals(globalUnionState.size() / 2, listState.getList().size()); + Assert.assertTrue(listState.getList().contains(globalUnionState.get(0))); + Assert.assertTrue(listState.getList().contains(globalUnionState.get(2))); } @Test public void testListStateChangedAfterSnapshotState() throws Exception { + // ---------------------------------------------------------------------- - // setting config, initial state and state after snapshot + // setup config, initial state and expected state snapshot // ---------------------------------------------------------------------- Properties config = new Properties(); config.setProperty(AWSConfigConstants.AWS_REGION, "us-east-1"); @@ -601,16 +606,16 @@ public class FlinkKinesisConsumerTest { new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(0))), new SequenceNumber("1"))); - ArrayList<Tuple2<KinesisStreamShard, SequenceNumber>> snapShotState = new ArrayList<>(3); - snapShotState.add(Tuple2.of( + ArrayList<Tuple2<KinesisStreamShard, SequenceNumber>> expectedStateSnapshot = new ArrayList<>(3); + expectedStateSnapshot.add(Tuple2.of( new KinesisStreamShard("fakeStream1", new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(0))), new SequenceNumber("12"))); - snapShotState.add(Tuple2.of( + expectedStateSnapshot.add(Tuple2.of( new KinesisStreamShard("fakeStream1", new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(1))), new SequenceNumber("11"))); - snapShotState.add(Tuple2.of( + expectedStateSnapshot.add(Tuple2.of( new KinesisStreamShard("fakeStream1", new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(2))), new SequenceNumber("31"))); @@ -618,8 +623,9 @@ public class FlinkKinesisConsumerTest { // ---------------------------------------------------------------------- // mock operator state backend and initial state for initializeState() // ---------------------------------------------------------------------- - TestingListState<Serializable> listState = new TestingListState<>(); - for (Serializable state: initialState) { + + TestingListState<Tuple2<KinesisStreamShard, SequenceNumber>> listState = new TestingListState<>(); + for (Tuple2<KinesisStreamShard, SequenceNumber> state: initialState) { listState.add(state); } @@ -633,8 +639,9 @@ public class FlinkKinesisConsumerTest { // ---------------------------------------------------------------------- // mock a running fetcher and its state for snapshot // ---------------------------------------------------------------------- + HashMap<KinesisStreamShard, SequenceNumber> stateSnapshot = new HashMap<>(); - for (Tuple2<KinesisStreamShard, SequenceNumber> tuple: snapShotState) { + for (Tuple2<KinesisStreamShard, SequenceNumber> tuple: expectedStateSnapshot) { stateSnapshot.put(tuple.f0, tuple.f1); } @@ -644,6 +651,7 @@ public class FlinkKinesisConsumerTest { // ---------------------------------------------------------------------- // create a consumer and test the snapshotState() // ---------------------------------------------------------------------- + FlinkKinesisConsumer<String> consumer = new FlinkKinesisConsumer<>("fakeStream", new SimpleStringSchema(), config); FlinkKinesisConsumer<?> mockedConsumer = spy(consumer); @@ -653,22 +661,22 @@ public class FlinkKinesisConsumerTest { mockedConsumer.setRuntimeContext(context); mockedConsumer.initializeState(initializationContext); mockedConsumer.open(new Configuration()); - Whitebox.setInternalState(mockedConsumer, "fetcher", mockedFetcher); // mock as consumer is running. + Whitebox.setInternalState(mockedConsumer, "fetcher", mockedFetcher); // mock consumer as running. mockedConsumer.snapshotState(mock(FunctionSnapshotContext.class)); assertEquals(true, listState.clearCalled); assertEquals(3, listState.getList().size()); - for (Serializable state: initialState) { - for (Serializable currentState: listState.getList()) { + for (Tuple2<KinesisStreamShard, SequenceNumber> state: initialState) { + for (Tuple2<KinesisStreamShard, SequenceNumber> currentState: listState.getList()) { assertNotEquals(state, currentState); } } - for (Serializable state: snapShotState) { + for (Tuple2<KinesisStreamShard, SequenceNumber> state: expectedStateSnapshot) { boolean hasOneIsSame = false; - for (Serializable currentState: listState.getList()) { + for (Tuple2<KinesisStreamShard, SequenceNumber> currentState: listState.getList()) { hasOneIsSame = hasOneIsSame || state.equals(currentState); } assertEquals(true, hasOneIsSame); @@ -693,8 +701,6 @@ public class FlinkKinesisConsumerTest { "fakeStream", new Properties(), 10, 2); consumer.open(new Configuration()); consumer.run(Mockito.mock(SourceFunction.SourceContext.class)); - - Mockito.verify(mockedFetcher).setIsRestoringFromFailure(false); } @Test @@ -718,7 +724,6 @@ public class FlinkKinesisConsumerTest { consumer.open(new Configuration()); consumer.run(Mockito.mock(SourceFunction.SourceContext.class)); - Mockito.verify(mockedFetcher).setIsRestoringFromFailure(true); for (Map.Entry<KinesisStreamShard, SequenceNumber> restoredShard : fakeRestoredState.entrySet()) { Mockito.verify(mockedFetcher).registerNewSubscribedShardState( new KinesisStreamShardState(restoredShard.getKey(), restoredShard.getValue())); @@ -728,15 +733,18 @@ public class FlinkKinesisConsumerTest { @Test @SuppressWarnings("unchecked") public void testFetcherShouldBeCorrectlySeededIfRestoringFromCheckpoint() throws Exception { + // ---------------------------------------------------------------------- - // setting initial state + // setup initial state // ---------------------------------------------------------------------- + HashMap<KinesisStreamShard, SequenceNumber> fakeRestoredState = getFakeRestoredStore("all"); // ---------------------------------------------------------------------- // mock operator state backend and initial state for initializeState() // ---------------------------------------------------------------------- - TestingListState<Serializable> listState = new TestingListState<>(); + + TestingListState<Tuple2<KinesisStreamShard, SequenceNumber>> listState = new TestingListState<>(); for (Map.Entry<KinesisStreamShard, SequenceNumber> state: fakeRestoredState.entrySet()) { listState.add(Tuple2.of(state.getKey(), state.getValue())); } @@ -751,6 +759,7 @@ public class FlinkKinesisConsumerTest { // ---------------------------------------------------------------------- // mock fetcher // ---------------------------------------------------------------------- + KinesisDataFetcher mockedFetcher = Mockito.mock(KinesisDataFetcher.class); List<KinesisStreamShard> shards = new ArrayList<>(); shards.addAll(fakeRestoredState.keySet()); @@ -762,15 +771,15 @@ public class FlinkKinesisConsumerTest { PowerMockito.doNothing().when(KinesisConfigUtil.class); // ---------------------------------------------------------------------- - // start to test seed initial state to fetcher + // start to test fetcher's initial state seeding // ---------------------------------------------------------------------- + TestableFlinkKinesisConsumer consumer = new TestableFlinkKinesisConsumer( "fakeStream", new Properties(), 10, 2); consumer.initializeState(initializationContext); consumer.open(new Configuration()); consumer.run(Mockito.mock(SourceFunction.SourceContext.class)); - Mockito.verify(mockedFetcher).setIsRestoringFromFailure(true); for (Map.Entry<KinesisStreamShard, SequenceNumber> restoredShard : fakeRestoredState.entrySet()) { Mockito.verify(mockedFetcher).registerNewSubscribedShardState( new KinesisStreamShardState(restoredShard.getKey(), restoredShard.getValue())); @@ -780,9 +789,11 @@ public class FlinkKinesisConsumerTest { @Test @SuppressWarnings("unchecked") public void testFetcherShouldBeCorrectlySeededOnlyItsOwnStates() throws Exception { + // ---------------------------------------------------------------------- - // setting initial state + // setup initial state // ---------------------------------------------------------------------- + HashMap<KinesisStreamShard, SequenceNumber> fakeRestoredState = getFakeRestoredStore("fakeStream1"); HashMap<KinesisStreamShard, SequenceNumber> fakeRestoredStateForOthers = getFakeRestoredStore("fakeStream2"); @@ -790,7 +801,8 @@ public class FlinkKinesisConsumerTest { // ---------------------------------------------------------------------- // mock operator state backend and initial state for initializeState() // ---------------------------------------------------------------------- - TestingListState<Serializable> listState = new TestingListState<>(); + + TestingListState<Tuple2<KinesisStreamShard, SequenceNumber>> listState = new TestingListState<>(); for (Map.Entry<KinesisStreamShard, SequenceNumber> state: fakeRestoredState.entrySet()) { listState.add(Tuple2.of(state.getKey(), state.getValue())); } @@ -808,6 +820,7 @@ public class FlinkKinesisConsumerTest { // ---------------------------------------------------------------------- // mock fetcher // ---------------------------------------------------------------------- + KinesisDataFetcher mockedFetcher = Mockito.mock(KinesisDataFetcher.class); List<KinesisStreamShard> shards = new ArrayList<>(); shards.addAll(fakeRestoredState.keySet()); @@ -819,15 +832,15 @@ public class FlinkKinesisConsumerTest { PowerMockito.doNothing().when(KinesisConfigUtil.class); // ---------------------------------------------------------------------- - // start to test seed initial state to fetcher + // start to test fetcher's initial state seeding // ---------------------------------------------------------------------- + TestableFlinkKinesisConsumer consumer = new TestableFlinkKinesisConsumer( "fakeStream", new Properties(), 10, 2); consumer.initializeState(initializationContext); consumer.open(new Configuration()); consumer.run(Mockito.mock(SourceFunction.SourceContext.class)); - Mockito.verify(mockedFetcher).setIsRestoringFromFailure(true); for (Map.Entry<KinesisStreamShard, SequenceNumber> restoredShard : fakeRestoredStateForOthers.entrySet()) { // should never get restored state not belonging to itself Mockito.verify(mockedFetcher, never()).registerNewSubscribedShardState( @@ -841,42 +854,49 @@ public class FlinkKinesisConsumerTest { } /* - * If the original parallelism is 2 and states is: + * This tests that the consumer correctly picks up shards that were not discovered on the previous run. + * + * Case under test: + * + * If the original parallelism is 2 and states are: * Consumer subtask 1: * stream1, shard1, SequentialNumber(xxx) * Consumer subtask 2: * stream1, shard2, SequentialNumber(yyy) - * After discoverNewShardsToSubscribe() if there are two shards (shard3, shard4) been created: + * + * After discoverNewShardsToSubscribe() if there were two shards (shard3, shard4) created: * Consumer subtask 1 (late for discoverNewShardsToSubscribe()): * stream1, shard1, SequentialNumber(xxx) * Consumer subtask 2: * stream1, shard2, SequentialNumber(yyy) * stream1, shard4, SequentialNumber(zzz) - * If snapshotState() occur and parallelism is changed to 1: - * Union state will be: + * + * If snapshotState() occurs and parallelism is changed to 1: + * Union state will be: * stream1, shard1, SequentialNumber(xxx) * stream1, shard2, SequentialNumber(yyy) * stream1, shard4, SequentialNumber(zzz) - * Fetcher should be seeded with: + * Fetcher should be seeded with: * stream1, shard1, SequentialNumber(xxx) * stream1, shard2, SequentialNumber(yyy) * stream1, share3, SentinelSequenceNumber.SENTINEL_EARLIEST_SEQUENCE_NUM * stream1, shard4, SequentialNumber(zzz) - * - * This test is to guarantee the fetcher will be seeded correctly for such situation. */ @Test @SuppressWarnings("unchecked") public void testFetcherShouldBeCorrectlySeededWithNewDiscoveredKinesisStreamShard() throws Exception { + // ---------------------------------------------------------------------- - // setting initial state + // setup initial state // ---------------------------------------------------------------------- + HashMap<KinesisStreamShard, SequenceNumber> fakeRestoredState = getFakeRestoredStore("all"); // ---------------------------------------------------------------------- // mock operator state backend and initial state for initializeState() // ---------------------------------------------------------------------- - TestingListState<Serializable> listState = new TestingListState<>(); + + TestingListState<Tuple2<KinesisStreamShard, SequenceNumber>> listState = new TestingListState<>(); for (Map.Entry<KinesisStreamShard, SequenceNumber> state: fakeRestoredState.entrySet()) { listState.add(Tuple2.of(state.getKey(), state.getValue())); } @@ -891,6 +911,7 @@ public class FlinkKinesisConsumerTest { // ---------------------------------------------------------------------- // mock fetcher // ---------------------------------------------------------------------- + KinesisDataFetcher mockedFetcher = Mockito.mock(KinesisDataFetcher.class); List<KinesisStreamShard> shards = new ArrayList<>(); shards.addAll(fakeRestoredState.keySet()); @@ -904,8 +925,9 @@ public class FlinkKinesisConsumerTest { PowerMockito.doNothing().when(KinesisConfigUtil.class); // ---------------------------------------------------------------------- - // start to test seed initial state to fetcher + // start to test fetcher's initial state seeding // ---------------------------------------------------------------------- + TestableFlinkKinesisConsumer consumer = new TestableFlinkKinesisConsumer( "fakeStream", new Properties(), 10, 2); consumer.initializeState(initializationContext); @@ -915,7 +937,6 @@ public class FlinkKinesisConsumerTest { fakeRestoredState.put(new KinesisStreamShard("fakeStream2", new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(2))), SentinelSequenceNumber.SENTINEL_EARLIEST_SEQUENCE_NUM.get()); - Mockito.verify(mockedFetcher).setIsRestoringFromFailure(true); for (Map.Entry<KinesisStreamShard, SequenceNumber> restoredShard : fakeRestoredState.entrySet()) { Mockito.verify(mockedFetcher).registerNewSubscribedShardState( new KinesisStreamShardState(restoredShard.getKey(), restoredShard.getValue())); http://git-wip-us.apache.org/repos/asf/flink/blob/e5b65a7f/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcherTest.java ---------------------------------------------------------------------- diff --git a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcherTest.java b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcherTest.java index e79f9b1..800fde5 100644 --- a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcherTest.java +++ b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcherTest.java @@ -18,9 +18,14 @@ package org.apache.flink.streaming.connectors.kinesis.internals; import com.amazonaws.services.kinesis.model.Shard; +import org.apache.flink.api.common.functions.RuntimeContext; +import org.apache.flink.streaming.api.functions.source.SourceFunction; +import org.apache.flink.streaming.connectors.kinesis.FlinkKinesisConsumer; +import org.apache.flink.streaming.connectors.kinesis.config.ConsumerConfigConstants; import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShard; import org.apache.flink.streaming.connectors.kinesis.model.SequenceNumber; import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShardState; +import org.apache.flink.streaming.connectors.kinesis.serialization.KinesisDeserializationSchema; import org.apache.flink.streaming.connectors.kinesis.testutils.FakeKinesisBehavioursFactory; import org.apache.flink.streaming.connectors.kinesis.testutils.KinesisShardIdGenerator; import org.apache.flink.streaming.connectors.kinesis.testutils.TestableKinesisDataFetcher; @@ -42,6 +47,8 @@ import java.util.UUID; import java.util.concurrent.atomic.AtomicReference; import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; @RunWith(PowerMockRunner.class) @PrepareForTest(TestableKinesisDataFetcher.class) @@ -67,8 +74,6 @@ public class KinesisDataFetcherTest { subscribedStreamsToLastSeenShardIdsUnderTest, FakeKinesisBehavioursFactory.noShardsFoundForRequestedStreamsBehaviour()); - fetcher.setIsRestoringFromFailure(false); // not restoring - fetcher.runFetcher(); // this should throw RuntimeException } @@ -100,23 +105,30 @@ public class KinesisDataFetcherTest { subscribedStreamsToLastSeenShardIdsUnderTest, FakeKinesisBehavioursFactory.nonReshardedStreamsBehaviour(streamToShardCount)); - fetcher.setIsRestoringFromFailure(false); + Properties testConfig = new Properties(); + testConfig.setProperty(ConsumerConfigConstants.AWS_REGION, "us-east-1"); + testConfig.setProperty(ConsumerConfigConstants.AWS_CREDENTIALS_PROVIDER, "BASIC"); + testConfig.setProperty(ConsumerConfigConstants.AWS_ACCESS_KEY_ID, "accessKeyId"); + testConfig.setProperty(ConsumerConfigConstants.AWS_SECRET_ACCESS_KEY, "secretKey"); + + final DummyFlinkKafkaConsumer<String> consumer = new DummyFlinkKafkaConsumer<>(testConfig, fetcher); PowerMockito.whenNew(ShardConsumer.class).withAnyArguments().thenReturn(Mockito.mock(ShardConsumer.class)); - Thread runFetcherThread = new Thread(new Runnable() { + Thread consumerThread = new Thread(new Runnable() { @Override public void run() { try { - fetcher.runFetcher(); + consumer.run(mock(SourceFunction.SourceContext.class)); } catch (Exception e) { // } } }); - runFetcherThread.start(); - Thread.sleep(1000); // sleep a while before closing - fetcher.shutdownFetcher(); + consumerThread.start(); + fetcher.waitUntilRun(); + consumer.cancel(); + consumerThread.join(); // assert that the streams tracked in the state are identical to the subscribed streams Set<String> streamsInState = subscribedStreamsToLastSeenShardIdsUnderTest.keySet(); @@ -192,8 +204,6 @@ public class KinesisDataFetcherTest { new KinesisStreamShardState(restoredState.getKey(), new SequenceNumber(restoredState.getValue()))); } - fetcher.setIsRestoringFromFailure(true); - PowerMockito.whenNew(ShardConsumer.class).withAnyArguments().thenReturn(Mockito.mock(ShardConsumer.class)); Thread runFetcherThread = new Thread(new Runnable() { @Override @@ -284,8 +294,6 @@ public class KinesisDataFetcherTest { new KinesisStreamShardState(restoredState.getKey(), new SequenceNumber(restoredState.getValue()))); } - fetcher.setIsRestoringFromFailure(true); - PowerMockito.whenNew(ShardConsumer.class).withAnyArguments().thenReturn(Mockito.mock(ShardConsumer.class)); Thread runFetcherThread = new Thread(new Runnable() { @Override @@ -380,8 +388,6 @@ public class KinesisDataFetcherTest { new KinesisStreamShardState(restoredState.getKey(), new SequenceNumber(restoredState.getValue()))); } - fetcher.setIsRestoringFromFailure(true); - PowerMockito.whenNew(ShardConsumer.class).withAnyArguments().thenReturn(Mockito.mock(ShardConsumer.class)); Thread runFetcherThread = new Thread(new Runnable() { @Override @@ -477,8 +483,6 @@ public class KinesisDataFetcherTest { new KinesisStreamShardState(restoredState.getKey(), new SequenceNumber(restoredState.getValue()))); } - fetcher.setIsRestoringFromFailure(true); - PowerMockito.whenNew(ShardConsumer.class).withAnyArguments().thenReturn(Mockito.mock(ShardConsumer.class)); Thread runFetcherThread = new Thread(new Runnable() { @Override @@ -507,4 +511,33 @@ public class KinesisDataFetcherTest { assertTrue(subscribedStreamsToLastSeenShardIdsUnderTest.get("fakeStream3") == null); assertTrue(subscribedStreamsToLastSeenShardIdsUnderTest.get("fakeStream4") == null); } + + private static class DummyFlinkKafkaConsumer<T> extends FlinkKinesisConsumer<T> { + private static final long serialVersionUID = 1L; + + private KinesisDataFetcher<T> fetcher; + + @SuppressWarnings("unchecked") + DummyFlinkKafkaConsumer(Properties properties, KinesisDataFetcher<T> fetcher) { + super("test", mock(KinesisDeserializationSchema.class), properties); + this.fetcher = fetcher; + } + + @Override + protected KinesisDataFetcher<T> createFetcher(List<String> streams, + SourceFunction.SourceContext<T> sourceContext, + RuntimeContext runtimeContext, + Properties configProps, + KinesisDeserializationSchema<T> deserializationSchema) { + return fetcher; + } + + @Override + public RuntimeContext getRuntimeContext() { + RuntimeContext context = mock(RuntimeContext.class); + when(context.getIndexOfThisSubtask()).thenReturn(0); + when(context.getNumberOfParallelSubtasks()).thenReturn(1); + return context; + } + } } http://git-wip-us.apache.org/repos/asf/flink/blob/e5b65a7f/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/testutils/TestableKinesisDataFetcher.java ---------------------------------------------------------------------- diff --git a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/testutils/TestableKinesisDataFetcher.java b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/testutils/TestableKinesisDataFetcher.java index 57886fe..bb644ba 100644 --- a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/testutils/TestableKinesisDataFetcher.java +++ b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/testutils/TestableKinesisDataFetcher.java @@ -18,6 +18,7 @@ package org.apache.flink.streaming.connectors.kinesis.testutils; import org.apache.flink.api.common.functions.RuntimeContext; +import org.apache.flink.core.testutils.OneShotLatch; import org.apache.flink.streaming.api.functions.source.SourceFunction; import org.apache.flink.streaming.connectors.kinesis.internals.KinesisDataFetcher; import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShardState; @@ -42,6 +43,8 @@ public class TestableKinesisDataFetcher extends KinesisDataFetcher<String> { private long numElementsCollected; + private OneShotLatch runWaiter; + public TestableKinesisDataFetcher(List<String> fakeStreams, Properties fakeConfiguration, int fakeTotalCountOfSubtasks, @@ -62,6 +65,7 @@ public class TestableKinesisDataFetcher extends KinesisDataFetcher<String> { fakeKinesis); this.numElementsCollected = 0; + this.runWaiter = new OneShotLatch(); } public long getNumOfElementsCollected() { @@ -81,6 +85,16 @@ public class TestableKinesisDataFetcher extends KinesisDataFetcher<String> { } } + @Override + public void runFetcher() throws Exception { + runWaiter.trigger(); + super.runFetcher(); + } + + public void waitUntilRun() throws Exception { + runWaiter.await(); + } + @SuppressWarnings("unchecked") private static SourceFunction.SourceContext<String> getMockedSourceContext() { return Mockito.mock(SourceFunction.SourceContext.class);
