zanmato1984 commented on code in PR #45612:
URL: https://github.com/apache/arrow/pull/45612#discussion_r1976868818
##########
cpp/src/arrow/acero/swiss_join.cc:
##########
@@ -2593,58 +2583,94 @@ class SwissJoin : public HashJoinImpl {
hash_table_build_ = std::make_unique<SwissTableForJoinBuild>();
RETURN_NOT_OK(CancelIfNotOK(hash_table_build_->Init(
&hash_table_, num_threads_, build_side_batches_.row_count(),
- reject_duplicate_keys, no_payload, key_types, payload_types, pool_,
- hardware_flags_)));
+ build_side_batches_.batch_count(), reject_duplicate_keys, no_payload,
key_types,
+ payload_types, pool_, hardware_flags_)));
// Process all input batches
//
- return CancelIfNotOK(
- start_task_group_callback_(task_group_build_,
build_side_batches_.batch_count()));
+ return CancelIfNotOK(start_task_group_callback_(task_group_partition_,
+
build_side_batches_.batch_count()));
}
- Status BuildTask(size_t thread_id, int64_t batch_id) {
+ Status PartitionTask(size_t thread_id, int64_t batch_id) {
if (IsCancelled()) {
return Status::OK();
}
DCHECK_GT(build_side_batches_[batch_id].length, 0);
const HashJoinProjectionMaps* schema = schema_[1];
- DCHECK_NE(hash_table_build_, nullptr);
- bool no_payload = hash_table_build_->no_payload();
-
ExecBatch input_batch;
ARROW_ASSIGN_OR_RAISE(
input_batch, KeyPayloadFromInput(/*side=*/1,
&build_side_batches_[batch_id]));
- // Split batch into key batch and optional payload batch
- //
- // Input batch is key-payload batch (key columns followed by payload
- // columns). We split it into two separate batches.
- //
- // TODO: Change SwissTableForJoinBuild interface to use key-payload
- // batch instead to avoid this operation, which involves increasing
- // shared pointer ref counts.
- //
ExecBatch key_batch({}, input_batch.length);
key_batch.values.resize(schema->num_cols(HashJoinProjection::KEY));
for (size_t icol = 0; icol < key_batch.values.size(); ++icol) {
key_batch.values[icol] = input_batch.values[icol];
}
- ExecBatch payload_batch({}, input_batch.length);
+ arrow::util::TempVectorStack* temp_stack = &local_states_[thread_id].stack;
+
+ DCHECK_NE(hash_table_build_, nullptr);
+ return hash_table_build_->PartitionBatch(static_cast<int64_t>(thread_id),
batch_id,
+ key_batch, temp_stack);
+ }
+
+ Status PartitionFinished(size_t thread_id) {
+ RETURN_NOT_OK(status());
+
+ DCHECK_NE(hash_table_build_, nullptr);
+ return CancelIfNotOK(
+ start_task_group_callback_(task_group_build_,
hash_table_build_->num_prtns()));
+ }
+
+ Status BuildTask(size_t thread_id, int64_t prtn_id) {
+ if (IsCancelled()) {
+ return Status::OK();
+ }
+ const HashJoinProjectionMaps* schema = schema_[1];
+ DCHECK_NE(hash_table_build_, nullptr);
+ bool no_payload = hash_table_build_->no_payload();
+ ExecBatch key_batch, payload_batch;
+ key_batch.values.resize(schema->num_cols(HashJoinProjection::KEY));
if (!no_payload) {
payload_batch.values.resize(schema->num_cols(HashJoinProjection::PAYLOAD));
- for (size_t icol = 0; icol < payload_batch.values.size(); ++icol) {
- payload_batch.values[icol] =
- input_batch.values[schema->num_cols(HashJoinProjection::KEY) +
icol];
- }
}
arrow::util::TempVectorStack* temp_stack = &local_states_[thread_id].stack;
- DCHECK_NE(hash_table_build_, nullptr);
- RETURN_NOT_OK(CancelIfNotOK(hash_table_build_->PushNextBatch(
- static_cast<int64_t>(thread_id), key_batch, no_payload ? nullptr :
&payload_batch,
- temp_stack)));
+
+ for (int64_t batch_id = 0;
+ batch_id < static_cast<int64_t>(build_side_batches_.batch_count());
++batch_id) {
+ ExecBatch input_batch;
+ ARROW_ASSIGN_OR_RAISE(
+ input_batch, KeyPayloadFromInput(/*side=*/1,
&build_side_batches_[batch_id]));
+
+ // Split batch into key batch and optional payload batch
+ //
+ // Input batch is key-payload batch (key columns followed by payload
+ // columns). We split it into two separate batches.
+ //
+ // TODO: Change SwissTableForJoinBuild interface to use key-payload
+ // batch instead to avoid this operation, which involves increasing
+ // shared pointer ref counts.
+ //
+ key_batch.length = input_batch.length;
+ for (size_t icol = 0; icol < key_batch.values.size(); ++icol) {
+ key_batch.values[icol] = input_batch.values[icol];
+ }
+
+ if (!no_payload) {
+ payload_batch.length = input_batch.length;
+ for (size_t icol = 0; icol < payload_batch.values.size(); ++icol) {
+ payload_batch.values[icol] =
+ input_batch.values[schema->num_cols(HashJoinProjection::KEY) +
icol];
Review Comment:
Done.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]