icexelloss commented on code in PR #34885:
URL: https://github.com/apache/arrow/pull/34885#discussion_r1157361535
##########
cpp/src/arrow/acero/aggregate_node.cc:
##########
@@ -560,114 +566,40 @@ class GroupByNode : public ExecNode, public TracedNode {
const ExecNodeOptions& options) {
RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 1, "GroupByNode"));
- auto input = inputs[0];
const auto& aggregate_options = checked_cast<const
AggregateNodeOptions&>(options);
+ auto aggregates = aggregate_options.aggregates;
const auto& keys = aggregate_options.keys;
const auto& segment_keys = aggregate_options.segment_keys;
- // Copy (need to modify options pointer below)
- auto aggs = aggregate_options.aggregates;
if (plan->query_context()->exec_context()->executor()->GetCapacity() > 1 &&
segment_keys.size() > 0) {
return Status::NotImplemented("Segmented aggregation in a multi-threaded
plan");
}
- // Get input schema
- auto input_schema = input->output_schema();
-
- // Find input field indices for key fields
- std::vector<int> key_field_ids(keys.size());
- for (size_t i = 0; i < keys.size(); ++i) {
- ARROW_ASSIGN_OR_RAISE(auto match, keys[i].FindOne(*input_schema));
- key_field_ids[i] = match[0];
- }
-
- // Find input field indices for segment key fields
- std::vector<int> segment_key_field_ids(segment_keys.size());
- for (size_t i = 0; i < segment_keys.size(); ++i) {
- ARROW_ASSIGN_OR_RAISE(auto match,
segment_keys[i].FindOne(*input_schema));
- segment_key_field_ids[i] = match[0];
- }
-
- // Check key fields and segment key fields are disjoint
- std::unordered_set<int> key_field_id_set(key_field_ids.begin(),
key_field_ids.end());
- for (const auto& segment_key_field_id : segment_key_field_ids) {
- if (key_field_id_set.find(segment_key_field_id) !=
key_field_id_set.end()) {
- return Status::Invalid("Group-by aggregation with field '",
-
input_schema->field(segment_key_field_id)->name(),
- "' as both key and segment key");
- }
- }
-
- // Find input field indices for aggregates
- std::vector<std::vector<int>> agg_src_fieldsets(aggs.size());
- for (size_t i = 0; i < aggs.size(); ++i) {
- const auto& target_fieldset = aggs[i].target;
- for (const auto& target : target_fieldset) {
- ARROW_ASSIGN_OR_RAISE(auto match, target.FindOne(*input_schema));
- agg_src_fieldsets[i].push_back(match[0]);
- }
- }
-
- // Build vector of aggregate source field data types
- std::vector<std::vector<TypeHolder>> agg_src_types(aggs.size());
- for (size_t i = 0; i < aggs.size(); ++i) {
- for (const auto& agg_src_field_id : agg_src_fieldsets[i]) {
-
agg_src_types[i].push_back(input_schema->field(agg_src_field_id)->type().get());
- }
- }
-
- // Build vector of segment key field data types
- std::vector<TypeHolder> segment_key_types(segment_keys.size());
- for (size_t i = 0; i < segment_keys.size(); ++i) {
- auto segment_key_field_id = segment_key_field_ids[i];
- segment_key_types[i] =
input_schema->field(segment_key_field_id)->type().get();
- }
-
- auto ctx = plan->query_context()->exec_context();
-
- ARROW_ASSIGN_OR_RAISE(auto segmenter,
- RowSegmenter::Make(std::move(segment_key_types),
- /*nullable_keys=*/false, ctx));
-
- // Construct aggregates
- ARROW_ASSIGN_OR_RAISE(auto agg_kernels, GetKernels(ctx, aggs,
agg_src_types));
-
- ARROW_ASSIGN_OR_RAISE(auto agg_states,
- InitKernels(agg_kernels, ctx, aggs, agg_src_types));
-
- ARROW_ASSIGN_OR_RAISE(
- FieldVector agg_result_fields,
- ResolveKernels(aggs, agg_kernels, agg_states, ctx, agg_src_types));
+ const auto& input_schema = *inputs[0]->output_schema();
+ auto exec_ctx = plan->query_context()->exec_context();
- // Build field vector for output schema
- FieldVector output_fields{keys.size() + segment_keys.size() + aggs.size()};
+ ARROW_ASSIGN_OR_RAISE(auto args,
+ aggregate::MakeAggregateNodeArgs(
+ input_schema, keys, segment_keys, aggregates,
+ plan->query_context()->max_concurrency(),
exec_ctx));
- // Aggregate fields come before key fields to match the behavior of
GroupBy function
- for (size_t i = 0; i < aggs.size(); ++i) {
- output_fields[i] =
- agg_result_fields[i]->WithName(aggregate_options.aggregates[i].name);
+ std::vector<const HashAggregateKernel*> kernels;
+ kernels.reserve(args.kernels.size());
+ for (auto kernel : args.kernels) {
+ kernels.push_back(static_cast<const HashAggregateKernel*>(kernel));
}
- size_t base = aggs.size();
- for (size_t i = 0; i < keys.size(); ++i) {
- int key_field_id = key_field_ids[i];
- output_fields[base + i] = input_schema->field(key_field_id);
- }
- base += keys.size();
- for (size_t i = 0; i < segment_keys.size(); ++i) {
- int segment_key_field_id = segment_key_field_ids[i];
- output_fields[base + i] = input_schema->field(segment_key_field_id);
- }
-
- return input->plan()->EmplaceNode<GroupByNode>(
- input, schema(std::move(output_fields)), std::move(key_field_ids),
- std::move(segment_key_field_ids), std::move(segmenter),
std::move(agg_src_types),
- std::move(agg_src_fieldsets), std::move(aggs), std::move(agg_kernels));
+ return inputs[0]->plan()->EmplaceNode<GroupByNode>(
+ inputs[0], std::move(args.output_schema),
std::move(args.grouping_key_field_ids),
+ std::move(args.segment_key_field_ids), std::move(args.segmenter),
+ std::move(args.kernel_intypes), std::move(args.target_fieldsets),
+ std::move(args.aggregates), std::move(kernels));
}
Status ResetKernelStates() {
auto ctx = plan()->query_context()->exec_context();
- ARROW_RETURN_NOT_OK(InitKernels(agg_kernels_, ctx, aggs_, agg_src_types_));
+ ARROW_RETURN_NOT_OK(InitKernels(InitHashAggregateKernel, agg_kernels_, ctx,
Review Comment:
Why passing do we need to pass `/*num_states_per_kernel=*/1`?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]