zhuzhurk commented on code in PR #25552: URL: https://github.com/apache/flink/pull/25552#discussion_r1907034134
########## flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/util/VertexParallelismAndInputInfosDeciderUtils.java: ########## @@ -0,0 +1,656 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.scheduler.adaptivebatch.util; + +import org.apache.flink.runtime.executiongraph.ExecutionVertexInputInfo; +import org.apache.flink.runtime.executiongraph.IndexRange; +import org.apache.flink.runtime.executiongraph.JobVertexInputInfo; +import org.apache.flink.runtime.scheduler.adaptivebatch.BisectionSearchUtils; +import org.apache.flink.runtime.scheduler.adaptivebatch.BlockingInputInfo; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Comparator; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.TreeMap; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import java.util.stream.LongStream; + +import static org.apache.flink.runtime.executiongraph.IndexRangeUtil.mergeIndexRanges; +import static org.apache.flink.util.Preconditions.checkArgument; +import static org.apache.flink.util.Preconditions.checkState; + +/** Utils class for VertexParallelismAndInputInfosDecider. */ +public class VertexParallelismAndInputInfosDeciderUtils { + /** + * Adjust the parallelism to the closest legal parallelism and return the computed subpartition + * ranges. + * + * @param currentDataVolumeLimit current data volume limit + * @param currentParallelism current parallelism + * @param minParallelism the min parallelism + * @param maxParallelism the max parallelism + * @param minLimit the minimum data volume limit + * @param maxLimit the maximum data volume limit + * @param parallelismComputer a function to compute the parallelism according to the data volume + * limit + * @param subpartitionRangesComputer a function to compute the subpartition ranges according to + * the data volume limit + * @return the computed subpartition ranges or {@link Optional#empty()} if we can't find any + * legal parallelism + */ + public static Optional<List<IndexRange>> adjustToClosestLegalParallelism( + long currentDataVolumeLimit, + int currentParallelism, + int minParallelism, + int maxParallelism, + long minLimit, + long maxLimit, + Function<Long, Integer> parallelismComputer, + Function<Long, List<IndexRange>> subpartitionRangesComputer) { + long adjustedDataVolumeLimit = currentDataVolumeLimit; + if (currentParallelism < minParallelism) { + // Current parallelism is smaller than the user-specified lower-limit of parallelism , + // we need to adjust it to the closest/minimum possible legal parallelism. That is, we + // need to find the maximum legal dataVolumeLimit. + adjustedDataVolumeLimit = + BisectionSearchUtils.findMaxLegalValue( + value -> parallelismComputer.apply(value) >= minParallelism, + minLimit, + currentDataVolumeLimit); + + // When we find the minimum possible legal parallelism, the dataVolumeLimit that can + // lead to this parallelism may be a range, and we need to find the minimum value of + // this range to make the data distribution as even as possible (the smaller the + // dataVolumeLimit, the more even the distribution) + final long minPossibleLegalParallelism = + parallelismComputer.apply(adjustedDataVolumeLimit); + adjustedDataVolumeLimit = + BisectionSearchUtils.findMinLegalValue( + value -> + parallelismComputer.apply(value) == minPossibleLegalParallelism, + minLimit, + adjustedDataVolumeLimit); + + } else if (currentParallelism > maxParallelism) { + // Current parallelism is larger than the user-specified upper-limit of parallelism , + // we need to adjust it to the closest/maximum possible legal parallelism. That is, we + // need to find the minimum legal dataVolumeLimit. + adjustedDataVolumeLimit = + BisectionSearchUtils.findMinLegalValue( + value -> parallelismComputer.apply(value) <= maxParallelism, + currentDataVolumeLimit, + maxLimit); + } + + int adjustedParallelism = parallelismComputer.apply(adjustedDataVolumeLimit); + if (isLegalParallelism(adjustedParallelism, minParallelism, maxParallelism)) { + return Optional.of(subpartitionRangesComputer.apply(adjustedDataVolumeLimit)); + } else { + return Optional.empty(); + } + } + + /** + * Computes the Cartesian product of a list of lists. + * + * <p>The Cartesian product is a set of all possible combinations formed by picking one element + * from each list. For example, given input lists [[1, 2], [3, 4]], the result will be [[1, 3], + * [1, 4], [2, 3], [2, 4]]. + * + * <p>Note: If the input list is empty or contains an empty list, the result will be an empty + * list. + * + * @param <T> the type of elements in the lists + * @param lists a list of lists for which the Cartesian product is to be computed + * @return a list of lists representing the Cartesian product, where each inner list is a + * combination + */ + public static <T> List<List<T>> cartesianProduct(List<List<T>> lists) { + List<List<T>> resultLists = new ArrayList<>(); + if (lists.isEmpty()) { + resultLists.add(new ArrayList<>()); + return resultLists; + } else { + List<T> firstList = lists.get(0); + List<List<T>> remainingLists = cartesianProduct(lists.subList(1, lists.size())); + for (T condition : firstList) { + for (List<T> remainingList : remainingLists) { + ArrayList<T> resultList = new ArrayList<>(); + resultList.add(condition); + resultList.addAll(remainingList); + resultLists.add(resultList); + } + } + } + return resultLists; + } + + /** + * Calculates the median of a given array of long integers. If the calculated median is less + * than 1, it returns 1 instead. + * + * @param nums an array of long integers for which to calculate the median. + * @return the median value, which will be at least 1. + */ + public static long median(long[] nums) { + int len = nums.length; + long[] sortedNums = LongStream.of(nums).sorted().toArray(); + if (len % 2 == 0) { + return Math.max((sortedNums[len / 2] + sortedNums[len / 2 - 1]) / 2, 1L); + } else { + return Math.max(sortedNums[len / 2], 1L); + } + } + + /** + * Computes the skew threshold based on the given media size and skewed factor. + * + * <p>The skew threshold is calculated as the product of the media size and the skewed factor. + * To ensure that the computed threshold does not fall below a specified default value, the + * method uses {@link Math#max} to return the largest of the calculated threshold and the + * default threshold. + * + * @param medianSize the size of the median + * @param skewedFactor a factor indicating the degree of skewness + * @param defaultSkewedThreshold the default threshold to be used if the calculated threshold is + * less than this value + * @return the computed skew threshold, which is guaranteed to be at least the default skewed + * threshold. + */ + public static long computeSkewThreshold( + long medianSize, double skewedFactor, long defaultSkewedThreshold) { + return (long) Math.max(medianSize * skewedFactor, defaultSkewedThreshold); + } + + /** + * Computes the target data size for each task based on the sizes of non-skewed subpartitions. + * + * <p>The target size is determined as the average size of non-skewed subpartitions and ensures + * that the target size is at least equal to the specified data volume per task. + * + * @param subpartitionBytes an array representing the data size of each subpartition + * @param skewedThreshold skewed threshold in bytes + * @param dataVolumePerTask the amount of data that should be allocated per task + * @return the computed target size for each task, which is the maximum between the average size + * of non-skewed subpartitions and data volume per task. + */ + public static long computeTargetSize( + long[] subpartitionBytes, long skewedThreshold, long dataVolumePerTask) { + long[] nonSkewPartitions = + LongStream.of(subpartitionBytes).filter(v -> v <= skewedThreshold).toArray(); + if (nonSkewPartitions.length == 0) { + return dataVolumePerTask; + } else { + return Math.max( + dataVolumePerTask, + LongStream.of(nonSkewPartitions).sum() / nonSkewPartitions.length); + } + } + + public static List<BlockingInputInfo> getNonBroadcastInputInfos( + List<BlockingInputInfo> consumedResults) { + return consumedResults.stream() + .filter(resultInfo -> !resultInfo.isBroadcast()) + .collect(Collectors.toList()); + } + + public static List<BlockingInputInfo> getBroadcastInputInfos( + List<BlockingInputInfo> consumedResults) { + return consumedResults.stream() + .filter(BlockingInputInfo::isBroadcast) + .collect(Collectors.toList()); + } + + public static boolean hasSameNumPartitions(List<BlockingInputInfo> inputInfos) { + Set<Integer> partitionNums = + inputInfos.stream() + .map(BlockingInputInfo::getNumPartitions) + .collect(Collectors.toSet()); + return partitionNums.size() == 1; + } + + public static int getMaxNumPartitions(List<BlockingInputInfo> consumedResults) { + checkArgument(!consumedResults.isEmpty()); + return consumedResults.stream() + .mapToInt(BlockingInputInfo::getNumPartitions) + .max() + .getAsInt(); + } + + public static int getMaxNumSubpartitions(List<BlockingInputInfo> consumedResults) { + checkArgument(!consumedResults.isEmpty()); + return consumedResults.stream() + .mapToInt( + resultInfo -> + IntStream.range(0, resultInfo.getNumPartitions()) + .boxed() + .mapToInt(resultInfo::getNumSubpartitions) + .sum()) + .max() + .getAsInt(); + } + + public static int checkAndGetSubpartitionNum(List<BlockingInputInfo> consumedResults) { + final Set<Integer> subpartitionNumSet = + consumedResults.stream() + .flatMap( + resultInfo -> + IntStream.range(0, resultInfo.getNumPartitions()) + .boxed() + .map(resultInfo::getNumSubpartitions)) + .collect(Collectors.toSet()); + // all partitions have the same subpartition num + checkState(subpartitionNumSet.size() == 1); + return subpartitionNumSet.iterator().next(); + } + + public static boolean isLegalParallelism( + int parallelism, int minParallelism, int maxParallelism) { + return parallelism >= minParallelism && parallelism <= maxParallelism; + } + + public static boolean checkAndGetIntraCorrelation(List<BlockingInputInfo> inputInfos) { + Set<Boolean> intraCorrelationSet = + inputInfos.stream() + .map(BlockingInputInfo::isIntraInputKeyCorrelated) + .collect(Collectors.toSet()); + checkArgument(intraCorrelationSet.size() == 1); + return intraCorrelationSet.iterator().next(); + } + + public static int checkAndGetParallelism(Collection<JobVertexInputInfo> vertexInputInfos) { + final Set<Integer> parallelismSet = + vertexInputInfos.stream() + .map( + vertexInputInfo -> + vertexInputInfo.getExecutionVertexInputInfos().size()) + .collect(Collectors.toSet()); + checkState(parallelismSet.size() == 1); + return parallelismSet.iterator().next(); + } + + /** + * Attempts to compute the subpartition slice ranges to ensure even distribution of data across + * downstream tasks. + * + * <p>This method first tries to compute the subpartition slice ranges by evenly distributing + * the data volume. If that fails, it attempts to compute the ranges by evenly distributing the + * number of subpartition slices. + * + * @param minParallelism The minimum parallelism. + * @param maxParallelism The maximum parallelism. + * @param maxSubpartitionSliceRangePerTask The maximum number of subpartition slice ranges per + * task. + * @param maxDataVolumePerTask The maximum data volume per task. + * @param subpartitionSlicesByTypeNumber A map of lists of subpartition slices grouped by type + * number. + * @return An {@code Optional} containing a list of index ranges representing the subpartition + * slice ranges. Returns an empty {@code Optional} if no suitable ranges can be computed. + */ + public static Optional<List<IndexRange>> tryComputeSubpartitionSliceRange( + int minParallelism, + int maxParallelism, + int maxSubpartitionSliceRangePerTask, + long maxDataVolumePerTask, + Map<Integer, List<SubpartitionSlice>> subpartitionSlicesByTypeNumber) { + Optional<List<IndexRange>> subpartitionSliceRanges = + tryComputeSubpartitionSliceRangeEvenlyDistributedData( + minParallelism, + maxParallelism, + maxSubpartitionSliceRangePerTask, + maxDataVolumePerTask, + subpartitionSlicesByTypeNumber); + if (subpartitionSliceRanges.isEmpty()) { + subpartitionSliceRanges = + tryComputeSubpartitionSliceRangeEvenlyDistributedSubpartitionSlices( + minParallelism, maxParallelism, subpartitionSlicesByTypeNumber); + } + return subpartitionSliceRanges; + } + + public static List<ExecutionVertexInputInfo> createdExecutionVertexInputInfosForBroadcast( + BlockingInputInfo inputInfo, int parallelism) { + checkArgument(inputInfo.isBroadcast()); + int numPartitions = inputInfo.getNumPartitions(); + List<ExecutionVertexInputInfo> executionVertexInputInfos = new ArrayList<>(); + for (int i = 0; i < parallelism; ++i) { + ExecutionVertexInputInfo executionVertexInputInfo; + if (inputInfo.isSingleSubpartitionContainsAllData()) { + executionVertexInputInfo = + new ExecutionVertexInputInfo( + i, new IndexRange(0, numPartitions - 1), new IndexRange(0, 0)); + } else { + // The partitions of the all-to-all result have the same number of + // subpartitions. So we can use the first partition's subpartition + // number. + executionVertexInputInfo = + new ExecutionVertexInputInfo( + i, + new IndexRange(0, numPartitions - 1), + new IndexRange(0, inputInfo.getNumSubpartitions(0) - 1)); + } + executionVertexInputInfos.add(executionVertexInputInfo); + } + return executionVertexInputInfos; + } + + public static List<ExecutionVertexInputInfo> createdExecutionVertexInputInfosForNonBroadcast( + BlockingInputInfo inputInfo, + List<IndexRange> subpartitionSliceRanges, + List<SubpartitionSlice> subpartitionSlices) { + checkArgument(!inputInfo.isBroadcast()); + int numPartitions = inputInfo.getNumPartitions(); + List<ExecutionVertexInputInfo> executionVertexInputInfos = new ArrayList<>(); + for (int i = 0; i < subpartitionSliceRanges.size(); ++i) { + IndexRange subpartitionSliceRange = subpartitionSliceRanges.get(i); + // Convert subpartitionSlices to partition range to subpartition range + Map<IndexRange, IndexRange> consumedSubpartitionGroups = + computeConsumedSubpartitionGroups( + numPartitions, + inputInfo.isPointwise(), + subpartitionSliceRange, + subpartitionSlices); + executionVertexInputInfos.add( + new ExecutionVertexInputInfo(i, consumedSubpartitionGroups)); + } + return executionVertexInputInfos; + } + + private static Optional<List<IndexRange>> tryComputeSubpartitionSliceRangeEvenlyDistributedData( + int minParallelism, + int maxParallelism, + int maxSubpartitionSliceRangePerTask, + long maxDataVolumePerTask, + Map<Integer, List<SubpartitionSlice>> subpartitionSlicesByTypeNumber) { + int subpartitionSlicesSize = + checkAdnGetSubpartitionSlicesSize(subpartitionSlicesByTypeNumber); + // Distribute the input data evenly among the downstream tasks and record the + // subpartition slice range for each task. + List<IndexRange> subpartitionSliceRanges = + computeSubpartitionSliceRanges( + maxDataVolumePerTask, + maxSubpartitionSliceRangePerTask, + subpartitionSlicesSize, + subpartitionSlicesByTypeNumber); + // if the parallelism is not legal, try to adjust to a legal parallelism + if (!isLegalParallelism(subpartitionSliceRanges.size(), minParallelism, maxParallelism)) { + long minBytesSize = maxDataVolumePerTask; + long sumBytesSize = 0; + for (int i = 0; i < subpartitionSlicesSize; ++i) { + long currentBytesSize = 0; + for (List<SubpartitionSlice> subpartitionSlice : + subpartitionSlicesByTypeNumber.values()) { + currentBytesSize += subpartitionSlice.get(i).getDataBytes(); + } + minBytesSize = Math.min(minBytesSize, currentBytesSize); + sumBytesSize += currentBytesSize; + } + return adjustToClosestLegalParallelism( + maxDataVolumePerTask, + subpartitionSliceRanges.size(), + minParallelism, + maxParallelism, + minBytesSize, + sumBytesSize, + limit -> + computeParallelism( + limit, + maxSubpartitionSliceRangePerTask, + subpartitionSlicesSize, + subpartitionSlicesByTypeNumber), + limit -> + computeSubpartitionSliceRanges( + limit, + maxSubpartitionSliceRangePerTask, + subpartitionSlicesSize, + subpartitionSlicesByTypeNumber)); + } + return Optional.of(subpartitionSliceRanges); + } + + private static Optional<List<IndexRange>> + tryComputeSubpartitionSliceRangeEvenlyDistributedSubpartitionSlices( + int minParallelism, + int maxParallelism, + Map<Integer, List<SubpartitionSlice>> subpartitionSlicesByTypeNumber) { + int subpartitionSlicesSize = + checkAdnGetSubpartitionSlicesSize(subpartitionSlicesByTypeNumber); + if (subpartitionSlicesSize < minParallelism) { + return Optional.empty(); + } + int parallelism = Math.min(subpartitionSlicesSize, maxParallelism); + List<IndexRange> subpartitionSliceRanges = new ArrayList<>(); + for (int i = 0; i < parallelism; i++) { + int start = i * subpartitionSlicesSize / parallelism; + int nextStart = (i + 1) * subpartitionSlicesSize / parallelism; + subpartitionSliceRanges.add(new IndexRange(start, nextStart - 1)); + } + checkState(subpartitionSliceRanges.size() == parallelism); + return Optional.of(subpartitionSliceRanges); + } + + /** + * Merge the subpartition slices of the specified range into an index range map, which the key + * is the partition index range and the value is the subpartition range. + * + * <p>Note: For pointwise, we prioritize that their partition ranges have no overlap, and for Review Comment: Could you add these explanation to the comment? This is not intuitive for other developers. And maybe do a check for the inputs? e.g. check that the key range size is 1. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
