This is an automated email from the ASF dual-hosted git repository.

ethanfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new b74e05b60 [CELEBORN-1821][CIP-14] Add controlMessages to cppClient
b74e05b60 is described below

commit b74e05b603bed45e8d719b2c0e7988a552f578c3
Author: HolyLow <[email protected]>
AuthorDate: Wed Jan 8 13:35:09 2025 +0800

    [CELEBORN-1821][CIP-14] Add controlMessages to cppClient
    
    ### What changes were proposed in this pull request?
    This PR adds ControlMessages to cppClient.
    
    ### Why are the changes needed?
    The ControlMessages are used to communicate with the CelebornServer and 
LifecycleManager.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Compilation and UTs.
    
    Closes #3052 from 
HolyLow/issue/celeborn-1821-add-control-messages-to-cpp-client.
    
    Authored-by: HolyLow <[email protected]>
    Signed-off-by: mingji <[email protected]>
---
 cpp/celeborn/protocol/CMakeLists.txt               |   3 +-
 cpp/celeborn/protocol/ControlMessages.cpp          | 182 +++++++++++++++
 cpp/celeborn/protocol/ControlMessages.h            | 107 +++++++++
 cpp/celeborn/protocol/PartitionLocation.cpp        |  31 +++
 cpp/celeborn/protocol/PartitionLocation.h          |   4 +
 cpp/celeborn/protocol/tests/CMakeLists.txt         |   3 +-
 .../protocol/tests/ControlMessagesTest.cpp         | 244 +++++++++++++++++++++
 cpp/celeborn/utils/CMakeLists.txt                  |   3 +-
 cpp/celeborn/utils/CelebornUtils.cpp               |  61 ++++++
 cpp/celeborn/utils/CelebornUtils.h                 |  57 ++++-
 10 files changed, 691 insertions(+), 4 deletions(-)

diff --git a/cpp/celeborn/protocol/CMakeLists.txt 
b/cpp/celeborn/protocol/CMakeLists.txt
index 6a6ef88a7..aa601c479 100644
--- a/cpp/celeborn/protocol/CMakeLists.txt
+++ b/cpp/celeborn/protocol/CMakeLists.txt
@@ -16,7 +16,8 @@ add_library(
         protocol
         STATIC
         PartitionLocation.cpp
-        TransportMessage.cpp)
+        TransportMessage.cpp
+        ControlMessages.cpp)
 
 target_include_directories(protocol PUBLIC ${CMAKE_BINARY_DIR})
 
diff --git a/cpp/celeborn/protocol/ControlMessages.cpp 
b/cpp/celeborn/protocol/ControlMessages.cpp
new file mode 100644
index 000000000..1a8fa6ce3
--- /dev/null
+++ b/cpp/celeborn/protocol/ControlMessages.cpp
@@ -0,0 +1,182 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "celeborn/protocol/ControlMessages.h"
+#include "celeborn/utils/CelebornUtils.h"
+
+namespace celeborn {
+namespace protocol {
+TransportMessage GetReducerFileGroup::toTransportMessage() const {
+  MessageType type = GET_REDUCER_FILE_GROUP;
+  PbGetReducerFileGroup pb;
+  pb.set_shuffleid(shuffleId);
+  std::string payload = pb.SerializeAsString();
+  return TransportMessage(type, std::move(payload));
+}
+
+std::unique_ptr<GetReducerFileGroupResponse>
+GetReducerFileGroupResponse::fromTransportMessage(
+    const TransportMessage& transportMessage) {
+  CELEBORN_CHECK(
+      transportMessage.type() == GET_REDUCER_FILE_GROUP_RESPONSE,
+      "transportMessageType mismatch");
+  auto payload = transportMessage.payload();
+  auto pbGetReducerFileGroupResponse =
+      utils::parseProto<PbGetReducerFileGroupResponse>(
+          reinterpret_cast<const uint8_t*>(payload.c_str()), payload.size());
+  auto response = std::make_unique<GetReducerFileGroupResponse>();
+  response->status = toStatusCode(pbGetReducerFileGroupResponse->status());
+  auto fileGroups = pbGetReducerFileGroupResponse->filegroups();
+  for (auto& kv : fileGroups) {
+    auto& fileGroup = response->fileGroups[kv.first];
+    // Legacy mode is deprecated.
+    CELEBORN_CHECK_EQ(
+        kv.second.locations().size(),
+        0,
+        "legecy PartitionLocation pb is deprecated");
+    // Packed mode: must use packedPartitionLocations.
+    auto& pbPackedPartitionLocationsPair = kv.second.partitionlocationspair();
+    int inputLocationSize = pbPackedPartitionLocationsPair.inputlocationsize();
+    auto& pbPackedPartitionLocations =
+        pbPackedPartitionLocationsPair.locations();
+    std::vector<std::unique_ptr<PartitionLocation>> partialLocations;
+    auto& pbIds = pbPackedPartitionLocations.ids();
+    for (int idx = 0; idx < pbIds.size(); idx++) {
+      partialLocations.push_back(
+          PartitionLocation::fromPackedPb(pbPackedPartitionLocations, idx));
+    }
+    for (int idx = 0; idx < inputLocationSize; idx++) {
+      auto replicaIdx = pbPackedPartitionLocationsPair.peerindexes(idx);
+      // has peer
+      if (replicaIdx != INT_MAX) {
+        CELEBORN_CHECK_GE(replicaIdx, inputLocationSize);
+        CELEBORN_CHECK_LT(replicaIdx, partialLocations.size());
+        auto location = std::move(partialLocations[idx]);
+        auto peerLocation = std::move(partialLocations[replicaIdx]);
+        // make sure the location is primary and peer location is replica
+        if (location->mode == PartitionLocation::Mode::REPLICA) {
+          std::swap(location, peerLocation);
+        }
+        CELEBORN_CHECK(location->mode == PartitionLocation::Mode::PRIMARY);
+        CELEBORN_CHECK(peerLocation->mode == PartitionLocation::Mode::REPLICA);
+        location->replicaPeer = std::move(peerLocation);
+        fileGroup.insert(std::move(location));
+      }
+      // has no peer
+      else {
+        fileGroup.insert(std::move(partialLocations[idx]));
+      }
+    }
+  }
+  auto attempts = pbGetReducerFileGroupResponse->attempts();
+  response->attempts.reserve(attempts.size());
+  for (auto attempt : attempts) {
+    response->attempts.push_back(attempt);
+  }
+  auto partitionIds = pbGetReducerFileGroupResponse->partitionids();
+  for (auto partitionId : partitionIds) {
+    response->partitionIds.insert(partitionId);
+  }
+  return std::move(response);
+}
+
+OpenStream::OpenStream(
+    const std::string& shuffleKey,
+    const std::string& filename,
+    int32_t startMapIndex,
+    int32_t endMapIndex)
+    : shuffleKey(shuffleKey),
+      filename(filename),
+      startMapIndex(startMapIndex),
+      endMapIndex(endMapIndex) {}
+
+TransportMessage OpenStream::toTransportMessage() const {
+  MessageType type = OPEN_STREAM;
+  PbOpenStream pb;
+  pb.set_shufflekey(shuffleKey);
+  pb.set_filename(filename);
+  pb.set_startindex(startMapIndex);
+  pb.set_endindex(endMapIndex);
+  std::string payload = pb.SerializeAsString();
+  return TransportMessage(type, std::move(payload));
+}
+
+std::unique_ptr<StreamHandler> StreamHandler::fromTransportMessage(
+    const TransportMessage& transportMessage) {
+  CELEBORN_CHECK(
+      transportMessage.type() == STREAM_HANDLER,
+      "transportMessageType should be STREAM_HANDLER");
+  auto payload = transportMessage.payload();
+  auto pbStreamHandler = utils::parseProto<PbStreamHandler>(
+      reinterpret_cast<const uint8_t*>(payload.c_str()), payload.size());
+  auto streamHandler = std::make_unique<StreamHandler>();
+  streamHandler->streamId = pbStreamHandler->streamid();
+  streamHandler->numChunks = pbStreamHandler->numchunks();
+  for (auto chunkOffset : pbStreamHandler->chunkoffsets()) {
+    streamHandler->chunkOffsets.push_back(chunkOffset);
+  }
+  streamHandler->fullPath = pbStreamHandler->fullpath();
+  return std::move(streamHandler);
+}
+
+std::unique_ptr<PbStreamChunkSlice> StreamChunkSlice::toProto() const {
+  auto pb = std::make_unique<PbStreamChunkSlice>();
+  pb->set_streamid(streamId);
+  pb->set_chunkindex(chunkIndex);
+  pb->set_offset(offset);
+  pb->set_len(len);
+  return std::move(pb);
+}
+
+StreamChunkSlice StreamChunkSlice::decodeFrom(
+    memory::ReadOnlyByteBuffer& data) {
+  CELEBORN_CHECK_GE(data.remainingSize(), 20);
+  StreamChunkSlice slice;
+  slice.streamId = data.read<long>();
+  slice.chunkIndex = data.read<int>();
+  slice.offset = data.read<int>();
+  slice.len = data.read<int>();
+  return slice;
+}
+
+size_t StreamChunkSlice::Hasher::operator()(const StreamChunkSlice& lhs) const 
{
+  const auto hashStreamId = std::hash<long>()(lhs.streamId);
+  const auto hashChunkIndex = std::hash<int>()(lhs.chunkIndex) << 1;
+  const auto hashOffset = std::hash<int>()(lhs.offset) << 2;
+  const auto hashLen = std::hash<int>()(lhs.len) << 3;
+  return hashStreamId ^ hashChunkIndex ^ hashOffset ^ hashLen;
+}
+
+TransportMessage ChunkFetchRequest::toTransportMessage() const {
+  MessageType type = CHUNK_FETCH_REQUEST;
+  PbChunkFetchRequest pb;
+  pb.unsafe_arena_set_allocated_streamchunkslice(
+      streamChunkSlice.toProto().release());
+  std::string payload = pb.SerializeAsString();
+  return TransportMessage(type, std::move(payload));
+}
+
+TransportMessage BufferStreamEnd::toTransportMessage() const {
+  MessageType type = BUFFER_STREAM_END;
+  PbBufferStreamEnd pb;
+  pb.set_streamtype(ChunkStream);
+  pb.set_streamid(streamId);
+  std::string payload = pb.SerializeAsString();
+  return TransportMessage(type, std::move(payload));
+}
+} // namespace protocol
+} // namespace celeborn
diff --git a/cpp/celeborn/protocol/ControlMessages.h 
b/cpp/celeborn/protocol/ControlMessages.h
new file mode 100644
index 000000000..7d74a4900
--- /dev/null
+++ b/cpp/celeborn/protocol/ControlMessages.h
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include <map>
+#include <set>
+
+#include "celeborn/protocol/PartitionLocation.h"
+#include "celeborn/protocol/StatusCode.h"
+#include "celeborn/protocol/TransportMessage.h"
+
+namespace celeborn {
+namespace protocol {
+struct GetReducerFileGroup {
+  int shuffleId;
+
+  TransportMessage toTransportMessage() const;
+};
+
+struct GetReducerFileGroupResponse {
+  StatusCode status;
+  std::map<int, std::set<std::shared_ptr<const PartitionLocation>>> fileGroups;
+  std::vector<int> attempts;
+  std::set<int> partitionIds;
+
+  static std::unique_ptr<GetReducerFileGroupResponse> fromTransportMessage(
+      const TransportMessage& transportMessage);
+};
+
+struct OpenStream {
+  std::string shuffleKey;
+  std::string filename;
+  int32_t startMapIndex;
+  int32_t endMapIndex;
+
+  OpenStream(
+      const std::string& shuffleKey,
+      const std::string& filename,
+      int32_t startMapIndex,
+      int32_t endMapIndex);
+
+  TransportMessage toTransportMessage() const;
+};
+
+struct StreamHandler {
+  int64_t streamId;
+  int32_t numChunks;
+  std::vector<int64_t> chunkOffsets;
+  std::string fullPath;
+
+  static std::unique_ptr<StreamHandler> fromTransportMessage(
+      const TransportMessage& transportMessage);
+};
+
+struct StreamChunkSlice {
+  long streamId;
+  int chunkIndex;
+  int offset{0};
+  int len{INT_MAX};
+
+  std::unique_ptr<PbStreamChunkSlice> toProto() const;
+
+  static StreamChunkSlice decodeFrom(memory::ReadOnlyByteBuffer& data);
+
+  std::string toString() const {
+    return std::to_string(streamId) + "-" + std::to_string(chunkIndex) + "-" +
+        std::to_string(offset) + "-" + std::to_string(len);
+  }
+
+  bool operator==(const StreamChunkSlice& rhs) const {
+    return streamId == rhs.streamId && chunkIndex == rhs.chunkIndex &&
+        offset == rhs.offset && len == rhs.len;
+  }
+
+  struct Hasher {
+    size_t operator()(const StreamChunkSlice& lhs) const;
+  };
+};
+
+struct ChunkFetchRequest {
+  StreamChunkSlice streamChunkSlice;
+
+  TransportMessage toTransportMessage() const;
+};
+
+struct BufferStreamEnd {
+  long streamId;
+
+  TransportMessage toTransportMessage() const;
+};
+} // namespace protocol
+} // namespace celeborn
diff --git a/cpp/celeborn/protocol/PartitionLocation.cpp 
b/cpp/celeborn/protocol/PartitionLocation.cpp
index aee59586b..5cd9bdecb 100644
--- a/cpp/celeborn/protocol/PartitionLocation.cpp
+++ b/cpp/celeborn/protocol/PartitionLocation.cpp
@@ -53,6 +53,37 @@ std::unique_ptr<const PartitionLocation> 
PartitionLocation::fromPb(
   return std::move(result);
 }
 
+
+std::unique_ptr<PartitionLocation> PartitionLocation::fromPackedPb(
+    const PbPackedPartitionLocations& pb,
+    int idx) {
+  auto& workerIdStr = pb.workeridsset(pb.workerids(idx));
+  auto workerIdParts = utils::parseColonSeparatedHostPorts(workerIdStr, 4);
+  std::string filePath = pb.filepaths(idx);
+  if (!filePath.empty()) {
+    filePath = pb.mountpointsset(pb.mountpoints(idx)) + pb.filepaths(idx);
+  }
+
+  auto result = std::make_unique<PartitionLocation>();
+  result->id = pb.ids(idx);
+  result->epoch = pb.epoches(idx);
+  result->host = workerIdParts[0];
+  result->rpcPort = utils::strv2val<int>(workerIdParts[1]);
+  result->pushPort = utils::strv2val<int>(workerIdParts[2]);
+  result->fetchPort = utils::strv2val<int>(workerIdParts[3]);
+  result->replicatePort = utils::strv2val<int>(workerIdParts[4]);
+  result->mode = static_cast<Mode>(pb.modes(idx));
+  result->replicaPeer = nullptr;
+  result->storageInfo = std::make_unique<StorageInfo>();
+  result->storageInfo->type = static_cast<StorageInfo::Type>(pb.types(idx));
+  result->storageInfo->mountPoint = pb.mountpointsset(pb.mountpoints(idx));
+  result->storageInfo->finalResult = pb.finalresult(idx);
+  result->storageInfo->filePath = filePath;
+  result->storageInfo->availableStorageTypes = pb.availablestoragetypes(idx);
+
+  return std::move(result);
+}
+
 PartitionLocation::PartitionLocation(const PartitionLocation& other)
     : id(other.id),
       epoch(other.epoch),
diff --git a/cpp/celeborn/protocol/PartitionLocation.h 
b/cpp/celeborn/protocol/PartitionLocation.h
index c485c1607..700773cba 100644
--- a/cpp/celeborn/protocol/PartitionLocation.h
+++ b/cpp/celeborn/protocol/PartitionLocation.h
@@ -80,6 +80,10 @@ struct PartitionLocation {
   static std::unique_ptr<const PartitionLocation> fromPb(
       const PbPartitionLocation& pb);
 
+  static std::unique_ptr<PartitionLocation> fromPackedPb(
+      const PbPackedPartitionLocations& pb,
+      int idx);
+
   PartitionLocation() = default;
 
   PartitionLocation(const PartitionLocation& other);
diff --git a/cpp/celeborn/protocol/tests/CMakeLists.txt 
b/cpp/celeborn/protocol/tests/CMakeLists.txt
index f5a00db2a..cb2a2378f 100644
--- a/cpp/celeborn/protocol/tests/CMakeLists.txt
+++ b/cpp/celeborn/protocol/tests/CMakeLists.txt
@@ -16,7 +16,8 @@
 add_executable(
         celeborn_protocol_test
         PartitionLocationTest.cpp
-        TransportMessageTest.cpp)
+        TransportMessageTest.cpp
+        ControlMessagesTest.cpp)
 
 add_test(NAME celeborn_protocol_test COMMAND celeborn_protocol_test)
 
diff --git a/cpp/celeborn/protocol/tests/ControlMessagesTest.cpp 
b/cpp/celeborn/protocol/tests/ControlMessagesTest.cpp
new file mode 100644
index 000000000..36003491a
--- /dev/null
+++ b/cpp/celeborn/protocol/tests/ControlMessagesTest.cpp
@@ -0,0 +1,244 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <gtest/gtest.h>
+
+#include "celeborn/proto/TransportMessagesCpp.pb.h"
+#include "celeborn/protocol/ControlMessages.h"
+
+#include "celeborn/utils/CelebornUtils.h"
+
+using namespace celeborn;
+using namespace celeborn::protocol;
+
+namespace {
+void generatePackedPartitionLocationPb(
+    PbPackedPartitionLocations& pbPackedPartitionLocations,
+    int idx,
+    PartitionLocation::Mode mode) {
+  pbPackedPartitionLocations.add_ids(1);
+  pbPackedPartitionLocations.add_epoches(101);
+  pbPackedPartitionLocations.add_workerids(idx);
+  pbPackedPartitionLocations.add_workeridsset("test-host:1001:1002:1003:1004");
+  pbPackedPartitionLocations.add_mountpoints(idx);
+  pbPackedPartitionLocations.add_mountpointsset("test-mountpoint/");
+  pbPackedPartitionLocations.add_filepaths("test-filepath");
+  pbPackedPartitionLocations.add_types(1);
+  pbPackedPartitionLocations.add_finalresult(true);
+  pbPackedPartitionLocations.add_availablestoragetypes(1);
+  pbPackedPartitionLocations.add_modes(mode);
+}
+
+void verifyUnpackedPartitionLocation(
+    const PartitionLocation* partitionLocation) {
+  EXPECT_EQ(partitionLocation->id, 1);
+  EXPECT_EQ(partitionLocation->epoch, 101);
+  EXPECT_EQ(partitionLocation->host, "test-host");
+  EXPECT_EQ(partitionLocation->rpcPort, 1001);
+  EXPECT_EQ(partitionLocation->pushPort, 1002);
+  EXPECT_EQ(partitionLocation->fetchPort, 1003);
+  EXPECT_EQ(partitionLocation->replicatePort, 1004);
+
+  auto storageInfo = partitionLocation->storageInfo.get();
+  EXPECT_EQ(storageInfo->type, 1);
+  EXPECT_EQ(storageInfo->mountPoint, "test-mountpoint/");
+  EXPECT_EQ(storageInfo->finalResult, true);
+  EXPECT_EQ(storageInfo->filePath, "test-mountpoint/test-filepath");
+  EXPECT_EQ(storageInfo->availableStorageTypes, 1);
+}
+} // namespace
+
+TEST(ControlMessagesTest, getReducerFileGroup) {
+  auto getReducerFileGroup = std::make_unique<GetReducerFileGroup>();
+  getReducerFileGroup->shuffleId = 1000;
+
+  auto transportMessage = getReducerFileGroup->toTransportMessage();
+  EXPECT_EQ(transportMessage.type(), GET_REDUCER_FILE_GROUP);
+  auto payload = transportMessage.payload();
+  auto pbGetReducerFileGroup = utils::parseProto<PbGetReducerFileGroup>(
+      reinterpret_cast<const uint8_t*>(payload.c_str()), payload.size());
+  EXPECT_EQ(pbGetReducerFileGroup->shuffleid(), 1000);
+}
+
+TEST(ControlMessagesTest, getReducerFileGroupResponseLegacyModeDeprecated) {
+  PbGetReducerFileGroupResponse pbGetReducerFileGroupResponse;
+  pbGetReducerFileGroupResponse.set_status(1);
+  for (int i = 0; i < 4; i++) {
+    pbGetReducerFileGroupResponse.add_attempts(i);
+  }
+  for (int i = 0; i < 6; i++) {
+    pbGetReducerFileGroupResponse.add_partitionids(i);
+  }
+  auto id2FileGroups = pbGetReducerFileGroupResponse.mutable_filegroups();
+  PbFileGroup pbFileGroup;
+  pbFileGroup.add_locations();
+  id2FileGroups->insert({0, pbFileGroup});
+
+  TransportMessage transportMessage(
+      GET_REDUCER_FILE_GROUP_RESPONSE,
+      pbGetReducerFileGroupResponse.SerializeAsString());
+  EXPECT_THROW(
+      GetReducerFileGroupResponse::fromTransportMessage(transportMessage),
+      utils::CelebornRuntimeError);
+}
+
+TEST(ControlMessagesTest, getReducerFileGroupResponsePackedMode) {
+  PbGetReducerFileGroupResponse pbGetReducerFileGroupResponse;
+  pbGetReducerFileGroupResponse.set_status(1);
+  for (int i = 0; i < 4; i++) {
+    pbGetReducerFileGroupResponse.add_attempts(i);
+  }
+  for (int i = 0; i < 6; i++) {
+    pbGetReducerFileGroupResponse.add_partitionids(i);
+  }
+
+  PbFileGroup pbFileGroup;
+  auto pbPackedPartitionLocationsPair =
+      pbFileGroup.mutable_partitionlocationspair();
+  auto pbPackedPartitionLocations =
+      pbPackedPartitionLocationsPair->mutable_locations();
+  // Has one inputLocation, with offset 0.
+  pbPackedPartitionLocationsPair->set_inputlocationsize(1);
+  // The peerIndex 1 is replica.
+  pbPackedPartitionLocationsPair->add_peerindexes(1);
+  // Add the two partitionLocations, one is primary and the other is replica.
+  generatePackedPartitionLocationPb(
+      *pbPackedPartitionLocations, 0, PartitionLocation::Mode::PRIMARY);
+  generatePackedPartitionLocationPb(
+      *pbPackedPartitionLocations, 1, PartitionLocation::Mode::REPLICA);
+
+  auto id2FileGroups = pbGetReducerFileGroupResponse.mutable_filegroups();
+  id2FileGroups->insert({0, pbFileGroup});
+
+  TransportMessage transportMessage(
+      GET_REDUCER_FILE_GROUP_RESPONSE,
+      pbGetReducerFileGroupResponse.SerializeAsString());
+  auto getReducerFileGroupResponse =
+      GetReducerFileGroupResponse::fromTransportMessage(transportMessage);
+  EXPECT_EQ(getReducerFileGroupResponse->status, 1);
+  EXPECT_EQ(getReducerFileGroupResponse->attempts.size(), 4);
+  for (int i = 0; i < 4; i++) {
+    EXPECT_EQ(getReducerFileGroupResponse->attempts[i], i);
+  }
+  EXPECT_EQ(getReducerFileGroupResponse->partitionIds.size(), 6);
+  for (int i = 0; i < 6; i++) {
+    EXPECT_EQ(getReducerFileGroupResponse->partitionIds.count(i), 1);
+  }
+  EXPECT_EQ(getReducerFileGroupResponse->fileGroups.size(), 1);
+  const auto& partitionLocations = getReducerFileGroupResponse->fileGroups[0];
+  EXPECT_EQ(partitionLocations.size(), 1);
+  auto primaryPartitionLocation = partitionLocations.begin()->get();
+  verifyUnpackedPartitionLocation(primaryPartitionLocation);
+  EXPECT_EQ(primaryPartitionLocation->mode, PartitionLocation::Mode::PRIMARY);
+  auto replicaPartitionLocation = primaryPartitionLocation->replicaPeer.get();
+  verifyUnpackedPartitionLocation(replicaPartitionLocation);
+  EXPECT_EQ(replicaPartitionLocation->mode, PartitionLocation::Mode::REPLICA);
+}
+
+TEST(ControlMessagesTest, openStream) {
+  auto openStream = std::make_unique<OpenStream>(
+      "test-shuffle-key", "test-filename", 100, 200);
+  auto transportMessage = openStream->toTransportMessage();
+  EXPECT_EQ(transportMessage.type(), OPEN_STREAM);
+  auto payload = transportMessage.payload();
+  auto pbOpenStream = utils::parseProto<PbOpenStream>(
+      reinterpret_cast<const uint8_t*>(payload.c_str()), payload.size());
+  EXPECT_EQ(pbOpenStream->shufflekey(), "test-shuffle-key");
+  EXPECT_EQ(pbOpenStream->filename(), "test-filename");
+  EXPECT_EQ(pbOpenStream->startindex(), 100);
+  EXPECT_EQ(pbOpenStream->endindex(), 200);
+}
+
+TEST(ControlMessagesTest, streamHandler) {
+  PbStreamHandler pb;
+  pb.set_streamid(100);
+  pb.set_numchunks(4);
+  for (int i = 0; i < 4; i++) {
+    pb.add_chunkoffsets(i);
+  }
+  pb.set_fullpath("test-fullpath");
+  TransportMessage transportMessage(STREAM_HANDLER, pb.SerializeAsString());
+
+  auto streamHandler = StreamHandler::fromTransportMessage(transportMessage);
+  EXPECT_EQ(streamHandler->streamId, 100);
+  EXPECT_EQ(streamHandler->numChunks, 4);
+  EXPECT_EQ(streamHandler->chunkOffsets.size(), 4);
+  for (int i = 0; i < 4; i++) {
+    EXPECT_EQ(streamHandler->chunkOffsets[i], i);
+  }
+  EXPECT_EQ(streamHandler->fullPath, "test-fullpath");
+}
+
+TEST(ControlMessagesTest, streamChunkSlice) {
+  StreamChunkSlice streamChunkSlice;
+  streamChunkSlice.streamId = 100;
+  streamChunkSlice.chunkIndex = 1000;
+  streamChunkSlice.offset = 111;
+  streamChunkSlice.len = 1111;
+
+  auto pb = streamChunkSlice.toProto();
+  EXPECT_EQ(pb->streamid(), 100);
+  EXPECT_EQ(pb->chunkindex(), 1000);
+  EXPECT_EQ(pb->offset(), 111);
+  EXPECT_EQ(pb->len(), 1111);
+
+  auto writeBuffer = memory::ByteBuffer::createWriteOnly(20);
+  writeBuffer->write<long>(streamChunkSlice.streamId);
+  writeBuffer->write<int>(streamChunkSlice.chunkIndex);
+  writeBuffer->write<int>(streamChunkSlice.offset);
+  writeBuffer->write<int>(streamChunkSlice.len);
+  auto readBuffer = memory::ByteBuffer::toReadOnly(std::move(writeBuffer));
+  auto decodedStreamChunkSlice = StreamChunkSlice::decodeFrom(*readBuffer);
+  EXPECT_EQ(readBuffer->remainingSize(), 0);
+  EXPECT_EQ(streamChunkSlice.streamId, decodedStreamChunkSlice.streamId);
+  EXPECT_EQ(streamChunkSlice.chunkIndex, decodedStreamChunkSlice.chunkIndex);
+  EXPECT_EQ(streamChunkSlice.offset, decodedStreamChunkSlice.offset);
+  EXPECT_EQ(streamChunkSlice.len, decodedStreamChunkSlice.len);
+}
+
+TEST(ControlMessagesTest, chunkFetchRequest) {
+  ChunkFetchRequest chunkFetchRequest;
+  StreamChunkSlice streamChunkSlice;
+  streamChunkSlice.streamId = 100;
+  streamChunkSlice.chunkIndex = 1000;
+  streamChunkSlice.offset = 111;
+  streamChunkSlice.len = 1111;
+  chunkFetchRequest.streamChunkSlice = streamChunkSlice;
+
+  auto transportMessage = chunkFetchRequest.toTransportMessage();
+  EXPECT_EQ(transportMessage.type(), CHUNK_FETCH_REQUEST);
+  auto payload = transportMessage.payload();
+  auto pbChunkFetchRequest = utils::parseProto<PbChunkFetchRequest>(
+      reinterpret_cast<const uint8_t*>(payload.c_str()), payload.size());
+  auto pbStreamChunkSlice = pbChunkFetchRequest->streamchunkslice();
+  EXPECT_EQ(pbStreamChunkSlice.streamid(), 100);
+  EXPECT_EQ(pbStreamChunkSlice.chunkindex(), 1000);
+  EXPECT_EQ(pbStreamChunkSlice.offset(), 111);
+  EXPECT_EQ(pbStreamChunkSlice.len(), 1111);
+}
+
+TEST(ControlMessagesTest, bufferStreamEnd) {
+  BufferStreamEnd bufferStreamEnd;
+  bufferStreamEnd.streamId = 111111;
+
+  auto transportMessage = bufferStreamEnd.toTransportMessage();
+  EXPECT_EQ(transportMessage.type(), BUFFER_STREAM_END);
+  auto payload = transportMessage.payload();
+  auto pb = utils::parseProto<PbBufferStreamEnd>(
+      reinterpret_cast<const uint8_t*>(payload.c_str()), payload.size());
+  EXPECT_EQ(pb->streamid(), 111111);
+}
diff --git a/cpp/celeborn/utils/CMakeLists.txt 
b/cpp/celeborn/utils/CMakeLists.txt
index 19f3efc44..600d3efaf 100644
--- a/cpp/celeborn/utils/CMakeLists.txt
+++ b/cpp/celeborn/utils/CMakeLists.txt
@@ -19,7 +19,8 @@ add_library(
         StackTrace.cpp
         CelebornException.cpp
         Exceptions.cpp
-        flags.cpp)
+        flags.cpp
+        CelebornUtils.cpp)
 
 target_link_libraries(
         utils
diff --git a/cpp/celeborn/utils/CelebornUtils.cpp 
b/cpp/celeborn/utils/CelebornUtils.cpp
new file mode 100644
index 000000000..c91c0ab5f
--- /dev/null
+++ b/cpp/celeborn/utils/CelebornUtils.cpp
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "celeborn/utils/CelebornUtils.h"
+
+namespace celeborn {
+namespace utils {
+std::vector<std::string_view> parseColonSeparatedHostPorts(
+    const std::string_view& s,
+    int num) {
+  auto parsed = explode(s, ':');
+  CELEBORN_CHECK_GT(parsed.size(), num);
+  std::vector<std::string_view> result;
+  result.resize(num + 1);
+  size_t size = 0;
+  for (int result_idx = 1, parsed_idx = parsed.size() - num;
+       result_idx <= num;
+       result_idx++, parsed_idx++) {
+    result[result_idx] = parsed[parsed_idx];
+    size += parsed[parsed_idx].size() + 1;
+       }
+  result[0] = s.substr(0, s.size() - size);
+  return result;
+}
+
+std::vector<std::string_view> explode(const std::string_view& s, char delim) {
+  std::vector<std::string_view> result;
+  std::string_view::size_type i = 0;
+  while (i < s.size()) {
+    auto j = s.find(delim, i);
+    if (j == std::string::npos) {
+      j = s.size();
+    }
+    result.emplace_back(s.substr(i, j - i));
+    i = j + 1;
+  }
+  return result;
+}
+
+std::tuple<std::string_view, std::string_view> split(
+    const std::string_view& s,
+    char delim) {
+  auto pos = s.find(delim);
+  return std::make_tuple(s.substr(0, pos), s.substr(pos + 1));
+}
+} // namespace utils
+} // namespace celeborn
diff --git a/cpp/celeborn/utils/CelebornUtils.h 
b/cpp/celeborn/utils/CelebornUtils.h
index 158b35944..4c1bd63fc 100644
--- a/cpp/celeborn/utils/CelebornUtils.h
+++ b/cpp/celeborn/utils/CelebornUtils.h
@@ -17,7 +17,10 @@
 
 #pragma once
 
+#include <google/protobuf/io/coded_stream.h>
+#include <charconv>
 #include <chrono>
+#include <vector>
 
 #include "celeborn/utils/Exceptions.h"
 
@@ -31,12 +34,64 @@ namespace utils {
 #define CELEBORN_SHUTDOWN_LOG(severity) \
   LOG(severity) << CELEBORN_SHUTDOWN_LOG_PREFIX
 
-
 using Duration = std::chrono::duration<double>;
 using Timeout = std::chrono::milliseconds;
 inline Timeout toTimeout(Duration duration) {
   return std::chrono::duration_cast<Timeout>(duration);
+}
+
+/// parse string like "Any-Host-Str:Port#1:Port#2:...:Port#num", split into
+/// {"Any-Host-Str", "Port#1", "Port#2", ..., "Port#num"}. Note that the
+/// "Any-Host_Str" might contain ':' in IPV6 address.
+std::vector<std::string_view> parseColonSeparatedHostPorts(
+    const std::string_view& s,
+    int num);
+
+std::vector<std::string_view> explode(const std::string_view& s, char delim);
 
+std::tuple<std::string_view, std::string_view> split(
+    const std::string_view& s,
+    char delim);
+
+template <class T>
+T strv2val(const std::string_view& s) {
+  T t;
+  const char* first = s.data();
+  const char* last = s.data() + s.size();
+  std::from_chars_result res = std::from_chars(first, last, t);
+
+  // These two exceptions reflect the behavior of std::stoi.
+  if (res.ec == std::errc::invalid_argument) {
+    CELEBORN_FAIL("Invalid argument when parsing");
+  } else if (res.ec == std::errc::result_out_of_range) {
+    CELEBORN_FAIL("Out of range when parsing");
+  }
+  return t;
 }
+
+template <typename T>
+std::unique_ptr<T> parseProto(const uint8_t* bytes, int len) {
+  CELEBORN_CHECK_NOT_NULL(
+      bytes, "Data for {} must be non-null", typeid(T).name());
+
+  auto pbObj = std::make_unique<T>();
+
+  google::protobuf::io::CodedInputStream cis(bytes, len);
+
+  // The default recursion depth is 100, which causes some test cases to fail
+  // during regression testing. By setting the recursion depth limit to 2000,
+  // it means that during the parsing process, if the recursion depth exceeds
+  // 2000 layers, the parsing process will be terminated and an error will be
+  // returned.
+  cis.SetRecursionLimit(2000);
+  bool parseSuccess = (pbObj.get())->ParseFromCodedStream(&cis);
+
+  if (!parseSuccess) {
+    std::cerr << "Unable to parse " << typeid(T).name() << " protobuf";
+    exit(1);
+  }
+  return pbObj;
+}
+
 } // namespace utils
 } // namespace celeborn

Reply via email to