zhuzhurk commented on code in PR #21570:
URL: https://github.com/apache/flink/pull/21570#discussion_r1062334935


##########
flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismDecider.java:
##########
@@ -219,58 +221,34 @@ private ParallelismAndInputInfos 
loadBalanceForAllToAllInputs(
         // compute subpartition ranges
         List<IndexRange> subpartitionRanges =
                 computeSubpartitionRanges(
-                        nonBroadcastBytesBySubpartition, 
nonBroadcastBytesPerTaskLimit);
-
-        if (subpartitionRanges.size() < minParallelism) {
-            long minSubpartitionBytes =
-                    
Arrays.stream(nonBroadcastBytesBySubpartition).min().getAsLong();
-            // find a legal limit so that the computed parallelism >= 
minParallelism
-            long adjustLimit =
-                    BisectionSearchUtils.findMaxLegalValue(
-                            value ->
-                                    
computeParallelism(nonBroadcastBytesBySubpartition, value)
-                                            >= minParallelism,
-                            minSubpartitionBytes,
-                            nonBroadcastBytesPerTaskLimit);
-
-            // the smaller the limit, the more even the distribution
-            final long expectedParallelism =
-                    computeParallelism(nonBroadcastBytesBySubpartition, 
adjustLimit);
-            adjustLimit =
-                    BisectionSearchUtils.findMinLegalValue(
-                            value ->
-                                    
computeParallelism(nonBroadcastBytesBySubpartition, value)
-                                            <= expectedParallelism,
-                            minSubpartitionBytes,
-                            adjustLimit);
-
-            subpartitionRanges =
-                    computeSubpartitionRanges(nonBroadcastBytesBySubpartition, 
adjustLimit);
-        } else if (subpartitionRanges.size() > maxParallelism) {
-            // find a legal limit so that the computed parallelism <= 
minParallelism
-            long adjustLimit =
-                    BisectionSearchUtils.findMinLegalValue(
-                            value ->
-                                    
computeParallelism(nonBroadcastBytesBySubpartition, value)
-                                            <= maxParallelism,
-                            nonBroadcastBytesPerTaskLimit,
-                            nonBroadcastBytes);
-
-            subpartitionRanges =
-                    computeSubpartitionRanges(nonBroadcastBytesBySubpartition, 
adjustLimit);
+                        nonBroadcastBytesBySubpartition, 
nonBroadcastDataVolumeLimit);
+
+        // if the parallelism is not legal, adjust ot a legal parallelism

Review Comment:
   ot -> to



##########
flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismAndInputInfosDeciderTest.java:
##########
@@ -0,0 +1,523 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.scheduler.adaptivebatch;
+
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.JobManagerOptions;
+import org.apache.flink.configuration.MemorySize;
+import org.apache.flink.runtime.executiongraph.ExecutionVertexInputInfo;
+import org.apache.flink.runtime.executiongraph.IndexRange;
+import org.apache.flink.runtime.executiongraph.JobVertexInputInfo;
+import org.apache.flink.runtime.executiongraph.ParallelismAndInputInfos;
+import org.apache.flink.runtime.executiongraph.ResultPartitionBytes;
+import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
+
+import org.apache.flink.shaded.guava30.com.google.common.collect.Iterables;
+
+import org.junit.jupiter.api.Test;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+import static org.apache.flink.util.Preconditions.checkState;
+import static org.assertj.core.api.Assertions.assertThat;
+
+/** Test for {@link DefaultVertexParallelismAndInputInfosDecider}. */
+class DefaultVertexParallelismAndInputInfosDeciderTest {
+
+    private static final long BYTE_256_MB = 256 * 1024 * 1024L;
+    private static final long BYTE_512_MB = 512 * 1024 * 1024L;
+    private static final long BYTE_1_GB = 1024 * 1024 * 1024L;
+    private static final long BYTE_8_GB = 8 * 1024 * 1024 * 1024L;
+    private static final long BYTE_1_TB = 1024 * 1024 * 1024 * 1024L;
+
+    private static final int MAX_PARALLELISM = 100;
+    private static final int MIN_PARALLELISM = 3;
+    private static final int DEFAULT_SOURCE_PARALLELISM = 10;
+    private static final long DATA_VOLUME_PER_TASK = 1024 * 1024 * 1024L;
+
+    @Test
+    void testNormalizedMaxAndMinParallelism() {
+        final DefaultVertexParallelismAndInputInfosDecider decider =
+                createVertexParallelismAndInputInfosDecider();
+        assertThat(decider.getMaxParallelism()).isEqualTo(64);
+        assertThat(decider.getMinParallelism()).isEqualTo(4);
+    }
+
+    @Test
+    void testNormalizeParallelismDownToPowerOf2() {
+        final DefaultVertexParallelismAndInputInfosDecider decider =
+                createVertexParallelismAndInputInfosDecider();
+
+        BlockingResultInfo resultInfo1 = 
createFromBroadcastResult(BYTE_256_MB);
+        BlockingResultInfo resultInfo2 = 
createFromNonBroadcastResult(BYTE_256_MB + BYTE_8_GB);
+
+        int parallelism = decider.decideParallelism(Arrays.asList(resultInfo1, 
resultInfo2), -1);
+
+        assertThat(parallelism).isEqualTo(8);
+    }
+
+    @Test
+    void testNormalizeParallelismUpToPowerOf2() {
+        final DefaultVertexParallelismAndInputInfosDecider decider =
+                createVertexParallelismAndInputInfosDecider();
+
+        BlockingResultInfo resultInfo1 = 
createFromBroadcastResult(BYTE_256_MB);
+        BlockingResultInfo resultInfo2 = 
createFromNonBroadcastResult(BYTE_1_GB + BYTE_8_GB);
+
+        int parallelism = decider.decideParallelism(Arrays.asList(resultInfo1, 
resultInfo2), -1);
+
+        assertThat(parallelism).isEqualTo(16);
+    }
+
+    @Test
+    void testInitiallyNormalizedParallelismIsLargerThanMaxParallelism() {
+        final DefaultVertexParallelismAndInputInfosDecider decider =
+                createVertexParallelismAndInputInfosDecider();
+
+        BlockingResultInfo resultInfo1 = 
createFromBroadcastResult(BYTE_256_MB);
+        BlockingResultInfo resultInfo2 = 
createFromNonBroadcastResult(BYTE_8_GB + BYTE_1_TB);
+
+        int parallelism = decider.decideParallelism(Arrays.asList(resultInfo1, 
resultInfo2), -1);
+
+        assertThat(parallelism).isEqualTo(64);
+    }
+
+    @Test
+    void testInitiallyNormalizedParallelismIsSmallerThanMinParallelism() {
+        final DefaultVertexParallelismAndInputInfosDecider decider =
+                createVertexParallelismAndInputInfosDecider();
+
+        BlockingResultInfo resultInfo1 = 
createFromBroadcastResult(BYTE_256_MB);
+        BlockingResultInfo resultInfo2 = 
createFromNonBroadcastResult(BYTE_512_MB);
+
+        int parallelism = decider.decideParallelism(Arrays.asList(resultInfo1, 
resultInfo2), -1);
+
+        assertThat(parallelism).isEqualTo(4);
+    }
+
+    @Test
+    void testBroadcastRatioExceedsCapRatio() {
+        final DefaultVertexParallelismAndInputInfosDecider decider =
+                createVertexParallelismAndInputInfosDecider();
+
+        BlockingResultInfo resultInfo1 = createFromBroadcastResult(BYTE_1_GB);
+        BlockingResultInfo resultInfo2 = 
createFromNonBroadcastResult(BYTE_8_GB);
+
+        int parallelism = decider.decideParallelism(Arrays.asList(resultInfo1, 
resultInfo2), -1);
+
+        assertThat(parallelism).isEqualTo(16);
+    }
+
+    @Test
+    void testNonBroadcastBytesCanNotDividedEvenly() {
+        final DefaultVertexParallelismAndInputInfosDecider decider =
+                createVertexParallelismAndInputInfosDecider();
+
+        BlockingResultInfo resultInfo1 = 
createFromBroadcastResult(BYTE_512_MB);
+        BlockingResultInfo resultInfo2 = 
createFromNonBroadcastResult(BYTE_256_MB + BYTE_8_GB);
+
+        int parallelism = decider.decideParallelism(Arrays.asList(resultInfo1, 
resultInfo2), -1);
+
+        assertThat(parallelism).isEqualTo(16);
+    }
+
+    @Test
+    void testAllEdgesAllToAll() {
+        final DefaultVertexParallelismAndInputInfosDecider decider =
+                createVertexParallelismAndInputInfosDecider(1, 10, 60L);
+
+        AllToAllBlockingResultInfo resultInfo1 =
+                createAllToAllBlockingResultInfo(
+                        new long[] {10L, 15L, 13L, 12L, 1L, 10L, 8L, 20L, 12L, 
17L});
+        AllToAllBlockingResultInfo resultInfo2 =
+                createAllToAllBlockingResultInfo(
+                        new long[] {8L, 12L, 21L, 9L, 13L, 7L, 19L, 13L, 14L, 
5L});
+        ParallelismAndInputInfos parallelismAndInputInfos =
+                decider.decideParallelismAndInputInfosForVertex(
+                        Arrays.asList(resultInfo1, resultInfo2), -1);
+
+        assertThat(parallelismAndInputInfos.getParallelism()).isEqualTo(5);
+        
assertThat(parallelismAndInputInfos.getJobVertexInputInfos()).hasSize(2);
+
+        List<IndexRange> subpartitionRanges =
+                Arrays.asList(
+                        new IndexRange(0, 1),
+                        new IndexRange(2, 3),
+                        new IndexRange(4, 6),
+                        new IndexRange(7, 8),
+                        new IndexRange(9, 9));
+        checkAllToAllJobVertexInputInfo(
+                
parallelismAndInputInfos.getJobVertexInputInfos().get(resultInfo1.getResultId()),
+                subpartitionRanges);
+        checkAllToAllJobVertexInputInfo(
+                
parallelismAndInputInfos.getJobVertexInputInfos().get(resultInfo2.getResultId()),
+                subpartitionRanges);
+    }
+
+    @Test
+    void testAllEdgesAllToAllAndDecidedParallelismIsMaxParallelism() {
+        final DefaultVertexParallelismAndInputInfosDecider decider =
+                createVertexParallelismAndInputInfosDecider(1, 2, 10L);
+
+        AllToAllBlockingResultInfo resultInfo =
+                createAllToAllBlockingResultInfo(
+                        new long[] {10L, 15L, 13L, 12L, 1L, 10L, 8L, 20L, 12L, 
17L});
+        ParallelismAndInputInfos parallelismAndInputInfos =
+                decider.decideParallelismAndInputInfosForVertex(
+                        Collections.singletonList(resultInfo), -1);
+
+        assertThat(parallelismAndInputInfos.getParallelism()).isEqualTo(2);
+        
assertThat(parallelismAndInputInfos.getJobVertexInputInfos()).hasSize(1);
+        checkAllToAllJobVertexInputInfo(
+                Iterables.getOnlyElement(
+                        
parallelismAndInputInfos.getJobVertexInputInfos().values()),
+                Arrays.asList(new IndexRange(0, 5), new IndexRange(6, 9)));
+    }
+
+    @Test
+    void testAllEdgesAllToAllAndDecidedParallelismIsMinParallelism() {
+        final DefaultVertexParallelismAndInputInfosDecider decider =
+                createVertexParallelismAndInputInfosDecider(4, 10, 1000L);
+
+        AllToAllBlockingResultInfo resultInfo =
+                createAllToAllBlockingResultInfo(
+                        new long[] {10L, 15L, 13L, 12L, 1L, 10L, 8L, 20L, 12L, 
17L});
+        ParallelismAndInputInfos parallelismAndInputInfos =
+                decider.decideParallelismAndInputInfosForVertex(
+                        Collections.singletonList(resultInfo), -1);
+
+        assertThat(parallelismAndInputInfos.getParallelism()).isEqualTo(4);
+        
assertThat(parallelismAndInputInfos.getJobVertexInputInfos()).hasSize(1);
+        checkAllToAllJobVertexInputInfo(
+                Iterables.getOnlyElement(
+                        
parallelismAndInputInfos.getJobVertexInputInfos().values()),
+                Arrays.asList(
+                        new IndexRange(0, 1),
+                        new IndexRange(2, 5),
+                        new IndexRange(6, 7),
+                        new IndexRange(8, 9)));
+    }
+
+    @Test
+    void testFallBackToEvenlyDistributeSubpartitions() {
+        final DefaultVertexParallelismAndInputInfosDecider decider =
+                createVertexParallelismAndInputInfosDecider(8, 8, 10L);
+
+        AllToAllBlockingResultInfo resultInfo =
+                createAllToAllBlockingResultInfo(
+                        new long[] {10L, 1L, 10L, 1L, 10L, 1L, 10L, 1L, 10L, 
1L});
+        ParallelismAndInputInfos parallelismAndInputInfos =
+                decider.decideParallelismAndInputInfosForVertex(
+                        Collections.singletonList(resultInfo), -1);
+
+        assertThat(parallelismAndInputInfos.getParallelism()).isEqualTo(8);
+        
assertThat(parallelismAndInputInfos.getJobVertexInputInfos()).hasSize(1);
+        checkAllToAllJobVertexInputInfo(
+                Iterables.getOnlyElement(
+                        
parallelismAndInputInfos.getJobVertexInputInfos().values()),
+                Arrays.asList(
+                        new IndexRange(0, 0),
+                        new IndexRange(1, 1),
+                        new IndexRange(2, 2),
+                        new IndexRange(3, 4),
+                        new IndexRange(5, 5),
+                        new IndexRange(6, 6),
+                        new IndexRange(7, 7),
+                        new IndexRange(8, 9)));
+    }
+
+    @Test
+    void testAllEdgesAllToAllAndOneIsBroadcast() {
+        final DefaultVertexParallelismAndInputInfosDecider decider =
+                createVertexParallelismAndInputInfosDecider(1, 10, 60L);
+
+        AllToAllBlockingResultInfo resultInfo1 =
+                createAllToAllBlockingResultInfo(
+                        new long[] {10L, 15L, 13L, 12L, 1L, 10L, 8L, 20L, 12L, 
17L});
+        AllToAllBlockingResultInfo resultInfo2 =
+                createAllToAllBlockingResultInfo(new long[] {10L}, true);
+
+        ParallelismAndInputInfos parallelismAndInputInfos =
+                decider.decideParallelismAndInputInfosForVertex(
+                        Arrays.asList(resultInfo1, resultInfo2), -1);
+
+        assertThat(parallelismAndInputInfos.getParallelism()).isEqualTo(3);
+        
assertThat(parallelismAndInputInfos.getJobVertexInputInfos()).hasSize(2);
+
+        checkAllToAllJobVertexInputInfo(
+                
parallelismAndInputInfos.getJobVertexInputInfos().get(resultInfo1.getResultId()),
+                Arrays.asList(new IndexRange(0, 3), new IndexRange(4, 7), new 
IndexRange(8, 9)));
+        checkAllToAllJobVertexInputInfo(
+                
parallelismAndInputInfos.getJobVertexInputInfos().get(resultInfo2.getResultId()),
+                Arrays.asList(new IndexRange(0, 0), new IndexRange(0, 0), new 
IndexRange(0, 0)));
+    }
+
+    @Test
+    void testAllEdgesBroadcast() {
+        final DefaultVertexParallelismAndInputInfosDecider decider =
+                createVertexParallelismAndInputInfosDecider(1, 10, 60L);
+
+        AllToAllBlockingResultInfo resultInfo1 =
+                createAllToAllBlockingResultInfo(new long[] {10L}, true);
+        AllToAllBlockingResultInfo resultInfo2 =
+                createAllToAllBlockingResultInfo(new long[] {10L}, true);
+        ParallelismAndInputInfos parallelismAndInputInfos =
+                decider.decideParallelismAndInputInfosForVertex(
+                        Arrays.asList(resultInfo1, resultInfo2), -1);
+
+        assertThat(parallelismAndInputInfos.getParallelism()).isEqualTo(1);
+        
assertThat(parallelismAndInputInfos.getJobVertexInputInfos()).hasSize(2);
+
+        checkAllToAllJobVertexInputInfo(
+                
parallelismAndInputInfos.getJobVertexInputInfos().get(resultInfo1.getResultId()),
+                Collections.singletonList(new IndexRange(0, 0)));
+        checkAllToAllJobVertexInputInfo(
+                
parallelismAndInputInfos.getJobVertexInputInfos().get(resultInfo2.getResultId()),
+                Collections.singletonList(new IndexRange(0, 0)));
+    }
+
+    @Test
+    void testHavePointwiseEdges() {
+        final DefaultVertexParallelismAndInputInfosDecider decider =
+                createVertexParallelismAndInputInfosDecider(1, 10, 60L);
+
+        AllToAllBlockingResultInfo resultInfo1 =
+                createAllToAllBlockingResultInfo(
+                        new long[] {10L, 15L, 13L, 12L, 1L, 10L, 8L, 20L, 12L, 
17L});
+        PointwiseBlockingResultInfo resultInfo2 =
+                createPointwiseBlockingResultInfo(
+                        new long[] {8L, 12L, 21L, 9L, 13L}, new long[] {7L, 
19L, 13L, 14L, 5L});
+        ParallelismAndInputInfos parallelismAndInputInfos =
+                decider.decideParallelismAndInputInfosForVertex(
+                        Arrays.asList(resultInfo1, resultInfo2), -1);
+
+        assertThat(parallelismAndInputInfos.getParallelism()).isEqualTo(4);
+        
assertThat(parallelismAndInputInfos.getJobVertexInputInfos()).hasSize(2);
+
+        checkAllToAllJobVertexInputInfo(
+                
parallelismAndInputInfos.getJobVertexInputInfos().get(resultInfo1.getResultId()),
+                Arrays.asList(
+                        new IndexRange(0, 1),
+                        new IndexRange(2, 4),
+                        new IndexRange(5, 6),
+                        new IndexRange(7, 9)));
+        checkPointwiseJobVertexInputInfo(
+                
parallelismAndInputInfos.getJobVertexInputInfos().get(resultInfo2.getResultId()),
+                Arrays.asList(
+                        new IndexRange(0, 0),
+                        new IndexRange(0, 0),
+                        new IndexRange(1, 1),
+                        new IndexRange(1, 1)),
+                Arrays.asList(
+                        new IndexRange(0, 1),
+                        new IndexRange(2, 4),
+                        new IndexRange(0, 1),
+                        new IndexRange(2, 4)));
+    }
+
+    @Test
+    void testParallelismAlreadyDecided() {
+        final DefaultVertexParallelismAndInputInfosDecider decider =
+                createVertexParallelismAndInputInfosDecider();
+
+        AllToAllBlockingResultInfo allToAllBlockingResultInfo =
+                createAllToAllBlockingResultInfo(
+                        new long[] {10L, 15L, 13L, 12L, 1L, 10L, 8L, 20L, 12L, 
17L});
+        ParallelismAndInputInfos parallelismAndInputInfos =
+                decider.decideParallelismAndInputInfosForVertex(
+                        Collections.singletonList(allToAllBlockingResultInfo), 
3);
+
+        assertThat(parallelismAndInputInfos.getParallelism()).isEqualTo(3);
+        
assertThat(parallelismAndInputInfos.getJobVertexInputInfos()).hasSize(1);
+
+        checkAllToAllJobVertexInputInfo(
+                Iterables.getOnlyElement(
+                        
parallelismAndInputInfos.getJobVertexInputInfos().values()),
+                Arrays.asList(new IndexRange(0, 2), new IndexRange(3, 5), new 
IndexRange(6, 9)));
+    }
+
+    @Test
+    void testSourceJobVertex() {

Review Comment:
   Can we have a case to verify that when a vertex consumes a result multiple 
times, the result is correct?



##########
flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismAndInputInfosDecider.java:
##########
@@ -0,0 +1,479 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.scheduler.adaptivebatch;
+
+import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.api.common.ExecutionConfig;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.JobManagerOptions;
+import org.apache.flink.configuration.MemorySize;
+import org.apache.flink.runtime.executiongraph.ExecutionVertexInputInfo;
+import org.apache.flink.runtime.executiongraph.IndexRange;
+import org.apache.flink.runtime.executiongraph.JobVertexInputInfo;
+import org.apache.flink.runtime.executiongraph.ParallelismAndInputInfos;
+import org.apache.flink.runtime.executiongraph.VertexInputInfoComputationUtils;
+import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
+import org.apache.flink.util.MathUtils;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Set;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
+import static org.apache.flink.util.Preconditions.checkArgument;
+import static org.apache.flink.util.Preconditions.checkNotNull;
+import static org.apache.flink.util.Preconditions.checkState;
+
+/**
+ * Default implementation of {@link VertexParallelismAndInputInfosDecider}. 
This implementation will
+ * decide parallelism and {@link JobVertexInputInfo}s as follows:
+ *
+ * <p>1. For job vertices whose inputs are all ALL_TO_ALL edges, evenly 
distribute data to
+ * downstream subtasks, make different downstream subtasks consume roughly the 
same amount of data.
+ *
+ * <p>2. For other cases, evenly distribute subpartitions to downstream 
subtasks, make different
+ * downstream subtasks consume roughly the same number of subpartitions.
+ */
+public class DefaultVertexParallelismAndInputInfosDecider
+        implements VertexParallelismAndInputInfosDecider {
+
+    private static final Logger LOG =
+            
LoggerFactory.getLogger(DefaultVertexParallelismAndInputInfosDecider.class);
+
+    /**
+     * The cap ratio of broadcast bytes to data volume per task. The cap ratio 
is 0.5 currently
+     * because we usually expect the broadcast dataset to be smaller than 
non-broadcast. We can make
+     * it configurable later if we see users requesting for it.
+     */
+    private static final double CAP_RATIO_OF_BROADCAST = 0.5;
+
+    private final int maxParallelism;
+    private final int minParallelism;
+    private final long dataVolumePerTask;
+    private final int defaultSourceParallelism;
+
+    private DefaultVertexParallelismAndInputInfosDecider(
+            int maxParallelism,
+            int minParallelism,
+            MemorySize dataVolumePerTask,
+            int defaultSourceParallelism) {
+
+        checkArgument(minParallelism > 0, "The minimum parallelism must be 
larger than 0.");
+        checkArgument(
+                maxParallelism >= minParallelism,
+                "Maximum parallelism should be greater than or equal to the 
minimum parallelism.");
+        checkArgument(
+                defaultSourceParallelism > 0,
+                "The default source parallelism must be larger than 0.");
+        checkNotNull(dataVolumePerTask);
+
+        this.maxParallelism = maxParallelism;
+        this.minParallelism = minParallelism;
+        this.dataVolumePerTask = dataVolumePerTask.getBytes();
+        this.defaultSourceParallelism = defaultSourceParallelism;
+    }
+
+    @Override
+    public ParallelismAndInputInfos decideParallelismAndInputInfosForVertex(
+            List<BlockingResultInfo> consumedResults, int initialParallelism) {
+        checkArgument(
+                initialParallelism == ExecutionConfig.PARALLELISM_DEFAULT
+                        || initialParallelism > 0);
+
+        if (initialParallelism == ExecutionConfig.PARALLELISM_DEFAULT
+                && areAllInputsAllToAll(consumedResults)
+                && !areAllInputsBroadcast(consumedResults)) {

Review Comment:
   I know this can work, but is a bit weird that we rely on this: for a source 
vertex, `areAllInputsBroadcast == true`.
   I prefer to explicitly check if it is a source and goes to the right branch.



##########
flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismDecider.java:
##########
@@ -219,58 +221,34 @@ private ParallelismAndInputInfos 
loadBalanceForAllToAllInputs(
         // compute subpartition ranges
         List<IndexRange> subpartitionRanges =
                 computeSubpartitionRanges(
-                        nonBroadcastBytesBySubpartition, 
nonBroadcastBytesPerTaskLimit);
-
-        if (subpartitionRanges.size() < minParallelism) {
-            long minSubpartitionBytes =
-                    
Arrays.stream(nonBroadcastBytesBySubpartition).min().getAsLong();
-            // find a legal limit so that the computed parallelism >= 
minParallelism
-            long adjustLimit =
-                    BisectionSearchUtils.findMaxLegalValue(
-                            value ->
-                                    
computeParallelism(nonBroadcastBytesBySubpartition, value)
-                                            >= minParallelism,
-                            minSubpartitionBytes,
-                            nonBroadcastBytesPerTaskLimit);
-
-            // the smaller the limit, the more even the distribution
-            final long expectedParallelism =
-                    computeParallelism(nonBroadcastBytesBySubpartition, 
adjustLimit);
-            adjustLimit =
-                    BisectionSearchUtils.findMinLegalValue(
-                            value ->
-                                    
computeParallelism(nonBroadcastBytesBySubpartition, value)
-                                            <= expectedParallelism,
-                            minSubpartitionBytes,
-                            adjustLimit);
-
-            subpartitionRanges =
-                    computeSubpartitionRanges(nonBroadcastBytesBySubpartition, 
adjustLimit);
-        } else if (subpartitionRanges.size() > maxParallelism) {
-            // find a legal limit so that the computed parallelism <= 
minParallelism
-            long adjustLimit =
-                    BisectionSearchUtils.findMinLegalValue(
-                            value ->
-                                    
computeParallelism(nonBroadcastBytesBySubpartition, value)
-                                            <= maxParallelism,
-                            nonBroadcastBytesPerTaskLimit,
-                            nonBroadcastBytes);
-
-            subpartitionRanges =
-                    computeSubpartitionRanges(nonBroadcastBytesBySubpartition, 
adjustLimit);
+                        nonBroadcastBytesBySubpartition, 
nonBroadcastDataVolumeLimit);
+
+        // if the parallelism is not legal, adjust ot a legal parallelism
+        if (!isLegalParallelism(subpartitionRanges.size())) {
+            Optional<List<IndexRange>> adjustedSubpartitionRanges =
+                    adjustToClosestLegalParallelism(
+                            nonBroadcastBytesBySubpartition,
+                            nonBroadcastDataVolumeLimit,
+                            subpartitionRanges.size());
+            if (!adjustedSubpartitionRanges.isPresent()) {
+                // can't find any legal parallelism, fall back to evenly 
distribute subpartitions
+                return decideParallelismAndEvenlyDistributeSubpartitions(

Review Comment:
   It's better to add a log here to show this problem.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to