StephanEwen commented on a change in pull request #13595:
URL: https://github.com/apache/flink/pull/13595#discussion_r510248527
##########
File path:
flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/BufferReaderWriterUtil.java
##########
@@ -201,7 +201,7 @@ private static boolean tryReadByteBuffer(FileChannel
channel, ByteBuffer b) thro
}
}
- private static void readByteBufferFully(FileChannel channel, ByteBuffer
b) throws IOException {
+ public static void readByteBufferFully(FileChannel channel, ByteBuffer
b) throws IOException {
Review comment:
Make this package-private instead, to be consistent with the visibility
of the class.
##########
File path:
flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/PartitionSortedBuffer.java
##########
@@ -0,0 +1,390 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.io.network.partition;
+
+import org.apache.flink.core.memory.MemorySegment;
+import org.apache.flink.core.memory.MemorySegmentFactory;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.buffer.BufferBuilder;
+import org.apache.flink.runtime.io.network.buffer.BufferPool;
+import org.apache.flink.runtime.io.network.buffer.NetworkBuffer;
+
+import javax.annotation.concurrent.NotThreadSafe;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Arrays;
+
+import static org.apache.flink.runtime.io.network.buffer.Buffer.DataType;
+import static org.apache.flink.util.Preconditions.checkArgument;
+import static org.apache.flink.util.Preconditions.checkNotNull;
+import static org.apache.flink.util.Preconditions.checkState;
+
+/**
+ * A {@link SortBuffer} implementation which sorts all appended records only
by subpartition index. Records of the
+ * same subpartition keep the appended order.
+ *
+ * <p>It maintains a list of {@link MemorySegment}s as a joint buffer. Data
will be appended to the joint buffer
+ * sequentially. When writing a record, an index entry will be appended first.
Each index entry has 4 fields: 4
+ * bytes record length, 4 bytes {@link DataType} and 8 bytes address pointing
to the next index entry of the same
+ * channel which will be used to index the next record to read when coping
data from this {@link SortBuffer}. For
+ * simplicity, no index entry can span multiple segments. The corresponding
record data sits right after its index
+ * entry and different from the index entry, records have variable length thus
may span multiple segments.
+ */
+@NotThreadSafe
+public class PartitionSortedBuffer implements SortBuffer {
+
+ /**
+ * Size of an index entry: 4 bytes for record length, 4 bytes for data
type and 8 bytes
+ * for pointer to next entry.
+ */
+ private static final int INDEX_ENTRY_SIZE = 4 + 4 + 8;
+
+ /** A buffer pool to request memory segments from. */
+ private final BufferPool bufferPool;
+
+ /** A segment list as a joint buffer which stores all records and index
entries. */
+ private final ArrayList<MemorySegment> buffers = new ArrayList<>();
+
+ /** Addresses of the first record's index entry for each subpartition.
*/
+ private final long[] firstIndexEntryAddresses;
+
+ /** Addresses of the last record's index entry for each subpartition. */
+ private final long[] lastIndexEntryAddresses;
+
+ /** Size of buffers requested from buffer pool. All buffers must be of
the same size. */
+ private final int bufferSize;
+
+ //
----------------------------------------------------------------------------------------------
+ // Statistics and states
+ //
----------------------------------------------------------------------------------------------
+
+ /** Total number of bytes already appended to this sort buffer. */
+ private long numTotalBytes;
+
+ /** Total number of records already appended to this sort buffer. */
+ private long numTotalRecords;
+
+ /** Total number of bytes already read from this sort buffer. */
+ private long numTotalBytesRead;
+
+ /** Whether this sort buffer is finished. One can only read a finished
sort buffer. */
+ private boolean isFinished;
+
+ /** Whether this sort buffer is released. A released sort buffer can
not be used. */
+ private boolean isReleased;
+
+ //
----------------------------------------------------------------------------------------------
+ // For writing
+ //
----------------------------------------------------------------------------------------------
+
+ /** Array index in the segment list of the current available buffer for
writing. */
+ private int writeSegmentIndex;
+
+ /** Next position in the current available buffer for writing. */
+ private int writeSegmentOffset;
+
+ //
----------------------------------------------------------------------------------------------
+ // For reading
+ //
----------------------------------------------------------------------------------------------
+
+ /** Index entry address of the current record or event to be read. */
+ private long readIndexEntryAddress;
+
+ /** Record bytes remaining after last copy, which must be read first in
next copy. */
+ private int recordRemainingBytes;
+
+ /** Current available channel to read data from. */
+ private int readChannelIndex = -1;
+
+ public PartitionSortedBuffer(BufferPool bufferPool, int
numSubpartitions, int bufferSize) {
+ checkArgument(bufferSize > INDEX_ENTRY_SIZE, "Buffer size is
too small.");
+
+ this.bufferPool = checkNotNull(bufferPool);
+ this.bufferSize = bufferSize;
+ this.firstIndexEntryAddresses = new long[numSubpartitions];
+ this.lastIndexEntryAddresses = new long[numSubpartitions];
+
+ // initialized with -1 means the corresponding channel has no
data
+ Arrays.fill(firstIndexEntryAddresses, -1L);
+ Arrays.fill(lastIndexEntryAddresses, -1L);
+ }
+
+ @Override
+ public boolean append(ByteBuffer source, int targetChannel, DataType
dataType) throws IOException {
+ checkState(!isFinished, "Sort buffer is already finished.");
+ checkState(!isReleased, "Sort buffer is already released.");
+
+ int totalBytes = source.remaining();
+ if (totalBytes == 0) {
+ return true;
+ }
+
+ // return false directly if it can not allocate enough buffers
for the given record
+ if (!allocateBuffersForRecord(totalBytes)) {
+ return false;
+ }
+
+ // write the index entry and record or event data
+ writeIndex(targetChannel, totalBytes, dataType);
+ writeRecord(source);
+
+ ++numTotalRecords;
+ numTotalBytes += totalBytes;
+
+ return true;
+ }
+
+ private void writeIndex(int channelIndex, int numRecordBytes,
Buffer.DataType dataType) {
+ MemorySegment segment = buffers.get(writeSegmentIndex);
+
+ // record length takes the high 32 bits and data type takes the
low 32 bits
+ segment.putLong(writeSegmentOffset, ((long) numRecordBytes <<
32) | dataType.ordinal());
+
+ // segment index takes the high 32 bits and segment offset
takes the low 32 bits
+ long indexEntryAddress = ((long) writeSegmentIndex << 32) |
writeSegmentOffset;
+
+ long lastIndexEntryAddress =
lastIndexEntryAddresses[channelIndex];
+ lastIndexEntryAddresses[channelIndex] = indexEntryAddress;
+
+ if (lastIndexEntryAddress >= 0) {
+ // link the previous index entry of the given channel
to the new index entry
+ segment =
buffers.get(getHigh32BitsFromLongAsInteger(lastIndexEntryAddress));
+
segment.putLong(getLow32BitsFromLongAsInteger(lastIndexEntryAddress) + 8,
indexEntryAddress);
+ } else {
+ firstIndexEntryAddresses[channelIndex] =
indexEntryAddress;
+ }
+
+ // move the write position forward so as to write the
corresponding record
+ updateWriteSegmentIndexAndOffset(INDEX_ENTRY_SIZE);
+ }
+
+ private void writeRecord(ByteBuffer source) {
+ while (source.hasRemaining()) {
+ MemorySegment segment = buffers.get(writeSegmentIndex);
+ int toCopy = Math.min(bufferSize - writeSegmentOffset,
source.remaining());
+ segment.put(writeSegmentOffset, source, toCopy);
+
+ // move the write position forward so as to write the
remaining bytes or next record
+ updateWriteSegmentIndexAndOffset(toCopy);
+ }
+ }
+
+ private boolean allocateBuffersForRecord(int numRecordBytes) throws
IOException {
+ int numBytesRequired = INDEX_ENTRY_SIZE + numRecordBytes;
+ int availableBytes = writeSegmentIndex == buffers.size() ? 0 :
bufferSize - writeSegmentOffset;
+
+ // return directly if current available bytes is adequate
+ if (availableBytes >= numBytesRequired) {
+ return true;
+ }
+
+ // skip the remaining free space if the available bytes is not
enough for an index entry
+ if (availableBytes < INDEX_ENTRY_SIZE) {
+ updateWriteSegmentIndexAndOffset(availableBytes);
+ availableBytes = 0;
+ }
+
+ // allocate exactly enough buffers for the appended record
+ do {
+ MemorySegment segment = requestBufferFromPool();
+ if (segment == null) {
+ // return false if we can not allocate enough
buffers for the appended record
+ return false;
+ }
+
+ assert segment.size() == bufferSize;
+ availableBytes += bufferSize;
+ buffers.add(segment);
+ } while (availableBytes < numBytesRequired);
+
+ return true;
+ }
+
+ private MemorySegment requestBufferFromPool() throws IOException {
+ try {
+ // blocking request buffers if there is still
guaranteed memory
+ if (buffers.size() <
bufferPool.getNumberOfRequiredMemorySegments()) {
+ return
bufferPool.requestBufferBuilderBlocking().getMemorySegment();
+ }
+ } catch (InterruptedException e) {
+ throw new IOException("Interrupted while requesting
buffer.");
+ }
+
+ BufferBuilder buffer = bufferPool.requestBufferBuilder();
+ return buffer != null ? buffer.getMemorySegment() : null;
+ }
+
+ private void updateWriteSegmentIndexAndOffset(int numBytes) {
+ writeSegmentOffset += numBytes;
+
+ // using the next available free buffer if the current is full
+ if (writeSegmentOffset == bufferSize) {
+ ++writeSegmentIndex;
+ writeSegmentOffset = 0;
+ }
+ }
+
+ @Override
+ public BufferWithChannel copyData(MemorySegment target) {
+ checkState(hasRemaining(), "No data remaining.");
+ checkState(isFinished, "Should finish the sort buffer first
before coping any data.");
+ checkState(!isReleased, "Sort buffer is already released.");
+
+ int numBytesCopied = 0;
+ DataType bufferDataType = DataType.DATA_BUFFER;
+ int channelIndex = readChannelIndex;
+
+ do {
+ int sourceSegmentIndex =
getHigh32BitsFromLongAsInteger(readIndexEntryAddress);
+ int sourceSegmentOffset =
getLow32BitsFromLongAsInteger(readIndexEntryAddress);
+ MemorySegment sourceSegment =
buffers.get(sourceSegmentIndex);
+
+ long lengthAndDataType =
sourceSegment.getLong(sourceSegmentOffset);
+ int length =
getHigh32BitsFromLongAsInteger(lengthAndDataType);
+ DataType dataType =
DataType.values()[getLow32BitsFromLongAsInteger(lengthAndDataType)];
+
+ // return the data read directly if the next to read is
an event
+ if (dataType.isEvent() && numBytesCopied > 0) {
+ break;
+ }
+ bufferDataType = dataType;
+
+ // get the next index entry address and move the read
position forward
+ long nextReadIndexEntryAddress =
sourceSegment.getLong(sourceSegmentOffset + 8);
+ sourceSegmentOffset += INDEX_ENTRY_SIZE;
+
+ // allocate a temp buffer for the event if the target
buffer is not big enough
+ if (bufferDataType.isEvent() && target.size() < length)
{
+ target =
MemorySegmentFactory.allocateUnpooledSegment(length);
+ }
+
+ numBytesCopied += copyRecordOrEvent(
+ target, numBytesCopied, sourceSegmentIndex,
sourceSegmentOffset, length);
+
+ if (recordRemainingBytes == 0) {
+ // move to next channel if the current channel
has been finished
+ if (readIndexEntryAddress ==
lastIndexEntryAddresses[channelIndex]) {
+ updateReadChannelAndIndexEntryAddress();
+ break;
+ }
+ readIndexEntryAddress =
nextReadIndexEntryAddress;
+ }
+ } while (numBytesCopied < target.size() &&
bufferDataType.isBuffer());
+
+ numTotalBytesRead += numBytesCopied;
+ Buffer buffer = new NetworkBuffer(target, (buf) -> {},
bufferDataType, numBytesCopied);
+ return new BufferWithChannel(buffer, channelIndex);
+ }
+
+ private int copyRecordOrEvent(
+ MemorySegment targetSegment,
+ int targetSegmentOffset,
+ int sourceSegmentIndex,
+ int sourceSegmentOffset,
+ int recordLength) {
+ if (recordRemainingBytes > 0) {
+ // skip the data already read if there is remaining
partial record after the previous copy
+ long position = (long) sourceSegmentOffset +
(recordLength - recordRemainingBytes);
+ sourceSegmentIndex += (position / bufferSize);
+ sourceSegmentOffset = (int) (position % bufferSize);
+ } else {
+ recordRemainingBytes = recordLength;
+ }
+
+ int targetSegmentSize = targetSegment.size();
+ int numBytesToCopy = Math.min(targetSegmentSize -
targetSegmentOffset, recordRemainingBytes);
+ do {
+ // move to next data buffer if all data of the current
buffer has been copied
+ if (sourceSegmentOffset == bufferSize) {
+ ++sourceSegmentIndex;
+ sourceSegmentOffset = 0;
+ }
+
+ int sourceRemainingBytes = Math.min(bufferSize -
sourceSegmentOffset, recordRemainingBytes);
+ int numBytes = Math.min(targetSegmentSize -
targetSegmentOffset, sourceRemainingBytes);
+ MemorySegment sourceSegment =
buffers.get(sourceSegmentIndex);
+ sourceSegment.copyTo(sourceSegmentOffset,
targetSegment, targetSegmentOffset, numBytes);
+
+ recordRemainingBytes -= numBytes;
+ targetSegmentOffset += numBytes;
+ sourceSegmentOffset += numBytes;
+ } while ((recordRemainingBytes > 0 && targetSegmentOffset <
targetSegmentSize));
+
+ return numBytesToCopy;
+ }
+
+ private void updateReadChannelAndIndexEntryAddress() {
+ // skip the channels without any data
+ while (++readChannelIndex < firstIndexEntryAddresses.length) {
+ if ((readIndexEntryAddress =
firstIndexEntryAddresses[readChannelIndex]) >= 0) {
+ break;
+ }
+ }
+ }
+
+ private int getHigh32BitsFromLongAsInteger(long value) {
Review comment:
Maybe rename this to `getSegmentIndexFromPointer` ?
##########
File path:
flink-core/src/main/java/org/apache/flink/configuration/NettyShuffleEnvironmentOptions.java
##########
@@ -173,6 +173,28 @@
" help relieve back-pressure caused by
unbalanced data distribution among the subpartitions. This value should be" +
" increased in case of higher round trip times
between nodes and/or larger number of machines in the cluster.");
+ /**
+ * Maximum number of network buffers can be used per sort-merge
blocking result partition.
+ */
+ @Documentation.Section(Documentation.Sections.ALL_TASK_MANAGER_NETWORK)
+ public static final ConfigOption<Integer>
NETWORK_MAX_BUFFERS_PER_SORT_MERGE_PARTITION =
+
key("taskmanager.network.sort-merge-blocking-shuffle.max-buffers-per-partition")
+ .defaultValue(2048)
+ .withDescription("Maximum number of network buffers can
be used per sort-merge blocking result partition. " +
+ "This value is only an upper bound limit and
does not mean that the sort-merge blocking result partition" +
+ " will use as many network buffers.");
+
+ /**
+ * Parallelism threshold to switch between sort-merge based blocking
shuffle and the default hash-based blocking shuffle.
+ */
+ @Documentation.Section(Documentation.Sections.ALL_TASK_MANAGER_NETWORK)
+ public static final ConfigOption<Integer>
NETWORK_SORT_MERGE_SHUFFLE_MIN_PARALLELISM =
+
key("taskmanager.network.sort-merge-blocking-shuffle.min-parallelism")
Review comment:
Similar, maybe use `taskmanager.network.sort-shuffle.min-parallelism`
here.
##########
File path:
flink-core/src/main/java/org/apache/flink/configuration/NettyShuffleEnvironmentOptions.java
##########
@@ -173,6 +173,28 @@
" help relieve back-pressure caused by
unbalanced data distribution among the subpartitions. This value should be" +
" increased in case of higher round trip times
between nodes and/or larger number of machines in the cluster.");
+ /**
+ * Maximum number of network buffers can be used per sort-merge
blocking result partition.
+ */
+ @Documentation.Section(Documentation.Sections.ALL_TASK_MANAGER_NETWORK)
+ public static final ConfigOption<Integer>
NETWORK_MAX_BUFFERS_PER_SORT_MERGE_PARTITION =
+
key("taskmanager.network.sort-merge-blocking-shuffle.max-buffers-per-partition")
+ .defaultValue(2048)
+ .withDescription("Maximum number of network buffers can
be used per sort-merge blocking result partition. " +
Review comment:
Please add the `intType()` here to make it runtime type safe (and not
have the deprecation warnings).
##########
File path:
flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartition.java
##########
@@ -0,0 +1,348 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.io.network.partition;
+
+import org.apache.flink.core.memory.MemorySegment;
+import org.apache.flink.core.memory.MemorySegmentFactory;
+import org.apache.flink.runtime.event.AbstractEvent;
+import org.apache.flink.runtime.io.disk.FileChannelManager;
+import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent;
+import org.apache.flink.runtime.io.network.api.serialization.EventSerializer;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.buffer.BufferCompressor;
+import org.apache.flink.runtime.io.network.buffer.BufferPool;
+import org.apache.flink.runtime.io.network.buffer.NetworkBuffer;
+import org.apache.flink.util.function.SupplierWithException;
+
+import javax.annotation.Nullable;
+import javax.annotation.concurrent.GuardedBy;
+import javax.annotation.concurrent.NotThreadSafe;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.HashSet;
+import java.util.Set;
+import java.util.concurrent.CompletableFuture;
+
+import static org.apache.flink.runtime.io.network.buffer.Buffer.DataType;
+import static
org.apache.flink.runtime.io.network.partition.SortBuffer.BufferWithChannel;
+import static org.apache.flink.util.Preconditions.checkElementIndex;
+import static org.apache.flink.util.Preconditions.checkNotNull;
+import static org.apache.flink.util.Preconditions.checkState;
+
+/**
+ * {@link SortMergeResultPartition} appends records and events to {@link
SortBuffer} and after the {@link SortBuffer}
+ * is full, all data in the {@link SortBuffer} will be copied and spilled to a
{@link PartitionedFile} in subpartition
+ * index order sequentially. Large records that can not be appended to an
empty {@link SortBuffer} will be spilled to
+ * the {@link PartitionedFile} separately.
+ */
+@NotThreadSafe
+public class SortMergeResultPartition extends ResultPartition {
+
+ private final Object lock = new Object();
+
+ /** All active readers which are consuming data from this result
partition now. */
+ @GuardedBy("lock")
+ private final Set<SortMergeSubpartitionReader> readers = new
HashSet<>();
+
+ /** {@link PartitionedFile} produced by this result partition. */
+ @GuardedBy("lock")
+ private PartitionedFile resultFile;
+
+ /** Used to generate random file channel ID. */
+ private final FileChannelManager channelManager;
+
+ /** Number of data buffers (excluding events) written for each
subpartition. */
+ private final int[] numDataBuffers;
+
+ /** A piece of unmanaged memory for data writing. */
+ private final MemorySegment writeBuffer;
+
+ /** Size of network buffer and write buffer. */
+ private final int networkBufferSize;
+
+ /** Current {@link SortBuffer} to append records to. */
+ private SortBuffer currentSortBuffer;
+
+ /** File writer for this result partition. */
+ private PartitionedFileWriter fileWriter;
+
+ public SortMergeResultPartition(
+ String owningTaskName,
+ int partitionIndex,
+ ResultPartitionID partitionId,
+ ResultPartitionType partitionType,
+ int numSubpartitions,
+ int numTargetKeyGroups,
+ int networkBufferSize,
+ ResultPartitionManager partitionManager,
+ FileChannelManager channelManager,
+ @Nullable BufferCompressor bufferCompressor,
+ SupplierWithException<BufferPool, IOException>
bufferPoolFactory) {
+
+ super(
+ owningTaskName,
+ partitionIndex,
+ partitionId,
+ partitionType,
+ numSubpartitions,
+ numTargetKeyGroups,
+ partitionManager,
+ bufferCompressor,
+ bufferPoolFactory);
+
+ this.channelManager = checkNotNull(channelManager);
+ this.networkBufferSize = networkBufferSize;
+ this.numDataBuffers = new int[numSubpartitions];
+ this.writeBuffer =
MemorySegmentFactory.allocateUnpooledOffHeapMemory(networkBufferSize);
+ }
+
+ @Override
+ protected void releaseInternal() {
+ synchronized (lock) {
+ isFinished = true; // to fail writing faster
+
+ // delete the produced file only when no reader is
reading now
+ if (readers.isEmpty()) {
+ if (resultFile != null) {
+ resultFile.deleteQuietly();
+ resultFile = null;
+ }
+ }
+ }
+ }
+
+ @Override
+ public void emitRecord(ByteBuffer record, int targetSubpartition)
throws IOException {
+ emit(record, targetSubpartition, DataType.DATA_BUFFER);
+ }
+
+ @Override
+ public void broadcastRecord(ByteBuffer record) throws IOException {
+ broadcast(record, DataType.DATA_BUFFER);
+ }
+
+ @Override
+ public void broadcastEvent(AbstractEvent event, boolean
isPriorityEvent) throws IOException {
+ Buffer buffer = EventSerializer.toBuffer(event,
isPriorityEvent);
+ try {
+ ByteBuffer serializedEvent =
buffer.getNioBufferReadable();
+ broadcast(serializedEvent, buffer.getDataType());
+ } finally {
+ buffer.recycleBuffer();
+ }
+ }
+
+ private void broadcast(ByteBuffer record, DataType dataType) throws
IOException {
+ for (int channelIndex = 0; channelIndex < numSubpartitions;
++channelIndex) {
+ record.rewind();
+ emit(record, channelIndex, dataType);
+ }
+ }
+
+ private void emit(ByteBuffer record, int targetSubpartition, DataType
dataType) throws IOException {
+ checkInProduceState();
+
+ SortBuffer sortBuffer = getSortBuffer();
+ if (sortBuffer.append(record, targetSubpartition, dataType)) {
+ return;
+ }
+
+ if (!sortBuffer.hasRemaining()) {
+ // the record can not be appended to the free sort
buffer because it is too large
+ releaseCurrentSortBuffer();
+ writeLargeRecord(record, targetSubpartition, dataType);
+ return;
+ }
+
+ flushCurrentSortBuffer();
+ emit(record, targetSubpartition, dataType);
+ }
+
+ private void releaseCurrentSortBuffer() {
+ if (currentSortBuffer != null) {
+ currentSortBuffer.release();
+ currentSortBuffer = null;
+ }
+ }
+
+ private SortBuffer getSortBuffer() {
+ if (currentSortBuffer != null) {
+ return currentSortBuffer;
+ }
+
+ currentSortBuffer = new PartitionSortedBuffer(bufferPool,
numSubpartitions, networkBufferSize);
+ return currentSortBuffer;
+ }
+
+ private void flushCurrentSortBuffer() throws IOException {
+ if (currentSortBuffer == null ||
!currentSortBuffer.hasRemaining()) {
+ releaseCurrentSortBuffer();
+ return;
+ }
+
+ currentSortBuffer.finish();
+ PartitionedFileWriter fileWriter = getPartitionedFileWriter();
+
+ while (currentSortBuffer.hasRemaining()) {
+ BufferWithChannel bufferWithChannel =
currentSortBuffer.copyData(writeBuffer);
+ Buffer buffer = bufferWithChannel.getBuffer();
+ int subpartitionIndex =
bufferWithChannel.getChannelIndex();
+
+ fileWriter.writeBuffer(buffer, subpartitionIndex);
+ updateStatistics(buffer, subpartitionIndex);
+ }
+
+ releaseCurrentSortBuffer();
+ }
+
+ private PartitionedFileWriter getPartitionedFileWriter() throws
IOException {
+ if (fileWriter == null) {
+ String basePath =
channelManager.createChannel().getPath();
+ fileWriter = new PartitionedFileWriter(basePath,
numSubpartitions);
+ fileWriter.open();
+ }
+
+ fileWriter.startNewRegion();
+ return fileWriter;
+ }
+
+ private void updateStatistics(Buffer buffer, int subpartitionIndex) {
+ numBuffersOut.inc();
+ numBytesOut.inc(buffer.readableBytes());
+ if (buffer.isBuffer()) {
+ ++numDataBuffers[subpartitionIndex];
+ }
+ }
+
+ /**
+ * Spills the large record into the target {@link PartitionedFile} as a
separate data region.
+ */
+ private void writeLargeRecord(ByteBuffer record, int
targetSubpartition, DataType dataType) throws IOException {
+ PartitionedFileWriter fileWriter = getPartitionedFileWriter();
+
+ while (record.hasRemaining()) {
+ int toCopy = Math.min(record.remaining(),
writeBuffer.size());
+ writeBuffer.put(0, record, toCopy);
+
+ NetworkBuffer buffer = new NetworkBuffer(writeBuffer,
(buf) -> {}, dataType, toCopy);
+ fileWriter.writeBuffer(buffer, targetSubpartition);
+ }
+ }
+
+ void releaseReader(SortMergeSubpartitionReader reader) {
+ synchronized (lock) {
+ readers.remove(reader);
+
+ // release the result partition if it has been marked
as released
+ if (readers.isEmpty() && isReleased()) {
+ releaseInternal();
+ }
+ }
+ }
+
+ @Override
+ public void finish() throws IOException {
+ checkInProduceState();
+
+ broadcastEvent(EndOfPartitionEvent.INSTANCE, false);
+ flushCurrentSortBuffer();
+
+ synchronized (lock) {
+ checkState(!isReleased());
+
+ resultFile = fileWriter.finish();
+ fileWriter = null;
+
+ LOG.info("New partitioned file produced: {}.",
resultFile);
+ }
+
+ super.finish();
+ }
+
+ @Override
+ public void close() {
+ releaseCurrentSortBuffer();
+
+ if (fileWriter != null) {
+ fileWriter.releaseQuietly();
+ }
+
+ super.close();
+ }
+
+ @Override
+ public ResultSubpartitionView createSubpartitionView(
+ int subpartitionIndex,
+ BufferAvailabilityListener availabilityListener) throws
IOException {
+ synchronized (lock) {
+ checkElementIndex(subpartitionIndex, numSubpartitions,
"Subpartition not found.");
+ checkState(!isReleased(), "Partition released.");
+ checkState(isFinished(), "Trying to read unfinished
blocking partition.");
+
+ SortMergeSubpartitionReader reader = new
SortMergeSubpartitionReader(
+ this,
+ availabilityListener,
+ subpartitionIndex,
+ numDataBuffers[subpartitionIndex],
+ networkBufferSize);
+ readers.add(reader);
+ availabilityListener.notifyDataAvailable();
+
+ return reader;
+ }
+ }
+
+ @Override
+ public void flushAll() {
+ try {
+ flushCurrentSortBuffer();
+ } catch (IOException e) {
+ LOG.error("Failed to flush the current sort buffer.",
e);
+ }
+ }
+
+ @Override
+ public void flush(int subpartitionIndex) {
+ try {
+ flushCurrentSortBuffer();
+ } catch (IOException e) {
+ LOG.error("Failed to flush the current sort buffer.",
e);
+ }
+ }
+
+ @Override
+ public CompletableFuture<?> getAvailableFuture() {
+ return AVAILABLE;
+ }
+
+ @Override
+ public int getNumberOfQueuedBuffers() {
+ return 0;
+ }
+
+ @Override
+ public int getNumberOfQueuedBuffers(int targetSubpartition) {
+ return 0;
+ }
+
+ public PartitionedFile getResultFile() {
Review comment:
This method publicly exposes a field that is otherwise lock-guarded.
Would be good to avoid that, or at least not in the production scope.
You can do the following:
- pass this directly to the constructor of `SortMergeSubpartitionReader`
- reduce visibility to package-private and annotate it as
`@VisibleForTesting`.
##########
File path:
flink-core/src/main/java/org/apache/flink/configuration/NettyShuffleEnvironmentOptions.java
##########
@@ -173,6 +173,28 @@
" help relieve back-pressure caused by
unbalanced data distribution among the subpartitions. This value should be" +
" increased in case of higher round trip times
between nodes and/or larger number of machines in the cluster.");
+ /**
+ * Maximum number of network buffers can be used per sort-merge
blocking result partition.
+ */
+ @Documentation.Section(Documentation.Sections.ALL_TASK_MANAGER_NETWORK)
+ public static final ConfigOption<Integer>
NETWORK_MAX_BUFFERS_PER_SORT_MERGE_PARTITION =
+
key("taskmanager.network.sort-merge-blocking-shuffle.max-buffers-per-partition")
+ .defaultValue(2048)
+ .withDescription("Maximum number of network buffers can
be used per sort-merge blocking result partition. " +
+ "This value is only an upper bound limit and
does not mean that the sort-merge blocking result partition" +
+ " will use as many network buffers.");
+
+ /**
+ * Parallelism threshold to switch between sort-merge based blocking
shuffle and the default hash-based blocking shuffle.
+ */
+ @Documentation.Section(Documentation.Sections.ALL_TASK_MANAGER_NETWORK)
+ public static final ConfigOption<Integer>
NETWORK_SORT_MERGE_SHUFFLE_MIN_PARALLELISM =
+
key("taskmanager.network.sort-merge-blocking-shuffle.min-parallelism")
+ .defaultValue(Integer.MAX_VALUE)
Review comment:
Please add the `intType()` here to make it runtime type safe (and not
have the deprecation warnings).
##########
File path:
flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/PartitionSortedBuffer.java
##########
@@ -0,0 +1,390 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.io.network.partition;
+
+import org.apache.flink.core.memory.MemorySegment;
+import org.apache.flink.core.memory.MemorySegmentFactory;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.buffer.BufferBuilder;
+import org.apache.flink.runtime.io.network.buffer.BufferPool;
+import org.apache.flink.runtime.io.network.buffer.NetworkBuffer;
+
+import javax.annotation.concurrent.NotThreadSafe;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Arrays;
+
+import static org.apache.flink.runtime.io.network.buffer.Buffer.DataType;
+import static org.apache.flink.util.Preconditions.checkArgument;
+import static org.apache.flink.util.Preconditions.checkNotNull;
+import static org.apache.flink.util.Preconditions.checkState;
+
+/**
+ * A {@link SortBuffer} implementation which sorts all appended records only
by subpartition index. Records of the
+ * same subpartition keep the appended order.
+ *
+ * <p>It maintains a list of {@link MemorySegment}s as a joint buffer. Data
will be appended to the joint buffer
+ * sequentially. When writing a record, an index entry will be appended first.
Each index entry has 4 fields: 4
+ * bytes record length, 4 bytes {@link DataType} and 8 bytes address pointing
to the next index entry of the same
+ * channel which will be used to index the next record to read when coping
data from this {@link SortBuffer}. For
+ * simplicity, no index entry can span multiple segments. The corresponding
record data sits right after its index
+ * entry and different from the index entry, records have variable length thus
may span multiple segments.
+ */
+@NotThreadSafe
+public class PartitionSortedBuffer implements SortBuffer {
+
+ /**
+ * Size of an index entry: 4 bytes for record length, 4 bytes for data
type and 8 bytes
+ * for pointer to next entry.
+ */
+ private static final int INDEX_ENTRY_SIZE = 4 + 4 + 8;
+
+ /** A buffer pool to request memory segments from. */
+ private final BufferPool bufferPool;
+
+ /** A segment list as a joint buffer which stores all records and index
entries. */
+ private final ArrayList<MemorySegment> buffers = new ArrayList<>();
+
+ /** Addresses of the first record's index entry for each subpartition.
*/
+ private final long[] firstIndexEntryAddresses;
+
+ /** Addresses of the last record's index entry for each subpartition. */
+ private final long[] lastIndexEntryAddresses;
+
+ /** Size of buffers requested from buffer pool. All buffers must be of
the same size. */
+ private final int bufferSize;
+
+ //
----------------------------------------------------------------------------------------------
+ // Statistics and states
+ //
----------------------------------------------------------------------------------------------
+
+ /** Total number of bytes already appended to this sort buffer. */
+ private long numTotalBytes;
+
+ /** Total number of records already appended to this sort buffer. */
+ private long numTotalRecords;
+
+ /** Total number of bytes already read from this sort buffer. */
+ private long numTotalBytesRead;
+
+ /** Whether this sort buffer is finished. One can only read a finished
sort buffer. */
+ private boolean isFinished;
+
+ /** Whether this sort buffer is released. A released sort buffer can
not be used. */
+ private boolean isReleased;
+
+ //
----------------------------------------------------------------------------------------------
+ // For writing
+ //
----------------------------------------------------------------------------------------------
+
+ /** Array index in the segment list of the current available buffer for
writing. */
+ private int writeSegmentIndex;
+
+ /** Next position in the current available buffer for writing. */
+ private int writeSegmentOffset;
+
+ //
----------------------------------------------------------------------------------------------
+ // For reading
+ //
----------------------------------------------------------------------------------------------
+
+ /** Index entry address of the current record or event to be read. */
+ private long readIndexEntryAddress;
+
+ /** Record bytes remaining after last copy, which must be read first in
next copy. */
+ private int recordRemainingBytes;
+
+ /** Current available channel to read data from. */
+ private int readChannelIndex = -1;
+
+ public PartitionSortedBuffer(BufferPool bufferPool, int
numSubpartitions, int bufferSize) {
+ checkArgument(bufferSize > INDEX_ENTRY_SIZE, "Buffer size is
too small.");
+
+ this.bufferPool = checkNotNull(bufferPool);
+ this.bufferSize = bufferSize;
+ this.firstIndexEntryAddresses = new long[numSubpartitions];
+ this.lastIndexEntryAddresses = new long[numSubpartitions];
+
+ // initialized with -1 means the corresponding channel has no
data
+ Arrays.fill(firstIndexEntryAddresses, -1L);
+ Arrays.fill(lastIndexEntryAddresses, -1L);
+ }
+
+ @Override
+ public boolean append(ByteBuffer source, int targetChannel, DataType
dataType) throws IOException {
+ checkState(!isFinished, "Sort buffer is already finished.");
+ checkState(!isReleased, "Sort buffer is already released.");
+
+ int totalBytes = source.remaining();
+ if (totalBytes == 0) {
+ return true;
+ }
+
+ // return false directly if it can not allocate enough buffers
for the given record
+ if (!allocateBuffersForRecord(totalBytes)) {
+ return false;
+ }
+
+ // write the index entry and record or event data
+ writeIndex(targetChannel, totalBytes, dataType);
+ writeRecord(source);
+
+ ++numTotalRecords;
+ numTotalBytes += totalBytes;
+
+ return true;
+ }
+
+ private void writeIndex(int channelIndex, int numRecordBytes,
Buffer.DataType dataType) {
+ MemorySegment segment = buffers.get(writeSegmentIndex);
+
+ // record length takes the high 32 bits and data type takes the
low 32 bits
+ segment.putLong(writeSegmentOffset, ((long) numRecordBytes <<
32) | dataType.ordinal());
+
+ // segment index takes the high 32 bits and segment offset
takes the low 32 bits
+ long indexEntryAddress = ((long) writeSegmentIndex << 32) |
writeSegmentOffset;
+
+ long lastIndexEntryAddress =
lastIndexEntryAddresses[channelIndex];
+ lastIndexEntryAddresses[channelIndex] = indexEntryAddress;
+
+ if (lastIndexEntryAddress >= 0) {
+ // link the previous index entry of the given channel
to the new index entry
+ segment =
buffers.get(getHigh32BitsFromLongAsInteger(lastIndexEntryAddress));
+
segment.putLong(getLow32BitsFromLongAsInteger(lastIndexEntryAddress) + 8,
indexEntryAddress);
+ } else {
+ firstIndexEntryAddresses[channelIndex] =
indexEntryAddress;
+ }
+
+ // move the write position forward so as to write the
corresponding record
+ updateWriteSegmentIndexAndOffset(INDEX_ENTRY_SIZE);
+ }
+
+ private void writeRecord(ByteBuffer source) {
+ while (source.hasRemaining()) {
+ MemorySegment segment = buffers.get(writeSegmentIndex);
+ int toCopy = Math.min(bufferSize - writeSegmentOffset,
source.remaining());
+ segment.put(writeSegmentOffset, source, toCopy);
+
+ // move the write position forward so as to write the
remaining bytes or next record
+ updateWriteSegmentIndexAndOffset(toCopy);
+ }
+ }
+
+ private boolean allocateBuffersForRecord(int numRecordBytes) throws
IOException {
+ int numBytesRequired = INDEX_ENTRY_SIZE + numRecordBytes;
+ int availableBytes = writeSegmentIndex == buffers.size() ? 0 :
bufferSize - writeSegmentOffset;
+
+ // return directly if current available bytes is adequate
+ if (availableBytes >= numBytesRequired) {
+ return true;
+ }
+
+ // skip the remaining free space if the available bytes is not
enough for an index entry
+ if (availableBytes < INDEX_ENTRY_SIZE) {
+ updateWriteSegmentIndexAndOffset(availableBytes);
+ availableBytes = 0;
+ }
+
+ // allocate exactly enough buffers for the appended record
+ do {
+ MemorySegment segment = requestBufferFromPool();
+ if (segment == null) {
+ // return false if we can not allocate enough
buffers for the appended record
+ return false;
+ }
+
+ assert segment.size() == bufferSize;
+ availableBytes += bufferSize;
+ buffers.add(segment);
+ } while (availableBytes < numBytesRequired);
+
+ return true;
+ }
+
+ private MemorySegment requestBufferFromPool() throws IOException {
+ try {
+ // blocking request buffers if there is still
guaranteed memory
+ if (buffers.size() <
bufferPool.getNumberOfRequiredMemorySegments()) {
+ return
bufferPool.requestBufferBuilderBlocking().getMemorySegment();
+ }
+ } catch (InterruptedException e) {
+ throw new IOException("Interrupted while requesting
buffer.");
+ }
+
+ BufferBuilder buffer = bufferPool.requestBufferBuilder();
+ return buffer != null ? buffer.getMemorySegment() : null;
+ }
+
+ private void updateWriteSegmentIndexAndOffset(int numBytes) {
+ writeSegmentOffset += numBytes;
+
+ // using the next available free buffer if the current is full
+ if (writeSegmentOffset == bufferSize) {
+ ++writeSegmentIndex;
+ writeSegmentOffset = 0;
+ }
+ }
+
+ @Override
+ public BufferWithChannel copyData(MemorySegment target) {
+ checkState(hasRemaining(), "No data remaining.");
+ checkState(isFinished, "Should finish the sort buffer first
before coping any data.");
+ checkState(!isReleased, "Sort buffer is already released.");
+
+ int numBytesCopied = 0;
+ DataType bufferDataType = DataType.DATA_BUFFER;
+ int channelIndex = readChannelIndex;
+
+ do {
+ int sourceSegmentIndex =
getHigh32BitsFromLongAsInteger(readIndexEntryAddress);
+ int sourceSegmentOffset =
getLow32BitsFromLongAsInteger(readIndexEntryAddress);
+ MemorySegment sourceSegment =
buffers.get(sourceSegmentIndex);
+
+ long lengthAndDataType =
sourceSegment.getLong(sourceSegmentOffset);
+ int length =
getHigh32BitsFromLongAsInteger(lengthAndDataType);
+ DataType dataType =
DataType.values()[getLow32BitsFromLongAsInteger(lengthAndDataType)];
+
+ // return the data read directly if the next to read is
an event
+ if (dataType.isEvent() && numBytesCopied > 0) {
+ break;
+ }
+ bufferDataType = dataType;
+
+ // get the next index entry address and move the read
position forward
+ long nextReadIndexEntryAddress =
sourceSegment.getLong(sourceSegmentOffset + 8);
+ sourceSegmentOffset += INDEX_ENTRY_SIZE;
+
+ // allocate a temp buffer for the event if the target
buffer is not big enough
+ if (bufferDataType.isEvent() && target.size() < length)
{
+ target =
MemorySegmentFactory.allocateUnpooledSegment(length);
+ }
+
+ numBytesCopied += copyRecordOrEvent(
+ target, numBytesCopied, sourceSegmentIndex,
sourceSegmentOffset, length);
+
+ if (recordRemainingBytes == 0) {
+ // move to next channel if the current channel
has been finished
+ if (readIndexEntryAddress ==
lastIndexEntryAddresses[channelIndex]) {
+ updateReadChannelAndIndexEntryAddress();
+ break;
+ }
+ readIndexEntryAddress =
nextReadIndexEntryAddress;
+ }
+ } while (numBytesCopied < target.size() &&
bufferDataType.isBuffer());
+
+ numTotalBytesRead += numBytesCopied;
+ Buffer buffer = new NetworkBuffer(target, (buf) -> {},
bufferDataType, numBytesCopied);
+ return new BufferWithChannel(buffer, channelIndex);
+ }
+
+ private int copyRecordOrEvent(
+ MemorySegment targetSegment,
+ int targetSegmentOffset,
+ int sourceSegmentIndex,
+ int sourceSegmentOffset,
+ int recordLength) {
+ if (recordRemainingBytes > 0) {
+ // skip the data already read if there is remaining
partial record after the previous copy
+ long position = (long) sourceSegmentOffset +
(recordLength - recordRemainingBytes);
+ sourceSegmentIndex += (position / bufferSize);
+ sourceSegmentOffset = (int) (position % bufferSize);
+ } else {
+ recordRemainingBytes = recordLength;
+ }
+
+ int targetSegmentSize = targetSegment.size();
+ int numBytesToCopy = Math.min(targetSegmentSize -
targetSegmentOffset, recordRemainingBytes);
+ do {
+ // move to next data buffer if all data of the current
buffer has been copied
+ if (sourceSegmentOffset == bufferSize) {
+ ++sourceSegmentIndex;
+ sourceSegmentOffset = 0;
+ }
+
+ int sourceRemainingBytes = Math.min(bufferSize -
sourceSegmentOffset, recordRemainingBytes);
+ int numBytes = Math.min(targetSegmentSize -
targetSegmentOffset, sourceRemainingBytes);
+ MemorySegment sourceSegment =
buffers.get(sourceSegmentIndex);
+ sourceSegment.copyTo(sourceSegmentOffset,
targetSegment, targetSegmentOffset, numBytes);
+
+ recordRemainingBytes -= numBytes;
+ targetSegmentOffset += numBytes;
+ sourceSegmentOffset += numBytes;
+ } while ((recordRemainingBytes > 0 && targetSegmentOffset <
targetSegmentSize));
+
+ return numBytesToCopy;
+ }
+
+ private void updateReadChannelAndIndexEntryAddress() {
+ // skip the channels without any data
+ while (++readChannelIndex < firstIndexEntryAddresses.length) {
+ if ((readIndexEntryAddress =
firstIndexEntryAddresses[readChannelIndex]) >= 0) {
+ break;
+ }
+ }
+ }
+
+ private int getHigh32BitsFromLongAsInteger(long value) {
+ return (int) (value >>> 32);
+ }
+
+ private int getLow32BitsFromLongAsInteger(long value) {
Review comment:
Maybe rename this to `getSegmentOffsetFromPointer` ?
##########
File path:
flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartitionFactory.java
##########
@@ -166,6 +189,14 @@ else if (type == ResultPartitionType.BLOCKING || type ==
ResultPartitionType.BLO
return partition;
}
+ private boolean isBlockingShuffle(ResultPartitionType type) {
Review comment:
You don't need this, you can just say `!type.isPipelined()`.
##########
File path:
flink-core/src/main/java/org/apache/flink/configuration/NettyShuffleEnvironmentOptions.java
##########
@@ -173,6 +173,28 @@
" help relieve back-pressure caused by
unbalanced data distribution among the subpartitions. This value should be" +
" increased in case of higher round trip times
between nodes and/or larger number of machines in the cluster.");
+ /**
+ * Maximum number of network buffers can be used per sort-merge
blocking result partition.
+ */
+ @Documentation.Section(Documentation.Sections.ALL_TASK_MANAGER_NETWORK)
+ public static final ConfigOption<Integer>
NETWORK_MAX_BUFFERS_PER_SORT_MERGE_PARTITION =
+
key("taskmanager.network.sort-merge-blocking-shuffle.max-buffers-per-partition")
Review comment:
Can we make the keys a bit shorter? For example use
`taskmanager.network.sort-shuffle.max-buffers`.
##########
File path:
flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartitionFactory.java
##########
@@ -215,10 +246,19 @@ private static void
releasePartitionsQuietly(ResultSubpartition[] partitions, in
return () -> {
int maxNumberOfMemorySegments = type.isBounded() ?
numberOfSubpartitions *
networkBuffersPerChannel + floatingNetworkBuffersPerGate : Integer.MAX_VALUE;
+ int numRequiredBuffers = numberOfSubpartitions + 1;
+
+ if (isSortMergeBlockingShuffle(type,
numberOfSubpartitions)) {
Review comment:
Maybe rewrite this to not compute it twice in the sort shuffle mode
(once for non-sort shuffle, then for sort shuffle).
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]