This is an automated email from the ASF dual-hosted git repository.
westonpace pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new bb67f8d128 ARROW-16549: [C++] Simplify AggregateNodeOptions
aggregates/targets (#13150)
bb67f8d128 is described below
commit bb67f8d128408b31e38eb603a03367e0d9a803c3
Author: Vibhatha Lakmal Abeykoon <[email protected]>
AuthorDate: Mon Jun 27 20:10:30 2022 +0530
ARROW-16549: [C++] Simplify AggregateNodeOptions aggregates/targets (#13150)
This PR is simplifying the existing `AggregateNodeOptions` usage. This work
is still in progress and need to evaluate the existing refactor and usage.
Todos
- [x] Test
- [ ] Update documentation
- [ ] Update function docs
- [x] Evaluate CI failures (only tested on Mac M1 with C++/Python, need to
check if the change breaks other language bindings
Authored-by: Vibhatha Abeykoon <[email protected]>
Signed-off-by: Weston Pace <[email protected]>
---
c_glib/arrow-glib/compute.cpp | 22 +-
c_glib/arrow-glib/compute.h | 2 +-
.../arrow/execution_plan_documentation_examples.cc | 9 +-
cpp/src/arrow/compute/api_aggregate.h | 9 +-
cpp/src/arrow/compute/exec/aggregate.cc | 5 +-
cpp/src/arrow/compute/exec/aggregate.h | 7 +-
cpp/src/arrow/compute/exec/aggregate_node.cc | 27 +-
cpp/src/arrow/compute/exec/ir_consumer.cc | 7 +-
cpp/src/arrow/compute/exec/ir_test.cc | 51 +--
cpp/src/arrow/compute/exec/options.h | 16 +-
cpp/src/arrow/compute/exec/plan_test.cc | 230 ++++++-----
cpp/src/arrow/compute/exec/test_util.cc | 22 +-
cpp/src/arrow/compute/exec/tpch_benchmark.cc | 23 +-
.../arrow/compute/kernels/aggregate_benchmark.cc | 5 +-
.../arrow/compute/kernels/hash_aggregate_test.cc | 440 ++++++++++-----------
cpp/src/arrow/dataset/scanner.cc | 6 +-
cpp/src/arrow/dataset/scanner_test.cc | 19 +-
python/pyarrow/includes/libarrow.pxd | 2 +-
r/R/arrowExports.R | 4 +-
r/R/query-engine.R | 19 +-
r/src/arrowExports.cpp | 10 +-
r/src/compute-exec.cpp | 21 +-
22 files changed, 451 insertions(+), 505 deletions(-)
diff --git a/c_glib/arrow-glib/compute.cpp b/c_glib/arrow-glib/compute.cpp
index ca5bc1a76f..193ba2837b 100644
--- a/c_glib/arrow-glib/compute.cpp
+++ b/c_glib/arrow-glib/compute.cpp
@@ -1254,9 +1254,7 @@ garrow_aggregate_node_options_new(GList *aggregations,
gsize n_keys,
GError **error)
{
- std::vector<arrow::compute::internal::Aggregate> arrow_aggregates;
- std::vector<arrow::FieldRef> arrow_targets;
- std::vector<std::string> arrow_names;
+ std::vector<arrow::compute::Aggregate> arrow_aggregates;
std::vector<arrow::FieldRef> arrow_keys;
for (auto node = aggregations; node; node = node->next) {
auto aggregation_priv = GARROW_AGGREGATION_GET_PRIVATE(node->data);
@@ -1265,21 +1263,19 @@ garrow_aggregate_node_options_new(GList *aggregations,
function_options =
garrow_function_options_get_raw(aggregation_priv->options);
};
- if (function_options) {
- arrow_aggregates.push_back({
- aggregation_priv->function,
- function_options->Copy(),
- });
- } else {
- arrow_aggregates.push_back({aggregation_priv->function, nullptr});
- };
+ std::vector<arrow::FieldRef> arrow_targets;
if (!garrow_field_refs_add(arrow_targets,
aggregation_priv->input,
error,
"[aggregate-node-options][new][input]")) {
return NULL;
}
- arrow_names.emplace_back(aggregation_priv->output);
+ arrow_aggregates.push_back({
+ aggregation_priv->function,
+ function_options ? function_options->Copy() : nullptr,
+ arrow_targets[0],
+ aggregation_priv->output,
+ });
}
for (gsize i = 0; i < n_keys; ++i) {
if (!garrow_field_refs_add(arrow_keys,
@@ -1291,8 +1287,6 @@ garrow_aggregate_node_options_new(GList *aggregations,
}
auto arrow_options =
new arrow::compute::AggregateNodeOptions(std::move(arrow_aggregates),
- std::move(arrow_targets),
- std::move(arrow_names),
std::move(arrow_keys));
auto options = g_object_new(GARROW_TYPE_AGGREGATE_NODE_OPTIONS,
"options", arrow_options,
diff --git a/c_glib/arrow-glib/compute.h b/c_glib/arrow-glib/compute.h
index 818b76ed72..32db15be8b 100644
--- a/c_glib/arrow-glib/compute.h
+++ b/c_glib/arrow-glib/compute.h
@@ -755,7 +755,7 @@ garrow_quantile_options_get_qs(GArrowQuantileOptions
*options,
GARROW_AVAILABLE_IN_9_0
void
garrow_quantile_options_set_q(GArrowQuantileOptions *options,
- gdouble quantile);
+ gdouble q);
GARROW_AVAILABLE_IN_9_0
void
garrow_quantile_options_set_qs(GArrowQuantileOptions *options,
diff --git a/cpp/examples/arrow/execution_plan_documentation_examples.cc
b/cpp/examples/arrow/execution_plan_documentation_examples.cc
index 5f80119bbb..331388a529 100644
--- a/cpp/examples/arrow/execution_plan_documentation_examples.cc
+++ b/cpp/examples/arrow/execution_plan_documentation_examples.cc
@@ -502,9 +502,8 @@ arrow::Status
SourceScalarAggregateSinkExample(cp::ExecContext& exec_context) {
ARROW_ASSIGN_OR_RAISE(cp::ExecNode * source,
cp::MakeExecNode("source", plan.get(), {},
source_node_options));
- auto aggregate_options = cp::AggregateNodeOptions{/*aggregates=*/{{"sum",
nullptr}},
- /*targets=*/{"a"},
- /*names=*/{"sum(a)"}};
+ auto aggregate_options =
+ cp::AggregateNodeOptions{/*aggregates=*/{{"sum", nullptr, "a",
"sum(a)"}}};
ARROW_ASSIGN_OR_RAISE(
cp::ExecNode * aggregate,
cp::MakeExecNode("aggregate", plan.get(), {source},
std::move(aggregate_options)));
@@ -541,9 +540,7 @@ arrow::Status
SourceGroupAggregateSinkExample(cp::ExecContext& exec_context) {
cp::MakeExecNode("source", plan.get(), {},
source_node_options));
auto options =
std::make_shared<cp::CountOptions>(cp::CountOptions::ONLY_VALID);
auto aggregate_options =
- cp::AggregateNodeOptions{/*aggregates=*/{{"hash_count", options}},
- /*targets=*/{"a"},
- /*names=*/{"count(a)"},
+ cp::AggregateNodeOptions{/*aggregates=*/{{"hash_count", options, "a",
"count(a)"}},
/*keys=*/{"b"}};
ARROW_ASSIGN_OR_RAISE(
cp::ExecNode * aggregate,
diff --git a/cpp/src/arrow/compute/api_aggregate.h
b/cpp/src/arrow/compute/api_aggregate.h
index becd5a7414..55f434c665 100644
--- a/cpp/src/arrow/compute/api_aggregate.h
+++ b/cpp/src/arrow/compute/api_aggregate.h
@@ -393,8 +393,6 @@ ARROW_EXPORT
Result<Datum> Index(const Datum& value, const IndexOptions& options,
ExecContext* ctx = NULLPTR);
-namespace internal {
-
/// \brief Configure a grouped aggregation
struct ARROW_EXPORT Aggregate {
/// the name of the aggregation function
@@ -402,8 +400,13 @@ struct ARROW_EXPORT Aggregate {
/// options for the aggregation function
std::shared_ptr<FunctionOptions> options;
+
+ // fields to which aggregations will be applied
+ FieldRef target;
+
+ // output field name for aggregations
+ std::string name;
};
-} // namespace internal
} // namespace compute
} // namespace arrow
diff --git a/cpp/src/arrow/compute/exec/aggregate.cc
b/cpp/src/arrow/compute/exec/aggregate.cc
index 934cdd4d1f..41b5bb75b6 100644
--- a/cpp/src/arrow/compute/exec/aggregate.cc
+++ b/cpp/src/arrow/compute/exec/aggregate.cc
@@ -22,6 +22,7 @@
#include "arrow/compute/exec_internal.h"
#include "arrow/compute/registry.h"
#include "arrow/compute/row/grouper.h"
+#include "arrow/util/checked_cast.h"
#include "arrow/util/task_group.h"
namespace arrow {
@@ -55,7 +56,9 @@ Result<std::vector<std::unique_ptr<KernelState>>> InitKernels(
std::vector<std::unique_ptr<KernelState>> states(kernels.size());
for (size_t i = 0; i < aggregates.size(); ++i) {
- const FunctionOptions* options = aggregates[i].options.get();
+ const FunctionOptions* options =
+ arrow::internal::checked_cast<const FunctionOptions*>(
+ aggregates[i].options.get());
if (options == nullptr) {
// use known default options for the named function if possible
diff --git a/cpp/src/arrow/compute/exec/aggregate.h
b/cpp/src/arrow/compute/exec/aggregate.h
index 2c62acf231..753b0a8c47 100644
--- a/cpp/src/arrow/compute/exec/aggregate.h
+++ b/cpp/src/arrow/compute/exec/aggregate.h
@@ -41,16 +41,15 @@ Result<Datum> GroupBy(const std::vector<Datum>& arguments,
const std::vector<Dat
ExecContext* ctx = default_exec_context());
Result<std::vector<const HashAggregateKernel*>> GetKernels(
- ExecContext* ctx, const std::vector<internal::Aggregate>& aggregates,
+ ExecContext* ctx, const std::vector<Aggregate>& aggregates,
const std::vector<ValueDescr>& in_descrs);
Result<std::vector<std::unique_ptr<KernelState>>> InitKernels(
const std::vector<const HashAggregateKernel*>& kernels, ExecContext* ctx,
- const std::vector<internal::Aggregate>& aggregates,
- const std::vector<ValueDescr>& in_descrs);
+ const std::vector<Aggregate>& aggregates, const std::vector<ValueDescr>&
in_descrs);
Result<FieldVector> ResolveKernels(
- const std::vector<internal::Aggregate>& aggregates,
+ const std::vector<Aggregate>& aggregates,
const std::vector<const HashAggregateKernel*>& kernels,
const std::vector<std::unique_ptr<KernelState>>& states, ExecContext* ctx,
const std::vector<ValueDescr>& descrs);
diff --git a/cpp/src/arrow/compute/exec/aggregate_node.cc
b/cpp/src/arrow/compute/exec/aggregate_node.cc
index c5c5d3efcf..8c7899c41e 100644
--- a/cpp/src/arrow/compute/exec/aggregate_node.cc
+++ b/cpp/src/arrow/compute/exec/aggregate_node.cc
@@ -44,7 +44,7 @@ namespace compute {
namespace {
void AggregatesToString(std::stringstream* ss, const Schema& input_schema,
- const std::vector<internal::Aggregate>& aggs,
+ const std::vector<Aggregate>& aggs,
const std::vector<int>& target_field_ids, int indent =
0) {
*ss << "aggregates=[" << std::endl;
for (size_t i = 0; i < aggs.size(); i++) {
@@ -64,8 +64,7 @@ class ScalarAggregateNode : public ExecNode {
public:
ScalarAggregateNode(ExecPlan* plan, std::vector<ExecNode*> inputs,
std::shared_ptr<Schema> output_schema,
- std::vector<int> target_field_ids,
- std::vector<internal::Aggregate> aggs,
+ std::vector<int> target_field_ids,
std::vector<Aggregate> aggs,
std::vector<const ScalarAggregateKernel*> kernels,
std::vector<std::vector<std::unique_ptr<KernelState>>>
states)
: ExecNode(plan, std::move(inputs), {"target"},
@@ -89,12 +88,12 @@ class ScalarAggregateNode : public ExecNode {
std::vector<const ScalarAggregateKernel*> kernels(aggregates.size());
std::vector<std::vector<std::unique_ptr<KernelState>>>
states(kernels.size());
FieldVector fields(kernels.size());
- const auto& field_names = aggregate_options.names;
std::vector<int> target_field_ids(kernels.size());
for (size_t i = 0; i < kernels.size(); ++i) {
- ARROW_ASSIGN_OR_RAISE(auto match,
-
FieldRef(aggregate_options.targets[i]).FindOne(input_schema));
+ ARROW_ASSIGN_OR_RAISE(
+ auto match,
+
FieldRef(aggregate_options.aggregates[i].target).FindOne(input_schema));
target_field_ids[i] = match[0];
ARROW_ASSIGN_OR_RAISE(
@@ -129,7 +128,7 @@ class ScalarAggregateNode : public ExecNode {
ARROW_ASSIGN_OR_RAISE(
auto descr, kernels[i]->signature->out_type().Resolve(&kernel_ctx,
{in_type}));
- fields[i] = field(field_names[i], std::move(descr.type));
+ fields[i] = field(aggregate_options.aggregates[i].name,
std::move(descr.type));
}
return plan->EmplaceNode<ScalarAggregateNode>(
@@ -263,7 +262,7 @@ class ScalarAggregateNode : public ExecNode {
}
const std::vector<int> target_field_ids_;
- const std::vector<internal::Aggregate> aggs_;
+ const std::vector<Aggregate> aggs_;
const std::vector<const ScalarAggregateKernel*> kernels_;
std::vector<std::vector<std::unique_ptr<KernelState>>> states_;
@@ -276,7 +275,7 @@ class GroupByNode : public ExecNode {
public:
GroupByNode(ExecNode* input, std::shared_ptr<Schema> output_schema,
ExecContext* ctx,
std::vector<int> key_field_ids, std::vector<int>
agg_src_field_ids,
- std::vector<internal::Aggregate> aggs,
+ std::vector<Aggregate> aggs,
std::vector<const HashAggregateKernel*> agg_kernels)
: ExecNode(input->plan(), {input}, {"groupby"}, std::move(output_schema),
/*num_outputs=*/1),
@@ -295,7 +294,6 @@ class GroupByNode : public ExecNode {
const auto& keys = aggregate_options.keys;
// Copy (need to modify options pointer below)
auto aggs = aggregate_options.aggregates;
- const auto& field_names = aggregate_options.names;
// Get input schema
auto input_schema = input->output_schema();
@@ -310,13 +308,11 @@ class GroupByNode : public ExecNode {
// Find input field indices for aggregates
std::vector<int> agg_src_field_ids(aggs.size());
for (size_t i = 0; i < aggs.size(); ++i) {
- ARROW_ASSIGN_OR_RAISE(auto match,
-
aggregate_options.targets[i].FindOne(*input_schema));
+ ARROW_ASSIGN_OR_RAISE(auto match, aggs[i].target.FindOne(*input_schema));
agg_src_field_ids[i] = match[0];
}
// Build vector of aggregate source field data types
- DCHECK_EQ(aggregate_options.targets.size(), aggs.size());
std::vector<ValueDescr> agg_src_descrs(aggs.size());
for (size_t i = 0; i < aggs.size(); ++i) {
auto agg_src_field_id = agg_src_field_ids[i];
@@ -342,7 +338,8 @@ class GroupByNode : public ExecNode {
// Aggregate fields come before key fields to match the behavior of
GroupBy function
for (size_t i = 0; i < aggs.size(); ++i) {
- output_fields[i] = agg_result_fields[i]->WithName(field_names[i]);
+ output_fields[i] =
+ agg_result_fields[i]->WithName(aggregate_options.aggregates[i].name);
}
size_t base = aggs.size();
for (size_t i = 0; i < keys.size(); ++i) {
@@ -660,7 +657,7 @@ class GroupByNode : public ExecNode {
const std::vector<int> key_field_ids_;
const std::vector<int> agg_src_field_ids_;
- const std::vector<internal::Aggregate> aggs_;
+ const std::vector<Aggregate> aggs_;
const std::vector<const HashAggregateKernel*> agg_kernels_;
ThreadIndexer get_thread_index_;
diff --git a/cpp/src/arrow/compute/exec/ir_consumer.cc
b/cpp/src/arrow/compute/exec/ir_consumer.cc
index 0aafa2c281..f17dbf1ed7 100644
--- a/cpp/src/arrow/compute/exec/ir_consumer.cc
+++ b/cpp/src/arrow/compute/exec/ir_consumer.cc
@@ -531,7 +531,7 @@ Result<Declaration> Convert(const ir::Relation& rel) {
ARROW_ASSIGN_OR_RAISE(auto arg,
Convert(*aggregate->rel()).As<Declaration::Input>());
- AggregateNodeOptions opts{{}, {}, {}};
+ AggregateNodeOptions opts{{}, {}};
if (!aggregate->measures()) return
UnexpectedNullField("Aggregate.measures");
for (const ir::Expression* m : *aggregate->measures()) {
@@ -550,9 +550,8 @@ Result<Declaration> Convert(const ir::Relation& rel) {
"Support for non-FieldRef arguments to Aggregate.measures");
}
- opts.aggregates.push_back({call->function_name, nullptr});
- opts.targets.push_back(*target);
- opts.names.push_back(call->function_name + " " + target->ToString());
+ opts.aggregates.push_back({call->function_name, nullptr, *target,
+ call->function_name + " " +
target->ToString()});
}
if (!aggregate->groupings()) return
UnexpectedNullField("Aggregate.groupings");
diff --git a/cpp/src/arrow/compute/exec/ir_test.cc
b/cpp/src/arrow/compute/exec/ir_test.cc
index 847f555c69..d7eb37c185 100644
--- a/cpp/src/arrow/compute/exec/ir_test.cc
+++ b/cpp/src/arrow/compute/exec/ir_test.cc
@@ -249,7 +249,8 @@ TEST(Relation, Filter) {
}
TEST(Relation, AggregateSimple) {
- ASSERT_THAT(ConvertJSON<ir::Relation>(R"({
+ ASSERT_THAT(
+ ConvertJSON<ir::Relation>(R"({
"impl": {
id: {id: 1},
"groupings": [
@@ -347,28 +348,22 @@ TEST(Relation, AggregateSimple) {
},
"impl_type": "Aggregate"
})"),
- ResultWith(Eq(Declaration::Sequence({
- {"catalog_source",
- CatalogSourceNodeOptions{"tbl", schema({
- field("foo", int32()),
- field("bar", int64()),
- field("baz", float64()),
- })},
- "0"},
- {"aggregate",
- AggregateNodeOptions{/*aggregates=*/{
- {"sum", nullptr},
- {"mean", nullptr},
- },
- /*targets=*/{1, 2},
- /*names=*/
- {
- "sum FieldRef.FieldPath(1)",
- "mean FieldRef.FieldPath(2)",
- },
- /*keys=*/{0}},
- "1"},
- }))));
+ ResultWith(Eq(Declaration::Sequence({
+ {"catalog_source",
+ CatalogSourceNodeOptions{"tbl", schema({
+ field("foo", int32()),
+ field("bar", int64()),
+ field("baz", float64()),
+ })},
+ "0"},
+ {"aggregate",
+ AggregateNodeOptions{/*aggregates=*/{
+ {"sum", nullptr, 1, "sum
FieldRef.FieldPath(1)"},
+ {"mean", nullptr, 2, "mean
FieldRef.FieldPath(2)"},
+ },
+ /*keys=*/{0}},
+ "1"},
+ }))));
}
TEST(Relation, AggregateWithHaving) {
@@ -564,14 +559,8 @@ TEST(Relation, AggregateWithHaving) {
{"filter", FilterNodeOptions{less(field_ref(0),
literal<int8_t>(3))}, "1"},
{"aggregate",
AggregateNodeOptions{/*aggregates=*/{
- {"sum", nullptr},
- {"mean", nullptr},
- },
- /*targets=*/{1, 2},
- /*names=*/
- {
- "sum FieldRef.FieldPath(1)",
- "mean FieldRef.FieldPath(2)",
+ {"sum", nullptr, 1, "sum
FieldRef.FieldPath(1)"},
+ {"mean", nullptr, 2, "mean
FieldRef.FieldPath(2)"},
},
/*keys=*/{0}},
"2"},
diff --git a/cpp/src/arrow/compute/exec/options.h
b/cpp/src/arrow/compute/exec/options.h
index e0fb31963c..a86b6c63d3 100644
--- a/cpp/src/arrow/compute/exec/options.h
+++ b/cpp/src/arrow/compute/exec/options.h
@@ -111,20 +111,12 @@ class ARROW_EXPORT ProjectNodeOptions : public
ExecNodeOptions {
/// \brief Make a node which aggregates input batches, optionally grouped by
keys.
class ARROW_EXPORT AggregateNodeOptions : public ExecNodeOptions {
public:
- AggregateNodeOptions(std::vector<internal::Aggregate> aggregates,
- std::vector<FieldRef> targets, std::vector<std::string>
names,
- std::vector<FieldRef> keys = {})
- : aggregates(std::move(aggregates)),
- targets(std::move(targets)),
- names(std::move(names)),
- keys(std::move(keys)) {}
+ explicit AggregateNodeOptions(std::vector<Aggregate> aggregates,
+ std::vector<FieldRef> keys = {})
+ : aggregates(std::move(aggregates)), keys(std::move(keys)) {}
// aggregations which will be applied to the targetted fields
- std::vector<internal::Aggregate> aggregates;
- // fields to which aggregations will be applied
- std::vector<FieldRef> targets;
- // output field names for aggregations
- std::vector<std::string> names;
+ std::vector<Aggregate> aggregates;
// keys by which aggregations will be grouped
std::vector<FieldRef> keys;
};
diff --git a/cpp/src/arrow/compute/exec/plan_test.cc
b/cpp/src/arrow/compute/exec/plan_test.cc
index 2df3c5e915..9efa6623e5 100644
--- a/cpp/src/arrow/compute/exec/plan_test.cc
+++ b/cpp/src/arrow/compute/exec/plan_test.cc
@@ -391,9 +391,10 @@ TEST(ExecPlan, ToString) {
}}},
{"aggregate",
AggregateNodeOptions{
- /*aggregates=*/{{"hash_sum", nullptr}, {"hash_count",
options}},
- /*targets=*/{"multiply(i32, 2)", "multiply(i32, 2)"},
- /*names=*/{"sum(multiply(i32, 2))", "count(multiply(i32,
2))"},
+ /*aggregates=*/{
+ {"hash_sum", nullptr, "multiply(i32, 2)",
"sum(multiply(i32, 2))"},
+ {"hash_count", options, "multiply(i32, 2)",
+ "count(multiply(i32, 2))"}},
/*keys=*/{"bool"}}},
{"filter",
FilterNodeOptions{greater(field_ref("sum(multiply(i32, 2))"),
literal(10))}},
@@ -429,17 +430,16 @@
custom_sink_label:OrderBySinkNode{by={sort_keys=[FieldRef.Name(sum(multiply(i32,
rhs.label = "rhs";
union_node.inputs.emplace_back(lhs);
union_node.inputs.emplace_back(rhs);
- ASSERT_OK(Declaration::Sequence(
- {
- union_node,
- {"aggregate",
- AggregateNodeOptions{/*aggregates=*/{{"count",
std::move(options)}},
- /*targets=*/{"i32"},
- /*names=*/{"count(i32)"},
- /*keys=*/{}}},
- {"sink", SinkNodeOptions{&sink_gen}},
- })
- .AddToPlan(plan.get()));
+ ASSERT_OK(
+ Declaration::Sequence(
+ {
+ union_node,
+ {"aggregate", AggregateNodeOptions{
+ /*aggregates=*/{{"count", options, "i32",
"count(i32)"}},
+ /*keys=*/{}}},
+ {"sink", SinkNodeOptions{&sink_gen}},
+ })
+ .AddToPlan(plan.get()));
EXPECT_EQ(plan->ToString(), R"a(ExecPlan with 5 nodes:
:SinkNode{}
:ScalarAggregateNode{aggregates=[
@@ -765,17 +765,17 @@ TEST(ExecPlanExecution, StressSourceGroupedSumStop) {
auto random_data = MakeRandomBatches(input_schema, num_batches);
SortOptions options({SortKey("a", SortOrder::Ascending)});
- ASSERT_OK(Declaration::Sequence(
- {
- {"source", SourceNodeOptions{random_data.schema,
- random_data.gen(parallel,
slow)}},
- {"aggregate",
- AggregateNodeOptions{/*aggregates=*/{{"hash_sum",
nullptr}},
- /*targets=*/{"a"},
/*names=*/{"sum(a)"},
- /*keys=*/{"b"}}},
- {"sink", SinkNodeOptions{&sink_gen}},
- })
- .AddToPlan(plan.get()));
+ ASSERT_OK(
+ Declaration::Sequence(
+ {
+ {"source", SourceNodeOptions{random_data.schema,
+ random_data.gen(parallel,
slow)}},
+ {"aggregate", AggregateNodeOptions{
+ /*aggregates=*/{{"hash_sum", nullptr, "a",
"sum(a)"}},
+ /*keys=*/{"b"}}},
+ {"sink", SinkNodeOptions{&sink_gen}},
+ })
+ .AddToPlan(plan.get()));
ASSERT_OK(plan->Validate());
ASSERT_OK(plan->StartProducing());
@@ -913,17 +913,17 @@ TEST(ExecPlanExecution, SourceGroupedSum) {
ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make());
AsyncGenerator<util::optional<ExecBatch>> sink_gen;
- ASSERT_OK(Declaration::Sequence(
- {
- {"source", SourceNodeOptions{input.schema,
- input.gen(parallel,
/*slow=*/false)}},
- {"aggregate",
- AggregateNodeOptions{/*aggregates=*/{{"hash_sum",
nullptr}},
- /*targets=*/{"i32"},
/*names=*/{"sum(i32)"},
- /*keys=*/{"str"}}},
- {"sink", SinkNodeOptions{&sink_gen}},
- })
- .AddToPlan(plan.get()));
+ ASSERT_OK(
+ Declaration::Sequence(
+ {
+ {"source",
+ SourceNodeOptions{input.schema, input.gen(parallel,
/*slow=*/false)}},
+ {"aggregate",
AggregateNodeOptions{/*aggregates=*/{{"hash_sum", nullptr,
+ "i32",
"sum(i32)"}},
+ /*keys=*/{"str"}}},
+ {"sink", SinkNodeOptions{&sink_gen}},
+ })
+ .AddToPlan(plan.get()));
ASSERT_THAT(StartAndCollect(plan.get(), sink_gen),
Finishes(ResultWith(UnorderedElementsAreArray({ExecBatchFromJSON(
@@ -987,9 +987,8 @@ TEST(ExecPlanExecution, NestedSourceProjectGroupedSum) {
field_ref(FieldRef("struct", "bool")),
},
{"i32", "bool"}}},
- {"aggregate",
AggregateNodeOptions{/*aggregates=*/{{"hash_sum", nullptr}},
- /*targets=*/{"i32"},
- /*names=*/{"sum(i32)"},
+ {"aggregate",
AggregateNodeOptions{/*aggregates=*/{{"hash_sum", nullptr,
+ "i32",
"sum(i32)"}},
/*keys=*/{"bool"}}},
{"sink", SinkNodeOptions{&sink_gen}},
})
@@ -1021,10 +1020,11 @@ TEST(ExecPlanExecution,
SourceFilterProjectGroupedSumFilter) {
field_ref("str"),
call("multiply", {field_ref("i32"),
literal(2)}),
}}},
- {"aggregate",
AggregateNodeOptions{/*aggregates=*/{{"hash_sum", nullptr}},
- /*targets=*/{"multiply(i32,
2)"},
-
/*names=*/{"sum(multiply(i32, 2))"},
- /*keys=*/{"str"}}},
+ {"aggregate",
+ AggregateNodeOptions{
+ /*aggregates=*/{{"hash_sum", nullptr, "multiply(i32, 2)",
+ "sum(multiply(i32, 2))"}},
+ /*keys=*/{"str"}}},
{"filter",
FilterNodeOptions{greater(field_ref("sum(multiply(i32, 2))"),
literal(10 *
batch_multiplicity))}},
{"sink", SinkNodeOptions{&sink_gen}},
@@ -1060,10 +1060,11 @@ TEST(ExecPlanExecution,
SourceFilterProjectGroupedSumOrderBy) {
field_ref("str"),
call("multiply", {field_ref("i32"),
literal(2)}),
}}},
- {"aggregate",
AggregateNodeOptions{/*aggregates=*/{{"hash_sum", nullptr}},
- /*targets=*/{"multiply(i32,
2)"},
-
/*names=*/{"sum(multiply(i32, 2))"},
- /*keys=*/{"str"}}},
+ {"aggregate",
+ AggregateNodeOptions{
+ /*aggregates=*/{{"hash_sum", nullptr, "multiply(i32, 2)",
+ "sum(multiply(i32, 2))"}},
+ /*keys=*/{"str"}}},
{"filter",
FilterNodeOptions{greater(field_ref("sum(multiply(i32, 2))"),
literal(10 *
batch_multiplicity))}},
{"order_by_sink", OrderBySinkNodeOptions{options, &sink_gen}},
@@ -1088,22 +1089,22 @@ TEST(ExecPlanExecution,
SourceFilterProjectGroupedSumTopK) {
AsyncGenerator<util::optional<ExecBatch>> sink_gen;
SelectKOptions options = SelectKOptions::TopKDefault(/*k=*/1, {"str"});
- ASSERT_OK(
- Declaration::Sequence(
- {
- {"source",
- SourceNodeOptions{input.schema, input.gen(parallel,
/*slow=*/false)}},
- {"project", ProjectNodeOptions{{
- field_ref("str"),
- call("multiply", {field_ref("i32"),
literal(2)}),
- }}},
- {"aggregate",
AggregateNodeOptions{/*aggregates=*/{{"hash_sum", nullptr}},
- /*targets=*/{"multiply(i32,
2)"},
-
/*names=*/{"sum(multiply(i32, 2))"},
- /*keys=*/{"str"}}},
- {"select_k_sink", SelectKSinkNodeOptions{options, &sink_gen}},
- })
- .AddToPlan(plan.get()));
+ ASSERT_OK(Declaration::Sequence(
+ {
+ {"source", SourceNodeOptions{input.schema,
+ input.gen(parallel,
/*slow=*/false)}},
+ {"project", ProjectNodeOptions{{
+ field_ref("str"),
+ call("multiply", {field_ref("i32"),
literal(2)}),
+ }}},
+ {"aggregate",
+ AggregateNodeOptions{
+ /*aggregates=*/{{"hash_sum", nullptr,
"multiply(i32, 2)",
+ "sum(multiply(i32, 2))"}},
+ /*keys=*/{"str"}}},
+ {"select_k_sink", SelectKSinkNodeOptions{options,
&sink_gen}},
+ })
+ .AddToPlan(plan.get()));
ASSERT_THAT(
StartAndCollect(plan.get(), sink_gen),
@@ -1118,18 +1119,19 @@ TEST(ExecPlanExecution, SourceScalarAggSink) {
auto basic_data = MakeBasicBatches();
- ASSERT_OK(Declaration::Sequence(
- {
- {"source", SourceNodeOptions{basic_data.schema,
-
basic_data.gen(/*parallel=*/false,
-
/*slow=*/false)}},
- {"aggregate", AggregateNodeOptions{
- /*aggregates=*/{{"sum", nullptr},
{"any", nullptr}},
- /*targets=*/{"i32", "bool"},
- /*names=*/{"sum(i32)", "any(bool)"}}},
- {"sink", SinkNodeOptions{&sink_gen}},
- })
- .AddToPlan(plan.get()));
+ ASSERT_OK(
+ Declaration::Sequence(
+ {
+ {"source",
+ SourceNodeOptions{basic_data.schema,
basic_data.gen(/*parallel=*/false,
+
/*slow=*/false)}},
+ {"aggregate", AggregateNodeOptions{
+ /*aggregates=*/{{"sum", nullptr, "i32",
"sum(i32)"},
+ {"any", nullptr, "bool",
"any(bool)"}},
+ }},
+ {"sink", SinkNodeOptions{&sink_gen}},
+ })
+ .AddToPlan(plan.get()));
ASSERT_THAT(
StartAndCollect(plan.get(), sink_gen),
@@ -1156,9 +1158,9 @@ TEST(ExecPlanExecution, AggregationPreservesOptions) {
basic_data.gen(/*parallel=*/false,
/*slow=*/false)}},
{"aggregate",
- AggregateNodeOptions{/*aggregates=*/{{"tdigest",
options}},
- /*targets=*/{"i32"},
- /*names=*/{"tdigest(i32)"}}},
+ AggregateNodeOptions{
+ /*aggregates=*/{{"tdigest", options, "i32",
"tdigest(i32)"}},
+ }},
{"sink", SinkNodeOptions{&sink_gen}},
})
.AddToPlan(plan.get()));
@@ -1183,10 +1185,9 @@ TEST(ExecPlanExecution, AggregationPreservesOptions) {
{"source", SourceNodeOptions{data.schema,
data.gen(/*parallel=*/false,
/*slow=*/false)}},
{"aggregate",
- AggregateNodeOptions{/*aggregates=*/{{"hash_count",
options}},
- /*targets=*/{"i32"},
- /*names=*/{"count(i32)"},
- /*keys=*/{"str"}}},
+ AggregateNodeOptions{
+ /*aggregates=*/{{"hash_count", options, "i32",
"count(i32)"}},
+ /*keys=*/{"str"}}},
{"sink", SinkNodeOptions{&sink_gen}},
})
.AddToPlan(plan.get()));
@@ -1215,29 +1216,24 @@ TEST(ExecPlanExecution, ScalarSourceScalarAggSink) {
// index can't be tested as it's order-dependent
// mode/quantile can't be tested as they're technically vector kernels
- ASSERT_OK(
- Declaration::Sequence(
- {
- {"source",
- SourceNodeOptions{scalar_data.schema,
scalar_data.gen(/*parallel=*/false,
-
/*slow=*/false)}},
- {"aggregate", AggregateNodeOptions{
- /*aggregates=*/{{"all", nullptr},
- {"any", nullptr},
- {"count", nullptr},
- {"mean", nullptr},
- {"product", nullptr},
- {"stddev", nullptr},
- {"sum", nullptr},
- {"tdigest", nullptr},
- {"variance", nullptr}},
- /*targets=*/{"b", "b", "a", "a", "a", "a",
"a", "a", "a"},
- /*names=*/
- {"all(b)", "any(b)", "count(a)", "mean(a)",
"product(a)",
- "stddev(a)", "sum(a)", "tdigest(a)",
"variance(a)"}}},
- {"sink", SinkNodeOptions{&sink_gen}},
- })
- .AddToPlan(plan.get()));
+ ASSERT_OK(Declaration::Sequence(
+ {
+ {"source", SourceNodeOptions{scalar_data.schema,
+
scalar_data.gen(/*parallel=*/false,
+
/*slow=*/false)}},
+ {"aggregate", AggregateNodeOptions{/*aggregates=*/{
+ {"all", nullptr, "b", "all(b)"},
+ {"any", nullptr, "b", "any(b)"},
+ {"count", nullptr, "a", "count(a)"},
+ {"mean", nullptr, "a", "mean(a)"},
+ {"product", nullptr, "a", "product(a)"},
+ {"stddev", nullptr, "a", "stddev(a)"},
+ {"sum", nullptr, "a", "sum(a)"},
+ {"tdigest", nullptr, "a", "tdigest(a)"},
+ {"variance", nullptr, "a",
"variance(a)"}}}},
+ {"sink", SinkNodeOptions{&sink_gen}},
+ })
+ .AddToPlan(plan.get()));
ASSERT_THAT(
StartAndCollect(plan.get(), sink_gen),
@@ -1267,18 +1263,18 @@ TEST(ExecPlanExecution, ScalarSourceGroupedSum) {
scalar_data.schema = schema({field("a", int32()), field("b", boolean())});
SortOptions options({SortKey("b", SortOrder::Descending)});
- ASSERT_OK(Declaration::Sequence(
- {
- {"source", SourceNodeOptions{scalar_data.schema,
-
scalar_data.gen(/*parallel=*/false,
-
/*slow=*/false)}},
- {"aggregate",
- AggregateNodeOptions{/*aggregates=*/{{"hash_sum",
nullptr}},
- /*targets=*/{"a"},
/*names=*/{"hash_sum(a)"},
- /*keys=*/{"b"}}},
- {"order_by_sink", OrderBySinkNodeOptions{options,
&sink_gen}},
- })
- .AddToPlan(plan.get()));
+ ASSERT_OK(
+ Declaration::Sequence(
+ {
+ {"source",
+ SourceNodeOptions{scalar_data.schema,
scalar_data.gen(/*parallel=*/false,
+
/*slow=*/false)}},
+ {"aggregate", AggregateNodeOptions{/*aggregates=*/{{"hash_sum",
nullptr,
+ "a",
"hash_sum(a)"}},
+ /*keys=*/{"b"}}},
+ {"order_by_sink", OrderBySinkNodeOptions{options, &sink_gen}},
+ })
+ .AddToPlan(plan.get()));
ASSERT_THAT(StartAndCollect(plan.get(), sink_gen),
Finishes(ResultWith(UnorderedElementsAreArray({
diff --git a/cpp/src/arrow/compute/exec/test_util.cc
b/cpp/src/arrow/compute/exec/test_util.cc
index 40512f868c..1e09cb742f 100644
--- a/cpp/src/arrow/compute/exec/test_util.cc
+++ b/cpp/src/arrow/compute/exec/test_util.cc
@@ -337,10 +337,12 @@ bool operator==(const Declaration& l, const Declaration&
r) {
if (l_agg->options == nullptr || r_agg->options == nullptr) return false;
if (!l_agg->options->Equals(*r_agg->options)) return false;
+
+ if (l_agg->target != r_agg->target) return false;
+ if (l_agg->name != r_agg->name) return false;
}
- return l_opts->targets == r_opts->targets && l_opts->names ==
r_opts->names &&
- l_opts->keys == r_opts->keys;
+ return l_opts->keys == r_opts->keys;
}
if (l.factory_name == "order_by_sink") {
@@ -400,24 +402,14 @@ static inline void PrintToImpl(const std::string&
factory_name,
*os << "aggregates={";
for (const auto& agg : o->aggregates) {
- *os << agg.function << "<";
+ *os << "function=" << agg.function << "<";
if (agg.options) PrintTo(*agg.options, os);
*os << ">,";
+ *os << "target=" << agg.target.ToString() << ",";
+ *os << "name=" << agg.name;
}
*os << "},";
- *os << "targets={";
- for (const auto& target : o->targets) {
- *os << target.ToString() << ",";
- }
- *os << "},";
-
- *os << "names={";
- for (const auto& name : o->names) {
- *os << name << ",";
- }
- *os << "}";
-
if (!o->keys.empty()) {
*os << ",keys={";
for (const auto& key : o->keys) {
diff --git a/cpp/src/arrow/compute/exec/tpch_benchmark.cc
b/cpp/src/arrow/compute/exec/tpch_benchmark.cc
index 82584f58e9..54ac7cbdbf 100644
--- a/cpp/src/arrow/compute/exec/tpch_benchmark.cc
+++ b/cpp/src/arrow/compute/exec/tpch_benchmark.cc
@@ -77,21 +77,18 @@ std::shared_ptr<ExecPlan>
Plan_Q1(AsyncGenerator<util::optional<ExecBatch>>* sin
auto sum_opts =
std::make_shared<ScalarAggregateOptions>(ScalarAggregateOptions::Defaults());
auto count_opts =
std::make_shared<CountOptions>(CountOptions::CountMode::ALL);
- std::vector<arrow::compute::internal::Aggregate> aggs = {
- {"hash_sum", sum_opts}, {"hash_sum", sum_opts}, {"hash_sum",
sum_opts},
- {"hash_sum", sum_opts}, {"hash_mean", sum_opts}, {"hash_mean",
sum_opts},
- {"hash_mean", sum_opts}, {"hash_count", count_opts}};
-
- std::vector<FieldRef> to_aggregate = {"sum_qty", "sum_base_price",
"sum_disc_price",
- "sum_charge", "avg_qty",
"avg_price",
- "avg_disc", "sum_qty"};
-
- std::vector<std::string> names = {"sum_qty", "sum_base_price",
"sum_disc_price",
- "sum_charge", "avg_qty",
"avg_price",
- "avg_disc", "count_order"};
+ std::vector<arrow::compute::Aggregate> aggs = {
+ {"hash_sum", sum_opts, "sum_qty", "sum_qty"},
+ {"hash_sum", sum_opts, "sum_base_price", "sum_base_price"},
+ {"hash_sum", sum_opts, "sum_disc_price", "sum_disc_price"},
+ {"hash_sum", sum_opts, "sum_charge", "sum_charge"},
+ {"hash_mean", sum_opts, "avg_qty", "avg_qty"},
+ {"hash_mean", sum_opts, "avg_price", "avg_price"},
+ {"hash_mean", sum_opts, "avg_disc", "avg_disc"},
+ {"hash_count", count_opts, "sum_qty", "count_order"}};
std::vector<FieldRef> keys = {"l_returnflag", "l_linestatus"};
- AggregateNodeOptions agg_opts(aggs, to_aggregate, names, keys);
+ AggregateNodeOptions agg_opts(aggs, keys);
SortKey l_returnflag_key("l_returnflag");
SortKey l_linestatus_key("l_linestatus");
diff --git a/cpp/src/arrow/compute/kernels/aggregate_benchmark.cc
b/cpp/src/arrow/compute/kernels/aggregate_benchmark.cc
index c271285434..a8cae5b50c 100644
--- a/cpp/src/arrow/compute/kernels/aggregate_benchmark.cc
+++ b/cpp/src/arrow/compute/kernels/aggregate_benchmark.cc
@@ -306,8 +306,9 @@ BENCHMARK_TEMPLATE(ReferenceSum,
SumBitmapVectorizeUnroll<int64_t>)
// GroupBy
//
-static void BenchmarkGroupBy(benchmark::State& state,
- std::vector<internal::Aggregate> aggregates,
+using arrow::compute::internal::GroupBy;
+
+static void BenchmarkGroupBy(benchmark::State& state, std::vector<Aggregate>
aggregates,
std::vector<Datum> arguments, std::vector<Datum>
keys) {
for (auto _ : state) {
ABORT_NOT_OK(GroupBy(arguments, keys, aggregates).status());
diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc
b/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc
index 7b47845f23..82d40aba94 100644
--- a/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc
+++ b/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc
@@ -72,7 +72,7 @@ namespace compute {
namespace {
Result<Datum> NaiveGroupBy(std::vector<Datum> arguments, std::vector<Datum>
keys,
- const std::vector<internal::Aggregate>& aggregates)
{
+ const std::vector<Aggregate>& aggregates) {
ARROW_ASSIGN_OR_RAISE(auto key_batch, ExecBatch::Make(std::move(keys)));
ARROW_ASSIGN_OR_RAISE(auto grouper,
Grouper::Make(key_batch.GetDescriptors()));
@@ -123,16 +123,9 @@ Result<Datum> NaiveGroupBy(std::vector<Datum> arguments,
std::vector<Datum> keys
Result<Datum> GroupByUsingExecPlan(const BatchesWithSchema& input,
const std::vector<std::string>& key_names,
- const std::vector<std::string>& arg_names,
- const std::vector<internal::Aggregate>&
aggregates,
+ const std::vector<Aggregate>& aggregates,
bool use_threads, ExecContext* ctx) {
std::vector<FieldRef> keys(key_names.size());
- std::vector<FieldRef> targets(aggregates.size());
- std::vector<std::string> names(aggregates.size());
- for (size_t i = 0; i < aggregates.size(); ++i) {
- names[i] = aggregates[i].function;
- targets[i] = FieldRef(arg_names[i]);
- }
for (size_t i = 0; i < key_names.size(); ++i) {
keys[i] = FieldRef(key_names[i]);
}
@@ -144,9 +137,7 @@ Result<Datum> GroupByUsingExecPlan(const BatchesWithSchema&
input,
{
{"source",
SourceNodeOptions{input.schema, input.gen(use_threads,
/*slow=*/false)}},
- {"aggregate",
- AggregateNodeOptions{std::move(aggregates), std::move(targets),
- std::move(names), std::move(keys)}},
+ {"aggregate", AggregateNodeOptions{std::move(aggregates),
std::move(keys)}},
{"sink", SinkNodeOptions{&sink_gen}},
})
.AddToPlan(plan.get()));
@@ -191,17 +182,15 @@ Result<Datum> GroupByUsingExecPlan(const
BatchesWithSchema& input,
/// Simpler overload where you can give the columns as datums
Result<Datum> GroupByUsingExecPlan(const std::vector<Datum>& arguments,
const std::vector<Datum>& keys,
- const std::vector<internal::Aggregate>&
aggregates,
+ const std::vector<Aggregate>& aggregates,
bool use_threads, ExecContext* ctx) {
using arrow::compute::detail::ExecBatchIterator;
FieldVector scan_fields(arguments.size() + keys.size());
std::vector<std::string> key_names(keys.size());
- std::vector<std::string> arg_names(arguments.size());
for (size_t i = 0; i < arguments.size(); ++i) {
auto name = std::string("agg_") + std::to_string(i);
scan_fields[i] = field(name, arguments[i].type());
- arg_names[i] = std::move(name);
}
for (size_t i = 0; i < keys.size(); ++i) {
auto name = std::string("key_") + std::to_string(i);
@@ -223,14 +212,14 @@ Result<Datum> GroupByUsingExecPlan(const
std::vector<Datum>& arguments,
input.batches.push_back(std::move(batch));
}
- return GroupByUsingExecPlan(input, key_names, arg_names, aggregates,
use_threads, ctx);
+ return GroupByUsingExecPlan(input, key_names, aggregates, use_threads, ctx);
}
-void ValidateGroupBy(const std::vector<internal::Aggregate>& aggregates,
+void ValidateGroupBy(const std::vector<Aggregate>& aggregates,
std::vector<Datum> arguments, std::vector<Datum> keys) {
ASSERT_OK_AND_ASSIGN(Datum expected, NaiveGroupBy(arguments, keys,
aggregates));
- ASSERT_OK_AND_ASSIGN(Datum actual, GroupBy(arguments, keys, aggregates));
+ ASSERT_OK_AND_ASSIGN(Datum actual, internal::GroupBy(arguments, keys,
aggregates));
ASSERT_OK(expected.make_array()->ValidateFull());
ValidateOutput(actual);
@@ -246,15 +235,27 @@ ExecContext* small_chunksize_context(bool use_threads =
false) {
return use_threads ? &ctx_with_threads : &ctx;
}
-Result<Datum> GroupByTest(
- const std::vector<Datum>& arguments, const std::vector<Datum>& keys,
- const std::vector<::arrow::compute::internal::Aggregate>& aggregates,
- bool use_threads, bool use_exec_plan) {
+struct TestAggregate {
+ std::string function;
+ std::shared_ptr<FunctionOptions> options;
+};
+
+Result<Datum> GroupByTest(const std::vector<Datum>& arguments,
+ const std::vector<Datum>& keys,
+ const std::vector<TestAggregate>& aggregates, bool
use_threads,
+ bool use_exec_plan = false) {
+ std::vector<Aggregate> internal_aggregates;
+ int idx = 0;
+ for (auto t_agg : aggregates) {
+ internal_aggregates.push_back(
+ {t_agg.function, t_agg.options, "agg_" + std::to_string(idx),
t_agg.function});
+ idx = idx + 1;
+ }
if (use_exec_plan) {
- return GroupByUsingExecPlan(arguments, keys, aggregates, use_threads,
+ return GroupByUsingExecPlan(arguments, keys, internal_aggregates,
use_threads,
small_chunksize_context(use_threads));
} else {
- return internal::GroupBy(arguments, keys, aggregates, use_threads,
+ return internal::GroupBy(arguments, keys, internal_aggregates, use_threads,
default_exec_context());
}
}
@@ -860,11 +861,11 @@ TEST(GroupBy, CountScalar) {
SCOPED_TRACE(use_threads ? "parallel/merged" : "serial");
ASSERT_OK_AND_ASSIGN(
Datum actual,
- GroupByUsingExecPlan(input, {"key"}, {"argument", "argument",
"argument"},
+ GroupByUsingExecPlan(input, {"key"},
{
- {"hash_count", skip_nulls},
- {"hash_count", keep_nulls},
- {"hash_count", count_all},
+ {"hash_count", skip_nulls, "argument",
"hash_count"},
+ {"hash_count", keep_nulls, "argument",
"hash_count"},
+ {"hash_count", count_all, "argument",
"hash_count"},
},
use_threads, default_exec_context()));
Datum expected = ArrayFromJSON(struct_({
@@ -1031,14 +1032,14 @@ TEST(GroupBy, MeanOnly) {
auto min_count =
std::make_shared<ScalarAggregateOptions>(/*skip_nulls=*/true,
/*min_count=*/3);
ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped,
- internal::GroupBy({table->GetColumnByName("argument"),
-
table->GetColumnByName("argument")},
- {table->GetColumnByName("key")},
- {
- {"hash_mean", nullptr},
- {"hash_mean", min_count},
- },
- use_threads));
+ GroupByTest({table->GetColumnByName("argument"),
+ table->GetColumnByName("argument")},
+ {table->GetColumnByName("key")},
+ {
+ {"hash_mean", nullptr},
+ {"hash_mean", min_count},
+ },
+ use_threads));
SortBy({"key_0"}, &aggregated_and_grouped);
AssertDatumsApproxEqual(ArrayFromJSON(struct_({
@@ -1072,11 +1073,11 @@ TEST(GroupBy, SumMeanProductScalar) {
SCOPED_TRACE(use_threads ? "parallel/merged" : "serial");
ASSERT_OK_AND_ASSIGN(
Datum actual,
- GroupByUsingExecPlan(input, {"key"}, {"argument", "argument",
"argument"},
+ GroupByUsingExecPlan(input, {"key"},
{
- {"hash_sum", nullptr},
- {"hash_mean", nullptr},
- {"hash_product", nullptr},
+ {"hash_sum", nullptr, "argument", "hash_sum"},
+ {"hash_mean", nullptr, "argument",
"hash_mean"},
+ {"hash_product", nullptr, "argument",
"hash_product"},
},
use_threads, default_exec_context()));
Datum expected = ArrayFromJSON(struct_({
@@ -1110,7 +1111,7 @@ TEST(GroupBy, VarianceAndStddev) {
])");
ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped,
- internal::GroupBy(
+ GroupByTest(
{
batch->GetColumnByName("argument"),
batch->GetColumnByName("argument"),
@@ -1121,7 +1122,8 @@ TEST(GroupBy, VarianceAndStddev) {
{
{"hash_variance", nullptr},
{"hash_stddev", nullptr},
- }));
+ },
+ false));
AssertDatumsApproxEqual(ArrayFromJSON(struct_({
field("hash_variance", float64()),
@@ -1151,7 +1153,7 @@ TEST(GroupBy, VarianceAndStddev) {
[null, 3]
])");
- ASSERT_OK_AND_ASSIGN(aggregated_and_grouped, internal::GroupBy(
+ ASSERT_OK_AND_ASSIGN(aggregated_and_grouped, GroupByTest(
{
batch->GetColumnByName("argument"),
batch->GetColumnByName("argument"),
@@ -1162,7 +1164,8 @@ TEST(GroupBy, VarianceAndStddev) {
{
{"hash_variance",
nullptr},
{"hash_stddev",
nullptr},
- }));
+ },
+ false));
AssertDatumsApproxEqual(ArrayFromJSON(struct_({
field("hash_variance", float64()),
@@ -1181,7 +1184,7 @@ TEST(GroupBy, VarianceAndStddev) {
// Test ddof
auto variance_options = std::make_shared<VarianceOptions>(/*ddof=*/2);
ASSERT_OK_AND_ASSIGN(aggregated_and_grouped,
- internal::GroupBy(
+ GroupByTest(
{
batch->GetColumnByName("argument"),
batch->GetColumnByName("argument"),
@@ -1192,7 +1195,8 @@ TEST(GroupBy, VarianceAndStddev) {
{
{"hash_variance", variance_options},
{"hash_stddev", variance_options},
- }));
+ },
+ false));
AssertDatumsApproxEqual(ArrayFromJSON(struct_({
field("hash_variance", float64()),
@@ -1225,7 +1229,7 @@ TEST(GroupBy, VarianceAndStddevDecimal) {
])");
ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped,
- internal::GroupBy(
+ GroupByTest(
{
batch->GetColumnByName("argument0"),
batch->GetColumnByName("argument0"),
@@ -1240,7 +1244,8 @@ TEST(GroupBy, VarianceAndStddevDecimal) {
{"hash_stddev", nullptr},
{"hash_variance", nullptr},
{"hash_stddev", nullptr},
- }));
+ },
+ false));
AssertDatumsApproxEqual(ArrayFromJSON(struct_({
field("hash_variance", float64()),
@@ -1291,7 +1296,7 @@ TEST(GroupBy, TDigest) {
std::make_shared<TDigestOptions>(/*q=*/0.5, /*delta=*/100,
/*buffer_size=*/500,
/*skip_nulls=*/false, /*min_count=*/3);
ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped,
- internal::GroupBy(
+ GroupByTest(
{
batch->GetColumnByName("argument"),
batch->GetColumnByName("argument"),
@@ -1310,7 +1315,8 @@ TEST(GroupBy, TDigest) {
{"hash_tdigest", keep_nulls},
{"hash_tdigest", min_count},
{"hash_tdigest", keep_nulls_min_count},
- }));
+ },
+ false));
AssertDatumsApproxEqual(
ArrayFromJSON(struct_({
@@ -1349,7 +1355,7 @@ TEST(GroupBy, TDigestDecimal) {
])");
ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped,
- internal::GroupBy(
+ GroupByTest(
{
batch->GetColumnByName("argument0"),
batch->GetColumnByName("argument1"),
@@ -1358,7 +1364,8 @@ TEST(GroupBy, TDigestDecimal) {
{
{"hash_tdigest", nullptr},
{"hash_tdigest", nullptr},
- }));
+ },
+ false));
AssertDatumsApproxEqual(
ArrayFromJSON(struct_({
@@ -1403,7 +1410,7 @@ TEST(GroupBy, ApproximateMedian) {
auto keep_nulls_min_count = std::make_shared<ScalarAggregateOptions>(
/*skip_nulls=*/false, /*min_count=*/3);
ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped,
- internal::GroupBy(
+ GroupByTest(
{
batch->GetColumnByName("argument"),
batch->GetColumnByName("argument"),
@@ -1418,7 +1425,8 @@ TEST(GroupBy, ApproximateMedian) {
{"hash_approximate_median", keep_nulls},
{"hash_approximate_median", min_count},
{"hash_approximate_median",
keep_nulls_min_count},
- }));
+ },
+ false));
AssertDatumsApproxEqual(ArrayFromJSON(struct_({
field("hash_approximate_median",
float64()),
@@ -1456,19 +1464,18 @@ TEST(GroupBy, StddevVarianceTDigestScalar) {
for (bool use_threads : {false}) {
SCOPED_TRACE(use_threads ? "parallel/merged" : "serial");
- ASSERT_OK_AND_ASSIGN(Datum actual,
- GroupByUsingExecPlan(input, {"key"},
- {"argument", "argument",
"argument",
- "argument1", "argument1",
"argument1"},
- {
- {"hash_stddev", nullptr},
- {"hash_variance", nullptr},
- {"hash_tdigest", nullptr},
- {"hash_stddev", nullptr},
- {"hash_variance", nullptr},
- {"hash_tdigest", nullptr},
- },
- use_threads,
default_exec_context()));
+ ASSERT_OK_AND_ASSIGN(
+ Datum actual,
+ GroupByUsingExecPlan(input, {"key"},
+ {
+ {"hash_stddev", nullptr, "argument",
"hash_stddev"},
+ {"hash_variance", nullptr, "argument",
"hash_variance"},
+ {"hash_tdigest", nullptr, "argument",
"hash_tdigest"},
+ {"hash_stddev", nullptr, "argument1",
"hash_stddev"},
+ {"hash_variance", nullptr, "argument1",
"hash_variance"},
+ {"hash_tdigest", nullptr, "argument1",
"hash_tdigest"},
+ },
+ use_threads, default_exec_context()));
Datum expected =
ArrayFromJSON(struct_({
field("hash_stddev", float64()),
@@ -1516,25 +1523,19 @@ TEST(GroupBy, VarianceOptions) {
for (bool use_threads : {false}) {
SCOPED_TRACE(use_threads ? "parallel/merged" : "serial");
- ASSERT_OK_AND_ASSIGN(Datum actual,
- GroupByUsingExecPlan(input, {"key"},
- {
- "argument",
- "argument",
- "argument",
- "argument",
- "argument",
- "argument",
- },
- {
- {"hash_stddev", keep_nulls},
- {"hash_stddev", min_count},
- {"hash_stddev",
keep_nulls_min_count},
- {"hash_variance",
keep_nulls},
- {"hash_variance", min_count},
- {"hash_variance",
keep_nulls_min_count},
- },
- use_threads,
default_exec_context()));
+ ASSERT_OK_AND_ASSIGN(
+ Datum actual,
+ GroupByUsingExecPlan(
+ input, {"key"},
+ {
+ {"hash_stddev", keep_nulls, "argument", "hash_stddev"},
+ {"hash_stddev", min_count, "argument", "hash_stddev"},
+ {"hash_stddev", keep_nulls_min_count, "argument",
"hash_stddev"},
+ {"hash_variance", keep_nulls, "argument", "hash_variance"},
+ {"hash_variance", min_count, "argument", "hash_variance"},
+ {"hash_variance", keep_nulls_min_count, "argument",
"hash_variance"},
+ },
+ use_threads, default_exec_context()));
Datum expected = ArrayFromJSON(struct_({
field("hash_stddev", float64()),
field("hash_stddev", float64()),
@@ -1553,25 +1554,19 @@ TEST(GroupBy, VarianceOptions) {
ValidateOutput(expected);
AssertDatumsApproxEqual(expected, actual, /*verbose=*/true);
- ASSERT_OK_AND_ASSIGN(actual,
- GroupByUsingExecPlan(input, {"key"},
- {
- "argument1",
- "argument1",
- "argument1",
- "argument1",
- "argument1",
- "argument1",
- },
- {
- {"hash_stddev", keep_nulls},
- {"hash_stddev", min_count},
- {"hash_stddev",
keep_nulls_min_count},
- {"hash_variance",
keep_nulls},
- {"hash_variance", min_count},
- {"hash_variance",
keep_nulls_min_count},
- },
- use_threads,
default_exec_context()));
+ ASSERT_OK_AND_ASSIGN(
+ actual,
+ GroupByUsingExecPlan(
+ input, {"key"},
+ {
+ {"hash_stddev", keep_nulls, "argument1", "hash_stddev"},
+ {"hash_stddev", min_count, "argument1", "hash_stddev"},
+ {"hash_stddev", keep_nulls_min_count, "argument1",
"hash_stddev"},
+ {"hash_variance", keep_nulls, "argument1", "hash_variance"},
+ {"hash_variance", min_count, "argument1", "hash_variance"},
+ {"hash_variance", keep_nulls_min_count, "argument1",
"hash_variance"},
+ },
+ use_threads, default_exec_context()));
expected = ArrayFromJSON(struct_({
field("hash_stddev", float64()),
field("hash_stddev", float64()),
@@ -1997,9 +1992,9 @@ TEST(GroupBy, MinMaxScalar) {
SCOPED_TRACE(use_threads ? "parallel/merged" : "serial");
ASSERT_OK_AND_ASSIGN(
Datum actual,
- GroupByUsingExecPlan(input, {"key"}, {"argument", "argument",
"argument"},
- {{"hash_min_max", nullptr}}, use_threads,
- default_exec_context()));
+ GroupByUsingExecPlan(input, {"key"},
+ {{"hash_min_max", nullptr, "argument",
"hash_min_max"}},
+ use_threads, default_exec_context()));
Datum expected =
ArrayFromJSON(struct_({
field("hash_min_max",
@@ -2062,14 +2057,14 @@ TEST(GroupBy, AnyAndAll) {
},
{table->GetColumnByName("key")},
{
- {"hash_any", no_min},
- {"hash_any", min_count},
- {"hash_any", keep_nulls},
- {"hash_any", keep_nulls_min_count},
- {"hash_all", no_min},
- {"hash_all", min_count},
- {"hash_all", keep_nulls},
- {"hash_all", keep_nulls_min_count},
+ {"hash_any", no_min, "agg_0", "hash_any"},
+ {"hash_any", min_count, "agg_1", "hash_any"},
+ {"hash_any", keep_nulls, "agg_2", "hash_any"},
+ {"hash_any", keep_nulls_min_count, "agg_3",
"hash_any"},
+ {"hash_all", no_min, "agg_4", "hash_all"},
+ {"hash_all", min_count, "agg_5", "hash_all"},
+ {"hash_all", keep_nulls, "agg_6", "hash_all"},
+ {"hash_all", keep_nulls_min_count, "agg_7",
"hash_all"},
},
use_threads));
SortBy({"key_0"}, &aggregated_and_grouped);
@@ -2122,12 +2117,11 @@ TEST(GroupBy, AnyAllScalar) {
ASSERT_OK_AND_ASSIGN(
Datum actual,
GroupByUsingExecPlan(input, {"key"},
- {"argument", "argument", "argument", "argument"},
{
- {"hash_any", nullptr},
- {"hash_all", nullptr},
- {"hash_any", keep_nulls},
- {"hash_all", keep_nulls},
+ {"hash_any", nullptr, "argument", "hash_any"},
+ {"hash_all", nullptr, "argument", "hash_all"},
+ {"hash_any", keep_nulls, "argument",
"hash_any"},
+ {"hash_all", keep_nulls, "argument",
"hash_all"},
},
use_threads, default_exec_context()));
Datum expected = ArrayFromJSON(struct_({
@@ -2184,22 +2178,23 @@ TEST(GroupBy, CountDistinct) {
[3, null]
])"});
- ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped,
- internal::GroupBy(
- {
- table->GetColumnByName("argument"),
- table->GetColumnByName("argument"),
- table->GetColumnByName("argument"),
- },
- {
- table->GetColumnByName("key"),
- },
- {
- {"hash_count_distinct", all},
- {"hash_count_distinct", only_valid},
- {"hash_count_distinct", only_null},
- },
- use_threads));
+ ASSERT_OK_AND_ASSIGN(
+ Datum aggregated_and_grouped,
+ internal::GroupBy(
+ {
+ table->GetColumnByName("argument"),
+ table->GetColumnByName("argument"),
+ table->GetColumnByName("argument"),
+ },
+ {
+ table->GetColumnByName("key"),
+ },
+ {
+ {"hash_count_distinct", all, "agg_0", "hash_count_distinct"},
+ {"hash_count_distinct", only_valid, "agg_1",
"hash_count_distinct"},
+ {"hash_count_distinct", only_null, "agg_2",
"hash_count_distinct"},
+ },
+ use_threads));
SortBy({"key_0"}, &aggregated_and_grouped);
ValidateOutput(aggregated_and_grouped);
@@ -2250,22 +2245,23 @@ TEST(GroupBy, CountDistinct) {
["b", null]
])"});
- ASSERT_OK_AND_ASSIGN(aggregated_and_grouped,
- internal::GroupBy(
- {
- table->GetColumnByName("argument"),
- table->GetColumnByName("argument"),
- table->GetColumnByName("argument"),
- },
- {
- table->GetColumnByName("key"),
- },
- {
- {"hash_count_distinct", all},
- {"hash_count_distinct", only_valid},
- {"hash_count_distinct", only_null},
- },
- use_threads));
+ ASSERT_OK_AND_ASSIGN(
+ aggregated_and_grouped,
+ internal::GroupBy(
+ {
+ table->GetColumnByName("argument"),
+ table->GetColumnByName("argument"),
+ table->GetColumnByName("argument"),
+ },
+ {
+ table->GetColumnByName("key"),
+ },
+ {
+ {"hash_count_distinct", all, "agg_0", "hash_count_distinct"},
+ {"hash_count_distinct", only_valid, "agg_1",
"hash_count_distinct"},
+ {"hash_count_distinct", only_null, "agg_2",
"hash_count_distinct"},
+ },
+ use_threads));
ValidateOutput(aggregated_and_grouped);
SortBy({"key_0"}, &aggregated_and_grouped);
@@ -2296,22 +2292,23 @@ TEST(GroupBy, CountDistinct) {
])",
});
- ASSERT_OK_AND_ASSIGN(aggregated_and_grouped,
- internal::GroupBy(
- {
- table->GetColumnByName("argument"),
- table->GetColumnByName("argument"),
- table->GetColumnByName("argument"),
- },
- {
- table->GetColumnByName("key"),
- },
- {
- {"hash_count_distinct", all},
- {"hash_count_distinct", only_valid},
- {"hash_count_distinct", only_null},
- },
- use_threads));
+ ASSERT_OK_AND_ASSIGN(
+ aggregated_and_grouped,
+ internal::GroupBy(
+ {
+ table->GetColumnByName("argument"),
+ table->GetColumnByName("argument"),
+ table->GetColumnByName("argument"),
+ },
+ {
+ table->GetColumnByName("key"),
+ },
+ {
+ {"hash_count_distinct", all, "agg_0", "hash_count_distinct"},
+ {"hash_count_distinct", only_valid, "agg_1",
"hash_count_distinct"},
+ {"hash_count_distinct", only_null, "agg_2",
"hash_count_distinct"},
+ },
+ use_threads));
ValidateOutput(aggregated_and_grouped);
SortBy({"key_0"}, &aggregated_and_grouped);
@@ -2379,9 +2376,9 @@ TEST(GroupBy, Distinct) {
table->GetColumnByName("key"),
},
{
- {"hash_distinct", all},
- {"hash_distinct", only_valid},
- {"hash_distinct", only_null},
+ {"hash_distinct", all, "agg_0",
"hash_distinct"},
+ {"hash_distinct", only_valid, "agg_1",
"hash_distinct"},
+ {"hash_distinct", only_null, "agg_2",
"hash_distinct"},
},
use_threads));
ValidateOutput(aggregated_and_grouped);
@@ -2452,9 +2449,9 @@ TEST(GroupBy, Distinct) {
table->GetColumnByName("key"),
},
{
- {"hash_distinct", all},
- {"hash_distinct", only_valid},
- {"hash_distinct", only_null},
+ {"hash_distinct", all, "agg_0",
"hash_distinct"},
+ {"hash_distinct", only_valid, "agg_1",
"hash_distinct"},
+ {"hash_distinct", only_null, "agg_2",
"hash_distinct"},
},
use_threads));
ValidateOutput(aggregated_and_grouped);
@@ -2744,8 +2741,8 @@ TEST(GroupBy, OneScalar) {
SCOPED_TRACE(use_threads ? "parallel/merged" : "serial");
ASSERT_OK_AND_ASSIGN(
Datum actual, GroupByUsingExecPlan(
- input, {"key"}, {"argument", "argument", "argument"},
- {{"hash_one", nullptr}}, use_threads,
default_exec_context()));
+ input, {"key"}, {{"hash_one", nullptr, "argument",
"hash_one"}},
+ use_threads, default_exec_context()));
const auto& struct_arr = actual.array_as<StructArray>();
// Check the key column
@@ -2805,7 +2802,7 @@ TEST(GroupBy, ListNumeric) {
table->GetColumnByName("key"),
},
{
- {"hash_list", nullptr},
+ {"hash_list", nullptr, "agg_0",
"hash_list"},
},
use_threads));
ValidateOutput(aggregated_and_grouped);
@@ -2876,7 +2873,7 @@ TEST(GroupBy, ListNumeric) {
table->GetColumnByName("key"),
},
{
- {"hash_list", nullptr},
+ {"hash_list", nullptr, "agg_0",
"hash_list"},
},
use_threads));
ValidateOutput(aggregated_and_grouped);
@@ -2945,7 +2942,7 @@ TEST(GroupBy, ListBinaryTypes) {
table->GetColumnByName("key"),
},
{
- {"hash_list", nullptr},
+ {"hash_list", nullptr, "agg_0",
"hash_list"},
},
use_threads));
ValidateOutput(aggregated_and_grouped);
@@ -3007,7 +3004,7 @@ TEST(GroupBy, ListBinaryTypes) {
table->GetColumnByName("key"),
},
{
- {"hash_list", nullptr},
+ {"hash_list", nullptr, "agg_0",
"hash_list"},
},
use_threads));
ValidateOutput(aggregated_and_grouped);
@@ -3238,12 +3235,12 @@ TEST(GroupBy, CountAndSum) {
batch->GetColumnByName("key"),
},
{
- {"hash_count", count_options},
- {"hash_count", count_nulls},
- {"hash_count", count_all},
- {"hash_sum", nullptr},
- {"hash_sum", min_count},
- {"hash_sum", nullptr},
+ {"hash_count", count_options, "agg_0", "hash_count"},
+ {"hash_count", count_nulls, "agg_1", "hash_count"},
+ {"hash_count", count_all, "agg_2", "hash_count"},
+ {"hash_sum", nullptr, "agg_3", "hash_sum"},
+ {"hash_sum", min_count, "agg_4", "hash_sum"},
+ {"hash_sum", nullptr, "agg_5", "hash_sum"},
}));
AssertDatumsEqual(
@@ -3295,9 +3292,9 @@ TEST(GroupBy, Product) {
batch->GetColumnByName("key"),
},
{
- {"hash_product", nullptr},
- {"hash_product", nullptr},
- {"hash_product", min_count},
+ {"hash_product", nullptr, "agg_0",
"hash_product"},
+ {"hash_product", nullptr, "agg_1",
"hash_product"},
+ {"hash_product", min_count, "agg_2",
"hash_product"},
}));
AssertDatumsApproxEqual(ArrayFromJSON(struct_({
@@ -3322,16 +3319,17 @@ TEST(GroupBy, Product) {
[8589934593, 1]
])");
- ASSERT_OK_AND_ASSIGN(aggregated_and_grouped, internal::GroupBy(
- {
-
batch->GetColumnByName("argument"),
- },
- {
-
batch->GetColumnByName("key"),
- },
- {
- {"hash_product",
nullptr},
- }));
+ ASSERT_OK_AND_ASSIGN(aggregated_and_grouped,
+ internal::GroupBy(
+ {
+ batch->GetColumnByName("argument"),
+ },
+ {
+ batch->GetColumnByName("key"),
+ },
+ {
+ {"hash_product", nullptr, "agg_0",
"hash_product"},
+ }));
AssertDatumsApproxEqual(ArrayFromJSON(struct_({
field("hash_product", int64()),
@@ -3374,12 +3372,12 @@ TEST(GroupBy, SumMeanProductKeepNulls) {
batch->GetColumnByName("key"),
},
{
- {"hash_sum", keep_nulls},
- {"hash_sum", min_count},
- {"hash_mean", keep_nulls},
- {"hash_mean", min_count},
- {"hash_product", keep_nulls},
- {"hash_product", min_count},
+ {"hash_sum", keep_nulls, "agg_0", "hash_sum"},
+ {"hash_sum", min_count, "agg_1", "hash_sum"},
+ {"hash_mean", keep_nulls, "agg_2", "hash_mean"},
+ {"hash_mean", min_count, "agg_3", "hash_mean"},
+ {"hash_product", keep_nulls, "agg_4",
"hash_product"},
+ {"hash_product", min_count, "agg_5",
"hash_product"},
}));
AssertDatumsApproxEqual(ArrayFromJSON(struct_({
@@ -3423,7 +3421,7 @@ TEST(GroupBy, SumOnlyStringAndDictKeys) {
internal::GroupBy({batch->GetColumnByName("argument")},
{batch->GetColumnByName("key")},
{
- {"hash_sum", nullptr},
+ {"hash_sum", nullptr, "agg_0",
"hash_sum"},
}));
SortBy({"key_0"}, &aggregated_and_grouped);
@@ -3464,13 +3462,12 @@ TEST(GroupBy, ConcreteCaseWithValidateGroupBy) {
std::shared_ptr<CountOptions> non_null =
std::make_shared<CountOptions>(CountOptions::ONLY_VALID);
- using internal::Aggregate;
for (auto agg : {
- Aggregate{"hash_sum", nullptr},
- Aggregate{"hash_count", non_null},
- Aggregate{"hash_count", nulls},
- Aggregate{"hash_min_max", nullptr},
- Aggregate{"hash_min_max", keepna},
+ Aggregate{"hash_sum", nullptr, "agg_0", "hash_sum"},
+ Aggregate{"hash_count", non_null, "agg_1", "hash_count"},
+ Aggregate{"hash_count", nulls, "agg_2", "hash_count"},
+ Aggregate{"hash_min_max", nullptr, "agg_3", "hash_min_max"},
+ Aggregate{"hash_min_max", keepna, "agg_4", "hash_min_max"},
}) {
SCOPED_TRACE(agg.function);
ValidateGroupBy({agg}, {batch->GetColumnByName("argument")},
@@ -3492,10 +3489,9 @@ TEST(GroupBy, CountNull) {
std::shared_ptr<CountOptions> skipna =
std::make_shared<CountOptions>(CountOptions::ONLY_VALID);
- using internal::Aggregate;
for (auto agg : {
- Aggregate{"hash_count", keepna},
- Aggregate{"hash_count", skipna},
+ Aggregate{"hash_count", keepna, "agg_0", "hash_count"},
+ Aggregate{"hash_count", skipna, "agg_1", "hash_count"},
}) {
SCOPED_TRACE(agg.function);
ValidateGroupBy({agg}, {batch->GetColumnByName("argument")},
@@ -3519,7 +3515,7 @@ TEST(GroupBy, RandomArraySum) {
ValidateGroupBy(
{
- {"hash_sum", options},
+ {"hash_sum", options, "agg_0", "hash_sum"},
},
{batch->GetColumnByName("argument")},
{batch->GetColumnByName("key")});
}
@@ -3552,9 +3548,9 @@ TEST(GroupBy, WithChunkedArray) {
table->GetColumnByName("key"),
},
{
- {"hash_count", nullptr},
- {"hash_sum", nullptr},
- {"hash_min_max", nullptr},
+ {"hash_count", nullptr, "agg_0", "hash_count"},
+ {"hash_sum", nullptr, "agg_1", "hash_sum"},
+ {"hash_min_max", nullptr, "agg_2",
"hash_min_max"},
}));
AssertDatumsEqual(ArrayFromJSON(struct_({
@@ -3590,7 +3586,7 @@ TEST(GroupBy, MinMaxWithNewGroupsInChunkedArray) {
table->GetColumnByName("key"),
},
{
- {"hash_min_max", nullptr},
+ {"hash_min_max", nullptr, "agg_1",
"hash_min_max"},
}));
AssertDatumsEqual(ArrayFromJSON(struct_({
@@ -3626,7 +3622,7 @@ TEST(GroupBy, SmallChunkSizeSumOnly) {
internal::GroupBy({batch->GetColumnByName("argument")},
{batch->GetColumnByName("key")},
{
- {"hash_sum", nullptr},
+ {"hash_sum", nullptr, "agg_0",
"hash_sum"},
},
small_chunksize_context()));
AssertDatumsEqual(ArrayFromJSON(struct_({
diff --git a/cpp/src/arrow/dataset/scanner.cc b/cpp/src/arrow/dataset/scanner.cc
index 02f658181c..3cd5f1fcc2 100644
--- a/cpp/src/arrow/dataset/scanner.cc
+++ b/cpp/src/arrow/dataset/scanner.cc
@@ -675,10 +675,8 @@ Result<int64_t> AsyncScanner::CountRows() {
std::move(fragment_gen)),
options}},
{"project", compute::ProjectNodeOptions{{options->filter},
{"mask"}}},
- {"aggregate",
compute::AggregateNodeOptions{{compute::internal::Aggregate{
- "sum", nullptr}},
- /*targets=*/{"mask"},
-
/*names=*/{"selected_count"}}},
+ {"aggregate", compute::AggregateNodeOptions{{compute::Aggregate{
+ "sum", nullptr, "mask", "selected_count"}}}},
{"sink", compute::SinkNodeOptions{&sink_gen}},
})
.AddToPlan(plan.get()));
diff --git a/cpp/src/arrow/dataset/scanner_test.cc
b/cpp/src/arrow/dataset/scanner_test.cc
index b7dcb8b18d..5316f63d08 100644
--- a/cpp/src/arrow/dataset/scanner_test.cc
+++ b/cpp/src/arrow/dataset/scanner_test.cc
@@ -1837,11 +1837,9 @@ TEST(ScanNode, MinimalScalarAggEndToEnd) {
// pipe the projection into a scalar aggregate node
ASSERT_OK_AND_ASSIGN(
compute::ExecNode * aggregate,
- compute::MakeExecNode(
- "aggregate", plan.get(), {project},
- compute::AggregateNodeOptions{{compute::internal::Aggregate{"sum",
nullptr}},
- /*targets=*/{"a * 2"},
- /*names=*/{"sum(a * 2)"}}));
+ compute::MakeExecNode("aggregate", plan.get(), {project},
+ compute::AggregateNodeOptions{{compute::Aggregate{
+ "sum", nullptr, "a * 2", "sum(a * 2)"}}}));
// finally, pipe the aggregate node into a sink node
AsyncGenerator<util::optional<compute::ExecBatch>> sink_gen;
@@ -1927,12 +1925,11 @@ TEST(ScanNode, MinimalGroupedAggEndToEnd) {
// pipe the projection into a grouped aggregate node
ASSERT_OK_AND_ASSIGN(
compute::ExecNode * aggregate,
- compute::MakeExecNode("aggregate", plan.get(), {project},
- compute::AggregateNodeOptions{
- {compute::internal::Aggregate{"hash_sum",
nullptr}},
- /*targets=*/{"a * 2"},
- /*names=*/{"sum(a * 2)"},
- /*keys=*/{"b"}}));
+ compute::MakeExecNode(
+ "aggregate", plan.get(), {project},
+ compute::AggregateNodeOptions{
+ {compute::Aggregate{"hash_sum", nullptr, "a * 2", "sum(a * 2)"}},
+ /*keys=*/{"b"}}));
// finally, pipe the aggregate node into a sink node
AsyncGenerator<util::optional<compute::ExecBatch>> sink_gen;
diff --git a/python/pyarrow/includes/libarrow.pxd
b/python/pyarrow/includes/libarrow.pxd
index 302ac99c36..9e43eb4eb9 100644
--- a/python/pyarrow/includes/libarrow.pxd
+++ b/python/pyarrow/includes/libarrow.pxd
@@ -2408,7 +2408,7 @@ cdef extern from * namespace "arrow::compute":
cdef extern from "arrow/compute/exec/aggregate.h" namespace \
"arrow::compute::internal" nogil:
- cdef cppclass CAggregate "arrow::compute::internal::Aggregate":
+ cdef cppclass CAggregate "arrow::compute::Aggregate":
c_string function
shared_ptr[CFunctionOptions] options
diff --git a/r/R/arrowExports.R b/r/R/arrowExports.R
index bf5a8d0682..84f6ee54fc 100644
--- a/r/R/arrowExports.R
+++ b/r/R/arrowExports.R
@@ -432,8 +432,8 @@ ExecNode_Project <- function(input, exprs, names) {
.Call(`_arrow_ExecNode_Project`, input, exprs, names)
}
-ExecNode_Aggregate <- function(input, options, target_names, out_field_names,
key_names) {
- .Call(`_arrow_ExecNode_Aggregate`, input, options, target_names,
out_field_names, key_names)
+ExecNode_Aggregate <- function(input, options, key_names) {
+ .Call(`_arrow_ExecNode_Aggregate`, input, options, key_names)
}
ExecNode_Join <- function(input, type, right_data, left_keys, right_keys,
left_output, right_output, output_suffix_for_left, output_suffix_for_right) {
diff --git a/r/R/query-engine.R b/r/R/query-engine.R
index c40c61e98a..513b861d41 100644
--- a/r/R/query-engine.R
+++ b/r/R/query-engine.R
@@ -14,7 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
ExecPlan <- R6Class("ExecPlan",
inherit = ArrowObject,
public = list(
@@ -73,8 +72,6 @@ ExecPlan <- R6Class("ExecPlan",
group_vars <- dplyr::group_vars(.data)
grouped <- length(group_vars) > 0
- # Collect the target names first because we have to add back the group
vars
- target_names <- names(.data)
.data <- ensure_group_vars(.data)
.data <- ensure_arrange_vars(.data) # this sets .data$temp_columns
@@ -115,10 +112,15 @@ ExecPlan <- R6Class("ExecPlan",
})
}
+ .data$aggregations <- imap(.data$aggregations, function(x, name) {
+ # Embed the name inside the aggregation objects. `target` and `name`
+ # are the same because we just Project()ed the data that way above
+ x[["name"]] <- x[["target"]] <- name
+ x
+ })
+
node <- node$Aggregate(
- options = map(.data$aggregations, ~ .[c("fun", "options")]),
- target_names = names(.data$aggregations),
- out_field_names = names(.data$aggregations),
+ options = .data$aggregations,
key_names = group_vars
)
@@ -179,7 +181,6 @@ ExecPlan <- R6Class("ExecPlan",
temp_columns = names(.data$temp_columns)
)
}
-
# This is only safe because we are going to evaluate queries that end
# with head/tail first, then evaluate any subsequent query as a new query
if (!is.null(.data$head)) {
@@ -304,9 +305,9 @@ ExecNode <- R6Class("ExecNode",
assert_is(expr, "Expression")
self$preserve_extras(ExecNode_Filter(self, expr))
},
- Aggregate = function(options, target_names, out_field_names, key_names) {
+ Aggregate = function(options, key_names) {
out <- self$preserve_extras(
- ExecNode_Aggregate(self, options, target_names, out_field_names,
key_names)
+ ExecNode_Aggregate(self, options, key_names)
)
# dplyr drops top-level attributes when you call summarize()
out$extras$source_schema$metadata[["r"]]$attributes <- NULL
diff --git a/r/src/arrowExports.cpp b/r/src/arrowExports.cpp
index 947270199a..62c8b6695c 100644
--- a/r/src/arrowExports.cpp
+++ b/r/src/arrowExports.cpp
@@ -964,15 +964,13 @@ BEGIN_CPP11
END_CPP11
}
// compute-exec.cpp
-std::shared_ptr<compute::ExecNode> ExecNode_Aggregate(const
std::shared_ptr<compute::ExecNode>& input, cpp11::list options,
std::vector<std::string> target_names, std::vector<std::string>
out_field_names, std::vector<std::string> key_names);
-extern "C" SEXP _arrow_ExecNode_Aggregate(SEXP input_sexp, SEXP options_sexp,
SEXP target_names_sexp, SEXP out_field_names_sexp, SEXP key_names_sexp){
+std::shared_ptr<compute::ExecNode> ExecNode_Aggregate(const
std::shared_ptr<compute::ExecNode>& input, cpp11::list options,
std::vector<std::string> key_names);
+extern "C" SEXP _arrow_ExecNode_Aggregate(SEXP input_sexp, SEXP options_sexp,
SEXP key_names_sexp){
BEGIN_CPP11
arrow::r::Input<const std::shared_ptr<compute::ExecNode>&>::type
input(input_sexp);
arrow::r::Input<cpp11::list>::type options(options_sexp);
- arrow::r::Input<std::vector<std::string>>::type
target_names(target_names_sexp);
- arrow::r::Input<std::vector<std::string>>::type
out_field_names(out_field_names_sexp);
arrow::r::Input<std::vector<std::string>>::type
key_names(key_names_sexp);
- return cpp11::as_sexp(ExecNode_Aggregate(input, options, target_names,
out_field_names, key_names));
+ return cpp11::as_sexp(ExecNode_Aggregate(input, options, key_names));
END_CPP11
}
// compute-exec.cpp
@@ -5248,7 +5246,7 @@ static const R_CallMethodDef CallEntries[] = {
{ "_arrow_ExecPlan_Write", (DL_FUNC) &_arrow_ExecPlan_Write,
14},
{ "_arrow_ExecNode_Filter", (DL_FUNC) &_arrow_ExecNode_Filter,
2},
{ "_arrow_ExecNode_Project", (DL_FUNC)
&_arrow_ExecNode_Project, 3},
- { "_arrow_ExecNode_Aggregate", (DL_FUNC)
&_arrow_ExecNode_Aggregate, 5},
+ { "_arrow_ExecNode_Aggregate", (DL_FUNC)
&_arrow_ExecNode_Aggregate, 3},
{ "_arrow_ExecNode_Join", (DL_FUNC) &_arrow_ExecNode_Join, 9},
{ "_arrow_ExecNode_Union", (DL_FUNC) &_arrow_ExecNode_Union,
2},
{ "_arrow_ExecNode_SourceNode", (DL_FUNC)
&_arrow_ExecNode_SourceNode, 2},
diff --git a/r/src/compute-exec.cpp b/r/src/compute-exec.cpp
index 089d1e71eb..76112b4cef 100644
--- a/r/src/compute-exec.cpp
+++ b/r/src/compute-exec.cpp
@@ -228,29 +228,26 @@ std::shared_ptr<compute::ExecNode> ExecNode_Project(
// [[arrow::export]]
std::shared_ptr<compute::ExecNode> ExecNode_Aggregate(
const std::shared_ptr<compute::ExecNode>& input, cpp11::list options,
- std::vector<std::string> target_names, std::vector<std::string>
out_field_names,
std::vector<std::string> key_names) {
- std::vector<arrow::compute::internal::Aggregate> aggregates;
+ std::vector<arrow::compute::Aggregate> aggregates;
for (cpp11::list name_opts : options) {
- auto name = cpp11::as_cpp<std::string>(name_opts[0]);
- auto opts = make_compute_options(name, name_opts[1]);
+ auto function = cpp11::as_cpp<std::string>(name_opts["fun"]);
+ auto opts = make_compute_options(function, name_opts["options"]);
+ auto target = cpp11::as_cpp<std::string>(name_opts["target"]);
+ auto name = cpp11::as_cpp<std::string>(name_opts["name"]);
- aggregates.push_back(
- arrow::compute::internal::Aggregate{std::move(name), std::move(opts)});
+ aggregates.push_back(arrow::compute::Aggregate{std::move(function), opts,
+ std::move(target),
std::move(name)});
}
- std::vector<arrow::FieldRef> targets, keys;
- for (auto&& name : target_names) {
- targets.emplace_back(std::move(name));
- }
+ std::vector<arrow::FieldRef> keys;
for (auto&& name : key_names) {
keys.emplace_back(std::move(name));
}
return MakeExecNodeOrStop(
"aggregate", input->plan(), {input.get()},
- compute::AggregateNodeOptions{std::move(aggregates), std::move(targets),
- std::move(out_field_names),
std::move(keys)});
+ compute::AggregateNodeOptions{std::move(aggregates), std::move(keys)});
}
// [[arrow::export]]