zanmato1984 commented on code in PR #45612:
URL: https://github.com/apache/arrow/pull/45612#discussion_r1976868677
##########
cpp/src/arrow/acero/swiss_join.cc:
##########
@@ -1154,91 +1155,73 @@ Status SwissTableForJoinBuild::Init(SwissTableForJoin*
target, int dop, int64_t
return Status::OK();
}
-Status SwissTableForJoinBuild::PushNextBatch(int64_t thread_id,
- const ExecBatch& key_batch,
- const ExecBatch*
payload_batch_maybe_null,
- arrow::util::TempVectorStack*
temp_stack) {
- ARROW_DCHECK(thread_id < dop_);
+Status SwissTableForJoinBuild::PartitionBatch(size_t thread_id, int64_t
batch_id,
+ const ExecBatch& key_batch,
+ arrow::util::TempVectorStack*
temp_stack) {
+ DCHECK_LE(static_cast<int64_t>(thread_id), dop_);
+ DCHECK_LE(batch_id, static_cast<int64_t>(batch_states_.size()));
Review Comment:
Something must be wrong with my mind then that I mis-used `LE` as `LT` for
all the code around. Thank you for pointing this out. Addressed.
##########
cpp/src/arrow/acero/swiss_join.cc:
##########
@@ -1154,91 +1155,73 @@ Status SwissTableForJoinBuild::Init(SwissTableForJoin*
target, int dop, int64_t
return Status::OK();
}
-Status SwissTableForJoinBuild::PushNextBatch(int64_t thread_id,
- const ExecBatch& key_batch,
- const ExecBatch*
payload_batch_maybe_null,
- arrow::util::TempVectorStack*
temp_stack) {
- ARROW_DCHECK(thread_id < dop_);
+Status SwissTableForJoinBuild::PartitionBatch(size_t thread_id, int64_t
batch_id,
+ const ExecBatch& key_batch,
+ arrow::util::TempVectorStack*
temp_stack) {
+ DCHECK_LE(static_cast<int64_t>(thread_id), dop_);
+ DCHECK_LE(batch_id, static_cast<int64_t>(batch_states_.size()));
ThreadState& locals = thread_states_[thread_id];
+ BatchState& batch_state = batch_states_[batch_id];
+ uint16_t num_rows = static_cast<uint16_t>(key_batch.length);
// Compute hash
//
- locals.batch_hashes.resize(key_batch.length);
- RETURN_NOT_OK(Hashing32::HashBatch(
- key_batch, locals.batch_hashes.data(), locals.temp_column_arrays,
hardware_flags_,
- temp_stack, /*start_row=*/0, static_cast<int>(key_batch.length)));
+ batch_state.hashes.resize(num_rows);
+ RETURN_NOT_OK(Hashing32::HashBatch(key_batch, batch_state.hashes.data(),
+ locals.temp_column_arrays,
hardware_flags_,
+ temp_stack, /*start_row=*/0, num_rows));
// Partition on hash
//
- locals.batch_prtn_row_ids.resize(locals.batch_hashes.size());
- locals.batch_prtn_ranges.resize(num_prtns_ + 1);
- int num_rows = static_cast<int>(locals.batch_hashes.size());
+ batch_state.prtn_ranges.resize(num_prtns_ + 1);
+ batch_state.prtn_row_ids.resize(num_rows);
if (num_prtns_ == 1) {
// We treat single partition case separately to avoid extra checks in row
// partitioning implementation for general case.
//
- locals.batch_prtn_ranges[0] = 0;
- locals.batch_prtn_ranges[1] = num_rows;
- for (int i = 0; i < num_rows; ++i) {
- locals.batch_prtn_row_ids[i] = i;
+ batch_state.prtn_ranges[0] = 0;
+ batch_state.prtn_ranges[1] = num_rows;
+ for (uint16_t i = 0; i < num_rows; ++i) {
+ batch_state.prtn_row_ids[i] = i;
}
} else {
PartitionSort::Eval(
- static_cast<int>(locals.batch_hashes.size()), num_prtns_,
- locals.batch_prtn_ranges.data(),
- [this, &locals](int64_t i) {
+ num_rows, num_prtns_, batch_state.prtn_ranges.data(),
+ [this, &batch_state](int64_t i) {
// SwissTable uses the highest bits of the hash for block index.
// We want each partition to correspond to a range of block indices,
// so we also partition on the highest bits of the hash.
//
- return locals.batch_hashes[i] >> (SwissTable::bits_hash_ -
log_num_prtns_);
+ return batch_state.hashes[i] >> (SwissTable::bits_hash_ -
log_num_prtns_);
},
- [&locals](int64_t i, int pos) {
- locals.batch_prtn_row_ids[pos] = static_cast<uint16_t>(i);
+ [&batch_state](int64_t i, int pos) {
+ batch_state.prtn_row_ids[pos] = static_cast<uint16_t>(i);
});
}
// Update hashes, shifting left to get rid of the bits that were already used
// for partitioning.
//
- for (size_t i = 0; i < locals.batch_hashes.size(); ++i) {
- locals.batch_hashes[i] <<= log_num_prtns_;
+ for (size_t i = 0; i < batch_state.hashes.size(); ++i) {
+ batch_state.hashes[i] <<= log_num_prtns_;
}
- // For each partition:
- // - map keys to unique integers using (this partition's) hash table
- // - append payloads (if present) to (this partition's) row array
- //
- locals.temp_prtn_ids.resize(num_prtns_);
-
- RETURN_NOT_OK(prtn_locks_.ForEachPartition(
- thread_id, locals.temp_prtn_ids.data(),
- /*is_prtn_empty_fn=*/
- [&](int prtn_id) {
- return locals.batch_prtn_ranges[prtn_id + 1] ==
locals.batch_prtn_ranges[prtn_id];
- },
- /*process_prtn_fn=*/
- [&](int prtn_id) {
- return ProcessPartition(thread_id, key_batch, payload_batch_maybe_null,
- temp_stack, prtn_id);
- }));
-
return Status::OK();
}
-Status SwissTableForJoinBuild::ProcessPartition(int64_t thread_id,
- const ExecBatch& key_batch,
- const ExecBatch*
payload_batch_maybe_null,
- arrow::util::TempVectorStack*
temp_stack,
- int prtn_id) {
- ARROW_DCHECK(thread_id < dop_);
+Status SwissTableForJoinBuild::ProcessPartition(
+ size_t thread_id, int64_t batch_id, int prtn_id, const ExecBatch&
key_batch,
+ const ExecBatch* payload_batch_maybe_null, arrow::util::TempVectorStack*
temp_stack) {
+ DCHECK_LE(static_cast<int64_t>(thread_id), dop_);
+ DCHECK_LE(batch_id, static_cast<int64_t>(batch_states_.size()));
Review Comment:
Done.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]