save-buffer commented on code in PR #13332:
URL: https://github.com/apache/arrow/pull/13332#discussion_r892762033
##########
cpp/src/arrow/compute/exec/hash_join_node.cc:
##########
@@ -454,6 +456,172 @@ Status
HashJoinSchema::CollectFilterColumns(std::vector<FieldRef>& left_filter,
return Status::OK();
}
+class HashJoinNode;
+
+struct BloomFilterPushdownContext {
+ using BuildFinishedCallback = std::function<Status(size_t,
AccumulationQueue)>;
+ using FiltersReceivedCallback = std::function<Status()>;
+ using FilterFinishedCallback = std::function<Status(size_t,
AccumulationQueue)>;
+ void Init(HashJoinNode* owner, size_t num_threads, TaskScheduler* scheduler,
+ FiltersReceivedCallback on_bloom_filters_received, bool
disable_bloom_filter,
+ bool use_sync_execution);
+
+ Status StartProducing();
+
+ void ExpectBloomFilter() { eval_.num_expected_bloom_filters_ += 1; }
+
+ Status BuildBloomFilter(size_t thread_index, AccumulationQueue batches,
+ BuildFinishedCallback on_finished);
+
+ Status PushBloomFilter();
+
+ Status ReceiveBloomFilter(std::unique_ptr<BlockedBloomFilter> filter,
+ std::vector<int> column_map) {
+ bool proceed;
+ {
+ std::lock_guard<std::mutex> guard(eval_.receive_mutex_);
+ eval_.received_filters_.emplace_back(std::move(filter));
+ eval_.received_maps_.emplace_back(std::move(column_map));
+ proceed = eval_.num_expected_bloom_filters_ ==
eval_.received_filters_.size();
+
+ ARROW_DCHECK(eval_.received_filters_.size() ==
eval_.received_maps_.size());
+ ARROW_DCHECK(eval_.received_filters_.size() <=
eval_.num_expected_bloom_filters_);
+ }
+ if (proceed) {
+ return push_.all_received_callback_();
+ }
+ return Status::OK();
+ }
+
+ Status FilterBatches(size_t thread_index, AccumulationQueue batches,
+ FilterFinishedCallback on_finished) {
+ eval_.batches_ = std::move(batches);
+ eval_.on_finished_ = std::move(on_finished);
+
+ if (eval_.num_expected_bloom_filters_ == 0)
+ return eval_.on_finished_(thread_index, std::move(eval_.batches_));
+
+ return scheduler_->StartTaskGroup(thread_index, eval_.task_id_,
+
/*num_tasks=*/eval_.batches_.batch_count());
+ }
+
+ Status FilterSingleBatch(size_t thread_index, ExecBatch& batch) {
+ if (disable_bloom_filter_ || batch.length == 0) return Status::OK();
+ int64_t bit_vector_bytes = bit_util::BytesForBits(batch.length);
+ std::vector<uint8_t> selected(bit_vector_bytes);
+ std::vector<uint32_t> hashes(batch.length);
+ std::vector<uint8_t> bv(bit_vector_bytes);
+
+ ARROW_ASSIGN_OR_RAISE(util::TempVectorStack * stack,
GetStack(thread_index));
+
+ // Start with full selection for the current batch
+ memset(selected.data(), 0xff, bit_vector_bytes);
+ for (size_t ifilter = 0; ifilter < eval_.num_expected_bloom_filters_;
ifilter++) {
+ std::vector<Datum> keys(eval_.received_maps_[ifilter].size());
+ for (size_t i = 0; i < keys.size(); i++) {
+ int input_idx = eval_.received_maps_[ifilter][i];
+ keys[i] = batch[input_idx];
+ if (keys[i].is_scalar()) {
+ ARROW_ASSIGN_OR_RAISE(
+ keys[i],
+ MakeArrayFromScalar(*keys[i].scalar(), batch.length,
ctx_->memory_pool()));
+ }
+ }
+ ARROW_ASSIGN_OR_RAISE(ExecBatch key_batch,
ExecBatch::Make(std::move(keys)));
+ RETURN_NOT_OK(Hashing32::HashBatch(key_batch, hashes.data(),
+ ctx_->cpu_info()->hardware_flags(),
stack, 0,
+ key_batch.length));
+
+
eval_.received_filters_[ifilter]->Find(ctx_->cpu_info()->hardware_flags(),
+ key_batch.length, hashes.data(),
bv.data());
+ arrow::internal::BitmapAnd(bv.data(), 0, selected.data(), 0,
key_batch.length, 0,
+ selected.data());
+ }
+ auto selected_buffer =
+ arrow::internal::make_unique<Buffer>(selected.data(),
bit_vector_bytes);
+ ArrayData selected_arraydata(boolean(), batch.length,
+ {nullptr, std::move(selected_buffer)});
+ Datum selected_datum(selected_arraydata);
+ FilterOptions options;
+ size_t first_nonscalar = batch.values.size();
+ for (size_t i = 0; i < batch.values.size(); i++) {
+ if (!batch.values[i].is_scalar()) {
+ ARROW_ASSIGN_OR_RAISE(batch.values[i],
+ Filter(batch.values[i], selected_datum, options,
ctx_));
+ first_nonscalar = std::min(first_nonscalar, i);
+ ARROW_DCHECK_EQ(batch.values[i].length(),
batch.values[first_nonscalar].length());
+ }
+ }
+ // If they're all Scalar, then the length of the batch is the number of
set bits
+ if (first_nonscalar == batch.values.size())
+ batch.length = arrow::internal::CountSetBits(selected.data(), 0,
batch.length);
+ else
+ batch.length = batch.values[first_nonscalar].length();
+ return Status::OK();
+ }
+
+ Status BuildBloomFilter_exec_task(size_t thread_index, int64_t task_id);
+
+ Status BuildBloomFilter_on_finished(size_t thread_index) {
+ return build_.on_finished_(thread_index, std::move(build_.batches_));
+ }
+
+ // The Bloom filter is built on the build side of some upstream join. For a
join to
+ // evaluate the Bloom filter on its input columns, it has to rearrange its
input columns
+ // to match the column order of the Bloom filter.
+ //
+ // The first part of the pair is the HashJoin to actually perform the
pushdown into.
+ // The second part is a mapping such that column_map[i] is the index of key
i in
+ // the first part's input.
+ // If we should disable Bloom filter, returns nullptr and an empty vector,
and sets
+ // the disable_bloom_filter_ flag.
+ std::pair<HashJoinNode*, std::vector<int>> GetPushdownTarget(HashJoinNode*
start);
+
+ Result<util::TempVectorStack*> GetStack(size_t thread_index) {
+ if (!tld_[thread_index].is_init) {
+ RETURN_NOT_OK(tld_[thread_index].stack.Init(
+ ctx_->memory_pool(), 4 * util::MiniBatch::kMiniBatchLength *
sizeof(uint32_t)));
+ tld_[thread_index].is_init = true;
+ }
+ return &tld_[thread_index].stack;
+ }
+
+ bool disable_bloom_filter_;
+ HashJoinNode* owner_;
+ ExecContext* ctx_;
+ TaskScheduler* scheduler_;
+
+ struct ThreadLocalData {
+ bool is_init = false;
+ util::TempVectorStack stack;
+ };
+ std::vector<ThreadLocalData> tld_;
+
+ struct {
+ int task_id_;
+ std::unique_ptr<BloomFilterBuilder> builder_;
+ AccumulationQueue batches_;
+ BuildFinishedCallback on_finished_;
+ } build_;
+
+ struct {
+ std::unique_ptr<BlockedBloomFilter> bloom_filter_;
+ HashJoinNode* pushdown_target_;
+ std::vector<int> column_map_;
+ FiltersReceivedCallback all_received_callback_;
Review Comment:
I agree, it belongs on the receiver not the pusher. I've moved it.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]