This is an automated email from the ASF dual-hosted git repository.

rexxiong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new bd29da836 [CELEBORN-1490][CIP-6] Introduce tier consumer for hybrid 
shuffle
bd29da836 is described below

commit bd29da83635bd97b690174d1d0551fd114dbde50
Author: Weijie Guo <[email protected]>
AuthorDate: Thu Oct 17 10:46:35 2024 +0800

    [CELEBORN-1490][CIP-6] Introduce tier consumer for hybrid shuffle
    
    ### What changes were proposed in this pull request?
    
    Introduce tier consumer for hybrid shuffle
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    unit test
    
    Closes #2786 from reswqa/cip6-7-pr-new.
    
    Authored-by: Weijie Guo <[email protected]>
    Signed-off-by: Shuang <[email protected]>
---
 .../plugin/flink/network/ReadClientHandler.java    |   2 -
 .../celeborn/plugin/flink/protocol/ReadData.java   |   6 +-
 .../flink/protocol/SubPartitionReadData.java       |  36 +-
 .../flink/readclient/CelebornBufferStream.java     |  59 ++-
 .../flink/readclient/FlinkShuffleClientImpl.java   |  14 +-
 ...nsportFrameDecoderWithBufferSupplierSuiteJ.java |  39 +-
 .../flink/tiered/CelebornChannelBufferManager.java | 164 ++++++
 .../flink/tiered/CelebornChannelBufferReader.java  | 329 ++++++++++++
 .../flink/tiered/CelebornTierConsumerAgent.java    | 550 +++++++++++++++++++++
 .../plugin/flink/tiered/CelebornTierFactory.java   |   3 +-
 .../celeborn/common/network/protocol/ReadData.java |   2 +-
 .../network/protocol/SubPartitionReadData.java     |  20 +-
 12 files changed, 1157 insertions(+), 67 deletions(-)

diff --git 
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/ReadClientHandler.java
 
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/ReadClientHandler.java
index c1d0982fb..1ef2b7781 100644
--- 
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/ReadClientHandler.java
+++ 
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/ReadClientHandler.java
@@ -66,8 +66,6 @@ public class ReadClientHandler extends BaseMessageHandler {
     } else {
       if (msg != null && msg instanceof ReadData) {
         ((ReadData) msg).getFlinkBuffer().release();
-      } else if (msg != null && msg instanceof SubPartitionReadData) {
-        ((SubPartitionReadData) msg).getFlinkBuffer().release();
       }
 
       logger.warn("Unexpected streamId received: {}", streamId);
diff --git 
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/protocol/ReadData.java
 
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/protocol/ReadData.java
index a8f7fcf84..ba507e88e 100644
--- 
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/protocol/ReadData.java
+++ 
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/protocol/ReadData.java
@@ -23,9 +23,9 @@ import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf;
 
 import org.apache.celeborn.common.network.protocol.RequestMessage;
 
-public final class ReadData extends RequestMessage {
-  private final long streamId;
-  private ByteBuf flinkBuffer;
+public class ReadData extends RequestMessage {
+  protected final long streamId;
+  protected ByteBuf flinkBuffer;
 
   @Override
   public boolean needCopyOut() {
diff --git 
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/protocol/SubPartitionReadData.java
 
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/protocol/SubPartitionReadData.java
index b12f24d9b..fc08de75f 100644
--- 
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/protocol/SubPartitionReadData.java
+++ 
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/protocol/SubPartitionReadData.java
@@ -20,46 +20,30 @@ package org.apache.celeborn.plugin.flink.protocol;
 
 import java.util.Objects;
 
-import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf;
-
-import org.apache.celeborn.common.network.protocol.ReadData;
-import org.apache.celeborn.common.network.protocol.RequestMessage;
-
 /**
  * Comparing {@link ReadData}, this class has an additional field of 
subpartitionId. This class is
  * added to keep the backward compatibility.
  */
-public class SubPartitionReadData extends RequestMessage {
-  private final long streamId;
+public class SubPartitionReadData extends ReadData {
   private final int subPartitionId;
-  private ByteBuf flinkBuffer;
-
-  @Override
-  public boolean needCopyOut() {
-    return true;
-  }
 
   public SubPartitionReadData(long streamId, int subPartitionId) {
+    super(streamId);
     this.subPartitionId = subPartitionId;
-    this.streamId = streamId;
   }
 
   @Override
   public int encodedLength() {
-    return 8 + 4;
+    return super.encodedLength() + 4;
   }
 
   // This method will not be called because ReadData won't be created at flink 
client.
   @Override
   public void encode(io.netty.buffer.ByteBuf buf) {
-    buf.writeLong(streamId);
+    super.encode(buf);
     buf.writeInt(subPartitionId);
   }
 
-  public long getStreamId() {
-    return streamId;
-  }
-
   public int getSubPartitionId() {
     return subPartitionId;
   }
@@ -74,8 +58,8 @@ public class SubPartitionReadData extends RequestMessage {
     if (this == o) return true;
     if (o == null || getClass() != o.getClass()) return false;
     SubPartitionReadData readData = (SubPartitionReadData) o;
-    return streamId == readData.streamId
-        && subPartitionId == readData.subPartitionId
+    return streamId == readData.getStreamId()
+        && subPartitionId == readData.getSubPartitionId()
         && Objects.equals(flinkBuffer, readData.flinkBuffer);
   }
 
@@ -95,12 +79,4 @@ public class SubPartitionReadData extends RequestMessage {
         + flinkBuffer
         + '}';
   }
-
-  public ByteBuf getFlinkBuffer() {
-    return flinkBuffer;
-  }
-
-  public void setFlinkBuffer(ByteBuf flinkBuffer) {
-    this.flinkBuffer = flinkBuffer;
-  }
 }
diff --git 
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/CelebornBufferStream.java
 
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/CelebornBufferStream.java
index 7f478143d..849895fef 100644
--- 
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/CelebornBufferStream.java
+++ 
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/CelebornBufferStream.java
@@ -20,9 +20,12 @@ package org.apache.celeborn.plugin.flink.readclient;
 import java.io.IOException;
 import java.nio.ByteBuffer;
 import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.BiConsumer;
 import java.util.function.Consumer;
 import java.util.function.Supplier;
 
+import javax.annotation.Nullable;
+
 import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -36,6 +39,7 @@ import org.apache.celeborn.common.network.util.NettyUtils;
 import org.apache.celeborn.common.protocol.MessageType;
 import org.apache.celeborn.common.protocol.PartitionLocation;
 import org.apache.celeborn.common.protocol.PbBufferStreamEnd;
+import org.apache.celeborn.common.protocol.PbNotifyRequiredSegment;
 import org.apache.celeborn.common.protocol.PbOpenStream;
 import org.apache.celeborn.common.protocol.PbReadAddCredit;
 import org.apache.celeborn.common.protocol.PbStreamHandler;
@@ -110,10 +114,38 @@ public class CelebornBufferStream {
         });
   }
 
+  public void notifyRequiredSegment(PbNotifyRequiredSegment 
pbNotifyRequiredSegment) {
+    this.client.sendRpc(
+        new TransportMessage(
+                MessageType.NOTIFY_REQUIRED_SEGMENT, 
pbNotifyRequiredSegment.toByteArray())
+            .toByteBuffer(),
+        new RpcResponseCallback() {
+
+          @Override
+          public void onSuccess(ByteBuffer response) {
+            // Send PbNotifyRequiredSegment do not expect response.
+          }
+
+          @Override
+          public void onFailure(Throwable e) {
+            logger.error(
+                "Send PbNotifyRequiredSegment to {} failed, streamId {}, 
detail {}",
+                NettyUtils.getRemoteAddress(client.getChannel()),
+                streamId,
+                e.getCause());
+            messageConsumer.accept(new TransportableError(streamId, e));
+          }
+        });
+  }
+
   public static CelebornBufferStream empty() {
     return EMPTY_CELEBORN_BUFFER_STREAM;
   }
 
+  public static boolean isEmptyStream(CelebornBufferStream stream) {
+    return stream == null || stream == EMPTY_CELEBORN_BUFFER_STREAM;
+  }
+
   public long getStreamId() {
     return streamId;
   }
@@ -167,6 +199,11 @@ public class CelebornBufferStream {
   }
 
   public void moveToNextPartitionIfPossible(long endedStreamId) {
+    moveToNextPartitionIfPossible(endedStreamId, null);
+  }
+
+  public void moveToNextPartitionIfPossible(
+      long endedStreamId, @Nullable BiConsumer<Long, Integer> 
requiredSegmentIdConsumer) {
     logger.debug(
         "MoveToNextPartitionIfPossible in this:{},  endedStreamId: {}, 
currentLocationIndex: {}, currentSteamId:{}, locationsLength:{}",
         this,
@@ -178,9 +215,10 @@ public class CelebornBufferStream {
       logger.debug("Get end streamId {}", endedStreamId);
       cleanStream(endedStreamId);
     }
+
     if (currentLocationIndex.get() < locations.length) {
       try {
-        openStreamInternal();
+        openStreamInternal(requiredSegmentIdConsumer);
         logger.debug(
             "MoveToNextPartitionIfPossible after openStream this:{},  
endedStreamId: {}, currentLocationIndex: {}, currentSteamId:{}, 
locationsLength:{}",
             this,
@@ -195,7 +233,12 @@ public class CelebornBufferStream {
     }
   }
 
-  private void openStreamInternal() throws IOException, InterruptedException {
+  /**
+   * Open the stream, note that if the openReaderFuture is not null, 
requiredSegmentIdConsumer will
+   * be invoked for every subPartition when open stream success.
+   */
+  private void openStreamInternal(@Nullable BiConsumer<Long, Integer> 
requiredSegmentIdConsumer)
+      throws IOException, InterruptedException {
     this.client =
         clientFactory.createClientWithRetry(
             locations[currentLocationIndex.get()].getHost(),
@@ -210,6 +253,7 @@ public class CelebornBufferStream {
                 .setStartIndex(subIndexStart)
                 .setEndIndex(subIndexEnd)
                 .setInitialCredit(initialCredit)
+                .setRequireSubpartitionId(true)
                 .build()
                 .toByteArray());
     client.sendRpc(
@@ -230,6 +274,13 @@ public class CelebornBufferStream {
                       .getReadClientHandler()
                       .registerHandler(streamId, messageConsumer, client);
                   isOpenSuccess = true;
+                  if (requiredSegmentIdConsumer != null) {
+                    for (int subPartitionId = subIndexStart;
+                        subPartitionId <= subIndexEnd;
+                        subPartitionId++) {
+                      requiredSegmentIdConsumer.accept(streamId, 
subPartitionId);
+                    }
+                  }
                   logger.debug(
                       "open stream success from remote:{}, stream id:{}, 
fileName: {}",
                       client.getSocketAddress(),
@@ -269,4 +320,8 @@ public class CelebornBufferStream {
   public TransportClient getClient() {
     return client;
   }
+
+  public boolean isOpened() {
+    return isOpenSuccess;
+  }
 }
diff --git 
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/FlinkShuffleClientImpl.java
 
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/FlinkShuffleClientImpl.java
index 4fb65777e..a3f8a1a9b 100644
--- 
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/FlinkShuffleClientImpl.java
+++ 
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/FlinkShuffleClientImpl.java
@@ -181,15 +181,11 @@ public class FlinkShuffleClientImpl extends 
ShuffleClientImpl {
           shuffleId,
           partitionId,
           isSegmentGranularityVisible);
-      if (isSegmentGranularityVisible) {
-        // When the downstream reduce tasks start early than upstream map 
tasks, the shuffle
-        // partition locations may be found empty, should retry until the 
upstream task started
-        return CelebornBufferStream.empty();
-      } else {
-        throw new PartitionUnRetryAbleException(
-            String.format(
-                "Shuffle data lost for shuffle %d partition %d.", shuffleId, 
partitionId));
-      }
+      // TODO: in segment granularity visible senarios, when the downstream 
reduce tasks start early
+      // than upstream map tasks, the shuffle
+      // partition locations may be found empty, should retry until the 
upstream task started
+      throw new PartitionUnRetryAbleException(
+          String.format("Shuffle data lost for shuffle %d partition %d.", 
shuffleId, partitionId));
     } else {
       Arrays.sort(partitionLocations, 
Comparator.comparingInt(PartitionLocation::getEpoch));
       logger.debug(
diff --git 
a/client-flink/common/src/test/java/org/apache/celeborn/plugin/flink/network/TransportFrameDecoderWithBufferSupplierSuiteJ.java
 
b/client-flink/common/src/test/java/org/apache/celeborn/plugin/flink/network/TransportFrameDecoderWithBufferSupplierSuiteJ.java
index 64696abec..c7c8440c8 100644
--- 
a/client-flink/common/src/test/java/org/apache/celeborn/plugin/flink/network/TransportFrameDecoderWithBufferSupplierSuiteJ.java
+++ 
b/client-flink/common/src/test/java/org/apache/celeborn/plugin/flink/network/TransportFrameDecoderWithBufferSupplierSuiteJ.java
@@ -21,6 +21,8 @@ import static 
org.apache.celeborn.common.network.client.TransportClient.requestI
 
 import java.io.IOException;
 import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
 import java.util.List;
 import java.util.Random;
 import java.util.concurrent.ConcurrentHashMap;
@@ -31,19 +33,40 @@ import io.netty.buffer.Unpooled;
 import io.netty.channel.ChannelHandlerContext;
 import org.junit.Assert;
 import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
 import org.mockito.Mockito;
 
 import org.apache.celeborn.common.network.buffer.NioManagedBuffer;
 import org.apache.celeborn.common.network.protocol.Message;
 import org.apache.celeborn.common.network.protocol.ReadData;
 import org.apache.celeborn.common.network.protocol.RpcRequest;
+import org.apache.celeborn.common.network.protocol.SubPartitionReadData;
 import org.apache.celeborn.common.network.protocol.TransportMessage;
 import org.apache.celeborn.common.protocol.MessageType;
 import org.apache.celeborn.common.protocol.PbBacklogAnnouncement;
 import org.apache.celeborn.common.util.JavaUtils;
 
+@RunWith(Parameterized.class)
 public class TransportFrameDecoderWithBufferSupplierSuiteJ {
 
+  enum TestReadDataType {
+    READ_DATA,
+    SUBPARTITION_READ_DATA,
+  }
+
+  private TestReadDataType testReadDataType;
+
+  public TransportFrameDecoderWithBufferSupplierSuiteJ(TestReadDataType 
testReadDataType) {
+    this.testReadDataType = testReadDataType;
+  }
+
+  @Parameterized.Parameters
+  public static Collection prepareData() {
+    Object[][] object = {{TestReadDataType.READ_DATA}, 
{TestReadDataType.SUBPARTITION_READ_DATA}};
+    return Arrays.asList(object);
+  }
+
   @Test
   public void testDropUnusedBytes() throws IOException {
     ConcurrentHashMap<Long, 
Supplier<org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf>>
@@ -64,11 +87,11 @@ public class TransportFrameDecoderWithBufferSupplierSuiteJ {
     ChannelHandlerContext context = Mockito.mock(ChannelHandlerContext.class);
 
     RpcRequest announcement = createBacklogAnnouncement(0, 0);
-    ReadData unUsedReadData = new ReadData(1, generateData(1024));
-    ReadData readData = new ReadData(2, generateData(1024));
+    ReadData unUsedReadData = generateReadDataMessage(1, 0, 
generateData(1024));
+    ReadData readData = generateReadDataMessage(2, 0, generateData(1024));
     RpcRequest announcement1 = createBacklogAnnouncement(0, 0);
-    ReadData unUsedReadData1 = new ReadData(1, generateData(1024));
-    ReadData readData1 = new ReadData(2, generateData(8));
+    ReadData unUsedReadData1 = generateReadDataMessage(1, 0, 
generateData(1024));
+    ReadData readData1 = generateReadDataMessage(2, 0, generateData(8));
 
     ByteBuf buffer = Unpooled.buffer(5000);
     encodeMessage(announcement, buffer);
@@ -145,4 +168,12 @@ public class TransportFrameDecoderWithBufferSupplierSuiteJ 
{
 
     return data;
   }
+
+  private ReadData generateReadDataMessage(long streamId, int subPartitionId, 
ByteBuf buf) {
+    if (testReadDataType == TestReadDataType.READ_DATA) {
+      return new ReadData(streamId, buf);
+    } else {
+      return new SubPartitionReadData(streamId, subPartitionId, buf);
+    }
+  }
 }
diff --git 
a/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornChannelBufferManager.java
 
b/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornChannelBufferManager.java
new file mode 100644
index 000000000..8fe9e4bbb
--- /dev/null
+++ 
b/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornChannelBufferManager.java
@@ -0,0 +1,164 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.celeborn.plugin.flink.tiered;
+
+import static org.apache.flink.util.Preconditions.checkNotNull;
+
+import java.util.LinkedList;
+import java.util.Queue;
+
+import javax.annotation.concurrent.GuardedBy;
+
+import org.apache.flink.core.memory.MemorySegment;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.buffer.BufferListener;
+import org.apache.flink.runtime.io.network.buffer.BufferPool;
+import org.apache.flink.runtime.io.network.buffer.BufferRecycler;
+import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.storage.TieredStorageMemoryManager;
+import org.apache.flink.util.ExceptionUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class CelebornChannelBufferManager implements BufferListener, 
BufferRecycler {
+
+  private static Logger logger = 
LoggerFactory.getLogger(CelebornChannelBufferManager.class);
+
+  /** The queue to hold the available buffer when the reader is waiting for 
buffers. */
+  private final Queue<Buffer> bufferQueue;
+
+  private final TieredStorageMemoryManager memoryManager;
+
+  private final CelebornChannelBufferReader bufferReader;
+
+  /** The tag indicates whether it is waiting for buffers from the buffer 
pool. */
+  @GuardedBy("bufferQueue")
+  private boolean isWaitingForFloatingBuffers;
+
+  /** The total number of required buffers for the respective input channel. */
+  @GuardedBy("bufferQueue")
+  private int numRequiredBuffers = 0;
+
+  public CelebornChannelBufferManager(
+      TieredStorageMemoryManager memoryManager, CelebornChannelBufferReader 
bufferReader) {
+    this.memoryManager = checkNotNull(memoryManager);
+    this.bufferReader = checkNotNull(bufferReader);
+    this.bufferQueue = new LinkedList<>();
+  }
+
+  @Override
+  public boolean notifyBufferAvailable(Buffer buffer) {
+    if (bufferReader.isClosed()) {
+      return false;
+    }
+    int numBuffers = 0;
+    boolean isBufferUsed = false;
+    try {
+      synchronized (bufferQueue) {
+        if (!isWaitingForFloatingBuffers) {
+          logger.warn("This channel should be waiting for floating buffers.");
+          return false;
+        }
+        isWaitingForFloatingBuffers = false;
+        if (bufferReader.isClosed() || bufferQueue.size() >= 
numRequiredBuffers) {
+          return false;
+        }
+        bufferQueue.add(buffer);
+        isBufferUsed = true;
+        numBuffers = 1 + tryRequestBuffers();
+      }
+      bufferReader.notifyAvailableCredits(numBuffers);
+    } catch (Throwable t) {
+      bufferReader.errorReceived(t.getLocalizedMessage());
+    }
+    return isBufferUsed;
+  }
+
+  public void decreaseRequiredCredits(int numCredits) {
+    synchronized (bufferQueue) {
+      numRequiredBuffers -= numCredits;
+    }
+  }
+
+  @Override
+  public void notifyBufferDestroyed() {
+    // noop
+  }
+
+  @Override
+  public void recycle(MemorySegment segment) {
+    try {
+      memoryManager.getBufferPool().recycle(segment);
+    } catch (Throwable t) {
+      ExceptionUtils.rethrow(t);
+    }
+  }
+
+  Buffer requestBuffer() {
+    synchronized (bufferQueue) {
+      return bufferQueue.poll();
+    }
+  }
+
+  int requestBuffers(int numRequired) {
+    int numRequestedBuffers = 0;
+    synchronized (bufferQueue) {
+      if (bufferReader.isClosed()) {
+        return numRequestedBuffers;
+      }
+      numRequiredBuffers += numRequired;
+      numRequestedBuffers = tryRequestBuffers();
+    }
+    return numRequestedBuffers;
+  }
+
+  int tryRequestBuffersIfNeeded() {
+    synchronized (bufferQueue) {
+      if (numRequiredBuffers > 0 && !isWaitingForFloatingBuffers && 
bufferQueue.isEmpty()) {
+        return tryRequestBuffers();
+      }
+      return 0;
+    }
+  }
+
+  void close() {
+    synchronized (bufferQueue) {
+      for (Buffer buffer : bufferQueue) {
+        buffer.recycleBuffer();
+      }
+      bufferQueue.clear();
+    }
+  }
+
+  @GuardedBy("bufferQueue")
+  private int tryRequestBuffers() {
+    assert Thread.holdsLock(bufferQueue);
+    int numRequestedBuffers = 0;
+    while (bufferQueue.size() < numRequiredBuffers && 
!isWaitingForFloatingBuffers) {
+      BufferPool bufferPool = memoryManager.getBufferPool();
+      Buffer buffer = bufferPool.requestBuffer();
+      if (buffer != null) {
+        bufferQueue.add(buffer);
+        numRequestedBuffers++;
+      } else if (bufferPool.addBufferListener(this)) {
+        isWaitingForFloatingBuffers = true;
+        break;
+      }
+    }
+    return numRequestedBuffers;
+  }
+}
diff --git 
a/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornChannelBufferReader.java
 
b/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornChannelBufferReader.java
new file mode 100644
index 000000000..527617c96
--- /dev/null
+++ 
b/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornChannelBufferReader.java
@@ -0,0 +1,329 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.plugin.flink.tiered;
+
+import static org.apache.celeborn.plugin.flink.utils.Utils.checkState;
+
+import java.io.IOException;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.function.BiConsumer;
+import java.util.function.Consumer;
+
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.common.TieredStorageInputChannelId;
+import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.common.TieredStorageSubpartitionId;
+import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.storage.TieredStorageMemoryManager;
+import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.celeborn.common.network.protocol.BacklogAnnouncement;
+import org.apache.celeborn.common.network.protocol.BufferStreamEnd;
+import org.apache.celeborn.common.network.protocol.RequestMessage;
+import org.apache.celeborn.common.network.protocol.TransportableError;
+import org.apache.celeborn.common.network.util.NettyUtils;
+import org.apache.celeborn.common.protocol.PbNotifyRequiredSegment;
+import org.apache.celeborn.common.protocol.PbReadAddCredit;
+import org.apache.celeborn.common.util.JavaUtils;
+import org.apache.celeborn.plugin.flink.ShuffleResourceDescriptor;
+import org.apache.celeborn.plugin.flink.protocol.SubPartitionReadData;
+import org.apache.celeborn.plugin.flink.readclient.CelebornBufferStream;
+import org.apache.celeborn.plugin.flink.readclient.FlinkShuffleClientImpl;
+
+/**
+ * Wrap the {@link CelebornBufferStream}, utilize in flink hybrid shuffle 
integration strategy now.
+ */
+public class CelebornChannelBufferReader {
+  private static final Logger LOG = 
LoggerFactory.getLogger(CelebornChannelBufferReader.class);
+
+  private CelebornChannelBufferManager bufferManager;
+
+  private final FlinkShuffleClientImpl client;
+
+  private final int shuffleId;
+
+  private final int partitionId;
+
+  private final TieredStorageInputChannelId inputChannelId;
+
+  private final int subPartitionIndexStart;
+
+  private final int subPartitionIndexEnd;
+
+  private final BiConsumer<ByteBuf, TieredStorageSubpartitionId> dataListener;
+
+  private final BiConsumer<Throwable, TieredStorageSubpartitionId> 
failureListener;
+
+  private final Consumer<RequestMessage> messageConsumer;
+
+  private CelebornBufferStream bufferStream;
+
+  private boolean isOpened;
+
+  private volatile boolean closed = false;
+
+  private volatile ConcurrentHashMap<Integer, Integer> 
subPartitionRequiredSegmentIds;
+
+  /** Note this field is to record the number of backlog before the read is 
set up. */
+  private int numBackLog = 0;
+
+  public CelebornChannelBufferReader(
+      FlinkShuffleClientImpl client,
+      ShuffleResourceDescriptor shuffleDescriptor,
+      TieredStorageInputChannelId inputChannelId,
+      int startSubIdx,
+      int endSubIdx,
+      BiConsumer<ByteBuf, TieredStorageSubpartitionId> dataListener,
+      BiConsumer<Throwable, TieredStorageSubpartitionId> failureListener) {
+    this.client = client;
+    this.shuffleId = shuffleDescriptor.getShuffleId();
+    this.partitionId = shuffleDescriptor.getPartitionId();
+    this.inputChannelId = inputChannelId;
+    this.subPartitionIndexStart = startSubIdx;
+    this.subPartitionIndexEnd = endSubIdx;
+    this.dataListener = dataListener;
+    this.failureListener = failureListener;
+    this.subPartitionRequiredSegmentIds = JavaUtils.newConcurrentHashMap();
+    for (int subPartitionId = subPartitionIndexStart;
+        subPartitionId <= subPartitionIndexEnd;
+        subPartitionId++) {
+      subPartitionRequiredSegmentIds.put(subPartitionId, -1);
+    }
+    this.messageConsumer =
+        requestMessage -> {
+          // Note that we need to use SubPartitionReadData because the 
isSegmentGranularityVisible
+          // is set as true when opening stream
+          if (requestMessage instanceof SubPartitionReadData) {
+            dataReceived((SubPartitionReadData) requestMessage);
+          } else if (requestMessage instanceof BacklogAnnouncement) {
+            backlogReceived(((BacklogAnnouncement) 
requestMessage).getBacklog());
+          } else if (requestMessage instanceof TransportableError) {
+            errorReceived(((TransportableError) 
requestMessage).getErrorMessage());
+          } else if (requestMessage instanceof BufferStreamEnd) {
+            onStreamEnd((BufferStreamEnd) requestMessage);
+          }
+        };
+  }
+
+  public void setup(TieredStorageMemoryManager memoryManager) {
+    this.bufferManager = new CelebornChannelBufferManager(memoryManager, this);
+    if (numBackLog > 0) {
+      notifyAvailableCredits(bufferManager.requestBuffers(numBackLog));
+      numBackLog = 0;
+    }
+  }
+
+  public void open(int initialCredit) {
+    try {
+      bufferStream =
+          client.readBufferedPartition(
+              shuffleId, partitionId, subPartitionIndexStart, 
subPartitionIndexEnd, true);
+      bufferStream.open(this::requestBuffer, initialCredit, messageConsumer);
+      this.isOpened = bufferStream.isOpened();
+    } catch (Exception e) {
+      messageConsumer.accept(new TransportableError(0L, e));
+      LOG.error("Failed to open reader", e);
+    }
+  }
+
+  public void close() {
+    // It may be call multiple times because subPartitions can share the same 
reader, as a single
+    // reader can consume multiple subPartitions
+    if (closed) {
+      return;
+    }
+
+    // need set closed first before remove Handler
+    closed = true;
+    if (!CelebornBufferStream.isEmptyStream(bufferStream)) {
+      bufferStream.close();
+      bufferStream = null;
+    } else {
+      LOG.warn(
+          "bufferStream is null when closed, shuffleId: {}, partitionId: {}",
+          shuffleId,
+          partitionId);
+    }
+
+    try {
+      if (bufferManager != null) {
+        bufferManager.close();
+        bufferManager = null;
+      }
+    } catch (Throwable throwable) {
+      LOG.warn("Failed to close buffer manager.", throwable);
+    }
+
+    subPartitionRequiredSegmentIds.clear();
+    subPartitionRequiredSegmentIds = null;
+  }
+
+  public boolean isOpened() {
+    return isOpened;
+  }
+
+  boolean isClosed() {
+    return closed;
+  }
+
+  public void notifyAvailableCredits(int numCredits) {
+    if (numCredits <= 0) {
+      return;
+    }
+    if (!closed && !CelebornBufferStream.isEmptyStream(bufferStream)) {
+      bufferStream.addCredit(
+          PbReadAddCredit.newBuilder()
+              .setStreamId(bufferStream.getStreamId())
+              .setCredit(numCredits)
+              .build());
+      bufferManager.decreaseRequiredCredits(numCredits);
+      return;
+    }
+    LOG.warn(
+        "The buffer stream is null or closed, ignore the credits for 
shuffleId: {}, partitionId: {}",
+        shuffleId,
+        partitionId);
+  }
+
+  public void notifyRequiredSegmentIfNeeded(int requiredSegmentId, int 
subPartitionId) {
+    Integer lastRequiredSegmentId =
+        subPartitionRequiredSegmentIds.computeIfAbsent(subPartitionId, id -> 
-1);
+    if (requiredSegmentId >= 0 && requiredSegmentId != lastRequiredSegmentId) {
+      LOG.debug(
+          "Notify required segment id {} for {} {}, the last segment id is {}",
+          requiredSegmentId,
+          partitionId,
+          subPartitionId,
+          lastRequiredSegmentId);
+      subPartitionRequiredSegmentIds.put(subPartitionId, requiredSegmentId);
+      if (!this.notifyRequiredSegment(requiredSegmentId, subPartitionId)) {
+        // if fail to notify reader segment, restore the last required segment 
id
+        subPartitionRequiredSegmentIds.putIfAbsent(subPartitionId, 
lastRequiredSegmentId);
+      }
+    }
+  }
+
+  public boolean notifyRequiredSegment(int requiredSegmentId, int 
subPartitionId) {
+    this.subPartitionRequiredSegmentIds.put(subPartitionId, requiredSegmentId);
+    if (!closed && !CelebornBufferStream.isEmptyStream(bufferStream)) {
+      LOG.debug(
+          "Notify required segmentId {} for {} {} {}",
+          requiredSegmentId,
+          partitionId,
+          subPartitionId,
+          shuffleId);
+      PbNotifyRequiredSegment notifyRequiredSegment =
+          PbNotifyRequiredSegment.newBuilder()
+              .setStreamId(bufferStream.getStreamId())
+              .setRequiredSegmentId(requiredSegmentId)
+              .setSubPartitionId(subPartitionId)
+              .build();
+      bufferStream.notifyRequiredSegment(notifyRequiredSegment);
+      return true;
+    }
+    return false;
+  }
+
+  public ByteBuf requestBuffer() {
+    Buffer buffer = bufferManager.requestBuffer();
+    return buffer == null ? null : buffer.asByteBuf();
+  }
+
+  public void backlogReceived(int backlog) {
+    if (!closed) {
+      if (bufferManager == null) {
+        numBackLog += backlog;
+        return;
+      }
+      int numRequestedBuffers = bufferManager.requestBuffers(backlog);
+      if (numRequestedBuffers > 0) {
+        notifyAvailableCredits(numRequestedBuffers);
+      }
+      numBackLog = 0;
+      return;
+    }
+    LOG.warn(
+        "The buffer stream {} is null or closed, ignore the backlog for 
shuffleId: {}, partitionId: {}",
+        bufferStream.getStreamId(),
+        shuffleId,
+        partitionId);
+  }
+
+  public void errorReceived(String errorMsg) {
+    if (!closed) {
+      closed = true;
+      LOG.debug("Error received, " + errorMsg);
+      if (!CelebornBufferStream.isEmptyStream(bufferStream) && 
bufferStream.getClient() != null) {
+        LOG.error(
+            "Received error from {} message {}",
+            NettyUtils.getRemoteAddress(bufferStream.getClient().getChannel()),
+            errorMsg);
+      }
+      for (int subPartitionId = subPartitionIndexStart;
+          subPartitionId <= subPartitionIndexEnd;
+          subPartitionId++) {
+        failureListener.accept(
+            new IOException(errorMsg), new 
TieredStorageSubpartitionId(subPartitionId));
+      }
+    }
+  }
+
+  public void dataReceived(SubPartitionReadData readData) {
+    LOG.debug(
+        "Remote buffer stream reader get stream id {} subPartitionId {} 
received readable bytes {}.",
+        readData.getStreamId(),
+        readData.getSubPartitionId(),
+        readData.getFlinkBuffer().readableBytes());
+    checkState(
+        readData.getSubPartitionId() >= subPartitionIndexStart
+            && readData.getSubPartitionId() <= subPartitionIndexEnd,
+        "Wrong sub partition id: " + readData.getSubPartitionId());
+    dataListener.accept(
+        readData.getFlinkBuffer(), new 
TieredStorageSubpartitionId(readData.getSubPartitionId()));
+    int numRequested = bufferManager.tryRequestBuffersIfNeeded();
+    notifyAvailableCredits(numRequested);
+  }
+
+  public void onStreamEnd(BufferStreamEnd streamEnd) {
+    long streamId = streamEnd.getStreamId();
+    LOG.debug("Buffer stream reader get stream end for {}", streamId);
+    if (!closed && !CelebornBufferStream.isEmptyStream(bufferStream)) {
+      // TOOD: Update the partition locations here if support reading and 
writing shuffle data
+      // simultaneously
+      bufferStream.moveToNextPartitionIfPossible(streamId, 
this::sendRequireSegmentId);
+    }
+  }
+
+  public TieredStorageInputChannelId getInputChannelId() {
+    return inputChannelId;
+  }
+
+  private void sendRequireSegmentId(long streamId, int subPartitionId) {
+    if (subPartitionRequiredSegmentIds.containsKey(subPartitionId)) {
+      int currentSegmentId = 
subPartitionRequiredSegmentIds.get(subPartitionId);
+      if (bufferStream.isOpened() && currentSegmentId >= 0) {
+        LOG.debug(
+            "Buffer stream {} is opened, notify required segment id {} ",
+            streamId,
+            currentSegmentId);
+        notifyRequiredSegment(currentSegmentId, subPartitionId);
+      }
+    }
+  }
+}
diff --git 
a/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierConsumerAgent.java
 
b/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierConsumerAgent.java
new file mode 100644
index 000000000..d858ae891
--- /dev/null
+++ 
b/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierConsumerAgent.java
@@ -0,0 +1,550 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.plugin.flink.tiered;
+
+import static org.apache.celeborn.plugin.flink.utils.Utils.checkArgument;
+import static org.apache.celeborn.plugin.flink.utils.Utils.checkNotNull;
+import static org.apache.celeborn.plugin.flink.utils.Utils.checkState;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Queue;
+import java.util.Set;
+import java.util.function.BiConsumer;
+import java.util.stream.Collectors;
+
+import javax.annotation.concurrent.GuardedBy;
+
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import 
org.apache.flink.runtime.io.network.partition.PartitionNotFoundException;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
+import 
org.apache.flink.runtime.io.network.partition.ResultSubpartitionIndexSet;
+import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.common.TieredStorageIdMappingUtils;
+import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.common.TieredStorageInputChannelId;
+import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.common.TieredStoragePartitionId;
+import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.common.TieredStorageSubpartitionId;
+import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.storage.AvailabilityNotifier;
+import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.storage.TieredStorageConsumerSpec;
+import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.storage.TieredStorageMemoryManager;
+import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.tier.TierConsumerAgent;
+import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.tier.TierShuffleDescriptor;
+import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf;
+import org.apache.flink.util.ExceptionUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.celeborn.common.CelebornConf;
+import org.apache.celeborn.common.exception.DriverChangedException;
+import org.apache.celeborn.common.exception.PartitionUnRetryAbleException;
+import org.apache.celeborn.common.identity.UserIdentifier;
+import org.apache.celeborn.plugin.flink.RemoteShuffleResource;
+import org.apache.celeborn.plugin.flink.ShuffleResourceDescriptor;
+import org.apache.celeborn.plugin.flink.buffer.ReceivedNoHeaderBufferPacker;
+import org.apache.celeborn.plugin.flink.readclient.FlinkShuffleClientImpl;
+
+public class CelebornTierConsumerAgent implements TierConsumerAgent {
+
+  private static final Logger LOG = 
LoggerFactory.getLogger(CelebornTierConsumerAgent.class);
+
+  private final CelebornConf conf;
+
+  private final int gateIndex;
+
+  private final List<TieredStorageConsumerSpec> consumerSpecs;
+
+  private final List<TierShuffleDescriptor> shuffleDescriptors;
+
+  /**
+   * partitionId -> subPartitionId -> reader, note that subPartitions may 
share the same reader, as
+   * a single reader can consume multiple subPartitions to improvement 
performance.
+   */
+  private final Map<
+          TieredStoragePartitionId, Map<TieredStorageSubpartitionId, 
CelebornChannelBufferReader>>
+      bufferReaders;
+
+  /** Lock to protect {@link #receivedBuffers} and {@link #cause} and {@link 
#closed}, etc. */
+  private final Object lock = new Object();
+
+  /** Received buffers from remote shuffle worker. It's consumed by upper 
computing task. */
+  @GuardedBy("lock")
+  private final Map<TieredStoragePartitionId, Map<TieredStorageSubpartitionId, 
Queue<Buffer>>>
+      receivedBuffers;
+
+  @GuardedBy("lock")
+  private final Set<Tuple2<TieredStoragePartitionId, 
TieredStorageSubpartitionId>>
+      subPartitionsNeedNotifyAvailable;
+
+  @GuardedBy("lock")
+  private boolean started = false;
+
+  @GuardedBy("lock")
+  private Throwable cause;
+
+  /** Whether this remote input gate has been closed or not. */
+  @GuardedBy("lock")
+  private boolean closed;
+
+  private FlinkShuffleClientImpl shuffleClient;
+
+  /**
+   * The notify target is flink inputGate, used in notify input gate which 
subPartition contain
+   * shuffle data that can to be read.
+   */
+  private AvailabilityNotifier availabilityNotifier;
+
+  private TieredStorageMemoryManager memoryManager;
+
+  public CelebornTierConsumerAgent(
+      CelebornConf conf,
+      List<TieredStorageConsumerSpec> tieredStorageConsumerSpecs,
+      List<TierShuffleDescriptor> shuffleDescriptors) {
+    checkArgument(!shuffleDescriptors.isEmpty(), "Wrong shuffle descriptors 
size.");
+    checkArgument(
+        tieredStorageConsumerSpecs.size() == shuffleDescriptors.size(),
+        "Wrong consumer spec size.");
+    this.conf = conf;
+    this.gateIndex = tieredStorageConsumerSpecs.get(0).getGateIndex();
+    this.consumerSpecs = tieredStorageConsumerSpecs;
+    this.shuffleDescriptors = shuffleDescriptors;
+    this.bufferReaders = new HashMap<>();
+    this.receivedBuffers = new HashMap<>();
+    this.subPartitionsNeedNotifyAvailable = new HashSet<>();
+    for (TierShuffleDescriptor shuffleDescriptor : shuffleDescriptors) {
+      if (shuffleDescriptor instanceof TierShuffleDescriptorImpl) {
+        initShuffleClient((TierShuffleDescriptorImpl) shuffleDescriptor);
+        break;
+      }
+    }
+    checkNotNull(this.shuffleClient);
+    initBufferReaders();
+  }
+
+  @Override
+  public void setup(TieredStorageMemoryManager memoryManager) {
+    this.memoryManager = memoryManager;
+    for (Map<TieredStorageSubpartitionId, CelebornChannelBufferReader> 
subPartitionReaders :
+        bufferReaders.values()) {
+      subPartitionReaders.forEach((partitionId, reader) -> 
reader.setup(memoryManager));
+    }
+  }
+
+  @Override
+  public void start() {
+    // notify input gate that some sub partitions are available
+    Set<Tuple2<TieredStoragePartitionId, TieredStorageSubpartitionId>> 
needNotifyAvailable;
+    synchronized (lock) {
+      needNotifyAvailable = new HashSet<>(subPartitionsNeedNotifyAvailable);
+      subPartitionsNeedNotifyAvailable.clear();
+      started = true;
+    }
+    try {
+      needNotifyAvailable.forEach(
+          partitionIdTuple -> notifyAvailable(partitionIdTuple.f0, 
partitionIdTuple.f1));
+    } catch (Throwable t) {
+      LOG.error("Error occurred when notifying sub partitions available", t);
+      recycleAllResources();
+      ExceptionUtils.rethrow(t);
+    }
+    needNotifyAvailable.clear();
+
+    // Require segment 0 when starting the client
+    for (TieredStorageConsumerSpec spec : consumerSpecs) {
+      for (int subpartitionId : spec.getSubpartitionIds().values()) {
+        CelebornChannelBufferReader bufferReader =
+            getBufferReader(spec.getPartitionId(), new 
TieredStorageSubpartitionId(subpartitionId));
+        if (bufferReader == null) {
+          continue;
+        }
+        // TODO: if fail to open reader, may the downstream task start before 
than upstream task,
+        // should retry open reader, rather than throw exception
+        boolean openReaderSuccess = openReader(bufferReader);
+        if (!openReaderSuccess) {
+          LOG.error("Failed to open reader.");
+          recycleAllResources();
+          ExceptionUtils.rethrow(new IOException("Failed to open reader."));
+        }
+        bufferReader.notifyRequiredSegmentIfNeeded(0, subpartitionId);
+      }
+    }
+  }
+
+  @Override
+  public int peekNextBufferSubpartitionId(
+      TieredStoragePartitionId tieredStoragePartitionId,
+      ResultSubpartitionIndexSet resultSubpartitionIndexSet) {
+    synchronized (lock) {
+      // check health
+      healthCheck();
+
+      // return the subPartitionId if already receive buffer from 
corresponding subpartition
+      Map<TieredStorageSubpartitionId, Queue<Buffer>> 
subPartitionReceivedBuffers =
+          receivedBuffers.get(tieredStoragePartitionId);
+      if (subPartitionReceivedBuffers == null) {
+        return -1;
+      }
+      for (int subPartitionIndex = resultSubpartitionIndexSet.getStartIndex();
+          subPartitionIndex <= resultSubpartitionIndexSet.getEndIndex();
+          subPartitionIndex++) {
+        Queue<Buffer> buffers =
+            subPartitionReceivedBuffers.get(new 
TieredStorageSubpartitionId(subPartitionIndex));
+        if (buffers != null && !buffers.isEmpty()) {
+          return subPartitionIndex;
+        }
+      }
+    }
+    return -1;
+  }
+
+  @Override
+  public Optional<Buffer> getNextBuffer(
+      TieredStoragePartitionId tieredStoragePartitionId,
+      TieredStorageSubpartitionId tieredStorageSubpartitionId,
+      int segmentId) {
+    synchronized (lock) {
+      // check health
+      healthCheck();
+    }
+
+    // check reader status
+    if (!bufferReaders.containsKey(tieredStoragePartitionId)
+        || 
!bufferReaders.get(tieredStoragePartitionId).containsKey(tieredStorageSubpartitionId))
 {
+      return Optional.empty();
+    }
+    try {
+      boolean openReaderSuccess = openReader(tieredStoragePartitionId, 
tieredStorageSubpartitionId);
+      if (!openReaderSuccess) {
+        return Optional.empty();
+      }
+    } catch (Throwable throwable) {
+      LOG.error("Failed to open reader.", throwable);
+      recycleAllResources();
+      ExceptionUtils.rethrow(throwable);
+    }
+
+    synchronized (lock) {
+      CelebornChannelBufferReader bufferReader =
+          getBufferReader(tieredStoragePartitionId, 
tieredStorageSubpartitionId);
+      bufferReader.notifyRequiredSegmentIfNeeded(
+          segmentId, tieredStorageSubpartitionId.getSubpartitionId());
+      Map<TieredStorageSubpartitionId, Queue<Buffer>> partitionBuffers =
+          receivedBuffers.get(tieredStoragePartitionId);
+      if (partitionBuffers == null || partitionBuffers.isEmpty()) {
+        return Optional.empty();
+      }
+      Queue<Buffer> subPartitionBuffers = 
partitionBuffers.get(tieredStorageSubpartitionId);
+      if (subPartitionBuffers == null || subPartitionBuffers.isEmpty()) {
+        return Optional.empty();
+      }
+      return Optional.ofNullable(subPartitionBuffers.poll());
+    }
+  }
+
+  @Override
+  public void registerAvailabilityNotifier(AvailabilityNotifier 
availabilityNotifier) {
+    this.availabilityNotifier = availabilityNotifier;
+    LOG.info("Registered availability notifier for gate {}.", gateIndex);
+  }
+
+  @Override
+  public void updateTierShuffleDescriptor(
+      TieredStoragePartitionId tieredStoragePartitionId,
+      TieredStorageInputChannelId tieredStorageInputChannelId,
+      TieredStorageSubpartitionId subpartitionId,
+      TierShuffleDescriptor tierShuffleDescriptor) {
+    if (!(tierShuffleDescriptor instanceof TierShuffleDescriptorImpl)) {
+      return;
+    }
+    TierShuffleDescriptorImpl shuffleDescriptor = (TierShuffleDescriptorImpl) 
tierShuffleDescriptor;
+    checkState(
+        
shuffleDescriptor.getResultPartitionID().equals(tieredStoragePartitionId.getPartitionID()),
+        "Wrong result partition id: " + 
shuffleDescriptor.getResultPartitionID());
+    ResultSubpartitionIndexSet subpartitionIndexSet =
+        new ResultSubpartitionIndexSet(subpartitionId.getSubpartitionId());
+    if (!bufferReaders.containsKey(tieredStoragePartitionId)
+        || 
!bufferReaders.get(tieredStoragePartitionId).containsKey(subpartitionId)) {
+      ShuffleResourceDescriptor shuffleResourceDescriptor =
+          
shuffleDescriptor.getShuffleResource().getMapPartitionShuffleDescriptor();
+      createBufferReader(
+          shuffleResourceDescriptor,
+          tieredStoragePartitionId,
+          tieredStorageInputChannelId,
+          subpartitionIndexSet);
+      CelebornChannelBufferReader bufferReader =
+          checkNotNull(getBufferReader(tieredStoragePartitionId, 
subpartitionId));
+      bufferReader.setup(checkNotNull(memoryManager));
+      openReader(bufferReader);
+    }
+  }
+
+  @Override
+  public void close() {
+    Throwable closeException = null;
+    // Do not check closed flag, thus to allow calling this method from both 
task thread and
+    // cancel thread.
+    try {
+      recycleAllResources();
+    } catch (Throwable throwable) {
+      closeException = throwable;
+      LOG.error("Failed to recycle all resources.", throwable);
+    }
+    if (closeException != null) {
+      ExceptionUtils.rethrow(closeException);
+    }
+  }
+
+  private void initShuffleClient(TierShuffleDescriptorImpl 
remoteShuffleDescriptor) {
+    RemoteShuffleResource shuffleResource = 
remoteShuffleDescriptor.getShuffleResource();
+    try {
+      String appUniqueId = remoteShuffleDescriptor.getCelebornAppId();
+      this.shuffleClient =
+          FlinkShuffleClientImpl.get(
+              appUniqueId,
+              shuffleResource.getLifecycleManagerHost(),
+              shuffleResource.getLifecycleManagerPort(),
+              shuffleResource.getLifecycleManagerTimestamp(),
+              conf,
+              new UserIdentifier("default", "default"));
+    } catch (DriverChangedException e) {
+      throw new RuntimeException(e.getMessage());
+    }
+  }
+
+  private CelebornChannelBufferReader getBufferReader(
+      TieredStoragePartitionId tieredStoragePartitionId,
+      TieredStorageSubpartitionId tieredStorageSubpartitionId) {
+    return 
bufferReaders.get(tieredStoragePartitionId).get(tieredStorageSubpartitionId);
+  }
+
+  private void recycleAllResources() {
+    List<Buffer> buffersToRecycle = new ArrayList<>();
+    for (Map<TieredStorageSubpartitionId, CelebornChannelBufferReader> 
subPartitionReaders :
+        bufferReaders.values()) {
+      subPartitionReaders.values().forEach(CelebornChannelBufferReader::close);
+    }
+    synchronized (lock) {
+      for (Map<TieredStorageSubpartitionId, Queue<Buffer>> subPartitionMap :
+          receivedBuffers.values()) {
+        buffersToRecycle.addAll(
+            subPartitionMap.values().stream()
+                .flatMap(Queue::stream)
+                .collect(Collectors.toCollection(LinkedList::new)));
+      }
+      receivedBuffers.clear();
+      bufferReaders.clear();
+      availabilityNotifier = null;
+      closed = true;
+    }
+    try {
+      buffersToRecycle.forEach(Buffer::recycleBuffer);
+    } catch (Throwable throwable) {
+      LOG.error("Failed to recycle buffers.", throwable);
+      throw throwable;
+    }
+  }
+
+  private boolean openReader(
+      TieredStoragePartitionId partitionId, TieredStorageSubpartitionId 
subPartitionId) {
+    CelebornChannelBufferReader bufferReader =
+        
checkNotNull(checkNotNull(bufferReaders.get(partitionId)).get(subPartitionId));
+    return openReader(bufferReader);
+  }
+
+  private boolean openReader(CelebornChannelBufferReader bufferReader) {
+    if (!bufferReader.isOpened()) {
+      try {
+        bufferReader.open(0);
+      } catch (Exception e) {
+        // may throw PartitionUnRetryAbleException
+        recycleAllResources();
+        ExceptionUtils.rethrow(e);
+      }
+    }
+
+    return bufferReader.isOpened();
+  }
+
+  private void initBufferReaders() {
+    for (int i = 0; i < shuffleDescriptors.size(); i++) {
+      if (!(shuffleDescriptors.get(i) instanceof TierShuffleDescriptorImpl)) {
+        continue;
+      }
+      TierShuffleDescriptorImpl shuffleDescriptor =
+          (TierShuffleDescriptorImpl) shuffleDescriptors.get(i);
+      ResultPartitionID resultPartitionID = 
shuffleDescriptor.getResultPartitionID();
+      ShuffleResourceDescriptor shuffleResourceDescriptor =
+          
shuffleDescriptor.getShuffleResource().getMapPartitionShuffleDescriptor();
+      TieredStoragePartitionId partitionId = new 
TieredStoragePartitionId(resultPartitionID);
+      checkState(consumerSpecs.get(i).getPartitionId().equals(partitionId), 
"Wrong partition id.");
+      ResultSubpartitionIndexSet subPartitionIdSet = 
consumerSpecs.get(i).getSubpartitionIds();
+      LOG.debug(
+          "create shuffle reader for gate {} descriptor {} partitionId {}, 
subPartitionId start {} and end {}",
+          gateIndex,
+          shuffleResourceDescriptor,
+          partitionId,
+          subPartitionIdSet.getStartIndex(),
+          subPartitionIdSet.getEndIndex());
+      createBufferReader(
+          shuffleResourceDescriptor,
+          partitionId,
+          consumerSpecs.get(i).getInputChannelId(),
+          subPartitionIdSet);
+    }
+  }
+
+  private void createBufferReader(
+      ShuffleResourceDescriptor shuffleDescriptor,
+      TieredStoragePartitionId partitionId,
+      TieredStorageInputChannelId inputChannelId,
+      ResultSubpartitionIndexSet subPartitionIdSet) {
+    // create a single reader for multiple subPartitions to improvement 
performance
+    CelebornChannelBufferReader reader =
+        new CelebornChannelBufferReader(
+            shuffleClient,
+            shuffleDescriptor,
+            inputChannelId,
+            subPartitionIdSet.getStartIndex(),
+            subPartitionIdSet.getEndIndex(),
+            getDataListener(partitionId),
+            getFailureListener(partitionId));
+
+    for (int id = subPartitionIdSet.getStartIndex(); id <= 
subPartitionIdSet.getEndIndex(); id++) {
+      TieredStorageSubpartitionId subPartitionId = new 
TieredStorageSubpartitionId(id);
+      checkState(
+          !bufferReaders.containsKey(partitionId)
+              || !bufferReaders.get(partitionId).containsKey(subPartitionId),
+          "Duplicate shuffle reader.");
+      bufferReaders
+          .computeIfAbsent(partitionId, partition -> new HashMap<>())
+          .put(subPartitionId, reader);
+    }
+  }
+
+  @GuardedBy("lock")
+  private void healthCheck() {
+    if (closed || cause != null) {
+      Exception e;
+      if (closed) {
+        e = new IOException("Celeborn consumer agent already closed.");
+      } else {
+        e = new IOException(cause);
+      }
+      recycleAllResources();
+      LOG.error("Failed to check health.", e);
+      ExceptionUtils.rethrow(e);
+    }
+  }
+
+  private void onBuffer(
+      TieredStoragePartitionId partitionId,
+      TieredStorageSubpartitionId subPartitionId,
+      Buffer buffer) {
+    boolean wasEmpty;
+    synchronized (lock) {
+      if (closed || cause != null) {
+        buffer.recycleBuffer();
+        recycleAllResources();
+        throw new IllegalStateException("Input gate already closed or 
failed.");
+      }
+      Queue<Buffer> buffers =
+          receivedBuffers
+              .computeIfAbsent(partitionId, partition -> new HashMap<>())
+              .computeIfAbsent(subPartitionId, subpartition -> new 
LinkedList<>());
+      wasEmpty = buffers.isEmpty();
+      buffers.add(buffer);
+      if (wasEmpty && !started) {
+        subPartitionsNeedNotifyAvailable.add(Tuple2.of(partitionId, 
subPartitionId));
+        return;
+      }
+    }
+    if (wasEmpty) {
+      notifyAvailable(partitionId, subPartitionId);
+    }
+  }
+
+  private BiConsumer<ByteBuf, TieredStorageSubpartitionId> getDataListener(
+      TieredStoragePartitionId partitionId) {
+    return (byteBuf, subPartitionId) -> {
+      Queue<Buffer> unpackedBuffers = null;
+      try {
+        unpackedBuffers = ReceivedNoHeaderBufferPacker.unpack(byteBuf);
+        while (!unpackedBuffers.isEmpty()) {
+          onBuffer(partitionId, subPartitionId, unpackedBuffers.poll());
+        }
+      } catch (Throwable throwable) {
+        synchronized (lock) {
+          LOG.error(
+              "Failed to process the received buffer, cause: {} throwable {}.",
+              cause == null ? "" : cause,
+              throwable);
+          if (cause == null) {
+            cause = throwable;
+          }
+        }
+        notifyAvailable(partitionId, subPartitionId);
+        if (unpackedBuffers != null) {
+          unpackedBuffers.forEach(Buffer::recycleBuffer);
+        }
+        recycleAllResources();
+      }
+    };
+  }
+
+  private BiConsumer<Throwable, TieredStorageSubpartitionId> 
getFailureListener(
+      TieredStoragePartitionId partitionId) {
+    return (throwable, subPartitionId) -> {
+      synchronized (lock) {
+        // only record and process the first exception
+        if (cause != null) {
+          return;
+        }
+        Class<?> clazz = PartitionUnRetryAbleException.class;
+        if (throwable.getMessage() != null && 
throwable.getMessage().contains(clazz.getName())) {
+          cause =
+              new 
PartitionNotFoundException(TieredStorageIdMappingUtils.convertId(partitionId));
+          LOG.error("The consumer agent received an 
PartitionUnRetryAbleException.", throwable);
+        } else {
+          LOG.error("The consumer agent received an exception.", throwable);
+          cause = throwable;
+        }
+      }
+      // notify input gate, the input gate will call 
peekNextBufferSubpartitionId or getNextBufer,
+      // and process exception
+      notifyAvailable(partitionId, subPartitionId);
+    };
+  }
+
+  private void notifyAvailable(
+      TieredStoragePartitionId partitionId, TieredStorageSubpartitionId 
subPartitionId) {
+    Map<TieredStorageSubpartitionId, CelebornChannelBufferReader> 
subPartitionReaders =
+        bufferReaders.get(partitionId);
+    if (subPartitionReaders != null) {
+      CelebornChannelBufferReader channelBufferReader = 
subPartitionReaders.get(subPartitionId);
+      if (channelBufferReader != null) {
+        availabilityNotifier.notifyAvailable(partitionId, 
channelBufferReader.getInputChannelId());
+      }
+    }
+  }
+}
diff --git 
a/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierFactory.java
 
b/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierFactory.java
index 02306a5ad..1a86130e4 100644
--- 
a/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierFactory.java
+++ 
b/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierFactory.java
@@ -118,8 +118,7 @@ public class CelebornTierFactory implements TierFactory {
       List<TieredStorageConsumerSpec> tieredStorageConsumerSpecs,
       List<TierShuffleDescriptor> shuffleDescriptors,
       TieredStorageNettyService nettyService) {
-    // TODO impl this in the follow-up PR.
-    return null;
+    return new CelebornTierConsumerAgent(conf, tieredStorageConsumerSpecs, 
shuffleDescriptors);
   }
 
   public static String getCelebornTierName() {
diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/protocol/ReadData.java
 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/ReadData.java
index 6e465ccc5..1be7b46ff 100644
--- 
a/common/src/main/java/org/apache/celeborn/common/network/protocol/ReadData.java
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/ReadData.java
@@ -26,7 +26,7 @@ import 
org.apache.celeborn.common.network.buffer.NettyManagedBuffer;
 // This is buffer wrapper used in celeborn worker only
 // It doesn't need decode in worker.
 public class ReadData extends RequestMessage {
-  private long streamId;
+  protected long streamId;
 
   public ReadData(long streamId, ByteBuf buf) {
     super(new NettyManagedBuffer(buf));
diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/protocol/SubPartitionReadData.java
 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/SubPartitionReadData.java
index bdea84989..11a13118d 100644
--- 
a/common/src/main/java/org/apache/celeborn/common/network/protocol/SubPartitionReadData.java
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/SubPartitionReadData.java
@@ -22,38 +22,30 @@ import java.util.Objects;
 
 import io.netty.buffer.ByteBuf;
 
-import org.apache.celeborn.common.network.buffer.NettyManagedBuffer;
-
 /**
  * Comparing {@link ReadData}, this class has an additional field of 
subpartitionId. This class is
  * added to keep the backward compatibility.
  */
-public class SubPartitionReadData extends RequestMessage {
-  private long streamId;
+public class SubPartitionReadData extends ReadData {
 
   private int subPartitionId;
 
   public SubPartitionReadData(long streamId, int subPartitionId, ByteBuf buf) {
-    super(new NettyManagedBuffer(buf));
-    this.streamId = streamId;
+    super(streamId, buf);
     this.subPartitionId = subPartitionId;
   }
 
   @Override
   public int encodedLength() {
-    return 8 + 4;
+    return super.encodedLength() + 4;
   }
 
   @Override
   public void encode(ByteBuf buf) {
-    buf.writeLong(streamId);
+    super.encode(buf);
     buf.writeInt(subPartitionId);
   }
 
-  public long getStreamId() {
-    return streamId;
-  }
-
   public int getSubPartitionId() {
     return subPartitionId;
   }
@@ -68,8 +60,8 @@ public class SubPartitionReadData extends RequestMessage {
     if (this == o) return true;
     if (o == null || getClass() != o.getClass()) return false;
     SubPartitionReadData readData = (SubPartitionReadData) o;
-    return streamId == readData.streamId
-        && subPartitionId == readData.subPartitionId
+    return streamId == readData.getStreamId()
+        && subPartitionId == readData.getSubPartitionId()
         && super.equals(o);
   }
 

Reply via email to