This is an automated email from the ASF dual-hosted git repository.
weitingchen pushed a commit to branch branch-1.2
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/branch-1.2 by this push:
new 6e86564ecc [CORE][BRANCH-1.2] Port #7861 to fix OOM in shuffle writer
(#8078)
6e86564ecc is described below
commit 6e86564ecca1e21255ad27dc82378b3ca65dbaa3
Author: Lingfeng Zhang <[email protected]>
AuthorDate: Fri Nov 29 13:10:11 2024 +0800
[CORE][BRANCH-1.2] Port #7861 to fix OOM in shuffle writer (#8078)
* Impl MmapFileStream
* Prefetch spark.shuffle.file.buffer
* fix ut
* format
---------
Co-authored-by: Rong Ma <[email protected]>
---
cpp/core/config/GlutenConfig.h | 2 +
cpp/core/jni/JniWrapper.cc | 7 ++
cpp/core/shuffle/LocalPartitionWriter.cc | 4 +-
cpp/core/shuffle/Options.h | 3 +
cpp/core/shuffle/Payload.cc | 5 +
cpp/core/shuffle/Spill.cc | 5 +-
cpp/core/shuffle/Spill.h | 9 +-
cpp/core/shuffle/Utils.cc | 113 +++++++++++++++++++++
cpp/core/shuffle/Utils.h | 35 +++++++
.../scala/org/apache/gluten/GlutenConfig.scala | 12 ++-
10 files changed, 185 insertions(+), 10 deletions(-)
diff --git a/cpp/core/config/GlutenConfig.h b/cpp/core/config/GlutenConfig.h
index 060bbe1112..4a5d0190dd 100644
--- a/cpp/core/config/GlutenConfig.h
+++ b/cpp/core/config/GlutenConfig.h
@@ -64,6 +64,8 @@ const std::string kShuffleCompressionCodecBackend =
"spark.gluten.sql.columnar.s
const std::string kQatBackendName = "qat";
const std::string kIaaBackendName = "iaa";
+const std::string kShuffleFileBufferSize = "spark.shuffle.file.buffer";
+
std::unordered_map<std::string, std::string>
parseConfMap(JNIEnv* env, const uint8_t* planData, const int32_t
planDataLength);
diff --git a/cpp/core/jni/JniWrapper.cc b/cpp/core/jni/JniWrapper.cc
index ea5c9d271c..0d466a28d1 100644
--- a/cpp/core/jni/JniWrapper.cc
+++ b/cpp/core/jni/JniWrapper.cc
@@ -824,6 +824,13 @@ JNIEXPORT jlong JNICALL
Java_org_apache_gluten_vectorized_ShuffleWriterJniWrappe
partitionWriterOptions.codecBackend = getCodecBackend(env,
codecBackendJstr);
partitionWriterOptions.compressionMode = getCompressionMode(env,
compressionModeJstr);
}
+ const auto& conf = ctx->getConfMap();
+ {
+ auto it = conf.find(kShuffleFileBufferSize);
+ if (it != conf.end()) {
+ partitionWriterOptions.shuffleFileBufferSize =
static_cast<int64_t>(stoi(it->second));
+ }
+ }
std::unique_ptr<PartitionWriter> partitionWriter;
diff --git a/cpp/core/shuffle/LocalPartitionWriter.cc
b/cpp/core/shuffle/LocalPartitionWriter.cc
index f56543bab5..0fd0d6e298 100644
--- a/cpp/core/shuffle/LocalPartitionWriter.cc
+++ b/cpp/core/shuffle/LocalPartitionWriter.cc
@@ -391,7 +391,8 @@ arrow::Status LocalPartitionWriter::openDataFile() {
ARROW_ASSIGN_OR_RAISE(fout, arrow::io::FileOutputStream::Open(dataFile_));
if (options_.bufferedWrite) {
// Output stream buffer is neither partition buffer memory nor ipc memory.
- ARROW_ASSIGN_OR_RAISE(dataFileOs_,
arrow::io::BufferedOutputStream::Create(16384, pool_, fout));
+ ARROW_ASSIGN_OR_RAISE(
+ dataFileOs_,
arrow::io::BufferedOutputStream::Create(options_.shuffleFileBufferSize, pool_,
fout));
} else {
dataFileOs_ = fout;
}
@@ -422,6 +423,7 @@ arrow::Status LocalPartitionWriter::mergeSpills(uint32_t
partitionId) {
auto spillIter = spills_.begin();
while (spillIter != spills_.end()) {
ARROW_ASSIGN_OR_RAISE(auto st, dataFileOs_->Tell());
+ (*spillIter)->openForRead(options_.shuffleFileBufferSize);
// Read if partition exists in the spilled file and write to the final
file.
while (auto payload = (*spillIter)->nextPayload(partitionId)) {
// May trigger spill during compression.
diff --git a/cpp/core/shuffle/Options.h b/cpp/core/shuffle/Options.h
index 4828c7c822..1f1496e489 100644
--- a/cpp/core/shuffle/Options.h
+++ b/cpp/core/shuffle/Options.h
@@ -35,6 +35,7 @@ static constexpr int32_t kDefaultBufferAlignment = 64;
static constexpr double kDefaultBufferReallocThreshold = 0.25;
static constexpr double kDefaultMergeBufferThreshold = 0.25;
static constexpr bool kEnableBufferedWrite = true;
+static constexpr int64_t kDefaultShuffleFileBufferSize = 32 << 10;
enum ShuffleWriterType { kHashShuffle, kSortShuffle };
enum PartitionWriterType { kLocal, kRss };
@@ -75,6 +76,8 @@ struct PartitionWriterOptions {
int64_t pushBufferMaxSize = kDefaultPushMemoryThreshold;
int64_t sortBufferMaxSize = kDefaultSortBufferThreshold;
+
+ int64_t shuffleFileBufferSize = kDefaultShuffleFileBufferSize;
};
struct ShuffleWriterMetrics {
diff --git a/cpp/core/shuffle/Payload.cc b/cpp/core/shuffle/Payload.cc
index fb91c326b6..86a6e96339 100644
--- a/cpp/core/shuffle/Payload.cc
+++ b/cpp/core/shuffle/Payload.cc
@@ -476,6 +476,9 @@ arrow::Result<std::shared_ptr<arrow::Buffer>>
UncompressedDiskBlockPayload::read
}
arrow::Status UncompressedDiskBlockPayload::serialize(arrow::io::OutputStream*
outputStream) {
+ ARROW_RETURN_IF(
+ inputStream_ == nullptr, arrow::Status::Invalid("inputStream_ is
uninitialized before calling serialize()."));
+
if (codec_ == nullptr || type_ == Payload::kUncompressed) {
ARROW_ASSIGN_OR_RAISE(auto block, inputStream_->Read(rawSize_));
RETURN_NOT_OK(outputStream->Write(block));
@@ -526,6 +529,8 @@ CompressedDiskBlockPayload::CompressedDiskBlockPayload(
: Payload(Type::kCompressed, numRows, isValidityBuffer),
inputStream_(inputStream), rawSize_(rawSize) {}
arrow::Status CompressedDiskBlockPayload::serialize(arrow::io::OutputStream*
outputStream) {
+ ARROW_RETURN_IF(
+ inputStream_ == nullptr, arrow::Status::Invalid("inputStream_ is
uninitialized before calling serialize()."));
ScopedTimer timer(&writeTime_);
ARROW_ASSIGN_OR_RAISE(auto block, inputStream_->Read(rawSize_));
RETURN_NOT_OK(outputStream->Write(block));
diff --git a/cpp/core/shuffle/Spill.cc b/cpp/core/shuffle/Spill.cc
index 51e07ae52e..a621a58c3f 100644
--- a/cpp/core/shuffle/Spill.cc
+++ b/cpp/core/shuffle/Spill.cc
@@ -35,7 +35,6 @@ bool Spill::hasNextPayload(uint32_t partitionId) {
}
std::unique_ptr<Payload> Spill::nextPayload(uint32_t partitionId) {
- openSpillFile();
if (!hasNextPayload(partitionId)) {
return nullptr;
}
@@ -72,9 +71,9 @@ void Spill::insertPayload(
}
}
-void Spill::openSpillFile() {
+void Spill::openForRead(uint64_t shuffleFileBufferSize) {
if (!is_) {
- GLUTEN_ASSIGN_OR_THROW(is_, arrow::io::MemoryMappedFile::Open(spillFile_,
arrow::io::FileMode::READ));
+ GLUTEN_ASSIGN_OR_THROW(is_, MmapFileStream::open(spillFile_,
shuffleFileBufferSize));
rawIs_ = is_.get();
}
}
diff --git a/cpp/core/shuffle/Spill.h b/cpp/core/shuffle/Spill.h
index 6ce247b10d..bc1f427c44 100644
--- a/cpp/core/shuffle/Spill.h
+++ b/cpp/core/shuffle/Spill.h
@@ -37,6 +37,8 @@ class Spill final {
SpillType type() const;
+ void openForRead(uint64_t shuffleFileBufferSize);
+
bool hasNextPayload(uint32_t partitionId);
std::unique_ptr<Payload> nextPayload(uint32_t partitionId);
@@ -57,13 +59,10 @@ class Spill final {
};
SpillType type_;
- std::shared_ptr<arrow::io::MemoryMappedFile> is_;
+ std::shared_ptr<gluten::MmapFileStream> is_;
std::list<PartitionPayload> partitionPayloads_{};
- std::shared_ptr<arrow::io::MemoryMappedFile> inputStream_{};
std::string spillFile_;
- arrow::io::InputStream* rawIs_;
-
- void openSpillFile();
+ arrow::io::InputStream* rawIs_{nullptr};
};
} // namespace gluten
\ No newline at end of file
diff --git a/cpp/core/shuffle/Utils.cc b/cpp/core/shuffle/Utils.cc
index 6854c19783..cb572d4e47 100644
--- a/cpp/core/shuffle/Utils.cc
+++ b/cpp/core/shuffle/Utils.cc
@@ -16,10 +16,14 @@
*/
#include "shuffle/Utils.h"
+#include <arrow/buffer.h>
#include <arrow/record_batch.h>
#include <boost/uuid/uuid_generators.hpp>
#include <boost/uuid/uuid_io.hpp>
#include <fcntl.h>
+#include <glog/logging.h>
+#include <sys/mman.h>
+#include <unistd.h>
#include <iomanip>
#include <iostream>
#include <numeric>
@@ -151,6 +155,14 @@ arrow::Status getLengthBufferAndValueBufferStream(
*compressedLengthPtr = actualLength;
return arrow::Status::OK();
}
+
+uint64_t roundUpToPageSize(uint64_t value) {
+ static auto pageSize = static_cast<size_t>(arrow::internal::GetPageSize());
+ static auto pageMask = ~(pageSize - 1);
+ DCHECK_GT(pageSize, 0);
+ DCHECK_EQ(pageMask & pageSize, pageSize);
+ return (value + pageSize - 1) & pageMask;
+}
} // namespace
arrow::Result<std::shared_ptr<arrow::RecordBatch>> makeCompressedRecordBatch(
@@ -212,6 +224,107 @@ arrow::Result<std::shared_ptr<arrow::RecordBatch>>
makeUncompressedRecordBatch(
}
return arrow::RecordBatch::Make(writeSchema, 1, {arrays});
}
+
+MmapFileStream::MmapFileStream(arrow::internal::FileDescriptor fd, uint8_t*
data, int64_t size, uint64_t prefetchSize)
+ : prefetchSize_(roundUpToPageSize(prefetchSize)), fd_(std::move(fd)),
data_(data), size_(size){};
+
+arrow::Result<std::shared_ptr<MmapFileStream>> MmapFileStream::open(const
std::string& path, uint64_t prefetchSize) {
+ ARROW_ASSIGN_OR_RAISE(auto fileName,
arrow::internal::PlatformFilename::FromString(path));
+
+ ARROW_ASSIGN_OR_RAISE(auto fd, arrow::internal::FileOpenReadable(fileName));
+ ARROW_ASSIGN_OR_RAISE(auto size, arrow::internal::FileGetSize(fd.fd()));
+
+ ARROW_RETURN_IF(size == 0, arrow::Status::Invalid("Cannot mmap an empty
file: ", path));
+
+ void* result = mmap(nullptr, size, PROT_READ, MAP_PRIVATE, fd.fd(), 0);
+ if (result == MAP_FAILED) {
+ return arrow::Status::IOError("Memory mapping file failed: ",
::arrow::internal::ErrnoMessage(errno));
+ }
+
+ return std::make_shared<MmapFileStream>(std::move(fd),
static_cast<uint8_t*>(result), size, prefetchSize);
+}
+
+arrow::Result<int64_t> MmapFileStream::actualReadSize(int64_t nbytes) {
+ if (nbytes < 0 || pos_ > size_) {
+ return arrow::Status::IOError("Read out of range. Offset: ", pos_, " Size:
", nbytes, " File Size: ", size_);
+ }
+ return std::min(size_ - pos_, nbytes);
+}
+
+bool MmapFileStream::closed() const {
+ return data_ == nullptr;
+};
+
+void MmapFileStream::advance(int64_t length) {
+ // Dont need data before pos
+ auto purgeLength = (pos_ - posRetain_) / prefetchSize_ * prefetchSize_;
+ if (purgeLength > 0) {
+ int ret = madvise(data_ + posRetain_, purgeLength, MADV_DONTNEED);
+ if (ret != 0) {
+ LOG(WARNING) << "fadvise failed " <<
::arrow::internal::ErrnoMessage(errno);
+ }
+ posRetain_ += purgeLength;
+ }
+
+ pos_ += length;
+}
+
+void MmapFileStream::willNeed(int64_t length) {
+ // Skip if already fetched
+ if (pos_ + length <= posFetch_) {
+ return;
+ }
+
+ // Round up to multiple of prefetchSize
+ auto fetchLen = ((length + prefetchSize_ - 1) / prefetchSize_) *
prefetchSize_;
+ fetchLen = std::min(size_ - pos_, fetchLen);
+ int ret = madvise(data_ + posFetch_, fetchLen, MADV_WILLNEED);
+ if (ret != 0) {
+ LOG(WARNING) << "madvise willneed failed: " <<
::arrow::internal::ErrnoMessage(errno);
+ }
+
+ posFetch_ += fetchLen;
+}
+
+arrow::Status MmapFileStream::Close() {
+ if (data_ != nullptr) {
+ int result = munmap(data_, size_);
+ if (result != 0) {
+ LOG(WARNING) << "munmap failed";
+ }
+ data_ = nullptr;
+ }
+
+ return fd_.Close();
+}
+
+arrow::Result<int64_t> MmapFileStream::Tell() const {
+ return pos_;
+}
+
+arrow::Result<int64_t> MmapFileStream::Read(int64_t nbytes, void* out) {
+ ARROW_ASSIGN_OR_RAISE(nbytes, actualReadSize(nbytes));
+
+ if (nbytes > 0) {
+ memcpy(out, data_ + pos_, nbytes);
+ advance(nbytes);
+ }
+
+ return nbytes;
+}
+
+arrow::Result<std::shared_ptr<arrow::Buffer>> MmapFileStream::Read(int64_t
nbytes) {
+ ARROW_ASSIGN_OR_RAISE(nbytes, actualReadSize(nbytes));
+
+ if (nbytes > 0) {
+ auto buffer = std::make_shared<arrow::Buffer>(data_ + pos_, nbytes);
+ willNeed(nbytes);
+ advance(nbytes);
+ return buffer;
+ } else {
+ return std::make_shared<arrow::Buffer>(nullptr, 0);
+ }
+}
} // namespace gluten
std::string gluten::generateUuid() {
diff --git a/cpp/core/shuffle/Utils.h b/cpp/core/shuffle/Utils.h
index c4e2409d2d..67d0c3be03 100644
--- a/cpp/core/shuffle/Utils.h
+++ b/cpp/core/shuffle/Utils.h
@@ -72,4 +72,39 @@ arrow::Result<std::shared_ptr<arrow::RecordBatch>>
makeUncompressedRecordBatch(
std::shared_ptr<arrow::Buffer> zeroLengthNullBuffer();
+// MmapFileStream is used to optimize sequential file reading. It uses madvise
+// to prefetch and release memory timely.
+class MmapFileStream : public arrow::io::InputStream {
+ public:
+ MmapFileStream(arrow::internal::FileDescriptor fd, uint8_t* data, int64_t
size, uint64_t prefetchSize);
+
+ static arrow::Result<std::shared_ptr<MmapFileStream>> open(const
std::string& path, uint64_t prefetchSize = 0);
+
+ arrow::Result<int64_t> Tell() const override;
+
+ arrow::Status Close() override;
+
+ arrow::Result<int64_t> Read(int64_t nbytes, void* out) override;
+
+ arrow::Result<std::shared_ptr<arrow::Buffer>> Read(int64_t nbytes) override;
+
+ bool closed() const override;
+
+ private:
+ arrow::Result<int64_t> actualReadSize(int64_t nbytes);
+
+ void advance(int64_t length);
+
+ void willNeed(int64_t length);
+
+ // Page-aligned prefetch size
+ const int64_t prefetchSize_;
+ arrow::internal::FileDescriptor fd_;
+ uint8_t* data_ = nullptr;
+ int64_t size_;
+ int64_t pos_ = 0;
+ int64_t posFetch_ = 0;
+ int64_t posRetain_ = 0;
+};
+
} // namespace gluten
diff --git a/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala
b/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala
index 67bd4716df..3923e7d6e1 100644
--- a/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala
+++ b/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala
@@ -17,7 +17,7 @@
package org.apache.gluten
import org.apache.spark.internal.Logging
-import org.apache.spark.network.util.ByteUnit
+import org.apache.spark.network.util.{ByteUnit, JavaUtils}
import org.apache.spark.sql.internal.SQLConf
import com.google.common.collect.ImmutableList
@@ -523,6 +523,7 @@ object GlutenConfig {
val GLUTEN_ONHEAP_SIZE_KEY = "spark.executor.memory"
val GLUTEN_OFFHEAP_SIZE_KEY = "spark.memory.offHeap.size"
val GLUTEN_OFFHEAP_ENABLED = "spark.memory.offHeap.enabled"
+ val SPARK_SHUFFLE_FILE_BUFFER = "spark.shuffle.file.buffer"
// For Soft Affinity Scheduling
// Enable Soft Affinity Scheduling, defalut value is false
@@ -667,6 +668,15 @@ object GlutenConfig {
)
keyWithDefault.forEach(e => nativeConfMap.put(e._1, conf.getOrElse(e._1,
e._2)))
+ conf
+ .get(SPARK_SHUFFLE_FILE_BUFFER)
+ .foreach(
+ v =>
+ nativeConfMap
+ .put(
+ SPARK_SHUFFLE_FILE_BUFFER,
+ (JavaUtils.byteStringAs(v, ByteUnit.KiB) * 1024).toString))
+
// Backend's dynamic session conf only.
conf
.filter(entry => entry._1.startsWith(backendPrefix) &&
!SQLConf.isStaticConfigKey(entry._1))
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]