This is an automated email from the ASF dual-hosted git repository.

wanglijie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 963333795ae5bf0b6d6c8b1fa793be742a7c1af3
Author: Lijie Wang <[email protected]>
AuthorDate: Mon Dec 19 11:03:29 2022 +0800

    [FLINK-29664][runtime] Support to collect the size of subpartitions.
    
    This closes #21111.
---
 .../flink/runtime/executiongraph/Execution.java    |  12 +-
 .../flink/runtime/executiongraph/IOMetrics.java    |  59 ++++++--
 .../executiongraph/ResultPartitionBytes.java       |  37 +++++
 .../metrics/ResultPartitionBytesCounter.java       |  53 +++++++
 .../partition/BufferWritingResultPartition.java    |   4 +-
 .../io/network/partition/ResultPartition.java      |  14 +-
 .../partition/SortMergeResultPartition.java        |  14 +-
 .../partition/hybrid/HsResultPartition.java        |   4 +-
 .../runtime/metrics/groups/TaskIOMetricGroup.java  |  16 +-
 .../flink/runtime/scheduler/DefaultScheduler.java  |   3 +-
 .../flink/runtime/scheduler/SchedulerBase.java     |   5 +-
 .../adaptivebatch/AbstractBlockingResultInfo.java  |  75 ++++++++++
 .../adaptivebatch/AdaptiveBatchScheduler.java      |  98 +++++++++++-
 .../adaptivebatch/AllToAllBlockingResultInfo.java  | 112 ++++++++++++++
 .../adaptivebatch/BlockingResultInfo.java          | 113 +++++++-------
 .../DefaultVertexParallelismDecider.java           |  10 +-
 .../adaptivebatch/PointwiseBlockingResultInfo.java |  53 +++++++
 .../adaptivebatch/SpeculativeScheduler.java        |   5 +-
 .../DefaultExecutionGraphDeploymentTest.java       |  22 ++-
 .../io/network/partition/ResultPartitionTest.java  |  34 +++--
 .../partition/SortMergeResultPartitionTest.java    |  23 ++-
 .../partition/hybrid/HsResultPartitionTest.java    |  17 ++-
 .../metrics/groups/TaskIOMetricGroupTest.java      |  30 ++--
 .../runtime/scheduler/DefaultSchedulerTest.java    |   7 +-
 .../runtime/scheduler/SchedulerTestingUtils.java   |  34 +++++
 .../adaptivebatch/AdaptiveBatchSchedulerTest.java  | 166 +++++++++++++++++++--
 .../AllToAllBlockingResultInfoTest.java            | 108 ++++++++++++++
 .../DefaultVertexParallelismDeciderTest.java       |  88 +++++++----
 .../PointwiseBlockingResultInfoTest.java           |  69 +++++++++
 .../adaptivebatch/SpeculativeSchedulerTest.java    |  32 ++--
 30 files changed, 1094 insertions(+), 223 deletions(-)

diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java
index 3323f453373..b957c6b467e 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java
@@ -1548,7 +1548,17 @@ public class Execution
             }
         }
         if (metrics != null) {
-            this.ioMetrics = metrics;
+            // Drop IOMetrics#resultPartitionBytes because it will not be used 
anymore. It can
+            // result in very high memory usage when there are many executions 
and sub-partitions.
+            this.ioMetrics =
+                    new IOMetrics(
+                            metrics.getNumBytesIn(),
+                            metrics.getNumBytesOut(),
+                            metrics.getNumRecordsIn(),
+                            metrics.getNumRecordsOut(),
+                            metrics.getAccumulateIdleTime(),
+                            metrics.getAccumulateBusyTime(),
+                            metrics.getAccumulateBackPressuredTime());
         }
     }
 
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IOMetrics.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IOMetrics.java
index e612837531a..5dd24dab563 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IOMetrics.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IOMetrics.java
@@ -18,14 +18,20 @@
 
 package org.apache.flink.runtime.executiongraph;
 
-import org.apache.flink.metrics.Counter;
+import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.metrics.Gauge;
 import org.apache.flink.metrics.Meter;
+import org.apache.flink.runtime.io.network.metrics.ResultPartitionBytesCounter;
 import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
 
+import javax.annotation.Nullable;
+
 import java.io.Serializable;
-import java.util.HashMap;
+import java.util.Collections;
 import java.util.Map;
+import java.util.stream.Collectors;
+
+import static org.apache.flink.util.Preconditions.checkNotNull;
 
 /** An instance of this class represents a snapshot of the io-related metrics 
of a single task. */
 public class IOMetrics implements Serializable {
@@ -42,18 +48,19 @@ public class IOMetrics implements Serializable {
     protected double accumulateBusyTime;
     protected long accumulateIdleTime;
 
-    protected final Map<IntermediateResultPartitionID, Long> 
numBytesProducedOfPartitions =
-            new HashMap<>();
+    @Nullable
+    protected Map<IntermediateResultPartitionID, ResultPartitionBytes> 
resultPartitionBytes;
 
     public IOMetrics(
             Meter recordsIn,
             Meter recordsOut,
             Meter bytesIn,
             Meter bytesOut,
-            Map<IntermediateResultPartitionID, Counter> 
numBytesProducedCounters,
             Gauge<Long> accumulatedBackPressuredTime,
             Gauge<Long> accumulatedIdleTime,
-            Gauge<Double> accumulatedBusyTime) {
+            Gauge<Double> accumulatedBusyTime,
+            Map<IntermediateResultPartitionID, ResultPartitionBytesCounter>
+                    resultPartitionBytesCounters) {
         this.numRecordsIn = recordsIn.getCount();
         this.numRecordsOut = recordsOut.getCount();
         this.numBytesIn = bytesIn.getCount();
@@ -61,11 +68,12 @@ public class IOMetrics implements Serializable {
         this.accumulateBackPressuredTime = 
accumulatedBackPressuredTime.getValue();
         this.accumulateBusyTime = accumulatedBusyTime.getValue();
         this.accumulateIdleTime = accumulatedIdleTime.getValue();
-
-        for (Map.Entry<IntermediateResultPartitionID, Counter> counter :
-                numBytesProducedCounters.entrySet()) {
-            numBytesProducedOfPartitions.put(counter.getKey(), 
counter.getValue().getCount());
-        }
+        this.resultPartitionBytes =
+                resultPartitionBytesCounters.entrySet().stream()
+                        .collect(
+                                Collectors.toMap(
+                                        Map.Entry::getKey,
+                                        entry -> 
entry.getValue().createSnapshot()));
     }
 
     public IOMetrics(
@@ -74,8 +82,30 @@ public class IOMetrics implements Serializable {
             long numRecordsIn,
             long numRecordsOut,
             long accumulateIdleTime,
-            long accumulateBusyTime,
+            double accumulateBusyTime,
             long accumulateBackPressuredTime) {
+        this(
+                numBytesIn,
+                numBytesOut,
+                numRecordsIn,
+                numRecordsOut,
+                accumulateIdleTime,
+                accumulateBusyTime,
+                accumulateBackPressuredTime,
+                null);
+    }
+
+    @VisibleForTesting
+    public IOMetrics(
+            long numBytesIn,
+            long numBytesOut,
+            long numRecordsIn,
+            long numRecordsOut,
+            long accumulateIdleTime,
+            double accumulateBusyTime,
+            long accumulateBackPressuredTime,
+            @Nullable
+                    Map<IntermediateResultPartitionID, ResultPartitionBytes> 
resultPartitionBytes) {
         this.numBytesIn = numBytesIn;
         this.numBytesOut = numBytesOut;
         this.numRecordsIn = numRecordsIn;
@@ -83,6 +113,7 @@ public class IOMetrics implements Serializable {
         this.accumulateIdleTime = accumulateIdleTime;
         this.accumulateBusyTime = accumulateBusyTime;
         this.accumulateBackPressuredTime = accumulateBackPressuredTime;
+        this.resultPartitionBytes = resultPartitionBytes;
     }
 
     public long getNumRecordsIn() {
@@ -113,7 +144,7 @@ public class IOMetrics implements Serializable {
         return accumulateIdleTime;
     }
 
-    public Map<IntermediateResultPartitionID, Long> 
getNumBytesProducedOfPartitions() {
-        return numBytesProducedOfPartitions;
+    public Map<IntermediateResultPartitionID, ResultPartitionBytes> 
getResultPartitionBytes() {
+        return Collections.unmodifiableMap(checkNotNull(resultPartitionBytes));
     }
 }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ResultPartitionBytes.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ResultPartitionBytes.java
new file mode 100644
index 00000000000..630a828c648
--- /dev/null
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ResultPartitionBytes.java
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.executiongraph;
+
+import java.io.Serializable;
+
+import static org.apache.flink.util.Preconditions.checkNotNull;
+
+/** This class represents a snapshot of the result partition bytes metrics. */
+public class ResultPartitionBytes implements Serializable {
+
+    private final long[] subpartitionBytes;
+
+    public ResultPartitionBytes(long[] subpartitionBytes) {
+        this.subpartitionBytes = checkNotNull(subpartitionBytes);
+    }
+
+    public long[] getSubpartitionBytes() {
+        return subpartitionBytes;
+    }
+}
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/metrics/ResultPartitionBytesCounter.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/metrics/ResultPartitionBytesCounter.java
new file mode 100644
index 00000000000..21df06b6f68
--- /dev/null
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/metrics/ResultPartitionBytesCounter.java
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.io.network.metrics;
+
+import org.apache.flink.metrics.Counter;
+import org.apache.flink.metrics.SimpleCounter;
+import org.apache.flink.runtime.executiongraph.ResultPartitionBytes;
+
+import java.util.ArrayList;
+import java.util.List;
+
+/** This counter will count the data size of a partition. */
+public class ResultPartitionBytesCounter {
+
+    /** The data size of each subpartition. */
+    private final List<Counter> subpartitionBytes;
+
+    public ResultPartitionBytesCounter(int numSubpartitions) {
+        this.subpartitionBytes = new ArrayList<>();
+        for (int i = 0; i < numSubpartitions; ++i) {
+            subpartitionBytes.add(new SimpleCounter());
+        }
+    }
+
+    public void inc(int targetSubpartition, long bytes) {
+        subpartitionBytes.get(targetSubpartition).inc(bytes);
+    }
+
+    public void incAll(long bytes) {
+        subpartitionBytes.forEach(counter -> counter.inc(bytes));
+    }
+
+    public ResultPartitionBytes createSnapshot() {
+        return new ResultPartitionBytes(
+                
subpartitionBytes.stream().mapToLong(Counter::getCount).toArray());
+    }
+}
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/BufferWritingResultPartition.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/BufferWritingResultPartition.java
index 8b36a7587d1..2107a24c747 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/BufferWritingResultPartition.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/BufferWritingResultPartition.java
@@ -423,7 +423,7 @@ public abstract class BufferWritingResultPartition extends 
ResultPartition {
         final BufferBuilder bufferBuilder = 
unicastBufferBuilders[targetSubpartition];
         if (bufferBuilder != null) {
             int bytes = bufferBuilder.finish();
-            numBytesProduced.inc(bytes);
+            resultPartitionBytes.inc(targetSubpartition, bytes);
             numBytesOut.inc(bytes);
             numBuffersOut.inc();
             unicastBufferBuilders[targetSubpartition] = null;
@@ -440,7 +440,7 @@ public abstract class BufferWritingResultPartition extends 
ResultPartition {
     private void finishBroadcastBufferBuilder() {
         if (broadcastBufferBuilder != null) {
             int bytes = broadcastBufferBuilder.finish();
-            numBytesProduced.inc(bytes);
+            resultPartitionBytes.incAll(bytes);
             numBytesOut.inc(bytes * numSubpartitions);
             numBuffersOut.inc(numSubpartitions);
             broadcastBufferBuilder.close();
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartition.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartition.java
index e3b2e2ff74e..8f3631de15d 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartition.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartition.java
@@ -28,6 +28,7 @@ import 
org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
 import org.apache.flink.runtime.io.network.buffer.Buffer;
 import org.apache.flink.runtime.io.network.buffer.BufferCompressor;
 import org.apache.flink.runtime.io.network.buffer.BufferPool;
+import org.apache.flink.runtime.io.network.metrics.ResultPartitionBytesCounter;
 import 
org.apache.flink.runtime.io.network.partition.consumer.LocalInputChannel;
 import 
org.apache.flink.runtime.io.network.partition.consumer.RemoteInputChannel;
 import org.apache.flink.runtime.jobgraph.DistributionPattern;
@@ -112,13 +113,7 @@ public abstract class ResultPartition implements 
ResultPartitionWriter {
 
     protected Counter numBuffersOut = new SimpleCounter();
 
-    /**
-     * The difference with {@link #numBytesOut} : numBytesProduced represents 
the number of bytes
-     * actually produced, and numBytesOut represents the number of bytes sent 
to downstream tasks.
-     * In unicast scenarios, these two values should be equal. In broadcast 
scenarios, numBytesOut
-     * should be (N * numBytesProduced), where N refers to the number of 
subpartitions.
-     */
-    protected Counter numBytesProduced = new SimpleCounter();
+    protected ResultPartitionBytesCounter resultPartitionBytes;
 
     public ResultPartition(
             String owningTaskName,
@@ -141,6 +136,7 @@ public abstract class ResultPartition implements 
ResultPartitionWriter {
         this.partitionManager = checkNotNull(partitionManager);
         this.bufferCompressor = bufferCompressor;
         this.bufferPoolFactory = bufferPoolFactory;
+        this.resultPartitionBytes = new 
ResultPartitionBytesCounter(numSubpartitions);
     }
 
     /**
@@ -301,8 +297,8 @@ public abstract class ResultPartition implements 
ResultPartitionWriter {
     public void setMetricGroup(TaskIOMetricGroup metrics) {
         numBytesOut = metrics.getNumBytesOutCounter();
         numBuffersOut = metrics.getNumBuffersOutCounter();
-        metrics.registerNumBytesProducedCounterForPartition(
-                partitionId.getPartitionId(), numBytesProduced);
+        metrics.registerResultPartitionBytesCounter(
+                partitionId.getPartitionId(), resultPartitionBytes);
     }
 
     /**
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartition.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartition.java
index 2b603a219f5..def80f40574 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartition.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartition.java
@@ -391,7 +391,7 @@ public class SortMergeResultPartition extends 
ResultPartition {
                 break;
             }
 
-            updateStatistics(bufferWithChannel.getBuffer(), isBroadcast);
+            updateStatistics(bufferWithChannel, isBroadcast);
             toWrite.add(compressBufferIfPossible(bufferWithChannel));
         } while (true);
 
@@ -424,10 +424,14 @@ public class SortMergeResultPartition extends 
ResultPartition {
         return new BufferWithChannel(buffer, 
bufferWithChannel.getChannelIndex());
     }
 
-    private void updateStatistics(Buffer buffer, boolean isBroadcast) {
+    private void updateStatistics(BufferWithChannel bufferWithChannel, boolean 
isBroadcast) {
         numBuffersOut.inc(isBroadcast ? numSubpartitions : 1);
-        long readableBytes = buffer.readableBytes();
-        numBytesProduced.inc(readableBytes);
+        long readableBytes = bufferWithChannel.getBuffer().readableBytes();
+        if (isBroadcast) {
+            resultPartitionBytes.incAll(readableBytes);
+        } else {
+            resultPartitionBytes.inc(bufferWithChannel.getChannelIndex(), 
readableBytes);
+        }
         numBytesOut.inc(isBroadcast ? readableBytes * numSubpartitions : 
readableBytes);
     }
 
@@ -456,7 +460,7 @@ public class SortMergeResultPartition extends 
ResultPartition {
 
             NetworkBuffer buffer = new NetworkBuffer(writeBuffer, (buf) -> {}, 
dataType, toCopy);
             BufferWithChannel bufferWithChannel = new 
BufferWithChannel(buffer, targetSubpartition);
-            updateStatistics(buffer, isBroadcast);
+            updateStatistics(bufferWithChannel, isBroadcast);
             toWrite.add(compressBufferIfPossible(bufferWithChannel));
         }
 
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsResultPartition.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsResultPartition.java
index d4f2f632aff..eb915335ab1 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsResultPartition.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsResultPartition.java
@@ -152,7 +152,7 @@ public class HsResultPartition extends ResultPartition {
 
     @Override
     public void emitRecord(ByteBuffer record, int targetSubpartition) throws 
IOException {
-        numBytesProduced.inc(record.remaining());
+        resultPartitionBytes.inc(targetSubpartition, record.remaining());
         emit(record, targetSubpartition, Buffer.DataType.DATA_BUFFER);
     }
 
@@ -173,7 +173,7 @@ public class HsResultPartition extends ResultPartition {
     }
 
     private void broadcast(ByteBuffer record, Buffer.DataType dataType) throws 
IOException {
-        numBytesProduced.inc(record.remaining());
+        resultPartitionBytes.incAll(record.remaining());
         if (isBroadcastOnly) {
             emit(record, BROADCAST_CHANNEL, dataType);
         } else {
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/metrics/groups/TaskIOMetricGroup.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/metrics/groups/TaskIOMetricGroup.java
index 17dca7b97b7..6903f666bda 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/metrics/groups/TaskIOMetricGroup.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/metrics/groups/TaskIOMetricGroup.java
@@ -25,6 +25,7 @@ import org.apache.flink.metrics.Meter;
 import org.apache.flink.metrics.MeterView;
 import org.apache.flink.metrics.SimpleCounter;
 import org.apache.flink.runtime.executiongraph.IOMetrics;
+import org.apache.flink.runtime.io.network.metrics.ResultPartitionBytesCounter;
 import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
 import org.apache.flink.runtime.metrics.DescriptiveStatisticsHistogram;
 import org.apache.flink.runtime.metrics.MetricNames;
@@ -71,8 +72,8 @@ public class TaskIOMetricGroup extends 
ProxyMetricGroup<TaskMetricGroup> {
 
     private long taskStartTime;
 
-    private final Map<IntermediateResultPartitionID, Counter> 
numBytesProducedOfPartitions =
-            new HashMap<>();
+    private final Map<IntermediateResultPartitionID, 
ResultPartitionBytesCounter>
+            resultPartitionBytes = new HashMap<>();
 
     public TaskIOMetricGroup(TaskMetricGroup parent) {
         super(parent);
@@ -135,10 +136,10 @@ public class TaskIOMetricGroup extends 
ProxyMetricGroup<TaskMetricGroup> {
                 numRecordsOutRate,
                 numBytesInRate,
                 numBytesOutRate,
-                numBytesProducedOfPartitions,
                 accumulatedBackPressuredTime,
                 accumulatedIdleTime,
-                accumulatedBusyTime);
+                accumulatedBusyTime,
+                resultPartitionBytes);
     }
 
     // 
============================================================================================
@@ -238,9 +239,10 @@ public class TaskIOMetricGroup extends 
ProxyMetricGroup<TaskMetricGroup> {
         this.numRecordsOut.addCounter(numRecordsOutCounter);
     }
 
-    public void registerNumBytesProducedCounterForPartition(
-            IntermediateResultPartitionID resultPartitionId, Counter 
numBytesProducedCounter) {
-        this.numBytesProducedOfPartitions.put(resultPartitionId, 
numBytesProducedCounter);
+    public void registerResultPartitionBytesCounter(
+            IntermediateResultPartitionID resultPartitionId,
+            ResultPartitionBytesCounter resultPartitionBytesCounter) {
+        this.resultPartitionBytes.put(resultPartitionId, 
resultPartitionBytesCounter);
     }
 
     public void registerMailboxSizeSupplier(SizeGauge.SizeSupplier<Integer> 
supplier) {
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/DefaultScheduler.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/DefaultScheduler.java
index da59e2275fd..2136479a2c7 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/DefaultScheduler.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/DefaultScheduler.java
@@ -30,6 +30,7 @@ import 
org.apache.flink.runtime.concurrent.ComponentMainThreadExecutor;
 import org.apache.flink.runtime.execution.ExecutionState;
 import org.apache.flink.runtime.executiongraph.Execution;
 import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
+import org.apache.flink.runtime.executiongraph.IOMetrics;
 import org.apache.flink.runtime.executiongraph.JobStatusListener;
 import 
org.apache.flink.runtime.executiongraph.failover.flip1.ExecutionFailureHandler;
 import org.apache.flink.runtime.executiongraph.failover.flip1.FailoverStrategy;
@@ -215,7 +216,7 @@ public class DefaultScheduler extends SchedulerBase 
implements SchedulerOperatio
     }
 
     @Override
-    protected void onTaskFinished(final Execution execution) {
+    protected void onTaskFinished(final Execution execution, final IOMetrics 
ioMetrics) {
         checkState(execution.getState() == ExecutionState.FINISHED);
 
         final ExecutionVertexID executionVertexId = 
execution.getVertex().getID();
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/SchedulerBase.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/SchedulerBase.java
index 02c19f92391..846d0948e61 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/SchedulerBase.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/SchedulerBase.java
@@ -53,6 +53,7 @@ import 
org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.executiongraph.ExecutionGraph;
 import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
 import org.apache.flink.runtime.executiongraph.ExecutionVertex;
+import org.apache.flink.runtime.executiongraph.IOMetrics;
 import org.apache.flink.runtime.executiongraph.JobStatusListener;
 import org.apache.flink.runtime.executiongraph.JobStatusProvider;
 import org.apache.flink.runtime.executiongraph.TaskExecutionStateTransition;
@@ -733,7 +734,7 @@ public abstract class SchedulerBase implements SchedulerNG, 
CheckpointScheduling
         // can be refined in FLINK-14233 after the actions are factored out 
from ExecutionGraph.
         switch (taskExecutionState.getExecutionState()) {
             case FINISHED:
-                onTaskFinished(execution);
+                onTaskFinished(execution, taskExecutionState.getIOMetrics());
                 break;
             case FAILED:
                 onTaskFailed(execution);
@@ -741,7 +742,7 @@ public abstract class SchedulerBase implements SchedulerNG, 
CheckpointScheduling
         }
     }
 
-    protected abstract void onTaskFinished(final Execution execution);
+    protected abstract void onTaskFinished(final Execution execution, final 
IOMetrics ioMetrics);
 
     protected abstract void onTaskFailed(final Execution execution);
 
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AbstractBlockingResultInfo.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AbstractBlockingResultInfo.java
new file mode 100644
index 00000000000..33147bcdc16
--- /dev/null
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AbstractBlockingResultInfo.java
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.scheduler.adaptivebatch;
+
+import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.runtime.executiongraph.ResultPartitionBytes;
+import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
+
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.apache.flink.util.Preconditions.checkNotNull;
+import static org.apache.flink.util.Preconditions.checkState;
+
+/** Base blocking result info. */
+abstract class AbstractBlockingResultInfo implements BlockingResultInfo {
+
+    private final IntermediateDataSetID resultId;
+
+    protected final int numOfPartitions;
+
+    protected final int numOfSubpartitions;
+
+    /**
+     * The subpartition bytes map. The key is the partition index, value is a 
subpartition bytes
+     * list.
+     */
+    protected final Map<Integer, long[]> subpartitionBytesByPartitionIndex;
+
+    AbstractBlockingResultInfo(
+            IntermediateDataSetID resultId, int numOfPartitions, int 
numOfSubpartitions) {
+        this.resultId = checkNotNull(resultId);
+        this.numOfPartitions = numOfPartitions;
+        this.numOfSubpartitions = numOfSubpartitions;
+        this.subpartitionBytesByPartitionIndex = new HashMap<>();
+    }
+
+    @Override
+    public IntermediateDataSetID getResultId() {
+        return resultId;
+    }
+
+    @Override
+    public void recordPartitionInfo(int partitionIndex, ResultPartitionBytes 
partitionBytes) {
+        checkState(partitionBytes.getSubpartitionBytes().length == 
numOfSubpartitions);
+        subpartitionBytesByPartitionIndex.put(
+                partitionIndex, partitionBytes.getSubpartitionBytes());
+    }
+
+    @Override
+    public void resetPartitionInfo(int partitionIndex) {
+        subpartitionBytesByPartitionIndex.remove(partitionIndex);
+    }
+
+    @VisibleForTesting
+    int getNumOfRecordedPartitions() {
+        return subpartitionBytesByPartitionIndex.size();
+    }
+}
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchScheduler.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchScheduler.java
index 6a80d7a9569..da4beee4c43 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchScheduler.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchScheduler.java
@@ -26,12 +26,19 @@ import org.apache.flink.runtime.JobException;
 import org.apache.flink.runtime.checkpoint.CheckpointRecoveryFactory;
 import org.apache.flink.runtime.checkpoint.CheckpointsCleaner;
 import org.apache.flink.runtime.concurrent.ComponentMainThreadExecutor;
+import org.apache.flink.runtime.execution.ExecutionState;
 import org.apache.flink.runtime.executiongraph.Execution;
 import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
+import org.apache.flink.runtime.executiongraph.ExecutionVertex;
+import org.apache.flink.runtime.executiongraph.IOMetrics;
 import org.apache.flink.runtime.executiongraph.IntermediateResult;
 import org.apache.flink.runtime.executiongraph.JobStatusListener;
+import org.apache.flink.runtime.executiongraph.ResultPartitionBytes;
 import org.apache.flink.runtime.executiongraph.failover.flip1.FailoverStrategy;
 import 
org.apache.flink.runtime.executiongraph.failover.flip1.RestartBackoffTimeStrategy;
+import org.apache.flink.runtime.jobgraph.DistributionPattern;
+import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
+import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
 import org.apache.flink.runtime.jobgraph.JobEdge;
 import org.apache.flink.runtime.jobgraph.JobGraph;
 import org.apache.flink.runtime.jobgraph.JobVertex;
@@ -50,6 +57,7 @@ import 
org.apache.flink.runtime.scheduler.ExecutionVertexVersioner;
 import org.apache.flink.runtime.scheduler.VertexParallelismStore;
 import 
org.apache.flink.runtime.scheduler.adaptivebatch.forwardgroup.ForwardGroup;
 import 
org.apache.flink.runtime.scheduler.adaptivebatch.forwardgroup.ForwardGroupComputeUtil;
+import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID;
 import org.apache.flink.runtime.scheduler.strategy.SchedulingStrategyFactory;
 import org.apache.flink.runtime.shuffle.ShuffleMaster;
 import org.apache.flink.util.concurrent.ScheduledExecutor;
@@ -57,6 +65,7 @@ import org.apache.flink.util.concurrent.ScheduledExecutor;
 import org.slf4j.Logger;
 
 import java.util.ArrayList;
+import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Optional;
@@ -64,6 +73,7 @@ import java.util.concurrent.Executor;
 import java.util.function.Consumer;
 import java.util.function.Function;
 
+import static org.apache.flink.util.Preconditions.checkArgument;
 import static org.apache.flink.util.Preconditions.checkNotNull;
 
 /**
@@ -78,6 +88,8 @@ public class AdaptiveBatchScheduler extends DefaultScheduler {
 
     private final Map<JobVertexID, ForwardGroup> forwardGroupsByJobVertexId;
 
+    private final Map<IntermediateDataSetID, BlockingResultInfo> 
blockingResultInfos;
+
     public AdaptiveBatchScheduler(
             final Logger log,
             final JobGraph jobGraph,
@@ -140,6 +152,8 @@ public class AdaptiveBatchScheduler extends 
DefaultScheduler {
                 ForwardGroupComputeUtil.computeForwardGroups(
                         jobGraph.getVerticesSortedTopologicallyFromSources(),
                         getExecutionGraph()::getJobVertex);
+
+        this.blockingResultInfos = new HashMap<>();
     }
 
     @Override
@@ -150,10 +164,58 @@ public class AdaptiveBatchScheduler extends 
DefaultScheduler {
     }
 
     @Override
-    protected void onTaskFinished(final Execution execution) {
+    protected void onTaskFinished(final Execution execution, final IOMetrics 
ioMetrics) {
+        checkNotNull(ioMetrics);
+        updateResultPartitionBytesMetrics(ioMetrics.getResultPartitionBytes());
         initializeVerticesIfPossible();
 
-        super.onTaskFinished(execution);
+        super.onTaskFinished(execution, ioMetrics);
+    }
+
+    private void updateResultPartitionBytesMetrics(
+            Map<IntermediateResultPartitionID, ResultPartitionBytes> 
resultPartitionBytes) {
+        checkNotNull(resultPartitionBytes);
+        resultPartitionBytes.forEach(
+                (partitionId, partitionBytes) -> {
+                    IntermediateResult result =
+                            getExecutionGraph()
+                                    .getAllIntermediateResults()
+                                    
.get(partitionId.getIntermediateDataSetID());
+                    checkNotNull(result);
+
+                    blockingResultInfos.compute(
+                            result.getId(),
+                            (ignored, resultInfo) -> {
+                                if (resultInfo == null) {
+                                    resultInfo = 
createFromIntermediateResult(result);
+                                }
+                                resultInfo.recordPartitionInfo(
+                                        partitionId.getPartitionNumber(), 
partitionBytes);
+                                return resultInfo;
+                            });
+                });
+    }
+
+    @Override
+    protected void resetForNewExecution(final ExecutionVertexID 
executionVertexId) {
+        final ExecutionVertex executionVertex = 
getExecutionVertex(executionVertexId);
+        if (executionVertex.getExecutionState() == ExecutionState.FINISHED) {
+            executionVertex
+                    .getProducedPartitions()
+                    .values()
+                    .forEach(
+                            partition -> {
+                                blockingResultInfos.computeIfPresent(
+                                        
partition.getIntermediateResult().getId(),
+                                        (ignored, resultInfo) -> {
+                                            resultInfo.resetPartitionInfo(
+                                                    
partition.getPartitionNumber());
+                                            return resultInfo;
+                                        });
+                            });
+        }
+
+        super.resetForNewExecution(executionVertexId);
     }
 
     void initializeVerticesIfPossible() {
@@ -246,12 +308,9 @@ public class AdaptiveBatchScheduler extends 
DefaultScheduler {
             final ExecutionJobVertex producerVertex =
                     
getExecutionJobVertex(consumedResult.getProducer().getId());
             if (producerVertex.isFinished()) {
-                IntermediateResult intermediateResult =
-                        
getExecutionGraph().getAllIntermediateResults().get(consumedResult.getId());
-                checkNotNull(intermediateResult);
-
-                consumableResultInfo.add(
-                        
BlockingResultInfo.createFromIntermediateResult(intermediateResult));
+                BlockingResultInfo resultInfo =
+                        
checkNotNull(blockingResultInfos.get(consumedResult.getId()));
+                consumableResultInfo.add(resultInfo);
             } else {
                 // not all inputs consumable, return Optional.empty()
                 return Optional.empty();
@@ -320,4 +379,27 @@ public class AdaptiveBatchScheduler extends 
DefaultScheduler {
                 },
                 Function.identity());
     }
+
+    private static BlockingResultInfo 
createFromIntermediateResult(IntermediateResult result) {
+        checkArgument(result != null);
+        // Note that for dynamic graph, different partitions in the same 
result have the same number
+        // of subpartitions.
+        if (result.getConsumingDistributionPattern() == 
DistributionPattern.POINTWISE) {
+            return new PointwiseBlockingResultInfo(
+                    result.getId(),
+                    result.getNumberOfAssignedPartitions(),
+                    result.getPartitions()[0].getNumberOfSubpartitions());
+        } else {
+            return new AllToAllBlockingResultInfo(
+                    result.getId(),
+                    result.getNumberOfAssignedPartitions(),
+                    result.getPartitions()[0].getNumberOfSubpartitions(),
+                    result.isBroadcast());
+        }
+    }
+
+    @VisibleForTesting
+    BlockingResultInfo getBlockingResultInfo(IntermediateDataSetID resultId) {
+        return blockingResultInfos.get(resultId);
+    }
 }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfo.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfo.java
new file mode 100644
index 00000000000..a4a06efc57a
--- /dev/null
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfo.java
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.scheduler.adaptivebatch;
+
+import org.apache.flink.runtime.executiongraph.ResultPartitionBytes;
+import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
+
+import javax.annotation.Nullable;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.stream.Collectors;
+
+import static org.apache.flink.util.Preconditions.checkState;
+
+/** Information of All-To-All result. */
+public class AllToAllBlockingResultInfo extends AbstractBlockingResultInfo {
+
+    private final boolean isBroadcast;
+
+    /**
+     * Aggregated subpartition bytes, which aggregates the subpartition bytes 
with the same
+     * subpartition index in different partitions. Note that We can aggregate 
them because they will
+     * be consumed by the same downstream task.
+     */
+    @Nullable private List<Long> aggregatedSubpartitionBytes;
+
+    AllToAllBlockingResultInfo(
+            IntermediateDataSetID resultId,
+            int numOfPartitions,
+            int numOfSubpartitions,
+            boolean isBroadcast) {
+        super(resultId, numOfPartitions, numOfSubpartitions);
+        this.isBroadcast = isBroadcast;
+    }
+
+    @Override
+    public boolean isBroadcast() {
+        return isBroadcast;
+    }
+
+    @Override
+    public boolean isPointwise() {
+        return false;
+    }
+
+    @Override
+    public long getNumBytesProduced() {
+        checkState(aggregatedSubpartitionBytes != null, "Not all partition 
infos are ready");
+        if (isBroadcast) {
+            return aggregatedSubpartitionBytes.get(0);
+        } else {
+            return aggregatedSubpartitionBytes.stream().reduce(0L, Long::sum);
+        }
+    }
+
+    @Override
+    public void recordPartitionInfo(int partitionIndex, ResultPartitionBytes 
partitionBytes) {
+        // Once all partitions are finished, we can convert the subpartition 
bytes to aggregated
+        // value to reduce the space usage, because the distribution of source 
splits does not
+        // affect the distribution of data consumed by downstream tasks of 
ALL_TO_ALL edges(Hashing
+        // or Rebalancing, we do not consider rare cases such as custom 
partitions here).
+        if (aggregatedSubpartitionBytes == null) {
+            super.recordPartitionInfo(partitionIndex, partitionBytes);
+
+            if (subpartitionBytesByPartitionIndex.size() == numOfPartitions) {
+                long[] aggregatedBytes = new long[numOfSubpartitions];
+                subpartitionBytesByPartitionIndex
+                        .values()
+                        .forEach(
+                                subpartitionBytes -> {
+                                    checkState(subpartitionBytes.length == 
numOfSubpartitions);
+                                    for (int i = 0; i < 
subpartitionBytes.length; ++i) {
+                                        aggregatedBytes[i] += 
subpartitionBytes[i];
+                                    }
+                                });
+                this.aggregatedSubpartitionBytes =
+                        
Arrays.stream(aggregatedBytes).boxed().collect(Collectors.toList());
+                this.subpartitionBytesByPartitionIndex.clear();
+            }
+        }
+    }
+
+    @Override
+    public void resetPartitionInfo(int partitionIndex) {
+        if (aggregatedSubpartitionBytes == null) {
+            super.resetPartitionInfo(partitionIndex);
+        }
+    }
+
+    public List<Long> getAggregatedSubpartitionBytes() {
+        checkState(aggregatedSubpartitionBytes != null, "Not all partition 
infos are ready");
+        return Collections.unmodifiableList(aggregatedSubpartitionBytes);
+    }
+}
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/BlockingResultInfo.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/BlockingResultInfo.java
index 302980ab216..0fd14e4439e 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/BlockingResultInfo.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/BlockingResultInfo.java
@@ -18,63 +18,60 @@
 
 package org.apache.flink.runtime.scheduler.adaptivebatch;
 
-import org.apache.flink.annotation.VisibleForTesting;
-import org.apache.flink.runtime.executiongraph.IOMetrics;
-import org.apache.flink.runtime.executiongraph.IntermediateResult;
-import org.apache.flink.runtime.executiongraph.IntermediateResultPartition;
+import org.apache.flink.runtime.executiongraph.ResultPartitionBytes;
+import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
 
-import java.util.ArrayList;
-import java.util.List;
-
-import static org.apache.flink.util.Preconditions.checkArgument;
-import static org.apache.flink.util.Preconditions.checkNotNull;
-import static org.apache.flink.util.Preconditions.checkState;
-
-/** The blocking result info, which will be used to calculate the vertex 
parallelism. */
-public class BlockingResultInfo {
-
-    private final List<Long> blockingPartitionSizes;
-
-    private final boolean isBroadcast;
-
-    private BlockingResultInfo(List<Long> blockingPartitionSizes, boolean 
isBroadcast) {
-        this.blockingPartitionSizes = blockingPartitionSizes;
-        this.isBroadcast = isBroadcast;
-    }
-
-    public List<Long> getBlockingPartitionSizes() {
-        return blockingPartitionSizes;
-    }
-
-    public boolean isBroadcast() {
-        return isBroadcast;
-    }
-
-    @VisibleForTesting
-    static BlockingResultInfo createFromBroadcastResult(List<Long> 
blockingPartitionSizes) {
-        return new BlockingResultInfo(blockingPartitionSizes, true);
-    }
-
-    @VisibleForTesting
-    static BlockingResultInfo createFromNonBroadcastResult(List<Long> 
blockingPartitionSizes) {
-        return new BlockingResultInfo(blockingPartitionSizes, false);
-    }
-
-    public static BlockingResultInfo createFromIntermediateResult(
-            IntermediateResult intermediateResult) {
-        checkArgument(intermediateResult != null);
-
-        List<Long> blockingPartitionSizes = new ArrayList<>();
-        for (IntermediateResultPartition partition : 
intermediateResult.getPartitions()) {
-            checkState(partition.isConsumable());
-
-            IOMetrics ioMetrics = 
partition.getProducer().getPartitionProducer().getIOMetrics();
-            checkNotNull(ioMetrics, "IOMetrics should not be null.");
-
-            blockingPartitionSizes.add(
-                    
ioMetrics.getNumBytesProducedOfPartitions().get(partition.getPartitionId()));
-        }
-
-        return new BlockingResultInfo(blockingPartitionSizes, 
intermediateResult.isBroadcast());
-    }
+/**
+ * The blocking result info, which will be used to calculate the vertex 
parallelism and input infos.
+ */
+public interface BlockingResultInfo {
+
+    /**
+     * Get the intermediate result id.
+     *
+     * @return the intermediate result id
+     */
+    IntermediateDataSetID getResultId();
+
+    /**
+     * Whether it is a broadcast result.
+     *
+     * @return whether it is a broadcast result
+     */
+    boolean isBroadcast();
+
+    /**
+     * Whether it is a pointwise result.
+     *
+     * @return whether it is a pointwise result
+     */
+    boolean isPointwise();
+
+    /**
+     * Return the num of bytes produced(numBytesProduced) by the producer.
+     *
+     * <p>The difference between numBytesProduced and numBytesOut : 
numBytesProduced represents the
+     * number of bytes actually produced, and numBytesOut represents the 
number of bytes sent to
+     * downstream tasks. In unicast scenarios, these two values should be 
equal. In broadcast
+     * scenarios, numBytesOut should be (N * numBytesProduced), where N refers 
to the number of
+     * subpartitions.
+     *
+     * @return the num of bytes produced by the producer
+     */
+    long getNumBytesProduced();
+
+    /**
+     * Record the information of the result partition.
+     *
+     * @param partitionIndex the intermediate result partition index
+     * @param partitionBytes the {@link ResultPartitionBytes} of the partition
+     */
+    void recordPartitionInfo(int partitionIndex, ResultPartitionBytes 
partitionBytes);
+
+    /**
+     * Reset the information of the result partition.
+     *
+     * @param partitionIndex the intermediate result partition index
+     */
+    void resetPartitionInfo(int partitionIndex);
 }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismDecider.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismDecider.java
index 43225fd4f61..64e6998a4fa 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismDecider.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismDecider.java
@@ -92,19 +92,13 @@ public class DefaultVertexParallelismDecider implements 
VertexParallelismDecider
         long broadcastBytes =
                 consumedResults.stream()
                         .filter(BlockingResultInfo::isBroadcast)
-                        .mapToLong(
-                                consumedResult ->
-                                        
consumedResult.getBlockingPartitionSizes().stream()
-                                                .reduce(0L, Long::sum))
+                        .mapToLong(BlockingResultInfo::getNumBytesProduced)
                         .sum();
 
         long nonBroadcastBytes =
                 consumedResults.stream()
                         .filter(consumedResult -> 
!consumedResult.isBroadcast())
-                        .mapToLong(
-                                consumedResult ->
-                                        
consumedResult.getBlockingPartitionSizes().stream()
-                                                .reduce(0L, Long::sum))
+                        .mapToLong(BlockingResultInfo::getNumBytesProduced)
                         .sum();
 
         long expectedMaxBroadcastBytes =
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/PointwiseBlockingResultInfo.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/PointwiseBlockingResultInfo.java
new file mode 100644
index 00000000000..287b180df00
--- /dev/null
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/PointwiseBlockingResultInfo.java
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.scheduler.adaptivebatch;
+
+import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
+
+import java.util.Arrays;
+
+import static org.apache.flink.util.Preconditions.checkState;
+
+/** Information of Pointwise result. */
+public class PointwiseBlockingResultInfo extends AbstractBlockingResultInfo {
+    PointwiseBlockingResultInfo(
+            IntermediateDataSetID resultId, int numOfPartitions, int 
numOfSubpartitions) {
+        super(resultId, numOfPartitions, numOfSubpartitions);
+    }
+
+    @Override
+    public boolean isBroadcast() {
+        return false;
+    }
+
+    @Override
+    public boolean isPointwise() {
+        return true;
+    }
+
+    @Override
+    public long getNumBytesProduced() {
+        checkState(
+                subpartitionBytesByPartitionIndex.size() == numOfPartitions,
+                "Not all partition infos are ready");
+        return subpartitionBytesByPartitionIndex.values().stream()
+                .flatMapToLong(Arrays::stream)
+                .reduce(0L, Long::sum);
+    }
+}
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/SpeculativeScheduler.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/SpeculativeScheduler.java
index 7c0a6c806af..37c72aa2daa 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/SpeculativeScheduler.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/SpeculativeScheduler.java
@@ -35,6 +35,7 @@ import org.apache.flink.runtime.execution.ExecutionState;
 import org.apache.flink.runtime.executiongraph.Execution;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.executiongraph.ExecutionVertex;
+import org.apache.flink.runtime.executiongraph.IOMetrics;
 import org.apache.flink.runtime.executiongraph.JobStatusListener;
 import org.apache.flink.runtime.executiongraph.SpeculativeExecutionVertex;
 import org.apache.flink.runtime.executiongraph.failover.flip1.FailoverStrategy;
@@ -194,7 +195,7 @@ public class SpeculativeScheduler extends 
AdaptiveBatchScheduler
     }
 
     @Override
-    protected void onTaskFinished(final Execution execution) {
+    protected void onTaskFinished(final Execution execution, final IOMetrics 
ioMetrics) {
         if (!isOriginalAttempt(execution)) {
             numEffectiveSpeculativeExecutionsCounter.inc();
         }
@@ -202,7 +203,7 @@ public class SpeculativeScheduler extends 
AdaptiveBatchScheduler
         // cancel all un-terminated executions because the execution vertex 
has finished
         
FutureUtils.assertNoException(cancelPendingExecutions(execution.getVertex().getID()));
 
-        super.onTaskFinished(execution);
+        super.onTaskFinished(execution, ioMetrics);
     }
 
     private static boolean isOriginalAttempt(final Execution execution) {
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraphDeploymentTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraphDeploymentTest.java
index 32e6302ec88..e629fce99a8 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraphDeploymentTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraphDeploymentTest.java
@@ -336,7 +336,7 @@ class DefaultExecutionGraphDeploymentTest {
 
         scheduler.updateTaskExecutionState(state);
 
-        assertThat(execution1.getIOMetrics()).isEqualTo(ioMetrics);
+        assertIOMetricsEqual(execution1.getIOMetrics(), ioMetrics);
         assertThat(execution1.getUserAccumulators()).isNotNull();
         
assertThat(execution1.getUserAccumulators().get("acc").getLocalValue()).isEqualTo(4);
 
@@ -359,7 +359,7 @@ class DefaultExecutionGraphDeploymentTest {
 
         scheduler.updateTaskExecutionState(state2);
 
-        assertThat(execution2.getIOMetrics()).isEqualTo(ioMetrics2);
+        assertIOMetricsEqual(execution2.getIOMetrics(), ioMetrics2);
         assertThat(execution2.getUserAccumulators()).isNotNull();
         
assertThat(execution2.getUserAccumulators().get("acc").getLocalValue()).isEqualTo(8);
     }
@@ -388,13 +388,13 @@ class DefaultExecutionGraphDeploymentTest {
         execution1.cancel();
         execution1.completeCancelling(accumulators, ioMetrics, false);
 
-        assertThat(execution1.getIOMetrics()).isEqualTo(ioMetrics);
+        assertIOMetricsEqual(execution1.getIOMetrics(), ioMetrics);
         assertThat(execution1.getUserAccumulators()).isEqualTo(accumulators);
 
         Execution execution2 = executions.values().iterator().next();
         execution2.markFailed(new Throwable(), false, accumulators, ioMetrics, 
false, true);
 
-        assertThat(execution2.getIOMetrics()).isEqualTo(ioMetrics);
+        assertIOMetricsEqual(execution2.getIOMetrics(), ioMetrics);
         assertThat(execution2.getUserAccumulators()).isEqualTo(accumulators);
     }
 
@@ -667,7 +667,7 @@ class DefaultExecutionGraphDeploymentTest {
                 .build(EXECUTOR_EXTENSION.getExecutor());
     }
 
-    private boolean isDeployedInTopologicalOrder(
+    private static boolean isDeployedInTopologicalOrder(
             List<ExecutionAttemptID> submissionOrder,
             List<Collection<ExecutionAttemptID>> executionStages) {
         final Iterator<ExecutionAttemptID> submissionIterator = 
submissionOrder.iterator();
@@ -688,4 +688,16 @@ class DefaultExecutionGraphDeploymentTest {
 
         return !submissionIterator.hasNext();
     }
+
+    private void assertIOMetricsEqual(IOMetrics ioMetrics1, IOMetrics 
ioMetrics2) {
+        assertThat(ioMetrics1.numBytesIn).isEqualTo(ioMetrics2.numBytesIn);
+        assertThat(ioMetrics1.numBytesOut).isEqualTo(ioMetrics2.numBytesOut);
+        assertThat(ioMetrics1.numRecordsIn).isEqualTo(ioMetrics2.numRecordsIn);
+        
assertThat(ioMetrics1.numRecordsOut).isEqualTo(ioMetrics2.numRecordsOut);
+        
assertThat(ioMetrics1.accumulateIdleTime).isEqualTo(ioMetrics2.accumulateIdleTime);
+        
assertThat(ioMetrics1.accumulateBusyTime).isEqualTo(ioMetrics2.accumulateBusyTime);
+        assertThat(ioMetrics1.accumulateBackPressuredTime)
+                .isEqualTo(ioMetrics2.accumulateBackPressuredTime);
+        
assertThat(ioMetrics1.resultPartitionBytes).isEqualTo(ioMetrics2.resultPartitionBytes);
+    }
 }
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/ResultPartitionTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/ResultPartitionTest.java
index 94095c755b8..f66b3bcf9b4 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/ResultPartitionTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/ResultPartitionTest.java
@@ -697,30 +697,44 @@ class ResultPartitionTest {
     }
 
     @Test
-    void testNumBytesProducedCounterForUnicast() throws IOException {
-        testNumBytesProducedCounter(false);
+    void testResultPartitionBytesCounterForUnicast() throws IOException {
+        testResultPartitionBytesCounter(false);
     }
 
     @Test
-    void testNumBytesProducedCounterForBroadcast() throws IOException {
-        testNumBytesProducedCounter(true);
+    void testResultPartitionBytesCounterForBroadcast() throws IOException {
+        testResultPartitionBytesCounter(true);
     }
 
-    private void testNumBytesProducedCounter(boolean isBroadcast) throws 
IOException {
+    private void testResultPartitionBytesCounter(boolean isBroadcast) throws 
IOException {
         BufferWritingResultPartition bufferWritingResultPartition =
                 createResultPartition(ResultPartitionType.BLOCKING);
 
         if (isBroadcast) {
             
bufferWritingResultPartition.broadcastRecord(ByteBuffer.allocate(bufferSize));
-            
assertThat(bufferWritingResultPartition.numBytesProduced.getCount())
-                    .isEqualTo(bufferSize);
+
+            long[] subpartitionBytes =
+                    bufferWritingResultPartition
+                            .resultPartitionBytes
+                            .createSnapshot()
+                            .getSubpartitionBytes();
+            assertThat(subpartitionBytes).containsExactly(bufferSize, 
bufferSize);
+
             assertThat(bufferWritingResultPartition.numBytesOut.getCount())
                     .isEqualTo(2 * bufferSize);
         } else {
             
bufferWritingResultPartition.emitRecord(ByteBuffer.allocate(bufferSize), 0);
-            
assertThat(bufferWritingResultPartition.numBytesProduced.getCount())
-                    .isEqualTo(bufferSize);
-            
assertThat(bufferWritingResultPartition.numBytesOut.getCount()).isEqualTo(bufferSize);
+            bufferWritingResultPartition.emitRecord(ByteBuffer.allocate(2 * 
bufferSize), 1);
+
+            long[] subpartitionBytes =
+                    bufferWritingResultPartition
+                            .resultPartitionBytes
+                            .createSnapshot()
+                            .getSubpartitionBytes();
+            assertThat(subpartitionBytes).containsExactly(bufferSize, (long) 2 
* bufferSize);
+
+            assertThat(bufferWritingResultPartition.numBytesOut.getCount())
+                    .isEqualTo(3 * bufferSize);
         }
     }
 
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartitionTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartitionTest.java
index 1403b60dff2..3139935ab90 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartitionTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartitionTest.java
@@ -415,17 +415,17 @@ public class SortMergeResultPartitionTest {
 
     @TestTemplate
     void testNumBytesProducedCounterForUnicast() throws IOException {
-        testNumBytesProducedCounter(false);
+        testResultPartitionBytesCounter(false);
     }
 
     @TestTemplate
     void testNumBytesProducedCounterForBroadcast() throws IOException {
-        testNumBytesProducedCounter(true);
+        testResultPartitionBytesCounter(true);
     }
 
-    private void testNumBytesProducedCounter(boolean isBroadcast) throws 
IOException {
+    private void testResultPartitionBytesCounter(boolean isBroadcast) throws 
IOException {
         int numBuffers = useHashDataBuffer ? 100 : 15;
-        int numSubpartitions = 10;
+        int numSubpartitions = 2;
 
         BufferPool bufferPool = globalPool.createBufferPool(numBuffers, 
numBuffers);
         SortMergeResultPartition partition =
@@ -435,16 +435,25 @@ public class SortMergeResultPartitionTest {
             partition.broadcastRecord(ByteBuffer.allocate(bufferSize));
             partition.finish();
 
-            
assertThat(partition.numBytesProduced.getCount()).isEqualTo(bufferSize + 4);
+            long[] subpartitionBytes =
+                    
partition.resultPartitionBytes.createSnapshot().getSubpartitionBytes();
+            assertThat(subpartitionBytes)
+                    .containsExactly((long) bufferSize + 4, (long) bufferSize 
+ 4);
+
             assertThat(partition.numBytesOut.getCount())
                     .isEqualTo(numSubpartitions * (bufferSize + 4));
         } else {
             partition.emitRecord(ByteBuffer.allocate(bufferSize), 0);
+            partition.emitRecord(ByteBuffer.allocate(2 * bufferSize), 1);
             partition.finish();
 
-            
assertThat(partition.numBytesProduced.getCount()).isEqualTo(bufferSize + 4);
+            long[] subpartitionBytes =
+                    
partition.resultPartitionBytes.createSnapshot().getSubpartitionBytes();
+            assertThat(subpartitionBytes)
+                    .containsExactly((long) bufferSize + 4, (long) 2 * 
bufferSize + 4);
+
             assertThat(partition.numBytesOut.getCount())
-                    .isEqualTo(bufferSize + numSubpartitions * 4);
+                    .isEqualTo(3 * bufferSize + numSubpartitions * 4);
         }
     }
 
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsResultPartitionTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsResultPartitionTest.java
index 71f777046e7..0ad1541f697 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsResultPartitionTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsResultPartitionTest.java
@@ -24,6 +24,7 @@ import org.apache.flink.core.memory.MemorySegmentFactory;
 import org.apache.flink.core.testutils.CheckedThread;
 import org.apache.flink.runtime.event.AbstractEvent;
 import org.apache.flink.runtime.executiongraph.IOMetrics;
+import org.apache.flink.runtime.executiongraph.ResultPartitionBytes;
 import org.apache.flink.runtime.io.disk.BatchShuffleReadBufferPool;
 import org.apache.flink.runtime.io.disk.FileChannelManager;
 import org.apache.flink.runtime.io.disk.FileChannelManagerImpl;
@@ -439,9 +440,11 @@ class HsResultPartitionTest {
             assertThat(taskIOMetricGroup.getNumBytesOutCounter().getCount())
                     .isEqualTo(3 * bufferSize);
             IOMetrics ioMetrics = taskIOMetricGroup.createSnapshot();
-            assertThat(ioMetrics.getNumBytesProducedOfPartitions())
-                    .hasSize(1)
-                    .containsValue((long) 2 * bufferSize);
+            assertThat(ioMetrics.getResultPartitionBytes()).hasSize(1);
+            ResultPartitionBytes partitionBytes =
+                    
ioMetrics.getResultPartitionBytes().values().iterator().next();
+            assertThat(partitionBytes.getSubpartitionBytes())
+                    .containsExactly((long) 2 * bufferSize, (long) bufferSize);
         }
     }
 
@@ -498,9 +501,11 @@ class HsResultPartitionTest {
             
assertThat(taskIOMetricGroup.getNumBuffersOutCounter().getCount()).isEqualTo(1);
             
assertThat(taskIOMetricGroup.getNumBytesOutCounter().getCount()).isEqualTo(bufferSize);
             IOMetrics ioMetrics = taskIOMetricGroup.createSnapshot();
-            assertThat(ioMetrics.getNumBytesProducedOfPartitions())
-                    .hasSize(1)
-                    .containsValue((long) bufferSize);
+            assertThat(ioMetrics.getResultPartitionBytes()).hasSize(1);
+            ResultPartitionBytes partitionBytes =
+                    
ioMetrics.getResultPartitionBytes().values().iterator().next();
+            assertThat(partitionBytes.getSubpartitionBytes())
+                    .containsExactly((long) bufferSize, (long) bufferSize);
         }
     }
 
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/metrics/groups/TaskIOMetricGroupTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/metrics/groups/TaskIOMetricGroupTest.java
index 0faab8c1815..729c54f62df 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/metrics/groups/TaskIOMetricGroupTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/metrics/groups/TaskIOMetricGroupTest.java
@@ -21,6 +21,8 @@ package org.apache.flink.runtime.metrics.groups;
 import org.apache.flink.metrics.Counter;
 import org.apache.flink.metrics.SimpleCounter;
 import org.apache.flink.runtime.executiongraph.IOMetrics;
+import org.apache.flink.runtime.executiongraph.ResultPartitionBytes;
+import org.apache.flink.runtime.io.network.metrics.ResultPartitionBytesCounter;
 import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
 
 import org.junit.jupiter.api.Test;
@@ -96,26 +98,30 @@ class TaskIOMetricGroupTest {
     }
 
     @Test
-    void testNumBytesProducedOfPartitionsMetrics() {
+    void testResultPartitionBytesMetrics() {
         TaskMetricGroup task = 
UnregisteredMetricGroups.createUnregisteredTaskMetricGroup();
         TaskIOMetricGroup taskIO = task.getIOMetricGroup();
 
-        Counter c1 = new SimpleCounter();
-        c1.inc(32L);
-        Counter c2 = new SimpleCounter();
-        c2.inc(64L);
+        ResultPartitionBytesCounter c1 = new ResultPartitionBytesCounter(2);
+        ResultPartitionBytesCounter c2 = new ResultPartitionBytesCounter(2);
+
+        c1.inc(0, 32L);
+        c1.inc(1, 64L);
+        c2.incAll(128L);
 
         IntermediateResultPartitionID resultPartitionID1 = new 
IntermediateResultPartitionID();
         IntermediateResultPartitionID resultPartitionID2 = new 
IntermediateResultPartitionID();
 
-        taskIO.registerNumBytesProducedCounterForPartition(resultPartitionID1, 
c1);
-        taskIO.registerNumBytesProducedCounterForPartition(resultPartitionID2, 
c2);
+        taskIO.registerResultPartitionBytesCounter(resultPartitionID1, c1);
+        taskIO.registerResultPartitionBytesCounter(resultPartitionID2, c2);
 
-        Map<IntermediateResultPartitionID, Long> numBytesProducedOfPartitions =
-                taskIO.createSnapshot().getNumBytesProducedOfPartitions();
+        Map<IntermediateResultPartitionID, ResultPartitionBytes> 
resultPartitionBytes =
+                taskIO.createSnapshot().getResultPartitionBytes();
 
-        assertThat(numBytesProducedOfPartitions.size()).isEqualTo(2);
-        
assertThat(numBytesProducedOfPartitions.get(resultPartitionID1).longValue()).isEqualTo(32L);
-        
assertThat(numBytesProducedOfPartitions.get(resultPartitionID2).longValue()).isEqualTo(64L);
+        assertThat(resultPartitionBytes.size()).isEqualTo(2);
+        
assertThat(resultPartitionBytes.get(resultPartitionID1).getSubpartitionBytes())
+                .containsExactly(32L, 64L);
+        
assertThat(resultPartitionBytes.get(resultPartitionID2).getSubpartitionBytes())
+                .containsExactly(128L, 128L);
     }
 }
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/DefaultSchedulerTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/DefaultSchedulerTest.java
index 47c1ae6ff6b..59122521cb2 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/DefaultSchedulerTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/DefaultSchedulerTest.java
@@ -136,6 +136,7 @@ import static 
org.apache.flink.runtime.executiongraph.ExecutionGraphTestUtils.cr
 import static 
org.apache.flink.runtime.jobmaster.slotpool.DefaultDeclarativeSlotPoolTest.createSlotOffersForResourceRequirements;
 import static 
org.apache.flink.runtime.jobmaster.slotpool.SlotPoolTestUtils.offerSlots;
 import static 
org.apache.flink.runtime.scheduler.SchedulerTestingUtils.acknowledgePendingCheckpoint;
+import static 
org.apache.flink.runtime.scheduler.SchedulerTestingUtils.createFailedTaskExecutionState;
 import static 
org.apache.flink.runtime.scheduler.SchedulerTestingUtils.enableCheckpointing;
 import static 
org.apache.flink.runtime.scheduler.SchedulerTestingUtils.getCheckpointCoordinator;
 import static org.apache.flink.util.ExceptionUtils.findThrowable;
@@ -1782,12 +1783,6 @@ public class DefaultSchedulerTest extends TestLogger {
         schedulerClosed.get();
     }
 
-    public static TaskExecutionState createFailedTaskExecutionState(
-            ExecutionAttemptID executionAttemptID) {
-        return new TaskExecutionState(
-                executionAttemptID, ExecutionState.FAILED, new 
Exception("Expected failure cause"));
-    }
-
     private static long initiateFailure(
             DefaultScheduler scheduler,
             ExecutionAttemptID executionAttemptId,
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/SchedulerTestingUtils.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/SchedulerTestingUtils.java
index 0de11b3941e..c6497b9e96b 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/SchedulerTestingUtils.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/SchedulerTestingUtils.java
@@ -35,11 +35,14 @@ import 
org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.executiongraph.ExecutionGraph;
 import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
 import org.apache.flink.runtime.executiongraph.ExecutionVertex;
+import org.apache.flink.runtime.executiongraph.IOMetrics;
+import org.apache.flink.runtime.executiongraph.ResultPartitionBytes;
 import 
org.apache.flink.runtime.executiongraph.failover.flip1.TestRestartBackoffTimeStrategy;
 import org.apache.flink.runtime.io.network.partition.JobMasterPartitionTracker;
 import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
 import org.apache.flink.runtime.jobgraph.DistributionPattern;
 import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
+import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
 import org.apache.flink.runtime.jobgraph.JobGraph;
 import org.apache.flink.runtime.jobgraph.JobGraphBuilder;
 import org.apache.flink.runtime.jobgraph.JobVertex;
@@ -64,6 +67,7 @@ import java.util.ArrayList;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.List;
+import java.util.Map;
 import java.util.Objects;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ExecutionException;
@@ -434,4 +438,34 @@ public class SchedulerTestingUtils {
             vertex.deploy();
         }
     }
+
+    public static TaskExecutionState createFinishedTaskExecutionState(
+            ExecutionAttemptID attemptId,
+            Map<IntermediateResultPartitionID, ResultPartitionBytes> 
resultPartitionBytes) {
+        return new TaskExecutionState(
+                attemptId,
+                ExecutionState.FINISHED,
+                null,
+                null,
+                new IOMetrics(0, 0, 0, 0, 0, 0, 0, resultPartitionBytes));
+    }
+
+    public static TaskExecutionState createFinishedTaskExecutionState(
+            ExecutionAttemptID attemptId) {
+        return createFinishedTaskExecutionState(attemptId, 
Collections.emptyMap());
+    }
+
+    public static TaskExecutionState createFailedTaskExecutionState(
+            ExecutionAttemptID attemptId, Throwable failureCause) {
+        return new TaskExecutionState(attemptId, ExecutionState.FAILED, 
failureCause);
+    }
+
+    public static TaskExecutionState 
createFailedTaskExecutionState(ExecutionAttemptID attemptId) {
+        return createFailedTaskExecutionState(attemptId, new 
Exception("Expected failure cause"));
+    }
+
+    public static TaskExecutionState createCanceledTaskExecutionState(
+            ExecutionAttemptID attemptId) {
+        return new TaskExecutionState(attemptId, ExecutionState.CANCELED);
+    }
 }
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchSchedulerTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchSchedulerTest.java
index a6e87f9bfc2..2dc5389668a 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchSchedulerTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchSchedulerTest.java
@@ -30,9 +30,14 @@ import org.apache.flink.runtime.executiongraph.Execution;
 import org.apache.flink.runtime.executiongraph.ExecutionGraph;
 import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
 import org.apache.flink.runtime.executiongraph.ExecutionVertex;
-import org.apache.flink.runtime.executiongraph.IOMetrics;
+import org.apache.flink.runtime.executiongraph.ResultPartitionBytes;
+import 
org.apache.flink.runtime.executiongraph.failover.flip1.FixedDelayRestartBackoffTimeStrategy;
+import 
org.apache.flink.runtime.io.network.partition.PartitionNotFoundException;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
 import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
+import 
org.apache.flink.runtime.io.network.partition.TestingJobMasterPartitionTracker;
 import org.apache.flink.runtime.jobgraph.DistributionPattern;
+import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
 import org.apache.flink.runtime.jobgraph.JobGraph;
 import org.apache.flink.runtime.jobgraph.JobVertex;
 import org.apache.flink.runtime.scheduler.DefaultSchedulerBuilder;
@@ -41,16 +46,27 @@ import 
org.apache.flink.runtime.taskmanager.TaskExecutionState;
 import org.apache.flink.runtime.testtasks.NoOpInvokable;
 import org.apache.flink.testutils.TestingUtils;
 import org.apache.flink.testutils.executor.TestExecutorExtension;
+import org.apache.flink.util.concurrent.ManuallyTriggeredScheduledExecutor;
 
+import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
 import org.junit.jupiter.api.extension.RegisterExtension;
 
+import javax.annotation.Nullable;
+
 import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
 import java.util.Iterator;
 import java.util.List;
+import java.util.Map;
 import java.util.concurrent.ScheduledExecutorService;
 import java.util.stream.Collectors;
+import java.util.stream.LongStream;
 
+import static 
org.apache.flink.runtime.scheduler.SchedulerTestingUtils.createFailedTaskExecutionState;
+import static 
org.apache.flink.runtime.scheduler.SchedulerTestingUtils.createFinishedTaskExecutionState;
+import static 
org.apache.flink.shaded.guava30.com.google.common.collect.Iterables.getOnlyElement;
 import static org.assertj.core.api.Assertions.assertThat;
 
 /** Test for {@link AdaptiveBatchScheduler}. */
@@ -58,13 +74,20 @@ class AdaptiveBatchSchedulerTest {
 
     private static final int SOURCE_PARALLELISM_1 = 6;
     private static final int SOURCE_PARALLELISM_2 = 4;
+    private static final long SUBPARTITION_BYTES = 100L;
 
     @RegisterExtension
     static final TestExecutorExtension<ScheduledExecutorService> 
EXECUTOR_RESOURCE =
             TestingUtils.defaultExecutorExtension();
 
-    private static final ComponentMainThreadExecutor mainThreadExecutor =
-            ComponentMainThreadExecutorServiceAdapter.forMainThread();
+    private ComponentMainThreadExecutor mainThreadExecutor;
+    private ManuallyTriggeredScheduledExecutor taskRestartExecutor;
+
+    @BeforeEach
+    void setUp() {
+        mainThreadExecutor = 
ComponentMainThreadExecutorServiceAdapter.forMainThread();
+        taskRestartExecutor = new ManuallyTriggeredScheduledExecutor();
+    }
 
     @Test
     void testAdaptiveBatchScheduler() throws Exception {
@@ -122,20 +145,138 @@ class AdaptiveBatchSchedulerTest {
         assertThat(sink.getParallelism()).isEqualTo(SOURCE_PARALLELISM_1);
     }
 
+    @Test
+    void testUpdateBlockingResultInfoWhileScheduling() throws Exception {
+        JobGraph jobGraph = createJobGraph(false);
+        Iterator<JobVertex> jobVertexIterator = 
jobGraph.getVertices().iterator();
+        JobVertex source1 = jobVertexIterator.next();
+        JobVertex source2 = jobVertexIterator.next();
+        JobVertex sink = jobVertexIterator.next();
+
+        Configuration configuration = new Configuration();
+        configuration.set(
+                JobManagerOptions.SCHEDULER, 
JobManagerOptions.SchedulerType.AdaptiveBatch);
+
+        final TestingJobMasterPartitionTracker partitionTracker =
+                new TestingJobMasterPartitionTracker();
+        partitionTracker.setIsPartitionTrackedFunction(ignore -> true);
+        int maxParallelism = 6;
+
+        AdaptiveBatchScheduler scheduler =
+                new DefaultSchedulerBuilder(
+                                jobGraph, mainThreadExecutor, 
EXECUTOR_RESOURCE.getExecutor())
+                        .setDelayExecutor(taskRestartExecutor)
+                        .setJobMasterConfiguration(configuration)
+                        .setPartitionTracker(partitionTracker)
+                        .setRestartBackoffTimeStrategy(
+                                new FixedDelayRestartBackoffTimeStrategy
+                                                
.FixedDelayRestartBackoffTimeStrategyFactory(10, 0)
+                                        .create())
+                        .setVertexParallelismDecider((ignored) -> 
maxParallelism)
+                        .setDefaultMaxParallelism(maxParallelism)
+                        .buildAdaptiveBatchJobScheduler();
+
+        final DefaultExecutionGraph graph = (DefaultExecutionGraph) 
scheduler.getExecutionGraph();
+        final ExecutionJobVertex source1ExecutionJobVertex = 
graph.getJobVertex(source1.getID());
+        final ExecutionJobVertex sinkExecutionJobVertex = 
graph.getJobVertex(sink.getID());
+
+        PointwiseBlockingResultInfo blockingResultInfo;
+
+        scheduler.startScheduling();
+        // trigger source1 finished.
+        transitionExecutionsState(scheduler, ExecutionState.FINISHED, source1);
+        blockingResultInfo =
+                (PointwiseBlockingResultInfo) getBlockingResultInfo(scheduler, 
source1);
+        
assertThat(blockingResultInfo.getNumOfRecordedPartitions()).isEqualTo(SOURCE_PARALLELISM_1);
+
+        // trigger source2 finished.
+        transitionExecutionsState(scheduler, ExecutionState.FINISHED, source2);
+        blockingResultInfo =
+                (PointwiseBlockingResultInfo) getBlockingResultInfo(scheduler, 
source2);
+        
assertThat(blockingResultInfo.getNumOfRecordedPartitions()).isEqualTo(SOURCE_PARALLELISM_2);
+
+        // trigger sink fail with partition not found
+        triggerFailedByPartitionNotFound(
+                scheduler,
+                source1ExecutionJobVertex.getTaskVertices()[0],
+                sinkExecutionJobVertex.getTaskVertices()[0]);
+
+        taskRestartExecutor.triggerScheduledTasks();
+
+        // check the partition info is reset
+        assertThat(
+                        ((PointwiseBlockingResultInfo) 
getBlockingResultInfo(scheduler, source1))
+                                .getNumOfRecordedPartitions())
+                .isEqualTo(SOURCE_PARALLELISM_1 - 1);
+    }
+
+    private BlockingResultInfo getBlockingResultInfo(
+            AdaptiveBatchScheduler scheduler, JobVertex jobVertex) {
+        return scheduler.getBlockingResultInfo(
+                getOnlyElement(jobVertex.getProducedDataSets()).getId());
+    }
+
+    private void triggerFailedByPartitionNotFound(
+            SchedulerBase scheduler,
+            ExecutionVertex producerVertex,
+            ExecutionVertex consumerVertex) {
+        final Execution execution = 
consumerVertex.getCurrentExecutionAttempt();
+        final IntermediateResultPartitionID partitionId =
+                
getOnlyElement(producerVertex.getProducedPartitions().values()).getPartitionId();
+        // trigger execution vertex failed by partition not found.
+        transitionExecutionsState(
+                scheduler,
+                ExecutionState.FAILED,
+                Collections.singletonList(execution),
+                new PartitionNotFoundException(
+                        new ResultPartitionID(
+                                partitionId,
+                                
producerVertex.getCurrentExecutionAttempt().getAttemptId())));
+    }
+
     /** Transit the state of all executions. */
     public static void transitionExecutionsState(
-            final SchedulerBase scheduler, final ExecutionState state, 
List<Execution> executions) {
+            final SchedulerBase scheduler,
+            final ExecutionState state,
+            List<Execution> executions,
+            @Nullable Throwable throwable) {
         for (Execution execution : executions) {
-            scheduler.updateTaskExecutionState(
-                    new TaskExecutionState(
-                            execution.getAttemptId(),
-                            state,
-                            null,
-                            null,
-                            new IOMetrics(0, 0, 0, 0, 0, 0, 0)));
+            TaskExecutionState taskExecutionState;
+            if (state == ExecutionState.FINISHED) {
+                taskExecutionState =
+                        createFinishedTaskExecutionState(
+                                execution.getAttemptId(),
+                                
createResultPartitionBytesForExecution(execution));
+            } else if (state == ExecutionState.FAILED) {
+                taskExecutionState =
+                        
createFailedTaskExecutionState(execution.getAttemptId(), throwable);
+            } else {
+                throw new UnsupportedOperationException("Unsupported state " + 
state);
+            }
+            scheduler.updateTaskExecutionState(taskExecutionState);
         }
     }
 
+    static Map<IntermediateResultPartitionID, ResultPartitionBytes>
+            createResultPartitionBytesForExecution(Execution execution) {
+        Map<IntermediateResultPartitionID, ResultPartitionBytes> 
partitionBytes = new HashMap<>();
+        execution
+                .getVertex()
+                .getProducedPartitions()
+                .forEach(
+                        (partitionId, partition) -> {
+                            int numOfSubpartitions = 
partition.getNumberOfSubpartitions();
+                            partitionBytes.put(
+                                    partitionId,
+                                    new ResultPartitionBytes(
+                                            LongStream.range(0, 
numOfSubpartitions)
+                                                    .boxed()
+                                                    .mapToLong(ignored -> 
SUBPARTITION_BYTES)
+                                                    .toArray()));
+                        });
+        return partitionBytes;
+    }
+
     /** Transit the state of all executions in the Job Vertex. */
     public static void transitionExecutionsState(
             final SchedulerBase scheduler, final ExecutionState state, final 
JobVertex jobVertex) {
@@ -145,7 +286,7 @@ class AdaptiveBatchSchedulerTest {
                         .stream()
                         .map(ExecutionVertex::getCurrentExecutionAttempt)
                         .collect(Collectors.toList());
-        transitionExecutionsState(scheduler, state, executions);
+        transitionExecutionsState(scheduler, state, executions, null);
     }
 
     public JobVertex createJobVertex(String jobVertexName, int parallelism) {
@@ -178,6 +319,7 @@ class AdaptiveBatchSchedulerTest {
 
         return new DefaultSchedulerBuilder(
                         jobGraph, mainThreadExecutor, 
EXECUTOR_RESOURCE.getExecutor())
+                .setDelayExecutor(taskRestartExecutor)
                 .setJobMasterConfiguration(configuration)
                 .setVertexParallelismDecider((ignored) -> 10)
                 .buildAdaptiveBatchJobScheduler();
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfoTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfoTest.java
new file mode 100644
index 00000000000..f1c632b9ea0
--- /dev/null
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfoTest.java
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.scheduler.adaptivebatch;
+
+import org.apache.flink.runtime.executiongraph.ResultPartitionBytes;
+import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
+
+import org.junit.jupiter.api.Test;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
+/** Test for {@link AllToAllBlockingResultInfo}. */
+class AllToAllBlockingResultInfoTest {
+
+    @Test
+    void testGetNumBytesProducedForNonBroadcast() {
+        testGetNumBytesProduced(false, 192L);
+    }
+
+    @Test
+    void testGetNumBytesProducedForBroadcast() {
+        testGetNumBytesProduced(true, 96L);
+    }
+
+    @Test
+    void testGetAggregatedSubpartitionBytes() {
+        AllToAllBlockingResultInfo resultInfo =
+                new AllToAllBlockingResultInfo(new IntermediateDataSetID(), 2, 
2, false);
+        resultInfo.recordPartitionInfo(0, new ResultPartitionBytes(new long[] 
{32L, 64L}));
+        resultInfo.recordPartitionInfo(1, new ResultPartitionBytes(new long[] 
{128L, 256L}));
+
+        
assertThat(resultInfo.getAggregatedSubpartitionBytes()).containsExactly(160L, 
320L);
+    }
+
+    @Test
+    void testGetBytesWithPartialPartitionInfos() {
+        AllToAllBlockingResultInfo resultInfo =
+                new AllToAllBlockingResultInfo(new IntermediateDataSetID(), 2, 
2, false);
+        resultInfo.recordPartitionInfo(0, new ResultPartitionBytes(new long[] 
{32L, 64L}));
+
+        assertThatThrownBy(resultInfo::getNumBytesProduced)
+                .isInstanceOf(IllegalStateException.class);
+        assertThatThrownBy(resultInfo::getAggregatedSubpartitionBytes)
+                .isInstanceOf(IllegalStateException.class);
+    }
+
+    @Test
+    void testRecordPartitionInfoMultiTimes() {
+        AllToAllBlockingResultInfo resultInfo =
+                new AllToAllBlockingResultInfo(new IntermediateDataSetID(), 2, 
2, false);
+
+        ResultPartitionBytes partitionBytes1 = new ResultPartitionBytes(new 
long[] {32L, 64L});
+        ResultPartitionBytes partitionBytes2 = new ResultPartitionBytes(new 
long[] {64L, 128L});
+        ResultPartitionBytes partitionBytes3 = new ResultPartitionBytes(new 
long[] {128L, 256L});
+        ResultPartitionBytes partitionBytes4 = new ResultPartitionBytes(new 
long[] {256L, 512L});
+
+        // record partitionBytes1 for subtask 0 and then reset it
+        resultInfo.recordPartitionInfo(0, partitionBytes1);
+        assertThat(resultInfo.getNumOfRecordedPartitions()).isEqualTo(1);
+        resultInfo.resetPartitionInfo(0);
+        assertThat(resultInfo.getNumOfRecordedPartitions()).isEqualTo(0);
+
+        // record partitionBytes2 for subtask 0 and record partitionBytes3 for 
subtask 1
+        resultInfo.recordPartitionInfo(0, partitionBytes2);
+        resultInfo.recordPartitionInfo(1, partitionBytes3);
+
+        // The result info should be (partitionBytes2 + partitionBytes3)
+        assertThat(resultInfo.getNumBytesProduced()).isEqualTo(576L);
+        
assertThat(resultInfo.getAggregatedSubpartitionBytes()).containsExactly(192L, 
384L);
+        // The raw info should be clear
+        assertThat(resultInfo.getNumOfRecordedPartitions()).isEqualTo(0);
+
+        // reset subtask 0 and then record partitionBytes4 for subtask 0
+        resultInfo.resetPartitionInfo(0);
+        resultInfo.recordPartitionInfo(0, partitionBytes4);
+
+        // The result info should still be (partitionBytes2 + partitionBytes3)
+        assertThat(resultInfo.getNumBytesProduced()).isEqualTo(576L);
+        
assertThat(resultInfo.getAggregatedSubpartitionBytes()).containsExactly(192L, 
384L);
+        assertThat(resultInfo.getNumOfRecordedPartitions()).isEqualTo(0);
+    }
+
+    private void testGetNumBytesProduced(boolean isBroadcast, long 
expectedBytes) {
+        AllToAllBlockingResultInfo resultInfo =
+                new AllToAllBlockingResultInfo(new IntermediateDataSetID(), 2, 
2, isBroadcast);
+        resultInfo.recordPartitionInfo(0, new ResultPartitionBytes(new long[] 
{32L, 32L}));
+        resultInfo.recordPartitionInfo(1, new ResultPartitionBytes(new long[] 
{64L, 64L}));
+
+        assertThat(resultInfo.getNumBytesProduced()).isEqualTo(expectedBytes);
+    }
+}
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismDeciderTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismDeciderTest.java
index 387e2bdc051..0f588bcc5e6 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismDeciderTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismDeciderTest.java
@@ -21,6 +21,8 @@ package org.apache.flink.runtime.scheduler.adaptivebatch;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.configuration.JobManagerOptions;
 import org.apache.flink.configuration.MemorySize;
+import org.apache.flink.runtime.executiongraph.ResultPartitionBytes;
+import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
 
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
@@ -78,11 +80,8 @@ class DefaultVertexParallelismDeciderTest {
 
     @Test
     void testNormalizeParallelismDownToPowerOf2() {
-        BlockingResultInfo resultInfo1 =
-                
BlockingResultInfo.createFromBroadcastResult(Arrays.asList(BYTE_256_MB));
-        BlockingResultInfo resultInfo2 =
-                BlockingResultInfo.createFromNonBroadcastResult(
-                        Arrays.asList(BYTE_256_MB, BYTE_8_GB));
+        BlockingResultInfo resultInfo1 = 
createFromBroadcastResult(BYTE_256_MB);
+        BlockingResultInfo resultInfo2 = 
createFromNonBroadcastResult(BYTE_256_MB + BYTE_8_GB);
 
         int parallelism =
                 decider.decideParallelismForVertex(Arrays.asList(resultInfo1, 
resultInfo2));
@@ -92,11 +91,8 @@ class DefaultVertexParallelismDeciderTest {
 
     @Test
     void testNormalizeParallelismUpToPowerOf2() {
-        BlockingResultInfo resultInfo1 =
-                
BlockingResultInfo.createFromBroadcastResult(Arrays.asList(BYTE_256_MB));
-        BlockingResultInfo resultInfo2 =
-                BlockingResultInfo.createFromNonBroadcastResult(
-                        Arrays.asList(BYTE_1_GB, BYTE_8_GB));
+        BlockingResultInfo resultInfo1 = 
createFromBroadcastResult(BYTE_256_MB);
+        BlockingResultInfo resultInfo2 = 
createFromNonBroadcastResult(BYTE_1_GB + BYTE_8_GB);
 
         int parallelism =
                 decider.decideParallelismForVertex(Arrays.asList(resultInfo1, 
resultInfo2));
@@ -106,11 +102,8 @@ class DefaultVertexParallelismDeciderTest {
 
     @Test
     void testInitiallyNormalizedParallelismIsLargerThanMaxParallelism() {
-        BlockingResultInfo resultInfo1 =
-                
BlockingResultInfo.createFromBroadcastResult(Arrays.asList(BYTE_256_MB));
-        BlockingResultInfo resultInfo2 =
-                BlockingResultInfo.createFromNonBroadcastResult(
-                        Arrays.asList(BYTE_8_GB, BYTE_1_TB));
+        BlockingResultInfo resultInfo1 = 
createFromBroadcastResult(BYTE_256_MB);
+        BlockingResultInfo resultInfo2 = 
createFromNonBroadcastResult(BYTE_8_GB + BYTE_1_TB);
 
         int parallelism =
                 decider.decideParallelismForVertex(Arrays.asList(resultInfo1, 
resultInfo2));
@@ -120,10 +113,8 @@ class DefaultVertexParallelismDeciderTest {
 
     @Test
     void testInitiallyNormalizedParallelismIsSmallerThanMinParallelism() {
-        BlockingResultInfo resultInfo1 =
-                
BlockingResultInfo.createFromBroadcastResult(Arrays.asList(BYTE_256_MB));
-        BlockingResultInfo resultInfo2 =
-                
BlockingResultInfo.createFromNonBroadcastResult(Arrays.asList(BYTE_512_MB));
+        BlockingResultInfo resultInfo1 = 
createFromBroadcastResult(BYTE_256_MB);
+        BlockingResultInfo resultInfo2 = 
createFromNonBroadcastResult(BYTE_512_MB);
 
         int parallelism =
                 decider.decideParallelismForVertex(Arrays.asList(resultInfo1, 
resultInfo2));
@@ -133,10 +124,8 @@ class DefaultVertexParallelismDeciderTest {
 
     @Test
     void testBroadcastRatioExceedsCapRatio() {
-        BlockingResultInfo resultInfo1 =
-                
BlockingResultInfo.createFromBroadcastResult(Arrays.asList(BYTE_1_GB));
-        BlockingResultInfo resultInfo2 =
-                
BlockingResultInfo.createFromNonBroadcastResult(Arrays.asList(BYTE_8_GB));
+        BlockingResultInfo resultInfo1 = createFromBroadcastResult(BYTE_1_GB);
+        BlockingResultInfo resultInfo2 = 
createFromNonBroadcastResult(BYTE_8_GB);
 
         int parallelism =
                 decider.decideParallelismForVertex(Arrays.asList(resultInfo1, 
resultInfo2));
@@ -146,15 +135,58 @@ class DefaultVertexParallelismDeciderTest {
 
     @Test
     void testNonBroadcastBytesCanNotDividedEvenly() {
-        BlockingResultInfo resultInfo1 =
-                
BlockingResultInfo.createFromBroadcastResult(Arrays.asList(BYTE_512_MB));
-        BlockingResultInfo resultInfo2 =
-                BlockingResultInfo.createFromNonBroadcastResult(
-                        Arrays.asList(BYTE_256_MB, BYTE_8_GB));
+        BlockingResultInfo resultInfo1 = 
createFromBroadcastResult(BYTE_512_MB);
+        BlockingResultInfo resultInfo2 = 
createFromNonBroadcastResult(BYTE_256_MB + BYTE_8_GB);
 
         int parallelism =
                 decider.decideParallelismForVertex(Arrays.asList(resultInfo1, 
resultInfo2));
 
         assertThat(parallelism).isEqualTo(16);
     }
+
+    private static class TestingBlockingResultInfo implements 
BlockingResultInfo {
+
+        private final boolean isBroadcast;
+
+        private final long producedBytes;
+
+        private TestingBlockingResultInfo(boolean isBroadcast, long 
producedBytes) {
+            this.isBroadcast = isBroadcast;
+            this.producedBytes = producedBytes;
+        }
+
+        @Override
+        public IntermediateDataSetID getResultId() {
+            return new IntermediateDataSetID();
+        }
+
+        @Override
+        public boolean isBroadcast() {
+            return isBroadcast;
+        }
+
+        @Override
+        public boolean isPointwise() {
+            return false;
+        }
+
+        @Override
+        public long getNumBytesProduced() {
+            return producedBytes;
+        }
+
+        @Override
+        public void recordPartitionInfo(int partitionIndex, 
ResultPartitionBytes partitionBytes) {}
+
+        @Override
+        public void resetPartitionInfo(int partitionIndex) {}
+    }
+
+    private static BlockingResultInfo createFromBroadcastResult(long 
producedBytes) {
+        return new TestingBlockingResultInfo(true, producedBytes);
+    }
+
+    private static BlockingResultInfo createFromNonBroadcastResult(long 
producedBytes) {
+        return new TestingBlockingResultInfo(false, producedBytes);
+    }
 }
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/PointwiseBlockingResultInfoTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/PointwiseBlockingResultInfoTest.java
new file mode 100644
index 00000000000..556a2d48876
--- /dev/null
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/PointwiseBlockingResultInfoTest.java
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.scheduler.adaptivebatch;
+
+import org.apache.flink.runtime.executiongraph.ResultPartitionBytes;
+import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
+
+import org.junit.jupiter.api.Test;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
+/** Test for {@link PointwiseBlockingResultInfo}. */
+class PointwiseBlockingResultInfoTest {
+
+    @Test
+    void testGetNumBytesProduced() {
+        PointwiseBlockingResultInfo resultInfo =
+                new PointwiseBlockingResultInfo(new IntermediateDataSetID(), 
2, 2);
+        resultInfo.recordPartitionInfo(0, new ResultPartitionBytes(new long[] 
{32L, 32L}));
+        resultInfo.recordPartitionInfo(1, new ResultPartitionBytes(new long[] 
{64L, 64L}));
+
+        assertThat(resultInfo.getNumBytesProduced()).isEqualTo(192L);
+    }
+
+    @Test
+    void testGetBytesWithPartialPartitionInfos() {
+        PointwiseBlockingResultInfo resultInfo =
+                new PointwiseBlockingResultInfo(new IntermediateDataSetID(), 
2, 2);
+        resultInfo.recordPartitionInfo(0, new ResultPartitionBytes(new long[] 
{32L, 64L}));
+        assertThatThrownBy(resultInfo::getNumBytesProduced)
+                .isInstanceOf(IllegalStateException.class);
+    }
+
+    @Test
+    void testPartitionFinishedMultiTimes() {
+        PointwiseBlockingResultInfo resultInfo =
+                new PointwiseBlockingResultInfo(new IntermediateDataSetID(), 
2, 2);
+
+        resultInfo.recordPartitionInfo(0, new ResultPartitionBytes(new long[] 
{32L, 64L}));
+        resultInfo.recordPartitionInfo(1, new ResultPartitionBytes(new long[] 
{64L, 128L}));
+        assertThat(resultInfo.getNumOfRecordedPartitions()).isEqualTo(2);
+        assertThat(resultInfo.getNumBytesProduced()).isEqualTo(288L);
+
+        // reset partition info
+        resultInfo.resetPartitionInfo(0);
+        assertThat(resultInfo.getNumOfRecordedPartitions()).isEqualTo(1);
+
+        // record partition info again
+        resultInfo.recordPartitionInfo(0, new ResultPartitionBytes(new long[] 
{64L, 128L}));
+        assertThat(resultInfo.getNumBytesProduced()).isEqualTo(384L);
+    }
+}
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/SpeculativeSchedulerTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/SpeculativeSchedulerTest.java
index b438240a599..189c75c321c 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/SpeculativeSchedulerTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/SpeculativeSchedulerTest.java
@@ -33,7 +33,6 @@ import org.apache.flink.runtime.executiongraph.Execution;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
 import org.apache.flink.runtime.executiongraph.ExecutionVertex;
-import org.apache.flink.runtime.executiongraph.IOMetrics;
 import 
org.apache.flink.runtime.executiongraph.failover.flip1.TestRestartBackoffTimeStrategy;
 import 
org.apache.flink.runtime.io.network.partition.PartitionNotFoundException;
 import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
@@ -75,8 +74,11 @@ import java.util.stream.Collectors;
 
 import static 
org.apache.flink.runtime.executiongraph.ExecutionGraphTestUtils.completeCancellingForAllVertices;
 import static 
org.apache.flink.runtime.executiongraph.ExecutionGraphTestUtils.createNoOpVertex;
-import static 
org.apache.flink.runtime.scheduler.DefaultSchedulerTest.createFailedTaskExecutionState;
 import static 
org.apache.flink.runtime.scheduler.DefaultSchedulerTest.singleNonParallelJobVertexJobGraph;
+import static 
org.apache.flink.runtime.scheduler.SchedulerTestingUtils.createCanceledTaskExecutionState;
+import static 
org.apache.flink.runtime.scheduler.SchedulerTestingUtils.createFailedTaskExecutionState;
+import static 
org.apache.flink.runtime.scheduler.SchedulerTestingUtils.createFinishedTaskExecutionState;
+import static 
org.apache.flink.runtime.scheduler.adaptivebatch.AdaptiveBatchSchedulerTest.createResultPartitionBytesForExecution;
 import static org.assertj.core.api.Assertions.assertThat;
 
 /** Tests for {@link SpeculativeScheduler}. */
@@ -212,9 +214,8 @@ class SpeculativeSchedulerTest {
         notifySlowTask(scheduler, attempt1);
         final Execution attempt2 = getExecution(ev, 1);
         scheduler.updateTaskExecutionState(
-                new TaskExecutionState(
+                createFailedTaskExecutionState(
                         attempt1.getAttemptId(),
-                        ExecutionState.FAILED,
                         new PartitionNotFoundException(new 
ResultPartitionID())));
 
         assertThat(attempt2.getState()).isEqualTo(ExecutionState.CANCELING);
@@ -234,7 +235,7 @@ class SpeculativeSchedulerTest {
         notifySlowTask(scheduler, attempt1);
         final Execution attempt2 = getExecution(ev, 1);
         scheduler.updateTaskExecutionState(
-                new TaskExecutionState(attempt1.getAttemptId(), 
ExecutionState.FINISHED));
+                createFinishedTaskExecutionState(attempt1.getAttemptId()));
 
         assertThat(attempt2.getState()).isEqualTo(ExecutionState.CANCELING);
     }
@@ -251,7 +252,7 @@ class SpeculativeSchedulerTest {
         notifySlowTask(scheduler, attempt1);
         final Execution attempt2 = getExecution(ev, 1);
         scheduler.updateTaskExecutionState(
-                new TaskExecutionState(attempt1.getAttemptId(), 
ExecutionState.FINISHED));
+                createFinishedTaskExecutionState(attempt1.getAttemptId()));
 
         assertThat(attempt2.getState()).isEqualTo(ExecutionState.CANCELED);
     }
@@ -266,9 +267,8 @@ class SpeculativeSchedulerTest {
 
         // A partition exception can result in a restart of the whole 
execution vertex.
         scheduler.updateTaskExecutionState(
-                new TaskExecutionState(
+                createFailedTaskExecutionState(
                         attempt1.getAttemptId(),
-                        ExecutionState.FAILED,
                         new PartitionNotFoundException(new 
ResultPartitionID())));
 
         completeCancellingForAllVertices(scheduler.getExecutionGraph());
@@ -316,9 +316,8 @@ class SpeculativeSchedulerTest {
         notifySlowTask(scheduler, attempt1);
 
         final TaskExecutionState failedState =
-                new TaskExecutionState(
+                createFailedTaskExecutionState(
                         attempt1.getAttemptId(),
-                        ExecutionState.FAILED,
                         new SuppressRestartsException(
                                 new Exception("Forced failure for testing.")));
         scheduler.updateTaskExecutionState(failedState);
@@ -376,12 +375,9 @@ class SpeculativeSchedulerTest {
         // Finishing any source execution attempt will finish the source 
execution vertex, and then
         // finish the job vertex.
         scheduler.updateTaskExecutionState(
-                new TaskExecutionState(
+                createFinishedTaskExecutionState(
                         sourceAttempt1.getAttemptId(),
-                        ExecutionState.FINISHED,
-                        null,
-                        null,
-                        new IOMetrics(0, 0, 0, 0, 0, 0, 0)));
+                        
createResultPartitionBytesForExecution(sourceAttempt1)));
         assertThat(sinkExecutionJobVertex.getParallelism()).isEqualTo(3);
 
         // trigger sink vertex speculation
@@ -421,12 +417,12 @@ class SpeculativeSchedulerTest {
         // finishes first
         final Execution attempt2 = getExecution(ev, 1);
         scheduler.updateTaskExecutionState(
-                new TaskExecutionState(attempt2.getAttemptId(), 
ExecutionState.FINISHED));
+                createFinishedTaskExecutionState(attempt2.getAttemptId()));
         
assertThat(scheduler.getNumEffectiveSpeculativeExecutions()).isEqualTo(1);
 
         // complete cancellation
         scheduler.updateTaskExecutionState(
-                new TaskExecutionState(attempt1.getAttemptId(), 
ExecutionState.CANCELED));
+                createCanceledTaskExecutionState(attempt1.getAttemptId()));
 
         // trigger a global failure to reset the vertex.
         // after that, no speculative execution finishes before its original 
execution and the
@@ -441,7 +437,7 @@ class SpeculativeSchedulerTest {
         // numEffectiveSpeculativeExecutions will not increase if an original 
execution attempt
         // finishes first
         scheduler.updateTaskExecutionState(
-                new TaskExecutionState(attempt3.getAttemptId(), 
ExecutionState.FINISHED));
+                createFinishedTaskExecutionState(attempt3.getAttemptId()));
         assertThat(scheduler.getNumEffectiveSpeculativeExecutions()).isZero();
     }
 

Reply via email to