pitrou commented on code in PR #45562:
URL: https://github.com/apache/arrow/pull/45562#discussion_r1977602793
##########
cpp/src/arrow/compute/kernels/hash_aggregate.cc:
##########
@@ -3319,9 +3324,401 @@ struct GroupedListFactory {
HashAggregateKernel kernel;
InputType argument_type;
};
-} // namespace
-namespace {
+// ----------------------------------------------------------------------
+// Pivot implementation
+
+struct GroupedPivotAccumulator {
+ Status Init(ExecContext* ctx, std::shared_ptr<DataType> value_type,
+ const PivotWiderOptions* options) {
+ ctx_ = ctx;
+ value_type_ = std::move(value_type);
+ num_keys_ = static_cast<int>(options->key_names.size());
+ num_groups_ = 0;
+ columns_.resize(num_keys_);
+ scratch_buffer_ = BufferBuilder(ctx_->memory_pool());
+ return Status::OK();
+ }
+
+ Status Consume(span<const uint32_t> groups, span<const PivotWiderKeyIndex>
keys,
+ const ArraySpan& values) {
+ // To dispatch the values into the right (group, key) coordinates,
+ // we first compute a vector of take indices for each output column.
+ //
+ // For each index #i, we set take_indices[keys[#i]][groups[#i]] = #i.
+ // Unpopulated take_indices entries are null.
+ //
+ // For example, assuming we get:
+ // groups | keys
+ // ===================
+ // 1 | 0
+ // 3 | 1
+ // 1 | 1
+ // 0 | 1
+ //
+ // We are going to compute:
+ // - take_indices[key = 0] = [null, 0, null, null]
+ // - take_indices[key = 1] = [3, 2, null, 1]
+ //
+ // Then each output column is computed by taking the values with the
+ // respective take_indices for the column's keys.
+ //
+
+ DCHECK_EQ(groups.size(), keys.size());
+ DCHECK_EQ(groups.size(), static_cast<size_t>(values.length));
+
+ std::shared_ptr<DataType> take_index_type;
+ std::vector<std::shared_ptr<Buffer>> take_indices(num_keys_);
+ std::vector<std::shared_ptr<Buffer>> take_bitmaps(num_keys_);
+
+ // A generic lambda that computes the take indices with the desired
integer width
+ auto compute_take_indices = [&](auto typed_index) {
+ ARROW_UNUSED(typed_index);
+ using TakeIndex = std::decay_t<decltype(typed_index)>;
+ take_index_type = CTypeTraits<TakeIndex>::type_singleton();
+
+ const int64_t take_indices_size =
+ bit_util::RoundUpToMultipleOf64(num_groups_ * sizeof(TakeIndex));
+ const int64_t take_bitmap_size =
+ bit_util::RoundUpToMultipleOf64(bit_util::BytesForBits(num_groups_));
+ const int64_t total_scratch_size =
+ num_keys_ * (take_indices_size + take_bitmap_size);
+ RETURN_NOT_OK(scratch_buffer_.Resize(total_scratch_size,
/*shrink_to_fit=*/false));
+
+ // Slice the scratch space into individual buffers for each output
column's
+ // take_indices array.
+ std::vector<TakeIndex*> take_indices_data(num_keys_);
+ std::vector<uint8_t*> take_bitmap_data(num_keys_);
+ int64_t offset = 0;
+ for (int i = 0; i < num_keys_; ++i) {
+ take_indices[i] = std::make_shared<MutableBuffer>(
+ scratch_buffer_.mutable_data() + offset, take_indices_size);
+ take_indices_data[i] = take_indices[i]->mutable_data_as<TakeIndex>();
+ offset += take_indices_size;
+ take_bitmaps[i] = std::make_shared<MutableBuffer>(
+ scratch_buffer_.mutable_data() + offset, take_bitmap_size);
+ take_bitmap_data[i] = take_bitmaps[i]->mutable_data();
+ memset(take_bitmap_data[i], 0, take_bitmap_size);
+ offset += take_bitmap_size;
+ }
+ DCHECK_LE(offset, scratch_buffer_.capacity());
+
+ // Populate the take_indices for each output column
+ for (int64_t i = 0; i < values.length; ++i) {
+ const PivotWiderKeyIndex key = keys[i];
+ if (key != kNullPivotKey && !values.IsNull(i)) {
+ DCHECK_LT(static_cast<int>(key), num_keys_);
+ const uint32_t group = groups[i];
+ if (bit_util::GetBit(take_bitmap_data[key], group)) {
+ return DuplicateValue();
+ }
+ // For row #group in column #key, we are going to take the value at
index #i
+ bit_util::SetBit(take_bitmap_data[key], group);
+ take_indices_data[key][group] = static_cast<TakeIndex>(i);
+ }
+ }
+ return Status::OK();
+ };
+
+ // Call compute_take_indices with the optimal integer width
+ if (values.length <=
static_cast<int64_t>(std::numeric_limits<uint8_t>::max())) {
+ RETURN_NOT_OK(compute_take_indices(uint8_t{}));
+ } else if (values.length <=
+ static_cast<int64_t>(std::numeric_limits<uint16_t>::max())) {
+ RETURN_NOT_OK(compute_take_indices(uint16_t{}));
+ } else if (values.length <=
+ static_cast<int64_t>(std::numeric_limits<uint32_t>::max())) {
+ RETURN_NOT_OK(compute_take_indices(uint32_t{}));
+ } else {
+ RETURN_NOT_OK(compute_take_indices(uint64_t{}));
+ }
+
+ // Use take_indices to compute the output columns for this batch
+ auto values_data = values.ToArrayData();
+ ArrayVector new_columns(num_keys_);
+ for (int i = 0; i < num_keys_; ++i) {
+ auto indices_data =
+ ArrayData::Make(take_index_type, num_groups_,
+ {std::move(take_bitmaps[i]),
std::move(take_indices[i])});
+ // If indices_data is all nulls, we can just ignore this column.
+ if (indices_data->GetNullCount() != indices_data->length) {
+ ARROW_ASSIGN_OR_RAISE(Datum grouped_column, Take(values_data,
indices_data,
+
TakeOptions::Defaults(), ctx_));
+ new_columns[i] = grouped_column.make_array();
+ }
+ }
+ // Merge them with the previous columns
+ return MergeColumns(std::move(new_columns));
+ }
+
+ Status Consume(span<const uint32_t> groups, const PivotWiderKeyIndex key,
+ const ArraySpan& values) {
+ if (key == kNullPivotKey) {
+ // Nothing to update
+ return Status::OK();
+ }
+ DCHECK_LT(static_cast<int>(key), num_keys_);
+ DCHECK_EQ(groups.size(), static_cast<size_t>(values.length));
+
+ // The algorithm is simpler than in the array-taking version of Consume()
+ // below, since only the column #key needs to be updated.
+ std::shared_ptr<DataType> take_index_type;
+ std::shared_ptr<Buffer> take_indices;
+ std::shared_ptr<Buffer> take_bitmap;
+
+ // A generic lambda that computes the take indices with the desired
integer width
+ auto compute_take_indices = [&](auto typed_index) {
+ ARROW_UNUSED(typed_index);
+ using TakeIndex = std::decay_t<decltype(typed_index)>;
+ take_index_type = CTypeTraits<TakeIndex>::type_singleton();
+
+ const int64_t take_indices_size =
+ bit_util::RoundUpToMultipleOf64(num_groups_ * sizeof(TakeIndex));
+ const int64_t take_bitmap_size =
+ bit_util::RoundUpToMultipleOf64(bit_util::BytesForBits(num_groups_));
+ const int64_t total_scratch_size = take_indices_size + take_bitmap_size;
+ RETURN_NOT_OK(scratch_buffer_.Resize(total_scratch_size,
/*shrink_to_fit=*/false));
+
+ take_indices =
std::make_shared<MutableBuffer>(scratch_buffer_.mutable_data(),
+ take_indices_size);
+ take_bitmap = std::make_shared<MutableBuffer>(
+ scratch_buffer_.mutable_data() + take_indices_size,
take_bitmap_size);
+ auto take_indices_data = take_indices->mutable_data_as<TakeIndex>();
+ auto take_bitmap_data = take_bitmap->mutable_data();
+ memset(take_bitmap_data, 0, take_bitmap_size);
+
+ for (int64_t i = 0; i < values.length; ++i) {
+ const uint32_t group = groups[i];
+ if (!values.IsNull(i)) {
+ if (bit_util::GetBit(take_bitmap_data, group)) {
+ return DuplicateValue();
+ }
+ bit_util::SetBit(take_bitmap_data, group);
+ take_indices_data[group] = static_cast<TakeIndex>(i);
+ }
+ }
+ return Status::OK();
+ };
+
+ // Call compute_take_indices with the optimal integer width
+ if (values.length <=
static_cast<int64_t>(std::numeric_limits<uint8_t>::max())) {
+ RETURN_NOT_OK(compute_take_indices(uint8_t{}));
+ } else if (values.length <=
+ static_cast<int64_t>(std::numeric_limits<uint16_t>::max())) {
+ RETURN_NOT_OK(compute_take_indices(uint16_t{}));
+ } else if (values.length <=
+ static_cast<int64_t>(std::numeric_limits<uint32_t>::max())) {
+ RETURN_NOT_OK(compute_take_indices(uint32_t{}));
+ } else {
+ RETURN_NOT_OK(compute_take_indices(uint64_t{}));
+ }
+
+ // Use take_indices to update column #key
+ auto values_data = values.ToArrayData();
+ auto indices_data = ArrayData::Make(
+ take_index_type, num_groups_, {std::move(take_bitmap),
std::move(take_indices)});
+ ARROW_ASSIGN_OR_RAISE(Datum grouped_column,
+ Take(values_data, indices_data,
TakeOptions::Defaults(), ctx_));
+ return MergeColumn(&columns_[key], grouped_column.make_array());
+ }
+
+ Status Resize(int64_t new_num_groups) {
+ if (new_num_groups > std::numeric_limits<int32_t>::max()) {
+ return Status::NotImplemented("Pivot with more 2**31 groups");
+ }
+ return ResizeColumns(new_num_groups);
+ }
+
+ Status Merge(GroupedPivotAccumulator&& other, const ArrayData&
group_id_mapping) {
+ // To merge `other` into `*this`, we simply merge their respective columns.
+ // However, we must first transpose `other`'s rows using
`group_id_mapping`.
+ // This is a logical "scatter" operation.
+ //
+ // Since `scatter(indices)` is implemented as
`take(inverse_permutation(indices))`,
+ // we can save time by computing `inverse_permutation(indices)` once for
all
+ // columns.
+
+ // Scatter/InversePermutation only accept signed indices. We checked
+ // in Resize() above that we were inside the limites for int32.
+ auto scatter_indices = group_id_mapping.Copy();
+ scatter_indices->type = int32();
+ std::shared_ptr<DataType> take_indices_type;
+ if (num_groups_ - 1 <= std::numeric_limits<int8_t>::max()) {
+ take_indices_type = int8();
+ } else if (num_groups_ - 1 <= std::numeric_limits<int16_t>::max()) {
+ take_indices_type = int16();
+ } else {
+ DCHECK_GE(num_groups_ - 1, std::numeric_limits<int32_t>::max());
+ take_indices_type = int32();
+ }
+ InversePermutationOptions options(/*max_index=*/num_groups_ - 1,
take_indices_type);
+ ARROW_ASSIGN_OR_RAISE(auto take_indices,
+ InversePermutation(scatter_indices, options, ctx_));
+ auto scatter_column =
+ [&](const std::shared_ptr<Array>& column) ->
Result<std::shared_ptr<Array>> {
+ ARROW_ASSIGN_OR_RAISE(auto scattered,
+ CallFunction("take", {column, take_indices},
&options, ctx_));
+ return scattered.make_array();
+ };
+ return MergeColumns(std::move(other.columns_), std::move(scatter_column));
+ }
+
+ Result<ArrayVector> Finalize() {
+ // Ensure that columns are allocated even if num_groups_ == 0
+ RETURN_NOT_OK(ResizeColumns(num_groups_));
+ return std::move(columns_);
+ }
+
+ protected:
+ Status ResizeColumns(int64_t new_num_groups) {
+ if (new_num_groups == num_groups_ && num_groups_ != 0) {
+ return Status::OK();
+ }
+ ARROW_ASSIGN_OR_RAISE(
+ auto array_suffix,
+ MakeArrayOfNull(value_type_, new_num_groups - num_groups_,
ctx_->memory_pool()));
+ for (auto& column : columns_) {
+ if (num_groups_ != 0) {
+ DCHECK_NE(column, nullptr);
+ ARROW_ASSIGN_OR_RAISE(
+ column, Concatenate({std::move(column), array_suffix},
ctx_->memory_pool()));
+ } else {
+ column = array_suffix;
+ }
+ DCHECK_EQ(column->length(), new_num_groups);
+ }
+ num_groups_ = new_num_groups;
+ return Status::OK();
+ }
+
+ using ColumnTransform =
+ std::function<Result<std::shared_ptr<Array>>(const
std::shared_ptr<Array>&)>;
+
+ Status MergeColumns(ArrayVector&& other_columns,
+ const ColumnTransform& transform = {}) {
+ DCHECK_EQ(columns_.size(), other_columns.size());
+ for (int i = 0; i < num_keys_; ++i) {
+ if (other_columns[i]) {
+ RETURN_NOT_OK(MergeColumn(&columns_[i], std::move(other_columns[i]),
transform));
+ }
+ }
+ return Status::OK();
+ }
+
+ Status MergeColumn(std::shared_ptr<Array>* column, std::shared_ptr<Array>
other_column,
+ const ColumnTransform& transform = {}) {
+ if (other_column->null_count() == other_column->length()) {
+ // Avoid paying for the transform step below, since merging will be a
no-op anyway.
+ return Status::OK();
+ }
+ if (transform) {
+ ARROW_ASSIGN_OR_RAISE(other_column, transform(other_column));
+ }
+ DCHECK_EQ(num_groups_, other_column->length());
+ if (!*column) {
+ *column = other_column;
+ return Status::OK();
+ }
+ if ((*column)->null_count() == (*column)->length()) {
+ *column = other_column;
+ return Status::OK();
+ }
+ int64_t expected_non_nulls = (num_groups_ - (*column)->null_count()) +
+ (num_groups_ - other_column->null_count());
+ ARROW_ASSIGN_OR_RAISE(auto coalesced,
+ CallFunction("coalesce", {*column, other_column},
ctx_));
+ // Check that all non-null values in other_column and column were kept in
the result.
+ if (expected_non_nulls != num_groups_ - coalesced.null_count()) {
+ DCHECK_GT(expected_non_nulls, num_groups_ - coalesced.null_count());
+ return DuplicateValue();
+ }
+ *column = coalesced.make_array();
+ return Status::OK();
+ }
+
+ Status DuplicateValue() {
+ return Status::Invalid(
+ "Encountered more than one non-null value for the same grouped pivot
key");
+ }
+
+ ExecContext* ctx_;
+ std::shared_ptr<DataType> value_type_;
+ int num_keys_;
+ int64_t num_groups_;
+ ArrayVector columns_;
+ // A persistent scratch buffer to store the take indices in Consume
+ BufferBuilder scratch_buffer_;
+};
+
+struct GroupedPivotImpl : public GroupedAggregator {
+ Status Init(ExecContext* ctx, const KernelInitArgs& args) override {
+ DCHECK_EQ(args.inputs.size(), 3);
+ key_type_ = args.inputs[0].GetSharedPtr();
+ options_ = checked_cast<const PivotWiderOptions*>(args.options);
+ DCHECK_NE(options_, nullptr);
+ auto value_type = args.inputs[1].GetSharedPtr();
+ FieldVector fields;
+ fields.reserve(options_->key_names.size());
+ for (const auto& key_name : options_->key_names) {
+ fields.push_back(field(key_name, value_type));
+ }
+ out_type_ = struct_(std::move(fields));
+ out_struct_type_ = checked_cast<const StructType*>(out_type_.get());
+ ARROW_ASSIGN_OR_RAISE(key_mapper_, PivotWiderKeyMapper::Make(*key_type_,
options_));
+ RETURN_NOT_OK(accumulator_.Init(ctx, value_type, options_));
+ return Status::OK();
+ }
+
+ Status Resize(int64_t new_num_groups) override {
+ num_groups_ = new_num_groups;
+ return accumulator_.Resize(new_num_groups);
+ }
+
+ Status Merge(GroupedAggregator&& raw_other,
+ const ArrayData& group_id_mapping) override {
+ auto other = checked_cast<GroupedPivotImpl*>(&raw_other);
+ return accumulator_.Merge(std::move(other->accumulator_),
group_id_mapping);
+ }
+
+ Status Consume(const ExecSpan& batch) override {
+ DCHECK_EQ(batch.values.size(), 3);
+ auto groups = batch[2].array.GetSpan<const uint32_t>(1, batch.length);
+ if (!batch[1].is_array()) {
+ return Status::NotImplemented("Consuming scalar pivot value");
Review Comment:
Good question. It might, depending on whether that is easily testable.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]