This is an automated email from the ASF dual-hosted git repository.

rexxiong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new f4dc7a839 [CELEBORN-1490][CIP-6] Impl worker read process in Flink 
Hybrid Shuffle
f4dc7a839 is described below

commit f4dc7a839bd1f51ab1d4aa6fba6eba5a385b3edf
Author: Weijie Guo <[email protected]>
AuthorDate: Tue Oct 29 16:29:16 2024 +0800

    [CELEBORN-1490][CIP-6] Impl worker read process in Flink Hybrid Shuffle
    
    ### What changes were proposed in this pull request?
    
    Impl worker read process in Flink Hybrid Shuffle
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    Closes #2820 from reswqa/cip6-8-pr.
    
    Lead-authored-by: Weijie Guo <[email protected]>
    Co-authored-by: codenohup <[email protected]>
    Signed-off-by: Shuang <[email protected]>
---
 .../flink/readclient/CelebornBufferStream.java     |   1 -
 common/src/main/proto/TransportMessages.proto      |   1 -
 .../deploy/worker/memory/RecyclableBuffer.java     |  13 +-
 .../worker/memory/RecyclableSegmentIdBuffer.java   |  40 +++
 .../deploy/worker/storage/CreditStreamManager.java |  73 ++++-
 .../deploy/worker/storage/MapPartitionData.java    |  16 +-
 .../worker/storage/MapPartitionDataReader.java     | 115 +++++---
 .../storage/segment/SegmentMapPartitionData.java   | 109 +++++++
 .../segment/SegmentMapPartitionDataReader.java     | 325 +++++++++++++++++++++
 .../service/deploy/worker/FetchHandler.scala       |  25 +-
 10 files changed, 656 insertions(+), 62 deletions(-)

diff --git 
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/CelebornBufferStream.java
 
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/CelebornBufferStream.java
index 849895fef..1fda95c01 100644
--- 
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/CelebornBufferStream.java
+++ 
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/CelebornBufferStream.java
@@ -253,7 +253,6 @@ public class CelebornBufferStream {
                 .setStartIndex(subIndexStart)
                 .setEndIndex(subIndexEnd)
                 .setInitialCredit(initialCredit)
-                .setRequireSubpartitionId(true)
                 .build()
                 .toByteArray());
     client.sendRpc(
diff --git a/common/src/main/proto/TransportMessages.proto 
b/common/src/main/proto/TransportMessages.proto
index d5b129b0d..245b78dc0 100644
--- a/common/src/main/proto/TransportMessages.proto
+++ b/common/src/main/proto/TransportMessages.proto
@@ -684,7 +684,6 @@ message PbOpenStream {
   int32 endIndex = 4;
   int32 initialCredit = 5;
   bool readLocalShuffle = 6;
-  bool requireSubpartitionId = 7;
 }
 
 message PbStreamHandler {
diff --git 
a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/memory/RecyclableBuffer.java
 
b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/memory/RecyclableBuffer.java
index efbe0ece4..45e951397 100644
--- 
a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/memory/RecyclableBuffer.java
+++ 
b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/memory/RecyclableBuffer.java
@@ -23,14 +23,25 @@ public class RecyclableBuffer {
 
   public final ByteBuf byteBuf;
 
+  public final int subPartitionId;
+
   public final BufferRecycler bufferRecycler;
 
-  public RecyclableBuffer(ByteBuf byteBuf, BufferRecycler bufferRecycler) {
+  public RecyclableBuffer(ByteBuf byteBuf, int subPartitionId, BufferRecycler 
bufferRecycler) {
     this.byteBuf = byteBuf;
+    this.subPartitionId = subPartitionId;
     this.bufferRecycler = bufferRecycler;
   }
 
   public void recycle() {
     bufferRecycler.recycle(byteBuf);
   }
+
+  public int getSubPartitionId() {
+    return subPartitionId;
+  }
+
+  public boolean isDataBuffer() {
+    return true;
+  }
 }
diff --git 
a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/memory/RecyclableSegmentIdBuffer.java
 
b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/memory/RecyclableSegmentIdBuffer.java
new file mode 100644
index 000000000..10561bdb0
--- /dev/null
+++ 
b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/memory/RecyclableSegmentIdBuffer.java
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.service.deploy.worker.memory;
+
+/**
+ * Warp {@link RecyclableBuffer} with a segmentId, it contains no data buffer, 
only the segmentId,
+ * which is used to verify whether the subsequently buffer should be sent.
+ */
+public class RecyclableSegmentIdBuffer extends RecyclableBuffer {
+  private final int segmentId;
+
+  public RecyclableSegmentIdBuffer(int subpartitionId, int segmentId) {
+    super(null, subpartitionId, new BufferRecycler(buf -> {}));
+    this.segmentId = segmentId;
+  }
+
+  public int getSegmentId() {
+    return segmentId;
+  }
+
+  public boolean isDataBuffer() {
+    return false;
+  }
+}
diff --git 
a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/CreditStreamManager.java
 
b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/CreditStreamManager.java
index eba3cc809..526998208 100644
--- 
a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/CreditStreamManager.java
+++ 
b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/CreditStreamManager.java
@@ -37,6 +37,7 @@ import org.apache.celeborn.common.meta.MapFileMeta;
 import org.apache.celeborn.common.util.JavaUtils;
 import org.apache.celeborn.common.util.ThreadUtils;
 import org.apache.celeborn.service.deploy.worker.memory.MemoryManager;
+import 
org.apache.celeborn.service.deploy.worker.storage.segment.SegmentMapPartitionData;
 
 public class CreditStreamManager {
   private static final Logger logger = 
LoggerFactory.getLogger(CreditStreamManager.class);
@@ -98,15 +99,25 @@ public class CreditStreamManager {
             (k, v) -> {
               if (v == null) {
                 try {
+                  MapFileMeta fileMeta = (MapFileMeta) fileInfo.getFileMeta();
                   v =
-                      new MapPartitionData(
-                          minReadBuffers,
-                          maxReadBuffers,
-                          storageFetcherPool,
-                          threadsPerMountPoint,
-                          fileInfo,
-                          id -> recycleStream(id),
-                          minBuffersToTriggerRead);
+                      fileMeta.isSegmentGranularityVisible()
+                          ? new SegmentMapPartitionData(
+                              minReadBuffers,
+                              maxReadBuffers,
+                              storageFetcherPool,
+                              threadsPerMountPoint,
+                              fileInfo,
+                              id -> recycleStream(id),
+                              minBuffersToTriggerRead)
+                          : new MapPartitionData(
+                              minReadBuffers,
+                              maxReadBuffers,
+                              storageFetcherPool,
+                              threadsPerMountPoint,
+                              fileInfo,
+                              id -> recycleStream(id),
+                              minBuffersToTriggerRead);
                 } catch (IOException e) {
                   exception.set(e);
                   return null;
@@ -158,11 +169,57 @@ public class CreditStreamManager {
     }
   }
 
+  private void notifyRequiredSegment(
+      MapPartitionData mapPartitionData, int requiredSegmentId, long streamId, 
int subPartitionId) {
+    logger.debug(
+        "Receive RequiredSegment from client, streamId: {}, requiredSegmentId: 
{}, subPartitionId: {}",
+        streamId,
+        requiredSegmentId,
+        subPartitionId);
+    try {
+      if (mapPartitionData instanceof SegmentMapPartitionData) {
+        ((SegmentMapPartitionData) mapPartitionData)
+            .notifyRequiredSegmentId(requiredSegmentId, streamId, 
subPartitionId);
+      } else {
+        logger.warn("Only non-null SegmentMapPartitionData is expected for 
notifyRequiredSegment.");
+      }
+    } catch (Throwable e) {
+      logger.error(
+          String.format("Fail to notify segmentId %s for stream %s.", 
requiredSegmentId, streamId),
+          e);
+      throw e;
+    }
+  }
+
   public void addCredit(int numCredit, long streamId) {
+    if (!streams.containsKey(streamId)) {
+      // In flink hybrid shuffle integration strategy, the stream may release 
in worker before
+      // client receive bufferStreamEnd,
+      // and the client may send request with old streamId, so ignore 
non-exist streams.
+      logger.warn("Ignore AddCredit from stream {}, numCredit {}.", streamId, 
numCredit);
+      return;
+    }
     MapPartitionData mapPartitionData = 
streams.get(streamId).getMapDataPartition();
     addCredit(mapPartitionData, numCredit, streamId);
   }
 
+  public void notifyRequiredSegment(int requiredSegmentId, long streamId, int 
subPartitionId) {
+    StreamState streamState = streams.get(streamId);
+    if (streamState != null) {
+      notifyRequiredSegment(
+          streamState.getMapDataPartition(), requiredSegmentId, streamId, 
subPartitionId);
+    } else {
+      // In flink hybrid shuffle integration strategy, the stream may release 
in worker before
+      // client receive bufferStreamEnd,
+      // and the client may send request with old streamId, so ignore 
non-exist streams.
+      logger.warn(
+          "Ignore RequiredSegment from stream {}, subPartition {}, segmentId 
{}.",
+          streamId,
+          subPartitionId,
+          requiredSegmentId);
+    }
+  }
+
   public void connectionTerminated(Channel channel) {
     for (Map.Entry<Long, StreamState> entry : streams.entrySet()) {
       if (entry.getValue().getAssociatedChannel() == channel) {
diff --git 
a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/MapPartitionData.java
 
b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/MapPartitionData.java
index e0c60c0b9..10cd0332b 100644
--- 
a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/MapPartitionData.java
+++ 
b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/MapPartitionData.java
@@ -45,12 +45,12 @@ import 
org.apache.celeborn.service.deploy.worker.memory.BufferRecycler;
 import org.apache.celeborn.service.deploy.worker.memory.MemoryManager;
 
 // this means active data partition
-class MapPartitionData implements MemoryManager.ReadBufferTargetChangeListener 
{
+public class MapPartitionData implements 
MemoryManager.ReadBufferTargetChangeListener {
   public static final Logger logger = 
LoggerFactory.getLogger(MapPartitionData.class);
-  private final DiskFileInfo diskFileInfo;
+  protected final DiskFileInfo diskFileInfo;
   private final MapFileMeta mapFileMeta;
-  private final ExecutorService readExecutor;
-  private final ConcurrentHashMap<Long, MapPartitionDataReader> readers =
+  protected final ExecutorService readExecutor;
+  protected final ConcurrentHashMap<Long, MapPartitionDataReader> readers =
       JavaUtils.newConcurrentHashMap();
   private FileChannel dataFileChanel;
   private FileChannel indexChannel;
@@ -59,7 +59,7 @@ class MapPartitionData implements 
MemoryManager.ReadBufferTargetChangeListener {
   private final BufferQueue bufferQueue = new BufferQueue();
   private AtomicBoolean bufferQueueInitialized = new AtomicBoolean(false);
   private MemoryManager memoryManager = MemoryManager.instance();
-  private Consumer<Long> recycleStream;
+  protected Consumer<Long> recycleStream;
   private int minReadBuffers;
   private int maxReadBuffers;
   private int minBuffersToTriggerRead;
@@ -181,6 +181,10 @@ class MapPartitionData implements 
MemoryManager.ReadBufferTargetChangeListener {
     bufferQueue.tryApplyNewBuffers(readers.size(), 
mapFileMeta.getBufferSize(), this::onBuffer);
   }
 
+  protected void openReader(MapPartitionDataReader reader) throws IOException {
+    reader.open(dataFileChanel, indexChannel, indexSize);
+  }
+
   public synchronized void readBuffers() {
     hasReadingTask.set(false);
     if (isReleased) {
@@ -195,7 +199,7 @@ class MapPartitionData implements 
MemoryManager.ReadBufferTargetChangeListener {
                   .filter(MapPartitionDataReader::shouldReadData)
                   .collect(Collectors.toList()));
       for (MapPartitionDataReader reader : sortedReaders) {
-        reader.open(dataFileChanel, indexChannel, indexSize);
+        openReader(reader);
       }
       while (bufferQueue.bufferAvailable() && !sortedReaders.isEmpty()) {
         BufferRecycler bufferRecycler = new 
BufferRecycler(MapPartitionData.this::recycle);
diff --git 
a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/MapPartitionDataReader.java
 
b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/MapPartitionDataReader.java
index ad13bf5e9..96d723bde 100644
--- 
a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/MapPartitionDataReader.java
+++ 
b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/MapPartitionDataReader.java
@@ -42,6 +42,7 @@ import 
org.apache.celeborn.common.network.buffer.NioManagedBuffer;
 import org.apache.celeborn.common.network.client.TransportClient;
 import org.apache.celeborn.common.network.protocol.BacklogAnnouncement;
 import org.apache.celeborn.common.network.protocol.ReadData;
+import org.apache.celeborn.common.network.protocol.RequestMessage;
 import org.apache.celeborn.common.network.protocol.RpcRequest;
 import org.apache.celeborn.common.network.protocol.TransportMessage;
 import org.apache.celeborn.common.network.protocol.TransportableError;
@@ -68,12 +69,12 @@ public class MapPartitionDataReader implements 
Comparable<MapPartitionDataReader
   private long dataConsumingOffset;
   private volatile long currentPartitionRemainingBytes;
   private DiskFileInfo fileInfo;
-  private MapFileMeta mapFileMeta;
+  protected MapFileMeta mapFileMeta;
   private int INDEX_ENTRY_SIZE = 16;
-  private long streamId;
+  protected long streamId;
   protected final Object lock = new Object();
 
-  private final AtomicInteger credits = new AtomicInteger();
+  protected final AtomicInteger credits = new AtomicInteger();
 
   @GuardedBy("lock")
   protected final Queue<RecyclableBuffer> buffersToSend = new ArrayDeque<>();
@@ -101,7 +102,7 @@ public class MapPartitionDataReader implements 
Comparable<MapPartitionDataReader
 
   private Runnable recycleStream;
 
-  private AtomicInteger numInUseBuffers = new AtomicInteger(0);
+  protected AtomicInteger numInUseBuffers = new AtomicInteger(0);
   private boolean isOpen = false;
 
   public MapPartitionDataReader(
@@ -177,18 +178,22 @@ public class MapPartitionDataReader implements 
Comparable<MapPartitionDataReader
       closeReader();
     }
 
+    tryNotifyBacklog(numDataBuffers);
+  }
+
+  protected void tryNotifyBacklog(int numDataBuffers) {
     if (numDataBuffers > 0) {
       notifyBacklog(numDataBuffers);
     }
   }
 
-  private void addBuffer(ByteBuf buffer, BufferRecycler bufferRecycler) {
+  protected void addBuffer(ByteBuf buffer, BufferRecycler bufferRecycler) {
     if (buffer == null) {
       return;
     }
     synchronized (lock) {
       if (!isReleased) {
-        buffersToSend.add(new RecyclableBuffer(buffer, bufferRecycler));
+        buffersToSend.add(new RecyclableBuffer(buffer, -1, bufferRecycler));
       } else {
         bufferRecycler.recycle(buffer);
         numInUseBuffers.decrementAndGet();
@@ -197,7 +202,7 @@ public class MapPartitionDataReader implements 
Comparable<MapPartitionDataReader
     }
   }
 
-  private RecyclableBuffer fetchBufferToSend() {
+  protected RecyclableBuffer fetchBufferToSend() {
     synchronized (lock) {
       if (!buffersToSend.isEmpty() && credits.get() > 0 && !isReleased) {
         return buffersToSend.poll();
@@ -207,7 +212,7 @@ public class MapPartitionDataReader implements 
Comparable<MapPartitionDataReader
     }
   }
 
-  private int getNumBuffersToSend() {
+  protected int getNumBuffersToSend() {
     synchronized (lock) {
       return buffersToSend.size();
     }
@@ -216,32 +221,13 @@ public class MapPartitionDataReader implements 
Comparable<MapPartitionDataReader
   public synchronized void sendData() {
     RecyclableBuffer buffer;
     while (null != (buffer = fetchBufferToSend())) {
-      final RecyclableBuffer wrappedBuffer = buffer;
-      int readableBytes = wrappedBuffer.byteBuf.readableBytes();
-      if (logger.isDebugEnabled()) {
-        logger.debug("send data start: {}, {}, {}", streamId, readableBytes, 
getNumBuffersToSend());
-      }
-      ReadData readData = new ReadData(streamId, wrappedBuffer.byteBuf);
-      associatedChannel
-          .writeAndFlush(readData)
-          .addListener(
-              (ChannelFutureListener)
-                  future -> {
-                    try {
-                      if (!future.isSuccess()) {
-                        recycleOnError(future.cause());
-                      }
-                    } finally {
-                      logger.debug("send data end: {}, {}", streamId, 
readableBytes);
-                      wrappedBuffer.recycle();
-                      numInUseBuffers.decrementAndGet();
-                    }
-                  });
-
-      int currentCredit = credits.decrementAndGet();
-      logger.debug("stream {} credit {}", streamId, currentCredit);
+      sendDataInternal(buffer);
     }
 
+    tryRecycleReader();
+  }
+
+  public void tryRecycleReader() {
     boolean shouldRecycle = false;
     synchronized (lock) {
       if (isReleased) return;
@@ -255,11 +241,46 @@ public class MapPartitionDataReader implements 
Comparable<MapPartitionDataReader
     }
   }
 
-  private long getIndexRegionSize() {
+  public RequestMessage generateReadDataMessage(
+      long streamId, int subPartitionId, ByteBuf byteBuf) {
+    return new ReadData(streamId, byteBuf);
+  }
+
+  protected void sendDataInternal(RecyclableBuffer buffer) {
+    final RecyclableBuffer wrappedBuffer = buffer;
+    int readableBytes = wrappedBuffer.byteBuf.readableBytes();
+    if (logger.isDebugEnabled()) {
+      logger.debug("send data start: {}, {}", streamId, readableBytes);
+    }
+
+    RequestMessage readData =
+        generateReadDataMessage(streamId, wrappedBuffer.subPartitionId, 
wrappedBuffer.byteBuf);
+    associatedChannel
+        .writeAndFlush(readData)
+        .addListener(
+            (ChannelFutureListener)
+                future -> {
+                  try {
+                    if (!future.isSuccess()) {
+                      recycleOnError(future.cause());
+                    }
+                  } finally {
+                    if (logger.isDebugEnabled()) {
+                      logger.debug("send data end: {}, {}", streamId, 
readableBytes);
+                    }
+                    wrappedBuffer.recycle();
+                    numInUseBuffers.decrementAndGet();
+                  }
+                });
+    int currentCredit = credits.decrementAndGet();
+    logger.debug("Current credit is {} after stream {}", currentCredit, 
streamId);
+  }
+
+  protected long getIndexRegionSize() {
     return mapFileMeta.getNumSubpartitions() * (long) INDEX_ENTRY_SIZE;
   }
 
-  private void readHeaderOrIndexBuffer(FileChannel channel, ByteBuffer buffer, 
int length)
+  protected void readHeaderOrIndexBuffer(FileChannel channel, ByteBuffer 
buffer, int length)
       throws IOException {
     Utils.checkFileIntegrity(channel, length);
     buffer.clear();
@@ -270,7 +291,7 @@ public class MapPartitionDataReader implements 
Comparable<MapPartitionDataReader
     buffer.flip();
   }
 
-  private void readBufferIntoReadBuffer(FileChannel channel, ByteBuf buf, int 
length)
+  protected void readBufferIntoReadBuffer(FileChannel channel, ByteBuf buf, 
int length)
       throws IOException {
     Utils.checkFileIntegrity(channel, length);
     ByteBuffer tmpBuffer = ByteBuffer.allocate(length);
@@ -281,7 +302,7 @@ public class MapPartitionDataReader implements 
Comparable<MapPartitionDataReader
     buf.writeBytes(tmpBuffer);
   }
 
-  private int readBuffer(
+  protected int readBuffer(
       String filename, FileChannel channel, ByteBuffer header, ByteBuf buffer, 
int headerSize)
       throws IOException {
     readHeaderOrIndexBuffer(channel, header, headerSize);
@@ -297,7 +318,7 @@ public class MapPartitionDataReader implements 
Comparable<MapPartitionDataReader
     return bufferLength + headerSize;
   }
 
-  private void updateConsumingOffset() throws IOException {
+  protected void updateConsumingOffset() throws IOException {
     while (currentPartitionRemainingBytes == 0
         && (currentDataRegion < numRegions - 1 || numRemainingPartitions > 0)) 
{
       if (numRemainingPartitions <= 0) {
@@ -344,6 +365,7 @@ public class MapPartitionDataReader implements 
Comparable<MapPartitionDataReader
               buffer,
               headerBuffer.capacity());
       currentPartitionRemainingBytes -= readSize;
+      dataConsumingOffset = dataFileChannel.position();
 
       logger.debug(
           "readBuffer data: {}, {}, {}, {}, {}, {}",
@@ -369,8 +391,6 @@ public class MapPartitionDataReader implements 
Comparable<MapPartitionDataReader
         return prevDataRegion == currentDataRegion && 
currentPartitionRemainingBytes > 0;
       }
 
-      dataConsumingOffset = dataFileChannel.position();
-
       logger.debug(
           "readBuffer run: {}, {}, {}, {}",
           streamId,
@@ -391,7 +411,7 @@ public class MapPartitionDataReader implements 
Comparable<MapPartitionDataReader
     return currentPartitionRemainingBytes > 0;
   }
 
-  private void notifyBacklog(int backlog) {
+  protected void notifyBacklog(int backlog) {
     logger.debug("stream manager stream id {} backlog:{}", streamId, backlog);
     associatedChannel
         .writeAndFlush(new BacklogAnnouncement(streamId, backlog))
@@ -471,7 +491,7 @@ public class MapPartitionDataReader implements 
Comparable<MapPartitionDataReader
         logger.debug("release reader for stream {}", streamId);
         // old client can't support BufferStreamEnd, so for new client it 
tells client that this
         // stream is finished.
-        if (fileInfo.isPartitionSplitEnabled() && !errorNotified)
+        if (fileInfo.isPartitionSplitEnabled() && !errorNotified) {
           associatedChannel.writeAndFlush(
               new RpcRequest(
                   TransportClient.requestId(),
@@ -484,10 +504,17 @@ public class MapPartitionDataReader implements 
Comparable<MapPartitionDataReader
                                   .build()
                                   .toByteArray())
                           .toByteBuffer())));
+        }
         if (!buffersToSend.isEmpty()) {
-          numInUseBuffers.addAndGet(-1 * buffersToSend.size());
-          buffersToSend.forEach(RecyclableBuffer::recycle);
-          buffersToSend.clear();
+          int dataBufferInUse = 0;
+          RecyclableBuffer buffer;
+          while ((buffer = buffersToSend.poll()) != null) {
+            if (buffer.isDataBuffer()) {
+              dataBufferInUse++;
+            }
+            buffer.recycle();
+          }
+          numInUseBuffers.addAndGet(-1 * dataBufferInUse);
         }
         isReleased = true;
       }
diff --git 
a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/segment/SegmentMapPartitionData.java
 
b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/segment/SegmentMapPartitionData.java
new file mode 100644
index 000000000..3d3b6cfe8
--- /dev/null
+++ 
b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/segment/SegmentMapPartitionData.java
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.service.deploy.worker.storage.segment;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.concurrent.ExecutorService;
+import java.util.function.Consumer;
+
+import io.netty.channel.Channel;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.celeborn.common.meta.DiskFileInfo;
+import org.apache.celeborn.service.deploy.worker.storage.MapPartitionData;
+import 
org.apache.celeborn.service.deploy.worker.storage.MapPartitionDataReader;
+
+public class SegmentMapPartitionData extends MapPartitionData {
+
+  public static final Logger logger = 
LoggerFactory.getLogger(SegmentMapPartitionData.class);
+
+  public SegmentMapPartitionData(
+      int minReadBuffers,
+      int maxReadBuffers,
+      HashMap<String, ExecutorService> storageFetcherPool,
+      int threadsPerMountPoint,
+      DiskFileInfo fileInfo,
+      Consumer<Long> recycleStream,
+      int minBuffersToTriggerRead)
+      throws IOException {
+    super(
+        minReadBuffers,
+        maxReadBuffers,
+        storageFetcherPool,
+        threadsPerMountPoint,
+        fileInfo,
+        recycleStream,
+        minBuffersToTriggerRead);
+  }
+
+  @Override
+  public void setupDataPartitionReader(
+      int startSubIndex, int endSubIndex, long streamId, Channel channel) {
+    SegmentMapPartitionDataReader mapDataPartitionReader =
+        new SegmentMapPartitionDataReader(
+            startSubIndex,
+            endSubIndex,
+            getDiskFileInfo(),
+            streamId,
+            channel,
+            () -> recycleStream.accept(streamId));
+    logger.debug(
+        "Setup data partition reader from {} to {} with streamId {}",
+        startSubIndex,
+        endSubIndex,
+        streamId);
+    readers.put(streamId, mapDataPartitionReader);
+  }
+
+  @Override
+  protected void openReader(MapPartitionDataReader reader) throws IOException {
+    super.openReader(reader);
+    if (reader instanceof SegmentMapPartitionDataReader) {
+      ((SegmentMapPartitionDataReader) reader).updateSegmentId();
+    } else {
+      logger.warn("openReader only expects SegmentMapPartitionDataReader.");
+    }
+  }
+
+  @Override
+  public String toString() {
+    return String.format("SegmentMapDataPartition{filePath=%s}", 
diskFileInfo.getFilePath());
+  }
+
+  public void notifyRequiredSegmentId(int segmentId, long streamId, int 
subPartitionId) {
+    MapPartitionDataReader streamReader = getStreamReader(streamId);
+    if (!(streamReader instanceof SegmentMapPartitionDataReader)) {
+      logger.warn("notifyRequiredSegmentId only expects non-null 
SegmentMapPartitionDataReader.");
+      return;
+    }
+    ((SegmentMapPartitionDataReader) streamReader)
+        .notifyRequiredSegmentId(segmentId, subPartitionId);
+    // After notifying the required segment id, we need to try to send data 
again.
+    readExecutor.submit(
+        () -> {
+          try {
+            streamReader.sendData();
+          } catch (Throwable throwable) {
+            logger.error("Failed to send data.", throwable);
+          }
+        });
+  }
+}
diff --git 
a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/segment/SegmentMapPartitionDataReader.java
 
b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/segment/SegmentMapPartitionDataReader.java
new file mode 100644
index 000000000..21b36472c
--- /dev/null
+++ 
b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/segment/SegmentMapPartitionDataReader.java
@@ -0,0 +1,325 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.service.deploy.worker.storage.segment;
+
+import java.util.Deque;
+import java.util.HashMap;
+import java.util.LinkedList;
+import java.util.Map;
+
+import javax.annotation.concurrent.GuardedBy;
+
+import com.google.common.base.Preconditions;
+import io.netty.buffer.ByteBuf;
+import io.netty.channel.Channel;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.celeborn.common.meta.DiskFileInfo;
+import org.apache.celeborn.common.network.protocol.RequestMessage;
+import org.apache.celeborn.common.network.protocol.SubPartitionReadData;
+import org.apache.celeborn.service.deploy.worker.memory.BufferRecycler;
+import org.apache.celeborn.service.deploy.worker.memory.RecyclableBuffer;
+import 
org.apache.celeborn.service.deploy.worker.memory.RecyclableSegmentIdBuffer;
+import 
org.apache.celeborn.service.deploy.worker.storage.MapPartitionDataReader;
+
+public class SegmentMapPartitionDataReader extends MapPartitionDataReader {
+
+  public static final Logger logger = 
LoggerFactory.getLogger(SegmentMapPartitionDataReader.class);
+
+  // The Flink EndOfSegment buffer data type, which has been written
+  // to the buffer header, is in the 16th bit(starting from 0).
+  private static final int END_OF_SEGMENT_BUFFER_DATA_TYPE = 7;
+
+  private final int startPartitionIndex;
+
+  private final int endPartitionIndex;
+
+  @GuardedBy("lock")
+  private final Deque<Integer> backlogs = new LinkedList<>();
+
+  @GuardedBy("lock")
+  // subpartitionId -> segmentId, current required segmentId by client per 
subpartition
+  private Map<Integer, Integer> subPartitionRequiredSegmentIds = new 
HashMap<>();
+
+  @GuardedBy("lock")
+  private boolean hasUpdateSegmentId = false;
+
+  @GuardedBy("lock")
+  // subpartitionId -> segmentId, segmentId of current reading buffer per 
subpartition
+  private Map<Integer, Integer> subPartitionLastSegmentIds = new HashMap<>();
+
+  @GuardedBy("lock")
+  // subpartitionId -> buffer index, current reading buffer index per 
subpartition
+  private Map<Integer, Integer> subPartitionNextBufferIndex = new HashMap<>();
+
+  public SegmentMapPartitionDataReader(
+      int startPartitionIndex,
+      int endPartitionIndex,
+      DiskFileInfo fileInfo,
+      long streamId,
+      Channel associatedChannel,
+      Runnable recycleStream) {
+    super(
+        startPartitionIndex,
+        endPartitionIndex,
+        fileInfo,
+        streamId,
+        associatedChannel,
+        recycleStream);
+    this.startPartitionIndex = startPartitionIndex;
+    this.endPartitionIndex = endPartitionIndex;
+    for (int i = startPartitionIndex; i <= endPartitionIndex; i++) {
+      subPartitionLastSegmentIds.put(i, -1);
+      subPartitionRequiredSegmentIds.put(i, -1);
+      subPartitionNextBufferIndex.put(i, 0);
+    }
+  }
+
+  @Override
+  protected void tryNotifyBacklog(int numDataBuffers) {
+    notifyBacklog(getBacklog());
+  }
+
+  @Override
+  public synchronized void sendData() {
+    while (true) {
+      synchronized (lock) {
+        if (!hasUpdateSegmentId) {
+          logger.debug(
+              "The required segment id is not updated for {}, skip sending 
data.", getStreamId());
+          return;
+        }
+
+        RecyclableBuffer buffer;
+
+        // Verify if the client requires the segmentId of the first buffer; if 
so, send it to the
+        // worker; otherwise, wait until the client sends a new segmentId.
+        boolean breakLoop = false;
+        while ((buffer = buffersToSend.peek()) instanceof 
RecyclableSegmentIdBuffer) {
+          RecyclableSegmentIdBuffer recyclableSegmentIdBuffer = 
(RecyclableSegmentIdBuffer) buffer;
+          int subPartitionId = recyclableSegmentIdBuffer.getSubPartitionId();
+          int segmentId = recyclableSegmentIdBuffer.getSegmentId();
+          int requiredSegmentId = 
subPartitionRequiredSegmentIds.get(subPartitionId);
+          if (segmentId != requiredSegmentId) {
+            // If the queued head buffer is not the required segment id, we do 
not sent it.
+            logger.info(
+                "The queued head buffer is not the required segment id, "
+                    + "do not sent it. details: streamId {}, subPartitionId: 
{}, current segment id: {}, required segment id: {}, reader: {}",
+                streamId,
+                subPartitionId,
+                segmentId,
+                requiredSegmentId,
+                this);
+            breakLoop = true;
+            break;
+          } else {
+            buffersToSend.poll();
+          }
+        }
+        if (breakLoop) {
+          break;
+        }
+
+        // fetch first buffer and send
+        buffer = fetchBufferToSend();
+        if (buffer == null) {
+          break;
+        }
+        if (buffer instanceof RecyclableSegmentIdBuffer) {
+          logger.warn("Wrong type of buffer, the RecyclableSegmentIdBuffer is 
not expected.");
+          return;
+        }
+        sendDataInternal(buffer);
+      }
+    }
+
+    tryRecycleReader();
+  }
+
+  @Override
+  protected void addBuffer(ByteBuf buffer, BufferRecycler bufferRecycler) {
+    if (buffer == null) {
+      return;
+    }
+
+    buffer.markReaderIndex();
+    int subPartitionId = buffer.readInt();
+    // check the buffer type
+    boolean isEndOfSegment = isEndOfSegment(buffer, subPartitionId);
+    buffer.resetReaderIndex();
+
+    synchronized (lock) {
+      if (!isReleased) {
+        buffersToSend.add(new RecyclableBuffer(buffer, subPartitionId, 
bufferRecycler));
+        increaseBacklog();
+      } else {
+        bufferRecycler.recycle(buffer);
+        numInUseBuffers.decrementAndGet();
+        throw new RuntimeException("Partition reader has been failed or 
finished.", errorCause);
+      }
+
+      subPartitionNextBufferIndex.compute(subPartitionId, (k, v) -> v + 1);
+      if (isEndOfSegment) {
+        updateSegmentId(subPartitionId);
+      }
+    }
+  }
+
+  @Override
+  protected RecyclableBuffer fetchBufferToSend() {
+    synchronized (lock) {
+      if (isReleased || buffersToSend.isEmpty()) {
+        return null;
+      }
+
+      RecyclableBuffer buffer = null;
+      int numCredit = credits.get();
+      if (numCredit > 0) {
+        buffer = buffersToSend.poll();
+      }
+      if (numCredit <= 1) {
+        notifyBacklog(getBacklog());
+      }
+      return buffer;
+    }
+  }
+
+  @Override
+  protected void notifyBacklog(int backlog) {
+    if (backlog == 0) {
+      return;
+    }
+    super.notifyBacklog(backlog);
+  }
+
+  @Override
+  public String toString() {
+    final StringBuilder sb = new 
StringBuilder("SegmentMapDataPartitionReader{");
+    sb.append("startPartitionIndex=").append(startPartitionIndex);
+    sb.append(", endPartitionIndex=").append(endPartitionIndex);
+    sb.append(", streamId=").append(streamId);
+    sb.append('}');
+    return sb.toString();
+  }
+
+  private boolean isEndOfSegment(ByteBuf buffer, int subPartitionId) {
+    boolean isEndOfSegment = false;
+    Preconditions.checkState(
+        subPartitionId >= startPartitionIndex && subPartitionId <= 
endPartitionIndex);
+    if (mapFileMeta.hasPartitionSegmentIds()) {
+      // skip another 3 int fields, the write details are in
+      // FlinkShuffleClientImpl#pushDataToLocation
+      buffer.skipBytes(3 * 4);
+      int dataType = buffer.readByte();
+      isEndOfSegment = END_OF_SEGMENT_BUFFER_DATA_TYPE == dataType;
+    }
+    return isEndOfSegment;
+  }
+
+  private int getBacklog() {
+    synchronized (lock) {
+      Integer backlog = backlogs.peekFirst();
+      while (backlog != null && backlog == 0) {
+        backlogs.pollFirst();
+        backlog = backlogs.peekFirst();
+      }
+      if (backlog != null) {
+        backlogs.pollFirst();
+      }
+      return backlog == null ? 0 : backlog;
+    }
+  }
+
+  @GuardedBy("lock")
+  private void addNewBacklog() {
+    backlogs.addLast(0);
+  }
+
+  @GuardedBy("lock")
+  private void increaseBacklog() {
+    Integer backlog = backlogs.pollLast();
+    if (backlog == null) {
+      backlogs.addLast(1);
+    } else {
+      backlogs.addLast(backlog + 1);
+    }
+  }
+
+  private void updateSegmentId(int subPartitionId) {
+    synchronized (lock) {
+      // Note that only when writing buffers, it has the segment info, when 
loading buffers from
+      // disk,
+      // we do not know when the segment is started, so we try to get the 
segment id from the buffer
+      // index here.
+      Integer segmentId =
+          mapFileMeta.getSegmentIdByFirstBufferIndex(
+              subPartitionId, subPartitionNextBufferIndex.get(subPartitionId));
+      if (segmentId == null) {
+        return;
+      }
+      if (segmentId != -1) {
+        // For the continuous segments, we use the same backlog
+        if (segmentId == 0
+            || 
!segmentId.equals(subPartitionLastSegmentIds.get(subPartitionId))
+                && segmentId != 
(subPartitionLastSegmentIds.get(subPartitionId) + 1)) {
+          addNewBacklog();
+        }
+        subPartitionLastSegmentIds.put(subPartitionId, segmentId);
+      }
+      logger.debug(
+          "Insert a buffer to indicate the current segment id "
+              + "subPartitionId={}, segmentId={} for {}.",
+          subPartitionId,
+          segmentId,
+          this);
+      // Before adding buffers in this segment, add a new buffer to indicate 
the segment id. So this
+      // buffer
+      // is in the head of this segment.
+      buffersToSend.add(new RecyclableSegmentIdBuffer(subPartitionId, 
segmentId));
+    }
+  }
+
+  public RequestMessage generateReadDataMessage(
+      long streamId, int subPartitionId, ByteBuf byteBuf) {
+    return new SubPartitionReadData(streamId, subPartitionId, byteBuf);
+  }
+
+  public void notifyRequiredSegmentId(int segmentId, int subPartitionId) {
+    synchronized (lock) {
+      logger.debug(
+          "Update the required segment id to {}, {}, subPartitionId: {}",
+          segmentId,
+          this,
+          subPartitionId);
+      this.subPartitionRequiredSegmentIds.put(subPartitionId, segmentId);
+    }
+  }
+
+  protected void updateSegmentId() {
+    synchronized (lock) {
+      hasUpdateSegmentId = true;
+      for (int i = startPartitionIndex; i <= endPartitionIndex; i++) {
+        if (subPartitionLastSegmentIds.get(i) < 0) {
+          updateSegmentId(i);
+        }
+      }
+    }
+  }
+}
diff --git 
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala
 
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala
index 4d7e53895..f040f4334 100644
--- 
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala
+++ 
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala
@@ -37,7 +37,7 @@ import 
org.apache.celeborn.common.network.client.{RpcResponseCallback, Transport
 import org.apache.celeborn.common.network.protocol._
 import org.apache.celeborn.common.network.server.BaseMessageHandler
 import org.apache.celeborn.common.network.util.{NettyUtils, TransportConf}
-import org.apache.celeborn.common.protocol.{MessageType, PbBufferStreamEnd, 
PbChunkFetchRequest, PbOpenStream, PbOpenStreamList, PbOpenStreamListResponse, 
PbReadAddCredit, PbStreamHandler, PbStreamHandlerOpt, StreamType}
+import org.apache.celeborn.common.protocol.{MessageType, PbBufferStreamEnd, 
PbChunkFetchRequest, PbNotifyRequiredSegment, PbOpenStream, PbOpenStreamList, 
PbOpenStreamListResponse, PbReadAddCredit, PbStreamHandler, PbStreamHandlerOpt, 
StreamType}
 import org.apache.celeborn.common.protocol.message.StatusCode
 import org.apache.celeborn.common.util.{ExceptionUtils, Utils}
 import org.apache.celeborn.service.deploy.worker.storage.{ChunkStreamManager, 
CreditStreamManager, PartitionFilesSorter, StorageManager}
@@ -171,6 +171,12 @@ class FetchHandler(
           bufferStreamEnd.getStreamType)
       case readAddCredit: PbReadAddCredit =>
         handleReadAddCredit(client, readAddCredit.getCredit, 
readAddCredit.getStreamId)
+      case notifyRequiredSegment: PbNotifyRequiredSegment =>
+        handleNotifyRequiredSegment(
+          client,
+          notifyRequiredSegment.getRequiredSegmentId,
+          notifyRequiredSegment.getStreamId,
+          notifyRequiredSegment.getSubPartitionId)
       case chunkFetchRequest: PbChunkFetchRequest =>
         handleChunkFetchRequest(
           client,
@@ -484,6 +490,23 @@ class FetchHandler(
     }
   }
 
+  def handleNotifyRequiredSegment(
+      client: TransportClient,
+      requiredSegmentId: Int,
+      streamId: Long,
+      subPartitionId: Int): Unit = {
+    // process NotifyRequiredSegment request here, the MapPartitionDataReader 
will send data if the segment buffer is available.
+    logDebug(
+      s"Handle NotifyRequiredSegment with streamId: $streamId, 
requiredSegmentId: $requiredSegmentId")
+    val shuffleKey = creditStreamManager.getStreamShuffleKey(streamId)
+    if (shuffleKey != null) {
+      workerSource.recordAppActiveConnection(
+        client,
+        shuffleKey)
+      creditStreamManager.notifyRequiredSegment(requiredSegmentId, streamId, 
subPartitionId)
+    }
+  }
+
   def handleChunkFetchRequest(
       client: TransportClient,
       streamChunkSlice: StreamChunkSlice,


Reply via email to