This is an automated email from the ASF dual-hosted git repository.

zhouky pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new 074597d05 [CELEBORN-1147] Added a dedicated API for RPC messages which 
also accepts an RpcResponseCallback instance
074597d05 is described below

commit 074597d05e15bbe031938795d2079f567c7851e8
Author: Chandni Singh <[email protected]>
AuthorDate: Sat Dec 9 02:02:15 2023 +0800

    [CELEBORN-1147] Added a dedicated API for RPC messages which also accepts 
an RpcResponseCallback instance
    
    ### What changes were proposed in this pull request?
    
    Currently in `BaseMessageHandler` there is a single API for receive which 
is used for all messages. This makes handling messages when multiple handlers 
are added messy.
    
    - req.body.release() is only invoked when the handler actually process the 
message and not delegates it.
    - every handler will have to create an instance of RpcResponseCallback for 
Rpc messages which is exactly the same.
    
    Instead, releasing the message body and creating a callback for Rpc 
messages can be done in TransportRequestHandler. This avoids:
    
    - code duplication related to RpcResponseCallback in every RPC handler
    - every new request handler doesn't need to release the request body. It 
will be always be done in TransportRequestHandler.
    
    Please note that this is how it is in Apache Spark and with Sasl 
Authentication, we will add a SaslRpcHandler 
(https://github.com/apache/incubator-celeborn/pull/2105) which wraps the 
underlying message handler.
    
    ### Why are the changes needed?
    
    The changes are needed for adding authentication to Celeborn. See 
[CELEBORN-1011](https://issues.apache.org/jira/browse/CELEBORN-1011).
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Existing UTs and added some more UTs.
    
    Closes #2123 from otterc/CELEBORN-1147.
    
    Authored-by: Chandni Singh <[email protected]>
    Signed-off-by: zky.zhoukeyong <[email protected]>
---
 .../plugin/flink/network/ReadClientHandler.java    |   6 ++
 .../common/network/server/BaseMessageHandler.java  |   5 +
 .../network/server/TransportRequestHandler.java    |  63 +++++++++++-
 .../celeborn/common/rpc/netty/NettyRpcEnv.scala    |  50 +++-------
 .../common/network/RpcIntegrationSuiteJ.java       |  53 ++++------
 .../server/TransportRequestHandlerSuiteJ.java      |  83 +++++++++++++++
 .../service/deploy/worker/FetchHandler.scala       | 111 +++++++++++----------
 .../service/deploy/worker/PushDataHandler.scala    |  46 +++++----
 .../service/deploy/worker/FetchHandlerSuiteJ.java  |  32 +++++-
 .../network/RequestTimeoutIntegrationSuiteJ.java   |  18 ++++
 .../storage/ChunkFetchIntegrationSuiteJ.java       |  24 ++++-
 11 files changed, 341 insertions(+), 150 deletions(-)

diff --git 
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/ReadClientHandler.java
 
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/ReadClientHandler.java
index b55e3a618..d773edb88 100644
--- 
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/ReadClientHandler.java
+++ 
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/ReadClientHandler.java
@@ -29,6 +29,7 @@ import java.util.function.Consumer;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import org.apache.celeborn.common.network.client.RpcResponseCallback;
 import org.apache.celeborn.common.network.client.TransportClient;
 import org.apache.celeborn.common.network.protocol.BacklogAnnouncement;
 import org.apache.celeborn.common.network.protocol.BufferStreamEnd;
@@ -70,6 +71,11 @@ public class ReadClientHandler extends BaseMessageHandler {
     }
   }
 
+  @Override
+  public void receive(TransportClient client, RequestMessage msg, 
RpcResponseCallback callback) {
+    receive(client, msg);
+  }
+
   @Override
   public void receive(TransportClient client, RequestMessage msg) {
     switch (msg.type()) {
diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/server/BaseMessageHandler.java
 
b/common/src/main/java/org/apache/celeborn/common/network/server/BaseMessageHandler.java
index e0f7f085d..d975dc482 100644
--- 
a/common/src/main/java/org/apache/celeborn/common/network/server/BaseMessageHandler.java
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/server/BaseMessageHandler.java
@@ -17,6 +17,7 @@
 
 package org.apache.celeborn.common.network.server;
 
+import org.apache.celeborn.common.network.client.RpcResponseCallback;
 import org.apache.celeborn.common.network.client.TransportClient;
 import org.apache.celeborn.common.network.protocol.RequestMessage;
 
@@ -27,6 +28,10 @@ public class BaseMessageHandler {
     throw new UnsupportedOperationException();
   }
 
+  public void receive(TransportClient client, RequestMessage msg, 
RpcResponseCallback callback) {
+    throw new UnsupportedOperationException();
+  }
+
   public boolean checkRegistered() {
     throw new UnsupportedOperationException();
   }
diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/server/TransportRequestHandler.java
 
b/common/src/main/java/org/apache/celeborn/common/network/server/TransportRequestHandler.java
index 3098d413f..fde1f56e2 100644
--- 
a/common/src/main/java/org/apache/celeborn/common/network/server/TransportRequestHandler.java
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/server/TransportRequestHandler.java
@@ -19,6 +19,7 @@ package org.apache.celeborn.common.network.server;
 
 import java.io.IOException;
 import java.net.SocketAddress;
+import java.nio.ByteBuffer;
 
 import com.google.common.base.Throwables;
 import io.netty.channel.Channel;
@@ -26,6 +27,8 @@ import io.netty.channel.ChannelFuture;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import org.apache.celeborn.common.network.buffer.NioManagedBuffer;
+import org.apache.celeborn.common.network.client.RpcResponseCallback;
 import org.apache.celeborn.common.network.client.TransportClient;
 import org.apache.celeborn.common.network.protocol.*;
 
@@ -73,8 +76,66 @@ public class TransportRequestHandler extends 
MessageHandler<RequestMessage> {
 
   @Override
   public void handle(RequestMessage request) {
+    if (logger.isTraceEnabled()) {
+      logger.trace("Received request {} from {}", 
request.getClass().getName(), reverseClient);
+    }
     if (checkRegistered(request)) {
-      msgHandler.receive(reverseClient, request);
+      if (request instanceof RpcRequest) {
+        processRpcRequest((RpcRequest) request);
+      } else if (request instanceof OneWayMessage) {
+        processOneWayMessage((OneWayMessage) request);
+      } else {
+        processOtherMessages(request);
+      }
+    }
+  }
+
+  private void processRpcRequest(final RpcRequest req) {
+    try {
+      logger.trace("Process rpc request {}", req.requestId);
+      msgHandler.receive(
+          reverseClient,
+          req,
+          new RpcResponseCallback() {
+            @Override
+            public void onSuccess(ByteBuffer response) {
+              respond(new RpcResponse(req.requestId, new 
NioManagedBuffer(response)));
+            }
+
+            @Override
+            public void onFailure(Throwable e) {
+              respond(new RpcFailure(req.requestId, 
Throwables.getStackTraceAsString(e)));
+            }
+          });
+    } catch (Exception e) {
+      logger.error("Error while invoking handler#receive() on RPC id " + 
req.requestId, e);
+      respond(new RpcFailure(req.requestId, 
Throwables.getStackTraceAsString(e)));
+    } finally {
+      req.body().release();
+    }
+  }
+
+  private void processOneWayMessage(OneWayMessage req) {
+    try {
+      logger.trace("Process one way request");
+      msgHandler.receive(reverseClient, req);
+    } catch (Exception e) {
+      logger.error("Error while invoking handler#receive() for one-way 
message.", e);
+    } finally {
+      req.body().release();
+    }
+  }
+
+  private void processOtherMessages(RequestMessage req) {
+    try {
+      logger.trace("delegating to handler to process other request");
+      msgHandler.receive(reverseClient, req);
+    } catch (Exception e) {
+      logger.error("Error while invoking handler#receive() for other 
message.", e);
+    } finally {
+      if (req.body() != null) {
+        req.body().release();
+      }
     }
   }
 
diff --git 
a/common/src/main/scala/org/apache/celeborn/common/rpc/netty/NettyRpcEnv.scala 
b/common/src/main/scala/org/apache/celeborn/common/rpc/netty/NettyRpcEnv.scala
index 2d467cc9b..153751808 100644
--- 
a/common/src/main/scala/org/apache/celeborn/common/rpc/netty/NettyRpcEnv.scala
+++ 
b/common/src/main/scala/org/apache/celeborn/common/rpc/netty/NettyRpcEnv.scala
@@ -36,7 +36,7 @@ import org.apache.celeborn.common.internal.Logging
 import org.apache.celeborn.common.network.TransportContext
 import org.apache.celeborn.common.network.buffer.NioManagedBuffer
 import org.apache.celeborn.common.network.client._
-import org.apache.celeborn.common.network.protocol.{OneWayMessage => 
NOneWayMessage, RequestMessage => NRequestMessage, RpcFailure => NRpcFailure, 
RpcRequest, RpcResponse}
+import org.apache.celeborn.common.network.protocol.{RequestMessage => 
NRequestMessage, RpcFailure => NRpcFailure, RpcRequest}
 import org.apache.celeborn.common.network.server._
 import org.apache.celeborn.common.protocol.{RpcNameConstants, 
TransportModuleConstants}
 import org.apache.celeborn.common.rpc._
@@ -530,53 +530,29 @@ private[celeborn] class NettyRpcHandler(
   override def receive(
       client: TransportClient,
       requestMessage: NRequestMessage): Unit = {
-    requestMessage match {
-      case r: RpcRequest =>
-        processRpc(client, r)
-      case r: NOneWayMessage =>
-        processOnewayMessage(client, r)
-    }
-  }
-
-  private def processRpc(client: TransportClient, r: RpcRequest): Unit = {
-    val callback = new RpcResponseCallback {
-      override def onSuccess(response: ByteBuffer): Unit = {
-        client.getChannel.writeAndFlush(new RpcResponse(
-          r.requestId,
-          new NioManagedBuffer(response)))
-      }
-
-      override def onFailure(e: Throwable): Unit = {
-        client.getChannel.writeAndFlush(new NRpcFailure(
-          r.requestId,
-          Throwables.getStackTraceAsString(e)))
-      }
-    }
     try {
-      val message = r.body().nioByteBuffer()
+      val message = requestMessage.body().nioByteBuffer()
       val messageToDispatch = internalReceive(client, message)
-      dispatcher.postRemoteMessage(messageToDispatch, callback)
+      dispatcher.postOneWayMessage(messageToDispatch)
     } catch {
       case e: Exception =>
-        logError("Error while invoking RpcHandler#receive() on RPC id " + 
r.requestId, e)
-        client.getChannel.writeAndFlush(new NRpcFailure(
-          r.requestId,
-          Throwables.getStackTraceAsString(e)))
-    } finally {
-      r.body().release()
+        logError("Error while invoking NettyRpcHandler#receive() for one-way 
message.", e)
     }
   }
 
-  private def processOnewayMessage(client: TransportClient, r: 
NOneWayMessage): Unit = {
+  override def receive(
+      client: TransportClient,
+      requestMessage: NRequestMessage,
+      callback: RpcResponseCallback): Unit = {
     try {
-      val message = r.body().nioByteBuffer()
+      val message = requestMessage.body().nioByteBuffer()
       val messageToDispatch = internalReceive(client, message)
-      dispatcher.postOneWayMessage(messageToDispatch)
+      dispatcher.postRemoteMessage(messageToDispatch, callback)
     } catch {
       case e: Exception =>
-        logError("Error while invoking RpcHandler#receive() for one-way 
message.", e)
-    } finally {
-      r.body().release()
+        val rpcReq = requestMessage.asInstanceOf[RpcRequest]
+        logError("Error while invoking NettyRpcHandler#receive() on RPC id " + 
rpcReq.requestId, e)
+        callback.onFailure(e)
     }
   }
 
diff --git 
a/common/src/test/java/org/apache/celeborn/common/network/RpcIntegrationSuiteJ.java
 
b/common/src/test/java/org/apache/celeborn/common/network/RpcIntegrationSuiteJ.java
index e3748cae6..ec1834b4d 100644
--- 
a/common/src/test/java/org/apache/celeborn/common/network/RpcIntegrationSuiteJ.java
+++ 
b/common/src/test/java/org/apache/celeborn/common/network/RpcIntegrationSuiteJ.java
@@ -25,7 +25,6 @@ import java.util.*;
 import java.util.concurrent.Semaphore;
 import java.util.concurrent.TimeUnit;
 
-import com.google.common.base.Throwables;
 import com.google.common.collect.Sets;
 import org.apache.commons.lang3.tuple.ImmutablePair;
 import org.apache.commons.lang3.tuple.Pair;
@@ -34,7 +33,6 @@ import org.junit.BeforeClass;
 import org.junit.Test;
 
 import org.apache.celeborn.common.CelebornConf;
-import org.apache.celeborn.common.network.buffer.NioManagedBuffer;
 import org.apache.celeborn.common.network.client.RpcResponseCallback;
 import org.apache.celeborn.common.network.client.TransportClient;
 import org.apache.celeborn.common.network.client.TransportClientFactory;
@@ -60,32 +58,21 @@ public class RpcIntegrationSuiteJ {
         new BaseMessageHandler() {
           @Override
           public void receive(TransportClient client, RequestMessage message) {
-            if (message instanceof RpcRequest) {
-              String msg;
-              RpcRequest r = (RpcRequest) message;
-              RpcResponseCallback callback =
-                  new RpcResponseCallback() {
-                    @Override
-                    public void onSuccess(ByteBuffer response) {
-                      client
-                          .getChannel()
-                          .writeAndFlush(
-                              new RpcResponse(r.requestId, new 
NioManagedBuffer(response)));
-                    }
-
-                    @Override
-                    public void onFailure(Throwable e) {
-                      client
-                          .getChannel()
-                          .writeAndFlush(
-                              new RpcFailure(r.requestId, 
Throwables.getStackTraceAsString(e)));
-                    }
-                  };
-              try {
-                msg = JavaUtils.bytesToString(message.body().nioByteBuffer());
-              } catch (Exception e) {
-                throw new RuntimeException(e);
-              }
+            assertTrue(message instanceof OneWayMessage);
+            String msg;
+            try {
+              msg = JavaUtils.bytesToString(message.body().nioByteBuffer());
+            } catch (Exception e) {
+              throw new RuntimeException(e);
+            }
+            oneWayMsgs.add(msg);
+          }
+
+          @Override
+          public void receive(
+              TransportClient client, RequestMessage requestMessage, 
RpcResponseCallback callback) {
+            try {
+              String msg = 
JavaUtils.bytesToString(requestMessage.body().nioByteBuffer());
               String[] parts = msg.split("/");
               if (parts[0].equals("hello")) {
                 callback.onSuccess(JavaUtils.stringToBytes("Hello, " + 
parts[1] + "!"));
@@ -94,14 +81,8 @@ public class RpcIntegrationSuiteJ {
               } else if (parts[0].equals("throw error")) {
                 callback.onFailure(new RuntimeException("Thrown: " + 
parts[1]));
               }
-            } else if (message instanceof OneWayMessage) {
-              String msg;
-              try {
-                msg = JavaUtils.bytesToString(message.body().nioByteBuffer());
-              } catch (Exception e) {
-                throw new RuntimeException(e);
-              }
-              oneWayMsgs.add(msg);
+            } catch (Exception e) {
+              throw new RuntimeException(e);
             }
           }
 
diff --git 
a/common/src/test/java/org/apache/celeborn/common/network/server/TransportRequestHandlerSuiteJ.java
 
b/common/src/test/java/org/apache/celeborn/common/network/server/TransportRequestHandlerSuiteJ.java
new file mode 100644
index 000000000..e1baafd2e
--- /dev/null
+++ 
b/common/src/test/java/org/apache/celeborn/common/network/server/TransportRequestHandlerSuiteJ.java
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.network.server;
+
+import static org.mockito.Mockito.*;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.Channel;
+import org.junit.Before;
+import org.junit.Test;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+
+import org.apache.celeborn.common.network.buffer.NettyManagedBuffer;
+import org.apache.celeborn.common.network.client.TransportClient;
+import org.apache.celeborn.common.network.protocol.OneWayMessage;
+import org.apache.celeborn.common.network.protocol.PushData;
+import org.apache.celeborn.common.network.protocol.RpcRequest;
+
+public class TransportRequestHandlerSuiteJ {
+
+  @Mock private Channel channel;
+
+  @Mock private TransportClient reverseClient;
+
+  @Mock private BaseMessageHandler msgHandler;
+
+  private TransportRequestHandler requestHandler;
+
+  @Before
+  public void setUp() {
+    MockitoAnnotations.openMocks(this);
+    when(msgHandler.checkRegistered()).thenReturn(true);
+    requestHandler = new TransportRequestHandler(channel, reverseClient, 
msgHandler);
+  }
+
+  @Test
+  public void testHandleRpcRequest() {
+    ByteBuf buffer = Unpooled.wrappedBuffer(new byte[] {1});
+    RpcRequest rpcRequest = new RpcRequest(1, new NettyManagedBuffer(buffer));
+    requestHandler.handle(rpcRequest);
+    verify(msgHandler).receive(eq(reverseClient), eq(rpcRequest), any());
+    verify(msgHandler, times(0)).receive(eq(reverseClient), eq(rpcRequest));
+    assert buffer.refCnt() == 0;
+  }
+
+  @Test
+  public void testHandleOneWayMessage() {
+    ByteBuf buffer = Unpooled.wrappedBuffer(new byte[] {1});
+    OneWayMessage oneWayMessage = new OneWayMessage(new 
NettyManagedBuffer(buffer));
+    requestHandler.handle(oneWayMessage);
+    verify(msgHandler).receive(eq(reverseClient), eq(oneWayMessage));
+    verify(msgHandler, times(0)).receive(eq(reverseClient), eq(oneWayMessage), 
any());
+    assert buffer.refCnt() == 0;
+  }
+
+  @Test
+  public void testHandleOtherMessage() {
+    ByteBuf buffer = Unpooled.wrappedBuffer(new byte[] {1});
+    PushData pushData =
+        new PushData((byte) 0, "shuffleKey", "partitionId", new 
NettyManagedBuffer(buffer));
+    requestHandler.handle(pushData);
+    verify(msgHandler).receive(eq(reverseClient), eq(pushData));
+    verify(msgHandler, times(0)).receive(eq(reverseClient), eq(pushData), 
any());
+    assert buffer.refCnt() == 0;
+  }
+}
diff --git 
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala
 
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala
index f28927fe7..fa163d4f6 100644
--- 
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala
+++ 
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala
@@ -33,7 +33,7 @@ import 
org.apache.celeborn.common.exception.CelebornIOException
 import org.apache.celeborn.common.internal.Logging
 import org.apache.celeborn.common.meta.{FileInfo, FileManagedBuffers}
 import org.apache.celeborn.common.network.buffer.NioManagedBuffer
-import org.apache.celeborn.common.network.client.TransportClient
+import org.apache.celeborn.common.network.client.{RpcResponseCallback, 
TransportClient}
 import org.apache.celeborn.common.network.protocol._
 import org.apache.celeborn.common.network.server.BaseMessageHandler
 import org.apache.celeborn.common.network.util.{NettyUtils, TransportConf}
@@ -87,6 +87,13 @@ class FetchHandler(
     fileInfo
   }
 
+  override def receive(
+      client: TransportClient,
+      msg: RequestMessage,
+      callback: RpcResponseCallback): Unit = {
+    handleRpcRequest(client, msg.asInstanceOf[RpcRequest], callback)
+  }
+
   override def receive(client: TransportClient, msg: RequestMessage): Unit = {
     msg match {
       case r: BufferStreamEnd =>
@@ -95,54 +102,55 @@ class FetchHandler(
         handleReadAddCredit(r.getCredit, r.getStreamId)
       case r: ChunkFetchRequest =>
         handleChunkFetchRequest(client, r.streamChunkSlice, r)
-      case r: RpcRequest =>
-        handleRpcRequest(client, r)
       case unknown: RequestMessage =>
         throw new IllegalArgumentException(s"Unknown message type id: 
${unknown.`type`.id}")
     }
   }
 
-  private def handleRpcRequest(client: TransportClient, rpcRequest: 
RpcRequest): Unit = {
+  private def handleRpcRequest(
+      client: TransportClient,
+      rpcRequest: RpcRequest,
+      callback: RpcResponseCallback): Unit = {
+    var message: GeneratedMessageV3 = null
     try {
-      var message: GeneratedMessageV3 = null
-      try {
-        message = 
TransportMessage.fromByteBuffer(rpcRequest.body().nioByteBuffer())
-          .getParsedPayload[GeneratedMessageV3]
-      } catch {
-        case exception: CelebornIOException =>
-          logWarning("Handle request with legacy RPCs", exception)
-          return handleLegacyRpcMessage(client, rpcRequest)
-      }
-      message match {
-        case openStream: PbOpenStream =>
-          handleOpenStreamInternal(
-            client,
-            openStream.getShuffleKey,
-            openStream.getFileName,
-            openStream.getStartIndex,
-            openStream.getEndIndex,
-            openStream.getInitialCredit,
-            rpcRequest.requestId,
-            isLegacy = false,
-            openStream.getReadLocalShuffle)
-        case bufferStreamEnd: PbBufferStreamEnd =>
-          handleEndStreamFromClient(bufferStreamEnd.getStreamId, 
bufferStreamEnd.getStreamType)
-        case readAddCredit: PbReadAddCredit =>
-          handleReadAddCredit(readAddCredit.getCredit, 
readAddCredit.getStreamId)
-        case chunkFetchRequest: PbChunkFetchRequest =>
-          handleChunkFetchRequest(
-            client,
-            StreamChunkSlice.fromProto(chunkFetchRequest.getStreamChunkSlice),
-            rpcRequest)
-        case message: GeneratedMessageV3 =>
-          logError(s"Unknown message $message")
-      }
-    } finally {
-      rpcRequest.body().release()
+      message = 
TransportMessage.fromByteBuffer(rpcRequest.body().nioByteBuffer())
+        .getParsedPayload[GeneratedMessageV3]
+    } catch {
+      case exception: CelebornIOException =>
+        logWarning("Handle request with legacy RPCs", exception)
+        return handleLegacyRpcMessage(client, rpcRequest, callback)
+    }
+    message match {
+      case openStream: PbOpenStream =>
+        handleOpenStreamInternal(
+          client,
+          openStream.getShuffleKey,
+          openStream.getFileName,
+          openStream.getStartIndex,
+          openStream.getEndIndex,
+          openStream.getInitialCredit,
+          rpcRequest.requestId,
+          isLegacy = false,
+          openStream.getReadLocalShuffle,
+          callback)
+      case bufferStreamEnd: PbBufferStreamEnd =>
+        handleEndStreamFromClient(bufferStreamEnd.getStreamId, 
bufferStreamEnd.getStreamType)
+      case readAddCredit: PbReadAddCredit =>
+        handleReadAddCredit(readAddCredit.getCredit, readAddCredit.getStreamId)
+      case chunkFetchRequest: PbChunkFetchRequest =>
+        handleChunkFetchRequest(
+          client,
+          StreamChunkSlice.fromProto(chunkFetchRequest.getStreamChunkSlice),
+          rpcRequest)
+      case message: GeneratedMessageV3 =>
+        logError(s"Unknown message $message")
     }
   }
 
-  private def handleLegacyRpcMessage(client: TransportClient, rpcRequest: 
RpcRequest): Unit = {
+  private def handleLegacyRpcMessage(
+      client: TransportClient,
+      rpcRequest: RpcRequest,
+      callback: RpcResponseCallback): Unit = {
     try {
       val message = Message.decode(rpcRequest.body().nioByteBuffer())
       message.`type`() match {
@@ -158,7 +166,8 @@ class FetchHandler(
             rpcRequestId = rpcRequest.requestId,
             isLegacy = true,
             // legacy [[OpenStream]] doesn't support read local shuffle
-            readLocalShuffle = false)
+            readLocalShuffle = false,
+            callback)
         case Message.Type.OPEN_STREAM_WITH_CREDIT =>
           val openStreamWithCredit = message.asInstanceOf[OpenStreamWithCredit]
           handleOpenStreamInternal(
@@ -170,7 +179,8 @@ class FetchHandler(
             openStreamWithCredit.initialCredit,
             rpcRequestId = rpcRequest.requestId,
             isLegacy = true,
-            readLocalShuffle = false)
+            readLocalShuffle = false,
+            callback)
         case _ =>
           logError(s"Received an unknown message type id: 
${message.`type`.id}")
       }
@@ -190,7 +200,8 @@ class FetchHandler(
       initialCredit: Int,
       rpcRequestId: Long,
       isLegacy: Boolean,
-      readLocalShuffle: Boolean = false): Unit = {
+      readLocalShuffle: Boolean = false,
+      callback: RpcResponseCallback): Unit = {
     workerSource.startTimer(WorkerSource.OPEN_STREAM_TIME, shuffleKey)
     try {
       var fileInfo = getRawFileInfo(shuffleKey, fileName)
@@ -263,7 +274,7 @@ class FetchHandler(
       }
     } catch {
       case e: IOException =>
-        handleRpcIOException(client, rpcRequestId, shuffleKey, fileName, e)
+        handleRpcIOException(client, rpcRequestId, shuffleKey, fileName, e, 
callback)
     } finally {
       workerSource.stopTimer(WorkerSource.OPEN_STREAM_TIME, shuffleKey)
     }
@@ -304,23 +315,23 @@ class FetchHandler(
       requestId: Long,
       shuffleKey: String,
       fileName: String,
-      ioe: IOException): Unit = {
+      ioe: IOException,
+      rpcCallback: RpcResponseCallback): Unit = {
     // if open stream rpc failed, this IOException actually should be 
FileNotFoundException
     // we wrapper this IOException(Other place may have other exception like 
FileCorruptException) unify to
     // PartitionUnRetryableException for reader can give up this partition and 
choose to regenerate the partition data
     logError(
       s"Read file: $fileName with shuffleKey: $shuffleKey error from 
${NettyUtils.getRemoteAddress(client.getChannel)}",
       ioe)
-    handleRpcException(client, requestId, ioe)
+    handleRpcException(client, requestId, ioe, rpcCallback)
   }
 
   private def handleRpcException(
       client: TransportClient,
       requestId: Long,
-      ioe: IOException): Unit = {
-    client.getChannel.writeAndFlush(new RpcFailure(
-      requestId,
-      
Throwables.getStackTraceAsString(ExceptionUtils.wrapIOExceptionToUnRetryable(ioe))))
+      ioe: IOException,
+      rpcResponseCallback: RpcResponseCallback): Unit = {
+    
rpcResponseCallback.onFailure(ExceptionUtils.wrapIOExceptionToUnRetryable(ioe))
   }
 
   def handleEndStreamFromClient(streamId: Long): Unit = {
diff --git 
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala
 
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala
index 208ff8bca..9beca4c00 100644
--- 
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala
+++ 
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala
@@ -100,18 +100,25 @@ class PushDataHandler(val workerSource: WorkerSource) 
extends BaseMessageHandler
       s"diskReserveSize ${Utils.bytesToString(diskReserveSize)}, 
diskReserveRatio ${diskReserveRatio.orNull}")
   }
 
+  override def receive(
+      client: TransportClient,
+      msg: RequestMessage,
+      callback: RpcResponseCallback): Unit = {
+    handleRpcRequest(client, msg.asInstanceOf[RpcRequest], callback)
+  }
+
   override def receive(client: TransportClient, msg: RequestMessage): Unit =
     msg match {
       case pushData: PushData =>
+        val callback = new SimpleRpcResponseCallback(
+          client,
+          pushData.requestId,
+          pushData.shuffleKey)
         handleCore(
           client,
           pushData,
           pushData.requestId,
           () => {
-            val callback = new SimpleRpcResponseCallback(
-              client,
-              pushData.requestId,
-              pushData.shuffleKey)
             val partitionType =
               shufflePartitionType.getOrDefault(pushData.shuffleKey, 
PartitionType.REDUCE)
             partitionType match {
@@ -123,8 +130,13 @@ class PushDataHandler(val workerSource: WorkerSource) 
extends BaseMessageHandler
                   callback)
               case _ => throw new UnsupportedOperationException(s"Not support 
$partitionType yet")
             }
-          })
+          },
+          callback)
       case pushMergedData: PushMergedData =>
+        val callback = new SimpleRpcResponseCallback(
+          client,
+          pushMergedData.requestId,
+          pushMergedData.shuffleKey)
         handleCore(
           client,
           pushMergedData,
@@ -132,11 +144,8 @@ class PushDataHandler(val workerSource: WorkerSource) 
extends BaseMessageHandler
           () =>
             handlePushMergedData(
               pushMergedData,
-              new SimpleRpcResponseCallback(
-                client,
-                pushMergedData.requestId,
-                pushMergedData.shuffleKey)))
-      case rpcRequest: RpcRequest => handleRpcRequest(client, rpcRequest)
+              callback),
+          callback)
     }
 
   def handlePushData(pushData: PushData, callback: RpcResponseCallback): Unit 
= {
@@ -726,17 +735,14 @@ class PushDataHandler(val workerSource: WorkerSource) 
extends BaseMessageHandler
       client: TransportClient,
       message: RequestMessage,
       requestId: Long,
-      handler: () => Unit): Unit = {
+      handler: () => Unit,
+      callback: RpcResponseCallback): Unit = {
     try {
       handler()
     } catch {
       case e: Exception =>
         logError(s"Error while handle${message.`type`()} $message", e)
-        client.getChannel.writeAndFlush(new RpcFailure(
-          requestId,
-          Throwables.getStackTraceAsString(e)))
-    } finally {
-      message.body().release()
+        callback.onFailure(e)
     }
   }
 
@@ -812,7 +818,10 @@ class PushDataHandler(val workerSource: WorkerSource) 
extends BaseMessageHandler
     }
   }
 
-  private def handleRpcRequest(client: TransportClient, rpcRequest: 
RpcRequest): Unit = {
+  private def handleRpcRequest(
+      client: TransportClient,
+      rpcRequest: RpcRequest,
+      callback: RpcResponseCallback): Unit = {
     val requestId = rpcRequest.requestId
     val (pbMsg, msg, isLegacy, messageType, mode, shuffleKey, 
partitionUniqueId, checkSplit) =
       mapPartitionRpcRequest(rpcRequest)
@@ -834,7 +843,8 @@ class PushDataHandler(val workerSource: WorkerSource) 
extends BaseMessageHandler
           new SimpleRpcResponseCallback(
             client,
             requestId,
-            shuffleKey)))
+            shuffleKey)),
+      callback)
   }
 
   private def mapPartitionRpcRequest(rpcRequest: RpcRequest)
diff --git 
a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/FetchHandlerSuiteJ.java
 
b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/FetchHandlerSuiteJ.java
index 7da20b2ea..504aae14b 100644
--- 
a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/FetchHandlerSuiteJ.java
+++ 
b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/FetchHandlerSuiteJ.java
@@ -33,6 +33,8 @@ import java.util.Map;
 import java.util.Random;
 import java.util.UUID;
 
+import com.google.common.base.Throwables;
+import io.netty.channel.Channel;
 import io.netty.channel.embedded.EmbeddedChannel;
 import org.junit.AfterClass;
 import org.junit.BeforeClass;
@@ -45,11 +47,13 @@ import org.apache.celeborn.common.CelebornConf;
 import org.apache.celeborn.common.identity.UserIdentifier;
 import org.apache.celeborn.common.meta.FileInfo;
 import org.apache.celeborn.common.network.buffer.NioManagedBuffer;
+import org.apache.celeborn.common.network.client.RpcResponseCallback;
 import org.apache.celeborn.common.network.client.TransportClient;
 import org.apache.celeborn.common.network.client.TransportResponseHandler;
 import org.apache.celeborn.common.network.protocol.ChunkFetchSuccess;
 import org.apache.celeborn.common.network.protocol.Message;
 import org.apache.celeborn.common.network.protocol.OpenStream;
+import org.apache.celeborn.common.network.protocol.RpcFailure;
 import org.apache.celeborn.common.network.protocol.RpcRequest;
 import org.apache.celeborn.common.network.protocol.RpcResponse;
 import org.apache.celeborn.common.network.protocol.StreamHandle;
@@ -289,7 +293,9 @@ public class FetchHandlerSuiteJ {
     ByteBuffer openStreamByteBuffer =
         new OpenStream(shuffleKey, fileName, startIndex, 
endIndex).toByteBuffer();
     fetchHandler.receive(
-        client, new RpcRequest(dummyRequestId, new 
NioManagedBuffer(openStreamByteBuffer)));
+        client,
+        new RpcRequest(dummyRequestId, new 
NioManagedBuffer(openStreamByteBuffer)),
+        createRpcResponseCallback(channel));
     RpcResponse result = channel.readOutbound();
     StreamHandle streamHandler = (StreamHandle) 
Message.decode(result.body().nioByteBuffer());
     if (endIndex == Integer.MAX_VALUE) {
@@ -318,7 +324,9 @@ public class FetchHandlerSuiteJ {
                     .toByteArray())
             .toByteBuffer();
     fetchHandler.receive(
-        client, new RpcRequest(dummyRequestId, new 
NioManagedBuffer(openStreamByteBuffer)));
+        client,
+        new RpcRequest(dummyRequestId, new 
NioManagedBuffer(openStreamByteBuffer)),
+        createRpcResponseCallback(channel));
     RpcResponse result = channel.readOutbound();
     PbStreamHandler streamHandler =
         
TransportMessage.fromByteBuffer(result.body().nioByteBuffer()).getParsedPayload();
@@ -352,7 +360,8 @@ public class FetchHandlerSuiteJ {
                                       .setLen(Integer.MAX_VALUE))
                               .build()
                               .toByteArray())
-                      .toByteBuffer())));
+                      .toByteBuffer())),
+          createRpcResponseCallback(channel));
       ChunkFetchSuccess chunkFetchSuccess = channel.readOutbound();
       chunkFetchSuccess.body().retain();
       // chunk size 8m
@@ -372,7 +381,8 @@ public class FetchHandlerSuiteJ {
                 .toByteArray());
     fetchHandler.receive(
         client,
-        new RpcRequest(dummyRequestId, new 
NioManagedBuffer(bufferStreamEnd.toByteBuffer())));
+        new RpcRequest(dummyRequestId, new 
NioManagedBuffer(bufferStreamEnd.toByteBuffer())),
+        createRpcResponseCallback(client.getChannel()));
   }
 
   private void checkOriginFileBeDeleted(FileInfo fileInfo) {
@@ -403,4 +413,18 @@ public class FetchHandlerSuiteJ {
     Collections.shuffle(ids);
     return ids;
   }
+
+  private RpcResponseCallback createRpcResponseCallback(Channel channel) {
+    return new RpcResponseCallback() {
+      @Override
+      public void onSuccess(ByteBuffer response) {
+        channel.writeAndFlush(new RpcResponse(dummyRequestId, new 
NioManagedBuffer(response)));
+      }
+
+      @Override
+      public void onFailure(Throwable e) {
+        channel.writeAndFlush(new RpcFailure(dummyRequestId, 
Throwables.getStackTraceAsString(e)));
+      }
+    };
+  }
 }
diff --git 
a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/network/RequestTimeoutIntegrationSuiteJ.java
 
b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/network/RequestTimeoutIntegrationSuiteJ.java
index 85ce86367..fb083d392 100644
--- 
a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/network/RequestTimeoutIntegrationSuiteJ.java
+++ 
b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/network/RequestTimeoutIntegrationSuiteJ.java
@@ -107,6 +107,12 @@ public class RequestTimeoutIntegrationSuiteJ {
             }
           }
 
+          @Override
+          public void receive(
+              TransportClient client, RequestMessage msg, RpcResponseCallback 
callback) {
+            receive(client, msg);
+          }
+
           @Override
           public boolean checkRegistered() {
             return true;
@@ -157,6 +163,12 @@ public class RequestTimeoutIntegrationSuiteJ {
             }
           }
 
+          @Override
+          public void receive(
+              TransportClient client, RequestMessage msg, RpcResponseCallback 
callback) {
+            receive(client, msg);
+          }
+
           @Override
           public boolean checkRegistered() {
             return true;
@@ -216,6 +228,12 @@ public class RequestTimeoutIntegrationSuiteJ {
             client.getChannel().writeAndFlush(new ChunkFetchSuccess(slice, 
buf));
           }
 
+          @Override
+          public void receive(
+              TransportClient client, RequestMessage msg, RpcResponseCallback 
callback) {
+            receive(client, msg);
+          }
+
           @Override
           public boolean checkRegistered() {
             return true;
diff --git 
a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/ChunkFetchIntegrationSuiteJ.java
 
b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/ChunkFetchIntegrationSuiteJ.java
index 6a1df6747..36e12ca4c 100644
--- 
a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/ChunkFetchIntegrationSuiteJ.java
+++ 
b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/ChunkFetchIntegrationSuiteJ.java
@@ -28,6 +28,7 @@ import java.util.*;
 import java.util.concurrent.Semaphore;
 import java.util.concurrent.TimeUnit;
 
+import com.google.common.base.Throwables;
 import com.google.common.collect.Sets;
 import com.google.common.io.Closeables;
 import org.junit.AfterClass;
@@ -40,8 +41,10 @@ import 
org.apache.celeborn.common.network.buffer.FileSegmentManagedBuffer;
 import org.apache.celeborn.common.network.buffer.ManagedBuffer;
 import org.apache.celeborn.common.network.buffer.NioManagedBuffer;
 import org.apache.celeborn.common.network.client.ChunkReceivedCallback;
+import org.apache.celeborn.common.network.client.RpcResponseCallback;
 import org.apache.celeborn.common.network.client.TransportClient;
 import org.apache.celeborn.common.network.client.TransportClientFactory;
+import org.apache.celeborn.common.network.protocol.ChunkFetchFailure;
 import org.apache.celeborn.common.network.protocol.ChunkFetchSuccess;
 import org.apache.celeborn.common.network.protocol.RequestMessage;
 import org.apache.celeborn.common.network.protocol.StreamChunkSlice;
@@ -117,10 +120,23 @@ public class ChunkFetchIntegrationSuiteJ {
             }
             StreamChunkSlice slice =
                 
StreamChunkSlice.fromProto(chunkFetchRequest.getStreamChunkSlice());
-            ManagedBuffer buf =
-                chunkStreamManager.getChunk(
-                    slice.streamId, slice.chunkIndex, slice.offset, slice.len);
-            client.getChannel().writeAndFlush(new ChunkFetchSuccess(slice, 
buf));
+            ManagedBuffer buf;
+            try {
+              buf =
+                  chunkStreamManager.getChunk(
+                      slice.streamId, slice.chunkIndex, slice.offset, 
slice.len);
+              client.getChannel().writeAndFlush(new ChunkFetchSuccess(slice, 
buf));
+            } catch (Exception e) {
+              client
+                  .getChannel()
+                  .writeAndFlush(new ChunkFetchFailure(slice, 
Throwables.getStackTraceAsString(e)));
+            }
+          }
+
+          @Override
+          public void receive(
+              TransportClient client, RequestMessage msg, RpcResponseCallback 
callback) {
+            receive(client, msg);
           }
 
           @Override


Reply via email to