This is an automated email from the ASF dual-hosted git repository.
ethanfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git
The following commit(s) were added to refs/heads/main by this push:
new b74e05b60 [CELEBORN-1821][CIP-14] Add controlMessages to cppClient
b74e05b60 is described below
commit b74e05b603bed45e8d719b2c0e7988a552f578c3
Author: HolyLow <[email protected]>
AuthorDate: Wed Jan 8 13:35:09 2025 +0800
[CELEBORN-1821][CIP-14] Add controlMessages to cppClient
### What changes were proposed in this pull request?
This PR adds ControlMessages to cppClient.
### Why are the changes needed?
The ControlMessages are used to communicate with the CelebornServer and
LifecycleManager.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Compilation and UTs.
Closes #3052 from
HolyLow/issue/celeborn-1821-add-control-messages-to-cpp-client.
Authored-by: HolyLow <[email protected]>
Signed-off-by: mingji <[email protected]>
---
cpp/celeborn/protocol/CMakeLists.txt | 3 +-
cpp/celeborn/protocol/ControlMessages.cpp | 182 +++++++++++++++
cpp/celeborn/protocol/ControlMessages.h | 107 +++++++++
cpp/celeborn/protocol/PartitionLocation.cpp | 31 +++
cpp/celeborn/protocol/PartitionLocation.h | 4 +
cpp/celeborn/protocol/tests/CMakeLists.txt | 3 +-
.../protocol/tests/ControlMessagesTest.cpp | 244 +++++++++++++++++++++
cpp/celeborn/utils/CMakeLists.txt | 3 +-
cpp/celeborn/utils/CelebornUtils.cpp | 61 ++++++
cpp/celeborn/utils/CelebornUtils.h | 57 ++++-
10 files changed, 691 insertions(+), 4 deletions(-)
diff --git a/cpp/celeborn/protocol/CMakeLists.txt
b/cpp/celeborn/protocol/CMakeLists.txt
index 6a6ef88a7..aa601c479 100644
--- a/cpp/celeborn/protocol/CMakeLists.txt
+++ b/cpp/celeborn/protocol/CMakeLists.txt
@@ -16,7 +16,8 @@ add_library(
protocol
STATIC
PartitionLocation.cpp
- TransportMessage.cpp)
+ TransportMessage.cpp
+ ControlMessages.cpp)
target_include_directories(protocol PUBLIC ${CMAKE_BINARY_DIR})
diff --git a/cpp/celeborn/protocol/ControlMessages.cpp
b/cpp/celeborn/protocol/ControlMessages.cpp
new file mode 100644
index 000000000..1a8fa6ce3
--- /dev/null
+++ b/cpp/celeborn/protocol/ControlMessages.cpp
@@ -0,0 +1,182 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "celeborn/protocol/ControlMessages.h"
+#include "celeborn/utils/CelebornUtils.h"
+
+namespace celeborn {
+namespace protocol {
+TransportMessage GetReducerFileGroup::toTransportMessage() const {
+ MessageType type = GET_REDUCER_FILE_GROUP;
+ PbGetReducerFileGroup pb;
+ pb.set_shuffleid(shuffleId);
+ std::string payload = pb.SerializeAsString();
+ return TransportMessage(type, std::move(payload));
+}
+
+std::unique_ptr<GetReducerFileGroupResponse>
+GetReducerFileGroupResponse::fromTransportMessage(
+ const TransportMessage& transportMessage) {
+ CELEBORN_CHECK(
+ transportMessage.type() == GET_REDUCER_FILE_GROUP_RESPONSE,
+ "transportMessageType mismatch");
+ auto payload = transportMessage.payload();
+ auto pbGetReducerFileGroupResponse =
+ utils::parseProto<PbGetReducerFileGroupResponse>(
+ reinterpret_cast<const uint8_t*>(payload.c_str()), payload.size());
+ auto response = std::make_unique<GetReducerFileGroupResponse>();
+ response->status = toStatusCode(pbGetReducerFileGroupResponse->status());
+ auto fileGroups = pbGetReducerFileGroupResponse->filegroups();
+ for (auto& kv : fileGroups) {
+ auto& fileGroup = response->fileGroups[kv.first];
+ // Legacy mode is deprecated.
+ CELEBORN_CHECK_EQ(
+ kv.second.locations().size(),
+ 0,
+ "legecy PartitionLocation pb is deprecated");
+ // Packed mode: must use packedPartitionLocations.
+ auto& pbPackedPartitionLocationsPair = kv.second.partitionlocationspair();
+ int inputLocationSize = pbPackedPartitionLocationsPair.inputlocationsize();
+ auto& pbPackedPartitionLocations =
+ pbPackedPartitionLocationsPair.locations();
+ std::vector<std::unique_ptr<PartitionLocation>> partialLocations;
+ auto& pbIds = pbPackedPartitionLocations.ids();
+ for (int idx = 0; idx < pbIds.size(); idx++) {
+ partialLocations.push_back(
+ PartitionLocation::fromPackedPb(pbPackedPartitionLocations, idx));
+ }
+ for (int idx = 0; idx < inputLocationSize; idx++) {
+ auto replicaIdx = pbPackedPartitionLocationsPair.peerindexes(idx);
+ // has peer
+ if (replicaIdx != INT_MAX) {
+ CELEBORN_CHECK_GE(replicaIdx, inputLocationSize);
+ CELEBORN_CHECK_LT(replicaIdx, partialLocations.size());
+ auto location = std::move(partialLocations[idx]);
+ auto peerLocation = std::move(partialLocations[replicaIdx]);
+ // make sure the location is primary and peer location is replica
+ if (location->mode == PartitionLocation::Mode::REPLICA) {
+ std::swap(location, peerLocation);
+ }
+ CELEBORN_CHECK(location->mode == PartitionLocation::Mode::PRIMARY);
+ CELEBORN_CHECK(peerLocation->mode == PartitionLocation::Mode::REPLICA);
+ location->replicaPeer = std::move(peerLocation);
+ fileGroup.insert(std::move(location));
+ }
+ // has no peer
+ else {
+ fileGroup.insert(std::move(partialLocations[idx]));
+ }
+ }
+ }
+ auto attempts = pbGetReducerFileGroupResponse->attempts();
+ response->attempts.reserve(attempts.size());
+ for (auto attempt : attempts) {
+ response->attempts.push_back(attempt);
+ }
+ auto partitionIds = pbGetReducerFileGroupResponse->partitionids();
+ for (auto partitionId : partitionIds) {
+ response->partitionIds.insert(partitionId);
+ }
+ return std::move(response);
+}
+
+OpenStream::OpenStream(
+ const std::string& shuffleKey,
+ const std::string& filename,
+ int32_t startMapIndex,
+ int32_t endMapIndex)
+ : shuffleKey(shuffleKey),
+ filename(filename),
+ startMapIndex(startMapIndex),
+ endMapIndex(endMapIndex) {}
+
+TransportMessage OpenStream::toTransportMessage() const {
+ MessageType type = OPEN_STREAM;
+ PbOpenStream pb;
+ pb.set_shufflekey(shuffleKey);
+ pb.set_filename(filename);
+ pb.set_startindex(startMapIndex);
+ pb.set_endindex(endMapIndex);
+ std::string payload = pb.SerializeAsString();
+ return TransportMessage(type, std::move(payload));
+}
+
+std::unique_ptr<StreamHandler> StreamHandler::fromTransportMessage(
+ const TransportMessage& transportMessage) {
+ CELEBORN_CHECK(
+ transportMessage.type() == STREAM_HANDLER,
+ "transportMessageType should be STREAM_HANDLER");
+ auto payload = transportMessage.payload();
+ auto pbStreamHandler = utils::parseProto<PbStreamHandler>(
+ reinterpret_cast<const uint8_t*>(payload.c_str()), payload.size());
+ auto streamHandler = std::make_unique<StreamHandler>();
+ streamHandler->streamId = pbStreamHandler->streamid();
+ streamHandler->numChunks = pbStreamHandler->numchunks();
+ for (auto chunkOffset : pbStreamHandler->chunkoffsets()) {
+ streamHandler->chunkOffsets.push_back(chunkOffset);
+ }
+ streamHandler->fullPath = pbStreamHandler->fullpath();
+ return std::move(streamHandler);
+}
+
+std::unique_ptr<PbStreamChunkSlice> StreamChunkSlice::toProto() const {
+ auto pb = std::make_unique<PbStreamChunkSlice>();
+ pb->set_streamid(streamId);
+ pb->set_chunkindex(chunkIndex);
+ pb->set_offset(offset);
+ pb->set_len(len);
+ return std::move(pb);
+}
+
+StreamChunkSlice StreamChunkSlice::decodeFrom(
+ memory::ReadOnlyByteBuffer& data) {
+ CELEBORN_CHECK_GE(data.remainingSize(), 20);
+ StreamChunkSlice slice;
+ slice.streamId = data.read<long>();
+ slice.chunkIndex = data.read<int>();
+ slice.offset = data.read<int>();
+ slice.len = data.read<int>();
+ return slice;
+}
+
+size_t StreamChunkSlice::Hasher::operator()(const StreamChunkSlice& lhs) const
{
+ const auto hashStreamId = std::hash<long>()(lhs.streamId);
+ const auto hashChunkIndex = std::hash<int>()(lhs.chunkIndex) << 1;
+ const auto hashOffset = std::hash<int>()(lhs.offset) << 2;
+ const auto hashLen = std::hash<int>()(lhs.len) << 3;
+ return hashStreamId ^ hashChunkIndex ^ hashOffset ^ hashLen;
+}
+
+TransportMessage ChunkFetchRequest::toTransportMessage() const {
+ MessageType type = CHUNK_FETCH_REQUEST;
+ PbChunkFetchRequest pb;
+ pb.unsafe_arena_set_allocated_streamchunkslice(
+ streamChunkSlice.toProto().release());
+ std::string payload = pb.SerializeAsString();
+ return TransportMessage(type, std::move(payload));
+}
+
+TransportMessage BufferStreamEnd::toTransportMessage() const {
+ MessageType type = BUFFER_STREAM_END;
+ PbBufferStreamEnd pb;
+ pb.set_streamtype(ChunkStream);
+ pb.set_streamid(streamId);
+ std::string payload = pb.SerializeAsString();
+ return TransportMessage(type, std::move(payload));
+}
+} // namespace protocol
+} // namespace celeborn
diff --git a/cpp/celeborn/protocol/ControlMessages.h
b/cpp/celeborn/protocol/ControlMessages.h
new file mode 100644
index 000000000..7d74a4900
--- /dev/null
+++ b/cpp/celeborn/protocol/ControlMessages.h
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include <map>
+#include <set>
+
+#include "celeborn/protocol/PartitionLocation.h"
+#include "celeborn/protocol/StatusCode.h"
+#include "celeborn/protocol/TransportMessage.h"
+
+namespace celeborn {
+namespace protocol {
+struct GetReducerFileGroup {
+ int shuffleId;
+
+ TransportMessage toTransportMessage() const;
+};
+
+struct GetReducerFileGroupResponse {
+ StatusCode status;
+ std::map<int, std::set<std::shared_ptr<const PartitionLocation>>> fileGroups;
+ std::vector<int> attempts;
+ std::set<int> partitionIds;
+
+ static std::unique_ptr<GetReducerFileGroupResponse> fromTransportMessage(
+ const TransportMessage& transportMessage);
+};
+
+struct OpenStream {
+ std::string shuffleKey;
+ std::string filename;
+ int32_t startMapIndex;
+ int32_t endMapIndex;
+
+ OpenStream(
+ const std::string& shuffleKey,
+ const std::string& filename,
+ int32_t startMapIndex,
+ int32_t endMapIndex);
+
+ TransportMessage toTransportMessage() const;
+};
+
+struct StreamHandler {
+ int64_t streamId;
+ int32_t numChunks;
+ std::vector<int64_t> chunkOffsets;
+ std::string fullPath;
+
+ static std::unique_ptr<StreamHandler> fromTransportMessage(
+ const TransportMessage& transportMessage);
+};
+
+struct StreamChunkSlice {
+ long streamId;
+ int chunkIndex;
+ int offset{0};
+ int len{INT_MAX};
+
+ std::unique_ptr<PbStreamChunkSlice> toProto() const;
+
+ static StreamChunkSlice decodeFrom(memory::ReadOnlyByteBuffer& data);
+
+ std::string toString() const {
+ return std::to_string(streamId) + "-" + std::to_string(chunkIndex) + "-" +
+ std::to_string(offset) + "-" + std::to_string(len);
+ }
+
+ bool operator==(const StreamChunkSlice& rhs) const {
+ return streamId == rhs.streamId && chunkIndex == rhs.chunkIndex &&
+ offset == rhs.offset && len == rhs.len;
+ }
+
+ struct Hasher {
+ size_t operator()(const StreamChunkSlice& lhs) const;
+ };
+};
+
+struct ChunkFetchRequest {
+ StreamChunkSlice streamChunkSlice;
+
+ TransportMessage toTransportMessage() const;
+};
+
+struct BufferStreamEnd {
+ long streamId;
+
+ TransportMessage toTransportMessage() const;
+};
+} // namespace protocol
+} // namespace celeborn
diff --git a/cpp/celeborn/protocol/PartitionLocation.cpp
b/cpp/celeborn/protocol/PartitionLocation.cpp
index aee59586b..5cd9bdecb 100644
--- a/cpp/celeborn/protocol/PartitionLocation.cpp
+++ b/cpp/celeborn/protocol/PartitionLocation.cpp
@@ -53,6 +53,37 @@ std::unique_ptr<const PartitionLocation>
PartitionLocation::fromPb(
return std::move(result);
}
+
+std::unique_ptr<PartitionLocation> PartitionLocation::fromPackedPb(
+ const PbPackedPartitionLocations& pb,
+ int idx) {
+ auto& workerIdStr = pb.workeridsset(pb.workerids(idx));
+ auto workerIdParts = utils::parseColonSeparatedHostPorts(workerIdStr, 4);
+ std::string filePath = pb.filepaths(idx);
+ if (!filePath.empty()) {
+ filePath = pb.mountpointsset(pb.mountpoints(idx)) + pb.filepaths(idx);
+ }
+
+ auto result = std::make_unique<PartitionLocation>();
+ result->id = pb.ids(idx);
+ result->epoch = pb.epoches(idx);
+ result->host = workerIdParts[0];
+ result->rpcPort = utils::strv2val<int>(workerIdParts[1]);
+ result->pushPort = utils::strv2val<int>(workerIdParts[2]);
+ result->fetchPort = utils::strv2val<int>(workerIdParts[3]);
+ result->replicatePort = utils::strv2val<int>(workerIdParts[4]);
+ result->mode = static_cast<Mode>(pb.modes(idx));
+ result->replicaPeer = nullptr;
+ result->storageInfo = std::make_unique<StorageInfo>();
+ result->storageInfo->type = static_cast<StorageInfo::Type>(pb.types(idx));
+ result->storageInfo->mountPoint = pb.mountpointsset(pb.mountpoints(idx));
+ result->storageInfo->finalResult = pb.finalresult(idx);
+ result->storageInfo->filePath = filePath;
+ result->storageInfo->availableStorageTypes = pb.availablestoragetypes(idx);
+
+ return std::move(result);
+}
+
PartitionLocation::PartitionLocation(const PartitionLocation& other)
: id(other.id),
epoch(other.epoch),
diff --git a/cpp/celeborn/protocol/PartitionLocation.h
b/cpp/celeborn/protocol/PartitionLocation.h
index c485c1607..700773cba 100644
--- a/cpp/celeborn/protocol/PartitionLocation.h
+++ b/cpp/celeborn/protocol/PartitionLocation.h
@@ -80,6 +80,10 @@ struct PartitionLocation {
static std::unique_ptr<const PartitionLocation> fromPb(
const PbPartitionLocation& pb);
+ static std::unique_ptr<PartitionLocation> fromPackedPb(
+ const PbPackedPartitionLocations& pb,
+ int idx);
+
PartitionLocation() = default;
PartitionLocation(const PartitionLocation& other);
diff --git a/cpp/celeborn/protocol/tests/CMakeLists.txt
b/cpp/celeborn/protocol/tests/CMakeLists.txt
index f5a00db2a..cb2a2378f 100644
--- a/cpp/celeborn/protocol/tests/CMakeLists.txt
+++ b/cpp/celeborn/protocol/tests/CMakeLists.txt
@@ -16,7 +16,8 @@
add_executable(
celeborn_protocol_test
PartitionLocationTest.cpp
- TransportMessageTest.cpp)
+ TransportMessageTest.cpp
+ ControlMessagesTest.cpp)
add_test(NAME celeborn_protocol_test COMMAND celeborn_protocol_test)
diff --git a/cpp/celeborn/protocol/tests/ControlMessagesTest.cpp
b/cpp/celeborn/protocol/tests/ControlMessagesTest.cpp
new file mode 100644
index 000000000..36003491a
--- /dev/null
+++ b/cpp/celeborn/protocol/tests/ControlMessagesTest.cpp
@@ -0,0 +1,244 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <gtest/gtest.h>
+
+#include "celeborn/proto/TransportMessagesCpp.pb.h"
+#include "celeborn/protocol/ControlMessages.h"
+
+#include "celeborn/utils/CelebornUtils.h"
+
+using namespace celeborn;
+using namespace celeborn::protocol;
+
+namespace {
+void generatePackedPartitionLocationPb(
+ PbPackedPartitionLocations& pbPackedPartitionLocations,
+ int idx,
+ PartitionLocation::Mode mode) {
+ pbPackedPartitionLocations.add_ids(1);
+ pbPackedPartitionLocations.add_epoches(101);
+ pbPackedPartitionLocations.add_workerids(idx);
+ pbPackedPartitionLocations.add_workeridsset("test-host:1001:1002:1003:1004");
+ pbPackedPartitionLocations.add_mountpoints(idx);
+ pbPackedPartitionLocations.add_mountpointsset("test-mountpoint/");
+ pbPackedPartitionLocations.add_filepaths("test-filepath");
+ pbPackedPartitionLocations.add_types(1);
+ pbPackedPartitionLocations.add_finalresult(true);
+ pbPackedPartitionLocations.add_availablestoragetypes(1);
+ pbPackedPartitionLocations.add_modes(mode);
+}
+
+void verifyUnpackedPartitionLocation(
+ const PartitionLocation* partitionLocation) {
+ EXPECT_EQ(partitionLocation->id, 1);
+ EXPECT_EQ(partitionLocation->epoch, 101);
+ EXPECT_EQ(partitionLocation->host, "test-host");
+ EXPECT_EQ(partitionLocation->rpcPort, 1001);
+ EXPECT_EQ(partitionLocation->pushPort, 1002);
+ EXPECT_EQ(partitionLocation->fetchPort, 1003);
+ EXPECT_EQ(partitionLocation->replicatePort, 1004);
+
+ auto storageInfo = partitionLocation->storageInfo.get();
+ EXPECT_EQ(storageInfo->type, 1);
+ EXPECT_EQ(storageInfo->mountPoint, "test-mountpoint/");
+ EXPECT_EQ(storageInfo->finalResult, true);
+ EXPECT_EQ(storageInfo->filePath, "test-mountpoint/test-filepath");
+ EXPECT_EQ(storageInfo->availableStorageTypes, 1);
+}
+} // namespace
+
+TEST(ControlMessagesTest, getReducerFileGroup) {
+ auto getReducerFileGroup = std::make_unique<GetReducerFileGroup>();
+ getReducerFileGroup->shuffleId = 1000;
+
+ auto transportMessage = getReducerFileGroup->toTransportMessage();
+ EXPECT_EQ(transportMessage.type(), GET_REDUCER_FILE_GROUP);
+ auto payload = transportMessage.payload();
+ auto pbGetReducerFileGroup = utils::parseProto<PbGetReducerFileGroup>(
+ reinterpret_cast<const uint8_t*>(payload.c_str()), payload.size());
+ EXPECT_EQ(pbGetReducerFileGroup->shuffleid(), 1000);
+}
+
+TEST(ControlMessagesTest, getReducerFileGroupResponseLegacyModeDeprecated) {
+ PbGetReducerFileGroupResponse pbGetReducerFileGroupResponse;
+ pbGetReducerFileGroupResponse.set_status(1);
+ for (int i = 0; i < 4; i++) {
+ pbGetReducerFileGroupResponse.add_attempts(i);
+ }
+ for (int i = 0; i < 6; i++) {
+ pbGetReducerFileGroupResponse.add_partitionids(i);
+ }
+ auto id2FileGroups = pbGetReducerFileGroupResponse.mutable_filegroups();
+ PbFileGroup pbFileGroup;
+ pbFileGroup.add_locations();
+ id2FileGroups->insert({0, pbFileGroup});
+
+ TransportMessage transportMessage(
+ GET_REDUCER_FILE_GROUP_RESPONSE,
+ pbGetReducerFileGroupResponse.SerializeAsString());
+ EXPECT_THROW(
+ GetReducerFileGroupResponse::fromTransportMessage(transportMessage),
+ utils::CelebornRuntimeError);
+}
+
+TEST(ControlMessagesTest, getReducerFileGroupResponsePackedMode) {
+ PbGetReducerFileGroupResponse pbGetReducerFileGroupResponse;
+ pbGetReducerFileGroupResponse.set_status(1);
+ for (int i = 0; i < 4; i++) {
+ pbGetReducerFileGroupResponse.add_attempts(i);
+ }
+ for (int i = 0; i < 6; i++) {
+ pbGetReducerFileGroupResponse.add_partitionids(i);
+ }
+
+ PbFileGroup pbFileGroup;
+ auto pbPackedPartitionLocationsPair =
+ pbFileGroup.mutable_partitionlocationspair();
+ auto pbPackedPartitionLocations =
+ pbPackedPartitionLocationsPair->mutable_locations();
+ // Has one inputLocation, with offset 0.
+ pbPackedPartitionLocationsPair->set_inputlocationsize(1);
+ // The peerIndex 1 is replica.
+ pbPackedPartitionLocationsPair->add_peerindexes(1);
+ // Add the two partitionLocations, one is primary and the other is replica.
+ generatePackedPartitionLocationPb(
+ *pbPackedPartitionLocations, 0, PartitionLocation::Mode::PRIMARY);
+ generatePackedPartitionLocationPb(
+ *pbPackedPartitionLocations, 1, PartitionLocation::Mode::REPLICA);
+
+ auto id2FileGroups = pbGetReducerFileGroupResponse.mutable_filegroups();
+ id2FileGroups->insert({0, pbFileGroup});
+
+ TransportMessage transportMessage(
+ GET_REDUCER_FILE_GROUP_RESPONSE,
+ pbGetReducerFileGroupResponse.SerializeAsString());
+ auto getReducerFileGroupResponse =
+ GetReducerFileGroupResponse::fromTransportMessage(transportMessage);
+ EXPECT_EQ(getReducerFileGroupResponse->status, 1);
+ EXPECT_EQ(getReducerFileGroupResponse->attempts.size(), 4);
+ for (int i = 0; i < 4; i++) {
+ EXPECT_EQ(getReducerFileGroupResponse->attempts[i], i);
+ }
+ EXPECT_EQ(getReducerFileGroupResponse->partitionIds.size(), 6);
+ for (int i = 0; i < 6; i++) {
+ EXPECT_EQ(getReducerFileGroupResponse->partitionIds.count(i), 1);
+ }
+ EXPECT_EQ(getReducerFileGroupResponse->fileGroups.size(), 1);
+ const auto& partitionLocations = getReducerFileGroupResponse->fileGroups[0];
+ EXPECT_EQ(partitionLocations.size(), 1);
+ auto primaryPartitionLocation = partitionLocations.begin()->get();
+ verifyUnpackedPartitionLocation(primaryPartitionLocation);
+ EXPECT_EQ(primaryPartitionLocation->mode, PartitionLocation::Mode::PRIMARY);
+ auto replicaPartitionLocation = primaryPartitionLocation->replicaPeer.get();
+ verifyUnpackedPartitionLocation(replicaPartitionLocation);
+ EXPECT_EQ(replicaPartitionLocation->mode, PartitionLocation::Mode::REPLICA);
+}
+
+TEST(ControlMessagesTest, openStream) {
+ auto openStream = std::make_unique<OpenStream>(
+ "test-shuffle-key", "test-filename", 100, 200);
+ auto transportMessage = openStream->toTransportMessage();
+ EXPECT_EQ(transportMessage.type(), OPEN_STREAM);
+ auto payload = transportMessage.payload();
+ auto pbOpenStream = utils::parseProto<PbOpenStream>(
+ reinterpret_cast<const uint8_t*>(payload.c_str()), payload.size());
+ EXPECT_EQ(pbOpenStream->shufflekey(), "test-shuffle-key");
+ EXPECT_EQ(pbOpenStream->filename(), "test-filename");
+ EXPECT_EQ(pbOpenStream->startindex(), 100);
+ EXPECT_EQ(pbOpenStream->endindex(), 200);
+}
+
+TEST(ControlMessagesTest, streamHandler) {
+ PbStreamHandler pb;
+ pb.set_streamid(100);
+ pb.set_numchunks(4);
+ for (int i = 0; i < 4; i++) {
+ pb.add_chunkoffsets(i);
+ }
+ pb.set_fullpath("test-fullpath");
+ TransportMessage transportMessage(STREAM_HANDLER, pb.SerializeAsString());
+
+ auto streamHandler = StreamHandler::fromTransportMessage(transportMessage);
+ EXPECT_EQ(streamHandler->streamId, 100);
+ EXPECT_EQ(streamHandler->numChunks, 4);
+ EXPECT_EQ(streamHandler->chunkOffsets.size(), 4);
+ for (int i = 0; i < 4; i++) {
+ EXPECT_EQ(streamHandler->chunkOffsets[i], i);
+ }
+ EXPECT_EQ(streamHandler->fullPath, "test-fullpath");
+}
+
+TEST(ControlMessagesTest, streamChunkSlice) {
+ StreamChunkSlice streamChunkSlice;
+ streamChunkSlice.streamId = 100;
+ streamChunkSlice.chunkIndex = 1000;
+ streamChunkSlice.offset = 111;
+ streamChunkSlice.len = 1111;
+
+ auto pb = streamChunkSlice.toProto();
+ EXPECT_EQ(pb->streamid(), 100);
+ EXPECT_EQ(pb->chunkindex(), 1000);
+ EXPECT_EQ(pb->offset(), 111);
+ EXPECT_EQ(pb->len(), 1111);
+
+ auto writeBuffer = memory::ByteBuffer::createWriteOnly(20);
+ writeBuffer->write<long>(streamChunkSlice.streamId);
+ writeBuffer->write<int>(streamChunkSlice.chunkIndex);
+ writeBuffer->write<int>(streamChunkSlice.offset);
+ writeBuffer->write<int>(streamChunkSlice.len);
+ auto readBuffer = memory::ByteBuffer::toReadOnly(std::move(writeBuffer));
+ auto decodedStreamChunkSlice = StreamChunkSlice::decodeFrom(*readBuffer);
+ EXPECT_EQ(readBuffer->remainingSize(), 0);
+ EXPECT_EQ(streamChunkSlice.streamId, decodedStreamChunkSlice.streamId);
+ EXPECT_EQ(streamChunkSlice.chunkIndex, decodedStreamChunkSlice.chunkIndex);
+ EXPECT_EQ(streamChunkSlice.offset, decodedStreamChunkSlice.offset);
+ EXPECT_EQ(streamChunkSlice.len, decodedStreamChunkSlice.len);
+}
+
+TEST(ControlMessagesTest, chunkFetchRequest) {
+ ChunkFetchRequest chunkFetchRequest;
+ StreamChunkSlice streamChunkSlice;
+ streamChunkSlice.streamId = 100;
+ streamChunkSlice.chunkIndex = 1000;
+ streamChunkSlice.offset = 111;
+ streamChunkSlice.len = 1111;
+ chunkFetchRequest.streamChunkSlice = streamChunkSlice;
+
+ auto transportMessage = chunkFetchRequest.toTransportMessage();
+ EXPECT_EQ(transportMessage.type(), CHUNK_FETCH_REQUEST);
+ auto payload = transportMessage.payload();
+ auto pbChunkFetchRequest = utils::parseProto<PbChunkFetchRequest>(
+ reinterpret_cast<const uint8_t*>(payload.c_str()), payload.size());
+ auto pbStreamChunkSlice = pbChunkFetchRequest->streamchunkslice();
+ EXPECT_EQ(pbStreamChunkSlice.streamid(), 100);
+ EXPECT_EQ(pbStreamChunkSlice.chunkindex(), 1000);
+ EXPECT_EQ(pbStreamChunkSlice.offset(), 111);
+ EXPECT_EQ(pbStreamChunkSlice.len(), 1111);
+}
+
+TEST(ControlMessagesTest, bufferStreamEnd) {
+ BufferStreamEnd bufferStreamEnd;
+ bufferStreamEnd.streamId = 111111;
+
+ auto transportMessage = bufferStreamEnd.toTransportMessage();
+ EXPECT_EQ(transportMessage.type(), BUFFER_STREAM_END);
+ auto payload = transportMessage.payload();
+ auto pb = utils::parseProto<PbBufferStreamEnd>(
+ reinterpret_cast<const uint8_t*>(payload.c_str()), payload.size());
+ EXPECT_EQ(pb->streamid(), 111111);
+}
diff --git a/cpp/celeborn/utils/CMakeLists.txt
b/cpp/celeborn/utils/CMakeLists.txt
index 19f3efc44..600d3efaf 100644
--- a/cpp/celeborn/utils/CMakeLists.txt
+++ b/cpp/celeborn/utils/CMakeLists.txt
@@ -19,7 +19,8 @@ add_library(
StackTrace.cpp
CelebornException.cpp
Exceptions.cpp
- flags.cpp)
+ flags.cpp
+ CelebornUtils.cpp)
target_link_libraries(
utils
diff --git a/cpp/celeborn/utils/CelebornUtils.cpp
b/cpp/celeborn/utils/CelebornUtils.cpp
new file mode 100644
index 000000000..c91c0ab5f
--- /dev/null
+++ b/cpp/celeborn/utils/CelebornUtils.cpp
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "celeborn/utils/CelebornUtils.h"
+
+namespace celeborn {
+namespace utils {
+std::vector<std::string_view> parseColonSeparatedHostPorts(
+ const std::string_view& s,
+ int num) {
+ auto parsed = explode(s, ':');
+ CELEBORN_CHECK_GT(parsed.size(), num);
+ std::vector<std::string_view> result;
+ result.resize(num + 1);
+ size_t size = 0;
+ for (int result_idx = 1, parsed_idx = parsed.size() - num;
+ result_idx <= num;
+ result_idx++, parsed_idx++) {
+ result[result_idx] = parsed[parsed_idx];
+ size += parsed[parsed_idx].size() + 1;
+ }
+ result[0] = s.substr(0, s.size() - size);
+ return result;
+}
+
+std::vector<std::string_view> explode(const std::string_view& s, char delim) {
+ std::vector<std::string_view> result;
+ std::string_view::size_type i = 0;
+ while (i < s.size()) {
+ auto j = s.find(delim, i);
+ if (j == std::string::npos) {
+ j = s.size();
+ }
+ result.emplace_back(s.substr(i, j - i));
+ i = j + 1;
+ }
+ return result;
+}
+
+std::tuple<std::string_view, std::string_view> split(
+ const std::string_view& s,
+ char delim) {
+ auto pos = s.find(delim);
+ return std::make_tuple(s.substr(0, pos), s.substr(pos + 1));
+}
+} // namespace utils
+} // namespace celeborn
diff --git a/cpp/celeborn/utils/CelebornUtils.h
b/cpp/celeborn/utils/CelebornUtils.h
index 158b35944..4c1bd63fc 100644
--- a/cpp/celeborn/utils/CelebornUtils.h
+++ b/cpp/celeborn/utils/CelebornUtils.h
@@ -17,7 +17,10 @@
#pragma once
+#include <google/protobuf/io/coded_stream.h>
+#include <charconv>
#include <chrono>
+#include <vector>
#include "celeborn/utils/Exceptions.h"
@@ -31,12 +34,64 @@ namespace utils {
#define CELEBORN_SHUTDOWN_LOG(severity) \
LOG(severity) << CELEBORN_SHUTDOWN_LOG_PREFIX
-
using Duration = std::chrono::duration<double>;
using Timeout = std::chrono::milliseconds;
inline Timeout toTimeout(Duration duration) {
return std::chrono::duration_cast<Timeout>(duration);
+}
+
+/// parse string like "Any-Host-Str:Port#1:Port#2:...:Port#num", split into
+/// {"Any-Host-Str", "Port#1", "Port#2", ..., "Port#num"}. Note that the
+/// "Any-Host_Str" might contain ':' in IPV6 address.
+std::vector<std::string_view> parseColonSeparatedHostPorts(
+ const std::string_view& s,
+ int num);
+
+std::vector<std::string_view> explode(const std::string_view& s, char delim);
+std::tuple<std::string_view, std::string_view> split(
+ const std::string_view& s,
+ char delim);
+
+template <class T>
+T strv2val(const std::string_view& s) {
+ T t;
+ const char* first = s.data();
+ const char* last = s.data() + s.size();
+ std::from_chars_result res = std::from_chars(first, last, t);
+
+ // These two exceptions reflect the behavior of std::stoi.
+ if (res.ec == std::errc::invalid_argument) {
+ CELEBORN_FAIL("Invalid argument when parsing");
+ } else if (res.ec == std::errc::result_out_of_range) {
+ CELEBORN_FAIL("Out of range when parsing");
+ }
+ return t;
}
+
+template <typename T>
+std::unique_ptr<T> parseProto(const uint8_t* bytes, int len) {
+ CELEBORN_CHECK_NOT_NULL(
+ bytes, "Data for {} must be non-null", typeid(T).name());
+
+ auto pbObj = std::make_unique<T>();
+
+ google::protobuf::io::CodedInputStream cis(bytes, len);
+
+ // The default recursion depth is 100, which causes some test cases to fail
+ // during regression testing. By setting the recursion depth limit to 2000,
+ // it means that during the parsing process, if the recursion depth exceeds
+ // 2000 layers, the parsing process will be terminated and an error will be
+ // returned.
+ cis.SetRecursionLimit(2000);
+ bool parseSuccess = (pbObj.get())->ParseFromCodedStream(&cis);
+
+ if (!parseSuccess) {
+ std::cerr << "Unable to parse " << typeid(T).name() << " protobuf";
+ exit(1);
+ }
+ return pbObj;
+}
+
} // namespace utils
} // namespace celeborn