save-buffer commented on code in PR #13669:
URL: https://github.com/apache/arrow/pull/13669#discussion_r968869202


##########
cpp/src/arrow/compute/exec/accumulation_queue.cc:
##########
@@ -48,11 +50,221 @@ void AccumulationQueue::InsertBatch(ExecBatch batch) {
   batches_.emplace_back(std::move(batch));
 }
 
+void AccumulationQueue::SetBatch(size_t idx, ExecBatch batch)
+{
+    ARROW_DCHECK(idx < batches_.size());
+    arrow::util::AtomicFetchSub(&row_count_, batches_[idx].length, 
std::memory_order_relaxed);
+    arrow::util::AtomicFetchAdd(&row_count_, batch.length, 
std::memory_order_relaxed);
+    batches_[idx] = std::move(batch);
+}
+
 void AccumulationQueue::Clear() {
   row_count_ = 0;
   batches_.clear();
 }
 
-ExecBatch& AccumulationQueue::operator[](size_t i) { return batches_[i]; }
+    Status SpillingAccumulationQueue::Init(QueryContext *ctx)
+    {
+        ctx_ = ctx;
+        partition_locks_.Init(ctx_->max_concurrency(), kNumPartitions);
+        return Status::OK();
+    }
+
+    Status SpillingAccumulationQueue::InsertBatch(
+        size_t thread_index,
+        ExecBatch batch)
+    {
+        Datum &hash_datum = batch.values.back();
+        const uint64_t *hashes = reinterpret_cast<const uint64_t 
*>(hash_datum.array()->buffers[1]->data());
+        // `permutation` stores the indices of rows in the input batch sorted 
by partition.
+        std::vector<uint16_t> permutation(batch.length);
+        uint16_t part_starts[kNumPartitions + 1];
+        PartitionSort::Eval(
+            batch.length,
+            kNumPartitions,
+            part_starts,
+            [&](int64_t i)
+            {
+                return hashes[i] & (kNumPartitions - 1);
+            },
+            [&permutation](int64_t input_pos, int64_t output_pos)
+            {
+                permutation[output_pos] = static_cast<uint16_t>(input_pos);
+            });
+
+        int unprocessed_partition_ids[kNumPartitions];
+        RETURN_NOT_OK(partition_locks_.ForEachPartition(
+                          thread_index,
+                          unprocessed_partition_ids,
+                          [&](int part_id)
+                          {
+                              return part_starts[part_id + 1] == 
part_starts[part_id];
+                          },
+                          [&](int locked_part_id_int)
+                          {
+                              size_t locked_part_id = 
static_cast<size_t>(locked_part_id_int);
+                              uint64_t num_total_rows_to_append =
+                                  part_starts[locked_part_id + 1] - 
part_starts[locked_part_id];
+
+                              size_t offset = 
static_cast<size_t>(part_starts[locked_part_id]);
+                              while(num_total_rows_to_append > 0)
+                              {
+                                  int num_rows_to_append = std::min(
+                                      
static_cast<int>(num_total_rows_to_append),
+                                      
static_cast<int>(ExecBatchBuilder::num_rows_max() - 
builders_[locked_part_id].num_rows()));
+
+                                  
RETURN_NOT_OK(builders_[locked_part_id].AppendSelected(
+                                                    ctx_->memory_pool(),
+                                                    batch,
+                                                    num_rows_to_append,
+                                                    permutation.data() + 
offset,
+                                                    batch.num_values()));
+
+                                  if(builders_[locked_part_id].is_full())
+                                  {
+                                      ExecBatch batch = 
builders_[locked_part_id].Flush();
+                                      Datum hash = 
std::move(batch.values.back());
+                                      batch.values.pop_back();
+                                      ExecBatch hash_batch({ std::move(hash) 
}, batch.length);
+                                      if(locked_part_id < spilling_cursor_)
+                                          
RETURN_NOT_OK(files_[locked_part_id].SpillBatch(
+                                                            ctx_,
+                                                            std::move(batch)));
+                                      else
+                                          
queues_[locked_part_id].InsertBatch(std::move(batch));
+
+                                      if(locked_part_id < hash_cursor_)
+                                          RETURN_NOT_OK(
+                                              
hash_files_[locked_part_id].SpillBatch(
+                                                  ctx_,
+                                                  std::move(hash_batch)));
+                                      else
+                                          
hash_queues_[locked_part_id].InsertBatch(std::move(hash_batch));
+
+                                  }
+                                  offset += num_rows_to_append;
+                                  num_total_rows_to_append -= 
num_rows_to_append;
+                              }
+                              return Status::OK();
+                          }));
+        return Status::OK();
+    }
+
+    const uint64_t *SpillingAccumulationQueue::GetHashes(size_t partition, 
size_t batch_idx)
+    {
+        ARROW_DCHECK(partition >= hash_cursor_.load());
+        const Datum &datum = hash_queues_[partition][batch_idx].values[0];
+        return reinterpret_cast<const uint64_t *>(
+            datum.array()->buffers[1]->data());
+    }
+
+    Status SpillingAccumulationQueue::ZipBatchesAndHashes(size_t partition)
+    {
+        // Append the hash column to the corresponding batch in queues_
+        for(size_t i = 0; i < queues_[partition].batch_count(); i++)
+            queues_[partition][i].values.push_back(
+                std::move(
+                    hash_queues_[partition][i].values[0]));
+        return Status::OK();
+    }
+
+    Status SpillingAccumulationQueue::ReadBackHashes(
+        size_t thread_index,
+        size_t partition,
+        std::function<Status(size_t, AccumulationQueue)> on_finished)
+    {
+        if(partition >= hash_cursor_.load())
+        {
+            RETURN_NOT_OK(ZipBatchesAndHashes(partition));
+            return on_finished(thread_index, std::move(queues_[partition]));
+        }
+
+        hash_queues_[partition].Resize(hash_files_[partition].num_batches());
+        return hash_files_[partition].ReadBackBatches(
+            ctx_,
+            [this, partition](size_t idx, ExecBatch hash)
+            {
+                
queues_[partition][idx].values.push_back(std::move(hash.values[0]));
+                return Status::OK();
+            },
+            [this, partition, on_finished](size_t thread_index)
+            {
+                RETURN_NOT_OK(hash_files_[partition].Cleanup());
+                return on_finished(thread_index, 
std::move(queues_[partition]));
+            });
+    }
+
+    Status SpillingAccumulationQueue::GetPartition(
+        size_t thread_index,
+        size_t partition,
+        std::function<Status(size_t, AccumulationQueue)> on_finished)
+    {
+        if(partition >= spilling_cursor_.load())
+        {
+            ARROW_DCHECK(partition >= hash_cursor_.load());
+            RETURN_NOT_OK(ZipBatchesAndHashes(partition));
+            if(builders_[partition].num_rows() != 0)
+                queues_[partition].InsertBatch(builders_[partition].Flush());
+            return on_finished(thread_index, std::move(queues_[partition]));
+        }
+
+        queues_[partition].Resize(files_[partition].num_batches());
+        auto on_finished_insert_builder =
+            [this, on_finished, partition](size_t thread_index, 
AccumulationQueue queue)
+            {
+                if(builders_[partition].num_rows() != 0)
+                    queue.InsertBatch(builders_[partition].Flush());
+                return on_finished(thread_index, std::move(queue));
+            };
+
+        return files_[partition].ReadBackBatches(
+            ctx_,
+            [this, partition](size_t idx, ExecBatch batch)
+            {
+                queues_[partition].SetBatch(idx, std::move(batch));
+                return Status::OK();
+            },
+            [this, partition, on_finished_insert_builder](size_t thread_index)
+            {
+                RETURN_NOT_OK(files_[partition].Cleanup());
+                return ReadBackHashes(thread_index, partition, 
std::move(on_finished_insert_builder));
+            });
+    }
+
+    Result<bool> SpillingAccumulationQueue::AdvanceSpillCursor()
+    {
+        size_t to_spill = spilling_cursor_.fetch_add(1);
+        if(to_spill >= kNumPartitions)
+        {
+            // This is maybe overly paranoid, but it's conceivable to overflow 
and
+            // wrap back around to 0 on 32-bit platforms
+            spilling_cursor_.fetch_sub(1, std::memory_order_relaxed);

Review Comment:
   Added an assert. 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to