This is an automated email from the ASF dual-hosted git repository.
hongze pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new af626c2802 [GLUTEN-7509][VL] Memory management: Release all native
memory managers after all Velox tasks were released (#7478)
af626c2802 is described below
commit af626c2802b5e4c3a9526402dd93c9820e0d08c8
Author: Hongze Zhang <[email protected]>
AuthorDate: Mon Oct 14 12:58:42 2024 +0800
[GLUTEN-7509][VL] Memory management: Release all native memory managers
after all Velox tasks were released (#7478)
Closes #7509
---
.../spark/shuffle/ColumnarShuffleWriter.scala | 26 +--
cpp/core/CMakeLists.txt | 1 +
cpp/core/compute/Runtime.cc | 39 +----
cpp/core/compute/Runtime.h | 12 +-
cpp/core/jni/JniWrapper.cc | 190 ++++++++++++---------
.../memory/{MemoryManager.h => MemoryManager.cc} | 43 +++--
cpp/core/memory/MemoryManager.h | 6 +
.../{memory/MemoryManager.h => utils/Registry.h} | 47 ++---
cpp/velox/benchmarks/GenericBenchmark.cc | 14 +-
cpp/velox/benchmarks/ParquetWriteBenchmark.cc | 2 +-
cpp/velox/compute/VeloxBackend.cc | 23 ++-
cpp/velox/compute/VeloxRuntime.cc | 37 ++--
cpp/velox/compute/VeloxRuntime.h | 7 +-
cpp/velox/tests/RuntimeTest.cc | 19 +--
.../gluten/columnarbatch/IndicatorVectorPool.java | 2 +-
.../NativeMemoryManagerJniWrapper.java} | 18 +-
.../apache/gluten/runtime/RuntimeJniWrapper.java | 12 +-
.../vectorized/ColumnarBatchOutIterator.java | 2 +-
.../gluten/vectorized/NativePlanEvaluator.java | 22 +--
.../NativeMemoryManager.scala} | 61 +++----
.../scala/org/apache/gluten/runtime/Runtime.scala | 74 +-------
.../VeloxCelebornColumnarShuffleWriter.scala | 24 +--
.../writer/VeloxUniffleColumnarShuffleWriter.java | 28 +--
23 files changed, 327 insertions(+), 382 deletions(-)
diff --git
a/backends-velox/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala
b/backends-velox/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala
index 4dcd7181bc..6a6d1c57a3 100644
---
a/backends-velox/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala
+++
b/backends-velox/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala
@@ -167,19 +167,21 @@ class ColumnarShuffleWriter[K, V](
GlutenShuffleUtils.getStartPartitionId(dep.nativePartitioning,
taskContext.partitionId),
shuffleWriterType
)
- runtime.addSpiller(new Spiller() {
- override def spill(self: MemoryTarget, phase: Spiller.Phase, size:
Long): Long = {
- if (!Spillers.PHASE_SET_SPILL_ONLY.contains(phase)) {
- return 0L
+ runtime
+ .memoryManager()
+ .addSpiller(new Spiller() {
+ override def spill(self: MemoryTarget, phase: Spiller.Phase,
size: Long): Long = {
+ if (!Spillers.PHASE_SET_SPILL_ONLY.contains(phase)) {
+ return 0L
+ }
+ logInfo(s"Gluten shuffle writer: Trying to spill $size bytes
of data")
+ // fixme pass true when being called by self
+ val spilled =
+ jniWrapper.nativeEvict(nativeShuffleWriter, size, false)
+ logInfo(s"Gluten shuffle writer: Spilled $spilled / $size
bytes of data")
+ spilled
}
- logInfo(s"Gluten shuffle writer: Trying to spill $size bytes of
data")
- // fixme pass true when being called by self
- val spilled =
- jniWrapper.nativeEvict(nativeShuffleWriter, size, false)
- logInfo(s"Gluten shuffle writer: Spilled $spilled / $size bytes
of data")
- spilled
- }
- })
+ })
}
val startTime = System.nanoTime()
jniWrapper.write(nativeShuffleWriter, rows, handle,
availableOffHeapPerTask())
diff --git a/cpp/core/CMakeLists.txt b/cpp/core/CMakeLists.txt
index 1dd31c1a1c..eee834eb7d 100644
--- a/cpp/core/CMakeLists.txt
+++ b/cpp/core/CMakeLists.txt
@@ -185,6 +185,7 @@ set(SPARK_COLUMNAR_PLUGIN_SRCS
jni/JniWrapper.cc
memory/AllocationListener.cc
memory/MemoryAllocator.cc
+ memory/MemoryManager.cc
memory/ArrowMemoryPool.cc
memory/ColumnarBatch.cc
operators/writer/ArrowWriter.cc
diff --git a/cpp/core/compute/Runtime.cc b/cpp/core/compute/Runtime.cc
index 49565fac69..3336412447 100644
--- a/cpp/core/compute/Runtime.cc
+++ b/cpp/core/compute/Runtime.cc
@@ -16,52 +16,27 @@
*/
#include "Runtime.h"
-#include "utils/Print.h"
+#include "utils/Registry.h"
namespace gluten {
namespace {
-class FactoryRegistry {
- public:
- void registerFactory(const std::string& kind, Runtime::Factory factory) {
- std::lock_guard<std::mutex> l(mutex_);
- GLUTEN_CHECK(map_.find(kind) == map_.end(), "Runtime factory already
registered for " + kind);
- map_[kind] = std::move(factory);
- }
-
- Runtime::Factory& getFactory(const std::string& kind) {
- std::lock_guard<std::mutex> l(mutex_);
- GLUTEN_CHECK(map_.find(kind) != map_.end(), "Runtime factory not
registered for " + kind);
- return map_[kind];
- }
-
- bool unregisterFactory(const std::string& kind) {
- std::lock_guard<std::mutex> l(mutex_);
- GLUTEN_CHECK(map_.find(kind) != map_.end(), "Runtime factory not
registered for " + kind);
- return map_.erase(kind);
- }
-
- private:
- std::mutex mutex_;
- std::unordered_map<std::string, Runtime::Factory> map_;
-};
-
-FactoryRegistry& runtimeFactories() {
- static FactoryRegistry registry;
+Registry<Runtime::Factory>& runtimeFactories() {
+ static Registry<Runtime::Factory> registry;
return registry;
}
} // namespace
void Runtime::registerFactory(const std::string& kind, Runtime::Factory
factory) {
- runtimeFactories().registerFactory(kind, std::move(factory));
+ runtimeFactories().registerObj(kind, std::move(factory));
}
Runtime* Runtime::create(
const std::string& kind,
- std::unique_ptr<AllocationListener> listener,
+ MemoryManager* memoryManager,
const std::unordered_map<std::string, std::string>& sessionConf) {
- auto& factory = runtimeFactories().getFactory(kind);
- return factory(std::move(listener), sessionConf);
+ auto& factory = runtimeFactories().get(kind);
+ return factory(std::move(memoryManager), sessionConf);
}
void Runtime::release(Runtime* runtime) {
diff --git a/cpp/core/compute/Runtime.h b/cpp/core/compute/Runtime.h
index 2901a22b0b..a49e0688a6 100644
--- a/cpp/core/compute/Runtime.h
+++ b/cpp/core/compute/Runtime.h
@@ -54,17 +54,17 @@ struct SparkTaskInfo {
class Runtime : public std::enable_shared_from_this<Runtime> {
public:
- using Factory = std::function<
- Runtime*(std::unique_ptr<AllocationListener> listener, const
std::unordered_map<std::string, std::string>&)>;
+ using Factory =
+ std::function<Runtime*(MemoryManager* memoryManager, const
std::unordered_map<std::string, std::string>&)>;
static void registerFactory(const std::string& kind, Factory factory);
static Runtime* create(
const std::string& kind,
- std::unique_ptr<AllocationListener> listener,
+ MemoryManager* memoryManager,
const std::unordered_map<std::string, std::string>& sessionConf = {});
static void release(Runtime*);
static std::optional<std::string>* localWriteFilesTempPath();
- Runtime(std::shared_ptr<MemoryManager> memoryManager, const
std::unordered_map<std::string, std::string>& confMap)
+ Runtime(MemoryManager* memoryManager, const std::unordered_map<std::string,
std::string>& confMap)
: memoryManager_(memoryManager), confMap_(confMap) {}
virtual ~Runtime() = default;
@@ -90,7 +90,7 @@ class Runtime : public std::enable_shared_from_this<Runtime> {
virtual std::shared_ptr<ColumnarBatch>
select(std::shared_ptr<ColumnarBatch>, const std::vector<int32_t>&) = 0;
virtual MemoryManager* memoryManager() {
- return memoryManager_.get();
+ return memoryManager_;
};
/// This function is used to create certain converter from the format used by
@@ -127,7 +127,7 @@ class Runtime : public
std::enable_shared_from_this<Runtime> {
}
protected:
- std::shared_ptr<MemoryManager> memoryManager_;
+ MemoryManager* memoryManager_;
std::unique_ptr<ObjectStore> objStore_ = ObjectStore::create();
std::unordered_map<std::string, std::string> confMap_; // Session conf map
diff --git a/cpp/core/jni/JniWrapper.cc b/cpp/core/jni/JniWrapper.cc
index e6c87b7bd0..f5c105e974 100644
--- a/cpp/core/jni/JniWrapper.cc
+++ b/cpp/core/jni/JniWrapper.cc
@@ -74,7 +74,7 @@ class JavaInputStreamAdaptor final : public
arrow::io::InputStream {
// IMPORTANT: DO NOT USE LOCAL REF IN DIFFERENT THREAD
if (env->GetJavaVM(&vm_) != JNI_OK) {
std::string errorMessage = "Unable to get JavaVM instance";
- throw gluten::GlutenException(errorMessage);
+ throw GlutenException(errorMessage);
}
jniIn_ = env->NewGlobalRef(jniIn);
}
@@ -149,8 +149,8 @@ jint JNI_OnLoad(JavaVM* vm, void* reserved) {
if (vm->GetEnv(reinterpret_cast<void**>(&env), jniVersion) != JNI_OK) {
return JNI_ERR;
}
- gluten::getJniCommonState()->ensureInitialized(env);
- gluten::getJniErrorState()->ensureInitialized(env);
+ getJniCommonState()->ensureInitialized(env);
+ getJniErrorState()->ensureInitialized(env);
byteArrayClass = createGlobalClassReferenceOrError(env, "[B");
@@ -204,49 +204,75 @@ void JNI_OnUnload(JavaVM* vm, void* reserved) {
env->DeleteGlobalRef(byteArrayClass);
env->DeleteGlobalRef(shuffleReaderMetricsClass);
- gluten::getJniErrorState()->close();
- gluten::getJniCommonState()->close();
+ getJniErrorState()->close();
+ getJniCommonState()->close();
+}
+
+JNIEXPORT jlong JNICALL
Java_org_apache_gluten_runtime_RuntimeJniWrapper_createRuntime( // NOLINT
+ JNIEnv* env,
+ jclass,
+ jstring jBackendType,
+ jlong nmmHandle,
+ jbyteArray sessionConf) {
+ JNI_METHOD_START
+ MemoryManager* memoryManager = jniCastOrThrow<MemoryManager>(nmmHandle);
+ auto safeArray = getByteArrayElementsSafe(env, sessionConf);
+ auto sparkConf = parseConfMap(env, safeArray.elems(), safeArray.length());
+ auto backendType = jStringToCString(env, jBackendType);
+
+ auto runtime = Runtime::create(backendType, memoryManager, sparkConf);
+ return reinterpret_cast<jlong>(runtime);
+ JNI_METHOD_END(kInvalidObjectHandle)
+}
+
+JNIEXPORT void JNICALL
Java_org_apache_gluten_runtime_RuntimeJniWrapper_releaseRuntime( // NOLINT
+ JNIEnv* env,
+ jclass,
+ jlong ctxHandle) {
+ JNI_METHOD_START
+ auto runtime = jniCastOrThrow<Runtime>(ctxHandle);
+
+ Runtime::release(runtime);
+ JNI_METHOD_END()
}
namespace {
const std::string kBacktraceAllocation =
"spark.gluten.memory.backtrace.allocation";
}
-JNIEXPORT jlong JNICALL
Java_org_apache_gluten_runtime_RuntimeJniWrapper_createRuntime( // NOLINT
+JNIEXPORT jlong JNICALL
Java_org_apache_gluten_memory_NativeMemoryManagerJniWrapper_create( // NOLINT
JNIEnv* env,
jclass,
- jstring jbackendType,
- jobject jlistener,
+ jstring jBackendType,
+ jobject jListener,
jbyteArray sessionConf) {
JNI_METHOD_START
JavaVM* vm;
if (env->GetJavaVM(&vm) != JNI_OK) {
- throw gluten::GlutenException("Unable to get JavaVM instance");
+ throw GlutenException("Unable to get JavaVM instance");
}
- auto safeArray = gluten::getByteArrayElementsSafe(env, sessionConf);
- auto sparkConf = gluten::parseConfMap(env, safeArray.elems(),
safeArray.length());
- auto backendType = jStringToCString(env, jbackendType);
-
+ auto backendType = jStringToCString(env, jBackendType);
+ auto safeArray = getByteArrayElementsSafe(env, sessionConf);
+ auto sparkConf = parseConfMap(env, safeArray.elems(), safeArray.length());
std::unique_ptr<AllocationListener> listener =
- std::make_unique<SparkAllocationListener>(vm, jlistener,
reserveMemoryMethod, unreserveMemoryMethod);
+ std::make_unique<SparkAllocationListener>(vm, jListener,
reserveMemoryMethod, unreserveMemoryMethod);
bool backtrace = sparkConf.at(kBacktraceAllocation) == "true";
if (backtrace) {
listener =
std::make_unique<BacktraceAllocationListener>(std::move(listener));
}
-
- auto runtime = gluten::Runtime::create(backendType, std::move(listener),
sparkConf);
- return reinterpret_cast<jlong>(runtime);
- JNI_METHOD_END(kInvalidObjectHandle)
+ MemoryManager* mm = MemoryManager::create(backendType, std::move(listener));
+ return reinterpret_cast<jlong>(mm);
+ JNI_METHOD_END(-1L)
}
-JNIEXPORT jbyteArray JNICALL
Java_org_apache_gluten_runtime_RuntimeJniWrapper_collectMemoryUsage( // NOLINT
+JNIEXPORT jbyteArray JNICALL
Java_org_apache_gluten_memory_NativeMemoryManagerJniWrapper_collectUsage( //
NOLINT
JNIEnv* env,
jclass,
- jlong ctxHandle) {
+ jlong nmmHandle) {
JNI_METHOD_START
- auto runtime = jniCastOrThrow<Runtime>(ctxHandle);
+ auto* memoryManager = jniCastOrThrow<MemoryManager>(nmmHandle);
- const MemoryUsageStats& stats =
runtime->memoryManager()->collectMemoryUsageStats();
+ const MemoryUsageStats& stats = memoryManager->collectMemoryUsageStats();
auto size = stats.ByteSizeLong();
jbyteArray out = env->NewByteArray(size);
uint8_t buffer[size];
@@ -258,35 +284,34 @@ JNIEXPORT jbyteArray JNICALL
Java_org_apache_gluten_runtime_RuntimeJniWrapper_co
JNI_METHOD_END(nullptr)
}
-JNIEXPORT jlong JNICALL
Java_org_apache_gluten_runtime_RuntimeJniWrapper_shrinkMemory( // NOLINT
+JNIEXPORT jlong JNICALL
Java_org_apache_gluten_memory_NativeMemoryManagerJniWrapper_shrink( // NOLINT
JNIEnv* env,
jclass,
- jlong ctxHandle,
+ jlong nmmHandle,
jlong size) {
JNI_METHOD_START
- auto runtime = jniCastOrThrow<Runtime>(ctxHandle);
- return runtime->memoryManager()->shrink(static_cast<int64_t>(size));
+ auto* memoryManager = jniCastOrThrow<MemoryManager>(nmmHandle);
+ return memoryManager->shrink(static_cast<int64_t>(size));
JNI_METHOD_END(kInvalidObjectHandle)
}
-JNIEXPORT void JNICALL
Java_org_apache_gluten_runtime_RuntimeJniWrapper_holdMemory( // NOLINT
+JNIEXPORT void JNICALL
Java_org_apache_gluten_memory_NativeMemoryManagerJniWrapper_hold( // NOLINT
JNIEnv* env,
jclass,
- jlong ctxHandle) {
+ jlong nmmHandle) {
JNI_METHOD_START
- auto runtime = jniCastOrThrow<Runtime>(ctxHandle);
- runtime->memoryManager()->hold();
+ auto* memoryManager = jniCastOrThrow<MemoryManager>(nmmHandle);
+ memoryManager->hold();
JNI_METHOD_END()
}
-JNIEXPORT void JNICALL
Java_org_apache_gluten_runtime_RuntimeJniWrapper_releaseRuntime( // NOLINT
+JNIEXPORT void JNICALL
Java_org_apache_gluten_memory_NativeMemoryManagerJniWrapper_release( // NOLINT
JNIEnv* env,
jclass,
- jlong ctxHandle) {
+ jlong nmmHandle) {
JNI_METHOD_START
- auto runtime = jniCastOrThrow<Runtime>(ctxHandle);
-
- gluten::Runtime::release(runtime);
+ auto* memoryManager = jniCastOrThrow<MemoryManager>(nmmHandle);
+ MemoryManager::release(memoryManager);
JNI_METHOD_END()
}
@@ -297,10 +322,10 @@ JNIEXPORT jstring JNICALL
Java_org_apache_gluten_vectorized_PlanEvaluatorJniWrap
jboolean details) {
JNI_METHOD_START
- auto safeArray = gluten::getByteArrayElementsSafe(env, planArray);
+ auto safeArray = getByteArrayElementsSafe(env, planArray);
auto planData = safeArray.elems();
auto planSize = env->GetArrayLength(planArray);
- auto ctx = gluten::getRuntime(env, wrapper);
+ auto ctx = getRuntime(env, wrapper);
ctx->parsePlan(planData, planSize, std::nullopt);
auto& conf = ctx->getConfMap();
auto planString = ctx->planString(details, conf);
@@ -315,9 +340,9 @@ JNIEXPORT void JNICALL
Java_org_apache_gluten_vectorized_PlanEvaluatorJniWrapper
jbyteArray path) {
JNI_METHOD_START
auto len = env->GetArrayLength(path);
- auto safeArray = gluten::getByteArrayElementsSafe(env, path);
+ auto safeArray = getByteArrayElementsSafe(env, path);
std::string pathStr(reinterpret_cast<char*>(safeArray.elems()), len);
- *gluten::Runtime::localWriteFilesTempPath() = pathStr;
+ *Runtime::localWriteFilesTempPath() = pathStr;
JNI_METHOD_END()
}
@@ -335,7 +360,7 @@
Java_org_apache_gluten_vectorized_PlanEvaluatorJniWrapper_nativeCreateKernelWith
jstring spillDir) {
JNI_METHOD_START
- auto ctx = gluten::getRuntime(env, wrapper);
+ auto ctx = getRuntime(env, wrapper);
auto& conf = ctx->getConfMap();
ctx->setSparkTaskInfo({stageId, partitionId, taskId});
@@ -344,19 +369,19 @@
Java_org_apache_gluten_vectorized_PlanEvaluatorJniWrapper_nativeCreateKernelWith
std::string fileIdentifier = "_" + std::to_string(stageId) + "_" +
std::to_string(partitionId);
if (saveInput) {
if (conf.find(kGlutenSaveDir) == conf.end()) {
- throw gluten::GlutenException(kGlutenSaveDir + " is not configured.");
+ throw GlutenException(kGlutenSaveDir + " is not configured.");
}
saveDir = conf.at(kGlutenSaveDir);
std::filesystem::path f{saveDir};
if (!std::filesystem::exists(f)) {
- throw gluten::GlutenException("Save input path " + saveDir + " does not
exists");
+ throw GlutenException("Save input path " + saveDir + " does not exists");
}
ctx->dumpConf(saveDir + "/conf" + fileIdentifier + ".ini");
}
auto spillDirStr = jStringToCString(env, spillDir);
- auto safePlanArray = gluten::getByteArrayElementsSafe(env, planArr);
+ auto safePlanArray = getByteArrayElementsSafe(env, planArr);
auto planSize = env->GetArrayLength(planArr);
ctx->parsePlan(
safePlanArray.elems(),
@@ -366,7 +391,7 @@
Java_org_apache_gluten_vectorized_PlanEvaluatorJniWrapper_nativeCreateKernelWith
for (jsize i = 0, splitInfoArraySize = env->GetArrayLength(splitInfosArr); i
< splitInfoArraySize; i++) {
jbyteArray splitInfoArray =
static_cast<jbyteArray>(env->GetObjectArrayElement(splitInfosArr, i));
jsize splitInfoSize = env->GetArrayLength(splitInfoArray);
- auto safeSplitArray = gluten::getByteArrayElementsSafe(env,
splitInfoArray);
+ auto safeSplitArray = getByteArrayElementsSafe(env, splitInfoArray);
auto splitInfoData = safeSplitArray.elems();
ctx->parseSplitInfo(
splitInfoData,
@@ -403,7 +428,7 @@ JNIEXPORT jboolean JNICALL
Java_org_apache_gluten_vectorized_ColumnarBatchOutIte
if (iter == nullptr) {
std::string errorMessage =
"When hasNext() is called on a closed iterator, an exception is
thrown. To prevent this, consider using the protectInvocationFlow() method when
creating the iterator in scala side. This will allow the hasNext() method to be
called multiple times without issue.";
- throw gluten::GlutenException(errorMessage);
+ throw GlutenException(errorMessage);
}
return iter->hasNext();
JNI_METHOD_END(false)
@@ -414,7 +439,7 @@ JNIEXPORT jlong JNICALL
Java_org_apache_gluten_vectorized_ColumnarBatchOutIterat
jobject wrapper,
jlong iterHandle) {
JNI_METHOD_START
- auto ctx = gluten::getRuntime(env, wrapper);
+ auto ctx = getRuntime(env, wrapper);
auto iter = ObjectStore::retrieve<ResultIterator>(iterHandle);
if (!iter->hasNext()) {
@@ -498,7 +523,7 @@ JNIEXPORT jlong JNICALL
Java_org_apache_gluten_vectorized_ColumnarBatchOutIterat
auto it = ObjectStore::retrieve<ResultIterator>(iterHandle);
if (it == nullptr) {
std::string errorMessage = "Invalid result iter handle " +
std::to_string(iterHandle);
- throw gluten::GlutenException(errorMessage);
+ throw GlutenException(errorMessage);
}
return it->spillFixedSize(size);
JNI_METHOD_END(kInvalidObjectHandle)
@@ -518,7 +543,7 @@
Java_org_apache_gluten_vectorized_NativeColumnarToRowJniWrapper_nativeColumnarTo
JNIEnv* env,
jobject wrapper) {
JNI_METHOD_START
- auto ctx = gluten::getRuntime(env, wrapper);
+ auto ctx = getRuntime(env, wrapper);
auto& conf = ctx->getConfMap();
int64_t column2RowMemThreshold;
@@ -576,7 +601,7 @@ JNIEXPORT jlong JNICALL
Java_org_apache_gluten_vectorized_NativeRowToColumnarJni
jobject wrapper,
jlong cSchema) {
JNI_METHOD_START
- auto ctx = gluten::getRuntime(env, wrapper);
+ auto ctx = getRuntime(env, wrapper);
return
ctx->saveObject(ctx->createRow2ColumnarConverter(reinterpret_cast<struct
ArrowSchema*>(cSchema)));
JNI_METHOD_END(kInvalidObjectHandle)
@@ -590,13 +615,13 @@
Java_org_apache_gluten_vectorized_NativeRowToColumnarJniWrapper_nativeConvertRow
jlongArray rowLength,
jlong memoryAddress) {
JNI_METHOD_START
- auto ctx = gluten::getRuntime(env, wrapper);
+ auto ctx = getRuntime(env, wrapper);
if (rowLength == nullptr) {
- throw gluten::GlutenException("Native convert row to columnar: buf_addrs
can't be null");
+ throw GlutenException("Native convert row to columnar: buf_addrs can't be
null");
}
int numRows = env->GetArrayLength(rowLength);
- auto safeArray = gluten::getLongArrayElementsSafe(env, rowLength);
+ auto safeArray = getLongArrayElementsSafe(env, rowLength);
uint8_t* address = reinterpret_cast<uint8_t*>(memoryAddress);
auto converter = ObjectStore::retrieve<RowToColumnarConverter>(r2cHandle);
@@ -675,7 +700,7 @@ JNIEXPORT jlong JNICALL
Java_org_apache_gluten_columnarbatch_ColumnarBatchJniWra
jlong cSchema,
jlong cArray) {
JNI_METHOD_START
- auto ctx = gluten::getRuntime(env, wrapper);
+ auto ctx = getRuntime(env, wrapper);
std::unique_ptr<ArrowSchema> targetSchema = std::make_unique<ArrowSchema>();
std::unique_ptr<ArrowArray> targetArray = std::make_unique<ArrowArray>();
@@ -694,7 +719,7 @@ JNIEXPORT jlong JNICALL
Java_org_apache_gluten_columnarbatch_ColumnarBatchJniWra
jobject wrapper,
jint numRows) {
JNI_METHOD_START
- auto ctx = gluten::getRuntime(env, wrapper);
+ auto ctx = getRuntime(env, wrapper);
return
ctx->saveObject(ctx->createOrGetEmptySchemaBatch(static_cast<int32_t>(numRows)));
JNI_METHOD_END(kInvalidObjectHandle)
}
@@ -705,9 +730,9 @@ JNIEXPORT jlong JNICALL
Java_org_apache_gluten_columnarbatch_ColumnarBatchJniWra
jlong batchHandle,
jintArray jcolumnIndices) {
JNI_METHOD_START
- auto ctx = gluten::getRuntime(env, wrapper);
+ auto ctx = getRuntime(env, wrapper);
- auto safeArray = gluten::getIntArrayElementsSafe(env, jcolumnIndices);
+ auto safeArray = getIntArrayElementsSafe(env, jcolumnIndices);
int size = env->GetArrayLength(jcolumnIndices);
std::vector<int32_t> columnIndices;
for (int32_t i = 0; i < size; i++) {
@@ -757,19 +782,19 @@ JNIEXPORT jlong JNICALL
Java_org_apache_gluten_vectorized_ShuffleWriterJniWrappe
jstring partitionWriterTypeJstr,
jstring shuffleWriterTypeJstr) {
JNI_METHOD_START
- auto ctx = gluten::getRuntime(env, wrapper);
+ auto ctx = getRuntime(env, wrapper);
if (partitioningNameJstr == nullptr) {
- throw gluten::GlutenException(std::string("Short partitioning name can't
be null"));
+ throw GlutenException(std::string("Short partitioning name can't be
null"));
}
// Build ShuffleWriterOptions.
auto shuffleWriterOptions = ShuffleWriterOptions{
.bufferSize = bufferSize,
.bufferReallocThreshold = reallocThreshold,
- .partitioning = gluten::toPartitioning(jStringToCString(env,
partitioningNameJstr)),
+ .partitioning = toPartitioning(jStringToCString(env,
partitioningNameJstr)),
.taskAttemptId = (int64_t)taskAttemptId,
.startPartitionId = startPartitionId,
- .shuffleWriterType =
gluten::ShuffleWriter::stringToType(jStringToCString(env,
shuffleWriterTypeJstr)),
+ .shuffleWriterType = ShuffleWriter::stringToType(jStringToCString(env,
shuffleWriterTypeJstr)),
.sortBufferInitialSize = sortBufferInitialSize,
.compressionBufferSize = compressionBufferSize,
.useRadixSort = static_cast<bool>(useRadixSort)};
@@ -799,17 +824,17 @@ JNIEXPORT jlong JNICALL
Java_org_apache_gluten_vectorized_ShuffleWriterJniWrappe
if (partitionWriterType == "local") {
if (dataFileJstr == NULL) {
- throw gluten::GlutenException(std::string("Shuffle DataFile can't be
null"));
+ throw GlutenException(std::string("Shuffle DataFile can't be null"));
}
if (localDirsJstr == NULL) {
- throw gluten::GlutenException(std::string("Shuffle DataFile can't be
null"));
+ throw GlutenException(std::string("Shuffle DataFile can't be null"));
}
auto dataFileC = env->GetStringUTFChars(dataFileJstr, JNI_FALSE);
auto dataFile = std::string(dataFileC);
env->ReleaseStringUTFChars(dataFileJstr, dataFileC);
auto localDirsC = env->GetStringUTFChars(localDirsJstr, JNI_FALSE);
- auto configuredDirs = gluten::splitPaths(std::string(localDirsC));
+ auto configuredDirs = splitPaths(std::string(localDirsC));
env->ReleaseStringUTFChars(localDirsJstr, localDirsC);
partitionWriter = std::make_unique<LocalPartitionWriter>(
@@ -825,7 +850,7 @@ JNIEXPORT jlong JNICALL
Java_org_apache_gluten_vectorized_ShuffleWriterJniWrappe
getMethodIdOrError(env, celebornPartitionPusherClass,
"pushPartitionData", "(I[BI)I");
JavaVM* vm;
if (env->GetJavaVM(&vm) != JNI_OK) {
- throw gluten::GlutenException("Unable to get JavaVM instance");
+ throw GlutenException("Unable to get JavaVM instance");
}
std::shared_ptr<JavaRssClient> celebornClient =
std::make_shared<JavaRssClient>(vm, partitionPusher,
celebornPushPartitionDataMethod);
@@ -841,7 +866,7 @@ JNIEXPORT jlong JNICALL
Java_org_apache_gluten_vectorized_ShuffleWriterJniWrappe
getMethodIdOrError(env, unifflePartitionPusherClass,
"pushPartitionData", "(I[BI)I");
JavaVM* vm;
if (env->GetJavaVM(&vm) != JNI_OK) {
- throw gluten::GlutenException("Unable to get JavaVM instance");
+ throw GlutenException("Unable to get JavaVM instance");
}
std::shared_ptr<JavaRssClient> uniffleClient =
std::make_shared<JavaRssClient>(vm, partitionPusher,
unifflePushPartitionDataMethod);
@@ -851,7 +876,7 @@ JNIEXPORT jlong JNICALL
Java_org_apache_gluten_vectorized_ShuffleWriterJniWrappe
ctx->memoryManager()->getArrowMemoryPool(),
std::move(uniffleClient));
} else {
- throw gluten::GlutenException("Unrecognizable partition writer type: " +
partitionWriterType);
+ throw GlutenException("Unrecognizable partition writer type: " +
partitionWriterType);
}
return ctx->saveObject(
@@ -869,11 +894,10 @@ JNIEXPORT jlong JNICALL
Java_org_apache_gluten_vectorized_ShuffleWriterJniWrappe
auto shuffleWriter =
ObjectStore::retrieve<ShuffleWriter>(shuffleWriterHandle);
if (!shuffleWriter) {
std::string errorMessage = "Invalid shuffle writer handle " +
std::to_string(shuffleWriterHandle);
- throw gluten::GlutenException(errorMessage);
+ throw GlutenException(errorMessage);
}
int64_t evictedSize;
- gluten::arrowAssertOkOrThrow(
- shuffleWriter->reclaimFixedSize(size, &evictedSize), "(shuffle)
nativeEvict: evict failed");
+ arrowAssertOkOrThrow(shuffleWriter->reclaimFixedSize(size, &evictedSize),
"(shuffle) nativeEvict: evict failed");
return (jlong)evictedSize;
JNI_METHOD_END(kInvalidObjectHandle)
}
@@ -889,13 +913,13 @@ JNIEXPORT jlong JNICALL
Java_org_apache_gluten_vectorized_ShuffleWriterJniWrappe
auto shuffleWriter =
ObjectStore::retrieve<ShuffleWriter>(shuffleWriterHandle);
if (!shuffleWriter) {
std::string errorMessage = "Invalid shuffle writer handle " +
std::to_string(shuffleWriterHandle);
- throw gluten::GlutenException(errorMessage);
+ throw GlutenException(errorMessage);
}
// The column batch maybe VeloxColumnBatch or
ArrowCStructColumnarBatch(FallbackRangeShuffleWriter)
auto batch = ObjectStore::retrieve<ColumnarBatch>(batchHandle);
auto numBytes = batch->numBytes();
- gluten::arrowAssertOkOrThrow(shuffleWriter->write(batch, memLimit), "Native
write: shuffle writer failed");
+ arrowAssertOkOrThrow(shuffleWriter->write(batch, memLimit), "Native write:
shuffle writer failed");
return numBytes;
JNI_METHOD_END(kInvalidObjectHandle)
}
@@ -908,10 +932,10 @@ JNIEXPORT jobject JNICALL
Java_org_apache_gluten_vectorized_ShuffleWriterJniWrap
auto shuffleWriter =
ObjectStore::retrieve<ShuffleWriter>(shuffleWriterHandle);
if (!shuffleWriter) {
std::string errorMessage = "Invalid shuffle writer handle " +
std::to_string(shuffleWriterHandle);
- throw gluten::GlutenException(errorMessage);
+ throw GlutenException(errorMessage);
}
- gluten::arrowAssertOkOrThrow(shuffleWriter->stop(), "Native shuffle write:
ShuffleWriter stop failed");
+ arrowAssertOkOrThrow(shuffleWriter->stop(), "Native shuffle write:
ShuffleWriter stop failed");
const auto& partitionLengths = shuffleWriter->partitionLengths();
auto partitionLengthArr = env->NewLongArray(partitionLengths.size());
@@ -959,7 +983,7 @@ JNIEXPORT void JNICALL
Java_org_apache_gluten_vectorized_OnHeapJniByteInputStrea
jlong destAddress,
jint size) {
JNI_METHOD_START
- auto safeArray = gluten::getByteArrayElementsSafe(env, source);
+ auto safeArray = getByteArrayElementsSafe(env, source);
std::memcpy(reinterpret_cast<void*>(destAddress), safeArray.elems(), size);
JNI_METHOD_END()
}
@@ -973,7 +997,7 @@ JNIEXPORT jlong JNICALL
Java_org_apache_gluten_vectorized_ShuffleReaderJniWrappe
jint batchSize,
jstring shuffleWriterType) {
JNI_METHOD_START
- auto ctx = gluten::getRuntime(env, wrapper);
+ auto ctx = getRuntime(env, wrapper);
ShuffleReaderOptions options = ShuffleReaderOptions{};
options.compressionType = getCompressionType(env, compressionType);
@@ -984,9 +1008,9 @@ JNIEXPORT jlong JNICALL
Java_org_apache_gluten_vectorized_ShuffleReaderJniWrappe
options.batchSize = batchSize;
// TODO: Add coalesce option and maximum coalesced size.
- options.shuffleWriterType =
gluten::ShuffleWriter::stringToType(jStringToCString(env, shuffleWriterType));
+ options.shuffleWriterType =
ShuffleWriter::stringToType(jStringToCString(env, shuffleWriterType));
std::shared_ptr<arrow::Schema> schema =
- gluten::arrowGetOrThrow(arrow::ImportSchema(reinterpret_cast<struct
ArrowSchema*>(cSchema)));
+ arrowGetOrThrow(arrow::ImportSchema(reinterpret_cast<struct
ArrowSchema*>(cSchema)));
return ctx->saveObject(ctx->createShuffleReader(schema, options));
JNI_METHOD_END(kInvalidObjectHandle)
@@ -998,7 +1022,7 @@ JNIEXPORT jlong JNICALL
Java_org_apache_gluten_vectorized_ShuffleReaderJniWrappe
jlong shuffleReaderHandle,
jobject jniIn) {
JNI_METHOD_START
- auto ctx = gluten::getRuntime(env, wrapper);
+ auto ctx = getRuntime(env, wrapper);
auto reader = ObjectStore::retrieve<ShuffleReader>(shuffleReaderHandle);
std::shared_ptr<arrow::io::InputStream> in =
std::make_shared<JavaInputStreamAdaptor>(env, reader->getPool(), jniIn);
auto outItr = reader->readStream(in);
@@ -1036,10 +1060,10 @@ JNIEXPORT jobject JNICALL
Java_org_apache_gluten_vectorized_ColumnarBatchSeriali
jobject wrapper,
jlongArray handles) {
JNI_METHOD_START
- auto ctx = gluten::getRuntime(env, wrapper);
+ auto ctx = getRuntime(env, wrapper);
int32_t numBatches = env->GetArrayLength(handles);
- auto safeArray = gluten::getLongArrayElementsSafe(env, handles);
+ auto safeArray = getLongArrayElementsSafe(env, handles);
std::vector<std::shared_ptr<ColumnarBatch>> batches;
int64_t numRows = 0L;
@@ -1068,7 +1092,7 @@ JNIEXPORT jlong JNICALL
Java_org_apache_gluten_vectorized_ColumnarBatchSerialize
jobject wrapper,
jlong cSchema) {
JNI_METHOD_START
- auto ctx = gluten::getRuntime(env, wrapper);
+ auto ctx = getRuntime(env, wrapper);
return
ctx->saveObject(ctx->createColumnarBatchSerializer(reinterpret_cast<struct
ArrowSchema*>(cSchema)));
JNI_METHOD_END(kInvalidObjectHandle)
}
@@ -1079,12 +1103,12 @@ JNIEXPORT jlong JNICALL
Java_org_apache_gluten_vectorized_ColumnarBatchSerialize
jlong serializerHandle,
jbyteArray data) {
JNI_METHOD_START
- auto ctx = gluten::getRuntime(env, wrapper);
+ auto ctx = getRuntime(env, wrapper);
auto serializer =
ObjectStore::retrieve<ColumnarBatchSerializer>(serializerHandle);
GLUTEN_DCHECK(serializer != nullptr, "ColumnarBatchSerializer cannot be
null");
int32_t size = env->GetArrayLength(data);
- auto safeArray = gluten::getByteArrayElementsSafe(env, data);
+ auto safeArray = getByteArrayElementsSafe(env, data);
auto batch = serializer->deserialize(safeArray.elems(), size);
return ctx->saveObject(batch);
JNI_METHOD_END(kInvalidObjectHandle)
diff --git a/cpp/core/memory/MemoryManager.h b/cpp/core/memory/MemoryManager.cc
similarity index 54%
copy from cpp/core/memory/MemoryManager.h
copy to cpp/core/memory/MemoryManager.cc
index 5ec5213051..e4fcc4fb09 100644
--- a/cpp/core/memory/MemoryManager.h
+++ b/cpp/core/memory/MemoryManager.cc
@@ -14,30 +14,29 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-
-#pragma once
-
-#include "arrow/memory_pool.h"
-#include "memory.pb.h"
+#include "MemoryManager.h"
+#include "utils/Registry.h"
namespace gluten {
-class MemoryManager {
- public:
- MemoryManager() = default;
-
- virtual ~MemoryManager() = default;
-
- virtual arrow::MemoryPool* getArrowMemoryPool() = 0;
-
- virtual const MemoryUsageStats collectMemoryUsageStats() const = 0;
-
- virtual const int64_t shrink(int64_t size) = 0;
-
- // Hold this memory manager. The underlying memory pools will be released as
lately as this memory manager gets
- // destroyed. Which means, a call to this function would make sure the
memory blocks directly or indirectly managed
- // by this manager, be guaranteed safe to access during the period that this
manager is alive.
- virtual void hold() = 0;
-};
+namespace {
+Registry<MemoryManager::Factory>& memoryManagerFactories() {
+ static Registry<MemoryManager::Factory> registry;
+ return registry;
+}
+} // namespace
+
+void MemoryManager::registerFactory(const std::string& kind,
MemoryManager::Factory factory) {
+ memoryManagerFactories().registerObj(kind, std::move(factory));
+}
+
+MemoryManager* MemoryManager::create(const std::string& kind,
std::unique_ptr<AllocationListener> listener) {
+ auto& factory = memoryManagerFactories().get(kind);
+ return factory(std::move(listener));
+}
+
+void MemoryManager::release(MemoryManager* memoryManager) {
+ delete memoryManager;
+}
} // namespace gluten
diff --git a/cpp/core/memory/MemoryManager.h b/cpp/core/memory/MemoryManager.h
index 5ec5213051..a3a2d4322c 100644
--- a/cpp/core/memory/MemoryManager.h
+++ b/cpp/core/memory/MemoryManager.h
@@ -19,11 +19,17 @@
#include "arrow/memory_pool.h"
#include "memory.pb.h"
+#include "memory/AllocationListener.h"
namespace gluten {
class MemoryManager {
public:
+ using Factory =
std::function<MemoryManager*(std::unique_ptr<AllocationListener> listener)>;
+ static void registerFactory(const std::string& kind, Factory factory);
+ static MemoryManager* create(const std::string& kind,
std::unique_ptr<AllocationListener> listener);
+ static void release(MemoryManager*);
+
MemoryManager() = default;
virtual ~MemoryManager() = default;
diff --git a/cpp/core/memory/MemoryManager.h b/cpp/core/utils/Registry.h
similarity index 52%
copy from cpp/core/memory/MemoryManager.h
copy to cpp/core/utils/Registry.h
index 5ec5213051..e50eb6763d 100644
--- a/cpp/core/memory/MemoryManager.h
+++ b/cpp/core/utils/Registry.h
@@ -17,27 +17,36 @@
#pragma once
-#include "arrow/memory_pool.h"
-#include "memory.pb.h"
+#include <mutex>
+#include <string>
+#include <unordered_map>
-namespace gluten {
+#include "utils/Exception.h"
-class MemoryManager {
+namespace gluten {
+template <typename T>
+class Registry {
public:
- MemoryManager() = default;
-
- virtual ~MemoryManager() = default;
-
- virtual arrow::MemoryPool* getArrowMemoryPool() = 0;
-
- virtual const MemoryUsageStats collectMemoryUsageStats() const = 0;
-
- virtual const int64_t shrink(int64_t size) = 0;
-
- // Hold this memory manager. The underlying memory pools will be released as
lately as this memory manager gets
- // destroyed. Which means, a call to this function would make sure the
memory blocks directly or indirectly managed
- // by this manager, be guaranteed safe to access during the period that this
manager is alive.
- virtual void hold() = 0;
+ void registerObj(const std::string& kind, T t) {
+ std::lock_guard<std::mutex> l(mutex_);
+ GLUTEN_CHECK(map_.find(kind) == map_.end(), "Already registered for " +
kind);
+ map_[kind] = std::move(t);
+ }
+
+ T& get(const std::string& kind) {
+ std::lock_guard<std::mutex> l(mutex_);
+ GLUTEN_CHECK(map_.find(kind) != map_.end(), "Not registered for " + kind);
+ return map_[kind];
+ }
+
+ bool unregisterObj(const std::string& kind) {
+ std::lock_guard<std::mutex> l(mutex_);
+ GLUTEN_CHECK(map_.find(kind) != map_.end(), "Not registered for " + kind);
+ return map_.erase(kind);
+ }
+
+ private:
+ std::mutex mutex_;
+ std::unordered_map<std::string, T> map_;
};
-
} // namespace gluten
diff --git a/cpp/velox/benchmarks/GenericBenchmark.cc
b/cpp/velox/benchmarks/GenericBenchmark.cc
index ebbb0e0ea8..e42aed9f21 100644
--- a/cpp/velox/benchmarks/GenericBenchmark.cc
+++ b/cpp/velox/benchmarks/GenericBenchmark.cc
@@ -295,7 +295,7 @@ void updateBenchmarkMetrics(
} // namespace
-using RuntimeFactory =
std::function<VeloxRuntime*(std::unique_ptr<AllocationListener> listener)>;
+using RuntimeFactory = std::function<VeloxRuntime*(MemoryManager*
memoryManager)>;
auto BM_Generic = [](::benchmark::State& state,
const std::string& planFile,
@@ -307,7 +307,8 @@ auto BM_Generic = [](::benchmark::State& state,
auto listener =
std::make_unique<BenchmarkAllocationListener>(FLAGS_memory_limit);
auto* listenerPtr = listener.get();
- auto runtime = runtimeFactory(std::move(listener));
+ auto* memoryManager = MemoryManager::create(kVeloxBackendKind,
std::move(listener));
+ auto runtime = runtimeFactory(memoryManager);
auto plan = getPlanFromFile("Plan", planFile);
std::vector<std::string> splits{};
@@ -399,6 +400,7 @@ auto BM_Generic = [](::benchmark::State& state,
updateBenchmarkMetrics(state, elapsedTime, readInputTime, writerMetrics,
readerMetrics);
Runtime::release(runtime);
+ MemoryManager::release(memoryManager);
};
auto BM_ShuffleWriteRead = [](::benchmark::State& state,
@@ -409,7 +411,8 @@ auto BM_ShuffleWriteRead = [](::benchmark::State& state,
auto listener =
std::make_unique<BenchmarkAllocationListener>(FLAGS_memory_limit);
auto* listenerPtr = listener.get();
- auto runtime = runtimeFactory(std::move(listener));
+ auto* memoryManager = MemoryManager::create(kVeloxBackendKind,
std::move(listener));
+ auto runtime = runtimeFactory(memoryManager);
WriterMetrics writerMetrics{};
ReaderMetrics readerMetrics{};
@@ -428,6 +431,7 @@ auto BM_ShuffleWriteRead = [](::benchmark::State& state,
updateBenchmarkMetrics(state, elapsedTime, readInputTime, writerMetrics,
readerMetrics);
Runtime::release(runtime);
+ MemoryManager::release(memoryManager);
};
int main(int argc, char** argv) {
@@ -592,8 +596,8 @@ int main(int argc, char** argv) {
}
}
- RuntimeFactory runtimeFactory = [=](std::unique_ptr<AllocationListener>
listener) {
- return dynamic_cast<VeloxRuntime*>(Runtime::create(kVeloxRuntimeKind,
std::move(listener), sessionConf));
+ RuntimeFactory runtimeFactory = [=](MemoryManager* memoryManager) {
+ return dynamic_cast<VeloxRuntime*>(Runtime::create(kVeloxBackendKind,
memoryManager, sessionConf));
};
#define GENERIC_BENCHMARK(READER_TYPE)
\
diff --git a/cpp/velox/benchmarks/ParquetWriteBenchmark.cc
b/cpp/velox/benchmarks/ParquetWriteBenchmark.cc
index 2369a13bd2..e60efd5505 100644
--- a/cpp/velox/benchmarks/ParquetWriteBenchmark.cc
+++ b/cpp/velox/benchmarks/ParquetWriteBenchmark.cc
@@ -257,8 +257,8 @@ class GoogleBenchmarkVeloxParquetWriteCacheScanBenchmark :
public GoogleBenchmar
// reuse the ParquetWriteConverter for batches caused system % increase a
lot
auto fileName = "velox_parquet_write.parquet";
- auto runtime = Runtime::create(kVeloxRuntimeKind,
AllocationListener::noop());
auto memoryManager = getDefaultMemoryManager();
+ auto runtime = Runtime::create(kVeloxBackendKind, memoryManager.get());
auto veloxPool = memoryManager->getAggregateMemoryPool();
for (auto _ : state) {
diff --git a/cpp/velox/compute/VeloxBackend.cc
b/cpp/velox/compute/VeloxBackend.cc
index 609ae6fce3..387ad8c431 100644
--- a/cpp/velox/compute/VeloxBackend.cc
+++ b/cpp/velox/compute/VeloxBackend.cc
@@ -62,10 +62,16 @@ using namespace facebook;
namespace gluten {
namespace {
-gluten::Runtime* veloxRuntimeFactory(
- std::unique_ptr<AllocationListener> listener,
+MemoryManager* veloxMemoryManagerFactory(std::unique_ptr<AllocationListener>
listener) {
+ return new VeloxMemoryManager(std::move(listener));
+}
+
+Runtime* veloxRuntimeFactory(
+ MemoryManager* memoryManager,
const std::unordered_map<std::string, std::string>& sessionConf) {
- return new gluten::VeloxRuntime(std::move(listener), sessionConf);
+ auto* vmm = dynamic_cast<VeloxMemoryManager*>(memoryManager);
+ GLUTEN_CHECK(vmm != nullptr, "Not a Velox memory manager");
+ return new VeloxRuntime(vmm, sessionConf);
}
} // namespace
@@ -73,8 +79,9 @@ void VeloxBackend::init(const std::unordered_map<std::string,
std::string>& conf
backendConf_ =
std::make_shared<facebook::velox::config::ConfigBase>(std::unordered_map<std::string,
std::string>(conf));
- // Register Velox runtime factory
- gluten::Runtime::registerFactory(gluten::kVeloxRuntimeKind,
veloxRuntimeFactory);
+ // Register factories.
+ MemoryManager::registerFactory(kVeloxBackendKind, veloxMemoryManagerFactory);
+ Runtime::registerFactory(kVeloxBackendKind, veloxRuntimeFactory);
if (backendConf_->get<bool>(kDebugModeEnabled, false)) {
LOG(INFO) << "VeloxBackend config:" <<
printConfig(backendConf_->rawConfigs());
@@ -175,7 +182,7 @@ void VeloxBackend::initJolFilesystem() {
// FIXME It's known that if spill compression is disabled, the actual spill
file size may
// in crease beyond this limit a little (maximum 64 rows which is by
default
// one compression page)
- gluten::registerJolFileSystem(maxSpillFileSize);
+ registerJolFileSystem(maxSpillFileSize);
}
void VeloxBackend::initCache() {
@@ -284,7 +291,7 @@ void VeloxBackend::initConnector() {
void VeloxBackend::initUdf() {
auto got = backendConf_->get<std::string>(kVeloxUdfLibraryPaths, "");
if (!got.empty()) {
- auto udfLoader = gluten::UdfLoader::getInstance();
+ auto udfLoader = UdfLoader::getInstance();
udfLoader->loadUdfLibraries(got);
udfLoader->registerUdf();
}
@@ -293,7 +300,7 @@ void VeloxBackend::initUdf() {
std::unique_ptr<VeloxBackend> VeloxBackend::instance_ = nullptr;
void VeloxBackend::create(const std::unordered_map<std::string, std::string>&
conf) {
- instance_ = std::unique_ptr<VeloxBackend>(new gluten::VeloxBackend(conf));
+ instance_ = std::unique_ptr<VeloxBackend>(new VeloxBackend(conf));
}
VeloxBackend* VeloxBackend::get() {
diff --git a/cpp/velox/compute/VeloxRuntime.cc
b/cpp/velox/compute/VeloxRuntime.cc
index 398597f9b4..931ac82c5d 100644
--- a/cpp/velox/compute/VeloxRuntime.cc
+++ b/cpp/velox/compute/VeloxRuntime.cc
@@ -56,12 +56,9 @@ using namespace facebook;
namespace gluten {
-VeloxRuntime::VeloxRuntime(
- std::unique_ptr<AllocationListener> listener,
- const std::unordered_map<std::string, std::string>& confMap)
- : Runtime(std::make_shared<VeloxMemoryManager>(std::move(listener)),
confMap) {
+VeloxRuntime::VeloxRuntime(VeloxMemoryManager* vmm, const
std::unordered_map<std::string, std::string>& confMap)
+ : Runtime(vmm, confMap) {
// Refresh session config.
- vmm_ = dynamic_cast<VeloxMemoryManager*>(memoryManager_.get());
veloxCfg_ =
std::make_shared<facebook::velox::config::ConfigBase>(std::unordered_map<std::string,
std::string>(confMap_));
debugModeEnabled_ = veloxCfg_->get<bool>(kDebugModeEnabled, false);
@@ -129,7 +126,9 @@ std::string VeloxRuntime::planString(bool details, const
std::unordered_map<std:
}
VeloxMemoryManager* VeloxRuntime::memoryManager() {
- return vmm_;
+ auto vmm = dynamic_cast<VeloxMemoryManager*>(memoryManager_);
+ GLUTEN_CHECK(vmm != nullptr, "Not a Velox memory manager");
+ return vmm;
}
std::shared_ptr<ResultIterator> VeloxRuntime::createResultIterator(
@@ -139,7 +138,7 @@ std::shared_ptr<ResultIterator>
VeloxRuntime::createResultIterator(
LOG_IF(INFO, debugModeEnabled_) << "VeloxRuntime session config:" <<
printConfig(confMap_);
VeloxPlanConverter veloxPlanConverter(
- inputs, vmm_->getLeafMemoryPool().get(), sessionConf,
*localWriteFilesTempPath());
+ inputs, memoryManager()->getLeafMemoryPool().get(), sessionConf,
*localWriteFilesTempPath());
veloxPlan_ = veloxPlanConverter.toVeloxPlan(substraitPlan_,
std::move(localFiles_));
// Scan node can be required.
@@ -151,12 +150,12 @@ std::shared_ptr<ResultIterator>
VeloxRuntime::createResultIterator(
getInfoAndIds(veloxPlanConverter.splitInfos(),
veloxPlan_->leafPlanNodeIds(), scanInfos, scanIds, streamIds);
auto wholestageIter = std::make_unique<WholeStageResultIterator>(
- vmm_, veloxPlan_, scanIds, scanInfos, streamIds, spillDir, sessionConf,
taskInfo_);
+ memoryManager(), veloxPlan_, scanIds, scanInfos, streamIds, spillDir,
sessionConf, taskInfo_);
return std::make_shared<ResultIterator>(std::move(wholestageIter), this);
}
std::shared_ptr<ColumnarToRowConverter>
VeloxRuntime::createColumnar2RowConverter(int64_t column2RowMemThreshold) {
- auto veloxPool = vmm_->getLeafMemoryPool();
+ auto veloxPool = memoryManager()->getLeafMemoryPool();
return std::make_shared<VeloxColumnarToRowConverter>(veloxPool,
column2RowMemThreshold);
}
@@ -172,14 +171,14 @@ std::shared_ptr<ColumnarBatch>
VeloxRuntime::createOrGetEmptySchemaBatch(int32_t
std::shared_ptr<ColumnarBatch> VeloxRuntime::select(
std::shared_ptr<ColumnarBatch> batch,
const std::vector<int32_t>& columnIndices) {
- auto veloxPool = vmm_->getLeafMemoryPool();
+ auto veloxPool = memoryManager()->getLeafMemoryPool();
auto veloxBatch = gluten::VeloxColumnarBatch::from(veloxPool.get(), batch);
auto outputBatch = veloxBatch->select(veloxPool.get(),
std::move(columnIndices));
return outputBatch;
}
std::shared_ptr<RowToColumnarConverter>
VeloxRuntime::createRow2ColumnarConverter(struct ArrowSchema* cSchema) {
- auto veloxPool = vmm_->getLeafMemoryPool();
+ auto veloxPool = memoryManager()->getLeafMemoryPool();
return std::make_shared<VeloxRowToColumnarConverter>(cSchema, veloxPool);
}
@@ -187,8 +186,8 @@ std::shared_ptr<ShuffleWriter>
VeloxRuntime::createShuffleWriter(
int numPartitions,
std::unique_ptr<PartitionWriter> partitionWriter,
ShuffleWriterOptions options) {
- auto veloxPool = vmm_->getLeafMemoryPool();
- auto arrowPool = vmm_->getArrowMemoryPool();
+ auto veloxPool = memoryManager()->getLeafMemoryPool();
+ auto arrowPool = memoryManager()->getArrowMemoryPool();
GLUTEN_ASSIGN_OR_THROW(
std::shared_ptr<ShuffleWriter> shuffleWriter,
VeloxShuffleWriter::create(
@@ -205,11 +204,11 @@ std::shared_ptr<VeloxDataSource>
VeloxRuntime::createDataSource(
const std::string& filePath,
std::shared_ptr<arrow::Schema> schema) {
static std::atomic_uint32_t id{0UL};
- auto veloxPool =
vmm_->getAggregateMemoryPool()->addAggregateChild("datasource." +
std::to_string(id++));
+ auto veloxPool =
memoryManager()->getAggregateMemoryPool()->addAggregateChild("datasource." +
std::to_string(id++));
// Pass a dedicate pool for S3 and GCS sinks as can't share veloxPool
// with parquet writer.
// FIXME: Check file formats?
- auto sinkPool = vmm_->getLeafMemoryPool();
+ auto sinkPool = memoryManager()->getLeafMemoryPool();
if (isSupportedHDFSPath(filePath)) {
#ifdef ENABLE_HDFS
return std::make_shared<VeloxParquetDataSourceHDFS>(filePath, veloxPool,
sinkPool, schema);
@@ -247,7 +246,7 @@ std::shared_ptr<ShuffleReader>
VeloxRuntime::createShuffleReader(
ShuffleReaderOptions options) {
auto rowType = facebook::velox::asRowType(gluten::fromArrowSchema(schema));
auto codec = gluten::createArrowIpcCodec(options.compressionType,
options.codecBackend);
- auto ctxVeloxPool = vmm_->getLeafMemoryPool();
+ auto ctxVeloxPool = memoryManager()->getLeafMemoryPool();
auto veloxCompressionType =
facebook::velox::common::stringToCompressionKind(options.compressionTypeStr);
auto deserializerFactory =
std::make_unique<gluten::VeloxColumnarBatchDeserializerFactory>(
schema,
@@ -255,7 +254,7 @@ std::shared_ptr<ShuffleReader>
VeloxRuntime::createShuffleReader(
veloxCompressionType,
rowType,
options.batchSize,
- vmm_->getArrowMemoryPool(),
+ memoryManager()->getArrowMemoryPool(),
ctxVeloxPool,
options.shuffleWriterType);
auto reader =
std::make_shared<VeloxShuffleReader>(std::move(deserializerFactory));
@@ -263,8 +262,8 @@ std::shared_ptr<ShuffleReader>
VeloxRuntime::createShuffleReader(
}
std::unique_ptr<ColumnarBatchSerializer>
VeloxRuntime::createColumnarBatchSerializer(struct ArrowSchema* cSchema) {
- auto arrowPool = vmm_->getArrowMemoryPool();
- auto veloxPool = vmm_->getLeafMemoryPool();
+ auto arrowPool = memoryManager()->getArrowMemoryPool();
+ auto veloxPool = memoryManager()->getLeafMemoryPool();
return std::make_unique<VeloxColumnarBatchSerializer>(arrowPool, veloxPool,
cSchema);
}
diff --git a/cpp/velox/compute/VeloxRuntime.h b/cpp/velox/compute/VeloxRuntime.h
index 8101426859..6df47aeffb 100644
--- a/cpp/velox/compute/VeloxRuntime.h
+++ b/cpp/velox/compute/VeloxRuntime.h
@@ -29,13 +29,11 @@
namespace gluten {
// This kind string must be same with VeloxBackend#name in java side.
-inline static const std::string kVeloxRuntimeKind{"velox"};
+inline static const std::string kVeloxBackendKind{"velox"};
class VeloxRuntime final : public Runtime {
public:
- explicit VeloxRuntime(
- std::unique_ptr<AllocationListener> listener,
- const std::unordered_map<std::string, std::string>& confMap);
+ explicit VeloxRuntime(VeloxMemoryManager* vmm, const
std::unordered_map<std::string, std::string>& confMap);
void parsePlan(const uint8_t* data, int32_t size, std::optional<std::string>
dumpFile) override;
@@ -96,7 +94,6 @@ class VeloxRuntime final : public Runtime {
std::vector<facebook::velox::core::PlanNodeId>& streamIds);
private:
- VeloxMemoryManager* vmm_;
std::shared_ptr<const facebook::velox::core::PlanNode> veloxPlan_;
std::shared_ptr<facebook::velox::config::ConfigBase> veloxCfg_;
bool debugModeEnabled_{false};
diff --git a/cpp/velox/tests/RuntimeTest.cc b/cpp/velox/tests/RuntimeTest.cc
index e978f2eec7..9b40f6c78e 100644
--- a/cpp/velox/tests/RuntimeTest.cc
+++ b/cpp/velox/tests/RuntimeTest.cc
@@ -40,8 +40,7 @@ class DummyMemoryManager final : public MemoryManager {
class DummyRuntime final : public Runtime {
public:
- DummyRuntime(std::unique_ptr<AllocationListener> listener, const
std::unordered_map<std::string, std::string>& conf)
- : Runtime(std::make_shared<DummyMemoryManager>(), conf) {}
+ DummyRuntime(DummyMemoryManager* mm, const std::unordered_map<std::string,
std::string>& conf) : Runtime(mm, conf) {}
void parsePlan(const uint8_t* data, int32_t size, std::optional<std::string>
dumpFile) override {}
@@ -113,29 +112,29 @@ class DummyRuntime final : public Runtime {
};
};
-static Runtime* dummyRuntimeFactory(
- std::unique_ptr<AllocationListener> listener,
- const std::unordered_map<std::string, std::string> conf) {
- return new DummyRuntime(std::move(listener), conf);
+static Runtime* dummyRuntimeFactory(MemoryManager* mm, const
std::unordered_map<std::string, std::string> conf) {
+ return new DummyRuntime(dynamic_cast<DummyMemoryManager*>(mm), conf);
}
TEST(TestRuntime, CreateRuntime) {
Runtime::registerFactory("DUMMY", dummyRuntimeFactory);
- auto runtime = Runtime::create("DUMMY", AllocationListener::noop());
+ DummyMemoryManager mm;
+ auto runtime = Runtime::create("DUMMY", &mm);
ASSERT_EQ(typeid(*runtime), typeid(DummyRuntime));
Runtime::release(runtime);
}
TEST(TestRuntime, CreateVeloxRuntime) {
VeloxBackend::create({});
- auto runtime = Runtime::create(kVeloxRuntimeKind,
AllocationListener::noop());
+ VeloxMemoryManager mm(AllocationListener::noop());
+ auto runtime = Runtime::create(kVeloxBackendKind, &mm);
ASSERT_EQ(typeid(*runtime), typeid(VeloxRuntime));
Runtime::release(runtime);
}
TEST(TestRuntime, GetResultIterator) {
- auto runtime =
- std::make_shared<DummyRuntime>(AllocationListener::noop(),
std::unordered_map<std::string, std::string>());
+ DummyMemoryManager mm;
+ auto runtime = std::make_shared<DummyRuntime>(&mm,
std::unordered_map<std::string, std::string>());
auto iter = runtime->createResultIterator("/tmp/test-spill", {}, {});
ASSERT_TRUE(iter->hasNext());
auto next = iter->next();
diff --git
a/gluten-arrow/src/main/java/org/apache/gluten/columnarbatch/IndicatorVectorPool.java
b/gluten-arrow/src/main/java/org/apache/gluten/columnarbatch/IndicatorVectorPool.java
index c122cc1cca..41c8cbdd43 100644
---
a/gluten-arrow/src/main/java/org/apache/gluten/columnarbatch/IndicatorVectorPool.java
+++
b/gluten-arrow/src/main/java/org/apache/gluten/columnarbatch/IndicatorVectorPool.java
@@ -56,7 +56,7 @@ public class IndicatorVectorPool implements TaskResource {
@Override
public int priority() {
- return 0;
+ return 10;
}
@Override
diff --git
a/gluten-arrow/src/main/java/org/apache/gluten/runtime/RuntimeJniWrapper.java
b/gluten-arrow/src/main/java/org/apache/gluten/memory/NativeMemoryManagerJniWrapper.java
similarity index 69%
copy from
gluten-arrow/src/main/java/org/apache/gluten/runtime/RuntimeJniWrapper.java
copy to
gluten-arrow/src/main/java/org/apache/gluten/memory/NativeMemoryManagerJniWrapper.java
index 80f9509d9e..c23b9704ec 100644
---
a/gluten-arrow/src/main/java/org/apache/gluten/runtime/RuntimeJniWrapper.java
+++
b/gluten-arrow/src/main/java/org/apache/gluten/memory/NativeMemoryManagerJniWrapper.java
@@ -14,23 +14,21 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.gluten.runtime;
+package org.apache.gluten.memory;
import org.apache.gluten.memory.listener.ReservationListener;
-public class RuntimeJniWrapper {
+public class NativeMemoryManagerJniWrapper {
+ private NativeMemoryManagerJniWrapper() {}
- private RuntimeJniWrapper() {}
-
- public static native long createRuntime(
+ public static native long create(
String backendType, ReservationListener listener, byte[] sessionConf);
- // Memory management.
- public static native byte[] collectMemoryUsage(long handle);
+ public static native byte[] collectUsage(long handle);
- public static native long shrinkMemory(long handle, long size);
+ public static native long shrink(long handle, long size);
- public static native void holdMemory(long handle);
+ public static native void hold(long handle);
- public static native void releaseRuntime(long handle);
+ public static native void release(long handle);
}
diff --git
a/gluten-arrow/src/main/java/org/apache/gluten/runtime/RuntimeJniWrapper.java
b/gluten-arrow/src/main/java/org/apache/gluten/runtime/RuntimeJniWrapper.java
index 80f9509d9e..8c4280a3b5 100644
---
a/gluten-arrow/src/main/java/org/apache/gluten/runtime/RuntimeJniWrapper.java
+++
b/gluten-arrow/src/main/java/org/apache/gluten/runtime/RuntimeJniWrapper.java
@@ -16,21 +16,11 @@
*/
package org.apache.gluten.runtime;
-import org.apache.gluten.memory.listener.ReservationListener;
-
public class RuntimeJniWrapper {
private RuntimeJniWrapper() {}
- public static native long createRuntime(
- String backendType, ReservationListener listener, byte[] sessionConf);
-
- // Memory management.
- public static native byte[] collectMemoryUsage(long handle);
-
- public static native long shrinkMemory(long handle, long size);
-
- public static native void holdMemory(long handle);
+ public static native long createRuntime(String backendType, long nmm, byte[]
sessionConf);
public static native void releaseRuntime(long handle);
}
diff --git
a/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ColumnarBatchOutIterator.java
b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ColumnarBatchOutIterator.java
index b253485c90..4293b2abf8 100644
---
a/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ColumnarBatchOutIterator.java
+++
b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ColumnarBatchOutIterator.java
@@ -78,7 +78,7 @@ public class ColumnarBatchOutIterator extends
ClosableIterator implements Runtim
public void close0() {
// To make sure the outputted batches are still accessible after the
iterator is closed.
// TODO: Remove this API if we have other choice, e.g., hold the pools in
native code.
- runtime.holdMemory();
+ runtime.memoryManager().hold();
nativeClose(iterHandle);
}
}
diff --git
a/gluten-arrow/src/main/java/org/apache/gluten/vectorized/NativePlanEvaluator.java
b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/NativePlanEvaluator.java
index 2a3d6013a9..1c03c415ed 100644
---
a/gluten-arrow/src/main/java/org/apache/gluten/vectorized/NativePlanEvaluator.java
+++
b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/NativePlanEvaluator.java
@@ -73,16 +73,18 @@ public class NativePlanEvaluator {
DebugUtil.saveInputToFile(),
spillDirPath);
final ColumnarBatchOutIterator out = createOutIterator(runtime, itrHandle);
- runtime.addSpiller(
- new Spiller() {
- @Override
- public long spill(MemoryTarget self, Spiller.Phase phase, long size)
{
- if (!Spillers.PHASE_SET_SPILL_ONLY.contains(phase)) {
- return 0L;
- }
- return out.spill(size);
- }
- });
+ runtime
+ .memoryManager()
+ .addSpiller(
+ new Spiller() {
+ @Override
+ public long spill(MemoryTarget self, Spiller.Phase phase, long
size) {
+ if (!Spillers.PHASE_SET_SPILL_ONLY.contains(phase)) {
+ return 0L;
+ }
+ return out.spill(size);
+ }
+ });
return out;
}
diff --git
a/gluten-arrow/src/main/scala/org/apache/gluten/runtime/Runtime.scala
b/gluten-arrow/src/main/scala/org/apache/gluten/memory/NativeMemoryManager.scala
similarity index 69%
copy from gluten-arrow/src/main/scala/org/apache/gluten/runtime/Runtime.scala
copy to
gluten-arrow/src/main/scala/org/apache/gluten/memory/NativeMemoryManager.scala
index 8053e886f2..409c4297fa 100644
--- a/gluten-arrow/src/main/scala/org/apache/gluten/runtime/Runtime.scala
+++
b/gluten-arrow/src/main/scala/org/apache/gluten/memory/NativeMemoryManager.scala
@@ -14,12 +14,11 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.gluten.runtime
+package org.apache.gluten.memory
import org.apache.gluten.GlutenConfig
import org.apache.gluten.backend.Backend
import org.apache.gluten.exception.GlutenException
-import org.apache.gluten.memory.MemoryUsageStatsBuilder
import org.apache.gluten.memory.listener.ReservationListeners
import org.apache.gluten.memory.memtarget.{KnownNameAndStats, MemoryTarget,
Spiller, Spillers}
import org.apache.gluten.proto.MemoryUsageStats
@@ -27,7 +26,7 @@ import org.apache.gluten.utils.ConfigUtil
import org.apache.spark.memory.SparkMemoryUtil
import org.apache.spark.sql.internal.{GlutenConfigUtil, SQLConf}
-import org.apache.spark.task.TaskResource
+import org.apache.spark.task.{TaskResource, TaskResources}
import org.slf4j.LoggerFactory
@@ -36,25 +35,19 @@ import java.util.concurrent.atomic.AtomicBoolean
import scala.collection.JavaConverters._
import scala.collection.mutable
-trait Runtime {
+trait NativeMemoryManager {
def addSpiller(spiller: Spiller): Unit
- def holdMemory(): Unit
- def collectMemoryUsage(): MemoryUsageStats
+ def hold(): Unit
def getHandle(): Long
}
-object Runtime {
- private[runtime] def apply(name: String): Runtime with TaskResource = {
- new RuntimeImpl(name)
- }
-
- private class RuntimeImpl(name: String) extends Runtime with TaskResource {
- private val LOGGER = LoggerFactory.getLogger(classOf[Runtime])
-
+object NativeMemoryManager {
+ private class Impl(name: String) extends NativeMemoryManager with
TaskResource {
+ private val LOGGER = LoggerFactory.getLogger(classOf[NativeMemoryManager])
private val spillers = Spillers.appendable()
private val mutableStats: mutable.Map[String, MemoryUsageStatsBuilder] =
mutable.Map()
- private val rl = ReservationListeners.create(resourceName(), spillers,
mutableStats.asJava)
- private val handle = RuntimeJniWrapper.createRuntime(
+ private val rl = ReservationListeners.create(name, spillers,
mutableStats.asJava)
+ private val handle = NativeMemoryManagerJniWrapper.create(
Backend.get().name(),
rl,
ConfigUtil.serialize(
@@ -62,40 +55,32 @@ object Runtime {
Backend.get().name(),
GlutenConfigUtil.parseConfig(SQLConf.get.getAllConfs)))
)
-
spillers.append(new Spiller() {
override def spill(self: MemoryTarget, phase: Spiller.Phase, size:
Long): Long = {
if (!Spillers.PHASE_SET_SHRINK_ONLY.contains(phase)) {
// Only respond for shrinking.
return 0L
}
- RuntimeJniWrapper.shrinkMemory(handle, size)
+ NativeMemoryManagerJniWrapper.shrink(handle, size)
}
})
mutableStats += "single" -> new MemoryUsageStatsBuilder {
- override def toStats: MemoryUsageStats = collectMemoryUsage()
+ override def toStats: MemoryUsageStats = collectUsage()
}
- private val released: AtomicBoolean = new AtomicBoolean(false)
-
- def getHandle(): Long = handle
-
- def addSpiller(spiller: Spiller): Unit = {
- spillers.append(spiller)
+ private def collectUsage() = {
+
MemoryUsageStats.parseFrom(NativeMemoryManagerJniWrapper.collectUsage(handle))
}
- def holdMemory(): Unit = {
- RuntimeJniWrapper.holdMemory(handle)
- }
-
- def collectMemoryUsage(): MemoryUsageStats = {
- MemoryUsageStats.parseFrom(RuntimeJniWrapper.collectMemoryUsage(handle))
- }
+ private val released: AtomicBoolean = new AtomicBoolean(false)
+ override def addSpiller(spiller: Spiller): Unit = spillers.append(spiller)
+ override def hold(): Unit = NativeMemoryManagerJniWrapper.hold(handle)
+ override def getHandle(): Long = handle
override def release(): Unit = {
if (!released.compareAndSet(false, true)) {
throw new GlutenException(
- s"Runtime instance already released: $handle, ${resourceName()},
${priority()}")
+ s"Memory manager instance already released: $handle,
${resourceName()}, ${priority()}")
}
def dump(): String = {
@@ -103,7 +88,7 @@ object Runtime {
s"[${resourceName()}]",
new KnownNameAndStats() {
override def name: String = resourceName()
- override def stats: MemoryUsageStats = collectMemoryUsage()
+ override def stats: MemoryUsageStats = collectUsage()
})
}
@@ -111,7 +96,7 @@ object Runtime {
LOGGER.debug("About to release memory manager, " + dump())
}
- RuntimeJniWrapper.releaseRuntime(handle)
+ NativeMemoryManagerJniWrapper.release(handle)
if (rl.getUsedBytes != 0) {
LOGGER.warn(
@@ -124,9 +109,11 @@ object Runtime {
))
}
}
-
override def priority(): Int = 0
+ override def resourceName(): String = "nmm"
+ }
- override def resourceName(): String = name
+ def apply(name: String): NativeMemoryManager = {
+ TaskResources.addAnonymousResource(new Impl(name))
}
}
diff --git
a/gluten-arrow/src/main/scala/org/apache/gluten/runtime/Runtime.scala
b/gluten-arrow/src/main/scala/org/apache/gluten/runtime/Runtime.scala
index 8053e886f2..efd4928cfd 100644
--- a/gluten-arrow/src/main/scala/org/apache/gluten/runtime/Runtime.scala
+++ b/gluten-arrow/src/main/scala/org/apache/gluten/runtime/Runtime.scala
@@ -19,13 +19,9 @@ package org.apache.gluten.runtime
import org.apache.gluten.GlutenConfig
import org.apache.gluten.backend.Backend
import org.apache.gluten.exception.GlutenException
-import org.apache.gluten.memory.MemoryUsageStatsBuilder
-import org.apache.gluten.memory.listener.ReservationListeners
-import org.apache.gluten.memory.memtarget.{KnownNameAndStats, MemoryTarget,
Spiller, Spillers}
-import org.apache.gluten.proto.MemoryUsageStats
+import org.apache.gluten.memory.NativeMemoryManager
import org.apache.gluten.utils.ConfigUtil
-import org.apache.spark.memory.SparkMemoryUtil
import org.apache.spark.sql.internal.{GlutenConfigUtil, SQLConf}
import org.apache.spark.task.TaskResource
@@ -33,13 +29,8 @@ import org.slf4j.LoggerFactory
import java.util.concurrent.atomic.AtomicBoolean
-import scala.collection.JavaConverters._
-import scala.collection.mutable
-
trait Runtime {
- def addSpiller(spiller: Spiller): Unit
- def holdMemory(): Unit
- def collectMemoryUsage(): MemoryUsageStats
+ def memoryManager(): NativeMemoryManager
def getHandle(): Long
}
@@ -51,82 +42,33 @@ object Runtime {
private class RuntimeImpl(name: String) extends Runtime with TaskResource {
private val LOGGER = LoggerFactory.getLogger(classOf[Runtime])
- private val spillers = Spillers.appendable()
- private val mutableStats: mutable.Map[String, MemoryUsageStatsBuilder] =
mutable.Map()
- private val rl = ReservationListeners.create(resourceName(), spillers,
mutableStats.asJava)
+ private val nmm: NativeMemoryManager = NativeMemoryManager(name)
private val handle = RuntimeJniWrapper.createRuntime(
Backend.get().name(),
- rl,
+ nmm.getHandle(),
ConfigUtil.serialize(
GlutenConfig.getNativeSessionConf(
Backend.get().name(),
GlutenConfigUtil.parseConfig(SQLConf.get.getAllConfs)))
)
- spillers.append(new Spiller() {
- override def spill(self: MemoryTarget, phase: Spiller.Phase, size:
Long): Long = {
- if (!Spillers.PHASE_SET_SHRINK_ONLY.contains(phase)) {
- // Only respond for shrinking.
- return 0L
- }
- RuntimeJniWrapper.shrinkMemory(handle, size)
- }
- })
- mutableStats += "single" -> new MemoryUsageStatsBuilder {
- override def toStats: MemoryUsageStats = collectMemoryUsage()
- }
-
private val released: AtomicBoolean = new AtomicBoolean(false)
- def getHandle(): Long = handle
+ override def getHandle(): Long = handle
- def addSpiller(spiller: Spiller): Unit = {
- spillers.append(spiller)
- }
-
- def holdMemory(): Unit = {
- RuntimeJniWrapper.holdMemory(handle)
- }
-
- def collectMemoryUsage(): MemoryUsageStats = {
- MemoryUsageStats.parseFrom(RuntimeJniWrapper.collectMemoryUsage(handle))
- }
+ override def memoryManager(): NativeMemoryManager = nmm
override def release(): Unit = {
if (!released.compareAndSet(false, true)) {
throw new GlutenException(
s"Runtime instance already released: $handle, ${resourceName()},
${priority()}")
}
-
- def dump(): String = {
- SparkMemoryUtil.prettyPrintStats(
- s"[${resourceName()}]",
- new KnownNameAndStats() {
- override def name: String = resourceName()
- override def stats: MemoryUsageStats = collectMemoryUsage()
- })
- }
-
- if (LOGGER.isDebugEnabled) {
- LOGGER.debug("About to release memory manager, " + dump())
- }
-
RuntimeJniWrapper.releaseRuntime(handle)
- if (rl.getUsedBytes != 0) {
- LOGGER.warn(
- String.format(
- "%s Reservation listener %s still reserved non-zero bytes, which
may cause memory" +
- " leak, size: %s.",
- name,
- rl.toString,
- SparkMemoryUtil.bytesToString(rl.getUsedBytes)
- ))
- }
}
- override def priority(): Int = 0
+ override def priority(): Int = 20
- override def resourceName(): String = name
+ override def resourceName(): String = s"runtime"
}
}
diff --git
a/gluten-celeborn/velox/src/main/scala/org/apache/spark/shuffle/VeloxCelebornColumnarShuffleWriter.scala
b/gluten-celeborn/velox/src/main/scala/org/apache/spark/shuffle/VeloxCelebornColumnarShuffleWriter.scala
index 0f479a6095..1a0bc475d3 100644
---
a/gluten-celeborn/velox/src/main/scala/org/apache/spark/shuffle/VeloxCelebornColumnarShuffleWriter.scala
+++
b/gluten-celeborn/velox/src/main/scala/org/apache/spark/shuffle/VeloxCelebornColumnarShuffleWriter.scala
@@ -138,18 +138,20 @@ class VeloxCelebornColumnarShuffleWriter[K, V](
shuffleWriterType,
GlutenConfig.getConf.columnarShuffleReallocThreshold
)
- runtime.addSpiller(new Spiller() {
- override def spill(self: MemoryTarget, phase: Spiller.Phase, size:
Long): Long = {
- if (!Spillers.PHASE_SET_SPILL_ONLY.contains(phase)) {
- return 0L
+ runtime
+ .memoryManager()
+ .addSpiller(new Spiller() {
+ override def spill(self: MemoryTarget, phase: Spiller.Phase, size:
Long): Long = {
+ if (!Spillers.PHASE_SET_SPILL_ONLY.contains(phase)) {
+ return 0L
+ }
+ logInfo(s"Gluten shuffle writer: Trying to push $size bytes of data")
+ // fixme pass true when being called by self
+ val pushed = jniWrapper.nativeEvict(nativeShuffleWriter, size, false)
+ logInfo(s"Gluten shuffle writer: Pushed $pushed / $size bytes of
data")
+ pushed
}
- logInfo(s"Gluten shuffle writer: Trying to push $size bytes of data")
- // fixme pass true when being called by self
- val pushed = jniWrapper.nativeEvict(nativeShuffleWriter, size, false)
- logInfo(s"Gluten shuffle writer: Pushed $pushed / $size bytes of data")
- pushed
- }
- })
+ })
}
override def closeShuffleWriter(): Unit = {
diff --git
a/gluten-uniffle/velox/src/main/java/org/apache/spark/shuffle/writer/VeloxUniffleColumnarShuffleWriter.java
b/gluten-uniffle/velox/src/main/java/org/apache/spark/shuffle/writer/VeloxUniffleColumnarShuffleWriter.java
index 8deeab1683..84fac8ace6 100644
---
a/gluten-uniffle/velox/src/main/java/org/apache/spark/shuffle/writer/VeloxUniffleColumnarShuffleWriter.java
+++
b/gluten-uniffle/velox/src/main/java/org/apache/spark/shuffle/writer/VeloxUniffleColumnarShuffleWriter.java
@@ -171,19 +171,21 @@ public class VeloxUniffleColumnarShuffleWriter<K, V>
extends RssShuffleWriter<K,
? GlutenConfig.GLUTEN_SORT_SHUFFLE_WRITER()
: GlutenConfig.GLUTEN_HASH_SHUFFLE_WRITER(),
reallocThreshold);
- runtime.addSpiller(
- new Spiller() {
- @Override
- public long spill(MemoryTarget self, Spiller.Phase phase, long
size) {
- if (!Spillers.PHASE_SET_SPILL_ONLY.contains(phase)) {
- return 0L;
- }
- LOG.info("Gluten shuffle writer: Trying to push {} bytes of
data", size);
- long pushed = jniWrapper.nativeEvict(nativeShuffleWriter,
size, false);
- LOG.info("Gluten shuffle writer: Pushed {} / {} bytes of
data", pushed, size);
- return pushed;
- }
- });
+ runtime
+ .memoryManager()
+ .addSpiller(
+ new Spiller() {
+ @Override
+ public long spill(MemoryTarget self, Spiller.Phase phase,
long size) {
+ if (!Spillers.PHASE_SET_SPILL_ONLY.contains(phase)) {
+ return 0L;
+ }
+ LOG.info("Gluten shuffle writer: Trying to push {} bytes
of data", size);
+ long pushed =
jniWrapper.nativeEvict(nativeShuffleWriter, size, false);
+ LOG.info("Gluten shuffle writer: Pushed {} / {} bytes of
data", pushed, size);
+ return pushed;
+ }
+ });
}
long startTime = System.nanoTime();
long bytes =
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]