This is an automated email from the ASF dual-hosted git repository.

junrui pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 0ce55c37d8c961ee66f3a157aa2d71e65ee215f4
Author: JunRuiLee <[email protected]>
AuthorDate: Thu Dec 12 17:32:42 2024 +0800

    [FLINK-36067][runtime] Manually trigger aggregate all-to-all result 
partition info when all consumers created and initialized.
---
 .../adaptivebatch/AdaptiveBatchScheduler.java      | 23 ++++++
 .../adaptivebatch/AllToAllBlockingResultInfo.java  | 95 ++++++++++++++++------
 .../AllToAllBlockingResultInfoTest.java            |  3 +
 ...tVertexParallelismAndInputInfosDeciderTest.java |  5 ++
 4 files changed, 100 insertions(+), 26 deletions(-)

diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchScheduler.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchScheduler.java
index 99df43dc807..a46a210446f 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchScheduler.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchScheduler.java
@@ -273,6 +273,14 @@ public class AdaptiveBatchScheduler extends 
DefaultScheduler implements JobGraph
 
         // 4. update json plan
         
getExecutionGraph().setJsonPlan(JsonPlanGenerator.generatePlan(getJobGraph()));
+
+        // 5. try aggregate subpartition bytes
+        for (JobVertex newVertex : newVertices) {
+            for (JobEdge input : newVertex.getInputs()) {
+                
Optional.ofNullable(blockingResultInfos.get(input.getSourceId()))
+                        .ifPresent(this::maybeAggregateSubpartitionBytes);
+            }
+        }
     }
 
     @Override
@@ -486,11 +494,25 @@ public class AdaptiveBatchScheduler extends 
DefaultScheduler implements JobGraph
                                 }
                                 resultInfo.recordPartitionInfo(
                                         partitionId.getPartitionNumber(), 
partitionBytes);
+                                maybeAggregateSubpartitionBytes(resultInfo);
                                 return resultInfo;
                             });
                 });
     }
 
+    private void maybeAggregateSubpartitionBytes(BlockingResultInfo 
resultInfo) {
+        IntermediateResult intermediateResult =
+                
getExecutionGraph().getAllIntermediateResults().get(resultInfo.getResultId());
+
+        if (resultInfo instanceof AllToAllBlockingResultInfo
+                && intermediateResult.areAllConsumerVerticesCreated()
+                && intermediateResult.getConsumerVertices().stream()
+                        .map(this::getExecutionJobVertex)
+                        .allMatch(ExecutionJobVertex::isInitialized)) {
+            ((AllToAllBlockingResultInfo) 
resultInfo).aggregateSubpartitionBytes();
+        }
+    }
+
     @Override
     public void allocateSlotsAndDeploy(final List<ExecutionVertexID> 
verticesToDeploy) {
         List<ExecutionVertex> executionVertices =
@@ -657,6 +679,7 @@ public class AdaptiveBatchScheduler extends 
DefaultScheduler implements JobGraph
                                 
parallelismAndInputInfos.getJobVertexInputInfos(),
                                 createTimestamp);
                         newlyInitializedJobVertices.add(jobVertex);
+                        
consumedResultsInfo.get().forEach(this::maybeAggregateSubpartitionBytes);
                     }
                 }
             }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfo.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfo.java
index 9f01a1061e1..ed1e945912f 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfo.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfo.java
@@ -27,6 +27,7 @@ import javax.annotation.Nullable;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.List;
+import java.util.Optional;
 import java.util.stream.Collectors;
 
 import static org.apache.flink.util.Preconditions.checkState;
@@ -74,18 +75,28 @@ public class AllToAllBlockingResultInfo extends 
AbstractBlockingResultInfo {
 
     @Override
     public long getNumBytesProduced() {
-        checkState(aggregatedSubpartitionBytes != null, "Not all partition 
infos are ready");
+        checkState(
+                aggregatedSubpartitionBytes != null
+                        || subpartitionBytesByPartitionIndex.size() == 
numOfPartitions,
+                "Not all partition infos are ready");
+
+        List<Long> bytes =
+                Optional.ofNullable(aggregatedSubpartitionBytes)
+                        .orElse(getAggregatedSubpartitionBytesInternal());
         if (isBroadcast) {
-            return aggregatedSubpartitionBytes.get(0);
+            return bytes.get(0);
         } else {
-            return aggregatedSubpartitionBytes.stream().reduce(0L, Long::sum);
+            return bytes.stream().reduce(0L, Long::sum);
         }
     }
 
     @Override
     public long getNumBytesProduced(
             IndexRange partitionIndexRange, IndexRange subpartitionIndexRange) 
{
-        checkState(aggregatedSubpartitionBytes != null, "Not all partition 
infos are ready");
+        List<Long> bytes =
+                Optional.ofNullable(aggregatedSubpartitionBytes)
+                        .orElse(getAggregatedSubpartitionBytesInternal());
+
         checkState(
                 partitionIndexRange.getStartIndex() == 0
                         && partitionIndexRange.getEndIndex() == 
numOfPartitions - 1,
@@ -96,7 +107,7 @@ public class AllToAllBlockingResultInfo extends 
AbstractBlockingResultInfo {
                 "Subpartition index %s is out of range.",
                 subpartitionIndexRange.getEndIndex());
 
-        return aggregatedSubpartitionBytes
+        return bytes
                 .subList(
                         subpartitionIndexRange.getStartIndex(),
                         subpartitionIndexRange.getEndIndex() + 1)
@@ -106,31 +117,56 @@ public class AllToAllBlockingResultInfo extends 
AbstractBlockingResultInfo {
 
     @Override
     public void recordPartitionInfo(int partitionIndex, ResultPartitionBytes 
partitionBytes) {
-        // Once all partitions are finished, we can convert the subpartition 
bytes to aggregated
-        // value to reduce the space usage, because the distribution of source 
splits does not
-        // affect the distribution of data consumed by downstream tasks of 
ALL_TO_ALL edges(Hashing
-        // or Rebalancing, we do not consider rare cases such as custom 
partitions here).
         if (aggregatedSubpartitionBytes == null) {
             super.recordPartitionInfo(partitionIndex, partitionBytes);
+        }
+    }
 
-            if (subpartitionBytesByPartitionIndex.size() == numOfPartitions) {
-                long[] aggregatedBytes = new long[numOfSubpartitions];
-                subpartitionBytesByPartitionIndex
-                        .values()
-                        .forEach(
-                                subpartitionBytes -> {
-                                    checkState(subpartitionBytes.length == 
numOfSubpartitions);
-                                    for (int i = 0; i < 
subpartitionBytes.length; ++i) {
-                                        aggregatedBytes[i] += 
subpartitionBytes[i];
-                                    }
-                                });
-                this.aggregatedSubpartitionBytes =
-                        
Arrays.stream(aggregatedBytes).boxed().collect(Collectors.toList());
-                this.subpartitionBytesByPartitionIndex.clear();
-            }
+    /**
+     * Aggregates the subpartition bytes to reduce space usage.
+     *
+     * <p>Once all partitions are finished and all consumer jobVertices are 
initialized, we can
+     * convert the subpartition bytes to aggregated value to reduce the space 
usage, because the
+     * distribution of source splits does not affect the distribution of data 
consumed by downstream
+     * tasks of ALL_TO_ALL edges(Hashing or Rebalancing, we do not consider 
rare cases such as
+     * custom partitions here).
+     */
+    protected void aggregateSubpartitionBytes() {
+        if (subpartitionBytesByPartitionIndex.size() == numOfPartitions) {
+            this.aggregatedSubpartitionBytes = 
getAggregatedSubpartitionBytesInternal();
+            this.subpartitionBytesByPartitionIndex.clear();
         }
     }
 
+    /**
+     * Aggregates the bytes of subpartitions across all partition indices 
without modifying the
+     * existing state. This method is intended for querying purposes only.
+     *
+     * <p>The method computes the sum of the bytes for each subpartition 
across all partitions and
+     * returns a list containing these summed values.
+     *
+     * <p>This method is needed in scenarios where aggregated results are 
required, but fine-grained
+     * statistics should remain not aggregated. Specifically, when not all 
consumer vertices of this
+     * result info are created or initialized, this result info could not be 
aggregated. And the
+     * existing consumer vertices of this info still require these aggregated 
result for scheduling.
+     *
+     * @return a list of aggregated byte counts for each subpartition.
+     */
+    private List<Long> getAggregatedSubpartitionBytesInternal() {
+        long[] aggregatedBytes = new long[numOfSubpartitions];
+        subpartitionBytesByPartitionIndex
+                .values()
+                .forEach(
+                        subpartitionBytes -> {
+                            checkState(subpartitionBytes.length == 
numOfSubpartitions);
+                            for (int i = 0; i < subpartitionBytes.length; ++i) 
{
+                                aggregatedBytes[i] += subpartitionBytes[i];
+                            }
+                        });
+
+        return 
Arrays.stream(aggregatedBytes).boxed().collect(Collectors.toList());
+    }
+
     @Override
     public void resetPartitionInfo(int partitionIndex) {
         if (aggregatedSubpartitionBytes == null) {
@@ -139,7 +175,14 @@ public class AllToAllBlockingResultInfo extends 
AbstractBlockingResultInfo {
     }
 
     public List<Long> getAggregatedSubpartitionBytes() {
-        checkState(aggregatedSubpartitionBytes != null, "Not all partition 
infos are ready");
-        return Collections.unmodifiableList(aggregatedSubpartitionBytes);
+        checkState(
+                aggregatedSubpartitionBytes != null
+                        || subpartitionBytesByPartitionIndex.size() == 
numOfPartitions,
+                "Not all partition infos are ready");
+        if (aggregatedSubpartitionBytes == null) {
+            return getAggregatedSubpartitionBytesInternal();
+        } else {
+            return Collections.unmodifiableList(aggregatedSubpartitionBytes);
+        }
     }
 }
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfoTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfoTest.java
index b25665b299d..e298b4a065a 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfoTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfoTest.java
@@ -99,6 +99,9 @@ class AllToAllBlockingResultInfoTest {
         // The result info should be (partitionBytes2 + partitionBytes3)
         assertThat(resultInfo.getNumBytesProduced()).isEqualTo(576L);
         
assertThat(resultInfo.getAggregatedSubpartitionBytes()).containsExactly(192L, 
384L);
+        // The raw info should not be clear
+        assertThat(resultInfo.getNumOfRecordedPartitions()).isGreaterThan(0);
+        resultInfo.aggregateSubpartitionBytes();
         // The raw info should be clear
         assertThat(resultInfo.getNumOfRecordedPartitions()).isZero();
 
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismAndInputInfosDeciderTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismAndInputInfosDeciderTest.java
index d1b24d862f4..23c70f317bd 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismAndInputInfosDeciderTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismAndInputInfosDeciderTest.java
@@ -609,6 +609,11 @@ class DefaultVertexParallelismAndInputInfosDeciderTest {
 
         @Override
         public void resetPartitionInfo(int partitionIndex) {}
+
+        @Override
+        public Map<Integer, long[]> getSubpartitionBytesByPartitionIndex() {
+            return Map.of();
+        }
     }
 
     private static BlockingResultInfo createFromBroadcastResult(long 
producedBytes) {

Reply via email to