This is an automated email from the ASF dual-hosted git repository. pnowojski pushed a commit to branch release-1.11 in repository https://gitbox.apache.org/repos/asf/flink.git
commit 1d211b93c4d7a180f3bb5ab131929201af29b8dc Author: Arvid Heise <[email protected]> AuthorDate: Mon Jun 15 21:43:32 2020 +0200 [FLINK-18094][network] Buffers are only addressed through InputChannelInfo. This removes the need to translate the InputChannelInfo back and forth to flattened indexes across all InputGates. All index-based data structures are replaced by maps that associate a certain state to a given InputChannelInfo. For performance reasons, these maps are fully initialized upon construction, such that no nodes need to be added/removed during runtime and only values are updated. Additionally, this commit unifies the creation of BarrierHandlers (similar signature) and removes the error-prone offset handling from CheckpointedInputGate. --- .../network/api/reader/AbstractRecordReader.java | 22 +++--- .../network/partition/consumer/BufferOrEvent.java | 33 ++++---- .../partition/consumer/SingleInputGate.java | 4 +- .../network/partition/consumer/UnionInputGate.java | 22 ++---- .../io/network/api/writer/RecordWriterTest.java | 5 +- .../partition/consumer/LocalInputChannelTest.java | 4 +- .../partition/consumer/SingleInputGateBuilder.java | 24 +++++- .../partition/consumer/SingleInputGateTest.java | 19 +++-- .../io/AlternatingCheckpointBarrierHandler.java | 8 +- .../runtime/io/CheckpointBarrierAligner.java | 81 ++++++++++--------- .../runtime/io/CheckpointBarrierHandler.java | 6 +- .../runtime/io/CheckpointBarrierTracker.java | 5 +- .../runtime/io/CheckpointBarrierUnaligner.java | 80 +++++++------------ .../runtime/io/CheckpointedInputGate.java | 43 +--------- .../streaming/runtime/io/InputProcessorUtil.java | 91 +++++----------------- .../runtime/io/StreamTaskNetworkInput.java | 25 +++++- .../AlternatingCheckpointBarrierHandlerTest.java | 45 ++++++----- .../CheckpointBarrierAlignerMassiveRandomTest.java | 18 ++++- .../io/CheckpointBarrierAlignerTestBase.java | 12 +-- .../runtime/io/CheckpointBarrierTrackerTest.java | 7 +- ...CheckpointBarrierUnalignerCancellationTest.java | 5 +- .../runtime/io/CheckpointBarrierUnalignerTest.java | 50 ++++++------ .../CreditBasedCheckpointBarrierAlignerTest.java | 2 +- .../runtime/io/InputProcessorUtilTest.java | 31 -------- .../flink/streaming/runtime/io/MockInputGate.java | 2 +- .../runtime/io/StreamTaskNetworkInputTest.java | 27 ++++--- 26 files changed, 287 insertions(+), 384 deletions(-) diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/reader/AbstractRecordReader.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/reader/AbstractRecordReader.java index 1c98d0c..5632370 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/reader/AbstractRecordReader.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/reader/AbstractRecordReader.java @@ -19,6 +19,7 @@ package org.apache.flink.runtime.io.network.api.reader; import org.apache.flink.core.io.IOReadableWritable; +import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo; import org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer; import org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer.DeserializationResult; import org.apache.flink.runtime.io.network.api.serialization.SpillingAdaptiveSpanningRecordDeserializer; @@ -27,6 +28,9 @@ import org.apache.flink.runtime.io.network.partition.consumer.BufferOrEvent; import org.apache.flink.runtime.io.network.partition.consumer.InputGate; import java.io.IOException; +import java.util.Map; +import java.util.function.Function; +import java.util.stream.Collectors; /** * A record-oriented reader. @@ -37,7 +41,7 @@ import java.io.IOException; */ abstract class AbstractRecordReader<T extends IOReadableWritable> extends AbstractReader implements ReaderBase { - private final RecordDeserializer<T>[] recordDeserializers; + private final Map<InputChannelInfo, RecordDeserializer<T>> recordDeserializers; private RecordDeserializer<T> currentRecordDeserializer; @@ -58,10 +62,10 @@ abstract class AbstractRecordReader<T extends IOReadableWritable> extends Abstra super(inputGate); // Initialize one deserializer per input channel - this.recordDeserializers = new SpillingAdaptiveSpanningRecordDeserializer[inputGate.getNumberOfInputChannels()]; - for (int i = 0; i < recordDeserializers.length; i++) { - recordDeserializers[i] = new SpillingAdaptiveSpanningRecordDeserializer<T>(tmpDirectories); - } + recordDeserializers = inputGate.getChannelInfos().stream() + .collect(Collectors.toMap( + Function.identity(), + channelInfo -> new SpillingAdaptiveSpanningRecordDeserializer<>(tmpDirectories))); } protected boolean getNextRecord(T target) throws IOException, InterruptedException { @@ -96,15 +100,15 @@ abstract class AbstractRecordReader<T extends IOReadableWritable> extends Abstra final BufferOrEvent bufferOrEvent = inputGate.getNext().orElseThrow(IllegalStateException::new); if (bufferOrEvent.isBuffer()) { - currentRecordDeserializer = recordDeserializers[bufferOrEvent.getChannelIndex()]; + currentRecordDeserializer = recordDeserializers.get(bufferOrEvent.getChannelInfo()); currentRecordDeserializer.setNextBuffer(bufferOrEvent.getBuffer()); } else { // sanity check for leftover data in deserializers. events should only come between // records, not in the middle of a fragment - if (recordDeserializers[bufferOrEvent.getChannelIndex()].hasUnfinishedData()) { + if (recordDeserializers.get(bufferOrEvent.getChannelInfo()).hasUnfinishedData()) { throw new IOException( - "Received an event in channel " + bufferOrEvent.getChannelIndex() + " while still having " + "Received an event in channel " + bufferOrEvent.getChannelInfo() + " while still having " + "data from a record. This indicates broken serialization logic. " + "If you are using custom serialization code (Writable or Value types), check their " + "serialization routines. In the case of Kryo, check the respective Kryo serializer."); @@ -125,7 +129,7 @@ abstract class AbstractRecordReader<T extends IOReadableWritable> extends Abstra } public void clearBuffers() { - for (RecordDeserializer<?> deserializer : recordDeserializers) { + for (RecordDeserializer<?> deserializer : recordDeserializers.values()) { Buffer buffer = deserializer.getCurrentBuffer(); if (buffer != null && !buffer.isRecycled()) { buffer.recycleBuffer(); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/BufferOrEvent.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/BufferOrEvent.java index 1ec864d..498e338 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/BufferOrEvent.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/BufferOrEvent.java @@ -19,10 +19,10 @@ package org.apache.flink.runtime.io.network.partition.consumer; import org.apache.flink.annotation.VisibleForTesting; +import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo; import org.apache.flink.runtime.event.AbstractEvent; import org.apache.flink.runtime.io.network.buffer.Buffer; -import static org.apache.flink.util.Preconditions.checkArgument; import static org.apache.flink.util.Preconditions.checkNotNull; /** @@ -42,34 +42,34 @@ public class BufferOrEvent { */ private boolean moreAvailable; - private int channelIndex; + private InputChannelInfo channelInfo; private final int size; - public BufferOrEvent(Buffer buffer, int channelIndex, boolean moreAvailable) { + public BufferOrEvent(Buffer buffer, InputChannelInfo channelInfo, boolean moreAvailable) { this.buffer = checkNotNull(buffer); this.event = null; - this.channelIndex = channelIndex; + this.channelInfo = channelInfo; this.moreAvailable = moreAvailable; this.size = buffer.getSize(); } - public BufferOrEvent(AbstractEvent event, int channelIndex, boolean moreAvailable, int size) { + public BufferOrEvent(AbstractEvent event, InputChannelInfo channelInfo, boolean moreAvailable, int size) { this.buffer = null; this.event = checkNotNull(event); - this.channelIndex = channelIndex; + this.channelInfo = channelInfo; this.moreAvailable = moreAvailable; this.size = size; } @VisibleForTesting - public BufferOrEvent(Buffer buffer, int channelIndex) { - this(buffer, channelIndex, true); + public BufferOrEvent(Buffer buffer, InputChannelInfo channelInfo) { + this(buffer, channelInfo, true); } @VisibleForTesting - public BufferOrEvent(AbstractEvent event, int channelIndex) { - this(event, channelIndex, true, 0); + public BufferOrEvent(AbstractEvent event, InputChannelInfo channelInfo) { + this(event, channelInfo, true, 0); } public boolean isBuffer() { @@ -88,13 +88,12 @@ public class BufferOrEvent { return event; } - public int getChannelIndex() { - return channelIndex; + public InputChannelInfo getChannelInfo() { + return channelInfo; } - public void setChannelIndex(int channelIndex) { - checkArgument(channelIndex >= 0); - this.channelIndex = channelIndex; + public void setChannelInfo(InputChannelInfo channelInfo) { + this.channelInfo = channelInfo; } public boolean moreAvailable() { @@ -103,8 +102,8 @@ public class BufferOrEvent { @Override public String toString() { - return String.format("BufferOrEvent [%s, channelIndex = %d, size = %d]", - isBuffer() ? buffer : event, channelIndex, size); + return String.format("BufferOrEvent [%s, channelInfo = %d, size = %d]", + isBuffer() ? buffer : event, channelInfo, size); } public void setMoreAvailable(boolean moreAvailable) { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java index 0f227f9..0bd06c0 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java @@ -669,7 +669,7 @@ public class SingleInputGate extends IndexedInputGate { } private BufferOrEvent transformBuffer(Buffer buffer, boolean moreAvailable, InputChannel currentChannel) { - return new BufferOrEvent(decompressBufferIfNeeded(buffer), currentChannel.getChannelIndex(), moreAvailable); + return new BufferOrEvent(decompressBufferIfNeeded(buffer), currentChannel.getChannelInfo(), moreAvailable); } private BufferOrEvent transformEvent( @@ -700,7 +700,7 @@ public class SingleInputGate extends IndexedInputGate { currentChannel.releaseAllResources(); } - return new BufferOrEvent(event, currentChannel.getChannelIndex(), moreAvailable, buffer.getSize()); + return new BufferOrEvent(event, currentChannel.getChannelInfo(), moreAvailable, buffer.getSize()); } private Buffer decompressBufferIfNeeded(Buffer buffer) { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java index ad8361c..c05eef7 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java @@ -176,10 +176,11 @@ public class UnionInputGate extends InputGate { InputWithData<IndexedInputGate, BufferOrEvent> inputWithData = next.get(); handleEndOfPartitionEvent(inputWithData.data, inputWithData.input); - return Optional.of(adjustForUnionInputGate( - inputWithData.data, - inputWithData.input, - inputWithData.moreAvailable)); + if (!inputWithData.data.moreAvailable()) { + inputWithData.data.setMoreAvailable(inputWithData.moreAvailable); + } + + return Optional.of(inputWithData.data); } private Optional<InputWithData<IndexedInputGate, BufferOrEvent>> waitAndGetNextData(boolean blocking) @@ -217,19 +218,6 @@ public class UnionInputGate extends InputGate { } } - private BufferOrEvent adjustForUnionInputGate( - BufferOrEvent bufferOrEvent, - IndexedInputGate inputGate, - boolean moreInputGatesAvailable) { - // Set the channel index to identify the input channel (across all unioned input gates) - final int channelIndexOffset = inputGateChannelIndexOffsets[inputGate.getGateIndex()]; - - bufferOrEvent.setChannelIndex(channelIndexOffset + bufferOrEvent.getChannelIndex()); - bufferOrEvent.setMoreAvailable(bufferOrEvent.moreAvailable() || moreInputGatesAvailable); - - return bufferOrEvent; - } - private void handleEndOfPartitionEvent(BufferOrEvent bufferOrEvent, InputGate inputGate) { if (bufferOrEvent.isEvent() && bufferOrEvent.getEvent().getClass() == EndOfPartitionEvent.class diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/RecordWriterTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/RecordWriterTest.java index 6e75dff..caa934a 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/RecordWriterTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/RecordWriterTest.java @@ -25,6 +25,7 @@ import org.apache.flink.core.memory.DataOutputView; import org.apache.flink.core.memory.MemorySegment; import org.apache.flink.runtime.checkpoint.CheckpointOptions; import org.apache.flink.runtime.checkpoint.channel.ChannelStateReader; +import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo; import org.apache.flink.runtime.event.AbstractEvent; import org.apache.flink.runtime.io.network.api.CheckpointBarrier; import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent; @@ -657,12 +658,12 @@ public class RecordWriterTest { static BufferOrEvent parseBuffer(BufferConsumer bufferConsumer, int targetChannel) throws IOException { Buffer buffer = buildSingleBuffer(bufferConsumer); if (buffer.isBuffer()) { - return new BufferOrEvent(buffer, targetChannel); + return new BufferOrEvent(buffer, new InputChannelInfo(0, targetChannel)); } else { // is event: AbstractEvent event = EventSerializer.fromBuffer(buffer, RecordWriterTest.class.getClassLoader()); buffer.recycleBuffer(); // the buffer is not needed anymore - return new BufferOrEvent(event, targetChannel); + return new BufferOrEvent(event, new InputChannelInfo(0, targetChannel)); } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannelTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannelTest.java index e9e16db..b6c0f48 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannelTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannelTest.java @@ -629,11 +629,11 @@ public class LocalInputChannelTest { boe.get().getBuffer().recycleBuffer(); // Check that we don't receive too many buffers - if (++numberOfBuffersPerChannel[boe.get().getChannelIndex()] + if (++numberOfBuffersPerChannel[boe.get().getChannelInfo().getInputChannelIdx()] > numberOfExpectedBuffersPerChannel) { throw new IllegalStateException("Received more buffers than expected " + - "on channel " + boe.get().getChannelIndex() + "."); + "on channel " + boe.get().getChannelInfo() + "."); } } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateBuilder.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateBuilder.java index b279998..8ae12cfc 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateBuilder.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateBuilder.java @@ -29,7 +29,11 @@ import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; import org.apache.flink.runtime.taskmanager.NettyShuffleEnvironmentConfiguration; import org.apache.flink.util.function.SupplierWithException; +import javax.annotation.Nullable; + import java.io.IOException; +import java.util.function.BiFunction; +import java.util.stream.IntStream; /** * Utility class to encapsulate the logic of building a {@link SingleInputGate} instance. @@ -54,6 +58,9 @@ public class SingleInputGateBuilder { private MemorySegmentProvider segmentProvider = InputChannelTestUtils.StubMemorySegmentProvider.getInstance(); + @Nullable + private BiFunction<InputChannelBuilder, SingleInputGate, InputChannel> channelFactory = null; + private SupplierWithException<BufferPool, IOException> bufferPoolFactory = () -> { throw new UnsupportedOperationException(); }; @@ -112,8 +119,17 @@ public class SingleInputGateBuilder { return this; } + /** + * Adds automatic initialization of all channels with the given factory. + */ + public SingleInputGateBuilder setChannelFactory( + BiFunction<InputChannelBuilder, SingleInputGate, InputChannel> channelFactory) { + this.channelFactory = channelFactory; + return this; + } + public SingleInputGate build() { - return new SingleInputGate( + SingleInputGate gate = new SingleInputGate( "Single Input Gate", gateIndex, intermediateDataSetID, @@ -124,5 +140,11 @@ public class SingleInputGateBuilder { bufferPoolFactory, bufferDecompressor, segmentProvider); + if (channelFactory != null) { + gate.setInputChannels(IntStream.range(0, numberOfChannels) + .mapToObj(index -> channelFactory.apply(InputChannelBuilder.newBuilder().setChannelIndex(index), gate)) + .toArray(InputChannel[]::new)); + } + return gate; } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java index 0bd9a41..ea3f668 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java @@ -821,14 +821,13 @@ public class SingleInputGateTest extends InputGateTestBase { // Setup final SingleInputGate inputGate = createInputGate(network, 2, ResultPartitionType.PIPELINED); - final int channelIndex1 = 0, channelIndex2 = 1; final RemoteInputChannel remoteInputChannel1 = InputChannelBuilder.newBuilder() - .setChannelIndex(channelIndex1) + .setChannelIndex(0) .setupFromNettyShuffleEnvironment(network) .setConnectionManager(new TestingConnectionManager()) .buildRemoteChannel(inputGate); final RemoteInputChannel remoteInputChannel2 = InputChannelBuilder.newBuilder() - .setChannelIndex(channelIndex2) + .setChannelIndex(1) .setupFromNettyShuffleEnvironment(network) .setConnectionManager(new TestingConnectionManager()) .buildRemoteChannel(inputGate); @@ -838,12 +837,12 @@ public class SingleInputGateTest extends InputGateTestBase { inputGate.registerBufferReceivedListener(new BufferReceivedListener() { @Override public void notifyBufferReceived(Buffer buffer, InputChannelInfo channelInfo) { - notifications.add(new BufferOrEvent(buffer, channelInfo.getInputChannelIdx())); + notifications.add(new BufferOrEvent(buffer, channelInfo)); } @Override public void notifyBarrierReceived(CheckpointBarrier barrier, InputChannelInfo channelInfo) { - notifications.add(new BufferOrEvent(barrier, channelInfo.getInputChannelIdx())); + notifications.add(new BufferOrEvent(barrier, channelInfo)); } }); setupInputGate(inputGate, remoteInputChannel1, remoteInputChannel2); @@ -873,10 +872,10 @@ public class SingleInputGateTest extends InputGateTestBase { } assertEquals(getIds(asList( - new BufferOrEvent(new CheckpointBarrier(0, 0, options), channelIndex2), - new BufferOrEvent(createBuffer(11), channelIndex1), - new BufferOrEvent(new CheckpointBarrier(1, 0, options), channelIndex1), - new BufferOrEvent(createBuffer(22), channelIndex2) + new BufferOrEvent(new CheckpointBarrier(0, 0, options), remoteInputChannel2.getChannelInfo()), + new BufferOrEvent(createBuffer(11), remoteInputChannel1.getChannelInfo()), + new BufferOrEvent(new CheckpointBarrier(1, 0, options), remoteInputChannel1.getChannelInfo()), + new BufferOrEvent(createBuffer(22), remoteInputChannel2.getChannelInfo()) )), getIds(notifications)); } @@ -1071,7 +1070,7 @@ public class SingleInputGateTest extends InputGateTestBase { final Optional<BufferOrEvent> bufferOrEvent = inputGate.getNext(); assertTrue(bufferOrEvent.isPresent()); assertEquals(expectedIsBuffer, bufferOrEvent.get().isBuffer()); - assertEquals(expectedChannelIndex, bufferOrEvent.get().getChannelIndex()); + assertEquals(inputGate.getChannel(expectedChannelIndex).getChannelInfo(), bufferOrEvent.get().getChannelInfo()); assertEquals(expectedMoreAvailable, bufferOrEvent.get().moreAvailable()); if (!expectedMoreAvailable) { assertFalse(inputGate.pollNext().isPresent()); diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/AlternatingCheckpointBarrierHandler.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/AlternatingCheckpointBarrierHandler.java index 3b27d95..2fb6e72 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/AlternatingCheckpointBarrierHandler.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/AlternatingCheckpointBarrierHandler.java @@ -51,12 +51,12 @@ class AlternatingCheckpointBarrierHandler extends CheckpointBarrierHandler { } @Override - public boolean isBlocked(int channelIndex) { - return activeHandler.isBlocked(channelIndex); + public boolean isBlocked(InputChannelInfo channelInfo) { + return activeHandler.isBlocked(channelInfo); } @Override - public void processBarrier(CheckpointBarrier receivedBarrier, int channelIndex) throws Exception { + public void processBarrier(CheckpointBarrier receivedBarrier, InputChannelInfo channelInfo) throws Exception { if (receivedBarrier.getId() < lastSeenBarrierId) { return; } @@ -70,7 +70,7 @@ class AlternatingCheckpointBarrierHandler extends CheckpointBarrierHandler { new CheckpointException(format("checkpoint subsumed by %d", lastSeenBarrierId), CHECKPOINT_DECLINED_SUBSUMED)); } - activeHandler.processBarrier(receivedBarrier, channelIndex); + activeHandler.processBarrier(receivedBarrier, channelInfo); } @Override diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierAligner.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierAligner.java index 01892bc..a052a70 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierAligner.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierAligner.java @@ -22,6 +22,7 @@ import org.apache.flink.annotation.Internal; import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.runtime.checkpoint.CheckpointException; import org.apache.flink.runtime.checkpoint.CheckpointFailureReason; +import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo; import org.apache.flink.runtime.io.network.api.CancelCheckpointMarker; import org.apache.flink.runtime.io.network.api.CheckpointBarrier; import org.apache.flink.runtime.io.network.partition.consumer.InputGate; @@ -31,9 +32,11 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.IOException; +import java.util.Arrays; import java.util.Map; +import java.util.function.Function; +import java.util.stream.Collectors; -import static org.apache.flink.util.Preconditions.checkNotNull; import static org.apache.flink.util.Preconditions.checkState; /** @@ -46,14 +49,8 @@ public class CheckpointBarrierAligner extends CheckpointBarrierHandler { private static final Logger LOG = LoggerFactory.getLogger(CheckpointBarrierAligner.class); - /** Used to get InputGate by channel index. */ - private final InputGate[] channelIndexToInputGate; - - /** Used to get channel index offset by InputGate. */ - private final Map<InputGate, Integer> inputGateToChannelIndexOffset; - /** Flags that indicate whether a channel is currently blocked/buffered. */ - private final boolean[] blockedChannels; + private final Map<InputChannelInfo, Boolean> blockedChannels; /** The total number of channels that this buffer handles data from. */ private final int totalNumberOfInputChannels; @@ -78,18 +75,20 @@ public class CheckpointBarrierAligner extends CheckpointBarrierHandler { /** The time (in nanoseconds) that the latest alignment took. */ private long latestAlignmentDurationNanos; + private final InputGate[] inputGates; + CheckpointBarrierAligner( String taskName, - InputGate[] channelIndexToInputGate, - Map<InputGate, Integer> inputGateToChannelIndexOffset, - AbstractInvokable toNotifyOnCheckpoint) { + AbstractInvokable toNotifyOnCheckpoint, + InputGate... inputGates) { super(toNotifyOnCheckpoint); - this.taskName = taskName; - this.channelIndexToInputGate = checkNotNull(channelIndexToInputGate); - this.inputGateToChannelIndexOffset = checkNotNull(inputGateToChannelIndexOffset); - this.totalNumberOfInputChannels = channelIndexToInputGate.length; - this.blockedChannels = new boolean[totalNumberOfInputChannels]; + this.taskName = taskName; + this.inputGates = inputGates; + blockedChannels = Arrays.stream(inputGates) + .flatMap(gate -> gate.getChannelInfos().stream()) + .collect(Collectors.toMap(Function.identity(), info -> false)); + totalNumberOfInputChannels = blockedChannels.size(); } @Override @@ -104,12 +103,12 @@ public class CheckpointBarrierAligner extends CheckpointBarrierHandler { public void releaseBlocksAndResetBarriers() { LOG.debug("{}: End of stream alignment, feeding buffered data back.", taskName); - for (int i = 0; i < blockedChannels.length; i++) { - if (blockedChannels[i]) { - resumeConsumption(i); + blockedChannels.entrySet().forEach(blockedChannel -> { + if (blockedChannel.getValue()) { + resumeConsumption(blockedChannel.getKey()); } - blockedChannels[i] = false; - } + blockedChannel.setValue(false); + }); // the next barrier that comes must assume it is the first numBarriersReceived = 0; @@ -121,17 +120,17 @@ public class CheckpointBarrierAligner extends CheckpointBarrierHandler { } @Override - public boolean isBlocked(int channelIndex) { - return blockedChannels[channelIndex]; + public boolean isBlocked(InputChannelInfo channelInfo) { + return blockedChannels.get(channelInfo); } @Override - public void processBarrier(CheckpointBarrier receivedBarrier, int channelIndex) throws Exception { + public void processBarrier(CheckpointBarrier receivedBarrier, InputChannelInfo channelInfo) throws Exception { final long barrierId = receivedBarrier.getId(); // fast path for single channel cases if (totalNumberOfInputChannels == 1) { - resumeConsumption(channelIndex); + resumeConsumption(channelInfo); if (barrierId > currentCheckpointId) { // new checkpoint currentCheckpointId = barrierId; @@ -147,7 +146,7 @@ public class CheckpointBarrierAligner extends CheckpointBarrierHandler { if (barrierId == currentCheckpointId) { // regular case - onBarrier(channelIndex); + onBarrier(channelInfo); } else if (barrierId > currentCheckpointId) { // we did not complete the current checkpoint, another started before @@ -167,21 +166,21 @@ public class CheckpointBarrierAligner extends CheckpointBarrierHandler { releaseBlocksAndResetBarriers(); // begin a new checkpoint - beginNewAlignment(barrierId, channelIndex, receivedBarrier.getTimestamp()); + beginNewAlignment(barrierId, channelInfo, receivedBarrier.getTimestamp()); } else { // ignore trailing barrier from an earlier checkpoint (obsolete now) - resumeConsumption(channelIndex); + resumeConsumption(channelInfo); } } else if (barrierId > currentCheckpointId) { // first barrier of a new checkpoint - beginNewAlignment(barrierId, channelIndex, receivedBarrier.getTimestamp()); + beginNewAlignment(barrierId, channelInfo, receivedBarrier.getTimestamp()); } else { // either the current checkpoint was canceled (numBarriers == 0) or // this barrier is from an old subsumed checkpoint - resumeConsumption(channelIndex); + resumeConsumption(channelInfo); } // check if we have all barriers - since canceled checkpoints always have zero barriers @@ -202,11 +201,11 @@ public class CheckpointBarrierAligner extends CheckpointBarrierHandler { protected void beginNewAlignment( long checkpointId, - int channelIndex, + InputChannelInfo channelInfo, long checkpointTimestamp) throws IOException { markCheckpointStart(checkpointTimestamp); currentCheckpointId = checkpointId; - onBarrier(channelIndex); + onBarrier(channelInfo); startOfAlignmentTimestamp = System.nanoTime(); @@ -218,20 +217,20 @@ public class CheckpointBarrierAligner extends CheckpointBarrierHandler { /** * Blocks the given channel index, from which a barrier has been received. * - * @param channelIndex The channel index to block. + * @param channelInfo The channel to block. */ - protected void onBarrier(int channelIndex) throws IOException { - if (!blockedChannels[channelIndex]) { - blockedChannels[channelIndex] = true; + protected void onBarrier(InputChannelInfo channelInfo) throws IOException { + if (!blockedChannels.get(channelInfo)) { + blockedChannels.put(channelInfo, true); numBarriersReceived++; if (LOG.isDebugEnabled()) { - LOG.debug("{}: Received barrier from channel {}.", taskName, channelIndex); + LOG.debug("{}: Received barrier from channel {}.", taskName, channelInfo); } } else { - throw new IOException("Stream corrupt: Repeated barrier for same checkpoint on input " + channelIndex); + throw new IOException("Stream corrupt: Repeated barrier for same checkpoint on input " + channelInfo); } } @@ -339,11 +338,11 @@ public class CheckpointBarrierAligner extends CheckpointBarrierHandler { return numBarriersReceived > 0; } - private void resumeConsumption(int channelIndex) { - InputGate inputGate = channelIndexToInputGate[channelIndex]; + private void resumeConsumption(InputChannelInfo channelInfo) { + InputGate inputGate = inputGates[channelInfo.getGateIdx()]; checkState(!inputGate.isFinished(), "InputGate already finished."); - inputGate.resumeConsumption(channelIndex - inputGateToChannelIndexOffset.get(inputGate)); + inputGate.resumeConsumption(channelInfo.getInputChannelIdx()); } @VisibleForTesting diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierHandler.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierHandler.java index 15382ee..fb0a319 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierHandler.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierHandler.java @@ -58,10 +58,10 @@ public abstract class CheckpointBarrierHandler implements Closeable { /** * Checks whether the channel with the given index is blocked. * - * @param channelIndex The channel index to check. + * @param channelInfo The channel index to check. * @return True if the channel is blocked, false if not. */ - public boolean isBlocked(int channelIndex) { + public boolean isBlocked(InputChannelInfo channelInfo) { return false; } @@ -69,7 +69,7 @@ public abstract class CheckpointBarrierHandler implements Closeable { public void close() throws IOException { } - public abstract void processBarrier(CheckpointBarrier receivedBarrier, int channelIndex) throws Exception; + public abstract void processBarrier(CheckpointBarrier receivedBarrier, InputChannelInfo channelInfo) throws Exception; public abstract void processCancellationBarrier(CancelCheckpointMarker cancelBarrier) throws Exception; diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierTracker.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierTracker.java index 7dbfbaa..6b89854 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierTracker.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierTracker.java @@ -21,6 +21,7 @@ package org.apache.flink.streaming.runtime.io; import org.apache.flink.annotation.Internal; import org.apache.flink.runtime.checkpoint.CheckpointException; import org.apache.flink.runtime.checkpoint.CheckpointFailureReason; +import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo; import org.apache.flink.runtime.io.network.api.CancelCheckpointMarker; import org.apache.flink.runtime.io.network.api.CheckpointBarrier; import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable; @@ -75,7 +76,7 @@ public class CheckpointBarrierTracker extends CheckpointBarrierHandler { this.pendingCheckpoints = new ArrayDeque<>(); } - public void processBarrier(CheckpointBarrier receivedBarrier, int channelIndex) throws Exception { + public void processBarrier(CheckpointBarrier receivedBarrier, InputChannelInfo channelInfo) throws Exception { final long barrierId = receivedBarrier.getId(); // fast path for single channel trackers @@ -86,7 +87,7 @@ public class CheckpointBarrierTracker extends CheckpointBarrierHandler { // general path for multiple input channels if (LOG.isDebugEnabled()) { - LOG.debug("Received barrier for checkpoint {} from channel {}", barrierId, channelIndex); + LOG.debug("Received barrier for checkpoint {} from channel {}", barrierId, channelInfo); } // find the checkpoint barrier in the queue of pending barriers diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierUnaligner.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierUnaligner.java index 114c03e..8d7cdbf 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierUnaligner.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierUnaligner.java @@ -29,6 +29,7 @@ import org.apache.flink.runtime.io.network.api.CancelCheckpointMarker; import org.apache.flink.runtime.io.network.api.CheckpointBarrier; import org.apache.flink.runtime.io.network.buffer.Buffer; import org.apache.flink.runtime.io.network.buffer.BufferReceivedListener; +import org.apache.flink.runtime.io.network.partition.consumer.InputGate; import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable; import org.apache.flink.streaming.runtime.tasks.SubtaskCheckpointCoordinator; @@ -41,10 +42,11 @@ import javax.annotation.concurrent.ThreadSafe; import java.io.Closeable; import java.io.IOException; import java.util.Arrays; +import java.util.Map; import java.util.Optional; import java.util.concurrent.CompletableFuture; import java.util.function.Function; -import java.util.stream.IntStream; +import java.util.stream.Collectors; import static org.apache.flink.runtime.checkpoint.CheckpointFailureReason.CHECKPOINT_DECLINED_SUBSUMED; import static org.apache.flink.util.CloseableIterator.ofElement; @@ -66,20 +68,11 @@ public class CheckpointBarrierUnaligner extends CheckpointBarrierHandler { * Tag the state of which input channel has pending in-flight buffers; that is, already received buffers that * predate the checkpoint barrier of the current checkpoint. */ - private final boolean[] hasInflightBuffers; + private final Map<InputChannelInfo, Boolean> hasInflightBuffers; private int numBarrierConsumed; /** - * Contains the offsets of the channel indices for each gate when flattening the channels of all gates. - * - * <p>For example, consider 3 gates with 4 channels, {@code gateChannelOffsets = [0, 4, 8]}. - */ - private final int[] gateChannelOffsets; - - private final InputChannelInfo[] channelInfos; - - /** * The checkpoint id to guarantee that we would trigger only one checkpoint when reading the same barrier from * different channels. * @@ -92,31 +85,17 @@ public class CheckpointBarrierUnaligner extends CheckpointBarrierHandler { private final ThreadSafeUnaligner threadSafeUnaligner; CheckpointBarrierUnaligner( - int[] numberOfInputChannelsPerGate, SubtaskCheckpointCoordinator checkpointCoordinator, String taskName, - AbstractInvokable toNotifyOnCheckpoint) { + AbstractInvokable toNotifyOnCheckpoint, + InputGate... inputGates) { super(toNotifyOnCheckpoint); this.taskName = taskName; - - final int numGates = numberOfInputChannelsPerGate.length; - - gateChannelOffsets = new int[numGates]; - for (int index = 1; index < numGates; index++) { - gateChannelOffsets[index] = gateChannelOffsets[index - 1] + numberOfInputChannelsPerGate[index - 1]; - } - - final int totalNumChannels = gateChannelOffsets[numGates - 1] + numberOfInputChannelsPerGate[numGates - 1]; - hasInflightBuffers = new boolean[totalNumChannels]; - - channelInfos = IntStream.range(0, numGates) - .mapToObj(gateIndex -> IntStream.range(0, numberOfInputChannelsPerGate[gateIndex]) - .mapToObj(channelIndex -> new InputChannelInfo(gateIndex, channelIndex))) - .flatMap(Function.identity()) - .toArray(InputChannelInfo[]::new); - - threadSafeUnaligner = new ThreadSafeUnaligner(totalNumChannels, checkNotNull(checkpointCoordinator), this); + hasInflightBuffers = Arrays.stream(inputGates) + .flatMap(gate -> gate.getChannelInfos().stream()) + .collect(Collectors.toMap(Function.identity(), info -> false)); + threadSafeUnaligner = new ThreadSafeUnaligner(checkNotNull(checkpointCoordinator), this, inputGates); } /** @@ -127,7 +106,7 @@ public class CheckpointBarrierUnaligner extends CheckpointBarrierHandler { * <p>Note this is also suitable for the trigger case of local input channel. */ @Override - public void processBarrier(CheckpointBarrier receivedBarrier, int channelIndex) throws Exception { + public void processBarrier(CheckpointBarrier receivedBarrier, InputChannelInfo channelInfo) throws Exception { long barrierId = receivedBarrier.getId(); if (currentConsumedCheckpointId > barrierId || (currentConsumedCheckpointId == barrierId && !isCheckpointPending())) { // ignore old and cancelled barriers @@ -136,13 +115,13 @@ public class CheckpointBarrierUnaligner extends CheckpointBarrierHandler { if (currentConsumedCheckpointId < barrierId) { currentConsumedCheckpointId = barrierId; numBarrierConsumed = 0; - Arrays.fill(hasInflightBuffers, true); + hasInflightBuffers.entrySet().forEach(hasInflightBuffer -> hasInflightBuffer.setValue(true)); } if (currentConsumedCheckpointId == barrierId) { - hasInflightBuffers[channelIndex] = false; + hasInflightBuffers.put(channelInfo, false); numBarrierConsumed++; } - threadSafeUnaligner.notifyBarrierReceived(receivedBarrier, channelInfos[channelIndex]); + threadSafeUnaligner.notifyBarrierReceived(receivedBarrier, channelInfo); } @Override @@ -184,7 +163,7 @@ public class CheckpointBarrierUnaligner extends CheckpointBarrierHandler { checkpointId, currentConsumedCheckpointId); - Arrays.fill(hasInflightBuffers, false); + hasInflightBuffers.entrySet().forEach(hasInflightBuffer -> hasInflightBuffer.setValue(false)); numBarrierConsumed = 0; } } @@ -213,7 +192,7 @@ public class CheckpointBarrierUnaligner extends CheckpointBarrierHandler { if (checkpointId > currentConsumedCheckpointId) { return true; } - return hasInflightBuffers[getFlattenedChannelIndex(channelInfo)]; + return hasInflightBuffers.get(channelInfo); } @Override @@ -231,10 +210,6 @@ public class CheckpointBarrierUnaligner extends CheckpointBarrierHandler { return numBarrierConsumed > 0; } - private int getFlattenedChannelIndex(InputChannelInfo channelInfo) { - return gateChannelOffsets[channelInfo.getGateIdx()] + channelInfo.getInputChannelIdx(); - } - @VisibleForTesting int getNumOpenChannels() { return threadSafeUnaligner.getNumOpenChannels(); @@ -259,7 +234,7 @@ public class CheckpointBarrierUnaligner extends CheckpointBarrierHandler { * Tag the state of which input channel has not received the barrier, such that newly arriving buffers need * to be written in the unaligned checkpoint. */ - private final boolean[] storeNewBuffers; + private final Map<InputChannelInfo, Boolean> storeNewBuffers; /** The number of input channels which has received or processed the barrier. */ private int numBarriersReceived; @@ -282,9 +257,11 @@ public class CheckpointBarrierUnaligner extends CheckpointBarrierHandler { private final CheckpointBarrierUnaligner handler; - ThreadSafeUnaligner(int totalNumChannels, SubtaskCheckpointCoordinator checkpointCoordinator, CheckpointBarrierUnaligner handler) { - this.numOpenChannels = totalNumChannels; - this.storeNewBuffers = new boolean[totalNumChannels]; + ThreadSafeUnaligner(SubtaskCheckpointCoordinator checkpointCoordinator, CheckpointBarrierUnaligner handler, InputGate... inputGates) { + storeNewBuffers = Arrays.stream(inputGates) + .flatMap(gate -> gate.getChannelInfos().stream()) + .collect(Collectors.toMap(Function.identity(), info -> false)); + numOpenChannels = storeNewBuffers.size(); this.checkpointCoordinator = checkpointCoordinator; this.handler = handler; } @@ -298,13 +275,12 @@ public class CheckpointBarrierUnaligner extends CheckpointBarrierHandler { handler.executeInTaskThread(() -> handler.notifyCheckpoint(barrier), "notifyCheckpoint"); } - int channelIndex = handler.getFlattenedChannelIndex(channelInfo); - if (barrierId == currentReceivedCheckpointId && storeNewBuffers[channelIndex]) { + if (barrierId == currentReceivedCheckpointId && storeNewBuffers.get(channelInfo)) { if (LOG.isDebugEnabled()) { - LOG.debug("{}: Received barrier from channel {} @ {}.", handler.taskName, channelIndex, barrierId); + LOG.debug("{}: Received barrier from channel {} @ {}.", handler.taskName, channelInfo, barrierId); } - storeNewBuffers[channelIndex] = false; + storeNewBuffers.put(channelInfo, false); if (++numBarriersReceived == numOpenChannels) { allBarriersReceivedFuture.complete(null); @@ -314,7 +290,7 @@ public class CheckpointBarrierUnaligner extends CheckpointBarrierHandler { @Override public synchronized void notifyBufferReceived(Buffer buffer, InputChannelInfo channelInfo) { - if (storeNewBuffers[handler.getFlattenedChannelIndex(channelInfo)]) { + if (storeNewBuffers.get(channelInfo)) { checkpointCoordinator.getChannelStateWriter().addInputData( currentReceivedCheckpointId, channelInfo, @@ -350,7 +326,7 @@ public class CheckpointBarrierUnaligner extends CheckpointBarrierHandler { } currentReceivedCheckpointId = barrierId; - Arrays.fill(storeNewBuffers, true); + storeNewBuffers.entrySet().forEach(storeNewBuffer -> storeNewBuffer.setValue(true)); numBarriersReceived = 0; allBarriersReceivedFuture = new CompletableFuture<>(); checkpointCoordinator.initCheckpoint(barrierId, barrier.getCheckpointOptions()); @@ -397,7 +373,7 @@ public class CheckpointBarrierUnaligner extends CheckpointBarrierHandler { return false; } - Arrays.fill(storeNewBuffers, false); + storeNewBuffers.entrySet().forEach(storeNewBuffer -> storeNewBuffer.setValue(false)); numBarriersReceived = 0; return true; } diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/CheckpointedInputGate.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/CheckpointedInputGate.java index 9428515..cb503b1 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/CheckpointedInputGate.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/CheckpointedInputGate.java @@ -28,10 +28,6 @@ import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent; import org.apache.flink.runtime.io.network.partition.consumer.BufferOrEvent; import org.apache.flink.runtime.io.network.partition.consumer.InputChannel; import org.apache.flink.runtime.io.network.partition.consumer.InputGate; -import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; import java.io.Closeable; import java.io.IOException; @@ -47,39 +43,14 @@ import static org.apache.flink.util.Preconditions.checkState; */ @Internal public class CheckpointedInputGate implements PullingAsyncDataInput<BufferOrEvent>, Closeable { - - private static final Logger LOG = LoggerFactory.getLogger(CheckpointedInputGate.class); - private final CheckpointBarrierHandler barrierHandler; /** The gate that the buffer draws its input from. */ private final InputGate inputGate; - private final int channelIndexOffset; - /** Indicate end of the input. */ private boolean isFinished; - public CheckpointedInputGate( - InputGate inputGate, - String taskName, - AbstractInvokable toNotifyOnCheckpoint) { - this( - inputGate, - new CheckpointBarrierAligner( - taskName, - InputProcessorUtil.generateChannelIndexToInputGateMap(inputGate), - InputProcessorUtil.generateInputGateToChannelIndexOffsetMap(inputGate), - toNotifyOnCheckpoint) - ); - } - - public CheckpointedInputGate( - InputGate inputGate, - CheckpointBarrierHandler barrierHandler) { - this(inputGate, barrierHandler, 0); - } - /** * Creates a new checkpoint stream aligner. * @@ -89,15 +60,11 @@ public class CheckpointedInputGate implements PullingAsyncDataInput<BufferOrEven * * @param inputGate The input gate to draw the buffers and events from. * @param barrierHandler Handler that controls which channels are blocked. - * @param channelIndexOffset Optional offset added to channelIndex returned from the inputGate - * before passing it to the barrierHandler. */ public CheckpointedInputGate( InputGate inputGate, - CheckpointBarrierHandler barrierHandler, - int channelIndexOffset) { + CheckpointBarrierHandler barrierHandler) { this.inputGate = inputGate; - this.channelIndexOffset = channelIndexOffset; this.barrierHandler = barrierHandler; } @@ -116,14 +83,14 @@ public class CheckpointedInputGate implements PullingAsyncDataInput<BufferOrEven } BufferOrEvent bufferOrEvent = next.get(); - checkState(!barrierHandler.isBlocked(offsetChannelIndex(bufferOrEvent.getChannelIndex()))); + checkState(!barrierHandler.isBlocked(bufferOrEvent.getChannelInfo())); if (bufferOrEvent.isBuffer()) { return next; } else if (bufferOrEvent.getEvent().getClass() == CheckpointBarrier.class) { CheckpointBarrier checkpointBarrier = (CheckpointBarrier) bufferOrEvent.getEvent(); - barrierHandler.processBarrier(checkpointBarrier, offsetChannelIndex(bufferOrEvent.getChannelIndex())); + barrierHandler.processBarrier(checkpointBarrier, bufferOrEvent.getChannelInfo()); return next; } else if (bufferOrEvent.getEvent().getClass() == CancelCheckpointMarker.class) { @@ -152,10 +119,6 @@ public class CheckpointedInputGate implements PullingAsyncDataInput<BufferOrEven return barrierHandler.getAllBarriersReceivedFuture(checkpointId); } - private int offsetChannelIndex(int channelIndex) { - return channelIndex + channelIndexOffset; - } - private Optional<BufferOrEvent> handleEmptyBuffer() { if (inputGate.isFinished()) { isFinished = true; diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/InputProcessorUtil.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/InputProcessorUtil.java index 762761b..3ed8584 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/InputProcessorUtil.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/InputProcessorUtil.java @@ -31,9 +31,7 @@ import org.apache.flink.shaded.guava18.com.google.common.collect.Iterables; import java.util.Arrays; import java.util.Collection; import java.util.Comparator; -import java.util.HashMap; -import java.util.Map; -import java.util.stream.IntStream; +import java.util.List; /** * Utility for creating {@link CheckpointedInputGate} based on checkpoint mode @@ -71,108 +69,55 @@ public class InputProcessorUtil { String taskName, List<IndexedInputGate>... inputGates) { - IntStream numberOfInputChannelsPerGate = - Arrays - .stream(inputGates) - .flatMap(collection -> collection.stream()) - .sorted(Comparator.comparingInt(IndexedInputGate::getGateIndex)) - .mapToInt(InputGate::getNumberOfInputChannels); - - Map<InputGate, Integer> inputGateToChannelIndexOffset = generateInputGateToChannelIndexOffsetMap(unionedInputGates); - // Note that numberOfInputChannelsPerGate and inputGateToChannelIndexOffset have a bit different - // indexing and purposes. - // - // The numberOfInputChannelsPerGate is indexed based on flattened input gates, and sorted based on GateIndex, - // so that it can be used in combination with InputChannelInfo class. - // - // The inputGateToChannelIndexOffset is based upon unioned input gates and it's use for translating channel - // indexes from perspective of UnionInputGate to perspective of SingleInputGate. - + IndexedInputGate[] sortedInputGates = Arrays.stream(inputGates) + .flatMap(Collection::stream) + .sorted(Comparator.comparing(IndexedInputGate::getGateIndex)) + .toArray(IndexedInputGate[]::new); CheckpointBarrierHandler barrierHandler = createCheckpointBarrierHandler( config, - numberOfInputChannelsPerGate, + sortedInputGates, checkpointCoordinator, taskName, - generateChannelIndexToInputGateMap(unionedInputGates), - inputGateToChannelIndexOffset, toNotifyOnCheckpoint); registerCheckpointMetrics(taskIOMetricGroup, barrierHandler); + InputGate[] unionedInputGates = Arrays.stream(inputGates) + .map(InputGateUtil::createInputGate) + .toArray(InputGate[]::new); barrierHandler.getBufferReceivedListener().ifPresent(listener -> { for (final InputGate inputGate : unionedInputGates) { inputGate.registerBufferReceivedListener(listener); } }); - CheckpointedInputGate[] checkpointedInputGates = new CheckpointedInputGate[unionedInputGates.length]; - - for (int i = 0; i < unionedInputGates.length; i++) { - checkpointedInputGates[i] = new CheckpointedInputGate( - unionedInputGates[i], barrierHandler, inputGateToChannelIndexOffset.get(unionedInputGates[i])); - } - - return checkpointedInputGates; + return Arrays.stream(unionedInputGates) + .map(unionedInputGate -> new CheckpointedInputGate(unionedInputGate, barrierHandler)) + .toArray(CheckpointedInputGate[]::new); } private static CheckpointBarrierHandler createCheckpointBarrierHandler( StreamConfig config, - IntStream numberOfInputChannelsPerGate, + InputGate[] inputGates, SubtaskCheckpointCoordinator checkpointCoordinator, String taskName, - InputGate[] channelIndexToInputGate, - Map<InputGate, Integer> inputGateToChannelIndexOffset, AbstractInvokable toNotifyOnCheckpoint) { switch (config.getCheckpointMode()) { case EXACTLY_ONCE: if (config.isUnalignedCheckpointsEnabled()) { return new AlternatingCheckpointBarrierHandler( - new CheckpointBarrierAligner( - taskName, - channelIndexToInputGate, - inputGateToChannelIndexOffset, - toNotifyOnCheckpoint), - new CheckpointBarrierUnaligner( - numberOfInputChannelsPerGate.toArray(), - checkpointCoordinator, - taskName, - toNotifyOnCheckpoint), + new CheckpointBarrierAligner(taskName, toNotifyOnCheckpoint, inputGates), + new CheckpointBarrierUnaligner(checkpointCoordinator, taskName, toNotifyOnCheckpoint, inputGates), toNotifyOnCheckpoint); } - return new CheckpointBarrierAligner( - taskName, - channelIndexToInputGate, - inputGateToChannelIndexOffset, - toNotifyOnCheckpoint); + return new CheckpointBarrierAligner(taskName, toNotifyOnCheckpoint, inputGates); case AT_LEAST_ONCE: - return new CheckpointBarrierTracker(numberOfInputChannelsPerGate.sum(), toNotifyOnCheckpoint); + int numInputChannels = Arrays.stream(inputGates).mapToInt(InputGate::getNumberOfInputChannels).sum(); + return new CheckpointBarrierTracker(numInputChannels, toNotifyOnCheckpoint); default: throw new UnsupportedOperationException("Unrecognized Checkpointing Mode: " + config.getCheckpointMode()); } } - static InputGate[] generateChannelIndexToInputGateMap(InputGate ...inputGates) { - int numberOfInputChannels = Arrays.stream(inputGates).mapToInt(InputGate::getNumberOfInputChannels).sum(); - InputGate[] channelIndexToInputGate = new InputGate[numberOfInputChannels]; - int channelIndexOffset = 0; - for (InputGate inputGate: inputGates) { - for (int i = 0; i < inputGate.getNumberOfInputChannels(); ++i) { - channelIndexToInputGate[channelIndexOffset + i] = inputGate; - } - channelIndexOffset += inputGate.getNumberOfInputChannels(); - } - return channelIndexToInputGate; - } - - static Map<InputGate, Integer> generateInputGateToChannelIndexOffsetMap(InputGate ...inputGates) { - Map<InputGate, Integer> inputGateToChannelIndexOffset = new HashMap<>(); - int channelIndexOffset = 0; - for (InputGate inputGate: inputGates) { - inputGateToChannelIndexOffset.put(inputGate, channelIndexOffset); - channelIndexOffset += inputGate.getNumberOfInputChannels(); - } - return inputGateToChannelIndexOffset; - } - private static void registerCheckpointMetrics(TaskIOMetricGroup taskIOMetricGroup, CheckpointBarrierHandler barrierHandler) { taskIOMetricGroup.gauge(MetricNames.CHECKPOINT_ALIGNMENT_TIME, barrierHandler::getAlignmentDurationNanos); taskIOMetricGroup.gauge(MetricNames.CHECKPOINT_START_DELAY_TIME, barrierHandler::getCheckpointStartDelayNanos); diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/StreamTaskNetworkInput.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/StreamTaskNetworkInput.java index 7134b44..8585dab 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/StreamTaskNetworkInput.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/StreamTaskNetworkInput.java @@ -23,6 +23,7 @@ import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.core.io.InputStatus; import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter; +import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo; import org.apache.flink.runtime.event.AbstractEvent; import org.apache.flink.runtime.io.disk.iomanager.IOManager; import org.apache.flink.runtime.io.network.api.CheckpointBarrier; @@ -41,7 +42,12 @@ import org.apache.flink.streaming.runtime.streamrecord.StreamElementSerializer; import org.apache.flink.streaming.runtime.streamstatus.StatusWatermarkValve; import org.apache.flink.streaming.runtime.streamstatus.StreamStatus; +import javax.annotation.Nonnull; + import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.concurrent.CompletableFuture; @@ -74,6 +80,8 @@ public final class StreamTaskNetworkInput<T> implements StreamTaskInput<T> { private final int inputIndex; + private final Map<InputChannelInfo, Integer> channelIndexes; + private int lastChannel = UNSPECIFIED; private RecordDeserializer<DeserializationDelegate<StreamElement>> currentRecordDeserializer = null; @@ -98,6 +106,18 @@ public final class StreamTaskNetworkInput<T> implements StreamTaskInput<T> { this.statusWatermarkValve = checkNotNull(statusWatermarkValve); this.inputIndex = inputIndex; + this.channelIndexes = getChannelIndexes(checkpointedInputGate); + } + + @Nonnull + private static Map<InputChannelInfo, Integer> getChannelIndexes(CheckpointedInputGate checkpointedInputGate) { + int index = 0; + List<InputChannelInfo> channelInfos = checkpointedInputGate.getChannelInfos(); + Map<InputChannelInfo, Integer> channelIndexes = new HashMap<>(channelInfos.size()); + for (InputChannelInfo channelInfo : channelInfos) { + channelIndexes.put(channelInfo, index++); + } + return channelIndexes; } @VisibleForTesting @@ -114,6 +134,7 @@ public final class StreamTaskNetworkInput<T> implements StreamTaskInput<T> { this.recordDeserializers = recordDeserializers; this.statusWatermarkValve = statusWatermarkValve; this.inputIndex = inputIndex; + this.channelIndexes = getChannelIndexes(checkpointedInputGate); } @Override @@ -168,7 +189,7 @@ public final class StreamTaskNetworkInput<T> implements StreamTaskInput<T> { private void processBufferOrEvent(BufferOrEvent bufferOrEvent) throws IOException { if (bufferOrEvent.isBuffer()) { - lastChannel = bufferOrEvent.getChannelIndex(); + lastChannel = channelIndexes.get(bufferOrEvent.getChannelInfo()); checkState(lastChannel != StreamTaskInput.UNSPECIFIED); currentRecordDeserializer = recordDeserializers[lastChannel]; checkState(currentRecordDeserializer != null, @@ -186,7 +207,7 @@ public final class StreamTaskNetworkInput<T> implements StreamTaskInput<T> { // release the record deserializer immediately, // which is very valuable in case of bounded stream - releaseDeserializer(bufferOrEvent.getChannelIndex()); + releaseDeserializer(channelIndexes.get(bufferOrEvent.getChannelInfo())); } } diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/AlternatingCheckpointBarrierHandlerTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/AlternatingCheckpointBarrierHandlerTest.java index 16a6bf2..ea147df 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/AlternatingCheckpointBarrierHandlerTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/AlternatingCheckpointBarrierHandlerTest.java @@ -25,6 +25,7 @@ import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo; import org.apache.flink.runtime.io.network.api.CheckpointBarrier; import org.apache.flink.runtime.io.network.buffer.Buffer; import org.apache.flink.runtime.io.network.partition.consumer.InputChannel; +import org.apache.flink.runtime.io.network.partition.consumer.InputChannelBuilder; import org.apache.flink.runtime.io.network.partition.consumer.InputGate; import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate; import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGateBuilder; @@ -43,7 +44,6 @@ import java.util.Arrays; import java.util.List; import static java.util.Collections.singletonList; -import static java.util.Collections.singletonMap; import static org.apache.flink.runtime.checkpoint.CheckpointType.CHECKPOINT; import static org.apache.flink.runtime.checkpoint.CheckpointType.SAVEPOINT; import static org.apache.flink.runtime.io.network.api.serialization.EventSerializer.toBuffer; @@ -88,14 +88,14 @@ public class AlternatingCheckpointBarrierHandlerTest { SingleInputGate inputGate = new SingleInputGateBuilder().setNumberOfChannels(2).build(); inputGate.setInputChannels(new TestInputChannel(inputGate, 0), new TestInputChannel(inputGate, 1)); TestInvokable target = new TestInvokable(); - CheckpointBarrierAligner alignedHandler = new CheckpointBarrierAligner("test", new InputGate[]{inputGate, inputGate}, singletonMap(inputGate, 0), target); - CheckpointBarrierUnaligner unalignedHandler = new CheckpointBarrierUnaligner(new int[]{inputGate.getNumberOfInputChannels()}, TestSubtaskCheckpointCoordinator.INSTANCE, "test", target); + CheckpointBarrierAligner alignedHandler = new CheckpointBarrierAligner("test", target, inputGate); + CheckpointBarrierUnaligner unalignedHandler = new CheckpointBarrierUnaligner(TestSubtaskCheckpointCoordinator.INSTANCE, "test", target, inputGate); AlternatingCheckpointBarrierHandler barrierHandler = new AlternatingCheckpointBarrierHandler(alignedHandler, unalignedHandler, target); for (int i = 0; i < 4; i++) { int channel = i % 2; CheckpointType type = channel == 0 ? CHECKPOINT : SAVEPOINT; - barrierHandler.processBarrier(new CheckpointBarrier(i, 0, new CheckpointOptions(type, CheckpointStorageLocationReference.getDefault())), channel); + barrierHandler.processBarrier(new CheckpointBarrier(i, 0, new CheckpointOptions(type, CheckpointStorageLocationReference.getDefault())), new InputChannelInfo(0, channel)); assertEquals(type.isSavepoint(), alignedHandler.isCheckpointPending()); assertNotEquals(alignedHandler.isCheckpointPending(), unalignedHandler.isCheckpointPending()); @@ -118,12 +118,12 @@ public class AlternatingCheckpointBarrierHandlerTest { SingleInputGate inputGate = new SingleInputGateBuilder().setNumberOfChannels(2).build(); inputGate.setInputChannels(new TestInputChannel(inputGate, 0), new TestInputChannel(inputGate, 1)); TestInvokable target = new TestInvokable(); - CheckpointBarrierAligner alignedHandler = new CheckpointBarrierAligner("test", new InputGate[]{inputGate, inputGate}, singletonMap(inputGate, 0), target); - CheckpointBarrierUnaligner unalignedHandler = new CheckpointBarrierUnaligner(new int[]{inputGate.getNumberOfInputChannels()}, TestSubtaskCheckpointCoordinator.INSTANCE, "test", target); + CheckpointBarrierAligner alignedHandler = new CheckpointBarrierAligner("test", target, inputGate); + CheckpointBarrierUnaligner unalignedHandler = new CheckpointBarrierUnaligner(TestSubtaskCheckpointCoordinator.INSTANCE, "test", target, inputGate); AlternatingCheckpointBarrierHandler barrierHandler = new AlternatingCheckpointBarrierHandler(alignedHandler, unalignedHandler, target); final long id = 1; - unalignedHandler.processBarrier(new CheckpointBarrier(id, 0, new CheckpointOptions(CHECKPOINT, CheckpointStorageLocationReference.getDefault())), 0); + unalignedHandler.processBarrier(new CheckpointBarrier(id, 0, new CheckpointOptions(CHECKPOINT, CheckpointStorageLocationReference.getDefault())), new InputChannelInfo(0, 0)); assertInflightDataEquals(unalignedHandler, barrierHandler, id, inputGate.getNumberOfInputChannels()); assertFalse(barrierHandler.getAllBarriersReceivedFuture(id).isDone()); @@ -134,16 +134,16 @@ public class AlternatingCheckpointBarrierHandlerTest { SingleInputGate inputGate = new SingleInputGateBuilder().setNumberOfChannels(2).build(); inputGate.setInputChannels(new TestInputChannel(inputGate, 0), new TestInputChannel(inputGate, 1)); TestInvokable target = new TestInvokable(); - CheckpointBarrierAligner alignedHandler = new CheckpointBarrierAligner("test", new InputGate[]{inputGate, inputGate}, singletonMap(inputGate, 0), target); - CheckpointBarrierUnaligner unalignedHandler = new CheckpointBarrierUnaligner(new int[]{inputGate.getNumberOfInputChannels()}, TestSubtaskCheckpointCoordinator.INSTANCE, "test", target); + CheckpointBarrierAligner alignedHandler = new CheckpointBarrierAligner("test", target, inputGate); + CheckpointBarrierUnaligner unalignedHandler = new CheckpointBarrierUnaligner(TestSubtaskCheckpointCoordinator.INSTANCE, "test", target, inputGate); AlternatingCheckpointBarrierHandler barrierHandler = new AlternatingCheckpointBarrierHandler(alignedHandler, unalignedHandler, target); long checkpointId = 10; long outOfOrderSavepointId = 5; long initialAlignedCheckpointId = alignedHandler.getLatestCheckpointId(); - barrierHandler.processBarrier(new CheckpointBarrier(checkpointId, 0, new CheckpointOptions(CHECKPOINT, CheckpointStorageLocationReference.getDefault())), 0); - barrierHandler.processBarrier(new CheckpointBarrier(outOfOrderSavepointId, 0, new CheckpointOptions(SAVEPOINT, CheckpointStorageLocationReference.getDefault())), 1); + barrierHandler.processBarrier(new CheckpointBarrier(checkpointId, 0, new CheckpointOptions(CHECKPOINT, CheckpointStorageLocationReference.getDefault())), new InputChannelInfo(0, 0)); + barrierHandler.processBarrier(new CheckpointBarrier(outOfOrderSavepointId, 0, new CheckpointOptions(SAVEPOINT, CheckpointStorageLocationReference.getDefault())), new InputChannelInfo(0, 1)); assertEquals(checkpointId, barrierHandler.getLatestCheckpointId()); assertInflightDataEquals(unalignedHandler, barrierHandler, checkpointId, inputGate.getNumberOfInputChannels()); @@ -154,10 +154,13 @@ public class AlternatingCheckpointBarrierHandlerTest { public void testEndOfPartition() throws Exception { int totalChannels = 5; int closedChannels = 2; - SingleInputGate inputGate = new SingleInputGateBuilder().setNumberOfChannels(totalChannels).build(); + SingleInputGate inputGate = new SingleInputGateBuilder() + .setNumberOfChannels(totalChannels) + .setChannelFactory(InputChannelBuilder::buildLocalChannel) + .build(); TestInvokable target = new TestInvokable(); - CheckpointBarrierAligner alignedHandler = new CheckpointBarrierAligner("test", new InputGate[]{inputGate}, singletonMap(inputGate, 0), target); - CheckpointBarrierUnaligner unalignedHandler = new CheckpointBarrierUnaligner(new int[]{inputGate.getNumberOfInputChannels()}, TestSubtaskCheckpointCoordinator.INSTANCE, "test", target); + CheckpointBarrierAligner alignedHandler = new CheckpointBarrierAligner("test", target, inputGate); + CheckpointBarrierUnaligner unalignedHandler = new CheckpointBarrierUnaligner(TestSubtaskCheckpointCoordinator.INSTANCE, "test", target, inputGate); AlternatingCheckpointBarrierHandler barrierHandler = new AlternatingCheckpointBarrierHandler(alignedHandler, unalignedHandler, target); for (int i = 0; i < closedChannels; i++) { barrierHandler.processEndOfPartition(); @@ -174,19 +177,19 @@ public class AlternatingCheckpointBarrierHandlerTest { TestInputChannel slow = new TestInputChannel(gate, 1, false, true); gate.setInputChannels(fast, slow); AlternatingCheckpointBarrierHandler barrierHandler = barrierHandler(gate, target); - CheckpointedInputGate checkpointedGate = new CheckpointedInputGate(gate, barrierHandler, 0 /* offset */); + CheckpointedInputGate checkpointedGate = new CheckpointedInputGate(gate, barrierHandler /* offset */); sendBarrier(barrierId, checkpointType, fast, checkpointedGate); assertEquals(checkpointType.isSavepoint(), target.triggeredCheckpoints.isEmpty()); - assertEquals(checkpointType.isSavepoint(), barrierHandler.isBlocked(fast.getChannelIndex())); - assertFalse(barrierHandler.isBlocked(slow.getChannelIndex())); + assertEquals(checkpointType.isSavepoint(), barrierHandler.isBlocked(fast.getChannelInfo())); + assertFalse(barrierHandler.isBlocked(slow.getChannelInfo())); sendBarrier(barrierId, checkpointType, slow, checkpointedGate); assertEquals(singletonList(barrierId), target.triggeredCheckpoints); for (InputChannel channel : gate.getInputChannels().values()) { - assertFalse(barrierHandler.isBlocked(channel.getChannelIndex())); + assertFalse(barrierHandler.isBlocked(channel.getChannelInfo())); assertEquals( String.format("channel %d should be resumed", channel.getChannelIndex()), checkpointType.isSavepoint(), @@ -205,8 +208,8 @@ public class AlternatingCheckpointBarrierHandlerTest { InputGate[] channelIndexToInputGate = new InputGate[inputGate.getNumberOfInputChannels()]; Arrays.fill(channelIndexToInputGate, inputGate); return new AlternatingCheckpointBarrierHandler( - new CheckpointBarrierAligner(taskName, channelIndexToInputGate, singletonMap(inputGate, 0), target), - new CheckpointBarrierUnaligner(new int[]{inputGate.getNumberOfInputChannels()}, TestSubtaskCheckpointCoordinator.INSTANCE, taskName, target), + new CheckpointBarrierAligner(taskName, target, inputGate), + new CheckpointBarrierUnaligner(TestSubtaskCheckpointCoordinator.INSTANCE, taskName, target, inputGate), target); } @@ -250,7 +253,7 @@ public class AlternatingCheckpointBarrierHandlerTest { channels[i] = new TestInputChannel(gate, i, false, true); } gate.setInputChannels(channels); - return new CheckpointedInputGate(gate, barrierHandler(gate, target), 0); + return new CheckpointedInputGate(gate, barrierHandler(gate, target)); } } diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierAlignerMassiveRandomTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierAlignerMassiveRandomTest.java index a653916..f415e8f 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierAlignerMassiveRandomTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierAlignerMassiveRandomTest.java @@ -19,6 +19,7 @@ package org.apache.flink.streaming.runtime.io; import org.apache.flink.runtime.checkpoint.CheckpointOptions; import org.apache.flink.runtime.checkpoint.channel.ChannelStateReader; +import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo; import org.apache.flink.runtime.event.TaskEvent; import org.apache.flink.runtime.io.network.api.CheckpointBarrier; import org.apache.flink.runtime.io.network.buffer.Buffer; @@ -34,10 +35,13 @@ import org.junit.Test; import java.io.IOException; import java.util.Arrays; +import java.util.List; import java.util.Optional; import java.util.Random; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutorService; +import java.util.stream.Collectors; +import java.util.stream.IntStream; /** * The test generates two random streams (input channels) which independently @@ -65,8 +69,7 @@ public class CheckpointBarrierAlignerMassiveRandomTest { CheckpointedInputGate checkpointedInputGate = new CheckpointedInputGate( myIG, - "Testing: No task associated", - new DummyCheckpointInvokable()); + new CheckpointBarrierAligner("Testing: No task associated", new DummyCheckpointInvokable(), myIG)); for (int i = 0; i < 2000000; i++) { BufferOrEvent boe = checkpointedInputGate.pollNext().get(); @@ -161,6 +164,13 @@ public class CheckpointBarrierAlignerMassiveRandomTest { } @Override + public List<InputChannelInfo> getChannelInfos() { + return IntStream.range(0, numberOfChannels) + .mapToObj(channelIndex -> new InputChannelInfo(0, channelIndex)) + .collect(Collectors.toList()); + } + + @Override public Optional<BufferOrEvent> getNext() throws IOException { currentChannel = (currentChannel + 1) % numberOfChannels; if (channelBlocked[currentChannel]) { @@ -179,7 +189,7 @@ public class CheckpointBarrierAlignerMassiveRandomTest { ++currentBarriers[currentChannel], System.currentTimeMillis(), CheckpointOptions.forCheckpointWithDefaultLocation()), - currentChannel)); + new InputChannelInfo(0, currentChannel))); } else { Buffer buffer = bufferPools[currentChannel].requestBuffer(); if (buffer == null) { @@ -188,7 +198,7 @@ public class CheckpointBarrierAlignerMassiveRandomTest { return getNext(); } buffer.getMemorySegment().putLong(0, c++); - return Optional.of(new BufferOrEvent(buffer, currentChannel)); + return Optional.of(new BufferOrEvent(buffer, new InputChannelInfo(0, currentChannel))); } } diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierAlignerTestBase.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierAlignerTestBase.java index ea4e004..0c478a0 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierAlignerTestBase.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierAlignerTestBase.java @@ -25,6 +25,7 @@ import org.apache.flink.runtime.checkpoint.CheckpointFailureReason; import org.apache.flink.runtime.checkpoint.CheckpointMetaData; import org.apache.flink.runtime.checkpoint.CheckpointMetrics; import org.apache.flink.runtime.checkpoint.CheckpointOptions; +import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo; import org.apache.flink.runtime.io.network.api.CancelCheckpointMarker; import org.apache.flink.runtime.io.network.api.CheckpointBarrier; import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent; @@ -835,12 +836,13 @@ public abstract class CheckpointBarrierAlignerTestBase { // ------------------------------------------------------------------------ private static BufferOrEvent createBarrier(long checkpointId, int channel) { - return new BufferOrEvent(new CheckpointBarrier( - checkpointId, System.currentTimeMillis(), CheckpointOptions.forCheckpointWithDefaultLocation()), channel); + return new BufferOrEvent( + new CheckpointBarrier(checkpointId, System.currentTimeMillis(), CheckpointOptions.forCheckpointWithDefaultLocation()), + new InputChannelInfo(0, channel)); } private static BufferOrEvent createCancellationBarrier(long checkpointId, int channel) { - return new BufferOrEvent(new CancelCheckpointMarker(checkpointId), channel); + return new BufferOrEvent(new CancelCheckpointMarker(checkpointId), new InputChannelInfo(0, channel)); } private static BufferOrEvent createBuffer(int channel) { @@ -857,11 +859,11 @@ public abstract class CheckpointBarrierAlignerTestBase { // retain an additional time so it does not get disposed after being read by the input gate buf.retainBuffer(); - return new BufferOrEvent(buf, channel); + return new BufferOrEvent(buf, new InputChannelInfo(0, channel)); } private static BufferOrEvent createEndOfPartition(int channel) { - return new BufferOrEvent(EndOfPartitionEvent.INSTANCE, channel); + return new BufferOrEvent(EndOfPartitionEvent.INSTANCE, new InputChannelInfo(0, channel)); } private static void check(BufferOrEvent expected, BufferOrEvent present, int pageSize) { diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierTrackerTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierTrackerTest.java index 1deca6a..5a3eafc 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierTrackerTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierTrackerTest.java @@ -20,6 +20,7 @@ package org.apache.flink.streaming.runtime.io; import org.apache.flink.core.memory.MemorySegmentFactory; import org.apache.flink.runtime.checkpoint.CheckpointOptions; +import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo; import org.apache.flink.runtime.io.network.api.CancelCheckpointMarker; import org.apache.flink.runtime.io.network.api.CheckpointBarrier; import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler; @@ -369,16 +370,16 @@ public class CheckpointBarrierTrackerTest { } private static BufferOrEvent createBarrier(long id, int channel) { - return new BufferOrEvent(new CheckpointBarrier(id, System.currentTimeMillis(), CheckpointOptions.forCheckpointWithDefaultLocation()), channel); + return new BufferOrEvent(new CheckpointBarrier(id, System.currentTimeMillis(), CheckpointOptions.forCheckpointWithDefaultLocation()), new InputChannelInfo(0, channel)); } private static BufferOrEvent createCancellationBarrier(long id, int channel) { - return new BufferOrEvent(new CancelCheckpointMarker(id), channel); + return new BufferOrEvent(new CancelCheckpointMarker(id), new InputChannelInfo(0, channel)); } private static BufferOrEvent createBuffer(int channel) { return new BufferOrEvent( - new NetworkBuffer(MemorySegmentFactory.wrap(new byte[]{1, 2}), FreeingBufferRecycler.INSTANCE), channel); + new NetworkBuffer(MemorySegmentFactory.wrap(new byte[]{1, 2}), FreeingBufferRecycler.INSTANCE), new InputChannelInfo(0, channel)); } // ------------------------------------------------------------------------ diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierUnalignerCancellationTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierUnalignerCancellationTest.java index 1bb4972..37d4865 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierUnalignerCancellationTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierUnalignerCancellationTest.java @@ -20,6 +20,7 @@ package org.apache.flink.streaming.runtime.io; import org.apache.flink.runtime.checkpoint.CheckpointMetaData; import org.apache.flink.runtime.checkpoint.CheckpointMetrics; import org.apache.flink.runtime.checkpoint.CheckpointOptions; +import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo; import org.apache.flink.runtime.event.RuntimeEvent; import org.apache.flink.runtime.io.network.api.CancelCheckpointMarker; import org.apache.flink.runtime.io.network.api.CheckpointBarrier; @@ -77,13 +78,13 @@ public class CheckpointBarrierUnalignerCancellationTest { @Test public void test() throws Exception { TestInvokable invokable = new TestInvokable(); - CheckpointBarrierUnaligner unaligner = new CheckpointBarrierUnaligner(new int[]{numChannels}, TestSubtaskCheckpointCoordinator.INSTANCE, "test", invokable); + CheckpointBarrierUnaligner unaligner = new CheckpointBarrierUnaligner(TestSubtaskCheckpointCoordinator.INSTANCE, "test", invokable, new MockIndexedInputGate(0, numChannels)); for (RuntimeEvent e : events) { if (e instanceof CancelCheckpointMarker) { unaligner.processCancellationBarrier((CancelCheckpointMarker) e); } else if (e instanceof CheckpointBarrier) { - unaligner.processBarrier((CheckpointBarrier) e, channel); + unaligner.processBarrier((CheckpointBarrier) e, new InputChannelInfo(0, channel)); } else { throw new IllegalArgumentException("unexpected event type: " + e); } diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierUnalignerTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierUnalignerTest.java index 0ab8ee2..35e5b08 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierUnalignerTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierUnalignerTest.java @@ -30,8 +30,8 @@ import org.apache.flink.runtime.io.network.api.CheckpointBarrier; import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent; import org.apache.flink.runtime.io.network.api.serialization.EventSerializer; import org.apache.flink.runtime.io.network.partition.consumer.BufferOrEvent; +import org.apache.flink.runtime.io.network.partition.consumer.IndexedInputGate; import org.apache.flink.runtime.io.network.partition.consumer.InputChannelBuilder; -import org.apache.flink.runtime.io.network.partition.consumer.InputGate; import org.apache.flink.runtime.io.network.partition.consumer.RemoteInputChannel; import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate; import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGateBuilder; @@ -479,7 +479,7 @@ public class CheckpointBarrierUnalignerTest { } /** - * Tests the race condition between {@link CheckpointBarrierUnaligner#processBarrier(CheckpointBarrier, int)} + * Tests the race condition between {@link CheckpointBarrierHandler#processBarrier(CheckpointBarrier, InputChannelInfo)} * and {@link ThreadSafeUnaligner#notifyBarrierReceived(CheckpointBarrier, InputChannelInfo)}. The barrier * notification will trigger an async checkpoint (ch1) via mailbox, and meanwhile the barrier processing will * execute the next checkpoint (ch2) directly in advance. When the ch1 action is taken from mailbox to execute, @@ -488,7 +488,7 @@ public class CheckpointBarrierUnalignerTest { @Test public void testConcurrentProcessBarrierAndNotifyBarrierReceived() throws Exception { final ValidatingCheckpointInvokable invokable = new ValidatingCheckpointInvokable(); - final CheckpointBarrierUnaligner handler = new CheckpointBarrierUnaligner(new int[] { 1 }, TestSubtaskCheckpointCoordinator.INSTANCE, "test", invokable); + final CheckpointBarrierUnaligner handler = new CheckpointBarrierUnaligner(TestSubtaskCheckpointCoordinator.INSTANCE, "test", invokable, new MockIndexedInputGate()); final InputChannelInfo channelInfo = new InputChannelInfo(0, 0); final ExecutorService executor = Executors.newFixedThreadPool(1); @@ -502,7 +502,7 @@ public class CheckpointBarrierUnalignerTest { result.get(); // Execute the checkpoint (ch1) directly because it is triggered by main thread. - handler.processBarrier(buildCheckpointBarrier(1), 0); + handler.processBarrier(buildCheckpointBarrier(1), new InputChannelInfo(0, 0)); // Run the previous queued mailbox action to execute ch0. invokable.runMailboxStep(); @@ -523,8 +523,7 @@ public class CheckpointBarrierUnalignerTest { @Test public void testProcessCancellationBarrierAfterNotifyBarrierReceived() throws Exception { final ValidatingCheckpointInvokable invokable = new ValidatingCheckpointInvokable(); - final CheckpointBarrierUnaligner handler = new CheckpointBarrierUnaligner( - new int[] { 1 }, TestSubtaskCheckpointCoordinator.INSTANCE, "test", invokable); + final CheckpointBarrierUnaligner handler = new CheckpointBarrierUnaligner(TestSubtaskCheckpointCoordinator.INSTANCE, "test", invokable, new MockIndexedInputGate()); ThreadSafeUnaligner unaligner = handler.getThreadSafeUnaligner(); // should trigger respective checkpoint @@ -541,16 +540,15 @@ public class CheckpointBarrierUnalignerTest { /** * Tests {@link CheckpointBarrierUnaligner#processCancellationBarrier(CancelCheckpointMarker)} * abort the current pending checkpoint triggered by - * {@link CheckpointBarrierUnaligner#processBarrier(CheckpointBarrier, int)}. + * {@link CheckpointBarrierHandler#processBarrier(CheckpointBarrier, InputChannelInfo)}. */ @Test public void testProcessCancellationBarrierAfterProcessBarrier() throws Exception { final ValidatingCheckpointInvokable invokable = new ValidatingCheckpointInvokable(); - final CheckpointBarrierUnaligner handler = new CheckpointBarrierUnaligner( - new int[] { 1 }, TestSubtaskCheckpointCoordinator.INSTANCE, "test", invokable); + final CheckpointBarrierUnaligner handler = new CheckpointBarrierUnaligner(TestSubtaskCheckpointCoordinator.INSTANCE, "test", invokable, new MockIndexedInputGate()); // should trigger respective checkpoint - handler.processBarrier(buildCheckpointBarrier(DEFAULT_CHECKPOINT_ID), 0); + handler.processBarrier(buildCheckpointBarrier(DEFAULT_CHECKPOINT_ID), new InputChannelInfo(0, 0)); assertTrue(handler.isCheckpointPending()); assertTrue(handler.getThreadSafeUnaligner().isCheckpointPending()); @@ -563,15 +561,14 @@ public class CheckpointBarrierUnalignerTest { @Test public void testProcessCancellationBarrierBeforeProcessAndReceiveBarrier() throws Exception { final ValidatingCheckpointInvokable invokable = new ValidatingCheckpointInvokable(); - final CheckpointBarrierUnaligner handler = new CheckpointBarrierUnaligner( - new int[] { 1 }, TestSubtaskCheckpointCoordinator.INSTANCE, "test", invokable); + final CheckpointBarrierUnaligner handler = new CheckpointBarrierUnaligner(TestSubtaskCheckpointCoordinator.INSTANCE, "test", invokable, new MockIndexedInputGate()); handler.processCancellationBarrier(new CancelCheckpointMarker(DEFAULT_CHECKPOINT_ID)); verifyTriggeredCheckpoint(handler, invokable, DEFAULT_CHECKPOINT_ID); // it would not trigger checkpoint since the respective cancellation barrier already happened before - handler.processBarrier(buildCheckpointBarrier(DEFAULT_CHECKPOINT_ID), 0); + handler.processBarrier(buildCheckpointBarrier(DEFAULT_CHECKPOINT_ID), new InputChannelInfo(0, 0)); handler.getThreadSafeUnaligner().notifyBarrierReceived(buildCheckpointBarrier(DEFAULT_CHECKPOINT_ID), new InputChannelInfo(0, 0)); verifyTriggeredCheckpoint(handler, invokable, DEFAULT_CHECKPOINT_ID); @@ -608,8 +605,7 @@ public class CheckpointBarrierUnalignerTest { public void testEndOfStreamWithPendingCheckpoint() throws Exception { final int numberOfChannels = 2; final ValidatingCheckpointInvokable invokable = new ValidatingCheckpointInvokable(); - final CheckpointBarrierUnaligner handler = new CheckpointBarrierUnaligner( - new int[] { numberOfChannels }, TestSubtaskCheckpointCoordinator.INSTANCE, "test", invokable); + final CheckpointBarrierUnaligner handler = new CheckpointBarrierUnaligner(TestSubtaskCheckpointCoordinator.INSTANCE, "test", invokable, new MockIndexedInputGate(0, numberOfChannels)); ThreadSafeUnaligner unaligner = handler.getThreadSafeUnaligner(); // should trigger respective checkpoint @@ -639,26 +635,26 @@ public class CheckpointBarrierUnalignerTest { checkpointId, System.currentTimeMillis(), CheckpointOptions.forCheckpointWithDefaultLocation()), - channel); + new InputChannelInfo(0, channel)); } private BufferOrEvent createCancellationBarrier(long checkpointId, int channel) { sizeCounter++; - return new BufferOrEvent(new CancelCheckpointMarker(checkpointId), channel); + return new BufferOrEvent(new CancelCheckpointMarker(checkpointId), new InputChannelInfo(0, channel)); } private BufferOrEvent createBuffer(int channel) { final int size = sizeCounter++; - return new BufferOrEvent(TestBufferFactory.createBuffer(size), channel); + return new BufferOrEvent(TestBufferFactory.createBuffer(size), new InputChannelInfo(0, channel)); } private static BufferOrEvent createEndOfPartition(int channel) { - return new BufferOrEvent(EndOfPartitionEvent.INSTANCE, channel); + return new BufferOrEvent(EndOfPartitionEvent.INSTANCE, new InputChannelInfo(0, channel)); } private CheckpointedInputGate createInputGate( int numberOfChannels, - AbstractInvokable toNotify) throws IOException, InterruptedException { + AbstractInvokable toNotify) throws IOException { final NettyShuffleEnvironment environment = new NettyShuffleEnvironmentBuilder().build(); SingleInputGate gate = new SingleInputGateBuilder() .setNumberOfChannels(numberOfChannels) @@ -687,12 +683,12 @@ public class CheckpointBarrierUnalignerTest { if (bufferOrEvent.isEvent()) { bufferOrEvent = new BufferOrEvent( EventSerializer.toBuffer(bufferOrEvent.getEvent()), - bufferOrEvent.getChannelIndex(), + bufferOrEvent.getChannelInfo(), bufferOrEvent.moreAvailable()); } - ((RemoteInputChannel) inputGate.getChannel(bufferOrEvent.getChannelIndex())).onBuffer( + ((RemoteInputChannel) inputGate.getChannel(bufferOrEvent.getChannelInfo().getInputChannelIdx())).onBuffer( bufferOrEvent.getBuffer(), - sequenceNumbers[bufferOrEvent.getChannelIndex()]++, + sequenceNumbers[bufferOrEvent.getChannelInfo().getInputChannelIdx()]++, 0); while (inputGate.pollNext().map(output::add).isPresent()) { @@ -702,12 +698,12 @@ public class CheckpointBarrierUnalignerTest { return sequence; } - private CheckpointedInputGate createCheckpointedInputGate(InputGate gate, AbstractInvokable toNotify) { + private CheckpointedInputGate createCheckpointedInputGate(IndexedInputGate gate, AbstractInvokable toNotify) { final CheckpointBarrierUnaligner barrierHandler = new CheckpointBarrierUnaligner( - new int[]{ gate.getNumberOfInputChannels() }, new TestSubtaskCheckpointCoordinator(channelStateWriter), "Test", - toNotify); + toNotify, + gate); barrierHandler.getBufferReceivedListener().ifPresent(gate::registerBufferReceivedListener); return new CheckpointedInputGate(gate, barrierHandler); } @@ -719,7 +715,7 @@ public class CheckpointBarrierUnalignerTest { private Collection<BufferOrEvent> getAndResetInflightData() { final List<BufferOrEvent> inflightData = channelStateWriter.getAddedInput().entries().stream() - .map(entry -> new BufferOrEvent(entry.getValue(), entry.getKey().getInputChannelIdx())) + .map(entry -> new BufferOrEvent(entry.getValue(), entry.getKey())) .collect(Collectors.toList()); channelStateWriter.reset(); return inflightData; diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/CreditBasedCheckpointBarrierAlignerTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/CreditBasedCheckpointBarrierAlignerTest.java index 2837890..eb1b403 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/CreditBasedCheckpointBarrierAlignerTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/CreditBasedCheckpointBarrierAlignerTest.java @@ -28,6 +28,6 @@ public class CreditBasedCheckpointBarrierAlignerTest extends CheckpointBarrierAl @Override CheckpointedInputGate createBarrierBuffer(InputGate gate, AbstractInvokable toNotify) { - return new CheckpointedInputGate(gate, "Testing", toNotify); + return new CheckpointedInputGate(gate, new CheckpointBarrierAligner("Testing", toNotify, gate)); } } diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/InputProcessorUtilTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/InputProcessorUtilTest.java index a17ead0..9112bf0 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/InputProcessorUtilTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/InputProcessorUtilTest.java @@ -25,9 +25,6 @@ import org.apache.flink.runtime.checkpoint.channel.MockChannelStateWriter; import org.apache.flink.runtime.io.network.api.CheckpointBarrier; import org.apache.flink.runtime.io.network.buffer.BufferReceivedListener; import org.apache.flink.runtime.io.network.partition.consumer.IndexedInputGate; -import org.apache.flink.runtime.io.network.partition.consumer.InputGate; -import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate; -import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGateBuilder; import org.apache.flink.runtime.operators.testutils.MockEnvironment; import org.apache.flink.runtime.operators.testutils.MockEnvironmentBuilder; import org.apache.flink.streaming.api.CheckpointingMode; @@ -41,10 +38,8 @@ import org.junit.Test; import java.util.Arrays; import java.util.Collections; import java.util.List; -import java.util.Map; import java.util.stream.Collectors; -import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; /** @@ -53,32 +48,6 @@ import static org.junit.Assert.assertTrue; public class InputProcessorUtilTest { @Test - public void testGenerateChannelIndexToInputGateMap() { - SingleInputGate ig1 = new SingleInputGateBuilder().setNumberOfChannels(2).build(); - SingleInputGate ig2 = new SingleInputGateBuilder().setNumberOfChannels(3).build(); - - InputGate[] channelIndexToInputGateMap = InputProcessorUtil.generateChannelIndexToInputGateMap(ig1, ig2); - assertEquals(5, channelIndexToInputGateMap.length); - assertEquals(ig1, channelIndexToInputGateMap[0]); - assertEquals(ig1, channelIndexToInputGateMap[1]); - assertEquals(ig2, channelIndexToInputGateMap[2]); - assertEquals(ig2, channelIndexToInputGateMap[3]); - assertEquals(ig2, channelIndexToInputGateMap[4]); - } - - @Test - public void testGenerateInputGateToChannelIndexOffsetMap() { - SingleInputGate ig1 = new SingleInputGateBuilder().setNumberOfChannels(3).build(); - SingleInputGate ig2 = new SingleInputGateBuilder().setNumberOfChannels(2).build(); - - Map<InputGate, Integer> inputGateToChannelIndexOffsetMap = - InputProcessorUtil.generateInputGateToChannelIndexOffsetMap(ig1, ig2); - assertEquals(2, inputGateToChannelIndexOffsetMap.size()); - assertEquals(0, inputGateToChannelIndexOffsetMap.get(ig1).intValue()); - assertEquals(3, inputGateToChannelIndexOffsetMap.get(ig2).intValue()); - } - - @Test public void testCreateCheckpointedMultipleInputGate() throws Exception { try (CloseableRegistry registry = new CloseableRegistry()) { MockEnvironment environment = new MockEnvironmentBuilder().build(); diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/MockInputGate.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/MockInputGate.java index a536cbd..8779200 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/MockInputGate.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/MockInputGate.java @@ -113,7 +113,7 @@ public class MockInputGate extends InputGate { return Optional.empty(); } - int channelIdx = next.getChannelIndex(); + int channelIdx = next.getChannelInfo().getInputChannelIdx(); if (closed[channelIdx]) { throw new RuntimeException("Inconsistent: Channel " + channelIdx + " has data even though it is already closed."); diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/StreamTaskNetworkInputTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/StreamTaskNetworkInputTest.java index b61da52..5fd2aef 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/StreamTaskNetworkInputTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/StreamTaskNetworkInputTest.java @@ -22,6 +22,7 @@ import org.apache.flink.api.common.typeutils.base.LongSerializer; import org.apache.flink.core.io.InputStatus; import org.apache.flink.runtime.checkpoint.CheckpointOptions; import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter; +import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo; import org.apache.flink.runtime.checkpoint.channel.RecordingChannelStateWriter; import org.apache.flink.runtime.io.disk.iomanager.IOManager; import org.apache.flink.runtime.io.disk.iomanager.IOManagerAsync; @@ -109,7 +110,7 @@ public class StreamTaskNetworkInputTest { CheckpointBarrier barrier = new CheckpointBarrier(0, 0, CheckpointOptions.forCheckpointWithDefaultLocation()); List<BufferOrEvent> buffers = new ArrayList<>(2); - buffers.add(new BufferOrEvent(barrier, 0)); + buffers.add(new BufferOrEvent(barrier, new InputChannelInfo(0, 0))); buffers.add(createDataBuffer()); VerifyRecordsDataOutput output = new VerifyRecordsDataOutput<>(); @@ -121,22 +122,24 @@ public class StreamTaskNetworkInputTest { @Test public void testSnapshotWithTwoInputGates() throws Exception { - CheckpointBarrierUnaligner unaligner = new CheckpointBarrierUnaligner( - new int[]{ 1, 1 }, - TestSubtaskCheckpointCoordinator.INSTANCE, - "test", - new DummyCheckpointInvokable()); - SingleInputGate inputGate1 = new SingleInputGateBuilder().setSingleInputGateIndex(0).build(); RemoteInputChannel channel1 = InputChannelBuilder.newBuilder().buildRemoteChannel(inputGate1); inputGate1.setInputChannels(channel1); - inputGate1.registerBufferReceivedListener(unaligner.getBufferReceivedListener().get()); - StreamTaskNetworkInput<Long> input1 = createInput(unaligner, inputGate1); SingleInputGate inputGate2 = new SingleInputGateBuilder().setSingleInputGateIndex(1).build(); RemoteInputChannel channel2 = InputChannelBuilder.newBuilder().buildRemoteChannel(inputGate2); inputGate2.setInputChannels(channel2); + + CheckpointBarrierUnaligner unaligner = new CheckpointBarrierUnaligner( + TestSubtaskCheckpointCoordinator.INSTANCE, + "test", + new DummyCheckpointInvokable(), + inputGate1, + inputGate2); + inputGate1.registerBufferReceivedListener(unaligner.getBufferReceivedListener().get()); inputGate2.registerBufferReceivedListener(unaligner.getBufferReceivedListener().get()); + + StreamTaskNetworkInput<Long> input1 = createInput(unaligner, inputGate1); StreamTaskNetworkInput<Long> input2 = createInput(unaligner, inputGate2); CheckpointBarrier barrier = new CheckpointBarrier(0, 0L, CheckpointOptions.forCheckpointWithDefaultLocation()); @@ -194,10 +197,10 @@ public class StreamTaskNetworkInputTest { new CheckpointedInputGate( inputGate.getInputGate(), new CheckpointBarrierUnaligner( - new int[] { numInputChannels }, TestSubtaskCheckpointCoordinator.INSTANCE, "test", - new DummyCheckpointInvokable())), + new DummyCheckpointInvokable(), + inputGate.getInputGate())), inSerializer, new StatusWatermarkValve(numInputChannels, output), 0, @@ -261,7 +264,7 @@ public class StreamTaskNetworkInputTest { serializeRecord(42L, bufferBuilder); serializeRecord(44L, bufferBuilder); - return new BufferOrEvent(bufferConsumer.build(), 0, false); + return new BufferOrEvent(bufferConsumer.build(), new InputChannelInfo(0, 0), false); } private StreamTaskNetworkInput createStreamTaskNetworkInput(List<BufferOrEvent> buffers, DataOutput output) {
