This is an automated email from the ASF dual-hosted git repository.
nicholasjiang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git
The following commit(s) were added to refs/heads/main by this push:
new b1cbdabdf [CELEBORN-2221][CIP-14] Support writing with compression in
C++ client
b1cbdabdf is described below
commit b1cbdabdf70230b7d2fff3ee0f7c44fa5a829f92
Author: afterincomparableyum <afterincomparableyum>
AuthorDate: Mon Feb 16 14:59:18 2026 +0800
[CELEBORN-2221][CIP-14] Support writing with compression in C++ client
Integrate existing compression infrastructure (LZ4 and ZSTD) into the C++
client write path. This enables compression during pushData operations,
matching the functionality available in the Java client.
Changes:
- Add compression support to ShuffleClientImpl:
* Add shuffleCompressionEnabled_ flag and compressor_ member
* Initialize compressor from CelebornConf in constructor
* Compress data in pushData() when compression is enabled
* Use compressed size for batchBytesSize tracking
- Configuration integration:
* Read compression codec from celeborn.client.shuffle.compression.codec
* Read ZSTD compression level from
celeborn.client.shuffle.compression.zstd.level
* Default to NONE (compression disabled)
- Retry/revive support:
* Retry path correctly uses pre-compressed body buffer
* No re-compression needed during retries
- Testing:
* Add CompressorFactoryTest for factory pattern and config integration
* Add compression config tests to CelebornConfTest
* Test offset compression support for both LZ4 and ZSTD
### How was this patch tested?
Unit Tests, as well as compiling code
Closes #3575 from afterincomparableyum/cpp-client/celeborn-2221.
Authored-by: afterincomparableyum <afterincomparableyum>
Signed-off-by: SteNicholas <[email protected]>
---
cpp/celeborn/client/ShuffleClient.cpp | 75 ++++++-
cpp/celeborn/client/ShuffleClient.h | 6 +
cpp/celeborn/client/tests/CMakeLists.txt | 3 +-
.../client/tests/CompressorFactoryTest.cpp | 226 +++++++++++++++++++++
cpp/celeborn/conf/tests/CelebornConfTest.cpp | 13 ++
5 files changed, 315 insertions(+), 8 deletions(-)
diff --git a/cpp/celeborn/client/ShuffleClient.cpp
b/cpp/celeborn/client/ShuffleClient.cpp
index 401b99e2f..35887135e 100644
--- a/cpp/celeborn/client/ShuffleClient.cpp
+++ b/cpp/celeborn/client/ShuffleClient.cpp
@@ -16,6 +16,7 @@
*/
#include "celeborn/client/ShuffleClient.h"
+#include <limits>
#include "celeborn/utils/CelebornUtils.h"
@@ -57,7 +58,16 @@ ShuffleClientImpl::ShuffleClientImpl(
: appUniqueId_(appUniqueId),
conf_(conf),
clientFactory_(clientEndpoint.clientFactory()),
- pushDataRetryPool_(clientEndpoint.pushDataRetryPool()) {
+ pushDataRetryPool_(clientEndpoint.pushDataRetryPool()),
+ shuffleCompressionEnabled_(
+ conf->shuffleCompressionCodec() != protocol::CompressionCodec::NONE),
+ compressorFactory_(
+ shuffleCompressionEnabled_
+ ? std::function<std::unique_ptr<compress::Compressor>()>(
+ [conf]() {
+ return compress::Compressor::createCompressor(*conf);
+ })
+ : std::function<std::unique_ptr<compress::Compressor>()>()) {
CELEBORN_CHECK_NOT_NULL(clientFactory_);
CELEBORN_CHECK_NOT_NULL(pushDataRetryPool_);
}
@@ -154,23 +164,74 @@ int ShuffleClientImpl::pushData(
auto pushState = getPushState(mapKey);
const int nextBatchId = pushState->nextBatchId();
- // TODO: compression in writing is not supported.
+ // Validate input size fits in 32-bit int since it is required by compressor
+ // API and wire protocol
+ CELEBORN_CHECK(
+ length <= static_cast<size_t>(std::numeric_limits<int>::max()),
+ fmt::format(
+ "Data length {} exceeds maximum supported size {}",
+ length,
+ std::numeric_limits<int>::max()));
+
+ // Compression support: compress data if compression is enabled
+ const uint8_t* dataToWrite = data + offset;
+ int lengthToWrite = static_cast<int>(length);
+ std::unique_ptr<uint8_t[]> compressedBuffer;
+
+ if (shuffleCompressionEnabled_ && compressorFactory_) {
+ // Create a new compressor instance for thread-safety
+ auto compressor = compressorFactory_();
+ // Allocate buffer for compressed data
+ const size_t compressedCapacity =
+ compressor->getDstCapacity(static_cast<int>(length));
+ compressedBuffer = std::make_unique<uint8_t[]>(compressedCapacity);
+
+ // Compress the data
+ const size_t compressedSize = compressor->compress(
+ dataToWrite, 0, static_cast<int>(length), compressedBuffer.get(), 0);
+
+ CELEBORN_CHECK(
+ compressedSize <= static_cast<size_t>(std::numeric_limits<int>::max()),
+ fmt::format(
+ "Compressed size {} exceeds maximum supported size {}",
+ compressedSize,
+ std::numeric_limits<int>::max()));
+
+ lengthToWrite = static_cast<int>(compressedSize);
+ dataToWrite = compressedBuffer.get();
+ }
- auto writeBuffer =
- memory::ByteBuffer::createWriteOnly(kBatchHeaderSize + length);
+ // Validate final buffer size fits in size_t and int
+ CELEBORN_CHECK(
+ static_cast<size_t>(lengthToWrite) <=
+ std::numeric_limits<size_t>::max() - kBatchHeaderSize,
+ fmt::format(
+ "Buffer size {} + header {} would overflow",
+ lengthToWrite,
+ kBatchHeaderSize));
+
+ auto writeBuffer = memory::ByteBuffer::createWriteOnly(
+ kBatchHeaderSize + static_cast<size_t>(lengthToWrite));
// TODO: the java side uses Platform to write the data. We simply assume
// littleEndian here.
writeBuffer->writeLE<int>(mapId);
writeBuffer->writeLE<int>(attemptId);
writeBuffer->writeLE<int>(nextBatchId);
- writeBuffer->writeLE<int>(length);
- writeBuffer->writeFromBuffer(data, offset, length);
+ writeBuffer->writeLE<int>(lengthToWrite);
+ writeBuffer->writeFromBuffer(
+ dataToWrite, 0, static_cast<size_t>(lengthToWrite));
auto hostAndPushPort = partitionLocation->hostAndPushPort();
// Check limit.
limitMaxInFlight(mapKey, *pushState, hostAndPushPort);
// Add inFlight requests.
- const int batchBytesSize = length + kBatchHeaderSize;
+ CELEBORN_CHECK(
+ lengthToWrite <= std::numeric_limits<int>::max() - kBatchHeaderSize,
+ fmt::format(
+ "Batch bytes size {} + header {} would overflow int",
+ lengthToWrite,
+ kBatchHeaderSize));
+ const int batchBytesSize = lengthToWrite + kBatchHeaderSize;
pushState->addBatch(nextBatchId, batchBytesSize, hostAndPushPort);
// Build pushData request.
const auto shuffleKey = utils::makeShuffleKey(appUniqueId_, shuffleId);
diff --git a/cpp/celeborn/client/ShuffleClient.h
b/cpp/celeborn/client/ShuffleClient.h
index 3e8cb9d37..e899ba0cd 100644
--- a/cpp/celeborn/client/ShuffleClient.h
+++ b/cpp/celeborn/client/ShuffleClient.h
@@ -17,6 +17,8 @@
#pragma once
+#include <functional>
+#include "celeborn/client/compress/Compressor.h"
#include "celeborn/client/reader/CelebornInputStream.h"
#include "celeborn/client/writer/PushDataCallback.h"
#include "celeborn/client/writer/PushState.h"
@@ -249,6 +251,7 @@ class ShuffleClientImpl
static constexpr size_t kBatchHeaderSize = 4 * 4;
const std::string appUniqueId_;
+ const bool shuffleCompressionEnabled_;
std::shared_ptr<const conf::CelebornConf> conf_;
std::shared_ptr<network::NettyRpcEndpointRef> lifecycleManagerRef_;
std::shared_ptr<network::TransportClientFactory> clientFactory_;
@@ -266,6 +269,9 @@ class ShuffleClientImpl
mapperEndSets_;
utils::ConcurrentHashSet<int> stageEndShuffleSet_;
+ // Factory for creating compressor instances on demand to avoid sharing a
+ // single non-thread-safe compressor across concurrent operations.
+ std::function<std::unique_ptr<compress::Compressor>()> compressorFactory_;
// TODO: pushExcludedWorker is not supported yet
};
} // namespace client
diff --git a/cpp/celeborn/client/tests/CMakeLists.txt
b/cpp/celeborn/client/tests/CMakeLists.txt
index e19703f31..63c37c6e1 100644
--- a/cpp/celeborn/client/tests/CMakeLists.txt
+++ b/cpp/celeborn/client/tests/CMakeLists.txt
@@ -22,7 +22,8 @@ add_executable(
Lz4DecompressorTest.cpp
ZstdDecompressorTest.cpp
Lz4CompressorTest.cpp
- ZstdCompressorTest.cpp)
+ ZstdCompressorTest.cpp
+ CompressorFactoryTest.cpp)
add_test(NAME celeborn_client_test COMMAND celeborn_client_test)
diff --git a/cpp/celeborn/client/tests/CompressorFactoryTest.cpp
b/cpp/celeborn/client/tests/CompressorFactoryTest.cpp
new file mode 100644
index 000000000..cbfaa27e1
--- /dev/null
+++ b/cpp/celeborn/client/tests/CompressorFactoryTest.cpp
@@ -0,0 +1,226 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <gtest/gtest.h>
+
+#include "celeborn/client/compress/Compressor.h"
+#include "celeborn/client/compress/Decompressor.h"
+#include "celeborn/conf/CelebornConf.h"
+
+using namespace celeborn;
+using namespace celeborn::client;
+using namespace celeborn::conf;
+using namespace celeborn::protocol;
+
+TEST(CompressorFactoryTest, CreateLz4CompressorFromConf) {
+ CelebornConf conf;
+ conf.registerProperty(CelebornConf::kShuffleCompressionCodec, "LZ4");
+
+ auto compressor = compress::Compressor::createCompressor(conf);
+ ASSERT_NE(compressor, nullptr);
+
+ const std::string testData = "Test data for compression";
+ const size_t maxLength = compressor->getDstCapacity(testData.size());
+ std::vector<uint8_t> compressedData(maxLength);
+
+ const size_t compressedSize = compressor->compress(
+ reinterpret_cast<const uint8_t*>(testData.data()),
+ 0,
+ testData.size(),
+ compressedData.data(),
+ 0);
+
+ ASSERT_GT(compressedSize, 8);
+ EXPECT_EQ(compressedData[0], 'L');
+ EXPECT_EQ(compressedData[1], 'Z');
+ EXPECT_EQ(compressedData[2], '4');
+ EXPECT_EQ(compressedData[3], 'B');
+ EXPECT_EQ(compressedData[4], 'l');
+ EXPECT_EQ(compressedData[5], 'o');
+ EXPECT_EQ(compressedData[6], 'c');
+ EXPECT_EQ(compressedData[7], 'k');
+}
+
+TEST(CompressorFactoryTest, CreateZstdCompressorFromConf) {
+ CelebornConf conf;
+ conf.registerProperty(CelebornConf::kShuffleCompressionCodec, "ZSTD");
+ conf.registerProperty(
+ CelebornConf::kShuffleCompressionZstdCompressLevel, "3");
+
+ auto compressor = compress::Compressor::createCompressor(conf);
+ ASSERT_NE(compressor, nullptr);
+
+ const std::string testData = "Test data for compression";
+ const size_t maxLength = compressor->getDstCapacity(testData.size());
+ std::vector<uint8_t> compressedData(maxLength);
+
+ const size_t compressedSize = compressor->compress(
+ reinterpret_cast<const uint8_t*>(testData.data()),
+ 0,
+ testData.size(),
+ compressedData.data(),
+ 0);
+
+ ASSERT_GT(compressedSize, 9);
+ EXPECT_EQ(compressedData[0], 'Z');
+ EXPECT_EQ(compressedData[1], 'S');
+ EXPECT_EQ(compressedData[2], 'T');
+ EXPECT_EQ(compressedData[3], 'D');
+ EXPECT_EQ(compressedData[4], 'B');
+ EXPECT_EQ(compressedData[5], 'l');
+ EXPECT_EQ(compressedData[6], 'o');
+ EXPECT_EQ(compressedData[7], 'c');
+ EXPECT_EQ(compressedData[8], 'k');
+}
+
+TEST(CompressorFactoryTest, CompressionCodecNoneDisablesCompression) {
+ CelebornConf conf;
+ // Verify default is NONE
+ EXPECT_EQ(conf.shuffleCompressionCodec(), CompressionCodec::NONE);
+}
+
+TEST(CompressorFactoryTest, ZstdCompressionLevelFromConf) {
+ // Test that configuration correctly reads ZSTD compression levels
+ const std::string testData = "Test data for compression";
+
+ for (int level = -5; level <= 10; level++) {
+ CelebornConf conf;
+ conf.registerProperty(CelebornConf::kShuffleCompressionCodec, "ZSTD");
+ conf.registerProperty(
+ CelebornConf::kShuffleCompressionZstdCompressLevel,
+ std::to_string(level));
+
+ // Verify the compression level is set correctly
+ EXPECT_EQ(conf.shuffleCompressionZstdCompressLevel(), level);
+
+ // Verify the compressor is created correctly and produces ZSTD output
+ auto compressor = compress::Compressor::createCompressor(conf);
+ ASSERT_NE(compressor, nullptr);
+
+ const size_t maxLength = compressor->getDstCapacity(testData.size());
+ std::vector<uint8_t> compressedData(maxLength);
+
+ const size_t compressedSize = compressor->compress(
+ reinterpret_cast<const uint8_t*>(testData.data()),
+ 0,
+ testData.size(),
+ compressedData.data(),
+ 0);
+
+ ASSERT_GT(compressedSize, 9);
+ EXPECT_EQ(compressedData[0], 'Z');
+ EXPECT_EQ(compressedData[1], 'S');
+ EXPECT_EQ(compressedData[2], 'T');
+ EXPECT_EQ(compressedData[3], 'D');
+ }
+}
+
+TEST(CompressorFactoryTest, CompressWithOffsetLz4) {
+ CelebornConf conf;
+ conf.registerProperty(CelebornConf::kShuffleCompressionCodec, "LZ4");
+
+ auto compressor = compress::Compressor::createCompressor(conf);
+ ASSERT_NE(compressor, nullptr);
+
+ const std::string prefix = "SKIP_THIS_PREFIX";
+ const std::string testData =
+ "Celeborn compression offset test with structured data: "
+ "partition_0:shuffle_1:map_2:attempt_0:batch_3:data_block_4 "
+ "partition_0:shuffle_1:map_2:attempt_0:batch_3:data_block_4 "
+ "partition_0:shuffle_1:map_2:attempt_0:batch_3:data_block_4";
+ std::string fullData = prefix + testData;
+
+ const auto maxLength = compressor->getDstCapacity(testData.size());
+ std::vector<uint8_t> compressedData(maxLength);
+
+ // Compress with offset (simulating pushData usage pattern)
+ const size_t compressedSize = compressor->compress(
+ reinterpret_cast<const uint8_t*>(fullData.data()),
+ prefix.size(),
+ testData.size(),
+ compressedData.data(),
+ 0);
+
+ ASSERT_GT(compressedSize, 0);
+ ASSERT_LE(compressedSize, maxLength);
+
+ auto decompressor =
+ compress::Decompressor::createDecompressor(CompressionCodec::LZ4);
+ ASSERT_NE(decompressor, nullptr);
+
+ const int originalLen = decompressor->getOriginalLen(compressedData.data());
+ EXPECT_EQ(originalLen, testData.size());
+
+ std::vector<uint8_t> decompressedData(originalLen);
+ const int decompressedSize = decompressor->decompress(
+ compressedData.data(), decompressedData.data(), 0);
+ EXPECT_EQ(decompressedSize, originalLen);
+
+ const std::string decompressedStr(
+ reinterpret_cast<const char*>(decompressedData.data()),
decompressedSize);
+ EXPECT_EQ(decompressedStr, testData);
+ EXPECT_NE(decompressedStr, fullData);
+}
+
+TEST(CompressorFactoryTest, CompressWithOffsetZstd) {
+ CelebornConf conf;
+ conf.registerProperty(CelebornConf::kShuffleCompressionCodec, "ZSTD");
+ conf.registerProperty(
+ CelebornConf::kShuffleCompressionZstdCompressLevel, "3");
+
+ auto compressor = compress::Compressor::createCompressor(conf);
+ ASSERT_NE(compressor, nullptr);
+
+ const std::string prefix = "SKIP_THIS_PREFIX";
+ const std::string testData =
+ "Celeborn compression offset test with structured data: "
+ "partition_0:shuffle_1:map_2:attempt_0:batch_3:data_block_4 "
+ "partition_0:shuffle_1:map_2:attempt_0:batch_3:data_block_4 "
+ "partition_0:shuffle_1:map_2:attempt_0:batch_3:data_block_4";
+ std::string fullData = prefix + testData;
+
+ const auto maxLength = compressor->getDstCapacity(testData.size());
+ std::vector<uint8_t> compressedData(maxLength);
+
+ // Compress with offset (simulating pushData usage pattern)
+ const size_t compressedSize = compressor->compress(
+ reinterpret_cast<const uint8_t*>(fullData.data()),
+ prefix.size(),
+ testData.size(),
+ compressedData.data(),
+ 0);
+
+ ASSERT_GT(compressedSize, 0);
+ ASSERT_LE(compressedSize, maxLength);
+
+ auto decompressor =
+ compress::Decompressor::createDecompressor(CompressionCodec::ZSTD);
+ ASSERT_NE(decompressor, nullptr);
+
+ const int originalLen = decompressor->getOriginalLen(compressedData.data());
+ EXPECT_EQ(originalLen, testData.size());
+
+ std::vector<uint8_t> decompressedData(originalLen);
+ const int decompressedSize = decompressor->decompress(
+ compressedData.data(), decompressedData.data(), 0);
+ EXPECT_EQ(decompressedSize, originalLen);
+
+ const std::string decompressedStr(
+ reinterpret_cast<const char*>(decompressedData.data()),
decompressedSize);
+ EXPECT_EQ(decompressedStr, testData);
+ EXPECT_NE(decompressedStr, fullData);
+}
diff --git a/cpp/celeborn/conf/tests/CelebornConfTest.cpp
b/cpp/celeborn/conf/tests/CelebornConfTest.cpp
index 41efdbc00..79619d888 100644
--- a/cpp/celeborn/conf/tests/CelebornConfTest.cpp
+++ b/cpp/celeborn/conf/tests/CelebornConfTest.cpp
@@ -19,8 +19,10 @@
#include <fstream>
#include "celeborn/conf/CelebornConf.h"
+#include "celeborn/protocol/CompressionCodec.h"
using namespace celeborn::conf;
+using namespace celeborn::protocol;
using CelebornUserError = celeborn::utils::CelebornUserError;
using SECOND = std::chrono::seconds;
@@ -47,6 +49,8 @@ void testDefaultValues(CelebornConf* conf) {
EXPECT_EQ(conf->networkIoNumConnectionsPerPeer(), 1);
EXPECT_EQ(conf->networkIoClientThreads(), 0);
EXPECT_EQ(conf->clientFetchMaxReqsInFlight(), 3);
+ EXPECT_EQ(conf->shuffleCompressionCodec(), CompressionCodec::NONE);
+ EXPECT_EQ(conf->shuffleCompressionZstdCompressLevel(), 1);
}
TEST(CelebornConfTest, defaultValues) {
@@ -73,6 +77,15 @@ TEST(CelebornConfTest, setValues) {
EXPECT_EQ(conf->networkIoClientThreads(), 10);
conf->registerProperty(CelebornConf::kClientFetchMaxReqsInFlight, "10");
EXPECT_EQ(conf->clientFetchMaxReqsInFlight(), 10);
+ conf->registerProperty(CelebornConf::kShuffleCompressionCodec, "LZ4");
+ EXPECT_EQ(conf->shuffleCompressionCodec(), CompressionCodec::LZ4);
+ conf->registerProperty(CelebornConf::kShuffleCompressionCodec, "ZSTD");
+ EXPECT_EQ(conf->shuffleCompressionCodec(), CompressionCodec::ZSTD);
+ conf->registerProperty(CelebornConf::kShuffleCompressionCodec, "NONE");
+ EXPECT_EQ(conf->shuffleCompressionCodec(), CompressionCodec::NONE);
+ conf->registerProperty(
+ CelebornConf::kShuffleCompressionZstdCompressLevel, "5");
+ EXPECT_EQ(conf->shuffleCompressionZstdCompressLevel(), 5);
EXPECT_THROW(
conf->registerProperty("non-exist-key", "non-exist-value"),