R-JunmingChen commented on code in PR #37100:
URL: https://github.com/apache/arrow/pull/37100#discussion_r1364753755
##########
cpp/src/arrow/compute/kernels/aggregate_basic_internal.h:
##########
@@ -912,6 +912,122 @@ struct NullMinMaxImpl : public ScalarAggregator {
}
};
+template <SimdLevel::type SimdLevel>
+struct DictionaryMinMaxImpl : public ScalarAggregator {
+ using ThisType = DictionaryMinMaxImpl<SimdLevel>;
+
+ DictionaryMinMaxImpl(std::shared_ptr<DataType> out_type,
ScalarAggregateOptions options)
+ : options(std::move(options)),
+ out_type(std::move(out_type)),
+ has_nulls(false),
+ count(0),
+ min(nullptr),
+ max(nullptr) {
+ this->options.min_count = std::max<uint32_t>(1, this->options.min_count);
+ }
+
+ Status Consume(KernelContext* ctx, const ExecSpan& batch) override {
+ if (batch[0].is_scalar()) {
+ return Status::NotImplemented("No min/max implemented for
DictionaryScalar");
+ }
+
+ DictionaryArray dict_arr(batch[0].array.ToArrayData());
+ ARROW_ASSIGN_OR_RAISE(auto compacted_arr,
dict_arr.Compact(ctx->memory_pool()));
+ const DictionaryArray& compacted_dict_arr =
+ checked_cast<const DictionaryArray&>(*compacted_arr);
+ if (compacted_dict_arr.length() - compacted_dict_arr.null_count() == 0) {
+ return Status::OK();
+ }
+ this->has_nulls |= compacted_dict_arr.null_count() > 0;
+ this->count += compacted_dict_arr.length() -
compacted_dict_arr.null_count();
+
+ std::shared_ptr<Scalar> dict_min;
+ std::shared_ptr<Scalar> dict_max;
+ if (compacted_dict_arr.length() - compacted_dict_arr.null_count() == 1) {
+ ARROW_ASSIGN_OR_RAISE(dict_min,
compacted_dict_arr.dictionary()->GetScalar(0));
+ dict_max = dict_min;
+ } else {
+ Datum dict_values(compacted_dict_arr.dictionary());
+ ARROW_ASSIGN_OR_RAISE(
+ Datum result, MinMax(std::move(dict_values),
ScalarAggregateOptions::Defaults(),
+ ctx->exec_context()));
+ const StructScalar& struct_result =
+ checked_cast<const StructScalar&>(*result.scalar());
+ ARROW_ASSIGN_OR_RAISE(dict_min, struct_result.field(FieldRef("min")));
+ ARROW_ASSIGN_OR_RAISE(dict_max, struct_result.field(FieldRef("max")));
+ }
+ ARROW_RETURN_NOT_OK(UpdateMinMaxState(dict_min, dict_max, ctx));
+ return Status::OK();
+ }
+
+ Status MergeFrom(KernelContext* ctx, KernelState&& src) override {
+ const auto& other = checked_cast<const ThisType&>(src);
+
+ this->has_nulls |= other.has_nulls;
+ this->count += other.count;
+ ARROW_RETURN_NOT_OK(UpdateMinMaxState(other.min, other.max, ctx));
+ return Status::OK();
+ }
+
+ Status Finalize(KernelContext*, Datum* out) override {
+ const auto& struct_type = checked_cast<const StructType&>(*out_type);
+ const auto& child_type = struct_type.field(0)->type();
+
+ std::vector<std::shared_ptr<Scalar>> values;
+ if ((this->has_nulls && !options.skip_nulls) || (this->count <
options.min_count)) {
+ // (null, null)
+ std::shared_ptr<Scalar> null_scalar = MakeNullScalar(child_type);
+ values = {null_scalar, null_scalar};
+ } else {
+ values = {std::move(this->min), std::move(this->max)};
+ }
+
+ out->value = std::make_shared<StructScalar>(std::move(values),
this->out_type);
+ return Status::OK();
+ }
+
+ ScalarAggregateOptions options;
+ std::shared_ptr<DataType> out_type;
+ bool has_nulls;
+ int64_t count;
+ std::shared_ptr<Scalar> min;
+ std::shared_ptr<Scalar> max;
+
+ private:
+ Status UpdateMinMaxState(const std::shared_ptr<Scalar>& other_min,
+ const std::shared_ptr<Scalar>& other_max,
KernelContext* ctx) {
+ if (this->min == nullptr || this->min->type->id() == Type::NA) {
+ this->min = other_min;
+ } else if (other_min != nullptr && other_min->type->id() != Type::NA) {
+ ARROW_ASSIGN_OR_RAISE(
+ Datum greater_result,
+ CallFunction("greater", {this->min, other_min},
ctx->exec_context()));
+ const BooleanScalar& greater_scalar =
+ checked_cast<const BooleanScalar&>(*greater_result.scalar());
+
+ if (greater_scalar.value) {
+ this->min = other_min;
+ }
+ }
+
+ if (this->max == nullptr || this->max->type->id() == Type::NA) {
+ this->max = other_max;
+ } else if (other_max != nullptr && other_max->type->id() != Type::NA) {
+ ARROW_ASSIGN_OR_RAISE(
+ Datum less_result,
+ CallFunction("less", {this->max, other_max}, ctx->exec_context()));
+ const BooleanScalar& less_scalar =
+ checked_cast<const BooleanScalar&>(*less_result.scalar());
+
+ if (less_scalar.value) {
+ this->max = other_max;
+ }
+ }
Review Comment:
Sure, it's a good idea, @js8544 also suggested to use kernels directly. I
will try this plan.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]