This is an automated email from the ASF dual-hosted git repository. wanglijie pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
commit 963333795ae5bf0b6d6c8b1fa793be742a7c1af3 Author: Lijie Wang <[email protected]> AuthorDate: Mon Dec 19 11:03:29 2022 +0800 [FLINK-29664][runtime] Support to collect the size of subpartitions. This closes #21111. --- .../flink/runtime/executiongraph/Execution.java | 12 +- .../flink/runtime/executiongraph/IOMetrics.java | 59 ++++++-- .../executiongraph/ResultPartitionBytes.java | 37 +++++ .../metrics/ResultPartitionBytesCounter.java | 53 +++++++ .../partition/BufferWritingResultPartition.java | 4 +- .../io/network/partition/ResultPartition.java | 14 +- .../partition/SortMergeResultPartition.java | 14 +- .../partition/hybrid/HsResultPartition.java | 4 +- .../runtime/metrics/groups/TaskIOMetricGroup.java | 16 +- .../flink/runtime/scheduler/DefaultScheduler.java | 3 +- .../flink/runtime/scheduler/SchedulerBase.java | 5 +- .../adaptivebatch/AbstractBlockingResultInfo.java | 75 ++++++++++ .../adaptivebatch/AdaptiveBatchScheduler.java | 98 +++++++++++- .../adaptivebatch/AllToAllBlockingResultInfo.java | 112 ++++++++++++++ .../adaptivebatch/BlockingResultInfo.java | 113 +++++++------- .../DefaultVertexParallelismDecider.java | 10 +- .../adaptivebatch/PointwiseBlockingResultInfo.java | 53 +++++++ .../adaptivebatch/SpeculativeScheduler.java | 5 +- .../DefaultExecutionGraphDeploymentTest.java | 22 ++- .../io/network/partition/ResultPartitionTest.java | 34 +++-- .../partition/SortMergeResultPartitionTest.java | 23 ++- .../partition/hybrid/HsResultPartitionTest.java | 17 ++- .../metrics/groups/TaskIOMetricGroupTest.java | 30 ++-- .../runtime/scheduler/DefaultSchedulerTest.java | 7 +- .../runtime/scheduler/SchedulerTestingUtils.java | 34 +++++ .../adaptivebatch/AdaptiveBatchSchedulerTest.java | 166 +++++++++++++++++++-- .../AllToAllBlockingResultInfoTest.java | 108 ++++++++++++++ .../DefaultVertexParallelismDeciderTest.java | 88 +++++++---- .../PointwiseBlockingResultInfoTest.java | 69 +++++++++ .../adaptivebatch/SpeculativeSchedulerTest.java | 32 ++-- 30 files changed, 1094 insertions(+), 223 deletions(-) diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java index 3323f453373..b957c6b467e 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java @@ -1548,7 +1548,17 @@ public class Execution } } if (metrics != null) { - this.ioMetrics = metrics; + // Drop IOMetrics#resultPartitionBytes because it will not be used anymore. It can + // result in very high memory usage when there are many executions and sub-partitions. + this.ioMetrics = + new IOMetrics( + metrics.getNumBytesIn(), + metrics.getNumBytesOut(), + metrics.getNumRecordsIn(), + metrics.getNumRecordsOut(), + metrics.getAccumulateIdleTime(), + metrics.getAccumulateBusyTime(), + metrics.getAccumulateBackPressuredTime()); } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IOMetrics.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IOMetrics.java index e612837531a..5dd24dab563 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IOMetrics.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IOMetrics.java @@ -18,14 +18,20 @@ package org.apache.flink.runtime.executiongraph; -import org.apache.flink.metrics.Counter; +import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.metrics.Gauge; import org.apache.flink.metrics.Meter; +import org.apache.flink.runtime.io.network.metrics.ResultPartitionBytesCounter; import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; +import javax.annotation.Nullable; + import java.io.Serializable; -import java.util.HashMap; +import java.util.Collections; import java.util.Map; +import java.util.stream.Collectors; + +import static org.apache.flink.util.Preconditions.checkNotNull; /** An instance of this class represents a snapshot of the io-related metrics of a single task. */ public class IOMetrics implements Serializable { @@ -42,18 +48,19 @@ public class IOMetrics implements Serializable { protected double accumulateBusyTime; protected long accumulateIdleTime; - protected final Map<IntermediateResultPartitionID, Long> numBytesProducedOfPartitions = - new HashMap<>(); + @Nullable + protected Map<IntermediateResultPartitionID, ResultPartitionBytes> resultPartitionBytes; public IOMetrics( Meter recordsIn, Meter recordsOut, Meter bytesIn, Meter bytesOut, - Map<IntermediateResultPartitionID, Counter> numBytesProducedCounters, Gauge<Long> accumulatedBackPressuredTime, Gauge<Long> accumulatedIdleTime, - Gauge<Double> accumulatedBusyTime) { + Gauge<Double> accumulatedBusyTime, + Map<IntermediateResultPartitionID, ResultPartitionBytesCounter> + resultPartitionBytesCounters) { this.numRecordsIn = recordsIn.getCount(); this.numRecordsOut = recordsOut.getCount(); this.numBytesIn = bytesIn.getCount(); @@ -61,11 +68,12 @@ public class IOMetrics implements Serializable { this.accumulateBackPressuredTime = accumulatedBackPressuredTime.getValue(); this.accumulateBusyTime = accumulatedBusyTime.getValue(); this.accumulateIdleTime = accumulatedIdleTime.getValue(); - - for (Map.Entry<IntermediateResultPartitionID, Counter> counter : - numBytesProducedCounters.entrySet()) { - numBytesProducedOfPartitions.put(counter.getKey(), counter.getValue().getCount()); - } + this.resultPartitionBytes = + resultPartitionBytesCounters.entrySet().stream() + .collect( + Collectors.toMap( + Map.Entry::getKey, + entry -> entry.getValue().createSnapshot())); } public IOMetrics( @@ -74,8 +82,30 @@ public class IOMetrics implements Serializable { long numRecordsIn, long numRecordsOut, long accumulateIdleTime, - long accumulateBusyTime, + double accumulateBusyTime, long accumulateBackPressuredTime) { + this( + numBytesIn, + numBytesOut, + numRecordsIn, + numRecordsOut, + accumulateIdleTime, + accumulateBusyTime, + accumulateBackPressuredTime, + null); + } + + @VisibleForTesting + public IOMetrics( + long numBytesIn, + long numBytesOut, + long numRecordsIn, + long numRecordsOut, + long accumulateIdleTime, + double accumulateBusyTime, + long accumulateBackPressuredTime, + @Nullable + Map<IntermediateResultPartitionID, ResultPartitionBytes> resultPartitionBytes) { this.numBytesIn = numBytesIn; this.numBytesOut = numBytesOut; this.numRecordsIn = numRecordsIn; @@ -83,6 +113,7 @@ public class IOMetrics implements Serializable { this.accumulateIdleTime = accumulateIdleTime; this.accumulateBusyTime = accumulateBusyTime; this.accumulateBackPressuredTime = accumulateBackPressuredTime; + this.resultPartitionBytes = resultPartitionBytes; } public long getNumRecordsIn() { @@ -113,7 +144,7 @@ public class IOMetrics implements Serializable { return accumulateIdleTime; } - public Map<IntermediateResultPartitionID, Long> getNumBytesProducedOfPartitions() { - return numBytesProducedOfPartitions; + public Map<IntermediateResultPartitionID, ResultPartitionBytes> getResultPartitionBytes() { + return Collections.unmodifiableMap(checkNotNull(resultPartitionBytes)); } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ResultPartitionBytes.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ResultPartitionBytes.java new file mode 100644 index 00000000000..630a828c648 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ResultPartitionBytes.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.executiongraph; + +import java.io.Serializable; + +import static org.apache.flink.util.Preconditions.checkNotNull; + +/** This class represents a snapshot of the result partition bytes metrics. */ +public class ResultPartitionBytes implements Serializable { + + private final long[] subpartitionBytes; + + public ResultPartitionBytes(long[] subpartitionBytes) { + this.subpartitionBytes = checkNotNull(subpartitionBytes); + } + + public long[] getSubpartitionBytes() { + return subpartitionBytes; + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/metrics/ResultPartitionBytesCounter.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/metrics/ResultPartitionBytesCounter.java new file mode 100644 index 00000000000..21df06b6f68 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/metrics/ResultPartitionBytesCounter.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.io.network.metrics; + +import org.apache.flink.metrics.Counter; +import org.apache.flink.metrics.SimpleCounter; +import org.apache.flink.runtime.executiongraph.ResultPartitionBytes; + +import java.util.ArrayList; +import java.util.List; + +/** This counter will count the data size of a partition. */ +public class ResultPartitionBytesCounter { + + /** The data size of each subpartition. */ + private final List<Counter> subpartitionBytes; + + public ResultPartitionBytesCounter(int numSubpartitions) { + this.subpartitionBytes = new ArrayList<>(); + for (int i = 0; i < numSubpartitions; ++i) { + subpartitionBytes.add(new SimpleCounter()); + } + } + + public void inc(int targetSubpartition, long bytes) { + subpartitionBytes.get(targetSubpartition).inc(bytes); + } + + public void incAll(long bytes) { + subpartitionBytes.forEach(counter -> counter.inc(bytes)); + } + + public ResultPartitionBytes createSnapshot() { + return new ResultPartitionBytes( + subpartitionBytes.stream().mapToLong(Counter::getCount).toArray()); + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/BufferWritingResultPartition.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/BufferWritingResultPartition.java index 8b36a7587d1..2107a24c747 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/BufferWritingResultPartition.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/BufferWritingResultPartition.java @@ -423,7 +423,7 @@ public abstract class BufferWritingResultPartition extends ResultPartition { final BufferBuilder bufferBuilder = unicastBufferBuilders[targetSubpartition]; if (bufferBuilder != null) { int bytes = bufferBuilder.finish(); - numBytesProduced.inc(bytes); + resultPartitionBytes.inc(targetSubpartition, bytes); numBytesOut.inc(bytes); numBuffersOut.inc(); unicastBufferBuilders[targetSubpartition] = null; @@ -440,7 +440,7 @@ public abstract class BufferWritingResultPartition extends ResultPartition { private void finishBroadcastBufferBuilder() { if (broadcastBufferBuilder != null) { int bytes = broadcastBufferBuilder.finish(); - numBytesProduced.inc(bytes); + resultPartitionBytes.incAll(bytes); numBytesOut.inc(bytes * numSubpartitions); numBuffersOut.inc(numSubpartitions); broadcastBufferBuilder.close(); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartition.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartition.java index e3b2e2ff74e..8f3631de15d 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartition.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartition.java @@ -28,6 +28,7 @@ import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter; import org.apache.flink.runtime.io.network.buffer.Buffer; import org.apache.flink.runtime.io.network.buffer.BufferCompressor; import org.apache.flink.runtime.io.network.buffer.BufferPool; +import org.apache.flink.runtime.io.network.metrics.ResultPartitionBytesCounter; import org.apache.flink.runtime.io.network.partition.consumer.LocalInputChannel; import org.apache.flink.runtime.io.network.partition.consumer.RemoteInputChannel; import org.apache.flink.runtime.jobgraph.DistributionPattern; @@ -112,13 +113,7 @@ public abstract class ResultPartition implements ResultPartitionWriter { protected Counter numBuffersOut = new SimpleCounter(); - /** - * The difference with {@link #numBytesOut} : numBytesProduced represents the number of bytes - * actually produced, and numBytesOut represents the number of bytes sent to downstream tasks. - * In unicast scenarios, these two values should be equal. In broadcast scenarios, numBytesOut - * should be (N * numBytesProduced), where N refers to the number of subpartitions. - */ - protected Counter numBytesProduced = new SimpleCounter(); + protected ResultPartitionBytesCounter resultPartitionBytes; public ResultPartition( String owningTaskName, @@ -141,6 +136,7 @@ public abstract class ResultPartition implements ResultPartitionWriter { this.partitionManager = checkNotNull(partitionManager); this.bufferCompressor = bufferCompressor; this.bufferPoolFactory = bufferPoolFactory; + this.resultPartitionBytes = new ResultPartitionBytesCounter(numSubpartitions); } /** @@ -301,8 +297,8 @@ public abstract class ResultPartition implements ResultPartitionWriter { public void setMetricGroup(TaskIOMetricGroup metrics) { numBytesOut = metrics.getNumBytesOutCounter(); numBuffersOut = metrics.getNumBuffersOutCounter(); - metrics.registerNumBytesProducedCounterForPartition( - partitionId.getPartitionId(), numBytesProduced); + metrics.registerResultPartitionBytesCounter( + partitionId.getPartitionId(), resultPartitionBytes); } /** diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartition.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartition.java index 2b603a219f5..def80f40574 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartition.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartition.java @@ -391,7 +391,7 @@ public class SortMergeResultPartition extends ResultPartition { break; } - updateStatistics(bufferWithChannel.getBuffer(), isBroadcast); + updateStatistics(bufferWithChannel, isBroadcast); toWrite.add(compressBufferIfPossible(bufferWithChannel)); } while (true); @@ -424,10 +424,14 @@ public class SortMergeResultPartition extends ResultPartition { return new BufferWithChannel(buffer, bufferWithChannel.getChannelIndex()); } - private void updateStatistics(Buffer buffer, boolean isBroadcast) { + private void updateStatistics(BufferWithChannel bufferWithChannel, boolean isBroadcast) { numBuffersOut.inc(isBroadcast ? numSubpartitions : 1); - long readableBytes = buffer.readableBytes(); - numBytesProduced.inc(readableBytes); + long readableBytes = bufferWithChannel.getBuffer().readableBytes(); + if (isBroadcast) { + resultPartitionBytes.incAll(readableBytes); + } else { + resultPartitionBytes.inc(bufferWithChannel.getChannelIndex(), readableBytes); + } numBytesOut.inc(isBroadcast ? readableBytes * numSubpartitions : readableBytes); } @@ -456,7 +460,7 @@ public class SortMergeResultPartition extends ResultPartition { NetworkBuffer buffer = new NetworkBuffer(writeBuffer, (buf) -> {}, dataType, toCopy); BufferWithChannel bufferWithChannel = new BufferWithChannel(buffer, targetSubpartition); - updateStatistics(buffer, isBroadcast); + updateStatistics(bufferWithChannel, isBroadcast); toWrite.add(compressBufferIfPossible(bufferWithChannel)); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsResultPartition.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsResultPartition.java index d4f2f632aff..eb915335ab1 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsResultPartition.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsResultPartition.java @@ -152,7 +152,7 @@ public class HsResultPartition extends ResultPartition { @Override public void emitRecord(ByteBuffer record, int targetSubpartition) throws IOException { - numBytesProduced.inc(record.remaining()); + resultPartitionBytes.inc(targetSubpartition, record.remaining()); emit(record, targetSubpartition, Buffer.DataType.DATA_BUFFER); } @@ -173,7 +173,7 @@ public class HsResultPartition extends ResultPartition { } private void broadcast(ByteBuffer record, Buffer.DataType dataType) throws IOException { - numBytesProduced.inc(record.remaining()); + resultPartitionBytes.incAll(record.remaining()); if (isBroadcastOnly) { emit(record, BROADCAST_CHANNEL, dataType); } else { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/metrics/groups/TaskIOMetricGroup.java b/flink-runtime/src/main/java/org/apache/flink/runtime/metrics/groups/TaskIOMetricGroup.java index 17dca7b97b7..6903f666bda 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/metrics/groups/TaskIOMetricGroup.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/metrics/groups/TaskIOMetricGroup.java @@ -25,6 +25,7 @@ import org.apache.flink.metrics.Meter; import org.apache.flink.metrics.MeterView; import org.apache.flink.metrics.SimpleCounter; import org.apache.flink.runtime.executiongraph.IOMetrics; +import org.apache.flink.runtime.io.network.metrics.ResultPartitionBytesCounter; import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; import org.apache.flink.runtime.metrics.DescriptiveStatisticsHistogram; import org.apache.flink.runtime.metrics.MetricNames; @@ -71,8 +72,8 @@ public class TaskIOMetricGroup extends ProxyMetricGroup<TaskMetricGroup> { private long taskStartTime; - private final Map<IntermediateResultPartitionID, Counter> numBytesProducedOfPartitions = - new HashMap<>(); + private final Map<IntermediateResultPartitionID, ResultPartitionBytesCounter> + resultPartitionBytes = new HashMap<>(); public TaskIOMetricGroup(TaskMetricGroup parent) { super(parent); @@ -135,10 +136,10 @@ public class TaskIOMetricGroup extends ProxyMetricGroup<TaskMetricGroup> { numRecordsOutRate, numBytesInRate, numBytesOutRate, - numBytesProducedOfPartitions, accumulatedBackPressuredTime, accumulatedIdleTime, - accumulatedBusyTime); + accumulatedBusyTime, + resultPartitionBytes); } // ============================================================================================ @@ -238,9 +239,10 @@ public class TaskIOMetricGroup extends ProxyMetricGroup<TaskMetricGroup> { this.numRecordsOut.addCounter(numRecordsOutCounter); } - public void registerNumBytesProducedCounterForPartition( - IntermediateResultPartitionID resultPartitionId, Counter numBytesProducedCounter) { - this.numBytesProducedOfPartitions.put(resultPartitionId, numBytesProducedCounter); + public void registerResultPartitionBytesCounter( + IntermediateResultPartitionID resultPartitionId, + ResultPartitionBytesCounter resultPartitionBytesCounter) { + this.resultPartitionBytes.put(resultPartitionId, resultPartitionBytesCounter); } public void registerMailboxSizeSupplier(SizeGauge.SizeSupplier<Integer> supplier) { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/DefaultScheduler.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/DefaultScheduler.java index da59e2275fd..2136479a2c7 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/DefaultScheduler.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/DefaultScheduler.java @@ -30,6 +30,7 @@ import org.apache.flink.runtime.concurrent.ComponentMainThreadExecutor; import org.apache.flink.runtime.execution.ExecutionState; import org.apache.flink.runtime.executiongraph.Execution; import org.apache.flink.runtime.executiongraph.ExecutionJobVertex; +import org.apache.flink.runtime.executiongraph.IOMetrics; import org.apache.flink.runtime.executiongraph.JobStatusListener; import org.apache.flink.runtime.executiongraph.failover.flip1.ExecutionFailureHandler; import org.apache.flink.runtime.executiongraph.failover.flip1.FailoverStrategy; @@ -215,7 +216,7 @@ public class DefaultScheduler extends SchedulerBase implements SchedulerOperatio } @Override - protected void onTaskFinished(final Execution execution) { + protected void onTaskFinished(final Execution execution, final IOMetrics ioMetrics) { checkState(execution.getState() == ExecutionState.FINISHED); final ExecutionVertexID executionVertexId = execution.getVertex().getID(); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/SchedulerBase.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/SchedulerBase.java index 02c19f92391..846d0948e61 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/SchedulerBase.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/SchedulerBase.java @@ -53,6 +53,7 @@ import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.executiongraph.ExecutionGraph; import org.apache.flink.runtime.executiongraph.ExecutionJobVertex; import org.apache.flink.runtime.executiongraph.ExecutionVertex; +import org.apache.flink.runtime.executiongraph.IOMetrics; import org.apache.flink.runtime.executiongraph.JobStatusListener; import org.apache.flink.runtime.executiongraph.JobStatusProvider; import org.apache.flink.runtime.executiongraph.TaskExecutionStateTransition; @@ -733,7 +734,7 @@ public abstract class SchedulerBase implements SchedulerNG, CheckpointScheduling // can be refined in FLINK-14233 after the actions are factored out from ExecutionGraph. switch (taskExecutionState.getExecutionState()) { case FINISHED: - onTaskFinished(execution); + onTaskFinished(execution, taskExecutionState.getIOMetrics()); break; case FAILED: onTaskFailed(execution); @@ -741,7 +742,7 @@ public abstract class SchedulerBase implements SchedulerNG, CheckpointScheduling } } - protected abstract void onTaskFinished(final Execution execution); + protected abstract void onTaskFinished(final Execution execution, final IOMetrics ioMetrics); protected abstract void onTaskFailed(final Execution execution); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AbstractBlockingResultInfo.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AbstractBlockingResultInfo.java new file mode 100644 index 00000000000..33147bcdc16 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AbstractBlockingResultInfo.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.scheduler.adaptivebatch; + +import org.apache.flink.annotation.VisibleForTesting; +import org.apache.flink.runtime.executiongraph.ResultPartitionBytes; +import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; + +import java.util.HashMap; +import java.util.Map; + +import static org.apache.flink.util.Preconditions.checkNotNull; +import static org.apache.flink.util.Preconditions.checkState; + +/** Base blocking result info. */ +abstract class AbstractBlockingResultInfo implements BlockingResultInfo { + + private final IntermediateDataSetID resultId; + + protected final int numOfPartitions; + + protected final int numOfSubpartitions; + + /** + * The subpartition bytes map. The key is the partition index, value is a subpartition bytes + * list. + */ + protected final Map<Integer, long[]> subpartitionBytesByPartitionIndex; + + AbstractBlockingResultInfo( + IntermediateDataSetID resultId, int numOfPartitions, int numOfSubpartitions) { + this.resultId = checkNotNull(resultId); + this.numOfPartitions = numOfPartitions; + this.numOfSubpartitions = numOfSubpartitions; + this.subpartitionBytesByPartitionIndex = new HashMap<>(); + } + + @Override + public IntermediateDataSetID getResultId() { + return resultId; + } + + @Override + public void recordPartitionInfo(int partitionIndex, ResultPartitionBytes partitionBytes) { + checkState(partitionBytes.getSubpartitionBytes().length == numOfSubpartitions); + subpartitionBytesByPartitionIndex.put( + partitionIndex, partitionBytes.getSubpartitionBytes()); + } + + @Override + public void resetPartitionInfo(int partitionIndex) { + subpartitionBytesByPartitionIndex.remove(partitionIndex); + } + + @VisibleForTesting + int getNumOfRecordedPartitions() { + return subpartitionBytesByPartitionIndex.size(); + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchScheduler.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchScheduler.java index 6a80d7a9569..da4beee4c43 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchScheduler.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchScheduler.java @@ -26,12 +26,19 @@ import org.apache.flink.runtime.JobException; import org.apache.flink.runtime.checkpoint.CheckpointRecoveryFactory; import org.apache.flink.runtime.checkpoint.CheckpointsCleaner; import org.apache.flink.runtime.concurrent.ComponentMainThreadExecutor; +import org.apache.flink.runtime.execution.ExecutionState; import org.apache.flink.runtime.executiongraph.Execution; import org.apache.flink.runtime.executiongraph.ExecutionJobVertex; +import org.apache.flink.runtime.executiongraph.ExecutionVertex; +import org.apache.flink.runtime.executiongraph.IOMetrics; import org.apache.flink.runtime.executiongraph.IntermediateResult; import org.apache.flink.runtime.executiongraph.JobStatusListener; +import org.apache.flink.runtime.executiongraph.ResultPartitionBytes; import org.apache.flink.runtime.executiongraph.failover.flip1.FailoverStrategy; import org.apache.flink.runtime.executiongraph.failover.flip1.RestartBackoffTimeStrategy; +import org.apache.flink.runtime.jobgraph.DistributionPattern; +import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; +import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; import org.apache.flink.runtime.jobgraph.JobEdge; import org.apache.flink.runtime.jobgraph.JobGraph; import org.apache.flink.runtime.jobgraph.JobVertex; @@ -50,6 +57,7 @@ import org.apache.flink.runtime.scheduler.ExecutionVertexVersioner; import org.apache.flink.runtime.scheduler.VertexParallelismStore; import org.apache.flink.runtime.scheduler.adaptivebatch.forwardgroup.ForwardGroup; import org.apache.flink.runtime.scheduler.adaptivebatch.forwardgroup.ForwardGroupComputeUtil; +import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID; import org.apache.flink.runtime.scheduler.strategy.SchedulingStrategyFactory; import org.apache.flink.runtime.shuffle.ShuffleMaster; import org.apache.flink.util.concurrent.ScheduledExecutor; @@ -57,6 +65,7 @@ import org.apache.flink.util.concurrent.ScheduledExecutor; import org.slf4j.Logger; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; @@ -64,6 +73,7 @@ import java.util.concurrent.Executor; import java.util.function.Consumer; import java.util.function.Function; +import static org.apache.flink.util.Preconditions.checkArgument; import static org.apache.flink.util.Preconditions.checkNotNull; /** @@ -78,6 +88,8 @@ public class AdaptiveBatchScheduler extends DefaultScheduler { private final Map<JobVertexID, ForwardGroup> forwardGroupsByJobVertexId; + private final Map<IntermediateDataSetID, BlockingResultInfo> blockingResultInfos; + public AdaptiveBatchScheduler( final Logger log, final JobGraph jobGraph, @@ -140,6 +152,8 @@ public class AdaptiveBatchScheduler extends DefaultScheduler { ForwardGroupComputeUtil.computeForwardGroups( jobGraph.getVerticesSortedTopologicallyFromSources(), getExecutionGraph()::getJobVertex); + + this.blockingResultInfos = new HashMap<>(); } @Override @@ -150,10 +164,58 @@ public class AdaptiveBatchScheduler extends DefaultScheduler { } @Override - protected void onTaskFinished(final Execution execution) { + protected void onTaskFinished(final Execution execution, final IOMetrics ioMetrics) { + checkNotNull(ioMetrics); + updateResultPartitionBytesMetrics(ioMetrics.getResultPartitionBytes()); initializeVerticesIfPossible(); - super.onTaskFinished(execution); + super.onTaskFinished(execution, ioMetrics); + } + + private void updateResultPartitionBytesMetrics( + Map<IntermediateResultPartitionID, ResultPartitionBytes> resultPartitionBytes) { + checkNotNull(resultPartitionBytes); + resultPartitionBytes.forEach( + (partitionId, partitionBytes) -> { + IntermediateResult result = + getExecutionGraph() + .getAllIntermediateResults() + .get(partitionId.getIntermediateDataSetID()); + checkNotNull(result); + + blockingResultInfos.compute( + result.getId(), + (ignored, resultInfo) -> { + if (resultInfo == null) { + resultInfo = createFromIntermediateResult(result); + } + resultInfo.recordPartitionInfo( + partitionId.getPartitionNumber(), partitionBytes); + return resultInfo; + }); + }); + } + + @Override + protected void resetForNewExecution(final ExecutionVertexID executionVertexId) { + final ExecutionVertex executionVertex = getExecutionVertex(executionVertexId); + if (executionVertex.getExecutionState() == ExecutionState.FINISHED) { + executionVertex + .getProducedPartitions() + .values() + .forEach( + partition -> { + blockingResultInfos.computeIfPresent( + partition.getIntermediateResult().getId(), + (ignored, resultInfo) -> { + resultInfo.resetPartitionInfo( + partition.getPartitionNumber()); + return resultInfo; + }); + }); + } + + super.resetForNewExecution(executionVertexId); } void initializeVerticesIfPossible() { @@ -246,12 +308,9 @@ public class AdaptiveBatchScheduler extends DefaultScheduler { final ExecutionJobVertex producerVertex = getExecutionJobVertex(consumedResult.getProducer().getId()); if (producerVertex.isFinished()) { - IntermediateResult intermediateResult = - getExecutionGraph().getAllIntermediateResults().get(consumedResult.getId()); - checkNotNull(intermediateResult); - - consumableResultInfo.add( - BlockingResultInfo.createFromIntermediateResult(intermediateResult)); + BlockingResultInfo resultInfo = + checkNotNull(blockingResultInfos.get(consumedResult.getId())); + consumableResultInfo.add(resultInfo); } else { // not all inputs consumable, return Optional.empty() return Optional.empty(); @@ -320,4 +379,27 @@ public class AdaptiveBatchScheduler extends DefaultScheduler { }, Function.identity()); } + + private static BlockingResultInfo createFromIntermediateResult(IntermediateResult result) { + checkArgument(result != null); + // Note that for dynamic graph, different partitions in the same result have the same number + // of subpartitions. + if (result.getConsumingDistributionPattern() == DistributionPattern.POINTWISE) { + return new PointwiseBlockingResultInfo( + result.getId(), + result.getNumberOfAssignedPartitions(), + result.getPartitions()[0].getNumberOfSubpartitions()); + } else { + return new AllToAllBlockingResultInfo( + result.getId(), + result.getNumberOfAssignedPartitions(), + result.getPartitions()[0].getNumberOfSubpartitions(), + result.isBroadcast()); + } + } + + @VisibleForTesting + BlockingResultInfo getBlockingResultInfo(IntermediateDataSetID resultId) { + return blockingResultInfos.get(resultId); + } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfo.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfo.java new file mode 100644 index 00000000000..a4a06efc57a --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfo.java @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.scheduler.adaptivebatch; + +import org.apache.flink.runtime.executiongraph.ResultPartitionBytes; +import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; + +import javax.annotation.Nullable; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; + +import static org.apache.flink.util.Preconditions.checkState; + +/** Information of All-To-All result. */ +public class AllToAllBlockingResultInfo extends AbstractBlockingResultInfo { + + private final boolean isBroadcast; + + /** + * Aggregated subpartition bytes, which aggregates the subpartition bytes with the same + * subpartition index in different partitions. Note that We can aggregate them because they will + * be consumed by the same downstream task. + */ + @Nullable private List<Long> aggregatedSubpartitionBytes; + + AllToAllBlockingResultInfo( + IntermediateDataSetID resultId, + int numOfPartitions, + int numOfSubpartitions, + boolean isBroadcast) { + super(resultId, numOfPartitions, numOfSubpartitions); + this.isBroadcast = isBroadcast; + } + + @Override + public boolean isBroadcast() { + return isBroadcast; + } + + @Override + public boolean isPointwise() { + return false; + } + + @Override + public long getNumBytesProduced() { + checkState(aggregatedSubpartitionBytes != null, "Not all partition infos are ready"); + if (isBroadcast) { + return aggregatedSubpartitionBytes.get(0); + } else { + return aggregatedSubpartitionBytes.stream().reduce(0L, Long::sum); + } + } + + @Override + public void recordPartitionInfo(int partitionIndex, ResultPartitionBytes partitionBytes) { + // Once all partitions are finished, we can convert the subpartition bytes to aggregated + // value to reduce the space usage, because the distribution of source splits does not + // affect the distribution of data consumed by downstream tasks of ALL_TO_ALL edges(Hashing + // or Rebalancing, we do not consider rare cases such as custom partitions here). + if (aggregatedSubpartitionBytes == null) { + super.recordPartitionInfo(partitionIndex, partitionBytes); + + if (subpartitionBytesByPartitionIndex.size() == numOfPartitions) { + long[] aggregatedBytes = new long[numOfSubpartitions]; + subpartitionBytesByPartitionIndex + .values() + .forEach( + subpartitionBytes -> { + checkState(subpartitionBytes.length == numOfSubpartitions); + for (int i = 0; i < subpartitionBytes.length; ++i) { + aggregatedBytes[i] += subpartitionBytes[i]; + } + }); + this.aggregatedSubpartitionBytes = + Arrays.stream(aggregatedBytes).boxed().collect(Collectors.toList()); + this.subpartitionBytesByPartitionIndex.clear(); + } + } + } + + @Override + public void resetPartitionInfo(int partitionIndex) { + if (aggregatedSubpartitionBytes == null) { + super.resetPartitionInfo(partitionIndex); + } + } + + public List<Long> getAggregatedSubpartitionBytes() { + checkState(aggregatedSubpartitionBytes != null, "Not all partition infos are ready"); + return Collections.unmodifiableList(aggregatedSubpartitionBytes); + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/BlockingResultInfo.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/BlockingResultInfo.java index 302980ab216..0fd14e4439e 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/BlockingResultInfo.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/BlockingResultInfo.java @@ -18,63 +18,60 @@ package org.apache.flink.runtime.scheduler.adaptivebatch; -import org.apache.flink.annotation.VisibleForTesting; -import org.apache.flink.runtime.executiongraph.IOMetrics; -import org.apache.flink.runtime.executiongraph.IntermediateResult; -import org.apache.flink.runtime.executiongraph.IntermediateResultPartition; +import org.apache.flink.runtime.executiongraph.ResultPartitionBytes; +import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; -import java.util.ArrayList; -import java.util.List; - -import static org.apache.flink.util.Preconditions.checkArgument; -import static org.apache.flink.util.Preconditions.checkNotNull; -import static org.apache.flink.util.Preconditions.checkState; - -/** The blocking result info, which will be used to calculate the vertex parallelism. */ -public class BlockingResultInfo { - - private final List<Long> blockingPartitionSizes; - - private final boolean isBroadcast; - - private BlockingResultInfo(List<Long> blockingPartitionSizes, boolean isBroadcast) { - this.blockingPartitionSizes = blockingPartitionSizes; - this.isBroadcast = isBroadcast; - } - - public List<Long> getBlockingPartitionSizes() { - return blockingPartitionSizes; - } - - public boolean isBroadcast() { - return isBroadcast; - } - - @VisibleForTesting - static BlockingResultInfo createFromBroadcastResult(List<Long> blockingPartitionSizes) { - return new BlockingResultInfo(blockingPartitionSizes, true); - } - - @VisibleForTesting - static BlockingResultInfo createFromNonBroadcastResult(List<Long> blockingPartitionSizes) { - return new BlockingResultInfo(blockingPartitionSizes, false); - } - - public static BlockingResultInfo createFromIntermediateResult( - IntermediateResult intermediateResult) { - checkArgument(intermediateResult != null); - - List<Long> blockingPartitionSizes = new ArrayList<>(); - for (IntermediateResultPartition partition : intermediateResult.getPartitions()) { - checkState(partition.isConsumable()); - - IOMetrics ioMetrics = partition.getProducer().getPartitionProducer().getIOMetrics(); - checkNotNull(ioMetrics, "IOMetrics should not be null."); - - blockingPartitionSizes.add( - ioMetrics.getNumBytesProducedOfPartitions().get(partition.getPartitionId())); - } - - return new BlockingResultInfo(blockingPartitionSizes, intermediateResult.isBroadcast()); - } +/** + * The blocking result info, which will be used to calculate the vertex parallelism and input infos. + */ +public interface BlockingResultInfo { + + /** + * Get the intermediate result id. + * + * @return the intermediate result id + */ + IntermediateDataSetID getResultId(); + + /** + * Whether it is a broadcast result. + * + * @return whether it is a broadcast result + */ + boolean isBroadcast(); + + /** + * Whether it is a pointwise result. + * + * @return whether it is a pointwise result + */ + boolean isPointwise(); + + /** + * Return the num of bytes produced(numBytesProduced) by the producer. + * + * <p>The difference between numBytesProduced and numBytesOut : numBytesProduced represents the + * number of bytes actually produced, and numBytesOut represents the number of bytes sent to + * downstream tasks. In unicast scenarios, these two values should be equal. In broadcast + * scenarios, numBytesOut should be (N * numBytesProduced), where N refers to the number of + * subpartitions. + * + * @return the num of bytes produced by the producer + */ + long getNumBytesProduced(); + + /** + * Record the information of the result partition. + * + * @param partitionIndex the intermediate result partition index + * @param partitionBytes the {@link ResultPartitionBytes} of the partition + */ + void recordPartitionInfo(int partitionIndex, ResultPartitionBytes partitionBytes); + + /** + * Reset the information of the result partition. + * + * @param partitionIndex the intermediate result partition index + */ + void resetPartitionInfo(int partitionIndex); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismDecider.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismDecider.java index 43225fd4f61..64e6998a4fa 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismDecider.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismDecider.java @@ -92,19 +92,13 @@ public class DefaultVertexParallelismDecider implements VertexParallelismDecider long broadcastBytes = consumedResults.stream() .filter(BlockingResultInfo::isBroadcast) - .mapToLong( - consumedResult -> - consumedResult.getBlockingPartitionSizes().stream() - .reduce(0L, Long::sum)) + .mapToLong(BlockingResultInfo::getNumBytesProduced) .sum(); long nonBroadcastBytes = consumedResults.stream() .filter(consumedResult -> !consumedResult.isBroadcast()) - .mapToLong( - consumedResult -> - consumedResult.getBlockingPartitionSizes().stream() - .reduce(0L, Long::sum)) + .mapToLong(BlockingResultInfo::getNumBytesProduced) .sum(); long expectedMaxBroadcastBytes = diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/PointwiseBlockingResultInfo.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/PointwiseBlockingResultInfo.java new file mode 100644 index 00000000000..287b180df00 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/PointwiseBlockingResultInfo.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.scheduler.adaptivebatch; + +import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; + +import java.util.Arrays; + +import static org.apache.flink.util.Preconditions.checkState; + +/** Information of Pointwise result. */ +public class PointwiseBlockingResultInfo extends AbstractBlockingResultInfo { + PointwiseBlockingResultInfo( + IntermediateDataSetID resultId, int numOfPartitions, int numOfSubpartitions) { + super(resultId, numOfPartitions, numOfSubpartitions); + } + + @Override + public boolean isBroadcast() { + return false; + } + + @Override + public boolean isPointwise() { + return true; + } + + @Override + public long getNumBytesProduced() { + checkState( + subpartitionBytesByPartitionIndex.size() == numOfPartitions, + "Not all partition infos are ready"); + return subpartitionBytesByPartitionIndex.values().stream() + .flatMapToLong(Arrays::stream) + .reduce(0L, Long::sum); + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/SpeculativeScheduler.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/SpeculativeScheduler.java index 7c0a6c806af..37c72aa2daa 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/SpeculativeScheduler.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/SpeculativeScheduler.java @@ -35,6 +35,7 @@ import org.apache.flink.runtime.execution.ExecutionState; import org.apache.flink.runtime.executiongraph.Execution; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.executiongraph.ExecutionVertex; +import org.apache.flink.runtime.executiongraph.IOMetrics; import org.apache.flink.runtime.executiongraph.JobStatusListener; import org.apache.flink.runtime.executiongraph.SpeculativeExecutionVertex; import org.apache.flink.runtime.executiongraph.failover.flip1.FailoverStrategy; @@ -194,7 +195,7 @@ public class SpeculativeScheduler extends AdaptiveBatchScheduler } @Override - protected void onTaskFinished(final Execution execution) { + protected void onTaskFinished(final Execution execution, final IOMetrics ioMetrics) { if (!isOriginalAttempt(execution)) { numEffectiveSpeculativeExecutionsCounter.inc(); } @@ -202,7 +203,7 @@ public class SpeculativeScheduler extends AdaptiveBatchScheduler // cancel all un-terminated executions because the execution vertex has finished FutureUtils.assertNoException(cancelPendingExecutions(execution.getVertex().getID())); - super.onTaskFinished(execution); + super.onTaskFinished(execution, ioMetrics); } private static boolean isOriginalAttempt(final Execution execution) { diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraphDeploymentTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraphDeploymentTest.java index 32e6302ec88..e629fce99a8 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraphDeploymentTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraphDeploymentTest.java @@ -336,7 +336,7 @@ class DefaultExecutionGraphDeploymentTest { scheduler.updateTaskExecutionState(state); - assertThat(execution1.getIOMetrics()).isEqualTo(ioMetrics); + assertIOMetricsEqual(execution1.getIOMetrics(), ioMetrics); assertThat(execution1.getUserAccumulators()).isNotNull(); assertThat(execution1.getUserAccumulators().get("acc").getLocalValue()).isEqualTo(4); @@ -359,7 +359,7 @@ class DefaultExecutionGraphDeploymentTest { scheduler.updateTaskExecutionState(state2); - assertThat(execution2.getIOMetrics()).isEqualTo(ioMetrics2); + assertIOMetricsEqual(execution2.getIOMetrics(), ioMetrics2); assertThat(execution2.getUserAccumulators()).isNotNull(); assertThat(execution2.getUserAccumulators().get("acc").getLocalValue()).isEqualTo(8); } @@ -388,13 +388,13 @@ class DefaultExecutionGraphDeploymentTest { execution1.cancel(); execution1.completeCancelling(accumulators, ioMetrics, false); - assertThat(execution1.getIOMetrics()).isEqualTo(ioMetrics); + assertIOMetricsEqual(execution1.getIOMetrics(), ioMetrics); assertThat(execution1.getUserAccumulators()).isEqualTo(accumulators); Execution execution2 = executions.values().iterator().next(); execution2.markFailed(new Throwable(), false, accumulators, ioMetrics, false, true); - assertThat(execution2.getIOMetrics()).isEqualTo(ioMetrics); + assertIOMetricsEqual(execution2.getIOMetrics(), ioMetrics); assertThat(execution2.getUserAccumulators()).isEqualTo(accumulators); } @@ -667,7 +667,7 @@ class DefaultExecutionGraphDeploymentTest { .build(EXECUTOR_EXTENSION.getExecutor()); } - private boolean isDeployedInTopologicalOrder( + private static boolean isDeployedInTopologicalOrder( List<ExecutionAttemptID> submissionOrder, List<Collection<ExecutionAttemptID>> executionStages) { final Iterator<ExecutionAttemptID> submissionIterator = submissionOrder.iterator(); @@ -688,4 +688,16 @@ class DefaultExecutionGraphDeploymentTest { return !submissionIterator.hasNext(); } + + private void assertIOMetricsEqual(IOMetrics ioMetrics1, IOMetrics ioMetrics2) { + assertThat(ioMetrics1.numBytesIn).isEqualTo(ioMetrics2.numBytesIn); + assertThat(ioMetrics1.numBytesOut).isEqualTo(ioMetrics2.numBytesOut); + assertThat(ioMetrics1.numRecordsIn).isEqualTo(ioMetrics2.numRecordsIn); + assertThat(ioMetrics1.numRecordsOut).isEqualTo(ioMetrics2.numRecordsOut); + assertThat(ioMetrics1.accumulateIdleTime).isEqualTo(ioMetrics2.accumulateIdleTime); + assertThat(ioMetrics1.accumulateBusyTime).isEqualTo(ioMetrics2.accumulateBusyTime); + assertThat(ioMetrics1.accumulateBackPressuredTime) + .isEqualTo(ioMetrics2.accumulateBackPressuredTime); + assertThat(ioMetrics1.resultPartitionBytes).isEqualTo(ioMetrics2.resultPartitionBytes); + } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/ResultPartitionTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/ResultPartitionTest.java index 94095c755b8..f66b3bcf9b4 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/ResultPartitionTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/ResultPartitionTest.java @@ -697,30 +697,44 @@ class ResultPartitionTest { } @Test - void testNumBytesProducedCounterForUnicast() throws IOException { - testNumBytesProducedCounter(false); + void testResultPartitionBytesCounterForUnicast() throws IOException { + testResultPartitionBytesCounter(false); } @Test - void testNumBytesProducedCounterForBroadcast() throws IOException { - testNumBytesProducedCounter(true); + void testResultPartitionBytesCounterForBroadcast() throws IOException { + testResultPartitionBytesCounter(true); } - private void testNumBytesProducedCounter(boolean isBroadcast) throws IOException { + private void testResultPartitionBytesCounter(boolean isBroadcast) throws IOException { BufferWritingResultPartition bufferWritingResultPartition = createResultPartition(ResultPartitionType.BLOCKING); if (isBroadcast) { bufferWritingResultPartition.broadcastRecord(ByteBuffer.allocate(bufferSize)); - assertThat(bufferWritingResultPartition.numBytesProduced.getCount()) - .isEqualTo(bufferSize); + + long[] subpartitionBytes = + bufferWritingResultPartition + .resultPartitionBytes + .createSnapshot() + .getSubpartitionBytes(); + assertThat(subpartitionBytes).containsExactly(bufferSize, bufferSize); + assertThat(bufferWritingResultPartition.numBytesOut.getCount()) .isEqualTo(2 * bufferSize); } else { bufferWritingResultPartition.emitRecord(ByteBuffer.allocate(bufferSize), 0); - assertThat(bufferWritingResultPartition.numBytesProduced.getCount()) - .isEqualTo(bufferSize); - assertThat(bufferWritingResultPartition.numBytesOut.getCount()).isEqualTo(bufferSize); + bufferWritingResultPartition.emitRecord(ByteBuffer.allocate(2 * bufferSize), 1); + + long[] subpartitionBytes = + bufferWritingResultPartition + .resultPartitionBytes + .createSnapshot() + .getSubpartitionBytes(); + assertThat(subpartitionBytes).containsExactly(bufferSize, (long) 2 * bufferSize); + + assertThat(bufferWritingResultPartition.numBytesOut.getCount()) + .isEqualTo(3 * bufferSize); } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartitionTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartitionTest.java index 1403b60dff2..3139935ab90 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartitionTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartitionTest.java @@ -415,17 +415,17 @@ public class SortMergeResultPartitionTest { @TestTemplate void testNumBytesProducedCounterForUnicast() throws IOException { - testNumBytesProducedCounter(false); + testResultPartitionBytesCounter(false); } @TestTemplate void testNumBytesProducedCounterForBroadcast() throws IOException { - testNumBytesProducedCounter(true); + testResultPartitionBytesCounter(true); } - private void testNumBytesProducedCounter(boolean isBroadcast) throws IOException { + private void testResultPartitionBytesCounter(boolean isBroadcast) throws IOException { int numBuffers = useHashDataBuffer ? 100 : 15; - int numSubpartitions = 10; + int numSubpartitions = 2; BufferPool bufferPool = globalPool.createBufferPool(numBuffers, numBuffers); SortMergeResultPartition partition = @@ -435,16 +435,25 @@ public class SortMergeResultPartitionTest { partition.broadcastRecord(ByteBuffer.allocate(bufferSize)); partition.finish(); - assertThat(partition.numBytesProduced.getCount()).isEqualTo(bufferSize + 4); + long[] subpartitionBytes = + partition.resultPartitionBytes.createSnapshot().getSubpartitionBytes(); + assertThat(subpartitionBytes) + .containsExactly((long) bufferSize + 4, (long) bufferSize + 4); + assertThat(partition.numBytesOut.getCount()) .isEqualTo(numSubpartitions * (bufferSize + 4)); } else { partition.emitRecord(ByteBuffer.allocate(bufferSize), 0); + partition.emitRecord(ByteBuffer.allocate(2 * bufferSize), 1); partition.finish(); - assertThat(partition.numBytesProduced.getCount()).isEqualTo(bufferSize + 4); + long[] subpartitionBytes = + partition.resultPartitionBytes.createSnapshot().getSubpartitionBytes(); + assertThat(subpartitionBytes) + .containsExactly((long) bufferSize + 4, (long) 2 * bufferSize + 4); + assertThat(partition.numBytesOut.getCount()) - .isEqualTo(bufferSize + numSubpartitions * 4); + .isEqualTo(3 * bufferSize + numSubpartitions * 4); } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsResultPartitionTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsResultPartitionTest.java index 71f777046e7..0ad1541f697 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsResultPartitionTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsResultPartitionTest.java @@ -24,6 +24,7 @@ import org.apache.flink.core.memory.MemorySegmentFactory; import org.apache.flink.core.testutils.CheckedThread; import org.apache.flink.runtime.event.AbstractEvent; import org.apache.flink.runtime.executiongraph.IOMetrics; +import org.apache.flink.runtime.executiongraph.ResultPartitionBytes; import org.apache.flink.runtime.io.disk.BatchShuffleReadBufferPool; import org.apache.flink.runtime.io.disk.FileChannelManager; import org.apache.flink.runtime.io.disk.FileChannelManagerImpl; @@ -439,9 +440,11 @@ class HsResultPartitionTest { assertThat(taskIOMetricGroup.getNumBytesOutCounter().getCount()) .isEqualTo(3 * bufferSize); IOMetrics ioMetrics = taskIOMetricGroup.createSnapshot(); - assertThat(ioMetrics.getNumBytesProducedOfPartitions()) - .hasSize(1) - .containsValue((long) 2 * bufferSize); + assertThat(ioMetrics.getResultPartitionBytes()).hasSize(1); + ResultPartitionBytes partitionBytes = + ioMetrics.getResultPartitionBytes().values().iterator().next(); + assertThat(partitionBytes.getSubpartitionBytes()) + .containsExactly((long) 2 * bufferSize, (long) bufferSize); } } @@ -498,9 +501,11 @@ class HsResultPartitionTest { assertThat(taskIOMetricGroup.getNumBuffersOutCounter().getCount()).isEqualTo(1); assertThat(taskIOMetricGroup.getNumBytesOutCounter().getCount()).isEqualTo(bufferSize); IOMetrics ioMetrics = taskIOMetricGroup.createSnapshot(); - assertThat(ioMetrics.getNumBytesProducedOfPartitions()) - .hasSize(1) - .containsValue((long) bufferSize); + assertThat(ioMetrics.getResultPartitionBytes()).hasSize(1); + ResultPartitionBytes partitionBytes = + ioMetrics.getResultPartitionBytes().values().iterator().next(); + assertThat(partitionBytes.getSubpartitionBytes()) + .containsExactly((long) bufferSize, (long) bufferSize); } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/metrics/groups/TaskIOMetricGroupTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/metrics/groups/TaskIOMetricGroupTest.java index 0faab8c1815..729c54f62df 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/metrics/groups/TaskIOMetricGroupTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/metrics/groups/TaskIOMetricGroupTest.java @@ -21,6 +21,8 @@ package org.apache.flink.runtime.metrics.groups; import org.apache.flink.metrics.Counter; import org.apache.flink.metrics.SimpleCounter; import org.apache.flink.runtime.executiongraph.IOMetrics; +import org.apache.flink.runtime.executiongraph.ResultPartitionBytes; +import org.apache.flink.runtime.io.network.metrics.ResultPartitionBytesCounter; import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; import org.junit.jupiter.api.Test; @@ -96,26 +98,30 @@ class TaskIOMetricGroupTest { } @Test - void testNumBytesProducedOfPartitionsMetrics() { + void testResultPartitionBytesMetrics() { TaskMetricGroup task = UnregisteredMetricGroups.createUnregisteredTaskMetricGroup(); TaskIOMetricGroup taskIO = task.getIOMetricGroup(); - Counter c1 = new SimpleCounter(); - c1.inc(32L); - Counter c2 = new SimpleCounter(); - c2.inc(64L); + ResultPartitionBytesCounter c1 = new ResultPartitionBytesCounter(2); + ResultPartitionBytesCounter c2 = new ResultPartitionBytesCounter(2); + + c1.inc(0, 32L); + c1.inc(1, 64L); + c2.incAll(128L); IntermediateResultPartitionID resultPartitionID1 = new IntermediateResultPartitionID(); IntermediateResultPartitionID resultPartitionID2 = new IntermediateResultPartitionID(); - taskIO.registerNumBytesProducedCounterForPartition(resultPartitionID1, c1); - taskIO.registerNumBytesProducedCounterForPartition(resultPartitionID2, c2); + taskIO.registerResultPartitionBytesCounter(resultPartitionID1, c1); + taskIO.registerResultPartitionBytesCounter(resultPartitionID2, c2); - Map<IntermediateResultPartitionID, Long> numBytesProducedOfPartitions = - taskIO.createSnapshot().getNumBytesProducedOfPartitions(); + Map<IntermediateResultPartitionID, ResultPartitionBytes> resultPartitionBytes = + taskIO.createSnapshot().getResultPartitionBytes(); - assertThat(numBytesProducedOfPartitions.size()).isEqualTo(2); - assertThat(numBytesProducedOfPartitions.get(resultPartitionID1).longValue()).isEqualTo(32L); - assertThat(numBytesProducedOfPartitions.get(resultPartitionID2).longValue()).isEqualTo(64L); + assertThat(resultPartitionBytes.size()).isEqualTo(2); + assertThat(resultPartitionBytes.get(resultPartitionID1).getSubpartitionBytes()) + .containsExactly(32L, 64L); + assertThat(resultPartitionBytes.get(resultPartitionID2).getSubpartitionBytes()) + .containsExactly(128L, 128L); } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/DefaultSchedulerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/DefaultSchedulerTest.java index 47c1ae6ff6b..59122521cb2 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/DefaultSchedulerTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/DefaultSchedulerTest.java @@ -136,6 +136,7 @@ import static org.apache.flink.runtime.executiongraph.ExecutionGraphTestUtils.cr import static org.apache.flink.runtime.jobmaster.slotpool.DefaultDeclarativeSlotPoolTest.createSlotOffersForResourceRequirements; import static org.apache.flink.runtime.jobmaster.slotpool.SlotPoolTestUtils.offerSlots; import static org.apache.flink.runtime.scheduler.SchedulerTestingUtils.acknowledgePendingCheckpoint; +import static org.apache.flink.runtime.scheduler.SchedulerTestingUtils.createFailedTaskExecutionState; import static org.apache.flink.runtime.scheduler.SchedulerTestingUtils.enableCheckpointing; import static org.apache.flink.runtime.scheduler.SchedulerTestingUtils.getCheckpointCoordinator; import static org.apache.flink.util.ExceptionUtils.findThrowable; @@ -1782,12 +1783,6 @@ public class DefaultSchedulerTest extends TestLogger { schedulerClosed.get(); } - public static TaskExecutionState createFailedTaskExecutionState( - ExecutionAttemptID executionAttemptID) { - return new TaskExecutionState( - executionAttemptID, ExecutionState.FAILED, new Exception("Expected failure cause")); - } - private static long initiateFailure( DefaultScheduler scheduler, ExecutionAttemptID executionAttemptId, diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/SchedulerTestingUtils.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/SchedulerTestingUtils.java index 0de11b3941e..c6497b9e96b 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/SchedulerTestingUtils.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/SchedulerTestingUtils.java @@ -35,11 +35,14 @@ import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.executiongraph.ExecutionGraph; import org.apache.flink.runtime.executiongraph.ExecutionJobVertex; import org.apache.flink.runtime.executiongraph.ExecutionVertex; +import org.apache.flink.runtime.executiongraph.IOMetrics; +import org.apache.flink.runtime.executiongraph.ResultPartitionBytes; import org.apache.flink.runtime.executiongraph.failover.flip1.TestRestartBackoffTimeStrategy; import org.apache.flink.runtime.io.network.partition.JobMasterPartitionTracker; import org.apache.flink.runtime.io.network.partition.ResultPartitionType; import org.apache.flink.runtime.jobgraph.DistributionPattern; import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; +import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; import org.apache.flink.runtime.jobgraph.JobGraph; import org.apache.flink.runtime.jobgraph.JobGraphBuilder; import org.apache.flink.runtime.jobgraph.JobVertex; @@ -64,6 +67,7 @@ import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; @@ -434,4 +438,34 @@ public class SchedulerTestingUtils { vertex.deploy(); } } + + public static TaskExecutionState createFinishedTaskExecutionState( + ExecutionAttemptID attemptId, + Map<IntermediateResultPartitionID, ResultPartitionBytes> resultPartitionBytes) { + return new TaskExecutionState( + attemptId, + ExecutionState.FINISHED, + null, + null, + new IOMetrics(0, 0, 0, 0, 0, 0, 0, resultPartitionBytes)); + } + + public static TaskExecutionState createFinishedTaskExecutionState( + ExecutionAttemptID attemptId) { + return createFinishedTaskExecutionState(attemptId, Collections.emptyMap()); + } + + public static TaskExecutionState createFailedTaskExecutionState( + ExecutionAttemptID attemptId, Throwable failureCause) { + return new TaskExecutionState(attemptId, ExecutionState.FAILED, failureCause); + } + + public static TaskExecutionState createFailedTaskExecutionState(ExecutionAttemptID attemptId) { + return createFailedTaskExecutionState(attemptId, new Exception("Expected failure cause")); + } + + public static TaskExecutionState createCanceledTaskExecutionState( + ExecutionAttemptID attemptId) { + return new TaskExecutionState(attemptId, ExecutionState.CANCELED); + } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchSchedulerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchSchedulerTest.java index a6e87f9bfc2..2dc5389668a 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchSchedulerTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchSchedulerTest.java @@ -30,9 +30,14 @@ import org.apache.flink.runtime.executiongraph.Execution; import org.apache.flink.runtime.executiongraph.ExecutionGraph; import org.apache.flink.runtime.executiongraph.ExecutionJobVertex; import org.apache.flink.runtime.executiongraph.ExecutionVertex; -import org.apache.flink.runtime.executiongraph.IOMetrics; +import org.apache.flink.runtime.executiongraph.ResultPartitionBytes; +import org.apache.flink.runtime.executiongraph.failover.flip1.FixedDelayRestartBackoffTimeStrategy; +import org.apache.flink.runtime.io.network.partition.PartitionNotFoundException; +import org.apache.flink.runtime.io.network.partition.ResultPartitionID; import org.apache.flink.runtime.io.network.partition.ResultPartitionType; +import org.apache.flink.runtime.io.network.partition.TestingJobMasterPartitionTracker; import org.apache.flink.runtime.jobgraph.DistributionPattern; +import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; import org.apache.flink.runtime.jobgraph.JobGraph; import org.apache.flink.runtime.jobgraph.JobVertex; import org.apache.flink.runtime.scheduler.DefaultSchedulerBuilder; @@ -41,16 +46,27 @@ import org.apache.flink.runtime.taskmanager.TaskExecutionState; import org.apache.flink.runtime.testtasks.NoOpInvokable; import org.apache.flink.testutils.TestingUtils; import org.apache.flink.testutils.executor.TestExecutorExtension; +import org.apache.flink.util.concurrent.ManuallyTriggeredScheduledExecutor; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; +import javax.annotation.Nullable; + import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; import java.util.Iterator; import java.util.List; +import java.util.Map; import java.util.concurrent.ScheduledExecutorService; import java.util.stream.Collectors; +import java.util.stream.LongStream; +import static org.apache.flink.runtime.scheduler.SchedulerTestingUtils.createFailedTaskExecutionState; +import static org.apache.flink.runtime.scheduler.SchedulerTestingUtils.createFinishedTaskExecutionState; +import static org.apache.flink.shaded.guava30.com.google.common.collect.Iterables.getOnlyElement; import static org.assertj.core.api.Assertions.assertThat; /** Test for {@link AdaptiveBatchScheduler}. */ @@ -58,13 +74,20 @@ class AdaptiveBatchSchedulerTest { private static final int SOURCE_PARALLELISM_1 = 6; private static final int SOURCE_PARALLELISM_2 = 4; + private static final long SUBPARTITION_BYTES = 100L; @RegisterExtension static final TestExecutorExtension<ScheduledExecutorService> EXECUTOR_RESOURCE = TestingUtils.defaultExecutorExtension(); - private static final ComponentMainThreadExecutor mainThreadExecutor = - ComponentMainThreadExecutorServiceAdapter.forMainThread(); + private ComponentMainThreadExecutor mainThreadExecutor; + private ManuallyTriggeredScheduledExecutor taskRestartExecutor; + + @BeforeEach + void setUp() { + mainThreadExecutor = ComponentMainThreadExecutorServiceAdapter.forMainThread(); + taskRestartExecutor = new ManuallyTriggeredScheduledExecutor(); + } @Test void testAdaptiveBatchScheduler() throws Exception { @@ -122,20 +145,138 @@ class AdaptiveBatchSchedulerTest { assertThat(sink.getParallelism()).isEqualTo(SOURCE_PARALLELISM_1); } + @Test + void testUpdateBlockingResultInfoWhileScheduling() throws Exception { + JobGraph jobGraph = createJobGraph(false); + Iterator<JobVertex> jobVertexIterator = jobGraph.getVertices().iterator(); + JobVertex source1 = jobVertexIterator.next(); + JobVertex source2 = jobVertexIterator.next(); + JobVertex sink = jobVertexIterator.next(); + + Configuration configuration = new Configuration(); + configuration.set( + JobManagerOptions.SCHEDULER, JobManagerOptions.SchedulerType.AdaptiveBatch); + + final TestingJobMasterPartitionTracker partitionTracker = + new TestingJobMasterPartitionTracker(); + partitionTracker.setIsPartitionTrackedFunction(ignore -> true); + int maxParallelism = 6; + + AdaptiveBatchScheduler scheduler = + new DefaultSchedulerBuilder( + jobGraph, mainThreadExecutor, EXECUTOR_RESOURCE.getExecutor()) + .setDelayExecutor(taskRestartExecutor) + .setJobMasterConfiguration(configuration) + .setPartitionTracker(partitionTracker) + .setRestartBackoffTimeStrategy( + new FixedDelayRestartBackoffTimeStrategy + .FixedDelayRestartBackoffTimeStrategyFactory(10, 0) + .create()) + .setVertexParallelismDecider((ignored) -> maxParallelism) + .setDefaultMaxParallelism(maxParallelism) + .buildAdaptiveBatchJobScheduler(); + + final DefaultExecutionGraph graph = (DefaultExecutionGraph) scheduler.getExecutionGraph(); + final ExecutionJobVertex source1ExecutionJobVertex = graph.getJobVertex(source1.getID()); + final ExecutionJobVertex sinkExecutionJobVertex = graph.getJobVertex(sink.getID()); + + PointwiseBlockingResultInfo blockingResultInfo; + + scheduler.startScheduling(); + // trigger source1 finished. + transitionExecutionsState(scheduler, ExecutionState.FINISHED, source1); + blockingResultInfo = + (PointwiseBlockingResultInfo) getBlockingResultInfo(scheduler, source1); + assertThat(blockingResultInfo.getNumOfRecordedPartitions()).isEqualTo(SOURCE_PARALLELISM_1); + + // trigger source2 finished. + transitionExecutionsState(scheduler, ExecutionState.FINISHED, source2); + blockingResultInfo = + (PointwiseBlockingResultInfo) getBlockingResultInfo(scheduler, source2); + assertThat(blockingResultInfo.getNumOfRecordedPartitions()).isEqualTo(SOURCE_PARALLELISM_2); + + // trigger sink fail with partition not found + triggerFailedByPartitionNotFound( + scheduler, + source1ExecutionJobVertex.getTaskVertices()[0], + sinkExecutionJobVertex.getTaskVertices()[0]); + + taskRestartExecutor.triggerScheduledTasks(); + + // check the partition info is reset + assertThat( + ((PointwiseBlockingResultInfo) getBlockingResultInfo(scheduler, source1)) + .getNumOfRecordedPartitions()) + .isEqualTo(SOURCE_PARALLELISM_1 - 1); + } + + private BlockingResultInfo getBlockingResultInfo( + AdaptiveBatchScheduler scheduler, JobVertex jobVertex) { + return scheduler.getBlockingResultInfo( + getOnlyElement(jobVertex.getProducedDataSets()).getId()); + } + + private void triggerFailedByPartitionNotFound( + SchedulerBase scheduler, + ExecutionVertex producerVertex, + ExecutionVertex consumerVertex) { + final Execution execution = consumerVertex.getCurrentExecutionAttempt(); + final IntermediateResultPartitionID partitionId = + getOnlyElement(producerVertex.getProducedPartitions().values()).getPartitionId(); + // trigger execution vertex failed by partition not found. + transitionExecutionsState( + scheduler, + ExecutionState.FAILED, + Collections.singletonList(execution), + new PartitionNotFoundException( + new ResultPartitionID( + partitionId, + producerVertex.getCurrentExecutionAttempt().getAttemptId()))); + } + /** Transit the state of all executions. */ public static void transitionExecutionsState( - final SchedulerBase scheduler, final ExecutionState state, List<Execution> executions) { + final SchedulerBase scheduler, + final ExecutionState state, + List<Execution> executions, + @Nullable Throwable throwable) { for (Execution execution : executions) { - scheduler.updateTaskExecutionState( - new TaskExecutionState( - execution.getAttemptId(), - state, - null, - null, - new IOMetrics(0, 0, 0, 0, 0, 0, 0))); + TaskExecutionState taskExecutionState; + if (state == ExecutionState.FINISHED) { + taskExecutionState = + createFinishedTaskExecutionState( + execution.getAttemptId(), + createResultPartitionBytesForExecution(execution)); + } else if (state == ExecutionState.FAILED) { + taskExecutionState = + createFailedTaskExecutionState(execution.getAttemptId(), throwable); + } else { + throw new UnsupportedOperationException("Unsupported state " + state); + } + scheduler.updateTaskExecutionState(taskExecutionState); } } + static Map<IntermediateResultPartitionID, ResultPartitionBytes> + createResultPartitionBytesForExecution(Execution execution) { + Map<IntermediateResultPartitionID, ResultPartitionBytes> partitionBytes = new HashMap<>(); + execution + .getVertex() + .getProducedPartitions() + .forEach( + (partitionId, partition) -> { + int numOfSubpartitions = partition.getNumberOfSubpartitions(); + partitionBytes.put( + partitionId, + new ResultPartitionBytes( + LongStream.range(0, numOfSubpartitions) + .boxed() + .mapToLong(ignored -> SUBPARTITION_BYTES) + .toArray())); + }); + return partitionBytes; + } + /** Transit the state of all executions in the Job Vertex. */ public static void transitionExecutionsState( final SchedulerBase scheduler, final ExecutionState state, final JobVertex jobVertex) { @@ -145,7 +286,7 @@ class AdaptiveBatchSchedulerTest { .stream() .map(ExecutionVertex::getCurrentExecutionAttempt) .collect(Collectors.toList()); - transitionExecutionsState(scheduler, state, executions); + transitionExecutionsState(scheduler, state, executions, null); } public JobVertex createJobVertex(String jobVertexName, int parallelism) { @@ -178,6 +319,7 @@ class AdaptiveBatchSchedulerTest { return new DefaultSchedulerBuilder( jobGraph, mainThreadExecutor, EXECUTOR_RESOURCE.getExecutor()) + .setDelayExecutor(taskRestartExecutor) .setJobMasterConfiguration(configuration) .setVertexParallelismDecider((ignored) -> 10) .buildAdaptiveBatchJobScheduler(); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfoTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfoTest.java new file mode 100644 index 00000000000..f1c632b9ea0 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfoTest.java @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.scheduler.adaptivebatch; + +import org.apache.flink.runtime.executiongraph.ResultPartitionBytes; +import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; + +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** Test for {@link AllToAllBlockingResultInfo}. */ +class AllToAllBlockingResultInfoTest { + + @Test + void testGetNumBytesProducedForNonBroadcast() { + testGetNumBytesProduced(false, 192L); + } + + @Test + void testGetNumBytesProducedForBroadcast() { + testGetNumBytesProduced(true, 96L); + } + + @Test + void testGetAggregatedSubpartitionBytes() { + AllToAllBlockingResultInfo resultInfo = + new AllToAllBlockingResultInfo(new IntermediateDataSetID(), 2, 2, false); + resultInfo.recordPartitionInfo(0, new ResultPartitionBytes(new long[] {32L, 64L})); + resultInfo.recordPartitionInfo(1, new ResultPartitionBytes(new long[] {128L, 256L})); + + assertThat(resultInfo.getAggregatedSubpartitionBytes()).containsExactly(160L, 320L); + } + + @Test + void testGetBytesWithPartialPartitionInfos() { + AllToAllBlockingResultInfo resultInfo = + new AllToAllBlockingResultInfo(new IntermediateDataSetID(), 2, 2, false); + resultInfo.recordPartitionInfo(0, new ResultPartitionBytes(new long[] {32L, 64L})); + + assertThatThrownBy(resultInfo::getNumBytesProduced) + .isInstanceOf(IllegalStateException.class); + assertThatThrownBy(resultInfo::getAggregatedSubpartitionBytes) + .isInstanceOf(IllegalStateException.class); + } + + @Test + void testRecordPartitionInfoMultiTimes() { + AllToAllBlockingResultInfo resultInfo = + new AllToAllBlockingResultInfo(new IntermediateDataSetID(), 2, 2, false); + + ResultPartitionBytes partitionBytes1 = new ResultPartitionBytes(new long[] {32L, 64L}); + ResultPartitionBytes partitionBytes2 = new ResultPartitionBytes(new long[] {64L, 128L}); + ResultPartitionBytes partitionBytes3 = new ResultPartitionBytes(new long[] {128L, 256L}); + ResultPartitionBytes partitionBytes4 = new ResultPartitionBytes(new long[] {256L, 512L}); + + // record partitionBytes1 for subtask 0 and then reset it + resultInfo.recordPartitionInfo(0, partitionBytes1); + assertThat(resultInfo.getNumOfRecordedPartitions()).isEqualTo(1); + resultInfo.resetPartitionInfo(0); + assertThat(resultInfo.getNumOfRecordedPartitions()).isEqualTo(0); + + // record partitionBytes2 for subtask 0 and record partitionBytes3 for subtask 1 + resultInfo.recordPartitionInfo(0, partitionBytes2); + resultInfo.recordPartitionInfo(1, partitionBytes3); + + // The result info should be (partitionBytes2 + partitionBytes3) + assertThat(resultInfo.getNumBytesProduced()).isEqualTo(576L); + assertThat(resultInfo.getAggregatedSubpartitionBytes()).containsExactly(192L, 384L); + // The raw info should be clear + assertThat(resultInfo.getNumOfRecordedPartitions()).isEqualTo(0); + + // reset subtask 0 and then record partitionBytes4 for subtask 0 + resultInfo.resetPartitionInfo(0); + resultInfo.recordPartitionInfo(0, partitionBytes4); + + // The result info should still be (partitionBytes2 + partitionBytes3) + assertThat(resultInfo.getNumBytesProduced()).isEqualTo(576L); + assertThat(resultInfo.getAggregatedSubpartitionBytes()).containsExactly(192L, 384L); + assertThat(resultInfo.getNumOfRecordedPartitions()).isEqualTo(0); + } + + private void testGetNumBytesProduced(boolean isBroadcast, long expectedBytes) { + AllToAllBlockingResultInfo resultInfo = + new AllToAllBlockingResultInfo(new IntermediateDataSetID(), 2, 2, isBroadcast); + resultInfo.recordPartitionInfo(0, new ResultPartitionBytes(new long[] {32L, 32L})); + resultInfo.recordPartitionInfo(1, new ResultPartitionBytes(new long[] {64L, 64L})); + + assertThat(resultInfo.getNumBytesProduced()).isEqualTo(expectedBytes); + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismDeciderTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismDeciderTest.java index 387e2bdc051..0f588bcc5e6 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismDeciderTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismDeciderTest.java @@ -21,6 +21,8 @@ package org.apache.flink.runtime.scheduler.adaptivebatch; import org.apache.flink.configuration.Configuration; import org.apache.flink.configuration.JobManagerOptions; import org.apache.flink.configuration.MemorySize; +import org.apache.flink.runtime.executiongraph.ResultPartitionBytes; +import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -78,11 +80,8 @@ class DefaultVertexParallelismDeciderTest { @Test void testNormalizeParallelismDownToPowerOf2() { - BlockingResultInfo resultInfo1 = - BlockingResultInfo.createFromBroadcastResult(Arrays.asList(BYTE_256_MB)); - BlockingResultInfo resultInfo2 = - BlockingResultInfo.createFromNonBroadcastResult( - Arrays.asList(BYTE_256_MB, BYTE_8_GB)); + BlockingResultInfo resultInfo1 = createFromBroadcastResult(BYTE_256_MB); + BlockingResultInfo resultInfo2 = createFromNonBroadcastResult(BYTE_256_MB + BYTE_8_GB); int parallelism = decider.decideParallelismForVertex(Arrays.asList(resultInfo1, resultInfo2)); @@ -92,11 +91,8 @@ class DefaultVertexParallelismDeciderTest { @Test void testNormalizeParallelismUpToPowerOf2() { - BlockingResultInfo resultInfo1 = - BlockingResultInfo.createFromBroadcastResult(Arrays.asList(BYTE_256_MB)); - BlockingResultInfo resultInfo2 = - BlockingResultInfo.createFromNonBroadcastResult( - Arrays.asList(BYTE_1_GB, BYTE_8_GB)); + BlockingResultInfo resultInfo1 = createFromBroadcastResult(BYTE_256_MB); + BlockingResultInfo resultInfo2 = createFromNonBroadcastResult(BYTE_1_GB + BYTE_8_GB); int parallelism = decider.decideParallelismForVertex(Arrays.asList(resultInfo1, resultInfo2)); @@ -106,11 +102,8 @@ class DefaultVertexParallelismDeciderTest { @Test void testInitiallyNormalizedParallelismIsLargerThanMaxParallelism() { - BlockingResultInfo resultInfo1 = - BlockingResultInfo.createFromBroadcastResult(Arrays.asList(BYTE_256_MB)); - BlockingResultInfo resultInfo2 = - BlockingResultInfo.createFromNonBroadcastResult( - Arrays.asList(BYTE_8_GB, BYTE_1_TB)); + BlockingResultInfo resultInfo1 = createFromBroadcastResult(BYTE_256_MB); + BlockingResultInfo resultInfo2 = createFromNonBroadcastResult(BYTE_8_GB + BYTE_1_TB); int parallelism = decider.decideParallelismForVertex(Arrays.asList(resultInfo1, resultInfo2)); @@ -120,10 +113,8 @@ class DefaultVertexParallelismDeciderTest { @Test void testInitiallyNormalizedParallelismIsSmallerThanMinParallelism() { - BlockingResultInfo resultInfo1 = - BlockingResultInfo.createFromBroadcastResult(Arrays.asList(BYTE_256_MB)); - BlockingResultInfo resultInfo2 = - BlockingResultInfo.createFromNonBroadcastResult(Arrays.asList(BYTE_512_MB)); + BlockingResultInfo resultInfo1 = createFromBroadcastResult(BYTE_256_MB); + BlockingResultInfo resultInfo2 = createFromNonBroadcastResult(BYTE_512_MB); int parallelism = decider.decideParallelismForVertex(Arrays.asList(resultInfo1, resultInfo2)); @@ -133,10 +124,8 @@ class DefaultVertexParallelismDeciderTest { @Test void testBroadcastRatioExceedsCapRatio() { - BlockingResultInfo resultInfo1 = - BlockingResultInfo.createFromBroadcastResult(Arrays.asList(BYTE_1_GB)); - BlockingResultInfo resultInfo2 = - BlockingResultInfo.createFromNonBroadcastResult(Arrays.asList(BYTE_8_GB)); + BlockingResultInfo resultInfo1 = createFromBroadcastResult(BYTE_1_GB); + BlockingResultInfo resultInfo2 = createFromNonBroadcastResult(BYTE_8_GB); int parallelism = decider.decideParallelismForVertex(Arrays.asList(resultInfo1, resultInfo2)); @@ -146,15 +135,58 @@ class DefaultVertexParallelismDeciderTest { @Test void testNonBroadcastBytesCanNotDividedEvenly() { - BlockingResultInfo resultInfo1 = - BlockingResultInfo.createFromBroadcastResult(Arrays.asList(BYTE_512_MB)); - BlockingResultInfo resultInfo2 = - BlockingResultInfo.createFromNonBroadcastResult( - Arrays.asList(BYTE_256_MB, BYTE_8_GB)); + BlockingResultInfo resultInfo1 = createFromBroadcastResult(BYTE_512_MB); + BlockingResultInfo resultInfo2 = createFromNonBroadcastResult(BYTE_256_MB + BYTE_8_GB); int parallelism = decider.decideParallelismForVertex(Arrays.asList(resultInfo1, resultInfo2)); assertThat(parallelism).isEqualTo(16); } + + private static class TestingBlockingResultInfo implements BlockingResultInfo { + + private final boolean isBroadcast; + + private final long producedBytes; + + private TestingBlockingResultInfo(boolean isBroadcast, long producedBytes) { + this.isBroadcast = isBroadcast; + this.producedBytes = producedBytes; + } + + @Override + public IntermediateDataSetID getResultId() { + return new IntermediateDataSetID(); + } + + @Override + public boolean isBroadcast() { + return isBroadcast; + } + + @Override + public boolean isPointwise() { + return false; + } + + @Override + public long getNumBytesProduced() { + return producedBytes; + } + + @Override + public void recordPartitionInfo(int partitionIndex, ResultPartitionBytes partitionBytes) {} + + @Override + public void resetPartitionInfo(int partitionIndex) {} + } + + private static BlockingResultInfo createFromBroadcastResult(long producedBytes) { + return new TestingBlockingResultInfo(true, producedBytes); + } + + private static BlockingResultInfo createFromNonBroadcastResult(long producedBytes) { + return new TestingBlockingResultInfo(false, producedBytes); + } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/PointwiseBlockingResultInfoTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/PointwiseBlockingResultInfoTest.java new file mode 100644 index 00000000000..556a2d48876 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/PointwiseBlockingResultInfoTest.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.scheduler.adaptivebatch; + +import org.apache.flink.runtime.executiongraph.ResultPartitionBytes; +import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; + +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** Test for {@link PointwiseBlockingResultInfo}. */ +class PointwiseBlockingResultInfoTest { + + @Test + void testGetNumBytesProduced() { + PointwiseBlockingResultInfo resultInfo = + new PointwiseBlockingResultInfo(new IntermediateDataSetID(), 2, 2); + resultInfo.recordPartitionInfo(0, new ResultPartitionBytes(new long[] {32L, 32L})); + resultInfo.recordPartitionInfo(1, new ResultPartitionBytes(new long[] {64L, 64L})); + + assertThat(resultInfo.getNumBytesProduced()).isEqualTo(192L); + } + + @Test + void testGetBytesWithPartialPartitionInfos() { + PointwiseBlockingResultInfo resultInfo = + new PointwiseBlockingResultInfo(new IntermediateDataSetID(), 2, 2); + resultInfo.recordPartitionInfo(0, new ResultPartitionBytes(new long[] {32L, 64L})); + assertThatThrownBy(resultInfo::getNumBytesProduced) + .isInstanceOf(IllegalStateException.class); + } + + @Test + void testPartitionFinishedMultiTimes() { + PointwiseBlockingResultInfo resultInfo = + new PointwiseBlockingResultInfo(new IntermediateDataSetID(), 2, 2); + + resultInfo.recordPartitionInfo(0, new ResultPartitionBytes(new long[] {32L, 64L})); + resultInfo.recordPartitionInfo(1, new ResultPartitionBytes(new long[] {64L, 128L})); + assertThat(resultInfo.getNumOfRecordedPartitions()).isEqualTo(2); + assertThat(resultInfo.getNumBytesProduced()).isEqualTo(288L); + + // reset partition info + resultInfo.resetPartitionInfo(0); + assertThat(resultInfo.getNumOfRecordedPartitions()).isEqualTo(1); + + // record partition info again + resultInfo.recordPartitionInfo(0, new ResultPartitionBytes(new long[] {64L, 128L})); + assertThat(resultInfo.getNumBytesProduced()).isEqualTo(384L); + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/SpeculativeSchedulerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/SpeculativeSchedulerTest.java index b438240a599..189c75c321c 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/SpeculativeSchedulerTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/SpeculativeSchedulerTest.java @@ -33,7 +33,6 @@ import org.apache.flink.runtime.executiongraph.Execution; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.executiongraph.ExecutionJobVertex; import org.apache.flink.runtime.executiongraph.ExecutionVertex; -import org.apache.flink.runtime.executiongraph.IOMetrics; import org.apache.flink.runtime.executiongraph.failover.flip1.TestRestartBackoffTimeStrategy; import org.apache.flink.runtime.io.network.partition.PartitionNotFoundException; import org.apache.flink.runtime.io.network.partition.ResultPartitionID; @@ -75,8 +74,11 @@ import java.util.stream.Collectors; import static org.apache.flink.runtime.executiongraph.ExecutionGraphTestUtils.completeCancellingForAllVertices; import static org.apache.flink.runtime.executiongraph.ExecutionGraphTestUtils.createNoOpVertex; -import static org.apache.flink.runtime.scheduler.DefaultSchedulerTest.createFailedTaskExecutionState; import static org.apache.flink.runtime.scheduler.DefaultSchedulerTest.singleNonParallelJobVertexJobGraph; +import static org.apache.flink.runtime.scheduler.SchedulerTestingUtils.createCanceledTaskExecutionState; +import static org.apache.flink.runtime.scheduler.SchedulerTestingUtils.createFailedTaskExecutionState; +import static org.apache.flink.runtime.scheduler.SchedulerTestingUtils.createFinishedTaskExecutionState; +import static org.apache.flink.runtime.scheduler.adaptivebatch.AdaptiveBatchSchedulerTest.createResultPartitionBytesForExecution; import static org.assertj.core.api.Assertions.assertThat; /** Tests for {@link SpeculativeScheduler}. */ @@ -212,9 +214,8 @@ class SpeculativeSchedulerTest { notifySlowTask(scheduler, attempt1); final Execution attempt2 = getExecution(ev, 1); scheduler.updateTaskExecutionState( - new TaskExecutionState( + createFailedTaskExecutionState( attempt1.getAttemptId(), - ExecutionState.FAILED, new PartitionNotFoundException(new ResultPartitionID()))); assertThat(attempt2.getState()).isEqualTo(ExecutionState.CANCELING); @@ -234,7 +235,7 @@ class SpeculativeSchedulerTest { notifySlowTask(scheduler, attempt1); final Execution attempt2 = getExecution(ev, 1); scheduler.updateTaskExecutionState( - new TaskExecutionState(attempt1.getAttemptId(), ExecutionState.FINISHED)); + createFinishedTaskExecutionState(attempt1.getAttemptId())); assertThat(attempt2.getState()).isEqualTo(ExecutionState.CANCELING); } @@ -251,7 +252,7 @@ class SpeculativeSchedulerTest { notifySlowTask(scheduler, attempt1); final Execution attempt2 = getExecution(ev, 1); scheduler.updateTaskExecutionState( - new TaskExecutionState(attempt1.getAttemptId(), ExecutionState.FINISHED)); + createFinishedTaskExecutionState(attempt1.getAttemptId())); assertThat(attempt2.getState()).isEqualTo(ExecutionState.CANCELED); } @@ -266,9 +267,8 @@ class SpeculativeSchedulerTest { // A partition exception can result in a restart of the whole execution vertex. scheduler.updateTaskExecutionState( - new TaskExecutionState( + createFailedTaskExecutionState( attempt1.getAttemptId(), - ExecutionState.FAILED, new PartitionNotFoundException(new ResultPartitionID()))); completeCancellingForAllVertices(scheduler.getExecutionGraph()); @@ -316,9 +316,8 @@ class SpeculativeSchedulerTest { notifySlowTask(scheduler, attempt1); final TaskExecutionState failedState = - new TaskExecutionState( + createFailedTaskExecutionState( attempt1.getAttemptId(), - ExecutionState.FAILED, new SuppressRestartsException( new Exception("Forced failure for testing."))); scheduler.updateTaskExecutionState(failedState); @@ -376,12 +375,9 @@ class SpeculativeSchedulerTest { // Finishing any source execution attempt will finish the source execution vertex, and then // finish the job vertex. scheduler.updateTaskExecutionState( - new TaskExecutionState( + createFinishedTaskExecutionState( sourceAttempt1.getAttemptId(), - ExecutionState.FINISHED, - null, - null, - new IOMetrics(0, 0, 0, 0, 0, 0, 0))); + createResultPartitionBytesForExecution(sourceAttempt1))); assertThat(sinkExecutionJobVertex.getParallelism()).isEqualTo(3); // trigger sink vertex speculation @@ -421,12 +417,12 @@ class SpeculativeSchedulerTest { // finishes first final Execution attempt2 = getExecution(ev, 1); scheduler.updateTaskExecutionState( - new TaskExecutionState(attempt2.getAttemptId(), ExecutionState.FINISHED)); + createFinishedTaskExecutionState(attempt2.getAttemptId())); assertThat(scheduler.getNumEffectiveSpeculativeExecutions()).isEqualTo(1); // complete cancellation scheduler.updateTaskExecutionState( - new TaskExecutionState(attempt1.getAttemptId(), ExecutionState.CANCELED)); + createCanceledTaskExecutionState(attempt1.getAttemptId())); // trigger a global failure to reset the vertex. // after that, no speculative execution finishes before its original execution and the @@ -441,7 +437,7 @@ class SpeculativeSchedulerTest { // numEffectiveSpeculativeExecutions will not increase if an original execution attempt // finishes first scheduler.updateTaskExecutionState( - new TaskExecutionState(attempt3.getAttemptId(), ExecutionState.FINISHED)); + createFinishedTaskExecutionState(attempt3.getAttemptId())); assertThat(scheduler.getNumEffectiveSpeculativeExecutions()).isZero(); }
