afterincomparableyum commented on code in PR #3575:
URL: https://github.com/apache/celeborn/pull/3575#discussion_r2767454531
##########
cpp/celeborn/client/ShuffleClient.cpp:
##########
@@ -154,23 +164,68 @@ int ShuffleClientImpl::pushData(
auto pushState = getPushState(mapKey);
const int nextBatchId = pushState->nextBatchId();
- // TODO: compression in writing is not supported.
+ // Validate input size fits in 32-bit int since it is required by compressor
+ // API and wire protocol
+ CELEBORN_CHECK(
+ length <= static_cast<size_t>(std::numeric_limits<int>::max()),
+ fmt::format(
+ "Data length {} exceeds maximum supported size {}",
+ length,
+ std::numeric_limits<int>::max()));
+
+ // Compression support: compress data if compression is enabled
+ const uint8_t* dataToWrite = data + offset;
+ int lengthToWrite = static_cast<int>(length);
+ std::unique_ptr<uint8_t[]> compressedBuffer;
+
+ if (shuffleCompressionEnabled_ && compressorFactory_) {
+ // Create a new compressor instance for thread-safety
+ auto compressor = compressorFactory_();
+ // Allocate buffer for compressed data
+ const size_t compressedCapacity =
+ compressor->getDstCapacity(static_cast<int>(length));
+ compressedBuffer = std::make_unique<uint8_t[]>(compressedCapacity);
+
+ // Compress the data
+ const size_t compressedSize = compressor->compress(
+ dataToWrite, 0, static_cast<int>(length), compressedBuffer.get(), 0);
+
+ CELEBORN_CHECK(
+ compressedSize <= static_cast<size_t>(std::numeric_limits<int>::max()),
+ fmt::format(
+ "Compressed size {} exceeds maximum supported size {}",
+ compressedSize,
+ std::numeric_limits<int>::max()));
+
+ lengthToWrite = static_cast<int>(compressedSize);
+ dataToWrite = compressedBuffer.get();
+ }
- auto writeBuffer =
- memory::ByteBuffer::createWriteOnly(kBatchHeaderSize + length);
+ // Validate final buffer size fits in size_t and int
+ CELEBORN_CHECK(
+ static_cast<size_t>(lengthToWrite) <=
+ std::numeric_limits<size_t>::max() - kBatchHeaderSize,
+ fmt::format(
+ "Buffer size {} + header {} would overflow",
+ lengthToWrite,
+ kBatchHeaderSize));
+
+ auto writeBuffer = memory::ByteBuffer::createWriteOnly(
+ kBatchHeaderSize + static_cast<size_t>(lengthToWrite));
// TODO: the java side uses Platform to write the data. We simply assume
// littleEndian here.
writeBuffer->writeLE<int>(mapId);
writeBuffer->writeLE<int>(attemptId);
writeBuffer->writeLE<int>(nextBatchId);
- writeBuffer->writeLE<int>(length);
- writeBuffer->writeFromBuffer(data, offset, length);
+ writeBuffer->writeLE<int>(lengthToWrite);
+ writeBuffer->writeFromBuffer(
+ dataToWrite, 0, static_cast<size_t>(lengthToWrite));
auto hostAndPushPort = partitionLocation->hostAndPushPort();
// Check limit.
limitMaxInFlight(mapKey, *pushState, hostAndPushPort);
// Add inFlight requests.
Review Comment:
sure, I can add this check as well
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]