This is an automated email from the ASF dual-hosted git repository.
felipecrv pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new 4f89097765 GH-41301: [C++] Extract the kernel loops used for
PrimitiveTakeExec and generalize to any fixed-width type (#41373)
4f89097765 is described below
commit 4f890977650a36abaaec74ad2eaac31c04b5bf76
Author: Felipe Oliveira Carvalho <[email protected]>
AuthorDate: Mon Jun 10 12:32:46 2024 -0300
GH-41301: [C++] Extract the kernel loops used for PrimitiveTakeExec and
generalize to any fixed-width type (#41373)
### Rationale for this change
I want to instantiate this primitive operation in other scenarios (e.g. the
optimized version of `Take` that handles chunked arrays) and extend the
sub-classes of `GatherCRTP` with different member functions that re-use the
`WriteValue` function generically (any fixed-width type and even bit-wide
booleans).
When taking these improvements to `Filter` I will also re-use the "gather"
concept and parameterize it by bitmaps/boolean-arrays instead of selection
vectors (`indices`) like `take` does. So gather is not a "renaming of take" but
rather a generalization of `take` and `filter` do in Arrow with different
representations of what should be gathered from the `values` array.
### What changes are included in this PR?
- Introduce the Gather class helper to delegate fixed-width memory
gathering: both static and dynamically sized (size known at compile time or
size known at runtime)
- Specialized `Take` implementation for values/indices without nulls
- Fold the Boolean, Primitives, and Fixed-Width Binary implementation of
`Take` into a single one
- Skip validity bitmap allocation when inputs (values and indices) have no
nulls
### Are these changes tested?
- Existing tests
- New test assertions that check that `Take` guarantees null values are
zeroed out
* GitHub Issue: #41301
Authored-by: Felipe Oliveira Carvalho <[email protected]>
Signed-off-by: Felipe Oliveira Carvalho <[email protected]>
---
cpp/src/arrow/compute/kernels/gather_internal.h | 306 +++++++++++++++++
.../compute/kernels/vector_selection_internal.cc | 68 +---
.../compute/kernels/vector_selection_internal.h | 3 +-
.../kernels/vector_selection_take_internal.cc | 375 ++++++---------------
.../arrow/compute/kernels/vector_selection_test.cc | 17 +-
5 files changed, 435 insertions(+), 334 deletions(-)
diff --git a/cpp/src/arrow/compute/kernels/gather_internal.h
b/cpp/src/arrow/compute/kernels/gather_internal.h
new file mode 100644
index 0000000000..4c161533a7
--- /dev/null
+++ b/cpp/src/arrow/compute/kernels/gather_internal.h
@@ -0,0 +1,306 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cassert>
+#include <cstddef>
+#include <cstdint>
+
+#include "arrow/array/data.h"
+#include "arrow/util/bit_block_counter.h"
+#include "arrow/util/bit_run_reader.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/bitmap_ops.h"
+#include "arrow/util/macros.h"
+
+// Implementation helpers for kernels that need to load/gather fixed-width
+// data from multiple, arbitrary indices.
+//
+// https://en.wikipedia.org/wiki/Gather/scatter_(vector_addressing)
+
+namespace arrow::internal {
+
+// CRTP [1] base class for Gather that provides a gathering loop in terms of
+// Write*() methods that must be implemented by the derived class.
+//
+// [1] https://en.wikipedia.org/wiki/Curiously_recurring_template_pattern
+template <class GatherImpl>
+class GatherBaseCRTP {
+ public:
+ // Output offset is not supported by Gather and idx is supposed to have
offset
+ // pre-applied. idx_validity parameters on functions can use the offset they
+ // carry to read the validity bitmap as bitmaps can't have pre-applied
offsets
+ // (they might not align to byte boundaries).
+
+ GatherBaseCRTP() = default;
+ ARROW_DISALLOW_COPY_AND_ASSIGN(GatherBaseCRTP);
+ ARROW_DEFAULT_MOVE_AND_ASSIGN(GatherBaseCRTP);
+
+ protected:
+ ARROW_FORCE_INLINE int64_t ExecuteNoNulls(int64_t idx_length) {
+ auto* self = static_cast<GatherImpl*>(this);
+ for (int64_t position = 0; position < idx_length; position++) {
+ self->WriteValue(position);
+ }
+ return idx_length;
+ }
+
+ // See derived Gather classes below for the meaning of the parameters, pre
and
+ // post-conditions.
+ //
+ // src_validity is not necessarily the source of the values that are being
+ // gathered (e.g. the source could be a nested fixed-size list array and the
+ // values being gathered are from the innermost buffer), so the ArraySpan is
+ // used solely to check for nulls in the source values and nothing else.
+ //
+ // idx_length is the number of elements in idx and consequently the number of
+ // bits that might be written to out_is_valid. Member `Write*()` functions
will be
+ // called with positions from 0 to idx_length - 1.
+ //
+ // If `kOutputIsZeroInitialized` is true, then `WriteZero()` or
`WriteZeroSegment()`
+ // doesn't have to be called for resulting null positions. A position is
+ // considered null if either the index or the source value is null at that
+ // position.
+ template <bool kOutputIsZeroInitialized, typename IndexCType>
+ ARROW_FORCE_INLINE int64_t ExecuteWithNulls(const ArraySpan& src_validity,
+ int64_t idx_length, const
IndexCType* idx,
+ const ArraySpan& idx_validity,
+ uint8_t* out_is_valid) {
+ auto* self = static_cast<GatherImpl*>(this);
+ OptionalBitBlockCounter indices_bit_counter(idx_validity.buffers[0].data,
+ idx_validity.offset,
idx_length);
+ int64_t position = 0;
+ int64_t valid_count = 0;
+ while (position < idx_length) {
+ BitBlockCount block = indices_bit_counter.NextBlock();
+ if (!src_validity.MayHaveNulls()) {
+ // Source values are never null, so things are easier
+ valid_count += block.popcount;
+ if (block.popcount == block.length) {
+ // Fastest path: neither source values nor index nulls
+ bit_util::SetBitsTo(out_is_valid, position, block.length, true);
+ for (int64_t i = 0; i < block.length; ++i) {
+ self->WriteValue(position);
+ ++position;
+ }
+ } else if (block.popcount > 0) {
+ // Slow path: some indices but not all are null
+ for (int64_t i = 0; i < block.length; ++i) {
+ ARROW_COMPILER_ASSUME(idx_validity.buffers[0].data != nullptr);
+ if (idx_validity.IsValid(position)) {
+ // index is not null
+ bit_util::SetBit(out_is_valid, position);
+ self->WriteValue(position);
+ } else if constexpr (!kOutputIsZeroInitialized) {
+ self->WriteZero(position);
+ }
+ ++position;
+ }
+ } else {
+ self->WriteZeroSegment(position, block.length);
+ position += block.length;
+ }
+ } else {
+ // Source values may be null, so we must do random access into
src_validity
+ if (block.popcount == block.length) {
+ // Faster path: indices are not null but source values may be
+ for (int64_t i = 0; i < block.length; ++i) {
+ ARROW_COMPILER_ASSUME(src_validity.buffers[0].data != nullptr);
+ if (src_validity.IsValid(idx[position])) {
+ // value is not null
+ self->WriteValue(position);
+ bit_util::SetBit(out_is_valid, position);
+ ++valid_count;
+ } else if constexpr (!kOutputIsZeroInitialized) {
+ self->WriteZero(position);
+ }
+ ++position;
+ }
+ } else if (block.popcount > 0) {
+ // Slow path: some but not all indices are null. Since we are doing
+ // random access in general we have to check the value nullness one
by
+ // one.
+ for (int64_t i = 0; i < block.length; ++i) {
+ ARROW_COMPILER_ASSUME(src_validity.buffers[0].data != nullptr);
+ ARROW_COMPILER_ASSUME(idx_validity.buffers[0].data != nullptr);
+ if (idx_validity.IsValid(position) &&
src_validity.IsValid(idx[position])) {
+ // index is not null && value is not null
+ self->WriteValue(position);
+ bit_util::SetBit(out_is_valid, position);
+ ++valid_count;
+ } else if constexpr (!kOutputIsZeroInitialized) {
+ self->WriteZero(position);
+ }
+ ++position;
+ }
+ } else {
+ if constexpr (!kOutputIsZeroInitialized) {
+ self->WriteZeroSegment(position, block.length);
+ }
+ position += block.length;
+ }
+ }
+ }
+ return valid_count;
+ }
+};
+
+// A gather primitive for primitive fixed-width types with a integral byte
width. If
+// `kWithFactor` is true, the actual width is a runtime multiple of
`kValueWidthInbits`
+// (this can be useful for fixed-size list inputs and other input types with
unusual byte
+// widths that don't deserve value specialization).
+template <int kValueWidthInBits, typename IndexCType, bool kWithFactor>
+class Gather : public GatherBaseCRTP<Gather<kValueWidthInBits, IndexCType,
kWithFactor>> {
+ public:
+ static_assert(kValueWidthInBits >= 0 && kValueWidthInBits % 8 == 0);
+ static constexpr int kValueWidth = kValueWidthInBits / 8;
+
+ private:
+ const int64_t src_length_; // number of elements of kValueWidth bytes in
src_
+ const uint8_t* src_;
+ const int64_t idx_length_; // number IndexCType elements in idx_
+ const IndexCType* idx_;
+ uint8_t* out_;
+ int64_t factor_;
+
+ public:
+ void WriteValue(int64_t position) {
+ if constexpr (kWithFactor) {
+ const int64_t scaled_factor = kValueWidth * factor_;
+ memcpy(out_ + position * scaled_factor, src_ + idx_[position] *
scaled_factor,
+ scaled_factor);
+ } else {
+ memcpy(out_ + position * kValueWidth, src_ + idx_[position] *
kValueWidth,
+ kValueWidth);
+ }
+ }
+
+ void WriteZero(int64_t position) {
+ if constexpr (kWithFactor) {
+ const int64_t scaled_factor = kValueWidth * factor_;
+ memset(out_ + position * scaled_factor, 0, scaled_factor);
+ } else {
+ memset(out_ + position * kValueWidth, 0, kValueWidth);
+ }
+ }
+
+ void WriteZeroSegment(int64_t position, int64_t length) {
+ if constexpr (kWithFactor) {
+ const int64_t scaled_factor = kValueWidth * factor_;
+ memset(out_ + position * scaled_factor, 0, length * scaled_factor);
+ } else {
+ memset(out_ + position * kValueWidth, 0, length * kValueWidth);
+ }
+ }
+
+ public:
+ Gather(int64_t src_length, const uint8_t* src, int64_t zero_src_offset,
+ int64_t idx_length, const IndexCType* idx, uint8_t* out, int64_t
factor)
+ : src_length_(src_length),
+ src_(src),
+ idx_length_(idx_length),
+ idx_(idx),
+ out_(out),
+ factor_(factor) {
+ assert(zero_src_offset == 0);
+ assert(src && idx && out);
+ assert((kWithFactor || factor == 1) &&
+ "When kWithFactor is false, the factor is assumed to be 1 at
compile time");
+ }
+
+ ARROW_FORCE_INLINE int64_t Execute() { return
this->ExecuteNoNulls(idx_length_); }
+
+ /// \pre If kOutputIsZeroInitialized, then this->out_ has to be zero
initialized.
+ /// \pre Bits in out_is_valid have to always be zero initialized.
+ /// \post The bits for the valid elements (and only those) are set in
out_is_valid.
+ /// \post If !kOutputIsZeroInitialized, then positions in this->_out
containing null
+ /// elements have 0s written to them. This might be less efficient than
+ /// zero-initializing first and calling this->Execute() afterwards.
+ /// \return The number of valid elements in out.
+ template <bool kOutputIsZeroInitialized = false>
+ ARROW_FORCE_INLINE int64_t Execute(const ArraySpan& src_validity,
+ const ArraySpan& idx_validity,
+ uint8_t* out_is_valid) {
+ assert(src_length_ == src_validity.length);
+ assert(idx_length_ == idx_validity.length);
+ assert(out_is_valid);
+ return this->template ExecuteWithNulls<kOutputIsZeroInitialized>(
+ src_validity, idx_length_, idx_, idx_validity, out_is_valid);
+ }
+};
+
+// A gather primitive for boolean inputs. Unlike its counterpart above,
+// this does not support passing a non-trivial factor parameter.
+template <typename IndexCType>
+class Gather</*kValueWidthInBits=*/1, IndexCType, /*kWithFactor=*/false>
+ : public GatherBaseCRTP<Gather<1, IndexCType, false>> {
+ private:
+ const int64_t src_length_; // number of elements of bits bytes in src_
after offset
+ const uint8_t* src_; // the boolean array data buffer in bits
+ const int64_t src_offset_; // offset in bits
+ const int64_t idx_length_; // number IndexCType elements in idx_
+ const IndexCType* idx_;
+ uint8_t* out_; // output boolean array data buffer in bits
+
+ public:
+ Gather(int64_t src_length, const uint8_t* src, int64_t src_offset, int64_t
idx_length,
+ const IndexCType* idx, uint8_t* out, int64_t factor)
+ : src_length_(src_length),
+ src_(src),
+ src_offset_(src_offset),
+ idx_length_(idx_length),
+ idx_(idx),
+ out_(out) {
+ assert(src && idx && out);
+ assert(factor == 1 &&
+ "factor != 1 is not supported when Gather is used to gather
bits/booleans");
+ }
+
+ void WriteValue(int64_t position) {
+ bit_util::SetBitTo(out_, position,
+ bit_util::GetBit(src_, src_offset_ + idx_[position]));
+ }
+
+ void WriteZero(int64_t position) { bit_util::ClearBit(out_, position); }
+
+ void WriteZeroSegment(int64_t position, int64_t block_length) {
+ bit_util::SetBitsTo(out_, position, block_length, false);
+ }
+
+ ARROW_FORCE_INLINE int64_t Execute() { return
this->ExecuteNoNulls(idx_length_); }
+
+ /// \pre If kOutputIsZeroInitialized, then this->out_ has to be zero
initialized.
+ /// \pre Bits in out_is_valid have to always be zero initialized.
+ /// \post The bits for the valid elements (and only those) are set in
out_is_valid.
+ /// \post If !kOutputIsZeroInitialized, then positions in this->_out
containing null
+ /// elements have 0s written to them. This might be less efficient than
+ /// zero-initializing first and calling this->Execute() afterwards.
+ /// \return The number of valid elements in out.
+ template <bool kOutputIsZeroInitialized = false>
+ ARROW_FORCE_INLINE int64_t Execute(const ArraySpan& src_validity,
+ const ArraySpan& idx_validity,
+ uint8_t* out_is_valid) {
+ assert(src_length_ == src_validity.length);
+ assert(idx_length_ == idx_validity.length);
+ assert(out_is_valid);
+ return this->template ExecuteWithNulls<kOutputIsZeroInitialized>(
+ src_validity, idx_length_, idx_, idx_validity, out_is_valid);
+ }
+};
+
+} // namespace arrow::internal
diff --git a/cpp/src/arrow/compute/kernels/vector_selection_internal.cc
b/cpp/src/arrow/compute/kernels/vector_selection_internal.cc
index 2ba660e49a..1009bea5e7 100644
--- a/cpp/src/arrow/compute/kernels/vector_selection_internal.cc
+++ b/cpp/src/arrow/compute/kernels/vector_selection_internal.cc
@@ -547,39 +547,6 @@ struct VarBinarySelectionImpl : public
Selection<VarBinarySelectionImpl<Type>, T
}
};
-struct FSBSelectionImpl : public Selection<FSBSelectionImpl,
FixedSizeBinaryType> {
- using Base = Selection<FSBSelectionImpl, FixedSizeBinaryType>;
- LIFT_BASE_MEMBERS();
-
- TypedBufferBuilder<uint8_t> data_builder;
-
- FSBSelectionImpl(KernelContext* ctx, const ExecSpan& batch, int64_t
output_length,
- ExecResult* out)
- : Base(ctx, batch, output_length, out), data_builder(ctx->memory_pool())
{}
-
- template <typename Adapter>
- Status GenerateOutput() {
- FixedSizeBinaryArray typed_values(this->values.ToArrayData());
- int32_t value_size = typed_values.byte_width();
-
- RETURN_NOT_OK(data_builder.Reserve(value_size * output_length));
- Adapter adapter(this);
- return adapter.Generate(
- [&](int64_t index) {
- auto val = typed_values.GetView(index);
- data_builder.UnsafeAppend(reinterpret_cast<const
uint8_t*>(val.data()),
- value_size);
- return Status::OK();
- },
- [&]() {
- data_builder.UnsafeAppend(value_size, static_cast<uint8_t>(0x00));
- return Status::OK();
- });
- }
-
- Status Finish() override { return data_builder.Finish(&out->buffers[1]); }
-};
-
template <typename Type>
struct ListSelectionImpl : public Selection<ListSelectionImpl<Type>, Type> {
using offset_type = typename Type::offset_type;
@@ -939,23 +906,6 @@ Status LargeVarBinaryTakeExec(KernelContext* ctx, const
ExecSpan& batch,
return TakeExec<VarBinarySelectionImpl<LargeBinaryType>>(ctx, batch, out);
}
-Status FSBTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out)
{
- const ArraySpan& values = batch[0].array;
- const auto byte_width = values.type->byte_width();
- // Use primitive Take implementation (presumably faster) for some byte widths
- switch (byte_width) {
- case 1:
- case 2:
- case 4:
- case 8:
- case 16:
- case 32:
- return PrimitiveTakeExec(ctx, batch, out);
- default:
- return TakeExec<FSBSelectionImpl>(ctx, batch, out);
- }
-}
-
Status ListTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult*
out) {
return TakeExec<ListSelectionImpl<ListType>>(ctx, batch, out);
}
@@ -968,26 +918,12 @@ Status FSLTakeExec(KernelContext* ctx, const ExecSpan&
batch, ExecResult* out) {
const ArraySpan& values = batch[0].array;
// If a FixedSizeList wraps a fixed-width type we can, in some cases, use
- // PrimitiveTakeExec for a fixed-size list array.
+ // FixedWidthTakeExec for a fixed-size list array.
if (util::IsFixedWidthLike(values,
/*force_null_count=*/true,
/*exclude_bool_and_dictionary=*/true)) {
- const auto byte_width = util::FixedWidthInBytes(*values.type);
- // Additionally, PrimitiveTakeExec is only implemented for specific byte
widths.
- // TODO(GH-41301): Extend PrimitiveTakeExec for any fixed-width type.
- switch (byte_width) {
- case 1:
- case 2:
- case 4:
- case 8:
- case 16:
- case 32:
- return PrimitiveTakeExec(ctx, batch, out);
- default:
- break; // fallback to TakeExec<FSBSelectionImpl>
- }
+ return FixedWidthTakeExec(ctx, batch, out);
}
-
return TakeExec<FSLSelectionImpl>(ctx, batch, out);
}
diff --git a/cpp/src/arrow/compute/kernels/vector_selection_internal.h
b/cpp/src/arrow/compute/kernels/vector_selection_internal.h
index a169f4b38a..c5075d6dfe 100644
--- a/cpp/src/arrow/compute/kernels/vector_selection_internal.h
+++ b/cpp/src/arrow/compute/kernels/vector_selection_internal.h
@@ -73,8 +73,7 @@ Status MapFilterExec(KernelContext*, const ExecSpan&,
ExecResult*);
Status VarBinaryTakeExec(KernelContext*, const ExecSpan&, ExecResult*);
Status LargeVarBinaryTakeExec(KernelContext*, const ExecSpan&, ExecResult*);
-Status PrimitiveTakeExec(KernelContext*, const ExecSpan&, ExecResult*);
-Status FSBTakeExec(KernelContext*, const ExecSpan&, ExecResult*);
+Status FixedWidthTakeExec(KernelContext*, const ExecSpan&, ExecResult*);
Status ListTakeExec(KernelContext*, const ExecSpan&, ExecResult*);
Status LargeListTakeExec(KernelContext*, const ExecSpan&, ExecResult*);
Status FSLTakeExec(KernelContext*, const ExecSpan&, ExecResult*);
diff --git a/cpp/src/arrow/compute/kernels/vector_selection_take_internal.cc
b/cpp/src/arrow/compute/kernels/vector_selection_take_internal.cc
index 1a9af0efcd..dee80e9d25 100644
--- a/cpp/src/arrow/compute/kernels/vector_selection_take_internal.cc
+++ b/cpp/src/arrow/compute/kernels/vector_selection_take_internal.cc
@@ -19,6 +19,7 @@
#include <cstring>
#include <limits>
#include <memory>
+#include <utility>
#include <vector>
#include "arrow/array/builder_primitive.h"
@@ -27,6 +28,7 @@
#include "arrow/chunked_array.h"
#include "arrow/compute/api_vector.h"
#include "arrow/compute/kernels/codegen_internal.h"
+#include "arrow/compute/kernels/gather_internal.h"
#include "arrow/compute/kernels/vector_selection_internal.h"
#include "arrow/compute/kernels/vector_selection_take_internal.h"
#include "arrow/memory_pool.h"
@@ -324,238 +326,79 @@ namespace {
using TakeState = OptionsWrapper<TakeOptions>;
// ----------------------------------------------------------------------
-// Implement optimized take for primitive types from boolean to
1/2/4/8/16/32-byte
-// C-type based types. Use common implementation for every byte width and only
-// generate code for unsigned integer indices, since after boundschecking to
-// check for negative numbers in the indices we can safely reinterpret_cast
-// signed integers as unsigned.
-
-/// \brief The Take implementation for primitive (fixed-width) types does not
-/// use the logical Arrow type but rather the physical C type. This way we
-/// only generate one take function for each byte width.
+// Implement optimized take for primitive types from boolean to
+// 1/2/4/8/16/32-byte C-type based types and fixed-size binary (0 or more
+// bytes).
+//
+// Use one specialization for each of these primitive byte-widths so the
+// compiler can specialize the memcpy to dedicated CPU instructions and for
+// fixed-width binary use the 1-byte specialization but pass WithFactor=true
+// that makes the kernel consider the factor parameter provided at runtime.
+//
+// Only unsigned index types need to be instantiated since after
+// boundschecking to check for negative numbers in the indices we can safely
+// reinterpret_cast signed integers as unsigned.
+
+/// \brief The Take implementation for primitive types and fixed-width binary.
///
/// Also note that this function can also handle fixed-size-list arrays if
/// they fit the criteria described in fixed_width_internal.h, so use the
/// function defined in that file to access values and destination pointers
/// and DO NOT ASSUME `values.type()` is a primitive type.
///
+/// NOTE: Template parameters are types instead of values to let
+/// `TakeIndexDispatch<>` forward `typename... Args` after the index type.
+///
/// \pre the indices have been boundschecked
-template <typename IndexCType, typename ValueWidthConstant>
-struct PrimitiveTakeImpl {
- static constexpr int kValueWidth = ValueWidthConstant::value;
-
- static void Exec(const ArraySpan& values, const ArraySpan& indices,
- ArrayData* out_arr) {
- DCHECK_EQ(util::FixedWidthInBytes(*values.type), kValueWidth);
- const auto* values_data =
util::OffsetPointerOfFixedByteWidthValues(values);
- const uint8_t* values_is_valid = values.buffers[0].data;
- auto values_offset = values.offset;
-
- const auto* indices_data = indices.GetValues<IndexCType>(1);
- const uint8_t* indices_is_valid = indices.buffers[0].data;
- auto indices_offset = indices.offset;
-
- DCHECK_EQ(out_arr->offset, 0);
- auto* out = util::MutableFixedWidthValuesPointer(out_arr);
- auto out_is_valid = out_arr->buffers[0]->mutable_data();
-
- // If either the values or indices have nulls, we preemptively zero out the
- // out validity bitmap so that we don't have to use ClearBit in each
- // iteration for nulls.
- if (values.null_count != 0 || indices.null_count != 0) {
- bit_util::SetBitsTo(out_is_valid, 0, indices.length, false);
- }
-
- auto WriteValue = [&](int64_t position) {
- memcpy(out + position * kValueWidth,
- values_data + indices_data[position] * kValueWidth, kValueWidth);
- };
-
- auto WriteZero = [&](int64_t position) {
- memset(out + position * kValueWidth, 0, kValueWidth);
- };
-
- auto WriteZeroSegment = [&](int64_t position, int64_t length) {
- memset(out + position * kValueWidth, 0, kValueWidth * length);
- };
-
- OptionalBitBlockCounter indices_bit_counter(indices_is_valid,
indices_offset,
- indices.length);
- int64_t position = 0;
+template <typename IndexCType, typename ValueBitWidthConstant,
+ typename OutputIsZeroInitialized = std::false_type,
+ typename WithFactor = std::false_type>
+struct FixedWidthTakeImpl {
+ static constexpr int kValueWidthInBits = ValueBitWidthConstant::value;
+
+ static Status Exec(KernelContext* ctx, const ArraySpan& values,
+ const ArraySpan& indices, ArrayData* out_arr, int64_t
factor) {
+#ifndef NDEBUG
+ int64_t bit_width = util::FixedWidthInBits(*values.type);
+ DCHECK(WithFactor::value || (kValueWidthInBits == bit_width && factor ==
1));
+ DCHECK(!WithFactor::value ||
+ (factor > 0 && kValueWidthInBits == 8 && // factors are used with
bytes
+ static_cast<int64_t>(factor * kValueWidthInBits) == bit_width));
+#endif
+ const bool out_has_validity = values.MayHaveNulls() ||
indices.MayHaveNulls();
+
+ const uint8_t* src;
+ int64_t src_offset;
+ std::tie(src_offset, src) =
util::OffsetPointerOfFixedBitWidthValues(values);
+ uint8_t* out = util::MutableFixedWidthValuesPointer(out_arr);
int64_t valid_count = 0;
- while (position < indices.length) {
- BitBlockCount block = indices_bit_counter.NextBlock();
- if (values.null_count == 0) {
- // Values are never null, so things are easier
- valid_count += block.popcount;
- if (block.popcount == block.length) {
- // Fastest path: neither values nor index nulls
- bit_util::SetBitsTo(out_is_valid, position, block.length, true);
- for (int64_t i = 0; i < block.length; ++i) {
- WriteValue(position);
- ++position;
- }
- } else if (block.popcount > 0) {
- // Slow path: some indices but not all are null
- for (int64_t i = 0; i < block.length; ++i) {
- if (bit_util::GetBit(indices_is_valid, indices_offset + position))
{
- // index is not null
- bit_util::SetBit(out_is_valid, position);
- WriteValue(position);
- } else {
- WriteZero(position);
- }
- ++position;
- }
- } else {
- WriteZeroSegment(position, block.length);
- position += block.length;
- }
- } else {
- // Values have nulls, so we must do random access into the values
bitmap
- if (block.popcount == block.length) {
- // Faster path: indices are not null but values may be
- for (int64_t i = 0; i < block.length; ++i) {
- if (bit_util::GetBit(values_is_valid,
- values_offset + indices_data[position])) {
- // value is not null
- WriteValue(position);
- bit_util::SetBit(out_is_valid, position);
- ++valid_count;
- } else {
- WriteZero(position);
- }
- ++position;
- }
- } else if (block.popcount > 0) {
- // Slow path: some but not all indices are null. Since we are doing
- // random access in general we have to check the value nullness one
by
- // one.
- for (int64_t i = 0; i < block.length; ++i) {
- if (bit_util::GetBit(indices_is_valid, indices_offset + position)
&&
- bit_util::GetBit(values_is_valid,
- values_offset + indices_data[position])) {
- // index is not null && value is not null
- WriteValue(position);
- bit_util::SetBit(out_is_valid, position);
- ++valid_count;
- } else {
- WriteZero(position);
- }
- ++position;
- }
- } else {
- WriteZeroSegment(position, block.length);
- position += block.length;
- }
- }
- }
- out_arr->null_count = out_arr->length - valid_count;
- }
-};
-
-template <typename IndexCType>
-struct BooleanTakeImpl {
- static void Exec(const ArraySpan& values, const ArraySpan& indices,
- ArrayData* out_arr) {
- const uint8_t* values_data = values.buffers[1].data;
- const uint8_t* values_is_valid = values.buffers[0].data;
- auto values_offset = values.offset;
-
- const auto* indices_data = indices.GetValues<IndexCType>(1);
- const uint8_t* indices_is_valid = indices.buffers[0].data;
- auto indices_offset = indices.offset;
-
- auto out = out_arr->buffers[1]->mutable_data();
- auto out_is_valid = out_arr->buffers[0]->mutable_data();
- auto out_offset = out_arr->offset;
-
- // If either the values or indices have nulls, we preemptively zero out the
- // out validity bitmap so that we don't have to use ClearBit in each
- // iteration for nulls.
- if (values.null_count != 0 || indices.null_count != 0) {
- bit_util::SetBitsTo(out_is_valid, out_offset, indices.length, false);
- }
- // Avoid uninitialized data in values array
- bit_util::SetBitsTo(out, out_offset, indices.length, false);
-
- auto PlaceDataBit = [&](int64_t loc, IndexCType index) {
- bit_util::SetBitTo(out, out_offset + loc,
- bit_util::GetBit(values_data, values_offset + index));
- };
-
- OptionalBitBlockCounter indices_bit_counter(indices_is_valid,
indices_offset,
- indices.length);
- int64_t position = 0;
- int64_t valid_count = 0;
- while (position < indices.length) {
- BitBlockCount block = indices_bit_counter.NextBlock();
- if (values.null_count == 0) {
- // Values are never null, so things are easier
- valid_count += block.popcount;
- if (block.popcount == block.length) {
- // Fastest path: neither values nor index nulls
- bit_util::SetBitsTo(out_is_valid, out_offset + position,
block.length, true);
- for (int64_t i = 0; i < block.length; ++i) {
- PlaceDataBit(position, indices_data[position]);
- ++position;
- }
- } else if (block.popcount > 0) {
- // Slow path: some but not all indices are null
- for (int64_t i = 0; i < block.length; ++i) {
- if (bit_util::GetBit(indices_is_valid, indices_offset + position))
{
- // index is not null
- bit_util::SetBit(out_is_valid, out_offset + position);
- PlaceDataBit(position, indices_data[position]);
- }
- ++position;
- }
- } else {
- position += block.length;
- }
- } else {
- // Values have nulls, so we must do random access into the values
bitmap
- if (block.popcount == block.length) {
- // Faster path: indices are not null but values may be
- for (int64_t i = 0; i < block.length; ++i) {
- if (bit_util::GetBit(values_is_valid,
- values_offset + indices_data[position])) {
- // value is not null
- bit_util::SetBit(out_is_valid, out_offset + position);
- PlaceDataBit(position, indices_data[position]);
- ++valid_count;
- }
- ++position;
- }
- } else if (block.popcount > 0) {
- // Slow path: some but not all indices are null. Since we are doing
- // random access in general we have to check the value nullness one
by
- // one.
- for (int64_t i = 0; i < block.length; ++i) {
- if (bit_util::GetBit(indices_is_valid, indices_offset + position))
{
- // index is not null
- if (bit_util::GetBit(values_is_valid,
- values_offset + indices_data[position])) {
- // value is not null
- PlaceDataBit(position, indices_data[position]);
- bit_util::SetBit(out_is_valid, out_offset + position);
- ++valid_count;
- }
- }
- ++position;
- }
- } else {
- position += block.length;
- }
- }
+ arrow::internal::Gather<kValueWidthInBits, IndexCType, WithFactor::value>
gather{
+ /*src_length=*/values.length,
+ src,
+ src_offset,
+ /*idx_length=*/indices.length,
+ /*idx=*/indices.GetValues<IndexCType>(1),
+ out,
+ factor};
+ if (out_has_validity) {
+ DCHECK_EQ(out_arr->offset, 0);
+ // out_is_valid must be zero-initiliazed, because Gather::Execute
+ // saves time by not having to ClearBit on every null element.
+ auto out_is_valid = out_arr->GetMutableValues<uint8_t>(0);
+ memset(out_is_valid, 0, bit_util::BytesForBits(out_arr->length));
+ valid_count = gather.template Execute<OutputIsZeroInitialized::value>(
+ /*src_validity=*/values, /*idx_validity=*/indices, out_is_valid);
+ } else {
+ valid_count = gather.Execute();
}
out_arr->null_count = out_arr->length - valid_count;
+ return Status::OK();
}
};
template <template <typename...> class TakeImpl, typename... Args>
-void TakeIndexDispatch(const ArraySpan& values, const ArraySpan& indices,
- ArrayData* out) {
+Status TakeIndexDispatch(KernelContext* ctx, const ArraySpan& values,
+ const ArraySpan& indices, ArrayData* out, int64_t
factor = 1) {
// With the simplifying assumption that boundschecking has taken place
// already at a higher level, we can now assume that the index values are all
// non-negative. Thus, we can interpret signed integers as unsigned and avoid
@@ -563,22 +406,20 @@ void TakeIndexDispatch(const ArraySpan& values, const
ArraySpan& indices,
// width.
switch (indices.type->byte_width()) {
case 1:
- return TakeImpl<uint8_t, Args...>::Exec(values, indices, out);
+ return TakeImpl<uint8_t, Args...>::Exec(ctx, values, indices, out,
factor);
case 2:
- return TakeImpl<uint16_t, Args...>::Exec(values, indices, out);
+ return TakeImpl<uint16_t, Args...>::Exec(ctx, values, indices, out,
factor);
case 4:
- return TakeImpl<uint32_t, Args...>::Exec(values, indices, out);
- case 8:
- return TakeImpl<uint64_t, Args...>::Exec(values, indices, out);
+ return TakeImpl<uint32_t, Args...>::Exec(ctx, values, indices, out,
factor);
default:
- DCHECK(false) << "Invalid indices byte width";
- break;
+ DCHECK_EQ(indices.type->byte_width(), 8);
+ return TakeImpl<uint64_t, Args...>::Exec(ctx, values, indices, out,
factor);
}
}
} // namespace
-Status PrimitiveTakeExec(KernelContext* ctx, const ExecSpan& batch,
ExecResult* out) {
+Status FixedWidthTakeExec(KernelContext* ctx, const ExecSpan& batch,
ExecResult* out) {
const ArraySpan& values = batch[0].array;
const ArraySpan& indices = batch[1].array;
@@ -587,52 +428,60 @@ Status PrimitiveTakeExec(KernelContext* ctx, const
ExecSpan& batch, ExecResult*
}
ArrayData* out_arr = out->array_data().get();
-
DCHECK(util::IsFixedWidthLike(values));
- const int64_t bit_width = util::FixedWidthInBits(*values.type);
-
- // TODO: When neither values nor indices contain nulls, we can skip
- // allocating the validity bitmap altogether and save time and space. A
- // streamlined PrimitiveTakeImpl would need to be written that skips all
- // interactions with the output validity bitmap, though.
+ // When we know for sure that values nor indices contain nulls, we can skip
+ // allocating the validity bitmap altogether and save time and space.
+ const bool allocate_validity = values.MayHaveNulls() ||
indices.MayHaveNulls();
RETURN_NOT_OK(util::internal::PreallocateFixedWidthArrayData(
- ctx, indices.length, /*source=*/values,
- /*allocate_validity=*/true, out_arr));
- switch (bit_width) {
+ ctx, indices.length, /*source=*/values, allocate_validity, out_arr));
+ switch (util::FixedWidthInBits(*values.type)) {
+ case 0:
+ DCHECK(values.type->id() == Type::FIXED_SIZE_BINARY ||
+ values.type->id() == Type::FIXED_SIZE_LIST);
+ return TakeIndexDispatch<FixedWidthTakeImpl, std::integral_constant<int,
0>>(
+ ctx, values, indices, out_arr);
case 1:
- TakeIndexDispatch<BooleanTakeImpl>(values, indices, out_arr);
- break;
+ // Zero-initialize the data buffer for the output array when the
bit-width is 1
+ // (e.g. Boolean array) to avoid having to ClearBit on every null
element.
+ // This might be profitable for other types as well, but we take the most
+ // conservative approach for now.
+ memset(out_arr->buffers[1]->mutable_data(), 0,
out_arr->buffers[1]->size());
+ return TakeIndexDispatch<
+ FixedWidthTakeImpl, std::integral_constant<int, 1>,
/*OutputIsZeroInitialized=*/
+ std::true_type>(ctx, values, indices, out_arr);
case 8:
- TakeIndexDispatch<PrimitiveTakeImpl, std::integral_constant<int, 1>>(
- values, indices, out_arr);
- break;
+ return TakeIndexDispatch<FixedWidthTakeImpl, std::integral_constant<int,
8>>(
+ ctx, values, indices, out_arr);
case 16:
- TakeIndexDispatch<PrimitiveTakeImpl, std::integral_constant<int, 2>>(
- values, indices, out_arr);
- break;
+ return TakeIndexDispatch<FixedWidthTakeImpl, std::integral_constant<int,
16>>(
+ ctx, values, indices, out_arr);
case 32:
- TakeIndexDispatch<PrimitiveTakeImpl, std::integral_constant<int, 4>>(
- values, indices, out_arr);
- break;
+ return TakeIndexDispatch<FixedWidthTakeImpl, std::integral_constant<int,
32>>(
+ ctx, values, indices, out_arr);
case 64:
- TakeIndexDispatch<PrimitiveTakeImpl, std::integral_constant<int, 8>>(
- values, indices, out_arr);
- break;
+ return TakeIndexDispatch<FixedWidthTakeImpl, std::integral_constant<int,
64>>(
+ ctx, values, indices, out_arr);
case 128:
// For INTERVAL_MONTH_DAY_NANO, DECIMAL128
- TakeIndexDispatch<PrimitiveTakeImpl, std::integral_constant<int, 16>>(
- values, indices, out_arr);
- break;
+ return TakeIndexDispatch<FixedWidthTakeImpl, std::integral_constant<int,
128>>(
+ ctx, values, indices, out_arr);
case 256:
// For DECIMAL256
- TakeIndexDispatch<PrimitiveTakeImpl, std::integral_constant<int, 32>>(
- values, indices, out_arr);
- break;
- default:
- return Status::NotImplemented("Unsupported primitive type for take: ",
- *values.type);
+ return TakeIndexDispatch<FixedWidthTakeImpl, std::integral_constant<int,
256>>(
+ ctx, values, indices, out_arr);
}
- return Status::OK();
+ if (ARROW_PREDICT_TRUE(values.type->id() == Type::FIXED_SIZE_BINARY ||
+ values.type->id() == Type::FIXED_SIZE_LIST)) {
+ int64_t byte_width = util::FixedWidthInBytes(*values.type);
+ // 0-length fixed-size binary or lists were handled above on `case 0`
+ DCHECK_GT(byte_width, 0);
+ return TakeIndexDispatch<FixedWidthTakeImpl,
+ /*ValueBitWidth=*/std::integral_constant<int, 8>,
+ /*OutputIsZeroInitialized=*/std::false_type,
+ /*WithFactor=*/std::true_type>(ctx, values,
indices, out_arr,
+
/*factor=*/byte_width);
+ }
+ return Status::NotImplemented("Unsupported primitive type for take: ",
*values.type);
}
namespace {
@@ -883,13 +732,11 @@ void
PopulateTakeKernels(std::vector<SelectionKernelData>* out) {
auto take_indices = match::Integer();
*out = {
- {InputType(match::Primitive()), take_indices, PrimitiveTakeExec},
+ {InputType(match::Primitive()), take_indices, FixedWidthTakeExec},
{InputType(match::BinaryLike()), take_indices, VarBinaryTakeExec},
{InputType(match::LargeBinaryLike()), take_indices,
LargeVarBinaryTakeExec},
- {InputType(Type::FIXED_SIZE_BINARY), take_indices, FSBTakeExec},
+ {InputType(match::FixedSizeBinaryLike()), take_indices,
FixedWidthTakeExec},
{InputType(null()), take_indices, NullTakeExec},
- {InputType(Type::DECIMAL128), take_indices, PrimitiveTakeExec},
- {InputType(Type::DECIMAL256), take_indices, PrimitiveTakeExec},
{InputType(Type::DICTIONARY), take_indices, DictionaryTake},
{InputType(Type::EXTENSION), take_indices, ExtensionTake},
{InputType(Type::LIST), take_indices, ListTakeExec},
diff --git a/cpp/src/arrow/compute/kernels/vector_selection_test.cc
b/cpp/src/arrow/compute/kernels/vector_selection_test.cc
index 6261fa2dae..cafd889015 100644
--- a/cpp/src/arrow/compute/kernels/vector_selection_test.cc
+++ b/cpp/src/arrow/compute/kernels/vector_selection_test.cc
@@ -1146,6 +1146,15 @@ void ValidateTakeImpl(const std::shared_ptr<Array>&
values,
for (int64_t i = 0; i < indices->length(); ++i) {
if (typed_indices->IsNull(i) ||
typed_values->IsNull(typed_indices->Value(i))) {
ASSERT_TRUE(result->IsNull(i)) << i;
+ // The value of a null element is undefined, but right
+ // out of the Take kernel it is expected to be 0.
+ if constexpr (is_primitive(ValuesType::type_id)) {
+ if constexpr (ValuesType::type_id == Type::BOOL) {
+ ASSERT_EQ(typed_result->Value(i), false);
+ } else {
+ ASSERT_EQ(typed_result->Value(i), 0);
+ }
+ }
} else {
ASSERT_FALSE(result->IsNull(i)) << i;
ASSERT_EQ(typed_result->GetView(i),
typed_values->GetView(typed_indices->Value(i)))
@@ -1522,9 +1531,13 @@ TEST_F(TestTakeKernelWithFixedSizeList,
TakeFixedSizeListInt32) {
CheckTake(fixed_size_list(int32(), 3), list_json, "[3, 0, 0, 3]",
"[[7, 8, null], null, null, [7, 8, null]]");
CheckTake(fixed_size_list(int32(), 3), list_json, "[0, 1, 2, 3]", list_json);
+
+ // No nulls in inner list values trigger the use of FixedWidthTakeExec() in
+ // FSLTakeExec()
+ std::string no_nulls_list_json = "[[0, 0, 0], [1, 2, 3], [4, 5, 6], [7, 8,
9]]";
CheckTake(
- fixed_size_list(int32(), 3), list_json, "[2, 2, 2, 2, 2, 2, 1]",
- "[[4, 5, 6], [4, 5, 6], [4, 5, 6], [4, 5, 6], [4, 5, 6], [4, 5, 6], [1,
null, 3]]");
+ fixed_size_list(int32(), 3), no_nulls_list_json, "[2, 2, 2, 2, 2, 2, 1]",
+ "[[4, 5, 6], [4, 5, 6], [4, 5, 6], [4, 5, 6], [4, 5, 6], [4, 5, 6], [1,
2, 3]]");
this->TestNoValidityBitmapButUnknownNullCount(fixed_size_list(int32(), 3),
"[[1, null, 3], [4, 5, 6], [7,
8, null]]",