This is an automated email from the ASF dual-hosted git repository.

rexxiong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new 7ebd168f8 [CELEBORN-1490][CIP-6] Support process large buffer in flink 
hybrid shuffle
7ebd168f8 is described below

commit 7ebd168f808afe4cfbccaf75d074299d05eb9c50
Author: Yuxin Tan <[email protected]>
AuthorDate: Mon Nov 4 16:57:43 2024 +0800

    [CELEBORN-1490][CIP-6] Support process large buffer in flink hybrid shuffle
    
    ### What changes were proposed in this pull request?
    
    This is the last PR in the CIP-6 series.
    
    Fix the bug when hybrid shuffle face the buffer which large then 32K.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    UT
    
    Closes #2873 from reswqa/11-large-buffer-10month.
    
    Lead-authored-by: Yuxin Tan <[email protected]>
    Co-authored-by: Weijie Guo <[email protected]>
    Signed-off-by: Shuang <[email protected]>
---
 .../celeborn/plugin/flink/buffer/BufferPacker.java |  23 ++++
 .../flink/network/FlinkTransportClientFactory.java |   7 +-
 .../TransportFrameDecoderWithBufferSupplier.java   |  72 +++++++++++-
 .../flink/readclient/FlinkShuffleClientImpl.java   |  46 +++++++-
 .../celeborn/plugin/flink/utils/BufferUtils.java   |  12 ++
 .../celeborn/plugin/flink/BufferPackSuiteJ.java    |  48 ++++++++
 .../plugin/flink/FlinkShuffleClientImplSuiteJ.java |   2 +-
 ...nsportFrameDecoderWithBufferSupplierSuiteJ.java | 123 +++++++++++++++++++++
 .../flink/tiered/CelebornTierConsumerAgent.java    |   9 +-
 .../plugin/flink/tiered/CelebornTierFactory.java   |   3 +-
 .../flink/tiered/CelebornTierProducerAgent.java    |  15 ++-
 .../celeborn/tests/flink/HeartbeatTest.scala       |   9 +-
 12 files changed, 350 insertions(+), 19 deletions(-)

diff --git 
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/buffer/BufferPacker.java
 
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/buffer/BufferPacker.java
index 76a6c2ef7..8876b6b08 100644
--- 
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/buffer/BufferPacker.java
+++ 
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/buffer/BufferPacker.java
@@ -26,6 +26,7 @@ import org.apache.flink.runtime.io.network.buffer.Buffer;
 import org.apache.flink.runtime.io.network.buffer.BufferRecycler;
 import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf;
 import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBufAllocator;
+import org.apache.flink.shaded.netty4.io.netty.buffer.CompositeByteBuf;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -157,6 +158,28 @@ public class BufferPacker {
   public static Queue<Buffer> unpack(ByteBuf byteBuf) {
     Queue<Buffer> buffers = new ArrayDeque<>();
     try {
+      if (byteBuf instanceof CompositeByteBuf) {
+        // If the received byteBuf is a CompositeByteBuf, it indicates that 
the byteBuf originates
+        // from the Flink hybrid shuffle integration strategy. This byteBuf 
consists of two parts: a
+        // celeborn header and a data buffer.
+        CompositeByteBuf compositeByteBuf = (CompositeByteBuf) byteBuf;
+        ByteBuf headerBuffer = compositeByteBuf.component(0).unwrap();
+        ByteBuf dataBuffer = compositeByteBuf.component(1).unwrap();
+        dataBuffer.retain();
+        Utils.checkState(
+            dataBuffer instanceof Buffer, "Illegal data buffer type for 
CompositeByteBuf.");
+        BufferHeader bufferHeader = 
BufferUtils.getBufferHeaderFromByteBuf(headerBuffer, 0);
+        Buffer slice = ((Buffer) dataBuffer).readOnlySlice(0, 
bufferHeader.getSize());
+        buffers.add(
+            new UnpackSlicedBuffer(
+                slice,
+                bufferHeader.getDataType(),
+                bufferHeader.isCompressed(),
+                bufferHeader.getSize()));
+
+        return buffers;
+      }
+
       Utils.checkState(byteBuf instanceof Buffer, "Illegal buffer type.");
 
       Buffer buffer = (Buffer) byteBuf;
diff --git 
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/FlinkTransportClientFactory.java
 
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/FlinkTransportClientFactory.java
index 3cb180b3f..0bfaaf99e 100644
--- 
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/FlinkTransportClientFactory.java
+++ 
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/FlinkTransportClientFactory.java
@@ -39,11 +39,14 @@ public class FlinkTransportClientFactory extends 
TransportClientFactory {
 
   private ConcurrentHashMap<Long, Supplier<ByteBuf>> bufferSuppliers;
 
+  private int bufferSizeBytes;
+
   public FlinkTransportClientFactory(
-      TransportContext context, List<TransportClientBootstrap> bootstraps) {
+      TransportContext context, List<TransportClientBootstrap> bootstraps, int 
bufferSizeBytes) {
     super(context, bootstraps);
     bufferSuppliers = JavaUtils.newConcurrentHashMap();
     this.pooledAllocator = new UnpooledByteBufAllocator(true);
+    this.bufferSizeBytes = bufferSizeBytes;
   }
 
   public TransportClient createClientWithRetry(String remoteHost, int 
remotePort)
@@ -52,7 +55,7 @@ public class FlinkTransportClientFactory extends 
TransportClientFactory {
         remoteHost,
         remotePort,
         -1,
-        () -> new TransportFrameDecoderWithBufferSupplier(bufferSuppliers));
+        () -> new TransportFrameDecoderWithBufferSupplier(bufferSuppliers, 
bufferSizeBytes));
   }
 
   public void registerSupplier(long streamId, Supplier<ByteBuf> supplier) {
diff --git 
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/TransportFrameDecoderWithBufferSupplier.java
 
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/TransportFrameDecoderWithBufferSupplier.java
index 9140b6b23..796734f3e 100644
--- 
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/TransportFrameDecoderWithBufferSupplier.java
+++ 
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/TransportFrameDecoderWithBufferSupplier.java
@@ -23,6 +23,7 @@ import java.util.function.Supplier;
 import io.netty.channel.ChannelHandlerContext;
 import io.netty.channel.ChannelInboundHandlerAdapter;
 import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf;
+import org.apache.flink.shaded.netty4.io.netty.buffer.CompositeByteBuf;
 import org.apache.flink.shaded.netty4.io.netty.buffer.Unpooled;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -30,6 +31,7 @@ import org.slf4j.LoggerFactory;
 import org.apache.celeborn.common.network.protocol.Message;
 import org.apache.celeborn.common.network.util.FrameDecoder;
 import org.apache.celeborn.plugin.flink.protocol.ReadData;
+import org.apache.celeborn.plugin.flink.utils.BufferUtils;
 
 public class TransportFrameDecoderWithBufferSupplier extends 
ChannelInboundHandlerAdapter
     implements FrameDecoder {
@@ -44,17 +46,37 @@ public class TransportFrameDecoderWithBufferSupplier 
extends ChannelInboundHandl
   private final ByteBuf msgBuf = Unpooled.buffer(8);
   private Message curMsg = null;
   private int remainingSize = -1;
+  private int totalReadBytes = 0;
+  private int largeBufferHeaderRemainingBytes = -1;
+  private boolean isReadingLargeBuffer = false;
+  private ByteBuf largeBufferHeaderBuffer;
+  public static final int DISABLE_LARGE_BUFFER_SPLIT_SIZE = -1;
+
+  /**
+   * The flink buffer size bytes. If the received buffer size large than this 
value, means that we
+   * need to divide the received buffer into multiple smaller buffers, each 
small than {@link
+   * #bufferSizeBytes}. And when this value set to {@link 
#DISABLE_LARGE_BUFFER_SPLIT_SIZE},
+   * indicates that large buffer splitting will not be checked.
+   */
+  private final int bufferSizeBytes;
 
   private final ConcurrentHashMap<Long, Supplier<ByteBuf>> bufferSuppliers;
 
   public TransportFrameDecoderWithBufferSupplier(
       ConcurrentHashMap<Long, Supplier<ByteBuf>> bufferSuppliers) {
+    this(bufferSuppliers, DISABLE_LARGE_BUFFER_SPLIT_SIZE);
+  }
+
+  public TransportFrameDecoderWithBufferSupplier(
+      ConcurrentHashMap<Long, Supplier<ByteBuf>> bufferSuppliers, int 
bufferSizeBytes) {
     this.bufferSuppliers = bufferSuppliers;
+    this.bufferSizeBytes = bufferSizeBytes;
   }
 
-  private void copyByteBuf(io.netty.buffer.ByteBuf source, ByteBuf target, int 
targetSize) {
+  private int copyByteBuf(io.netty.buffer.ByteBuf source, ByteBuf target, int 
targetSize) {
     int bytes = Math.min(source.readableBytes(), targetSize - 
target.readableBytes());
     target.writeBytes(source.readSlice(bytes).nioBuffer());
+    return bytes;
   }
 
   private void decodeHeader(io.netty.buffer.ByteBuf buf, ChannelHandlerContext 
ctx) {
@@ -69,6 +91,15 @@ public class TransportFrameDecoderWithBufferSupplier extends 
ChannelInboundHandl
       // type byte is read
       headerBuf.readByte();
       bodySize = headerBuf.readInt();
+      if (bufferSizeBytes != DISABLE_LARGE_BUFFER_SPLIT_SIZE && bodySize > 
bufferSizeBytes) {
+        // if the message body size is larger than bufferSizeBytes, we need to 
split it into two
+        // parts: celeborn header and data buffer
+        isReadingLargeBuffer = true;
+        // create a temporary buffer to store the celeborn header
+        largeBufferHeaderBuffer =
+            Unpooled.buffer(BufferUtils.HEADER_LENGTH, 
BufferUtils.HEADER_LENGTH);
+        largeBufferHeaderRemainingBytes = BufferUtils.HEADER_LENGTH;
+      }
       decodeMsg(buf, ctx);
     }
   }
@@ -138,9 +169,31 @@ public class TransportFrameDecoderWithBufferSupplier 
extends ChannelInboundHandl
       }
     }
 
-    copyByteBuf(buf, externalBuf, bodySize);
-    if (externalBuf.readableBytes() == bodySize) {
-      ((ReadData) curMsg).setFlinkBuffer(externalBuf);
+    if (largeBufferHeaderRemainingBytes > 0) {
+      // if largeBufferHeaderRemainingBytes larger than zero, means that we 
are reading the celeborn
+      // header
+      int headerReadBytes = copyByteBuf(buf, largeBufferHeaderBuffer, 
BufferUtils.HEADER_LENGTH);
+      largeBufferHeaderRemainingBytes -= headerReadBytes;
+      totalReadBytes += headerReadBytes;
+    } else {
+      // if largeBufferHeaderRemainingBytes less or equal to zero, means that 
we are reading the
+      // data buffer
+      totalReadBytes += copyByteBuf(buf, externalBuf, 
getTargetDataBufferReadSize());
+    }
+
+    if (totalReadBytes == bodySize) {
+      ByteBuf resultByteBuf;
+      if (largeBufferHeaderBuffer == null) {
+        resultByteBuf = externalBuf;
+      } else {
+        // composite the celeborn header and data buffer together
+        CompositeByteBuf compositeByteBuf = Unpooled.compositeBuffer();
+        compositeByteBuf.addComponent(true, largeBufferHeaderBuffer);
+        compositeByteBuf.addComponent(true, externalBuf);
+        resultByteBuf = compositeByteBuf;
+      }
+
+      ((ReadData) curMsg).setFlinkBuffer(resultByteBuf);
       ctx.fireChannelRead(curMsg);
       clear();
     }
@@ -192,6 +245,13 @@ public class TransportFrameDecoderWithBufferSupplier 
extends ChannelInboundHandl
     }
   }
 
+  private int getTargetDataBufferReadSize() {
+    if (isReadingLargeBuffer) {
+      return bodySize - BufferUtils.HEADER_LENGTH;
+    }
+    return bodySize;
+  }
+
   private void clear() {
     externalBuf = null;
     curMsg = null;
@@ -200,6 +260,10 @@ public class TransportFrameDecoderWithBufferSupplier 
extends ChannelInboundHandl
     bodyBuf = null;
     bodySize = -1;
     remainingSize = -1;
+    totalReadBytes = 0;
+    largeBufferHeaderRemainingBytes = -1;
+    largeBufferHeaderBuffer = null;
+    isReadingLargeBuffer = false;
   }
 
   @Override
diff --git 
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/FlinkShuffleClientImpl.java
 
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/FlinkShuffleClientImpl.java
index efbf343ce..5602d1aac 100644
--- 
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/FlinkShuffleClientImpl.java
+++ 
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/FlinkShuffleClientImpl.java
@@ -68,6 +68,7 @@ import org.apache.celeborn.common.util.Utils;
 import org.apache.celeborn.common.write.PushState;
 import org.apache.celeborn.plugin.flink.network.FlinkTransportClientFactory;
 import org.apache.celeborn.plugin.flink.network.ReadClientHandler;
+import 
org.apache.celeborn.plugin.flink.network.TransportFrameDecoderWithBufferSupplier;
 
 public class FlinkShuffleClientImpl extends ShuffleClientImpl {
   public static final Logger logger = 
LoggerFactory.getLogger(FlinkShuffleClientImpl.class);
@@ -81,6 +82,9 @@ public class FlinkShuffleClientImpl extends ShuffleClientImpl 
{
 
   private final TransportContext context;
 
+  /** The buffer size bytes in flink, default value is 32KB. */
+  private final int bufferSizeBytes;
+
   public static FlinkShuffleClientImpl get(
       String appUniqueId,
       String driverHost,
@@ -89,18 +93,49 @@ public class FlinkShuffleClientImpl extends 
ShuffleClientImpl {
       CelebornConf conf,
       UserIdentifier userIdentifier)
       throws DriverChangedException {
+    return get(
+        appUniqueId,
+        driverHost,
+        port,
+        driverTimestamp,
+        conf,
+        userIdentifier,
+        
TransportFrameDecoderWithBufferSupplier.DISABLE_LARGE_BUFFER_SPLIT_SIZE);
+  }
+
+  public static FlinkShuffleClientImpl get(
+      String appUniqueId,
+      String driverHost,
+      int port,
+      long driverTimestamp,
+      CelebornConf conf,
+      UserIdentifier userIdentifier,
+      int bufferSizeBytes)
+      throws DriverChangedException {
     if (null == _instance || !initialized || _instance.driverTimestamp < 
driverTimestamp) {
       synchronized (FlinkShuffleClientImpl.class) {
         if (null == _instance) {
           _instance =
               new FlinkShuffleClientImpl(
-                  appUniqueId, driverHost, port, driverTimestamp, conf, 
userIdentifier);
+                  appUniqueId,
+                  driverHost,
+                  port,
+                  driverTimestamp,
+                  conf,
+                  userIdentifier,
+                  bufferSizeBytes);
           initialized = true;
         } else if (!initialized || _instance.driverTimestamp < 
driverTimestamp) {
           _instance.shutdown();
           _instance =
               new FlinkShuffleClientImpl(
-                  appUniqueId, driverHost, port, driverTimestamp, conf, 
userIdentifier);
+                  appUniqueId,
+                  driverHost,
+                  port,
+                  driverTimestamp,
+                  conf,
+                  userIdentifier,
+                  bufferSizeBytes);
           initialized = true;
         }
       }
@@ -133,8 +168,10 @@ public class FlinkShuffleClientImpl extends 
ShuffleClientImpl {
       int port,
       long driverTimestamp,
       CelebornConf conf,
-      UserIdentifier userIdentifier) {
+      UserIdentifier userIdentifier,
+      int bufferSizeBytes) {
     super(appUniqueId, conf, userIdentifier);
+    this.bufferSizeBytes = bufferSizeBytes;
     String module = TransportModuleConstants.DATA_MODULE;
     TransportConf dataTransportConf =
         Utils.fromCelebornConf(conf, module, conf.getInt("celeborn." + module 
+ ".io.threads", 8));
@@ -147,7 +184,8 @@ public class FlinkShuffleClientImpl extends 
ShuffleClientImpl {
 
   private void initializeTransportClientFactory() {
     if (null == flinkTransportClientFactory) {
-      flinkTransportClientFactory = new FlinkTransportClientFactory(context, 
createBootstraps());
+      flinkTransportClientFactory =
+          new FlinkTransportClientFactory(context, createBootstraps(), 
bufferSizeBytes);
     }
   }
 
diff --git 
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/utils/BufferUtils.java
 
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/utils/BufferUtils.java
index 999d1eb10..b28e6f753 100644
--- 
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/utils/BufferUtils.java
+++ 
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/utils/BufferUtils.java
@@ -113,6 +113,18 @@ public class BufferUtils {
     }
   }
 
+  public static BufferHeader getBufferHeaderFromByteBuf(ByteBuf byteBuf, int 
position) {
+    byteBuf.readerIndex(position);
+    return new BufferHeader(
+        byteBuf.readInt(),
+        byteBuf.readInt(),
+        byteBuf.readInt(),
+        byteBuf.readInt(),
+        Buffer.DataType.values()[byteBuf.readByte()],
+        byteBuf.readBoolean(),
+        byteBuf.readInt());
+  }
+
   public static void reserveNumRequiredBuffers(BufferPool bufferPool, int 
numRequiredBuffers)
       throws IOException {
     long startTime = System.nanoTime();
diff --git 
a/client-flink/common/src/test/java/org/apache/celeborn/plugin/flink/BufferPackSuiteJ.java
 
b/client-flink/common/src/test/java/org/apache/celeborn/plugin/flink/BufferPackSuiteJ.java
index 8f3c0ce6e..acf42401a 100644
--- 
a/client-flink/common/src/test/java/org/apache/celeborn/plugin/flink/BufferPackSuiteJ.java
+++ 
b/client-flink/common/src/test/java/org/apache/celeborn/plugin/flink/BufferPackSuiteJ.java
@@ -26,6 +26,8 @@ import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collection;
 import java.util.List;
+import java.util.Queue;
+import java.util.Random;
 
 import org.apache.commons.lang3.tuple.Pair;
 import org.apache.flink.core.memory.MemorySegment;
@@ -38,6 +40,7 @@ import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf;
 import org.apache.flink.shaded.netty4.io.netty.buffer.CompositeByteBuf;
 import org.apache.flink.shaded.netty4.io.netty.buffer.Unpooled;
 import org.junit.After;
+import org.junit.Assert;
 import org.junit.Before;
 import org.junit.Test;
 import org.junit.runner.RunWith;
@@ -220,6 +223,27 @@ public class BufferPackSuiteJ {
     unpacked.forEach(Buffer::recycleBuffer);
   }
 
+  @Test
+  public void testUnpackCompositeBuffer() throws Exception {
+    Buffer dataBuffer = bufferPool.requestBuffer();
+    fillBufferWithRandomByte(dataBuffer);
+    ByteBuf bufferHeaderByteBuf = createBufferHeaderByteBuf(BUFFER_SIZE);
+    bufferHeaderByteBuf.retain();
+    CompositeByteBuf compositeByteBuf = Unpooled.compositeBuffer();
+    compositeByteBuf.addComponent(true, bufferHeaderByteBuf);
+    compositeByteBuf.addComponent(true, dataBuffer.asByteBuf());
+
+    Queue<Buffer> unpackedBuffers = BufferPacker.unpack(compositeByteBuf);
+    Assert.assertEquals(1, unpackedBuffers.size());
+    Assert.assertEquals(dataBuffer.readableBytes(), 
unpackedBuffers.peek().readableBytes());
+    Assert.assertEquals(BUFFER_SIZE, unpackedBuffers.peek().readableBytes());
+    for (int i = 0; i < BUFFER_SIZE; ++i) {
+      Assert.assertEquals(
+          dataBuffer.getMemorySegment().get(i), 
unpackedBuffers.peek().getMemorySegment().get(i));
+    }
+    dataBuffer.recycleBuffer();
+  }
+
   @Test
   public void testPackMultipleBuffers() throws Exception {
     int numBuffers = 7;
@@ -404,4 +428,28 @@ public class BufferPackSuiteJ {
       return new ReceivedNoHeaderBufferPacker(ripeBufferHandler);
     }
   }
+
+  public ByteBuf createBufferHeaderByteBuf(int dataBufferSize) {
+    ByteBuf headerBuf = Unpooled.directBuffer(BufferUtils.HEADER_LENGTH, 
BufferUtils.HEADER_LENGTH);
+    // write celeborn buffer header (subpartitionid(4) + attemptId(4) + 
nextBatchId(4) +
+    // compressedsize)
+    headerBuf.writeInt(0);
+    headerBuf.writeInt(0);
+    headerBuf.writeInt(0);
+    headerBuf.writeInt(
+        dataBufferSize + (BufferUtils.HEADER_LENGTH - 
BufferUtils.HEADER_LENGTH_PREFIX));
+
+    // write flink buffer header (dataType(1) + isCompress(1) + size(4))
+    headerBuf.writeByte(DATA_BUFFER.ordinal());
+    headerBuf.writeBoolean(false);
+    headerBuf.writeInt(dataBufferSize);
+    return headerBuf;
+  }
+
+  public void fillBufferWithRandomByte(Buffer buffer) {
+    Random random = new Random();
+    for (int i = 0; i < buffer.getMaxCapacity(); i++) {
+      buffer.asByteBuf().writeByte(random.nextInt(255));
+    }
+  }
 }
diff --git 
a/client-flink/common/src/test/java/org/apache/celeborn/plugin/flink/FlinkShuffleClientImplSuiteJ.java
 
b/client-flink/common/src/test/java/org/apache/celeborn/plugin/flink/FlinkShuffleClientImplSuiteJ.java
index 60a843f4a..cb15c1e15 100644
--- 
a/client-flink/common/src/test/java/org/apache/celeborn/plugin/flink/FlinkShuffleClientImplSuiteJ.java
+++ 
b/client-flink/common/src/test/java/org/apache/celeborn/plugin/flink/FlinkShuffleClientImplSuiteJ.java
@@ -55,7 +55,7 @@ public class FlinkShuffleClientImplSuiteJ {
     conf = new CelebornConf();
     shuffleClient =
         new FlinkShuffleClientImpl(
-            "APP", "localhost", 1232, System.currentTimeMillis(), conf, null) {
+            "APP", "localhost", 1232, System.currentTimeMillis(), conf, null, 
-1) {
           @Override
           public void setupLifecycleManagerRef(String host, int port) {}
         };
diff --git 
a/client-flink/common/src/test/java/org/apache/celeborn/plugin/flink/network/TransportFrameDecoderWithBufferSupplierSuiteJ.java
 
b/client-flink/common/src/test/java/org/apache/celeborn/plugin/flink/network/TransportFrameDecoderWithBufferSupplierSuiteJ.java
index c7c8440c8..431f8bc62 100644
--- 
a/client-flink/common/src/test/java/org/apache/celeborn/plugin/flink/network/TransportFrameDecoderWithBufferSupplierSuiteJ.java
+++ 
b/client-flink/common/src/test/java/org/apache/celeborn/plugin/flink/network/TransportFrameDecoderWithBufferSupplierSuiteJ.java
@@ -18,6 +18,8 @@
 package org.apache.celeborn.plugin.flink.network;
 
 import static 
org.apache.celeborn.common.network.client.TransportClient.requestId;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.when;
 
 import java.io.IOException;
 import java.util.ArrayList;
@@ -31,6 +33,7 @@ import java.util.function.Supplier;
 import io.netty.buffer.ByteBuf;
 import io.netty.buffer.Unpooled;
 import io.netty.channel.ChannelHandlerContext;
+import org.apache.flink.shaded.netty4.io.netty.buffer.CompositeByteBuf;
 import org.junit.Assert;
 import org.junit.Test;
 import org.junit.runner.RunWith;
@@ -46,6 +49,7 @@ import 
org.apache.celeborn.common.network.protocol.TransportMessage;
 import org.apache.celeborn.common.protocol.MessageType;
 import org.apache.celeborn.common.protocol.PbBacklogAnnouncement;
 import org.apache.celeborn.common.util.JavaUtils;
+import org.apache.celeborn.plugin.flink.utils.BufferUtils;
 
 @RunWith(Parameterized.class)
 public class TransportFrameDecoderWithBufferSupplierSuiteJ {
@@ -131,6 +135,125 @@ public class 
TransportFrameDecoderWithBufferSupplierSuiteJ {
     Assert.assertEquals(buffers.size(), 6);
   }
 
+  @Test(expected = IndexOutOfBoundsException.class)
+  public void testFailProcessFullBufferIfDisableLargeBufferSplit() throws 
IOException {
+    int bufferSizeBytes = 10 * 1024;
+    ConcurrentHashMap<Long, 
Supplier<org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf>>
+        supplier = JavaUtils.newConcurrentHashMap();
+    List<org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf> buffers = new 
ArrayList<>();
+
+    supplier.put(
+        0L,
+        () -> {
+          org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf buffer =
+              org.apache.flink.shaded.netty4.io.netty.buffer.Unpooled.buffer(
+                  bufferSizeBytes, bufferSizeBytes);
+          buffers.add(buffer);
+          return buffer;
+        });
+
+    TransportFrameDecoderWithBufferSupplier decoder =
+        new TransportFrameDecoderWithBufferSupplier(
+            supplier, 
TransportFrameDecoderWithBufferSupplier.DISABLE_LARGE_BUFFER_SPLIT_SIZE);
+    ChannelHandlerContext context = Mockito.mock(ChannelHandlerContext.class);
+
+    SubPartitionReadData readData =
+        new SubPartitionReadData(0, 0, generateData(bufferSizeBytes + 
BufferUtils.HEADER_LENGTH));
+
+    ByteBuf buffer = Unpooled.buffer(bufferSizeBytes * 4);
+    encodeMessage(readData, buffer);
+
+    // simulate
+    buffer.retain();
+    decoder.channelRead(context, buffer);
+  }
+
+  @Test
+  public void testProcessFullBufferIfEnableLargeBufferSplit() throws 
IOException {
+    int bufferSizeBytes = 10 * 1024;
+    ConcurrentHashMap<Long, 
Supplier<org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf>>
+        supplier = JavaUtils.newConcurrentHashMap();
+    List<Message> parsedMessages = new ArrayList<>();
+
+    supplier.put(
+        0L,
+        () ->
+            org.apache.flink.shaded.netty4.io.netty.buffer.Unpooled.buffer(
+                bufferSizeBytes, bufferSizeBytes));
+
+    TransportFrameDecoderWithBufferSupplier decoder =
+        new TransportFrameDecoderWithBufferSupplier(supplier, bufferSizeBytes);
+    ChannelHandlerContext context = Mockito.mock(ChannelHandlerContext.class);
+    when(context.fireChannelRead(any()))
+        .thenAnswer(
+            m -> {
+              Assert.assertEquals(1, m.getArguments().length);
+              parsedMessages.add(m.getArgument(0));
+              return null;
+            });
+
+    ReadData readData1 = new ReadData(0, generateData(1024));
+    // simulate the client received a large buffer which body size large than 
size of given buffer
+    // in this case, the flinkBuffer of parsed message will contain two parts: 
celeborn header and
+    // data buffer
+    SubPartitionReadData readData2 =
+        new SubPartitionReadData(0, 0, generateData(BufferUtils.HEADER_LENGTH 
+ bufferSizeBytes));
+    SubPartitionReadData readData3 = new SubPartitionReadData(0, 0, 
generateData(1024));
+
+    ByteBuf buffer = Unpooled.buffer(bufferSizeBytes * 4);
+    encodeMessage(readData1, buffer);
+    encodeMessage(readData2, buffer);
+    encodeMessage(readData3, buffer);
+
+    // simulate
+    buffer.retain();
+    decoder.channelRead(context, buffer);
+    Assert.assertEquals(parsedMessages.size(), 3);
+
+    // the parsed first message contains the readData1
+    Assert.assertTrue(
+        parsedMessages.get(0) instanceof 
org.apache.celeborn.plugin.flink.protocol.ReadData);
+    Assert.assertEquals(
+        ((org.apache.celeborn.plugin.flink.protocol.ReadData) 
parsedMessages.get(0))
+            .getFlinkBuffer()
+            .nioBuffer(),
+        readData1.body().nioByteBuffer());
+
+    // the parsed second message contains the readData2
+    Assert.assertTrue(
+        parsedMessages.get(1)
+            instanceof 
org.apache.celeborn.plugin.flink.protocol.SubPartitionReadData);
+    org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf byteBuf2 =
+        ((org.apache.celeborn.plugin.flink.protocol.SubPartitionReadData) 
parsedMessages.get(1))
+            .getFlinkBuffer();
+    // verify the flinkBuffer of parsed message contains two parts: celeborn 
header and data buffer
+    Assert.assertTrue(
+        byteBuf2 instanceof 
org.apache.flink.shaded.netty4.io.netty.buffer.CompositeByteBuf);
+    CompositeByteBuf compositeByteBuf2 = (CompositeByteBuf) byteBuf2;
+    Assert.assertEquals(compositeByteBuf2.numComponents(), 2);
+    org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf inputByteBuf2 =
+        org.apache.flink.shaded.netty4.io.netty.buffer.Unpooled.wrappedBuffer(
+            readData2.body().nioByteBuffer());
+    // the first part is celeborn header
+    Assert.assertEquals(
+        compositeByteBuf2.component(0).nioBuffer(),
+        inputByteBuf2.slice(0, BufferUtils.HEADER_LENGTH).nioBuffer());
+    // the second part is data buffer
+    Assert.assertEquals(
+        compositeByteBuf2.component(1).nioBuffer(),
+        inputByteBuf2.slice(BufferUtils.HEADER_LENGTH, 
bufferSizeBytes).nioBuffer());
+
+    // the parsed third message contains the readData3
+    Assert.assertTrue(
+        parsedMessages.get(2)
+            instanceof 
org.apache.celeborn.plugin.flink.protocol.SubPartitionReadData);
+    Assert.assertEquals(
+        ((org.apache.celeborn.plugin.flink.protocol.SubPartitionReadData) 
parsedMessages.get(2))
+            .getFlinkBuffer()
+            .nioBuffer(),
+        readData3.body().nioByteBuffer());
+  }
+
   public RpcRequest createBacklogAnnouncement(long streamId, int backlog) {
     return new RpcRequest(
         requestId(),
diff --git 
a/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierConsumerAgent.java
 
b/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierConsumerAgent.java
index 0febd8bd3..8d06ba77c 100644
--- 
a/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierConsumerAgent.java
+++ 
b/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierConsumerAgent.java
@@ -117,10 +117,13 @@ public class CelebornTierConsumerAgent implements 
TierConsumerAgent {
 
   private TieredStorageMemoryManager memoryManager;
 
+  private final int bufferSizeBytes;
+
   public CelebornTierConsumerAgent(
       CelebornConf conf,
       List<TieredStorageConsumerSpec> tieredStorageConsumerSpecs,
-      List<TierShuffleDescriptor> shuffleDescriptors) {
+      List<TierShuffleDescriptor> shuffleDescriptors,
+      int bufferSizeBytes) {
     checkArgument(!shuffleDescriptors.isEmpty(), "Wrong shuffle descriptors 
size.");
     checkArgument(
         tieredStorageConsumerSpecs.size() == shuffleDescriptors.size(),
@@ -132,6 +135,7 @@ public class CelebornTierConsumerAgent implements 
TierConsumerAgent {
     this.bufferReaders = new HashMap<>();
     this.receivedBuffers = new HashMap<>();
     this.subPartitionsNeedNotifyAvailable = new HashSet<>();
+    this.bufferSizeBytes = bufferSizeBytes;
     for (TierShuffleDescriptor shuffleDescriptor : shuffleDescriptors) {
       if (shuffleDescriptor instanceof TierShuffleDescriptorImpl) {
         initShuffleClient((TierShuffleDescriptorImpl) shuffleDescriptor);
@@ -326,7 +330,8 @@ public class CelebornTierConsumerAgent implements 
TierConsumerAgent {
               shuffleResource.getLifecycleManagerPort(),
               shuffleResource.getLifecycleManagerTimestamp(),
               conf,
-              new UserIdentifier("default", "default"));
+              new UserIdentifier("default", "default"),
+              bufferSizeBytes);
     } catch (DriverChangedException e) {
       throw new RuntimeException(e.getMessage());
     }
diff --git 
a/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierFactory.java
 
b/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierFactory.java
index 1a86130e4..c9913d132 100644
--- 
a/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierFactory.java
+++ 
b/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierFactory.java
@@ -118,7 +118,8 @@ public class CelebornTierFactory implements TierFactory {
       List<TieredStorageConsumerSpec> tieredStorageConsumerSpecs,
       List<TierShuffleDescriptor> shuffleDescriptors,
       TieredStorageNettyService nettyService) {
-    return new CelebornTierConsumerAgent(conf, tieredStorageConsumerSpecs, 
shuffleDescriptors);
+    return new CelebornTierConsumerAgent(
+        conf, tieredStorageConsumerSpecs, shuffleDescriptors, bufferSizeBytes);
   }
 
   public static String getCelebornTierName() {
diff --git 
a/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierProducerAgent.java
 
b/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierProducerAgent.java
index aab2b3ae5..983f24cb0 100644
--- 
a/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierProducerAgent.java
+++ 
b/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierProducerAgent.java
@@ -64,6 +64,7 @@ public class CelebornTierProducerAgent implements 
TierProducerAgent {
 
   private final int numBuffersPerSegment;
 
+  // The flink buffer size in bytes.
   private final int bufferSizeBytes;
 
   private final int numPartitions;
@@ -325,9 +326,18 @@ public class CelebornTierProducerAgent implements 
TierProducerAgent {
     try {
       int remainingReviveTimes = maxReviveTimes;
       while (remainingReviveTimes-- > 0 && !hasSentHandshake) {
+        // In the Flink hybrid shuffle integration strategy, the data buffer 
sent to the Celeborn
+        // workers consists of two components: the Celeborn header and the 
data buffers.
+        // In this scenario, the maximum byte size of the buffer received by 
the Celeborn worker is
+        // equal to the sum of the Flink buffer size and the Celeborn header 
size.
         Optional<PartitionLocation> revivePartition =
             flinkShuffleClient.pushDataHandShake(
-                shuffleId, mapId, attemptId, numSubPartitions, 
bufferSizeBytes, partitionLocation);
+                shuffleId,
+                mapId,
+                attemptId,
+                numSubPartitions,
+                bufferSizeBytes + BufferUtils.HEADER_LENGTH,
+                partitionLocation);
         // if remainingReviveTimes == 0 and revivePartition.isPresent(), there 
is no need to send
         // handshake again
         if (revivePartition.isPresent() && remainingReviveTimes > 0) {
@@ -478,7 +488,8 @@ public class CelebornTierProducerAgent implements 
TierProducerAgent {
           lifecycleManagerPort,
           lifecycleManagerTimestamp,
           celebornConf,
-          null);
+          null,
+          bufferSizeBytes);
     } catch (DriverChangedException e) {
       // would generate a new attempt to retry output gate
       throw new RuntimeException(e.getMessage());
diff --git 
a/tests/flink-it/src/test/scala/org/apache/celeborn/tests/flink/HeartbeatTest.scala
 
b/tests/flink-it/src/test/scala/org/apache/celeborn/tests/flink/HeartbeatTest.scala
index a373c6fd8..dbd7e543f 100644
--- 
a/tests/flink-it/src/test/scala/org/apache/celeborn/tests/flink/HeartbeatTest.scala
+++ 
b/tests/flink-it/src/test/scala/org/apache/celeborn/tests/flink/HeartbeatTest.scala
@@ -37,7 +37,8 @@ class HeartbeatTest extends AnyFunSuite with Logging with 
MiniClusterFeature wit
         0,
         System.currentTimeMillis(),
         clientConf,
-        new UserIdentifier("1", "1")) {
+        new UserIdentifier("1", "1"),
+        -1) {
         override def setupLifecycleManagerRef(host: String, port: Int): Unit = 
{}
       }
     testHeartbeatFromWorker2Client(flinkShuffleClientImpl.getDataClientFactory)
@@ -52,7 +53,8 @@ class HeartbeatTest extends AnyFunSuite with Logging with 
MiniClusterFeature wit
         0,
         System.currentTimeMillis(),
         clientConf,
-        new UserIdentifier("1", "1")) {
+        new UserIdentifier("1", "1"),
+        -1) {
         override def setupLifecycleManagerRef(host: String, port: Int): Unit = 
{}
       }
     
testHeartbeatFromWorker2ClientWithNoHeartbeat(flinkShuffleClientImpl.getDataClientFactory)
@@ -67,7 +69,8 @@ class HeartbeatTest extends AnyFunSuite with Logging with 
MiniClusterFeature wit
         0,
         System.currentTimeMillis(),
         clientConf,
-        new UserIdentifier("1", "1")) {
+        new UserIdentifier("1", "1"),
+        -1) {
         override def setupLifecycleManagerRef(host: String, port: Int): Unit = 
{}
       }
     
testHeartbeatFromWorker2ClientWithCloseChannel(flinkShuffleClientImpl.getDataClientFactory)


Reply via email to