This is an automated email from the ASF dual-hosted git repository.
nicholasjiang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git
The following commit(s) were added to refs/heads/main by this push:
new a7735d99e [CELEBORN-2157][CIP-14] Support sending
RegisterShuffle/Revive/MapperEnd messages for NettyRpcEndpointRef in cppClient
a7735d99e is described below
commit a7735d99ec46fab102b88539e408a1afbde1ae78
Author: HolyLow <[email protected]>
AuthorDate: Fri Oct 10 15:47:00 2025 +0800
[CELEBORN-2157][CIP-14] Support sending RegisterShuffle/Revive/MapperEnd
messages for NettyRpcEndpointRef in cppClient
### What changes were proposed in this pull request?
This PR supports sending RegisterShuffle/Revive/MapperEnd messages for
NettyRpcEndpointRef in cppClient.
### Why are the changes needed?
These messages are used to communicate with CelebornMaster when writing
data with cppClient.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Compilation and UTs.
Closes #3484 from
HolyLow/issue/celeborn-2157-support-pushdata-revive-in-rpcref.
Authored-by: HolyLow <[email protected]>
Signed-off-by: SteNicholas <[email protected]>
---
cpp/celeborn/client/ShuffleClient.cpp | 12 +-
cpp/celeborn/conf/CelebornConf.cpp | 5 +
cpp/celeborn/conf/CelebornConf.h | 4 +
cpp/celeborn/network/NettyRpcEndpointRef.cpp | 45 +----
cpp/celeborn/network/NettyRpcEndpointRef.h | 49 ++++-
.../network/tests/NettyRpcEndpointRefTest.cpp | 209 +++++++++++++++++----
6 files changed, 241 insertions(+), 83 deletions(-)
diff --git a/cpp/celeborn/client/ShuffleClient.cpp
b/cpp/celeborn/client/ShuffleClient.cpp
index ccf7c6dcb..92279e4f3 100644
--- a/cpp/celeborn/client/ShuffleClient.cpp
+++ b/cpp/celeborn/client/ShuffleClient.cpp
@@ -32,7 +32,13 @@ void
ShuffleClientImpl::setupLifecycleManagerRef(std::string& host, int port) {
{
std::lock_guard<std::mutex> lock(mutex_);
lifecycleManagerRef_ = std::make_shared<network::NettyRpcEndpointRef>(
- "LifecycleManagerEndpoint", "dummy", 0, host, port, managerClient);
+ "LifecycleManagerEndpoint",
+ "dummy",
+ 0,
+ host,
+ port,
+ managerClient,
+ *conf_);
}
}
@@ -83,7 +89,9 @@ void ShuffleClientImpl::updateReducerFileGroup(int shuffleId)
{
CELEBORN_CHECK(
lifecycleManagerRef_, "lifecycleManagerRef_ is not initialized");
// Send the query request to lifecycleManager.
- auto reducerFileGroupInfo = lifecycleManagerRef_->askSync(
+ auto reducerFileGroupInfo = lifecycleManagerRef_->askSync<
+ protocol::GetReducerFileGroup,
+ protocol::GetReducerFileGroupResponse>(
protocol::GetReducerFileGroup{shuffleId},
conf_->clientRpcGetReducerFileGroupRpcAskTimeout());
diff --git a/cpp/celeborn/conf/CelebornConf.cpp
b/cpp/celeborn/conf/CelebornConf.cpp
index e21d39032..73b2b24dc 100644
--- a/cpp/celeborn/conf/CelebornConf.cpp
+++ b/cpp/celeborn/conf/CelebornConf.cpp
@@ -133,6 +133,7 @@ Duration toDuration(const std::string& str) {
const std::unordered_map<std::string, folly::Optional<std::string>>
CelebornConf::kDefaultProperties = {
+ STR_PROP(kRpcAskTimeout, "60s"),
STR_PROP(kRpcLookupTimeout, "30s"),
STR_PROP(kClientRpcGetReducerFileGroupRpcAskTimeout, "60s"),
STR_PROP(kNetworkConnectTimeout, "10s"),
@@ -175,6 +176,10 @@ void CelebornConf::registerProperty(
setValue(static_cast<std::string>(key), value);
}
+Timeout CelebornConf::rpcAskTimeout() const {
+ return
utils::toTimeout(toDuration(optionalProperty(kRpcAskTimeout).value()));
+}
+
Timeout CelebornConf::rpcLookupTimeout() const {
return utils::toTimeout(
toDuration(optionalProperty(kRpcLookupTimeout).value()));
diff --git a/cpp/celeborn/conf/CelebornConf.h b/cpp/celeborn/conf/CelebornConf.h
index 5aa3c6f9e..6ee278148 100644
--- a/cpp/celeborn/conf/CelebornConf.h
+++ b/cpp/celeborn/conf/CelebornConf.h
@@ -40,6 +40,8 @@ class CelebornConf : public BaseConf {
static const std::unordered_map<std::string, folly::Optional<std::string>>
kDefaultProperties;
+ static constexpr std::string_view kRpcAskTimeout{"celeborn.rpc.askTimeout"};
+
static constexpr std::string_view kRpcLookupTimeout{
"celeborn.rpc.lookupTimeout"};
@@ -77,6 +79,8 @@ class CelebornConf : public BaseConf {
void registerProperty(const std::string_view& key, const std::string& value);
+ Timeout rpcAskTimeout() const;
+
Timeout rpcLookupTimeout() const;
Timeout clientRpcGetReducerFileGroupRpcAskTimeout() const;
diff --git a/cpp/celeborn/network/NettyRpcEndpointRef.cpp
b/cpp/celeborn/network/NettyRpcEndpointRef.cpp
index 1ea4df0e7..e8ab4c951 100644
--- a/cpp/celeborn/network/NettyRpcEndpointRef.cpp
+++ b/cpp/celeborn/network/NettyRpcEndpointRef.cpp
@@ -26,51 +26,14 @@ NettyRpcEndpointRef::NettyRpcEndpointRef(
int srcPort,
const std::string& dstHost,
int dstPort,
- std::shared_ptr<TransportClient> client)
+ const std::shared_ptr<TransportClient>& client,
+ const conf::CelebornConf& conf)
: name_(name),
srcHost_(srcHost),
srcPort_(srcPort),
dstHost_(dstHost),
dstPort_(dstPort),
- client_(client) {}
-
-std::unique_ptr<protocol::GetReducerFileGroupResponse>
-NettyRpcEndpointRef::askSync(
- const protocol::GetReducerFileGroup& msg,
- Timeout timeout) {
- auto rpcRequest = buildRpcRequest(msg);
- auto rpcResponse = client_->sendRpcRequestSync(rpcRequest, timeout);
- return fromRpcResponse(std::move(rpcResponse));
-}
-
-RpcRequest NettyRpcEndpointRef::buildRpcRequest(
- const protocol::GetReducerFileGroup& msg) {
- auto transportData = msg.toTransportMessage().toReadOnlyByteBuffer();
- int size =
- srcHost_.size() + 3 + 4 + dstHost_.size() + 3 + 4 + name_.size() + 2 + 1;
- auto buffer = memory::ByteBuffer::createWriteOnly(size);
- // write srcAddr msg
- utils::writeRpcAddress(*buffer, srcHost_, srcPort_);
- // write dstAddr msg
- utils::writeRpcAddress(*buffer, dstHost_, dstPort_);
- // write srcName
- utils::writeUTF(*buffer, name_);
- // write the isTransportMessage flag
- buffer->write<uint8_t>(kNativeTransportMessageFlag);
- CELEBORN_CHECK_EQ(buffer->size(), size);
- auto result = memory::ByteBuffer::toReadOnly(std::move(buffer));
- auto combined = memory::ByteBuffer::concat(*result, *transportData);
- return RpcRequest(RpcRequest::nextRequestId(), std::move(combined));
-}
-
-std::unique_ptr<protocol::GetReducerFileGroupResponse>
-NettyRpcEndpointRef::fromRpcResponse(RpcResponse&& response) {
- auto body = response.body();
- uint8_t nativeTransportMessageFlag = body->read<uint8_t>();
- CELEBORN_CHECK_EQ(nativeTransportMessageFlag, kNativeTransportMessageFlag);
- auto transportMessage = protocol::TransportMessage(std::move(body));
- return protocol::GetReducerFileGroupResponse::fromTransportMessage(
- transportMessage);
-}
+ client_(client),
+ defaultTimeout_(conf.rpcAskTimeout()) {}
} // namespace network
} // namespace celeborn
diff --git a/cpp/celeborn/network/NettyRpcEndpointRef.h
b/cpp/celeborn/network/NettyRpcEndpointRef.h
index 334d4d4f1..5145f88ff 100644
--- a/cpp/celeborn/network/NettyRpcEndpointRef.h
+++ b/cpp/celeborn/network/NettyRpcEndpointRef.h
@@ -37,18 +37,50 @@ class NettyRpcEndpointRef {
int srcPort,
const std::string& dstHost,
int dstPort,
- std::shared_ptr<TransportClient> client);
+ const std::shared_ptr<TransportClient>& client,
+ const conf::CelebornConf& conf);
- // TODO: refactor to template function when needed.
- std::unique_ptr<protocol::GetReducerFileGroupResponse> askSync(
- const protocol::GetReducerFileGroup& msg,
- Timeout timeout);
+ template <class TRequest, class TResponse>
+ std::unique_ptr<TResponse> askSync(const TRequest& msg) {
+ return askSync<TRequest, TResponse>(msg, defaultTimeout_);
+ }
+
+ template <class TRequest, class TResponse>
+ std::unique_ptr<TResponse> askSync(const TRequest& msg, Timeout timeout) {
+ auto rpcRequest = buildRpcRequest<TRequest>(msg);
+ auto rpcResponse = client_->sendRpcRequestSync(rpcRequest, timeout);
+ return fromRpcResponse<TResponse>(std::move(rpcResponse));
+ }
private:
- RpcRequest buildRpcRequest(const protocol::GetReducerFileGroup& msg);
+ template <class TRequest>
+ RpcRequest buildRpcRequest(const TRequest& msg) {
+ auto transportData = msg.toTransportMessage().toReadOnlyByteBuffer();
+ int size = srcHost_.size() + 3 + 4 + dstHost_.size() + 3 + 4 +
+ name_.size() + 2 + 1;
+ auto buffer = memory::ByteBuffer::createWriteOnly(size);
+ // write srcAddr msg
+ utils::writeRpcAddress(*buffer, srcHost_, srcPort_);
+ // write dstAddr msg
+ utils::writeRpcAddress(*buffer, dstHost_, dstPort_);
+ // write srcName
+ utils::writeUTF(*buffer, name_);
+ // write the isTransportMessage flag
+ buffer->write<uint8_t>(kNativeTransportMessageFlag);
+ CELEBORN_CHECK_EQ(buffer->size(), size);
+ auto result = memory::ByteBuffer::toReadOnly(std::move(buffer));
+ auto combined = memory::ByteBuffer::concat(*result, *transportData);
+ return RpcRequest(RpcRequest::nextRequestId(), std::move(combined));
+ }
- std::unique_ptr<protocol::GetReducerFileGroupResponse> fromRpcResponse(
- RpcResponse&& response);
+ template <class TResponse>
+ std::unique_ptr<TResponse> fromRpcResponse(RpcResponse&& response) {
+ auto body = response.body();
+ bool isTransportMessage = body->read<uint8_t>();
+ CELEBORN_CHECK(isTransportMessage);
+ auto transportMessage = protocol::TransportMessage(std::move(body));
+ return TResponse::fromTransportMessage(transportMessage);
+ }
std::string name_;
std::string srcHost_;
@@ -56,6 +88,7 @@ class NettyRpcEndpointRef {
std::string dstHost_;
int dstPort_;
std::shared_ptr<TransportClient> client_;
+ Timeout defaultTimeout_;
};
} // namespace network
} // namespace celeborn
diff --git a/cpp/celeborn/network/tests/NettyRpcEndpointRefTest.cpp
b/cpp/celeborn/network/tests/NettyRpcEndpointRefTest.cpp
index 41f629364..32433c410 100644
--- a/cpp/celeborn/network/tests/NettyRpcEndpointRefTest.cpp
+++ b/cpp/celeborn/network/tests/NettyRpcEndpointRefTest.cpp
@@ -29,18 +29,18 @@ class MockTransportClient : public TransportClient {
public:
MockTransportClient()
: TransportClient(nullptr, nullptr, MS(100)),
- response_(
- RpcResponse(0, memory::ReadOnlyByteBuffer::createEmptyBuffer())),
+ respPromise_(),
+ respFuture_(respPromise_.getFuture()),
request_(
RpcRequest(0, memory::ReadOnlyByteBuffer::createEmptyBuffer())) {}
RpcResponse sendRpcRequestSync(const RpcRequest& request, Timeout timeout)
override {
request_ = request;
- return response_;
+ return std::move(respFuture_).get(timeout);
}
void setResponse(const RpcResponse& response) {
- response_ = response;
+ respPromise_.setValue(response);
}
RpcRequest getRequest() {
@@ -48,10 +48,23 @@ class MockTransportClient : public TransportClient {
}
private:
- RpcResponse response_;
+ folly::Promise<RpcResponse> respPromise_;
+ folly::Future<RpcResponse> respFuture_;
RpcRequest request_;
};
+std::unique_ptr<RpcResponse> makeResponseForNettyRpcEndpointRef(
+ const protocol::TransportMessage& transportMessage,
+ long requestId) {
+ auto msgBody = transportMessage.toReadOnlyByteBuffer();
+ auto writeBuffer = memory::ByteBuffer::createWriteOnly(sizeof(uint8_t));
+
writeBuffer->write<uint8_t>(NettyRpcEndpointRef::kNativeTransportMessageFlag);
+ auto flagBody = memory::ByteBuffer::toReadOnly(std::move(writeBuffer));
+ auto concatBody = memory::ByteBuffer::concat(*flagBody, *msgBody);
+
+ return std::make_unique<RpcResponse>(requestId, std::move(concatBody));
+}
+
void readUTF(memory::ReadOnlyByteBuffer& buffer, std::string& host) {
int size = buffer.read<short>();
host = buffer.readToString(size);
@@ -65,54 +78,186 @@ void readRpcAddress(
readUTF(buffer, host);
port = buffer.read<int32_t>();
}
+
+void verifyRequestForNettyRpcEndpointRef(
+ const RpcRequest& request,
+ const std::string& srcName,
+ const std::string& srcHost,
+ const int srcPort,
+ const std::string& dstHost,
+ const int dstPort) {
+ auto sentBody = request.body();
+ std::string host;
+ int port;
+ readRpcAddress(*sentBody, host, port);
+ EXPECT_EQ(host, srcHost);
+ EXPECT_EQ(port, srcPort);
+ readRpcAddress(*sentBody, host, port);
+ EXPECT_EQ(host, dstHost);
+ EXPECT_EQ(port, dstPort);
+ std::string name;
+ readUTF(*sentBody, name);
+ EXPECT_EQ(name, srcName);
+ EXPECT_EQ(
+ sentBody->read<uint8_t>(),
+ NettyRpcEndpointRef::kNativeTransportMessageFlag);
+}
+
} // namespace
-TEST(NettyRpcEndpointRefTest, askSyncGetReducerFileGroup) {
+TEST(NettyRpcEndpointRefTest, askSyncRegisterShuffle) {
+ auto mockedClient = std::make_shared<MockTransportClient>();
+ const std::string srcName = "test-name";
+ const std::string srcHost = "test-src-host";
+ const int srcPort = 100;
+ const std::string dstHost = "test-dst-host";
+ const int dstPort = 101;
+ const auto conf = std::make_shared<conf::CelebornConf>();
+ auto nettyRpcEndpointRef = NettyRpcEndpointRef(
+ srcName, srcHost, srcPort, dstHost, dstPort, mockedClient, *conf);
+
+ PbRegisterShuffleResponse pbRegisterShuffleResponse;
+ const int status = 5;
+ pbRegisterShuffleResponse.set_status(5);
+ protocol::TransportMessage transportMessage(
+ REGISTER_SHUFFLE_RESPONSE,
pbRegisterShuffleResponse.SerializeAsString());
+ auto rpcResponse = makeResponseForNettyRpcEndpointRef(transportMessage,
1000);
+ mockedClient->setResponse(*rpcResponse);
+
+ protocol::RegisterShuffle request{1001};
+ auto response = nettyRpcEndpointRef.askSync<
+ protocol::RegisterShuffle,
+ protocol::RegisterShuffleResponse>(request, MS(100));
+ EXPECT_EQ(response->status, status);
+
+ auto sentRequest = mockedClient->getRequest();
+ verifyRequestForNettyRpcEndpointRef(
+ sentRequest, srcName, srcHost, srcPort, dstHost, dstPort);
+}
+
+TEST(NettyRpcEndpointRefTest, askSyncRevive) {
+ auto mockedClient = std::make_shared<MockTransportClient>();
+ const std::string srcName = "test-name";
+ const std::string srcHost = "test-src-host";
+ const int srcPort = 100;
+ const std::string dstHost = "test-dst-host";
+ const int dstPort = 101;
+ const auto conf = std::make_shared<conf::CelebornConf>();
+ auto nettyRpcEndpointRef = NettyRpcEndpointRef(
+ srcName, srcHost, srcPort, dstHost, dstPort, mockedClient, *conf);
+
+ PbRevive pbRevive;
+ pbRevive.set_shuffleid(5);
+ protocol::TransportMessage transportMessage(
+ CHANGE_LOCATION_RESPONSE, pbRevive.SerializeAsString());
+ auto rpcResponse = makeResponseForNettyRpcEndpointRef(transportMessage,
1000);
+ mockedClient->setResponse(*rpcResponse);
+
+ protocol::Revive request{1001};
+ auto response =
+ nettyRpcEndpointRef
+ .askSync<protocol::Revive, protocol::ChangeLocationResponse>(
+ request, MS(100));
+
+ auto sentRequest = mockedClient->getRequest();
+ verifyRequestForNettyRpcEndpointRef(
+ sentRequest, srcName, srcHost, srcPort, dstHost, dstPort);
+}
+
+TEST(NettyRpcEndpointRefTest, askSyncMapperEnd) {
auto mockedClient = std::make_shared<MockTransportClient>();
const std::string srcName = "test-name";
const std::string srcHost = "test-src-host";
const int srcPort = 100;
const std::string dstHost = "test-dst-host";
const int dstPort = 101;
+ const auto conf = std::make_shared<conf::CelebornConf>();
auto nettyRpcEndpointRef = NettyRpcEndpointRef(
- srcName, srcHost, srcPort, dstHost, dstPort, mockedClient);
+ srcName, srcHost, srcPort, dstHost, dstPort, mockedClient, *conf);
+
+ PbMapperEndResponse pbMapperEndResponse;
+ const int status = 5;
+ pbMapperEndResponse.set_status(5);
+ protocol::TransportMessage transportMessage(
+ MAPPER_END_RESPONSE, pbMapperEndResponse.SerializeAsString());
+ auto rpcResponse = makeResponseForNettyRpcEndpointRef(transportMessage,
1000);
+ mockedClient->setResponse(*rpcResponse);
- const std::string responseBody = "test-response-body";
+ protocol::MapperEnd request{1001};
+ auto response =
+ nettyRpcEndpointRef
+ .askSync<protocol::MapperEnd, protocol::MapperEndResponse>(
+ request, MS(100));
+ EXPECT_EQ(response->status, status);
+
+ auto sentRequest = mockedClient->getRequest();
+ verifyRequestForNettyRpcEndpointRef(
+ sentRequest, srcName, srcHost, srcPort, dstHost, dstPort);
+}
+
+TEST(NettyRpcEndpointRefTest, askSyncGetReducerFileGroup) {
+ auto mockedClient = std::make_shared<MockTransportClient>();
+ const std::string srcName = "test-name";
+ const std::string srcHost = "test-src-host";
+ const int srcPort = 100;
+ const std::string dstHost = "test-dst-host";
+ const int dstPort = 101;
+ const auto conf = std::make_shared<conf::CelebornConf>();
+ auto nettyRpcEndpointRef = NettyRpcEndpointRef(
+ srcName, srcHost, srcPort, dstHost, dstPort, mockedClient, *conf);
- // rpcResponse -> transportMessage -> getReducerFileGroupResponse
PbGetReducerFileGroupResponse pbGetReducerFileGroupResponse;
const int status = 5;
pbGetReducerFileGroupResponse.set_status(5);
protocol::TransportMessage transportMessage(
GET_REDUCER_FILE_GROUP_RESPONSE,
pbGetReducerFileGroupResponse.SerializeAsString());
- auto msgBody = transportMessage.toReadOnlyByteBuffer();
- auto writeBuffer = memory::ByteBuffer::createWriteOnly(sizeof(uint8_t));
-
writeBuffer->write<uint8_t>(NettyRpcEndpointRef::kNativeTransportMessageFlag);
- auto flagBody = memory::ByteBuffer::toReadOnly(std::move(writeBuffer));
- auto concatBody = memory::ByteBuffer::concat(*flagBody, *msgBody);
-
- auto rpcResponse = std::make_unique<RpcResponse>(1000,
std::move(concatBody));
+ auto rpcResponse = makeResponseForNettyRpcEndpointRef(transportMessage,
1000);
mockedClient->setResponse(*rpcResponse);
protocol::GetReducerFileGroup request{1001};
- auto response = nettyRpcEndpointRef.askSync(request, MS(100));
+ auto response = nettyRpcEndpointRef.askSync<
+ protocol::GetReducerFileGroup,
+ protocol::GetReducerFileGroupResponse>(request, MS(100));
EXPECT_EQ(response->status, status);
auto sentRequest = mockedClient->getRequest();
- auto sentBody = sentRequest.body();
- std::string host;
- int port;
- readRpcAddress(*sentBody, host, port);
- EXPECT_EQ(host, srcHost);
- EXPECT_EQ(port, srcPort);
- readRpcAddress(*sentBody, host, port);
- EXPECT_EQ(host, dstHost);
- EXPECT_EQ(port, dstPort);
- std::string name;
- readUTF(*sentBody, name);
- EXPECT_EQ(name, srcName);
- EXPECT_EQ(
- sentBody->read<uint8_t>(),
- NettyRpcEndpointRef::kNativeTransportMessageFlag);
+ verifyRequestForNettyRpcEndpointRef(
+ sentRequest, srcName, srcHost, srcPort, dstHost, dstPort);
+}
+
+TEST(NettyRpcEndpointRefTest, askSyncTimeout) {
+ auto mockedClient = std::make_shared<MockTransportClient>();
+ const std::string srcName = "test-name";
+ const std::string srcHost = "test-src-host";
+ const int srcPort = 100;
+ const std::string dstHost = "test-dst-host";
+ const int dstPort = 101;
+ const auto conf = std::make_shared<conf::CelebornConf>();
+ auto nettyRpcEndpointRef = NettyRpcEndpointRef(
+ srcName, srcHost, srcPort, dstHost, dstPort, mockedClient, *conf);
+
+ const auto timeoutInterval = MS(200);
+ const auto sleepInterval = MS(100);
+ protocol::MapperEnd request{1001};
+ bool timeoutHappened = false;
+ std::thread syncThread([&]() {
+ try {
+ auto response =
+ nettyRpcEndpointRef
+ .askSync<protocol::MapperEnd, protocol::MapperEndResponse>(
+ request, timeoutInterval);
+ } catch (std::exception e) {
+ timeoutHappened = true;
+ }
+ });
+ std::this_thread::sleep_for(sleepInterval);
+
+ auto sentRequest = mockedClient->getRequest();
+ verifyRequestForNettyRpcEndpointRef(
+ sentRequest, srcName, srcHost, srcPort, dstHost, dstPort);
+ EXPECT_FALSE(timeoutHappened);
+
+ syncThread.join();
+ EXPECT_TRUE(timeoutHappened);
}