This is an automated email from the ASF dual-hosted git repository.
marong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 72b8810be5 [GLUTEN-7860][CORE] In shuffle writer, replace
MemoryMappedFile to avoid OOM (#7861)
72b8810be5 is described below
commit 72b8810be57cc986a52aa83df50e31281a803733
Author: Lingfeng Zhang <[email protected]>
AuthorDate: Thu Nov 28 18:58:39 2024 +0800
[GLUTEN-7860][CORE] In shuffle writer, replace MemoryMappedFile to avoid
OOM (#7861)
---
cpp/core/config/GlutenConfig.h | 1 +
cpp/core/jni/JniWrapper.cc | 7 ++
cpp/core/shuffle/LocalPartitionWriter.cc | 5 +-
cpp/core/shuffle/Options.h | 3 +
cpp/core/shuffle/Payload.cc | 5 +
cpp/core/shuffle/Spill.cc | 5 +-
cpp/core/shuffle/Spill.h | 9 +-
cpp/core/shuffle/Utils.cc | 112 +++++++++++++++++++++
cpp/core/shuffle/Utils.h | 35 +++++++
.../scala/org/apache/gluten/GlutenConfig.scala | 12 ++-
10 files changed, 183 insertions(+), 11 deletions(-)
diff --git a/cpp/core/config/GlutenConfig.h b/cpp/core/config/GlutenConfig.h
index 3dd2f77f9f..5a61b27a80 100644
--- a/cpp/core/config/GlutenConfig.h
+++ b/cpp/core/config/GlutenConfig.h
@@ -73,6 +73,7 @@ const std::string kSparkRedactionRegex =
"spark.redaction.regex";
const std::string kSparkRedactionString = "*********(redacted)";
const std::string kSparkLegacyTimeParserPolicy =
"spark.sql.legacy.timeParserPolicy";
+const std::string kShuffleFileBufferSize = "spark.shuffle.file.buffer";
std::unordered_map<std::string, std::string>
parseConfMap(JNIEnv* env, const uint8_t* planData, const int32_t
planDataLength);
diff --git a/cpp/core/jni/JniWrapper.cc b/cpp/core/jni/JniWrapper.cc
index 963440f6fc..794ca6b88f 100644
--- a/cpp/core/jni/JniWrapper.cc
+++ b/cpp/core/jni/JniWrapper.cc
@@ -824,6 +824,13 @@ JNIEXPORT jlong JNICALL
Java_org_apache_gluten_vectorized_ShuffleWriterJniWrappe
partitionWriterOptions.codecBackend = getCodecBackend(env,
codecBackendJstr);
partitionWriterOptions.compressionMode = getCompressionMode(env,
compressionModeJstr);
}
+ const auto& conf = ctx->getConfMap();
+ {
+ auto it = conf.find(kShuffleFileBufferSize);
+ if (it != conf.end()) {
+ partitionWriterOptions.shuffleFileBufferSize =
static_cast<int64_t>(stoi(it->second));
+ }
+ }
std::unique_ptr<PartitionWriter> partitionWriter;
diff --git a/cpp/core/shuffle/LocalPartitionWriter.cc
b/cpp/core/shuffle/LocalPartitionWriter.cc
index b7bfa19304..6692097986 100644
--- a/cpp/core/shuffle/LocalPartitionWriter.cc
+++ b/cpp/core/shuffle/LocalPartitionWriter.cc
@@ -389,9 +389,9 @@ arrow::Result<std::shared_ptr<arrow::io::OutputStream>>
LocalPartitionWriter::op
std::shared_ptr<arrow::io::FileOutputStream> fout;
ARROW_ASSIGN_OR_RAISE(fout, arrow::io::FileOutputStream::Open(file));
if (options_.bufferedWrite) {
- // The 16k bytes is a temporary allocation and will be freed with file
close.
+ // The `shuffleFileBufferSize` bytes is a temporary allocation and will be
freed with file close.
// Use default memory pool and count treat the memory as executor memory
overhead to avoid unnecessary spill.
- return arrow::io::BufferedOutputStream::Create(16384,
arrow::default_memory_pool(), fout);
+ return
arrow::io::BufferedOutputStream::Create(options_.shuffleFileBufferSize,
arrow::default_memory_pool(), fout);
}
return fout;
}
@@ -420,6 +420,7 @@ arrow::Status LocalPartitionWriter::mergeSpills(uint32_t
partitionId) {
auto spillIter = spills_.begin();
while (spillIter != spills_.end()) {
ARROW_ASSIGN_OR_RAISE(auto st, dataFileOs_->Tell());
+ (*spillIter)->openForRead(options_.shuffleFileBufferSize);
// Read if partition exists in the spilled file and write to the final
file.
while (auto payload = (*spillIter)->nextPayload(partitionId)) {
// May trigger spill during compression.
diff --git a/cpp/core/shuffle/Options.h b/cpp/core/shuffle/Options.h
index 6a9e0ec4b3..3a1efdc2ae 100644
--- a/cpp/core/shuffle/Options.h
+++ b/cpp/core/shuffle/Options.h
@@ -39,6 +39,7 @@ static constexpr bool kEnableBufferedWrite = true;
static constexpr bool kDefaultUseRadixSort = true;
static constexpr int32_t kDefaultSortBufferSize = 4096;
static constexpr int64_t kDefaultReadBufferSize = 1 << 20;
+static constexpr int64_t kDefaultShuffleFileBufferSize = 32 << 10;
enum ShuffleWriterType { kHashShuffle, kSortShuffle, kRssSortShuffle };
enum PartitionWriterType { kLocal, kRss };
@@ -86,6 +87,8 @@ struct PartitionWriterOptions {
int64_t pushBufferMaxSize = kDefaultPushMemoryThreshold;
int64_t sortBufferMaxSize = kDefaultSortBufferThreshold;
+
+ int64_t shuffleFileBufferSize = kDefaultShuffleFileBufferSize;
};
struct ShuffleWriterMetrics {
diff --git a/cpp/core/shuffle/Payload.cc b/cpp/core/shuffle/Payload.cc
index 55f3a43396..ddf4a40966 100644
--- a/cpp/core/shuffle/Payload.cc
+++ b/cpp/core/shuffle/Payload.cc
@@ -481,6 +481,9 @@ arrow::Result<std::shared_ptr<arrow::Buffer>>
UncompressedDiskBlockPayload::read
}
arrow::Status UncompressedDiskBlockPayload::serialize(arrow::io::OutputStream*
outputStream) {
+ ARROW_RETURN_IF(
+ inputStream_ == nullptr, arrow::Status::Invalid("inputStream_ is
uninitialized before calling serialize()."));
+
if (codec_ == nullptr || type_ == Payload::kUncompressed) {
ARROW_ASSIGN_OR_RAISE(auto block, inputStream_->Read(rawSize_));
RETURN_NOT_OK(outputStream->Write(block));
@@ -545,6 +548,8 @@ CompressedDiskBlockPayload::CompressedDiskBlockPayload(
: Payload(Type::kCompressed, numRows, isValidityBuffer),
inputStream_(inputStream), rawSize_(rawSize) {}
arrow::Status CompressedDiskBlockPayload::serialize(arrow::io::OutputStream*
outputStream) {
+ ARROW_RETURN_IF(
+ inputStream_ == nullptr, arrow::Status::Invalid("inputStream_ is
uninitialized before calling serialize()."));
ScopedTimer timer(&writeTime_);
ARROW_ASSIGN_OR_RAISE(auto block, inputStream_->Read(rawSize_));
RETURN_NOT_OK(outputStream->Write(block));
diff --git a/cpp/core/shuffle/Spill.cc b/cpp/core/shuffle/Spill.cc
index d8b9bc7ebf..8cc3a9d05e 100644
--- a/cpp/core/shuffle/Spill.cc
+++ b/cpp/core/shuffle/Spill.cc
@@ -34,7 +34,6 @@ bool Spill::hasNextPayload(uint32_t partitionId) {
}
std::unique_ptr<Payload> Spill::nextPayload(uint32_t partitionId) {
- openSpillFile();
if (!hasNextPayload(partitionId)) {
return nullptr;
}
@@ -71,9 +70,9 @@ void Spill::insertPayload(
}
}
-void Spill::openSpillFile() {
+void Spill::openForRead(uint64_t shuffleFileBufferSize) {
if (!is_) {
- GLUTEN_ASSIGN_OR_THROW(is_, arrow::io::MemoryMappedFile::Open(spillFile_,
arrow::io::FileMode::READ));
+ GLUTEN_ASSIGN_OR_THROW(is_, MmapFileStream::open(spillFile_,
shuffleFileBufferSize));
rawIs_ = is_.get();
}
}
diff --git a/cpp/core/shuffle/Spill.h b/cpp/core/shuffle/Spill.h
index c82a60f562..fd692537c5 100644
--- a/cpp/core/shuffle/Spill.h
+++ b/cpp/core/shuffle/Spill.h
@@ -37,6 +37,8 @@ class Spill final {
SpillType type() const;
+ void openForRead(uint64_t shuffleFileBufferSize);
+
bool hasNextPayload(uint32_t partitionId);
std::unique_ptr<Payload> nextPayload(uint32_t partitionId);
@@ -69,15 +71,12 @@ class Spill final {
};
SpillType type_;
- std::shared_ptr<arrow::io::MemoryMappedFile> is_;
+ std::shared_ptr<gluten::MmapFileStream> is_;
std::list<PartitionPayload> partitionPayloads_{};
- std::shared_ptr<arrow::io::MemoryMappedFile> inputStream_{};
std::string spillFile_;
int64_t spillTime_;
int64_t compressTime_;
- arrow::io::InputStream* rawIs_;
-
- void openSpillFile();
+ arrow::io::InputStream* rawIs_{nullptr};
};
} // namespace gluten
\ No newline at end of file
diff --git a/cpp/core/shuffle/Utils.cc b/cpp/core/shuffle/Utils.cc
index a11b6b09aa..457702c0c9 100644
--- a/cpp/core/shuffle/Utils.cc
+++ b/cpp/core/shuffle/Utils.cc
@@ -16,8 +16,11 @@
*/
#include "shuffle/Utils.h"
+#include <arrow/buffer.h>
#include <arrow/record_batch.h>
#include <fcntl.h>
+#include <glog/logging.h>
+#include <sys/mman.h>
#include <unistd.h>
#include <iomanip>
#include <iostream>
@@ -151,6 +154,14 @@ arrow::Status getLengthBufferAndValueBufferStream(
*compressedLengthPtr = actualLength;
return arrow::Status::OK();
}
+
+uint64_t roundUpToPageSize(uint64_t value) {
+ static auto pageSize = static_cast<size_t>(arrow::internal::GetPageSize());
+ static auto pageMask = ~(pageSize - 1);
+ DCHECK_GT(pageSize, 0);
+ DCHECK_EQ(pageMask & pageSize, pageSize);
+ return (value + pageSize - 1) & pageMask;
+}
} // namespace
arrow::Result<std::shared_ptr<arrow::RecordBatch>> makeCompressedRecordBatch(
@@ -212,6 +223,107 @@ arrow::Result<std::shared_ptr<arrow::RecordBatch>>
makeUncompressedRecordBatch(
}
return arrow::RecordBatch::Make(writeSchema, 1, {arrays});
}
+
+MmapFileStream::MmapFileStream(arrow::internal::FileDescriptor fd, uint8_t*
data, int64_t size, uint64_t prefetchSize)
+ : prefetchSize_(roundUpToPageSize(prefetchSize)), fd_(std::move(fd)),
data_(data), size_(size){};
+
+arrow::Result<std::shared_ptr<MmapFileStream>> MmapFileStream::open(const
std::string& path, uint64_t prefetchSize) {
+ ARROW_ASSIGN_OR_RAISE(auto fileName,
arrow::internal::PlatformFilename::FromString(path));
+
+ ARROW_ASSIGN_OR_RAISE(auto fd, arrow::internal::FileOpenReadable(fileName));
+ ARROW_ASSIGN_OR_RAISE(auto size, arrow::internal::FileGetSize(fd.fd()));
+
+ ARROW_RETURN_IF(size == 0, arrow::Status::Invalid("Cannot mmap an empty
file: ", path));
+
+ void* result = mmap(nullptr, size, PROT_READ, MAP_PRIVATE, fd.fd(), 0);
+ if (result == MAP_FAILED) {
+ return arrow::Status::IOError("Memory mapping file failed: ",
::arrow::internal::ErrnoMessage(errno));
+ }
+
+ return std::make_shared<MmapFileStream>(std::move(fd),
static_cast<uint8_t*>(result), size, prefetchSize);
+}
+
+arrow::Result<int64_t> MmapFileStream::actualReadSize(int64_t nbytes) {
+ if (nbytes < 0 || pos_ > size_) {
+ return arrow::Status::IOError("Read out of range. Offset: ", pos_, " Size:
", nbytes, " File Size: ", size_);
+ }
+ return std::min(size_ - pos_, nbytes);
+}
+
+bool MmapFileStream::closed() const {
+ return data_ == nullptr;
+};
+
+void MmapFileStream::advance(int64_t length) {
+ // Dont need data before pos
+ auto purgeLength = (pos_ - posRetain_) / prefetchSize_ * prefetchSize_;
+ if (purgeLength > 0) {
+ int ret = madvise(data_ + posRetain_, purgeLength, MADV_DONTNEED);
+ if (ret != 0) {
+ LOG(WARNING) << "fadvise failed " <<
::arrow::internal::ErrnoMessage(errno);
+ }
+ posRetain_ += purgeLength;
+ }
+
+ pos_ += length;
+}
+
+void MmapFileStream::willNeed(int64_t length) {
+ // Skip if already fetched
+ if (pos_ + length <= posFetch_) {
+ return;
+ }
+
+ // Round up to multiple of prefetchSize
+ auto fetchLen = ((length + prefetchSize_ - 1) / prefetchSize_) *
prefetchSize_;
+ fetchLen = std::min(size_ - pos_, fetchLen);
+ int ret = madvise(data_ + posFetch_, fetchLen, MADV_WILLNEED);
+ if (ret != 0) {
+ LOG(WARNING) << "madvise willneed failed: " <<
::arrow::internal::ErrnoMessage(errno);
+ }
+
+ posFetch_ += fetchLen;
+}
+
+arrow::Status MmapFileStream::Close() {
+ if (data_ != nullptr) {
+ int result = munmap(data_, size_);
+ if (result != 0) {
+ LOG(WARNING) << "munmap failed";
+ }
+ data_ = nullptr;
+ }
+
+ return fd_.Close();
+}
+
+arrow::Result<int64_t> MmapFileStream::Tell() const {
+ return pos_;
+}
+
+arrow::Result<int64_t> MmapFileStream::Read(int64_t nbytes, void* out) {
+ ARROW_ASSIGN_OR_RAISE(nbytes, actualReadSize(nbytes));
+
+ if (nbytes > 0) {
+ memcpy(out, data_ + pos_, nbytes);
+ advance(nbytes);
+ }
+
+ return nbytes;
+}
+
+arrow::Result<std::shared_ptr<arrow::Buffer>> MmapFileStream::Read(int64_t
nbytes) {
+ ARROW_ASSIGN_OR_RAISE(nbytes, actualReadSize(nbytes));
+
+ if (nbytes > 0) {
+ auto buffer = std::make_shared<arrow::Buffer>(data_ + pos_, nbytes);
+ willNeed(nbytes);
+ advance(nbytes);
+ return buffer;
+ } else {
+ return std::make_shared<arrow::Buffer>(nullptr, 0);
+ }
+}
} // namespace gluten
std::string gluten::getShuffleSpillDir(const std::string& configuredDir,
int32_t subDirId) {
diff --git a/cpp/core/shuffle/Utils.h b/cpp/core/shuffle/Utils.h
index 64b9292d9d..2e5ff58b6e 100644
--- a/cpp/core/shuffle/Utils.h
+++ b/cpp/core/shuffle/Utils.h
@@ -70,4 +70,39 @@ arrow::Result<std::shared_ptr<arrow::RecordBatch>>
makeUncompressedRecordBatch(
std::shared_ptr<arrow::Buffer> zeroLengthNullBuffer();
+// MmapFileStream is used to optimize sequential file reading. It uses madvise
+// to prefetch and release memory timely.
+class MmapFileStream : public arrow::io::InputStream {
+ public:
+ MmapFileStream(arrow::internal::FileDescriptor fd, uint8_t* data, int64_t
size, uint64_t prefetchSize);
+
+ static arrow::Result<std::shared_ptr<MmapFileStream>> open(const
std::string& path, uint64_t prefetchSize = 0);
+
+ arrow::Result<int64_t> Tell() const override;
+
+ arrow::Status Close() override;
+
+ arrow::Result<int64_t> Read(int64_t nbytes, void* out) override;
+
+ arrow::Result<std::shared_ptr<arrow::Buffer>> Read(int64_t nbytes) override;
+
+ bool closed() const override;
+
+ private:
+ arrow::Result<int64_t> actualReadSize(int64_t nbytes);
+
+ void advance(int64_t length);
+
+ void willNeed(int64_t length);
+
+ // Page-aligned prefetch size
+ const int64_t prefetchSize_;
+ arrow::internal::FileDescriptor fd_;
+ uint8_t* data_ = nullptr;
+ int64_t size_;
+ int64_t pos_ = 0;
+ int64_t posFetch_ = 0;
+ int64_t posRetain_ = 0;
+};
+
} // namespace gluten
diff --git a/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala
b/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala
index c4c67f49a5..b63a6bfecc 100644
--- a/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala
+++ b/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala
@@ -17,7 +17,7 @@
package org.apache.gluten
import org.apache.spark.internal.Logging
-import org.apache.spark.network.util.ByteUnit
+import org.apache.spark.network.util.{ByteUnit, JavaUtils}
import org.apache.spark.sql.internal.SQLConf
import com.google.common.collect.ImmutableList
@@ -568,6 +568,7 @@ object GlutenConfig {
val SPARK_OFFHEAP_SIZE_KEY = "spark.memory.offHeap.size"
val SPARK_OFFHEAP_ENABLED = "spark.memory.offHeap.enabled"
val SPARK_REDACTION_REGEX = "spark.redaction.regex"
+ val SPARK_SHUFFLE_FILE_BUFFER = "spark.shuffle.file.buffer"
// For Soft Affinity Scheduling
// Enable Soft Affinity Scheduling, default value is false
@@ -736,6 +737,15 @@ object GlutenConfig {
)
keyWithDefault.forEach(e => nativeConfMap.put(e._1, conf.getOrElse(e._1,
e._2)))
+ conf
+ .get(SPARK_SHUFFLE_FILE_BUFFER)
+ .foreach(
+ v =>
+ nativeConfMap
+ .put(
+ SPARK_SHUFFLE_FILE_BUFFER,
+ (JavaUtils.byteStringAs(v, ByteUnit.KiB) * 1024).toString))
+
// Backend's dynamic session conf only.
val confPrefix = prefixOf(backendName)
conf
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]