This is an automated email from the ASF dual-hosted git repository.

zhijiang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new 693cb6a  [FLINK-16537][network] Implement ResultPartition state 
recovery for unaligned checkpoint
693cb6a is described below

commit 693cb6adc42d75d1db720b45013430a4c6817d4a
Author: Zhijiang <wangzhijiang...@aliyun.com>
AuthorDate: Fri Apr 3 11:08:56 2020 +0800

    [FLINK-16537][network] Implement ResultPartition state recovery for 
unaligned checkpoint
    
    During state recovery for unaligned checkpoint, the partition state should 
also be recovered besides with existing operator states.
    
    The ResultPartition would request buffer from local pool and then interact 
with ChannelStateReader to fill in the state data.
    The filled buffer would be inserted into respective ResultSubpartition 
queue in normal way.
    
    It should guarantee that op can not process any inputs before finishing all 
the output recovery to avoid mis-order issue.
---
 .../checkpoint/channel/ChannelStateReader.java     |   5 +-
 .../network/api/writer/ResultPartitionWriter.java  |   7 ++
 .../network/partition/PipelinedSubpartition.java   |  22 +++-
 .../partition/PipelinedSubpartitionView.java       |   4 +-
 .../io/network/partition/ResultPartition.java      |   8 ++
 .../io/network/partition/ResultSubpartition.java   |   4 +
 ...bleNotifyingResultPartitionWriterDecorator.java |   6 +
 .../io/network/api/writer/RecordWriterTest.java    |  62 ++++++++++
 .../buffer/BufferBuilderAndConsumerTest.java       |  10 +-
 .../partition/MockResultPartitionWriter.java       |   5 +
 .../partition/NoOpBufferAvailablityListener.java   |   2 +-
 .../io/network/partition/ResultPartitionTest.java  | 125 +++++++++++++++++++++
 .../flink/streaming/runtime/tasks/StreamTask.java  |   9 ++
 .../streaming/runtime/tasks/StreamTaskTest.java    |  42 +++++++
 14 files changed, 300 insertions(+), 11 deletions(-)

diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateReader.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateReader.java
index 49321cc..0753e7a 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateReader.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateReader.java
@@ -18,6 +18,7 @@ package org.apache.flink.runtime.checkpoint.channel;
 
 import org.apache.flink.annotation.Internal;
 import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.buffer.BufferBuilder;
 
 import java.io.IOException;
 
@@ -42,7 +43,7 @@ public interface ChannelStateReader extends AutoCloseable {
         * Put data into the supplied buffer to be injected into
         * {@link 
org.apache.flink.runtime.io.network.partition.ResultSubpartition 
ResultSubpartition}.
         */
-       ReadResult readOutputData(ResultSubpartitionInfo info, Buffer buffer) 
throws IOException;
+       ReadResult readOutputData(ResultSubpartitionInfo info, BufferBuilder 
bufferBuilder) throws IOException;
 
        @Override
        void close() throws Exception;
@@ -55,7 +56,7 @@ public interface ChannelStateReader extends AutoCloseable {
                }
 
                @Override
-               public ReadResult readOutputData(ResultSubpartitionInfo info, 
Buffer buffer) {
+               public ReadResult readOutputData(ResultSubpartitionInfo info, 
BufferBuilder bufferBuilder) {
                        return ReadResult.NO_MORE_DATA;
                }
 
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ResultPartitionWriter.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ResultPartitionWriter.java
index 75cd5fb..2c1717d 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ResultPartitionWriter.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ResultPartitionWriter.java
@@ -18,6 +18,7 @@
 
 package org.apache.flink.runtime.io.network.api.writer;
 
+import org.apache.flink.runtime.checkpoint.channel.ChannelStateReader;
 import org.apache.flink.runtime.io.AvailabilityProvider;
 import org.apache.flink.runtime.io.network.buffer.BufferBuilder;
 import org.apache.flink.runtime.io.network.buffer.BufferConsumer;
@@ -42,6 +43,12 @@ public interface ResultPartitionWriter extends 
AutoCloseable, AvailabilityProvid
         */
        void setup() throws IOException;
 
+       /**
+        * Loads the previous output states with the given reader for unaligned 
checkpoint.
+        * It should be done before task processing the inputs.
+        */
+       void initializeState(ChannelStateReader stateReader) throws 
IOException, InterruptedException;
+
        ResultPartitionID getPartitionId();
 
        int getNumberOfSubpartitions();
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/PipelinedSubpartition.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/PipelinedSubpartition.java
index ecf6956..070089d 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/PipelinedSubpartition.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/PipelinedSubpartition.java
@@ -19,9 +19,12 @@
 package org.apache.flink.runtime.io.network.partition;
 
 import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.runtime.checkpoint.channel.ChannelStateReader;
+import 
org.apache.flink.runtime.checkpoint.channel.ChannelStateReader.ReadResult;
 import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent;
 import org.apache.flink.runtime.io.network.api.serialization.EventSerializer;
 import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.buffer.BufferBuilder;
 import org.apache.flink.runtime.io.network.buffer.BufferConsumer;
 
 import org.slf4j.Logger;
@@ -52,7 +55,7 @@ import static org.apache.flink.util.Preconditions.checkState;
  * {@link PipelinedSubpartitionView#notifyDataAvailable() notification} for any
  * {@link BufferConsumer} present in the queue.
  */
-class PipelinedSubpartition extends ResultSubpartition {
+public class PipelinedSubpartition extends ResultSubpartition {
 
        private static final Logger LOG = 
LoggerFactory.getLogger(PipelinedSubpartition.class);
 
@@ -90,6 +93,23 @@ class PipelinedSubpartition extends ResultSubpartition {
        }
 
        @Override
+       public void initializeState(ChannelStateReader stateReader) throws 
IOException, InterruptedException {
+               for (ReadResult readResult = ReadResult.HAS_MORE_DATA; 
readResult == ReadResult.HAS_MORE_DATA;) {
+                       BufferBuilder bufferBuilder = 
parent.getBufferPool().requestBufferBuilderBlocking();
+                       BufferConsumer bufferConsumer = 
bufferBuilder.createBufferConsumer();
+                       readResult = 
stateReader.readOutputData(subpartitionInfo, bufferBuilder);
+
+                       // check whether there are some states data filled in 
this time
+                       if (bufferConsumer.isDataAvailable()) {
+                               add(bufferConsumer);
+                               bufferBuilder.finish();
+                       } else {
+                               bufferConsumer.close();
+                       }
+               }
+       }
+
+       @Override
        public boolean add(BufferConsumer bufferConsumer) {
                return add(bufferConsumer, false);
        }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/PipelinedSubpartitionView.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/PipelinedSubpartitionView.java
index febbfbd..ee837d5 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/PipelinedSubpartitionView.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/PipelinedSubpartitionView.java
@@ -29,7 +29,7 @@ import static 
org.apache.flink.util.Preconditions.checkNotNull;
 /**
  * View over a pipelined in-memory only subpartition.
  */
-class PipelinedSubpartitionView implements ResultSubpartitionView {
+public class PipelinedSubpartitionView implements ResultSubpartitionView {
 
        /** The subpartition this view belongs to. */
        private final PipelinedSubpartition parent;
@@ -39,7 +39,7 @@ class PipelinedSubpartitionView implements 
ResultSubpartitionView {
        /** Flag indicating whether this view has been released. */
        private final AtomicBoolean isReleased;
 
-       PipelinedSubpartitionView(PipelinedSubpartition parent, 
BufferAvailabilityListener listener) {
+       public PipelinedSubpartitionView(PipelinedSubpartition parent, 
BufferAvailabilityListener listener) {
                this.parent = checkNotNull(parent);
                this.availabilityListener = checkNotNull(listener);
                this.isReleased = new AtomicBoolean();
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartition.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartition.java
index ccd3fa9..bb925fb 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartition.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartition.java
@@ -18,6 +18,7 @@
 
 package org.apache.flink.runtime.io.network.partition;
 
+import org.apache.flink.runtime.checkpoint.channel.ChannelStateReader;
 import org.apache.flink.runtime.executiongraph.IntermediateResultPartition;
 import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
 import org.apache.flink.runtime.io.network.buffer.Buffer;
@@ -150,6 +151,13 @@ public class ResultPartition implements 
ResultPartitionWriter, BufferPoolOwner {
                partitionManager.registerResultPartition(this);
        }
 
+       @Override
+       public void initializeState(ChannelStateReader stateReader) throws 
IOException, InterruptedException {
+               for (ResultSubpartition subpartition : subpartitions) {
+                       subpartition.initializeState(stateReader);
+               }
+       }
+
        public String getOwningTaskName() {
                return owningTaskName;
        }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultSubpartition.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultSubpartition.java
index d139df0..d0256a1 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultSubpartition.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultSubpartition.java
@@ -19,6 +19,7 @@
 package org.apache.flink.runtime.io.network.partition;
 
 import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.runtime.checkpoint.channel.ChannelStateReader;
 import org.apache.flink.runtime.checkpoint.channel.ResultSubpartitionInfo;
 import org.apache.flink.runtime.io.network.buffer.Buffer;
 import org.apache.flink.runtime.io.network.buffer.BufferConsumer;
@@ -76,6 +77,9 @@ public abstract class ResultSubpartition {
                parent.onConsumedSubpartition(index);
        }
 
+       public void initializeState(ChannelStateReader stateReader) throws 
IOException, InterruptedException {
+       }
+
        /**
         * Adds the given buffer.
         *
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/ConsumableNotifyingResultPartitionWriterDecorator.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/ConsumableNotifyingResultPartitionWriterDecorator.java
index 8b1d97d..ada45cb 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/ConsumableNotifyingResultPartitionWriterDecorator.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/ConsumableNotifyingResultPartitionWriterDecorator.java
@@ -19,6 +19,7 @@
 package org.apache.flink.runtime.taskmanager;
 
 import org.apache.flink.api.common.JobID;
+import org.apache.flink.runtime.checkpoint.channel.ChannelStateReader;
 import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor;
 import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
 import org.apache.flink.runtime.io.network.buffer.BufferBuilder;
@@ -89,6 +90,11 @@ public class 
ConsumableNotifyingResultPartitionWriterDecorator implements Result
        }
 
        @Override
+       public void initializeState(ChannelStateReader stateReader) throws 
IOException, InterruptedException {
+               partitionWriter.initializeState(stateReader);
+       }
+
+       @Override
        public boolean addBufferConsumer(BufferConsumer bufferConsumer, int 
subpartitionIndex) throws IOException {
                boolean success = 
partitionWriter.addBufferConsumer(bufferConsumer, subpartitionIndex);
                if (success) {
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/RecordWriterTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/RecordWriterTest.java
index 4964d93..867f591 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/RecordWriterTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/RecordWriterTest.java
@@ -25,6 +25,7 @@ import org.apache.flink.core.memory.DataOutputView;
 import org.apache.flink.core.memory.MemorySegment;
 import org.apache.flink.core.memory.MemorySegmentFactory;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
+import org.apache.flink.runtime.checkpoint.channel.ChannelStateReader;
 import org.apache.flink.runtime.event.AbstractEvent;
 import org.apache.flink.runtime.io.network.api.CheckpointBarrier;
 import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent;
@@ -34,6 +35,7 @@ import 
org.apache.flink.runtime.io.network.api.serialization.RecordSerializer.Se
 import 
org.apache.flink.runtime.io.network.api.serialization.SpillingAdaptiveSpanningRecordDeserializer;
 import org.apache.flink.runtime.io.network.buffer.Buffer;
 import org.apache.flink.runtime.io.network.buffer.BufferBuilder;
+import org.apache.flink.runtime.io.network.buffer.BufferBuilderAndConsumerTest;
 import org.apache.flink.runtime.io.network.buffer.BufferBuilderTestUtils;
 import org.apache.flink.runtime.io.network.buffer.BufferConsumer;
 import org.apache.flink.runtime.io.network.buffer.BufferPool;
@@ -41,8 +43,15 @@ import 
org.apache.flink.runtime.io.network.buffer.BufferProvider;
 import org.apache.flink.runtime.io.network.buffer.BufferRecycler;
 import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool;
 import org.apache.flink.runtime.io.network.partition.MockResultPartitionWriter;
+import 
org.apache.flink.runtime.io.network.partition.NoOpBufferAvailablityListener;
 import 
org.apache.flink.runtime.io.network.partition.NoOpResultPartitionConsumableNotifier;
+import org.apache.flink.runtime.io.network.partition.PipelinedSubpartition;
+import org.apache.flink.runtime.io.network.partition.PipelinedSubpartitionView;
+import org.apache.flink.runtime.io.network.partition.ResultPartition;
 import org.apache.flink.runtime.io.network.partition.ResultPartitionBuilder;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionTest;
+import org.apache.flink.runtime.io.network.partition.ResultSubpartition;
+import org.apache.flink.runtime.io.network.partition.ResultSubpartitionView;
 import org.apache.flink.runtime.io.network.partition.consumer.BufferOrEvent;
 import org.apache.flink.runtime.io.network.util.DeserializationUtils;
 import org.apache.flink.runtime.io.network.util.TestPooledBufferProvider;
@@ -464,6 +473,59 @@ public class RecordWriterTest {
                }
        }
 
+       @Test
+       public void testEmitRecordWithPartitionStateRecovery() throws Exception 
{
+               final int totalBuffers = 10; // enough for both states and 
normal records
+               final int totalStates = 2;
+               final int[] states = {1, 2, 3, 4};
+               final int[] records = {5, 6, 7, 8};
+               final int bufferSize = states.length * Integer.BYTES;
+
+               final NetworkBufferPool globalPool = new 
NetworkBufferPool(totalBuffers, bufferSize, 1);
+               final ChannelStateReader stateReader = new 
ResultPartitionTest.FiniteChannelStateReader(totalStates, states);
+               final ResultPartition partition = new ResultPartitionBuilder()
+                       .setNetworkBufferPool(globalPool)
+                       .build();
+               final RecordWriter<IntValue> recordWriter = new 
RecordWriterBuilder<IntValue>().build(partition);
+
+               try {
+                       partition.setup();
+                       partition.initializeState(stateReader);
+
+                       for (int record: records) {
+                               // the record length 4 is also written into 
buffer for every emit
+                               recordWriter.broadcastEmit(new 
IntValue(record));
+                       }
+
+                       // every buffer can contain 2 int records with 2 int 
length(4)
+                       final int[][] expectedRecordsInBuffer = {{4, 5, 4, 6}, 
{4, 7, 4, 8}};
+
+                       for (ResultSubpartition subpartition : 
partition.getAllPartitions()) {
+                               // create the view to consume all the buffers 
with states and records
+                               final ResultSubpartitionView view = new 
PipelinedSubpartitionView(
+                                       (PipelinedSubpartition) subpartition,
+                                       new NoOpBufferAvailablityListener());
+
+                               int numConsumedBuffers = 0;
+                               ResultSubpartition.BufferAndBacklog 
bufferAndBacklog;
+                               while ((bufferAndBacklog = 
view.getNextBuffer()) != null) {
+                                       Buffer buffer = 
bufferAndBacklog.buffer();
+                                       int[] expected = numConsumedBuffers < 
totalStates ? states : expectedRecordsInBuffer[numConsumedBuffers - 
totalStates];
+                                       
BufferBuilderAndConsumerTest.assertContent(buffer, partition.getBufferPool(), 
expected);
+
+                                       buffer.recycleBuffer();
+                                       numConsumedBuffers++;
+                               }
+
+                               assertEquals(totalStates + 
expectedRecordsInBuffer.length, numConsumedBuffers);
+                       }
+               } finally {
+                       // cleanup
+                       globalPool.destroyAllBufferPools();
+                       globalPool.destroy();
+               }
+       }
+
        private void verifyBroadcastBufferOrEventIndependence(boolean 
broadcastEvent) throws Exception {
                @SuppressWarnings("unchecked")
                ArrayDeque<BufferConsumer>[] queues = new ArrayDeque[]{new 
ArrayDeque(), new ArrayDeque()};
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/buffer/BufferBuilderAndConsumerTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/buffer/BufferBuilderAndConsumerTest.java
index 3975a71..1033c5e 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/buffer/BufferBuilderAndConsumerTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/buffer/BufferBuilderAndConsumerTest.java
@@ -164,7 +164,7 @@ public class BufferBuilderAndConsumerTest {
        public void buildEmptyBuffer() {
                Buffer buffer = buildSingleBuffer(createBufferBuilder());
                assertEquals(0, buffer.getSize());
-               assertContent(buffer);
+               assertContent(buffer, FreeingBufferRecycler.INSTANCE);
        }
 
        @Test
@@ -240,7 +240,7 @@ public class BufferBuilderAndConsumerTest {
                assertTrue(bufferConsumer.isFinished());
        }
 
-       private static ByteBuffer toByteBuffer(int... data) {
+       public static ByteBuffer toByteBuffer(int... data) {
                ByteBuffer byteBuffer = ByteBuffer.allocate(data.length * 
Integer.BYTES);
                byteBuffer.asIntBuffer().put(data);
                return byteBuffer;
@@ -250,18 +250,18 @@ public class BufferBuilderAndConsumerTest {
                assertFalse(actualConsumer.isFinished());
                Buffer buffer = actualConsumer.build();
                assertFalse(buffer.isRecycled());
-               assertContent(buffer, expected);
+               assertContent(buffer, FreeingBufferRecycler.INSTANCE, expected);
                assertEquals(expected.length * Integer.BYTES, buffer.getSize());
                buffer.recycleBuffer();
        }
 
-       private static void assertContent(Buffer actualBuffer, int... expected) 
{
+       public static void assertContent(Buffer actualBuffer, BufferRecycler 
recycler, int... expected) {
                IntBuffer actualIntBuffer = 
actualBuffer.getNioBufferReadable().asIntBuffer();
                int[] actual = new int[actualIntBuffer.limit()];
                actualIntBuffer.get(actual);
                assertArrayEquals(expected, actual);
 
-               assertEquals(FreeingBufferRecycler.INSTANCE, 
actualBuffer.getRecycler());
+               assertEquals(recycler, actualBuffer.getRecycler());
        }
 
        private static BufferBuilder createBufferBuilder() {
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/MockResultPartitionWriter.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/MockResultPartitionWriter.java
index fd6c1f8..9fd8205 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/MockResultPartitionWriter.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/MockResultPartitionWriter.java
@@ -18,6 +18,7 @@
 
 package org.apache.flink.runtime.io.network.partition;
 
+import org.apache.flink.runtime.checkpoint.channel.ChannelStateReader;
 import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
 import org.apache.flink.runtime.io.network.buffer.BufferBuilder;
 import org.apache.flink.runtime.io.network.buffer.BufferConsumer;
@@ -39,6 +40,10 @@ public class MockResultPartitionWriter implements 
ResultPartitionWriter {
        }
 
        @Override
+       public void initializeState(ChannelStateReader stateReader) {
+       }
+
+       @Override
        public ResultPartitionID getPartitionId() {
                return partitionId;
        }
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/NoOpBufferAvailablityListener.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/NoOpBufferAvailablityListener.java
index 4162975..7fbd43e 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/NoOpBufferAvailablityListener.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/NoOpBufferAvailablityListener.java
@@ -21,7 +21,7 @@ package org.apache.flink.runtime.io.network.partition;
 /**
  * Test implementation of {@link BufferAvailabilityListener}.
  */
-class NoOpBufferAvailablityListener implements BufferAvailabilityListener {
+public class NoOpBufferAvailablityListener implements 
BufferAvailabilityListener {
        @Override
        public void notifyDataAvailable() {
        }
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/ResultPartitionTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/ResultPartitionTest.java
index 011aa72..f3e512f 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/ResultPartitionTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/ResultPartitionTest.java
@@ -19,13 +19,17 @@
 package org.apache.flink.runtime.io.network.partition;
 
 import org.apache.flink.api.common.JobID;
+import org.apache.flink.runtime.checkpoint.channel.ChannelStateReader;
+import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo;
 import org.apache.flink.runtime.checkpoint.channel.ResultSubpartitionInfo;
 import org.apache.flink.runtime.io.disk.FileChannelManager;
 import org.apache.flink.runtime.io.disk.FileChannelManagerImpl;
 import org.apache.flink.runtime.io.network.NettyShuffleEnvironment;
 import org.apache.flink.runtime.io.network.NettyShuffleEnvironmentBuilder;
 import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
 import org.apache.flink.runtime.io.network.buffer.BufferBuilder;
+import org.apache.flink.runtime.io.network.buffer.BufferBuilderAndConsumerTest;
 import org.apache.flink.runtime.io.network.buffer.BufferBuilderTestUtils;
 import org.apache.flink.runtime.io.network.buffer.BufferConsumer;
 import org.apache.flink.runtime.io.network.buffer.BufferPool;
@@ -42,6 +46,11 @@ import org.junit.Test;
 
 import java.io.IOException;
 import java.util.Collections;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+import java.util.concurrent.TimeUnit;
 
 import static 
org.apache.flink.runtime.io.network.buffer.BufferBuilderTestUtils.createFilledFinishedBufferConsumer;
 import static 
org.apache.flink.runtime.io.network.partition.PartitionTestUtils.createPartition;
@@ -407,4 +416,120 @@ public class ResultPartitionTest {
                        jobId,
                        notifier)[0];
        }
+
+       @Test
+       public void testInitializeEmptyState() throws Exception {
+               final int totalBuffers = 2;
+               final NetworkBufferPool globalPool = new 
NetworkBufferPool(totalBuffers, 1, 1);
+               final ResultPartition partition = new ResultPartitionBuilder()
+                       .setNetworkBufferPool(globalPool)
+                       .build();
+               final ChannelStateReader stateReader = ChannelStateReader.NO_OP;
+               try {
+                       partition.setup();
+                       partition.initializeState(stateReader);
+
+                       for (ResultSubpartition subpartition : 
partition.getAllPartitions()) {
+                               // no buffers are added into the queue for 
empty states
+                               assertEquals(0, 
subpartition.getTotalNumberOfBuffers());
+                       }
+
+                       // destroy the local pool to verify that all the 
requested buffers by partition are recycled
+                       partition.getBufferPool().lazyDestroy();
+                       assertEquals(totalBuffers, 
globalPool.getNumberOfAvailableMemorySegments());
+               } finally {
+                       // cleanup
+                       globalPool.destroyAllBufferPools();
+                       globalPool.destroy();
+               }
+       }
+
+       @Test
+       public void testInitializeMoreStateThanBuffer() throws Exception {
+               final int totalBuffers = 2; // the total buffers are less than 
the requirement from total states
+               final int totalStates = 5;
+               final int[] states = {1, 2, 3, 4};
+               final int bufferSize = states.length * Integer.BYTES;
+
+               final NetworkBufferPool globalPool = new 
NetworkBufferPool(totalBuffers, bufferSize, 1);
+               final ChannelStateReader stateReader = new 
FiniteChannelStateReader(totalStates, states);
+               final ResultPartition partition = new ResultPartitionBuilder()
+                       .setNetworkBufferPool(globalPool)
+                       .build();
+               final ExecutorService executor = 
Executors.newFixedThreadPool(1);
+
+               try {
+                       final Callable<Void> partitionConsumeTask = () -> {
+                               for (ResultSubpartition subpartition : 
partition.getAllPartitions()) {
+                                       final ResultSubpartitionView view = new 
PipelinedSubpartitionView(
+                                               (PipelinedSubpartition) 
subpartition,
+                                               new 
NoOpBufferAvailablityListener());
+
+                                       int numConsumedBuffers = 0;
+                                       while (numConsumedBuffers != 
totalStates) {
+                                               
ResultSubpartition.BufferAndBacklog bufferAndBacklog = view.getNextBuffer();
+                                               if (bufferAndBacklog != null) {
+                                                       Buffer buffer = 
bufferAndBacklog.buffer();
+                                                       
BufferBuilderAndConsumerTest.assertContent(buffer, partition.getBufferPool(), 
states);
+                                                       buffer.recycleBuffer();
+                                                       numConsumedBuffers++;
+                                               } else {
+                                                       Thread.sleep(5);
+                                               }
+                                       }
+                               }
+                               return null;
+                       };
+                       Future<Void> result = 
executor.submit(partitionConsumeTask);
+
+                       partition.setup();
+                       partition.initializeState(stateReader);
+
+                       // wait the partition consume task finish
+                       result.get(20, TimeUnit.SECONDS);
+
+                       // destroy the local pool to verify that all the 
requested buffers by partition are recycled
+                       partition.getBufferPool().lazyDestroy();
+                       assertEquals(totalBuffers, 
globalPool.getNumberOfAvailableMemorySegments());
+               } finally {
+                       // cleanup
+                       executor.shutdown();
+                       globalPool.destroyAllBufferPools();
+                       globalPool.destroy();
+               }
+       }
+
+       /**
+        * The {@link ChannelStateReader} instance for restoring the specific 
number of states.
+        */
+       public static final class FiniteChannelStateReader implements 
ChannelStateReader {
+               private final int totalStates;
+               private int numRestoredStates;
+               private final int[] states;
+
+               public FiniteChannelStateReader(int totalStates, int[] states) {
+                       this.totalStates = totalStates;
+                       this.states = states;
+               }
+
+               @Override
+               public ReadResult readInputData(InputChannelInfo info, Buffer 
buffer) {
+                       return ReadResult.NO_MORE_DATA;
+               }
+
+               @Override
+               public ReadResult readOutputData(ResultSubpartitionInfo info, 
BufferBuilder bufferBuilder) {
+                       
bufferBuilder.appendAndCommit(BufferBuilderAndConsumerTest.toByteBuffer(states));
+
+                       if (++numRestoredStates < totalStates) {
+                               return ReadResult.HAS_MORE_DATA;
+                       } else {
+                               return ReadResult.NO_MORE_DATA;
+                       }
+               }
+
+               @Override
+               public void close() {
+               }
+       }
 }
diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
index d3a79b6..e699994 100644
--- 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
@@ -28,6 +28,7 @@ import 
org.apache.flink.runtime.checkpoint.CheckpointFailureReason;
 import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
 import org.apache.flink.runtime.checkpoint.CheckpointMetrics;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
+import org.apache.flink.runtime.checkpoint.channel.ChannelStateReader;
 import org.apache.flink.runtime.concurrent.FutureUtils;
 import org.apache.flink.runtime.execution.CancelTaskException;
 import org.apache.flink.runtime.execution.Environment;
@@ -434,6 +435,14 @@ public abstract class StreamTask<OUT, OP extends 
StreamOperator<OUT>>
                        // so that we avoid race conditions in the case that 
initializeState()
                        // registers a timer, that fires before the open() is 
called.
                        
operatorChain.initializeStateAndOpenOperators(createStreamTaskStateInitializer());
+
+                       ResultPartitionWriter[] writers = 
getEnvironment().getAllWriters();
+                       if (writers != null) {
+                               //TODO we should get proper state reader from 
getEnvironment().getTaskStateManager().getChannelStateReader()
+                               for (ResultPartitionWriter writer : writers) {
+                                       
writer.initializeState(ChannelStateReader.NO_OP);
+                               }
+                       }
                });
 
                isRunning = true;
diff --git 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java
 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java
index 3fa9508..c610a4f 100644
--- 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java
+++ 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java
@@ -33,6 +33,7 @@ import 
org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
 import org.apache.flink.runtime.checkpoint.StateObjectCollection;
 import org.apache.flink.runtime.checkpoint.SubtaskState;
 import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
+import org.apache.flink.runtime.checkpoint.channel.ChannelStateReader;
 import org.apache.flink.runtime.concurrent.FutureUtils;
 import org.apache.flink.runtime.concurrent.TestingUncaughtExceptionHandler;
 import org.apache.flink.runtime.execution.CancelTaskException;
@@ -44,6 +45,7 @@ import 
org.apache.flink.runtime.io.network.NettyShuffleEnvironmentBuilder;
 import 
org.apache.flink.runtime.io.network.api.writer.AvailabilityTestResultPartitionWriter;
 import org.apache.flink.runtime.io.network.api.writer.RecordWriter;
 import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
+import org.apache.flink.runtime.io.network.partition.MockResultPartitionWriter;
 import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
 import org.apache.flink.runtime.operators.testutils.DummyEnvironment;
@@ -895,6 +897,30 @@ public class StreamTaskTest extends TestLogger {
                }
        }
 
+       @Test
+       public void testInitializeResultPartitionState() throws Exception {
+               int numWriters = 2;
+               RecoveryResultPartition[] partitions = new 
RecoveryResultPartition[numWriters];
+               for (int i = 0; i < numWriters; i++) {
+                       partitions[i] = new RecoveryResultPartition();
+               }
+
+               MockEnvironment mockEnvironment = new 
MockEnvironmentBuilder().build();
+               mockEnvironment.addOutputs(Arrays.asList(partitions));
+               StreamTask task = new 
MockStreamTaskBuilder(mockEnvironment).build();
+
+               try {
+                       task.beforeInvoke();
+
+                       // output recovery should be done before task processing
+                       for (RecoveryResultPartition resultPartition : 
partitions) {
+                               
assertTrue(resultPartition.isStateInitialized());
+                       }
+               } finally {
+                       task.cleanUpInvoke();
+               }
+       }
+
        /**
         * Tests that some StreamTask methods are called only in the main 
task's thread.
         * Currently, the main task's thread is the thread that creates the 
task.
@@ -1723,4 +1749,20 @@ public class StreamTaskTest extends TestLogger {
                        throw new UnsupportedOperationException();
                }
        }
+
+       private static class RecoveryResultPartition extends 
MockResultPartitionWriter {
+               private boolean isStateInitialized;
+
+               RecoveryResultPartition() {
+               }
+
+               @Override
+               public void initializeState(ChannelStateReader stateReader) {
+                       isStateInitialized = true;
+               }
+
+               boolean isStateInitialized() {
+                       return isStateInitialized;
+               }
+       }
 }

Reply via email to