Repository: flink
Updated Branches:
  refs/heads/master 369837971 -> 81d8fe16a


[FLINK-5163] Port the StatefulSequenceSource to the new state abstractions.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/81d8fe16
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/81d8fe16
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/81d8fe16

Branch: refs/heads/master
Commit: 81d8fe16a04a7826ba72c89fdc98fa50e3f86f5e
Parents: 956ffa6
Author: kl0u <[email protected]>
Authored: Mon Nov 21 18:50:30 2016 +0100
Committer: Aljoscha Krettek <[email protected]>
Committed: Tue Dec 13 13:38:18 2016 +0100

----------------------------------------------------------------------
 .../source/StatefulSequenceSource.java          |  98 ++++++--
 .../flink/streaming/api/SourceFunctionTest.java |   8 -
 .../functions/StatefulSequenceSourceTest.java   | 242 +++++++++++++++++++
 3 files changed, 315 insertions(+), 33 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/81d8fe16/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/StatefulSequenceSource.java
----------------------------------------------------------------------
diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/StatefulSequenceSource.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/StatefulSequenceSource.java
index 563f6ef..bdb12f3 100644
--- 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/StatefulSequenceSource.java
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/StatefulSequenceSource.java
@@ -1,4 +1,4 @@
-/**
+/*
  * Licensed to the Apache Software Foundation (ASF) under one
  * or more contributor license agreements.  See the NOTICE file
  * distributed with this work for additional information
@@ -18,25 +18,42 @@
 package org.apache.flink.streaming.api.functions.source;
 
 import org.apache.flink.annotation.PublicEvolving;
-import org.apache.flink.api.common.functions.RuntimeContext;
-import org.apache.flink.streaming.api.checkpoint.Checkpointed;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeutils.base.LongSerializer;
+import org.apache.flink.runtime.state.FunctionInitializationContext;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
+import org.apache.flink.util.Preconditions;
+
+import java.util.ArrayDeque;
+import java.util.Deque;
 
 /**
  * A stateful streaming source that emits each number from a given interval 
exactly once,
  * possibly in parallel.
+ *
+ * <p>For the source to be re-scalable, the first time the job is run, we 
precompute all the elements
+ * that each of the tasks should emit and upon checkpointing, each element 
constitutes its own
+ * partition. When rescaling, these partitions will be randomly re-assigned to 
the new tasks.
+ *
+ * <p>This strategy guarantees that each element will be emitted exactly-once, 
but elements will not
+ * necessarily be emitted in ascending order, even for the same tasks.
  */
 @PublicEvolving
-public class StatefulSequenceSource extends RichParallelSourceFunction<Long> 
implements Checkpointed<Long> {
+public class StatefulSequenceSource extends RichParallelSourceFunction<Long> 
implements CheckpointedFunction {
        
        private static final long serialVersionUID = 1L;
 
        private final long start;
        private final long end;
 
-       private long collected;
-
        private volatile boolean isRunning = true;
 
+       private transient Deque<Long> valuesToEmit;
+
+       private transient ListState<Long> checkpointedState;
+
        /**
         * Creates a source that emits all numbers from the given interval 
exactly once.
         *
@@ -49,24 +66,47 @@ public class StatefulSequenceSource extends 
RichParallelSourceFunction<Long> imp
        }
 
        @Override
-       public void run(SourceContext<Long> ctx) throws Exception {
-               final Object checkpointLock = ctx.getCheckpointLock();
+       public void initializeState(FunctionInitializationContext context) 
throws Exception {
 
-               RuntimeContext context = getRuntimeContext();
+               Preconditions.checkState(this.checkpointedState == null,
+                       "The " + getClass().getSimpleName() + " has already 
been initialized.");
 
-               final long stepSize = context.getNumberOfParallelSubtasks();
-               final long congruence = start + context.getIndexOfThisSubtask();
+               this.checkpointedState = 
context.getOperatorStateStore().getOperatorState(
+                       new ListStateDescriptor<>(
+                               "stateful-sequence-source-state",
+                               LongSerializer.INSTANCE
+                       )
+               );
 
-               final long toCollect =
-                               ((end - start + 1) % stepSize > (congruence - 
start)) ?
-                                       ((end - start + 1) / stepSize + 1) :
-                                       ((end - start + 1) / stepSize);
-               
+               this.valuesToEmit = new ArrayDeque<>();
+               if (context.isRestored()) {
+                       // upon restoring
 
-               while (isRunning && collected < toCollect) {
-                       synchronized (checkpointLock) {
-                               ctx.collect(collected * stepSize + congruence);
-                               collected++;
+                       for (Long v : this.checkpointedState.get()) {
+                               this.valuesToEmit.add(v);
+                       }
+               } else {
+                       // the first time the job is executed
+
+                       final int stepSize = 
getRuntimeContext().getNumberOfParallelSubtasks();
+                       final int taskIdx = 
getRuntimeContext().getIndexOfThisSubtask();
+                       final long congruence = start + taskIdx;
+
+                       long totalNoOfElements = Math.abs(end - start + 1);
+                       final int baseSize = safeDivide(totalNoOfElements, 
stepSize);
+                       final int toCollect = (totalNoOfElements % stepSize > 
taskIdx) ? baseSize + 1 : baseSize;
+
+                       for (long collected = 0; collected < toCollect; 
collected++) {
+                               this.valuesToEmit.add(collected * stepSize + 
congruence);
+                       }
+               }
+       }
+
+       @Override
+       public void run(SourceContext<Long> ctx) throws Exception {
+               while (isRunning && !this.valuesToEmit.isEmpty()) {
+                       synchronized (ctx.getCheckpointLock()) {
+                               ctx.collect(this.valuesToEmit.poll());
                        }
                }
        }
@@ -77,12 +117,20 @@ public class StatefulSequenceSource extends 
RichParallelSourceFunction<Long> imp
        }
 
        @Override
-       public Long snapshotState(long checkpointId, long checkpointTimestamp) {
-               return collected;
+       public void snapshotState(FunctionSnapshotContext context) throws 
Exception {
+               Preconditions.checkState(this.checkpointedState != null,
+                       "The " + getClass().getSimpleName() + " state has not 
been properly initialized.");
+
+               this.checkpointedState.clear();
+               for (Long v : this.valuesToEmit) {
+                       this.checkpointedState.add(v);
+               }
        }
 
-       @Override
-       public void restoreState(Long state) {
-               collected = state;
+       private static int safeDivide(long left, long right) {
+               Preconditions.checkArgument(right > 0);
+               Preconditions.checkArgument(left >= 0);
+               Preconditions.checkArgument(left <= Integer.MAX_VALUE * right);
+               return (int) (left / right);
        }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/81d8fe16/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/SourceFunctionTest.java
----------------------------------------------------------------------
diff --git 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/SourceFunctionTest.java
 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/SourceFunctionTest.java
index 946b474..dd4ff33 100644
--- 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/SourceFunctionTest.java
+++ 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/SourceFunctionTest.java
@@ -52,12 +52,4 @@ public class SourceFunctionTest {
                                                Arrays.asList(1, 2, 3))));
                assertEquals(expectedList, actualList);
        }
-
-       @Test
-       public void generateSequenceTest() throws Exception {
-               List<Long> expectedList = Arrays.asList(1L, 2L, 3L, 4L, 5L, 6L, 
7L);
-               List<Long> actualList = 
SourceFunctionUtil.runSourceFunction(new StatefulSequenceSource(1,
-                               7));
-               assertEquals(expectedList, actualList);
-       }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/81d8fe16/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/functions/StatefulSequenceSourceTest.java
----------------------------------------------------------------------
diff --git 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/functions/StatefulSequenceSourceTest.java
 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/functions/StatefulSequenceSourceTest.java
new file mode 100644
index 0000000..8332cb3
--- /dev/null
+++ 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/functions/StatefulSequenceSourceTest.java
@@ -0,0 +1,242 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.functions;
+
+import org.apache.flink.core.testutils.OneShotLatch;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.streaming.api.functions.source.StatefulSequenceSource;
+import org.apache.flink.streaming.api.operators.StreamSource;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.tasks.OperatorStateHandles;
+import org.apache.flink.streaming.util.AbstractStreamOperatorTestHarness;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
+
+public class StatefulSequenceSourceTest {
+
+       @Test
+       public void testCheckpointRestore() throws Exception {
+               final int initElement = 0;
+               final int maxElement = 100;
+
+               final Set<Long> expectedOutput = new HashSet<>();
+               for (long i = initElement; i <= maxElement; i++) {
+                       expectedOutput.add(i);
+               }
+
+               final ConcurrentHashMap<String, List<Long>> outputCollector = 
new ConcurrentHashMap<>();
+               final OneShotLatch latchToTrigger1 = new OneShotLatch();
+               final OneShotLatch latchToWait1 = new OneShotLatch();
+               final OneShotLatch latchToTrigger2 = new OneShotLatch();
+               final OneShotLatch latchToWait2 = new OneShotLatch();
+
+               final StatefulSequenceSource source1 = new 
StatefulSequenceSource(initElement, maxElement);
+               StreamSource<Long, StatefulSequenceSource> src1 = new 
StreamSource<>(source1);
+
+               final AbstractStreamOperatorTestHarness<Long> testHarness1 =
+                       new AbstractStreamOperatorTestHarness<>(src1, 2, 2, 0);
+               testHarness1.open();
+
+               final StatefulSequenceSource source2 = new 
StatefulSequenceSource(initElement, maxElement);
+               StreamSource<Long, StatefulSequenceSource> src2 = new 
StreamSource<>(source2);
+
+               final AbstractStreamOperatorTestHarness<Long> testHarness2 =
+                       new AbstractStreamOperatorTestHarness<>(src2, 2, 2, 1);
+               testHarness2.open();
+
+               final Throwable[] error = new Throwable[3];
+
+               // run the source asynchronously
+               Thread runner1 = new Thread() {
+                       @Override
+                       public void run() {
+                               try {
+                                       source1.run(new 
BlockingSourceContext("1", latchToTrigger1, latchToWait1, outputCollector, 21));
+                               }
+                               catch (Throwable t) {
+                                       t.printStackTrace();
+                                       error[0] = t;
+                               }
+                       }
+               };
+
+               // run the source asynchronously
+               Thread runner2 = new Thread() {
+                       @Override
+                       public void run() {
+                               try {
+                                       source2.run(new 
BlockingSourceContext("2", latchToTrigger2, latchToWait2, outputCollector, 32));
+                               }
+                               catch (Throwable t) {
+                                       t.printStackTrace();
+                                       error[1] = t;
+                               }
+                       }
+               };
+
+               runner1.start();
+               runner2.start();
+
+               if (!latchToTrigger1.isTriggered()) {
+                       latchToTrigger1.await();
+               }
+
+               if (!latchToTrigger2.isTriggered()) {
+                       latchToTrigger2.await();
+               }
+
+               OperatorStateHandles snapshot = 
AbstractStreamOperatorTestHarness.repackageState(
+                       testHarness1.snapshot(0L, 0L),
+                       testHarness2.snapshot(0L, 0L)
+               );
+
+               final StatefulSequenceSource source3 = new 
StatefulSequenceSource(initElement, maxElement);
+               StreamSource<Long, StatefulSequenceSource> src3 = new 
StreamSource<>(source3);
+
+               final AbstractStreamOperatorTestHarness<Long> testHarness3 =
+                       new AbstractStreamOperatorTestHarness<>(src3, 2, 1, 0);
+               testHarness3.setup();
+               testHarness3.initializeState(snapshot);
+               testHarness3.open();
+
+               final OneShotLatch latchToTrigger3 = new OneShotLatch();
+               final OneShotLatch latchToWait3 = new OneShotLatch();
+               latchToWait3.trigger();
+
+               // run the source asynchronously
+               Thread runner3 = new Thread() {
+                       @Override
+                       public void run() {
+                               try {
+                                       source3.run(new 
BlockingSourceContext("3", latchToTrigger3, latchToWait3, outputCollector, 3));
+                               }
+                               catch (Throwable t) {
+                                       t.printStackTrace();
+                                       error[2] = t;
+                               }
+                       }
+               };
+               runner3.start();
+               runner3.join();
+
+               Assert.assertEquals(3, outputCollector.size()); // we have 3 
tasks.
+
+               // test for at-most-once
+               Set<Long> dedupRes = new HashSet<>(Math.abs(maxElement - 
initElement) + 1);
+               for (Map.Entry<String, List<Long>> elementsPerTask: 
outputCollector.entrySet()) {
+                       String key = elementsPerTask.getKey();
+                       List<Long> elements = outputCollector.get(key);
+
+                       // this tests the correctness of the latches in the test
+                       Assert.assertTrue(elements.size() > 0);
+
+                       for (Long elem : elements) {
+                               if (!dedupRes.add(elem)) {
+                                       Assert.fail("Duplicate entry: " + elem);
+                               }
+
+                               if (!expectedOutput.contains(elem)) {
+                                       Assert.fail("Unexpected element: " + 
elem);
+                               }
+                       }
+               }
+
+               // test for exactly-once
+               Assert.assertEquals(Math.abs(initElement - maxElement) + 1, 
dedupRes.size());
+
+               latchToWait1.trigger();
+               latchToWait2.trigger();
+
+               // wait for everybody ot finish.
+               runner1.join();
+               runner2.join();
+       }
+
+       private static class BlockingSourceContext implements 
SourceFunction.SourceContext<Long> {
+
+               private final String name;
+
+               private final Object lock;
+               private final OneShotLatch latchToTrigger;
+               private final OneShotLatch latchToWait;
+               private final ConcurrentHashMap<String, List<Long>> collector;
+
+               private final int threshold;
+               private int counter = 0;
+
+               private final List<Long> localOutput;
+
+               public BlockingSourceContext(String name, OneShotLatch 
latchToTrigger, OneShotLatch latchToWait,
+                                                                        
ConcurrentHashMap<String, List<Long>> output, int elemToFire) {
+                       this.name = name;
+                       this.lock = new Object();
+                       this.latchToTrigger = latchToTrigger;
+                       this.latchToWait = latchToWait;
+                       this.collector = output;
+                       this.threshold = elemToFire;
+
+                       this.localOutput = new ArrayList<>();
+                       List<Long> prev = collector.put(name, localOutput);
+                       if (prev != null) {
+                               Assert.fail();
+                       }
+               }
+
+               @Override
+               public void collectWithTimestamp(Long element, long timestamp) {
+                       collect(element);
+               }
+
+               @Override
+               public void collect(Long element) {
+                       localOutput.add(element);
+                       if (++counter == threshold) {
+                               latchToTrigger.trigger();
+                               try {
+                                       if (!latchToWait.isTriggered()) {
+                                               latchToWait.await();
+                                       }
+                               } catch (InterruptedException e) {
+                                       e.printStackTrace();
+                               }
+                       }
+               }
+
+
+               @Override
+               public void emitWatermark(Watermark mark) {
+               }
+
+               @Override
+               public Object getCheckpointLock() {
+                       return lock;
+               }
+
+               @Override
+               public void close() {
+               }
+       }
+}

Reply via email to