This is an automated email from the ASF dual-hosted git repository.
icexelloss pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new 9baefea1bc GH-32884: [C++] Add ordered aggregation (#34311)
9baefea1bc is described below
commit 9baefea1bca62e390219bd321b5915b4fba99279
Author: rtpsw <[email protected]>
AuthorDate: Fri Mar 10 16:58:36 2023 +0200
GH-32884: [C++] Add ordered aggregation (#34311)
This PR implements "Segmented Aggregation" to the existing aggregation
node to improve aggregation on ordered data.
A segment group is defined as "a continuous chunk of data that have the
same segment key value. e.g, if the input data looks like
```
[0, 0, 0, 1, 2, 2]
```
Then there are three segments `[0, 0, 0]` `[1]` `[2, 2]`
(Note the "group" in "segment group" here is added to differentiate from
"segment", which is defined as "a continuous chunk of data with in a
ExecBatch")
Segment aggregation can be used to replace existing hash aggregation in
the case that data are ordered. The benefit of this is
(1) We can output aggregation result earlier (as soon as a segment group
is fully consumed).
(2) We only need to hold partial aggregation for one segment group to
reduce memory usage.
See https://issues.apache.org/jira/browse/ARROW-17642
Replaces #14352
* Closes: #32884
Follow ups
=======
* #34475
* #34529
---------
Co-authored-by: Li Jin <[email protected]>
---
cpp/src/arrow/compute/exec.cc | 12 +
cpp/src/arrow/compute/exec.h | 8 +
cpp/src/arrow/compute/exec/aggregate_node.cc | 309 +++++-
cpp/src/arrow/compute/exec/exec_plan.h | 3 +-
cpp/src/arrow/compute/exec/options.h | 24 +-
cpp/src/arrow/compute/exec/plan_test.cc | 103 ++
.../arrow/compute/kernels/hash_aggregate_test.cc | 1002 +++++++++++++++++---
cpp/src/arrow/compute/row/grouper.cc | 347 ++++++-
cpp/src/arrow/compute/row/grouper.h | 78 +-
cpp/src/arrow/compute/row/grouper_internal.h | 27 +
cpp/src/arrow/scalar.h | 13 +
11 files changed, 1731 insertions(+), 195 deletions(-)
diff --git a/cpp/src/arrow/compute/exec.cc b/cpp/src/arrow/compute/exec.cc
index 15f8b263ed..c18dfa0952 100644
--- a/cpp/src/arrow/compute/exec.cc
+++ b/cpp/src/arrow/compute/exec.cc
@@ -147,6 +147,18 @@ ExecBatch ExecBatch::Slice(int64_t offset, int64_t length)
const {
return out;
}
+Result<ExecBatch> ExecBatch::SelectValues(const std::vector<int>& ids) const {
+ std::vector<Datum> selected_values;
+ selected_values.reserve(ids.size());
+ for (int id : ids) {
+ if (id < 0 || static_cast<size_t>(id) >= values.size()) {
+ return Status::Invalid("ExecBatch invalid value selection: ", id);
+ }
+ selected_values.push_back(values[id]);
+ }
+ return ExecBatch(std::move(selected_values), length);
+}
+
namespace {
enum LengthInferenceError {
diff --git a/cpp/src/arrow/compute/exec.h b/cpp/src/arrow/compute/exec.h
index 8128d84a12..338740f066 100644
--- a/cpp/src/arrow/compute/exec.h
+++ b/cpp/src/arrow/compute/exec.h
@@ -181,6 +181,12 @@ struct ARROW_EXPORT ExecBatch {
/// \brief Infer the ExecBatch length from values.
static Result<int64_t> InferLength(const std::vector<Datum>& values);
+ /// Creates an ExecBatch with length-validation.
+ ///
+ /// If any value is given, then all values must have a common length. If the
given
+ /// length is negative, then the length of the ExecBatch is set to this
common length,
+ /// or to 1 if no values are given. Otherwise, the given length must equal
the common
+ /// length, if any value is given.
static Result<ExecBatch> Make(std::vector<Datum> values, int64_t length =
-1);
Result<std::shared_ptr<RecordBatch>> ToRecordBatch(
@@ -240,6 +246,8 @@ struct ARROW_EXPORT ExecBatch {
ExecBatch Slice(int64_t offset, int64_t length) const;
+ Result<ExecBatch> SelectValues(const std::vector<int>& ids) const;
+
/// \brief A convenience for returning the types from the batch.
std::vector<TypeHolder> GetTypes() const {
std::vector<TypeHolder> result;
diff --git a/cpp/src/arrow/compute/exec/aggregate_node.cc
b/cpp/src/arrow/compute/exec/aggregate_node.cc
index aa9d832f90..62d4ac81d7 100644
--- a/cpp/src/arrow/compute/exec/aggregate_node.cc
+++ b/cpp/src/arrow/compute/exec/aggregate_node.cc
@@ -19,6 +19,7 @@
#include <sstream>
#include <thread>
#include <unordered_map>
+#include <unordered_set>
#include "arrow/compute/exec.h"
#include "arrow/compute/exec/exec_plan.h"
@@ -35,6 +36,25 @@
#include "arrow/util/thread_pool.h"
#include "arrow/util/tracing_internal.h"
+// This file implements both regular and segmented group-by aggregation, which
is a
+// generalization of ordered aggregation in which the key columns are not
required to be
+// ordered.
+//
+// In (regular) group-by aggregation, the input rows are partitioned into
groups using a
+// set of columns called keys, where in a given group each row has the same
values for
+// these columns. In segmented group-by aggregation, a second set of columns
called
+// segment-keys is used to refine the partitioning. However, segment-keys are
different in
+// that they partition only consecutive rows into a single group. Such a
partition of
+// consecutive rows is called a segment group. For example, consider a column
X with
+// values [A, A, B, A] at row-indices [0, 1, 2]. A regular group-by
aggregation with keys
+// [X] yields a row-index partitioning [[0, 1, 3], [2]] whereas a
segmented-group-by
+// aggregation with segment-keys [X] yields [[0, 1], [1], [3]].
+//
+// The implementation first segments the input using the segment-keys, then
groups by the
+// keys. When a segment group end is reached while scanning the input, output
is pushed
+// and the accumulating state is cleared. If no segment-keys are given, then
the entire
+// input is taken as one segment group. One batch per segment group is sent to
output.
+
namespace arrow {
using internal::checked_cast;
@@ -43,8 +63,6 @@ namespace compute {
namespace {
-namespace {
-
std::vector<TypeHolder> ExtendWithGroupIdType(const std::vector<TypeHolder>&
in_types) {
std::vector<TypeHolder> aggr_in_types;
aggr_in_types.reserve(in_types.size() + 1);
@@ -141,8 +159,6 @@ Result<FieldVector> ResolveKernels(
return fields;
}
-} // namespace
-
void AggregatesToString(std::stringstream* ss, const Schema& input_schema,
const std::vector<Aggregate>& aggs,
const std::vector<std::vector<int>>& target_fieldsets,
@@ -169,20 +185,79 @@ void AggregatesToString(std::stringstream* ss, const
Schema& input_schema,
*ss << ']';
}
+// Extract segments from a batch and run the given handler on them. Note that
the
+// handle may be called on open segments which are not yet finished.
Typically a
+// handler should accumulate those open segments until a closed segment is
reached.
+template <typename BatchHandler>
+Status HandleSegments(RowSegmenter* segmenter, const ExecBatch& batch,
+ const std::vector<int>& ids, const BatchHandler&
handle_batch) {
+ int64_t offset = 0;
+ ARROW_ASSIGN_OR_RAISE(auto segment_exec_batch, batch.SelectValues(ids));
+ ExecSpan segment_batch(segment_exec_batch);
+
+ while (true) {
+ ARROW_ASSIGN_OR_RAISE(compute::Segment segment,
+ segmenter->GetNextSegment(segment_batch, offset));
+ if (segment.offset >= segment_batch.length) break; // condition of
no-next-segment
+ ARROW_RETURN_NOT_OK(handle_batch(batch, segment));
+ offset = segment.offset + segment.length;
+ }
+ return Status::OK();
+}
+
+/// @brief Extract values of segment keys from a segment batch
+/// @param[out] values_ptr Vector to store the extracted segment key values
+/// @param[in] input_batch Segment batch. Must have the a constant value for
segment key
+/// @param[in] field_ids Segment key field ids
+Status ExtractSegmenterValues(std::vector<Datum>* values_ptr,
+ const ExecBatch& input_batch,
+ const std::vector<int>& field_ids) {
+ DCHECK_GT(input_batch.length, 0);
+ std::vector<Datum>& values = *values_ptr;
+ int64_t row = input_batch.length - 1;
+ values.clear();
+ values.resize(field_ids.size());
+ for (size_t i = 0; i < field_ids.size(); i++) {
+ const Datum& value = input_batch.values[field_ids[i]];
+ if (value.is_scalar()) {
+ values[i] = value;
+ } else if (value.is_array()) {
+ ARROW_ASSIGN_OR_RAISE(auto scalar, value.make_array()->GetScalar(row));
+ values[i] = scalar;
+ } else {
+ DCHECK(false);
+ }
+ }
+ return Status::OK();
+}
+
+void PlaceFields(ExecBatch& batch, size_t base, std::vector<Datum>& values) {
+ DCHECK_LE(base + values.size(), batch.values.size());
+ for (size_t i = 0; i < values.size(); i++) {
+ batch.values[base + i] = values[i];
+ }
+}
+
class ScalarAggregateNode : public ExecNode, public TracedNode {
public:
ScalarAggregateNode(ExecPlan* plan, std::vector<ExecNode*> inputs,
std::shared_ptr<Schema> output_schema,
+ std::unique_ptr<RowSegmenter> segmenter,
+ std::vector<int> segment_field_ids,
std::vector<std::vector<int>> target_fieldsets,
std::vector<Aggregate> aggs,
std::vector<const ScalarAggregateKernel*> kernels,
+ std::vector<std::vector<TypeHolder>> kernel_intypes,
std::vector<std::vector<std::unique_ptr<KernelState>>>
states)
: ExecNode(plan, std::move(inputs), {"target"},
/*output_schema=*/std::move(output_schema)),
TracedNode(this),
+ segmenter_(std::move(segmenter)),
+ segment_field_ids_(std::move(segment_field_ids)),
target_fieldsets_(std::move(target_fieldsets)),
aggs_(std::move(aggs)),
kernels_(std::move(kernels)),
+ kernel_intypes_(std::move(kernel_intypes)),
states_(std::move(states)) {}
static Result<ExecNode*> Make(ExecPlan* plan, std::vector<ExecNode*> inputs,
@@ -191,13 +266,40 @@ class ScalarAggregateNode : public ExecNode, public
TracedNode {
const auto& aggregate_options = checked_cast<const
AggregateNodeOptions&>(options);
auto aggregates = aggregate_options.aggregates;
+ const auto& keys = aggregate_options.keys;
+ const auto& segment_keys = aggregate_options.segment_keys;
+
+ if (keys.size() > 0) {
+ return Status::Invalid("Scalar aggregation with some key");
+ }
+ if (plan->query_context()->exec_context()->executor()->GetCapacity() > 1 &&
+ segment_keys.size() > 0) {
+ return Status::NotImplemented("Segmented aggregation in a multi-threaded
plan");
+ }
const auto& input_schema = *inputs[0]->output_schema();
auto exec_ctx = plan->query_context()->exec_context();
+ std::vector<int> segment_field_ids(segment_keys.size());
+ std::vector<TypeHolder> segment_key_types(segment_keys.size());
+ for (size_t i = 0; i < segment_keys.size(); i++) {
+ ARROW_ASSIGN_OR_RAISE(FieldPath match,
segment_keys[i].FindOne(input_schema));
+ if (match.indices().size() > 1) {
+ // ARROW-18369: Support nested references as segment ids
+ return Status::Invalid("Nested references cannot be used as segment
ids");
+ }
+ segment_field_ids[i] = match[0];
+ segment_key_types[i] = input_schema.field(match[0])->type().get();
+ }
+
+ ARROW_ASSIGN_OR_RAISE(auto segmenter,
+ RowSegmenter::Make(std::move(segment_key_types),
+ /*nullable_keys=*/false,
exec_ctx));
+
+ std::vector<std::vector<TypeHolder>> kernel_intypes(aggregates.size());
std::vector<const ScalarAggregateKernel*> kernels(aggregates.size());
std::vector<std::vector<std::unique_ptr<KernelState>>>
states(kernels.size());
- FieldVector fields(kernels.size());
+ FieldVector fields(kernels.size() + segment_keys.size());
std::vector<std::vector<int>> target_fieldsets(kernels.size());
for (size_t i = 0; i < kernels.size(); ++i) {
@@ -225,7 +327,9 @@ class ScalarAggregateNode : public ExecNode, public
TracedNode {
for (const auto& target : target_fieldsets[i]) {
in_types.emplace_back(input_schema.field(target)->type().get());
}
- ARROW_ASSIGN_OR_RAISE(const Kernel* kernel,
function->DispatchExact(in_types));
+ kernel_intypes[i] = in_types;
+ ARROW_ASSIGN_OR_RAISE(const Kernel* kernel,
+ function->DispatchExact(kernel_intypes[i]));
kernels[i] = static_cast<const ScalarAggregateKernel*>(kernel);
if (aggregates[i].options == nullptr) {
@@ -239,20 +343,26 @@ class ScalarAggregateNode : public ExecNode, public
TracedNode {
KernelContext kernel_ctx{exec_ctx};
states[i].resize(plan->query_context()->max_concurrency());
RETURN_NOT_OK(Kernel::InitAll(
- &kernel_ctx, KernelInitArgs{kernels[i], in_types,
aggregates[i].options.get()},
+ &kernel_ctx,
+ KernelInitArgs{kernels[i], kernel_intypes[i],
aggregates[i].options.get()},
&states[i]));
// pick one to resolve the kernel signature
kernel_ctx.SetState(states[i][0].get());
ARROW_ASSIGN_OR_RAISE(auto out_type,
kernels[i]->signature->out_type().Resolve(
- &kernel_ctx, in_types));
+ &kernel_ctx,
kernel_intypes[i]));
fields[i] = field(aggregate_options.aggregates[i].name,
out_type.GetSharedPtr());
}
+ for (size_t i = 0; i < segment_keys.size(); ++i) {
+ ARROW_ASSIGN_OR_RAISE(fields[kernels.size() + i],
+
segment_keys[i].GetOne(*inputs[0]->output_schema()));
+ }
return plan->EmplaceNode<ScalarAggregateNode>(
- plan, std::move(inputs), schema(std::move(fields)),
std::move(target_fieldsets),
- std::move(aggregates), std::move(kernels), std::move(states));
+ plan, std::move(inputs), schema(std::move(fields)),
std::move(segmenter),
+ std::move(segment_field_ids), std::move(target_fieldsets),
std::move(aggregates),
+ std::move(kernels), std::move(kernel_intypes), std::move(states));
}
const char* kind_name() const override { return "ScalarAggregateNode"; }
@@ -283,28 +393,46 @@ class ScalarAggregateNode : public ExecNode, public
TracedNode {
DCHECK_EQ(input, inputs_[0]);
auto thread_index = plan_->query_context()->GetThreadIndex();
-
- ARROW_RETURN_NOT_OK(DoConsume(ExecSpan(batch), thread_index));
+ auto handler = [this, thread_index](const ExecBatch& full_batch,
+ const Segment& segment) {
+ // (1) The segment is starting of a new segment group and points to
+ // the beginning of the batch, then it means no data in the batch belongs
+ // to the current segment group. We can output and reset kernel states.
+ if (!segment.extends && segment.offset == 0)
RETURN_NOT_OK(OutputResult(false));
+
+ // We add segment to the current segment group aggregation
+ auto exec_batch = full_batch.Slice(segment.offset, segment.length);
+ RETURN_NOT_OK(DoConsume(ExecSpan(exec_batch), thread_index));
+ RETURN_NOT_OK(
+ ExtractSegmenterValues(&segmenter_values_, exec_batch,
segment_field_ids_));
+
+ // If the segment closes the current segment group, we can output
segment group
+ // aggregation.
+ if (!segment.is_open) RETURN_NOT_OK(OutputResult(false));
+
+ return Status::OK();
+ };
+ RETURN_NOT_OK(HandleSegments(segmenter_.get(), batch, segment_field_ids_,
handler));
if (input_counter_.Increment()) {
- return Finish();
+ RETURN_NOT_OK(OutputResult(/*is_last=*/true));
}
return Status::OK();
}
Status InputFinished(ExecNode* input, int total_batches) override {
+ auto scope = TraceFinish();
EVENT_ON_CURRENT_SPAN("InputFinished", {{"batches.length",
total_batches}});
DCHECK_EQ(input, inputs_[0]);
if (input_counter_.SetTotal(total_batches)) {
- return Finish();
+ RETURN_NOT_OK(OutputResult(/*is_last=*/true));
}
return Status::OK();
}
Status StartProducing() override {
NoteStartProducing(ToStringExtra());
- // Scalar aggregates will only output a single batch
- return output_->InputFinished(this, 1);
+ return Status::OK();
}
void PauseProducing(ExecNode* output, int32_t counter) override {
@@ -326,10 +454,22 @@ class ScalarAggregateNode : public ExecNode, public
TracedNode {
}
private:
- Status Finish() {
- auto scope = TraceFinish();
+ Status ResetKernelStates() {
+ auto exec_ctx = plan()->query_context()->exec_context();
+ for (size_t i = 0; i < kernels_.size(); ++i) {
+ states_[i].resize(plan()->query_context()->max_concurrency());
+ KernelContext kernel_ctx{exec_ctx};
+ RETURN_NOT_OK(Kernel::InitAll(
+ &kernel_ctx,
+ KernelInitArgs{kernels_[i], kernel_intypes_[i],
aggs_[i].options.get()},
+ &states_[i]));
+ }
+ return Status::OK();
+ }
+
+ Status OutputResult(bool is_last) {
ExecBatch batch{{}, 1};
- batch.values.resize(kernels_.size());
+ batch.values.resize(kernels_.size() + segment_field_ids_.size());
for (size_t i = 0; i < kernels_.size(); ++i) {
util::tracing::Span span;
@@ -343,29 +483,54 @@ class ScalarAggregateNode : public ExecNode, public
TracedNode {
kernels_[i], &ctx,
std::move(states_[i])));
RETURN_NOT_OK(kernels_[i]->finalize(&ctx, &batch.values[i]));
}
+ PlaceFields(batch, kernels_.size(), segmenter_values_);
- return output_->InputReceived(this, std::move(batch));
+ ARROW_RETURN_NOT_OK(output_->InputReceived(this, std::move(batch)));
+ total_output_batches_++;
+ if (is_last) {
+ ARROW_RETURN_NOT_OK(output_->InputFinished(this, total_output_batches_));
+ } else {
+ ARROW_RETURN_NOT_OK(ResetKernelStates());
+ }
+ return Status::OK();
}
+ // A segmenter for the segment-keys
+ std::unique_ptr<RowSegmenter> segmenter_;
+ // Field indices corresponding to the segment-keys
+ const std::vector<int> segment_field_ids_;
+ // Holds the value of segment keys of the most recent input batch
+ // The values are updated everytime an input batch is processed
+ std::vector<Datum> segmenter_values_;
+
const std::vector<std::vector<int>> target_fieldsets_;
const std::vector<Aggregate> aggs_;
const std::vector<const ScalarAggregateKernel*> kernels_;
+ // Input type holders for each kernel, used for state initialization
+ std::vector<std::vector<TypeHolder>> kernel_intypes_;
std::vector<std::vector<std::unique_ptr<KernelState>>> states_;
AtomicCounter input_counter_;
+ /// \brief Total number of output batches produced
+ int total_output_batches_ = 0;
};
class GroupByNode : public ExecNode, public TracedNode {
public:
GroupByNode(ExecNode* input, std::shared_ptr<Schema> output_schema,
- std::vector<int> key_field_ids,
+ std::vector<int> key_field_ids, std::vector<int>
segment_key_field_ids,
+ std::unique_ptr<RowSegmenter> segmenter,
+ std::vector<std::vector<TypeHolder>> agg_src_types,
std::vector<std::vector<int>> agg_src_fieldsets,
std::vector<Aggregate> aggs,
std::vector<const HashAggregateKernel*> agg_kernels)
: ExecNode(input->plan(), {input}, {"groupby"},
std::move(output_schema)),
TracedNode(this),
+ segmenter_(std::move(segmenter)),
key_field_ids_(std::move(key_field_ids)),
+ segment_key_field_ids_(std::move(segment_key_field_ids)),
+ agg_src_types_(std::move(agg_src_types)),
agg_src_fieldsets_(std::move(agg_src_fieldsets)),
aggs_(std::move(aggs)),
agg_kernels_(std::move(agg_kernels)) {}
@@ -384,9 +549,15 @@ class GroupByNode : public ExecNode, public TracedNode {
auto input = inputs[0];
const auto& aggregate_options = checked_cast<const
AggregateNodeOptions&>(options);
const auto& keys = aggregate_options.keys;
+ const auto& segment_keys = aggregate_options.segment_keys;
// Copy (need to modify options pointer below)
auto aggs = aggregate_options.aggregates;
+ if (plan->query_context()->exec_context()->executor()->GetCapacity() > 1 &&
+ segment_keys.size() > 0) {
+ return Status::NotImplemented("Segmented aggregation in a multi-threaded
plan");
+ }
+
// Get input schema
auto input_schema = input->output_schema();
@@ -397,6 +568,23 @@ class GroupByNode : public ExecNode, public TracedNode {
key_field_ids[i] = match[0];
}
+ // Find input field indices for segment key fields
+ std::vector<int> segment_key_field_ids(segment_keys.size());
+ for (size_t i = 0; i < segment_keys.size(); ++i) {
+ ARROW_ASSIGN_OR_RAISE(auto match,
segment_keys[i].FindOne(*input_schema));
+ segment_key_field_ids[i] = match[0];
+ }
+
+ // Check key fields and segment key fields are disjoint
+ std::unordered_set<int> key_field_id_set(key_field_ids.begin(),
key_field_ids.end());
+ for (const auto& segment_key_field_id : segment_key_field_ids) {
+ if (key_field_id_set.find(segment_key_field_id) !=
key_field_id_set.end()) {
+ return Status::Invalid("Group-by aggregation with field '",
+
input_schema->field(segment_key_field_id)->name(),
+ "' as both key and segment key");
+ }
+ }
+
// Find input field indices for aggregates
std::vector<std::vector<int>> agg_src_fieldsets(aggs.size());
for (size_t i = 0; i < aggs.size(); ++i) {
@@ -415,8 +603,19 @@ class GroupByNode : public ExecNode, public TracedNode {
}
}
+ // Build vector of segment key field data types
+ std::vector<TypeHolder> segment_key_types(segment_keys.size());
+ for (size_t i = 0; i < segment_keys.size(); ++i) {
+ auto segment_key_field_id = segment_key_field_ids[i];
+ segment_key_types[i] =
input_schema->field(segment_key_field_id)->type().get();
+ }
+
auto ctx = plan->query_context()->exec_context();
+ ARROW_ASSIGN_OR_RAISE(auto segmenter,
+ RowSegmenter::Make(std::move(segment_key_types),
+ /*nullable_keys=*/false, ctx));
+
// Construct aggregates
ARROW_ASSIGN_OR_RAISE(auto agg_kernels, GetKernels(ctx, aggs,
agg_src_types));
@@ -428,7 +627,7 @@ class GroupByNode : public ExecNode, public TracedNode {
ResolveKernels(aggs, agg_kernels, agg_states, ctx, agg_src_types));
// Build field vector for output schema
- FieldVector output_fields{keys.size() + aggs.size()};
+ FieldVector output_fields{keys.size() + segment_keys.size() + aggs.size()};
// Aggregate fields come before key fields to match the behavior of
GroupBy function
for (size_t i = 0; i < aggs.size(); ++i) {
@@ -440,12 +639,24 @@ class GroupByNode : public ExecNode, public TracedNode {
int key_field_id = key_field_ids[i];
output_fields[base + i] = input_schema->field(key_field_id);
}
+ base += keys.size();
+ for (size_t i = 0; i < segment_keys.size(); ++i) {
+ int segment_key_field_id = segment_key_field_ids[i];
+ output_fields[base + i] = input_schema->field(segment_key_field_id);
+ }
return input->plan()->EmplaceNode<GroupByNode>(
input, schema(std::move(output_fields)), std::move(key_field_ids),
+ std::move(segment_key_field_ids), std::move(segmenter),
std::move(agg_src_types),
std::move(agg_src_fieldsets), std::move(aggs), std::move(agg_kernels));
}
+ Status ResetKernelStates() {
+ auto ctx = plan()->query_context()->exec_context();
+ ARROW_RETURN_NOT_OK(InitKernels(agg_kernels_, ctx, aggs_, agg_src_types_));
+ return Status::OK();
+ }
+
const char* kind_name() const override { return "GroupByNode"; }
Status Consume(ExecSpan batch) {
@@ -542,7 +753,8 @@ class GroupByNode : public ExecNode, public TracedNode {
RETURN_NOT_OK(InitLocalStateIfNeeded(state));
ExecBatch out_data{{}, state->grouper->num_groups()};
- out_data.values.resize(agg_kernels_.size() + key_field_ids_.size());
+ out_data.values.resize(agg_kernels_.size() + key_field_ids_.size() +
+ segment_key_field_ids_.size());
// Aggregate fields come before key fields to match the behavior of
GroupBy function
for (size_t i = 0; i < agg_kernels_.size(); ++i) {
@@ -561,6 +773,7 @@ class GroupByNode : public ExecNode, public TracedNode {
ARROW_ASSIGN_OR_RAISE(ExecBatch out_keys, state->grouper->GetUniques());
std::move(out_keys.values.begin(), out_keys.values.end(),
out_data.values.begin() + agg_kernels_.size());
+ PlaceFields(out_data, agg_kernels_.size() + key_field_ids_.size(),
segmenter_values_);
state->grouper.reset();
return out_data;
}
@@ -570,8 +783,7 @@ class GroupByNode : public ExecNode, public TracedNode {
return output_->InputReceived(this, out_data_.Slice(batch_size * n,
batch_size));
}
- Status OutputResult() {
- auto scope = TraceFinish();
+ Status OutputResult(bool is_last) {
// To simplify merging, ensure that the first grouper is nonempty
for (size_t i = 0; i < local_states_.size(); i++) {
if (local_states_[i].grouper) {
@@ -584,9 +796,18 @@ class GroupByNode : public ExecNode, public TracedNode {
ARROW_ASSIGN_OR_RAISE(out_data_, Finalize());
int64_t num_output_batches = bit_util::CeilDiv(out_data_.length,
output_batch_size());
- RETURN_NOT_OK(output_->InputFinished(this,
static_cast<int>(num_output_batches)));
- return plan_->query_context()->StartTaskGroup(output_task_group_id_,
- num_output_batches);
+ total_output_batches_ += static_cast<int>(num_output_batches);
+ if (is_last) {
+ ARROW_RETURN_NOT_OK(output_->InputFinished(this, total_output_batches_));
+
RETURN_NOT_OK(plan_->query_context()->StartTaskGroup(output_task_group_id_,
+
num_output_batches));
+ } else {
+ for (int64_t i = 0; i < num_output_batches; i++) {
+ ARROW_RETURN_NOT_OK(OutputNthBatch(i));
+ }
+ ARROW_RETURN_NOT_OK(ResetKernelStates());
+ }
+ return Status::OK();
}
Status InputReceived(ExecNode* input, ExecBatch batch) override {
@@ -594,19 +815,31 @@ class GroupByNode : public ExecNode, public TracedNode {
DCHECK_EQ(input, inputs_[0]);
- ARROW_RETURN_NOT_OK(Consume(ExecSpan(batch)));
+ auto handler = [this](const ExecBatch& full_batch, const Segment& segment)
{
+ if (!segment.extends && segment.offset == 0)
RETURN_NOT_OK(OutputResult(false));
+ auto exec_batch = full_batch.Slice(segment.offset, segment.length);
+ auto batch = ExecSpan(exec_batch);
+ RETURN_NOT_OK(Consume(batch));
+ RETURN_NOT_OK(
+ ExtractSegmenterValues(&segmenter_values_, exec_batch,
segment_key_field_ids_));
+ if (!segment.is_open) RETURN_NOT_OK(OutputResult(false));
+ return Status::OK();
+ };
+ ARROW_RETURN_NOT_OK(
+ HandleSegments(segmenter_.get(), batch, segment_key_field_ids_,
handler));
if (input_counter_.Increment()) {
- return OutputResult();
+ ARROW_RETURN_NOT_OK(OutputResult(/*is_last=*/true));
}
return Status::OK();
}
Status InputFinished(ExecNode* input, int total_batches) override {
+ auto scope = TraceFinish();
DCHECK_EQ(input, inputs_[0]);
if (input_counter_.SetTotal(total_batches)) {
- return OutputResult();
+ RETURN_NOT_OK(OutputResult(/*is_last=*/true));
}
return Status::OK();
}
@@ -619,12 +852,12 @@ class GroupByNode : public ExecNode, public TracedNode {
void PauseProducing(ExecNode* output, int32_t counter) override {
// TODO(ARROW-16260)
- // Without spillover there is way to handle backpressure in this node
+ // Without spillover there is no way to handle backpressure in this node
}
void ResumeProducing(ExecNode* output, int32_t counter) override {
// TODO(ARROW-16260)
- // Without spillover there is way to handle backpressure in this node
+ // Without spillover there is no way to handle backpressure in this node
}
Status StopProducingImpl() override { return Status::OK(); }
@@ -697,13 +930,23 @@ class GroupByNode : public ExecNode, public TracedNode {
}
int output_task_group_id_;
+ /// \brief A segmenter for the segment-keys
+ std::unique_ptr<RowSegmenter> segmenter_;
+ /// \brief Holds values of the current batch that were selected for the
segment-keys
+ std::vector<Datum> segmenter_values_;
const std::vector<int> key_field_ids_;
+ /// \brief Field indices corresponding to the segment-keys
+ const std::vector<int> segment_key_field_ids_;
+ /// \brief Types of input fields per aggregate
+ const std::vector<std::vector<TypeHolder>> agg_src_types_;
const std::vector<std::vector<int>> agg_src_fieldsets_;
const std::vector<Aggregate> aggs_;
const std::vector<const HashAggregateKernel*> agg_kernels_;
AtomicCounter input_counter_;
+ /// \brief Total number of output batches produced
+ int total_output_batches_ = 0;
std::vector<ThreadLocalState> local_states_;
ExecBatch out_data_;
diff --git a/cpp/src/arrow/compute/exec/exec_plan.h
b/cpp/src/arrow/compute/exec/exec_plan.h
index 83b9248eb6..7f9c19938d 100644
--- a/cpp/src/arrow/compute/exec/exec_plan.h
+++ b/cpp/src/arrow/compute/exec/exec_plan.h
@@ -241,8 +241,7 @@ class ARROW_EXPORT ExecNode {
/// concurrently, potentially even before the call to StartProducing
/// has finished.
/// - PauseProducing(), ResumeProducing(), StopProducing() may be called
- /// by the downstream nodes' InputReceived(), ErrorReceived(),
InputFinished()
- /// methods
+ /// by the downstream nodes' InputReceived(), InputFinished() methods
///
/// StopProducing may be called due to an error, by the user (e.g. cancel),
or
/// because a node has all the data it needs (e.g. limit, top-k on sorted
data).
diff --git a/cpp/src/arrow/compute/exec/options.h
b/cpp/src/arrow/compute/exec/options.h
index f532dd1c09..b0057d73da 100644
--- a/cpp/src/arrow/compute/exec/options.h
+++ b/cpp/src/arrow/compute/exec/options.h
@@ -221,21 +221,37 @@ class ARROW_EXPORT ProjectNodeOptions : public
ExecNodeOptions {
std::vector<std::string> names;
};
-/// \brief Make a node which aggregates input batches, optionally grouped by
keys.
+/// \brief Make a node which aggregates input batches, optionally grouped by
keys and
+/// optionally segmented by segment-keys. Both keys and segment-keys determine
the group.
+/// However segment-keys are also used for determining grouping segments,
which should be
+/// large, and allow streaming a partial aggregation result after processing
each segment.
+/// One common use-case for segment-keys is ordered aggregation, in which the
segment-key
+/// attribute specifies a column with non-decreasing values or a
lexicographically-ordered
+/// set of such columns.
///
/// If the keys attribute is a non-empty vector, then each aggregate in
`aggregates` is
/// expected to be a HashAggregate function. If the keys attribute is an empty
vector,
/// then each aggregate is assumed to be a ScalarAggregate function.
+///
+/// If the segment_keys attribute is a non-empty vector, then segmented
aggregation, as
+/// described above, applies.
+///
+/// The keys and segment_keys vectors must be disjoint.
class ARROW_EXPORT AggregateNodeOptions : public ExecNodeOptions {
public:
explicit AggregateNodeOptions(std::vector<Aggregate> aggregates,
- std::vector<FieldRef> keys = {})
- : aggregates(std::move(aggregates)), keys(std::move(keys)) {}
+ std::vector<FieldRef> keys = {},
+ std::vector<FieldRef> segment_keys = {})
+ : aggregates(std::move(aggregates)),
+ keys(std::move(keys)),
+ segment_keys(std::move(segment_keys)) {}
// aggregations which will be applied to the targetted fields
std::vector<Aggregate> aggregates;
- // keys by which aggregations will be grouped
+ // keys by which aggregations will be grouped (optional)
std::vector<FieldRef> keys;
+ // keys by which aggregations will be segmented (optional)
+ std::vector<FieldRef> segment_keys;
};
constexpr int32_t kDefaultBackpressureHighBytes = 1 << 30; // 1GiB
diff --git a/cpp/src/arrow/compute/exec/plan_test.cc
b/cpp/src/arrow/compute/exec/plan_test.cc
index eac4d12a06..9eed429b3d 100644
--- a/cpp/src/arrow/compute/exec/plan_test.cc
+++ b/cpp/src/arrow/compute/exec/plan_test.cc
@@ -1578,5 +1578,108 @@ TEST(ExecPlan, SourceEnforcesBatchLimit) {
}
}
+TEST(ExecPlanExecution, SegmentedAggregationWithMultiThreading) {
+ BatchesWithSchema data;
+ data.batches = {ExecBatchFromJSON({int32()}, "[[1]]")};
+ data.schema = schema({field("i32", int32())});
+ Declaration plan = Declaration::Sequence(
+ {{"source",
+ SourceNodeOptions{data.schema, data.gen(/*parallel=*/false,
/*slow=*/false)}},
+ {"aggregate", AggregateNodeOptions{/*aggregates=*/{
+ {"count", nullptr, "i32",
"count(i32)"},
+ },
+ /*keys=*/{"i32"},
/*segment_leys=*/{"i32"}}}});
+ EXPECT_RAISES_WITH_MESSAGE_THAT(NotImplemented, HasSubstr("multi-threaded"),
+ DeclarationToExecBatches(std::move(plan)));
+}
+
+TEST(ExecPlanExecution, SegmentedAggregationWithOneSegment) {
+ BatchesWithSchema data;
+ data.batches = {
+ ExecBatchFromJSON({int32(), int32(), int32()}, "[[1, 1, 1], [1, 2, 1],
[1, 1, 2]]"),
+ ExecBatchFromJSON({int32(), int32(), int32()},
+ "[[1, 2, 2], [1, 1, 3], [1, 2, 3]]")};
+ data.schema = schema({
+ field("a", int32()),
+ field("b", int32()),
+ field("c", int32()),
+ });
+
+ Declaration plan = Declaration::Sequence(
+ {{"source",
+ SourceNodeOptions{data.schema, data.gen(/*parallel=*/false,
/*slow=*/false)}},
+ {"aggregate", AggregateNodeOptions{/*aggregates=*/{
+ {"hash_sum", nullptr, "c",
"sum(c)"},
+ {"hash_mean", nullptr, "c",
"mean(c)"},
+ },
+ /*keys=*/{"b"},
/*segment_leys=*/{"a"}}}});
+ ASSERT_OK_AND_ASSIGN(BatchesWithCommonSchema actual_batches,
+ DeclarationToExecBatches(std::move(plan),
/*use_threads=*/false));
+
+ auto expected = ExecBatchFromJSON({int64(), float64(), int32(), int32()},
+ R"([[6, 2, 1, 1], [6, 2, 2, 1]])");
+ AssertExecBatchesEqualIgnoringOrder(actual_batches.schema,
actual_batches.batches,
+ {expected});
+}
+
+TEST(ExecPlanExecution, SegmentedAggregationWithTwoSegments) {
+ BatchesWithSchema data;
+ data.batches = {
+ ExecBatchFromJSON({int32(), int32(), int32()}, "[[1, 1, 1], [1, 2, 1],
[1, 1, 2]]"),
+ ExecBatchFromJSON({int32(), int32(), int32()},
+ "[[2, 2, 2], [2, 1, 3], [2, 2, 3]]")};
+ data.schema = schema({
+ field("a", int32()),
+ field("b", int32()),
+ field("c", int32()),
+ });
+
+ Declaration plan = Declaration::Sequence(
+ {{"source",
+ SourceNodeOptions{data.schema, data.gen(/*parallel=*/false,
/*slow=*/false)}},
+ {"aggregate", AggregateNodeOptions{/*aggregates=*/{
+ {"hash_sum", nullptr, "c",
"sum(c)"},
+ {"hash_mean", nullptr, "c",
"mean(c)"},
+ },
+ /*keys=*/{"b"},
/*segment_leys=*/{"a"}}}});
+ ASSERT_OK_AND_ASSIGN(BatchesWithCommonSchema actual_batches,
+ DeclarationToExecBatches(std::move(plan),
/*use_threads=*/false));
+
+ auto expected = ExecBatchFromJSON(
+ {int64(), float64(), int32(), int32()},
+ R"([[3, 1.5, 1, 1], [1, 1, 2, 1], [3, 3, 1, 2], [5, 2.5, 2, 2]])");
+ AssertExecBatchesEqualIgnoringOrder(actual_batches.schema,
actual_batches.batches,
+ {expected});
+}
+
+TEST(ExecPlanExecution, SegmentedAggregationWithBatchCrossingSegment) {
+ BatchesWithSchema data;
+ data.batches = {
+ ExecBatchFromJSON({int32(), int32(), int32()}, "[[1, 1, 1], [1, 1, 1],
[2, 2, 2]]"),
+ ExecBatchFromJSON({int32(), int32(), int32()},
+ "[[2, 2, 2], [3, 3, 3], [3, 3, 3]]")};
+ data.schema = schema({
+ field("a", int32()),
+ field("b", int32()),
+ field("c", int32()),
+ });
+
+ Declaration plan = Declaration::Sequence(
+ {{"source",
+ SourceNodeOptions{data.schema, data.gen(/*parallel=*/false,
/*slow=*/false)}},
+ {"aggregate", AggregateNodeOptions{/*aggregates=*/{
+ {"hash_sum", nullptr, "c",
"sum(c)"},
+ {"hash_mean", nullptr, "c",
"mean(c)"},
+ },
+ /*keys=*/{"b"},
/*segment_leys=*/{"a"}}}});
+ ASSERT_OK_AND_ASSIGN(BatchesWithCommonSchema actual_batches,
+ DeclarationToExecBatches(std::move(plan),
/*use_threads=*/false));
+
+ auto expected = ExecBatchFromJSON({int64(), float64(), int32(), int32()},
+ R"([[2, 1, 1, 1], [4, 2, 2, 2], [6, 3, 3,
3]])");
+ AssertExecBatchesEqualIgnoringOrder(actual_batches.schema,
actual_batches.batches,
+ {expected});
+}
+
} // namespace compute
} // namespace arrow
diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc
b/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc
index 3c3476d62d..fd631e0dc5 100644
--- a/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc
+++ b/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc
@@ -42,6 +42,7 @@
#include "arrow/compute/kernels/test_util.h"
#include "arrow/compute/registry.h"
#include "arrow/compute/row/grouper.h"
+#include "arrow/compute/row/grouper_internal.h"
#include "arrow/table.h"
#include "arrow/testing/generator.h"
#include "arrow/testing/gtest_util.h"
@@ -72,6 +73,10 @@ using internal::ToChars;
namespace compute {
namespace {
+using GroupByFunction = std::function<Result<Datum>(
+ const std::vector<Datum>&, const std::vector<Datum>&, const
std::vector<Datum>&,
+ const std::vector<Aggregate>&, bool, bool)>;
+
Result<Datum> NaiveGroupBy(std::vector<Datum> arguments, std::vector<Datum>
keys,
const std::vector<Aggregate>& aggregates) {
ARROW_ASSIGN_OR_RAISE(auto key_batch, ExecBatch::Make(std::move(keys)));
@@ -135,22 +140,99 @@ Result<Datum> NaiveGroupBy(std::vector<Datum> arguments,
std::vector<Datum> keys
return Take(struct_arr, sorted_indices);
}
+Result<Datum> MakeGroupByOutput(const std::vector<ExecBatch>& output_batches,
+ const std::shared_ptr<Schema> output_schema,
+ size_t num_aggregates, size_t num_keys, bool
naive) {
+ ArrayVector out_arrays(num_aggregates + num_keys);
+ for (size_t i = 0; i < out_arrays.size(); ++i) {
+ std::vector<std::shared_ptr<Array>> arrays(output_batches.size());
+ for (size_t j = 0; j < output_batches.size(); ++j) {
+ arrays[j] = output_batches[j].values[i].make_array();
+ }
+ if (arrays.empty()) {
+ ARROW_ASSIGN_OR_RAISE(
+ out_arrays[i],
+ MakeArrayOfNull(output_schema->field(static_cast<int>(i))->type(),
+ /*length=*/0));
+ } else {
+ ARROW_ASSIGN_OR_RAISE(out_arrays[i], Concatenate(arrays));
+ }
+ }
+
+ ARROW_ASSIGN_OR_RAISE(
+ std::shared_ptr<Array> struct_arr,
+ StructArray::Make(std::move(out_arrays), output_schema->fields()));
+
+ bool need_sort = !naive;
+ for (size_t i = num_aggregates; need_sort && i < out_arrays.size(); i++) {
+ if (output_schema->field(static_cast<int>(i))->type()->id() ==
Type::DICTIONARY) {
+ need_sort = false;
+ }
+ }
+ if (!need_sort) {
+ return struct_arr;
+ }
+
+ // The exec plan may reorder the output rows. The tests are all setup to
expect ouptut
+ // in ascending order of keys. So we need to sort the result by the key
columns. To do
+ // that we create a table using the key columns, calculate the sort indices
from that
+ // table (sorting on all fields) and then use those indices to calculate our
result.
+ std::vector<std::shared_ptr<Field>> key_fields;
+ std::vector<std::shared_ptr<Array>> key_columns;
+ std::vector<SortKey> sort_keys;
+ for (std::size_t i = 0; i < num_keys; i++) {
+ const std::shared_ptr<Array>& arr = out_arrays[i + num_aggregates];
+ key_columns.push_back(arr);
+ key_fields.push_back(field("name_does_not_matter", arr->type()));
+ sort_keys.emplace_back(static_cast<int>(i));
+ }
+ std::shared_ptr<Schema> key_schema = schema(std::move(key_fields));
+ std::shared_ptr<Table> key_table = Table::Make(std::move(key_schema),
key_columns);
+ SortOptions sort_options(std::move(sort_keys));
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Array> sort_indices,
+ SortIndices(key_table, sort_options));
+
+ return Take(struct_arr, sort_indices);
+}
+
Result<Datum> RunGroupBy(const BatchesWithSchema& input,
const std::vector<std::string>& key_names,
- const std::vector<Aggregate>& aggregates, bool
use_threads) {
+ const std::vector<std::string>& segment_key_names,
+ const std::vector<Aggregate>& aggregates,
ExecContext* ctx,
+ bool use_threads, bool segmented = false, bool naive
= false) {
+ // The `use_threads` flag determines whether threads are used in generating
the input to
+ // the group-by.
+ //
+ // When segment_keys is non-empty the `segmented` flag is always true;
otherwise (when
+ // empty), it may still be set to true. In this case, the tester
restructures (without
+ // changing the data of) the result of RunGroupBy from
`std::vector<ExecBatch>`
+ // (output_batches) to `std::vector<ArrayVector>` (out_arrays), which have
the structure
+ // typical of the case of a non-empty segment_keys (with multiple arrays per
column, one
+ // array per segment) but only one array per column (because, technically,
there is only
+ // one segment in this case). Thus, this case focuses on the structure of
the result.
+ //
+ // The `naive` flag means that the output is expected to be like that of
`NaiveGroupBy`,
+ // which in particular doesn't require sorting. The reason for the naive
flag is that
+ // the expected output of some test-cases is naive and of some others it is
not. The
+ // current `RunGroupBy` function deals with both kinds of expected output.
std::vector<FieldRef> keys(key_names.size());
for (size_t i = 0; i < key_names.size(); ++i) {
keys[i] = FieldRef(key_names[i]);
}
+ std::vector<FieldRef> segment_keys(segment_key_names.size());
+ for (size_t i = 0; i < segment_key_names.size(); ++i) {
+ segment_keys[i] = FieldRef(segment_key_names[i]);
+ }
- ARROW_ASSIGN_OR_RAISE(auto plan, ExecPlan::Make(*threaded_exec_context()));
+ ARROW_ASSIGN_OR_RAISE(auto plan, ExecPlan::Make(*ctx));
AsyncGenerator<std::optional<ExecBatch>> sink_gen;
RETURN_NOT_OK(
Declaration::Sequence(
{
{"source",
SourceNodeOptions{input.schema, input.gen(use_threads,
/*slow=*/false)}},
- {"aggregate", AggregateNodeOptions{std::move(aggregates),
std::move(keys)}},
+ {"aggregate", AggregateNodeOptions{std::move(aggregates),
std::move(keys),
+ std::move(segment_keys)}},
{"sink", SinkNodeOptions{&sink_gen}},
})
.AddToPlan(plan.get()));
@@ -174,81 +256,117 @@ Result<Datum> RunGroupBy(const BatchesWithSchema& input,
ARROW_ASSIGN_OR_RAISE(std::vector<ExecBatch> output_batches,
start_and_collect.MoveResult());
- ArrayVector out_arrays(aggregates.size() + key_names.size());
const auto& output_schema = plan->nodes()[0]->output()->output_schema();
+ if (!segmented) {
+ return MakeGroupByOutput(output_batches, output_schema, aggregates.size(),
+ key_names.size(), naive);
+ }
+
+ std::vector<ArrayVector> out_arrays(aggregates.size() + key_names.size() +
+ segment_key_names.size());
for (size_t i = 0; i < out_arrays.size(); ++i) {
std::vector<std::shared_ptr<Array>> arrays(output_batches.size());
for (size_t j = 0; j < output_batches.size(); ++j) {
- arrays[j] = output_batches[j].values[i].make_array();
+ auto& value = output_batches[j].values[i];
+ if (value.is_scalar()) {
+ ARROW_ASSIGN_OR_RAISE(
+ arrays[j], MakeArrayFromScalar(*value.scalar(),
output_batches[j].length));
+ } else if (value.is_array()) {
+ arrays[j] = value.make_array();
+ } else {
+ return Status::Invalid("GroupByUsingExecPlan unsupported value kind ",
+ ToString(value.kind()));
+ }
}
if (arrays.empty()) {
+ arrays.resize(1);
ARROW_ASSIGN_OR_RAISE(
- out_arrays[i],
- MakeArrayOfNull(output_schema->field(static_cast<int>(i))->type(),
- /*length=*/0));
- } else {
- ARROW_ASSIGN_OR_RAISE(out_arrays[i], Concatenate(arrays));
+ arrays[0],
MakeArrayOfNull(output_schema->field(static_cast<int>(i))->type(),
+ /*length=*/0));
}
+ out_arrays[i] = {std::move(arrays)};
}
- // The exec plan may reorder the output rows. The tests are all setup to
expect ouptut
- // in ascending order of keys. So we need to sort the result by the key
columns. To do
- // that we create a table using the key columns, calculate the sort indices
from that
- // table (sorting on all fields) and then use those indices to calculate our
result.
- std::vector<std::shared_ptr<Field>> key_fields;
- std::vector<std::shared_ptr<Array>> key_columns;
- std::vector<SortKey> sort_keys;
- for (std::size_t i = 0; i < key_names.size(); i++) {
- const std::shared_ptr<Array>& arr = out_arrays[i + aggregates.size()];
- if (arr->type_id() == Type::DICTIONARY) {
- // Can't sort dictionary columns so need to decode
- auto dict_arr = checked_pointer_cast<DictionaryArray>(arr);
- ARROW_ASSIGN_OR_RAISE(auto decoded_arr,
- Take(*dict_arr->dictionary(),
*dict_arr->indices()));
- key_columns.push_back(decoded_arr);
- key_fields.push_back(
- field("name_does_not_matter", dict_arr->dict_type()->value_type()));
- } else {
- key_columns.push_back(arr);
- key_fields.push_back(field("name_does_not_matter", arr->type()));
+ if (segmented && segment_key_names.size() > 0) {
+ ArrayVector struct_arrays;
+ struct_arrays.reserve(output_batches.size());
+ for (size_t j = 0; j < output_batches.size(); ++j) {
+ ArrayVector struct_fields;
+ struct_fields.reserve(out_arrays.size());
+ for (auto out_array : out_arrays) {
+ struct_fields.push_back(out_array[j]);
+ }
+ ARROW_ASSIGN_OR_RAISE(auto struct_array,
+ StructArray::Make(struct_fields,
output_schema->fields()));
+ struct_arrays.push_back(struct_array);
}
- sort_keys.emplace_back(static_cast<int>(i));
+ return ChunkedArray::Make(struct_arrays);
+ } else {
+ ArrayVector struct_fields(out_arrays.size());
+ for (size_t i = 0; i < out_arrays.size(); ++i) {
+ ARROW_ASSIGN_OR_RAISE(struct_fields[i], Concatenate(out_arrays[i]));
+ }
+ return StructArray::Make(std::move(struct_fields),
output_schema->fields());
}
- std::shared_ptr<Schema> key_schema = schema(std::move(key_fields));
- std::shared_ptr<Table> key_table = Table::Make(std::move(key_schema),
key_columns);
- SortOptions sort_options(std::move(sort_keys));
- ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Array> sort_indices,
- SortIndices(key_table, sort_options));
+}
- ARROW_ASSIGN_OR_RAISE(
- std::shared_ptr<Array> struct_arr,
- StructArray::Make(std::move(out_arrays), output_schema->fields()));
+Result<Datum> RunGroupBy(const BatchesWithSchema& input,
+ const std::vector<std::string>& key_names,
+ const std::vector<std::string>& segment_key_names,
+ const std::vector<Aggregate>& aggregates, bool
use_threads,
+ bool segmented = false, bool naive = false) {
+ if (segment_key_names.size() > 0) {
+ ARROW_ASSIGN_OR_RAISE(auto thread_pool,
arrow::internal::ThreadPool::Make(1));
+ ExecContext seq_ctx(default_memory_pool(), thread_pool.get());
+ return RunGroupBy(input, key_names, segment_key_names, aggregates,
&seq_ctx,
+ use_threads, segmented, naive);
+ } else {
+ return RunGroupBy(input, key_names, segment_key_names, aggregates,
+ threaded_exec_context(), use_threads, segmented, naive);
+ }
+}
- return Take(struct_arr, sort_indices);
+Result<Datum> RunGroupBy(const BatchesWithSchema& input,
+ const std::vector<std::string>& key_names,
+ const std::vector<Aggregate>& aggregates, bool
use_threads,
+ bool segmented = false, bool naive = false) {
+ return RunGroupBy(input, key_names, {}, aggregates, use_threads, segmented);
}
/// Simpler overload where you can give the columns as datums
Result<Datum> RunGroupBy(const std::vector<Datum>& arguments,
const std::vector<Datum>& keys,
- const std::vector<Aggregate>& aggregates,
- bool use_threads = false) {
+ const std::vector<Datum>& segment_keys,
+ const std::vector<Aggregate>& aggregates, bool
use_threads,
+ bool segmented = false, bool naive = false) {
using arrow::compute::detail::ExecSpanIterator;
- FieldVector scan_fields(arguments.size() + keys.size());
+ FieldVector scan_fields(arguments.size() + keys.size() +
segment_keys.size());
std::vector<std::string> key_names(keys.size());
+ std::vector<std::string> segment_key_names(segment_keys.size());
for (size_t i = 0; i < arguments.size(); ++i) {
auto name = std::string("agg_") + ToChars(i);
scan_fields[i] = field(name, arguments[i].type());
}
+ size_t base = arguments.size();
for (size_t i = 0; i < keys.size(); ++i) {
auto name = std::string("key_") + ToChars(i);
- scan_fields[arguments.size() + i] = field(name, keys[i].type());
+ scan_fields[base + i] = field(name, keys[i].type());
key_names[i] = std::move(name);
}
+ base += keys.size();
+ size_t j = keys.size();
+ std::string prefix("key_");
+ for (size_t i = 0; i < segment_keys.size(); ++i) {
+ auto name = prefix + std::to_string(j++);
+ scan_fields[base + i] = field(name, segment_keys[i].type());
+ segment_key_names[i] = std::move(name);
+ }
std::vector<Datum> inputs = arguments;
- inputs.reserve(inputs.size() + keys.size());
+ inputs.reserve(inputs.size() + keys.size() + segment_keys.size());
inputs.insert(inputs.end(), keys.begin(), keys.end());
+ inputs.insert(inputs.end(), segment_keys.begin(), segment_keys.end());
ExecSpanIterator span_iterator;
ARROW_ASSIGN_OR_RAISE(auto batch, ExecBatch::Make(inputs));
@@ -261,15 +379,35 @@ Result<Datum> RunGroupBy(const std::vector<Datum>&
arguments,
input.batches.push_back(span.ToExecBatch());
}
- return RunGroupBy(input, key_names, aggregates, use_threads);
+ return RunGroupBy(input, key_names, segment_key_names, aggregates,
use_threads,
+ segmented, naive);
+}
+
+Result<Datum> RunGroupByImpl(const std::vector<Datum>& arguments,
+ const std::vector<Datum>& keys,
+ const std::vector<Datum>& segment_keys,
+ const std::vector<Aggregate>& aggregates, bool
use_threads,
+ bool naive = false) {
+ return RunGroupBy(arguments, keys, segment_keys, aggregates, use_threads,
+ /*segmented=*/false, naive);
}
-void ValidateGroupBy(const std::vector<Aggregate>& aggregates,
- std::vector<Datum> arguments, std::vector<Datum> keys) {
+Result<Datum> RunSegmentedGroupByImpl(const std::vector<Datum>& arguments,
+ const std::vector<Datum>& keys,
+ const std::vector<Datum>& segment_keys,
+ const std::vector<Aggregate>& aggregates,
+ bool use_threads, bool naive = false) {
+ return RunGroupBy(arguments, keys, segment_keys, aggregates, use_threads,
+ /*segmented=*/true, naive);
+}
+
+void ValidateGroupBy(GroupByFunction group_by, const std::vector<Aggregate>&
aggregates,
+ std::vector<Datum> arguments, std::vector<Datum> keys,
+ bool naive = true) {
ASSERT_OK_AND_ASSIGN(Datum expected, NaiveGroupBy(arguments, keys,
aggregates));
- ASSERT_OK_AND_ASSIGN(Datum actual, RunGroupBy(arguments, keys, aggregates,
- /*use_threads=*/false));
+ ASSERT_OK_AND_ASSIGN(Datum actual, group_by(arguments, keys, {}, aggregates,
+ /*use_threads=*/false, naive));
ASSERT_OK(expected.make_array()->ValidateFull());
ValidateOutput(actual);
@@ -290,8 +428,9 @@ struct TestAggregate {
std::shared_ptr<FunctionOptions> options;
};
-Result<Datum> GroupByTest(const std::vector<Datum>& arguments,
+Result<Datum> GroupByTest(GroupByFunction group_by, const std::vector<Datum>&
arguments,
const std::vector<Datum>& keys,
+ const std::vector<Datum>& segment_keys,
const std::vector<TestAggregate>& aggregates,
bool use_threads) {
std::vector<Aggregate> internal_aggregates;
@@ -301,27 +440,36 @@ Result<Datum> GroupByTest(const std::vector<Datum>&
arguments,
{t_agg.function, t_agg.options, "agg_" + ToChars(idx),
t_agg.function});
idx = idx + 1;
}
- return RunGroupBy(arguments, keys, internal_aggregates, use_threads);
+ return group_by(arguments, keys, segment_keys, internal_aggregates,
use_threads,
+ /*naive=*/false);
}
-} // namespace
+Result<Datum> GroupByTest(GroupByFunction group_by, const std::vector<Datum>&
arguments,
+ const std::vector<Datum>& keys,
+ const std::vector<TestAggregate>& aggregates,
+ bool use_threads) {
+ return GroupByTest(group_by, arguments, keys, {}, aggregates, use_threads);
+}
-TEST(Grouper, SupportedKeys) {
- ASSERT_OK(Grouper::Make({boolean()}));
+template <typename GroupClass>
+void TestGroupClassSupportedKeys(
+ std::function<Result<std::unique_ptr<GroupClass>>(const
std::vector<TypeHolder>&)>
+ make_func) {
+ ASSERT_OK(make_func({boolean()}));
- ASSERT_OK(Grouper::Make({int8(), uint16(), int32(), uint64()}));
+ ASSERT_OK(make_func({int8(), uint16(), int32(), uint64()}));
- ASSERT_OK(Grouper::Make({dictionary(int64(), utf8())}));
+ ASSERT_OK(make_func({dictionary(int64(), utf8())}));
- ASSERT_OK(Grouper::Make({float16(), float32(), float64()}));
+ ASSERT_OK(make_func({float16(), float32(), float64()}));
- ASSERT_OK(Grouper::Make({utf8(), binary(), large_utf8(), large_binary()}));
+ ASSERT_OK(make_func({utf8(), binary(), large_utf8(), large_binary()}));
- ASSERT_OK(Grouper::Make({fixed_size_binary(16), fixed_size_binary(32)}));
+ ASSERT_OK(make_func({fixed_size_binary(16), fixed_size_binary(32)}));
- ASSERT_OK(Grouper::Make({decimal128(32, 10), decimal256(76, 20)}));
+ ASSERT_OK(make_func({decimal128(32, 10), decimal256(76, 20)}));
- ASSERT_OK(Grouper::Make({date32(), date64()}));
+ ASSERT_OK(make_func({date32(), date64()}));
for (auto unit : {
TimeUnit::SECOND,
@@ -329,25 +477,257 @@ TEST(Grouper, SupportedKeys) {
TimeUnit::MICRO,
TimeUnit::NANO,
}) {
- ASSERT_OK(Grouper::Make({timestamp(unit), duration(unit)}));
+ ASSERT_OK(make_func({timestamp(unit), duration(unit)}));
}
ASSERT_OK(
- Grouper::Make({day_time_interval(), month_interval(),
month_day_nano_interval()}));
+ make_func({day_time_interval(), month_interval(),
month_day_nano_interval()}));
+
+ ASSERT_OK(make_func({null()}));
- ASSERT_OK(Grouper::Make({null()}));
+ ASSERT_RAISES(NotImplemented, make_func({struct_({field("", int64())})}));
- ASSERT_RAISES(NotImplemented, Grouper::Make({struct_({field("",
int64())})}));
+ ASSERT_RAISES(NotImplemented, make_func({struct_({})}));
- ASSERT_RAISES(NotImplemented, Grouper::Make({struct_({})}));
+ ASSERT_RAISES(NotImplemented, make_func({list(int32())}));
- ASSERT_RAISES(NotImplemented, Grouper::Make({list(int32())}));
+ ASSERT_RAISES(NotImplemented, make_func({fixed_size_list(int32(), 5)}));
- ASSERT_RAISES(NotImplemented, Grouper::Make({fixed_size_list(int32(), 5)}));
+ ASSERT_RAISES(NotImplemented, make_func({dense_union({field("",
int32())})}));
+}
+
+void TestSegments(std::unique_ptr<RowSegmenter>& segmenter, const ExecSpan&
batch,
+ std::vector<Segment> expected_segments) {
+ int64_t offset = 0, segment_num = 0;
+ for (auto expected_segment : expected_segments) {
+ SCOPED_TRACE("segment #" + ToChars(segment_num++));
+ ASSERT_OK_AND_ASSIGN(auto segment, segmenter->GetNextSegment(batch,
offset));
+ ASSERT_EQ(expected_segment, segment);
+ offset = segment.offset + segment.length;
+ }
+}
- ASSERT_RAISES(NotImplemented, Grouper::Make({dense_union({field("",
int32())})}));
+Result<std::unique_ptr<Grouper>> MakeGrouper(const std::vector<TypeHolder>&
key_types) {
+ return Grouper::Make(key_types, default_exec_context());
+}
+
+Result<std::unique_ptr<RowSegmenter>> MakeRowSegmenter(
+ const std::vector<TypeHolder>& key_types) {
+ return RowSegmenter::Make(key_types, /*nullable_leys=*/false,
default_exec_context());
+}
+
+Result<std::unique_ptr<RowSegmenter>> MakeGenericSegmenter(
+ const std::vector<TypeHolder>& key_types) {
+ return MakeAnyKeysSegmenter(key_types, default_exec_context());
+}
+
+} // namespace
+
+TEST(RowSegmenter, SupportedKeys) {
+ TestGroupClassSupportedKeys<RowSegmenter>(MakeRowSegmenter);
+}
+
+TEST(RowSegmenter, Basics) {
+ std::vector<TypeHolder> bad_types2 = {int32(), float32()};
+ std::vector<TypeHolder> types2 = {int32(), int32()};
+ std::vector<TypeHolder> bad_types1 = {float32()};
+ std::vector<TypeHolder> types1 = {int32()};
+ std::vector<TypeHolder> types0 = {};
+ auto batch2 = ExecBatchFromJSON(types2, "[[1, 1], [1, 2], [2, 2]]");
+ auto batch1 = ExecBatchFromJSON(types1, "[[1], [1], [2]]");
+ ExecBatch batch0({}, 3);
+ {
+ SCOPED_TRACE("offset");
+ ASSERT_OK_AND_ASSIGN(auto segmenter, MakeRowSegmenter(types0));
+ ExecSpan span0(batch0);
+ for (int64_t offset : {-1, 4}) {
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
+ HasSubstr("invalid grouping segmenter
offset"),
+ segmenter->GetNextSegment(span0,
offset));
+ }
+ }
+ {
+ SCOPED_TRACE("types0 segmenting of batch2");
+ ASSERT_OK_AND_ASSIGN(auto segmenter, MakeRowSegmenter(types0));
+ ExecSpan span2(batch2);
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, HasSubstr("expected batch size 0
"),
+ segmenter->GetNextSegment(span2, 0));
+ ExecSpan span0(batch0);
+ TestSegments(segmenter, span0, {{0, 3, true, true}, {3, 0, true, true}});
+ }
+ {
+ SCOPED_TRACE("bad_types1 segmenting of batch1");
+ ASSERT_OK_AND_ASSIGN(auto segmenter, MakeRowSegmenter(bad_types1));
+ ExecSpan span1(batch1);
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, HasSubstr("expected batch value 0
of type "),
+ segmenter->GetNextSegment(span1, 0));
+ }
+ {
+ SCOPED_TRACE("types1 segmenting of batch2");
+ ASSERT_OK_AND_ASSIGN(auto segmenter, MakeRowSegmenter(types1));
+ ExecSpan span2(batch2);
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, HasSubstr("expected batch size 1
"),
+ segmenter->GetNextSegment(span2, 0));
+ ExecSpan span1(batch1);
+ TestSegments(segmenter, span1,
+ {{0, 2, false, true}, {2, 1, true, false}, {3, 0, true,
true}});
+ }
+ {
+ SCOPED_TRACE("bad_types2 segmenting of batch2");
+ ASSERT_OK_AND_ASSIGN(auto segmenter, MakeRowSegmenter(bad_types2));
+ ExecSpan span2(batch2);
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, HasSubstr("expected batch value 1
of type "),
+ segmenter->GetNextSegment(span2, 0));
+ }
+ {
+ SCOPED_TRACE("types2 segmenting of batch1");
+ ASSERT_OK_AND_ASSIGN(auto segmenter, MakeRowSegmenter(types2));
+ ExecSpan span1(batch1);
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, HasSubstr("expected batch size 2
"),
+ segmenter->GetNextSegment(span1, 0));
+ ExecSpan span2(batch2);
+ TestSegments(segmenter, span2,
+ {{0, 1, false, true},
+ {1, 1, false, false},
+ {2, 1, true, false},
+ {3, 0, true, true}});
+ }
+}
+
+TEST(RowSegmenter, NonOrdered) {
+ std::vector<TypeHolder> types = {int32()};
+ auto batch = ExecBatchFromJSON(types, "[[1], [1], [2], [1], [2]]");
+ ASSERT_OK_AND_ASSIGN(auto segmenter, MakeRowSegmenter(types));
+ TestSegments(segmenter, ExecSpan(batch),
+ {{0, 2, false, true},
+ {2, 1, false, false},
+ {3, 1, false, false},
+ {4, 1, true, false},
+ {5, 0, true, true}});
+}
+
+TEST(RowSegmenter, EmptyBatches) {
+ std::vector<TypeHolder> types = {int32()};
+ std::vector<ExecBatch> batches = {
+ ExecBatchFromJSON(types, "[]"), ExecBatchFromJSON(types, "[]"),
+ ExecBatchFromJSON(types, "[[1]]"), ExecBatchFromJSON(types, "[]"),
+ ExecBatchFromJSON(types, "[[1]]"), ExecBatchFromJSON(types, "[]"),
+ ExecBatchFromJSON(types, "[[2], [2]]"), ExecBatchFromJSON(types, "[]"),
+ };
+ ASSERT_OK_AND_ASSIGN(auto segmenter, MakeRowSegmenter(types));
+ TestSegments(segmenter, ExecSpan(batches[0]), {});
+ TestSegments(segmenter, ExecSpan(batches[1]), {});
+ TestSegments(segmenter, ExecSpan(batches[2]), {{0, 1, true, true}});
+ TestSegments(segmenter, ExecSpan(batches[3]), {});
+ TestSegments(segmenter, ExecSpan(batches[4]), {{0, 1, true, true}});
+ TestSegments(segmenter, ExecSpan(batches[5]), {});
+ TestSegments(segmenter, ExecSpan(batches[6]), {{0, 2, true, false}});
+ TestSegments(segmenter, ExecSpan(batches[7]), {});
+}
+
+TEST(RowSegmenter, MultipleSegments) {
+ std::vector<TypeHolder> types = {int32()};
+ auto batch = ExecBatchFromJSON(types, "[[1], [1], [2], [5], [3], [3], [5],
[5], [4]]");
+ ASSERT_OK_AND_ASSIGN(auto segmenter, MakeRowSegmenter(types));
+ TestSegments(segmenter, ExecSpan(batch),
+ {{0, 2, false, true},
+ {2, 1, false, false},
+ {3, 1, false, false},
+ {4, 2, false, false},
+ {6, 2, false, false},
+ {8, 1, true, false},
+ {9, 0, true, true}});
+}
+
+namespace {
+
+void TestRowSegmenterConstantBatch(
+ std::function<ArgShape(size_t i)> shape_func,
+ std::function<Result<std::unique_ptr<RowSegmenter>>(const
std::vector<TypeHolder>&)>
+ make_segmenter) {
+ constexpr size_t n = 3, repetitions = 3;
+ std::vector<TypeHolder> types = {int32(), int32(), int32()};
+ std::vector<ArgShape> shapes(n);
+ for (size_t i = 0; i < n; i++) shapes[i] = shape_func(i);
+ auto full_batch = ExecBatchFromJSON(types, shapes, "[[1, 1, 1], [1, 1, 1],
[1, 1, 1]]");
+ auto test_by_size = [&](size_t size) -> Status {
+ SCOPED_TRACE("constant-batch with " + ToChars(size) + " key(s)");
+ std::vector<Datum> values(full_batch.values.begin(),
+ full_batch.values.begin() + size);
+ ExecBatch batch(values, full_batch.length);
+ std::vector<TypeHolder> key_types(types.begin(), types.begin() + size);
+ ARROW_ASSIGN_OR_RAISE(auto segmenter, make_segmenter(key_types));
+ for (size_t i = 0; i < repetitions; i++) {
+ TestSegments(segmenter, ExecSpan(batch), {{0, 3, true, true}, {3, 0,
true, true}});
+ ARROW_RETURN_NOT_OK(segmenter->Reset());
+ }
+ return Status::OK();
+ };
+ for (size_t i = 0; i <= 3; i++) {
+ ASSERT_OK(test_by_size(i));
+ }
}
+} // namespace
+
+TEST(RowSegmenter, ConstantArrayBatch) {
+ TestRowSegmenterConstantBatch([](size_t i) { return ArgShape::ARRAY; },
+ MakeRowSegmenter);
+}
+
+TEST(RowSegmenter, ConstantScalarBatch) {
+ TestRowSegmenterConstantBatch([](size_t i) { return ArgShape::SCALAR; },
+ MakeRowSegmenter);
+}
+
+TEST(RowSegmenter, ConstantMixedBatch) {
+ TestRowSegmenterConstantBatch(
+ [](size_t i) { return i % 2 == 0 ? ArgShape::SCALAR : ArgShape::ARRAY; },
+ MakeRowSegmenter);
+}
+
+TEST(RowSegmenter, ConstantArrayBatchWithAnyKeysSegmenter) {
+ TestRowSegmenterConstantBatch([](size_t i) { return ArgShape::ARRAY; },
+ MakeGenericSegmenter);
+}
+
+TEST(RowSegmenter, ConstantScalarBatchWithAnyKeysSegmenter) {
+ TestRowSegmenterConstantBatch([](size_t i) { return ArgShape::SCALAR; },
+ MakeGenericSegmenter);
+}
+
+TEST(RowSegmenter, ConstantMixedBatchWithAnyKeysSegmenter) {
+ TestRowSegmenterConstantBatch(
+ [](size_t i) { return i % 2 == 0 ? ArgShape::SCALAR : ArgShape::ARRAY; },
+ MakeGenericSegmenter);
+}
+
+TEST(RowSegmenter, RowConstantBatch) {
+ constexpr size_t n = 3;
+ std::vector<TypeHolder> types = {int32(), int32(), int32()};
+ auto full_batch = ExecBatchFromJSON(types, "[[1, 1, 1], [2, 2, 2], [3, 3,
3]]");
+ std::vector<Segment> expected_segments_for_size_0 = {{0, 3, true, true},
+ {3, 0, true, true}};
+ std::vector<Segment> expected_segments = {
+ {0, 1, false, true}, {1, 1, false, false}, {2, 1, true, false}, {3, 0,
true, true}};
+ auto test_by_size = [&](size_t size) -> Status {
+ SCOPED_TRACE("constant-batch with " + ToChars(size) + " key(s)");
+ std::vector<Datum> values(full_batch.values.begin(),
+ full_batch.values.begin() + size);
+ ExecBatch batch(values, full_batch.length);
+ std::vector<TypeHolder> key_types(types.begin(), types.begin() + size);
+ ARROW_ASSIGN_OR_RAISE(auto segmenter, MakeRowSegmenter(key_types));
+ TestSegments(segmenter, ExecSpan(batch),
+ size == 0 ? expected_segments_for_size_0 : expected_segments);
+ return Status::OK();
+ };
+ for (size_t i = 0; i <= n; i++) {
+ ASSERT_OK(test_by_size(i));
+ }
+}
+
+TEST(Grouper, SupportedKeys) {
TestGroupClassSupportedKeys<Grouper>(MakeGrouper); }
+
struct TestGrouper {
explicit TestGrouper(std::vector<TypeHolder> types, std::vector<ArgShape>
shapes = {})
: types_(std::move(types)), shapes_(std::move(shapes)) {
@@ -783,7 +1163,49 @@ TEST(Grouper, ScalarValues) {
}
}
-TEST(GroupBy, Errors) {
+void TestSegmentKey(GroupByFunction group_by, const std::shared_ptr<Table>&
table,
+ Datum output, const std::vector<Datum>& segment_keys);
+
+class GroupBy : public ::testing::TestWithParam<GroupByFunction> {
+ public:
+ void ValidateGroupBy(const std::vector<Aggregate>& aggregates,
+ std::vector<Datum> arguments, std::vector<Datum> keys,
+ bool naive = true) {
+ compute::ValidateGroupBy(GetParam(), aggregates, arguments, keys, naive);
+ }
+
+ Result<Datum> GroupByTest(const std::vector<Datum>& arguments,
+ const std::vector<Datum>& keys,
+ const std::vector<Datum>& segment_keys,
+ const std::vector<TestAggregate>& aggregates,
+ bool use_threads) {
+ return compute::GroupByTest(GetParam(), arguments, keys, segment_keys,
aggregates,
+ use_threads);
+ }
+
+ Result<Datum> GroupByTest(const std::vector<Datum>& arguments,
+ const std::vector<Datum>& keys,
+ const std::vector<TestAggregate>& aggregates,
+ bool use_threads) {
+ return compute::GroupByTest(GetParam(), arguments, keys, aggregates,
use_threads);
+ }
+
+ Result<Datum> AltGroupBy(const std::vector<Datum>& arguments,
+ const std::vector<Datum>& keys,
+ const std::vector<Datum>& segment_keys,
+ const std::vector<Aggregate>& aggregates,
+ bool use_threads = false) {
+ return GetParam()(arguments, keys, segment_keys, aggregates, use_threads,
+ /*naive=*/false);
+ }
+
+ void TestSegmentKey(const std::shared_ptr<Table>& table, Datum output,
+ const std::vector<Datum>& segment_keys) {
+ return compute::TestSegmentKey(GetParam(), table, output, segment_keys);
+ }
+};
+
+TEST_P(GroupBy, Errors) {
auto batch = RecordBatchFromJSON(
schema({field("argument", float64()), field("group_id", uint32())}), R"([
[1.0, 1],
@@ -804,7 +1226,7 @@ TEST(GroupBy, Errors) {
HasSubstr("Direct execution of HASH_AGGREGATE
functions")));
}
-TEST(GroupBy, NoBatches) {
+TEST_P(GroupBy, NoBatches) {
// Regression test for ARROW-14583: handle when no batches are
// passed to the group by node before finalizing
auto table =
@@ -851,7 +1273,7 @@ void SortBy(std::vector<std::string> names, Datum*
aggregated_and_grouped) {
}
} // namespace
-TEST(GroupBy, CountOnly) {
+TEST_P(GroupBy, CountOnly) {
for (bool use_threads : {true, false}) {
SCOPED_TRACE(use_threads ? "parallel/merged" : "serial");
@@ -897,7 +1319,7 @@ TEST(GroupBy, CountOnly) {
}
}
-TEST(GroupBy, CountScalar) {
+TEST_P(GroupBy, CountScalar) {
BatchesWithSchema input;
input.batches = {
ExecBatchFromJSON({int32(), int64()}, {ArgShape::SCALAR,
ArgShape::ARRAY},
@@ -937,7 +1359,7 @@ TEST(GroupBy, CountScalar) {
}
}
-TEST(GroupBy, SumOnly) {
+TEST_P(GroupBy, SumOnly) {
for (bool use_threads : {true, false}) {
SCOPED_TRACE(use_threads ? "parallel/merged" : "serial");
@@ -983,7 +1405,7 @@ TEST(GroupBy, SumOnly) {
}
}
-TEST(GroupBy, SumMeanProductDecimal) {
+TEST_P(GroupBy, SumMeanProductDecimal) {
auto in_schema = schema({
field("argument0", decimal128(3, 2)),
field("argument1", decimal256(3, 2)),
@@ -1057,7 +1479,7 @@ TEST(GroupBy, SumMeanProductDecimal) {
}
}
-TEST(GroupBy, MeanOnly) {
+TEST_P(GroupBy, MeanOnly) {
for (bool use_threads : {true, false}) {
SCOPED_TRACE(use_threads ? "parallel/merged" : "serial");
@@ -1108,7 +1530,7 @@ TEST(GroupBy, MeanOnly) {
}
}
-TEST(GroupBy, SumMeanProductScalar) {
+TEST_P(GroupBy, SumMeanProductScalar) {
BatchesWithSchema input;
input.batches = {
ExecBatchFromJSON({int32(), int64()}, {ArgShape::SCALAR,
ArgShape::ARRAY},
@@ -1146,7 +1568,7 @@ TEST(GroupBy, SumMeanProductScalar) {
}
}
-TEST(GroupBy, VarianceAndStddev) {
+TEST_P(GroupBy, VarianceAndStddev) {
auto batch = RecordBatchFromJSON(
schema({field("argument", int32()), field("key", int64())}), R"([
[1, 1],
@@ -1170,6 +1592,7 @@ TEST(GroupBy, VarianceAndStddev) {
{
batch->GetColumnByName("key"),
},
+ {},
{
{"hash_variance", nullptr},
{"hash_stddev", nullptr},
@@ -1212,6 +1635,7 @@ TEST(GroupBy, VarianceAndStddev) {
{
batch->GetColumnByName("key"),
},
+ {},
{
{"hash_variance",
nullptr},
{"hash_stddev",
nullptr},
@@ -1243,6 +1667,7 @@ TEST(GroupBy, VarianceAndStddev) {
{
batch->GetColumnByName("key"),
},
+ {},
{
{"hash_variance", variance_options},
{"hash_stddev", variance_options},
@@ -1264,7 +1689,7 @@ TEST(GroupBy, VarianceAndStddev) {
/*verbose=*/true);
}
-TEST(GroupBy, VarianceAndStddevDecimal) {
+TEST_P(GroupBy, VarianceAndStddevDecimal) {
auto batch = RecordBatchFromJSON(
schema({field("argument0", decimal128(3, 2)), field("argument1",
decimal128(3, 2)),
field("key", int64())}),
@@ -1290,6 +1715,7 @@ TEST(GroupBy, VarianceAndStddevDecimal) {
{
batch->GetColumnByName("key"),
},
+ {},
{
{"hash_variance", nullptr},
{"hash_stddev", nullptr},
@@ -1314,7 +1740,7 @@ TEST(GroupBy, VarianceAndStddevDecimal) {
/*verbose=*/true);
}
-TEST(GroupBy, TDigest) {
+TEST_P(GroupBy, TDigest) {
auto batch = RecordBatchFromJSON(
schema({field("argument", float64()), field("key", int64())}), R"([
[1, 1],
@@ -1359,6 +1785,7 @@ TEST(GroupBy, TDigest) {
{
batch->GetColumnByName("key"),
},
+ {},
{
{"hash_tdigest", nullptr},
{"hash_tdigest", options1},
@@ -1390,7 +1817,7 @@ TEST(GroupBy, TDigest) {
/*verbose=*/true);
}
-TEST(GroupBy, TDigestDecimal) {
+TEST_P(GroupBy, TDigestDecimal) {
auto batch = RecordBatchFromJSON(
schema({field("argument0", decimal128(3, 2)), field("argument1",
decimal256(3, 2)),
field("key", int64())}),
@@ -1433,7 +1860,7 @@ TEST(GroupBy, TDigestDecimal) {
/*verbose=*/true);
}
-TEST(GroupBy, ApproximateMedian) {
+TEST_P(GroupBy, ApproximateMedian) {
for (const auto& type : {float64(), int8()}) {
auto batch =
RecordBatchFromJSON(schema({field("argument", type), field("key",
int64())}), R"([
@@ -1471,6 +1898,7 @@ TEST(GroupBy, ApproximateMedian) {
{
batch->GetColumnByName("key"),
},
+ {},
{
{"hash_approximate_median", options},
{"hash_approximate_median", keep_nulls},
@@ -1498,7 +1926,7 @@ TEST(GroupBy, ApproximateMedian) {
}
}
-TEST(GroupBy, StddevVarianceTDigestScalar) {
+TEST_P(GroupBy, StddevVarianceTDigestScalar) {
BatchesWithSchema input;
input.batches = {
ExecBatchFromJSON({int32(), float32(), int64()},
@@ -1547,7 +1975,7 @@ TEST(GroupBy, StddevVarianceTDigestScalar) {
}
}
-TEST(GroupBy, VarianceOptions) {
+TEST_P(GroupBy, VarianceOptions) {
BatchesWithSchema input;
input.batches = {
ExecBatchFromJSON(
@@ -1641,7 +2069,7 @@ TEST(GroupBy, VarianceOptions) {
}
}
-TEST(GroupBy, MinMaxOnly) {
+TEST_P(GroupBy, MinMaxOnly) {
auto in_schema = schema({
field("argument", float64()),
field("argument1", null()),
@@ -1711,7 +2139,7 @@ TEST(GroupBy, MinMaxOnly) {
}
}
-TEST(GroupBy, MinMaxTypes) {
+TEST_P(GroupBy, MinMaxTypes) {
std::vector<std::shared_ptr<DataType>> types;
types.insert(types.end(), NumericTypes().begin(), NumericTypes().end());
types.insert(types.end(), TemporalTypes().begin(), TemporalTypes().end());
@@ -1799,7 +2227,7 @@ TEST(GroupBy, MinMaxTypes) {
}
}
-TEST(GroupBy, MinMaxDecimal) {
+TEST_P(GroupBy, MinMaxDecimal) {
auto in_schema = schema({
field("argument0", decimal128(3, 2)),
field("argument1", decimal256(3, 2)),
@@ -1866,7 +2294,7 @@ TEST(GroupBy, MinMaxDecimal) {
}
}
-TEST(GroupBy, MinMaxBinary) {
+TEST_P(GroupBy, MinMaxBinary) {
for (bool use_threads : {true, false}) {
for (const auto& ty : BaseBinaryTypes()) {
SCOPED_TRACE(use_threads ? "parallel/merged" : "serial");
@@ -1917,7 +2345,7 @@ TEST(GroupBy, MinMaxBinary) {
}
}
-TEST(GroupBy, MinMaxFixedSizeBinary) {
+TEST_P(GroupBy, MinMaxFixedSizeBinary) {
const auto ty = fixed_size_binary(3);
for (bool use_threads : {true, false}) {
SCOPED_TRACE(use_threads ? "parallel/merged" : "serial");
@@ -1967,7 +2395,7 @@ TEST(GroupBy, MinMaxFixedSizeBinary) {
}
}
-TEST(GroupBy, MinOrMax) {
+TEST_P(GroupBy, MinOrMax) {
auto table =
TableFromJSON(schema({field("argument", float64()), field("key",
int64())}), {R"([
[1.0, 1],
@@ -2020,7 +2448,7 @@ TEST(GroupBy, MinOrMax) {
/*verbose=*/true);
}
-TEST(GroupBy, MinMaxScalar) {
+TEST_P(GroupBy, MinMaxScalar) {
BatchesWithSchema input;
input.batches = {
ExecBatchFromJSON({int32(), int64()}, {ArgShape::SCALAR,
ArgShape::ARRAY},
@@ -2053,7 +2481,7 @@ TEST(GroupBy, MinMaxScalar) {
}
}
-TEST(GroupBy, AnyAndAll) {
+TEST_P(GroupBy, AnyAndAll) {
for (bool use_threads : {true, false}) {
SCOPED_TRACE(use_threads ? "parallel/merged" : "serial");
@@ -2087,7 +2515,7 @@ TEST(GroupBy, AnyAndAll) {
auto keep_nulls_min_count =
std::make_shared<ScalarAggregateOptions>(/*skip_nulls=*/false,
/*min_count=*/3);
ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped,
- RunGroupBy(
+ AltGroupBy(
{
table->GetColumnByName("argument"),
table->GetColumnByName("argument"),
@@ -2098,7 +2526,7 @@ TEST(GroupBy, AnyAndAll) {
table->GetColumnByName("argument"),
table->GetColumnByName("argument"),
},
- {table->GetColumnByName("key")},
+ {table->GetColumnByName("key")}, {},
{
{"hash_any", no_min, "agg_0", "hash_any"},
{"hash_any", min_count, "agg_1", "hash_any"},
@@ -2142,7 +2570,7 @@ TEST(GroupBy, AnyAndAll) {
}
}
-TEST(GroupBy, AnyAllScalar) {
+TEST_P(GroupBy, AnyAllScalar) {
BatchesWithSchema input;
input.batches = {
ExecBatchFromJSON({boolean(), int64()}, {ArgShape::SCALAR,
ArgShape::ARRAY},
@@ -2183,7 +2611,7 @@ TEST(GroupBy, AnyAllScalar) {
}
}
-TEST(GroupBy, CountDistinct) {
+TEST_P(GroupBy, CountDistinct) {
auto all = std::make_shared<CountOptions>(CountOptions::ALL);
auto only_valid = std::make_shared<CountOptions>(CountOptions::ONLY_VALID);
auto only_null = std::make_shared<CountOptions>(CountOptions::ONLY_NULL);
@@ -2223,7 +2651,7 @@ TEST(GroupBy, CountDistinct) {
ASSERT_OK_AND_ASSIGN(
Datum aggregated_and_grouped,
- RunGroupBy(
+ AltGroupBy(
{
table->GetColumnByName("argument"),
table->GetColumnByName("argument"),
@@ -2232,6 +2660,7 @@ TEST(GroupBy, CountDistinct) {
{
table->GetColumnByName("key"),
},
+ {},
{
{"hash_count_distinct", all, "agg_0", "hash_count_distinct"},
{"hash_count_distinct", only_valid, "agg_1",
"hash_count_distinct"},
@@ -2290,7 +2719,7 @@ TEST(GroupBy, CountDistinct) {
ASSERT_OK_AND_ASSIGN(
aggregated_and_grouped,
- RunGroupBy(
+ AltGroupBy(
{
table->GetColumnByName("argument"),
table->GetColumnByName("argument"),
@@ -2299,6 +2728,7 @@ TEST(GroupBy, CountDistinct) {
{
table->GetColumnByName("key"),
},
+ {},
{
{"hash_count_distinct", all, "agg_0", "hash_count_distinct"},
{"hash_count_distinct", only_valid, "agg_1",
"hash_count_distinct"},
@@ -2337,7 +2767,7 @@ TEST(GroupBy, CountDistinct) {
ASSERT_OK_AND_ASSIGN(
aggregated_and_grouped,
- RunGroupBy(
+ AltGroupBy(
{
table->GetColumnByName("argument"),
table->GetColumnByName("argument"),
@@ -2346,6 +2776,7 @@ TEST(GroupBy, CountDistinct) {
{
table->GetColumnByName("key"),
},
+ {},
{
{"hash_count_distinct", all, "agg_0", "hash_count_distinct"},
{"hash_count_distinct", only_valid, "agg_1",
"hash_count_distinct"},
@@ -2370,7 +2801,7 @@ TEST(GroupBy, CountDistinct) {
}
}
-TEST(GroupBy, Distinct) {
+TEST_P(GroupBy, Distinct) {
auto all = std::make_shared<CountOptions>(CountOptions::ALL);
auto only_valid = std::make_shared<CountOptions>(CountOptions::ONLY_VALID);
auto only_null = std::make_shared<CountOptions>(CountOptions::ONLY_NULL);
@@ -2409,7 +2840,7 @@ TEST(GroupBy, Distinct) {
])"});
ASSERT_OK_AND_ASSIGN(auto aggregated_and_grouped,
- RunGroupBy(
+ AltGroupBy(
{
table->GetColumnByName("argument"),
table->GetColumnByName("argument"),
@@ -2418,6 +2849,7 @@ TEST(GroupBy, Distinct) {
{
table->GetColumnByName("key"),
},
+ {},
{
{"hash_distinct", all, "agg_0",
"hash_distinct"},
{"hash_distinct", only_valid, "agg_1",
"hash_distinct"},
@@ -2482,7 +2914,7 @@ TEST(GroupBy, Distinct) {
])",
});
ASSERT_OK_AND_ASSIGN(aggregated_and_grouped,
- RunGroupBy(
+ AltGroupBy(
{
table->GetColumnByName("argument"),
table->GetColumnByName("argument"),
@@ -2491,6 +2923,7 @@ TEST(GroupBy, Distinct) {
{
table->GetColumnByName("key"),
},
+ {},
{
{"hash_distinct", all, "agg_0",
"hash_distinct"},
{"hash_distinct", only_valid, "agg_1",
"hash_distinct"},
@@ -2513,7 +2946,7 @@ TEST(GroupBy, Distinct) {
}
}
-TEST(GroupBy, OneMiscTypes) {
+TEST_P(GroupBy, OneMiscTypes) {
auto in_schema = schema({
field("floats", float64()),
field("nulls", null()),
@@ -2628,7 +3061,7 @@ TEST(GroupBy, OneMiscTypes) {
}
}
-TEST(GroupBy, OneNumericTypes) {
+TEST_P(GroupBy, OneNumericTypes) {
std::vector<std::shared_ptr<DataType>> types;
types.insert(types.end(), NumericTypes().begin(), NumericTypes().end());
types.insert(types.end(), TemporalTypes().begin(), TemporalTypes().end());
@@ -2713,7 +3146,7 @@ TEST(GroupBy, OneNumericTypes) {
}
}
-TEST(GroupBy, OneBinaryTypes) {
+TEST_P(GroupBy, OneBinaryTypes) {
for (bool use_threads : {true, false}) {
for (const auto& type : BaseBinaryTypes()) {
SCOPED_TRACE(use_threads ? "parallel/merged" : "serial");
@@ -2761,7 +3194,7 @@ TEST(GroupBy, OneBinaryTypes) {
}
}
-TEST(GroupBy, OneScalar) {
+TEST_P(GroupBy, OneScalar) {
BatchesWithSchema input;
input.batches = {
ExecBatchFromJSON({int32(), int64()}, {ArgShape::SCALAR,
ArgShape::ARRAY},
@@ -2791,7 +3224,7 @@ TEST(GroupBy, OneScalar) {
}
}
-TEST(GroupBy, ListNumeric) {
+TEST_P(GroupBy, ListNumeric) {
for (const auto& type : NumericTypes()) {
for (auto use_threads : {true, false}) {
SCOPED_TRACE(use_threads ? "parallel/merged" : "serial");
@@ -2829,13 +3262,14 @@ TEST(GroupBy, ListNumeric) {
])"});
ASSERT_OK_AND_ASSIGN(auto aggregated_and_grouped,
- RunGroupBy(
+ AltGroupBy(
{
table->GetColumnByName("argument"),
},
{
table->GetColumnByName("key"),
},
+ {},
{
{"hash_list", nullptr, "agg_0",
"hash_list"},
},
@@ -2900,13 +3334,14 @@ TEST(GroupBy, ListNumeric) {
])"});
ASSERT_OK_AND_ASSIGN(auto aggregated_and_grouped,
- RunGroupBy(
+ AltGroupBy(
{
table->GetColumnByName("argument"),
},
{
table->GetColumnByName("key"),
},
+ {},
{
{"hash_list", nullptr, "agg_0",
"hash_list"},
},
@@ -2941,7 +3376,7 @@ TEST(GroupBy, ListNumeric) {
}
}
-TEST(GroupBy, ListBinaryTypes) {
+TEST_P(GroupBy, ListBinaryTypes) {
for (bool use_threads : {true, false}) {
for (const auto& type : BaseBinaryTypes()) {
SCOPED_TRACE(use_threads ? "parallel/merged" : "serial");
@@ -2969,13 +3404,14 @@ TEST(GroupBy, ListBinaryTypes) {
])"});
ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped,
- RunGroupBy(
+ AltGroupBy(
{
table->GetColumnByName("argument0"),
},
{
table->GetColumnByName("key"),
},
+ {},
{
{"hash_list", nullptr, "agg_0",
"hash_list"},
},
@@ -3031,13 +3467,14 @@ TEST(GroupBy, ListBinaryTypes) {
])"});
ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped,
- RunGroupBy(
+ AltGroupBy(
{
table->GetColumnByName("argument0"),
},
{
table->GetColumnByName("key"),
},
+ {},
{
{"hash_list", nullptr, "agg_0",
"hash_list"},
},
@@ -3073,7 +3510,7 @@ TEST(GroupBy, ListBinaryTypes) {
}
}
-TEST(GroupBy, ListMiscTypes) {
+TEST_P(GroupBy, ListMiscTypes) {
auto in_schema = schema({
field("floats", float64()),
field("nulls", null()),
@@ -3231,7 +3668,7 @@ TEST(GroupBy, ListMiscTypes) {
}
}
-TEST(GroupBy, CountAndSum) {
+TEST_P(GroupBy, CountAndSum) {
auto batch = RecordBatchFromJSON(
schema({field("argument", float64()), field("key", int64())}), R"([
[1.0, 1],
@@ -3253,7 +3690,7 @@ TEST(GroupBy, CountAndSum) {
std::make_shared<ScalarAggregateOptions>(/*skip_nulls=*/true,
/*min_count=*/3);
ASSERT_OK_AND_ASSIGN(
Datum aggregated_and_grouped,
- RunGroupBy(
+ AltGroupBy(
{
// NB: passing an argument twice or also using it as a key is
legal
batch->GetColumnByName("argument"),
@@ -3266,6 +3703,7 @@ TEST(GroupBy, CountAndSum) {
{
batch->GetColumnByName("key"),
},
+ {},
{
{"hash_count", count_opts, "agg_0", "hash_count"},
{"hash_count", count_nulls_opts, "agg_1", "hash_count"},
@@ -3298,7 +3736,7 @@ TEST(GroupBy, CountAndSum) {
/*verbose=*/true);
}
-TEST(GroupBy, StandAloneNullaryCount) {
+TEST_P(GroupBy, StandAloneNullaryCount) {
auto batch = RecordBatchFromJSON(
schema({field("argument", float64()), field("key", int64())}), R"([
[1.0, 1],
@@ -3314,13 +3752,14 @@ TEST(GroupBy, StandAloneNullaryCount) {
])");
ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped,
- RunGroupBy(
+ AltGroupBy(
// zero arguments for aggregations because only the
// nullary hash_count_all aggregation is present
{},
{
batch->GetColumnByName("key"),
},
+ {},
{
{"hash_count_all", "hash_count_all"},
}));
@@ -3339,7 +3778,7 @@ TEST(GroupBy, StandAloneNullaryCount) {
/*verbose=*/true);
}
-TEST(GroupBy, Product) {
+TEST_P(GroupBy, Product) {
auto batch = RecordBatchFromJSON(
schema({field("argument", float64()), field("key", int64())}), R"([
[-1.0, 1],
@@ -3357,7 +3796,7 @@ TEST(GroupBy, Product) {
auto min_count =
std::make_shared<ScalarAggregateOptions>(/*skip_nulls=*/true,
/*min_count=*/3);
ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped,
- RunGroupBy(
+ AltGroupBy(
{
batch->GetColumnByName("argument"),
batch->GetColumnByName("key"),
@@ -3366,6 +3805,7 @@ TEST(GroupBy, Product) {
{
batch->GetColumnByName("key"),
},
+ {},
{
{"hash_product", nullptr, "agg_0",
"hash_product"},
{"hash_product", nullptr, "agg_1",
"hash_product"},
@@ -3395,13 +3835,14 @@ TEST(GroupBy, Product) {
])");
ASSERT_OK_AND_ASSIGN(aggregated_and_grouped,
- RunGroupBy(
+ AltGroupBy(
{
batch->GetColumnByName("argument"),
},
{
batch->GetColumnByName("key"),
},
+ {},
{
{"hash_product", nullptr, "agg_0",
"hash_product"},
}));
@@ -3415,7 +3856,7 @@ TEST(GroupBy, Product) {
/*verbose=*/true);
}
-TEST(GroupBy, SumMeanProductKeepNulls) {
+TEST_P(GroupBy, SumMeanProductKeepNulls) {
auto batch = RecordBatchFromJSON(
schema({field("argument", float64()), field("key", int64())}), R"([
[-1.0, 1],
@@ -3434,7 +3875,7 @@ TEST(GroupBy, SumMeanProductKeepNulls) {
auto min_count =
std::make_shared<ScalarAggregateOptions>(/*skip_nulls=*/false,
/*min_count=*/3);
ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped,
- RunGroupBy(
+ AltGroupBy(
{
batch->GetColumnByName("argument"),
batch->GetColumnByName("argument"),
@@ -3446,6 +3887,7 @@ TEST(GroupBy, SumMeanProductKeepNulls) {
{
batch->GetColumnByName("key"),
},
+ {},
{
{"hash_sum", keep_nulls, "agg_0", "hash_sum"},
{"hash_sum", min_count, "agg_1", "hash_sum"},
@@ -3474,7 +3916,7 @@ TEST(GroupBy, SumMeanProductKeepNulls) {
/*verbose=*/true);
}
-TEST(GroupBy, SumOnlyStringAndDictKeys) {
+TEST_P(GroupBy, SumOnlyStringAndDictKeys) {
for (auto key_type : {utf8(), dictionary(int32(), utf8())}) {
SCOPED_TRACE("key type: " + key_type->ToString());
@@ -3494,7 +3936,7 @@ TEST(GroupBy, SumOnlyStringAndDictKeys) {
ASSERT_OK_AND_ASSIGN(
Datum aggregated_and_grouped,
- RunGroupBy({batch->GetColumnByName("agg_0")},
{batch->GetColumnByName("key")},
+ AltGroupBy({batch->GetColumnByName("agg_0")},
{batch->GetColumnByName("key")}, {},
{
{"hash_sum", nullptr, "agg_0", "hash_sum"},
}));
@@ -3515,7 +3957,7 @@ TEST(GroupBy, SumOnlyStringAndDictKeys) {
}
}
-TEST(GroupBy, ConcreteCaseWithValidateGroupBy) {
+TEST_P(GroupBy, ConcreteCaseWithValidateGroupBy) {
auto batch =
RecordBatchFromJSON(schema({field("agg_0", float64()), field("key",
utf8())}), R"([
[1.0, "alfa"],
@@ -3551,7 +3993,7 @@ TEST(GroupBy, ConcreteCaseWithValidateGroupBy) {
}
// Count nulls/non_nulls from record batch with no nulls
-TEST(GroupBy, CountNull) {
+TEST_P(GroupBy, CountNull) {
auto batch =
RecordBatchFromJSON(schema({field("agg_0", float64()), field("key",
utf8())}), R"([
[1.0, "alfa"],
@@ -3574,7 +4016,7 @@ TEST(GroupBy, CountNull) {
}
}
-TEST(GroupBy, RandomArraySum) {
+TEST_P(GroupBy, RandomArraySum) {
std::shared_ptr<ScalarAggregateOptions> options =
std::make_shared<ScalarAggregateOptions>(/*skip_nulls=*/true,
/*min_count=*/0);
for (int64_t length : {1 << 10, 1 << 12, 1 << 15}) {
@@ -3592,12 +4034,13 @@ TEST(GroupBy, RandomArraySum) {
{
{"hash_sum", options, "agg_0", "hash_sum"},
},
- {batch->GetColumnByName("agg_0")}, {batch->GetColumnByName("key")});
+ {batch->GetColumnByName("agg_0")}, {batch->GetColumnByName("key")},
+ /*naive=*/false);
}
}
}
-TEST(GroupBy, WithChunkedArray) {
+TEST_P(GroupBy, WithChunkedArray) {
auto table =
TableFromJSON(schema({field("argument", float64()), field("key",
int64())}),
{R"([{"argument": 1.0, "key": 1},
@@ -3613,7 +4056,7 @@ TEST(GroupBy, WithChunkedArray) {
{"argument": null, "key": 3}
])"});
ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped,
- RunGroupBy(
+ AltGroupBy(
{
table->GetColumnByName("argument"),
table->GetColumnByName("argument"),
@@ -3622,6 +4065,7 @@ TEST(GroupBy, WithChunkedArray) {
{
table->GetColumnByName("key"),
},
+ {},
{
{"hash_count", nullptr, "agg_0", "hash_count"},
{"hash_sum", nullptr, "agg_1", "hash_sum"},
@@ -3647,19 +4091,20 @@ TEST(GroupBy, WithChunkedArray) {
/*verbose=*/true);
}
-TEST(GroupBy, MinMaxWithNewGroupsInChunkedArray) {
+TEST_P(GroupBy, MinMaxWithNewGroupsInChunkedArray) {
auto table = TableFromJSON(
schema({field("argument", int64()), field("key", int64())}),
{R"([{"argument": 1, "key": 0}])", R"([{"argument": 0, "key": 1}])"});
ScalarAggregateOptions count_options;
ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped,
- RunGroupBy(
+ AltGroupBy(
{
table->GetColumnByName("argument"),
},
{
table->GetColumnByName("key"),
},
+ {},
{
{"hash_min_max", nullptr, "agg_0",
"hash_min_max"},
}));
@@ -3679,7 +4124,7 @@ TEST(GroupBy, MinMaxWithNewGroupsInChunkedArray) {
/*verbose=*/true);
}
-TEST(GroupBy, SmallChunkSizeSumOnly) {
+TEST_P(GroupBy, SmallChunkSizeSumOnly) {
auto batch = RecordBatchFromJSON(
schema({field("argument", float64()), field("key", int64())}), R"([
[1.0, 1],
@@ -3693,13 +4138,13 @@ TEST(GroupBy, SmallChunkSizeSumOnly) {
[0.75, null],
[null, 3]
])");
- ASSERT_OK_AND_ASSIGN(
- Datum aggregated_and_grouped,
- RunGroupBy({batch->GetColumnByName("argument")},
{batch->GetColumnByName("key")},
- {
- {"hash_sum", nullptr, "agg_0", "hash_sum"},
- },
- small_chunksize_context()));
+ ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped,
+ AltGroupBy({batch->GetColumnByName("argument")},
+ {batch->GetColumnByName("key")}, {},
+ {
+ {"hash_sum", nullptr, "agg_0",
"hash_sum"},
+ },
+ small_chunksize_context()));
AssertDatumsEqual(ArrayFromJSON(struct_({
field("hash_sum", float64()),
field("key_0", int64()),
@@ -3714,7 +4159,7 @@ TEST(GroupBy, SmallChunkSizeSumOnly) {
/*verbose=*/true);
}
-TEST(GroupBy, CountWithNullType) {
+TEST_P(GroupBy, CountWithNullType) {
auto table =
TableFromJSON(schema({field("argument", null()), field("key",
int64())}), {R"([
[null, 1],
@@ -3772,7 +4217,7 @@ TEST(GroupBy, CountWithNullType) {
}
}
-TEST(GroupBy, CountWithNullTypeEmptyTable) {
+TEST_P(GroupBy, CountWithNullTypeEmptyTable) {
auto table = TableFromJSON(schema({field("argument", null()), field("key",
int64())}),
{R"([])"});
@@ -3803,7 +4248,7 @@ TEST(GroupBy, CountWithNullTypeEmptyTable) {
}
}
-TEST(GroupBy, SingleNullTypeKey) {
+TEST_P(GroupBy, SingleNullTypeKey) {
auto table =
TableFromJSON(schema({field("argument", int64()), field("key",
null())}), {R"([
[1, null],
@@ -3860,7 +4305,7 @@ TEST(GroupBy, SingleNullTypeKey) {
}
}
-TEST(GroupBy, MultipleKeysIncludesNullType) {
+TEST_P(GroupBy, MultipleKeysIncludesNullType) {
auto table = TableFromJSON(schema({field("argument", float64()),
field("key_0", utf8()),
field("key_1", null())}),
{R"([
@@ -3920,7 +4365,7 @@ TEST(GroupBy, MultipleKeysIncludesNullType) {
}
}
-TEST(GroupBy, SumNullType) {
+TEST_P(GroupBy, SumNullType) {
auto table =
TableFromJSON(schema({field("argument", null()), field("key",
int64())}), {R"([
[null, 1],
@@ -3986,7 +4431,7 @@ TEST(GroupBy, SumNullType) {
}
}
-TEST(GroupBy, ProductNullType) {
+TEST_P(GroupBy, ProductNullType) {
auto table =
TableFromJSON(schema({field("argument", null()), field("key",
int64())}), {R"([
[null, 1],
@@ -4052,7 +4497,7 @@ TEST(GroupBy, ProductNullType) {
}
}
-TEST(GroupBy, MeanNullType) {
+TEST_P(GroupBy, MeanNullType) {
auto table =
TableFromJSON(schema({field("argument", null()), field("key",
int64())}), {R"([
[null, 1],
@@ -4118,7 +4563,7 @@ TEST(GroupBy, MeanNullType) {
}
}
-TEST(GroupBy, NullTypeEmptyTable) {
+TEST_P(GroupBy, NullTypeEmptyTable) {
auto table = TableFromJSON(schema({field("argument", null()), field("key",
int64())}),
{R"([])"});
@@ -4157,7 +4602,7 @@ TEST(GroupBy, NullTypeEmptyTable) {
}
}
-TEST(GroupBy, OnlyKeys) {
+TEST_P(GroupBy, OnlyKeys) {
auto table =
TableFromJSON(schema({field("key_0", int64()), field("key_1", utf8())}),
{R"([
[1, "a"],
@@ -4202,5 +4647,262 @@ TEST(GroupBy, OnlyKeys) {
/*verbose=*/true);
}
}
+
+INSTANTIATE_TEST_SUITE_P(GroupBy, GroupBy, ::testing::Values(RunGroupByImpl));
+
+class SegmentedScalarGroupBy : public GroupBy {};
+
+class SegmentedKeyGroupBy : public GroupBy {};
+
+void TestSegment(GroupByFunction group_by, const std::shared_ptr<Table>& table,
+ Datum output, const std::vector<Datum>& keys,
+ const std::vector<Datum>& segment_keys, bool
is_scalar_aggregate) {
+ const char* names[] = {
+ is_scalar_aggregate ? "count" : "hash_count",
+ is_scalar_aggregate ? "sum" : "hash_sum",
+ is_scalar_aggregate ? "min_max" : "hash_min_max",
+ };
+ ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped,
+ group_by(
+ {
+ table->GetColumnByName("argument"),
+ table->GetColumnByName("argument"),
+ table->GetColumnByName("argument"),
+ },
+ keys, segment_keys,
+ {
+ {names[0], nullptr, "agg_0", names[0]},
+ {names[1], nullptr, "agg_1", names[1]},
+ {names[2], nullptr, "agg_2", names[2]},
+ },
+ /*use_threads=*/false, /*naive=*/false));
+
+ AssertDatumsEqual(output, aggregated_and_grouped, /*verbose=*/true);
+}
+
+// test with empty keys, covering code in ScalarAggregateNode
+void TestSegmentScalar(GroupByFunction group_by, const std::shared_ptr<Table>&
table,
+ Datum output, const std::vector<Datum>& segment_keys) {
+ TestSegment(group_by, table, output, {}, segment_keys, /*scalar=*/true);
+}
+
+// test with given segment-keys and keys set to `{"key"}`, covering code in
GroupByNode
+void TestSegmentKey(GroupByFunction group_by, const std::shared_ptr<Table>&
table,
+ Datum output, const std::vector<Datum>& segment_keys) {
+ TestSegment(group_by, table, output, {table->GetColumnByName("key")},
segment_keys,
+ /*scalar=*/false);
+}
+
+Result<std::shared_ptr<Table>> GetSingleSegmentInputAsChunked() {
+ auto table = TableFromJSON(schema({field("argument", float64()),
field("key", int64()),
+ field("segment_key", int64())}),
+ {R"([{"argument": 1.0, "key": 1,
"segment_key": 1},
+ {"argument": null, "key": 1, "segment_key": 1}
+ ])",
+ R"([{"argument": 0.0, "key": 2,
"segment_key": 1},
+ {"argument": null, "key": 3, "segment_key": 1},
+ {"argument": 4.0, "key": null, "segment_key": 1},
+ {"argument": 3.25, "key": 1, "segment_key": 1},
+ {"argument": 0.125, "key": 2, "segment_key": 1},
+ {"argument": -0.25, "key": 2, "segment_key": 1},
+ {"argument": 0.75, "key": null, "segment_key": 1},
+ {"argument": null, "key": 3, "segment_key": 1}
+ ])",
+ R"([{"argument": 1.0, "key": 1,
"segment_key": 0},
+ {"argument": null, "key": 1, "segment_key": 0}
+ ])",
+ R"([{"argument": 0.0, "key": 2,
"segment_key": 0},
+ {"argument": null, "key": 3, "segment_key": 0},
+ {"argument": 4.0, "key": null, "segment_key": 0},
+ {"argument": 3.25, "key": 1, "segment_key": 0},
+ {"argument": 0.125, "key": 2, "segment_key": 0},
+ {"argument": -0.25, "key": 2, "segment_key": 0},
+ {"argument": 0.75, "key": null, "segment_key": 0},
+ {"argument": null, "key": 3, "segment_key": 0}
+ ])"});
+ return table;
+}
+
+Result<std::shared_ptr<Table>> GetSingleSegmentInputAsCombined() {
+ ARROW_ASSIGN_OR_RAISE(auto table, GetSingleSegmentInputAsChunked());
+ return table->CombineChunks();
+}
+
+Result<std::shared_ptr<ChunkedArray>> GetSingleSegmentScalarOutput() {
+ return ChunkedArrayFromJSON(struct_({
+ field("count", int64()),
+ field("sum", float64()),
+ field("min_max", struct_({
+ field("min", float64()),
+ field("max", float64()),
+ })),
+ field("key_0", int64()),
+ }),
+ {R"([
+ [7, 8.875, {"min": -0.25, "max": 4.0}, 1]
+ ])",
+ R"([
+ [7, 8.875, {"min": -0.25, "max": 4.0}, 0]
+ ])"});
+}
+
+Result<std::shared_ptr<ChunkedArray>> GetSingleSegmentKeyOutput() {
+ return ChunkedArrayFromJSON(struct_({
+ field("hash_count", int64()),
+ field("hash_sum", float64()),
+ field("hash_min_max", struct_({
+ field("min",
float64()),
+ field("max",
float64()),
+ })),
+ field("key_0", int64()),
+ field("key_1", int64()),
+ }),
+ {R"([
+ [2, 4.25, {"min": 1.0, "max": 3.25}, 1, 1],
+ [3, -0.125, {"min": -0.25, "max": 0.125}, 2, 1],
+ [0, null, {"min": null, "max": null}, 3, 1],
+ [2, 4.75, {"min": 0.75, "max": 4.0}, null, 1]
+ ])",
+ R"([
+ [2, 4.25, {"min": 1.0, "max": 3.25}, 1, 0],
+ [3, -0.125, {"min": -0.25, "max": 0.125}, 2, 0],
+ [0, null, {"min": null, "max": null}, 3, 0],
+ [2, 4.75, {"min": 0.75, "max": 4.0}, null, 0]
+ ])"});
+}
+
+void TestSingleSegmentScalar(GroupByFunction group_by,
+ std::function<Result<std::shared_ptr<Table>>()>
get_table) {
+ ASSERT_OK_AND_ASSIGN(auto table, get_table());
+ ASSERT_OK_AND_ASSIGN(auto output, GetSingleSegmentScalarOutput());
+ TestSegmentScalar(group_by, table, output,
{table->GetColumnByName("segment_key")});
+}
+
+void TestSingleSegmentKey(GroupByFunction group_by,
+ std::function<Result<std::shared_ptr<Table>>()>
get_table) {
+ ASSERT_OK_AND_ASSIGN(auto table, get_table());
+ ASSERT_OK_AND_ASSIGN(auto output, GetSingleSegmentKeyOutput());
+ TestSegmentKey(group_by, table, output,
{table->GetColumnByName("segment_key")});
+}
+
+TEST_P(SegmentedScalarGroupBy, SingleSegmentScalarChunked) {
+ TestSingleSegmentScalar(GetParam(), GetSingleSegmentInputAsChunked);
+}
+
+TEST_P(SegmentedScalarGroupBy, SingleSegmentScalarCombined) {
+ TestSingleSegmentScalar(GetParam(), GetSingleSegmentInputAsCombined);
+}
+
+TEST_P(SegmentedKeyGroupBy, SingleSegmentKeyChunked) {
+ TestSingleSegmentKey(GetParam(), GetSingleSegmentInputAsChunked);
+}
+
+TEST_P(SegmentedKeyGroupBy, SingleSegmentKeyCombined) {
+ TestSingleSegmentKey(GetParam(), GetSingleSegmentInputAsCombined);
+}
+
+// extracts one segment of the obtained (single-segment-key) table
+Result<std::shared_ptr<Table>> GetEmptySegmentKeysInput(
+ std::function<Result<std::shared_ptr<Table>>()> get_table) {
+ ARROW_ASSIGN_OR_RAISE(auto table, get_table());
+ auto sliced = table->Slice(0, 10);
+ ARROW_ASSIGN_OR_RAISE(auto batch, sliced->CombineChunksToBatch());
+ ARROW_ASSIGN_OR_RAISE(auto array, batch->ToStructArray());
+ ARROW_ASSIGN_OR_RAISE(auto chunked, ChunkedArray::Make({array},
array->type()));
+ return Table::FromChunkedStructArray(chunked);
+}
+
+Result<std::shared_ptr<Table>> GetEmptySegmentKeysInputAsChunked() {
+ return GetEmptySegmentKeysInput(GetSingleSegmentInputAsChunked);
+}
+
+Result<std::shared_ptr<Table>> GetEmptySegmentKeysInputAsCombined() {
+ return GetEmptySegmentKeysInput(GetSingleSegmentInputAsCombined);
+}
+
+// extracts the expected output for one segment
+Result<std::shared_ptr<Array>> GetEmptySegmentKeyOutput() {
+ ARROW_ASSIGN_OR_RAISE(auto chunked, GetSingleSegmentKeyOutput());
+ ARROW_ASSIGN_OR_RAISE(auto table, Table::FromChunkedStructArray(chunked));
+ ARROW_ASSIGN_OR_RAISE(auto removed, table->RemoveColumn(table->num_columns()
- 1));
+ auto sliced = removed->Slice(0, 4);
+ ARROW_ASSIGN_OR_RAISE(auto batch, sliced->CombineChunksToBatch());
+ return batch->ToStructArray();
+}
+
+void TestEmptySegmentKey(GroupByFunction group_by,
+ std::function<Result<std::shared_ptr<Table>>()>
get_table) {
+ ASSERT_OK_AND_ASSIGN(auto table, get_table());
+ ASSERT_OK_AND_ASSIGN(auto output, GetEmptySegmentKeyOutput());
+ TestSegmentKey(group_by, table, output, {});
+}
+
+TEST_P(SegmentedKeyGroupBy, EmptySegmentKeyChunked) {
+ TestEmptySegmentKey(GetParam(), GetEmptySegmentKeysInputAsChunked);
+}
+
+TEST_P(SegmentedKeyGroupBy, EmptySegmentKeyCombined) {
+ TestEmptySegmentKey(GetParam(), GetEmptySegmentKeysInputAsCombined);
+}
+
+// adds a named copy of the last (single-segment-key) column to the obtained
table
+Result<std::shared_ptr<Table>> GetMultiSegmentInput(
+ std::function<Result<std::shared_ptr<Table>>()> get_table,
+ const std::string& add_name) {
+ ARROW_ASSIGN_OR_RAISE(auto table, get_table());
+ int last = table->num_columns() - 1;
+ auto add_field = field(add_name, table->schema()->field(last)->type());
+ return table->AddColumn(table->num_columns(), add_field,
table->column(last));
+}
+
+Result<std::shared_ptr<Table>> GetMultiSegmentInputAsChunked(
+ const std::string& add_name) {
+ return GetMultiSegmentInput(GetSingleSegmentInputAsChunked, add_name);
+}
+
+Result<std::shared_ptr<Table>> GetMultiSegmentInputAsCombined(
+ const std::string& add_name) {
+ return GetMultiSegmentInput(GetSingleSegmentInputAsCombined, add_name);
+}
+
+// adds a named copy of the last (single-segment-key) column to the expected
output table
+Result<std::shared_ptr<ChunkedArray>> GetMultiSegmentKeyOutput(
+ const std::string& add_name) {
+ ARROW_ASSIGN_OR_RAISE(auto chunked, GetSingleSegmentKeyOutput());
+ ARROW_ASSIGN_OR_RAISE(auto table, Table::FromChunkedStructArray(chunked));
+ int last = table->num_columns() - 1;
+ auto add_field = field(add_name, table->schema()->field(last)->type());
+ ARROW_ASSIGN_OR_RAISE(auto added,
+ table->AddColumn(last + 1, add_field,
table->column(last)));
+ ARROW_ASSIGN_OR_RAISE(auto batch, added->CombineChunksToBatch());
+ ARROW_ASSIGN_OR_RAISE(auto array, batch->ToStructArray());
+ return ChunkedArray::Make({array->Slice(0, 4), array->Slice(4, 4)},
array->type());
+}
+
+void TestMultiSegmentKey(
+ GroupByFunction group_by,
+ std::function<Result<std::shared_ptr<Table>>(const std::string&)>
get_table) {
+ std::string add_name = "segment_key2";
+ ASSERT_OK_AND_ASSIGN(auto table, get_table(add_name));
+ ASSERT_OK_AND_ASSIGN(auto output, GetMultiSegmentKeyOutput("key_2"));
+ TestSegmentKey(
+ group_by, table, output,
+ {table->GetColumnByName("segment_key"),
table->GetColumnByName(add_name)});
+}
+
+TEST_P(SegmentedKeyGroupBy, MultiSegmentKeyChunked) {
+ TestMultiSegmentKey(GetParam(), GetMultiSegmentInputAsChunked);
+}
+
+TEST_P(SegmentedKeyGroupBy, MultiSegmentKeyCombined) {
+ TestMultiSegmentKey(GetParam(), GetMultiSegmentInputAsCombined);
+}
+
+INSTANTIATE_TEST_SUITE_P(SegmentedScalarGroupBy, SegmentedScalarGroupBy,
+ ::testing::Values(RunSegmentedGroupByImpl));
+
+INSTANTIATE_TEST_SUITE_P(SegmentedKeyGroupBy, SegmentedKeyGroupBy,
+ ::testing::Values(RunSegmentedGroupByImpl));
+
} // namespace compute
} // namespace arrow
diff --git a/cpp/src/arrow/compute/row/grouper.cc
b/cpp/src/arrow/compute/row/grouper.cc
index d003137d3e..75df42abd0 100644
--- a/cpp/src/arrow/compute/row/grouper.cc
+++ b/cpp/src/arrow/compute/row/grouper.cc
@@ -19,6 +19,9 @@
#include <memory>
#include <mutex>
+#include <type_traits>
+
+#include "arrow/array/builder_primitive.h"
#include "arrow/compute/exec/key_hash.h"
#include "arrow/compute/exec/key_map.h"
@@ -29,7 +32,9 @@
#include "arrow/compute/light_array.h"
#include "arrow/compute/registry.h"
#include "arrow/compute/row/compare_internal.h"
+#include "arrow/compute/row/grouper_internal.h"
#include "arrow/type.h"
+#include "arrow/type_traits.h"
#include "arrow/util/bitmap_ops.h"
#include "arrow/util/checked_cast.h"
#include "arrow/util/cpu_info.h"
@@ -39,12 +44,333 @@
namespace arrow {
using internal::checked_cast;
+using internal::PrimitiveScalarBase;
namespace compute {
namespace {
-struct GrouperImpl : Grouper {
+constexpr uint32_t kNoGroupId = std::numeric_limits<uint32_t>::max();
+
+using group_id_t = std::remove_const<decltype(kNoGroupId)>::type;
+using GroupIdType = CTypeTraits<group_id_t>::ArrowType;
+auto g_group_id_type = std::make_shared<GroupIdType>();
+
+inline const uint8_t* GetValuesAsBytes(const ArraySpan& data, int64_t offset =
0) {
+ DCHECK_GT(data.type->byte_width(), 0);
+ int64_t absolute_byte_offset = (data.offset + offset) *
data.type->byte_width();
+ return data.GetValues<uint8_t>(1, absolute_byte_offset);
+}
+
+template <typename Value>
+Status CheckForGetNextSegment(const std::vector<Value>& values, int64_t length,
+ int64_t offset, const std::vector<TypeHolder>&
key_types) {
+ if (offset < 0 || offset > length) {
+ return Status::Invalid("invalid grouping segmenter offset: ", offset);
+ }
+ if (values.size() != key_types.size()) {
+ return Status::Invalid("expected batch size ", key_types.size(), " but got
",
+ values.size());
+ }
+ for (size_t i = 0; i < key_types.size(); i++) {
+ const auto& value = values[i];
+ const auto& key_type = key_types[i];
+ if (*value.type() != *key_type.type) {
+ return Status::Invalid("expected batch value ", i, " of type ",
*key_type.type,
+ " but got ", *value.type());
+ }
+ }
+ return Status::OK();
+}
+
+template <typename Batch>
+enable_if_t<std::is_same<Batch, ExecSpan>::value || std::is_same<Batch,
ExecBatch>::value,
+ Status>
+CheckForGetNextSegment(const Batch& batch, int64_t offset,
+ const std::vector<TypeHolder>& key_types) {
+ return CheckForGetNextSegment(batch.values, batch.length, offset, key_types);
+}
+
+struct BaseRowSegmenter : public RowSegmenter {
+ explicit BaseRowSegmenter(const std::vector<TypeHolder>& key_types)
+ : key_types_(key_types) {}
+
+ const std::vector<TypeHolder>& key_types() const override { return
key_types_; }
+
+ std::vector<TypeHolder> key_types_;
+};
+
+Segment MakeSegment(int64_t batch_length, int64_t offset, int64_t length, bool
extends) {
+ return Segment{offset, length, offset + length >= batch_length, extends};
+}
+
+// Used by SimpleKeySegmenter::GetNextSegment to find the match-length of a
value within a
+// fixed-width buffer
+int64_t GetMatchLength(const uint8_t* match_bytes, int64_t match_width,
+ const uint8_t* array_bytes, int64_t offset, int64_t
length) {
+ int64_t cursor, byte_cursor;
+ for (cursor = offset, byte_cursor = match_width * cursor; cursor < length;
+ cursor++, byte_cursor += match_width) {
+ if (memcmp(match_bytes, array_bytes + byte_cursor,
+ static_cast<size_t>(match_width)) != 0) {
+ break;
+ }
+ }
+ return std::min(cursor, length) - offset;
+}
+
+using ExtendFunc = std::function<bool(const void*)>;
+constexpr bool kDefaultExtends = true; // by default, the first segment
extends
+constexpr bool kEmptyExtends = true; // an empty segment extends too
+
+struct NoKeysSegmenter : public BaseRowSegmenter {
+ static std::unique_ptr<RowSegmenter> Make() {
+ return std::make_unique<NoKeysSegmenter>();
+ }
+
+ NoKeysSegmenter() : BaseRowSegmenter({}) {}
+
+ Status Reset() override { return Status::OK(); }
+
+ Result<Segment> GetNextSegment(const ExecSpan& batch, int64_t offset)
override {
+ ARROW_RETURN_NOT_OK(CheckForGetNextSegment(batch, offset, {}));
+ return MakeSegment(batch.length, offset, batch.length - offset,
kDefaultExtends);
+ }
+};
+
+struct SimpleKeySegmenter : public BaseRowSegmenter {
+ static Result<std::unique_ptr<RowSegmenter>> Make(TypeHolder key_type) {
+ return std::make_unique<SimpleKeySegmenter>(key_type);
+ }
+
+ explicit SimpleKeySegmenter(TypeHolder key_type)
+ : BaseRowSegmenter({key_type}),
+ key_type_(key_types_[0]),
+ save_key_data_(static_cast<size_t>(key_type_.type->byte_width())),
+ extend_was_called_(false) {}
+
+ Status CheckType(const DataType& type) {
+ if (!is_fixed_width(type)) {
+ return Status::Invalid("SimpleKeySegmenter does not support type ",
type);
+ }
+ return Status::OK();
+ }
+
+ Status Reset() override {
+ extend_was_called_ = false;
+ return Status::OK();
+ }
+
+ // Checks whether the given grouping data extends the current segment, i.e.,
is equal to
+ // previously seen grouping data, which is updated with each invocation.
+ bool Extend(const void* data) {
+ bool extends = !extend_was_called_
+ ? kDefaultExtends
+ : 0 == memcmp(save_key_data_.data(), data,
save_key_data_.size());
+ extend_was_called_ = true;
+ memcpy(save_key_data_.data(), data, save_key_data_.size());
+ return extends;
+ }
+
+ Result<Segment> GetNextSegment(const Scalar& scalar, int64_t offset, int64_t
length) {
+ ARROW_RETURN_NOT_OK(CheckType(*scalar.type));
+ if (!scalar.is_valid) {
+ return Status::Invalid("segmenting an invalid scalar");
+ }
+ auto data = checked_cast<const PrimitiveScalarBase&>(scalar).data();
+ bool extends = length > 0 ? Extend(data) : kEmptyExtends;
+ return MakeSegment(length, offset, length, extends);
+ }
+
+ Result<Segment> GetNextSegment(const DataType& array_type, const uint8_t*
array_bytes,
+ int64_t offset, int64_t length) {
+ RETURN_NOT_OK(CheckType(array_type));
+ DCHECK_LE(offset, length);
+ int64_t byte_width = array_type.byte_width();
+ int64_t match_length = GetMatchLength(array_bytes + offset * byte_width,
byte_width,
+ array_bytes, offset, length);
+ bool extends = length > 0 ? Extend(array_bytes + offset * byte_width) :
kEmptyExtends;
+ return MakeSegment(length, offset, match_length, extends);
+ }
+
+ Result<Segment> GetNextSegment(const ExecSpan& batch, int64_t offset)
override {
+ ARROW_RETURN_NOT_OK(CheckForGetNextSegment(batch, offset, {key_type_}));
+ if (offset == batch.length) {
+ return MakeSegment(batch.length, offset, 0, kEmptyExtends);
+ }
+ const auto& value = batch.values[0];
+ if (value.is_scalar()) {
+ return GetNextSegment(*value.scalar, offset, batch.length);
+ }
+ ARROW_DCHECK(value.is_array());
+ const auto& array = value.array;
+ if (array.GetNullCount() > 0) {
+ return Status::NotImplemented("segmenting a nullable array");
+ }
+ return GetNextSegment(*array.type, GetValuesAsBytes(array), offset,
batch.length);
+ }
+
+ private:
+ TypeHolder key_type_;
+ std::vector<uint8_t> save_key_data_; // previusly seen segment-key grouping
data
+ bool extend_was_called_;
+};
+
+struct AnyKeysSegmenter : public BaseRowSegmenter {
+ static Result<std::unique_ptr<RowSegmenter>> Make(
+ const std::vector<TypeHolder>& key_types, ExecContext* ctx) {
+ ARROW_RETURN_NOT_OK(Grouper::Make(key_types, ctx)); // check types
+ return std::make_unique<AnyKeysSegmenter>(key_types, ctx);
+ }
+
+ AnyKeysSegmenter(const std::vector<TypeHolder>& key_types, ExecContext* ctx)
+ : BaseRowSegmenter(key_types),
+ ctx_(ctx),
+ grouper_(nullptr),
+ save_group_id_(kNoGroupId) {}
+
+ Status Reset() override {
+ grouper_ = nullptr;
+ save_group_id_ = kNoGroupId;
+ return Status::OK();
+ }
+
+ bool Extend(const void* data) {
+ auto group_id = *static_cast<const group_id_t*>(data);
+ bool extends =
+ save_group_id_ == kNoGroupId ? kDefaultExtends : save_group_id_ ==
group_id;
+ save_group_id_ = group_id;
+ return extends;
+ }
+
+ // Runs the grouper on a single row. This is used to determine the group id
of the
+ // first row of a new segment to see if it extends the previous segment.
+ template <typename Batch>
+ Result<group_id_t> MapGroupIdAt(const Batch& batch, int64_t offset) {
+ if (!grouper_) return kNoGroupId;
+ ARROW_ASSIGN_OR_RAISE(auto datum, grouper_->Consume(batch, offset,
+ /*length=*/1));
+ if (!datum.is_array()) {
+ return Status::Invalid("accessing unsupported datum kind ",
datum.kind());
+ }
+ const std::shared_ptr<ArrayData>& data = datum.array();
+ ARROW_DCHECK(data->GetNullCount() == 0);
+ DCHECK_EQ(data->type->id(), GroupIdType::type_id);
+ DCHECK_EQ(1, data->length);
+ const group_id_t* values = data->GetValues<group_id_t>(1);
+ return values[0];
+ }
+
+ Result<Segment> GetNextSegment(const ExecSpan& batch, int64_t offset)
override {
+ ARROW_RETURN_NOT_OK(CheckForGetNextSegment(batch, offset, key_types_));
+ if (offset == batch.length) {
+ return MakeSegment(batch.length, offset, 0, kEmptyExtends);
+ }
+ // ARROW-18311: make Grouper support Reset()
+ // so it can be reset instead of recreated below
+ //
+ // the group id must be computed prior to resetting the grouper, since it
is compared
+ // to save_group_id_, and after resetting the grouper produces
incomparable group ids
+ ARROW_ASSIGN_OR_RAISE(auto group_id, MapGroupIdAt(batch, offset));
+ ExtendFunc bound_extend = [this, group_id](const void* data) {
+ bool extends = Extend(&group_id);
+ save_group_id_ = *static_cast<const group_id_t*>(data);
+ return extends;
+ };
+ // resetting drops grouper's group-ids, freeing-up memory for the next
segment
+ ARROW_ASSIGN_OR_RAISE(grouper_, Grouper::Make(key_types_, ctx_)); //
TODO: reset it
+ // GH-34475: cache the grouper-consume result across invocations of
GetNextSegment
+ ARROW_ASSIGN_OR_RAISE(auto datum, grouper_->Consume(batch, offset));
+ if (datum.is_array()) {
+ // `data` is an array whose index-0 corresponds to index `offset` of
`batch`
+ const std::shared_ptr<ArrayData>& data = datum.array();
+ DCHECK_EQ(data->length, batch.length - offset);
+ ARROW_DCHECK(data->GetNullCount() == 0);
+ DCHECK_EQ(data->type->id(), GroupIdType::type_id);
+ const group_id_t* values = data->GetValues<group_id_t>(1);
+ int64_t cursor;
+ for (cursor = 1; cursor < data->length; cursor++) {
+ if (values[0] != values[cursor]) break;
+ }
+ int64_t length = cursor;
+ bool extends = length > 0 ? bound_extend(values) : kEmptyExtends;
+ return MakeSegment(batch.length, offset, length, extends);
+ } else {
+ return Status::Invalid("segmenting unsupported datum kind ",
datum.kind());
+ }
+ }
+
+ private:
+ ExecContext* const ctx_;
+ std::unique_ptr<Grouper> grouper_;
+ group_id_t save_group_id_;
+};
+
+Status CheckAndCapLengthForConsume(int64_t batch_length, int64_t&
consume_offset,
+ int64_t* consume_length) {
+ if (consume_offset < 0) {
+ return Status::Invalid("invalid grouper consume offset: ", consume_offset);
+ }
+ if (*consume_length < 0) {
+ *consume_length = batch_length - consume_offset;
+ }
+ return Status::OK();
+}
+
+} // namespace
+
+Result<std::unique_ptr<RowSegmenter>> MakeAnyKeysSegmenter(
+ const std::vector<TypeHolder>& key_types, ExecContext* ctx) {
+ return AnyKeysSegmenter::Make(key_types, ctx);
+}
+
+Result<std::unique_ptr<RowSegmenter>> RowSegmenter::Make(
+ const std::vector<TypeHolder>& key_types, bool nullable_keys, ExecContext*
ctx) {
+ if (key_types.size() == 0) {
+ return NoKeysSegmenter::Make();
+ } else if (!nullable_keys && key_types.size() == 1) {
+ const DataType* type = key_types[0].type;
+ if (type != NULLPTR && is_fixed_width(*type)) {
+ return SimpleKeySegmenter::Make(key_types[0]);
+ }
+ }
+ return AnyKeysSegmenter::Make(key_types, ctx);
+}
+
+namespace {
+
+struct GrouperNoKeysImpl : Grouper {
+ Result<std::shared_ptr<Array>> MakeConstantGroupIdArray(int64_t length,
+ group_id_t value) {
+ std::unique_ptr<ArrayBuilder> a_builder;
+ RETURN_NOT_OK(MakeBuilder(default_memory_pool(), g_group_id_type,
&a_builder));
+ using GroupIdBuilder = typename TypeTraits<GroupIdType>::BuilderType;
+ auto builder = checked_cast<GroupIdBuilder*>(a_builder.get());
+ if (length != 0) {
+ RETURN_NOT_OK(builder->Resize(length));
+ }
+ for (int64_t i = 0; i < length; i++) {
+ builder->UnsafeAppend(value);
+ }
+ std::shared_ptr<Array> array;
+ RETURN_NOT_OK(builder->Finish(&array));
+ return std::move(array);
+ }
+ Result<Datum> Consume(const ExecSpan& batch, int64_t offset, int64_t length)
override {
+ ARROW_ASSIGN_OR_RAISE(auto array, MakeConstantGroupIdArray(length, 0));
+ return Datum(array);
+ }
+ Result<ExecBatch> GetUniques() override {
+ auto data = ArrayData::Make(uint32(), 1, 0);
+ auto values = data->GetMutableValues<uint32_t>(0);
+ values[0] = 0;
+ ExecBatch out({Datum(data)}, 1);
+ return std::move(out);
+ }
+ uint32_t num_groups() const override { return 1; }
+};
+
+struct GrouperImpl : public Grouper {
static Result<std::unique_ptr<GrouperImpl>> Make(
const std::vector<TypeHolder>& key_types, ExecContext* ctx) {
auto impl = std::make_unique<GrouperImpl>();
@@ -95,7 +421,12 @@ struct GrouperImpl : Grouper {
return std::move(impl);
}
- Result<Datum> Consume(const ExecSpan& batch) override {
+ Result<Datum> Consume(const ExecSpan& batch, int64_t offset, int64_t length)
override {
+ ARROW_RETURN_NOT_OK(CheckAndCapLengthForConsume(batch.length, offset,
&length));
+ if (offset != 0 || length != batch.length) {
+ auto batch_slice = batch.ToExecBatch().Slice(offset, length);
+ return Consume(ExecSpan(batch_slice), 0, -1);
+ }
std::vector<int32_t> offsets_batch(batch.length + 1);
for (int i = 0; i < batch.num_values(); ++i) {
encoders_[i]->AddLength(batch[i], batch.length, offsets_batch.data());
@@ -179,11 +510,14 @@ struct GrouperImpl : Grouper {
std::vector<std::unique_ptr<internal::KeyEncoder>> encoders_;
};
-struct GrouperFastImpl : Grouper {
+struct GrouperFastImpl : public Grouper {
static constexpr int kBitmapPaddingForSIMD = 64; // bits
static constexpr int kPaddingForSIMD = 32; // bytes
static bool CanUse(const std::vector<TypeHolder>& key_types) {
+ if (key_types.size() == 0) {
+ return false;
+ }
#if ARROW_LITTLE_ENDIAN
for (size_t i = 0; i < key_types.size(); ++i) {
if (is_large_binary_like(key_types[i].id())) {
@@ -265,7 +599,12 @@ struct GrouperFastImpl : Grouper {
~GrouperFastImpl() { map_.cleanup(); }
- Result<Datum> Consume(const ExecSpan& batch) override {
+ Result<Datum> Consume(const ExecSpan& batch, int64_t offset, int64_t length)
override {
+ ARROW_RETURN_NOT_OK(CheckAndCapLengthForConsume(batch.length, offset,
&length));
+ if (offset != 0 || length != batch.length) {
+ auto batch_slice = batch.ToExecBatch().Slice(offset, length);
+ return Consume(ExecSpan(batch_slice), 0, -1);
+ }
// ARROW-14027: broadcast scalar arguments for now
for (int i = 0; i < batch.num_values(); i++) {
if (batch[i].is_scalar()) {
diff --git a/cpp/src/arrow/compute/row/grouper.h
b/cpp/src/arrow/compute/row/grouper.h
index ce09adf09b..f9e7e2e97e 100644
--- a/cpp/src/arrow/compute/row/grouper.h
+++ b/cpp/src/arrow/compute/row/grouper.h
@@ -30,6 +30,78 @@
namespace arrow {
namespace compute {
+/// \brief A segment
+/// A segment group is a chunk of continous rows that have the same segment
key. (For
+/// example, in ordered time series processing, segment key can be "date", and
a segment
+/// group can be all the rows that belong to the same date.) A segment group
can span
+/// across multiple exec batches. A segment is a chunk of continous rows that
has the same
+/// segment key within a given batch. When a segment group span cross batches,
it will
+/// have multiple segments. A segment never spans cross batches. The segment
data
+/// structure only makes sense when used along with a exec batch.
+struct ARROW_EXPORT Segment {
+ /// \brief the offset into the batch where the segment starts
+ int64_t offset;
+ /// \brief the length of the segment
+ int64_t length;
+ /// \brief whether the segment may be extended by a next one
+ bool is_open;
+ /// \brief whether the segment extends a preceeding one
+ bool extends;
+};
+
+inline bool operator==(const Segment& segment1, const Segment& segment2) {
+ return segment1.offset == segment2.offset && segment1.length ==
segment2.length &&
+ segment1.is_open == segment2.is_open && segment1.extends ==
segment2.extends;
+}
+inline bool operator!=(const Segment& segment1, const Segment& segment2) {
+ return !(segment1 == segment2);
+}
+
+/// \brief a helper class to divide a batch into segments of equal values
+///
+/// For example, given a batch with two rows:
+///
+/// A A
+/// A A
+/// A B
+/// A B
+/// A A
+///
+/// Then the batch could be divided into 3 segments. The first would be rows
0 & 1,
+/// the second would be rows 2 & 3, and the third would be row 4.
+///
+/// Further, a segmenter keeps track of the last value seen. This allows it
to calculate
+/// segments which span batches. In our above example the last batch we emit
would set
+/// the "open" flag, which indicates whether the segment may extend into the
next batch.
+///
+/// If the next call to the segmenter starts with `A A` then that segment
would set the
+/// "extends" flag, which indicates whether the segment continues the last
open batch.
+class ARROW_EXPORT RowSegmenter {
+ public:
+ virtual ~RowSegmenter() = default;
+
+ /// \brief Construct a Segmenter which segments on the specified key types
+ ///
+ /// \param[in] key_types the specified key types
+ /// \param[in] nullable_keys whether values of the specified keys may be null
+ /// \param[in] ctx the execution context to use
+ static Result<std::unique_ptr<RowSegmenter>> Make(
+ const std::vector<TypeHolder>& key_types, bool nullable_keys,
ExecContext* ctx);
+
+ /// \brief Return the key types of this segmenter
+ virtual const std::vector<TypeHolder>& key_types() const = 0;
+
+ /// \brief Reset this segmenter
+ ///
+ /// A segmenter normally extends (see `Segment`) a segment from one batch to
the next.
+ /// If segment-extenion is undesirable, for example when each batch is
processed
+ /// independently, then `Reset` should be invoked before processing the next
batch.
+ virtual Status Reset() = 0;
+
+ /// \brief Get the next segment for the given batch starting from the given
offset
+ virtual Result<Segment> GetNextSegment(const ExecSpan& batch, int64_t
offset) = 0;
+};
+
/// Consumes batches of keys and yields batches of the group ids.
class ARROW_EXPORT Grouper {
public:
@@ -39,10 +111,12 @@ class ARROW_EXPORT Grouper {
static Result<std::unique_ptr<Grouper>> Make(const std::vector<TypeHolder>&
key_types,
ExecContext* ctx =
default_exec_context());
- /// Consume a batch of keys, producing the corresponding group ids as an
integer array.
+ /// Consume a batch of keys, producing the corresponding group ids as an
integer array,
+ /// over a slice defined by an offset and length, which defaults to the
batch length.
/// Currently only uint32 indices will be produced, eventually the bit width
will only
/// be as wide as necessary.
- virtual Result<Datum> Consume(const ExecSpan& batch) = 0;
+ virtual Result<Datum> Consume(const ExecSpan& batch, int64_t offset = 0,
+ int64_t length = -1) = 0;
/// Get current unique keys. May be called multiple times.
virtual Result<ExecBatch> GetUniques() = 0;
diff --git a/cpp/src/arrow/compute/row/grouper_internal.h
b/cpp/src/arrow/compute/row/grouper_internal.h
new file mode 100644
index 0000000000..eb3dfe8ba1
--- /dev/null
+++ b/cpp/src/arrow/compute/row/grouper_internal.h
@@ -0,0 +1,27 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+namespace arrow {
+namespace compute {
+
+ARROW_EXPORT Result<std::unique_ptr<RowSegmenter>> MakeAnyKeysSegmenter(
+ const std::vector<TypeHolder>& key_types, ExecContext* ctx);
+
+} // namespace compute
+} // namespace arrow
diff --git a/cpp/src/arrow/scalar.h b/cpp/src/arrow/scalar.h
index 31dfdcbc84..d23b33e28f 100644
--- a/cpp/src/arrow/scalar.h
+++ b/cpp/src/arrow/scalar.h
@@ -136,6 +136,8 @@ struct ARROW_EXPORT PrimitiveScalarBase : public Scalar {
: Scalar(std::move(type), false) {}
using Scalar::Scalar;
+ /// \brief Get a const pointer to the value of this scalar. May be null.
+ virtual const void* data() const = 0;
/// \brief Get a mutable pointer to the value of this scalar. May be null.
virtual void* mutable_data() = 0;
/// \brief Get an immutable view of the value of this scalar as bytes.
@@ -157,6 +159,7 @@ struct ARROW_EXPORT PrimitiveScalar : public
PrimitiveScalarBase {
ValueType value{};
+ const void* data() const override { return &value; }
void* mutable_data() override { return &value; }
std::string_view view() const override {
return std::string_view(reinterpret_cast<const char*>(&value),
sizeof(ValueType));
@@ -241,6 +244,9 @@ struct ARROW_EXPORT BaseBinaryScalar : public
internal::PrimitiveScalarBase {
std::shared_ptr<Buffer> value;
+ const void* data() const override {
+ return value ? reinterpret_cast<const void*>(value->data()) : NULLPTR;
+ }
void* mutable_data() override {
return value ? reinterpret_cast<void*>(value->mutable_data()) : NULLPTR;
}
@@ -434,6 +440,10 @@ struct ARROW_EXPORT DecimalScalar : public
internal::PrimitiveScalarBase {
DecimalScalar(ValueType value, std::shared_ptr<DataType> type)
: internal::PrimitiveScalarBase(std::move(type), true), value(value) {}
+ const void* data() const override {
+ return reinterpret_cast<const void*>(value.native_endian_bytes());
+ }
+
void* mutable_data() override {
return reinterpret_cast<void*>(value.mutable_native_endian_bytes());
}
@@ -603,6 +613,9 @@ struct ARROW_EXPORT DictionaryScalar : public
internal::PrimitiveScalarBase {
Result<std::shared_ptr<Scalar>> GetEncodedValue() const;
+ const void* data() const override {
+ return
internal::checked_cast<internal::PrimitiveScalarBase&>(*value.index).data();
+ }
void* mutable_data() override {
return internal::checked_cast<internal::PrimitiveScalarBase&>(*value.index)
.mutable_data();