This is an automated email from the ASF dual-hosted git repository.
kejia pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new b88d83a4cc [GLUTEN-7548][VL] Follow up hash join optimization PR 8931
to resolve comments (#11728)
b88d83a4cc is described below
commit b88d83a4cc10e8c3392b2af8f1135609ee497178
Author: JiaKe <[email protected]>
AuthorDate: Tue Mar 10 14:03:51 2026 +0000
[GLUTEN-7548][VL] Follow up hash join optimization PR 8931 to resolve
comments (#11728)
---
.../apache/gluten/vectorized/HashJoinBuilder.java | 2 +-
.../org/apache/spark/rpc/GlutenRpcMessages.scala | 16 ------
.../sql/execution/ColumnarBuildSideRelation.scala | 14 +++--
.../unsafe/UnsafeColumnarBuildSideRelation.scala | 14 +++--
cpp/velox/compute/VeloxBackend.cc | 1 -
cpp/velox/jni/JniHashTable.cc | 48 ++++++++---------
cpp/velox/jni/JniHashTable.h | 63 +++++++++++++++++++---
cpp/velox/jni/VeloxJniWrapper.cc | 28 ++++++----
.../gluten/extension/columnar/FallbackRules.scala | 6 +--
9 files changed, 111 insertions(+), 81 deletions(-)
diff --git
a/backends-velox/src/main/java/org/apache/gluten/vectorized/HashJoinBuilder.java
b/backends-velox/src/main/java/org/apache/gluten/vectorized/HashJoinBuilder.java
index e54909054c..ebfd47669c 100644
---
a/backends-velox/src/main/java/org/apache/gluten/vectorized/HashJoinBuilder.java
+++
b/backends-velox/src/main/java/org/apache/gluten/vectorized/HashJoinBuilder.java
@@ -42,7 +42,7 @@ public class HashJoinBuilder implements RuntimeAware {
public static native long nativeBuild(
String buildHashTableId,
long[] batchHandlers,
- String joinKeys,
+ String[] joinKeys,
int joinType,
boolean hasMixedFiltCondition,
boolean isExistenceJoin,
diff --git
a/backends-velox/src/main/scala/org/apache/spark/rpc/GlutenRpcMessages.scala
b/backends-velox/src/main/scala/org/apache/spark/rpc/GlutenRpcMessages.scala
index 8127c324b7..dec67eed78 100644
--- a/backends-velox/src/main/scala/org/apache/spark/rpc/GlutenRpcMessages.scala
+++ b/backends-velox/src/main/scala/org/apache/spark/rpc/GlutenRpcMessages.scala
@@ -34,20 +34,4 @@ object GlutenRpcMessages {
case class GlutenCleanExecutionResource(executionId: String,
broadcastHashIds: util.Set[String])
extends GlutenRpcMessage
-
- // for mergetree cache
- case class GlutenMergeTreeCacheLoad(
- mergeTreeTable: String,
- columns: util.Set[String],
- onlyMetaCache: Boolean)
- extends GlutenRpcMessage
-
- case class GlutenCacheLoadStatus(jobId: String)
-
- case class CacheJobInfo(status: Boolean, jobId: String, reason: String = "")
- extends GlutenRpcMessage
-
- case class GlutenFilesCacheLoad(files: Array[Byte]) extends GlutenRpcMessage
-
- case class GlutenFilesCacheLoadStatus(jobId: String)
}
diff --git
a/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala
index 6429f8bb3f..b106319e81 100644
---
a/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala
+++
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala
@@ -197,20 +197,18 @@ case class ColumnarBuildSideRelation(
)
}
- val joinKey = keys.asScala
- .map {
- key =>
- val attr = ConverterUtils.getAttrFromExpr(key)
- ConverterUtils.genColumnNameWithExprId(attr)
- }
- .mkString(",")
+ val joinKeys = keys.asScala.map {
+ key =>
+ val attr = ConverterUtils.getAttrFromExpr(key)
+ ConverterUtils.genColumnNameWithExprId(attr)
+ }.toArray
// Build the hash table
hashTableData = HashJoinBuilder
.nativeBuild(
broadcastContext.buildHashTableId,
batchArray.toArray,
- joinKey,
+ joinKeys,
broadcastContext.substraitJoinType.ordinal(),
broadcastContext.hasMixedFiltCondition,
broadcastContext.isExistenceJoin,
diff --git
a/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelation.scala
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelation.scala
index fc7516c4b3..01fbb86bee 100644
---
a/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelation.scala
+++
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelation.scala
@@ -167,20 +167,18 @@ class UnsafeColumnarBuildSideRelation(
)
}
- val joinKey = keys.asScala
- .map {
- key =>
- val attr = ConverterUtils.getAttrFromExpr(key)
- ConverterUtils.genColumnNameWithExprId(attr)
- }
- .mkString(",")
+ val joinKeys = keys.asScala.map {
+ key =>
+ val attr = ConverterUtils.getAttrFromExpr(key)
+ ConverterUtils.genColumnNameWithExprId(attr)
+ }.toArray
// Build the hash table
hashTableData = HashJoinBuilder
.nativeBuild(
broadcastContext.buildHashTableId,
batchArray.toArray,
- joinKey,
+ joinKeys,
broadcastContext.substraitJoinType.ordinal(),
broadcastContext.hasMixedFiltCondition,
broadcastContext.isExistenceJoin,
diff --git a/cpp/velox/compute/VeloxBackend.cc
b/cpp/velox/compute/VeloxBackend.cc
index 0232da48da..de9e9385f8 100644
--- a/cpp/velox/compute/VeloxBackend.cc
+++ b/cpp/velox/compute/VeloxBackend.cc
@@ -362,7 +362,6 @@ void VeloxBackend::tearDown() {
filesystem->close();
}
#endif
- gluten::hashTableObjStore.reset();
// Destruct IOThreadPoolExecutor will join all threads.
// On threads exit, thread local variables can be constructed with
referencing global variables.
diff --git a/cpp/velox/jni/JniHashTable.cc b/cpp/velox/jni/JniHashTable.cc
index 77cd78ff6a..8af60a5534 100644
--- a/cpp/velox/jni/JniHashTable.cc
+++ b/cpp/velox/jni/JniHashTable.cc
@@ -29,24 +29,34 @@
namespace gluten {
-static jclass jniVeloxBroadcastBuildSideCache = nullptr;
-static jmethodID jniGet = nullptr;
+void JniHashTableContext::initialize(JNIEnv* env, JavaVM* javaVm) {
+ vm_ = javaVm;
+ const char* classSig =
"Lorg/apache/gluten/execution/VeloxBroadcastBuildSideCache;";
+ jniVeloxBroadcastBuildSideCache_ = createGlobalClassReferenceOrError(env,
classSig);
+ jniGet_ = getStaticMethodId(env, jniVeloxBroadcastBuildSideCache_, "get",
"(Ljava/lang/String;)J");
+}
-jlong callJavaGet(const std::string& id) {
+void JniHashTableContext::finalize(JNIEnv* env) {
+ if (jniVeloxBroadcastBuildSideCache_ != nullptr) {
+ env->DeleteGlobalRef(jniVeloxBroadcastBuildSideCache_);
+ jniVeloxBroadcastBuildSideCache_ = nullptr;
+ }
+}
+
+jlong JniHashTableContext::callJavaGet(const std::string& id) const {
JNIEnv* env;
- if (vm->GetEnv(reinterpret_cast<void**>(&env), jniVersion) != JNI_OK) {
+ if (vm_->GetEnv(reinterpret_cast<void**>(&env), jniVersion) != JNI_OK) {
throw gluten::GlutenException("JNIEnv was not attached to current thread");
}
const jstring s = env->NewStringUTF(id.c_str());
-
- auto result = env->CallStaticLongMethod(jniVeloxBroadcastBuildSideCache,
jniGet, s);
+ auto result = env->CallStaticLongMethod(jniVeloxBroadcastBuildSideCache_,
jniGet_, s);
return result;
}
// Return the velox's hash table.
std::shared_ptr<HashTableBuilder> nativeHashTableBuild(
- const std::string& joinKeys,
+ const std::vector<std::string>& joinKeys,
std::vector<std::string> names,
std::vector<facebook::velox::TypePtr> veloxTypeList,
int joinType,
@@ -98,12 +108,9 @@ std::shared_ptr<HashTableBuilder> nativeHashTableBuild(
VELOX_NYI("Unsupported Join type: {}", std::to_string(sJoin));
}
- std::vector<std::string> joinKeyNames;
- folly::split(',', joinKeys, joinKeyNames);
-
std::vector<std::shared_ptr<const
facebook::velox::core::FieldAccessTypedExpr>> joinKeyTypes;
- joinKeyTypes.reserve(joinKeyNames.size());
- for (const auto& name : joinKeyNames) {
+ joinKeyTypes.reserve(joinKeys.size());
+ for (const auto& name : joinKeys) {
joinKeyTypes.emplace_back(
std::make_shared<facebook::velox::core::FieldAccessTypedExpr>(rowType->findChild(name),
name));
}
@@ -125,21 +132,8 @@ std::shared_ptr<HashTableBuilder> nativeHashTableBuild(
return hashTableBuilder;
}
-long getJoin(std::string hashTableId) {
- return callJavaGet(hashTableId);
-}
-
-void initVeloxJniHashTable(JNIEnv* env) {
- if (env->GetJavaVM(&vm) != JNI_OK) {
- throw gluten::GlutenException("Unable to get JavaVM instance");
- }
- const char* classSig =
"Lorg/apache/gluten/execution/VeloxBroadcastBuildSideCache;";
- jniVeloxBroadcastBuildSideCache = createGlobalClassReferenceOrError(env,
classSig);
- jniGet = getStaticMethodId(env, jniVeloxBroadcastBuildSideCache, "get",
"(Ljava/lang/String;)J");
-}
-
-void finalizeVeloxJniHashTable(JNIEnv* env) {
- env->DeleteGlobalRef(jniVeloxBroadcastBuildSideCache);
+long getJoin(const std::string& hashTableId) {
+ return JniHashTableContext::getInstance().callJavaGet(hashTableId);
}
} // namespace gluten
diff --git a/cpp/velox/jni/JniHashTable.h b/cpp/velox/jni/JniHashTable.h
index c0d9227840..27061e1778 100644
--- a/cpp/velox/jni/JniHashTable.h
+++ b/cpp/velox/jni/JniHashTable.h
@@ -26,13 +26,53 @@
namespace gluten {
-inline static JavaVM* vm = nullptr;
+// Wrapper class to encapsulate JNI-related static objects for hash table
operations.
+// This avoids exposing global variables in the gluten namespace.
+class JniHashTableContext {
+ public:
+ static JniHashTableContext& getInstance() {
+ static JniHashTableContext instance;
+ return instance;
+ }
-inline static std::unique_ptr<ObjectStore> hashTableObjStore =
ObjectStore::create();
+ // Delete copy and move constructors/operators
+ JniHashTableContext(const JniHashTableContext&) = delete;
+ JniHashTableContext& operator=(const JniHashTableContext&) = delete;
+ JniHashTableContext(JniHashTableContext&&) = delete;
+ JniHashTableContext& operator=(JniHashTableContext&&) = delete;
+
+ void initialize(JNIEnv* env, JavaVM* javaVm);
+ void finalize(JNIEnv* env);
+
+ JavaVM* getJavaVM() const {
+ return vm_;
+ }
+
+ ObjectStore* getHashTableObjStore() const {
+ return hashTableObjStore_.get();
+ }
+
+ jlong callJavaGet(const std::string& id) const;
+
+ private:
+ JniHashTableContext() : hashTableObjStore_(ObjectStore::create()) {}
+
+ ~JniHashTableContext() {
+ // Note: The destructor is called at program exit (after main() returns).
+ // By this time, JNI_OnUnload should have already been called, which
invokes
+ // finalize() to clean up JNI global references while the JVM is still
valid.
+ // The singleton itself (including hashTableObjStore_) will be destroyed
here.
+ }
+
+ JavaVM* vm_{nullptr};
+ std::unique_ptr<ObjectStore> hashTableObjStore_;
+ jclass jniVeloxBroadcastBuildSideCache_{nullptr};
+ jmethodID jniGet_{nullptr};
+};
// Return the hash table builder address.
std::shared_ptr<HashTableBuilder> nativeHashTableBuild(
- const std::string& joinKeys,
+ const std::vector<std::string>& joinKeys,
std::vector<std::string> names,
std::vector<facebook::velox::TypePtr> veloxTypeList,
int joinType,
@@ -43,12 +83,21 @@ std::shared_ptr<HashTableBuilder> nativeHashTableBuild(
std::vector<std::shared_ptr<ColumnarBatch>>& batches,
std::shared_ptr<facebook::velox::memory::MemoryPool> memoryPool);
-long getJoin(std::string hashTableId);
+long getJoin(const std::string& hashTableId);
-void initVeloxJniHashTable(JNIEnv* env);
+// Initialize the JNI hash table context
+inline void initVeloxJniHashTable(JNIEnv* env, JavaVM* javaVm) {
+ JniHashTableContext::getInstance().initialize(env, javaVm);
+}
-void finalizeVeloxJniHashTable(JNIEnv* env);
+// Finalize the JNI hash table context
+inline void finalizeVeloxJniHashTable(JNIEnv* env) {
+ JniHashTableContext::getInstance().finalize(env);
+}
-jlong callJavaGet(const std::string& id);
+// Get hash table object store
+inline ObjectStore* getHashTableObjStore() {
+ return JniHashTableContext::getInstance().getHashTableObjStore();
+}
} // namespace gluten
diff --git a/cpp/velox/jni/VeloxJniWrapper.cc b/cpp/velox/jni/VeloxJniWrapper.cc
index e488274e97..ed1cd5e85d 100644
--- a/cpp/velox/jni/VeloxJniWrapper.cc
+++ b/cpp/velox/jni/VeloxJniWrapper.cc
@@ -80,7 +80,7 @@ jint JNI_OnLoad(JavaVM* vm, void*) {
getJniErrorState()->ensureInitialized(env);
initVeloxJniFileSystem(env);
initVeloxJniUDF(env);
- initVeloxJniHashTable(env);
+ initVeloxJniHashTable(env, vm);
infoCls = createGlobalClassReferenceOrError(env,
"Lorg/apache/gluten/validate/NativePlanValidationInfo;");
infoClsInitMethod = getMethodIdOrError(env, infoCls, "<init>",
"(ILjava/lang/String;)V");
@@ -94,8 +94,6 @@ jint JNI_OnLoad(JavaVM* vm, void*) {
DLOG(INFO) << "Loaded Velox backend.";
- gluten::vm = vm;
-
return jniVersion;
}
@@ -108,6 +106,7 @@ void JNI_OnUnload(JavaVM* vm, void*) {
finalizeVeloxJniUDF(env);
finalizeVeloxJniFileSystem(env);
+ finalizeVeloxJniHashTable(env);
getJniErrorState()->close();
getJniCommonState()->close();
google::ShutdownGoogleLogging();
@@ -939,7 +938,7 @@ JNIEXPORT jlong JNICALL
Java_org_apache_gluten_vectorized_HashJoinBuilder_native
jclass,
jstring tableId,
jlongArray batchHandles,
- jstring joinKey,
+ jobjectArray joinKeys,
jint joinType,
jboolean hasMixedJoinCondition,
jboolean isExistenceJoin,
@@ -949,7 +948,16 @@ JNIEXPORT jlong JNICALL
Java_org_apache_gluten_vectorized_HashJoinBuilder_native
jint broadcastHashTableBuildThreads) {
JNI_METHOD_START
const auto hashTableId = jStringToCString(env, tableId);
- const auto hashJoinKey = jStringToCString(env, joinKey);
+
+ // Convert Java String array to C++ vector<string>
+ std::vector<std::string> hashJoinKeys;
+ jsize joinKeysCount = env->GetArrayLength(joinKeys);
+ hashJoinKeys.reserve(joinKeysCount);
+ for (jsize i = 0; i < joinKeysCount; ++i) {
+ jstring jkey = (jstring)env->GetObjectArrayElement(joinKeys, i);
+ hashJoinKeys.emplace_back(jStringToCString(env, jkey));
+ }
+
const auto inputType = gluten::getByteArrayElementsSafe(env, namedStruct);
std::string structString{
reinterpret_cast<const char*>(inputType.elems()),
static_cast<std::string::size_type>(inputType.length())};
@@ -988,7 +996,7 @@ JNIEXPORT jlong JNICALL
Java_org_apache_gluten_vectorized_HashJoinBuilder_native
if (numThreads <= 1) {
auto builder = nativeHashTableBuild(
- hashJoinKey,
+ hashJoinKeys,
names,
veloxTypeList,
joinType,
@@ -1008,7 +1016,7 @@ JNIEXPORT jlong JNICALL
Java_org_apache_gluten_vectorized_HashJoinBuilder_native
nullptr);
builder->setHashTable(std::move(mainTable));
- return gluten::hashTableObjStore->save(builder);
+ return gluten::getHashTableObjStore()->save(builder);
}
std::vector<std::thread> threads;
@@ -1027,7 +1035,7 @@ JNIEXPORT jlong JNICALL
Java_org_apache_gluten_vectorized_HashJoinBuilder_native
}
auto builder = nativeHashTableBuild(
- hashJoinKey,
+ hashJoinKeys,
names,
veloxTypeList,
joinType,
@@ -1073,7 +1081,7 @@ JNIEXPORT jlong JNICALL
Java_org_apache_gluten_vectorized_HashJoinBuilder_native
}
hashTableBuilders[0]->setHashTable(std::move(mainTable));
- return gluten::hashTableObjStore->save(hashTableBuilders[0]);
+ return gluten::getHashTableObjStore()->save(hashTableBuilders[0]);
JNI_METHOD_END(kInvalidObjectHandle)
}
@@ -1083,7 +1091,7 @@ JNIEXPORT jlong JNICALL
Java_org_apache_gluten_vectorized_HashJoinBuilder_cloneH
jlong tableHandler) {
JNI_METHOD_START
auto hashTableHandler =
ObjectStore::retrieve<gluten::HashTableBuilder>(tableHandler);
- return gluten::hashTableObjStore->save(hashTableHandler);
+ return gluten::getHashTableObjStore()->save(hashTableHandler);
JNI_METHOD_END(kInvalidObjectHandle)
}
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/FallbackRules.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/FallbackRules.scala
index 5e6c777922..76d8a50ccd 100644
---
a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/FallbackRules.scala
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/FallbackRules.scala
@@ -44,14 +44,14 @@ case class FallbackMultiCodegens(session: SparkSession)
extends Rule[SparkPlan]
plan match {
case plan: CodegenSupport if plan.supportCodegen =>
if (
- (count + 1) >= optimizeLevel &&
plan.output.map(_.dataType.defaultSize).sum == outputSize
+ (count + 1) >= optimizeLevel &&
plan.output.map(_.dataType.defaultSize).sum >= outputSize
) {
return true
}
plan.children.exists(existsMultiCodegens(_, count + 1))
case plan: ShuffledHashJoinExec =>
if (
- (count + 1) >= optimizeLevel &&
plan.output.map(_.dataType.defaultSize).sum == outputSize
+ (count + 1) >= optimizeLevel &&
plan.output.map(_.dataType.defaultSize).sum >= outputSize
) {
return true
}
@@ -59,7 +59,7 @@ case class FallbackMultiCodegens(session: SparkSession)
extends Rule[SparkPlan]
plan.children.exists(existsMultiCodegens(_, count + 1))
case plan: SortMergeJoinExec if GlutenConfig.get.forceShuffledHashJoin =>
if (
- (count + 1) >= optimizeLevel &&
plan.output.map(_.dataType.defaultSize).sum == outputSize
+ (count + 1) >= optimizeLevel &&
plan.output.map(_.dataType.defaultSize).sum >= outputSize
) {
return true
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]