icexelloss commented on code in PR #36253:
URL: https://github.com/apache/arrow/pull/36253#discussion_r1242545830
##########
python/pyarrow/src/arrow/python/udf.cc:
##########
@@ -215,9 +249,162 @@ struct PythonUdfScalarAggregatorImpl : public
ScalarUdfAggregator {
return Status::OK();
}
- UdfWrapperCallback agg_cb;
+ std::shared_ptr<OwnedRefNoGIL> function;
+ UdfWrapperCallback cb;
+ std::vector<std::shared_ptr<RecordBatch>> values;
+ std::shared_ptr<Schema> input_schema;
+ std::shared_ptr<DataType> output_type;
+};
+
+struct PythonUdfHashAggregatorImpl : public HashUdfAggregator {
+ PythonUdfHashAggregatorImpl(std::shared_ptr<OwnedRefNoGIL> function,
+ UdfWrapperCallback cb,
+ std::vector<std::shared_ptr<DataType>>
input_types,
+ std::shared_ptr<DataType> output_type)
+ : function(function), cb(std::move(cb)),
output_type(std::move(output_type)) {
+ Py_INCREF(function->obj());
+ std::vector<std::shared_ptr<Field>> fields;
+ for (size_t i = 0; i < input_types.size(); i++) {
+ fields.push_back(field("", input_types[i]));
+ }
+ input_schema = schema(std::move(fields));
+ };
+
+ ~PythonUdfHashAggregatorImpl() override {
+ if (_Py_IsFinalizing()) {
+ function->detach();
+ }
+ }
+
+ /// @brief Same as ApplyGrouping in parition.cc
+ /// Replicated the code here to avoid complicating the dependencies
+ static Result<RecordBatchVector> ApplyGroupings(
+ const ListArray& groupings, const std::shared_ptr<RecordBatch>& batch) {
+ ARROW_ASSIGN_OR_RAISE(Datum sorted,
+ compute::Take(batch,
groupings.data()->child_data[0]));
+
+ const auto& sorted_batch = *sorted.record_batch();
+
+ RecordBatchVector out(static_cast<size_t>(groupings.length()));
+ for (size_t i = 0; i < out.size(); ++i) {
+ out[i] = sorted_batch.Slice(groupings.value_offset(i),
groupings.value_length(i));
+ }
+
+ return out;
+ }
+
+ Status Resize(KernelContext* ctx, int64_t new_num_groups) {
+ // We only need to change num_groups in resize
+ // similar to other hash aggregate kernels
+ num_groups = new_num_groups;
+ return Status::OK();
+ }
+
+ Status Consume(KernelContext* ctx, const ExecSpan& batch) {
+ ARROW_ASSIGN_OR_RAISE(
+ std::shared_ptr<RecordBatch> rb,
+ batch.ToExecBatch().ToRecordBatch(input_schema, ctx->memory_pool()));
+
+ // This is similar to GroupedListImpl
+ // last array is the group id
+ const ArraySpan& groups_array_data = batch[batch.num_values() - 1].array;
+ DCHECK_EQ(groups_array_data.offset, 0);
+ int64_t batch_num_values = groups_array_data.length;
+ const auto* batch_groups = groups_array_data.GetValues<uint32_t>(1, 0);
+ RETURN_NOT_OK(groups.Append(batch_groups, batch_num_values));
+ values.push_back(std::move(rb));
+ num_values += batch_num_values;
+ return Status::OK();
+ }
+ Status Merge(KernelContext* ctx, KernelState&& other_state,
+ const ArrayData& group_id_mapping) {
+ // This is similar to GroupedListImpl
+ auto& other = checked_cast<PythonUdfHashAggregatorImpl&>(other_state);
+ auto& other_values = other.values;
+ const uint32_t* other_raw_groups = other.groups.data();
+ values.insert(values.end(), std::make_move_iterator(other_values.begin()),
+ std::make_move_iterator(other_values.end()));
+
+ auto g = group_id_mapping.GetValues<uint32_t>(1);
+ for (uint32_t other_g = 0; static_cast<int64_t>(other_g) <
other.num_values;
+ ++other_g) {
+ // Different state can have different group_id mappings, so we
+ // need to translate the ids
+ RETURN_NOT_OK(groups.Append(g[other_raw_groups[other_g]]));
+ }
+
+ num_values += other.num_values;
+ return Status::OK();
+ }
+
+ Status Finalize(KernelContext* ctx, Datum* out) {
+ // Exclude the last column which is the group id
+ const int num_args = input_schema->num_fields() - 1;
+
+ ARROW_ASSIGN_OR_RAISE(auto groups_buffer, groups.Finish());
+ ARROW_ASSIGN_OR_RAISE(auto groupings,
+ Grouper::MakeGroupings(UInt32Array(num_values,
groups_buffer),
+
static_cast<uint32_t>(num_groups)));
+
+ ARROW_ASSIGN_OR_RAISE(auto table,
+ arrow::Table::FromRecordBatches(input_schema,
values));
+ ARROW_ASSIGN_OR_RAISE(auto rb,
table->CombineChunksToBatch(ctx->memory_pool()));
Review Comment:
Added note in the user doc:
https://github.com/apache/arrow/pull/36253/files#diff-439f91c435cc8136d40eaba8c168aaa1cec00f08b10ba92ce376e47f29f81814R2770
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]