This is an automated email from the ASF dual-hosted git repository. xtsong pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
commit fc7defb14d11e270d539ee0d80a2076ae55a4ea2 Author: Yuxin Tan <[email protected]> AuthorDate: Mon Dec 12 21:39:20 2022 +0800 [FLINK-30471][network] Optimize the enriching network memory process in SsgNetworkMemoryCalculationUtils --- .../SsgNetworkMemoryCalculationUtils.java | 67 +++++++++------------- .../SsgNetworkMemoryCalculationUtilsTest.java | 6 +- 2 files changed, 31 insertions(+), 42 deletions(-) diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtils.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtils.java index bca093c7219..c1e58745aad 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtils.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtils.java @@ -93,28 +93,27 @@ public class SsgNetworkMemoryCalculationUtils { private static TaskInputsOutputsDescriptor buildTaskInputsOutputsDescriptor( ExecutionJobVertex ejv, Function<JobVertexID, ExecutionJobVertex> ejvs) { - Map<IntermediateDataSetID, Integer> maxInputChannelNums; - Map<IntermediateDataSetID, Integer> maxSubpartitionNums; + Map<IntermediateDataSetID, Integer> maxInputChannelNums = new HashMap<>(); + Map<IntermediateDataSetID, Integer> maxSubpartitionNums = new HashMap<>(); + Map<IntermediateDataSetID, ResultPartitionType> partitionTypes = new HashMap<>(); if (ejv.getGraph().isDynamic()) { - maxInputChannelNums = getMaxInputChannelNumsForDynamicGraph(ejv); - maxSubpartitionNums = getMaxSubpartitionNumsForDynamicGraph(ejv); + getMaxInputChannelInfoForDynamicGraph(ejv, maxInputChannelNums); + getMaxSubpartitionInfoForDynamicGraph(ejv, maxSubpartitionNums, partitionTypes); } else { - maxInputChannelNums = getMaxInputChannelNums(ejv); - maxSubpartitionNums = getMaxSubpartitionNums(ejv, ejvs); + getMaxInputChannelInfo(ejv, maxInputChannelNums); + getMaxSubpartitionInfo(ejv, maxSubpartitionNums, partitionTypes, ejvs); } JobVertex jv = ejv.getJobVertex(); - Map<IntermediateDataSetID, ResultPartitionType> partitionTypes = getPartitionTypes(jv); return TaskInputsOutputsDescriptor.from( jv.getNumberOfInputs(), maxInputChannelNums, maxSubpartitionNums, partitionTypes); } - private static Map<IntermediateDataSetID, Integer> getMaxInputChannelNums( - ExecutionJobVertex ejv) { + private static void getMaxInputChannelInfo( + ExecutionJobVertex ejv, Map<IntermediateDataSetID, Integer> maxInputChannelNums) { - Map<IntermediateDataSetID, Integer> ret = new HashMap<>(); List<JobEdge> inputEdges = ejv.getJobVertex().getInputs(); for (int i = 0; i < inputEdges.size(); i++) { @@ -129,16 +128,15 @@ public class SsgNetworkMemoryCalculationUtils { ejv.getParallelism(), consumedResult.getNumberOfAssignedPartitions(), inputEdge.getDistributionPattern()); - ret.merge(consumedResult.getId(), maxNum, Integer::sum); + maxInputChannelNums.merge(consumedResult.getId(), maxNum, Integer::sum); } - - return ret; } - private static Map<IntermediateDataSetID, Integer> getMaxSubpartitionNums( - ExecutionJobVertex ejv, Function<JobVertexID, ExecutionJobVertex> ejvs) { - - Map<IntermediateDataSetID, Integer> ret = new HashMap<>(); + private static void getMaxSubpartitionInfo( + ExecutionJobVertex ejv, + Map<IntermediateDataSetID, Integer> maxSubpartitionNums, + Map<IntermediateDataSetID, ResultPartitionType> partitionTypes, + Function<JobVertexID, ExecutionJobVertex> ejvs) { List<IntermediateDataSet> producedDataSets = ejv.getJobVertex().getProducedDataSets(); checkState(!ejv.getGraph().isDynamic(), "Only support non-dynamic graph."); @@ -157,23 +155,14 @@ public class SsgNetworkMemoryCalculationUtils { consumerJobVertex.getParallelism(), outputEdge.getDistributionPattern()); } - ret.put(producedDataSet.getId(), maxNum); + maxSubpartitionNums.put(producedDataSet.getId(), maxNum); + partitionTypes.putIfAbsent(producedDataSet.getId(), producedDataSet.getResultType()); } - - return ret; - } - - private static Map<IntermediateDataSetID, ResultPartitionType> getPartitionTypes(JobVertex jv) { - Map<IntermediateDataSetID, ResultPartitionType> ret = new HashMap<>(); - jv.getProducedDataSets().forEach(ds -> ret.putIfAbsent(ds.getId(), ds.getResultType())); - return ret; } @VisibleForTesting - static Map<IntermediateDataSetID, Integer> getMaxInputChannelNumsForDynamicGraph( - ExecutionJobVertex ejv) { - - Map<IntermediateDataSetID, Integer> ret = new HashMap<>(); + static void getMaxInputChannelInfoForDynamicGraph( + ExecutionJobVertex ejv, Map<IntermediateDataSetID, Integer> maxInputChannelNums) { for (ExecutionVertex vertex : ejv.getTaskVertices()) { Map<IntermediateDataSetID, Integer> tmp = new HashMap<>(); @@ -194,27 +183,25 @@ public class SsgNetworkMemoryCalculationUtils { } for (Map.Entry<IntermediateDataSetID, Integer> entry : tmp.entrySet()) { - ret.merge(entry.getKey(), entry.getValue(), Integer::max); + maxInputChannelNums.merge(entry.getKey(), entry.getValue(), Integer::max); } } - - return ret; } - private static Map<IntermediateDataSetID, Integer> getMaxSubpartitionNumsForDynamicGraph( - ExecutionJobVertex ejv) { - - Map<IntermediateDataSetID, Integer> ret = new HashMap<>(); + private static void getMaxSubpartitionInfoForDynamicGraph( + ExecutionJobVertex ejv, + Map<IntermediateDataSetID, Integer> maxSubpartitionNums, + Map<IntermediateDataSetID, ResultPartitionType> partitionTypes) { for (IntermediateResult intermediateResult : ejv.getProducedDataSets()) { final int maxNum = Arrays.stream(intermediateResult.getPartitions()) .map(IntermediateResultPartition::getNumberOfSubpartitions) .reduce(0, Integer::max); - ret.put(intermediateResult.getId(), maxNum); + maxSubpartitionNums.put(intermediateResult.getId(), maxNum); + partitionTypes.putIfAbsent( + intermediateResult.getId(), intermediateResult.getResultType()); } - - return ret; } /** Private default constructor to avoid being instantiated. */ diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtilsTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtilsTest.java index b81824b89bd..45d71b8066f 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtilsTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtilsTest.java @@ -48,6 +48,7 @@ import org.junit.ClassRule; import org.junit.Test; import java.util.Arrays; +import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; @@ -231,8 +232,9 @@ public class SsgNetworkMemoryCalculationUtilsTest { consumer.setParallelism(decidedConsumerParallelism); eg.initializeJobVertex(consumer, 0L); - Map<IntermediateDataSetID, Integer> maxInputChannelNums = - SsgNetworkMemoryCalculationUtils.getMaxInputChannelNumsForDynamicGraph(consumer); + Map<IntermediateDataSetID, Integer> maxInputChannelNums = new HashMap<>(); + SsgNetworkMemoryCalculationUtils.getMaxInputChannelInfoForDynamicGraph( + consumer, maxInputChannelNums); assertThat(maxInputChannelNums.size(), is(1)); assertThat(maxInputChannelNums.get(result.getId()), is(expectedNumChannels));
