This is an automated email from the ASF dual-hosted git repository. pnowojski pushed a commit to branch release-1.11 in repository https://gitbox.apache.org/repos/asf/flink.git
commit 75c0b5b1f4e0545ee4c55349dd6633bbd13cf128 Author: Arvid Heise <[email protected]> AuthorDate: Tue Jun 16 09:21:41 2020 +0200 [FLINK-18094][network] Fixed UnionInputGate#getChannel. The method assumed that the gates have consecutive indexes starting at 0. --- .../io/network/partition/consumer/UnionInputGate.java | 19 +++++++++++-------- .../network/partition/consumer/InputGateTestBase.java | 2 +- .../partition/consumer/UnionInputGateTest.java | 18 ++++++++++++++++++ 3 files changed, 30 insertions(+), 9 deletions(-) diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java index e863e10..ad8361c 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java @@ -29,6 +29,7 @@ import java.io.IOException; import java.util.Arrays; import java.util.Iterator; import java.util.LinkedHashSet; +import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.concurrent.CompletableFuture; @@ -70,7 +71,7 @@ import static org.apache.flink.util.Preconditions.checkState; public class UnionInputGate extends InputGate { /** The input gates to union. */ - private final InputGate[] inputGates; + private final Map<Integer, InputGate> inputGatesByGateIndex; private final Set<IndexedInputGate> inputGatesWithRemainingData; @@ -89,7 +90,7 @@ public class UnionInputGate extends InputGate { private final int[] inputGateChannelIndexOffsets; public UnionInputGate(IndexedInputGate... inputGates) { - this.inputGates = checkNotNull(inputGates); + inputGatesByGateIndex = Arrays.stream(inputGates).collect(Collectors.toMap(IndexedInputGate::getGateIndex, ig -> ig)); checkArgument(inputGates.length > 1, "Union input gate should union at least two input gates."); if (Arrays.stream(inputGates).map(IndexedInputGate::getGateIndex).distinct().count() != inputGates.length) { @@ -100,8 +101,9 @@ public class UnionInputGate extends InputGate { this.inputGatesWithRemainingData = Sets.newHashSetWithExpectedSize(inputGates.length); final int maxGateIndex = Arrays.stream(inputGates).mapToInt(IndexedInputGate::getGateIndex).max().orElse(0); - inputGateChannelIndexOffsets = new int[maxGateIndex + 1]; int totalNumberOfInputChannels = Arrays.stream(inputGates).mapToInt(IndexedInputGate::getNumberOfInputChannels).sum(); + + inputGateChannelIndexOffsets = new int[maxGateIndex + 1]; inputChannelToInputGateIndex = new int[totalNumberOfInputChannels]; int currentNumberOfInputChannels = 0; @@ -141,8 +143,9 @@ public class UnionInputGate extends InputGate { @Override public InputChannel getChannel(int channelIndex) { - int gateIndex = this.inputChannelToInputGateIndex[channelIndex]; - return inputGates[gateIndex].getChannel(channelIndex - inputGateChannelIndexOffsets[gateIndex]); + int gateIndex = inputChannelToInputGateIndex[channelIndex]; + return inputGatesByGateIndex.get(gateIndex) + .getChannel(channelIndex - inputGateChannelIndexOffsets[gateIndex]); } @Override @@ -253,7 +256,7 @@ public class UnionInputGate extends InputGate { @Override public void sendTaskEvent(TaskEvent event) throws IOException { - for (InputGate inputGate : inputGates) { + for (InputGate inputGate : inputGatesByGateIndex.values()) { inputGate.sendTaskEvent(event); } } @@ -277,7 +280,7 @@ public class UnionInputGate extends InputGate { @Override public void requestPartitions() throws IOException { - for (InputGate inputGate : inputGates) { + for (InputGate inputGate : inputGatesByGateIndex.values()) { inputGate.requestPartitions(); } } @@ -332,7 +335,7 @@ public class UnionInputGate extends InputGate { @Override public void registerBufferReceivedListener(BufferReceivedListener listener) { - for (InputGate inputGate : inputGates) { + for (InputGate inputGate : inputGatesByGateIndex.values()) { inputGate.registerBufferReceivedListener(listener); } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/InputGateTestBase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/InputGateTestBase.java index 14f42f4..ae52a3c 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/InputGateTestBase.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/InputGateTestBase.java @@ -35,7 +35,7 @@ import static org.junit.Assert.assertTrue; */ public abstract class InputGateTestBase { - private int gateIndex; + int gateIndex; @Before public void resetGateIndex() { diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java index 8b4a7df..93b131e 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java @@ -164,4 +164,22 @@ public class UnionInputGateTest extends InputGateTestBase { // Check that updated input channel is visible via UnionInputGate assertThat(unionInputGate.getChannel(1), Matchers.is(inputGate2.getChannel(0))); } + + @Test + public void testGetChannelWithShiftedGateIndexes() { + gateIndex = 2; + final SingleInputGate inputGate1 = createInputGate(1); + TestInputChannel inputChannel1 = new TestInputChannel(inputGate1, 0); + inputGate1.setInputChannels(inputChannel1); + + final SingleInputGate inputGate2 = createInputGate(1); + TestInputChannel inputChannel2 = new TestInputChannel(inputGate2, 0); + inputGate2.setInputChannels(inputChannel2); + + UnionInputGate unionInputGate = new UnionInputGate(inputGate1, inputGate2); + + assertThat(unionInputGate.getChannel(0), Matchers.is(inputChannel1)); + // Check that updated input channel is visible via UnionInputGate + assertThat(unionInputGate.getChannel(1), Matchers.is(inputChannel2)); + } }
