lidavidm commented on a change in pull request #12368:
URL: https://github.com/apache/arrow/pull/12368#discussion_r804679025



##########
File path: cpp/src/arrow/compute/kernels/hash_aggregate_test.cc
##########
@@ -2460,6 +2460,476 @@ TEST(GroupBy, Distinct) {
   }
 }
 
+TEST(GroupBy, One) {
+  {
+    auto table =
+        TableFromJSON(schema({field("argument", int64()), field("key", 
int64())}), {R"([
+    [99,  1],
+    [99,  1]
+])",
+                                                                               
     R"([
+    [77,  2],
+    [null,   3],
+    [null,   3]
+])",
+                                                                               
     R"([
+    [null,   4],
+    [null,   4]
+])",
+                                                                               
   R"([
+    [88,  null],
+    [99,  3]
+])",
+                                                                               
   R"([
+    [77,  2],
+    [76, 2]
+])",
+                                                                               
   R"([
+    [75, null],
+    [74,  3]
+  ])",
+                                                                               
   R"([
+    [73,    null],
+    [72,    null]
+  ])"});
+
+  ASSERT_OK_AND_ASSIGN(auto aggregated_and_grouped,
+                       internal::GroupBy(
+                           {
+                               table->GetColumnByName("argument"),
+                           },
+                           {
+                               table->GetColumnByName("key"),
+                           },
+                           {
+                               {"hash_one", nullptr},
+                           },
+                           false));
+  ValidateOutput(aggregated_and_grouped);
+  SortBy({"key_0"}, &aggregated_and_grouped);
+
+  AssertDatumsEqual(ArrayFromJSON(struct_({
+                                      field("hash_one", int64()),
+                                      field("key_0", int64()),
+                                  }),
+                                  R"([
+      [99, 1],
+      [77, 2],
+      [null,  3],
+      [null,  4],
+      [88, null]
+    ])"),
+                    aggregated_and_grouped,
+                    /*verbose=*/true);
+  }
+  {
+    auto table =
+        TableFromJSON(schema({field("argument", utf8()), field("key", 
int64())}), {R"([
+     ["foo",  1],
+     ["foo",  1]
+ ])",
+                                                                               
    R"([
+     ["bar",  2],
+     [null,   3],
+     [null,   3]
+ ])",
+                                                                               
    R"([
+     [null,   4],
+     [null,   4]
+ ])",
+                                                                               
    R"([
+     ["baz",  null],
+     ["foo",  3]
+ ])",
+                                                                               
    R"([
+     ["bar",  2],
+     ["spam", 2]
+ ])",
+                                                                               
    R"([
+     ["eggs", null],
+     ["ham",  3]
+   ])",
+                                                                               
    R"([
+     ["a",    null],
+     ["b",    null]
+   ])"});
+
+    ASSERT_OK_AND_ASSIGN(auto aggregated_and_grouped,
+                         internal::GroupBy(
+                             {
+                                 table->GetColumnByName("argument"),
+                             },
+                             {
+                                 table->GetColumnByName("key"),
+                             },
+                             {
+                                 {"hash_one", nullptr},
+                             },
+                             false));
+    ValidateOutput(aggregated_and_grouped);
+    SortBy({"key_0"}, &aggregated_and_grouped);
+
+    AssertDatumsEqual(ArrayFromJSON(struct_({
+                                        field("hash_one", utf8()),
+                                        field("key_0", int64()),
+                                    }),
+                                    R"([
+       ["foo", 1],
+       ["bar", 2],
+       [null,  3],
+       [null,  4],
+       ["baz", null]
+     ])"),
+                      aggregated_and_grouped,
+                      /*verbose=*/true);
+  }
+}
+
+TEST(GroupBy, OneOnly) {
+  auto in_schema = schema({
+      field("argument0", float64()),
+      field("argument1", null()),
+      field("argument2", boolean()),
+      field("key", int64()),
+  });
+  for (bool use_exec_plan : {false, true}) {
+    for (bool use_threads : {false}) {
+      SCOPED_TRACE(use_threads ? "parallel/merged" : "serial");
+
+      auto table = TableFromJSON(in_schema, {R"([
+    [1.0,   null, true, 1],
+    [null,  null, true, 1]
+])",
+                                             R"([
+    [0.0,   null, false, 2],
+    [null,  null, false, 3],
+    [4.0,   null, null,  null],
+    [3.25,  null, true,  1],
+    [0.125, null, false, 2]
+])",
+                                             R"([
+    [-0.25, null, false, 2],
+    [0.75,  null, true,  null],
+    [null,  null, true,  3]
+])"});
+
+      ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped,
+                           GroupByTest(
+                               {
+                                   table->GetColumnByName("argument0"),
+                                   table->GetColumnByName("argument1"),
+                                   table->GetColumnByName("argument2"),
+                               },
+                               {table->GetColumnByName("key")},
+                               {
+                                   {"hash_one", nullptr},
+                                   {"hash_one", nullptr},
+                                   {"hash_one", nullptr},
+                               },
+                               use_threads, use_exec_plan));
+      ValidateOutput(aggregated_and_grouped);
+      SortBy({"key_0"}, &aggregated_and_grouped);
+
+      AssertDatumsEqual(ArrayFromJSON(struct_({
+                                          field("hash_one", float64()),
+                                          field("hash_one", null()),
+                                          field("hash_one", boolean()),
+                                          field("key_0", int64()),
+                                      }),
+                                      R"([
+    [1.0,  null, true,  1],
+    [0.0,  null, false, 2],
+    [null, null, false, 3],
+    [4.0,  null, null,  null]
+  ])"),
+                        aggregated_and_grouped,
+                        /*verbose=*/true);
+    }
+  }
+}
+
+TEST(GroupBy, OneTypes) {
+  std::vector<std::shared_ptr<DataType>> types;
+  types.insert(types.end(), NumericTypes().begin(), NumericTypes().end());
+  types.insert(types.end(), TemporalTypes().begin(), TemporalTypes().end());
+  types.push_back(month_interval());
+
+  const std::vector<std::string> default_table = {R"([
+    [1,    1],
+    [null, 1]
+])",
+                                                  R"([
+    [0,    2],
+    [null, 3],
+    [3,    4],
+    [5,    4],
+    [4,    null],
+    [3,    1],
+    [0,    2]
+])",
+                                                  R"([
+    [0,    2],
+    [1,    null],
+    [null, 3]
+])"};
+
+  const std::vector<std::string> date64_table = {R"([
+    [86400000,  1],
+    [null,      1]
+])",
+                                                 R"([
+    [0,         2],
+    [null,      3],
+    [259200000, 4],
+    [432000000, 4],
+    [345600000, null],
+    [259200000, 1],
+    [0,         2]
+])",
+                                                 R"([
+    [0,         2],
+    [86400000,  null],
+    [null,      3]
+])"};
+
+  const std::string default_expected =
+      R"([
+    [1,    1],
+    [0,    2],
+    [null, 3],
+    [3,    4],
+    [4,    null]
+    ])";
+
+  const std::string date64_expected =
+      R"([
+    [86400000,  1],
+    [0,         2],
+    [null,      3],
+    [259200000, 4],
+    [345600000, null]
+    ])";
+
+  for (const auto& ty : types) {
+    SCOPED_TRACE(ty->ToString());
+    auto in_schema = schema({field("argument0", ty), field("key", int64())});
+    auto table =
+        TableFromJSON(in_schema, (ty->name() == "date64") ? date64_table : 
default_table);
+
+    ASSERT_OK_AND_ASSIGN(
+        Datum aggregated_and_grouped,
+        GroupByTest({table->GetColumnByName("argument0")},
+                    {table->GetColumnByName("key")}, {{"hash_one", nullptr}},
+                    /*use_threads=*/false, /*use_exec_plan=*/true));
+    ValidateOutput(aggregated_and_grouped);
+    SortBy({"key_0"}, &aggregated_and_grouped);
+
+    AssertDatumsEqual(
+        ArrayFromJSON(struct_({
+                          field("hash_one", ty),
+                          field("key_0", int64()),
+                      }),
+                      (ty->name() == "date64") ? date64_expected : 
default_expected),
+        aggregated_and_grouped,
+        /*verbose=*/true);
+  }
+}
+
+TEST(GroupBy, OneDecimal) {
+  auto in_schema = schema({
+      field("argument0", decimal128(3, 2)),
+      field("argument1", decimal256(3, 2)),
+      field("key", int64()),
+  });
+  for (bool use_exec_plan : {false, true}) {
+    for (bool use_threads : {/*true, */ false}) {

Review comment:
       Or you can try to make AnyOf work with EXPECT_THAT: 
https://github.com/google/googletest/blob/main/docs/reference/matchers.md#composite-matchers




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to