wesm commented on a change in pull request #7442:
URL: https://github.com/apache/arrow/pull/7442#discussion_r441662165



##########
File path: cpp/src/arrow/compute/kernels/vector_selection.cc
##########
@@ -0,0 +1,1816 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <algorithm>
+#include <cstring>
+#include <limits>
+
+#include "arrow/array/array_base.h"
+#include "arrow/array/array_binary.h"
+#include "arrow/array/array_dict.h"
+#include "arrow/array/array_nested.h"
+#include "arrow/array/builder_primitive.h"
+#include "arrow/array/concatenate.h"
+#include "arrow/buffer_builder.h"
+#include "arrow/compute/api_vector.h"
+#include "arrow/compute/kernels/common.h"
+#include "arrow/compute/kernels/util_internal.h"
+#include "arrow/extension_type.h"
+#include "arrow/record_batch.h"
+#include "arrow/result.h"
+#include "arrow/util/bit_block_counter.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/bitmap_ops.h"
+#include "arrow/util/bitmap_reader.h"
+#include "arrow/util/int_util.h"
+
+namespace arrow {
+
+using internal::BinaryBitBlockCounter;
+using internal::BitBlockCount;
+using internal::BitBlockCounter;
+using internal::BitmapReader;
+using internal::CopyBitmap;
+using internal::GetArrayView;
+using internal::IndexBoundsCheck;
+using internal::OptionalBitBlockCounter;
+using internal::OptionalBitIndexer;
+
+namespace compute {
+namespace internal {
+
+int64_t GetFilterOutputSize(const ArrayData& filter,
+                            FilterOptions::NullSelectionBehavior 
null_selection) {
+  int64_t output_size = 0;
+  int64_t position = 0;
+  if (filter.GetNullCount() > 0) {
+    const uint8_t* filter_is_valid = filter.buffers[0]->data();
+    BinaryBitBlockCounter bit_counter(filter.buffers[1]->data(), filter.offset,
+                                      filter_is_valid, filter.offset, 
filter.length);
+    if (null_selection == FilterOptions::EMIT_NULL) {
+      while (position < filter.length) {
+        BitBlockCount block = bit_counter.NextOrNotWord();
+        output_size += block.popcount;
+        position += block.length;
+      }
+    } else {
+      while (position < filter.length) {
+        BitBlockCount block = bit_counter.NextAndWord();
+        output_size += block.popcount;
+        position += block.length;
+      }
+    }
+  } else {
+    // The filter has no nulls, so we plow through its data as fast as
+    // possible.
+    BitBlockCounter bit_counter(filter.buffers[1]->data(), filter.offset, 
filter.length);
+    while (position < filter.length) {
+      BitBlockCount block = bit_counter.NextFourWords();
+      output_size += block.popcount;
+      position += block.length;
+    }
+  }
+  return output_size;
+}
+
+template <typename IndexType>
+Result<std::shared_ptr<ArrayData>> GetTakeIndicesImpl(
+    const ArrayData& filter, FilterOptions::NullSelectionBehavior 
null_selection,
+    MemoryPool* memory_pool) {
+  using T = typename IndexType::c_type;
+  typename TypeTraits<IndexType>::BuilderType builder(memory_pool);
+
+  const uint8_t* filter_data = filter.buffers[1]->data();
+  BitBlockCounter data_counter(filter_data, filter.offset, filter.length);
+
+  // The position relative to the start of the filter
+  T position = 0;
+
+  // The current position taking the filter offset into account
+  int64_t position_with_offset = filter.offset;
+  if (filter.GetNullCount() > 0) {
+    // The filter has nulls, so we scan the validity bitmap and the filter data
+    // bitmap together, branching on the null selection type.
+    const uint8_t* filter_is_valid = filter.buffers[0]->data();
+
+    // To count blocks whether filter_data[i] || !filter_is_valid[i]
+    BinaryBitBlockCounter filter_counter(filter_data, filter.offset, 
filter_is_valid,
+                                         filter.offset, filter.length);
+    if (null_selection == FilterOptions::DROP) {
+      while (position < filter.length) {
+        BitBlockCount and_block = filter_counter.NextAndWord();
+        RETURN_NOT_OK(builder.Reserve(and_block.popcount));
+        if (and_block.AllSet()) {
+          // All the values are selected and non-null
+          for (int64_t i = 0; i < and_block.length; ++i) {
+            builder.UnsafeAppend(position++);
+          }
+          position_with_offset += and_block.length;
+        } else {
+          // Some of the values are false or null
+          for (int64_t i = 0; i < and_block.length; ++i) {
+            if (BitUtil::GetBit(filter_is_valid, position_with_offset) &&
+                BitUtil::GetBit(filter_data, position_with_offset)) {
+              builder.UnsafeAppend(position);
+            }
+            ++position;
+            ++position_with_offset;
+          }
+        }
+      }
+    } else {
+      BitBlockCounter is_valid_counter(filter_is_valid, filter.offset, 
filter.length);
+      while (position < filter.length) {
+        // true OR NOT valid
+        BitBlockCount or_not_block = filter_counter.NextOrNotWord();
+        RETURN_NOT_OK(builder.Reserve(or_not_block.popcount));
+
+        // If the values are all valid and the or_not_block is full, then we
+        // can infer that all the values are true and skip the bit checking
+        BitBlockCount is_valid_block = is_valid_counter.NextWord();
+
+        if (or_not_block.AllSet() && is_valid_block.AllSet()) {
+          // All the values are selected and non-null
+          for (int64_t i = 0; i < or_not_block.length; ++i) {
+            builder.UnsafeAppend(position++);
+          }
+          position_with_offset += or_not_block.length;
+        } else {
+          // Some of the values are false or null
+          for (int64_t i = 0; i < or_not_block.length; ++i) {
+            if (BitUtil::GetBit(filter_is_valid, position_with_offset)) {
+              if (BitUtil::GetBit(filter_data, position_with_offset)) {
+                builder.UnsafeAppend(position);
+              }
+            } else {
+              // Null slot, so append a null
+              builder.UnsafeAppendNull();
+            }
+            ++position;
+            ++position_with_offset;
+          }
+        }
+      }
+    }
+  } else {
+    // The filter has no nulls, so we need only look for true values
+    BitBlockCount current_block = data_counter.NextWord();
+    while (position < filter.length) {
+      if (current_block.AllSet()) {
+        int64_t run_length = 0;
+
+        // If we've found a all-true block, then we scan forward until we find
+        // a block that has some false values (or we reach the end
+        while (current_block.length > 0 && current_block.AllSet()) {
+          run_length += current_block.length;
+          current_block = data_counter.NextWord();
+        }
+
+        // Append the consecutive run of indices
+        RETURN_NOT_OK(builder.Reserve(run_length));
+        for (int64_t i = 0; i < run_length; ++i) {
+          builder.UnsafeAppend(position++);
+        }
+        position_with_offset += run_length;
+      } else {
+        // Must do bitchecking on the current block
+        RETURN_NOT_OK(builder.Reserve(current_block.popcount));
+        for (int64_t i = 0; i < current_block.length; ++i) {
+          if (BitUtil::GetBit(filter_data, position_with_offset)) {
+            builder.UnsafeAppend(position);
+          }
+          ++position;
+          ++position_with_offset;
+        }
+        current_block = data_counter.NextWord();
+      }
+    }
+  }
+  std::shared_ptr<ArrayData> result;
+  RETURN_NOT_OK(builder.FinishInternal(&result));
+  return result;
+}
+
+Result<std::shared_ptr<ArrayData>> GetTakeIndices(
+    const ArrayData& filter, FilterOptions::NullSelectionBehavior 
null_selection,
+    MemoryPool* memory_pool) {
+  DCHECK_EQ(filter.type->id(), Type::BOOL);
+  if (filter.length <= std::numeric_limits<uint16_t>::max()) {
+    return GetTakeIndicesImpl<UInt16Type>(filter, null_selection, memory_pool);
+  } else if (filter.length <= std::numeric_limits<uint32_t>::max()) {
+    return GetTakeIndicesImpl<UInt32Type>(filter, null_selection, memory_pool);
+  } else {
+    // Arrays over 4 billion elements, not especially likely.
+    return Status::NotImplemented(
+        "Filter length exceeds UINT32_MAX, "
+        "consider a different strategy for selecting elements");
+  }
+}
+
+namespace {
+
+template <typename ArrowType>
+struct GetCType {
+  using type = typename ArrowType::c_type;
+};
+
+// We want uint8_t for boolean instead of bool
+template <>
+struct GetCType<BooleanType> {
+  using type = uint8_t;
+};
+
+using FilterState = OptionsWrapper<FilterOptions>;
+using TakeState = OptionsWrapper<TakeOptions>;
+
+Status PreallocateData(KernelContext* ctx, int64_t length, int bit_width, 
Datum* out) {
+  // Preallocate memory
+  ArrayData* out_arr = out->mutable_array();
+  out_arr->length = length;
+  out_arr->buffers.resize(2);
+
+  ARROW_ASSIGN_OR_RAISE(out_arr->buffers[0], ctx->AllocateBitmap(length));
+  if (bit_width == 1) {
+    ARROW_ASSIGN_OR_RAISE(out_arr->buffers[1], ctx->AllocateBitmap(length));
+  } else {
+    ARROW_ASSIGN_OR_RAISE(out_arr->buffers[1], ctx->Allocate(length * 
bit_width / 8));
+  }
+  return Status::OK();
+}
+
+// ----------------------------------------------------------------------
+// Implement optimized take for primitive types from boolean to 1/2/4/8-byte
+// C-type based types. Use common implementation for every byte width and only
+// generate code for unsigned integer indices, since after boundschecking to
+// check for negative numbers in the indices we can safely reinterpret_cast
+// signed integers as unsigned.
+
+/// \brief The Take implementation for primitive (fixed-width) types does not
+/// use the logical Arrow type but rather the physical C type. This way we
+/// only generate one take function for each byte width.
+///
+/// This function assumes that the indices have been boundschecked.
+template <typename IndexCType, typename ValueCType>
+struct PrimitiveTakeImpl {
+  static void Exec(const PrimitiveArg& values, const PrimitiveArg& indices,
+                   Datum* out_datum) {
+    auto values_data = reinterpret_cast<const ValueCType*>(values.data);
+    auto values_is_valid = values.is_valid;
+    auto values_offset = values.offset;
+
+    auto indices_data = reinterpret_cast<const IndexCType*>(indices.data);
+    auto indices_is_valid = indices.is_valid;
+    auto indices_offset = indices.offset;
+
+    ArrayData* out_arr = out_datum->mutable_array();
+    auto out = out_arr->GetMutableValues<ValueCType>(1);
+    auto out_is_valid = out_arr->buffers[0]->mutable_data();
+    auto out_offset = out_arr->offset;
+
+    // If either the values or indices have nulls, we preemptively zero out the
+    // out validity bitmap so that we don't have to use ClearBit in each
+    // iteration for nulls.
+    if (values.null_count > 0 || indices.null_count > 0) {
+      BitUtil::SetBitsTo(out_is_valid, out_offset, indices.length, false);
+    }
+
+    OptionalBitBlockCounter indices_bit_counter(indices_is_valid, 
indices_offset,
+                                                indices.length);
+    int64_t position = 0;
+    int64_t valid_count = 0;
+    while (position < indices.length) {
+      BitBlockCount block = indices_bit_counter.NextBlock();
+      if (values.null_count == 0) {
+        // Values are never null, so things are easier
+        valid_count += block.popcount;
+        if (block.popcount == block.length) {
+          // Fastest path: neither values nor index nulls
+          BitUtil::SetBitsTo(out_is_valid, out_offset + position, 
block.length, true);
+          for (int64_t i = 0; i < block.length; ++i) {
+            out[position] = values_data[indices_data[position]];
+            ++position;
+          }
+        } else if (block.popcount > 0) {
+          // Slow path: some indices but not all are null
+          for (int64_t i = 0; i < block.length; ++i) {
+            if (BitUtil::GetBit(indices_is_valid, indices_offset + position)) {
+              // index is not null
+              BitUtil::SetBit(out_is_valid, out_offset + position);
+              out[position] = values_data[indices_data[position]];
+            } else {
+              out[position] = ValueCType{};
+            }
+            ++position;
+          }
+        } else {
+          memset(out + position, 0, sizeof(ValueCType) * block.length);
+          position += block.length;
+        }
+      } else {
+        // Values have nulls, so we must do random access into the values 
bitmap
+        if (block.popcount == block.length) {
+          // Faster path: indices are not null but values may be
+          for (int64_t i = 0; i < block.length; ++i) {
+            if (BitUtil::GetBit(values_is_valid,
+                                values_offset + indices_data[position])) {
+              // value is not null
+              out[position] = values_data[indices_data[position]];
+              BitUtil::SetBit(out_is_valid, out_offset + position);
+              ++valid_count;
+            } else {
+              out[position] = ValueCType{};
+            }
+            ++position;
+          }
+        } else if (block.popcount > 0) {
+          // Slow path: some but not all indices are null. Since we are doing
+          // random access in general we have to check the value nullness one 
by
+          // one.
+          for (int64_t i = 0; i < block.length; ++i) {
+            if (BitUtil::GetBit(indices_is_valid, indices_offset + position) &&
+                BitUtil::GetBit(values_is_valid,
+                                values_offset + indices_data[position])) {
+              // index is not null && value is not null
+              out[position] = values_data[indices_data[position]];
+              BitUtil::SetBit(out_is_valid, out_offset + position);
+              ++valid_count;
+            } else {
+              out[position] = ValueCType{};
+            }
+            ++position;
+          }
+        } else {
+          memset(out + position, 0, sizeof(ValueCType) * block.length);
+          position += block.length;
+        }
+      }
+    }
+    out_arr->null_count = out_arr->length - valid_count;
+  }
+};
+
+template <typename IndexCType>
+struct BooleanTakeImpl {
+  static void Exec(const PrimitiveArg& values, const PrimitiveArg& indices,
+                   Datum* out_datum) {
+    const uint8_t* values_data = values.data;
+    auto values_is_valid = values.is_valid;
+    auto values_offset = values.offset;
+
+    auto indices_data = reinterpret_cast<const IndexCType*>(indices.data);
+    auto indices_is_valid = indices.is_valid;
+    auto indices_offset = indices.offset;
+
+    ArrayData* out_arr = out_datum->mutable_array();
+    auto out = out_arr->buffers[1]->mutable_data();
+    auto out_is_valid = out_arr->buffers[0]->mutable_data();
+    auto out_offset = out_arr->offset;
+
+    // If either the values or indices have nulls, we preemptively zero out the
+    // out validity bitmap so that we don't have to use ClearBit in each
+    // iteration for nulls.
+    if (values.null_count > 0 || indices.null_count > 0) {
+      BitUtil::SetBitsTo(out_is_valid, out_offset, indices.length, false);
+    }
+    // Avoid uninitialized data in values array
+    BitUtil::SetBitsTo(out, out_offset, indices.length, false);
+
+    auto PlaceDataBit = [&](int64_t loc, IndexCType index) {
+      BitUtil::SetBitTo(out, out_offset + loc,
+                        BitUtil::GetBit(values_data, values_offset + index));
+    };
+
+    OptionalBitBlockCounter indices_bit_counter(indices_is_valid, 
indices_offset,
+                                                indices.length);
+    int64_t position = 0;
+    int64_t valid_count = 0;
+    while (position < indices.length) {
+      BitBlockCount block = indices_bit_counter.NextBlock();
+      if (values.null_count == 0) {
+        // Values are never null, so things are easier
+        valid_count += block.popcount;
+        if (block.popcount == block.length) {
+          // Fastest path: neither values nor index nulls
+          BitUtil::SetBitsTo(out_is_valid, out_offset + position, 
block.length, true);
+          for (int64_t i = 0; i < block.length; ++i) {
+            PlaceDataBit(position, indices_data[position]);
+            ++position;
+          }
+        } else if (block.popcount > 0) {
+          // Slow path: some but not all indices are null
+          for (int64_t i = 0; i < block.length; ++i) {
+            if (BitUtil::GetBit(indices_is_valid, indices_offset + position)) {
+              // index is not null
+              BitUtil::SetBit(out_is_valid, out_offset + position);
+              PlaceDataBit(position, indices_data[position]);
+            }
+            ++position;
+          }
+        } else {
+          position += block.length;
+        }
+      } else {
+        // Values have nulls, so we must do random access into the values 
bitmap
+        if (block.popcount == block.length) {
+          // Faster path: indices are not null but values may be
+          for (int64_t i = 0; i < block.length; ++i) {
+            if (BitUtil::GetBit(values_is_valid,
+                                values_offset + indices_data[position])) {
+              // value is not null
+              BitUtil::SetBit(out_is_valid, out_offset + position);
+              PlaceDataBit(position, indices_data[position]);
+              ++valid_count;
+            }
+            ++position;
+          }
+        } else if (block.popcount > 0) {
+          // Slow path: some but not all indices are null. Since we are doing
+          // random access in general we have to check the value nullness one 
by
+          // one.
+          for (int64_t i = 0; i < block.length; ++i) {
+            if (BitUtil::GetBit(indices_is_valid, indices_offset + position)) {
+              // index is not null
+              if (BitUtil::GetBit(values_is_valid,
+                                  values_offset + indices_data[position])) {
+                // value is not null
+                PlaceDataBit(position, indices_data[position]);
+                BitUtil::SetBit(out_is_valid, out_offset + position);
+                ++valid_count;
+              }
+            }
+            ++position;
+          }
+        } else {
+          position += block.length;
+        }
+      }
+    }
+    out_arr->null_count = out_arr->length - valid_count;
+  }
+};
+
+template <template <typename...> class TakeImpl, typename... Args>
+void TakeIndexDispatch(const PrimitiveArg& values, const PrimitiveArg& indices,
+                       Datum* out) {
+  // With the simplifying assumption that boundschecking has taken place
+  // already at a higher level, we can now assume that the index values are all
+  // non-negative. Thus, we can interpret signed integers as unsigned and avoid
+  // having to generate double the amount of binary code to handle each integer
+  // width.
+  switch (indices.bit_width) {
+    case 8:
+      return TakeImpl<uint8_t, Args...>::Exec(values, indices, out);
+    case 16:
+      return TakeImpl<uint16_t, Args...>::Exec(values, indices, out);
+    case 32:
+      return TakeImpl<uint32_t, Args...>::Exec(values, indices, out);
+    case 64:
+      return TakeImpl<uint64_t, Args...>::Exec(values, indices, out);
+    default:
+      DCHECK(false) << "Invalid indices byte width";
+      break;
+  }
+}
+
+void PrimitiveTake(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+  const auto& state = checked_cast<const TakeState&>(*ctx->state());
+  if (state.options.boundscheck) {
+    KERNEL_RETURN_IF_ERROR(ctx, IndexBoundsCheck(*batch[1].array(), 
batch[0].length()));
+  }
+
+  PrimitiveArg values = GetPrimitiveArg(*batch[0].array());
+  PrimitiveArg indices = GetPrimitiveArg(*batch[1].array());
+  KERNEL_RETURN_IF_ERROR(ctx,
+                         PreallocateData(ctx, indices.length, 
values.bit_width, out));
+  switch (values.bit_width) {
+    case 1:
+      return TakeIndexDispatch<BooleanTakeImpl>(values, indices, out);
+    case 8:
+      return TakeIndexDispatch<PrimitiveTakeImpl, int8_t>(values, indices, 
out);
+    case 16:
+      return TakeIndexDispatch<PrimitiveTakeImpl, int16_t>(values, indices, 
out);
+    case 32:
+      return TakeIndexDispatch<PrimitiveTakeImpl, int32_t>(values, indices, 
out);
+    case 64:
+      return TakeIndexDispatch<PrimitiveTakeImpl, int64_t>(values, indices, 
out);
+    default:
+      DCHECK(false) << "Invalid values byte width";
+      break;
+  }
+}
+
+// ----------------------------------------------------------------------
+// Optimized and streamlined filter for primitive types
+
+// Use either BitBlockCounter or BinaryBitBlockCounter to quickly scan filter a
+// word at a time for the DROP selection type.
+class DropNullCounter {
+ public:
+  // validity bitmap may be null
+  DropNullCounter(const uint8_t* validity, const uint8_t* data, int64_t offset,
+                  int64_t length)
+      : data_counter_(data, offset, length),
+        data_and_validity_counter_(data, offset, validity, offset, length),
+        has_validity_(validity != nullptr) {}
+
+  BitBlockCount NextBlock() {
+    if (has_validity_) {
+      // filter is true AND not null
+      return data_and_validity_counter_.NextAndWord();
+    } else {
+      return data_counter_.NextWord();
+    }
+  }
+
+ private:
+  // For when just data is present, but no validity bitmap
+  BitBlockCounter data_counter_;
+
+  // For when both validity bitmap and data are present
+  BinaryBitBlockCounter data_and_validity_counter_;
+  bool has_validity_;
+};
+
+/// \brief The Filter implementation for primitive (fixed-width) types does not
+/// use the logical Arrow type but rather then physical C type. This way we
+/// only generate one take function for each byte width. We use the same
+/// implementation here for boolean and fixed-byte-size inputs with some
+/// template specialization.
+template <typename ArrowType>
+class PrimitiveFilterImpl {
+ public:
+  using T = typename GetCType<ArrowType>::type;
+
+  PrimitiveFilterImpl(const PrimitiveArg& values, const PrimitiveArg& filter,
+                      FilterOptions::NullSelectionBehavior null_selection,
+                      Datum* out_datum)
+      : values_is_valid_(values.is_valid),
+        values_data_(reinterpret_cast<const T*>(values.data)),
+        values_null_count_(values.null_count),
+        values_offset_(values.offset),
+        values_length_(values.length),
+        filter_is_valid_(filter.is_valid),
+        filter_data_(filter.data),
+        filter_null_count_(filter.null_count),
+        filter_offset_(filter.offset),
+        null_selection_(null_selection) {
+    ArrayData* out_arr = out_datum->mutable_array();
+    out_is_valid_ = out_arr->buffers[0]->mutable_data();
+    out_data_ = reinterpret_cast<T*>(out_arr->buffers[1]->mutable_data());
+    out_offset_ = out_arr->offset;
+    out_length_ = out_arr->length;
+    out_position_ = 0;
+  }
+
+  void ExecNonNull() {
+    // The result is all not-null
+    BitUtil::SetBitsTo(out_is_valid_, out_offset_ + out_position_, 
out_length_, true);

Review comment:
       Good point, fixing




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to