icexelloss commented on code in PR #34311:
URL: https://github.com/apache/arrow/pull/34311#discussion_r1128152483
##########
cpp/src/arrow/compute/kernels/hash_aggregate_test.cc:
##########
@@ -301,53 +439,249 @@ Result<Datum> GroupByTest(const std::vector<Datum>&
arguments,
{t_agg.function, t_agg.options, "agg_" + ToChars(idx),
t_agg.function});
idx = idx + 1;
}
- return RunGroupBy(arguments, keys, internal_aggregates, use_threads);
+ return group_by(arguments, keys, segment_keys, internal_aggregates,
use_threads,
+ /*naive=*/false);
}
-} // namespace
+Result<Datum> GroupByTest(GroupByFunction group_by, const std::vector<Datum>&
arguments,
+ const std::vector<Datum>& keys,
+ const std::vector<TestAggregate>& aggregates,
+ bool use_threads) {
+ return GroupByTest(group_by, arguments, keys, {}, aggregates, use_threads);
+}
-TEST(Grouper, SupportedKeys) {
- ASSERT_OK(Grouper::Make({boolean()}));
+template <typename GroupClass>
+void TestGroupClassSupportedKeys(
+ std::function<Result<std::unique_ptr<GroupClass>>(const
std::vector<TypeHolder>&)>
+ make_func) {
+ ASSERT_OK(make_func({boolean()}));
- ASSERT_OK(Grouper::Make({int8(), uint16(), int32(), uint64()}));
+ ASSERT_OK(make_func({int8(), uint16(), int32(), uint64()}));
- ASSERT_OK(Grouper::Make({dictionary(int64(), utf8())}));
+ ASSERT_OK(make_func({dictionary(int64(), utf8())}));
- ASSERT_OK(Grouper::Make({float16(), float32(), float64()}));
+ ASSERT_OK(make_func({float16(), float32(), float64()}));
- ASSERT_OK(Grouper::Make({utf8(), binary(), large_utf8(), large_binary()}));
+ ASSERT_OK(make_func({utf8(), binary(), large_utf8(), large_binary()}));
- ASSERT_OK(Grouper::Make({fixed_size_binary(16), fixed_size_binary(32)}));
+ ASSERT_OK(make_func({fixed_size_binary(16), fixed_size_binary(32)}));
- ASSERT_OK(Grouper::Make({decimal128(32, 10), decimal256(76, 20)}));
+ ASSERT_OK(make_func({decimal128(32, 10), decimal256(76, 20)}));
- ASSERT_OK(Grouper::Make({date32(), date64()}));
+ ASSERT_OK(make_func({date32(), date64()}));
for (auto unit : {
TimeUnit::SECOND,
TimeUnit::MILLI,
TimeUnit::MICRO,
TimeUnit::NANO,
}) {
- ASSERT_OK(Grouper::Make({timestamp(unit), duration(unit)}));
+ ASSERT_OK(make_func({timestamp(unit), duration(unit)}));
}
ASSERT_OK(
- Grouper::Make({day_time_interval(), month_interval(),
month_day_nano_interval()}));
+ make_func({day_time_interval(), month_interval(),
month_day_nano_interval()}));
+
+ ASSERT_OK(make_func({null()}));
- ASSERT_OK(Grouper::Make({null()}));
+ ASSERT_RAISES(NotImplemented, make_func({struct_({field("", int64())})}));
- ASSERT_RAISES(NotImplemented, Grouper::Make({struct_({field("",
int64())})}));
+ ASSERT_RAISES(NotImplemented, make_func({struct_({})}));
- ASSERT_RAISES(NotImplemented, Grouper::Make({struct_({})}));
+ ASSERT_RAISES(NotImplemented, make_func({list(int32())}));
- ASSERT_RAISES(NotImplemented, Grouper::Make({list(int32())}));
+ ASSERT_RAISES(NotImplemented, make_func({fixed_size_list(int32(), 5)}));
+
+ ASSERT_RAISES(NotImplemented, make_func({dense_union({field("",
int32())})}));
+}
- ASSERT_RAISES(NotImplemented, Grouper::Make({fixed_size_list(int32(), 5)}));
+void TestSegments(std::unique_ptr<RowSegmenter>& segmenter, const ExecSpan&
batch,
+ std::vector<Segment> expected_segments) {
+ int64_t offset = 0, segment_num = 0;
+ for (auto expected_segment : expected_segments) {
+ SCOPED_TRACE("segment #" + ToChars(segment_num++));
+ ASSERT_OK_AND_ASSIGN(auto segment, segmenter->GetNextSegment(batch,
offset));
+ ASSERT_EQ(expected_segment, segment);
+ offset = segment.offset + segment.length;
+ }
+}
+
+Result<std::unique_ptr<Grouper>> MakeGrouper(const std::vector<TypeHolder>&
key_types) {
+ return Grouper::Make(key_types, default_exec_context());
+}
+
+Result<std::unique_ptr<RowSegmenter>> MakeRowSegmenter(
+ const std::vector<TypeHolder>& key_types) {
+ return RowSegmenter::Make(key_types, /*nullable_leys=*/false,
default_exec_context());
+}
+
+Result<std::unique_ptr<RowSegmenter>> MakeGenericSegmenter(
+ const std::vector<TypeHolder>& key_types) {
+ return MakeAnyKeysSegmenter(key_types, default_exec_context());
+}
+
+} // namespace
+
+TEST(RowSegmenter, SupportedKeys) {
+ TestGroupClassSupportedKeys<RowSegmenter>(MakeRowSegmenter);
+}
+
+TEST(RowSegmenter, Basics) {
+ std::vector<TypeHolder> bad_types2 = {int32(), float32()};
+ std::vector<TypeHolder> types2 = {int32(), int32()};
+ std::vector<TypeHolder> bad_types1 = {float32()};
+ std::vector<TypeHolder> types1 = {int32()};
+ std::vector<TypeHolder> types0 = {};
+ auto batch2 = ExecBatchFromJSON(types2, "[[1, 1], [1, 2], [2, 2]]");
+ auto batch1 = ExecBatchFromJSON(types1, "[[1], [1], [2]]");
+ ExecBatch batch0({}, 3);
+ {
+ SCOPED_TRACE("offset");
+ ASSERT_OK_AND_ASSIGN(auto segmenter, MakeRowSegmenter(types0));
+ ExecSpan span0(batch0);
+ for (int64_t offset : {-1, 4}) {
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
+ HasSubstr("invalid grouping segmenter
offset"),
+ segmenter->GetNextSegment(span0,
offset));
+ }
+ }
+ {
+ SCOPED_TRACE("types0 segmenting of batch2");
+ ASSERT_OK_AND_ASSIGN(auto segmenter, MakeRowSegmenter(types0));
+ ExecSpan span2(batch2);
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, HasSubstr("expected batch size 0
"),
+ segmenter->GetNextSegment(span2, 0));
+ ExecSpan span0(batch0);
+ TestSegments(segmenter, span0, {{0, 3, true, true}, {3, 0, true, true}});
+ }
+ {
+ SCOPED_TRACE("bad_types1 segmenting of batch1");
+ ASSERT_OK_AND_ASSIGN(auto segmenter, MakeRowSegmenter(bad_types1));
+ ExecSpan span1(batch1);
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, HasSubstr("expected batch value 0
of type "),
+ segmenter->GetNextSegment(span1, 0));
+ }
+ {
+ SCOPED_TRACE("types1 segmenting of batch2");
+ ASSERT_OK_AND_ASSIGN(auto segmenter, MakeRowSegmenter(types1));
+ ExecSpan span2(batch2);
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, HasSubstr("expected batch size 1
"),
+ segmenter->GetNextSegment(span2, 0));
+ ExecSpan span1(batch1);
+ TestSegments(segmenter, span1,
+ {{0, 2, false, true}, {2, 1, true, false}, {3, 0, true,
true}});
+ }
+ {
+ SCOPED_TRACE("bad_types2 segmenting of batch2");
+ ASSERT_OK_AND_ASSIGN(auto segmenter, MakeRowSegmenter(bad_types2));
+ ExecSpan span2(batch2);
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, HasSubstr("expected batch value 1
of type "),
+ segmenter->GetNextSegment(span2, 0));
+ }
+ {
+ SCOPED_TRACE("types2 segmenting of batch1");
+ ASSERT_OK_AND_ASSIGN(auto segmenter, MakeRowSegmenter(types2));
+ ExecSpan span1(batch1);
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, HasSubstr("expected batch size 2
"),
+ segmenter->GetNextSegment(span1, 0));
+ ExecSpan span2(batch2);
+ TestSegments(segmenter, span2,
+ {{0, 1, false, true},
+ {1, 1, false, false},
+ {2, 1, true, false},
+ {3, 0, true, true}});
+ }
+}
+
+namespace {
+
+void test_row_segmenter_constant_batch(
Review Comment:
Rename to `TestRowSegmenterConstantBatch`
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]