This is an automated email from the ASF dual-hosted git repository.

nicholasjiang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new a7735d99e [CELEBORN-2157][CIP-14] Support sending 
RegisterShuffle/Revive/MapperEnd messages for NettyRpcEndpointRef in cppClient
a7735d99e is described below

commit a7735d99ec46fab102b88539e408a1afbde1ae78
Author: HolyLow <[email protected]>
AuthorDate: Fri Oct 10 15:47:00 2025 +0800

    [CELEBORN-2157][CIP-14] Support sending RegisterShuffle/Revive/MapperEnd 
messages for NettyRpcEndpointRef in cppClient
    
    ### What changes were proposed in this pull request?
    This PR supports sending RegisterShuffle/Revive/MapperEnd messages for 
NettyRpcEndpointRef in cppClient.
    
    ### Why are the changes needed?
    These messages are used to communicate with CelebornMaster when writing 
data with cppClient.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Compilation and UTs.
    
    Closes #3484 from 
HolyLow/issue/celeborn-2157-support-pushdata-revive-in-rpcref.
    
    Authored-by: HolyLow <[email protected]>
    Signed-off-by: SteNicholas <[email protected]>
---
 cpp/celeborn/client/ShuffleClient.cpp              |  12 +-
 cpp/celeborn/conf/CelebornConf.cpp                 |   5 +
 cpp/celeborn/conf/CelebornConf.h                   |   4 +
 cpp/celeborn/network/NettyRpcEndpointRef.cpp       |  45 +----
 cpp/celeborn/network/NettyRpcEndpointRef.h         |  49 ++++-
 .../network/tests/NettyRpcEndpointRefTest.cpp      | 209 +++++++++++++++++----
 6 files changed, 241 insertions(+), 83 deletions(-)

diff --git a/cpp/celeborn/client/ShuffleClient.cpp 
b/cpp/celeborn/client/ShuffleClient.cpp
index ccf7c6dcb..92279e4f3 100644
--- a/cpp/celeborn/client/ShuffleClient.cpp
+++ b/cpp/celeborn/client/ShuffleClient.cpp
@@ -32,7 +32,13 @@ void 
ShuffleClientImpl::setupLifecycleManagerRef(std::string& host, int port) {
   {
     std::lock_guard<std::mutex> lock(mutex_);
     lifecycleManagerRef_ = std::make_shared<network::NettyRpcEndpointRef>(
-        "LifecycleManagerEndpoint", "dummy", 0, host, port, managerClient);
+        "LifecycleManagerEndpoint",
+        "dummy",
+        0,
+        host,
+        port,
+        managerClient,
+        *conf_);
   }
 }
 
@@ -83,7 +89,9 @@ void ShuffleClientImpl::updateReducerFileGroup(int shuffleId) 
{
   CELEBORN_CHECK(
       lifecycleManagerRef_, "lifecycleManagerRef_ is not initialized");
   // Send the query request to lifecycleManager.
-  auto reducerFileGroupInfo = lifecycleManagerRef_->askSync(
+  auto reducerFileGroupInfo = lifecycleManagerRef_->askSync<
+      protocol::GetReducerFileGroup,
+      protocol::GetReducerFileGroupResponse>(
       protocol::GetReducerFileGroup{shuffleId},
       conf_->clientRpcGetReducerFileGroupRpcAskTimeout());
 
diff --git a/cpp/celeborn/conf/CelebornConf.cpp 
b/cpp/celeborn/conf/CelebornConf.cpp
index e21d39032..73b2b24dc 100644
--- a/cpp/celeborn/conf/CelebornConf.cpp
+++ b/cpp/celeborn/conf/CelebornConf.cpp
@@ -133,6 +133,7 @@ Duration toDuration(const std::string& str) {
 
 const std::unordered_map<std::string, folly::Optional<std::string>>
     CelebornConf::kDefaultProperties = {
+        STR_PROP(kRpcAskTimeout, "60s"),
         STR_PROP(kRpcLookupTimeout, "30s"),
         STR_PROP(kClientRpcGetReducerFileGroupRpcAskTimeout, "60s"),
         STR_PROP(kNetworkConnectTimeout, "10s"),
@@ -175,6 +176,10 @@ void CelebornConf::registerProperty(
   setValue(static_cast<std::string>(key), value);
 }
 
+Timeout CelebornConf::rpcAskTimeout() const {
+  return 
utils::toTimeout(toDuration(optionalProperty(kRpcAskTimeout).value()));
+}
+
 Timeout CelebornConf::rpcLookupTimeout() const {
   return utils::toTimeout(
       toDuration(optionalProperty(kRpcLookupTimeout).value()));
diff --git a/cpp/celeborn/conf/CelebornConf.h b/cpp/celeborn/conf/CelebornConf.h
index 5aa3c6f9e..6ee278148 100644
--- a/cpp/celeborn/conf/CelebornConf.h
+++ b/cpp/celeborn/conf/CelebornConf.h
@@ -40,6 +40,8 @@ class CelebornConf : public BaseConf {
   static const std::unordered_map<std::string, folly::Optional<std::string>>
       kDefaultProperties;
 
+  static constexpr std::string_view kRpcAskTimeout{"celeborn.rpc.askTimeout"};
+
   static constexpr std::string_view kRpcLookupTimeout{
       "celeborn.rpc.lookupTimeout"};
 
@@ -77,6 +79,8 @@ class CelebornConf : public BaseConf {
 
   void registerProperty(const std::string_view& key, const std::string& value);
 
+  Timeout rpcAskTimeout() const;
+
   Timeout rpcLookupTimeout() const;
 
   Timeout clientRpcGetReducerFileGroupRpcAskTimeout() const;
diff --git a/cpp/celeborn/network/NettyRpcEndpointRef.cpp 
b/cpp/celeborn/network/NettyRpcEndpointRef.cpp
index 1ea4df0e7..e8ab4c951 100644
--- a/cpp/celeborn/network/NettyRpcEndpointRef.cpp
+++ b/cpp/celeborn/network/NettyRpcEndpointRef.cpp
@@ -26,51 +26,14 @@ NettyRpcEndpointRef::NettyRpcEndpointRef(
     int srcPort,
     const std::string& dstHost,
     int dstPort,
-    std::shared_ptr<TransportClient> client)
+    const std::shared_ptr<TransportClient>& client,
+    const conf::CelebornConf& conf)
     : name_(name),
       srcHost_(srcHost),
       srcPort_(srcPort),
       dstHost_(dstHost),
       dstPort_(dstPort),
-      client_(client) {}
-
-std::unique_ptr<protocol::GetReducerFileGroupResponse>
-NettyRpcEndpointRef::askSync(
-    const protocol::GetReducerFileGroup& msg,
-    Timeout timeout) {
-  auto rpcRequest = buildRpcRequest(msg);
-  auto rpcResponse = client_->sendRpcRequestSync(rpcRequest, timeout);
-  return fromRpcResponse(std::move(rpcResponse));
-}
-
-RpcRequest NettyRpcEndpointRef::buildRpcRequest(
-    const protocol::GetReducerFileGroup& msg) {
-  auto transportData = msg.toTransportMessage().toReadOnlyByteBuffer();
-  int size =
-      srcHost_.size() + 3 + 4 + dstHost_.size() + 3 + 4 + name_.size() + 2 + 1;
-  auto buffer = memory::ByteBuffer::createWriteOnly(size);
-  // write srcAddr msg
-  utils::writeRpcAddress(*buffer, srcHost_, srcPort_);
-  // write dstAddr msg
-  utils::writeRpcAddress(*buffer, dstHost_, dstPort_);
-  // write srcName
-  utils::writeUTF(*buffer, name_);
-  // write the isTransportMessage flag
-  buffer->write<uint8_t>(kNativeTransportMessageFlag);
-  CELEBORN_CHECK_EQ(buffer->size(), size);
-  auto result = memory::ByteBuffer::toReadOnly(std::move(buffer));
-  auto combined = memory::ByteBuffer::concat(*result, *transportData);
-  return RpcRequest(RpcRequest::nextRequestId(), std::move(combined));
-}
-
-std::unique_ptr<protocol::GetReducerFileGroupResponse>
-NettyRpcEndpointRef::fromRpcResponse(RpcResponse&& response) {
-  auto body = response.body();
-  uint8_t nativeTransportMessageFlag = body->read<uint8_t>();
-  CELEBORN_CHECK_EQ(nativeTransportMessageFlag, kNativeTransportMessageFlag);
-  auto transportMessage = protocol::TransportMessage(std::move(body));
-  return protocol::GetReducerFileGroupResponse::fromTransportMessage(
-      transportMessage);
-}
+      client_(client),
+      defaultTimeout_(conf.rpcAskTimeout()) {}
 } // namespace network
 } // namespace celeborn
diff --git a/cpp/celeborn/network/NettyRpcEndpointRef.h 
b/cpp/celeborn/network/NettyRpcEndpointRef.h
index 334d4d4f1..5145f88ff 100644
--- a/cpp/celeborn/network/NettyRpcEndpointRef.h
+++ b/cpp/celeborn/network/NettyRpcEndpointRef.h
@@ -37,18 +37,50 @@ class NettyRpcEndpointRef {
       int srcPort,
       const std::string& dstHost,
       int dstPort,
-      std::shared_ptr<TransportClient> client);
+      const std::shared_ptr<TransportClient>& client,
+      const conf::CelebornConf& conf);
 
-  // TODO: refactor to template function when needed.
-  std::unique_ptr<protocol::GetReducerFileGroupResponse> askSync(
-      const protocol::GetReducerFileGroup& msg,
-      Timeout timeout);
+  template <class TRequest, class TResponse>
+  std::unique_ptr<TResponse> askSync(const TRequest& msg) {
+    return askSync<TRequest, TResponse>(msg, defaultTimeout_);
+  }
+
+  template <class TRequest, class TResponse>
+  std::unique_ptr<TResponse> askSync(const TRequest& msg, Timeout timeout) {
+    auto rpcRequest = buildRpcRequest<TRequest>(msg);
+    auto rpcResponse = client_->sendRpcRequestSync(rpcRequest, timeout);
+    return fromRpcResponse<TResponse>(std::move(rpcResponse));
+  }
 
  private:
-  RpcRequest buildRpcRequest(const protocol::GetReducerFileGroup& msg);
+  template <class TRequest>
+  RpcRequest buildRpcRequest(const TRequest& msg) {
+    auto transportData = msg.toTransportMessage().toReadOnlyByteBuffer();
+    int size = srcHost_.size() + 3 + 4 + dstHost_.size() + 3 + 4 +
+        name_.size() + 2 + 1;
+    auto buffer = memory::ByteBuffer::createWriteOnly(size);
+    // write srcAddr msg
+    utils::writeRpcAddress(*buffer, srcHost_, srcPort_);
+    // write dstAddr msg
+    utils::writeRpcAddress(*buffer, dstHost_, dstPort_);
+    // write srcName
+    utils::writeUTF(*buffer, name_);
+    // write the isTransportMessage flag
+    buffer->write<uint8_t>(kNativeTransportMessageFlag);
+    CELEBORN_CHECK_EQ(buffer->size(), size);
+    auto result = memory::ByteBuffer::toReadOnly(std::move(buffer));
+    auto combined = memory::ByteBuffer::concat(*result, *transportData);
+    return RpcRequest(RpcRequest::nextRequestId(), std::move(combined));
+  }
 
-  std::unique_ptr<protocol::GetReducerFileGroupResponse> fromRpcResponse(
-      RpcResponse&& response);
+  template <class TResponse>
+  std::unique_ptr<TResponse> fromRpcResponse(RpcResponse&& response) {
+    auto body = response.body();
+    bool isTransportMessage = body->read<uint8_t>();
+    CELEBORN_CHECK(isTransportMessage);
+    auto transportMessage = protocol::TransportMessage(std::move(body));
+    return TResponse::fromTransportMessage(transportMessage);
+  }
 
   std::string name_;
   std::string srcHost_;
@@ -56,6 +88,7 @@ class NettyRpcEndpointRef {
   std::string dstHost_;
   int dstPort_;
   std::shared_ptr<TransportClient> client_;
+  Timeout defaultTimeout_;
 };
 } // namespace network
 } // namespace celeborn
diff --git a/cpp/celeborn/network/tests/NettyRpcEndpointRefTest.cpp 
b/cpp/celeborn/network/tests/NettyRpcEndpointRefTest.cpp
index 41f629364..32433c410 100644
--- a/cpp/celeborn/network/tests/NettyRpcEndpointRefTest.cpp
+++ b/cpp/celeborn/network/tests/NettyRpcEndpointRefTest.cpp
@@ -29,18 +29,18 @@ class MockTransportClient : public TransportClient {
  public:
   MockTransportClient()
       : TransportClient(nullptr, nullptr, MS(100)),
-        response_(
-            RpcResponse(0, memory::ReadOnlyByteBuffer::createEmptyBuffer())),
+        respPromise_(),
+        respFuture_(respPromise_.getFuture()),
         request_(
             RpcRequest(0, memory::ReadOnlyByteBuffer::createEmptyBuffer())) {}
   RpcResponse sendRpcRequestSync(const RpcRequest& request, Timeout timeout)
       override {
     request_ = request;
-    return response_;
+    return std::move(respFuture_).get(timeout);
   }
 
   void setResponse(const RpcResponse& response) {
-    response_ = response;
+    respPromise_.setValue(response);
   }
 
   RpcRequest getRequest() {
@@ -48,10 +48,23 @@ class MockTransportClient : public TransportClient {
   }
 
  private:
-  RpcResponse response_;
+  folly::Promise<RpcResponse> respPromise_;
+  folly::Future<RpcResponse> respFuture_;
   RpcRequest request_;
 };
 
+std::unique_ptr<RpcResponse> makeResponseForNettyRpcEndpointRef(
+    const protocol::TransportMessage& transportMessage,
+    long requestId) {
+  auto msgBody = transportMessage.toReadOnlyByteBuffer();
+  auto writeBuffer = memory::ByteBuffer::createWriteOnly(sizeof(uint8_t));
+  
writeBuffer->write<uint8_t>(NettyRpcEndpointRef::kNativeTransportMessageFlag);
+  auto flagBody = memory::ByteBuffer::toReadOnly(std::move(writeBuffer));
+  auto concatBody = memory::ByteBuffer::concat(*flagBody, *msgBody);
+
+  return std::make_unique<RpcResponse>(requestId, std::move(concatBody));
+}
+
 void readUTF(memory::ReadOnlyByteBuffer& buffer, std::string& host) {
   int size = buffer.read<short>();
   host = buffer.readToString(size);
@@ -65,54 +78,186 @@ void readRpcAddress(
   readUTF(buffer, host);
   port = buffer.read<int32_t>();
 }
+
+void verifyRequestForNettyRpcEndpointRef(
+    const RpcRequest& request,
+    const std::string& srcName,
+    const std::string& srcHost,
+    const int srcPort,
+    const std::string& dstHost,
+    const int dstPort) {
+  auto sentBody = request.body();
+  std::string host;
+  int port;
+  readRpcAddress(*sentBody, host, port);
+  EXPECT_EQ(host, srcHost);
+  EXPECT_EQ(port, srcPort);
+  readRpcAddress(*sentBody, host, port);
+  EXPECT_EQ(host, dstHost);
+  EXPECT_EQ(port, dstPort);
+  std::string name;
+  readUTF(*sentBody, name);
+  EXPECT_EQ(name, srcName);
+  EXPECT_EQ(
+      sentBody->read<uint8_t>(),
+      NettyRpcEndpointRef::kNativeTransportMessageFlag);
+}
+
 } // namespace
 
-TEST(NettyRpcEndpointRefTest, askSyncGetReducerFileGroup) {
+TEST(NettyRpcEndpointRefTest, askSyncRegisterShuffle) {
+  auto mockedClient = std::make_shared<MockTransportClient>();
+  const std::string srcName = "test-name";
+  const std::string srcHost = "test-src-host";
+  const int srcPort = 100;
+  const std::string dstHost = "test-dst-host";
+  const int dstPort = 101;
+  const auto conf = std::make_shared<conf::CelebornConf>();
+  auto nettyRpcEndpointRef = NettyRpcEndpointRef(
+      srcName, srcHost, srcPort, dstHost, dstPort, mockedClient, *conf);
+
+  PbRegisterShuffleResponse pbRegisterShuffleResponse;
+  const int status = 5;
+  pbRegisterShuffleResponse.set_status(5);
+  protocol::TransportMessage transportMessage(
+      REGISTER_SHUFFLE_RESPONSE, 
pbRegisterShuffleResponse.SerializeAsString());
+  auto rpcResponse = makeResponseForNettyRpcEndpointRef(transportMessage, 
1000);
+  mockedClient->setResponse(*rpcResponse);
+
+  protocol::RegisterShuffle request{1001};
+  auto response = nettyRpcEndpointRef.askSync<
+      protocol::RegisterShuffle,
+      protocol::RegisterShuffleResponse>(request, MS(100));
+  EXPECT_EQ(response->status, status);
+
+  auto sentRequest = mockedClient->getRequest();
+  verifyRequestForNettyRpcEndpointRef(
+      sentRequest, srcName, srcHost, srcPort, dstHost, dstPort);
+}
+
+TEST(NettyRpcEndpointRefTest, askSyncRevive) {
+  auto mockedClient = std::make_shared<MockTransportClient>();
+  const std::string srcName = "test-name";
+  const std::string srcHost = "test-src-host";
+  const int srcPort = 100;
+  const std::string dstHost = "test-dst-host";
+  const int dstPort = 101;
+  const auto conf = std::make_shared<conf::CelebornConf>();
+  auto nettyRpcEndpointRef = NettyRpcEndpointRef(
+      srcName, srcHost, srcPort, dstHost, dstPort, mockedClient, *conf);
+
+  PbRevive pbRevive;
+  pbRevive.set_shuffleid(5);
+  protocol::TransportMessage transportMessage(
+      CHANGE_LOCATION_RESPONSE, pbRevive.SerializeAsString());
+  auto rpcResponse = makeResponseForNettyRpcEndpointRef(transportMessage, 
1000);
+  mockedClient->setResponse(*rpcResponse);
+
+  protocol::Revive request{1001};
+  auto response =
+      nettyRpcEndpointRef
+          .askSync<protocol::Revive, protocol::ChangeLocationResponse>(
+              request, MS(100));
+
+  auto sentRequest = mockedClient->getRequest();
+  verifyRequestForNettyRpcEndpointRef(
+      sentRequest, srcName, srcHost, srcPort, dstHost, dstPort);
+}
+
+TEST(NettyRpcEndpointRefTest, askSyncMapperEnd) {
   auto mockedClient = std::make_shared<MockTransportClient>();
   const std::string srcName = "test-name";
   const std::string srcHost = "test-src-host";
   const int srcPort = 100;
   const std::string dstHost = "test-dst-host";
   const int dstPort = 101;
+  const auto conf = std::make_shared<conf::CelebornConf>();
   auto nettyRpcEndpointRef = NettyRpcEndpointRef(
-      srcName, srcHost, srcPort, dstHost, dstPort, mockedClient);
+      srcName, srcHost, srcPort, dstHost, dstPort, mockedClient, *conf);
+
+  PbMapperEndResponse pbMapperEndResponse;
+  const int status = 5;
+  pbMapperEndResponse.set_status(5);
+  protocol::TransportMessage transportMessage(
+      MAPPER_END_RESPONSE, pbMapperEndResponse.SerializeAsString());
+  auto rpcResponse = makeResponseForNettyRpcEndpointRef(transportMessage, 
1000);
+  mockedClient->setResponse(*rpcResponse);
 
-  const std::string responseBody = "test-response-body";
+  protocol::MapperEnd request{1001};
+  auto response =
+      nettyRpcEndpointRef
+          .askSync<protocol::MapperEnd, protocol::MapperEndResponse>(
+              request, MS(100));
+  EXPECT_EQ(response->status, status);
+
+  auto sentRequest = mockedClient->getRequest();
+  verifyRequestForNettyRpcEndpointRef(
+      sentRequest, srcName, srcHost, srcPort, dstHost, dstPort);
+}
+
+TEST(NettyRpcEndpointRefTest, askSyncGetReducerFileGroup) {
+  auto mockedClient = std::make_shared<MockTransportClient>();
+  const std::string srcName = "test-name";
+  const std::string srcHost = "test-src-host";
+  const int srcPort = 100;
+  const std::string dstHost = "test-dst-host";
+  const int dstPort = 101;
+  const auto conf = std::make_shared<conf::CelebornConf>();
+  auto nettyRpcEndpointRef = NettyRpcEndpointRef(
+      srcName, srcHost, srcPort, dstHost, dstPort, mockedClient, *conf);
 
-  // rpcResponse -> transportMessage -> getReducerFileGroupResponse
   PbGetReducerFileGroupResponse pbGetReducerFileGroupResponse;
   const int status = 5;
   pbGetReducerFileGroupResponse.set_status(5);
   protocol::TransportMessage transportMessage(
       GET_REDUCER_FILE_GROUP_RESPONSE,
       pbGetReducerFileGroupResponse.SerializeAsString());
-  auto msgBody = transportMessage.toReadOnlyByteBuffer();
-  auto writeBuffer = memory::ByteBuffer::createWriteOnly(sizeof(uint8_t));
-  
writeBuffer->write<uint8_t>(NettyRpcEndpointRef::kNativeTransportMessageFlag);
-  auto flagBody = memory::ByteBuffer::toReadOnly(std::move(writeBuffer));
-  auto concatBody = memory::ByteBuffer::concat(*flagBody, *msgBody);
-
-  auto rpcResponse = std::make_unique<RpcResponse>(1000, 
std::move(concatBody));
+  auto rpcResponse = makeResponseForNettyRpcEndpointRef(transportMessage, 
1000);
   mockedClient->setResponse(*rpcResponse);
 
   protocol::GetReducerFileGroup request{1001};
-  auto response = nettyRpcEndpointRef.askSync(request, MS(100));
+  auto response = nettyRpcEndpointRef.askSync<
+      protocol::GetReducerFileGroup,
+      protocol::GetReducerFileGroupResponse>(request, MS(100));
   EXPECT_EQ(response->status, status);
 
   auto sentRequest = mockedClient->getRequest();
-  auto sentBody = sentRequest.body();
-  std::string host;
-  int port;
-  readRpcAddress(*sentBody, host, port);
-  EXPECT_EQ(host, srcHost);
-  EXPECT_EQ(port, srcPort);
-  readRpcAddress(*sentBody, host, port);
-  EXPECT_EQ(host, dstHost);
-  EXPECT_EQ(port, dstPort);
-  std::string name;
-  readUTF(*sentBody, name);
-  EXPECT_EQ(name, srcName);
-  EXPECT_EQ(
-      sentBody->read<uint8_t>(),
-      NettyRpcEndpointRef::kNativeTransportMessageFlag);
+  verifyRequestForNettyRpcEndpointRef(
+      sentRequest, srcName, srcHost, srcPort, dstHost, dstPort);
+}
+
+TEST(NettyRpcEndpointRefTest, askSyncTimeout) {
+  auto mockedClient = std::make_shared<MockTransportClient>();
+  const std::string srcName = "test-name";
+  const std::string srcHost = "test-src-host";
+  const int srcPort = 100;
+  const std::string dstHost = "test-dst-host";
+  const int dstPort = 101;
+  const auto conf = std::make_shared<conf::CelebornConf>();
+  auto nettyRpcEndpointRef = NettyRpcEndpointRef(
+      srcName, srcHost, srcPort, dstHost, dstPort, mockedClient, *conf);
+
+  const auto timeoutInterval = MS(200);
+  const auto sleepInterval = MS(100);
+  protocol::MapperEnd request{1001};
+  bool timeoutHappened = false;
+  std::thread syncThread([&]() {
+    try {
+      auto response =
+          nettyRpcEndpointRef
+              .askSync<protocol::MapperEnd, protocol::MapperEndResponse>(
+                  request, timeoutInterval);
+    } catch (std::exception e) {
+      timeoutHappened = true;
+    }
+  });
+  std::this_thread::sleep_for(sleepInterval);
+
+  auto sentRequest = mockedClient->getRequest();
+  verifyRequestForNettyRpcEndpointRef(
+      sentRequest, srcName, srcHost, srcPort, dstHost, dstPort);
+  EXPECT_FALSE(timeoutHappened);
+
+  syncThread.join();
+  EXPECT_TRUE(timeoutHappened);
 }

Reply via email to