This is an automated email from the ASF dual-hosted git repository.

felipecrv pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new 4f89097765 GH-41301: [C++] Extract the kernel loops used for 
PrimitiveTakeExec and generalize to any fixed-width type (#41373)
4f89097765 is described below

commit 4f890977650a36abaaec74ad2eaac31c04b5bf76
Author: Felipe Oliveira Carvalho <[email protected]>
AuthorDate: Mon Jun 10 12:32:46 2024 -0300

    GH-41301: [C++] Extract the kernel loops used for PrimitiveTakeExec and 
generalize to any fixed-width type (#41373)
    
    ### Rationale for this change
    
    I want to instantiate this primitive operation in other scenarios (e.g. the 
optimized version of `Take` that handles chunked arrays) and extend the 
sub-classes of `GatherCRTP` with different member functions that re-use the 
`WriteValue` function generically (any fixed-width type and even bit-wide 
booleans).
    
    When taking these improvements to `Filter` I will also re-use the "gather" 
concept and parameterize it by bitmaps/boolean-arrays instead of selection 
vectors (`indices`) like `take` does. So gather is not a "renaming of take" but 
rather a generalization of `take` and `filter` do in Arrow with different 
representations of what should be gathered from the `values` array.
    
    ### What changes are included in this PR?
    
     - Introduce the Gather class helper to delegate fixed-width memory 
gathering: both static and dynamically sized (size known at compile time or 
size known at runtime)
     - Specialized `Take` implementation for values/indices without nulls
     - Fold the Boolean, Primitives, and Fixed-Width Binary implementation of 
`Take` into a single one
     - Skip validity bitmap allocation when inputs (values and indices) have no 
nulls
    
    ### Are these changes tested?
    
     - Existing tests
     - New test assertions that check that `Take` guarantees null values are 
zeroed out
    
    * GitHub Issue: #41301
    
    Authored-by: Felipe Oliveira Carvalho <[email protected]>
    Signed-off-by: Felipe Oliveira Carvalho <[email protected]>
---
 cpp/src/arrow/compute/kernels/gather_internal.h    | 306 +++++++++++++++++
 .../compute/kernels/vector_selection_internal.cc   |  68 +---
 .../compute/kernels/vector_selection_internal.h    |   3 +-
 .../kernels/vector_selection_take_internal.cc      | 375 ++++++---------------
 .../arrow/compute/kernels/vector_selection_test.cc |  17 +-
 5 files changed, 435 insertions(+), 334 deletions(-)

diff --git a/cpp/src/arrow/compute/kernels/gather_internal.h 
b/cpp/src/arrow/compute/kernels/gather_internal.h
new file mode 100644
index 0000000000..4c161533a7
--- /dev/null
+++ b/cpp/src/arrow/compute/kernels/gather_internal.h
@@ -0,0 +1,306 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cassert>
+#include <cstddef>
+#include <cstdint>
+
+#include "arrow/array/data.h"
+#include "arrow/util/bit_block_counter.h"
+#include "arrow/util/bit_run_reader.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/bitmap_ops.h"
+#include "arrow/util/macros.h"
+
+// Implementation helpers for kernels that need to load/gather fixed-width
+// data from multiple, arbitrary indices.
+//
+// https://en.wikipedia.org/wiki/Gather/scatter_(vector_addressing)
+
+namespace arrow::internal {
+
+// CRTP [1] base class for Gather that provides a gathering loop in terms of
+// Write*() methods that must be implemented by the derived class.
+//
+// [1] https://en.wikipedia.org/wiki/Curiously_recurring_template_pattern
+template <class GatherImpl>
+class GatherBaseCRTP {
+ public:
+  // Output offset is not supported by Gather and idx is supposed to have 
offset
+  // pre-applied. idx_validity parameters on functions can use the offset they
+  // carry to read the validity bitmap as bitmaps can't have pre-applied 
offsets
+  // (they might not align to byte boundaries).
+
+  GatherBaseCRTP() = default;
+  ARROW_DISALLOW_COPY_AND_ASSIGN(GatherBaseCRTP);
+  ARROW_DEFAULT_MOVE_AND_ASSIGN(GatherBaseCRTP);
+
+ protected:
+  ARROW_FORCE_INLINE int64_t ExecuteNoNulls(int64_t idx_length) {
+    auto* self = static_cast<GatherImpl*>(this);
+    for (int64_t position = 0; position < idx_length; position++) {
+      self->WriteValue(position);
+    }
+    return idx_length;
+  }
+
+  // See derived Gather classes below for the meaning of the parameters, pre 
and
+  // post-conditions.
+  //
+  // src_validity is not necessarily the source of the values that are being
+  // gathered (e.g. the source could be a nested fixed-size list array and the
+  // values being gathered are from the innermost buffer), so the ArraySpan is
+  // used solely to check for nulls in the source values and nothing else.
+  //
+  // idx_length is the number of elements in idx and consequently the number of
+  // bits that might be written to out_is_valid. Member `Write*()` functions 
will be
+  // called with positions from 0 to idx_length - 1.
+  //
+  // If `kOutputIsZeroInitialized` is true, then `WriteZero()` or 
`WriteZeroSegment()`
+  // doesn't have to be called for resulting null positions. A position is
+  // considered null if either the index or the source value is null at that
+  // position.
+  template <bool kOutputIsZeroInitialized, typename IndexCType>
+  ARROW_FORCE_INLINE int64_t ExecuteWithNulls(const ArraySpan& src_validity,
+                                              int64_t idx_length, const 
IndexCType* idx,
+                                              const ArraySpan& idx_validity,
+                                              uint8_t* out_is_valid) {
+    auto* self = static_cast<GatherImpl*>(this);
+    OptionalBitBlockCounter indices_bit_counter(idx_validity.buffers[0].data,
+                                                idx_validity.offset, 
idx_length);
+    int64_t position = 0;
+    int64_t valid_count = 0;
+    while (position < idx_length) {
+      BitBlockCount block = indices_bit_counter.NextBlock();
+      if (!src_validity.MayHaveNulls()) {
+        // Source values are never null, so things are easier
+        valid_count += block.popcount;
+        if (block.popcount == block.length) {
+          // Fastest path: neither source values nor index nulls
+          bit_util::SetBitsTo(out_is_valid, position, block.length, true);
+          for (int64_t i = 0; i < block.length; ++i) {
+            self->WriteValue(position);
+            ++position;
+          }
+        } else if (block.popcount > 0) {
+          // Slow path: some indices but not all are null
+          for (int64_t i = 0; i < block.length; ++i) {
+            ARROW_COMPILER_ASSUME(idx_validity.buffers[0].data != nullptr);
+            if (idx_validity.IsValid(position)) {
+              // index is not null
+              bit_util::SetBit(out_is_valid, position);
+              self->WriteValue(position);
+            } else if constexpr (!kOutputIsZeroInitialized) {
+              self->WriteZero(position);
+            }
+            ++position;
+          }
+        } else {
+          self->WriteZeroSegment(position, block.length);
+          position += block.length;
+        }
+      } else {
+        // Source values may be null, so we must do random access into 
src_validity
+        if (block.popcount == block.length) {
+          // Faster path: indices are not null but source values may be
+          for (int64_t i = 0; i < block.length; ++i) {
+            ARROW_COMPILER_ASSUME(src_validity.buffers[0].data != nullptr);
+            if (src_validity.IsValid(idx[position])) {
+              // value is not null
+              self->WriteValue(position);
+              bit_util::SetBit(out_is_valid, position);
+              ++valid_count;
+            } else if constexpr (!kOutputIsZeroInitialized) {
+              self->WriteZero(position);
+            }
+            ++position;
+          }
+        } else if (block.popcount > 0) {
+          // Slow path: some but not all indices are null. Since we are doing
+          // random access in general we have to check the value nullness one 
by
+          // one.
+          for (int64_t i = 0; i < block.length; ++i) {
+            ARROW_COMPILER_ASSUME(src_validity.buffers[0].data != nullptr);
+            ARROW_COMPILER_ASSUME(idx_validity.buffers[0].data != nullptr);
+            if (idx_validity.IsValid(position) && 
src_validity.IsValid(idx[position])) {
+              // index is not null && value is not null
+              self->WriteValue(position);
+              bit_util::SetBit(out_is_valid, position);
+              ++valid_count;
+            } else if constexpr (!kOutputIsZeroInitialized) {
+              self->WriteZero(position);
+            }
+            ++position;
+          }
+        } else {
+          if constexpr (!kOutputIsZeroInitialized) {
+            self->WriteZeroSegment(position, block.length);
+          }
+          position += block.length;
+        }
+      }
+    }
+    return valid_count;
+  }
+};
+
+// A gather primitive for primitive fixed-width types with a integral byte 
width. If
+// `kWithFactor` is true, the actual width is a runtime multiple of 
`kValueWidthInbits`
+// (this can be useful for fixed-size list inputs and other input types with 
unusual byte
+// widths that don't deserve value specialization).
+template <int kValueWidthInBits, typename IndexCType, bool kWithFactor>
+class Gather : public GatherBaseCRTP<Gather<kValueWidthInBits, IndexCType, 
kWithFactor>> {
+ public:
+  static_assert(kValueWidthInBits >= 0 && kValueWidthInBits % 8 == 0);
+  static constexpr int kValueWidth = kValueWidthInBits / 8;
+
+ private:
+  const int64_t src_length_;  // number of elements of kValueWidth bytes in 
src_
+  const uint8_t* src_;
+  const int64_t idx_length_;  // number IndexCType elements in idx_
+  const IndexCType* idx_;
+  uint8_t* out_;
+  int64_t factor_;
+
+ public:
+  void WriteValue(int64_t position) {
+    if constexpr (kWithFactor) {
+      const int64_t scaled_factor = kValueWidth * factor_;
+      memcpy(out_ + position * scaled_factor, src_ + idx_[position] * 
scaled_factor,
+             scaled_factor);
+    } else {
+      memcpy(out_ + position * kValueWidth, src_ + idx_[position] * 
kValueWidth,
+             kValueWidth);
+    }
+  }
+
+  void WriteZero(int64_t position) {
+    if constexpr (kWithFactor) {
+      const int64_t scaled_factor = kValueWidth * factor_;
+      memset(out_ + position * scaled_factor, 0, scaled_factor);
+    } else {
+      memset(out_ + position * kValueWidth, 0, kValueWidth);
+    }
+  }
+
+  void WriteZeroSegment(int64_t position, int64_t length) {
+    if constexpr (kWithFactor) {
+      const int64_t scaled_factor = kValueWidth * factor_;
+      memset(out_ + position * scaled_factor, 0, length * scaled_factor);
+    } else {
+      memset(out_ + position * kValueWidth, 0, length * kValueWidth);
+    }
+  }
+
+ public:
+  Gather(int64_t src_length, const uint8_t* src, int64_t zero_src_offset,
+         int64_t idx_length, const IndexCType* idx, uint8_t* out, int64_t 
factor)
+      : src_length_(src_length),
+        src_(src),
+        idx_length_(idx_length),
+        idx_(idx),
+        out_(out),
+        factor_(factor) {
+    assert(zero_src_offset == 0);
+    assert(src && idx && out);
+    assert((kWithFactor || factor == 1) &&
+           "When kWithFactor is false, the factor is assumed to be 1 at 
compile time");
+  }
+
+  ARROW_FORCE_INLINE int64_t Execute() { return 
this->ExecuteNoNulls(idx_length_); }
+
+  /// \pre If kOutputIsZeroInitialized, then this->out_ has to be zero 
initialized.
+  /// \pre Bits in out_is_valid have to always be zero initialized.
+  /// \post The bits for the valid elements (and only those) are set in 
out_is_valid.
+  /// \post If !kOutputIsZeroInitialized, then positions in this->_out 
containing null
+  ///       elements have 0s written to them. This might be less efficient than
+  ///       zero-initializing first and calling this->Execute() afterwards.
+  /// \return The number of valid elements in out.
+  template <bool kOutputIsZeroInitialized = false>
+  ARROW_FORCE_INLINE int64_t Execute(const ArraySpan& src_validity,
+                                     const ArraySpan& idx_validity,
+                                     uint8_t* out_is_valid) {
+    assert(src_length_ == src_validity.length);
+    assert(idx_length_ == idx_validity.length);
+    assert(out_is_valid);
+    return this->template ExecuteWithNulls<kOutputIsZeroInitialized>(
+        src_validity, idx_length_, idx_, idx_validity, out_is_valid);
+  }
+};
+
+// A gather primitive for boolean inputs. Unlike its counterpart above,
+// this does not support passing a non-trivial factor parameter.
+template <typename IndexCType>
+class Gather</*kValueWidthInBits=*/1, IndexCType, /*kWithFactor=*/false>
+    : public GatherBaseCRTP<Gather<1, IndexCType, false>> {
+ private:
+  const int64_t src_length_;  // number of elements of bits bytes in src_ 
after offset
+  const uint8_t* src_;        // the boolean array data buffer in bits
+  const int64_t src_offset_;  // offset in bits
+  const int64_t idx_length_;  // number IndexCType elements in idx_
+  const IndexCType* idx_;
+  uint8_t* out_;  // output boolean array data buffer in bits
+
+ public:
+  Gather(int64_t src_length, const uint8_t* src, int64_t src_offset, int64_t 
idx_length,
+         const IndexCType* idx, uint8_t* out, int64_t factor)
+      : src_length_(src_length),
+        src_(src),
+        src_offset_(src_offset),
+        idx_length_(idx_length),
+        idx_(idx),
+        out_(out) {
+    assert(src && idx && out);
+    assert(factor == 1 &&
+           "factor != 1 is not supported when Gather is used to gather 
bits/booleans");
+  }
+
+  void WriteValue(int64_t position) {
+    bit_util::SetBitTo(out_, position,
+                       bit_util::GetBit(src_, src_offset_ + idx_[position]));
+  }
+
+  void WriteZero(int64_t position) { bit_util::ClearBit(out_, position); }
+
+  void WriteZeroSegment(int64_t position, int64_t block_length) {
+    bit_util::SetBitsTo(out_, position, block_length, false);
+  }
+
+  ARROW_FORCE_INLINE int64_t Execute() { return 
this->ExecuteNoNulls(idx_length_); }
+
+  /// \pre If kOutputIsZeroInitialized, then this->out_ has to be zero 
initialized.
+  /// \pre Bits in out_is_valid have to always be zero initialized.
+  /// \post The bits for the valid elements (and only those) are set in 
out_is_valid.
+  /// \post If !kOutputIsZeroInitialized, then positions in this->_out 
containing null
+  ///       elements have 0s written to them. This might be less efficient than
+  ///       zero-initializing first and calling this->Execute() afterwards.
+  /// \return The number of valid elements in out.
+  template <bool kOutputIsZeroInitialized = false>
+  ARROW_FORCE_INLINE int64_t Execute(const ArraySpan& src_validity,
+                                     const ArraySpan& idx_validity,
+                                     uint8_t* out_is_valid) {
+    assert(src_length_ == src_validity.length);
+    assert(idx_length_ == idx_validity.length);
+    assert(out_is_valid);
+    return this->template ExecuteWithNulls<kOutputIsZeroInitialized>(
+        src_validity, idx_length_, idx_, idx_validity, out_is_valid);
+  }
+};
+
+}  // namespace arrow::internal
diff --git a/cpp/src/arrow/compute/kernels/vector_selection_internal.cc 
b/cpp/src/arrow/compute/kernels/vector_selection_internal.cc
index 2ba660e49a..1009bea5e7 100644
--- a/cpp/src/arrow/compute/kernels/vector_selection_internal.cc
+++ b/cpp/src/arrow/compute/kernels/vector_selection_internal.cc
@@ -547,39 +547,6 @@ struct VarBinarySelectionImpl : public 
Selection<VarBinarySelectionImpl<Type>, T
   }
 };
 
-struct FSBSelectionImpl : public Selection<FSBSelectionImpl, 
FixedSizeBinaryType> {
-  using Base = Selection<FSBSelectionImpl, FixedSizeBinaryType>;
-  LIFT_BASE_MEMBERS();
-
-  TypedBufferBuilder<uint8_t> data_builder;
-
-  FSBSelectionImpl(KernelContext* ctx, const ExecSpan& batch, int64_t 
output_length,
-                   ExecResult* out)
-      : Base(ctx, batch, output_length, out), data_builder(ctx->memory_pool()) 
{}
-
-  template <typename Adapter>
-  Status GenerateOutput() {
-    FixedSizeBinaryArray typed_values(this->values.ToArrayData());
-    int32_t value_size = typed_values.byte_width();
-
-    RETURN_NOT_OK(data_builder.Reserve(value_size * output_length));
-    Adapter adapter(this);
-    return adapter.Generate(
-        [&](int64_t index) {
-          auto val = typed_values.GetView(index);
-          data_builder.UnsafeAppend(reinterpret_cast<const 
uint8_t*>(val.data()),
-                                    value_size);
-          return Status::OK();
-        },
-        [&]() {
-          data_builder.UnsafeAppend(value_size, static_cast<uint8_t>(0x00));
-          return Status::OK();
-        });
-  }
-
-  Status Finish() override { return data_builder.Finish(&out->buffers[1]); }
-};
-
 template <typename Type>
 struct ListSelectionImpl : public Selection<ListSelectionImpl<Type>, Type> {
   using offset_type = typename Type::offset_type;
@@ -939,23 +906,6 @@ Status LargeVarBinaryTakeExec(KernelContext* ctx, const 
ExecSpan& batch,
   return TakeExec<VarBinarySelectionImpl<LargeBinaryType>>(ctx, batch, out);
 }
 
-Status FSBTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) 
{
-  const ArraySpan& values = batch[0].array;
-  const auto byte_width = values.type->byte_width();
-  // Use primitive Take implementation (presumably faster) for some byte widths
-  switch (byte_width) {
-    case 1:
-    case 2:
-    case 4:
-    case 8:
-    case 16:
-    case 32:
-      return PrimitiveTakeExec(ctx, batch, out);
-    default:
-      return TakeExec<FSBSelectionImpl>(ctx, batch, out);
-  }
-}
-
 Status ListTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* 
out) {
   return TakeExec<ListSelectionImpl<ListType>>(ctx, batch, out);
 }
@@ -968,26 +918,12 @@ Status FSLTakeExec(KernelContext* ctx, const ExecSpan& 
batch, ExecResult* out) {
   const ArraySpan& values = batch[0].array;
 
   // If a FixedSizeList wraps a fixed-width type we can, in some cases, use
-  // PrimitiveTakeExec for a fixed-size list array.
+  // FixedWidthTakeExec for a fixed-size list array.
   if (util::IsFixedWidthLike(values,
                              /*force_null_count=*/true,
                              /*exclude_bool_and_dictionary=*/true)) {
-    const auto byte_width = util::FixedWidthInBytes(*values.type);
-    // Additionally, PrimitiveTakeExec is only implemented for specific byte 
widths.
-    // TODO(GH-41301): Extend PrimitiveTakeExec for any fixed-width type.
-    switch (byte_width) {
-      case 1:
-      case 2:
-      case 4:
-      case 8:
-      case 16:
-      case 32:
-        return PrimitiveTakeExec(ctx, batch, out);
-      default:
-        break;  // fallback to TakeExec<FSBSelectionImpl>
-    }
+    return FixedWidthTakeExec(ctx, batch, out);
   }
-
   return TakeExec<FSLSelectionImpl>(ctx, batch, out);
 }
 
diff --git a/cpp/src/arrow/compute/kernels/vector_selection_internal.h 
b/cpp/src/arrow/compute/kernels/vector_selection_internal.h
index a169f4b38a..c5075d6dfe 100644
--- a/cpp/src/arrow/compute/kernels/vector_selection_internal.h
+++ b/cpp/src/arrow/compute/kernels/vector_selection_internal.h
@@ -73,8 +73,7 @@ Status MapFilterExec(KernelContext*, const ExecSpan&, 
ExecResult*);
 
 Status VarBinaryTakeExec(KernelContext*, const ExecSpan&, ExecResult*);
 Status LargeVarBinaryTakeExec(KernelContext*, const ExecSpan&, ExecResult*);
-Status PrimitiveTakeExec(KernelContext*, const ExecSpan&, ExecResult*);
-Status FSBTakeExec(KernelContext*, const ExecSpan&, ExecResult*);
+Status FixedWidthTakeExec(KernelContext*, const ExecSpan&, ExecResult*);
 Status ListTakeExec(KernelContext*, const ExecSpan&, ExecResult*);
 Status LargeListTakeExec(KernelContext*, const ExecSpan&, ExecResult*);
 Status FSLTakeExec(KernelContext*, const ExecSpan&, ExecResult*);
diff --git a/cpp/src/arrow/compute/kernels/vector_selection_take_internal.cc 
b/cpp/src/arrow/compute/kernels/vector_selection_take_internal.cc
index 1a9af0efcd..dee80e9d25 100644
--- a/cpp/src/arrow/compute/kernels/vector_selection_take_internal.cc
+++ b/cpp/src/arrow/compute/kernels/vector_selection_take_internal.cc
@@ -19,6 +19,7 @@
 #include <cstring>
 #include <limits>
 #include <memory>
+#include <utility>
 #include <vector>
 
 #include "arrow/array/builder_primitive.h"
@@ -27,6 +28,7 @@
 #include "arrow/chunked_array.h"
 #include "arrow/compute/api_vector.h"
 #include "arrow/compute/kernels/codegen_internal.h"
+#include "arrow/compute/kernels/gather_internal.h"
 #include "arrow/compute/kernels/vector_selection_internal.h"
 #include "arrow/compute/kernels/vector_selection_take_internal.h"
 #include "arrow/memory_pool.h"
@@ -324,238 +326,79 @@ namespace {
 using TakeState = OptionsWrapper<TakeOptions>;
 
 // ----------------------------------------------------------------------
-// Implement optimized take for primitive types from boolean to 
1/2/4/8/16/32-byte
-// C-type based types. Use common implementation for every byte width and only
-// generate code for unsigned integer indices, since after boundschecking to
-// check for negative numbers in the indices we can safely reinterpret_cast
-// signed integers as unsigned.
-
-/// \brief The Take implementation for primitive (fixed-width) types does not
-/// use the logical Arrow type but rather the physical C type. This way we
-/// only generate one take function for each byte width.
+// Implement optimized take for primitive types from boolean to
+// 1/2/4/8/16/32-byte C-type based types and fixed-size binary (0 or more
+// bytes).
+//
+// Use one specialization for each of these primitive byte-widths so the
+// compiler can specialize the memcpy to dedicated CPU instructions and for
+// fixed-width binary use the 1-byte specialization but pass WithFactor=true
+// that makes the kernel consider the factor parameter provided at runtime.
+//
+// Only unsigned index types need to be instantiated since after
+// boundschecking to check for negative numbers in the indices we can safely
+// reinterpret_cast signed integers as unsigned.
+
+/// \brief The Take implementation for primitive types and fixed-width binary.
 ///
 /// Also note that this function can also handle fixed-size-list arrays if
 /// they fit the criteria described in fixed_width_internal.h, so use the
 /// function defined in that file to access values and destination pointers
 /// and DO NOT ASSUME `values.type()` is a primitive type.
 ///
+/// NOTE: Template parameters are types instead of values to let
+/// `TakeIndexDispatch<>` forward `typename... Args`  after the index type.
+///
 /// \pre the indices have been boundschecked
-template <typename IndexCType, typename ValueWidthConstant>
-struct PrimitiveTakeImpl {
-  static constexpr int kValueWidth = ValueWidthConstant::value;
-
-  static void Exec(const ArraySpan& values, const ArraySpan& indices,
-                   ArrayData* out_arr) {
-    DCHECK_EQ(util::FixedWidthInBytes(*values.type), kValueWidth);
-    const auto* values_data = 
util::OffsetPointerOfFixedByteWidthValues(values);
-    const uint8_t* values_is_valid = values.buffers[0].data;
-    auto values_offset = values.offset;
-
-    const auto* indices_data = indices.GetValues<IndexCType>(1);
-    const uint8_t* indices_is_valid = indices.buffers[0].data;
-    auto indices_offset = indices.offset;
-
-    DCHECK_EQ(out_arr->offset, 0);
-    auto* out = util::MutableFixedWidthValuesPointer(out_arr);
-    auto out_is_valid = out_arr->buffers[0]->mutable_data();
-
-    // If either the values or indices have nulls, we preemptively zero out the
-    // out validity bitmap so that we don't have to use ClearBit in each
-    // iteration for nulls.
-    if (values.null_count != 0 || indices.null_count != 0) {
-      bit_util::SetBitsTo(out_is_valid, 0, indices.length, false);
-    }
-
-    auto WriteValue = [&](int64_t position) {
-      memcpy(out + position * kValueWidth,
-             values_data + indices_data[position] * kValueWidth, kValueWidth);
-    };
-
-    auto WriteZero = [&](int64_t position) {
-      memset(out + position * kValueWidth, 0, kValueWidth);
-    };
-
-    auto WriteZeroSegment = [&](int64_t position, int64_t length) {
-      memset(out + position * kValueWidth, 0, kValueWidth * length);
-    };
-
-    OptionalBitBlockCounter indices_bit_counter(indices_is_valid, 
indices_offset,
-                                                indices.length);
-    int64_t position = 0;
+template <typename IndexCType, typename ValueBitWidthConstant,
+          typename OutputIsZeroInitialized = std::false_type,
+          typename WithFactor = std::false_type>
+struct FixedWidthTakeImpl {
+  static constexpr int kValueWidthInBits = ValueBitWidthConstant::value;
+
+  static Status Exec(KernelContext* ctx, const ArraySpan& values,
+                     const ArraySpan& indices, ArrayData* out_arr, int64_t 
factor) {
+#ifndef NDEBUG
+    int64_t bit_width = util::FixedWidthInBits(*values.type);
+    DCHECK(WithFactor::value || (kValueWidthInBits == bit_width && factor == 
1));
+    DCHECK(!WithFactor::value ||
+           (factor > 0 && kValueWidthInBits == 8 &&  // factors are used with 
bytes
+            static_cast<int64_t>(factor * kValueWidthInBits) == bit_width));
+#endif
+    const bool out_has_validity = values.MayHaveNulls() || 
indices.MayHaveNulls();
+
+    const uint8_t* src;
+    int64_t src_offset;
+    std::tie(src_offset, src) = 
util::OffsetPointerOfFixedBitWidthValues(values);
+    uint8_t* out = util::MutableFixedWidthValuesPointer(out_arr);
     int64_t valid_count = 0;
-    while (position < indices.length) {
-      BitBlockCount block = indices_bit_counter.NextBlock();
-      if (values.null_count == 0) {
-        // Values are never null, so things are easier
-        valid_count += block.popcount;
-        if (block.popcount == block.length) {
-          // Fastest path: neither values nor index nulls
-          bit_util::SetBitsTo(out_is_valid, position, block.length, true);
-          for (int64_t i = 0; i < block.length; ++i) {
-            WriteValue(position);
-            ++position;
-          }
-        } else if (block.popcount > 0) {
-          // Slow path: some indices but not all are null
-          for (int64_t i = 0; i < block.length; ++i) {
-            if (bit_util::GetBit(indices_is_valid, indices_offset + position)) 
{
-              // index is not null
-              bit_util::SetBit(out_is_valid, position);
-              WriteValue(position);
-            } else {
-              WriteZero(position);
-            }
-            ++position;
-          }
-        } else {
-          WriteZeroSegment(position, block.length);
-          position += block.length;
-        }
-      } else {
-        // Values have nulls, so we must do random access into the values 
bitmap
-        if (block.popcount == block.length) {
-          // Faster path: indices are not null but values may be
-          for (int64_t i = 0; i < block.length; ++i) {
-            if (bit_util::GetBit(values_is_valid,
-                                 values_offset + indices_data[position])) {
-              // value is not null
-              WriteValue(position);
-              bit_util::SetBit(out_is_valid, position);
-              ++valid_count;
-            } else {
-              WriteZero(position);
-            }
-            ++position;
-          }
-        } else if (block.popcount > 0) {
-          // Slow path: some but not all indices are null. Since we are doing
-          // random access in general we have to check the value nullness one 
by
-          // one.
-          for (int64_t i = 0; i < block.length; ++i) {
-            if (bit_util::GetBit(indices_is_valid, indices_offset + position) 
&&
-                bit_util::GetBit(values_is_valid,
-                                 values_offset + indices_data[position])) {
-              // index is not null && value is not null
-              WriteValue(position);
-              bit_util::SetBit(out_is_valid, position);
-              ++valid_count;
-            } else {
-              WriteZero(position);
-            }
-            ++position;
-          }
-        } else {
-          WriteZeroSegment(position, block.length);
-          position += block.length;
-        }
-      }
-    }
-    out_arr->null_count = out_arr->length - valid_count;
-  }
-};
-
-template <typename IndexCType>
-struct BooleanTakeImpl {
-  static void Exec(const ArraySpan& values, const ArraySpan& indices,
-                   ArrayData* out_arr) {
-    const uint8_t* values_data = values.buffers[1].data;
-    const uint8_t* values_is_valid = values.buffers[0].data;
-    auto values_offset = values.offset;
-
-    const auto* indices_data = indices.GetValues<IndexCType>(1);
-    const uint8_t* indices_is_valid = indices.buffers[0].data;
-    auto indices_offset = indices.offset;
-
-    auto out = out_arr->buffers[1]->mutable_data();
-    auto out_is_valid = out_arr->buffers[0]->mutable_data();
-    auto out_offset = out_arr->offset;
-
-    // If either the values or indices have nulls, we preemptively zero out the
-    // out validity bitmap so that we don't have to use ClearBit in each
-    // iteration for nulls.
-    if (values.null_count != 0 || indices.null_count != 0) {
-      bit_util::SetBitsTo(out_is_valid, out_offset, indices.length, false);
-    }
-    // Avoid uninitialized data in values array
-    bit_util::SetBitsTo(out, out_offset, indices.length, false);
-
-    auto PlaceDataBit = [&](int64_t loc, IndexCType index) {
-      bit_util::SetBitTo(out, out_offset + loc,
-                         bit_util::GetBit(values_data, values_offset + index));
-    };
-
-    OptionalBitBlockCounter indices_bit_counter(indices_is_valid, 
indices_offset,
-                                                indices.length);
-    int64_t position = 0;
-    int64_t valid_count = 0;
-    while (position < indices.length) {
-      BitBlockCount block = indices_bit_counter.NextBlock();
-      if (values.null_count == 0) {
-        // Values are never null, so things are easier
-        valid_count += block.popcount;
-        if (block.popcount == block.length) {
-          // Fastest path: neither values nor index nulls
-          bit_util::SetBitsTo(out_is_valid, out_offset + position, 
block.length, true);
-          for (int64_t i = 0; i < block.length; ++i) {
-            PlaceDataBit(position, indices_data[position]);
-            ++position;
-          }
-        } else if (block.popcount > 0) {
-          // Slow path: some but not all indices are null
-          for (int64_t i = 0; i < block.length; ++i) {
-            if (bit_util::GetBit(indices_is_valid, indices_offset + position)) 
{
-              // index is not null
-              bit_util::SetBit(out_is_valid, out_offset + position);
-              PlaceDataBit(position, indices_data[position]);
-            }
-            ++position;
-          }
-        } else {
-          position += block.length;
-        }
-      } else {
-        // Values have nulls, so we must do random access into the values 
bitmap
-        if (block.popcount == block.length) {
-          // Faster path: indices are not null but values may be
-          for (int64_t i = 0; i < block.length; ++i) {
-            if (bit_util::GetBit(values_is_valid,
-                                 values_offset + indices_data[position])) {
-              // value is not null
-              bit_util::SetBit(out_is_valid, out_offset + position);
-              PlaceDataBit(position, indices_data[position]);
-              ++valid_count;
-            }
-            ++position;
-          }
-        } else if (block.popcount > 0) {
-          // Slow path: some but not all indices are null. Since we are doing
-          // random access in general we have to check the value nullness one 
by
-          // one.
-          for (int64_t i = 0; i < block.length; ++i) {
-            if (bit_util::GetBit(indices_is_valid, indices_offset + position)) 
{
-              // index is not null
-              if (bit_util::GetBit(values_is_valid,
-                                   values_offset + indices_data[position])) {
-                // value is not null
-                PlaceDataBit(position, indices_data[position]);
-                bit_util::SetBit(out_is_valid, out_offset + position);
-                ++valid_count;
-              }
-            }
-            ++position;
-          }
-        } else {
-          position += block.length;
-        }
-      }
+    arrow::internal::Gather<kValueWidthInBits, IndexCType, WithFactor::value> 
gather{
+        /*src_length=*/values.length,
+        src,
+        src_offset,
+        /*idx_length=*/indices.length,
+        /*idx=*/indices.GetValues<IndexCType>(1),
+        out,
+        factor};
+    if (out_has_validity) {
+      DCHECK_EQ(out_arr->offset, 0);
+      // out_is_valid must be zero-initiliazed, because Gather::Execute
+      // saves time by not having to ClearBit on every null element.
+      auto out_is_valid = out_arr->GetMutableValues<uint8_t>(0);
+      memset(out_is_valid, 0, bit_util::BytesForBits(out_arr->length));
+      valid_count = gather.template Execute<OutputIsZeroInitialized::value>(
+          /*src_validity=*/values, /*idx_validity=*/indices, out_is_valid);
+    } else {
+      valid_count = gather.Execute();
     }
     out_arr->null_count = out_arr->length - valid_count;
+    return Status::OK();
   }
 };
 
 template <template <typename...> class TakeImpl, typename... Args>
-void TakeIndexDispatch(const ArraySpan& values, const ArraySpan& indices,
-                       ArrayData* out) {
+Status TakeIndexDispatch(KernelContext* ctx, const ArraySpan& values,
+                         const ArraySpan& indices, ArrayData* out, int64_t 
factor = 1) {
   // With the simplifying assumption that boundschecking has taken place
   // already at a higher level, we can now assume that the index values are all
   // non-negative. Thus, we can interpret signed integers as unsigned and avoid
@@ -563,22 +406,20 @@ void TakeIndexDispatch(const ArraySpan& values, const 
ArraySpan& indices,
   // width.
   switch (indices.type->byte_width()) {
     case 1:
-      return TakeImpl<uint8_t, Args...>::Exec(values, indices, out);
+      return TakeImpl<uint8_t, Args...>::Exec(ctx, values, indices, out, 
factor);
     case 2:
-      return TakeImpl<uint16_t, Args...>::Exec(values, indices, out);
+      return TakeImpl<uint16_t, Args...>::Exec(ctx, values, indices, out, 
factor);
     case 4:
-      return TakeImpl<uint32_t, Args...>::Exec(values, indices, out);
-    case 8:
-      return TakeImpl<uint64_t, Args...>::Exec(values, indices, out);
+      return TakeImpl<uint32_t, Args...>::Exec(ctx, values, indices, out, 
factor);
     default:
-      DCHECK(false) << "Invalid indices byte width";
-      break;
+      DCHECK_EQ(indices.type->byte_width(), 8);
+      return TakeImpl<uint64_t, Args...>::Exec(ctx, values, indices, out, 
factor);
   }
 }
 
 }  // namespace
 
-Status PrimitiveTakeExec(KernelContext* ctx, const ExecSpan& batch, 
ExecResult* out) {
+Status FixedWidthTakeExec(KernelContext* ctx, const ExecSpan& batch, 
ExecResult* out) {
   const ArraySpan& values = batch[0].array;
   const ArraySpan& indices = batch[1].array;
 
@@ -587,52 +428,60 @@ Status PrimitiveTakeExec(KernelContext* ctx, const 
ExecSpan& batch, ExecResult*
   }
 
   ArrayData* out_arr = out->array_data().get();
-
   DCHECK(util::IsFixedWidthLike(values));
-  const int64_t bit_width = util::FixedWidthInBits(*values.type);
-
-  // TODO: When neither values nor indices contain nulls, we can skip
-  // allocating the validity bitmap altogether and save time and space. A
-  // streamlined PrimitiveTakeImpl would need to be written that skips all
-  // interactions with the output validity bitmap, though.
+  // When we know for sure that values nor indices contain nulls, we can skip
+  // allocating the validity bitmap altogether and save time and space.
+  const bool allocate_validity = values.MayHaveNulls() || 
indices.MayHaveNulls();
   RETURN_NOT_OK(util::internal::PreallocateFixedWidthArrayData(
-      ctx, indices.length, /*source=*/values,
-      /*allocate_validity=*/true, out_arr));
-  switch (bit_width) {
+      ctx, indices.length, /*source=*/values, allocate_validity, out_arr));
+  switch (util::FixedWidthInBits(*values.type)) {
+    case 0:
+      DCHECK(values.type->id() == Type::FIXED_SIZE_BINARY ||
+             values.type->id() == Type::FIXED_SIZE_LIST);
+      return TakeIndexDispatch<FixedWidthTakeImpl, std::integral_constant<int, 
0>>(
+          ctx, values, indices, out_arr);
     case 1:
-      TakeIndexDispatch<BooleanTakeImpl>(values, indices, out_arr);
-      break;
+      // Zero-initialize the data buffer for the output array when the 
bit-width is 1
+      // (e.g. Boolean array) to avoid having to ClearBit on every null 
element.
+      // This might be profitable for other types as well, but we take the most
+      // conservative approach for now.
+      memset(out_arr->buffers[1]->mutable_data(), 0, 
out_arr->buffers[1]->size());
+      return TakeIndexDispatch<
+          FixedWidthTakeImpl, std::integral_constant<int, 1>, 
/*OutputIsZeroInitialized=*/
+          std::true_type>(ctx, values, indices, out_arr);
     case 8:
-      TakeIndexDispatch<PrimitiveTakeImpl, std::integral_constant<int, 1>>(
-          values, indices, out_arr);
-      break;
+      return TakeIndexDispatch<FixedWidthTakeImpl, std::integral_constant<int, 
8>>(
+          ctx, values, indices, out_arr);
     case 16:
-      TakeIndexDispatch<PrimitiveTakeImpl, std::integral_constant<int, 2>>(
-          values, indices, out_arr);
-      break;
+      return TakeIndexDispatch<FixedWidthTakeImpl, std::integral_constant<int, 
16>>(
+          ctx, values, indices, out_arr);
     case 32:
-      TakeIndexDispatch<PrimitiveTakeImpl, std::integral_constant<int, 4>>(
-          values, indices, out_arr);
-      break;
+      return TakeIndexDispatch<FixedWidthTakeImpl, std::integral_constant<int, 
32>>(
+          ctx, values, indices, out_arr);
     case 64:
-      TakeIndexDispatch<PrimitiveTakeImpl, std::integral_constant<int, 8>>(
-          values, indices, out_arr);
-      break;
+      return TakeIndexDispatch<FixedWidthTakeImpl, std::integral_constant<int, 
64>>(
+          ctx, values, indices, out_arr);
     case 128:
       // For INTERVAL_MONTH_DAY_NANO, DECIMAL128
-      TakeIndexDispatch<PrimitiveTakeImpl, std::integral_constant<int, 16>>(
-          values, indices, out_arr);
-      break;
+      return TakeIndexDispatch<FixedWidthTakeImpl, std::integral_constant<int, 
128>>(
+          ctx, values, indices, out_arr);
     case 256:
       // For DECIMAL256
-      TakeIndexDispatch<PrimitiveTakeImpl, std::integral_constant<int, 32>>(
-          values, indices, out_arr);
-      break;
-    default:
-      return Status::NotImplemented("Unsupported primitive type for take: ",
-                                    *values.type);
+      return TakeIndexDispatch<FixedWidthTakeImpl, std::integral_constant<int, 
256>>(
+          ctx, values, indices, out_arr);
   }
-  return Status::OK();
+  if (ARROW_PREDICT_TRUE(values.type->id() == Type::FIXED_SIZE_BINARY ||
+                         values.type->id() == Type::FIXED_SIZE_LIST)) {
+    int64_t byte_width = util::FixedWidthInBytes(*values.type);
+    // 0-length fixed-size binary or lists were handled above on `case 0`
+    DCHECK_GT(byte_width, 0);
+    return TakeIndexDispatch<FixedWidthTakeImpl,
+                             /*ValueBitWidth=*/std::integral_constant<int, 8>,
+                             /*OutputIsZeroInitialized=*/std::false_type,
+                             /*WithFactor=*/std::true_type>(ctx, values, 
indices, out_arr,
+                                                            
/*factor=*/byte_width);
+  }
+  return Status::NotImplemented("Unsupported primitive type for take: ", 
*values.type);
 }
 
 namespace {
@@ -883,13 +732,11 @@ void 
PopulateTakeKernels(std::vector<SelectionKernelData>* out) {
   auto take_indices = match::Integer();
 
   *out = {
-      {InputType(match::Primitive()), take_indices, PrimitiveTakeExec},
+      {InputType(match::Primitive()), take_indices, FixedWidthTakeExec},
       {InputType(match::BinaryLike()), take_indices, VarBinaryTakeExec},
       {InputType(match::LargeBinaryLike()), take_indices, 
LargeVarBinaryTakeExec},
-      {InputType(Type::FIXED_SIZE_BINARY), take_indices, FSBTakeExec},
+      {InputType(match::FixedSizeBinaryLike()), take_indices, 
FixedWidthTakeExec},
       {InputType(null()), take_indices, NullTakeExec},
-      {InputType(Type::DECIMAL128), take_indices, PrimitiveTakeExec},
-      {InputType(Type::DECIMAL256), take_indices, PrimitiveTakeExec},
       {InputType(Type::DICTIONARY), take_indices, DictionaryTake},
       {InputType(Type::EXTENSION), take_indices, ExtensionTake},
       {InputType(Type::LIST), take_indices, ListTakeExec},
diff --git a/cpp/src/arrow/compute/kernels/vector_selection_test.cc 
b/cpp/src/arrow/compute/kernels/vector_selection_test.cc
index 6261fa2dae..cafd889015 100644
--- a/cpp/src/arrow/compute/kernels/vector_selection_test.cc
+++ b/cpp/src/arrow/compute/kernels/vector_selection_test.cc
@@ -1146,6 +1146,15 @@ void ValidateTakeImpl(const std::shared_ptr<Array>& 
values,
   for (int64_t i = 0; i < indices->length(); ++i) {
     if (typed_indices->IsNull(i) || 
typed_values->IsNull(typed_indices->Value(i))) {
       ASSERT_TRUE(result->IsNull(i)) << i;
+      // The value of a null element is undefined, but right
+      // out of the Take kernel it is expected to be 0.
+      if constexpr (is_primitive(ValuesType::type_id)) {
+        if constexpr (ValuesType::type_id == Type::BOOL) {
+          ASSERT_EQ(typed_result->Value(i), false);
+        } else {
+          ASSERT_EQ(typed_result->Value(i), 0);
+        }
+      }
     } else {
       ASSERT_FALSE(result->IsNull(i)) << i;
       ASSERT_EQ(typed_result->GetView(i), 
typed_values->GetView(typed_indices->Value(i)))
@@ -1522,9 +1531,13 @@ TEST_F(TestTakeKernelWithFixedSizeList, 
TakeFixedSizeListInt32) {
   CheckTake(fixed_size_list(int32(), 3), list_json, "[3, 0, 0, 3]",
             "[[7, 8, null], null, null, [7, 8, null]]");
   CheckTake(fixed_size_list(int32(), 3), list_json, "[0, 1, 2, 3]", list_json);
+
+  // No nulls in inner list values trigger the use of FixedWidthTakeExec() in
+  // FSLTakeExec()
+  std::string no_nulls_list_json = "[[0, 0, 0], [1, 2, 3], [4, 5, 6], [7, 8, 
9]]";
   CheckTake(
-      fixed_size_list(int32(), 3), list_json, "[2, 2, 2, 2, 2, 2, 1]",
-      "[[4, 5, 6], [4, 5, 6], [4, 5, 6], [4, 5, 6], [4, 5, 6], [4, 5, 6], [1, 
null, 3]]");
+      fixed_size_list(int32(), 3), no_nulls_list_json, "[2, 2, 2, 2, 2, 2, 1]",
+      "[[4, 5, 6], [4, 5, 6], [4, 5, 6], [4, 5, 6], [4, 5, 6], [4, 5, 6], [1, 
2, 3]]");
 
   this->TestNoValidityBitmapButUnknownNullCount(fixed_size_list(int32(), 3),
                                                 "[[1, null, 3], [4, 5, 6], [7, 
8, null]]",


Reply via email to