Copilot commented on code in PR #3575:
URL: https://github.com/apache/celeborn/pull/3575#discussion_r2752353901
##########
cpp/celeborn/client/ShuffleClient.cpp:
##########
@@ -57,7 +58,16 @@ ShuffleClientImpl::ShuffleClientImpl(
: appUniqueId_(appUniqueId),
conf_(conf),
clientFactory_(clientEndpoint.clientFactory()),
- pushDataRetryPool_(clientEndpoint.pushDataRetryPool()) {
+ pushDataRetryPool_(clientEndpoint.pushDataRetryPool()),
+ shuffleCompressionEnabled_(
+ conf->shuffleCompressionCodec() != protocol::CompressionCodec::NONE),
+ compressorFactory_(
+ shuffleCompressionEnabled_
+ ? std::function<std::unique_ptr<compress::Compressor>()>(
+ [conf]() {
+ return compress::Compressor::createCompressor(*conf);
+ })
+ : std::function<std::unique_ptr<compress::Compressor>()>()) {
Review Comment:
The constructor initializer list order does not match the member declaration
order in ShuffleClientImpl (e.g., conf_ is listed before
shuffleCompressionEnabled_ even though shuffleCompressionEnabled_ is declared
first). This can trigger -Wreorder (often treated as an error) and is
confusing. Reorder the initializer list to match the declaration order in
ShuffleClient.h.
##########
cpp/celeborn/client/ShuffleClient.h:
##########
@@ -266,6 +268,9 @@ class ShuffleClientImpl
mapperEndSets_;
utils::ConcurrentHashSet<int> stageEndShuffleSet_;
+ // Factory for creating compressor instances on demand to avoid sharing a
+ // single non-thread-safe compressor across concurrent operations.
+ std::function<std::unique_ptr<compress::Compressor>()> compressorFactory_;
Review Comment:
This header now introduces a `std::function` member (`compressorFactory_`)
but does not include `<functional>`. Relying on transitive includes is brittle
and can break builds when include graphs change; add an explicit `<functional>`
include in this header.
##########
cpp/celeborn/client/ShuffleClient.cpp:
##########
@@ -154,23 +164,68 @@ int ShuffleClientImpl::pushData(
auto pushState = getPushState(mapKey);
const int nextBatchId = pushState->nextBatchId();
- // TODO: compression in writing is not supported.
+ // Validate input size fits in 32-bit int since it is required by compressor
+ // API and wire protocol
+ CELEBORN_CHECK(
+ length <= static_cast<size_t>(std::numeric_limits<int>::max()),
+ fmt::format(
+ "Data length {} exceeds maximum supported size {}",
+ length,
+ std::numeric_limits<int>::max()));
+
+ // Compression support: compress data if compression is enabled
+ const uint8_t* dataToWrite = data + offset;
+ int lengthToWrite = static_cast<int>(length);
+ std::unique_ptr<uint8_t[]> compressedBuffer;
+
+ if (shuffleCompressionEnabled_ && compressorFactory_) {
+ // Create a new compressor instance for thread-safety
+ auto compressor = compressorFactory_();
+ // Allocate buffer for compressed data
+ const size_t compressedCapacity =
+ compressor->getDstCapacity(static_cast<int>(length));
+ compressedBuffer = std::make_unique<uint8_t[]>(compressedCapacity);
+
+ // Compress the data
+ const size_t compressedSize = compressor->compress(
+ dataToWrite, 0, static_cast<int>(length), compressedBuffer.get(), 0);
+
+ CELEBORN_CHECK(
+ compressedSize <= static_cast<size_t>(std::numeric_limits<int>::max()),
+ fmt::format(
+ "Compressed size {} exceeds maximum supported size {}",
+ compressedSize,
+ std::numeric_limits<int>::max()));
+
+ lengthToWrite = static_cast<int>(compressedSize);
+ dataToWrite = compressedBuffer.get();
+ }
- auto writeBuffer =
- memory::ByteBuffer::createWriteOnly(kBatchHeaderSize + length);
+ // Validate final buffer size fits in size_t and int
+ CELEBORN_CHECK(
+ static_cast<size_t>(lengthToWrite) <=
+ std::numeric_limits<size_t>::max() - kBatchHeaderSize,
+ fmt::format(
+ "Buffer size {} + header {} would overflow",
+ lengthToWrite,
+ kBatchHeaderSize));
+
+ auto writeBuffer = memory::ByteBuffer::createWriteOnly(
+ kBatchHeaderSize + static_cast<size_t>(lengthToWrite));
// TODO: the java side uses Platform to write the data. We simply assume
// littleEndian here.
writeBuffer->writeLE<int>(mapId);
writeBuffer->writeLE<int>(attemptId);
writeBuffer->writeLE<int>(nextBatchId);
- writeBuffer->writeLE<int>(length);
- writeBuffer->writeFromBuffer(data, offset, length);
+ writeBuffer->writeLE<int>(lengthToWrite);
+ writeBuffer->writeFromBuffer(
+ dataToWrite, 0, static_cast<size_t>(lengthToWrite));
auto hostAndPushPort = partitionLocation->hostAndPushPort();
// Check limit.
limitMaxInFlight(mapKey, *pushState, hostAndPushPort);
// Add inFlight requests.
Review Comment:
`batchBytesSize` is computed as `lengthToWrite + kBatchHeaderSize` and
stored in an `int`. Since `lengthToWrite` can be `INT_MAX` (it is only checked
against `INT_MAX`), this addition can overflow and produce a negative/incorrect
in-flight byte count. Add a guard that `lengthToWrite <= INT_MAX -
kBatchHeaderSize` (or change the accounting type to a wider integer) before
computing/storing `batchBytesSize`.
```suggestion
// Add inFlight requests.
CELEBORN_CHECK(
lengthToWrite <= std::numeric_limits<int>::max() - kBatchHeaderSize,
fmt::format(
"Batch bytes size {} + header {} would overflow int",
lengthToWrite,
kBatchHeaderSize));
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]