This is an automated email from the ASF dual-hosted git repository.
rexxiong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git
The following commit(s) were added to refs/heads/main by this push:
new 7ebd168f8 [CELEBORN-1490][CIP-6] Support process large buffer in flink
hybrid shuffle
7ebd168f8 is described below
commit 7ebd168f808afe4cfbccaf75d074299d05eb9c50
Author: Yuxin Tan <[email protected]>
AuthorDate: Mon Nov 4 16:57:43 2024 +0800
[CELEBORN-1490][CIP-6] Support process large buffer in flink hybrid shuffle
### What changes were proposed in this pull request?
This is the last PR in the CIP-6 series.
Fix the bug when hybrid shuffle face the buffer which large then 32K.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
UT
Closes #2873 from reswqa/11-large-buffer-10month.
Lead-authored-by: Yuxin Tan <[email protected]>
Co-authored-by: Weijie Guo <[email protected]>
Signed-off-by: Shuang <[email protected]>
---
.../celeborn/plugin/flink/buffer/BufferPacker.java | 23 ++++
.../flink/network/FlinkTransportClientFactory.java | 7 +-
.../TransportFrameDecoderWithBufferSupplier.java | 72 +++++++++++-
.../flink/readclient/FlinkShuffleClientImpl.java | 46 +++++++-
.../celeborn/plugin/flink/utils/BufferUtils.java | 12 ++
.../celeborn/plugin/flink/BufferPackSuiteJ.java | 48 ++++++++
.../plugin/flink/FlinkShuffleClientImplSuiteJ.java | 2 +-
...nsportFrameDecoderWithBufferSupplierSuiteJ.java | 123 +++++++++++++++++++++
.../flink/tiered/CelebornTierConsumerAgent.java | 9 +-
.../plugin/flink/tiered/CelebornTierFactory.java | 3 +-
.../flink/tiered/CelebornTierProducerAgent.java | 15 ++-
.../celeborn/tests/flink/HeartbeatTest.scala | 9 +-
12 files changed, 350 insertions(+), 19 deletions(-)
diff --git
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/buffer/BufferPacker.java
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/buffer/BufferPacker.java
index 76a6c2ef7..8876b6b08 100644
---
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/buffer/BufferPacker.java
+++
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/buffer/BufferPacker.java
@@ -26,6 +26,7 @@ import org.apache.flink.runtime.io.network.buffer.Buffer;
import org.apache.flink.runtime.io.network.buffer.BufferRecycler;
import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf;
import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBufAllocator;
+import org.apache.flink.shaded.netty4.io.netty.buffer.CompositeByteBuf;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -157,6 +158,28 @@ public class BufferPacker {
public static Queue<Buffer> unpack(ByteBuf byteBuf) {
Queue<Buffer> buffers = new ArrayDeque<>();
try {
+ if (byteBuf instanceof CompositeByteBuf) {
+ // If the received byteBuf is a CompositeByteBuf, it indicates that
the byteBuf originates
+ // from the Flink hybrid shuffle integration strategy. This byteBuf
consists of two parts: a
+ // celeborn header and a data buffer.
+ CompositeByteBuf compositeByteBuf = (CompositeByteBuf) byteBuf;
+ ByteBuf headerBuffer = compositeByteBuf.component(0).unwrap();
+ ByteBuf dataBuffer = compositeByteBuf.component(1).unwrap();
+ dataBuffer.retain();
+ Utils.checkState(
+ dataBuffer instanceof Buffer, "Illegal data buffer type for
CompositeByteBuf.");
+ BufferHeader bufferHeader =
BufferUtils.getBufferHeaderFromByteBuf(headerBuffer, 0);
+ Buffer slice = ((Buffer) dataBuffer).readOnlySlice(0,
bufferHeader.getSize());
+ buffers.add(
+ new UnpackSlicedBuffer(
+ slice,
+ bufferHeader.getDataType(),
+ bufferHeader.isCompressed(),
+ bufferHeader.getSize()));
+
+ return buffers;
+ }
+
Utils.checkState(byteBuf instanceof Buffer, "Illegal buffer type.");
Buffer buffer = (Buffer) byteBuf;
diff --git
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/FlinkTransportClientFactory.java
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/FlinkTransportClientFactory.java
index 3cb180b3f..0bfaaf99e 100644
---
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/FlinkTransportClientFactory.java
+++
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/FlinkTransportClientFactory.java
@@ -39,11 +39,14 @@ public class FlinkTransportClientFactory extends
TransportClientFactory {
private ConcurrentHashMap<Long, Supplier<ByteBuf>> bufferSuppliers;
+ private int bufferSizeBytes;
+
public FlinkTransportClientFactory(
- TransportContext context, List<TransportClientBootstrap> bootstraps) {
+ TransportContext context, List<TransportClientBootstrap> bootstraps, int
bufferSizeBytes) {
super(context, bootstraps);
bufferSuppliers = JavaUtils.newConcurrentHashMap();
this.pooledAllocator = new UnpooledByteBufAllocator(true);
+ this.bufferSizeBytes = bufferSizeBytes;
}
public TransportClient createClientWithRetry(String remoteHost, int
remotePort)
@@ -52,7 +55,7 @@ public class FlinkTransportClientFactory extends
TransportClientFactory {
remoteHost,
remotePort,
-1,
- () -> new TransportFrameDecoderWithBufferSupplier(bufferSuppliers));
+ () -> new TransportFrameDecoderWithBufferSupplier(bufferSuppliers,
bufferSizeBytes));
}
public void registerSupplier(long streamId, Supplier<ByteBuf> supplier) {
diff --git
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/TransportFrameDecoderWithBufferSupplier.java
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/TransportFrameDecoderWithBufferSupplier.java
index 9140b6b23..796734f3e 100644
---
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/TransportFrameDecoderWithBufferSupplier.java
+++
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/TransportFrameDecoderWithBufferSupplier.java
@@ -23,6 +23,7 @@ import java.util.function.Supplier;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf;
+import org.apache.flink.shaded.netty4.io.netty.buffer.CompositeByteBuf;
import org.apache.flink.shaded.netty4.io.netty.buffer.Unpooled;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -30,6 +31,7 @@ import org.slf4j.LoggerFactory;
import org.apache.celeborn.common.network.protocol.Message;
import org.apache.celeborn.common.network.util.FrameDecoder;
import org.apache.celeborn.plugin.flink.protocol.ReadData;
+import org.apache.celeborn.plugin.flink.utils.BufferUtils;
public class TransportFrameDecoderWithBufferSupplier extends
ChannelInboundHandlerAdapter
implements FrameDecoder {
@@ -44,17 +46,37 @@ public class TransportFrameDecoderWithBufferSupplier
extends ChannelInboundHandl
private final ByteBuf msgBuf = Unpooled.buffer(8);
private Message curMsg = null;
private int remainingSize = -1;
+ private int totalReadBytes = 0;
+ private int largeBufferHeaderRemainingBytes = -1;
+ private boolean isReadingLargeBuffer = false;
+ private ByteBuf largeBufferHeaderBuffer;
+ public static final int DISABLE_LARGE_BUFFER_SPLIT_SIZE = -1;
+
+ /**
+ * The flink buffer size bytes. If the received buffer size large than this
value, means that we
+ * need to divide the received buffer into multiple smaller buffers, each
small than {@link
+ * #bufferSizeBytes}. And when this value set to {@link
#DISABLE_LARGE_BUFFER_SPLIT_SIZE},
+ * indicates that large buffer splitting will not be checked.
+ */
+ private final int bufferSizeBytes;
private final ConcurrentHashMap<Long, Supplier<ByteBuf>> bufferSuppliers;
public TransportFrameDecoderWithBufferSupplier(
ConcurrentHashMap<Long, Supplier<ByteBuf>> bufferSuppliers) {
+ this(bufferSuppliers, DISABLE_LARGE_BUFFER_SPLIT_SIZE);
+ }
+
+ public TransportFrameDecoderWithBufferSupplier(
+ ConcurrentHashMap<Long, Supplier<ByteBuf>> bufferSuppliers, int
bufferSizeBytes) {
this.bufferSuppliers = bufferSuppliers;
+ this.bufferSizeBytes = bufferSizeBytes;
}
- private void copyByteBuf(io.netty.buffer.ByteBuf source, ByteBuf target, int
targetSize) {
+ private int copyByteBuf(io.netty.buffer.ByteBuf source, ByteBuf target, int
targetSize) {
int bytes = Math.min(source.readableBytes(), targetSize -
target.readableBytes());
target.writeBytes(source.readSlice(bytes).nioBuffer());
+ return bytes;
}
private void decodeHeader(io.netty.buffer.ByteBuf buf, ChannelHandlerContext
ctx) {
@@ -69,6 +91,15 @@ public class TransportFrameDecoderWithBufferSupplier extends
ChannelInboundHandl
// type byte is read
headerBuf.readByte();
bodySize = headerBuf.readInt();
+ if (bufferSizeBytes != DISABLE_LARGE_BUFFER_SPLIT_SIZE && bodySize >
bufferSizeBytes) {
+ // if the message body size is larger than bufferSizeBytes, we need to
split it into two
+ // parts: celeborn header and data buffer
+ isReadingLargeBuffer = true;
+ // create a temporary buffer to store the celeborn header
+ largeBufferHeaderBuffer =
+ Unpooled.buffer(BufferUtils.HEADER_LENGTH,
BufferUtils.HEADER_LENGTH);
+ largeBufferHeaderRemainingBytes = BufferUtils.HEADER_LENGTH;
+ }
decodeMsg(buf, ctx);
}
}
@@ -138,9 +169,31 @@ public class TransportFrameDecoderWithBufferSupplier
extends ChannelInboundHandl
}
}
- copyByteBuf(buf, externalBuf, bodySize);
- if (externalBuf.readableBytes() == bodySize) {
- ((ReadData) curMsg).setFlinkBuffer(externalBuf);
+ if (largeBufferHeaderRemainingBytes > 0) {
+ // if largeBufferHeaderRemainingBytes larger than zero, means that we
are reading the celeborn
+ // header
+ int headerReadBytes = copyByteBuf(buf, largeBufferHeaderBuffer,
BufferUtils.HEADER_LENGTH);
+ largeBufferHeaderRemainingBytes -= headerReadBytes;
+ totalReadBytes += headerReadBytes;
+ } else {
+ // if largeBufferHeaderRemainingBytes less or equal to zero, means that
we are reading the
+ // data buffer
+ totalReadBytes += copyByteBuf(buf, externalBuf,
getTargetDataBufferReadSize());
+ }
+
+ if (totalReadBytes == bodySize) {
+ ByteBuf resultByteBuf;
+ if (largeBufferHeaderBuffer == null) {
+ resultByteBuf = externalBuf;
+ } else {
+ // composite the celeborn header and data buffer together
+ CompositeByteBuf compositeByteBuf = Unpooled.compositeBuffer();
+ compositeByteBuf.addComponent(true, largeBufferHeaderBuffer);
+ compositeByteBuf.addComponent(true, externalBuf);
+ resultByteBuf = compositeByteBuf;
+ }
+
+ ((ReadData) curMsg).setFlinkBuffer(resultByteBuf);
ctx.fireChannelRead(curMsg);
clear();
}
@@ -192,6 +245,13 @@ public class TransportFrameDecoderWithBufferSupplier
extends ChannelInboundHandl
}
}
+ private int getTargetDataBufferReadSize() {
+ if (isReadingLargeBuffer) {
+ return bodySize - BufferUtils.HEADER_LENGTH;
+ }
+ return bodySize;
+ }
+
private void clear() {
externalBuf = null;
curMsg = null;
@@ -200,6 +260,10 @@ public class TransportFrameDecoderWithBufferSupplier
extends ChannelInboundHandl
bodyBuf = null;
bodySize = -1;
remainingSize = -1;
+ totalReadBytes = 0;
+ largeBufferHeaderRemainingBytes = -1;
+ largeBufferHeaderBuffer = null;
+ isReadingLargeBuffer = false;
}
@Override
diff --git
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/FlinkShuffleClientImpl.java
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/FlinkShuffleClientImpl.java
index efbf343ce..5602d1aac 100644
---
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/FlinkShuffleClientImpl.java
+++
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/FlinkShuffleClientImpl.java
@@ -68,6 +68,7 @@ import org.apache.celeborn.common.util.Utils;
import org.apache.celeborn.common.write.PushState;
import org.apache.celeborn.plugin.flink.network.FlinkTransportClientFactory;
import org.apache.celeborn.plugin.flink.network.ReadClientHandler;
+import
org.apache.celeborn.plugin.flink.network.TransportFrameDecoderWithBufferSupplier;
public class FlinkShuffleClientImpl extends ShuffleClientImpl {
public static final Logger logger =
LoggerFactory.getLogger(FlinkShuffleClientImpl.class);
@@ -81,6 +82,9 @@ public class FlinkShuffleClientImpl extends ShuffleClientImpl
{
private final TransportContext context;
+ /** The buffer size bytes in flink, default value is 32KB. */
+ private final int bufferSizeBytes;
+
public static FlinkShuffleClientImpl get(
String appUniqueId,
String driverHost,
@@ -89,18 +93,49 @@ public class FlinkShuffleClientImpl extends
ShuffleClientImpl {
CelebornConf conf,
UserIdentifier userIdentifier)
throws DriverChangedException {
+ return get(
+ appUniqueId,
+ driverHost,
+ port,
+ driverTimestamp,
+ conf,
+ userIdentifier,
+
TransportFrameDecoderWithBufferSupplier.DISABLE_LARGE_BUFFER_SPLIT_SIZE);
+ }
+
+ public static FlinkShuffleClientImpl get(
+ String appUniqueId,
+ String driverHost,
+ int port,
+ long driverTimestamp,
+ CelebornConf conf,
+ UserIdentifier userIdentifier,
+ int bufferSizeBytes)
+ throws DriverChangedException {
if (null == _instance || !initialized || _instance.driverTimestamp <
driverTimestamp) {
synchronized (FlinkShuffleClientImpl.class) {
if (null == _instance) {
_instance =
new FlinkShuffleClientImpl(
- appUniqueId, driverHost, port, driverTimestamp, conf,
userIdentifier);
+ appUniqueId,
+ driverHost,
+ port,
+ driverTimestamp,
+ conf,
+ userIdentifier,
+ bufferSizeBytes);
initialized = true;
} else if (!initialized || _instance.driverTimestamp <
driverTimestamp) {
_instance.shutdown();
_instance =
new FlinkShuffleClientImpl(
- appUniqueId, driverHost, port, driverTimestamp, conf,
userIdentifier);
+ appUniqueId,
+ driverHost,
+ port,
+ driverTimestamp,
+ conf,
+ userIdentifier,
+ bufferSizeBytes);
initialized = true;
}
}
@@ -133,8 +168,10 @@ public class FlinkShuffleClientImpl extends
ShuffleClientImpl {
int port,
long driverTimestamp,
CelebornConf conf,
- UserIdentifier userIdentifier) {
+ UserIdentifier userIdentifier,
+ int bufferSizeBytes) {
super(appUniqueId, conf, userIdentifier);
+ this.bufferSizeBytes = bufferSizeBytes;
String module = TransportModuleConstants.DATA_MODULE;
TransportConf dataTransportConf =
Utils.fromCelebornConf(conf, module, conf.getInt("celeborn." + module
+ ".io.threads", 8));
@@ -147,7 +184,8 @@ public class FlinkShuffleClientImpl extends
ShuffleClientImpl {
private void initializeTransportClientFactory() {
if (null == flinkTransportClientFactory) {
- flinkTransportClientFactory = new FlinkTransportClientFactory(context,
createBootstraps());
+ flinkTransportClientFactory =
+ new FlinkTransportClientFactory(context, createBootstraps(),
bufferSizeBytes);
}
}
diff --git
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/utils/BufferUtils.java
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/utils/BufferUtils.java
index 999d1eb10..b28e6f753 100644
---
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/utils/BufferUtils.java
+++
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/utils/BufferUtils.java
@@ -113,6 +113,18 @@ public class BufferUtils {
}
}
+ public static BufferHeader getBufferHeaderFromByteBuf(ByteBuf byteBuf, int
position) {
+ byteBuf.readerIndex(position);
+ return new BufferHeader(
+ byteBuf.readInt(),
+ byteBuf.readInt(),
+ byteBuf.readInt(),
+ byteBuf.readInt(),
+ Buffer.DataType.values()[byteBuf.readByte()],
+ byteBuf.readBoolean(),
+ byteBuf.readInt());
+ }
+
public static void reserveNumRequiredBuffers(BufferPool bufferPool, int
numRequiredBuffers)
throws IOException {
long startTime = System.nanoTime();
diff --git
a/client-flink/common/src/test/java/org/apache/celeborn/plugin/flink/BufferPackSuiteJ.java
b/client-flink/common/src/test/java/org/apache/celeborn/plugin/flink/BufferPackSuiteJ.java
index 8f3c0ce6e..acf42401a 100644
---
a/client-flink/common/src/test/java/org/apache/celeborn/plugin/flink/BufferPackSuiteJ.java
+++
b/client-flink/common/src/test/java/org/apache/celeborn/plugin/flink/BufferPackSuiteJ.java
@@ -26,6 +26,8 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
+import java.util.Queue;
+import java.util.Random;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.flink.core.memory.MemorySegment;
@@ -38,6 +40,7 @@ import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf;
import org.apache.flink.shaded.netty4.io.netty.buffer.CompositeByteBuf;
import org.apache.flink.shaded.netty4.io.netty.buffer.Unpooled;
import org.junit.After;
+import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
@@ -220,6 +223,27 @@ public class BufferPackSuiteJ {
unpacked.forEach(Buffer::recycleBuffer);
}
+ @Test
+ public void testUnpackCompositeBuffer() throws Exception {
+ Buffer dataBuffer = bufferPool.requestBuffer();
+ fillBufferWithRandomByte(dataBuffer);
+ ByteBuf bufferHeaderByteBuf = createBufferHeaderByteBuf(BUFFER_SIZE);
+ bufferHeaderByteBuf.retain();
+ CompositeByteBuf compositeByteBuf = Unpooled.compositeBuffer();
+ compositeByteBuf.addComponent(true, bufferHeaderByteBuf);
+ compositeByteBuf.addComponent(true, dataBuffer.asByteBuf());
+
+ Queue<Buffer> unpackedBuffers = BufferPacker.unpack(compositeByteBuf);
+ Assert.assertEquals(1, unpackedBuffers.size());
+ Assert.assertEquals(dataBuffer.readableBytes(),
unpackedBuffers.peek().readableBytes());
+ Assert.assertEquals(BUFFER_SIZE, unpackedBuffers.peek().readableBytes());
+ for (int i = 0; i < BUFFER_SIZE; ++i) {
+ Assert.assertEquals(
+ dataBuffer.getMemorySegment().get(i),
unpackedBuffers.peek().getMemorySegment().get(i));
+ }
+ dataBuffer.recycleBuffer();
+ }
+
@Test
public void testPackMultipleBuffers() throws Exception {
int numBuffers = 7;
@@ -404,4 +428,28 @@ public class BufferPackSuiteJ {
return new ReceivedNoHeaderBufferPacker(ripeBufferHandler);
}
}
+
+ public ByteBuf createBufferHeaderByteBuf(int dataBufferSize) {
+ ByteBuf headerBuf = Unpooled.directBuffer(BufferUtils.HEADER_LENGTH,
BufferUtils.HEADER_LENGTH);
+ // write celeborn buffer header (subpartitionid(4) + attemptId(4) +
nextBatchId(4) +
+ // compressedsize)
+ headerBuf.writeInt(0);
+ headerBuf.writeInt(0);
+ headerBuf.writeInt(0);
+ headerBuf.writeInt(
+ dataBufferSize + (BufferUtils.HEADER_LENGTH -
BufferUtils.HEADER_LENGTH_PREFIX));
+
+ // write flink buffer header (dataType(1) + isCompress(1) + size(4))
+ headerBuf.writeByte(DATA_BUFFER.ordinal());
+ headerBuf.writeBoolean(false);
+ headerBuf.writeInt(dataBufferSize);
+ return headerBuf;
+ }
+
+ public void fillBufferWithRandomByte(Buffer buffer) {
+ Random random = new Random();
+ for (int i = 0; i < buffer.getMaxCapacity(); i++) {
+ buffer.asByteBuf().writeByte(random.nextInt(255));
+ }
+ }
}
diff --git
a/client-flink/common/src/test/java/org/apache/celeborn/plugin/flink/FlinkShuffleClientImplSuiteJ.java
b/client-flink/common/src/test/java/org/apache/celeborn/plugin/flink/FlinkShuffleClientImplSuiteJ.java
index 60a843f4a..cb15c1e15 100644
---
a/client-flink/common/src/test/java/org/apache/celeborn/plugin/flink/FlinkShuffleClientImplSuiteJ.java
+++
b/client-flink/common/src/test/java/org/apache/celeborn/plugin/flink/FlinkShuffleClientImplSuiteJ.java
@@ -55,7 +55,7 @@ public class FlinkShuffleClientImplSuiteJ {
conf = new CelebornConf();
shuffleClient =
new FlinkShuffleClientImpl(
- "APP", "localhost", 1232, System.currentTimeMillis(), conf, null) {
+ "APP", "localhost", 1232, System.currentTimeMillis(), conf, null,
-1) {
@Override
public void setupLifecycleManagerRef(String host, int port) {}
};
diff --git
a/client-flink/common/src/test/java/org/apache/celeborn/plugin/flink/network/TransportFrameDecoderWithBufferSupplierSuiteJ.java
b/client-flink/common/src/test/java/org/apache/celeborn/plugin/flink/network/TransportFrameDecoderWithBufferSupplierSuiteJ.java
index c7c8440c8..431f8bc62 100644
---
a/client-flink/common/src/test/java/org/apache/celeborn/plugin/flink/network/TransportFrameDecoderWithBufferSupplierSuiteJ.java
+++
b/client-flink/common/src/test/java/org/apache/celeborn/plugin/flink/network/TransportFrameDecoderWithBufferSupplierSuiteJ.java
@@ -18,6 +18,8 @@
package org.apache.celeborn.plugin.flink.network;
import static
org.apache.celeborn.common.network.client.TransportClient.requestId;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.when;
import java.io.IOException;
import java.util.ArrayList;
@@ -31,6 +33,7 @@ import java.util.function.Supplier;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext;
+import org.apache.flink.shaded.netty4.io.netty.buffer.CompositeByteBuf;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
@@ -46,6 +49,7 @@ import
org.apache.celeborn.common.network.protocol.TransportMessage;
import org.apache.celeborn.common.protocol.MessageType;
import org.apache.celeborn.common.protocol.PbBacklogAnnouncement;
import org.apache.celeborn.common.util.JavaUtils;
+import org.apache.celeborn.plugin.flink.utils.BufferUtils;
@RunWith(Parameterized.class)
public class TransportFrameDecoderWithBufferSupplierSuiteJ {
@@ -131,6 +135,125 @@ public class
TransportFrameDecoderWithBufferSupplierSuiteJ {
Assert.assertEquals(buffers.size(), 6);
}
+ @Test(expected = IndexOutOfBoundsException.class)
+ public void testFailProcessFullBufferIfDisableLargeBufferSplit() throws
IOException {
+ int bufferSizeBytes = 10 * 1024;
+ ConcurrentHashMap<Long,
Supplier<org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf>>
+ supplier = JavaUtils.newConcurrentHashMap();
+ List<org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf> buffers = new
ArrayList<>();
+
+ supplier.put(
+ 0L,
+ () -> {
+ org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf buffer =
+ org.apache.flink.shaded.netty4.io.netty.buffer.Unpooled.buffer(
+ bufferSizeBytes, bufferSizeBytes);
+ buffers.add(buffer);
+ return buffer;
+ });
+
+ TransportFrameDecoderWithBufferSupplier decoder =
+ new TransportFrameDecoderWithBufferSupplier(
+ supplier,
TransportFrameDecoderWithBufferSupplier.DISABLE_LARGE_BUFFER_SPLIT_SIZE);
+ ChannelHandlerContext context = Mockito.mock(ChannelHandlerContext.class);
+
+ SubPartitionReadData readData =
+ new SubPartitionReadData(0, 0, generateData(bufferSizeBytes +
BufferUtils.HEADER_LENGTH));
+
+ ByteBuf buffer = Unpooled.buffer(bufferSizeBytes * 4);
+ encodeMessage(readData, buffer);
+
+ // simulate
+ buffer.retain();
+ decoder.channelRead(context, buffer);
+ }
+
+ @Test
+ public void testProcessFullBufferIfEnableLargeBufferSplit() throws
IOException {
+ int bufferSizeBytes = 10 * 1024;
+ ConcurrentHashMap<Long,
Supplier<org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf>>
+ supplier = JavaUtils.newConcurrentHashMap();
+ List<Message> parsedMessages = new ArrayList<>();
+
+ supplier.put(
+ 0L,
+ () ->
+ org.apache.flink.shaded.netty4.io.netty.buffer.Unpooled.buffer(
+ bufferSizeBytes, bufferSizeBytes));
+
+ TransportFrameDecoderWithBufferSupplier decoder =
+ new TransportFrameDecoderWithBufferSupplier(supplier, bufferSizeBytes);
+ ChannelHandlerContext context = Mockito.mock(ChannelHandlerContext.class);
+ when(context.fireChannelRead(any()))
+ .thenAnswer(
+ m -> {
+ Assert.assertEquals(1, m.getArguments().length);
+ parsedMessages.add(m.getArgument(0));
+ return null;
+ });
+
+ ReadData readData1 = new ReadData(0, generateData(1024));
+ // simulate the client received a large buffer which body size large than
size of given buffer
+ // in this case, the flinkBuffer of parsed message will contain two parts:
celeborn header and
+ // data buffer
+ SubPartitionReadData readData2 =
+ new SubPartitionReadData(0, 0, generateData(BufferUtils.HEADER_LENGTH
+ bufferSizeBytes));
+ SubPartitionReadData readData3 = new SubPartitionReadData(0, 0,
generateData(1024));
+
+ ByteBuf buffer = Unpooled.buffer(bufferSizeBytes * 4);
+ encodeMessage(readData1, buffer);
+ encodeMessage(readData2, buffer);
+ encodeMessage(readData3, buffer);
+
+ // simulate
+ buffer.retain();
+ decoder.channelRead(context, buffer);
+ Assert.assertEquals(parsedMessages.size(), 3);
+
+ // the parsed first message contains the readData1
+ Assert.assertTrue(
+ parsedMessages.get(0) instanceof
org.apache.celeborn.plugin.flink.protocol.ReadData);
+ Assert.assertEquals(
+ ((org.apache.celeborn.plugin.flink.protocol.ReadData)
parsedMessages.get(0))
+ .getFlinkBuffer()
+ .nioBuffer(),
+ readData1.body().nioByteBuffer());
+
+ // the parsed second message contains the readData2
+ Assert.assertTrue(
+ parsedMessages.get(1)
+ instanceof
org.apache.celeborn.plugin.flink.protocol.SubPartitionReadData);
+ org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf byteBuf2 =
+ ((org.apache.celeborn.plugin.flink.protocol.SubPartitionReadData)
parsedMessages.get(1))
+ .getFlinkBuffer();
+ // verify the flinkBuffer of parsed message contains two parts: celeborn
header and data buffer
+ Assert.assertTrue(
+ byteBuf2 instanceof
org.apache.flink.shaded.netty4.io.netty.buffer.CompositeByteBuf);
+ CompositeByteBuf compositeByteBuf2 = (CompositeByteBuf) byteBuf2;
+ Assert.assertEquals(compositeByteBuf2.numComponents(), 2);
+ org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf inputByteBuf2 =
+ org.apache.flink.shaded.netty4.io.netty.buffer.Unpooled.wrappedBuffer(
+ readData2.body().nioByteBuffer());
+ // the first part is celeborn header
+ Assert.assertEquals(
+ compositeByteBuf2.component(0).nioBuffer(),
+ inputByteBuf2.slice(0, BufferUtils.HEADER_LENGTH).nioBuffer());
+ // the second part is data buffer
+ Assert.assertEquals(
+ compositeByteBuf2.component(1).nioBuffer(),
+ inputByteBuf2.slice(BufferUtils.HEADER_LENGTH,
bufferSizeBytes).nioBuffer());
+
+ // the parsed third message contains the readData3
+ Assert.assertTrue(
+ parsedMessages.get(2)
+ instanceof
org.apache.celeborn.plugin.flink.protocol.SubPartitionReadData);
+ Assert.assertEquals(
+ ((org.apache.celeborn.plugin.flink.protocol.SubPartitionReadData)
parsedMessages.get(2))
+ .getFlinkBuffer()
+ .nioBuffer(),
+ readData3.body().nioByteBuffer());
+ }
+
public RpcRequest createBacklogAnnouncement(long streamId, int backlog) {
return new RpcRequest(
requestId(),
diff --git
a/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierConsumerAgent.java
b/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierConsumerAgent.java
index 0febd8bd3..8d06ba77c 100644
---
a/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierConsumerAgent.java
+++
b/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierConsumerAgent.java
@@ -117,10 +117,13 @@ public class CelebornTierConsumerAgent implements
TierConsumerAgent {
private TieredStorageMemoryManager memoryManager;
+ private final int bufferSizeBytes;
+
public CelebornTierConsumerAgent(
CelebornConf conf,
List<TieredStorageConsumerSpec> tieredStorageConsumerSpecs,
- List<TierShuffleDescriptor> shuffleDescriptors) {
+ List<TierShuffleDescriptor> shuffleDescriptors,
+ int bufferSizeBytes) {
checkArgument(!shuffleDescriptors.isEmpty(), "Wrong shuffle descriptors
size.");
checkArgument(
tieredStorageConsumerSpecs.size() == shuffleDescriptors.size(),
@@ -132,6 +135,7 @@ public class CelebornTierConsumerAgent implements
TierConsumerAgent {
this.bufferReaders = new HashMap<>();
this.receivedBuffers = new HashMap<>();
this.subPartitionsNeedNotifyAvailable = new HashSet<>();
+ this.bufferSizeBytes = bufferSizeBytes;
for (TierShuffleDescriptor shuffleDescriptor : shuffleDescriptors) {
if (shuffleDescriptor instanceof TierShuffleDescriptorImpl) {
initShuffleClient((TierShuffleDescriptorImpl) shuffleDescriptor);
@@ -326,7 +330,8 @@ public class CelebornTierConsumerAgent implements
TierConsumerAgent {
shuffleResource.getLifecycleManagerPort(),
shuffleResource.getLifecycleManagerTimestamp(),
conf,
- new UserIdentifier("default", "default"));
+ new UserIdentifier("default", "default"),
+ bufferSizeBytes);
} catch (DriverChangedException e) {
throw new RuntimeException(e.getMessage());
}
diff --git
a/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierFactory.java
b/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierFactory.java
index 1a86130e4..c9913d132 100644
---
a/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierFactory.java
+++
b/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierFactory.java
@@ -118,7 +118,8 @@ public class CelebornTierFactory implements TierFactory {
List<TieredStorageConsumerSpec> tieredStorageConsumerSpecs,
List<TierShuffleDescriptor> shuffleDescriptors,
TieredStorageNettyService nettyService) {
- return new CelebornTierConsumerAgent(conf, tieredStorageConsumerSpecs,
shuffleDescriptors);
+ return new CelebornTierConsumerAgent(
+ conf, tieredStorageConsumerSpecs, shuffleDescriptors, bufferSizeBytes);
}
public static String getCelebornTierName() {
diff --git
a/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierProducerAgent.java
b/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierProducerAgent.java
index aab2b3ae5..983f24cb0 100644
---
a/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierProducerAgent.java
+++
b/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierProducerAgent.java
@@ -64,6 +64,7 @@ public class CelebornTierProducerAgent implements
TierProducerAgent {
private final int numBuffersPerSegment;
+ // The flink buffer size in bytes.
private final int bufferSizeBytes;
private final int numPartitions;
@@ -325,9 +326,18 @@ public class CelebornTierProducerAgent implements
TierProducerAgent {
try {
int remainingReviveTimes = maxReviveTimes;
while (remainingReviveTimes-- > 0 && !hasSentHandshake) {
+ // In the Flink hybrid shuffle integration strategy, the data buffer
sent to the Celeborn
+ // workers consists of two components: the Celeborn header and the
data buffers.
+ // In this scenario, the maximum byte size of the buffer received by
the Celeborn worker is
+ // equal to the sum of the Flink buffer size and the Celeborn header
size.
Optional<PartitionLocation> revivePartition =
flinkShuffleClient.pushDataHandShake(
- shuffleId, mapId, attemptId, numSubPartitions,
bufferSizeBytes, partitionLocation);
+ shuffleId,
+ mapId,
+ attemptId,
+ numSubPartitions,
+ bufferSizeBytes + BufferUtils.HEADER_LENGTH,
+ partitionLocation);
// if remainingReviveTimes == 0 and revivePartition.isPresent(), there
is no need to send
// handshake again
if (revivePartition.isPresent() && remainingReviveTimes > 0) {
@@ -478,7 +488,8 @@ public class CelebornTierProducerAgent implements
TierProducerAgent {
lifecycleManagerPort,
lifecycleManagerTimestamp,
celebornConf,
- null);
+ null,
+ bufferSizeBytes);
} catch (DriverChangedException e) {
// would generate a new attempt to retry output gate
throw new RuntimeException(e.getMessage());
diff --git
a/tests/flink-it/src/test/scala/org/apache/celeborn/tests/flink/HeartbeatTest.scala
b/tests/flink-it/src/test/scala/org/apache/celeborn/tests/flink/HeartbeatTest.scala
index a373c6fd8..dbd7e543f 100644
---
a/tests/flink-it/src/test/scala/org/apache/celeborn/tests/flink/HeartbeatTest.scala
+++
b/tests/flink-it/src/test/scala/org/apache/celeborn/tests/flink/HeartbeatTest.scala
@@ -37,7 +37,8 @@ class HeartbeatTest extends AnyFunSuite with Logging with
MiniClusterFeature wit
0,
System.currentTimeMillis(),
clientConf,
- new UserIdentifier("1", "1")) {
+ new UserIdentifier("1", "1"),
+ -1) {
override def setupLifecycleManagerRef(host: String, port: Int): Unit =
{}
}
testHeartbeatFromWorker2Client(flinkShuffleClientImpl.getDataClientFactory)
@@ -52,7 +53,8 @@ class HeartbeatTest extends AnyFunSuite with Logging with
MiniClusterFeature wit
0,
System.currentTimeMillis(),
clientConf,
- new UserIdentifier("1", "1")) {
+ new UserIdentifier("1", "1"),
+ -1) {
override def setupLifecycleManagerRef(host: String, port: Int): Unit =
{}
}
testHeartbeatFromWorker2ClientWithNoHeartbeat(flinkShuffleClientImpl.getDataClientFactory)
@@ -67,7 +69,8 @@ class HeartbeatTest extends AnyFunSuite with Logging with
MiniClusterFeature wit
0,
System.currentTimeMillis(),
clientConf,
- new UserIdentifier("1", "1")) {
+ new UserIdentifier("1", "1"),
+ -1) {
override def setupLifecycleManagerRef(host: String, port: Int): Unit =
{}
}
testHeartbeatFromWorker2ClientWithCloseChannel(flinkShuffleClientImpl.getDataClientFactory)