rkhachatryan commented on a change in pull request #13735:
URL: https://github.com/apache/flink/pull/13735#discussion_r516902951



##########
File path: 
flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorSubtaskState.java
##########
@@ -327,4 +289,252 @@ public boolean hasState() {
                        || inputChannelState.hasState()
                        || resultSubpartitionState.hasState();
        }
+
+       public static Builder builder() {
+               return new Builder();
+       }
+
+       /**
+        * The builder for a new {@link OperatorSubtaskState} which can be 
obtained by {@link #builder()}.
+        */
+       public static class Builder {
+               private StateObjectCollection<OperatorStateHandle> 
managedOperatorState = StateObjectCollection.empty();
+               private StateObjectCollection<OperatorStateHandle> 
rawOperatorState = StateObjectCollection.empty();
+               private StateObjectCollection<KeyedStateHandle> 
managedKeyedState = StateObjectCollection.empty();
+               private StateObjectCollection<KeyedStateHandle> rawKeyedState = 
StateObjectCollection.empty();
+               private StateObjectCollection<InputChannelStateHandle> 
inputChannelState = StateObjectCollection.empty();
+               private StateObjectCollection<ResultSubpartitionStateHandle> 
resultSubpartitionState = StateObjectCollection.empty();
+               private VirtualChannelMapping inputChannelMappings = 
VirtualChannelMapping.NO_MAPPING;
+               private VirtualChannelMapping outputChannelMappings = 
VirtualChannelMapping.NO_MAPPING;
+
+               private Builder() {
+               }
+
+               public Builder 
setManagedOperatorState(StateObjectCollection<OperatorStateHandle> 
managedOperatorState) {
+                       this.managedOperatorState = 
checkNotNull(managedOperatorState);
+                       return this;
+               }
+
+               public Builder setManagedOperatorState(OperatorStateHandle 
managedOperatorState) {
+                       return 
setManagedOperatorState(StateObjectCollection.singleton(checkNotNull(managedOperatorState)));
+               }
+
+               public Builder 
setRawOperatorState(StateObjectCollection<OperatorStateHandle> 
rawOperatorState) {
+                       this.rawOperatorState = checkNotNull(rawOperatorState);
+                       return this;
+               }
+
+               public Builder setRawOperatorState(OperatorStateHandle 
rawOperatorState) {
+                       return 
setRawOperatorState(StateObjectCollection.singleton(checkNotNull(rawOperatorState)));
+               }
+
+               public Builder 
setManagedKeyedState(StateObjectCollection<KeyedStateHandle> managedKeyedState) 
{
+                       this.managedKeyedState = 
checkNotNull(managedKeyedState);
+                       return this;
+               }
+
+               public Builder setManagedKeyedState(KeyedStateHandle 
managedKeyedState) {
+                       return 
setManagedKeyedState(StateObjectCollection.singleton(checkNotNull(managedKeyedState)));
+               }
+
+               public Builder 
setRawKeyedState(StateObjectCollection<KeyedStateHandle> rawKeyedState) {
+                       this.rawKeyedState = checkNotNull(rawKeyedState);
+                       return this;
+               }
+
+               public Builder setRawKeyedState(KeyedStateHandle rawKeyedState) 
{
+                       return 
setRawKeyedState(StateObjectCollection.singleton(checkNotNull(rawKeyedState)));
+               }
+
+               public Builder 
setInputChannelState(StateObjectCollection<InputChannelStateHandle> 
inputChannelState) {
+                       this.inputChannelState = 
checkNotNull(inputChannelState);
+                       return this;
+               }
+
+               public Builder 
setResultSubpartitionState(StateObjectCollection<ResultSubpartitionStateHandle> 
resultSubpartitionState) {
+                       this.resultSubpartitionState = 
checkNotNull(resultSubpartitionState);
+                       return this;
+               }
+
+               public Builder setInputChannelMappings(VirtualChannelMapping 
inputChannelMappings) {
+                       this.inputChannelMappings = 
checkNotNull(inputChannelMappings);
+                       return this;
+               }
+
+               public Builder setOutputChannelMappings(VirtualChannelMapping 
outputChannelMappings) {
+                       this.outputChannelMappings = 
checkNotNull(outputChannelMappings);
+                       return this;
+               }
+
+               public OperatorSubtaskState build() {
+                       return new OperatorSubtaskState(
+                               managedOperatorState,
+                               rawOperatorState,
+                               managedKeyedState,
+                               rawKeyedState,
+                               inputChannelState,
+                               resultSubpartitionState,
+                               inputChannelMappings,
+                               outputChannelMappings);
+               }
+       }
+
+       /**
+        * Captures ambiguous mappings of old channels to new channels.
+        *
+        * <p>For inputs, this mapping implies the following:
+        * <li>
+        *     <ul>{@link #oldTaskInstances} is set when there is a rescale on 
this task potentially leading to different
+        *     key groups. Upstream task has a corresponding {@link 
#partitionMappings} where it sends data over
+        *     virtual channel while specifying the channel index in the 
VirtualChannelSelector. This subtask then
+        *     demultiplexes over the virtual subtask index.</ul>
+        *     <ul>{@link #partitionMappings} is set when there is a downscale 
of the upstream task. Upstream task has
+        *     a corresponding {@link #oldTaskInstances} where it sends data 
over virtual channel while specifying the
+        *     subtask index in the VirtualChannelSelector. This subtask then 
demultiplexes over channel indexes.</ul>
+        * </li>
+        *
+        * <p>For outputs, it's vice-versa. The information must be kept in 
sync but they are used in opposite ways for
+        * multiplexing/demultiplexing.
+        *
+        * <p>Note that in the common rescaling case both information is set 
and need to be simultaneously used. If the
+        * input subtask subsumes the state of 3 old subtasks and a channel 
corresponds to 2 old channels, then there are
+        * 6 virtual channels to be demultiplexed.
+        */
+       public static class VirtualChannelMapping implements Serializable {

Review comment:
       1. Is it actually **channel** mapping? I think it actually remaps 
operator state, doesn't it?
   2. `serialVersionUID`?
   3. I'd rather move this class to a separate file

##########
File path: 
flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorSubtaskState.java
##########
@@ -327,4 +289,252 @@ public boolean hasState() {
                        || inputChannelState.hasState()
                        || resultSubpartitionState.hasState();
        }
+
+       public static Builder builder() {
+               return new Builder();
+       }
+
+       /**
+        * The builder for a new {@link OperatorSubtaskState} which can be 
obtained by {@link #builder()}.
+        */
+       public static class Builder {
+               private StateObjectCollection<OperatorStateHandle> 
managedOperatorState = StateObjectCollection.empty();
+               private StateObjectCollection<OperatorStateHandle> 
rawOperatorState = StateObjectCollection.empty();
+               private StateObjectCollection<KeyedStateHandle> 
managedKeyedState = StateObjectCollection.empty();
+               private StateObjectCollection<KeyedStateHandle> rawKeyedState = 
StateObjectCollection.empty();
+               private StateObjectCollection<InputChannelStateHandle> 
inputChannelState = StateObjectCollection.empty();
+               private StateObjectCollection<ResultSubpartitionStateHandle> 
resultSubpartitionState = StateObjectCollection.empty();
+               private VirtualChannelMapping inputChannelMappings = 
VirtualChannelMapping.NO_MAPPING;
+               private VirtualChannelMapping outputChannelMappings = 
VirtualChannelMapping.NO_MAPPING;
+
+               private Builder() {
+               }
+
+               public Builder 
setManagedOperatorState(StateObjectCollection<OperatorStateHandle> 
managedOperatorState) {
+                       this.managedOperatorState = 
checkNotNull(managedOperatorState);
+                       return this;
+               }
+
+               public Builder setManagedOperatorState(OperatorStateHandle 
managedOperatorState) {
+                       return 
setManagedOperatorState(StateObjectCollection.singleton(checkNotNull(managedOperatorState)));
+               }
+
+               public Builder 
setRawOperatorState(StateObjectCollection<OperatorStateHandle> 
rawOperatorState) {
+                       this.rawOperatorState = checkNotNull(rawOperatorState);
+                       return this;
+               }
+
+               public Builder setRawOperatorState(OperatorStateHandle 
rawOperatorState) {
+                       return 
setRawOperatorState(StateObjectCollection.singleton(checkNotNull(rawOperatorState)));
+               }
+
+               public Builder 
setManagedKeyedState(StateObjectCollection<KeyedStateHandle> managedKeyedState) 
{
+                       this.managedKeyedState = 
checkNotNull(managedKeyedState);
+                       return this;
+               }
+
+               public Builder setManagedKeyedState(KeyedStateHandle 
managedKeyedState) {
+                       return 
setManagedKeyedState(StateObjectCollection.singleton(checkNotNull(managedKeyedState)));
+               }
+
+               public Builder 
setRawKeyedState(StateObjectCollection<KeyedStateHandle> rawKeyedState) {
+                       this.rawKeyedState = checkNotNull(rawKeyedState);
+                       return this;
+               }
+
+               public Builder setRawKeyedState(KeyedStateHandle rawKeyedState) 
{
+                       return 
setRawKeyedState(StateObjectCollection.singleton(checkNotNull(rawKeyedState)));
+               }
+
+               public Builder 
setInputChannelState(StateObjectCollection<InputChannelStateHandle> 
inputChannelState) {
+                       this.inputChannelState = 
checkNotNull(inputChannelState);
+                       return this;
+               }
+
+               public Builder 
setResultSubpartitionState(StateObjectCollection<ResultSubpartitionStateHandle> 
resultSubpartitionState) {
+                       this.resultSubpartitionState = 
checkNotNull(resultSubpartitionState);
+                       return this;
+               }
+
+               public Builder setInputChannelMappings(VirtualChannelMapping 
inputChannelMappings) {
+                       this.inputChannelMappings = 
checkNotNull(inputChannelMappings);
+                       return this;
+               }
+
+               public Builder setOutputChannelMappings(VirtualChannelMapping 
outputChannelMappings) {
+                       this.outputChannelMappings = 
checkNotNull(outputChannelMappings);
+                       return this;
+               }
+
+               public OperatorSubtaskState build() {
+                       return new OperatorSubtaskState(
+                               managedOperatorState,
+                               rawOperatorState,
+                               managedKeyedState,
+                               rawKeyedState,
+                               inputChannelState,
+                               resultSubpartitionState,
+                               inputChannelMappings,
+                               outputChannelMappings);
+               }
+       }
+
+       /**
+        * Captures ambiguous mappings of old channels to new channels.
+        *
+        * <p>For inputs, this mapping implies the following:
+        * <li>
+        *     <ul>{@link #oldTaskInstances} is set when there is a rescale on 
this task potentially leading to different
+        *     key groups. Upstream task has a corresponding {@link 
#partitionMappings} where it sends data over
+        *     virtual channel while specifying the channel index in the 
VirtualChannelSelector. This subtask then
+        *     demultiplexes over the virtual subtask index.</ul>
+        *     <ul>{@link #partitionMappings} is set when there is a downscale 
of the upstream task. Upstream task has
+        *     a corresponding {@link #oldTaskInstances} where it sends data 
over virtual channel while specifying the
+        *     subtask index in the VirtualChannelSelector. This subtask then 
demultiplexes over channel indexes.</ul>
+        * </li>
+        *
+        * <p>For outputs, it's vice-versa. The information must be kept in 
sync but they are used in opposite ways for
+        * multiplexing/demultiplexing.
+        *
+        * <p>Note that in the common rescaling case both information is set 
and need to be simultaneously used. If the
+        * input subtask subsumes the state of 3 old subtasks and a channel 
corresponds to 2 old channels, then there are
+        * 6 virtual channels to be demultiplexed.
+        */
+       public static class VirtualChannelMapping implements Serializable {
+               public static final PartitionMapping NO_CHANNEL_MAPPING = new 
PartitionMapping(emptyList());
+               public static final List<PartitionMapping> NO_PARTITIONS = 
emptyList();
+               public static final BitSet NO_SUBTASKS = new BitSet();
+               public static final VirtualChannelMapping NO_MAPPING = new 
VirtualChannelMapping(NO_SUBTASKS, NO_PARTITIONS);
+
+               /**
+                * Set when several operator instances are merged into one.
+                */
+               private final BitSet oldTaskInstances;
+
+               /**
+                * Set when channels are merged because the connected operator 
has been rescaled.
+                */
+               private final List<PartitionMapping> partitionMappings;

Review comment:
       To me, list isn't an obvious choice to provide index-based access.
   Array or hashtable would be more readable to me and guarantee O(1) time.

##########
File path: 
flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
##########
@@ -99,32 +119,40 @@ public void assignStates() {
                                                operatorID,
                                                
executionJobVertex.getParallelism(),
                                                
executionJobVertex.getMaxParallelism());
-                               } else if 
(operatorState.getNumberCollectedStates() > 0) {
-                                       statelessSubTasks = false;
                                }
-                               operatorStates.add(operatorState);
+                               
operatorStates.put(operatorIDPair.getGeneratedOperatorID(), operatorState);
                        }
-                       if (!statelessSubTasks) { // skip tasks where no 
operator has any state
-                               assignAttemptState(executionJobVertex, 
operatorStates);
+
+                       final TaskStateAssignment stateAssignment = new 
TaskStateAssignment(executionJobVertex, operatorStates);
+                       vertexAssignments.put(executionJobVertex, 
stateAssignment);
+                       for (final IntermediateResult producedDataSet : 
executionJobVertex.getInputs()) {
+                               consumerAssignment.put(producedDataSet, 
stateAssignment);

Review comment:
       1. I wonder if this `put` can override existing `assignment` (e.g. with 
a `UnionGate`).
   2. `IntermediateResult` use as a key doesn't override equals/hashCode - is 
it intentional?

##########
File path: 
flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/partitioner/ForwardPartitioner.java
##########
@@ -43,4 +44,14 @@ public int 
selectChannel(SerializationDelegate<StreamRecord<T>> record) {
        public String toString() {
                return "FORWARD";
        }
+
+       @Override
+       public ChannelStateRescaler getUpstreamChannelStateRescaler() {
+               return ChannelStateRescaler.FIRST_CHANNEL;
+       }
+
+       @Override
+       public ChannelStateRescaler getDownstreamChannelStateRescaler() {
+               return ChannelStateRescaler.ROUND_ROBIN;

Review comment:
       Why isn't channel state distributed to the same subtask as before (or 
1st if out of range)?
   (ditto upstream)
   Shouldn't recovered buffers sent from the upstream match what is on the 
downstream?

##########
File path: 
flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/partitioner/GlobalPartitioner.java
##########
@@ -40,6 +41,16 @@ public int 
selectChannel(SerializationDelegate<StreamRecord<T>> record) {
                return this;
        }
 
+       @Override
+       public ChannelStateRescaler getUpstreamChannelStateRescaler() {
+               return ChannelStateRescaler.FIRST_CHANNEL;
+       }
+
+       @Override
+       public ChannelStateRescaler getDownstreamChannelStateRescaler() {
+               return ChannelStateRescaler.ROUND_ROBIN;

Review comment:
       I'm probably misunderstanding it, but it seems it should be the opposite:
   upstream: `ROUND_ROBING` (or both FIRST)
   downstream: `FIRST_CHANNEL`
   

##########
File path: 
flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
##########
@@ -620,18 +758,112 @@ private static void checkStateMappingCompleteness(
                        chainOpParallelStates,
                        oldParallelism,
                        newParallelism);
+       }
+
+       static class TaskStateAssignment {

Review comment:
       1. It would ease review if this class was introduced in a separate 
refactoring commit (though it's probably too difficult at this stage)
   2. Can we extract this class to a separate file?

##########
File path: 
flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
##########
@@ -67,29 +83,33 @@
        private final long restoreCheckpointId;
        private final boolean allowNonRestoredState;
 
+       private final Map<IntermediateResult, TaskStateAssignment> 
consumerAssignment = new HashMap<>();
+       private final Map<ChannelStateRescaler, 
ChannelRescalerRepartitioner<Object>> rescalerRepartitioners =

Review comment:
       IIUC, this is a cache, right?
   I doubt that we really need it: it's only used on re/starts and can save 
maybe hundreds of ms with DOP=1K. And without it, the code would be much 
simpler.
   WDYT?

##########
File path: 
flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java
##########
@@ -725,6 +727,15 @@ public void deploy() throws JobException {
                        LOG.info("Deploying {} (attempt #{}) with attempt id {} 
to {} with allocation id {}", vertex.getTaskNameWithSubtaskIndex(),
                                attemptNumber, 
vertex.getCurrentExecutionAttempt().getAttemptId(), 
getAssignedResourceLocation(), slot.getAllocationId());
 
+                       if (taskRestore != null) {
+                               
checkState(taskRestore.getTaskStateSnapshot().getSubtaskStateMappings().stream().allMatch(entry
 ->
+                                       
entry.getValue().getInputChannelMappings().stream()
+                                               .allMatch(mapping -> 
mapping.equals(OperatorSubtaskState.VirtualChannelMapping.NO_MAPPING)) &&
+                                               
entry.getValue().getOutputChannelMappings().stream()
+                                                       .allMatch(mapping -> 
mapping.equals(OperatorSubtaskState.VirtualChannelMapping.NO_MAPPING))
+                               ), "Rescaling from unaligned checkpoint is not 
yet supported.");

Review comment:
       1. This doesn't compile (cannot find symbol `getInputChannelMappings`)
   2. Probably makes sense to extract testing methods (unless it's a temporary 
method)

##########
File path: 
flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java
##########
@@ -692,6 +692,8 @@ private void connect(Integer headOfChain, StreamEdge edge) {
                }
                // set strategy name so that web interface can show it.
                jobEdge.setShipStrategyName(partitioner.toString());
+               
jobEdge.setDownstreamChannelStateRescaler(partitioner.getUpstreamChannelStateRescaler());
+               
jobEdge.setUpstreamChannelStateRescaler(partitioner.getDownstreamChannelStateRescaler());

Review comment:
       OK now I see why the Rescalers used seemed inverted to me :)
   So why is it inverted here? (maybe add a comment?)

##########
File path: 
flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorSubtaskState.java
##########
@@ -327,4 +289,252 @@ public boolean hasState() {
                        || inputChannelState.hasState()
                        || resultSubpartitionState.hasState();
        }
+
+       public static Builder builder() {
+               return new Builder();
+       }
+
+       /**
+        * The builder for a new {@link OperatorSubtaskState} which can be 
obtained by {@link #builder()}.
+        */
+       public static class Builder {
+               private StateObjectCollection<OperatorStateHandle> 
managedOperatorState = StateObjectCollection.empty();
+               private StateObjectCollection<OperatorStateHandle> 
rawOperatorState = StateObjectCollection.empty();
+               private StateObjectCollection<KeyedStateHandle> 
managedKeyedState = StateObjectCollection.empty();
+               private StateObjectCollection<KeyedStateHandle> rawKeyedState = 
StateObjectCollection.empty();
+               private StateObjectCollection<InputChannelStateHandle> 
inputChannelState = StateObjectCollection.empty();
+               private StateObjectCollection<ResultSubpartitionStateHandle> 
resultSubpartitionState = StateObjectCollection.empty();
+               private VirtualChannelMapping inputChannelMappings = 
VirtualChannelMapping.NO_MAPPING;
+               private VirtualChannelMapping outputChannelMappings = 
VirtualChannelMapping.NO_MAPPING;
+
+               private Builder() {
+               }
+
+               public Builder 
setManagedOperatorState(StateObjectCollection<OperatorStateHandle> 
managedOperatorState) {
+                       this.managedOperatorState = 
checkNotNull(managedOperatorState);
+                       return this;
+               }
+
+               public Builder setManagedOperatorState(OperatorStateHandle 
managedOperatorState) {
+                       return 
setManagedOperatorState(StateObjectCollection.singleton(checkNotNull(managedOperatorState)));
+               }
+
+               public Builder 
setRawOperatorState(StateObjectCollection<OperatorStateHandle> 
rawOperatorState) {
+                       this.rawOperatorState = checkNotNull(rawOperatorState);
+                       return this;
+               }
+
+               public Builder setRawOperatorState(OperatorStateHandle 
rawOperatorState) {
+                       return 
setRawOperatorState(StateObjectCollection.singleton(checkNotNull(rawOperatorState)));
+               }
+
+               public Builder 
setManagedKeyedState(StateObjectCollection<KeyedStateHandle> managedKeyedState) 
{
+                       this.managedKeyedState = 
checkNotNull(managedKeyedState);
+                       return this;
+               }
+
+               public Builder setManagedKeyedState(KeyedStateHandle 
managedKeyedState) {
+                       return 
setManagedKeyedState(StateObjectCollection.singleton(checkNotNull(managedKeyedState)));
+               }
+
+               public Builder 
setRawKeyedState(StateObjectCollection<KeyedStateHandle> rawKeyedState) {
+                       this.rawKeyedState = checkNotNull(rawKeyedState);
+                       return this;
+               }
+
+               public Builder setRawKeyedState(KeyedStateHandle rawKeyedState) 
{
+                       return 
setRawKeyedState(StateObjectCollection.singleton(checkNotNull(rawKeyedState)));
+               }
+
+               public Builder 
setInputChannelState(StateObjectCollection<InputChannelStateHandle> 
inputChannelState) {
+                       this.inputChannelState = 
checkNotNull(inputChannelState);
+                       return this;
+               }
+
+               public Builder 
setResultSubpartitionState(StateObjectCollection<ResultSubpartitionStateHandle> 
resultSubpartitionState) {
+                       this.resultSubpartitionState = 
checkNotNull(resultSubpartitionState);
+                       return this;
+               }
+
+               public Builder setInputChannelMappings(VirtualChannelMapping 
inputChannelMappings) {
+                       this.inputChannelMappings = 
checkNotNull(inputChannelMappings);
+                       return this;
+               }
+
+               public Builder setOutputChannelMappings(VirtualChannelMapping 
outputChannelMappings) {
+                       this.outputChannelMappings = 
checkNotNull(outputChannelMappings);
+                       return this;
+               }
+
+               public OperatorSubtaskState build() {
+                       return new OperatorSubtaskState(
+                               managedOperatorState,
+                               rawOperatorState,
+                               managedKeyedState,
+                               rawKeyedState,
+                               inputChannelState,
+                               resultSubpartitionState,
+                               inputChannelMappings,
+                               outputChannelMappings);
+               }
+       }
+
+       /**
+        * Captures ambiguous mappings of old channels to new channels.
+        *
+        * <p>For inputs, this mapping implies the following:
+        * <li>
+        *     <ul>{@link #oldTaskInstances} is set when there is a rescale on 
this task potentially leading to different
+        *     key groups. Upstream task has a corresponding {@link 
#partitionMappings} where it sends data over
+        *     virtual channel while specifying the channel index in the 
VirtualChannelSelector. This subtask then
+        *     demultiplexes over the virtual subtask index.</ul>
+        *     <ul>{@link #partitionMappings} is set when there is a downscale 
of the upstream task. Upstream task has
+        *     a corresponding {@link #oldTaskInstances} where it sends data 
over virtual channel while specifying the
+        *     subtask index in the VirtualChannelSelector. This subtask then 
demultiplexes over channel indexes.</ul>
+        * </li>
+        *
+        * <p>For outputs, it's vice-versa. The information must be kept in 
sync but they are used in opposite ways for
+        * multiplexing/demultiplexing.
+        *
+        * <p>Note that in the common rescaling case both information is set 
and need to be simultaneously used. If the
+        * input subtask subsumes the state of 3 old subtasks and a channel 
corresponds to 2 old channels, then there are
+        * 6 virtual channels to be demultiplexed.
+        */
+       public static class VirtualChannelMapping implements Serializable {
+               public static final PartitionMapping NO_CHANNEL_MAPPING = new 
PartitionMapping(emptyList());
+               public static final List<PartitionMapping> NO_PARTITIONS = 
emptyList();
+               public static final BitSet NO_SUBTASKS = new BitSet();
+               public static final VirtualChannelMapping NO_MAPPING = new 
VirtualChannelMapping(NO_SUBTASKS, NO_PARTITIONS);
+
+               /**
+                * Set when several operator instances are merged into one.
+                */
+               private final BitSet oldTaskInstances;

Review comment:
       Why bitsets are used here and throughout the PR?
   I think just `Set<Integer>` would be
   a) more readable
   b) more efficient (no need for extra words if only high bits are sets)

##########
File path: 
flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorSubtaskState.java
##########
@@ -327,4 +289,252 @@ public boolean hasState() {
                        || inputChannelState.hasState()
                        || resultSubpartitionState.hasState();
        }
+
+       public static Builder builder() {
+               return new Builder();
+       }
+
+       /**
+        * The builder for a new {@link OperatorSubtaskState} which can be 
obtained by {@link #builder()}.
+        */
+       public static class Builder {
+               private StateObjectCollection<OperatorStateHandle> 
managedOperatorState = StateObjectCollection.empty();
+               private StateObjectCollection<OperatorStateHandle> 
rawOperatorState = StateObjectCollection.empty();
+               private StateObjectCollection<KeyedStateHandle> 
managedKeyedState = StateObjectCollection.empty();
+               private StateObjectCollection<KeyedStateHandle> rawKeyedState = 
StateObjectCollection.empty();
+               private StateObjectCollection<InputChannelStateHandle> 
inputChannelState = StateObjectCollection.empty();
+               private StateObjectCollection<ResultSubpartitionStateHandle> 
resultSubpartitionState = StateObjectCollection.empty();
+               private VirtualChannelMapping inputChannelMappings = 
VirtualChannelMapping.NO_MAPPING;
+               private VirtualChannelMapping outputChannelMappings = 
VirtualChannelMapping.NO_MAPPING;
+
+               private Builder() {
+               }
+
+               public Builder 
setManagedOperatorState(StateObjectCollection<OperatorStateHandle> 
managedOperatorState) {
+                       this.managedOperatorState = 
checkNotNull(managedOperatorState);
+                       return this;
+               }
+
+               public Builder setManagedOperatorState(OperatorStateHandle 
managedOperatorState) {
+                       return 
setManagedOperatorState(StateObjectCollection.singleton(checkNotNull(managedOperatorState)));
+               }
+
+               public Builder 
setRawOperatorState(StateObjectCollection<OperatorStateHandle> 
rawOperatorState) {
+                       this.rawOperatorState = checkNotNull(rawOperatorState);
+                       return this;
+               }
+
+               public Builder setRawOperatorState(OperatorStateHandle 
rawOperatorState) {
+                       return 
setRawOperatorState(StateObjectCollection.singleton(checkNotNull(rawOperatorState)));
+               }
+
+               public Builder 
setManagedKeyedState(StateObjectCollection<KeyedStateHandle> managedKeyedState) 
{
+                       this.managedKeyedState = 
checkNotNull(managedKeyedState);
+                       return this;
+               }
+
+               public Builder setManagedKeyedState(KeyedStateHandle 
managedKeyedState) {
+                       return 
setManagedKeyedState(StateObjectCollection.singleton(checkNotNull(managedKeyedState)));
+               }
+
+               public Builder 
setRawKeyedState(StateObjectCollection<KeyedStateHandle> rawKeyedState) {
+                       this.rawKeyedState = checkNotNull(rawKeyedState);
+                       return this;
+               }
+
+               public Builder setRawKeyedState(KeyedStateHandle rawKeyedState) 
{
+                       return 
setRawKeyedState(StateObjectCollection.singleton(checkNotNull(rawKeyedState)));
+               }
+
+               public Builder 
setInputChannelState(StateObjectCollection<InputChannelStateHandle> 
inputChannelState) {
+                       this.inputChannelState = 
checkNotNull(inputChannelState);
+                       return this;
+               }
+
+               public Builder 
setResultSubpartitionState(StateObjectCollection<ResultSubpartitionStateHandle> 
resultSubpartitionState) {
+                       this.resultSubpartitionState = 
checkNotNull(resultSubpartitionState);
+                       return this;
+               }
+
+               public Builder setInputChannelMappings(VirtualChannelMapping 
inputChannelMappings) {
+                       this.inputChannelMappings = 
checkNotNull(inputChannelMappings);
+                       return this;
+               }
+
+               public Builder setOutputChannelMappings(VirtualChannelMapping 
outputChannelMappings) {
+                       this.outputChannelMappings = 
checkNotNull(outputChannelMappings);
+                       return this;
+               }
+
+               public OperatorSubtaskState build() {
+                       return new OperatorSubtaskState(
+                               managedOperatorState,
+                               rawOperatorState,
+                               managedKeyedState,
+                               rawKeyedState,
+                               inputChannelState,
+                               resultSubpartitionState,
+                               inputChannelMappings,
+                               outputChannelMappings);
+               }
+       }
+
+       /**
+        * Captures ambiguous mappings of old channels to new channels.
+        *
+        * <p>For inputs, this mapping implies the following:
+        * <li>
+        *     <ul>{@link #oldTaskInstances} is set when there is a rescale on 
this task potentially leading to different
+        *     key groups. Upstream task has a corresponding {@link 
#partitionMappings} where it sends data over
+        *     virtual channel while specifying the channel index in the 
VirtualChannelSelector. This subtask then
+        *     demultiplexes over the virtual subtask index.</ul>
+        *     <ul>{@link #partitionMappings} is set when there is a downscale 
of the upstream task. Upstream task has
+        *     a corresponding {@link #oldTaskInstances} where it sends data 
over virtual channel while specifying the
+        *     subtask index in the VirtualChannelSelector. This subtask then 
demultiplexes over channel indexes.</ul>
+        * </li>
+        *
+        * <p>For outputs, it's vice-versa. The information must be kept in 
sync but they are used in opposite ways for
+        * multiplexing/demultiplexing.
+        *
+        * <p>Note that in the common rescaling case both information is set 
and need to be simultaneously used. If the
+        * input subtask subsumes the state of 3 old subtasks and a channel 
corresponds to 2 old channels, then there are
+        * 6 virtual channels to be demultiplexed.
+        */
+       public static class VirtualChannelMapping implements Serializable {
+               public static final PartitionMapping NO_CHANNEL_MAPPING = new 
PartitionMapping(emptyList());
+               public static final List<PartitionMapping> NO_PARTITIONS = 
emptyList();
+               public static final BitSet NO_SUBTASKS = new BitSet();
+               public static final VirtualChannelMapping NO_MAPPING = new 
VirtualChannelMapping(NO_SUBTASKS, NO_PARTITIONS);
+
+               /**
+                * Set when several operator instances are merged into one.
+                */
+               private final BitSet oldTaskInstances;
+
+               /**
+                * Set when channels are merged because the connected operator 
has been rescaled.
+                */
+               private final List<PartitionMapping> partitionMappings;
+
+               public VirtualChannelMapping(BitSet oldTaskInstances, 
List<PartitionMapping> partitionMappings) {
+                       this.oldTaskInstances = oldTaskInstances;
+                       this.partitionMappings = partitionMappings;
+               }
+
+               @Override
+               public boolean equals(Object o) {
+                       if (this == o) {
+                               return true;
+                       }
+                       if (o == null || getClass() != o.getClass()) {
+                               return false;
+                       }
+                       final VirtualChannelMapping that = 
(VirtualChannelMapping) o;
+                       return oldTaskInstances.equals(that.oldTaskInstances) &&
+                               
partitionMappings.equals(that.partitionMappings);
+               }
+
+               public int[] getOldTaskInstances(int defaultSubtask) {
+                       return oldTaskInstances.equals(NO_SUBTASKS) ?
+                               new int[] {defaultSubtask} :
+                               oldTaskInstances.stream().toArray();
+               }
+
+               public PartitionMapping getPartitionMapping(int partitionIndex) 
{
+                       if (partitionMappings.isEmpty()) {
+                               return NO_CHANNEL_MAPPING;
+                       }
+                       return partitionMappings.get(partitionIndex);
+               }
+
+               @Override
+               public int hashCode() {
+                       return Objects.hash(oldTaskInstances, 
partitionMappings);
+               }
+
+               @Override
+               public String toString() {
+                       return "VirtualChannelMapping{" +
+                               "oldTaskInstances=" + oldTaskInstances +
+                               ", partitionMappings=" + partitionMappings +
+                               '}';
+               }
+       }
+
+       /**
+        * Contains the fine-grain channel mappings that occur when a connected 
operator has been rescaled.
+        */
+       public static class PartitionMapping implements Serializable {
+
+               /**
+                * For each new channel (=index), all old channels are set.
+                */
+               private final List<BitSet> newToOldChannelIndexes;

Review comment:
       Again, the choice of data structures isn't obvious. Why not `Map<Int, 
Int>`?
   
   (ditto `oldToNewChannelIndexes`)

##########
File path: 
flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractChannelStateHandle.java
##########
@@ -46,6 +46,11 @@
        private final List<Long> offsets;
        private final long size;
 
+       /**
+        * The original subtask index before rescaling recovery.
+        */
+       private int originalSubtaskIndex;

Review comment:
       You probably should serialize this field and add it to the 
equals/hashCode.

##########
File path: 
flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorSubtaskState.java
##########
@@ -327,4 +289,252 @@ public boolean hasState() {
                        || inputChannelState.hasState()
                        || resultSubpartitionState.hasState();
        }
+
+       public static Builder builder() {
+               return new Builder();
+       }
+
+       /**
+        * The builder for a new {@link OperatorSubtaskState} which can be 
obtained by {@link #builder()}.
+        */
+       public static class Builder {
+               private StateObjectCollection<OperatorStateHandle> 
managedOperatorState = StateObjectCollection.empty();
+               private StateObjectCollection<OperatorStateHandle> 
rawOperatorState = StateObjectCollection.empty();
+               private StateObjectCollection<KeyedStateHandle> 
managedKeyedState = StateObjectCollection.empty();
+               private StateObjectCollection<KeyedStateHandle> rawKeyedState = 
StateObjectCollection.empty();
+               private StateObjectCollection<InputChannelStateHandle> 
inputChannelState = StateObjectCollection.empty();
+               private StateObjectCollection<ResultSubpartitionStateHandle> 
resultSubpartitionState = StateObjectCollection.empty();
+               private VirtualChannelMapping inputChannelMappings = 
VirtualChannelMapping.NO_MAPPING;
+               private VirtualChannelMapping outputChannelMappings = 
VirtualChannelMapping.NO_MAPPING;
+
+               private Builder() {
+               }
+
+               public Builder 
setManagedOperatorState(StateObjectCollection<OperatorStateHandle> 
managedOperatorState) {
+                       this.managedOperatorState = 
checkNotNull(managedOperatorState);
+                       return this;
+               }
+
+               public Builder setManagedOperatorState(OperatorStateHandle 
managedOperatorState) {
+                       return 
setManagedOperatorState(StateObjectCollection.singleton(checkNotNull(managedOperatorState)));
+               }
+
+               public Builder 
setRawOperatorState(StateObjectCollection<OperatorStateHandle> 
rawOperatorState) {
+                       this.rawOperatorState = checkNotNull(rawOperatorState);
+                       return this;
+               }
+
+               public Builder setRawOperatorState(OperatorStateHandle 
rawOperatorState) {
+                       return 
setRawOperatorState(StateObjectCollection.singleton(checkNotNull(rawOperatorState)));
+               }
+
+               public Builder 
setManagedKeyedState(StateObjectCollection<KeyedStateHandle> managedKeyedState) 
{
+                       this.managedKeyedState = 
checkNotNull(managedKeyedState);
+                       return this;
+               }
+
+               public Builder setManagedKeyedState(KeyedStateHandle 
managedKeyedState) {
+                       return 
setManagedKeyedState(StateObjectCollection.singleton(checkNotNull(managedKeyedState)));
+               }
+
+               public Builder 
setRawKeyedState(StateObjectCollection<KeyedStateHandle> rawKeyedState) {
+                       this.rawKeyedState = checkNotNull(rawKeyedState);
+                       return this;
+               }
+
+               public Builder setRawKeyedState(KeyedStateHandle rawKeyedState) 
{
+                       return 
setRawKeyedState(StateObjectCollection.singleton(checkNotNull(rawKeyedState)));
+               }
+
+               public Builder 
setInputChannelState(StateObjectCollection<InputChannelStateHandle> 
inputChannelState) {
+                       this.inputChannelState = 
checkNotNull(inputChannelState);
+                       return this;
+               }
+
+               public Builder 
setResultSubpartitionState(StateObjectCollection<ResultSubpartitionStateHandle> 
resultSubpartitionState) {
+                       this.resultSubpartitionState = 
checkNotNull(resultSubpartitionState);
+                       return this;
+               }
+
+               public Builder setInputChannelMappings(VirtualChannelMapping 
inputChannelMappings) {
+                       this.inputChannelMappings = 
checkNotNull(inputChannelMappings);
+                       return this;
+               }
+
+               public Builder setOutputChannelMappings(VirtualChannelMapping 
outputChannelMappings) {
+                       this.outputChannelMappings = 
checkNotNull(outputChannelMappings);
+                       return this;
+               }
+
+               public OperatorSubtaskState build() {
+                       return new OperatorSubtaskState(
+                               managedOperatorState,
+                               rawOperatorState,
+                               managedKeyedState,
+                               rawKeyedState,
+                               inputChannelState,
+                               resultSubpartitionState,
+                               inputChannelMappings,
+                               outputChannelMappings);
+               }
+       }
+
+       /**
+        * Captures ambiguous mappings of old channels to new channels.
+        *
+        * <p>For inputs, this mapping implies the following:
+        * <li>
+        *     <ul>{@link #oldTaskInstances} is set when there is a rescale on 
this task potentially leading to different
+        *     key groups. Upstream task has a corresponding {@link 
#partitionMappings} where it sends data over
+        *     virtual channel while specifying the channel index in the 
VirtualChannelSelector. This subtask then
+        *     demultiplexes over the virtual subtask index.</ul>
+        *     <ul>{@link #partitionMappings} is set when there is a downscale 
of the upstream task. Upstream task has
+        *     a corresponding {@link #oldTaskInstances} where it sends data 
over virtual channel while specifying the
+        *     subtask index in the VirtualChannelSelector. This subtask then 
demultiplexes over channel indexes.</ul>
+        * </li>
+        *
+        * <p>For outputs, it's vice-versa. The information must be kept in 
sync but they are used in opposite ways for
+        * multiplexing/demultiplexing.
+        *
+        * <p>Note that in the common rescaling case both information is set 
and need to be simultaneously used. If the
+        * input subtask subsumes the state of 3 old subtasks and a channel 
corresponds to 2 old channels, then there are
+        * 6 virtual channels to be demultiplexed.
+        */
+       public static class VirtualChannelMapping implements Serializable {
+               public static final PartitionMapping NO_CHANNEL_MAPPING = new 
PartitionMapping(emptyList());
+               public static final List<PartitionMapping> NO_PARTITIONS = 
emptyList();
+               public static final BitSet NO_SUBTASKS = new BitSet();
+               public static final VirtualChannelMapping NO_MAPPING = new 
VirtualChannelMapping(NO_SUBTASKS, NO_PARTITIONS);
+
+               /**
+                * Set when several operator instances are merged into one.
+                */
+               private final BitSet oldTaskInstances;
+
+               /**
+                * Set when channels are merged because the connected operator 
has been rescaled.
+                */
+               private final List<PartitionMapping> partitionMappings;
+
+               public VirtualChannelMapping(BitSet oldTaskInstances, 
List<PartitionMapping> partitionMappings) {
+                       this.oldTaskInstances = oldTaskInstances;
+                       this.partitionMappings = partitionMappings;
+               }
+
+               @Override
+               public boolean equals(Object o) {
+                       if (this == o) {
+                               return true;
+                       }
+                       if (o == null || getClass() != o.getClass()) {
+                               return false;
+                       }
+                       final VirtualChannelMapping that = 
(VirtualChannelMapping) o;
+                       return oldTaskInstances.equals(that.oldTaskInstances) &&
+                               
partitionMappings.equals(that.partitionMappings);
+               }
+
+               public int[] getOldTaskInstances(int defaultSubtask) {
+                       return oldTaskInstances.equals(NO_SUBTASKS) ?
+                               new int[] {defaultSubtask} :
+                               oldTaskInstances.stream().toArray();
+               }
+
+               public PartitionMapping getPartitionMapping(int partitionIndex) 
{
+                       if (partitionMappings.isEmpty()) {
+                               return NO_CHANNEL_MAPPING;
+                       }
+                       return partitionMappings.get(partitionIndex);
+               }
+
+               @Override
+               public int hashCode() {
+                       return Objects.hash(oldTaskInstances, 
partitionMappings);
+               }
+
+               @Override
+               public String toString() {
+                       return "VirtualChannelMapping{" +
+                               "oldTaskInstances=" + oldTaskInstances +
+                               ", partitionMappings=" + partitionMappings +
+                               '}';
+               }
+       }
+
+       /**
+        * Contains the fine-grain channel mappings that occur when a connected 
operator has been rescaled.
+        */
+       public static class PartitionMapping implements Serializable {

Review comment:
       1. `PartitionMapping` associates with the `ResultPartition`, but this 
class maps input channels and subPartitions, right? How about 
`RescaledChannelsMapping`?
   2. `serialVersionUID`?
   3. nit: I'd extract this class too

##########
File path: 
flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
##########
@@ -620,18 +758,112 @@ private static void checkStateMappingCompleteness(
                        chainOpParallelStates,
                        oldParallelism,
                        newParallelism);
+       }
+
+       static class TaskStateAssignment {
+               final ExecutionJobVertex executionJobVertex;
+               final Map<OperatorID, OperatorState> oldState;
+               final boolean hasState;
+               final int newParallelism;
+               final OperatorID inputOperatorID;
+               final OperatorID outputOperatorID;
+
+               final Map<OperatorInstanceID, List<OperatorStateHandle>> 
subManagedOperatorState;
+               final Map<OperatorInstanceID, List<OperatorStateHandle>> 
subRawOperatorState;
+               final Map<OperatorInstanceID, List<KeyedStateHandle>> 
subManagedKeyedState;
+               final Map<OperatorInstanceID, List<KeyedStateHandle>> 
subRawKeyedState;
+
+               final Map<OperatorInstanceID, List<InputChannelStateHandle>> 
inputChannelStates;
+               final Map<OperatorInstanceID, 
List<ResultSubpartitionStateHandle>> resultSubpartitionStates;
+               /** The subpartitions mappings per partition set when the 
output operator for a partition was rescaled. */
+               List<BitSet> outputOperatorInstanceMappings = emptyList();
+               /** The input channel mappings per input set when the input 
operator for a gate was rescaled. */
+               List<BitSet> inputOperatorInstanceMappings = emptyList();
+               /** The subpartitions mappings of the upstream task per input 
set when its output operator was rescaled. */
+               final Map<Integer, List<BitSet>> upstreamVirtualChannels;
+               /** The input channel mappings of the downstream task per 
partition set when its input operator was rescaled. */
+               final Map<Integer, List<BitSet>> downStreamVirtualChannels;

Review comment:
       How about storing a reference to the downstream/upstream 
`TaskStateAssignment` and get the mappings from there?
   I think it would be more readable.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to