Repository: flink Updated Branches: refs/heads/release-1.1 e2c53cf85 -> 5ebd7c844
[FLINK-5169] [network] Add tests for channel consumption This closes #2882. Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/5ebd7c84 Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/5ebd7c84 Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/5ebd7c84 Branch: refs/heads/release-1.1 Commit: 5ebd7c8443df38cddd9463a4609821518cd1a9cb Parents: 8d97eaa Author: Stephan Ewen <[email protected]> Authored: Sun Nov 27 18:15:40 2016 +0100 Committer: Ufuk Celebi <[email protected]> Committed: Mon Nov 28 21:05:00 2016 +0100 ---------------------------------------------------------------------- .../partition/PipelinedSubpartition.java | 8 + .../partition/consumer/LocalInputChannel.java | 4 +- .../partition/consumer/SingleInputGate.java | 4 +- .../partition/consumer/UnionInputGate.java | 2 +- .../partition/InputChannelTestUtils.java | 89 +++++ .../partition/InputGateConcurrentTest.java | 325 +++++++++++++++ .../partition/InputGateFairnessTest.java | 397 +++++++++++++++++++ 7 files changed, 824 insertions(+), 5 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/5ebd7c84/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/PipelinedSubpartition.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/PipelinedSubpartition.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/PipelinedSubpartition.java index 4d5e378..9d88ff0 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/PipelinedSubpartition.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/PipelinedSubpartition.java @@ -183,6 +183,14 @@ class PipelinedSubpartition extends ResultSubpartition { return readView; } + // ------------------------------------------------------------------------ + + int getCurrentNumberOfBuffers() { + return buffers.size(); + } + + // ------------------------------------------------------------------------ + @Override public String toString() { final long numBuffers; http://git-wip-us.apache.org/repos/asf/flink/blob/5ebd7c84/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannel.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannel.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannel.java index b34dbff..0a02ea1 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannel.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannel.java @@ -65,7 +65,7 @@ public class LocalInputChannel extends InputChannel implements BufferAvailabilit private volatile boolean isReleased; - LocalInputChannel( + public LocalInputChannel( SingleInputGate inputGate, int channelIndex, ResultPartitionID partitionId, @@ -77,7 +77,7 @@ public class LocalInputChannel extends InputChannel implements BufferAvailabilit new Tuple2<Integer, Integer>(0, 0), metrics); } - LocalInputChannel( + public LocalInputChannel( SingleInputGate inputGate, int channelIndex, ResultPartitionID partitionId, http://git-wip-us.apache.org/repos/asf/flink/blob/5ebd7c84/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java index 105d28b..8f44fbc 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java @@ -242,7 +242,7 @@ public class SingleInputGate implements InputGate { this.bufferPool = checkNotNull(bufferPool); } - void setInputChannel(IntermediateResultPartitionID partitionId, InputChannel inputChannel) { + public void setInputChannel(IntermediateResultPartitionID partitionId, InputChannel inputChannel) { synchronized (requestLock) { if (inputChannels.put(checkNotNull(partitionId), checkNotNull(inputChannel)) == null && inputChannel.getClass() == UnknownInputChannel.class) { @@ -527,7 +527,7 @@ public class SingleInputGate implements InputGate { inputChannelsWithData.add(channel); if (availableChannels == 0) { - inputChannelsWithData.notify(); + inputChannelsWithData.notifyAll(); } } http://git-wip-us.apache.org/repos/asf/flink/blob/5ebd7c84/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java index e8ccbb4..55c78af 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java @@ -225,7 +225,7 @@ public class UnionInputGate implements InputGate, InputGateListener { inputGatesWithData.add(inputGate); if (availableInputGates == 0) { - inputGatesWithData.notify(); + inputGatesWithData.notifyAll(); } } http://git-wip-us.apache.org/repos/asf/flink/blob/5ebd7c84/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputChannelTestUtils.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputChannelTestUtils.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputChannelTestUtils.java new file mode 100644 index 0000000..e292576 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputChannelTestUtils.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.io.network.partition; + +import org.apache.flink.runtime.io.network.ConnectionID; +import org.apache.flink.runtime.io.network.ConnectionManager; +import org.apache.flink.runtime.io.network.buffer.Buffer; +import org.apache.flink.runtime.io.network.buffer.BufferProvider; +import org.apache.flink.runtime.io.network.netty.PartitionRequestClient; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyInt; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +/** + * Some utility methods used for testing InputChannels and InputGates. + */ +class InputChannelTestUtils { + + /** + * Creates a simple Buffer that is not recycled (never will be) of the given size. + */ + public static Buffer createMockBuffer(int size) { + final Buffer mockBuffer = mock(Buffer.class); + when(mockBuffer.isBuffer()).thenReturn(true); + when(mockBuffer.getSize()).thenReturn(size); + when(mockBuffer.isRecycled()).thenReturn(false); + + return mockBuffer; + } + + /** + * Creates a result partition manager that ignores all IDs, and simply returns the given + * subpartitions in sequence. + */ + public static ResultPartitionManager createResultPartitionManager(final ResultSubpartition[] sources) throws Exception { + + final Answer<ResultSubpartitionView> viewCreator = new Answer<ResultSubpartitionView>() { + + private int num = 0; + + @Override + public ResultSubpartitionView answer(InvocationOnMock invocation) throws Throwable { + BufferAvailabilityListener channel = (BufferAvailabilityListener) invocation.getArguments()[3]; + return sources[num++].createReadView(null, channel); + } + }; + + ResultPartitionManager manager = mock(ResultPartitionManager.class); + when(manager.createSubpartitionView( + any(ResultPartitionID.class), anyInt(), any(BufferProvider.class), any(BufferAvailabilityListener.class))) + .thenAnswer(viewCreator); + + return manager; + } + + public static ConnectionManager createDummyConnectionManager() throws Exception { + final PartitionRequestClient mockClient = mock(PartitionRequestClient.class); + + final ConnectionManager connManager = mock(ConnectionManager.class); + when(connManager.createPartitionRequestClient(any(ConnectionID.class))).thenReturn(mockClient); + + return connManager; + } + + // ------------------------------------------------------------------------ + + /** This class is not meant to be instantiated */ + private InputChannelTestUtils() {} +} http://git-wip-us.apache.org/repos/asf/flink/blob/5ebd7c84/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateConcurrentTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateConcurrentTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateConcurrentTest.java new file mode 100644 index 0000000..a5f4c7d --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateConcurrentTest.java @@ -0,0 +1,325 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.io.network.partition; + +import org.apache.flink.api.common.JobID; +import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; +import org.apache.flink.runtime.io.network.ConnectionID; +import org.apache.flink.runtime.io.network.ConnectionManager; +import org.apache.flink.runtime.io.network.TaskEventDispatcher; +import org.apache.flink.runtime.io.network.buffer.Buffer; +import org.apache.flink.runtime.io.network.netty.PartitionStateChecker; +import org.apache.flink.runtime.io.network.partition.consumer.LocalInputChannel; +import org.apache.flink.runtime.io.network.partition.consumer.RemoteInputChannel; +import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate; +import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; +import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; +import org.apache.flink.runtime.operators.testutils.UnregisteredTaskMetricsGroup; +import org.apache.flink.runtime.operators.testutils.UnregisteredTaskMetricsGroup.DummyIOMetricGroup; +import org.junit.Test; +import scala.Tuple2; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Random; + +import static org.apache.flink.runtime.io.network.partition.InputChannelTestUtils.createDummyConnectionManager; +import static org.apache.flink.runtime.io.network.partition.InputChannelTestUtils.createResultPartitionManager; +import static org.junit.Assert.assertNotNull; +import static org.mockito.Mockito.mock; + +public class InputGateConcurrentTest { + + @Test + public void testConsumptionWithLocalChannels() throws Exception { + final int numChannels = 11; + final int buffersPerChannel = 1000; + + final ResultPartition resultPartition = mock(ResultPartition.class); + + final PipelinedSubpartition[] partitions = new PipelinedSubpartition[numChannels]; + final Source[] sources = new Source[numChannels]; + + final ResultPartitionManager resultPartitionManager = createResultPartitionManager(partitions); + + final SingleInputGate gate = new SingleInputGate( + "Test Task Name", + new JobID(), + new ExecutionAttemptID(), + new IntermediateDataSetID(), + 0, numChannels, + mock(PartitionStateChecker.class), + new UnregisteredTaskMetricsGroup.DummyIOMetricGroup()); + + for (int i = 0; i < numChannels; i++) { + LocalInputChannel channel = new LocalInputChannel(gate, i, new ResultPartitionID(), + resultPartitionManager, mock(TaskEventDispatcher.class), new DummyIOMetricGroup()); + gate.setInputChannel(new IntermediateResultPartitionID(), channel); + + partitions[i] = new PipelinedSubpartition(0, resultPartition); + sources[i] = new PipelinedSubpartitionSource(partitions[i]); + } + + ProducerThread producer = new ProducerThread(sources, numChannels * buffersPerChannel, 4, 10); + ConsumerThread consumer = new ConsumerThread(gate, numChannels * buffersPerChannel); + producer.start(); + consumer.start(); + + // the 'sync()' call checks for exceptions and failed assertions + producer.sync(); + consumer.sync(); + } + + @Test + public void testConsumptionWithRemoteChannels() throws Exception { + final int numChannels = 11; + final int buffersPerChannel = 1000; + + final ConnectionManager connManager = createDummyConnectionManager(); + final Source[] sources = new Source[numChannels]; + + final SingleInputGate gate = new SingleInputGate( + "Test Task Name", + new JobID(), + new ExecutionAttemptID(), + new IntermediateDataSetID(), + 0, + numChannels, + mock(PartitionStateChecker.class), + new UnregisteredTaskMetricsGroup.DummyIOMetricGroup()); + + for (int i = 0; i < numChannels; i++) { + RemoteInputChannel channel = new RemoteInputChannel( + gate, i, new ResultPartitionID(), mock(ConnectionID.class), + connManager, new Tuple2<>(0, 0), new DummyIOMetricGroup()); + gate.setInputChannel(new IntermediateResultPartitionID(), channel); + + sources[i] = new RemoteChannelSource(channel); + } + + ProducerThread producer = new ProducerThread(sources, numChannels * buffersPerChannel, 4, 10); + ConsumerThread consumer = new ConsumerThread(gate, numChannels * buffersPerChannel); + producer.start(); + consumer.start(); + + // the 'sync()' call checks for exceptions and failed assertions + producer.sync(); + consumer.sync(); + } + + @Test + public void testConsumptionWithMixedChannels() throws Exception { + final int numChannels = 61; + final int numLocalChannels = 20; + final int buffersPerChannel = 1000; + + // fill the local/remote decision + List<Boolean> localOrRemote = new ArrayList<>(numChannels); + for (int i = 0; i < numChannels; i++) { + localOrRemote.add(i < numLocalChannels); + } + Collections.shuffle(localOrRemote); + + final ConnectionManager connManager = createDummyConnectionManager(); + final ResultPartition resultPartition = mock(ResultPartition.class); + + final PipelinedSubpartition[] localPartitions = new PipelinedSubpartition[numLocalChannels]; + final ResultPartitionManager resultPartitionManager = createResultPartitionManager(localPartitions); + + final Source[] sources = new Source[numChannels]; + + final SingleInputGate gate = new SingleInputGate( + "Test Task Name", + new JobID(), + new ExecutionAttemptID(), + new IntermediateDataSetID(), + 0, + numChannels, + mock(PartitionStateChecker.class), + new UnregisteredTaskMetricsGroup.DummyIOMetricGroup()); + + for (int i = 0, local = 0; i < numChannels; i++) { + if (localOrRemote.get(i)) { + // local channel + PipelinedSubpartition psp = new PipelinedSubpartition(0, resultPartition); + localPartitions[local++] = psp; + sources[i] = new PipelinedSubpartitionSource(psp); + + LocalInputChannel channel = new LocalInputChannel(gate, i, new ResultPartitionID(), + resultPartitionManager, mock(TaskEventDispatcher.class), new DummyIOMetricGroup()); + gate.setInputChannel(new IntermediateResultPartitionID(), channel); + } + else { + //remote channel + RemoteInputChannel channel = new RemoteInputChannel( + gate, i, new ResultPartitionID(), mock(ConnectionID.class), + connManager, new Tuple2<>(0, 0), new DummyIOMetricGroup()); + gate.setInputChannel(new IntermediateResultPartitionID(), channel); + + sources[i] = new RemoteChannelSource(channel); + } + } + + ProducerThread producer = new ProducerThread(sources, numChannels * buffersPerChannel, 4, 10); + ConsumerThread consumer = new ConsumerThread(gate, numChannels * buffersPerChannel); + producer.start(); + consumer.start(); + + // the 'sync()' call checks for exceptions and failed assertions + producer.sync(); + consumer.sync(); + } + + // ------------------------------------------------------------------------ + // testing threads + // ------------------------------------------------------------------------ + + private static abstract class Source { + + abstract void addBuffer(Buffer buffer) throws Exception; + } + + private static class PipelinedSubpartitionSource extends Source { + + final PipelinedSubpartition partition; + + PipelinedSubpartitionSource(PipelinedSubpartition partition) { + this.partition = partition; + } + + @Override + void addBuffer(Buffer buffer) throws Exception { + partition.add(buffer); + } + } + + private static class RemoteChannelSource extends Source { + + final RemoteInputChannel channel; + private int seq = 0; + + RemoteChannelSource(RemoteInputChannel channel) { + this.channel = channel; + } + + @Override + void addBuffer(Buffer buffer) throws Exception { + channel.onBuffer(buffer, seq++); + } + } + + // ------------------------------------------------------------------------ + // testing threads + // ------------------------------------------------------------------------ + + private static abstract class CheckedThread extends Thread { + + private volatile Throwable error; + + public abstract void go() throws Exception; + + @Override + public void run() { + try { + go(); + } + catch (Throwable t) { + error = t; + } + } + + public void sync() throws Exception { + join(); + + // propagate the error + if (error != null) { + if (error instanceof Error) { + throw (Error) error; + } + else if (error instanceof Exception) { + throw (Exception) error; + } + else { + throw new Exception(error.getMessage(), error); + } + } + } + } + + private static class ProducerThread extends CheckedThread { + + private final Random rnd = new Random(); + private final Source[] sources; + private final int numTotal; + private final int maxChunk; + private final int yieldAfter; + + ProducerThread(Source[] sources, int numTotal, int maxChunk, int yieldAfter) { + this.sources = sources; + this.numTotal = numTotal; + this.maxChunk = maxChunk; + this.yieldAfter = yieldAfter; + } + + @Override + public void go() throws Exception { + final Buffer buffer = InputChannelTestUtils.createMockBuffer(100); + int nextYield = numTotal - yieldAfter; + + for (int i = numTotal; i > 0;) { + final int nextChannel = rnd.nextInt(sources.length); + final int chunk = Math.min(i, rnd.nextInt(maxChunk) + 1); + + final Source next = sources[nextChannel]; + + for (int k = chunk; k > 0; --k) { + next.addBuffer(buffer); + } + + i -= chunk; + + if (i <= nextYield) { + nextYield -= yieldAfter; + //noinspection CallToThreadYield + Thread.yield(); + } + + } + } + } + + private static class ConsumerThread extends CheckedThread { + + private final SingleInputGate gate; + private final int numBuffers; + + ConsumerThread(SingleInputGate gate, int numBuffers) { + this.gate = gate; + this.numBuffers = numBuffers; + } + + @Override + public void go() throws Exception { + for (int i = numBuffers; i > 0; --i) { + assertNotNull(gate.getNextBufferOrEvent()); + } + } + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/5ebd7c84/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateFairnessTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateFairnessTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateFairnessTest.java new file mode 100644 index 0000000..192b0eb --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateFairnessTest.java @@ -0,0 +1,397 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.io.network.partition; + +import org.apache.flink.api.common.JobID; +import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; +import org.apache.flink.runtime.io.network.ConnectionID; +import org.apache.flink.runtime.io.network.ConnectionManager; +import org.apache.flink.runtime.io.network.TaskEventDispatcher; +import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent; +import org.apache.flink.runtime.io.network.api.serialization.EventSerializer; +import org.apache.flink.runtime.io.network.buffer.Buffer; +import org.apache.flink.runtime.io.network.netty.PartitionStateChecker; +import org.apache.flink.runtime.io.network.partition.consumer.BufferOrEvent; +import org.apache.flink.runtime.io.network.partition.consumer.InputChannel; +import org.apache.flink.runtime.io.network.partition.consumer.LocalInputChannel; +import org.apache.flink.runtime.io.network.partition.consumer.RemoteInputChannel; +import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate; +import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; +import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; +import org.apache.flink.runtime.metrics.groups.IOMetricGroup; +import org.apache.flink.runtime.operators.testutils.UnregisteredTaskMetricsGroup; +import org.apache.flink.runtime.operators.testutils.UnregisteredTaskMetricsGroup.DummyIOMetricGroup; +import org.junit.Test; +import scala.Tuple2; + +import java.io.IOException; +import java.lang.reflect.Field; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashSet; + +import static org.apache.flink.runtime.io.network.partition.InputChannelTestUtils.createDummyConnectionManager; +import static org.apache.flink.runtime.io.network.partition.InputChannelTestUtils.createMockBuffer; +import static org.apache.flink.runtime.io.network.partition.InputChannelTestUtils.createResultPartitionManager; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import static org.mockito.Mockito.mock; + +public class InputGateFairnessTest { + + @Test + public void testFairConsumptionLocalChannelsPreFilled() throws Exception { + final int numChannels = 37; + final int buffersPerChannel = 27; + + final ResultPartition resultPartition = mock(ResultPartition.class); + final Buffer mockBuffer = createMockBuffer(42); + + // ----- create some source channels and fill them with buffers ----- + + final PipelinedSubpartition[] sources = new PipelinedSubpartition[numChannels]; + + for (int i = 0; i < numChannels; i++) { + PipelinedSubpartition partition = new PipelinedSubpartition(0, resultPartition); + + for (int p = 0; p < buffersPerChannel; p++) { + partition.add(mockBuffer); + } + + partition.finish(); + sources[i] = partition; + } + + // ----- create reading side ----- + + ResultPartitionManager resultPartitionManager = createResultPartitionManager(sources); + + SingleInputGate gate = new FairnessVerifyingInputGate( + "Test Task Name", + new JobID(), + new ExecutionAttemptID(), + new IntermediateDataSetID(), + 0, numChannels, + mock(PartitionStateChecker.class), + new UnregisteredTaskMetricsGroup.DummyIOMetricGroup()); + + for (int i = 0; i < numChannels; i++) { + LocalInputChannel channel = new LocalInputChannel(gate, i, new ResultPartitionID(), + resultPartitionManager, mock(TaskEventDispatcher.class), new DummyIOMetricGroup()); + gate.setInputChannel(new IntermediateResultPartitionID(), channel); + } + + // read all the buffers and the EOF event + for (int i = numChannels * (buffersPerChannel + 1); i > 0; --i) { + assertNotNull(gate.getNextBufferOrEvent()); + + int min = Integer.MAX_VALUE; + int max = 0; + + for (PipelinedSubpartition source : sources) { + int size = source.getCurrentNumberOfBuffers(); + min = Math.min(min, size); + max = Math.max(max, size); + } + + assertTrue(max == min || max == min+1); + } + + assertNull(gate.getNextBufferOrEvent()); + } + + @Test + public void testFairConsumptionLocalChannels() throws Exception { + final int numChannels = 37; + final int buffersPerChannel = 27; + + final ResultPartition resultPartition = mock(ResultPartition.class); + final Buffer mockBuffer = createMockBuffer(42); + + // ----- create some source channels and fill them with one buffer each ----- + + final PipelinedSubpartition[] sources = new PipelinedSubpartition[numChannels]; + + for (int i = 0; i < numChannels; i++) { + sources[i] = new PipelinedSubpartition(0, resultPartition); + } + + // ----- create reading side ----- + + ResultPartitionManager resultPartitionManager = createResultPartitionManager(sources); + + SingleInputGate gate = new FairnessVerifyingInputGate( + "Test Task Name", + new JobID(), + new ExecutionAttemptID(), + new IntermediateDataSetID(), + 0, numChannels, + mock(PartitionStateChecker.class), + new UnregisteredTaskMetricsGroup.DummyIOMetricGroup()); + + for (int i = 0; i < numChannels; i++) { + LocalInputChannel channel = new LocalInputChannel(gate, i, new ResultPartitionID(), + resultPartitionManager, mock(TaskEventDispatcher.class), new DummyIOMetricGroup()); + gate.setInputChannel(new IntermediateResultPartitionID(), channel); + } + + // seed one initial buffer + sources[12].add(mockBuffer); + + // read all the buffers and the EOF event + for (int i = 0; i < numChannels * buffersPerChannel; i++) { + assertNotNull(gate.getNextBufferOrEvent()); + + int min = Integer.MAX_VALUE; + int max = 0; + + for (PipelinedSubpartition source : sources) { + int size = source.getCurrentNumberOfBuffers(); + min = Math.min(min, size); + max = Math.max(max, size); + } + + assertTrue(max == min || max == min+1); + + if (i % (2 * numChannels) == 0) { + // add three buffers to each channel, in random order + fillRandom(sources, 3, mockBuffer); + } + } + + // there is still more in the queues + } + + @Test + public void testFairConsumptionRemoteChannelsPreFilled() throws Exception { + final int numChannels = 37; + final int buffersPerChannel = 27; + + final Buffer mockBuffer = createMockBuffer(42); + + // ----- create some source channels and fill them with buffers ----- + + SingleInputGate gate = new FairnessVerifyingInputGate( + "Test Task Name", + new JobID(), + new ExecutionAttemptID(), + new IntermediateDataSetID(), + 0, numChannels, + mock(PartitionStateChecker.class), + new UnregisteredTaskMetricsGroup.DummyIOMetricGroup()); + + final ConnectionManager connManager = createDummyConnectionManager(); + + final RemoteInputChannel[] channels = new RemoteInputChannel[numChannels]; + + for (int i = 0; i < numChannels; i++) { + RemoteInputChannel channel = new RemoteInputChannel( + gate, i, new ResultPartitionID(), mock(ConnectionID.class), + connManager, new Tuple2<>(0, 0), new DummyIOMetricGroup()); + + channels[i] = channel; + + for (int p = 0; p < buffersPerChannel; p++) { + channel.onBuffer(mockBuffer, p); + } + channel.onBuffer(EventSerializer.toBuffer(EndOfPartitionEvent.INSTANCE), buffersPerChannel); + + gate.setInputChannel(new IntermediateResultPartitionID(), channel); + } + + // read all the buffers and the EOF event + for (int i = numChannels * (buffersPerChannel + 1); i > 0; --i) { + assertNotNull(gate.getNextBufferOrEvent()); + + int min = Integer.MAX_VALUE; + int max = 0; + + for (RemoteInputChannel channel : channels) { + int size = channel.getNumberOfQueuedBuffers(); + min = Math.min(min, size); + max = Math.max(max, size); + } + + assertTrue(max == min || max == min+1); + } + + assertNull(gate.getNextBufferOrEvent()); + } + + @Test + public void testFairConsumptionRemoteChannels() throws Exception { + final int numChannels = 37; + final int buffersPerChannel = 27; + + final Buffer mockBuffer = createMockBuffer(42); + + // ----- create some source channels and fill them with buffers ----- + + SingleInputGate gate = new FairnessVerifyingInputGate( + "Test Task Name", + new JobID(), + new ExecutionAttemptID(), + new IntermediateDataSetID(), + 0, numChannels, + mock(PartitionStateChecker.class), + new UnregisteredTaskMetricsGroup.DummyIOMetricGroup()); + + final ConnectionManager connManager = createDummyConnectionManager(); + + final RemoteInputChannel[] channels = new RemoteInputChannel[numChannels]; + final int[] channelSequenceNums = new int[numChannels]; + + for (int i = 0; i < numChannels; i++) { + RemoteInputChannel channel = new RemoteInputChannel( + gate, i, new ResultPartitionID(), mock(ConnectionID.class), + connManager, new Tuple2<>(0, 0), new DummyIOMetricGroup()); + + channels[i] = channel; + gate.setInputChannel(new IntermediateResultPartitionID(), channel); + } + + channels[11].onBuffer(mockBuffer, 0); + channelSequenceNums[11]++; + + // read all the buffers and the EOF event + for (int i = 0; i < numChannels * buffersPerChannel; i++) { + assertNotNull(gate.getNextBufferOrEvent()); + + int min = Integer.MAX_VALUE; + int max = 0; + + for (RemoteInputChannel channel : channels) { + int size = channel.getNumberOfQueuedBuffers(); + min = Math.min(min, size); + max = Math.max(max, size); + } + + assertTrue(max == min || max == min+1); + + if (i % (2 * numChannels) == 0) { + // add three buffers to each channel, in random order + fillRandom(channels, channelSequenceNums, 3, mockBuffer); + } + } + } + + // ------------------------------------------------------------------------ + // Utilities + // ------------------------------------------------------------------------ + + private void fillRandom(PipelinedSubpartition[] partitions, int numPerPartition, Buffer buffer) throws Exception { + ArrayList<Integer> poss = new ArrayList<>(partitions.length * numPerPartition); + + for (int i = 0; i < partitions.length; i++) { + for (int k = 0; k < numPerPartition; k++) { + poss.add(i); + } + } + + Collections.shuffle(poss); + + for (Integer i : poss) { + partitions[i].add(buffer); + } + } + + private void fillRandom( + RemoteInputChannel[] partitions, + int[] sequenceNumbers, + int numPerPartition, + Buffer buffer) throws Exception { + + ArrayList<Integer> poss = new ArrayList<>(partitions.length * numPerPartition); + + for (int i = 0; i < partitions.length; i++) { + for (int k = 0; k < numPerPartition; k++) { + poss.add(i); + } + } + + Collections.shuffle(poss); + + for (int i : poss) { + partitions[i].onBuffer(buffer, sequenceNumbers[i]++); + } + } + + // ------------------------------------------------------------------------ + + private static class FairnessVerifyingInputGate extends SingleInputGate { + + private final ArrayDeque<InputChannel> channelsWithData; + + private final HashSet<InputChannel> uniquenessChecker; + + @SuppressWarnings("unchecked") + public FairnessVerifyingInputGate( + String owningTaskName, + JobID jobId, + ExecutionAttemptID executionId, + IntermediateDataSetID consumedResultId, + int consumedSubpartitionIndex, + int numberOfInputChannels, + PartitionStateChecker partitionStateChecker, + IOMetricGroup metrics) { + + super(owningTaskName, jobId, executionId, consumedResultId, consumedSubpartitionIndex, + numberOfInputChannels, partitionStateChecker, metrics); + + try { + Field f = SingleInputGate.class.getDeclaredField("inputChannelsWithData"); + f.setAccessible(true); + channelsWithData = (ArrayDeque<InputChannel>) f.get(this); + } + catch (Exception e) { + throw new RuntimeException(e); + } + + this.uniquenessChecker = new HashSet<>(); + } + + + @Override + public BufferOrEvent getNextBufferOrEvent() throws IOException, InterruptedException { + synchronized (channelsWithData) { + assertTrue("too many input channels", channelsWithData.size() <= getNumberOfInputChannels()); + ensureUnique(channelsWithData); + } + + return super.getNextBufferOrEvent(); + } + + private void ensureUnique(Collection<InputChannel> channels) { + HashSet<InputChannel> uniquenessChecker = this.uniquenessChecker; + + for (InputChannel channel : channels) { + if (!uniquenessChecker.add(channel)) { + fail("Duplicate channel in input gate: " + channel); + } + } + + assertTrue("found duplicate input channels", uniquenessChecker.size() == channels.size()); + uniquenessChecker.clear(); + } + } +}
