This is an automated email from the ASF dual-hosted git repository.

fchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new 7bc34972a [CELEBORN-1177] OpenStream should register stream via 
ChunkStreamManager to close stream for ReusedExchange
7bc34972a is described below

commit 7bc34972ac9d5c4bc6f7d96e857f7cbe5ec23986
Author: SteNicholas <[email protected]>
AuthorDate: Fri Jan 12 00:26:59 2024 +0800

    [CELEBORN-1177] OpenStream should register stream via ChunkStreamManager to 
close stream for ReusedExchange
    
    ### What changes were proposed in this pull request?
    
    `OpenStream` should register stream via `ChunkStreamManager`, which is 
served to obtain disk file to close stream for `ReusedExchange` operator.
    
    Follow up #1932.
    
    ### Why are the changes needed?
    
    `OpenStream` does not register chunk stream for reading local or dfs 
shuffle. Therefore `LocalPartitionReader` and `DfsPartitionReader` could not 
obtain the disk file from `ChunkStreamManager` that causes the below 
`NullPointerException` for closing stream.
    ```
    ERROR [fetch-server-11-11] TransportRequestHandler: Error while invoking 
handler#receive() on RPC id 4
    java.lang.NullPointerException
            at 
org.apache.celeborn.service.deploy.worker.storage.ChunkStreamManager.getShuffleKeyAndFileName(ChunkStreamManager.java:188)
            at 
org.apache.celeborn.service.deploy.worker.FetchHandler.handleEndStreamFromClient(FetchHandler.scala:344)
            at 
org.apache.celeborn.service.deploy.worker.FetchHandler.handleRpcRequest(FetchHandler.scala:137)
            at 
org.apache.celeborn.service.deploy.worker.FetchHandler.receive(FetchHandler.scala:94)
            at 
org.apache.celeborn.common.network.server.TransportRequestHandler.processRpcRequest(TransportRequestHandler.java:96)
            at 
org.apache.celeborn.common.network.server.TransportRequestHandler.handle(TransportRequestHandler.java:84)
            at 
org.apache.celeborn.common.network.server.TransportChannelHandler.channelRead(TransportChannelHandler.java:151)
            at 
io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:444)
            at 
io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:420)
            at 
io.netty.channel.AbstractChannelHandlerContext.fireChannelRead(AbstractChannelHandlerContext.java:412)
            at 
io.netty.handler.timeout.IdleStateHandler.channelRead(IdleStateHandler.java:286)
            at 
io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:442)
            at 
io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:420)
            at 
io.netty.channel.AbstractChannelHandlerContext.fireChannelRead(AbstractChannelHandlerContext.java:412)
            at 
org.apache.celeborn.common.network.util.TransportFrameDecoder.channelRead(TransportFrameDecoder.java:74)
            at 
io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:444)
            at 
io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:420)
            at 
io.netty.channel.AbstractChannelHandlerContext.fireChannelRead(AbstractChannelHandlerContext.java:412)
            at 
io.netty.channel.DefaultChannelPipeline$HeadContext.channelRead(DefaultChannelPipeline.java:1410)
            at 
io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:440)
            at 
io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:420)
            at 
io.netty.channel.DefaultChannelPipeline.fireChannelRead(DefaultChannelPipeline.java:919)
            at 
io.netty.channel.nio.AbstractNioByteChannel$NioByteUnsafe.read(AbstractNioByteChannel.java:166)
            at 
io.netty.channel.nio.NioEventLoop.processSelectedKey(NioEventLoop.java:788)
            at 
io.netty.channel.nio.NioEventLoop.processSelectedKeysOptimized(NioEventLoop.java:724)
            at 
io.netty.channel.nio.NioEventLoop.processSelectedKeys(NioEventLoop.java:650)
            at io.netty.channel.nio.NioEventLoop.run(NioEventLoop.java:562)
            at 
io.netty.util.concurrent.SingleThreadEventExecutor$4.run(SingleThreadEventExecutor.java:997)
            at 
io.netty.util.internal.ThreadExecutorMap$2.run(ThreadExecutorMap.java:74)
            at 
io.netty.util.concurrent.FastThreadLocalRunnable.run(FastThreadLocalRunnable.java:30)
            at java.lang.Thread.run(Thread.java:745)
    ```
    In summary, `FetchHandler` only closes stream registered via 
`ChunkStreamManager`. `LocalPartitionReader` and `DfsPartitionReader` should 
use `ChunkStreamManager#registerStream` to close stream for deleting original 
unsorted disk file in `ReusedExchange` operator.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    - `FetchHandlerSuiteJ#testLocalReadSortFileOnceOriginalFileBeDeleted`
    - 
`FetchHandlerSuiteJ#testDoNotDeleteOriginalFileWhenNonRangeLocalReadWorkInProgress`
    - `ReusedExchangeSuite`
    
    Closes #2209 from SteNicholas/CELEBORN-1177.
    
    Authored-by: SteNicholas <[email protected]>
    Signed-off-by: Fu Chen <[email protected]>
---
 .../celeborn/tests/spark/ReusedExchangeSuite.scala | 11 +++-
 .../deploy/worker/storage/ChunkStreamManager.java  | 14 ++++-
 .../service/deploy/worker/FetchHandler.scala       | 20 ++++---
 .../service/deploy/worker/FetchHandlerSuiteJ.java  | 67 +++++++++++++++++++---
 4 files changed, 94 insertions(+), 18 deletions(-)

diff --git 
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/ReusedExchangeSuite.scala
 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/ReusedExchangeSuite.scala
index 92984fb8f..6a3d3ad36 100644
--- 
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/ReusedExchangeSuite.scala
+++ 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/ReusedExchangeSuite.scala
@@ -34,10 +34,19 @@ class ReusedExchangeSuite extends AnyFunSuite
     ShuffleClient.reset()
   }
 
-  test("ReusedExchange end to end test") {
+  test("[CELEBORN-980] Asynchronously delete original files to fix 
ReusedExchange bug") {
+    testReusedExchange(false)
+  }
+
+  test("[CELEBORN-1177] OpenStream should register stream via 
ChunkStreamManager to close stream for ReusedExchange") {
+    testReusedExchange(true)
+  }
+
+  def testReusedExchange(readLocalShuffle: Boolean): Unit = {
     val sparkConf = new 
SparkConf().setAppName("celeborn-test").setMaster("local[2]")
       .set("spark.shuffle.manager", 
"org.apache.spark.shuffle.celeborn.SparkShuffleManager")
       .set(s"spark.${CelebornConf.MASTER_ENDPOINTS.key}", 
masterInfo._1.rpcEnv.address.toString)
+      .set(s"spark.${CelebornConf.READ_LOCAL_SHUFFLE_FILE.key}", 
readLocalShuffle.toString)
       .set("spark.sql.autoBroadcastJoinThreshold", "-1")
       .set("spark.sql.adaptive.skewJoin.skewedPartitionThresholdInBytes", 
"100")
       .set("spark.sql.adaptive.advisoryPartitionSizeInBytes", "100")
diff --git 
a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/ChunkStreamManager.java
 
b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/ChunkStreamManager.java
index 731754345..006ecbe44 100644
--- 
a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/ChunkStreamManager.java
+++ 
b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/ChunkStreamManager.java
@@ -25,7 +25,6 @@ import java.util.concurrent.atomic.AtomicLong;
 import scala.Tuple2;
 
 import com.google.common.annotations.VisibleForTesting;
-import com.google.common.base.Preconditions;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -62,7 +61,7 @@ public class ChunkStreamManager {
         FileManagedBuffers buffers,
         String fileName,
         TimeWindow fetchTimeMetric) {
-      this.buffers = Preconditions.checkNotNull(buffers);
+      this.buffers = buffers;
       this.shuffleKey = shuffleKey;
       this.fileName = fileName;
       this.fetchTimeMetric = fetchTimeMetric;
@@ -123,6 +122,17 @@ public class ChunkStreamManager {
     return sum;
   }
 
+  /**
+   * Registers a stream with shuffle key and disk file when reading local or 
dfs shuffle, which is
+   * served to obtain disk file via registered stream id to close stream.
+   *
+   * <p>This stream could be reused again when other channel of the client is 
reconnected. If a
+   * stream is not properly closed, it will eventually be cleaned up by 
`cleanupExpiredShuffleKey`.
+   */
+  public long registerStream(long streamId, String shuffleKey, String 
fileName) {
+    return registerStream(streamId, shuffleKey, null, fileName, null);
+  }
+
   /**
    * Registers a stream of ManagedBuffers which are served as individual 
chunks one at a time to
    * callers. Each ManagedBuffer will be release()'d after it is transferred 
on the wire. If a
diff --git 
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala
 
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala
index c7b657ba4..bea4959a4 100644
--- 
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala
+++ 
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala
@@ -230,33 +230,37 @@ class FetchHandler(
           }
           val meta = fileInfo.getFileMeta.asInstanceOf[ReduceFileMeta]
           if (readLocalShuffle) {
+            chunkStreamManager.registerStream(
+              streamId,
+              shuffleKey,
+              fileName)
             replyStreamHandler(
               client,
               rpcRequestId,
-              -1,
+              streamId,
               meta.getNumChunks,
               isLegacy,
               meta.getChunkOffsets,
               fileInfo.getFilePath)
           } else if (fileInfo.isHdfs) {
+            chunkStreamManager.registerStream(
+              streamId,
+              shuffleKey,
+              fileName)
             replyStreamHandler(client, rpcRequestId, streamId, numChunks = 0, 
isLegacy)
           } else {
-            val buffers =
-              new FileManagedBuffers(fileInfo, transportConf)
-            val fetchTimeMetrics =
-              storageManager.getFetchTimeMetric(fileInfo.getFile)
             chunkStreamManager.registerStream(
               streamId,
               shuffleKey,
-              buffers,
+              new FileManagedBuffers(fileInfo, transportConf),
               fileName,
-              fetchTimeMetrics)
+              storageManager.getFetchTimeMetric(fileInfo.getFile))
             if (meta.getNumChunks == 0)
               logDebug(s"StreamId $streamId, fileName $fileName, mapRange " +
                 s"[$startIndex-$endIndex] is empty. Received from client 
channel " +
                 s"${NettyUtils.getRemoteAddress(client.getChannel)}")
             else logDebug(
-              s"StreamId $streamId, fileName $fileName, numChunks 
${meta.getNumChunks()}, " +
+              s"StreamId $streamId, fileName $fileName, numChunks 
${meta.getNumChunks}, " +
                 s"mapRange [$startIndex-$endIndex]. Received from client 
channel " +
                 s"${NettyUtils.getRemoteAddress(client.getChannel)}")
             replyStreamHandler(
diff --git 
a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/FetchHandlerSuiteJ.java
 
b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/FetchHandlerSuiteJ.java
index 71ae24225..3bc4aabae 100644
--- 
a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/FetchHandlerSuiteJ.java
+++ 
b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/FetchHandlerSuiteJ.java
@@ -217,7 +217,7 @@ public class FetchHandlerSuiteJ {
   }
 
   @Test
-  public void testReadSortFileOnceOriginalFileBeDeleted() throws IOException {
+  public void testWorkerReadSortFileOnceOriginalFileBeDeleted() throws 
IOException {
     FileInfo fileInfo = null;
     try {
       // total write size: 32 * 50 * 256k = 400m
@@ -229,9 +229,9 @@ public class FetchHandlerSuiteJ {
       PbStreamHandler rangeReadStreamHandler =
           openStreamAndCheck(client, channel, fetchHandler, 5, 10);
       checkOriginFileBeDeleted(fileInfo);
-      PbStreamHandler nonRangeReadstreamHandler =
+      PbStreamHandler nonRangeReadStreamHandler =
           openStreamAndCheck(client, channel, fetchHandler, 0, 
Integer.MAX_VALUE);
-      fetchChunkAndCheck(client, channel, fetchHandler, 
nonRangeReadstreamHandler);
+      fetchChunkAndCheck(client, channel, fetchHandler, 
nonRangeReadStreamHandler);
       fetchChunkAndCheck(client, channel, fetchHandler, 
rangeReadStreamHandler);
     } finally {
       cleanup(fileInfo);
@@ -239,7 +239,7 @@ public class FetchHandlerSuiteJ {
   }
 
   @Test
-  public void testDoNotDeleteOriginalFileWhenNonRangeReadWorkInProgress() 
throws IOException {
+  public void testLocalReadSortFileOnceOriginalFileBeDeleted() throws 
IOException {
     FileInfo fileInfo = null;
     try {
       // total write size: 32 * 50 * 256k = 400m
@@ -248,15 +248,56 @@ public class FetchHandlerSuiteJ {
       TransportClient client = new TransportClient(channel, 
mock(TransportResponseHandler.class));
       FetchHandler fetchHandler = mockFetchHandler(fileInfo);
 
-      PbStreamHandler nonRangeReadstreamHandler =
+      // read local shuffle
+      openStreamAndCheck(client, channel, fetchHandler, 5, 10, true);
+      checkOriginFileBeDeleted(fileInfo);
+    } finally {
+      cleanup(fileInfo);
+    }
+  }
+
+  @Test
+  public void 
testDoNotDeleteOriginalFileWhenNonRangeWorkerReadWorkInProgress() throws 
IOException {
+    FileInfo fileInfo = null;
+    try {
+      // total write size: 32 * 50 * 256k = 400m
+      fileInfo = prepare(32);
+      EmbeddedChannel channel = new EmbeddedChannel();
+      TransportClient client = new TransportClient(channel, 
mock(TransportResponseHandler.class));
+      FetchHandler fetchHandler = mockFetchHandler(fileInfo);
+
+      PbStreamHandler nonRangeReadStreamHandler =
           openStreamAndCheck(client, channel, fetchHandler, 0, 
Integer.MAX_VALUE);
       PbStreamHandler rangeReadStreamHandler =
           openStreamAndCheck(client, channel, fetchHandler, 5, 10);
-      fetchChunkAndCheck(client, channel, fetchHandler, 
nonRangeReadstreamHandler);
+      fetchChunkAndCheck(client, channel, fetchHandler, 
nonRangeReadStreamHandler);
       fetchChunkAndCheck(client, channel, fetchHandler, 
rangeReadStreamHandler);
 
       // non-range fetch finished.
-      bufferStreamEnd(client, fetchHandler, 
nonRangeReadstreamHandler.getStreamId());
+      bufferStreamEnd(client, fetchHandler, 
nonRangeReadStreamHandler.getStreamId());
+      checkOriginFileBeDeleted(fileInfo);
+    } finally {
+      cleanup(fileInfo);
+    }
+  }
+
+  @Test
+  public void testDoNotDeleteOriginalFileWhenNonRangeLocalReadWorkInProgress() 
throws IOException {
+    FileInfo fileInfo = null;
+    try {
+      // total write size: 32 * 50 * 256k = 400m
+      fileInfo = prepare(32);
+      EmbeddedChannel channel = new EmbeddedChannel();
+      TransportClient client = new TransportClient(channel, 
mock(TransportResponseHandler.class));
+      FetchHandler fetchHandler = mockFetchHandler(fileInfo);
+
+      // read local shuffle
+      PbStreamHandler nonRangeReadStreamHandler =
+          openStreamAndCheck(client, channel, fetchHandler, 0, 
Integer.MAX_VALUE, true);
+      openStreamAndCheck(client, channel, fetchHandler, 5, 10);
+
+      // non-range fetch finished.
+      bufferStreamEnd(client, fetchHandler, 
nonRangeReadStreamHandler.getStreamId());
       checkOriginFileBeDeleted(fileInfo);
     } finally {
       cleanup(fileInfo);
@@ -316,6 +357,17 @@ public class FetchHandlerSuiteJ {
       int startIndex,
       int endIndex)
       throws IOException {
+    return openStreamAndCheck(client, channel, fetchHandler, startIndex, 
endIndex, false);
+  }
+
+  private PbStreamHandler openStreamAndCheck(
+      TransportClient client,
+      EmbeddedChannel channel,
+      FetchHandler fetchHandler,
+      int startIndex,
+      int endIndex,
+      Boolean readLocalShuffle)
+      throws IOException {
     ByteBuffer openStreamByteBuffer =
         new TransportMessage(
                 MessageType.OPEN_STREAM,
@@ -324,6 +376,7 @@ public class FetchHandlerSuiteJ {
                     .setFileName(fileName)
                     .setStartIndex(startIndex)
                     .setEndIndex(endIndex)
+                    .setReadLocalShuffle(readLocalShuffle)
                     .build()
                     .toByteArray())
             .toByteBuffer();

Reply via email to