This is an automated email from the ASF dual-hosted git repository.

nicholasjiang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new b1cbdabdf [CELEBORN-2221][CIP-14] Support writing with compression in 
C++ client
b1cbdabdf is described below

commit b1cbdabdf70230b7d2fff3ee0f7c44fa5a829f92
Author: afterincomparableyum <afterincomparableyum>
AuthorDate: Mon Feb 16 14:59:18 2026 +0800

    [CELEBORN-2221][CIP-14] Support writing with compression in C++ client
    
    Integrate existing compression infrastructure (LZ4 and ZSTD) into the C++ 
client write path. This enables compression during pushData operations, 
matching the functionality available in the Java client.
    
    Changes:
    - Add compression support to ShuffleClientImpl:
      * Add shuffleCompressionEnabled_ flag and compressor_ member
      * Initialize compressor from CelebornConf in constructor
      * Compress data in pushData() when compression is enabled
      * Use compressed size for batchBytesSize tracking
    
    - Configuration integration:
      * Read compression codec from celeborn.client.shuffle.compression.codec
      * Read ZSTD compression level from 
celeborn.client.shuffle.compression.zstd.level
      * Default to NONE (compression disabled)
    
    - Retry/revive support:
      * Retry path correctly uses pre-compressed body buffer
      * No re-compression needed during retries
    
    - Testing:
      * Add CompressorFactoryTest for factory pattern and config integration
      * Add compression config tests to CelebornConfTest
      * Test offset compression support for both LZ4 and ZSTD
    
    ### How was this patch tested?
    
    Unit Tests, as well as compiling code
    
    Closes #3575 from afterincomparableyum/cpp-client/celeborn-2221.
    
    Authored-by: afterincomparableyum <afterincomparableyum>
    Signed-off-by: SteNicholas <[email protected]>
---
 cpp/celeborn/client/ShuffleClient.cpp              |  75 ++++++-
 cpp/celeborn/client/ShuffleClient.h                |   6 +
 cpp/celeborn/client/tests/CMakeLists.txt           |   3 +-
 .../client/tests/CompressorFactoryTest.cpp         | 226 +++++++++++++++++++++
 cpp/celeborn/conf/tests/CelebornConfTest.cpp       |  13 ++
 5 files changed, 315 insertions(+), 8 deletions(-)

diff --git a/cpp/celeborn/client/ShuffleClient.cpp 
b/cpp/celeborn/client/ShuffleClient.cpp
index 401b99e2f..35887135e 100644
--- a/cpp/celeborn/client/ShuffleClient.cpp
+++ b/cpp/celeborn/client/ShuffleClient.cpp
@@ -16,6 +16,7 @@
  */
 
 #include "celeborn/client/ShuffleClient.h"
+#include <limits>
 
 #include "celeborn/utils/CelebornUtils.h"
 
@@ -57,7 +58,16 @@ ShuffleClientImpl::ShuffleClientImpl(
     : appUniqueId_(appUniqueId),
       conf_(conf),
       clientFactory_(clientEndpoint.clientFactory()),
-      pushDataRetryPool_(clientEndpoint.pushDataRetryPool()) {
+      pushDataRetryPool_(clientEndpoint.pushDataRetryPool()),
+      shuffleCompressionEnabled_(
+          conf->shuffleCompressionCodec() != protocol::CompressionCodec::NONE),
+      compressorFactory_(
+          shuffleCompressionEnabled_
+              ? std::function<std::unique_ptr<compress::Compressor>()>(
+                    [conf]() {
+                      return compress::Compressor::createCompressor(*conf);
+                    })
+              : std::function<std::unique_ptr<compress::Compressor>()>()) {
   CELEBORN_CHECK_NOT_NULL(clientFactory_);
   CELEBORN_CHECK_NOT_NULL(pushDataRetryPool_);
 }
@@ -154,23 +164,74 @@ int ShuffleClientImpl::pushData(
   auto pushState = getPushState(mapKey);
   const int nextBatchId = pushState->nextBatchId();
 
-  // TODO: compression in writing is not supported.
+  // Validate input size fits in 32-bit int since it is required by compressor
+  // API and wire protocol
+  CELEBORN_CHECK(
+      length <= static_cast<size_t>(std::numeric_limits<int>::max()),
+      fmt::format(
+          "Data length {} exceeds maximum supported size {}",
+          length,
+          std::numeric_limits<int>::max()));
+
+  // Compression support: compress data if compression is enabled
+  const uint8_t* dataToWrite = data + offset;
+  int lengthToWrite = static_cast<int>(length);
+  std::unique_ptr<uint8_t[]> compressedBuffer;
+
+  if (shuffleCompressionEnabled_ && compressorFactory_) {
+    // Create a new compressor instance for thread-safety
+    auto compressor = compressorFactory_();
+    // Allocate buffer for compressed data
+    const size_t compressedCapacity =
+        compressor->getDstCapacity(static_cast<int>(length));
+    compressedBuffer = std::make_unique<uint8_t[]>(compressedCapacity);
+
+    // Compress the data
+    const size_t compressedSize = compressor->compress(
+        dataToWrite, 0, static_cast<int>(length), compressedBuffer.get(), 0);
+
+    CELEBORN_CHECK(
+        compressedSize <= static_cast<size_t>(std::numeric_limits<int>::max()),
+        fmt::format(
+            "Compressed size {} exceeds maximum supported size {}",
+            compressedSize,
+            std::numeric_limits<int>::max()));
+
+    lengthToWrite = static_cast<int>(compressedSize);
+    dataToWrite = compressedBuffer.get();
+  }
 
-  auto writeBuffer =
-      memory::ByteBuffer::createWriteOnly(kBatchHeaderSize + length);
+  // Validate final buffer size fits in size_t and int
+  CELEBORN_CHECK(
+      static_cast<size_t>(lengthToWrite) <=
+          std::numeric_limits<size_t>::max() - kBatchHeaderSize,
+      fmt::format(
+          "Buffer size {} + header {} would overflow",
+          lengthToWrite,
+          kBatchHeaderSize));
+
+  auto writeBuffer = memory::ByteBuffer::createWriteOnly(
+      kBatchHeaderSize + static_cast<size_t>(lengthToWrite));
   // TODO: the java side uses Platform to write the data. We simply assume
   //  littleEndian here.
   writeBuffer->writeLE<int>(mapId);
   writeBuffer->writeLE<int>(attemptId);
   writeBuffer->writeLE<int>(nextBatchId);
-  writeBuffer->writeLE<int>(length);
-  writeBuffer->writeFromBuffer(data, offset, length);
+  writeBuffer->writeLE<int>(lengthToWrite);
+  writeBuffer->writeFromBuffer(
+      dataToWrite, 0, static_cast<size_t>(lengthToWrite));
 
   auto hostAndPushPort = partitionLocation->hostAndPushPort();
   // Check limit.
   limitMaxInFlight(mapKey, *pushState, hostAndPushPort);
   // Add inFlight requests.
-  const int batchBytesSize = length + kBatchHeaderSize;
+  CELEBORN_CHECK(
+      lengthToWrite <= std::numeric_limits<int>::max() - kBatchHeaderSize,
+      fmt::format(
+          "Batch bytes size {} + header {} would overflow int",
+          lengthToWrite,
+          kBatchHeaderSize));
+  const int batchBytesSize = lengthToWrite + kBatchHeaderSize;
   pushState->addBatch(nextBatchId, batchBytesSize, hostAndPushPort);
   // Build pushData request.
   const auto shuffleKey = utils::makeShuffleKey(appUniqueId_, shuffleId);
diff --git a/cpp/celeborn/client/ShuffleClient.h 
b/cpp/celeborn/client/ShuffleClient.h
index 3e8cb9d37..e899ba0cd 100644
--- a/cpp/celeborn/client/ShuffleClient.h
+++ b/cpp/celeborn/client/ShuffleClient.h
@@ -17,6 +17,8 @@
 
 #pragma once
 
+#include <functional>
+#include "celeborn/client/compress/Compressor.h"
 #include "celeborn/client/reader/CelebornInputStream.h"
 #include "celeborn/client/writer/PushDataCallback.h"
 #include "celeborn/client/writer/PushState.h"
@@ -249,6 +251,7 @@ class ShuffleClientImpl
   static constexpr size_t kBatchHeaderSize = 4 * 4;
 
   const std::string appUniqueId_;
+  const bool shuffleCompressionEnabled_;
   std::shared_ptr<const conf::CelebornConf> conf_;
   std::shared_ptr<network::NettyRpcEndpointRef> lifecycleManagerRef_;
   std::shared_ptr<network::TransportClientFactory> clientFactory_;
@@ -266,6 +269,9 @@ class ShuffleClientImpl
       mapperEndSets_;
   utils::ConcurrentHashSet<int> stageEndShuffleSet_;
 
+  // Factory for creating compressor instances on demand to avoid sharing a
+  // single non-thread-safe compressor across concurrent operations.
+  std::function<std::unique_ptr<compress::Compressor>()> compressorFactory_;
   // TODO: pushExcludedWorker is not supported yet
 };
 } // namespace client
diff --git a/cpp/celeborn/client/tests/CMakeLists.txt 
b/cpp/celeborn/client/tests/CMakeLists.txt
index e19703f31..63c37c6e1 100644
--- a/cpp/celeborn/client/tests/CMakeLists.txt
+++ b/cpp/celeborn/client/tests/CMakeLists.txt
@@ -22,7 +22,8 @@ add_executable(
         Lz4DecompressorTest.cpp
         ZstdDecompressorTest.cpp
         Lz4CompressorTest.cpp
-        ZstdCompressorTest.cpp)
+        ZstdCompressorTest.cpp
+        CompressorFactoryTest.cpp)
 
 add_test(NAME celeborn_client_test COMMAND celeborn_client_test)
 
diff --git a/cpp/celeborn/client/tests/CompressorFactoryTest.cpp 
b/cpp/celeborn/client/tests/CompressorFactoryTest.cpp
new file mode 100644
index 000000000..cbfaa27e1
--- /dev/null
+++ b/cpp/celeborn/client/tests/CompressorFactoryTest.cpp
@@ -0,0 +1,226 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <gtest/gtest.h>
+
+#include "celeborn/client/compress/Compressor.h"
+#include "celeborn/client/compress/Decompressor.h"
+#include "celeborn/conf/CelebornConf.h"
+
+using namespace celeborn;
+using namespace celeborn::client;
+using namespace celeborn::conf;
+using namespace celeborn::protocol;
+
+TEST(CompressorFactoryTest, CreateLz4CompressorFromConf) {
+  CelebornConf conf;
+  conf.registerProperty(CelebornConf::kShuffleCompressionCodec, "LZ4");
+
+  auto compressor = compress::Compressor::createCompressor(conf);
+  ASSERT_NE(compressor, nullptr);
+
+  const std::string testData = "Test data for compression";
+  const size_t maxLength = compressor->getDstCapacity(testData.size());
+  std::vector<uint8_t> compressedData(maxLength);
+
+  const size_t compressedSize = compressor->compress(
+      reinterpret_cast<const uint8_t*>(testData.data()),
+      0,
+      testData.size(),
+      compressedData.data(),
+      0);
+
+  ASSERT_GT(compressedSize, 8);
+  EXPECT_EQ(compressedData[0], 'L');
+  EXPECT_EQ(compressedData[1], 'Z');
+  EXPECT_EQ(compressedData[2], '4');
+  EXPECT_EQ(compressedData[3], 'B');
+  EXPECT_EQ(compressedData[4], 'l');
+  EXPECT_EQ(compressedData[5], 'o');
+  EXPECT_EQ(compressedData[6], 'c');
+  EXPECT_EQ(compressedData[7], 'k');
+}
+
+TEST(CompressorFactoryTest, CreateZstdCompressorFromConf) {
+  CelebornConf conf;
+  conf.registerProperty(CelebornConf::kShuffleCompressionCodec, "ZSTD");
+  conf.registerProperty(
+      CelebornConf::kShuffleCompressionZstdCompressLevel, "3");
+
+  auto compressor = compress::Compressor::createCompressor(conf);
+  ASSERT_NE(compressor, nullptr);
+
+  const std::string testData = "Test data for compression";
+  const size_t maxLength = compressor->getDstCapacity(testData.size());
+  std::vector<uint8_t> compressedData(maxLength);
+
+  const size_t compressedSize = compressor->compress(
+      reinterpret_cast<const uint8_t*>(testData.data()),
+      0,
+      testData.size(),
+      compressedData.data(),
+      0);
+
+  ASSERT_GT(compressedSize, 9);
+  EXPECT_EQ(compressedData[0], 'Z');
+  EXPECT_EQ(compressedData[1], 'S');
+  EXPECT_EQ(compressedData[2], 'T');
+  EXPECT_EQ(compressedData[3], 'D');
+  EXPECT_EQ(compressedData[4], 'B');
+  EXPECT_EQ(compressedData[5], 'l');
+  EXPECT_EQ(compressedData[6], 'o');
+  EXPECT_EQ(compressedData[7], 'c');
+  EXPECT_EQ(compressedData[8], 'k');
+}
+
+TEST(CompressorFactoryTest, CompressionCodecNoneDisablesCompression) {
+  CelebornConf conf;
+  // Verify default is NONE
+  EXPECT_EQ(conf.shuffleCompressionCodec(), CompressionCodec::NONE);
+}
+
+TEST(CompressorFactoryTest, ZstdCompressionLevelFromConf) {
+  // Test that configuration correctly reads ZSTD compression levels
+  const std::string testData = "Test data for compression";
+
+  for (int level = -5; level <= 10; level++) {
+    CelebornConf conf;
+    conf.registerProperty(CelebornConf::kShuffleCompressionCodec, "ZSTD");
+    conf.registerProperty(
+        CelebornConf::kShuffleCompressionZstdCompressLevel,
+        std::to_string(level));
+
+    // Verify the compression level is set correctly
+    EXPECT_EQ(conf.shuffleCompressionZstdCompressLevel(), level);
+
+    // Verify the compressor is created correctly and produces ZSTD output
+    auto compressor = compress::Compressor::createCompressor(conf);
+    ASSERT_NE(compressor, nullptr);
+
+    const size_t maxLength = compressor->getDstCapacity(testData.size());
+    std::vector<uint8_t> compressedData(maxLength);
+
+    const size_t compressedSize = compressor->compress(
+        reinterpret_cast<const uint8_t*>(testData.data()),
+        0,
+        testData.size(),
+        compressedData.data(),
+        0);
+
+    ASSERT_GT(compressedSize, 9);
+    EXPECT_EQ(compressedData[0], 'Z');
+    EXPECT_EQ(compressedData[1], 'S');
+    EXPECT_EQ(compressedData[2], 'T');
+    EXPECT_EQ(compressedData[3], 'D');
+  }
+}
+
+TEST(CompressorFactoryTest, CompressWithOffsetLz4) {
+  CelebornConf conf;
+  conf.registerProperty(CelebornConf::kShuffleCompressionCodec, "LZ4");
+
+  auto compressor = compress::Compressor::createCompressor(conf);
+  ASSERT_NE(compressor, nullptr);
+
+  const std::string prefix = "SKIP_THIS_PREFIX";
+  const std::string testData =
+      "Celeborn compression offset test with structured data: "
+      "partition_0:shuffle_1:map_2:attempt_0:batch_3:data_block_4 "
+      "partition_0:shuffle_1:map_2:attempt_0:batch_3:data_block_4 "
+      "partition_0:shuffle_1:map_2:attempt_0:batch_3:data_block_4";
+  std::string fullData = prefix + testData;
+
+  const auto maxLength = compressor->getDstCapacity(testData.size());
+  std::vector<uint8_t> compressedData(maxLength);
+
+  // Compress with offset (simulating pushData usage pattern)
+  const size_t compressedSize = compressor->compress(
+      reinterpret_cast<const uint8_t*>(fullData.data()),
+      prefix.size(),
+      testData.size(),
+      compressedData.data(),
+      0);
+
+  ASSERT_GT(compressedSize, 0);
+  ASSERT_LE(compressedSize, maxLength);
+
+  auto decompressor =
+      compress::Decompressor::createDecompressor(CompressionCodec::LZ4);
+  ASSERT_NE(decompressor, nullptr);
+
+  const int originalLen = decompressor->getOriginalLen(compressedData.data());
+  EXPECT_EQ(originalLen, testData.size());
+
+  std::vector<uint8_t> decompressedData(originalLen);
+  const int decompressedSize = decompressor->decompress(
+      compressedData.data(), decompressedData.data(), 0);
+  EXPECT_EQ(decompressedSize, originalLen);
+
+  const std::string decompressedStr(
+      reinterpret_cast<const char*>(decompressedData.data()), 
decompressedSize);
+  EXPECT_EQ(decompressedStr, testData);
+  EXPECT_NE(decompressedStr, fullData);
+}
+
+TEST(CompressorFactoryTest, CompressWithOffsetZstd) {
+  CelebornConf conf;
+  conf.registerProperty(CelebornConf::kShuffleCompressionCodec, "ZSTD");
+  conf.registerProperty(
+      CelebornConf::kShuffleCompressionZstdCompressLevel, "3");
+
+  auto compressor = compress::Compressor::createCompressor(conf);
+  ASSERT_NE(compressor, nullptr);
+
+  const std::string prefix = "SKIP_THIS_PREFIX";
+  const std::string testData =
+      "Celeborn compression offset test with structured data: "
+      "partition_0:shuffle_1:map_2:attempt_0:batch_3:data_block_4 "
+      "partition_0:shuffle_1:map_2:attempt_0:batch_3:data_block_4 "
+      "partition_0:shuffle_1:map_2:attempt_0:batch_3:data_block_4";
+  std::string fullData = prefix + testData;
+
+  const auto maxLength = compressor->getDstCapacity(testData.size());
+  std::vector<uint8_t> compressedData(maxLength);
+
+  // Compress with offset (simulating pushData usage pattern)
+  const size_t compressedSize = compressor->compress(
+      reinterpret_cast<const uint8_t*>(fullData.data()),
+      prefix.size(),
+      testData.size(),
+      compressedData.data(),
+      0);
+
+  ASSERT_GT(compressedSize, 0);
+  ASSERT_LE(compressedSize, maxLength);
+
+  auto decompressor =
+      compress::Decompressor::createDecompressor(CompressionCodec::ZSTD);
+  ASSERT_NE(decompressor, nullptr);
+
+  const int originalLen = decompressor->getOriginalLen(compressedData.data());
+  EXPECT_EQ(originalLen, testData.size());
+
+  std::vector<uint8_t> decompressedData(originalLen);
+  const int decompressedSize = decompressor->decompress(
+      compressedData.data(), decompressedData.data(), 0);
+  EXPECT_EQ(decompressedSize, originalLen);
+
+  const std::string decompressedStr(
+      reinterpret_cast<const char*>(decompressedData.data()), 
decompressedSize);
+  EXPECT_EQ(decompressedStr, testData);
+  EXPECT_NE(decompressedStr, fullData);
+}
diff --git a/cpp/celeborn/conf/tests/CelebornConfTest.cpp 
b/cpp/celeborn/conf/tests/CelebornConfTest.cpp
index 41efdbc00..79619d888 100644
--- a/cpp/celeborn/conf/tests/CelebornConfTest.cpp
+++ b/cpp/celeborn/conf/tests/CelebornConfTest.cpp
@@ -19,8 +19,10 @@
 #include <fstream>
 
 #include "celeborn/conf/CelebornConf.h"
+#include "celeborn/protocol/CompressionCodec.h"
 
 using namespace celeborn::conf;
+using namespace celeborn::protocol;
 
 using CelebornUserError = celeborn::utils::CelebornUserError;
 using SECOND = std::chrono::seconds;
@@ -47,6 +49,8 @@ void testDefaultValues(CelebornConf* conf) {
   EXPECT_EQ(conf->networkIoNumConnectionsPerPeer(), 1);
   EXPECT_EQ(conf->networkIoClientThreads(), 0);
   EXPECT_EQ(conf->clientFetchMaxReqsInFlight(), 3);
+  EXPECT_EQ(conf->shuffleCompressionCodec(), CompressionCodec::NONE);
+  EXPECT_EQ(conf->shuffleCompressionZstdCompressLevel(), 1);
 }
 
 TEST(CelebornConfTest, defaultValues) {
@@ -73,6 +77,15 @@ TEST(CelebornConfTest, setValues) {
   EXPECT_EQ(conf->networkIoClientThreads(), 10);
   conf->registerProperty(CelebornConf::kClientFetchMaxReqsInFlight, "10");
   EXPECT_EQ(conf->clientFetchMaxReqsInFlight(), 10);
+  conf->registerProperty(CelebornConf::kShuffleCompressionCodec, "LZ4");
+  EXPECT_EQ(conf->shuffleCompressionCodec(), CompressionCodec::LZ4);
+  conf->registerProperty(CelebornConf::kShuffleCompressionCodec, "ZSTD");
+  EXPECT_EQ(conf->shuffleCompressionCodec(), CompressionCodec::ZSTD);
+  conf->registerProperty(CelebornConf::kShuffleCompressionCodec, "NONE");
+  EXPECT_EQ(conf->shuffleCompressionCodec(), CompressionCodec::NONE);
+  conf->registerProperty(
+      CelebornConf::kShuffleCompressionZstdCompressLevel, "5");
+  EXPECT_EQ(conf->shuffleCompressionZstdCompressLevel(), 5);
 
   EXPECT_THROW(
       conf->registerProperty("non-exist-key", "non-exist-value"),

Reply via email to