This is an automated email from the ASF dual-hosted git repository. junrui pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
commit 0ce55c37d8c961ee66f3a157aa2d71e65ee215f4 Author: JunRuiLee <[email protected]> AuthorDate: Thu Dec 12 17:32:42 2024 +0800 [FLINK-36067][runtime] Manually trigger aggregate all-to-all result partition info when all consumers created and initialized. --- .../adaptivebatch/AdaptiveBatchScheduler.java | 23 ++++++ .../adaptivebatch/AllToAllBlockingResultInfo.java | 95 ++++++++++++++++------ .../AllToAllBlockingResultInfoTest.java | 3 + ...tVertexParallelismAndInputInfosDeciderTest.java | 5 ++ 4 files changed, 100 insertions(+), 26 deletions(-) diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchScheduler.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchScheduler.java index 99df43dc807..a46a210446f 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchScheduler.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchScheduler.java @@ -273,6 +273,14 @@ public class AdaptiveBatchScheduler extends DefaultScheduler implements JobGraph // 4. update json plan getExecutionGraph().setJsonPlan(JsonPlanGenerator.generatePlan(getJobGraph())); + + // 5. try aggregate subpartition bytes + for (JobVertex newVertex : newVertices) { + for (JobEdge input : newVertex.getInputs()) { + Optional.ofNullable(blockingResultInfos.get(input.getSourceId())) + .ifPresent(this::maybeAggregateSubpartitionBytes); + } + } } @Override @@ -486,11 +494,25 @@ public class AdaptiveBatchScheduler extends DefaultScheduler implements JobGraph } resultInfo.recordPartitionInfo( partitionId.getPartitionNumber(), partitionBytes); + maybeAggregateSubpartitionBytes(resultInfo); return resultInfo; }); }); } + private void maybeAggregateSubpartitionBytes(BlockingResultInfo resultInfo) { + IntermediateResult intermediateResult = + getExecutionGraph().getAllIntermediateResults().get(resultInfo.getResultId()); + + if (resultInfo instanceof AllToAllBlockingResultInfo + && intermediateResult.areAllConsumerVerticesCreated() + && intermediateResult.getConsumerVertices().stream() + .map(this::getExecutionJobVertex) + .allMatch(ExecutionJobVertex::isInitialized)) { + ((AllToAllBlockingResultInfo) resultInfo).aggregateSubpartitionBytes(); + } + } + @Override public void allocateSlotsAndDeploy(final List<ExecutionVertexID> verticesToDeploy) { List<ExecutionVertex> executionVertices = @@ -657,6 +679,7 @@ public class AdaptiveBatchScheduler extends DefaultScheduler implements JobGraph parallelismAndInputInfos.getJobVertexInputInfos(), createTimestamp); newlyInitializedJobVertices.add(jobVertex); + consumedResultsInfo.get().forEach(this::maybeAggregateSubpartitionBytes); } } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfo.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfo.java index 9f01a1061e1..ed1e945912f 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfo.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfo.java @@ -27,6 +27,7 @@ import javax.annotation.Nullable; import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.Optional; import java.util.stream.Collectors; import static org.apache.flink.util.Preconditions.checkState; @@ -74,18 +75,28 @@ public class AllToAllBlockingResultInfo extends AbstractBlockingResultInfo { @Override public long getNumBytesProduced() { - checkState(aggregatedSubpartitionBytes != null, "Not all partition infos are ready"); + checkState( + aggregatedSubpartitionBytes != null + || subpartitionBytesByPartitionIndex.size() == numOfPartitions, + "Not all partition infos are ready"); + + List<Long> bytes = + Optional.ofNullable(aggregatedSubpartitionBytes) + .orElse(getAggregatedSubpartitionBytesInternal()); if (isBroadcast) { - return aggregatedSubpartitionBytes.get(0); + return bytes.get(0); } else { - return aggregatedSubpartitionBytes.stream().reduce(0L, Long::sum); + return bytes.stream().reduce(0L, Long::sum); } } @Override public long getNumBytesProduced( IndexRange partitionIndexRange, IndexRange subpartitionIndexRange) { - checkState(aggregatedSubpartitionBytes != null, "Not all partition infos are ready"); + List<Long> bytes = + Optional.ofNullable(aggregatedSubpartitionBytes) + .orElse(getAggregatedSubpartitionBytesInternal()); + checkState( partitionIndexRange.getStartIndex() == 0 && partitionIndexRange.getEndIndex() == numOfPartitions - 1, @@ -96,7 +107,7 @@ public class AllToAllBlockingResultInfo extends AbstractBlockingResultInfo { "Subpartition index %s is out of range.", subpartitionIndexRange.getEndIndex()); - return aggregatedSubpartitionBytes + return bytes .subList( subpartitionIndexRange.getStartIndex(), subpartitionIndexRange.getEndIndex() + 1) @@ -106,31 +117,56 @@ public class AllToAllBlockingResultInfo extends AbstractBlockingResultInfo { @Override public void recordPartitionInfo(int partitionIndex, ResultPartitionBytes partitionBytes) { - // Once all partitions are finished, we can convert the subpartition bytes to aggregated - // value to reduce the space usage, because the distribution of source splits does not - // affect the distribution of data consumed by downstream tasks of ALL_TO_ALL edges(Hashing - // or Rebalancing, we do not consider rare cases such as custom partitions here). if (aggregatedSubpartitionBytes == null) { super.recordPartitionInfo(partitionIndex, partitionBytes); + } + } - if (subpartitionBytesByPartitionIndex.size() == numOfPartitions) { - long[] aggregatedBytes = new long[numOfSubpartitions]; - subpartitionBytesByPartitionIndex - .values() - .forEach( - subpartitionBytes -> { - checkState(subpartitionBytes.length == numOfSubpartitions); - for (int i = 0; i < subpartitionBytes.length; ++i) { - aggregatedBytes[i] += subpartitionBytes[i]; - } - }); - this.aggregatedSubpartitionBytes = - Arrays.stream(aggregatedBytes).boxed().collect(Collectors.toList()); - this.subpartitionBytesByPartitionIndex.clear(); - } + /** + * Aggregates the subpartition bytes to reduce space usage. + * + * <p>Once all partitions are finished and all consumer jobVertices are initialized, we can + * convert the subpartition bytes to aggregated value to reduce the space usage, because the + * distribution of source splits does not affect the distribution of data consumed by downstream + * tasks of ALL_TO_ALL edges(Hashing or Rebalancing, we do not consider rare cases such as + * custom partitions here). + */ + protected void aggregateSubpartitionBytes() { + if (subpartitionBytesByPartitionIndex.size() == numOfPartitions) { + this.aggregatedSubpartitionBytes = getAggregatedSubpartitionBytesInternal(); + this.subpartitionBytesByPartitionIndex.clear(); } } + /** + * Aggregates the bytes of subpartitions across all partition indices without modifying the + * existing state. This method is intended for querying purposes only. + * + * <p>The method computes the sum of the bytes for each subpartition across all partitions and + * returns a list containing these summed values. + * + * <p>This method is needed in scenarios where aggregated results are required, but fine-grained + * statistics should remain not aggregated. Specifically, when not all consumer vertices of this + * result info are created or initialized, this result info could not be aggregated. And the + * existing consumer vertices of this info still require these aggregated result for scheduling. + * + * @return a list of aggregated byte counts for each subpartition. + */ + private List<Long> getAggregatedSubpartitionBytesInternal() { + long[] aggregatedBytes = new long[numOfSubpartitions]; + subpartitionBytesByPartitionIndex + .values() + .forEach( + subpartitionBytes -> { + checkState(subpartitionBytes.length == numOfSubpartitions); + for (int i = 0; i < subpartitionBytes.length; ++i) { + aggregatedBytes[i] += subpartitionBytes[i]; + } + }); + + return Arrays.stream(aggregatedBytes).boxed().collect(Collectors.toList()); + } + @Override public void resetPartitionInfo(int partitionIndex) { if (aggregatedSubpartitionBytes == null) { @@ -139,7 +175,14 @@ public class AllToAllBlockingResultInfo extends AbstractBlockingResultInfo { } public List<Long> getAggregatedSubpartitionBytes() { - checkState(aggregatedSubpartitionBytes != null, "Not all partition infos are ready"); - return Collections.unmodifiableList(aggregatedSubpartitionBytes); + checkState( + aggregatedSubpartitionBytes != null + || subpartitionBytesByPartitionIndex.size() == numOfPartitions, + "Not all partition infos are ready"); + if (aggregatedSubpartitionBytes == null) { + return getAggregatedSubpartitionBytesInternal(); + } else { + return Collections.unmodifiableList(aggregatedSubpartitionBytes); + } } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfoTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfoTest.java index b25665b299d..e298b4a065a 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfoTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfoTest.java @@ -99,6 +99,9 @@ class AllToAllBlockingResultInfoTest { // The result info should be (partitionBytes2 + partitionBytes3) assertThat(resultInfo.getNumBytesProduced()).isEqualTo(576L); assertThat(resultInfo.getAggregatedSubpartitionBytes()).containsExactly(192L, 384L); + // The raw info should not be clear + assertThat(resultInfo.getNumOfRecordedPartitions()).isGreaterThan(0); + resultInfo.aggregateSubpartitionBytes(); // The raw info should be clear assertThat(resultInfo.getNumOfRecordedPartitions()).isZero(); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismAndInputInfosDeciderTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismAndInputInfosDeciderTest.java index d1b24d862f4..23c70f317bd 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismAndInputInfosDeciderTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismAndInputInfosDeciderTest.java @@ -609,6 +609,11 @@ class DefaultVertexParallelismAndInputInfosDeciderTest { @Override public void resetPartitionInfo(int partitionIndex) {} + + @Override + public Map<Integer, long[]> getSubpartitionBytesByPartitionIndex() { + return Map.of(); + } } private static BlockingResultInfo createFromBroadcastResult(long producedBytes) {
