zhuzhurk commented on code in PR #25552:
URL: https://github.com/apache/flink/pull/25552#discussion_r1907034134


##########
flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/util/VertexParallelismAndInputInfosDeciderUtils.java:
##########
@@ -0,0 +1,656 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.scheduler.adaptivebatch.util;
+
+import org.apache.flink.runtime.executiongraph.ExecutionVertexInputInfo;
+import org.apache.flink.runtime.executiongraph.IndexRange;
+import org.apache.flink.runtime.executiongraph.JobVertexInputInfo;
+import org.apache.flink.runtime.scheduler.adaptivebatch.BisectionSearchUtils;
+import org.apache.flink.runtime.scheduler.adaptivebatch.BlockingInputInfo;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Set;
+import java.util.TreeMap;
+import java.util.function.Function;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+import java.util.stream.LongStream;
+
+import static 
org.apache.flink.runtime.executiongraph.IndexRangeUtil.mergeIndexRanges;
+import static org.apache.flink.util.Preconditions.checkArgument;
+import static org.apache.flink.util.Preconditions.checkState;
+
+/** Utils class for VertexParallelismAndInputInfosDecider. */
+public class VertexParallelismAndInputInfosDeciderUtils {
+    /**
+     * Adjust the parallelism to the closest legal parallelism and return the 
computed subpartition
+     * ranges.
+     *
+     * @param currentDataVolumeLimit current data volume limit
+     * @param currentParallelism current parallelism
+     * @param minParallelism the min parallelism
+     * @param maxParallelism the max parallelism
+     * @param minLimit the minimum data volume limit
+     * @param maxLimit the maximum data volume limit
+     * @param parallelismComputer a function to compute the parallelism 
according to the data volume
+     *     limit
+     * @param subpartitionRangesComputer a function to compute the 
subpartition ranges according to
+     *     the data volume limit
+     * @return the computed subpartition ranges or {@link Optional#empty()} if 
we can't find any
+     *     legal parallelism
+     */
+    public static Optional<List<IndexRange>> adjustToClosestLegalParallelism(
+            long currentDataVolumeLimit,
+            int currentParallelism,
+            int minParallelism,
+            int maxParallelism,
+            long minLimit,
+            long maxLimit,
+            Function<Long, Integer> parallelismComputer,
+            Function<Long, List<IndexRange>> subpartitionRangesComputer) {
+        long adjustedDataVolumeLimit = currentDataVolumeLimit;
+        if (currentParallelism < minParallelism) {
+            // Current parallelism is smaller than the user-specified 
lower-limit of parallelism ,
+            // we need to adjust it to the closest/minimum possible legal 
parallelism. That is, we
+            // need to find the maximum legal dataVolumeLimit.
+            adjustedDataVolumeLimit =
+                    BisectionSearchUtils.findMaxLegalValue(
+                            value -> parallelismComputer.apply(value) >= 
minParallelism,
+                            minLimit,
+                            currentDataVolumeLimit);
+
+            // When we find the minimum possible legal parallelism, the 
dataVolumeLimit that can
+            // lead to this parallelism may be a range, and we need to find 
the minimum value of
+            // this range to make the data distribution as even as possible 
(the smaller the
+            // dataVolumeLimit, the more even the distribution)
+            final long minPossibleLegalParallelism =
+                    parallelismComputer.apply(adjustedDataVolumeLimit);
+            adjustedDataVolumeLimit =
+                    BisectionSearchUtils.findMinLegalValue(
+                            value ->
+                                    parallelismComputer.apply(value) == 
minPossibleLegalParallelism,
+                            minLimit,
+                            adjustedDataVolumeLimit);
+
+        } else if (currentParallelism > maxParallelism) {
+            // Current parallelism is larger than the user-specified 
upper-limit of parallelism ,
+            // we need to adjust it to the closest/maximum possible legal 
parallelism. That is, we
+            // need to find the minimum legal dataVolumeLimit.
+            adjustedDataVolumeLimit =
+                    BisectionSearchUtils.findMinLegalValue(
+                            value -> parallelismComputer.apply(value) <= 
maxParallelism,
+                            currentDataVolumeLimit,
+                            maxLimit);
+        }
+
+        int adjustedParallelism = 
parallelismComputer.apply(adjustedDataVolumeLimit);
+        if (isLegalParallelism(adjustedParallelism, minParallelism, 
maxParallelism)) {
+            return 
Optional.of(subpartitionRangesComputer.apply(adjustedDataVolumeLimit));
+        } else {
+            return Optional.empty();
+        }
+    }
+
+    /**
+     * Computes the Cartesian product of a list of lists.
+     *
+     * <p>The Cartesian product is a set of all possible combinations formed 
by picking one element
+     * from each list. For example, given input lists [[1, 2], [3, 4]], the 
result will be [[1, 3],
+     * [1, 4], [2, 3], [2, 4]].
+     *
+     * <p>Note: If the input list is empty or contains an empty list, the 
result will be an empty
+     * list.
+     *
+     * @param <T> the type of elements in the lists
+     * @param lists a list of lists for which the Cartesian product is to be 
computed
+     * @return a list of lists representing the Cartesian product, where each 
inner list is a
+     *     combination
+     */
+    public static <T> List<List<T>> cartesianProduct(List<List<T>> lists) {
+        List<List<T>> resultLists = new ArrayList<>();
+        if (lists.isEmpty()) {
+            resultLists.add(new ArrayList<>());
+            return resultLists;
+        } else {
+            List<T> firstList = lists.get(0);
+            List<List<T>> remainingLists = cartesianProduct(lists.subList(1, 
lists.size()));
+            for (T condition : firstList) {
+                for (List<T> remainingList : remainingLists) {
+                    ArrayList<T> resultList = new ArrayList<>();
+                    resultList.add(condition);
+                    resultList.addAll(remainingList);
+                    resultLists.add(resultList);
+                }
+            }
+        }
+        return resultLists;
+    }
+
+    /**
+     * Calculates the median of a given array of long integers. If the 
calculated median is less
+     * than 1, it returns 1 instead.
+     *
+     * @param nums an array of long integers for which to calculate the median.
+     * @return the median value, which will be at least 1.
+     */
+    public static long median(long[] nums) {
+        int len = nums.length;
+        long[] sortedNums = LongStream.of(nums).sorted().toArray();
+        if (len % 2 == 0) {
+            return Math.max((sortedNums[len / 2] + sortedNums[len / 2 - 1]) / 
2, 1L);
+        } else {
+            return Math.max(sortedNums[len / 2], 1L);
+        }
+    }
+
+    /**
+     * Computes the skew threshold based on the given media size and skewed 
factor.
+     *
+     * <p>The skew threshold is calculated as the product of the media size 
and the skewed factor.
+     * To ensure that the computed threshold does not fall below a specified 
default value, the
+     * method uses {@link Math#max} to return the largest of the calculated 
threshold and the
+     * default threshold.
+     *
+     * @param medianSize the size of the median
+     * @param skewedFactor a factor indicating the degree of skewness
+     * @param defaultSkewedThreshold the default threshold to be used if the 
calculated threshold is
+     *     less than this value
+     * @return the computed skew threshold, which is guaranteed to be at least 
the default skewed
+     *     threshold.
+     */
+    public static long computeSkewThreshold(
+            long medianSize, double skewedFactor, long defaultSkewedThreshold) 
{
+        return (long) Math.max(medianSize * skewedFactor, 
defaultSkewedThreshold);
+    }
+
+    /**
+     * Computes the target data size for each task based on the sizes of 
non-skewed subpartitions.
+     *
+     * <p>The target size is determined as the average size of non-skewed 
subpartitions and ensures
+     * that the target size is at least equal to the specified data volume per 
task.
+     *
+     * @param subpartitionBytes an array representing the data size of each 
subpartition
+     * @param skewedThreshold skewed threshold in bytes
+     * @param dataVolumePerTask the amount of data that should be allocated 
per task
+     * @return the computed target size for each task, which is the maximum 
between the average size
+     *     of non-skewed subpartitions and data volume per task.
+     */
+    public static long computeTargetSize(
+            long[] subpartitionBytes, long skewedThreshold, long 
dataVolumePerTask) {
+        long[] nonSkewPartitions =
+                LongStream.of(subpartitionBytes).filter(v -> v <= 
skewedThreshold).toArray();
+        if (nonSkewPartitions.length == 0) {
+            return dataVolumePerTask;
+        } else {
+            return Math.max(
+                    dataVolumePerTask,
+                    LongStream.of(nonSkewPartitions).sum() / 
nonSkewPartitions.length);
+        }
+    }
+
+    public static List<BlockingInputInfo> getNonBroadcastInputInfos(
+            List<BlockingInputInfo> consumedResults) {
+        return consumedResults.stream()
+                .filter(resultInfo -> !resultInfo.isBroadcast())
+                .collect(Collectors.toList());
+    }
+
+    public static List<BlockingInputInfo> getBroadcastInputInfos(
+            List<BlockingInputInfo> consumedResults) {
+        return consumedResults.stream()
+                .filter(BlockingInputInfo::isBroadcast)
+                .collect(Collectors.toList());
+    }
+
+    public static boolean hasSameNumPartitions(List<BlockingInputInfo> 
inputInfos) {
+        Set<Integer> partitionNums =
+                inputInfos.stream()
+                        .map(BlockingInputInfo::getNumPartitions)
+                        .collect(Collectors.toSet());
+        return partitionNums.size() == 1;
+    }
+
+    public static int getMaxNumPartitions(List<BlockingInputInfo> 
consumedResults) {
+        checkArgument(!consumedResults.isEmpty());
+        return consumedResults.stream()
+                .mapToInt(BlockingInputInfo::getNumPartitions)
+                .max()
+                .getAsInt();
+    }
+
+    public static int getMaxNumSubpartitions(List<BlockingInputInfo> 
consumedResults) {
+        checkArgument(!consumedResults.isEmpty());
+        return consumedResults.stream()
+                .mapToInt(
+                        resultInfo ->
+                                IntStream.range(0, 
resultInfo.getNumPartitions())
+                                        .boxed()
+                                        
.mapToInt(resultInfo::getNumSubpartitions)
+                                        .sum())
+                .max()
+                .getAsInt();
+    }
+
+    public static int checkAndGetSubpartitionNum(List<BlockingInputInfo> 
consumedResults) {
+        final Set<Integer> subpartitionNumSet =
+                consumedResults.stream()
+                        .flatMap(
+                                resultInfo ->
+                                        IntStream.range(0, 
resultInfo.getNumPartitions())
+                                                .boxed()
+                                                
.map(resultInfo::getNumSubpartitions))
+                        .collect(Collectors.toSet());
+        // all partitions have the same subpartition num
+        checkState(subpartitionNumSet.size() == 1);
+        return subpartitionNumSet.iterator().next();
+    }
+
+    public static boolean isLegalParallelism(
+            int parallelism, int minParallelism, int maxParallelism) {
+        return parallelism >= minParallelism && parallelism <= maxParallelism;
+    }
+
+    public static boolean checkAndGetIntraCorrelation(List<BlockingInputInfo> 
inputInfos) {
+        Set<Boolean> intraCorrelationSet =
+                inputInfos.stream()
+                        .map(BlockingInputInfo::isIntraInputKeyCorrelated)
+                        .collect(Collectors.toSet());
+        checkArgument(intraCorrelationSet.size() == 1);
+        return intraCorrelationSet.iterator().next();
+    }
+
+    public static int checkAndGetParallelism(Collection<JobVertexInputInfo> 
vertexInputInfos) {
+        final Set<Integer> parallelismSet =
+                vertexInputInfos.stream()
+                        .map(
+                                vertexInputInfo ->
+                                        
vertexInputInfo.getExecutionVertexInputInfos().size())
+                        .collect(Collectors.toSet());
+        checkState(parallelismSet.size() == 1);
+        return parallelismSet.iterator().next();
+    }
+
+    /**
+     * Attempts to compute the subpartition slice ranges to ensure even 
distribution of data across
+     * downstream tasks.
+     *
+     * <p>This method first tries to compute the subpartition slice ranges by 
evenly distributing
+     * the data volume. If that fails, it attempts to compute the ranges by 
evenly distributing the
+     * number of subpartition slices.
+     *
+     * @param minParallelism The minimum parallelism.
+     * @param maxParallelism The maximum parallelism.
+     * @param maxSubpartitionSliceRangePerTask The maximum number of 
subpartition slice ranges per
+     *     task.
+     * @param maxDataVolumePerTask The maximum data volume per task.
+     * @param subpartitionSlicesByTypeNumber A map of lists of subpartition 
slices grouped by type
+     *     number.
+     * @return An {@code Optional} containing a list of index ranges 
representing the subpartition
+     *     slice ranges. Returns an empty {@code Optional} if no suitable 
ranges can be computed.
+     */
+    public static Optional<List<IndexRange>> tryComputeSubpartitionSliceRange(
+            int minParallelism,
+            int maxParallelism,
+            int maxSubpartitionSliceRangePerTask,
+            long maxDataVolumePerTask,
+            Map<Integer, List<SubpartitionSlice>> 
subpartitionSlicesByTypeNumber) {
+        Optional<List<IndexRange>> subpartitionSliceRanges =
+                tryComputeSubpartitionSliceRangeEvenlyDistributedData(
+                        minParallelism,
+                        maxParallelism,
+                        maxSubpartitionSliceRangePerTask,
+                        maxDataVolumePerTask,
+                        subpartitionSlicesByTypeNumber);
+        if (subpartitionSliceRanges.isEmpty()) {
+            subpartitionSliceRanges =
+                    
tryComputeSubpartitionSliceRangeEvenlyDistributedSubpartitionSlices(
+                            minParallelism, maxParallelism, 
subpartitionSlicesByTypeNumber);
+        }
+        return subpartitionSliceRanges;
+    }
+
+    public static List<ExecutionVertexInputInfo> 
createdExecutionVertexInputInfosForBroadcast(
+            BlockingInputInfo inputInfo, int parallelism) {
+        checkArgument(inputInfo.isBroadcast());
+        int numPartitions = inputInfo.getNumPartitions();
+        List<ExecutionVertexInputInfo> executionVertexInputInfos = new 
ArrayList<>();
+        for (int i = 0; i < parallelism; ++i) {
+            ExecutionVertexInputInfo executionVertexInputInfo;
+            if (inputInfo.isSingleSubpartitionContainsAllData()) {
+                executionVertexInputInfo =
+                        new ExecutionVertexInputInfo(
+                                i, new IndexRange(0, numPartitions - 1), new 
IndexRange(0, 0));
+            } else {
+                // The partitions of the all-to-all result have the same 
number of
+                // subpartitions. So we can use the first partition's 
subpartition
+                // number.
+                executionVertexInputInfo =
+                        new ExecutionVertexInputInfo(
+                                i,
+                                new IndexRange(0, numPartitions - 1),
+                                new IndexRange(0, 
inputInfo.getNumSubpartitions(0) - 1));
+            }
+            executionVertexInputInfos.add(executionVertexInputInfo);
+        }
+        return executionVertexInputInfos;
+    }
+
+    public static List<ExecutionVertexInputInfo> 
createdExecutionVertexInputInfosForNonBroadcast(
+            BlockingInputInfo inputInfo,
+            List<IndexRange> subpartitionSliceRanges,
+            List<SubpartitionSlice> subpartitionSlices) {
+        checkArgument(!inputInfo.isBroadcast());
+        int numPartitions = inputInfo.getNumPartitions();
+        List<ExecutionVertexInputInfo> executionVertexInputInfos = new 
ArrayList<>();
+        for (int i = 0; i < subpartitionSliceRanges.size(); ++i) {
+            IndexRange subpartitionSliceRange = subpartitionSliceRanges.get(i);
+            // Convert subpartitionSlices to partition range to subpartition 
range
+            Map<IndexRange, IndexRange> consumedSubpartitionGroups =
+                    computeConsumedSubpartitionGroups(
+                            numPartitions,
+                            inputInfo.isPointwise(),
+                            subpartitionSliceRange,
+                            subpartitionSlices);
+            executionVertexInputInfos.add(
+                    new ExecutionVertexInputInfo(i, 
consumedSubpartitionGroups));
+        }
+        return executionVertexInputInfos;
+    }
+
+    private static Optional<List<IndexRange>> 
tryComputeSubpartitionSliceRangeEvenlyDistributedData(
+            int minParallelism,
+            int maxParallelism,
+            int maxSubpartitionSliceRangePerTask,
+            long maxDataVolumePerTask,
+            Map<Integer, List<SubpartitionSlice>> 
subpartitionSlicesByTypeNumber) {
+        int subpartitionSlicesSize =
+                
checkAdnGetSubpartitionSlicesSize(subpartitionSlicesByTypeNumber);
+        // Distribute the input data evenly among the downstream tasks and 
record the
+        // subpartition slice range for each task.
+        List<IndexRange> subpartitionSliceRanges =
+                computeSubpartitionSliceRanges(
+                        maxDataVolumePerTask,
+                        maxSubpartitionSliceRangePerTask,
+                        subpartitionSlicesSize,
+                        subpartitionSlicesByTypeNumber);
+        // if the parallelism is not legal, try to adjust to a legal 
parallelism
+        if (!isLegalParallelism(subpartitionSliceRanges.size(), 
minParallelism, maxParallelism)) {
+            long minBytesSize = maxDataVolumePerTask;
+            long sumBytesSize = 0;
+            for (int i = 0; i < subpartitionSlicesSize; ++i) {
+                long currentBytesSize = 0;
+                for (List<SubpartitionSlice> subpartitionSlice :
+                        subpartitionSlicesByTypeNumber.values()) {
+                    currentBytesSize += 
subpartitionSlice.get(i).getDataBytes();
+                }
+                minBytesSize = Math.min(minBytesSize, currentBytesSize);
+                sumBytesSize += currentBytesSize;
+            }
+            return adjustToClosestLegalParallelism(
+                    maxDataVolumePerTask,
+                    subpartitionSliceRanges.size(),
+                    minParallelism,
+                    maxParallelism,
+                    minBytesSize,
+                    sumBytesSize,
+                    limit ->
+                            computeParallelism(
+                                    limit,
+                                    maxSubpartitionSliceRangePerTask,
+                                    subpartitionSlicesSize,
+                                    subpartitionSlicesByTypeNumber),
+                    limit ->
+                            computeSubpartitionSliceRanges(
+                                    limit,
+                                    maxSubpartitionSliceRangePerTask,
+                                    subpartitionSlicesSize,
+                                    subpartitionSlicesByTypeNumber));
+        }
+        return Optional.of(subpartitionSliceRanges);
+    }
+
+    private static Optional<List<IndexRange>>
+            
tryComputeSubpartitionSliceRangeEvenlyDistributedSubpartitionSlices(
+                    int minParallelism,
+                    int maxParallelism,
+                    Map<Integer, List<SubpartitionSlice>> 
subpartitionSlicesByTypeNumber) {
+        int subpartitionSlicesSize =
+                
checkAdnGetSubpartitionSlicesSize(subpartitionSlicesByTypeNumber);
+        if (subpartitionSlicesSize < minParallelism) {
+            return Optional.empty();
+        }
+        int parallelism = Math.min(subpartitionSlicesSize, maxParallelism);
+        List<IndexRange> subpartitionSliceRanges = new ArrayList<>();
+        for (int i = 0; i < parallelism; i++) {
+            int start = i * subpartitionSlicesSize / parallelism;
+            int nextStart = (i + 1) * subpartitionSlicesSize / parallelism;
+            subpartitionSliceRanges.add(new IndexRange(start, nextStart - 1));
+        }
+        checkState(subpartitionSliceRanges.size() == parallelism);
+        return Optional.of(subpartitionSliceRanges);
+    }
+
+    /**
+     * Merge the subpartition slices of the specified range into an index 
range map, which the key
+     * is the partition index range and the value is the subpartition range.
+     *
+     * <p>Note: For pointwise, we prioritize that their partition ranges have 
no overlap, and for

Review Comment:
   Could you add these explanation to the comment? This is not intuitive for 
other developers.
   
   And maybe do a check for the inputs? e.g. check that the key range size is 1.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to