This is an automated email from the ASF dual-hosted git repository.

xtsong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit fc7defb14d11e270d539ee0d80a2076ae55a4ea2
Author: Yuxin Tan <[email protected]>
AuthorDate: Mon Dec 12 21:39:20 2022 +0800

    [FLINK-30471][network] Optimize the enriching network memory process in 
SsgNetworkMemoryCalculationUtils
---
 .../SsgNetworkMemoryCalculationUtils.java          | 67 +++++++++-------------
 .../SsgNetworkMemoryCalculationUtilsTest.java      |  6 +-
 2 files changed, 31 insertions(+), 42 deletions(-)

diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtils.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtils.java
index bca093c7219..c1e58745aad 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtils.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtils.java
@@ -93,28 +93,27 @@ public class SsgNetworkMemoryCalculationUtils {
     private static TaskInputsOutputsDescriptor 
buildTaskInputsOutputsDescriptor(
             ExecutionJobVertex ejv, Function<JobVertexID, ExecutionJobVertex> 
ejvs) {
 
-        Map<IntermediateDataSetID, Integer> maxInputChannelNums;
-        Map<IntermediateDataSetID, Integer> maxSubpartitionNums;
+        Map<IntermediateDataSetID, Integer> maxInputChannelNums = new 
HashMap<>();
+        Map<IntermediateDataSetID, Integer> maxSubpartitionNums = new 
HashMap<>();
+        Map<IntermediateDataSetID, ResultPartitionType> partitionTypes = new 
HashMap<>();
 
         if (ejv.getGraph().isDynamic()) {
-            maxInputChannelNums = getMaxInputChannelNumsForDynamicGraph(ejv);
-            maxSubpartitionNums = getMaxSubpartitionNumsForDynamicGraph(ejv);
+            getMaxInputChannelInfoForDynamicGraph(ejv, maxInputChannelNums);
+            getMaxSubpartitionInfoForDynamicGraph(ejv, maxSubpartitionNums, 
partitionTypes);
         } else {
-            maxInputChannelNums = getMaxInputChannelNums(ejv);
-            maxSubpartitionNums = getMaxSubpartitionNums(ejv, ejvs);
+            getMaxInputChannelInfo(ejv, maxInputChannelNums);
+            getMaxSubpartitionInfo(ejv, maxSubpartitionNums, partitionTypes, 
ejvs);
         }
 
         JobVertex jv = ejv.getJobVertex();
-        Map<IntermediateDataSetID, ResultPartitionType> partitionTypes = 
getPartitionTypes(jv);
 
         return TaskInputsOutputsDescriptor.from(
                 jv.getNumberOfInputs(), maxInputChannelNums, 
maxSubpartitionNums, partitionTypes);
     }
 
-    private static Map<IntermediateDataSetID, Integer> getMaxInputChannelNums(
-            ExecutionJobVertex ejv) {
+    private static void getMaxInputChannelInfo(
+            ExecutionJobVertex ejv, Map<IntermediateDataSetID, Integer> 
maxInputChannelNums) {
 
-        Map<IntermediateDataSetID, Integer> ret = new HashMap<>();
         List<JobEdge> inputEdges = ejv.getJobVertex().getInputs();
 
         for (int i = 0; i < inputEdges.size(); i++) {
@@ -129,16 +128,15 @@ public class SsgNetworkMemoryCalculationUtils {
                             ejv.getParallelism(),
                             consumedResult.getNumberOfAssignedPartitions(),
                             inputEdge.getDistributionPattern());
-            ret.merge(consumedResult.getId(), maxNum, Integer::sum);
+            maxInputChannelNums.merge(consumedResult.getId(), maxNum, 
Integer::sum);
         }
-
-        return ret;
     }
 
-    private static Map<IntermediateDataSetID, Integer> getMaxSubpartitionNums(
-            ExecutionJobVertex ejv, Function<JobVertexID, ExecutionJobVertex> 
ejvs) {
-
-        Map<IntermediateDataSetID, Integer> ret = new HashMap<>();
+    private static void getMaxSubpartitionInfo(
+            ExecutionJobVertex ejv,
+            Map<IntermediateDataSetID, Integer> maxSubpartitionNums,
+            Map<IntermediateDataSetID, ResultPartitionType> partitionTypes,
+            Function<JobVertexID, ExecutionJobVertex> ejvs) {
         List<IntermediateDataSet> producedDataSets = 
ejv.getJobVertex().getProducedDataSets();
 
         checkState(!ejv.getGraph().isDynamic(), "Only support non-dynamic 
graph.");
@@ -157,23 +155,14 @@ public class SsgNetworkMemoryCalculationUtils {
                                 consumerJobVertex.getParallelism(),
                                 outputEdge.getDistributionPattern());
             }
-            ret.put(producedDataSet.getId(), maxNum);
+            maxSubpartitionNums.put(producedDataSet.getId(), maxNum);
+            partitionTypes.putIfAbsent(producedDataSet.getId(), 
producedDataSet.getResultType());
         }
-
-        return ret;
-    }
-
-    private static Map<IntermediateDataSetID, ResultPartitionType> 
getPartitionTypes(JobVertex jv) {
-        Map<IntermediateDataSetID, ResultPartitionType> ret = new HashMap<>();
-        jv.getProducedDataSets().forEach(ds -> ret.putIfAbsent(ds.getId(), 
ds.getResultType()));
-        return ret;
     }
 
     @VisibleForTesting
-    static Map<IntermediateDataSetID, Integer> 
getMaxInputChannelNumsForDynamicGraph(
-            ExecutionJobVertex ejv) {
-
-        Map<IntermediateDataSetID, Integer> ret = new HashMap<>();
+    static void getMaxInputChannelInfoForDynamicGraph(
+            ExecutionJobVertex ejv, Map<IntermediateDataSetID, Integer> 
maxInputChannelNums) {
 
         for (ExecutionVertex vertex : ejv.getTaskVertices()) {
             Map<IntermediateDataSetID, Integer> tmp = new HashMap<>();
@@ -194,27 +183,25 @@ public class SsgNetworkMemoryCalculationUtils {
             }
 
             for (Map.Entry<IntermediateDataSetID, Integer> entry : 
tmp.entrySet()) {
-                ret.merge(entry.getKey(), entry.getValue(), Integer::max);
+                maxInputChannelNums.merge(entry.getKey(), entry.getValue(), 
Integer::max);
             }
         }
-
-        return ret;
     }
 
-    private static Map<IntermediateDataSetID, Integer> 
getMaxSubpartitionNumsForDynamicGraph(
-            ExecutionJobVertex ejv) {
-
-        Map<IntermediateDataSetID, Integer> ret = new HashMap<>();
+    private static void getMaxSubpartitionInfoForDynamicGraph(
+            ExecutionJobVertex ejv,
+            Map<IntermediateDataSetID, Integer> maxSubpartitionNums,
+            Map<IntermediateDataSetID, ResultPartitionType> partitionTypes) {
 
         for (IntermediateResult intermediateResult : 
ejv.getProducedDataSets()) {
             final int maxNum =
                     Arrays.stream(intermediateResult.getPartitions())
                             
.map(IntermediateResultPartition::getNumberOfSubpartitions)
                             .reduce(0, Integer::max);
-            ret.put(intermediateResult.getId(), maxNum);
+            maxSubpartitionNums.put(intermediateResult.getId(), maxNum);
+            partitionTypes.putIfAbsent(
+                    intermediateResult.getId(), 
intermediateResult.getResultType());
         }
-
-        return ret;
     }
 
     /** Private default constructor to avoid being instantiated. */
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtilsTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtilsTest.java
index b81824b89bd..45d71b8066f 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtilsTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtilsTest.java
@@ -48,6 +48,7 @@ import org.junit.ClassRule;
 import org.junit.Test;
 
 import java.util.Arrays;
+import java.util.HashMap;
 import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
@@ -231,8 +232,9 @@ public class SsgNetworkMemoryCalculationUtilsTest {
         consumer.setParallelism(decidedConsumerParallelism);
         eg.initializeJobVertex(consumer, 0L);
 
-        Map<IntermediateDataSetID, Integer> maxInputChannelNums =
-                
SsgNetworkMemoryCalculationUtils.getMaxInputChannelNumsForDynamicGraph(consumer);
+        Map<IntermediateDataSetID, Integer> maxInputChannelNums = new 
HashMap<>();
+        SsgNetworkMemoryCalculationUtils.getMaxInputChannelInfoForDynamicGraph(
+                consumer, maxInputChannelNums);
 
         assertThat(maxInputChannelNums.size(), is(1));
         assertThat(maxInputChannelNums.get(result.getId()), 
is(expectedNumChannels));

Reply via email to