Repository: flink Updated Branches: refs/heads/master 369837971 -> 81d8fe16a
[FLINK-5163] Port the StatefulSequenceSource to the new state abstractions. Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/81d8fe16 Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/81d8fe16 Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/81d8fe16 Branch: refs/heads/master Commit: 81d8fe16a04a7826ba72c89fdc98fa50e3f86f5e Parents: 956ffa6 Author: kl0u <[email protected]> Authored: Mon Nov 21 18:50:30 2016 +0100 Committer: Aljoscha Krettek <[email protected]> Committed: Tue Dec 13 13:38:18 2016 +0100 ---------------------------------------------------------------------- .../source/StatefulSequenceSource.java | 98 ++++++-- .../flink/streaming/api/SourceFunctionTest.java | 8 - .../functions/StatefulSequenceSourceTest.java | 242 +++++++++++++++++++ 3 files changed, 315 insertions(+), 33 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/81d8fe16/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/StatefulSequenceSource.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/StatefulSequenceSource.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/StatefulSequenceSource.java index 563f6ef..bdb12f3 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/StatefulSequenceSource.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/StatefulSequenceSource.java @@ -1,4 +1,4 @@ -/** +/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information @@ -18,25 +18,42 @@ package org.apache.flink.streaming.api.functions.source; import org.apache.flink.annotation.PublicEvolving; -import org.apache.flink.api.common.functions.RuntimeContext; -import org.apache.flink.streaming.api.checkpoint.Checkpointed; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeutils.base.LongSerializer; +import org.apache.flink.runtime.state.FunctionInitializationContext; +import org.apache.flink.runtime.state.FunctionSnapshotContext; +import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction; +import org.apache.flink.util.Preconditions; + +import java.util.ArrayDeque; +import java.util.Deque; /** * A stateful streaming source that emits each number from a given interval exactly once, * possibly in parallel. + * + * <p>For the source to be re-scalable, the first time the job is run, we precompute all the elements + * that each of the tasks should emit and upon checkpointing, each element constitutes its own + * partition. When rescaling, these partitions will be randomly re-assigned to the new tasks. + * + * <p>This strategy guarantees that each element will be emitted exactly-once, but elements will not + * necessarily be emitted in ascending order, even for the same tasks. */ @PublicEvolving -public class StatefulSequenceSource extends RichParallelSourceFunction<Long> implements Checkpointed<Long> { +public class StatefulSequenceSource extends RichParallelSourceFunction<Long> implements CheckpointedFunction { private static final long serialVersionUID = 1L; private final long start; private final long end; - private long collected; - private volatile boolean isRunning = true; + private transient Deque<Long> valuesToEmit; + + private transient ListState<Long> checkpointedState; + /** * Creates a source that emits all numbers from the given interval exactly once. * @@ -49,24 +66,47 @@ public class StatefulSequenceSource extends RichParallelSourceFunction<Long> imp } @Override - public void run(SourceContext<Long> ctx) throws Exception { - final Object checkpointLock = ctx.getCheckpointLock(); + public void initializeState(FunctionInitializationContext context) throws Exception { - RuntimeContext context = getRuntimeContext(); + Preconditions.checkState(this.checkpointedState == null, + "The " + getClass().getSimpleName() + " has already been initialized."); - final long stepSize = context.getNumberOfParallelSubtasks(); - final long congruence = start + context.getIndexOfThisSubtask(); + this.checkpointedState = context.getOperatorStateStore().getOperatorState( + new ListStateDescriptor<>( + "stateful-sequence-source-state", + LongSerializer.INSTANCE + ) + ); - final long toCollect = - ((end - start + 1) % stepSize > (congruence - start)) ? - ((end - start + 1) / stepSize + 1) : - ((end - start + 1) / stepSize); - + this.valuesToEmit = new ArrayDeque<>(); + if (context.isRestored()) { + // upon restoring - while (isRunning && collected < toCollect) { - synchronized (checkpointLock) { - ctx.collect(collected * stepSize + congruence); - collected++; + for (Long v : this.checkpointedState.get()) { + this.valuesToEmit.add(v); + } + } else { + // the first time the job is executed + + final int stepSize = getRuntimeContext().getNumberOfParallelSubtasks(); + final int taskIdx = getRuntimeContext().getIndexOfThisSubtask(); + final long congruence = start + taskIdx; + + long totalNoOfElements = Math.abs(end - start + 1); + final int baseSize = safeDivide(totalNoOfElements, stepSize); + final int toCollect = (totalNoOfElements % stepSize > taskIdx) ? baseSize + 1 : baseSize; + + for (long collected = 0; collected < toCollect; collected++) { + this.valuesToEmit.add(collected * stepSize + congruence); + } + } + } + + @Override + public void run(SourceContext<Long> ctx) throws Exception { + while (isRunning && !this.valuesToEmit.isEmpty()) { + synchronized (ctx.getCheckpointLock()) { + ctx.collect(this.valuesToEmit.poll()); } } } @@ -77,12 +117,20 @@ public class StatefulSequenceSource extends RichParallelSourceFunction<Long> imp } @Override - public Long snapshotState(long checkpointId, long checkpointTimestamp) { - return collected; + public void snapshotState(FunctionSnapshotContext context) throws Exception { + Preconditions.checkState(this.checkpointedState != null, + "The " + getClass().getSimpleName() + " state has not been properly initialized."); + + this.checkpointedState.clear(); + for (Long v : this.valuesToEmit) { + this.checkpointedState.add(v); + } } - @Override - public void restoreState(Long state) { - collected = state; + private static int safeDivide(long left, long right) { + Preconditions.checkArgument(right > 0); + Preconditions.checkArgument(left >= 0); + Preconditions.checkArgument(left <= Integer.MAX_VALUE * right); + return (int) (left / right); } } http://git-wip-us.apache.org/repos/asf/flink/blob/81d8fe16/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/SourceFunctionTest.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/SourceFunctionTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/SourceFunctionTest.java index 946b474..dd4ff33 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/SourceFunctionTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/SourceFunctionTest.java @@ -52,12 +52,4 @@ public class SourceFunctionTest { Arrays.asList(1, 2, 3)))); assertEquals(expectedList, actualList); } - - @Test - public void generateSequenceTest() throws Exception { - List<Long> expectedList = Arrays.asList(1L, 2L, 3L, 4L, 5L, 6L, 7L); - List<Long> actualList = SourceFunctionUtil.runSourceFunction(new StatefulSequenceSource(1, - 7)); - assertEquals(expectedList, actualList); - } } http://git-wip-us.apache.org/repos/asf/flink/blob/81d8fe16/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/functions/StatefulSequenceSourceTest.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/functions/StatefulSequenceSourceTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/functions/StatefulSequenceSourceTest.java new file mode 100644 index 0000000..8332cb3 --- /dev/null +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/functions/StatefulSequenceSourceTest.java @@ -0,0 +1,242 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.api.functions; + +import org.apache.flink.core.testutils.OneShotLatch; +import org.apache.flink.streaming.api.functions.source.SourceFunction; +import org.apache.flink.streaming.api.functions.source.StatefulSequenceSource; +import org.apache.flink.streaming.api.operators.StreamSource; +import org.apache.flink.streaming.api.watermark.Watermark; +import org.apache.flink.streaming.runtime.tasks.OperatorStateHandles; +import org.apache.flink.streaming.util.AbstractStreamOperatorTestHarness; +import org.junit.Assert; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; + +public class StatefulSequenceSourceTest { + + @Test + public void testCheckpointRestore() throws Exception { + final int initElement = 0; + final int maxElement = 100; + + final Set<Long> expectedOutput = new HashSet<>(); + for (long i = initElement; i <= maxElement; i++) { + expectedOutput.add(i); + } + + final ConcurrentHashMap<String, List<Long>> outputCollector = new ConcurrentHashMap<>(); + final OneShotLatch latchToTrigger1 = new OneShotLatch(); + final OneShotLatch latchToWait1 = new OneShotLatch(); + final OneShotLatch latchToTrigger2 = new OneShotLatch(); + final OneShotLatch latchToWait2 = new OneShotLatch(); + + final StatefulSequenceSource source1 = new StatefulSequenceSource(initElement, maxElement); + StreamSource<Long, StatefulSequenceSource> src1 = new StreamSource<>(source1); + + final AbstractStreamOperatorTestHarness<Long> testHarness1 = + new AbstractStreamOperatorTestHarness<>(src1, 2, 2, 0); + testHarness1.open(); + + final StatefulSequenceSource source2 = new StatefulSequenceSource(initElement, maxElement); + StreamSource<Long, StatefulSequenceSource> src2 = new StreamSource<>(source2); + + final AbstractStreamOperatorTestHarness<Long> testHarness2 = + new AbstractStreamOperatorTestHarness<>(src2, 2, 2, 1); + testHarness2.open(); + + final Throwable[] error = new Throwable[3]; + + // run the source asynchronously + Thread runner1 = new Thread() { + @Override + public void run() { + try { + source1.run(new BlockingSourceContext("1", latchToTrigger1, latchToWait1, outputCollector, 21)); + } + catch (Throwable t) { + t.printStackTrace(); + error[0] = t; + } + } + }; + + // run the source asynchronously + Thread runner2 = new Thread() { + @Override + public void run() { + try { + source2.run(new BlockingSourceContext("2", latchToTrigger2, latchToWait2, outputCollector, 32)); + } + catch (Throwable t) { + t.printStackTrace(); + error[1] = t; + } + } + }; + + runner1.start(); + runner2.start(); + + if (!latchToTrigger1.isTriggered()) { + latchToTrigger1.await(); + } + + if (!latchToTrigger2.isTriggered()) { + latchToTrigger2.await(); + } + + OperatorStateHandles snapshot = AbstractStreamOperatorTestHarness.repackageState( + testHarness1.snapshot(0L, 0L), + testHarness2.snapshot(0L, 0L) + ); + + final StatefulSequenceSource source3 = new StatefulSequenceSource(initElement, maxElement); + StreamSource<Long, StatefulSequenceSource> src3 = new StreamSource<>(source3); + + final AbstractStreamOperatorTestHarness<Long> testHarness3 = + new AbstractStreamOperatorTestHarness<>(src3, 2, 1, 0); + testHarness3.setup(); + testHarness3.initializeState(snapshot); + testHarness3.open(); + + final OneShotLatch latchToTrigger3 = new OneShotLatch(); + final OneShotLatch latchToWait3 = new OneShotLatch(); + latchToWait3.trigger(); + + // run the source asynchronously + Thread runner3 = new Thread() { + @Override + public void run() { + try { + source3.run(new BlockingSourceContext("3", latchToTrigger3, latchToWait3, outputCollector, 3)); + } + catch (Throwable t) { + t.printStackTrace(); + error[2] = t; + } + } + }; + runner3.start(); + runner3.join(); + + Assert.assertEquals(3, outputCollector.size()); // we have 3 tasks. + + // test for at-most-once + Set<Long> dedupRes = new HashSet<>(Math.abs(maxElement - initElement) + 1); + for (Map.Entry<String, List<Long>> elementsPerTask: outputCollector.entrySet()) { + String key = elementsPerTask.getKey(); + List<Long> elements = outputCollector.get(key); + + // this tests the correctness of the latches in the test + Assert.assertTrue(elements.size() > 0); + + for (Long elem : elements) { + if (!dedupRes.add(elem)) { + Assert.fail("Duplicate entry: " + elem); + } + + if (!expectedOutput.contains(elem)) { + Assert.fail("Unexpected element: " + elem); + } + } + } + + // test for exactly-once + Assert.assertEquals(Math.abs(initElement - maxElement) + 1, dedupRes.size()); + + latchToWait1.trigger(); + latchToWait2.trigger(); + + // wait for everybody ot finish. + runner1.join(); + runner2.join(); + } + + private static class BlockingSourceContext implements SourceFunction.SourceContext<Long> { + + private final String name; + + private final Object lock; + private final OneShotLatch latchToTrigger; + private final OneShotLatch latchToWait; + private final ConcurrentHashMap<String, List<Long>> collector; + + private final int threshold; + private int counter = 0; + + private final List<Long> localOutput; + + public BlockingSourceContext(String name, OneShotLatch latchToTrigger, OneShotLatch latchToWait, + ConcurrentHashMap<String, List<Long>> output, int elemToFire) { + this.name = name; + this.lock = new Object(); + this.latchToTrigger = latchToTrigger; + this.latchToWait = latchToWait; + this.collector = output; + this.threshold = elemToFire; + + this.localOutput = new ArrayList<>(); + List<Long> prev = collector.put(name, localOutput); + if (prev != null) { + Assert.fail(); + } + } + + @Override + public void collectWithTimestamp(Long element, long timestamp) { + collect(element); + } + + @Override + public void collect(Long element) { + localOutput.add(element); + if (++counter == threshold) { + latchToTrigger.trigger(); + try { + if (!latchToWait.isTriggered()) { + latchToWait.await(); + } + } catch (InterruptedException e) { + e.printStackTrace(); + } + } + } + + + @Override + public void emitWatermark(Watermark mark) { + } + + @Override + public Object getCheckpointLock() { + return lock; + } + + @Override + public void close() { + } + } +}
