SteNicholas commented on code in PR #3484:
URL: https://github.com/apache/celeborn/pull/3484#discussion_r2374819451


##########
cpp/celeborn/network/tests/NettyRpcEndpointRefTest.cpp:
##########
@@ -65,54 +77,150 @@ void readRpcAddress(
   readUTF(buffer, host);
   port = buffer.read<int32_t>();
 }
+
+void verifyRequestForNettyRpcEndpointRef(
+    const RpcRequest& request,
+    const std::string& srcName,
+    const std::string& srcHost,
+    const int srcPort,
+    const std::string& dstHost,
+    const int dstPort) {
+  auto sentBody = request.body();
+  std::string host;
+  int port;
+  readRpcAddress(*sentBody, host, port);
+  EXPECT_EQ(host, srcHost);
+  EXPECT_EQ(port, srcPort);
+  readRpcAddress(*sentBody, host, port);
+  EXPECT_EQ(host, dstHost);
+  EXPECT_EQ(port, dstPort);
+  std::string name;
+  readUTF(*sentBody, name);
+  EXPECT_EQ(name, srcName);
+  EXPECT_EQ(
+      sentBody->read<uint8_t>(),
+      NettyRpcEndpointRef::kNativeTransportMessageFlag);
+}
+
 } // namespace
 
-TEST(NettyRpcEndpointRefTest, askSyncGetReducerFileGroup) {
+TEST(NettyRpcEndpointRefTest, askSyncRegisterShuffle) {
+  auto mockedClient = std::make_shared<MockTransportClient>();
+  const std::string srcName = "test-name";
+  const std::string srcHost = "test-src-host";
+  const int srcPort = 100;
+  const std::string dstHost = "test-dst-host";
+  const int dstPort = 101;
+  const auto conf = std::make_shared<conf::CelebornConf>();
+  auto nettyRpcEndpointRef = NettyRpcEndpointRef(
+      srcName, srcHost, srcPort, dstHost, dstPort, mockedClient, *conf);
+
+  PbRegisterShuffleResponse pbRegisterShuffleResponse;
+  const int status = 5;
+  pbRegisterShuffleResponse.set_status(5);
+  protocol::TransportMessage transportMessage(
+      REGISTER_SHUFFLE_RESPONSE, 
pbRegisterShuffleResponse.SerializeAsString());
+  auto rpcResponse = makeResponseForNettyRpcEndpointRef(transportMessage, 
1000);
+  mockedClient->setResponse(*rpcResponse);
+
+  protocol::RegisterShuffle request{1001};
+  auto response = nettyRpcEndpointRef.askSync<
+      protocol::RegisterShuffle,
+      protocol::RegisterShuffleResponse>(request, MS(100));
+  EXPECT_EQ(response->status, status);
+
+  auto sentRequest = mockedClient->getRequest();
+  verifyRequestForNettyRpcEndpointRef(
+      sentRequest, srcName, srcHost, srcPort, dstHost, dstPort);
+}
+
+TEST(NettyRpcEndpointRefTest, askSyncRevive) {
   auto mockedClient = std::make_shared<MockTransportClient>();
   const std::string srcName = "test-name";
   const std::string srcHost = "test-src-host";
   const int srcPort = 100;
   const std::string dstHost = "test-dst-host";
   const int dstPort = 101;
+  const auto conf = std::make_shared<conf::CelebornConf>();
   auto nettyRpcEndpointRef = NettyRpcEndpointRef(
-      srcName, srcHost, srcPort, dstHost, dstPort, mockedClient);
+      srcName, srcHost, srcPort, dstHost, dstPort, mockedClient, *conf);
+
+  PbRevive pbRevive;
+  pbRevive.set_shuffleid(5);
+  protocol::TransportMessage transportMessage(
+      CHANGE_LOCATION_RESPONSE, pbRevive.SerializeAsString());
+  auto rpcResponse = makeResponseForNettyRpcEndpointRef(transportMessage, 
1000);
+  mockedClient->setResponse(*rpcResponse);
 
-  const std::string responseBody = "test-response-body";
+  protocol::Revive request{1001};
+  auto response =
+      nettyRpcEndpointRef
+          .askSync<protocol::Revive, protocol::ChangeLocationResponse>(
+              request, MS(100));
+
+  auto sentRequest = mockedClient->getRequest();
+  verifyRequestForNettyRpcEndpointRef(
+      sentRequest, srcName, srcHost, srcPort, dstHost, dstPort);
+}
+
+TEST(NettyRpcEndpointRefTest, askSyncMapperEnd) {
+  auto mockedClient = std::make_shared<MockTransportClient>();
+  const std::string srcName = "test-name";
+  const std::string srcHost = "test-src-host";
+  const int srcPort = 100;
+  const std::string dstHost = "test-dst-host";
+  const int dstPort = 101;
+  const auto conf = std::make_shared<conf::CelebornConf>();
+  auto nettyRpcEndpointRef = NettyRpcEndpointRef(
+      srcName, srcHost, srcPort, dstHost, dstPort, mockedClient, *conf);
+
+  PbMapperEndResponse pbMapperEndResponse;
+  const int status = 5;
+  pbMapperEndResponse.set_status(5);
+  protocol::TransportMessage transportMessage(
+      MAPPER_END_RESPONSE, pbMapperEndResponse.SerializeAsString());
+  auto rpcResponse = makeResponseForNettyRpcEndpointRef(transportMessage, 
1000);
+  mockedClient->setResponse(*rpcResponse);
+
+  protocol::MapperEnd request{1001};
+  auto response =
+      nettyRpcEndpointRef
+          .askSync<protocol::MapperEnd, protocol::MapperEndResponse>(
+              request, MS(100));
+  EXPECT_EQ(response->status, status);
+
+  auto sentRequest = mockedClient->getRequest();
+  verifyRequestForNettyRpcEndpointRef(
+      sentRequest, srcName, srcHost, srcPort, dstHost, dstPort);
+}
+
+TEST(NettyRpcEndpointRefTest, askSyncGetReducerFileGroup) {
+  auto mockedClient = std::make_shared<MockTransportClient>();
+  const std::string srcName = "test-name";
+  const std::string srcHost = "test-src-host";
+  const int srcPort = 100;
+  const std::string dstHost = "test-dst-host";
+  const int dstPort = 101;
+  const auto conf = std::make_shared<conf::CelebornConf>();
+  auto nettyRpcEndpointRef = NettyRpcEndpointRef(
+      srcName, srcHost, srcPort, dstHost, dstPort, mockedClient, *conf);
 
-  // rpcResponse -> transportMessage -> getReducerFileGroupResponse
   PbGetReducerFileGroupResponse pbGetReducerFileGroupResponse;
   const int status = 5;
   pbGetReducerFileGroupResponse.set_status(5);
   protocol::TransportMessage transportMessage(
       GET_REDUCER_FILE_GROUP_RESPONSE,
       pbGetReducerFileGroupResponse.SerializeAsString());
-  auto msgBody = transportMessage.toReadOnlyByteBuffer();
-  auto writeBuffer = memory::ByteBuffer::createWriteOnly(sizeof(uint8_t));
-  
writeBuffer->write<uint8_t>(NettyRpcEndpointRef::kNativeTransportMessageFlag);
-  auto flagBody = memory::ByteBuffer::toReadOnly(std::move(writeBuffer));
-  auto concatBody = memory::ByteBuffer::concat(*flagBody, *msgBody);
-
-  auto rpcResponse = std::make_unique<RpcResponse>(1000, 
std::move(concatBody));
+  auto rpcResponse = makeResponseForNettyRpcEndpointRef(transportMessage, 
1000);
   mockedClient->setResponse(*rpcResponse);
 
   protocol::GetReducerFileGroup request{1001};
-  auto response = nettyRpcEndpointRef.askSync(request, MS(100));
+  auto response = nettyRpcEndpointRef.askSync<
+      protocol::GetReducerFileGroup,
+      protocol::GetReducerFileGroupResponse>(request, MS(100));
   EXPECT_EQ(response->status, status);
 
   auto sentRequest = mockedClient->getRequest();
-  auto sentBody = sentRequest.body();
-  std::string host;
-  int port;
-  readRpcAddress(*sentBody, host, port);
-  EXPECT_EQ(host, srcHost);
-  EXPECT_EQ(port, srcPort);
-  readRpcAddress(*sentBody, host, port);
-  EXPECT_EQ(host, dstHost);
-  EXPECT_EQ(port, dstPort);
-  std::string name;
-  readUTF(*sentBody, name);
-  EXPECT_EQ(name, srcName);
-  EXPECT_EQ(
-      sentBody->read<uint8_t>(),
-      NettyRpcEndpointRef::kNativeTransportMessageFlag);
+  verifyRequestForNettyRpcEndpointRef(

Review Comment:
   Could you add some test cases for rpc timeout?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to