This is an automated email from the ASF dual-hosted git repository.

pnowojski pushed a commit to branch release-1.11
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 75c0b5b1f4e0545ee4c55349dd6633bbd13cf128
Author: Arvid Heise <[email protected]>
AuthorDate: Tue Jun 16 09:21:41 2020 +0200

    [FLINK-18094][network] Fixed UnionInputGate#getChannel.
    
    The method assumed that the gates have consecutive indexes starting at 0.
---
 .../io/network/partition/consumer/UnionInputGate.java | 19 +++++++++++--------
 .../network/partition/consumer/InputGateTestBase.java |  2 +-
 .../partition/consumer/UnionInputGateTest.java        | 18 ++++++++++++++++++
 3 files changed, 30 insertions(+), 9 deletions(-)

diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java
index e863e10..ad8361c 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java
@@ -29,6 +29,7 @@ import java.io.IOException;
 import java.util.Arrays;
 import java.util.Iterator;
 import java.util.LinkedHashSet;
+import java.util.Map;
 import java.util.Optional;
 import java.util.Set;
 import java.util.concurrent.CompletableFuture;
@@ -70,7 +71,7 @@ import static org.apache.flink.util.Preconditions.checkState;
 public class UnionInputGate extends InputGate {
 
        /** The input gates to union. */
-       private final InputGate[] inputGates;
+       private final Map<Integer, InputGate> inputGatesByGateIndex;
 
        private final Set<IndexedInputGate> inputGatesWithRemainingData;
 
@@ -89,7 +90,7 @@ public class UnionInputGate extends InputGate {
        private final int[] inputGateChannelIndexOffsets;
 
        public UnionInputGate(IndexedInputGate... inputGates) {
-               this.inputGates = checkNotNull(inputGates);
+               inputGatesByGateIndex = 
Arrays.stream(inputGates).collect(Collectors.toMap(IndexedInputGate::getGateIndex,
 ig -> ig));
                checkArgument(inputGates.length > 1, "Union input gate should 
union at least two input gates.");
 
                if 
(Arrays.stream(inputGates).map(IndexedInputGate::getGateIndex).distinct().count()
 != inputGates.length) {
@@ -100,8 +101,9 @@ public class UnionInputGate extends InputGate {
                this.inputGatesWithRemainingData = 
Sets.newHashSetWithExpectedSize(inputGates.length);
 
                final int maxGateIndex = 
Arrays.stream(inputGates).mapToInt(IndexedInputGate::getGateIndex).max().orElse(0);
-               inputGateChannelIndexOffsets = new int[maxGateIndex + 1];
                int totalNumberOfInputChannels = 
Arrays.stream(inputGates).mapToInt(IndexedInputGate::getNumberOfInputChannels).sum();
+
+               inputGateChannelIndexOffsets = new int[maxGateIndex + 1];
                inputChannelToInputGateIndex = new 
int[totalNumberOfInputChannels];
 
                int currentNumberOfInputChannels = 0;
@@ -141,8 +143,9 @@ public class UnionInputGate extends InputGate {
 
        @Override
        public InputChannel getChannel(int channelIndex) {
-               int gateIndex = this.inputChannelToInputGateIndex[channelIndex];
-               return inputGates[gateIndex].getChannel(channelIndex - 
inputGateChannelIndexOffsets[gateIndex]);
+               int gateIndex = inputChannelToInputGateIndex[channelIndex];
+               return inputGatesByGateIndex.get(gateIndex)
+                       .getChannel(channelIndex - 
inputGateChannelIndexOffsets[gateIndex]);
        }
 
        @Override
@@ -253,7 +256,7 @@ public class UnionInputGate extends InputGate {
 
        @Override
        public void sendTaskEvent(TaskEvent event) throws IOException {
-               for (InputGate inputGate : inputGates) {
+               for (InputGate inputGate : inputGatesByGateIndex.values()) {
                        inputGate.sendTaskEvent(event);
                }
        }
@@ -277,7 +280,7 @@ public class UnionInputGate extends InputGate {
 
        @Override
        public void requestPartitions() throws IOException {
-               for (InputGate inputGate : inputGates) {
+               for (InputGate inputGate : inputGatesByGateIndex.values()) {
                        inputGate.requestPartitions();
                }
        }
@@ -332,7 +335,7 @@ public class UnionInputGate extends InputGate {
 
        @Override
        public void registerBufferReceivedListener(BufferReceivedListener 
listener) {
-               for (InputGate inputGate : inputGates) {
+               for (InputGate inputGate : inputGatesByGateIndex.values()) {
                        inputGate.registerBufferReceivedListener(listener);
                }
        }
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/InputGateTestBase.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/InputGateTestBase.java
index 14f42f4..ae52a3c 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/InputGateTestBase.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/InputGateTestBase.java
@@ -35,7 +35,7 @@ import static org.junit.Assert.assertTrue;
  */
 public abstract class InputGateTestBase {
 
-       private int gateIndex;
+       int gateIndex;
 
        @Before
        public void resetGateIndex() {
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java
index 8b4a7df..93b131e 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java
@@ -164,4 +164,22 @@ public class UnionInputGateTest extends InputGateTestBase {
                // Check that updated input channel is visible via 
UnionInputGate
                assertThat(unionInputGate.getChannel(1), 
Matchers.is(inputGate2.getChannel(0)));
        }
+
+       @Test
+       public void testGetChannelWithShiftedGateIndexes() {
+               gateIndex = 2;
+               final SingleInputGate inputGate1 = createInputGate(1);
+               TestInputChannel inputChannel1 = new 
TestInputChannel(inputGate1, 0);
+               inputGate1.setInputChannels(inputChannel1);
+
+               final SingleInputGate inputGate2 = createInputGate(1);
+               TestInputChannel inputChannel2 = new 
TestInputChannel(inputGate2, 0);
+               inputGate2.setInputChannels(inputChannel2);
+
+               UnionInputGate unionInputGate = new UnionInputGate(inputGate1, 
inputGate2);
+
+               assertThat(unionInputGate.getChannel(0), 
Matchers.is(inputChannel1));
+               // Check that updated input channel is visible via 
UnionInputGate
+               assertThat(unionInputGate.getChannel(1), 
Matchers.is(inputChannel2));
+       }
 }

Reply via email to