pitrou commented on a change in pull request #11257:
URL: https://github.com/apache/arrow/pull/11257#discussion_r721311367
##########
File path: cpp/src/arrow/compute/kernels/aggregate_basic.cc
##########
@@ -121,6 +122,137 @@ Result<std::unique_ptr<KernelState>>
CountInit(KernelContext*,
static_cast<const CountOptions&>(*args.options));
}
+// ----------------------------------------------------------------------
+// Distinct Count implementation
+
+template <typename Type, typename VisitorArgType>
+struct CountDistinctImpl : public ScalarAggregator {
+ using MemoTable = typename arrow::internal::HashTraits<Type>::MemoTableType;
+
+ explicit CountDistinctImpl(MemoryPool* memory_pool, CountOptions options)
+ : options(std::move(options)), memo_table_(new MemoTable(memory_pool,
0)) {}
+
+ Status Consume(KernelContext*, const ExecBatch& batch) override {
+ if (batch[0].is_array()) {
+ const ArrayData& arr = *batch[0].array();
+ auto visit_null = []() { return Status::OK(); };
+ auto visit_value = [&](VisitorArgType arg) {
+ int y;
+ RETURN_NOT_OK(memo_table_->GetOrInsert(arg, &y));
+ return Status::OK();
Review comment:
Nit: can be simplified to `return memo_table_->GetOrInsert(arg, &y);`
##########
File path: cpp/src/arrow/compute/kernels/aggregate_basic.cc
##########
@@ -121,6 +122,137 @@ Result<std::unique_ptr<KernelState>>
CountInit(KernelContext*,
static_cast<const CountOptions&>(*args.options));
}
+// ----------------------------------------------------------------------
+// Distinct Count implementation
+
+template <typename Type, typename VisitorArgType>
+struct CountDistinctImpl : public ScalarAggregator {
+ using MemoTable = typename arrow::internal::HashTraits<Type>::MemoTableType;
+
+ explicit CountDistinctImpl(MemoryPool* memory_pool, CountOptions options)
+ : options(std::move(options)), memo_table_(new MemoTable(memory_pool,
0)) {}
+
+ Status Consume(KernelContext*, const ExecBatch& batch) override {
+ if (batch[0].is_array()) {
+ const ArrayData& arr = *batch[0].array();
+ auto visit_null = []() { return Status::OK(); };
+ auto visit_value = [&](VisitorArgType arg) {
+ int y;
+ RETURN_NOT_OK(memo_table_->GetOrInsert(arg, &y));
+ return Status::OK();
+ };
+ RETURN_NOT_OK(VisitArrayDataInline<Type>(arr, visit_value, visit_null));
+ this->non_nulls += memo_table_->size();
+ this->has_nulls = arr.GetNullCount() > 0;
+ } else {
+ const Scalar& input = *batch[0].scalar();
+ this->has_nulls = !input.is_valid;
+ this->non_nulls += batch.length;
Review comment:
Even if `input.is_valid` is false?
##########
File path: cpp/src/arrow/compute/kernels/aggregate_test.cc
##########
@@ -873,6 +874,98 @@ TYPED_TEST(TestRandomNumericCountKernel, RandomArrayCount)
{
}
}
+//
+// Count Distinct
+//
+
+class TestCountDistinctKernel : public ::testing::Test {
+ protected:
+ void SetUp() override {
+ only_valid = CountOptions(CountOptions::ONLY_VALID);
+ only_null = CountOptions(CountOptions::ONLY_NULL);
+ all = CountOptions(CountOptions::ALL);
+ }
+
+ Datum Expected(int64_t value) { return
MakeScalar(static_cast<int64_t>(value)); }
+
+ void Check(Datum input, int64_t expected_all, bool has_nulls = true) {
+ int64_t expected_valid = has_nulls ? expected_all - 1 : expected_all;
+ int64_t expected_null = has_nulls ? 1 : 0;
+ CheckScalar("count_distinct", {input}, Expected(expected_valid),
&only_valid);
+ CheckScalar("count_distinct", {input}, Expected(expected_null),
&only_null);
+ CheckScalar("count_distinct", {input}, Expected(expected_all), &all);
+ }
+
+ void Check(const std::shared_ptr<DataType>& type, util::string_view json,
+ int64_t expected_all, bool has_nulls = true) {
+ Check(ArrayFromJSON(type, json), expected_all, has_nulls);
+ }
+
+ CountOptions only_valid;
+ CountOptions only_null;
+ CountOptions all;
+};
+
+TEST_F(TestCountDistinctKernel, AllArrowTypesWithNulls) {
+ // Boolean
+ Check(boolean(), "[true, null, false, null, false, true]", 3);
+ // Number
+ for (auto ty : NumericTypes()) {
+ Check(ty, "[1, 1, null, 2, 5, 8, 9, 9, null, 10, 6, 6]", 8);
+ Check(ty, "[1, 1, 8, 2, 5, 8, 9, 9, 10, 10, 6, 6]", 7, false);
+ }
+ // Date
+ Check(date32(), "[0, 11016, 0, null, 14241, 14241, null]", 4);
+ Check(date64(), "[0, null, 0, null, 0, 0, 1262217600000]", 3);
+ // Time
+ Check(time32(TimeUnit::SECOND), "[0, 11, 0, null, 14, 14, null]", 4);
+ Check(time32(TimeUnit::MILLI), "[0, 11000, 0, null, 11000, 11000]", 3);
+ Check(time64(TimeUnit::MICRO), "[84203999999, 0, null, 84203999999, 0]", 3);
+ Check(time64(TimeUnit::NANO), "[11715003000000, 0, null, 0, 0]", 3);
+ // Timestamp & Duration
+ for (auto u : TimeUnit::values()) {
+ Check(duration(u), "[123456789, null, 987654321, 123456789, null]", 3);
+ Check(duration(u), "[123456789, 987654321, 123456789, 123456789]", 2,
false);
+ auto ts = R"(["2009-12-31T04:20:20", "2020-01-01", null,
"2009-12-31T04:20:20"])";
+ Check(timestamp(u), ts, 3);
+ Check(timestamp(u, "Pacific/Marquesas"), ts, 3);
+ }
+ // Interval
+ Check(month_interval(), "[9012, 5678, null, 9012, 5678, null, 9012]", 3);
+ Check(day_time_interval(), "[[0, 1], [0, 1], null, [0, 1], [1234, 5678]]",
3);
+ Check(month_day_nano_interval(), "[[0, 1, 2], [0, 1, 2], null, [0, 1, 2]]",
2);
+ // Binary & String & Fixed binary
+ auto samples = R"([null, "abc", null, "abc", "abc", "cba", "bca", "cba",
null])";
+ Check(binary(), samples, 4);
+ Check(large_binary(), samples, 4);
+ Check(utf8(), samples, 4);
+ Check(large_utf8(), samples, 4);
+ Check(fixed_size_binary(3), samples, 4);
+ // Decimal
+ samples = R"(["12345.679", "98765.421", null, "12345.679", "98765.421"])";
+ Check(decimal(15, 3), samples, 3);
Review comment:
`decimal` is a deprecated alias for `decimal128`, so this is redundant.
##########
File path: r/src/compute.cpp
##########
@@ -208,7 +208,7 @@ std::shared_ptr<arrow::compute::FunctionOptions>
make_compute_options(
return out;
}
- if (func_name == "hash_count_distinct") {
+ if (func_name == "count_distinct" || func_name == "hash_count_distinct") {
Review comment:
Unrelated to this PR, but it's amusing that R tests for function names
even though the name of the options class is available as
`FunctionDoc::options_class`. @nealrichardson @jonkeane @thisisnic
##########
File path: cpp/src/arrow/compute/kernels/aggregate_basic.cc
##########
@@ -121,6 +122,137 @@ Result<std::unique_ptr<KernelState>>
CountInit(KernelContext*,
static_cast<const CountOptions&>(*args.options));
}
+// ----------------------------------------------------------------------
+// Distinct Count implementation
+
+template <typename Type, typename VisitorArgType>
+struct CountDistinctImpl : public ScalarAggregator {
+ using MemoTable = typename arrow::internal::HashTraits<Type>::MemoTableType;
+
+ explicit CountDistinctImpl(MemoryPool* memory_pool, CountOptions options)
+ : options(std::move(options)), memo_table_(new MemoTable(memory_pool,
0)) {}
+
+ Status Consume(KernelContext*, const ExecBatch& batch) override {
+ if (batch[0].is_array()) {
+ const ArrayData& arr = *batch[0].array();
+ auto visit_null = []() { return Status::OK(); };
+ auto visit_value = [&](VisitorArgType arg) {
+ int y;
+ RETURN_NOT_OK(memo_table_->GetOrInsert(arg, &y));
+ return Status::OK();
+ };
+ RETURN_NOT_OK(VisitArrayDataInline<Type>(arr, visit_value, visit_null));
+ this->non_nulls += memo_table_->size();
+ this->has_nulls = arr.GetNullCount() > 0;
+ } else {
+ const Scalar& input = *batch[0].scalar();
+ this->has_nulls = !input.is_valid;
+ this->non_nulls += batch.length;
+ }
+ return Status::OK();
+ }
+
+ Status MergeFrom(KernelContext*, KernelState&& src) override {
+ const auto& other_state = checked_cast<const CountDistinctImpl&>(src);
+ this->non_nulls += other_state.non_nulls;
+ this->has_nulls = this->has_nulls || other_state.has_nulls;
+ return Status::OK();
+ }
+
+ Status Finalize(KernelContext* ctx, Datum* out) override {
+ const auto& state = checked_cast<const CountDistinctImpl&>(*ctx->state());
+ const int64_t nulls = state.has_nulls ? 1 : 0;
+ switch (state.options.mode) {
+ case CountOptions::ONLY_VALID:
+ *out = Datum(state.non_nulls);
+ break;
+ case CountOptions::ALL:
+ *out = Datum(state.non_nulls + nulls);
+ break;
+ case CountOptions::ONLY_NULL:
+ *out = Datum(nulls);
+ break;
+ default:
+ DCHECK(false) << "unreachable";
+ }
+ return Status::OK();
+ }
+
+ const CountOptions options;
+ int64_t non_nulls = 0;
+ bool has_nulls = false;
+ std::unique_ptr<MemoTable> memo_table_;
+};
+
+template <typename Type, typename VisitorArgType>
+Result<std::unique_ptr<KernelState>> CountDistinctInit(KernelContext* ctx,
+ const KernelInitArgs&
args) {
+ return ::arrow::internal::make_unique<CountDistinctImpl<Type,
VisitorArgType>>(
+ ctx->memory_pool(), static_cast<const CountOptions&>(*args.options));
+}
+
+template <typename Type, typename VisitorArgType>
+void AddCountDistinctKernel(InputType type, ScalarAggregateFunction* func) {
+ AddAggKernel(KernelSignature::Make({type}, ValueDescr::Scalar(int64())),
+ aggregate::CountDistinctInit<Type, VisitorArgType>, func);
+}
+
+template <typename Type, typename VisitorArgType = typename Type::c_type>
+struct CountDistinctKernel {
+ static void Add(InputType type, ScalarAggregateFunction* func) {
+ using PhysicalType = typename Type::PhysicalType;
+ AddCountDistinctKernel<PhysicalType, VisitorArgType>(type, func);
+ }
+};
+
+template <>
+struct CountDistinctKernel<FixedSizeBinaryType, util::string_view> {
+ static void Add(InputType type, ScalarAggregateFunction* func) {
+ AddCountDistinctKernel<FixedSizeBinaryType, util::string_view>(type, func);
+ }
+};
+
+void AddCountDistinctKernels(ScalarAggregateFunction* func) {
+ // Boolean
+ aggregate::CountDistinctKernel<BooleanType>::Add(boolean(), func);
+ // Number
+ aggregate::CountDistinctKernel<Int8Type>::Add(int8(), func);
+ aggregate::CountDistinctKernel<Int16Type>::Add(int16(), func);
+ aggregate::CountDistinctKernel<Int32Type>::Add(int32(), func);
+ aggregate::CountDistinctKernel<Int64Type>::Add(int64(), func);
+ aggregate::CountDistinctKernel<UInt8Type>::Add(uint8(), func);
+ aggregate::CountDistinctKernel<UInt16Type>::Add(uint16(), func);
+ aggregate::CountDistinctKernel<UInt32Type>::Add(uint32(), func);
+ aggregate::CountDistinctKernel<UInt64Type>::Add(uint64(), func);
+ aggregate::CountDistinctKernel<HalfFloatType>::Add(float16(), func);
+ aggregate::CountDistinctKernel<FloatType>::Add(float32(), func);
+ aggregate::CountDistinctKernel<DoubleType>::Add(float64(), func);
+ // Date
+ aggregate::CountDistinctKernel<Date32Type>::Add(date32(), func);
+ aggregate::CountDistinctKernel<Date64Type>::Add(date64(), func);
+ // Time
+
aggregate::CountDistinctKernel<Time32Type>::Add(match::SameTypeId(Type::TIME32),
func);
+
aggregate::CountDistinctKernel<Time64Type>::Add(match::SameTypeId(Type::TIME64),
func);
+ // Timestamp & Duration
+
aggregate::CountDistinctKernel<TimestampType>::Add(match::SameTypeId(Type::TIMESTAMP),
+ func);
+
aggregate::CountDistinctKernel<DurationType>::Add(match::SameTypeId(Type::DURATION),
+ func);
+ // Interval
+ aggregate::CountDistinctKernel<MonthIntervalType>::Add(month_interval(),
func);
+
aggregate::CountDistinctKernel<DayTimeIntervalType>::Add(day_time_interval(),
func);
+
aggregate::CountDistinctKernel<MonthDayNanoIntervalType>::Add(month_day_nano_interval(),
+ func);
+ // Binary & String
+ aggregate::CountDistinctKernel<BinaryType,
util::string_view>::Add(match::BinaryLike(),
+ func);
+ aggregate::CountDistinctKernel<LargeBinaryType, util::string_view>::Add(
+ match::LargeBinaryLike(), func);
+ // Fixed binary & Decimal
+ aggregate::CountDistinctKernel<FixedSizeBinaryType, util::string_view>::Add(
+ match::FixedSizeBinaryLike(), func);
Review comment:
It would be nice to avoid enumerating types explicitly. Can you take a
look at the helpers exposed by `codegen_internal.h`, for example
`GeneratePhysicalNumeric`, `GenerateTypeAgnosticVarBinaryBase`. One might also
add a `GenerateTypeAgnosticFixedBinaryBase`.
##########
File path: cpp/src/arrow/compute/kernels/aggregate_test.cc
##########
@@ -873,6 +874,98 @@ TYPED_TEST(TestRandomNumericCountKernel, RandomArrayCount)
{
}
}
+//
+// Count Distinct
+//
+
+class TestCountDistinctKernel : public ::testing::Test {
+ protected:
+ void SetUp() override {
+ only_valid = CountOptions(CountOptions::ONLY_VALID);
+ only_null = CountOptions(CountOptions::ONLY_NULL);
+ all = CountOptions(CountOptions::ALL);
+ }
+
+ Datum Expected(int64_t value) { return
MakeScalar(static_cast<int64_t>(value)); }
+
+ void Check(Datum input, int64_t expected_all, bool has_nulls = true) {
+ int64_t expected_valid = has_nulls ? expected_all - 1 : expected_all;
+ int64_t expected_null = has_nulls ? 1 : 0;
+ CheckScalar("count_distinct", {input}, Expected(expected_valid),
&only_valid);
+ CheckScalar("count_distinct", {input}, Expected(expected_null),
&only_null);
+ CheckScalar("count_distinct", {input}, Expected(expected_all), &all);
+ }
+
+ void Check(const std::shared_ptr<DataType>& type, util::string_view json,
+ int64_t expected_all, bool has_nulls = true) {
+ Check(ArrayFromJSON(type, json), expected_all, has_nulls);
+ }
+
+ CountOptions only_valid;
+ CountOptions only_null;
+ CountOptions all;
+};
+
+TEST_F(TestCountDistinctKernel, AllArrowTypesWithNulls) {
Review comment:
Can you add a test for scalars?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]