JkSelf commented on code in PR #8931:
URL: https://github.com/apache/incubator-gluten/pull/8931#discussion_r2904457634


##########
cpp/velox/jni/VeloxJniWrapper.cc:
##########
@@ -914,18 +922,181 @@ JNIEXPORT jobject JNICALL 
Java_org_apache_gluten_execution_IcebergWriteJniWrappe
   auto writer = ObjectStore::retrieve<IcebergWriter>(writerHandle);
   auto writeStats = writer->writeStats();
   jobject writeMetrics = env->NewObject(
-    batchWriteMetricsClass,
-    batchWriteMetricsConstructor,
-    writeStats.numWrittenBytes,
-    writeStats.numWrittenFiles,
-    writeStats.writeIOTimeNs,
-    writeStats.writeWallNs);
+      batchWriteMetricsClass,
+      batchWriteMetricsConstructor,
+      writeStats.numWrittenBytes,
+      writeStats.numWrittenFiles,
+      writeStats.writeIOTimeNs,
+      writeStats.writeWallNs);
   return writeMetrics;
 
   JNI_METHOD_END(nullptr)
 }
 #endif
 
+JNIEXPORT jlong JNICALL 
Java_org_apache_gluten_vectorized_HashJoinBuilder_nativeBuild( // NOLINT
+    JNIEnv* env,
+    jclass,
+    jstring tableId,
+    jlongArray batchHandles,
+    jstring joinKey,
+    jint joinType,
+    jboolean hasMixedJoinCondition,
+    jboolean isExistenceJoin,
+    jbyteArray namedStruct,
+    jboolean isNullAwareAntiJoin,
+    jlong bloomFilterPushdownSize,
+    jint broadcastHashTableBuildThreads) {
+  JNI_METHOD_START
+  const auto hashTableId = jStringToCString(env, tableId);
+  const auto hashJoinKey = jStringToCString(env, joinKey);
+  const auto inputType = gluten::getByteArrayElementsSafe(env, namedStruct);
+  std::string structString{
+      reinterpret_cast<const char*>(inputType.elems()), 
static_cast<std::string::size_type>(inputType.length())};
+
+  substrait::NamedStruct substraitStruct;
+  substraitStruct.ParseFromString(structString);
+
+  std::vector<facebook::velox::TypePtr> veloxTypeList;
+  veloxTypeList = SubstraitParser::parseNamedStruct(substraitStruct);
+
+  const auto& substraitNames = substraitStruct.names();
+
+  std::vector<std::string> names;
+  names.reserve(substraitNames.size());
+  for (const auto& name : substraitNames) {
+    names.emplace_back(name);
+  }
+
+  std::vector<std::shared_ptr<ColumnarBatch>> cb;
+  int handleCount = env->GetArrayLength(batchHandles);
+  auto safeArray = getLongArrayElementsSafe(env, batchHandles);
+  for (int i = 0; i < handleCount; ++i) {
+    int64_t handle = safeArray.elems()[i];
+    cb.push_back(ObjectStore::retrieve<ColumnarBatch>(handle));
+  }
+
+  size_t maxThreads = broadcastHashTableBuildThreads > 0
+      ? std::min((size_t)broadcastHashTableBuildThreads, (size_t)32)
+      : std::min((size_t)std::thread::hardware_concurrency(), (size_t)32);
+
+  // Heuristic: Each thread should process at least a certain number of 
batches to justify parallelism overhead.
+  // 32 batches is roughly 128k rows, which is a reasonable granularity for a 
single thread.
+  constexpr size_t kMinBatchesPerThread = 32;
+  size_t numThreads = std::min(maxThreads, (handleCount + kMinBatchesPerThread 
- 1) / kMinBatchesPerThread);
+  numThreads = std::max((size_t)1, numThreads);
+
+  if (numThreads <= 1) {
+    auto builder = nativeHashTableBuild(
+        hashJoinKey,
+        names,
+        veloxTypeList,
+        joinType,
+        hasMixedJoinCondition,
+        isExistenceJoin,
+        isNullAwareAntiJoin,
+        bloomFilterPushdownSize,
+        cb,
+        defaultLeafVeloxMemoryPool());
+
+    auto mainTable = builder->uniqueTable();
+    mainTable->prepareJoinTable(
+        {},
+        facebook::velox::exec::BaseHashTable::kNoSpillInputStartPartitionBit,
+        1'000'000,
+        builder->dropDuplicates(),
+        nullptr);
+    builder->setHashTable(std::move(mainTable));
+
+    return gluten::hashTableObjStore->save(builder);
+  }
+
+  std::vector<std::thread> threads;
+
+  std::vector<std::shared_ptr<gluten::HashTableBuilder>> 
hashTableBuilders(numThreads);
+  std::vector<std::unique_ptr<facebook::velox::exec::BaseHashTable>> 
otherTables(numThreads);
+
+  for (size_t t = 0; t < numThreads; ++t) {
+    size_t start = (handleCount * t) / numThreads;
+    size_t end = (handleCount * (t + 1)) / numThreads;
+
+    threads.emplace_back([&, t, start, end]() {
+      std::vector<std::shared_ptr<gluten::ColumnarBatch>> threadBatches;
+      for (size_t i = start; i < end; ++i) {
+        threadBatches.push_back(cb[i]);
+      }
+
+      auto builder = nativeHashTableBuild(
+          hashJoinKey,
+          names,
+          veloxTypeList,
+          joinType,
+          hasMixedJoinCondition,
+          isExistenceJoin,
+          isNullAwareAntiJoin,
+          bloomFilterPushdownSize,
+          threadBatches,
+          defaultLeafVeloxMemoryPool());

Review Comment:
   Are you suggesting we track this memory and implement spilling logic when 
usage exceeds the limit?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to