akpatnam25 commented on code in PR #2123:
URL:
https://github.com/apache/incubator-celeborn/pull/2123#discussion_r1412313514
##########
common/src/main/java/org/apache/celeborn/common/network/server/TransportRequestHandler.java:
##########
@@ -73,8 +76,64 @@ public void channelInactive() {
@Override
public void handle(RequestMessage request) {
+ logger.trace("Received request {} from {}", request.getClass().getName(),
reverseClient);
if (checkRegistered(request)) {
- msgHandler.receive(reverseClient, request);
+ if (request instanceof RpcRequest) {
Review Comment:
should we call `request.body().retain();` before we process the requests?
##########
worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala:
##########
@@ -726,17 +735,14 @@ class PushDataHandler(val workerSource: WorkerSource)
extends BaseMessageHandler
client: TransportClient,
message: RequestMessage,
requestId: Long,
- handler: () => Unit): Unit = {
+ handler: () => Unit,
+ callback: RpcResponseCallback): Unit = {
try {
handler()
} catch {
case e: Exception =>
logError(s"Error while handle${message.`type`()} $message", e)
- client.getChannel.writeAndFlush(new RpcFailure(
- requestId,
- Throwables.getStackTraceAsString(e)))
- } finally {
- message.body().release()
Review Comment:
why are these lines deleted? This is needed to maintain the buffer count,
right?
##########
common/src/test/java/org/apache/celeborn/common/network/server/TransportRequestHandlerSuiteJ.java:
##########
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.network.server;
+
+import static org.mockito.Mockito.*;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.Channel;
+import org.junit.Before;
+import org.junit.Test;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+
+import org.apache.celeborn.common.network.buffer.NettyManagedBuffer;
+import org.apache.celeborn.common.network.client.TransportClient;
+import org.apache.celeborn.common.network.protocol.OneWayMessage;
+import org.apache.celeborn.common.network.protocol.PushData;
+import org.apache.celeborn.common.network.protocol.RpcRequest;
+
+public class TransportRequestHandlerSuiteJ {
+
+ @Mock private Channel channel;
+
+ @Mock private TransportClient reverseClient;
+
+ @Mock private BaseMessageHandler msgHandler;
+
+ private TransportRequestHandler requestHandler;
+
+ @Before
+ public void setUp() {
+ MockitoAnnotations.openMocks(this);
+ when(msgHandler.checkRegistered()).thenReturn(true);
+ requestHandler = new TransportRequestHandler(channel, reverseClient,
msgHandler);
+ }
+
+ @Test
+ public void testHandleRpcRequest() {
+ ByteBuf buffer = Unpooled.wrappedBuffer(new byte[] {1});
+ RpcRequest rpcRequest = new RpcRequest(1, new NettyManagedBuffer(buffer));
+ requestHandler.handle(rpcRequest);
+ verify(msgHandler).receive(eq(reverseClient), eq(rpcRequest), any());
+ verify(msgHandler, times(0)).receive(eq(reverseClient), eq(rpcRequest));
+ assert buffer.refCnt() == 0;
+ }
+
+ @Test
+ public void testHandleOneWayMessage() {
+ when(msgHandler.checkRegistered()).thenReturn(true);
Review Comment:
isn't this already called on L49 in the setUp method? Same for L65.
##########
common/src/main/java/org/apache/celeborn/common/network/server/TransportRequestHandler.java:
##########
@@ -73,8 +76,64 @@ public void channelInactive() {
@Override
public void handle(RequestMessage request) {
+ logger.trace("Received request {} from {}", request.getClass().getName(),
reverseClient);
if (checkRegistered(request)) {
- msgHandler.receive(reverseClient, request);
+ if (request instanceof RpcRequest) {
+ processRpcRequest((RpcRequest) request);
+ } else if (request instanceof OneWayMessage) {
+ processOneWayMessage((OneWayMessage) request);
+ } else {
+ processOtherMessages(request);
+ }
+ }
+ }
+
+ private void processRpcRequest(final RpcRequest req) {
+ try {
+ logger.trace("Process rpc request {}", req.requestId);
+ msgHandler.receive(
+ reverseClient,
+ req,
+ new RpcResponseCallback() {
+ @Override
+ public void onSuccess(ByteBuffer response) {
+ respond(new RpcResponse(req.requestId, new
NioManagedBuffer(response)));
+ }
+
+ @Override
+ public void onFailure(Throwable e) {
+ respond(new RpcFailure(req.requestId,
Throwables.getStackTraceAsString(e)));
+ }
+ });
+ } catch (Exception e) {
Review Comment:
should we catch a specific exception here instead? Same for other places
too.
##########
common/src/main/java/org/apache/celeborn/common/network/server/TransportRequestHandler.java:
##########
@@ -73,8 +76,64 @@ public void channelInactive() {
@Override
public void handle(RequestMessage request) {
+ logger.trace("Received request {} from {}", request.getClass().getName(),
reverseClient);
if (checkRegistered(request)) {
- msgHandler.receive(reverseClient, request);
+ if (request instanceof RpcRequest) {
+ processRpcRequest((RpcRequest) request);
+ } else if (request instanceof OneWayMessage) {
+ processOneWayMessage((OneWayMessage) request);
+ } else {
+ processOtherMessages(request);
+ }
+ }
+ }
+
+ private void processRpcRequest(final RpcRequest req) {
+ try {
+ logger.trace("Process rpc request {}", req.requestId);
+ msgHandler.receive(
+ reverseClient,
+ req,
+ new RpcResponseCallback() {
+ @Override
+ public void onSuccess(ByteBuffer response) {
+ respond(new RpcResponse(req.requestId, new
NioManagedBuffer(response)));
+ }
+
+ @Override
+ public void onFailure(Throwable e) {
+ respond(new RpcFailure(req.requestId,
Throwables.getStackTraceAsString(e)));
+ }
+ });
+ } catch (Exception e) {
+ logger.error("Error while invoking handler#receive() on RPC id " +
req.requestId, e);
+ respond(new RpcFailure(req.requestId,
Throwables.getStackTraceAsString(e)));
+ } finally {
+ req.body().release();
Review Comment:
we don't need this check here?
```
if (req.body() != null) {
```
Same with L123.
##########
common/src/main/java/org/apache/celeborn/common/network/server/TransportRequestHandler.java:
##########
@@ -73,8 +76,64 @@ public void channelInactive() {
@Override
public void handle(RequestMessage request) {
+ logger.trace("Received request {} from {}", request.getClass().getName(),
reverseClient);
if (checkRegistered(request)) {
- msgHandler.receive(reverseClient, request);
+ if (request instanceof RpcRequest) {
+ processRpcRequest((RpcRequest) request);
+ } else if (request instanceof OneWayMessage) {
+ processOneWayMessage((OneWayMessage) request);
+ } else {
+ processOtherMessages(request);
+ }
+ }
+ }
+
+ private void processRpcRequest(final RpcRequest req) {
+ try {
+ logger.trace("Process rpc request {}", req.requestId);
+ msgHandler.receive(
+ reverseClient,
+ req,
+ new RpcResponseCallback() {
+ @Override
+ public void onSuccess(ByteBuffer response) {
+ respond(new RpcResponse(req.requestId, new
NioManagedBuffer(response)));
+ }
+
+ @Override
+ public void onFailure(Throwable e) {
+ respond(new RpcFailure(req.requestId,
Throwables.getStackTraceAsString(e)));
+ }
+ });
+ } catch (Exception e) {
+ logger.error("Error while invoking handler#receive() on RPC id " +
req.requestId, e);
+ respond(new RpcFailure(req.requestId,
Throwables.getStackTraceAsString(e)));
Review Comment:
Should we wrap the exception as a CelebornIOException and then respond?
##########
worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala:
##########
@@ -95,54 +102,55 @@ class FetchHandler(
handleReadAddCredit(r.getCredit, r.getStreamId)
case r: ChunkFetchRequest =>
handleChunkFetchRequest(client, r.streamChunkSlice, r)
- case r: RpcRequest =>
- handleRpcRequest(client, r)
case unknown: RequestMessage =>
throw new IllegalArgumentException(s"Unknown message type id:
${unknown.`type`.id}")
}
}
- private def handleRpcRequest(client: TransportClient, rpcRequest:
RpcRequest): Unit = {
+ private def handleRpcRequest(
+ client: TransportClient,
+ rpcRequest: RpcRequest,
+ callback: RpcResponseCallback): Unit = {
+ var message: GeneratedMessageV3 = null
try {
- var message: GeneratedMessageV3 = null
- try {
- message =
TransportMessage.fromByteBuffer(rpcRequest.body().nioByteBuffer())
- .getParsedPayload[GeneratedMessageV3]
- } catch {
- case exception: CelebornIOException =>
- logWarning("Handle request with legacy RPCs", exception)
- return handleLegacyRpcMessage(client, rpcRequest)
- }
- message match {
- case openStream: PbOpenStream =>
- handleOpenStreamInternal(
- client,
- openStream.getShuffleKey,
- openStream.getFileName,
- openStream.getStartIndex,
- openStream.getEndIndex,
- openStream.getInitialCredit,
- rpcRequest.requestId,
- isLegacy = false,
- openStream.getReadLocalShuffle)
- case bufferStreamEnd: PbBufferStreamEnd =>
- handleEndStreamFromClient(bufferStreamEnd.getStreamId,
bufferStreamEnd.getStreamType)
- case readAddCredit: PbReadAddCredit =>
- handleReadAddCredit(readAddCredit.getCredit,
readAddCredit.getStreamId)
- case chunkFetchRequest: PbChunkFetchRequest =>
- handleChunkFetchRequest(
- client,
- StreamChunkSlice.fromProto(chunkFetchRequest.getStreamChunkSlice),
- rpcRequest)
- case message: GeneratedMessageV3 =>
- logError(s"Unknown message $message")
- }
- } finally {
- rpcRequest.body().release()
+ message =
TransportMessage.fromByteBuffer(rpcRequest.body().nioByteBuffer())
+ .getParsedPayload[GeneratedMessageV3]
+ } catch {
+ case exception: CelebornIOException =>
+ logWarning("Handle request with legacy RPCs", exception)
+ return handleLegacyRpcMessage(client, rpcRequest, callback)
+ }
+ message match {
+ case openStream: PbOpenStream =>
+ handleOpenStreamInternal(
+ client,
+ openStream.getShuffleKey,
+ openStream.getFileName,
+ openStream.getStartIndex,
+ openStream.getEndIndex,
+ openStream.getInitialCredit,
+ rpcRequest.requestId,
+ isLegacy = false,
+ openStream.getReadLocalShuffle,
+ callback)
+ case bufferStreamEnd: PbBufferStreamEnd =>
+ handleEndStreamFromClient(bufferStreamEnd.getStreamId,
bufferStreamEnd.getStreamType)
+ case readAddCredit: PbReadAddCredit =>
+ handleReadAddCredit(readAddCredit.getCredit, readAddCredit.getStreamId)
+ case chunkFetchRequest: PbChunkFetchRequest =>
+ handleChunkFetchRequest(
+ client,
+ StreamChunkSlice.fromProto(chunkFetchRequest.getStreamChunkSlice),
+ rpcRequest)
+ case message: GeneratedMessageV3 =>
+ logError(s"Unknown message $message")
Review Comment:
probably good to add a colon to make message more clear
```
Unknown message: $message
```
##########
worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala:
##########
@@ -95,54 +102,55 @@ class FetchHandler(
handleReadAddCredit(r.getCredit, r.getStreamId)
case r: ChunkFetchRequest =>
handleChunkFetchRequest(client, r.streamChunkSlice, r)
- case r: RpcRequest =>
- handleRpcRequest(client, r)
case unknown: RequestMessage =>
throw new IllegalArgumentException(s"Unknown message type id:
${unknown.`type`.id}")
}
}
- private def handleRpcRequest(client: TransportClient, rpcRequest:
RpcRequest): Unit = {
+ private def handleRpcRequest(
+ client: TransportClient,
+ rpcRequest: RpcRequest,
+ callback: RpcResponseCallback): Unit = {
+ var message: GeneratedMessageV3 = null
try {
- var message: GeneratedMessageV3 = null
- try {
- message =
TransportMessage.fromByteBuffer(rpcRequest.body().nioByteBuffer())
- .getParsedPayload[GeneratedMessageV3]
- } catch {
- case exception: CelebornIOException =>
- logWarning("Handle request with legacy RPCs", exception)
- return handleLegacyRpcMessage(client, rpcRequest)
- }
- message match {
- case openStream: PbOpenStream =>
- handleOpenStreamInternal(
- client,
- openStream.getShuffleKey,
- openStream.getFileName,
- openStream.getStartIndex,
- openStream.getEndIndex,
- openStream.getInitialCredit,
- rpcRequest.requestId,
- isLegacy = false,
- openStream.getReadLocalShuffle)
- case bufferStreamEnd: PbBufferStreamEnd =>
- handleEndStreamFromClient(bufferStreamEnd.getStreamId,
bufferStreamEnd.getStreamType)
- case readAddCredit: PbReadAddCredit =>
- handleReadAddCredit(readAddCredit.getCredit,
readAddCredit.getStreamId)
- case chunkFetchRequest: PbChunkFetchRequest =>
- handleChunkFetchRequest(
- client,
- StreamChunkSlice.fromProto(chunkFetchRequest.getStreamChunkSlice),
- rpcRequest)
- case message: GeneratedMessageV3 =>
- logError(s"Unknown message $message")
- }
- } finally {
- rpcRequest.body().release()
+ message =
TransportMessage.fromByteBuffer(rpcRequest.body().nioByteBuffer())
+ .getParsedPayload[GeneratedMessageV3]
+ } catch {
+ case exception: CelebornIOException =>
+ logWarning("Handle request with legacy RPCs", exception)
+ return handleLegacyRpcMessage(client, rpcRequest, callback)
+ }
+ message match {
+ case openStream: PbOpenStream =>
+ handleOpenStreamInternal(
+ client,
+ openStream.getShuffleKey,
+ openStream.getFileName,
+ openStream.getStartIndex,
+ openStream.getEndIndex,
+ openStream.getInitialCredit,
+ rpcRequest.requestId,
+ isLegacy = false,
+ openStream.getReadLocalShuffle,
+ callback)
+ case bufferStreamEnd: PbBufferStreamEnd =>
+ handleEndStreamFromClient(bufferStreamEnd.getStreamId,
bufferStreamEnd.getStreamType)
+ case readAddCredit: PbReadAddCredit =>
+ handleReadAddCredit(readAddCredit.getCredit, readAddCredit.getStreamId)
+ case chunkFetchRequest: PbChunkFetchRequest =>
+ handleChunkFetchRequest(
+ client,
+ StreamChunkSlice.fromProto(chunkFetchRequest.getStreamChunkSlice),
+ rpcRequest)
+ case message: GeneratedMessageV3 =>
Review Comment:
is it better to make this case a general case?
```
case _ =>
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]