lidavidm commented on a change in pull request #12368:
URL: https://github.com/apache/arrow/pull/12368#discussion_r806838598
##########
File path: cpp/src/arrow/compute/kernels/hash_aggregate_test.cc
##########
@@ -2460,6 +2461,558 @@ TEST(GroupBy, Distinct) {
}
}
+MATCHER_P(AnyOfScalar, arrow_array, "") {
Review comment:
IIRC, I think these convenience macros aren't always available in the CI
environments we use. See
https://github.com/apache/arrow/commit/cd30dea861d6dfd670032c655f329cb16bb99a7a
##########
File path: cpp/src/arrow/compute/kernels/hash_aggregate.cc
##########
@@ -2451,6 +2451,333 @@ Result<std::unique_ptr<KernelState>>
GroupedDistinctInit(KernelContext* ctx,
return std::move(impl);
}
+// ----------------------------------------------------------------------
+// One implementation
+
+template <typename Type, typename Enable = void>
+struct GroupedOneImpl final : public GroupedAggregator {
+ using CType = typename TypeTraits<Type>::CType;
+ using GetSet = GroupedValueTraits<Type>;
+
+ Status Init(ExecContext* ctx, const std::vector<ValueDescr>&,
+ const FunctionOptions* options) override {
+ // out_type_ initialized by GroupedOneInit
+ ones_ = TypedBufferBuilder<CType>(ctx->memory_pool());
+ has_one_ = TypedBufferBuilder<bool>(ctx->memory_pool());
+ has_value_ = TypedBufferBuilder<bool>(ctx->memory_pool());
+ return Status::OK();
+ }
+
+ Status Resize(int64_t new_num_groups) override {
+ auto added_groups = new_num_groups - num_groups_;
+ num_groups_ = new_num_groups;
+ RETURN_NOT_OK(ones_.Append(added_groups, static_cast<CType>(0)));
+ RETURN_NOT_OK(has_one_.Append(added_groups, false));
+ RETURN_NOT_OK(has_value_.Append(added_groups, false));
+ return Status::OK();
+ }
+
+ Status Consume(const ExecBatch& batch) override {
+ auto raw_ones_ = ones_.mutable_data();
+
+ return VisitGroupedValues<Type>(
+ batch,
+ [&](uint32_t g, CType val) -> Status {
+ if (!bit_util::GetBit(has_one_.data(), g)) {
+ GetSet::Set(raw_ones_, g, val);
+ bit_util::SetBit(has_one_.mutable_data(), g);
+ bit_util::SetBit(has_value_.mutable_data(), g);
+ }
+ return Status::OK();
+ },
+ [&](uint32_t g) -> Status {
+ bit_util::SetBit(has_one_.mutable_data(), g);
+ return Status::OK();
+ });
+ }
+
+ Status Merge(GroupedAggregator&& raw_other,
+ const ArrayData& group_id_mapping) override {
+ auto other = checked_cast<GroupedOneImpl*>(&raw_other);
+
+ auto raw_ones = ones_.mutable_data();
+ auto other_raw_ones = other->ones_.mutable_data();
+
+ auto g = group_id_mapping.GetValues<uint32_t>(1);
+ for (uint32_t other_g = 0; static_cast<int64_t>(other_g) <
group_id_mapping.length;
+ ++other_g, ++g) {
+ if (!bit_util::GetBit(has_one_.data(), *g)) {
+ if (bit_util::GetBit(other->has_value_.data(), other_g)) {
+ GetSet::Set(raw_ones, *g, GetSet::Get(other_raw_ones, other_g));
+ bit_util::SetBit(has_value_.mutable_data(), *g);
+ }
+ bit_util::SetBit(has_one_.mutable_data(), *g);
+ }
+ }
+
+ return Status::OK();
+ }
+
+ Result<Datum> Finalize() override {
+ ARROW_ASSIGN_OR_RAISE(auto null_bitmap, has_value_.Finish());
+ ARROW_ASSIGN_OR_RAISE(auto data, ones_.Finish());
+ return ArrayData::Make(out_type_, num_groups_,
+ {std::move(null_bitmap), std::move(data)});
+ }
+
+ std::shared_ptr<DataType> out_type() const override { return out_type_; }
+
+ int64_t num_groups_;
+ TypedBufferBuilder<CType> ones_;
+ TypedBufferBuilder<bool> has_one_, has_value_;
+ std::shared_ptr<DataType> out_type_;
+};
+
+struct GroupedNullOneImpl : public GroupedAggregator {
+ Status Init(ExecContext* ctx, const std::vector<ValueDescr>&,
+ const FunctionOptions* options) override {
+ return Status::OK();
+ }
+
+ Status Resize(int64_t new_num_groups) override {
+ num_groups_ = new_num_groups;
+ return Status::OK();
+ }
+
+ Status Consume(const ExecBatch& batch) override { return Status::OK(); }
+
+ Status Merge(GroupedAggregator&& raw_other,
+ const ArrayData& group_id_mapping) override {
+ return Status::OK();
+ }
+
+ Result<Datum> Finalize() override {
+ return ArrayData::Make(null(), num_groups_, {nullptr}, num_groups_);
+ }
+
+ std::shared_ptr<DataType> out_type() const override { return null(); }
+
+ int64_t num_groups_;
+};
+
+template <typename Type>
+struct GroupedOneImpl<Type, enable_if_t<is_base_binary_type<Type>::value ||
+ std::is_same<Type,
FixedSizeBinaryType>::value>>
+ final : public GroupedAggregator {
+ using Allocator = arrow::stl::allocator<char>;
+ using StringType = std::basic_string<char, std::char_traits<char>,
Allocator>;
+
+ Status Init(ExecContext* ctx, const std::vector<ValueDescr>&,
+ const FunctionOptions* options) override {
+ ctx_ = ctx;
+ allocator_ = Allocator(ctx->memory_pool());
+ // out_type_ initialized by GroupedOneInit
+ has_value_ = TypedBufferBuilder<bool>(ctx->memory_pool());
+ return Status::OK();
+ }
+
+ Status Resize(int64_t new_num_groups) override {
+ auto added_groups = new_num_groups - num_groups_;
+ DCHECK_GE(added_groups, 0);
+ num_groups_ = new_num_groups;
+ ones_.resize(new_num_groups);
+ RETURN_NOT_OK(has_one_.Append(added_groups, false));
+ RETURN_NOT_OK(has_value_.Append(added_groups, false));
+ return Status::OK();
+ }
+
+ Status Consume(const ExecBatch& batch) override {
+ return VisitGroupedValues<Type>(
+ batch,
+ [&](uint32_t g, util::string_view val) -> Status {
+ if (!bit_util::GetBit(has_one_.data(), g)) {
+ ones_[g].emplace(val.data(), val.size(), allocator_);
+ bit_util::SetBit(has_one_.mutable_data(), g);
+ bit_util::SetBit(has_value_.mutable_data(), g);
+ }
+ return Status::OK();
+ },
+ [&](uint32_t g) -> Status {
+ // as has_one_ is set, has_value_ will never be set, resulting in
null
+ bit_util::SetBit(has_one_.mutable_data(), g);
+ return Status::OK();
+ });
+ }
+
+ Status Merge(GroupedAggregator&& raw_other,
+ const ArrayData& group_id_mapping) override {
+ auto other = checked_cast<GroupedOneImpl*>(&raw_other);
+ auto g = group_id_mapping.GetValues<uint32_t>(1);
+ for (uint32_t other_g = 0; static_cast<int64_t>(other_g) <
group_id_mapping.length;
+ ++other_g, ++g) {
+ if (!bit_util::GetBit(has_one_.data(), *g)) {
+ if (bit_util::GetBit(other->has_value_.data(), other_g)) {
+ ones_[*g] = std::move(other->ones_[other_g]);
+ bit_util::SetBit(has_value_.mutable_data(), *g);
+ }
+ bit_util::SetBit(has_one_.mutable_data(), *g);
+ }
+ }
+ return Status::OK();
+ }
+
+ Result<Datum> Finalize() override {
+ ARROW_ASSIGN_OR_RAISE(auto null_bitmap, has_value_.Finish());
+ auto ones =
+ ArrayData::Make(out_type(), num_groups_, {std::move(null_bitmap),
nullptr});
+ RETURN_NOT_OK(MakeOffsetsValues(ones.get(), ones_));
+ return ones;
+ }
+
+ template <typename T = Type>
+ enable_if_base_binary<T, Status> MakeOffsetsValues(
+ ArrayData* array, const std::vector<util::optional<StringType>>& values)
{
+ using offset_type = typename T::offset_type;
+ ARROW_ASSIGN_OR_RAISE(
+ auto raw_offsets,
+ AllocateBuffer((1 + values.size()) * sizeof(offset_type),
ctx_->memory_pool()));
+ auto* offsets =
reinterpret_cast<offset_type*>(raw_offsets->mutable_data());
+ offsets[0] = 0;
+ offsets++;
+ const uint8_t* null_bitmap = array->buffers[0]->data();
+ offset_type total_length = 0;
+ for (size_t i = 0; i < values.size(); i++) {
+ if (bit_util::GetBit(null_bitmap, i)) {
+ const util::optional<StringType>& value = values[i];
+ DCHECK(value.has_value());
+ if (value->size() >
+ static_cast<size_t>(std::numeric_limits<offset_type>::max()) ||
+ arrow::internal::AddWithOverflow(
+ total_length, static_cast<offset_type>(value->size()),
&total_length)) {
+ return Status::Invalid("Result is too large to fit in ",
*array->type,
+ " cast to large_ variant of type");
+ }
+ }
+ offsets[i] = total_length;
+ }
+ ARROW_ASSIGN_OR_RAISE(auto data, AllocateBuffer(total_length,
ctx_->memory_pool()));
+ int64_t offset = 0;
+ for (size_t i = 0; i < values.size(); i++) {
+ if (bit_util::GetBit(null_bitmap, i)) {
+ const util::optional<StringType>& value = values[i];
+ DCHECK(value.has_value());
+ std::memcpy(data->mutable_data() + offset, value->data(),
value->size());
+ offset += value->size();
+ }
+ }
+ array->buffers[1] = std::move(raw_offsets);
+ array->buffers.push_back(std::move(data));
+ return Status::OK();
+ }
+
+ template <typename T = Type>
+ enable_if_same<T, FixedSizeBinaryType, Status> MakeOffsetsValues(
+ ArrayData* array, const std::vector<util::optional<StringType>>& values)
{
+ const uint8_t* null_bitmap = array->buffers[0]->data();
+ const int32_t slot_width =
+ checked_cast<const FixedSizeBinaryType&>(*array->type).byte_width();
+ int64_t total_length = values.size() * slot_width;
+ ARROW_ASSIGN_OR_RAISE(auto data, AllocateBuffer(total_length,
ctx_->memory_pool()));
+ int64_t offset = 0;
+ for (size_t i = 0; i < values.size(); i++) {
+ if (bit_util::GetBit(null_bitmap, i)) {
+ const util::optional<StringType>& value = values[i];
+ DCHECK(value.has_value());
+ std::memcpy(data->mutable_data() + offset, value->data(), slot_width);
+ } else {
+ std::memset(data->mutable_data() + offset, 0x00, slot_width);
+ }
+ offset += slot_width;
+ }
+ array->buffers[1] = std::move(data);
+ return Status::OK();
+ }
+
+ std::shared_ptr<DataType> out_type() const override { return out_type_; }
+
+ ExecContext* ctx_;
+ Allocator allocator_;
+ int64_t num_groups_;
+ std::vector<util::optional<StringType>> ones_;
+ TypedBufferBuilder<bool> has_one_, has_value_;
+ std::shared_ptr<DataType> out_type_;
+};
+
+template <typename T>
+Result<std::unique_ptr<KernelState>> GroupedOneInit(KernelContext* ctx,
+ const KernelInitArgs&
args) {
+ ARROW_ASSIGN_OR_RAISE(auto impl, HashAggregateInit<GroupedOneImpl<T>>(ctx,
args));
+ auto instance = static_cast<GroupedOneImpl<T>*>(impl.get());
+ instance->out_type_ = args.inputs[0].type;
+ return std::move(impl);
+}
+
+struct GroupedOneFactory {
+ template <typename T>
+ enable_if_physical_integer<T, Status> Visit(const T&) {
+ using PhysicalType = typename T::PhysicalType;
+ kernel = MakeKernel(std::move(argument_type),
GroupedOneInit<PhysicalType>);
+ return Status::OK();
+ }
+
+ // MSVC2015 apparently doesn't compile this properly if we use
Review comment:
we got rid of MSVC2015 so we can replace these two overloads with
enable_if_floating_point.
##########
File path: cpp/src/arrow/compute/kernels/hash_aggregate_test.cc
##########
@@ -2460,6 +2461,558 @@ TEST(GroupBy, Distinct) {
}
}
+MATCHER_P(AnyOfScalar, arrow_array, "") {
+ for (int64_t i = 0; i < arrow_array->length(); ++i) {
+ auto scalar = arrow_array->GetScalar(i).ValueOrDie();
Review comment:
We could handle the error instead and report an assertion failure if
GetScalar fails.
##########
File path: cpp/src/arrow/compute/kernels/hash_aggregate.cc
##########
@@ -2451,6 +2451,333 @@ Result<std::unique_ptr<KernelState>>
GroupedDistinctInit(KernelContext* ctx,
return std::move(impl);
}
+// ----------------------------------------------------------------------
+// One implementation
+
+template <typename Type, typename Enable = void>
+struct GroupedOneImpl final : public GroupedAggregator {
+ using CType = typename TypeTraits<Type>::CType;
+ using GetSet = GroupedValueTraits<Type>;
+
+ Status Init(ExecContext* ctx, const std::vector<ValueDescr>&,
+ const FunctionOptions* options) override {
+ // out_type_ initialized by GroupedOneInit
+ ones_ = TypedBufferBuilder<CType>(ctx->memory_pool());
+ has_one_ = TypedBufferBuilder<bool>(ctx->memory_pool());
+ has_value_ = TypedBufferBuilder<bool>(ctx->memory_pool());
+ return Status::OK();
+ }
+
+ Status Resize(int64_t new_num_groups) override {
+ auto added_groups = new_num_groups - num_groups_;
+ num_groups_ = new_num_groups;
+ RETURN_NOT_OK(ones_.Append(added_groups, static_cast<CType>(0)));
+ RETURN_NOT_OK(has_one_.Append(added_groups, false));
+ RETURN_NOT_OK(has_value_.Append(added_groups, false));
+ return Status::OK();
+ }
+
+ Status Consume(const ExecBatch& batch) override {
+ auto raw_ones_ = ones_.mutable_data();
+
+ return VisitGroupedValues<Type>(
+ batch,
+ [&](uint32_t g, CType val) -> Status {
+ if (!bit_util::GetBit(has_one_.data(), g)) {
+ GetSet::Set(raw_ones_, g, val);
+ bit_util::SetBit(has_one_.mutable_data(), g);
+ bit_util::SetBit(has_value_.mutable_data(), g);
+ }
+ return Status::OK();
+ },
+ [&](uint32_t g) -> Status {
+ bit_util::SetBit(has_one_.mutable_data(), g);
+ return Status::OK();
+ });
+ }
+
+ Status Merge(GroupedAggregator&& raw_other,
+ const ArrayData& group_id_mapping) override {
+ auto other = checked_cast<GroupedOneImpl*>(&raw_other);
+
+ auto raw_ones = ones_.mutable_data();
+ auto other_raw_ones = other->ones_.mutable_data();
+
+ auto g = group_id_mapping.GetValues<uint32_t>(1);
+ for (uint32_t other_g = 0; static_cast<int64_t>(other_g) <
group_id_mapping.length;
+ ++other_g, ++g) {
+ if (!bit_util::GetBit(has_one_.data(), *g)) {
+ if (bit_util::GetBit(other->has_value_.data(), other_g)) {
+ GetSet::Set(raw_ones, *g, GetSet::Get(other_raw_ones, other_g));
+ bit_util::SetBit(has_value_.mutable_data(), *g);
+ }
+ bit_util::SetBit(has_one_.mutable_data(), *g);
+ }
+ }
+
+ return Status::OK();
+ }
+
+ Result<Datum> Finalize() override {
+ ARROW_ASSIGN_OR_RAISE(auto null_bitmap, has_value_.Finish());
+ ARROW_ASSIGN_OR_RAISE(auto data, ones_.Finish());
+ return ArrayData::Make(out_type_, num_groups_,
+ {std::move(null_bitmap), std::move(data)});
+ }
+
+ std::shared_ptr<DataType> out_type() const override { return out_type_; }
+
+ int64_t num_groups_;
+ TypedBufferBuilder<CType> ones_;
+ TypedBufferBuilder<bool> has_one_, has_value_;
+ std::shared_ptr<DataType> out_type_;
+};
+
+struct GroupedNullOneImpl : public GroupedAggregator {
+ Status Init(ExecContext* ctx, const std::vector<ValueDescr>&,
+ const FunctionOptions* options) override {
+ return Status::OK();
+ }
+
+ Status Resize(int64_t new_num_groups) override {
+ num_groups_ = new_num_groups;
+ return Status::OK();
+ }
+
+ Status Consume(const ExecBatch& batch) override { return Status::OK(); }
+
+ Status Merge(GroupedAggregator&& raw_other,
+ const ArrayData& group_id_mapping) override {
+ return Status::OK();
+ }
+
+ Result<Datum> Finalize() override {
+ return ArrayData::Make(null(), num_groups_, {nullptr}, num_groups_);
+ }
+
+ std::shared_ptr<DataType> out_type() const override { return null(); }
+
+ int64_t num_groups_;
+};
+
+template <typename Type>
+struct GroupedOneImpl<Type, enable_if_t<is_base_binary_type<Type>::value ||
+ std::is_same<Type,
FixedSizeBinaryType>::value>>
+ final : public GroupedAggregator {
+ using Allocator = arrow::stl::allocator<char>;
+ using StringType = std::basic_string<char, std::char_traits<char>,
Allocator>;
+
+ Status Init(ExecContext* ctx, const std::vector<ValueDescr>&,
+ const FunctionOptions* options) override {
+ ctx_ = ctx;
+ allocator_ = Allocator(ctx->memory_pool());
+ // out_type_ initialized by GroupedOneInit
+ has_value_ = TypedBufferBuilder<bool>(ctx->memory_pool());
+ return Status::OK();
+ }
+
+ Status Resize(int64_t new_num_groups) override {
+ auto added_groups = new_num_groups - num_groups_;
+ DCHECK_GE(added_groups, 0);
+ num_groups_ = new_num_groups;
+ ones_.resize(new_num_groups);
+ RETURN_NOT_OK(has_one_.Append(added_groups, false));
+ RETURN_NOT_OK(has_value_.Append(added_groups, false));
+ return Status::OK();
+ }
+
+ Status Consume(const ExecBatch& batch) override {
+ return VisitGroupedValues<Type>(
+ batch,
+ [&](uint32_t g, util::string_view val) -> Status {
+ if (!bit_util::GetBit(has_one_.data(), g)) {
+ ones_[g].emplace(val.data(), val.size(), allocator_);
+ bit_util::SetBit(has_one_.mutable_data(), g);
+ bit_util::SetBit(has_value_.mutable_data(), g);
+ }
+ return Status::OK();
+ },
+ [&](uint32_t g) -> Status {
+ // as has_one_ is set, has_value_ will never be set, resulting in
null
+ bit_util::SetBit(has_one_.mutable_data(), g);
+ return Status::OK();
+ });
+ }
+
+ Status Merge(GroupedAggregator&& raw_other,
+ const ArrayData& group_id_mapping) override {
+ auto other = checked_cast<GroupedOneImpl*>(&raw_other);
+ auto g = group_id_mapping.GetValues<uint32_t>(1);
+ for (uint32_t other_g = 0; static_cast<int64_t>(other_g) <
group_id_mapping.length;
+ ++other_g, ++g) {
+ if (!bit_util::GetBit(has_one_.data(), *g)) {
+ if (bit_util::GetBit(other->has_value_.data(), other_g)) {
+ ones_[*g] = std::move(other->ones_[other_g]);
+ bit_util::SetBit(has_value_.mutable_data(), *g);
+ }
+ bit_util::SetBit(has_one_.mutable_data(), *g);
+ }
+ }
+ return Status::OK();
+ }
+
+ Result<Datum> Finalize() override {
+ ARROW_ASSIGN_OR_RAISE(auto null_bitmap, has_value_.Finish());
+ auto ones =
+ ArrayData::Make(out_type(), num_groups_, {std::move(null_bitmap),
nullptr});
+ RETURN_NOT_OK(MakeOffsetsValues(ones.get(), ones_));
+ return ones;
+ }
+
+ template <typename T = Type>
+ enable_if_base_binary<T, Status> MakeOffsetsValues(
Review comment:
We could factor those out, yeah
##########
File path: cpp/src/arrow/compute/kernels/hash_aggregate_test.cc
##########
@@ -2460,6 +2461,558 @@ TEST(GroupBy, Distinct) {
}
}
+MATCHER_P(AnyOfScalar, arrow_array, "") {
+ for (int64_t i = 0; i < arrow_array->length(); ++i) {
+ auto scalar = arrow_array->GetScalar(i).ValueOrDie();
+ if (scalar->Equals(arg)) return true;
+ }
+ *result_listener << "Argument scalar: '" << arg->ToString()
+ << "' matches no input scalar.";
+ return false;
+}
+
+MATCHER_P(AnyOfScalarFromUniques, unique_list, "") {
+ const auto& flatten = unique_list->Flatten().ValueOrDie();
+ const auto& offsets =
std::dynamic_pointer_cast<Int32Array>(unique_list->offsets());
+
+ for (int64_t i = 0; i < arg->length(); ++i) {
+ bool match_found = false;
+ const auto group_hash_one = arg->GetScalar(i).ValueOrDie();
+ int64_t start = offsets->Value(i);
+ int64_t end = offsets->Value(i + 1);
+ for (int64_t j = start; j < end; ++j) {
+ auto s = flatten->GetScalar(j).ValueOrDie();
+ if (s->Equals(group_hash_one)) {
+ match_found = true;
+ break;
+ }
+ }
+ if (!match_found) {
+ *result_listener << "Argument scalar: '" << group_hash_one->ToString()
+ << "' matches no input scalar.";
+ return false;
+ }
+ }
+ return true;
+}
+
+TEST(GroupBy, One) {
+ {
+ auto table =
+ TableFromJSON(schema({field("argument", int64()), field("key",
int64())}), {R"([
+ [99, 1],
+ [99, 1]
+])",
+
R"([
+ [77, 2],
+ [null, 3],
+ [null, 3]
+])",
+
R"([
+ [null, 4],
+ [null, 4]
+])",
+
R"([
+ [88, null],
+ [99, 3]
+])",
+
R"([
+ [77, 2],
+ [76, 2]
+])",
+
R"([
+ [75, null],
+ [74, 3]
+ ])",
+
R"([
+ [73, null],
+ [72, null]
+ ])"});
+
+ ASSERT_OK_AND_ASSIGN(auto aggregated_and_grouped,
+ internal::GroupBy(
+ {
+ table->GetColumnByName("argument"),
+ },
+ {
+ table->GetColumnByName("key"),
+ },
+ {
+ {"hash_one", nullptr},
+ },
+ false));
+ ValidateOutput(aggregated_and_grouped);
+ SortBy({"key_0"}, &aggregated_and_grouped);
+
+ AssertDatumsEqual(ArrayFromJSON(struct_({
+ field("hash_one", int64()),
+ field("key_0", int64()),
+ }),
+ R"([
+ [99, 1],
+ [77, 2],
+ [null, 3],
+ [null, 4],
+ [88, null]
+ ])"),
+ aggregated_and_grouped,
+ /*verbose=*/true);
+ }
+ {
+ auto table =
+ TableFromJSON(schema({field("argument", utf8()), field("key",
int64())}), {R"([
+ ["foo", 1],
+ ["foo", 1]
+ ])",
+
R"([
+ ["bar", 2],
+ [null, 3],
+ [null, 3]
+ ])",
+
R"([
+ [null, 4],
+ [null, 4]
+ ])",
+
R"([
+ ["baz", null],
+ ["foo", 3]
+ ])",
+
R"([
+ ["bar", 2],
+ ["spam", 2]
+ ])",
+
R"([
+ ["eggs", null],
+ ["ham", 3]
+ ])",
+
R"([
+ ["a", null],
+ ["b", null]
+ ])"});
+
+ ASSERT_OK_AND_ASSIGN(auto aggregated_and_grouped,
+ internal::GroupBy(
+ {
+ table->GetColumnByName("argument"),
+ },
+ {
+ table->GetColumnByName("key"),
+ },
+ {
+ {"hash_one", nullptr},
+ },
+ false));
+ ValidateOutput(aggregated_and_grouped);
+ SortBy({"key_0"}, &aggregated_and_grouped);
+
+ AssertDatumsEqual(ArrayFromJSON(struct_({
+ field("hash_one", utf8()),
+ field("key_0", int64()),
+ }),
+ R"([
+ ["foo", 1],
+ ["bar", 2],
+ [null, 3],
+ [null, 4],
+ ["baz", null]
+ ])"),
+ aggregated_and_grouped,
+ /*verbose=*/true);
+ }
+}
+
+TEST(GroupBy, OneOnly) {
+ auto in_schema = schema({
+ field("argument0", float64()),
+ field("argument1", null()),
+ field("argument2", boolean()),
+ field("key", int64()),
+ });
+ for (bool use_exec_plan : {false, true}) {
+ for (bool use_threads : {false, true}) {
+ SCOPED_TRACE(use_threads ? "parallel/merged" : "serial");
+
+ auto table = TableFromJSON(in_schema, {R"([
+ [1.0, null, true, 1],
+ [null, null, true, 1]
+])",
+ R"([
+ [0.0, null, false, 2],
+ [null, null, false, 3],
+ [4.0, null, null, null],
+ [3.25, null, true, 1],
+ [0.125, null, false, 2]
+])",
+ R"([
+ [-0.25, null, false, 2],
+ [0.75, null, true, null],
+ [null, null, true, 3]
+])"});
+
+ ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped,
+ GroupByTest(
+ {
+ table->GetColumnByName("argument0"),
+ table->GetColumnByName("argument1"),
+ table->GetColumnByName("argument2"),
+ },
+ {table->GetColumnByName("key")},
+ {
+ {"hash_one", nullptr},
+ {"hash_one", nullptr},
+ {"hash_one", nullptr},
+ },
+ use_threads, use_exec_plan));
+ ValidateOutput(aggregated_and_grouped);
+ SortBy({"key_0"}, &aggregated_and_grouped);
+
+ // AssertDatumsEqual(ArrayFromJSON(struct_({
+ // field("hash_one", float64()),
+ // field("hash_one", null()),
+ // field("hash_one", boolean()),
+ // field("key_0", int64()),
+ // }),
+ // R"([
+ // [1.0, null, true, 1],
+ // [0.0, null, false, 2],
+ // [null, null, false, 3],
+ // [4.0, null, null, null]
+ // ])"),
+ // aggregated_and_grouped,
+ // /*verbose=*/true);
+
+ const auto& struct_arr = aggregated_and_grouped.array_as<StructArray>();
+ // Check the key column
+ AssertDatumsEqual(ArrayFromJSON(int64(), "[1, 2, 3, null]"),
struct_arr->field(3));
+
+ auto type_col_0 = float64();
+ auto group_one_col_0 =
+ AnyOfScalar(ArrayFromJSON(type_col_0, R"([1.0, null, 3.25])"));
+ auto group_two_col_0 =
+ AnyOfScalar(ArrayFromJSON(type_col_0, R"([0.0, 0.125, -0.25])"));
+ auto group_three_col_0 = AnyOfScalar(ArrayFromJSON(type_col_0,
R"([null])"));
+ auto group_null_col_0 = AnyOfScalar(ArrayFromJSON(type_col_0, R"([4.0,
0.75])"));
+
+ // Check values individually
+ const auto& col0 = struct_arr->field(0);
+ ASSERT_OK_AND_ASSIGN(const auto g_one, col0->GetScalar(0));
+ EXPECT_THAT(g_one, group_one_col_0);
+ ASSERT_OK_AND_ASSIGN(const auto g_two, col0->GetScalar(1));
+ EXPECT_THAT(g_two, group_two_col_0);
+ ASSERT_OK_AND_ASSIGN(const auto g_three, col0->GetScalar(2));
+ EXPECT_THAT(g_three, group_three_col_0);
+ ASSERT_OK_AND_ASSIGN(const auto g_null, col0->GetScalar(3));
+ EXPECT_THAT(g_null, group_null_col_0);
Review comment:
ResultWith is in matchers.h:
https://github.com/apache/arrow/blob/26d6e6217ff79451a3fe366bcc88293c7ae67417/cpp/src/arrow/testing/matchers.h#L250-L254
##########
File path: cpp/src/arrow/compute/kernels/hash_aggregate.cc
##########
@@ -2451,6 +2451,92 @@ Result<std::unique_ptr<KernelState>>
GroupedDistinctInit(KernelContext* ctx,
return std::move(impl);
}
+// ----------------------------------------------------------------------
+// One implementation
+
+struct GroupedOneImpl : public GroupedAggregator {
+ Status Init(ExecContext* ctx, const std::vector<ValueDescr>&,
+ const FunctionOptions* options) override {
+ ctx_ = ctx;
+ pool_ = ctx->memory_pool();
+ return Status::OK();
+ }
+
+ Status Resize(int64_t new_num_groups) override {
+ num_groups_ = new_num_groups;
+ return Status::OK();
+ }
+
+ Status Consume(const ExecBatch& batch) override {
+ ARROW_ASSIGN_OR_RAISE(std::ignore, grouper_->Consume(batch));
+ return Status::OK();
+ }
+
+ Status Merge(GroupedAggregator&& raw_other,
+ const ArrayData& group_id_mapping) override {
+ auto other = checked_cast<GroupedOneImpl*>(&raw_other);
+
+ // Get (value, group_id) pairs, then translate the group IDs and consume
them
+ // ourselves
+ ARROW_ASSIGN_OR_RAISE(auto uniques, other->grouper_->GetUniques());
+ ARROW_ASSIGN_OR_RAISE(auto remapped_g,
+ AllocateBuffer(uniques.length * sizeof(uint32_t),
pool_));
+
+ const auto* g_mapping = group_id_mapping.GetValues<uint32_t>(1);
+ const auto* other_g = uniques[1].array()->GetValues<uint32_t>(1);
+ auto* g = reinterpret_cast<uint32_t*>(remapped_g->mutable_data());
+
+ for (int64_t i = 0; i < uniques.length; i++) {
+ g[i] = g_mapping[other_g[i]];
+ }
+ uniques.values[1] =
+ ArrayData::Make(uint32(), uniques.length, {nullptr,
std::move(remapped_g)});
+
+ return Consume(std::move(uniques));
+ }
+
+ Result<Datum> Finalize() override {
Review comment:
Hash aggregates can be executed in parallel
Consume takes an input batch and updates local state
Merge takes two local states and combines them
Finalize takes a local state and produces the ouput array
##########
File path: cpp/src/arrow/compute/kernels/hash_aggregate_test.cc
##########
@@ -2460,6 +2460,476 @@ TEST(GroupBy, Distinct) {
}
}
+TEST(GroupBy, One) {
Review comment:
I think we can remove this, and we can consolidate test cases to be more
compact.
We can have one test for all the numeric types ("OneTypes", though maybe
let's rename it "OneNumericTypes" or something?), then one test for all the
"misc" types (write out one large input for null, boolean, decimal128,
decimal256, fixed size binary), and one test for all the binary types (iterate
through binary/large binary/string/large string).
##########
File path: cpp/src/arrow/compute/kernels/hash_aggregate.cc
##########
@@ -2451,6 +2451,333 @@ Result<std::unique_ptr<KernelState>>
GroupedDistinctInit(KernelContext* ctx,
return std::move(impl);
}
+// ----------------------------------------------------------------------
+// One implementation
+
+template <typename Type, typename Enable = void>
+struct GroupedOneImpl final : public GroupedAggregator {
+ using CType = typename TypeTraits<Type>::CType;
+ using GetSet = GroupedValueTraits<Type>;
+
+ Status Init(ExecContext* ctx, const std::vector<ValueDescr>&,
+ const FunctionOptions* options) override {
+ // out_type_ initialized by GroupedOneInit
+ ones_ = TypedBufferBuilder<CType>(ctx->memory_pool());
+ has_one_ = TypedBufferBuilder<bool>(ctx->memory_pool());
+ has_value_ = TypedBufferBuilder<bool>(ctx->memory_pool());
+ return Status::OK();
+ }
+
+ Status Resize(int64_t new_num_groups) override {
+ auto added_groups = new_num_groups - num_groups_;
+ num_groups_ = new_num_groups;
+ RETURN_NOT_OK(ones_.Append(added_groups, static_cast<CType>(0)));
+ RETURN_NOT_OK(has_one_.Append(added_groups, false));
+ RETURN_NOT_OK(has_value_.Append(added_groups, false));
+ return Status::OK();
+ }
+
+ Status Consume(const ExecBatch& batch) override {
+ auto raw_ones_ = ones_.mutable_data();
+
+ return VisitGroupedValues<Type>(
+ batch,
+ [&](uint32_t g, CType val) -> Status {
+ if (!bit_util::GetBit(has_one_.data(), g)) {
+ GetSet::Set(raw_ones_, g, val);
+ bit_util::SetBit(has_one_.mutable_data(), g);
+ bit_util::SetBit(has_value_.mutable_data(), g);
+ }
+ return Status::OK();
+ },
+ [&](uint32_t g) -> Status {
+ bit_util::SetBit(has_one_.mutable_data(), g);
Review comment:
Hmm, maybe we don't want this? That is, we could remove this and "bias"
the kernel towards not returning null.
##########
File path: cpp/src/arrow/compute/kernels/hash_aggregate_test.cc
##########
@@ -2460,6 +2461,558 @@ TEST(GroupBy, Distinct) {
}
}
+MATCHER_P(AnyOfScalar, arrow_array, "") {
+ for (int64_t i = 0; i < arrow_array->length(); ++i) {
+ auto scalar = arrow_array->GetScalar(i).ValueOrDie();
+ if (scalar->Equals(arg)) return true;
+ }
+ *result_listener << "Argument scalar: '" << arg->ToString()
+ << "' matches no input scalar.";
+ return false;
+}
+
+MATCHER_P(AnyOfScalarFromUniques, unique_list, "") {
+ const auto& flatten = unique_list->Flatten().ValueOrDie();
+ const auto& offsets =
std::dynamic_pointer_cast<Int32Array>(unique_list->offsets());
+
+ for (int64_t i = 0; i < arg->length(); ++i) {
+ bool match_found = false;
+ const auto group_hash_one = arg->GetScalar(i).ValueOrDie();
+ int64_t start = offsets->Value(i);
+ int64_t end = offsets->Value(i + 1);
+ for (int64_t j = start; j < end; ++j) {
+ auto s = flatten->GetScalar(j).ValueOrDie();
+ if (s->Equals(group_hash_one)) {
+ match_found = true;
+ break;
+ }
+ }
+ if (!match_found) {
+ *result_listener << "Argument scalar: '" << group_hash_one->ToString()
+ << "' matches no input scalar.";
+ return false;
+ }
+ }
+ return true;
+}
+
+TEST(GroupBy, One) {
+ {
+ auto table =
+ TableFromJSON(schema({field("argument", int64()), field("key",
int64())}), {R"([
+ [99, 1],
+ [99, 1]
+])",
+
R"([
+ [77, 2],
+ [null, 3],
+ [null, 3]
+])",
+
R"([
+ [null, 4],
+ [null, 4]
+])",
+
R"([
+ [88, null],
+ [99, 3]
+])",
+
R"([
+ [77, 2],
+ [76, 2]
+])",
+
R"([
+ [75, null],
+ [74, 3]
+ ])",
+
R"([
+ [73, null],
+ [72, null]
+ ])"});
+
+ ASSERT_OK_AND_ASSIGN(auto aggregated_and_grouped,
+ internal::GroupBy(
+ {
+ table->GetColumnByName("argument"),
+ },
+ {
+ table->GetColumnByName("key"),
+ },
+ {
+ {"hash_one", nullptr},
+ },
+ false));
+ ValidateOutput(aggregated_and_grouped);
+ SortBy({"key_0"}, &aggregated_and_grouped);
+
+ AssertDatumsEqual(ArrayFromJSON(struct_({
+ field("hash_one", int64()),
+ field("key_0", int64()),
+ }),
+ R"([
+ [99, 1],
+ [77, 2],
+ [null, 3],
+ [null, 4],
+ [88, null]
+ ])"),
+ aggregated_and_grouped,
+ /*verbose=*/true);
+ }
+ {
+ auto table =
+ TableFromJSON(schema({field("argument", utf8()), field("key",
int64())}), {R"([
+ ["foo", 1],
+ ["foo", 1]
+ ])",
+
R"([
+ ["bar", 2],
+ [null, 3],
+ [null, 3]
+ ])",
+
R"([
+ [null, 4],
+ [null, 4]
+ ])",
+
R"([
+ ["baz", null],
+ ["foo", 3]
+ ])",
+
R"([
+ ["bar", 2],
+ ["spam", 2]
+ ])",
+
R"([
+ ["eggs", null],
+ ["ham", 3]
+ ])",
+
R"([
+ ["a", null],
+ ["b", null]
+ ])"});
+
+ ASSERT_OK_AND_ASSIGN(auto aggregated_and_grouped,
+ internal::GroupBy(
+ {
+ table->GetColumnByName("argument"),
+ },
+ {
+ table->GetColumnByName("key"),
+ },
+ {
+ {"hash_one", nullptr},
+ },
+ false));
+ ValidateOutput(aggregated_and_grouped);
+ SortBy({"key_0"}, &aggregated_and_grouped);
+
+ AssertDatumsEqual(ArrayFromJSON(struct_({
+ field("hash_one", utf8()),
+ field("key_0", int64()),
+ }),
+ R"([
+ ["foo", 1],
+ ["bar", 2],
+ [null, 3],
+ [null, 4],
+ ["baz", null]
+ ])"),
+ aggregated_and_grouped,
+ /*verbose=*/true);
+ }
+}
+
+TEST(GroupBy, OneOnly) {
+ auto in_schema = schema({
+ field("argument0", float64()),
+ field("argument1", null()),
+ field("argument2", boolean()),
+ field("key", int64()),
+ });
+ for (bool use_exec_plan : {false, true}) {
+ for (bool use_threads : {false, true}) {
+ SCOPED_TRACE(use_threads ? "parallel/merged" : "serial");
+
+ auto table = TableFromJSON(in_schema, {R"([
+ [1.0, null, true, 1],
+ [null, null, true, 1]
+])",
+ R"([
+ [0.0, null, false, 2],
+ [null, null, false, 3],
+ [4.0, null, null, null],
+ [3.25, null, true, 1],
+ [0.125, null, false, 2]
+])",
+ R"([
+ [-0.25, null, false, 2],
+ [0.75, null, true, null],
+ [null, null, true, 3]
+])"});
+
+ ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped,
+ GroupByTest(
+ {
+ table->GetColumnByName("argument0"),
+ table->GetColumnByName("argument1"),
+ table->GetColumnByName("argument2"),
+ },
+ {table->GetColumnByName("key")},
+ {
+ {"hash_one", nullptr},
+ {"hash_one", nullptr},
+ {"hash_one", nullptr},
+ },
+ use_threads, use_exec_plan));
+ ValidateOutput(aggregated_and_grouped);
+ SortBy({"key_0"}, &aggregated_and_grouped);
+
+ // AssertDatumsEqual(ArrayFromJSON(struct_({
+ // field("hash_one", float64()),
+ // field("hash_one", null()),
+ // field("hash_one", boolean()),
+ // field("key_0", int64()),
+ // }),
+ // R"([
+ // [1.0, null, true, 1],
+ // [0.0, null, false, 2],
+ // [null, null, false, 3],
+ // [4.0, null, null, null]
+ // ])"),
+ // aggregated_and_grouped,
+ // /*verbose=*/true);
+
+ const auto& struct_arr = aggregated_and_grouped.array_as<StructArray>();
+ // Check the key column
+ AssertDatumsEqual(ArrayFromJSON(int64(), "[1, 2, 3, null]"),
struct_arr->field(3));
+
+ auto type_col_0 = float64();
+ auto group_one_col_0 =
+ AnyOfScalar(ArrayFromJSON(type_col_0, R"([1.0, null, 3.25])"));
+ auto group_two_col_0 =
+ AnyOfScalar(ArrayFromJSON(type_col_0, R"([0.0, 0.125, -0.25])"));
+ auto group_three_col_0 = AnyOfScalar(ArrayFromJSON(type_col_0,
R"([null])"));
+ auto group_null_col_0 = AnyOfScalar(ArrayFromJSON(type_col_0, R"([4.0,
0.75])"));
+
+ // Check values individually
+ const auto& col0 = struct_arr->field(0);
+ ASSERT_OK_AND_ASSIGN(const auto g_one, col0->GetScalar(0));
+ EXPECT_THAT(g_one, group_one_col_0);
+ ASSERT_OK_AND_ASSIGN(const auto g_two, col0->GetScalar(1));
+ EXPECT_THAT(g_two, group_two_col_0);
+ ASSERT_OK_AND_ASSIGN(const auto g_three, col0->GetScalar(2));
+ EXPECT_THAT(g_three, group_three_col_0);
+ ASSERT_OK_AND_ASSIGN(const auto g_null, col0->GetScalar(3));
+ EXPECT_THAT(g_null, group_null_col_0);
Review comment:
I think something like `EXPECT_THAT(col0->GetScalar(0),
ResultWith(AnyOfScalar(...))` could shorten this. Also, we could make a helper
function `AnyOfJSON(type, str)` which calls `AnyOfScalar(ArrayFromJSON())` for
you.
##########
File path: cpp/src/arrow/compute/kernels/hash_aggregate_test.cc
##########
@@ -2460,6 +2461,558 @@ TEST(GroupBy, Distinct) {
}
}
+MATCHER_P(AnyOfScalar, arrow_array, "") {
+ for (int64_t i = 0; i < arrow_array->length(); ++i) {
+ auto scalar = arrow_array->GetScalar(i).ValueOrDie();
+ if (scalar->Equals(arg)) return true;
+ }
+ *result_listener << "Argument scalar: '" << arg->ToString()
+ << "' matches no input scalar.";
+ return false;
+}
+
+MATCHER_P(AnyOfScalarFromUniques, unique_list, "") {
+ const auto& flatten = unique_list->Flatten().ValueOrDie();
+ const auto& offsets =
std::dynamic_pointer_cast<Int32Array>(unique_list->offsets());
+
+ for (int64_t i = 0; i < arg->length(); ++i) {
+ bool match_found = false;
+ const auto group_hash_one = arg->GetScalar(i).ValueOrDie();
+ int64_t start = offsets->Value(i);
+ int64_t end = offsets->Value(i + 1);
+ for (int64_t j = start; j < end; ++j) {
+ auto s = flatten->GetScalar(j).ValueOrDie();
+ if (s->Equals(group_hash_one)) {
+ match_found = true;
+ break;
+ }
+ }
+ if (!match_found) {
+ *result_listener << "Argument scalar: '" << group_hash_one->ToString()
+ << "' matches no input scalar.";
+ return false;
+ }
+ }
+ return true;
+}
+
+TEST(GroupBy, One) {
+ {
+ auto table =
+ TableFromJSON(schema({field("argument", int64()), field("key",
int64())}), {R"([
+ [99, 1],
+ [99, 1]
+])",
+
R"([
+ [77, 2],
+ [null, 3],
+ [null, 3]
+])",
+
R"([
+ [null, 4],
+ [null, 4]
+])",
+
R"([
+ [88, null],
+ [99, 3]
+])",
+
R"([
+ [77, 2],
+ [76, 2]
+])",
+
R"([
+ [75, null],
+ [74, 3]
+ ])",
+
R"([
+ [73, null],
+ [72, null]
+ ])"});
+
+ ASSERT_OK_AND_ASSIGN(auto aggregated_and_grouped,
+ internal::GroupBy(
+ {
+ table->GetColumnByName("argument"),
+ },
+ {
+ table->GetColumnByName("key"),
+ },
+ {
+ {"hash_one", nullptr},
+ },
+ false));
+ ValidateOutput(aggregated_and_grouped);
+ SortBy({"key_0"}, &aggregated_and_grouped);
+
+ AssertDatumsEqual(ArrayFromJSON(struct_({
+ field("hash_one", int64()),
+ field("key_0", int64()),
+ }),
+ R"([
+ [99, 1],
+ [77, 2],
+ [null, 3],
+ [null, 4],
+ [88, null]
+ ])"),
+ aggregated_and_grouped,
+ /*verbose=*/true);
+ }
+ {
+ auto table =
+ TableFromJSON(schema({field("argument", utf8()), field("key",
int64())}), {R"([
+ ["foo", 1],
+ ["foo", 1]
+ ])",
+
R"([
+ ["bar", 2],
+ [null, 3],
+ [null, 3]
+ ])",
+
R"([
+ [null, 4],
+ [null, 4]
+ ])",
+
R"([
+ ["baz", null],
+ ["foo", 3]
+ ])",
+
R"([
+ ["bar", 2],
+ ["spam", 2]
+ ])",
+
R"([
+ ["eggs", null],
+ ["ham", 3]
+ ])",
+
R"([
+ ["a", null],
+ ["b", null]
+ ])"});
+
+ ASSERT_OK_AND_ASSIGN(auto aggregated_and_grouped,
+ internal::GroupBy(
+ {
+ table->GetColumnByName("argument"),
+ },
+ {
+ table->GetColumnByName("key"),
+ },
+ {
+ {"hash_one", nullptr},
+ },
+ false));
+ ValidateOutput(aggregated_and_grouped);
+ SortBy({"key_0"}, &aggregated_and_grouped);
+
+ AssertDatumsEqual(ArrayFromJSON(struct_({
+ field("hash_one", utf8()),
+ field("key_0", int64()),
+ }),
+ R"([
+ ["foo", 1],
+ ["bar", 2],
+ [null, 3],
+ [null, 4],
+ ["baz", null]
+ ])"),
+ aggregated_and_grouped,
+ /*verbose=*/true);
+ }
+}
+
+TEST(GroupBy, OneOnly) {
+ auto in_schema = schema({
+ field("argument0", float64()),
+ field("argument1", null()),
+ field("argument2", boolean()),
+ field("key", int64()),
+ });
+ for (bool use_exec_plan : {false, true}) {
+ for (bool use_threads : {false, true}) {
+ SCOPED_TRACE(use_threads ? "parallel/merged" : "serial");
+
+ auto table = TableFromJSON(in_schema, {R"([
+ [1.0, null, true, 1],
+ [null, null, true, 1]
+])",
+ R"([
+ [0.0, null, false, 2],
+ [null, null, false, 3],
+ [4.0, null, null, null],
+ [3.25, null, true, 1],
+ [0.125, null, false, 2]
+])",
+ R"([
+ [-0.25, null, false, 2],
+ [0.75, null, true, null],
+ [null, null, true, 3]
+])"});
+
+ ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped,
+ GroupByTest(
+ {
+ table->GetColumnByName("argument0"),
+ table->GetColumnByName("argument1"),
+ table->GetColumnByName("argument2"),
+ },
+ {table->GetColumnByName("key")},
+ {
+ {"hash_one", nullptr},
+ {"hash_one", nullptr},
+ {"hash_one", nullptr},
+ },
+ use_threads, use_exec_plan));
+ ValidateOutput(aggregated_and_grouped);
+ SortBy({"key_0"}, &aggregated_and_grouped);
+
+ // AssertDatumsEqual(ArrayFromJSON(struct_({
+ // field("hash_one", float64()),
+ // field("hash_one", null()),
+ // field("hash_one", boolean()),
+ // field("key_0", int64()),
+ // }),
+ // R"([
+ // [1.0, null, true, 1],
+ // [0.0, null, false, 2],
+ // [null, null, false, 3],
+ // [4.0, null, null, null]
+ // ])"),
+ // aggregated_and_grouped,
+ // /*verbose=*/true);
+
+ const auto& struct_arr = aggregated_and_grouped.array_as<StructArray>();
+ // Check the key column
+ AssertDatumsEqual(ArrayFromJSON(int64(), "[1, 2, 3, null]"),
struct_arr->field(3));
+
+ auto type_col_0 = float64();
+ auto group_one_col_0 =
+ AnyOfScalar(ArrayFromJSON(type_col_0, R"([1.0, null, 3.25])"));
+ auto group_two_col_0 =
+ AnyOfScalar(ArrayFromJSON(type_col_0, R"([0.0, 0.125, -0.25])"));
+ auto group_three_col_0 = AnyOfScalar(ArrayFromJSON(type_col_0,
R"([null])"));
+ auto group_null_col_0 = AnyOfScalar(ArrayFromJSON(type_col_0, R"([4.0,
0.75])"));
+
+ // Check values individually
+ const auto& col0 = struct_arr->field(0);
+ ASSERT_OK_AND_ASSIGN(const auto g_one, col0->GetScalar(0));
+ EXPECT_THAT(g_one, group_one_col_0);
+ ASSERT_OK_AND_ASSIGN(const auto g_two, col0->GetScalar(1));
+ EXPECT_THAT(g_two, group_two_col_0);
+ ASSERT_OK_AND_ASSIGN(const auto g_three, col0->GetScalar(2));
+ EXPECT_THAT(g_three, group_three_col_0);
+ ASSERT_OK_AND_ASSIGN(const auto g_null, col0->GetScalar(3));
+ EXPECT_THAT(g_null, group_null_col_0);
+
+ CountOptions all(CountOptions::ALL);
+ ASSERT_OK_AND_ASSIGN(
+ auto distinct_out,
+ internal::GroupBy(
+ {
+ table->GetColumnByName("argument0"),
+ table->GetColumnByName("argument1"),
+ table->GetColumnByName("argument2"),
+ },
+ {
+ table->GetColumnByName("key"),
+ },
+ {{"hash_distinct", &all}, {"hash_distinct", &all},
{"hash_distinct", &all}},
+ use_threads));
+ ValidateOutput(distinct_out);
+ SortBy({"key_0"}, &distinct_out);
+
+ const auto& struct_arr_distinct = distinct_out.array_as<StructArray>();
+ for (int64_t col = 0; col < struct_arr_distinct->length() - 1; ++col) {
+ const auto matcher = AnyOfScalarFromUniques(
+ checked_pointer_cast<ListArray>(struct_arr_distinct->field(col)));
+ EXPECT_THAT(struct_arr->field(col), matcher);
+ }
Review comment:
We can use other kernels, but I'm not sure this is any cleaner. The
other approach is repetitive, but clear about what's going on. This requires a
lot of thought to see what's happening.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]