westonpace commented on code in PR #36748:
URL: https://github.com/apache/arrow/pull/36748#discussion_r1272489252
##########
python/pyarrow/table.pxi:
##########
@@ -5307,8 +5307,14 @@ list[tuple(str, str, FunctionOptions)]
# Ensure aggregate function is hash_ if needed
if len(self.keys) > 0 and not func.startswith("hash_"):
func = "hash_" + func
+ import pyarrow.compute as pc
if len(self.keys) == 0 and func.startswith("hash_"):
- func = func[5:]
+ scalar_func = func[5:]
+ try:
+ pc.get_function(scalar_func)
+ func = scalar_func
+ except:
+ pass
Review Comment:
This changes to an error from C++. What does the error look like?
##########
python/pyarrow/table.pxi:
##########
@@ -5307,8 +5307,14 @@ list[tuple(str, str, FunctionOptions)]
# Ensure aggregate function is hash_ if needed
if len(self.keys) > 0 and not func.startswith("hash_"):
func = "hash_" + func
+ import pyarrow.compute as pc
Review Comment:
Let's use `_pac()`
##########
cpp/src/arrow/compute/kernels/aggregate_basic.cc:
##########
@@ -150,6 +157,85 @@ Result<std::unique_ptr<KernelState>>
CountInit(KernelContext*,
return std::make_unique<CountImpl>(static_cast<const
CountOptions&>(*args.options));
}
+// ----------------------------------------------------------------------
+// Distinct implementations
+
+struct DistinctImpl : public ScalarAggregator {
+ Status Consume(KernelContext* ctx, const ExecSpan& batch) override {
+ if (batch[0].is_array()) {
+ const ArraySpan& input = batch[0].array;
+ this->arrays.push_back(input.ToArray());
+ } else {
+ const Scalar& input = *batch[0].scalar;
+ std::shared_ptr<arrow::Array> scalar_array;
+ ARROW_ASSIGN_OR_RAISE(scalar_array,
+ arrow::MakeArrayFromScalar(input, 1,
ctx->memory_pool()));
+ this->arrays.push_back(scalar_array);
+ }
+ return Status::OK();
+ }
Review Comment:
This approach is "ok" but I think we can do better. It will accumulate the
entire array in memory. We could instead, compute distinct items here and only
add those that aren't already in the set.
We could use a memo table. In consume we could update the memo table with
the new values. In the end we just dump out the memo table.
##########
python/pyarrow/tests/test_acero.py:
##########
@@ -196,6 +196,11 @@ def test_aggregate_scalar(table_source):
with pytest.raises(ValueError, match="is a hash aggregate function"):
_ = decl.to_table()
+ aggr_opts = AggregateNodeOptions([("a", "hash_list", None, "a_list")])
+ decl = Declaration.from_sequence(
+ [table_source, Declaration("aggregate", aggr_opts)]
+ )
Review Comment:
What is being tested here?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]