save-buffer commented on code in PR #12289:
URL: https://github.com/apache/arrow/pull/12289#discussion_r859065009


##########
cpp/src/arrow/compute/exec/bloom_filter_test.cc:
##########
@@ -32,39 +33,107 @@
 namespace arrow {
 namespace compute {
 
-Status BuildBloomFilter(BloomFilterBuildStrategy strategy, int64_t 
hardware_flags,
-                        MemoryPool* pool, int64_t num_rows,
-                        std::function<void(int64_t, int, uint32_t*)> 
get_hash32_impl,
-                        std::function<void(int64_t, int, uint64_t*)> 
get_hash64_impl,
-                        BlockedBloomFilter* target) {
-  constexpr int batch_size_max = 32 * 1024;
-  int64_t num_batches = bit_util::CeilDiv(num_rows, batch_size_max);
-
-  auto builder = BloomFilterBuilder::Make(strategy);
-
-  std::vector<uint32_t> thread_local_hashes32;
-  std::vector<uint64_t> thread_local_hashes64;
-  thread_local_hashes32.resize(batch_size_max);
-  thread_local_hashes64.resize(batch_size_max);
-
-  RETURN_NOT_OK(builder->Begin(/*num_threads=*/1, hardware_flags, pool, 
num_rows,
-                               bit_util::CeilDiv(num_rows, batch_size_max), 
target));
-
-  for (int64_t i = 0; i < num_batches; ++i) {
+constexpr int kBatchSizeMax = 32 * 1024;
+Status BuildBloomFilter_Serial(
+    std::unique_ptr<BloomFilterBuilder>& builder, int64_t num_rows, int64_t 
num_batches,
+    std::function<void(int64_t, int, uint32_t*)> get_hash32_impl,
+    std::function<void(int64_t, int, uint64_t*)> get_hash64_impl,
+    BlockedBloomFilter* target) {
+  std::vector<uint32_t> hashes32(kBatchSizeMax);
+  std::vector<uint64_t> hashes64(kBatchSizeMax);
+  for (int64_t i = 0; i < num_batches; i++) {
     size_t thread_index = 0;
     int batch_size = static_cast<int>(
-        std::min(num_rows - i * batch_size_max, 
static_cast<int64_t>(batch_size_max)));
+        std::min(num_rows - i * kBatchSizeMax, 
static_cast<int64_t>(kBatchSizeMax)));
     if (target->NumHashBitsUsed() > 32) {
-      uint64_t* hashes = thread_local_hashes64.data();
-      get_hash64_impl(i * batch_size_max, batch_size, hashes);
+      uint64_t* hashes = hashes64.data();
+      get_hash64_impl(i * kBatchSizeMax, batch_size, hashes);
       RETURN_NOT_OK(builder->PushNextBatch(thread_index, batch_size, hashes));
     } else {
-      uint32_t* hashes = thread_local_hashes32.data();
-      get_hash32_impl(i * batch_size_max, batch_size, hashes);
+      uint32_t* hashes = hashes32.data();
+      get_hash32_impl(i * kBatchSizeMax, batch_size, hashes);
       RETURN_NOT_OK(builder->PushNextBatch(thread_index, batch_size, hashes));
     }
   }
+  return Status::OK();
+}
+
+Status BuildBloomFilter_Parallel(
+    std::unique_ptr<BloomFilterBuilder>& builder, size_t num_threads, int64_t 
num_rows,
+    int64_t num_batches, std::function<void(int64_t, int, uint32_t*)> 
get_hash32_impl,
+    std::function<void(int64_t, int, uint64_t*)> get_hash64_impl,
+    BlockedBloomFilter* target) {
+  std::mutex mutex;
+  ThreadIndexer thread_indexer;
+  std::unique_ptr<TaskScheduler> scheduler = TaskScheduler::Make();
+  std::vector<std::vector<uint32_t>> thread_local_hashes32(num_threads);
+  std::vector<std::vector<uint64_t>> thread_local_hashes64(num_threads);
+  for (std::vector<uint32_t>& h : thread_local_hashes32) 
h.resize(kBatchSizeMax);
+  for (std::vector<uint64_t>& h : thread_local_hashes64) 
h.resize(kBatchSizeMax);
+
+  std::condition_variable cv;
+  std::unique_lock<std::mutex> lk(mutex, std::defer_lock);
+  auto group = scheduler->RegisterTaskGroup(
+      [&](size_t thread_index, int64_t task_id) -> Status {
+        int batch_size = static_cast<int>(std::min(num_rows - task_id * 
kBatchSizeMax,
+                                                   
static_cast<int64_t>(kBatchSizeMax)));
+        if (target->NumHashBitsUsed() > 32) {
+          uint64_t* hashes = thread_local_hashes64[thread_index].data();
+          get_hash64_impl(task_id * kBatchSizeMax, batch_size, hashes);
+          RETURN_NOT_OK(builder->PushNextBatch(thread_index, batch_size, 
hashes));
+        } else {
+          uint32_t* hashes = thread_local_hashes32[thread_index].data();
+          get_hash32_impl(task_id * kBatchSizeMax, batch_size, hashes);
+          RETURN_NOT_OK(builder->PushNextBatch(thread_index, batch_size, 
hashes));
+        }
+        return Status::OK();
+      },
+      [&](size_t thread_index) -> Status {
+        lk.unlock();
+        cv.notify_one();
+        return Status::OK();
+      });
+  scheduler->RegisterEnd();
+  auto tp = arrow::internal::GetCpuThreadPool();
+  RETURN_NOT_OK(scheduler->StartScheduling(
+      0,
+      [&](std::function<Status(size_t)> func) -> Status {
+        return tp->Spawn([&, func] {
+          size_t tid = thread_indexer();
+          std::ignore = func(tid);
+        });
+      },
+      static_cast<int>(2 * num_threads), false));

Review Comment:
   Not sure, it's 2x everywhere else. I can change it. 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to