westonpace commented on code in PR #34311:
URL: https://github.com/apache/arrow/pull/34311#discussion_r1119120080


##########
cpp/src/arrow/compute/kernels/hash_aggregate_test.cc:
##########
@@ -135,22 +141,84 @@ Result<Datum> NaiveGroupBy(std::vector<Datum> arguments, 
std::vector<Datum> keys
   return Take(struct_arr, sorted_indices);
 }
 
+Result<Datum> MakeGroupByOutput(const std::vector<ExecBatch>& output_batches,
+                                const std::shared_ptr<Schema> output_schema,
+                                size_t num_aggregates, size_t num_keys, bool 
naive) {
+  ArrayVector out_arrays(num_aggregates + num_keys);
+  for (size_t i = 0; i < out_arrays.size(); ++i) {
+    std::vector<std::shared_ptr<Array>> arrays(output_batches.size());
+    for (size_t j = 0; j < output_batches.size(); ++j) {
+      arrays[j] = output_batches[j].values[i].make_array();
+    }
+    if (arrays.empty()) {
+      ARROW_ASSIGN_OR_RAISE(
+          out_arrays[i],
+          MakeArrayOfNull(output_schema->field(static_cast<int>(i))->type(),
+                          /*length=*/0));
+    } else {
+      ARROW_ASSIGN_OR_RAISE(out_arrays[i], Concatenate(arrays));
+    }
+  }
+
+  ARROW_ASSIGN_OR_RAISE(
+      std::shared_ptr<Array> struct_arr,
+      StructArray::Make(std::move(out_arrays), output_schema->fields()));
+
+  bool need_sort = !naive;
+  for (size_t i = num_aggregates; need_sort && i < out_arrays.size(); i++) {
+    if (output_schema->field(i)->type()->id() == Type::DICTIONARY) {

Review Comment:
   Windows compilers will also emit a warning in this case.  We could suppress 
that warning but I think it is useful to be explicit and should help remind the 
author to at least do a basic sanity check that an overflow is not going to 
happen here.



##########
cpp/src/arrow/compute/exec/aggregate_node.cc:
##########
@@ -169,35 +186,117 @@ void AggregatesToString(std::stringstream* ss, const 
Schema& input_schema,
   *ss << ']';
 }
 
+template <typename BatchHandler>
+Status HandleSegments(std::unique_ptr<GroupingSegmenter>& segmenter,
+                      const ExecBatch& batch, const std::vector<int>& ids,
+                      const BatchHandler& handle_batch) {
+  int64_t offset = 0;
+  ARROW_ASSIGN_OR_RAISE(auto segment_exec_batch, batch.SelectValues(ids));
+  ExecSpan segment_batch(segment_exec_batch);
+  while (true) {
+    ARROW_ASSIGN_OR_RAISE(auto segment, 
segmenter->GetNextSegment(segment_batch, offset));
+    if (segment.offset >= segment_batch.length) break;  // condition of 
no-next-segment
+    ARROW_RETURN_NOT_OK(handle_batch(batch, segment));
+    offset = segment.offset + segment.length;
+  }
+  return Status::OK();
+}
+
+Status GetScalarFields(std::vector<Datum>* values_ptr, const ExecBatch& 
input_batch,
+                       const std::vector<int>& field_ids) {
+  DCHECK_GT(input_batch.length, 0);
+  std::vector<Datum>& values = *values_ptr;
+  int64_t row = input_batch.length - 1;
+  values.clear();
+  values.resize(field_ids.size());
+  for (size_t i = 0; i < field_ids.size(); i++) {
+    const Datum& value = input_batch.values[field_ids[i]];
+    if (value.is_scalar()) {
+      values[i] = value;
+    } else if (value.is_array()) {
+      ARROW_ASSIGN_OR_RAISE(auto scalar, value.make_array()->GetScalar(row));
+      values[i] = scalar;
+    } else {
+      DCHECK(false);
+    }
+  }
+  return Status::OK();
+}
+
+void PlaceFields(ExecBatch& batch, size_t base, std::vector<Datum>& values) {
+  DCHECK_LE(base + values.size(), batch.values.size());
+  for (size_t i = 0; i < values.size(); i++) {
+    batch.values[base + i] = values[i];
+  }
+}
+
 class ScalarAggregateNode : public ExecNode, public TracedNode {
  public:
   ScalarAggregateNode(ExecPlan* plan, std::vector<ExecNode*> inputs,
                       std::shared_ptr<Schema> output_schema,
+                      std::unique_ptr<GroupingSegmenter> segmenter,
+                      std::vector<int> segment_field_ids,
                       std::vector<std::vector<int>> target_fieldsets,
                       std::vector<Aggregate> aggs,
                       std::vector<const ScalarAggregateKernel*> kernels,
                       std::vector<std::vector<std::unique_ptr<KernelState>>> 
states)
       : ExecNode(plan, std::move(inputs), {"target"},
                  /*output_schema=*/std::move(output_schema)),
         TracedNode(this),
+        segmenter_(std::move(segmenter)),
+        segment_field_ids_(std::move(segment_field_ids)),
         target_fieldsets_(std::move(target_fieldsets)),
         aggs_(std::move(aggs)),
         kernels_(std::move(kernels)),
-        states_(std::move(states)) {}
+        states_(std::move(states)) {
+    const auto& input_schema = *this->inputs()[0]->output_schema();
+    for (size_t i = 0; i < kernels_.size(); ++i) {
+      std::vector<TypeHolder> in_types;
+      for (const auto& target : target_fieldsets_[i]) {
+        in_types.emplace_back(input_schema.field(target)->type().get());
+      }
+      in_typesets_.push_back(std::move(in_types));
+    }
+  }
 
   static Result<ExecNode*> Make(ExecPlan* plan, std::vector<ExecNode*> inputs,
                                 const ExecNodeOptions& options) {
     RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 1, 
"ScalarAggregateNode"));
 
     const auto& aggregate_options = checked_cast<const 
AggregateNodeOptions&>(options);
     auto aggregates = aggregate_options.aggregates;
+    const auto& keys = aggregate_options.keys;
+    const auto& segment_keys = aggregate_options.segment_keys;
+
+    if (keys.size() > 0) {
+      return Status::Invalid("Scalar aggregation with some key");
+    }
+    if (plan->query_context()->exec_context()->executor()->GetCapacity() > 1 &&
+        segment_keys.size() > 0) {
+      return Status::NotImplemented("Segmented aggregation in a multi-threaded 
plan");
+    }
 
     const auto& input_schema = *inputs[0]->output_schema();
     auto exec_ctx = plan->query_context()->exec_context();
 
+    std::vector<int> segment_field_ids(segment_keys.size());
+    std::vector<TypeHolder> segment_key_types(segment_keys.size());
+    for (size_t i = 0; i < segment_keys.size(); i++) {
+      ARROW_ASSIGN_OR_RAISE(auto match, segment_keys[i].FindOne(input_schema));
+      if (match.indices().size() > 1) {
+        // ARROW-18369: Support nested references as segment ids
+        return Status::Invalid("Nested references cannot be used as segment 
ids");
+      }
+      segment_field_ids[i] = match[0];
+      segment_key_types[i] = input_schema.field(match[0])->type().get();
+    }
+
+    ARROW_ASSIGN_OR_RAISE(
+        auto segmenter, GroupingSegmenter::Make(std::move(segment_key_types), 
exec_ctx));

Review Comment:
   We are asserting here that the segment keys will not be null since 
`nullable_keys` is defaulting to `false`.  At the very least, we should 
document this in `options.h`
   
   On the other hand, why can't the fact grouping implementation tolerate 
nulls? It seems we could pick some meaning for `null`. Either:
   
    1. null represents a key of it's own
    2. null means we don't know the value and we assume we are maintaining the 
previous key
   
   Do you know how groupby is handling null keys today (and by extension the 
`AnyKeysGroupingSegmenter`)?  I think it is `#1` but I could be mistaken.  If 
that is the case then it probably wouldn't be too much more complexity to 
handle nulls in the same fashion in `SimpleKeyGroupingSegmenter`.
   
   We need to document the constraint now.  However, I'd be find deferring any 
extended implementation for a follow-up.
   
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to