This is an automated email from the ASF dual-hosted git repository.

kejia pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new b88d83a4cc [GLUTEN-7548][VL] Follow up hash join optimization PR 8931 
to resolve comments (#11728)
b88d83a4cc is described below

commit b88d83a4cc10e8c3392b2af8f1135609ee497178
Author: JiaKe <[email protected]>
AuthorDate: Tue Mar 10 14:03:51 2026 +0000

    [GLUTEN-7548][VL] Follow up hash join optimization PR 8931 to resolve 
comments (#11728)
---
 .../apache/gluten/vectorized/HashJoinBuilder.java  |  2 +-
 .../org/apache/spark/rpc/GlutenRpcMessages.scala   | 16 ------
 .../sql/execution/ColumnarBuildSideRelation.scala  | 14 +++--
 .../unsafe/UnsafeColumnarBuildSideRelation.scala   | 14 +++--
 cpp/velox/compute/VeloxBackend.cc                  |  1 -
 cpp/velox/jni/JniHashTable.cc                      | 48 ++++++++---------
 cpp/velox/jni/JniHashTable.h                       | 63 +++++++++++++++++++---
 cpp/velox/jni/VeloxJniWrapper.cc                   | 28 ++++++----
 .../gluten/extension/columnar/FallbackRules.scala  |  6 +--
 9 files changed, 111 insertions(+), 81 deletions(-)

diff --git 
a/backends-velox/src/main/java/org/apache/gluten/vectorized/HashJoinBuilder.java
 
b/backends-velox/src/main/java/org/apache/gluten/vectorized/HashJoinBuilder.java
index e54909054c..ebfd47669c 100644
--- 
a/backends-velox/src/main/java/org/apache/gluten/vectorized/HashJoinBuilder.java
+++ 
b/backends-velox/src/main/java/org/apache/gluten/vectorized/HashJoinBuilder.java
@@ -42,7 +42,7 @@ public class HashJoinBuilder implements RuntimeAware {
   public static native long nativeBuild(
       String buildHashTableId,
       long[] batchHandlers,
-      String joinKeys,
+      String[] joinKeys,
       int joinType,
       boolean hasMixedFiltCondition,
       boolean isExistenceJoin,
diff --git 
a/backends-velox/src/main/scala/org/apache/spark/rpc/GlutenRpcMessages.scala 
b/backends-velox/src/main/scala/org/apache/spark/rpc/GlutenRpcMessages.scala
index 8127c324b7..dec67eed78 100644
--- a/backends-velox/src/main/scala/org/apache/spark/rpc/GlutenRpcMessages.scala
+++ b/backends-velox/src/main/scala/org/apache/spark/rpc/GlutenRpcMessages.scala
@@ -34,20 +34,4 @@ object GlutenRpcMessages {
 
   case class GlutenCleanExecutionResource(executionId: String, 
broadcastHashIds: util.Set[String])
     extends GlutenRpcMessage
-
-  // for mergetree cache
-  case class GlutenMergeTreeCacheLoad(
-      mergeTreeTable: String,
-      columns: util.Set[String],
-      onlyMetaCache: Boolean)
-    extends GlutenRpcMessage
-
-  case class GlutenCacheLoadStatus(jobId: String)
-
-  case class CacheJobInfo(status: Boolean, jobId: String, reason: String = "")
-    extends GlutenRpcMessage
-
-  case class GlutenFilesCacheLoad(files: Array[Byte]) extends GlutenRpcMessage
-
-  case class GlutenFilesCacheLoadStatus(jobId: String)
 }
diff --git 
a/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala
 
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala
index 6429f8bb3f..b106319e81 100644
--- 
a/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala
+++ 
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala
@@ -197,20 +197,18 @@ case class ColumnarBuildSideRelation(
           )
         }
 
-        val joinKey = keys.asScala
-          .map {
-            key =>
-              val attr = ConverterUtils.getAttrFromExpr(key)
-              ConverterUtils.genColumnNameWithExprId(attr)
-          }
-          .mkString(",")
+        val joinKeys = keys.asScala.map {
+          key =>
+            val attr = ConverterUtils.getAttrFromExpr(key)
+            ConverterUtils.genColumnNameWithExprId(attr)
+        }.toArray
 
         // Build the hash table
         hashTableData = HashJoinBuilder
           .nativeBuild(
             broadcastContext.buildHashTableId,
             batchArray.toArray,
-            joinKey,
+            joinKeys,
             broadcastContext.substraitJoinType.ordinal(),
             broadcastContext.hasMixedFiltCondition,
             broadcastContext.isExistenceJoin,
diff --git 
a/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelation.scala
 
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelation.scala
index fc7516c4b3..01fbb86bee 100644
--- 
a/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelation.scala
+++ 
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelation.scala
@@ -167,20 +167,18 @@ class UnsafeColumnarBuildSideRelation(
           )
         }
 
-        val joinKey = keys.asScala
-          .map {
-            key =>
-              val attr = ConverterUtils.getAttrFromExpr(key)
-              ConverterUtils.genColumnNameWithExprId(attr)
-          }
-          .mkString(",")
+        val joinKeys = keys.asScala.map {
+          key =>
+            val attr = ConverterUtils.getAttrFromExpr(key)
+            ConverterUtils.genColumnNameWithExprId(attr)
+        }.toArray
 
         // Build the hash table
         hashTableData = HashJoinBuilder
           .nativeBuild(
             broadcastContext.buildHashTableId,
             batchArray.toArray,
-            joinKey,
+            joinKeys,
             broadcastContext.substraitJoinType.ordinal(),
             broadcastContext.hasMixedFiltCondition,
             broadcastContext.isExistenceJoin,
diff --git a/cpp/velox/compute/VeloxBackend.cc 
b/cpp/velox/compute/VeloxBackend.cc
index 0232da48da..de9e9385f8 100644
--- a/cpp/velox/compute/VeloxBackend.cc
+++ b/cpp/velox/compute/VeloxBackend.cc
@@ -362,7 +362,6 @@ void VeloxBackend::tearDown() {
     filesystem->close();
   }
 #endif
-  gluten::hashTableObjStore.reset();
 
   // Destruct IOThreadPoolExecutor will join all threads.
   // On threads exit, thread local variables can be constructed with 
referencing global variables.
diff --git a/cpp/velox/jni/JniHashTable.cc b/cpp/velox/jni/JniHashTable.cc
index 77cd78ff6a..8af60a5534 100644
--- a/cpp/velox/jni/JniHashTable.cc
+++ b/cpp/velox/jni/JniHashTable.cc
@@ -29,24 +29,34 @@
 
 namespace gluten {
 
-static jclass jniVeloxBroadcastBuildSideCache = nullptr;
-static jmethodID jniGet = nullptr;
+void JniHashTableContext::initialize(JNIEnv* env, JavaVM* javaVm) {
+  vm_ = javaVm;
+  const char* classSig = 
"Lorg/apache/gluten/execution/VeloxBroadcastBuildSideCache;";
+  jniVeloxBroadcastBuildSideCache_ = createGlobalClassReferenceOrError(env, 
classSig);
+  jniGet_ = getStaticMethodId(env, jniVeloxBroadcastBuildSideCache_, "get", 
"(Ljava/lang/String;)J");
+}
 
-jlong callJavaGet(const std::string& id) {
+void JniHashTableContext::finalize(JNIEnv* env) {
+  if (jniVeloxBroadcastBuildSideCache_ != nullptr) {
+    env->DeleteGlobalRef(jniVeloxBroadcastBuildSideCache_);
+    jniVeloxBroadcastBuildSideCache_ = nullptr;
+  }
+}
+
+jlong JniHashTableContext::callJavaGet(const std::string& id) const {
   JNIEnv* env;
-  if (vm->GetEnv(reinterpret_cast<void**>(&env), jniVersion) != JNI_OK) {
+  if (vm_->GetEnv(reinterpret_cast<void**>(&env), jniVersion) != JNI_OK) {
     throw gluten::GlutenException("JNIEnv was not attached to current thread");
   }
 
   const jstring s = env->NewStringUTF(id.c_str());
-
-  auto result = env->CallStaticLongMethod(jniVeloxBroadcastBuildSideCache, 
jniGet, s);
+  auto result = env->CallStaticLongMethod(jniVeloxBroadcastBuildSideCache_, 
jniGet_, s);
   return result;
 }
 
 // Return the velox's hash table.
 std::shared_ptr<HashTableBuilder> nativeHashTableBuild(
-    const std::string& joinKeys,
+    const std::vector<std::string>& joinKeys,
     std::vector<std::string> names,
     std::vector<facebook::velox::TypePtr> veloxTypeList,
     int joinType,
@@ -98,12 +108,9 @@ std::shared_ptr<HashTableBuilder> nativeHashTableBuild(
       VELOX_NYI("Unsupported Join type: {}", std::to_string(sJoin));
   }
 
-  std::vector<std::string> joinKeyNames;
-  folly::split(',', joinKeys, joinKeyNames);
-
   std::vector<std::shared_ptr<const 
facebook::velox::core::FieldAccessTypedExpr>> joinKeyTypes;
-  joinKeyTypes.reserve(joinKeyNames.size());
-  for (const auto& name : joinKeyNames) {
+  joinKeyTypes.reserve(joinKeys.size());
+  for (const auto& name : joinKeys) {
     joinKeyTypes.emplace_back(
         
std::make_shared<facebook::velox::core::FieldAccessTypedExpr>(rowType->findChild(name),
 name));
   }
@@ -125,21 +132,8 @@ std::shared_ptr<HashTableBuilder> nativeHashTableBuild(
   return hashTableBuilder;
 }
 
-long getJoin(std::string hashTableId) {
-  return callJavaGet(hashTableId);
-}
-
-void initVeloxJniHashTable(JNIEnv* env) {
-  if (env->GetJavaVM(&vm) != JNI_OK) {
-    throw gluten::GlutenException("Unable to get JavaVM instance");
-  }
-  const char* classSig = 
"Lorg/apache/gluten/execution/VeloxBroadcastBuildSideCache;";
-  jniVeloxBroadcastBuildSideCache = createGlobalClassReferenceOrError(env, 
classSig);
-  jniGet = getStaticMethodId(env, jniVeloxBroadcastBuildSideCache, "get", 
"(Ljava/lang/String;)J");
-}
-
-void finalizeVeloxJniHashTable(JNIEnv* env) {
-  env->DeleteGlobalRef(jniVeloxBroadcastBuildSideCache);
+long getJoin(const std::string& hashTableId) {
+  return JniHashTableContext::getInstance().callJavaGet(hashTableId);
 }
 
 } // namespace gluten
diff --git a/cpp/velox/jni/JniHashTable.h b/cpp/velox/jni/JniHashTable.h
index c0d9227840..27061e1778 100644
--- a/cpp/velox/jni/JniHashTable.h
+++ b/cpp/velox/jni/JniHashTable.h
@@ -26,13 +26,53 @@
 
 namespace gluten {
 
-inline static JavaVM* vm = nullptr;
+// Wrapper class to encapsulate JNI-related static objects for hash table 
operations.
+// This avoids exposing global variables in the gluten namespace.
+class JniHashTableContext {
+ public:
+  static JniHashTableContext& getInstance() {
+    static JniHashTableContext instance;
+    return instance;
+  }
 
-inline static std::unique_ptr<ObjectStore> hashTableObjStore = 
ObjectStore::create();
+  // Delete copy and move constructors/operators
+  JniHashTableContext(const JniHashTableContext&) = delete;
+  JniHashTableContext& operator=(const JniHashTableContext&) = delete;
+  JniHashTableContext(JniHashTableContext&&) = delete;
+  JniHashTableContext& operator=(JniHashTableContext&&) = delete;
+
+  void initialize(JNIEnv* env, JavaVM* javaVm);
+  void finalize(JNIEnv* env);
+
+  JavaVM* getJavaVM() const {
+    return vm_;
+  }
+
+  ObjectStore* getHashTableObjStore() const {
+    return hashTableObjStore_.get();
+  }
+
+  jlong callJavaGet(const std::string& id) const;
+
+ private:
+  JniHashTableContext() : hashTableObjStore_(ObjectStore::create()) {}
+  
+  ~JniHashTableContext() {
+    // Note: The destructor is called at program exit (after main() returns).
+    // By this time, JNI_OnUnload should have already been called, which 
invokes
+    // finalize() to clean up JNI global references while the JVM is still 
valid.
+    // The singleton itself (including hashTableObjStore_) will be destroyed 
here.
+  }
+
+  JavaVM* vm_{nullptr};
+  std::unique_ptr<ObjectStore> hashTableObjStore_;
+  jclass jniVeloxBroadcastBuildSideCache_{nullptr};
+  jmethodID jniGet_{nullptr};
+};
 
 // Return the hash table builder address.
 std::shared_ptr<HashTableBuilder> nativeHashTableBuild(
-    const std::string& joinKeys,
+    const std::vector<std::string>& joinKeys,
     std::vector<std::string> names,
     std::vector<facebook::velox::TypePtr> veloxTypeList,
     int joinType,
@@ -43,12 +83,21 @@ std::shared_ptr<HashTableBuilder> nativeHashTableBuild(
     std::vector<std::shared_ptr<ColumnarBatch>>& batches,
     std::shared_ptr<facebook::velox::memory::MemoryPool> memoryPool);
 
-long getJoin(std::string hashTableId);
+long getJoin(const std::string& hashTableId);
 
-void initVeloxJniHashTable(JNIEnv* env);
+// Initialize the JNI hash table context
+inline void initVeloxJniHashTable(JNIEnv* env, JavaVM* javaVm) {
+  JniHashTableContext::getInstance().initialize(env, javaVm);
+}
 
-void finalizeVeloxJniHashTable(JNIEnv* env);
+// Finalize the JNI hash table context
+inline void finalizeVeloxJniHashTable(JNIEnv* env) {
+  JniHashTableContext::getInstance().finalize(env);
+}
 
-jlong callJavaGet(const std::string& id);
+// Get hash table object store
+inline ObjectStore* getHashTableObjStore() {
+  return JniHashTableContext::getInstance().getHashTableObjStore();
+}
 
 } // namespace gluten
diff --git a/cpp/velox/jni/VeloxJniWrapper.cc b/cpp/velox/jni/VeloxJniWrapper.cc
index e488274e97..ed1cd5e85d 100644
--- a/cpp/velox/jni/VeloxJniWrapper.cc
+++ b/cpp/velox/jni/VeloxJniWrapper.cc
@@ -80,7 +80,7 @@ jint JNI_OnLoad(JavaVM* vm, void*) {
   getJniErrorState()->ensureInitialized(env);
   initVeloxJniFileSystem(env);
   initVeloxJniUDF(env);
-  initVeloxJniHashTable(env);
+  initVeloxJniHashTable(env, vm);
 
   infoCls = createGlobalClassReferenceOrError(env, 
"Lorg/apache/gluten/validate/NativePlanValidationInfo;");
   infoClsInitMethod = getMethodIdOrError(env, infoCls, "<init>", 
"(ILjava/lang/String;)V");
@@ -94,8 +94,6 @@ jint JNI_OnLoad(JavaVM* vm, void*) {
 
   DLOG(INFO) << "Loaded Velox backend.";
 
-  gluten::vm = vm;
-
   return jniVersion;
 }
 
@@ -108,6 +106,7 @@ void JNI_OnUnload(JavaVM* vm, void*) {
 
   finalizeVeloxJniUDF(env);
   finalizeVeloxJniFileSystem(env);
+  finalizeVeloxJniHashTable(env);
   getJniErrorState()->close();
   getJniCommonState()->close();
   google::ShutdownGoogleLogging();
@@ -939,7 +938,7 @@ JNIEXPORT jlong JNICALL 
Java_org_apache_gluten_vectorized_HashJoinBuilder_native
     jclass,
     jstring tableId,
     jlongArray batchHandles,
-    jstring joinKey,
+    jobjectArray joinKeys,
     jint joinType,
     jboolean hasMixedJoinCondition,
     jboolean isExistenceJoin,
@@ -949,7 +948,16 @@ JNIEXPORT jlong JNICALL 
Java_org_apache_gluten_vectorized_HashJoinBuilder_native
     jint broadcastHashTableBuildThreads) {
   JNI_METHOD_START
   const auto hashTableId = jStringToCString(env, tableId);
-  const auto hashJoinKey = jStringToCString(env, joinKey);
+
+  // Convert Java String array to C++ vector<string>
+  std::vector<std::string> hashJoinKeys;
+  jsize joinKeysCount = env->GetArrayLength(joinKeys);
+  hashJoinKeys.reserve(joinKeysCount);
+  for (jsize i = 0; i < joinKeysCount; ++i) {
+    jstring jkey = (jstring)env->GetObjectArrayElement(joinKeys, i);
+    hashJoinKeys.emplace_back(jStringToCString(env, jkey));
+  }
+
   const auto inputType = gluten::getByteArrayElementsSafe(env, namedStruct);
   std::string structString{
       reinterpret_cast<const char*>(inputType.elems()), 
static_cast<std::string::size_type>(inputType.length())};
@@ -988,7 +996,7 @@ JNIEXPORT jlong JNICALL 
Java_org_apache_gluten_vectorized_HashJoinBuilder_native
 
   if (numThreads <= 1) {
     auto builder = nativeHashTableBuild(
-        hashJoinKey,
+        hashJoinKeys,
         names,
         veloxTypeList,
         joinType,
@@ -1008,7 +1016,7 @@ JNIEXPORT jlong JNICALL 
Java_org_apache_gluten_vectorized_HashJoinBuilder_native
         nullptr);
     builder->setHashTable(std::move(mainTable));
 
-    return gluten::hashTableObjStore->save(builder);
+    return gluten::getHashTableObjStore()->save(builder);
   }
 
   std::vector<std::thread> threads;
@@ -1027,7 +1035,7 @@ JNIEXPORT jlong JNICALL 
Java_org_apache_gluten_vectorized_HashJoinBuilder_native
       }
 
       auto builder = nativeHashTableBuild(
-          hashJoinKey,
+          hashJoinKeys,
           names,
           veloxTypeList,
           joinType,
@@ -1073,7 +1081,7 @@ JNIEXPORT jlong JNICALL 
Java_org_apache_gluten_vectorized_HashJoinBuilder_native
   }
 
   hashTableBuilders[0]->setHashTable(std::move(mainTable));
-  return gluten::hashTableObjStore->save(hashTableBuilders[0]);
+  return gluten::getHashTableObjStore()->save(hashTableBuilders[0]);
   JNI_METHOD_END(kInvalidObjectHandle)
 }
 
@@ -1083,7 +1091,7 @@ JNIEXPORT jlong JNICALL 
Java_org_apache_gluten_vectorized_HashJoinBuilder_cloneH
     jlong tableHandler) {
   JNI_METHOD_START
   auto hashTableHandler = 
ObjectStore::retrieve<gluten::HashTableBuilder>(tableHandler);
-  return gluten::hashTableObjStore->save(hashTableHandler);
+  return gluten::getHashTableObjStore()->save(hashTableHandler);
   JNI_METHOD_END(kInvalidObjectHandle)
 }
 
diff --git 
a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/FallbackRules.scala
 
b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/FallbackRules.scala
index 5e6c777922..76d8a50ccd 100644
--- 
a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/FallbackRules.scala
+++ 
b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/FallbackRules.scala
@@ -44,14 +44,14 @@ case class FallbackMultiCodegens(session: SparkSession) 
extends Rule[SparkPlan]
     plan match {
       case plan: CodegenSupport if plan.supportCodegen =>
         if (
-          (count + 1) >= optimizeLevel && 
plan.output.map(_.dataType.defaultSize).sum == outputSize
+          (count + 1) >= optimizeLevel && 
plan.output.map(_.dataType.defaultSize).sum >= outputSize
         ) {
           return true
         }
         plan.children.exists(existsMultiCodegens(_, count + 1))
       case plan: ShuffledHashJoinExec =>
         if (
-          (count + 1) >= optimizeLevel && 
plan.output.map(_.dataType.defaultSize).sum == outputSize
+          (count + 1) >= optimizeLevel && 
plan.output.map(_.dataType.defaultSize).sum >= outputSize
         ) {
           return true
         }
@@ -59,7 +59,7 @@ case class FallbackMultiCodegens(session: SparkSession) 
extends Rule[SparkPlan]
         plan.children.exists(existsMultiCodegens(_, count + 1))
       case plan: SortMergeJoinExec if GlutenConfig.get.forceShuffledHashJoin =>
         if (
-          (count + 1) >= optimizeLevel && 
plan.output.map(_.dataType.defaultSize).sum == outputSize
+          (count + 1) >= optimizeLevel && 
plan.output.map(_.dataType.defaultSize).sum >= outputSize
         ) {
           return true
         }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to