rkhachatryan commented on a change in pull request #13845:
URL: https://github.com/apache/flink/pull/13845#discussion_r583490420
##########
File path:
flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/RecoveredChannelStateHandler.java
##########
@@ -45,32 +55,46 @@
BufferWithContext<Context> getBuffer(Info info) throws IOException,
InterruptedException;
- void recover(Info info, Context context) throws IOException;
+ void recover(Info info, int oldSubtaskIndex, Context context) throws
IOException;
}
class InputChannelRecoveredStateHandler
implements RecoveredChannelStateHandler<InputChannelInfo, Buffer> {
private final InputGate[] inputGates;
- InputChannelRecoveredStateHandler(InputGate[] inputGates) {
+ private final InflightDataRescalingDescriptor channelMapping;
+
+ private final Map<InputChannelInfo, List<RecoveredInputChannel>>
rescaledChannels =
+ new HashMap<>();
+
+ InputChannelRecoveredStateHandler(
+ InputGate[] inputGates, InflightDataRescalingDescriptor
channelMapping) {
this.inputGates = inputGates;
+ this.channelMapping = channelMapping;
}
@Override
public BufferWithContext<Buffer> getBuffer(InputChannelInfo channelInfo)
throws IOException, InterruptedException {
- RecoveredInputChannel channel = getChannel(channelInfo);
+ RecoveredInputChannel channel = getMappedChannels(channelInfo).get(0);
Buffer buffer = channel.requestBufferBlocking();
return new BufferWithContext<>(wrap(buffer), buffer);
}
@Override
- public void recover(InputChannelInfo channelInfo, Buffer buffer) {
+ public void recover(InputChannelInfo channelInfo, int oldSubtaskIndex,
Buffer buffer)
+ throws IOException {
if (buffer.readableBytes() > 0) {
- getChannel(channelInfo).onRecoveredStateBuffer(buffer);
- } else {
- buffer.recycleBuffer();
+ for (final RecoveredInputChannel channel :
getMappedChannels(channelInfo)) {
+ channel.onRecoveredStateBuffer(
+ EventSerializer.toBuffer(
+ new VirtualChannelSelector(
+ oldSubtaskIndex,
channelInfo.getInputChannelIdx()),
+ false));
+ channel.onRecoveredStateBuffer(buffer.retainBuffer());
+ }
}
+ buffer.recycleBuffer();
Review comment:
Should it be in `finally`?
##########
File path:
flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilter.java
##########
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.runtime.io.recovery;
+
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.runtime.io.network.api.writer.ChannelSelector;
+import org.apache.flink.runtime.plugable.SerializationDelegate;
+import org.apache.flink.streaming.runtime.streamrecord.StreamElementSerializer;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+
+import java.util.function.Predicate;
+
+/**
+ * Filters records for ambiguous channel mappings.
+ *
+ * <p>For example, when the downstream node of a keyed exchange is scaled from
1 to 2, the state of
+ * the output side on te upstream node needs to be replicated to both
channels. This filter then
+ * checks the deserialized records on both downstream subtasks and filters out
the irrelevant
+ * records.
+ *
+ * @param <T>
+ */
+class RecordFilter<T> implements Predicate<StreamRecord<T>> {
+ private final ChannelSelector<SerializationDelegate<StreamRecord<T>>>
partitioner;
+
+ private final SerializationDelegate<StreamRecord<T>> delegate;
+
+ private final int subtaskIndex;
+
+ public RecordFilter(
+ ChannelSelector<SerializationDelegate<StreamRecord<T>>>
partitioner,
+ TypeSerializer<T> inputSerializer,
+ int subtaskIndex) {
+ this.partitioner = partitioner;
+ delegate = new SerializationDelegate<>(new
StreamElementSerializer(inputSerializer));
+ this.subtaskIndex = subtaskIndex;
+ }
+
+ public static <T> Predicate<StreamRecord<T>> all() {
+ return record -> true;
+ }
+
+ @Override
+ public boolean test(StreamRecord<T> streamRecord) {
+ delegate.setInstance(streamRecord);
+ // check if record would have arrived at this subtask if it had been
partitioned upstream
+ return partitioner.selectChannel(delegate) == subtaskIndex;
Review comment:
I'm wondering whether it's possible that the record will be discarded
because the partitioner always chooses the "other" subtask?
For example, in an up-scaling from 1 to 2 scenario with RoundRobin
partitioner:
1. let subtask0.rrPartitioner.nextChannelToSendTo = 0
selectChannel returns 1 - filtered out
2. let subtask1.rrPartitioner.nextChannelToSendTo = 1 (some record was
already processed)
selectChannel returns 0 - filtered out
##########
File path:
flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/RecoveredChannelStateHandler.java
##########
@@ -94,17 +145,25 @@ private RecoveredInputChannel getChannel(InputChannelInfo
info) {
private final ResultPartitionWriter[] writers;
private final boolean notifyAndBlockOnCompletion;
+ private final InflightDataRescalingDescriptor channelMapping;
+
+ private final Map<ResultSubpartitionInfo,
List<CheckpointedResultSubpartition>>
+ rescaledChannels = new HashMap<>();
+
ResultSubpartitionRecoveredStateHandler(
- ResultPartitionWriter[] writers, boolean
notifyAndBlockOnCompletion) {
+ ResultPartitionWriter[] writers,
+ boolean notifyAndBlockOnCompletion,
+ InflightDataRescalingDescriptor channelMapping) {
this.writers = writers;
+ this.channelMapping = channelMapping;
this.notifyAndBlockOnCompletion = notifyAndBlockOnCompletion;
}
@Override
public BufferWithContext<Tuple2<BufferBuilder, BufferConsumer>> getBuffer(
ResultSubpartitionInfo subpartitionInfo) throws IOException,
InterruptedException {
- BufferBuilder bufferBuilder =
-
getSubpartition(subpartitionInfo).requestBufferBuilderBlocking();
+ final List<CheckpointedResultSubpartition> channels =
getMappedChannels(subpartitionInfo);
+ BufferBuilder bufferBuilder =
channels.get(0).requestBufferBuilderBlocking();
Review comment:
I guess `get(0)` is used because the actual subpartition that we use to
request a buffer doesn't matter here.
If so, could you add a comment in code?
##########
File path:
flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/recovery/RescalingStreamTaskNetworkInput.java
##########
@@ -0,0 +1,280 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.runtime.io.recovery;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.TaskInfo;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.io.InputStatus;
+import org.apache.flink.runtime.checkpoint.CheckpointException;
+import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor;
+import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter;
+import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo;
+import org.apache.flink.runtime.event.AbstractEvent;
+import org.apache.flink.runtime.io.disk.iomanager.IOManager;
+import org.apache.flink.runtime.io.network.api.VirtualChannelSelector;
+import
org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer;
+import
org.apache.flink.runtime.io.network.api.serialization.SpillingAdaptiveSpanningRecordDeserializer;
+import org.apache.flink.runtime.io.network.partition.consumer.BufferOrEvent;
+import org.apache.flink.runtime.plugable.DeserializationDelegate;
+import org.apache.flink.streaming.runtime.io.AbstractStreamTaskNetworkInput;
+import org.apache.flink.streaming.runtime.io.RecoverableStreamTaskInput;
+import org.apache.flink.streaming.runtime.io.StreamTaskInput;
+import org.apache.flink.streaming.runtime.io.StreamTaskNetworkInput;
+import
org.apache.flink.streaming.runtime.io.checkpointing.CheckpointedInputGate;
+import
org.apache.flink.streaming.runtime.partitioner.ConfigurableStreamPartitioner;
+import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
+import org.apache.flink.streaming.runtime.streamrecord.StreamElement;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.runtime.streamstatus.StatusWatermarkValve;
+
+import org.apache.flink.shaded.guava18.com.google.common.collect.Maps;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.concurrent.CompletableFuture;
+import java.util.function.Function;
+import java.util.function.Predicate;
+
+import static
org.apache.flink.runtime.checkpoint.CheckpointFailureReason.CHECKPOINT_DECLINED_TASK_NOT_READY;
+
+/**
+ * A {@link StreamTaskNetworkInput} implementation that demultiplexes virtual
channels.
+ *
+ * <p>The demultiplexing works in two dimensions for the following cases. *
+ *
+ * <ul>
+ * <li>Subtasks of the current operator have been collapsed in a round-robin
fashion.
+ * <li>The connected output operator has been rescaled (up and down!) and
there is an overlap of
+ * channels (mostly relevant to keyed exchanges).
+ * </ul>
+ *
+ * <p>In both cases, records from multiple old channels are received over one
new physical channel,
+ * which need to demultiplex the record to correctly restore spanning records
(similar to how
+ * StreamTaskNetworkInput works).
+ *
+ * <p>Note that when both cases occur at the same time (downscaling of several
operators), there is
+ * the cross product of channels. So if two subtasks are collapsed and two
channels overlap from the
+ * output side, there is a total of 4 virtual channels.
+ */
+@Internal
+public final class RescalingStreamTaskNetworkInput<T>
+ extends AbstractStreamTaskNetworkInput<T,
DemultiplexingRecordDeserializer<T>>
+ implements RecoverableStreamTaskInput<T> {
+
+ private static final Logger LOG =
+ LoggerFactory.getLogger(RescalingStreamTaskNetworkInput.class);
+ private final IOManager ioManager;
+
+ private RescalingStreamTaskNetworkInput(
+ CheckpointedInputGate checkpointedInputGate,
+ TypeSerializer<T> inputSerializer,
+ IOManager ioManager,
+ StatusWatermarkValve statusWatermarkValve,
+ int inputIndex,
+ InflightDataRescalingDescriptor inflightDataRescalingDescriptor,
+ Function<Integer, StreamPartitioner<?>> gatePartitioners,
+ TaskInfo taskInfo) {
+ super(
+ checkpointedInputGate,
+ inputSerializer,
+ statusWatermarkValve,
+ inputIndex,
+ getRecordDeserializers(
+ checkpointedInputGate,
+ inputSerializer,
+ ioManager,
+ inflightDataRescalingDescriptor,
+ gatePartitioners,
+ taskInfo));
+ this.ioManager = ioManager;
+
+ LOG.info(
+ "Created demultiplexer for input {} from {}",
+ inputIndex,
+ inflightDataRescalingDescriptor);
+ }
+
+ private static <T>
+ Map<InputChannelInfo, DemultiplexingRecordDeserializer<T>>
getRecordDeserializers(
+ CheckpointedInputGate checkpointedInputGate,
+ TypeSerializer<T> inputSerializer,
+ IOManager ioManager,
+ InflightDataRescalingDescriptor rescalingDescriptor,
+ Function<Integer, StreamPartitioner<?>> gatePartitioners,
+ TaskInfo taskInfo) {
+
+ RecordFilterFactory<T> recordFilterFactory =
+ new RecordFilterFactory<>(
+ taskInfo.getIndexOfThisSubtask(),
+ inputSerializer,
+ taskInfo.getNumberOfParallelSubtasks(),
+ gatePartitioners,
+ taskInfo.getMaxNumberOfParallelSubtasks());
+ final DeserializerFactory deserializerFactory = new
DeserializerFactory(ioManager);
+ Map<InputChannelInfo, DemultiplexingRecordDeserializer<T>>
deserializers =
+
Maps.newHashMapWithExpectedSize(checkpointedInputGate.getChannelInfos().size());
+ for (InputChannelInfo channelInfo :
checkpointedInputGate.getChannelInfos()) {
+ deserializers.put(
+ channelInfo,
+ DemultiplexingRecordDeserializer.create(
+ channelInfo,
+ rescalingDescriptor,
+ deserializerFactory,
+ recordFilterFactory));
+ }
+ return deserializers;
+ }
+
+ @Override
+ public StreamTaskInput<T> finishRecovery() throws IOException {
+ close();
+ return new StreamTaskNetworkInput<>(
+ checkpointedInputGate,
+ inputSerializer,
+ ioManager,
+ statusWatermarkValve,
+ inputIndex);
+ }
+
+ /**
+ * Factory method for {@link StreamTaskNetworkInput} or {@link
RescalingStreamTaskNetworkInput}
+ * depending on {@link InflightDataRescalingDescriptor}.
+ */
+ public static <T> StreamTaskInput<T> create(
+ CheckpointedInputGate checkpointedInputGate,
+ TypeSerializer<T> inputSerializer,
+ IOManager ioManager,
+ StatusWatermarkValve statusWatermarkValve,
+ int inputIndex,
+ InflightDataRescalingDescriptor
rescalingDescriptorinflightDataRescalingDescriptor,
+ Function<Integer, StreamPartitioner<?>> gatePartitioners,
+ TaskInfo taskInfo) {
+ return rescalingDescriptorinflightDataRescalingDescriptor.equals(
+ InflightDataRescalingDescriptor.NO_RESCALE)
+ ? new StreamTaskNetworkInput<>(
+ checkpointedInputGate,
+ inputSerializer,
+ ioManager,
+ statusWatermarkValve,
+ inputIndex)
+ : new RescalingStreamTaskNetworkInput<>(
+ checkpointedInputGate,
+ inputSerializer,
+ ioManager,
+ statusWatermarkValve,
+ inputIndex,
+ rescalingDescriptorinflightDataRescalingDescriptor,
+ gatePartitioners,
+ taskInfo);
+ }
+
+ protected DemultiplexingRecordDeserializer<T> getActiveSerializer(
+ InputChannelInfo channelInfo) {
+ final DemultiplexingRecordDeserializer<T> deserialier =
+ recordDeserializers.get(channelInfo);
+ if (!deserialier.hasMappings()) {
+ throw new IllegalStateException(
+ "Channel " + channelInfo + " should not receive data
during recovery.");
+ }
+ return deserialier;
+ }
+
+ protected InputStatus processEvent(BufferOrEvent bufferOrEvent) {
+ // Event received
+ final AbstractEvent event = bufferOrEvent.getEvent();
+ if (event instanceof VirtualChannelSelector) {
+ getActiveSerializer(bufferOrEvent.getChannelInfo())
+ .select((VirtualChannelSelector) event);
+ return InputStatus.MORE_AVAILABLE;
+ }
+ return super.processEvent(bufferOrEvent);
+ }
+
+ @Override
+ public CompletableFuture<Void> prepareSnapshot(
+ ChannelStateWriter channelStateWriter, long checkpointId) throws
CheckpointException {
+ throw new CheckpointException(CHECKPOINT_DECLINED_TASK_NOT_READY);
+ }
+
+ static class RecordFilterFactory<T>
+ implements Function<InputChannelInfo, Predicate<StreamRecord<T>>> {
+ private final Map<Integer, StreamPartitioner<T>> partitionerCache =
new HashMap<>(1);
+ private final Function<Integer, StreamPartitioner<?>> gatePartitioners;
+ private final TypeSerializer<T> inputSerializer;
+ private final int numberOfChannels;
+ private int subtaskIndex;
+ private int maxParallelism;
+
+ public RecordFilterFactory(
+ int subtaskIndex,
+ TypeSerializer<T> inputSerializer,
+ int numberOfChannels,
+ Function<Integer, StreamPartitioner<?>> gatePartitioners,
+ int maxParallelism) {
+ this.gatePartitioners = gatePartitioners;
+ this.inputSerializer = inputSerializer;
+ this.numberOfChannels = numberOfChannels;
+ this.subtaskIndex = subtaskIndex;
+ this.maxParallelism = maxParallelism;
+ }
+
+ @Override
+ public Predicate<StreamRecord<T>> apply(InputChannelInfo channelInfo) {
+ return new RecordFilter<>(
+ partitionerCache.computeIfAbsent(
+ channelInfo.getGateIdx(), this::createPartitioner),
+ inputSerializer,
+ subtaskIndex);
+ }
+
+ private StreamPartitioner<T> createPartitioner(Integer index) {
+ StreamPartitioner<T> partitioner = (StreamPartitioner<T>)
gatePartitioners.apply(index);
+ partitioner.setup(numberOfChannels);
+ if (partitioner instanceof ConfigurableStreamPartitioner) {
+ ((ConfigurableStreamPartitioner)
partitioner).configure(maxParallelism);
+ }
+ return partitioner;
+ }
+ }
+
+ static class DeserializerFactory
+ implements Function<
+ Integer,
RecordDeserializer<DeserializationDelegate<StreamElement>>> {
+ private final IOManager ioManager;
+
+ public DeserializerFactory(IOManager ioManager) {
+ this.ioManager = ioManager;
+ }
+
+ @Override
+ public RecordDeserializer<DeserializationDelegate<StreamElement>>
apply(
+ Integer totalChannels) {
+ return new SpillingAdaptiveSpanningRecordDeserializer<>(
+ ioManager.getSpillingDirectoriesPaths(),
+
SpillingAdaptiveSpanningRecordDeserializer.DEFAULT_THRESHOLD_FOR_SPILLING
+ / totalChannels,
+
SpillingAdaptiveSpanningRecordDeserializer.DEFAULT_FILE_BUFFER_SIZE
+ / totalChannels);
+ }
Review comment:
If the defaults in `SpillingAdaptiveSpanningRecordDeserializer` are
decreased then we can get 0 here with high enough DoP.
Should we add `Math.max(some_minimum, ....)` ?
##########
File path:
flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/recovery/RescalingStreamTaskNetworkInput.java
##########
@@ -0,0 +1,280 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.runtime.io.recovery;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.TaskInfo;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.io.InputStatus;
+import org.apache.flink.runtime.checkpoint.CheckpointException;
+import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor;
+import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter;
+import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo;
+import org.apache.flink.runtime.event.AbstractEvent;
+import org.apache.flink.runtime.io.disk.iomanager.IOManager;
+import org.apache.flink.runtime.io.network.api.VirtualChannelSelector;
+import
org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer;
+import
org.apache.flink.runtime.io.network.api.serialization.SpillingAdaptiveSpanningRecordDeserializer;
+import org.apache.flink.runtime.io.network.partition.consumer.BufferOrEvent;
+import org.apache.flink.runtime.plugable.DeserializationDelegate;
+import org.apache.flink.streaming.runtime.io.AbstractStreamTaskNetworkInput;
+import org.apache.flink.streaming.runtime.io.RecoverableStreamTaskInput;
+import org.apache.flink.streaming.runtime.io.StreamTaskInput;
+import org.apache.flink.streaming.runtime.io.StreamTaskNetworkInput;
+import
org.apache.flink.streaming.runtime.io.checkpointing.CheckpointedInputGate;
+import
org.apache.flink.streaming.runtime.partitioner.ConfigurableStreamPartitioner;
+import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
+import org.apache.flink.streaming.runtime.streamrecord.StreamElement;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.runtime.streamstatus.StatusWatermarkValve;
+
+import org.apache.flink.shaded.guava18.com.google.common.collect.Maps;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.concurrent.CompletableFuture;
+import java.util.function.Function;
+import java.util.function.Predicate;
+
+import static
org.apache.flink.runtime.checkpoint.CheckpointFailureReason.CHECKPOINT_DECLINED_TASK_NOT_READY;
+
+/**
+ * A {@link StreamTaskNetworkInput} implementation that demultiplexes virtual
channels.
+ *
+ * <p>The demultiplexing works in two dimensions for the following cases. *
+ *
+ * <ul>
+ * <li>Subtasks of the current operator have been collapsed in a round-robin
fashion.
+ * <li>The connected output operator has been rescaled (up and down!) and
there is an overlap of
+ * channels (mostly relevant to keyed exchanges).
+ * </ul>
+ *
+ * <p>In both cases, records from multiple old channels are received over one
new physical channel,
+ * which need to demultiplex the record to correctly restore spanning records
(similar to how
+ * StreamTaskNetworkInput works).
+ *
+ * <p>Note that when both cases occur at the same time (downscaling of several
operators), there is
+ * the cross product of channels. So if two subtasks are collapsed and two
channels overlap from the
+ * output side, there is a total of 4 virtual channels.
+ */
+@Internal
+public final class RescalingStreamTaskNetworkInput<T>
+ extends AbstractStreamTaskNetworkInput<T,
DemultiplexingRecordDeserializer<T>>
+ implements RecoverableStreamTaskInput<T> {
+
+ private static final Logger LOG =
+ LoggerFactory.getLogger(RescalingStreamTaskNetworkInput.class);
+ private final IOManager ioManager;
+
+ private RescalingStreamTaskNetworkInput(
+ CheckpointedInputGate checkpointedInputGate,
+ TypeSerializer<T> inputSerializer,
+ IOManager ioManager,
+ StatusWatermarkValve statusWatermarkValve,
+ int inputIndex,
+ InflightDataRescalingDescriptor inflightDataRescalingDescriptor,
+ Function<Integer, StreamPartitioner<?>> gatePartitioners,
+ TaskInfo taskInfo) {
+ super(
+ checkpointedInputGate,
+ inputSerializer,
+ statusWatermarkValve,
+ inputIndex,
+ getRecordDeserializers(
+ checkpointedInputGate,
+ inputSerializer,
+ ioManager,
+ inflightDataRescalingDescriptor,
+ gatePartitioners,
+ taskInfo));
+ this.ioManager = ioManager;
+
+ LOG.info(
+ "Created demultiplexer for input {} from {}",
+ inputIndex,
+ inflightDataRescalingDescriptor);
+ }
+
+ private static <T>
+ Map<InputChannelInfo, DemultiplexingRecordDeserializer<T>>
getRecordDeserializers(
+ CheckpointedInputGate checkpointedInputGate,
+ TypeSerializer<T> inputSerializer,
+ IOManager ioManager,
+ InflightDataRescalingDescriptor rescalingDescriptor,
+ Function<Integer, StreamPartitioner<?>> gatePartitioners,
+ TaskInfo taskInfo) {
+
+ RecordFilterFactory<T> recordFilterFactory =
+ new RecordFilterFactory<>(
+ taskInfo.getIndexOfThisSubtask(),
+ inputSerializer,
+ taskInfo.getNumberOfParallelSubtasks(),
+ gatePartitioners,
+ taskInfo.getMaxNumberOfParallelSubtasks());
+ final DeserializerFactory deserializerFactory = new
DeserializerFactory(ioManager);
+ Map<InputChannelInfo, DemultiplexingRecordDeserializer<T>>
deserializers =
+
Maps.newHashMapWithExpectedSize(checkpointedInputGate.getChannelInfos().size());
+ for (InputChannelInfo channelInfo :
checkpointedInputGate.getChannelInfos()) {
+ deserializers.put(
+ channelInfo,
+ DemultiplexingRecordDeserializer.create(
+ channelInfo,
+ rescalingDescriptor,
+ deserializerFactory,
+ recordFilterFactory));
+ }
+ return deserializers;
+ }
+
+ @Override
+ public StreamTaskInput<T> finishRecovery() throws IOException {
+ close();
+ return new StreamTaskNetworkInput<>(
+ checkpointedInputGate,
+ inputSerializer,
+ ioManager,
+ statusWatermarkValve,
+ inputIndex);
+ }
+
+ /**
+ * Factory method for {@link StreamTaskNetworkInput} or {@link
RescalingStreamTaskNetworkInput}
+ * depending on {@link InflightDataRescalingDescriptor}.
+ */
+ public static <T> StreamTaskInput<T> create(
+ CheckpointedInputGate checkpointedInputGate,
+ TypeSerializer<T> inputSerializer,
+ IOManager ioManager,
+ StatusWatermarkValve statusWatermarkValve,
+ int inputIndex,
+ InflightDataRescalingDescriptor
rescalingDescriptorinflightDataRescalingDescriptor,
+ Function<Integer, StreamPartitioner<?>> gatePartitioners,
+ TaskInfo taskInfo) {
+ return rescalingDescriptorinflightDataRescalingDescriptor.equals(
+ InflightDataRescalingDescriptor.NO_RESCALE)
+ ? new StreamTaskNetworkInput<>(
+ checkpointedInputGate,
+ inputSerializer,
+ ioManager,
+ statusWatermarkValve,
+ inputIndex)
+ : new RescalingStreamTaskNetworkInput<>(
+ checkpointedInputGate,
+ inputSerializer,
+ ioManager,
+ statusWatermarkValve,
+ inputIndex,
+ rescalingDescriptorinflightDataRescalingDescriptor,
+ gatePartitioners,
+ taskInfo);
+ }
+
+ protected DemultiplexingRecordDeserializer<T> getActiveSerializer(
+ InputChannelInfo channelInfo) {
+ final DemultiplexingRecordDeserializer<T> deserialier =
+ recordDeserializers.get(channelInfo);
Review comment:
nit: `super.getActiveSerializer(channelInfo);`
##########
File path:
flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/RecoveredChannelStateHandler.java
##########
@@ -121,27 +181,59 @@ public void recover(
"ResultSubpartitionRecoveredStateHandler#recover",
bufferBuilderAndConsumer.f1,
subpartitionInfo);
- boolean added =
- getSubpartition(subpartitionInfo)
- .add(bufferBuilderAndConsumer.f1,
Integer.MIN_VALUE);
- if (!added) {
- throw new IOException("Buffer consumer couldn't be added to
ResultSubpartition");
+ final List<CheckpointedResultSubpartition> channels =
+ getMappedChannels(subpartitionInfo);
+ for (final CheckpointedResultSubpartition channel : channels) {
+ // channel selector is created from the downstream's point of
view: the subtask of
+ // downstream = subpartition index of recovered buffer
+ final VirtualChannelSelector channelSelector =
+ new VirtualChannelSelector(
+ subpartitionInfo.getSubPartitionIdx(),
oldSubtaskIndex);
+ channel.add(
+ EventSerializer.toBufferConsumer(channelSelector,
false),
+ Integer.MIN_VALUE);
+ boolean added =
channel.add(bufferBuilderAndConsumer.f1.copy(), Integer.MIN_VALUE);
+ if (!added) {
+ throw new IOException(
+ "Buffer consumer couldn't be added to
ResultSubpartition");
+ }
}
- } else {
- bufferBuilderAndConsumer.f1.close();
}
+ bufferBuilderAndConsumer.f1.close();
Review comment:
Should it be in `finally`?
##########
File path:
flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/recovery/RescalingStreamTaskNetworkInput.java
##########
@@ -0,0 +1,280 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.runtime.io.recovery;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.TaskInfo;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.io.InputStatus;
+import org.apache.flink.runtime.checkpoint.CheckpointException;
+import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor;
+import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter;
+import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo;
+import org.apache.flink.runtime.event.AbstractEvent;
+import org.apache.flink.runtime.io.disk.iomanager.IOManager;
+import org.apache.flink.runtime.io.network.api.VirtualChannelSelector;
+import
org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer;
+import
org.apache.flink.runtime.io.network.api.serialization.SpillingAdaptiveSpanningRecordDeserializer;
+import org.apache.flink.runtime.io.network.partition.consumer.BufferOrEvent;
+import org.apache.flink.runtime.plugable.DeserializationDelegate;
+import org.apache.flink.streaming.runtime.io.AbstractStreamTaskNetworkInput;
+import org.apache.flink.streaming.runtime.io.RecoverableStreamTaskInput;
+import org.apache.flink.streaming.runtime.io.StreamTaskInput;
+import org.apache.flink.streaming.runtime.io.StreamTaskNetworkInput;
+import
org.apache.flink.streaming.runtime.io.checkpointing.CheckpointedInputGate;
+import
org.apache.flink.streaming.runtime.partitioner.ConfigurableStreamPartitioner;
+import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
+import org.apache.flink.streaming.runtime.streamrecord.StreamElement;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.runtime.streamstatus.StatusWatermarkValve;
+
+import org.apache.flink.shaded.guava18.com.google.common.collect.Maps;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.concurrent.CompletableFuture;
+import java.util.function.Function;
+import java.util.function.Predicate;
+
+import static
org.apache.flink.runtime.checkpoint.CheckpointFailureReason.CHECKPOINT_DECLINED_TASK_NOT_READY;
+
+/**
+ * A {@link StreamTaskNetworkInput} implementation that demultiplexes virtual
channels.
+ *
+ * <p>The demultiplexing works in two dimensions for the following cases. *
+ *
+ * <ul>
+ * <li>Subtasks of the current operator have been collapsed in a round-robin
fashion.
+ * <li>The connected output operator has been rescaled (up and down!) and
there is an overlap of
+ * channels (mostly relevant to keyed exchanges).
+ * </ul>
+ *
+ * <p>In both cases, records from multiple old channels are received over one
new physical channel,
+ * which need to demultiplex the record to correctly restore spanning records
(similar to how
+ * StreamTaskNetworkInput works).
+ *
+ * <p>Note that when both cases occur at the same time (downscaling of several
operators), there is
+ * the cross product of channels. So if two subtasks are collapsed and two
channels overlap from the
+ * output side, there is a total of 4 virtual channels.
+ */
+@Internal
+public final class RescalingStreamTaskNetworkInput<T>
+ extends AbstractStreamTaskNetworkInput<T,
DemultiplexingRecordDeserializer<T>>
+ implements RecoverableStreamTaskInput<T> {
+
+ private static final Logger LOG =
+ LoggerFactory.getLogger(RescalingStreamTaskNetworkInput.class);
+ private final IOManager ioManager;
+
+ private RescalingStreamTaskNetworkInput(
+ CheckpointedInputGate checkpointedInputGate,
+ TypeSerializer<T> inputSerializer,
+ IOManager ioManager,
+ StatusWatermarkValve statusWatermarkValve,
+ int inputIndex,
+ InflightDataRescalingDescriptor inflightDataRescalingDescriptor,
+ Function<Integer, StreamPartitioner<?>> gatePartitioners,
+ TaskInfo taskInfo) {
+ super(
+ checkpointedInputGate,
+ inputSerializer,
+ statusWatermarkValve,
+ inputIndex,
+ getRecordDeserializers(
+ checkpointedInputGate,
+ inputSerializer,
+ ioManager,
+ inflightDataRescalingDescriptor,
+ gatePartitioners,
+ taskInfo));
+ this.ioManager = ioManager;
+
+ LOG.info(
+ "Created demultiplexer for input {} from {}",
+ inputIndex,
+ inflightDataRescalingDescriptor);
+ }
+
+ private static <T>
+ Map<InputChannelInfo, DemultiplexingRecordDeserializer<T>>
getRecordDeserializers(
+ CheckpointedInputGate checkpointedInputGate,
+ TypeSerializer<T> inputSerializer,
+ IOManager ioManager,
+ InflightDataRescalingDescriptor rescalingDescriptor,
+ Function<Integer, StreamPartitioner<?>> gatePartitioners,
+ TaskInfo taskInfo) {
+
+ RecordFilterFactory<T> recordFilterFactory =
+ new RecordFilterFactory<>(
+ taskInfo.getIndexOfThisSubtask(),
+ inputSerializer,
+ taskInfo.getNumberOfParallelSubtasks(),
+ gatePartitioners,
+ taskInfo.getMaxNumberOfParallelSubtasks());
+ final DeserializerFactory deserializerFactory = new
DeserializerFactory(ioManager);
+ Map<InputChannelInfo, DemultiplexingRecordDeserializer<T>>
deserializers =
+
Maps.newHashMapWithExpectedSize(checkpointedInputGate.getChannelInfos().size());
+ for (InputChannelInfo channelInfo :
checkpointedInputGate.getChannelInfos()) {
+ deserializers.put(
+ channelInfo,
+ DemultiplexingRecordDeserializer.create(
+ channelInfo,
+ rescalingDescriptor,
+ deserializerFactory,
+ recordFilterFactory));
+ }
+ return deserializers;
+ }
+
+ @Override
+ public StreamTaskInput<T> finishRecovery() throws IOException {
+ close();
+ return new StreamTaskNetworkInput<>(
+ checkpointedInputGate,
+ inputSerializer,
+ ioManager,
+ statusWatermarkValve,
+ inputIndex);
+ }
+
+ /**
+ * Factory method for {@link StreamTaskNetworkInput} or {@link
RescalingStreamTaskNetworkInput}
+ * depending on {@link InflightDataRescalingDescriptor}.
+ */
+ public static <T> StreamTaskInput<T> create(
Review comment:
How about moving this (and other) static methods to a dedicated factory
class?
To me the responsibility of this class would be more clear.
##########
File path:
flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/recovery/RescalingStreamTaskNetworkInput.java
##########
@@ -0,0 +1,280 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.runtime.io.recovery;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.TaskInfo;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.io.InputStatus;
+import org.apache.flink.runtime.checkpoint.CheckpointException;
+import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor;
+import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter;
+import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo;
+import org.apache.flink.runtime.event.AbstractEvent;
+import org.apache.flink.runtime.io.disk.iomanager.IOManager;
+import org.apache.flink.runtime.io.network.api.VirtualChannelSelector;
+import
org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer;
+import
org.apache.flink.runtime.io.network.api.serialization.SpillingAdaptiveSpanningRecordDeserializer;
+import org.apache.flink.runtime.io.network.partition.consumer.BufferOrEvent;
+import org.apache.flink.runtime.plugable.DeserializationDelegate;
+import org.apache.flink.streaming.runtime.io.AbstractStreamTaskNetworkInput;
+import org.apache.flink.streaming.runtime.io.RecoverableStreamTaskInput;
+import org.apache.flink.streaming.runtime.io.StreamTaskInput;
+import org.apache.flink.streaming.runtime.io.StreamTaskNetworkInput;
+import
org.apache.flink.streaming.runtime.io.checkpointing.CheckpointedInputGate;
+import
org.apache.flink.streaming.runtime.partitioner.ConfigurableStreamPartitioner;
+import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
+import org.apache.flink.streaming.runtime.streamrecord.StreamElement;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.runtime.streamstatus.StatusWatermarkValve;
+
+import org.apache.flink.shaded.guava18.com.google.common.collect.Maps;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.concurrent.CompletableFuture;
+import java.util.function.Function;
+import java.util.function.Predicate;
+
+import static
org.apache.flink.runtime.checkpoint.CheckpointFailureReason.CHECKPOINT_DECLINED_TASK_NOT_READY;
+
+/**
+ * A {@link StreamTaskNetworkInput} implementation that demultiplexes virtual
channels.
+ *
+ * <p>The demultiplexing works in two dimensions for the following cases. *
+ *
+ * <ul>
+ * <li>Subtasks of the current operator have been collapsed in a round-robin
fashion.
+ * <li>The connected output operator has been rescaled (up and down!) and
there is an overlap of
+ * channels (mostly relevant to keyed exchanges).
+ * </ul>
+ *
+ * <p>In both cases, records from multiple old channels are received over one
new physical channel,
+ * which need to demultiplex the record to correctly restore spanning records
(similar to how
+ * StreamTaskNetworkInput works).
+ *
+ * <p>Note that when both cases occur at the same time (downscaling of several
operators), there is
+ * the cross product of channels. So if two subtasks are collapsed and two
channels overlap from the
+ * output side, there is a total of 4 virtual channels.
+ */
+@Internal
+public final class RescalingStreamTaskNetworkInput<T>
+ extends AbstractStreamTaskNetworkInput<T,
DemultiplexingRecordDeserializer<T>>
+ implements RecoverableStreamTaskInput<T> {
+
+ private static final Logger LOG =
+ LoggerFactory.getLogger(RescalingStreamTaskNetworkInput.class);
+ private final IOManager ioManager;
+
+ private RescalingStreamTaskNetworkInput(
+ CheckpointedInputGate checkpointedInputGate,
+ TypeSerializer<T> inputSerializer,
+ IOManager ioManager,
+ StatusWatermarkValve statusWatermarkValve,
+ int inputIndex,
+ InflightDataRescalingDescriptor inflightDataRescalingDescriptor,
+ Function<Integer, StreamPartitioner<?>> gatePartitioners,
+ TaskInfo taskInfo) {
+ super(
+ checkpointedInputGate,
+ inputSerializer,
+ statusWatermarkValve,
+ inputIndex,
+ getRecordDeserializers(
+ checkpointedInputGate,
+ inputSerializer,
+ ioManager,
+ inflightDataRescalingDescriptor,
+ gatePartitioners,
+ taskInfo));
+ this.ioManager = ioManager;
+
+ LOG.info(
+ "Created demultiplexer for input {} from {}",
+ inputIndex,
+ inflightDataRescalingDescriptor);
+ }
+
+ private static <T>
+ Map<InputChannelInfo, DemultiplexingRecordDeserializer<T>>
getRecordDeserializers(
+ CheckpointedInputGate checkpointedInputGate,
+ TypeSerializer<T> inputSerializer,
+ IOManager ioManager,
+ InflightDataRescalingDescriptor rescalingDescriptor,
+ Function<Integer, StreamPartitioner<?>> gatePartitioners,
+ TaskInfo taskInfo) {
+
+ RecordFilterFactory<T> recordFilterFactory =
+ new RecordFilterFactory<>(
+ taskInfo.getIndexOfThisSubtask(),
+ inputSerializer,
+ taskInfo.getNumberOfParallelSubtasks(),
+ gatePartitioners,
+ taskInfo.getMaxNumberOfParallelSubtasks());
+ final DeserializerFactory deserializerFactory = new
DeserializerFactory(ioManager);
+ Map<InputChannelInfo, DemultiplexingRecordDeserializer<T>>
deserializers =
+
Maps.newHashMapWithExpectedSize(checkpointedInputGate.getChannelInfos().size());
+ for (InputChannelInfo channelInfo :
checkpointedInputGate.getChannelInfos()) {
+ deserializers.put(
+ channelInfo,
+ DemultiplexingRecordDeserializer.create(
+ channelInfo,
+ rescalingDescriptor,
+ deserializerFactory,
+ recordFilterFactory));
+ }
+ return deserializers;
+ }
+
+ @Override
+ public StreamTaskInput<T> finishRecovery() throws IOException {
+ close();
Review comment:
Do we need to make sure that all the buffers in deserializers are
consumed?
##########
File path:
flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/VirtualChannelSelector.java
##########
@@ -0,0 +1,88 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.io.network.api;
+
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.runtime.event.RuntimeEvent;
+
+import java.util.Objects;
+
+/** An event that is used to demultiplex virtual channels over the same
physical channel. */
+public final class VirtualChannelSelector extends RuntimeEvent {
Review comment:
nit: To me `SubtaskConnectionDescriptor` would be more informative. But
that's a matter of taste so please ignore if you prefer
`VirtualChannelSelector`.
nit: virtual/physical channels in javadoc are confusing to me. How about
channels before/after re-scaling? For example:
An event sent over a channel **after re-scaling** to signal what channel was
used **before re-scaling** for the data being sent.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]