This is an automated email from the ASF dual-hosted git repository.
marong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 14ab4e828 [Gluten-5534][VL] Fix shuffle OOM if input batch is
extremely large (#5536)
14ab4e828 is described below
commit 14ab4e828416edfae7e4ac3b97e52c1a31207f62
Author: guhaiyan0221 <[email protected]>
AuthorDate: Tue Apr 30 09:02:14 2024 +0800
[Gluten-5534][VL] Fix shuffle OOM if input batch is extremely large (#5536)
---
cpp/velox/shuffle/VeloxShuffleWriter.cc | 50 +++++++++++++++++++++++--------
cpp/velox/shuffle/VeloxShuffleWriter.h | 9 ++++++
cpp/velox/tests/VeloxShuffleWriterTest.cc | 16 ++++++++++
3 files changed, 62 insertions(+), 13 deletions(-)
diff --git a/cpp/velox/shuffle/VeloxShuffleWriter.cc
b/cpp/velox/shuffle/VeloxShuffleWriter.cc
index ee427ab21..b304565e5 100644
--- a/cpp/velox/shuffle/VeloxShuffleWriter.cc
+++ b/cpp/velox/shuffle/VeloxShuffleWriter.cc
@@ -340,25 +340,42 @@ arrow::Status
VeloxShuffleWriter::split(std::shared_ptr<ColumnarBatch> cb, int64
START_TIMING(cpuWallTimingList_[CpuWallTimingFlattenRV]);
rv = veloxColumnBatch->getFlattenedRowVector();
END_TIMING();
- if (partitioner_->hasPid()) {
- auto pidArr = getFirstColumn(*rv);
- START_TIMING(cpuWallTimingList_[CpuWallTimingCompute]);
- RETURN_NOT_OK(partitioner_->compute(pidArr, rv->size(), row2Partition_,
partition2RowCount_));
- END_TIMING();
- auto strippedRv = getStrippedRowVector(*rv);
- RETURN_NOT_OK(initFromRowVector(*strippedRv));
- RETURN_NOT_OK(doSplit(*strippedRv, memLimit));
+ if (isExtremelyLargeBatch(rv)) {
+ auto numRows = rv->size();
+ int32_t offset = 0;
+ do {
+ auto length = std::min(maxBatchSize_, numRows);
+ auto slicedBatch =
std::dynamic_pointer_cast<facebook::velox::RowVector>(rv->slice(offset,
length));
+ RETURN_NOT_OK(partitioningAndDoSplit(std::move(slicedBatch),
memLimit));
+ offset += length;
+ numRows -= length;
+ } while (numRows);
} else {
- RETURN_NOT_OK(initFromRowVector(*rv));
- START_TIMING(cpuWallTimingList_[CpuWallTimingCompute]);
- RETURN_NOT_OK(partitioner_->compute(nullptr, rv->size(), row2Partition_,
partition2RowCount_));
- END_TIMING();
- RETURN_NOT_OK(doSplit(*rv, memLimit));
+ RETURN_NOT_OK(partitioningAndDoSplit(std::move(rv), memLimit));
}
}
return arrow::Status::OK();
}
+arrow::Status
VeloxShuffleWriter::partitioningAndDoSplit(facebook::velox::RowVectorPtr rv,
int64_t memLimit) {
+ if (partitioner_->hasPid()) {
+ auto pidArr = getFirstColumn(*rv);
+ START_TIMING(cpuWallTimingList_[CpuWallTimingCompute]);
+ RETURN_NOT_OK(partitioner_->compute(pidArr, rv->size(), row2Partition_,
partition2RowCount_));
+ END_TIMING();
+ auto strippedRv = getStrippedRowVector(*rv);
+ RETURN_NOT_OK(initFromRowVector(*strippedRv));
+ RETURN_NOT_OK(doSplit(*strippedRv, memLimit));
+ } else {
+ RETURN_NOT_OK(initFromRowVector(*rv));
+ START_TIMING(cpuWallTimingList_[CpuWallTimingCompute]);
+ RETURN_NOT_OK(partitioner_->compute(nullptr, rv->size(), row2Partition_,
partition2RowCount_));
+ END_TIMING();
+ RETURN_NOT_OK(doSplit(*rv, memLimit));
+ }
+ return arrow::Status::OK();
+}
+
arrow::Status VeloxShuffleWriter::stop() {
if (options_.partitioning != Partitioning::kSingle) {
for (auto pid = 0; pid < numPartitions_; ++pid) {
@@ -892,6 +909,8 @@ uint32_t
VeloxShuffleWriter::calculatePartitionBufferSize(const facebook::velox:
totalInputNumRows_ += numRows;
+ maxBatchSize_ = preAllocRowCnt == 0 ? numPartitions_ : preAllocRowCnt *
numPartitions_;
+
return (uint32_t)preAllocRowCnt;
}
@@ -1474,4 +1493,9 @@ arrow::Status
VeloxShuffleWriter::preAllocPartitionBuffers(uint32_t preAllocBuff
}
return arrow::Status::OK();
}
+
+bool VeloxShuffleWriter::isExtremelyLargeBatch(facebook::velox::RowVectorPtr&
rv) const {
+ return (rv->size() > maxBatchSize_ && maxBatchSize_ > 0);
+}
+
} // namespace gluten
diff --git a/cpp/velox/shuffle/VeloxShuffleWriter.h
b/cpp/velox/shuffle/VeloxShuffleWriter.h
index c06cd7a0d..e699a323b 100644
--- a/cpp/velox/shuffle/VeloxShuffleWriter.h
+++ b/cpp/velox/shuffle/VeloxShuffleWriter.h
@@ -192,6 +192,10 @@ class VeloxShuffleWriter final : public ShuffleWriter {
VS_PRINT_CONTAINER(input_has_null_);
}
+ int32_t maxBatchSize() const {
+ return maxBatchSize_;
+ }
+
private:
VeloxShuffleWriter(
uint32_t numPartitions,
@@ -306,6 +310,10 @@ class VeloxShuffleWriter final : public ShuffleWriter {
arrow::Result<uint32_t> partitionBufferSizeAfterShrink(uint32_t partitionId)
const;
+ bool isExtremelyLargeBatch(facebook::velox::RowVectorPtr& rv) const;
+
+ arrow::Status partitioningAndDoSplit(facebook::velox::RowVectorPtr rv,
int64_t memLimit);
+
SplitState splitState_{kInit};
EvictState evictState_{kEvictable};
@@ -466,6 +474,7 @@ class VeloxShuffleWriter final : public ShuffleWriter {
}
facebook::velox::CpuWallTiming cpuWallTimingList_[CpuWallTimingNum];
+ int32_t maxBatchSize_{0};
}; // class VeloxShuffleWriter
} // namespace gluten
diff --git a/cpp/velox/tests/VeloxShuffleWriterTest.cc
b/cpp/velox/tests/VeloxShuffleWriterTest.cc
index 2a8f12afb..ffda945b1 100644
--- a/cpp/velox/tests/VeloxShuffleWriterTest.cc
+++ b/cpp/velox/tests/VeloxShuffleWriterTest.cc
@@ -258,6 +258,22 @@ TEST_P(HashPartitioningShuffleWriter, hashPart3Vectors) {
{{blockPid2}, {blockPid1}});
}
+TEST_P(HashPartitioningShuffleWriter, hashLargeVectors) {
+ const int32_t expectedMaxBatchSize = 8;
+ ASSERT_NOT_OK(initShuffleWriterOptions());
+ auto shuffleWriter = createShuffleWriter(defaultArrowMemoryPool().get());
+ // calculate maxBatchSize_
+ ASSERT_NOT_OK(splitRowVector(*shuffleWriter, hashInputVector1_));
+ VELOX_CHECK_EQ(shuffleWriter->maxBatchSize(), expectedMaxBatchSize);
+
+ auto blockPid2 = takeRows({inputVector1_, inputVector2_, inputVector1_},
{{1, 2, 3, 4, 8}, {0, 1}, {1, 2, 3, 4, 8}});
+ auto blockPid1 = takeRows({inputVector1_}, {{0, 5, 6, 7, 9, 0, 5, 6, 7, 9}});
+
+ VELOX_CHECK(hashInputVector1_->size() > expectedMaxBatchSize);
+ testShuffleWriteMultiBlocks(
+ *shuffleWriter, {hashInputVector2_, hashInputVector1_}, 2,
inputVector1_->type(), {{blockPid2}, {blockPid1}});
+}
+
TEST_P(RangePartitioningShuffleWriter, rangePartition) {
ASSERT_NOT_OK(initShuffleWriterOptions());
auto shuffleWriter = createShuffleWriter(defaultArrowMemoryPool().get());
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]