dhruv9vats commented on a change in pull request #12368:
URL: https://github.com/apache/arrow/pull/12368#discussion_r804579744



##########
File path: cpp/src/arrow/compute/kernels/hash_aggregate_test.cc
##########
@@ -2460,6 +2460,476 @@ TEST(GroupBy, Distinct) {
   }
 }
 
+TEST(GroupBy, One) {

Review comment:
       This test can probably be deleted as tests below capture this.

##########
File path: cpp/src/arrow/compute/kernels/hash_aggregate_test.cc
##########
@@ -2460,6 +2460,476 @@ TEST(GroupBy, Distinct) {
   }
 }
 
+TEST(GroupBy, One) {
+  {
+    auto table =
+        TableFromJSON(schema({field("argument", int64()), field("key", 
int64())}), {R"([
+    [99,  1],
+    [99,  1]
+])",
+                                                                               
     R"([
+    [77,  2],
+    [null,   3],
+    [null,   3]
+])",
+                                                                               
     R"([
+    [null,   4],
+    [null,   4]
+])",
+                                                                               
   R"([
+    [88,  null],
+    [99,  3]
+])",
+                                                                               
   R"([
+    [77,  2],
+    [76, 2]
+])",
+                                                                               
   R"([
+    [75, null],
+    [74,  3]
+  ])",
+                                                                               
   R"([
+    [73,    null],
+    [72,    null]
+  ])"});
+
+  ASSERT_OK_AND_ASSIGN(auto aggregated_and_grouped,
+                       internal::GroupBy(
+                           {
+                               table->GetColumnByName("argument"),
+                           },
+                           {
+                               table->GetColumnByName("key"),
+                           },
+                           {
+                               {"hash_one", nullptr},
+                           },
+                           false));
+  ValidateOutput(aggregated_and_grouped);
+  SortBy({"key_0"}, &aggregated_and_grouped);
+
+  AssertDatumsEqual(ArrayFromJSON(struct_({
+                                      field("hash_one", int64()),
+                                      field("key_0", int64()),
+                                  }),
+                                  R"([
+      [99, 1],
+      [77, 2],
+      [null,  3],
+      [null,  4],
+      [88, null]
+    ])"),
+                    aggregated_and_grouped,
+                    /*verbose=*/true);
+  }
+  {
+    auto table =
+        TableFromJSON(schema({field("argument", utf8()), field("key", 
int64())}), {R"([
+     ["foo",  1],
+     ["foo",  1]
+ ])",
+                                                                               
    R"([
+     ["bar",  2],
+     [null,   3],
+     [null,   3]
+ ])",
+                                                                               
    R"([
+     [null,   4],
+     [null,   4]
+ ])",
+                                                                               
    R"([
+     ["baz",  null],
+     ["foo",  3]
+ ])",
+                                                                               
    R"([
+     ["bar",  2],
+     ["spam", 2]
+ ])",
+                                                                               
    R"([
+     ["eggs", null],
+     ["ham",  3]
+   ])",
+                                                                               
    R"([
+     ["a",    null],
+     ["b",    null]
+   ])"});
+
+    ASSERT_OK_AND_ASSIGN(auto aggregated_and_grouped,
+                         internal::GroupBy(
+                             {
+                                 table->GetColumnByName("argument"),
+                             },
+                             {
+                                 table->GetColumnByName("key"),
+                             },
+                             {
+                                 {"hash_one", nullptr},
+                             },
+                             false));
+    ValidateOutput(aggregated_and_grouped);
+    SortBy({"key_0"}, &aggregated_and_grouped);
+
+    AssertDatumsEqual(ArrayFromJSON(struct_({
+                                        field("hash_one", utf8()),
+                                        field("key_0", int64()),
+                                    }),
+                                    R"([
+       ["foo", 1],
+       ["bar", 2],
+       [null,  3],
+       [null,  4],
+       ["baz", null]
+     ])"),
+                      aggregated_and_grouped,
+                      /*verbose=*/true);
+  }
+}
+
+TEST(GroupBy, OneOnly) {
+  auto in_schema = schema({
+      field("argument0", float64()),
+      field("argument1", null()),
+      field("argument2", boolean()),
+      field("key", int64()),
+  });
+  for (bool use_exec_plan : {false, true}) {
+    for (bool use_threads : {false}) {
+      SCOPED_TRACE(use_threads ? "parallel/merged" : "serial");
+
+      auto table = TableFromJSON(in_schema, {R"([
+    [1.0,   null, true, 1],
+    [null,  null, true, 1]
+])",
+                                             R"([
+    [0.0,   null, false, 2],
+    [null,  null, false, 3],
+    [4.0,   null, null,  null],
+    [3.25,  null, true,  1],
+    [0.125, null, false, 2]
+])",
+                                             R"([
+    [-0.25, null, false, 2],
+    [0.75,  null, true,  null],
+    [null,  null, true,  3]
+])"});
+
+      ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped,
+                           GroupByTest(
+                               {
+                                   table->GetColumnByName("argument0"),
+                                   table->GetColumnByName("argument1"),
+                                   table->GetColumnByName("argument2"),
+                               },
+                               {table->GetColumnByName("key")},
+                               {
+                                   {"hash_one", nullptr},
+                                   {"hash_one", nullptr},
+                                   {"hash_one", nullptr},
+                               },
+                               use_threads, use_exec_plan));
+      ValidateOutput(aggregated_and_grouped);
+      SortBy({"key_0"}, &aggregated_and_grouped);
+
+      AssertDatumsEqual(ArrayFromJSON(struct_({
+                                          field("hash_one", float64()),
+                                          field("hash_one", null()),
+                                          field("hash_one", boolean()),
+                                          field("key_0", int64()),
+                                      }),
+                                      R"([
+    [1.0,  null, true,  1],
+    [0.0,  null, false, 2],
+    [null, null, false, 3],
+    [4.0,  null, null,  null]
+  ])"),
+                        aggregated_and_grouped,
+                        /*verbose=*/true);
+    }
+  }
+}
+
+TEST(GroupBy, OneTypes) {
+  std::vector<std::shared_ptr<DataType>> types;
+  types.insert(types.end(), NumericTypes().begin(), NumericTypes().end());
+  types.insert(types.end(), TemporalTypes().begin(), TemporalTypes().end());
+  types.push_back(month_interval());
+
+  const std::vector<std::string> default_table = {R"([
+    [1,    1],
+    [null, 1]
+])",
+                                                  R"([
+    [0,    2],
+    [null, 3],
+    [3,    4],
+    [5,    4],
+    [4,    null],
+    [3,    1],
+    [0,    2]
+])",
+                                                  R"([
+    [0,    2],
+    [1,    null],
+    [null, 3]
+])"};
+
+  const std::vector<std::string> date64_table = {R"([
+    [86400000,  1],
+    [null,      1]
+])",
+                                                 R"([
+    [0,         2],
+    [null,      3],
+    [259200000, 4],
+    [432000000, 4],
+    [345600000, null],
+    [259200000, 1],
+    [0,         2]
+])",
+                                                 R"([
+    [0,         2],
+    [86400000,  null],
+    [null,      3]
+])"};
+
+  const std::string default_expected =
+      R"([
+    [1,    1],
+    [0,    2],
+    [null, 3],
+    [3,    4],
+    [4,    null]
+    ])";
+
+  const std::string date64_expected =
+      R"([
+    [86400000,  1],
+    [0,         2],
+    [null,      3],
+    [259200000, 4],
+    [345600000, null]
+    ])";
+
+  for (const auto& ty : types) {
+    SCOPED_TRACE(ty->ToString());
+    auto in_schema = schema({field("argument0", ty), field("key", int64())});
+    auto table =
+        TableFromJSON(in_schema, (ty->name() == "date64") ? date64_table : 
default_table);
+
+    ASSERT_OK_AND_ASSIGN(
+        Datum aggregated_and_grouped,
+        GroupByTest({table->GetColumnByName("argument0")},
+                    {table->GetColumnByName("key")}, {{"hash_one", nullptr}},
+                    /*use_threads=*/false, /*use_exec_plan=*/true));
+    ValidateOutput(aggregated_and_grouped);
+    SortBy({"key_0"}, &aggregated_and_grouped);
+
+    AssertDatumsEqual(
+        ArrayFromJSON(struct_({
+                          field("hash_one", ty),
+                          field("key_0", int64()),
+                      }),
+                      (ty->name() == "date64") ? date64_expected : 
default_expected),
+        aggregated_and_grouped,
+        /*verbose=*/true);
+  }
+}
+
+TEST(GroupBy, OneDecimal) {
+  auto in_schema = schema({
+      field("argument0", decimal128(3, 2)),
+      field("argument1", decimal256(3, 2)),
+      field("key", int64()),
+  });
+  for (bool use_exec_plan : {false, true}) {
+    for (bool use_threads : {/*true, */ false}) {

Review comment:
       The `true` option for `use_threads` is greyed out in all these test 
because, well, they run parallelly, and the way `hash_one` is currently 
implemented, when it encounters a value (null or otherwise) for a group for the 
first time, it stores that value and never updates it.
   For example, in the table below, 
   - If we go serially, `["4.01", "4.01",   null]` is the first row encountered 
which has `null` as its key. But,
   - If we go parallelly, `["0.75",  "0.75",  null]` is the first row 
encountered which has `null` as its key.
   
   And while this is totally expected, this is a bit tricky to test for in a 
single go. So should we test explicitly for both `use_threads` = `true` and 
`false`?

##########
File path: cpp/src/arrow/compute/kernels/hash_aggregate_test.cc
##########
@@ -2460,6 +2460,476 @@ TEST(GroupBy, Distinct) {
   }
 }
 
+TEST(GroupBy, One) {
+  {
+    auto table =
+        TableFromJSON(schema({field("argument", int64()), field("key", 
int64())}), {R"([
+    [99,  1],
+    [99,  1]
+])",
+                                                                               
     R"([
+    [77,  2],
+    [null,   3],
+    [null,   3]
+])",
+                                                                               
     R"([
+    [null,   4],
+    [null,   4]
+])",
+                                                                               
   R"([
+    [88,  null],
+    [99,  3]
+])",
+                                                                               
   R"([
+    [77,  2],
+    [76, 2]
+])",
+                                                                               
   R"([
+    [75, null],
+    [74,  3]
+  ])",
+                                                                               
   R"([
+    [73,    null],
+    [72,    null]
+  ])"});
+
+  ASSERT_OK_AND_ASSIGN(auto aggregated_and_grouped,
+                       internal::GroupBy(
+                           {
+                               table->GetColumnByName("argument"),
+                           },
+                           {
+                               table->GetColumnByName("key"),
+                           },
+                           {
+                               {"hash_one", nullptr},
+                           },
+                           false));
+  ValidateOutput(aggregated_and_grouped);
+  SortBy({"key_0"}, &aggregated_and_grouped);
+
+  AssertDatumsEqual(ArrayFromJSON(struct_({
+                                      field("hash_one", int64()),
+                                      field("key_0", int64()),
+                                  }),
+                                  R"([
+      [99, 1],
+      [77, 2],
+      [null,  3],
+      [null,  4],
+      [88, null]
+    ])"),
+                    aggregated_and_grouped,
+                    /*verbose=*/true);
+  }
+  {
+    auto table =
+        TableFromJSON(schema({field("argument", utf8()), field("key", 
int64())}), {R"([
+     ["foo",  1],
+     ["foo",  1]
+ ])",
+                                                                               
    R"([
+     ["bar",  2],
+     [null,   3],
+     [null,   3]
+ ])",
+                                                                               
    R"([
+     [null,   4],
+     [null,   4]
+ ])",
+                                                                               
    R"([
+     ["baz",  null],
+     ["foo",  3]
+ ])",
+                                                                               
    R"([
+     ["bar",  2],
+     ["spam", 2]
+ ])",
+                                                                               
    R"([
+     ["eggs", null],
+     ["ham",  3]
+   ])",
+                                                                               
    R"([
+     ["a",    null],
+     ["b",    null]
+   ])"});
+
+    ASSERT_OK_AND_ASSIGN(auto aggregated_and_grouped,
+                         internal::GroupBy(
+                             {
+                                 table->GetColumnByName("argument"),
+                             },
+                             {
+                                 table->GetColumnByName("key"),
+                             },
+                             {
+                                 {"hash_one", nullptr},
+                             },
+                             false));
+    ValidateOutput(aggregated_and_grouped);
+    SortBy({"key_0"}, &aggregated_and_grouped);
+
+    AssertDatumsEqual(ArrayFromJSON(struct_({
+                                        field("hash_one", utf8()),
+                                        field("key_0", int64()),
+                                    }),
+                                    R"([
+       ["foo", 1],
+       ["bar", 2],
+       [null,  3],
+       [null,  4],
+       ["baz", null]
+     ])"),
+                      aggregated_and_grouped,
+                      /*verbose=*/true);
+  }
+}
+
+TEST(GroupBy, OneOnly) {
+  auto in_schema = schema({
+      field("argument0", float64()),
+      field("argument1", null()),
+      field("argument2", boolean()),
+      field("key", int64()),
+  });
+  for (bool use_exec_plan : {false, true}) {
+    for (bool use_threads : {false}) {
+      SCOPED_TRACE(use_threads ? "parallel/merged" : "serial");
+
+      auto table = TableFromJSON(in_schema, {R"([
+    [1.0,   null, true, 1],
+    [null,  null, true, 1]
+])",
+                                             R"([
+    [0.0,   null, false, 2],
+    [null,  null, false, 3],
+    [4.0,   null, null,  null],
+    [3.25,  null, true,  1],
+    [0.125, null, false, 2]
+])",
+                                             R"([
+    [-0.25, null, false, 2],
+    [0.75,  null, true,  null],
+    [null,  null, true,  3]
+])"});
+
+      ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped,
+                           GroupByTest(
+                               {
+                                   table->GetColumnByName("argument0"),
+                                   table->GetColumnByName("argument1"),
+                                   table->GetColumnByName("argument2"),
+                               },
+                               {table->GetColumnByName("key")},
+                               {
+                                   {"hash_one", nullptr},
+                                   {"hash_one", nullptr},
+                                   {"hash_one", nullptr},
+                               },
+                               use_threads, use_exec_plan));
+      ValidateOutput(aggregated_and_grouped);
+      SortBy({"key_0"}, &aggregated_and_grouped);
+
+      AssertDatumsEqual(ArrayFromJSON(struct_({
+                                          field("hash_one", float64()),
+                                          field("hash_one", null()),
+                                          field("hash_one", boolean()),
+                                          field("key_0", int64()),
+                                      }),
+                                      R"([
+    [1.0,  null, true,  1],
+    [0.0,  null, false, 2],
+    [null, null, false, 3],
+    [4.0,  null, null,  null]
+  ])"),
+                        aggregated_and_grouped,
+                        /*verbose=*/true);
+    }
+  }
+}
+
+TEST(GroupBy, OneTypes) {

Review comment:
       Most of these tests are carried over from MinMax tests with appropriate 
modifications. Is that okay? Should some more tests also be added?

##########
File path: cpp/src/arrow/compute/kernels/hash_aggregate.cc
##########
@@ -2451,6 +2451,333 @@ Result<std::unique_ptr<KernelState>> 
GroupedDistinctInit(KernelContext* ctx,
   return std::move(impl);
 }
 
+// ----------------------------------------------------------------------
+// One implementation
+
+template <typename Type, typename Enable = void>
+struct GroupedOneImpl final : public GroupedAggregator {
+  using CType = typename TypeTraits<Type>::CType;
+  using GetSet = GroupedValueTraits<Type>;
+
+  Status Init(ExecContext* ctx, const std::vector<ValueDescr>&,
+              const FunctionOptions* options) override {
+    // out_type_ initialized by GroupedOneInit
+    ones_ = TypedBufferBuilder<CType>(ctx->memory_pool());
+    has_one_ = TypedBufferBuilder<bool>(ctx->memory_pool());
+    has_value_ = TypedBufferBuilder<bool>(ctx->memory_pool());
+    return Status::OK();
+  }
+
+  Status Resize(int64_t new_num_groups) override {
+    auto added_groups = new_num_groups - num_groups_;
+    num_groups_ = new_num_groups;
+    RETURN_NOT_OK(ones_.Append(added_groups, static_cast<CType>(0)));
+    RETURN_NOT_OK(has_one_.Append(added_groups, false));
+    RETURN_NOT_OK(has_value_.Append(added_groups, false));
+    return Status::OK();
+  }
+
+  Status Consume(const ExecBatch& batch) override {
+    auto raw_ones_ = ones_.mutable_data();
+
+    return VisitGroupedValues<Type>(
+        batch,
+        [&](uint32_t g, CType val) -> Status {
+          if (!bit_util::GetBit(has_one_.data(), g)) {
+            GetSet::Set(raw_ones_, g, val);
+            bit_util::SetBit(has_one_.mutable_data(), g);
+            bit_util::SetBit(has_value_.mutable_data(), g);
+          }
+          return Status::OK();
+        },
+        [&](uint32_t g) -> Status {
+          bit_util::SetBit(has_one_.mutable_data(), g);
+          return Status::OK();
+        });
+  }
+
+  Status Merge(GroupedAggregator&& raw_other,
+               const ArrayData& group_id_mapping) override {
+    auto other = checked_cast<GroupedOneImpl*>(&raw_other);
+
+    auto raw_ones = ones_.mutable_data();
+    auto other_raw_ones = other->ones_.mutable_data();
+
+    auto g = group_id_mapping.GetValues<uint32_t>(1);
+    for (uint32_t other_g = 0; static_cast<int64_t>(other_g) < 
group_id_mapping.length;
+         ++other_g, ++g) {
+      if (!bit_util::GetBit(has_one_.data(), *g)) {
+        if (bit_util::GetBit(other->has_value_.data(), other_g)) {
+          GetSet::Set(raw_ones, *g, GetSet::Get(other_raw_ones, other_g));
+          bit_util::SetBit(has_value_.mutable_data(), *g);
+        }
+        bit_util::SetBit(has_one_.mutable_data(), *g);
+      }
+    }
+
+    return Status::OK();
+  }
+
+  Result<Datum> Finalize() override {
+    ARROW_ASSIGN_OR_RAISE(auto null_bitmap, has_value_.Finish());
+    ARROW_ASSIGN_OR_RAISE(auto data, ones_.Finish());
+    return ArrayData::Make(out_type_, num_groups_,
+                           {std::move(null_bitmap), std::move(data)});
+  }
+
+  std::shared_ptr<DataType> out_type() const override { return out_type_; }
+
+  int64_t num_groups_;
+  TypedBufferBuilder<CType> ones_;
+  TypedBufferBuilder<bool> has_one_, has_value_;
+  std::shared_ptr<DataType> out_type_;
+};
+
+struct GroupedNullOneImpl : public GroupedAggregator {
+  Status Init(ExecContext* ctx, const std::vector<ValueDescr>&,
+              const FunctionOptions* options) override {
+    return Status::OK();
+  }
+
+  Status Resize(int64_t new_num_groups) override {
+    num_groups_ = new_num_groups;
+    return Status::OK();
+  }
+
+  Status Consume(const ExecBatch& batch) override { return Status::OK(); }
+
+  Status Merge(GroupedAggregator&& raw_other,
+               const ArrayData& group_id_mapping) override {
+    return Status::OK();
+  }
+
+  Result<Datum> Finalize() override {
+    return ArrayData::Make(null(), num_groups_, {nullptr}, num_groups_);
+  }
+
+  std::shared_ptr<DataType> out_type() const override { return null(); }
+
+  int64_t num_groups_;
+};
+
+template <typename Type>
+struct GroupedOneImpl<Type, enable_if_t<is_base_binary_type<Type>::value ||
+                                        std::is_same<Type, 
FixedSizeBinaryType>::value>>
+    final : public GroupedAggregator {
+  using Allocator = arrow::stl::allocator<char>;
+  using StringType = std::basic_string<char, std::char_traits<char>, 
Allocator>;
+
+  Status Init(ExecContext* ctx, const std::vector<ValueDescr>&,
+              const FunctionOptions* options) override {
+    ctx_ = ctx;
+    allocator_ = Allocator(ctx->memory_pool());
+    // out_type_ initialized by GroupedOneInit
+    has_value_ = TypedBufferBuilder<bool>(ctx->memory_pool());
+    return Status::OK();
+  }
+
+  Status Resize(int64_t new_num_groups) override {
+    auto added_groups = new_num_groups - num_groups_;
+    DCHECK_GE(added_groups, 0);
+    num_groups_ = new_num_groups;
+    ones_.resize(new_num_groups);
+    RETURN_NOT_OK(has_one_.Append(added_groups, false));
+    RETURN_NOT_OK(has_value_.Append(added_groups, false));
+    return Status::OK();
+  }
+
+  Status Consume(const ExecBatch& batch) override {
+    return VisitGroupedValues<Type>(
+        batch,
+        [&](uint32_t g, util::string_view val) -> Status {
+          if (!bit_util::GetBit(has_one_.data(), g)) {
+            ones_[g].emplace(val.data(), val.size(), allocator_);
+            bit_util::SetBit(has_one_.mutable_data(), g);
+            bit_util::SetBit(has_value_.mutable_data(), g);
+          }
+          return Status::OK();
+        },
+        [&](uint32_t g) -> Status {
+          // as has_one_ is set, has_value_ will never be set, resulting in 
null
+          bit_util::SetBit(has_one_.mutable_data(), g);
+          return Status::OK();
+        });
+  }
+
+  Status Merge(GroupedAggregator&& raw_other,
+               const ArrayData& group_id_mapping) override {
+    auto other = checked_cast<GroupedOneImpl*>(&raw_other);
+    auto g = group_id_mapping.GetValues<uint32_t>(1);
+    for (uint32_t other_g = 0; static_cast<int64_t>(other_g) < 
group_id_mapping.length;
+         ++other_g, ++g) {
+      if (!bit_util::GetBit(has_one_.data(), *g)) {
+        if (bit_util::GetBit(other->has_value_.data(), other_g)) {
+          ones_[*g] = std::move(other->ones_[other_g]);
+          bit_util::SetBit(has_value_.mutable_data(), *g);
+        }
+        bit_util::SetBit(has_one_.mutable_data(), *g);
+      }
+    }
+    return Status::OK();
+  }
+
+  Result<Datum> Finalize() override {
+    ARROW_ASSIGN_OR_RAISE(auto null_bitmap, has_value_.Finish());
+    auto ones =
+        ArrayData::Make(out_type(), num_groups_, {std::move(null_bitmap), 
nullptr});
+    RETURN_NOT_OK(MakeOffsetsValues(ones.get(), ones_));
+    return ones;
+  }
+
+  template <typename T = Type>
+  enable_if_base_binary<T, Status> MakeOffsetsValues(

Review comment:
       These overloaded `MakeOffsetsValues` methods are identical between 
`GroupedOneImp` and `GroupedMinMaxImp`, should we consider making them helper 
functions?  




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to