zhuzhurk commented on code in PR #21570:
URL: https://github.com/apache/flink/pull/21570#discussion_r1060266079


##########
flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionVertexInputInfo.java:
##########
@@ -68,4 +68,9 @@ public boolean equals(Object obj) {
             return false;
         }
     }
+
+    @Override
+    public String toString() {
+        return subtaskIndex + " " + partitionIndexRange + " " + 
subpartitionIndexRange;

Review Comment:
   The string is not easy to understand, maybe "subtask index: {}, consumed 
partition index range: {}, consumed subpartition index range:{}"?



##########
flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/VertexInputInfoComputationUtils.java:
##########
@@ -167,7 +167,7 @@ static JobVertexInputInfo 
computeVertexInputInfoForPointwise(
      * @return the computed {@link JobVertexInputInfo}
      */
     @VisibleForTesting

Review Comment:
   The `@VisibleForTesting` is no longer needed.



##########
flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ParallelismAndInputInfos.java:
##########
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.executiongraph;
+
+import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
+
+import java.util.Map;
+
+import static org.apache.flink.util.Preconditions.checkArgument;
+import static org.apache.flink.util.Preconditions.checkNotNull;
+
+/** The parallelism and {@link JobVertexInputInfo}s of a job vertex . */
+public class ParallelismAndInputInfos {
+
+    private final int parallelism;
+    private final Map<IntermediateDataSetID, JobVertexInputInfo> 
jobVertexInputInfos;
+
+    public ParallelismAndInputInfos(
+            int parallelism, Map<IntermediateDataSetID, JobVertexInputInfo> 
jobVertexInputInfos) {
+        checkArgument(parallelism > 0);
+        this.parallelism = parallelism;
+        this.jobVertexInputInfos = checkNotNull(jobVertexInputInfos);
+    }
+
+    public int getParallelism() {
+        return parallelism;
+    }
+
+    public Map<IntermediateDataSetID, JobVertexInputInfo> 
getJobVertexInputInfos() {
+        return jobVertexInputInfos;

Review Comment:
   Better to make it unmodifiable.



##########
flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/BisectionSearchUtils.java:
##########
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.flink.runtime.scheduler.adaptivebatch;
+
+import java.util.function.Function;
+
+/** Utility class for bisection search. */
+public class BisectionSearchUtils {

Review Comment:
   can be package private



##########
flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismDecider.java:
##########
@@ -77,46 +92,74 @@ private DefaultVertexParallelismDecider(
     }
 
     @Override
-    public int decideParallelismForVertex(List<BlockingResultInfo> 
consumedResults) {
+    public ParallelismAndInputInfos decideParallelismAndInputInfosForVertex(
+            List<BlockingResultInfo> inputs, int parallelism) {
+        checkArgument(parallelism == ExecutionConfig.PARALLELISM_DEFAULT || 
parallelism > 0);
 
-        if (consumedResults.isEmpty()) {
+        if (inputs.isEmpty()) {
             // source job vertex
-            return defaultSourceParallelism;
+            return new ParallelismAndInputInfos(defaultSourceParallelism, 
Collections.emptyMap());
+        } else if (parallelism == ExecutionConfig.PARALLELISM_DEFAULT
+                && areAllInputsAllToAll(inputs)
+                && !areAllInputsBroadcast(inputs)) {
+            // load balance for ALL_TO_ALL inputs
+            return loadBalanceForAllToAllInputs(inputs, parallelism);
         } else {
-            return calculateParallelism(consumedResults);
+            if (parallelism == ExecutionConfig.PARALLELISM_DEFAULT) {
+                parallelism = decideParallelism(inputs);
+            }
+            return new ParallelismAndInputInfos(parallelism, 
decideInputInfos(inputs, parallelism));
         }
     }
 
-    private int calculateParallelism(List<BlockingResultInfo> consumedResults) 
{
+    private static boolean areAllInputsAllToAll(List<BlockingResultInfo> 
inputs) {
+        return inputs.stream().noneMatch(BlockingResultInfo::isPointwise);
+    }
 
-        long broadcastBytes =
-                consumedResults.stream()
-                        .filter(BlockingResultInfo::isBroadcast)
-                        .mapToLong(BlockingResultInfo::getNumBytesProduced)
-                        .sum();
+    private static boolean areAllInputsBroadcast(List<BlockingResultInfo> 
inputs) {
+        return inputs.stream().allMatch(BlockingResultInfo::isBroadcast);
+    }
 
-        long nonBroadcastBytes =
-                consumedResults.stream()
-                        .filter(consumedResult -> 
!consumedResult.isBroadcast())
-                        .mapToLong(BlockingResultInfo::getNumBytesProduced)
-                        .sum();
+    static Map<IntermediateDataSetID, JobVertexInputInfo> decideInputInfos(

Review Comment:
   can be private



##########
flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismDecider.java:
##########
@@ -77,46 +92,74 @@ private DefaultVertexParallelismDecider(
     }
 
     @Override
-    public int decideParallelismForVertex(List<BlockingResultInfo> 
consumedResults) {
+    public ParallelismAndInputInfos decideParallelismAndInputInfosForVertex(
+            List<BlockingResultInfo> inputs, int parallelism) {
+        checkArgument(parallelism == ExecutionConfig.PARALLELISM_DEFAULT || 
parallelism > 0);
 
-        if (consumedResults.isEmpty()) {
+        if (inputs.isEmpty()) {
             // source job vertex
-            return defaultSourceParallelism;
+            return new ParallelismAndInputInfos(defaultSourceParallelism, 
Collections.emptyMap());
+        } else if (parallelism == ExecutionConfig.PARALLELISM_DEFAULT
+                && areAllInputsAllToAll(inputs)
+                && !areAllInputsBroadcast(inputs)) {
+            // load balance for ALL_TO_ALL inputs
+            return loadBalanceForAllToAllInputs(inputs, parallelism);
         } else {

Review Comment:
   Can we have some more detailed comments for the difference of these 
options?i.e. In which case, we take what actions, and why?



##########
flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismDecider.java:
##########
@@ -150,6 +193,202 @@ private int calculateParallelism(List<BlockingResultInfo> 
consumedResults) {
         return parallelism;
     }
 
+    private ParallelismAndInputInfos loadBalanceForAllToAllInputs(
+            List<BlockingResultInfo> inputs, int parallelism) {
+        checkArgument(parallelism == ExecutionConfig.PARALLELISM_DEFAULT);
+        checkArgument(!inputs.isEmpty());
+        inputs.forEach(resultInfo -> checkState(!resultInfo.isPointwise()));
+
+        final List<BlockingResultInfo> nonBroadcastInputs = 
getNonBroadcastResultInfos(inputs);
+        long broadcastBytes = getReasonableBroadcastBytes(inputs);
+        long nonBroadcastBytes = getNonBroadcastBytes(inputs);
+
+        int subpartitionNum = checkAndGetSubpartitionNum(nonBroadcastInputs);
+
+        long nonBroadcastBytesPerTaskLimit = dataVolumePerTask - 
broadcastBytes;
+        long[] nonBroadcastBytesBySubpartition = new long[subpartitionNum];
+        Arrays.fill(nonBroadcastBytesBySubpartition, 0L);
+        for (BlockingResultInfo resultInfo : nonBroadcastInputs) {
+            List<Long> subpartitionBytes =
+                    ((AllToAllBlockingResultInfo) 
resultInfo).getAggregatedSubpartitionBytes();
+            for (int i = 0; i < subpartitionNum; ++i) {
+                nonBroadcastBytesBySubpartition[i] += subpartitionBytes.get(i);
+            }
+        }
+
+        // compute subpartition ranges
+        List<IndexRange> subpartitionRanges =
+                computeSubpartitionRanges(
+                        nonBroadcastBytesBySubpartition, 
nonBroadcastBytesPerTaskLimit);
+
+        if (subpartitionRanges.size() < minParallelism) {
+            long minSubpartitionBytes =
+                    
Arrays.stream(nonBroadcastBytesBySubpartition).min().getAsLong();
+            // find a legal limit so that the computed parallelism >= 
minParallelism
+            long adjustLimit =
+                    BisectionSearchUtils.findMaxLegalValue(
+                            value ->
+                                    
computeParallelism(nonBroadcastBytesBySubpartition, value)
+                                            >= minParallelism,
+                            minSubpartitionBytes,
+                            nonBroadcastBytesPerTaskLimit);
+
+            // the smaller the limit, the more even the distribution
+            final long expectedParallelism =
+                    computeParallelism(nonBroadcastBytesBySubpartition, 
adjustLimit);
+            adjustLimit =
+                    BisectionSearchUtils.findMinLegalValue(
+                            value ->
+                                    
computeParallelism(nonBroadcastBytesBySubpartition, value)
+                                            <= expectedParallelism,
+                            minSubpartitionBytes,
+                            adjustLimit);
+
+            subpartitionRanges =
+                    computeSubpartitionRanges(nonBroadcastBytesBySubpartition, 
adjustLimit);
+        } else if (subpartitionRanges.size() > maxParallelism) {
+            // find a legal limit so that the computed parallelism <= 
minParallelism
+            long adjustLimit =
+                    BisectionSearchUtils.findMinLegalValue(
+                            value ->
+                                    
computeParallelism(nonBroadcastBytesBySubpartition, value)
+                                            <= maxParallelism,
+                            nonBroadcastBytesPerTaskLimit,
+                            nonBroadcastBytes);

Review Comment:
   Looks the new limit is possible to result in a parallelism which is smaller 
than the minParallelism?



##########
flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismDecider.java:
##########
@@ -77,46 +92,74 @@ private DefaultVertexParallelismDecider(
     }
 
     @Override
-    public int decideParallelismForVertex(List<BlockingResultInfo> 
consumedResults) {
+    public ParallelismAndInputInfos decideParallelismAndInputInfosForVertex(
+            List<BlockingResultInfo> inputs, int parallelism) {
+        checkArgument(parallelism == ExecutionConfig.PARALLELISM_DEFAULT || 
parallelism > 0);
 
-        if (consumedResults.isEmpty()) {
+        if (inputs.isEmpty()) {
             // source job vertex
-            return defaultSourceParallelism;
+            return new ParallelismAndInputInfos(defaultSourceParallelism, 
Collections.emptyMap());
+        } else if (parallelism == ExecutionConfig.PARALLELISM_DEFAULT
+                && areAllInputsAllToAll(inputs)
+                && !areAllInputsBroadcast(inputs)) {
+            // load balance for ALL_TO_ALL inputs
+            return loadBalanceForAllToAllInputs(inputs, parallelism);
         } else {
-            return calculateParallelism(consumedResults);
+            if (parallelism == ExecutionConfig.PARALLELISM_DEFAULT) {
+                parallelism = decideParallelism(inputs);
+            }
+            return new ParallelismAndInputInfos(parallelism, 
decideInputInfos(inputs, parallelism));
         }
     }
 
-    private int calculateParallelism(List<BlockingResultInfo> consumedResults) 
{
+    private static boolean areAllInputsAllToAll(List<BlockingResultInfo> 
inputs) {
+        return inputs.stream().noneMatch(BlockingResultInfo::isPointwise);
+    }
 
-        long broadcastBytes =
-                consumedResults.stream()
-                        .filter(BlockingResultInfo::isBroadcast)
-                        .mapToLong(BlockingResultInfo::getNumBytesProduced)
-                        .sum();
+    private static boolean areAllInputsBroadcast(List<BlockingResultInfo> 
inputs) {
+        return inputs.stream().allMatch(BlockingResultInfo::isBroadcast);
+    }
 
-        long nonBroadcastBytes =
-                consumedResults.stream()
-                        .filter(consumedResult -> 
!consumedResult.isBroadcast())
-                        .mapToLong(BlockingResultInfo::getNumBytesProduced)
-                        .sum();
+    static Map<IntermediateDataSetID, JobVertexInputInfo> decideInputInfos(

Review Comment:
   Can we just use 
`VertexInputInfoComputationUtils.computeVertexInputInfos(...)`?



##########
flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismDecider.java:
##########
@@ -55,7 +70,7 @@ public class DefaultVertexParallelismDecider implements 
VertexParallelismDecider
     private final long dataVolumePerTask;
     private final int defaultSourceParallelism;
 
-    private DefaultVertexParallelismDecider(
+    protected DefaultVertexParallelismDecider(

Review Comment:
   What's this change for?



##########
flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AbstractBlockingResultInfo.java:
##########
@@ -43,12 +43,21 @@ abstract class AbstractBlockingResultInfo implements 
BlockingResultInfo {
      */
     protected final Map<Integer, long[]> subpartitionBytesByPartitionIndex;
 
+    @VisibleForTesting

Review Comment:
   Is it possible to use `recordPartitionInfo(...)` in tests so that we do not 
need to introduce these testing methods?



##########
flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismDecider.java:
##########
@@ -77,46 +92,74 @@ private DefaultVertexParallelismDecider(
     }
 
     @Override
-    public int decideParallelismForVertex(List<BlockingResultInfo> 
consumedResults) {
+    public ParallelismAndInputInfos decideParallelismAndInputInfosForVertex(
+            List<BlockingResultInfo> inputs, int parallelism) {
+        checkArgument(parallelism == ExecutionConfig.PARALLELISM_DEFAULT || 
parallelism > 0);
 
-        if (consumedResults.isEmpty()) {
+        if (inputs.isEmpty()) {
             // source job vertex
-            return defaultSourceParallelism;
+            return new ParallelismAndInputInfos(defaultSourceParallelism, 
Collections.emptyMap());
+        } else if (parallelism == ExecutionConfig.PARALLELISM_DEFAULT
+                && areAllInputsAllToAll(inputs)
+                && !areAllInputsBroadcast(inputs)) {
+            // load balance for ALL_TO_ALL inputs
+            return loadBalanceForAllToAllInputs(inputs, parallelism);
         } else {
-            return calculateParallelism(consumedResults);
+            if (parallelism == ExecutionConfig.PARALLELISM_DEFAULT) {
+                parallelism = decideParallelism(inputs);
+            }
+            return new ParallelismAndInputInfos(parallelism, 
decideInputInfos(inputs, parallelism));
         }
     }
 
-    private int calculateParallelism(List<BlockingResultInfo> consumedResults) 
{
+    private static boolean areAllInputsAllToAll(List<BlockingResultInfo> 
inputs) {
+        return inputs.stream().noneMatch(BlockingResultInfo::isPointwise);
+    }
 
-        long broadcastBytes =
-                consumedResults.stream()
-                        .filter(BlockingResultInfo::isBroadcast)
-                        .mapToLong(BlockingResultInfo::getNumBytesProduced)
-                        .sum();
+    private static boolean areAllInputsBroadcast(List<BlockingResultInfo> 
inputs) {
+        return inputs.stream().allMatch(BlockingResultInfo::isBroadcast);
+    }
 
-        long nonBroadcastBytes =
-                consumedResults.stream()
-                        .filter(consumedResult -> 
!consumedResult.isBroadcast())
-                        .mapToLong(BlockingResultInfo::getNumBytesProduced)
-                        .sum();
+    static Map<IntermediateDataSetID, JobVertexInputInfo> decideInputInfos(
+            List<BlockingResultInfo> inputs, int parallelism) {
+        checkArgument(!inputs.isEmpty());
 
-        long expectedMaxBroadcastBytes =
-                (long) Math.ceil((dataVolumePerTask * CAP_RATIO_OF_BROADCAST));
+        final Map<IntermediateDataSetID, JobVertexInputInfo> 
jobVertexInputInfos = new HashMap<>();
 
-        if (broadcastBytes > expectedMaxBroadcastBytes) {
-            LOG.info(
-                    "The size of broadcast data {} is larger than the expected 
maximum value {} ('{}' * {})."
-                            + " Use {} as the size of broadcast data to decide 
the parallelism.",
-                    new MemorySize(broadcastBytes),
-                    new MemorySize(expectedMaxBroadcastBytes),
-                    
JobManagerOptions.ADAPTIVE_BATCH_SCHEDULER_AVG_DATA_VOLUME_PER_TASK.key(),
-                    CAP_RATIO_OF_BROADCAST,
-                    new MemorySize(expectedMaxBroadcastBytes));
-
-            broadcastBytes = expectedMaxBroadcastBytes;
+        for (BlockingResultInfo resultInfo : inputs) {
+            IntermediateDataSetID resultId = resultInfo.getResultId();
+            int sourceParallelism = resultInfo.getNumPartitions();
+            boolean isBroadcast = resultInfo.isBroadcast();
+            if (resultInfo.isPointwise()) {
+                jobVertexInputInfos.putIfAbsent(
+                        resultId,
+                        
VertexInputInfoComputationUtils.computeVertexInputInfoForPointwise(
+                                sourceParallelism,
+                                parallelism,
+                                resultInfo::getNumSubpartitions,
+                                true));
+            } else {
+                jobVertexInputInfos.putIfAbsent(
+                        resultId,
+                        
VertexInputInfoComputationUtils.computeVertexInputInfoForAllToAll(
+                                sourceParallelism,
+                                parallelism,
+                                resultInfo::getNumSubpartitions,
+                                true,
+                                isBroadcast));
+            }
         }
 
+        return jobVertexInputInfos;
+    }
+
+    int decideParallelism(List<BlockingResultInfo> consumedResults) {
+

Review Comment:
   useless empty line.



##########
flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/VertexParallelismDecider.java:
##########
@@ -18,19 +18,27 @@
 
 package org.apache.flink.runtime.scheduler.adaptivebatch;
 
+import org.apache.flink.api.common.ExecutionConfig;
+import org.apache.flink.runtime.executiongraph.JobVertexInputInfo;
+import org.apache.flink.runtime.executiongraph.ParallelismAndInputInfos;
+
 import java.util.List;
 
 /**
- * {@link VertexParallelismDecider} is responsible for determining the 
parallelism of a job vertex,
- * based on the size of the consumed blocking results.
+ * {@link VertexParallelismDecider} is responsible for deciding the 
parallelism and {@link
+ * JobVertexInputInfo}s of a job vertex, based on the information of the 
consumed blocking results.
  */
 public interface VertexParallelismDecider {
 
     /**
-     * Computing the parallelism.
+     * Decide the parallelism and {@link JobVertexInputInfo}s for this job 
vertex.
      *
-     * @param consumedResults The information of consumed blocking results.
-     * @return the parallelism of the job vertex.
+     * @param inputs The information of consumed blocking results

Review Comment:
   Not sure why we change the name? The previous one looks better to me. That 
`consumedResults` correspond to results while `inputs` correspond to edges.



##########
flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ParallelismAndInputInfos.java:
##########
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.executiongraph;
+
+import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
+
+import java.util.Map;
+
+import static org.apache.flink.util.Preconditions.checkArgument;
+import static org.apache.flink.util.Preconditions.checkNotNull;
+
+/** The parallelism and {@link JobVertexInputInfo}s of a job vertex . */
+public class ParallelismAndInputInfos {
+
+    private final int parallelism;
+    private final Map<IntermediateDataSetID, JobVertexInputInfo> 
jobVertexInputInfos;
+
+    public ParallelismAndInputInfos(
+            int parallelism, Map<IntermediateDataSetID, JobVertexInputInfo> 
jobVertexInputInfos) {
+        checkArgument(parallelism > 0);
+        this.parallelism = parallelism;
+        this.jobVertexInputInfos = checkNotNull(jobVertexInputInfos);

Review Comment:
   Maybe verify that the number of `ExecutionVertexInputInfo` in each 
`JobVertexInputInfo` equals to parallelism?



##########
flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/VertexInputInfoComputationUtils.java:
##########
@@ -102,7 +102,7 @@ public static Map<IntermediateDataSetID, 
JobVertexInputInfo> computeVertexInputI
      * @return the computed {@link JobVertexInputInfo}
      */
     @VisibleForTesting
-    static JobVertexInputInfo computeVertexInputInfoForPointwise(

Review Comment:
   `@VisibleForTesting` is no longer needed because it's now required to be 
public for production code.



##########
flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismDecider.java:
##########
@@ -19,15 +19,30 @@
 package org.apache.flink.runtime.scheduler.adaptivebatch;
 
 import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.api.common.ExecutionConfig;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.configuration.JobManagerOptions;
 import org.apache.flink.configuration.MemorySize;
+import org.apache.flink.runtime.executiongraph.ExecutionVertexInputInfo;
+import org.apache.flink.runtime.executiongraph.IndexRange;
+import org.apache.flink.runtime.executiongraph.JobVertexInputInfo;
+import org.apache.flink.runtime.executiongraph.ParallelismAndInputInfos;
+import org.apache.flink.runtime.executiongraph.VertexInputInfoComputationUtils;
+import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
 import org.apache.flink.util.MathUtils;
 
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
 import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
 
 import static org.apache.flink.util.Preconditions.checkArgument;
 import static org.apache.flink.util.Preconditions.checkNotNull;

Review Comment:
   Could you improve the documentation of this class to describe what is does 
in the latest version?



##########
flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismDecider.java:
##########
@@ -150,6 +193,202 @@ private int calculateParallelism(List<BlockingResultInfo> 
consumedResults) {
         return parallelism;
     }
 
+    private ParallelismAndInputInfos loadBalanceForAllToAllInputs(
+            List<BlockingResultInfo> inputs, int parallelism) {
+        checkArgument(parallelism == ExecutionConfig.PARALLELISM_DEFAULT);
+        checkArgument(!inputs.isEmpty());
+        inputs.forEach(resultInfo -> checkState(!resultInfo.isPointwise()));
+
+        final List<BlockingResultInfo> nonBroadcastInputs = 
getNonBroadcastResultInfos(inputs);
+        long broadcastBytes = getReasonableBroadcastBytes(inputs);
+        long nonBroadcastBytes = getNonBroadcastBytes(inputs);
+
+        int subpartitionNum = checkAndGetSubpartitionNum(nonBroadcastInputs);
+
+        long nonBroadcastBytesPerTaskLimit = dataVolumePerTask - 
broadcastBytes;
+        long[] nonBroadcastBytesBySubpartition = new long[subpartitionNum];
+        Arrays.fill(nonBroadcastBytesBySubpartition, 0L);
+        for (BlockingResultInfo resultInfo : nonBroadcastInputs) {
+            List<Long> subpartitionBytes =
+                    ((AllToAllBlockingResultInfo) 
resultInfo).getAggregatedSubpartitionBytes();
+            for (int i = 0; i < subpartitionNum; ++i) {
+                nonBroadcastBytesBySubpartition[i] += subpartitionBytes.get(i);
+            }
+        }
+
+        // compute subpartition ranges
+        List<IndexRange> subpartitionRanges =
+                computeSubpartitionRanges(
+                        nonBroadcastBytesBySubpartition, 
nonBroadcastBytesPerTaskLimit);
+
+        if (subpartitionRanges.size() < minParallelism) {
+            long minSubpartitionBytes =
+                    
Arrays.stream(nonBroadcastBytesBySubpartition).min().getAsLong();
+            // find a legal limit so that the computed parallelism >= 
minParallelism
+            long adjustLimit =
+                    BisectionSearchUtils.findMaxLegalValue(
+                            value ->
+                                    
computeParallelism(nonBroadcastBytesBySubpartition, value)
+                                            >= minParallelism,
+                            minSubpartitionBytes,
+                            nonBroadcastBytesPerTaskLimit);
+
+            // the smaller the limit, the more even the distribution
+            final long expectedParallelism =
+                    computeParallelism(nonBroadcastBytesBySubpartition, 
adjustLimit);
+            adjustLimit =
+                    BisectionSearchUtils.findMinLegalValue(
+                            value ->
+                                    
computeParallelism(nonBroadcastBytesBySubpartition, value)
+                                            <= expectedParallelism,
+                            minSubpartitionBytes,
+                            adjustLimit);
+
+            subpartitionRanges =
+                    computeSubpartitionRanges(nonBroadcastBytesBySubpartition, 
adjustLimit);
+        } else if (subpartitionRanges.size() > maxParallelism) {
+            // find a legal limit so that the computed parallelism <= 
minParallelism
+            long adjustLimit =
+                    BisectionSearchUtils.findMinLegalValue(
+                            value ->
+                                    
computeParallelism(nonBroadcastBytesBySubpartition, value)
+                                            <= maxParallelism,
+                            nonBroadcastBytesPerTaskLimit,
+                            nonBroadcastBytes);
+
+            subpartitionRanges =
+                    computeSubpartitionRanges(nonBroadcastBytesBySubpartition, 
adjustLimit);
+        }
+
+        checkState(isLegalParallelism(subpartitionRanges.size()));
+        return createParallelismAndInputInfos(inputs, subpartitionRanges);
+    }
+
+    private boolean isLegalParallelism(int parallelism) {
+        return parallelism >= minParallelism && parallelism <= maxParallelism;
+    }
+
+    private static int checkAndGetSubpartitionNum(List<BlockingResultInfo> 
inputs) {
+        final Set<Integer> subpartitionNumSet =
+                inputs.stream()
+                        .flatMap(
+                                resultInfo ->
+                                        IntStream.range(0, 
resultInfo.getNumPartitions())
+                                                .boxed()
+                                                
.map(resultInfo::getNumSubpartitions))
+                        .collect(Collectors.toSet());
+        // all partitions have the same subpartition num
+        checkState(subpartitionNumSet.size() == 1);
+        return subpartitionNumSet.iterator().next();
+    }
+
+    private static ParallelismAndInputInfos createParallelismAndInputInfos(
+            List<BlockingResultInfo> inputs, List<IndexRange> 
subpartitionRanges) {
+
+        final Map<IntermediateDataSetID, JobVertexInputInfo> vertexInputInfos 
= new HashMap<>();
+        inputs.forEach(
+                resultInfo -> {
+                    int sourceParallelism = resultInfo.getNumPartitions();
+                    IndexRange partitionRange = new IndexRange(0, 
sourceParallelism - 1);
+
+                    List<ExecutionVertexInputInfo> executionVertexInputInfos = 
new ArrayList<>();
+                    for (int i = 0; i < subpartitionRanges.size(); ++i) {
+                        IndexRange subpartitionRange;
+                        if (resultInfo.isBroadcast()) {
+                            subpartitionRange = new IndexRange(0, 0);
+                        } else {
+                            subpartitionRange = subpartitionRanges.get(i);
+                        }
+                        ExecutionVertexInputInfo executionVertexInputInfo =
+                                new ExecutionVertexInputInfo(i, 
partitionRange, subpartitionRange);
+                        
executionVertexInputInfos.add(executionVertexInputInfo);
+                    }
+
+                    vertexInputInfos.put(
+                            resultInfo.getResultId(),
+                            new JobVertexInputInfo(executionVertexInputInfos));
+                });
+        return new ParallelismAndInputInfos(subpartitionRanges.size(), 
vertexInputInfos);
+    }
+
+    private static List<IndexRange> computeSubpartitionRanges(long[] nums, 
long limit) {
+        List<IndexRange> subpartitionRanges = new ArrayList<>();
+        long tmpSum = 0;
+        int startIndex = 0;
+        for (int i = 0; i < nums.length; ++i) {
+            long num = nums[i];
+            if (tmpSum == 0 || tmpSum + num <= limit) {
+                tmpSum += num;
+            } else {
+                subpartitionRanges.add(new IndexRange(startIndex, i - 1));
+                startIndex = i;
+                tmpSum = num;
+            }
+        }
+        subpartitionRanges.add(new IndexRange(startIndex, nums.length - 1));
+        return subpartitionRanges;
+    }
+
+    private static int computeParallelism(long[] nums, long limit) {
+

Review Comment:
   computeParallelism



##########
flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchScheduler.java:
##########
@@ -258,46 +267,44 @@ void initializeVerticesIfPossible() {
         }
     }
 
-    private void maybeSetParallelism(final ExecutionJobVertex jobVertex) {
-        if (jobVertex.isParallelismDecided()) {
-            return;
-        }
-
-        Optional<List<BlockingResultInfo>> consumedResultsInfo =
-                tryGetConsumedResultsInfo(jobVertex);
-        if (!consumedResultsInfo.isPresent()) {
-            return;
-        }
-
+    private ParallelismAndInputInfos tryDecideParallelismAndInputInfos(
+            final ExecutionJobVertex jobVertex, List<BlockingResultInfo> 
inputs) {
+        int parallelism = jobVertex.getParallelism();
         ForwardGroup forwardGroup = 
forwardGroupsByJobVertexId.get(jobVertex.getJobVertexId());
-        int parallelism;
-
         if (forwardGroup != null && forwardGroup.isParallelismDecided()) {

Review Comment:
   Better to do this only if the parallelism of the vertex is not set.



##########
flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismDecider.java:
##########
@@ -150,6 +193,202 @@ private int calculateParallelism(List<BlockingResultInfo> 
consumedResults) {
         return parallelism;
     }
 
+    private ParallelismAndInputInfos loadBalanceForAllToAllInputs(
+            List<BlockingResultInfo> inputs, int parallelism) {
+        checkArgument(parallelism == ExecutionConfig.PARALLELISM_DEFAULT);
+        checkArgument(!inputs.isEmpty());
+        inputs.forEach(resultInfo -> checkState(!resultInfo.isPointwise()));
+
+        final List<BlockingResultInfo> nonBroadcastInputs = 
getNonBroadcastResultInfos(inputs);
+        long broadcastBytes = getReasonableBroadcastBytes(inputs);
+        long nonBroadcastBytes = getNonBroadcastBytes(inputs);
+
+        int subpartitionNum = checkAndGetSubpartitionNum(nonBroadcastInputs);
+
+        long nonBroadcastBytesPerTaskLimit = dataVolumePerTask - 
broadcastBytes;
+        long[] nonBroadcastBytesBySubpartition = new long[subpartitionNum];
+        Arrays.fill(nonBroadcastBytesBySubpartition, 0L);
+        for (BlockingResultInfo resultInfo : nonBroadcastInputs) {
+            List<Long> subpartitionBytes =
+                    ((AllToAllBlockingResultInfo) 
resultInfo).getAggregatedSubpartitionBytes();
+            for (int i = 0; i < subpartitionNum; ++i) {
+                nonBroadcastBytesBySubpartition[i] += subpartitionBytes.get(i);
+            }
+        }
+
+        // compute subpartition ranges
+        List<IndexRange> subpartitionRanges =
+                computeSubpartitionRanges(
+                        nonBroadcastBytesBySubpartition, 
nonBroadcastBytesPerTaskLimit);
+
+        if (subpartitionRanges.size() < minParallelism) {
+            long minSubpartitionBytes =
+                    
Arrays.stream(nonBroadcastBytesBySubpartition).min().getAsLong();
+            // find a legal limit so that the computed parallelism >= 
minParallelism
+            long adjustLimit =
+                    BisectionSearchUtils.findMaxLegalValue(
+                            value ->
+                                    
computeParallelism(nonBroadcastBytesBySubpartition, value)
+                                            >= minParallelism,
+                            minSubpartitionBytes,
+                            nonBroadcastBytesPerTaskLimit);
+
+            // the smaller the limit, the more even the distribution
+            final long expectedParallelism =
+                    computeParallelism(nonBroadcastBytesBySubpartition, 
adjustLimit);
+            adjustLimit =
+                    BisectionSearchUtils.findMinLegalValue(
+                            value ->
+                                    
computeParallelism(nonBroadcastBytesBySubpartition, value)
+                                            <= expectedParallelism,
+                            minSubpartitionBytes,
+                            adjustLimit);
+
+            subpartitionRanges =
+                    computeSubpartitionRanges(nonBroadcastBytesBySubpartition, 
adjustLimit);
+        } else if (subpartitionRanges.size() > maxParallelism) {
+            // find a legal limit so that the computed parallelism <= 
minParallelism

Review Comment:
   minParallelism -> maxParallelism



##########
flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/VertexParallelismDecider.java:
##########
@@ -18,19 +18,27 @@
 
 package org.apache.flink.runtime.scheduler.adaptivebatch;
 
+import org.apache.flink.api.common.ExecutionConfig;
+import org.apache.flink.runtime.executiongraph.JobVertexInputInfo;
+import org.apache.flink.runtime.executiongraph.ParallelismAndInputInfos;
+
 import java.util.List;
 
 /**
- * {@link VertexParallelismDecider} is responsible for determining the 
parallelism of a job vertex,
- * based on the size of the consumed blocking results.
+ * {@link VertexParallelismDecider} is responsible for deciding the 
parallelism and {@link
+ * JobVertexInputInfo}s of a job vertex, based on the information of the 
consumed blocking results.
  */
 public interface VertexParallelismDecider {
 
     /**
-     * Computing the parallelism.
+     * Decide the parallelism and {@link JobVertexInputInfo}s for this job 
vertex.
      *
-     * @param consumedResults The information of consumed blocking results.
-     * @return the parallelism of the job vertex.
+     * @param inputs The information of consumed blocking results
+     * @param parallelism The original parallelism of the job vertex, used to 
determine whether the

Review Comment:
   Maybe `initialParallelism` to make it easier to undertand?



##########
flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/BisectionSearchUtils.java:
##########
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.flink.runtime.scheduler.adaptivebatch;
+
+import java.util.function.Function;
+
+/** Utility class for bisection search. */
+public class BisectionSearchUtils {
+
+    public static long findMinLegalValue(
+            Function<Long, Boolean> legalChecker, long low, long high) {
+        if (!legalChecker.apply(high)) {
+            return -1;
+        }
+        while (low <= high) {
+            long mid = low + (high - low) / 2;

Review Comment:
   ```suggestion
               long mid = (high + low) / 2;
   ```



##########
flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismDecider.java:
##########
@@ -150,6 +193,202 @@ private int calculateParallelism(List<BlockingResultInfo> 
consumedResults) {
         return parallelism;
     }
 
+    private ParallelismAndInputInfos loadBalanceForAllToAllInputs(
+            List<BlockingResultInfo> inputs, int parallelism) {
+        checkArgument(parallelism == ExecutionConfig.PARALLELISM_DEFAULT);
+        checkArgument(!inputs.isEmpty());
+        inputs.forEach(resultInfo -> checkState(!resultInfo.isPointwise()));
+
+        final List<BlockingResultInfo> nonBroadcastInputs = 
getNonBroadcastResultInfos(inputs);
+        long broadcastBytes = getReasonableBroadcastBytes(inputs);
+        long nonBroadcastBytes = getNonBroadcastBytes(inputs);
+
+        int subpartitionNum = checkAndGetSubpartitionNum(nonBroadcastInputs);
+
+        long nonBroadcastBytesPerTaskLimit = dataVolumePerTask - 
broadcastBytes;
+        long[] nonBroadcastBytesBySubpartition = new long[subpartitionNum];
+        Arrays.fill(nonBroadcastBytesBySubpartition, 0L);
+        for (BlockingResultInfo resultInfo : nonBroadcastInputs) {
+            List<Long> subpartitionBytes =
+                    ((AllToAllBlockingResultInfo) 
resultInfo).getAggregatedSubpartitionBytes();
+            for (int i = 0; i < subpartitionNum; ++i) {
+                nonBroadcastBytesBySubpartition[i] += subpartitionBytes.get(i);
+            }
+        }
+
+        // compute subpartition ranges
+        List<IndexRange> subpartitionRanges =
+                computeSubpartitionRanges(
+                        nonBroadcastBytesBySubpartition, 
nonBroadcastBytesPerTaskLimit);
+
+        if (subpartitionRanges.size() < minParallelism) {

Review Comment:
   Maybe extract the adjust logic into a standalone method and add some 
comments to explain it?



##########
flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/VertexParallelismDecider.java:
##########
@@ -18,19 +18,27 @@
 
 package org.apache.flink.runtime.scheduler.adaptivebatch;
 
+import org.apache.flink.api.common.ExecutionConfig;
+import org.apache.flink.runtime.executiongraph.JobVertexInputInfo;
+import org.apache.flink.runtime.executiongraph.ParallelismAndInputInfos;
+
 import java.util.List;
 
 /**
- * {@link VertexParallelismDecider} is responsible for determining the 
parallelism of a job vertex,
- * based on the size of the consumed blocking results.
+ * {@link VertexParallelismDecider} is responsible for deciding the 
parallelism and {@link
+ * JobVertexInputInfo}s of a job vertex, based on the information of the 
consumed blocking results.
  */
 public interface VertexParallelismDecider {
 
     /**
-     * Computing the parallelism.
+     * Decide the parallelism and {@link JobVertexInputInfo}s for this job 
vertex.
      *
-     * @param consumedResults The information of consumed blocking results.
-     * @return the parallelism of the job vertex.
+     * @param inputs The information of consumed blocking results
+     * @param parallelism The original parallelism of the job vertex, used to 
determine whether the

Review Comment:
   And it's better to re-phrase it, e.g. "The initial parallelism of the job 
vertex. If it's a positive number, it will be respected. If it's not set(equals 
to ExecutionConfig#PARALLELISM_DEFAULT), a parallelism will be automatically 
decided for the vertex." 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to