This is an automated email from the ASF dual-hosted git repository.
panyuepeng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push:
new 7b1e2129a65 [FLINK-38943][runtime] Support Adaptive Partition
Selection for RescalePartitioner & RebalancePartitioner (#27446)
7b1e2129a65 is described below
commit 7b1e2129a6565b0edcce6f6990dea74df23df9b3
Author: Yuepeng Pan <[email protected]>
AuthorDate: Tue Feb 10 11:19:05 2026 +0800
[FLINK-38943][runtime] Support Adaptive Partition Selection for
RescalePartitioner & RebalancePartitioner (#27446)
Co-authored-by: Tartarus0zm <[email protected]>
Co-authored-by: 1996fanrui <[email protected]>
---
.../generated/all_taskmanager_network_section.html | 12 ++
.../netty_shuffle_environment_configuration.html | 12 ++
.../NettyShuffleEnvironmentOptions.java | 28 +++
.../api/writer/AdaptiveLoadBasedRecordWriter.java | 139 +++++++++++++
.../io/network/api/writer/RecordWriterBuilder.java | 23 ++-
.../network/api/writer/ResultPartitionWriter.java | 8 +
.../runtime/io/network/buffer/BufferPool.java | 5 +
.../runtime/io/network/buffer/LocalBufferPool.java | 5 +
.../partition/BufferWritingResultPartition.java | 32 ++-
.../io/network/partition/ResultPartition.java | 4 +
.../flink/streaming/runtime/tasks/StreamTask.java | 27 +++
.../writer/AdaptiveLoadBasedRecordWriterTest.java | 224 +++++++++++++++++++++
.../streaming/runtime/tasks/StreamTaskTest.java | 23 +++
13 files changed, 531 insertions(+), 11 deletions(-)
diff --git
a/docs/layouts/shortcodes/generated/all_taskmanager_network_section.html
b/docs/layouts/shortcodes/generated/all_taskmanager_network_section.html
index 0036d781c12..7299bbab326 100644
--- a/docs/layouts/shortcodes/generated/all_taskmanager_network_section.html
+++ b/docs/layouts/shortcodes/generated/all_taskmanager_network_section.html
@@ -8,6 +8,18 @@
</tr>
</thead>
<tbody>
+ <tr>
+ <td><h5>taskmanager.network.adaptive-partitioner.enabled</h5></td>
+ <td style="word-wrap: break-word;">false</td>
+ <td>Boolean</td>
+ <td>Whether to enable adaptive partitioner feature for rescale and
rebalance partitioners based on the load of the downstream tasks.</td>
+ </tr>
+ <tr>
+
<td><h5>taskmanager.network.adaptive-partitioner.max-traverse-size</h5></td>
+ <td style="word-wrap: break-word;">4</td>
+ <td>Integer</td>
+ <td>Maximum number of channels to traverse when looking for the
most idle channel for rescale and rebalance partitioners when <code
class="highlighter-rouge">taskmanager.network.adaptive-partitioner.enabled</code>
is enabled.<br />Note, the value of the configuration option must be greater
than `1`.</td>
+ </tr>
<tr>
<td><h5>taskmanager.network.compression.codec</h5></td>
<td style="word-wrap: break-word;">LZ4</td>
diff --git
a/docs/layouts/shortcodes/generated/netty_shuffle_environment_configuration.html
b/docs/layouts/shortcodes/generated/netty_shuffle_environment_configuration.html
index 3e6012bea1d..7e851456b0d 100644
---
a/docs/layouts/shortcodes/generated/netty_shuffle_environment_configuration.html
+++
b/docs/layouts/shortcodes/generated/netty_shuffle_environment_configuration.html
@@ -26,6 +26,18 @@
<td>Boolean</td>
<td>Enable SSL support for the taskmanager data transport. This is
applicable only when the global flag for internal SSL
(security.ssl.internal.enabled) is set to true</td>
</tr>
+ <tr>
+ <td><h5>taskmanager.network.adaptive-partitioner.enabled</h5></td>
+ <td style="word-wrap: break-word;">false</td>
+ <td>Boolean</td>
+ <td>Whether to enable adaptive partitioner feature for rescale and
rebalance partitioners based on the load of the downstream tasks.</td>
+ </tr>
+ <tr>
+
<td><h5>taskmanager.network.adaptive-partitioner.max-traverse-size</h5></td>
+ <td style="word-wrap: break-word;">4</td>
+ <td>Integer</td>
+ <td>Maximum number of channels to traverse when looking for the
most idle channel for rescale and rebalance partitioners when <code
class="highlighter-rouge">taskmanager.network.adaptive-partitioner.enabled</code>
is enabled.<br />Note, the value of the configuration option must be greater
than `1`.</td>
+ </tr>
<tr>
<td><h5>taskmanager.network.compression.codec</h5></td>
<td style="word-wrap: break-word;">LZ4</td>
diff --git
a/flink-core/src/main/java/org/apache/flink/configuration/NettyShuffleEnvironmentOptions.java
b/flink-core/src/main/java/org/apache/flink/configuration/NettyShuffleEnvironmentOptions.java
index 018b11e5ecd..d9be6cb1b27 100644
---
a/flink-core/src/main/java/org/apache/flink/configuration/NettyShuffleEnvironmentOptions.java
+++
b/flink-core/src/main/java/org/apache/flink/configuration/NettyShuffleEnvironmentOptions.java
@@ -325,6 +325,34 @@ public class NettyShuffleEnvironmentOptions {
code(NETWORK_REQUEST_BACKOFF_MAX.key()))
.build());
+ /** Whether to improve the rebalance and rescale partitioners to adaptive
partition. */
+ @Documentation.Section(Documentation.Sections.ALL_TASK_MANAGER_NETWORK)
+ public static final ConfigOption<Boolean> ADAPTIVE_PARTITIONER_ENABLED =
+ key("taskmanager.network.adaptive-partitioner.enabled")
+ .booleanType()
+ .defaultValue(false)
+ .withDescription(
+ "Whether to enable adaptive partitioner feature
for rescale and rebalance partitioners based on the load of the downstream
tasks.");
+
+ /**
+ * Maximum number of channels to traverse when looking for the most idle
channel for rescale and
+ * rebalance partitioners when {@link #ADAPTIVE_PARTITIONER_ENABLED} is
true.
+ */
+ @Documentation.Section(Documentation.Sections.ALL_TASK_MANAGER_NETWORK)
+ public static final ConfigOption<Integer>
ADAPTIVE_PARTITIONER_MAX_TRAVERSE_SIZE =
+ key("taskmanager.network.adaptive-partitioner.max-traverse-size")
+ .intType()
+ .defaultValue(4)
+ .withDescription(
+ Description.builder()
+ .text(
+ "Maximum number of channels to
traverse when looking for the most idle channel for rescale and rebalance
partitioners when %s is enabled.",
+
code(ADAPTIVE_PARTITIONER_ENABLED.key()))
+ .linebreak()
+ .text(
+ "Note, the value of the
configuration option must be greater than `1`.")
+ .build());
+
// ------------------------------------------------------------------------
/** Not intended to be instantiated. */
diff --git
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/AdaptiveLoadBasedRecordWriter.java
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/AdaptiveLoadBasedRecordWriter.java
new file mode 100644
index 00000000000..19eb08e6571
--- /dev/null
+++
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/AdaptiveLoadBasedRecordWriter.java
@@ -0,0 +1,139 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.io.network.api.writer;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.core.io.IOReadableWritable;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+
+/**
+ * A record writer based on load of downstream tasks for {@link
+ * org.apache.flink.streaming.runtime.partitioner.RescalePartitioner} and
{@link
+ * org.apache.flink.streaming.runtime.partitioner.RebalancePartitioner}.
+ *
+ * <pre>
+ *
+ * Here are clarifications for some items to provide quick understanding.
+ *
+ * - Two new immutable attributes are introduced in this class:
+ * -- `numberOfSubpartitions` represents the number of downstream partitions
that can be written to.
+ * -- `maxTraverseSize` represents the maximum number of partitions that the
current partition selector can compare when performing rescale or rebalance.
+ *
+ * - Why do `maxTraverseSize` and `numberOfSubpartitions` not share a common
attribute ?
+ * If the same field were shared and `maxTraverseSize` were less than
`numberOfSubpartitions` (e.g., 2 < 6), it would result in some downstream
partitions (4 in this case) never being written to, which is incorrect behavior.
+ *
+ * - Why is it described that users cannot explicitly configure
`maxTraverseSize` as 1 ?
+ * Users should not explicitly set it to 1, as this would mean no load
comparison is performed, effectively disabling the adaptive partitioning
feature.
+ *
+ * - Why the internal value of `maxTraverseSize` may become 1:
+ * This is reasonable if and only if the number of downstream partitions is
exactly 1 (since no comparison is needed). This situation can arise from
framework behaviors such as the {@link
org.apache.flink.runtime.scheduler.adaptive.AdaptiveScheduler}, which are not
directly controlled by users.
+ * For example, when the following job enables the AdaptiveScheduler before
rescaling:
+ *
+ * JobVertexA(parallelism=4, slotSharingGroup=SSG-A) --(rescale)-->
JobVertexB(parallelism=5, slotSharingGroup=SSG-B)
+ *
+ * If the job scales down and only 2 slots are available, the parallelism
configuration of the job changes to:
+ *
+ * JobVertexA(parallelism=1, slotSharingGroup=SSG-A) --(rescale)-->
JobVertexB(parallelism=1, slotSharingGroup=SSG-B)
+ *
+ * In this case, the task of JobVertexA has only one writable downstream
partition, so a `maxTraverseSize` of 1 is reasonable and meaningful.
+ *
+ * </pre>
+ *
+ * @param <T> The type of IOReadableWritable records.
+ */
+@Internal
+public final class AdaptiveLoadBasedRecordWriter<T extends IOReadableWritable>
+ extends RecordWriter<T> {
+
+ private final int maxTraverseSize;
+ private final int numberOfSubpartitions;
+ private int currentChannel = -1;
+
+ AdaptiveLoadBasedRecordWriter(
+ ResultPartitionWriter writer, long timeout, String taskName, int
maxTraverseSize) {
+ super(writer, timeout, taskName);
+ this.numberOfSubpartitions = writer.getNumberOfSubpartitions();
+ this.maxTraverseSize = Math.min(maxTraverseSize,
numberOfSubpartitions);
+ }
+
+ @Override
+ public void emit(T record) throws IOException {
+ checkErroneous();
+
+ currentChannel = getIdlestChannelIndex();
+
+ ByteBuffer byteBuffer = serializeRecord(serializer, record);
+ targetPartition.emitRecord(byteBuffer, currentChannel);
+
+ if (flushAlways) {
+ targetPartition.flush(currentChannel);
+ }
+ }
+
+ @VisibleForTesting
+ int getIdlestChannelIndex() {
+ int bestChannelBuffersCount = Integer.MAX_VALUE;
+ long bestChannelBytesInQueue = Long.MAX_VALUE;
+ int bestChannel = 0;
+ for (int i = 1; i <= maxTraverseSize; i++) {
+ int candidateChannel = (currentChannel + i) %
numberOfSubpartitions;
+ int candidateChannelBuffersCount =
+ targetPartition.getBuffersCountUnsafe(candidateChannel);
+ long candidateChannelBytesInQueue =
+ targetPartition.getBytesInQueueUnsafe(candidateChannel);
+
+ if (candidateChannelBuffersCount == 0) {
+ // If there isn't any pending data in the current channel,
choose this channel
+ // directly.
+ return candidateChannel;
+ }
+
+ if (candidateChannelBuffersCount < bestChannelBuffersCount
+ || (candidateChannelBuffersCount == bestChannelBuffersCount
+ && candidateChannelBytesInQueue <
bestChannelBytesInQueue)) {
+ bestChannel = candidateChannel;
+ bestChannelBuffersCount = candidateChannelBuffersCount;
+ bestChannelBytesInQueue = candidateChannelBytesInQueue;
+ }
+ }
+ return bestChannel;
+ }
+
+ /** Copy from {@link ChannelSelectorRecordWriter#broadcastEmit}. */
+ @Override
+ public void broadcastEmit(T record) throws IOException {
+ checkErroneous();
+
+ // Emitting to all channels in a for loop can be better than calling
+ // ResultPartitionWriter#broadcastRecord because the broadcastRecord
+ // method incurs extra overhead.
+ ByteBuffer serializedRecord = serializeRecord(serializer, record);
+ for (int channelIndex = 0; channelIndex < numberOfSubpartitions;
channelIndex++) {
+ serializedRecord.rewind();
+ emit(record, channelIndex);
+ }
+
+ if (flushAlways) {
+ flushAll();
+ }
+ }
+}
diff --git
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/RecordWriterBuilder.java
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/RecordWriterBuilder.java
index 78e6424844d..d730a73a7fe 100644
---
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/RecordWriterBuilder.java
+++
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/RecordWriterBuilder.java
@@ -18,6 +18,7 @@
package org.apache.flink.runtime.io.network.api.writer;
+import org.apache.flink.configuration.NettyShuffleEnvironmentOptions;
import org.apache.flink.core.io.IOReadableWritable;
/** Utility class to encapsulate the logic of building a {@link RecordWriter}
instance. */
@@ -29,6 +30,11 @@ public class RecordWriterBuilder<T extends
IOReadableWritable> {
private String taskName = "test";
+ private boolean enabledAdaptivePartitioner = false;
+
+ private int maxTraverseSize =
+
NettyShuffleEnvironmentOptions.ADAPTIVE_PARTITIONER_MAX_TRAVERSE_SIZE.defaultValue();
+
public RecordWriterBuilder<T> setChannelSelector(ChannelSelector<T>
selector) {
this.selector = selector;
return this;
@@ -44,11 +50,24 @@ public class RecordWriterBuilder<T extends
IOReadableWritable> {
return this;
}
+ public RecordWriterBuilder<T> setEnabledAdaptivePartitioner(
+ boolean enabledAdaptivePartitioner) {
+ this.enabledAdaptivePartitioner = enabledAdaptivePartitioner;
+ return this;
+ }
+
+ public RecordWriterBuilder<T> setMaxTraverseSize(int maxTraverseSize) {
+ this.maxTraverseSize = maxTraverseSize;
+ return this;
+ }
+
public RecordWriter<T> build(ResultPartitionWriter writer) {
if (selector.isBroadcast()) {
return new BroadcastRecordWriter<>(writer, timeout, taskName);
- } else {
- return new ChannelSelectorRecordWriter<>(writer, selector,
timeout, taskName);
}
+ if (enabledAdaptivePartitioner) {
+ return new AdaptiveLoadBasedRecordWriter<>(writer, timeout,
taskName, maxTraverseSize);
+ }
+ return new ChannelSelectorRecordWriter<>(writer, selector, timeout,
taskName);
}
}
diff --git
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ResultPartitionWriter.java
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ResultPartitionWriter.java
index e283fac596f..04cfa0ad33d 100644
---
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ResultPartitionWriter.java
+++
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ResultPartitionWriter.java
@@ -60,6 +60,14 @@ public interface ResultPartitionWriter extends
AutoCloseable, AvailabilityProvid
/** Writes the given serialized record to the target subpartition. */
void emitRecord(ByteBuffer record, int targetSubpartition) throws
IOException;
+ default long getBytesInQueueUnsafe(int targetSubpartition) {
+ return 0;
+ }
+
+ default int getBuffersCountUnsafe(int targetSubpartition) {
+ return 0;
+ }
+
/**
* Writes the given serialized record to all subpartitions. One can also
achieve the same effect
* by emitting the same record to all subpartitions one by one, however,
this method can have
diff --git
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/BufferPool.java
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/BufferPool.java
index c574607e28e..5061d08353d 100644
---
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/BufferPool.java
+++
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/BufferPool.java
@@ -75,4 +75,9 @@ public interface BufferPool extends BufferProvider,
BufferRecycler {
/** Returns the number of used buffers of this buffer pool. */
int bestEffortGetNumOfUsedBuffers();
+
+ /** Returns the requested buffer count for target channel. */
+ default int getBuffersCountUnsafe(int targetChannel) {
+ return 0;
+ }
}
diff --git
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/LocalBufferPool.java
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/LocalBufferPool.java
index 873414c6fe2..f31bd95f1a1 100644
---
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/LocalBufferPool.java
+++
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/LocalBufferPool.java
@@ -824,4 +824,9 @@ public class LocalBufferPool implements BufferPool {
}
}
}
+
+ @Override
+ public int getBuffersCountUnsafe(int targetChannel) {
+ return subpartitionBuffersCount[targetChannel];
+ }
}
diff --git
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/BufferWritingResultPartition.java
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/BufferWritingResultPartition.java
index eae9260642a..334647a367b 100644
---
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/BufferWritingResultPartition.java
+++
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/BufferWritingResultPartition.java
@@ -65,7 +65,7 @@ public abstract class BufferWritingResultPartition extends
ResultPartition {
private TimerGauge hardBackPressuredTimeMsPerSecond = new TimerGauge();
- private long totalWrittenBytes;
+ private final long[] writtenBytesPerSubpartition;
public BufferWritingResultPartition(
String owningTaskName,
@@ -91,6 +91,7 @@ public abstract class BufferWritingResultPartition extends
ResultPartition {
this.subpartitions = checkNotNull(subpartitions);
this.unicastBufferBuilders = new BufferBuilder[subpartitions.length];
+ this.writtenBytesPerSubpartition = new long[subpartitions.length];
}
@Override
@@ -114,6 +115,11 @@ public abstract class BufferWritingResultPartition extends
ResultPartition {
@Override
public long getSizeOfQueuedBuffersUnsafe() {
+ long totalWrittenBytes = 0;
+ for (int i = 0; i < subpartitions.length; i++) {
+ totalWrittenBytes += writtenBytesPerSubpartition[i];
+ }
+
long totalNumberOfBytes = 0;
for (ResultSubpartition subpartition : subpartitions) {
@@ -123,6 +129,12 @@ public abstract class BufferWritingResultPartition extends
ResultPartition {
return totalWrittenBytes - totalNumberOfBytes;
}
+ @Override
+ public long getBytesInQueueUnsafe(int targetSubpartition) {
+ return writtenBytesPerSubpartition[targetSubpartition]
+ -
subpartitions[targetSubpartition].getTotalNumberOfBytesUnsafe();
+ }
+
@Override
public int getNumberOfQueuedBuffers(int targetSubpartition) {
checkArgument(targetSubpartition >= 0 && targetSubpartition <
numSubpartitions);
@@ -151,7 +163,7 @@ public abstract class BufferWritingResultPartition extends
ResultPartition {
@Override
public void emitRecord(ByteBuffer record, int targetSubpartition) throws
IOException {
- totalWrittenBytes += record.remaining();
+ writtenBytesPerSubpartition[targetSubpartition] += record.remaining();
BufferBuilder buffer = appendUnicastDataForNewRecord(record,
targetSubpartition);
@@ -171,7 +183,9 @@ public abstract class BufferWritingResultPartition extends
ResultPartition {
@Override
public void broadcastRecord(ByteBuffer record) throws IOException {
- totalWrittenBytes += ((long) record.remaining() * numSubpartitions);
+ for (int i = 0; i < subpartitions.length; i++) {
+ writtenBytesPerSubpartition[i] += record.remaining();
+ }
BufferBuilder buffer = appendBroadcastDataForNewRecord(record);
@@ -197,11 +211,11 @@ public abstract class BufferWritingResultPartition
extends ResultPartition {
try (BufferConsumer eventBufferConsumer =
EventSerializer.toBufferConsumer(event, isPriorityEvent)) {
- totalWrittenBytes += ((long) eventBufferConsumer.getWrittenBytes()
* numSubpartitions);
- for (ResultSubpartition subpartition : subpartitions) {
+ for (int i = 0; i < subpartitions.length; i++) {
// Retain the buffer so that it can be recycled by each
subpartition of
// targetPartition
- subpartition.add(eventBufferConsumer.copy(), 0);
+ subpartitions[i].add(eventBufferConsumer.copy(), 0);
+ writtenBytesPerSubpartition[i] +=
eventBufferConsumer.getWrittenBytes();
}
}
}
@@ -246,8 +260,8 @@ public abstract class BufferWritingResultPartition extends
ResultPartition {
finishBroadcastBufferBuilder();
finishUnicastBufferBuilders();
- for (ResultSubpartition subpartition : subpartitions) {
- totalWrittenBytes += subpartition.finish();
+ for (int i = 0; i < subpartitions.length; i++) {
+ writtenBytesPerSubpartition[i] += subpartitions[i].finish();
}
super.finish();
@@ -340,7 +354,7 @@ public abstract class BufferWritingResultPartition extends
ResultPartition {
protected int addToSubpartition(
int targetSubpartition, BufferConsumer bufferConsumer, int
partialRecordLength)
throws IOException {
- totalWrittenBytes += bufferConsumer.getWrittenBytes();
+ writtenBytesPerSubpartition[targetSubpartition] +=
bufferConsumer.getWrittenBytes();
return subpartitions[targetSubpartition].add(bufferConsumer,
partialRecordLength);
}
diff --git
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartition.java
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartition.java
index 6cbcfc0c598..47b52caa8d6 100644
---
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartition.java
+++
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartition.java
@@ -202,6 +202,10 @@ public abstract class ResultPartition implements
ResultPartitionWriter {
/** Returns the number of queued buffers of the given target subpartition.
*/
public abstract int getNumberOfQueuedBuffers(int targetSubpartition);
+ public int getBuffersCountUnsafe(int targetSubpartition) {
+ return bufferPool.getBuffersCountUnsafe(targetSubpartition);
+ }
+
public void setMaxOverdraftBuffersPerGate(int maxOverdraftBuffersPerGate) {
this.bufferPool.setMaxOverdraftBuffersPerGate(maxOverdraftBuffersPerGate);
}
diff --git
a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
index 505e67a30d7..07167659cb0 100644
---
a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
+++
b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
@@ -23,6 +23,7 @@ import org.apache.flink.api.common.operators.MailboxExecutor;
import
org.apache.flink.api.common.operators.ProcessingTimeService.ProcessingTimeCallback;
import org.apache.flink.configuration.CheckpointingOptions;
import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.NettyShuffleEnvironmentOptions;
import org.apache.flink.configuration.TaskManagerOptions;
import org.apache.flink.core.execution.RecoveryClaimMode;
import org.apache.flink.core.fs.AutoCloseableRegistry;
@@ -80,6 +81,7 @@ import
org.apache.flink.runtime.taskmanager.AsyncExceptionHandler;
import org.apache.flink.runtime.taskmanager.AsynchronousException;
import org.apache.flink.runtime.taskmanager.DispatcherThreadFactory;
import org.apache.flink.runtime.taskmanager.Task;
+import org.apache.flink.runtime.util.ConfigurationParserUtils;
import org.apache.flink.streaming.api.graph.NonChainedOutput;
import org.apache.flink.streaming.api.graph.StreamConfig;
import org.apache.flink.streaming.api.operators.InternalTimeServiceManager;
@@ -95,6 +97,7 @@ import
org.apache.flink.streaming.runtime.io.checkpointing.CheckpointBarrierHand
import
org.apache.flink.streaming.runtime.partitioner.ConfigurableStreamPartitioner;
import org.apache.flink.streaming.runtime.partitioner.ForwardPartitioner;
import org.apache.flink.streaming.runtime.partitioner.RebalancePartitioner;
+import org.apache.flink.streaming.runtime.partitioner.RescalePartitioner;
import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.streaming.runtime.tasks.mailbox.GaugePeriodTimer;
@@ -1830,17 +1833,41 @@ public abstract class StreamTask<OUT, OP extends
StreamOperator<OUT>>
((ConfigurableStreamPartitioner)
outputPartitioner).configure(numKeyGroups);
}
}
+ Configuration conf = environment.getJobConfiguration();
+ final boolean enabledAdaptivePartitioner =
+ (outputPartitioner instanceof RebalancePartitioner
+ || outputPartitioner instanceof
RescalePartitioner)
+ &&
conf.get(NettyShuffleEnvironmentOptions.ADAPTIVE_PARTITIONER_ENABLED)
+ && bufferWriter.getNumberOfSubpartitions() > 1;
+ final int maxTraverseSize = getAndCheckMaxTraverseSize(conf);
RecordWriter<SerializationDelegate<StreamRecord<OUT>>> output =
new
RecordWriterBuilder<SerializationDelegate<StreamRecord<OUT>>>()
.setChannelSelector(outputPartitioner)
.setTimeout(bufferTimeout)
.setTaskName(taskNameWithSubtask)
+
.setEnabledAdaptivePartitioner(enabledAdaptivePartitioner)
+ .setMaxTraverseSize(maxTraverseSize)
.build(bufferWriter);
output.setMetricGroup(environment.getMetricGroup().getIOMetricGroup());
return output;
}
+ @VisibleForTesting
+ static int getAndCheckMaxTraverseSize(Configuration jobConf) {
+ final int maxTraverseSize =
+
jobConf.get(NettyShuffleEnvironmentOptions.ADAPTIVE_PARTITIONER_MAX_TRAVERSE_SIZE);
+ ConfigurationParserUtils.checkConfigParameter(
+ maxTraverseSize > 1,
+ maxTraverseSize,
+
NettyShuffleEnvironmentOptions.ADAPTIVE_PARTITIONER_MAX_TRAVERSE_SIZE.key(),
+ String.format(
+ "The value of '%s' must be greater than 1 when '%s' is
enabled.",
+
NettyShuffleEnvironmentOptions.ADAPTIVE_PARTITIONER_MAX_TRAVERSE_SIZE.key(),
+
NettyShuffleEnvironmentOptions.ADAPTIVE_PARTITIONER_ENABLED.key()));
+ return maxTraverseSize;
+ }
+
private void handleTimerException(Exception ex) {
handleAsyncException("Caught exception while processing timer.", new
TimerException(ex));
}
diff --git
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/AdaptiveLoadBasedRecordWriterTest.java
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/AdaptiveLoadBasedRecordWriterTest.java
new file mode 100644
index 00000000000..b835a31a5da
--- /dev/null
+++
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/AdaptiveLoadBasedRecordWriterTest.java
@@ -0,0 +1,224 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.io.network.api.writer;
+
+import org.apache.flink.core.io.IOReadableWritable;
+import org.apache.flink.runtime.checkpoint.CheckpointException;
+import org.apache.flink.runtime.event.AbstractEvent;
+import org.apache.flink.runtime.io.network.api.StopMode;
+import
org.apache.flink.runtime.io.network.partition.BufferAvailabilityListener;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
+import
org.apache.flink.runtime.io.network.partition.ResultSubpartitionIndexSet;
+import org.apache.flink.runtime.io.network.partition.ResultSubpartitionView;
+import org.apache.flink.runtime.metrics.groups.TaskIOMetricGroup;
+
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.Arguments;
+import org.junit.jupiter.params.provider.MethodSource;
+
+import javax.annotation.Nullable;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.concurrent.CompletableFuture;
+import java.util.stream.Stream;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/** Test for {@link AdaptiveLoadBasedRecordWriter}. */
+class AdaptiveLoadBasedRecordWriterTest {
+
+ static Stream<Arguments> getTestingParams() {
+ return Stream.of(
+ // maxTraverseSize, bytesPerPartition, bufferPerPartition,
+ // targetResultPartitionIndex
+ Arguments.of(2, new long[] {1L, 2L, 3L}, new int[] {2, 3, 4},
0),
+ Arguments.of(2, new long[] {0L, 0L, 0L}, new int[] {2, 3, 4},
0),
+ Arguments.of(2, new long[] {0L, 0L, 0L}, new int[] {0, 0, 0},
0),
+ Arguments.of(3, new long[] {1L, 2L, 3L}, new int[] {2, 3, 4},
0),
+ Arguments.of(3, new long[] {0L, 0L, 0L}, new int[] {2, 3, 4},
0),
+ Arguments.of(3, new long[] {0L, 0L, 0L}, new int[] {0, 0, 0},
0),
+ Arguments.of(
+ 2, new long[] {1L, 2L, 3L, 1L, 2L, 3L}, new int[] {2,
3, 4, 2, 3, 4}, 0),
+ Arguments.of(
+ 2, new long[] {0L, 0L, 3L, 1L, 2L, 3L}, new int[] {3,
2, 4, 2, 3, 4}, 1),
+ Arguments.of(
+ 2, new long[] {0L, 0L, 3L, 1L, 2L, 3L}, new int[] {0,
0, 4, 2, 3, 4}, 0),
+ Arguments.of(
+ 4, new long[] {1L, 2L, 3L, 0L, 2L, 3L}, new int[] {2,
3, 4, 2, 3, 4}, 3),
+ Arguments.of(
+ 4, new long[] {1L, 1L, 1L, 1L, 2L, 3L}, new int[] {2,
3, 4, 0, 3, 4}, 3),
+ Arguments.of(
+ 4, new long[] {0L, 0L, 0L, 0L, 2L, 3L}, new int[] {2,
3, 0, 2, 3, 4}, 2));
+ }
+
+ @ParameterizedTest(
+ name =
+ "maxTraverseSize: {0}, bytesPerPartition: {1},
bufferPerPartition: {2}, targetResultPartitionIndex: {3}")
+ @MethodSource("getTestingParams")
+ void testGetIdlestChannelIndex(
+ int maxTraverseSize,
+ long[] bytesPerPartition,
+ int[] buffersPerPartition,
+ int targetResultPartitionIndex) {
+ TestingResultPartitionWriter resultPartitionWriter =
+ getTestingResultPartitionWriter(bytesPerPartition,
buffersPerPartition);
+
+ AdaptiveLoadBasedRecordWriter<IOReadableWritable>
adaptiveLoadBasedRecordWriter =
+ new AdaptiveLoadBasedRecordWriter<>(
+ resultPartitionWriter, 5L, "testingTask",
maxTraverseSize);
+ assertThat(adaptiveLoadBasedRecordWriter.getIdlestChannelIndex())
+ .isEqualTo(targetResultPartitionIndex);
+ }
+
+ private static TestingResultPartitionWriter
getTestingResultPartitionWriter(
+ long[] bytesPerPartition, int[] buffersPerPartition) {
+ final Map<Integer, Long> bytesPerPartitionMap = new HashMap<>();
+ final Map<Integer, Integer> bufferPerPartitionMap = new HashMap<>();
+ for (int i = 0; i < bytesPerPartition.length; i++) {
+ bytesPerPartitionMap.put(i, bytesPerPartition[i]);
+ bufferPerPartitionMap.put(i, buffersPerPartition[i]);
+ }
+
+ return new TestingResultPartitionWriter(
+ buffersPerPartition.length, bytesPerPartitionMap,
bufferPerPartitionMap);
+ }
+
+ /** Test utils class to simulate {@link ResultPartitionWriter}. */
+ static final class TestingResultPartitionWriter implements
ResultPartitionWriter {
+
+ private final int numberOfSubpartitions;
+ private final Map<Integer, Long> bytesPerPartition;
+ private final Map<Integer, Integer> bufferPerPartition;
+
+ TestingResultPartitionWriter(
+ int numberOfSubpartitions,
+ Map<Integer, Long> bytesPerPartition,
+ Map<Integer, Integer> bufferPerPartition) {
+ this.numberOfSubpartitions = numberOfSubpartitions;
+ this.bytesPerPartition = bytesPerPartition;
+ this.bufferPerPartition = bufferPerPartition;
+ }
+
+ // The methods that are used in the testing.
+
+ @Override
+ public long getBytesInQueueUnsafe(int targetSubpartition) {
+ return bytesPerPartition.getOrDefault(targetSubpartition, 0L);
+ }
+
+ @Override
+ public int getBuffersCountUnsafe(int targetSubpartition) {
+ return bufferPerPartition.getOrDefault(targetSubpartition, 0);
+ }
+
+ @Override
+ public int getNumberOfSubpartitions() {
+ return numberOfSubpartitions;
+ }
+
+ // The methods that are not used.
+
+ @Override
+ public void setup() throws IOException {}
+
+ @Override
+ public ResultPartitionID getPartitionId() {
+ return null;
+ }
+
+ @Override
+ public int getNumTargetKeyGroups() {
+ return 0;
+ }
+
+ @Override
+ public void setMaxOverdraftBuffersPerGate(int
maxOverdraftBuffersPerGate) {}
+
+ @Override
+ public void emitRecord(ByteBuffer record, int targetSubpartition)
throws IOException {}
+
+ @Override
+ public void broadcastRecord(ByteBuffer record) throws IOException {}
+
+ @Override
+ public void broadcastEvent(AbstractEvent event, boolean
isPriorityEvent)
+ throws IOException {}
+
+ @Override
+ public void alignedBarrierTimeout(long checkpointId) throws
IOException {}
+
+ @Override
+ public void abortCheckpoint(long checkpointId, CheckpointException
cause) {}
+
+ @Override
+ public void notifyEndOfData(StopMode mode) throws IOException {}
+
+ @Override
+ public CompletableFuture<Void> getAllDataProcessedFuture() {
+ return null;
+ }
+
+ @Override
+ public void setMetricGroup(TaskIOMetricGroup metrics) {}
+
+ @Override
+ public ResultSubpartitionView createSubpartitionView(
+ ResultSubpartitionIndexSet indexSet,
+ BufferAvailabilityListener availabilityListener)
+ throws IOException {
+ return null;
+ }
+
+ @Override
+ public void flushAll() {}
+
+ @Override
+ public void flush(int subpartitionIndex) {}
+
+ @Override
+ public void fail(@Nullable Throwable throwable) {}
+
+ @Override
+ public void finish() throws IOException {}
+
+ @Override
+ public boolean isFinished() {
+ return false;
+ }
+
+ @Override
+ public void release(Throwable cause) {}
+
+ @Override
+ public boolean isReleased() {
+ return false;
+ }
+
+ @Override
+ public void close() throws Exception {}
+
+ @Override
+ public CompletableFuture<?> getAvailableFuture() {
+ return null;
+ }
+ }
+}
diff --git
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java
index de7c3d6f5a8..1b42449f539 100644
---
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java
+++
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java
@@ -28,6 +28,8 @@ import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.configuration.CheckpointingOptions;
import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.IllegalConfigurationException;
+import org.apache.flink.configuration.NettyShuffleEnvironmentOptions;
import org.apache.flink.configuration.ReadableConfig;
import org.apache.flink.core.execution.SavepointFormatType;
import org.apache.flink.core.fs.FSDataInputStream;
@@ -1885,6 +1887,27 @@ public class StreamTaskTest {
}
}
+ @Test
+ void testGetAndCheckMaxTraverseSize() {
+ Configuration config = new Configuration();
+ assertThat(StreamTask.getAndCheckMaxTraverseSize(config)).isEqualTo(4);
+
+
config.set(NettyShuffleEnvironmentOptions.ADAPTIVE_PARTITIONER_MAX_TRAVERSE_SIZE,
2);
+ assertThat(StreamTask.getAndCheckMaxTraverseSize(config)).isEqualTo(2);
+
+
config.set(NettyShuffleEnvironmentOptions.ADAPTIVE_PARTITIONER_MAX_TRAVERSE_SIZE,
-1);
+ assertThatThrownBy(() -> StreamTask.getAndCheckMaxTraverseSize(config))
+ .isInstanceOf(IllegalConfigurationException.class);
+
+
config.set(NettyShuffleEnvironmentOptions.ADAPTIVE_PARTITIONER_MAX_TRAVERSE_SIZE,
0);
+ assertThatThrownBy(() -> StreamTask.getAndCheckMaxTraverseSize(config))
+ .isInstanceOf(IllegalConfigurationException.class);
+
+
config.set(NettyShuffleEnvironmentOptions.ADAPTIVE_PARTITIONER_MAX_TRAVERSE_SIZE,
1);
+ assertThatThrownBy(() -> StreamTask.getAndCheckMaxTraverseSize(config))
+ .isInstanceOf(IllegalConfigurationException.class);
+ }
+
private int getCurrentBufferSize(InputGate inputGate) {
return getTestChannel(inputGate, 0).getCurrentBufferSize();
}