kou commented on code in PR #45612:
URL: https://github.com/apache/arrow/pull/45612#discussion_r1976310060


##########
cpp/src/arrow/acero/swiss_join.cc:
##########
@@ -1154,91 +1155,73 @@ Status SwissTableForJoinBuild::Init(SwissTableForJoin* 
target, int dop, int64_t
   return Status::OK();
 }
 
-Status SwissTableForJoinBuild::PushNextBatch(int64_t thread_id,
-                                             const ExecBatch& key_batch,
-                                             const ExecBatch* 
payload_batch_maybe_null,
-                                             arrow::util::TempVectorStack* 
temp_stack) {
-  ARROW_DCHECK(thread_id < dop_);
+Status SwissTableForJoinBuild::PartitionBatch(size_t thread_id, int64_t 
batch_id,
+                                              const ExecBatch& key_batch,
+                                              arrow::util::TempVectorStack* 
temp_stack) {
+  DCHECK_LE(static_cast<int64_t>(thread_id), dop_);
+  DCHECK_LE(batch_id, static_cast<int64_t>(batch_states_.size()));

Review Comment:
   `LT`?
   
   ```suggestion
     DCHECK_LT(batch_id, static_cast<int64_t>(batch_states_.size()));
   ```



##########
cpp/src/arrow/acero/swiss_join.cc:
##########
@@ -1154,91 +1155,73 @@ Status SwissTableForJoinBuild::Init(SwissTableForJoin* 
target, int dop, int64_t
   return Status::OK();
 }
 
-Status SwissTableForJoinBuild::PushNextBatch(int64_t thread_id,
-                                             const ExecBatch& key_batch,
-                                             const ExecBatch* 
payload_batch_maybe_null,
-                                             arrow::util::TempVectorStack* 
temp_stack) {
-  ARROW_DCHECK(thread_id < dop_);
+Status SwissTableForJoinBuild::PartitionBatch(size_t thread_id, int64_t 
batch_id,
+                                              const ExecBatch& key_batch,
+                                              arrow::util::TempVectorStack* 
temp_stack) {
+  DCHECK_LE(static_cast<int64_t>(thread_id), dop_);
+  DCHECK_LE(batch_id, static_cast<int64_t>(batch_states_.size()));
   ThreadState& locals = thread_states_[thread_id];
+  BatchState& batch_state = batch_states_[batch_id];
+  uint16_t num_rows = static_cast<uint16_t>(key_batch.length);
 
   // Compute hash
   //
-  locals.batch_hashes.resize(key_batch.length);
-  RETURN_NOT_OK(Hashing32::HashBatch(
-      key_batch, locals.batch_hashes.data(), locals.temp_column_arrays, 
hardware_flags_,
-      temp_stack, /*start_row=*/0, static_cast<int>(key_batch.length)));
+  batch_state.hashes.resize(num_rows);
+  RETURN_NOT_OK(Hashing32::HashBatch(key_batch, batch_state.hashes.data(),
+                                     locals.temp_column_arrays, 
hardware_flags_,
+                                     temp_stack, /*start_row=*/0, num_rows));
 
   // Partition on hash
   //
-  locals.batch_prtn_row_ids.resize(locals.batch_hashes.size());
-  locals.batch_prtn_ranges.resize(num_prtns_ + 1);
-  int num_rows = static_cast<int>(locals.batch_hashes.size());
+  batch_state.prtn_ranges.resize(num_prtns_ + 1);
+  batch_state.prtn_row_ids.resize(num_rows);
   if (num_prtns_ == 1) {
     // We treat single partition case separately to avoid extra checks in row
     // partitioning implementation for general case.
     //
-    locals.batch_prtn_ranges[0] = 0;
-    locals.batch_prtn_ranges[1] = num_rows;
-    for (int i = 0; i < num_rows; ++i) {
-      locals.batch_prtn_row_ids[i] = i;
+    batch_state.prtn_ranges[0] = 0;
+    batch_state.prtn_ranges[1] = num_rows;
+    for (uint16_t i = 0; i < num_rows; ++i) {
+      batch_state.prtn_row_ids[i] = i;
     }
   } else {
     PartitionSort::Eval(
-        static_cast<int>(locals.batch_hashes.size()), num_prtns_,
-        locals.batch_prtn_ranges.data(),
-        [this, &locals](int64_t i) {
+        num_rows, num_prtns_, batch_state.prtn_ranges.data(),
+        [this, &batch_state](int64_t i) {
           // SwissTable uses the highest bits of the hash for block index.
           // We want each partition to correspond to a range of block indices,
           // so we also partition on the highest bits of the hash.
           //
-          return locals.batch_hashes[i] >> (SwissTable::bits_hash_ - 
log_num_prtns_);
+          return batch_state.hashes[i] >> (SwissTable::bits_hash_ - 
log_num_prtns_);
         },
-        [&locals](int64_t i, int pos) {
-          locals.batch_prtn_row_ids[pos] = static_cast<uint16_t>(i);
+        [&batch_state](int64_t i, int pos) {
+          batch_state.prtn_row_ids[pos] = static_cast<uint16_t>(i);
         });
   }
 
   // Update hashes, shifting left to get rid of the bits that were already used
   // for partitioning.
   //
-  for (size_t i = 0; i < locals.batch_hashes.size(); ++i) {
-    locals.batch_hashes[i] <<= log_num_prtns_;
+  for (size_t i = 0; i < batch_state.hashes.size(); ++i) {
+    batch_state.hashes[i] <<= log_num_prtns_;
   }

Review Comment:
   This is not related to this PR but we don't need to do this when `num_prtns_ 
== 1`.



##########
cpp/src/arrow/acero/swiss_join.cc:
##########
@@ -2593,58 +2583,94 @@ class SwissJoin : public HashJoinImpl {
     hash_table_build_ = std::make_unique<SwissTableForJoinBuild>();
     RETURN_NOT_OK(CancelIfNotOK(hash_table_build_->Init(
         &hash_table_, num_threads_, build_side_batches_.row_count(),
-        reject_duplicate_keys, no_payload, key_types, payload_types, pool_,
-        hardware_flags_)));
+        build_side_batches_.batch_count(), reject_duplicate_keys, no_payload, 
key_types,
+        payload_types, pool_, hardware_flags_)));
 
     // Process all input batches
     //
-    return CancelIfNotOK(
-        start_task_group_callback_(task_group_build_, 
build_side_batches_.batch_count()));
+    return CancelIfNotOK(start_task_group_callback_(task_group_partition_,
+                                                    
build_side_batches_.batch_count()));
   }
 
-  Status BuildTask(size_t thread_id, int64_t batch_id) {
+  Status PartitionTask(size_t thread_id, int64_t batch_id) {
     if (IsCancelled()) {
       return Status::OK();
     }
 
     DCHECK_GT(build_side_batches_[batch_id].length, 0);
 
     const HashJoinProjectionMaps* schema = schema_[1];
-    DCHECK_NE(hash_table_build_, nullptr);
-    bool no_payload = hash_table_build_->no_payload();
-
     ExecBatch input_batch;
     ARROW_ASSIGN_OR_RAISE(
         input_batch, KeyPayloadFromInput(/*side=*/1, 
&build_side_batches_[batch_id]));
 
-    // Split batch into key batch and optional payload batch
-    //
-    // Input batch is key-payload batch (key columns followed by payload
-    // columns). We split it into two separate batches.
-    //
-    // TODO: Change SwissTableForJoinBuild interface to use key-payload
-    // batch instead to avoid this operation, which involves increasing
-    // shared pointer ref counts.
-    //
     ExecBatch key_batch({}, input_batch.length);
     key_batch.values.resize(schema->num_cols(HashJoinProjection::KEY));
     for (size_t icol = 0; icol < key_batch.values.size(); ++icol) {
       key_batch.values[icol] = input_batch.values[icol];
     }
-    ExecBatch payload_batch({}, input_batch.length);
+    arrow::util::TempVectorStack* temp_stack = &local_states_[thread_id].stack;
+
+    DCHECK_NE(hash_table_build_, nullptr);
+    return hash_table_build_->PartitionBatch(static_cast<int64_t>(thread_id), 
batch_id,
+                                             key_batch, temp_stack);
+  }
+
+  Status PartitionFinished(size_t thread_id) {
+    RETURN_NOT_OK(status());
+
+    DCHECK_NE(hash_table_build_, nullptr);
+    return CancelIfNotOK(
+        start_task_group_callback_(task_group_build_, 
hash_table_build_->num_prtns()));
+  }
+
+  Status BuildTask(size_t thread_id, int64_t prtn_id) {
+    if (IsCancelled()) {
+      return Status::OK();
+    }
 
+    const HashJoinProjectionMaps* schema = schema_[1];
+    DCHECK_NE(hash_table_build_, nullptr);
+    bool no_payload = hash_table_build_->no_payload();
+    ExecBatch key_batch, payload_batch;
+    key_batch.values.resize(schema->num_cols(HashJoinProjection::KEY));
     if (!no_payload) {
       
payload_batch.values.resize(schema->num_cols(HashJoinProjection::PAYLOAD));
-      for (size_t icol = 0; icol < payload_batch.values.size(); ++icol) {
-        payload_batch.values[icol] =
-            input_batch.values[schema->num_cols(HashJoinProjection::KEY) + 
icol];
-      }
     }
     arrow::util::TempVectorStack* temp_stack = &local_states_[thread_id].stack;
-    DCHECK_NE(hash_table_build_, nullptr);
-    RETURN_NOT_OK(CancelIfNotOK(hash_table_build_->PushNextBatch(
-        static_cast<int64_t>(thread_id), key_batch, no_payload ? nullptr : 
&payload_batch,
-        temp_stack)));
+
+    for (int64_t batch_id = 0;
+         batch_id < static_cast<int64_t>(build_side_batches_.batch_count()); 
++batch_id) {
+      ExecBatch input_batch;
+      ARROW_ASSIGN_OR_RAISE(
+          input_batch, KeyPayloadFromInput(/*side=*/1, 
&build_side_batches_[batch_id]));
+
+      // Split batch into key batch and optional payload batch
+      //
+      // Input batch is key-payload batch (key columns followed by payload
+      // columns). We split it into two separate batches.
+      //
+      // TODO: Change SwissTableForJoinBuild interface to use key-payload
+      // batch instead to avoid this operation, which involves increasing
+      // shared pointer ref counts.
+      //
+      key_batch.length = input_batch.length;
+      for (size_t icol = 0; icol < key_batch.values.size(); ++icol) {
+        key_batch.values[icol] = input_batch.values[icol];
+      }
+
+      if (!no_payload) {
+        payload_batch.length = input_batch.length;
+        for (size_t icol = 0; icol < payload_batch.values.size(); ++icol) {
+          payload_batch.values[icol] =
+              input_batch.values[schema->num_cols(HashJoinProjection::KEY) + 
icol];

Review Comment:
   This is not related to this PR but can we avoid calling `schema->num_cols()` 
in this loop? (I'm not sure whether this is a performance impact code but it 
seems that we can use pre-computed value.)



##########
cpp/src/arrow/acero/swiss_join.cc:
##########
@@ -1154,91 +1155,73 @@ Status SwissTableForJoinBuild::Init(SwissTableForJoin* 
target, int dop, int64_t
   return Status::OK();
 }
 
-Status SwissTableForJoinBuild::PushNextBatch(int64_t thread_id,
-                                             const ExecBatch& key_batch,
-                                             const ExecBatch* 
payload_batch_maybe_null,
-                                             arrow::util::TempVectorStack* 
temp_stack) {
-  ARROW_DCHECK(thread_id < dop_);
+Status SwissTableForJoinBuild::PartitionBatch(size_t thread_id, int64_t 
batch_id,
+                                              const ExecBatch& key_batch,
+                                              arrow::util::TempVectorStack* 
temp_stack) {
+  DCHECK_LE(static_cast<int64_t>(thread_id), dop_);

Review Comment:
   It seems that the max `thread_id` is `dop_ - 1`:
   
   ```suggestion
     DCHECK_LT(static_cast<int64_t>(thread_id), dop_);
   ```
   
   (`thread_states_.size()` may be better than `dop_`.)



##########
cpp/src/arrow/acero/swiss_join.cc:
##########
@@ -1154,91 +1155,73 @@ Status SwissTableForJoinBuild::Init(SwissTableForJoin* 
target, int dop, int64_t
   return Status::OK();
 }
 
-Status SwissTableForJoinBuild::PushNextBatch(int64_t thread_id,
-                                             const ExecBatch& key_batch,
-                                             const ExecBatch* 
payload_batch_maybe_null,
-                                             arrow::util::TempVectorStack* 
temp_stack) {
-  ARROW_DCHECK(thread_id < dop_);
+Status SwissTableForJoinBuild::PartitionBatch(size_t thread_id, int64_t 
batch_id,
+                                              const ExecBatch& key_batch,
+                                              arrow::util::TempVectorStack* 
temp_stack) {
+  DCHECK_LE(static_cast<int64_t>(thread_id), dop_);
+  DCHECK_LE(batch_id, static_cast<int64_t>(batch_states_.size()));
   ThreadState& locals = thread_states_[thread_id];
+  BatchState& batch_state = batch_states_[batch_id];
+  uint16_t num_rows = static_cast<uint16_t>(key_batch.length);
 
   // Compute hash
   //
-  locals.batch_hashes.resize(key_batch.length);
-  RETURN_NOT_OK(Hashing32::HashBatch(
-      key_batch, locals.batch_hashes.data(), locals.temp_column_arrays, 
hardware_flags_,
-      temp_stack, /*start_row=*/0, static_cast<int>(key_batch.length)));
+  batch_state.hashes.resize(num_rows);
+  RETURN_NOT_OK(Hashing32::HashBatch(key_batch, batch_state.hashes.data(),
+                                     locals.temp_column_arrays, 
hardware_flags_,
+                                     temp_stack, /*start_row=*/0, num_rows));
 
   // Partition on hash
   //
-  locals.batch_prtn_row_ids.resize(locals.batch_hashes.size());
-  locals.batch_prtn_ranges.resize(num_prtns_ + 1);
-  int num_rows = static_cast<int>(locals.batch_hashes.size());
+  batch_state.prtn_ranges.resize(num_prtns_ + 1);
+  batch_state.prtn_row_ids.resize(num_rows);
   if (num_prtns_ == 1) {
     // We treat single partition case separately to avoid extra checks in row
     // partitioning implementation for general case.
     //
-    locals.batch_prtn_ranges[0] = 0;
-    locals.batch_prtn_ranges[1] = num_rows;
-    for (int i = 0; i < num_rows; ++i) {
-      locals.batch_prtn_row_ids[i] = i;
+    batch_state.prtn_ranges[0] = 0;
+    batch_state.prtn_ranges[1] = num_rows;
+    for (uint16_t i = 0; i < num_rows; ++i) {
+      batch_state.prtn_row_ids[i] = i;
     }
   } else {
     PartitionSort::Eval(
-        static_cast<int>(locals.batch_hashes.size()), num_prtns_,
-        locals.batch_prtn_ranges.data(),
-        [this, &locals](int64_t i) {
+        num_rows, num_prtns_, batch_state.prtn_ranges.data(),
+        [this, &batch_state](int64_t i) {
           // SwissTable uses the highest bits of the hash for block index.
           // We want each partition to correspond to a range of block indices,
           // so we also partition on the highest bits of the hash.
           //
-          return locals.batch_hashes[i] >> (SwissTable::bits_hash_ - 
log_num_prtns_);
+          return batch_state.hashes[i] >> (SwissTable::bits_hash_ - 
log_num_prtns_);
         },
-        [&locals](int64_t i, int pos) {
-          locals.batch_prtn_row_ids[pos] = static_cast<uint16_t>(i);
+        [&batch_state](int64_t i, int pos) {
+          batch_state.prtn_row_ids[pos] = static_cast<uint16_t>(i);
         });
   }
 
   // Update hashes, shifting left to get rid of the bits that were already used
   // for partitioning.
   //
-  for (size_t i = 0; i < locals.batch_hashes.size(); ++i) {
-    locals.batch_hashes[i] <<= log_num_prtns_;
+  for (size_t i = 0; i < batch_state.hashes.size(); ++i) {
+    batch_state.hashes[i] <<= log_num_prtns_;
   }
 
-  // For each partition:
-  // - map keys to unique integers using (this partition's) hash table
-  // - append payloads (if present) to (this partition's) row array
-  //
-  locals.temp_prtn_ids.resize(num_prtns_);
-
-  RETURN_NOT_OK(prtn_locks_.ForEachPartition(
-      thread_id, locals.temp_prtn_ids.data(),
-      /*is_prtn_empty_fn=*/
-      [&](int prtn_id) {
-        return locals.batch_prtn_ranges[prtn_id + 1] == 
locals.batch_prtn_ranges[prtn_id];
-      },
-      /*process_prtn_fn=*/
-      [&](int prtn_id) {
-        return ProcessPartition(thread_id, key_batch, payload_batch_maybe_null,
-                                temp_stack, prtn_id);
-      }));
-
   return Status::OK();
 }
 
-Status SwissTableForJoinBuild::ProcessPartition(int64_t thread_id,
-                                                const ExecBatch& key_batch,
-                                                const ExecBatch* 
payload_batch_maybe_null,
-                                                arrow::util::TempVectorStack* 
temp_stack,
-                                                int prtn_id) {
-  ARROW_DCHECK(thread_id < dop_);
+Status SwissTableForJoinBuild::ProcessPartition(
+    size_t thread_id, int64_t batch_id, int prtn_id, const ExecBatch& 
key_batch,
+    const ExecBatch* payload_batch_maybe_null, arrow::util::TempVectorStack* 
temp_stack) {
+  DCHECK_LE(static_cast<int64_t>(thread_id), dop_);
+  DCHECK_LE(batch_id, static_cast<int64_t>(batch_states_.size()));

Review Comment:
   ditto.
   
   (We may want to add `DCHECK_LT()` for `prtn_id` too.)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to