This is an automated email from the ASF dual-hosted git repository.

marong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 14ab4e828 [Gluten-5534][VL] Fix shuffle OOM if input batch is 
extremely large (#5536)
14ab4e828 is described below

commit 14ab4e828416edfae7e4ac3b97e52c1a31207f62
Author: guhaiyan0221 <[email protected]>
AuthorDate: Tue Apr 30 09:02:14 2024 +0800

    [Gluten-5534][VL] Fix shuffle OOM if input batch is extremely large (#5536)
---
 cpp/velox/shuffle/VeloxShuffleWriter.cc   | 50 +++++++++++++++++++++++--------
 cpp/velox/shuffle/VeloxShuffleWriter.h    |  9 ++++++
 cpp/velox/tests/VeloxShuffleWriterTest.cc | 16 ++++++++++
 3 files changed, 62 insertions(+), 13 deletions(-)

diff --git a/cpp/velox/shuffle/VeloxShuffleWriter.cc 
b/cpp/velox/shuffle/VeloxShuffleWriter.cc
index ee427ab21..b304565e5 100644
--- a/cpp/velox/shuffle/VeloxShuffleWriter.cc
+++ b/cpp/velox/shuffle/VeloxShuffleWriter.cc
@@ -340,25 +340,42 @@ arrow::Status 
VeloxShuffleWriter::split(std::shared_ptr<ColumnarBatch> cb, int64
     START_TIMING(cpuWallTimingList_[CpuWallTimingFlattenRV]);
     rv = veloxColumnBatch->getFlattenedRowVector();
     END_TIMING();
-    if (partitioner_->hasPid()) {
-      auto pidArr = getFirstColumn(*rv);
-      START_TIMING(cpuWallTimingList_[CpuWallTimingCompute]);
-      RETURN_NOT_OK(partitioner_->compute(pidArr, rv->size(), row2Partition_, 
partition2RowCount_));
-      END_TIMING();
-      auto strippedRv = getStrippedRowVector(*rv);
-      RETURN_NOT_OK(initFromRowVector(*strippedRv));
-      RETURN_NOT_OK(doSplit(*strippedRv, memLimit));
+    if (isExtremelyLargeBatch(rv)) {
+      auto numRows = rv->size();
+      int32_t offset = 0;
+      do {
+        auto length = std::min(maxBatchSize_, numRows);
+        auto slicedBatch = 
std::dynamic_pointer_cast<facebook::velox::RowVector>(rv->slice(offset, 
length));
+        RETURN_NOT_OK(partitioningAndDoSplit(std::move(slicedBatch), 
memLimit));
+        offset += length;
+        numRows -= length;
+      } while (numRows);
     } else {
-      RETURN_NOT_OK(initFromRowVector(*rv));
-      START_TIMING(cpuWallTimingList_[CpuWallTimingCompute]);
-      RETURN_NOT_OK(partitioner_->compute(nullptr, rv->size(), row2Partition_, 
partition2RowCount_));
-      END_TIMING();
-      RETURN_NOT_OK(doSplit(*rv, memLimit));
+      RETURN_NOT_OK(partitioningAndDoSplit(std::move(rv), memLimit));
     }
   }
   return arrow::Status::OK();
 }
 
+arrow::Status 
VeloxShuffleWriter::partitioningAndDoSplit(facebook::velox::RowVectorPtr rv, 
int64_t memLimit) {
+  if (partitioner_->hasPid()) {
+    auto pidArr = getFirstColumn(*rv);
+    START_TIMING(cpuWallTimingList_[CpuWallTimingCompute]);
+    RETURN_NOT_OK(partitioner_->compute(pidArr, rv->size(), row2Partition_, 
partition2RowCount_));
+    END_TIMING();
+    auto strippedRv = getStrippedRowVector(*rv);
+    RETURN_NOT_OK(initFromRowVector(*strippedRv));
+    RETURN_NOT_OK(doSplit(*strippedRv, memLimit));
+  } else {
+    RETURN_NOT_OK(initFromRowVector(*rv));
+    START_TIMING(cpuWallTimingList_[CpuWallTimingCompute]);
+    RETURN_NOT_OK(partitioner_->compute(nullptr, rv->size(), row2Partition_, 
partition2RowCount_));
+    END_TIMING();
+    RETURN_NOT_OK(doSplit(*rv, memLimit));
+  }
+  return arrow::Status::OK();
+}
+
 arrow::Status VeloxShuffleWriter::stop() {
   if (options_.partitioning != Partitioning::kSingle) {
     for (auto pid = 0; pid < numPartitions_; ++pid) {
@@ -892,6 +909,8 @@ uint32_t 
VeloxShuffleWriter::calculatePartitionBufferSize(const facebook::velox:
 
   totalInputNumRows_ += numRows;
 
+  maxBatchSize_ = preAllocRowCnt == 0 ? numPartitions_ : preAllocRowCnt * 
numPartitions_;
+
   return (uint32_t)preAllocRowCnt;
 }
 
@@ -1474,4 +1493,9 @@ arrow::Status 
VeloxShuffleWriter::preAllocPartitionBuffers(uint32_t preAllocBuff
   }
   return arrow::Status::OK();
 }
+
+bool VeloxShuffleWriter::isExtremelyLargeBatch(facebook::velox::RowVectorPtr& 
rv) const {
+  return (rv->size() > maxBatchSize_ && maxBatchSize_ > 0);
+}
+
 } // namespace gluten
diff --git a/cpp/velox/shuffle/VeloxShuffleWriter.h 
b/cpp/velox/shuffle/VeloxShuffleWriter.h
index c06cd7a0d..e699a323b 100644
--- a/cpp/velox/shuffle/VeloxShuffleWriter.h
+++ b/cpp/velox/shuffle/VeloxShuffleWriter.h
@@ -192,6 +192,10 @@ class VeloxShuffleWriter final : public ShuffleWriter {
     VS_PRINT_CONTAINER(input_has_null_);
   }
 
+  int32_t maxBatchSize() const {
+    return maxBatchSize_;
+  }
+
  private:
   VeloxShuffleWriter(
       uint32_t numPartitions,
@@ -306,6 +310,10 @@ class VeloxShuffleWriter final : public ShuffleWriter {
 
   arrow::Result<uint32_t> partitionBufferSizeAfterShrink(uint32_t partitionId) 
const;
 
+  bool isExtremelyLargeBatch(facebook::velox::RowVectorPtr& rv) const;
+
+  arrow::Status partitioningAndDoSplit(facebook::velox::RowVectorPtr rv, 
int64_t memLimit);
+
   SplitState splitState_{kInit};
 
   EvictState evictState_{kEvictable};
@@ -466,6 +474,7 @@ class VeloxShuffleWriter final : public ShuffleWriter {
   }
 
   facebook::velox::CpuWallTiming cpuWallTimingList_[CpuWallTimingNum];
+  int32_t maxBatchSize_{0};
 }; // class VeloxShuffleWriter
 
 } // namespace gluten
diff --git a/cpp/velox/tests/VeloxShuffleWriterTest.cc 
b/cpp/velox/tests/VeloxShuffleWriterTest.cc
index 2a8f12afb..ffda945b1 100644
--- a/cpp/velox/tests/VeloxShuffleWriterTest.cc
+++ b/cpp/velox/tests/VeloxShuffleWriterTest.cc
@@ -258,6 +258,22 @@ TEST_P(HashPartitioningShuffleWriter, hashPart3Vectors) {
       {{blockPid2}, {blockPid1}});
 }
 
+TEST_P(HashPartitioningShuffleWriter, hashLargeVectors) {
+  const int32_t expectedMaxBatchSize = 8;
+  ASSERT_NOT_OK(initShuffleWriterOptions());
+  auto shuffleWriter = createShuffleWriter(defaultArrowMemoryPool().get());
+  // calculate maxBatchSize_
+  ASSERT_NOT_OK(splitRowVector(*shuffleWriter, hashInputVector1_));
+  VELOX_CHECK_EQ(shuffleWriter->maxBatchSize(), expectedMaxBatchSize);
+
+  auto blockPid2 = takeRows({inputVector1_, inputVector2_, inputVector1_}, 
{{1, 2, 3, 4, 8}, {0, 1}, {1, 2, 3, 4, 8}});
+  auto blockPid1 = takeRows({inputVector1_}, {{0, 5, 6, 7, 9, 0, 5, 6, 7, 9}});
+
+  VELOX_CHECK(hashInputVector1_->size() > expectedMaxBatchSize);
+  testShuffleWriteMultiBlocks(
+      *shuffleWriter, {hashInputVector2_, hashInputVector1_}, 2, 
inputVector1_->type(), {{blockPid2}, {blockPid1}});
+}
+
 TEST_P(RangePartitioningShuffleWriter, rangePartition) {
   ASSERT_NOT_OK(initShuffleWriterOptions());
   auto shuffleWriter = createShuffleWriter(defaultArrowMemoryPool().get());


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to