icexelloss commented on code in PR #34311:
URL: https://github.com/apache/arrow/pull/34311#discussion_r1128375752
##########
cpp/src/arrow/compute/kernels/hash_aggregate_test.cc:
##########
@@ -301,53 +437,249 @@ Result<Datum> GroupByTest(const std::vector<Datum>&
arguments,
{t_agg.function, t_agg.options, "agg_" + ToChars(idx),
t_agg.function});
idx = idx + 1;
}
- return RunGroupBy(arguments, keys, internal_aggregates, use_threads);
+ return group_by(arguments, keys, segment_keys, internal_aggregates,
use_threads,
+ /*naive=*/false);
}
-} // namespace
+Result<Datum> GroupByTest(GroupByFunction group_by, const std::vector<Datum>&
arguments,
+ const std::vector<Datum>& keys,
+ const std::vector<TestAggregate>& aggregates,
+ bool use_threads) {
+ return GroupByTest(group_by, arguments, keys, {}, aggregates, use_threads);
+}
-TEST(Grouper, SupportedKeys) {
- ASSERT_OK(Grouper::Make({boolean()}));
+template <typename GroupClass>
+void TestGroupClassSupportedKeys(
+ std::function<Result<std::unique_ptr<GroupClass>>(const
std::vector<TypeHolder>&)>
+ make_func) {
+ ASSERT_OK(make_func({boolean()}));
- ASSERT_OK(Grouper::Make({int8(), uint16(), int32(), uint64()}));
+ ASSERT_OK(make_func({int8(), uint16(), int32(), uint64()}));
- ASSERT_OK(Grouper::Make({dictionary(int64(), utf8())}));
+ ASSERT_OK(make_func({dictionary(int64(), utf8())}));
- ASSERT_OK(Grouper::Make({float16(), float32(), float64()}));
+ ASSERT_OK(make_func({float16(), float32(), float64()}));
- ASSERT_OK(Grouper::Make({utf8(), binary(), large_utf8(), large_binary()}));
+ ASSERT_OK(make_func({utf8(), binary(), large_utf8(), large_binary()}));
- ASSERT_OK(Grouper::Make({fixed_size_binary(16), fixed_size_binary(32)}));
+ ASSERT_OK(make_func({fixed_size_binary(16), fixed_size_binary(32)}));
- ASSERT_OK(Grouper::Make({decimal128(32, 10), decimal256(76, 20)}));
+ ASSERT_OK(make_func({decimal128(32, 10), decimal256(76, 20)}));
- ASSERT_OK(Grouper::Make({date32(), date64()}));
+ ASSERT_OK(make_func({date32(), date64()}));
for (auto unit : {
TimeUnit::SECOND,
TimeUnit::MILLI,
TimeUnit::MICRO,
TimeUnit::NANO,
}) {
- ASSERT_OK(Grouper::Make({timestamp(unit), duration(unit)}));
+ ASSERT_OK(make_func({timestamp(unit), duration(unit)}));
}
ASSERT_OK(
- Grouper::Make({day_time_interval(), month_interval(),
month_day_nano_interval()}));
+ make_func({day_time_interval(), month_interval(),
month_day_nano_interval()}));
+
+ ASSERT_OK(make_func({null()}));
+
+ ASSERT_RAISES(NotImplemented, make_func({struct_({field("", int64())})}));
+
+ ASSERT_RAISES(NotImplemented, make_func({struct_({})}));
+
+ ASSERT_RAISES(NotImplemented, make_func({list(int32())}));
+
+ ASSERT_RAISES(NotImplemented, make_func({fixed_size_list(int32(), 5)}));
+
+ ASSERT_RAISES(NotImplemented, make_func({dense_union({field("",
int32())})}));
+}
+
+void TestSegments(std::unique_ptr<RowSegmenter>& segmenter, const ExecSpan&
batch,
+ std::vector<Segment> expected_segments) {
+ int64_t offset = 0, segment_num = 0;
+ for (auto expected_segment : expected_segments) {
+ SCOPED_TRACE("segment #" + ToChars(segment_num++));
+ ASSERT_OK_AND_ASSIGN(auto segment, segmenter->GetNextSegment(batch,
offset));
+ ASSERT_EQ(expected_segment, segment);
+ offset = segment.offset + segment.length;
+ }
+}
+
+Result<std::unique_ptr<Grouper>> MakeGrouper(const std::vector<TypeHolder>&
key_types) {
+ return Grouper::Make(key_types, default_exec_context());
+}
+
+Result<std::unique_ptr<RowSegmenter>> MakeRowSegmenter(
+ const std::vector<TypeHolder>& key_types) {
+ return RowSegmenter::Make(key_types, /*nullable_leys=*/false,
default_exec_context());
+}
+
+Result<std::unique_ptr<RowSegmenter>> MakeGenericSegmenter(
+ const std::vector<TypeHolder>& key_types) {
+ return MakeAnyKeysSegmenter(key_types, default_exec_context());
+}
+
+} // namespace
+
+TEST(RowSegmenter, SupportedKeys) {
+ TestGroupClassSupportedKeys<RowSegmenter>(MakeRowSegmenter);
+}
+
+TEST(RowSegmenter, Basics) {
+ std::vector<TypeHolder> bad_types2 = {int32(), float32()};
+ std::vector<TypeHolder> types2 = {int32(), int32()};
+ std::vector<TypeHolder> bad_types1 = {float32()};
+ std::vector<TypeHolder> types1 = {int32()};
+ std::vector<TypeHolder> types0 = {};
+ auto batch2 = ExecBatchFromJSON(types2, "[[1, 1], [1, 2], [2, 2]]");
Review Comment:
Can you add a few more test cases here?
A few ideas:
(1) Non-ordered segment case, e.g., `1, 1, 2, 2, 1, 1`
(2) Empty batches (This should included in the end-to-end test)
(3) More than 2 segments inside one partition
I suggest try to cover as much edge case you can here
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]