felipecrv commented on code in PR #41373:
URL: https://github.com/apache/arrow/pull/41373#discussion_r1591385343


##########
cpp/src/arrow/compute/kernels/vector_selection_take_internal.cc:
##########
@@ -324,261 +326,109 @@ namespace {
 using TakeState = OptionsWrapper<TakeOptions>;
 
 // ----------------------------------------------------------------------
-// Implement optimized take for primitive types from boolean to 
1/2/4/8/16/32-byte
-// C-type based types. Use common implementation for every byte width and only
-// generate code for unsigned integer indices, since after boundschecking to
-// check for negative numbers in the indices we can safely reinterpret_cast
-// signed integers as unsigned.
-
-/// \brief The Take implementation for primitive (fixed-width) types does not
-/// use the logical Arrow type but rather the physical C type. This way we
-/// only generate one take function for each byte width.
+// Implement optimized take for primitive types from boolean to
+// 1/2/4/8/16/32-byte C-type based types and fixed-size binary (0 or more
+// bytes).
+//
+// Use one specialization for each of these primitive byte-widths so the
+// compiler can specialize the memcpy to dedicated CPU instructions and for
+// fixed-width binary use the 1-byte specialization but pass WithFactor=true
+// that makes the kernel consider the factor parameter provided at runtime.
+//
+// Only unsigned index types need to be instantiated since after
+// boundschecking to check for negative numbers in the indices we can safely
+// reinterpret_cast signed integers as unsigned.
+
+/// \brief The Take implementation for primitive types and fixed-width binary.
 ///
 /// Also note that this function can also handle fixed-size-list arrays if
 /// they fit the criteria described in fixed_width_internal.h, so use the
 /// function defined in that file to access values and destination pointers
 /// and DO NOT ASSUME `values.type()` is a primitive type.
 ///
 /// \pre the indices have been boundschecked
-template <typename IndexCType, typename ValueWidthConstant>
-struct PrimitiveTakeImpl {
-  static constexpr int kValueWidth = ValueWidthConstant::value;
-
-  static void Exec(const ArraySpan& values, const ArraySpan& indices,
-                   ArrayData* out_arr) {
-    DCHECK_EQ(util::FixedWidthInBytes(*values.type), kValueWidth);
-    const auto* values_data = util::OffsetPointerOfFixedWidthValues(values);
-    const uint8_t* values_is_valid = values.buffers[0].data;
-    auto values_offset = values.offset;
-
-    const auto* indices_data = indices.GetValues<IndexCType>(1);
-    const uint8_t* indices_is_valid = indices.buffers[0].data;
-    auto indices_offset = indices.offset;
-
-    DCHECK_EQ(out_arr->offset, 0);
-    auto* out = util::MutableFixedWidthValuesPointer(out_arr);
-    auto out_is_valid = out_arr->buffers[0]->mutable_data();
-
-    // If either the values or indices have nulls, we preemptively zero out the
-    // out validity bitmap so that we don't have to use ClearBit in each
-    // iteration for nulls.
-    if (values.null_count != 0 || indices.null_count != 0) {
-      bit_util::SetBitsTo(out_is_valid, 0, indices.length, false);
-    }
-
-    auto WriteValue = [&](int64_t position) {
-      memcpy(out + position * kValueWidth,
-             values_data + indices_data[position] * kValueWidth, kValueWidth);
-    };
-
-    auto WriteZero = [&](int64_t position) {
-      memset(out + position * kValueWidth, 0, kValueWidth);
-    };
-
-    auto WriteZeroSegment = [&](int64_t position, int64_t length) {
-      memset(out + position * kValueWidth, 0, kValueWidth * length);
-    };
-
-    OptionalBitBlockCounter indices_bit_counter(indices_is_valid, 
indices_offset,
-                                                indices.length);
-    int64_t position = 0;
-    int64_t valid_count = 0;
-    while (position < indices.length) {
-      BitBlockCount block = indices_bit_counter.NextBlock();
-      if (values.null_count == 0) {
-        // Values are never null, so things are easier
-        valid_count += block.popcount;
-        if (block.popcount == block.length) {
-          // Fastest path: neither values nor index nulls
-          bit_util::SetBitsTo(out_is_valid, position, block.length, true);
-          for (int64_t i = 0; i < block.length; ++i) {
-            WriteValue(position);
-            ++position;
-          }
-        } else if (block.popcount > 0) {
-          // Slow path: some indices but not all are null
-          for (int64_t i = 0; i < block.length; ++i) {
-            if (bit_util::GetBit(indices_is_valid, indices_offset + position)) 
{
-              // index is not null
-              bit_util::SetBit(out_is_valid, position);
-              WriteValue(position);
-            } else {
-              WriteZero(position);
-            }
-            ++position;
-          }
-        } else {
-          WriteZeroSegment(position, block.length);
-          position += block.length;
-        }
-      } else {
-        // Values have nulls, so we must do random access into the values 
bitmap
-        if (block.popcount == block.length) {
-          // Faster path: indices are not null but values may be
-          for (int64_t i = 0; i < block.length; ++i) {
-            if (bit_util::GetBit(values_is_valid,
-                                 values_offset + indices_data[position])) {
-              // value is not null
-              WriteValue(position);
-              bit_util::SetBit(out_is_valid, position);
-              ++valid_count;
-            } else {
-              WriteZero(position);
-            }
-            ++position;
-          }
-        } else if (block.popcount > 0) {
-          // Slow path: some but not all indices are null. Since we are doing
-          // random access in general we have to check the value nullness one 
by
-          // one.
-          for (int64_t i = 0; i < block.length; ++i) {
-            if (bit_util::GetBit(indices_is_valid, indices_offset + position) 
&&
-                bit_util::GetBit(values_is_valid,
-                                 values_offset + indices_data[position])) {
-              // index is not null && value is not null
-              WriteValue(position);
-              bit_util::SetBit(out_is_valid, position);
-              ++valid_count;
-            } else {
-              WriteZero(position);
-            }
-            ++position;
-          }
-        } else {
-          WriteZeroSegment(position, block.length);
-          position += block.length;
-        }
-      }
+template <typename IndexCType, typename ValueBitWidthConstant,
+          typename OutputIsZeroInitialized = std::false_type,
+          typename WithFactor = std::false_type>
+struct FixedWidthTakeImpl {
+  static constexpr int kValueWidthInBits = ValueBitWidthConstant::value;
+
+  // offset returned is defined as number of kValueWidthInBits blocks
+  static std::pair<int64_t, const uint8_t*> SourceOffsetAndValuesPointer(
+      const ArraySpan& values) {
+    if constexpr (kValueWidthInBits == 1) {
+      DCHECK_EQ(values.type->id(), Type::BOOL);
+      return {values.offset, values.GetValues<uint8_t>(1, 0)};
+    } else {
+      static_assert(kValueWidthInBits % 8 == 0,
+                    "kValueWidthInBits must be 1 or a multiple of 8");
+      return {0, util::OffsetPointerOfFixedWidthValues(values)};
     }
-    out_arr->null_count = out_arr->length - valid_count;
   }
-};
 
-template <typename IndexCType>
-struct BooleanTakeImpl {
-  static void Exec(const ArraySpan& values, const ArraySpan& indices,
-                   ArrayData* out_arr) {
-    const uint8_t* values_data = values.buffers[1].data;
-    const uint8_t* values_is_valid = values.buffers[0].data;
-    auto values_offset = values.offset;
-
-    const auto* indices_data = indices.GetValues<IndexCType>(1);
-    const uint8_t* indices_is_valid = indices.buffers[0].data;
-    auto indices_offset = indices.offset;
-
-    auto out = out_arr->buffers[1]->mutable_data();
-    auto out_is_valid = out_arr->buffers[0]->mutable_data();
-    auto out_offset = out_arr->offset;
-
-    // If either the values or indices have nulls, we preemptively zero out the
-    // out validity bitmap so that we don't have to use ClearBit in each
-    // iteration for nulls.
-    if (values.null_count != 0 || indices.null_count != 0) {
-      bit_util::SetBitsTo(out_is_valid, out_offset, indices.length, false);
-    }
-    // Avoid uninitialized data in values array
-    bit_util::SetBitsTo(out, out_offset, indices.length, false);
-
-    auto PlaceDataBit = [&](int64_t loc, IndexCType index) {
-      bit_util::SetBitTo(out, out_offset + loc,
-                         bit_util::GetBit(values_data, values_offset + index));
-    };
-
-    OptionalBitBlockCounter indices_bit_counter(indices_is_valid, 
indices_offset,
-                                                indices.length);
-    int64_t position = 0;
+  static void Exec(KernelContext* ctx, const ArraySpan& values, const 
ArraySpan& indices,
+                   ArrayData* out_arr, size_t factor) {
+#ifndef NDEBUG
+    int64_t bit_width = util::FixedWidthInBits(*values.type);
+    DCHECK(WithFactor::value || (kValueWidthInBits == bit_width && factor == 
1));
+    DCHECK(!WithFactor::value ||
+           (factor > 0 && kValueWidthInBits == 8 &&  // factors are used with 
bytes
+            static_cast<int64_t>(factor * kValueWidthInBits) == bit_width));
+#endif
+    const bool out_has_validity = values.MayHaveNulls() || 
indices.MayHaveNulls();
+
+    const uint8_t* src;
+    int64_t src_offset;
+    std::tie(src_offset, src) = SourceOffsetAndValuesPointer(values);
+    uint8_t* out = util::MutableFixedWidthValuesPointer(out_arr);
     int64_t valid_count = 0;
-    while (position < indices.length) {
-      BitBlockCount block = indices_bit_counter.NextBlock();
-      if (values.null_count == 0) {
-        // Values are never null, so things are easier
-        valid_count += block.popcount;
-        if (block.popcount == block.length) {
-          // Fastest path: neither values nor index nulls
-          bit_util::SetBitsTo(out_is_valid, out_offset + position, 
block.length, true);
-          for (int64_t i = 0; i < block.length; ++i) {
-            PlaceDataBit(position, indices_data[position]);
-            ++position;
-          }
-        } else if (block.popcount > 0) {
-          // Slow path: some but not all indices are null
-          for (int64_t i = 0; i < block.length; ++i) {
-            if (bit_util::GetBit(indices_is_valid, indices_offset + position)) 
{
-              // index is not null
-              bit_util::SetBit(out_is_valid, out_offset + position);
-              PlaceDataBit(position, indices_data[position]);
-            }
-            ++position;
-          }
-        } else {
-          position += block.length;
-        }
-      } else {
-        // Values have nulls, so we must do random access into the values 
bitmap
-        if (block.popcount == block.length) {
-          // Faster path: indices are not null but values may be
-          for (int64_t i = 0; i < block.length; ++i) {
-            if (bit_util::GetBit(values_is_valid,
-                                 values_offset + indices_data[position])) {
-              // value is not null
-              bit_util::SetBit(out_is_valid, out_offset + position);
-              PlaceDataBit(position, indices_data[position]);
-              ++valid_count;
-            }
-            ++position;
-          }
-        } else if (block.popcount > 0) {
-          // Slow path: some but not all indices are null. Since we are doing
-          // random access in general we have to check the value nullness one 
by
-          // one.
-          for (int64_t i = 0; i < block.length; ++i) {
-            if (bit_util::GetBit(indices_is_valid, indices_offset + position)) 
{
-              // index is not null
-              if (bit_util::GetBit(values_is_valid,
-                                   values_offset + indices_data[position])) {
-                // value is not null
-                PlaceDataBit(position, indices_data[position]);
-                bit_util::SetBit(out_is_valid, out_offset + position);
-                ++valid_count;
-              }
-            }
-            ++position;
-          }
-        } else {
-          position += block.length;
-        }
-      }
+    arrow::internal::Gather<kValueWidthInBits, IndexCType, WithFactor::value> 
gather{
+        /*src_length=*/values.length,
+        src,
+        src_offset,
+        /*idx_length=*/indices.length,
+        /*idx=*/indices.GetValues<IndexCType>(1),
+        out,
+        factor};
+    if (out_has_validity) {
+      DCHECK_EQ(out_arr->offset, 0);
+      // out_is_valid must be zero-initiliazed, because Gather::Execute
+      // saves time by not having to ClearBit on every null element.
+      auto out_is_valid = out_arr->GetMutableValues<uint8_t>(0);
+      memset(out_is_valid, 0, bit_util::BytesForBits(out_arr->length));
+      valid_count = gather.template Execute<OutputIsZeroInitialized::value>(
+          /*src_validity=*/values, /*idx_validity=*/indices, out_is_valid);
+    } else {
+      valid_count = gather.Execute();
     }
     out_arr->null_count = out_arr->length - valid_count;
   }
 };
 
 template <template <typename...> class TakeImpl, typename... Args>
-void TakeIndexDispatch(const ArraySpan& values, const ArraySpan& indices,
-                       ArrayData* out) {
+void TakeIndexDispatch(KernelContext* ctx, const ArraySpan& values,
+                       const ArraySpan& indices, ArrayData* out, size_t factor 
= 1) {
   // With the simplifying assumption that boundschecking has taken place
   // already at a higher level, we can now assume that the index values are all
   // non-negative. Thus, we can interpret signed integers as unsigned and avoid
   // having to generate double the amount of binary code to handle each integer
   // width.
   switch (indices.type->byte_width()) {
     case 1:
-      return TakeImpl<uint8_t, Args...>::Exec(values, indices, out);
+      return TakeImpl<uint8_t, Args...>::Exec(ctx, values, indices, out, 
factor);
     case 2:
-      return TakeImpl<uint16_t, Args...>::Exec(values, indices, out);
+      return TakeImpl<uint16_t, Args...>::Exec(ctx, values, indices, out, 
factor);
     case 4:
-      return TakeImpl<uint32_t, Args...>::Exec(values, indices, out);
+      return TakeImpl<uint32_t, Args...>::Exec(ctx, values, indices, out, 
factor);
     case 8:
-      return TakeImpl<uint64_t, Args...>::Exec(values, indices, out);
-    default:
-      DCHECK(false) << "Invalid indices byte width";
-      break;
+      return TakeImpl<uint64_t, Args...>::Exec(ctx, values, indices, out, 
factor);
   }
+  DCHECK(false) << "Invalid indices byte width";

Review Comment:
   No, because the compiler doesn't expect switches on integers to handle all 
cases. And this function returns `void`, so it's fine to take a branch that 
never returns.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to