Repository: flink
Updated Branches:
  refs/heads/release-1.1 e2c53cf85 -> 5ebd7c844


[FLINK-5169] [network] Add tests for channel consumption

This closes #2882.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/5ebd7c84
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/5ebd7c84
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/5ebd7c84

Branch: refs/heads/release-1.1
Commit: 5ebd7c8443df38cddd9463a4609821518cd1a9cb
Parents: 8d97eaa
Author: Stephan Ewen <[email protected]>
Authored: Sun Nov 27 18:15:40 2016 +0100
Committer: Ufuk Celebi <[email protected]>
Committed: Mon Nov 28 21:05:00 2016 +0100

----------------------------------------------------------------------
 .../partition/PipelinedSubpartition.java        |   8 +
 .../partition/consumer/LocalInputChannel.java   |   4 +-
 .../partition/consumer/SingleInputGate.java     |   4 +-
 .../partition/consumer/UnionInputGate.java      |   2 +-
 .../partition/InputChannelTestUtils.java        |  89 +++++
 .../partition/InputGateConcurrentTest.java      | 325 +++++++++++++++
 .../partition/InputGateFairnessTest.java        | 397 +++++++++++++++++++
 7 files changed, 824 insertions(+), 5 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/5ebd7c84/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/PipelinedSubpartition.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/PipelinedSubpartition.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/PipelinedSubpartition.java
index 4d5e378..9d88ff0 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/PipelinedSubpartition.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/PipelinedSubpartition.java
@@ -183,6 +183,14 @@ class PipelinedSubpartition extends ResultSubpartition {
                return readView;
        }
 
+       // 
------------------------------------------------------------------------
+
+       int getCurrentNumberOfBuffers() {
+               return buffers.size();
+       }
+
+       // 
------------------------------------------------------------------------
+
        @Override
        public String toString() {
                final long numBuffers;

http://git-wip-us.apache.org/repos/asf/flink/blob/5ebd7c84/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannel.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannel.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannel.java
index b34dbff..0a02ea1 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannel.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannel.java
@@ -65,7 +65,7 @@ public class LocalInputChannel extends InputChannel 
implements BufferAvailabilit
 
        private volatile boolean isReleased;
 
-       LocalInputChannel(
+       public LocalInputChannel(
                SingleInputGate inputGate,
                int channelIndex,
                ResultPartitionID partitionId,
@@ -77,7 +77,7 @@ public class LocalInputChannel extends InputChannel 
implements BufferAvailabilit
                        new Tuple2<Integer, Integer>(0, 0), metrics);
        }
 
-       LocalInputChannel(
+       public LocalInputChannel(
                SingleInputGate inputGate,
                int channelIndex,
                ResultPartitionID partitionId,

http://git-wip-us.apache.org/repos/asf/flink/blob/5ebd7c84/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java
index 105d28b..8f44fbc 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java
@@ -242,7 +242,7 @@ public class SingleInputGate implements InputGate {
                this.bufferPool = checkNotNull(bufferPool);
        }
 
-       void setInputChannel(IntermediateResultPartitionID partitionId, 
InputChannel inputChannel) {
+       public void setInputChannel(IntermediateResultPartitionID partitionId, 
InputChannel inputChannel) {
                synchronized (requestLock) {
                        if (inputChannels.put(checkNotNull(partitionId), 
checkNotNull(inputChannel)) == null
                                        && inputChannel.getClass() == 
UnknownInputChannel.class) {
@@ -527,7 +527,7 @@ public class SingleInputGate implements InputGate {
                        inputChannelsWithData.add(channel);
 
                        if (availableChannels == 0) {
-                               inputChannelsWithData.notify();
+                               inputChannelsWithData.notifyAll();
                        }
                }
 

http://git-wip-us.apache.org/repos/asf/flink/blob/5ebd7c84/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java
index e8ccbb4..55c78af 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java
@@ -225,7 +225,7 @@ public class UnionInputGate implements InputGate, 
InputGateListener {
                        inputGatesWithData.add(inputGate);
 
                        if (availableInputGates == 0) {
-                               inputGatesWithData.notify();
+                               inputGatesWithData.notifyAll();
                        }
                }
 

http://git-wip-us.apache.org/repos/asf/flink/blob/5ebd7c84/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputChannelTestUtils.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputChannelTestUtils.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputChannelTestUtils.java
new file mode 100644
index 0000000..e292576
--- /dev/null
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputChannelTestUtils.java
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.io.network.partition;
+
+import org.apache.flink.runtime.io.network.ConnectionID;
+import org.apache.flink.runtime.io.network.ConnectionManager;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.buffer.BufferProvider;
+import org.apache.flink.runtime.io.network.netty.PartitionRequestClient;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
+
+import static org.mockito.Matchers.any;
+import static org.mockito.Matchers.anyInt;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+/**
+ * Some utility methods used for testing InputChannels and InputGates.
+ */
+class InputChannelTestUtils {
+
+       /**
+        * Creates a simple Buffer that is not recycled (never will be) of the 
given size.
+        */
+       public static Buffer createMockBuffer(int size) {
+               final Buffer mockBuffer = mock(Buffer.class);
+               when(mockBuffer.isBuffer()).thenReturn(true);
+               when(mockBuffer.getSize()).thenReturn(size);
+               when(mockBuffer.isRecycled()).thenReturn(false);
+
+               return mockBuffer;
+       }
+
+       /**
+        * Creates a result partition manager that ignores all IDs, and simply 
returns the given
+        * subpartitions in sequence.
+        */
+       public static ResultPartitionManager createResultPartitionManager(final 
ResultSubpartition[] sources) throws Exception {
+
+               final Answer<ResultSubpartitionView> viewCreator = new 
Answer<ResultSubpartitionView>() {
+
+                       private int num = 0;
+
+                       @Override
+                       public ResultSubpartitionView answer(InvocationOnMock 
invocation) throws Throwable {
+                               BufferAvailabilityListener channel = 
(BufferAvailabilityListener) invocation.getArguments()[3];
+                               return sources[num++].createReadView(null, 
channel);
+                       }
+               };
+
+               ResultPartitionManager manager = 
mock(ResultPartitionManager.class);
+               when(manager.createSubpartitionView(
+                               any(ResultPartitionID.class), anyInt(), 
any(BufferProvider.class), any(BufferAvailabilityListener.class)))
+                               .thenAnswer(viewCreator);
+
+               return manager;
+       }
+       
+       public static ConnectionManager createDummyConnectionManager() throws 
Exception {
+               final PartitionRequestClient mockClient = 
mock(PartitionRequestClient.class);
+
+               final ConnectionManager connManager = 
mock(ConnectionManager.class);
+               
when(connManager.createPartitionRequestClient(any(ConnectionID.class))).thenReturn(mockClient);
+
+               return connManager;
+       }
+
+       // 
------------------------------------------------------------------------
+
+       /** This class is not meant to be instantiated */
+       private InputChannelTestUtils() {}
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/5ebd7c84/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateConcurrentTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateConcurrentTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateConcurrentTest.java
new file mode 100644
index 0000000..a5f4c7d
--- /dev/null
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateConcurrentTest.java
@@ -0,0 +1,325 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.io.network.partition;
+
+import org.apache.flink.api.common.JobID;
+import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
+import org.apache.flink.runtime.io.network.ConnectionID;
+import org.apache.flink.runtime.io.network.ConnectionManager;
+import org.apache.flink.runtime.io.network.TaskEventDispatcher;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.netty.PartitionStateChecker;
+import 
org.apache.flink.runtime.io.network.partition.consumer.LocalInputChannel;
+import 
org.apache.flink.runtime.io.network.partition.consumer.RemoteInputChannel;
+import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate;
+import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
+import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
+import 
org.apache.flink.runtime.operators.testutils.UnregisteredTaskMetricsGroup;
+import 
org.apache.flink.runtime.operators.testutils.UnregisteredTaskMetricsGroup.DummyIOMetricGroup;
+import org.junit.Test;
+import scala.Tuple2;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Random;
+
+import static 
org.apache.flink.runtime.io.network.partition.InputChannelTestUtils.createDummyConnectionManager;
+import static 
org.apache.flink.runtime.io.network.partition.InputChannelTestUtils.createResultPartitionManager;
+import static org.junit.Assert.assertNotNull;
+import static org.mockito.Mockito.mock;
+
+public class InputGateConcurrentTest {
+
+       @Test
+       public void testConsumptionWithLocalChannels() throws Exception {
+               final int numChannels = 11;
+               final int buffersPerChannel = 1000;
+
+               final ResultPartition resultPartition = 
mock(ResultPartition.class);
+
+               final PipelinedSubpartition[] partitions = new 
PipelinedSubpartition[numChannels];
+               final Source[] sources = new Source[numChannels];
+
+               final ResultPartitionManager resultPartitionManager = 
createResultPartitionManager(partitions);
+
+               final SingleInputGate gate = new SingleInputGate(
+                               "Test Task Name",
+                               new JobID(),
+                               new ExecutionAttemptID(),
+                               new IntermediateDataSetID(),
+                               0, numChannels,
+                               mock(PartitionStateChecker.class),
+                               new 
UnregisteredTaskMetricsGroup.DummyIOMetricGroup());
+
+               for (int i = 0; i < numChannels; i++) {
+                       LocalInputChannel channel = new LocalInputChannel(gate, 
i, new ResultPartitionID(),
+                                       resultPartitionManager, 
mock(TaskEventDispatcher.class), new DummyIOMetricGroup());
+                       gate.setInputChannel(new 
IntermediateResultPartitionID(), channel);
+
+                       partitions[i] = new PipelinedSubpartition(0, 
resultPartition);
+                       sources[i] = new 
PipelinedSubpartitionSource(partitions[i]);
+               }
+
+               ProducerThread producer = new ProducerThread(sources, 
numChannels * buffersPerChannel, 4, 10);
+               ConsumerThread consumer = new ConsumerThread(gate, numChannels 
* buffersPerChannel);
+               producer.start();
+               consumer.start();
+
+               // the 'sync()' call checks for exceptions and failed assertions
+               producer.sync();
+               consumer.sync();
+       }
+
+       @Test
+       public void testConsumptionWithRemoteChannels() throws Exception {
+               final int numChannels = 11;
+               final int buffersPerChannel = 1000;
+
+               final ConnectionManager connManager = 
createDummyConnectionManager();
+               final Source[] sources = new Source[numChannels];
+
+               final SingleInputGate gate = new SingleInputGate(
+                               "Test Task Name",
+                               new JobID(),
+                               new ExecutionAttemptID(),
+                               new IntermediateDataSetID(),
+                               0,
+                               numChannels,
+                               mock(PartitionStateChecker.class),
+                               new 
UnregisteredTaskMetricsGroup.DummyIOMetricGroup());
+
+               for (int i = 0; i < numChannels; i++) {
+                       RemoteInputChannel channel = new RemoteInputChannel(
+                                       gate, i, new ResultPartitionID(), 
mock(ConnectionID.class),
+                                       connManager, new Tuple2<>(0, 0), new 
DummyIOMetricGroup());
+                       gate.setInputChannel(new 
IntermediateResultPartitionID(), channel);
+
+                       sources[i] = new RemoteChannelSource(channel);
+               }
+
+               ProducerThread producer = new ProducerThread(sources, 
numChannels * buffersPerChannel, 4, 10);
+               ConsumerThread consumer = new ConsumerThread(gate, numChannels 
* buffersPerChannel);
+               producer.start();
+               consumer.start();
+
+               // the 'sync()' call checks for exceptions and failed assertions
+               producer.sync();
+               consumer.sync();
+       }
+
+       @Test
+       public void testConsumptionWithMixedChannels() throws Exception {
+               final int numChannels = 61;
+               final int numLocalChannels = 20;
+               final int buffersPerChannel = 1000;
+
+               // fill the local/remote decision
+               List<Boolean> localOrRemote = new ArrayList<>(numChannels);
+               for (int i = 0; i < numChannels; i++) {
+                       localOrRemote.add(i < numLocalChannels);
+               }
+               Collections.shuffle(localOrRemote);
+
+               final ConnectionManager connManager = 
createDummyConnectionManager();
+               final ResultPartition resultPartition = 
mock(ResultPartition.class);
+
+               final PipelinedSubpartition[] localPartitions = new 
PipelinedSubpartition[numLocalChannels];
+               final ResultPartitionManager resultPartitionManager = 
createResultPartitionManager(localPartitions);
+
+               final Source[] sources = new Source[numChannels];
+
+               final SingleInputGate gate = new SingleInputGate(
+                               "Test Task Name",
+                               new JobID(),
+                               new ExecutionAttemptID(),
+                               new IntermediateDataSetID(),
+                               0,
+                               numChannels,
+                               mock(PartitionStateChecker.class),
+                               new 
UnregisteredTaskMetricsGroup.DummyIOMetricGroup());
+
+               for (int i = 0, local = 0; i < numChannels; i++) {
+                       if (localOrRemote.get(i)) {
+                               // local channel
+                               PipelinedSubpartition psp = new 
PipelinedSubpartition(0, resultPartition);
+                               localPartitions[local++] = psp;
+                               sources[i] = new 
PipelinedSubpartitionSource(psp);
+
+                               LocalInputChannel channel = new 
LocalInputChannel(gate, i, new ResultPartitionID(),
+                                               resultPartitionManager, 
mock(TaskEventDispatcher.class), new DummyIOMetricGroup());
+                               gate.setInputChannel(new 
IntermediateResultPartitionID(), channel);
+                       }
+                       else {
+                               //remote channel
+                               RemoteInputChannel channel = new 
RemoteInputChannel(
+                                               gate, i, new 
ResultPartitionID(), mock(ConnectionID.class),
+                                               connManager, new Tuple2<>(0, 
0), new DummyIOMetricGroup());
+                               gate.setInputChannel(new 
IntermediateResultPartitionID(), channel);
+
+                               sources[i] = new RemoteChannelSource(channel);
+                       }
+               }
+
+               ProducerThread producer = new ProducerThread(sources, 
numChannels * buffersPerChannel, 4, 10);
+               ConsumerThread consumer = new ConsumerThread(gate, numChannels 
* buffersPerChannel);
+               producer.start();
+               consumer.start();
+
+               // the 'sync()' call checks for exceptions and failed assertions
+               producer.sync();
+               consumer.sync();
+       }
+
+       // 
------------------------------------------------------------------------
+       //  testing threads
+       // 
------------------------------------------------------------------------
+
+       private static abstract class Source {
+       
+               abstract void addBuffer(Buffer buffer) throws Exception;
+       }
+
+       private static class PipelinedSubpartitionSource extends Source {
+
+               final PipelinedSubpartition partition;
+
+               PipelinedSubpartitionSource(PipelinedSubpartition partition) {
+                       this.partition = partition;
+               }
+
+               @Override
+               void addBuffer(Buffer buffer) throws Exception {
+                       partition.add(buffer);
+               }
+       }
+
+       private static class RemoteChannelSource extends Source {
+
+               final RemoteInputChannel channel;
+               private int seq = 0;
+
+               RemoteChannelSource(RemoteInputChannel channel) {
+                       this.channel = channel;
+               }
+
+               @Override
+               void addBuffer(Buffer buffer) throws Exception {
+                       channel.onBuffer(buffer, seq++);
+               }
+       }
+
+       // 
------------------------------------------------------------------------
+       //  testing threads
+       // 
------------------------------------------------------------------------
+
+       private static abstract class CheckedThread extends Thread {
+
+               private volatile Throwable error;
+
+               public abstract void go() throws Exception;
+
+               @Override
+               public void run() {
+                       try {
+                               go();
+                       }
+                       catch (Throwable t) {
+                               error = t;
+                       }
+               }
+
+               public void sync() throws Exception {
+                       join();
+
+                       // propagate the error
+                       if (error != null) {
+                               if (error instanceof Error) {
+                                       throw (Error) error;
+                               }
+                               else if (error instanceof Exception) {
+                                       throw (Exception) error;
+                               }
+                               else {
+                                       throw new Exception(error.getMessage(), 
error);
+                               }
+                       }
+               }
+       }
+
+       private static class ProducerThread extends CheckedThread {
+
+               private final Random rnd = new Random();
+               private final Source[] sources;
+               private final int numTotal;
+               private final int maxChunk;
+               private final int yieldAfter;
+
+               ProducerThread(Source[] sources, int numTotal, int maxChunk, 
int yieldAfter) {
+                       this.sources = sources;
+                       this.numTotal = numTotal;
+                       this.maxChunk = maxChunk;
+                       this.yieldAfter = yieldAfter;
+               }
+
+               @Override
+               public void go() throws Exception {
+                       final Buffer buffer = 
InputChannelTestUtils.createMockBuffer(100);
+                       int nextYield = numTotal - yieldAfter;
+
+                       for (int i = numTotal; i > 0;) {
+                               final int nextChannel = 
rnd.nextInt(sources.length);
+                               final int chunk = Math.min(i, 
rnd.nextInt(maxChunk) + 1);
+
+                               final Source next = sources[nextChannel];
+
+                               for (int k = chunk; k > 0; --k) {
+                                       next.addBuffer(buffer);
+                               }
+
+                               i -= chunk;
+
+                               if (i <= nextYield) {
+                                       nextYield -= yieldAfter;
+                                       //noinspection CallToThreadYield
+                                       Thread.yield();
+                               }
+
+                       }
+               }
+       }
+
+       private static class ConsumerThread extends CheckedThread {
+
+               private final SingleInputGate gate;
+               private final int numBuffers;
+
+               ConsumerThread(SingleInputGate gate, int numBuffers) {
+                       this.gate = gate;
+                       this.numBuffers = numBuffers;
+               }
+
+               @Override
+               public void go() throws Exception {
+                       for (int i = numBuffers; i > 0; --i) {
+                               assertNotNull(gate.getNextBufferOrEvent());
+                       }
+               }
+       }
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/5ebd7c84/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateFairnessTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateFairnessTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateFairnessTest.java
new file mode 100644
index 0000000..192b0eb
--- /dev/null
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateFairnessTest.java
@@ -0,0 +1,397 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.io.network.partition;
+
+import org.apache.flink.api.common.JobID;
+import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
+import org.apache.flink.runtime.io.network.ConnectionID;
+import org.apache.flink.runtime.io.network.ConnectionManager;
+import org.apache.flink.runtime.io.network.TaskEventDispatcher;
+import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent;
+import org.apache.flink.runtime.io.network.api.serialization.EventSerializer;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.netty.PartitionStateChecker;
+import org.apache.flink.runtime.io.network.partition.consumer.BufferOrEvent;
+import org.apache.flink.runtime.io.network.partition.consumer.InputChannel;
+import 
org.apache.flink.runtime.io.network.partition.consumer.LocalInputChannel;
+import 
org.apache.flink.runtime.io.network.partition.consumer.RemoteInputChannel;
+import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate;
+import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
+import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
+import org.apache.flink.runtime.metrics.groups.IOMetricGroup;
+import 
org.apache.flink.runtime.operators.testutils.UnregisteredTaskMetricsGroup;
+import 
org.apache.flink.runtime.operators.testutils.UnregisteredTaskMetricsGroup.DummyIOMetricGroup;
+import org.junit.Test;
+import scala.Tuple2;
+
+import java.io.IOException;
+import java.lang.reflect.Field;
+import java.util.ArrayDeque;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashSet;
+
+import static 
org.apache.flink.runtime.io.network.partition.InputChannelTestUtils.createDummyConnectionManager;
+import static 
org.apache.flink.runtime.io.network.partition.InputChannelTestUtils.createMockBuffer;
+import static 
org.apache.flink.runtime.io.network.partition.InputChannelTestUtils.createResultPartitionManager;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+import static org.mockito.Mockito.mock;
+
+public class InputGateFairnessTest {
+
+       @Test
+       public void testFairConsumptionLocalChannelsPreFilled() throws 
Exception {
+               final int numChannels = 37;
+               final int buffersPerChannel = 27;
+
+               final ResultPartition resultPartition = 
mock(ResultPartition.class);
+               final Buffer mockBuffer = createMockBuffer(42);
+
+               // ----- create some source channels and fill them with buffers 
-----
+
+               final PipelinedSubpartition[] sources = new 
PipelinedSubpartition[numChannels];
+
+               for (int i = 0; i < numChannels; i++) {
+                       PipelinedSubpartition partition = new 
PipelinedSubpartition(0, resultPartition);
+
+                       for (int p = 0; p < buffersPerChannel; p++) {
+                               partition.add(mockBuffer);
+                       }
+
+                       partition.finish();
+                       sources[i] = partition;
+               }
+
+               // ----- create reading side -----
+
+               ResultPartitionManager resultPartitionManager = 
createResultPartitionManager(sources);
+
+               SingleInputGate gate = new FairnessVerifyingInputGate(
+                               "Test Task Name",
+                               new JobID(),
+                               new ExecutionAttemptID(),
+                               new IntermediateDataSetID(),
+                               0, numChannels,
+                               mock(PartitionStateChecker.class),
+                               new 
UnregisteredTaskMetricsGroup.DummyIOMetricGroup());
+
+               for (int i = 0; i < numChannels; i++) {
+                       LocalInputChannel channel = new LocalInputChannel(gate, 
i, new ResultPartitionID(),
+                                       resultPartitionManager, 
mock(TaskEventDispatcher.class), new DummyIOMetricGroup());
+                       gate.setInputChannel(new 
IntermediateResultPartitionID(), channel);
+               }
+
+               // read all the buffers and the EOF event
+               for (int i = numChannels * (buffersPerChannel + 1); i > 0; --i) 
{
+                       assertNotNull(gate.getNextBufferOrEvent());
+
+                       int min = Integer.MAX_VALUE;
+                       int max = 0;
+
+                       for (PipelinedSubpartition source : sources) {
+                               int size = source.getCurrentNumberOfBuffers();
+                               min = Math.min(min, size);
+                               max = Math.max(max, size);
+                       }
+
+                       assertTrue(max == min || max == min+1);
+               }
+
+               assertNull(gate.getNextBufferOrEvent());
+       }
+
+       @Test
+       public void testFairConsumptionLocalChannels() throws Exception {
+               final int numChannels = 37;
+               final int buffersPerChannel = 27;
+
+               final ResultPartition resultPartition = 
mock(ResultPartition.class);
+               final Buffer mockBuffer = createMockBuffer(42);
+
+               // ----- create some source channels and fill them with one 
buffer each -----
+
+               final PipelinedSubpartition[] sources = new 
PipelinedSubpartition[numChannels];
+
+               for (int i = 0; i < numChannels; i++) {
+                       sources[i] = new PipelinedSubpartition(0, 
resultPartition);
+               }
+
+               // ----- create reading side -----
+
+               ResultPartitionManager resultPartitionManager = 
createResultPartitionManager(sources);
+
+               SingleInputGate gate = new FairnessVerifyingInputGate(
+                               "Test Task Name",
+                               new JobID(),
+                               new ExecutionAttemptID(),
+                               new IntermediateDataSetID(),
+                               0, numChannels,
+                               mock(PartitionStateChecker.class),
+                               new 
UnregisteredTaskMetricsGroup.DummyIOMetricGroup());
+
+               for (int i = 0; i < numChannels; i++) {
+                       LocalInputChannel channel = new LocalInputChannel(gate, 
i, new ResultPartitionID(),
+                                       resultPartitionManager, 
mock(TaskEventDispatcher.class), new DummyIOMetricGroup());
+                       gate.setInputChannel(new 
IntermediateResultPartitionID(), channel);
+               }
+
+               // seed one initial buffer
+               sources[12].add(mockBuffer);
+
+               // read all the buffers and the EOF event
+               for (int i = 0; i < numChannels * buffersPerChannel; i++) {
+                       assertNotNull(gate.getNextBufferOrEvent());
+
+                       int min = Integer.MAX_VALUE;
+                       int max = 0;
+
+                       for (PipelinedSubpartition source : sources) {
+                               int size = source.getCurrentNumberOfBuffers();
+                               min = Math.min(min, size);
+                               max = Math.max(max, size);
+                       }
+
+                       assertTrue(max == min || max == min+1);
+
+                       if (i % (2 * numChannels) == 0) {
+                               // add three buffers to each channel, in random 
order
+                               fillRandom(sources, 3, mockBuffer);
+                       }
+               }
+
+               // there is still more in the queues
+       }
+
+       @Test
+       public void testFairConsumptionRemoteChannelsPreFilled() throws 
Exception {
+               final int numChannels = 37;
+               final int buffersPerChannel = 27;
+
+               final Buffer mockBuffer = createMockBuffer(42);
+
+               // ----- create some source channels and fill them with buffers 
-----
+
+               SingleInputGate gate = new FairnessVerifyingInputGate(
+                               "Test Task Name",
+                               new JobID(),
+                               new ExecutionAttemptID(),
+                               new IntermediateDataSetID(),
+                               0, numChannels,
+                               mock(PartitionStateChecker.class),
+                               new 
UnregisteredTaskMetricsGroup.DummyIOMetricGroup());
+
+               final ConnectionManager connManager = 
createDummyConnectionManager();
+
+               final RemoteInputChannel[] channels = new 
RemoteInputChannel[numChannels];
+
+               for (int i = 0; i < numChannels; i++) {
+                       RemoteInputChannel channel = new RemoteInputChannel(
+                                       gate, i, new ResultPartitionID(), 
mock(ConnectionID.class), 
+                                       connManager, new Tuple2<>(0, 0), new 
DummyIOMetricGroup());
+
+                       channels[i] = channel;
+                       
+                       for (int p = 0; p < buffersPerChannel; p++) {
+                               channel.onBuffer(mockBuffer, p);
+                       }
+                       
channel.onBuffer(EventSerializer.toBuffer(EndOfPartitionEvent.INSTANCE), 
buffersPerChannel);
+
+                       gate.setInputChannel(new 
IntermediateResultPartitionID(), channel);
+               }
+
+               // read all the buffers and the EOF event
+               for (int i = numChannels * (buffersPerChannel + 1); i > 0; --i) 
{
+                       assertNotNull(gate.getNextBufferOrEvent());
+
+                       int min = Integer.MAX_VALUE;
+                       int max = 0;
+
+                       for (RemoteInputChannel channel : channels) {
+                               int size = channel.getNumberOfQueuedBuffers();
+                               min = Math.min(min, size);
+                               max = Math.max(max, size);
+                       }
+
+                       assertTrue(max == min || max == min+1);
+               }
+
+               assertNull(gate.getNextBufferOrEvent());
+       }
+
+       @Test
+       public void testFairConsumptionRemoteChannels() throws Exception {
+               final int numChannels = 37;
+               final int buffersPerChannel = 27;
+
+               final Buffer mockBuffer = createMockBuffer(42);
+
+               // ----- create some source channels and fill them with buffers 
-----
+
+               SingleInputGate gate = new FairnessVerifyingInputGate(
+                               "Test Task Name",
+                               new JobID(),
+                               new ExecutionAttemptID(),
+                               new IntermediateDataSetID(),
+                               0, numChannels,
+                               mock(PartitionStateChecker.class),
+                               new 
UnregisteredTaskMetricsGroup.DummyIOMetricGroup());
+
+               final ConnectionManager connManager = 
createDummyConnectionManager();
+
+               final RemoteInputChannel[] channels = new 
RemoteInputChannel[numChannels];
+               final int[] channelSequenceNums = new int[numChannels];
+
+               for (int i = 0; i < numChannels; i++) {
+                       RemoteInputChannel channel = new RemoteInputChannel(
+                                       gate, i, new ResultPartitionID(), 
mock(ConnectionID.class),
+                                       connManager, new Tuple2<>(0, 0), new 
DummyIOMetricGroup());
+
+                       channels[i] = channel;
+                       gate.setInputChannel(new 
IntermediateResultPartitionID(), channel);
+               }
+
+               channels[11].onBuffer(mockBuffer, 0);
+               channelSequenceNums[11]++;
+
+               // read all the buffers and the EOF event
+               for (int i = 0; i < numChannels * buffersPerChannel; i++) {
+                       assertNotNull(gate.getNextBufferOrEvent());
+
+                       int min = Integer.MAX_VALUE;
+                       int max = 0;
+
+                       for (RemoteInputChannel channel : channels) {
+                               int size = channel.getNumberOfQueuedBuffers();
+                               min = Math.min(min, size);
+                               max = Math.max(max, size);
+                       }
+
+                       assertTrue(max == min || max == min+1);
+
+                       if (i % (2 * numChannels) == 0) {
+                               // add three buffers to each channel, in random 
order
+                               fillRandom(channels, channelSequenceNums, 3, 
mockBuffer);
+                       }
+               }
+       }
+
+       // 
------------------------------------------------------------------------
+       //  Utilities
+       // 
------------------------------------------------------------------------
+
+       private void fillRandom(PipelinedSubpartition[] partitions, int 
numPerPartition, Buffer buffer) throws Exception {
+               ArrayList<Integer> poss = new ArrayList<>(partitions.length * 
numPerPartition);
+
+               for (int i = 0; i < partitions.length; i++) {
+                       for (int k = 0; k < numPerPartition; k++) {
+                               poss.add(i);
+                       }
+               }
+
+               Collections.shuffle(poss);
+
+               for (Integer i : poss) {
+                       partitions[i].add(buffer);
+               }
+       }
+
+       private void fillRandom(
+                       RemoteInputChannel[] partitions,
+                       int[] sequenceNumbers,
+                       int numPerPartition,
+                       Buffer buffer) throws Exception {
+
+               ArrayList<Integer> poss = new ArrayList<>(partitions.length * 
numPerPartition);
+
+               for (int i = 0; i < partitions.length; i++) {
+                       for (int k = 0; k < numPerPartition; k++) {
+                               poss.add(i);
+                       }
+               }
+
+               Collections.shuffle(poss);
+
+               for (int i : poss) {
+                       partitions[i].onBuffer(buffer, sequenceNumbers[i]++);
+               }
+       }
+       
+       // 
------------------------------------------------------------------------
+
+       private static class FairnessVerifyingInputGate extends SingleInputGate 
{
+
+               private final ArrayDeque<InputChannel> channelsWithData;
+
+               private final HashSet<InputChannel> uniquenessChecker;
+
+               @SuppressWarnings("unchecked")
+               public FairnessVerifyingInputGate(
+                               String owningTaskName,
+                               JobID jobId,
+                               ExecutionAttemptID executionId,
+                               IntermediateDataSetID consumedResultId,
+                               int consumedSubpartitionIndex,
+                               int numberOfInputChannels,
+                               PartitionStateChecker partitionStateChecker,
+                               IOMetricGroup metrics) {
+
+                       super(owningTaskName, jobId, executionId, 
consumedResultId, consumedSubpartitionIndex,
+                                       numberOfInputChannels, 
partitionStateChecker, metrics);
+
+                       try {
+                               Field f = 
SingleInputGate.class.getDeclaredField("inputChannelsWithData");
+                               f.setAccessible(true);
+                               channelsWithData = (ArrayDeque<InputChannel>) 
f.get(this);
+                       }
+                       catch (Exception e) {
+                               throw new RuntimeException(e);
+                       }
+
+                       this.uniquenessChecker = new HashSet<>();
+               }
+
+
+               @Override
+               public BufferOrEvent getNextBufferOrEvent() throws IOException, 
InterruptedException {
+                       synchronized (channelsWithData) {
+                               assertTrue("too many input channels", 
channelsWithData.size() <= getNumberOfInputChannels());
+                               ensureUnique(channelsWithData);
+                       }
+
+                       return super.getNextBufferOrEvent();
+               }
+
+               private void ensureUnique(Collection<InputChannel> channels) {
+                       HashSet<InputChannel> uniquenessChecker = 
this.uniquenessChecker;
+
+                       for (InputChannel channel : channels) {
+                               if (!uniquenessChecker.add(channel)) {
+                                       fail("Duplicate channel in input gate: 
" + channel);
+                               }
+                       }
+
+                       assertTrue("found duplicate input channels", 
uniquenessChecker.size() == channels.size());
+                       uniquenessChecker.clear();
+               }
+       }
+}

Reply via email to