dhruv9vats commented on a change in pull request #12368:
URL: https://github.com/apache/arrow/pull/12368#discussion_r806444672
##########
File path: cpp/src/arrow/compute/kernels/hash_aggregate_test.cc
##########
@@ -2460,6 +2461,558 @@ TEST(GroupBy, Distinct) {
}
}
+MATCHER_P(AnyOfScalar, arrow_array, "") {
+ for (int64_t i = 0; i < arrow_array->length(); ++i) {
+ auto scalar = arrow_array->GetScalar(i).ValueOrDie();
+ if (scalar->Equals(arg)) return true;
+ }
+ *result_listener << "Argument scalar: '" << arg->ToString()
+ << "' matches no input scalar.";
+ return false;
+}
+
+MATCHER_P(AnyOfScalarFromUniques, unique_list, "") {
+ const auto& flatten = unique_list->Flatten().ValueOrDie();
+ const auto& offsets =
std::dynamic_pointer_cast<Int32Array>(unique_list->offsets());
+
+ for (int64_t i = 0; i < arg->length(); ++i) {
+ bool match_found = false;
+ const auto group_hash_one = arg->GetScalar(i).ValueOrDie();
+ int64_t start = offsets->Value(i);
+ int64_t end = offsets->Value(i + 1);
+ for (int64_t j = start; j < end; ++j) {
+ auto s = flatten->GetScalar(j).ValueOrDie();
+ if (s->Equals(group_hash_one)) {
+ match_found = true;
+ break;
+ }
+ }
+ if (!match_found) {
+ *result_listener << "Argument scalar: '" << group_hash_one->ToString()
+ << "' matches no input scalar.";
+ return false;
+ }
+ }
+ return true;
+}
+
+TEST(GroupBy, One) {
+ {
+ auto table =
+ TableFromJSON(schema({field("argument", int64()), field("key",
int64())}), {R"([
+ [99, 1],
+ [99, 1]
+])",
+
R"([
+ [77, 2],
+ [null, 3],
+ [null, 3]
+])",
+
R"([
+ [null, 4],
+ [null, 4]
+])",
+
R"([
+ [88, null],
+ [99, 3]
+])",
+
R"([
+ [77, 2],
+ [76, 2]
+])",
+
R"([
+ [75, null],
+ [74, 3]
+ ])",
+
R"([
+ [73, null],
+ [72, null]
+ ])"});
+
+ ASSERT_OK_AND_ASSIGN(auto aggregated_and_grouped,
+ internal::GroupBy(
+ {
+ table->GetColumnByName("argument"),
+ },
+ {
+ table->GetColumnByName("key"),
+ },
+ {
+ {"hash_one", nullptr},
+ },
+ false));
+ ValidateOutput(aggregated_and_grouped);
+ SortBy({"key_0"}, &aggregated_and_grouped);
+
+ AssertDatumsEqual(ArrayFromJSON(struct_({
+ field("hash_one", int64()),
+ field("key_0", int64()),
+ }),
+ R"([
+ [99, 1],
+ [77, 2],
+ [null, 3],
+ [null, 4],
+ [88, null]
+ ])"),
+ aggregated_and_grouped,
+ /*verbose=*/true);
+ }
+ {
+ auto table =
+ TableFromJSON(schema({field("argument", utf8()), field("key",
int64())}), {R"([
+ ["foo", 1],
+ ["foo", 1]
+ ])",
+
R"([
+ ["bar", 2],
+ [null, 3],
+ [null, 3]
+ ])",
+
R"([
+ [null, 4],
+ [null, 4]
+ ])",
+
R"([
+ ["baz", null],
+ ["foo", 3]
+ ])",
+
R"([
+ ["bar", 2],
+ ["spam", 2]
+ ])",
+
R"([
+ ["eggs", null],
+ ["ham", 3]
+ ])",
+
R"([
+ ["a", null],
+ ["b", null]
+ ])"});
+
+ ASSERT_OK_AND_ASSIGN(auto aggregated_and_grouped,
+ internal::GroupBy(
+ {
+ table->GetColumnByName("argument"),
+ },
+ {
+ table->GetColumnByName("key"),
+ },
+ {
+ {"hash_one", nullptr},
+ },
+ false));
+ ValidateOutput(aggregated_and_grouped);
+ SortBy({"key_0"}, &aggregated_and_grouped);
+
+ AssertDatumsEqual(ArrayFromJSON(struct_({
+ field("hash_one", utf8()),
+ field("key_0", int64()),
+ }),
+ R"([
+ ["foo", 1],
+ ["bar", 2],
+ [null, 3],
+ [null, 4],
+ ["baz", null]
+ ])"),
+ aggregated_and_grouped,
+ /*verbose=*/true);
+ }
+}
+
+TEST(GroupBy, OneOnly) {
+ auto in_schema = schema({
+ field("argument0", float64()),
+ field("argument1", null()),
+ field("argument2", boolean()),
+ field("key", int64()),
+ });
+ for (bool use_exec_plan : {false, true}) {
+ for (bool use_threads : {false, true}) {
+ SCOPED_TRACE(use_threads ? "parallel/merged" : "serial");
+
+ auto table = TableFromJSON(in_schema, {R"([
+ [1.0, null, true, 1],
+ [null, null, true, 1]
+])",
+ R"([
+ [0.0, null, false, 2],
+ [null, null, false, 3],
+ [4.0, null, null, null],
+ [3.25, null, true, 1],
+ [0.125, null, false, 2]
+])",
+ R"([
+ [-0.25, null, false, 2],
+ [0.75, null, true, null],
+ [null, null, true, 3]
+])"});
+
+ ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped,
+ GroupByTest(
+ {
+ table->GetColumnByName("argument0"),
+ table->GetColumnByName("argument1"),
+ table->GetColumnByName("argument2"),
+ },
+ {table->GetColumnByName("key")},
+ {
+ {"hash_one", nullptr},
+ {"hash_one", nullptr},
+ {"hash_one", nullptr},
+ },
+ use_threads, use_exec_plan));
+ ValidateOutput(aggregated_and_grouped);
+ SortBy({"key_0"}, &aggregated_and_grouped);
+
+ // AssertDatumsEqual(ArrayFromJSON(struct_({
+ // field("hash_one", float64()),
+ // field("hash_one", null()),
+ // field("hash_one", boolean()),
+ // field("key_0", int64()),
+ // }),
+ // R"([
+ // [1.0, null, true, 1],
+ // [0.0, null, false, 2],
+ // [null, null, false, 3],
+ // [4.0, null, null, null]
+ // ])"),
+ // aggregated_and_grouped,
+ // /*verbose=*/true);
+
+ const auto& struct_arr = aggregated_and_grouped.array_as<StructArray>();
+ // Check the key column
+ AssertDatumsEqual(ArrayFromJSON(int64(), "[1, 2, 3, null]"),
struct_arr->field(3));
+
+ auto type_col_0 = float64();
+ auto group_one_col_0 =
+ AnyOfScalar(ArrayFromJSON(type_col_0, R"([1.0, null, 3.25])"));
+ auto group_two_col_0 =
+ AnyOfScalar(ArrayFromJSON(type_col_0, R"([0.0, 0.125, -0.25])"));
+ auto group_three_col_0 = AnyOfScalar(ArrayFromJSON(type_col_0,
R"([null])"));
+ auto group_null_col_0 = AnyOfScalar(ArrayFromJSON(type_col_0, R"([4.0,
0.75])"));
+
+ // Check values individually
+ const auto& col0 = struct_arr->field(0);
+ ASSERT_OK_AND_ASSIGN(const auto g_one, col0->GetScalar(0));
+ EXPECT_THAT(g_one, group_one_col_0);
+ ASSERT_OK_AND_ASSIGN(const auto g_two, col0->GetScalar(1));
+ EXPECT_THAT(g_two, group_two_col_0);
+ ASSERT_OK_AND_ASSIGN(const auto g_three, col0->GetScalar(2));
+ EXPECT_THAT(g_three, group_three_col_0);
+ ASSERT_OK_AND_ASSIGN(const auto g_null, col0->GetScalar(3));
+ EXPECT_THAT(g_null, group_null_col_0);
+
+ CountOptions all(CountOptions::ALL);
+ ASSERT_OK_AND_ASSIGN(
+ auto distinct_out,
+ internal::GroupBy(
+ {
+ table->GetColumnByName("argument0"),
+ table->GetColumnByName("argument1"),
+ table->GetColumnByName("argument2"),
+ },
+ {
+ table->GetColumnByName("key"),
+ },
+ {{"hash_distinct", &all}, {"hash_distinct", &all},
{"hash_distinct", &all}},
+ use_threads));
+ ValidateOutput(distinct_out);
+ SortBy({"key_0"}, &distinct_out);
+
+ const auto& struct_arr_distinct = distinct_out.array_as<StructArray>();
+ for (int64_t col = 0; col < struct_arr_distinct->length() - 1; ++col) {
Review comment:
```suggestion
for (int64_t col = 0; col < struct_arr_distinct->num_fields() - 1;
++col) {
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]