zanmato1984 commented on code in PR #45612:
URL: https://github.com/apache/arrow/pull/45612#discussion_r1976867837
##########
cpp/src/arrow/acero/swiss_join.cc:
##########
@@ -1154,91 +1155,73 @@ Status SwissTableForJoinBuild::Init(SwissTableForJoin*
target, int dop, int64_t
return Status::OK();
}
-Status SwissTableForJoinBuild::PushNextBatch(int64_t thread_id,
- const ExecBatch& key_batch,
- const ExecBatch*
payload_batch_maybe_null,
- arrow::util::TempVectorStack*
temp_stack) {
- ARROW_DCHECK(thread_id < dop_);
+Status SwissTableForJoinBuild::PartitionBatch(size_t thread_id, int64_t
batch_id,
+ const ExecBatch& key_batch,
+ arrow::util::TempVectorStack*
temp_stack) {
+ DCHECK_LE(static_cast<int64_t>(thread_id), dop_);
+ DCHECK_LE(batch_id, static_cast<int64_t>(batch_states_.size()));
ThreadState& locals = thread_states_[thread_id];
+ BatchState& batch_state = batch_states_[batch_id];
+ uint16_t num_rows = static_cast<uint16_t>(key_batch.length);
// Compute hash
//
- locals.batch_hashes.resize(key_batch.length);
- RETURN_NOT_OK(Hashing32::HashBatch(
- key_batch, locals.batch_hashes.data(), locals.temp_column_arrays,
hardware_flags_,
- temp_stack, /*start_row=*/0, static_cast<int>(key_batch.length)));
+ batch_state.hashes.resize(num_rows);
+ RETURN_NOT_OK(Hashing32::HashBatch(key_batch, batch_state.hashes.data(),
+ locals.temp_column_arrays,
hardware_flags_,
+ temp_stack, /*start_row=*/0, num_rows));
// Partition on hash
//
- locals.batch_prtn_row_ids.resize(locals.batch_hashes.size());
- locals.batch_prtn_ranges.resize(num_prtns_ + 1);
- int num_rows = static_cast<int>(locals.batch_hashes.size());
+ batch_state.prtn_ranges.resize(num_prtns_ + 1);
+ batch_state.prtn_row_ids.resize(num_rows);
if (num_prtns_ == 1) {
// We treat single partition case separately to avoid extra checks in row
// partitioning implementation for general case.
//
- locals.batch_prtn_ranges[0] = 0;
- locals.batch_prtn_ranges[1] = num_rows;
- for (int i = 0; i < num_rows; ++i) {
- locals.batch_prtn_row_ids[i] = i;
+ batch_state.prtn_ranges[0] = 0;
+ batch_state.prtn_ranges[1] = num_rows;
+ for (uint16_t i = 0; i < num_rows; ++i) {
+ batch_state.prtn_row_ids[i] = i;
}
} else {
PartitionSort::Eval(
- static_cast<int>(locals.batch_hashes.size()), num_prtns_,
- locals.batch_prtn_ranges.data(),
- [this, &locals](int64_t i) {
+ num_rows, num_prtns_, batch_state.prtn_ranges.data(),
+ [this, &batch_state](int64_t i) {
// SwissTable uses the highest bits of the hash for block index.
// We want each partition to correspond to a range of block indices,
// so we also partition on the highest bits of the hash.
//
- return locals.batch_hashes[i] >> (SwissTable::bits_hash_ -
log_num_prtns_);
+ return batch_state.hashes[i] >> (SwissTable::bits_hash_ -
log_num_prtns_);
},
- [&locals](int64_t i, int pos) {
- locals.batch_prtn_row_ids[pos] = static_cast<uint16_t>(i);
+ [&batch_state](int64_t i, int pos) {
+ batch_state.prtn_row_ids[pos] = static_cast<uint16_t>(i);
});
}
// Update hashes, shifting left to get rid of the bits that were already used
// for partitioning.
//
- for (size_t i = 0; i < locals.batch_hashes.size(); ++i) {
- locals.batch_hashes[i] <<= log_num_prtns_;
+ for (size_t i = 0; i < batch_state.hashes.size(); ++i) {
+ batch_state.hashes[i] <<= log_num_prtns_;
}
Review Comment:
Good point. Updated.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]