This is an automated email from the ASF dual-hosted git repository.

marong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 72b8810be5 [GLUTEN-7860][CORE] In shuffle writer, replace 
MemoryMappedFile to avoid OOM (#7861)
72b8810be5 is described below

commit 72b8810be57cc986a52aa83df50e31281a803733
Author: Lingfeng Zhang <[email protected]>
AuthorDate: Thu Nov 28 18:58:39 2024 +0800

    [GLUTEN-7860][CORE] In shuffle writer, replace MemoryMappedFile to avoid 
OOM (#7861)
---
 cpp/core/config/GlutenConfig.h                     |   1 +
 cpp/core/jni/JniWrapper.cc                         |   7 ++
 cpp/core/shuffle/LocalPartitionWriter.cc           |   5 +-
 cpp/core/shuffle/Options.h                         |   3 +
 cpp/core/shuffle/Payload.cc                        |   5 +
 cpp/core/shuffle/Spill.cc                          |   5 +-
 cpp/core/shuffle/Spill.h                           |   9 +-
 cpp/core/shuffle/Utils.cc                          | 112 +++++++++++++++++++++
 cpp/core/shuffle/Utils.h                           |  35 +++++++
 .../scala/org/apache/gluten/GlutenConfig.scala     |  12 ++-
 10 files changed, 183 insertions(+), 11 deletions(-)

diff --git a/cpp/core/config/GlutenConfig.h b/cpp/core/config/GlutenConfig.h
index 3dd2f77f9f..5a61b27a80 100644
--- a/cpp/core/config/GlutenConfig.h
+++ b/cpp/core/config/GlutenConfig.h
@@ -73,6 +73,7 @@ const std::string kSparkRedactionRegex = 
"spark.redaction.regex";
 const std::string kSparkRedactionString = "*********(redacted)";
 
 const std::string kSparkLegacyTimeParserPolicy = 
"spark.sql.legacy.timeParserPolicy";
+const std::string kShuffleFileBufferSize = "spark.shuffle.file.buffer";
 
 std::unordered_map<std::string, std::string>
 parseConfMap(JNIEnv* env, const uint8_t* planData, const int32_t 
planDataLength);
diff --git a/cpp/core/jni/JniWrapper.cc b/cpp/core/jni/JniWrapper.cc
index 963440f6fc..794ca6b88f 100644
--- a/cpp/core/jni/JniWrapper.cc
+++ b/cpp/core/jni/JniWrapper.cc
@@ -824,6 +824,13 @@ JNIEXPORT jlong JNICALL 
Java_org_apache_gluten_vectorized_ShuffleWriterJniWrappe
     partitionWriterOptions.codecBackend = getCodecBackend(env, 
codecBackendJstr);
     partitionWriterOptions.compressionMode = getCompressionMode(env, 
compressionModeJstr);
   }
+  const auto& conf = ctx->getConfMap();
+  {
+    auto it = conf.find(kShuffleFileBufferSize);
+    if (it != conf.end()) {
+      partitionWriterOptions.shuffleFileBufferSize = 
static_cast<int64_t>(stoi(it->second));
+    }
+  }
 
   std::unique_ptr<PartitionWriter> partitionWriter;
 
diff --git a/cpp/core/shuffle/LocalPartitionWriter.cc 
b/cpp/core/shuffle/LocalPartitionWriter.cc
index b7bfa19304..6692097986 100644
--- a/cpp/core/shuffle/LocalPartitionWriter.cc
+++ b/cpp/core/shuffle/LocalPartitionWriter.cc
@@ -389,9 +389,9 @@ arrow::Result<std::shared_ptr<arrow::io::OutputStream>> 
LocalPartitionWriter::op
   std::shared_ptr<arrow::io::FileOutputStream> fout;
   ARROW_ASSIGN_OR_RAISE(fout, arrow::io::FileOutputStream::Open(file));
   if (options_.bufferedWrite) {
-    // The 16k bytes is a temporary allocation and will be freed with file 
close.
+    // The `shuffleFileBufferSize` bytes is a temporary allocation and will be 
freed with file close.
     // Use default memory pool and count treat the memory as executor memory 
overhead to avoid unnecessary spill.
-    return arrow::io::BufferedOutputStream::Create(16384, 
arrow::default_memory_pool(), fout);
+    return 
arrow::io::BufferedOutputStream::Create(options_.shuffleFileBufferSize, 
arrow::default_memory_pool(), fout);
   }
   return fout;
 }
@@ -420,6 +420,7 @@ arrow::Status LocalPartitionWriter::mergeSpills(uint32_t 
partitionId) {
   auto spillIter = spills_.begin();
   while (spillIter != spills_.end()) {
     ARROW_ASSIGN_OR_RAISE(auto st, dataFileOs_->Tell());
+    (*spillIter)->openForRead(options_.shuffleFileBufferSize);
     // Read if partition exists in the spilled file and write to the final 
file.
     while (auto payload = (*spillIter)->nextPayload(partitionId)) {
       // May trigger spill during compression.
diff --git a/cpp/core/shuffle/Options.h b/cpp/core/shuffle/Options.h
index 6a9e0ec4b3..3a1efdc2ae 100644
--- a/cpp/core/shuffle/Options.h
+++ b/cpp/core/shuffle/Options.h
@@ -39,6 +39,7 @@ static constexpr bool kEnableBufferedWrite = true;
 static constexpr bool kDefaultUseRadixSort = true;
 static constexpr int32_t kDefaultSortBufferSize = 4096;
 static constexpr int64_t kDefaultReadBufferSize = 1 << 20;
+static constexpr int64_t kDefaultShuffleFileBufferSize = 32 << 10;
 
 enum ShuffleWriterType { kHashShuffle, kSortShuffle, kRssSortShuffle };
 enum PartitionWriterType { kLocal, kRss };
@@ -86,6 +87,8 @@ struct PartitionWriterOptions {
   int64_t pushBufferMaxSize = kDefaultPushMemoryThreshold;
 
   int64_t sortBufferMaxSize = kDefaultSortBufferThreshold;
+
+  int64_t shuffleFileBufferSize = kDefaultShuffleFileBufferSize;
 };
 
 struct ShuffleWriterMetrics {
diff --git a/cpp/core/shuffle/Payload.cc b/cpp/core/shuffle/Payload.cc
index 55f3a43396..ddf4a40966 100644
--- a/cpp/core/shuffle/Payload.cc
+++ b/cpp/core/shuffle/Payload.cc
@@ -481,6 +481,9 @@ arrow::Result<std::shared_ptr<arrow::Buffer>> 
UncompressedDiskBlockPayload::read
 }
 
 arrow::Status UncompressedDiskBlockPayload::serialize(arrow::io::OutputStream* 
outputStream) {
+  ARROW_RETURN_IF(
+      inputStream_ == nullptr, arrow::Status::Invalid("inputStream_ is 
uninitialized before calling serialize()."));
+
   if (codec_ == nullptr || type_ == Payload::kUncompressed) {
     ARROW_ASSIGN_OR_RAISE(auto block, inputStream_->Read(rawSize_));
     RETURN_NOT_OK(outputStream->Write(block));
@@ -545,6 +548,8 @@ CompressedDiskBlockPayload::CompressedDiskBlockPayload(
     : Payload(Type::kCompressed, numRows, isValidityBuffer), 
inputStream_(inputStream), rawSize_(rawSize) {}
 
 arrow::Status CompressedDiskBlockPayload::serialize(arrow::io::OutputStream* 
outputStream) {
+  ARROW_RETURN_IF(
+      inputStream_ == nullptr, arrow::Status::Invalid("inputStream_ is 
uninitialized before calling serialize()."));
   ScopedTimer timer(&writeTime_);
   ARROW_ASSIGN_OR_RAISE(auto block, inputStream_->Read(rawSize_));
   RETURN_NOT_OK(outputStream->Write(block));
diff --git a/cpp/core/shuffle/Spill.cc b/cpp/core/shuffle/Spill.cc
index d8b9bc7ebf..8cc3a9d05e 100644
--- a/cpp/core/shuffle/Spill.cc
+++ b/cpp/core/shuffle/Spill.cc
@@ -34,7 +34,6 @@ bool Spill::hasNextPayload(uint32_t partitionId) {
 }
 
 std::unique_ptr<Payload> Spill::nextPayload(uint32_t partitionId) {
-  openSpillFile();
   if (!hasNextPayload(partitionId)) {
     return nullptr;
   }
@@ -71,9 +70,9 @@ void Spill::insertPayload(
   }
 }
 
-void Spill::openSpillFile() {
+void Spill::openForRead(uint64_t shuffleFileBufferSize) {
   if (!is_) {
-    GLUTEN_ASSIGN_OR_THROW(is_, arrow::io::MemoryMappedFile::Open(spillFile_, 
arrow::io::FileMode::READ));
+    GLUTEN_ASSIGN_OR_THROW(is_, MmapFileStream::open(spillFile_, 
shuffleFileBufferSize));
     rawIs_ = is_.get();
   }
 }
diff --git a/cpp/core/shuffle/Spill.h b/cpp/core/shuffle/Spill.h
index c82a60f562..fd692537c5 100644
--- a/cpp/core/shuffle/Spill.h
+++ b/cpp/core/shuffle/Spill.h
@@ -37,6 +37,8 @@ class Spill final {
 
   SpillType type() const;
 
+  void openForRead(uint64_t shuffleFileBufferSize);
+
   bool hasNextPayload(uint32_t partitionId);
 
   std::unique_ptr<Payload> nextPayload(uint32_t partitionId);
@@ -69,15 +71,12 @@ class Spill final {
   };
 
   SpillType type_;
-  std::shared_ptr<arrow::io::MemoryMappedFile> is_;
+  std::shared_ptr<gluten::MmapFileStream> is_;
   std::list<PartitionPayload> partitionPayloads_{};
-  std::shared_ptr<arrow::io::MemoryMappedFile> inputStream_{};
   std::string spillFile_;
   int64_t spillTime_;
   int64_t compressTime_;
 
-  arrow::io::InputStream* rawIs_;
-
-  void openSpillFile();
+  arrow::io::InputStream* rawIs_{nullptr};
 };
 } // namespace gluten
\ No newline at end of file
diff --git a/cpp/core/shuffle/Utils.cc b/cpp/core/shuffle/Utils.cc
index a11b6b09aa..457702c0c9 100644
--- a/cpp/core/shuffle/Utils.cc
+++ b/cpp/core/shuffle/Utils.cc
@@ -16,8 +16,11 @@
  */
 
 #include "shuffle/Utils.h"
+#include <arrow/buffer.h>
 #include <arrow/record_batch.h>
 #include <fcntl.h>
+#include <glog/logging.h>
+#include <sys/mman.h>
 #include <unistd.h>
 #include <iomanip>
 #include <iostream>
@@ -151,6 +154,14 @@ arrow::Status getLengthBufferAndValueBufferStream(
   *compressedLengthPtr = actualLength;
   return arrow::Status::OK();
 }
+
+uint64_t roundUpToPageSize(uint64_t value) {
+  static auto pageSize = static_cast<size_t>(arrow::internal::GetPageSize());
+  static auto pageMask = ~(pageSize - 1);
+  DCHECK_GT(pageSize, 0);
+  DCHECK_EQ(pageMask & pageSize, pageSize);
+  return (value + pageSize - 1) & pageMask;
+}
 } // namespace
 
 arrow::Result<std::shared_ptr<arrow::RecordBatch>> makeCompressedRecordBatch(
@@ -212,6 +223,107 @@ arrow::Result<std::shared_ptr<arrow::RecordBatch>> 
makeUncompressedRecordBatch(
   }
   return arrow::RecordBatch::Make(writeSchema, 1, {arrays});
 }
+
+MmapFileStream::MmapFileStream(arrow::internal::FileDescriptor fd, uint8_t* 
data, int64_t size, uint64_t prefetchSize)
+    : prefetchSize_(roundUpToPageSize(prefetchSize)), fd_(std::move(fd)), 
data_(data), size_(size){};
+
+arrow::Result<std::shared_ptr<MmapFileStream>> MmapFileStream::open(const 
std::string& path, uint64_t prefetchSize) {
+  ARROW_ASSIGN_OR_RAISE(auto fileName, 
arrow::internal::PlatformFilename::FromString(path));
+
+  ARROW_ASSIGN_OR_RAISE(auto fd, arrow::internal::FileOpenReadable(fileName));
+  ARROW_ASSIGN_OR_RAISE(auto size, arrow::internal::FileGetSize(fd.fd()));
+
+  ARROW_RETURN_IF(size == 0, arrow::Status::Invalid("Cannot mmap an empty 
file: ", path));
+
+  void* result = mmap(nullptr, size, PROT_READ, MAP_PRIVATE, fd.fd(), 0);
+  if (result == MAP_FAILED) {
+    return arrow::Status::IOError("Memory mapping file failed: ", 
::arrow::internal::ErrnoMessage(errno));
+  }
+
+  return std::make_shared<MmapFileStream>(std::move(fd), 
static_cast<uint8_t*>(result), size, prefetchSize);
+}
+
+arrow::Result<int64_t> MmapFileStream::actualReadSize(int64_t nbytes) {
+  if (nbytes < 0 || pos_ > size_) {
+    return arrow::Status::IOError("Read out of range. Offset: ", pos_, " Size: 
", nbytes, " File Size: ", size_);
+  }
+  return std::min(size_ - pos_, nbytes);
+}
+
+bool MmapFileStream::closed() const {
+  return data_ == nullptr;
+};
+
+void MmapFileStream::advance(int64_t length) {
+  // Dont need data before pos
+  auto purgeLength = (pos_ - posRetain_) / prefetchSize_ * prefetchSize_;
+  if (purgeLength > 0) {
+    int ret = madvise(data_ + posRetain_, purgeLength, MADV_DONTNEED);
+    if (ret != 0) {
+      LOG(WARNING) << "fadvise failed " << 
::arrow::internal::ErrnoMessage(errno);
+    }
+    posRetain_ += purgeLength;
+  }
+
+  pos_ += length;
+}
+
+void MmapFileStream::willNeed(int64_t length) {
+  // Skip if already fetched
+  if (pos_ + length <= posFetch_) {
+    return;
+  }
+
+  // Round up to multiple of prefetchSize
+  auto fetchLen = ((length + prefetchSize_ - 1) / prefetchSize_) * 
prefetchSize_;
+  fetchLen = std::min(size_ - pos_, fetchLen);
+  int ret = madvise(data_ + posFetch_, fetchLen, MADV_WILLNEED);
+  if (ret != 0) {
+    LOG(WARNING) << "madvise willneed failed: " << 
::arrow::internal::ErrnoMessage(errno);
+  }
+
+  posFetch_ += fetchLen;
+}
+
+arrow::Status MmapFileStream::Close() {
+  if (data_ != nullptr) {
+    int result = munmap(data_, size_);
+    if (result != 0) {
+      LOG(WARNING) << "munmap failed";
+    }
+    data_ = nullptr;
+  }
+
+  return fd_.Close();
+}
+
+arrow::Result<int64_t> MmapFileStream::Tell() const {
+  return pos_;
+}
+
+arrow::Result<int64_t> MmapFileStream::Read(int64_t nbytes, void* out) {
+  ARROW_ASSIGN_OR_RAISE(nbytes, actualReadSize(nbytes));
+
+  if (nbytes > 0) {
+    memcpy(out, data_ + pos_, nbytes);
+    advance(nbytes);
+  }
+
+  return nbytes;
+}
+
+arrow::Result<std::shared_ptr<arrow::Buffer>> MmapFileStream::Read(int64_t 
nbytes) {
+  ARROW_ASSIGN_OR_RAISE(nbytes, actualReadSize(nbytes));
+
+  if (nbytes > 0) {
+    auto buffer = std::make_shared<arrow::Buffer>(data_ + pos_, nbytes);
+    willNeed(nbytes);
+    advance(nbytes);
+    return buffer;
+  } else {
+    return std::make_shared<arrow::Buffer>(nullptr, 0);
+  }
+}
 } // namespace gluten
 
 std::string gluten::getShuffleSpillDir(const std::string& configuredDir, 
int32_t subDirId) {
diff --git a/cpp/core/shuffle/Utils.h b/cpp/core/shuffle/Utils.h
index 64b9292d9d..2e5ff58b6e 100644
--- a/cpp/core/shuffle/Utils.h
+++ b/cpp/core/shuffle/Utils.h
@@ -70,4 +70,39 @@ arrow::Result<std::shared_ptr<arrow::RecordBatch>> 
makeUncompressedRecordBatch(
 
 std::shared_ptr<arrow::Buffer> zeroLengthNullBuffer();
 
+// MmapFileStream is used to optimize sequential file reading. It uses madvise
+// to prefetch and release memory timely.
+class MmapFileStream : public arrow::io::InputStream {
+ public:
+  MmapFileStream(arrow::internal::FileDescriptor fd, uint8_t* data, int64_t 
size, uint64_t prefetchSize);
+
+  static arrow::Result<std::shared_ptr<MmapFileStream>> open(const 
std::string& path, uint64_t prefetchSize = 0);
+
+  arrow::Result<int64_t> Tell() const override;
+
+  arrow::Status Close() override;
+
+  arrow::Result<int64_t> Read(int64_t nbytes, void* out) override;
+
+  arrow::Result<std::shared_ptr<arrow::Buffer>> Read(int64_t nbytes) override;
+
+  bool closed() const override;
+
+ private:
+  arrow::Result<int64_t> actualReadSize(int64_t nbytes);
+
+  void advance(int64_t length);
+
+  void willNeed(int64_t length);
+
+  // Page-aligned prefetch size
+  const int64_t prefetchSize_;
+  arrow::internal::FileDescriptor fd_;
+  uint8_t* data_ = nullptr;
+  int64_t size_;
+  int64_t pos_ = 0;
+  int64_t posFetch_ = 0;
+  int64_t posRetain_ = 0;
+};
+
 } // namespace gluten
diff --git a/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala 
b/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala
index c4c67f49a5..b63a6bfecc 100644
--- a/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala
+++ b/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala
@@ -17,7 +17,7 @@
 package org.apache.gluten
 
 import org.apache.spark.internal.Logging
-import org.apache.spark.network.util.ByteUnit
+import org.apache.spark.network.util.{ByteUnit, JavaUtils}
 import org.apache.spark.sql.internal.SQLConf
 
 import com.google.common.collect.ImmutableList
@@ -568,6 +568,7 @@ object GlutenConfig {
   val SPARK_OFFHEAP_SIZE_KEY = "spark.memory.offHeap.size"
   val SPARK_OFFHEAP_ENABLED = "spark.memory.offHeap.enabled"
   val SPARK_REDACTION_REGEX = "spark.redaction.regex"
+  val SPARK_SHUFFLE_FILE_BUFFER = "spark.shuffle.file.buffer"
 
   // For Soft Affinity Scheduling
   // Enable Soft Affinity Scheduling, default value is false
@@ -736,6 +737,15 @@ object GlutenConfig {
     )
     keyWithDefault.forEach(e => nativeConfMap.put(e._1, conf.getOrElse(e._1, 
e._2)))
 
+    conf
+      .get(SPARK_SHUFFLE_FILE_BUFFER)
+      .foreach(
+        v =>
+          nativeConfMap
+            .put(
+              SPARK_SHUFFLE_FILE_BUFFER,
+              (JavaUtils.byteStringAs(v, ByteUnit.KiB) * 1024).toString))
+
     // Backend's dynamic session conf only.
     val confPrefix = prefixOf(backendName)
     conf


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to